1 //===- LoopCanonicalization.cpp - Cross-dialect canonicalization patterns -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains cross-dialect canonicalization patterns that cannot be 10 // actual canonicalization patterns due to undesired additional dependencies. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/SCF/IR/SCF.h" 18 #include "mlir/Dialect/SCF/Transforms/Passes.h" 19 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 20 #include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h" 21 #include "mlir/Dialect/Tensor/IR/Tensor.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 24 #include "llvm/ADT/TypeSwitch.h" 25 26 using namespace mlir; 27 using namespace mlir::scf; 28 29 /// A simple, conservative analysis to determine if the loop is shape 30 /// conserving. I.e., the type of the arg-th yielded value is the same as the 31 /// type of the corresponding basic block argument of the loop. 32 /// Note: This function handles only simple cases. Expand as needed. 33 static bool isShapePreserving(ForOp forOp, int64_t arg) { 34 auto yieldOp = cast<YieldOp>(forOp.getBody()->getTerminator()); 35 assert(arg < static_cast<int64_t>(yieldOp.getResults().size()) && 36 "arg is out of bounds"); 37 Value value = yieldOp.getResults()[arg]; 38 while (value) { 39 if (value == forOp.getRegionIterArgs()[arg]) 40 return true; 41 OpResult opResult = value.dyn_cast<OpResult>(); 42 if (!opResult) 43 return false; 44 45 using tensor::InsertSliceOp; 46 value = 47 llvm::TypeSwitch<Operation *, Value>(opResult.getOwner()) 48 .template Case<InsertSliceOp>( 49 [&](InsertSliceOp op) { return op.getDest(); }) 50 .template Case<ForOp>([&](ForOp forOp) { 51 return isShapePreserving(forOp, opResult.getResultNumber()) 52 ? forOp.getIterOperands()[opResult.getResultNumber()] 53 : Value(); 54 }) 55 .Default([&](auto op) { return Value(); }); 56 } 57 return false; 58 } 59 60 namespace { 61 /// Fold dim ops of iter_args to dim ops of their respective init args. E.g.: 62 /// 63 /// ``` 64 /// %0 = ... : tensor<?x?xf32> 65 /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) { 66 /// %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32> 67 /// ... 68 /// } 69 /// ``` 70 /// 71 /// is folded to: 72 /// 73 /// ``` 74 /// %0 = ... : tensor<?x?xf32> 75 /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) { 76 /// %1 = tensor.dim %0, %c0 : tensor<?x?xf32> 77 /// ... 78 /// } 79 /// ``` 80 /// 81 /// Note: Dim ops are folded only if it can be proven that the runtime type of 82 /// the iter arg does not change with loop iterations. 83 template <typename OpTy> 84 struct DimOfIterArgFolder : public OpRewritePattern<OpTy> { 85 using OpRewritePattern<OpTy>::OpRewritePattern; 86 87 LogicalResult matchAndRewrite(OpTy dimOp, 88 PatternRewriter &rewriter) const override { 89 auto blockArg = dimOp.getSource().template dyn_cast<BlockArgument>(); 90 if (!blockArg) 91 return failure(); 92 auto forOp = dyn_cast<ForOp>(blockArg.getParentBlock()->getParentOp()); 93 if (!forOp) 94 return failure(); 95 if (!isShapePreserving(forOp, blockArg.getArgNumber() - 1)) 96 return failure(); 97 98 Value initArg = forOp.getOpOperandForRegionIterArg(blockArg).get(); 99 rewriter.updateRootInPlace( 100 dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); }); 101 102 return success(); 103 }; 104 }; 105 106 /// Fold dim ops of loop results to dim ops of their respective init args. E.g.: 107 /// 108 /// ``` 109 /// %0 = ... : tensor<?x?xf32> 110 /// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) { 111 /// ... 112 /// } 113 /// %1 = tensor.dim %r, %c0 : tensor<?x?xf32> 114 /// ``` 115 /// 116 /// is folded to: 117 /// 118 /// ``` 119 /// %0 = ... : tensor<?x?xf32> 120 /// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) { 121 /// ... 122 /// } 123 /// %1 = tensor.dim %0, %c0 : tensor<?x?xf32> 124 /// ``` 125 /// 126 /// Note: Dim ops are folded only if it can be proven that the runtime type of 127 /// the iter arg does not change with loop iterations. 128 template <typename OpTy> 129 struct DimOfLoopResultFolder : public OpRewritePattern<OpTy> { 130 using OpRewritePattern<OpTy>::OpRewritePattern; 131 132 LogicalResult matchAndRewrite(OpTy dimOp, 133 PatternRewriter &rewriter) const override { 134 auto forOp = dimOp.getSource().template getDefiningOp<scf::ForOp>(); 135 if (!forOp) 136 return failure(); 137 auto opResult = dimOp.getSource().template cast<OpResult>(); 138 unsigned resultNumber = opResult.getResultNumber(); 139 if (!isShapePreserving(forOp, resultNumber)) 140 return failure(); 141 rewriter.updateRootInPlace(dimOp, [&]() { 142 dimOp.getSourceMutable().assign(forOp.getIterOperands()[resultNumber]); 143 }); 144 return success(); 145 } 146 }; 147 148 /// Canonicalize AffineMinOp/AffineMaxOp operations in the context of scf.for 149 /// and scf.parallel loops with a known range. 150 template <typename OpTy, bool IsMin> 151 struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> { 152 using OpRewritePattern<OpTy>::OpRewritePattern; 153 154 LogicalResult matchAndRewrite(OpTy op, 155 PatternRewriter &rewriter) const override { 156 auto loopMatcher = [](Value iv, OpFoldResult &lb, OpFoldResult &ub, 157 OpFoldResult &step) { 158 if (scf::ForOp forOp = scf::getForInductionVarOwner(iv)) { 159 lb = forOp.getLowerBound(); 160 ub = forOp.getUpperBound(); 161 step = forOp.getStep(); 162 return success(); 163 } 164 if (scf::ParallelOp parOp = scf::getParallelForInductionVarOwner(iv)) { 165 for (unsigned idx = 0; idx < parOp.getNumLoops(); ++idx) { 166 if (parOp.getInductionVars()[idx] == iv) { 167 lb = parOp.getLowerBound()[idx]; 168 ub = parOp.getUpperBound()[idx]; 169 step = parOp.getStep()[idx]; 170 return success(); 171 } 172 } 173 return failure(); 174 } 175 if (scf::ForeachThreadOp foreachThreadOp = 176 scf::getForeachThreadOpThreadIndexOwner(iv)) { 177 for (int64_t idx = 0; idx < foreachThreadOp.getRank(); ++idx) { 178 if (foreachThreadOp.getThreadIndices()[idx] == iv) { 179 lb = OpBuilder(iv.getContext()).getIndexAttr(0); 180 ub = foreachThreadOp.getNumThreads()[idx]; 181 step = OpBuilder(iv.getContext()).getIndexAttr(1); 182 return success(); 183 } 184 } 185 return failure(); 186 } 187 return failure(); 188 }; 189 190 return scf::canonicalizeMinMaxOpInLoop(rewriter, op, op.getAffineMap(), 191 op.operands(), IsMin, loopMatcher); 192 } 193 }; 194 195 struct SCFForLoopCanonicalization 196 : public SCFForLoopCanonicalizationBase<SCFForLoopCanonicalization> { 197 void runOnOperation() override { 198 auto *parentOp = getOperation(); 199 MLIRContext *ctx = parentOp->getContext(); 200 RewritePatternSet patterns(ctx); 201 scf::populateSCFForLoopCanonicalizationPatterns(patterns); 202 if (failed(applyPatternsAndFoldGreedily(parentOp, std::move(patterns)))) 203 signalPassFailure(); 204 } 205 }; 206 } // namespace 207 208 void mlir::scf::populateSCFForLoopCanonicalizationPatterns( 209 RewritePatternSet &patterns) { 210 MLIRContext *ctx = patterns.getContext(); 211 patterns 212 .add<AffineOpSCFCanonicalizationPattern<AffineMinOp, /*IsMin=*/true>, 213 AffineOpSCFCanonicalizationPattern<AffineMaxOp, /*IsMin=*/false>, 214 DimOfIterArgFolder<tensor::DimOp>, DimOfIterArgFolder<memref::DimOp>, 215 DimOfLoopResultFolder<tensor::DimOp>, 216 DimOfLoopResultFolder<memref::DimOp>>(ctx); 217 } 218 219 std::unique_ptr<Pass> mlir::createSCFForLoopCanonicalizationPass() { 220 return std::make_unique<SCFForLoopCanonicalization>(); 221 } 222