{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# TransferQueue Tutorial — Basic Usage\n",
"\n",
"This notebook walks through the core **Key-Value (KV) interface** of\n",
"[TransferQueue](https://github.com/Ascend/TransferQueue), an asynchronous\n",
"streaming data management module for efficient post-training workflows.\n",
"\n",
"**What you will learn:**\n",
"\n",
"1. Initialise TransferQueue (with Ray)\n",
"2. Store a single sample — `kv_put`\n",
"3. Store a batch of samples — `kv_batch_put`\n",
"4. Retrieve data — `kv_batch_get`\n",
"5. List stored keys & tags — `kv_list`\n",
"6. Partial-key and partial-field retrieval\n",
"7. Updating fields incrementally\n",
"8. Working with nested (variable-length) tensors\n",
"9. Storing variable-size image data\n",
"10. Storing non-tensor data (`NonTensorStack`)\n",
"11. Lazy Data Parsing with `data_parser`\n",
"12. Multiple partitions\n",
"13. Clean up — `kv_clear` / `close`\n",
"\n",
"> **Prerequisites:** `pip install TransferQueue` (or install from source). \n",
"> Ray will be started automatically in this notebook."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Initialization\n\nTransferQueue runs on top of [Ray](https://www.ray.io/). \nWe start Ray, then call `tq.init()` with a minimal configuration that uses the\nbuilt-in **SimpleStorage** backend."
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"import ray\n",
"import torch\n",
"from omegaconf import OmegaConf\n",
"from tensordict import TensorDict\n",
"from tensordict.tensorclass import NonTensorStack\n",
"\n",
"import transfer_queue as tq\n",
"\n",
"ray.init(ignore_reinit_error=True)\n",
"\n",
"config = OmegaConf.create(\n",
" {\n",
" \"controller\": {\"polling_mode\": True},\n",
" \"backend\": {\n",
" \"storage_backend\": \"SimpleStorage\",\n",
" \"SimpleStorage\": {\n",
" \"total_storage_size\": 200,\n",
" \"num_data_storage_units\": 2,\n",
" },\n",
" },\n",
" }\n",
")\n",
"\n",
"tq.init(config)\n",
"print(\"TransferQueue is ready!\")"
],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
""
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"TransferQueue is ready!\n"
]
}
],
"execution_count": 1
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Store a Single Sample — `kv_put`\n\n`kv_put` stores **one** key-value pair. \n- `key` — a unique string identifier for the sample \n- `partition_id` — a logical namespace (like a table name) \n- `fields` — a `dict` of tensors **or** a `TensorDict` \n- `tag` — optional metadata dict attached to the key"
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"tq.kv_put(\n",
" key=\"sample_0\",\n",
" partition_id=\"train\",\n",
" fields={\"input_ids\": torch.tensor([1, 2, 3, 4])},\n",
" tag={\"source\": \"wikipedia\", \"score\": 0.95},\n",
")\n",
"print(\"Stored sample_0\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Stored sample_0\n"
]
}
],
"execution_count": 2
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can also pass a pre-built `TensorDict` directly (the batch dimension\nmust be 1 for `kv_put`):"
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"fields_td = TensorDict(\n",
" {\n",
" \"input_ids\": torch.tensor([[5, 6, 7, 8]]),\n",
" \"attention_mask\": torch.ones(1, 4, dtype=torch.long),\n",
" },\n",
" batch_size=1,\n",
")\n",
"\n",
"tq.kv_put(\n",
" key=\"sample_1\",\n",
" partition_id=\"train\",\n",
" fields=fields_td,\n",
" tag={\"source\": \"books\", \"score\": 0.88},\n",
")\n",
"print(\"Stored sample_1\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Stored sample_1\n"
]
}
],
"execution_count": 3
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Store a Batch of Samples — `kv_batch_put`\n\nWhen you have multiple samples, `kv_batch_put` is more efficient than\ncalling `kv_put` in a loop. The `fields` TensorDict must have\n`batch_size == len(keys)`."
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"keys = [\"batch_0\", \"batch_1\", \"batch_2\"]\n",
"\n",
"fields = TensorDict(\n",
" {\n",
" \"input_ids\": torch.tensor([[10, 20], [30, 40], [50, 60]]),\n",
" \"attention_mask\": torch.ones(3, 2, dtype=torch.long),\n",
" },\n",
" batch_size=3,\n",
")\n",
"\n",
"tags = [\n",
" {\"split\": \"train\", \"idx\": 0},\n",
" {\"split\": \"train\", \"idx\": 1},\n",
" {\"split\": \"train\", \"idx\": 2},\n",
"]\n",
"\n",
"tq.kv_batch_put(keys=keys, partition_id=\"train\", fields=fields, tags=tags)\n",
"print(f\"Stored {len(keys)} samples in one call\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Stored 3 samples in one call\n"
]
}
],
"execution_count": 4
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Retrieve Data — `kv_batch_get`\n\nRetrieve samples by key(s). The result is always a `TensorDict`."
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# Retrieve a single key (pass a string)\n",
"result = tq.kv_batch_get(keys=\"sample_0\", partition_id=\"train\")\n",
"print(\"sample_0 →\", result)\n",
"print(\"input_ids:\", result[\"input_ids\"])"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"sample_0 → TensorDict(\n",
" fields={\n",
" input_ids: NestedTensor(shape=torch.Size([1, j1]), device=cpu, dtype=torch.int64, is_shared=False)},\n",
" batch_size=torch.Size([1]),\n",
" device=None,\n",
" is_shared=False)\n",
"input_ids: NestedTensor(size=(1, j1), offsets=tensor([0, 4]), contiguous=True)\n"
]
}
],
"execution_count": 5
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# Retrieve multiple keys at once\n",
"result = tq.kv_batch_get(keys=keys, partition_id=\"train\")\n",
"print(\"batch result →\", result)\n",
"print(\"input_ids:\\n\", result[\"input_ids\"])\n",
"print(\"attention_mask:\\n\", result[\"attention_mask\"])"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"batch result → TensorDict(\n",
" fields={\n",
" attention_mask: NestedTensor(shape=torch.Size([3, j2]), device=cpu, dtype=torch.int64, is_shared=False),\n",
" input_ids: NestedTensor(shape=torch.Size([3, j3]), device=cpu, dtype=torch.int64, is_shared=False)},\n",
" batch_size=torch.Size([3]),\n",
" device=None,\n",
" is_shared=False)\n",
"input_ids:\n",
" NestedTensor(size=(3, j3), offsets=tensor([0, 2, 4, 6]), contiguous=True)\n",
"attention_mask:\n",
" NestedTensor(size=(3, j2), offsets=tensor([0, 2, 4, 6]), contiguous=True)\n"
]
}
],
"execution_count": 6
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. List Keys & Tags — `kv_list`\n\n`kv_list` returns a nested dict:\n```\n{ partition_id: { key: tag_dict, ... }, ... }\n```\nPass `partition_id` to filter, or omit it to see everything."
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"info = tq.kv_list(partition_id=\"train\")\n",
"\n",
"for partition, key_tags in info.items():\n",
" print(f\"\\nPartition: {partition}\")\n",
" for key, tag in key_tags.items():\n",
" print(f\" {key}: {tag}\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Partition: train\n",
" sample_0: {'source': 'wikipedia', 'score': 0.95}\n",
" sample_1: {'source': 'books', 'score': 0.88}\n",
" batch_0: {'split': 'train', 'idx': 0}\n",
" batch_1: {'split': 'train', 'idx': 1}\n",
" batch_2: {'split': 'train', 'idx': 2}\n"
]
}
],
"execution_count": 7
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6. Partial-Key and Partial-Field Retrieval\n\nYou don't have to retrieve *all* keys or *all* fields at once."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 6a. Partial Keys\n\nJust pass a subset of the keys you stored."
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"partial = tq.kv_batch_get(keys=[\"batch_0\", \"batch_2\"], partition_id=\"train\")\n",
"print(\"Partial-key input_ids:\\n\", partial[\"input_ids\"])\n",
"assert partial[\"input_ids\"].shape[0] == 2 # only 2 rows"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Partial-key input_ids:\n",
" NestedTensor(size=(2, j5), offsets=tensor([0, 2, 4]), contiguous=True)\n"
]
}
],
"execution_count": 8
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 6b. Partial Fields\n\nUse the `fields` argument to select specific columns."
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# Retrieve only input_ids (single field)\n",
"result = tq.kv_batch_get(keys=\"sample_1\", partition_id=\"train\", select_fields=\"input_ids\")\n",
"print(\"Fields returned:\", list(result.keys()))\n",
"assert \"input_ids\" in result.keys()\n",
"assert \"attention_mask\" not in result.keys()\n",
"\n",
"# Retrieve a specific set of fields\n",
"result = tq.kv_batch_get(\n",
" keys=\"sample_1\",\n",
" partition_id=\"train\",\n",
" select_fields=[\"input_ids\", \"attention_mask\"],\n",
")\n",
"print(\"Fields returned:\", list(result.keys()))"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fields returned: ['input_ids']\n",
"Fields returned: ['attention_mask', 'input_ids']\n"
]
}
],
"execution_count": 9
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 7. Updating Fields Incrementally\n\nTransferQueue tracks each field (column) independently per key (row). \nYou can **add new fields** to existing keys with another `kv_put` /\n`kv_batch_put` call — the earlier fields are preserved."
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# Add a \"response\" field to the batch keys\n",
"response_fields = TensorDict(\n",
" {\"response\": torch.tensor([[100, 200], [300, 400], [500, 600]])},\n",
" batch_size=3,\n",
")\n",
"\n",
"tq.kv_batch_put(keys=keys, partition_id=\"train\", fields=response_fields)\n",
"\n",
"result = tq.kv_batch_get(keys=keys, partition_id=\"train\")\n",
"print(\"All fields now:\", list(result.keys()))\n",
"print(\"response:\\n\", result[\"response\"])"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"All fields now: ['attention_mask', 'input_ids', 'response']\n",
"response:\n",
" NestedTensor(size=(3, j11), offsets=tensor([0, 2, 4, 6]), contiguous=True)\n"
]
}
],
"execution_count": 10
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 8. Working with Nested (Variable-Length) Tensors\n\nIn many NLP and RL workloads each sample has a **different sequence length**\n(e.g. generated responses). PyTorch represents these as\n[nested tensors](https://pytorch.org/docs/stable/nested.html) with the\n**jagged layout** (`layout=torch.jagged`), and TransferQueue handles them\nnatively.\n\n> **Note:** Because individual samples have different shapes, you must use\n> `kv_batch_put` (not `kv_put`) to store nested tensors."
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"nested_keys = [\"nested_0\", \"nested_1\", \"nested_2\"]\n",
"\n",
"# Each sample has a different sequence length\n",
"nested_responses = torch.nested.as_nested_tensor(\n",
" [\n",
" torch.tensor([10, 11, 12]), # length 3\n",
" torch.tensor([20]), # length 1\n",
" torch.tensor([30, 31]), # length 2\n",
" ],\n",
" layout=torch.jagged,\n",
")\n",
"\n",
"nested_fields = TensorDict(\n",
" {\n",
" \"input_ids\": torch.tensor([[1, 2], [3, 4], [5, 6]]),\n",
" \"response\": nested_responses,\n",
" },\n",
" batch_size=3,\n",
")\n",
"\n",
"tq.kv_batch_put(\n",
" keys=nested_keys,\n",
" partition_id=\"train\",\n",
" fields=nested_fields,\n",
" tags=[{\"len\": 3}, {\"len\": 1}, {\"len\": 2}],\n",
")\n",
"print(\"Stored 3 samples with variable-length responses\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Stored 3 samples with variable-length responses\n"
]
}
],
"execution_count": 11
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# Retrieve all nested samples\n",
"result = tq.kv_batch_get(keys=nested_keys, partition_id=\"train\")\n",
"print(\"input_ids:\\n\", result[\"input_ids\"])\n",
"print(\"\\nresponse (nested tensor):\")\n",
"for i, sample in enumerate(result[\"response\"]):\n",
" print(f\" sample {i}: {sample} (length {sample.shape[0]})\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"input_ids:\n",
" NestedTensor(size=(3, j13), offsets=tensor([0, 2, 4, 6]), contiguous=True)\n",
"\n",
"response (nested tensor):\n",
" sample 0: tensor([10, 11, 12]) (length 3)\n",
" sample 1: tensor([20]) (length 1)\n",
" sample 2: tensor([30, 31]) (length 2)\n"
]
}
],
"execution_count": 12
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Partial-key retrieval works the same way with nested tensors — only the\nrequested samples are returned:"
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# Retrieve only the first and last sample\n",
"partial = tq.kv_batch_get(keys=[\"nested_0\", \"nested_2\"], partition_id=\"train\")\n",
"print(\"Partial-key input_ids:\\n\", partial[\"input_ids\"])\n",
"print(\"\\nPartial-key responses:\")\n",
"for i, sample in enumerate(partial[\"response\"]):\n",
" print(f\" sample {i}: {sample}\")\n",
"\n",
"# Verify correctness\n",
"assert torch.equal(partial[\"response\"][0], torch.tensor([10, 11, 12]))\n",
"assert torch.equal(partial[\"response\"][1], torch.tensor([30, 31]))\n",
"print(\"\\nAssertions passed!\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Partial-key input_ids:\n",
" NestedTensor(size=(2, j15), offsets=tensor([0, 2, 4]), contiguous=True)\n",
"\n",
"Partial-key responses:\n",
" sample 0: tensor([10, 11, 12])\n",
" sample 1: tensor([30, 31])\n",
"\n",
"Assertions passed!\n"
]
}
],
"execution_count": 13
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Higher-dimensional nested tensors work too. Here each sample is a 3D\ntensor with a variable first dimension (e.g. a different number of\nattention heads or generated candidates):"
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"nested_3d_keys = [\"nd3d_0\", \"nd3d_1\", \"nd3d_2\"]\n",
"\n",
"nested_3d = torch.nested.as_nested_tensor(\n",
" [\n",
" torch.randn(2, 3, 4), # 2 heads\n",
" torch.randn(5, 3, 4), # 5 heads\n",
" torch.randn(1, 3, 4), # 1 head\n",
" ],\n",
" layout=torch.jagged,\n",
")\n",
"\n",
"fields_3d = TensorDict(\n",
" {\n",
" \"input_ids\": torch.tensor([[1, 2], [3, 4], [5, 6]]),\n",
" \"hidden_states\": nested_3d,\n",
" },\n",
" batch_size=3,\n",
")\n",
"\n",
"tq.kv_batch_put(keys=nested_3d_keys, partition_id=\"train\", fields=fields_3d)\n",
"\n",
"result_3d = tq.kv_batch_get(keys=nested_3d_keys, partition_id=\"train\")\n",
"for i, sample in enumerate(result_3d[\"hidden_states\"]):\n",
" print(f\"sample {i} hidden_states shape: {sample.shape}\")\n",
"\n",
"# Clean up nested-tensor keys\n",
"tq.kv_clear(keys=nested_keys, partition_id=\"train\")\n",
"tq.kv_clear(keys=nested_3d_keys, partition_id=\"train\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"sample 0 hidden_states shape: torch.Size([2, 3, 4])\n",
"sample 1 hidden_states shape: torch.Size([5, 3, 4])\n",
"sample 2 hidden_states shape: torch.Size([1, 3, 4])\n"
]
}
],
"execution_count": 14
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 9. Storing Variable-Size Image Data\n\nA common multimodal scenario: each sample in a batch contains a\n**different number of images**, and each image has a **different\nresolution**. We can model this with a list of nested tensors — one\nnested tensor per sample — wrapped inside a `TensorDict`.\n\nSince the data is doubly ragged (variable count *and* variable size),\nwe store each image as a flattened 1-D tensor and pack all images per\nsample into a single jagged nested tensor. This way every sample is\none element of the batch, yet images retain their individual sizes."
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"image_keys = [\"img_sample_0\", \"img_sample_1\", \"img_sample_2\"]\n",
"\n",
"# Sample 0: 2 images — 3\u00d732\u00d732 (RGB, 32\u00d732) and 3\u00d764\u00d764\n",
"sample_0_images = [torch.randn(3, 32, 32), torch.randn(3, 64, 64)]\n",
"\n",
"# Sample 1: 1 image — 3\u00d748\u00d748\n",
"sample_1_images = [torch.randn(3, 48, 48)]\n",
"\n",
"# Sample 2: 3 images — 3\u00d716\u00d716, 3\u00d724\u00d724, 3\u00d732\u00d764\n",
"sample_2_images = [torch.randn(3, 16, 16), torch.randn(3, 24, 24), torch.randn(3, 32, 64)]\n",
"\n",
"\n",
"# Flatten each image to 1-D so they can live in a single jagged nested tensor per sample\n",
"def flatten_images(images):\n",
" return torch.cat([img.flatten() for img in images])\n",
"\n",
"\n",
"pixel_data = torch.nested.as_nested_tensor(\n",
" [\n",
" flatten_images(sample_0_images), # 3*32*32 + 3*64*64 = 15360\n",
" flatten_images(sample_1_images), # 3*48*48 = 6912\n",
" flatten_images(sample_2_images),\n",
" ], # 3*16*16 + 3*24*24 + 3*32*64 = 8736\n",
" layout=torch.jagged,\n",
")\n",
"\n",
"# Store the number of pixels per image so we can reconstruct later\n",
"image_shapes = torch.nested.as_nested_tensor(\n",
" [\n",
" torch.tensor([[3, 32, 32], [3, 64, 64]]), # 2 images\n",
" torch.tensor([[3, 48, 48]]), # 1 image\n",
" torch.tensor([[3, 16, 16], [3, 24, 24], [3, 32, 64]]), # 3 images\n",
" ],\n",
" layout=torch.jagged,\n",
")\n",
"\n",
"fields_img = TensorDict(\n",
" {\n",
" \"prompt\": torch.tensor([[101, 102], [201, 202], [301, 302]]),\n",
" \"pixel_data\": pixel_data,\n",
" \"image_shapes\": image_shapes,\n",
" },\n",
" batch_size=3,\n",
")\n",
"\n",
"tags_img = [\n",
" {\"num_images\": 2},\n",
" {\"num_images\": 1},\n",
" {\"num_images\": 3},\n",
"]\n",
"\n",
"tq.kv_batch_put(keys=image_keys, partition_id=\"train\", fields=fields_img, tags=tags_img)\n",
"print(\"Stored 3 samples with variable numbers of variable-size images\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Stored 3 samples with variable numbers of variable-size images\n"
]
}
],
"execution_count": 15
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# Retrieve and reconstruct the images\n",
"result_img = tq.kv_batch_get(keys=image_keys, partition_id=\"train\")\n",
"\n",
"for i in range(result_img.batch_size[0]):\n",
" shapes = result_img[\"image_shapes\"][i] # (num_images, 3) tensor\n",
" pixels = result_img[\"pixel_data\"][i] # flat 1-D tensor of all pixels\n",
" num_images = shapes.shape[0]\n",
"\n",
" offset = 0\n",
" reconstructed = []\n",
" for j in range(num_images):\n",
" c, h, w = shapes[j].tolist()\n",
" numel = c * h * w\n",
" img = pixels[offset : offset + numel].reshape(c, h, w)\n",
" reconstructed.append(img)\n",
" offset += numel\n",
"\n",
" print(f\"Sample {i}: {num_images} image(s) → {[tuple(img.shape) for img in reconstructed]}\")\n",
"\n",
"# Clean up\n",
"tq.kv_clear(keys=image_keys, partition_id=\"train\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sample 0: 2 image(s) → [(3, 32, 32), (3, 64, 64)]\n",
"Sample 1: 1 image(s) → [(3, 48, 48)]\n",
"Sample 2: 3 image(s) → [(3, 16, 16), (3, 24, 24), (3, 32, 64)]\n"
]
}
],
"execution_count": 16
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 10. Storing Non-Tensor Data — `NonTensorStack`\n",
"\n",
"Not every field is a numeric tensor. Prompts, file paths, JSON metadata,\n",
"or arbitrary Python objects can be stored as **non-tensor data** using\n",
"tensordict's `NonTensorStack`.\n",
"\n",
"- `NonTensorStack` wraps a **batch** of Python objects — one per sample.\n",
"\n",
"TransferQueue serialises them transparently alongside regular tensors."
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"nt_keys = [\"nt_sample_0\", \"nt_sample_1\", \"nt_sample_2\"]\n",
"\n",
"# Build a TensorDict that mixes tensors with non-tensor data\n",
"nt_fields = TensorDict(\n",
" {\n",
" \"input_ids\": torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),\n",
" # NonTensorStack: one string per sample in the batch\n",
" \"prompt_text\": NonTensorStack(\n",
" \"Summarise the following article.\",\n",
" \"Translate to French:\",\n",
" \"Write a poem about rain.\",\n",
" ),\n",
" # You can also store richer Python objects (dicts, lists, \u2026)\n",
" \"metadata\": NonTensorStack(\n",
" {\"source\": \"wiki\", \"lang\": \"en\"},\n",
" {\"source\": \"books\", \"lang\": \"fr\"},\n",
" {\"source\": \"user\", \"lang\": \"en\"},\n",
" ),\n",
" },\n",
" batch_size=3,\n",
")\n",
"\n",
"tq.kv_batch_put(keys=nt_keys, partition_id=\"train\", fields=nt_fields)\n",
"print(\"Stored 3 samples with non-tensor fields\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Stored 3 samples with non-tensor fields\n"
]
}
],
"execution_count": 17
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# Retrieve and inspect non-tensor fields\n",
"result_nt = tq.kv_batch_get(keys=nt_keys, partition_id=\"train\")\n",
"\n",
"print(\"input_ids:\\n\", result_nt[\"input_ids\"])\n",
"\n",
"print(\"\\nprompt_text (NonTensorStack):\")\n",
"for i, text in enumerate(list(result_nt[\"prompt_text\"])):\n",
" print(f\" [{i}] {text!r}\")\n",
"\n",
"print(\"\\nmetadata (NonTensorStack of dicts):\")\n",
"for i, meta in enumerate(list(result_nt[\"metadata\"])):\n",
" print(f\" [{i}] {meta}\")\n",
"\n",
"# You can also add a NonTensorStack field to a single key via kv_put\n",
"tq.kv_put(\n",
" key=\"single_nt\",\n",
" partition_id=\"train\",\n",
" fields={\"label\": NonTensorStack(\"positive\"), \"score\": torch.tensor([0.99])},\n",
")\n",
"single = tq.kv_batch_get(keys=\"single_nt\", partition_id=\"train\")\n",
"print(f\"\\nSingle-key non-tensor field: label={list(single['label'])}\")\n",
"\n",
"# Clean up\n",
"tq.kv_clear(keys=nt_keys, partition_id=\"train\")\n",
"tq.kv_clear(keys=\"single_nt\", partition_id=\"train\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"input_ids:\n",
" NestedTensor(size=(3, j25), offsets=tensor([0, 3, 6, 9]), contiguous=True)\n",
"\n",
"prompt_text (NonTensorStack):\n",
" [0] 'Summarise the following article.'\n",
" [1] 'Translate to French:'\n",
" [2] 'Write a poem about rain.'\n",
"\n",
"metadata (NonTensorStack of dicts):\n",
" [0] {'source': 'wiki', 'lang': 'en'}\n",
" [1] {'source': 'books', 'lang': 'fr'}\n",
" [2] {'source': 'user', 'lang': 'en'}\n",
"\n",
"Single-key non-tensor field: label=[['positive']]\n"
]
}
],
"execution_count": 18
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 11. Lazy Data Parsing with `data_parser`\n",
"\n",
"Sometimes you want to store lightweight **references** (e.g. URLs or file paths)\n",
"and defer the expensive loading / decoding until the data reaches the storage unit.\n",
"\n",
"`kv_put` / `kv_batch_put` accept an optional `data_parser` callable. It is executed **inside**\n",
"each `SimpleStorageUnit` at put time. The callable receives a plain `dict` (not a TensorDict)\n",
"mapping `field_name -> batched_values`. For a regular tensor column the value is a batched tensor;\n",
"for nested tensors (jagged or strided) and `NonTensorStack` columns the values are extracted into\n",
"a `list`. It must modify values in-place based on the original keys; do not add or remove keys.\n",
"The number of elements per column must also remain unchanged. Do not change the inner order of\n",
"values within each column.\n",
"\n",
"> **Design tip:** Separate the **core single-sample parser** from the **batch concurrency wrapper**.\n",
"> The wrapper can use `asyncio` to process all samples in parallel while the parser function itself\n",
"> remains synchronous to the caller.\n",
"\n",
"> **Note:** `data_parser` is only supported by the **SimpleStorage** backend."
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"import asyncio\n",
"import time\n",
"\n",
"\n",
"# 1. Define core single-sample parser (pure business logic, no asyncio, no batch)\n",
"def parse_url(url: str) -> torch.Tensor:\n",
" \"\"\"Parse a URL-like descriptor 'dtype:HxW' into a random tensor.\"\"\"\n",
" dtype_str, shape_str = url.split(\":\")\n",
" dtype = getattr(torch, dtype_str)\n",
" shape = [int(dim) for dim in shape_str.split(\"x\")]\n",
" return torch.randn(shape, dtype=dtype)\n",
"\n",
"\n",
"# 2. Define Batch-level parser: sync on the outside, async-parallel on the inside\n",
"def concurrent_batch_url_parser(field_data: dict) -> dict:\n",
" \"\"\"Batch-level data_parser executed inside SimpleStorageUnit.\n",
"\n",
" It receives a ``dict`` (not a TensorDict) where each value is a\n",
" batched column. For columns created from ``NonTensorStack`` the\n",
" value is a plain ``list`` of Python objects.\n",
"\n",
" Workflow:\n",
" 1. Spawns one async task per list element.\n",
" 2. Waits until *all* tasks finish (``asyncio.gather``).\n",
" 3. Replaces the list with the list of results.\n",
"\n",
" Because ``asyncio.run`` blocks until the loop finishes, this function\n",
" is **synchronous** to its caller: when it returns, every sample has\n",
" been processed.\n",
"\n",
" Args:\n",
" field_data: Mapping ``field_name -> batched_values``. The dict\n",
" keys must stay exactly the same; only values may be\n",
" transformed in-place.\n",
"\n",
" Returns:\n",
" The same dict with parsed values substituted.\n",
" \"\"\"\n",
" if \"data_to_be_parsed\" not in field_data:\n",
" return field_data\n",
"\n",
" urls: list[str] = field_data[\"data_to_be_parsed\"]\n",
"\n",
" async def _async_parse_single(url: str) -> torch.Tensor:\n",
" await asyncio.sleep(1.0) # Add fixed delay per sample\n",
" return parse_url(url)\n",
"\n",
" async def _process_all():\n",
" tasks = [asyncio.create_task(_async_parse_single(url)) for url in urls]\n",
" return await asyncio.gather(*tasks)\n",
"\n",
" start = time.perf_counter()\n",
" field_data[\"data_to_be_parsed\"] = asyncio.run(_process_all())\n",
" elapsed = time.perf_counter() - start\n",
"\n",
" print(f\"[data_parser] Processed {len(urls)} samples in {elapsed:.2f}s (serial would be ~{len(urls)}.0s)\")\n",
" return field_data\n",
"\n",
"\n",
"# ---------------------------------------------------------------------------\n",
"# Build the batch\n",
"# ---------------------------------------------------------------------------\n",
"batch_size = 32\n",
"\n",
"normal_data = torch.randn(batch_size, 2)\n",
"\n",
"# URL-like strings: all use the same dtype so TQ can pack them on get\n",
"shapes = [(i % 4 + 1, i % 3 + 2) for i in range(batch_size)]\n",
"urls = [f\"float32:{h}x{w}\" for h, w in shapes]\n",
"\n",
"parser_fields = TensorDict(\n",
" {\n",
" \"normal_data\": normal_data,\n",
" \"data_to_be_parsed\": NonTensorStack(*urls),\n",
" },\n",
" batch_size=batch_size,\n",
")\n",
"\n",
"data_parser_keys = [f\"data_parser_sample_{i}\" for i in range(batch_size)]\n",
"\n",
"put_start_time = time.perf_counter()\n",
"meta = tq.kv_batch_put(\n",
" keys=data_parser_keys,\n",
" partition_id=\"train\",\n",
" fields=parser_fields,\n",
" data_parser=concurrent_batch_url_parser,\n",
")\n",
"put_elapsed = time.perf_counter() - put_start_time\n",
"print(f\"Put succeeded. Fields: {meta.fields}\")\n",
"print(f\"Total kv_batch_put time: {put_elapsed:.2f}s (concurrency keeps it ~1s, not {batch_size}s)\\n\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Put succeeded. Fields: ['data_to_be_parsed', 'normal_data']\n",
"Total kv_batch_put time: 1.02s (concurrency keeps it ~1s, not 32s)\n",
"\n"
]
}
],
"execution_count": 19
},
{
"cell_type": "markdown",
"metadata": {},
"source": "When `kv_batch_put` returns, the user-defined data parser has also finished executing. We can then safely call `kv_batch_get` to retrieve the parsed data."
},
{
"cell_type": "code",
"metadata": {},
"source": [
"result = tq.kv_batch_get(keys=data_parser_keys, partition_id=\"train\")\n\n# normal_data should be unchanged\n# normal_data is packed as a nested tensor on retrieval; compare per-sample.\nfor t1, t2 in zip(result[\"normal_data\"], normal_data, strict=True):\n torch.testing.assert_close(t1, t2)\nprint(\"[PASS] normal_data is unchanged.\")\n\n# data_to_be_parsed should now be tensors with the requested shapes\nexpected_shapes = [(i % 4 + 1, i % 3 + 2) for i in range(batch_size)]\nfor i, expected in enumerate(expected_shapes):\n tensor = result[\"data_to_be_parsed\"][i]\n assert tensor.dtype == torch.float32\n actual = tuple(tensor.shape)\n assert actual == expected, f\"Mismatch at {i}: {actual} != {expected}\"\nprint(f\"[PASS] All {batch_size} parsed tensors have correct dtype & shape.\")\n\n# Timing sanity check: serial would be ~batch_size seconds.\n# Because asyncio tasks run in parallel inside the parser, it should be ~1 s.\nassert put_elapsed < 2.0, f\"Expected concurrent execution (~1s), but took {put_elapsed:.2f}s.\"\nprint(f\"[PASS] Timing looks concurrent: {put_elapsed:.2f}s < 2.0s\")\n\n# wait for Ray log collect\ntime.sleep(2)"
],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
""
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[PASS] normal_data is unchanged.\n",
"[PASS] All 32 parsed tensors have correct dtype & shape.\n",
"[PASS] Timing looks concurrent: 1.02s < 2.0s\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
""
]
}
],
"execution_count": 20
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 12. Multiple Partitions\n\nPartitions provide logical isolation — the same key name can exist in\ndifferent partitions without conflict."
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"tq.kv_put(\n",
" key=\"val_sample_0\",\n",
" partition_id=\"validation\",\n",
" fields={\"input_ids\": torch.tensor([99, 98, 97])},\n",
" tag={\"split\": \"val\"},\n",
")\n",
"\n",
"all_info = tq.kv_list() # no partition_id → list everything\n",
"print(\"All partitions:\", list(all_info.keys()))"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"All partitions: ['train', 'validation']\n"
]
}
],
"execution_count": 21
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 13. Clean Up — `kv_clear` and `close`\n\nRemove specific keys with `kv_clear`, then shut down the system with `tq.close()`."
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# Clear individual keys\n",
"tq.kv_clear(keys=\"sample_0\", partition_id=\"train\")\n",
"tq.kv_clear(keys=\"sample_1\", partition_id=\"train\")\n",
"tq.kv_clear(keys=keys, partition_id=\"train\")\n",
"tq.kv_clear(keys=data_parser_keys, partition_id=\"train\")\n",
"tq.kv_clear(keys=\"val_sample_0\", partition_id=\"validation\")\n",
"print(\"All keys cleared.\")\n",
"\n",
"remaining = tq.kv_list()\n",
"total_keys = sum(len(v) for v in remaining.values())\n",
"print(f\"Remaining keys across all partitions: {total_keys}\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"All keys cleared.\n",
"Remaining keys across all partitions: 0\n"
]
}
],
"execution_count": 22
},
{
"cell_type": "code",
"metadata": {},
"source": [
"tq.close()\n",
"ray.shutdown()\n",
"print(\"TransferQueue and Ray shut down.\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"TransferQueue and Ray shut down.\n"
]
}
],
"execution_count": 23
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n\n## Summary\n\n| Operation | Function | Notes |\n|---|---|---|\n| Init | `tq.init(config)` | Call once; subsequent processes auto-connect |\n| Put single | `tq.kv_put(key, partition_id, fields, tag)` | `fields` can be a plain dict |\n| Put batch | `tq.kv_batch_put(keys, partition_id, fields, tags)` | `fields` must be a `TensorDict` |\n| Put with parser | `tq.kv_batch_put(..., data_parser=fn)` | Only for **SimpleStorage**; receives dict, can use asyncio inside |\n| Get | `tq.kv_batch_get(keys, partition_id, select_fields=None)` | Returns a `TensorDict` |\n| List | `tq.kv_list(partition_id=None)` | Returns `{partition: {key: tag}}` |\n| Clear | `tq.kv_clear(keys, partition_id)` | Removes keys + data |\n| Close | `tq.close()` | Tears down controller & storage |\n\nFor **async** variants, use `async_kv_put`, `async_kv_batch_put`,\n`async_kv_batch_get`, `async_kv_list`, and `async_kv_clear`.\n\nFor low-level, metadata-based access, see `tq.get_client()` and the\n[official tutorials](https://github.com/Ascend/TransferQueue/tree/main/tutorial)."
]
},
{
"cell_type": "code",
"metadata": {},
"source": [],
"outputs": [],
"execution_count": 23
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.14"
}
},
"nbformat": 4,
"nbformat_minor": 4
}