{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TransferQueue Tutorial — Basic Usage\n",
    "\n",
    "This notebook walks through the core **Key-Value (KV) interface** of\n",
    "[TransferQueue](https://github.com/Ascend/TransferQueue), an asynchronous\n",
    "streaming data management module for efficient post-training workflows.\n",
    "\n",
    "**What you will learn:**\n",
    "\n",
    "1. Initialise TransferQueue (with Ray)\n",
    "2. Store a single sample — `kv_put`\n",
    "3. Store a batch of samples — `kv_batch_put`\n",
    "4. Retrieve data — `kv_batch_get`\n",
    "5. List stored keys & tags — `kv_list`\n",
    "6. Partial-key and partial-field retrieval\n",
    "7. Updating fields incrementally\n",
    "8. Working with nested (variable-length) tensors\n",
    "9. Storing variable-size image data\n",
    "10. Storing non-tensor data (`NonTensorStack`)\n",
    "11. Lazy Data Parsing with `data_parser`\n",
    "12. Multiple partitions\n",
    "13. Clean up — `kv_clear` / `close`\n",
    "\n",
    "> **Prerequisites:** `pip install TransferQueue` (or install from source).  \n",
    "> Ray will be started automatically in this notebook."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Initialization\n\nTransferQueue runs on top of [Ray](https://www.ray.io/).  \nWe start Ray, then call `tq.init()` with a minimal configuration that uses the\nbuilt-in **SimpleStorage** backend."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "import ray\n",
    "import torch\n",
    "from omegaconf import OmegaConf\n",
    "from tensordict import TensorDict\n",
    "from tensordict.tensorclass import NonTensorStack\n",
    "\n",
    "import transfer_queue as tq\n",
    "\n",
    "ray.init(ignore_reinit_error=True)\n",
    "\n",
    "config = OmegaConf.create(\n",
    "    {\n",
    "        \"controller\": {\"polling_mode\": True},\n",
    "        \"backend\": {\n",
    "            \"storage_backend\": \"SimpleStorage\",\n",
    "            \"SimpleStorage\": {\n",
    "                \"total_storage_size\": 200,\n",
    "                \"num_data_storage_units\": 2,\n",
    "            },\n",
    "        },\n",
    "    }\n",
    ")\n",
    "\n",
    "tq.init(config)\n",
    "print(\"TransferQueue is ready!\")"
   ],
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      ""
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TransferQueue is ready!\n"
     ]
    }
   ],
   "execution_count": 1
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Store a Single Sample — `kv_put`\n\n`kv_put` stores **one** key-value pair.  \n- `key` — a unique string identifier for the sample  \n- `partition_id` — a logical namespace (like a table name)  \n- `fields` — a `dict` of tensors **or** a `TensorDict`  \n- `tag` — optional metadata dict attached to the key"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "tq.kv_put(\n",
    "    key=\"sample_0\",\n",
    "    partition_id=\"train\",\n",
    "    fields={\"input_ids\": torch.tensor([1, 2, 3, 4])},\n",
    "    tag={\"source\": \"wikipedia\", \"score\": 0.95},\n",
    ")\n",
    "print(\"Stored sample_0\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Stored sample_0\n"
     ]
    }
   ],
   "execution_count": 2
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can also pass a pre-built `TensorDict` directly (the batch dimension\nmust be 1 for `kv_put`):"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "fields_td = TensorDict(\n",
    "    {\n",
    "        \"input_ids\": torch.tensor([[5, 6, 7, 8]]),\n",
    "        \"attention_mask\": torch.ones(1, 4, dtype=torch.long),\n",
    "    },\n",
    "    batch_size=1,\n",
    ")\n",
    "\n",
    "tq.kv_put(\n",
    "    key=\"sample_1\",\n",
    "    partition_id=\"train\",\n",
    "    fields=fields_td,\n",
    "    tag={\"source\": \"books\", \"score\": 0.88},\n",
    ")\n",
    "print(\"Stored sample_1\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Stored sample_1\n"
     ]
    }
   ],
   "execution_count": 3
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Store a Batch of Samples — `kv_batch_put`\n\nWhen you have multiple samples, `kv_batch_put` is more efficient than\ncalling `kv_put` in a loop.  The `fields` TensorDict must have\n`batch_size == len(keys)`."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "keys = [\"batch_0\", \"batch_1\", \"batch_2\"]\n",
    "\n",
    "fields = TensorDict(\n",
    "    {\n",
    "        \"input_ids\": torch.tensor([[10, 20], [30, 40], [50, 60]]),\n",
    "        \"attention_mask\": torch.ones(3, 2, dtype=torch.long),\n",
    "    },\n",
    "    batch_size=3,\n",
    ")\n",
    "\n",
    "tags = [\n",
    "    {\"split\": \"train\", \"idx\": 0},\n",
    "    {\"split\": \"train\", \"idx\": 1},\n",
    "    {\"split\": \"train\", \"idx\": 2},\n",
    "]\n",
    "\n",
    "tq.kv_batch_put(keys=keys, partition_id=\"train\", fields=fields, tags=tags)\n",
    "print(f\"Stored {len(keys)} samples in one call\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Stored 3 samples in one call\n"
     ]
    }
   ],
   "execution_count": 4
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Retrieve Data — `kv_batch_get`\n\nRetrieve samples by key(s).  The result is always a `TensorDict`."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# Retrieve a single key (pass a string)\n",
    "result = tq.kv_batch_get(keys=\"sample_0\", partition_id=\"train\")\n",
    "print(\"sample_0 →\", result)\n",
    "print(\"input_ids:\", result[\"input_ids\"])"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sample_0 → TensorDict(\n",
      "    fields={\n",
      "        input_ids: NestedTensor(shape=torch.Size([1, j1]), device=cpu, dtype=torch.int64, is_shared=False)},\n",
      "    batch_size=torch.Size([1]),\n",
      "    device=None,\n",
      "    is_shared=False)\n",
      "input_ids: NestedTensor(size=(1, j1), offsets=tensor([0, 4]), contiguous=True)\n"
     ]
    }
   ],
   "execution_count": 5
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# Retrieve multiple keys at once\n",
    "result = tq.kv_batch_get(keys=keys, partition_id=\"train\")\n",
    "print(\"batch result →\", result)\n",
    "print(\"input_ids:\\n\", result[\"input_ids\"])\n",
    "print(\"attention_mask:\\n\", result[\"attention_mask\"])"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "batch result → TensorDict(\n",
      "    fields={\n",
      "        attention_mask: NestedTensor(shape=torch.Size([3, j2]), device=cpu, dtype=torch.int64, is_shared=False),\n",
      "        input_ids: NestedTensor(shape=torch.Size([3, j3]), device=cpu, dtype=torch.int64, is_shared=False)},\n",
      "    batch_size=torch.Size([3]),\n",
      "    device=None,\n",
      "    is_shared=False)\n",
      "input_ids:\n",
      " NestedTensor(size=(3, j3), offsets=tensor([0, 2, 4, 6]), contiguous=True)\n",
      "attention_mask:\n",
      " NestedTensor(size=(3, j2), offsets=tensor([0, 2, 4, 6]), contiguous=True)\n"
     ]
    }
   ],
   "execution_count": 6
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. List Keys & Tags — `kv_list`\n\n`kv_list` returns a nested dict:\n```\n{ partition_id: { key: tag_dict, ... }, ... }\n```\nPass `partition_id` to filter, or omit it to see everything."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "info = tq.kv_list(partition_id=\"train\")\n",
    "\n",
    "for partition, key_tags in info.items():\n",
    "    print(f\"\\nPartition: {partition}\")\n",
    "    for key, tag in key_tags.items():\n",
    "        print(f\"  {key}: {tag}\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Partition: train\n",
      "  sample_0: {'source': 'wikipedia', 'score': 0.95}\n",
      "  sample_1: {'source': 'books', 'score': 0.88}\n",
      "  batch_0: {'split': 'train', 'idx': 0}\n",
      "  batch_1: {'split': 'train', 'idx': 1}\n",
      "  batch_2: {'split': 'train', 'idx': 2}\n"
     ]
    }
   ],
   "execution_count": 7
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Partial-Key and Partial-Field Retrieval\n\nYou don't have to retrieve *all* keys or *all* fields at once."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6a. Partial Keys\n\nJust pass a subset of the keys you stored."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "partial = tq.kv_batch_get(keys=[\"batch_0\", \"batch_2\"], partition_id=\"train\")\n",
    "print(\"Partial-key input_ids:\\n\", partial[\"input_ids\"])\n",
    "assert partial[\"input_ids\"].shape[0] == 2  # only 2 rows"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Partial-key input_ids:\n",
      " NestedTensor(size=(2, j5), offsets=tensor([0, 2, 4]), contiguous=True)\n"
     ]
    }
   ],
   "execution_count": 8
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6b. Partial Fields\n\nUse the `fields` argument to select specific columns."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# Retrieve only input_ids (single field)\n",
    "result = tq.kv_batch_get(keys=\"sample_1\", partition_id=\"train\", select_fields=\"input_ids\")\n",
    "print(\"Fields returned:\", list(result.keys()))\n",
    "assert \"input_ids\" in result.keys()\n",
    "assert \"attention_mask\" not in result.keys()\n",
    "\n",
    "# Retrieve a specific set of fields\n",
    "result = tq.kv_batch_get(\n",
    "    keys=\"sample_1\",\n",
    "    partition_id=\"train\",\n",
    "    select_fields=[\"input_ids\", \"attention_mask\"],\n",
    ")\n",
    "print(\"Fields returned:\", list(result.keys()))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fields returned: ['input_ids']\n",
      "Fields returned: ['attention_mask', 'input_ids']\n"
     ]
    }
   ],
   "execution_count": 9
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. Updating Fields Incrementally\n\nTransferQueue tracks each field (column) independently per key (row).  \nYou can **add new fields** to existing keys with another `kv_put` /\n`kv_batch_put` call — the earlier fields are preserved."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# Add a \"response\" field to the batch keys\n",
    "response_fields = TensorDict(\n",
    "    {\"response\": torch.tensor([[100, 200], [300, 400], [500, 600]])},\n",
    "    batch_size=3,\n",
    ")\n",
    "\n",
    "tq.kv_batch_put(keys=keys, partition_id=\"train\", fields=response_fields)\n",
    "\n",
    "result = tq.kv_batch_get(keys=keys, partition_id=\"train\")\n",
    "print(\"All fields now:\", list(result.keys()))\n",
    "print(\"response:\\n\", result[\"response\"])"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "All fields now: ['attention_mask', 'input_ids', 'response']\n",
      "response:\n",
      " NestedTensor(size=(3, j11), offsets=tensor([0, 2, 4, 6]), contiguous=True)\n"
     ]
    }
   ],
   "execution_count": 10
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 8. Working with Nested (Variable-Length) Tensors\n\nIn many NLP and RL workloads each sample has a **different sequence length**\n(e.g. generated responses).  PyTorch represents these as\n[nested tensors](https://pytorch.org/docs/stable/nested.html) with the\n**jagged layout** (`layout=torch.jagged`), and TransferQueue handles them\nnatively.\n\n> **Note:** Because individual samples have different shapes, you must use\n> `kv_batch_put` (not `kv_put`) to store nested tensors."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "nested_keys = [\"nested_0\", \"nested_1\", \"nested_2\"]\n",
    "\n",
    "# Each sample has a different sequence length\n",
    "nested_responses = torch.nested.as_nested_tensor(\n",
    "    [\n",
    "        torch.tensor([10, 11, 12]),  # length 3\n",
    "        torch.tensor([20]),  # length 1\n",
    "        torch.tensor([30, 31]),  # length 2\n",
    "    ],\n",
    "    layout=torch.jagged,\n",
    ")\n",
    "\n",
    "nested_fields = TensorDict(\n",
    "    {\n",
    "        \"input_ids\": torch.tensor([[1, 2], [3, 4], [5, 6]]),\n",
    "        \"response\": nested_responses,\n",
    "    },\n",
    "    batch_size=3,\n",
    ")\n",
    "\n",
    "tq.kv_batch_put(\n",
    "    keys=nested_keys,\n",
    "    partition_id=\"train\",\n",
    "    fields=nested_fields,\n",
    "    tags=[{\"len\": 3}, {\"len\": 1}, {\"len\": 2}],\n",
    ")\n",
    "print(\"Stored 3 samples with variable-length responses\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Stored 3 samples with variable-length responses\n"
     ]
    }
   ],
   "execution_count": 11
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# Retrieve all nested samples\n",
    "result = tq.kv_batch_get(keys=nested_keys, partition_id=\"train\")\n",
    "print(\"input_ids:\\n\", result[\"input_ids\"])\n",
    "print(\"\\nresponse (nested tensor):\")\n",
    "for i, sample in enumerate(result[\"response\"]):\n",
    "    print(f\"  sample {i}: {sample}  (length {sample.shape[0]})\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input_ids:\n",
      " NestedTensor(size=(3, j13), offsets=tensor([0, 2, 4, 6]), contiguous=True)\n",
      "\n",
      "response (nested tensor):\n",
      "  sample 0: tensor([10, 11, 12])  (length 3)\n",
      "  sample 1: tensor([20])  (length 1)\n",
      "  sample 2: tensor([30, 31])  (length 2)\n"
     ]
    }
   ],
   "execution_count": 12
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Partial-key retrieval works the same way with nested tensors — only the\nrequested samples are returned:"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# Retrieve only the first and last sample\n",
    "partial = tq.kv_batch_get(keys=[\"nested_0\", \"nested_2\"], partition_id=\"train\")\n",
    "print(\"Partial-key input_ids:\\n\", partial[\"input_ids\"])\n",
    "print(\"\\nPartial-key responses:\")\n",
    "for i, sample in enumerate(partial[\"response\"]):\n",
    "    print(f\"  sample {i}: {sample}\")\n",
    "\n",
    "# Verify correctness\n",
    "assert torch.equal(partial[\"response\"][0], torch.tensor([10, 11, 12]))\n",
    "assert torch.equal(partial[\"response\"][1], torch.tensor([30, 31]))\n",
    "print(\"\\nAssertions passed!\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Partial-key input_ids:\n",
      " NestedTensor(size=(2, j15), offsets=tensor([0, 2, 4]), contiguous=True)\n",
      "\n",
      "Partial-key responses:\n",
      "  sample 0: tensor([10, 11, 12])\n",
      "  sample 1: tensor([30, 31])\n",
      "\n",
      "Assertions passed!\n"
     ]
    }
   ],
   "execution_count": 13
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Higher-dimensional nested tensors work too. Here each sample is a 3D\ntensor with a variable first dimension (e.g. a different number of\nattention heads or generated candidates):"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "nested_3d_keys = [\"nd3d_0\", \"nd3d_1\", \"nd3d_2\"]\n",
    "\n",
    "nested_3d = torch.nested.as_nested_tensor(\n",
    "    [\n",
    "        torch.randn(2, 3, 4),  # 2 heads\n",
    "        torch.randn(5, 3, 4),  # 5 heads\n",
    "        torch.randn(1, 3, 4),  # 1 head\n",
    "    ],\n",
    "    layout=torch.jagged,\n",
    ")\n",
    "\n",
    "fields_3d = TensorDict(\n",
    "    {\n",
    "        \"input_ids\": torch.tensor([[1, 2], [3, 4], [5, 6]]),\n",
    "        \"hidden_states\": nested_3d,\n",
    "    },\n",
    "    batch_size=3,\n",
    ")\n",
    "\n",
    "tq.kv_batch_put(keys=nested_3d_keys, partition_id=\"train\", fields=fields_3d)\n",
    "\n",
    "result_3d = tq.kv_batch_get(keys=nested_3d_keys, partition_id=\"train\")\n",
    "for i, sample in enumerate(result_3d[\"hidden_states\"]):\n",
    "    print(f\"sample {i} hidden_states shape: {sample.shape}\")\n",
    "\n",
    "# Clean up nested-tensor keys\n",
    "tq.kv_clear(keys=nested_keys, partition_id=\"train\")\n",
    "tq.kv_clear(keys=nested_3d_keys, partition_id=\"train\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sample 0 hidden_states shape: torch.Size([2, 3, 4])\n",
      "sample 1 hidden_states shape: torch.Size([5, 3, 4])\n",
      "sample 2 hidden_states shape: torch.Size([1, 3, 4])\n"
     ]
    }
   ],
   "execution_count": 14
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 9. Storing Variable-Size Image Data\n\nA common multimodal scenario: each sample in a batch contains a\n**different number of images**, and each image has a **different\nresolution**.  We can model this with a list of nested tensors — one\nnested tensor per sample — wrapped inside a `TensorDict`.\n\nSince the data is doubly ragged (variable count *and* variable size),\nwe store each image as a flattened 1-D tensor and pack all images per\nsample into a single jagged nested tensor.  This way every sample is\none element of the batch, yet images retain their individual sizes."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "image_keys = [\"img_sample_0\", \"img_sample_1\", \"img_sample_2\"]\n",
    "\n",
    "# Sample 0: 2 images — 3\u00d732\u00d732 (RGB, 32\u00d732) and 3\u00d764\u00d764\n",
    "sample_0_images = [torch.randn(3, 32, 32), torch.randn(3, 64, 64)]\n",
    "\n",
    "# Sample 1: 1 image — 3\u00d748\u00d748\n",
    "sample_1_images = [torch.randn(3, 48, 48)]\n",
    "\n",
    "# Sample 2: 3 images — 3\u00d716\u00d716, 3\u00d724\u00d724, 3\u00d732\u00d764\n",
    "sample_2_images = [torch.randn(3, 16, 16), torch.randn(3, 24, 24), torch.randn(3, 32, 64)]\n",
    "\n",
    "\n",
    "# Flatten each image to 1-D so they can live in a single jagged nested tensor per sample\n",
    "def flatten_images(images):\n",
    "    return torch.cat([img.flatten() for img in images])\n",
    "\n",
    "\n",
    "pixel_data = torch.nested.as_nested_tensor(\n",
    "    [\n",
    "        flatten_images(sample_0_images),  # 3*32*32 + 3*64*64 = 15360\n",
    "        flatten_images(sample_1_images),  # 3*48*48            = 6912\n",
    "        flatten_images(sample_2_images),\n",
    "    ],  # 3*16*16 + 3*24*24 + 3*32*64 = 8736\n",
    "    layout=torch.jagged,\n",
    ")\n",
    "\n",
    "# Store the number of pixels per image so we can reconstruct later\n",
    "image_shapes = torch.nested.as_nested_tensor(\n",
    "    [\n",
    "        torch.tensor([[3, 32, 32], [3, 64, 64]]),  # 2 images\n",
    "        torch.tensor([[3, 48, 48]]),  # 1 image\n",
    "        torch.tensor([[3, 16, 16], [3, 24, 24], [3, 32, 64]]),  # 3 images\n",
    "    ],\n",
    "    layout=torch.jagged,\n",
    ")\n",
    "\n",
    "fields_img = TensorDict(\n",
    "    {\n",
    "        \"prompt\": torch.tensor([[101, 102], [201, 202], [301, 302]]),\n",
    "        \"pixel_data\": pixel_data,\n",
    "        \"image_shapes\": image_shapes,\n",
    "    },\n",
    "    batch_size=3,\n",
    ")\n",
    "\n",
    "tags_img = [\n",
    "    {\"num_images\": 2},\n",
    "    {\"num_images\": 1},\n",
    "    {\"num_images\": 3},\n",
    "]\n",
    "\n",
    "tq.kv_batch_put(keys=image_keys, partition_id=\"train\", fields=fields_img, tags=tags_img)\n",
    "print(\"Stored 3 samples with variable numbers of variable-size images\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Stored 3 samples with variable numbers of variable-size images\n"
     ]
    }
   ],
   "execution_count": 15
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# Retrieve and reconstruct the images\n",
    "result_img = tq.kv_batch_get(keys=image_keys, partition_id=\"train\")\n",
    "\n",
    "for i in range(result_img.batch_size[0]):\n",
    "    shapes = result_img[\"image_shapes\"][i]  # (num_images, 3) tensor\n",
    "    pixels = result_img[\"pixel_data\"][i]  # flat 1-D tensor of all pixels\n",
    "    num_images = shapes.shape[0]\n",
    "\n",
    "    offset = 0\n",
    "    reconstructed = []\n",
    "    for j in range(num_images):\n",
    "        c, h, w = shapes[j].tolist()\n",
    "        numel = c * h * w\n",
    "        img = pixels[offset : offset + numel].reshape(c, h, w)\n",
    "        reconstructed.append(img)\n",
    "        offset += numel\n",
    "\n",
    "    print(f\"Sample {i}: {num_images} image(s) → {[tuple(img.shape) for img in reconstructed]}\")\n",
    "\n",
    "# Clean up\n",
    "tq.kv_clear(keys=image_keys, partition_id=\"train\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sample 0: 2 image(s) → [(3, 32, 32), (3, 64, 64)]\n",
      "Sample 1: 1 image(s) → [(3, 48, 48)]\n",
      "Sample 2: 3 image(s) → [(3, 16, 16), (3, 24, 24), (3, 32, 64)]\n"
     ]
    }
   ],
   "execution_count": 16
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 10. Storing Non-Tensor Data — `NonTensorStack`\n",
    "\n",
    "Not every field is a numeric tensor.  Prompts, file paths, JSON metadata,\n",
    "or arbitrary Python objects can be stored as **non-tensor data** using\n",
    "tensordict's `NonTensorStack`.\n",
    "\n",
    "- `NonTensorStack` wraps a **batch** of Python objects — one per sample.\n",
    "\n",
    "TransferQueue serialises them transparently alongside regular tensors."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "nt_keys = [\"nt_sample_0\", \"nt_sample_1\", \"nt_sample_2\"]\n",
    "\n",
    "# Build a TensorDict that mixes tensors with non-tensor data\n",
    "nt_fields = TensorDict(\n",
    "    {\n",
    "        \"input_ids\": torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),\n",
    "        # NonTensorStack: one string per sample in the batch\n",
    "        \"prompt_text\": NonTensorStack(\n",
    "            \"Summarise the following article.\",\n",
    "            \"Translate to French:\",\n",
    "            \"Write a poem about rain.\",\n",
    "        ),\n",
    "        # You can also store richer Python objects (dicts, lists, \u2026)\n",
    "        \"metadata\": NonTensorStack(\n",
    "            {\"source\": \"wiki\", \"lang\": \"en\"},\n",
    "            {\"source\": \"books\", \"lang\": \"fr\"},\n",
    "            {\"source\": \"user\", \"lang\": \"en\"},\n",
    "        ),\n",
    "    },\n",
    "    batch_size=3,\n",
    ")\n",
    "\n",
    "tq.kv_batch_put(keys=nt_keys, partition_id=\"train\", fields=nt_fields)\n",
    "print(\"Stored 3 samples with non-tensor fields\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Stored 3 samples with non-tensor fields\n"
     ]
    }
   ],
   "execution_count": 17
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# Retrieve and inspect non-tensor fields\n",
    "result_nt = tq.kv_batch_get(keys=nt_keys, partition_id=\"train\")\n",
    "\n",
    "print(\"input_ids:\\n\", result_nt[\"input_ids\"])\n",
    "\n",
    "print(\"\\nprompt_text (NonTensorStack):\")\n",
    "for i, text in enumerate(list(result_nt[\"prompt_text\"])):\n",
    "    print(f\"  [{i}] {text!r}\")\n",
    "\n",
    "print(\"\\nmetadata (NonTensorStack of dicts):\")\n",
    "for i, meta in enumerate(list(result_nt[\"metadata\"])):\n",
    "    print(f\"  [{i}] {meta}\")\n",
    "\n",
    "# You can also add a NonTensorStack field to a single key via kv_put\n",
    "tq.kv_put(\n",
    "    key=\"single_nt\",\n",
    "    partition_id=\"train\",\n",
    "    fields={\"label\": NonTensorStack(\"positive\"), \"score\": torch.tensor([0.99])},\n",
    ")\n",
    "single = tq.kv_batch_get(keys=\"single_nt\", partition_id=\"train\")\n",
    "print(f\"\\nSingle-key non-tensor field: label={list(single['label'])}\")\n",
    "\n",
    "# Clean up\n",
    "tq.kv_clear(keys=nt_keys, partition_id=\"train\")\n",
    "tq.kv_clear(keys=\"single_nt\", partition_id=\"train\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input_ids:\n",
      " NestedTensor(size=(3, j25), offsets=tensor([0, 3, 6, 9]), contiguous=True)\n",
      "\n",
      "prompt_text (NonTensorStack):\n",
      "  [0] 'Summarise the following article.'\n",
      "  [1] 'Translate to French:'\n",
      "  [2] 'Write a poem about rain.'\n",
      "\n",
      "metadata (NonTensorStack of dicts):\n",
      "  [0] {'source': 'wiki', 'lang': 'en'}\n",
      "  [1] {'source': 'books', 'lang': 'fr'}\n",
      "  [2] {'source': 'user', 'lang': 'en'}\n",
      "\n",
      "Single-key non-tensor field: label=[['positive']]\n"
     ]
    }
   ],
   "execution_count": 18
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 11. Lazy Data Parsing with `data_parser`\n",
    "\n",
    "Sometimes you want to store lightweight **references** (e.g. URLs or file paths)\n",
    "and defer the expensive loading / decoding until the data reaches the storage unit.\n",
    "\n",
    "`kv_put` / `kv_batch_put` accept an optional `data_parser` callable.  It is executed **inside**\n",
    "each `SimpleStorageUnit` at put time.  The callable receives a plain `dict` (not a TensorDict)\n",
    "mapping `field_name -> batched_values`.  For a regular tensor column the value is a batched tensor;\n",
    "for nested tensors (jagged or strided) and `NonTensorStack` columns the values are extracted into\n",
    "a `list`.  It must modify values in-place based on the original keys; do not add or remove keys.\n",
    "The number of elements per column must also remain unchanged. Do not change the inner order of\n",
    "values within each column.\n",
    "\n",
    "> **Design tip:** Separate the **core single-sample parser**  from the **batch concurrency wrapper**.\n",
    "> The wrapper can use `asyncio` to process all samples in parallel while the parser function itself\n",
    "> remains synchronous to the caller.\n",
    "\n",
    "> **Note:** `data_parser` is only supported by the **SimpleStorage** backend."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "import asyncio\n",
    "import time\n",
    "\n",
    "\n",
    "# 1. Define core single-sample parser (pure business logic, no asyncio, no batch)\n",
    "def parse_url(url: str) -> torch.Tensor:\n",
    "    \"\"\"Parse a URL-like descriptor 'dtype:HxW' into a random tensor.\"\"\"\n",
    "    dtype_str, shape_str = url.split(\":\")\n",
    "    dtype = getattr(torch, dtype_str)\n",
    "    shape = [int(dim) for dim in shape_str.split(\"x\")]\n",
    "    return torch.randn(shape, dtype=dtype)\n",
    "\n",
    "\n",
    "# 2. Define Batch-level parser: sync on the outside, async-parallel on the inside\n",
    "def concurrent_batch_url_parser(field_data: dict) -> dict:\n",
    "    \"\"\"Batch-level data_parser executed inside SimpleStorageUnit.\n",
    "\n",
    "    It receives a ``dict`` (not a TensorDict) where each value is a\n",
    "    batched column.  For columns created from ``NonTensorStack`` the\n",
    "    value is a plain ``list`` of Python objects.\n",
    "\n",
    "    Workflow:\n",
    "    1. Spawns one async task per list element.\n",
    "    2. Waits until *all* tasks finish (``asyncio.gather``).\n",
    "    3. Replaces the list with the list of results.\n",
    "\n",
    "    Because ``asyncio.run`` blocks until the loop finishes, this function\n",
    "    is **synchronous** to its caller: when it returns, every sample has\n",
    "    been processed.\n",
    "\n",
    "    Args:\n",
    "        field_data: Mapping ``field_name -> batched_values``.  The dict\n",
    "            keys must stay exactly the same; only values may be\n",
    "            transformed in-place.\n",
    "\n",
    "    Returns:\n",
    "        The same dict with parsed values substituted.\n",
    "    \"\"\"\n",
    "    if \"data_to_be_parsed\" not in field_data:\n",
    "        return field_data\n",
    "\n",
    "    urls: list[str] = field_data[\"data_to_be_parsed\"]\n",
    "\n",
    "    async def _async_parse_single(url: str) -> torch.Tensor:\n",
    "        await asyncio.sleep(1.0)  # Add fixed delay per sample\n",
    "        return parse_url(url)\n",
    "\n",
    "    async def _process_all():\n",
    "        tasks = [asyncio.create_task(_async_parse_single(url)) for url in urls]\n",
    "        return await asyncio.gather(*tasks)\n",
    "\n",
    "    start = time.perf_counter()\n",
    "    field_data[\"data_to_be_parsed\"] = asyncio.run(_process_all())\n",
    "    elapsed = time.perf_counter() - start\n",
    "\n",
    "    print(f\"[data_parser] Processed {len(urls)} samples in {elapsed:.2f}s (serial would be ~{len(urls)}.0s)\")\n",
    "    return field_data\n",
    "\n",
    "\n",
    "# ---------------------------------------------------------------------------\n",
    "# Build the batch\n",
    "# ---------------------------------------------------------------------------\n",
    "batch_size = 32\n",
    "\n",
    "normal_data = torch.randn(batch_size, 2)\n",
    "\n",
    "# URL-like strings: all use the same dtype so TQ can pack them on get\n",
    "shapes = [(i % 4 + 1, i % 3 + 2) for i in range(batch_size)]\n",
    "urls = [f\"float32:{h}x{w}\" for h, w in shapes]\n",
    "\n",
    "parser_fields = TensorDict(\n",
    "    {\n",
    "        \"normal_data\": normal_data,\n",
    "        \"data_to_be_parsed\": NonTensorStack(*urls),\n",
    "    },\n",
    "    batch_size=batch_size,\n",
    ")\n",
    "\n",
    "data_parser_keys = [f\"data_parser_sample_{i}\" for i in range(batch_size)]\n",
    "\n",
    "put_start_time = time.perf_counter()\n",
    "meta = tq.kv_batch_put(\n",
    "    keys=data_parser_keys,\n",
    "    partition_id=\"train\",\n",
    "    fields=parser_fields,\n",
    "    data_parser=concurrent_batch_url_parser,\n",
    ")\n",
    "put_elapsed = time.perf_counter() - put_start_time\n",
    "print(f\"Put succeeded. Fields: {meta.fields}\")\n",
    "print(f\"Total kv_batch_put time: {put_elapsed:.2f}s (concurrency keeps it ~1s, not {batch_size}s)\\n\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Put succeeded. Fields: ['data_to_be_parsed', 'normal_data']\n",
      "Total kv_batch_put time: 1.02s (concurrency keeps it ~1s, not 32s)\n",
      "\n"
     ]
    }
   ],
   "execution_count": 19
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "When `kv_batch_put` returns, the user-defined data parser has also finished executing. We can then safely call `kv_batch_get` to retrieve the parsed data."
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "result = tq.kv_batch_get(keys=data_parser_keys, partition_id=\"train\")\n\n# normal_data should be unchanged\n# normal_data is packed as a nested tensor on retrieval; compare per-sample.\nfor t1, t2 in zip(result[\"normal_data\"], normal_data, strict=True):\n    torch.testing.assert_close(t1, t2)\nprint(\"[PASS] normal_data is unchanged.\")\n\n# data_to_be_parsed should now be tensors with the requested shapes\nexpected_shapes = [(i % 4 + 1, i % 3 + 2) for i in range(batch_size)]\nfor i, expected in enumerate(expected_shapes):\n    tensor = result[\"data_to_be_parsed\"][i]\n    assert tensor.dtype == torch.float32\n    actual = tuple(tensor.shape)\n    assert actual == expected, f\"Mismatch at {i}: {actual} != {expected}\"\nprint(f\"[PASS] All {batch_size} parsed tensors have correct dtype & shape.\")\n\n# Timing sanity check: serial would be ~batch_size seconds.\n# Because asyncio tasks run in parallel inside the parser, it should be ~1 s.\nassert put_elapsed < 2.0, f\"Expected concurrent execution (~1s), but took {put_elapsed:.2f}s.\"\nprint(f\"[PASS] Timing looks concurrent: {put_elapsed:.2f}s < 2.0s\")\n\n# wait for Ray log collect\ntime.sleep(2)"
   ],
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      ""
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[PASS] normal_data is unchanged.\n",
      "[PASS] All 32 parsed tensors have correct dtype & shape.\n",
      "[PASS] Timing looks concurrent: 1.02s < 2.0s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      ""
     ]
    }
   ],
   "execution_count": 20
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 12. Multiple Partitions\n\nPartitions provide logical isolation — the same key name can exist in\ndifferent partitions without conflict."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "tq.kv_put(\n",
    "    key=\"val_sample_0\",\n",
    "    partition_id=\"validation\",\n",
    "    fields={\"input_ids\": torch.tensor([99, 98, 97])},\n",
    "    tag={\"split\": \"val\"},\n",
    ")\n",
    "\n",
    "all_info = tq.kv_list()  # no partition_id → list everything\n",
    "print(\"All partitions:\", list(all_info.keys()))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "All partitions: ['train', 'validation']\n"
     ]
    }
   ],
   "execution_count": 21
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 13. Clean Up — `kv_clear` and `close`\n\nRemove specific keys with `kv_clear`, then shut down the system with `tq.close()`."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# Clear individual keys\n",
    "tq.kv_clear(keys=\"sample_0\", partition_id=\"train\")\n",
    "tq.kv_clear(keys=\"sample_1\", partition_id=\"train\")\n",
    "tq.kv_clear(keys=keys, partition_id=\"train\")\n",
    "tq.kv_clear(keys=data_parser_keys, partition_id=\"train\")\n",
    "tq.kv_clear(keys=\"val_sample_0\", partition_id=\"validation\")\n",
    "print(\"All keys cleared.\")\n",
    "\n",
    "remaining = tq.kv_list()\n",
    "total_keys = sum(len(v) for v in remaining.values())\n",
    "print(f\"Remaining keys across all partitions: {total_keys}\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "All keys cleared.\n",
      "Remaining keys across all partitions: 0\n"
     ]
    }
   ],
   "execution_count": 22
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "tq.close()\n",
    "ray.shutdown()\n",
    "print(\"TransferQueue and Ray shut down.\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TransferQueue and Ray shut down.\n"
     ]
    }
   ],
   "execution_count": 23
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n\n## Summary\n\n| Operation | Function | Notes |\n|---|---|---|\n| Init | `tq.init(config)` | Call once; subsequent processes auto-connect |\n| Put single | `tq.kv_put(key, partition_id, fields, tag)` | `fields` can be a plain dict |\n| Put batch | `tq.kv_batch_put(keys, partition_id, fields, tags)` | `fields` must be a `TensorDict` |\n| Put with parser | `tq.kv_batch_put(..., data_parser=fn)` | Only for **SimpleStorage**; receives dict, can use asyncio inside |\n| Get | `tq.kv_batch_get(keys, partition_id, select_fields=None)` | Returns a `TensorDict` |\n| List | `tq.kv_list(partition_id=None)` | Returns `{partition: {key: tag}}` |\n| Clear | `tq.kv_clear(keys, partition_id)` | Removes keys + data |\n| Close | `tq.close()` | Tears down controller & storage |\n\nFor **async** variants, use `async_kv_put`, `async_kv_batch_put`,\n`async_kv_batch_get`, `async_kv_list`, and `async_kv_clear`.\n\nFor low-level, metadata-based access, see `tq.get_client()` and the\n[official tutorials](https://github.com/Ascend/TransferQueue/tree/main/tutorial)."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [],
   "outputs": [],
   "execution_count": 23
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}