from dataclasses import dataclass
from typing import ClassVar, Dict
from enum import Enum
import math
import torch
import torch_npu
class TensorState(Enum):
""" Define activation tensor status.
"""
NORMAL = "normal"
COMPRESS = "compress"
class ListNode:
""" Utilize the linked list data structure to record
the dependencies between transformer layer computations.
"""
def __init__(
self,
order_layer_uuid,
prev_layer_node=None,
next_layer_node=None
) -> None:
self.order_layer_uuid = order_layer_uuid
self.next_layer_node = next_layer_node
self.prev_layer_node = prev_layer_node
def set_next_layer_node(self, next_layer_node) -> None:
self.next_layer_node = next_layer_node
def next(self):
return self.next_layer_node
def prev(self):
return self.prev_layer_node
class ShareMemory:
""" Class for managing shared swap tensor and shared PDF tensor.
"""
def __init__(self, numel: int, dtype: torch.dtype) -> None:
self.numel = numel
self.dtype = dtype
self.min_host_size = 2 * 1024 * 1024
device = torch.empty([], device=torch.cuda.current_device()).device
self.virtual_tensor = get_swap_tensor(numel, device, dtype)
self.can_be_used = True
self.pdf = torch.zeros(256, dtype=torch.int32, device=device)
def get_swap_tensor(ts_numel: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
""" Return the swap tensor through the input attributes.
Args:
ts_numel: Swap tensor numel.
device: Swap tensor device.
dtype: Swap tensor dtype.
Returns:
swap_tensor: Return swap tensor.
"""
if not hasattr(torch_npu, "empty_with_swapped_memory"):
raise ModuleNotFoundError("PTA dose not support this func, please update to latest version.")
size = torch.Size([ts_numel])
swap_tensor = torch_npu.empty_with_swapped_memory(size, dtype=dtype, device=device)
swap_tensor.zero_()
return swap_tensor
class TensorManager:
""" Manages tensor compression/decompression with NPU-accelerated operator npu_hans_encode/npu_hans_decode.
Core Responsibilities:
- Memory allocation for compressed representations
- State tracking for tensor lifecycle
- NPU hardware acceleration for encode/decode operations
"""
def __init__(self, tensor: torch.Tensor, compress_ratio: float = 0.5) -> None:
self.tensor = tensor
self.fixed_numel = (math.ceil(tensor.numel() * compress_ratio) // tensor.element_size() + 1) // 2 * 2
self.mantissa_numel = tensor.numel() * (tensor.element_size() - 1) // self.tensor.element_size()
self.storage_size = self.tensor.numel() * self.tensor.element_size()
self.var = None
self.fixed = None
self.mantissa = None
self.state = TensorState.NORMAL
self.statistic = True
def malloc(self, var: ShareMemory, statistic: bool) -> None:
""" Allocate the required memory before executing the compression operator.
Args:
var: Share memory.
"""
self.var = var
self.statistic = statistic
self.fixed = torch.zeros(
self.fixed_numel, dtype=self.tensor.dtype, device=self.tensor.device)
self.mantissa = torch.zeros(
self.mantissa_numel, dtype=self.tensor.dtype, device=self.tensor.device)
def encode(self) -> None:
""" Asynchronous execution of compression task.
"""
self.var.pdf, self.mantissa, self.fixed, self.var.virtual_tensor = torch_npu.npu_hans_encode(\
self.tensor, self.statistic, False, \
out=(self.var.pdf, self.mantissa, self.fixed, self.var.virtual_tensor))
def encode_wait(self) -> None:
""" Wait for the asynchronous task to complete compression,
then release the memory of the original activation values.
"""
self.state = TensorState.COMPRESS
self.tensor.untyped_storage().resize_(0)
def pre_decode(self) -> None:
""" Reapply for activation memory before decompression task.
"""
self.tensor.untyped_storage().resize_(self.storage_size)
def decode(self) -> None:
""" Asynchronous execution of decompression task.
"""
self.tensor = torch_npu.npu_hans_decode(self.mantissa, \
self.fixed, self.var.virtual_tensor, self.var.pdf, False, out=self.tensor)
def release(self) -> None:
""" After decompression, release all allocated memory.
"""
if hasattr(self.var, "can_be_used"):
self.var.can_be_used = True
self.fixed = None
self.mantissa = None
self.var = None
self.state = TensorState.NORMAL
def recover(self) -> None:
""" Synchronize and restore all activation, and release any excess memory.
"""
self.pre_decode()
self.decode()
self.release()
@dataclass
class SimulationHyperParams:
""" Hyperparameters for Time-Consuming Theoretical Modeling.
"""
allgather_throughput: Dict[str, float]
all2all_throughput: Dict[str, float]
TFLOPS: ClassVar[int] = 10**12
GIGABYTE: ClassVar[int] = 1024 ** 3
MAX_BANDWIDTH: ClassVar[int] = 1000 * GIGABYTE
encode_throughput: float = 100.0 * GIGABYTE
decode_throughput: float = 111.0 * GIGABYTE
cube_tflops: float = 280.0 * TFLOPS
class SimulationBase:
""" Used for modeling various asynchronous operators.
"""
def __init__(self, simulation_config: SimulationHyperParams) -> None:
self.simulation_config = simulation_config
def time_cost(self, op_name: str, *args, **kwargs) -> float:
if op_name == "matmul":
return self._matmul(*args, **kwargs)
elif op_name == "all2all":
return self._all2all(*args, **kwargs)
elif op_name == "allgather":
return self._allgather(*args, **kwargs)
else:
return 0
def _matmul(self, *args, **kwargs) -> float:
""" Matmul time cost.
"""
output_shape = infer_matmul_shape(args[0], args[1])
total_flop = 2 * args[0].shape[-1]
for dim in output_shape:
total_flop *= dim
return total_flop / self.simulation_config.cube_tflops
def _all2all(self, *args, **kwargs) -> float:
""" All2All time cost.
"""
if not kwargs.get("group", False):
return 0
group_size = torch.distributed.get_world_size(kwargs["group"])
simulation_bandwidth = self.simulation_config.all2all_throughput.get(
str(group_size), self.simulation_config.MAX_BANDWIDTH)
return args[0].numel() * args[0].element_size() / simulation_bandwidth
def _allgather(self, *args, **kwargs) -> float:
""" AllGather time cost.
"""
if not kwargs.get("group", False):
return 0
group_size = torch.distributed.get_world_size(kwargs["group"])
simulation_bandwidth = self.simulation_config.allgather_throughput.get(
str(group_size), self.simulation_config.MAX_BANDWIDTH)
return args[0].numel() * args[0].element_size() / simulation_bandwidth
def _reducescatter(self, *args, **kwargs) -> float:
raise NotImplementedError
def encode_max_numel(self, estimated_time) -> int:
return int(self.simulation_config.encode_throughput * estimated_time / 2)
def decode_max_numel(self, estimated_time) -> int:
return int(self.simulation_config.decode_throughput * estimated_time / 2)
class SimulationA2(SimulationBase):
def __init__(self):
cfg = SimulationHyperParams(
allgather_throughput={
"2": 36 * SimulationHyperParams.GIGABYTE,
"4": 73 * SimulationHyperParams.GIGABYTE,
"8": 147 * SimulationHyperParams.GIGABYTE,
"16": 138 * SimulationHyperParams.GIGABYTE,
},
all2all_throughput={
"2": 37 * SimulationHyperParams.GIGABYTE,
"4": 69 * SimulationHyperParams.GIGABYTE,
"8": 119 * SimulationHyperParams.GIGABYTE,
"16": 40.1 * SimulationHyperParams.GIGABYTE,
"32": 30.2 * SimulationHyperParams.GIGABYTE,
"64": 27.0 * SimulationHyperParams.GIGABYTE,
}
)
super().__init__(cfg)
class SimulationA3(SimulationBase):
def __init__(self):
cfg = SimulationHyperParams(
allgather_throughput={
"2": 350.0 * SimulationHyperParams.GIGABYTE,
"4": 324.9 * SimulationHyperParams.GIGABYTE,
"8": 298.8 * SimulationHyperParams.GIGABYTE,
"16": 283.6 * SimulationHyperParams.GIGABYTE,
},
all2all_throughput={
"2": 0 * SimulationHyperParams.GIGABYTE,
"4": 229.4 * SimulationHyperParams.GIGABYTE,
"8": 173.9 * SimulationHyperParams.GIGABYTE,
"16": 154.0 * SimulationHyperParams.GIGABYTE,
"32": 143.9 * SimulationHyperParams.GIGABYTE,
"64": 137.3 * SimulationHyperParams.GIGABYTE,
}
)
super().__init__(cfg)
def infer_matmul_shape(A: torch.Tensor, B: torch.Tensor):
a_shape = list(A.shape)
b_shape = list(B.shape)
a_was_1d = False
b_was_1d = False
if A.dim() == 1:
a_shape = [1, a_shape[0]]
a_was_1d = True
if B.dim() == 1:
b_shape = [b_shape[0], 1]
b_was_1d = True
if a_shape[-1] != b_shape[-2]:
raise ValueError(f"Incompatible shapes: {A.shape} @ {B.shape}")
batch_a = a_shape[:-2]
batch_b = b_shape[:-2]
try:
broadcast_batch = torch.broadcast_shapes(tuple(batch_a), tuple(batch_b))
except RuntimeError as e:
raise ValueError(f"Cannot broadcast batch dimensions: {batch_a} vs {batch_b}") from e
m = a_shape[-2]
n = b_shape[-1]
out_shape = list(broadcast_batch) + [m, n]
if a_was_1d and b_was_1d:
return tuple(out_shape[:-2])
elif a_was_1d:
return tuple(out_shape[:-2] + [n])
elif b_was_1d:
return tuple(out_shape[:-2] + [m])
else:
return tuple(out_shape)