/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file bitwise_shift.h
 * \brief
 */

#ifndef TILEOP_TILE_OPERATOR_BITWISE_SHIFT__H
#define TILEOP_TILE_OPERATOR_BITWISE_SHIFT__H
#include "pto_tile.h"
#include "utils/layout.h"
#include "utils/tile_tensor.h"
#include "tileop_common.h"
#include "unary.h"

template <BitwiseShiftOp op, typename T0, typename T1, typename T2>
TILEOP void BitwiseShiftComputeImpl(T0 dst, T1 src0, T2 src1)
{
    if constexpr (op == BitwiseShiftOp::BITWISERIGHTSHIFT) {
        pto::TSHR(dst, src0, src1);
        return;
    }

    if constexpr (op == BitwiseShiftOp::BITWISELEFTSHIFT) {
        pto::TSHL(dst, src0, src1);
        return;
    }
}

template <BitwiseShiftOp op, typename T0, typename T1, typename Scalar>
TILEOP void BitwiseShiftScalarComputeImpl(T0 dst, T1 src0, Scalar src1)
{
    if constexpr (op == BitwiseShiftOp::BITWISERIGHTSHIFT) {
        pto::TSHRS(dst, src0, src1);
        return;
    }

    if constexpr (op == BitwiseShiftOp::BITWISELEFTSHIFT) {
        pto::TSHLS(dst, src0, src1);
        return;
    }
}

template <size_t MAX_SHIFT_NUM, typename T, typename U, typename V, typename USigned, typename VSigned>
TILEOP void GetValidShiftTile(T& dst, U& src1, V& tmp, USigned& src1Signed, VSigned& tmpSigned)
{
    pto::TEXPANDS(tmp, MAX_SHIFT_NUM);
    SyncV();
    pto::TSUB(tmpSigned, tmpSigned, src1Signed);
    SyncV();
    pto::TOR(tmp, tmp, src1);
    SyncV();
    pto::TSHRS(tmpSigned, tmpSigned, MAX_SHIFT_NUM);
    SyncV();
    pto::TNOT(dst, tmp);
    SyncV();
    pto::TAND(src1, src1, dst);
    SyncV();
    pto::TEXPANDS(dst, MAX_SHIFT_NUM);
    SyncV();
    pto::TAND(tmp, tmp, dst);
    SyncV();
    pto::TOR(src1, src1, tmp);
}

template <BitwiseShiftOp op, size_t MAX_SHIFT_NUM, typename T0, typename T1, typename T2, typename T3,
          typename T2Signed, typename T3Signed>
TILEOP void BitwiseShiftImpl(T0& dst, T1& src0, T2& src1, T3& tmp, T2Signed& src1Signed, T3Signed& tmpSigned)
{
    GetValidShiftTile<MAX_SHIFT_NUM>(dst, src1, tmp, src1Signed, tmpSigned);
    SyncV();
    BitwiseShiftComputeImpl<op>(dst, src0, src1);
}

template <BitwiseShiftOp op, typename T0, typename T1, typename T2, typename T3>
TILEOP void BitwiseShiftCompute(T0 dst, T1 src0, T2 src1, T3 tmp)
{
    constexpr auto MAX_SHIFT_NUM = sizeof(typename T0::Type) * TileOp::BLOCK_NELEM_B32;
    const auto dstLayout = dst.GetLayout();
    auto shape0 = dstLayout.template GetShapeDim<DIM_1ST, MAX_DIMS>();
    auto shape1 = dstLayout.template GetShapeDim<DIM_2ND, MAX_DIMS>();
    auto shape2 = dstLayout.template GetShapeDim<DIM_3RD, MAX_DIMS>();
    auto dstTile = PtoTile<T0>(dst);
    auto src0Tile = PtoTile<T1>(src0);
    auto src1Tile = PtoTile<T2>(src1);
    auto tmpTile = PtoTile<T3>(tmp);
    auto src1SignedTile =
        PtoTile<T2, pto::BLayout::RowMajor, false, std::make_signed_t<typename PtoTile<T2>::Dtype>>(src1);
    auto tmpSignedTile =
        PtoTile<T3, pto::BLayout::RowMajor, false, std::make_signed_t<typename PtoTile<T3>::Dtype>>(tmp);

    for (LoopVar n0Index = 0; n0Index < shape0; ++n0Index) {
        for (LoopVar n1Index = 0; n1Index < shape1; ++n1Index) {
            for (LoopVar n2Index = 0; n2Index < shape2; ++n2Index) {
                auto tileOffsets = TileOffset(n0Index, n1Index, n2Index);
                dstTile.Assign(dst, tileOffsets);
                src0Tile.Assign(src0, tileOffsets);
                src1Tile.Assign(src1, tileOffsets);
                tmpTile.Assign(tmp, tileOffsets);
                src1SignedTile.Assign(src1, tileOffsets);
                tmpSignedTile.Assign(tmp, tileOffsets);
                BitwiseShiftImpl<op, MAX_SHIFT_NUM>(
                    dstTile.Data(), src0Tile.Data(), src1Tile.Data(), tmpTile.Data(),
                    src1SignedTile.Data(), tmpSignedTile.Data());
            }
        }
    }
}

template <BitwiseShiftOp op, typename T0, typename T1, typename Scalar>
TILEOP void BitwiseShiftScalarCompute(T0 dst, T1 src0, Scalar src1)
{
    constexpr auto MAX_SHIFT_NUM = sizeof(typename T0::Type) * TileOp::BLOCK_NELEM_B32;
    const auto dstLayout = dst.GetLayout();
    auto shape0 = dstLayout.template GetShapeDim<DIM_1ST, MAX_DIMS>();
    auto shape1 = dstLayout.template GetShapeDim<DIM_2ND, MAX_DIMS>();
    auto shape2 = dstLayout.template GetShapeDim<DIM_3RD, MAX_DIMS>();
    auto dstTile = PtoTile<T0>(dst);
    auto src0Tile = PtoTile<T1>(src0);
    if (src1 < 0 || src1 > MAX_SHIFT_NUM) {
        src1 = MAX_SHIFT_NUM;
    }
    for (LoopVar n0Index = 0; n0Index < shape0; ++n0Index) {
        for (LoopVar n1Index = 0; n1Index < shape1; ++n1Index) {
            for (LoopVar n2Index = 0; n2Index < shape2; ++n2Index) {
                auto tileOffsets = TileOffset(n0Index, n1Index, n2Index);
                dstTile.Assign(dst, tileOffsets);
                src0Tile.Assign(src0, tileOffsets);
                BitwiseShiftScalarComputeImpl<op>(dstTile.Data(), src0Tile.Data(), src1);
            }
        }
    }
}

template <BitwiseShiftOp op, size_t MAX_SHIFT_NUM, typename T0, typename Scalar, typename T1, typename T2,
          typename T1Signed, typename T2Signed>
TILEOP void ScalarBitwiseShiftImpl(
    T0& dst, Scalar& src0, T1& src1, T2& tmp, T1Signed& src1Signed, T2Signed& tmpSigned)
{
    GetValidShiftTile<MAX_SHIFT_NUM>(dst, src1, tmp, src1Signed, tmpSigned);
    SyncV();
    pto::TEXPANDS(dst, src0);
    SyncV();
    BitwiseShiftComputeImpl<op>(dst, dst, src1);
}

template <BitwiseShiftOp op, typename T0, typename Scalar, typename T1, typename T2>
TILEOP void ScalarBitwiseShiftCompute(T0 dst, Scalar src0, T1 src1, T2 tmp)
{
    constexpr auto MAX_SHIFT_NUM = sizeof(typename T0::Type) * TileOp::BLOCK_NELEM_B32;
    const auto dstLayout = dst.GetLayout();
    auto shape0 = dstLayout.template GetShapeDim<DIM_1ST, MAX_DIMS>();
    auto shape1 = dstLayout.template GetShapeDim<DIM_2ND, MAX_DIMS>();
    auto shape2 = dstLayout.template GetShapeDim<DIM_3RD, MAX_DIMS>();
    auto dstTile = PtoTile<T0>(dst);
    auto src1Tile = PtoTile<T1>(src1);
    auto tmpTile = PtoTile<T2>(tmp);
    auto src1SignedTile =
        PtoTile<T1, pto::BLayout::RowMajor, false, std::make_signed_t<typename PtoTile<T1>::Dtype>>(src1);
    auto tmpSignedTile =
        PtoTile<T2, pto::BLayout::RowMajor, false, std::make_signed_t<typename PtoTile<T2>::Dtype>>(tmp);

    for (LoopVar n0Index = 0; n0Index < shape0; ++n0Index) {
        for (LoopVar n1Index = 0; n1Index < shape1; ++n1Index) {
            for (LoopVar n2Index = 0; n2Index < shape2; ++n2Index) {
                auto tileOffsets = TileOffset(n0Index, n1Index, n2Index);
                dstTile.Assign(dst, tileOffsets);
                src1Tile.Assign(src1, tileOffsets);
                tmpTile.Assign(tmp, tileOffsets);
                src1SignedTile.Assign(src1, tileOffsets);
                tmpSignedTile.Assign(tmp, tileOffsets);
                ScalarBitwiseShiftImpl<op, MAX_SHIFT_NUM>(
                    dstTile.Data(), src0, src1Tile.Data(), tmpTile.Data(),
                    src1SignedTile.Data(), tmpSignedTile.Data());
            }
        }
    }
}

#define OP_TILE_OP_BITWISERIGHTSHIFT TBitrshift
template <typename T0, typename T1, typename T2, typename T3>
TILEOP void TBitrshift(T0 dst, T1 src0, T2 src1, T3 tmp)
{
    BitwiseShiftCompute<BitwiseShiftOp::BITWISERIGHTSHIFT>(dst, src0, src1, tmp);
}

#define OP_TILE_OP_BITWISELEFTSHIFT TBitlshift
template <typename T0, typename T1, typename T2, typename T3>
TILEOP void TBitlshift(T0 dst, T1 src0, T2 src1, T3 tmp)
{
    BitwiseShiftCompute<BitwiseShiftOp::BITWISELEFTSHIFT>(dst, src0, src1, tmp);
}

#define OP_TILE_OP_BITWISERIGHTSHIFTS TBitrshiftS
template <typename Scalar, typename T0, typename T1>
TILEOP void TBitrshiftS(T0 dst, T1 src0, Scalar src1)
{
    BitwiseShiftScalarCompute<BitwiseShiftOp::BITWISERIGHTSHIFT>(dst, src0, src1);
}

#define OP_TILE_OP_BITWISELEFTSHIFTS TBitlshiftS
template <typename Scalar, typename T0, typename T1>
TILEOP void TBitlshiftS(T0 dst, T1 src0, Scalar src1)
{
    BitwiseShiftScalarCompute<BitwiseShiftOp::BITWISELEFTSHIFT>(dst, src0, src1);
}

#define OP_TILE_OP_SBITWISERIGHTSHIFT TSBitrshift
template <typename Scalar, typename T0, typename T1, typename T2>
TILEOP void TSBitrshift(T0 dst, T1 src1, Scalar src0, T2 tmp)
{
    ScalarBitwiseShiftCompute<BitwiseShiftOp::BITWISERIGHTSHIFT>(dst, src0, src1, tmp);
}

#define OP_TILE_OP_SBITWISELEFTSHIFT TSBitlshift
template <typename Scalar, typename T0, typename T1, typename T2>
TILEOP void TSBitlshift(T0 dst, T1 src1, Scalar src0, T2 tmp)
{
    ScalarBitwiseShiftCompute<BitwiseShiftOp::BITWISELEFTSHIFT>(dst, src0, src1, tmp);
}
#endif