* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file mod_dag.h
* \brief mod dag
*/
#ifndef MOD_DAG_H
#define MOD_DAG_H
#include "atvoss/util/dag.h"
#include "atvoss/util/placeholder.h"
#include "atvoss/util/vec.h"
#include "kernel_tiling/kernel_tiling.h"
#include "op_kernel/math_util.h"
#ifdef __CCE_AICORE__
#include "op_kernel/platform_util.h"
#include "simt_api/asc_simt.h"
#endif
namespace ModOp {
using namespace Ops::Base;
constexpr int CAST_NONE_MODE = 0;
constexpr int CAST_RINT_MODE = 1;
template <class T>
struct ModIntPostCompute : public Vec::ElemwiseTernaryOP<T, T, T, T> {
__aicore__ inline ModIntPostCompute(
const LocalTensor<T>& dst, const LocalTensor<T>& input1, const LocalTensor<T>& input2,
const LocalTensor<T>& div, const uint32_t& count)
{
#ifdef __CCE_AICORE__
constexpr uint32_t VECTOR_LENGTH = Ops::Base::GetVRegSize();
constexpr uint32_t VL_T = VECTOR_LENGTH / sizeof(T);
__local_mem__ T* input1Addr = (__local_mem__ T*)input1.GetPhyAddr();
__local_mem__ T* input2Addr = (__local_mem__ T*)input2.GetPhyAddr();
__local_mem__ T* divAddr = (__local_mem__ T*)div.GetPhyAddr();
__local_mem__ T* dstAddr = (__local_mem__ T*)dst.GetPhyAddr();
uint16_t loopTimes = Ops::Base::CeilDiv(count, VL_T);
__VEC_SCOPE__
{
MicroAPI::RegTensor<T> zeroValue;
MicroAPI::RegTensor<T> defaultValue;
MicroAPI::RegTensor<T> input1Value;
MicroAPI::RegTensor<T> input2Value;
MicroAPI::RegTensor<T> divValue;
MicroAPI::RegTensor<T> mulValue;
MicroAPI::RegTensor<T> subValue;
MicroAPI::RegTensor<T> resValue;
MicroAPI::MaskReg preg;
MicroAPI::MaskReg cmpValue;
uint32_t sregMask = count;
MicroAPI::Duplicate(zeroValue, T(0));
MicroAPI::Duplicate(defaultValue, T(-1));
for (uint16_t j = 0; j < loopTimes; j++) {
preg = MicroAPI::UpdateMask<T>(sregMask);
MicroAPI::DataCopy<T, MicroAPI::LoadDist::DIST_NORM>(input2Value, input2Addr + VL_T * j);
MicroAPI::DataCopy<T, MicroAPI::LoadDist::DIST_NORM>(divValue, divAddr + VL_T * j);
MicroAPI::Mul(mulValue, input2Value, divValue, preg);
MicroAPI::DataCopy<T, MicroAPI::LoadDist::DIST_NORM>(input1Value, input1Addr + VL_T * j);
MicroAPI::Sub(subValue, input1Value, mulValue, preg);
MicroAPI::Compare<T, CMPMODE::NE>(cmpValue, input2Value, zeroValue, preg);
MicroAPI::Select(resValue, subValue, defaultValue, cmpValue);
MicroAPI::DataCopy<T, MicroAPI::StoreDist::DIST_NORM>(dstAddr + VL_T * j, resValue, preg);
}
}
#endif
}
};
template <typename T1, typename T2>
struct ModCastIntPostCompute : public Vec::ElemwiseTernaryOP<T1, T2, T2, T2> {
__aicore__ inline ModCastIntPostCompute(
const LocalTensor<T1>& dst, const LocalTensor<T2>& input1, const LocalTensor<T2>& input2,
const LocalTensor<T2>& div, const uint32_t& count)
{
#ifdef __CCE_AICORE__
constexpr uint32_t VECTOR_LENGTH = Ops::Base::GetVRegSize();
constexpr uint32_t VL_T = VECTOR_LENGTH / sizeof(T2);
__local_mem__ T2* input1Addr = (__local_mem__ T2*)input1.GetPhyAddr();
__local_mem__ T2* input2Addr = (__local_mem__ T2*)input2.GetPhyAddr();
__local_mem__ T2* divAddr = (__local_mem__ T2*)div.GetPhyAddr();
__local_mem__ T1* dstAddr = (__local_mem__ T1*)dst.GetPhyAddr();
uint16_t loopTimes = CeilDiv(count, VL_T);
__VEC_SCOPE__
{
MicroAPI::RegTensor<T2> zeroValue;
MicroAPI::RegTensor<T2> defaultValue;
MicroAPI::RegTensor<T2> input1Value;
MicroAPI::RegTensor<T2> input2Value;
MicroAPI::RegTensor<T2> divValue;
MicroAPI::RegTensor<T2> mulValue;
MicroAPI::RegTensor<T2> subValue;
MicroAPI::RegTensor<T2> resValue;
MicroAPI::MaskReg preg;
MicroAPI::MaskReg cmpValue;
uint32_t sregMask = count;
MicroAPI::Duplicate(zeroValue, T2(0));
MicroAPI::Duplicate(defaultValue, T2(-1));
for (uint16_t j = 0; j < loopTimes; j++) {
preg = MicroAPI::UpdateMask<T2>(sregMask);
MicroAPI::DataCopy<T2, MicroAPI::LoadDist::DIST_NORM>(input2Value, input2Addr + VL_T * j);
MicroAPI::DataCopy<T2, MicroAPI::LoadDist::DIST_NORM>(divValue, divAddr + VL_T * j);
MicroAPI::Mul(mulValue, input2Value, divValue, preg);
MicroAPI::DataCopy<T2, MicroAPI::LoadDist::DIST_NORM>(input1Value, input1Addr + VL_T * j);
MicroAPI::Sub(subValue, input1Value, mulValue, preg);
MicroAPI::Compare<T2, CMPMODE::NE>(cmpValue, input2Value, zeroValue, preg);
MicroAPI::Select(resValue, subValue, defaultValue, cmpValue);
MicroAPI::DataCopy<T1, MicroAPI::StoreDist::DIST_PACK_B16>(
dstAddr + VL_T * j, (MicroAPI::RegTensor<T1>&)resValue, preg);
}
}
#endif
}
};
#ifdef __CCE_AICORE__
template <typename T>
__simt_vf__ __aicore__
LAUNCH_BOUND(1024) inline void ModIntSimt(__ubuf__ T* dst, __ubuf__ T* src1, __ubuf__ T* src2, int count)
{
for (uint32_t index = static_cast<uint32_t>(threadIdx.x); index < count;
index += static_cast<uint32_t>(blockDim.x)) {
const auto rem = src1[index] % src2[index];
dst[index] = rem;
}
}
#endif
template <class T>
struct ModInt : public Vec::ElemwiseBinaryOP<T, T, T> {
__aicore__ inline ModInt(LocalTensor<T>& dst, LocalTensor<T>& src1, LocalTensor<T>& src2, int count)
{
#ifdef __CCE_AICORE__
__ubuf__ T* dst_1 = (__ubuf__ T*)dst.GetPhyAddr();
__ubuf__ T* src1_1 = (__ubuf__ T*)src1.GetPhyAddr();
__ubuf__ T* src2_1 = (__ubuf__ T*)src2.GetPhyAddr();
asc_vf_call<ModIntSimt<T>>(dim3(1024), dst_1, src1_1, src2_1, count);
#endif
}
};
template <typename T>
struct ModFloatWithCastOp {
using OpInputX1 = Bind<Vec::CopyInBrc<T>, Placeholder::In0<T>>;
using OpInputX2 = Bind<Vec::CopyInBrc<T>, Placeholder::In1<T>>;
using OpCastX1 = Bind<Vec::Cast<float, T, CAST_NONE_MODE>, OpInputX1>;
using OpCastX2 = Bind<Vec::Cast<float, T, CAST_NONE_MODE>, OpInputX2>;
using FmodRes = Bind<Vec::FmodHighPrecision<float>, OpCastX1, OpCastX2>;
using OpCastRes = Bind<Vec::Cast<T, float, CAST_RINT_MODE>, FmodRes>;
using OpCopyOut = Bind<Vec::CopyOut<T>, Placeholder::Out0<T>, OpCastRes>;
using Outputs = Elems<OpCopyOut>;
using MemCfg = MemOptCfg<MemLevel::LEVEL_2>;
using OpDag = DAGSch<Outputs, void, MemCfg>;
};
template <typename T>
struct ModFloatOp {
using OpInputX1 = Bind<Vec::CopyInBrc<T>, Placeholder::In0<T>>;
using OpInputX2 = Bind<Vec::CopyInBrc<T>, Placeholder::In1<T>>;
using FmodRes = Bind<Vec::FmodHighPrecision<T>, OpInputX1, OpInputX2>;
using OpCopyOut = Bind<Vec::CopyOut<T>, Placeholder::Out0<T>, FmodRes>;
using Outputs = Elems<OpCopyOut>;
using MemCfg = MemOptCfg<MemLevel::LEVEL_2>;
using OpDag = DAGSch<Outputs, void, MemCfg>;
};
template <typename T1, typename T2>
struct ModIntWithCastOp {
using OpInputX1 = Bind<Vec::CopyInBrc<T1>, Placeholder::In0<T1>>;
using OpInputX2 = Bind<Vec::CopyInBrc<T1>, Placeholder::In1<T1>>;
using OpCastX1 = Bind<Vec::Cast<T2, T1, CAST_NONE_MODE>, OpInputX1>;
using OpCastX2 = Bind<Vec::Cast<T2, T1, CAST_NONE_MODE>, OpInputX2>;
using DivRes = Bind<Vec::Div<T2>, OpCastX1, OpCastX2>;
using Output = Bind<ModCastIntPostCompute<T1, T2>, OpCastX1, OpCastX2, DivRes>;
using OpCopyOut = Bind<Vec::CopyOut<T1>, Placeholder::Out0<T1>, Output>;
using Outputs = Elems<OpCopyOut>;
using MemCfg = MemOptCfg<MemLevel::LEVEL_2>;
using OpDag = DAGSch<Outputs, void, MemCfg>;
};
template <typename T>
struct ModIntOp {
using OpInputX1 = Bind<Vec::CopyInBrc<T>, Placeholder::In0<T>>;
using OpInputX2 = Bind<Vec::CopyInBrc<T>, Placeholder::In1<T>>;
using DivRes = Bind<Vec::Div<T>, OpInputX1, OpInputX2>;
using Output = Bind<ModIntPostCompute<T>, OpInputX1, OpInputX2, DivRes>;
using OpCopyOut = Bind<Vec::CopyOut<T>, Placeholder::Out0<T>, Output>;
using Outputs = Elems<OpCopyOut>;
using MemCfg = MemOptCfg<MemLevel::LEVEL_2>;
using OpDag = DAGSch<Outputs, void, MemCfg>;
};
template <typename T>
struct ModInt64Op {
using OpInputX1 = Bind<Vec::CopyInBrc<T>, Placeholder::In0<T>>;
using OpInputX2 = Bind<Vec::CopyInBrc<T>, Placeholder::In1<T>>;
using FmodRes = Bind<ModInt<T>, OpInputX1, OpInputX2>;
using OpCopyOut = Bind<Vec::CopyOut<T>, Placeholder::Out0<T>, FmodRes>;
using Outputs = Elems<OpCopyOut>;
using MemCfg = MemOptCfg<MemLevel::LEVEL_2>;
using OpDag = DAGSch<Outputs, void, MemCfg>;
};
}
#endif