import json
import os
from dataclasses import dataclass
from typing import Any, Literal, Optional, Union
from huggingface_hub import hf_hub_download
DATA_CONFIG = "dataset_info.json"
@dataclass
class DatasetAttr:
r"""Dataset attributes."""
load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"]
dataset_name: str
formatting: Literal["alpaca", "sharegpt", "openai"] = "alpaca"
ranking: bool = False
subset: Optional[str] = None
split: str = "train"
folder: Optional[str] = None
num_samples: Optional[int] = None
system: Optional[str] = None
tools: Optional[str] = None
chosen: Optional[str] = None
rejected: Optional[str] = None
kto_tag: Optional[str] = None
prompt: Optional[str] = "instruction"
query: Optional[str] = "input"
response: Optional[str] = "output"
history: Optional[str] = None
messages: Optional[str] = "conversations"
role_tag: Optional[str] = "from"
content_tag: Optional[str] = "value"
user_tag: Optional[str] = "human"
assistant_tag: Optional[str] = "gpt"
observation_tag: Optional[str] = "observation"
function_tag: Optional[str] = "function_call"
system_tag: Optional[str] = "system"
def __repr__(self) -> str:
return self.dataset_name
def set_attr(self, key: str, obj: dict[str, Any], default: Optional[Any] = None) -> None:
setattr(self, key, obj.get(key, default))
def join(self, attr: dict[str, Any]) -> None:
self.set_attr("formatting", attr, default="alpaca")
self.set_attr("ranking", attr, default=False)
self.set_attr("subset", attr)
self.set_attr("split", attr, default="train")
self.set_attr("folder", attr)
self.set_attr("num_samples", attr)
if "columns" in attr:
column_names = ["prompt", "query", "response", "history", "messages", "system", "tools"]
column_names += [ "chosen", "rejected", "kto_tag"]
for column_name in column_names:
self.set_attr(column_name, attr["columns"])
if "tags" in attr:
tag_names = ["role_tag", "content_tag"]
tag_names += ["user_tag", "assistant_tag", "observation_tag", "function_tag", "system_tag"]
for tag in tag_names:
self.set_attr(tag, attr["tags"])
def is_env_enabled(env_var: str, default: str = "0") -> bool:
r"""Check if the environment variable is enabled."""
return os.getenv(env_var, default).lower() in ["true", "y", "1"]
def use_modelscope() -> bool:
return is_env_enabled("USE_MODELSCOPE_HUB")
def use_openmind() -> bool:
return is_env_enabled("USE_OPENMIND_HUB")
def get_dataset_list(
dataset_input: Optional[Union[dict[str, Any], str]],
dataset_dir: Union[str, dict]
) -> list["DatasetAttr"]:
if dataset_input is None:
return []
if isinstance(dataset_input, dict):
return [build_dataset_attr_from_dict(dataset_input)]
if isinstance(dataset_input, str):
dataset_names = [name.strip() for name in dataset_input.split(",") if name.strip()]
else:
raise ValueError(f"Invalid dataset input: {dataset_input}")
if not dataset_names:
return []
if isinstance(dataset_dir, dict):
dataset_info = dataset_dir
elif dataset_dir == "ONLINE":
dataset_info = None
else:
if dataset_dir.startswith("REMOTE:"):
config_path = hf_hub_download(repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset")
else:
config_path = os.path.join(dataset_dir, DATA_CONFIG)
try:
with open(config_path) as f:
dataset_info = json.load(f)
except Exception as err:
raise ValueError(f"Cannot open {config_path} due to {str(err)}.")
dataset_attrs = []
for name in dataset_names:
if dataset_info is None:
load_from = "ms_hub" if use_modelscope() else "om_hub" if use_openmind() else "hf_hub"
attr = DatasetAttr(load_from, dataset_name=name)
else:
if name not in dataset_info:
raise ValueError(f"Undefined dataset '{name}' in {DATA_CONFIG}.")
config = dataset_info[name]
attr = build_dataset_attr_from_dict(config)
dataset_attrs.append(attr)
return dataset_attrs
def build_dataset_attr_from_dict(config: dict[str, Any]) -> "DatasetAttr":
r"""
Build a DatasetAttr object from a config dictionary.
The config dict should follow the same format as entries in dataset_info.json,
supporting keys like:
- file_name, script_url, cloud_file_name
- hf_hub_url, ms_hub_url, om_hub_url
- formatting, ranking, split, num_samples, subset, folder
- columns, tags, etc.
Example:
config = {
"file_name": "alpaca.json",
"formatting": "alpaca",
"columns": {"prompt": "instruction", "response": "output"}
}
"""
if "file_name" in config:
load_from = "file"
dataset_name = config["file_name"]
elif "script_url" in config:
load_from = "script"
dataset_name = config["script_url"]
elif "cloud_file_name" in config:
load_from = "cloud_file"
dataset_name = config["cloud_file_name"]
elif "hf_hub_url" in config:
load_from = "hf_hub"
dataset_name = config["hf_hub_url"]
elif "ms_hub_url" in config:
load_from = "ms_hub"
dataset_name = config["ms_hub_url"]
elif "om_hub_url" in config:
load_from = "om_hub"
dataset_name = config["om_hub_url"]
else:
raise ValueError(
"Config must contain one of: "
"'file_name', 'script_url', 'cloud_file_name', "
"'hf_hub_url', 'ms_hub_url', or 'om_hub_url'."
)
attr = DatasetAttr(load_from=load_from, dataset_name=dataset_name)
attr.join(config)
return attr