import bisect
import os
import copy
from abc import abstractmethod, ABC
from collections import defaultdict
from dataclasses import dataclass, field
from itertools import chain
from enum import Enum, unique
from typing import TYPE_CHECKING, Any, Optional, Union, Tuple, Literal, List, Dict, Type, TypedDict
import torch
from transformers import PreTrainedTokenizer, ProcessorMixin, AutoProcessor, AutoConfig, AutoTokenizer, PretrainedConfig
from mindspeed_mm.data.data_utils.func_utils.log import get_logger
from mindspeed_mm.data.data_utils.func_utils.model_args import ProcessorArguments
from mindspeed_mm.data.data_utils.video_processor import VideoProcessor
from mindspeed_mm.data.data_utils.video_reader import VideoReader
IGNORE_INDEX = -100
logger = get_logger(__file__)
@unique
class Role(str, Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
FUNCTION = "function"
OBSERVATION = "observation"
TOOL_CALL = "tool_call"
TOOL_RESPONSE = "tool_response"
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments
from mindspeed_mm.data.data_utils.func_utils.template import Template
from .mm_plugin import AudioInput, ImageInput, VideoInput
MediaType = Union[ImageInput, VideoInput, AudioInput]
class TokenizerModule(TypedDict):
tokenizer: "PreTrainedTokenizer"
processor: Optional["ProcessorMixin"]
@dataclass
class DatasetConverter:
dataset_attr: "DatasetAttr"
data_args: "DataArguments"
def _find_media_files(self, media_files: Union["MediaType", List["MediaType"], None]) -> Optional[List["MediaType"]]:
r"""Optionally concatenate media path to media dir when loading from local disk."""
if media_files is None:
return None
elif not isinstance(media_files, list):
media_files = [media_files]
elif len(media_files) == 0:
return None
else:
media_files = media_files[:]
for i, media in enumerate(media_files):
if os.path.isfile(os.path.join(self.data_args.dataset_dir, media)):
media_files[i] = os.path.join(self.data_args.dataset_dir, media)
else:
logger.warning(f"Media {media} does not exist in `media_dir`. Use original path.")
return media_files
@abstractmethod
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]:
r"""Convert a single example in the dataset to the standard format."""
...
@dataclass
class AlpacaDatasetConverter(DatasetConverter):
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]:
prompt = []
if self.dataset_attr.history and isinstance(example[self.dataset_attr.history], list):
for old_prompt, old_response in example[self.dataset_attr.history]:
prompt.append({"role": Role.USER.value, "content": old_prompt})
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
query = []
if self.dataset_attr.prompt and example[self.dataset_attr.prompt]:
query.append(example[self.dataset_attr.prompt])
if self.dataset_attr.query and example[self.dataset_attr.query]:
query.append(example[self.dataset_attr.query])
prompt.append({"role": Role.USER.value, "content": "\n".join(query)})
if (
self.dataset_attr.ranking
and isinstance(example[self.dataset_attr.chosen], str)
and isinstance(example[self.dataset_attr.rejected], str)
):
response = [
{"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.chosen]},
{"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.rejected]},
]
elif self.dataset_attr.response and isinstance(example[self.dataset_attr.response], str):
response = [{"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.response]}]
else:
response = []
output = {
"_prompt": prompt,
"_response": response,
"_system": example[self.dataset_attr.system] if self.dataset_attr.system else "",
"_images": self._find_media_files(example[self.dataset_attr.images]) if self.dataset_attr.images else None,
"_videos": self._find_media_files(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None,
"_audios": self._find_media_files(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None,
}
return output
@dataclass
class SharegptDatasetConverter(DatasetConverter):
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]:
tag_mapping = {
self.dataset_attr.user_tag: Role.USER.value,
self.dataset_attr.assistant_tag: Role.ASSISTANT.value,
self.dataset_attr.observation_tag: Role.OBSERVATION.value,
self.dataset_attr.function_tag: Role.FUNCTION.value,
self.dataset_attr.system_tag: Role.SYSTEM.value,
}
odd_tags = (self.dataset_attr.user_tag, self.dataset_attr.observation_tag)
even_tags = (self.dataset_attr.assistant_tag, self.dataset_attr.function_tag)
accept_tags = (odd_tags, even_tags)
messages = example[self.dataset_attr.messages]
if (
self.dataset_attr.system_tag
and len(messages) != 0
and messages[0][self.dataset_attr.role_tag] == self.dataset_attr.system_tag
):
system = messages[0][self.dataset_attr.content_tag]
messages = messages[1:]
else:
system = example[self.dataset_attr.system] if self.dataset_attr.system else ""
aligned_messages = []
broken_data = False
for turn_idx, message in enumerate(messages):
if message[self.dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
logger.warning_rank0(f"Invalid role tag in {messages}.")
broken_data = True
break
aligned_messages.append(
{
"role": tag_mapping.get(message.get(self.dataset_attr.role_tag)),
"content": message.get(self.dataset_attr.content_tag),
}
)
is_invalid_message_count = (not self.dataset_attr.ranking and len(aligned_messages) % 2 != 0) or \
(self.dataset_attr.ranking and len(aligned_messages) % 2 == 0)
if is_invalid_message_count:
logger.warning_rank0(f"Invalid message count in {messages}.")
broken_data = True
if broken_data:
logger.warning_rank0("Skipping this abnormal example.")
prompt, response = [], []
elif (
self.dataset_attr.ranking
and isinstance(example[self.dataset_attr.chosen], dict)
and isinstance(example[self.dataset_attr.rejected], dict)
):
chosen = example[self.dataset_attr.chosen]
rejected = example[self.dataset_attr.rejected]
if (
chosen[self.dataset_attr.role_tag] not in accept_tags[-1]
or rejected[self.dataset_attr.role_tag] not in accept_tags[-1]
):
logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.")
broken_data = True
prompt = aligned_messages
response = [
{
"role": tag_mapping.get(chosen.get(self.dataset_attr.role_tag)),
"content": chosen.get(self.dataset_attr.content_tag),
},
{
"role": tag_mapping.get(rejected.get(self.dataset_attr.role_tag)),
"content": rejected.get(self.dataset_attr.content_tag),
},
]
else:
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
output = {
"_prompt": prompt,
"_response": response,
"_system": system,
"_images": self._find_media_files(example[self.dataset_attr.images]) if self.dataset_attr.images else None,
"_videos": self._find_media_files(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None,
"_audios": self._find_media_files(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None,
}
return output
@dataclass
class MultiModalToolDatasetConverter(DatasetConverter):
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]:
messages = example[self.dataset_attr.messages]
if (
self.dataset_attr.system_tag
and len(messages) != 0
and messages[0][self.dataset_attr.role_tag] == self.dataset_attr.system_tag
):
system = messages[0][self.dataset_attr.content_tag]
messages = messages[1:]
else:
system = example[self.dataset_attr.system] if self.dataset_attr.system else ""
aligned_messages = messages
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
output = {
"_prompt": prompt,
"_response": response,
"_system": system,
"_images": self._find_media_files(example[self.dataset_attr.images]) if self.dataset_attr.images else None,
"_videos": self._find_media_files(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None,
"_audios": self._find_media_files(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None,
"_tools": example['tools'] or None
}
return output
DATASET_CONVERTERS = {
"alpaca": AlpacaDatasetConverter,
"sharegpt": SharegptDatasetConverter,
"multimodal_tool": MultiModalToolDatasetConverter
}
def register_dataset_converter(name: str, dataset_converter: Type["DatasetConverter"]) -> None:
r"""Register a new dataset converter."""
if name in DATASET_CONVERTERS:
raise ValueError(f"Dataset converter {name} already exists.")
DATASET_CONVERTERS[name] = dataset_converter
def get_dataset_converter(name: str, dataset_attr: "DatasetAttr", data_args: "DataArguments") -> "DatasetConverter":
r"""Get a dataset converter."""
if name not in DATASET_CONVERTERS:
raise ValueError(f"Dataset converter {name} not found.")
return DATASET_CONVERTERS[name](dataset_attr, data_args)
def align_dataset(
dataset: Union["Dataset", "IterableDataset"],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Union["Dataset", "IterableDataset"]:
r"""Align the dataset to a specific format.
Aligned dataset:
_prompt: [{"role": "user", "content": "..."}] * (2T - 1)
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
_system: "..."
_images: []
_videos: []
_audios: []
"""
column_names = list(next(iter(dataset)).keys())
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache) or (int(os.getenv("LOCAL_RANK", -1)) != 0),
desc="Converting format of dataset",
)
dataset_converter = get_dataset_converter(dataset_attr.formatting, dataset_attr, data_args)
return dataset.map(
dataset_converter,
batched=False,
remove_columns=column_names,
**kwargs,
)
@dataclass
class DatasetAttr:
r"""
Dataset attributes.
"""
packing: bool = False
ranking: bool = False
pretrain: bool = False
system: Optional[str] = None
images: Optional[str] = None
videos: Optional[str] = None
audios: Optional[str] = None
prompt: Optional[str] = None
query: Optional[str] = None
response: Optional[str] = None
history: Optional[str] = None
messages: Optional[str] = "conversations"
role_tag: Optional[str] = "from"
content_tag: Optional[str] = "value"
user_tag: Optional[str] = "human"
assistant_tag: Optional[str] = "gpt"
observation_tag: Optional[str] = "observation"
function_tag: Optional[str] = "function_call"
system_tag: Optional[str] = "system"
chosen: Optional[str] = None
rejected: Optional[str] = None
formatting: Literal["alpaca", "sharegpt"] = "sharegpt"
@dataclass
class DataArguments:
r"""
Arguments pertaining to what data we are going to input our model for training and evaluation.
"""
cache_dir: Optional[str] = field(
default=None,
metadata={
"help": "Directory to read/write data. Defaults to `~/.cache/huggingface/datasets`(env:HF_DATASETS_CACHE)"},
)
template: Optional[str] = field(
default=None,
metadata={
"help": "Which template to use for constructing prompts in training and inference."},
)
enable_thinking: Optional[bool] = field(
default=True,
metadata={"help": "Whether or not to enable thinking mode for reasoning models."},
)
dataset_dir: str = field(
default="data",
metadata={"help": "Path to the folder containing the datasets."},
)
dataset: Optional[str] = field(
default=None,
metadata={
"help": "The name of dataset(s) to use for training. Use commas to separate multiple datasets."},
)
cutoff_len: int = field(
default=1024,
metadata={
"help": "The cutoff length of the tokenized inputs in the dataset."},
)
train_on_prompt: bool = field(
default=False,
metadata={"help": "Whether or not to disable the mask on the prompt."},
)
mask_history: bool = field(
default=False,
metadata={
"help": "Whether or not to mask the history and train on the last turn only."},
)
streaming: bool = field(
default=False,
metadata={"help": "Enable dataset streaming."},
)
overwrite_cache: bool = field(
default=False,
metadata={"help": "Overwrite the cached training and evaluation sets."},
)
preprocessing_batch_size: int = field(
default=1000,
metadata={"help": "The number of examples in one group in pre-processing."},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the pre-processing."},
)
max_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes, truncate the number of examples for each dataset."},
)
tool_format: Optional[str] = field(
default=None,
metadata={
"help": "Tool format to use for constructing function calling examples."},
)
val_dataset: Optional[str] = field(
default=None,
metadata={
"help": "Name of the validation dataset."},
)
val_max_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes, truncate the number of examples for each validation dataset."},
)
val_rate: Optional[float] = field(
default=None,
metadata={"help": "The proportion of the dataset to be used for validation."},
)
packing: Optional[bool] = field(
default=None,
metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."},
)
neat_packing: bool = field(
default=False,
metadata={"help": "Enable sequence packing without cross-attention."},
)
preprocess_on_fly: Optional[bool] = field(
default=False,
metadata={"help": "Whether to perform preprocess during training."},
)
async_preprocess: Optional[bool] = field(
default=False,
metadata={"help": "Whether to perform async preprocess during training."},
)
async_preprocess_buffer_size: Optional[int] = field(
default=None,
metadata={"help": "Buffer size for async preprocess. Defaults to 8 when not set and num_workers is unset."},
)
use_pmcc_data: Optional[bool] = field(
default=False,
metadata={"help": "Whether to use PMCC dataset."},
)
def __post_init__(self):
self.dataset = self.dataset.split(",")
@dataclass
class DataArgumentsForRewardVideo:
r"""
Arguments pertaining to what data we are going to input our model for training and evaluation.
"""
cache_dir: Optional[str] = field(
default=None,
metadata={
"help": "Directory to read/write data. Defaults to `~/.cache/huggingface/datasets`(env:HF_DATASETS_CACHE)"},
)
template: Optional[str] = field(
default=None,
metadata={
"help": "Which template to use for constructing prompts in training and inference."},
)
data_folder: str = field(
default="data",
metadata={"help": "Path to the folder containing the datasets."},
)
data_path: Optional[str] = field(
default=None,
metadata={
"help": "The name of dataset(s) to use for training. Use commas to separate multiple datasets."},
)
data_path_val: Optional[str] = field(
default=None,
metadata={
"help": "The name of dataset(s) to use for validation. Use commas to separate multiple datasets."},
)
train_on_prompt: bool = field(
default=False,
metadata={"help": "Whether or not to disable the mask on the prompt."},
)
mask_history: bool = field(
default=False,
metadata={
"help": "Whether or not to mask the history and train on the last turn only."},
)
streaming: bool = field(
default=False,
metadata={"help": "Enable dataset streaming."},
)
overwrite_cache: bool = field(
default=False,
metadata={"help": "Overwrite the cached training and evaluation sets."},
)
preprocessing_batch_size: int = field(
default=1000,
metadata={"help": "The number of examples in one group in pre-processing."},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the pre-processing."},
)
max_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes, truncate the number of examples for each dataset."},
)
tool_format: Optional[str] = field(
default=None,
metadata={
"help": "Tool format to use for constructing function calling examples."},
)
val_dataset: Optional[str] = field(
default=None,
metadata={
"help": "Name of the validation dataset."},
)
val_max_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes, truncate the number of examples for each validation dataset."},
)
val_rate: Optional[float] = field(
default=None,
metadata={"help": "The proportion of the dataset to be used for validation."},
)
packing: Optional[bool] = field(
default=None,
metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."},
)
neat_packing: bool = field(
default=False,
metadata={"help": "Enable sequence packing without cross-attention."},
)
def __post_init__(self):
self.data_path = self.data_path.split(",")
def search_for_fit(numbers: List[int], capacity: int) -> int:
r"""Find the index of largest number that fits into the knapsack with the given capacity."""
index = bisect.bisect(numbers, capacity)
return -1 if index == 0 else (index - 1)
def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
r"""Implement efficient greedy algorithm with binary search for the knapsack problem."""
numbers.sort()
knapsacks = []
while numbers:
current_knapsack = []
remaining_capacity = capacity
while True:
index = search_for_fit(numbers, remaining_capacity)
if index == -1:
break
remaining_capacity -= numbers[index]
current_knapsack.append(numbers.pop(index))
knapsacks.append(current_knapsack)
return knapsacks
@dataclass
class DatasetProcessor(ABC):
r"""A class for data processors."""
template: "Template"
tokenizer: "PreTrainedTokenizer"
processor: Optional["ProcessorMixin"]
data_args: "DataArguments"
@abstractmethod
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
r"""Build model inputs from the examples."""
...
@abstractmethod
def print_data_example(self, example: Dict[str, List[int]]) -> None:
r"""Print a data example to stdout."""
...
@dataclass
class SupervisedDatasetProcessor(DatasetProcessor):
def _encode_data_example(
self,
prompt: List[Dict[str, str]],
response: List[Dict[str, str]],
system: Optional[str],
images: List["ImageInput"],
videos: List["VideoInput"],
audios: List["AudioInput"],
tools: List[str]
) -> Tuple[List[int], List[int]]:
if hasattr(self.data_args, "use_pmcc_data") and self.data_args.use_pmcc_data:
it = (item['content'] for item in prompt + response)
input_ids, labels = [], []
encoded_pairs = list(zip(it, it))
else:
messages = self.template.mm_plugin.process_messages(prompt + response, images, videos, audios, self.processor)
input_ids, labels = self.template.mm_plugin.process_token_ids(
[], [], images, videos, audios, self.tokenizer, self.processor
)
encoded_pairs = self.template.encode_multiturn(self.tokenizer, messages, system, tools)
total_length = len(input_ids) + (1 if self.template.efficient_eos else 0)
if self.data_args.mask_history:
encoded_pairs = encoded_pairs[::-1]
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
if total_length >= self.data_args.cutoff_len:
logger.info(
f"Maximum sequence length {self.data_args.cutoff_len} reached. "
f"Please increase seq_len or cutoff_len in config."
)
break
source_len, target_len = infer_seqlen(
len(source_ids), len(target_ids), self.data_args.cutoff_len - total_length
)
source_ids = source_ids[:source_len]
target_ids = target_ids[:target_len]
total_length += source_len + target_len
if self.data_args.train_on_prompt:
source_label = source_ids
elif self.template.efficient_eos:
source_label = [self.tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
else:
source_label = [IGNORE_INDEX] * source_len
if self.data_args.mask_history and turn_idx != 0:
target_label = [IGNORE_INDEX] * target_len
else:
target_label = target_ids
if self.data_args.mask_history:
input_ids = source_ids + target_ids + input_ids
labels = source_label + target_label + labels
else:
input_ids += source_ids + target_ids
labels += source_label + target_label
if self.template.efficient_eos:
input_ids += [self.tokenizer.eos_token_id]
labels += [self.tokenizer.eos_token_id]
return input_ids, labels
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
try:
tool_schema = []
if '_tools' in examples:
tool_schema = examples['_tools'][i]
input_ids, labels = self._encode_data_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
system=examples["_system"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
audios=examples["_audios"][i] or [],
tools=tool_schema
)
except OSError as e:
err_img = examples["_images"][i] if examples["_images"][i] else "No images"
logger.warning(f"Skipping invalid sample: {err_img}. Error: {str(e)}")
continue
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
model_inputs["audios"].append(examples["_audios"][i])
return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None:
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print(f"labels:\n{self.tokenizer.decode(valid_labels, skip_special_tokens=False)}")
@dataclass
class PretrainDatasetProcessor(DatasetProcessor):
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
eos_token = "<|end_of_text|>" if self.data_args.template == "llama3" else self.tokenizer.eos_token
text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]]
if not self.data_args.packing:
if getattr(self.tokenizer, "add_bos_token", False):
text_examples = [self.tokenizer.bos_token + example for example in text_examples]
result = self.tokenizer(
text_examples, add_special_tokens=False, truncation=True, max_length=self.data_args.cutoff_len
)
else:
tokenized_examples = self.tokenizer(text_examples, add_special_tokens=False)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = self.data_args.cutoff_len
total_length = (total_length // block_size) * block_size
result = {
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
if getattr(self.tokenizer, "add_bos_token", False):
for i in range(len(result["input_ids"])):
result["input_ids"][i][0] = self.tokenizer.bos_token_id
return result
def print_data_example(self, example: dict[str, list[int]]) -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
@dataclass
class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
valid_num = 0
batch_input_ids, batch_labels, batch_images, batch_videos, batch_audios = [], [], [], [], []
lengths = []
length2indexes = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
tool_schema = []
if '_tools' in examples:
tool_schema = examples['_tools'][i]
input_ids, labels = self._encode_data_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
system=examples["_system"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
audios=examples["_audios"][i] or [],
tools=tool_schema
)
length = len(input_ids)
if length > self.data_args.cutoff_len:
logger.warning_rank0(f"Dropped lengthy example with length {length} > {self.data_args.cutoff_len}.")
else:
lengths.append(length)
length2indexes[length].append(valid_num)
batch_input_ids.append(input_ids)
batch_labels.append(labels)
batch_images.append(examples["_images"][i] or [])
batch_videos.append(examples["_videos"][i] or [])
batch_audios.append(examples["_audios"][i] or [])
valid_num += 1
model_inputs = defaultdict(list)
knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len)
for knapsack in knapsacks:
packed_input_ids, packed_attention_masks, packed_position_ids, packed_labels = [], [], [], []
packed_images, packed_videos, packed_audios = [], [], []
for i, length in enumerate(knapsack):
index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index]
packed_position_ids += list(range(len(batch_input_ids[index])))
packed_labels += batch_labels[index]
packed_images += batch_images[index]
packed_videos += batch_videos[index]
packed_audios += batch_audios[index]
if self.data_args.neat_packing:
packed_attention_masks += [i + 1] * len(batch_input_ids[index])
else:
packed_attention_masks += [1] * len(batch_input_ids[index])
if len(packed_input_ids) < self.data_args.cutoff_len + 1:
pad_length = self.data_args.cutoff_len - len(packed_input_ids) + 1
packed_input_ids += [self.tokenizer.pad_token_id] * pad_length
packed_position_ids += [0] * pad_length
packed_labels += [IGNORE_INDEX] * pad_length
if self.data_args.neat_packing:
packed_attention_masks += [0] * pad_length
else:
packed_attention_masks += [1] * pad_length
if len(packed_input_ids) != self.data_args.cutoff_len + 1:
raise ValueError("The length of packed example should be identical to the cutoff length.")
model_inputs["input_ids"].append(packed_input_ids)
model_inputs["attention_mask"].append(packed_attention_masks)
model_inputs["position_ids"].append(packed_position_ids)
model_inputs["labels"].append(packed_labels)
model_inputs["images"].append(packed_images or None)
model_inputs["videos"].append(packed_videos or None)
model_inputs["audios"].append(packed_audios or None)
return model_inputs
class PairwiseDatasetProcessor(DatasetProcessor):
def _encode_data_example(
self,
prompt: List[Dict[str, str]],
response: List[Dict[str, str]],
system: Optional[str],
images: List["ImageInput"],
videos: List["VideoInput"],
audios: List["AudioInput"],
) -> Tuple[List[int], List[int], List[int], List[int]]:
chosen_messages = self.template.mm_plugin.process_messages(
prompt + [response[0]], images, videos, audios, self.processor
)
rejected_messages = self.template.mm_plugin.process_messages(
prompt + [response[1]], images, videos, audios, self.processor
)
prompt_ids, chosen_ids = self.template.encode_oneturn(self.tokenizer, chosen_messages, system)
_, rejected_ids = self.template.encode_oneturn(self.tokenizer, rejected_messages, system)
if self.template.efficient_eos:
chosen_ids += [self.tokenizer.eos_token_id]
rejected_ids += [self.tokenizer.eos_token_id]
prompt_ids, _ = self.template.mm_plugin.process_token_ids(
prompt_ids, None, images, videos, audios, self.tokenizer, self.processor
)
source_len, target_len = infer_seqlen(
len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), self.data_args.cutoff_len
)
prompt_ids = prompt_ids[:source_len]
chosen_ids = chosen_ids[:target_len]
rejected_ids = rejected_ids[:target_len]
chosen_input_ids = prompt_ids + chosen_ids
chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids
rejected_input_ids = prompt_ids + rejected_ids
rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = self._encode_data_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
system=examples["_system"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
audios=examples["_audios"][i] or [],
)
model_inputs["chosen_input_ids"].append(chosen_input_ids)
model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids))
model_inputs["chosen_labels"].append(chosen_labels)
model_inputs["rejected_input_ids"].append(rejected_input_ids)
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
model_inputs["rejected_labels"].append(rejected_labels)
model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
model_inputs["audios"].append(examples["_audios"][i])
return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None:
valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"]))
valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"]))
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
print(
"chosen_inputs:\n{}".format(self.tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False))
)
print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
print(f"chosen_labels:\n{self.tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)}")
print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
print(
"rejected_inputs:\n{}".format(
self.tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)
)
)
print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
print(f"rejected_labels:\n{self.tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)}")
class VideoRewardProcessor(DatasetProcessor):
def __init__(self, template, tokenizer, processor, data_args, video_reader, video_processor):
from mindspeed_mm.data.data_utils.reward_preprocess import clean_examples
super().__init__(template, tokenizer, processor, data_args, )
self.video_reader = video_reader
self.video_processor = video_processor
self.clean_examples = clean_examples
def _pad_sequence(self, sequences, attention_mask, max_len, padding_side='right'):
"""
Pad the sequences to the maximum length.
"""
if sequences.shape[1] >= max_len:
return sequences, attention_mask
pad_len = max_len - sequences.shape[1]
padding = (0, pad_len) if padding_side == 'right' else (pad_len, 0)
sequences_padded = torch.nn.functional.pad(sequences, padding, 'constant', self.processor.tokenizer.pad_token_id)
attention_mask_padded = torch.nn.functional.pad(attention_mask, padding, 'constant', 0)
return sequences_padded, attention_mask_padded
def _encode_data_example(
self,
) -> Tuple[List[int], List[int]]:
has_idx = "metainfo_idx" in self.examples and self.examples["metainfo_idx"] is not None
A_data = self.clean_examples(self.examples['A_data'])
B_data = self.clean_examples(self.examples['B_data'])
video_inputs_A = self.video_processor(self.video_reader(A_data[0]['content'][0]["video"])) / 255.0
video_inputs_B = self.video_processor(self.video_reader(B_data[0]['content'][0]["video"])) / 255.0
batch_A = self.processor(
text=self.processor.apply_chat_template([A_data], tokenize=False, add_generation_prompt=True),
images=None,
videos=[video_inputs_A],
padding=True,
return_tensors="pt",
videos_kwargs={"do_rescale": False},
)
batch_B = self.processor(
text=self.processor.apply_chat_template([B_data], tokenize=False, add_generation_prompt=True),
images=None,
videos=[video_inputs_B],
padding=True,
return_tensors="pt",
videos_kwargs={"do_rescale": False},
)
max_len = max(batch_A["input_ids"].shape[1], batch_B["input_ids"].shape[1])
batch_A["input_ids"], batch_A["attention_mask"] = self._pad_sequence(batch_A["input_ids"], batch_A["attention_mask"], max_len, "right")
batch_B["input_ids"], batch_B["attention_mask"] = self._pad_sequence(batch_B["input_ids"], batch_B["attention_mask"], max_len, "right")
batch = {
"A": batch_A,
"B": batch_B,
"return_loss": True,
}
if has_idx:
metainfo_idx = torch.tensor(self.examples["metainfo_idx"])
batch["metainfo_idx"] = metainfo_idx
return batch
def preprocess_dataset(self, examples: List) -> Dict[str, List[Any]]:
model_inputs = defaultdict(list)
for example in examples:
self.examples = example
batch = self._encode_data_example()
model_inputs["A"].append(batch["A"])
model_inputs["B"].append(batch["B"])
if "metainfo_idx" in batch:
model_inputs["metainfo_idx"].append(batch["metainfo_idx"])
model_inputs["A_scores"].append(example["A_scores"])
model_inputs["B_scores"].append(example["B_scores"])
model_inputs["chosen_label"].append(example["chosen_label"])
return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None:
print("chosen_label:\n{}".format(example["chosen_label"]))
print("A_scores:\n{}".format(example["A_scores"]))
print("B_scores:\n{}".format(example["B_scores"]))
def reward_setting_processor(preprocess_args_dict):
eval_dim = preprocess_args_dict.pop("eval_dim", ["VQ"])
train_pipeline = preprocess_args_dict.pop("train_pipeline", {})
video_reader_type = preprocess_args_dict.pop('video_reader_type', "DecordVideo")
video_processor_type = preprocess_args_dict.pop('video_processor_type', "RewardVideoProcessor")
sample_type = preprocess_args_dict.pop('sample_type', "uniform")
sample_nframe = preprocess_args_dict.pop('sample_nframe', None)
fps = preprocess_args_dict.pop('fps', 2.0)
video_max_pixels = preprocess_args_dict.pop('video_max_pixels', 200704)
video_min_pixels = preprocess_args_dict.pop('video_min_pixels', 100352)
split_special_tokens = preprocess_args_dict.pop("split_special_tokens", False)
video_reader = VideoReader(video_reader_type=video_reader_type)
video_processor = VideoProcessor.create(
video_processor_type=video_processor_type,
fps=fps,
video_min_pixels=video_min_pixels,
video_max_pixels=video_max_pixels,
sample_type=sample_type,
sample_nframe=sample_nframe,
train_pipeline=train_pipeline
)
tokenizer_module = load_reward_tokenizer(preprocess_args_dict)
tokenizer, processor = tokenizer_module['tokenizer'], tokenizer_module['processor']
special_token_ids = None
token_embedding_length = None
if split_special_tokens:
special_tokens = ["<|VQ_reward|>", "<|MQ_reward|>", "<|TA_reward|>"]
tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})
special_token_ids = tokenizer.convert_tokens_to_ids(special_tokens)
token_embedding_length = len(tokenizer)
tokenizer_padding_side = "right"
pad_token_id = tokenizer.pad_token_id
model_args = {"special_token_ids": special_token_ids, "token_embedding_length": token_embedding_length,
"tokenizer_padding_side": tokenizer_padding_side, "pad_token_id": pad_token_id}
return video_reader, video_processor, tokenizer, processor, model_args
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
r"""
Computes the real sequence length after truncation by the cutoff_len.
"""
if target_len * 2 < cutoff_len:
max_target_len = cutoff_len
elif source_len * 2 < cutoff_len:
max_target_len = cutoff_len - source_len
else:
max_target_len = int(
cutoff_len * (target_len / (source_len + target_len)))
new_target_len = min(max_target_len, target_len)
max_source_len = max(cutoff_len - new_target_len, 0)
new_source_len = min(max_source_len, source_len)
return new_source_len, new_target_len
def get_vision_feature_select_strategy(config: "PretrainedConfig") -> int:
r"""
Get the vision_feature_select_strategy.
"""
vision_feature_select_strategy = getattr(config, "vision_feature_select_strategy", "default")
return vision_feature_select_strategy
def load_tokenizer(model_args: "ProcessorArguments") -> "TokenizerModule":
r"""
Loads pretrained tokenizer and optionally loads processor.
Note: including inplace operation of model_args.
"""
fix_mistral_regex = getattr(model_args, "fix_mistral_regex", False)
trust_remote_code = getattr(model_args, "trust_remote_code", False)
config = AutoConfig.from_pretrained(model_args.model_name_or_path, local_files_only=True)
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=trust_remote_code,
fix_mistral_regex=fix_mistral_regex,
use_fast=model_args.use_fast_tokenizer,
split_special_tokens=model_args.split_special_tokens,
padding_side="right", local_files_only=True
)
try:
processor = AutoProcessor.from_pretrained(
model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer,
local_files_only=True)
setattr(processor, "tokenizer", tokenizer)
setattr(processor, "image_max_pixels", model_args.image_max_pixels)
setattr(processor, "image_min_pixels", model_args.image_min_pixels)
setattr(processor, "image_do_pan_and_scan", model_args.image_do_pan_and_scan)
setattr(processor, "crop_to_patches", model_args.crop_to_patches)
setattr(processor, "video_max_pixels", model_args.video_max_pixels)
setattr(processor, "video_min_pixels", model_args.video_min_pixels)
setattr(processor, "video_fps", model_args.video_fps)
setattr(processor, "video_maxlen", model_args.video_maxlen)
setattr(processor, "audio_sampling_rate", model_args.audio_sampling_rate)
setattr(processor, "use_audio_in_video", model_args.use_audio_in_video)
except Exception as e:
logger.warning("Processor was not found: %s.", e)
processor = None
if processor is not None and "Processor" not in processor.__class__.__name__:
processor = None
return {"tokenizer": tokenizer, "processor": processor}
def load_reward_tokenizer(model_args) -> "TokenizerModule":
r"""
Loads pretrained tokenizer and optionally loads processor for reward model.
"""
try:
processor = AutoProcessor.from_pretrained(model_args['model_name_or_path'], padding_side="right", local_files_only=getattr(model_args, 'local_files_only', False))
except Exception as e:
logger.warning("Processor was not found: %s.", e)
processor = None
if processor is not None and "Processor" not in processor.__class__.__name__:
processor = None
return {"tokenizer": processor.tokenizer, "processor": processor}