import unittest

import torch
from tensor_cast.device import TEST_DEVICE
from tensor_cast.performance_model.analytic import AnalyticPerformanceModel
from tensor_cast.runtime import Runtime


class CommAnalyticTestCase(unittest.TestCase):
    def test_all_to_all_excludes_local_chunk_from_network_bytes(self):
        x = torch.randn([16, 8], device="meta", dtype=torch.float16)
        perf_model = AnalyticPerformanceModel(TEST_DEVICE)
        with (
            Runtime(perf_model, TEST_DEVICE) as runtime,
            torch.no_grad(),
        ):
            torch.ops.tensor_cast.all_to_all(
                x,
                [4, 4, 4, 4],
                [4, 4, 4, 4],
                0,
                [0, 1, 2, 3],
            )

        stats = runtime.event_list[0].perf_results["analytic"].statistics
        self.assertEqual(stats["total_bytes_sent"], 192)
        self.assertEqual(stats["total_bytes_received"], 192)
        self.assertEqual(stats["message_size_bytes"], 192)