def build_input_from_ids(text_a_ids=None,
text_b_ids=None,
answer_ids=None,
max_seq_length=16,
tokenizer=None,
args=None,
add_cls=True,
add_sep=False,
add_piece=False,
add_eos=True,
mask_id=None,
masked_lm=False):
if mask_id is None:
mask_id = tokenizer.get_command_id('mask')
eos_id = tokenizer.get_command_id('eos')
cls_id = tokenizer.get_command_id('cls')
sep_id = tokenizer.get_command_id('sep')
ids = []
types = []
paddings = []
if add_cls:
ids.append(cls_id)
types.append(0)
paddings.append(1)
len_text_a = len(text_a_ids)
ids.extend(text_a_ids)
types.extend([0] * len_text_a)
paddings.extend([1] * len_text_a)
if text_b_ids is not None:
if add_sep:
ids.append(sep_id)
types.append(0)
paddings.append(1)
len_text_b = len(text_b_ids)
ids.extend(text_b_ids)
types.extend([1] * len_text_b)
paddings.extend([1] * len_text_b)
eos_length = 1 if add_eos else 0
if len(ids) >= max_seq_length - eos_length:
max_seq_length_m1 = max_seq_length - 1
ids = ids[0:max_seq_length_m1]
types = types[0:max_seq_length_m1]
paddings = paddings[0:max_seq_length_m1]
end_type = 0 if text_b_ids is None else 1
if add_eos:
ids.append(eos_id)
types.append(end_type)
paddings.append(1)
sep = len(ids)
target_ids = [0] * len(ids)
loss_masks = [0] * len(ids)
position_ids = list(range(len(ids)))
block_position_ids = [0] * len(ids)
if add_piece or answer_ids is not None:
sop_id = tokenizer.get_command_id('sop')
mask_position = ids.index(
mask_id
) if not args.sentinel_token else args.max_position_embeddings
ids.append(sop_id)
types.append(end_type)
paddings.append(1)
position_ids.append(mask_position)
block_position_ids.append(1)
if answer_ids is not None:
len_answer = len(answer_ids)
ids.extend(answer_ids[:-1])
types.extend([end_type] * (len_answer - 1))
paddings.extend([1] * (len_answer - 1))
position_ids.extend([mask_position] * (len_answer - 1))
if not args.no_block_position:
block_position_ids.extend(range(2, len(answer_ids) + 1))
else:
block_position_ids.extend([1] * (len(answer_ids) - 1))
target_ids.extend(answer_ids)
loss_masks.extend([1] * len(answer_ids))
else:
target_ids.append(0)
loss_masks.append(1)
padding_length = max_seq_length - len(ids)
if padding_length > 0:
ids.extend([eos_id] * padding_length)
types.extend([eos_id] * padding_length)
paddings.extend([0] * padding_length)
position_ids.extend([0] * padding_length)
block_position_ids.extend([0] * padding_length)
target_ids.extend([0] * padding_length)
loss_masks.extend([0] * padding_length)
if masked_lm:
position_ids = [position_ids]
return ids, types, paddings, position_ids, sep, target_ids, loss_masks
class CollateArguments:
def __init__(self):
self.cloze_eval = True
self.pretrained_bert = False
self.task_mask = True
self.continuous_prompt = False
self.prefix_prompt = 0
self.sentinel_token = False
self.max_position_embeddings = 1024
self.no_block_position = False
self.masked_lm = False
self.pattern_id = 0
self.seq_length = 256
self.num_prompt_tokens = 0
self.multi_token = False
self.segment_length = 0
self.fast_decode = False
self.few_superglue = False
self.pattern_text = False
class Seq2SeqCollateArguments:
def __init__(self):
self.cloze_eval = True
self.pretrained_bert = False
self.task_mask = True
self.continuous_prompt = False
self.prefix_prompt = 0
self.max_src_length = 464
self.max_tgt_length = 48
self.min_tgt_length = 0
self.no_block_position = True
class PretrainDatasetArguments:
def __init__(self):
self.task_mask = True
self.block_mask_prob = 0.1
self.block_lm = True
self.masked_lm = False
self.pre_tokenize = True
self.no_lazy_loader = True
self.half_lazy_loader = False
self.sentinel_token = False