* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file mul_dag.h
* \brief mul dag
*/
#ifndef MUL_DAG_H
#define MUL_DAG_H
#include "atvoss/util/dag.h"
#include "atvoss/util/vec.h"
#include "atvoss/util/placeholder.h"
#ifdef __CCE_AICORE__
#include "simt_api/asc_simt.h"
#endif
namespace MulDag {
using namespace AscendC;
using namespace Ops::Base;
constexpr int CAST_MODE_NONE = 0;
constexpr int CAST_MODE_RINT = 1;
constexpr int CAST_MODE_FLOOR = 2;
constexpr int CAST_MODE_ROUND = 4;
constexpr uint32_t COUNT_DOUBLE = 2;
constexpr std::int32_t FF = 255;
constexpr int SIGNMASK = 0x80000000;
constexpr int EXPMASK = 0x7FF << 20;
constexpr int MMASK = 0xFFFFF;
constexpr int NANI = 0x7FF80000;
constexpr int MAX_EXP = 2048 - 1;
constexpr int INT32_MMAXMASK = 0x7FFFFFFF;
constexpr int SUBNORMAL_MASK = 0xFFFF0000;
constexpr int EXP_OFFSET = 20;
constexpr int EXP_BIAS = 1024 - 1;
constexpr int BIT_ALIGN = 8;
constexpr int OFFSET_16 = 16;
constexpr int OFFSET_32 = 32;
constexpr int OFFSET_64 = 64;
constexpr int NOSIGN_OFFSET_32 = 32 - 1;
constexpr int SUBN_BOUND_52 = -53;
constexpr int SUBN_BOUND_32 = -33;
constexpr int SUBN_BOUND_16 = -17;
constexpr int MAX_SAFE_LSHIFT_20 = 20 - 1;
constexpr int HIGH_BIT_OFFSET_32 = 30;
union U {
unsigned int i[2];
double d;
unsigned long long u;
};
union ll {
unsigned int i[2];
unsigned long long il;
};
template <class T1, class T2>
struct CastComplex32ToComplex64 : public Vec::ElemwiseUnaryOP<T1, T2> {
__aicore__ inline CastComplex32ToComplex64(LocalTensor<T1>& dst, LocalTensor<T2>& src, uint32_t count)
{
#ifdef __CCE_AICORE__
AscendC::Cast(
dst.template ReinterpretCast<float>(), src.template ReinterpretCast<half>(), RoundMode::CAST_NONE,
count * COUNT_DOUBLE);
#endif
}
};
template <class T1, class T2>
struct CastComplex64ToComplex32 : public Vec::ElemwiseUnaryOP<T1, T2> {
__aicore__ inline CastComplex64ToComplex32(LocalTensor<T1>& dst, LocalTensor<T2>& src, uint32_t count)
{
#ifdef __CCE_AICORE__
AscendC::Cast(
dst.template ReinterpretCast<half>(), src.template ReinterpretCast<float>(), RoundMode::CAST_RINT,
count * COUNT_DOUBLE);
#endif
}
};
template <class T>
struct AndFF : public Vec::ElemwiseUnaryOP<T, T> {
__aicore__ inline AndFF(const LocalTensor<T>& dst, const LocalTensor<T>& src, uint32_t count)
{
#ifdef __CCE_AICORE__
AscendC::Duplicate(dst, static_cast<T>(FF), count);
AscendC::And(dst, src, dst, count);
#endif
}
};
#ifdef __CCE_AICORE__
template <typename T>
__simt_vf__ __aicore__
LAUNCH_BOUND(1024) inline void MulDouble_vf(__ubuf__ T* dst, __ubuf__ T* src1, __ubuf__ T* src2, int count)
{
for (uint32_t index = static_cast<uint32_t>(threadIdx.x); index < count;
index += static_cast<uint32_t>(blockDim.x)) {
U a;
U b;
a.d = src1[index];
b.d = src2[index];
int sa = a.i[1] & SIGNMASK;
int sb = b.i[1] & SIGNMASK;
int ea = (a.i[1] & EXPMASK) >> EXP_OFFSET;
int eb = (b.i[1] & EXPMASK) >> EXP_OFFSET;
U c;
if (ea == MAX_EXP && ((a.i[1] & MMASK) > 0 || a.i[0] > 0)) {
c.i[1] = NANI;
dst[index] = c.d;
continue;
}
if (eb == MAX_EXP && ((b.i[1] & MMASK) > 0 || b.i[0] > 0)) {
c.i[1] = NANI;
dst[index] = c.d;
continue;
}
if (ea == MAX_EXP) {
if (eb == 0 && (b.i[1] & MMASK) == 0 && b.i[0] == 0) {
c.i[1] = NANI;
dst[index] = c.d;
continue;
}
c = a;
c.i[1] ^= (b.i[1] & SIGNMASK);
dst[index] = c.d;
continue;
}
if (eb == MAX_EXP) {
if (ea == 0 && (a.i[1] & MMASK) == 0 && a.i[0] == 0) {
c.i[1] = NANI;
dst[index] = c.d;
continue;
}
c = b;
c.i[1] ^= (a.i[1] & SIGNMASK);
dst[index] = c.d;
continue;
}
if (((a.i[1] & INT32_MMAXMASK) == 0) && (a.i[0] == 0)) {
c.d = 0;
dst[index] = c.d;
continue;
}
if (((b.i[1] & INT32_MMAXMASK) == 0) && (b.i[0] == 0)) {
c.d = 0;
dst[index] = c.d;
continue;
}
bool aflag = (ea > 0);
bool bflag = (eb > 0);
ea = ea - EXP_BIAS + 1 - aflag;
eb = eb - EXP_BIAS + 1 - bflag;
unsigned int ma_high = (a.i[1] & MMASK) + (aflag << EXP_OFFSET);
unsigned int ma_low = a.i[0];
unsigned int mb_high = (b.i[1] & MMASK) + (bflag << EXP_OFFSET);
unsigned int mb_low = b.i[0];
ll res[3];
res[0].il = (unsigned long long)ma_low * (unsigned long long)mb_low;
res[1].i[0] = res[0].i[1];
res[1].i[1] = 0;
res[1].il += (unsigned long long)ma_low * (unsigned long long)mb_high;
res[1].il += (unsigned long long)ma_high * (unsigned long long)mb_low;
res[2].i[0] = res[1].i[1];
res[2].i[1] = 0;
res[2].il += (unsigned long long)ma_high * (unsigned long long)mb_high;
unsigned int s3 = res[2].i[1];
unsigned int s2 = res[2].i[0];
unsigned int s1 = res[1].i[0];
unsigned int s0 = res[0].i[0];
int total_shift = 0;
if (s3 != 0) {
total_shift = 0;
} else if ((s2 & SUBNORMAL_MASK) != 0) {
total_shift = OFFSET_16;
} else if (s2 != 0) {
total_shift = OFFSET_32;
} else if ((s1 & SUBNORMAL_MASK) != 0) {
total_shift = OFFSET_32 + OFFSET_16;
} else if (s1 != 0) {
total_shift = OFFSET_64;
} else if ((s0 & SUBNORMAL_MASK) != 0) {
total_shift = OFFSET_64 + OFFSET_16;
} else if (s0 != 0) {
total_shift = OFFSET_64 + OFFSET_32;
} else {
total_shift = 0;
s3 = 0;
}
if (total_shift == OFFSET_16) {
s3 = (s3 << OFFSET_16) | (s2 >> OFFSET_16);
s2 = (s2 << OFFSET_16) | (s1 >> OFFSET_16);
s1 = (s1 << OFFSET_16) | (s0 >> OFFSET_16);
s0 = (s0 << OFFSET_16);
} else if (total_shift == OFFSET_32) {
s3 = s2;
s2 = s1;
s1 = s0;
s0 = 0;
} else if (total_shift == (OFFSET_32 + OFFSET_16)) {
s3 = (s1 >> OFFSET_16);
s2 = (s1 << OFFSET_16) | (s0 >> OFFSET_16);
s1 = (s0 << OFFSET_16);
s0 = 0;
} else if (total_shift == OFFSET_64) {
s3 = s1;
s2 = s0;
s1 = 0;
s0 = 0;
} else if (total_shift == (OFFSET_64 + OFFSET_16)) {
s3 = (s0 >> OFFSET_16);
s2 = (s0 << OFFSET_16);
s1 = 0;
s0 = 0;
} else if (total_shift == (OFFSET_64 + OFFSET_32)) {
s3 = s0;
s2 = 0;
s1 = 0;
s0 = 0;
}
int bias = -total_shift;
int n = NOSIGN_OFFSET_32 - __builtin_clz(s3);
unsigned int mc_high = 0, mc_low = 0;
int sc = sa ^ sb;
int ec = ea + (n - BIT_ALIGN) + eb + EXP_BIAS + bias;
if (ec >= MAX_EXP) {
ec = MAX_EXP;
} else if (ec <= 0) {
if (ec <= SUBN_BOUND_52) {
c.d = 0;
dst[index] = c.d;
continue;
}
if (ec <= SUBN_BOUND_32) {
s0 = s1;
s1 = s2;
s2 = s3;
s3 = 0;
ec += OFFSET_32;
} else if (ec <= SUBN_BOUND_16) {
s0 = (s0 >> OFFSET_16) | (s1 << OFFSET_16);
s1 = (s1 >> OFFSET_16) | (s2 << OFFSET_16);
s2 = (s2 >> OFFSET_16) | (s3 << OFFSET_16);
s3 = s3 >> OFFSET_16;
ec += OFFSET_16;
}
int lshift = MAX_SAFE_LSHIFT_20 + ec - n;
if (lshift > 0) {
mc_high = s3 << lshift;
mc_high += s2 >> (OFFSET_32 - lshift);
unsigned int roundvalue = 1 << (HIGH_BIT_OFFSET_32 - lshift);
mc_low = (s2 << lshift);
mc_low += ((s1 >> 1) + roundvalue) >> (NOSIGN_OFFSET_32 - lshift);
} else if (lshift < 0) {
int rshift = -lshift;
mc_high = s3 >> rshift;
mc_low = s3 << (OFFSET_32 - rshift);
unsigned int roundvalue = 1 << (rshift - 1);
mc_low += (s2 + roundvalue) >> (rshift);
} else {
mc_high = s3;
mc_low = s2;
mc_low += ((s1 >> 1) + (1 << HIGH_BIT_OFFSET_32)) >> NOSIGN_OFFSET_32;
}
ec = 0;
} else {
int lshift = EXP_OFFSET - n;
if (lshift > 0) {
mc_high = (s3 << lshift) - (1 << EXP_OFFSET);
mc_high += s2 >> (OFFSET_32 - lshift);
mc_low = s2 << lshift;
unsigned int roundvalue = 1 << (HIGH_BIT_OFFSET_32 - lshift);
unsigned int s1value = (((s1 >> 1) + roundvalue) >> (NOSIGN_OFFSET_32 - lshift));
int p = (~mc_low < s1value);
mc_low += s1value;
mc_high += p;
} else if (lshift < 0) {
int rshift = -lshift;
mc_high = (s3 >> rshift) - (1 << EXP_OFFSET);
mc_low = s3 << (OFFSET_32 - rshift);
unsigned int roundvalue = 1 << (rshift - 1);
unsigned int s2value = (s2 + roundvalue) >> (rshift);
int p = (~mc_low < s2value);
mc_low += s2value;
mc_high += p;
} else {
mc_high = s3 - (1 << EXP_OFFSET);
mc_low = s2;
mc_low += ((s1 >> 1) + (1 << HIGH_BIT_OFFSET_32)) >> NOSIGN_OFFSET_32;
unsigned int s2value = ((s1 >> 1) + (1 << HIGH_BIT_OFFSET_32)) >> NOSIGN_OFFSET_32;
int p = (~mc_low < s2value);
mc_low += s2value;
mc_high += p;
}
}
c.i[1] = sc + (ec << EXP_OFFSET) + mc_high;
c.i[0] = mc_low;
dst[index] = c.d;
}
}
#endif
template <class T>
struct MulDouble : public Vec::ElemwiseBinaryOP<T, T, T> {
__aicore__ inline MulDouble(LocalTensor<T>& dst, LocalTensor<T>& src1, LocalTensor<T>& src2, int count)
{
#ifdef __CCE_AICORE__
__ubuf__ T* dst_1 = (__ubuf__ T*)dst.GetPhyAddr();
__ubuf__ T* src1_1 = (__ubuf__ T*)src1.GetPhyAddr();
__ubuf__ T* src2_1 = (__ubuf__ T*)src2.GetPhyAddr();
asc_vf_call<MulDouble_vf<T>>(dim3(1024), dst_1, src1_1, src2_1, count);
#endif
}
};
template <typename T>
struct MulOp {
using InputX1 = Bind<Vec::CopyInBrc<T>, Placeholder::In0<T>>;
using InputX2 = Bind<Vec::CopyInBrc<T>, Placeholder::In1<T>>;
using Y = Bind<Vec::Mul<T>, InputX1, InputX2>;
using OpCopyOut = Bind<Vec::CopyOut<T>, Placeholder::Out0<T>, Y>;
using Outputs = Elems<OpCopyOut>;
using MemCfg = MemOptCfg<MemLevel::LEVEL_2>;
using OpDag = DAGSch<Outputs, void, MemCfg>;
};
template <typename T>
struct MulXfp16Op {
using InputX1 = Bind<Vec::CopyInBrc<T>, Placeholder::In0<T>>;
using InputX2 = Bind<Vec::CopyInBrc<T>, Placeholder::In1<T>>;
using CastX1 = Bind<Vec::Cast<float, T, CAST_MODE_NONE>, InputX1>;
using CastX2 = Bind<Vec::Cast<float, T, CAST_MODE_NONE>, InputX2>;
using Y = Bind<Vec::Mul<float>, CastX1, CastX2>;
using YB16 = Bind<Vec::Cast<T, float, CAST_MODE_RINT>, Y>;
using OpCopyOut = Bind<Vec::CopyOut<T>, Placeholder::Out0<T>, YB16>;
using Outputs = Elems<OpCopyOut>;
using MemCfg = MemOptCfg<MemLevel::LEVEL_2>;
using OpDag = DAGSch<Outputs, void, MemCfg>;
};
template <typename T, typename PromoteT>
struct MulComplex32Op {
using InputX1 = Bind<Vec::CopyInBrc<T>, Placeholder::In0<T>>;
using InputX2 = Bind<Vec::CopyInBrc<T>, Placeholder::In1<T>>;
using CastX1 = Bind<CastComplex32ToComplex64<PromoteT, T>, InputX1>;
using CastX2 = Bind<CastComplex32ToComplex64<PromoteT, T>, InputX2>;
using Y = Bind<Vec::Mul<PromoteT>, CastX1, CastX2>;
using YCast = Bind<CastComplex64ToComplex32<T, PromoteT>, Y>;
using OpCopyOut = Bind<Vec::CopyOut<T>, Placeholder::Out0<T>, YCast>;
using Outputs = Elems<OpCopyOut>;
using MemCfg = MemOptCfg<MemLevel::LEVEL_2>;
using OpDag = DAGSch<Outputs, void, MemCfg>;
};
template <typename T1, typename T2, typename PromoteT>
struct MulMixFpOp {
using InputX1 = Bind<Vec::CopyInBrc<T1>, Placeholder::In0<T1>>;
using InputX2 = Bind<Vec::CopyInBrc<T2>, Placeholder::In1<T2>>;
using CastX1 = Bind<Vec::Cast<PromoteT, T1, CAST_MODE_NONE>, InputX1>;
using CastX2 = Bind<Vec::Cast<PromoteT, T2, CAST_MODE_NONE>, InputX2>;
using Y = Bind<Vec::Mul<PromoteT>, CastX1, CastX2>;
using OpCopyOut = Bind<Vec::CopyOut<PromoteT>, Placeholder::Out0<PromoteT>, Y>;
using Outputs = Elems<OpCopyOut>;
using MemCfg = MemOptCfg<MemLevel::LEVEL_2>;
using OpDag = DAGSch<Outputs, void, MemCfg>;
};
struct MulInt8Op {
using InputX1 = Bind<Vec::CopyInBrc<int8_t>, Placeholder::In0<int8_t>>;
using InputX2 = Bind<Vec::CopyInBrc<int8_t>, Placeholder::In1<int8_t>>;
using CastX1 = Bind<Vec::Cast<int32_t, int8_t, CAST_MODE_NONE>, InputX1>;
using CastX2 = Bind<Vec::Cast<int32_t, int8_t, CAST_MODE_NONE>, InputX2>;
using Y = Bind<Vec::Mul<int32_t>, CastX1, CastX2>;
using Y1 = Bind<AndFF<int32_t>, Y>;
using Y2 = Bind<Vec::Cast<uint8_t, int32_t, CAST_MODE_NONE>, Y1>;
using OpCopyOut = Bind<Vec::CopyOut<int8_t>, Placeholder::Out0<int8_t>, Y2>;
using Outputs = Elems<OpCopyOut>;
using MemCfg = MemOptCfg<MemLevel::LEVEL_2>;
using OpDag = DAGSch<Outputs, void, MemCfg>;
};
struct MulUint8Op {
using InputX1 = Bind<Vec::CopyInBrc<uint8_t>, Placeholder::In0<uint8_t>>;
using InputX2 = Bind<Vec::CopyInBrc<uint8_t>, Placeholder::In1<uint8_t>>;
using CastX1 = Bind<Vec::Cast<uint32_t, uint8_t, CAST_MODE_NONE>, InputX1>;
using CastX2 = Bind<Vec::Cast<uint32_t, uint8_t, CAST_MODE_NONE>, InputX2>;
using Y = Bind<Vec::Mul<uint32_t>, CastX1, CastX2>;
using Y1 = Bind<AndFF<int32_t>, Y>;
using Y2 = Bind<Vec::Cast<uint8_t, uint32_t, CAST_MODE_NONE>, Y1>;
using OpCopyOut = Bind<Vec::CopyOut<uint8_t>, Placeholder::Out0<uint8_t>, Y2>;
using Outputs = Elems<OpCopyOut>;
using MemCfg = MemOptCfg<MemLevel::LEVEL_2>;
using OpDag = DAGSch<Outputs, void, MemCfg>;
};
struct MulBoolOp {
using InputX1 = Bind<Vec::CopyInBrc<int8_t>, Placeholder::In0<int8_t>>;
using InputX2 = Bind<Vec::CopyInBrc<int8_t>, Placeholder::In1<int8_t>>;
using CastX1 = Bind<Vec::Cast<half, int8_t, CAST_MODE_NONE>, InputX1>;
using CastX2 = Bind<Vec::Cast<half, int8_t, CAST_MODE_NONE>, InputX2>;
using Y = Bind<Vec::Mul<half>, CastX1, CastX2>;
using YCast = Bind<Vec::Cast<int8_t, half, CAST_MODE_ROUND>, Y>;
using OpCopyOut = Bind<Vec::CopyOut<int8_t>, Placeholder::Out0<int8_t>, YCast>;
using Outputs = Elems<OpCopyOut>;
using MemCfg = MemOptCfg<MemLevel::LEVEL_2>;
using OpDag = DAGSch<Outputs, void, MemCfg>;
};
template <typename T>
struct MulDoubleOp {
using InputX1 = Bind<Vec::CopyInBrc<T>, Placeholder::In0<T>>;
using InputX2 = Bind<Vec::CopyInBrc<T>, Placeholder::In1<T>>;
using Y = Bind<MulDouble<T>, InputX1, InputX2>;
using OpCopyOut = Bind<Vec::CopyOut<T>, Placeholder::Out0<T>, Y>;
using Outputs = Elems<OpCopyOut>;
using MemCfg = MemOptCfg<MemLevel::LEVEL_2>;
using OpDag = DAGSch<Outputs, void, MemCfg>;
};
}
#endif