import os
import threading
import time
import logging
from multiprocessing import Queue
import torch
import torch_npu
from mspti import (
KernelData,
KernelMonitor,
CommunicationData,
CommunicationMonitor
)
data_queue = Queue()
logging.basicConfig(format='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s',
level=logging.INFO)
def kernel_parser(data: KernelData):
data_queue.put(data)
def communication_parser(data: CommunicationData):
data_queue.put(data)
def consumer_func(consume_queue):
while True:
if not consume_queue.empty():
data = consume_queue.get()
if data is None:
break
if isinstance(data, KernelData):
logging.info(f'{data.kind}, {data.start}, {data.end}, {data.device_id}, {data.stream_id}, '
f'{data.correlation_id}, {data.type}, {data.name}')
elif isinstance(data, CommunicationData):
logging.info(f'{data.kind}, {data.start}, {data.end}, {data.device_id}, {data.stream_id}, '
f'{data.data_type}, {data.count}, {data.name}, {data.comm_name}, '
f'{data.alg_type}, {data.correlation_id}')
else:
time.sleep(0.1)
def init_process(backend="hccl"):
torch.distributed.init_process_group(backend=backend, init_method='env://')
def test_monitor():
consumer = threading.Thread(target=consumer_func, args=(data_queue, ))
consumer.start()
kernel_monitor = KernelMonitor()
kernel_monitor.start(kernel_parser)
communication_monitor = CommunicationMonitor()
communication_monitor.start(communication_parser)
init_process()
device = int(os.getenv('LOCAL_RANK'))
torch.npu.set_device(device)
width = 256
x = torch.randn(width, width, dtype=torch.float16).npu()
y = torch.randn(width, width, dtype=torch.float16).npu()
result = x + y
result = torch.matmul(x, y)
torch.distributed.all_reduce(result)
torch.npu.synchronize()
kernel_monitor.stop()
communication_monitor.stop()
data_queue.put(None)
consumer.join()
if __name__ == "__main__":
test_monitor()