import json
import warnings
from abc import ABC, abstractmethod
from collections.abc import Iterable
import torch
import torch.nn.functional as F
import torch.distributed as dist
from megatron.core import mpu
from megatron.core.parallel_state import get_data_parallel_rank, get_data_parallel_group
from megatron.training import get_args
from mindspeed_mm.utils.data_balance.balance_sorting_algo import SORTING_ALGO_FUNC
from mindspeed_mm.utils.data_balance.batch_generator import PrefetchMicroBatchIterator
from mindspeed_mm.utils.utils import EncoderBalanceComm
TXT_ELEM_SET = {'input_ids', 'attention_mask', 'labels'}
TXT_ELEM_LIST = ['input_ids', 'attention_mask', 'labels']
class BaseDataBalance(ABC):
def __init__(self, sorting_algo_name, need_rev=False, merge_down_ratio=1.):
self.sorting_algo = self._get_sorting_algo(sorting_algo_name)
self.state_buffer = {}
self.need_rev = need_rev
self.merge_down_ratio = merge_down_ratio
@staticmethod
@abstractmethod
def _rank_table_mapping(rank_table, dp_rank):
raise NotImplementedError("method 'rank_table_mapping' must be implemented")
@staticmethod
@abstractmethod
def _split_batch_data(datas: dict):
raise NotImplementedError("method 'data_balance' must be implemented")
@staticmethod
def _all_to_all_communication(data, balanced_data_lengths, data_dim, dp_process_group, require_grad=False):
balanced_data_cache = [
torch.empty(
(*new_length, *data_dim), dtype=data[0].dtype, device=data[0].device
).squeeze(-1) for new_length in balanced_data_lengths
]
if require_grad:
from torch.distributed.nn.functional import all_to_all
else:
from torch.distributed import all_to_all
all_to_all(balanced_data_cache, data, group=dp_process_group)
return balanced_data_cache
@staticmethod
def _data_reorganization(data, data_list):
if isinstance(data, torch.Tensor):
new_data_group_per_rank = [data[new_group_idxs] for new_group_idxs in data_list]
else:
new_data_group_per_rank = [
torch.cat([data[idx] for idx in new_group_idxs])
if new_group_idxs.numel() != 0
else torch.tensor([], dtype=data[0].dtype, device=data[0].device)
for new_group_idxs in data_list
]
return new_data_group_per_rank
@staticmethod
def _get_sorting_algo(sorting_algo_name):
return SORTING_ALGO_FUNC[sorting_algo_name]
@abstractmethod
def _all_gather_data_lengths(self, data_lengths, num_replicas, dp_process_group):
raise NotImplementedError("method 'data_balance' must be implemented")
def _data_balance(
self,
data_lengths: torch.Tensor,
datas: dict[str, torch.Tensor],
dp_process_group=None,
data_type='Unknown data',
**kwargs
):
dp_rank = get_data_parallel_rank()
num_replicas = dp_process_group.size()
gathered_lengths, samples_lengths = self._all_gather_data_lengths(data_lengths, num_replicas, dp_process_group)
rank_table = self.sorting_algo(samples_lengths, num_replicas, **kwargs)
data_list, rank_table = self._rank_table_mapping(rank_table, dp_rank)
balance_data_mapping_index = [torch.where(rank_table[dp_rank][:, 0] == r)[0] for r in range(num_replicas)]
self.state_buffer[data_type]['balance_data_mapping_index'] = torch.cat(balance_data_mapping_index)
balanced_datas = {}
balanced_data_lengths = torch.empty(
num_replicas, 2, dtype=rank_table[dp_rank].dtype, device=rank_table[dp_rank].device
)
sample_num_per_rank = torch.bincount(rank_table[dp_rank][:, 0], minlength=num_replicas)
for i, (data_name, data) in enumerate(datas.items()):
reorganized_data = self._data_reorganization(data, data_list)
balanced_data_dim = ()
if data_name != 'pixel_values':
balanced_data_dim = (*data[0].shape[1:],)
balanced_data_lengths[:, 0] = sample_num_per_rank
if isinstance(gathered_lengths, torch.Tensor):
balanced_data_lengths[:, 1] = gathered_lengths[:, 0, i]
else:
balanced_data_lengths[:, 1] = torch.stack([gl[0, i] for gl in gathered_lengths])
if self.need_rev:
origin_data = torch.cat(reorganized_data)
self.state_buffer[data_type][f"{data_name}_origin_split"] = (
origin_data[:, 0] * origin_data[:, 1] * origin_data[:, 2] // self.merge_down_ratio).tolist()
else:
balanced_data_lengths[:, 0] = 0
balanced_data_lengths[:, 0].index_add_(0, rank_table[dp_rank][:, 0], rank_table[dp_rank][:, 2 + i])
balanced_data_lengths[:, 1] = data[0].shape[-1]
if self.need_rev:
self.state_buffer[data_type][f"{data_name}_split"] = (
balanced_data_lengths[:, 0] // self.merge_down_ratio).tolist()
self.state_buffer[data_type][f"{data_name}_origin"] = [
(d.shape[0] // self.merge_down_ratio,)
for d in reorganized_data
]
self.state_buffer[data_type][f"{data_name}_data_list"] = torch.cat(data_list)
balanced_data = self._all_to_all_communication(
reorganized_data, balanced_data_lengths, balanced_data_dim, dp_process_group)
balanced_datas[data_name] = balanced_data
return balanced_datas
class GBSImageDataBalance(BaseDataBalance):
def __init__(
self,
virtual_pipeline_model_parallel_size,
model_config_path,
sorting_algo_name,
len_model,
train_data_iterator
):
super().__init__(sorting_algo_name)
with open(model_config_path, 'r') as f:
model_config = json.load(f)
world_size = dist.get_world_size()
if get_args().hetero_parallel:
self.image_encoder_dp = world_size // int(
model_config['image_encoder']['tp'] *
model_config['image_encoder']['pp'] *
model_config['image_encoder']['cp']
)
else:
self.image_encoder_dp = mpu.get_data_parallel_world_size()
if self.image_encoder_dp <= 1:
warnings.warn(
"Image data balance is enabled, but the image encoder's data parallelism (dp) is set to 1. "
"In this case, the data balance feature has no effect and may introduce unnecessary overhead. "
"Consider disabling it by removing --use-data-balance.",
UserWarning,
stacklevel=2
)
self.txt_padding_dict = {
'input_ids':
train_data_iterator.iterable.gi_frame.f_locals['dl'].collate_fn.data_collator.tokenizer.pad_token_id,
'labels': train_data_iterator.iterable.gi_frame.f_locals['dl'].collate_fn.data_collator.label_pad_token_id,
'attention_mask': 0
}
self.train_data_iterator = train_data_iterator
self.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size
self.len_model = len_model
@staticmethod
def _rank_table_mapping(rank_table, dp_rank):
rank_table = torch.stack(rank_table)
rank_mask = rank_table[:, :, 0] == dp_rank
rank_table_for_current_rank = [rank_table[i][rank_mask[i]][:, 1] for i in range(len(rank_table))]
return rank_table_for_current_rank, rank_table
@staticmethod
def _split_batch_data(datas: dict):
all_data_batchs = {}
all_data_lengths = []
if 'pixel_values' in datas.keys():
image_grid_thw = datas.pop('image_grid_thw')
pixel_values = datas.pop('pixel_values')
pixel_values_length = (image_grid_thw[:, 0] * image_grid_thw[:, 1] * image_grid_thw[:, 2])
all_data_batchs['pixel_values'] = pixel_values.npu().split(pixel_values_length.tolist(), dim=0)
all_data_batchs['image_grid_thw'] = image_grid_thw.npu()
all_data_lengths.extend([
pixel_values_length.npu(),
torch.empty(
pixel_values_length.shape[0], dtype=torch.long, device='npu'
).fill_(image_grid_thw.shape[-1])
])
input_idxs_length = torch.empty(
datas['input_ids'].shape[0], dtype=torch.long, device='npu'
).fill_(datas['input_ids'].shape[-1])
for key in TXT_ELEM_LIST:
all_data_batchs[key] = datas.pop(key).npu()
all_data_lengths.extend([input_idxs_length] * len(TXT_ELEM_LIST))
for key in datas.keys():
value = datas[key].npu()
all_data_batchs[key] = value
all_data_lengths.extend([
torch.empty(
value.shape[0], dtype=torch.long, device=value.device
).fill_(value.shape[1] if len(value.shape) > 1 else 1)
])
return all_data_batchs, torch.stack(all_data_lengths, dim=-1)
def build_balanced_train_data_iterator(
self,
is_vit_last_stage=False,
max_batch_capacity=None,
micro_batch_size=None,
num_microbatches=None,
data_type='Unknown data',
**kwargs
):
batch = next(self.train_data_iterator)
has_video = 'pixel_values_videos' in batch and 'video_grid_thw' in batch
if has_video:
batch['pixel_values'] = batch.pop('pixel_values_videos')
batch['image_grid_thw'] = batch.pop('video_grid_thw')
if (mpu.is_pipeline_first_stage() or is_vit_last_stage) and get_args().encoder_dp_balance:
batch['pixel_values'], batch['tranfer'] = EncoderBalanceComm.apply(
batch['pixel_values'],
mpu.get_data_parallel_group())
if data_type not in self.state_buffer:
self.state_buffer[data_type] = {}
for data_name, value in batch.items():
if not isinstance(value, Iterable) or isinstance(value, (str, bytes)):
if "non_balanced_data" not in self.state_buffer[data_type]:
self.state_buffer[data_type]["non_balanced_data"] = {}
warnings.warn(
f"find un-iterable data: {data_name}, type:{type(value)}. To ensure "
f"correct decomposition into mbs individual samples, it has been moved to non_balanced_data. "
f"Please verify the actual purpose of this data and apply appropriate adjustments."
)
self.state_buffer[data_type]["non_balanced_data"][data_name] = value
if 'non_balanced_data' in self.state_buffer[data_type]:
for data_name in self.state_buffer[data_type]["non_balanced_data"]:
batch.pop(data_name)
split_batch, split_lengths = self._split_batch_data(batch)
balanced_datas = self._data_balance(
data_lengths=split_lengths,
datas=split_batch,
dp_process_group=get_data_parallel_group(),
data_type=data_type,
max_batch_capacity=max_batch_capacity,
image_encoder_dp=self.image_encoder_dp
)
balanced_global_batchs = self.get_global_balanced_data(
balanced_datas,
micro_batch_size,
num_microbatches,
data_type
)
micro_batchs = self.collate_fn(balanced_global_batchs, data_type)
if self.virtual_pipeline_model_parallel_size:
batch_generator = [PrefetchMicroBatchIterator(micro_batchs) for _ in range(self.len_model)]
else:
batch_generator = PrefetchMicroBatchIterator(micro_batchs)
return batch_generator
def get_global_balanced_data(
self,
balanced_data_batch: dict,
micro_batch_size: int,
num_microbatches: int,
data_type: str = 'image',
):
split_balanced_data = {}
split_list = [micro_batch_size] * num_microbatches
pixel_values = balanced_data_batch.pop('pixel_values')
image_grid_thws = balanced_data_batch.pop('image_grid_thw')
image_grid_thws = torch.cat(image_grid_thws)
split_grid_data = self.divide_data_based_on_split(image_grid_thws, data_type)
split_balanced_data['image_grid_thw'] = split_grid_data.split(split_list)
pixel_split = (image_grid_thws[:, 0] * image_grid_thws[:, 1] * image_grid_thws[:, 2])
pixel_values_list = torch.cat(pixel_values).split(pixel_split.tolist())
merge_pixels = [None] * len(pixel_values_list)
for i, idx in enumerate(self.state_buffer[data_type]['balance_data_mapping_index']):
merge_pixels[idx] = pixel_values_list[i]
pixel_split = torch.stack(split_balanced_data['image_grid_thw'])
pixel_split = (pixel_split[:, :, 0] * pixel_split[:, :, 1] * pixel_split[:, :, 2]).sum(-1)
split_balanced_data['pixel_values'] = torch.cat(merge_pixels).split(pixel_split.tolist(), dim=0)
max_dim = torch.cat([mask.sum(-1) for mask in balanced_data_batch['attention_mask']]).max()
for new_rank in range(len(balanced_data_batch['attention_mask'])):
if balanced_data_batch['attention_mask'][new_rank].shape[-1] > max_dim:
balanced_data_batch['input_ids'][new_rank] = balanced_data_batch['input_ids'][new_rank][:, : max_dim]
balanced_data_batch['labels'][new_rank] = balanced_data_batch['labels'][new_rank][:, : max_dim]
balanced_data_batch['attention_mask'][new_rank] = (
balanced_data_batch['attention_mask'][new_rank][:, : max_dim]
)
elif balanced_data_batch['attention_mask'][new_rank].shape[-1] < max_dim:
pad_shape = (0, max_dim - balanced_data_batch['input_ids'][new_rank].size(-1))
balanced_data_batch['input_ids'][new_rank] = F.pad(
balanced_data_batch['input_ids'][new_rank],
pad_shape,
value=self.txt_padding_dict['input_ids']
)
balanced_data_batch['attention_mask'][new_rank] = F.pad(
balanced_data_batch['attention_mask'][new_rank],
pad_shape,
value=self.txt_padding_dict['attention_mask']
)
balanced_data_batch['labels'][new_rank] = F.pad(
balanced_data_batch['labels'][new_rank],
pad_shape,
value=self.txt_padding_dict['labels']
)
for name, data in balanced_data_batch.items():
split_data = self.divide_data_based_on_split(torch.cat(data), data_type)
split_balanced_data[name] = split_data.split(split_list)
return split_balanced_data
def divide_data_based_on_split(self, datas, data_type="image") -> torch.Tensor:
merge_datas = torch.empty_like(datas)
merge_datas[self.state_buffer[data_type]['balance_data_mapping_index']] = datas
return merge_datas
def collate_fn(self, balanced_global_batchs, data_type):
micro_batchs = [dict(zip(balanced_global_batchs.keys(), row)) for row in zip(*balanced_global_batchs.values())]
for batch in micro_batchs:
if "non_balanced_data" in self.state_buffer[data_type]:
batch.update(self.state_buffer[data_type]["non_balanced_data"])
return micro_batchs
def _all_gather_data_lengths(self, data_lengths, num_replicas, dp_process_group):
gathered_lengths = [
torch.empty(data_lengths.shape, dtype=data_lengths.dtype, device=data_lengths.device)
for _ in range(num_replicas)
]
dist.all_gather(gathered_lengths, data_lengths, group=dp_process_group)
gathered_lengths = torch.stack(gathered_lengths)
samples_lengths = torch.cat(
[
torch.arange(
num_replicas, dtype=gathered_lengths.dtype, device=gathered_lengths.device
).view(-1, 1).expand(num_replicas, gathered_lengths.shape[1]).unsqueeze(-1),
torch.arange(
gathered_lengths.shape[1], dtype=gathered_lengths.dtype, device=gathered_lengths.device
).view(1, -1).expand(num_replicas, gathered_lengths.shape[1]).unsqueeze(-1),
gathered_lengths
], dim=-1
).flatten(0, 1)
return gathered_lengths, samples_lengths
class MBSImageDataBalance(BaseDataBalance):
def __init__(self, sorting_algo_name, spatial_merge_size=1.):
super().__init__(sorting_algo_name, need_rev=True, merge_down_ratio=spatial_merge_size**2)
@staticmethod
def _split_batch_data(datas: dict):
image_data_batchs = {}
image_grid_thw = datas['image_grid_thw']
pixel_values = datas['pixel_values']
pixel_values_length = (image_grid_thw[:, 0] * image_grid_thw[:, 1] * image_grid_thw[:, 2])
image_data_batchs['pixel_values'] = pixel_values.npu().split(pixel_values_length.tolist(), dim=0)
image_data_batchs['image_grid_thw'] = image_grid_thw.npu()
image_data_lengths = [
pixel_values_length.npu(),
torch.empty(
pixel_values_length.shape[0], dtype=torch.long, device='npu'
).fill_(image_grid_thw.shape[-1])
]
return image_data_batchs, image_data_lengths
@staticmethod
def _rank_table_mapping(rank_table, dp_rank):
rank_table = [torch.stack(rt) for rt in rank_table]
rank_table_for_current_rank = [rt[rt[:, 0] == dp_rank][:, 1] for rt in rank_table]
return rank_table_for_current_rank, rank_table
def get_image_balance_data(self, image_batch, data_type='image'):
if data_type not in self.state_buffer:
self.state_buffer[data_type] = {}
split_batch, split_lengths = self._split_batch_data(image_batch)
balanced_datas = self._data_balance(
data_lengths=torch.stack(split_lengths, dim=-1),
datas=split_batch,
dp_process_group=get_data_parallel_group(),
data_type=data_type,
)
image_grid_thw = torch.cat(balanced_datas['image_grid_thw'])
pixel_values = torch.cat(balanced_datas['pixel_values'])
return pixel_values, image_grid_thw
def reverse_img_balance_data(self, hidden_state, deepstack_feature_lists, require_grad=False):
dp_process_group = get_data_parallel_group()
recoverd_hidden_state = self._recover_image_data(hidden_state, dp_process_group, require_grad)
if deepstack_feature_lists:
recovered_deepstack_feature_lists = [
self._recover_image_data(df, dp_process_group, require_grad)
for df in deepstack_feature_lists
]
else:
recovered_deepstack_feature_lists = []
return recoverd_hidden_state, recovered_deepstack_feature_lists
def _all_gather_data_lengths(self, data_lengths, num_replicas, dp_process_group):
cur_bs = torch.tensor(data_lengths.shape[0], dtype=torch.long, device=data_lengths.device)
all_gather_bs = [torch.empty(1, dtype=torch.long, device=data_lengths.device) for _ in range(num_replicas)]
dist.all_gather(all_gather_bs, cur_bs, group=dp_process_group)
gathered_lengths = [
torch.empty((all_gather_bs[i], *data_lengths.shape[1:]), dtype=data_lengths.dtype,
device=data_lengths.device)
for i in range(num_replicas)
]
dist.all_gather(gathered_lengths, data_lengths, group=dp_process_group)
samples_lengths = [
F.pad(torch.cat([
torch.arange(
len(batch), dtype=batch.dtype, device=batch.device
).view(-1, 1),
batch
], dim=-1), pad=(1, 0), value=i)
for i, batch in enumerate(gathered_lengths)
]
samples_lengths = torch.cat(samples_lengths)
return gathered_lengths, samples_lengths
def _recover_image_data(self, hidden_state, dp_process_group, require_grad=False):
recovered_hidden_state = self._all_to_all_communication(
list(hidden_state.split(self.state_buffer['image']["pixel_values_split"])),
self.state_buffer['image']["pixel_values_origin"],
(hidden_state.shape[-1],),
dp_process_group,
require_grad=require_grad
)
recovered_hidden_state = torch.cat(recovered_hidden_state).split(
self.state_buffer['image']["image_grid_thw_origin_split"])
origin_hidden_state = [None] * len(recovered_hidden_state)
for i, idx in enumerate(self.state_buffer['image']['pixel_values_data_list']):
origin_hidden_state[idx] = recovered_hidden_state[i]
return torch.cat(origin_hidden_state)