* This program is free software, you can redistribute it and/or modify.
* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. See LICENSE in the root of
* the software repository for the full text of the License.
*/
#ifndef OPTEST_TORCH_UTILS_H
#define OPTEST_TORCH_UTILS_H
#include <acl/acl.h>
#include <torch/torch.h>
#include <torch_npu/csrc/core/npu/DeviceUtils.h>
#include <torch_npu/csrc/core/npu/NPUFormat.h>
#include <torch_npu/csrc/core/npu/NPUFunctions.h>
#include <torch_npu/csrc/core/npu/NPUStream.h>
namespace CatlassKernelWrapper {
enum class TransposeStatus : uint32_t
{
NO_TRANSPOSE = 0,
TRANSPOSE = 1,
NON_CONTINUOUS = 2,
LAYOUT_ZN = 3,
LAYOUT_NZ = 4
};
* @brief Allocate an NPU output tensor with ND format.
* @param shape Output tensor shape.
* @param dtype Output tensor dtype.
* @return Newly allocated tensor on the active NPU device.
*/
torch::Tensor GetOutputTensor(const std::vector<int64_t>& shape, torch::Dtype dtype);
* @brief Convert a canonical dtype string to a torch dtype.
* @param typeStr Name such as "float16", "float", or "bfloat16".
* @return Matching torch dtype.
*/
torch::Dtype TypeStrToTorchDtype(const std::string& typeStr);
* @brief Convert a torch dtype to the corresponding ACL dtype.
* @param torchDtype PyTorch scalar dtype.
* @return Matching ACL dtype.
*/
aclDataType TorchDtypeToAclDtype(torch::Dtype torchDtype);
* @brief Convert an ACL dtype to the corresponding torch dtype.
* @param aclDtype ACL dtype value.
* @return Matching PyTorch scalar dtype.
*/
torch::Dtype AclDtypeToTorchDtype(aclDataType aclDtype);
* @brief Classify the logical layout of a 2D tensor view.
* @param mat Tensor whose last two dimensions are interpreted as a matrix.
* @return Whether the matrix is normal, transposed, NZ-formatted, or unsupported.
*/
TransposeStatus GetTransposeStatus(const at::Tensor& mat);
}
#endif