import os
import ast
import concurrent.futures
import dataclasses
import inspect
import logging
import random
import socket
import string
import sys
import time
import threading
import traceback
from functools import wraps
from typing import Callable, Dict, List, Any
from ctypes import CDLL
import torch
import torch_npu
import acl
import memfabric_hybrid
from memfabric_hybrid import bm, shm, TransferEngine, create_config_store, set_log_level
globals_device_id = -1
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@dataclasses.dataclass
class CliCommand:
def __init__(self, cmd_name: str, cmd_description: str, func: Callable):
self.cmd_name = cmd_name
self.cmd_description = cmd_description
self.func = func
self.args_num = self._get_user_param_count(func)
@staticmethod
def _get_user_param_count(func):
sig = inspect.signature(func)
params = sig.parameters
return len(params)
@dataclasses.dataclass
class TransferEngineData:
buffer: torch.Tensor
buffer_bytes: int
batch_buffers: List[torch.Tensor]
batch_bytes: List[int]
class TestServer:
def __init__(self, socket_id: int):
self.socket_file = f"/home/CI_HOME/memfabric_hybrid/bin/mb_{socket_id}.socket"
self.server_socket = None
self._commands: Dict[str:CliCommand] = {}
self.generated_keys = set()
self.thread_local = threading.local()
self.thread_local.client_socket = None
self._register_inner_command()
self._stream = 0
def _register_inner_command(self):
self._commands = {
"help": CliCommand("help", "show command list information", self._help),
"getServerCommands": CliCommand("getServerCommands", "ServerCommands, get the registered Commands",
self._get_server_commands),
}
def register_command(self, cmds: List[CliCommand]):
for cmd in cmds:
self._commands[cmd.cmd_name] = cmd
@staticmethod
def _convert_argument(arg_str: str, param_type):
try:
if param_type == int:
return int(arg_str)
elif param_type == float:
return float(arg_str)
elif param_type == str:
return str(arg_str)
elif param_type == bool:
return arg_str.lower() in ['true', '1', 'yes']
else:
return ast.literal_eval(arg_str)
except (ValueError, SyntaxError):
return arg_str
@staticmethod
def _parse_arguments(func, arg_strs):
signature = inspect.signature(func)
parsed_args = []
param_iter = iter(signature.parameters.values())
for _, arg in enumerate(arg_strs):
param = next(param_iter)
param_type = param.annotation
parsed_args.append(TestServer._convert_argument(arg, param_type))
return parsed_args
def _execute(self, cmd):
"""执行命令。"""
parts = cmd.split()
if not parts:
return
cmd_str = parts[0]
params = parts[1:]
if cmd_str in self._commands:
command = self._commands[cmd_str]
if len(params) == command.args_num:
parsed_params = self._parse_arguments(command.func, params)
command.func(*parsed_params)
self._cli_end_line()
return
self.cli_print(f"Invalid input args num:{len(params)}, need:{command.args_num}.")
else:
self.cli_print(f"Unknown command: {cmd_str}")
self._help()
self._cli_end_line()
def start(self):
self.server_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
os.unlink(self.socket_file)
except OSError as e:
pass
self.server_socket.bind(self.socket_file)
self.server_socket.listen(5)
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
while True:
client_socket, _ = self.server_socket.accept()
executor.submit(self._handle_client, client_socket)
def close(self):
self.server_socket.close()
os.remove(self.socket_file)
def _handle_client(self, client_socket: socket.socket):
self.thread_local.client_socket = client_socket
if self._stream == 0:
self._stream = set_device(globals_device_id)
current_thread = threading.current_thread()
logging.info(f"current thread: {current_thread.ident}")
buffer = b""
try:
while True:
data = client_socket.recv(1024)
if not data:
self.thread_local.client_socket = None
break
buffer += data
if b"\0" in data:
msg = buffer.decode('utf-8').replace("\0", "")
buffer = b""
logging.info(f"[mftest] received command: {msg}")
try:
self._execute(msg.strip())
except Exception:
traceback.print_exc()
finally:
client_socket.close()
def cli_print(self, msg: str):
self.thread_local.client_socket.send(f"{msg}\n".encode('utf-8'))
def _cli_end_line(self):
self.thread_local.client_socket.send("\0".encode('utf-8'))
def _help(self):
"""显示帮助信息。"""
col_widths = max(len(item) for item in self._commands.keys()) + 1
self.cli_print("Available commands:")
for cmd in self._commands.values():
self.cli_print(f": {cmd.cmd_name: >{col_widths}}: {cmd.cmd_description}")
def _get_server_commands(self):
msg = ",".join(self._commands.keys())
self.cli_print(f"{msg}")
def generate_unique_key(self) -> str:
while True:
key = ''.join(random.choices(string.ascii_letters, k=26))
if key not in self.generated_keys:
self.generated_keys.add(key)
return key
def result_handler(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
try:
func(self, *args, **kwargs)
self.cli_print(f"{func.__name__} success")
except Exception as e:
self.cli_print(f"{func.__name__} raised exception: {e}")
return wrapper
def set_device(device_id: int):
torch.npu.set_device(device=device_id)
_stream, ret = acl.rt.create_stream()
return _stream
class MfTest(TestServer):
def __init__(self, socket_id: int):
super().__init__(socket_id)
self._init_cmds()
self._config_dic = {}
self._bm_handle_dic = {}
self._shm_handle_dic = {}
self._cpu_src_tensor_dic = {}
self._npu_tensor_dic = {}
self._transfer_engine_dic = {}
self._transfer_engine_data_dic = {}
def _init_cmds(self):
cmds = [
CliCommand("show",
"Show local object directory, show",
self.show),
CliCommand("clear_all",
"Clear local object directory, clear",
self.clear_all),
CliCommand("clear_object",
"Clear specify id object directory, clear_object [id]",
self.clear_object),
CliCommand("generate_bm_config",
"Generate a bm config, generate_bm_config "
"[init_timeout] [create_timeout] [operation_timeout] [start_store] [start_store_only] "
"[dynamic_world_size] [unified_address_space] [rank_id] [auto_ranking] [flags]",
self.generate_bm_config),
CliCommand("generate_bm_config_default",
"Generate a bm config, generate_bm_config_default [rank_id] [auto_ranking]",
self.generate_bm_config_default),
CliCommand("generate_tensor_addr",
"Generate cpu and npu tensor, generate_tensor_addr [copy_size]",
self.generate_tensor_addr),
CliCommand("generate_shm_config",
"Generate a shm config, generate_shm_config "
"[init_timeout] [create_timeout] [operation_timeout] [start_store] [flags]",
self.generate_shm_config),
CliCommand("generate_shm_config_default",
"Generate a shm config, generate_shm_config [start_store]",
self.generate_shm_config_default),
CliCommand("mf_smem_init",
"Initialize the smem running environment, mf_smem_init",
self.mf_smem_init),
CliCommand("smem_set_log_level",
"Set log print level, smem_set_log_level [level]",
self.smem_set_log_level),
CliCommand("mf_smem_uninit",
"Uninitialize the smem running environment, mf_smem_uninit",
self.mf_smem_uninit),
CliCommand("bm_init",
"Initialize smem big memory library, bm_init "
"[store_url] [world_size] [device_id] [config_id]",
self.bm_init),
CliCommand("bm_uninitialize",
"uninitialize smem big memory library, bm_uninitialize [flags]",
self.bm_uninitialize),
CliCommand("bm_rank_id",
"Get the rank id assigned during initialize, bm_rank_id ",
self.bm_rank_id),
CliCommand("bm_create",
"Create a big memory object locally after initialized, bm_create "
"[mem_id] [local_dram_size] [local_hbm_size] [data_op_type] [flags]",
self.bm_create),
CliCommand("bm_create2",
"Create a big memory object locally after initialized, bm_create2 "
"[mem_id] [local_dram_size] [max_dram_size] [local_hbm_size] [max_hbm_size] [data_op_type] "
"[enable_56bits_gva] [flags]",
self.bm_create2),
CliCommand("bm_destroy",
"Destroy big memory by a big memory handle. bm_destroy [handle_id]",
self.bm_destroy),
CliCommand("bm_join",
"Join to global Big Memory space actively, after this, "
"we operate on the global space, bm_join [handle_id] [flags]",
self.bm_join),
CliCommand("bm_leave",
"Leave the global Big Memory space actively, after this, "
"we cannot operate on the global space any more, bm_leave [handle_id] [flags]",
self.bm_leave),
CliCommand("bm_extend_local_mem",
"Alloc an extend memory for rank, all alloc memory must range in reserved memory. "
"bm_extend_local_mem [mem_type] [size]",
self.bm_extend_local_mem),
CliCommand("bm_local_mem_size",
"Get size of local memory that contributed to global space, "
"bm_local_mem_size [handle_id] [mem_type]",
self.bm_local_mem_size),
CliCommand("bm_peer_rank_ptr",
"Get peer gva by rank id, bm_peer_rank_ptr [handle_id] [peer_rank] [mem_type]",
self.bm_peer_rank_ptr),
CliCommand("bm_get_rank_id_by_gva",
"Get rank id of gva that belongs to, bm_get_rank_id_by_gva [handle_id] [gva]",
self.bm_get_rank_id_by_gva),
CliCommand("bm_register",
"Register user mem, bm_register [handle_id] [addr] [size]",
self.bm_register),
CliCommand("bm_unregister",
"Unregister user mem, bm_unregister [handle_id] [addr]",
self.bm_unregister),
CliCommand("bm_copy_data",
"Data operation on Big Memory object, copy_data "
"[handle_id] [src_ptr] [dst_ptr] [size] [copy_type] [flags]",
self.bm_copy_data),
CliCommand("bm_copy_data_batch",
"Data operation on Big Memory object, copy_data_batch "
"[handle_id] [src_addrs_str] [dst_addrs_str] [sizes] [count] [copy_type] [flags]",
self.bm_copy_data_batch),
CliCommand("bm_wait",
"Wait all issued async copy(s) finish, bm_wait [handle_id]",
self.bm_wait),
CliCommand("delete_bm_handle", "delete a bm handle, delete_bm_handle [handle_id]",
self.delete_bm_handle),
CliCommand("shm_init",
"Initialize smem shm library, shm_init "
"[store_url] [world_size] [rank_id] [device_id] [config_id]",
self.shm_init),
CliCommand("shm_uninit",
"Uninitialize smem shm library, shm_uninit [flags]",
self.shm_uninit),
CliCommand("shm_create",
"Create a shm handle, shm_create "
"[shm_id] [rank_size] [rank_id] [local_mem_size] [data_op_type] [flags]",
self.shm_create),
CliCommand("delete_shm_handle", "delete a bm handle, delete_shm_handle [handle_id]",
self.delete_shm_handle),
CliCommand("shm_destroy",
"Destroy shm by a shm handle.shm_destroy [handle_id] [flags]",
self.shm_destroy),
CliCommand("shm_set_context",
"Set user extra context of shm object, shm_set_context [handle_id] [context]",
self.shm_set_context),
CliCommand("shm_local_rank",
"Get local rank of a shm object, shm_local_rank [handle_id]",
self.shm_local_rank),
CliCommand("rank_size",
"Get rank size of a shm object, rank_size [handle_id]",
self.rank_size),
CliCommand("shm_barrier",
"Do barrier on a shm object, using control network, shm_barrier [handle_id]",
self.shm_barrier),
CliCommand("shm_all_gather",
"Do all gather on a shm object, using control network, shm_all_gather [handle_id] [local_data]",
self.shm_all_gather),
CliCommand("shm_gva",
"Get global virtual address created, "
"it can be passed to kernel to data operations, sh_gva [handle_id]",
self.shm_gva),
CliCommand("generate_transfer_engine_npu_tensor",
"Generate the buffer which is to transfer data between engines, "
"generate_transfer_engine_npu_tensor [is_src]",
self.generate_transfer_engine_npu_tensor),
CliCommand("generate_transfer_engine_npu_buffer_sum",
"Calculate the sum of buffer, generate_transfer_engine_npu_buffer_sum "
"[transfer_data_id] [is_batch]",
self.generate_transfer_engine_npu_buffer_sum),
CliCommand("generate_transfer_engine_cpu_tensor",
"Generate the buffer which is to transfer data between engines, "
"generate_transfer_engine_cpu_tensor [is_src]",
self.generate_transfer_engine_cpu_tensor),
CliCommand("generate_transfer_engine_cpu_buffer_sum",
"Calculate the sum of buffer, generate_transfer_engine_cpu_buffer_sum "
"[transfer_data_id] [is_batch]",
self.generate_transfer_engine_cpu_buffer_sum),
CliCommand("transfer_engine_get_rpc_port",
"Generate a random port [handle_id]",
self.transfer_engine_get_rpc_port),
CliCommand("transfer_engine_create_config_store",
"Create the config store server, transfer_engine_create_config_store [store_url]",
self.transfer_engine_create_config_store),
CliCommand("transfer_engine_set_log_level",
"Set log level for transfer engine, transfer_engine_set_log_level [level]",
self.transfer_engine_set_log_level),
CliCommand("transfer_engine_initialize",
"Initialize the transfer engine, transfer_engine_initialize "
"[store_url] [session_id] [role] [device_id] [op_type]",
self.transfer_engine_initialize),
CliCommand("transfer_engine_register_memory",
"Register memory for transfer engine, "
"transfer_engine_register_memory [handle_id] [transfer_data_id]",
self.transfer_engine_register_memory),
CliCommand("transfer_engine_transfer_sync_write",
"Transfer data between transfer engine, transfer_engine_transfer_sync_write "
"[handle_id] [transfer_data_id] [peer_buffer_addresses] [destSession]",
self.transfer_engine_transfer_sync_write),
CliCommand("transfer_engine_transfer_sync_read",
"Transfer data between transfer engine, transfer_engine_transfer_sync_read "
"[handle_id] [transfer_data_id] [peer_buffer_addresses] [destSession]",
self.transfer_engine_transfer_sync_read),
CliCommand("transfer_engine_batch_register_memory",
"Register batch memory for transfer engine, "
"transfer_engine_batch_register_memory [handle_id] [transfer_data_id]",
self.transfer_engine_batch_register_memory),
CliCommand("transfer_engine_batch_transfer_sync_write",
"Transfer batch data between transfer engine, transfer_engine_batch_transfer_sync_write "
"[handle_id] [transfer_data_id] [peer_buffer_addresses] [destSession]",
self.transfer_engine_batch_transfer_sync_write),
CliCommand("transfer_engine_batch_transfer_sync_read",
"Transfer batch data between transfer engine, transfer_engine_batch_transfer_sync_read "
"[handle_id] [transfer_data_id] [peer_buffer_addresses] [destSession]",
self.transfer_engine_batch_transfer_sync_read),
CliCommand("transfer_engine_transfer_async_write_submit",
"Transfer data between transfer engine, transfer_engine_transfer_async_write_submit "
"[handle_id] [transfer_data_id] [peer_buffer_address] [destSession] [stream]",
self.transfer_engine_transfer_async_write_submit),
CliCommand("transfer_engine_transfer_async_read_submit",
"Transfer data between transfer engine, transfer_engine_transfer_async_read_submit "
"[handle_id] [transfer_data_id] [peer_buffer_address] [destSession] [stream]",
self.transfer_engine_transfer_async_read_submit),
CliCommand("transfer_engine_unregister_memory",
"Destroy transfer engine handle, "
"transfer_engine_unregister_memory [handle_id] [transfer_data_id]",
self.transfer_engine_unregister_memory),
CliCommand("transfer_engine_destroy",
"Destroy transfer engine handle, transfer_engine_destroy [handle_id]",
self.transfer_engine_destroy),
CliCommand("transfer_engine_uninitialize",
"Uninitialize transfer engine, transfer_engine_uninitialize [handle_id]",
self.transfer_engine_uninitialize),
CliCommand("delete_trans_handle", "delete a delete_trans_handle handle, delete_bm_handle [handle_id]",
self.delete_trans_handle),
CliCommand("close", "release the bound UDS socket file", self.close)
]
self.register_command(cmds)
def show(self):
self.cli_print("config objects:")
for key, value in self._config_dic.items():
self.cli_print(f":\tid:{key}-{value}")
self.cli_print("bm handle objects:")
for key, value in self._bm_handle_dic.items():
self.cli_print(f":\tid:{key}")
self.cli_print("shm handle objects:")
for key, value in self._shm_handle_dic.items():
self.cli_print(f":\tid:{key}")
self.cli_print("cpu tensor objects:")
for key, value in self._cpu_src_tensor_dic.items():
self.cli_print(f":\tdevice:{value.device}, id:{key},addr:{value.data_ptr()}")
self.cli_print("npu tensor objects:")
for key, value in self._npu_tensor_dic.items():
self.cli_print(f":\tdevice:{value.device}, id:{key},addr:{value.data_ptr()}")
self.cli_print("transfer engine objects:")
for key, value in self._transfer_engine_dic.items():
self.cli_print(f":\tid:{key}")
@result_handler
def clear_all(self):
self._config_dic.clear()
self._bm_handle_dic.clear()
self._shm_handle_dic.clear()
self._cpu_src_tensor_dic.clear()
self._npu_tensor_dic.clear()
self._transfer_engine_dic.clear()
self._transfer_engine_data_dic.clear()
@result_handler
def clear_object(self, obj_id):
if obj_id in self._config_dic:
self._config_dic.pop(obj_id, None)
return
if obj_id in self._bm_handle_dic:
self._bm_handle_dic.pop(obj_id, None)
return
if obj_id in self._shm_handle_dic:
self._shm_handle_dic.pop(obj_id, None)
return
if obj_id in self._cpu_src_tensor_dic:
self._cpu_src_tensor_dic.pop(obj_id, None)
return
if obj_id in self._npu_tensor_dic:
self._npu_tensor_dic.pop(obj_id, None)
return
if obj_id in self._transfer_engine_dic:
self._transfer_engine_dic.pop(obj_id, None)
@result_handler
def generate_tensor_addr(self, copy_size: int):
count = copy_size // 4
src_tensor = torch.empty([count], dtype=torch.int32)
base = globals_device_id
mod = 32767
for i in range(0, count):
base = (base * 23 + 17) % mod
if ((i + globals_device_id) % 3) == 0:
src_tensor[i] = -base
else:
src_tensor[i] = base
src_tensor_id = id(src_tensor)
self._cpu_src_tensor_dic[src_tensor_id] = src_tensor
npu_tensor = src_tensor.npu(globals_device_id)
npu_tensor_id = id(npu_tensor)
self._npu_tensor_dic[npu_tensor_id] = npu_tensor
self.cli_print(f"cpu src tensor addr:{src_tensor.data_ptr()}, npu tensor addr:{npu_tensor.data_ptr()}")
def _get_config(self, config_id: int) -> Any:
"""
:param config_id: 保存config的key值
"""
config = self._config_dic.get(config_id, None)
return config
def _get_transfer_engine_data(self, data_id: int) -> Any:
"""
:param data_id: 保存transfer_data的key值
"""
transfer_data = self._transfer_engine_data_dic.get(data_id, None)
return transfer_data
@result_handler
def generate_bm_config(self, init_timeout: int, create_timeout: int, operation_timeout: int, start_store: bool,
start_store_only: bool, dynamic_world_size: bool, unified_address_space: bool,
auto_ranking: bool, rank_id: int, flags: int, hcom_url: str):
config = bm.BmConfig()
config.init_timeout = init_timeout
config.create_timeout = create_timeout
config.operation_timeout = operation_timeout
config.start_store = start_store
config.start_store_only = start_store_only
config.dynamic_world_size = dynamic_world_size
config.unified_address_space = unified_address_space
config.auto_ranking = auto_ranking
config.rank_id = rank_id
config.flags = flags
config.set_nic(hcom_url)
config_id = id(config)
self._config_dic[config_id] = config
logging.info(f"generate bm config id:{config_id}")
self.cli_print(f"generate bm config id:{config_id}")
@result_handler
def generate_bm_config_default(self, rank_id: int, auto_ranking: bool):
config = bm.BmConfig()
config.rank_id = rank_id
config.auto_ranking = auto_ranking
config_id = id(config)
self._config_dic[config_id] = config
self.cli_print(f"generate bm config id:{config_id}")
@result_handler
def generate_shm_config(self, init_timeout: int, create_timeout: int, operation_timeout: int, start_store: bool,
flags: int):
config = shm.ShmConfig()
config.init_timeout = init_timeout
config.create_timeout = create_timeout
config.operation_timeout = operation_timeout
config.start_store = start_store
config.flags = flags
config_id = id(config)
self._config_dic[config_id] = config
self.cli_print(f"generate shm config id:{config_id}")
@result_handler
def generate_shm_config_default(self, start_store: bool):
config = shm.ShmConfig()
config.start_store = start_store
config_id = id(config)
self._config_dic[config_id] = config
self.cli_print(f"generate shm config id:{config_id}")
@result_handler
def mf_smem_init(self, flags: int = 0):
ret = memfabric_hybrid.initialize(flags)
self.cli_print(f"mf smem init ret({ret})")
@result_handler
def smem_set_log_level(self, level: int):
ret = memfabric_hybrid.set_log_level(level)
memfabric_hybrid.set_conf_store_tls(0, "0")
self.cli_print(f"set log level{level} ret({ret})")
@result_handler
def mf_smem_uninit(self):
memfabric_hybrid.uninitialize()
@result_handler
def bm_init(self, store_url: str, world_size: int, device_id: int, config_id: int):
logging.info("============================bm_init============")
config = self._get_config(config_id)
logging.info(f"config={config} config_id={config_id}")
ret = bm.initialize(store_url=store_url, world_size=world_size, device_id=device_id, config=config)
self.cli_print(f"smem BM initialize ret({ret})")
@result_handler
def bm_uninitialize(self, flags: int):
bm.uninitialize(flags)
@result_handler
def bm_rank_id(self):
rank_id = bm.bm_rank_id()
self.cli_print(f"bm rank id:{rank_id}")
@result_handler
def bm_create(self, mem_id: int, local_dram_size: int, local_hbm_size: int, data_op_type: int, flags: int):
handle = bm.create(id=mem_id, local_dram_size=local_dram_size,
local_hbm_size=local_hbm_size, data_op_type=bm.BmDataOpType(data_op_type), flags=flags)
self.cli_print(f"id={mem_id}, local_dram_size={local_dram_size}, local_hbm_size={local_hbm_size}, "
f"data_op_type={bm.BmDataOpType(data_op_type),} flags={flags}")
if handle is not None:
addr = id(handle)
self._bm_handle_dic[addr] = handle
else:
addr = 0
self.cli_print(f"create a big memory object:{addr}")
@result_handler
def bm_create2(self, mem_id: int, local_dram_size: int, max_dram_size: int, local_hbm_size: int, max_hbm_size: int,
data_op_type: int, enable_56bits_gva: bool, flags: int):
handle = bm.create2(id=mem_id, local_dram_size=local_dram_size, max_dram_size=max_dram_size,
local_hbm_size=local_hbm_size, max_hbm_size=max_hbm_size,
data_op_type=bm.BmDataOpType(data_op_type), enable_56bits_gva=enable_56bits_gva,
flags=flags)
self.cli_print(f"id={mem_id}, local_dram_size={local_dram_size}, local_hbm_size={local_hbm_size}, "
f"data_op_type={bm.BmDataOpType(data_op_type),} flags={flags}")
if handle is not None:
addr = id(handle)
self._bm_handle_dic[addr] = handle
else:
addr = 0
self.cli_print(f"create a big memory object:{addr}")
@result_handler
def bm_destroy(self, handle_id: int):
handle = self._bm_handle_dic[handle_id]
handle.destroy()
@result_handler
def bm_join(self, handle_id: int, flags: int):
handle = self._bm_handle_dic[handle_id]
ret = handle.join(flags=flags)
self.cli_print(f"bm join, ret:{ret}")
@result_handler
def bm_leave(self, handle_id: int, flags: int):
handle = self._bm_handle_dic[handle_id]
ret = handle.leave(flags)
self.cli_print(f"bm leave, ret:{ret}")
@result_handler
def bm_extend_local_mem(self, handle_id: int, mem_type: int, size: int):
handle = self._bm_handle_dic[handle_id]
ret = handle.extend_local_mem(bm.BmMemType(mem_type), size)
self.cli_print(f"bm_extend_local_mem, ret:{ret}")
@result_handler
def bm_local_mem_size(self, handle_id: int, mem_type: int):
handle = self._bm_handle_dic[handle_id]
mem_size = handle.local_mem_size(mem_type=bm.BmMemType(mem_type))
self.cli_print(f"local memory size:{mem_size}, type:{bm.BmMemType(mem_type)}")
@result_handler
def bm_peer_rank_ptr(self, handle_id: int, peer_rank: int, mem_type: int):
handle = self._bm_handle_dic[handle_id]
ptr = handle.peer_rank_ptr(peer_rank=peer_rank, mem_type=bm.BmMemType(mem_type))
self.cli_print(f"peer rack ptr16:{hex(ptr)}, type:{bm.BmMemType(mem_type)}")
self.cli_print(f"peer rack ptr:{ptr}, type:{bm.BmMemType(mem_type)}")
@result_handler
def bm_get_rank_id_by_gva(self, handle_id: int, gva: int):
handle = self._bm_handle_dic[handle_id]
rank_id = handle.get_rank_id_by_gva(gva)
self.cli_print(f"bm rank id:{rank_id}")
@result_handler
def bm_register(self, handle_id: int, addr: int, size: int):
handle = self._bm_handle_dic[handle_id]
ret = handle.register(addr=addr, size=size)
self.cli_print(f"bm register mem, ret:{ret}")
@result_handler
def bm_unregister(self, handle_id: int, addr: int):
handle = self._bm_handle_dic[handle_id]
ret = handle.unregister(addr=addr)
self.cli_print(f"bm unregister mem, ret:{ret}")
@result_handler
def bm_copy_data(self, handle_id: int, src_ptr: int, dst_ptr: int, size: int, copy_type: int, flags: int):
handle = self._bm_handle_dic[handle_id]
self.cli_print(
f"src_ptr={src_ptr}, dst_ptr={hex(dst_ptr)}, size={size}, type={bm.BmCopyType(copy_type)}, flags={flags}")
ret = handle.copy_data(src_ptr=src_ptr, dst_ptr=dst_ptr, size=size, type=bm.BmCopyType(copy_type), flags=flags)
self.cli_print(f"bm copy_data, ret:{ret}")
@result_handler
def bm_copy_data_batch(self, handle_id: int, src_addrs_str: int, dst_addrs_str: int, sizes_str: int,
count: int, copy_type: int, flags: int):
handle = self._bm_handle_dic[handle_id]
src_addrs = list(map(int, src_addrs_str.split(",")))
dst_addrs = list(map(int, dst_addrs_str.split(",")))
sizes = list(map(int, sizes_str.split(",")))
ret = handle.copy_data_batch(src_addrs=src_addrs, dst_addrs=dst_addrs, sizes=sizes, count=count,
type=bm.BmCopyType(copy_type), flags=flags)
self.cli_print(f"bm copy_data_batch, ret:{ret}")
@result_handler
def bm_wait(self, handle_id: int):
handle = self._bm_handle_dic[handle_id]
ret = handle.wait()
self.cli_print(f"bm wait, ret:{ret}")
@result_handler
def delete_bm_handle(self, handle_id: int):
handle = self._bm_handle_dic[handle_id]
del handle
@result_handler
def delete_shm_handle(self, handle_id: int):
handle = self._shm_handle_dic[handle_id]
del handle
@result_handler
def shm_init(self, store_url: str, world_size: int, rank_id: int, device_id: int, config_id: int):
config = self._get_config(config_id)
ret = shm.initialize(store_url=store_url, world_size=world_size, rank_id=rank_id, device_id=device_id,
config=config)
self.cli_print(f"smem SHM initialize ret({ret})")
@result_handler
def shm_uninit(self, flags: int):
shm.uninitialize(flags)
@result_handler
def shm_create(self, shm_id: int, rank_size: int, rank_id: int, local_mem_size: int, data_op_type: int, flags: int):
shm_handle = shm.create(id=shm_id, rank_size=rank_size, rank_id=rank_id, local_mem_size=local_mem_size,
data_op_type=shm.ShmDataOpType(data_op_type), flags=flags)
if shm_handle is not None:
addr = id(shm_handle)
self._shm_handle_dic[addr] = shm_handle
else:
addr = 0
self.cli_print(f"create shm object:{addr}")
@result_handler
def shm_destroy(self, handle_id: int, flags: int):
handle = self._shm_handle_dic[handle_id]
self.cli_print(f"before shm_destroy")
handle.destroy(flags)
@result_handler
def shm_set_context(self, handle_id: int, context: str):
handle = self._shm_handle_dic[handle_id]
handle.set_context(bytes(context, encoding="utf-8"))
@result_handler
def shm_local_rank(self, handle_id: int):
handle = self._shm_handle_dic[handle_id]
rank_id = handle.local_rank()
self.cli_print(f"shm object local rank:(\d+):{rank_id}")
@result_handler
def rank_size(self, handle_id: int):
handle = self._shm_handle_dic[handle_id]
size = handle.local_rank()
self.cli_print(f"shm object size:{size}")
@result_handler
def shm_barrier(self, handle_id: int):
handle = self._shm_handle_dic[handle_id]
handle.barrier()
@result_handler
def shm_all_gather(self, handle_id: int, local_data: str):
handle = self._shm_handle_dic[handle_id]
handle.all_gather(local_data.encode())
@result_handler
def shm_gva(self, handle_id: int):
handle = self._shm_handle_dic[handle_id]
addr = handle.gva()
self.cli_print(f"get shm object gva:{addr}")
@result_handler
def generate_transfer_engine_npu_tensor(self, is_src: bool):
total_size_bytes = 4 * 1024 * 1024
device = 'npu'
if is_src:
big_buffer = torch.ones((total_size_bytes,), dtype=torch.uint8, device=device)
else:
big_buffer = torch.zeros((total_size_bytes,), dtype=torch.uint8, device=device)
buffer = big_buffer[0:1 * 1024 * 1024].reshape(1024, 1024)
buffer_bytes = buffer.element_size() * buffer.numel()
batch_buffers = [
big_buffer[2 * 1024 * 1024:3 * 1024 * 1024].reshape(1024, 1024),
big_buffer[3 * 1024 * 1024:4 * 1024 * 1024].reshape(1024, 1024)
]
batch_bytes = [b.element_size() * b.numel() for b in batch_buffers]
transfer_data = TransferEngineData(
buffer=buffer,
buffer_bytes=buffer_bytes,
batch_buffers=batch_buffers,
batch_bytes=batch_bytes
)
transfer_data_id = id(transfer_data)
self._transfer_engine_data_dic[transfer_data_id] = transfer_data
buffer_addrs = [b.data_ptr() for b in transfer_data.batch_buffers]
self.cli_print(f"generate transfer engine data id:{transfer_data_id}")
self.cli_print(f"generate transfer engine data buffer addr is:{buffer.data_ptr()}")
self.cli_print(f"generate transfer engine data buffer bytes is:{buffer_bytes}")
self.cli_print(f"generate transfer engine data batch buffer addr is:{buffer_addrs[0]},{buffer_addrs[1]}")
self.cli_print(f"generate transfer engine data batch buffer bytes is:{batch_bytes[0]},{batch_bytes[1]}")
@result_handler
def generate_transfer_engine_cpu_tensor(self, is_src: bool):
total_size_bytes = 4 * 1024 * 1024
device = 'cpu'
if is_src:
big_buffer = torch.ones((total_size_bytes,), dtype=torch.uint8, device=device)
else:
big_buffer = torch.zeros((total_size_bytes,), dtype=torch.uint8, device=device)
buffer = big_buffer[0:1 * 1024 * 1024].reshape(1024, 1024)
buffer_bytes = buffer.element_size() * buffer.numel()
batch_buffers = [
big_buffer[2 * 1024 * 1024:3 * 1024 * 1024].reshape(1024, 1024),
big_buffer[3 * 1024 * 1024:4 * 1024 * 1024].reshape(1024, 1024)
]
batch_bytes = [b.element_size() * b.numel() for b in batch_buffers]
transfer_data = TransferEngineData(
buffer=buffer,
buffer_bytes=buffer_bytes,
batch_buffers=batch_buffers,
batch_bytes=batch_bytes
)
transfer_data_id = id(transfer_data)
self._transfer_engine_data_dic[transfer_data_id] = transfer_data
buffer_addrs = [b.data_ptr() for b in transfer_data.batch_buffers]
self.cli_print(f"generate transfer engine data id:{transfer_data_id}")
self.cli_print(f"generate transfer engine data buffer addr is:{buffer.data_ptr()}")
self.cli_print(f"generate transfer engine data buffer bytes is:{buffer_bytes}")
self.cli_print(f"generate transfer engine data batch buffer addr is:{buffer_addrs[0]},{buffer_addrs[1]}")
self.cli_print(f"generate transfer engine data batch buffer bytes is:{batch_bytes[0]},{batch_bytes[1]}")
@result_handler
def generate_transfer_engine_npu_buffer_sum(self, transfer_data_id: int, is_batch: bool):
if transfer_data_id not in self._transfer_engine_data_dic:
self.cli_print(f"Transfer data with ID {transfer_data_id} not found")
return
transfer_data = self._transfer_engine_data_dic[transfer_data_id]
if is_batch:
buffer_sum = torch.sum(torch.cat(transfer_data.batch_buffers))
else:
buffer_sum = torch.sum(transfer_data.buffer)
self.cli_print(f"Calculate the buffer sum is:{buffer_sum}")
@result_handler
def generate_transfer_engine_cpu_buffer_sum(self, transfer_data_id: int, is_batch: bool):
if transfer_data_id not in self._transfer_engine_data_dic:
self.cli_print(f"Transfer data with ID {transfer_data_id} not found")
return
transfer_data = self._transfer_engine_data_dic[transfer_data_id]
if is_batch:
buffer_sum = torch.sum(torch.cat(transfer_data.batch_buffers))
else:
buffer_sum = torch.sum(transfer_data.buffer)
self.cli_print(f"Calculate the buffer sum is:{buffer_sum}")
@result_handler
def transfer_engine_get_rpc_port(self, handle_id: int):
engine = self._transfer_engine_dic[handle_id]
port = engine.get_rpc_port()
self.cli_print(f"generate rpc port is:{port}")
@result_handler
def transfer_engine_create_config_store(self, store_url: str):
create_config_store(store_url)
self.cli_print(f"*********store_url is : {store_url}")
self.cli_print("create transfer engine successfully.")
@result_handler
def transfer_engine_set_log_level(self, level: int):
set_log_level(level)
memfabric_hybrid.set_conf_store_tls(0, "0")
self.cli_print("set log level for transfer engine successfully.")
@result_handler
def transfer_engine_initialize(self, store_url: str, session_id: str, role: str, device_id: int, op_type: int):
engine = TransferEngine()
port = engine.get_rpc_port()
if op_type == 1:
data_op_type = TransferEngine.TransDataOpType.SDMA
elif op_type == 2:
data_op_type = TransferEngine.TransDataOpType.DEVICE_RDMA
else:
logging.error("Invalid optype: %d", op_type)
return
self.cli_print(f"TransDataOpType:{data_op_type}")
ret_value = engine.initialize(
store_url,
session_id,
role,
device_id,
data_op_type
)
if ret_value == 0:
addr = id(engine)
self._transfer_engine_dic[addr] = engine
else:
addr = 0
self.cli_print(f"transfer engine initialize finish:{addr}")
@result_handler
def transfer_engine_register_memory(self, handle_id: int, buffers: int, capacities: int):
engine = self._transfer_engine_dic[handle_id]
ret_value = engine.register_memory(buffers, capacities)
self.cli_print(f"register memory for transfer engine result is:{ret_value}")
@result_handler
def transfer_engine_transfer_sync_write(self, handle_id: int, dest_session: str, buffers: int,
peer_buffer_addresses: int, length: int):
engine = self._transfer_engine_dic[handle_id]
ret_value = engine.transfer_sync_write(dest_session, buffers, peer_buffer_addresses, length)
time.sleep(5)
self.cli_print(f"write data for transfer engine result is:{ret_value}")
@result_handler
def transfer_engine_transfer_sync_read(self, handle_id: int, dest_session: str, buffers: int,
peer_buffer_addresses: int, length: int):
engine = self._transfer_engine_dic[handle_id]
ret_value = engine.transfer_sync_read(dest_session, buffers, peer_buffer_addresses, length)
time.sleep(5)
self.cli_print(f"read data for transfer engine result is:{ret_value}")
@result_handler
def transfer_engine_batch_register_memory(self, handle_id: int, buffers_str: str, capacities_str: str):
engine = self._transfer_engine_dic[handle_id]
buffers = list(map(int, buffers_str.split(",")))
capacities = list(map(int, capacities_str.split(",")))
ret_value = engine.batch_register_memory(buffers, capacities)
self.cli_print(f"register batch memory for transfer engine result is:{ret_value}")
@result_handler
def transfer_engine_batch_transfer_sync_write(self, handle_id: int, dest_session: str, buffers_str: str,
peer_buffer_addresses_str: str, lengths_str: str):
buffers = list(map(int, buffers_str.split(",")))
peer_buffer_addresses = list(map(int, peer_buffer_addresses_str.split(",")))
lengths = list(map(int, lengths_str.split(",")))
engine = self._transfer_engine_dic[handle_id]
ret_value = engine.batch_transfer_sync_write(dest_session, buffers, peer_buffer_addresses, lengths)
time.sleep(5)
self.cli_print(f"write batch data for transfer engine result is:{ret_value}")
@result_handler
def transfer_engine_batch_transfer_sync_read(self, handle_id: int, dest_session: str, buffers_str: str,
peer_buffer_addresses_str: str, lengths_str: str):
buffers = list(map(int, buffers_str.split(",")))
peer_buffer_addresses = list(map(int, peer_buffer_addresses_str.split(",")))
lengths = list(map(int, lengths_str.split(",")))
engine = self._transfer_engine_dic[handle_id]
ret_value = engine.batch_transfer_sync_read(dest_session, buffers, peer_buffer_addresses, lengths)
time.sleep(5)
self.cli_print(f"read batch data for transfer engine result is:{ret_value}")
@result_handler
def transfer_engine_transfer_async_write_submit(self, handle_id: int, dest_session: str, buffers: int,
peer_buffer_addresses: int, length: int, stream: int):
engine = self._transfer_engine_dic[handle_id]
if stream != 0:
stream = self._stream
ret_value = engine.transfer_async_write_submit(dest_session, buffers, peer_buffer_addresses, length, stream)
time.sleep(5)
self.cli_print(f"async write data submit for transfer engine result is:{ret_value}")
@result_handler
def transfer_engine_transfer_async_read_submit(self, handle_id: int, dest_session: str, buffers: int,
peer_buffer_addresses: int, length: int, stream: int):
engine = self._transfer_engine_dic[handle_id]
if stream != 0:
stream = self._stream
ret_value = engine.transfer_async_read_submit(dest_session, buffers, peer_buffer_addresses, length, stream)
time.sleep(5)
self.cli_print(f"async read data submit for transfer engine result is:{ret_value}")
@result_handler
def transfer_engine_unregister_memory(self, handle_id: int, buffers: int):
engine = self._transfer_engine_dic[handle_id]
ret_value = engine.unregister_memory(buffers)
self.cli_print(f"unregister memory for transfer engine result is:{ret_value}")
@result_handler
def transfer_engine_destroy(self, handle_id: int):
self.cli_print("start transfer_engine_destroy")
engine = self._transfer_engine_dic[handle_id]
engine.destroy()
self.cli_print("finish transfer_engine_destroy")
@result_handler
def transfer_engine_uninitialize(self, handle_id: int):
self.cli_print("start transfer_engine_uninitialize")
engine = self._transfer_engine_dic[handle_id]
engine.unInitialize()
self.cli_print("finish transfer_engine_uninitialize")
@result_handler
def delete_trans_handle(self, handle_id: int):
handle = self._transfer_engine_dic[handle_id]
del handle
if __name__ == "__main__":
if len(sys.argv) <= 1:
logging.error("Please input app_id and device_id when starting the process"
"usage: memfabric_test.py [app_id] [device_id].")
sys.exit(1)
app_id = int(sys.argv[1])
globals_device_id = int(sys.argv[2])
logging.info(f"Start app_id: {app_id}, device:{globals_device_id}")
server = MfTest(app_id)
server.start()