from typing import List, Optional, Tuple, Union

import numpy as np
import torch


def scatter_nd(indices: torch.Tensor, updates: torch.Tensor, shape: torch.Tensor) -> torch.Tensor:
    """pytorch edition of tensorflow scatter_nd.

    this function don't contain except handle code. so use this carefully when
    indice repeats, don't support repeat add which is supported in tensorflow.
    """
    ret = torch.zeros(*shape, dtype=updates.dtype, device=updates.device)
    ndim = indices.shape[-1]
    output_shape = list(indices.shape[:-1]) + shape[indices.shape[-1]:]
    flatted_indices = indices.view(-1, ndim)
    slices = [flatted_indices[:, i] for i in range(ndim)]
    slices += [Ellipsis]
    ret[slices] = updates.view(*output_shape)
    return ret


class IndiceData(object):
    def __init__(
        self,
        origin_spatial_shape,
        origin_indices,
        unique_indices_offset,
        sorted_idx_to_former_indices,
        outidx_pair,
        is_subm: bool,
    ):
        self.origin_spatial_shape = origin_spatial_shape
        self.origin_indices = origin_indices
        self.unique_indices_offset = unique_indices_offset
        self.sorted_idx_to_former_indices = sorted_idx_to_former_indices
        self.outidx_pair = outidx_pair
        self.is_subm = is_subm


class SparseConvTensor:
    def __init__(
        self,
        features: torch.Tensor,
        indices: torch.Tensor,
        spatial_shape: Union[List, Tuple],
        batch_size: int,
        grid: Optional[torch.Tensor] = None,
    ):
        self.features = features
        self.indices = indices
        if self.indices.dtype != torch.int32:
            self.indices.int()
        self.spatial_shape = spatial_shape
        self.batch_size = batch_size
        self.indice_dict: dict = {}
        self.grid = grid

    @property
    def spatial_size(self):
        return np.prod(self.spatial_shape)

    def find_indice_pair(self, key):
        if key is None:
            return None
        if key in self.indice_dict:
            return self.indice_dict[key]
        return None

    def dense(self, channels_first: bool = True) -> torch.Tensor:
        output_shape = [self.batch_size] + list(self.spatial_shape) + [self.features.shape[1]]
        res = scatter_nd(self.indices.long(), self.features, output_shape)
        if not channels_first:
            return res
        ndim = len(self.spatial_shape)
        trans_params = list(range(0, ndim + 1))
        trans_params.insert(1, ndim + 1)
        return res.permute(*trans_params).contiguous()

    @property
    def sparity(self):
        return self.indices.shape[0] / np.prod(self.spatial_shape) / self.batch_size