from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from peft import LoraConfig, get_peft_model
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
from transformers.modeling_outputs import ModelOutput
from .configuration_thermo_qwen import ThermoQwenConfig
@dataclass
class ThermoQwenForecastOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
pred_depths: Optional[torch.FloatTensor] = None
def _config_from_dict(config_dict):
model_type = config_dict.get("model_type")
kwargs = dict(config_dict)
kwargs.pop("model_type", None)
return AutoConfig.for_model(model_type, **kwargs)
def _pool_hidden_states(hidden_states, attention_mask):
if attention_mask is None:
mean_pool = hidden_states.mean(dim=1)
last_pool = hidden_states[:, -1, :]
return torch.cat([last_pool, mean_pool], dim=-1)
mask = attention_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
mean_pool = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
seq_lengths = attention_mask.sum(dim=1).clamp(min=1) - 1
last_pool = hidden_states[
torch.arange(hidden_states.size(0), device=hidden_states.device),
seq_lengths,
]
return torch.cat([last_pool, mean_pool], dim=-1)
class ThermoQwenForForecasting(PreTrainedModel):
config_class = ThermoQwenConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"model\.lm_head\.weight"]
def __init__(self, config: ThermoQwenConfig):
super().__init__(config)
if config.backbone_config is not None:
backbone_config = _config_from_dict(config.backbone_config)
elif config.base_model_name_or_path:
backbone_config = AutoConfig.from_pretrained(
config.base_model_name_or_path,
trust_remote_code=True,
)
else:
raise ValueError(
"ThermoQwenConfig requires either backbone_config or "
"base_model_name_or_path."
)
base_model = AutoModelForCausalLM.from_config(
backbone_config,
trust_remote_code=True,
)
lora_config = LoraConfig(
r=config.lora_r,
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
target_modules=config.lora_target_modules,
bias="none",
task_type="SEQ_CLS",
)
self.model = get_peft_model(base_model, lora_config).model
# The forecasting wrapper never uses lm_head logits. Keep the outer
# tied-weight list empty so Transformers does not look for a top-level
# lm_head on this wrapper.
self.all_tied_weights_keys = {}
hidden_size = backbone_config.hidden_size
self.regressor = self._build_regressor(
hidden_size * 2,
config.regressor_hidden_sizes,
config.regressor_dropout,
config.horizon,
)
@staticmethod
def _build_regressor(input_size, hidden_sizes, dropouts, output_size):
layers = [nn.LayerNorm(input_size)]
in_features = input_size
for idx, hidden_size in enumerate(hidden_sizes):
layers.extend([
nn.Linear(in_features, hidden_size),
nn.GELU(),
])
dropout = dropouts[idx] if idx < len(dropouts) else 0.0
if dropout > 0:
layers.append(nn.Dropout(dropout))
in_features = hidden_size
layers.append(nn.Linear(in_features, output_size))
return nn.Sequential(*layers)
def denormalize_depth(self, logits):
depth_range = self.config.depth_max - self.config.depth_min
return (logits + 1.0) / 2.0 * depth_range + self.config.depth_min
def forward(
self,
input_ids=None,
attention_mask=None,
labels=None,
output_hidden_states=True,
return_dict=True,
**kwargs,
):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
**kwargs,
)
pooled = _pool_hidden_states(outputs.hidden_states[-1], attention_mask)
logits = self.regressor(pooled.to(self.regressor[0].weight.dtype))
loss = None
if labels is not None:
loss = F.huber_loss(logits.float(), labels.float(), delta=0.2)
pred_depths = self.denormalize_depth(logits.float())
if not return_dict:
output: Tuple[torch.Tensor, ...] = (logits, pred_depths)
return ((loss,) + output) if loss is not None else output
return ThermoQwenForecastOutput(
loss=loss,
logits=logits,
pred_depths=pred_depths,
)