from pathlib import Path
import torch
from transformers import AutoConfig, AutoProcessor
from checkpoint.common.converter import Converter
from checkpoint.common.permissions import set_directory_permissions
from checkpoint.common.merge_dcp_to_hf import load_dcp_state_dict, save_hf_weights, merge_dcp_to_hf_sharded
from checkpoint.common.hf_to_dcp import hf_to_dcp_sharded
from checkpoint.vlm_model.hf_to_mm import load_from_hf, save_by_dcp
class Qwen3VLConverter(Converter):
"""
A utility class to convert model checkpoints of Qwen3-VL between different formats,
specifically between Hugging Face (HF) and torch-dcp (DCP) formats.
Supports:
- HF → DCP conversion
- DCP → HF merging
- Placeholder methods for megatron format and resharding operations.
"""
dcp_prefix = "model."
hf_prefix = ""
dcp_prefix_for_lora_base = "base_model.model.model."
dcp_prefix_lora = "base_model.model."
hf_prefix_lora = "base_model."
tie_weight_mapping = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
fused_linear_names = ["gate_up_proj", "down_proj"]
def hf_to_dcp(
self,
hf_dir: str = "Qwen3-VL-xxB",
dcp_dir: str = "Qwen3-VL-xxB-dcp",
tie_weight: bool = False,
is_lora_base: bool = False
):
"""
Converts a Hugging Face formatted model checkpoint to torch-dcp format.
Steps:
1. Load the state dict from HF format.
2. Optionally tie weights (e.g., share lm_head and embed_tokens weights).
3. Rename all keys by adding DCP prefix and removing HF prefix.
4. Save the converted checkpoint in DCP format.
5. Set proper directory permissions.
"""
def state_dict_convert_func(state_dict):
if tie_weight:
for tgt_weight, src_weight in self.tie_weight_mapping.items():
if src_weight in state_dict.keys():
state_dict[tgt_weight] = state_dict[src_weight]
ori_keys = list(state_dict.keys())
for ori_key in ori_keys:
value = state_dict.pop(ori_key)
if any(fused_linear_name in ori_key for fused_linear_name in self.fused_linear_names):
value = value.view(-1, value.shape[-1])
dcp_prefix = self.dcp_prefix_for_lora_base if is_lora_base else self.dcp_prefix
new_key = ori_key.replace(self.hf_prefix, dcp_prefix, 1) if len(
self.hf_prefix) > 0 else f"{dcp_prefix}{ori_key}"
state_dict[new_key] = value
return state_dict
hf_to_dcp_sharded(
hf_dir=hf_dir,
dcp_dir=dcp_dir,
state_dict_convert_func=state_dict_convert_func
)
def dcp_to_hf(
self,
load_dir: str = "mm_save_dir/release",
save_dir: Path = "Qwen3-VL-xxB-hf",
model_assets_dir: str = "Qwen3-VL-xxB",
to_bf16: bool = False,
):
"""
Merges torch-dcp shards and converts them back into standard Hugging Face format.
This is typically used after training or inference in torch-dcp format to export
a model that can be easily loaded with Hugging Face Transformers.
"""
config = AutoConfig.from_pretrained(model_assets_dir)
num_experts = getattr(config.text_config, "num_experts", None)
def state_dict_convert_func(state_dict):
state_dict_keys = list(state_dict.keys())
for key in state_dict_keys:
if num_experts and any(fused_linear_name in key for fused_linear_name in self.fused_linear_names):
state_dict[key] = state_dict[key].view(num_experts, -1, state_dict[key].shape[-1])
value = state_dict.pop(key)
new_key = key.replace(self.dcp_prefix, self.hf_prefix, 1) if key.startswith(self.dcp_prefix) else key
state_dict[new_key] = value
if to_bf16:
state_dict[new_key] = state_dict[new_key].to(dtype=torch.bfloat16)
return state_dict
merge_dcp_to_hf_sharded(
load_dir=Path(load_dir),
save_dir=Path(save_dir),
model_assets_dir=Path(model_assets_dir),
select_key_convert_func=lambda key: f"model.{self.dcp_prefix}" + key,
state_dict_convert_func=state_dict_convert_func
)
def lora_hf_to_dcp(
self,
hf_dir: str = "Qwen3-VL-xxB-lora",
dcp_dir: str = "Qwen3-VL-xxB-dcp-lora",
):
def state_dict_convert_func(state_dict):
ori_keys = list(state_dict.keys())
for ori_key in ori_keys:
value = state_dict.pop(ori_key)
new_key = ori_key.replace(self.hf_prefix_lora, self.dcp_prefix_lora, 1)
new_key = new_key.replace(".weight", ".default.weight", 1)
state_dict[new_key] = value
return state_dict
hf_to_dcp_sharded(
hf_dir=hf_dir,
dcp_dir=dcp_dir,
state_dict_convert_func=state_dict_convert_func
)
def merge_mm_lora_dcp_weight_to_base_hf(
self,
base_hf_dir: str = "Qwen3-VL-xxB",
lora_dcp_dir: str = "Qwen3-VL-xx-B-lora-dcp",
lora_target_modules: str = "",
save_merged_hf_dir: str = "Qwen3-VL-xxB-merged-hf",
scaling=1.0
):
target_module_list = [module.strip() for module in lora_target_modules.split(",")]
lora_state_dict = load_dcp_state_dict(lora_dcp_dir)
base_state_dict = load_from_hf(Path(base_hf_dir))
merge_state_dict = base_state_dict
target_layers = set()
for name in lora_state_dict.keys():
if 'weight' in name and any(lora_target_module in name for lora_target_module in target_module_list):
target_layers.add(name.split('.lora_')[0])
for target_layer in target_layers:
lora_a_weight = lora_state_dict.get(target_layer + '.lora_A.default.weight', None)
lora_b_weight = lora_state_dict.get(target_layer + '.lora_B.default.weight', None)
if lora_a_weight is not None and lora_b_weight is not None:
base_weight_key = f"{target_layer}.weight".replace("base_model.model.model.", "")
base_weight_fp32 = merge_state_dict[base_weight_key].data.to(dtype=torch.float32).clone()
base_weight_fp32.data.addmm_(lora_b_weight.data, lora_a_weight.data, alpha=scaling)
merge_state_dict[base_weight_key].data = base_weight_fp32.to(dtype=torch.bfloat16)
config = AutoConfig.from_pretrained(str(base_hf_dir))
processor = AutoProcessor.from_pretrained(str(base_hf_dir), trust_remote_code=True)
save_path = Path(save_merged_hf_dir)
config.save_pretrained(save_path)
processor.save_pretrained(save_path)
save_hf_weights(
save_dir=save_merged_hf_dir,
model_assets_dir=str(base_hf_dir),
state_dict=merge_state_dict,
prefix="",
)
@staticmethod
def hf_to_mm():
pass
@staticmethod
def mm_to_hf():
pass
@staticmethod
def resplit():
pass