1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/Linalg/Utils/Utils.h" 18 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 19 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 20 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h" 21 #include "mlir/Dialect/Vector/VectorOps.h" 22 #include "mlir/IR/AffineExpr.h" 23 #include "mlir/IR/Matchers.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "mlir/Pass/Pass.h" 26 #include "mlir/Support/LLVM.h" 27 #include "llvm/Support/Debug.h" 28 #include "llvm/Support/raw_ostream.h" 29 #include <type_traits> 30 31 #define DEBUG_TYPE "linalg-transforms" 32 33 using namespace mlir; 34 using namespace mlir::edsc; 35 using namespace mlir::edsc::intrinsics; 36 using namespace mlir::linalg; 37 38 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 39 //===----------------------------------------------------------------------===// 40 // Transformations exposed as rewrite patterns. 41 //===----------------------------------------------------------------------===// 42 // Marker used as attribute name in generated Linalg rewriting transformations. 43 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 44 "__internal_linalg_transform__"; 45 46 mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<Identifier> matchDisjunction, 47 Optional<Identifier> replacement) 48 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 49 replacement(replacement) {} 50 51 LogicalResult 52 mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter, 53 Operation *op) const { 54 auto attr = op->template getAttrOfType<StringAttr>( 55 LinalgTransforms::kLinalgTransformMarker); 56 57 if (!attr) { 58 // 1. Has no marker case and matchDisjunction is empty. 59 if (matchDisjunction.empty()) 60 return success(); 61 62 // 2. Has no marker but was expecting a marker. 63 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 64 diag << " does not have any marker from list: "; 65 interleaveComma(matchDisjunction, diag); 66 }); 67 } 68 69 // 4. Match explicit marker. 70 for (auto marker : matchDisjunction) 71 if (attr.getValue() == marker) 72 return success(); 73 74 // 5. Fail to match. 75 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 76 diag << " does not have any marker from list: "; 77 interleaveComma(matchDisjunction, diag); 78 }); 79 } 80 81 void mlir::linalg::LinalgMarker::replaceLinalgMarker(PatternRewriter &rewriter, 82 Operation *op) const { 83 if (replacement.hasValue()) 84 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 85 rewriter.getStringAttr(replacement.getValue())); 86 else 87 op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker, 88 rewriter.getContext())); 89 } 90 91 LinalgTilingOptions & 92 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 93 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 94 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 95 OpBuilder::InsertionGuard guard(b); 96 b.setInsertionPointToStart( 97 &op->getParentOfType<FuncOp>().getBody().front()); 98 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 99 Value v = b.create<ConstantIndexOp>(op->getLoc(), s); 100 return v; 101 })); 102 }; 103 return *this; 104 } 105 106 /// Linalg base tiling pattern. 107 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 108 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 109 LinalgMarker marker, PatternBenefit benefit) 110 : RewritePattern(opName, {}, benefit, context), marker(marker), 111 options(options) {} 112 113 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewrite( 114 Operation *op, PatternRewriter &rewriter) const { 115 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 116 if (!linalgOp) 117 return failure(); 118 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 119 return failure(); 120 121 Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options); 122 123 if (!res) 124 return failure(); 125 126 // New marker if specified. 127 marker.replaceLinalgMarker(rewriter, res->op.getOperation()); 128 129 rewriter.eraseOp(op); 130 return success(); 131 } 132 133 /// Linalg base interchange pattern. 134 mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern( 135 StringRef opName, MLIRContext *context, 136 ArrayRef<unsigned> interchangeVector, LinalgMarker marker, 137 PatternBenefit benefit) 138 : RewritePattern(opName, {}, benefit, context), marker(marker), 139 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 140 141 LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite( 142 Operation *op, PatternRewriter &rewriter) const { 143 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 144 if (!linalgOp) 145 return failure(); 146 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 147 return failure(); 148 if (failed(interchangeGenericLinalgOpPrecondition(op, interchangeVector))) 149 return failure(); 150 151 // TODO: figure out how this interplays with named ops. In particular this 152 // should break the named op property. 153 rewriter.updateRootInPlace(op, [&]() { 154 interchange(linalgOp, interchangeVector); 155 // New marker if specified. 156 marker.replaceLinalgMarker(rewriter, op); 157 }); 158 return success(); 159 } 160 161 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 162 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 163 LinalgMarker marker, PatternBenefit benefit) 164 : RewritePattern(opName, {}, benefit, context), marker(marker), 165 options(options) {} 166 167 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 168 Operation *op, PatternRewriter &rewriter) const { 169 if (failed(marker.checkAndNotify(rewriter, op))) 170 return failure(); 171 if (failed(promoteSubviewsPrecondition(op, options))) 172 return failure(); 173 174 // TODO: We cannot use root update here. This pattern is creating other ops, 175 // so if the promotion fails, those need to be cleaned up, which doesnt seem 176 // to be happening here. So to fail properly, we should be cloning the op and 177 // deleting the previous op. This needs more investigation. 178 rewriter.startRootUpdate(op); 179 Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options); 180 if (!promotedOp) { 181 rewriter.cancelRootUpdate(op); 182 return op->emitError("subview promotion failed"); 183 } 184 rewriter.finalizeRootUpdate(op); 185 marker.replaceLinalgMarker(rewriter, op); 186 return success(); 187 } 188 189 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 190 StringRef opName, MLIRContext *context, LinalgMarker marker, 191 PatternBenefit benefit) 192 : RewritePattern(opName, {}, benefit, context), marker(marker) {} 193 194 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( 195 Operation *op, PatternRewriter &rewriter) const { 196 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 197 if (!linalgOp) 198 return failure(); 199 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 200 return failure(); 201 if (failed(vectorizeLinalgOpPrecondition(op))) 202 return failure(); 203 vectorizeLinalgOp(rewriter, op); 204 rewriter.eraseOp(op); 205 return success(); 206 } 207 208 LogicalResult mlir::linalg::applyStagedPatterns( 209 Operation *op, ArrayRef<OwningRewritePatternList> stage1Patterns, 210 const OwningRewritePatternList &stage2Patterns, 211 function_ref<LogicalResult(Operation *)> stage3Lambda) { 212 unsigned iteration = 0; 213 (void)iteration; 214 for (const auto &patterns : stage1Patterns) { 215 LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" 216 << *op); 217 if (failed(applyPatternsAndFoldGreedily(op, patterns))) { 218 LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); 219 return failure(); 220 } 221 LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" 222 << *op); 223 if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { 224 LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); 225 return failure(); 226 } 227 LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" 228 << *op); 229 if (stage3Lambda) { 230 if (failed(stage3Lambda(op))) 231 return failure(); 232 LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" 233 << *op); 234 } 235 } 236 return success(); 237 } 238