#!/usr/bin/env python3

# coding: utf-8

# Copyright (c) 2026 Huawei Technologies Co., Ltd.

# This program is free software, you can redistribute it and/or modify it under the terms and conditions of

# CANN Open Software License Agreement Version 2.0 (the "License").

# Please refer to the License for details. You may not use this file except in compliance with the License.

# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,

# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.

# See LICENSE in the root of the software repository for the full text of the License.

# -----------------------------------------------------------------------------------------------------------

"""

"""

import os

import math

import pypto

import pytest

from numpy.testing import assert_allclose

import torch

import torch_npu





def test_pow_onboard():

    device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))

    torch.npu.set_device(device_id)

    shape = (64, 64)

    view_shape = (32, 32)

    tile_shape = (32, 32)

    pypto.runtime._device_init()



    input1 = pypto.tensor(shape, pypto.DT_FP32, "PTO_TENSOR_input1")

    input2 = pypto.tensor(shape, pypto.DT_FP32, "PTO_TENSOR_input2")

    output = pypto.tensor(shape, pypto.DT_FP32, "PTO_TENSOR_output")



    b_loop_num = math.ceil(shape[0] / view_shape[0])

    s_loop_num = math.ceil(shape[1] / view_shape[1])

    with pypto.function("MAIN", input1, input2, output):

        for b_idx in pypto.loop(b_loop_num, name="b0", idx_name="bidx"):

            for s_idx in pypto.loop(s_loop_num, name="s0", idx_name="sidx"):

                view_tensor_a = pypto.view(

                    input1, view_shape, [b_idx * view_shape[0], s_idx * view_shape[1]],

                    valid_shape=[

                        pypto.min(pypto.symbolic_scalar(shape[0]) - b_idx * view_shape[0],

                                pypto.symbolic_scalar(view_shape[0])),

                        pypto.min(pypto.symbolic_scalar(shape[1]) - s_idx * view_shape[1],

                                pypto.symbolic_scalar(view_shape[1])),

                    ]

                )

                view_tensor_b = pypto.view(

                    input2, view_shape, [b_idx * view_shape[0], s_idx * view_shape[1]],

                    valid_shape=[

                        pypto.min(pypto.symbolic_scalar(shape[0]) - b_idx * view_shape[0],

                                pypto.symbolic_scalar(view_shape[0])),

                        pypto.min(pypto.symbolic_scalar(shape[1]) - s_idx * view_shape[1],

                                pypto.symbolic_scalar(view_shape[1])),

                    ]

                )

                pypto.set_vec_tile_shapes(tile_shape[0], tile_shape[1])



                result = pypto.pow(view_tensor_a, view_tensor_b)

                pypto.assemble(

                    result, [b_idx * view_shape[0], s_idx * view_shape[1]], output)



    assert isinstance(output, pypto.tensor)



    a_tensor = torch.randint(low=-10, high=10, size=[shape[0], shape[1]], dtype=torch.float32)

    b_tensor = torch.randint(low=-4, high=4, size=[shape[0], shape[1]], dtype=torch.float32)

    c_tensor = torch.zeros(shape[0], shape[1], dtype=torch.float32)



    pto_a_tensor = pypto.from_torch(a_tensor, "a_tensor")

    pto_b_tensor = pypto.from_torch(b_tensor, "b_tensor")

    pto_c_tensor = pypto.from_torch(c_tensor, "c_tensor")



    pypto.runtime._device_run_once_data_from_host(pto_a_tensor, pto_b_tensor, pto_c_tensor)



    golden = torch.pow(a_tensor, b_tensor)

    assert_allclose(c_tensor.flatten(), golden.flatten(), rtol=3e-3, atol=3e-3)

    pypto.runtime._device_fini()