#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
#include "op_plugin/utils/custom_functions/opapi/FFTCommonOpApi.h"
namespace op_api {
#if VERSION_BETWEEN(V2R1, VERSION_NEWEST)
using npu_preparation = at_npu::native::OpPreparation;
enum class fft_norm_mode {
none,
by_root_n,
by_n,
};
enum class fft_mode {
c2c,
r2c,
c2r,
};
double _fft_normalization_scale(int64_t normalization, at::IntArrayRef sizes, at::IntArrayRef dims)
{
auto norm = static_cast<fft_norm_mode>(normalization);
if (norm == fft_norm_mode::none) {
return 1.0;
}
int64_t signal_numel = 1;
for (auto dim : dims) {
signal_numel *= sizes[dim];
}
const double scale_denom = (norm == fft_norm_mode::by_root_n) ?
std::sqrt(signal_numel) : static_cast<double>(signal_numel);
return 1.0 / scale_denom;
}
const at::Tensor &_fft_apply_normalization(const at::Tensor &self, int64_t normalization,
at::IntArrayRef sizes, at::IntArrayRef dims)
{
auto scale = _fft_normalization_scale(normalization, sizes, dims);
return (scale == 1.0) ? self : self.mul_(scale);
}
static void HackComplexintoFloat(at::Tensor& self)
{
auto old_sizes = self.sym_sizes();
at::SymDimVector new_sizes(old_sizes.size() + 1);
std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin());
new_sizes.back() = 2;
auto old_strides = self.sym_strides();
at::SymDimVector new_strides(old_strides.size() + 1);
for (uint32_t i = 0; i < old_strides.size(); i++) {
new_strides[i] = old_strides[i] * 2;
}
new_strides.back() = 1;
auto *impl = self.unsafeGetTensorImpl();
impl->set_storage_and_dtype(self.storage(), c10::scalarTypeToTypeMeta(c10::toRealValueType(self.scalar_type())));
impl->set_sizes_and_strides(new_sizes, new_strides, self.sym_storage_offset() * 2);
}
static void HackFloatintoComplex(at::Tensor& self)
{
auto old_sizes = self.sym_sizes();
at::SymDimVector new_sizes(old_sizes.size() - 1);
std::copy(old_sizes.begin(), old_sizes.end() - 1, new_sizes.begin());
auto old_strides = self.sym_strides();
at::SymDimVector new_strides(old_strides.size() - 1);
for (uint32_t i = 0; i < new_strides.size(); i++) {
new_strides[i] = old_strides[i] / 2;
}
auto *impl = self.unsafeGetTensorImpl();
impl->set_storage_and_dtype(self.storage(), c10::scalarTypeToTypeMeta(c10::toComplexType(self.scalar_type())));
impl->set_sizes_and_strides(new_sizes, new_strides, self.sym_storage_offset() / 2);
}
static at::DimVector _sort_dims(const at::Tensor& self, at::IntArrayRef dim, int64_t mode_code = 0)
{
auto mode = static_cast<fft_mode>(mode_code);
at::DimVector sorted_dims(dim.begin(), dim.end() - (mode_code > 0));
auto self_strides = self.strides();
std::sort(sorted_dims.begin(), sorted_dims.end(),
[&](int64_t a, int64_t b) { return self_strides[a] > self_strides[b]; });
if (mode == fft_mode::c2r) {
sorted_dims.push_back(dim.back());
} else if (mode == fft_mode::r2c) {
sorted_dims.insert(sorted_dims.begin(), dim.back());
}
return sorted_dims;
}
static op_api::PlanMode get_plan_mode(fft_mode mode, int idx, int signal_ndim, bool oneside)
{
switch (mode) {
case fft_mode::c2c:
return op_api::PlanMode::c2c;
case fft_mode::r2c:
if (idx != 0) {
return op_api::PlanMode::c2c;
}
if (oneside) {
return op_api::PlanMode::r2c;
} else {
return op_api::PlanMode::r2c_bothside;
}
case fft_mode::c2r:
if (idx != signal_ndim - 1) {
return op_api::PlanMode::c2c;
}
return op_api::PlanMode::c2r;
}
}
static at::Tensor& _exec_fft(at::Tensor& out_, const at::Tensor& self_, at::IntArrayRef out_sizes,
at::IntArrayRef dim, int64_t normalization, bool forward, int64_t mode_code = 0)
{
auto mode = static_cast<fft_mode>(mode_code);
auto self = self_.view(self_.sizes());
auto out = out_.view(out_.sizes());
const auto ndim = self.dim();
const auto signal_ndim = dim.size();
const auto batch_dims = ndim - signal_ndim;
at::DimVector dim_permute(ndim);
std::iota(dim_permute.begin(), dim_permute.end(), int64_t{0});
std::vector<bool> is_transformed_dim(ndim);
for (const auto& d : dim) {
is_transformed_dim[d] = true;
}
auto batch_end = std::partition(dim_permute.begin(), dim_permute.end(),
[&](int64_t d) {return is_transformed_dim[d]; });
auto self_strides = self.strides();
at::DimVector sorted_dims = _sort_dims(self, dim, mode_code);
sorted_dims = _sort_dims(self, dim, mode_code);
std::copy(sorted_dims.begin(), sorted_dims.end(), dim_permute.begin());
std::sort(batch_end, dim_permute.end(),
[&](int64_t a, int64_t b) { return self_strides[a] > self_strides[b]; });
if (mode != fft_mode::r2c) {
dim_permute.insert(dim_permute.begin(), ndim);
HackComplexintoFloat(self);
}
if (mode != fft_mode::c2r) {
HackComplexintoFloat(out);
}
self = self.permute(dim_permute);
at::DimVector final_sizes(ndim);
auto self_begin = self.sizes().begin();
if (mode != fft_mode::r2c) {
self_begin++;
}
std::copy(self_begin + signal_ndim, self.sizes().end(), final_sizes.begin());
std::copy(self_begin, self_begin + signal_ndim, final_sizes.begin() + batch_dims);
if (mode == fft_mode::c2r) {
final_sizes[ndim - 1] = -1;
} else if (mode == fft_mode::r2c) {
final_sizes[batch_dims] = -1;
}
if (mode != fft_mode::c2r) {
final_sizes.push_back(2);
}
at::Tensor tmp_pingpong[2];
int64_t numel_buffer = self.numel();
if (mode != fft_mode::c2c) {
numel_buffer *= 2;
}
tmp_pingpong[0] = npu_preparation::apply_tensor_without_format(
numel_buffer, self_.options().dtype(c10::toRealValueType(self_.scalar_type())));
tmp_pingpong[1] = npu_preparation::apply_tensor_without_format(
numel_buffer, self_.options().dtype(c10::toRealValueType(self_.scalar_type())));
tmp_pingpong[0].resize_(self.sizes());
tmp_pingpong[0].copy_(self);
int64_t signal_total = 1;
for (int64_t i = 0; i < signal_ndim; i++) {
signal_total *= *(self_begin + i);
}
int64_t batch_total = 1;
for (int64_t i = 0; i < batch_dims; i++) {
batch_total *= final_sizes[i];
}
at::DimVector collapsed_sizes(3);
uint32_t ping = 0;
uint32_t pong = 1;
for (int64_t i = 0; i < signal_ndim; i++) {
if (*(self_begin + i) == 1) {
if (mode == fft_mode::r2c && i == 0) {
collapsed_sizes[0] = 1;
collapsed_sizes[1] = -1;
collapsed_sizes[2] = batch_total;
tmp_pingpong[ping] = tmp_pingpong[ping].reshape(collapsed_sizes);
collapsed_sizes[1] = tmp_pingpong[ping].size(1);
out.resize_(collapsed_sizes);
at::zeros_like_out(out, tmp_pingpong[ping]);
collapsed_sizes[0] = 2;
tmp_pingpong[pong].resize_(collapsed_sizes);
at::cat_out(tmp_pingpong[pong], {tmp_pingpong[ping], out}, 0);
ping = 1 - ping;
pong = 1 - pong;
} else {
collapsed_sizes[0] = 2;
collapsed_sizes[1] = -1;
collapsed_sizes[2] = batch_total;
tmp_pingpong[ping].reshape(collapsed_sizes);
}
continue;
}
auto plan_mode = get_plan_mode(mode, i, signal_ndim, out_sizes[dim.back()] != *(self_begin + i));
int64_t radix = *(self_begin + i);
if (plan_mode == op_api::PlanMode::c2r) {
radix = out_sizes[dim.back()];
}
auto plan_item = op_api::get_plan(radix, forward, plan_mode, c10::toRealValueType(self_.scalar_type()));
auto coefficient_matrix_list = plan_item.get_rotate_matrices();
auto factors_list = plan_item.get_factors();
uint32_t factors_num = coefficient_matrix_list.size();
if ((plan_mode == op_api::PlanMode::c2r) && (out_sizes[dim.back()] != *(self_begin + i))) {
collapsed_sizes[0] = 2;
collapsed_sizes[1] = *(self_begin + i);
collapsed_sizes[2] = -1;
auto buffer_0 = tmp_pingpong[ping].reshape(collapsed_sizes);
auto buffer_1 = at::slice(buffer_0, 1, 1, collapsed_sizes[1]
- ((out_sizes[dim.back()] == 2 * (collapsed_sizes[1] - 1)) ? 1 : 0), 1);
out.resize_(buffer_1.sizes());
at::flip_out(out, buffer_1, 1);
auto imag_buffer = at::select(out, 0, 1);
at::neg_(imag_buffer);
collapsed_sizes[1] = out_sizes[dim.back()];
collapsed_sizes[2] = buffer_0.size(2);
tmp_pingpong[pong].resize_(collapsed_sizes);
at::cat_out(tmp_pingpong[pong], {buffer_0, out}, 1);
ping = 1 - ping;
pong = 1 - pong;
}
for (auto& coefficient_matrix : coefficient_matrix_list) {
collapsed_sizes[0] = coefficient_matrix.size(0);
collapsed_sizes[1] = coefficient_matrix.size(2);
collapsed_sizes[2] = -1;
tmp_pingpong[ping] = tmp_pingpong[ping].reshape(collapsed_sizes);
collapsed_sizes[1] = coefficient_matrix.size(1);
collapsed_sizes[2] = tmp_pingpong[ping].size(2);
tmp_pingpong[pong].resize_(collapsed_sizes);
uint8_t cube_math_type = 0;
EXEC_NPU_CMD(aclnnMatmul, coefficient_matrix, tmp_pingpong[ping], tmp_pingpong[pong], cube_math_type);
ping = 1 - ping;
pong = 1 - pong;
}
at::DimVector reshape_sizes(factors_num + 3);
if ((mode == fft_mode::c2r) && (i == (signal_ndim - 1))) {
reshape_sizes[factors_num] = 1;
} else {
reshape_sizes[factors_num] = 2;
}
reshape_sizes[factors_num + 1] = signal_total / *(self_begin + i);
reshape_sizes[factors_num + 2] = batch_total;
std::copy(factors_list.cbegin(), factors_list.cend(), reshape_sizes.begin());
if (plan_mode == op_api::PlanMode::r2c) {
signal_total /= factors_list.back();
signal_total *= (factors_list.back() / 2 + 1);
reshape_sizes[factors_num - 1] = reshape_sizes[factors_num - 1] / 2 + 1;
}
tmp_pingpong[ping] = tmp_pingpong[ping].reshape(reshape_sizes);
at::DimVector dim_permute_(factors_num + 3);
std::iota(dim_permute_.rbegin() + 1, dim_permute_.rbegin() + factors_num + 1, int64_t{0});
dim_permute_[0] = factors_num;
dim_permute_[1] = factors_num + 1;
dim_permute_[factors_num + 2] = factors_num + 2;
tmp_pingpong[ping] = tmp_pingpong[ping].permute(dim_permute_);
if (i != (signal_ndim - 1)) {
tmp_pingpong[pong].resize_(tmp_pingpong[ping].sizes());
tmp_pingpong[pong].copy_(tmp_pingpong[ping]);
ping = 1 - ping;
pong = 1 - pong;
}
}
if ((mode == fft_mode::r2c) && (signal_ndim > 1)) {
if (out_sizes[dim.back()] != *(self_begin)) {
int64_t signal_total_final = out_sizes[dim.back()];
for (int64_t i = 1; i < signal_ndim - 1; i++) {
signal_total_final *= *(self_begin + i);
}
tmp_pingpong[ping] = at::slice(tmp_pingpong[ping], 1, 0, signal_total_final, 1);
}
}
int out_ndim = tmp_pingpong[ping].dim();
at::DimVector dim_permute_out(out_ndim);
std::iota(dim_permute_out.begin(), dim_permute_out.end(), int64_t{0});
dim_permute_out[0] = out_ndim - 1;
dim_permute_out[out_ndim - 1] = 0;
tmp_pingpong[ping] = tmp_pingpong[ping].permute(dim_permute_out);
out.resize_(tmp_pingpong[ping].sizes());
out.copy_(tmp_pingpong[ping]);
out = out.reshape(final_sizes);
if ((mode == fft_mode::r2c) && (signal_ndim == 1) && (out_sizes[dim.back()] != *(self_begin))) {
out = at::slice(out, ndim - 1, 0, out_sizes[dim.back()], 1);
}
if (mode != fft_mode::r2c) {
_fft_apply_normalization(out, normalization, out_sizes, dim);
} else {
_fft_apply_normalization(out, normalization, self_.sizes(), dim);
}
if (mode != fft_mode::c2r) {
HackFloatintoComplex(out);
}
if (mode != fft_mode::r2c) {
dim_permute.erase(dim_permute.begin());
}
at::DimVector out_strides(ndim);
auto now_strides_ = out.strides();
for (const auto i : c10::irange(0, signal_ndim)) {
out_strides[dim_permute[i]] = now_strides_[i + batch_dims];
}
for (const auto i : c10::irange(signal_ndim, ndim)) {
out_strides[dim_permute[i]] = now_strides_[i - signal_ndim];
}
out_.as_strided_(out_sizes, out_strides, out.storage_offset());
return out_;
}
static at::Tensor& _exec_fft_asdsip(at::Tensor& out_, const at::Tensor& self_, at::IntArrayRef out_sizes,
at::IntArrayRef dim, int64_t normalization, bool forward, int64_t mode_code = 0)
{
auto mode = static_cast<fft_mode>(mode_code);
auto self = self_.view(self_.sizes());
auto out = out_.view(out_.sizes());
const auto ndim = self.dim();
const auto signal_ndim = dim.size();
const auto batch_dims = ndim - signal_ndim;
at::DimVector dim_permute(ndim);
std::iota(dim_permute.begin(), dim_permute.end(), int64_t{0});
std::vector<bool> is_transformed_dim(ndim);
for (const auto& d : dim) {
is_transformed_dim[d] = true;
}
auto batch_end = std::partition(dim_permute.begin(), dim_permute.end(),
[&](int64_t d) {return !is_transformed_dim[d]; });
auto self_strides = self.strides();
at::DimVector sorted_dims(dim.begin(), dim.end() - (mode != fft_mode::c2c));
std::sort(sorted_dims.begin(), sorted_dims.end(),
[&](int64_t a, int64_t b) { return self_strides[a] > self_strides[b]; });
if (mode != fft_mode::c2c) {
sorted_dims.push_back(dim.back());
}
std::copy(sorted_dims.begin(), sorted_dims.end(), batch_end);
std::sort(dim_permute.begin(), batch_end,
[&](int64_t a, int64_t b) { return self_strides[a] > self_strides[b]; });
at::Tensor asdfft_input;
if (mode != fft_mode::r2c) {
HackComplexintoFloat(self);
dim_permute.push_back(ndim);
asdfft_input = self.permute(dim_permute).contiguous();
dim_permute.pop_back();
HackFloatintoComplex(asdfft_input);
} else {
asdfft_input = self.permute(dim_permute).contiguous();
}
auto out_ori = out.permute(dim_permute);
at::DimVector asdsip_sizes(signal_ndim + 1);
asdsip_sizes[0] = -1;
std::copy(asdfft_input.sizes().begin() + batch_dims, asdfft_input.sizes().end(), asdsip_sizes.begin() + 1);
asdfft_input = asdfft_input.reshape(asdsip_sizes);
at::DimVector out_sizes_(out_ori.sizes());
std::copy(out_ori.sizes().begin() + batch_dims, out_ori.sizes().end(), asdsip_sizes.begin() + 1);
out_ori = out_ori.reshape(asdsip_sizes);
out.resize_(out_ori.sizes());
FFTParam param;
param.batchSize = asdfft_input.size(0);
if (forward) {
param.direction = asdFftDirection::ASCEND_FFT_FORWARD;
} else {
param.direction = asdFftDirection::ASCEND_FFT_BACKWARD;
}
switch (mode) {
case fft_mode::c2c:
param.fftXSize = asdfft_input.size(1);
if (signal_ndim == 2) {
param.fftYSize = asdfft_input.size(2);
}
param.fftType = asdFftType::ASCEND_FFT_C2C;
EXEC_ASDSIP_FFT_NPU_CMD(C2C, asdfft_input, out, param);
break;
case fft_mode::r2c:
param.fftXSize = asdfft_input.size(1);
if (signal_ndim == 2) {
param.fftYSize = asdfft_input.size(2);
}
param.fftType = asdFftType::ASCEND_FFT_R2C;
EXEC_ASDSIP_FFT_NPU_CMD(R2C, asdfft_input, out, param);
break;
case fft_mode::c2r:
param.fftXSize = out.size(1);
if (signal_ndim == 2) {
param.fftYSize = out.size(2);
}
param.fftType = asdFftType::ASCEND_FFT_C2R;
EXEC_ASDSIP_FFT_NPU_CMD(C2R, asdfft_input, out, param);
break;
}
auto norm_out_sizes = out_sizes;
if (mode == fft_mode::r2c) {
norm_out_sizes = self_.sizes();
}
if (mode != fft_mode::c2r) {
HackComplexintoFloat(out);
}
_fft_apply_normalization(out, normalization, norm_out_sizes, dim);
if (mode != fft_mode::c2r) {
HackFloatintoComplex(out);
}
out = out.reshape(out_sizes_);
auto now_strides_ = out.strides();
at::DimVector out_strides(ndim);
for (const auto i : c10::irange(0, ndim)) {
out_strides[dim_permute[i]] = now_strides_[i];
}
out_.as_strided_(out_sizes, out_strides, out.storage_offset());
return out_;
}
at::Tensor _fft_c2c(const at::Tensor& self, at::IntArrayRef dim, int64_t normalization, bool forward)
{
TORCH_CHECK(self.is_complex(), OPS_ERROR(ErrCode::PARAM));
auto output_size = op_infer::input_same_output_size(self);
auto out = npu_preparation::apply_tensor_without_format(output_size, self.options().dtype(self.scalar_type()));
DO_ASDSIP_COMPATIBILITY(C2C, _exec_fft(out, self, self.sizes(), dim, normalization, forward, 0));
if (dim.size() > 1 || self.scalar_type() == at::ScalarType::ComplexHalf) {
_exec_fft(out, self, self.sizes(), dim, normalization, forward, 0);
} else {
_exec_fft_asdsip(out, self, self.sizes(), dim, normalization, forward, 0);
}
return out;
}
at::Tensor& _fft_c2c_out(const at::Tensor& self, at::IntArrayRef dim, int64_t normalization,
bool forward, at::Tensor& out)
{
TORCH_CHECK(self.is_complex(), OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(out.is_complex(), OPS_ERROR(ErrCode::PARAM));
DO_ASDSIP_COMPATIBILITY(C2C, _exec_fft(out, self, self.sizes(), dim, normalization, forward, 0));
if (dim.size() > 1 || self.scalar_type() == at::ScalarType::ComplexHalf) {
_exec_fft(out, self, self.sizes(), dim, normalization, forward, 0);
} else {
_exec_fft_asdsip(out, self, self.sizes(), dim, normalization, forward, 0);
}
return out;
}
at::Tensor _fft_r2c(const at::Tensor& self, at::IntArrayRef dim, int64_t normalization, bool onesided)
{
TORCH_CHECK(self.is_floating_point(), OPS_ERROR(ErrCode::PARAM));
auto input_sizes = self.sizes();
at::DimVector out_sizes(input_sizes.begin(), input_sizes.end());
auto last_dim = dim.back();
auto last_dim_halfsize = (input_sizes[last_dim]) / 2 + 1;
if (onesided) {
out_sizes[last_dim] = last_dim_halfsize;
}
auto out = npu_preparation::apply_tensor_without_format(
out_sizes, self.options().dtype(c10::toComplexType(self.scalar_type())));
DO_ASDSIP_COMPATIBILITY(R2C, _exec_fft(out, self, out_sizes, dim, normalization, true, 1));
if (dim.size() > 1 || self.scalar_type() == at::ScalarType::Half) {
_exec_fft(out, self, out_sizes, dim, normalization, true, 1);
} else {
_exec_fft_asdsip(out, self, out_sizes, dim, normalization, true, 1);
}
return out;
}
at::Tensor &_fft_r2c_out(const at::Tensor &self, at::IntArrayRef dim,
int64_t normalization, bool onesided, at::Tensor &out)
{
TORCH_CHECK(self.is_floating_point(), OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(out.is_complex(), OPS_ERROR(ErrCode::PARAM));
auto input_sizes = self.sizes();
at::DimVector out_sizes(input_sizes.begin(), input_sizes.end());
auto last_dim = dim.back();
auto last_dim_halfsize = (input_sizes[last_dim]) / 2 + 1;
if (onesided) {
out_sizes[last_dim] = last_dim_halfsize;
}
DO_ASDSIP_COMPATIBILITY(R2C, _exec_fft(out, self, out_sizes, dim, normalization, true, 1));
if (dim.size() > 1 || self.scalar_type() == at::ScalarType::Half) {
_exec_fft(out, self, out_sizes, dim, normalization, true, 1);
} else {
_exec_fft_asdsip(out, self, out_sizes, dim, normalization, true, 1);
}
return out;
}
at::Tensor _fft_c2r(const at::Tensor& self, at::IntArrayRef dim, int64_t normalization, int64_t lastdim)
{
TORCH_CHECK(self.is_complex(), OPS_ERROR(ErrCode::PARAM));
auto in_sizes = self.sizes();
at::DimVector out_sizes(in_sizes.begin(), in_sizes.end());
out_sizes[dim.back()] = lastdim;
auto out = npu_preparation::apply_tensor_without_format(
out_sizes, self.options().dtype(c10::toRealValueType(self.scalar_type())));
DO_ASDSIP_COMPATIBILITY(C2R, _exec_fft(out, self, out_sizes, dim, normalization, self.is_conj(), 2));
if (dim.size() > 1 || self.scalar_type() == at::ScalarType::ComplexHalf) {
_exec_fft(out, self, out_sizes, dim, normalization, self.is_conj(), 2);
} else {
_exec_fft_asdsip(out, self, out_sizes, dim, normalization, self.is_conj(), 2);
}
return out;
}
at::Tensor &_fft_c2r_out(const at::Tensor &self, at::IntArrayRef dim,
int64_t normalization, int64_t lastdim, at::Tensor &out)
{
TORCH_CHECK(self.is_complex(), OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(out.is_floating_point(), OPS_ERROR(ErrCode::PARAM));
auto in_sizes = self.sizes();
at::DimVector out_sizes(in_sizes.begin(), in_sizes.end());
out_sizes[dim.back()] = lastdim;
DO_ASDSIP_COMPATIBILITY(C2R, _exec_fft(out, self, out_sizes, dim, normalization, self.is_conj(), 2));
if (dim.size() > 1 || self.scalar_type() == at::ScalarType::ComplexHalf) {
_exec_fft(out, self, out_sizes, dim, normalization, self.is_conj(), 2);
} else {
_exec_fft_asdsip(out, self, out_sizes, dim, normalization, self.is_conj(), 2);
}
return out;
}
#endif
}