1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Affine/Utils.h" 16 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/SCF/Transforms.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/Dialect/Utils/StaticValueUtils.h" 22 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 23 #include "mlir/Dialect/Vector/VectorOps.h" 24 #include "mlir/IR/AffineExpr.h" 25 #include "mlir/IR/Matchers.h" 26 #include "mlir/Pass/Pass.h" 27 #include "mlir/Support/LLVM.h" 28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 29 #include "llvm/ADT/ScopeExit.h" 30 #include "llvm/ADT/TypeSwitch.h" 31 #include "llvm/Support/Debug.h" 32 #include "llvm/Support/raw_ostream.h" 33 #include <type_traits> 34 35 #define DEBUG_TYPE "linalg-transforms" 36 37 using namespace mlir; 38 using namespace mlir::linalg; 39 40 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 41 42 //===----------------------------------------------------------------------===// 43 // Transformations exposed as rewrite patterns. 44 //===----------------------------------------------------------------------===// 45 // Marker used as attribute name in generated Linalg rewriting transformations. 46 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 47 "__internal_linalg_transform__"; 48 49 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 50 ArrayRef<Identifier> matchDisjunction, Optional<Identifier> replacement) 51 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 52 replacement(replacement) {} 53 54 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 55 FilterFunction f, ArrayRef<Identifier> matchDisjunction, 56 Optional<Identifier> replacement) 57 : filters(), 58 matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 59 replacement(replacement) { 60 if (f) 61 filters.push_back(f); 62 } 63 64 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify( 65 PatternRewriter &rewriter, Operation *op) const { 66 if (llvm::any_of(filters, 67 [&](const FilterFunction &f) { return failed(f(op)); })) 68 return failure(); 69 70 auto attr = op->template getAttrOfType<StringAttr>( 71 LinalgTransforms::kLinalgTransformMarker); 72 73 if (!attr) { 74 // 1. Has no filter case and matchDisjunction is empty. 75 if (matchDisjunction.empty()) 76 return success(); 77 78 // 2. Has no filter but was expecting a filter. 79 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 80 diag << " does not have any filter from list: "; 81 interleaveComma(matchDisjunction, diag); 82 }); 83 } 84 85 // 4. Match explicit filter. 86 for (auto filter : matchDisjunction) 87 if (attr.getValue() == filter) 88 return success(); 89 90 // 5. Fail to match. 91 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 92 diag << " does not have any filter from list: "; 93 interleaveComma(matchDisjunction, diag); 94 }); 95 } 96 97 void mlir::linalg::LinalgTransformationFilter:: 98 replaceLinalgTransformationFilter(PatternRewriter &rewriter, 99 Operation *op) const { 100 if (replacement.hasValue()) 101 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 102 rewriter.getStringAttr(replacement.getValue().strref())); 103 else 104 op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker, 105 rewriter.getContext())); 106 } 107 108 LinalgTilingOptions & 109 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 110 assert(!tileSizeComputationFunction && "tile sizes already set"); 111 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 112 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 113 OpBuilder::InsertionGuard guard(b); 114 b.setInsertionPointToStart( 115 &op->getParentOfType<FuncOp>().getBody().front()); 116 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 117 Value v = b.create<ConstantIndexOp>(op->getLoc(), s); 118 return v; 119 })); 120 }; 121 return *this; 122 } 123 124 LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() { 125 assert(!tileSizeComputationFunction && "tile sizes already set"); 126 tileSizeComputationFunction = [](OpBuilder &b, Operation *op) { 127 SmallVector<Value, 4> tileSizes; 128 auto linalgOp = dyn_cast<LinalgOp>(op); 129 if (!linalgOp) 130 return tileSizes; 131 Location loc = linalgOp.getLoc(); 132 auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc); 133 AffineMap map = linalgOp.getShapesToLoopsMap(); 134 if (!map) 135 return tileSizes; 136 auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes); 137 // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile 138 // size 0). 139 for (Value shapeSize : shapeSizes) 140 tileSizes.push_back(getConstantIntValue(shapeSize).hasValue() 141 ? b.create<ConstantIndexOp>(loc, 0) 142 : b.create<ConstantIndexOp>(loc, 1)); 143 return tileSizes; 144 }; 145 return *this; 146 } 147 148 /// Try to compute a static bounding box for `operand`. The padding happens 149 /// even if the operand already has static shape. `result` is the result of a 150 /// freshly created PadTensorOp. Return failure if the operand cannot be padded 151 /// to a static shape. 152 static LogicalResult padOperandToSmallestStaticBoundingBox( 153 PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand, 154 const PaddingValueComputationFunction &paddingFunc, Value &result) { 155 // Can't pad scalars. 156 if (opToPad.getShape(opOperand).empty()) 157 return success(); 158 auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>(); 159 // Not a slice op, cannot construct a static bounding box. 160 if (!sliceOp) 161 return failure(); 162 SmallVector<int64_t> staticSizes; 163 staticSizes.reserve(opToPad.getRank(opOperand)); 164 auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation()); 165 for (auto size : shapedOp.getMixedSizes()) { 166 auto indexAttr = size.is<Attribute>() 167 ? size.get<Attribute>().dyn_cast<IntegerAttr>() 168 : linalg::getSmallestBoundingIndex(size.get<Value>()); 169 // SmallestBoundingIndex must exist for all sizes. 170 // For now return an error if we can't find it. 171 if (!indexAttr) 172 return rewriter.notifyMatchFailure( 173 opToPad, "No constant bounding box can be found for padding"); 174 staticSizes.push_back(indexAttr.getInt()); 175 } 176 Value pad = paddingFunc(rewriter, *opOperand); 177 auto staticTensorType = RankedTensorType::get( 178 staticSizes, getElementTypeOrSelf(opOperand->get())); 179 result = linalg::PadTensorOp::createPadHighOp( 180 staticTensorType, opOperand->get(), pad, /*packing=*/true, 181 opToPad->getLoc(), rewriter); 182 return success(); 183 } 184 185 LogicalResult 186 linalg::rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad, 187 const PaddingValueComputationFunction &paddingFunc, 188 LinalgOp &paddedOp) { 189 Location loc = opToPad->getLoc(); 190 191 // TODO: there are cases where we may still want to pad to larger sizes. 192 assert(opToPad.hasTensorSemantics() && 193 "expected operation to have tensor semantics"); 194 195 OpBuilder::InsertionGuard g(rewriter); 196 // Set IP after op because we also take the dims of the original output. 197 rewriter.setInsertionPointAfter(opToPad); 198 // Make a copy of the shaped operands and update it. 199 SmallVector<Value> newOperands; 200 newOperands.reserve(opToPad.getNumInputsAndOutputs()); 201 for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) { 202 Value paddedOperand; 203 // If padding was requested but the shape cannot be bounded statically then 204 // the pattern fails to apply. 205 if (failed(padOperandToSmallestStaticBoundingBox( 206 rewriter, opToPad, opOperand, paddingFunc, paddedOperand))) 207 return failure(); 208 newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get()); 209 } 210 211 SmallVector<SmallVector<Value>> reifiedResultShapes; 212 if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation()) 213 .reifyResultShapes(rewriter, reifiedResultShapes))) 214 return failure(); 215 assert(reifiedResultShapes.size() == opToPad->getNumResults() && 216 "expected same number of results"); 217 218 // Clone `opToPad` to operate on the statically padded shapes. 219 auto resultTensorTypes = 220 ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes(); 221 paddedOp = opToPad.clone(rewriter, loc, resultTensorTypes, newOperands); 222 223 // Recover the slice out of the new static results. This keeps the original 224 // linalg op around because it uses the dims of the original results. 225 SmallVector<Value> paddedSubviewResults; 226 paddedSubviewResults.reserve(opToPad->getNumResults()); 227 for (auto en : llvm::enumerate(paddedOp->getResults())) { 228 Value paddedResult = en.value(); 229 int64_t resultNumber = en.index(); 230 int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank(); 231 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0)); 232 SmallVector<OpFoldResult> sizes; 233 for (Value v : reifiedResultShapes[resultNumber]) 234 sizes.push_back(v); 235 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); 236 paddedSubviewResults.push_back(rewriter.create<tensor::ExtractSliceOp>( 237 loc, paddedResult, offsets, sizes, strides)); 238 } 239 rewriter.replaceOp(opToPad, paddedSubviewResults); 240 return success(); 241 } 242 243 /// Linalg base tiling pattern. 244 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 245 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 246 LinalgTransformationFilter filter, PatternBenefit benefit) 247 : RewritePattern(opName, benefit, context), filter(filter), 248 options(options) {} 249 250 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 251 MLIRContext *context, LinalgTilingOptions options, 252 LinalgTransformationFilter filter, PatternBenefit benefit) 253 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 254 options(options) {} 255 256 /// Try to peel a loop `op` and return the new result. 257 // TODO: Add support for scf.parallel and affine.for loops. 258 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) { 259 return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op) 260 .Case<scf::ForOp>([&](scf::ForOp forOp) { 261 scf::ForOp partialIteration; 262 if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp, 263 partialIteration))) 264 return partialIteration->getResults(); 265 assert(!partialIteration && "expected that loop was not peeled"); 266 return forOp->getResults(); 267 }) 268 .Default([&](Operation *op) { return op->getResults(); }); 269 } 270 271 /// Try to peel a TiledLoopOp and return the new result. 272 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, 273 TiledLoopOp tiledLoop, int64_t idx) { 274 assert(idx < static_cast<int64_t>(tiledLoop.iterator_types().size()) && 275 "requested peeling of non-existing loop"); 276 TiledLoopOp result; 277 if (succeeded(peelAndCanonicalizeTiledLoop(rewriter, tiledLoop, idx, result))) 278 return result->getResults(); 279 assert(!result && "expected that loop was not peeled"); 280 return tiledLoop->getResults(); 281 } 282 283 /// Peel loops after tiling. 284 static void peelLoops(RewriterBase &rewriter, TiledLinalgOp &res, 285 const LinalgTilingOptions &options) { 286 for (int64_t loop : options.peeledLoops) { 287 assert(loop < static_cast<int64_t>(res.loops.size()) && 288 "requested peeling of non-existing loop"); 289 SmallVector<Value, 4> loopResults; 290 Operation *loopOp = res.loops[loop]; 291 if (options.loopType == LinalgTilingLoopType::TiledLoops) { 292 assert(llvm::all_of( 293 res.loops, 294 [&](Operation *op) { return op == res.loops.front(); }) && 295 "expected that all loop ops are the same TiledLoopOp"); 296 auto tiledLoopOp = dyn_cast<TiledLoopOp>(loopOp); 297 assert(tiledLoopOp && "expected TiledLoopOp"); 298 loopResults = peelLoop(rewriter, tiledLoopOp, loop); 299 } else { 300 loopResults = peelLoop(rewriter, loopOp); 301 } 302 303 // The result of the loop nest may change with peeling. 304 if (res.tensorResults.size() == loopOp->getNumResults() && 305 std::equal(res.tensorResults.begin(), res.tensorResults.end(), 306 loopOp->getResults().begin())) 307 res.tensorResults = loopResults; 308 } 309 } 310 311 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase( 312 Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const { 313 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 314 if (!linalgOp) 315 return failure(); 316 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 317 return failure(); 318 319 Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options); 320 321 if (!res) 322 return failure(); 323 // Clear filter to stop recursive pattern application. 324 filter.replaceLinalgTransformationFilter(rewriter, res->op); 325 326 // Peel loops. 327 peelLoops(rewriter, *res, options); 328 329 // Consider padding on the fly only if the op has tensor semantics. 330 if (!options.paddingValueComputationFunction || 331 !linalgOp.hasTensorSemantics()) { 332 result = *res; 333 return success(); 334 } 335 336 // Try to pad on the fly by rewriting res->op as a padded op. If successful, 337 // `res.op` is rewritten in static form with padded operands. 338 LinalgOp paddedOp; 339 if (succeeded(rewriteAsPaddedOp(rewriter, res->op, 340 options.paddingValueComputationFunction, 341 paddedOp))) { 342 filter.replaceLinalgTransformationFilter(rewriter, paddedOp); 343 res->op = paddedOp; 344 result = *res; 345 // Do not perform replacement of `linalgOp`, let the derived patterns 346 // do this as they see fit, from the resulting TiledLinalgOp. 347 return success(); 348 } 349 // Set so RAII guard does not propagate TiledLinalgOp to `result`. 350 return failure(); 351 } 352 353 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) { 354 if (tiledOp.loops.empty()) 355 return tiledOp.op.getOperation()->getResults(); 356 return tiledOp.loops.front()->getResults(); 357 } 358 359 static ValueRange 360 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) { 361 if (tiledAndFusedOp.fusedLoops.empty()) 362 return tiledAndFusedOp.op.getOperation()->getResults(); 363 return tiledAndFusedOp.fusedLoops.front()->getResults(); 364 } 365 366 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( 367 StringRef opName, MLIRContext *context, 368 const LinalgDependenceGraph &dependenceGraph, 369 LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, 370 LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker, 371 LinalgTransformationFilter originalOpMarker, PatternBenefit benefit) 372 : RewritePattern(opName, benefit, context, {}), 373 dependenceGraph(dependenceGraph), tilingOptions(tilingOptions), 374 fusionOptions(fusionOptions), filter(filter), 375 fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {} 376 377 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( 378 Operation *op, PatternRewriter &rewriter) const { 379 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 380 // TODO: remove hasIndexSemantics check once index ops are supported. 381 if (!linalgOp || linalgOp.hasIndexSemantics()) 382 return failure(); 383 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 384 return failure(); 385 386 DenseSet<Operation *> producers; 387 producers.insert(linalgOp); 388 for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) { 389 Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum(); 390 // When looking at dependences into, indexingOp is always OpOperand. We 391 // could assert, but continue if this is not the case. 392 if (!operandNumber) 393 continue; 394 if (!fusionOptions.indicesToFuse.count(operandNumber.getValue())) 395 continue; 396 if (isa<LinalgOp>(dependence.getDependentOp())) 397 producers.insert(dependence.getDependentOp()); 398 } 399 400 SmallVector<LinalgOp, 1> fusionOps; 401 for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; 402 ++it) { 403 auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it)); 404 if (producerLinalgOp && producers.count(producerLinalgOp)) 405 fusionOps.push_back(producerLinalgOp); 406 } 407 fusionOps.push_back(linalgOp); 408 409 SmallVector<Value, 4> tileSizes = 410 tilingOptions.tileSizeComputationFunction(rewriter, op); 411 LinalgTilingOptions instanceTilingOptions = tilingOptions; 412 instanceTilingOptions.setTileSizes(tileSizes); 413 Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps( 414 rewriter, fusionOps, dependenceGraph, instanceTilingOptions); 415 if (!tiledAndFusedOps) 416 return failure(); 417 418 // Tile the unfused loops; 419 SmallVector<Value, 4> unfusedLoopTileSizes; 420 Value zero = rewriter.create<ConstantIndexOp>(op->getLoc(), 0); 421 for (auto tileSize : enumerate(tileSizes)) { 422 if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) 423 unfusedLoopTileSizes.push_back(zero); 424 else 425 unfusedLoopTileSizes.push_back(tileSize.value()); 426 } 427 // Tile the loop only if there is a non-zero tile size. 428 if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) 429 unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); 430 if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { 431 if (auto cst = val.getDefiningOp<ConstantIndexOp>()) 432 return cst.getValue() != 0; 433 return true; 434 })) { 435 LinalgTilingOptions unfusedTilingOptions = tilingOptions; 436 unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); 437 Optional<TiledLinalgOp> unfusedTiledOp = 438 tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); 439 if (!unfusedTiledOp) 440 return failure(); 441 rewriter.replaceOp(tiledAndFusedOps->op, 442 getTiledOpResult(unfusedTiledOp.getValue())); 443 tiledAndFusedOps->op = unfusedTiledOp->op; 444 } 445 op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue())); 446 447 filter.replaceLinalgTransformationFilter(rewriter, 448 tiledAndFusedOps->op.getOperation()); 449 for (auto fusedOp : tiledAndFusedOps->fusedProducers) { 450 fusedOpMarker.replaceLinalgTransformationFilter(rewriter, 451 fusedOp.getOperation()); 452 } 453 for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) { 454 originalOpMarker.replaceLinalgTransformationFilter( 455 rewriter, origProducerOp.getOperation()); 456 } 457 rewriter.updateRootInPlace(op, [&]() { 458 originalOpMarker.replaceLinalgTransformationFilter(rewriter, op); 459 }); 460 return success(); 461 } 462 463 /// Linalg generic interchange pattern. 464 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern( 465 MLIRContext *context, ArrayRef<unsigned> interchangeVector, 466 LinalgTransformationFilter filter, PatternBenefit benefit) 467 : OpRewritePattern(context, benefit), filter(filter), 468 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 469 470 LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite( 471 GenericOp genericOp, PatternRewriter &rewriter) const { 472 if (failed(filter.checkAndNotify(rewriter, genericOp))) 473 return failure(); 474 if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector))) 475 return failure(); 476 477 // TODO: figure out how this interplays with named ops. In particular this 478 // should break the named op property. 479 rewriter.updateRootInPlace(genericOp, [&]() { 480 interchangeGenericOp(rewriter, genericOp, interchangeVector); 481 // New filter if specified. 482 filter.replaceLinalgTransformationFilter(rewriter, genericOp); 483 }); 484 return success(); 485 } 486 487 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 488 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 489 LinalgTransformationFilter filter, PatternBenefit benefit) 490 : RewritePattern(opName, benefit, context, {}), filter(filter), 491 options(options) {} 492 493 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 494 Operation *op, PatternRewriter &rewriter) const { 495 if (failed(filter.checkAndNotify(rewriter, op))) 496 return failure(); 497 if (failed(promoteSubviewsPrecondition(op, options))) 498 return failure(); 499 500 // TODO: We cannot use root update here. This pattern is creating other ops, 501 // so if the promotion fails, those need to be cleaned up, which doesnt seem 502 // to be happening here. So to fail properly, we should be cloning the op and 503 // deleting the previous op. This needs more investigation. 504 rewriter.startRootUpdate(op); 505 Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options); 506 if (!promotedOp) { 507 rewriter.cancelRootUpdate(op); 508 return op->emitError("subview promotion failed"); 509 } 510 rewriter.finalizeRootUpdate(op); 511 filter.replaceLinalgTransformationFilter(rewriter, op); 512 return success(); 513 } 514 515 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 516 MLIRContext *context, LinalgTransformationFilter filter, 517 PatternBenefit benefit) 518 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {} 519 520 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 521 StringRef opName, MLIRContext *context, LinalgTransformationFilter filter, 522 PatternBenefit benefit) 523 : RewritePattern(opName, benefit, context, {}), filter(filter) {} 524 525 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( 526 Operation *op, PatternRewriter &rewriter) const { 527 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 528 if (!linalgOp) 529 return failure(); 530 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 531 return failure(); 532 SmallVector<Value> newResults; 533 if (failed(vectorizeLinalgOp(rewriter, op, newResults))) 534 return failure(); 535 if (!newResults.empty()) 536 rewriter.replaceOp(op, newResults); 537 else 538 rewriter.eraseOp(op); 539 return success(); 540 } 541 542 LogicalResult mlir::linalg::applyStagedPatterns( 543 Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns, 544 const FrozenRewritePatternSet &stage2Patterns, 545 function_ref<LogicalResult(Operation *)> stage3Lambda) { 546 unsigned iteration = 0; 547 (void)iteration; 548 for (const auto &patterns : stage1Patterns) { 549 LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" 550 << *op); 551 if (failed(applyPatternsAndFoldGreedily(op, patterns))) { 552 LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); 553 return failure(); 554 } 555 LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" 556 << *op); 557 if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { 558 LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); 559 return failure(); 560 } 561 LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" 562 << *op); 563 if (stage3Lambda) { 564 if (failed(stage3Lambda(op))) 565 return failure(); 566 LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" 567 << *op); 568 } 569 } 570 return success(); 571 } 572 573 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) { 574 return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName()); 575 } 576 577 /// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize 578 /// with pad_val) and GenericOp (to copy contents). 579 LogicalResult PadTensorOpTransformationPattern::matchAndRewrite( 580 linalg::PadTensorOp padOp, PatternRewriter &rewriter) const { 581 582 auto inputShapedType = padOp.source().getType().cast<ShapedType>(); 583 auto resultShapedType = padOp.result().getType().cast<ShapedType>(); 584 585 // Bail on non-static shapes. 586 if (!inputShapedType.hasStaticShape()) 587 return failure(); 588 if (!resultShapedType.hasStaticShape()) 589 return failure(); 590 591 // Only support padding with a constant for now, i.e. either: 592 // 1. A BBarg from a different block. 593 // 2. A value defined outside of the current block. 594 Block &block = padOp.region().front(); 595 auto yieldOp = cast<YieldOp>(block.getTerminator()); 596 assert(yieldOp.getNumOperands() == 1 && "expected single operand yield"); 597 Value padValue = yieldOp.values().front(); 598 Operation *definingOp = padValue.getDefiningOp(); 599 if (definingOp && definingOp->getBlock() == &block) 600 return failure(); 601 if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block) 602 return failure(); 603 604 // Create tensor with the padded shape 605 Location loc = padOp.getLoc(); 606 SmallVector<Value> indices(resultShapedType.getRank(), 607 rewriter.create<ConstantIndexOp>(loc, 0)); 608 Value initTensor = rewriter.create<InitTensorOp>( 609 loc, resultShapedType.getShape(), resultShapedType.getElementType()); 610 611 // Initialize tensor with the pad value 612 Value tmpTensor = 613 rewriter.create<linalg::FillOp>(loc, padValue, initTensor).result(); 614 615 // Copy original contents into new tensor 616 // Uses linalg.generic, but could be done with tensor.insert_slice 617 SmallVector<AffineExpr, 4> outputExprs; 618 for (unsigned i = 0; i < resultShapedType.getRank(); ++i) { 619 outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) + 620 padOp.static_low()[i].cast<IntegerAttr>().getInt()); 621 } 622 623 SmallVector<AffineMap, 2> transferMaps = { 624 rewriter.getMultiDimIdentityMap(inputShapedType.getRank()), 625 AffineMap::get(resultShapedType.getRank(), 626 /*symbolCount=*/0, outputExprs, rewriter.getContext())}; 627 628 rewriter.replaceOpWithNewOp<linalg::GenericOp>( 629 padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps, 630 getNParallelLoopsAttrs(resultShapedType.getRank()), 631 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { 632 nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]); 633 }); 634 635 return success(); 636 } 637 638 /// Filling `dest` using FillOp constant padding value if possible. 639 /// Otherwise, generate a tensor::GenerateOp. 640 Value GeneralizePadTensorOpPattern::createFillOrGenerateOp( 641 PatternRewriter &rewriter, PadTensorOp padOp, Value dest, 642 const SmallVector<Value> &dynSizes) const { 643 auto padValue = padOp.getConstantPaddingValue(); 644 if (padValue) 645 return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result(); 646 647 // Fill could not be optimized: Lower to tensor::GenerateOp with region. 648 auto generateOp = rewriter.create<tensor::GenerateOp>( 649 padOp.getLoc(), padOp.getResultType(), dynSizes); 650 // Copy region to new op. 651 BlockAndValueMapping bvm; 652 padOp.region().cloneInto(&generateOp.getRegion(), bvm); 653 // Rewrite linalg::YieldOp to tensor::YieldOp. 654 OpBuilder::InsertionGuard guard(rewriter); 655 auto yieldOp = 656 dyn_cast<linalg::YieldOp>(generateOp.getRegion().front().getTerminator()); 657 assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator"); 658 assert(yieldOp.values().size() == 1); 659 rewriter.setInsertionPoint(yieldOp); 660 rewriter.replaceOpWithNewOp<tensor::YieldOp>(yieldOp, yieldOp.values()[0]); 661 return generateOp; 662 } 663 664 LogicalResult 665 GeneralizePadTensorOpPattern::matchAndRewrite(PadTensorOp padOp, 666 PatternRewriter &rewriter) const { 667 // Given an OpFoldResult, return an index-typed value. 668 auto getIdxValue = [&](OpFoldResult ofr) { 669 if (auto val = ofr.dyn_cast<Value>()) 670 return val; 671 return rewriter 672 .create<ConstantIndexOp>( 673 padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt()) 674 .getResult(); 675 }; 676 677 auto resultType = padOp.getResultType(); 678 // Compute size of InitTensorOp. Any combination of static/dynamic is 679 // supported. 680 SmallVector<Value> dynSizes; 681 SmallVector<int64_t> staticSizes; 682 for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { 683 if (resultType.isDynamicDim(dim)) { 684 auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(), 685 padOp.source(), dim); 686 // Add low and high padding value. 687 auto plusLow = rewriter.createOrFold<AddIOp>( 688 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); 689 auto plusHigh = rewriter.createOrFold<AddIOp>( 690 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); 691 dynSizes.push_back(plusHigh); 692 } 693 staticSizes.push_back(resultType.getDimSize(dim)); 694 } 695 696 // Init tensor and fill it with padding. 697 Value init = rewriter.create<InitTensorOp>( 698 padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType()); 699 Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes); 700 701 // Try optimize the copy of source. 702 if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded()) 703 return success(); 704 705 // PadTensorOps cannot be optimized. Generate a InsertSliceOp instead 706 // for copying the PadOp source. 707 auto sourceType = padOp.getSourceType(); 708 // Compute size of source of PadTensorOp. 709 SmallVector<OpFoldResult> srcSizes; 710 for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) { 711 if (sourceType.isDynamicDim(dim)) { 712 srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>( 713 padOp.getLoc(), padOp.source(), dim)); 714 } else { 715 srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim))); 716 } 717 } 718 // Strides of InsertSliceOp are all 1. 719 SmallVector<OpFoldResult> strides(sourceType.getRank(), 720 rewriter.getIndexAttr(1)); 721 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( 722 padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides); 723 724 return success(); 725 } 726 727 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( 728 tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { 729 auto padOp = sliceOp.source().getDefiningOp<PadTensorOp>(); 730 if (!padOp) 731 return failure(); 732 // Only unit stride supported. 733 if (!sliceOp.hasUnitStride()) 734 return failure(); 735 736 Operation *tiledPadOp = padOp.getTiledImplementation( 737 rewriter, /*dest=*/ValueRange{}, sliceOp.getMixedOffsets(), 738 sliceOp.getMixedSizes()); 739 // All shapes are static and the data source is actually used. Rewrite into 740 // pad_tensor(subtensor(x)). 741 rewriter.replaceOp(sliceOp, tiledPadOp->getResults()); 742 return success(); 743 } 744