1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/Linalg/Utils/Utils.h" 18 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 19 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 20 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h" 21 #include "mlir/Dialect/Vector/VectorOps.h" 22 #include "mlir/IR/AffineExpr.h" 23 #include "mlir/IR/Matchers.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "mlir/Pass/Pass.h" 26 #include "mlir/Support/LLVM.h" 27 #include "llvm/Support/Debug.h" 28 #include "llvm/Support/raw_ostream.h" 29 #include <type_traits> 30 31 #define DEBUG_TYPE "linalg-transforms" 32 33 using namespace mlir; 34 using namespace mlir::edsc; 35 using namespace mlir::edsc::intrinsics; 36 using namespace mlir::linalg; 37 38 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 39 40 //===----------------------------------------------------------------------===// 41 // Transformations exposed as rewrite patterns. 42 //===----------------------------------------------------------------------===// 43 // Marker used as attribute name in generated Linalg rewriting transformations. 44 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 45 "__internal_linalg_transform__"; 46 47 mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<Identifier> matchDisjunction, 48 Optional<Identifier> replacement) 49 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 50 replacement(replacement) {} 51 52 LogicalResult 53 mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter, 54 Operation *op) const { 55 auto attr = op->template getAttrOfType<StringAttr>( 56 LinalgTransforms::kLinalgTransformMarker); 57 58 if (!attr) { 59 // 1. Has no marker case and matchDisjunction is empty. 60 if (matchDisjunction.empty()) 61 return success(); 62 63 // 2. Has no marker but was expecting a marker. 64 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 65 diag << " does not have any marker from list: "; 66 interleaveComma(matchDisjunction, diag); 67 }); 68 } 69 70 // 4. Match explicit marker. 71 for (auto marker : matchDisjunction) 72 if (attr.getValue() == marker) 73 return success(); 74 75 // 5. Fail to match. 76 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 77 diag << " does not have any marker from list: "; 78 interleaveComma(matchDisjunction, diag); 79 }); 80 } 81 82 void mlir::linalg::LinalgMarker::replaceLinalgMarker(PatternRewriter &rewriter, 83 Operation *op) const { 84 if (replacement.hasValue()) 85 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 86 rewriter.getStringAttr(replacement.getValue())); 87 else 88 op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker, 89 rewriter.getContext())); 90 } 91 92 LinalgTilingOptions & 93 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 94 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 95 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 96 OpBuilder::InsertionGuard guard(b); 97 b.setInsertionPointToStart( 98 &op->getParentOfType<FuncOp>().getBody().front()); 99 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 100 Value v = b.create<ConstantIndexOp>(op->getLoc(), s); 101 return v; 102 })); 103 }; 104 return *this; 105 } 106 107 /// Linalg base tiling pattern. 108 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 109 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 110 LinalgMarker marker, PatternBenefit benefit) 111 : RewritePattern(opName, {}, benefit, context), marker(marker), 112 options(options) {} 113 114 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewrite( 115 Operation *op, PatternRewriter &rewriter) const { 116 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 117 if (!linalgOp) 118 return failure(); 119 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 120 return failure(); 121 122 Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options); 123 124 if (!res) 125 return failure(); 126 127 // New marker if specified. 128 marker.replaceLinalgMarker(rewriter, res->op.getOperation()); 129 return success(); 130 } 131 132 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( 133 StringRef opName, MLIRContext *context, 134 const LinalgDependenceGraph &dependenceGraph, 135 LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, 136 LinalgMarker marker, LinalgMarker fusedOpMarker, 137 LinalgMarker originalOpMarker, PatternBenefit benefit) 138 : RewritePattern(opName, {}, benefit, context), 139 dependenceGraph(dependenceGraph), tilingOptions(tilingOptions), 140 fusionOptions(fusionOptions), marker(marker), 141 fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {} 142 143 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( 144 Operation *op, PatternRewriter &rewriter) const { 145 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 146 if (!linalgOp) 147 return failure(); 148 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 149 return failure(); 150 if (!linalgOp.hasBufferSemantics()) 151 return failure(); 152 153 Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps( 154 rewriter, op, dependenceGraph, tilingOptions, fusionOptions); 155 if (!tiledAndFusedOps) 156 return failure(); 157 marker.replaceLinalgMarker(rewriter, tiledAndFusedOps->op.getOperation()); 158 for (auto fusedOp : tiledAndFusedOps->fusedProducers) { 159 fusedOpMarker.replaceLinalgMarker(rewriter, fusedOp.getOperation()); 160 } 161 for (auto origProducerOp : tiledAndFusedOps->originalProducers) 162 originalOpMarker.replaceLinalgMarker(rewriter, 163 origProducerOp.getOperation()); 164 rewriter.updateRootInPlace( 165 op, [&]() { originalOpMarker.replaceLinalgMarker(rewriter, op); }); 166 return success(); 167 } 168 169 /// Linalg base interchange pattern. 170 mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern( 171 StringRef opName, MLIRContext *context, 172 ArrayRef<unsigned> interchangeVector, LinalgMarker marker, 173 PatternBenefit benefit) 174 : RewritePattern(opName, {}, benefit, context), marker(marker), 175 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 176 177 LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite( 178 Operation *op, PatternRewriter &rewriter) const { 179 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 180 if (!linalgOp) 181 return failure(); 182 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 183 return failure(); 184 if (failed(interchangeGenericLinalgOpPrecondition(op, interchangeVector))) 185 return failure(); 186 187 // TODO: figure out how this interplays with named ops. In particular this 188 // should break the named op property. 189 rewriter.updateRootInPlace(op, [&]() { 190 interchange(linalgOp, interchangeVector); 191 // New marker if specified. 192 marker.replaceLinalgMarker(rewriter, op); 193 }); 194 return success(); 195 } 196 197 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 198 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 199 LinalgMarker marker, PatternBenefit benefit) 200 : RewritePattern(opName, {}, benefit, context), marker(marker), 201 options(options) {} 202 203 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 204 Operation *op, PatternRewriter &rewriter) const { 205 if (failed(marker.checkAndNotify(rewriter, op))) 206 return failure(); 207 if (failed(promoteSubviewsPrecondition(op, options))) 208 return failure(); 209 210 // TODO: We cannot use root update here. This pattern is creating other ops, 211 // so if the promotion fails, those need to be cleaned up, which doesnt seem 212 // to be happening here. So to fail properly, we should be cloning the op and 213 // deleting the previous op. This needs more investigation. 214 rewriter.startRootUpdate(op); 215 Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options); 216 if (!promotedOp) { 217 rewriter.cancelRootUpdate(op); 218 return op->emitError("subview promotion failed"); 219 } 220 rewriter.finalizeRootUpdate(op); 221 marker.replaceLinalgMarker(rewriter, op); 222 return success(); 223 } 224 225 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 226 StringRef opName, MLIRContext *context, LinalgMarker marker, 227 PatternBenefit benefit) 228 : RewritePattern(opName, {}, benefit, context), marker(marker) {} 229 230 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( 231 Operation *op, PatternRewriter &rewriter) const { 232 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 233 if (!linalgOp) 234 return failure(); 235 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 236 return failure(); 237 if (failed(vectorizeLinalgOpPrecondition(op))) 238 return failure(); 239 vectorizeLinalgOp(rewriter, op); 240 rewriter.eraseOp(op); 241 return success(); 242 } 243 244 LogicalResult mlir::linalg::applyStagedPatterns( 245 Operation *op, ArrayRef<OwningRewritePatternList> stage1Patterns, 246 const OwningRewritePatternList &stage2Patterns, 247 function_ref<LogicalResult(Operation *)> stage3Lambda) { 248 unsigned iteration = 0; 249 (void)iteration; 250 for (const auto &patterns : stage1Patterns) { 251 LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" 252 << *op); 253 if (failed(applyPatternsAndFoldGreedily(op, patterns))) { 254 LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); 255 return failure(); 256 } 257 LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" 258 << *op); 259 if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { 260 LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); 261 return failure(); 262 } 263 LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" 264 << *op); 265 if (stage3Lambda) { 266 if (failed(stage3Lambda(op))) 267 return failure(); 268 LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" 269 << *op); 270 } 271 } 272 return success(); 273 } 274 275 /// Traverse `e` and return an AffineExpr where all occurrences of `dim` have 276 /// been replaced by either: 277 /// - `min` if `positivePath` is true when we reach an occurrence of `dim` 278 /// - `max` if `positivePath` is true when we reach an occurrence of `dim` 279 /// `positivePath` is negated each time we hit a multiplicative or divisive 280 /// binary op with a constant negative coefficient. 281 static AffineExpr substWithMin(AffineExpr e, AffineExpr dim, AffineExpr min, 282 AffineExpr max, bool positivePath = true) { 283 if (e == dim) 284 return positivePath ? min : max; 285 if (auto bin = e.dyn_cast<AffineBinaryOpExpr>()) { 286 AffineExpr lhs = bin.getLHS(); 287 AffineExpr rhs = bin.getRHS(); 288 if (bin.getKind() == mlir::AffineExprKind::Add) 289 return substWithMin(lhs, dim, min, max, positivePath) + 290 substWithMin(rhs, dim, min, max, positivePath); 291 292 auto c1 = bin.getLHS().dyn_cast<AffineConstantExpr>(); 293 auto c2 = bin.getRHS().dyn_cast<AffineConstantExpr>(); 294 if (c1 && c1.getValue() < 0) 295 return getAffineBinaryOpExpr( 296 bin.getKind(), c1, substWithMin(rhs, dim, min, max, !positivePath)); 297 if (c2 && c2.getValue() < 0) 298 return getAffineBinaryOpExpr( 299 bin.getKind(), substWithMin(lhs, dim, min, max, !positivePath), c2); 300 return getAffineBinaryOpExpr( 301 bin.getKind(), substWithMin(lhs, dim, min, max, positivePath), 302 substWithMin(rhs, dim, min, max, positivePath)); 303 } 304 return e; 305 } 306 307 /// Given the `lbVal`, `ubVal` and `stepVal` of a loop, append `lbVal` and 308 /// `ubVal` to `dims` and `stepVal` to `symbols`. 309 /// Create new AffineDimExpr (`%lb` and `%ub`) and AffineSymbolExpr (`%step`) 310 /// with positions matching the newly appended values. Substitute occurrences of 311 /// `dimExpr` by either the min expression (i.e. `%lb`) or the max expression 312 /// (i.e. `%lb + %step * floordiv(%ub -1 - %lb, %step)`), depending on whether 313 /// the induction variable is used with a positive or negative coefficient. 314 static AffineExpr substituteLoopInExpr(AffineExpr expr, AffineExpr dimExpr, 315 Value lbVal, Value ubVal, Value stepVal, 316 SmallVectorImpl<Value> &dims, 317 SmallVectorImpl<Value> &symbols) { 318 MLIRContext *ctx = lbVal.getContext(); 319 AffineExpr lb = getAffineDimExpr(dims.size(), ctx); 320 dims.push_back(lbVal); 321 AffineExpr ub = getAffineDimExpr(dims.size(), ctx); 322 dims.push_back(ubVal); 323 AffineExpr step = getAffineSymbolExpr(symbols.size(), ctx); 324 symbols.push_back(stepVal); 325 LLVM_DEBUG(DBGS() << "Before: " << expr << "\n"); 326 AffineExpr ee = substWithMin(expr, dimExpr, lb, 327 lb + step * ((ub - 1) - lb).floorDiv(step)); 328 LLVM_DEBUG(DBGS() << "After: " << expr << "\n"); 329 return ee; 330 } 331 332 /// Traverse the `dims` and substitute known min or max expressions in place of 333 /// induction variables in `exprs`. 334 static AffineMap substitute(AffineMap map, SmallVectorImpl<Value> &dims, 335 SmallVectorImpl<Value> &symbols) { 336 auto exprs = llvm::to_vector<4>(map.getResults()); 337 for (AffineExpr &expr : exprs) { 338 bool substituted = true; 339 while (substituted) { 340 substituted = false; 341 for (unsigned dimIdx = 0; dimIdx < dims.size(); ++dimIdx) { 342 Value dim = dims[dimIdx]; 343 AffineExpr dimExpr = getAffineDimExpr(dimIdx, expr.getContext()); 344 LLVM_DEBUG(DBGS() << "Subst: " << dim << " @ " << dimExpr << "\n"); 345 AffineExpr substitutedExpr; 346 if (auto forOp = scf::getForInductionVarOwner(dim)) 347 substitutedExpr = substituteLoopInExpr( 348 expr, dimExpr, forOp.lowerBound(), forOp.upperBound(), 349 forOp.step(), dims, symbols); 350 351 if (auto parallelForOp = scf::getParallelForInductionVarOwner(dim)) 352 for (unsigned idx = 0, e = parallelForOp.getNumLoops(); idx < e; 353 ++idx) 354 substitutedExpr = substituteLoopInExpr( 355 expr, dimExpr, parallelForOp.lowerBound()[idx], 356 parallelForOp.upperBound()[idx], parallelForOp.step()[idx], 357 dims, symbols); 358 359 if (!substitutedExpr) 360 continue; 361 362 substituted = (substitutedExpr != expr); 363 expr = substitutedExpr; 364 } 365 } 366 367 // Cleanup and simplify the results. 368 // This needs to happen outside of the loop iterating on dims.size() since 369 // it modifies dims. 370 SmallVector<Value, 4> operands(dims.begin(), dims.end()); 371 operands.append(symbols.begin(), symbols.end()); 372 auto map = AffineMap::get(dims.size(), symbols.size(), exprs, 373 exprs.front().getContext()); 374 375 LLVM_DEBUG(DBGS() << "Map to simplify: " << map << "\n"); 376 377 // Pull in affine.apply operations and compose them fully into the 378 // result. 379 fullyComposeAffineMapAndOperands(&map, &operands); 380 canonicalizeMapAndOperands(&map, &operands); 381 map = simplifyAffineMap(map); 382 // Assign the results. 383 exprs.assign(map.getResults().begin(), map.getResults().end()); 384 dims.assign(operands.begin(), operands.begin() + map.getNumDims()); 385 symbols.assign(operands.begin() + map.getNumDims(), operands.end()); 386 387 LLVM_DEBUG(DBGS() << "Map simplified: " << map << "\n"); 388 } 389 390 assert(!exprs.empty() && "Unexpected empty exprs"); 391 return AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext()); 392 } 393 394 LogicalResult AffineMinSCFCanonicalizationPattern::matchAndRewrite( 395 AffineMinOp minOp, PatternRewriter &rewriter) const { 396 LLVM_DEBUG(DBGS() << "Canonicalize AffineMinSCF: " << *minOp.getOperation() 397 << "\n"); 398 399 SmallVector<Value, 4> dims(minOp.getDimOperands()), 400 symbols(minOp.getSymbolOperands()); 401 AffineMap map = substitute(minOp.getAffineMap(), dims, symbols); 402 403 LLVM_DEBUG(DBGS() << "Resulting map: " << map << "\n"); 404 405 // Check whether any of the expressions, when subtracted from all other 406 // expressions, produces only >= 0 constants. If so, it is the min. 407 for (auto e : minOp.getAffineMap().getResults()) { 408 LLVM_DEBUG(DBGS() << "Candidate min: " << e << "\n"); 409 if (!e.isSymbolicOrConstant()) 410 continue; 411 412 auto isNonPositive = [](AffineExpr e) { 413 if (auto cst = e.dyn_cast<AffineConstantExpr>()) 414 return cst.getValue() < 0; 415 return true; 416 }; 417 418 // Build the subMap and check everything is statically known to be 419 // positive. 420 SmallVector<AffineExpr, 4> subExprs; 421 subExprs.reserve(map.getNumResults()); 422 for (auto ee : map.getResults()) 423 subExprs.push_back(ee - e); 424 MLIRContext *ctx = minOp.getContext(); 425 AffineMap subMap = simplifyAffineMap( 426 AffineMap::get(map.getNumDims(), map.getNumSymbols(), subExprs, ctx)); 427 LLVM_DEBUG(DBGS() << "simplified subMap: " << subMap << "\n"); 428 if (llvm::any_of(subMap.getResults(), isNonPositive)) 429 continue; 430 431 // Static min found. 432 if (auto cst = e.dyn_cast<AffineConstantExpr>()) { 433 rewriter.replaceOpWithNewOp<ConstantIndexOp>(minOp, cst.getValue()); 434 } else { 435 auto resultMap = AffineMap::get(0, map.getNumSymbols(), {e}, ctx); 436 SmallVector<Value, 4> resultOperands = dims; 437 resultOperands.append(symbols.begin(), symbols.end()); 438 canonicalizeMapAndOperands(&resultMap, &resultOperands); 439 resultMap = simplifyAffineMap(resultMap); 440 rewriter.replaceOpWithNewOp<AffineApplyOp>(minOp, resultMap, 441 resultOperands); 442 } 443 return success(); 444 } 445 446 return failure(); 447 } 448