1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/Linalg/Utils/Utils.h" 18 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 19 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 20 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h" 21 #include "mlir/Dialect/Vector/VectorOps.h" 22 #include "mlir/IR/AffineExpr.h" 23 #include "mlir/IR/Matchers.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "mlir/Pass/Pass.h" 26 #include "mlir/Support/LLVM.h" 27 #include "llvm/Support/Debug.h" 28 #include "llvm/Support/raw_ostream.h" 29 #include <type_traits> 30 31 #define DEBUG_TYPE "linalg-transforms" 32 33 using namespace mlir; 34 using namespace mlir::edsc; 35 using namespace mlir::edsc::intrinsics; 36 using namespace mlir::linalg; 37 38 using llvm::dbgs; 39 40 //===----------------------------------------------------------------------===// 41 // Transformations exposed as rewrite patterns. 42 //===----------------------------------------------------------------------===// 43 // Marker used as attribute name in generated Linalg rewriting transformations. 44 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 45 "__internal_linalg_transform__"; 46 47 mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction, 48 llvm::Optional<StringRef> replacement) 49 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 50 replacement(replacement) {} 51 52 mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction, 53 StringRef replacement) 54 : LinalgMarker(matchDisjunction, llvm::Optional<StringRef>{replacement}) {} 55 56 LogicalResult 57 mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter, 58 Operation *op) const { 59 auto attr = op->template getAttrOfType<StringAttr>( 60 LinalgTransforms::kLinalgTransformMarker); 61 62 if (!attr) { 63 // 1. Has no marker case and matchDisjunction is empty. 64 if (matchDisjunction.empty()) 65 return success(); 66 67 // 2. Has no marker and matchDisjuntion matches the no-moarker case. 68 for (auto marker : matchDisjunction) 69 if (marker.empty()) 70 return success(); 71 72 // 3. Has no marker but was expecting a marker. 73 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 74 diag << " does not have any marker from list: "; 75 llvm::interleaveComma(matchDisjunction, diag); 76 }); 77 } 78 79 // 4. Match explicit marker. 80 for (auto marker : matchDisjunction) 81 if (attr.getValue() == marker) 82 return success(); 83 84 // 5. Fail to match. 85 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 86 diag << " does not have any marker from list: "; 87 llvm::interleaveComma(matchDisjunction, diag); 88 }); 89 } 90 91 void mlir::linalg::LinalgMarker::replaceLinalgMarker(PatternRewriter &rewriter, 92 Operation *op) const { 93 if (replacement.hasValue()) 94 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 95 rewriter.getStringAttr(replacement.getValue())); 96 else 97 op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker, 98 rewriter.getContext())); 99 } 100 101 LinalgTilingOptions & 102 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 103 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 104 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 105 OpBuilder::InsertionGuard guard(b); 106 b.setInsertionPointToStart( 107 &op->getParentOfType<FuncOp>().getBody().front()); 108 return llvm::to_vector<4>(llvm::map_range(tileSizes, [&](int64_t s) { 109 Value v = b.create<ConstantIndexOp>(op->getLoc(), s); 110 return v; 111 })); 112 }; 113 return *this; 114 }; 115 116 /// Linalg base tiling pattern. 117 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 118 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 119 LinalgMarker marker, PatternBenefit benefit) 120 : RewritePattern(opName, {}, benefit, context), marker(marker), 121 options(options) {} 122 123 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewrite( 124 Operation *op, PatternRewriter &rewriter) const { 125 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 126 if (!linalgOp) 127 return failure(); 128 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 129 return failure(); 130 Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options); 131 132 if (!res) 133 return failure(); 134 135 // New marker if specified. 136 marker.replaceLinalgMarker(rewriter, res->op.getOperation()); 137 138 rewriter.eraseOp(op); 139 return success(); 140 } 141 142 /// Linalg base interchange pattern. 143 mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern( 144 StringRef opName, MLIRContext *context, 145 ArrayRef<unsigned> interchangeVector, LinalgMarker marker, 146 PatternBenefit benefit) 147 : RewritePattern(opName, {}, benefit, context), marker(marker), 148 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 149 150 LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite( 151 Operation *op, PatternRewriter &rewriter) const { 152 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 153 if (!linalgOp) 154 return failure(); 155 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 156 return failure(); 157 if (failed(interchangeGenericLinalgOpPrecondition(op, interchangeVector))) 158 return failure(); 159 160 // TODO: figure out how this interplays with named ops. In particular this 161 // should break the named op property. 162 rewriter.updateRootInPlace(op, [&]() { 163 interchange(linalgOp, interchangeVector); 164 // New marker if specified. 165 marker.replaceLinalgMarker(rewriter, op); 166 }); 167 return success(); 168 } 169 170 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 171 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 172 LinalgMarker marker, PatternBenefit benefit) 173 : RewritePattern(opName, {}, benefit, context), marker(marker), 174 options(options) {} 175 176 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 177 Operation *op, PatternRewriter &rewriter) const { 178 if (failed(marker.checkAndNotify(rewriter, op))) 179 return failure(); 180 if (failed(promoteSubviewsPrecondition(op, options))) 181 return failure(); 182 rewriter.updateRootInPlace(op, [&]() { 183 auto promotedOp = promoteSubViews(rewriter, op, options); 184 (void)promotedOp; 185 assert(promotedOp && "Unexpected pattern failure"); 186 marker.replaceLinalgMarker(rewriter, op); 187 }); 188 return success(); 189 } 190 191 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 192 StringRef opName, MLIRContext *context, LinalgMarker marker, 193 PatternBenefit benefit) 194 : RewritePattern(opName, {}, benefit, context), marker(marker) {} 195 196 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( 197 Operation *op, PatternRewriter &rewriter) const { 198 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 199 if (!linalgOp) 200 return failure(); 201 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 202 return failure(); 203 if (failed(vectorizeLinalgOpPrecondition(op))) 204 return failure(); 205 vectorizeLinalgOp(rewriter, op); 206 rewriter.eraseOp(op); 207 return success(); 208 } 209 210 LogicalResult mlir::linalg::applyStagedPatterns( 211 Operation *op, ArrayRef<OwningRewritePatternList> stage1Patterns, 212 const OwningRewritePatternList &stage2Patterns, 213 llvm::function_ref<LogicalResult(Operation *)> stage3Lambda) { 214 for (const auto &patterns : stage1Patterns) { 215 if (!applyPatternsAndFoldGreedily(op, patterns)) { 216 llvm::dbgs() << "Underlying first stage rewrite did not converge"; 217 return failure(); 218 } 219 if (!applyPatternsAndFoldGreedily(op, stage2Patterns)) { 220 llvm::dbgs() << "Underlying second stage rewrite did not converge"; 221 return failure(); 222 } 223 if (stage3Lambda) { 224 if (failed(stage3Lambda(op))) 225 return failure(); 226 } 227 } 228 return success(); 229 } 230