import torch
from ..passes.register_pattern_to_pass import PatternBase
if hasattr(torch.npu, "is_available"):
npu_available = torch.npu.is_available()
if npu_available:
import torch_npu
import mindiesd
def create(dtype, epsilon=1e-6):
class AdaLayerNormZeroPatternDiffusers(PatternBase):
@staticmethod
def name():
return __class__.__name__ + f"-{dtype}"
@staticmethod
def inputs():
x = torch.empty(2, 2, 2, dtype=dtype, device="meta")
scale = torch.empty(2, 2, dtype=dtype, device="meta")
shift = torch.empty(2, 2, dtype=dtype, device="meta")
return [x, scale, shift]
@staticmethod
def pattern(x, scale, shift):
def func(x, scale, shift):
ln_out = torch.nn.LayerNorm(
x.shape[-1],
elementwise_affine=False,
eps=epsilon,
dtype=x.dtype,
device=x.device)(x)
out = ln_out * (1 + scale[:, None]) + shift[:, None]
return out
return func(x, scale, shift)
@staticmethod
def replacement(x, scale, shift):
norm = torch.nn.LayerNorm(x.shape[-1], eps=epsilon, dtype=x.dtype, device=x.device)
def func(x, scale, shift):
return mindiesd.layernorm_scale_shift(
layernorm=norm,
x=x,
scale=scale[:, None],
shift=shift[:, None],
fused=True
)
return func(x, scale, shift)
return AdaLayerNormZeroPatternDiffusers
AdaLayerNormPatternGroup = [create(dtype=torch.bfloat16, epsilon=1e-6)]