import numpy as np
import torch
from data_cache import golden_data_cache
from torch_npu.testing.testcase import TestCase, run_tests

import mx_driving
from mx_driving.ops.three_interpolate import ThreeInterpolateFunction

class TestThreeinterpolate(TestCase):
    @golden_data_cache(__file__)
    def cpu_op_exec(self, feat, idx, wt):
        bs, cs, ms = feat.shape
        ns = idx.shape[1]
        out = np.zeros((bs, cs, ns)).astype(feat.dtype)
        # forward
        for b in range(bs):
            for c in range(cs):
                for n in range(ns):
                    out[b, c, n] = feat[b, c, idx[b, n, 0]] * wt[b, n, 0] + \
                                    feat[b, c, idx[b, n, 1]] * wt[b, n, 1] + \
                                    feat[b, c, idx[b, n, 2]] * wt[b, n, 2]
        
        grad_out = np.zeros((bs, cs, ms)).astype(feat.dtype)
        grad_out = grad_out.transpose(0, 2, 1)
        for b in range(bs):
            for n in range(ns):
                ind = idx[b, n, :]
                weight = wt[b, n, :]
                grad_out[b, ind[0], :] += weight[0].repeat(cs)
                grad_out[b, ind[1], :] += weight[1].repeat(cs)
                grad_out[b, ind[2], :] += weight[2].repeat(cs)
        grad_out = grad_out.transpose(0, 2, 1)
        return out, grad_out
    
    def npu_op_exec(self, feat, idx, wt, grad_out_dtype=torch.float32):
        
        feat.requires_grad = True
        b, c, m = feat.size()
        out = mx_driving.three_interpolate(feat, idx, wt)

        class MockCtx:
            def __init__(self,saved_tensors):
                self.three_interpolate_for_backward=saved_tensors

        saved_tensors=(idx, wt,m)
        ctx = MockCtx(saved_tensors)
        
        grad_out = torch.ones_like(out, dtype=grad_out_dtype)
        grad_features,_,_=ThreeInterpolateFunction.backward(ctx,grad_out)

        grad_out = grad_features.detach().cpu().numpy()
        out = out.detach().cpu().numpy()
        return out, grad_out

    def test_three_interpolate_with_grad(self):
        bs = [2, 10, 224]
        cs = [3, 20, 45]
        ms = [4, 17, 224]
        ns = [5, 34, 150]

        for i in range(3):
            np.random.seed(i)
            features = np.random.uniform(-1000, 1000, size=(bs[i], cs[i], ms[i])).astype(np.float32)
            indices = np.random.randint(0, ms[i], size=(bs[i], ns[i], 3)).astype(np.int32)
            weights = np.random.uniform(0, 1, size=(bs[i], ns[i], 3)).astype(np.float32)
            
            npu_features = torch.from_numpy(features).to(torch.float32).npu()
            npu_indices = torch.from_numpy(indices).int().npu()
            npu_weights = torch.from_numpy(weights).to(torch.float32).npu()
            cpu_output = self.cpu_op_exec(features, indices, weights)
            npu_output = self.npu_op_exec(npu_features, npu_indices, npu_weights)
            self.assertRtolEqual(cpu_output[0], npu_output[0])
            self.assertRtolEqual(cpu_output[1], npu_output[1])

    def test_three_interpolate_half_dtype(self):
        bs = [2]
        cs = [3]
        ms = [4]
        ns = [5]

        for i in range(1):
            np.random.seed(i)
            features = np.random.uniform(-1000, 1000, size=(bs[i], cs[i], ms[i])).astype(np.float16)
            indices = np.random.randint(0, ms[i], size=(bs[i], ns[i], 3)).astype(np.int32)
            weights = np.random.uniform(0, 1, size=(bs[i], ns[i], 3)).astype(np.float16)
            
            npu_features = torch.from_numpy(features).to(torch.float16).npu()
            npu_indices = torch.from_numpy(indices).int().npu()
            npu_weights = torch.from_numpy(weights).to(torch.float16).npu()
            cpu_output = self.cpu_op_exec(features, indices, weights)
            npu_output = self.npu_op_exec(npu_features, npu_indices, npu_weights,torch.float16)

if __name__ == "__main__":
    run_tests()