import torch

from vllm_ascend.utils import device_print


def compute_and_print(x: torch.Tensor) -> torch.Tensor:
    y = torch.square(x) - torch.cos(x)
    device_print("device_print from current execution mode")
    device_print(7)
    device_print(True)
    device_print(y)
    device_print(f"Compatible with f-strings: {x.dtype = }, {isinstance(x, torch.Tensor) = }")
    return y


def main() -> None:
    torch.npu.set_device(0)
    torch.npu.set_compile_mode(jit_compile=False)

    x = torch.arange(1, 28, dtype=torch.float32).reshape(3, 3, 3).npu()

    print("=== eager ===", flush=True)
    eager_out = compute_and_print(x)
    torch.npu.synchronize()

    print("=== torch.compile(backend='aot_eager') ===", flush=True)
    compiled_compute_and_print = torch.compile(compute_and_print, backend="aot_eager")
    compiled_out = compiled_compute_and_print(x)
    torch.npu.synchronize()

    assert torch.allclose(eager_out, compiled_out), "Outputs from eager and compiled modes do not match."

    graph = torch.npu.NPUGraph()
    capture_stream = torch.npu.Stream()
    x_capture = x.clone()

    with torch.npu.stream(capture_stream), torch.npu.graph(graph, stream=capture_stream):
        captured_out = compiled_compute_and_print(x_capture)

    print("=== replay graph ===", flush=True)
    graph.replay()
    torch.npu.synchronize()

    assert torch.allclose(eager_out, captured_out), "Outputs from eager and graph modes do not match."

    print("=== modify input and replay graph ===", flush=True)
    x_capture.copy_(torch.arange(28, 1, -1, dtype=torch.float32).reshape(3, 3, 3).npu())
    graph.replay()
    torch.npu.synchronize()

    assert not torch.allclose(eager_out, captured_out), "Outputs from eager and modified graph modes should not match."


if __name__ == "__main__":
    main()