/**
 * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
 * MindIE is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *          http://license.coscl.org.cn/MulanPSL2
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */

#include "infer_request_impl.h"

#include "check_utils.h"
#include "log.h"
using namespace mindie_llm;

namespace mindie_llm {
InferRequestImpl::InferRequestImpl(InferRequestId requestId) : requestId_(requestId) {}

void InferRequestImpl::SetTensor(const std::string &tensorName, TensorPtr &tensor) {
    if (!CheckStringInputLength(tensorName, MAX_STRING_LENGTH)) {
        MINDIE_LLM_LOG_ERROR("The length of tensor name: " << tensorName << "is too long.");
        return;
    }
    inputs_[tensorName] = tensor;
}

Status InferRequestImpl::AddTensor(const std::string &tensorName, TensorPtr &tensor) {
    if (!CheckStringInputLength(tensorName, MAX_STRING_LENGTH)) {
        MINDIE_LLM_LOG_ERROR("The length of tensor name: " << tensorName << "is too long in 'AddTensor'.");
        return Status(Error::Code::INVALID_ARG, "The length of tensor name: " + tensorName + " is too long");
    }
    if (tensor == nullptr) {
        return Status(Error::Code::INVALID_ARG, "tensor is nullptr in 'AddTensor' parameter");
    }
    const auto &pr = inputs_.insert(std::make_pair(tensorName, tensor));
    if (!pr.second) {
        return Status(Error::Code::INVALID_ARG, "input '" + tensorName + "' already exists in request");
    }
    return Status(Error::Code::OK, "Success");
}

Status InferRequestImpl::GetTensorByName(const std::string &tensorName, TensorPtr &tensor) {
    if (!CheckStringInputLength(tensorName, MAX_STRING_LENGTH)) {
        MINDIE_LLM_LOG_ERROR("The length of tensor name: " << tensorName << " is too long in 'GetTensorByName'.");
        return Status(Error::Code::INVALID_ARG, "The length of tensor name: " + tensorName + " is too long");
    }
    auto iter = inputs_.find(tensorName);
    if (iter == inputs_.end()) {
        return Status(Error::Code::NOT_FOUND, "input '" + tensorName + "' not found in request");
    }
    tensor = iter->second;
    if (tensor == nullptr) {
        return Status(Error::Code::INVALID_ARG, "tensor is nullptr in 'GetTensorByName' parameter");
    }
    return Status(Error::Code::OK, "Success");
}

Status InferRequestImpl::DelTensorByName(const std::string &name) {
    if (!CheckStringInputLength(name, MAX_STRING_LENGTH)) {
        MINDIE_LLM_LOG_ERROR("The length of tensor name: " << name << "is too long in 'DelTensorByName'.");
        return Status(Error::Code::INVALID_ARG, "The length of tensor name: " + name + " is too long");
    }
    if (inputs_.erase(name) != 1) {
        return Status(Error::Code::INVALID_ARG, "input '" + name + "' does not exist in request");
    }
    return Status(Error::Code::OK, "Success");
}

InferRequestId InferRequestImpl::GetRequestId() const { return requestId_; }

Status InferRequestImpl::SetMaxOutputLen(uint32_t maxOutputLen) {
    if (maxOutputLen > 0) {
        maxOutputLen_ = maxOutputLen;
        return Status(Error::Code::OK, "Success");
    } else {
        MINDIE_LLM_LOG_ERROR("InferRequest SetMaxOutputLen failed due to invalid parameter");
        return Status(Error::Code::ERROR, "output length is invalid parameter");
    }
}

uint32_t InferRequestImpl::GetMaxOutputLen() const { return maxOutputLen_; }

void InferRequestImpl::SetSendResponseCallback(const mindie_llm::SendResponseCallback4Request &callback) {
    responseCallback_ = callback;
}

mindie_llm::SendResponseCallback4Request &InferRequestImpl::GetSendResponseCallback() { return responseCallback_; }

void InferRequestImpl::SetReleaseCallback(const mindie_llm::ReleaseCallback &callback) { releaseCallback_ = callback; }

mindie_llm::ReleaseCallback &InferRequestImpl::GetReleaseCallback() { return releaseCallback_; }

void InferRequestImpl::SetEngineResponseCallback(const mindie_llm::SendResponseCallback4Request &callback) {
    engineResponseCallback_ = callback;
}

mindie_llm::SendResponseCallback4Request &InferRequestImpl::GetEngineResponseCallback() {
    return engineResponseCallback_;
}

const mindie_llm::TensorMap &InferRequestImpl::ImmutableInputs() const { return inputs_; }

void InferRequestImpl::SetReqType(mindie_llm::InferReqType reqType) { reqType_ = reqType; }

mindie_llm::InferReqType InferRequestImpl::GetReqType() const { return reqType_; }

void InferRequestImpl::SetRecompute(bool isRecompute) { isRecompute_ = isRecompute; }

bool InferRequestImpl::IsRecompute() const { return isRecompute_; }

bool InferRequestImpl::IsPrefillReq() const { return reqType_ == mindie_llm::InferReqType::REQ_PREFILL; }

bool InferRequestImpl::IsDecodeReq() const { return reqType_ == mindie_llm::InferReqType::REQ_DECODE; }

void InferRequestImpl::SetDTarget(std::string &dTarget) {
    if (!CheckStringInputLength(dTarget, MAX_STRING_LENGTH)) {
        MINDIE_LLM_LOG_ERROR("The Length of dTarget: " << dTarget << " is too long in SetDTarget.");
        return;
    }
    dTarget_ = dTarget;
}

std::string InferRequestImpl::GetDTarget() const { return dTarget_; }

void InferRequestImpl::SetPrefillAddr(std::string &prefillAddr) {
    if (!CheckStringInputLength(prefillAddr, MAX_STRING_LENGTH)) {
        MINDIE_LLM_LOG_ERROR("The Length of dTarget: prefillAddr is too long");
        return;
    }
    prefillAddr_ = prefillAddr;
}

std::string InferRequestImpl::GetPrefillAddr() const { return prefillAddr_; }

void InferRequestImpl::SetSrcBlockTable(const std::vector<int64_t> &srcBlockTable) {
    srcBlockTable_.clear();
    srcBlockTable_.push_back(srcBlockTable);
}

std::vector<int64_t> InferRequestImpl::GetSrcBlockTable() const {
    if (srcBlockTable_.empty()) {
        return {};
    }
    return srcBlockTable_[0];
}

void InferRequestImpl::SetDpInstanceIds(const std::vector<uint64_t> &dpInstanceIds) { dpInstanceIds_ = dpInstanceIds; }

std::vector<uint64_t> InferRequestImpl::GetDpInstanceIds() const { return dpInstanceIds_; }

void InferRequestImpl::SetSrcHmoTable(const std::vector<std::vector<int64_t>> &srcHmoTable) {
    srcHmoTable_ = srcHmoTable;
}

std::vector<std::vector<int64_t>> InferRequestImpl::GetSrcHmoTable() const { return srcHmoTable_; }

}  // namespace mindie_llm