1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/Linalg/Utils/Utils.h" 18 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 19 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 20 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h" 21 #include "mlir/Dialect/Vector/VectorOps.h" 22 #include "mlir/IR/AffineExpr.h" 23 #include "mlir/IR/Matchers.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "mlir/Pass/Pass.h" 26 #include "mlir/Support/LLVM.h" 27 #include "llvm/Support/Debug.h" 28 #include "llvm/Support/raw_ostream.h" 29 #include <type_traits> 30 31 #define DEBUG_TYPE "linalg-transforms" 32 33 using namespace mlir; 34 using namespace mlir::edsc; 35 using namespace mlir::edsc::intrinsics; 36 using namespace mlir::linalg; 37 38 using llvm::dbgs; 39 40 #define DEBUG_TYPE "linalg-transforms" 41 42 //===----------------------------------------------------------------------===// 43 // Transformations exposed as rewrite patterns. 44 //===----------------------------------------------------------------------===// 45 // Marker used as attribute name in generated Linalg rewriting transformations. 46 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 47 "__internal_linalg_transform__"; 48 49 mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<Identifier> matchDisjunction, 50 Optional<Identifier> replacement) 51 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 52 replacement(replacement) {} 53 54 LogicalResult 55 mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter, 56 Operation *op) const { 57 auto attr = op->template getAttrOfType<StringAttr>( 58 LinalgTransforms::kLinalgTransformMarker); 59 60 if (!attr) { 61 // 1. Has no marker case and matchDisjunction is empty. 62 if (matchDisjunction.empty()) 63 return success(); 64 65 // 2. Has no marker but was expecting a marker. 66 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 67 diag << " does not have any marker from list: "; 68 interleaveComma(matchDisjunction, diag); 69 }); 70 } 71 72 // 4. Match explicit marker. 73 for (auto marker : matchDisjunction) 74 if (attr.getValue() == marker) 75 return success(); 76 77 // 5. Fail to match. 78 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 79 diag << " does not have any marker from list: "; 80 interleaveComma(matchDisjunction, diag); 81 }); 82 } 83 84 void mlir::linalg::LinalgMarker::replaceLinalgMarker(PatternRewriter &rewriter, 85 Operation *op) const { 86 if (replacement.hasValue()) 87 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 88 rewriter.getStringAttr(replacement.getValue())); 89 else 90 op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker, 91 rewriter.getContext())); 92 } 93 94 LinalgTilingOptions & 95 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 96 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 97 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 98 OpBuilder::InsertionGuard guard(b); 99 b.setInsertionPointToStart( 100 &op->getParentOfType<FuncOp>().getBody().front()); 101 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 102 Value v = b.create<ConstantIndexOp>(op->getLoc(), s); 103 return v; 104 })); 105 }; 106 return *this; 107 }; 108 109 /// Linalg base tiling pattern. 110 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 111 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 112 LinalgMarker marker, PatternBenefit benefit) 113 : RewritePattern(opName, {}, benefit, context), marker(marker), 114 options(options) {} 115 116 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewrite( 117 Operation *op, PatternRewriter &rewriter) const { 118 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 119 if (!linalgOp) 120 return failure(); 121 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 122 return failure(); 123 Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options); 124 125 if (!res) 126 return failure(); 127 128 // New marker if specified. 129 marker.replaceLinalgMarker(rewriter, res->op.getOperation()); 130 131 rewriter.eraseOp(op); 132 return success(); 133 } 134 135 /// Linalg base interchange pattern. 136 mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern( 137 StringRef opName, MLIRContext *context, 138 ArrayRef<unsigned> interchangeVector, LinalgMarker marker, 139 PatternBenefit benefit) 140 : RewritePattern(opName, {}, benefit, context), marker(marker), 141 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 142 143 LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite( 144 Operation *op, PatternRewriter &rewriter) const { 145 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 146 if (!linalgOp) 147 return failure(); 148 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 149 return failure(); 150 if (failed(interchangeGenericLinalgOpPrecondition(op, interchangeVector))) 151 return failure(); 152 153 // TODO: figure out how this interplays with named ops. In particular this 154 // should break the named op property. 155 rewriter.updateRootInPlace(op, [&]() { 156 interchange(linalgOp, interchangeVector); 157 // New marker if specified. 158 marker.replaceLinalgMarker(rewriter, op); 159 }); 160 return success(); 161 } 162 163 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 164 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 165 LinalgMarker marker, PatternBenefit benefit) 166 : RewritePattern(opName, {}, benefit, context), marker(marker), 167 options(options) {} 168 169 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 170 Operation *op, PatternRewriter &rewriter) const { 171 if (failed(marker.checkAndNotify(rewriter, op))) 172 return failure(); 173 if (failed(promoteSubviewsPrecondition(op, options))) 174 return failure(); 175 176 // TODO: We cannot use root update here. This pattern is creating other ops, 177 // so if the promotion fails, those need to be cleaned up, which doesnt seem 178 // to be happening here. So to fail properly, we should be cloning the op and 179 // deleting the previous op. This needs more investigation. 180 rewriter.startRootUpdate(op); 181 Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options); 182 if (!promotedOp) { 183 rewriter.cancelRootUpdate(op); 184 return op->emitError("subview promotion failed"); 185 } 186 rewriter.finalizeRootUpdate(op); 187 marker.replaceLinalgMarker(rewriter, op); 188 return success(); 189 } 190 191 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 192 StringRef opName, MLIRContext *context, LinalgMarker marker, 193 PatternBenefit benefit) 194 : RewritePattern(opName, {}, benefit, context), marker(marker) {} 195 196 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( 197 Operation *op, PatternRewriter &rewriter) const { 198 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 199 if (!linalgOp) 200 return failure(); 201 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 202 return failure(); 203 if (failed(vectorizeLinalgOpPrecondition(op))) 204 return failure(); 205 vectorizeLinalgOp(rewriter, op); 206 rewriter.eraseOp(op); 207 return success(); 208 } 209 210 LogicalResult mlir::linalg::applyStagedPatterns( 211 Operation *op, ArrayRef<OwningRewritePatternList> stage1Patterns, 212 const OwningRewritePatternList &stage2Patterns, 213 function_ref<LogicalResult(Operation *)> stage3Lambda) { 214 unsigned iteration = 0; 215 (void)iteration; 216 StringRef dbgPref = "\n[" DEBUG_TYPE "]: "; 217 (void)dbgPref; 218 for (const auto &patterns : stage1Patterns) { 219 if (!applyPatternsAndFoldGreedily(op, patterns)) { 220 dbgs() << "Underlying first stage rewrite did not converge"; 221 return failure(); 222 } 223 LLVM_DEBUG(dbgs() 224 << dbgPref << "After 1st stage, iter: " << ++iteration << "\n" 225 << *op); 226 if (!applyPatternsAndFoldGreedily(op, stage2Patterns)) { 227 LLVM_DEBUG(dbgs() 228 << dbgPref << "Underlying 2nd stage rewrite did not converge"); 229 return failure(); 230 } 231 LLVM_DEBUG(dbgs() 232 << dbgPref << "After 2nd stage, iter : " << iteration << "\n" 233 << *op); 234 if (stage3Lambda) { 235 if (failed(stage3Lambda(op))) 236 return failure(); 237 LLVM_DEBUG(dbgs() 238 << dbgPref << "After 3rd stage, iter : " << iteration << "\n" 239 << *op); 240 } 241 } 242 return success(); 243 } 244