@@ -2959,6 +2959,7 @@ def Torch_AtenReciprocalOp : Torch_Op<"aten.reciprocal", [
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
+ let hasCanonicalizer = 1;
}
def Torch_AtenReciprocal_Op : Torch_Op<"aten.reciprocal_", [
@@ -10680,7 +10681,6 @@ def Torch_AtenOnesOp : Torch_Op<"aten.ones", [
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
- let hasFolder = 1;
}
def Torch_AtenNewOnesOp : Torch_Op<"aten.new_ones", [
@@ -10736,7 +10736,6 @@ def Torch_AtenZerosOp : Torch_Op<"aten.zeros", [
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
- let hasFolder = 1;
}
def Torch_AtenNewZerosOp : Torch_Op<"aten.new_zeros", [
@@ -11514,7 +11513,6 @@ def Torch_AtenCloneOp : Torch_Op<"aten.clone", [
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
- let hasFolder = 1;
}
def Torch_AtenLiftFreshCopyOp : Torch_Op<"aten.lift_fresh_copy", [
@@ -13775,7 +13773,6 @@ def Torch_AtenFullOp : Torch_Op<"aten.full", [
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
- let hasFolder = 1;
}
def Torch_AtenFullLikeOp : Torch_Op<"aten.full_like", [
@@ -18524,7 +18521,6 @@ def Torch_PrimNumToTensorScalarOp : Torch_Op<"prim.NumToTensor.Scalar", [
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
- let hasFolder = 1;
}
def Torch_PrimMinSelfIntOp : Torch_Op<"prim.min.self_int", [
@@ -2920,17 +2920,21 @@ void AtenMaskedFillTensorOp::getCanonicalizationPatterns(
}
//
-// AtenCloneOp
+// AtenReciprocalOp
//
-OpFoldResult AtenCloneOp::fold(FoldAdaptor adaptor) {
- // note: memory_format would be ignored
- if (getSelf().getType() == getResult().getType() &&
- llvm::dyn_cast<ValueTensorType>(getSelf().getType())) {
- // self should have value semantics
- return getSelf();
- }
- return {};
+void AtenReciprocalOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add(+[](AtenReciprocalOp op, PatternRewriter &rewriter) {
+ auto sqrtOp = op.getSelf().getDefiningOp<AtenSqrtOp>();
+ if (!sqrtOp)
+ return failure();
+
+ rewriter.replaceOpWithNewOp<AtenRsqrtOp>(op, op.getType(), sqrtOp.getSelf());
+ if (sqrtOp->use_empty())
+ rewriter.eraseOp(sqrtOp);
+ return success();
+ });
}
//
@@ -4835,130 +4839,6 @@ OpFoldResult AtenItemOp::fold(FoldAdaptor adaptor) {
return nullptr;
}
-//===----------------------------------------------------------------------===//
-// AtenOnesOp, AtenZerosOp, AtenFullOp
-//===----------------------------------------------------------------------===//
-OpFoldResult AtenOnesOp::fold(FoldAdaptor adaptor) {
- SmallVector<int64_t> sizes;
- if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) {
- return nullptr;
- }
-
- Type resultType = getResult().getType();
- BaseTensorType resultTensorType = dyn_cast<BaseTensorType>(resultType);
- if (!resultTensorType || !resultTensorType.hasDtype() ||
- !resultTensorType.hasSizes()) {
- return nullptr;
- }
-
- for (auto sz : sizes)
- if (sz == Torch::kUnknownSize || sz < 0)
- return nullptr;
-
- for (auto sz : resultTensorType.getSizes())
- if (sz == Torch::kUnknownSize || sz < 0)
- return nullptr;
-
- ShapedType shapedty =
- mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType
- sizes, resultTensorType.getDtype());
- if (!shapedty) {
- return nullptr;
- }
- auto elementType = shapedty.getElementType();
- if (isa<IntegerType>(elementType)) {
- Attribute attribute = IntegerAttr::get(elementType, 1);
- return DenseElementsAttr::get(shapedty, attribute);
- }
- if (isa<FloatType>(elementType)) {
- Attribute attribute = FloatAttr::get(elementType, 1.0);
- return DenseElementsAttr::get(shapedty, attribute);
- }
- return nullptr;
-}
-
-OpFoldResult AtenZerosOp::fold(FoldAdaptor adaptor) {
- SmallVector<int64_t> sizes;
- if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) {
- return nullptr;
- }
-
- Type resultType = getResult().getType();
- BaseTensorType resultTensorType = dyn_cast<BaseTensorType>(resultType);
- if (!resultTensorType || !resultTensorType.hasDtype() ||
- !resultTensorType.hasSizes()) {
- return nullptr;
- }
-
- for (auto sz : sizes)
- if (sz == Torch::kUnknownSize || sz < 0)
- return nullptr;
-
- for (auto sz : resultTensorType.getSizes())
- if (sz == Torch::kUnknownSize || sz < 0)
- return nullptr;
-
- ShapedType shapedty =
- mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType
- sizes, resultTensorType.getDtype());
- if (!shapedty) {
- return nullptr;
- }
-
- auto elementType = shapedty.getElementType();
- if (isa<IntegerType>(elementType)) {
- Attribute attribute = IntegerAttr::get(elementType, 0);
- return DenseElementsAttr::get(shapedty, attribute);
- }
- if (isa<FloatType>(elementType)) {
- Attribute attribute = FloatAttr::get(elementType, 0.0);
- return DenseElementsAttr::get(shapedty, attribute);
- }
-
- return nullptr;
-}
-
-OpFoldResult AtenFullOp::fold(FoldAdaptor adaptor) {
- SmallVector<int64_t> sizes;
- if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) {
- return nullptr;
- }
-
- Type resultType = getResult().getType();
- BaseTensorType resultTensorType = dyn_cast<BaseTensorType>(resultType);
- if (!resultTensorType || !resultTensorType.hasDtype() ||
- !resultTensorType.hasSizes()) {
- return nullptr;
- }
-
- for (auto sz : sizes)
- if (sz == Torch::kUnknownSize || sz < 0)
- return nullptr;
-
- for (auto sz : resultTensorType.getSizes())
- if (sz == Torch::kUnknownSize || sz < 0)
- return nullptr;
-
- ShapedType shapedty = mlir::RankedTensorType::get(
- resultTensorType.getSizes(), resultTensorType.getDtype());
-
- auto elementType = shapedty.getElementType();
- if (isa<IntegerType>(elementType)) {
- int64_t value = 0;
- if (matchPattern(getFillValue(), m_TorchConstantInt(&value))) {
- Attribute attribute = IntegerAttr::get(elementType, value);
- return DenseElementsAttr::get(shapedty, attribute);
- }
- }
- if (isa<FloatType>(elementType)) {
- double value = 0.0;
- if (matchPattern(getFillValue(), m_TorchConstantFloat(&value))) {
- Attribute attribute = FloatAttr::get(elementType, value);
- return DenseElementsAttr::get(shapedty, attribute);
- }
- }
- return nullptr;
-}
//
// AtenCeilFloatOp
//
@@ -5150,29 +5030,6 @@ OpFoldResult PrimMaxIntOp::fold(FoldAdaptor adaptor) {
// PrimNumToTensorScalarOp
//
-OpFoldResult PrimNumToTensorScalarOp::fold(FoldAdaptor adaptor) {
- Attribute a = adaptor.getA();
- auto resultTy = dyn_cast<ValueTensorType>(getType());
- if (!a)
- return {};
- if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes())
- return {};
-
- auto dty = resultTy.getDtype();
- if (auto iattr = dyn_cast<IntegerAttr>(a)) {
- a = IntegerAttr::get(dty, iattr.getInt());
- } else if (auto fattr = dyn_cast<FloatAttr>(a)) {
- a = FloatAttr::get(dty, fattr.getValueAsDouble());
- } else {
- // doesn't handle other types, like complex type
- return {};
- }
-
- auto mlirTensorType =
- RankedTensorType::get(resultTy.getSizes(), resultTy.getDtype());
- return SplatElementsAttr::get(mlirTensorType, a);
-}
-
//
// PrimMinSelfIntOp
//