bee271e1创建于 2024年3月5日历史提交
from typing import List, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from open_clip import create_model, get_tokenizer
from open_clip.transform import PreprocessCfg, image_transform_v2
from PIL import Image
from transformers import PretrainedConfig, PreTrainedModel


class MatryoshkaNllbClipConfig(PretrainedConfig):
    def __init__(
        self,
        clip_model_name: str = "",
        target_resolution: int = -1,
        mrl_resolutions: List[int] = [],
        preprocess_cfg: Union[dict, None] = None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.clip_model_name = clip_model_name
        self.target_resolution = target_resolution
        self.mrl_resolutions = mrl_resolutions
        self.preprocess_cfg = preprocess_cfg


class MatryoshkaLayer(nn.Module):
    def __init__(self, resolutions: List[int], target_resolution: int = 768):
        super().__init__()
        self.resolutions = resolutions
        self.layers = nn.ModuleDict()
        for resolution in resolutions:
            self.layers[str(resolution)] = nn.Linear(target_resolution, resolution)

    def forward(self, x, resolution: Union[int, None] = None):
        if resolution is not None:
            if resolution not in self.resolutions:
                raise ValueError(f"Resolution {resolution} not in {self.resolutions}")
            return self.layers[str(resolution)](x)
        outputs = []
        for resolution in self.resolutions:
            outputs.append(self.layers[str(resolution)](x))
        return outputs


class MatryoshkaNllbClip(PreTrainedModel):
    config_class = MatryoshkaNllbClipConfig

    def __init__(self, config: MatryoshkaNllbClipConfig, device):
        super().__init__(config)
        if isinstance(device, str):
            device = torch.device(device)
        self.config = config
        self.model = create_model(config.clip_model_name, output_dict=True)
        pp_cfg = PreprocessCfg(
            size=config.preprocess_cfg["size"],
            mean=config.preprocess_cfg["mean"],
            std=config.preprocess_cfg["std"],
            interpolation=config.preprocess_cfg["interpolation"],
            resize_mode=config.preprocess_cfg["resize_mode"],
        )
        self.transform = image_transform_v2(
            pp_cfg,
            is_train=False,
        )
        self._device = device
        self.model.to(device)
        self.matryoshka_layer = MatryoshkaLayer(
            config.mrl_resolutions, config.target_resolution
        )
        self.matryoshka_layer.to(device)
        self.tokenizer = get_tokenizer(config.clip_model_name)

    def forward(self, image_inputs, input_ids, resolution: Union[int, None] = None):
        image_inputs = image_inputs.to(self._device)
        input_ids = input_ids.to(self._device)
        outputs = self.model(
            image=image_inputs,
            text=input_ids,
        )
        mrl_image_features = None
        mrl_text_features = None
        if resolution is not None:
            mrl_image_features = self.matryoshka_layer.forward(
                outputs["image_features"], resolution
            )
            mrl_text_features = self.matryoshka_layer.forward(
                outputs["text_features"], resolution
            )
        return {
            "image_features": outputs["image_features"],
            "text_features": outputs["text_features"],
            "mrl_image_features": mrl_image_features,
            "mrl_text_features": mrl_text_features,
            "logit_scale": outputs["logit_scale"],
            "logit_bias": outputs["logit_bias"],
        }

    def encode_image(
        self,
        image,
        normalize=False,
        resolution: Union[int, None] = None,
    ):
        with torch.inference_mode():
            features = self.model.visual(image)
            if resolution is not None:
                if resolution not in self.matryoshka_layer.resolutions:
                    raise ValueError(
                        f"Resolution {resolution} not in {self.matryoshka_layer.resolutions}"
                    )
                features = self.matryoshka_layer.layers[str(resolution)](features)
            return F.normalize(features, dim=-1) if normalize else features

    def encode_text(
        self,
        text,
        normalize=False,
        resolution: Union[int, None] = None,
    ):
        with torch.inference_mode():
            features = self.model.text(text)
            if resolution is not None:
                if resolution not in self.matryoshka_layer.resolutions:
                    raise ValueError(
                        f"Resolution {resolution} not in {self.matryoshka_layer.resolutions}"
                    )
                features = self.matryoshka_layer.layers[str(resolution)](features)
            return F.normalize(features, dim=-1) if normalize else features

    def image_features(
        self,
        images: List[Image.Image],
        normalize=False,
        resolution: Union[int, None] = None,
    ):
        image_inputs = [self.transform(image) for image in images]
        image_inputs = torch.stack(image_inputs, dim=0).to(self._device)
        with torch.inference_mode():
            features = self.model.visual(image_inputs)
            if resolution is not None:
                if resolution not in self.matryoshka_layer.resolutions:
                    raise ValueError(
                        f"Resolution {resolution} not in {self.matryoshka_layer.resolutions}"
                    )
                features = self.matryoshka_layer.layers[str(resolution)](features)
            return F.normalize(features, dim=-1) if normalize else features

    def text_features(
        self,
        texts: List[str],
        langs: Union[List[str], None] = None,
        normalize=False,
        resolution: Union[int, None] = None,
    ):
        if langs is None:
            langs = ["eng_Latn"] * len(texts)
        texts = [f"{lang}{text}" for lang, text in zip(langs, texts)]
        input_ids = self.tokenizer.tokenizer.batch_encode_plus(
            texts, return_tensors="pt", padding="longest", add_special_tokens=False
        )["input_ids"].to(self._device)
        with torch.inference_mode():
            features = self.model.text(input_ids)
            if resolution is not None:
                if resolution not in self.matryoshka_layer.resolutions:
                    raise ValueError(
                        f"Resolution {resolution} not in {self.matryoshka_layer.resolutions}"
                    )
                features = self.matryoshka_layer.layers[str(resolution)](features)
            return F.normalize(features, dim=-1) if normalize else features

    def get_logits(
        self,
        images: List[Image.Image],
        texts: List[str],
        langs: Union[List[str], None] = None,
        resolution: Union[int, None] = None,
    ):
        image_features = self.image_features(
            images, normalize=True, resolution=resolution
        )
        text_features = self.text_features(
            texts, langs, normalize=True, resolution=resolution
        )
        with torch.inference_mode():
            image_logits = (
                self.model.logit_scale.exp() * image_features @ text_features.T
            )
            if self.model.logit_bias is not None:
                image_logits += self.model.logit_bias
        text_logits = image_logits.T
        return image_logits, text_logits