#pragma once
#include <cstring>
#include <functional>
#include <memory>
#include "kcal/core/context.h"
#include "kcal/yacl_linker.h"
namespace kcal {
class ContextExt {
public:
using SendCallback = std::function<int(const TeeNodeInfo &, const uint8_t *, size_t)>;
using RecvCallback = std::function<int(const TeeNodeInfo &, uint8_t *, size_t)>;
static std::shared_ptr<ContextExt> Create(Config config, SendCallback sendCb, RecvCallback recvCb);
static std::shared_ptr<ContextExt> CreateFromYaclLinker(Config config, std::shared_ptr<YaclLinker> linker);
ContextExt() = default;
~ContextExt();
std::shared_ptr<Context> GetKcalContext() const { return kcalCtx_; }
std::shared_ptr<yacl::link::Context> GetYaclContext() const
{
return yaclLinker_ ? yaclLinker_->GetYaclContext() : nullptr;
}
static ContextExt *GetCurrentContext() { return currentContext_; }
private:
ContextExt(SendCallback sendCb, RecvCallback recvCb);
explicit ContextExt(std::shared_ptr<YaclLinker> linker);
static ContextExt *currentContext_;
SendCallback sendCallback_;
RecvCallback recvCallback_;
std::shared_ptr<Context> kcalCtx_;
std::shared_ptr<YaclLinker> yaclLinker_;
static int SendDataThunk(TeeNodeInfo *nodeInfo, unsigned char *buf, u64 len);
static int RecvDataThunk(TeeNodeInfo *nodeInfo, unsigned char *buf, u64 *len);
};
}