from typing import List, Optional
import os
import torch
import torch.nn.functional as F
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
from transformers.modeling_utils import PreTrainedModel
from transformers.models.qwen3 import Qwen3Model
from peft import PeftMixedModel, PeftConfig
from .configuration_jina_embeddings_v5 import JinaEmbeddingsV5Config
class JinaEmbeddingsV5Model(PeftMixedModel):
@classmethod
def register_for_auto_class(cls, auto_class="AutoModel"):
return PreTrainedModel.register_for_auto_class.__func__(cls, auto_class)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs):
if kwargs.get("config", None):
base_config = kwargs.pop("config")
else:
base_config = JinaEmbeddingsV5Config.from_pretrained(
pretrained_model_name_or_path,
)
base_model = Qwen3Model.from_pretrained(
pretrained_model_name_or_path,
config=base_config,
dtype=kwargs.pop("dtype", torch.bfloat16),
)
if os.path.isdir(base_model.name_or_path):
adapters_dir = os.path.join(base_model.name_or_path, "adapters")
else:
adapter_cache_path = snapshot_download(
repo_id=base_model.name_or_path,
allow_patterns=["adapters/*"],
)
adapters_dir = os.path.join(adapter_cache_path, "adapters")
adapter_paths = {
name: os.path.join(adapters_dir, name)
for name in base_config.task_names
}
peft_config = PeftConfig.from_pretrained(adapter_paths["retrieval"], **kwargs)
model = cls(base_model, peft_config, adapter_name="retrieval")
model._pretrained_path = pretrained_model_name_or_path
for adapter_name in base_config.task_names:
model.load_adapter(
adapter_paths[adapter_name],
adapter_name=adapter_name,
**kwargs,
)
model.tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=True,
)
return model
def encode(
self,
texts: List[str],
task: str,
prompt_name: Optional[str] = "document",
truncate_dim: Optional[int] = None,
max_length: Optional[int] = None,
) -> List[torch.Tensor]:
if task not in self.base_model.config.task_names:
raise ValueError(f"Unknown task: {task}")
if prompt_name is None:
prompt_name = "document"
if prompt_name not in {"query", "document"}:
raise ValueError(f"Unknown prompt_name: {prompt_name}")
prefix = "Query: " if prompt_name == "query" else "Document: "
inputs = [f"{prefix}{text}" for text in texts]
if not hasattr(self, "tokenizer") or self.tokenizer is None:
raise ValueError("Tokenizer not found on model. Load with from_pretrained().")
max_length = max_length or self.config.max_position_embeddings
batch = self.tokenizer(
inputs,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
)
device = next(self.parameters()).device
batch = {k: v.to(device) for k, v in batch.items()}
self.set_adapter([task])
self.eval()
with torch.no_grad():
outputs = self(**batch)
hidden = outputs.last_hidden_state
mask = batch.get("attention_mask")
if mask is None:
pooled = hidden[:, -1]
else:
sequence_lengths = mask.sum(dim=1) - 1
pooled = hidden[
torch.arange(hidden.shape[0], device=hidden.device),
sequence_lengths,
]
if truncate_dim is not None:
pooled = pooled[:, :truncate_dim]
embeddings = F.normalize(pooled, p=2, dim=-1)
return embeddings