#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
#include "op_plugin/utils/OpUtils.h"
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
namespace {
static constexpr int32_t K_STRIDED_SLICE_NEW_AXIS = -2;
static constexpr int32_t K_SHRINK_AXIS = -1;
static constexpr int64_t INVALID_IDX = -3L;
struct StridedSliceParams {
std::vector<int64_t> input_shape;
std::vector<int64_t> begin;
std::vector<int64_t> end;
std::vector<int64_t> strides;
uint64_t begin_mask;
uint64_t end_mask;
uint64_t ellipsis_mask;
uint64_t new_axis_mask;
uint64_t shrink_axis_mask;
bool begin_valid;
bool end_valid;
bool stride_valid;
bool real_begin_valid = true;
bool real_end_valid = true;
std::vector<int64_t> dy_shape;
[[nodiscard]] std::string to_string() const
{
std::string result = "input_shape:" + op_plugin::utils::get_vector_str(input_shape);
result += " begin:" + op_plugin::utils::get_vector_str(begin);
result += " end:" + op_plugin::utils::get_vector_str(end);
result += " strides:" + op_plugin::utils::get_vector_str(strides);
result += " begin_mask:" + std::to_string(begin_mask);
result += " end_mask:" + std::to_string(end_mask);
result += " ellipsis_mask:" + std::to_string(ellipsis_mask);
result += " new_axis_mask:" + std::to_string(new_axis_mask);
result += " shrink_axis_mask:" + std::to_string(shrink_axis_mask);
result += " begin_valid:" + std::to_string(begin_valid);
result += " end_valid:" + std::to_string(end_valid);
result += " stride_valid:" + std::to_string(stride_valid);
result += " real_begin_valid:" + std::to_string(static_cast<int32_t>(real_begin_valid));
result += " real_end_valid:" + std::to_string(static_cast<int32_t>(real_end_valid));
return result;
}
};
struct ProcessingData {
std::vector<int64_t> processing_shape;
std::vector<int64_t> processing_begin;
std::vector<int64_t> processing_end;
std::vector<int64_t> processing_strides;
[[nodiscard]] std::string to_string() const
{
std::string result = "processing_shape:" + op_plugin::utils::get_vector_str(processing_shape);
result += " processing_begin:" + op_plugin::utils::get_vector_str(processing_begin);
result += " processing_end:" + op_plugin::utils::get_vector_str(processing_end);
result += " processing_strides:" + op_plugin::utils::get_vector_str(processing_strides);
return result;
}
};
struct InputParamUnit {
int64_t begin;
int64_t end;
int64_t stride;
int64_t dim;
bool shrink;
};
struct StridedSliceSparseSpec {
int64_t dims;
int32_t num_add_axis_after_ellipsis;
const std::vector<int64_t> begin;
const std::vector<int64_t> end;
const std::vector<int64_t> strides;
const uint64_t begin_mask;
const uint64_t end_mask;
uint64_t ellipsis_mask;
const uint64_t new_axis_mask;
const uint64_t shrink_axis_mask;
};
struct StridedSliceDenseSpec {
const int64_t dims;
uint64_t begin_mask;
uint64_t end_mask;
bool begin_valid;
bool end_valid;
std::vector<int64_t> begin;
std::vector<int64_t> end;
std::vector<int64_t> strides;
std::vector<int64_t> final_shape_gather_indices;
uint64_t shrink_axis_mask;
[[nodiscard]] std::string to_string() const
{
std::string result = "dims:" + std::to_string(dims);
result += " begin_mask:" + std::to_string(begin_mask);
result += " end_mask:" + std::to_string(end_mask);
result += " begin_valid:" + std::to_string(static_cast<int32_t>(begin_valid));
result += " end_valid:" + std::to_string(static_cast<int32_t>(end_valid));
result += " begin:" + op_plugin::utils::get_vector_str(begin);
result += " end:" + op_plugin::utils::get_vector_str(end);
result += " strides:" + op_plugin::utils::get_vector_str(strides);
result += " final_shape_gather_indices:" + op_plugin::utils::get_vector_str(final_shape_gather_indices);
result += " shrink_axis_mask:" + std::to_string(shrink_axis_mask);
return result;
}
};
static inline uint64_t bit_1_value(int i)
{
const uint64_t bit_i = static_cast<uint64_t>(1) << static_cast<uint64_t>(i);
return bit_i;
}
static inline int64_t normalize_index(int64_t x, int64_t dim)
{
return x < 0 ? dim + x : x;
}
static inline bool fwd_out_of_bound(int64_t fwd, int64_t lower, int64_t upper)
{
return (fwd < lower) || (fwd >= upper);
}
static void build_sparse_spec(const StridedSliceParams& params, StridedSliceSparseSpec& sparse_spec)
{
sparse_spec.dims = static_cast<int64_t>(params.strides.size());
bool ellipsis_seen = false;
for (int32_t i = 0; i < sparse_spec.dims; i++) {
const uint64_t bit_i = bit_1_value(i);
if (ellipsis_seen && (bit_i & params.new_axis_mask) != 0) {
sparse_spec.num_add_axis_after_ellipsis++;
}
if ((bit_i & params.ellipsis_mask) != 0) {
ellipsis_seen = true;
}
}
if (!ellipsis_seen) {
sparse_spec.ellipsis_mask |= bit_1_value(sparse_spec.dims);
sparse_spec.dims++;
}
}
static void build_dense_spec(const StridedSliceSparseSpec& sparse, StridedSliceDenseSpec& dense)
{
dense.begin.resize(dense.dims);
dense.end.resize(dense.dims);
dense.strides.resize(dense.dims);
dense.begin_mask = 0;
dense.end_mask = 0;
dense.shrink_axis_mask = 0;
int full_index = 0;
for (int i = 0; i < sparse.dims; i++) {
const uint64_t bit_i = bit_1_value(i);
if ((bit_i & sparse.ellipsis_mask) != 0) {
int32_t next_index =
std::min(dense.dims - (sparse.dims - i) + 1 + sparse.num_add_axis_after_ellipsis, dense.dims);
for (; full_index < next_index; full_index++) {
dense.begin[full_index] = dense.end[full_index] = 0;
dense.strides[full_index] = 1;
dense.begin_mask |= bit_1_value(full_index);
dense.end_mask |= bit_1_value(full_index);
dense.final_shape_gather_indices.push_back(full_index);
}
} else if ((bit_i & sparse.new_axis_mask) != 0) {
dense.final_shape_gather_indices.push_back(K_STRIDED_SLICE_NEW_AXIS);
} else {
TORCH_CHECK_INDEX(
static_cast<size_t>(full_index) < dense.begin.size(), "Index out of range using input dim ", full_index,
"; input has only ", dense.dims, " dims.");
dense.begin[full_index] = sparse.begin[i];
dense.end[full_index] = sparse.end[i];
dense.strides[full_index] = sparse.strides[i];
if ((sparse.begin_mask & bit_i) != 0) {
dense.begin_mask |= bit_1_value(full_index);
}
if ((sparse.end_mask & bit_i) != 0) {
dense.end_mask |= bit_1_value(full_index);
}
if ((sparse.shrink_axis_mask & bit_i) != 0) {
dense.final_shape_gather_indices.push_back(K_SHRINK_AXIS);
dense.shrink_axis_mask |= bit_1_value(full_index);
} else {
dense.final_shape_gather_indices.push_back(full_index);
}
full_index++;
}
}
}
static void build_processing_shape(
const StridedSliceDenseSpec& dense_spec, const InputParamUnit& input_param_unit, bool begin_and_end_masked,
std::vector<int64_t>& processing_shape)
{
int64_t interval_length;
bool known_interval = false;
if (dense_spec.begin_valid && dense_spec.end_valid) {
interval_length = input_param_unit.end - input_param_unit.begin;
known_interval = true;
} else if (input_param_unit.shrink) {
interval_length = 1;
known_interval = true;
} else if (begin_and_end_masked) {
if (input_param_unit.dim >= 0) {
if (input_param_unit.stride < 0) {
interval_length = -input_param_unit.dim;
} else {
interval_length = input_param_unit.dim;
}
known_interval = true;
}
}
if (known_interval) {
int64_t size_i;
if (interval_length == 0 || ((interval_length < 0) != (input_param_unit.stride < 0))) {
size_i = 0;
} else {
size_i =
interval_length / input_param_unit.stride + (interval_length % input_param_unit.stride != 0 ? 1 : 0);
}
processing_shape.push_back(size_i);
} else {
processing_shape.push_back(-1);
}
}
static void build_processing_data(
const StridedSliceDenseSpec& dense_spec, StridedSliceParams& params, ProcessingData& processing_data)
{
bool is_identity = true;
bool slice_dim0 = true;
bool is_simple_slice = true;
for (int i = 0; i < static_cast<int>(params.input_shape.size()); ++i) {
auto& begin_i = params.begin[i];
auto& end_i = params.end[i];
auto& stride_i = params.strides[i];
auto dim_i = params.input_shape[i];
TORCH_CHECK_VALUE(stride_i != 0, "strides[", i, "] must be non-zero");
const uint64_t bit_i = bit_1_value(i);
bool shrink_i = (dense_spec.shrink_axis_mask & bit_i);
const std::array<uint64_t, 2> masks = {{dense_spec.begin_mask & bit_i, dense_spec.end_mask & bit_i}};
if (dim_i == -1) {
processing_data.processing_shape.push_back(shrink_i ? 1 : -1);
processing_data.processing_begin.push_back(begin_i);
processing_data.processing_end.push_back(shrink_i ? (begin_i + 1) : end_i);
processing_data.processing_strides.push_back(shrink_i ? 1 : stride_i);
continue;
}
const std::array<bool, 2> real_valid = {params.real_begin_valid, params.real_end_valid};
const std::array<int64_t, 2> valid_range = {{stride_i > 0 ? 0 : -1, stride_i > 0 ? dim_i : dim_i - 1}};
auto canonical = [stride_i, dim_i, masks, valid_range, real_valid](int64_t x, int c) {
if (masks[c]) {
return stride_i > 0 ? valid_range[c] :
valid_range[static_cast<uint64_t>(c + 1) & static_cast<uint64_t>(1)];
} else {
if (!real_valid[c]) {
return INVALID_IDX;
}
int64_t x_fwd = normalize_index(x, dim_i);
return x_fwd < valid_range[0] ? valid_range[0] : std::min(x_fwd, valid_range[1]);
}
};
TORCH_CHECK_VALUE(!(shrink_i && stride_i <= 0), "only stride 1 allowed on non-range indexing.");
is_simple_slice = is_simple_slice && (stride_i == 1);
const bool begin_and_end_masked =
((dense_spec.begin_mask & bit_i) != 0) && ((dense_spec.end_mask & bit_i) != 0);
if (dense_spec.begin_valid && dense_spec.end_valid) {
if (shrink_i) {
if (real_valid[0]) {
int64_t x_fwd = normalize_index(begin_i, dim_i);
begin_i = x_fwd;
end_i = begin_i + 1;
TORCH_CHECK_INDEX(
!fwd_out_of_bound(x_fwd, 0, dim_i), "slice index ", begin_i, " of dimension ", i,
" out of bounds.");
} else {
begin_i = -2;
end_i = begin_i + 1;
}
} else {
begin_i = canonical(begin_i, 0);
end_i = canonical(end_i, 1);
}
TORCH_CHECK_VALUE(
!((!real_valid[0] || !real_valid[1]) && (begin_i == INVALID_IDX || end_i == INVALID_IDX)),
"begin_i:", begin_i, " end_i:", end_i, " is invalid while unconst begin or end, shrink_i:", shrink_i,
" masks:", masks[0], masks[1]);
processing_data.processing_begin.push_back(begin_i);
processing_data.processing_end.push_back(end_i);
processing_data.processing_strides.push_back(stride_i);
bool take_all_in_dimension = stride_i == 1 && begin_i == 0 && end_i == dim_i;
is_identity = is_identity && take_all_in_dimension;
slice_dim0 = slice_dim0 && ((i == 0 && stride_i == 1) || take_all_in_dimension);
} else {
is_identity = is_identity && (stride_i == 1 && begin_and_end_masked);
slice_dim0 = slice_dim0 && ((i == 0 && stride_i == 1) || begin_and_end_masked);
processing_data.processing_begin.push_back(begin_i);
processing_data.processing_end.push_back(end_i);
processing_data.processing_strides.push_back(1);
}
InputParamUnit input_param_unit = {begin_i, end_i, stride_i, dim_i, shrink_i};
build_processing_shape(dense_spec, input_param_unit, begin_and_end_masked, processing_data.processing_shape);
}
}
static std::vector<int64_t> build_final_shape(
const ProcessingData& processing_data, const StridedSliceDenseSpec& dense_spec, StridedSliceParams& params)
{
params.begin.clear();
params.end.clear();
params.strides.clear();
std::vector<int64_t> out_shape;
std::vector<int64_t> final_shape_input;
int shrink_gather_index = 0;
for (size_t i = 0; i < dense_spec.final_shape_gather_indices.size(); i++) {
auto gather_index = dense_spec.final_shape_gather_indices[i];
if (gather_index >= 0) {
const auto dim_gather_i = processing_data.processing_shape[gather_index];
out_shape.push_back(dim_gather_i);
final_shape_input.push_back(params.input_shape[gather_index]);
params.begin.push_back(processing_data.processing_begin[gather_index]);
params.end.push_back(processing_data.processing_end[gather_index]);
params.strides.push_back(processing_data.processing_strides[gather_index]);
shrink_gather_index = gather_index + 1;
} else if (gather_index == K_STRIDED_SLICE_NEW_AXIS) {
out_shape.push_back(1);
if (params.input_shape.empty()) {
final_shape_input.push_back(1);
params.begin.push_back(0);
params.end.push_back(1);
params.strides.push_back(1);
}
} else {
final_shape_input.push_back(params.input_shape[shrink_gather_index]);
params.begin.push_back(processing_data.processing_begin[shrink_gather_index]);
params.end.push_back(processing_data.processing_begin[shrink_gather_index] + 1);
params.strides.push_back(1);
shrink_gather_index += 1;
}
}
params.input_shape = final_shape_input;
return out_shape;
}
static inline c10::SmallVector<int64_t, SIZE> infer_shape_internal(StridedSliceParams& params)
{
ASCEND_LOGD("input params: %s", params.to_string().c_str());
TORCH_CHECK_VALUE(
params.begin.size() == params.end.size() && params.end.size() == params.strides.size(),
"Expected begin, end, and strides to be 1D equal size tensors, but got shapes [", params.begin.size(), "], [",
params.end.size(), "], [", params.strides.size(), "] instead.");
auto& ellipsis_mask = params.ellipsis_mask;
TORCH_CHECK_VALUE(
!(ellipsis_mask != 0 && ((ellipsis_mask & (ellipsis_mask - 1)) != 0)),
"Multiple ellipses in slice spec not allowed.");
StridedSliceSparseSpec sparse_spec{
0,
0,
params.begin,
params.end,
params.strides,
params.begin_mask,
params.end_mask,
params.ellipsis_mask,
params.new_axis_mask,
params.shrink_axis_mask};
build_sparse_spec(params, sparse_spec);
StridedSliceDenseSpec dense_spec{
static_cast<int64_t>(params.input_shape.size()),
0,
0,
params.begin_valid,
params.end_valid,
params.begin,
params.end,
params.strides,
{},
0};
build_dense_spec(sparse_spec, dense_spec);
ASCEND_LOGD("dense spec: %s", dense_spec.to_string().c_str());
ProcessingData processing_data;
params.begin = dense_spec.begin;
params.end = dense_spec.end;
params.strides = dense_spec.strides;
build_processing_data(dense_spec, params, processing_data);
ASCEND_LOGD("processing data: %s", processing_data.to_string().c_str());
auto out_shape = build_final_shape(processing_data, dense_spec, params);
ASCEND_LOGD("after infershape params: %s", params.to_string().c_str());
ASCEND_LOGI("[npu_indexing] output shape: %s", op_plugin::utils::get_vector_str(out_shape).c_str());
TORCH_CHECK_VALUE(out_shape.size() <= SIZE, "The output tensor cannot be larger than ", SIZE, " dimensions");
return c10::SmallVector<int64_t, SIZE>(out_shape);
}
static c10::SmallVector<int64_t, SIZE> npu_indexing_output_size(
const at::Tensor& self, at::IntArrayRef begin, at::IntArrayRef end, at::IntArrayRef strides, int64_t begin_mask,
int64_t end_mask, int64_t ellipsis_mask, int64_t new_axis_mask, int64_t shrink_axis_mask)
{
StridedSliceParams params;
params.input_shape = self.sizes().vec();
params.begin = begin.vec();
params.end = end.vec();
params.strides = strides.vec();
params.begin_mask = static_cast<uint64_t>(begin_mask);
params.end_mask = static_cast<uint64_t>(end_mask);
params.ellipsis_mask = static_cast<uint64_t>(ellipsis_mask);
params.new_axis_mask = static_cast<uint64_t>(new_axis_mask);
params.shrink_axis_mask = static_cast<uint64_t>(shrink_axis_mask);
params.begin_valid = !begin.empty();
params.end_valid = !end.empty();
params.stride_valid = !strides.empty();
return infer_shape_internal(params);
}
}
at::Tensor& npu_indexing_out(
const at::Tensor& self, at::IntArrayRef begin, at::IntArrayRef end, at::IntArrayRef strides, int64_t begin_mask,
int64_t end_mask, int64_t ellipsis_mask, int64_t new_axis_mask, int64_t shrink_axis_mask, at::Tensor& out)
{
if (c10_npu::GetSocVersion() < c10_npu::SocVersion::Ascend950) {
return acl_op::npu_indexing_out(
self, begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask, out);
}
DO_COMPATIBILITY(
aclnnStridedSlice,
acl_op::npu_indexing_out(
self, begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask, out));
auto out_size = npu_indexing_output_size(
self, begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask);
npu_preparation::check_tensor({self}, out, out.scalar_type(), out_size);
EXEC_NPU_CMD(
aclnnStridedSlice, self, begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask,
shrink_axis_mask, out);
return out;
}
at::Tensor npu_indexing(
const at::Tensor& self, at::IntArrayRef begin, at::IntArrayRef end, at::IntArrayRef strides, int64_t begin_mask,
int64_t end_mask, int64_t ellipsis_mask, int64_t new_axis_mask, int64_t shrink_axis_mask)
{
if (c10_npu::GetSocVersion() < c10_npu::SocVersion::Ascend950) {
return acl_op::npu_indexing(
self, begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask);
}
DO_COMPATIBILITY(
aclnnStridedSlice,
acl_op::npu_indexing(
self, begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask));
auto out_size = npu_indexing_output_size(
self, begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask);
at::Tensor out = npu_preparation::apply_tensor_without_format(out_size, self.options());
EXEC_NPU_CMD(
aclnnStridedSlice, self, begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask,
shrink_axis_mask, out);
return out;
}
}