| 文件 | 最后提交记录 | 最后更新时间 |
|---|---|---|
| 2 个月前 | ||
| 2 个月前 | ||
| 2 个月前 | ||
| 2 个月前 | ||
| 2 个月前 | ||
| 2 个月前 | ||
| 2 个月前 | ||
| 2 个月前 | ||
| 2 个月前 | ||
| 2 个月前 | ||
| 2 个月前 | ||
| 2 个月前 | ||
| 2 个月前 | ||
| 2 个月前 | ||
| 2 个月前 | ||
| 9 天前 | ||
| 2 个月前 | ||
| 2 个月前 | ||
| 2 个月前 | ||
| 2 个月前 | ||
| 2 个月前 | ||
| 2 个月前 | ||
| 2 个月前 | ||
| 2 个月前 | ||
| 2 个月前 | ||
| 2 个月前 | ||
| 6 个月前 |
TransdataOp使用说明
TransdataOp
功能
TransdataOp 函数用于创建并执行一个 transdata 操作,将输入的 ND 格式的 tensor 转换为 FRACTAL_NZ (后文简称为NZ)格式的 tensor。
函数介绍
atb::Status TransdataOp(atb::Context *contextPtr, aclrtStream stream, const atb::Tensor inTensor,
const aclDataType tensorType, atb::Tensor &outTensor, std::vector<int64_t> shape)
输入参数
| 参数名称 | 含义 |
|---|---|
contextPtr |
指向atb::Context 的指针,用于提供上下文信息 |
stream |
用于执行操作的流 |
inTensor |
待转换的 ND 格式的tensor |
tensorType |
输入 tensor 的数据类型 |
outTensor |
输出 tensor |
outTensor |
输出的FRACTAL_NZ格式的 tensor,函数执行后将被填充 |
shape |
输出 tensor的维度 |
返回值
atb::Status: 函数执行的状态码,表示操作是否成功。如果成功,返回 atb::ErrorType::NO_ERROR,否则返回相应的错误码。
GetShape
功能
GetShape 函数用于根据输入的 inShape 判断其是 NZ格式的shape 还是ND 格式的 shape,然后分别计算出两种格式下的 shape。
函数介绍
atb::Status GetShape(const aclDataType tensorType, const std::vector<int64_t> &inShape, std::vector<int64_t> &ndShape,
std::vector<int64_t> &nzShape)
输入参数
| 参数名称 | 含义 |
|---|---|
tensorType |
输入 tensor 的数据类型,不同数据类型的tensor的格式转换方式不同 |
inShape |
输入 tensor 的维度 |
ndShape |
输出的 ND 格式的 shape |
nzShape |
输出的 NZ 格式的 shape |
返回值
atb::Status: 函数执行的状态码,表示操作是否成功。如果成功,返回 atb::ErrorType::NO_ERROR,否则返回相应的错误码。
Transdata使用说明
注意:
- 使用
TransdataOp转换ND格式的tensor时,需要保证传入的tensor是ND格式,因此创建初始的tensor时,应当使用ND格式下的shape,而使用TransdataOp函数时传入最终想得到的NZ格式下tensor的shape; - 在
CreateTensorFromVector中只传入了一种格式下的shape,因此需要使用GetShape函数得到ND和NZ两种数据格式下的shape
使用TransdataOp进行ND格式tensor转NZ格式tensor的CreateTensorFromVector的内容如下:
/**
* @brief 简单封装,拷贝vector data中数据以创建tensor
* @details 用于创建outTensorType类型的tensor
* @param contextPtr context指针
* @param stream stream
* @param data 输入vector数据
* @param outTensorType 期望输出tensor数据类型
* @param format 输出tensor的格式,即NZ,ND等
* @param shape 输出tensor的shape
* @param outTensor 返回创建的tensor
*/
template <typename T>
atb::Status CreateTensorFromVector(atb::Context *contextPtr, aclrtStream stream, std::vector<T> data,
const aclDataType outTensorType, const aclFormat format, std::vector<int64_t> shape,
atb::Tensor &outTensor, const aclDataType inTensorType = ACL_DT_UNDEFINED)
{
atb::Tensor tensor;
aclDataType intermediateType;
switch (outTensorType) {
case aclDataType::ACL_FLOAT16:
case aclDataType::ACL_BF16:
case aclDataType::ACL_DOUBLE:
intermediateType = aclDataType::ACL_FLOAT;
break;
default:
intermediateType = outTensorType;
}
if (inTensorType == outTensorType && inTensorType != ACL_DT_UNDEFINED) {
intermediateType = outTensorType;
}
aclFormat tensorFormat = format;
std::vector<int64_t> ndShape, nzShape;
if (intermediateType != outTensorType && format == aclFormat::ACL_FORMAT_FRACTAL_NZ) {
tensorFormat = aclFormat::ACL_FORMAT_ND;
CHECK_STATUS(GetShape(outTensorType, shape, ndShape, nzShape));
} else {
ndShape.assign(shape.begin(), shape.end());
}
CHECK_STATUS(CreateTensor(intermediateType, tensorFormat, shape, tensor));
CHECK_STATUS(aclrtMemcpy(tensor.deviceData, tensor.dataSize, data.data(), sizeof(T) * data.size(),
ACL_MEMCPY_HOST_TO_DEVICE));
CHECK_STATUS(CreateTensor(outTensorType, aclFormat::ACL_FORMAT_ND, ndShape, outTensor));
if (intermediateType == outTensorType) {
// 原始创建的tensor类型,不需要转换
outTensor = tensor;
return atb::ErrorType::NO_ERROR;
}
CHECK_STATUS(CastOp(contextPtr, stream, tensor, outTensorType, outTensor));
if (format == aclFormat::ACL_FORMAT_FRACTAL_NZ) {
CHECK_STATUS(TransdataOp(contextPtr, stream, outTensor, outTensorType, outTensor, nzShape));
}
return atb::ErrorType::NO_ERROR;
}