* This program is free software, you can redistribute it and/or modify.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef EXAMPLES_COMMON_HELPER_HPP
#define EXAMPLES_COMMON_HELPER_HPP
#ifdef ASCENDC_MODULE_OPERATOR_H
#undef inline
#endif
#include <cstdio>
#include <fstream>
#include <iostream>
#include <vector>
#include <acl/acl.h>
#include <opdev/bfloat16.h>
#include <opdev/fp16_t.h>
#include <runtime/rt_ffts.h>
#include <tiling/platform/platform_ascendc.h>
#ifdef ASCENDC_MODULE_OPERATOR_H
#define inline __inline__ __attribute__((always_inline))
#endif
#include "catlass/layout/layout.hpp"
#include "options.hpp"
using op::bfloat16;
using op::fp16_t;
#define ACL_CHECK(status) \
do { \
aclError error = status; \
if (error != ACL_ERROR_NONE) { \
std::cerr << __FILE__ << ":" << __LINE__ << " aclError:" << error << std::endl; \
} \
} while (0)
#define RT_CHECK(status) \
do { \
rtError_t error = status; \
if (error != RT_ERROR_NONE) { \
std::cerr << __FILE__ << ":" << __LINE__ << " rtError:" << error << std::endl; \
} \
} while (0)
* Function for read file.
*/
inline bool ReadFile(const std::string &filePath, void *buffer, size_t bufferSize) {
if (buffer == nullptr) {
printf("Read file %s failed. Buffer is nullptr.\n", filePath.c_str());
return false;
}
std::ifstream fd(filePath, std::ios::binary);
if (!fd) {
printf("Open file failed. path = %s.\n", filePath.c_str());
return false;
}
std::filebuf *buf = fd.rdbuf();
size_t size = buf->pubseekoff(0, std::ios::end, std::ios::in);
if (size == 0) {
printf("File %s size is 0\n", filePath.c_str());
return false;
}
if (size > bufferSize) {
printf("File %s size is larger than buffer size.\n", filePath.c_str());
return false;
}
buf->pubseekpos(0, std::ios::in);
buf->sgetn(static_cast<char *>(buffer), size);
return true;
}
template <class Adapter>
inline void RunAdapter(
Adapter opAdapter,
typename Adapter::Arguments args,
aclrtStream stream,
uint32_t coreNum,
uint64_t fftsAddr = 0
) {
size_t sizeWorkspace = opAdapter.GetWorkspaceSize(args);
uint8_t *deviceWorkspace = nullptr;
if (sizeWorkspace > 0) {
ACL_CHECK(aclrtMalloc(reinterpret_cast<void **>(&deviceWorkspace), sizeWorkspace, ACL_MEM_MALLOC_HUGE_FIRST));
}
opAdapter.Initialize(args, deviceWorkspace);
opAdapter(stream, coreNum, fftsAddr);
ACL_CHECK(aclrtSynchronizeStream(stream));
if (sizeWorkspace > 0) {
ACL_CHECK(aclrtFree(deviceWorkspace));
}
}
namespace Catlass {
inline bool IsNeedPadding(layout::RowMajor layout, uint32_t align) {
if (layout.stride(0) < 65536) {
return layout.stride(0) % align != 0;
} else {
return true;
}
}
inline bool IsNeedPadding(layout::ColumnMajor layout, uint32_t align) {
if (layout.stride(1) < 65536) {
return layout.stride(1) % align != 0;
} else {
return true;
}
}
inline bool IsNeedPadding(layout::zN layout, uint32_t align) {
return false;
}
inline bool IsNeedPadding(layout::nZ layout, uint32_t align) {
return false;
}
}
#endif