* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (C) 2026. Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DEVICE_UTILS_H
#define DEVICE_UTILS_H
#include "asc_fp16.h"
#include "asc_bf16.h"
namespace dyn_emb {
template <typename TOUT, typename TIN>
struct SimdTypeConvertFunc;
template <>
struct SimdTypeConvertFunc<bfloat16_t, float> {
__aicore__ __inline__ static bfloat16_t convert(float val)
{
return AscendC::ToBfloat16(val);
}
};
template <>
struct SimdTypeConvertFunc<float, float> {
__aicore__ __inline__ static float convert(float val)
{
return val;
}
};
template <>
struct SimdTypeConvertFunc<half, float> {
__aicore__ __inline__ static half convert(float val)
{
return AscendC::ScalarCast<float, half, AscendC::RoundMode::CAST_ODD>(val);
}
};
#define DEVICE_INLINE __simt_callee__ __forceinline__
template <typename TOUT, typename TIN>
struct TypeConvertFunc;
template <>
struct TypeConvertFunc<half, float> {
static DEVICE_INLINE half convert(float val)
{
return __float2half(val);
}
};
template <>
struct TypeConvertFunc<float, half> {
static DEVICE_INLINE float convert(half val)
{
return __half2float(val);
}
};
template <>
struct TypeConvertFunc<bfloat16_t, float> {
static DEVICE_INLINE bfloat16_t convert(float val)
{
return __float2bfloat16(val);
}
};
template <>
struct TypeConvertFunc<float, bfloat16_t> {
static DEVICE_INLINE float convert(bfloat16_t val)
{
return __bfloat162float(val);
}
};
template <>
struct TypeConvertFunc<bfloat16_t, half> {
static DEVICE_INLINE bfloat16_t convert(half val)
{
float temp = __half2float(val);
return __float2bfloat16(temp);
}
};
template <>
struct TypeConvertFunc<half, bfloat16_t> {
static DEVICE_INLINE half convert(bfloat16_t val)
{
float temp = __bfloat162float(val);
return __float2half(temp);
}
};
template <>
struct TypeConvertFunc<float, float> {
static DEVICE_INLINE float convert(float val)
{
return val;
}
};
template <>
struct TypeConvertFunc<half, half> {
static DEVICE_INLINE half convert(half val)
{
return val;
}
};
template <>
struct TypeConvertFunc<bfloat16_t, bfloat16_t> {
static DEVICE_INLINE bfloat16_t convert(bfloat16_t val)
{
return val;
}
};
template <>
struct TypeConvertFunc<float, long long> {
static DEVICE_INLINE float convert(long long val)
{
return static_cast<float>(val);
}
};
template <>
struct TypeConvertFunc<float, unsigned int> {
static DEVICE_INLINE float convert(unsigned int val)
{
return static_cast<float>(val);
}
};
template <>
struct TypeConvertFunc<int, long long> {
static DEVICE_INLINE int convert(long long val)
{
return static_cast<int>(val);
}
};
template <>
struct TypeConvertFunc<int, unsigned int> {
static DEVICE_INLINE int convert(unsigned int val)
{
return static_cast<int>(val);
}
};
}
#endif