from typing import List, Optional
import os

import torch
import torch.nn.functional as F

from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
from transformers.modeling_utils import PreTrainedModel
from transformers.models.qwen3 import Qwen3Model
from peft import PeftMixedModel, PeftConfig
from .configuration_jina_embeddings_v5 import JinaEmbeddingsV5Config


class JinaEmbeddingsV5Model(PeftMixedModel):
    @classmethod
    def register_for_auto_class(cls, auto_class="AutoModel"):
        return PreTrainedModel.register_for_auto_class.__func__(cls, auto_class)
    
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs):
        if kwargs.get("config", None):
            base_config = kwargs.pop("config")
        else:
            base_config = JinaEmbeddingsV5Config.from_pretrained(
                pretrained_model_name_or_path,
            )
        base_model = Qwen3Model.from_pretrained(
            pretrained_model_name_or_path,
            config=base_config,
            dtype=kwargs.pop("dtype", torch.bfloat16),
        )
        
        if os.path.isdir(base_model.name_or_path):
            adapters_dir = os.path.join(base_model.name_or_path, "adapters")
        else:
            adapter_cache_path = snapshot_download(
                repo_id=base_model.name_or_path,
                allow_patterns=["adapters/*"],
            )
            adapters_dir = os.path.join(adapter_cache_path, "adapters")
        adapter_paths = {
            name: os.path.join(adapters_dir, name)
            for name in base_config.task_names
        }

        peft_config = PeftConfig.from_pretrained(adapter_paths["retrieval"], **kwargs)
        model = cls(base_model, peft_config, adapter_name="retrieval")
        model._pretrained_path = pretrained_model_name_or_path
        for adapter_name in base_config.task_names:
            model.load_adapter(
                adapter_paths[adapter_name],
                adapter_name=adapter_name,
                **kwargs,
            )

        model.tokenizer = AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path,
            trust_remote_code=True,
        )
        return model

    def encode(
        self,
        texts: List[str],
        task: str,
        prompt_name: Optional[str] = "document",
        truncate_dim: Optional[int] = None,
        max_length: Optional[int] = None,
    ) -> List[torch.Tensor]:
        if task not in self.base_model.config.task_names:
            raise ValueError(f"Unknown task: {task}")

        if prompt_name is None:
            prompt_name = "document"
        if prompt_name not in {"query", "document"}:
            raise ValueError(f"Unknown prompt_name: {prompt_name}")

        prefix = "Query: " if prompt_name == "query" else "Document: "
        inputs = [f"{prefix}{text}" for text in texts]

        if not hasattr(self, "tokenizer") or self.tokenizer is None:
            raise ValueError("Tokenizer not found on model. Load with from_pretrained().")

        max_length = max_length or self.config.max_position_embeddings
        batch = self.tokenizer(
            inputs,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length,
        )
        device = next(self.parameters()).device
        batch = {k: v.to(device) for k, v in batch.items()}
        self.set_adapter([task])
        self.eval()
        with torch.no_grad():
            outputs = self(**batch)
            hidden = outputs.last_hidden_state
            mask = batch.get("attention_mask")
            if mask is None:
                pooled = hidden[:, -1]
            else:
                sequence_lengths = mask.sum(dim=1) - 1
                pooled = hidden[
                    torch.arange(hidden.shape[0], device=hidden.device),
                    sequence_lengths,
                ]

            if truncate_dim is not None:
                pooled = pooled[:, :truncate_dim]
            embeddings = F.normalize(pooled, p=2, dim=-1)

        return embeddings