* -------------------------------------------------------------------------
* This file is part of the Vision SDK project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* Vision SDK is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
* Description: CrnnPostProcess private interface for internal use only.
* Author: MindX SDK
* Create: 2020
* History: NA
*/
#ifndef CRNN_POST_PROCESS_DPTR_H
#define CRNN_POST_PROCESS_DPTR_H
#include "TextGenerationPostProcessors/CrnnPostProcess.h"
#include "MxBase/Log/Log.h"
namespace MxBase {
class SDK_UNAVAILABLE_FOR_OTHER CrnnPostProcessDptr {
public:
explicit CrnnPostProcessDptr(CrnnPostProcess *pCrnnPostProcess);
CrnnPostProcessDptr(const CrnnPostProcessDptr &others);
~CrnnPostProcessDptr() = default;
void TextGenerationOutput(const std::vector<TensorBase> &tensors, std::vector<TextsInfo> &textsInfos);
APP_ERROR CheckAndMoveTensors(std::vector<TensorBase> &tensors);
bool IsValidTensors(const std::vector<TensorBase> &tensors) const;
std::string CalcOutputArgmax(TensorBase &tensor, uint32_t batchNum);
std::string CalcOutputIndex(TensorBase &tensor, uint32_t batchNum);
CrnnPostProcessDptr &operator=(const CrnnPostProcessDptr &other);
public:
uint32_t objectNum_ = 0;
uint32_t blankIdx_ = 0;
bool withArgmax_ = false;
CrnnPostProcess* qPtr_ = nullptr;
};
CrnnPostProcessDptr::CrnnPostProcessDptr(CrnnPostProcess *pCrnnPostProcess)
: qPtr_(pCrnnPostProcess)
{}
CrnnPostProcessDptr::CrnnPostProcessDptr(const CrnnPostProcessDptr &others)
{
*this = others;
}
CrnnPostProcessDptr& CrnnPostProcessDptr::operator=(const CrnnPostProcessDptr &other)
{
if (this == &other) {
return *this;
}
objectNum_ = other.objectNum_;
blankIdx_ = other.blankIdx_;
withArgmax_ = other.withArgmax_;
return *this;
}
void CrnnPostProcessDptr::TextGenerationOutput(const std::vector<TensorBase> &tensors,
std::vector<TextsInfo> &textsInfos)
{
LogDebug << "CrnnPostProcess start to write results.";
auto shape = tensors[0].GetShape();
auto tensor = tensors[0];
uint32_t batchSize = shape[0];
for (uint32_t i = 0; i < batchSize; i++) {
TextsInfo textsInfo;
std::string result = "";
if (!withArgmax_) {
result = CalcOutputArgmax(tensor, i);
} else {
result = CalcOutputIndex(tensor, i);
}
textsInfo.text.push_back(result);
textsInfos.push_back(textsInfo);
LogDebug << "CrnnPostProcessor output string(" << textsInfos[i].text[0] << ").";
}
LogDebug << "CrnnPostProcess end to write results.";
}
bool CrnnPostProcessDptr::IsValidTensors(const std::vector <TensorBase> &tensors) const
{
if (tensors.size() != 1) {
LogError << "The number of tensors (" << tensors.size() << ") " << "is not equal to required (" << 1 << ")"
<< GetErrorInfo(APP_ERR_COMM_INVALID_PARAM);
return false;
}
auto tensor = tensors[0];
auto shape = tensor.GetShape();
if (!withArgmax_) {
for (size_t i = 0; i < tensors.size(); i++) {
if (tensors[i].GetDataTypeSize() != qPtr_->FOUR_BYTE) {
LogError << "The tensor type(" << TensorDataTypeStr[tensors[i].GetTensorType()]
<< ") mismatched. requires(" << qPtr_->FOUR_BYTE << ") bytes tensortype."
<< GetErrorInfo(APP_ERR_COMM_INVALID_PARAM);
return false;
}
}
if (shape.size() != 0x3 && shape.size() != 0x4) {
LogError << "The input size of tensor0(" << shape.size() << ") " << "mismatch the size:"
<< 0x3 << " or " << 0x4 << GetErrorInfo(APP_ERR_COMM_INVALID_PARAM);
return false;
}
if (shape.size() == 0x3 && shape[shape.size() - 0x2] != objectNum_) {
LogError << "The input size(" << shape[shape.size() - 0x2] << ") not equal to objectNum("
<< objectNum_ << ")" << GetErrorInfo(APP_ERR_COMM_INVALID_PARAM);
return false;
} else if (shape.size() == 0x4 && shape[shape.size() - 0x3] != objectNum_) {
LogError << "The input size(" << shape[shape.size() - 0x3] << ") not equal to objectNum("
<< objectNum_ << ")" << GetErrorInfo(APP_ERR_COMM_INVALID_PARAM);
return false;
}
if (shape[shape.size() - 0x1] != qPtr_->classNum_) {
LogError << "The input size(" << shape[shape.size() - 0x1] << ") " << "not equal to classNum("
<< qPtr_->classNum_ << ")" << GetErrorInfo(APP_ERR_COMM_INVALID_PARAM);
return false;
}
} else {
if (shape.size() != 0x2) {
LogError << "The input size of tensor0(" << shape.size() << ") " << "mismatch the size:"
<< 0x2 << GetErrorInfo(APP_ERR_COMM_INVALID_PARAM);
return false;
}
if (shape[1] != objectNum_) {
LogError << "The input size(" << shape[1] << ") " << "not equal to objectNum(" << objectNum_ << ")"
<< GetErrorInfo(APP_ERR_COMM_INVALID_PARAM);
return false;
}
}
return true;
}
APP_ERROR CrnnPostProcessDptr::CheckAndMoveTensors(std::vector<TensorBase> &tensors)
{
if (qPtr_->checkModelFlag_ && !IsValidTensors(tensors)) {
LogError << "Input tensors are invalid." << GetErrorInfo(APP_ERR_COMM_INVALID_PARAM);
return APP_ERR_COMM_INVALID_PARAM;
}
return qPtr_->CheckAndMoveTensors(tensors);
}
std::string CrnnPostProcessDptr::CalcOutputArgmax(TensorBase &tensor, uint32_t batchNum)
{
LogDebug << "Start to Process CalcOutputArgmax.";
float *outputInfo = (float *)(qPtr_->GetBuffer)(tensor, batchNum);
if (outputInfo == nullptr) {
LogError << "The outputInfo is nullptr." << GetErrorInfo(APP_ERR_COMM_INVALID_POINTER);
return "";
}
std::vector<float> logits;
uint32_t previousIdx = blankIdx_;
std::string result = "";
for (uint32_t i = 0; i < objectNum_; i++) {
for (uint32_t j = 0; j < (qPtr_->classNum_); j++) {
logits.push_back(outputInfo[i * (qPtr_->classNum_) + j]);
}
auto maxElement = std::max_element(std::begin(logits), std::end(logits));
uint32_t argmaxIndex = (uint32_t)(maxElement - std::begin(logits));
if (argmaxIndex != blankIdx_ && argmaxIndex != previousIdx) {
result += (qPtr_->configData_).GetClassName(argmaxIndex);
}
previousIdx = argmaxIndex;
logits.clear();
}
LogDebug << "End to Process CalcOutputArgmax.";
return result;
}
std::string CrnnPostProcessDptr::CalcOutputIndex(TensorBase &tensor, uint32_t batchNum)
{
LogDebug << "Start to Process CalcOutputIndex.";
long long int *objectInfo = (long long int *)(qPtr_->GetBuffer)(tensor, batchNum);
if (objectInfo == nullptr) {
LogError << "The objectInfo is nullptr." << GetErrorInfo(APP_ERR_COMM_INVALID_POINTER);
return "";
}
long long int previousIdx = blankIdx_;
std::string result = "";
for (uint32_t i = 0; i < objectNum_; i++) {
if (objectInfo[i] != blankIdx_ && objectInfo[i] != previousIdx) {
result += (qPtr_->configData_).GetClassName(objectInfo[i]);
}
previousIdx = objectInfo[i];
}
LogDebug << "End to Process CalcOutputIndex.";
return result;
}
}
#endif