import subprocess
import os
import re
import time
import pandas as pd
import pytest
import logging
import sqlite3
from typing import List, Optional, Tuple
COMMAND_SUCCESS = 0
def detect_free_npu_card(max_devices: int = 4) -> Optional[List[int]]:
"""
检测 NPU 上哪张卡剩余空闲(可用显存最多)。
通过 npu-smi info -i {device_id} -t usages 查询各卡显存占用,选取可用显存最多的卡号。
多人并发时:多卡可用显存相近时,按 pid+时间戳 分散到不同卡,避免抢同一张卡。
返回:
卡号列表,按优先级排序(首选在前),用于重试时依次换卡。检测失败时返回 None。
"""
candidates: List[Tuple[int, float]] = []
for device_id in range(max_devices):
try:
cmd = ["npu-smi", "info", "-i", str(device_id), "-t", "usages"]
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=5,
)
if result.returncode != 0:
continue
output = result.stdout or ""
capacity_mb = None
usage_pct = None
for line in output.splitlines():
line = line.strip()
cap_match = re.search(r"Capacity\(MB\)\s*:\s*(\d+)", line, re.I)
if cap_match:
capacity_mb = int(cap_match.group(1))
continue
usage_match = re.search(
r"(?:Usage\s+Rate|Usage\s+Rate\(%\))\s*:\s*(\d+)", line, re.I
)
if usage_match:
usage_pct = int(usage_match.group(1))
continue
if capacity_mb is not None and usage_pct is not None:
available_mb = capacity_mb * (1 - usage_pct / 100.0)
candidates.append((device_id, available_mb))
except (subprocess.TimeoutExpired, ValueError, OSError):
continue
if not candidates:
return None
candidates.sort(key=lambda x: x[1], reverse=True)
best_available_mb = candidates[0][1]
similar = [(d, m) for d, m in candidates if m >= best_available_mb * 0.9]
if len(similar) > 1:
idx = (os.getpid() + int(time.time() * 1000)) % len(similar)
ordered = similar[idx:] + similar[:idx]
else:
ordered = similar
return [d for d, _ in ordered]
def execute_cmd(cmd):
logging.info('Execute command:%s' % " ".join(cmd))
completed_process = subprocess.run(cmd, shell=False, stderr=subprocess.PIPE)
if completed_process.returncode != COMMAND_SUCCESS:
logging.error(completed_process.stderr.decode())
return completed_process.returncode
def execute_script(cmd):
logging.info('Execute command:%s' % " ".join(cmd))
process = subprocess.Popen(cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
while process.poll() is None:
line = process.stdout.readline().strip()
if line:
logging.debug(line)
return process.returncode
def check_result_file(out_path):
pass
def select_count(db_path: str, query: str):
"""
Execute a SQL query to count the number of records in the database.
"""
conn, cursor = create_connect_db(db_path)
cursor.execute(query)
count = cursor.fetchone()
destroy_db_connect(conn, cursor)
return count[0]
def select_by_query(db_path: str, query: str, db_class):
"""
Execute a SQL query and return the first record as an instance of db_class.
"""
conn, cursor = create_connect_db(db_path)
cursor.execute(query)
rows = cursor.fetchall()
dbs = [db_class(*row) for row in rows]
destroy_db_connect(conn, cursor)
return dbs[0]
def create_connect_db(db_file: str) -> tuple:
"""
Create a connection to the SQLite database.
"""
try:
conn = sqlite3.connect(db_file)
curs = conn.cursor()
return conn, curs
except sqlite3.Error as e:
logging.error("Unable to connect to database: %s", e)
return None, None
def destroy_db_connect(conn: any, curs: any) -> None:
"""
Close the database connection and cursor.
"""
try:
if isinstance(curs, sqlite3.Cursor):
curs.close()
except sqlite3.Error as err:
logging.error("%s", err)
try:
if isinstance(conn, sqlite3.Connection):
conn.close()
except sqlite3.Error as err:
logging.error("%s", err)
def change_dict(obj, *keys, value=None):
"""
修改嵌套数据结构中指定路径的值
参数:
obj: 要修改的数据结构(字典、列表或元组)
*keys: 表示路径的键序列
value: 要设置的值
返回:
修改后的数据结构(注意:元组会被转换为列表)
异常情况:
如果路径不存在或类型不支持,会引发KeyError或TypeError
"""
if not keys:
return value
current = obj
for key in keys[:-1]:
if isinstance(current, dict):
if key not in current:
current[key] = {}
current = current[key]
elif isinstance(current, (list, tuple)):
if not isinstance(key, int) or key < 0 or key >= len(current):
raise IndexError(f"Invalid list index: {key}")
current = current[key]
else:
raise TypeError(f"Cannot access key '{key}' in object of type {type(current)}")
last_key = keys[-1]
if isinstance(current, dict):
current[last_key] = value
elif isinstance(current, (list, tuple)):
if not isinstance(last_key, int) or last_key < 0 or last_key > len(current):
raise IndexError(f"Invalid list index: {last_key}")
if last_key == len(current):
current.append(value)
else:
current[last_key] = value
else:
raise TypeError(f"Cannot set key '{last_key}' in object of type {type(current)}")
return obj
def check_column_actual(actual_columns, expected_columns, context):
"""检查实际列名是否与预期列名一致"""
for col in expected_columns:
if col not in actual_columns:
logging.error(f"在 {context} 中未找到预期列名: {col}")
return False
return True
def check_no_empty_lines_before_first_line(dataframe, context=""):
empty_line = 0
for _, row in dataframe.iterrows():
if row.isnull().all():
empty_line += 1
else:
break
pytest.assume(empty_line == 0, f"{context} table has {empty_line} empty lines before first line.")
def check_no_empty_lines_between_first_last_line(dataframe, context=""):
empty_rows = dataframe.eq('').all(axis=1)
num_empty_rows = empty_rows.sum()
pytest.assume(num_empty_rows == 0, f"{context} table has empty lines.")
def check_during_time(dataframe, context=""):
required_columns = ['end_time(ms)', 'start_time(ms)', 'during_time(ms)']
for col in required_columns:
if col not in dataframe.columns:
logging.error(f"The column {col} not found in {context}.")
return False
for index, row in dataframe.iloc[:-1].iterrows():
end_time = row['end_time(ms)']
start_time = row['start_time(ms)']
during_time = row['during_time(ms)']
diff = abs((end_time - start_time) - during_time)
if diff > 1:
logging.error(f"In row {index} of {context}, the during_time is not correct.")
return False
return True
def check_split_csv_content(output_path, csv_file_name):
csv_file = os.path.join(output_path, csv_file_name)
pytest.assume(os.path.exists(csv_file), f"CSV file not found: {csv_file}")
task_name = os.path.splitext(csv_file_name)[0]
expected_header = ['name', 'during_time(ms)', 'max', 'min', 'mean', 'std', \
'pid', 'tid', 'start_time(ms)', 'end_time(ms)']
if task_name == 'prefill':
expected_header.append('rid')
df = pd.read_csv(csv_file)
result = check_column_actual(df.columns.tolist(), expected_header, context=csv_file_name)
pytest.assume(result, f"{csv_file_name} column check failed")
check_no_empty_lines_before_first_line(df, context=csv_file_name)
check_no_empty_lines_between_first_last_line(df, context=csv_file_name)
result = check_during_time(df, context=csv_file_name)
pytest.assume(result, f"{csv_file_name} execution time validation failed")
def check_row(df, expected_columns, numeric_columns):
"""检查数据框中Metric列数据类型和指定列数据是否为数字"""
for row_index in df.index:
try:
value = df.at[row_index, 'Metric']
if not isinstance(value, str):
logging.error(f"在Metric列的第{row_index}行,值 '{value}' 不是字符串类型")
return False
except KeyError:
logging.error(f"数据框中不存在 'Metric' 列")
return False
for column in numeric_columns:
if column not in df.columns:
logging.error(f"数据框中不存在 {column} 列")
continue
for row_index in df.index:
try:
cell_value = df.at[row_index, column]
float(cell_value)
except (ValueError, KeyError):
logging.error(
f"在 {column} 列的第 {row_index} 行,值 {cell_value} 不是有效的数字")
return False
return True