#include "xsched/cuda/hal/level2/guardian.h"
#include "xsched/cuda/hal/level3/trap.h"
#include "xsched/cuda/hal/level3/cuda_queue.h"
#include "xsched/cuda/hal/common/levels.h"
#include "xsched/cuda/hal/common/driver.h"
#include "xsched/cuda/hal/common/cuda_assert.h"
#include "xsched/cuda/hal/arch/sm35.h"
#include "xsched/cuda/hal/arch/sm70.h"
#include "xsched/cuda/hal/arch/sm86.h"
using namespace xsched::cuda;
using namespace xsched::preempt;
static int32_t GetArch(CUdevice dev)
{
int major, minor;
CUDA_ASSERT(Driver::DeviceGetAttribute(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, dev));
CUDA_ASSERT(Driver::DeviceGetAttribute(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, dev));
return major * 10 + minor;
}
std::shared_ptr<Guardian> Guardian::Instance(CUdevice dev)
{
switch (GetArch(dev)) {
case 35:
return std::make_shared<GuardianSM35>();
case 70:
return std::make_shared<GuardianSM70>();
case 86:
return std::make_shared<GuardianSM86>();
default:
return nullptr;
}
}
std::shared_ptr<TarpHandler> TarpHandler::Instance(CUdevice dev)
{
switch (GetArch(dev)) {
case 70:
return std::make_shared<TarpHandlerSM70>();
case 86:
return std::make_shared<TarpHandlerSM86>();
default:
return nullptr;
}
}
std::shared_ptr<HwQueue> xsched::cuda::CudaQueueCreate(CUstream stream)
{
#if defined(_WIN32)
return std::make_shared<CudaQueueLv1>(stream);
#endif
if (GetCudaLv3Implementation() == kCudaLv3ImplementationTsg) {
return std::make_shared<CudaQueueLv3Tsg>(stream);
}
CUdevice dev;
CUcontext stream_ctx;
CUcontext current_ctx;
CUDA_ASSERT(Driver::StreamGetCtx(stream, &stream_ctx));
CUDA_ASSERT(Driver::CtxGetCurrent(¤t_ctx));
XASSERT(current_ctx == stream_ctx,
"create CudaQueue failed: current context (%p) does not match stream context (%p)",
current_ctx, stream_ctx);
CUDA_ASSERT(Driver::CtxGetDevice(&dev));
switch (GetArch(dev)) {
case 35:
return std::make_shared<CudaQueueLv2>(stream);
case 70:
case 86:
return std::make_shared<CudaQueueLv3Trap>(stream);
default:
return std::make_shared<CudaQueueLv1>(stream);
}
}
CUresult xsched::cuda::DirectLaunch(std::shared_ptr<CudaKernelCommand> kernel, CUstream stream)
{
#if defined(_WIN32)
return CudaQueueLv1::DirectLaunch(kernel, stream);
#endif
if (GetCudaLv3Implementation() == kCudaLv3ImplementationTsg) {
return CudaQueueLv3Tsg::DirectLaunch(kernel, stream);
}
CUdevice dev;
CUcontext stream_ctx;
CUcontext current_ctx;
CUDA_ASSERT(Driver::StreamGetCtx(stream, &stream_ctx));
CUDA_ASSERT(Driver::CtxGetCurrent(¤t_ctx));
XASSERT(current_ctx == stream_ctx,
"direct launch kernel failed: current context (%p) does not match stream context (%p)",
current_ctx, stream_ctx);
CUDA_ASSERT(Driver::CtxGetDevice(&dev));
switch (GetArch(dev)) {
case 35:
return CudaQueueLv2::DirectLaunch(kernel, current_ctx, stream);
case 70:
case 86:
return CudaQueueLv3Trap::DirectLaunch(kernel, current_ctx, stream);
default:
return CUDA_ERROR_NOT_SUPPORTED;
}
}