"""
Invoke this script in `pypto` project root dir
"""
from dataclasses import dataclass
from typing import List, Tuple, Optional
import argparse
import os
import struct
import numpy as np
import torch
from thread_task_runner import ThreadTaskRunner
DEFAULT_MAX_WORKERS = 16
HUGE_TENSOR_THRESHOLD = 1024 * 1024
PROPERTY_NUM = 6
def parse_task_id(task_id: int) -> Tuple[int, int]:
taskid_task_bits = 16
return (task_id >> taskid_task_bits, task_id & ((1 << taskid_task_bits) - 1))
@dataclass(frozen=True)
class RawTensorDesc:
seq_no: int
task_id: int
raw_magic: int
address: int
dtype: str
bytes_of_dtype: int
shape: Tuple[int]
io_mark: Optional[str]
symlink_src: Optional[str] = None
def numel(self) -> int:
return np.prod(self.shape)
def tensor_key(self) -> Tuple[int, str, Tuple[int]]:
return (self.address, self.dtype, tuple(self.shape))
def name(self) -> str:
func_id, subtask_id = parse_task_id(self.task_id)
name = f"{self.seq_no}-{func_id}-{subtask_id}-{self.raw_magic}"
if self.io_mark is not None:
name += f"-{self.io_mark}"
return name
def copy_with_symlink_src(self, symlink_src) -> "RawTensorDesc":
return RawTensorDesc(
seq_no=self.seq_no,
task_id=self.task_id,
raw_magic=self.raw_magic,
address=self.address,
dtype=self.dtype,
bytes_of_dtype=self.bytes_of_dtype,
shape=self.shape,
io_mark=self.io_mark,
symlink_src=symlink_src,
)
def parse_tile_fwk_aicpu_ctrl(filename, inputs, outputs) -> List[RawTensorDesc]:
io_map = {}
def preprocess_io_map(datas, class_str: str):
for idx, data in enumerate(datas):
if data not in io_map:
io_map[data] = class_str + str(idx)
else:
io_map[data] += '-' + class_str + str(idx)
preprocess_io_map(inputs, 'i')
preprocess_io_map(outputs, 'o')
raw_tensors: List[RawTensorDesc] = []
tensor_cache = {}
with open(filename, 'r') as f:
start_recording = False
for line in f:
stripped_line = line.strip()
if not stripped_line:
continue
if '[DumpTensor]' not in stripped_line:
continue
stripped_line = stripped_line.split('[DumpTensor]')[-1].strip(' "')
if stripped_line.startswith(">>>"):
start_recording = True
continue
if stripped_line.startswith("<<<"):
start_recording = False
continue
if start_recording:
splits = stripped_line.strip().split(',')
seq_no, task_id, raw_magic, address, dtype, bytes_of_dtype = splits[:PROPERTY_NUM]
seq_no, task_id, raw_magic, address, bytes_of_dtype = map(int, (
seq_no, task_id, raw_magic, address, bytes_of_dtype
))
shape = tuple(map(lambda x : int(x.strip('()')), splits[PROPERTY_NUM:]))
rt = RawTensorDesc(
seq_no=seq_no,
task_id=task_id,
raw_magic=raw_magic,
address=address,
dtype=dtype,
bytes_of_dtype=bytes_of_dtype,
shape=shape,
io_mark=None if address not in io_map else io_map[address],
)
key = rt.tensor_key()
if key in tensor_cache:
symlink_src = tensor_cache[key]
if symlink_src != rt.name():
rt = rt.copy_with_symlink_src(symlink_src)
else:
tensor_cache[key] = rt.name()
raw_tensors.append(rt)
raw_tensors = list(set(raw_tensors))
return raw_tensors
class ByteTable:
def __init__(self, binary_data, offset=0):
self.blocks = []
self._parse(binary_data, offset)
def query(self, addr_start, addr_end):
for base, size, data in self.blocks:
block_end = base + size
if addr_end <= base or addr_start >= block_end:
continue
if not (addr_start >= base and addr_end <= block_end):
raise Exception("Unexpected memrange spanning multiple memblocks")
offset_in_block = addr_start - base
length = addr_end - addr_start
return bytearray(data[offset_in_block:offset_in_block + length])
raise Exception("Address mismatching")
def _parse(self, data, offset=0):
while offset < len(data):
base_addr, size = struct.unpack_from('<QQ', data, offset)
offset += 16
block_data = data[offset:offset + size]
offset += size
self.blocks.append((base_addr, size, block_data))
print(f"Parsed a binary data block | addr=0x{base_addr:X}, size={size}")
def read_uint64_list(binary_data: bytes, offset: int):
size = struct.unpack_from('<Q', binary_data, offset)[0]
offset += 8
data_format = f'<{size}Q'
data = list(struct.unpack_from(data_format, binary_data, offset))
offset += size * 8
return data, offset
def parse_dump_tensor_binary(filename) -> Tuple[ByteTable, List[int], List[int]]:
with open(filename, 'rb') as f:
binary_data = f.read()
offset = 0
inputs, offset = read_uint64_list(binary_data, offset)
outputs, offset = read_uint64_list(binary_data, offset)
return ByteTable(binary_data, offset), inputs, outputs
def pypto_dtype_to_torch_dtype(dtype: str) -> Optional[torch.dtype]:
dtype_map = {
"INT4": None,
"INT8": torch.int8,
"INT16": torch.int16,
"INT32": torch.int32,
"INT64": torch.int64,
"FP8": None,
"FP16": torch.float16,
"FP32": torch.float32,
"BF16": torch.bfloat16,
"HF4": None,
"HF8": None,
"UINT8": torch.uint8,
"UINT16": None,
"UINT32": None,
"UINT64": None,
"BOOL": torch.bool,
"DOUBLE": torch.double,
}
if dtype not in dtype_map:
print(f"Invalid pypto dtype: {dtype}")
return None
torch_dtype = dtype_map[dtype]
if torch_dtype is None:
print(f"Cannot convert pypto dtype: {dtype} to corresponding torch dtype")
return None
return torch_dtype
def parse_arguments():
parser = argparse.ArgumentParser(description="Process process_dump_tensor.")
parser.add_argument("dump_tensor_filename", type=str, help="Path to dump_tensor.txt")
parser.add_argument("tile_fwk_aicpu_ctrl_filename", type=str, help="Path to tile_fwk_aicpu_ctrl.txt")
parser.add_argument("--max_workers", type=int, default=DEFAULT_MAX_WORKERS,
help=f"Maximum number of threading workers, {DEFAULT_MAX_WORKERS} by default")
return parser.parse_args()
def main():
args = parse_arguments()
binary_table, inputs, outputs = parse_dump_tensor_binary(args.dump_tensor_filename)
raw_tensors = parse_tile_fwk_aicpu_ctrl(args.tile_fwk_aicpu_ctrl_filename, inputs, outputs)
base_dir = os.path.dirname(args.dump_tensor_filename)
dump_tensor_dir = os.path.join(base_dir, "dump_tensor")
os.makedirs(dump_tensor_dir, exist_ok=True)
torch.set_printoptions(
threshold=1024**3,
linewidth=1024**2,
)
print(f"In total {len(raw_tensors)} raw tensors to be processed")
def seq_no_dir_str(seq_no: int):
return f"seqNo-{seq_no}"
for rt in raw_tensors:
os.makedirs(os.path.join(dump_tensor_dir, seq_no_dir_str(rt.seq_no)), exist_ok=True)
def dump_raw_tensor(rt: RawTensorDesc):
dst_file = os.path.join(dump_tensor_dir, seq_no_dir_str(rt.seq_no), f"{rt.name()}.txt")
if rt.symlink_src is not None:
assert rt.symlink_src != rt.name(), f"Invalid symlink to self: {rt.name()}"
src_seq_no = int(rt.symlink_src.split('-')[0])
src_file = f"../{seq_no_dir_str(src_seq_no)}/{rt.symlink_src}.txt"
if os.path.islink(dst_file) or os.path.exists(dst_file):
os.remove(dst_file)
os.symlink(src=src_file, dst=dst_file)
return
mem_req = rt.numel() * rt.bytes_of_dtype
binary_data = binary_table.query(rt.address, rt.address + mem_req)
torch_dtype = pypto_dtype_to_torch_dtype(rt.dtype)
if torch_dtype is None:
return
tensor = torch.frombuffer(binary_data, dtype=torch_dtype).reshape(rt.shape)
with open(dst_file, 'w') as f:
f.write(f"address=0x{rt.address:X}\nshape={tensor.shape}\ndtype={tensor.dtype}\n\n")
f.write(str(tensor))
def is_huge_tensor(rt: RawTensorDesc):
return rt.numel() >= HUGE_TENSOR_THRESHOLD
def task_info(rt: RawTensorDesc):
return f"{rt.name()}: shape={list(rt.shape)}, dtype={rt.dtype}"
runner = ThreadTaskRunner(title="Dump Tensor", max_workers=args.max_workers)
runner.run_batch(raw_tensors, dump_raw_tensor, is_huge_tensor, task_info)
print(f"Output files location: `{dump_tensor_dir}`")
if __name__ == "__main__":
main()