import random
import json
from copy import deepcopy
import numpy as np
import torch
import torch.distributed as dist
from torch.nn.attention.flex_attention import or_masks, and_masks
from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer
from mindspeed_mm.data.data_utils.data_transform import MaxLongEdgeMinShortEdgeResize
from mindspeed_mm.data.datasets.bagel_iterable_dataset import T2IIterableDataset, SftJSONLIterableDataset
DATASET_REGISTRY = {
't2i_pretrain': T2IIterableDataset,
'vlm_sft': SftJSONLIterableDataset,
}
def create_sparse_mask(document_lens, split_lens, attn_modes, device):
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
def full_and_noise_mask(b, h, q_idx, kv_idx):
return (full_and_noise_seq_id[q_idx] == full_and_noise_seq_id[kv_idx]) & (full_and_noise_seq_id[q_idx] >= 0)
def remove_noise_mask(b, h, q_idx, kv_idx):
return (~((noise_seq_id[kv_idx] >= 0) & (noise_seq_id[q_idx] != noise_seq_id[kv_idx])))
def sample_mask(b, h, q_idx, kv_idx):
return document_id[q_idx] == document_id[kv_idx]
full_and_noise_tmp = []
noise_tmp = []
for i, (length, model) in enumerate(zip(split_lens, attn_modes)):
value = i if model in ['full', 'noise'] else -1
full_and_noise_tmp.extend([value] * length)
value_noise = i if model == 'noise' else -1
noise_tmp.extend([value_noise] * length)
full_and_noise_seq_id = torch.Tensor(full_and_noise_tmp).to(device)
noise_seq_id = torch.Tensor(noise_tmp).to(device)
document_id = torch.cat([torch.full((doc_len,), i) for i, doc_len in enumerate(document_lens, start=1)]).to(device)
return and_masks(or_masks(causal_mask, full_and_noise_mask), remove_noise_mask, sample_mask)
def add_special_tokens(tokenizer):
all_special_tokens = []
for _, v in tokenizer.special_tokens_map.items():
if isinstance(v, str):
all_special_tokens.append(v)
elif isinstance(v, list):
all_special_tokens += v
new_tokens = []
if '<|im_start|>' not in all_special_tokens:
new_tokens.append('<|im_start|>')
if '<|im_end|>' not in all_special_tokens:
new_tokens.append('<|im_end|>')
if '<|vision_start|>' not in all_special_tokens:
new_tokens.append('<|vision_start|>')
if '<|vision_end|>' not in all_special_tokens:
new_tokens.append('<|vision_end|>')
num_new_tokens = tokenizer.add_tokens(new_tokens)
bos_token_id = tokenizer.convert_tokens_to_ids('<|im_start|>')
eos_token_id = tokenizer.convert_tokens_to_ids('<|im_end|>')
start_of_image = tokenizer.convert_tokens_to_ids('<|vision_start|>')
end_of_image = tokenizer.convert_tokens_to_ids('<|vision_end|>')
new_token_ids = dict(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
start_of_image=start_of_image,
end_of_image=end_of_image,
)
return tokenizer, new_token_ids
def get_flattened_position_ids_extrapolate(img_h, img_w, patch_size, max_num_patches_per_side):
num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size
coords_h = torch.arange(0, num_patches_h)
coords_w = torch.arange(0, num_patches_w)
pos_ids = (coords_h[:, None] * max_num_patches_per_side + coords_w).flatten()
return pos_ids
def get_flattened_position_ids_interpolate(img_h, img_w, patch_size, max_num_patches_per_side):
num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size
boundaries = torch.arange(1 / max_num_patches_per_side, 1.0, 1 / max_num_patches_per_side)
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / num_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / num_patches_w)
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
pos_ids = (bucket_coords_h[:, None] * max_num_patches_per_side + bucket_coords_w).flatten()
return pos_ids
def prepare_attention_mask_per_sample(split_lens, attn_modes, device="cpu"):
"""
nested_split_lens: A list of N lists of ints. Each int indicates the length of a split within
a sample, where each sample contains multiple splits with different attn modes.
nested_attn_modes: whether to use full attn in each split.
"""
sample_len = sum(split_lens)
attention_mask = torch.zeros((sample_len, sample_len), dtype=torch.bool, device=device)
csum = 0
for s, attn_mode in zip(split_lens, attn_modes):
if attn_mode == "causal":
attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s), device=device).tril()
attention_mask[csum:csum + s, :csum] = 1
else:
attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s))
attention_mask[csum:csum + s, :csum] = 1
csum += s
csum = 0
for s, attn_mode in zip(split_lens, attn_modes):
if attn_mode == "noise":
attention_mask[:, csum:csum + s] = torch.zeros((sample_len, s))
attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s))
csum += s
attention_mask = torch.zeros_like(attention_mask, dtype=torch.float).masked_fill_(
~attention_mask, float("-inf")
)
return attention_mask
class BagelMultiDataset(torch.utils.data.IterableDataset):
def __init__(
self,
basic_param: dict,
preprocess_param: dict,
**kwargs,
):
super().__init__()
packed_parameters = kwargs.pop("packed_parameters", None)
self.config = self.combine_dicts(basic_param, packed_parameters)
self.local_rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.tokenizer = Qwen2Tokenizer.from_pretrained(self.config.pop('model_path', None))
self.tokenizer, special_tokens = add_special_tokens(self.tokenizer)
for k, v in special_tokens.items():
setattr(self, k, v)
self.available_data = basic_param.pop("available_data", None)
grouped_datasets, is_mandatory, grouped_weights = self.build_datasets(
basic_param, self.config.pop('data_status', None)
)
self.dataset_iters = [iter(dataset) for dataset in grouped_datasets]
self.is_mandatory = is_mandatory
self.grouped_weights = grouped_weights
if self.config.pop('interpolate_pos', False):
self.get_flattened_position_ids = get_flattened_position_ids_interpolate
else:
self.get_flattened_position_ids = get_flattened_position_ids_extrapolate
def build_datasets(self, datasets_metainfo, data_status):
datasets = []
is_mandatory = []
grouped_weights = []
for grouped_dataset_name, dataset_args in datasets_metainfo.items():
if grouped_dataset_name not in self.available_data:
continue
is_mandatory.append(dataset_args.pop('is_mandatory', False))
grouped_weights.append(dataset_args.pop('weight', 0.0))
if 'image_transform_args' in dataset_args.keys():
transform = MaxLongEdgeMinShortEdgeResize(**dataset_args.pop('image_transform_args'))
dataset_args['transform'] = transform
dataset_names = dataset_args.pop('dataset_names')
dataset_args['data_dir_list'] = []
for item in dataset_names:
if self.local_rank == 0:
print(f'Preparing Dataset {grouped_dataset_name}/{item}')
meta_info = dataset_args[item]
dataset_args['data_dir_list'].append(meta_info['data_dir'])
if 'jsonl_path' in meta_info.keys():
if 'jsonl_path_list' not in dataset_args.keys():
dataset_args['jsonl_path_list'] = [meta_info['jsonl_path']]
else:
dataset_args['jsonl_path_list'].append(meta_info['jsonl_path'])
dataset_args.pop(item)
resume_data_status = dataset_args.pop('resume_data_status', True)
if data_status is not None and grouped_dataset_name in data_status.keys() and resume_data_status:
data_status_per_group = data_status[grouped_dataset_name]
else:
data_status_per_group = None
dataset = DATASET_REGISTRY[grouped_dataset_name](
dataset_name=grouped_dataset_name,
tokenizer=self.tokenizer,
local_rank=self.local_rank,
world_size=self.world_size,
num_workers=self.config["num_workers"],
data_status=data_status_per_group,
**dataset_args
)
datasets.append(dataset)
return datasets, is_mandatory, grouped_weights
def combine_dicts(self, *args):
merged = {}
for source_dict in args:
if not isinstance(source_dict, dict):
raise TypeError(f"Merged parameters must be of type dict, current type: {type(source_dict)}")
self.recursive_merge(target=merged, source=source_dict)
return merged
def recursive_merge(self, target, source):
for key, value in source.items():
if key in target:
if isinstance(target[key], dict) and isinstance(value, dict):
self.recursive_merge(target=target[key], source=value)
else:
if not isinstance(target[key], list):
target[key] = [target[key]]
target[key].append(value)
else:
target[key] = value
def set_sequence_status(self):
sequence_status = dict(
curr=0,
sample_lens=list(),
packed_position_ids=list(),
nested_attention_masks=list(),
split_lens=list(),
attn_modes=list(),
packed_text_ids=list(),
packed_text_indexes=list(),
packed_label_ids=list(),
ce_loss_indexes=list(),
ce_loss_weights=list(),
vae_image_tensors=list(),
packed_latent_position_ids=list(),
vae_latent_shapes=list(),
packed_vae_token_indexes=list(),
packed_timesteps=list(),
mse_loss_indexes=list(),
packed_vit_tokens=list(),
vit_token_seqlens=list(),
packed_vit_position_ids=list(),
packed_vit_token_indexes=list(),
)
return sequence_status
def to_tensor(self, sequence_status):
data = dict(
sequence_length=sum(sequence_status['sample_lens']),
sample_lens=sequence_status['sample_lens'],
packed_text_ids=torch.tensor(sequence_status['packed_text_ids']),
packed_text_indexes=torch.tensor(sequence_status['packed_text_indexes']),
packed_position_ids=torch.tensor(sequence_status['packed_position_ids']),
)
if not self.config["use_flex"]:
data['nested_attention_masks'] = sequence_status['nested_attention_masks']
else:
sequence_len = data['sequence_length']
pad_len = self.config["max_num_tokens"] - sequence_len
data['split_lens'] = sequence_status['split_lens'] + [pad_len]
data['attn_modes'] = sequence_status['attn_modes'] + ['causal']
data['sample_lens'] += [pad_len]
if len(sequence_status['vae_image_tensors']) > 0:
image_tensors = sequence_status.pop('vae_image_tensors')
image_sizes = [item.shape for item in image_tensors]
max_image_size = [max(item) for item in list(zip(*image_sizes))]
padded_images = torch.zeros(size=(len(image_tensors), *max_image_size))
for i, image_tensor in enumerate(image_tensors):
padded_images[i, :, :image_tensor.shape[1], :image_tensor.shape[2]] = image_tensor
data['padded_images'] = padded_images
data['patchified_vae_latent_shapes'] = sequence_status['vae_latent_shapes']
data['packed_latent_position_ids'] = torch.cat(sequence_status['packed_latent_position_ids'], dim=0)
data['packed_vae_token_indexes'] = torch.tensor(sequence_status['packed_vae_token_indexes'])
if len(sequence_status['packed_vit_tokens']) > 0:
data['packed_vit_tokens'] = torch.cat(sequence_status['packed_vit_tokens'], dim=0)
data['packed_vit_position_ids'] = torch.cat(sequence_status['packed_vit_position_ids'], dim=0)
data['packed_vit_token_indexes'] = torch.tensor(sequence_status['packed_vit_token_indexes'])
data['vit_token_seqlens'] = torch.tensor(sequence_status['vit_token_seqlens'])
if len(sequence_status['packed_timesteps']) > 0:
data['packed_timesteps'] = torch.tensor(sequence_status['packed_timesteps'])
data['mse_loss_indexes'] = torch.tensor(sequence_status['mse_loss_indexes'])
if len(sequence_status['packed_label_ids']) > 0:
data['packed_label_ids'] = torch.tensor(sequence_status['packed_label_ids'])
data['ce_loss_indexes'] = torch.tensor(sequence_status['ce_loss_indexes'])
data['ce_loss_weights'] = torch.tensor(sequence_status['ce_loss_weights'])
return data
def __iter__(self):
total_weights = sum(self.grouped_weights)
group_cumprobs = [sum(self.grouped_weights[:i + 1]) / total_weights
for i in range(len(self.grouped_weights))]
sequence_status = self.set_sequence_status()
batch_data_indexes = []
buffer = []
while True:
if sequence_status['curr'] == 0:
for group_index, group_iter in enumerate(self.dataset_iters):
if self.is_mandatory[group_index]:
while True:
sample = next(group_iter)
num_tokens = sample['num_tokens'] + 2 * len(sample['sequence_plan'])
if num_tokens < self.config["max_num_tokens_per_sample"]:
sequence_status = self.pack_sequence(sample, sequence_status)
batch_data_indexes.append(sample['data_indexes'])
break
else:
print(f"skip a sample with length {num_tokens}")
continue
if sequence_status['curr'] < self.config["prefer_buffer_before"] and len(buffer) > 0:
sample = buffer.pop(0)
sample_from_buffer = True
else:
n = random.random()
group_index = 0
for i, cumprob in enumerate(group_cumprobs):
if n < cumprob:
group_index = i
break
sample = next(self.dataset_iters[group_index])
sample_from_buffer = False
num_tokens = sample['num_tokens'] + 2 * len(sample['sequence_plan'])
if num_tokens > self.config["max_num_tokens_per_sample"]:
print(f"skip a sample with length {num_tokens}")
continue
if sequence_status['curr'] + num_tokens > self.config["max_num_tokens"]:
if len(buffer) < self.config["max_buffer_size"] and not sample_from_buffer:
buffer.append(sample)
else:
print(f"Yielding data with length {sum(sequence_status['sample_lens'])}")
data = self.to_tensor(sequence_status)
data['batch_data_indexes'] = batch_data_indexes
yield data
sequence_status = self.set_sequence_status()
batch_data_indexes = []
continue
sequence_status = self.pack_sequence(sample, sequence_status)
batch_data_indexes.append(sample['data_indexes'])
if sequence_status['curr'] >= self.config["expected_num_tokens"]:
data = self.to_tensor(sequence_status)
data['batch_data_indexes'] = batch_data_indexes
yield data
sequence_status = self.set_sequence_status()
batch_data_indexes = []
def patchify(self, image, patch_size):
p = patch_size
c, h, w = image.shape
image = image.reshape(c, h // p, p, w // p, p)
image = torch.einsum("chpwq->hwpqc", image)
image = image.reshape(-1, p ** 2 * c)
return image
def pack_sequence(self, sample, sequence_status):
image_tensor_list = sample['image_tensor_list']
text_ids_list = sample['text_ids_list']
sequence_plan = sample['sequence_plan']
split_lens, attn_modes = list(), list()
curr = sequence_status['curr']
curr_rope_id = 0
sample_lens = 0
for item in sequence_plan:
split_start = item.get('split_start', True)
if split_start:
curr_split_len = 0
if item['type'] == 'text':
text_ids = text_ids_list.pop(0)
if item['enable_cfg'] == 1 and random.random() < self.config["text_cond_dropout_prob"]:
continue
shifted_text_ids = [self.bos_token_id] + text_ids
sequence_status['packed_text_ids'].extend(shifted_text_ids)
sequence_status['packed_text_indexes'].extend(range(curr, curr + len(shifted_text_ids)))
if item['loss'] == 1:
sequence_status['ce_loss_indexes'].extend(range(curr, curr + len(shifted_text_ids)))
weight = 1 / (len(shifted_text_ids) ** 0.5)
sequence_status['ce_loss_weights'].extend([weight] * len(shifted_text_ids))
sequence_status['packed_label_ids'].extend(text_ids + [self.eos_token_id])
curr += len(shifted_text_ids)
curr_split_len += len(shifted_text_ids)
sequence_status['packed_text_ids'].append(self.eos_token_id)
sequence_status['packed_text_indexes'].append(curr)
if item['special_token_loss'] == 1:
sequence_status['ce_loss_indexes'].append(curr)
sequence_status['ce_loss_weights'].append(1.0)
sequence_status['packed_label_ids'].append(item['special_token_label'])
curr += 1
curr_split_len += 1
attn_modes.append("causal")
sequence_status['packed_position_ids'].extend(range(curr_rope_id, curr_rope_id + curr_split_len))
curr_rope_id += curr_split_len
elif item['type'] == 'vit_image':
image_tensor = image_tensor_list.pop(0)
if item['enable_cfg'] == 1 and random.random() < self.config["vit_cond_dropout_prob"]:
curr_rope_id += 1
continue
sequence_status['packed_text_ids'].append(self.start_of_image)
sequence_status['packed_text_indexes'].append(curr)
curr += 1
curr_split_len += 1
vit_tokens = self.patchify(image_tensor, self.config["vit_patch_size"])
num_img_tokens = vit_tokens.shape[0]
sequence_status['packed_vit_token_indexes'].extend(range(curr, curr + num_img_tokens))
curr += num_img_tokens
curr_split_len += num_img_tokens
sequence_status['packed_vit_tokens'].append(vit_tokens)
sequence_status['vit_token_seqlens'].append(num_img_tokens)
sequence_status['packed_vit_position_ids'].append(
self.get_flattened_position_ids(
image_tensor.size(1), image_tensor.size(2),
self.config["vit_patch_size"],
max_num_patches_per_side=self.config["max_num_patch_per_side"]
)
)
sequence_status['packed_text_ids'].append(self.end_of_image)
sequence_status['packed_text_indexes'].append(curr)
if item['special_token_loss'] == 1:
sequence_status['ce_loss_indexes'].append(curr)
sequence_status['ce_loss_weights'].append(1.0)
sequence_status['packed_label_ids'].append(item['special_token_label'])
curr += 1
curr_split_len += 1
attn_modes.append("full")
sequence_status['packed_position_ids'].extend([curr_rope_id] * curr_split_len)
curr_rope_id += 1
elif item['type'] == 'vae_image':
image_tensor = image_tensor_list.pop(0)
if item['enable_cfg'] == 1 and random.random() < self.config["vae_cond_dropout_prob"]:
curr_rope_id += 1
continue
sequence_status['packed_text_ids'].append(self.start_of_image)
sequence_status['packed_text_indexes'].append(curr)
curr += 1
curr_split_len += 1
sequence_status['vae_image_tensors'].append(image_tensor)
sequence_status['packed_latent_position_ids'].append(
self.get_flattened_position_ids(
image_tensor.size(1), image_tensor.size(2),
self.config["vae_image_downsample"],
max_num_patches_per_side=self.config['max_latent_size']
)
)
H, W = image_tensor.shape[1:]
h = H // self.config["vae_image_downsample"]
w = W // self.config["vae_image_downsample"]
sequence_status['vae_latent_shapes'].append((h, w))
num_img_tokens = w * h
sequence_status['packed_vae_token_indexes'].extend(range(curr, curr + num_img_tokens))
if item['loss'] == 1:
sequence_status['mse_loss_indexes'].extend(range(curr, curr + num_img_tokens))
if split_start:
np.random.seed(42)
timestep = 0.5
else:
timestep = float('-inf')
sequence_status['packed_timesteps'].extend([timestep] * num_img_tokens)
curr += num_img_tokens
curr_split_len += num_img_tokens
sequence_status['packed_text_ids'].append(self.end_of_image)
sequence_status['packed_text_indexes'].append(curr)
if item['special_token_loss'] == 1:
sequence_status['ce_loss_indexes'].append(curr)
sequence_status['ce_loss_weights'].append(1.0)
sequence_status['packed_label_ids'].append(item['special_token_label'])
curr += 1
curr_split_len += 1
if split_start:
if item['loss'] == 1 and 'frame_delta' not in item.keys():
attn_modes.append("noise")
else:
attn_modes.append("full")
sequence_status['packed_position_ids'].extend([curr_rope_id] * (num_img_tokens + 2))
if 'frame_delta' in item.keys():
curr_rope_id += item['frame_delta']
elif item['loss'] == 0:
curr_rope_id += 1
if item.get('split_end', True):
split_lens.append(curr_split_len)
sample_lens += curr_split_len
sequence_status['curr'] = curr
sequence_status['sample_lens'].append(sample_lens)
if not self.config["use_flex"]:
sequence_status['nested_attention_masks'].append(
prepare_attention_mask_per_sample(split_lens, attn_modes)
)
else:
sequence_status['split_lens'].extend(split_lens)
sequence_status['attn_modes'].extend(attn_modes)
return sequence_status