/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file gcd_dag.h
 * \brief gcd dag
 */

#ifndef GCD_DAG_H
#define GCD_DAG_H

#include "atvoss/util/dag.h"
#include "atvoss/util/vec.h"
#include "atvoss/util/placeholder.h"

#ifdef __CCE_AICORE__
#include "simt_api/asc_simt.h"
#endif

using namespace Ops::Base;
using namespace AscendC;

namespace GcdOp {

union U {
    uint32_t i[2];
    uint64_t u;
};

#ifdef __CCE_AICORE__
template <typename T>
__simt_vf__ __aicore__
    LAUNCH_BOUND(512) inline void GcdVecInt64(__ubuf__ T* dst, __ubuf__ T* src1, __ubuf__ T* src2, int count)
{
    for (uint32_t index = static_cast<uint32_t>(threadIdx.x); index < count;
         index += static_cast<uint32_t>(blockDim.x)) {
        U a;
        U b;
        U c;
        a.u = static_cast<uint64_t>(src1[index]);
        b.u = static_cast<uint64_t>(src2[index]);

        uint64_t mask;
        mask = static_cast<uint64_t>(src1[index] >> 63);
        a.u = (a.u ^ mask) - mask;

        mask = static_cast<uint64_t>(src2[index] >> 63);
        b.u = (b.u ^ mask) - mask;

        if (a.u == 0) {
            dst[index] = b.u;
            continue;
        }
        if (b.u == 0) {
            dst[index] = a.u;
            continue;
        }

        uint8_t offset;

        offset = 0;
        c.u = a.u | b.u;
        if (c.i[0] != 0) {
            offset = __builtin_ffs(c.i[0]);
        } else if (c.i[1] != 0) {
            offset = __builtin_ffs(c.i[1]) + 32;
        }
        uint8_t shift = offset - 1;

        offset = 0;
        if (a.i[0] != 0) {
            offset = __builtin_ffs(a.i[0]);
        } else if (a.i[1] != 0) {
            offset = __builtin_ffs(a.i[1]) + 32;
        }
        a.u >>= offset - 1;

        offset = 0;
        if (b.i[0] != 0) {
            offset = __builtin_ffs(b.i[0]);
        } else if (b.i[1] != 0) {
            offset = __builtin_ffs(b.i[1]) + 32;
        }
        b.u >>= offset - 1;

        while (b.u != 0) {
            offset = 0;
            if (b.i[0] != 0) {
                offset = __builtin_ffs(b.i[0]);
            } else if (b.i[1] != 0) {
            offset = __builtin_ffs(b.i[1]) + 32;
            }
            b.u >>= offset - 1;
            if (a.u > b.u) {
                uint64_t temp = a.u;
                a.u = b.u;
                b.u = temp;
            }
            b.u = b.u - a.u;
        }
        dst[index] = a.u << shift;
    }
}
#endif

#ifdef __CCE_AICORE__
template <typename T1, typename T2>
__simt_vf__ __aicore__
    LAUNCH_BOUND(1024) inline void GcdVec(__ubuf__ T1* dst, __ubuf__ T1* src1, __ubuf__ T1* src2, uint8_t attr, int count)
{
    for (uint32_t index = static_cast<uint32_t>(threadIdx.x); index < count;
         index += static_cast<uint32_t>(blockDim.x)) {
        T2 a = static_cast<T2>(src1[index]);
        T2 b = static_cast<T2>(src2[index]);
        if constexpr (!IsSameType<T1, T2>::value) {
            T2 mask;
            mask = static_cast<T2>(src1[index] >> attr);
            a = (a ^ mask) - mask;

            mask = static_cast<T2>(src2[index] >> attr);
            b = (b ^ mask) - mask;
        }

        if (a == 0) {
            dst[index] = b;
            continue;
        }
        if (b == 0) {
            dst[index] = a;
            continue;
        }

        uint8_t shift = __builtin_ffs(a | b) - 1;
        a >>= __builtin_ffs(a) - 1;
        b >>= __builtin_ffs(b) - 1;
        while (b != 0) {
            b >>= __builtin_ffs(b) - 1;
            if (a > b) {
                T2 temp = a;
                a = b;
                b = temp;
            }
            b = b - a;
        }
        dst[index] = a << shift;
    }
}
#endif

template <class T>
struct GcdNode : public Vec::ElemwiseBinaryOP<T, T, T> {
    __aicore__ inline GcdNode(LocalTensor<T>& dst, LocalTensor<T>& src1, LocalTensor<T>& src2, int count)
    {
#ifdef __CCE_AICORE__
        __ubuf__ T* dst_1 = (__ubuf__ T*)dst.GetPhyAddr();
        __ubuf__ T* src1_1 = (__ubuf__ T*)src1.GetPhyAddr();
        __ubuf__ T* src2_1 = (__ubuf__ T*)src2.GetPhyAddr();
        if constexpr (IsSameType<T, int64_t>::value) {
            asc_vf_call<GcdVecInt64<int64_t>>(dim3(512), dst_1, src1_1, src2_1, count);
        } else if constexpr (IsSameType<T, int32_t>::value) {
            asc_vf_call<GcdVec<int32_t, uint32_t>>(dim3(1024), dst_1, src1_1, src2_1, static_cast<uint8_t>(31), count);
        } else if constexpr (IsSameType<T, int16_t>::value) {
            asc_vf_call<GcdVec<int16_t, uint16_t>>(dim3(1024), dst_1, src1_1, src2_1, static_cast<uint8_t>(15), count);
        } else if constexpr (IsSameType<T, int8_t>::value) {
            asc_vf_call<GcdVec<int8_t, uint8_t>>(dim3(1024), dst_1, src1_1, src2_1, static_cast<uint8_t>(7), count);
        } else if constexpr (IsSameType<T, uint8_t>::value) {
            asc_vf_call<GcdVec<uint8_t, uint8_t>>(dim3(1024), dst_1, src1_1, src2_1, static_cast<uint8_t>(7), count);
        }
#endif
    }
};


template <typename T>
struct GcdCompute {
    using InputX1 = Bind<Vec::CopyInBrc<T>, Placeholder::In0<T>>;
    using InputX2 = Bind<Vec::CopyInBrc<T>, Placeholder::In1<T>>;
    using GcdRes = Bind<GcdNode<T>, InputX1, InputX2>;
    using OpCopyOut = Bind<Vec::CopyOut<T>, Placeholder::Out0<T>, GcdRes>;
    using Outputs = Elems<OpCopyOut>;
    using MemCfg = MemOptCfg<MemLevel::LEVEL_2>;
    using OpDag = DAGSch<Outputs, void, MemCfg>;
};
} // namespace GcdOp

#endif // GCD_DAG_H