"""Optimizer util."""
import numpy as np
import mindspore
from mindspore import nn, Tensor
from mindspore.ops import operations as P
from mindformers.core.optim import build_optim
from mindformers.core.optim.muon import Muon
np.random.seed(1024)
fc1_weight = np.array([[0.72346634, 0.95608497, 0.4084163, 0.18627149,
0.6942514, 0.39767185, 0.24918061, 0.4548748],
[0.7203382, 0.19086994, 0.76286614, 0.87920564,
0.3169892, 0.9462494, 0.62827677, 0.27504718],
[0.3544535, 0.2524781, 0.5370583, 0.8313121,
0.6670143, 0.0488653, 0.62225235, 0.7546456],
[0.17985944, 0.05106374, 0.31064633, 0.4863033,
0.848814, 0.5523157, 0.20295663, 0.7213356]]).astype("float32")
fc1_bias = np.array([0.79708564, 0.13728078, 0.66322654, 0.88128525]).astype("float32")
fc2_weight = np.array([[0.8473515, 0.50923985, 0.42287776, 0.29769543]]).astype("float32")
fc2_bias = np.array([0.09996348]).astype("float32")
def make_fake_data():
"""
make fake data
"""
data, label = [], []
for i in range(20):
data.append(mindspore.Tensor(np.array(np.ones((2, 8)) * i, dtype=np.float32)))
label.append(mindspore.Tensor(np.array(np.ones((2, 1)) * (i + 1), dtype=np.float32)))
return data, label
class NetWithLoss(nn.Cell):
"""
build net with loss
"""
def __init__(self, network, loss_fn):
super().__init__()
self.network = network
self.loss = loss_fn
def construct(self, x, label):
out = self.network(x)
loss = self.loss(out, label)
return loss
class FakeNet(nn.Cell):
"""
build fake net
"""
def __init__(self):
super().__init__()
self.fc1 = nn.Dense(in_channels=8, out_channels=4, weight_init=Tensor(fc1_weight), bias_init=Tensor(fc1_bias))
self.fc2 = nn.Dense(in_channels=4, out_channels=1, weight_init=Tensor(fc2_weight), bias_init=Tensor(fc2_bias))
self.relu = nn.ReLU()
self.reducemean = P.ReduceMean()
def construct(self, x):
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
def _initialize_weights(self):
"""
parameter initialization
"""
self.init_parameters_data()
for name, m in self.cells_and_names():
if name == 'fc1':
m.weight.set_data(Tensor(fc1_weight))
m.bias.set_data(Tensor(fc1_bias))
elif name == 'fc2':
m.weight.set_data(Tensor(fc2_weight))
m.bias.set_data(Tensor(fc2_bias))
def build_network(opt_config, net, is_group=None, loss_fn=None):
"""
Construct training
"""
if is_group is None:
is_group = False
if loss_fn is None:
loss_fn = nn.L1Loss(reduction='mean')
losses = []
networkwithloss = NetWithLoss(net, loss_fn)
networkwithloss.set_train()
if is_group:
fc1_params = list(filter(lambda x: 'fc1' in x.name, networkwithloss.trainable_params()))
fc2_params = list(filter(lambda x: 'fc1' not in x.name, networkwithloss.trainable_params()))
if opt_config['type'] == 'AdamW':
params = [{'params': fc1_params, 'weight_decay': 0.01, 'lr': 0.01}, {'params': fc2_params, 'lr': 0.1}]
else:
params = [{'params': fc1_params, 'lr': 0.01}, {'params': fc2_params, 'lr': 0.01}]
else:
params = networkwithloss.trainable_params()
opt_config['params'] = params
net_opt = build_optim(opt_config)
trainonestepcell = mindspore.nn.TrainOneStepCell(networkwithloss, net_opt)
data, label = make_fake_data()
for i in range(20):
loss = trainonestepcell(data[i], label[i])
losses.append(loss.asnumpy())
return np.array(losses), net_opt
default_fc1_weight_adamw_m = (
np.array([[0.75276935, 0.75276935, 0.75276935, 0.75276935, 0.75276935, 0.75276935, 0.75276935, 0.75276935],
[0.28740492, 0.28740492, 0.28740492, 0.28740492, 0.28740492, 0.28740492, 0.28740492, 0.28740492],
[0.12561864, 0.12561864, 0.12561864, 0.12561864, 0.12561864, 0.12561864, 0.12561864, 0.12561864],
[-0.06905057, -0.06905057, -0.06905057, -0.06905057, -0.06905057, -0.06905057, -0.06905057,
-0.06905057]], dtype=np.float32)
)
default_fc2_weight_adamw_m = (
np.array([[6.978479, 7.470356, 5.508465, 5.176325]], dtype=np.float32)
)
default_fc1_weight_adamw_v = (
np.array([[0.28913346, 0.28913346, 0.28913346, 0.28913346, 0.28913346, 0.28913346, 0.28913346, 0.28913346],
[0.01420226, 0.01420226, 0.01420226, 0.01420226, 0.01420226, 0.01420226, 0.01420226, 0.01420226],
[0.00199351, 0.00199351, 0.00199351, 0.00199351, 0.00199351, 0.00199351, 0.00199351, 0.00199351],
[0.04521008, 0.04521008, 0.04521008, 0.04521008, 0.04521008, 0.04521008, 0.04521008, 0.04521008]],
dtype=np.float32)
)
default_fc2_weight_adamw_v = (
np.array([[35.217834, 42.283375, 26.52298, 21.510029]], dtype=np.float32)
)
class MockTransformerConfig:
"""Mock transformer config for testing Muon optimizer."""
def __init__(self):
self.multi_latent_attention = True
self.tensor_model_parallel_size = 1
self.data_parallel_size = 1
class MockModel:
"""
Mock model class that provides required interfaces for Muon optimizer.
This simulates the model interface that Muon optimizer expects.
"""
def __init__(self):
self.config = MockTransformerConfig()
def get_gpt_transformer_config(self):
"""Return transformer config."""
return self.config
def make_model_muon_fns(self):
"""Return muon split and merge functions."""
def muon_split_fn(param_name, tensor):
"""Split function - returns tensor as list."""
return [tensor]
def muon_merge_fn(param_name, tensor_list):
"""Merge function - returns first tensor."""
return tensor_list[0]
return muon_split_fn, muon_merge_fn
def get_param_layer_indices(self, params):
"""Return layer indices for parameters."""
return {p.name: 0 for p in params}
def get_muon_filter(self):
"""Return filter function to determine which params use Muon."""
def muon_filter(param):
return len(param.shape) == 2 and 'bias' not in param.name
return muon_filter
def get_tp_dims(self, params):
"""Return tensor parallel dimensions."""
return tuple(-1 for _ in params)
def get_op_groups_info(self, params, op):
"""Return optimizer parallel group info."""
ops = tuple(1 for _ in params)
op_groups = tuple("" for _ in params)
return ops, op_groups
def build_muon_network(net, mock_model, learning_rate=0.02):
"""
Build network with Muon optimizer for testing.
Args:
net: The network to train
mock_model: Mock model providing Muon interface
learning_rate: Learning rate for optimizer
Returns:
tuple: (losses, optimizer)
"""
loss_fn = nn.L1Loss(reduction='mean')
networkwithloss = NetWithLoss(net, loss_fn)
networkwithloss.set_train()
params = networkwithloss.trainable_params()
optimizer = Muon(
params=params,
learning_rate=learning_rate,
weight_decay=0.1,
matched_adamw_rms=0.2,
momentum=0.95,
nesterov=True,
adamw_betas=(0.95, 0.95),
adamw_eps=1e-8,
model=mock_model,
)
trainonestepcell = mindspore.nn.TrainOneStepCell(networkwithloss, optimizer)
losses = []
data, label = make_fake_data()
for i in range(20):
loss = trainonestepcell(data[i], label[i])
losses.append(loss.asnumpy())
return np.array(losses), optimizer