# -------------------------------------------------------------------------
# This file is part of the MindStudio project.
# Copyright (c) 2025 Huawei Technologies Co.,Ltd.
#
# MindStudio is licensed under Mulan PSL v2.
# You can use this software according to the terms and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
#
#          http://license.coscl.org.cn/MulanPSL2
#
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.
# -------------------------------------------------------------------------

import subprocess
import os
import re
import time
import pandas as pd
import pytest
import logging
import sqlite3
from typing import List, Optional, Tuple

COMMAND_SUCCESS = 0


def detect_free_npu_card(max_devices: int = 4) -> Optional[List[int]]:
    """
    检测 NPU 上哪张卡剩余空闲(可用显存最多)。
    通过 npu-smi info -i {device_id} -t usages 查询各卡显存占用,选取可用显存最多的卡号。
    多人并发时:多卡可用显存相近时,按 pid+时间戳 分散到不同卡,避免抢同一张卡。

    返回:
        卡号列表,按优先级排序(首选在前),用于重试时依次换卡。检测失败时返回 None。
    """
    candidates: List[Tuple[int, float]] = []  # (device_id, available_mb)
    for device_id in range(max_devices):
        try:
            cmd = ["npu-smi", "info", "-i", str(device_id), "-t", "usages"]
            result = subprocess.run(
                cmd,
                capture_output=True,
                text=True,
                timeout=5,
            )
            if result.returncode != 0:
                continue
            output = result.stdout or ""
            capacity_mb = None
            usage_pct = None
            for line in output.splitlines():
                line = line.strip()
                cap_match = re.search(r"Capacity\(MB\)\s*:\s*(\d+)", line, re.I)
                if cap_match:
                    capacity_mb = int(cap_match.group(1))
                    continue
                usage_match = re.search(
                    r"(?:Usage\s+Rate|Usage\s+Rate\(%\))\s*:\s*(\d+)", line, re.I
                )
                if usage_match:
                    usage_pct = int(usage_match.group(1))
                    continue
            if capacity_mb is not None and usage_pct is not None:
                available_mb = capacity_mb * (1 - usage_pct / 100.0)
                candidates.append((device_id, available_mb))
        except (subprocess.TimeoutExpired, ValueError, OSError):
            continue
    if not candidates:
        return None
    # 按可用显存降序
    candidates.sort(key=lambda x: x[1], reverse=True)
    best_available_mb = candidates[0][1]
    # 多人并发:显存相近的卡(差距 < 10%)按 pid+毫秒时间戳 分散,降低同时拉起时撞卡概率
    similar = [(d, m) for d, m in candidates if m >= best_available_mb * 0.9]
    if len(similar) > 1:
        idx = (os.getpid() + int(time.time() * 1000)) % len(similar)
        ordered = similar[idx:] + similar[:idx]
    else:
        ordered = similar
    return [d for d, _ in ordered]


def execute_cmd(cmd):
    logging.info('Execute command:%s' % " ".join(cmd))
    completed_process = subprocess.run(cmd, shell=False, stderr=subprocess.PIPE)
    if completed_process.returncode != COMMAND_SUCCESS:
        logging.error(completed_process.stderr.decode())
    return completed_process.returncode


def execute_script(cmd):
    logging.info('Execute command:%s' % " ".join(cmd))
    process = subprocess.Popen(cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
    while process.poll() is None:
        line = process.stdout.readline().strip()
        if line:
            logging.debug(line)
    return process.returncode


def check_result_file(out_path):
    pass



def select_count(db_path: str, query: str):
    """
    Execute a SQL query to count the number of records in the database.
    """
    conn, cursor = create_connect_db(db_path)
    cursor.execute(query)
    count = cursor.fetchone()
    destroy_db_connect(conn, cursor)
    return count[0]


def select_by_query(db_path: str, query: str, db_class):
    """
    Execute a SQL query and return the first record as an instance of db_class.
    """
    conn, cursor = create_connect_db(db_path)
    cursor.execute(query)
    rows = cursor.fetchall()
    dbs = [db_class(*row) for row in rows]
    destroy_db_connect(conn, cursor)
    return dbs[0]


def create_connect_db(db_file: str) -> tuple:
    """
    Create a connection to the SQLite database.
    """
    try:
        conn = sqlite3.connect(db_file)
        curs = conn.cursor()
        return conn, curs
    except sqlite3.Error as e:
        logging.error("Unable to connect to database: %s", e)
        return None, None


def destroy_db_connect(conn: any, curs: any) -> None:
    """
    Close the database connection and cursor.
    """
    try:
        if isinstance(curs, sqlite3.Cursor):
            curs.close()
    except sqlite3.Error as err:
        logging.error("%s", err)
    try:
        if isinstance(conn, sqlite3.Connection):
            conn.close()
    except sqlite3.Error as err:
        logging.error("%s", err)


def change_dict(obj, *keys, value=None):
    """
    修改嵌套数据结构中指定路径的值

    参数:
        obj: 要修改的数据结构(字典、列表或元组)
        *keys: 表示路径的键序列
        value: 要设置的值

    返回:
        修改后的数据结构(注意:元组会被转换为列表)

    异常情况:
        如果路径不存在或类型不支持,会引发KeyError或TypeError
    """
    if not keys:
        return value

    current = obj
    # 遍历除最后一个key外的所有key
    for key in keys[:-1]:
        if isinstance(current, dict):
            if key not in current:
                current[key] = {}  # 自动创建中间字典
            current = current[key]
        elif isinstance(current, (list, tuple)):
            if not isinstance(key, int) or key < 0 or key >= len(current):
                raise IndexError(f"Invalid list index: {key}")
            current = current[key]
        else:
            raise TypeError(f"Cannot access key '{key}' in object of type {type(current)}")

    # 设置最终的值
    last_key = keys[-1]
    if isinstance(current, dict):
        current[last_key] = value
    elif isinstance(current, (list, tuple)):
        if not isinstance(last_key, int) or last_key < 0 or last_key > len(current):
            raise IndexError(f"Invalid list index: {last_key}")
        if last_key == len(current):
            # 扩展列表
            current.append(value)
        else:
            current[last_key] = value
    else:
        raise TypeError(f"Cannot set key '{last_key}' in object of type {type(current)}")

    return obj


def check_column_actual(actual_columns, expected_columns, context):
    """检查实际列名是否与预期列名一致"""
    for col in expected_columns:
        if col not in actual_columns:
            logging.error(f"在 {context} 中未找到预期列名: {col}")
            return False
    return True


def check_no_empty_lines_before_first_line(dataframe, context=""):
    empty_line = 0
    # 检查是否有空行
    for _, row in dataframe.iterrows():
        if row.isnull().all():
            empty_line += 1
        else:
            break

    pytest.assume(empty_line == 0, f"{context} table has {empty_line} empty lines before first line.")


def check_no_empty_lines_between_first_last_line(dataframe, context=""):
    # 计算非空行的数量
    empty_rows = dataframe.eq('').all(axis=1)
    num_empty_rows = empty_rows.sum()
    pytest.assume(num_empty_rows == 0, f"{context} table has empty lines.")


def check_during_time(dataframe, context=""):
    # 检查所需列是否存在于数据框中
    required_columns = ['end_time(ms)', 'start_time(ms)', 'during_time(ms)']
    for col in required_columns:
        if col not in dataframe.columns:
            logging.error(f"The column {col} not found in {context}.")
            return False

    # 检查during_time是否正确
    for index, row in dataframe.iloc[:-1].iterrows():
        end_time = row['end_time(ms)']
        start_time = row['start_time(ms)']
        during_time = row['during_time(ms)']
        # 计算 end_time - start_time 与 during_time 的差值
        diff = abs((end_time - start_time) - during_time)
        if diff > 1:
            logging.error(f"In row {index} of {context}, the during_time is not correct.")
            return False

    return True


def check_split_csv_content(output_path, csv_file_name):
    # 校验该路径下是否正确生成csv文件,以及文件内容
    csv_file = os.path.join(output_path, csv_file_name)
    pytest.assume(os.path.exists(csv_file), f"CSV file not found: {csv_file}")
    task_name = os.path.splitext(csv_file_name)[0]
    expected_header = ['name', 'during_time(ms)', 'max', 'min', 'mean', 'std', \
                       'pid', 'tid', 'start_time(ms)', 'end_time(ms)']
    if task_name == 'prefill':
        expected_header.append('rid')
    df = pd.read_csv(csv_file)
    # 检查列名是否正确
    result = check_column_actual(df.columns.tolist(), expected_header, context=csv_file_name)
    pytest.assume(result, f"{csv_file_name} column check failed")
    # 检查是否存在空行
    check_no_empty_lines_before_first_line(df, context=csv_file_name)
    check_no_empty_lines_between_first_last_line(df, context=csv_file_name)
    # 检查执行时间是否正确
    result = check_during_time(df, context=csv_file_name)
    pytest.assume(result, f"{csv_file_name} execution time validation failed")


def check_row(df, expected_columns, numeric_columns):
    """检查数据框中Metric列数据类型和指定列数据是否为数字"""
    # 检查Metric列的数据类型是否为字符串
    for row_index in df.index:
        try:
            value = df.at[row_index, 'Metric']
            if not isinstance(value, str):
                logging.error(f"在Metric列的第{row_index}行,值 '{value}' 不是字符串类型")
                return False
        except KeyError:
            logging.error(f"数据框中不存在 'Metric' 列")
            return False

    # 检查其他列的数据是否为数字
    for column in numeric_columns:
        if column not in df.columns:
            logging.error(f"数据框中不存在 {column} 列")
            continue
        for row_index in df.index:
            try:
                cell_value = df.at[row_index, column]
                float(cell_value)
            except (ValueError, KeyError):
                logging.error(
                    f"在 {column} 列的第 {row_index} 行,值 {cell_value} 不是有效的数字")
                return False
    return True