import copy
import os
import json
import re
import shutil
from collections import OrderedDict
import logging as logger
import argparse
import torch
import safetensors.torch
logger.basicConfig(format="")
logger.getLogger().setLevel(logger.INFO)
class MambaConverter:
def __init__(self, args, convert=None):
self.args = args
self.args.mamba_d_inner = self.args.hidden_size * 2
self.args.mamba2_n_heads = self.args.mamba_d_inner // self.args.mamba_head_dim
self.tp_split_dim = {
'word_embeddings.weight': 0,
'norm.weight': -1,
'final_norm.weight': -1,
'output_layer.weight': 0,
'A_log': 0,
'D': 0,
'dt_bias': 0,
'in_proj.weight': 0,
'conv1d.weight': 0,
'conv1d.bias': 0,
'x_proj.weight': 1,
'dt_proj.weight': 0,
'dt_proj.bias': 0,
'out_proj.weight': 1,
'mixer.norm.weight': 0,
'linear_fc1.layer_norm_weight': -1,
'linear_fc1.weight': 0,
'linear_fc2.weight': 1,
'self_attention.linear_proj.weight': 1,
'self_attention.linear_qkv.layer_norm_weight': -1,
'self_attention.linear_qkv.weight': 0,
}
if convert == 'hf2mg':
self.args.target_tensor_parallel_size = self.args.tensor_model_parallel_size
self.args.target_pipeline_parallel_size = self.args.pipeline_model_parallel_size
self.args.load_model_type = 'hf'
self.args.save_model_type = 'mg'
self.args.load_dir = self.args.load
self.args.save_dir = self.args.mg_save_dir
elif convert == 'mg2hf':
self.args.target_tensor_parallel_size = self.args.tensor_model_parallel_size
self.args.target_pipeline_parallel_size = self.args.pipeline_model_parallel_size
self.args.load_model_type = 'mg'
self.args.save_model_type = 'hf'
self.args.load_dir = self.args.save
self.args.save_dir = self.args.hf_save_dir
self._valid_parameter()
def get_split_dim(self, tensor_name):
if 'norm.weight' in tensor_name:
if 'mixer.norm.weight' in tensor_name:
return self.tp_split_dim['mixer.norm.weight']
else:
return self.tp_split_dim['norm.weight']
for key in self.tp_split_dim.keys():
if key in tensor_name:
return self.tp_split_dim[key]
raise Exception(f"Unknown tensor name {tensor_name}")
def _valid_parameter(self):
if self.args.mamba_num_groups % self.args.target_tensor_parallel_size != 0:
raise ValueError("target_tensor_parallel_size must can divide n_groups. Please adjust values.")
@staticmethod
def load_hf_files_to_dict(directory_path):
model_dict = {}
loaded = False
for filename in os.listdir(directory_path):
file_path = os.path.join(directory_path, filename)
try:
if filename.endswith(".bin"):
cur_weights = torch.load(file_path, map_location=torch.device('cpu'))
model_dict.update(cur_weights)
print(f"Successfully loaded: {filename}")
loaded = True
elif filename.endswith(".safetensors"):
try:
from safetensors.torch import load_file as safe_load
except ImportError as e:
raise ImportError(
"Detected a .safetensors file but the 'safetensors' package is not installed. "
"Please install it using `pip install safetensors`."
) from e
cur_weights = safe_load(file_path)
model_dict.update(cur_weights)
print(f"Successfully loaded: {filename}")
loaded = True
except Exception as e:
print(f"Error loading {filename}: {e}")
if not loaded:
raise RuntimeError("No valid weight files (.bin or .safetensors) were found or loaded.")
return model_dict
@staticmethod
def load_config_and_get_n_layer(directory_path, n_layers):
if n_layers is not None:
return n_layers
config_file_path = os.path.join(directory_path, "config.json")
if not os.path.exists(config_file_path):
raise FileNotFoundError(f"{config_file_path} not found!")
with open(config_file_path, 'r') as f:
config = json.load(f)
n_layer = config.get("num_hidden_layers", config.get("n_layer", None))
if n_layer is None:
raise KeyError("Neither 'num_hidden_layers' nor 'n_layers' key found in the config file")
return n_layer
@staticmethod
def modify_keys_hf2mg(model_dict):
modified_dict = {}
for key, value in model_dict.items():
new_key = key
if "norm_f" in key:
new_key = key.replace("backbone.norm_f", "decoder.final_norm")
elif "lm_head" in key:
new_key = key.replace("lm_head", "output_layer")
elif "embedding" in key:
if "backbone.embeddings" in key:
new_key = key.replace("backbone.embeddings", "embedding.word_embeddings")
else:
new_key = key.replace("backbone.embedding", "embedding.word_embeddings")
elif "backbone" in key:
new_key = key.replace("backbone", "decoder")
modified_dict[new_key] = value
return modified_dict
@staticmethod
def modify_keys_mg2hf(modified_dict):
restored_dict = {}
for key, value in modified_dict.items():
new_key = key
if "final_norm" in key:
new_key = key.replace("decoder.final_norm", "backbone.norm_f")
elif "output_layer" in key:
new_key = key.replace("output_layer", "lm_head")
elif "embedding" in key:
new_key = key.replace("embedding.word_embeddings", "backbone.embeddings")
elif "decoder" in key:
new_key = key.replace("decoder", "backbone")
restored_dict[new_key] = value
return restored_dict
def combine_tp_tensors(self, params, key, dim, tensors):
tp_size = len(tensors)
if 'mixer.in_proj.weight' in key:
xs = []
zs = []
Bs = []
Cs = []
dts = []
for tensor in tensors:
x, z, B, C, dt = torch.split(tensor, [params.mamba_d_inner // tp_size,
params.mamba_d_inner // tp_size,
(params.mamba_num_groups // tp_size) * self.args.mamba_state_dim,
(params.mamba_num_groups // tp_size) * self.args.mamba_state_dim,
params.mamba2_n_heads // tp_size], dim=dim)
xs.append(x)
zs.append(z)
Bs.append(B)
Cs.append(C)
dts.append(dt)
for tensor_index, (B_tensor, C_tensor) in enumerate(zip(Bs, Cs)):
Bs[tensor_index] = torch.reshape(B_tensor, (-1, params.mamba_state_dim, B_tensor.shape[-1]))
Cs[tensor_index] = torch.reshape(C_tensor, (-1, params.mamba_state_dim, C_tensor.shape[-1]))
B = torch.cat(Bs, dim=dim)
C = torch.cat(Cs, dim=dim)
x = torch.cat(xs, dim=dim)
z = torch.cat(zs, dim=dim)
dt = torch.cat(dts, dim=dim)
return torch.cat([x, z, B.flatten(0, 1), C.flatten(0, 1), dt], dim=dim)
elif 'mixer.conv1d' in key:
xs = []
Bs = []
Cs = []
for tensor in tensors:
x, B, C = torch.split(tensor, [params.mamba_d_inner // tp_size,
(params.mamba_num_groups // tp_size) * params.mamba_state_dim,
(params.mamba_num_groups // tp_size) * params.mamba_state_dim], dim=dim)
xs.append(x)
Bs.append(B)
Cs.append(C)
for tensor_index, (B_tensor, C_tensor) in enumerate(zip(Bs, Cs)):
if 'weight' in key:
Bs[tensor_index] = torch.reshape(B_tensor, (-1, params.mamba_state_dim, B_tensor.shape[-2], B_tensor.shape[-1]))
Cs[tensor_index] = torch.reshape(C_tensor, (-1, params.mamba_state_dim, C_tensor.shape[-2], C_tensor.shape[-1]))
elif 'bias' in key:
Bs[tensor_index] = torch.reshape(B_tensor, (-1, params.mamba_state_dim))
Cs[tensor_index] = torch.reshape(C_tensor, (-1, params.mamba_state_dim))
else:
raise Exception("Unknown key")
B = torch.cat(Bs, dim=dim)
C = torch.cat(Cs, dim=dim)
x = torch.cat(xs, dim=dim)
return torch.cat([x, B.flatten(0, 1), C.flatten(0, 1)], dim=dim)
else:
return torch.cat(tensors, dim=dim)
@staticmethod
def split_tensor_for_tp(params, key, dim, tensor):
tp_size = params.target_tensor_parallel_size
tensor_sliced = []
if 'mixer.in_proj.weight' in key:
x, z, B, C, dt = torch.split(tensor, [params.mamba_d_inner, params.mamba_d_inner,
params.mamba_num_groups * params.mamba_state_dim,
params.mamba_num_groups * params.mamba_state_dim,
params.mamba2_n_heads], dim=dim)
B = torch.reshape(B, (-1, params.mamba_state_dim, B.shape[-1]))
C = torch.reshape(C, (-1, params.mamba_state_dim, C.shape[-1]))
B_sliced = torch.chunk(B, tp_size, dim=dim)
C_sliced = torch.chunk(C, tp_size, dim=dim)
x_sliced = torch.chunk(x, tp_size, dim=dim)
z_sliced = torch.chunk(z, tp_size, dim=dim)
dt_sliced = torch.chunk(dt, tp_size, dim=dim)
tensor_sliced = []
for (x, z, B, C, dt) in zip(x_sliced, z_sliced, B_sliced, C_sliced, dt_sliced):
tensor_sliced.append(torch.cat((x, z, B.flatten(0, 1), C.flatten(0, 1), dt), dim=dim))
elif 'mixer.conv1d' in key:
x, B, C = torch.split(tensor, [params.mamba_d_inner,
params.mamba_num_groups * params.mamba_state_dim,
params.mamba_num_groups * params.mamba_state_dim], dim=dim)
if 'weight' in key:
B = torch.reshape(B, (-1, params.mamba_state_dim, B.shape[-2], B.shape[-1]))
C = torch.reshape(C, (-1, params.mamba_state_dim, C.shape[-2], C.shape[-1]))
elif 'bias' in key:
B = torch.reshape(B, (-1, params.mamba_state_dim))
C = torch.reshape(C, (-1, params.mamba_state_dim))
else:
raise Exception("Unknown key")
B_sliced = torch.chunk(B, tp_size, dim=dim)
C_sliced = torch.chunk(C, tp_size, dim=dim)
x_sliced = torch.chunk(x, tp_size, dim=dim)
tensor_sliced = []
for (x, B, C) in zip(x_sliced, B_sliced, C_sliced):
tensor_sliced.append(torch.cat((x, B.flatten(0, 1), C.flatten(0, 1)), dim=dim))
else:
tensor_sliced = torch.chunk(tensor, tp_size, dim=dim)
return tensor_sliced
@staticmethod
def finalize_checkpoint(src_model, model, params, verbose=False):
if 'args' in src_model.keys():
model['args'] = copy.deepcopy(src_model['args'])
model['args'].tensor_model_parallel_size = params.target_tensor_parallel_size
model['args'].pipeline_model_parallel_size = params.target_pipeline_parallel_size
if 'checkpoint_version' in src_model.keys():
model['checkpoint_version'] = copy.deepcopy(src_model['checkpoint_version'])
else:
model['checkpoint_version'] = 3.0
if 'iteration' in src_model.keys():
model['iteration'] = copy.deepcopy(src_model['iteration'])
else:
model['iteration'] = 1
if 'opt_param_scheduler' in src_model.keys():
model['opt_param_scheduler'] = copy.deepcopy(src_model['opt_param_scheduler'])
if 'rng_state' in src_model.keys():
model['rng_state'] = copy.deepcopy(src_model['rng_state'])
if verbose:
original_args = src_model['args'].__dict__
final_args = model['args'].__dict__
for key in original_args:
if key in final_args:
if final_args[key] != original_args[key]:
logger.info("KEY MISMATCH: {}".format(key))
logger.info("\toriginal: {}\n\tfinal: {}".format(original_args[key], final_args[key]))
else:
logger.info("KEY MISSING from final: {}, value {}".format(key, original_args[key]))
for key in final_args:
if key not in original_args:
logger.info("KEY ADDED to final: {}, value {}".format(key, final_args[key]))
return model
@staticmethod
def get_latest_checkpoint_model_file(load_dir, model_filename="model_optim_rng.pt"):
"""
Get the sub-model path for the latest iteration (any sub-model) to extract structure or parameter information.
Parameters:
- load_dir: Main directory to load the model
- model_filename: Sub-model filename (default is "model_optim_rng.pt")
Returns:
- out_iteration: Latest iteration number
- input_model_dir: Iteration directory path
- src_model_file: Full file path of any sub-model
"""
tracker_filename = os.path.join(load_dir, 'latest_checkpointed_iteration.txt')
try:
with open(tracker_filename, 'r') as f:
metastring = f.read().strip()
iteration = int(metastring)
except FileNotFoundError as e:
raise FileNotFoundError(f"Checkpoint tracker file not found at {tracker_filename}") from e
except ValueError as e:
raise ValueError(f"Invalid iteration value in {tracker_filename}: '{metastring}' is not an integer.") from e
out_iteration = iteration
input_model_dir = os.path.join(load_dir, f'iter_{iteration:07d}')
input_sub_models = os.listdir(input_model_dir)
if not input_sub_models:
raise RuntimeError(f"No sub-models found under {input_model_dir}")
src_model_file = os.path.join(input_model_dir, input_sub_models[0], model_filename)
return out_iteration, input_model_dir, src_model_file
def load_mg_model(self):
"""
Load the latest iteration number and retrieve a model checkpoint from the model directory to obtain structure and other information.
Returns:
- out_iteration: Iteration number
- input_model_dir: Corresponding iteration directory
- src_model: Sample model dict
"""
out_iteration, input_model_dir, src_model_file = self.get_latest_checkpoint_model_file(self.args.load_dir)
src_model = torch.load(src_model_file, map_location='cpu', weights_only=False)
logger.info(f"Sample model {src_model_file} is loaded.\n")
return out_iteration, input_model_dir, src_model
@staticmethod
def get_model_file_path(input_model_dir, tp_index, pp_index=0, input_pp_rank=1, filename="model_optim_rng.pt"):
dir_name = f"mp_rank_{tp_index:02d}"
if input_pp_rank > 1:
dir_name += f"_{pp_index:03d}"
return os.path.join(input_model_dir, dir_name, filename)
def merge_checkpoints(self, input_model_dir, input_tp_rank, input_pp_rank, num_layers_per_pipeline_rank, args):
"""
Load and merge Megatron-style TP+PP model weights, and return the merged full model (OrderedDict)
"""
full_model = OrderedDict()
for pp_index in range(input_pp_rank):
logger.info(f"Processing input pipeline rank {pp_index}")
tp_models = []
for tp_index in range(input_tp_rank):
model_file = self.get_model_file_path(
input_model_dir,
tp_index,
pp_index,
input_pp_rank
)
tp_models.append(torch.load(model_file, map_location='cpu', weights_only=False))
logger.info(f"Model {model_file} is loaded.")
if input_tp_rank > 1:
combined_tp_model = OrderedDict()
for key, tensor in tp_models[0]['model'].items():
if "_extra_state" in key:
combined_tp_model[key] = tensor
continue
split_dim = self.get_split_dim(key)
if split_dim != -1:
combined_tensor = self.combine_tp_tensors(
args, key, split_dim,
[tp_models[i]['model'][key].cpu() for i in range(input_tp_rank)]
)
combined_tp_model[key] = combined_tensor
else:
combined_tp_model[key] = tensor
else:
combined_tp_model = tp_models[0]['model']
for key, tensor in combined_tp_model.items():
try:
layer_num = int(re.findall(r'\d+', key)[0])
new_key = key.replace(str(layer_num), str(layer_num + pp_index * num_layers_per_pipeline_rank), 1)
except Exception:
new_key = key
full_model[new_key] = tensor
logger.info("Loaded combined model.")
return full_model
def split_model_by_pp_tp(self, full_model, args):
"""
split Megatron model by pipeline rank and tensor parallel rank
"""
split_models = {}
layers_per_pp = args.num_layers // args.target_pipeline_parallel_size
for pp_idx in range(args.target_pipeline_parallel_size):
tp_models = [{'model': OrderedDict()} for _ in range(args.target_tensor_parallel_size)]
for key, tensor in full_model.items():
try:
layer_num = int(re.findall(r'\d+', key)[0])
if layer_num >= layers_per_pp * (pp_idx + 1):
continue
new_key = key.replace(str(layer_num), str(layer_num - layers_per_pp * pp_idx), 1)
except Exception:
new_key = key
if "_extra_state" in key:
for tp_model in tp_models:
tp_model['model'][new_key] = tensor
continue
split_dim = self.get_split_dim(new_key)
if split_dim != -1:
slices = self.split_tensor_for_tp(args, new_key, split_dim, tensor)
for i in range(args.target_tensor_parallel_size):
tp_models[i]['model'][new_key] = slices[i]
else:
for tp_model in tp_models:
tp_model['model'][new_key] = tensor
for tp_idx, tp_model in enumerate(tp_models):
split_models[(pp_idx, tp_idx)] = tp_model['model']
return split_models
@staticmethod
def build_model_save_path(args, out_iteration, tp_idx, pp_idx=0):
"""
Constructs and creates the path to save the model, and returns the full model file path.
Parameters:
- save_dir: Top-level directory to save
- out_iteration: Current iteration number (int)
- tp_idx: Tensor parallel rank
- pp_idx: Pipeline parallel rank (default is 0)
- target_pipeline_parallel_size: Pipeline parallel size (default is 1)
- filename: Filename (default is 'model_optim_rng.pt')
Returns:
- model_file: The full model file path in the save directory
"""
filename = "model_optim_rng.pt"
dir_name = f"mp_rank_{tp_idx:02d}"
if args.target_pipeline_parallel_size > 1:
dir_name += f"_{pp_idx:03d}"
save_path = os.path.join(args.save_dir, f"iter_{out_iteration:07d}", dir_name)
os.makedirs(save_path, exist_ok=True)
return os.path.join(save_path, filename)
def save_split_models(self, split_models, src_model, args, out_iteration):
"""
save splited model
"""
for (pp_idx, tp_idx), model_dict in split_models.items():
model_file = self.build_model_save_path(
args,
out_iteration,
tp_idx,
pp_idx,
)
finalized = self.finalize_checkpoint(src_model, {'model': model_dict}, args, verbose=False)
torch.save(finalized, model_file)
logger.info(f"Model {model_file} is saved.")
tracker_file = os.path.join(args.save_dir, "latest_checkpointed_iteration.txt")
with open(tracker_file, "w") as f:
f.write(str(out_iteration))
def convert_mg_checkpoint(self, args):
logger.info("====RUNNING CHECKPOINT CONVERSION====")
out_iteration, input_model_dir, src_model = self.load_mg_model()
if 'args' in src_model.keys():
input_tp_rank = src_model['args'].tensor_model_parallel_size
input_pp_rank = src_model['args'].pipeline_model_parallel_size
num_layers_per_pipeline_rank = src_model['args'].num_layers // input_pp_rank
else:
input_tp_rank = self.args.input_tp_rank
input_pp_rank = self.args.input_pp_rank
num_layers = self.args.num_layers
num_layers_per_pipeline_rank = num_layers // input_pp_rank
full_model = self.merge_checkpoints(
input_model_dir=input_model_dir,
input_tp_rank=input_tp_rank,
input_pp_rank=input_pp_rank,
num_layers_per_pipeline_rank=num_layers_per_pipeline_rank,
args=args)
if args.save_model_type == 'hf':
hf_model = self.modify_keys_mg2hf(full_model)
model_file = os.path.join(args.save_dir, "torch_model.bin")
os.makedirs(os.path.dirname(model_file), exist_ok=True)
torch.save(hf_model, model_file)
logger.info(f"Model saved to {model_file}.")
else:
if 'args' in src_model:
args.num_layers = src_model['args'].num_layers
split_models = self.split_model_by_pp_tp(full_model, args)
self.save_split_models(split_models, src_model, args, out_iteration)
def convert_hf2mg(self, args):
logger.info("====RUNNING CHECKPOINT CONVERSION====")
hf_model = self.load_hf_files_to_dict(args.load_dir)
hf_model_new = self.modify_keys_hf2mg(hf_model)
args.num_layers = self.load_config_and_get_n_layer(args.load_dir, args.num_layers)
split_models = self.split_model_by_pp_tp(hf_model_new, args)
self.save_split_models(split_models, hf_model, args, out_iteration=1)
def run(self):
if self.args.load_model_type == "mg":
self.convert_mg_checkpoint(self.args)
else:
self.convert_hf2mg(self.args)