// Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.

#pragma once

#include <cstring>
#include <functional>
#include <memory>

#include "kcal/core/context.h"
#include "kcal/yacl_linker.h"

namespace kcal {

class ContextExt {
public:
    using SendCallback = std::function<int(const TeeNodeInfo &, const uint8_t *, size_t)>;
    using RecvCallback = std::function<int(const TeeNodeInfo &, uint8_t *, size_t)>;

    static std::shared_ptr<ContextExt> Create(Config config, SendCallback sendCb, RecvCallback recvCb);
    static std::shared_ptr<ContextExt> CreateFromYaclLinker(Config config, std::shared_ptr<YaclLinker> linker);

    ContextExt() = default;
    ~ContextExt();

    std::shared_ptr<Context> GetKcalContext() const { return kcalCtx_; }
    std::shared_ptr<yacl::link::Context> GetYaclContext() const
    {
        return yaclLinker_ ? yaclLinker_->GetYaclContext() : nullptr;
    }

    static ContextExt *GetCurrentContext() { return currentContext_; }

private:
    ContextExt(SendCallback sendCb, RecvCallback recvCb);
    explicit ContextExt(std::shared_ptr<YaclLinker> linker);

    static ContextExt *currentContext_;

    SendCallback sendCallback_;
    RecvCallback recvCallback_;
    std::shared_ptr<Context> kcalCtx_;

    std::shared_ptr<YaclLinker> yaclLinker_;

    static int SendDataThunk(TeeNodeInfo *nodeInfo, unsigned char *buf, u64 len);
    static int RecvDataThunk(TeeNodeInfo *nodeInfo, unsigned char *buf, u64 *len);
};

} // namespace kcal