import unittest
import torch
from tensor_cast.device import (
A5,
ATLAS_800,
CommGrid,
DeviceProfile,
InterconnectTopology,
InterconnectType,
StaticCost,
)
from tensor_cast.utils import DTYPE_FP4, DTYPE_FP8
_DEVICE_PROFILE_SPECS = [
{
"name": "ATLAS_800_A2_376T_64G",
"comm_grid": ATLAS_800.A2_INTERCONNECT,
"mma_ops": {
torch.float32: 99.5 * 1e12,
torch.bfloat16: 353.9 * 1e12,
torch.half: 376 * 1e12,
torch.int8: 752 * 1e12,
},
"gp_ops": {
torch.float32: 22 / 2 * 1e12,
torch.bfloat16: 22 * 1e12,
torch.half: 22 * 1e12,
},
"memory_size_bytes": 64 * (1024**3),
"memory_bandwidth_bytes_ps": 1.6 * (1024**4),
"compute_efficiency": 0.7,
"memory_efficiency": 0.6,
},
{
"name": "ATLAS_800_A2_313T_64G",
"comm_grid": ATLAS_800.A2_INTERCONNECT,
"mma_ops": {
torch.float32: 83 * 1e12,
torch.bfloat16: 294.9 * 1e12,
torch.half: 313 * 1e12,
torch.int8: 626 * 1e12,
},
"gp_ops": {
torch.float32: 18 / 2 * 1e12,
torch.bfloat16: 18 * 1e12,
torch.half: 18 * 1e12,
},
"memory_size_bytes": 64 * (1024**3),
"memory_bandwidth_bytes_ps": 1.6 * (1024**4),
"compute_efficiency": 0.7,
"memory_efficiency": 0.6,
},
{
"name": "ATLAS_800_A2_280T_64G",
"comm_grid": ATLAS_800.A2_INTERCONNECT,
"mma_ops": {
torch.float32: 75 * 1e12,
torch.bfloat16: 245.8 * 1e12,
torch.half: 280 * 1e12,
torch.int8: 560 * 1e12,
},
"gp_ops": {
torch.float32: 16 / 2 * 1e12,
torch.bfloat16: 16 * 1e12,
torch.half: 16 * 1e12,
},
"memory_size_bytes": 64 * (1024**3),
"memory_bandwidth_bytes_ps": 1.6 * (1024**4),
"compute_efficiency": 0.7,
"memory_efficiency": 0.6,
},
{
"name": "ATLAS_800_A2_280T_64G_PCIE",
"comm_grid": ATLAS_800.A2_INTERCONNECT_PCIE,
"mma_ops": {
torch.float32: 75 * 1e12,
torch.bfloat16: 245.8 * 1e12,
torch.half: 280 * 1e12,
torch.int8: 560 * 1e12,
},
"gp_ops": {
torch.float32: 16 / 2 * 1e12,
torch.bfloat16: 16 * 1e12,
torch.half: 16 * 1e12,
},
"memory_size_bytes": 64 * (1024**3),
"memory_bandwidth_bytes_ps": 1.6 * (1024**4),
"compute_efficiency": 0.7,
"memory_efficiency": 0.6,
},
{
"name": "ATLAS_800_A2_280T_32G_PCIE",
"comm_grid": ATLAS_800.A2_INTERCONNECT_PCIE,
"mma_ops": {
torch.float32: 75 * 1e12,
torch.bfloat16: 245.8 * 1e12,
torch.half: 280 * 1e12,
torch.int8: 560 * 1e12,
},
"gp_ops": {
torch.float32: 16 / 2 * 1e12,
torch.bfloat16: 16 * 1e12,
torch.half: 16 * 1e12,
},
"memory_size_bytes": 32 * (1024**3),
"memory_bandwidth_bytes_ps": 0.8 * (1024**4),
"compute_efficiency": 0.7,
"memory_efficiency": 0.6,
},
{
"name": "ATLAS_800_A3_752T_128G_DIE",
"comm_grid": ATLAS_800.A3_INTERCONNECT,
"mma_ops": {
torch.float32: 99.5 * 1e12,
torch.bfloat16: 353.9 * 1e12,
torch.half: 376 * 1e12,
torch.int8: 752 * 1e12,
},
"gp_ops": {
torch.float32: 22 / 2 * 1e12,
torch.bfloat16: 22 * 1e12,
torch.half: 22 * 1e12,
},
"memory_size_bytes": 64 * (1024**3),
"memory_bandwidth_bytes_ps": 1.6 * (1024**4),
"compute_efficiency": 0.7,
"memory_efficiency": 0.6,
},
{
"name": "ATLAS_800_A3_560T_128G_DIE",
"comm_grid": ATLAS_800.A3_INTERCONNECT,
"mma_ops": {
torch.float32: 75 * 1e12,
torch.bfloat16: 245.8 * 1e12,
torch.half: 280 * 1e12,
torch.int8: 560 * 1e12,
},
"gp_ops": {
torch.float32: 16 / 2 * 1e12,
torch.bfloat16: 16 * 1e12,
torch.half: 16 * 1e12,
},
"memory_size_bytes": 64 * (1024**3),
"memory_bandwidth_bytes_ps": 1.6 * (1024**4),
"compute_efficiency": 0.7,
"memory_efficiency": 0.6,
},
{
"name": "ATLAS_800_A3_560T_128G_DIE_ROCE",
"comm_grid": ATLAS_800.A3_INTERCONNECT_ROCE,
"mma_ops": {
torch.float32: 75 * 1e12,
torch.bfloat16: 245.8 * 1e12,
torch.half: 280 * 1e12,
torch.int8: 560 * 1e12,
},
"gp_ops": {
torch.float32: 16 / 2 * 1e12,
torch.bfloat16: 16 * 1e12,
torch.half: 16 * 1e12,
},
"memory_size_bytes": 64 * (1024**3),
"memory_bandwidth_bytes_ps": 1.6 * (1024**4),
"compute_efficiency": 0.7,
"memory_efficiency": 0.6,
},
]
_A5_DEVICE_PROFILE_SPECS = [
{
"name": "ATLAS_350_425T_112G",
"comm_grid": A5.Interconnect.PCIE2_UB4,
"mma_ops": {
torch.float32: 189 * 1e12,
torch.bfloat16: 378 * 1e12,
torch.half: 378 * 1e12,
DTYPE_FP8: 756 * 1e12,
torch.int8: 756 * 1e12,
DTYPE_FP4: 1512 * 1e12,
},
"gp_ops": {
torch.float32: 24 * 1e12,
torch.bfloat16: 47 * 1e12,
torch.half: 47 * 1e12,
},
"memory_size_bytes": 112 * (1024**3),
"memory_bandwidth_bytes_ps": 1.4 * (1024**4),
"compute_efficiency": 0.9,
"memory_efficiency": 0.8,
},
{
"name": "ATLAS_350_425T_84G",
"comm_grid": A5.Interconnect.PCIE2_UB4,
"mma_ops": {
torch.float32: 189 * 1e12,
torch.bfloat16: 378 * 1e12,
torch.half: 378 * 1e12,
DTYPE_FP8: 756 * 1e12,
torch.int8: 756 * 1e12,
DTYPE_FP4: 1512 * 1e12,
},
"gp_ops": {
torch.float32: 24 * 1e12,
torch.bfloat16: 47 * 1e12,
torch.half: 47 * 1e12,
},
"memory_size_bytes": 84 * (1024**3),
"memory_bandwidth_bytes_ps": 1.4 * (1024**4),
"compute_efficiency": 0.9,
"memory_efficiency": 0.8,
},
]
class A3InterconnectRoceTestCase(unittest.TestCase):
def setUp(self):
self.roce = ATLAS_800.A3_INTERCONNECT_ROCE
self.orig = ATLAS_800.A3_INTERCONNECT
def test_grid_shape_dual_node_only(self):
self.assertEqual(self.roce.grid.shape, (2, 8, 2))
def test_grid_ndim(self):
self.assertEqual(self.roce.grid.ndim, 3)
def test_topologies_count_matches_ndim(self):
self.assertEqual(len(self.roce.topologies), self.roce.grid.ndim)
def test_tier0_is_roce_bandwidth(self):
self.assertEqual(self.roce.topologies[0].bandwidth_bytes_ps, 196 * 1e9 / 8)
def test_tier0_latency(self):
self.assertEqual(self.roce.topologies[0].latency_s, 5.5 * 1e-6)
def test_tier0_comm_efficiency(self):
self.assertEqual(self.roce.topologies[0].comm_efficiency, 0.7)
def test_tier1_same_as_original(self):
self.assertEqual(
self.roce.topologies[1].bandwidth_bytes_ps,
self.orig.topologies[1].bandwidth_bytes_ps,
)
self.assertEqual(self.roce.topologies[1].latency_s, self.orig.topologies[1].latency_s)
def test_tier2_same_as_original(self):
self.assertEqual(
self.roce.topologies[2].bandwidth_bytes_ps,
self.orig.topologies[2].bandwidth_bytes_ps,
)
self.assertEqual(self.roce.topologies[2].latency_s, self.orig.topologies[2].latency_s)
def test_total_devices_number(self):
self.assertEqual(self.roce.grid.numel(), 32)
class DeviceProfileTestCase(unittest.TestCase):
"""Generic parameterized tests for each device profile defined in _DEVICE_PROFILE_SPECS."""
def test_registered_in_all_device_profiles(self):
for spec in _DEVICE_PROFILE_SPECS:
with self.subTest(device=spec["name"]):
self.assertIn(spec["name"], DeviceProfile.all_device_profiles)
def test_name(self):
for spec in _DEVICE_PROFILE_SPECS:
with self.subTest(device=spec["name"]):
profile = DeviceProfile.all_device_profiles[spec["name"]]
self.assertEqual(profile.name, spec["name"])
def test_vendor(self):
for spec in _DEVICE_PROFILE_SPECS:
with self.subTest(device=spec["name"]):
profile = DeviceProfile.all_device_profiles[spec["name"]]
self.assertEqual(profile.vendor, "HUAWEI")
def test_comm_grid(self):
for spec in _DEVICE_PROFILE_SPECS:
with self.subTest(device=spec["name"]):
profile = DeviceProfile.all_device_profiles[spec["name"]]
self.assertIs(profile.comm_grid, spec["comm_grid"])
def test_mma_ops(self):
for spec in _DEVICE_PROFILE_SPECS:
with self.subTest(device=spec["name"]):
profile = DeviceProfile.all_device_profiles[spec["name"]]
for dtype, expected in spec["mma_ops"].items():
with self.subTest(dtype=str(dtype)):
self.assertEqual(profile.mma_ops[dtype], expected)
def test_gp_ops(self):
for spec in _DEVICE_PROFILE_SPECS:
with self.subTest(device=spec["name"]):
profile = DeviceProfile.all_device_profiles[spec["name"]]
for dtype, expected in spec["gp_ops"].items():
with self.subTest(dtype=str(dtype)):
self.assertEqual(profile.gp_ops[dtype], expected)
def test_memory_size_bytes(self):
for spec in _DEVICE_PROFILE_SPECS:
with self.subTest(device=spec["name"]):
profile = DeviceProfile.all_device_profiles[spec["name"]]
self.assertEqual(profile.memory_size_bytes, spec["memory_size_bytes"])
def test_memory_bandwidth(self):
for spec in _DEVICE_PROFILE_SPECS:
with self.subTest(device=spec["name"]):
profile = DeviceProfile.all_device_profiles[spec["name"]]
self.assertEqual(profile.memory_bandwidth_bytes_ps, spec["memory_bandwidth_bytes_ps"])
def test_compute_efficiency(self):
for spec in _DEVICE_PROFILE_SPECS:
with self.subTest(device=spec["name"]):
profile = DeviceProfile.all_device_profiles[spec["name"]]
self.assertEqual(profile.compute_efficiency, spec["compute_efficiency"])
def test_memory_efficiency(self):
for spec in _DEVICE_PROFILE_SPECS:
with self.subTest(device=spec["name"]):
profile = DeviceProfile.all_device_profiles[spec["name"]]
self.assertEqual(profile.memory_efficiency, spec["memory_efficiency"])
def test_static_cost(self):
for spec in _DEVICE_PROFILE_SPECS:
with self.subTest(device=spec["name"]):
profile = DeviceProfile.all_device_profiles[spec["name"]]
self.assertIs(profile.static_cost, ATLAS_800.STATIC_COST)
class DeviceProfileRegistrationTestCase(unittest.TestCase):
def test_duplicate_name_raises(self):
name = "ATLAS_800_A3_560T_128G_DIE_ROCE"
before = set(DeviceProfile.all_device_profiles.keys())
with self.assertRaises(ValueError):
DeviceProfile(
name=name,
vendor="TEST",
comm_grid=ATLAS_800.A3_INTERCONNECT_ROCE,
)
self.assertEqual(set(DeviceProfile.all_device_profiles.keys()), before)
class CommGridValidationTestCase(unittest.TestCase):
def test_zero_dim_grid_raises(self):
with self.assertRaises(ValueError):
CommGrid(grid=torch.tensor(0), topologies={})
def test_ndim_topologies_mismatch_raises(self):
with self.assertRaises(ValueError):
CommGrid(
grid=torch.arange(4).reshape(2, 2),
topologies={0: InterconnectTopology(bandwidth_bytes_ps=1e9, latency_s=1e-6)},
)
def test_dimension_less_than_two_raises(self):
with self.assertRaises(ValueError):
CommGrid(
grid=torch.arange(1).reshape(1),
topologies={0: InterconnectTopology(bandwidth_bytes_ps=1e9, latency_s=1e-6)},
)
class InterconnectTopologyTestCase(unittest.TestCase):
def test_default_comm_efficiency(self):
topo = InterconnectTopology(bandwidth_bytes_ps=1e9, latency_s=1e-6)
self.assertEqual(topo.comm_efficiency, 1.0)
def test_default_type_is_clos(self):
topo = InterconnectTopology(bandwidth_bytes_ps=1e9, latency_s=1e-6)
self.assertEqual(topo.type, InterconnectType.CLOS)
def test_full_mesh_type(self):
topo = InterconnectTopology(
bandwidth_bytes_ps=1e9,
latency_s=1e-6,
type=InterconnectType.FULL_MESH,
)
self.assertEqual(topo.type, InterconnectType.FULL_MESH)
class StaticCostTestCase(unittest.TestCase):
def test_default_values(self):
cost = StaticCost()
self.assertEqual(cost.mma_op_cost_s, 0)
self.assertEqual(cost.gp_op_cost_s, 0)
self.assertEqual(cost.comm_op_cost_s, 0)
def test_a3_static_cost_values(self):
cost = ATLAS_800.STATIC_COST
self.assertEqual(cost.mma_op_cost_s, 5 * 1e-6)
self.assertEqual(cost.gp_op_cost_s, 2 * 1e-6)
self.assertEqual(cost.comm_op_cost_s, 10 * 1e-6)
class A5ChipTestCase(unittest.TestCase):
def test_c425t_mma_ops(self):
chip = A5.Chip.C425T["mma_ops"]
self.assertEqual(chip[torch.float32], 189 * 1e12)
self.assertEqual(chip[torch.bfloat16], 378 * 1e12)
self.assertEqual(chip[torch.half], 378 * 1e12)
self.assertEqual(chip[torch.float8_e5m2], 756 * 1e12)
self.assertEqual(chip[torch.int8], 756 * 1e12)
self.assertEqual(chip[DTYPE_FP4], 1512 * 1e12)
def test_c425t_gp_ops(self):
chip = A5.Chip.C425T["gp_ops"]
self.assertEqual(chip[torch.float32], 24 * 1e12)
self.assertEqual(chip[torch.bfloat16], 47 * 1e12)
self.assertEqual(chip[torch.half], 47 * 1e12)
def test_c425t_compute_efficiency(self):
self.assertEqual(A5.Chip.C425T["compute_efficiency"], 0.9)
class A5MemTestCase(unittest.TestCase):
def test_m112g_1_4t(self):
mem = A5.Mem.M112G_1_4T
self.assertEqual(mem["memory_size_bytes"], 112 * (1024**3))
self.assertEqual(mem["memory_bandwidth_bytes_ps"], 1.4 * (1024**4))
self.assertEqual(mem["memory_efficiency"], 0.8)
def test_m84g_1_4t(self):
mem = A5.Mem.M84G_1_4T
self.assertEqual(mem["memory_size_bytes"], 84 * (1024**3))
self.assertEqual(mem["memory_bandwidth_bytes_ps"], 1.4 * (1024**4))
self.assertEqual(mem["memory_efficiency"], 0.8)
class A5InterconnectTestCase(unittest.TestCase):
def test_pcie2_ub4_grid_shape(self):
ic = A5.Interconnect.PCIE2_UB4
self.assertEqual(ic.grid.shape, (2, 2, 4))
def test_pcie2_ub4_ndim(self):
self.assertEqual(A5.Interconnect.PCIE2_UB4.grid.ndim, 3)
def test_pcie2_ub4_topologies_count(self):
ic = A5.Interconnect.PCIE2_UB4
self.assertEqual(len(ic.topologies), ic.grid.ndim)
def test_pcie2_ub4_tier0_cpu_pcie(self):
t = A5.Interconnect.PCIE2_UB4.topologies[0]
self.assertEqual(t.bandwidth_bytes_ps, 24 * 1e9)
self.assertEqual(t.latency_s, 4.5 * 1e-6)
self.assertEqual(t.comm_efficiency, 0.75 * 0.7)
def test_pcie2_ub4_tier1_group_pcie(self):
t = A5.Interconnect.PCIE2_UB4.topologies[1]
self.assertEqual(t.bandwidth_bytes_ps, 32 * 1e9)
self.assertEqual(t.latency_s, 3 * 1e-6)
self.assertEqual(t.comm_efficiency, 0.8 * 0.7)
def test_pcie2_ub4_tier2_ub_fullmesh(self):
t = A5.Interconnect.PCIE2_UB4.topologies[2]
self.assertEqual(t.bandwidth_bytes_ps, 53 * 3 * 1e9)
self.assertEqual(t.latency_s, 1.5 * 1e-6)
self.assertEqual(t.comm_efficiency, 0.85)
self.assertEqual(t.type, InterconnectType.FULL_MESH)
def test_pcie2_ub4_max_devices(self):
self.assertEqual(A5.Interconnect.PCIE2_UB4.grid.numel(), 16)
class A5DeviceProfileTestCase(unittest.TestCase):
def test_registered_in_all_device_profiles(self):
for spec in _A5_DEVICE_PROFILE_SPECS:
with self.subTest(device=spec["name"]):
self.assertIn(spec["name"], DeviceProfile.all_device_profiles)
def test_name(self):
for spec in _A5_DEVICE_PROFILE_SPECS:
with self.subTest(device=spec["name"]):
profile = DeviceProfile.all_device_profiles[spec["name"]]
self.assertEqual(profile.name, spec["name"])
def test_vendor(self):
for spec in _A5_DEVICE_PROFILE_SPECS:
with self.subTest(device=spec["name"]):
profile = DeviceProfile.all_device_profiles[spec["name"]]
self.assertEqual(profile.vendor, "HUAWEI")
def test_comm_grid(self):
for spec in _A5_DEVICE_PROFILE_SPECS:
with self.subTest(device=spec["name"]):
profile = DeviceProfile.all_device_profiles[spec["name"]]
self.assertIs(profile.comm_grid, spec["comm_grid"])
def test_mma_ops(self):
for spec in _A5_DEVICE_PROFILE_SPECS:
with self.subTest(device=spec["name"]):
profile = DeviceProfile.all_device_profiles[spec["name"]]
for dtype, expected in spec["mma_ops"].items():
with self.subTest(dtype=str(dtype)):
self.assertEqual(profile.mma_ops[dtype], expected)
def test_gp_ops(self):
for spec in _A5_DEVICE_PROFILE_SPECS:
with self.subTest(device=spec["name"]):
profile = DeviceProfile.all_device_profiles[spec["name"]]
for dtype, expected in spec["gp_ops"].items():
with self.subTest(dtype=str(dtype)):
self.assertEqual(profile.gp_ops[dtype], expected)
def test_memory_size_bytes(self):
for spec in _A5_DEVICE_PROFILE_SPECS:
with self.subTest(device=spec["name"]):
profile = DeviceProfile.all_device_profiles[spec["name"]]
self.assertEqual(profile.memory_size_bytes, spec["memory_size_bytes"])
def test_memory_bandwidth(self):
for spec in _A5_DEVICE_PROFILE_SPECS:
with self.subTest(device=spec["name"]):
profile = DeviceProfile.all_device_profiles[spec["name"]]
self.assertEqual(profile.memory_bandwidth_bytes_ps, spec["memory_bandwidth_bytes_ps"])
def test_compute_efficiency(self):
for spec in _A5_DEVICE_PROFILE_SPECS:
with self.subTest(device=spec["name"]):
profile = DeviceProfile.all_device_profiles[spec["name"]]
self.assertEqual(profile.compute_efficiency, spec["compute_efficiency"])
def test_memory_efficiency(self):
for spec in _A5_DEVICE_PROFILE_SPECS:
with self.subTest(device=spec["name"]):
profile = DeviceProfile.all_device_profiles[spec["name"]]
self.assertEqual(profile.memory_efficiency, spec["memory_efficiency"])
def test_static_cost(self):
for spec in _A5_DEVICE_PROFILE_SPECS:
with self.subTest(device=spec["name"]):
profile = DeviceProfile.all_device_profiles[spec["name"]]
self.assertIs(profile.static_cost, A5.STATIC_COST)
class A5StaticCostTestCase(unittest.TestCase):
def test_values(self):
cost = A5.STATIC_COST
self.assertEqual(cost.mma_op_cost_s, 5 * 1e-6)
self.assertEqual(cost.gp_op_cost_s, 2 * 1e-6)
self.assertEqual(cost.comm_op_cost_s, 5 * 1e-6)
class A5RegistrationTestCase(unittest.TestCase):
def test_duplicate_name_raises(self):
name = "ATLAS_350_425T_112G"
before = set(DeviceProfile.all_device_profiles.keys())
with self.assertRaises(ValueError):
DeviceProfile(
name=name,
vendor="TEST",
comm_grid=A5.Interconnect.PCIE2_UB4,
)
self.assertEqual(set(DeviceProfile.all_device_profiles.keys()), before)
if __name__ == "__main__":
unittest.main()