# Copyright (c) 2026 Huawei Technologies Co., Ltd.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
import pypto
from pypto import pil, ir, logging

# ---------- Ops compile tests ----------


def _compile_ops(func, *args):
    return pil.compile(func, *args)


def _tensor_ops_of(ops):
    return [op for op in ops if isinstance(op, ir.TensorOpStmt)]


def test_tensor_binary_ops():
    def f(x, y):
        pypto.set_vec_tile_shapes(16, 16)
        z = x / y
        z = x ** y

    x = pypto.Tensor(shape=(4, 4), dtype=pypto.DT_FP32)
    y = pypto.Tensor(shape=(4, 4), dtype=pypto.DT_FP32)
    func = _compile_ops(f, x, y)
    ts = _tensor_ops_of(func.body)
    opcodes = {s.opcode for s in ts}
    assert 'DIV' in opcodes
    assert 'POW' in opcodes


def test_scalar_binary_bitwise_ops():
    def f(x, y):
        z = x | y
        z = x ^ y
        z = x | y
        z = x ^ y
        z = x & y
        z = x << y
        z = x >> y

    func = _compile_ops(f, 1, 2)
    ts = _tensor_ops_of(func.body)
    assert not ts


def test_tensor_matmul_op():
    def f(x, y):
        pypto.set_cube_tile_shapes([16, 16], [16, 16], [16, 16])
        z = x @ y

    x = pypto.Tensor(shape=(32, 32), dtype=pypto.DT_FP32)
    y = pypto.Tensor(shape=(32, 32), dtype=pypto.DT_FP32)
    func = _compile_ops(f, x, y)
    ts = _tensor_ops_of(func.body)
    opcodes = {s.opcode for s in ts}
    assert 'A_MUL_B' in opcodes


def test_tensor_unary_ops():
    def f(x):
        pypto.set_vec_tile_shapes(16, 16)
        z = -x
        z = +x

    x = pypto.Tensor(shape=(4, 4), dtype=pypto.DT_INT32)
    func = _compile_ops(f, x)
    ts = _tensor_ops_of(func.body)
    opcodes = {s.opcode for s in ts}
    # pypto use MULS to impl neg and pos just return self
    assert 'MULS' in opcodes


def test_pypto2ir():
    def f(x, y):
        pypto.set_vec_tile_shapes(16, 16)
        z = pypto.Tensor((x.shape[0], 64), dtype=pypto.DT_FP32, name="z")
        i = 0
        tx = pypto.view(x, [16, 16], [i * 32, 0])
        tx2 = pypto.add(tx, tx)
        pypto.assemble(tx2, [i * 32, 0], z)

        tz = pypto.view(z, [16, 16], [i * 32, 0])
        tz2 = pypto.add(tz, tz)
        pypto.assemble(tz2, [i * 32, 0], y)
        min(1, 2, 3, 4, 5)

    x = pypto.Tensor(shape=(-1, 64), dtype=pypto.DT_FP32, name="x")
    y = pypto.Tensor(shape=(-1, 64), dtype=pypto.DT_FP32, name="y")
    func = pil.compile(f, x, y)


def test_ir_range():
    results = []

    def foo(n):
        total = 0
        for i in range(n):
            # nested if
            if i < 3:
                if i == 0:
                    results.append(("nested_if", i))
                else:
                    results.append(("nested_else", i))
                    total += i
            # nested loop with continue/break
            if i == 5:
                for j in range(4):
                    if j == 1:
                        continue
                    if j == 2:
                        break
                    results.append(("inner", j))
                    total += j
            # break in outer loop
            if i == 7:
                break
            total += i
        results.append(("final", i, total))

    foo(10)
    expected, results = results, []
    pil.compile(foo, 10)
    assert results == expected


def test_ir_loop():
    def foo(n):
        ans = 0
        for i in pypto.loop(n):
            ans += i
        ans = ans + 10
    pil.compile(foo, 10)

# ---------- Loop IR tests ----------


def _for_ops_of(ops):
    return [op for op in ops if isinstance(op, ir.ForStmt)]


def test_ir_loop_basic():
    """Carry variables: ans and loop var i."""
    def foo(n):
        ans = 0
        for i in pypto.loop(n):
            ans += i
    func = pil.compile(foo, 10)
    for_stmts = _for_ops_of(func.body)
    assert len(for_stmts) == 1
    f = for_stmts[0]
    assert len(f.iter_args) == 2
    assert len(f.return_vars) == 2
    assert isinstance(f.start, ir.ConstInt) and f.start.value == 0
    assert isinstance(f.stop, ir.ConstInt) and f.stop.value == 10
    assert isinstance(f.step, ir.ConstInt) and f.step.value == 1


def test_ir_loop_no_carry():
    """Loop var i is always carried."""
    def foo(n):
        for _ in pypto.loop(n):
            pass
    func = pil.compile(foo, 5)
    for_stmts = _for_ops_of(func.body)
    assert len(for_stmts) == 1
    f = for_stmts[0]
    assert len(f.iter_args) == 1
    assert len(f.return_vars) == 1


def test_ir_loop_two_carries():
    """Loop carrying ans, count, and loop var i."""
    def foo(n):
        ans = 0
        count = 1
        for i in pypto.loop(n):
            ans += i
            count += 1
    func = pil.compile(foo, 10)
    for_stmts = _for_ops_of(func.body)
    assert len(for_stmts) == 1
    f = for_stmts[0]
    assert len(f.iter_args) == 3
    assert len(f.return_vars) == 3


def test_ir_loop_range_two_args():
    """pypto.loop(start, stop) form."""
    def foo():
        ans = 0
        for i in pypto.loop(2, 10):
            ans += i
    func = pil.compile(foo)
    for_stmts = _for_ops_of(func.body)
    assert len(for_stmts) == 1
    f = for_stmts[0]
    assert isinstance(f.start, ir.ConstInt) and f.start.value == 2
    assert isinstance(f.stop, ir.ConstInt) and f.stop.value == 10
    assert isinstance(f.step, ir.ConstInt) and f.step.value == 1
    assert len(f.iter_args) == 2


def test_ir_loop_range_three_args():
    """pypto.loop(start, stop, step) form."""
    def foo():
        ans = 0
        for i in pypto.loop(1, 20, 3):
            ans += i
    func = pil.compile(foo)
    for_stmts = _for_ops_of(func.body)
    assert len(for_stmts) == 1
    f = for_stmts[0]
    assert isinstance(f.start, ir.ConstInt) and f.start.value == 1
    assert isinstance(f.stop, ir.ConstInt) and f.stop.value == 20
    assert isinstance(f.step, ir.ConstInt) and f.step.value == 3
    assert len(f.iter_args) == 2


def test_ir_loop_sequential():
    """Two sequential loops, both carry ans and their loop var."""
    def foo(n):
        ans = 0
        for i in pypto.loop(n):
            ans += i
        for j in pypto.loop(n):
            ans += j
    func = pil.compile(foo, 5)
    for_stmts = _for_ops_of(func.body)
    assert len(for_stmts) == 2
    # Both loops carry ans and their loop var
    for f in for_stmts:
        assert len(f.iter_args) == 2
        assert len(f.return_vars) == 2


def test_ir_loop_nested():
    """Nested loops: inner ForStmt is inside outer's body."""
    def foo(n, m):
        ans = 0
        for i in pypto.loop(n):
            for j in pypto.loop(m):
                if j % 2 == 0:
                    ans += i + j
                else:
                    ans += i - j
            ans += n
    pil.compile(foo, 4, 5)


def test_ir_loop_nested1():
    """Nested loops: inner ForStmt is inside outer's body."""
    def foo(n, m):
        ans = 0
        for i in pypto.loop(n):
            for j in pypto.loop(m):
                if j % 2 == 0:
                    ans += i + j
                    break
                else:
                    ans += i - j
            ans += n
    pil.compile(foo, 4, 5)


def test_ir_loop_carry_used_after():
    """Carry variable used after the loop produces a scalar op."""
    def foo(n):
        ans = 0
        for i in pypto.loop(n):
            ans += i
        ans = ans + 10
    func = pil.compile(foo, 10)
    for_stmts = _for_ops_of(func.body)
    assert len(for_stmts) == 1
    assert len(for_stmts[0].iter_args) == 2  # i and ans


def test_ir_loop_body_multiple_ops():
    """i, a and ans are all carried."""
    def foo(n):
        ans = 0
        for i in pypto.loop(n):
            a = i * 2
            ans += a
        ans = ans + a
    func = pil.compile(foo, 10)
    for_stmts = _for_ops_of(func.body)
    assert len(for_stmts) == 1
    f = for_stmts[0]
    assert len(f.iter_args) == 3
    assert len(f.return_vars) == 3
    # Body should contain mul and add scalar ops


def test_ir_deadcode():
    def foo(n):
        for i in pypto.loop(n):
            x = i + 1
            break
            return x

    pil.compile(foo, 10)

# ---------- Loop control flow IR tests ----------


def _collect_stmts(stmt, cls):
    """Recursively collect all stmts of a given ir."""
    result = []
    if isinstance(stmt, cls):
        result.append(stmt)
    if isinstance(stmt, ir.SeqStmts):
        for s in stmt.stmts:
            result.extend(_collect_stmts(s, cls))
    if isinstance(stmt, ir.IfStmt):
        result.extend(_collect_stmts(stmt.then_body, cls))
        result.extend(_collect_stmts(stmt.else_body, cls))
    if isinstance(stmt, ir.ForStmt):
        result.extend(_collect_stmts(stmt.body, cls))
    return result


def test_ir_loop_break():
    """break inside loop body generates BreakStmt."""
    def foo(n):
        ans = 0
        for i in pypto.loop(n):
            ans += i
            break
    func = pil.compile(foo, 10)
    for_stmts = _for_ops_of(func.body)
    assert len(for_stmts) == 1
    breaks = _collect_stmts(for_stmts[0].body, ir.BreakStmt)
    assert len(breaks) == 1


def test_ir_loop_continue():
    """continue inside loop body generates ContinueStmt."""
    def foo(n):
        ans = 0
        for i in pypto.loop(n):
            continue
            ans += i
    func = pil.compile(foo, 10)
    for_stmts = _for_ops_of(func.body)
    assert len(for_stmts) == 1
    continues = _collect_stmts(for_stmts[0].body, ir.ContinueStmt)
    assert len(continues) == 1


def test_ir_loop_return():
    """return inside loop body generates ReturnStmt."""
    def foo(n):
        ans = 0
        for i in pypto.loop(n):
            ans += i
            return
    func = pil.compile(foo, 10)
    for_stmts = _for_ops_of(func.body)
    assert len(for_stmts) == 1
    returns = _collect_stmts(for_stmts[0].body, ir.ReturnStmt)
    assert len(returns) == 1


def test_ir_loop_break_in_if():
    """break inside if_else generates IfStmt with BreakStmt in then branch."""
    def foo(n):
        ans = 0
        for i in pypto.loop(n):
            if i:
                break
            ans += i
    func = pil.compile(foo, 10)
    for_stmts = _for_ops_of(func.body)
    assert len(for_stmts) == 1
    body = for_stmts[0].body
    # Body should contain IfStmt
    if_stmts = _collect_stmts(body, ir.IfStmt)
    assert len(if_stmts) >= 1
    # Then branch should have BreakStmt
    then_breaks = _collect_stmts(if_stmts[0].then_body, ir.BreakStmt)
    assert len(then_breaks) == 1


def test_ir_loop_continue_in_if():
    """continue inside if_else generates IfStmt with ContinueStmt."""
    def foo(n):
        ans = 0
        for i in pypto.loop(n):
            if i:
                continue
            ans += i
    func = pil.compile(foo, 10)
    for_stmts = _for_ops_of(func.body)
    assert len(for_stmts) == 1
    if_stmts = _collect_stmts(for_stmts[0].body, ir.IfStmt)
    assert len(if_stmts) >= 1
    then_conts = _collect_stmts(if_stmts[0].then_body, ir.ContinueStmt)
    assert len(then_conts) == 1


def test_ir_loop_if_else_both_branches():
    """if/else inside loop compiles both branches."""
    def foo(n):
        ans = 0
        for i in pypto.loop(n):
            if i:
                ans += i
            else:
                ans += 1
    func = pil.compile(foo, 10)
    for_stmts = _for_ops_of(func.body)
    assert len(for_stmts) == 1
    if_stmts = _collect_stmts(for_stmts[0].body, ir.IfStmt)
    assert len(if_stmts) >= 1


def test_tensor_add_dyn():
    """Add dynamic tensor should be supported."""
    def foo(x, y):
        for i in pypto.loop(x.shape[0] // 32):
            pypto.set_vec_tile_shapes(32, 32)
            ta = x[i:i + 32, :]
            if i % 2 == 0:
                y[i:, :] = ta + 1
            else:
                y[i:, :] = ta - 1

    x = pypto.Tensor((-1, 32), pypto.DT_FP32, 'x')
    y = pypto.Tensor((-1, 32), pypto.DT_FP32, 'y')
    pil.compile(foo, x, y)


def test_fstring():
    def foo(x, y):
        # basic expressions
        assert f"{x + y=} , {x - y=}" == "x + y=30 , x - y=-10"
        # simple variable reference
        assert f"x={x}, y={y}" == "x=10, y=20"
        # arithmetic in f-string
        assert f"sum={x + y}, diff={x - y}, prod={x * y}" == "sum=30, diff=-10, prod=200"
        # format specifiers
        assert f"x={x:05d}, y={y:05d}" == "x=00010, y=00020"
        # nested expressions with modulo
        assert f"mod={x % 3}, pow={x ** 2}" == "mod=1, pow=100"
        # mixed literal and expression parts
        assert f"result: ({x} + {y}) = {x + y}" == "result: (10 + 20) = 30"
        # conversion specifiers
        assert f"{x!r}, {y!s}" == "10, 20"
    pil.compile(foo, 10, 20)