8209d3e4创建于 2025年2月25日历史提交
import unittest
import numpy as nps
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices


class TestGroupTopk(TestCase):
    def generalize_param(self):
        token_nums = [1, 16, 33]
        expert_nums = [1, 4, 15, 16, 17, 32, 65, 128, 257, 1024]
        k_params = [1, 4, 15, 16, 17, 32, 65, 257, 1024]
        k_inner_params = [1, 2, 3, 4, 16, 32, 65, 1024]

        expert_num_groups = []
        for expert_num in expert_nums:
            factors = set()
            for i in range(1, int(expert_num**0.5) + 1):
                if expert_num % i == 0:
                    factors.add(i)
                    factors.add(int(expert_num / i))
            expert_num_groups.append(list(factors))

        for token_num in token_nums:
            for expert_num, group_nums in zip(expert_nums, expert_num_groups):
                for group_num in group_nums:
                    for k in k_params:
                        for k_inner in k_inner_params:
                            if k > group_num or k_inner > expert_num // group_num:
                                continue
                            yield token_num, expert_num, group_num, k, k_inner

    def golden_calc(self, input0, k, group_num, k_inner):
        token_num, expert_num = input0.shape
        input0 = torch.reshape(input0, (token_num, group_num, expert_num // group_num))
        output = input0.clone()
        input0 = input0.to(torch.float)
        group_tensor = torch.topk(input0, k_inner).values
        group_tensor = torch.sum(group_tensor, dim=-1)
        sort_index = torch.from_numpy(np.argsort(-group_tensor.numpy(), kind='stable'))
        cols_to_use = torch.arange(k, group_num, dtype=torch.long)
        row_indices = torch.arange(sort_index.shape[0]).repeat_interleave(cols_to_use.shape[0])
        col_indices = sort_index.index_select(1, cols_to_use).view(-1)
        output[row_indices, col_indices] = 0
        return [torch.reshape(output, (token_num, expert_num))]

    @unittest.skip("skip test_group_topk now")
    @SupportedDevices(['Ascend910B'])
    def test_group_topk(self):
        for dtype in [torch.float16, torch.bfloat16]:
            for token_num, expert_num, group_num, k, k_inner in self.generalize_param():
                input0 = torch.empty((token_num, expert_num), dtype=dtype, device='npu').uniform_(-2, 2)
                output0 = torch.randn((token_num, expert_num), dtype=dtype, device='npu')
                expect_output = self.golden_calc(input0.cpu(), k, group_num, k_inner)
                torch_npu._npu_group_topk(input0, k=k, group_num=group_num, n=k_inner)

                self.assertRtolEqual(input0, expect_output[0])


if __name__ == "__main__":
    run_tests()