Yyezhenhuiinit
297fea2a创建于 2024年2月2日历史提交
import pickle
from bisect import bisect
from copy import deepcopy
import numpy as np
import gzip


def int2bytes(i: int, *, signed: bool = False) -> bytes:
    length = ((i + ((i * signed) < 0)).bit_length() + 7 + signed) // 8
    return i.to_bytes(length, byteorder='little', signed=signed)


def bytes2int(b: bytes, *, signed: bool = False) -> int:
    return int.from_bytes(b, byteorder='little', signed=signed)


def load_index_data(data_file):
    index_data_size = bytes2int(data_file.read(32))
    index_data = data_file.read(index_data_size)
    index_data = pickle.loads(index_data)
    data_offsets = deepcopy(index_data['offsets'])
    id2pos = deepcopy(index_data.get('id2pos', {}))
    meta = deepcopy(index_data.get('meta', {}))
    return data_offsets, id2pos, meta


class IndexedDataset:
    def __init__(self, path, unpickle=True):
        self.path = path
        self.root_data_file = open(f"{path}.data", 'rb', buffering=-1)
        try:
            self.byte_offsets, self.id2pos, self.meta = load_index_data(self.root_data_file)
            self.data_files = [self.root_data_file]
        except:
            self.__init__old(path)
            self.meta = {}
        self.gzip = self.meta.get('gzip', False)
        if 'chunk_begin' not in self.meta:
            self.meta['chunk_begin'] = [0]
        for i in range(len(self.meta['chunk_begin'][1:])):
            self.data_files.append(open(f"{self.path}.{i + 1}.data", 'rb'))
        self.unpickle = unpickle

    def __init__old(self, path):
        self.path = path
        index_data = np.load(f"{path}.idx", allow_pickle=True).item()
        self.byte_offsets = index_data['offsets']
        self.id2pos = index_data.get('id2pos', {})
        self.data_files = [open(f"{path}.data", 'rb', buffering=-1)]

    def __getitem__(self, i):
        if self.id2pos is not None and len(self.id2pos) > 0:
            i = self.id2pos[i]
        self.check_index(i)
        
        # chunk_id = bisect(self.meta['chunk_begin'][1:], self.byte_offsets[i])
        # if chunk_id == 0:
        #     data_file = open(f"{self.path}.data", 'rb', buffering=-1)
        # else:
        #     data_file = open(f"{self.path}.{chunk_id}.data", 'rb', buffering=-1)
        # data_file.seek(self.byte_offsets[i] - self.meta['chunk_begin'][chunk_id])
        # b = data_file.read(self.byte_offsets[i + 1] - self.byte_offsets[i])
        # data_file.close()
        
        chunk_id = bisect(self.meta['chunk_begin'][1:], self.byte_offsets[i])
        data_file = self.data_files[chunk_id]
        data_file.seek(self.byte_offsets[i] - self.meta['chunk_begin'][chunk_id])
        b = data_file.read(self.byte_offsets[i + 1] - self.byte_offsets[i])

        unpickle = self.unpickle
        if unpickle:
            if self.gzip:
                b = gzip.decompress(b)
            item = pickle.loads(b)
        else:
            item = b
        return item

    def __del__(self):
        for data_file in self.data_files:
            data_file.close()

    def check_index(self, i):
        if i < 0 or i >= len(self.byte_offsets) - 1:
            raise IndexError('index out of range')

    def __len__(self):
        return len(self.byte_offsets) - 1

    def __iter__(self):
        self.iter_i = 0
        return self

    def __next__(self):
        if self.iter_i == len(self):
            raise StopIteration
        else:
            item = self[self.iter_i]
            self.iter_i += 1
            return item


class IndexedDatasetBuilder:
    def __init__(self, path, append=False, max_size=1024 * 1024 * 1024 * 64,
                 default_idx_size=1024 * 1024 * 16, gzip=False):
        self.path = self.root_path = path
        self.default_idx_size = default_idx_size
        if append:
            self.data_file = open(f"{path}.data", 'r+b')
            self.data_file.seek(0)
            self.byte_offsets, self.id2pos, self.meta = load_index_data(self.data_file)
            self.data_file.seek(0)
            self.data_file.write(bytes(default_idx_size))
            self.data_file.seek(self.byte_offsets[-1])
            self.gzip = self.meta['gzip']
        else:
            self.data_file = open(f"{path}.data", 'wb')
            self.data_file.seek(default_idx_size)
            self.byte_offsets = [default_idx_size]
            self.id2pos = {}
            self.meta = {}
            self.meta['chunk_begin'] = [0]
            self.gzip = self.meta['gzip'] = gzip
        self.root_data_file = self.data_file
        self.max_size = max_size
        self.data_chunk_id = 0

    def add_item(self, item, id=None, use_pickle=True):
        if self.byte_offsets[-1] > self.meta['chunk_begin'][-1] + self.max_size:
            if self.data_file != self.root_data_file:
                self.data_file.close()
            self.data_chunk_id += 1
            self.data_file = open(f"{self.path}.{self.data_chunk_id}.data", 'wb')
            self.data_file.seek(0)
            self.meta['chunk_begin'].append(self.byte_offsets[-1])
        if not use_pickle:
            s = item
        else:
            s = pickle.dumps(item)
            if self.gzip:
                s = gzip.compress(s, 1)
        bytes = self.data_file.write(s)
        if id is not None:
            self.id2pos[id] = len(self.byte_offsets) - 1
        self.byte_offsets.append(self.byte_offsets[-1] + bytes)

    def finalize(self):
        self.root_data_file.seek(0)
        s = pickle.dumps({'offsets': self.byte_offsets, 'id2pos': self.id2pos, 'meta': self.meta})
        assert len(s) < self.default_idx_size, (len(s), self.default_idx_size)
        len_bytes = int2bytes(len(s))
        self.root_data_file.write(len_bytes)
        self.root_data_file.seek(32)
        self.root_data_file.write(s)
        self.root_data_file.close()
        try:
            self.data_file.close()
        except:
            pass


if __name__ == "__main__":
    import random
    from tqdm import tqdm

    # builder = IndexedDatasetBuilder(ds_path, append=True)
    # for i in tqdm(range(size)):
    #     builder.add_item(items[i], i + size)
    # builder.finalize()
    # ds = IndexedDataset(ds_path)
    # for i in tqdm(range(1000)):
    #     idx = random.randint(size, 2 * size - 1)
    #     assert (ds[idx]['a'] == items[idx - size]['a']).all()
    #     idx = random.randint(0, size - 1)
    #     assert (ds[idx]['a'] == items[idx]['a']).all()

    ds_path = '/tmp/indexed_ds_example'
    size = 100
    items = [{"a": np.random.normal(size=[10000, 10]),
              "b": np.random.normal(size=[10000, 10])} for i in range(size)]
    builder = IndexedDatasetBuilder(ds_path, max_size=1024 * 1024 * 40)
    builder.meta['lengths'] = [1, 2, 3]
    for i in tqdm(range(size)):
        builder.add_item(pickle.dumps(items[i]), i, use_pickle=False)
    builder.finalize()
    ds = IndexedDataset(ds_path)
    assert ds.meta['lengths'] == [1, 2, 3]
    for i in tqdm(range(1000)):
        idx = random.randint(0, size - 1)
        assert (ds[idx]['a'] == items[idx]['a']).all()

    # builder = IndexedDataset2Builder(ds_path, append=True)
    # builder.meta['lengths'] = [1, 2, 3, 5, 6, 7]
    # for i in tqdm(range(size)):
    #     builder.add_item(items[i], i + size)
    # builder.finalize()
    # ds = IndexedDataset2(ds_path)
    # assert ds.meta['lengths'] == [1, 2, 3, 5, 6, 7]
    # for i in tqdm(range(1000)):
    #     idx = random.randint(size, 2 * size - 1)
    #     assert (ds[idx]['a'] == items[idx - size]['a']).all()
    #     idx = random.randint(0, size - 1)
    #     assert (ds[idx]['a'] == items[idx]['a']).all()