* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef EXAMPLES_COMMON_GOLDEN_CONV2D_HPP
#define EXAMPLES_COMMON_GOLDEN_CONV2D_HPP
#include <vector>
#include "catlass/conv_coord.hpp"
#include "catlass/layout/layout.hpp"
namespace Catlass::golden {
template <class ElementFmap, class LayoutFmap, class ElementFilter, class LayoutFilter, class ElementGolden, class LayoutGolden>
void ComputeConv2d(
const Conv2dParams ¶ms,
const std::vector<ElementFmap> &dataFmap,
const LayoutFmap &layoutFmap,
const std::vector<ElementFilter> &dataFilter,
const LayoutFilter &layoutFilter,
std::vector<ElementGolden> &dataGolden,
const LayoutGolden &layoutGolden
)
{
for (uint32_t batch = 0; batch < params.batch(); batch++) {
for (uint32_t ho = 0; ho < params.ho(); ho++) {
for (uint32_t wo = 0; wo < params.wo(); wo++) {
for (uint32_t cout = 0; cout < params.cout(); cout++) {
ElementGolden accumulator = 0;
for (uint32_t kh = 0; kh < params.kh(); kh++) {
for (uint32_t kw = 0; kw < params.kw(); kw++) {
uint32_t hi = ho * params.strideH() + kh * params.dilationH() - params.padTop();
uint32_t wi = wo * params.strideW() + kw * params.dilationW() - params.padLeft();
if (hi < 0 || hi > params.hi() - 1 || wi < 0 || wi > params.wi() - 1)
continue;
for (uint32_t cin1 = 0; cin1 < params.cin1(); cin1++) {
for (uint32_t c0 = 0; c0 < params.C0; c0++) {
accumulator += static_cast<ElementGolden>(
dataFmap
[batch * params.cin1() * params.hi() * params.wi() * params.C0
+ cin1 * params.hi() * params.wi() * params.C0
+ hi * params.wi() * params.C0 + wi * params.C0 + c0]
)
* static_cast<ElementGolden>(
dataFilter
[cin1 * params.kh() * params.kw() * params.cout() * params.C0
+ kh * params.kw() * params.cout() * params.C0
+ kw * params.cout() * params.C0 + cout * params.C0 + c0]
);
}
}
}
}
uint32_t cout1 = cout / params.C0;
uint32_t c0 = cout - cout1 * params.C0;
dataGolden
[batch * params.cout1() * params.ho() * params.wo() * params.C0
+ cout1 * params.ho() * params.wo() * params.C0 + ho * params.wo() * params.C0 + wo * params.C0
+ c0] = static_cast<ElementGolden>(accumulator);
}
}
}
}
}
template <class Element>
void ClearInvalidOutput(std::vector<Element>& output, const Conv2dParams& params)
{
uint32_t c0InvalidStart = params.cout() % params.C0;
if (c0InvalidStart == 0) {
return;
}
for (uint32_t batch = 0; batch < params.batch(); batch++) {
uint32_t baseOffset = batch * params.cout1() * params.ho() * params.wo() * params.C0 +
(params.cout1() - 1) * params.ho() * params.wo() * params.C0;
for (uint32_t ho = 0; ho < params.ho(); ho++) {
for (uint32_t wo = 0; wo < params.wo(); wo++) {
for (uint32_t c0 = c0InvalidStart; c0 < params.C0; c0++) {
output[baseOffset + ho * params.wo() * params.C0 + wo * params.C0 + c0] = 0;
}
}
}
}
}
}
#endif