1 //===- LoopCanonicalization.cpp - Cross-dialect canonicalization patterns -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains cross-dialect canonicalization patterns that cannot be 10 // actual canonicalization patterns due to undesired additional dependencies. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/SCF/Passes.h" 18 #include "mlir/Dialect/SCF/SCF.h" 19 #include "mlir/Dialect/SCF/Transforms.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 23 #include "llvm/ADT/TypeSwitch.h" 24 25 using namespace mlir; 26 using namespace mlir::scf; 27 28 /// A simple, conservative analysis to determine if the loop is shape 29 /// conserving. I.e., the type of the arg-th yielded value is the same as the 30 /// type of the corresponding basic block argument of the loop. 31 /// Note: This function handles only simple cases. Expand as needed. 32 static bool isShapePreserving(ForOp forOp, int64_t arg) { 33 auto yieldOp = cast<YieldOp>(forOp.getBody()->getTerminator()); 34 assert(arg < static_cast<int64_t>(yieldOp.results().size()) && 35 "arg is out of bounds"); 36 Value value = yieldOp.results()[arg]; 37 while (value) { 38 if (value == forOp.getRegionIterArgs()[arg]) 39 return true; 40 OpResult opResult = value.dyn_cast<OpResult>(); 41 if (!opResult) 42 return false; 43 44 using tensor::InsertSliceOp; 45 value = 46 llvm::TypeSwitch<Operation *, Value>(opResult.getOwner()) 47 .template Case<InsertSliceOp>( 48 [&](InsertSliceOp op) { return op.dest(); }) 49 .template Case<ForOp>([&](ForOp forOp) { 50 return isShapePreserving(forOp, opResult.getResultNumber()) 51 ? forOp.getIterOperands()[opResult.getResultNumber()] 52 : Value(); 53 }) 54 .Default([&](auto op) { return Value(); }); 55 } 56 return false; 57 } 58 59 namespace { 60 /// Fold dim ops of iter_args to dim ops of their respective init args. E.g.: 61 /// 62 /// ``` 63 /// %0 = ... : tensor<?x?xf32> 64 /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) { 65 /// %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32> 66 /// ... 67 /// } 68 /// ``` 69 /// 70 /// is folded to: 71 /// 72 /// ``` 73 /// %0 = ... : tensor<?x?xf32> 74 /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) { 75 /// %1 = tensor.dim %0, %c0 : tensor<?x?xf32> 76 /// ... 77 /// } 78 /// ``` 79 /// 80 /// Note: Dim ops are folded only if it can be proven that the runtime type of 81 /// the iter arg does not change with loop iterations. 82 template <typename OpTy> 83 struct DimOfIterArgFolder : public OpRewritePattern<OpTy> { 84 using OpRewritePattern<OpTy>::OpRewritePattern; 85 86 LogicalResult matchAndRewrite(OpTy dimOp, 87 PatternRewriter &rewriter) const override { 88 auto blockArg = dimOp.source().template dyn_cast<BlockArgument>(); 89 if (!blockArg) 90 return failure(); 91 auto forOp = dyn_cast<ForOp>(blockArg.getParentBlock()->getParentOp()); 92 if (!forOp) 93 return failure(); 94 if (!isShapePreserving(forOp, blockArg.getArgNumber() - 1)) 95 return failure(); 96 97 Value initArg = forOp.getOpOperandForRegionIterArg(blockArg).get(); 98 rewriter.updateRootInPlace( 99 dimOp, [&]() { dimOp.sourceMutable().assign(initArg); }); 100 101 return success(); 102 }; 103 }; 104 105 /// Fold dim ops of loop results to dim ops of their respective init args. E.g.: 106 /// 107 /// ``` 108 /// %0 = ... : tensor<?x?xf32> 109 /// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) { 110 /// ... 111 /// } 112 /// %1 = tensor.dim %r, %c0 : tensor<?x?xf32> 113 /// ``` 114 /// 115 /// is folded to: 116 /// 117 /// ``` 118 /// %0 = ... : tensor<?x?xf32> 119 /// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) { 120 /// ... 121 /// } 122 /// %1 = tensor.dim %0, %c0 : tensor<?x?xf32> 123 /// ``` 124 /// 125 /// Note: Dim ops are folded only if it can be proven that the runtime type of 126 /// the iter arg does not change with loop iterations. 127 template <typename OpTy> 128 struct DimOfLoopResultFolder : public OpRewritePattern<OpTy> { 129 using OpRewritePattern<OpTy>::OpRewritePattern; 130 131 LogicalResult matchAndRewrite(OpTy dimOp, 132 PatternRewriter &rewriter) const override { 133 auto forOp = dimOp.source().template getDefiningOp<scf::ForOp>(); 134 if (!forOp) 135 return failure(); 136 auto opResult = dimOp.source().template cast<OpResult>(); 137 unsigned resultNumber = opResult.getResultNumber(); 138 if (!isShapePreserving(forOp, resultNumber)) 139 return failure(); 140 rewriter.updateRootInPlace(dimOp, [&](){ 141 dimOp.sourceMutable().assign(forOp.getIterOperands()[resultNumber]); 142 }); 143 return success(); 144 } 145 }; 146 147 /// Canonicalize AffineMinOp/AffineMaxOp operations in the context of scf.for 148 /// and scf.parallel loops with a known range. 149 template <typename OpTy, bool IsMin> 150 struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> { 151 using OpRewritePattern<OpTy>::OpRewritePattern; 152 153 LogicalResult matchAndRewrite(OpTy op, 154 PatternRewriter &rewriter) const override { 155 auto loopMatcher = [](Value iv, Value &lb, Value &ub, Value &step) { 156 if (scf::ForOp forOp = scf::getForInductionVarOwner(iv)) { 157 lb = forOp.lowerBound(); 158 ub = forOp.upperBound(); 159 step = forOp.step(); 160 return success(); 161 } 162 if (scf::ParallelOp parOp = scf::getParallelForInductionVarOwner(iv)) { 163 for (unsigned idx = 0; idx < parOp.getNumLoops(); ++idx) { 164 if (parOp.getInductionVars()[idx] == iv) { 165 lb = parOp.lowerBound()[idx]; 166 ub = parOp.upperBound()[idx]; 167 step = parOp.step()[idx]; 168 return success(); 169 } 170 } 171 return failure(); 172 } 173 return failure(); 174 }; 175 176 return scf::canonicalizeMinMaxOpInLoop(rewriter, op, op.getAffineMap(), 177 op.operands(), IsMin, loopMatcher); 178 } 179 }; 180 181 struct SCFForLoopCanonicalization 182 : public SCFForLoopCanonicalizationBase<SCFForLoopCanonicalization> { 183 void runOnFunction() override { 184 FuncOp funcOp = getFunction(); 185 MLIRContext *ctx = funcOp.getContext(); 186 RewritePatternSet patterns(ctx); 187 scf::populateSCFForLoopCanonicalizationPatterns(patterns); 188 if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) 189 signalPassFailure(); 190 } 191 }; 192 } // namespace 193 194 void mlir::scf::populateSCFForLoopCanonicalizationPatterns( 195 RewritePatternSet &patterns) { 196 MLIRContext *ctx = patterns.getContext(); 197 patterns 198 .insert<AffineOpSCFCanonicalizationPattern<AffineMinOp, /*IsMin=*/true>, 199 AffineOpSCFCanonicalizationPattern<AffineMaxOp, /*IsMin=*/false>, 200 DimOfIterArgFolder<tensor::DimOp>, 201 DimOfIterArgFolder<memref::DimOp>, 202 DimOfLoopResultFolder<tensor::DimOp>, 203 DimOfLoopResultFolder<memref::DimOp>>(ctx); 204 } 205 206 std::unique_ptr<Pass> mlir::createSCFForLoopCanonicalizationPass() { 207 return std::make_unique<SCFForLoopCanonicalization>(); 208 } 209