#include "npu_decode_video_kernel.hpp"
#include <vector>
#include <map>
#include <sys/time.h>
#include <sys/prctl.h>
#include <ATen/ATen.h>
#include <torch/library.h>
#include "op_api_common.hpp"
#include <torch_npu/csrc/core/npu/NPUFormat.h>
#include <torch_npu/csrc/core/npu/NPUFunctions.h>
namespace vision {
namespace ops {
constexpr uint32_t MAX_CHN_HEIGHT = 4096;
constexpr uint32_t MAX_CHN_WIDTH = 4096;
constexpr int32_t SEND_TIMEOUT = 30;
constexpr uint32_t WAIT_TIMEOUT = 5000000;
constexpr uint32_t REF_FRAME_NUM = 16;
constexpr uint32_t DISPLAY_FRAME_NUM = 16;
constexpr uint32_t FRAME_BUF_CNT = REF_FRAME_NUM + DISPLAY_FRAME_NUM + 1;
pthread_t g_vdec_get_thread[VDEC_MAX_CHNL_NUM] = {0};
uint32_t g_get_exit_state[VDEC_MAX_CHNL_NUM] = {0};
std::vector<std::vector<at::Tensor>> g_out_queue(VDEC_MAX_CHNL_NUM);
std::mutex outTensorMapMutex[VDEC_MAX_CHNL_NUM];
std::map<hi_u64, at::Tensor> outTensorMap[VDEC_MAX_CHNL_NUM];
struct GetThreadPara {
uint32_t chnId;
uint32_t deviceId;
uint32_t totalFrame;
uint32_t successCnt;
};
GetThreadPara g_getPara[VDEC_MAX_CHNL_NUM];
static inline bool ValidChnNum(uint32_t chn)
{
return (chn < VDEC_MAX_CHNL_NUM);
}
static inline void get_current_time_us(uint64_t& timeUs)
{
struct timeval curTime;
gettimeofday(&curTime, NULL);
timeUs = static_cast<uint64_t>(curTime.tv_sec) * 1000000 + curTime.tv_usec;
}
template <class T>
static inline void LoadFunc(void* const handle, T& funPtr, const std::string& funName)
{
funPtr = reinterpret_cast<T>(dlsym(handle, funName.c_str()));
TORCH_CHECK(funPtr != nullptr, "vdec function not load, func name ", funName.c_str());
}
VideoDecode &VideoDecode::GetInstance()
{
static VideoDecode instance;
return instance;
}
VideoDecode::VideoDecode()
{
LoadFunctions();
for (uint32_t i = 0; i < VDEC_MAX_CHNL_NUM; ++i) {
channelStatus_[i] = ChnStatus::DESTROYED;
}
}
VideoDecode::~VideoDecode() {}
void VideoDecode::LoadFunctions()
{
void * const handle = dlopen("libacl_dvpp_mpi.so", RTLD_LAZY);
TORCH_CHECK(handle != nullptr, "dlopen libacl_dvpp_mpi.so fail");
LoadFunc(handle, sysInitFunPtr_, "hi_mpi_sys_init");
LoadFunc(handle, sysExitFunPtr_, "hi_mpi_sys_exit");
LoadFunc(handle, vdecGetPicBufferSizeFunPtr_, "hi_vdec_get_pic_buf_size");
LoadFunc(handle, vdecGetTmvBufferSizeFunPtr_, "hi_vdec_get_tmv_buf_size");
LoadFunc(handle, vdecCreateChnFunPtr_, "hi_mpi_vdec_create_chn");
LoadFunc(handle, vdecDestroyChnFunPtr_, "hi_mpi_vdec_destroy_chn");
LoadFunc(handle, sysSetChnCscMatrixFunPtr_, "hi_mpi_sys_set_chn_csc_matrix");
LoadFunc(handle, vdecStartRecvStreamFunPtr_, "hi_mpi_vdec_start_recv_stream");
LoadFunc(handle, vdecStopRecvStreamFunPtr_, "hi_mpi_vdec_stop_recv_stream");
LoadFunc(handle, vdecQueryStatusFunPtr_, "hi_mpi_vdec_query_status");
LoadFunc(handle, vdecResetChnFunPtr_, "hi_mpi_vdec_reset_chn");
LoadFunc(handle, vdecSendStreamFunPtr_, "hi_mpi_vdec_send_stream");
LoadFunc(handle, vdecGetFrameFunPtr_, "hi_mpi_vdec_get_frame");
LoadFunc(handle, vdecReleaseFrameFunPtr_, "hi_mpi_vdec_release_frame");
}
int32_t VideoDecode::GetUnusedChn(uint32_t& chn)
{
for (uint32_t i = 0; i < VDEC_MAX_CHNL_NUM; ++i) {
const std::lock_guard<std::mutex> guard(channelMutex_[i]);
if (channelStatus_[i] != ChnStatus::DESTROYED) {
continue;
} else {
channelStatus_[i] = ChnStatus::CREATED;
chn = i;
return 0;
}
}
return -1;
}
void VideoDecode::PutChn(uint32_t chn)
{
const std::lock_guard<std::mutex> guard(channelMutex_[chn]);
channelStatus_[chn] = ChnStatus::DESTROYED;
}
bool VideoDecode::ChannelCreated(uint32_t chn)
{
const std::lock_guard<std::mutex> guard(channelMutex_[chn]);
return (channelStatus_[chn] == ChnStatus::CREATED);
}
hi_s32 VideoDecode::sys_init(void)
{
return sysInitFunPtr_();
}
hi_s32 VideoDecode::sys_exit(void)
{
return sysExitFunPtr_();
}
hi_u32 VideoDecode::get_pic_buf_size(hi_payload_type type, hi_pic_buf_attr *buf_attr)
{
return vdecGetPicBufferSizeFunPtr_(type, buf_attr);
}
hi_u32 VideoDecode::get_tmv_buf_size(hi_payload_type type, hi_u32 width, hi_u32 height)
{
return vdecGetTmvBufferSizeFunPtr_(type, width, height);
}
hi_s32 VideoDecode::create_chn(hi_vdec_chn chn, const hi_vdec_chn_attr *attr)
{
return vdecCreateChnFunPtr_(chn, attr);
}
hi_s32 VideoDecode::destroy_chn(hi_vdec_chn chn)
{
return vdecDestroyChnFunPtr_(chn);
}
hi_s32 VideoDecode::sys_set_chn_csc_matrix(hi_vdec_chn chn)
{
return sysSetChnCscMatrixFunPtr_(HI_ID_VDEC, chn, HI_CSC_MATRIX_BT601_NARROW, nullptr);
}
hi_s32 VideoDecode::start_recv_stream(hi_vdec_chn chn)
{
return vdecStartRecvStreamFunPtr_(chn);
}
hi_s32 VideoDecode::stop_recv_stream(hi_vdec_chn chn)
{
return vdecStopRecvStreamFunPtr_(chn);
}
hi_s32 VideoDecode::query_status(hi_vdec_chn chn, hi_vdec_chn_status *status)
{
return vdecQueryStatusFunPtr_(chn, status);
}
hi_s32 VideoDecode::reset_chn(hi_vdec_chn chn)
{
return vdecResetChnFunPtr_(chn);
}
hi_s32 VideoDecode::send_stream(hi_vdec_chn chn, const hi_vdec_stream *stream,
hi_vdec_pic_info * vdec_pic_info, hi_s32 milli_sec)
{
return vdecSendStreamFunPtr_(chn, stream, vdec_pic_info, milli_sec);
}
hi_s32 VideoDecode::get_frame(hi_vdec_chn chn, hi_video_frame_info *frame_info,
hi_vdec_supplement_info *supplement, hi_vdec_stream *stream, hi_s32 milli_sec)
{
return vdecGetFrameFunPtr_(chn, frame_info, supplement, stream, milli_sec);
}
hi_s32 VideoDecode::release_frame(hi_vdec_chn chn, const hi_video_frame_info *frame_info)
{
return vdecReleaseFrameFunPtr_(chn, frame_info);
}
namespace {
static void vdec_reset_chn(uint32_t chn)
{
int32_t ret = VideoDecode::GetInstance().stop_recv_stream(chn);
TORCH_CHECK(ret == 0, "reset chn ", chn, ", hi_mpi_vdec_stop_recv_stream failed, ret = ", ret);
ret = VideoDecode::GetInstance().reset_chn(chn);
TORCH_CHECK(ret == 0, "reset chn ", chn, ", hi_mpi_vdec_reset_chn failed, ret = ", ret);
ret = VideoDecode::GetInstance().start_recv_stream(chn);
TORCH_CHECK(ret == 0, "reset chn ", chn, ", hi_mpi_vdec_start_recv_stream failed, ret = ", ret);
}
void* get_pic(void* args)
{
prctl(PR_SET_NAME, "VdecGetPic", 0, 0, 0);
GetThreadPara *para = (GetThreadPara*)args;
uint32_t chanId = para->chnId;
c10_npu::set_device(para->deviceId);
int32_t ret = HI_SUCCESS;
hi_video_frame_info frame{};
hi_vdec_stream stream{};
int32_t decResult = 0;
hi_u64 outputBuffer = 0;
uint32_t successCnt = 0;
uint32_t failCnt = 0;
int32_t timeOut = 0;
std::vector<at::Tensor> outQueue = std::vector<at::Tensor>(para->totalFrame);
g_get_exit_state[chanId] = 0;
while (g_get_exit_state[chanId] == 0) {
ret = VideoDecode::GetInstance().get_frame(chanId, &frame, nullptr, &stream, timeOut);
if (ret == HI_SUCCESS) {
outputBuffer = static_cast<hi_u64>(reinterpret_cast<uintptr_t>(frame.v_frame.virt_addr[0]));
decResult = frame.v_frame.frame_flag;
if (decResult == 0) {
const std::lock_guard<std::mutex> guard(outTensorMapMutex[chanId]);
auto iter = outTensorMap[chanId].find(outputBuffer);
if (iter != outTensorMap[chanId].end()) {
outQueue[successCnt] = iter->second;
outTensorMap[chanId].erase(iter);
successCnt++;
}
} else if (decResult == 1) {
failCnt++;
TORCH_WARN("chn ", chanId, "GetFrame Success, decode failed, fail count ", failCnt);
} else if (decResult == 2) {
} else if (decResult == 3) {
failCnt++;
TORCH_WARN("chn ", chanId, "GetFrame Success, refFrame num Error, fail count ", failCnt);
} else if (decResult == 4) {
failCnt++;
TORCH_WARN("chn ", chanId, "GetFrame Success, refFrame Size Error, fail count ", failCnt);
}
ret = VideoDecode::GetInstance().release_frame(chanId, &frame);
TORCH_CHECK(ret == 0, "chn ", chanId, ", hi_mpi_vdec_release_frame failed, ret = ", ret);
} else {
usleep(500);
}
}
g_out_queue[chanId] = outQueue;
para->successCnt = successCnt;
return (void*)HI_SUCCESS;
}
int64_t dvpp_sys_init()
{
return static_cast<int64_t>(VideoDecode::GetInstance().sys_init());
}
int64_t dvpp_sys_exit()
{
return static_cast<int64_t>(VideoDecode::GetInstance().sys_exit());
}
int64_t dvpp_vdec_create_chnl(int64_t pType)
{
TORCH_CHECK((pType == 96) || (pType == 265),
"invalid pType ", pType, ", should be H264:96, H265:265");
uint32_t chn = 0;
int32_t ret = VideoDecode::GetInstance().GetUnusedChn(chn);
TORCH_CHECK(ret == 0, "get unused chn failed");
hi_vdec_chn_attr chnAttr{};
chnAttr.type = static_cast<hi_payload_type>(pType);
chnAttr.mode = HI_VDEC_SEND_MODE_FRAME;
chnAttr.pic_width = MAX_CHN_WIDTH;
chnAttr.pic_height = MAX_CHN_HEIGHT;
chnAttr.stream_buf_size = MAX_CHN_WIDTH * MAX_CHN_HEIGHT * 3 / 2;
chnAttr.frame_buf_cnt = FRAME_BUF_CNT;
hi_pic_buf_attr buf_attr{MAX_CHN_WIDTH, MAX_CHN_HEIGHT, 0, HI_DATA_BIT_WIDTH_10, HI_PIXEL_FORMAT_YUV_SEMIPLANAR_420,
HI_COMPRESS_MODE_NONE};
chnAttr.frame_buf_size = VideoDecode::GetInstance().get_pic_buf_size(chnAttr.type, &buf_attr);
chnAttr.video_attr.ref_frame_num = REF_FRAME_NUM;
chnAttr.video_attr.temporal_mvp_en = HI_TRUE;
chnAttr.video_attr.tmv_buf_size = VideoDecode::GetInstance().get_tmv_buf_size(chnAttr.type, MAX_CHN_WIDTH,
MAX_CHN_HEIGHT);
ret = VideoDecode::GetInstance().create_chn(chn, &chnAttr);
if (ret != HI_SUCCESS) {
VideoDecode::GetInstance().PutChn(chn);
AT_ERROR("hi_mpi_vdec_create_chn ", chn, ", failed, ret = ", ret);
return -1;
}
ret = VideoDecode::GetInstance().sys_set_chn_csc_matrix(chn);
if (ret != HI_SUCCESS) {
int32_t result = VideoDecode::GetInstance().destroy_chn(chn);
VideoDecode::GetInstance().PutChn(chn);
AT_ERROR("chn ", chn, ", hi_mpi_sys_set_chn_csc_matrix failed, ret = ", ret);
return -1;
}
ret = VideoDecode::GetInstance().start_recv_stream(chn);
if (ret != HI_SUCCESS) {
int32_t result = VideoDecode::GetInstance().destroy_chn(chn);
VideoDecode::GetInstance().PutChn(chn);
AT_ERROR("chn ", chn, ", hi_mpi_vdec_start_recv_stream failed, ret = ", ret);
return -1;
}
return static_cast<int64_t>(chn);
}
int64_t dvpp_vdec_start_get_frame(int64_t chnId, int64_t totalFrame)
{
TORCH_CHECK(ValidChnNum(chnId), "invalid chn ", chnId);
int32_t deviceId = 0;
aclError aclRet = c10_npu::GetDevice(&deviceId);
TORCH_CHECK(aclRet == 0, "get device id failed, ret = ", aclRet);
g_getPara[chnId].chnId = chnId;
g_getPara[chnId].deviceId = deviceId;
g_getPara[chnId].totalFrame = totalFrame;
g_getPara[chnId].successCnt = 0;
g_vdec_get_thread[chnId] = 0;
int32_t ret = pthread_create(&g_vdec_get_thread[chnId], 0, get_pic, static_cast<void*>(&g_getPara[chnId]));
if (ret != 0) {
g_vdec_get_thread[chnId] = 0;
AT_ERROR("Chn ", chnId, ", create get pic thread failed, ret = ", ret);
return -1;
}
return 0;
}
int64_t dvpp_vdec_send_stream(int64_t chnId, const at::Tensor& self, int64_t outFormat, bool display, at::Tensor& out)
{
TORCH_CHECK(ValidChnNum(chnId), "invalid chn ", chnId);
hi_pixel_format outputFormat = static_cast<hi_pixel_format>(outFormat);
TORCH_CHECK(((outputFormat == HI_PIXEL_FORMAT_RGB_888) || (outputFormat == HI_PIXEL_FORMAT_BGR_888) ||
(outputFormat == HI_PIXEL_FORMAT_RGB_888_PLANAR) || (outputFormat == HI_PIXEL_FORMAT_BGR_888_PLANAR)),
"invalid outFormat ", outputFormat, ", should be ", HI_PIXEL_FORMAT_RGB_888, " or ", HI_PIXEL_FORMAT_BGR_888,
" or ", HI_PIXEL_FORMAT_RGB_888_PLANAR, " or ", HI_PIXEL_FORMAT_BGR_888_PLANAR);
auto selfSize = self.sizes();
int64_t selfNelements = c10::multiply_integers(selfSize);
auto selfDtype = self.dtype();
int64_t selfSizeBytes = selfNelements * selfDtype.itemsize();
auto outSize = out.sizes();
int64_t outNelements = c10::multiply_integers(outSize);
auto outDtype = out.dtype();
int64_t outSizeBytes = outNelements * outDtype.itemsize();
hi_vdec_stream stream{};
uint64_t currentSendTime = 0;
get_current_time_us(currentSendTime);
stream.pts = currentSendTime;
stream.addr = static_cast<hi_u8*>(self.data_ptr());
stream.len = selfSizeBytes;
stream.end_of_frame = HI_TRUE;
stream.end_of_stream = HI_FALSE;
stream.need_display = display ? HI_TRUE : HI_FALSE;
hi_vdec_pic_info outPicInfo{};
outPicInfo.height = 0;
outPicInfo.width = 0;
outPicInfo.width_stride = 0;
outPicInfo.height_stride = 0;
outPicInfo.pixel_format = outputFormat;
outPicInfo.vir_addr = 0;
outPicInfo.buffer_size = 0;
if (display) {
outPicInfo.vir_addr = static_cast<hi_u64>(reinterpret_cast<uintptr_t>(out.data_ptr()));
outPicInfo.buffer_size = outSizeBytes;
}
uint32_t sendOneFrameCnt = 0;
int32_t ret = 0;
do {
sendOneFrameCnt++;
ret = VideoDecode::GetInstance().send_stream(chnId, &stream, &outPicInfo, SEND_TIMEOUT);
if (sendOneFrameCnt > 30) {
if (ret != 0) {
vdec_reset_chn(chnId);
}
break;
}
} while (ret == HI_ERR_VDEC_BUF_FULL);
TORCH_CHECK(ret == 0, "chn ", chnId, ", hi_mpi_vdec_send_stream failed, ret = ", ret);
if (display) {
const std::lock_guard<std::mutex> guard(outTensorMapMutex[chnId]);
outTensorMap[chnId].insert(std::make_pair(static_cast<hi_u64>(reinterpret_cast<uintptr_t>(out.data_ptr())),
out));
}
return 0;
}
at::Tensor dvpp_vdec_stop_get_frame(int64_t chnId, int64_t totalFrame)
{
hi_vdec_chn_status status{};
hi_vdec_chn_status pre_status{};
hi_vdec_stream stream{};
hi_vdec_pic_info outPicInfo{};
stream.addr = NULL;
stream.len = 0;
stream.end_of_frame = HI_FALSE;
stream.end_of_stream = HI_TRUE;
outPicInfo.vir_addr = 0;
outPicInfo.buffer_size = 0;
int32_t ret = VideoDecode::GetInstance().send_stream(chnId, &stream, &outPicInfo, -1);
TORCH_CHECK(ret == 0, "chn ", chnId, ", hi_mpi_vdec_send_stream send end_of_stream failed, ret = ", ret);
uint32_t waitTimes = 0;
uint32_t sleepTime = 10000;
ret = VideoDecode::GetInstance().stop_recv_stream(chnId);
TORCH_CHECK(ret == 0, "chn ", chnId, ", hi_mpi_vdec_stop_recv_stream failed, ret = ", ret);
while (waitTimes < WAIT_TIMEOUT) {
ret = VideoDecode::GetInstance().query_status(chnId, &status);
TORCH_CHECK(ret == 0, "chn ", chnId, ", hi_mpi_vdec_query_status failed, ret = ", ret);
if (((status.left_stream_bytes == 0) && (status.left_decoded_frames == 0))) {
break;
}
if (status.left_decoded_frames == pre_status.left_decoded_frames) {
waitTimes += sleepTime;
} else {
waitTimes = 0;
}
pre_status = status;
usleep(sleepTime);
if (waitTimes >= WAIT_TIMEOUT) {
vdec_reset_chn(chnId);
break;
}
}
g_get_exit_state[chnId] = 1;
ret = pthread_join(g_vdec_get_thread[chnId], nullptr);
TORCH_CHECK(ret == 0, "chn ", chnId, ", pthread_join get_pic thread failed, ret = ", ret);
g_vdec_get_thread[chnId] = 0;
if (g_getPara[chnId].successCnt == totalFrame) {
g_out_queue[chnId].clear();
outTensorMap[chnId].clear();
return at::empty({0});
}
at::Tensor &tensor = g_out_queue[chnId][0];
at::Tensor result_tensor =
at::empty({g_getPara[chnId].successCnt, tensor.size(0), tensor.size(1), tensor.size(2)}, tensor.options());
for (int i = 0 ; i < g_getPara[chnId].successCnt; i++)
{
result_tensor[i] = g_out_queue[chnId][i].squeeze(0);
}
g_out_queue[chnId].clear();
outTensorMap[chnId].clear();
return result_tensor;
}
int64_t dvpp_vdec_destroy_chnl(int64_t chnId)
{
int32_t ret = VideoDecode::GetInstance().destroy_chn(chnId);
VideoDecode::GetInstance().PutChn(chnId);
TORCH_CHECK(ret == 0, "chn ", chnId, ", hi_mpi_vdec_destroy_chn failed, ret ", ret);
return 0;
}
}
TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def("_dvpp_sys_init() -> int", &dvpp_sys_init);
m.def("_dvpp_sys_exit() -> int", &dvpp_sys_exit);
m.def("_decode_video_create_chn(int ptype) -> int", &dvpp_vdec_create_chnl);
m.def("_decode_video_start_get_frame(int chnId, int totalFrame) -> int", &dvpp_vdec_start_get_frame);
m.def("_decode_video_send_stream(int chnId, Tensor self, int outFormat, bool display, Tensor out) -> int", &dvpp_vdec_send_stream);
m.def("_decode_video_stop_get_frame(int chnId, int totalFrame) -> Tensor", &dvpp_vdec_stop_get_frame);
m.def("_decode_video_destroy_chnl(int chnId) -> int", &dvpp_vdec_destroy_chnl);
}
}
}