import os
import hashlib
import torch


DEFAULT_OUTPUT_DIR = "./dvm_fx_regression_cases"


def generate_dvm_fx_case(
    gm: torch.fx.GraphModule,
    output_dir: str = DEFAULT_OUTPUT_DIR,
    fusion_type: str = "graph",
):
    if fusion_type not in ("graph", "mlir"):
        raise ValueError(f"unsupported fusion_type: {fusion_type}")

    def _indent(code: str, spaces: int) -> str:
        pad = " " * spaces
        return "\n".join(
            pad + line if line.strip() else line
            for line in code.split("\n")
        )

    os.makedirs(output_dir, exist_ok=True)

    readable = gm.print_readable(print_output=False)

    sig = []
    inputs = [n for n in gm.graph.nodes if n.op == "placeholder"]
    for i, n in enumerate(inputs):
        v = n.meta["val"]
        sig.append(f"arg{i}:{tuple(v.shape)},{tuple(v.stride())},{v.dtype}")

    h = hashlib.sha256(
        (fusion_type + "\n" + readable + "\n" + "\n".join(sig)).encode("utf-8")
    ).hexdigest()[:16]
    case_name = f"test_{fusion_type}_{h}"

    file_path = os.path.join(output_dir, f"{case_name}.py")
    if os.path.exists(file_path):
        print(f"[skip] {file_path}")
        return None
    class_name = "TestModel"

    input_lines = []
    for i, n in enumerate(inputs):
        v = n.meta["val"]
        fill = "random_()" if v.dtype == torch.bool else "uniform_(0, 1)"
        input_lines.append(
            f"arg{i} = torch.empty_strided("
            f"torch.Size({tuple(v.shape)}), "
            f"{tuple(v.stride())}, "
            f"dtype={v.dtype}, device='npu').{fill}"
        )

    input_code = "\n    ".join(input_lines)
    fwd_args = ", ".join(f"arg{i}" for i in range(len(inputs)))

    if fusion_type == "graph":
        fusion_env = ""
        fusion_imports = (
            "from torch_npu._inductor.dvm.graph_fusion "
            "import DvmGraphFusionPatch"
        )
        compile_lines = [
            "with DvmGraphFusionPatch():",
            "    compiled = torch.compile(model, backend=\"inductor\", dynamic=False)",
            f"    out = compiled({fwd_args})",
            "    deterministic_state = torch.are_deterministic_algorithms_enabled()",
            "    deterministic_warn_only = torch.is_deterministic_algorithms_warn_only_enabled()",
            "    try:",
            "        torch.use_deterministic_algorithms(True)",
            "        deterministic_compiled = torch.compile(model, backend=\"inductor\", dynamic=False)",
            f"        deterministic_out = deterministic_compiled({fwd_args})",
            "    finally:",
            "        torch.use_deterministic_algorithms(deterministic_state, warn_only=deterministic_warn_only)",
        ]
    else:
        fusion_env = 'os.environ["TORCHINDUCTOR_NPU_BACKEND"] = "dvm"'
        fusion_imports = "from torch_npu._inductor.dvm import mlir_fusion"
        compile_lines = [
            "compiled = torch.compile(model, backend=\"inductor\", dynamic=False)",
            f"out = compiled({fwd_args})",
            "deterministic_state = torch.are_deterministic_algorithms_enabled()",
            "deterministic_warn_only = torch.is_deterministic_algorithms_warn_only_enabled()",
            "try:",
            "    torch.use_deterministic_algorithms(True)",
            "    deterministic_compiled = torch.compile(model, backend=\"inductor\", dynamic=False)",
            f"    deterministic_out = deterministic_compiled({fwd_args})",
            "finally:",
            "    torch.use_deterministic_algorithms(deterministic_state, warn_only=deterministic_warn_only)",
        ]
    compile_code = "\n    ".join(compile_lines)
    env_lines = fusion_env

    test_code = f"""import torch
import torch_npu
from torch import device
from torch.utils._pytree import tree_flatten
import os
{env_lines}


class {class_name}(torch.nn.Module):
    def __init__(self):
        super().__init__()
{_indent(gm.code, 4)}


def _assert_close(ref, out, atol=2e-3, rtol=2e-3):
    rf, rs = tree_flatten(ref)
    of, os = tree_flatten(out)
    assert rs == os, f"pytree mismatch\\nref={{rs}}\\nout={{os}}"
    for r, o in zip(rf, of):
        torch.testing.assert_close(r, o, atol=atol, rtol=rtol, equal_nan=True)


def test_case():
    {input_code}

    model = {class_name}().npu()
    ref = model({fwd_args})

    {fusion_imports}

    {compile_code}

    _assert_close(ref, out)
    _assert_close(ref, deterministic_out)


if __name__ == "__main__":
    test_case()
    print("PASS")
"""

    with open(file_path, "w", encoding="utf-8") as f:
        f.write(test_code)

    print(f"[ok] generated: {file_path}")
    return file_path