import abc
import os
import sys
import re
from tqdm import tqdm
import torch
from transformers import AutoModelForCausalLM, AutoConfig, AutoModelForSequenceClassification
from peft import get_peft_model, LoraConfig, TaskType
from mindspeed_llm.tasks.checkpoint.models import ModelBase
def register_functions(self):
self.get_module_mapping()
def _get_obj(self, value, **kwargs):
pattern = r'(\w+)(?:\[(\w+)\])?'
matches = re.findall(pattern, value)
self.update_kwargs_idx(**kwargs)
obj = self.get_model_item(**kwargs)
for attr, attr_ident in matches:
if hasattr(obj, attr):
obj = getattr(obj, attr)
else:
return None
if attr_ident:
if attr_ident in self.kwargs_idx:
attr_idx = self.kwargs_idx[attr_ident]
obj = obj[attr_idx]
else:
raise AssertionError(f"check {self.__class__.__name__}.module_mapping **{attr_ident}**.")
return obj
def _get_dst_obj(self, value, **kwargs):
if kwargs.get("layer_idx") is None:
kwargs["layer_idx"] = kwargs.get("dst_layer_idx")
return _get_obj(self, value, **kwargs)
def _get_src_obj(self, value, **kwargs):
if kwargs.get("layer_idx") is None:
kwargs["layer_idx"] = kwargs.get("src_layer_idx")
return _get_obj(self, value, **kwargs)
def _func_generator_get_module(value):
def func(self, **kwargs):
return _get_src_obj(self, value, **kwargs)
return func
def _func_generator_get_weight(value):
def func(self, **kwargs):
return _get_src_obj(self, value, **kwargs).weight.data
return func
def _func_generator_get_bias(value):
def func(self, **kwargs):
return _get_src_obj(self, value, **kwargs).bias.data
return func
def _func_generator_set_weight(value):
def func(self, **kwargs):
set_tensor = _get_dst_obj(self, value, **kwargs)
data = kwargs.get('data')
if data.dtype != set_tensor.weight.dtype:
data = data.to(dtype=set_tensor.weight.dtype)
set_tensor.weight.data = data
return set_tensor.weight.data
return func
def _func_generator_set_module(value):
def func(self, **kwargs):
return _get_dst_obj(self, value, **kwargs).data.copy_(kwargs.get('data'))
return func
def _func_generator_set_bias(value):
def func(self, **kwargs):
set_tensor = _get_dst_obj(self, value, **kwargs)
data = kwargs.get('data')
if data.dtype != set_tensor.weight.dtype:
data = data.to(dtype=set_tensor.weight.dtype)
set_tensor.bias.data = data
return set_tensor.bias.data
return func
def _func_generator_has_module(value):
def func(self, **kwargs):
obj = _get_src_obj(self, value, **kwargs)
return True if obj else False
return func
def _func_generator_has_bias(value):
def func(self, **kwargs):
bias = getattr(_get_src_obj(self, value, **kwargs), 'bias', None)
return bias is not None
return func
if self.module_mapping:
for key, value in self.module_mapping.items():
setattr(self, "get_" + key + "_module", _func_generator_get_module(value).__get__(self, ModelBase))
setattr(self, "set_" + key + "_module", _func_generator_set_module(value).__get__(self, ModelBase))
setattr(self, "get_" + key + "_weight", _func_generator_get_weight(value).__get__(self, ModelBase))
setattr(self, "get_" + key + "_bias", _func_generator_get_bias(value).__get__(self, ModelBase))
setattr(self, "set_" + key + "_weight", _func_generator_set_weight(value).__get__(self, ModelBase))
setattr(self, "set_" + key + "_bias", _func_generator_set_bias(value).__get__(self, ModelBase))
setattr(self, "has_" + key + "_module", _func_generator_has_module(value).__get__(self, ModelBase))
setattr(self, "has_" + key + "_bias", _func_generator_has_bias(value).__get__(self, ModelBase))
def get_modules_from_pretrained(self, device_map="cpu", trust_remote_code=True):
if self.args_cmd.save_model_type == "hf":
load_dir = self.args_cmd.save_dir
else:
load_dir = self.args_cmd.load_dir
self.module = [AutoModelForCausalLM.from_pretrained(
load_dir, device_map=device_map, trust_remote_code=trust_remote_code, local_files_only=True, low_cpu_mem_usage=False
)]
if self.args_cmd.save_lora_to_hf:
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=self.args_cmd.lora_r,
lora_alpha=self.args_cmd.lora_alpha,
target_modules=self.target_lora_modules_hf,
lora_dropout=0.0,
bias="none"
)
self.module = [get_peft_model(self.module[0], lora_config)]
if hasattr(self.args, "torch_dtype") and self.args.torch_dtype in ["float16", "bfloat16"]:
self.module[0] = self.module[0].to(eval(f'torch.{self.args.torch_dtype}'))