import sys
import os
import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
from torch.optim import AdamW
from transformers import AutoModel
ocr_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "deepseekocr")
if ocr_dir not in sys.path:
sys.path.insert(0, ocr_dir)
from examples.deepseekocr.finetune_ocr import OCRTrainer, get_parser
class OCR2Trainer(OCRTrainer):
def __init__(self, config):
super().__init__(config)
def build_model_and_optimizer(self, attn_implementation="eager"):
self.model = AutoModel.from_pretrained(
self.config.load,
_attn_implementation=attn_implementation,
trust_remote_code=True,
use_safetensors=True
).to("cuda", dtype=torch.bfloat16)
if self.config.freeze_tokenizer:
self.model.model.sam_model.requires_grad_(False)
if self.config.freeze_vis_encoder:
self.model.model.qwen2_model.requires_grad_(False)
self.model.model.projector.requires_grad_(False)
if self.config.freeze_decoder:
self.model.model.layers.requires_grad_(False)
self.model.model.embed_tokens.requires_grad_(False)
fsdp_kwargs = {}
fsdp_kwargs["mp_policy"] = MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32
)
fully_shard(self.model.model.embed_tokens)
for layer in self.model.model.layers:
fully_shard(layer, **fsdp_kwargs)
for layer in self.model.model.qwen2_model.model.model.layers:
fully_shard(layer, **fsdp_kwargs)
fully_shard(self.model.lm_head, **fsdp_kwargs)
fully_shard(self.model, **fsdp_kwargs)
if torch.distributed.get_rank() == 0:
print(self.model)
self.optimizer = AdamW(self.model.parameters(), lr=self.config.lr, weight_decay=self.config.weight_decay)
num_warmup_steps = int(self.config.warmup_ratio * self.config.train_iters)
self.scheduler = OCR2Trainer.get_cosine_schedule_with_warmup(
optimizer=self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=self.config.train_iters
)
def get_ocr2_parser():
parser = get_parser()
parser.add_argument(
'--freeze-tokenizer',
action='store_true',
default=False,
help='Whether or not to freeze tokenizer weight in training.'
)
parser.add_argument(
'--freeze-vis-encoder',
action='store_true',
default=False,
help='Whether or not to freeze vision encoder weight in training.'
)
parser.add_argument(
'--freeze-decoder',
action='store_true',
default=False,
help='Whether or not to freeze decoder weight in training.'
)
return parser
def main():
args = get_ocr2_parser().parse_args()
OCR2Trainer.setup_distributed()
ocr_trainer = OCR2Trainer(args)
ocr_trainer.train()
OCR2Trainer.cleanup_distributed()
if __name__ == "__main__":
torch.npu.config.allow_internal_format = False
main()