import json

import transformers


class HunyuanMllmTokenizer:
    def __init__(
        self,
        **config,
    ):
        template_file_path = config.pop("template_file_path")
        template_id = config.pop("template_id", "hyv-llm-encode-video")
        with open(template_file_path, "r") as f:
            templates = json.load(f)
        self.template_info = templates[template_id]
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(**config)


    @staticmethod
    def apply_template(text, template):
        if isinstance(text, str):
            return [template.format(text)]
        elif isinstance(text, list) or isinstance(text, tuple):
            return [template.format(t) for t in text]
        else:
            raise NotImplementedError(f"Not Support text type: {type(text)}")


    def __call__(
        self,
        prompt,
        padding: str = "max_length",
        max_length: int = 256,
        truncation: bool = True,
        return_attention_mask: bool = True,
        add_special_tokens: bool = True,
        return_tensors: str = "pt",
        **kwargs,
    ):
        prompt = HunyuanMllmTokenizer.apply_template(prompt, self.template_info["template"])
        text_inputs = self.tokenizer(
            prompt,
            padding=padding,
            max_length=max_length + self.template_info["crop_start"],
            truncation=truncation,
            return_attention_mask=return_attention_mask,
            add_special_tokens=add_special_tokens,
            return_tensors=return_tensors,
            **kwargs,
        )
        return text_inputs

    def __getattr__(self, name):
        if name in dir(self):
            return super().__getattr__(name)
        else:
            return getattr(self.tokenizer, name)