1 //===- LoopCanonicalization.cpp - Cross-dialect canonicalization patterns -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains cross-dialect canonicalization patterns that cannot be 10 // actual canonicalization patterns due to undesired additional dependencies. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/SCF/Passes.h" 18 #include "mlir/Dialect/SCF/SCF.h" 19 #include "mlir/Dialect/SCF/Transforms.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 23 #include "llvm/ADT/TypeSwitch.h" 24 25 using namespace mlir; 26 using namespace mlir::scf; 27 28 namespace { 29 /// Fold dim ops of iter_args to dim ops of their respective init args. E.g.: 30 /// 31 /// ``` 32 /// %0 = ... : tensor<?x?xf32> 33 /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) { 34 /// %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32> 35 /// ... 36 /// } 37 /// ``` 38 /// 39 /// is folded to: 40 /// 41 /// ``` 42 /// %0 = ... : tensor<?x?xf32> 43 /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) { 44 /// %1 = tensor.dim %0, %c0 : tensor<?x?xf32> 45 /// ... 46 /// } 47 /// ``` 48 /// 49 /// Note: Dim ops are folded only if it can be proven that the runtime type of 50 /// the iter arg does not change with loop iterations. 51 template <typename OpTy> 52 struct DimOfIterArgFolder : public OpRewritePattern<OpTy> { 53 using OpRewritePattern<OpTy>::OpRewritePattern; 54 55 /// A simple, conservative analysis to determine if the loop is shape 56 /// conserving. I.e., the type of the arg-th yielded value is the same as the 57 /// type of the corresponding basic block argument of the loop. 58 /// Note: This function handles only simple cases. Expand as needed. 59 static bool isShapePreserving(ForOp forOp, int64_t arg) { 60 auto yieldOp = cast<YieldOp>(forOp.getBody()->getTerminator()); 61 assert(arg < static_cast<int64_t>(yieldOp.results().size()) && 62 "arg is out of bounds"); 63 Value value = yieldOp.results()[arg]; 64 while (value) { 65 if (value == forOp.getRegionIterArgs()[arg]) 66 return true; 67 OpResult opResult = value.dyn_cast<OpResult>(); 68 if (!opResult) 69 return false; 70 71 using tensor::InsertSliceOp; 72 value = 73 llvm::TypeSwitch<Operation *, Value>(opResult.getOwner()) 74 .template Case<InsertSliceOp>( 75 [&](InsertSliceOp op) { return op.dest(); }) 76 .template Case<ForOp>([&](ForOp forOp) { 77 return isShapePreserving(forOp, opResult.getResultNumber()) 78 ? forOp.getIterOperands()[opResult.getResultNumber()] 79 : Value(); 80 }) 81 .Default([&](auto op) { return Value(); }); 82 } 83 return false; 84 } 85 86 LogicalResult matchAndRewrite(OpTy dimOp, 87 PatternRewriter &rewriter) const override { 88 auto blockArg = dimOp.source().template dyn_cast<BlockArgument>(); 89 if (!blockArg) 90 return failure(); 91 auto forOp = dyn_cast<ForOp>(blockArg.getParentBlock()->getParentOp()); 92 if (!forOp) 93 return failure(); 94 if (!isShapePreserving(forOp, blockArg.getArgNumber() - 1)) 95 return failure(); 96 97 Value initArg = forOp.getOpOperandForRegionIterArg(blockArg).get(); 98 rewriter.updateRootInPlace( 99 dimOp, [&]() { dimOp.sourceMutable().assign(initArg); }); 100 101 return success(); 102 }; 103 }; 104 105 /// Canonicalize AffineMinOp/AffineMaxOp operations in the context of scf.for 106 /// and scf.parallel loops with a known range. 107 template <typename OpTy, bool IsMin> 108 struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> { 109 using OpRewritePattern<OpTy>::OpRewritePattern; 110 111 LogicalResult matchAndRewrite(OpTy op, 112 PatternRewriter &rewriter) const override { 113 auto loopMatcher = [](Value iv, Value &lb, Value &ub, Value &step) { 114 if (scf::ForOp forOp = scf::getForInductionVarOwner(iv)) { 115 lb = forOp.lowerBound(); 116 ub = forOp.upperBound(); 117 step = forOp.step(); 118 return success(); 119 } 120 if (scf::ParallelOp parOp = scf::getParallelForInductionVarOwner(iv)) { 121 for (unsigned idx = 0; idx < parOp.getNumLoops(); ++idx) { 122 if (parOp.getInductionVars()[idx] == iv) { 123 lb = parOp.lowerBound()[idx]; 124 ub = parOp.upperBound()[idx]; 125 step = parOp.step()[idx]; 126 return success(); 127 } 128 } 129 return failure(); 130 } 131 return failure(); 132 }; 133 134 return scf::canonicalizeMinMaxOpInLoop(rewriter, op, op.getAffineMap(), 135 op.operands(), IsMin, loopMatcher); 136 } 137 }; 138 139 struct SCFForLoopCanonicalization 140 : public SCFForLoopCanonicalizationBase<SCFForLoopCanonicalization> { 141 void runOnFunction() override { 142 FuncOp funcOp = getFunction(); 143 MLIRContext *ctx = funcOp.getContext(); 144 RewritePatternSet patterns(ctx); 145 scf::populateSCFForLoopCanonicalizationPatterns(patterns); 146 if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) 147 signalPassFailure(); 148 } 149 }; 150 } // namespace 151 152 void mlir::scf::populateSCFForLoopCanonicalizationPatterns( 153 RewritePatternSet &patterns) { 154 MLIRContext *ctx = patterns.getContext(); 155 patterns 156 .insert<AffineOpSCFCanonicalizationPattern<AffineMinOp, /*IsMin=*/true>, 157 AffineOpSCFCanonicalizationPattern<AffineMaxOp, /*IsMin=*/false>, 158 DimOfIterArgFolder<tensor::DimOp>, 159 DimOfIterArgFolder<memref::DimOp>>(ctx); 160 } 161 162 std::unique_ptr<Pass> mlir::createSCFForLoopCanonicalizationPass() { 163 return std::make_unique<SCFForLoopCanonicalization>(); 164 } 165