"""Run Muon optimizer accuracy test with configurable parameters via args"""
import argparse
import numpy as np
import mindspore as ms
from mindspore import nn, Tensor
from mindformers.core.context.build_context import build_context
from mindformers.core.optim.muon import Muon
np.random.seed(1024)
FC1_WEIGHT = np.array([[0.72346634, 0.95608497, 0.4084163, 0.18627149,
0.6942514, 0.39767185, 0.24918061, 0.4548748],
[0.7203382, 0.19086994, 0.76286614, 0.87920564,
0.3169892, 0.9462494, 0.62827677, 0.27504718],
[0.3544535, 0.2524781, 0.5370583, 0.8313121,
0.6670143, 0.0488653, 0.62225235, 0.7546456],
[0.17985944, 0.05106374, 0.31064633, 0.4863033,
0.848814, 0.5523157, 0.20295663, 0.7213356]]).astype("float32")
FC1_BIAS = np.array([0.79708564, 0.13728078, 0.66322654, 0.88128525]).astype("float32")
FC2_WEIGHT = np.array([[0.8473515, 0.50923985, 0.42287776, 0.29769543]]).astype("float32")
FC2_BIAS = np.array([0.09996348]).astype("float32")
class MockTransformerConfig:
"""Mock transformer config for testing Muon optimizer."""
def __init__(self):
self.multi_latent_attention = True
self.tensor_model_parallel_size = 1
self.data_parallel_size = 1
class MockModel:
"""
Mock model class that provides required interfaces for Muon optimizer.
This simulates the model interface that Muon optimizer expects.
"""
def __init__(self):
self.config = MockTransformerConfig()
def get_gpt_transformer_config(self):
"""Return transformer config."""
return self.config
def make_model_muon_fns(self):
"""Return muon split and merge functions."""
def muon_split_fn(param_name, tensor):
"""Split function - returns tensor as list."""
return [tensor]
def muon_merge_fn(param_name, tensor_list):
"""Merge function - returns first tensor."""
return tensor_list[0]
return muon_split_fn, muon_merge_fn
def apply_qk_clip_scaling(self, params, param_names, param_layer, logit_threshold,
muon_split_fn, muon_merge_fn):
"""Apply query-key clipping scaling."""
return [(0, params[0])]
def get_param_layer_indices(self, params):
"""Return layer indices for parameters."""
return {p.name: 0 for p in params}
def get_muon_filter(self):
"""Return filter function to determine which params use Muon."""
def muon_filter(param):
return len(param.shape) == 2 and 'bias' not in param.name
return muon_filter
def get_tp_dims(self, params):
"""Return tensor parallel dimensions."""
return tuple(-1 for _ in params)
def get_op_groups_info(self, params, op):
"""Return optimizer parallel group info."""
ops = tuple(1 for _ in params)
op_groups = tuple("" for _ in params)
return ops, op_groups
class FakeNet(nn.Cell):
"""Build fake net for testing."""
def __init__(self):
super().__init__()
self.fc1 = nn.Dense(in_channels=8, out_channels=4,
weight_init=Tensor(FC1_WEIGHT),
bias_init=Tensor(FC1_BIAS))
self.fc2 = nn.Dense(in_channels=4, out_channels=1,
weight_init=Tensor(FC2_WEIGHT),
bias_init=Tensor(FC2_BIAS))
self.relu = nn.ReLU()
def construct(self, x):
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
class NetWithLoss(nn.Cell):
"""Build net with loss."""
def __init__(self, network, loss_fn):
super().__init__()
self.network = network
self.loss = loss_fn
def construct(self, x, label):
out = self.network(x)
loss = self.loss(out, label)
return loss
def make_fake_data():
"""Make fake data for testing."""
data, label = [], []
for i in range(20):
data.append(ms.Tensor(np.array(np.ones((2, 8)) * i, dtype=np.float32)))
label.append(ms.Tensor(np.array(np.ones((2, 1)) * (i + 1), dtype=np.float32)))
return data, label
class MuonRunner:
"""Class to manage Muon optimizer test and training."""
def __init__(self, args_from_parser):
self.args = args_from_parser
self.learning_rate = self.args.learning_rate
self.weight_decay = self.args.weight_decay
self.momentum = self.args.momentum
self.nesterov = self.args.nesterov
self.num_steps = self.args.num_steps
def build_network(self):
"""Build network with Muon optimizer."""
net = FakeNet()
mock_model = MockModel()
loss_fn = nn.L1Loss(reduction='mean')
networkwithloss = NetWithLoss(net, loss_fn)
networkwithloss.set_train()
params = networkwithloss.trainable_params()
optimizer = Muon(
params=params,
learning_rate=self.learning_rate,
weight_decay=self.weight_decay,
matched_adamw_rms=0.2,
momentum=self.momentum,
nesterov=self.nesterov,
adamw_betas=(0.95, 0.95),
adamw_eps=1e-8,
model=mock_model,
)
return networkwithloss, optimizer, mock_model
def run(self):
"""Run the training with Muon optimizer."""
networkwithloss, optimizer, mock_model = self.build_network()
trainonestepcell = nn.TrainOneStepCell(networkwithloss, optimizer)
losses = []
data, label = make_fake_data()
for i in range(self.num_steps):
loss = trainonestepcell(data[i], label[i])
losses.append(loss.asnumpy())
output_dict = {
"losses": np.array(losses),
"num_muon_m": len(optimizer.muon_m),
"num_moments1": len(optimizer.moments1),
"num_moments2": len(optimizer.moments2),
}
muon_filter = mock_model.get_muon_filter()
for idx, param in enumerate(optimizer._parameters):
if muon_filter(param):
muon_m_value = optimizer.muon_m[idx].asnumpy()
output_dict[f"muon_m_{idx}"] = muon_m_value
np.savez(self.args.output_path, **output_dict)
print(f"Results saved to {self.args.output_path}")
def main():
parser = argparse.ArgumentParser(description="Run Muon optimizer test")
parser.add_argument("--learning_rate", type=float, default=0.02)
parser.add_argument("--weight_decay", type=float, default=0.1)
parser.add_argument("--momentum", type=float, default=0.95)
parser.add_argument("--nesterov", type=lambda x: x.lower() == "true", default=True)
parser.add_argument("--num_steps", type=int, default=20)
parser.add_argument("--output_path", type=str, default="output_muon.npz")
args = parser.parse_args()
build_context({"use_legacy": False, "use_parallel": True})
ms.set_deterministic(True)
ms.set_context(mode=ms.GRAPH_MODE)
ms.set_seed(42)
runner = MuonRunner(args)
runner.run()
if __name__ == "__main__":
main()