#include <ATen/ATen.h>
#include <torch/library.h>
#include "op_api_common.hpp"
namespace vision {
namespace ops {
namespace {
const int SIZE = 8;
constexpr int64_t ratio = 1;
constexpr bool fancy_upscaling = true;
constexpr float acceptable_fraction = 1.0;
const std::string dct_method = "";
const std::string dst_img_format = "CHW";
c10::SmallVector<int64_t, SIZE> decode_jpeg_output_size(
at::IntArrayRef image_shape,
int64_t channels)
{
TORCH_CHECK(image_shape.size() == 3,
"Op[decode_jpeg_aclop] argument[image_shape] should have 3 elements: (height, width, channels).");
int64_t h = image_shape[0];
int64_t w = image_shape[1];
int64_t c = image_shape[2];
c10::SmallVector<int64_t, SIZE> output_size;
if (channels == 0) {
output_size = {1, c, h, w};
} else {
output_size = {1, channels, h, w};
}
return output_size;
}
at::Tensor decode_jpeg_aclop_kernel(const at::Tensor& self,
at::IntArrayRef image_shape,
int64_t channels,
bool try_recover_truncated)
{
auto output_size = decode_jpeg_output_size(image_shape, channels);
output_size = {output_size[1], output_size[2], output_size[3]};
at::Tensor result = at::empty(output_size, self.options().dtype(at::kByte));
at_npu::native::OpCommand cmd;
cmd.Name("DecodeJpeg")
.Input(self, "", c10::nullopt, "string")
.Output(result)
.Attr("channels", channels)
.Attr("ratio", ratio)
.Attr("fancy_upscaling", fancy_upscaling)
.Attr("try_recover_truncated", try_recover_truncated)
.Attr("acceptable_fraction", acceptable_fraction)
.Attr("dct_method", dct_method)
.Attr("dst_img_format", dst_img_format)
.Run();
return result;
}
at::Tensor decode_jpeg_aclnn_kernel(const at::Tensor& self,
at::IntArrayRef image_shape,
int64_t channels,
bool try_recover_truncated)
{
auto output_size = decode_jpeg_output_size(image_shape, channels);
at::Tensor result = at::empty(output_size, self.options().dtype(at::kByte));
EXEC_NPU_CMD(acldvppDecodeJpeg, self, channels, try_recover_truncated, result);
return result;
}
}
TORCH_LIBRARY_IMPL(torchvision, PrivateUse1, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::_decode_jpeg_aclop"), TORCH_FN(decode_jpeg_aclop_kernel));
m.impl(TORCH_SELECTIVE_NAME("torchvision::_decode_jpeg_aclnn"), TORCH_FN(decode_jpeg_aclnn_kernel));
}
}
}