/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/**
 * @file block_ops/out_matmul.cpp
 * \brief Block matrix multiplication operations with explicit output tiles.
 *
 * All operations receive a pre-allocated output tile as the last argument.
 *   block.matmul       (lhs, rhs, out)
 *   block.matmul_acc   (acc, lhs, rhs, out)
 *   block.matmul_bias  (lhs, rhs, bias, out)
 *   block.gemv         (lhs, rhs, out)
 *   block.gemv_acc     (acc, lhs, rhs, out)
 *   block.gemv_bias    (lhs, rhs, bias, out)
 */

#include <any>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "core/logging.h"
#include "ir/kind_traits.h"
#include "ir/op_registry.h"
#include "ir/type.h"
#include "ir/type_inference.h"

namespace pypto {
namespace ir {

// ---------------------------------------------------------------------------
// Op registration
// ---------------------------------------------------------------------------

// block.matmul: (lhs, rhs, out) -> out's type
REGISTER_OP("block.matmul")
    .set_op_category("BlockOp")
    .set_description("Block explicit-output matrix multiplication: out = lhs @ rhs")
    .add_argument("lhs", "Left matrix tile [M,K] (TileType)")
    .add_argument("rhs", "Right matrix tile [K,N] (TileType)")
    .add_argument("out", "Pre-allocated output tile [M,N] (TileType)")
    .f_deduce_type([]([[maybe_unused]] const std::vector<ExprPtr>& args,
                      [[maybe_unused]] const std::vector<std::pair<std::string, std::any>>& kwargs) {
        return DeduceBlockOutTileType(args, kwargs, "block.matmul", 3);
    });

// block.matmul_acc: (acc, lhs, rhs, out) -> out's type
REGISTER_OP("block.matmul_acc")
    .set_op_category("BlockOp")
    .set_description("Block explicit-output matmul with accumulation: out = acc + lhs @ rhs")
    .add_argument("acc", "Accumulator tile [M,N] (TileType)")
    .add_argument("lhs", "Left matrix tile [M,K] (TileType)")
    .add_argument("rhs", "Right matrix tile [K,N] (TileType)")
    .add_argument("out", "Pre-allocated output tile [M,N] (TileType)")
    .f_deduce_type([]([[maybe_unused]] const std::vector<ExprPtr>& args,
                      [[maybe_unused]] const std::vector<std::pair<std::string, std::any>>& kwargs) {
        return DeduceBlockOutTileType(args, kwargs, "block.matmul_acc", 4);
    });

// block.matmul_bias: (lhs, rhs, bias, out) -> out's type
REGISTER_OP("block.matmul_bias")
    .set_op_category("BlockOp")
    .set_description("Block explicit-output matmul with bias: out = lhs @ rhs + bias")
    .add_argument("lhs", "Left matrix tile [M,K] (TileType)")
    .add_argument("rhs", "Right matrix tile [K,N] (TileType)")
    .add_argument("bias", "Bias tile [1,N] (TileType)")
    .add_argument("out", "Pre-allocated output tile [M,N] (TileType)")
    .f_deduce_type([]([[maybe_unused]] const std::vector<ExprPtr>& args,
                      [[maybe_unused]] const std::vector<std::pair<std::string, std::any>>& kwargs) {
        return DeduceBlockOutTileType(args, kwargs, "block.matmul_bias", 4);
    });

// block.gemv: (lhs, rhs, out) -> out's type
REGISTER_OP("block.gemv")
    .set_op_category("BlockOp")
    .set_description("Block explicit-output GEMV: out[1,N] = lhs[1,K] @ rhs[K,N]")
    .add_argument("lhs", "Row vector tile [1,K] (TileType)")
    .add_argument("rhs", "Matrix tile [K,N] (TileType)")
    .add_argument("out", "Pre-allocated output tile [1,N] (TileType)")
    .f_deduce_type([]([[maybe_unused]] const std::vector<ExprPtr>& args,
                      [[maybe_unused]] const std::vector<std::pair<std::string, std::any>>& kwargs) {
        return DeduceBlockOutTileType(args, kwargs, "block.gemv", 3);
    });

// block.gemv_acc: (acc, lhs, rhs, out) -> out's type
REGISTER_OP("block.gemv_acc")
    .set_op_category("BlockOp")
    .set_description("Block explicit-output GEMV with accumulation: out += lhs @ rhs")
    .add_argument("acc", "Accumulator tile [1,N] (TileType)")
    .add_argument("lhs", "Row vector tile [1,K] (TileType)")
    .add_argument("rhs", "Matrix tile [K,N] (TileType)")
    .add_argument("out", "Pre-allocated output tile [1,N] (TileType)")
    .f_deduce_type([]([[maybe_unused]] const std::vector<ExprPtr>& args,
                      [[maybe_unused]] const std::vector<std::pair<std::string, std::any>>& kwargs) {
        return DeduceBlockOutTileType(args, kwargs, "block.gemv_acc", 4);
    });

// block.gemv_bias: (lhs, rhs, bias, out) -> out's type
REGISTER_OP("block.gemv_bias")
    .set_op_category("BlockOp")
    .set_description("Block explicit-output GEMV with bias: out = lhs @ rhs + bias")
    .add_argument("lhs", "Row vector tile [1,K] (TileType)")
    .add_argument("rhs", "Matrix tile [K,N] (TileType)")
    .add_argument("bias", "Bias tile [1,N] (TileType)")
    .add_argument("out", "Pre-allocated output tile [1,N] (TileType)")
    .f_deduce_type([]([[maybe_unused]] const std::vector<ExprPtr>& args,
                      [[maybe_unused]] const std::vector<std::pair<std::string, std::any>>& kwargs) {
        return DeduceBlockOutTileType(args, kwargs, "block.gemv_bias", 4);
    });

} // namespace ir
} // namespace pypto