* This program is free software, you can redistribute it and/or modify.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. See LICENSE in the root of
* the software repository for the full text of the License.
*/
#include "aclnn_catlass_basic_matmul.h"
#include "golden.hpp"
#include "helper.hpp"
using namespace Catlass;
template <class Layout>
inline aclTensor *
Create2DAclTensorFromLayoutAndDataPtr(uint8_t *dataPtr, Layout layout, aclDataType dataType, aclFormat format)
{
const int64_t dims[2] = {layout.shape(0), layout.shape(1)};
const int64_t strides[2] = {layout.stride(0), layout.stride(1)};
return aclCreateTensor(dims, 2, dataType, strides, 0, format, dims, 2, dataPtr);
}
using Options = GemmOptions;
static void Run(const Options &options)
{
aclrtStream stream{nullptr};
ACL_CHECK(aclInit(nullptr));
ACL_CHECK(aclrtSetDevice(options.deviceId));
ACL_CHECK(aclrtCreateStream(&stream));
uint32_t m = options.problemShape.m();
uint32_t n = options.problemShape.n();
uint32_t k = options.problemShape.k();
size_t lenA = static_cast<size_t>(m) * k;
size_t lenB = static_cast<size_t>(k) * n;
size_t lenC = static_cast<size_t>(m) * n;
size_t sizeA = lenA * sizeof(fp16_t);
size_t sizeB = lenB * sizeof(fp16_t);
size_t sizeC = lenC * sizeof(fp16_t);
using LayoutA = layout::RowMajor;
using LayoutB = layout::RowMajor;
using LayoutC = layout::RowMajor;
LayoutA layoutA{m, k};
LayoutB layoutB{k, n};
LayoutC layoutC{m, n};
std::vector<fp16_t> hostA(lenA);
std::vector<fp16_t> hostB(lenB);
golden::FillRandomData<fp16_t>(hostA, -5.0f, 5.0f);
golden::FillRandomData<fp16_t>(hostB, -5.0f, 5.0f);
uint8_t *deviceA{nullptr};
ACL_CHECK(aclrtMalloc(reinterpret_cast<void **>(&deviceA), sizeA, ACL_MEM_MALLOC_HUGE_FIRST));
ACL_CHECK(aclrtMemcpy(deviceA, sizeA, hostA.data(), sizeA, ACL_MEMCPY_HOST_TO_DEVICE));
uint8_t *deviceB{nullptr};
ACL_CHECK(aclrtMalloc(reinterpret_cast<void **>(&deviceB), sizeB, ACL_MEM_MALLOC_HUGE_FIRST));
ACL_CHECK(aclrtMemcpy(deviceB, sizeB, hostB.data(), sizeB, ACL_MEMCPY_HOST_TO_DEVICE));
uint8_t *deviceC{nullptr};
ACL_CHECK(aclrtMalloc(reinterpret_cast<void **>(&deviceC), sizeC, ACL_MEM_MALLOC_HUGE_FIRST));
aclTensor *tensorA = Create2DAclTensorFromLayoutAndDataPtr(deviceA, layoutA, ACL_FLOAT16, ACL_FORMAT_ND);
aclTensor *tensorB = Create2DAclTensorFromLayoutAndDataPtr(deviceB, layoutB, ACL_FLOAT16, ACL_FORMAT_ND);
aclTensor *tensorC = Create2DAclTensorFromLayoutAndDataPtr(deviceC, layoutC, ACL_FLOAT16, ACL_FORMAT_ND);
size_t sizeWorkspace = 0;
aclOpExecutor *executor;
ACL_CHECK(aclnnCatlassBasicMatmulGetWorkspaceSize(tensorA, tensorB, tensorC, &sizeWorkspace, &executor));
uint8_t *deviceWorkspace = nullptr;
if (sizeWorkspace > 0) {
ACL_CHECK(aclrtMalloc(reinterpret_cast<void **>(&deviceWorkspace), sizeWorkspace, ACL_MEM_MALLOC_HUGE_FIRST));
}
ACL_CHECK(aclnnCatlassBasicMatmul(deviceWorkspace, sizeWorkspace, executor, stream));
ACL_CHECK(aclrtSynchronizeStream(stream));
if (sizeWorkspace > 0) {
ACL_CHECK(aclrtFree(deviceWorkspace));
}
std::vector<fp16_t> hostC(lenC);
ACL_CHECK(aclrtMemcpy(hostC.data(), sizeC, deviceC, sizeC, ACL_MEMCPY_DEVICE_TO_HOST));
std::vector<float> hostGolden(lenC);
golden::ComputeMatmul(options.problemShape, hostA, layoutA, hostB, layoutB, hostGolden, layoutC);
std::vector<uint64_t> errorIndices = golden::CompareData(hostC, hostGolden, k);
if (errorIndices.empty()) {
std::cout << "Compare success." << std::endl;
} else {
std::cerr << "Compare failed. Error count: " << errorIndices.size() << std::endl;
}
ACL_CHECK(aclrtFree(deviceA));
ACL_CHECK(aclrtFree(deviceB));
ACL_CHECK(aclrtFree(deviceC));
ACL_CHECK(aclrtDestroyStream(stream));
ACL_CHECK(aclrtResetDevice(options.deviceId));
ACL_CHECK(aclFinalize());
}
int main(int argc, const char **argv)
{
Options options;
if (options.Parse(argc, argv) != 0) {
return -1;
}
Run(options);
return 0;
}