torch.nn

Note

若API“是否支持”为“是”,“限制与说明”为“-”,说明此API和原生API支持度保持一致。

API名称 是否支持 限制与说明
torch.nn.parameter.Parameter 支持fp32
torch.nn.parameter.UninitializedParameter -
torch.nn.parameter.UninitializedParameter.cls_to_become -
torch.nn.parameter.UninitializedBuffer -
torch.nn.Module 支持fp32
torch.nn.Module.add_module 支持fp32
torch.nn.Module.apply 支持fp32
torch.nn.Module.bfloat16 -
torch.nn.Module.buffers -
torch.nn.Module.children 支持fp32
torch.nn.Module.compile -
torch.nn.Module.cpu 支持fp32
torch.nn.Module.cuda 支持fp32
torch.nn.Module.double -
torch.nn.Module.eval 支持fp32,int64
torch.nn.Module.extra_repr 支持fp32
torch.nn.Module.float 支持fp16,fp32
torch.nn.Module.forward 支持fp32
torch.nn.Module.get_buffer -
torch.nn.Module.get_extra_state -
torch.nn.Module.get_parameter 支持fp32
torch.nn.Module.get_submodule 支持fp32
torch.nn.Module.half 支持fp16,fp32
torch.nn.Module.ipu -
torch.nn.Module.load_state_dict 支持fp32
torch.nn.Module.modules 支持fp32
torch.nn.Module.named_buffers -
torch.nn.Module.named_children 支持fp32
torch.nn.Module.named_modules 支持fp32
torch.nn.Module.named_parameters -
torch.nn.Module.parameters -
torch.nn.Module.register_backward_hook 支持fp32
torch.nn.Module.register_buffer 支持fp32
torch.nn.Module.register_forward_hook 支持fp32
torch.nn.Module.register_forward_pre_hook 支持fp32
torch.nn.Module.register_full_backward_hook 支持fp32
torch.nn.Module.register_full_backward_pre_hook 支持fp32
torch.nn.Module.register_load_state_dict_post_hook 支持fp32
torch.nn.Module.register_module 支持fp32
torch.nn.Module.register_parameter -
torch.nn.Module.register_state_dict_pre_hook -
torch.nn.Module.requires_grad_ -
torch.nn.Module.set_extra_state -
torch.nn.Module.share_memory -
torch.nn.Module.state_dict 支持fp32
torch.nn.Module.to 支持fp32
torch.nn.Module.to_empty 支持fp32
torch.nn.Module.train 支持fp32
torch.nn.Module.type 支持fp16,fp32,int64
torch.nn.Module.xpu -
torch.nn.Module.zero_grad 支持fp32
torch.nn.Sequential 支持fp32
torch.nn.Sequential.append 支持fp32
torch.nn.ModuleList 支持fp32
torch.nn.ModuleList.append 支持fp32
torch.nn.ModuleList.extend 支持fp32
torch.nn.ModuleList.insert 支持fp32
torch.nn.ModuleDict 支持fp32
torch.nn.ModuleDict.clear 支持fp32
torch.nn.ModuleDict.items 支持fp32
torch.nn.ModuleDict.keys 支持fp32
torch.nn.ModuleDict.pop 支持fp32
torch.nn.ModuleDict.update 支持fp32
torch.nn.ModuleDict.values 支持fp32
torch.nn.ParameterList 支持fp32
torch.nn.ParameterList.append 支持fp32
torch.nn.ParameterList.extend 支持fp32
torch.nn.ParameterDict 支持fp32
torch.nn.ParameterDict.clear 支持fp32
torch.nn.ParameterDict.copy 支持fp32
torch.nn.ParameterDict.fromkeys 支持fp32
torch.nn.ParameterDict.get 支持fp32
torch.nn.ParameterDict.items 支持fp32
torch.nn.ParameterDict.keys 支持fp32
torch.nn.ParameterDict.pop 支持fp32
torch.nn.ParameterDict.popitem 支持fp32
torch.nn.ParameterDict.setdefault 支持fp32
torch.nn.ParameterDict.update 支持fp32
torch.nn.ParameterDict.values 支持fp32
torch.nn.modules.module.register_module_forward_pre_hook 支持fp32
torch.nn.modules.module.register_module_forward_hook 支持fp32
torch.nn.modules.module.register_module_backward_hook 支持fp32
torch.nn.modules.module.register_module_full_backward_pre_hook -
torch.nn.modules.module.register_module_full_backward_hook 支持fp32
torch.nn.modules.module.register_module_buffer_registration_hook -
torch.nn.modules.module.register_module_module_registration_hook -
torch.nn.modules.module.register_module_parameter_registration_hook -
torch.nn.Conv1d 支持fp16,fp32
torch.nn.Conv2d 支持bf16,fp16,fp32
Atlas A2 训练系列产品默认场景下,如果频繁触发编译,建议手动设置torch.npu.config.allow_internal_format为False,控制入参不使能内部格式,避免在线编译
torch.nn.Conv3d 支持bf16,fp16,fp32
torch.nn.ConvTranspose1d 支持fp32
torch.nn.ConvTranspose2d 支持fp16,fp32
Atlas 训练系列产品/Atlas A2 训练系列产品,需手动设置torch.npu.config.allow_internal_format为False,才可支持3维输入
torch.nn.ConvTranspose3d 支持bf16,fp16,fp32
torch.nn.LazyConv1d 支持fp16,fp32
torch.nn.LazyConv1d.cls_to_become -
torch.nn.LazyConv2d 支持fp16,fp32
torch.nn.LazyConv2d.cls_to_become -
torch.nn.LazyConv3d.cls_to_become -
torch.nn.LazyConvTranspose1d 支持fp16
torch.nn.LazyConvTranspose1d.cls_to_become -
torch.nn.LazyConvTranspose2d 支持fp16,fp32
torch.nn.LazyConvTranspose2d.cls_to_become -
torch.nn.LazyConvTranspose3d.cls_to_become -
torch.nn.Unfold 支持bf16,fp16,fp32
torch.nn.Fold 支持fp16
torch.nn.MaxPool1d -
torch.nn.MaxPool2d 支持bf16,fp16,fp32
通过设置torch_npu.npu.use_compatible_impl(True),保证与社区同名接口在内存一致性上对齐
torch.nn.MaxPool3d -
torch.nn.MaxUnpool1d 支持fp16,fp32
torch.nn.MaxUnpool2d 支持fp16,fp32
torch.nn.MaxUnpool3d -
torch.nn.AvgPool1d 支持bf16,fp16,fp32
torch.nn.AvgPool2d 支持bf16,fp16,fp32
torch.nn.AvgPool3d -
torch.nn.LPPool1d 支持fp16,fp32,uint8,int8,int16,int32,int64,bool
torch.nn.LPPool2d 支持fp16,fp32,int16,int32,int64,bool
torch.nn.AdaptiveMaxPool1d -
torch.nn.AdaptiveMaxPool2d -
torch.nn.AdaptiveMaxPool3d 支持fp32,fp64
torch.nn.AdaptiveAvgPool1d 支持fp16,fp32
torch.nn.AdaptiveAvgPool2d 支持fp16,fp32
torch.nn.AdaptiveAvgPool3d 支持bf16,fp16,fp32
torch.nn.ReflectionPad1d 支持fp16,fp32
torch.nn.ReflectionPad2d 支持fp16,fp32
torch.nn.ReflectionPad3d -
torch.nn.ReplicationPad1d 支持fp16,fp32,complex64,complex128
torch.nn.ReplicationPad2d 支持fp16,fp32,complex64,complex128
torch.nn.ReplicationPad3d -
torch.nn.ZeroPad1d 支持bf16,fp16,fp32,fp64,complex64,complex128
支持2-3维
torch.nn.ZeroPad2d 可能回退至CPU执行
torch.nn.ZeroPad3d 支持bf16,fp16,fp32,fp64,complex64,complex128
支持5-6维
torch.nn.ConstantPad1d 支持int8,bool
在输入x为六维以上时可能会出现性能下降问题
torch.nn.ConstantPad2d 支持bf16,fp16,fp32,fp64,uint8,int8,int16,int32,int64,bool,complex64,complex128
在输入x为六维以上时可能会出现性能下降问题
torch.nn.ConstantPad3d 支持fp16,fp32,uint8,int8,int16,int32,int64,bool,complex64,complex128
在输入x为六维以上时可能会出现性能下降问题
torch.nn.ELU 支持bf16,fp16,fp32,fp64
torch.nn.Hardshrink 支持fp16,fp32
可能回退至CPU执行
torch.nn.Hardsigmoid 支持fp16,fp32,int32
可能回退至CPU执行
torch.nn.Hardtanh 支持bf16,fp16,fp32,fp64,uint8,int8,int16,int32,int64
torch.nn.Hardswish 支持fp16,fp32
torch.nn.LeakyReLU 支持bf16,fp16,fp32,fp64
torch.nn.LogSigmoid 支持fp16,fp32
torch.nn.MultiheadAttention 支持bf16,fp16,fp32
torch.nn.MultiheadAttention.forward 支持bf16,fp16,fp32
torch.nn.PReLU 支持fp32
torch.nn.ReLU 支持bf16,fp16,fp32,uint8,int8,int32,int64
torch.nn.ReLU6 支持bf16,fp16,fp32,fp64,uint8,int8,int16,int32,int64
torch.nn.RReLU -
torch.nn.SELU 支持fp16,fp32,fp64,uint8,int8,int16,int32,int64,bool
torch.nn.CELU 支持fp16,fp32
torch.nn.GELU 支持bf16,fp16,fp32
approximate参数仅支持设置为tanh
torch.nn.Sigmoid 支持bf16,fp16,fp32,uint8,int8,int16,int32,int64,bool,complex64,complex128
torch.nn.SiLU 支持bf16,fp16,fp32
torch.nn.Mish 支持fp16,fp32
torch.nn.Softplus 支持bf16,fp16,fp32
torch.nn.Softshrink 支持bf16,fp16,fp32
torch.nn.Softsign 支持bf16,fp16,fp32,uint8,int8,int16,int32,int64
torch.nn.Tanh 支持bf16,fp16,fp32,uint8,int8,int16,int32,int64,bool
torch.nn.Tanhshrink 支持fp16,fp32,uint8,int8,int16,int32,int64
可能回退至CPU执行
torch.nn.Threshold 支持fp16,fp32,uint8,int8,int16,int32,int64
torch.nn.GLU 支持fp16,fp32
torch.nn.Softmin 支持bf16,fp16,fp32
torch.nn.Softmax 支持bf16,fp16,fp32,fp64
torch.nn.Softmax2d 支持bf16,fp16,fp32
torch.nn.LogSoftmax 支持bf16,fp16,fp32
torch.nn.AdaptiveLogSoftmaxWithLoss -
torch.nn.AdaptiveLogSoftmaxWithLoss.log_prob -
torch.nn.AdaptiveLogSoftmaxWithLoss.predict -
torch.nn.BatchNorm1d 支持fp16,fp32
torch.nn.BatchNorm2d 支持fp16,fp32
torch.nn.BatchNorm3d 支持fp16,fp32
torch.nn.LazyBatchNorm1d.cls_to_become -
torch.nn.LazyBatchNorm2d.cls_to_become -
torch.nn.LazyBatchNorm3d.cls_to_become -
torch.nn.GroupNorm 支持fp32
eps参数需大于0
不支持jit_compile=True的场景
该API仅支持2维及以上的输入input。该API反向不支持输入input不为4维,或输入num_groups非32整除,或C轴维度非(10 * num_groups)整除的场景
torch.nn.SyncBatchNorm 支持fp16,fp32
torch.nn.SyncBatchNorm.convert_sync_batchnorm -
torch.nn.LazyInstanceNorm1d.cls_to_become -
torch.nn.LazyInstanceNorm2d.cls_to_become -
torch.nn.LazyInstanceNorm3d.cls_to_become -
torch.nn.LayerNorm 支持bf16,fp16,fp32
通过torch_npu.npu.use_compatible_impl(True),设置该接口从aclnnLayerNorm算子切换为aclnnFastLayerNorm算子,保证与社区同名接口在内存一致性上对齐。
torch.nn.RNNBase -
torch.nn.RNNBase.flatten_parameters -
torch.nn.RNN -
torch.nn.LSTM 支持fp32
不支持proj_size参数
不支持dropout参数
入参input不支持2维
torch.nn.GRU -
torch.nn.RNNCell -
torch.nn.LSTMCell 接口暂不支持jit_compile=False,需要在该模式下使用时请将"DynamicGRUV2"添加至"NPU_FUZZY_COMPILE_BLACKLIST"选项内,具体操作可参考添加二进制黑名单示例
torch.nn.GRUCell 支持fp16,fp32
torch.nn.Transformer -
torch.nn.Transformer.forward 支持fp32
torch.nn.TransformerEncoder -
torch.nn.TransformerEncoder.forward 支持fp32
torch.nn.TransformerDecoder -
torch.nn.TransformerDecoder.forward -
torch.nn.TransformerEncoderLayer.forward -
torch.nn.TransformerDecoderLayer.forward -
torch.nn.Identity 支持fp32
torch.nn.Linear 支持fp16,fp32
torch.nn.Bilinear 支持bf16,fp16,fp32
torch.nn.LazyLinear 支持fp16,fp32
torch.nn.LazyLinear.cls_to_become -
torch.nn.Dropout 支持bf16,fp16,fp32
torch.nn.Dropout2d 支持fp16,fp32,int64,bool
torch.nn.AlphaDropout 支持fp16,fp32
torch.nn.FeatureAlphaDropout 支持fp16,fp32
torch.nn.Embedding 支持int32,int64
属性max_norm不支持nan,仅支持非负值
torch.nn.Embedding.from_pretrained 支持fp64
torch.nn.EmbeddingBag 支持int32,int64
仅支持max_norm大于等于0
torch.nn.EmbeddingBag.forward 支持int64
torch.nn.EmbeddingBag.from_pretrained 支持int64
torch.nn.L1Loss 支持fp16,fp32,int64
torch.nn.MSELoss 支持fp16,fp32
torch.nn.CrossEntropyLoss 支持fp16,fp32
torch.nn.CTCLoss 支持fp32,fp64
不支持log_probs 2D输入
torch.nn.NLLLoss 支持fp16,fp32
target每一维的维度应该大于等于0且小于input的类别数
torch.nn.PoissonNLLLoss 支持bf16,fp16,fp32,int64
torch.nn.GaussianNLLLoss 支持bf16,fp16,fp32,int16,int32,int64
torch.nn.KLDivLoss 支持bf16,fp16,fp32
当前log_target参数仅支持False
torch.nn.BCELoss 支持fp16,fp32
torch.nn.BCEWithLogitsLoss 支持bf16,fp16,fp32
入参target不支持反向计算
torch.nn.MarginRankingLoss 支持bf16,fp16,fp32,int8,int32,int64
torch.nn.HingeEmbeddingLoss 支持bf16,fp16,fp32,uint8,int8,int16,int32,int64
torch.nn.MultiLabelMarginLoss -
torch.nn.HuberLoss input支持fp32,fp64
target支持bf16,fp16,fp32,fp64,uint8,int8,int16,int32,int64,bool
可能回退至CPU执行
torch.nn.SmoothL1Loss 支持bf16,fp16,fp32
torch.nn.MultiLabelSoftMarginLoss 支持fp16,fp32
torch.nn.CosineEmbeddingLoss -
torch.nn.MultiMarginLoss input支持fp32,fp64
target支持int64
可能回退至CPU执行
torch.nn.TripletMarginLoss 支持fp16,fp32,uint8,int8,int16,int32,int64
可能回退至CPU执行
torch.nn.TripletMarginWithDistanceLoss 支持bf16,fp16,fp32
torch.nn.PixelShuffle 支持bf16,fp16,fp32,fp64,uint8,int8,int16,int32,int64,bool
torch.nn.PixelUnshuffle 支持fp16,fp32,fp64,uint8,int8,int16,int32,int64,bool
torch.nn.Upsample 支持bf16,fp16,fp32,fp64
torch.nn.UpsamplingNearest2d 支持fp16,fp32,uint8
可能回退至CPU执行
torch.nn.ChannelShuffle 支持bf16,fp16,fp32,uint8,int8,int16,int32,int64,bool,complex64,complex128
torch.nn.DataParallel -
torch.nn.parallel.DistributedDataParallel -
torch.nn.parallel.DistributedDataParallel.join -
torch.nn.parallel.DistributedDataParallel.join_hook -
torch.nn.parallel.DistributedDataParallel.no_sync -
torch.nn.parallel.DistributedDataParallel.register_comm_hook -
torch.nn.utils.clip_grad_norm_ 支持bf16,fp16,fp32,fp64,uint8,int8,int16,int32,int64,bool
torch.nn.utils.clip_grad_norm -
torch.nn.utils.clip_grad_value_ 支持bf16,fp16,fp32
torch.nn.utils.vector_to_parameters 支持bf16,fp16,fp32,fp64,complex64
torch.nn.utils.weight_norm -
torch.nn.utils.spectral_norm -
torch.nn.utils.remove_spectral_norm -
torch.nn.utils.skip_init -
torch.nn.utils.prune.BasePruningMethod -
torch.nn.utils.prune.BasePruningMethod.apply -
torch.nn.utils.prune.BasePruningMethod.apply_mask 支持fp32
torch.nn.utils.prune.BasePruningMethod.compute_mask -
torch.nn.utils.prune.BasePruningMethod.prune 支持fp32
torch.nn.utils.prune.BasePruningMethod.remove 支持fp32
torch.nn.utils.prune.PruningContainer -
torch.nn.utils.prune.PruningContainer.add_pruning_method -
torch.nn.utils.prune.PruningContainer.apply -
torch.nn.utils.prune.PruningContainer.apply_mask -
torch.nn.utils.prune.PruningContainer.compute_mask 支持fp32
torch.nn.utils.prune.PruningContainer.prune 支持fp32
torch.nn.utils.prune.PruningContainer.remove 支持fp32
torch.nn.utils.prune.Identity 支持fp32
torch.nn.utils.prune.Identity.apply 支持fp32
torch.nn.utils.prune.Identity.apply_mask 支持fp32
torch.nn.utils.prune.Identity.prune 支持fp32
torch.nn.utils.prune.Identity.remove 支持fp32
torch.nn.utils.prune.RandomUnstructured 支持fp32
torch.nn.utils.prune.RandomUnstructured.apply 支持fp32
torch.nn.utils.prune.RandomUnstructured.apply_mask 支持fp32
torch.nn.utils.prune.RandomUnstructured.prune 支持fp32
torch.nn.utils.prune.RandomUnstructured.remove -
torch.nn.utils.prune.L1Unstructured 支持fp32
torch.nn.utils.prune.L1Unstructured.apply 支持fp32
torch.nn.utils.prune.L1Unstructured.apply_mask 支持fp32
torch.nn.utils.prune.L1Unstructured.prune 支持fp32
torch.nn.utils.prune.L1Unstructured.remove 支持fp32
torch.nn.utils.prune.RandomStructured 支持fp32
torch.nn.utils.prune.RandomStructured.apply 支持fp32
torch.nn.utils.prune.RandomStructured.apply_mask 支持fp32
torch.nn.utils.prune.RandomStructured.compute_mask 支持fp32
torch.nn.utils.prune.RandomStructured.prune -
torch.nn.utils.prune.RandomStructured.remove -
torch.nn.utils.prune.LnStructured 支持fp32
torch.nn.utils.prune.LnStructured.apply 支持fp32
torch.nn.utils.prune.LnStructured.apply_mask 支持fp32
torch.nn.utils.prune.LnStructured.compute_mask 支持fp32
torch.nn.utils.prune.LnStructured.prune 支持fp32
torch.nn.utils.prune.LnStructured.remove 支持fp32
torch.nn.utils.prune.CustomFromMask 支持int64
torch.nn.utils.prune.CustomFromMask.apply 支持int64
torch.nn.utils.prune.CustomFromMask.apply_mask -
torch.nn.utils.prune.CustomFromMask.prune -
torch.nn.utils.prune.CustomFromMask.remove -
torch.nn.utils.prune.random_unstructured -
torch.nn.utils.prune.l1_unstructured -
torch.nn.utils.prune.random_structured -
torch.nn.utils.prune.ln_structured -
torch.nn.utils.prune.global_unstructured -
torch.nn.utils.prune.custom_from_mask 支持int64
torch.nn.utils.prune.remove -
torch.nn.utils.prune.is_pruned -
torch.nn.utils.parametrizations.orthogonal -
torch.nn.utils.parametrizations.spectral_norm -
torch.nn.utils.parametrize.register_parametrization -
torch.nn.utils.parametrize.remove_parametrizations -
torch.nn.utils.parametrize.cached -
torch.nn.utils.parametrize.is_parametrized -
torch.nn.utils.parametrize.ParametrizationList -
torch.nn.utils.parametrize.ParametrizationList.right_inverse 支持fp32
torch.nn.utils.stateless.functional_call -
torch.nn.utils.rnn.PackedSequence 支持fp32,int64
torch.nn.utils.rnn.PackedSequence.batch_sizes -
torch.nn.utils.rnn.PackedSequence.count 支持fp32
torch.nn.utils.rnn.PackedSequence.data -
torch.nn.utils.rnn.PackedSequence.index 支持fp32
torch.nn.utils.rnn.PackedSequence.is_cuda -
torch.nn.utils.rnn.PackedSequence.is_pinned -
torch.nn.utils.rnn.PackedSequence.sorted_indices -
torch.nn.utils.rnn.PackedSequence.to 支持fp32,int64
torch.nn.utils.rnn.PackedSequence.unsorted_indices -
torch.nn.utils.rnn.pack_padded_sequence -
torch.nn.utils.rnn.pad_packed_sequence -
torch.nn.utils.rnn.pad_sequence 支持fp16,fp32
torch.nn.utils.rnn.pack_sequence -
torch.nn.utils.rnn.unpack_sequence -
torch.nn.utils.rnn.unpad_sequence -
torch.nn.Flatten 支持bf16,fp16,fp32,uint8,int8,int16,int32,int64,bool,complex64,complex128
torch.nn.Unflatten 支持fp16,fp32,fp64,uint8,int8,int16,int32,int64,bool
torch.nn.modules.lazy.LazyModuleMixin 支持fp32
torch.nn.modules.lazy.LazyModuleMixin.has_uninitialized_params 支持fp32
torch.nn.modules.lazy.LazyModuleMixin.initialize_parameters 支持fp32