import math
import os
import time
from time import sleep
from typing import List
import torch
import torch_npu
g_local_type = "npu"
layers_num = 61
k_size = 128 * 1024
v_size = 16 * 1024
k_sizes = [k_size for _ in range(layers_num)]
v_sizes = [v_size for _ in range(layers_num)]
layers_block_size = [item for pair in zip(k_sizes, v_sizes) for item in pair]
key_prefix: str = "key_"
PRINT_DATA_SUM = True
MISALIGNED_ACCESS = True
def set_device(device_id):
import acl
acl.init()
ret = acl.rt.set_device(device_id)
if ret != 0:
raise RuntimeError("acl set device failed")
def tensor_sum(tensor: List[torch.Tensor], sizes: List[int] = None):
if tensor is None:
return 0
if sizes is None:
return sum(layer.sum().item() for layer in tensor)
return sum(layer[:size].sum().item() for layer, size in zip(tensor, sizes))
def allocate_aligned_tensor(shape, dtype=torch.float32, alignment=2*1024*1024):
num_elements = torch.prod(torch.tensor(shape)).item()
element_size = torch.finfo(dtype).bits // 8 if dtype.is_floating_point else torch.iinfo(dtype).bits // 8
total_bytes = num_elements * element_size
padding = alignment - 1
buffer = torch.randint(0, 256, (total_bytes + padding,), dtype=dtype, device=g_local_type)
address = buffer.data_ptr()
aligned_address = (address + alignment - 1) & ~(alignment - 1)
offset = (aligned_address - address) // element_size
aligned_tensor = buffer[offset:offset + num_elements].view(*shape)
print(f"==== Aligned tensor address: {aligned_tensor.data_ptr():x}, {num_elements=}, "
f"{element_size=}, {total_bytes=}, {dtype=}, {shape=}")
return aligned_tensor
def malloc_npu_blocks(min_block_size: int, layer_num: int, block_num: int):
raw_blocks = allocate_aligned_tensor((layer_num, block_num, min_block_size), torch.uint8)
torch_npu.npu.current_stream().synchronize()
return raw_blocks
def get_col_tensors_by_index(tensors, layer_num, block_index):
block_tensor = []
for li in range(layer_num):
block_tensor.append(tensors[li][block_index])
return block_tensor
def get_col_tensors_ptr_by_index(tensors, layer_num, block_index):
block_ptrs = []
for li in range(layer_num):
block_ptrs.append(tensors[li][block_index].data_ptr())
return block_ptrs
def init_mooncake(device_id: int):
from mooncake_store import Mooncakestore, MooncakeConfig
config = MooncakeConfig(
device=device_id,
protocol='rdma',
device_name='',
local_hostname='192.168.1.2',
metadata_server='P2PHANDSHAKE',
global_segment_size=1024 * 1024 * 1024 * 64,
local_buffer_size=128 * 1024 * 1024,
master_server_address='192.168.1.1:50051')
store = Mooncakestore(config)
return store
def write_worker(*args):
device_id = args[0]
batch_size = args[1]
block_size = [args[2]]
call_count = args[3]
data_dim = args[4]
backend = args[5]
global g_local_type
g_local_type = args[6]
process_count = args[7]
sync = args[8]
set_device(device_id)
print(f"npu:{device_id} 开始,PID: {os.getpid()}")
if backend == "mooncake":
store = init_mooncake(device_id)
print(f"==== Start to init mooncake device:{device_id}")
else:
from memcache_hybrid import DistributedObjectStore
store = DistributedObjectStore()
print(f"==== Start to init memcache device:{device_id}")
res = store.init(device_id)
if res != 0:
raise f"Failed to start pid:{os.getpid()} deviceId:{device_id}"
print(f"==== Success to init device:{device_id}")
one_dim_tensor = None
k_tensors = None
v_tensors = None
k_sum = 0
v_sum = 0
one_dim_sum = 0
if data_dim == 2:
k_tensors = malloc_npu_blocks(max(k_sizes, default=0), len(k_sizes), batch_size)
v_tensors = malloc_npu_blocks(max(v_sizes, default=0), len(v_sizes), batch_size)
store.register_buffer(k_tensors.data_ptr(), max(k_sizes, default=0) * len(k_sizes) * batch_size)
store.register_buffer(v_tensors.data_ptr(), max(v_sizes, default=0) * len(v_sizes) * batch_size)
k_sum = k_tensors.sum().item()
v_sum = v_tensors.sum().item()
else:
one_dim_tensor = malloc_npu_blocks(max(block_size, default=0), 1, batch_size)
store.register_buffer(one_dim_tensor.data_ptr(), max(block_size, default=0) * batch_size)
one_dim_sum = one_dim_tensor.sum().item()
keys_list = []
buffs_list = []
sizes_list = []
for i in range(call_count):
keys = []
buffs = []
sizes = []
for j in range(batch_size):
key = key_prefix + str(device_id) + '_' + str(i) + '_' + str(j)
keys.append(key)
if data_dim == 2:
block_buffs = [item for pair in zip(get_col_tensors_ptr_by_index(k_tensors, len(k_sizes), j),
get_col_tensors_ptr_by_index(v_tensors, len(v_sizes), j))
for item in pair]
sizes.append(layers_block_size)
else:
block_buffs = get_col_tensors_ptr_by_index(one_dim_tensor, 1, j)
sizes.append(block_size)
buffs.append(block_buffs)
keys_list.append(keys)
buffs_list.append(buffs)
sizes_list.append(sizes)
sync.wait()
start = time.perf_counter()
for keys, buffs, sizes in zip(keys_list, buffs_list, sizes_list):
write_ret = store.batch_put_from_layers(keys, buffs, sizes, 0)
if any(x != 0 for x in write_ret):
raise f"Failed to put pid:{os.getpid()} deviceId:{device_id}"
end = time.perf_counter()
sync.wait()
duration_us = (end - start) * 1_000_000
total_size_bytes = sum(sum(size) for size in sizes for sizes in sizes_list)
total_size_gb = total_size_bytes / (1024 * 1024 * 1024)
total_duration_seconds = duration_us / 1_000_000
bandwidth_gb_per_sec = total_size_gb / total_duration_seconds
print(f"\033[91mdevice_id:{device_id} write_total_size:{total_size_bytes} bytes, "
f"single_size:{total_size_bytes / call_count:.0f} bytes, call count:{call_count}, "
f"total_time:{duration_us:.2f} us, avg_time:{duration_us / call_count:.2f} us, "
f"bw:{bandwidth_gb_per_sec:.3f} GB/s\033[0m\n")
sleep(1)
if PRINT_DATA_SUM:
print(f"write data sum for check data ==> {device_id=} {k_sum=}, {v_sum}, {one_dim_sum}\n")
sleep(1)
print(f"===== npu:{device_id} write finish, wait read testing ......\n")
sleep(30 * 60)
def read_worker(*args):
device_id = args[0]
batch_size = args[1]
block_size = [args[2]]
call_count = args[3]
data_dim = args[4]
backend = args[5]
global g_local_type
g_local_type = args[6]
process_count = args[7]
sync = args[8]
set_device(device_id)
print(f"npu:{device_id} 开始,PID: {os.getpid()}")
if backend == "mooncake":
store = init_mooncake(device_id)
print(f"==== Start to init mooncake device:{device_id}")
else:
from memcache_hybrid import DistributedObjectStore
store = DistributedObjectStore()
print(f"==== Start to init memcache device:{device_id}")
res = store.init(device_id)
if res != 0:
raise f"Failed to start pid:{os.getpid()} deviceId:{device_id}"
print(f"==== Success to init device:{device_id}")
one_dim_tensor = None
k_tensors = None
v_tensors = None
k_sum = 0
v_sum = 0
one_dim_sum = 0
if data_dim == 2:
read_size = max(k_sizes, default=0) * len(k_sizes) + max(v_sizes, default=0) * len(v_sizes)
k_tensors = malloc_npu_blocks(max(k_sizes, default=0), len(k_sizes), batch_size)
v_tensors = malloc_npu_blocks(max(v_sizes, default=0), len(v_sizes), batch_size)
store.register_buffer(k_tensors.data_ptr(), max(k_sizes, default=0) * len(k_sizes) * batch_size)
store.register_buffer(v_tensors.data_ptr(), max(v_sizes, default=0) * len(v_sizes) * batch_size)
else:
read_size = max(block_size, default=0)
one_dim_tensor = malloc_npu_blocks(max(block_size, default=0), 1, batch_size)
store.register_buffer(one_dim_tensor.data_ptr(), max(block_size, default=0) * batch_size)
if g_local_type == "npu":
direct_t = 1
else:
direct_t = 2
keys_list = []
buffs_list = []
sizes_list = []
start_index = (math.ceil((1 << 30) / read_size) * device_id) % call_count if MISALIGNED_ACCESS else 0
for i in range(start_index, call_count):
keys = []
buffs = []
sizes = []
for j in range(batch_size):
key = key_prefix + str(device_id) + '_' + str(i) + '_' + str(j)
keys.append(key)
if data_dim == 2:
block_buffs = [item for pair in zip(get_col_tensors_ptr_by_index(k_tensors, len(k_sizes), j),
get_col_tensors_ptr_by_index(v_tensors, len(v_sizes), j))
for item in pair]
sizes.append(layers_block_size)
else:
block_buffs = get_col_tensors_ptr_by_index(one_dim_tensor, 1, j)
sizes.append(block_size)
buffs.append(block_buffs)
keys_list.append(keys)
buffs_list.append(buffs)
sizes_list.append(sizes)
for i in range(start_index):
keys = []
buffs = []
sizes = []
for j in range(batch_size):
key = key_prefix + str(device_id) + '_' + str(i) + '_' + str(j)
keys.append(key)
if data_dim == 2:
block_buffs = [item for pair in zip(get_col_tensors_ptr_by_index(k_tensors, len(k_sizes), j),
get_col_tensors_ptr_by_index(v_tensors, len(v_sizes), j))
for item in pair]
sizes.append(layers_block_size)
else:
block_buffs = get_col_tensors_ptr_by_index(one_dim_tensor, 1, j)
sizes.append(block_size)
buffs.append(block_buffs)
keys_list.append(keys)
buffs_list.append(buffs)
sizes_list.append(sizes)
for keys, buffs, sizes in zip(keys_list, buffs_list, sizes_list):
read_ret = store.batch_get_into_layers(keys, buffs, sizes, direct_t)
if any(x != 0 for x in read_ret):
raise f"Failed to get pid:{os.getpid()} deviceId:{device_id}"
sync.wait()
start = time.perf_counter()
for keys, buffs, sizes in zip(keys_list, buffs_list, sizes_list):
read_ret = store.batch_get_into_layers(keys, buffs, sizes, direct_t)
if any(x != 0 for x in read_ret):
raise f"Failed to get pid:{os.getpid()} deviceId:{device_id}"
end = time.perf_counter()
sync.wait()
duration_us = (end - start) * 1_000_000
total_size_bytes = sum(sum(size) for size in sizes for sizes in sizes_list)
total_size_gb = total_size_bytes / (1024 * 1024 * 1024)
total_duration_seconds = duration_us / 1_000_000
bandwidth_gb_per_sec = total_size_gb / total_duration_seconds
if data_dim == 2:
k_sum = k_tensors.sum().item()
v_sum = v_tensors.sum().item()
else:
one_dim_sum = one_dim_tensor.sum().item()
print(f"\033[91mdevice_id:{device_id} read_total_size:{total_size_bytes} bytes, "
f"single_size:{total_size_bytes / call_count:.0f} bytes, call count:{call_count}, "
f"total_time:{duration_us:.2f} us, avg_time:{duration_us / call_count:.2f} us, "
f"bw:{bandwidth_gb_per_sec:.3f} GB/s\033[0m\n")
sleep(1)
if PRINT_DATA_SUM:
print(f"read data sum for check data ==> {device_id=} {k_sum=}, {v_sum}, {one_dim_sum}")
sleep(10)