#pragma once
#include <map>
#include <memory>
#include <string>

#include "torch_npu/csrc/core/npu/npu_log.h"
#include "torch_npu/csrc/core/npu/sys_ctrl/npu_sys_ctrl.h"
#include "torch_npu/csrc/core/npu/NPUException.h"
#include "torch_npu/csrc/core/npu/NPUStream.h"

#include <ATen/ATen.h>
#include <c10/core/Device.h>
#include <c10/util/Optional.h>
#include "third_party/hccl/inc/hccl/hccl.h"
#include "third_party/hccl/inc/hccl/hccl_types.h"

#define HCCL_CHECK_ERROR(err_code, ...)                                      \
    do {                                                                     \
        auto Error = err_code;                                               \
        if ((Error) != HCCL_SUCCESS) {                                       \
            CHECK_AND_THROW_ERROR_WITH_SPECIFIC_MESSAGE(Error);              \
            if (c10_npu::option::OptionsManager::IsCompactErrorOutput()) {   \
                std::ostringstream oss;                                      \
                oss << " HCCL function error: " << getErrorFunction(#err_code, ##__VA_ARGS__)    \
                   << ", error code is " << Error << " "                    \
                   << DIST_ERROR(ErrCode::HCCL) + ".\n";                     \
                std::string err_msg = oss.str();                          \
                ASCEND_LOGE("%s", err_msg.c_str());                       \
                std::string errmsg(c10_npu::c10_npu_get_error_message());    \
                TORCH_CHECK(                                                 \
                    false,                                                   \
                    errmsg.empty() ? err_msg : errmsg);                      \
            } else {                                                         \
                auto retmsg = std::string(__func__) + ":" + __FILE__ + ":" + std::to_string(__LINE__) +    \
                    " HCCL function error: " + getErrorFunction(#err_code, ##__VA_ARGS__) +                \
                    ", error code is " + std::to_string(Error) + " " + DIST_ERROR(ErrCode::HCCL) + ".\n" + \
                    c10_npu::c10_npu_get_error_message();                                                  \
                if (c10_npu::isCannOOM(retmsg)) {                                 \
                    if (c10_npu::option::OptionsManager::IsOomSnapshotEnable()) { \
                        c10_npu::option::oom_observer();                          \
                    }                                                             \
                    ASCEND_LOGE("%s", retmsg.c_str());                            \
                    TORCH_CHECK_WITH(OutOfMemoryError, false, retmsg.c_str());    \
                }                                                                 \
                TORCH_CHECK(false, retmsg);                                       \
            }                                                                     \
        }                                                                         \
    } while (0)

#define ENABLE_HCCL_ERROR_CHECKING

namespace c10d_npu {
extern HcclResult hcclGetCommAsyncError(HcclComm comm, HcclResult* asyncError);
extern HcclResult hcclCommInitRootInfoConfig(uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, HcclCommConfig* config, HcclComm *comm);
extern HcclResult hcclCommInitClusterInfoConfig(const char *clusterInfo, uint32_t rank, HcclCommConfig *config, HcclComm *comm);
extern HcclResult hcclCreateSubCommConfig(HcclComm *comm, uint32_t rankNum, uint32_t *rankIds, uint64_t subCommId, uint32_t subCommRankId,
    HcclCommConfig* config, HcclComm *subComm);
extern HcclResult hcclCommWorkingDevNicSet(HcclComm comm, uint32_t *ranks, bool *useBackup, uint32_t nRanks);
extern HcclResult hcclCommDestroy(HcclComm comm);
extern HcclResult hcclBroadcast(void *buf, uint64_t count, HcclDataType dataType, uint32_t root, HcclComm comm, aclrtStream stream);
extern HcclResult hcclCommInitRootInfo(uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, HcclComm *comm);
extern HcclResult hcclCommInitAll(uint32_t ndev, int32_t *devices, HcclComm *comms);

HcclDataType getHcclDataType(at::ScalarType type);

std::string getHcclDataTypeSerialString(HcclDataType type);

bool isFileExists(const std::string& path);

bool checkFilePathReadable(const std::string& file);

bool isSupportHcclCommName();

c10::optional<c10_npu::NPUStream> getHcclStreamByBufferName(const std::string &name, c10::DeviceIndex device_index);

bool setHcclStreamByBufferName(const std::string &name, c10::DeviceIndex device_index, c10_npu::NPUStream steam);

// RAII wrapper for HCCL communicator
class C10_NPU_API HCCLComm {
public:
    explicit HCCLComm(HcclComm hcclComm);
    HCCLComm() : HCCLComm(nullptr) {}
    ~HCCLComm();

    static std::shared_ptr<HCCLComm> create(
        int numRanks,
        int rank,
        HcclRootInfo& rootInfo);

    static std::shared_ptr<HCCLComm> create_config(
        int numRanks,
        int rank,
        HcclRootInfo& rootInfo,
        HcclCommConfig* config);

    static std::shared_ptr<HCCLComm> createGlobalHcclComm(
        const char *clusterInfo,
        uint32_t rank,
        HcclCommConfig* config);

    static std::shared_ptr<HCCLComm> createSubHcclComm(
        std::shared_ptr<HCCLComm> comm,
        uint32_t rankNum,
        uint32_t *rankIds,
        uint64_t subCommId,
        uint32_t subCommRankId,
        HcclCommConfig* config);

    int hcclCommType;
    int p2pPeer;

    // Must not be copyable
    HCCLComm(const HCCLComm&) = delete;
    HCCLComm& operator=(const HCCLComm&) = delete;

    // Move constructable
    HCCLComm(HCCLComm&& other);

    // Move assignable
    HCCLComm& operator=(HCCLComm&& other);

    HcclComm getHcclComm() const
    {
        return hcclComm_;
    }

    void destroyHcclComm();

    HcclResult checkForHcclError();

protected:
    HcclComm hcclComm_;
    mutable std::mutex mutex_;
    HcclResult hcclAsyncErr_;
};

class TORCH_API DebugInfoWriter {
public:
    virtual ~DebugInfoWriter() = default;
    virtual void write(const std::string &hcclTrace);
    static DebugInfoWriter &getWriter(int rank);
    static void registerWriter(std::unique_ptr<DebugInfoWriter> writer);
    virtual std::string getWriterTarget()
    {
        return filename_;
    }

protected:
    DebugInfoWriter(std::string namePrefix, int rank)
    {
        filename_ = c10::str(namePrefix, rank);
    }
    std::string filename_;

private:
    static std::unique_ptr<DebugInfoWriter> writer_;
    static std::atomic<bool> hasWriterRegistered_;
};
} // namespace c10d_npu