* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cmath>
#include <string>
#include "src/weight_decoder.h"
#include "src/huffman_decode.h"
namespace mindspore::lite {
constexpr int kBit8 = 8;
constexpr int kBit32 = 32;
std::vector<bool> StringToBitVector(const std::string &str) {
std::vector<bool> vec(str.size() * kBit8);
size_t index = 0;
for (auto ch : str) {
for (size_t shift = kBit8; shift > 0; shift--) {
vec[index++] = (ch >> static_cast<size_t>(shift - 1)) & 0x1;
}
}
return vec;
}
STATUS IndexingDecompress(const schema::Tensor &src_tensor, Tensor *dst_tensor) {
MS_LOG(DEBUG) << "un-index weight";
MS_CHECK_TRUE_MSG(src_tensor.quantParams() != nullptr, RET_ERROR, "quant params is nullptr");
MS_CHECK_TRUE_MSG((*src_tensor.quantParams()).size() > 0, RET_ERROR, "quant params size need bigger than 0");
MS_CHECK_TRUE_MSG(src_tensor.quantParams()->Get(0) != nullptr, RET_ERROR, "quant param is nullptr");
auto bit_num = src_tensor.quantParams()->Get(0)->numBits();
std::string str(reinterpret_cast<const char *>(src_tensor.data()->data()), src_tensor.data()->size());
auto bit_vec = StringToBitVector(str);
size_t index = 0;
size_t unique_value_cnt = 0;
for (int i = 0; i < bit_num; i++) {
bool bit = bit_vec[index++];
unique_value_cnt |= bit << static_cast<size_t>((bit_num - i - 1));
}
if (unique_value_cnt == 0) {
unique_value_cnt = 1 << bit_num;
}
std::vector<int> unique_values;
for (size_t i = 0; i < unique_value_cnt; i++) {
int unique_value = 0;
for (int j = 0; j < bit_num; j++) {
bool bit = bit_vec[index++];
unique_value |= bit << static_cast<size_t>((bit_num - j - 1));
}
unique_values.push_back(unique_value - (1 << static_cast<size_t>((bit_num - 1))));
}
std::vector<size_t> unique_value_index_vec;
auto elem_cnt = dst_tensor->ElementsNum();
size_t unique_value_bit = ceil(log2(unique_value_cnt));
for (int i = 0; i < elem_cnt; i++) {
size_t unique_value_index = 0;
for (size_t j = 0; j < unique_value_bit; j++) {
bool bit = bit_vec[index++];
unique_value_index |= bit << (static_cast<size_t>(unique_value_bit - j - 1));
}
unique_value_index_vec.push_back(unique_value_index);
}
if (dst_tensor->data() != nullptr) {
MS_LOG(ERROR) << "data_c not null";
return RET_ERROR;
}
auto ret = dst_tensor->MallocData();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Malloc tensor data failed";
return RET_NULL_PTR;
}
auto dst_data = dst_tensor->data();
if (bit_num <= kBit8) {
ret = UnIndexTensorData<int8_t>(unique_values, unique_value_index_vec, dst_data, dst_tensor->Size());
} else {
ret = UnIndexTensorData<int16_t>(unique_values, unique_value_index_vec, dst_data, dst_tensor->Size());
}
if (ret != RET_OK) {
MS_LOG(ERROR) << "UnIndexTensorData error";
return RET_ERROR;
}
return RET_OK;
}
STATUS SparseDecompress(const schema::Tensor &src_tensor, Tensor *dst_tensor) {
MS_LOG(DEBUG) << "un-sparse weight";
MS_CHECK_TRUE_MSG(src_tensor.quantParams() != nullptr, RET_ERROR, "quant params is nullptr");
MS_CHECK_TRUE_MSG((*src_tensor.quantParams()).size() > 0, RET_ERROR, "quant params size need bigger than 0");
MS_CHECK_TRUE_MSG(src_tensor.quantParams()->Get(0) != nullptr, RET_ERROR, "quant param is nullptr");
size_t bit_num = src_tensor.quantParams()->Get(0)->numBits();
std::string str(reinterpret_cast<const char *>(src_tensor.data()->data()), src_tensor.data()->size());
auto bit_vec = StringToBitVector(str);
size_t index = 0;
size_t coor_best_bit = 0;
for (size_t i = 0; i < kBit8; i++) {
bool bit = bit_vec[index++];
coor_best_bit |= bit << static_cast<size_t>((kBit8 - i - 1));
}
size_t nz_cnt = 0;
for (size_t i = 0; i < kBit32; i++) {
bool bit = bit_vec[index++];
nz_cnt |= bit << static_cast<size_t>((kBit32 - i - 1));
}
size_t unique_value_cnt = 0;
for (size_t i = 0; i < bit_num; i++) {
bool bit = bit_vec[index++];
unique_value_cnt |= bit << static_cast<size_t>((bit_num - i - 1));
}
if (unique_value_cnt == 0) {
unique_value_cnt = 1 << bit_num;
}
std::vector<int> unique_values;
for (size_t i = 0; i < unique_value_cnt; i++) {
int unique_value = 0;
for (size_t j = 0; j < bit_num; j++) {
bool bit = bit_vec[index++];
unique_value |= bit << static_cast<size_t>((bit_num - j - 1));
}
unique_values.push_back(unique_value - (1 << static_cast<size_t>((bit_num - 1))));
}
std::vector<size_t> unique_value_index_vec;
auto elem_cnt = dst_tensor->ElementsNum();
size_t unique_value_bit = static_cast<size_t>(ceil(log2(unique_value_cnt)));
for (size_t i = 0; i < nz_cnt; i++) {
size_t unique_value_index = 0;
for (size_t j = 0; j < unique_value_bit; j++) {
bool bit = bit_vec[index++];
unique_value_index |= bit << (unique_value_bit - j - 1);
}
unique_value_index_vec.push_back(unique_value_index);
}
std::vector<size_t> coor_vec;
for (size_t i = 0; i < nz_cnt; i++) {
size_t coor = 0;
for (size_t j = 0; j < coor_best_bit; j++) {
bool bit = bit_vec[index++];
coor |= bit << static_cast<size_t>((coor_best_bit - j - 1));
}
coor_vec.push_back(coor);
}
if (dst_tensor->data() != nullptr) {
MS_LOG(ERROR) << "data_c not null";
return RET_ERROR;
}
auto ret = dst_tensor->MallocData();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Malloc tensor data failed";
return RET_NULL_PTR;
}
auto dst_data = dst_tensor->data();
if (bit_num <= kBit8) {
ret = UnSparseTensorData<int8_t>(unique_values, unique_value_index_vec, coor_vec, src_tensor.quantParams(),
elem_cnt, coor_best_bit, dst_data, dst_tensor->Size());
} else {
ret = UnSparseTensorData<int16_t>(unique_values, unique_value_index_vec, coor_vec, src_tensor.quantParams(),
elem_cnt, coor_best_bit, dst_data, dst_tensor->Size());
}
if (ret != RET_OK) {
MS_LOG(ERROR) << "UnSparseTensorData error";
return RET_ERROR;
}
return RET_OK;
}
int WeightDecoder::DequantWeight(lite::Tensor *input_tensor, bool channel_first, TypeId dst_data_type) {
MS_ASSERT(input_tensor != nullptr);
if (input_tensor->data_type() != kNumberTypeInt8 && input_tensor->data_type() != kNumberTypeInt16) {
MS_LOG(ERROR) << "Conv weight input type error." << input_tensor->data_type();
return RET_ERROR;
}
if (input_tensor->quant_params().empty()) {
MS_LOG(ERROR) << "No quant param.";
return RET_ERROR;
}
if (input_tensor->data_type() == kNumberTypeInt16 && dst_data_type == kNumberTypeFloat32) {
auto new_const_data = DequantData<int16_t, float>(input_tensor, channel_first);
input_tensor->FreeData();
input_tensor->set_data(new_const_data);
input_tensor->set_own_data(true);
input_tensor->set_data_type(dst_data_type);
} else if (input_tensor->data_type() == kNumberTypeInt16 && dst_data_type == kNumberTypeFloat16) {
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
auto new_const_data = DequantData<int16_t, float16_t>(input_tensor, channel_first);
input_tensor->FreeData();
input_tensor->set_data(new_const_data);
input_tensor->set_own_data(true);
input_tensor->set_data_type(dst_data_type);
#else
MS_LOG(ERROR) << "Float16 is not supported";
return RET_NOT_SUPPORT;
#endif
} else if (input_tensor->data_type() == kNumberTypeInt8 && dst_data_type == kNumberTypeFloat32) {
auto new_const_data = DequantData<int8_t, float>(input_tensor, channel_first);
input_tensor->FreeData();
input_tensor->set_data(new_const_data);
input_tensor->set_own_data(true);
input_tensor->set_data_type(dst_data_type);
} else if (input_tensor->data_type() == kNumberTypeInt8 && dst_data_type == kNumberTypeFloat16) {
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
auto new_const_data = DequantData<int8_t, float16_t>(input_tensor, channel_first);
input_tensor->FreeData();
input_tensor->set_data(new_const_data);
input_tensor->set_own_data(true);
input_tensor->set_data_type(dst_data_type);
#else
MS_LOG(ERROR) << "Float16 is not supported";
return RET_NOT_SUPPORT;
#endif
} else {
MS_LOG(ERROR) << "Unsupported dequant from data_type(" << (input_tensor->data_type()) << ") to data_type("
<< dst_data_type << ")";
return RET_NOT_SUPPORT;
}
return RET_OK;
}
int WeightDecoder::DecodeHuffmanCode(const schema::Tensor &src_tensor, lite::Tensor *dst_tensor) {
MS_ASSERT(dst_tensor != nullptr);
if (!dst_tensor->IsConst() || !src_tensor.enableHuffmanCode()) {
return RET_NO_CHANGE;
}
if (src_tensor.data() == nullptr) {
return RET_NO_CHANGE;
}
auto data = reinterpret_cast<const char *>(src_tensor.data()->data());
if (data == nullptr) {
return RET_NO_CHANGE;
}
std::string encode_str(data, src_tensor.data()->size());
dst_tensor->FreeData();
dst_tensor->set_data(nullptr);
auto ret = dst_tensor->MallocData();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Malloc tensor data failed";
return RET_NULL_PTR;
}
auto dst_data = dst_tensor->data();
MS_ASSERT(dst_data != nullptr);
ret = HuffmanDecode::DoHuffmanDecode(encode_str, dst_data, dst_tensor->Size());
if (ret != RET_OK) {
MS_LOG(ERROR) << "DoHuffmanDecode failed.";
return ret;
}
return RET_OK;
}
int WeightDecoder::UnPackToInt(const schema::Tensor &src_tensor, lite::Tensor *dst_tensor) {
MS_ASSERT(dst_tensor != nullptr);
auto quant_params = src_tensor.quantParams();
if (quant_params == nullptr || quant_params->size() == 0) {
return RET_NO_CHANGE;
}
auto quant_param = quant_params->Get(0);
if (quant_param == nullptr) {
return RET_NO_CHANGE;
}
auto dst_data = dst_tensor->data();
if (dst_data != nullptr) {
MS_LOG(ERROR) << "lite Tensor has already malloced data";
return RET_ERROR;
}
auto ret = dst_tensor->MallocData();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Malloc tensor data failed";
return RET_NULL_PTR;
}
dst_data = dst_tensor->data();
int origin_bit = quant_param->numBits();
if (origin_bit < kBitNum8 && origin_bit >= kBitNum1) {
UnPackUtil<int8_t, uint8_t>(&src_tensor, origin_bit, dst_data);
return RET_OK;
} else if (origin_bit < kBitNum16 && origin_bit > kBitNum8) {
UnPackUtil<int16_t, uint16_t>(&src_tensor, origin_bit, dst_data);
return RET_OK;
} else {
MS_LOG(ERROR) << "Unsupported bit number: " << origin_bit;
return RET_NOT_SUPPORT;
}
}
int WeightDecoder::UnPack(const schema::Tensor &src_tensor, lite::Tensor *dst_tensor) {
STATUS ret = RET_OK;
if (src_tensor.enableHuffmanCode()) {
ret = WeightDecoder::DecodeHuffmanCode(src_tensor, dst_tensor);
if (ret != RET_OK && ret != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Decode huffman code failed: " << ret;
}
} else {
ret = WeightDecoder::UnPackToInt(src_tensor, dst_tensor);
if (ret != RET_OK && ret != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Unpack to int8 failed: " << ret;
}
}
return ret;
}
int WeightDecoder::DequantNode(OpParameter *op_parameter, const std::vector<Tensor *> &in_tensors,
TypeId dst_data_type) {
if (op_parameter->quant_type_ != schema::QuantType_QUANT_WEIGHT) {
return RET_OK;
}
int index = 0;
for (auto &tensor : in_tensors) {
MS_CHECK_TRUE_RET(tensor != nullptr, RET_ERROR);
auto channel_first = IsChannelFirst(index++, op_parameter);
auto ret = WeightDecoder::DequantTensor(tensor, channel_first, dst_data_type);
if (ret != RET_OK && ret != RET_NO_CHANGE) {
MS_LOG(DEBUG) << "Dequant tensor failed";
return RET_ERROR;
}
}
return RET_OK;
}
int WeightDecoder::DequantTensor(Tensor *tensor, bool channel_first, TypeId dst_data_type) {
MS_ASSERT(tensor != nullptr);
if (!tensor->IsConst() ||
!(dst_data_type == TypeId::kNumberTypeFloat32 || dst_data_type == TypeId::kNumberTypeFloat16)) {
return RET_NO_CHANGE;
}
bool need_dequant = !tensor->quant_params().empty() && tensor->quant_params().front().inited &&
(tensor->data_type() == kNumberTypeInt8 || tensor->data_type() == kNumberTypeInt16);
if (!need_dequant) {
return RET_NO_CHANGE;
}
auto ret = WeightDecoder::DequantWeight(tensor, channel_first, dst_data_type);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Dequant data failed: " << ret;
return ret;
}
return RET_OK;
}
}