#include <ATen/native/LinearAlgebraUtils.h>
#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
namespace {
constexpr int DIM_2D = 2;
static void check_linalg_solve_triangular_inputs(const at::Tensor& self, const at::Tensor& B)
{
TORCH_CHECK(
self.dim() >= DIM_2D,
"linalg.solve_triangular: The input tensor A must have at least 2 dimensions.",
OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(
B.dim() >= DIM_2D,
"linalg.solve_triangular: The input tensor B must have at least 2 dimensions.",
OPS_ERROR(ErrCode::PARAM));
}
at::Tensor exec_triangular_solve(
const at::Tensor& self,
const at::Tensor& A,
bool upper,
bool transpose,
bool unitriangular)
{
at::native::checkInputsSolver(A, self, true, "linalg.solve_triangular");
at::Tensor self_broadcasted;
at::Tensor a_broadcasted;
std::tie(self_broadcasted, a_broadcasted) = at::native::_linalg_broadcast_batch_dims(self, A, nullptr);
if (self_broadcasted.scalar_type() != a_broadcasted.scalar_type()) {
self_broadcasted = self_broadcasted.to(a_broadcasted.scalar_type());
}
auto self_working_copy = npu_preparation::apply_tensor(self_broadcasted);
auto a_working_copy = a_broadcasted.clone();
EXEC_NPU_CMD(aclnnTriangularSolve, self_broadcasted, a_broadcasted, upper, transpose, unitriangular, self_working_copy, a_working_copy);
return self_working_copy;
}
}
at::Tensor& linalg_solve_triangular_out(
const at::Tensor& self,
const at::Tensor& B,
bool upper,
bool left,
bool unitriangular,
at::Tensor& out)
{
check_linalg_solve_triangular_inputs(self, B);
at::Tensor X;
at::Tensor X_transpose;
bool transpose = false;
if (left) {
X = exec_triangular_solve(B, self, upper, transpose, unitriangular);
out.resize_as_(X).copy_(X);
} else {
X = exec_triangular_solve(B.transpose(-2, -1), self.transpose(-2, -1), !upper, transpose, unitriangular);
X_transpose = X.transpose(-2, -1);
out.resize_as_(X_transpose).copy_(X_transpose);
}
return out;
}
at::Tensor linalg_solve_triangular(
const at::Tensor& self,
const at::Tensor& B,
bool upper,
bool left,
bool unitriangular)
{
check_linalg_solve_triangular_inputs(self, B);
at::Tensor X;
at::Tensor X_transpose;
bool transpose = false;
if (left) {
X = exec_triangular_solve(B, self, upper, transpose, unitriangular);
return X;
} else {
X = exec_triangular_solve(B.transpose(-2, -1), self.transpose(-2, -1), !upper, transpose, unitriangular);
X_transpose = X.transpose(-2, -1);
return X_transpose;
}
}
}