/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file calc_distributed.cpp
 * \brief
 */

#include <memory>
#include <iostream>
#include "interface/interpreter/operation.h"
#include "tensor/symbolic_scalar.h"
#include "tilefwk/error.h"
#include "tilefwk/tilefwk_op.h"
#include "tilefwk/comm_group_recorder.h"
#include "calc.h"
#include "communication.h"
#include "interface/operation/distributed/distributed_common.h"

namespace npu::tile_fwk {
void ExecuteOpBindTensor(ExecuteOperationContext *ctx) {
    (void) ctx;
}
REGISTER_CALC_OP(OP_BIND_TENSOR, Opcode::OP_BIND_TENSOR, ExecuteOpBindTensor);

void ExecuteOpShmemSet(ExecuteOperationContext *ctx) {
    ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 0x2);
    ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() == 1 || ctx->ooperandInplaceDataViewList->size() == 0x2);
    auto &shm = ctx->ioperandDataViewList->at(1);

    Distributed::ShmemSetAttr attr;
    ctx->op->GetAttr(OpAttributeKey::distOpAttr, attr);
    
    std::shared_ptr<SimulationCommContext> context = SimulationCommManager::Instance().GetCommContext(attr.group);
    size_t slotSize = shm->GetSize() * BytesOf(shm->GetDataType());
    if (!attr.isSetData) {
        context->Signal(context->GetRank(), 0, slotSize, shm->GetShmStorageOffset());
    } else {
        context->Set(context->GetRank(), 0, slotSize, shm->GetShmStorageOffset());
    }
}
REGISTER_CALC_OP(OP_SHMEM_SET, Opcode::OP_SHMEM_SET, ExecuteOpShmemSet);

void ExecuteOpShmemPut(ExecuteOperationContext *ctx) {
    ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 0x3);
    ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() == 1 || ctx->ooperandInplaceDataViewList->size() == 0x2);
    auto &in = ctx->ioperandDataViewList->at(1);
    auto &shm = ctx->ioperandDataViewList->at(0x2);

    Distributed::ShmemPutAttr attr;
    ctx->op->GetAttr(OpAttributeKey::distOpAttr, attr);
    
    std::shared_ptr<SimulationCommContext> context = SimulationCommManager::Instance().GetCommContext(attr.group);
    int dstRank = ctx->opInter->EvaluateSymbolicScalar(attr.ownerRank);
    int atomicType = 0;
    if (attr.atomicType == Distributed::AtomicType::ADD) {
        atomicType = 1;
    }
    if (attr.atomicType == Distributed::AtomicType::SET) {
        atomicType = 0;
    }

    if (shm->GetDataType() != in->GetDataType()) {
        auto castedIn = LogicalTensorData::CreateEmpty(shm->GetDataType(), shm->GetShape(), shm->GetValidShape(), shm->GetShape());
        calc::Cast(castedIn, in);
        context->Put(castedIn, dstRank, shm->GetShmStorageOffset(), atomicType);
    } else {
        context->Put(in, dstRank, shm->GetShmStorageOffset(), atomicType);
    }
}
REGISTER_CALC_OP(OP_SHMEM_PUT, Opcode::OP_SHMEM_PUT, ExecuteOpShmemPut);

void ExecuteOpShmemSignal(ExecuteOperationContext *ctx) {
    ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 0x2);
    ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() == 0x2 || ctx->ooperandInplaceDataViewList->size() == 1);
    auto &shm = ctx->ioperandDataViewList->at(1);

    Distributed::ShmemSignalAttr attr;
    ctx->op->GetAttr(OpAttributeKey::distOpAttr, attr);
    
    std::shared_ptr<SimulationCommContext> context = SimulationCommManager::Instance().GetCommContext(attr.group);
    int dstRank = ctx->opInter->EvaluateSymbolicScalar(attr.ownerRank);
    int atomicType = 0;
    if (attr.atomicType == Distributed::AtomicType::SET) {
        atomicType = 0;
    }
    if (attr.atomicType == Distributed::AtomicType::ADD) {
        atomicType = 1;
    }
    int value = attr.signalValue;
    bool notifyAll = attr.notifyAll;
    size_t slotSize = shm->GetSize() * BytesOf(shm->GetDataType());
    context->Signal(dstRank, value, slotSize, shm->GetShmStorageOffset(), atomicType, notifyAll);
}
REGISTER_CALC_OP(OP_SHMEM_SIGNAL, Opcode::OP_SHMEM_SIGNAL, ExecuteOpShmemSignal);

void ExecuteOpShmemWaitUntil(ExecuteOperationContext *ctx) {
    ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 0x2);
    ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() == 1);
    auto &shm = ctx->ioperandDataViewList->at(1);

    Distributed::ShmemWaitUntilAttr attr;
    ctx->op->GetAttr(OpAttributeKey::distOpAttr, attr);
    
    std::shared_ptr<SimulationCommContext> context = SimulationCommManager::Instance().GetCommContext(attr.group);
    int srcRank = context->GetRank();
    int expect = attr.expectedSum;
    bool reset = attr.resetSignal;
    size_t slotSize =  shm->GetSize() * BytesOf(shm->GetDataType());

    context->Wait(srcRank, expect, slotSize, shm->GetShmStorageOffset(), reset);
}
REGISTER_CALC_OP(OP_SHMEM_WAIT_UNTIL, Opcode::OP_SHMEM_WAIT_UNTIL, ExecuteOpShmemWaitUntil);

void ExecuteOpShmemGet(ExecuteOperationContext *ctx) {
    ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 0x2);
    ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() == 0x2 || ctx->ooperandInplaceDataViewList->size() == 1);
    auto &shm = ctx->ioperandDataViewList->at(1);
    auto out = ctx->ooperandInplaceDataViewList->at(0);

    Distributed::ShmemGetAttr attr;
    ctx->op->GetAttr(OpAttributeKey::distOpAttr, attr);
    
    std::shared_ptr<SimulationCommContext> context = SimulationCommManager::Instance().GetCommContext(attr.group);
    int srcRank = ctx->opInter->EvaluateSymbolicScalar(attr.ownerRank);
    LogicalTensorDataPtr tmp = context->Get(srcRank, out->GetDataType(), out->GetShape(), shm->GetShmStorageOffset());
    calc::Copy(out, tmp);
}
REGISTER_CALC_OP(OP_SHMEM_GET, Opcode::OP_SHMEM_GET, ExecuteOpShmemGet);
REGISTER_CALC_OP(OP_SHMEM_LOAD, Opcode::OP_SHMEM_LOAD, ExecuteOpShmemGet);
}