import random
import numpy as np
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices


class TestReshapeAndCacheSiso(TestCase):
    num_tokens = 16
    num_head = 16
    head_size = 16
    block_size = 16
    num_blocks = 8
    def cal(self, key, key_cache, slot_mapping):
        key_expect = key_cache.clone()
        for i, slot in enumerate(slot_mapping):
            if slot < 0:
                continue
            block_index = slot // self.block_size
            block_offset = slot % self.block_size
            token_key = key[i]
            key_expect[block_index][block_offset] = token_key
        return key_expect

    @SupportedDevices(['Ascend910B'])
    def test_reshapeandcache_siso(self):
        head_size_k = np.random.randint(1, 256)
        key = torch.rand((self.num_tokens, self.num_head, head_size_k), dtype=torch.float16)
        num_slots = self.block_size * self.num_blocks
        slot_list = random.sample(range(num_slots), self.num_tokens)
        slot_mapping = np.array(slot_list).astype(np.int32)
        key_cache = torch.rand((self.num_blocks, self.block_size, self.num_head, head_size_k), dtype=torch.float16)
        key_expect = self.cal(key, key_cache, slot_mapping)
        key = key.npu()
        key_cache = key_cache.npu()
        slot_mapping = torch.from_numpy(slot_mapping).to(torch.int32).npu()
        torch_npu._npu_reshape_and_cache_siso(key, key_cache, slot_mapping)
        self.assertRtolEqual(key_expect, key_cache)

if __name__ == '__main__':
    run_tests()