#include "yuv420p_rgb.h"
namespace vision {
namespace ops {
at::Tensor kp_yuv2rgb(int64_t nthreads, int64_t width, int64_t height, int64_t color_range, int64_t color_space,
const at::Tensor ydata, std::vector<int64_t> src_stride, const at::Tensor udata,
const at::Tensor vdata)
{
int channels = 3;
at::Tensor tmp = at::empty({height, width, channels}, at::kByte);
if (color_range == KP_AVCOL_RANGE_JPEG) {
kp_yuv420p_to_rgb_full(nthreads, width, height,
(const uint8_t*)ydata.data_ptr(),
src_stride.data(),
(const uint8_t*)udata.data_ptr(),
(const uint8_t*)vdata.data_ptr(),
(uint8_t*)tmp.data_ptr());
} else {
kp_yuv420p_to_rgb_limit(nthreads, width, height,
(const uint8_t*)ydata.data_ptr(),
src_stride.data(),
(const uint8_t*)udata.data_ptr(),
(const uint8_t*)vdata.data_ptr(),
(uint8_t*)tmp.data_ptr());
}
return tmp;
}
TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::yuv2rgb"), TORCH_FN(kp_yuv2rgb));
}
}
}