1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Affine/Utils.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 18 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/Dialect/SCF/Transforms.h" 21 #include "mlir/Dialect/Tensor/IR/Tensor.h" 22 #include "mlir/Dialect/Utils/StaticValueUtils.h" 23 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 24 #include "mlir/Dialect/Vector/VectorOps.h" 25 #include "mlir/IR/AffineExpr.h" 26 #include "mlir/IR/Matchers.h" 27 #include "mlir/Pass/Pass.h" 28 #include "mlir/Support/LLVM.h" 29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 30 #include "llvm/ADT/ScopeExit.h" 31 #include "llvm/ADT/TypeSwitch.h" 32 #include "llvm/Support/Debug.h" 33 #include "llvm/Support/raw_ostream.h" 34 #include <type_traits> 35 36 #define DEBUG_TYPE "linalg-transforms" 37 38 using namespace mlir; 39 using namespace mlir::linalg; 40 41 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 42 43 //===----------------------------------------------------------------------===// 44 // Transformations exposed as rewrite patterns. 45 //===----------------------------------------------------------------------===// 46 // Marker used as attribute name in generated Linalg rewriting transformations. 47 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 48 "__internal_linalg_transform__"; 49 50 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 51 ArrayRef<Identifier> matchDisjunction, Optional<Identifier> replacement) 52 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 53 replacement(replacement) {} 54 55 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 56 FilterFunction f, ArrayRef<Identifier> matchDisjunction, 57 Optional<Identifier> replacement) 58 : filters(), 59 matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 60 replacement(replacement) { 61 if (f) 62 filters.push_back(f); 63 } 64 65 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify( 66 PatternRewriter &rewriter, Operation *op) const { 67 if (llvm::any_of(filters, 68 [&](const FilterFunction &f) { return failed(f(op)); })) 69 return failure(); 70 71 auto attr = op->template getAttrOfType<StringAttr>( 72 LinalgTransforms::kLinalgTransformMarker); 73 74 if (!attr) { 75 // 1. Has no filter case and matchDisjunction is empty. 76 if (matchDisjunction.empty()) 77 return success(); 78 79 // 2. Has no filter but was expecting a filter. 80 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 81 diag << " does not have any filter from list: "; 82 interleaveComma(matchDisjunction, diag); 83 }); 84 } 85 86 // 4. Match explicit filter. 87 for (auto filter : matchDisjunction) 88 if (attr.getValue() == filter) 89 return success(); 90 91 // 5. Fail to match. 92 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 93 diag << " does not have any filter from list: "; 94 interleaveComma(matchDisjunction, diag); 95 }); 96 } 97 98 void mlir::linalg::LinalgTransformationFilter:: 99 replaceLinalgTransformationFilter(PatternRewriter &rewriter, 100 Operation *op) const { 101 if (replacement.hasValue()) 102 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 103 rewriter.getStringAttr(replacement.getValue().strref())); 104 else 105 op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker, 106 rewriter.getContext())); 107 } 108 109 LinalgTilingOptions & 110 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 111 assert(!tileSizeComputationFunction && "tile sizes already set"); 112 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 113 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 114 OpBuilder::InsertionGuard guard(b); 115 b.setInsertionPointToStart( 116 &op->getParentOfType<FuncOp>().getBody().front()); 117 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 118 Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); 119 return v; 120 })); 121 }; 122 return *this; 123 } 124 125 LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() { 126 assert(!tileSizeComputationFunction && "tile sizes already set"); 127 tileSizeComputationFunction = [](OpBuilder &b, Operation *op) { 128 SmallVector<Value, 4> tileSizes; 129 auto linalgOp = dyn_cast<LinalgOp>(op); 130 if (!linalgOp) 131 return tileSizes; 132 Location loc = linalgOp.getLoc(); 133 auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc); 134 AffineMap map = linalgOp.getShapesToLoopsMap(); 135 if (!map) 136 return tileSizes; 137 auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes); 138 // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile 139 // size 0). 140 for (Value shapeSize : shapeSizes) 141 tileSizes.push_back(getConstantIntValue(shapeSize).hasValue() 142 ? b.create<arith::ConstantIndexOp>(loc, 0) 143 : b.create<arith::ConstantIndexOp>(loc, 1)); 144 return tileSizes; 145 }; 146 return *this; 147 } 148 149 /// Helper function that tries to pad `opOperand`. Exit early and return success 150 /// for scalar operands or if `paddingFunc` returns failure. Otherwise, try to 151 /// pad the operand even if it already has a static shape. Set `result` to the 152 /// result of the created PadTensorOp or return failure if the operand cannot be 153 /// padded to a static shape. 154 static LogicalResult padOperandToSmallestStaticBoundingBox( 155 PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand, 156 const PaddingValueComputationFunction &paddingFunc, Value &result) { 157 // Can't pad scalars. 158 if (opToPad.getShape(opOperand).empty()) 159 return success(); 160 // Can't pad if no padding value is known. 161 FailureOr<Value> paddingValue = paddingFunc(rewriter, *opOperand); 162 if (failed(paddingValue)) 163 return success(); 164 auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>(); 165 // Not a slice op, cannot construct a static bounding box. 166 if (!sliceOp) 167 return failure(); 168 SmallVector<int64_t> staticSizes; 169 staticSizes.reserve(opToPad.getRank(opOperand)); 170 auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation()); 171 for (auto size : shapedOp.getMixedSizes()) { 172 auto indexAttr = size.is<Attribute>() 173 ? size.get<Attribute>().dyn_cast<IntegerAttr>() 174 : linalg::getSmallestBoundingIndex(size.get<Value>()); 175 // SmallestBoundingIndex must exist for all sizes. 176 // For now return an error if we can't find it. 177 if (!indexAttr) 178 return rewriter.notifyMatchFailure( 179 opToPad, "No constant bounding box can be found for padding"); 180 staticSizes.push_back(indexAttr.getInt()); 181 } 182 auto staticTensorType = RankedTensorType::get( 183 staticSizes, getElementTypeOrSelf(opOperand->get())); 184 result = linalg::PadTensorOp::createPadHighOp( 185 staticTensorType, opOperand->get(), paddingValue.getValue(), 186 /*nofold=*/true, opToPad->getLoc(), rewriter); 187 return success(); 188 } 189 190 LogicalResult 191 linalg::rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad, 192 const PaddingValueComputationFunction &paddingFunc, 193 LinalgOp &paddedOp) { 194 Location loc = opToPad->getLoc(); 195 196 // TODO: there are cases where we may still want to pad to larger sizes. 197 assert(opToPad.hasTensorSemantics() && 198 "expected operation to have tensor semantics"); 199 200 OpBuilder::InsertionGuard g(rewriter); 201 // Set IP after op because we also take the dims of the original output. 202 rewriter.setInsertionPointAfter(opToPad); 203 // Make a copy of the shaped operands and update it. 204 SmallVector<Value> newOperands; 205 newOperands.reserve(opToPad.getNumInputsAndOutputs()); 206 for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) { 207 Value paddedOperand; 208 // If padding was requested but the shape cannot be bounded statically then 209 // the pattern fails to apply. 210 if (failed(padOperandToSmallestStaticBoundingBox( 211 rewriter, opToPad, opOperand, paddingFunc, paddedOperand))) 212 return failure(); 213 newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get()); 214 } 215 216 SmallVector<SmallVector<Value>> reifiedResultShapes; 217 if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation()) 218 .reifyResultShapes(rewriter, reifiedResultShapes))) 219 return failure(); 220 assert(reifiedResultShapes.size() == opToPad->getNumResults() && 221 "expected same number of results"); 222 223 // Clone `opToPad` to operate on the statically padded shapes. 224 auto resultTensorTypes = 225 ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes(); 226 paddedOp = opToPad.clone(rewriter, loc, resultTensorTypes, newOperands); 227 228 // Recover the slice out of the new static results. This keeps the original 229 // linalg op around because it uses the dims of the original results. 230 SmallVector<Value> paddedSubviewResults; 231 paddedSubviewResults.reserve(opToPad->getNumResults()); 232 for (auto en : llvm::enumerate(paddedOp->getResults())) { 233 Value paddedResult = en.value(); 234 int64_t resultNumber = en.index(); 235 int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank(); 236 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0)); 237 SmallVector<OpFoldResult> sizes; 238 for (Value v : reifiedResultShapes[resultNumber]) 239 sizes.push_back(v); 240 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); 241 paddedSubviewResults.push_back(rewriter.create<tensor::ExtractSliceOp>( 242 loc, paddedResult, offsets, sizes, strides)); 243 } 244 rewriter.replaceOp(opToPad, paddedSubviewResults); 245 return success(); 246 } 247 248 /// Linalg base tiling pattern. 249 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 250 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 251 LinalgTransformationFilter filter, PatternBenefit benefit) 252 : RewritePattern(opName, benefit, context), filter(filter), 253 options(options) {} 254 255 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 256 MLIRContext *context, LinalgTilingOptions options, 257 LinalgTransformationFilter filter, PatternBenefit benefit) 258 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 259 options(options) {} 260 261 /// Try to peel a loop `op` and return the new result. 262 // TODO: Add support for scf.parallel and affine.for loops. 263 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) { 264 return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op) 265 .Case<scf::ForOp>([&](scf::ForOp forOp) { 266 scf::ForOp partialIteration; 267 if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp, 268 partialIteration))) 269 return partialIteration->getResults(); 270 assert(!partialIteration && "expected that loop was not peeled"); 271 return forOp->getResults(); 272 }) 273 .Default([&](Operation *op) { return op->getResults(); }); 274 } 275 276 /// Try to peel a TiledLoopOp and return the new result. 277 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, 278 TiledLoopOp tiledLoop, int64_t idx) { 279 assert(idx < static_cast<int64_t>(tiledLoop.iterator_types().size()) && 280 "requested peeling of non-existing loop"); 281 TiledLoopOp result; 282 if (succeeded(peelAndCanonicalizeTiledLoop(rewriter, tiledLoop, idx, result))) 283 return result->getResults(); 284 assert(!result && "expected that loop was not peeled"); 285 return tiledLoop->getResults(); 286 } 287 288 /// Peel loops after tiling. 289 static void peelLoops(RewriterBase &rewriter, TiledLinalgOp &res, 290 const LinalgTilingOptions &options) { 291 for (int64_t loop : options.peeledLoops) { 292 assert(loop < static_cast<int64_t>(res.loops.size()) && 293 "requested peeling of non-existing loop"); 294 SmallVector<Value, 4> loopResults; 295 Operation *loopOp = res.loops[loop]; 296 if (options.loopType == LinalgTilingLoopType::TiledLoops) { 297 assert(llvm::all_of( 298 res.loops, 299 [&](Operation *op) { return op == res.loops.front(); }) && 300 "expected that all loop ops are the same TiledLoopOp"); 301 auto tiledLoopOp = dyn_cast<TiledLoopOp>(loopOp); 302 assert(tiledLoopOp && "expected TiledLoopOp"); 303 loopResults = peelLoop(rewriter, tiledLoopOp, loop); 304 } else { 305 loopResults = peelLoop(rewriter, loopOp); 306 } 307 308 // The result of the loop nest may change with peeling. 309 if (res.tensorResults.size() == loopOp->getNumResults() && 310 std::equal(res.tensorResults.begin(), res.tensorResults.end(), 311 loopOp->getResults().begin())) 312 res.tensorResults = loopResults; 313 } 314 } 315 316 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase( 317 Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const { 318 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 319 if (!linalgOp) 320 return failure(); 321 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 322 return failure(); 323 324 Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options); 325 326 if (!res) 327 return failure(); 328 // Clear filter to stop recursive pattern application. 329 filter.replaceLinalgTransformationFilter(rewriter, res->op); 330 331 // Peel loops. 332 peelLoops(rewriter, *res, options); 333 334 // Consider padding on the fly only if the op has tensor semantics. 335 if (!options.paddingValueComputationFunction || 336 !linalgOp.hasTensorSemantics()) { 337 result = *res; 338 return success(); 339 } 340 341 // Try to pad on the fly by rewriting res->op as a padded op. If successful, 342 // `res.op` is rewritten in static form with padded operands. 343 LinalgOp paddedOp; 344 if (succeeded(rewriteAsPaddedOp(rewriter, res->op, 345 options.paddingValueComputationFunction, 346 paddedOp))) { 347 filter.replaceLinalgTransformationFilter(rewriter, paddedOp); 348 res->op = paddedOp; 349 result = *res; 350 // Do not perform replacement of `linalgOp`, let the derived patterns 351 // do this as they see fit, from the resulting TiledLinalgOp. 352 return success(); 353 } 354 // Set so RAII guard does not propagate TiledLinalgOp to `result`. 355 return failure(); 356 } 357 358 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) { 359 if (tiledOp.loops.empty()) 360 return tiledOp.op.getOperation()->getResults(); 361 return tiledOp.loops.front()->getResults(); 362 } 363 364 static ValueRange 365 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) { 366 if (tiledAndFusedOp.fusedLoops.empty()) 367 return tiledAndFusedOp.op.getOperation()->getResults(); 368 return tiledAndFusedOp.fusedLoops.front()->getResults(); 369 } 370 371 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( 372 StringRef opName, MLIRContext *context, 373 const LinalgDependenceGraph &dependenceGraph, 374 LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, 375 LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker, 376 LinalgTransformationFilter originalOpMarker, PatternBenefit benefit) 377 : RewritePattern(opName, benefit, context, {}), 378 dependenceGraph(dependenceGraph), tilingOptions(tilingOptions), 379 fusionOptions(fusionOptions), filter(filter), 380 fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {} 381 382 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( 383 Operation *op, PatternRewriter &rewriter) const { 384 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 385 // TODO: remove hasIndexSemantics check once index ops are supported. 386 if (!linalgOp || linalgOp.hasIndexSemantics()) 387 return failure(); 388 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 389 return failure(); 390 391 DenseSet<Operation *> producers; 392 producers.insert(linalgOp); 393 for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) { 394 Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum(); 395 // When looking at dependences into, indexingOp is always OpOperand. We 396 // could assert, but continue if this is not the case. 397 if (!operandNumber) 398 continue; 399 if (!fusionOptions.indicesToFuse.count(operandNumber.getValue())) 400 continue; 401 if (isa<LinalgOp>(dependence.getDependentOp())) 402 producers.insert(dependence.getDependentOp()); 403 } 404 405 SmallVector<LinalgOp, 1> fusionOps; 406 for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; 407 ++it) { 408 auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it)); 409 if (producerLinalgOp && producers.count(producerLinalgOp)) 410 fusionOps.push_back(producerLinalgOp); 411 } 412 fusionOps.push_back(linalgOp); 413 414 SmallVector<Value, 4> tileSizes = 415 tilingOptions.tileSizeComputationFunction(rewriter, op); 416 LinalgTilingOptions instanceTilingOptions = tilingOptions; 417 instanceTilingOptions.setTileSizes(tileSizes); 418 Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps( 419 rewriter, fusionOps, dependenceGraph, instanceTilingOptions); 420 if (!tiledAndFusedOps) 421 return failure(); 422 423 // Tile the unfused loops; 424 SmallVector<Value, 4> unfusedLoopTileSizes; 425 Value zero = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0); 426 for (auto tileSize : enumerate(tileSizes)) { 427 if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) 428 unfusedLoopTileSizes.push_back(zero); 429 else 430 unfusedLoopTileSizes.push_back(tileSize.value()); 431 } 432 // Tile the loop only if there is a non-zero tile size. 433 if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) 434 unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); 435 if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { 436 if (auto cst = val.getDefiningOp<arith::ConstantIndexOp>()) 437 return cst.value() != 0; 438 return true; 439 })) { 440 LinalgTilingOptions unfusedTilingOptions = tilingOptions; 441 unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); 442 Optional<TiledLinalgOp> unfusedTiledOp = 443 tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); 444 if (!unfusedTiledOp) 445 return failure(); 446 rewriter.replaceOp(tiledAndFusedOps->op, 447 getTiledOpResult(unfusedTiledOp.getValue())); 448 tiledAndFusedOps->op = unfusedTiledOp->op; 449 } 450 op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue())); 451 452 filter.replaceLinalgTransformationFilter(rewriter, 453 tiledAndFusedOps->op.getOperation()); 454 for (auto fusedOp : tiledAndFusedOps->fusedProducers) { 455 fusedOpMarker.replaceLinalgTransformationFilter(rewriter, 456 fusedOp.getOperation()); 457 } 458 for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) { 459 originalOpMarker.replaceLinalgTransformationFilter( 460 rewriter, origProducerOp.getOperation()); 461 } 462 rewriter.updateRootInPlace(op, [&]() { 463 originalOpMarker.replaceLinalgTransformationFilter(rewriter, op); 464 }); 465 return success(); 466 } 467 468 /// Linalg generic interchange pattern. 469 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern( 470 MLIRContext *context, ArrayRef<unsigned> interchangeVector, 471 LinalgTransformationFilter filter, PatternBenefit benefit) 472 : OpRewritePattern(context, benefit), filter(filter), 473 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 474 475 LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite( 476 GenericOp genericOp, PatternRewriter &rewriter) const { 477 if (failed(filter.checkAndNotify(rewriter, genericOp))) 478 return failure(); 479 if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector))) 480 return failure(); 481 482 // TODO: figure out how this interplays with named ops. In particular this 483 // should break the named op property. 484 rewriter.updateRootInPlace(genericOp, [&]() { 485 interchangeGenericOp(rewriter, genericOp, interchangeVector); 486 // New filter if specified. 487 filter.replaceLinalgTransformationFilter(rewriter, genericOp); 488 }); 489 return success(); 490 } 491 492 /// Linalg generalization pattern. 493 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 494 MLIRContext *context, LinalgTransformationFilter filter, 495 PatternBenefit benefit) 496 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {} 497 498 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 499 StringRef opName, MLIRContext *context, LinalgTransformationFilter filter, 500 PatternBenefit benefit) 501 : RewritePattern(opName, benefit, context, {}), filter(filter) {} 502 503 LogicalResult mlir::linalg::LinalgGeneralizationPattern::matchAndRewrite( 504 Operation *op, PatternRewriter &rewriter) const { 505 if (failed(filter.checkAndNotify(rewriter, op))) 506 return failure(); 507 if (failed(generalizeNamedOpPrecondition(op))) 508 return failure(); 509 510 GenericOp genericOp = generalizeNamedOp(rewriter, op); 511 rewriter.replaceOp(op, genericOp.getResults()); 512 filter.replaceLinalgTransformationFilter(rewriter, genericOp); 513 return success(); 514 } 515 516 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 517 MLIRContext *context, LinalgTransformationFilter filter, 518 LinalgPromotionOptions options, PatternBenefit benefit) 519 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 520 options(options) {} 521 522 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 523 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 524 LinalgTransformationFilter filter, PatternBenefit benefit) 525 : RewritePattern(opName, benefit, context, {}), filter(filter), 526 options(options) {} 527 528 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 529 Operation *op, PatternRewriter &rewriter) const { 530 if (failed(filter.checkAndNotify(rewriter, op))) 531 return failure(); 532 if (failed(promoteSubviewsPrecondition(op, options))) 533 return failure(); 534 535 // TODO: We cannot use root update here. This pattern is creating other ops, 536 // so if the promotion fails, those need to be cleaned up, which doesnt seem 537 // to be happening here. So to fail properly, we should be cloning the op and 538 // deleting the previous op. This needs more investigation. 539 rewriter.startRootUpdate(op); 540 Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options); 541 if (!promotedOp) { 542 rewriter.cancelRootUpdate(op); 543 return op->emitError("subview promotion failed"); 544 } 545 rewriter.finalizeRootUpdate(op); 546 filter.replaceLinalgTransformationFilter(rewriter, op); 547 return success(); 548 } 549 550 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 551 MLIRContext *context, LinalgTransformationFilter filter, 552 PatternBenefit benefit) 553 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {} 554 555 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 556 StringRef opName, MLIRContext *context, LinalgTransformationFilter filter, 557 PatternBenefit benefit) 558 : RewritePattern(opName, benefit, context, {}), filter(filter) {} 559 560 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( 561 Operation *op, PatternRewriter &rewriter) const { 562 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 563 if (!linalgOp) 564 return failure(); 565 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 566 return failure(); 567 SmallVector<Value> newResults; 568 if (failed(vectorizeLinalgOp(rewriter, op, newResults))) 569 return failure(); 570 if (!newResults.empty()) 571 rewriter.replaceOp(op, newResults); 572 else 573 rewriter.eraseOp(op); 574 return success(); 575 } 576 577 LogicalResult mlir::linalg::applyStagedPatterns( 578 Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns, 579 const FrozenRewritePatternSet &stage2Patterns, 580 function_ref<LogicalResult(Operation *)> stage3Lambda) { 581 unsigned iteration = 0; 582 (void)iteration; 583 for (const auto &patterns : stage1Patterns) { 584 LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" 585 << *op); 586 if (failed(applyPatternsAndFoldGreedily(op, patterns))) { 587 LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); 588 return failure(); 589 } 590 LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" 591 << *op); 592 if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { 593 LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); 594 return failure(); 595 } 596 LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" 597 << *op); 598 if (stage3Lambda) { 599 if (failed(stage3Lambda(op))) 600 return failure(); 601 LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" 602 << *op); 603 } 604 } 605 return success(); 606 } 607 608 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) { 609 return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName()); 610 } 611 612 /// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize 613 /// with pad_val) and GenericOp (to copy contents). 614 LogicalResult PadTensorOpTransformationPattern::matchAndRewrite( 615 linalg::PadTensorOp padOp, PatternRewriter &rewriter) const { 616 617 auto inputShapedType = padOp.source().getType().cast<ShapedType>(); 618 auto resultShapedType = padOp.result().getType().cast<ShapedType>(); 619 620 // Bail on non-static shapes. 621 if (!inputShapedType.hasStaticShape()) 622 return failure(); 623 if (!resultShapedType.hasStaticShape()) 624 return failure(); 625 626 // Only support padding with a constant for now, i.e. either: 627 // 1. A BBarg from a different block. 628 // 2. A value defined outside of the current block. 629 Block &block = padOp.region().front(); 630 auto yieldOp = cast<YieldOp>(block.getTerminator()); 631 assert(yieldOp.getNumOperands() == 1 && "expected single operand yield"); 632 Value padValue = yieldOp.values().front(); 633 Operation *definingOp = padValue.getDefiningOp(); 634 if (definingOp && definingOp->getBlock() == &block) 635 return failure(); 636 if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block) 637 return failure(); 638 639 // Create tensor with the padded shape 640 Location loc = padOp.getLoc(); 641 SmallVector<Value> indices(resultShapedType.getRank(), 642 rewriter.create<arith::ConstantIndexOp>(loc, 0)); 643 Value initTensor = rewriter.create<InitTensorOp>( 644 loc, resultShapedType.getShape(), resultShapedType.getElementType()); 645 646 // Initialize tensor with the pad value 647 Value tmpTensor = 648 rewriter.create<linalg::FillOp>(loc, padValue, initTensor).result(); 649 650 // Copy original contents into new tensor 651 // Uses linalg.generic, but could be done with tensor.insert_slice 652 SmallVector<AffineExpr, 4> outputExprs; 653 for (unsigned i = 0; i < resultShapedType.getRank(); ++i) { 654 outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) + 655 padOp.static_low()[i].cast<IntegerAttr>().getInt()); 656 } 657 658 SmallVector<AffineMap, 2> transferMaps = { 659 rewriter.getMultiDimIdentityMap(inputShapedType.getRank()), 660 AffineMap::get(resultShapedType.getRank(), 661 /*symbolCount=*/0, outputExprs, rewriter.getContext())}; 662 663 rewriter.replaceOpWithNewOp<linalg::GenericOp>( 664 padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps, 665 getNParallelLoopsAttrs(resultShapedType.getRank()), 666 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { 667 nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]); 668 }); 669 670 return success(); 671 } 672 673 /// Filling `dest` using FillOp constant padding value if possible. 674 /// Otherwise, generate a tensor::GenerateOp. 675 Value GeneralizePadTensorOpPattern::createFillOrGenerateOp( 676 PatternRewriter &rewriter, PadTensorOp padOp, Value dest, 677 const SmallVector<Value> &dynSizes) const { 678 auto padValue = padOp.getConstantPaddingValue(); 679 if (padValue) 680 return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result(); 681 682 // Fill could not be optimized: Lower to tensor::GenerateOp with region. 683 auto generateOp = rewriter.create<tensor::GenerateOp>( 684 padOp.getLoc(), padOp.getResultType(), dynSizes); 685 // Copy region to new op. 686 BlockAndValueMapping bvm; 687 padOp.region().cloneInto(&generateOp.getRegion(), bvm); 688 // Rewrite linalg::YieldOp to tensor::YieldOp. 689 OpBuilder::InsertionGuard guard(rewriter); 690 auto yieldOp = 691 dyn_cast<linalg::YieldOp>(generateOp.getRegion().front().getTerminator()); 692 assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator"); 693 assert(yieldOp.values().size() == 1); 694 rewriter.setInsertionPoint(yieldOp); 695 rewriter.replaceOpWithNewOp<tensor::YieldOp>(yieldOp, yieldOp.values()[0]); 696 return generateOp; 697 } 698 699 LogicalResult 700 GeneralizePadTensorOpPattern::matchAndRewrite(PadTensorOp padOp, 701 PatternRewriter &rewriter) const { 702 // Given an OpFoldResult, return an index-typed value. 703 auto getIdxValue = [&](OpFoldResult ofr) { 704 if (auto val = ofr.dyn_cast<Value>()) 705 return val; 706 return rewriter 707 .create<arith::ConstantIndexOp>( 708 padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt()) 709 .getResult(); 710 }; 711 712 auto resultType = padOp.getResultType(); 713 // Compute size of InitTensorOp. Any combination of static/dynamic is 714 // supported. 715 SmallVector<Value> dynSizes; 716 SmallVector<int64_t> staticSizes; 717 for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { 718 if (resultType.isDynamicDim(dim)) { 719 auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(), 720 padOp.source(), dim); 721 // Add low and high padding value. 722 auto plusLow = rewriter.createOrFold<arith::AddIOp>( 723 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); 724 auto plusHigh = rewriter.createOrFold<arith::AddIOp>( 725 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); 726 dynSizes.push_back(plusHigh); 727 } 728 staticSizes.push_back(resultType.getDimSize(dim)); 729 } 730 731 // Init tensor and fill it with padding. 732 Value init = rewriter.create<InitTensorOp>( 733 padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType()); 734 Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes); 735 736 // Try optimize the copy of source. 737 if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded()) 738 return success(); 739 740 // PadTensorOps cannot be optimized. Generate a InsertSliceOp instead 741 // for copying the PadOp source. 742 auto sourceType = padOp.getSourceType(); 743 // Compute size of source of PadTensorOp. 744 SmallVector<OpFoldResult> srcSizes; 745 for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) { 746 if (sourceType.isDynamicDim(dim)) { 747 srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>( 748 padOp.getLoc(), padOp.source(), dim)); 749 } else { 750 srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim))); 751 } 752 } 753 // Strides of InsertSliceOp are all 1. 754 SmallVector<OpFoldResult> strides(sourceType.getRank(), 755 rewriter.getIndexAttr(1)); 756 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( 757 padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides); 758 759 return success(); 760 } 761 762 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( 763 tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { 764 auto padOp = sliceOp.source().getDefiningOp<PadTensorOp>(); 765 if (!padOp) 766 return failure(); 767 // Only unit stride supported. 768 if (!sliceOp.hasUnitStride()) 769 return failure(); 770 771 Operation *tiledPadOp = padOp.getTiledImplementation( 772 rewriter, /*dest=*/ValueRange{}, sliceOp.getMixedOffsets(), 773 sliceOp.getMixedSizes()); 774 // All shapes are static and the data source is actually used. Rewrite into 775 // pad_tensor(subtensor(x)). 776 rewriter.replaceOp(sliceOp, tiledPadOp->getResults()); 777 return success(); 778 } 779