import os
import types
from unittest.mock import MagicMock, patch
import pytest
import torch
import torch.distributed as dist
import torchtitan.components.optimizer as tt_optimizer
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
from torchtitan_npu.patches.optimizer.virtual_optimizer import (
swap_tensor_copy_wrapper,
unwrap_dtensor,
virtual_optimizer_step_impl,
VirtualAllocator,
wrap_like_param,
)
@pytest.fixture(scope="module", autouse=True)
def setup_dist_env():
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
if not dist.is_initialized():
dist.init_process_group(backend="gloo", rank=0, world_size=1)
yield
if dist.is_initialized():
dist.destroy_process_group()
def test_unwrap_dtensor_logic():
plain_tensor = torch.randn(4, 4)
assert unwrap_dtensor(plain_tensor) is plain_tensor
mesh = DeviceMesh("cpu", [0])
dtensor = DTensor.from_local(plain_tensor, mesh, [Replicate()])
result = unwrap_dtensor(dtensor)
assert not isinstance(result, DTensor)
assert torch.equal(result, plain_tensor)
def test_wrap_like_param():
mesh = DeviceMesh("cpu", [0])
local_tensor = torch.randn(4, 4)
p = DTensor.from_local(torch.randn(4, 4), mesh, [Shard(0)])
wrapped = wrap_like_param(local_tensor, p)
assert isinstance(wrapped, DTensor)
assert wrapped.device_mesh == p.device_mesh
assert wrapped.placements == p.placements
assert wrapped.size() == p.size()
def test_virtual_allocator_memory_logic():
alloc = VirtualAllocator(pp_rank=0, pp_stages=2, virtual_optimizer_size=10.0)
assert alloc.get_swap_memory_sizes() == [10.0, 10.0]
alloc_list = VirtualAllocator(
pp_rank=1, pp_stages=2, virtual_optimizer_size=[5.0, 15.0]
)
assert alloc_list.get_swap_memory_sizes() == [5.0, 15.0]
p = torch.randn(2, 2)
with patch("torchtitan_npu.patches.optimizer.virtual_optimizer.torch_npu", spec=[]):
mem = alloc.create(p)
assert mem.shape == p.shape
assert torch.all(mem == 0)
def test_build_optimizers_virtual_validation():
config = types.SimpleNamespace(
virtual_optimizer=True, swap_optimizer=True, virtual_optimizer_size=10.0
)
with pytest.raises(
ValueError, match="Virtual optimizer does not support swap_optimizer"
):
tt_optimizer.build_optimizers([], config, None, None)
config_no_size = types.SimpleNamespace(
virtual_optimizer=True, swap_optimizer=False, virtual_optimizer_size=None
)
with pytest.raises(ValueError, match="virtual_optimizer_size must be specified"):
tt_optimizer.build_optimizers([], config_no_size, None, None)
@patch("torchtitan_npu.patches.optimizer.virtual_optimizer.torch_npu")
@patch("torchtitan_npu.patches.optimizer.virtual_optimizer._original_build_optimizers")
def test_build_optimizers_success_path(mock_orig, mock_npu):
mock_npu.npu.current_device.return_value = 0
mock_opt = MagicMock(spec=torch.optim.AdamW)
mock_opt.param_groups = [{"params": []}]
mock_orig.return_value = [mock_opt]
config = types.SimpleNamespace(
virtual_optimizer=True,
swap_optimizer=False,
virtual_optimizer_size=20.0,
name="AdamW",
)
results = tt_optimizer.build_optimizers(["part1"], config, None, None)
assert len(results) == 1
config_val = results[0]._allocator_config
assert config_val == (0, 1, 20.0)
assert torch.optim.AdamW.step.__name__ == "virtual_optimizer_step"
@patch("torchtitan_npu.patches.optimizer.virtual_optimizer.torch_npu")
def test_virtual_optimizer_step_impl_logic(mock_npu, monkeypatch):
mock_npu.npu.current_device.return_value = torch.device("cpu")
def mock_npu_method(self, *args, **kwargs):
return self
monkeypatch.setattr(torch.Tensor, "npu", mock_npu_method)
mesh = DeviceMesh("cpu", [0])
local_p = torch.randn(2, 2, requires_grad=True)
local_p.grad = torch.randn(2, 2)
p_dtensor = DTensor.from_local(local_p.detach(), mesh, [Replicate()])
p_dtensor.grad = DTensor.from_local(local_p.grad, mesh, [Replicate()])
class MockOptimizer:
def __init__(self):
self.param_groups = [
{
"params": [p_dtensor],
"lr": 0.01,
"betas": (0.9, 0.999),
"eps": 1e-8,
"weight_decay": 0.01,
"amsgrad": False,
"maximize": False,
"decoupled_weight_decay": True,
}
]
self.state = {p_dtensor: {}}
self.virtual_allocator = MagicMock()
self.virtual_allocator.init_exp.return_value = (
torch.zeros(2, 2),
torch.zeros(2, 2),
)
self.print_swap_flag = False
opt = MockOptimizer()
with patch("torch._fused_adamw_") as mock_fused:
virtual_optimizer_step_impl(opt)
assert "exp_avg" in opt.state[p_dtensor]
assert isinstance(opt.state[p_dtensor]["exp_avg"], DTensor)
assert mock_fused.called
@patch("torchtitan_npu.patches.optimizer.virtual_optimizer.torch_npu")
def test_virtual_optimizer_smoke_test(mock_npu, monkeypatch):
"""Smoke test: verify complete optimizer step flow with multiple parameters."""
mock_npu.npu.current_device.return_value = torch.device("cpu")
def mock_npu_method(self, *args, **kwargs):
return self
monkeypatch.setattr(torch.Tensor, "npu", mock_npu_method)
mesh = DeviceMesh("cpu", [0])
param1 = torch.randn(4, 4, requires_grad=True)
param1.grad = torch.randn(4, 4)
p1_dtensor = DTensor.from_local(param1.detach(), mesh, [Replicate()])
p1_dtensor.grad = DTensor.from_local(param1.grad, mesh, [Replicate()])
param2 = torch.randn(2, 2, requires_grad=True)
param2.grad = torch.randn(2, 2)
p2_dtensor = DTensor.from_local(param2.detach(), mesh, [Replicate()])
p2_dtensor.grad = DTensor.from_local(param2.grad, mesh, [Replicate()])
class MockOptimizer:
def __init__(self):
self.param_groups = [
{
"params": [p1_dtensor, p2_dtensor],
"lr": 0.001,
"betas": (0.9, 0.999),
"eps": 1e-8,
"weight_decay": 0.01,
"amsgrad": False,
"maximize": False,
"decoupled_weight_decay": True,
}
]
self.state = {p1_dtensor: {}, p2_dtensor: {}}
self.virtual_allocator = MagicMock()
self.virtual_allocator.init_exp.return_value = (
torch.zeros(2, 2),
torch.zeros(2, 2),
)
self.print_swap_flag = False
opt = MockOptimizer()
with patch("torch._fused_adamw_") as mock_fused:
virtual_optimizer_step_impl(opt)
assert "exp_avg" in opt.state[p1_dtensor]
assert isinstance(opt.state[p1_dtensor]["exp_avg"], DTensor)
assert "exp_avg" in opt.state[p2_dtensor]
assert isinstance(opt.state[p2_dtensor]["exp_avg"], DTensor)
first_call_count = mock_fused.call_count
p1_dtensor.grad = DTensor.from_local(torch.randn(4, 4), mesh, [Replicate()])
p2_dtensor.grad = DTensor.from_local(torch.randn(2, 2), mesh, [Replicate()])
virtual_optimizer_step_impl(opt)
assert "exp_avg_sq" in opt.state[p1_dtensor]
assert mock_fused.call_count > first_call_count
@pytest.fixture(autouse=True)
def patch_copy():
original = torch.Tensor.copy_
torch.Tensor.copy_ = swap_tensor_copy_wrapper(torch.Tensor.copy_)
yield
torch.Tensor.copy_ = original
def test_shape_mismatch():
dst = torch.zeros(3)
src = torch.zeros(4)
with pytest.raises(RuntimeError):
dst.copy_(src)
def test_self_copy():
t = torch.tensor([1.0, 2.0])
res = t.copy_(t)
assert res is t
def test_normal_cpu_to_cpu():
src = torch.tensor([1.0, 2.0, 3.0])
dst = torch.zeros_like(src)
dst.copy_(src)
assert torch.allclose(dst, src)
def test_swap_cpu_to_cpu():
src = torch.tensor([1.0, 2.0, 3.0])
src.swap_tensor = True
dst = torch.zeros_like(src)
dst.swap_tensor = True
dst.copy_(src)
assert torch.allclose(dst, src)
def test_swap_npu_to_cpu():
if not torch.npu.is_available():
return
src = torch.tensor([1.0, 2.0], device="npu")
src.swap_tensor = True
dst = torch.zeros_like(src).cpu()
dst.copy_(src)
assert torch.allclose(dst, src.cpu())
def test_cpu_to_swap_npu():
if not torch.npu.is_available():
return
src = torch.tensor([1.0, 2.0], device="cpu")
dst = torch.zeros_like(src).npu()
dst.swap_tensor = True
dst.copy_(src)
assert torch.allclose(dst, src.npu())
def test_swap_npu_to_swap_npu_same_device():
if not torch.npu.is_available():
return
src = torch.tensor([1.0, 2.0], device="npu")
src.swap_tensor = True
dst = torch.zeros_like(src).npu()
dst.swap_tensor = True
dst.copy_(src)
assert torch.allclose(dst, src)
def test_swap_npu_to_swap_npu_cross_device():
if not torch.npu.is_available() or torch.npu.device_count() < 2:
return
src = torch.tensor([1.0, 2.0], device="npu:0")
src.swap_tensor = True
dst = torch.zeros(2, device="npu:1")
dst.swap_tensor = True
dst.copy_(src)
assert torch.allclose(dst, src.to("npu:1"))
def test_cpu_to_npu_swap_non_blocking():
if not torch.npu.is_available():
return
src = torch.tensor([1.0, 2.0], device="cpu")
dst = torch.zeros_like(src).npu()
dst.swap_tensor = True
dst.copy_(src, non_blocking=True)
assert torch.allclose(dst, src.npu())
def test_swap_npu_to_npu_non_blocking():
if not torch.npu.is_available():
return
src = torch.tensor([1.0, 2.0], device="npu")
src.swap_tensor = True
dst = torch.zeros_like(src).npu()
dst.swap_tensor = True
dst.copy_(src, non_blocking=True)
assert torch.allclose(dst, src)