de51d76f创建于 2025年8月20日历史提交
import os
from abc import ABC

import torch
from torch import nn


class SoraGRPOModel(nn.Module, ABC):
    def __init__(self):
        super().__init__()
        self.ae = None
        self.diffuser = None
        self.reward = None
        self.text_encoder = None

    def initialize_reward_model(self, args, device):
        reward_model_config = args.mm.model.reward
        if reward_model_config.model_id == "CLIP-ViT-H-14":
            return self.initialize_hps_model(device, reward_model_config)
        else:
            raise ValueError("reward model id is wrong.")

    def initialize_hps_model(self, device, reward_model_config):
        from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer

        def initialize_model(device):
            model_dict = {}
            pretrained = os.path.join(reward_model_config.ckpt_dir, reward_model_config.pretrained)
            model, preprocess_train, preprocess_val = create_model_and_transforms(
                reward_model_config.model_name,
                pretrained,
                precision='amp',
                device=device,
                jit=False,
                force_quick_gelu=False,
                force_custom_text=False,
                force_patch_dropout=False,
                force_image_size=None,
                pretrained_image=False,
                image_mean=None,
                image_std=None,
                light_augmentation=True,
                aug_cfg={},
                output_dict=True,
                with_score_predictor=False,
                with_region_predictor=False
            )
            model_dict['model'] = model
            model_dict['preprocess_val'] = preprocess_val
            return model_dict

        model_dict = initialize_model(device)
        model = model_dict['model']
        preprocess_val = model_dict['preprocess_val']
        cp = os.path.join(reward_model_config.ckpt_dir, reward_model_config.load_pt)
        checkpoint = torch.load(cp, map_location=f'cuda:{device}')
        model.load_state_dict(checkpoint['state_dict'])
        processor = get_tokenizer(reward_model_config.model_name)
        reward_model = model.to(device)
        reward_model.eval()
        model_dict['processor'] = processor
        model_dict['reward_model'] = reward_model
        model_dict['preprocess_val'] = preprocess_val
        return model_dict

    def get_split_modules(self):
        raise NotImplementedError("Subclasses must implement this method")