import logging
import operator
import os
from functools import reduce
import torch
from torch import Tensor
logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
def allocate_empty_tensors(
dtypes: list[torch.dtype], shapes: list[tuple]
) -> tuple[list[Tensor], list[int], list[int], list[int]]:
"""Allocate empty tensors, grouping same dtypes into shared memory blocks.
Instead of allocating each tensor separately, this function groups tensors
by their dtype and allocates one large contiguous memory block per dtype.
Each tensor is then created as a view into this shared memory.
Args:
dtypes: List of torch dtypes for each tensor.
shapes: List of shapes (tuples) for each tensor.
Returns:
A tuple containing:
- List of tensors sharing memory within their dtype groups.
- List of memory pointers (data_ptr) for each tensor.
- List of base pointers for each allocated memory region (one per dtype).
- List of total bytes for each allocated memory region (one per dtype).
Example:
>>> dtypes = [torch.float32, torch.float32, torch.int32, torch.float32]
>>> shapes = [(10,), (20,), (5,), (15,)]
>>> tensors, ptrs, region_ptrs, region_sizes = allocate_empty_tensors(dtypes, shapes)
>>> # tensors[0], [1], [3] share the same dtype and memory block
"""
assert len(dtypes) == len(shapes), "dtypes and shapes must have the same length"
if len(dtypes) == 0:
return [], [], [], []
dtype_groups: dict[torch.dtype, list[int]] = {}
for i, dtype in enumerate(dtypes):
if dtype not in dtype_groups:
dtype_groups[dtype] = []
dtype_groups[dtype].append(i)
tensor_list = [torch.empty(()) for _ in range(len(dtypes))]
ptr_list = [0] * len(dtypes)
region_ptrs: list[int] = []
region_sizes: list[int] = []
for dtype, indices in dtype_groups.items():
total_elements = 0
shape_info = []
for idx in indices:
shape = tuple(shapes[idx])
num_elements = reduce(operator.mul, shape, 1)
shape_info.append((idx, shape, num_elements, total_elements))
total_elements += num_elements
big_tensor = torch.empty(total_elements, dtype=dtype)
region_ptrs.append(big_tensor.data_ptr())
region_sizes.append(big_tensor.nbytes)
for idx, shape, num_elements, offset in shape_info:
small_tensor = big_tensor.as_strided(size=shape, stride=compute_stride(shape), storage_offset=offset)
tensor_list[idx] = small_tensor
ptr_list[idx] = small_tensor.data_ptr()
return tensor_list, ptr_list, region_ptrs, region_sizes
def compute_stride(shape: tuple[int, ...]) -> tuple[int, ...]:
"""Compute stride for a contiguous row-major (C-style) tensor.
Args:
shape: The shape of the tensor.
Returns:
Stride tuple for contiguous storage.
Example:
>>> compute_stride((2, 3, 4))
(12, 4, 1)
"""
stride = []
cumulative = 1
for dim in reversed(shape):
stride.append(cumulative)
cumulative *= dim
return tuple(reversed(stride))
def get_nbytes(dtypes, shapes) -> list[int]:
"""Calculate number of bytes according to tensor dtypes and shapes."""
assert len(dtypes) == len(shapes)
nbytes = []
for i in range(len(dtypes)):
elem_size = torch.tensor([], dtype=dtypes[i]).element_size()
shape = tuple(shapes[i])
numel = reduce(operator.mul, shape, 1)
nbytes.append(elem_size * numel)
return nbytes
def merge_contiguous_memory(ptrs: list[int], sizes: list[int]) -> tuple[list[int], list[int]]:
"""Merge contiguous memory regions to reduce register_buffer overhead
Args:
ptrs: List of memory pointers (starting addresses).
sizes: List of memory region sizes corresponding to each pointer.
Returns:
A tuple of (merged_ptrs, merged_sizes) where contiguous regions
have been merged into single regions.
Example:
>>> merge_contiguous_memory([0, 10, 30], [10, 20, 10])
([0, 30], [30, 10])
>>> merge_contiguous_memory([0, 5, 20], [5, 5, 10])
([0, 20], [10, 10])
"""
if len(ptrs) != len(sizes):
raise ValueError("ptrs and sizes must have the same length")
if not ptrs:
return [], []
regions = sorted(zip(ptrs, sizes, strict=False), key=lambda x: x[0])
merged_ptrs = []
merged_sizes = []
current_ptr, current_size = regions[0]
for ptr, size in regions[1:]:
if ptr == current_ptr + current_size:
current_size += size
else:
merged_ptrs.append(current_ptr)
merged_sizes.append(current_size)
current_ptr, current_size = ptr, size
merged_ptrs.append(current_ptr)
merged_sizes.append(current_size)
return merged_ptrs, merged_sizes