#include "csrc/OpApiCommon.h"
thread_local char g_hashBuf[kHashBufSize];
thread_local int g_hashOffset = 0;
constexpr int g_mix64Shift = 33;
typedef void (*AddTensorAddrToCachedList)(void* addr);
void AddParamToBuf(const at::Tensor& at_tensor)
{
static const auto addTensorAddrToCachedListAddr = GetOpApiFuncAddr("AddTensorAddrToCachedList");
AddTensorAddrToCachedList addTensorAddrToCachedListFunc =
reinterpret_cast<AddTensorAddrToCachedList>(addTensorAddrToCachedListAddr);
if (!at_tensor.defined()) {
MEMCPY_TO_BUF(",", 1);
return;
}
MEMCPY_TO_BUF(at_tensor.sizes().data(), at_tensor.sizes().size() * sizeof(int64_t));
auto st = at_tensor.scalar_type();
MEMCPY_TO_BUF(&st, sizeof(st));
MEMCPY_TO_BUF(",", 1);
MEMCPY_TO_BUF(at_tensor.strides().data(), at_tensor.sizes().size() * sizeof(int64_t));
auto so = at_tensor.storage_offset();
MEMCPY_TO_BUF(&so, sizeof(so));
aclDataType acl_data_type = kATenScalarTypeToAclDataTypeTable[static_cast<int64_t>(st)];
c10::SmallVector<int64_t, 5> storageDims;
if (acl_data_type != ACL_STRING) {
storageDims.push_back(at_tensor.storage().nbytes() / at_tensor.itemsize());
}
MEMCPY_TO_BUF(storageDims.data(), storageDims.size() * sizeof(int64_t));
addTensorAddrToCachedListFunc(const_cast<void*>(at_tensor.storage().data()));
}
void AddParamToBuf(const at::Scalar& at_scalar)
{
at::ScalarType scalar_data_type = at_scalar.type();
switch (scalar_data_type) {
case at::ScalarType::Double: {
double value = at_scalar.toDouble();
MEMCPY_TO_BUF(&value, sizeof(double));
break;
}
case at::ScalarType::Long: {
int64_t value = at_scalar.toLong();
MEMCPY_TO_BUF(&value, sizeof(int64_t));
break;
}
case at::ScalarType::Bool: {
bool value = at_scalar.toBool();
MEMCPY_TO_BUF(&value, sizeof(bool));
break;
}
case at::ScalarType::ComplexDouble: {
auto value = at_scalar.toComplexDouble();
MEMCPY_TO_BUF(&value, sizeof(value));
break;
}
default: {
break;
}
}
}
void AddParamToBuf(const at::IntArrayRef& at_array)
{
MEMCPY_TO_BUF(at_array.data(), at_array.size() * sizeof(int64_t));
}
void AddParamToBuf(const at::ArrayRef<bool>& at_array)
{
MEMCPY_TO_BUF(at_array.data(), at_array.size() * sizeof(bool));
}
void AddParamToBuf(const at::TensorList& at_tensor_list)
{
for (size_t i = 0; i < at_tensor_list.size(); i++) {
AddParamToBuf(at_tensor_list[i]);
}
auto counter = at_tensor_list.size();
MEMCPY_TO_BUF(&counter, sizeof(counter));
}
void AddParamToBuf(const c10::optional<at::Tensor>& opt_tensor)
{
if (opt_tensor.has_value() && opt_tensor.value().defined()) {
AddParamToBuf(opt_tensor.value());
} else {
MEMCPY_TO_BUF(",", 1);
}
}
void AddParamToBuf(const c10::optional<at::IntArrayRef>& opt_array)
{
if (opt_array.has_value()) {
AddParamToBuf(opt_array.value());
} else {
MEMCPY_TO_BUF(",", 1);
}
}
void AddParamToBuf(const c10::optional<at::Scalar>& opt_scalar)
{
if (opt_scalar.has_value()) {
AddParamToBuf(opt_scalar.value());
} else {
MEMCPY_TO_BUF(",", 1);
}
}
void AddParamToBuf(const at::ScalarType scalar_type)
{
MEMCPY_TO_BUF(&scalar_type, sizeof(scalar_type));
}
void AddParamToBuf(const string& s)
{
MEMCPY_TO_BUF(s.c_str(), s.size());
}
void AddParamToBuf() {}
inline uint64_t rotl64(uint64_t x, int8_t r)
{
return (x << r) | (x >> (64 - r));
}
#define ROTL64(x, y) rotl64(x, y)
#define BIG_CONSTANT(x) (x##LLU)
inline uint64_t GetBlock64(const uint64_t* p, int i)
{
return p[i];
}
inline uint64_t fmix64(uint64_t k)
{
k ^= k >> g_mix64Shift;
k *= BIG_CONSTANT(0xff51afd7ed558ccd);
k ^= k >> g_mix64Shift;
k *= BIG_CONSTANT(0xc4ceb9fe1a85ec53);
k ^= k >> g_mix64Shift;
return k;
}
uint64_t MurmurHash(const void* key, const int len, const uint32_t seed = 0xdeadb0d7)
{
const uint8_t* data = (const uint8_t*)key;
const int nblocks = len / 16;
uint64_t h1 = seed;
uint64_t h2 = seed;
const uint64_t c1 = BIG_CONSTANT(0x87c37b91114253d5);
const uint64_t c2 = BIG_CONSTANT(0x4cf5ad432745937f);
const uint64_t* blocks = (const uint64_t*)(data);
for (int i = 0; i < nblocks; i++) {
int even_num = 2;
int odd_num = 1;
uint64_t k1 = GetBlock64(blocks, i * even_num);
uint64_t k2 = GetBlock64(blocks, i * even_num + odd_num);
int8_t k1_shift = 31;
k1 *= c1;
k1 = ROTL64(k1, k1_shift);
k1 *= c2;
h1 ^= k1;
int8_t h1_shift = 27;
h1 = ROTL64(h1, h1_shift);
h1 += h2;
h1 = h1 * 5 + 0x52dce729;
int8_t k2_shift = 33;
k2 *= c2;
k2 = ROTL64(k2, k2_shift);
k2 *= c1;
h2 ^= k2;
int8_t h2_shift = 31;
h2 = ROTL64(h2, h2_shift);
h2 += h1;
h2 = h2 * 5 + 0x38495ab5;
}
const uint8_t* tail = (const uint8_t*)(data + nblocks * 16);
uint64_t k1 = 0;
uint64_t k2 = 0;
switch (len & 15) {
case 15:
k2 ^= ((uint64_t)tail[14]) << 48;
[[fallthrough]];
case 14:
k2 ^= ((uint64_t)tail[13]) << 40;
[[fallthrough]];
case 13:
k2 ^= ((uint64_t)tail[12]) << 32;
[[fallthrough]];
case 12:
k2 ^= ((uint64_t)tail[11]) << 24;
[[fallthrough]];
case 11:
k2 ^= ((uint64_t)tail[10]) << 16;
[[fallthrough]];
case 10:
k2 ^= ((uint64_t)tail[9]) << 8;
[[fallthrough]];
case 9:
k2 ^= ((uint64_t)tail[8]) << 0;
k2 *= c2;
k2 = ROTL64(k2, 33);
k2 *= c1;
h2 ^= k2;
[[fallthrough]];
case 8:
k1 ^= ((uint64_t)tail[7]) << 56;
[[fallthrough]];
case 7:
k1 ^= ((uint64_t)tail[6]) << 48;
[[fallthrough]];
case 6:
k1 ^= ((uint64_t)tail[5]) << 40;
[[fallthrough]];
case 5:
k1 ^= ((uint64_t)tail[4]) << 32;
[[fallthrough]];
case 4:
k1 ^= ((uint64_t)tail[3]) << 24;
[[fallthrough]];
case 3:
k1 ^= ((uint64_t)tail[2]) << 16;
[[fallthrough]];
case 2:
k1 ^= ((uint64_t)tail[1]) << 8;
[[fallthrough]];
case 1:
k1 ^= ((uint64_t)tail[0]) << 0;
k1 *= c1;
k1 = ROTL64(k1, 31);
k1 *= c2;
h1 ^= k1;
[[fallthrough]];
default:
break;
};
h1 ^= len;
h2 ^= len;
h1 += h2;
h2 += h1;
h1 = fmix64(h1);
h2 = fmix64(h2);
h1 += h2;
h2 += h1;
return h2;
}
uint64_t CalcHashId()
{
if (g_hashOffset == kHashBufMaxSize) {
return 0;
}
uint64_t hash_id = MurmurHash(g_hashBuf, g_hashOffset);
return hash_id;
}