#!/usr/bin/env python
# coding=utf-8
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
# MemCache_Hybrid is licensed under Mulan PSL v2.
# You can use this software according to the terms and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
#          http://license.coscl.org.cn/MulanPSL2
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.
import math
import os
import time
from time import sleep
from typing import List
import torch
import torch_npu


g_local_type = "npu"

# the block size of 128 token, the model is deepseek r1 w8a8
layers_num = 61
k_size = 128 * 1024
v_size = 16 * 1024
k_sizes = [k_size for _ in range(layers_num)]
v_sizes = [v_size for _ in range(layers_num)]
layers_block_size = [item for pair in zip(k_sizes, v_sizes) for item in pair]
key_prefix: str = "key_"
PRINT_DATA_SUM = True
MISALIGNED_ACCESS = True


def set_device(device_id):
    import acl
    acl.init()
    ret = acl.rt.set_device(device_id)
    if ret != 0:
        raise RuntimeError("acl set device failed")


def tensor_sum(tensor: List[torch.Tensor], sizes: List[int] = None):
    if tensor is None:
        return 0
    if sizes is None:
        return sum(layer.sum().item() for layer in tensor)
    return sum(layer[:size].sum().item() for layer, size in zip(tensor, sizes))


def allocate_aligned_tensor(shape, dtype=torch.float32, alignment=2*1024*1024):
    num_elements = torch.prod(torch.tensor(shape)).item()
    element_size = torch.finfo(dtype).bits // 8 if dtype.is_floating_point else torch.iinfo(dtype).bits // 8
    total_bytes = num_elements * element_size

    padding = alignment - 1
    buffer = torch.randint(0, 256, (total_bytes + padding,), dtype=dtype, device=g_local_type)

    address = buffer.data_ptr()
    aligned_address = (address + alignment - 1) & ~(alignment - 1)
    offset = (aligned_address - address) // element_size

    aligned_tensor = buffer[offset:offset + num_elements].view(*shape)
    print(f"==== Aligned tensor address: {aligned_tensor.data_ptr():x}, {num_elements=}, "
          f"{element_size=}, {total_bytes=}, {dtype=}, {shape=}")
    return aligned_tensor


def malloc_npu_blocks(min_block_size: int, layer_num: int, block_num: int):
    raw_blocks = allocate_aligned_tensor((layer_num, block_num, min_block_size), torch.uint8)
    torch_npu.npu.current_stream().synchronize()
    return raw_blocks


def get_col_tensors_by_index(tensors, layer_num, block_index):
    block_tensor = []
    for li in range(layer_num):
        block_tensor.append(tensors[li][block_index])
    return block_tensor


def get_col_tensors_ptr_by_index(tensors, layer_num, block_index):
    block_ptrs = []
    for li in range(layer_num):
        block_ptrs.append(tensors[li][block_index].data_ptr())
    return block_ptrs


def init_mooncake(device_id: int):
    from mooncake_store import Mooncakestore, MooncakeConfig
    config = MooncakeConfig(
        device=device_id,
        protocol='rdma',
        device_name='',
        local_hostname='192.168.1.2', # Change to your local IP
        metadata_server='P2PHANDSHAKE',
        global_segment_size=1024 * 1024 * 1024 * 64,
        local_buffer_size=128 * 1024 * 1024,
        master_server_address='192.168.1.1:50051') # Change to your master server
    store = Mooncakestore(config)
    return store


def write_worker(*args):
    device_id = args[0]
    batch_size = args[1]
    block_size = [args[2]]
    call_count = args[3]
    data_dim = args[4]
    backend = args[5]
    global g_local_type
    g_local_type = args[6]
    process_count = args[7]
    sync = args[8]

    set_device(device_id)
    print(f"npu:{device_id} 开始,PID: {os.getpid()}")
    if backend == "mooncake":
        store = init_mooncake(device_id)
        print(f"==== Start to init mooncake device:{device_id}")
    else:
        from memcache_hybrid import DistributedObjectStore
        store = DistributedObjectStore()
        print(f"==== Start to init memcache device:{device_id}")
        res = store.init(device_id)
        if res != 0:
            raise f"Failed to start pid:{os.getpid()} deviceId:{device_id}"
    print(f"==== Success to init device:{device_id}")
    one_dim_tensor = None
    k_tensors = None
    v_tensors = None
    k_sum = 0
    v_sum = 0
    one_dim_sum = 0
    if data_dim == 2:
        k_tensors = malloc_npu_blocks(max(k_sizes, default=0), len(k_sizes), batch_size)
        v_tensors = malloc_npu_blocks(max(v_sizes, default=0), len(v_sizes), batch_size)
        store.register_buffer(k_tensors.data_ptr(), max(k_sizes, default=0) * len(k_sizes) * batch_size)
        store.register_buffer(v_tensors.data_ptr(), max(v_sizes, default=0) * len(v_sizes) * batch_size)
        k_sum = k_tensors.sum().item()
        v_sum = v_tensors.sum().item()
    else:
        one_dim_tensor = malloc_npu_blocks(max(block_size, default=0), 1, batch_size)
        store.register_buffer(one_dim_tensor.data_ptr(), max(block_size, default=0) * batch_size)
        one_dim_sum = one_dim_tensor.sum().item()
    keys_list = []
    buffs_list = []
    sizes_list = []
    for i in range(call_count):
        keys = []
        buffs = []
        sizes = []
        for j in range(batch_size):
            key = key_prefix + str(device_id) + '_' + str(i) + '_' + str(j)
            keys.append(key)
            if data_dim == 2:
                block_buffs = [item for pair in zip(get_col_tensors_ptr_by_index(k_tensors, len(k_sizes), j),
                                                    get_col_tensors_ptr_by_index(v_tensors, len(v_sizes), j))
                                                    for item in pair]
                sizes.append(layers_block_size)
            else:
                block_buffs = get_col_tensors_ptr_by_index(one_dim_tensor, 1, j)
                sizes.append(block_size)
            buffs.append(block_buffs)
        keys_list.append(keys)
        buffs_list.append(buffs)
        sizes_list.append(sizes)

    sync.wait()
    start = time.perf_counter()
    for keys, buffs, sizes in zip(keys_list, buffs_list, sizes_list):
        write_ret = store.batch_put_from_layers(keys, buffs, sizes, 0)
        if any(x != 0 for x in write_ret):
            raise f"Failed to put pid:{os.getpid()} deviceId:{device_id}"
    end = time.perf_counter()
    sync.wait()
    duration_us = (end - start) * 1_000_000
    total_size_bytes = sum(sum(size) for size in sizes for sizes in sizes_list)
    total_size_gb = total_size_bytes / (1024 * 1024 * 1024)
    total_duration_seconds = duration_us / 1_000_000
    bandwidth_gb_per_sec = total_size_gb / total_duration_seconds
    print(f"\033[91mdevice_id:{device_id} write_total_size:{total_size_bytes} bytes, "
          f"single_size:{total_size_bytes / call_count:.0f} bytes, call count:{call_count}, "
          f"total_time:{duration_us:.2f} us, avg_time:{duration_us / call_count:.2f} us, "
          f"bw:{bandwidth_gb_per_sec:.3f} GB/s\033[0m\n")

    sleep(1)
    if PRINT_DATA_SUM:
        print(f"write data sum for check data ==> {device_id=} {k_sum=}, {v_sum}, {one_dim_sum}\n")
    sleep(1)
    print(f"===== npu:{device_id} write finish, wait read testing ......\n")
    sleep(30 * 60)


def read_worker(*args):
    device_id = args[0]
    batch_size = args[1]
    block_size = [args[2]]
    call_count = args[3]
    data_dim = args[4]
    backend = args[5]
    global g_local_type
    g_local_type = args[6]
    process_count = args[7]
    sync = args[8]

    set_device(device_id)
    print(f"npu:{device_id} 开始,PID: {os.getpid()}")
    if backend == "mooncake":
        store = init_mooncake(device_id)
        print(f"==== Start to init mooncake device:{device_id}")
    else:
        from memcache_hybrid import DistributedObjectStore
        store = DistributedObjectStore()
        print(f"==== Start to init memcache device:{device_id}")
        res = store.init(device_id)
        if res != 0:
            raise f"Failed to start pid:{os.getpid()} deviceId:{device_id}"
    print(f"==== Success to init device:{device_id}")
    one_dim_tensor = None
    k_tensors = None
    v_tensors = None
    k_sum = 0
    v_sum = 0
    one_dim_sum = 0
    if data_dim == 2:
        read_size = max(k_sizes, default=0) * len(k_sizes) + max(v_sizes, default=0) * len(v_sizes)
        k_tensors = malloc_npu_blocks(max(k_sizes, default=0), len(k_sizes), batch_size)
        v_tensors = malloc_npu_blocks(max(v_sizes, default=0), len(v_sizes), batch_size)
        store.register_buffer(k_tensors.data_ptr(), max(k_sizes, default=0) * len(k_sizes) * batch_size)
        store.register_buffer(v_tensors.data_ptr(), max(v_sizes, default=0) * len(v_sizes) * batch_size)
    else:
        read_size = max(block_size, default=0)
        one_dim_tensor = malloc_npu_blocks(max(block_size, default=0), 1, batch_size)
        store.register_buffer(one_dim_tensor.data_ptr(), max(block_size, default=0) * batch_size)
    if g_local_type == "npu":
        direct_t = 1
    else:
        direct_t = 2
    # 实测此步骤很耗时,提前准备
    keys_list = []
    buffs_list = []
    sizes_list = []
    # 错开读取位置,防止流量都到一个打到numa
    start_index = (math.ceil((1 << 30) / read_size) * device_id) % call_count if MISALIGNED_ACCESS else 0
    for i in range(start_index, call_count):
        keys = []
        buffs = []
        sizes = []
        for j in range(batch_size):
            key = key_prefix + str(device_id) + '_' + str(i) + '_' + str(j)
            keys.append(key)
            if data_dim == 2:
                block_buffs = [item for pair in zip(get_col_tensors_ptr_by_index(k_tensors, len(k_sizes), j),
                                                    get_col_tensors_ptr_by_index(v_tensors, len(v_sizes), j))
                                                    for item in pair]
                sizes.append(layers_block_size)
            else:
                block_buffs = get_col_tensors_ptr_by_index(one_dim_tensor, 1, j)
                sizes.append(block_size)
            buffs.append(block_buffs)

        keys_list.append(keys)
        buffs_list.append(buffs)
        sizes_list.append(sizes)

    for i in range(start_index):
        keys = []
        buffs = []
        sizes = []
        for j in range(batch_size):
            key = key_prefix + str(device_id) + '_' + str(i) + '_' + str(j)
            keys.append(key)
            if data_dim == 2:
                block_buffs = [item for pair in zip(get_col_tensors_ptr_by_index(k_tensors, len(k_sizes), j),
                                                    get_col_tensors_ptr_by_index(v_tensors, len(v_sizes), j))
                                                    for item in pair]
                sizes.append(layers_block_size)
            else:
                block_buffs = get_col_tensors_ptr_by_index(one_dim_tensor, 1, j)
                sizes.append(block_size)
            buffs.append(block_buffs)
        keys_list.append(keys)
        buffs_list.append(buffs)
        sizes_list.append(sizes)

    for keys, buffs, sizes in zip(keys_list, buffs_list, sizes_list):
        read_ret = store.batch_get_into_layers(keys, buffs, sizes, direct_t)
        if any(x != 0 for x in read_ret):
            raise f"Failed to get pid:{os.getpid()} deviceId:{device_id}"

    sync.wait()
    start = time.perf_counter()
    for keys, buffs, sizes in zip(keys_list, buffs_list, sizes_list):
        read_ret = store.batch_get_into_layers(keys, buffs, sizes, direct_t)
        if any(x != 0 for x in read_ret):
            raise f"Failed to get pid:{os.getpid()} deviceId:{device_id}"
    end = time.perf_counter()
    sync.wait()
    duration_us = (end - start) * 1_000_000
    total_size_bytes = sum(sum(size) for size in sizes for sizes in sizes_list)
    total_size_gb = total_size_bytes / (1024 * 1024 * 1024)
    total_duration_seconds = duration_us / 1_000_000
    bandwidth_gb_per_sec = total_size_gb / total_duration_seconds
    if data_dim == 2:
        k_sum = k_tensors.sum().item()
        v_sum = v_tensors.sum().item()
    else:
        one_dim_sum = one_dim_tensor.sum().item()

    print(f"\033[91mdevice_id:{device_id} read_total_size:{total_size_bytes} bytes, "
          f"single_size:{total_size_bytes / call_count:.0f} bytes, call count:{call_count}, "
          f"total_time:{duration_us:.2f} us, avg_time:{duration_us / call_count:.2f} us, "
          f"bw:{bandwidth_gb_per_sec:.3f} GB/s\033[0m\n")

    sleep(1)
    if PRINT_DATA_SUM:
        print(f"read data sum for check data ==> {device_id=} {k_sum=}, {v_sum}, {one_dim_sum}")
    sleep(10)