1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Affine/Utils.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 18 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/Dialect/SCF/Transforms.h" 21 #include "mlir/Dialect/Tensor/IR/Tensor.h" 22 #include "mlir/Dialect/Utils/StaticValueUtils.h" 23 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 24 #include "mlir/Dialect/Vector/VectorOps.h" 25 #include "mlir/IR/AffineExpr.h" 26 #include "mlir/IR/Matchers.h" 27 #include "mlir/Pass/Pass.h" 28 #include "mlir/Support/LLVM.h" 29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 30 #include "llvm/ADT/ScopeExit.h" 31 #include "llvm/ADT/TypeSwitch.h" 32 #include "llvm/Support/Debug.h" 33 #include "llvm/Support/raw_ostream.h" 34 #include <type_traits> 35 36 #define DEBUG_TYPE "linalg-transforms" 37 38 using namespace mlir; 39 using namespace mlir::linalg; 40 41 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 42 43 //===----------------------------------------------------------------------===// 44 // Transformations exposed as rewrite patterns. 45 //===----------------------------------------------------------------------===// 46 // Marker used as attribute name in generated Linalg rewriting transformations. 47 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 48 "__internal_linalg_transform__"; 49 50 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 51 ArrayRef<Identifier> matchDisjunction, Optional<Identifier> replacement) 52 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 53 replacement(replacement) {} 54 55 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 56 FilterFunction f, ArrayRef<Identifier> matchDisjunction, 57 Optional<Identifier> replacement) 58 : filters(), 59 matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 60 replacement(replacement) { 61 if (f) 62 filters.push_back(f); 63 } 64 65 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify( 66 PatternRewriter &rewriter, Operation *op) const { 67 if (llvm::any_of(filters, 68 [&](const FilterFunction &f) { return failed(f(op)); })) 69 return failure(); 70 71 auto attr = op->template getAttrOfType<StringAttr>( 72 LinalgTransforms::kLinalgTransformMarker); 73 74 if (!attr) { 75 // 1. Has no filter case and matchDisjunction is empty. 76 if (matchDisjunction.empty()) 77 return success(); 78 79 // 2. Has no filter but was expecting a filter. 80 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 81 diag << " does not have any filter from list: "; 82 interleaveComma(matchDisjunction, diag); 83 }); 84 } 85 86 // 4. Match explicit filter. 87 for (auto filter : matchDisjunction) 88 if (attr.getValue() == filter) 89 return success(); 90 91 // 5. Fail to match. 92 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 93 diag << " does not have any filter from list: "; 94 interleaveComma(matchDisjunction, diag); 95 }); 96 } 97 98 void mlir::linalg::LinalgTransformationFilter:: 99 replaceLinalgTransformationFilter(PatternRewriter &rewriter, 100 Operation *op) const { 101 if (replacement.hasValue()) 102 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 103 rewriter.getStringAttr(replacement.getValue().strref())); 104 else 105 op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker, 106 rewriter.getContext())); 107 } 108 109 LinalgTilingOptions & 110 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 111 assert(!tileSizeComputationFunction && "tile sizes already set"); 112 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 113 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 114 OpBuilder::InsertionGuard guard(b); 115 b.setInsertionPointToStart( 116 &op->getParentOfType<FuncOp>().getBody().front()); 117 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 118 Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); 119 return v; 120 })); 121 }; 122 return *this; 123 } 124 125 LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() { 126 assert(!tileSizeComputationFunction && "tile sizes already set"); 127 tileSizeComputationFunction = [](OpBuilder &b, Operation *op) { 128 SmallVector<Value, 4> tileSizes; 129 auto linalgOp = dyn_cast<LinalgOp>(op); 130 if (!linalgOp) 131 return tileSizes; 132 Location loc = linalgOp.getLoc(); 133 auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc); 134 AffineMap map = linalgOp.getShapesToLoopsMap(); 135 if (!map) 136 return tileSizes; 137 auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes); 138 // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile 139 // size 0). 140 for (Value shapeSize : shapeSizes) 141 tileSizes.push_back(getConstantIntValue(shapeSize).hasValue() 142 ? b.create<arith::ConstantIndexOp>(loc, 0) 143 : b.create<arith::ConstantIndexOp>(loc, 1)); 144 return tileSizes; 145 }; 146 return *this; 147 } 148 149 /// Helper function that tries to pad `opOperand`. Exit early and return success 150 /// for scalar operands or if `paddingFunc` returns failure. Otherwise, try to 151 /// pad the operand even if it already has a static shape. Set `result` to the 152 /// result of the created PadTensorOp or return failure if the operand cannot be 153 /// padded to a static shape. 154 static LogicalResult padOperandToSmallestStaticBoundingBox( 155 PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand, 156 const PaddingValueComputationFunction &paddingFunc, 157 const PaddingNoFoldComputationFunction &nofoldFunc, Value &result) { 158 // Can't pad scalars. 159 if (opToPad.getShape(opOperand).empty()) 160 return success(); 161 // Can't pad if no padding value is known. 162 FailureOr<Value> paddingValue = paddingFunc(rewriter, *opOperand); 163 if (failed(paddingValue)) 164 return success(); 165 auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>(); 166 // Not a slice op, cannot construct a static bounding box. 167 if (!sliceOp) 168 return failure(); 169 SmallVector<int64_t> staticSizes; 170 staticSizes.reserve(opToPad.getRank(opOperand)); 171 auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation()); 172 for (auto size : shapedOp.getMixedSizes()) { 173 auto indexAttr = size.is<Attribute>() 174 ? size.get<Attribute>().dyn_cast<IntegerAttr>() 175 : linalg::getSmallestBoundingIndex(size.get<Value>()); 176 // SmallestBoundingIndex must exist for all sizes. 177 // For now return an error if we can't find it. 178 if (!indexAttr) 179 return rewriter.notifyMatchFailure( 180 opToPad, "No constant bounding box can be found for padding"); 181 staticSizes.push_back(indexAttr.getInt()); 182 } 183 auto staticTensorType = RankedTensorType::get( 184 staticSizes, getElementTypeOrSelf(opOperand->get())); 185 bool nofold = nofoldFunc ? nofoldFunc(*opOperand) : false; 186 result = linalg::PadTensorOp::createPadHighOp( 187 staticTensorType, opOperand->get(), paddingValue.getValue(), 188 /*nofold=*/nofold, opToPad->getLoc(), rewriter); 189 return success(); 190 } 191 192 LogicalResult 193 linalg::rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad, 194 const PaddingValueComputationFunction &paddingFunc, 195 const PaddingNoFoldComputationFunction &nofoldFunc, 196 LinalgOp &paddedOp) { 197 Location loc = opToPad->getLoc(); 198 199 // TODO: there are cases where we may still want to pad to larger sizes. 200 assert(opToPad.hasTensorSemantics() && 201 "expected operation to have tensor semantics"); 202 203 OpBuilder::InsertionGuard g(rewriter); 204 // Set IP after op because we also take the dims of the original output. 205 rewriter.setInsertionPointAfter(opToPad); 206 // Make a copy of the shaped operands and update it. 207 SmallVector<Value> newOperands; 208 newOperands.reserve(opToPad.getNumInputsAndOutputs()); 209 for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) { 210 Value paddedOperand; 211 // If padding was requested but the shape cannot be bounded statically then 212 // the pattern fails to apply. 213 if (failed(padOperandToSmallestStaticBoundingBox( 214 rewriter, opToPad, opOperand, paddingFunc, nofoldFunc, 215 paddedOperand))) 216 return failure(); 217 newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get()); 218 } 219 220 SmallVector<SmallVector<Value>> reifiedResultShapes; 221 if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation()) 222 .reifyResultShapes(rewriter, reifiedResultShapes))) 223 return failure(); 224 assert(reifiedResultShapes.size() == opToPad->getNumResults() && 225 "expected same number of results"); 226 227 // Clone `opToPad` to operate on the statically padded shapes. 228 auto resultTensorTypes = 229 ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes(); 230 paddedOp = opToPad.clone(rewriter, loc, resultTensorTypes, newOperands); 231 232 // Recover the slice out of the new static results. This keeps the original 233 // linalg op around because it uses the dims of the original results. 234 SmallVector<Value> paddedSubviewResults; 235 paddedSubviewResults.reserve(opToPad->getNumResults()); 236 for (auto en : llvm::enumerate(paddedOp->getResults())) { 237 Value paddedResult = en.value(); 238 int64_t resultNumber = en.index(); 239 int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank(); 240 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0)); 241 SmallVector<OpFoldResult> sizes; 242 for (Value v : reifiedResultShapes[resultNumber]) 243 sizes.push_back(v); 244 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); 245 paddedSubviewResults.push_back(rewriter.create<tensor::ExtractSliceOp>( 246 loc, paddedResult, offsets, sizes, strides)); 247 } 248 rewriter.replaceOp(opToPad, paddedSubviewResults); 249 return success(); 250 } 251 252 /// Linalg base tiling pattern. 253 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 254 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 255 LinalgTransformationFilter filter, PatternBenefit benefit) 256 : RewritePattern(opName, benefit, context), filter(filter), 257 options(options) {} 258 259 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 260 MLIRContext *context, LinalgTilingOptions options, 261 LinalgTransformationFilter filter, PatternBenefit benefit) 262 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 263 options(options) {} 264 265 /// Try to peel a loop `op` and return the new result. 266 // TODO: Add support for scf.parallel and affine.for loops. 267 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) { 268 return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op) 269 .Case<scf::ForOp>([&](scf::ForOp forOp) { 270 scf::ForOp partialIteration; 271 if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp, 272 partialIteration))) 273 return partialIteration->getResults(); 274 assert(!partialIteration && "expected that loop was not peeled"); 275 return forOp->getResults(); 276 }) 277 .Default([&](Operation *op) { return op->getResults(); }); 278 } 279 280 /// Try to peel a TiledLoopOp and return the new result. 281 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, 282 TiledLoopOp tiledLoop, int64_t idx) { 283 assert(idx < static_cast<int64_t>(tiledLoop.iterator_types().size()) && 284 "requested peeling of non-existing loop"); 285 TiledLoopOp result; 286 if (succeeded(peelAndCanonicalizeTiledLoop(rewriter, tiledLoop, idx, result))) 287 return result->getResults(); 288 assert(!result && "expected that loop was not peeled"); 289 return tiledLoop->getResults(); 290 } 291 292 /// Peel loops after tiling. 293 static void peelLoops(RewriterBase &rewriter, TiledLinalgOp &res, 294 const LinalgTilingOptions &options) { 295 for (int64_t loop : options.peeledLoops) { 296 assert(loop < static_cast<int64_t>(res.loops.size()) && 297 "requested peeling of non-existing loop"); 298 SmallVector<Value, 4> loopResults; 299 Operation *loopOp = res.loops[loop]; 300 if (options.loopType == LinalgTilingLoopType::TiledLoops) { 301 assert(llvm::all_of( 302 res.loops, 303 [&](Operation *op) { return op == res.loops.front(); }) && 304 "expected that all loop ops are the same TiledLoopOp"); 305 auto tiledLoopOp = dyn_cast<TiledLoopOp>(loopOp); 306 assert(tiledLoopOp && "expected TiledLoopOp"); 307 loopResults = peelLoop(rewriter, tiledLoopOp, loop); 308 } else { 309 loopResults = peelLoop(rewriter, loopOp); 310 } 311 312 // The result of the loop nest may change with peeling. 313 if (res.tensorResults.size() == loopOp->getNumResults() && 314 std::equal(res.tensorResults.begin(), res.tensorResults.end(), 315 loopOp->getResults().begin())) 316 res.tensorResults = loopResults; 317 } 318 } 319 320 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase( 321 Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const { 322 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 323 if (!linalgOp) 324 return failure(); 325 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 326 return failure(); 327 328 Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options); 329 330 if (!res) 331 return failure(); 332 // Clear filter to stop recursive pattern application. 333 filter.replaceLinalgTransformationFilter(rewriter, res->op); 334 335 // Peel loops. 336 peelLoops(rewriter, *res, options); 337 338 // Consider padding on the fly only if the op has tensor semantics. 339 if (!options.paddingValueComputationFunction || 340 !linalgOp.hasTensorSemantics()) { 341 result = *res; 342 return success(); 343 } 344 345 // Try to pad on the fly by rewriting res->op as a padded op. If successful, 346 // `res.op` is rewritten in static form with padded operands. 347 LinalgOp paddedOp; 348 if (succeeded(rewriteAsPaddedOp( 349 rewriter, res->op, options.paddingValueComputationFunction, 350 options.paddingNoFoldComputationFunction, paddedOp))) { 351 filter.replaceLinalgTransformationFilter(rewriter, paddedOp); 352 res->op = paddedOp; 353 result = *res; 354 // Do not perform replacement of `linalgOp`, let the derived patterns 355 // do this as they see fit, from the resulting TiledLinalgOp. 356 return success(); 357 } 358 // Set so RAII guard does not propagate TiledLinalgOp to `result`. 359 return failure(); 360 } 361 362 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) { 363 if (tiledOp.loops.empty()) 364 return tiledOp.op.getOperation()->getResults(); 365 return tiledOp.loops.front()->getResults(); 366 } 367 368 static ValueRange 369 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) { 370 if (tiledAndFusedOp.fusedLoops.empty()) 371 return tiledAndFusedOp.op.getOperation()->getResults(); 372 return tiledAndFusedOp.fusedLoops.front()->getResults(); 373 } 374 375 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( 376 StringRef opName, MLIRContext *context, 377 const LinalgDependenceGraph &dependenceGraph, 378 LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, 379 LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker, 380 LinalgTransformationFilter originalOpMarker, PatternBenefit benefit) 381 : RewritePattern(opName, benefit, context, {}), 382 dependenceGraph(dependenceGraph), tilingOptions(tilingOptions), 383 fusionOptions(fusionOptions), filter(filter), 384 fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {} 385 386 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( 387 Operation *op, PatternRewriter &rewriter) const { 388 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 389 // TODO: remove hasIndexSemantics check once index ops are supported. 390 if (!linalgOp || linalgOp.hasIndexSemantics()) 391 return failure(); 392 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 393 return failure(); 394 395 DenseSet<Operation *> producers; 396 producers.insert(linalgOp); 397 for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) { 398 Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum(); 399 // When looking at dependences into, indexingOp is always OpOperand. We 400 // could assert, but continue if this is not the case. 401 if (!operandNumber) 402 continue; 403 if (!fusionOptions.indicesToFuse.count(operandNumber.getValue())) 404 continue; 405 if (isa<LinalgOp>(dependence.getDependentOp())) 406 producers.insert(dependence.getDependentOp()); 407 } 408 409 SmallVector<LinalgOp, 1> fusionOps; 410 for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; 411 ++it) { 412 auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it)); 413 if (producerLinalgOp && producers.count(producerLinalgOp)) 414 fusionOps.push_back(producerLinalgOp); 415 } 416 fusionOps.push_back(linalgOp); 417 418 SmallVector<Value, 4> tileSizes = 419 tilingOptions.tileSizeComputationFunction(rewriter, op); 420 LinalgTilingOptions instanceTilingOptions = tilingOptions; 421 instanceTilingOptions.setTileSizes(tileSizes); 422 Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps( 423 rewriter, fusionOps, dependenceGraph, instanceTilingOptions); 424 if (!tiledAndFusedOps) 425 return failure(); 426 427 // Tile the unfused loops; 428 SmallVector<Value, 4> unfusedLoopTileSizes; 429 Value zero = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0); 430 for (auto tileSize : enumerate(tileSizes)) { 431 if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) 432 unfusedLoopTileSizes.push_back(zero); 433 else 434 unfusedLoopTileSizes.push_back(tileSize.value()); 435 } 436 // Tile the loop only if there is a non-zero tile size. 437 if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) 438 unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); 439 if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { 440 if (auto cst = val.getDefiningOp<arith::ConstantIndexOp>()) 441 return cst.value() != 0; 442 return true; 443 })) { 444 LinalgTilingOptions unfusedTilingOptions = tilingOptions; 445 unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); 446 Optional<TiledLinalgOp> unfusedTiledOp = 447 tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); 448 if (!unfusedTiledOp) 449 return failure(); 450 rewriter.replaceOp(tiledAndFusedOps->op, 451 getTiledOpResult(unfusedTiledOp.getValue())); 452 tiledAndFusedOps->op = unfusedTiledOp->op; 453 } 454 op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue())); 455 456 filter.replaceLinalgTransformationFilter(rewriter, 457 tiledAndFusedOps->op.getOperation()); 458 for (auto fusedOp : tiledAndFusedOps->fusedProducers) { 459 fusedOpMarker.replaceLinalgTransformationFilter(rewriter, 460 fusedOp.getOperation()); 461 } 462 for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) { 463 originalOpMarker.replaceLinalgTransformationFilter( 464 rewriter, origProducerOp.getOperation()); 465 } 466 rewriter.updateRootInPlace(op, [&]() { 467 originalOpMarker.replaceLinalgTransformationFilter(rewriter, op); 468 }); 469 return success(); 470 } 471 472 /// Linalg generic interchange pattern. 473 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern( 474 MLIRContext *context, ArrayRef<unsigned> interchangeVector, 475 LinalgTransformationFilter filter, PatternBenefit benefit) 476 : OpRewritePattern(context, benefit), filter(filter), 477 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 478 479 LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite( 480 GenericOp genericOp, PatternRewriter &rewriter) const { 481 if (failed(filter.checkAndNotify(rewriter, genericOp))) 482 return failure(); 483 if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector))) 484 return failure(); 485 486 // TODO: figure out how this interplays with named ops. In particular this 487 // should break the named op property. 488 rewriter.updateRootInPlace(genericOp, [&]() { 489 interchangeGenericOp(rewriter, genericOp, interchangeVector); 490 // New filter if specified. 491 filter.replaceLinalgTransformationFilter(rewriter, genericOp); 492 }); 493 return success(); 494 } 495 496 /// Linalg generalization pattern. 497 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 498 MLIRContext *context, LinalgTransformationFilter filter, 499 PatternBenefit benefit) 500 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {} 501 502 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 503 StringRef opName, MLIRContext *context, LinalgTransformationFilter filter, 504 PatternBenefit benefit) 505 : RewritePattern(opName, benefit, context, {}), filter(filter) {} 506 507 LogicalResult mlir::linalg::LinalgGeneralizationPattern::matchAndRewrite( 508 Operation *op, PatternRewriter &rewriter) const { 509 if (failed(filter.checkAndNotify(rewriter, op))) 510 return failure(); 511 if (failed(generalizeNamedOpPrecondition(op))) 512 return failure(); 513 514 GenericOp genericOp = generalizeNamedOp(rewriter, op); 515 rewriter.replaceOp(op, genericOp.getResults()); 516 filter.replaceLinalgTransformationFilter(rewriter, genericOp); 517 return success(); 518 } 519 520 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 521 MLIRContext *context, LinalgTransformationFilter filter, 522 LinalgPromotionOptions options, PatternBenefit benefit) 523 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 524 options(options) {} 525 526 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 527 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 528 LinalgTransformationFilter filter, PatternBenefit benefit) 529 : RewritePattern(opName, benefit, context, {}), filter(filter), 530 options(options) {} 531 532 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 533 Operation *op, PatternRewriter &rewriter) const { 534 if (failed(filter.checkAndNotify(rewriter, op))) 535 return failure(); 536 if (failed(promoteSubviewsPrecondition(op, options))) 537 return failure(); 538 539 // TODO: We cannot use root update here. This pattern is creating other ops, 540 // so if the promotion fails, those need to be cleaned up, which doesnt seem 541 // to be happening here. So to fail properly, we should be cloning the op and 542 // deleting the previous op. This needs more investigation. 543 rewriter.startRootUpdate(op); 544 Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options); 545 if (!promotedOp) { 546 rewriter.cancelRootUpdate(op); 547 return op->emitError("subview promotion failed"); 548 } 549 rewriter.finalizeRootUpdate(op); 550 filter.replaceLinalgTransformationFilter(rewriter, op); 551 return success(); 552 } 553 554 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 555 MLIRContext *context, LinalgTransformationFilter filter, 556 PatternBenefit benefit) 557 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {} 558 559 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 560 StringRef opName, MLIRContext *context, LinalgTransformationFilter filter, 561 PatternBenefit benefit) 562 : RewritePattern(opName, benefit, context, {}), filter(filter) {} 563 564 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( 565 Operation *op, PatternRewriter &rewriter) const { 566 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 567 if (!linalgOp) 568 return failure(); 569 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 570 return failure(); 571 SmallVector<Value> newResults; 572 if (failed(vectorizeLinalgOp(rewriter, op, newResults))) 573 return failure(); 574 if (!newResults.empty()) 575 rewriter.replaceOp(op, newResults); 576 else 577 rewriter.eraseOp(op); 578 return success(); 579 } 580 581 LogicalResult mlir::linalg::applyStagedPatterns( 582 Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns, 583 const FrozenRewritePatternSet &stage2Patterns, 584 function_ref<LogicalResult(Operation *)> stage3Lambda) { 585 unsigned iteration = 0; 586 (void)iteration; 587 for (const auto &patterns : stage1Patterns) { 588 LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" 589 << *op); 590 if (failed(applyPatternsAndFoldGreedily(op, patterns))) { 591 LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); 592 return failure(); 593 } 594 LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" 595 << *op); 596 if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { 597 LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); 598 return failure(); 599 } 600 LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" 601 << *op); 602 if (stage3Lambda) { 603 if (failed(stage3Lambda(op))) 604 return failure(); 605 LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" 606 << *op); 607 } 608 } 609 return success(); 610 } 611 612 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) { 613 return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName()); 614 } 615 616 /// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize 617 /// with pad_val) and GenericOp (to copy contents). 618 LogicalResult PadTensorOpTransformationPattern::matchAndRewrite( 619 linalg::PadTensorOp padOp, PatternRewriter &rewriter) const { 620 621 auto inputShapedType = padOp.source().getType().cast<ShapedType>(); 622 auto resultShapedType = padOp.result().getType().cast<ShapedType>(); 623 624 // Bail on non-static shapes. 625 if (!inputShapedType.hasStaticShape()) 626 return failure(); 627 if (!resultShapedType.hasStaticShape()) 628 return failure(); 629 630 // Only support padding with a constant for now, i.e. either: 631 // 1. A BBarg from a different block. 632 // 2. A value defined outside of the current block. 633 Block &block = padOp.region().front(); 634 auto yieldOp = cast<YieldOp>(block.getTerminator()); 635 assert(yieldOp.getNumOperands() == 1 && "expected single operand yield"); 636 Value padValue = yieldOp.values().front(); 637 Operation *definingOp = padValue.getDefiningOp(); 638 if (definingOp && definingOp->getBlock() == &block) 639 return failure(); 640 if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block) 641 return failure(); 642 643 // Create tensor with the padded shape 644 Location loc = padOp.getLoc(); 645 SmallVector<Value> indices(resultShapedType.getRank(), 646 rewriter.create<arith::ConstantIndexOp>(loc, 0)); 647 Value initTensor = rewriter.create<InitTensorOp>( 648 loc, resultShapedType.getShape(), resultShapedType.getElementType()); 649 650 // Initialize tensor with the pad value 651 Value tmpTensor = 652 rewriter.create<linalg::FillOp>(loc, padValue, initTensor).result(); 653 654 // Copy original contents into new tensor 655 // Uses linalg.generic, but could be done with tensor.insert_slice 656 SmallVector<AffineExpr, 4> outputExprs; 657 for (unsigned i = 0; i < resultShapedType.getRank(); ++i) { 658 outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) + 659 padOp.static_low()[i].cast<IntegerAttr>().getInt()); 660 } 661 662 SmallVector<AffineMap, 2> transferMaps = { 663 rewriter.getMultiDimIdentityMap(inputShapedType.getRank()), 664 AffineMap::get(resultShapedType.getRank(), 665 /*symbolCount=*/0, outputExprs, rewriter.getContext())}; 666 667 rewriter.replaceOpWithNewOp<linalg::GenericOp>( 668 padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps, 669 getNParallelLoopsAttrs(resultShapedType.getRank()), 670 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { 671 nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]); 672 }); 673 674 return success(); 675 } 676 677 /// Filling `dest` using FillOp constant padding value if possible. 678 /// Otherwise, generate a tensor::GenerateOp. 679 Value GeneralizePadTensorOpPattern::createFillOrGenerateOp( 680 PatternRewriter &rewriter, PadTensorOp padOp, Value dest, 681 const SmallVector<Value> &dynSizes) const { 682 auto padValue = padOp.getConstantPaddingValue(); 683 if (padValue) 684 return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result(); 685 686 // Fill could not be optimized: Lower to tensor::GenerateOp with region. 687 auto generateOp = rewriter.create<tensor::GenerateOp>( 688 padOp.getLoc(), padOp.getResultType(), dynSizes); 689 // Copy region to new op. 690 BlockAndValueMapping bvm; 691 padOp.region().cloneInto(&generateOp.getRegion(), bvm); 692 // Rewrite linalg::YieldOp to tensor::YieldOp. 693 OpBuilder::InsertionGuard guard(rewriter); 694 auto yieldOp = 695 dyn_cast<linalg::YieldOp>(generateOp.getRegion().front().getTerminator()); 696 assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator"); 697 assert(yieldOp.values().size() == 1); 698 rewriter.setInsertionPoint(yieldOp); 699 rewriter.replaceOpWithNewOp<tensor::YieldOp>(yieldOp, yieldOp.values()[0]); 700 return generateOp; 701 } 702 703 LogicalResult 704 GeneralizePadTensorOpPattern::matchAndRewrite(PadTensorOp padOp, 705 PatternRewriter &rewriter) const { 706 // Given an OpFoldResult, return an index-typed value. 707 auto getIdxValue = [&](OpFoldResult ofr) { 708 if (auto val = ofr.dyn_cast<Value>()) 709 return val; 710 return rewriter 711 .create<arith::ConstantIndexOp>( 712 padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt()) 713 .getResult(); 714 }; 715 716 auto resultType = padOp.getResultType(); 717 // Compute size of InitTensorOp. Any combination of static/dynamic is 718 // supported. 719 SmallVector<Value> dynSizes; 720 SmallVector<int64_t> staticSizes; 721 for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { 722 if (resultType.isDynamicDim(dim)) { 723 auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(), 724 padOp.source(), dim); 725 // Add low and high padding value. 726 auto plusLow = rewriter.createOrFold<arith::AddIOp>( 727 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); 728 auto plusHigh = rewriter.createOrFold<arith::AddIOp>( 729 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); 730 dynSizes.push_back(plusHigh); 731 } 732 staticSizes.push_back(resultType.getDimSize(dim)); 733 } 734 735 // Init tensor and fill it with padding. 736 Value init = rewriter.create<InitTensorOp>( 737 padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType()); 738 Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes); 739 740 // Try optimize the copy of source. 741 if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded()) 742 return success(); 743 744 // PadTensorOps cannot be optimized. Generate a InsertSliceOp instead 745 // for copying the PadOp source. 746 auto sourceType = padOp.getSourceType(); 747 // Compute size of source of PadTensorOp. 748 SmallVector<OpFoldResult> srcSizes; 749 for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) { 750 if (sourceType.isDynamicDim(dim)) { 751 srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>( 752 padOp.getLoc(), padOp.source(), dim)); 753 } else { 754 srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim))); 755 } 756 } 757 // Strides of InsertSliceOp are all 1. 758 SmallVector<OpFoldResult> strides(sourceType.getRank(), 759 rewriter.getIndexAttr(1)); 760 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( 761 padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides); 762 763 return success(); 764 } 765 766 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( 767 tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { 768 auto padOp = sliceOp.source().getDefiningOp<PadTensorOp>(); 769 if (!padOp) 770 return failure(); 771 // Only unit stride supported. 772 if (!sliceOp.hasUnitStride()) 773 return failure(); 774 775 Operation *tiledPadOp = padOp.getTiledImplementation( 776 rewriter, /*dest=*/ValueRange{}, sliceOp.getMixedOffsets(), 777 sliceOp.getMixedSizes()); 778 // All shapes are static and the data source is actually used. Rewrite into 779 // pad_tensor(subtensor(x)). 780 rewriter.replaceOp(sliceOp, tiledPadOp->getResults()); 781 return success(); 782 } 783