1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Affine/Utils.h" 16 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 20 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 21 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h" 22 #include "mlir/Dialect/Vector/VectorOps.h" 23 #include "mlir/IR/AffineExpr.h" 24 #include "mlir/IR/Matchers.h" 25 #include "mlir/Pass/Pass.h" 26 #include "mlir/Support/LLVM.h" 27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 28 #include "llvm/Support/Debug.h" 29 #include "llvm/Support/raw_ostream.h" 30 #include <type_traits> 31 32 #define DEBUG_TYPE "linalg-transforms" 33 34 using namespace mlir; 35 using namespace mlir::edsc; 36 using namespace mlir::edsc::intrinsics; 37 using namespace mlir::linalg; 38 39 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 40 41 //===----------------------------------------------------------------------===// 42 // Transformations exposed as rewrite patterns. 43 //===----------------------------------------------------------------------===// 44 // Marker used as attribute name in generated Linalg rewriting transformations. 45 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 46 "__internal_linalg_transform__"; 47 48 mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<Identifier> matchDisjunction, 49 Optional<Identifier> replacement) 50 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 51 replacement(replacement) {} 52 53 LogicalResult 54 mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter, 55 Operation *op) const { 56 auto attr = op->template getAttrOfType<StringAttr>( 57 LinalgTransforms::kLinalgTransformMarker); 58 59 if (!attr) { 60 // 1. Has no marker case and matchDisjunction is empty. 61 if (matchDisjunction.empty()) 62 return success(); 63 64 // 2. Has no marker but was expecting a marker. 65 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 66 diag << " does not have any marker from list: "; 67 interleaveComma(matchDisjunction, diag); 68 }); 69 } 70 71 // 4. Match explicit marker. 72 for (auto marker : matchDisjunction) 73 if (attr.getValue() == marker) 74 return success(); 75 76 // 5. Fail to match. 77 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 78 diag << " does not have any marker from list: "; 79 interleaveComma(matchDisjunction, diag); 80 }); 81 } 82 83 void mlir::linalg::LinalgMarker::replaceLinalgMarker(PatternRewriter &rewriter, 84 Operation *op) const { 85 if (replacement.hasValue()) 86 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 87 rewriter.getStringAttr(replacement.getValue())); 88 else 89 op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker, 90 rewriter.getContext())); 91 } 92 93 LinalgTilingOptions & 94 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 95 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 96 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 97 OpBuilder::InsertionGuard guard(b); 98 b.setInsertionPointToStart( 99 &op->getParentOfType<FuncOp>().getBody().front()); 100 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 101 Value v = b.create<ConstantIndexOp>(op->getLoc(), s); 102 return v; 103 })); 104 }; 105 return *this; 106 } 107 108 /// Linalg base tiling pattern. 109 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 110 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 111 LinalgMarker marker, PatternBenefit benefit) 112 : RewritePattern(opName, {}, benefit, context), marker(marker), 113 options(options) {} 114 115 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 116 LinalgTilingOptions options, LinalgMarker marker, PatternBenefit benefit) 117 : RewritePattern(benefit, MatchAnyOpTypeTag()), marker(marker), 118 options(options) {} 119 120 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase( 121 Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const { 122 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 123 if (!linalgOp) 124 return failure(); 125 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 126 return failure(); 127 128 Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options); 129 130 if (!res) 131 return failure(); 132 133 // Return relevant information to derived pattern. 134 result = *res; 135 136 // New marker if specified. 137 marker.replaceLinalgMarker(rewriter, res->op.getOperation()); 138 return success(); 139 } 140 141 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( 142 StringRef opName, MLIRContext *context, 143 const LinalgDependenceGraph &dependenceGraph, 144 LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, 145 LinalgMarker marker, LinalgMarker fusedOpMarker, 146 LinalgMarker originalOpMarker, PatternBenefit benefit) 147 : RewritePattern(opName, {}, benefit, context), 148 dependenceGraph(dependenceGraph), tilingOptions(tilingOptions), 149 fusionOptions(fusionOptions), marker(marker), 150 fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {} 151 152 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( 153 Operation *op, PatternRewriter &rewriter) const { 154 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 155 if (!linalgOp) 156 return failure(); 157 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 158 return failure(); 159 if (!linalgOp.hasBufferSemantics()) 160 return failure(); 161 162 DenseSet<Operation *> producers; 163 producers.insert(linalgOp); 164 for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) { 165 Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum(); 166 // When looking at dependences into, indexingOp is always OpOperand. We 167 // could assert, but continue if this is not the case. 168 if (!operandNumber) 169 continue; 170 if (!fusionOptions.indicesToFuse.count(operandNumber.getValue())) 171 continue; 172 if (isa<LinalgOp>(dependence.getDependentOp())) 173 producers.insert(dependence.getDependentOp()); 174 } 175 176 SmallVector<LinalgOp, 1> fusionOps; 177 for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; 178 ++it) { 179 auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it)); 180 if (producerLinalgOp && producers.count(producerLinalgOp)) 181 fusionOps.push_back(producerLinalgOp); 182 } 183 fusionOps.push_back(linalgOp); 184 185 SmallVector<Value, 4> tileSizes = 186 tilingOptions.tileSizeComputationFunction(rewriter, op); 187 LinalgTilingOptions instanceTilingOptions = tilingOptions; 188 instanceTilingOptions.setTileSizes(tileSizes); 189 Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps( 190 rewriter, fusionOps, dependenceGraph, instanceTilingOptions); 191 if (!tiledAndFusedOps) 192 return failure(); 193 194 // Tile the unfused loops; 195 SmallVector<Value, 4> unfusedLoopTileSizes; 196 Value zero = rewriter.create<ConstantIndexOp>(op->getLoc(), 0); 197 for (auto tileSize : enumerate(tileSizes)) { 198 if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) 199 unfusedLoopTileSizes.push_back(zero); 200 else 201 unfusedLoopTileSizes.push_back(tileSize.value()); 202 } 203 // Tile the loop only if there is a non-zero tile size. 204 if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) 205 unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); 206 if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { 207 if (auto cst = val.getDefiningOp<ConstantIndexOp>()) 208 return cst.getValue() != 0; 209 return true; 210 })) { 211 LinalgTilingOptions unfusedTilingOptions = tilingOptions; 212 unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); 213 Optional<TiledLinalgOp> unfusedTiledOp = 214 tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); 215 if (!unfusedTiledOp) 216 return failure(); 217 rewriter.eraseOp(tiledAndFusedOps->op); 218 tiledAndFusedOps->op = unfusedTiledOp->op; 219 } 220 221 marker.replaceLinalgMarker(rewriter, tiledAndFusedOps->op.getOperation()); 222 for (auto fusedOp : tiledAndFusedOps->fusedProducers) { 223 fusedOpMarker.replaceLinalgMarker(rewriter, fusedOp.getOperation()); 224 } 225 for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) { 226 originalOpMarker.replaceLinalgMarker(rewriter, 227 origProducerOp.getOperation()); 228 } 229 rewriter.updateRootInPlace( 230 op, [&]() { originalOpMarker.replaceLinalgMarker(rewriter, op); }); 231 return success(); 232 } 233 234 /// Linalg base interchange pattern. 235 mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern( 236 StringRef opName, MLIRContext *context, 237 ArrayRef<unsigned> interchangeVector, LinalgMarker marker, 238 PatternBenefit benefit) 239 : RewritePattern(opName, {}, benefit, context), marker(marker), 240 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 241 242 LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite( 243 Operation *op, PatternRewriter &rewriter) const { 244 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 245 if (!linalgOp) 246 return failure(); 247 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 248 return failure(); 249 if (failed(interchangeGenericLinalgOpPrecondition(op, interchangeVector))) 250 return failure(); 251 252 // TODO: figure out how this interplays with named ops. In particular this 253 // should break the named op property. 254 rewriter.updateRootInPlace(op, [&]() { 255 interchange(linalgOp, interchangeVector); 256 // New marker if specified. 257 marker.replaceLinalgMarker(rewriter, op); 258 }); 259 return success(); 260 } 261 262 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 263 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 264 LinalgMarker marker, PatternBenefit benefit) 265 : RewritePattern(opName, {}, benefit, context), marker(marker), 266 options(options) {} 267 268 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 269 Operation *op, PatternRewriter &rewriter) const { 270 if (failed(marker.checkAndNotify(rewriter, op))) 271 return failure(); 272 if (failed(promoteSubviewsPrecondition(op, options))) 273 return failure(); 274 275 // TODO: We cannot use root update here. This pattern is creating other ops, 276 // so if the promotion fails, those need to be cleaned up, which doesnt seem 277 // to be happening here. So to fail properly, we should be cloning the op and 278 // deleting the previous op. This needs more investigation. 279 rewriter.startRootUpdate(op); 280 Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options); 281 if (!promotedOp) { 282 rewriter.cancelRootUpdate(op); 283 return op->emitError("subview promotion failed"); 284 } 285 rewriter.finalizeRootUpdate(op); 286 marker.replaceLinalgMarker(rewriter, op); 287 return success(); 288 } 289 290 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 291 StringRef opName, MLIRContext *context, LinalgMarker marker, 292 PatternBenefit benefit) 293 : RewritePattern(opName, {}, benefit, context), marker(marker) {} 294 295 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( 296 Operation *op, PatternRewriter &rewriter) const { 297 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 298 if (!linalgOp) 299 return failure(); 300 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 301 return failure(); 302 if (failed(vectorizeLinalgOpPrecondition(op))) 303 return failure(); 304 vectorizeLinalgOp(rewriter, op); 305 rewriter.eraseOp(op); 306 return success(); 307 } 308 309 LogicalResult mlir::linalg::applyStagedPatterns( 310 Operation *op, ArrayRef<FrozenRewritePatternList> stage1Patterns, 311 const FrozenRewritePatternList &stage2Patterns, 312 function_ref<LogicalResult(Operation *)> stage3Lambda) { 313 unsigned iteration = 0; 314 (void)iteration; 315 for (const auto &patterns : stage1Patterns) { 316 LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" 317 << *op); 318 if (failed(applyPatternsAndFoldGreedily(op, patterns))) { 319 LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); 320 return failure(); 321 } 322 LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" 323 << *op); 324 if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { 325 LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); 326 return failure(); 327 } 328 LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" 329 << *op); 330 if (stage3Lambda) { 331 if (failed(stage3Lambda(op))) 332 return failure(); 333 LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" 334 << *op); 335 } 336 } 337 return success(); 338 } 339 340 /// Given the `lbVal`, `ubVal` and `stepVal` of a loop, append `lbVal` and 341 /// `ubVal` to `dims` and `stepVal` to `symbols`. 342 /// Create new AffineDimExpr (`%lb` and `%ub`) and AffineSymbolExpr (`%step`) 343 /// with positions matching the newly appended values. Substitute occurrences of 344 /// `dimExpr` by either the min expression (i.e. `%lb`) or the max expression 345 /// (i.e. `%lb + %step * floordiv(%ub -1 - %lb, %step)`), depending on whether 346 /// the induction variable is used with a positive or negative coefficient. 347 static AffineExpr substituteLoopInExpr(AffineExpr expr, AffineExpr dimExpr, 348 Value lbVal, Value ubVal, Value stepVal, 349 SmallVectorImpl<Value> &dims, 350 SmallVectorImpl<Value> &symbols) { 351 MLIRContext *ctx = lbVal.getContext(); 352 AffineExpr lb = getAffineDimExpr(dims.size(), ctx); 353 dims.push_back(lbVal); 354 AffineExpr ub = getAffineDimExpr(dims.size(), ctx); 355 dims.push_back(ubVal); 356 AffineExpr step = getAffineSymbolExpr(symbols.size(), ctx); 357 symbols.push_back(stepVal); 358 LLVM_DEBUG(DBGS() << "Before: " << expr << "\n"); 359 AffineExpr ee = substWithMin(expr, dimExpr, lb, 360 lb + step * ((ub - 1) - lb).floorDiv(step)); 361 LLVM_DEBUG(DBGS() << "After: " << expr << "\n"); 362 return ee; 363 } 364 365 /// Traverse the `dims` and substitute known min or max expressions in place of 366 /// induction variables in `exprs`. 367 static AffineMap substitute(AffineMap map, SmallVectorImpl<Value> &dims, 368 SmallVectorImpl<Value> &symbols) { 369 auto exprs = llvm::to_vector<4>(map.getResults()); 370 for (AffineExpr &expr : exprs) { 371 bool substituted = true; 372 while (substituted) { 373 substituted = false; 374 for (unsigned dimIdx = 0; dimIdx < dims.size(); ++dimIdx) { 375 Value dim = dims[dimIdx]; 376 AffineExpr dimExpr = getAffineDimExpr(dimIdx, expr.getContext()); 377 LLVM_DEBUG(DBGS() << "Subst: " << dim << " @ " << dimExpr << "\n"); 378 AffineExpr substitutedExpr; 379 if (auto forOp = scf::getForInductionVarOwner(dim)) 380 substitutedExpr = substituteLoopInExpr( 381 expr, dimExpr, forOp.lowerBound(), forOp.upperBound(), 382 forOp.step(), dims, symbols); 383 384 if (auto parallelForOp = scf::getParallelForInductionVarOwner(dim)) 385 for (unsigned idx = 0, e = parallelForOp.getNumLoops(); idx < e; 386 ++idx) 387 substitutedExpr = substituteLoopInExpr( 388 expr, dimExpr, parallelForOp.lowerBound()[idx], 389 parallelForOp.upperBound()[idx], parallelForOp.step()[idx], 390 dims, symbols); 391 392 if (!substitutedExpr) 393 continue; 394 395 substituted = (substitutedExpr != expr); 396 expr = substitutedExpr; 397 } 398 } 399 400 // Cleanup and simplify the results. 401 // This needs to happen outside of the loop iterating on dims.size() since 402 // it modifies dims. 403 SmallVector<Value, 4> operands(dims.begin(), dims.end()); 404 operands.append(symbols.begin(), symbols.end()); 405 auto map = AffineMap::get(dims.size(), symbols.size(), exprs, 406 exprs.front().getContext()); 407 408 LLVM_DEBUG(DBGS() << "Map to simplify: " << map << "\n"); 409 410 // Pull in affine.apply operations and compose them fully into the 411 // result. 412 fullyComposeAffineMapAndOperands(&map, &operands); 413 canonicalizeMapAndOperands(&map, &operands); 414 map = simplifyAffineMap(map); 415 // Assign the results. 416 exprs.assign(map.getResults().begin(), map.getResults().end()); 417 dims.assign(operands.begin(), operands.begin() + map.getNumDims()); 418 symbols.assign(operands.begin() + map.getNumDims(), operands.end()); 419 420 LLVM_DEBUG(DBGS() << "Map simplified: " << map << "\n"); 421 } 422 423 assert(!exprs.empty() && "Unexpected empty exprs"); 424 return AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext()); 425 } 426 427 LogicalResult AffineMinSCFCanonicalizationPattern::matchAndRewrite( 428 AffineMinOp minOp, PatternRewriter &rewriter) const { 429 LLVM_DEBUG(DBGS() << "Canonicalize AffineMinSCF: " << *minOp.getOperation() 430 << "\n"); 431 432 SmallVector<Value, 4> dims(minOp.getDimOperands()), 433 symbols(minOp.getSymbolOperands()); 434 AffineMap map = substitute(minOp.getAffineMap(), dims, symbols); 435 436 LLVM_DEBUG(DBGS() << "Resulting map: " << map << "\n"); 437 438 // Check whether any of the expressions, when subtracted from all other 439 // expressions, produces only >= 0 constants. If so, it is the min. 440 for (auto e : minOp.getAffineMap().getResults()) { 441 LLVM_DEBUG(DBGS() << "Candidate min: " << e << "\n"); 442 if (!e.isSymbolicOrConstant()) 443 continue; 444 445 auto isNonPositive = [](AffineExpr e) { 446 if (auto cst = e.dyn_cast<AffineConstantExpr>()) 447 return cst.getValue() < 0; 448 return true; 449 }; 450 451 // Build the subMap and check everything is statically known to be 452 // positive. 453 SmallVector<AffineExpr, 4> subExprs; 454 subExprs.reserve(map.getNumResults()); 455 for (auto ee : map.getResults()) 456 subExprs.push_back(ee - e); 457 MLIRContext *ctx = minOp.getContext(); 458 AffineMap subMap = simplifyAffineMap( 459 AffineMap::get(map.getNumDims(), map.getNumSymbols(), subExprs, ctx)); 460 LLVM_DEBUG(DBGS() << "simplified subMap: " << subMap << "\n"); 461 if (llvm::any_of(subMap.getResults(), isNonPositive)) 462 continue; 463 464 // Static min found. 465 if (auto cst = e.dyn_cast<AffineConstantExpr>()) { 466 rewriter.replaceOpWithNewOp<ConstantIndexOp>(minOp, cst.getValue()); 467 } else { 468 auto resultMap = AffineMap::get(0, map.getNumSymbols(), {e}, ctx); 469 SmallVector<Value, 4> resultOperands = dims; 470 resultOperands.append(symbols.begin(), symbols.end()); 471 canonicalizeMapAndOperands(&resultMap, &resultOperands); 472 resultMap = simplifyAffineMap(resultMap); 473 rewriter.replaceOpWithNewOp<AffineApplyOp>(minOp, resultMap, 474 resultOperands); 475 } 476 return success(); 477 } 478 479 return failure(); 480 } 481