# Copyright 2025 the LlamaFactory team.

import json
import os
from dataclasses import dataclass
from typing import Any, Literal, Optional, Union
from huggingface_hub import hf_hub_download

DATA_CONFIG = "dataset_info.json"


@dataclass
class DatasetAttr:
    r"""Dataset attributes."""

    # basic configs
    load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"]
    dataset_name: str
    formatting: Literal["alpaca", "sharegpt", "openai"] = "alpaca"
    ranking: bool = False
    # extra configs
    subset: Optional[str] = None
    split: str = "train"
    folder: Optional[str] = None
    num_samples: Optional[int] = None
    # common columns
    system: Optional[str] = None
    tools: Optional[str] = None
    # dpo columns
    chosen: Optional[str] = None
    rejected: Optional[str] = None
    kto_tag: Optional[str] = None
    # alpaca columns
    prompt: Optional[str] = "instruction"
    query: Optional[str] = "input"
    response: Optional[str] = "output"
    history: Optional[str] = None
    # sharegpt columns
    messages: Optional[str] = "conversations"
    # sharegpt tags
    role_tag: Optional[str] = "from"
    content_tag: Optional[str] = "value"
    user_tag: Optional[str] = "human"
    assistant_tag: Optional[str] = "gpt"
    observation_tag: Optional[str] = "observation"
    function_tag: Optional[str] = "function_call"
    system_tag: Optional[str] = "system"


    def __repr__(self) -> str:
        return self.dataset_name


    def set_attr(self, key: str, obj: dict[str, Any], default: Optional[Any] = None) -> None:
        setattr(self, key, obj.get(key, default))


    def join(self, attr: dict[str, Any]) -> None:
        self.set_attr("formatting", attr, default="alpaca")
        self.set_attr("ranking", attr, default=False)
        self.set_attr("subset", attr)
        self.set_attr("split", attr, default="train")
        self.set_attr("folder", attr)
        self.set_attr("num_samples", attr)

        if "columns" in attr:
            column_names = ["prompt", "query", "response", "history", "messages", "system", "tools"]
            column_names += [ "chosen", "rejected", "kto_tag"]
            for column_name in column_names:
                self.set_attr(column_name, attr["columns"])

        if "tags" in attr:
            tag_names = ["role_tag", "content_tag"]
            tag_names += ["user_tag", "assistant_tag", "observation_tag", "function_tag", "system_tag"]
            for tag in tag_names:
                self.set_attr(tag, attr["tags"])


def is_env_enabled(env_var: str, default: str = "0") -> bool:
    r"""Check if the environment variable is enabled."""
    return os.getenv(env_var, default).lower() in ["true", "y", "1"]


def use_modelscope() -> bool:
    return is_env_enabled("USE_MODELSCOPE_HUB")


def use_openmind() -> bool:
    return is_env_enabled("USE_OPENMIND_HUB")


def get_dataset_list(
    dataset_input: Optional[Union[dict[str, Any], str]],
    dataset_dir: Union[str, dict]
) -> list["DatasetAttr"]:
    if dataset_input is None:
        return []

    # Case 1: Local single dataset (dict)
    if isinstance(dataset_input, dict):
        return [build_dataset_attr_from_dict(dataset_input)]

    # Case 2: Dataset names from dataset_info.json (str)
    if isinstance(dataset_input, str):
        dataset_names = [name.strip() for name in dataset_input.split(",") if name.strip()]
    else:
        raise ValueError(f"Invalid dataset input: {dataset_input}")

    if not dataset_names:
        return []

    # Load dataset_info.json
    if isinstance(dataset_dir, dict):
        dataset_info = dataset_dir
    elif dataset_dir == "ONLINE":
        dataset_info = None
    else:
        if dataset_dir.startswith("REMOTE:"):
            config_path = hf_hub_download(repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset")
        else:
            config_path = os.path.join(dataset_dir, DATA_CONFIG)
        try:
            with open(config_path) as f:
                dataset_info = json.load(f)
        except Exception as err:
            raise ValueError(f"Cannot open {config_path} due to {str(err)}.")

    # Build attrs from dataset_info
    dataset_attrs = []
    for name in dataset_names:
        if dataset_info is None:
            load_from = "ms_hub" if use_modelscope() else "om_hub" if use_openmind() else "hf_hub"
            attr = DatasetAttr(load_from, dataset_name=name)
        else:
            if name not in dataset_info:
                raise ValueError(f"Undefined dataset '{name}' in {DATA_CONFIG}.")
            config = dataset_info[name]
            attr = build_dataset_attr_from_dict(config)
        dataset_attrs.append(attr)

    return dataset_attrs


def build_dataset_attr_from_dict(config: dict[str, Any]) -> "DatasetAttr":
    r"""
    Build a DatasetAttr object from a config dictionary.
    
    The config dict should follow the same format as entries in dataset_info.json,
    supporting keys like:
      - file_name, script_url, cloud_file_name
      - hf_hub_url, ms_hub_url, om_hub_url
      - formatting, ranking, split, num_samples, subset, folder
      - columns, tags, etc.
    
    Example:
        config = {
            "file_name": "alpaca.json",
            "formatting": "alpaca",
            "columns": {"prompt": "instruction", "response": "output"}
        }
    """
    # Determine load_from and dataset_name based on available keys
    if "file_name" in config:
        load_from = "file"
        dataset_name = config["file_name"]
    elif "script_url" in config:
        load_from = "script"
        dataset_name = config["script_url"]
    elif "cloud_file_name" in config:
        load_from = "cloud_file"
        dataset_name = config["cloud_file_name"]
    elif "hf_hub_url" in config:
        load_from = "hf_hub"
        dataset_name = config["hf_hub_url"]
    elif "ms_hub_url" in config:
        load_from = "ms_hub"
        dataset_name = config["ms_hub_url"]
    elif "om_hub_url" in config:
        load_from = "om_hub"
        dataset_name = config["om_hub_url"]
    else:
        raise ValueError(
            "Config must contain one of: "
            "'file_name', 'script_url', 'cloud_file_name', "
            "'hf_hub_url', 'ms_hub_url', or 'om_hub_url'."
        )

    # Create DatasetAttr with basic info
    attr = DatasetAttr(load_from=load_from, dataset_name=dataset_name)

    # Apply all other fields (formatting, ranking, columns, tags, etc.)
    attr.join(config)

    return attr