#!/bin/bash
device_count="${1:-1}"
threads_per_device="${2:-64}"
DIR_TESTS="tests"
DIR_BENCHMARK="benchmark"
PR_LOG_DIR="/home/pr_test_log"
TIMESTAMP=$(date +"%Y%m%d")
LOG_ARCHIVE="test_flaggems_logs_${TIMESTAMP}.tar.gz"
SUMMARY_FILE="${WORKSPACE}/triton-ascend/ascend/examples/summary.txt"
mkdir -p "$PR_LOG_DIR" || { echo "无法创建日志目录 $PR_LOG_DIR"; exit 1; }
COUNTER_FILE=$(mktemp)
LOCK_FILE="/tmp/op_test_run.lock"
touch $LOCK_FILE
STATS_DIR=$(mktemp -d)
for ((device_id=0; device_id < device_count; device_id++)); do
stats_file="${STATS_DIR}/device_${device_id}.stats"
echo "success=0" > "$stats_file"
echo "failure=0" >> "$stats_file"
echo "skipped=0" >> "$stats_file"
echo "error=0" >> "$stats_file"
done
record_stats() {
local device_id=$1
local status=$2
local stats_file="${STATS_DIR}/device_${device_id}.stats"
(
flock -x 20
current=$(grep "^${status}=" "$stats_file" | cut -d= -f2)
new_value=$((current + 1))
sed -i "s/^${status}=.*/${status}=${new_value}/" "$stats_file"
) 20>"${stats_file}.lock"
}
init_task_queue() {
local -n arr_ref=$1
TASK_FILE=$(mktemp)
printf "%s\n" "${arr_ref[@]}" > "$TASK_FILE"
echo 0 > "$TASK_FILE.counter"
echo "${#arr_ref[@]}" > "$COUNTER_FILE.total"
echo 0 > "$COUNTER_FILE.completed"
}
get_next_task() {
(
flock -x 9
counter=$(< $TASK_FILE.counter)
total_tasks=$(wc -l < $TASK_FILE)
if (( counter >= total_tasks )); then
echo ""
return
fi
task_name=$(sed -n "$((counter+1))p" $TASK_FILE)
echo $((counter+1)) > "$TASK_FILE.counter"
echo "$task_name"
) 9> "$TASK_FILE.lock"
}
update_progress() {
(
flock -x 11
local current=$(< $COUNTER_FILE.completed)
echo $((current + 1)) > $COUNTER_FILE.completed
echo $((current + 1))
) 11> $LOCK_FILE
}
get_progress() {
(
flock -s 11
completed=$(< $COUNTER_FILE.completed)
total=$(< $COUNTER_FILE.total)
echo "$completed $total"
) 11> $LOCK_FILE
}
cleanup_tasks() {
rm -f "$TASK_FILE" "$TASK_FILE.counter" "$TASK_FILE.lock" $LOCK_FILE $COUNTER_FILE*
}
OPS=("abs" "add" "addmm" "all" "amax" "argmax" "bitwise_and" "bitwise_not" "bitwise_or" "bmm" \
"cos" "CrossEntryLoss" "div" "dropout" "eq" "exp" "fill" "ge" "gelu" "group_norm" "gt" "isinf" \
"isnan" "rsub" "le" "linear" "log_softmax" "lt" "max" "mean" "min" "mm" "mul" "mv" \
"native_dropout" "ne" "neg" "pow" "prod" "reciprocal" "relu" "rsqrt" "sigmoid" "silu" \
"sin" "softmax" "sub" "sum" "tanh" "triu")
total_ops=${#OPS[@]}
echo "======================================"
echo "测试算子列表: ${OPS[@]}"
echo "算子总数: $total_ops"
echo "使用设备数量: $device_count"
echo "每设备线程数: $threads_per_device"
echo "======================================"
start_time=$(date +%s)
run_tests_thread() {
local device_id=$1
local thread_id=$2
local device_log_dir=$3
local thread_log_dir="$device_log_dir/thread_${thread_id}"
mkdir -p "$thread_log_dir"
while true; do
task_name=$(get_next_task)
[[ -z "$task_name" ]] && break
echo "[设备 $device_id-线程 $thread_id] 正在执行: pytest -m $task_name --ref cpu -sv"
log_file="${thread_log_dir}/result_${task_name}.log"
start_op=$(date +%s)
python -m pytest -m $task_name --dist=loadfile --ref cpu -sv &> "$log_file"
exit_code=$?
duration=$(( $(date +%s) - start_op ))
case $exit_code in
0)
status="success"
;;
1)
status="failure"
;;
2)
status="skipped"
;;
*)
status="error"
;;
esac
record_stats $device_id $status
new_completed=$(update_progress)
read completed total < <(get_progress)
progress=$(( completed * 100 / total ))
if [ $exit_code -ne 0 ]; then
echo "[错误] [$device_id-$thread_id] $task_name 失败! (用时 ${duration}s, 进度: $completed/$total)"
else
echo "[成功] [$device_id-$thread_id] $task_name 完成! (用时 ${duration}s, 进度: $completed/$total)"
fi
done
}
run_device() {
local device_id=$1
local device_log_dir="device_${device_id}_logs"
mkdir -p "$device_log_dir"
for ((thread_id=0; thread_id < threads_per_device; thread_id++)); do
run_tests_thread $device_id $thread_id "$device_log_dir" &
done
wait
echo "======== 设备 $device_id 上所有任务完成 ========"
}
cd "$DIR_TESTS" || { echo "无法进入目录 $DIR_TESTS"; exit 1; }
init_task_queue OPS
for ((device_id=0; device_id < device_count; device_id++)); do
(
export ASCEND_RT_VISIBLE_DEVICES=$device_id
run_device $device_id
) &
done
wait
cleanup_tasks
total_success=0
total_failure=0
total_skipped=0
total_error=0
for ((device_id=0; device_id < device_count; device_id++)); do
stats_file="${STATS_DIR}/device_${device_id}.stats"
if [ -f "$stats_file" ]; then
d_success=$(grep '^success=' "$stats_file" | cut -d= -f2)
d_failure=$(grep '^failure=' "$stats_file" | cut -d= -f2)
d_skipped=$(grep '^skipped=' "$stats_file" | cut -d= -f2)
d_error=$(grep '^error=' "$stats_file" | cut -d= -f2)
total_success=$((total_success + d_success))
total_failure=$((total_failure + d_failure))
total_skipped=$((total_skipped + d_skipped))
total_error=$((total_error + d_error))
echo "设备 $device_id 完成情况: $d_success 成功, $d_failure 失败, $d_skipped 跳过, $d_error 错误"
else
echo "警告: 设备 $device_id 的统计文件未找到"
fi
done
rm -rf "$STATS_DIR"
total_time=$(( $(date +%s) - start_time ))
hours=$(( total_time / 3600 ))
minutes=$(( (total_time % 3600) / 60 ))
seconds=$(( total_time % 60 ))
time_str=$(printf "%02dh %02dm %02ds" $hours $minutes $seconds)
if [[ $total_ops -gt 0 ]]; then
completed_ops=$((total_success + total_failure + total_error))
if [[ $completed_ops -gt 0 ]]; then
avg_time=$((total_time / completed_ops))
avg_min=$((avg_time / 60))
avg_sec=$((avg_time % 60))
avg_str=$(printf "%02dm %02ds" $avg_min $avg_sec)
else
avg_str="N/A"
fi
else
avg_str="N/A"
fi
{
echo "===================== flaggems测试统计摘要 ====================="
echo "开始时间: $(date -d @$start_time '+%Y-%m-%d %H:%M:%S')"
echo "结束时间: $(date '+%Y-%m-%d %H:%M:%S')"
echo "测试日期: $(date '+%Y-%m-%d')"
echo "总耗时: $time_str"
echo "--------------------------------------------------------"
echo "总算子数: $total_ops"
echo "成功用例数: $total_success"
echo "失败用例数: $total_failure"
echo "跳过用例数: $total_skipped"
echo "错误用例数: $total_error"
echo "完成用例数: $((total_success + total_failure + total_error))"
if [[ $total_ops -gt 0 ]]; then
echo "完成率: $(( (total_success + total_failure + total_error) * 100 / total_ops ))%"
else
echo "完成率: N/A"
fi
if [[ $total_success -gt 0 ]] || [[ $total_failure -gt 0 ]] || [[ $total_error -gt 0 ]]; then
success_rate=$(( total_success * 100 / (total_success + total_failure + total_error) ))
echo "成功率: ${success_rate}%"
else
echo "成功率: N/A"
fi
echo "平均耗时/算子: $avg_str"
echo "--------------------------------------------------------"
echo "设备数量: $device_count"
echo "每设备线程数: $threads_per_device"
echo "========================================================"
echo ""
} | tee -a $SUMMARY_FILE
log_dirs=($(find . -maxdepth 1 -type d -name "device_*_logs" 2>/dev/null))
if [ ${#log_dirs[@]} -gt 0 ]; then
echo "归档日志文件到 $LOG_ARCHIVE"
tar -czf "$LOG_ARCHIVE" "${log_dirs[@]}"
if mv "$LOG_ARCHIVE" "$PR_LOG_DIR"; then
echo "日志已保存到: $PR_LOG_DIR/$LOG_ARCHIVE"
else
echo "警告:日志移动到 $PR_LOG_DIR 失败"
fi
rm -rf "${log_dirs[@]}"
else
echo "警告:未找到任何日志目录,跳过归档"
fi
echo "所有算子测试执行完成!"
echo "详细统计信息已追加到: $SUMMARY_FILE"
exit 0