/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file common.h
 * \brief
 */

#ifndef MC2_MOE_DISPATCH_COMM_H
#define MC2_MOE_DISPATCH_COMM_H

namespace Mc2Kernel{
constexpr uint32_t NEED_ONE_HUNDRED_AND_TWENTY_SEVEN = 127;
constexpr uint32_t RIGHT_SHIFT_BIT_SEVEN = 7;
constexpr uint32_t NEED_THIRTY_FIRST = 31;
constexpr uint32_t ALIGN_UP_TO_2_MASK = 1;
constexpr uint32_t ALIGN_UP_TO_32_MASK = 31;
constexpr uint32_t ALIGN_UP_TO_64_MASK = 64;
constexpr uint32_t ALIGN_UP_TO_128_MASK = 127;
constexpr uint32_t ALIGN_UP_TO_256_MASK = 255;
constexpr uint32_t ALIGN_UP_TO_512_MASK = 511;
constexpr uint32_t RIGHT_SHIFT_BIT_FIVE = 5;
constexpr uint32_t FIVE_HUNDRED_AND_ELEVEN = 511;
constexpr uint32_t RIGHT_SHIFT_BIT_NINE = 9;
}

namespace AscendC {
template <typename T1, typename T2>
__aicore__ inline T2 Ceil(T1 x, T1 y)
{
    return (x + y - 1) / y;
}

template <typename T>
__aicore__ inline T Ceil32(T x)
{
    return (x + Mc2Kernel::NEED_THIRTY_FIRST) >> Mc2Kernel::RIGHT_SHIFT_BIT_FIVE;
}

template <typename T>
__aicore__ inline T Ceil128(T x)
{
    return (x + Mc2Kernel::NEED_ONE_HUNDRED_AND_TWENTY_SEVEN) >> Mc2Kernel::RIGHT_SHIFT_BIT_SEVEN;
}

template <typename T>
__aicore__ inline T Ceil512(T x)
{
    return (x + Mc2Kernel::FIVE_HUNDRED_AND_ELEVEN) >> Mc2Kernel::RIGHT_SHIFT_BIT_NINE;
}

template <typename T1, typename T2>
__aicore__ inline T2 Align(T1 x, T1 y)
{
    return Ceil<T1, T2>(x, y) * y;
}

template <typename T>
__aicore__ inline T Align2(T x)
{
    return (x + Mc2Kernel::ALIGN_UP_TO_2_MASK) & (~Mc2Kernel::ALIGN_UP_TO_2_MASK);
}

template <typename T>
__aicore__ inline T Align32(T x)
{
    return (x + Mc2Kernel::ALIGN_UP_TO_32_MASK) & (~Mc2Kernel::ALIGN_UP_TO_32_MASK);
}

template <typename T>
__aicore__ inline T Align64(T x)
{
    return (x + Mc2Kernel::ALIGN_UP_TO_64_MASK) & (~Mc2Kernel::ALIGN_UP_TO_64_MASK);
}

template <typename T>
__aicore__ inline T Align128(T x)
{
    return (x + Mc2Kernel::ALIGN_UP_TO_128_MASK) & (~Mc2Kernel::ALIGN_UP_TO_128_MASK);
}

template <typename T>
__aicore__ inline T Align256(T x)
{
    return (x + Mc2Kernel::ALIGN_UP_TO_256_MASK) & (~Mc2Kernel::ALIGN_UP_TO_256_MASK);
}

template <typename T>
__aicore__ inline T Align512(T x)
{
    return (x + Mc2Kernel::ALIGN_UP_TO_512_MASK) & (~Mc2Kernel::ALIGN_UP_TO_512_MASK);
}

template <MicroAPI::HistogramsType htype, typename T, typename U>
static __aicore__ inline void HistogramsVf(__local_mem__ U* dst, __local_mem__ T* src, uint16_t repeatElm,
                                           uint16_t halfRepeat, uint32_t totalElm, uint16_t repeatTimes)
{
    AscendC::MicroAPI::RegTensor<T> srcReg;
    AscendC::MicroAPI::RegTensor<U> dst0Reg;
    AscendC::MicroAPI::RegTensor<U> dst1Reg;
    AscendC::MicroAPI::MaskReg pregOut = AscendC::MicroAPI::CreateMask<T>();
    MicroAPI::Duplicate(dst0Reg, 0);
    MicroAPI::Duplicate(dst1Reg, 0);
    for (uint16_t i = 0; i < repeatTimes; ++i) {
        MicroAPI::MaskReg preg = MicroAPI::UpdateMask<T>(totalElm);
        MicroAPI::DataCopy(srcReg, src + repeatElm * i);
        MicroAPI::Histograms<T, U, MicroAPI::HistogramsBinType::BIN0, htype>(dst0Reg, srcReg, preg);
        MicroAPI::Histograms<T, U, MicroAPI::HistogramsBinType::BIN1, htype>(dst1Reg, srcReg, preg);
    }
    MicroAPI::DataCopy(dst, dst0Reg, pregOut);
    MicroAPI::DataCopy(dst + halfRepeat, dst1Reg, pregOut);
}

__aicore__ inline void GetExpertFreq(LocalTensor<uint16_t>& dstLocal, LocalTensor<uint8_t>& srcLocal, uint32_t totalElm)
{
    uint32_t repeatElm = GetVecLen();
    uint16_t repeatTimes = Ceil<uint32_t, uint16_t>(totalElm, repeatElm);
    __local_mem__ uint8_t* src = (__local_mem__ uint8_t*)srcLocal.GetPhyAddr();
    __local_mem__ uint16_t* dst = (__local_mem__ uint16_t*)dstLocal.GetPhyAddr();
    VF_CALL<HistogramsVf<MicroAPI::HistogramsType::FREQUENCY, uint8_t, uint16_t>>(dst, src, repeatElm, repeatElm >> 1,
                                                                                  totalElm, repeatTimes);
    PipeBarrier<PIPE_V>();
}

__aicore__ inline void GetExpertCumSum(LocalTensor<uint16_t>& dstLocal, LocalTensor<uint8_t>& srcLocal,
                                       uint32_t totalElm)
{
    uint32_t repeatElm = GetVecLen();
    uint16_t repeatTimes = Ceil<uint32_t, uint16_t>(totalElm, repeatElm);
    __local_mem__ uint8_t *src = (__local_mem__ uint8_t *)srcLocal.GetPhyAddr();
    __local_mem__ uint16_t *dst = (__local_mem__ uint16_t *)dstLocal.GetPhyAddr();
    VF_CALL<HistogramsVf<MicroAPI::HistogramsType::ACCUMULATE, uint8_t, uint16_t>>(dst, src, repeatElm, repeatElm >> 1,
                                                                                   totalElm, repeatTimes);
    PipeBarrier<PIPE_V>();
}

static __aicore__ inline void ReduceLoop(__local_mem__ int32_t* dst, __local_mem__ int32_t* src, uint16_t repeat0Times,
                                         uint32_t repeat0SrcStride, uint16_t repeat1Times, uint32_t repeat1Stride,
                                         uint32_t repeat1Element, uint32_t repeat0DstStride)
{
    MicroAPI::MaskReg maskFirst = MicroAPI::CreateMask<int32_t, MicroAPI::MaskPattern::VL1>();
    for (uint16_t i0 = 0; i0 < repeat0Times; ++i0) {
        MicroAPI::RegTensor<int32_t> sumReg;
        uint32_t elements = repeat1Element;
        MicroAPI::Duplicate(sumReg, 0);
        for (uint16_t i1 = 0; i1 < repeat1Times; ++i1) {
            MicroAPI::RegTensor<int32_t> srcReg;
            MicroAPI::RegTensor<int32_t> dstReg;
            MicroAPI::MaskReg mask = MicroAPI::UpdateMask<int32_t>(elements);
            MicroAPI::DataCopy(srcReg, src + i0 * repeat0SrcStride + i1 * repeat1Stride);
            MicroAPI::ReduceSum(dstReg, srcReg, mask);
            MicroAPI::Add(sumReg, sumReg, dstReg, maskFirst);
        }
        MicroAPI::DataCopy<int32_t, MicroAPI::StoreDist::DIST_FIRST_ELEMENT_B32>(dst + i0 * repeat0DstStride, sumReg,
                                                                                 maskFirst);
    }
}

__aicore__ inline void GetReduceSum(LocalTensor<int32_t>& dstLocal, LocalTensor<int32_t>& srcLocal,
                                    uint16_t repeat0Times, uint32_t repeat0SrcStride, uint32_t repeat1Element,
                                    uint32_t repeat0DstStride)
{
    if (repeat0Times == 0 || repeat1Element == 0) {
        return;
    }

    uint32_t repeat1Stride = GetVecLen() / sizeof(int32_t);
    uint16_t repeat1Times = Ceil<uint32_t, uint16_t>(repeat1Element, repeat1Stride);
    __local_mem__ int32_t* src = (__local_mem__ int32_t*)srcLocal.GetPhyAddr();
    __local_mem__ int32_t* dst = (__local_mem__ int32_t*)dstLocal.GetPhyAddr();
    VF_CALL<ReduceLoop>(dst, src, repeat0Times, repeat0SrcStride, repeat1Times, repeat1Stride, repeat1Element,
                        repeat0DstStride);
    PipeBarrier<PIPE_V>();
}

}  // namespace AscendC

#endif