1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/Linalg/Utils/Utils.h" 18 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 19 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 20 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h" 21 #include "mlir/Dialect/Vector/VectorOps.h" 22 #include "mlir/IR/AffineExpr.h" 23 #include "mlir/IR/Matchers.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "mlir/Pass/Pass.h" 26 #include "mlir/Support/LLVM.h" 27 #include "llvm/Support/Debug.h" 28 #include "llvm/Support/raw_ostream.h" 29 #include <type_traits> 30 31 #define DEBUG_TYPE "linalg-transforms" 32 33 using namespace mlir; 34 using namespace mlir::edsc; 35 using namespace mlir::edsc::intrinsics; 36 using namespace mlir::linalg; 37 38 using llvm::dbgs; 39 40 #define DEBUG_TYPE "linalg-transforms" 41 42 //===----------------------------------------------------------------------===// 43 // Transformations exposed as rewrite patterns. 44 //===----------------------------------------------------------------------===// 45 // Marker used as attribute name in generated Linalg rewriting transformations. 46 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 47 "__internal_linalg_transform__"; 48 49 mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction, 50 Optional<StringRef> replacement) 51 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 52 replacement(replacement) {} 53 54 mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction, 55 StringRef replacement) 56 : LinalgMarker(matchDisjunction, Optional<StringRef>{replacement}) {} 57 58 LogicalResult 59 mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter, 60 Operation *op) const { 61 auto attr = op->template getAttrOfType<StringAttr>( 62 LinalgTransforms::kLinalgTransformMarker); 63 64 if (!attr) { 65 // 1. Has no marker case and matchDisjunction is empty. 66 if (matchDisjunction.empty()) 67 return success(); 68 69 // 2. Has no marker and matchDisjuntion matches the no-moarker case. 70 for (auto marker : matchDisjunction) 71 if (marker.empty()) 72 return success(); 73 74 // 3. Has no marker but was expecting a marker. 75 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 76 diag << " does not have any marker from list: "; 77 interleaveComma(matchDisjunction, diag); 78 }); 79 } 80 81 // 4. Match explicit marker. 82 for (auto marker : matchDisjunction) 83 if (attr.getValue() == marker) 84 return success(); 85 86 // 5. Fail to match. 87 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 88 diag << " does not have any marker from list: "; 89 interleaveComma(matchDisjunction, diag); 90 }); 91 } 92 93 void mlir::linalg::LinalgMarker::replaceLinalgMarker(PatternRewriter &rewriter, 94 Operation *op) const { 95 if (replacement.hasValue()) 96 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 97 rewriter.getStringAttr(replacement.getValue())); 98 else 99 op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker, 100 rewriter.getContext())); 101 } 102 103 LinalgTilingOptions & 104 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 105 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 106 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 107 OpBuilder::InsertionGuard guard(b); 108 b.setInsertionPointToStart( 109 &op->getParentOfType<FuncOp>().getBody().front()); 110 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 111 Value v = b.create<ConstantIndexOp>(op->getLoc(), s); 112 return v; 113 })); 114 }; 115 return *this; 116 }; 117 118 /// Linalg base tiling pattern. 119 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 120 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 121 LinalgMarker marker, PatternBenefit benefit) 122 : RewritePattern(opName, {}, benefit, context), marker(marker), 123 options(options) {} 124 125 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewrite( 126 Operation *op, PatternRewriter &rewriter) const { 127 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 128 if (!linalgOp) 129 return failure(); 130 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 131 return failure(); 132 Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options); 133 134 if (!res) 135 return failure(); 136 137 // New marker if specified. 138 marker.replaceLinalgMarker(rewriter, res->op.getOperation()); 139 140 rewriter.eraseOp(op); 141 return success(); 142 } 143 144 /// Linalg base interchange pattern. 145 mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern( 146 StringRef opName, MLIRContext *context, 147 ArrayRef<unsigned> interchangeVector, LinalgMarker marker, 148 PatternBenefit benefit) 149 : RewritePattern(opName, {}, benefit, context), marker(marker), 150 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 151 152 LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite( 153 Operation *op, PatternRewriter &rewriter) const { 154 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 155 if (!linalgOp) 156 return failure(); 157 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 158 return failure(); 159 if (failed(interchangeGenericLinalgOpPrecondition(op, interchangeVector))) 160 return failure(); 161 162 // TODO: figure out how this interplays with named ops. In particular this 163 // should break the named op property. 164 rewriter.updateRootInPlace(op, [&]() { 165 interchange(linalgOp, interchangeVector); 166 // New marker if specified. 167 marker.replaceLinalgMarker(rewriter, op); 168 }); 169 return success(); 170 } 171 172 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 173 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 174 LinalgMarker marker, PatternBenefit benefit) 175 : RewritePattern(opName, {}, benefit, context), marker(marker), 176 options(options) {} 177 178 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 179 Operation *op, PatternRewriter &rewriter) const { 180 if (failed(marker.checkAndNotify(rewriter, op))) 181 return failure(); 182 if (failed(promoteSubviewsPrecondition(op, options))) 183 return failure(); 184 185 // TODO: We cannot use root update here. This pattern is creating other ops, 186 // so if the promotion fails, those need to be cleaned up, which doesnt seem 187 // to be happening here. So to fail properly, we should be cloning the op and 188 // deleting the previous op. This needs more investigation. 189 rewriter.startRootUpdate(op); 190 Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options); 191 if (!promotedOp) { 192 rewriter.cancelRootUpdate(op); 193 return op->emitError("subview promotion failed"); 194 } 195 rewriter.finalizeRootUpdate(op); 196 marker.replaceLinalgMarker(rewriter, op); 197 return success(); 198 } 199 200 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 201 StringRef opName, MLIRContext *context, LinalgMarker marker, 202 PatternBenefit benefit) 203 : RewritePattern(opName, {}, benefit, context), marker(marker) {} 204 205 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( 206 Operation *op, PatternRewriter &rewriter) const { 207 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 208 if (!linalgOp) 209 return failure(); 210 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 211 return failure(); 212 if (failed(vectorizeLinalgOpPrecondition(op))) 213 return failure(); 214 vectorizeLinalgOp(rewriter, op); 215 rewriter.eraseOp(op); 216 return success(); 217 } 218 219 LogicalResult mlir::linalg::applyStagedPatterns( 220 Operation *op, ArrayRef<OwningRewritePatternList> stage1Patterns, 221 const OwningRewritePatternList &stage2Patterns, 222 function_ref<LogicalResult(Operation *)> stage3Lambda) { 223 unsigned iteration = 0; 224 (void)iteration; 225 StringRef dbgPref = "\n[" DEBUG_TYPE "]: "; 226 (void)dbgPref; 227 for (const auto &patterns : stage1Patterns) { 228 if (!applyPatternsAndFoldGreedily(op, patterns)) { 229 dbgs() << "Underlying first stage rewrite did not converge"; 230 return failure(); 231 } 232 LLVM_DEBUG(dbgs() 233 << dbgPref << "After 1st stage, iter: " << ++iteration << "\n" 234 << *op); 235 if (!applyPatternsAndFoldGreedily(op, stage2Patterns)) { 236 LLVM_DEBUG(dbgs() 237 << dbgPref << "Underlying 2nd stage rewrite did not converge"); 238 return failure(); 239 } 240 LLVM_DEBUG(dbgs() 241 << dbgPref << "After 2nd stage, iter : " << iteration << "\n" 242 << *op); 243 if (stage3Lambda) { 244 if (failed(stage3Lambda(op))) 245 return failure(); 246 LLVM_DEBUG(dbgs() 247 << dbgPref << "After 3rd stage, iter : " << iteration << "\n" 248 << *op); 249 } 250 } 251 return success(); 252 } 253