from transformers import PretrainedConfig


class ThermoQwenConfig(PretrainedConfig):
    model_type = "thermo_qwen_tsf"

    def __init__(
        self,
        base_model_name_or_path=None,
        backbone_config=None,
        horizon=5,
        depth_min=-150.0,
        depth_max=-25.0,
        lora_r=8,
        lora_alpha=16,
        lora_dropout=0.05,
        lora_target_modules=None,
        regressor_hidden_sizes=None,
        regressor_dropout=(0.2, 0.1),
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.base_model_name_or_path = base_model_name_or_path
        self.backbone_config = backbone_config
        self.horizon = horizon
        self.depth_min = depth_min
        self.depth_max = depth_max
        self.lora_r = lora_r
        self.lora_alpha = lora_alpha
        self.lora_dropout = lora_dropout
        self.lora_target_modules = lora_target_modules or [
            "k_proj",
            "q_proj",
            "v_proj",
            "o_proj",
        ]
        self.regressor_hidden_sizes = regressor_hidden_sizes or [256, 64]
        self.regressor_dropout = list(regressor_dropout)