设置线性module里面的mm和bmm算子是否用ND格式。
torch_npu.npu.set_mm_bmm_format_nd(bool)
import torch import torch_npu torch_npu.npu.set_mm_bmm_format_nd(True)