1 //===- LoopCanonicalization.cpp - Cross-dialect canonicalization patterns -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains cross-dialect canonicalization patterns that cannot be 10 // actual canonicalization patterns due to undesired additional dependencies. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/SCF/AffineCanonicalizationUtils.h" 18 #include "mlir/Dialect/SCF/Passes.h" 19 #include "mlir/Dialect/SCF/SCF.h" 20 #include "mlir/Dialect/SCF/Transforms.h" 21 #include "mlir/Dialect/Tensor/IR/Tensor.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 24 #include "llvm/ADT/TypeSwitch.h" 25 26 using namespace mlir; 27 using namespace mlir::scf; 28 29 /// A simple, conservative analysis to determine if the loop is shape 30 /// conserving. I.e., the type of the arg-th yielded value is the same as the 31 /// type of the corresponding basic block argument of the loop. 32 /// Note: This function handles only simple cases. Expand as needed. 33 static bool isShapePreserving(ForOp forOp, int64_t arg) { 34 auto yieldOp = cast<YieldOp>(forOp.getBody()->getTerminator()); 35 assert(arg < static_cast<int64_t>(yieldOp.getResults().size()) && 36 "arg is out of bounds"); 37 Value value = yieldOp.getResults()[arg]; 38 while (value) { 39 if (value == forOp.getRegionIterArgs()[arg]) 40 return true; 41 OpResult opResult = value.dyn_cast<OpResult>(); 42 if (!opResult) 43 return false; 44 45 using tensor::InsertSliceOp; 46 value = 47 llvm::TypeSwitch<Operation *, Value>(opResult.getOwner()) 48 .template Case<InsertSliceOp>( 49 [&](InsertSliceOp op) { return op.dest(); }) 50 .template Case<ForOp>([&](ForOp forOp) { 51 return isShapePreserving(forOp, opResult.getResultNumber()) 52 ? forOp.getIterOperands()[opResult.getResultNumber()] 53 : Value(); 54 }) 55 .Default([&](auto op) { return Value(); }); 56 } 57 return false; 58 } 59 60 namespace { 61 /// Fold dim ops of iter_args to dim ops of their respective init args. E.g.: 62 /// 63 /// ``` 64 /// %0 = ... : tensor<?x?xf32> 65 /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) { 66 /// %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32> 67 /// ... 68 /// } 69 /// ``` 70 /// 71 /// is folded to: 72 /// 73 /// ``` 74 /// %0 = ... : tensor<?x?xf32> 75 /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) { 76 /// %1 = tensor.dim %0, %c0 : tensor<?x?xf32> 77 /// ... 78 /// } 79 /// ``` 80 /// 81 /// Note: Dim ops are folded only if it can be proven that the runtime type of 82 /// the iter arg does not change with loop iterations. 83 template <typename OpTy> 84 struct DimOfIterArgFolder : public OpRewritePattern<OpTy> { 85 using OpRewritePattern<OpTy>::OpRewritePattern; 86 87 LogicalResult matchAndRewrite(OpTy dimOp, 88 PatternRewriter &rewriter) const override { 89 auto blockArg = dimOp.source().template dyn_cast<BlockArgument>(); 90 if (!blockArg) 91 return failure(); 92 auto forOp = dyn_cast<ForOp>(blockArg.getParentBlock()->getParentOp()); 93 if (!forOp) 94 return failure(); 95 if (!isShapePreserving(forOp, blockArg.getArgNumber() - 1)) 96 return failure(); 97 98 Value initArg = forOp.getOpOperandForRegionIterArg(blockArg).get(); 99 rewriter.updateRootInPlace( 100 dimOp, [&]() { dimOp.sourceMutable().assign(initArg); }); 101 102 return success(); 103 }; 104 }; 105 106 /// Fold dim ops of loop results to dim ops of their respective init args. E.g.: 107 /// 108 /// ``` 109 /// %0 = ... : tensor<?x?xf32> 110 /// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) { 111 /// ... 112 /// } 113 /// %1 = tensor.dim %r, %c0 : tensor<?x?xf32> 114 /// ``` 115 /// 116 /// is folded to: 117 /// 118 /// ``` 119 /// %0 = ... : tensor<?x?xf32> 120 /// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) { 121 /// ... 122 /// } 123 /// %1 = tensor.dim %0, %c0 : tensor<?x?xf32> 124 /// ``` 125 /// 126 /// Note: Dim ops are folded only if it can be proven that the runtime type of 127 /// the iter arg does not change with loop iterations. 128 template <typename OpTy> 129 struct DimOfLoopResultFolder : public OpRewritePattern<OpTy> { 130 using OpRewritePattern<OpTy>::OpRewritePattern; 131 132 LogicalResult matchAndRewrite(OpTy dimOp, 133 PatternRewriter &rewriter) const override { 134 auto forOp = dimOp.source().template getDefiningOp<scf::ForOp>(); 135 if (!forOp) 136 return failure(); 137 auto opResult = dimOp.source().template cast<OpResult>(); 138 unsigned resultNumber = opResult.getResultNumber(); 139 if (!isShapePreserving(forOp, resultNumber)) 140 return failure(); 141 rewriter.updateRootInPlace(dimOp, [&](){ 142 dimOp.sourceMutable().assign(forOp.getIterOperands()[resultNumber]); 143 }); 144 return success(); 145 } 146 }; 147 148 /// Canonicalize AffineMinOp/AffineMaxOp operations in the context of scf.for 149 /// and scf.parallel loops with a known range. 150 template <typename OpTy, bool IsMin> 151 struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> { 152 using OpRewritePattern<OpTy>::OpRewritePattern; 153 154 LogicalResult matchAndRewrite(OpTy op, 155 PatternRewriter &rewriter) const override { 156 auto loopMatcher = [](Value iv, Value &lb, Value &ub, Value &step) { 157 if (scf::ForOp forOp = scf::getForInductionVarOwner(iv)) { 158 lb = forOp.getLowerBound(); 159 ub = forOp.getUpperBound(); 160 step = forOp.getStep(); 161 return success(); 162 } 163 if (scf::ParallelOp parOp = scf::getParallelForInductionVarOwner(iv)) { 164 for (unsigned idx = 0; idx < parOp.getNumLoops(); ++idx) { 165 if (parOp.getInductionVars()[idx] == iv) { 166 lb = parOp.getLowerBound()[idx]; 167 ub = parOp.getUpperBound()[idx]; 168 step = parOp.getStep()[idx]; 169 return success(); 170 } 171 } 172 return failure(); 173 } 174 return failure(); 175 }; 176 177 return scf::canonicalizeMinMaxOpInLoop(rewriter, op, op.getAffineMap(), 178 op.operands(), IsMin, loopMatcher); 179 } 180 }; 181 182 struct SCFForLoopCanonicalization 183 : public SCFForLoopCanonicalizationBase<SCFForLoopCanonicalization> { 184 void runOnFunction() override { 185 FuncOp funcOp = getFunction(); 186 MLIRContext *ctx = funcOp.getContext(); 187 RewritePatternSet patterns(ctx); 188 scf::populateSCFForLoopCanonicalizationPatterns(patterns); 189 if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) 190 signalPassFailure(); 191 } 192 }; 193 } // namespace 194 195 void mlir::scf::populateSCFForLoopCanonicalizationPatterns( 196 RewritePatternSet &patterns) { 197 MLIRContext *ctx = patterns.getContext(); 198 patterns 199 .insert<AffineOpSCFCanonicalizationPattern<AffineMinOp, /*IsMin=*/true>, 200 AffineOpSCFCanonicalizationPattern<AffineMaxOp, /*IsMin=*/false>, 201 DimOfIterArgFolder<tensor::DimOp>, 202 DimOfIterArgFolder<memref::DimOp>, 203 DimOfLoopResultFolder<tensor::DimOp>, 204 DimOfLoopResultFolder<memref::DimOp>>(ctx); 205 } 206 207 std::unique_ptr<Pass> mlir::createSCFForLoopCanonicalizationPass() { 208 return std::make_unique<SCFForLoopCanonicalization>(); 209 } 210