#ifndef PULGIN_UTILS_ADVANCED_INDEX
#define PULGIN_UTILS_ADVANCED_INDEX
#include <ATen/native/IndexingUtils.h>
#include <ATen/ExpandUtils.h>
#include "op_plugin/utils/Export.h"
namespace op_plugin {
struct OP_PLUGIN_HIDDEN AdvancedIndex {
AdvancedIndex(const at::Tensor& src, at::TensorList list_indices);
at::Tensor src;
std::vector<at::Tensor> indices;
at::DimVector indexed_sizes;
at::DimVector indexed_strides;
int64_t dims_before;
int64_t dims_after;
};
class OP_PLUGIN_HIDDEN AdvanceIndex {
public:
static bool all_strides_match(at::TensorList tensor_list);
static at::Tensor reshape_indexer(const at::Tensor& index, int64_t dims_before, int64_t dims_after);
static at::Tensor restride_src(const at::Tensor& src, int64_t before_dims, int64_t dims_indexed,
at::IntArrayRef replacement_shape);
static std::string shapes_as_str(at::TensorList tensors);
static AdvancedIndex make_info(at::Tensor self, const torch::List<c10::optional<at::Tensor>>& orig);
static std::vector<at::Tensor> npu_expand_tensors(
const at::Tensor& self,
const torch::List<c10::optional<at::Tensor>>& indices,
bool needCast,
bool flag_aclnn = false);
static std::vector<at::Tensor> npu_broadcast_tensors(std::vector<at::Tensor> to_broadcast);
static bool is_expandable_to(c10::IntArrayRef shape, c10::IntArrayRef desired);
static bool checkIndexTensorTypes(const torch::List<c10::optional<at::Tensor>> &indices);
};
}
#endif