class DynBszBuffer:
"""
A buffer to store samples for dynamic batch size.
"""
def __init__(self):
self._buffer = []
self._buffer_sample_lens = []
self.del_idxs = []
self.cur_idx = 0
self.all_token_cnt = 0
def append(self, item: dict):
"""
Append a sample to the buffer.
Args:
item: a sample to append to the buffer.
The sample should be a dict with the following keys:
- input_ids: torch.Tensor of shape (seq_len, )
- attention_mask: torch.Tensor of shape (seq_len, )
"""
self._buffer.append(item)
self._buffer_sample_lens.append(item["attention_mask"].sum())
self.all_token_cnt += self._buffer_sample_lens[-1]
def get_samples(self, max_seq_len: int, force: bool = True):
"""
get samples from the buffer.
Args:
max_seq_len: the number of tokens to get.
force: if True, the first sample will be returned even if it is not full.
Returns:
samples: a list of samples.
"""
cum_seq_len = 0
samples = []
while self.cur_idx < len(self._buffer) and cum_seq_len < max_seq_len:
seq_len = self._buffer_sample_lens[self.cur_idx]
is_valid_sample = self.cur_idx not in self.del_idxs
"""
In these cases, a sample could add to sequence:
1. force is true and current sample is the first sample in this sequence (see function annotation)
2. current sequence length + cumulate length < max sequence length
"""
could_add_to_seq = (force is True and cum_seq_len == 0) or (seq_len <= max_seq_len - cum_seq_len)
if is_valid_sample and could_add_to_seq:
cum_seq_len += seq_len
samples.append(self._buffer[self.cur_idx])
self.del_idxs.append(self.cur_idx)
self.cur_idx += 1
if len(samples) == 0:
raise ValueError("Could not get samples from buffer")
return samples
def __len__(self):
return len(self._buffer)
def flush(self):
"""
Flush the buffer.
"""
self.cur_idx = 0
self.all_token_cnt -= sum([self._buffer_sample_lens[idx] for idx in self.del_idxs])
buffer_len = len(self._buffer)
self._buffer = [self._buffer[idx] for idx in range(buffer_len) if idx not in self.del_idxs]
self._buffer_sample_lens = [
self._buffer_sample_lens[idx]
for idx in range(buffer_len)
if idx not in self.del_idxs
]
self.del_idxs = []
def merge(self, buffer_to_merge: "DynBszBuffer"):
""" "
Merge the buffer with another buffer.
Args:
buffer_to_merge: the buffer to merge.
"""
self.flush()
buffer_to_merge.flush()
for item in buffer_to_merge._buffer:
self.append(item)
class TextBatchingStrategy(object):
""" "
Batching strategy for text data.
Args:
max_seq_len: the max number of tokens to get for each sequence.
buffer_size: the size of the buffer.
"""
def __init__(
self,
max_seq_len,
buffer_size: int = 500,
) -> None:
super().__init__()
self.max_seq_len = max_seq_len
self.buffer_size = buffer_size
self._buffer = DynBszBuffer()
def is_full_filled(self) -> bool:
return len(self._buffer) >= self.buffer_size and self._buffer.all_token_cnt >= self.max_seq_len
def put_item(self, item: dict):
self._buffer.append(item)
def get_micro_batch(self) -> list:
samples = self._buffer.get_samples(self.max_seq_len)
self._buffer.flush()
return samples
def empty(self) -> bool:
return len(self._buffer) == 0