#include <ATen/ATen.h>
#include <torch/library.h>
namespace vision {
namespace ops {
namespace {
struct ImageShapeInfo {
int64_t width;
int64_t height;
int64_t channels;
};
static const std::unordered_set<uint16_t> marker_sof = {
0xFFC0, 0xFFC1, 0xFFC2, 0xFFC3, 0xFFC5, 0xFFC6, 0xFFC7,
0xFFC9, 0xFFCA, 0xFFCB, 0xFFCD, 0xFFCE, 0xFFCF
};
ImageShapeInfo get_jpeg_shape_from_memory(const uint8_t* jpegData, size_t dataLength)
{
ImageShapeInfo info = {0, 0, 0};
size_t min_length = 2;
if (dataLength < min_length || jpegData[0] != 0xFF || jpegData[1] != 0xD8) {
TORCH_CHECK(false, "invalid JPEG");
return info;
}
size_t index = 2;
auto times = 0;
size_t byte_size = 8;
size_t precision_size = 5;
size_t add_size = 4;
while (index < dataLength - 1) {
times += 1;
if (jpegData[index] == 0xFF) {
uint16_t marker = (jpegData[index] << byte_size) | jpegData[index + 1];
if (marker_sof.find(marker) != marker_sof.end()) {
index += precision_size;
TORCH_CHECK(index + add_size < dataLength, "JPEG segment exceeds data length: invalid segment");
info.height = (jpegData[index] << byte_size) | jpegData[index + 1];
index += 2;
info.width = (jpegData[index] << byte_size) | jpegData[index + 1];
index += 2;
info.channels = jpegData[index];
break;
}
if (index + 3 >= dataLength) {
TORCH_CHECK(false, "no marker found");
break;
}
int segmentLength = (jpegData[index + 2] << byte_size) | jpegData[index + 3];
index += segmentLength + 2;
}
else {
TORCH_CHECK(false, "invalid JPEG");
break;
}
}
return info;
}
std::tuple<int64_t, int64_t, int64_t> _get_jpeg_image_shape(const at::Tensor &data)
{
TORCH_CHECK(data.dtype() == at::kByte, "Expected a torch.uint8 tensor");
TORCH_CHECK(data.dim() == 1 && data.numel() > 0, "Expected a non empty 1-dimensional tensor");
auto datap = data.data_ptr<uint8_t>();
auto info = get_jpeg_shape_from_memory(datap, data.numel());
return std::make_tuple(info.height, info.width, info.channels);
}
}
TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::_get_jpeg_image_shape(Tensor self) -> (int ,int , int)"));
}
TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::_get_jpeg_image_shape"), TORCH_FN(_get_jpeg_image_shape));
}
}
}