import itertools
from typing import Iterator
from dataclasses import dataclass
import torch_npu
from torch.utils.data.dataset import IterableDataset
import torch
from torchrec.streamable import Pipelineable
from torchrec import KeyedJaggedTensor, JaggedTensor
@dataclass
class Batch(Pipelineable):
sparse_features: KeyedJaggedTensor
labels: torch.Tensor
def __init__(self, sparse_features, labels) -> None:
self.sparse_features = sparse_features
self.labels = labels
def to(self, device: torch.device, non_blocking: bool = False) -> "Batch":
return Batch(
sparse_features=self.sparse_features.to(device, non_blocking=non_blocking),
labels=self.labels.to(device, non_blocking=non_blocking),
)
def record_stream(self, stream: torch_npu.npu.streams.Stream) -> None:
self.sparse_features.record_stream(stream)
self.labels.record_stream(stream)
def pin_memory(self) -> "Batch":
return Batch(
sparse_features=self.sparse_features.pin_memory(),
labels=self.labels.pin_memory(),
)
class RandomRecDataset(IterableDataset[Batch]):
def __init__(self, batch_size, batch_num, feat_names, id_ranges):
super().__init__()
self.index = 0
self.names = list(itertools.chain.from_iterable(feat_names))
self.id_ranges = id_ranges
self.data = [self.generate_one_batch(batch_size) for i in range(batch_num)]
def generate_one_batch(self, batch_size) -> Batch:
torch.manual_seed(1)
input_dict = {}
for name, id_range in zip(self.names, self.id_ranges):
ids = torch.randint(0, id_range, (batch_size,))
lengths = torch.ones(batch_size).long()
input_dict[name] = JaggedTensor(values=ids, lengths=lengths)
kjt_tensor = KeyedJaggedTensor.from_jt_dict(input_dict)
label = torch.randint(0, 2, (batch_size,))
return Batch(kjt_tensor, label)
def __iter__(self) -> Iterator[Batch]:
return iter(self.data)
def __len__(self) -> int:
return len(self.data)