Pattern ConvertAtenAbs {
let root = op<torch.aten.abs>(input: Value) -> (resType: Type);
replace root with op<mfuse.abs>(convertValue(input)) -> (convertType(resType));
}
Pattern ConvertAtenAddTensor {
let root = op<torch.aten.add.Tensor>(operands: ValueRange) -> (resType: Type);
replace root with op<mfuse.aclnn.add>(convertValues(operands)) -> (convertType(resType));
}
Pattern ConvertAtenAddScalar {
let root = op<torch.aten.add.Scalar>(operands: ValueRange) -> (resType: Type);
replace root with op<mfuse.aclnn.add>(convertValues(operands)) -> (convertType(resType));
}
Pattern ConvertAtenBmm {
let root = op<torch.aten.bmm>(operands: ValueRange) -> (resType: Type);
replace root with op<mfuse.aclnn.batch_matmul>(convertValues(operands)) -> (convertType(resType));
}
Pattern ConvertAtenCeil {
let root = op<torch.aten.ceil>(input: Value) -> (resType: Type);
replace root with op<mfuse.ceil>(convertValue(input)) -> (convertType(resType));
}
Pattern ConvertAtenExp {
let root = op<torch.aten.exp>(input: Value) -> (resType: Type);
replace root with op<mfuse.exp>(convertValue(input)) -> (convertType(resType));
}
Pattern ConvertAtenFloor {
let root = op<torch.aten.floor>(input: Value) -> (resType: Type);
replace root with op<mfuse.floor>(convertValue(input)) -> (convertType(resType));
}
Pattern ConvertAtenIsfinite {
let root = op<torch.aten.isfinite>(input: Value) -> (resType: Type);
replace root with op<mfuse.is_finite>(convertValue(input)) -> (convertType(resType));
}
Pattern ConvertAtenLog {
let root = op<torch.aten.log>(input: Value) -> (resType: Type);
replace root with op<mfuse.log>(convertValue(input)) -> (convertType(resType));
}
Pattern ConvertAtenExpTensor {
let root = op<torch.aten.exp>(operands: ValueRange) -> (resType: Type);
replace root with op<mfuse.exp>(convertValues(operands)) -> (convertType(resType));
}
Pattern ConvertAtenGelu {
let root = op<torch.aten.gelu>(inputs: ValueRange) -> (resType: Type);
replace root with op<mfuse.aclnn.gelu>(convertValue(Get0(inputs))) { approximate = GetStrAttr(Get1(inputs)) } -> (convertType(resType));
}
Pattern ConvertAtenGeluBackward {
let root = op<torch.aten.gelu_backward>(inputs: ValueRange) -> (resType: Type);
replace root with op<mfuse.aclnn.gelu_backward>(convertValue(Get0(inputs)), convertValue(Get1(inputs))) { approximate = GetStrAttr(Get2(inputs)) } -> (convertType(resType));
}
Pattern ConvertAtenLogicalNot {
let root = op<torch.aten.logical_not>(input: Value) -> (resType: Type);
replace root with op<mfuse.logical_not>(convertValue(input)) -> (convertType(resType));
}
Pattern ConvertAtenMatmul {
let root = op<torch.aten.matmul>(operands: ValueRange) -> (resType: Type);
replace root with op<mfuse.aclnn.matmul>(convertValues(operands)) -> (convertType(resType));
}
Pattern ConvertAtenMm {
let root = op<torch.aten.mm>(operands: ValueRange) -> (resType: Type);
replace root with op<mfuse.aclnn.mm>(convertValues(operands)) -> (convertType(resType));
}
Pattern ConvertAtenNeg {
let root = op<torch.aten.neg>(input: Value) -> (resType: Type);
replace root with op<mfuse.neg>(convertValue(input)) -> (convertType(resType));
}
Pattern ConvertAtenRsqrt {
let root = op<torch.aten.rsqrt>(input: Value) -> (resType: Type);
replace root with op<mfuse.rsqrt>(convertValue(input)) -> (convertType(resType));
}
Pattern ConvertAtenReciprocal {
let root = op<torch.aten.reciprocal>(input: Value) -> (resType: Type);
replace root with op<mfuse.reciprocal>(convertValue(input)) -> (convertType(resType));
}
Pattern ConvertAtenRelu {
let root = op<torch.aten.relu>(input: Value) -> (resType: Type);
replace root with op<mfuse.relu>(convertValue(input)) -> (convertType(resType));
}
Pattern ConvertAtenSigmoid {
let root = op<torch.aten.sigmoid>(input: Value) -> (resType: Type);
replace root with op<mfuse.aclnn.sigmoid>(convertValue(input)) -> (convertType(resType));
}
Pattern ConvertAtenSqrt {
let root = op<torch.aten.sqrt>(input: Value) -> (resType: Type);
replace root with op<mfuse.sqrt>(convertValue(input)) -> (convertType(resType));
}
Pattern ConvertAtenSubTensor {
let root = op<torch.aten.sub.Tensor>(operands: ValueRange) -> (resType: Type);
replace root with op<mfuse.aclnn.sub>(convertValues(operands)) -> (convertType(resType));
}
Pattern ConvertAtenSubScalar {
let root = op<torch.aten.sub.Scalar>(operands: ValueRange) -> (resType: Type);
replace root with op<mfuse.aclnn.sub>(convertValues(operands)) -> (convertType(resType));
}
Pattern ConvertAtenTanh {
let root = op<torch.aten.tanh>(input: Value) -> (resType: Type);
replace root with op<mfuse.aclnn.tanh>(convertValue(input)) -> (convertType(resType));
}
Pattern ConvertAtenToDtype {
let root = op<torch.aten.to.dtype>(inputs: ValueRange) -> (resType: Type);
replace root with op<mfuse.cast>(convertValue(Get0(inputs))) -> (convertType(resType));
}
Pattern ConvertAtenTrunc {
let root = op<torch.aten.trunc>(input: Value) -> (resType: Type);
replace root with op<mfuse.trunc>(convertValue(input)) -> (convertType(resType));
}
Pattern ConvertAtenWhereSelf {
let root = op<torch.aten.where.self>(condition: Value, self: Value, other: Value) -> (resType: Type);
replace root with op<mfuse.select>(convertValue(condition), convertValue(self), convertValue(other)) -> (convertType(resType));
}