from pathlib import Path
from typing import Dict, Iterable
import numpy as np
MB2B = 2 ** 20
B2MB = 1 / MB2B
FLUSH_THRESHOLD_B = 256 * MB2B
def pad_except_batch_axis(data: np.ndarray, target_shape_with_batch_axis: Iterable[int]):
assert all(
[current_size <= target_size for target_size, current_size in zip(target_shape_with_batch_axis, data.shape)]
), "target_shape should have equal or greater all dimensions comparing to data.shape"
padding = [(0, 0)] + [
(0, target_size - current_size)
for target_size, current_size in zip(target_shape_with_batch_axis[1:], data.shape[1:])
]
return np.pad(data, padding, "constant", constant_values=np.nan)
class NpzWriter:
"""
Dumps dicts of numpy arrays into npz files
It can/shall be used as context manager:
```
with OutputWriter('mydir') as writer:
writer.write(outputs={'classes': np.zeros(8), 'probs': np.zeros((8, 4))},
labels={'classes': np.zeros(8)},
inputs={'input': np.zeros((8, 240, 240, 3)})
```
## Variable size data
Only dynamic of last axis is handled. Data is padded with np.nan value.
Also each generated file may have different size of dynamic axis.
"""
def __init__(self, output_dir, compress=False):
self._output_dir = Path(output_dir)
self._items_cache: Dict[str, Dict[str, np.ndarray]] = {}
self._items_counters: Dict[str, int] = {}
self._flush_threshold_b = FLUSH_THRESHOLD_B
self._compress = compress
@property
def cache_size(self):
return {name: sum([a.nbytes for a in data.values()]) for name, data in self._items_cache.items()}
def _append_to_cache(self, prefix, data):
if data is None:
return
if not isinstance(data, dict):
raise ValueError(f"{prefix} data to store shall be dict")
cached_data = self._items_cache.get(prefix, {})
for name, value in data.items():
assert isinstance(
value, (list, np.ndarray)
), f"Values shall be lists or np.ndarrays; current type {type(value)}"
if not isinstance(value, np.ndarray):
value = np.array(value)
assert value.dtype.kind in ["S", "U"] or not np.any(
np.isnan(value)
), f"Values with np.nan is not supported; {name}={value}"
cached_value = cached_data.get(name, None)
if cached_value is not None:
target_shape = np.max([cached_value.shape, value.shape], axis=0)
cached_value = pad_except_batch_axis(cached_value, target_shape)
value = pad_except_batch_axis(value, target_shape)
value = np.concatenate((cached_value, value))
cached_data[name] = value
self._items_cache[prefix] = cached_data
def write(self, **kwargs):
"""
Writes named list of dictionaries of np.ndarrays.
Finally keyword names will be later prefixes of npz files where those dictionaries will be stored.
ex. writer.write(inputs={'input': np.zeros((2, 10))},
outputs={'classes': np.zeros((2,)), 'probabilities': np.zeros((2, 32))},
labels={'classes': np.zeros((2,))})
Args:
**kwargs: named list of dictionaries of np.ndarrays to store
"""
for prefix, data in kwargs.items():
self._append_to_cache(prefix, data)
biggest_item_size = max(self.cache_size.values())
if biggest_item_size > self._flush_threshold_b:
self.flush()
def flush(self):
for prefix, data in self._items_cache.items():
self._dump(prefix, data)
self._items_cache = {}
def _dump(self, prefix, data):
idx = self._items_counters.setdefault(prefix, 0)
filename = f"{prefix}-{idx:012d}.npz"
output_path = self._output_dir / filename
if self._compress:
np.savez_compressed(output_path, **data)
else:
np.savez(output_path, **data)
nitems = len(list(data.values())[0])
msg_for_labels = (
"If these are correct shapes - consider moving loading of them into metrics.py."
if prefix == "labels"
else ""
)
shapes = {name: value.shape if isinstance(value, np.ndarray) else (len(value),) for name, value in data.items()}
assert all(len(v) == nitems for v in data.values()), (
f'All items in "{prefix}" shall have same size on 0 axis equal to batch size. {msg_for_labels}'
f'{", ".join(f"{name}: {shape}" for name, shape in shapes.items())}'
)
self._items_counters[prefix] += nitems
def __enter__(self):
if self._output_dir.exists() and len(list(self._output_dir.iterdir())):
raise ValueError(f"{self._output_dir.as_posix()} is not empty")
self._output_dir.mkdir(parents=True, exist_ok=True)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.flush()