# coding=utf-8
# Copyright 2024 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from dataclasses import dataclass, field
from typing import Dict, Optional

import torch
import transformers
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers.training_args import TrainingArguments


@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="gpt2")


@dataclass
class DataArguments:
    data_path: str = field(
        default=None, metadata={"help": "Path to the training data."}
    )


@dataclass
class TrainingArguments(transformers.TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    model_max_length: int = field(
        default=512,
        metadata={
            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )
    use_lora: bool = field(default=False)


class SupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(
        self,
        data_path,
        tokenizer,
        model_max_length,
        user_string="## human:",
        copilot_string="## copilot:",
        assistant_string="## assistant:",
        end_string=" |<end>| ",
    ):
        super(SupervisedDataset, self).__init__()
        self.data = json.load(open(data_path))
        self.tokenizer = tokenizer
        self.model_max_length = model_max_length
        self.user_string = user_string
        self.assistant_string = assistant_string
        self.end_string = end_string
        self.user_tokens = self.tokenizer.encode(user_string)
        self.copilot_tokens = self.tokenizer.encode(copilot_string)
        self.assistant_tokens = self.tokenizer.encode(assistant_string)
        self.end_tokens = self.tokenizer.encode(end_string)
        self.ignore_index = -100

        self.preprocessed_data = self.preprocessing()
        item = self.preprocessed_data[0]

        labels = []
        for id_ in item["labels"]:
            if id_ == -100:
                continue

            labels.append(id_)

    def __len__(self):
        return len(self.preprocessed_data)

    def preprocessing(self):
        preprocessed_data = []
        for example in tqdm(self.data, desc="Preprocessing"):
            preprocess_example = self.preprocess_one(example)
            if len(preprocess_example["input_ids"]) <= 16:
                continue
            preprocessed_data.append(preprocess_example)
        return preprocessed_data

    def preprocess_one(self, example):
        input_ids = []
        labels = []

        chat_mode = "human"
        if "copilot" in [message["from"] for message in example["conversations"]]:
            chat_mode = "copilot"

        if chat_mode == "human":
            for idx, message in enumerate(example["conversations"]):
                if idx == 0:
                    input_ids += [self.tokenizer.eos_token_id]
                    labels += [self.ignore_index]
                from_ = message["from"]
                value = message["value"]
                value_ids = self.tokenizer.encode(value)

                if len(input_ids) + len(self.user_tokens + value_ids + self.end_tokens) > self.model_max_length:
                    break

                if from_ == "human":
                    input_ids += self.user_tokens + value_ids + self.end_tokens
                    labels += [self.ignore_index] * len(
                        self.user_tokens + value_ids + self.end_tokens
                    )
                else:
                    input_ids += self.assistant_tokens + value_ids + self.end_tokens
                    labels += [self.ignore_index] * len(self.assistant_tokens) \
                              + value_ids + self.end_tokens
        elif chat_mode == "copilot":
            for idx, message in enumerate(example["conversations"]):
                from_ = message["from"]
                value = message["value"]
                value_ids = self.tokenizer.encode(value)

                if len(input_ids) + len(value_ids) > self.model_max_length:
                    break

                if from_ == "copilot":
                    input_ids += value_ids
                    labels += [self.ignore_index] * len(value_ids)
                else:
                    input_ids += value_ids + [self.tokenizer.eos_token_id]
                    labels += value_ids + [self.tokenizer.eos_token_id]
        else:
            raise ValueError("chat_mode should be human or copilot")

        input_ids = input_ids[-self.model_max_length:]
        labels = labels[-self.model_max_length:]
        input_ids += [self.tokenizer.pad_token_id] * (self.model_max_length - len(input_ids))
        labels += [self.ignore_index] * (self.model_max_length - len(labels))
        input_ids = torch.LongTensor(input_ids)
        labels = torch.LongTensor(labels)
        attention_mask = input_ids.ne(self.tokenizer.eos_token_id)
        return {
            "input_ids": input_ids,
            "labels": labels,
            "attention_mask": attention_mask,
        }

    def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
        return self.preprocessed_data[idx]

    def print_dataset_example(self, num=3):
        for idx in range(num):
            example = self.preprocessed_data[idx]


def train():
    parser = transformers.HfArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments)
    )
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    model = transformers.AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        trust_remote_code=True,
        cache_dir=training_args.cache_dir,
    )
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        use_fast=True,
        trust_remote_code=True,
        model_max_length=training_args.model_max_length,
    )

    if tokenizer.eos_token_id is None:
        tokenizer.eos_token_id = tokenizer.bos_token_id
    if tokenizer.eos_token is None:
        tokenizer.eos_token = tokenizer.bos_token
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    if training_args.use_lora:
        from peft import LoraConfig, TaskType, get_peft_model

        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            target_modules=["c_attn"],
            inference_mode=False,
            r=1,
            lora_alpha=32,
            lora_dropout=0.1,
        )
        model.enable_input_require_grads()
        model = get_peft_model(model, peft_config)
        model.print_trainable_parameters()

    dataset = SupervisedDataset(
        data_args.data_path, tokenizer, training_args.model_max_length
    )
    dataset.print_dataset_example()

    trainer = transformers.Trainer(
        model=model, args=training_args, train_dataset=dataset, tokenizer=tokenizer
    )
    trainer.train()
    trainer.save_state()
    trainer.save_model(output_dir=training_args.output_dir)


if __name__ == "__main__":
    train()