# coding=utf-8
# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import namedtuple
import heapq

import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
import acl

from meta_gauss_render import gaussian_sort, get_render_schedule

torch.npu.set_device('npu:0')
torch.set_printoptions(sci_mode=False)

option = {}
option['ACL_OP_DEBUG_LEVEL'] = 1
torch.npu.set_option(option)

ExecResults = namedtuple('ExecResults', ['sorted_gs_ids'])
Inputs = namedtuple(
    'Inputs', ['lb_scheds', 'tile_sums', 'tile_depths', 'tile_gaussian_ids', 'sorted_offset', 'max_tile_gauss']
)


def _gaussian_sort(lb_scheds, tile_sums, tile_depths, tile_gauss_ids, sorted_offset):
    B, C, T = tile_sums.shape
    sorted_gs_ids = torch.zeros(int(sorted_offset[-1].item()), dtype=torch.int64)
    vector_num = acl.get_device_capability(0, 1)[0]
    num_bins = min(vector_num, T)
    for b in range(B):
        for c in range(C):
            base_id = b * C + c
            base_offset = int(sorted_offset[base_id - 1].item()) if base_id > 0 else 0
            scheduled_tile_ids = lb_scheds[b, c, num_bins:num_bins + T]
            tile_offsets = lb_scheds[b, c, num_bins + T:]
            for sched_idx in range(T):
                tile_id = int(scheduled_tile_ids[sched_idx].item())
                sorted_tile_offset = int(tile_offsets[tile_id - 1].item()) if tile_id > 0 else 0
                tile_gauss_num = int(tile_sums[b, c, tile_id].item())
                if tile_gauss_num == 0:
                    continue
                depths_t = tile_depths[b, c, tile_id, :tile_gauss_num]
                gauss_ids_t = tile_gauss_ids[b, c, tile_id, :tile_gauss_num]
                sort_idx = torch.argsort(depths_t, stable=True)
                write_start = base_offset + sorted_tile_offset
                write_end = write_start + tile_gauss_num
                sorted_gs_ids[write_start:write_end] = gauss_ids_t[sort_idx]
    return sorted_gs_ids


class TestGaussianSort(TestCase):
    def setUp(self):
        self.batch_size = 2
        self.camera_num = 3
        self.test_cases = [[2, 8], [2, 123], [10, 1234], [112, 23456], [222, 234567]]
        self.test_results = self.gen_results()

    def gen_inputs(self, tile_num, gaussian_num):
        B, C, T, G = self.batch_size, self.camera_num, tile_num, gaussian_num
        mask = torch.rand(B, C, T, G) > 0.5
        tile_sums = mask.sum(dim=-1).to(torch.int64)
        tile_offsets = tile_sums.cumsum(dim=-1)
        sorted_cnts = tile_offsets[:, :, -1].flatten()
        sorted_offset = torch.cumsum(sorted_cnts, dim=0)

        gaussian_ids = torch.arange(G).view(1, 1, 1, G).expand(B, C, T, G)
        depths = torch.rand(B, C, T, G)

        tile_gaussian_ids = torch.zeros(B, C, T, G)
        tile_depths = torch.zeros(B, C, T, G)

        for b in range(B):
            for c in range(C):
                for t in range(T):
                    mask_t = mask[b, c, t]
                    if mask_t.any():
                        tile_gaussian_ids[b, c, t, : mask_t.sum()] = gaussian_ids[b, c, t, mask_t]
                        tile_depths[b, c, t, : mask_t.sum()] = depths[b, c, t, mask_t]
        vector_num = acl.get_device_capability(0, 1)[0]
        num_bins = min(vector_num, T)
        lb_scheds = get_render_schedule(tile_sums, num_bins)
        max_tile_gauss = tile_sums.max().item()

        return Inputs(lb_scheds, tile_sums, tile_depths, tile_gaussian_ids, sorted_offset, max_tile_gauss)

    def cpu_to_exec(self, inputs):
        sorted_gs_ids = _gaussian_sort(
            inputs.lb_scheds, inputs.tile_sums, inputs.tile_depths, inputs.tile_gaussian_ids, inputs.sorted_offset
        )
        return ExecResults(sorted_gs_ids=sorted_gs_ids.detach().to(torch.int64))

    def npu_to_exec(self, inputs):
        sorted_gs_ids = gaussian_sort(
            inputs.lb_scheds.npu(),
            inputs.tile_sums.npu().to(torch.int32),
            inputs.tile_depths.npu(),
            inputs.tile_gaussian_ids.npu(),
            inputs.sorted_offset.npu(),
            int(inputs.max_tile_gauss),
        )
        return ExecResults(sorted_gs_ids=sorted_gs_ids.detach().to(torch.int64))

    def gen_results(self):
        test_results = []
        for tile_num, gaussian_num in self.test_cases:
            inputs = self.gen_inputs(tile_num, gaussian_num)
            cpu_results = self.cpu_to_exec(inputs)
            npu_results = self.npu_to_exec(inputs)
            test_results.append((cpu_results, npu_results))
        return test_results

    def test_gaussian_sort(self):
        for cpu_results, npu_results in self.test_results:
            self.assertRtolEqual(cpu_results.sorted_gs_ids.numpy(), npu_results.sorted_gs_ids.cpu().numpy())


if __name__ == "__main__":
    run_tests()