* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file muls_dag.h
* \brief muls dag
*/
#ifndef MULS_DAG_H
#define MULS_DAG_H
#include "atvoss/util/dag.h"
#include "atvoss/util/vec.h"
#include "atvoss/util/placeholder.h"
namespace MulsDag
{
using namespace Ops::Base;
constexpr int CAST_MODE_NONE = 0;
constexpr int CAST_MODE_RINT = 1;
constexpr int CAST_MODE_ROUND = 4;
constexpr int CAST_MODE_TRUNC = 5;
constexpr uint32_t COUNT_DOUBLE = 2;
constexpr int8_t REG_BIT_59 = 59;
template <class T1, class T2>
struct CastComplex32ToComplex64 : public Vec::ElemwiseUnaryOP<T1, T2> {
__aicore__ inline CastComplex32ToComplex64(LocalTensor<T1>& dst, LocalTensor<T2>& src, uint32_t count)
{
#ifdef __CCE_AICORE__
AscendC::Cast(dst.template ReinterpretCast<float>(), src.template ReinterpretCast<half>(), RoundMode::CAST_NONE,
count * COUNT_DOUBLE);
#endif
}
};
template <class T1, class T2>
struct CastComplex64ToComplex32 : public Vec::ElemwiseUnaryOP<T1, T2> {
__aicore__ inline CastComplex64ToComplex32(LocalTensor<T1>& dst, LocalTensor<T2>& src, uint32_t count)
{
#ifdef __CCE_AICORE__
AscendC::Cast(dst.template ReinterpretCast<half>(), src.template ReinterpretCast<float>(), RoundMode::CAST_RINT,
count * COUNT_DOUBLE);
#endif
}
};
template <class T1, class T2>
struct MulsComplex64Float : public Vec::ElemwiseBinaryOP<T1, T1, T2> {
__aicore__ inline MulsComplex64Float(const LocalTensor<T1>& dst, const LocalTensor<T1>& src, const T2& scalar,
int count)
{
#ifdef __CCE_AICORE__
AscendC::Muls(dst.template ReinterpretCast<float>(), src.template ReinterpretCast<float>(),
static_cast<float>(scalar), count * COUNT_DOUBLE);
#endif
}
};
template <class T1, class T2>
struct CastFp32ToInt16 : public Vec::ElemwiseUnaryOP<int16_t, float> {
__aicore__ inline CastFp32ToInt16(LocalTensor<T1>& dst, LocalTensor<T2>& src, uint32_t count)
{
#ifdef __CCE_AICORE__
AscendC::SetCtrlSpr<REG_BIT_59, REG_BIT_59>(1);
AscendC::Cast(dst, src, RoundMode::CAST_TRUNC, count);
AscendC::SetCtrlSpr<REG_BIT_59, REG_BIT_59>(0);
#endif
}
};
struct MulsInt16Op {
using InputX = Bind<Vec::CopyIn<int16_t>, Placeholder::In0<int16_t>>;
using CastX = Bind<Vec::Cast<float, int16_t, CAST_MODE_NONE>, InputX>;
using Y = Bind<Vec::Muls<float>, CastX, Placeholder::Var<float, 0>>;
using CastY = Bind<CastFp32ToInt16<int16_t, float>, Y>;
using OpCopyOut = Bind<Vec::CopyOut<int16_t>, Placeholder::Out0<int16_t>, CastY>;
using Outputs = Elems<OpCopyOut>;
using MemCfg = MemOptCfg<MemLevel::LEVEL_2>;
using OpDag = DAGSch<Outputs, void, MemCfg>;
};
template <typename T, int castMode1, int castMode2>
struct MulsOp {
using InputX = Bind<Vec::CopyIn<T>, Placeholder::In0<T>>;
using CastX = Bind<Vec::Cast<float, T, castMode1>, InputX>;
using Y = Bind<Vec::Muls<float>, CastX, Placeholder::Var<float, 0>>;
using CastY = Bind<Vec::Cast<T, float, castMode2>, Y>;
using OpCopyOut = Bind<Vec::CopyOut<T>, Placeholder::Out0<T>, CastY>;
using Outputs = Elems<OpCopyOut>;
using MemCfg = MemOptCfg<MemLevel::LEVEL_2>;
using OpDag = DAGSch<Outputs, void, MemCfg>;
};
template <typename T, typename PromoteT>
struct MulsComplex32Op {
using InputX = Bind<Vec::CopyIn<T>, Placeholder::In0<T>>;
using CastX = Bind<CastComplex32ToComplex64<PromoteT, T>, InputX>;
using Y = Bind<MulsComplex64Float<PromoteT, float>, CastX, Placeholder::Var<float, 0>>;
using CastY = Bind<CastComplex64ToComplex32<T, PromoteT>, Y>;
using OpCopyOut = Bind<Vec::CopyOut<T>, Placeholder::Out0<T>, CastY>;
using Outputs = Elems<OpCopyOut>;
using MemCfg = MemOptCfg<MemLevel::LEVEL_2>;
using OpDag = DAGSch<Outputs, void, MemCfg>;
};
template <typename T>
struct MulsComplex64Op {
using InputX = Bind<Vec::CopyIn<T>, Placeholder::In0<T>>;
using Y = Bind<MulsComplex64Float<T, float>, InputX, Placeholder::Var<float, 0>>;
using OpCopyOut = Bind<Vec::CopyOut<T>, Placeholder::Out0<T>, Y>;
using Outputs = Elems<OpCopyOut>;
using MemCfg = MemOptCfg<MemLevel::LEVEL_2>;
using OpDag = DAGSch<Outputs, void, MemCfg>;
};
}
#endif