import importlib
import os
import sys
from typing import Any, Iterable, Mapping

import torch
from torch._inductor.compile_fx import clone_preserve_strides


def clone_for_accuracy(arg):
    if not isinstance(arg, torch.Tensor):
        return arg
    cloned = clone_preserve_strides(arg)
    return cloned.float() if cloned.dtype == torch.bfloat16 else cloned


def compare_outputs(
    actual_outputs: Iterable[Any],
    expected_outputs: Iterable[Any],
    kernel_name: str,
    tolerances: Mapping[Any, Mapping[str, float]],
    dump_path: str = "",
):
    failed_indices = []
    for idx, (actual, expected) in enumerate(zip(actual_outputs, expected_outputs)):
        if not isinstance(actual, torch.Tensor) or not isinstance(expected, torch.Tensor):
            continue
        if actual.dtype != expected.dtype:
            expected = expected.to(actual.dtype)

        tol = tolerances.get(actual.dtype, tolerances["default"])
        rtol, atol = tol["rtol"], tol["atol"]
        matches = torch.isclose(actual, expected, rtol=rtol, atol=atol, equal_nan=True)
        if not matches.all():
            _report_mismatch(idx, actual, expected, matches, rtol, atol, kernel_name, dump_path)
            failed_indices.append(idx)
        del matches

    return not failed_indices


def _report_mismatch(idx, actual, expected, matches, rtol, atol, kernel_name, dump_path=""):
    try:
        abs_diff = torch.abs(actual - expected)
    except RuntimeError:
        abs_diff = torch.abs(actual.to(torch.float32) - expected.to(torch.float32))
    expected_abs = torch.abs(expected)
    if not expected_abs.is_floating_point() and not expected_abs.is_complex():
        expected_abs = expected_abs.to(torch.float32)
    rel_diff = abs_diff / torch.clamp(expected_abs, min=1e-20)
    rel_diff.masked_fill_(matches, 0)
    number_of_elements = matches.numel()
    total_mismatches = number_of_elements - int(torch.sum(matches))
    msg = (
        "CHECK ACCURACY FAILED! "
        f"Kernel: {kernel_name}, Output idx: {idx}, "
        f"Mismatched: {total_mismatches}/{number_of_elements} "
        f"({total_mismatches / number_of_elements:.1%}), "
        f"Greatest Rel Diff: {rel_diff.max().item()}, "
        f"Greatest Abs Diff: {abs_diff.max().item()}, "
        f"rtol: {rtol}, atol: {atol}"
    )
    if dump_path:
        msg += f", dump_path: {dump_path}"
    print(msg, flush=True)
    del abs_diff, rel_diff


def get_triton_fx_graph_call(inductor_meta, auto_fallback=False):
        kernel_name = inductor_meta.get("kernel_name", "triton_")
        traced_graph_hash = inductor_meta.get("traced_graph_hash")
        if not traced_graph_hash:
            return None, None, None, None
        dump_dir = inductor_meta.get("traced_graph_dir", "")
        dump_path = os.path.join(dump_dir, traced_graph_hash)
        if dump_dir == "" or not os.path.exists(dump_path):
            return None, None, None, None
        sys.path.append(dump_path)
        fx_module = importlib.import_module(traced_graph_hash)
        sys.path.remove(dump_path)

        model = fx_module.model
        num_inputs = fx_module.num_inputs
        num_outputs = fx_module.num_outputs
        non_contiguous_indices = fx_module.non_contiguous_indices
        mismatch_indices_shapes = fx_module.mismatch_indices_shapes

        def fx_graph_call(*fx_args):
            fx_inputs = [fx_args[idx].contiguous() if idx in non_contiguous_indices['inputs'] else \
                             fx_args[idx] for idx in range(num_inputs)]
            if len(mismatch_indices_shapes):
                for ind, shape in mismatch_indices_shapes.items():
                    if ind >= num_inputs:
                        break
                    fx_inputs[ind] = fx_inputs[ind].reshape(shape)
            model_outputs = model.forward(*fx_inputs)
            for idx, (out1, out2) in enumerate(zip(model_outputs, fx_args[num_inputs:(num_inputs + num_outputs)])):
                out1 = out1.reshape(out2.shape)
                if idx in non_contiguous_indices['outputs']:
                    out2.copy_(out1)
                else:
                    out2.data = out1.data

        def fallback_call(*args):
            fx_args = [args[idx] for idx in fx_module.call_args_mapping]
            return fx_graph_call(*fx_args)

        if auto_fallback:
            return fallback_call, kernel_name, None, None
        return fx_graph_call, kernel_name, dump_path, fx_module


def check_accuracy_triton(*args, launcher, grid, stream, inductor_meta, **kwargs):
    import torch_npu._inductor.config as npu_config
    fx_graph_call, kernel_name, dump_path, fx_module = get_triton_fx_graph_call(inductor_meta)
    if not fx_graph_call:
        return None
    call_outputs_indices = fx_module.call_args_mapping[fx_module.num_inputs:]

    fx_args = []
    for idx in fx_module.call_args_mapping:
        arg = args[idx]
        if not isinstance(arg, torch.Tensor):
            arg = torch.Tensor(arg).npu()
        fx_arg = clone_preserve_strides(arg).float() if arg.dtype == torch.bfloat16 else clone_preserve_strides(
                arg)
        fx_args.append(fx_arg)
    
    fx_graph_call(*fx_args)

    launcher(*args, **kwargs, stream=stream)

    compare_outputs(
        [args[i] for i in call_outputs_indices],
        fx_args[fx_module.num_inputs:],
        kernel_name=kernel_name,
        tolerances=npu_config.acc_comp_tol,
        dump_path=dump_path,
    )

    for arg in fx_args:
        del arg
    return True


def check_accuracy_mlir(*args, kernel_name, launchers, num_outputs, dynamic, **kwargs):
    from torch_npu._inductor.ascend_npu_ir.ascend_npu_ir import config as anir_config
    launcher_fx = launchers[1]
    launcher = launchers[0]

    num_inputs = len(args) - num_outputs
    fx_outputs = [clone_for_accuracy(arg) for arg in args[num_inputs:]]
    fx_inputs = [clone_for_accuracy(arg) for arg in args[:num_inputs]]
    fx_args = fx_inputs + fx_outputs

    launcher_fx(*fx_args, **kwargs)

    if dynamic:
        args_new = ()
        for arg in args:
            if not torch.is_tensor(arg):
                args_new = args_new + (arg,)
                continue
            args_new = args_new + (arg, arg, 0) + arg.size() + arg.stride()
    else:
        args_new = args
    
    output = launcher(*args_new, **kwargs)
    result = compare_outputs(
        args[num_inputs:],
        fx_outputs,
        kernel_name=kernel_name,
        tolerances=anir_config.acc_comp_tol,
    )
    del fx_inputs
    return (output, result)


def _load_fx_model(acc_meta):
    """Load the traced FX GraphModule from disk for accuracy comparison."""
    if acc_meta.get('_fx_model') is not None:
        return acc_meta['_fx_model']
    dump_path = os.path.join(
        os.getenv("TORCHINDUCTOR_CACHE_DIR"),
        acc_meta['traced_graph_cache'],
        str(acc_meta['device_index']),
        acc_meta['traced_graph_hash'],
    )
    sys.path.insert(0, dump_path)
    try:
        module = importlib.import_module(acc_meta['traced_graph_hash'])
    finally:
        sys.path.remove(dump_path)
    Model = getattr(module, acc_meta['traced_graph_hash'])
    model = Model()
    acc_meta['_fx_model'] = model
    return model
 
 
def check_accuracy_dvm(kobj, acc_meta, kernel_name, args):
    """Run DVM kernel then compare outputs against FX graph reference."""
    from torch_npu._inductor.ascend_npu_ir.ascend_npu_ir import config as anir_config

    fx_model = _load_fx_model(acc_meta)

    num_outputs = acc_meta['num_outputs']
    num_inputs = len(args) - num_outputs

    fx_inputs = [clone_for_accuracy(arg) for arg in args[:num_inputs]]
    fx_outputs = fx_model.forward(*fx_inputs)
    if not isinstance(fx_outputs, (tuple, list)):
        fx_outputs = (fx_outputs,)

    kobj.run(*args)

    compare_outputs(
        args[num_inputs:],
        fx_outputs,
        kernel_name=kernel_name,
        tolerances=anir_config.acc_comp_tol,
    )