#!/usr/bin/env python
# coding=utf-8
# Copyright (c) Huawei Technologies Co., Ltd. 2024-2026. All rights reserved.
# MindIE is licensed under Mulan PSL v2.
# You can use this software according to the terms and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
#          http://license.coscl.org.cn/MulanPSL2
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.

import time
from collections.abc import Callable, Sequence

import torch

WARMUP = 5
REPEAT = 10
TAIL = 5


def benchmark(
    func: Callable,
    args: Sequence[torch.Tensor],
    warmup: int = WARMUP,
    repeat: int = REPEAT,
) -> float:
    for _ in range(warmup):
        func(*args)
        torch.npu.synchronize()

    times = []
    for _ in range(repeat):
        torch.npu.synchronize()
        start = time.perf_counter()
        func(*args)
        torch.npu.synchronize()
        end = time.perf_counter()
        times.append(end - start)

    return sum(times[-TAIL:]) / TAIL