1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/Linalg/Utils/Utils.h" 18 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 19 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 20 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h" 21 #include "mlir/Dialect/Vector/VectorOps.h" 22 #include "mlir/IR/AffineExpr.h" 23 #include "mlir/IR/Matchers.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "mlir/Pass/Pass.h" 26 #include "mlir/Support/LLVM.h" 27 #include "llvm/Support/Debug.h" 28 #include "llvm/Support/raw_ostream.h" 29 #include <type_traits> 30 31 #define DEBUG_TYPE "linalg-transforms" 32 33 using namespace mlir; 34 using namespace mlir::edsc; 35 using namespace mlir::edsc::intrinsics; 36 using namespace mlir::linalg; 37 38 using llvm::dbgs; 39 40 //===----------------------------------------------------------------------===// 41 // Transformations exposed as rewrite patterns. 42 //===----------------------------------------------------------------------===// 43 // Marker used as attribute name in generated Linalg rewriting transformations. 44 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 45 "__internal_linalg_transform__"; 46 47 mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction, 48 llvm::Optional<StringRef> replacement) 49 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 50 replacement(replacement) {} 51 52 mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction, 53 StringRef replacement) 54 : LinalgMarker(matchDisjunction, llvm::Optional<StringRef>{replacement}) {} 55 56 LogicalResult 57 mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter, 58 Operation *op) const { 59 auto attr = op->template getAttrOfType<StringAttr>( 60 LinalgTransforms::kLinalgTransformMarker); 61 62 if (!attr) { 63 // 1. Has no marker case and matchDisjunction is empty. 64 if (matchDisjunction.empty()) 65 return success(); 66 67 // 2. Has no marker and matchDisjuntion matches the no-moarker case. 68 for (auto marker : matchDisjunction) 69 if (marker.empty()) 70 return success(); 71 72 // 3. Has no marker but was expecting a marker. 73 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 74 diag << " does not have any marker from list: "; 75 llvm::interleaveComma(matchDisjunction, diag); 76 }); 77 } 78 79 // 4. Match explicit marker. 80 for (auto marker : matchDisjunction) 81 if (attr.getValue() == marker) 82 return success(); 83 84 // 5. Fail to match. 85 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 86 diag << " does not have any marker from list: "; 87 llvm::interleaveComma(matchDisjunction, diag); 88 }); 89 } 90 91 void mlir::linalg::LinalgMarker::replaceLinalgMarker(PatternRewriter &rewriter, 92 Operation *op) const { 93 if (replacement.hasValue()) 94 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 95 rewriter.getStringAttr(replacement.getValue())); 96 else 97 op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker, 98 rewriter.getContext())); 99 } 100 101 /// Linalg base tiling pattern. 102 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 103 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 104 LinalgMarker marker, PatternBenefit benefit) 105 : RewritePattern(opName, {}, benefit, context), marker(marker), 106 options(options) {} 107 108 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewrite( 109 Operation *op, PatternRewriter &rewriter) const { 110 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 111 if (!linalgOp) 112 return failure(); 113 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 114 return failure(); 115 Optional<TiledLinalgOp> res; 116 if (options.loopType == LinalgTilingLoopType::Loops) 117 res = tileLinalgOp(rewriter, linalgOp, options.tileSizes, 118 options.interchangeVector); 119 else if (options.loopType == LinalgTilingLoopType::ParallelLoops) 120 res = tileLinalgOpToParallelLoops(rewriter, linalgOp, options.tileSizes, 121 options.interchangeVector); 122 // TODO: Impl tiling to affine loops when it makes sense. 123 124 if (!res) 125 return failure(); 126 127 // New marker if specified. 128 marker.replaceLinalgMarker(rewriter, res->op.getOperation()); 129 130 rewriter.eraseOp(op); 131 return success(); 132 } 133 134 /// Linalg base interchange pattern. 135 mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern( 136 StringRef opName, MLIRContext *context, 137 ArrayRef<unsigned> interchangeVector, LinalgMarker marker, 138 PatternBenefit benefit) 139 : RewritePattern(opName, {}, benefit, context), marker(marker), 140 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 141 142 LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite( 143 Operation *op, PatternRewriter &rewriter) const { 144 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 145 if (!linalgOp) 146 return failure(); 147 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 148 return failure(); 149 if (failed(interchangeGenericLinalgOpPrecondition(op, interchangeVector))) 150 return failure(); 151 152 // TODO: figure out how this interplays with named ops. In particular this 153 // should break the named op property. 154 rewriter.updateRootInPlace(op, [&]() { 155 interchange(linalgOp, interchangeVector); 156 // New marker if specified. 157 marker.replaceLinalgMarker(rewriter, op); 158 }); 159 return success(); 160 } 161 162 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 163 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 164 LinalgMarker marker, PatternBenefit benefit) 165 : RewritePattern(opName, {}, benefit, context), marker(marker), 166 options(options) {} 167 168 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 169 Operation *op, PatternRewriter &rewriter) const { 170 if (failed(marker.checkAndNotify(rewriter, op))) 171 return failure(); 172 if (failed(promoteSubviewsPrecondition(op, options))) 173 return failure(); 174 rewriter.updateRootInPlace(op, [&]() { 175 auto promotedOp = promoteSubViews(rewriter, op, options); 176 (void)promotedOp; 177 assert(promotedOp && "Unexpected pattern failure"); 178 marker.replaceLinalgMarker(rewriter, op); 179 }); 180 return success(); 181 } 182 183 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 184 StringRef opName, MLIRContext *context, LinalgMarker marker, 185 PatternBenefit benefit) 186 : RewritePattern(opName, {}, benefit, context), marker(marker) {} 187 188 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( 189 Operation *op, PatternRewriter &rewriter) const { 190 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 191 if (!linalgOp) 192 return failure(); 193 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 194 return failure(); 195 if (failed(vectorizeLinalgOpPrecondition(op))) 196 return failure(); 197 vectorizeLinalgOp(rewriter, op); 198 rewriter.eraseOp(op); 199 return success(); 200 } 201 202 LogicalResult mlir::linalg::applyStagedPatterns( 203 Operation *op, ArrayRef<OwningRewritePatternList> stage1Patterns, 204 const OwningRewritePatternList &stage2Patterns, 205 llvm::function_ref<LogicalResult(Operation *)> stage3Lambda) { 206 for (const auto &patterns : stage1Patterns) { 207 if (!applyPatternsAndFoldGreedily(op, patterns)) { 208 llvm::dbgs() << "Underlying first stage rewrite did not converge"; 209 return failure(); 210 } 211 if (!applyPatternsAndFoldGreedily(op, stage2Patterns)) { 212 llvm::dbgs() << "Underlying second stage rewrite did not converge"; 213 return failure(); 214 } 215 if (stage3Lambda) { 216 if (failed(stage3Lambda(op))) 217 return failure(); 218 } 219 } 220 return success(); 221 } 222