from functools import wraps
import itertools
import os
import types
from typing import (
Any,
Callable,
Dict,
Generator,
Iterable,
Iterator,
List,
Optional,
Tuple,
Union,
)
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_distributed import skipIfUnsupportMultiNPU
from torch_npu.npu.amp.sharded_grad_scaler import _ShardedGradScaler as NpuShardedGradScaler
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.fsdp.wrap import _Policy
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import (
BackwardPrefetch,
CPUOffload,
FullOptimStateDictConfig,
FullStateDictConfig,
LocalOptimStateDictConfig,
LocalStateDictConfig,
MixedPrecision,
OptimStateDictConfig,
ShardedOptimStateDictConfig,
ShardedStateDictConfig,
ShardingStrategy,
StateDictConfig,
StateDictSettings,
StateDictType,
)
class FSDPTestConfig(object):
def __init__(self) -> None:
self.sharding_strategy: Optional[ShardingStrategy] = None
self.cpu_offload: Optional[CPUOffload] = None
self.auto_wrap_policy: Optional[Union[Callable, _Policy]] = None
self.backward_prefetch: Optional[BackwardPrefetch] = BackwardPrefetch.BACKWARD_PRE
self.mixed_precision: Optional[MixedPrecision] = None
self.ignored_modules: Optional[Iterable[torch.nn.Module]] = None
self.param_init_fn: Optional[Callable[[nn.Module], None]] = None
self.device_id: Optional[Union[int, torch.device]] = None
self.sync_module_states: bool = False
self.forward_prefetch: bool = False
self.limit_all_gathers: bool = False
self.use_orig_params: bool = False
self.ignored_modules: Optional[Iterable[torch.nn.Module]] = None
def __repr__(self) -> str:
msg = "FSDP Config:[\n"
for k, v in self.__dict__.items():
msg += f" {k}: {v}\n"
msg += "]"
return msg
def get_wrap_policies() -> List[_Policy]:
def always_true_policy(module: torch.nn.Module, recurse: bool, nonwrapped_numel: int) -> bool:
return True
def always_false_policy(module: torch.nn.Module, recurse: bool, nonwrapped_numel: int) -> bool:
return False
return [None, always_true_policy, always_false_policy]
def get_mixed_precision_configs() -> List[MixedPrecision]:
return [None,
MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True),
MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=False)
]
def get_cpu_offload_configs() -> List[CPUOffload]:
return [None, CPUOffload(offload_params=True), CPUOffload(offload_params=False)]
def named_product(**items: Iterable[Any]) -> Generator[Dict[str, Any], None, None]:
keys = list(items.keys())
values = list(items.values())
for values in itertools.product(*values):
yield dict(zip(keys, values))
def testcase_configs() -> Generator[FSDPTestConfig, None, None]:
for params in named_product(use_orig_params=[True, False],
auto_wrap_policy=get_wrap_policies(),
forward_prefetch=[True, False],
backward_prefetch=[BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST],
mixed_precision=get_mixed_precision_configs(),
sync_module_states=[True, False],
cpu_offload=get_cpu_offload_configs(),
limit_all_gathers=[True, False]):
test_config = FSDPTestConfig()
for k, v in params.items():
setattr(test_config, k, v)
yield test_config
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(2, 2)
self.fc2 = torch.nn.Linear(2, 2)
def forward(self, x):
x = self.fc1(x)
return self.fc2(x)
def run_testcase_with_config(device, config: FSDPTestConfig) -> None:
model = FSDP(MyModule(),
device_id=torch.device(device),
use_orig_params=config.use_orig_params,
cpu_offload=config.cpu_offload,
auto_wrap_policy=config.auto_wrap_policy,
backward_prefetch=config.backward_prefetch,
mixed_precision=config.mixed_precision,
sync_module_states=config.sync_module_states,
forward_prefetch=config.forward_prefetch,
limit_all_gathers=config.limit_all_gathers,
ignored_modules=None)
scaler = NpuShardedGradScaler(growth_interval=1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
for _ in range(5):
x = model(torch.ones(2, 2).to(device))
loss = x.sum()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
rank = device.split(":")[1]
torch.save(model.state_dict(), f"test_fsdp_with_many_config{rank}.pt")
model.load_state_dict(torch.load(f"test_fsdp_with_many_config{rank}.pt"))
class TestFSDP(TestCase):
MAIN_PROCESS_RANK = -1
def join_or_run(self, fn):
@wraps(fn)
def wrapper(self):
if self.rank == self.MAIN_PROCESS_RANK:
for p in self.processes:
p.join()
else:
fn()
return types.MethodType(wrapper, self)
def __init__(self, method_name: str = "runTest") -> None:
super().__init__(method_name)
fn = getattr(self, method_name)
setattr(self, method_name, self.join_or_run(fn))
def setUp(self):
super(TestCase, self).setUp()
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29588'
os.environ["BACKEND"] = dist.Backend.HCCL
self.processes = []
self.rank = self.MAIN_PROCESS_RANK
proc = torch.multiprocessing.get_context("spawn").Process
for rank in range(int(self.world_size)):
process = proc(
target=self.__class__._run,
name="process " + str(rank),
args=(rank, self._current_test_name()),
)
process.start()
self.processes.append(process)
def tearDown(self):
super().tearDown()
for p in self.processes:
p.terminate()
self.processes = []
def _current_test_name(self) -> str:
return self.id().split(".")[-1]
@property
def world_size(self) -> int:
return 2
@classmethod
def _run(cls, rank: int, test_name: str) -> None:
self = cls(test_name)
self.rank = rank
getattr(self, test_name)()
def dist_init(self):
torch.npu.set_device(self.rank)
dist.init_process_group(backend="hccl", rank=self.rank, world_size=self.world_size)
@skipIfUnsupportMultiNPU(2)
def test_fsdp_with_many_config(self):
'''
Testcase for FSDP with different configs and expect non-raise error.
'''
self.dist_init()
device = f"npu:{self.rank}"
for config in testcase_configs():
try:
if self.rank == 0:
print(f"Running FSDP testcase with config: {config}")
run_testcase_with_config(device, config)
except Exception as e:
self.assertTrue(False, f"Error when running with config: {config}, and error is {e}")
if __name__ == "__main__":
run_tests()