import math
from typing import Optional
import torch
from ..utils import exact_division
def _get_sharded_shape(shape: torch.Tensor, dim: int, block_size: int):
sharded_shape = list(shape)
sharded_shape[dim] = block_size
return sharded_shape
def _get_partial_sharded_padding(tensor: torch.Tensor, world_size: int, rank: int, dim: int = 0):
"""
Splits a PyTorch tensor along a specified dimension into shards across multiple processes,
with padding applied to ensure all shards have uniform size (using zeros for padding).
Args:
tensor: Input tensor to be sharded and padded.
world_size: Total number of processes (total shards to split into).
rank: Index of current process (determines which shard to return).
dim: Dimension along which to shard. Must be 0, -1, or the last dimension (tensor.dim() - 1).
Returns:
torch.Tensor:
Padded shard of the input tensor corresponding to the current rank,with uniform size across all processes.
"""
assert dim in [0, -1, tensor.dim() - 1]
size = tensor.shape[dim]
block_size = math.ceil(size / world_size)
start = rank * block_size
stop = (rank + 1) * block_size
if dim == 0:
tensor = tensor[start:stop]
else:
tensor = tensor[..., start:stop]
sharded_shape = _get_sharded_shape(tensor.shape, dim, block_size)
tensor_zeros = torch.zeros(size=sharded_shape, dtype=tensor.dtype, device=tensor.device)
if dim == 0:
tensor_zeros[: tensor.shape[0]] = tensor
else:
tensor_zeros[..., : tensor.shape[-1]] = tensor
return tensor_zeros
def _get_partial_sharded_by_unit(tensor: torch.Tensor, world_size: int, rank: int, dim: int = 0, unit_size: int = 1):
"""
Splits a PyTorch tensor along a specified dimension into shards across multiple processes,
with alignment to fixed-size units to ensure shards don't split units.
Args:
tensor: Input tensor to be sharded.
world_size: Total number of processes (total shards to split into).
rank: Index of current process (determines which shard to return).
dim: Dimension along which to shard. Must be 0, -1, or the last dimension (tensor.dim() - 1).
unit_size: Size of the fixed unit for alignment. Shard boundaries will always align with these units.
Returns:
torch.Tensor: Shard of the input tensor corresponding to the current rank.
"""
assert dim in [0, -1, tensor.dim() - 1]
size = tensor.shape[dim]
unit_num = size // unit_size
if unit_num >= world_size:
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
else:
start = (rank // (world_size // unit_num)) * unit_size
stop = ((rank // (world_size // unit_num)) + 1) * unit_size
if dim == 0:
tensor = tensor[start:stop]
else:
tensor = tensor[..., start:stop]
return tensor
def get_partial_sharded(
tensor: torch.Tensor,
world_size: int,
rank: int,
dim: int = 0,
unit_num: Optional[int] = None,
):
size = tensor.shape[dim]
if unit_num is not None:
unit_size = exact_division(size, unit_num)
if unit_num % world_size == 0 or world_size % unit_num == 0:
return _get_partial_sharded_by_unit(tensor, world_size, rank, dim, unit_size)
else:
raise ValueError(
f"The scenario where unit_num {unit_num} does not divide world_size {world_size}"
f"and world_size {world_size} does not divide unit_num {unit_num} is not supported."
)
else:
return _get_partial_sharded_padding(tensor, world_size, rank, dim)
def apply_static_quant_linear(x: torch.Tensor, module: torch.nn.Module) -> torch.Tensor:
"""Invoke a Linear-like module, preferring its statically-quantized inner path.
Contract:
``module`` is expected to be either a plain ``torch.nn.Module`` (typically
``torch.nn.Linear`` or a quantized linear that exposes ``qweight``), or a
:class:`ModelWrapperBase` whose ``_inner`` attribute holds such a module.
No other wrapper conventions are supported.
Behavior:
* If the (optionally unwrapped) target carries a ``qweight`` attribute, it
is treated as a statically-quantized linear and invoked directly so the
quantized kernel is exercised without any wrapper-side dispatch.
* Otherwise, the original ``module`` is called so wrapper-side logic such
as dtype casting or hooks is preserved.
"""
if isinstance(module, ModelWrapperBase):
target = module._inner
else:
target = module
qweight = getattr(target, "qweight", None)
if qweight is None:
return module(x)
return target(x)
class ModelWrapperBase(torch.nn.Module):
def __init__(self, wrapped: Optional[torch.nn.Module]):
super().__init__()
self._inner = wrapped
def unwrap(self) -> torch.nn.Module:
wrapped = self
while isinstance(wrapped, ModelWrapperBase):
wrapped = wrapped._inner
return wrapped
def __getattr__(self, item):
try:
return super().__getattr__(item)
except AttributeError:
if hasattr(self._inner, item):
return getattr(self._inner, item)
raise