from typing import Dict, Iterable
import numpy as np
import tensorflow as tf
def gen_mock_dataset(batch_num: int = 100, batch_size: int = 4096) -> tf.compat.v1.data.Dataset:
def data_generator() -> Iterable[Dict[str, np.ndarray]]:
i = 0
while i < batch_num:
mock_ids = np.random.randint(low=0, high=100, size=(batch_size, 8))
mock_labels = np.random.randint(low=0, high=100, size=(batch_size, 1))
mock_timestamp = np.random.randint(low=0, high=100, size=(batch_size, 1))
yield {"mock_ids": mock_ids, "mock_labels": mock_labels, "mock_timestamp": mock_timestamp}
i += 1
dataset = tf.compat.v1.data.Dataset.from_generator(
generator=data_generator,
output_types={"mock_ids": tf.int64, "mock_labels": tf.int32, "mock_timestamp": tf.int32},
output_shapes={
"mock_ids": tf.TensorShape([batch_size, 8]),
"mock_labels": tf.TensorShape([batch_size, 1]),
"mock_timestamp": tf.TensorShape([batch_size, 1]),
},
)
return dataset