import triton.profiler as proton

import torch
import sys

from helper_kernels import custom_add, matmul_kernel


def main():
    a = torch.zeros(1, device="cuda")
    with proton.scope("test"):
        custom_add[(1, )](a)


def test_main():
    main()


def matmul():
    a = torch.randn((32, 32), device="cuda", dtype=torch.float16)
    b = torch.randn((32, 32), device="cuda", dtype=torch.float16)
    M, K = a.shape
    K, N = b.shape
    c = torch.empty((M, N), device=a.device, dtype=a.dtype)

    matmul_kernel[(1, )](
        a, b, c,  #
        M, N, K,  #
        a.stride(0), a.stride(1),  #
        b.stride(0), b.stride(1),  #
        c.stride(0), c.stride(1),  #
        128, 256, 64, 8)
    return c


if __name__ == "__main__":
    if sys.argv[1] == "test":
        main()
    elif sys.argv[1] == "test_matmul":
        matmul()