* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <iostream>
#include <memory>
#include <sstream>
#include <securec.h>
#include "compress.h"
#include "log.h"
#include "compress_weight.h"
#define CHECK_NOTNULLPTR(val, msg, expr) \
do { \
if ((val) == nullptr) { \
LogFatal(msg); \
expr; \
} \
} while (0)
namespace {
struct CompressParametersConv2d {
uint64_t weightSizeTotal = 0;
size_t weightK = 0U;
size_t weightN = 0U;
size_t weightCo0 = 0U;
size_t weightCo1 = 0U;
size_t kbFrac = 0U;
size_t nbFrac = 0U;
size_t nRemainders = 0U;
size_t wDtypeSize = 0U;
size_t totalCompressedLength = 0U;
size_t dataBase = 0U;
size_t indexBase = 0U;
size_t tailBlockWeightSize = 0U;
size_t tailBlockSize = 0U;
size_t blockSize = 0U;
size_t firstWeightSize = 0U;
size_t indexUnitSize = 0U;
};
const int INDEX_SIZE_COMPACT_MODE = 8;
const int IDX_WEIGHT_SHAPE_K = 0;
const int IDX_WEIGHT_SHAPE_N = 1;
const int IDX_WEIGHT_SHAPE_CO0 = 2;
const int IDX_WEIGHT_SHAPE_CO1 = 3;
const size_t WEIGHT_SPARSE_BLOCK = 4;
const size_t NONZERO_THRESHOLD = 2;
const size_t K0N0_PRODUCT = 512;
const size_t ZERO_REMAINDER = 0;
CmpStatus RearrangeConv2d(const char *const input, char *const weightRe,
size_t compressTilingK, size_t compressTilingN,
CompressParametersConv2d& compressParameters)
{
if (compressTilingK <= 0 || compressTilingN <= 0) {
LogFatal("compressTiling is invalid: compressTilingK("
<< compressTilingK << ") compressTilingN(" << compressTilingN
<< ")");
return RET_ERROR;
}
int weightNum = 0;
size_t jRange = compressTilingN;
for (size_t n = 0; n < compressParameters.nbFrac; n++) {
if ((n == compressParameters.nbFrac - 1) && (compressParameters.nRemainders > 0)) {
jRange = compressParameters.nRemainders;
}
for (size_t k = 0; k < compressParameters.kbFrac; k++) {
for (size_t i = 0; i < compressTilingK; i++) {
for (size_t j = 0; j < jRange; j++) {
size_t start = (j + i * compressParameters.weightN + n * compressTilingN +
k * compressTilingK * compressParameters.weightN) *
(compressParameters.weightCo0 * compressParameters.weightCo1);
size_t end = start + (compressParameters.weightCo0 * compressParameters.weightCo1);
for (size_t x = start; x < end; x++) {
weightRe[weightNum] = input[x];
weightNum += 1;
}
}
}
}
}
return RET_SUCCESS;
}
void LogCompressOpConfig(CompressOpConfig *const param)
{
std::stringstream ss;
ss << std::endl << "==== compress op param ====" << std::endl;
ss << "shape of weight: ";
for (int i = 0; i < SHAPE_SIZE_WEIGHT; ++i) {
ss << param->wShape[i] << " ";
}
ss << std::endl;
ss << "tiling k: " << param->compressTilingK << std::endl;
ss << "tiling n: " << param->compressTilingN << std::endl;
ss << "compress param:" << std::endl;
ss << " inputSize: " << param->compressConfig.inputSize << std::endl;
ss << " engineNum: " << param->compressConfig.engineNum << std::endl;
ss << " maxRatio: " << param->compressConfig.maxRatio << std::endl;
ss << " channel: " << param->compressConfig.channel << std::endl;
ss << " fractalSize: " << param->compressConfig.fractalSize << std::endl;
ss << " isTight: " << param->compressConfig.isTight << std::endl;
ss << " initOffset: " << param->compressConfig.init_offset << std::endl;
ss << "==== compress op param ====" << std::endl;
Log(ss.str());
}
CmpStatus CompressConv2d(CompressOpConfig *const param, char *const weightRe, CompressParametersConv2d&
compressParameters, char *const zipBuffer, char *const infoBuffer)
{
char* weightData = const_cast<char*>(weightRe) + compressParameters.dataBase;
if (param->compressConfig.inputSize == 0) {
LogFatal("The inputSize of compressConfig is zero.");
return RET_ERROR;
}
if (param->compressConfig.fractalSize == 0) {
LogFatal("The fractalSize of compressConfig is zero.");
return RET_ERROR;
}
size_t indexSize = param->compressConfig.inputSize / param->compressConfig.fractalSize * INDEX_SIZE_COMPACT_MODE;
std::unique_ptr<char[]> compIdxBuffer(new (std::nothrow) char[indexSize]());
CHECK_NOTNULLPTR(compIdxBuffer.get(), "compIdxBuffer is null pointers.", return RET_ERROR);
std::unique_ptr<char[]> compDataBuffer(new (std::nothrow) char[param->compressConfig.inputSize]());
CHECK_NOTNULLPTR(compDataBuffer.get(), "compDataBuffer is null pointers.", return RET_ERROR);
size_t compressedLength = 0;
param->compressConfig.init_offset = compressParameters.totalCompressedLength;
if (compressParameters.dataBase + param->compressConfig.inputSize > compressParameters.weightSizeTotal) {
LogFatal("weightSizeTotal is:"<<compressParameters.weightSizeTotal<<", datebase is:"<<
compressParameters.dataBase<<", size is:"<<param->compressConfig.inputSize<<".");
return RET_ERROR;
}
auto ret = CompressWeights(
weightData, param->compressConfig, compIdxBuffer.get(),
compDataBuffer.get(), compressedLength);
if (ret != 0) {
LogFatal("weight compress of Conv2d failed");
return RET_ERROR;
}
auto res = memcpy_s(infoBuffer + compressParameters.indexBase, indexSize, compIdxBuffer.get(), indexSize);
if (res != 0) {
LogFatal("copy compress index failed.");
return RET_ERROR;
}
if (compressParameters.totalCompressedLength + compressedLength > compressParameters.weightSizeTotal) {
LogFatal("weightSizeTotal is:"<<compressParameters.weightSizeTotal<<", totalCompressedLength is:"<<
compressParameters.totalCompressedLength<<", compressedLength is:"<<compressedLength<<".");
return RET_ERROR;
}
res = memcpy_s(
zipBuffer + compressParameters.totalCompressedLength, compressedLength,
compDataBuffer.get(), compressedLength);
if (res != 0) {
LogFatal("copy compress data failed.");
return RET_ERROR;
}
compressParameters.dataBase += param->compressConfig.inputSize;
compressParameters.totalCompressedLength += compressedLength;
compressParameters.indexBase += indexSize;
return RET_SUCCESS;
}
CmpStatus CompressWeightConv2d(char *const weightRe, char *const zipBuffer,
char *const infoBuffer, CompressOpConfig *const param,
CompressParametersConv2d &compressParameters)
{
if (compressParameters.nRemainders == 0) {
size_t compressedLength = 0;
auto ret = CompressWeights(weightRe, param->compressConfig, infoBuffer, zipBuffer, compressedLength);
if (ret != 0) {
LogFatal("weight compress of Conv2d failed");
return RET_ERROR;
}
} else {
if (compressParameters.firstWeightSize != 0) {
param->compressConfig.inputSize = compressParameters.firstWeightSize;
auto ret = CompressConv2d(param, weightRe, compressParameters, zipBuffer, infoBuffer);
if (ret != 0) {
LogFatal("compress firstWeight failed");
return RET_ERROR;
}
}
for (size_t i = 0; i < compressParameters.kbFrac; i++) {
param->compressConfig.inputSize = compressParameters.tailBlockWeightSize - compressParameters.tailBlockSize;
param->compressConfig.fractalSize = compressParameters.blockSize;
if (param->compressConfig.inputSize != 0) {
auto ret = CompressConv2d(param, weightRe, compressParameters, zipBuffer, infoBuffer);
if (ret != 0) {
LogFatal("compress tail failed");
return RET_ERROR;
}
}
param->compressConfig.inputSize = compressParameters.tailBlockSize;
param->compressConfig.fractalSize = compressParameters.tailBlockSize;
if (compressParameters.tailBlockSize != 0) {
auto ret = CompressConv2d(param, weightRe, compressParameters, zipBuffer, infoBuffer);
if (ret != 0) {
LogFatal("CompressTailOfTail failed");
return RET_ERROR;
}
}
}
}
return RET_SUCCESS;
}
bool SetCompressParametersConv2dConv2d(CompressParametersConv2d& compressParameters, CompressOpConfig *const param)
{
if (param->compressTilingK <= 0) {
LogFatal("CompressTilingK in parameter must > 0.");
return false;
}
if (param->compressTilingN <= 0) {
LogFatal("CompressTilingN in parameter must > 0.");
return false;
}
for (auto v : param->wShape) {
if (v <= 0) {
LogFatal("Dim of weight must > 0.");
return false;
}
}
size_t compressWeightSize = compressParameters.weightK * compressParameters.weightN *
compressParameters.weightCo0 * compressParameters.weightCo1;
if (compressWeightSize <= 0) {
LogFatal("compressWeightSize must > 0.");
return false;
}
compressParameters.weightSizeTotal = param->compressConfig.inputSize;
compressParameters.weightK = param->wShape[IDX_WEIGHT_SHAPE_K];
compressParameters.weightN = param->wShape[IDX_WEIGHT_SHAPE_N];
compressParameters.weightCo0 = param->wShape[IDX_WEIGHT_SHAPE_CO0];
compressParameters.weightCo1 = param->wShape[IDX_WEIGHT_SHAPE_CO1];
compressParameters.kbFrac = compressParameters.weightK / param->compressTilingK;
compressParameters.nbFrac = compressParameters.weightN / param->compressTilingN;
compressParameters.nRemainders = compressParameters.weightN % param->compressTilingN;
compressParameters.wDtypeSize = compressParameters.weightSizeTotal / compressWeightSize;
if (compressParameters.nRemainders != 0) {
compressParameters.firstWeightSize = param->compressTilingN * param->compressTilingK *
compressParameters.weightCo0 * compressParameters.weightCo1 * compressParameters.kbFrac *
compressParameters.nbFrac * compressParameters.wDtypeSize;
compressParameters.nbFrac += 1;
}
compressParameters.blockSize = param->compressConfig.fractalSize;
compressParameters.tailBlockWeightSize = compressParameters.nRemainders * param->compressTilingK *
compressParameters.weightCo0 * compressParameters.weightCo1 * compressParameters.wDtypeSize;
compressParameters.tailBlockSize = compressParameters.tailBlockWeightSize % compressParameters.blockSize;
compressParameters.dataBase = 0;
compressParameters.indexBase = 0;
compressParameters.totalCompressedLength = 0;
return true;
}
}
CmpStatus CompressWeightsConv2D(const char *const input, char *const zipBuffer,
char *const infoBuffer,
CompressOpConfig *const param)
{
if (input == nullptr || zipBuffer == nullptr || infoBuffer == nullptr ||
param == nullptr) {
LogFatal("The input parameters contain null pointers.");
return RET_ERROR;
}
if (param->compressConfig.inputSize <= 0) {
LogFatal("inputSize in CompressConfig is invalid: "
<< param->compressConfig.inputSize);
return RET_ERROR;
}
LogCompressOpConfig(param);
std::unique_ptr<char[]> weightRe(
new (std::nothrow) char[param->compressConfig.inputSize]());
CHECK_NOTNULLPTR(weightRe.get(), "weightRe is null pointers.", return RET_ERROR);
CompressParametersConv2d compressParameters;
if (!SetCompressParametersConv2dConv2d(compressParameters, param)) {
return RET_ERROR;
}
auto res = RearrangeConv2d(input, weightRe.get(), param->compressTilingK,
param->compressTilingN, compressParameters);
if (res == RET_ERROR) {
LogFatal("rearrange weight failed");
return RET_ERROR;
}
res = CompressWeightConv2d(weightRe.get(), zipBuffer, infoBuffer, param, compressParameters);
if (res == RET_ERROR) {
LogFatal("weight compress of Conv2d failed");
return RET_ERROR;
}
return RET_SUCCESS;
}
CmpStatus SparseWeightsConv2D(const char *const input, size_t weight_size)
{
input: weight data, int8 dtype.
weight_size: The size of weight_data. Ci1HkWk*Co1*Co0*Ci0
Check if there are over 2 nonzero elements in 4 elements, then return error.
*/
if (input == nullptr) {
LogFatal("Input weight is a null pointer!");
return RET_ERROR;
}
if (weight_size % K0N0_PRODUCT != ZERO_REMAINDER) {
LogFatal("Invalid weight_size, which is not n0*k0 aligned!");
return RET_ERROR;
}
for (size_t i = 0; i < weight_size / WEIGHT_SPARSE_BLOCK; i++) {
size_t nonzeroNum = 0;
for (size_t j = 0; j < WEIGHT_SPARSE_BLOCK; j++) {
if (input[i * WEIGHT_SPARSE_BLOCK + j] != 0) {
nonzeroNum += 1;
}
}
if (nonzeroNum > NONZERO_THRESHOLD) {
LogFatal("Weight sparse is not satisfied! Over 2 nonzero elements in 4 elements.");
return RET_ERROR;
}
}
return RET_SUCCESS;
}