/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file block_scheduler_swizzle.h
 * \brief
 */

#include "include/tensor_api/tensor.h"
#include "basic_api/kernel_basic_intf.h"

#pragma once

namespace Blaze {
namespace Gemm {
namespace Block {

template <uint32_t SwizzleOffset = 1, uint32_t SwizzleDirection = 0>
class BlockSchedulerSwizzle {
public:
    using ProblemShape = AscendC::Te::Shape<int64_t, int64_t, int64_t>;
    using BlockShape = AscendC::Te::Shape<int64_t, int64_t, int64_t>;
    using BlockCoord = AscendC::Te::Coord<int64_t, int64_t, int64_t>;
    using TileShape = AscendC::Te::Shape<int64_t, int64_t>;

    struct Params {
        TileShape tileShape;
    };

    __aicore__ inline BlockSchedulerSwizzle(const ProblemShape& shape, const Params& params)
        : problemShape_(shape), tileShape_(params.tileShape)
    {
        if constexpr (SwizzleDirection == 0) {
            // m first
            loopFirst_ = AscendC::Std::ceil_division(get<IDX_M_IDX>(problemShape_), get<IDX_M_IDX>(tileShape_));
            loopSecond_ = AscendC::Std::ceil_division(get<IDX_N_IDX>(problemShape_), get<IDX_N_IDX>(tileShape_));
        } else if constexpr (SwizzleDirection == 1) {
            // n first
            loopSecond_ = AscendC::Std::ceil_division(get<IDX_M_IDX>(problemShape_), get<IDX_M_IDX>(tileShape_));
            loopFirst_ = AscendC::Std::ceil_division(get<IDX_N_IDX>(problemShape_), get<IDX_N_IDX>(tileShape_));
        }
    }

    __aicore__ inline int64_t GetTileNum()
    {
        return loopFirst_ * loopSecond_;
    }

    __aicore__ inline BlockShape GetBlockShape(const BlockCoord& blockCoord)
    {
        return {
            min(get<IDX_M_IDX>(tileShape_), get<IDX_M_IDX>(problemShape_) - get<IDX_M_IDX>(blockCoord)),
            min(get<IDX_N_IDX>(tileShape_), get<IDX_N_IDX>(problemShape_) - get<IDX_N_IDX>(blockCoord)),
            get<IDX_K_IDX>(problemShape_)};
    }

    __aicore__ inline BlockCoord GetBlockCoord(int tileIdx)
    {
        int64_t blockSpan = SwizzleOffset * loopSecond_;
        int64_t blockIdx = tileIdx / blockSpan;
        int64_t inBlockIdx = tileIdx % blockSpan;

        int64_t firstValid = Min(loopFirst_ - blockIdx * SwizzleOffset, static_cast<int64_t>(SwizzleOffset));
        int64_t firstIdx = blockIdx * SwizzleOffset + inBlockIdx % firstValid;
        int64_t secondIdx = inBlockIdx / firstValid;
        if (blockIdx & 1) {
            secondIdx = loopSecond_ - secondIdx - 1;
        }

        if constexpr (SwizzleDirection == 0) {
            return {firstIdx * get<IDX_M_IDX>(tileShape_), secondIdx * get<IDX_N_IDX>(tileShape_), 0};
        } else {
            return {secondIdx * get<IDX_M_IDX>(tileShape_), firstIdx * get<IDX_N_IDX>(tileShape_), 0};
        }
    }

private:
    ProblemShape problemShape_;
    TileShape tileShape_;
    int64_t loopFirst_;
    int64_t loopSecond_;
};

} // namespace Block
} // namespace Gemm
} // namespace Blaze