* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/huffman_decode.h"
#include <queue>
namespace mindspore {
namespace lite {
STATUS HuffmanDecode::DoHuffmanDecode(const std::string &input_str, void *decoded_data, size_t data_len) {
if (decoded_data == nullptr) {
MS_LOG(ERROR) << "decoded_data is nullptr.";
return RET_ERROR;
}
int status;
std::string huffman_decoded_str;
auto key_pos = input_str.find_first_of('#');
auto code_pos = input_str.find_first_of('#', key_pos + 1);
auto key = input_str.substr(0, key_pos);
auto code = input_str.substr(key_pos + 1, code_pos - key_pos - 1);
auto encoded_data = input_str.substr(code_pos + 1);
auto root = new (std::nothrow) HuffmanNode();
if (root == nullptr) {
MS_LOG(ERROR) << "new HuffmanNode failed.";
return RET_MEMORY_FAILED;
}
root->left = nullptr;
root->right = nullptr;
root->parent = nullptr;
status = RebuildHuffmanTree(key, code, root);
if (status != RET_OK) {
MS_LOG(ERROR) << "Rebuild huffman tree failed.";
delete root;
return status;
}
status = DoHuffmanDecompress(root, encoded_data, &huffman_decoded_str);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoHuffmanDecompress failed.";
delete root;
return status;
}
size_t len = huffman_decoded_str.length();
if (data_len >= len) {
memcpy(decoded_data, huffman_decoded_str.c_str(), len);
} else {
FreeHuffmanNodeTree(root);
return RET_ERROR;
}
FreeHuffmanNodeTree(root);
return RET_OK;
}
STATUS HuffmanDecode::RebuildHuffmanTree(std::string keys, std::string codes, const HuffmanNodePtr &root) {
HuffmanNodePtr cur_node, tmp_node, new_node;
auto huffman_keys = Str2Vec(std::move(keys));
auto huffman_codes = Str2Vec(std::move(codes));
for (size_t i = 0; i < huffman_codes.size(); ++i) {
auto key = stoi(huffman_keys[i]);
auto code = huffman_codes[i];
auto code_len = code.length();
cur_node = root;
for (size_t j = 0; j < code_len; ++j) {
if (code[j] == '0') {
tmp_node = cur_node->left;
} else if (code[j] == '1') {
tmp_node = cur_node->right;
} else {
MS_LOG(ERROR) << "find huffman code is not 0 or 1";
return RET_ERROR;
}
if (tmp_node == nullptr) {
new_node = new (std::nothrow) HuffmanNode();
if (new_node == nullptr) {
MS_LOG(ERROR) << "new HuffmanNode failed.";
return RET_MEMORY_FAILED;
}
new_node->left = nullptr;
new_node->right = nullptr;
new_node->parent = cur_node;
if (j == code_len - 1) {
new_node->key = key;
new_node->code = code;
}
if (code[j] == '0') {
cur_node->left = new_node;
} else {
cur_node->right = new_node;
}
tmp_node = new_node;
} else if (j == code_len - 1) {
MS_LOG(ERROR) << "the huffman code is incomplete.";
return RET_ERROR;
} else if (tmp_node->left == nullptr && tmp_node->right == nullptr) {
MS_LOG(ERROR) << "the huffman code is incomplete";
return RET_ERROR;
}
cur_node = tmp_node;
}
}
return RET_OK;
}
STATUS HuffmanDecode::DoHuffmanDecompress(HuffmanNodePtr root, std::string encoded_data, std::string *decoded_str) {
HuffmanNodePtr cur_node = root;
bool pseudo_eof = false;
size_t pos = 0;
unsigned char flag;
decoded_str->clear();
while (pos < encoded_data.length()) {
auto u_char = static_cast<unsigned char>(encoded_data[pos]);
flag = 0x80;
for (size_t i = 0; i < 8; ++i) {
if (u_char & flag) {
cur_node = cur_node->right;
} else {
cur_node = cur_node->left;
}
if (cur_node->left == nullptr && cur_node->right == nullptr) {
auto key = cur_node->key;
if (key == PSEUDO_EOF) {
pseudo_eof = true;
break;
} else {
*decoded_str += static_cast<char>(cur_node->key);
cur_node = root;
}
}
flag = flag >> 1;
}
pos++;
if (pseudo_eof) {
break;
}
}
return RET_OK;
}
void HuffmanDecode::FreeHuffmanNodeTree(HuffmanNodePtr root) {
if (root == nullptr) {
return;
}
std::queue<HuffmanNodePtr> node_queue;
node_queue.push(root);
while (!node_queue.empty()) {
auto cur_node = node_queue.front();
node_queue.pop();
if (cur_node->left != nullptr) {
node_queue.push(cur_node->left);
}
if (cur_node->right != nullptr) {
node_queue.push(cur_node->right);
}
delete (cur_node);
}
}
}
}