1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Affine/Utils.h" 16 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/SCF/Transforms.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/Dialect/Utils/StaticValueUtils.h" 22 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 23 #include "mlir/Dialect/Vector/VectorOps.h" 24 #include "mlir/IR/AffineExpr.h" 25 #include "mlir/IR/Matchers.h" 26 #include "mlir/Pass/Pass.h" 27 #include "mlir/Support/LLVM.h" 28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 29 #include "llvm/ADT/ScopeExit.h" 30 #include "llvm/ADT/TypeSwitch.h" 31 #include "llvm/Support/Debug.h" 32 #include "llvm/Support/raw_ostream.h" 33 #include <type_traits> 34 35 #define DEBUG_TYPE "linalg-transforms" 36 37 using namespace mlir; 38 using namespace mlir::linalg; 39 40 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 41 42 //===----------------------------------------------------------------------===// 43 // Transformations exposed as rewrite patterns. 44 //===----------------------------------------------------------------------===// 45 // Marker used as attribute name in generated Linalg rewriting transformations. 46 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 47 "__internal_linalg_transform__"; 48 49 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 50 ArrayRef<Identifier> matchDisjunction, Optional<Identifier> replacement) 51 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 52 replacement(replacement) {} 53 54 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 55 FilterFunction f, ArrayRef<Identifier> matchDisjunction, 56 Optional<Identifier> replacement) 57 : filters(), 58 matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 59 replacement(replacement) { 60 if (f) 61 filters.push_back(f); 62 } 63 64 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify( 65 PatternRewriter &rewriter, Operation *op) const { 66 if (llvm::any_of(filters, 67 [&](const FilterFunction &f) { return failed(f(op)); })) 68 return failure(); 69 70 auto attr = op->template getAttrOfType<StringAttr>( 71 LinalgTransforms::kLinalgTransformMarker); 72 73 if (!attr) { 74 // 1. Has no filter case and matchDisjunction is empty. 75 if (matchDisjunction.empty()) 76 return success(); 77 78 // 2. Has no filter but was expecting a filter. 79 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 80 diag << " does not have any filter from list: "; 81 interleaveComma(matchDisjunction, diag); 82 }); 83 } 84 85 // 4. Match explicit filter. 86 for (auto filter : matchDisjunction) 87 if (attr.getValue() == filter) 88 return success(); 89 90 // 5. Fail to match. 91 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 92 diag << " does not have any filter from list: "; 93 interleaveComma(matchDisjunction, diag); 94 }); 95 } 96 97 void mlir::linalg::LinalgTransformationFilter:: 98 replaceLinalgTransformationFilter(PatternRewriter &rewriter, 99 Operation *op) const { 100 if (replacement.hasValue()) 101 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 102 rewriter.getStringAttr(replacement.getValue().strref())); 103 else 104 op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker, 105 rewriter.getContext())); 106 } 107 108 LinalgTilingOptions & 109 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 110 assert(!tileSizeComputationFunction && "tile sizes already set"); 111 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 112 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 113 OpBuilder::InsertionGuard guard(b); 114 b.setInsertionPointToStart( 115 &op->getParentOfType<FuncOp>().getBody().front()); 116 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 117 Value v = b.create<ConstantIndexOp>(op->getLoc(), s); 118 return v; 119 })); 120 }; 121 return *this; 122 } 123 124 LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() { 125 assert(!tileSizeComputationFunction && "tile sizes already set"); 126 tileSizeComputationFunction = [](OpBuilder &b, Operation *op) { 127 SmallVector<Value, 4> tileSizes; 128 auto linalgOp = dyn_cast<LinalgOp>(op); 129 if (!linalgOp) 130 return tileSizes; 131 Location loc = linalgOp.getLoc(); 132 auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc); 133 AffineMap map = linalgOp.getShapesToLoopsMap(); 134 if (!map) 135 return tileSizes; 136 auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes); 137 // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile 138 // size 0). 139 for (Value shapeSize : shapeSizes) 140 tileSizes.push_back(getConstantIntValue(shapeSize).hasValue() 141 ? b.create<ConstantIndexOp>(loc, 0) 142 : b.create<ConstantIndexOp>(loc, 1)); 143 return tileSizes; 144 }; 145 return *this; 146 } 147 148 /// Helper function that tries to pad `opOperand`. Exit early and return success 149 /// for scalar operands or if `paddingFunc` returns failure. Otherwise, try to 150 /// pad the operand even if it already has a static shape. Set `result` to the 151 /// result of the created PadTensorOp or return failure if the operand cannot be 152 /// padded to a static shape. 153 static LogicalResult padOperandToSmallestStaticBoundingBox( 154 PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand, 155 const PaddingValueComputationFunction &paddingFunc, Value &result) { 156 // Can't pad scalars. 157 if (opToPad.getShape(opOperand).empty()) 158 return success(); 159 // Can't pad if no padding value is known. 160 FailureOr<Value> paddingValue = paddingFunc(rewriter, *opOperand); 161 if (failed(paddingValue)) 162 return success(); 163 auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>(); 164 // Not a slice op, cannot construct a static bounding box. 165 if (!sliceOp) 166 return failure(); 167 SmallVector<int64_t> staticSizes; 168 staticSizes.reserve(opToPad.getRank(opOperand)); 169 auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation()); 170 for (auto size : shapedOp.getMixedSizes()) { 171 auto indexAttr = size.is<Attribute>() 172 ? size.get<Attribute>().dyn_cast<IntegerAttr>() 173 : linalg::getSmallestBoundingIndex(size.get<Value>()); 174 // SmallestBoundingIndex must exist for all sizes. 175 // For now return an error if we can't find it. 176 if (!indexAttr) 177 return rewriter.notifyMatchFailure( 178 opToPad, "No constant bounding box can be found for padding"); 179 staticSizes.push_back(indexAttr.getInt()); 180 } 181 auto staticTensorType = RankedTensorType::get( 182 staticSizes, getElementTypeOrSelf(opOperand->get())); 183 result = linalg::PadTensorOp::createPadHighOp( 184 staticTensorType, opOperand->get(), paddingValue.getValue(), 185 /*nofold=*/true, opToPad->getLoc(), rewriter); 186 return success(); 187 } 188 189 LogicalResult 190 linalg::rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad, 191 const PaddingValueComputationFunction &paddingFunc, 192 LinalgOp &paddedOp) { 193 Location loc = opToPad->getLoc(); 194 195 // TODO: there are cases where we may still want to pad to larger sizes. 196 assert(opToPad.hasTensorSemantics() && 197 "expected operation to have tensor semantics"); 198 199 OpBuilder::InsertionGuard g(rewriter); 200 // Set IP after op because we also take the dims of the original output. 201 rewriter.setInsertionPointAfter(opToPad); 202 // Make a copy of the shaped operands and update it. 203 SmallVector<Value> newOperands; 204 newOperands.reserve(opToPad.getNumInputsAndOutputs()); 205 for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) { 206 Value paddedOperand; 207 // If padding was requested but the shape cannot be bounded statically then 208 // the pattern fails to apply. 209 if (failed(padOperandToSmallestStaticBoundingBox( 210 rewriter, opToPad, opOperand, paddingFunc, paddedOperand))) 211 return failure(); 212 newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get()); 213 } 214 215 SmallVector<SmallVector<Value>> reifiedResultShapes; 216 if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation()) 217 .reifyResultShapes(rewriter, reifiedResultShapes))) 218 return failure(); 219 assert(reifiedResultShapes.size() == opToPad->getNumResults() && 220 "expected same number of results"); 221 222 // Clone `opToPad` to operate on the statically padded shapes. 223 auto resultTensorTypes = 224 ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes(); 225 paddedOp = opToPad.clone(rewriter, loc, resultTensorTypes, newOperands); 226 227 // Recover the slice out of the new static results. This keeps the original 228 // linalg op around because it uses the dims of the original results. 229 SmallVector<Value> paddedSubviewResults; 230 paddedSubviewResults.reserve(opToPad->getNumResults()); 231 for (auto en : llvm::enumerate(paddedOp->getResults())) { 232 Value paddedResult = en.value(); 233 int64_t resultNumber = en.index(); 234 int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank(); 235 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0)); 236 SmallVector<OpFoldResult> sizes; 237 for (Value v : reifiedResultShapes[resultNumber]) 238 sizes.push_back(v); 239 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); 240 paddedSubviewResults.push_back(rewriter.create<tensor::ExtractSliceOp>( 241 loc, paddedResult, offsets, sizes, strides)); 242 } 243 rewriter.replaceOp(opToPad, paddedSubviewResults); 244 return success(); 245 } 246 247 /// Linalg base tiling pattern. 248 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 249 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 250 LinalgTransformationFilter filter, PatternBenefit benefit) 251 : RewritePattern(opName, benefit, context), filter(filter), 252 options(options) {} 253 254 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 255 MLIRContext *context, LinalgTilingOptions options, 256 LinalgTransformationFilter filter, PatternBenefit benefit) 257 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 258 options(options) {} 259 260 /// Try to peel a loop `op` and return the new result. 261 // TODO: Add support for scf.parallel and affine.for loops. 262 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) { 263 return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op) 264 .Case<scf::ForOp>([&](scf::ForOp forOp) { 265 scf::ForOp partialIteration; 266 if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp, 267 partialIteration))) 268 return partialIteration->getResults(); 269 assert(!partialIteration && "expected that loop was not peeled"); 270 return forOp->getResults(); 271 }) 272 .Default([&](Operation *op) { return op->getResults(); }); 273 } 274 275 /// Try to peel a TiledLoopOp and return the new result. 276 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, 277 TiledLoopOp tiledLoop, int64_t idx) { 278 assert(idx < static_cast<int64_t>(tiledLoop.iterator_types().size()) && 279 "requested peeling of non-existing loop"); 280 TiledLoopOp result; 281 if (succeeded(peelAndCanonicalizeTiledLoop(rewriter, tiledLoop, idx, result))) 282 return result->getResults(); 283 assert(!result && "expected that loop was not peeled"); 284 return tiledLoop->getResults(); 285 } 286 287 /// Peel loops after tiling. 288 static void peelLoops(RewriterBase &rewriter, TiledLinalgOp &res, 289 const LinalgTilingOptions &options) { 290 for (int64_t loop : options.peeledLoops) { 291 assert(loop < static_cast<int64_t>(res.loops.size()) && 292 "requested peeling of non-existing loop"); 293 SmallVector<Value, 4> loopResults; 294 Operation *loopOp = res.loops[loop]; 295 if (options.loopType == LinalgTilingLoopType::TiledLoops) { 296 assert(llvm::all_of( 297 res.loops, 298 [&](Operation *op) { return op == res.loops.front(); }) && 299 "expected that all loop ops are the same TiledLoopOp"); 300 auto tiledLoopOp = dyn_cast<TiledLoopOp>(loopOp); 301 assert(tiledLoopOp && "expected TiledLoopOp"); 302 loopResults = peelLoop(rewriter, tiledLoopOp, loop); 303 } else { 304 loopResults = peelLoop(rewriter, loopOp); 305 } 306 307 // The result of the loop nest may change with peeling. 308 if (res.tensorResults.size() == loopOp->getNumResults() && 309 std::equal(res.tensorResults.begin(), res.tensorResults.end(), 310 loopOp->getResults().begin())) 311 res.tensorResults = loopResults; 312 } 313 } 314 315 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase( 316 Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const { 317 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 318 if (!linalgOp) 319 return failure(); 320 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 321 return failure(); 322 323 Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options); 324 325 if (!res) 326 return failure(); 327 // Clear filter to stop recursive pattern application. 328 filter.replaceLinalgTransformationFilter(rewriter, res->op); 329 330 // Peel loops. 331 peelLoops(rewriter, *res, options); 332 333 // Consider padding on the fly only if the op has tensor semantics. 334 if (!options.paddingValueComputationFunction || 335 !linalgOp.hasTensorSemantics()) { 336 result = *res; 337 return success(); 338 } 339 340 // Try to pad on the fly by rewriting res->op as a padded op. If successful, 341 // `res.op` is rewritten in static form with padded operands. 342 LinalgOp paddedOp; 343 if (succeeded(rewriteAsPaddedOp(rewriter, res->op, 344 options.paddingValueComputationFunction, 345 paddedOp))) { 346 filter.replaceLinalgTransformationFilter(rewriter, paddedOp); 347 res->op = paddedOp; 348 result = *res; 349 // Do not perform replacement of `linalgOp`, let the derived patterns 350 // do this as they see fit, from the resulting TiledLinalgOp. 351 return success(); 352 } 353 // Set so RAII guard does not propagate TiledLinalgOp to `result`. 354 return failure(); 355 } 356 357 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) { 358 if (tiledOp.loops.empty()) 359 return tiledOp.op.getOperation()->getResults(); 360 return tiledOp.loops.front()->getResults(); 361 } 362 363 static ValueRange 364 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) { 365 if (tiledAndFusedOp.fusedLoops.empty()) 366 return tiledAndFusedOp.op.getOperation()->getResults(); 367 return tiledAndFusedOp.fusedLoops.front()->getResults(); 368 } 369 370 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( 371 StringRef opName, MLIRContext *context, 372 const LinalgDependenceGraph &dependenceGraph, 373 LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, 374 LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker, 375 LinalgTransformationFilter originalOpMarker, PatternBenefit benefit) 376 : RewritePattern(opName, benefit, context, {}), 377 dependenceGraph(dependenceGraph), tilingOptions(tilingOptions), 378 fusionOptions(fusionOptions), filter(filter), 379 fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {} 380 381 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( 382 Operation *op, PatternRewriter &rewriter) const { 383 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 384 // TODO: remove hasIndexSemantics check once index ops are supported. 385 if (!linalgOp || linalgOp.hasIndexSemantics()) 386 return failure(); 387 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 388 return failure(); 389 390 DenseSet<Operation *> producers; 391 producers.insert(linalgOp); 392 for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) { 393 Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum(); 394 // When looking at dependences into, indexingOp is always OpOperand. We 395 // could assert, but continue if this is not the case. 396 if (!operandNumber) 397 continue; 398 if (!fusionOptions.indicesToFuse.count(operandNumber.getValue())) 399 continue; 400 if (isa<LinalgOp>(dependence.getDependentOp())) 401 producers.insert(dependence.getDependentOp()); 402 } 403 404 SmallVector<LinalgOp, 1> fusionOps; 405 for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; 406 ++it) { 407 auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it)); 408 if (producerLinalgOp && producers.count(producerLinalgOp)) 409 fusionOps.push_back(producerLinalgOp); 410 } 411 fusionOps.push_back(linalgOp); 412 413 SmallVector<Value, 4> tileSizes = 414 tilingOptions.tileSizeComputationFunction(rewriter, op); 415 LinalgTilingOptions instanceTilingOptions = tilingOptions; 416 instanceTilingOptions.setTileSizes(tileSizes); 417 Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps( 418 rewriter, fusionOps, dependenceGraph, instanceTilingOptions); 419 if (!tiledAndFusedOps) 420 return failure(); 421 422 // Tile the unfused loops; 423 SmallVector<Value, 4> unfusedLoopTileSizes; 424 Value zero = rewriter.create<ConstantIndexOp>(op->getLoc(), 0); 425 for (auto tileSize : enumerate(tileSizes)) { 426 if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) 427 unfusedLoopTileSizes.push_back(zero); 428 else 429 unfusedLoopTileSizes.push_back(tileSize.value()); 430 } 431 // Tile the loop only if there is a non-zero tile size. 432 if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) 433 unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); 434 if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { 435 if (auto cst = val.getDefiningOp<ConstantIndexOp>()) 436 return cst.getValue() != 0; 437 return true; 438 })) { 439 LinalgTilingOptions unfusedTilingOptions = tilingOptions; 440 unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); 441 Optional<TiledLinalgOp> unfusedTiledOp = 442 tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); 443 if (!unfusedTiledOp) 444 return failure(); 445 rewriter.replaceOp(tiledAndFusedOps->op, 446 getTiledOpResult(unfusedTiledOp.getValue())); 447 tiledAndFusedOps->op = unfusedTiledOp->op; 448 } 449 op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue())); 450 451 filter.replaceLinalgTransformationFilter(rewriter, 452 tiledAndFusedOps->op.getOperation()); 453 for (auto fusedOp : tiledAndFusedOps->fusedProducers) { 454 fusedOpMarker.replaceLinalgTransformationFilter(rewriter, 455 fusedOp.getOperation()); 456 } 457 for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) { 458 originalOpMarker.replaceLinalgTransformationFilter( 459 rewriter, origProducerOp.getOperation()); 460 } 461 rewriter.updateRootInPlace(op, [&]() { 462 originalOpMarker.replaceLinalgTransformationFilter(rewriter, op); 463 }); 464 return success(); 465 } 466 467 /// Linalg generic interchange pattern. 468 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern( 469 MLIRContext *context, ArrayRef<unsigned> interchangeVector, 470 LinalgTransformationFilter filter, PatternBenefit benefit) 471 : OpRewritePattern(context, benefit), filter(filter), 472 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 473 474 LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite( 475 GenericOp genericOp, PatternRewriter &rewriter) const { 476 if (failed(filter.checkAndNotify(rewriter, genericOp))) 477 return failure(); 478 if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector))) 479 return failure(); 480 481 // TODO: figure out how this interplays with named ops. In particular this 482 // should break the named op property. 483 rewriter.updateRootInPlace(genericOp, [&]() { 484 interchangeGenericOp(rewriter, genericOp, interchangeVector); 485 // New filter if specified. 486 filter.replaceLinalgTransformationFilter(rewriter, genericOp); 487 }); 488 return success(); 489 } 490 491 /// Linalg generalization pattern. 492 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 493 MLIRContext *context, LinalgTransformationFilter filter, 494 PatternBenefit benefit) 495 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {} 496 497 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 498 StringRef opName, MLIRContext *context, LinalgTransformationFilter filter, 499 PatternBenefit benefit) 500 : RewritePattern(opName, benefit, context, {}), filter(filter) {} 501 502 LogicalResult mlir::linalg::LinalgGeneralizationPattern::matchAndRewrite( 503 Operation *op, PatternRewriter &rewriter) const { 504 if (failed(filter.checkAndNotify(rewriter, op))) 505 return failure(); 506 if (failed(generalizeNamedOpPrecondition(op))) 507 return failure(); 508 509 GenericOp genericOp = generalizeNamedOp(rewriter, op); 510 rewriter.replaceOp(op, genericOp.getResults()); 511 filter.replaceLinalgTransformationFilter(rewriter, genericOp); 512 return success(); 513 } 514 515 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 516 MLIRContext *context, LinalgTransformationFilter filter, 517 LinalgPromotionOptions options, PatternBenefit benefit) 518 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 519 options(options) {} 520 521 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 522 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 523 LinalgTransformationFilter filter, PatternBenefit benefit) 524 : RewritePattern(opName, benefit, context, {}), filter(filter), 525 options(options) {} 526 527 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 528 Operation *op, PatternRewriter &rewriter) const { 529 if (failed(filter.checkAndNotify(rewriter, op))) 530 return failure(); 531 if (failed(promoteSubviewsPrecondition(op, options))) 532 return failure(); 533 534 // TODO: We cannot use root update here. This pattern is creating other ops, 535 // so if the promotion fails, those need to be cleaned up, which doesnt seem 536 // to be happening here. So to fail properly, we should be cloning the op and 537 // deleting the previous op. This needs more investigation. 538 rewriter.startRootUpdate(op); 539 Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options); 540 if (!promotedOp) { 541 rewriter.cancelRootUpdate(op); 542 return op->emitError("subview promotion failed"); 543 } 544 rewriter.finalizeRootUpdate(op); 545 filter.replaceLinalgTransformationFilter(rewriter, op); 546 return success(); 547 } 548 549 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 550 MLIRContext *context, LinalgTransformationFilter filter, 551 PatternBenefit benefit) 552 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {} 553 554 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 555 StringRef opName, MLIRContext *context, LinalgTransformationFilter filter, 556 PatternBenefit benefit) 557 : RewritePattern(opName, benefit, context, {}), filter(filter) {} 558 559 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( 560 Operation *op, PatternRewriter &rewriter) const { 561 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 562 if (!linalgOp) 563 return failure(); 564 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 565 return failure(); 566 SmallVector<Value> newResults; 567 if (failed(vectorizeLinalgOp(rewriter, op, newResults))) 568 return failure(); 569 if (!newResults.empty()) 570 rewriter.replaceOp(op, newResults); 571 else 572 rewriter.eraseOp(op); 573 return success(); 574 } 575 576 LogicalResult mlir::linalg::applyStagedPatterns( 577 Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns, 578 const FrozenRewritePatternSet &stage2Patterns, 579 function_ref<LogicalResult(Operation *)> stage3Lambda) { 580 unsigned iteration = 0; 581 (void)iteration; 582 for (const auto &patterns : stage1Patterns) { 583 LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" 584 << *op); 585 if (failed(applyPatternsAndFoldGreedily(op, patterns))) { 586 LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); 587 return failure(); 588 } 589 LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" 590 << *op); 591 if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { 592 LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); 593 return failure(); 594 } 595 LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" 596 << *op); 597 if (stage3Lambda) { 598 if (failed(stage3Lambda(op))) 599 return failure(); 600 LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" 601 << *op); 602 } 603 } 604 return success(); 605 } 606 607 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) { 608 return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName()); 609 } 610 611 /// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize 612 /// with pad_val) and GenericOp (to copy contents). 613 LogicalResult PadTensorOpTransformationPattern::matchAndRewrite( 614 linalg::PadTensorOp padOp, PatternRewriter &rewriter) const { 615 616 auto inputShapedType = padOp.source().getType().cast<ShapedType>(); 617 auto resultShapedType = padOp.result().getType().cast<ShapedType>(); 618 619 // Bail on non-static shapes. 620 if (!inputShapedType.hasStaticShape()) 621 return failure(); 622 if (!resultShapedType.hasStaticShape()) 623 return failure(); 624 625 // Only support padding with a constant for now, i.e. either: 626 // 1. A BBarg from a different block. 627 // 2. A value defined outside of the current block. 628 Block &block = padOp.region().front(); 629 auto yieldOp = cast<YieldOp>(block.getTerminator()); 630 assert(yieldOp.getNumOperands() == 1 && "expected single operand yield"); 631 Value padValue = yieldOp.values().front(); 632 Operation *definingOp = padValue.getDefiningOp(); 633 if (definingOp && definingOp->getBlock() == &block) 634 return failure(); 635 if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block) 636 return failure(); 637 638 // Create tensor with the padded shape 639 Location loc = padOp.getLoc(); 640 SmallVector<Value> indices(resultShapedType.getRank(), 641 rewriter.create<ConstantIndexOp>(loc, 0)); 642 Value initTensor = rewriter.create<InitTensorOp>( 643 loc, resultShapedType.getShape(), resultShapedType.getElementType()); 644 645 // Initialize tensor with the pad value 646 Value tmpTensor = 647 rewriter.create<linalg::FillOp>(loc, padValue, initTensor).result(); 648 649 // Copy original contents into new tensor 650 // Uses linalg.generic, but could be done with tensor.insert_slice 651 SmallVector<AffineExpr, 4> outputExprs; 652 for (unsigned i = 0; i < resultShapedType.getRank(); ++i) { 653 outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) + 654 padOp.static_low()[i].cast<IntegerAttr>().getInt()); 655 } 656 657 SmallVector<AffineMap, 2> transferMaps = { 658 rewriter.getMultiDimIdentityMap(inputShapedType.getRank()), 659 AffineMap::get(resultShapedType.getRank(), 660 /*symbolCount=*/0, outputExprs, rewriter.getContext())}; 661 662 rewriter.replaceOpWithNewOp<linalg::GenericOp>( 663 padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps, 664 getNParallelLoopsAttrs(resultShapedType.getRank()), 665 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { 666 nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]); 667 }); 668 669 return success(); 670 } 671 672 /// Filling `dest` using FillOp constant padding value if possible. 673 /// Otherwise, generate a tensor::GenerateOp. 674 Value GeneralizePadTensorOpPattern::createFillOrGenerateOp( 675 PatternRewriter &rewriter, PadTensorOp padOp, Value dest, 676 const SmallVector<Value> &dynSizes) const { 677 auto padValue = padOp.getConstantPaddingValue(); 678 if (padValue) 679 return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result(); 680 681 // Fill could not be optimized: Lower to tensor::GenerateOp with region. 682 auto generateOp = rewriter.create<tensor::GenerateOp>( 683 padOp.getLoc(), padOp.getResultType(), dynSizes); 684 // Copy region to new op. 685 BlockAndValueMapping bvm; 686 padOp.region().cloneInto(&generateOp.getRegion(), bvm); 687 // Rewrite linalg::YieldOp to tensor::YieldOp. 688 OpBuilder::InsertionGuard guard(rewriter); 689 auto yieldOp = 690 dyn_cast<linalg::YieldOp>(generateOp.getRegion().front().getTerminator()); 691 assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator"); 692 assert(yieldOp.values().size() == 1); 693 rewriter.setInsertionPoint(yieldOp); 694 rewriter.replaceOpWithNewOp<tensor::YieldOp>(yieldOp, yieldOp.values()[0]); 695 return generateOp; 696 } 697 698 LogicalResult 699 GeneralizePadTensorOpPattern::matchAndRewrite(PadTensorOp padOp, 700 PatternRewriter &rewriter) const { 701 // Given an OpFoldResult, return an index-typed value. 702 auto getIdxValue = [&](OpFoldResult ofr) { 703 if (auto val = ofr.dyn_cast<Value>()) 704 return val; 705 return rewriter 706 .create<ConstantIndexOp>( 707 padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt()) 708 .getResult(); 709 }; 710 711 auto resultType = padOp.getResultType(); 712 // Compute size of InitTensorOp. Any combination of static/dynamic is 713 // supported. 714 SmallVector<Value> dynSizes; 715 SmallVector<int64_t> staticSizes; 716 for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { 717 if (resultType.isDynamicDim(dim)) { 718 auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(), 719 padOp.source(), dim); 720 // Add low and high padding value. 721 auto plusLow = rewriter.createOrFold<AddIOp>( 722 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); 723 auto plusHigh = rewriter.createOrFold<AddIOp>( 724 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); 725 dynSizes.push_back(plusHigh); 726 } 727 staticSizes.push_back(resultType.getDimSize(dim)); 728 } 729 730 // Init tensor and fill it with padding. 731 Value init = rewriter.create<InitTensorOp>( 732 padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType()); 733 Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes); 734 735 // Try optimize the copy of source. 736 if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded()) 737 return success(); 738 739 // PadTensorOps cannot be optimized. Generate a InsertSliceOp instead 740 // for copying the PadOp source. 741 auto sourceType = padOp.getSourceType(); 742 // Compute size of source of PadTensorOp. 743 SmallVector<OpFoldResult> srcSizes; 744 for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) { 745 if (sourceType.isDynamicDim(dim)) { 746 srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>( 747 padOp.getLoc(), padOp.source(), dim)); 748 } else { 749 srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim))); 750 } 751 } 752 // Strides of InsertSliceOp are all 1. 753 SmallVector<OpFoldResult> strides(sourceType.getRank(), 754 rewriter.getIndexAttr(1)); 755 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( 756 padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides); 757 758 return success(); 759 } 760 761 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( 762 tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { 763 auto padOp = sliceOp.source().getDefiningOp<PadTensorOp>(); 764 if (!padOp) 765 return failure(); 766 // Only unit stride supported. 767 if (!sliceOp.hasUnitStride()) 768 return failure(); 769 770 Operation *tiledPadOp = padOp.getTiledImplementation( 771 rewriter, /*dest=*/ValueRange{}, sliceOp.getMixedOffsets(), 772 sliceOp.getMixedSizes()); 773 // All shapes are static and the data source is actually used. Rewrite into 774 // pad_tensor(subtensor(x)). 775 rewriter.replaceOp(sliceOp, tiledPadOp->getResults()); 776 return success(); 777 } 778