/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2025-2026. All rights reserved.
 * MindIE is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *          http://license.coscl.org.cn/MulanPSL2
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */
#include <gtest/gtest.h>
#include "mockcpp/mockcpp.hpp"
#include <filesystem>
#include <iostream>
#include <string>
 
#include "dmi_msg_sender.h"
#include <grpcpp/channel.h>
#include <grpcpp/create_channel.h>
#include "health_checker.h"
#include "config/config_manager.h"
#include "env_util.h"
 
using namespace prefillAndDecodeCommunication;
 
namespace grpc {
bool operator==(const Status& lhs, const Status& rhs)
{
    return lhs.error_code() == rhs.error_code() &&
           lhs.error_message() == rhs.error_message();
}
 
bool operator!=(const Status& lhs, const Status& rhs)
{
    return !(lhs == rhs);
}
}
 
namespace mindie_llm {
 
bool operator==(const ServerConfig& lhs, const ServerConfig& rhs)
{
    return lhs.inferMode == rhs.inferMode &&
           lhs.npuUsageThreshold == rhs.npuUsageThreshold;
}
 
const ServerConfig& GetServerConfig();
 
class DmiMsgSenderTest : public testing::Test {
protected:
    void SetUp() override
    {
        EnvUtil::GetInstance().SetEnvVar(
            "RANK_TABLE_FILE",
            GetParentDirectory() + "/../../config_manager/conf/ranktable.json");
        EnvUtil::GetInstance().SetEnvVar("MIES_CONTAINER_IP", "127.0.0.1");
        EnvUtil::GetInstance().SetEnvVar("HOST_IP", "127.0.0.1");
        EnvUtil::GetInstance().SetEnvVar("MINDIE_CHECK_INPUTFILES_PERMISSION", "1");
 
        ASSERT_TRUE(ConfigManager::CreateInstance(
            GetParentDirectory() + "/../../config_manager/conf/config_grpc.json"));
 
        ServerConfig serverConfig;
        serverConfig.inferMode = "standard";
        serverConfig.npuUsageThreshold = 0;
        serverConfig.distDPServerEnabled = false;
 
        MOCKER_CPP(GetServerConfig, const ServerConfig& (*)())
            .stubs()
            .will(returnValue(serverConfig));
    }
 
    void TearDown() override
    {
        EnvUtil::GetInstance().ClearEnvVar("RANK_TABLE_FILE");
        EnvUtil::GetInstance().ClearEnvVar("MINDIE_CHECK_INPUTFILES_PERMISSION");
        EnvUtil::GetInstance().ClearEnvVar("MIES_CONTAINER_IP");
        EnvUtil::GetInstance().ClearEnvVar("HOST_IP");
        GlobalMockObject::verify();
    }
 
    static std::string GetParentDirectory()
    {
        try {
            return std::filesystem::current_path().string();
        } catch (const std::filesystem::filesystem_error& e) {
            std::cerr << "Error getting current directory: " << e.what() << std::endl;
            return "";
        }
    }
};

TEST_F(DmiMsgSenderTest, DecodeRequestSender_InitWithoutTlsSuccess)
{
    DecodeRequestSender sender("127.0.0.1:50051", "127.0.0.1:50052", false, nullptr);
    EXPECT_TRUE(sender.Init());
}

TEST_F(DmiMsgSenderTest, DecodeRequestSender_InitWithTlsSuccess)
{
    auto tlsOptions = std::make_unique<grpc::experimental::TlsChannelCredentialsOptions>();
    DecodeRequestSender sender("127.0.0.1:50051", "127.0.0.1:50052", true, std::move(tlsOptions));
    EXPECT_TRUE(sender.Init());
}

TEST_F(DmiMsgSenderTest, DecodeRequestSender_InitWithTlsButNoOptions)
{
    DecodeRequestSender sender("127.0.0.1:50051", "127.0.0.1:50052", true, nullptr);
    EXPECT_FALSE(sender.Init());
}

TEST_F(DmiMsgSenderTest, DecodeRequestSender_CreateStubSuccess)
{
    auto channel = grpc::CreateChannel("127.0.0.1:50051", grpc::InsecureChannelCredentials());
    DecodeRequestSender sender("127.0.0.1:50051", "127.0.0.1:50052", false, nullptr);
    sender.CreateStub(channel);
    EXPECT_EQ(sender.receiverAddr_, "127.0.0.1:50051");
    EXPECT_EQ(sender.localAddr_, "127.0.0.1:50052");
}

TEST_F(DmiMsgSenderTest, DecodeRequestSender_ResponseIsNotValiddecodeParameters)
{
    auto channel = grpc::CreateChannel("127.0.0.1:50051", grpc::InsecureChannelCredentials());
    DecodeRequestSender sender("127.0.0.1:50051", "127.0.0.1:50052", false, nullptr);
    sender.CreateStub(channel);
    MOCKER_CPP(&DecodeService::Stub::DecodeRequestChannel, grpc::Status (*)(
                grpc::ClientContext*, const DecodeParameters&, DecodeRequestResponse*))
                .stubs().will(returnValue(grpc::Status::OK));
    std::string errMsg;
    DecodeParameters request;
    std::string reqId = "req-";
    EXPECT_FALSE(sender.SendDecodeRequestMsg(request, reqId, errMsg));
    EXPECT_EQ(errMsg, "");
}

TEST_F(DmiMsgSenderTest, DecodeRequestSender_StubNull)
{
    DecodeRequestSender sender("127.0.0.1:50051", "127.0.0.1:50052", false, nullptr);
    MOCKER_CPP(&DecodeService::Stub::DecodeRequestChannel, grpc::Status (*)(
                grpc::ClientContext*, const DecodeParameters&, DecodeRequestResponse*))
                .stubs().will(returnValue(grpc::Status::OK));
    std::string errMsg;
    DecodeParameters request;
    std::string reqId = "req-";
    EXPECT_FALSE(sender.SendDecodeRequestMsg(request, reqId, errMsg));
    EXPECT_EQ(errMsg, "The stub_ is nullptr");
}

TEST_F(DmiMsgSenderTest, DecodeRequestSender_DecodeRequestChannelError)
{
    auto channel = grpc::CreateChannel("127.0.0.1:50051", grpc::InsecureChannelCredentials());
    DecodeRequestSender sender("127.0.0.1:50051", "127.0.0.1:50052", false, nullptr);
    sender.CreateStub(channel);
    MOCKER_CPP(&DecodeService::Stub::DecodeRequestChannel, grpc::Status (*)(
                grpc::ClientContext*, const DecodeParameters&, DecodeRequestResponse*))
                .stubs().will(returnValue(grpc::Status::CANCELLED));
    std::string errMsg;
    DecodeParameters request;
    std::string reqId = "req-";
    EXPECT_FALSE(sender.SendDecodeRequestMsg(request, reqId, errMsg));
    EXPECT_EQ(errMsg, "Failed to send decode request msg because[1]  receiverAddr is 127.0.0.1:50051. RequestId is req-");
}

TEST_F(DmiMsgSenderTest, KvReleaseSender_CreateStubSuccess)
{
    auto channel = grpc::CreateChannel("127.0.0.1:50051", grpc::InsecureChannelCredentials());
    KvReleaseSender sender("127.0.0.1:50051", "127.0.0.1:50052", false, nullptr);
    sender.CreateStub(channel);
    EXPECT_EQ(sender.receiverAddr_, "127.0.0.1:50051");
    EXPECT_EQ(sender.localAddr_, "127.0.0.1:50052");
}

TEST_F(DmiMsgSenderTest, KvReleaseSender_Success)
{
    auto channel = grpc::CreateChannel("127.0.0.1:50051", grpc::InsecureChannelCredentials());
    KvReleaseSender sender("127.0.0.1:50051", "127.0.0.1:50052", false, nullptr);
    sender.CreateStub(channel);
    MOCKER_CPP(&PrefillService::Stub::ReleaseKVCacheChannel, grpc::Status (*)(
                grpc::ClientContext*, const RequestId&, google::protobuf::Empty*))
                .stubs().will(returnValue(grpc::Status::OK));
    prefillAndDecodeCommunication::RequestId request;
    EXPECT_TRUE(sender.SendKvReleaseMsg(request));
}

TEST_F(DmiMsgSenderTest, KvReleaseSender_StubNull)
{
    KvReleaseSender sender("127.0.0.1:50051", "127.0.0.1:50052", false, nullptr);
    MOCKER_CPP(&PrefillService::Stub::ReleaseKVCacheChannel, grpc::Status (*)(
                grpc::ClientContext*, const RequestId&, google::protobuf::Empty*))
                .stubs().will(returnValue(grpc::Status::OK));
    prefillAndDecodeCommunication::RequestId request;
    EXPECT_FALSE(sender.SendKvReleaseMsg(request));
}

TEST_F(DmiMsgSenderTest, KvReleaseSender_DecodeRequestChannelError)
{
    auto channel = grpc::CreateChannel("127.0.0.1:50051", grpc::InsecureChannelCredentials());
    KvReleaseSender sender("127.0.0.1:50051", "127.0.0.1:50052", false, nullptr);
    sender.CreateStub(channel);
    MOCKER_CPP(&PrefillService::Stub::ReleaseKVCacheChannel, grpc::Status (*)(
                grpc::ClientContext*, const RequestId&, google::protobuf::Empty*))
                .stubs().will(returnValue(grpc::Status::CANCELLED));
    prefillAndDecodeCommunication::RequestId request;
    EXPECT_FALSE(sender.SendKvReleaseMsg(request));
}

} // namespace mindie_llm