import os
import math
import json
from collections.abc import Mapping
from typing import TypeVar, Optional, Iterator, List, Optional, Sequence, Union
from functools import partial
import numpy as np
import torch
import torch.utils.data
import torch.distributed as dist
from torch.utils.data import DataLoader, Sampler, BatchSampler, DistributedSampler
from torch.utils.data.dataloader import default_collate
from torch.utils.data.distributed import DistributedSampler
from ..dataset.utils.dynamic_dataset import UniformBucketingDynamicDataset
from ..dataset.utils.dynamic_sampler import DynamicDistributedSampler
class Collater:
def __init__(self, follow_batch, exclude_keys):
self.follow_batch = follow_batch
self.exclude_keys = exclude_keys
def __call__(self, batch):
from torch_geometric.data import Batch
from torch_geometric.data.data import BaseData
elem = batch[0]
if isinstance(elem, BaseData):
return Batch.from_data_list(batch, self.follow_batch,
self.exclude_keys)
elif isinstance(elem, torch.Tensor):
return default_collate(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, str):
return batch
elif isinstance(elem, Mapping):
return {key: self([data[key] for data in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'):
return type(elem)(*(self(s) for s in zip(*batch)))
elif isinstance(elem, Sequence) and not isinstance(elem, str):
return [self(s) for s in zip(*batch)]
raise TypeError(f'DataLoader found invalid type: {type(elem)}')
def collate(self, batch):
return self(batch)
class AgentDynamicDataset(UniformBucketingDynamicDataset):
def get_agents_num(self):
'''
Get agents number from json file.
Used to sort dataset by cluster based on agents number.
'''
if self.split != 'train':
return
agents_num_file_name = f"{self.split}_agents_num.json"
agents_num_file_path = os.path.join(self.processed_dir, agents_num_file_name)
if os.path.exists(agents_num_file_path):
with open(agents_num_file_path, "r") as handle:
self.agents_num = json.load(handle)
return
for idx in tqdm(range(self._num_samples)):
data = self.get(idx)
self.agents_num[idx] = data["agent"]['num_nodes']
with open(agents_num_file_path, 'w') as handle:
json.dump(self.agents_num, handle)
def sorting(self):
self.indice = torch.arange(self._num_samples)
self.agents_indices = torch.tensor(self.agents_num)[self.indice]
self.clusters = self.agents_indices // 10
self.max_cluster = int(self.clusters.max().item()) + 1
self.sorted_ids = self.indice[self.clusters.argsort()]
self.sorted_clusters = self.clusters.sort()[0]
def bucketing(self):
self.buckets = []
start_idx = 0
for clst_idx in range(self.max_cluster):
end_idx = torch.searchsorted(self.sorted_clusters, clst_idx + 1, side='left')
if end_idx > start_idx:
current_cluster = self.sorted_ids[start_idx:end_idx]
self.buckets.append(current_cluster.tolist())
start_idx = end_idx
class DynamicBatchSampler(BatchSampler):
def __init__(self,
dataset,
sampler,
batch_size: int,
drop_last: bool = False) -> None:
super().__init__(sampler, batch_size, drop_last)
self.lengths = dataset.agents_num
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if batch and not self.drop_last:
yield batch
class AgentDynamicBatchSampler(DynamicDistributedSampler):
def __init__(self,
dataset,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False) -> None:
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
self.drop_last = drop_last
if self.drop_last and len(self.dataset) % self.num_replicas != 0:
self.num_samples = math.ceil((len(self.dataset) - self.num_replicas) / self.num_replicas)
else:
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
self.total_size = self.num_samples * self.num_replicas
self.seed = seed
def __iter__(self):
if self.shuffle:
indices = self.bucket_arange()
else:
indices = []
for bct in self.dataset.buckets:
indices.extend(bct)
if not self.drop_last:
padding_size = self.total_size - len(indices)
indices = torch.tensor(indices)
indices = torch.cat([indices, indices[:padding_size]])
else:
indices = torch.tensor(indices)
indices = indices[:self.total_size]
indices = indices[self.rank: self.total_size: self.num_replicas]
if not len(indices) == self.num_samples:
raise ValueError(f"indices length should be equal to num_samples, but got indices length: {len(indices)}, expected: {self.num_samples}")
return iter(indices)
class AgentDynamicBatchDataLoader(DataLoader):
def __init__(self,
dataset,
batch_size: int,
train_batch_size: int,
shuffle: bool = True,
follow_batch: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None,
**kwargs) -> None:
kwargs.pop('collate_fn', None)
kwargs.pop('batch_sampler', None)
self.follow_batch = follow_batch
self.exclude_keys = exclude_keys
sampler = AgentDynamicBatchSampler(dataset, shuffle=True)
super().__init__(
dataset, collate_fn=Collater(follow_batch, exclude_keys),
batch_sampler=DynamicBatchSampler(dataset, sampler, train_batch_size), **kwargs)