"""CPU offloading tests adapted for Ascend NPU.
These tests validate the CPU offloading functionality (OffloadableLayerState,
DefaultOffloadSynchronizer, ManualOffloadSynchronizer) on NPU, exercising
tensor offload/reload roundtrips, memory accounting, and integration with
TE-NPU modules (Linear, LayerNormLinear, LayerNormMLP, GroupedLinear).
Adapted from NVIDIA's test_cpu_offloading.py with the following NPU-specific
changes:
- Device: cuda → npu
- Streams: torch.cuda.Stream → torch.npu.Stream
- Events: torch.cuda.Event → torch.npu.Event
- Memory: torch.cuda.memory_allocated → torch.npu.memory_allocated
- Layer types: reduced to TE-NPU available modules
- Quantization: only MXFP8 (no DelayedScaling, Float8CurrentScaling, NVFP4)
- No CUDA graphs, no attention backend parametrization
"""
from __future__ import annotations
import contextlib
import gc
import random
from typing import Optional
import pytest
import torch
from transformer_engine.common import recipe
from transformer_engine.pytorch import (
Linear,
GroupedLinear,
LayerNormLinear,
autocast,
)
from transformer_engine.pytorch.constants import NPUVersion
from transformer_engine.pytorch.cpu_offload import (
get_cpu_offload_context,
OffloadableLayerState,
DefaultOffloadSynchronizer,
start_offload,
mark_not_offload,
)
from transformer_engine.pytorch.ops import L2Normalization
from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage
from transformer_engine.pytorch.quantization.utils import check_npu_version
from utils import npu_available
pytestmark = pytest.mark.skipif(not npu_available(), reason="NPU device is required")
_quantization_recipes = [
None,
pytest.param(
recipe.MXFP8BlockScaling(),
marks=pytest.mark.skipif(
not check_npu_version(NPUVersion.A5),
reason="FP8 (MXFP8BlockScaling) requires Ascend 950 or newer (Ascend 910B does not support FP8)",
),
),
]
_D = 256
_H = 4
_B = 128
_S = 512
_EPSILON = 0.1
gc.disable()
class _Utils:
tensor1 = torch.randn((1024, 1024), device="npu", dtype=torch.bfloat16)
@staticmethod
def long_job(stream: Optional[torch.npu.Stream] = None):
NUM_ITERS = 6000
if stream is None:
stream = torch.npu.current_stream()
with torch.npu.stream(stream):
for _ in range(NUM_ITERS):
_Utils.tensor1.normal_()
@staticmethod
def measure_time(func):
import time
torch.npu.synchronize()
start = time.time()
func()
torch.npu.synchronize()
end = time.time()
return (end - start) * 1000
@staticmethod
def get_npu_memory_mb():
return torch.npu.memory_allocated() / (1024**2)
@staticmethod
def get_max_npu_memory_mb():
return torch.npu.max_memory_allocated() / (1024**2)
@staticmethod
def get_cpu_memory_mb() -> float:
try:
import psutil
import os
return psutil.Process(os.getpid()).memory_info().rss / (1024**2)
except ImportError:
pytest.skip("psutil not installed")
@staticmethod
def get_layer_names():
return [
"linear",
"layernorm_linear",
"layernorm_mlp",
"grouped_linear",
]
@staticmethod
def create_layer(layer_type: str):
if layer_type == "linear":
return Linear(_D, _D, params_dtype=torch.bfloat16, device="npu")
elif layer_type == "layernorm_linear":
return LayerNormLinear(_D, _D, params_dtype=torch.bfloat16, device="npu")
elif layer_type == "layernorm_mlp":
from transformer_engine.pytorch.module.layernorm_mlp import LayerNormMLP
return LayerNormMLP(_D, _D, params_dtype=torch.bfloat16, device="npu")
elif layer_type == "grouped_linear":
return GroupedLinear(_H, _D, _D, params_dtype=torch.bfloat16, device="npu")
else:
raise ValueError(f"Unknown layer type: {layer_type}")
@staticmethod
def create_tensor(
fp8_recipe: Optional[recipe.Recipe],
requires_grad: bool = False,
) -> torch.Tensor:
shape = (_B, _S, _D)
tensor = torch.randn(shape, device="npu", dtype=torch.bfloat16)
if fp8_recipe is None:
return tensor.requires_grad_(requires_grad) if requires_grad else tensor
if fp8_recipe.mxfp8():
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
quantizer = MXFP8Quantizer(fp8_dtype=torch.float8_e4m3fn)
return quantizer(tensor)
raise ValueError(f"Unsupported recipe: {fp8_recipe}")
@staticmethod
def create_recipe_ctx(fp8_recipe: Optional[recipe.Recipe]):
if fp8_recipe is None:
return lambda: contextlib.nullcontext()
return lambda: autocast(enabled=True, recipe=fp8_recipe)
@staticmethod
def get_tensor_size_mb(tensor):
if tensor is None:
return 0
if isinstance(tensor, QuantizedTensorStorage):
return sum(_Utils.get_tensor_size_mb(t) for t in tensor.get_data_tensors())
return tensor.numel() * tensor.element_size() / (1024**2)
@staticmethod
def memory_leak_check():
if _Utils.get_npu_memory_mb() > 1000:
memory_num = _Utils.get_npu_memory_mb()
gc.collect()
gc.disable()
raise RuntimeError(f"Memory leak: {memory_num} MB")
class TestOffloadableLayerState:
@pytest.mark.parametrize("random_num_tensors", [True, False])
@pytest.mark.parametrize("fp8_recipe", _quantization_recipes)
def test_general(self, random_num_tensors, fp8_recipe):
_Utils.memory_leak_check()
NUM_ITERATIONS = 10
stream = torch.npu.Stream()
offload_layer_state = OffloadableLayerState(offload_stream=stream)
for _ in range(NUM_ITERATIONS):
original_tensors = []
tensors_ids = []
NUM_TENSORS = random.choice([1, 20]) if random_num_tensors else 1
for _ in range(NUM_TENSORS):
tensor = _Utils.create_tensor(fp8_recipe)
original_tensors.append(tensor)
tensor_id = offload_layer_state.push_tensor(tensor)
assert tensor.device.type == "npu"
tensors_ids.append(tensor_id)
offload_layer_state.start_offload()
offload_layer_state.release_activation_forward_gpu_memory()
offload_layer_state.start_reload()
for j in range(len(tensors_ids)):
tensor_npu = offload_layer_state.pop_tensor(tensors_ids[j])
assert tensor_npu.device.type == "npu"
assert tensor_npu.shape == original_tensors[j].shape
assert tensor_npu.dtype == original_tensors[j].dtype
torch.testing.assert_close(tensor_npu, original_tensors[j])
offload_layer_state.release_all_memory()
torch.npu.synchronize()
def test_offload_base_tensor(self):
_Utils.memory_leak_check()
stream = torch.npu.Stream()
offload_layer_state = OffloadableLayerState(offload_stream=stream)
init_npu_memory = _Utils.get_npu_memory_mb()
x = _Utils.create_tensor(None)
x_size = _Utils.get_tensor_size_mb(x)
x_1 = x[::2]
x_2 = x[1::2]
start_offload(x_1, offload_base_tensor=True)
start_offload(x_2, offload_base_tensor=True)
x1_id = offload_layer_state.push_tensor(x_1)
x2_id = offload_layer_state.push_tensor(x_2)
del x_1, x_2
offload_layer_state.start_offload()
offload_layer_state.release_activation_forward_gpu_memory()
assert offload_layer_state.get_offloaded_total_size_mb() == pytest.approx(x_size, 0.1)
offload_layer_state.start_reload()
x_1 = offload_layer_state.pop_tensor(x1_id)
x_2 = offload_layer_state.pop_tensor(x2_id)
assert x_1.device.type == "npu"
assert x_2.device.type == "npu"
assert torch.allclose(x_1, x[::2])
assert torch.allclose(x_2, x[1::2])
del x
assert _Utils.get_npu_memory_mb() == pytest.approx(init_npu_memory + x_size, 0.1)
class TestDefaultOffloadSynchronizer:
@pytest.mark.parametrize("random_num_tensors", [True, False])
@pytest.mark.parametrize("fp8_recipe", _quantization_recipes)
def test_general(self, random_num_tensors, fp8_recipe):
_Utils.memory_leak_check()
NUM_LAYERS = 10
NUM_ITERATIONS = 10
offload_synchronizer = DefaultOffloadSynchronizer(
num_layers=NUM_LAYERS,
num_offloaded_layers=NUM_LAYERS - 1,
)
for _ in range(NUM_ITERATIONS):
original_tensors = []
tensors_ids = []
layer_ids = []
for i in range(NUM_LAYERS):
NUM_LAYER_TENSORS = random.randint(1, 10) if random_num_tensors else 1
layer_tensors = []
layer_tensors_ids = []
layer_id = offload_synchronizer.fwd_step()
for _ in range(NUM_LAYER_TENSORS):
tensor = _Utils.create_tensor(fp8_recipe)
layer_tensors.append(tensor)
tensor_id = offload_synchronizer.push_tensor(tensor)
assert tensor.device.type == "npu"
layer_tensors_ids.append(tensor_id)
layer_ids.append(layer_id)
tensors_ids.append(layer_tensors_ids)
original_tensors.append(layer_tensors)
for i in range(NUM_LAYERS - 1, -1, -1):
offload_synchronizer.bwd_step(layer_ids[i])
for j in range(len(tensors_ids[i])):
tensor_npu = offload_synchronizer.pop_tensor(tensors_ids[i][j])
assert tensor_npu.device.type == "npu"
assert tensor_npu.shape == original_tensors[i][j].shape
assert tensor_npu.dtype == original_tensors[i][j].dtype
torch.testing.assert_close(tensor_npu, original_tensors[i][j])
offload_synchronizer.finish_part_of_bwd()
torch.npu.synchronize()
@pytest.mark.parametrize("fp8_recipe", _quantization_recipes)
def test_memory(self, fp8_recipe):
torch.npu.synchronize()
_Utils.memory_leak_check()
NUM_LAYERS = 10
torch.npu.reset_peak_memory_stats()
offload_synchronizer = DefaultOffloadSynchronizer(
num_layers=NUM_LAYERS,
num_offloaded_layers=NUM_LAYERS - 1,
)
init_npu_memory = _Utils.get_npu_memory_mb()
tensor_ids = []
torch.npu.synchronize()
for _ in range(NUM_LAYERS):
offload_synchronizer.fwd_step()
tensor = _Utils.create_tensor(fp8_recipe)
tensor_size = _Utils.get_tensor_size_mb(tensor)
tensor_id = offload_synchronizer.push_tensor(tensor)
assert tensor.device.type == "npu"
tensor_ids.append(tensor_id)
del tensor, tensor_id
torch.npu.synchronize()
if fp8_recipe is None:
assert _Utils.get_max_npu_memory_mb() <= init_npu_memory + tensor_size * 4
assert _Utils.get_npu_memory_mb() == pytest.approx(init_npu_memory + tensor_size, 0.1)
for i in range(NUM_LAYERS - 1, -1, -1):
offload_synchronizer.bwd_step(i)
tensor_npu = offload_synchronizer.pop_tensor(tensor_ids[i])
assert tensor_npu.device.type == "npu"
del tensor_npu, tensor_ids[i]
offload_synchronizer.finish_part_of_bwd()
del tensor_ids
torch.npu.synchronize()
if fp8_recipe is None:
assert _Utils.get_max_npu_memory_mb() <= init_npu_memory + tensor_size * 4
assert _Utils.get_npu_memory_mb() == pytest.approx(init_npu_memory, 0.1)
@pytest.mark.parametrize("fp8_recipe", _quantization_recipes)
def test_multiple_tensor_offload(self, fp8_recipe):
_Utils.memory_leak_check()
init_cpu_memory = _Utils.get_cpu_memory_mb()
init_npu_memory = _Utils.get_npu_memory_mb()
offload_synchronizer = DefaultOffloadSynchronizer(
num_layers=2,
num_offloaded_layers=1,
)
x1 = _Utils.create_tensor(fp8_recipe)
x_size = _Utils.get_tensor_size_mb(x1)
offload_synchronizer.fwd_step()
offload_synchronizer.push_tensor(x1)
offload_synchronizer.push_tensor(x1)
offload_synchronizer.push_tensor(x1)
if fp8_recipe is not None:
x1.dequantize()
offload_synchronizer.fwd_step()
assert _Utils.get_cpu_memory_mb() == pytest.approx(init_cpu_memory + 1 * x_size, 0.1)
del x1
offload_synchronizer.bwd_step(1)
offload_synchronizer.bwd_step(0)
offload_synchronizer.finish_part_of_bwd()
assert _Utils.get_npu_memory_mb() == pytest.approx(init_npu_memory, 0.1)
class TestTELayers:
@pytest.mark.parametrize("layer_type", _Utils.get_layer_names())
@pytest.mark.parametrize("fp8_recipe", _quantization_recipes)
def test_sanity(self, layer_type, fp8_recipe):
_Utils.memory_leak_check()
recipe_ctx = _Utils.create_recipe_ctx(fp8_recipe)
OFFLOAD_LAYERS = 6
NUM_LAYERS = 10
offload_ctx, sync_function = get_cpu_offload_context(
enabled=True,
num_layers=OFFLOAD_LAYERS,
model_layers=NUM_LAYERS,
)
layers = [_Utils.create_layer(layer_type) for _ in range(NUM_LAYERS)]
inp = _Utils.create_tensor(None)
m_splits = {"m_splits": [_B * _S // _H] * _H} if layer_type == "grouped_linear" else {}
out = inp
for i in range(NUM_LAYERS):
with offload_ctx, recipe_ctx():
out = layers[i](out, is_first_microbatch=False, **m_splits)
out = sync_function(out)
out.sum().backward()
torch.npu.synchronize()
del out, inp, layers
@pytest.mark.parametrize("layer_type", _Utils.get_layer_names())
@pytest.mark.parametrize("fp8_recipe", _quantization_recipes)
def test_memory(self, layer_type, fp8_recipe):
_Utils.memory_leak_check()
offload_ctx, sync_function = get_cpu_offload_context(
enabled=True,
num_layers=1,
model_layers=2,
offload_activations=True,
offload_weights=False,
)
recipe_ctx = _Utils.create_recipe_ctx(fp8_recipe)
layer = _Utils.create_layer(layer_type)
inp = _Utils.create_tensor(None)
m_splits = {"m_splits": [_B * _S // _H] * _H} if layer_type == "grouped_linear" else {}
with recipe_ctx():
out = layer(inp, is_first_microbatch=True, **m_splits)
out.sum().backward()
del inp
init_npu_memory = _Utils.get_npu_memory_mb()
inp = _Utils.create_tensor(None)
with recipe_ctx():
out = layer(inp, is_first_microbatch=False, **m_splits)
with recipe_ctx():
out = out + 1
del inp
npu_memory_no_offload = _Utils.get_npu_memory_mb()
out.sum().backward()
inp = _Utils.create_tensor(None)
with offload_ctx, recipe_ctx():
out = layer(inp, is_first_microbatch=False, **m_splits)
out = sync_function(out)
with offload_ctx, recipe_ctx():
out = out + 1
out = sync_function(out)
del inp
assert _Utils.get_npu_memory_mb() == pytest.approx(init_npu_memory, 0.1)
offloaded_memory_cpu = offload_ctx.offload_synchronizer.get_offloaded_total_size_mb()
assert _Utils.get_npu_memory_mb() + offloaded_memory_cpu == pytest.approx(
npu_memory_no_offload, 1
)
out.sum().backward()
def test_l2_normalization_saved_activation_offloads_to_cpu(self):
_Utils.memory_leak_check()
offload_ctx, sync_function, manual_controller = get_cpu_offload_context(
enabled=True,
model_layers=1,
offload_activations=True,
manual_synchronization=True,
)
x_ref = torch.randn((512, 512), device="npu", dtype=torch.bfloat16, requires_grad=True)
x_test = x_ref.detach().clone().requires_grad_()
dy = torch.randn_like(x_ref)
y_ref = L2Normalization()(x_ref)
y_ref.backward(dy)
with offload_ctx:
y_test = L2Normalization()(x_test)
y_test = sync_function(y_test)
manual_controller.start_offload_layer(0)
manual_controller.release_activation_forward_gpu_memory(0)
assert offload_ctx.offload_synchronizer.get_offloaded_total_size_mb() > 0
manual_controller.start_reload_layer(0)
y_test.backward(dy)
torch.testing.assert_close(y_test, y_ref)
torch.testing.assert_close(x_test.grad, x_ref.grad)
@pytest.mark.parametrize("layer_type", _Utils.get_layer_names())
@pytest.mark.parametrize("fp8_recipe", _quantization_recipes)
def test_manual_synchronization(self, fp8_recipe, layer_type):
_Utils.memory_leak_check()
offload_ctx, sync_function, manual_controller = get_cpu_offload_context(
enabled=True,
model_layers=6,
offload_activations=True,
manual_synchronization=True,
)
layer_1 = _Utils.create_layer(layer_type)
layer_2 = _Utils.create_layer(layer_type)
inp1 = _Utils.create_tensor(None)
inp2 = _Utils.create_tensor(None)
recipe_ctx = _Utils.create_recipe_ctx(fp8_recipe)
m_splits = {"m_splits": [_B * _S // _H] * _H} if layer_type == "grouped_linear" else {}
with offload_ctx, recipe_ctx():
out_1 = layer_1(inp1, **m_splits)
out_1 = sync_function(out_1)
with offload_ctx, recipe_ctx():
out_2 = layer_2(inp2, **m_splits)
out_2 = sync_function(out_2)
mark_not_offload(out_1, out_2)
del inp1, inp2
memory_before_offload = _Utils.get_npu_memory_mb()
manual_controller.start_offload_layer(0)
manual_controller.release_activation_forward_gpu_memory(0)
manual_controller.start_offload_layer(1)
manual_controller.release_activation_forward_gpu_memory(1)
memory_after_offload = _Utils.get_npu_memory_mb()
assert memory_after_offload - _EPSILON <= memory_before_offload
manual_controller.start_reload_layer(0)
manual_controller.start_reload_layer(1)
memory_after_reload = _Utils.get_npu_memory_mb()
assert memory_after_reload == pytest.approx(memory_before_offload, 1)
out_1.sum().backward()
out_2.sum().backward()
@pytest.mark.parametrize("fp8_recipe", _quantization_recipes)
@pytest.mark.parametrize("layer_type", _Utils.get_layer_names())
def test_numerics(self, fp8_recipe, layer_type):
recipe_ctx = _Utils.create_recipe_ctx(fp8_recipe)
offload_ctx, sync_function = get_cpu_offload_context(
enabled=True,
num_layers=1,
model_layers=2,
offload_activations=True,
offload_weights=False,
)
class Callable(torch.nn.Module):
def __init__(self, offload_ctx=None, sync_function=None):
super().__init__()
self.layers = torch.nn.ModuleList(
[_Utils.create_layer(layer_type) for _ in range(2)]
)
self.offload_ctx = offload_ctx
self.sync_function = sync_function
def forward(self, x):
m_splits = (
{"m_splits": [_B * _S // _H] * _H} if layer_type == "grouped_linear" else {}
)
for layer in self.layers:
with self.offload_ctx, recipe_ctx():
x = layer(x, is_first_microbatch=False, **m_splits)
if self.sync_function is not None:
x = self.sync_function(x)
return x
callable_offload = Callable(offload_ctx=offload_ctx, sync_function=sync_function)
callable_no_offload = Callable(offload_ctx=contextlib.nullcontext(), sync_function=None)
for param_offload, param_no_offload in zip(
callable_offload.parameters(), callable_no_offload.parameters()
):
param_offload.data.copy_(param_no_offload.data)
x = _Utils.create_tensor(None)
for _ in range(4):
out = callable_offload(x)
out.sum().backward()
out = callable_no_offload(x)
out.sum().backward()
callable_offload.zero_grad(set_to_none=True)
out_offload = callable_offload(x)
out_offload.sum().backward()
offload_outs = [out_offload]
for param in callable_offload.parameters():
offload_outs.append(param.detach().clone())
torch.npu.reset_peak_memory_stats()
out_no_offload = callable_no_offload(x)
out_no_offload.sum().backward()
no_offload_outs = [out_no_offload]
for param in callable_no_offload.parameters():
no_offload_outs.append(param.detach().clone())
for i in range(len(offload_outs)):
assert torch.allclose(offload_outs[i], no_offload_outs[i]), f"Error in tensor {i}."
torch.npu.synchronize()
def test_example_from_doc(self):
offload_stream = torch.npu.Stream()
num_layers = 10
layers = [_Utils.create_layer("linear") for _ in range(num_layers)]
inp = [_Utils.create_tensor(None) for _ in range(num_layers)]
out = [None] * num_layers
cpu_offload_context, sync_function, manual_controller = get_cpu_offload_context(
enabled=True,
model_layers=num_layers,
manual_synchronization=True,
offload_stream=offload_stream,
)
for i in range(num_layers):
with cpu_offload_context:
out[i] = layers[i].forward(inp[i])
out[i] = sync_function(out[i])
manual_controller.start_offload_layer(i)
offload_stream.synchronize()
for i in range(num_layers):
manual_controller.release_activation_forward_gpu_memory(i)
for i in range(num_layers - 1, -1, -1):
manual_controller.start_reload_layer(i)
offload_stream.synchronize()
for i in range(num_layers):
out[i].sum().backward()