#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/custom_functions/atb/AtbCommon.h"
#include <acl/acl.h>
using namespace std;
namespace atb {
using LinearParam = atb::infer::LinearParam;
void _npu_matmul_add_fp32(const at::Tensor &x, const at::Tensor &weight, at::Tensor & C)
{
const c10::OptionalDeviceGuard device_guard(device_of(x));
OpParamCache<LinearParam>& linearParamCache = OpParamCache<LinearParam>::getInstance();
LinearParam linearParam;
linearParam.transposeA = true;
linearParam.transposeB = false;
linearParam.hasBias = false;
linearParam.enAccum = true;
auto opLinear = linearParamCache.getOperation(linearParam, "LinearOperation");
ParamSetter paramsetter;
paramsetter.Input(x)
.Input(weight)
.Input(C)
.Output(C);
RunAtbCmd(opLinear, paramsetter, "LinearOperation");
return ;
}
namespace {
TORCH_LIBRARY_FRAGMENT(atb, m)
{
m.def("_npu_matmul_add_fp32(Tensor x, Tensor weight, Tensor(a!) C) -> ()");
}
}
namespace {
TORCH_LIBRARY_IMPL(atb, PrivateUse1, m)
{
m.impl("_npu_matmul_add_fp32", TORCH_FN(atb::_npu_matmul_add_fp32));
}
}
}