# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0  (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest

from msprof_analyze.compare_tools.compare_backend.compare_bean.profiling_info import ProfilingInfo


class TestProfilingInfo(unittest.TestCase):
    def test_calculate_schedule_time(self):
        info = ProfilingInfo("NPU")
        info.e2e_time = 10
        info.compute_time = 5
        info.communication_not_overlapped = 3
        info.calculate_schedule_time()
        self.assertEqual(info.scheduling_time, 2)

    def test_update_fa_fwd_info(self):
        info = ProfilingInfo("NPU")
        info.fa_time_fwd_cube = 5
        info.fa_time_fwd_vector = 5
        info.fa_num_fwd_cube = 1
        info.fa_num_fwd_vector = 1
        self.assertEqual(info.fa_time_fwd, 0.01)
        self.assertEqual(info.fa_num_fwd, 2)

    def test_update_fa_bwd_info(self):
        info = ProfilingInfo("NPU")
        info.fa_time_bwd_cube = 5
        info.fa_time_bwd_vector = 5
        info.fa_num_bwd_cube = 1
        info.fa_num_bwd_vector = 1
        self.assertEqual(info.fa_time_bwd, 0.01)
        self.assertEqual(info.fa_num_bwd, 2)

    def test_update_sdma_info(self):
        info = ProfilingInfo("NPU")
        info.sdma_time_tensor_move = 5
        info.sdma_time_stream = 5
        info.sdma_num_tensor_move = 5
        info.sdma_num_stream = 5
        self.assertEqual(info.sdma_time, 0.01)
        self.assertEqual(info.sdma_num, 10)

    def test_update_cube_info(self):
        info = ProfilingInfo("NPU")
        info.matmul_time_cube = 1
        info.matmul_time_vector = 1
        info.other_cube_time = 1
        info.matmul_num_cube = 5
        info.matmul_num_vector = 5
        info.other_cube_num = 5
        self.assertEqual(info.cube_time, 0.003)
        self.assertEqual(info.cube_num, 15)

    def test_update_vec_info(self):
        info = ProfilingInfo("NPU")
        info.vector_time_trans = 1
        info.vector_time_notrans = 1
        info.vector_num_trans = 2
        info.vector_num_notrans = 2
        self.assertEqual(info.vec_time, 0.002)
        self.assertEqual(info.vec_num, 4)

    def test_set_compute_time(self):
        info = ProfilingInfo("NPU")
        info.update_compute_time(1)
        info.set_compute_time(5)
        self.assertEqual(info.compute_time, 5)

    def test_update_compute_time(self):
        info = ProfilingInfo("NPU")
        info.update_compute_time(5)
        info.update_compute_time(5)
        self.assertEqual(info.compute_time, 10)

    def test_set_e2e_time(self):
        info = ProfilingInfo("NPU")
        info.set_e2e_time(5)
        self.assertEqual(info.e2e_time, 5)

    def test_set_comm_not_overlap(self):
        info = ProfilingInfo("NPU")
        info.update_comm_not_overlap(10)
        info.set_comm_not_overlap(5)
        self.assertEqual(info.communication_not_overlapped, 5)

    def test_update_comm_not_overlap(self):
        info = ProfilingInfo("NPU")
        info.update_comm_not_overlap(5)
        info.update_comm_not_overlap(5)
        self.assertEqual(info.communication_not_overlapped, 10)

    def test_set_memory_used(self):
        info = ProfilingInfo("NPU")
        info.set_memory_used(10)
        self.assertEqual(info.memory_used, 10)

    def test_is_not_minimal_profiling(self):
        info = ProfilingInfo("GPU")
        info.minimal_profiling = False
        self.assertFalse(info.is_not_minimal_profiling())
        info = ProfilingInfo("NPU")
        info.minimal_profiling = True
        self.assertFalse(info.is_not_minimal_profiling())
        info.minimal_profiling = False
        self.assertTrue(info.is_not_minimal_profiling())