/*
 * Copyright(C) 2020. Huawei Technologies Co.,Ltd. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "Yolov3PostProcess.h"

#include "MxBase/CV/ObjectDetection/Nms/Nms.h"
#include "MxBase/Log/Log.h"
#include "MxBase/Maths/FastMath.h"

namespace
{
const int SCALE = 32;
const int BIASESDIM = 2;
const int OFFSETWIDTH = 2;
const int OFFSETHEIGHT = 3;
const int OFFSETBIASES = 1;
const int OFFSETOBJECTNESS = 1;

const int NHWC_HEIGHTINDEX = 1;
const int NHWC_WIDTHINDEX = 2;
const int NCHW_HEIGHTINDEX = 2;
const int NCHW_WIDTHINDEX = 3;
const int YOLO_INFO_DIM = 5;

auto uint8Deleter = [](uint8_t *p) {};
}  // namespace
namespace MxBase
{
Yolov3PostProcess &Yolov3PostProcess::operator=(const Yolov3PostProcess &other)
{
    if (this == &other)
    {
        return *this;
    }
    ObjectPostProcessBase::operator=(other);
    objectnessThresh_ = other.objectnessThresh_;  // Threshold of objectness value
    iouThresh_ = other.iouThresh_;
    anchorDim_ = other.anchorDim_;
    biasesNum_ = other.biasesNum_;
    yoloType_ = other.yoloType_;
    modelType_ = other.modelType_;
    yoloVersion_ = other.yoloVersion_;
    inputType_ = other.inputType_;
    biases_ = other.biases_;
    return *this;
}

APP_ERROR Yolov3PostProcess::Init(const std::map<std::string, std::string> &postConfig)
{
    LogDebug << "Start to Init Yolov3PostProcess.";
    APP_ERROR ret = ObjectPostProcessBase::Init(postConfig);
    if (ret != APP_ERR_OK)
    {
        LogError << GetErrorInfo(ret) << "Fail to superInit in ObjectPostProcessBase.";
        return ret;
    }

    configData_.GetFileValue<int>("BIASES_NUM", biasesNum_);
    std::string str;
    configData_.GetFileValue<std::string>("BIASES", str);
    configData_.GetFileValue<float>("OBJECTNESS_THRESH", objectnessThresh_);
    configData_.GetFileValue<float>("IOU_THRESH", iouThresh_);
    configData_.GetFileValue<int>("YOLO_TYPE", yoloType_);
    configData_.GetFileValue<int>("MODEL_TYPE", modelType_);
    configData_.GetFileValue<int>("YOLO_VERSION", yoloVersion_);
    configData_.GetFileValue<int>("INPUT_TYPE", inputType_);
    configData_.GetFileValue<int>("ANCHOR_DIM", anchorDim_);
    ret = GetBiases(str);
    if (ret != APP_ERR_OK)
    {
        LogError << GetErrorInfo(ret) << "Failed to get biases.";
        return ret;
    }
    LogDebug << "End to Init Yolov3PostProcess.";
    return APP_ERR_OK;
}

APP_ERROR Yolov3PostProcess::DeInit() { return APP_ERR_OK; }

bool Yolov3PostProcess::IsValidTensors(const std::vector<TensorBase> &tensors) const
{
    if (tensors.size() != (size_t)yoloType_)
    {
        LogError << "number of tensors (" << tensors.size() << ") " << "is unequal to yoloType_(" << yoloType_ << ")";
        return false;
    }
    if (yoloVersion_ == YOLOV3_VERSION)
    {
        for (size_t i = 0; i < tensors.size(); i++)
        {
            auto shape = tensors[i].GetShape();
            if (shape.size() < VECTOR_FIFTH_INDEX)
            {
                LogError << "dimensions of tensor [" << i << "] is less than " << VECTOR_FIFTH_INDEX << ".";
                return false;
            }
            uint32_t channelNumber = 1;
            int startIndex = modelType_ ? VECTOR_SECOND_INDEX : VECTOR_FOURTH_INDEX;
            int endIndex = modelType_ ? (shape.size() - VECTOR_THIRD_INDEX) : shape.size();
            for (int i = startIndex; i < endIndex; i++)
            {
                channelNumber *= shape[i];
            }
            if (channelNumber != anchorDim_ * (classNum_ + YOLO_INFO_DIM))
            {
                LogError << "channelNumber(" << channelNumber << ") != anchorDim_ * (classNum_ + 5).";
                return false;
            }
        }
        return true;
    }
    else
    {
        return true;
    }
}

void Yolov3PostProcess::ObjectDetectionOutput(const std::vector<TensorBase> &tensors,
                                              std::vector<std::vector<ObjectInfo>> &objectInfos,
                                              const std::vector<ResizedImageInfo> &resizedImageInfos)
{
    LogDebug << "Yolov3PostProcess start to write results.";
    if (tensors.size() == 0)
    {
        return;
    }
    auto shape = tensors[0].GetShape();
    if (shape.size() == 0)
    {
        return;
    }
    uint32_t batchSize = shape[0];
    for (uint32_t i = 0; i < batchSize; i++)
    {
        std::vector<std::shared_ptr<void>> featLayerData = {};
        std::vector<std::vector<size_t>> featLayerShapes = {};
        for (uint32_t j = 0; j < tensors.size(); j++)
        {
            auto dataPtr = (uint8_t *)GetBuffer(tensors[j], i);
            std::shared_ptr<void> tmpPointer;
            tmpPointer.reset(dataPtr, uint8Deleter);
            featLayerData.push_back(tmpPointer);
            shape = tensors[j].GetShape();
            std::vector<size_t> featLayerShape = {};
            for (auto s : shape)
            {
                featLayerShape.push_back((size_t)s);
            }
            featLayerShapes.push_back(featLayerShape);
        }
        std::vector<ObjectInfo> objectInfo;
        GenerateBbox(featLayerData, objectInfo, featLayerShapes, resizedImageInfos[i].widthResize,
                     resizedImageInfos[i].heightResize);
        MxBase::NmsSort(objectInfo, iouThresh_);
        objectInfos.push_back(objectInfo);
    }
    LogDebug << "Yolov3PostProcess write results successed.";
}

APP_ERROR Yolov3PostProcess::Process(const std::vector<TensorBase> &tensors,
                                     std::vector<std::vector<ObjectInfo>> &objectInfos,
                                     const std::vector<ResizedImageInfo> &resizedImageInfos,
                                     const std::map<std::string, std::shared_ptr<void>> &paramMap)
{
    LogDebug << "Start to Process Yolov3PostProcess.";
    APP_ERROR ret = APP_ERR_OK;
    if (resizedImageInfos.size() == 0)
    {
        ret = APP_ERR_INPUT_NOT_MATCH;
        LogError << GetErrorInfo(ret) << "resizedImageInfos is not provided which is necessary for Yolov3PostProcess.";
        return ret;
    }
    auto inputs = tensors;
    ret = CheckAndMoveTensors(inputs);
    if (ret != APP_ERR_OK)
    {
        LogError << GetErrorInfo(ret) << "CheckAndMoveTensors failed.";
        return ret;
    }

    ObjectDetectionOutput(inputs, objectInfos, resizedImageInfos);

    for (uint32_t i = 0; i < resizedImageInfos.size(); i++)
    {
        CoordinatesReduction(i, resizedImageInfos[i], objectInfos[i]);
    }
    LogObjectInfos(objectInfos);
    LogDebug << "End to Process Yolov3PostProcess.";
    return APP_ERR_OK;
}

/*
 * @description: Compare the confidences between 2 classes and get the larger one
 */
void Yolov3PostProcess::CompareProb(int &classID, float &maxProb, float classProb, int classNum)
{
    if (classProb > maxProb)
    {
        maxProb = classProb;
        classID = classNum;
    }
}

/*
 * @description: Select the highest confidence class name for each predicted box
 * @param netout  The feature data which contains box coordinates, objectness value and confidence of each class
 * @param info  Yolo layer info which contains class number, box dim and so on
 * @param detBoxes  ObjectInfo vector where all ObjectInfoes's confidences are greater than threshold
 * @param stride  Stride of output feature data
 * @param layer  Yolo output layer
 */
void Yolov3PostProcess::SelectClassNCHW(std::shared_ptr<void> netout, NetInfo info,
                                        std::vector<MxBase::ObjectInfo> &detBoxes, int stride, OutputLayer layer)
{
    for (int j = 0; j < stride; ++j)
    {
        for (int k = 0; k < info.anchorDim; ++k)
        {
            int bIdx = (info.bboxDim + 1 + info.classNum) * stride * k + j;  // begin index
            int oIdx = bIdx + info.bboxDim * stride;                         // objectness index
            // check obj
            float objectness = fastmath::sigmoid(static_cast<float *>(netout.get())[oIdx]);
            if (objectness <= objectnessThresh_)
            {
                continue;
            }
            int classID = -1;
            float maxProb = scoreThresh_;
            float classProb;
            // Compare the confidence of the 3 anchors, select the largest one
            for (int c = 0; c < info.classNum; ++c)
            {
                classProb = fastmath::sigmoid(static_cast<float *>(
                                netout.get())[bIdx + (info.bboxDim + OFFSETOBJECTNESS + c) * stride]) *
                            objectness;
                CompareProb(classID, maxProb, classProb, c);
            }
            if (classID < 0) continue;
            MxBase::ObjectInfo det;
            int row = j / layer.width;
            int col = j % layer.width;
            float x = (col + fastmath::sigmoid(static_cast<float *>(netout.get())[bIdx])) / layer.width;
            float y = (row + fastmath::sigmoid(static_cast<float *>(netout.get())[bIdx + stride])) / layer.height;
            float width = fastmath::exp(static_cast<float *>(netout.get())[bIdx + OFFSETWIDTH * stride]) *
                          layer.anchors[BIASESDIM * k] / info.netWidth;
            float height = fastmath::exp(static_cast<float *>(netout.get())[bIdx + OFFSETHEIGHT * stride]) *
                           layer.anchors[BIASESDIM * k + OFFSETBIASES] / info.netHeight;
            det.x0 = std::max(0.0f, x - width / COORDINATE_PARAM);
            det.x1 = std::min(1.0f, x + width / COORDINATE_PARAM);
            det.y0 = std::max(0.0f, y - height / COORDINATE_PARAM);
            det.y1 = std::min(1.0f, y + height / COORDINATE_PARAM);
            det.classId = classID;
            det.className = configData_.GetClassName(classID);
            det.confidence = maxProb;
            if (det.confidence < separateScoreThresh_[classID])
            {
                continue;
            }
            detBoxes.emplace_back(det);
        }
    }
}

void Yolov3PostProcess::SelectClassNCHWC(std::shared_ptr<void> netout, NetInfo info,
                                         std::vector<MxBase::ObjectInfo> &detBoxes, int stride, OutputLayer layer)
{
    LogDebug << " out size " << sizeof(netout.get());
    const int offsetY = 1;
    for (int j = 0; j < stride; ++j)
    {
        for (int k = 0; k < info.anchorDim; ++k)
        {
            int bIdx = (info.bboxDim + 1 + info.classNum) * stride * k + j * (info.bboxDim + 1 + info.classNum);
            int oIdx = bIdx + info.bboxDim;  // objectness index
            // check obj
            float objectness = fastmath::sigmoid(static_cast<float *>(netout.get())[oIdx]);
            if (objectness <= objectnessThresh_)
            {
                continue;
            }
            int classID = -1;
            float maxProb = scoreThresh_;
            float classProb;
            // Compare the confidence of the 3 anchors, select the largest one
            for (int c = 0; c < info.classNum; ++c)
            {
                classProb = fastmath::sigmoid(
                                static_cast<float *>(netout.get())[bIdx + (info.bboxDim + OFFSETOBJECTNESS + c)]) *
                            objectness;
                CompareProb(classID, maxProb, classProb, c);
            }
            if (classID < 0) continue;
            MxBase::ObjectInfo det;
            int row = j / layer.width;
            int col = j % layer.width;
            float x =
                (col + fastmath::sigmoid(static_cast<float *>(netout.get())[bIdx]) * COORDINATE_PARAM - MEAN_PARAM) /
                layer.width;
            float y = (row + fastmath::sigmoid(static_cast<float *>(netout.get())[bIdx + offsetY]) * COORDINATE_PARAM -
                       MEAN_PARAM) /
                      layer.height;
            float width =
                (fastmath::sigmoid(static_cast<float *>(netout.get())[bIdx + OFFSETWIDTH]) * COORDINATE_PARAM) *
                (fastmath::sigmoid(static_cast<float *>(netout.get())[bIdx + OFFSETWIDTH]) * COORDINATE_PARAM) *
                layer.anchors[BIASESDIM * k] / info.netWidth;
            float height =
                (fastmath::sigmoid(static_cast<float *>(netout.get())[bIdx + OFFSETHEIGHT]) * COORDINATE_PARAM) *
                (fastmath::sigmoid(static_cast<float *>(netout.get())[bIdx + OFFSETHEIGHT]) * COORDINATE_PARAM) *
                layer.anchors[BIASESDIM * k + OFFSETBIASES] / info.netHeight;
            det.x0 = std::max(0.0f, x - width / COORDINATE_PARAM);
            det.x1 = std::min(1.0f, x + width / COORDINATE_PARAM);
            det.y0 = std::max(0.0f, y - height / COORDINATE_PARAM);
            det.y1 = std::min(1.0f, y + height / COORDINATE_PARAM);
            det.classId = classID;
            det.className = configData_.GetClassName(classID);
            det.confidence = maxProb;
            if (det.confidence < separateScoreThresh_[classID]) continue;
            detBoxes.emplace_back(det);
        }
    }
}

/*
 * @description: Select the highest confidence class label for each predicted box and save into detBoxes
 * @param netout  The feature data which contains box coordinates, objectness value and confidence of each class
 * @param info  Yolo layer info which contains class number, box dim and so on
 * @param detBoxes  ObjectInfo vector where all ObjectInfoes's confidences are greater than threshold
 * @param stride  Stride of output feature data
 * @param layer  Yolo output layer
 */
void Yolov3PostProcess::SelectClassNHWC(std::shared_ptr<void> netout, NetInfo info,
                                        std::vector<MxBase::ObjectInfo> &detBoxes, int stride, OutputLayer layer)
{
    const int offsetY = 1;
    for (int j = 0; j < stride; ++j)
    {
        for (int k = 0; k < info.anchorDim; ++k)
        {
            int bIdx = (info.bboxDim + 1 + info.classNum) * info.anchorDim * j + k * (info.bboxDim + 1 + info.classNum);
            int oIdx = bIdx + info.bboxDim;  // objectness index
            // check obj
            float objectness = fastmath::sigmoid(static_cast<float *>(netout.get())[oIdx]);
            if (objectness <= objectnessThresh_)
            {
                continue;
            }
            int classID = -1;
            float maxProb = scoreThresh_;
            float classProb;
            // Compare the confidence of the 3 anchors, select the largest one
            for (int c = 0; c < info.classNum; ++c)
            {
                classProb = fastmath::sigmoid(
                                static_cast<float *>(netout.get())[bIdx + (info.bboxDim + OFFSETOBJECTNESS + c)]) *
                            objectness;
                CompareProb(classID, maxProb, classProb, c);
            }
            if (classID < 0) continue;
            MxBase::ObjectInfo det;
            int row = j / layer.width;
            int col = j % layer.width;
            float x = (col + fastmath::sigmoid(static_cast<float *>(netout.get())[bIdx])) / layer.width;
            float y = (row + fastmath::sigmoid(static_cast<float *>(netout.get())[bIdx + offsetY])) / layer.height;
            float width = fastmath::exp(static_cast<float *>(netout.get())[bIdx + OFFSETWIDTH]) *
                          layer.anchors[BIASESDIM * k] / info.netWidth;
            float height = fastmath::exp(static_cast<float *>(netout.get())[bIdx + OFFSETHEIGHT]) *
                           layer.anchors[BIASESDIM * k + OFFSETBIASES] / info.netHeight;
            det.x0 = std::max(0.0f, x - width / COORDINATE_PARAM);
            det.x1 = std::min(1.0f, x + width / COORDINATE_PARAM);
            det.y0 = std::max(0.0f, y - height / COORDINATE_PARAM);
            det.y1 = std::min(1.0f, y + height / COORDINATE_PARAM);
            det.classId = classID;
            det.className = configData_.GetClassName(classID);
            det.confidence = maxProb;
            if (det.confidence < separateScoreThresh_[classID])
            {
                continue;
            }
            detBoxes.emplace_back(det);
        }
    }
}

/*
 * @description: According to the yolo layer structure, encapsulate the anchor box data of each feature into detBoxes
 * @param featLayerData  Vector of 3 output feature data
 * @param info  Yolo layer info which contains anchors dim, bbox dim, class number, net width, net height and
                3 outputlayer(eg. 13*13, 26*26, 52*52)
 * @param detBoxes  ObjectInfo vector where all ObjectInfoes's confidences are greater than threshold
 */
void Yolov3PostProcess::GenerateBbox(std::vector<std::shared_ptr<void>> featLayerData,
                                     std::vector<MxBase::ObjectInfo> &detBoxes,
                                     const std::vector<std::vector<size_t>> &featLayerShapes, const int netWidth,
                                     const int netHeight)
{
    NetInfo netInfo;
    netInfo.anchorDim = anchorDim_;
    netInfo.bboxDim = BOX_DIM;
    netInfo.classNum = classNum_;
    netInfo.netWidth = netWidth;
    netInfo.netHeight = netHeight;
    for (int i = 0; i < yoloType_; ++i)
    {
        int widthIndex_ = modelType_ ? NCHW_WIDTHINDEX : NHWC_WIDTHINDEX;
        int heightIndex_ = modelType_ ? NCHW_HEIGHTINDEX : NHWC_HEIGHTINDEX;
        OutputLayer layer = {featLayerShapes[i][widthIndex_], featLayerShapes[i][heightIndex_]};
        int logOrder = log(featLayerShapes[i][widthIndex_] * SCALE / netWidth) / log(BIASESDIM);
        int startIdx = (yoloType_ - 1 - logOrder) * netInfo.anchorDim * BIASESDIM;
        int endIdx = startIdx + netInfo.anchorDim * BIASESDIM;
        int idx = 0;
        for (int j = startIdx; j < endIdx; ++j)
        {
            layer.anchors[idx++] = biases_[j];
        }

        int stride = layer.width * layer.height;  // 13*13 26*26 52*52
        std::shared_ptr<void> netout = featLayerData[i];
        if (modelType_ == 0)
        {
            SelectClassNHWC(netout, netInfo, detBoxes, stride, layer);
        }
        else
        {
            if (yoloVersion_ == YOLOV3_VERSION)
            {
                SelectClassNCHW(netout, netInfo, detBoxes, stride, layer);
            }
            else
            {
                SelectClassNCHWC(netout, netInfo, detBoxes, stride, layer);
            }
        }
    }
}

APP_ERROR Yolov3PostProcess::GetBiases(std::string &strBiases)
{
    if (biasesNum_ <= 0)
    {
        LogError << GetErrorInfo(APP_ERR_COMM_INVALID_PARAM) << "Failed to get biasesNum (" << biasesNum_ << ").";
        return APP_ERR_COMM_INVALID_PARAM;
    }
    biases_.clear();
    int i = 0;
    int num = strBiases.find(",");
    while (num >= 0 && i < biasesNum_)
    {
        std::string tmp = strBiases.substr(0, num);
        num++;
        strBiases = strBiases.substr(num, strBiases.size());
        biases_.push_back(stof(tmp));
        i++;
        num = strBiases.find(",");
    }
    if (i != biasesNum_ - 1 || strBiases.size() <= 0)
    {
        LogError << GetErrorInfo(APP_ERR_COMM_INVALID_PARAM) << "biasesNum (" << biasesNum_
                 << ") is not equal to total number of biases (" << strBiases << ").";
        return APP_ERR_COMM_INVALID_PARAM;
    }
    biases_.push_back(stof(strBiases));
    return APP_ERR_OK;
}

extern "C"
{
    std::shared_ptr<MxBase::Yolov3PostProcess> GetObjectInstance()
    {
        LogInfo << "Begin to get Yolov3PostProcess instance.";
        auto instance = std::make_shared<MxBase::Yolov3PostProcess>();
        LogInfo << "End to get Yolov3PostProcess instance.";
        return instance;
    }
}
}  // namespace MxBase