"""
Dataset utils for training
"""
import os
from abc import ABC, abstractmethod
import pickle
import numpy as np
from mindspore.mindrecord import FileWriter
class Dataset(ABC):
@abstractmethod
def __len__(self):
pass
@abstractmethod
def __getitem__(self, idx):
pass
class LMDBDataset(Dataset):
"""Read data from lmdb"""
def __init__(self, path, process_fn=None):
import lmdb
self.path = path
self.env = lmdb.open(
path,
max_readers=32,
readonly=True,
lock=False,
readahead=False,
meminit=False,
)
self.process_fn = process_fn
if not self.env:
raise IOError("Cannot open lmdb dataset", path)
with self.env.begin(write=False) as txn:
self.length = int(txn.get("length".encode("utf-8")).decode("utf-8"))
def __len__(self):
return self.length
def __getitem__(self, idx):
with self.env.begin(write=False) as txn:
key = str(idx).encode("utf-8")
try:
row = pickle.loads(txn.get(key))
except TypeError:
raise IndexError("Index out of range")
if self.process_fn:
return self.process_fn(row)
return row
class PadDataset(Dataset):
"""Pad data"""
def __init__(self, dataset, seq_len, eod_id):
self.dataset = dataset
self.seq_len = seq_len + 1
self.eod_id = eod_id
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
item = self.dataset[idx][0]
return (item[:self.seq_len],) if self.seq_len <= len(item) else (
np.concatenate((item, np.ones(self.seq_len - len(item)) * self.eod_id), axis=0),)
def get_code_data_train(code_data_path, args, process_fn=None):
"""Get train data"""
if os.path.exists(os.path.join(code_data_path, 'data.mdb')):
full_path = os.path.join(code_data_path)
print(f"Loading code data {full_path}")
data = LMDBDataset(
full_path,
process_fn=process_fn,
)
data = PadDataset(
data,
args.seq_length,
args.eod_id,
)
return data
def generate_mindrecord(args, file_name="codegeex.mindrecord"):
"""Generate mindrecord format data."""
data_schema = {"input_ids": {"type": "int64", "shape": [-1]}}
writer = FileWriter(file_name, shard_num=1, overwrite=True)
writer.add_schema(data_schema, "it is a code dataset")
data = []
train_data = get_code_data_train(args.code_data, args)
for i, input_id in enumerate(train_data):
print(i)
sample = {"input_ids": np.array(input_id).squeeze().astype(np.int64)}
data.append(sample)
if i > 100:
writer.write_raw_data(data)
data = []
if data:
print(data)
writer.write_raw_data(data)
writer.commit()
if __name__ == "__main__":
import argparse
args_opt = argparse.ArgumentParser(description="PanguAlpha training")
args_opt.add_argument("--seq_length",
type=int,
default=2048,
help="sequence length, default is 2048.")
args_opt.add_argument("--eod_id",
type=int, default=50256,
help="The id of end of document")
args_opt.add_argument("--eod_reset",
type=int,
default=1,
help="Enable eod mask, default is 1.")
args_opt.add_argument('--code_data',
type=str,
help='Location of code data.')
args_opt = args_opt.parse_args()
print(args_opt)
generate_mindrecord(args_opt)