1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/Linalg/Utils/Utils.h" 18 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 19 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 20 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h" 21 #include "mlir/Dialect/Vector/VectorOps.h" 22 #include "mlir/IR/AffineExpr.h" 23 #include "mlir/IR/Matchers.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "mlir/Pass/Pass.h" 26 #include "mlir/Support/LLVM.h" 27 #include "llvm/Support/Debug.h" 28 #include "llvm/Support/raw_ostream.h" 29 #include <type_traits> 30 31 #define DEBUG_TYPE "linalg-transforms" 32 33 using namespace mlir; 34 using namespace mlir::edsc; 35 using namespace mlir::edsc::intrinsics; 36 using namespace mlir::linalg; 37 38 using llvm::dbgs; 39 40 //===----------------------------------------------------------------------===// 41 // Transformations exposed as rewrite patterns. 42 //===----------------------------------------------------------------------===// 43 // Marker used as attribute name in generated Linalg rewriting transformations. 44 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 45 "__internal_linalg_transform__"; 46 47 mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction, 48 llvm::Optional<StringRef> replacement) 49 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 50 replacement(replacement) {} 51 52 mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction, 53 StringRef replacement) 54 : LinalgMarker(matchDisjunction, llvm::Optional<StringRef>{replacement}) {} 55 56 LogicalResult 57 mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter, 58 Operation *op) const { 59 auto attr = op->template getAttrOfType<StringAttr>( 60 LinalgTransforms::kLinalgTransformMarker); 61 62 if (!attr) { 63 // 1. Has no marker case and matchDisjunction is empty. 64 if (matchDisjunction.empty()) 65 return success(); 66 67 // 2. Has no marker and matchDisjuntion matches the no-moarker case. 68 for (auto marker : matchDisjunction) 69 if (marker.empty()) 70 return success(); 71 72 // 3. Has no marker but was expecting a marker. 73 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 74 diag << " does not have any marker from list: "; 75 llvm::interleaveComma(matchDisjunction, diag); 76 }); 77 } 78 79 // 4. Match explicit marker. 80 for (auto marker : matchDisjunction) 81 if (attr.getValue() == marker) 82 return success(); 83 84 // 5. Fail to match. 85 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 86 diag << " does not have any marker from list: "; 87 llvm::interleaveComma(matchDisjunction, diag); 88 }); 89 } 90 91 void mlir::linalg::LinalgMarker::replaceLinalgMarker(PatternRewriter &rewriter, 92 Operation *op) const { 93 if (replacement.hasValue()) 94 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 95 rewriter.getStringAttr(replacement.getValue())); 96 else 97 op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker, 98 rewriter.getContext())); 99 } 100 101 /// Linalg base tiling pattern. 102 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 103 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 104 LinalgMarker marker, PatternBenefit benefit) 105 : RewritePattern(opName, {}, benefit, context), marker(marker), 106 options(options) {} 107 108 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewrite( 109 Operation *op, PatternRewriter &rewriter) const { 110 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 111 if (!linalgOp) 112 return failure(); 113 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 114 return failure(); 115 Optional<TiledLinalgOp> res; 116 if (options.loopType == LinalgTilingLoopType::Loops) 117 res = tileLinalgOp(rewriter, linalgOp, options.tileSizes, 118 options.interchangeVector); 119 else if (options.loopType == LinalgTilingLoopType::ParallelLoops) 120 res = tileLinalgOpToParallelLoops(rewriter, linalgOp, options.tileSizes, 121 options.interchangeVector); 122 // TODO: Impl tiling to affine loops when it makes sense. 123 124 if (!res) 125 return failure(); 126 127 // New marker if specified. 128 marker.replaceLinalgMarker(rewriter, res->op.getOperation()); 129 130 rewriter.eraseOp(op); 131 return success(); 132 } 133 134 /// Linalg base interchange pattern. 135 mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern( 136 StringRef opName, MLIRContext *context, 137 ArrayRef<unsigned> interchangeVector, LinalgMarker marker, 138 PatternBenefit benefit) 139 : RewritePattern(opName, {}, benefit, context), marker(marker), 140 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 141 142 LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite( 143 Operation *op, PatternRewriter &rewriter) const { 144 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 145 if (!linalgOp) 146 return failure(); 147 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 148 return failure(); 149 if (failed(interchangeGenericLinalgOpPrecondition(op, interchangeVector))) 150 return failure(); 151 152 // TODO: figure out how this interplays with named ops. In particular this 153 // should break the named op property. 154 rewriter.updateRootInPlace(op, [&]() { 155 interchange(linalgOp, interchangeVector); 156 // New marker if specified. 157 marker.replaceLinalgMarker(rewriter, op); 158 }); 159 return success(); 160 } 161 162 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 163 StringRef opName, MLIRContext *context, 164 ArrayRef<unsigned> operandsToPromote, unsigned alignment, 165 LinalgMarker marker, PatternBenefit benefit) 166 : RewritePattern(opName, {}, benefit, context), marker(marker), 167 operandsToPromote(operandsToPromote.begin(), operandsToPromote.end()), 168 alignment(alignment) {} 169 170 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 171 Operation *op, PatternRewriter &rewriter) const { 172 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 173 if (!linalgOp) 174 return failure(); 175 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 176 return failure(); 177 if (operandsToPromote.empty()) { 178 if (failed(promoteSubviewsLinalgOpPrecondition(op, llvm::None))) 179 return failure(); 180 } else { 181 DenseSet<unsigned> set; 182 set.insert(operandsToPromote.begin(), operandsToPromote.end()); 183 if (failed(promoteSubviewsLinalgOpPrecondition(op, set))) 184 return failure(); 185 } 186 187 llvm::SetVector<Value> subViews; 188 if (!operandsToPromote.empty()) { 189 for (unsigned idx : operandsToPromote) { 190 auto *op = linalgOp.getBuffer(idx).getDefiningOp(); 191 if (auto sv = dyn_cast_or_null<SubViewOp>(op)) 192 subViews.insert(sv); 193 } 194 } else { 195 unsigned nBuffers = linalgOp.getNumInputsAndOutputBuffers(); 196 for (unsigned idx = 0; idx < nBuffers; ++idx) { 197 auto *op = linalgOp.getBuffer(idx).getDefiningOp(); 198 if (auto sv = dyn_cast_or_null<SubViewOp>(op)) 199 subViews.insert(sv); 200 } 201 } 202 203 auto promotedOp = 204 promoteSubViewOperands(rewriter, op, subViews, /*dynamicBuffers=*/false, 205 /*alignment=*/alignment); 206 marker.replaceLinalgMarker(rewriter, promotedOp.getOperation()); 207 rewriter.eraseOp(op); 208 return success(); 209 } 210 211 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 212 StringRef opName, MLIRContext *context, LinalgMarker marker, 213 PatternBenefit benefit) 214 : RewritePattern(opName, {}, benefit, context), marker(marker) {} 215 216 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( 217 Operation *op, PatternRewriter &rewriter) const { 218 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 219 if (!linalgOp) 220 return failure(); 221 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 222 return failure(); 223 if (failed(vectorizeLinalgOpPrecondition(op))) 224 return failure(); 225 vectorizeLinalgOp(rewriter, op); 226 rewriter.eraseOp(op); 227 return success(); 228 } 229