rms_norm对外接口
前向接口
输入:
- x:必选输入,数据类型float16, bfloat16, float32
- gamma:必选输入,数据类型float16, bfloat16, float32
输出:
- rstd:中间变量,数据类型float32
- out:必选输出,数据类型float16, bfloat16, float32
属性:
- epsilon:可选属性,数据类型float,缺省1e-6。
反向接口
输入:
- dy:必选输入,数据类型float16, bfloat16, float32
- x:必选输入,数据类型float16, bfloat16, float32
- rstd:必选输入,数据类型float32
- gamma:必选输入,数据类型float16, bfloat16, float32
输出:
- dx:必选输出,数据类型float16, bfloat16, float32
- dgamma:必选输出,数据类型float32
案例
import os
import torch
import torch_npu
import numpy as np
import math
from mindspeed.op_builder import RmsNormOpBuilder
x = torch.rand([2, 4]).to(torch.float).npu()
gamma = torch.rand([4]).to(torch.float).npu()
# 正向接口案例
mindspeed_ops = RmsNormOpBuilder().load()
out = mindspeed_ops.rms_norm(x, gamma)
# 反向接口案例
out.backward(torch.ones(out.shape).npu())