"""Estimate HBM memory consumption for generated shape grid rows.
Used by ``grid_generator/runner.py`` to filter out shapes that would exceed
the available device memory during microbench replay.
Default budget: 32 GiB (50 % of Ascend 910B 64 GiB HBM).
"""
from __future__ import annotations
import math
import re
from typing import Sequence
DTYPE_BYTES: dict[str, int] = {
"DT_FLOAT": 4,
"DT_FLOAT16": 2,
"DT_BF16": 2,
"DT_BFLOAT16": 2,
"DT_FLOAT32": 4,
"DT_FLOAT64": 8,
"DT_DOUBLE": 8,
"DT_INT8": 1,
"DT_INT16": 2,
"DT_INT32": 4,
"DT_INT64": 8,
"DT_UINT8": 1,
"DT_UINT16": 2,
"DT_UINT32": 4,
"DT_UINT64": 8,
"DT_BOOL": 1,
"DT_COMPLEX64": 8,
"DT_COMPLEX128": 16,
"DT_FLOAT8_E4M3": 1,
"DT_FLOAT8_E5M2": 1,
"DT_FLOAT8": 1,
"FLOAT": 4,
"INT8": 1,
"INT16": 2,
"INT32": 4,
"INT64": 8,
"UINT8": 1,
"BOOL": 1,
}
DEFAULT_BYTES_PER_ELEMENT = 2
DEFAULT_MAX_BYTES = 32 * 1024 ** 3
def dtype_to_bytes(dtype_name: str) -> int:
"""Convert a dtype string to element size in bytes.
Normalises the name by upper-casing and stripping whitespace before
lookup. Falls back to ``DEFAULT_BYTES_PER_ELEMENT`` for unknown types.
"""
key = dtype_name.strip().upper()
return DTYPE_BYTES.get(key, DEFAULT_BYTES_PER_ELEMENT)
def _parse_dtype_list(dtype_cell: str) -> list[str]:
"""Parse the ``Input Data Types`` CSV column into a dtype list."""
raw = str(dtype_cell or "").strip().strip('"')
if not raw:
return []
return [part.strip() for part in re.split(r"[;\s]+", raw) if part.strip()]
def estimate_tensor_bytes(shape: tuple[int, ...], bytes_per_elem: int) -> int:
"""Return estimated bytes for one tensor. Empty shape → 0."""
if not shape:
return 0
return math.prod(shape) * bytes_per_elem
def estimate_row_memory(
input_shapes: Sequence[tuple[int, ...]],
output_shapes: Sequence[tuple[int, ...]],
input_dtypes: Sequence[str],
output_dtypes: Sequence[str] | None = None,
) -> int:
"""Estimate total HBM memory for one shape-grid row (bytes).
Each input/output tensor size = product(shape) * bytes_per_element.
Missing dtype entries fall back to FP16 (2 bytes).
"""
total = 0
for idx, shape in enumerate(input_shapes):
dtype = input_dtypes[idx] if idx < len(input_dtypes) else ""
total += estimate_tensor_bytes(shape, dtype_to_bytes(dtype))
if output_dtypes is None:
output_dtypes = input_dtypes
for idx, shape in enumerate(output_shapes):
dtype = output_dtypes[idx] if idx < len(output_dtypes) else ""
total += estimate_tensor_bytes(shape, dtype_to_bytes(dtype))
return total
def exceeds_memory_budget(
input_shapes: Sequence[tuple[int, ...]],
output_shapes: Sequence[tuple[int, ...]],
input_dtypes: Sequence[str],
output_dtypes: Sequence[str] | None = None,
max_bytes: int = DEFAULT_MAX_BYTES,
) -> tuple[bool, int]:
"""Check whether a row exceeds the HBM budget.
Returns:
(exceeded: bool, estimated_bytes: int)
"""
estimated = estimate_row_memory(
input_shapes, output_shapes, input_dtypes, output_dtypes,
)
return estimated > max_bytes, estimated
def format_bytes(n: int) -> str:
"""Human-readable byte count (e.g. '12.34 GiB')."""
if n >= 1024 ** 3:
return f"{n / 1024 ** 3:.2f} GiB"
if n >= 1024 ** 2:
return f"{n / 1024 ** 2:.2f} MiB"
if n >= 1024:
return f"{n / 1024:.2f} KiB"
return f"{n} B"
def parse_dtype_from_template_row(row: dict[str, str]) -> tuple[list[str], list[str]]:
"""Extract input/output dtype lists from a CSV template row."""
input_dtypes = _parse_dtype_list(row.get("Input Data Types", ""))
output_dtypes = _parse_dtype_list(row.get("Output Data Types", ""))
return input_dtypes, output_dtypes