BBenedikt Schifferermodify README
16cdad8c创建于 2025年6月28日历史提交
# --------------------------------------------------------
# Copyright (c) 2025 NVIDIA
# Licensed under customized NSCLv1 [see LICENSE.md for details]
# --------------------------------------------------------

# Based on https://github.com/OpenGVLab/InternVL/blob/main/streamlit_demo/model_worker.py
# https://github.com/OpenGVLab/InternVL/?tab=MIT-1-ov-file#readme

# Importing torch before transformers can cause `segmentation fault`

from transformers import  AutoTokenizer,  AutoConfig
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
import base64
import os
from io import BytesIO
from typing import Tuple
import math
import requests
import torch
from torch import Tensor
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from typing import Optional, Any, Union, Dict, List

from tqdm import tqdm
import torch.nn.functional as F
from datasets import Dataset
from torch.utils.data import DataLoader

from .modeling_eagle_chat import Eagle2ChatModel
from .configuration_eagle_chat import Eagle2ChatConfig
from .conversation import get_conv_template 

from .configuration_siglip import SiglipVisionConfig
from .modeling_siglip import SiglipVisionModel
from .flash_attention import *

from .llama_bidirectional_model import LlamaBidirectionalModel
from transformers import PreTrainedModel

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

SIGLIP_MEAN = (0.5, 0.5, 0.5)
SIGLIP_STD = (0.5, 0.5, 0.5)

def load_image(image):
    if isinstance(image, Image.Image):
        return image
    elif isinstance(image, str) and os.path.exists(image):
        return Image.open(image)
    elif isinstance(image, dict):
        if 'disk_path' in image:
            return Image.open(image['disk_path'])
        elif 'base64' in image:
            return Image.open(BytesIO(base64.b64decode(image['base64'])))
        elif 'url' in image:
            response = requests.get(image['url'])
            return Image.open(BytesIO(response.content))
        elif 'bytes' in image:
            return Image.open(BytesIO(image['bytes']))
        else:
            raise ValueError(f'Invalid image: {image}')
    else:
        raise ValueError(f'Invalid image: {image}')

def build_transform(input_size, norm_type='imagenet'):
    if norm_type == 'imagenet':
        MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    elif norm_type == 'siglip':
        MEAN, STD = SIGLIP_MEAN, SIGLIP_STD
        
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform

def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    """
    previous version mainly foucs on ratio.
    We also consider area ratio here.
    """
    best_factor = float('-inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        area_ratio = (ratio[0]*ratio[1]*image_size*image_size)/ area
        """
        new area > 60% of original image area is enough.
        """
        factor_based_on_area_n_ratio = min((ratio[0]*ratio[1]*image_size*image_size)/ area, 0.6)* \
                                     min(target_aspect_ratio/aspect_ratio, aspect_ratio/target_aspect_ratio)
        
        if factor_based_on_area_n_ratio > best_factor:
            best_factor = factor_based_on_area_n_ratio
            best_ratio = ratio
        
    return best_ratio

def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    # calculate the existing image aspect ratio
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images

def split_model(model_path, device):

    device_map = {}
    world_size = torch.cuda.device_count()
    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
    num_layers = config.llm_config.num_hidden_layers

    print('world_size', world_size)
    num_layers_per_gpu_ = math.floor(num_layers / (world_size - 1))
    num_layers_per_gpu = [num_layers_per_gpu_] * world_size
    num_layers_per_gpu[device] = num_layers - num_layers_per_gpu_ * (world_size-1)
    print(num_layers_per_gpu)
    layer_cnt = 0
    for i, num_layer in enumerate(num_layers_per_gpu):
        for j in range(num_layer):
            device_map[f'language_model.model.layers.{layer_cnt}'] = i
            layer_cnt += 1
    device_map['vision_model'] = device
    device_map['mlp1'] = device
    device_map['language_model.model.tok_embeddings'] = device
    device_map['language_model.model.embed_tokens'] = device
    device_map['language_model.output'] = device
    device_map['language_model.model.norm'] = device
    device_map['language_model.lm_head'] = device
    device_map['language_model.model.rotary_emb'] = device
    device_map[f'language_model.model.layers.{num_layers - 1}'] = device
    return device_map

class llama_NemoRetrieverColEmbedConfig(Eagle2ChatConfig):
    model_type = "llama_nemoretrievercolembed"

    q_max_length: Optional[int] 
    p_max_length: Optional[int]     
    query_prefix: str 
    passage_prefix: str
    pooling: str
    bidirectional_attention: bool

    def __init__(
        self, 
        q_max_length: Optional[int] = 512,
        p_max_length: Optional[int] = 10240,        
        query_prefix: str = "query:",
        passage_prefix: str = "passage:",
        pooling: str = "last", 
        bidirectional_attention: bool = False,
        max_input_tiles: int = 2,
        img_context_token_id: int = 128258, #tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>")
        out_dimension: int = -1,
        **kwargs,
    ):
        self.q_max_length = q_max_length
        self.p_max_length = p_max_length        
        self.query_prefix = query_prefix
        self.passage_prefix = passage_prefix
        self.pooling = pooling
        self.bidirectional_attention = bidirectional_attention
        self.img_context_token_id = img_context_token_id
        self.max_input_tiles = max_input_tiles
        self.out_dimension = out_dimension
        super().__init__(**kwargs)

class llama_NemoRetrieverColEmbed(Eagle2ChatModel):

    config_class = llama_NemoRetrieverColEmbedConfig
    _supports_flash_attn_2 = True

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.padding = True
        self.q_max_length = 512
        self.p_max_length = 10240
        self.pad_to_multiple_of = None
        self.query_prefix = 'query:'
        self.passage_prefix = 'passage:'

        if isinstance(args[0], llama_NemoRetrieverColEmbedConfig):
            tokenizer = AutoTokenizer.from_pretrained(args[0]._name_or_path, trust_remote_code=True)
            tokens_to_keep = ['<box>', '</box>', '<ref>', '</ref>']
            tokenizer.additional_special_tokens = [item for item in tokenizer.additional_special_tokens if item not in tokens_to_keep]
            tokenizer.padding_side = 'left'
            self.tokenizer = tokenizer

        self.norm_type = 'siglip'
        self.image_size = self.config.force_image_size
        self.max_input_tiles = 6
        self.system_message = ""
        self.use_visual_embedding = True
    
    def process_documents(self, documents: Union[Dict,List[Dict]], **kwargs):
        if isinstance(documents, dict):
            images = documents["images"]
            texts = documents["texts"]
            assert len(texts) == len(images)
        elif isinstance(documents, list):
            images = [pair['image'] for pair in documents ]
            texts = [pair['text'] for pair in documents ]
        else:
            raise ValueError("The documents need to be a dict or list of dicts")
        
        if self.passage_prefix:
            texts = [self.passage_prefix + ' ' + t for t in texts]

        contents, pil_images, max_input_tile_list, llm_onlys = [], [], [], []
        for image, text in zip(images, texts):
            prefix = ''
            llm_only = True
            if image != '':
                pil_images.append(load_image(image))
                prefix = '<image>'
                max_input_tile_list.append(self.max_input_tiles)
                llm_only = False
            else:
                pil_images.append(None)
                max_input_tile_list.append(self.max_input_tiles)
            
            llm_onlys.append(llm_only)

            content = text
            if prefix!='':
                content = prefix + ' ' + content
            if self.passage_prefix:
                content = self.passage_prefix + ' ' + content
            contents.append(content)
        
        transform = build_transform(input_size=self.image_size, norm_type=self.norm_type)

        template = get_conv_template(self.config.template)
        template.system_message = self.system_message
        
        content_prompts = []
        pixel_values_list = []
        for content, pil_image, max_input_tiles, llm_only in zip(contents, pil_images, max_input_tile_list, llm_onlys):
            if pil_image is not None:
                if self.config.dynamic_image_size:
                    image_tiles = dynamic_preprocess(
                        pil_image, image_size=self.image_size, max_num=max_input_tiles,
                        use_thumbnail=self.config.use_thumbnail)
                else:
                    image_tiles = [pil_image]
                        
                pixel_values = [transform(item) for item in image_tiles]
                pixel_values = torch.stack(pixel_values).to(dtype=torch.bfloat16)
                pixel_values_list.append(pixel_values)
            else:
                pixel_values = None

            IMG_START_TOKEN='<img>'
            IMG_END_TOKEN='</img>'
            IMG_CONTEXT_TOKEN='<IMG_CONTEXT>'        

            if pixel_values is not None and '<image>' not in content and not llm_only:
                content = '<image> ' + content

            # Reseting conversation messages
            template.messages.clear()

            # TODO: do we need this template?
            template.append_message(template.roles[0], content) # user
            template.append_message(template.roles[1], None)     # assistant
            content_prompt = template.get_prompt()

            if '<image>' not in content:
                content_prompt = content_prompt
            else:
                num_patches = pixel_values.shape[0]
                image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
                content_prompt = content_prompt.replace('<image>', image_tokens, 1)

            content_prompts.append(content_prompt)

        model_inputs = self.tokenizer(content_prompts, 
                                      truncation=True,
                                      max_length=self.p_max_length,
                                      padding=self.padding,
                                      pad_to_multiple_of=self.pad_to_multiple_of,
                                      return_tensors='pt')

        if len(pixel_values_list)>1:
            pixel_values_squeezed = torch.concat(pixel_values_list, axis=0)
        elif len(pixel_values_list)==1:
            pixel_values_squeezed = pixel_values_list[0]
        else:
            pixel_values_squeezed = None

        batch_docs = {
            "input_ids": model_inputs['input_ids'],
            "attention_mask": model_inputs['attention_mask'],
            "pixel_values": None
        }
        if pixel_values_squeezed is not None:
            batch_docs["pixel_values"] = pixel_values_squeezed
        
        return batch_docs

    def process_queries(self, queries: List[str], **kwargs):

        template = get_conv_template(self.config.template)
        template.system_message = self.system_message

        query_prompts = []
        for query in queries:
            if self.query_prefix:
                query = f"{self.query_prefix} {query}"

            # Reseting conversation messages
            template.messages.clear()

            template.append_message(template.roles[0], query)    # user
            template.append_message(template.roles[1], None)     # assistant
            query_prompt = template.get_prompt()

            query_prompts.append(query_prompt)

        
        batch_query = self.tokenizer(
            query_prompts, 
            truncation=True,
            max_length=self.q_max_length,
            padding=self.padding,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors='pt'
        )

        return batch_query

    def get_scores(
        self, 
        query_embeddings: Union[torch.Tensor, List[torch.Tensor]],
        passage_embeddings: Union[torch.Tensor, List[torch.Tensor]],
        batch_size: Optional[int] = 8,
    ) -> torch.Tensor:
        """Dot-product similarity between queries and passages."""        
        if isinstance(query_embeddings, list):
            if len(query_embeddings[0].shape)==2:
                # Expend Batch Dimension as ViDoRe Framework remove it
                query_embeddings = [q.unsqueeze(0) for q in query_embeddings]
            query_embeddings = self.padding_various_shape_tensor(query_embeddings)
        if isinstance(passage_embeddings, list):
            if len(passage_embeddings[0].shape)==2:
                # Expend Batch Dimension as ViDoRe Framework remove it
                passage_embeddings = [p.unsqueeze(0) for p in passage_embeddings]
            passage_embeddings = self.padding_various_shape_tensor(passage_embeddings)
        
        return self.colbert_score(query_embeddings, passage_embeddings, batch_size)

    def colbert_score(
        self,
        qs: Union[torch.Tensor, List[torch.Tensor]],
        ps: Union[torch.Tensor, List[torch.Tensor]],
        batch_size: int = 128,
        device: Optional[Union[str, torch.device]] = None,
    ) -> torch.Tensor:
        if batch_size is None:
            batch_size = 128
        if len(qs) == 0:
            raise ValueError("No queries provided")
        if len(ps) == 0:
            raise ValueError("No passages provided")
    
        scores_list: List[torch.Tensor] = []
        for i in range(0, len(qs), batch_size):
            scores_batch = []
            qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size].cuda(), batch_first=True, padding_value=0)
            for j in range(0, len(ps), batch_size):
                ps_batch = torch.nn.utils.rnn.pad_sequence(
                    ps[j : j + batch_size].cuda(), batch_first=True, padding_value=0
                )
                scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2))
            # Keep scores_batch on the GPU
            scores_batch = torch.cat(scores_batch, dim=1)
            scores_list.append(scores_batch)
    
        scores = torch.cat(scores_list, dim=0)
        return(scores)
    
    def _extract_embeddings(self, dataloader: DataLoader, is_query: bool) -> List[torch.Tensor]:
        qs = []
        message = "query" if is_query else "document"
        for batch in tqdm(dataloader, desc=f"Extracting {message} embeddings..."):
            with torch.inference_mode():
                with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
                    if 'pixel_values' in batch and batch['pixel_values'] is None:
                        batch.pop('pixel_values')
                    batch = {k: v.to(self.device) for k, v in batch.items()}
                    embeddings = self(**batch, output_hidden_states=True).hidden_states[-1]
                    embeddings = embeddings*batch['attention_mask'].unsqueeze(-1)
                    embeddings = F.normalize(embeddings, dim=-1)
                    
            # Detecting abnormal outputs
            assert torch.sum(embeddings).float().item() not in [float(0.), float("inf")]
            qs.append(embeddings.contiguous())
        
        qs_tensor = self.padding_various_shape_tensor(qs)
        all_embeddings_tensor = qs_tensor.detach().cpu()
        return all_embeddings_tensor

    def forward_passages(self, passages, batch_size=8, **kwargs) -> Union[torch.Tensor, List[torch.Tensor]]:
        """Forward passages as image-only documents."""
        corpus = []
        for image in passages:
            corpus.append({
                "image": image,
                "text": ''
            })
        return self.forward_documents(corpus, batch_size)
    
    def forward_queries(self, queries: List, batch_size=8) -> List[torch.Tensor]:
        dataset = ListDataset[str](queries)
        dataloader = DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            collate_fn=self.process_queries,
            shuffle=False,
            num_workers=8,
            pin_memory=True,
            drop_last=False,
        )
        return self._extract_embeddings(dataloader=dataloader, is_query=True)

    def forward_documents(self, corpus: List, batch_size=8) -> List[torch.Tensor]:    
        images = []
        texts = []
        for doc in corpus:
            text = doc["text"]
            image = doc.get("image", "")
            if image.mode != "RGB":
                image = image.convert("RGB")
            images.append(image)
            texts.append(text)

        dataset = Dataset.from_dict({"image": images, "text": texts})
        dataloader = DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            collate_fn=self.process_documents,
            shuffle=False,
            num_workers=8,
            pin_memory=True,
            drop_last=False,
        )
        return self._extract_embeddings(dataloader=dataloader, is_query=False)

    def padding_various_shape_tensor(self, tensors: List[torch.Tensor]) -> torch.Tensor:
        """Pad tensors of various shapes for colbert-like scoring"""
        max_seq_len = max(t.shape[1] for t in tensors)
        padded_tensors = [F.pad(t, (0, 0, 0, max_seq_len - t.shape[1]), mode="constant", value=0) for t in tensors]
        return torch.cat(padded_tensors, dim=0)


from typing import TypeVar
from torch.utils.data import Dataset as TorchDataset
TV = TypeVar("T")
class ListDataset(TorchDataset[TV]):
    def __init__(self, elements: List[TV]):
        self.elements = elements

    def __len__(self) -> int:
        return len(self.elements)

    def __getitem__(self, idx: int) -> TV:
        return self.elements[idx]