TransData

产品支持情况

产品

是否支持

Atlas A3 训练系列产品/Atlas A3 推理系列产品

Atlas A2 训练系列产品/Atlas A2 推理系列产品

功能说明

将输入数据的排布格式转换为目标排布格式。

本接口支持的数据格式转换场景包括以下四种,除维度顺序变换外,其中涉及到C轴和N轴的拆分,具体转换方式为,C轴拆分为C1轴、C0轴,N轴拆分为N1轴、N0轴。对于位宽为16的数据类型的数据,C0和N0固定为16,C1和N1的计算公式如下。

  • 场景1:NCDHW -> NDC1HWC0

    输入Tensor {shape:[N, C, D, H, W]},输出Tensor {shape:[N, D, C/16, H, W, 16]}。请注意,C0实际上等于16,为便于展示,下图中C0被设定为2。

    图 1 NCDHW格式转为NDC1HWC0格式示意图

  • 场景2:NDC1HWC0 -> NCDHW

    输入Tensor {shape:[N, D, C/16, H, W, 16]},输出Tensor {shape:[N, C, D, H, W]}。请注意,C0实际上等于16,为便于展示,下图中C0被设定为2。

    图 2 NDC1HWC0格式转为NCDHW格式示意图

  • 场景3:NCDHW -> FRACTAL_Z_3D

    输入Tensor {shape:[N, C, D, H, W]},输出Tensor {shape:[D, C/16, H, W, N/16, 16, 16]}。请注意,C0和N0实际上等于16,为便于展示,下图中C0和N0被设定为2。

    图 3 NCDHW格式转为FRACTAL_Z_3D格式示意图

  • 场景4:FRACTAL_Z_3D -> NCDHW

    输入Tensor {shape:[D, C/16, H, W, N/16, 16, 16]},输出Tensor {shape:[N, C, D, H, W]}。请注意,C0和N0实际上等于16,为便于展示,下图中C0和N0被设定为2。

    图 4 FRACTAL_Z_3D格式转为NCDHW格式示意图

函数原型

  • 通过sharedTmpBuffer入参传入临时空间

    template <const TransDataConfig& config, typename T, typename U, typename S>
    __aicore__ inline void TransData(const LocalTensor<T>& dstTensor, const LocalTensor<T>& srcTensor, const LocalTensor<uint8_t>& sharedTmpBuffer, const TransDataParams<U, S>& params)
    
  • 接口框架申请临时空间

    template <const TransDataConfig& config, typename T, typename U, typename S>
    __aicore__ inline void TransData(const LocalTensor<T>& dstTensor, const LocalTensor<T>& srcTensor, const TransDataParams<U, S>& params)
    

由于该接口的内部实现中,需要额外的临时空间来存储计算过程中的中间变量。临时空间支持开发者通过sharedTmpBuffer入参传入接口框架申请两种方式。

  • 通过sharedTmpBuffer入参传入,使用该tensor作为临时空间进行处理,接口框架不再申请。该方式开发者可以自行管理sharedTmpBuffer内存空间,并在接口调用完成后,复用该部分内存,内存不会反复申请释放,灵活性较高,内存利用率也较高。
  • 接口框架申请临时空间,开发者无需申请,但是需要预留临时空间的大小。

通过sharedTmpBuffer传入的情况,开发者需要为tensor申请空间;接口框架申请的方式,开发者需要预留临时空间。临时空间大小BufferSize的获取方式如下:通过GetTransDataMaxMinTmpSize中提供的接口获取需要预留空间范围的大小。

参数说明

表 1 模板参数说明

参数名

描述

config

指定数据格式转换的场景。当前支持的转换场景有如下四种:NCDHW -> NDC1HWC0、NDC1HWC0 -> NCDHW、NCDHW -> FRACTAL_Z_3D、FRACTAL_Z_3D -> NCDHW。该参数为TransDataConfig类型,具体定义如下。

struct TransDataConfig {
    DataFormat srcFormat;
    DataFormat dstFormat;
};

enum class DataFormat : uint8_t { ND = 0, NZ, NCHW, NC1HWC0, NHWC, NCDHW, NDC1HWC0, FRACTAL_Z_3D, };

配置示例如下。

constexpr AscendC::TransDataConfig config1 = {AscendC::DataFormat::NCDHW, AscendC::DataFormat::FRACTAL_Z_3D};

T

操作数的数据类型。

Atlas A3 训练系列产品/Atlas A3 推理系列产品,支持的数据类型为:int16_t、uint16_t、half、bfloat16_t。

Atlas A2 训练系列产品/Atlas A2 推理系列产品,支持的数据类型为:int16_t、uint16_t、half、bfloat16_t。

U

源操作数的Shape信息,Layout类型。

AscendC::Layout ncdhwLayout = AscendC::MakeLayout(AscendC::MakeShape(n, c, d, h, w), AscendC::MakeStride());

S

目的操作数的Shape信息,Layout类型。

AscendC::Layout fractalzLayout = AscendC::MakeLayout(AscendC::MakeShape(d, c1, h, w, n1, n0, c0), AscendC::MakeStride());

表 2 接口参数说明

参数名

输入/输出

描述

dstTensor

输出

目的操作数。

类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。

srcTensor

输入

源操作数。

类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。

源操作数的数据类型需要与目的操作数保持一致。

sharedTmpBuffer

输入

临时缓存。

类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。

用于TransData内部复杂计算时存储中间变量,由开发者提供。

临时空间大小BufferSize的获取方式请参考GetTransDataMaxMinTmpSize

params

输入

源操作数和目的操作数的Shape信息。该参数为TransDataParams数据类型,具体定义如下,其中模板参数T、U必须为Layout类型。该参数指定的Shape维度必须与config中Format对应维度保持一致。

template <typename T, typename U>
struct TransDataParams {
    T srcLayout;
    U dstLayout;
};

配置示例如下。

AscendC::Layout ncdhwLayout = AscendC::MakeLayout(AscendC::MakeShape(n, c, d, h, w), AscendC::MakeStride());
AscendC::Layout fractalzLayout = AscendC::MakeLayout(AscendC::MakeShape(d, c1, h, w, n1, n0, c0), AscendC::MakeStride());
AscendC::TransDataParams<decltype(ncdhwLayout), decltype(fractalzLayout)> params = {ncdhwLayout, fractalzLayout};

返回值说明

约束说明

  • 操作数地址对齐要求请参见通用地址对齐约束
  • 不支持源操作数与目的操作数地址重叠。
  • 不支持sharedTmpBuffer与源操作数和目的操作数地址重叠。
  • 对于NCDHW格式的输入,如果H轴和W轴合并后的轴不是32字节对齐,则在调用此接口前,用户需要在合并后的轴上填充数据,使其达到32字节对齐。调用此接口时,在指定Shape信息的参数处,应传入原始Shape,即合轴前的Shape。例如,如果输入的原始Shape是[1, 16, 2, 3, 5],则用户需要将输入数据填充至Shape [1, 16, 2, 16],填充的数据为无效数据。
  • 对于NCDHW格式的输出,接口实现将H轴和W轴合并,并在合并后的轴上填充数据以达到32字节对齐;调用此接口时,在指定Shape信息的参数处,应传入原始Shape,即合并轴前的Shape。例如,如果原始NCDHW格式的目标Shape为[1, 16, 2, 3, 5],则实际输出Shape为[1, 16, 2, 16]的数据,其中接口填充的数据为无效数据。

调用示例

AscendC::LocalTensor<half> dstLocal = outQueue.AllocTensor<half>();
AscendC::LocalTensor<half> srcLocal = inQueue.DeQue<half>();
AscendC::LocalTensor<uint8_t> tmp = tbuf.Get<uint8_t>();
// 构造Layout方式
AscendC::Layout ncdhwLayout = AscendC::MakeLayout(AscendC::MakeShape(1, 32, 2, 2, 8), AscendC::MakeStride());
AscendC::Layout ndc1hwc0Layout = AscendC::MakeLayout(AscendC::MakeShape(1, 2, 2, 2, 8, 16), AscendC::MakeStride());
static constexpr AscendC::TransDataConfig config = {DataFormat::NCDHW, DataFormat::NDC1HWC0};
AscendC::TransDataParams<decltype(ncdhwLayout), decltype(ndc1hwc0Layout)> params = {ncdhwLayout, ndc1hwc0Layout};
AscendC::TransData<config>(dstLocal, srcLocal, tmp, params);

结果示例如下:

输入、输出的数据类型为half
输入数据(src): 
[[[[[  0   1   2   3   4   5   6   7]
    [  8   9  10  11  12  13  14  15]]]

  [[[ 16  17  18  19  20  21  22  23]
    [ 24  25  26  27  28  29  30  31]]]

  [[[ 32  33  34  35  36  37  38  39]
    [ 40  41  42  43  44  45  46  47]]]

  [[[ 48  49  50  51  52  53  54  55]
    [ 56  57  58  59  60  61  62  63]]]

  [[[ 64  65  66  67  68  69  70  71]
    [ 72  73  74  75  76  77  78  79]]]

  [[[ 80  81  82  83  84  85  86  87]
    [ 88  89  90  91  92  93  94  95]]]

  [[[ 96  97  98  99 100 101 102 103]
    [104 105 106 107 108 109 110 111]]]

  [[[112 113 114 115 116 117 118 119]
    [120 121 122 123 124 125 126 127]]]

  [[[128 129 130 131 132 133 134 135]
    [136 137 138 139 140 141 142 143]]]

  [[[144 145 146 147 148 149 150 151]
    [152 153 154 155 156 157 158 159]]]

  [[[160 161 162 163 164 165 166 167]
    [168 169 170 171 172 173 174 175]]]

  [[[176 177 178 179 180 181 182 183]
    [184 185 186 187 188 189 190 191]]]

  [[[192 193 194 195 196 197 198 199]
    [200 201 202 203 204 205 206 207]]]

  [[[208 209 210 211 212 213 214 215]
    [216 217 218 219 220 221 222 223]]]

  [[[224 225 226 227 228 229 230 231]
    [232 233 234 235 236 237 238 239]]]

  [[[240 241 242 243 244 245 246 247]
    [248 249 250 251 252 253 254 255]]]

  [[[256 257 258 259 260 261 262 263]
    [264 265 266 267 268 269 270 271]]]

  [[[272 273 274 275 276 277 278 279]
    [280 281 282 283 284 285 286 287]]]

  [[[288 289 290 291 292 293 294 295]
    [296 297 298 299 300 301 302 303]]]

  [[[304 305 306 307 308 309 310 311]
    [312 313 314 315 316 317 318 319]]]

  [[[320 321 322 323 324 325 326 327]
    [328 329 330 331 332 333 334 335]]]

  [[[336 337 338 339 340 341 342 343]
    [344 345 346 347 348 349 350 351]]]

  [[[352 353 354 355 356 357 358 359]
    [360 361 362 363 364 365 366 367]]]

  [[[368 369 370 371 372 373 374 375]
    [376 377 378 379 380 381 382 383]]]

  [[[384 385 386 387 388 389 390 391]
    [392 393 394 395 396 397 398 399]]]

  [[[400 401 402 403 404 405 406 407]
    [408 409 410 411 412 413 414 415]]]

  [[[416 417 418 419 420 421 422 423]
    [424 425 426 427 428 429 430 431]]]

  [[[432 433 434 435 436 437 438 439]
    [440 441 442 443 444 445 446 447]]]

  [[[448 449 450 451 452 453 454 455]
    [456 457 458 459 460 461 462 463]]]

  [[[464 465 466 467 468 469 470 471]
    [472 473 474 475 476 477 478 479]]]

  [[[480 481 482 483 484 485 486 487]
    [488 489 490 491 492 493 494 495]]]

  [[[496 497 498 499 500 501 502 503]
    [504 505 506 507 508 509 510 511]]]]]

输入config:{DataFormat::NCDHW, DataFormat::NDC1HWC0}
输入params:{(1, 32, 2, 2, 8), (1, 2, 2, 2, 8, 16)}

输出数据(dst): 
[[[[[[  0  16  32  48  64  80  96 112 128 144 160 176
      192 208 224 240]
     [  1  17  33  49  65  81  97 113 129 145 161 177
      193 209 225 241]
     [  2  18  34  50  66  82  98 114 130 146 162 178
      194 210 226 242]
     [  3  19  35  51  67  83  99 115 131 147 163 179
      195 211 227 243]]

    [[  4  20  36  52  68  84 100 116 132 148 164 180
      196 212 228 244]
     [  5  21  37  53  69  85 101 117 133 149 165 181
      197 213 229 245]
     [  6  22  38  54  70  86 102 118 134 150 166 182
      198 214 230 246]
     [  7  23  39  55  71  87 103 119 135 151 167 183
      199 215 231 247]]

    [[  8  24  40  56  72  88 104 120 136 152 168 184
      200 216 232 248]
     [  9  25  41  57  73  89 105 121 137 153 169 185
      201 217 233 249]
     [ 10  26  42  58  74  90 106 122 138 154 170 186
      202 218 234 250]
     [ 11  27  43  59  75  91 107 123 139 155 171 187
      203 219 235 251]]

    [[ 12  28  44  60  76  92 108 124 140 156 172 188
      204 220 236 252]
     [ 13  29  45  61  77  93 109 125 141 157 173 189
      205 221 237 253]
     [ 14  30  46  62  78  94 110 126 142 158 174 190
      206 222 238 254]
     [ 15  31  47  63  79  95 111 127 143 159 175 191
      207 223 239 255]]]


   [[[256 272 288 304 320 336 352 368 384 400 416 432
      448 464 480 496]
     [257 273 289 305 321 337 353 369 385 401 417 433
      449 465 481 497]
     [258 274 290 306 322 338 354 370 386 402 418 434
      450 466 482 498]
     [259 275 291 307 323 339 355 371 387 403 419 435
      451 467 483 499]]

    [[260 276 292 308 324 340 356 372 388 404 420 436
      452 468 484 500]
     [261 277 293 309 325 341 357 373 389 405 421 437
      453 469 485 501]
     [262 278 294 310 326 342 358 374 390 406 422 438
      454 470 486 502]
     [263 279 295 311 327 343 359 375 391 407 423 439
      455 471 487 503]]

    [[264 280 296 312 328 344 360 376 392 408 424 440
      456 472 488 504]
     [265 281 297 313 329 345 361 377 393 409 425 441
      457 473 489 505]
     [266 282 298 314 330 346 362 378 394 410 426 442
      458 474 490 506]
     [267 283 299 315 331 347 363 379 395 411 427 443
      459 475 491 507]]

    [[268 284 300 316 332 348 364 380 396 412 428 444
      460 476 492 508]
     [269 285 301 317 333 349 365 381 397 413 429 445
      461 477 493 509]
     [270 286 302 318 334 350 366 382 398 414 430 446
      462 478 494 510]
     [271 287 303 319 335 351 367 383 399 415 431 447
      463 479 495 511]]]]]]