(beta)torch_npu.contrib.module.ChannelShuffle

产品支持情况

产品 是否支持
Atlas A3 训练系列产品
Atlas A2 训练系列产品
Atlas 推理系列产品
Atlas 训练系列产品

功能说明

  • API功能:应用NPU兼容的通道shuffle操作。

  • 等价计算逻辑:

    split_shuffle=False场景可使用cpu_channel_shuffle等价替换torch_npu.contrib.module.ChannelShuffle,两者计算逻辑一致。

    import torch
    def cpu_channel_shuffle(x, groups, split_shuffle):
        # cpu仅支持 split_shuffle=False场景
        batchsize, num_channels, height, width = x.size()
        channels_per_group = num_channels // groups
        x.requires_grad_(True)
        # reshape
        x = x.view(batchsize, groups, channels_per_group, height, width)
      
        x = torch.transpose(x, 1, 2).contiguous()
      
        # flatten
        x = x.view(batchsize, -1, height, width)
        output = x.view(batchsize, -1, height, width)
        return output
    

函数原型

torch_npu.contrib.module.ChannelShuffle(in_channels, groups=2, split_shuffle=True)

参数说明

计算参数

  • in_channels (int):必选参数。输入张量中的通道总数。
  • groups (int):可选参数。shuffle组数。默认值为2。
  • split_shuffle (bool):可选参数。shuffle后是否执行chunk操作。默认值为True。

计算输入

  • x1 (Tensor):输入张量。 shape为(N,Cin,∗)(N, C_{in}, *)
  • x2 (Tensor):输入张量。 shape为(N,Cin,∗)(N, C_{in}, *)

返回值说明

  • out1 (Tensor):输出张量。 shape为(N,Cout,∗)(N, C_{out}, *)
  • out2 (Tensor):输出张量。 shape为(N,Cout,∗)(N, C_{out}, *)

约束说明

只实现了groups为2场景,请自行修改其他groups场景。

调用示例

>>> import torch, torch_npu
>>> from torch_npu.contrib.module import ChannelShuffle
>>> x1 = torch.randn(2, 32, 7, 7).npu()
>>> x2 = torch.randn(2, 32, 7, 7).npu()
>>> m = ChannelShuffle(64, split_shuffle=True)
>>> out1, out2 = m(x1, x2)
>>> out1.shape
torch.Size([2, 32, 7, 7])
>>> out2.shape
torch.Size([2, 32, 7, 7])