1 //===- LoopCanonicalization.cpp - Cross-dialect canonicalization patterns -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains cross-dialect canonicalization patterns that cannot be 10 // actual canonicalization patterns due to undesired additional dependencies. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/SCF/Passes.h" 18 #include "mlir/Dialect/SCF/SCF.h" 19 #include "mlir/Dialect/SCF/Transforms.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 23 24 using namespace mlir; 25 using namespace mlir::scf; 26 27 namespace { 28 /// Fold dim ops of iter_args to dim ops of their respective init args. E.g.: 29 /// 30 /// ``` 31 /// %0 = ... : tensor<?x?xf32> 32 /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) { 33 /// %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32> 34 /// ... 35 /// } 36 /// ``` 37 /// 38 /// is folded to: 39 /// 40 /// ``` 41 /// %0 = ... : tensor<?x?xf32> 42 /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) { 43 /// %1 = tensor.dim %0, %c0 : tensor<?x?xf32> 44 /// ... 45 /// } 46 /// ``` 47 template <typename OpTy> 48 struct DimOfIterArgFolder : public OpRewritePattern<OpTy> { 49 using OpRewritePattern<OpTy>::OpRewritePattern; 50 51 LogicalResult matchAndRewrite(OpTy dimOp, 52 PatternRewriter &rewriter) const override { 53 auto blockArg = dimOp.source().template dyn_cast<BlockArgument>(); 54 if (!blockArg) 55 return failure(); 56 auto forOp = dyn_cast<ForOp>(blockArg.getParentBlock()->getParentOp()); 57 if (!forOp) 58 return failure(); 59 60 Value initArg = forOp.getOpOperandForRegionIterArg(blockArg).get(); 61 rewriter.updateRootInPlace( 62 dimOp, [&]() { dimOp.sourceMutable().assign(initArg); }); 63 64 return success(); 65 }; 66 }; 67 68 /// Canonicalize AffineMinOp/AffineMaxOp operations in the context of scf.for 69 /// and scf.parallel loops with a known range. 70 template <typename OpTy, bool IsMin> 71 struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> { 72 using OpRewritePattern<OpTy>::OpRewritePattern; 73 74 LogicalResult matchAndRewrite(OpTy op, 75 PatternRewriter &rewriter) const override { 76 auto loopMatcher = [](Value iv, Value &lb, Value &ub, Value &step) { 77 if (scf::ForOp forOp = scf::getForInductionVarOwner(iv)) { 78 lb = forOp.lowerBound(); 79 ub = forOp.upperBound(); 80 step = forOp.step(); 81 return success(); 82 } 83 if (scf::ParallelOp parOp = scf::getParallelForInductionVarOwner(iv)) { 84 for (unsigned idx = 0; idx < parOp.getNumLoops(); ++idx) { 85 if (parOp.getInductionVars()[idx] == iv) { 86 lb = parOp.lowerBound()[idx]; 87 ub = parOp.upperBound()[idx]; 88 step = parOp.step()[idx]; 89 return success(); 90 } 91 } 92 return failure(); 93 } 94 return failure(); 95 }; 96 97 return scf::canonicalizeMinMaxOpInLoop(rewriter, op, op.getAffineMap(), 98 op.operands(), IsMin, loopMatcher); 99 } 100 }; 101 102 struct SCFForLoopCanonicalization 103 : public SCFForLoopCanonicalizationBase<SCFForLoopCanonicalization> { 104 void runOnFunction() override { 105 FuncOp funcOp = getFunction(); 106 MLIRContext *ctx = funcOp.getContext(); 107 RewritePatternSet patterns(ctx); 108 scf::populateSCFForLoopCanonicalizationPatterns(patterns); 109 if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) 110 signalPassFailure(); 111 } 112 }; 113 } // namespace 114 115 void mlir::scf::populateSCFForLoopCanonicalizationPatterns( 116 RewritePatternSet &patterns) { 117 MLIRContext *ctx = patterns.getContext(); 118 patterns 119 .insert<AffineOpSCFCanonicalizationPattern<AffineMinOp, /*IsMin=*/true>, 120 AffineOpSCFCanonicalizationPattern<AffineMaxOp, /*IsMin=*/false>, 121 DimOfIterArgFolder<tensor::DimOp>, 122 DimOfIterArgFolder<memref::DimOp>>(ctx); 123 } 124 125 std::unique_ptr<Pass> mlir::createSCFForLoopCanonicalizationPass() { 126 return std::make_unique<SCFForLoopCanonicalization>(); 127 } 128