1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Affine/Utils.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 18 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/Dialect/SCF/Transforms.h" 21 #include "mlir/Dialect/Tensor/IR/Tensor.h" 22 #include "mlir/Dialect/Utils/StaticValueUtils.h" 23 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 24 #include "mlir/Dialect/Vector/VectorOps.h" 25 #include "mlir/IR/AffineExpr.h" 26 #include "mlir/IR/Matchers.h" 27 #include "mlir/Pass/Pass.h" 28 #include "mlir/Support/LLVM.h" 29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 30 #include "llvm/ADT/ScopeExit.h" 31 #include "llvm/ADT/TypeSwitch.h" 32 #include "llvm/Support/Debug.h" 33 #include "llvm/Support/raw_ostream.h" 34 #include <type_traits> 35 36 #define DEBUG_TYPE "linalg-transforms" 37 38 using namespace mlir; 39 using namespace mlir::linalg; 40 41 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 42 43 //===----------------------------------------------------------------------===// 44 // Transformations exposed as rewrite patterns. 45 //===----------------------------------------------------------------------===// 46 // Marker used as attribute name in generated Linalg rewriting transformations. 47 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 48 "__internal_linalg_transform__"; 49 50 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 51 ArrayRef<Identifier> matchDisjunction, Optional<Identifier> replacement) 52 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 53 replacement(replacement) {} 54 55 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 56 FilterFunction f, ArrayRef<Identifier> matchDisjunction, 57 Optional<Identifier> replacement) 58 : filters(), 59 matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 60 replacement(replacement) { 61 if (f) 62 filters.push_back(f); 63 } 64 65 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify( 66 PatternRewriter &rewriter, Operation *op) const { 67 if (llvm::any_of(filters, 68 [&](const FilterFunction &f) { return failed(f(op)); })) 69 return failure(); 70 71 auto attr = op->template getAttrOfType<StringAttr>( 72 LinalgTransforms::kLinalgTransformMarker); 73 74 if (!attr) { 75 // 1. Has no filter case and matchDisjunction is empty. 76 if (matchDisjunction.empty()) 77 return success(); 78 79 // 2. Has no filter but was expecting a filter. 80 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 81 diag << " does not have any filter from list: "; 82 interleaveComma(matchDisjunction, diag); 83 }); 84 } 85 86 // 4. Match explicit filter. 87 for (auto filter : matchDisjunction) 88 if (attr.getValue() == filter) 89 return success(); 90 91 // 5. Fail to match. 92 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 93 diag << " does not have any filter from list: "; 94 interleaveComma(matchDisjunction, diag); 95 }); 96 } 97 98 void mlir::linalg::LinalgTransformationFilter:: 99 replaceLinalgTransformationFilter(PatternRewriter &rewriter, 100 Operation *op) const { 101 if (replacement.hasValue()) 102 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 103 rewriter.getStringAttr(replacement.getValue().strref())); 104 else 105 op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker, 106 rewriter.getContext())); 107 } 108 109 LinalgTilingOptions & 110 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 111 assert(!tileSizeComputationFunction && "tile sizes already set"); 112 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 113 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 114 OpBuilder::InsertionGuard guard(b); 115 b.setInsertionPointToStart( 116 &op->getParentOfType<FuncOp>().getBody().front()); 117 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 118 Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); 119 return v; 120 })); 121 }; 122 return *this; 123 } 124 125 LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() { 126 assert(!tileSizeComputationFunction && "tile sizes already set"); 127 tileSizeComputationFunction = [](OpBuilder &b, Operation *op) { 128 SmallVector<Value, 4> tileSizes; 129 auto linalgOp = dyn_cast<LinalgOp>(op); 130 if (!linalgOp) 131 return tileSizes; 132 Location loc = linalgOp.getLoc(); 133 auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc); 134 AffineMap map = linalgOp.getShapesToLoopsMap(); 135 if (!map) 136 return tileSizes; 137 auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes); 138 // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile 139 // size 0). 140 for (Value shapeSize : shapeSizes) 141 tileSizes.push_back(getConstantIntValue(shapeSize).hasValue() 142 ? b.create<arith::ConstantIndexOp>(loc, 0) 143 : b.create<arith::ConstantIndexOp>(loc, 1)); 144 return tileSizes; 145 }; 146 return *this; 147 } 148 149 /// Helper function that tries to pad `opOperand`. Exit early and return success 150 /// for scalar operands or if `paddingFunc` returns failure. Otherwise, try to 151 /// pad the operand even if it already has a static shape. Set `result` to the 152 /// result of the created PadTensorOp or return failure if the operand cannot be 153 /// padded to a static shape. 154 static LogicalResult padOperandToSmallestStaticBoundingBox( 155 OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand, 156 const PaddingValueComputationFunction &paddingFunc, 157 const PaddingNoFoldComputationFunction &nofoldFunc, Value &result) { 158 // Can't pad scalars. 159 if (opToPad.getShape(opOperand).empty()) 160 return success(); 161 // Can't pad if no padding value is known. 162 FailureOr<Value> paddingValue = paddingFunc(b, *opOperand); 163 if (failed(paddingValue)) 164 return success(); 165 auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>(); 166 // Not a slice op, cannot construct a static bounding box. 167 if (!sliceOp) 168 return failure(); 169 SmallVector<int64_t> staticSizes; 170 staticSizes.reserve(opToPad.getRank(opOperand)); 171 auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation()); 172 for (auto size : shapedOp.getMixedSizes()) { 173 auto indexAttr = size.is<Attribute>() 174 ? size.get<Attribute>().dyn_cast<IntegerAttr>() 175 : linalg::getSmallestBoundingIndex(size.get<Value>()); 176 // SmallestBoundingIndex must exist for all sizes. 177 // For now return an error if we can't find it. 178 if (!indexAttr) { 179 LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding"); 180 return failure(); 181 } 182 staticSizes.push_back(indexAttr.getInt()); 183 } 184 auto staticTensorType = RankedTensorType::get( 185 staticSizes, getElementTypeOrSelf(opOperand->get())); 186 bool nofold = nofoldFunc ? nofoldFunc(*opOperand) : false; 187 result = linalg::PadTensorOp::createPadHighOp( 188 staticTensorType, opOperand->get(), paddingValue.getValue(), 189 /*nofold=*/nofold, opToPad->getLoc(), b); 190 return success(); 191 } 192 193 FailureOr<SmallVector<Value>> 194 linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad, 195 const PaddingValueComputationFunction &paddingFunc, 196 const PaddingNoFoldComputationFunction &nofoldFunc, 197 LinalgOp &paddedOp) { 198 Location loc = opToPad->getLoc(); 199 200 // TODO: there are cases where we may still want to pad to larger sizes. 201 assert(opToPad.hasTensorSemantics() && 202 "expected operation to have tensor semantics"); 203 204 OpBuilder::InsertionGuard g(b); 205 // Set IP after op because we also take the dims of the original output. 206 b.setInsertionPointAfter(opToPad); 207 // Make a copy of the shaped operands and update it. 208 SmallVector<Value> newOperands; 209 newOperands.reserve(opToPad.getNumInputsAndOutputs()); 210 for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) { 211 Value paddedOperand; 212 // If padding was requested but the shape cannot be bounded statically then 213 // the pattern fails to apply. 214 if (failed(padOperandToSmallestStaticBoundingBox( 215 b, opToPad, opOperand, paddingFunc, nofoldFunc, paddedOperand))) 216 return failure(); 217 newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get()); 218 } 219 220 SmallVector<SmallVector<Value>> reifiedResultShapes; 221 if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation()) 222 .reifyResultShapes(b, reifiedResultShapes))) 223 return failure(); 224 assert(reifiedResultShapes.size() == opToPad->getNumResults() && 225 "expected same number of results"); 226 227 // Clone `opToPad` to operate on the statically padded shapes. 228 auto resultTensorTypes = 229 ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes(); 230 paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands); 231 232 // Recover the slice out of the new static results. This keeps the original 233 // linalg op around because it uses the dims of the original results. 234 SmallVector<Value> paddedSubviewResults; 235 paddedSubviewResults.reserve(opToPad->getNumResults()); 236 for (auto en : llvm::enumerate(paddedOp->getResults())) { 237 Value paddedResult = en.value(); 238 int64_t resultNumber = en.index(); 239 int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank(); 240 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0)); 241 SmallVector<OpFoldResult> sizes; 242 for (Value v : reifiedResultShapes[resultNumber]) 243 sizes.push_back(v); 244 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1)); 245 paddedSubviewResults.push_back(b.create<tensor::ExtractSliceOp>( 246 loc, paddedResult, offsets, sizes, strides)); 247 } 248 return paddedSubviewResults; 249 } 250 251 /// Linalg base tiling pattern. 252 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 253 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 254 LinalgTransformationFilter filter, PatternBenefit benefit) 255 : RewritePattern(opName, benefit, context), filter(filter), 256 options(options) {} 257 258 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 259 MLIRContext *context, LinalgTilingOptions options, 260 LinalgTransformationFilter filter, PatternBenefit benefit) 261 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 262 options(options) {} 263 264 /// Try to peel a loop `op` and return the new result. 265 // TODO: Add support for scf.parallel and affine.for loops. 266 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) { 267 return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op) 268 .Case<scf::ForOp>([&](scf::ForOp forOp) { 269 scf::ForOp partialIteration; 270 if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp, 271 partialIteration))) 272 return partialIteration->getResults(); 273 assert(!partialIteration && "expected that loop was not peeled"); 274 return forOp->getResults(); 275 }) 276 .Default([&](Operation *op) { return op->getResults(); }); 277 } 278 279 /// Try to peel a TiledLoopOp and return the new result. 280 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, 281 TiledLoopOp tiledLoop, int64_t idx) { 282 assert(idx < static_cast<int64_t>(tiledLoop.iterator_types().size()) && 283 "requested peeling of non-existing loop"); 284 TiledLoopOp result; 285 if (succeeded(peelAndCanonicalizeTiledLoop(rewriter, tiledLoop, idx, result))) 286 return result->getResults(); 287 assert(!result && "expected that loop was not peeled"); 288 return tiledLoop->getResults(); 289 } 290 291 /// Peel loops after tiling. 292 static void peelLoops(RewriterBase &rewriter, TiledLinalgOp &res, 293 const LinalgTilingOptions &options) { 294 for (int64_t loop : options.peeledLoops) { 295 assert(loop < static_cast<int64_t>(res.loops.size()) && 296 "requested peeling of non-existing loop"); 297 SmallVector<Value, 4> loopResults; 298 Operation *loopOp = res.loops[loop]; 299 if (options.loopType == LinalgTilingLoopType::TiledLoops) { 300 assert(llvm::all_of( 301 res.loops, 302 [&](Operation *op) { return op == res.loops.front(); }) && 303 "expected that all loop ops are the same TiledLoopOp"); 304 auto tiledLoopOp = dyn_cast<TiledLoopOp>(loopOp); 305 assert(tiledLoopOp && "expected TiledLoopOp"); 306 loopResults = peelLoop(rewriter, tiledLoopOp, loop); 307 } else { 308 loopResults = peelLoop(rewriter, loopOp); 309 } 310 311 // The result of the loop nest may change with peeling. 312 if (res.tensorResults.size() == loopOp->getNumResults() && 313 std::equal(res.tensorResults.begin(), res.tensorResults.end(), 314 loopOp->getResults().begin())) 315 res.tensorResults = loopResults; 316 } 317 } 318 319 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase( 320 Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const { 321 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 322 if (!linalgOp) 323 return failure(); 324 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 325 return failure(); 326 327 Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options); 328 329 if (!res) 330 return failure(); 331 // Clear filter to stop recursive pattern application. 332 filter.replaceLinalgTransformationFilter(rewriter, res->op); 333 334 // Peel loops. 335 peelLoops(rewriter, *res, options); 336 337 // Consider padding on the fly only if the op has tensor semantics. 338 if (!options.paddingValueComputationFunction || 339 !linalgOp.hasTensorSemantics()) { 340 result = *res; 341 return success(); 342 } 343 344 // Try to pad on the fly by rewriting res->op as a padded op. If successful, 345 // `res.op` is rewritten in static form with padded operands. 346 LinalgOp paddedOp; 347 FailureOr<SmallVector<Value>> newResults = rewriteAsPaddedOp( 348 rewriter, res->op, options.paddingValueComputationFunction, 349 options.paddingNoFoldComputationFunction, paddedOp); 350 if (succeeded(newResults)) { 351 rewriter.replaceOp(res->op, newResults.getValue()); 352 filter.replaceLinalgTransformationFilter(rewriter, paddedOp); 353 res->op = paddedOp; 354 result = *res; 355 // Do not perform replacement of `linalgOp`, let the derived patterns 356 // do this as they see fit, from the resulting TiledLinalgOp. 357 return success(); 358 } 359 // Set so RAII guard does not propagate TiledLinalgOp to `result`. 360 return failure(); 361 } 362 363 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) { 364 if (tiledOp.loops.empty()) 365 return tiledOp.op.getOperation()->getResults(); 366 return tiledOp.loops.front()->getResults(); 367 } 368 369 static ValueRange 370 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) { 371 if (tiledAndFusedOp.fusedLoops.empty()) 372 return tiledAndFusedOp.op.getOperation()->getResults(); 373 return tiledAndFusedOp.fusedLoops.front()->getResults(); 374 } 375 376 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( 377 StringRef opName, MLIRContext *context, 378 const LinalgDependenceGraph &dependenceGraph, 379 LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, 380 LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker, 381 LinalgTransformationFilter originalOpMarker, PatternBenefit benefit) 382 : RewritePattern(opName, benefit, context, {}), 383 dependenceGraph(dependenceGraph), tilingOptions(tilingOptions), 384 fusionOptions(fusionOptions), filter(filter), 385 fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {} 386 387 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( 388 Operation *op, PatternRewriter &rewriter) const { 389 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 390 // TODO: remove hasIndexSemantics check once index ops are supported. 391 if (!linalgOp || linalgOp.hasIndexSemantics()) 392 return failure(); 393 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 394 return failure(); 395 396 DenseSet<Operation *> producers; 397 producers.insert(linalgOp); 398 for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) { 399 Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum(); 400 // When looking at dependences into, indexingOp is always OpOperand. We 401 // could assert, but continue if this is not the case. 402 if (!operandNumber) 403 continue; 404 if (!fusionOptions.indicesToFuse.count(operandNumber.getValue())) 405 continue; 406 if (isa<LinalgOp>(dependence.getDependentOp())) 407 producers.insert(dependence.getDependentOp()); 408 } 409 410 SmallVector<LinalgOp, 1> fusionOps; 411 for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; 412 ++it) { 413 auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it)); 414 if (producerLinalgOp && producers.count(producerLinalgOp)) 415 fusionOps.push_back(producerLinalgOp); 416 } 417 fusionOps.push_back(linalgOp); 418 419 SmallVector<Value, 4> tileSizes = 420 tilingOptions.tileSizeComputationFunction(rewriter, op); 421 LinalgTilingOptions instanceTilingOptions = tilingOptions; 422 instanceTilingOptions.setTileSizes(tileSizes); 423 Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps( 424 rewriter, fusionOps, dependenceGraph, instanceTilingOptions); 425 if (!tiledAndFusedOps) 426 return failure(); 427 428 // Tile the unfused loops; 429 SmallVector<Value, 4> unfusedLoopTileSizes; 430 Value zero = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0); 431 for (auto tileSize : enumerate(tileSizes)) { 432 if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) 433 unfusedLoopTileSizes.push_back(zero); 434 else 435 unfusedLoopTileSizes.push_back(tileSize.value()); 436 } 437 // Tile the loop only if there is a non-zero tile size. 438 if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) 439 unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); 440 if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { 441 if (auto cst = val.getDefiningOp<arith::ConstantIndexOp>()) 442 return cst.value() != 0; 443 return true; 444 })) { 445 LinalgTilingOptions unfusedTilingOptions = tilingOptions; 446 unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); 447 Optional<TiledLinalgOp> unfusedTiledOp = 448 tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); 449 if (!unfusedTiledOp) 450 return failure(); 451 rewriter.replaceOp(tiledAndFusedOps->op, 452 getTiledOpResult(unfusedTiledOp.getValue())); 453 tiledAndFusedOps->op = unfusedTiledOp->op; 454 } 455 op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue())); 456 457 filter.replaceLinalgTransformationFilter(rewriter, 458 tiledAndFusedOps->op.getOperation()); 459 for (auto fusedOp : tiledAndFusedOps->fusedProducers) { 460 fusedOpMarker.replaceLinalgTransformationFilter(rewriter, 461 fusedOp.getOperation()); 462 } 463 for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) { 464 originalOpMarker.replaceLinalgTransformationFilter( 465 rewriter, origProducerOp.getOperation()); 466 } 467 rewriter.updateRootInPlace(op, [&]() { 468 originalOpMarker.replaceLinalgTransformationFilter(rewriter, op); 469 }); 470 return success(); 471 } 472 473 /// Linalg generic interchange pattern. 474 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern( 475 MLIRContext *context, ArrayRef<unsigned> interchangeVector, 476 LinalgTransformationFilter filter, PatternBenefit benefit) 477 : OpRewritePattern(context, benefit), filter(filter), 478 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 479 480 LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite( 481 GenericOp genericOp, PatternRewriter &rewriter) const { 482 if (failed(filter.checkAndNotify(rewriter, genericOp))) 483 return failure(); 484 if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector))) 485 return failure(); 486 487 // TODO: figure out how this interplays with named ops. In particular this 488 // should break the named op property. 489 rewriter.updateRootInPlace(genericOp, [&]() { 490 interchangeGenericOp(rewriter, genericOp, interchangeVector); 491 // New filter if specified. 492 filter.replaceLinalgTransformationFilter(rewriter, genericOp); 493 }); 494 return success(); 495 } 496 497 /// Linalg generalization pattern. 498 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 499 MLIRContext *context, LinalgTransformationFilter filter, 500 PatternBenefit benefit) 501 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {} 502 503 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 504 StringRef opName, MLIRContext *context, LinalgTransformationFilter filter, 505 PatternBenefit benefit) 506 : RewritePattern(opName, benefit, context, {}), filter(filter) {} 507 508 LogicalResult mlir::linalg::LinalgGeneralizationPattern::matchAndRewrite( 509 Operation *op, PatternRewriter &rewriter) const { 510 if (failed(filter.checkAndNotify(rewriter, op))) 511 return failure(); 512 if (failed(generalizeNamedOpPrecondition(op))) 513 return failure(); 514 515 GenericOp genericOp = generalizeNamedOp(rewriter, op); 516 rewriter.replaceOp(op, genericOp.getResults()); 517 filter.replaceLinalgTransformationFilter(rewriter, genericOp); 518 return success(); 519 } 520 521 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 522 MLIRContext *context, LinalgTransformationFilter filter, 523 LinalgPromotionOptions options, PatternBenefit benefit) 524 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 525 options(options) {} 526 527 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 528 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 529 LinalgTransformationFilter filter, PatternBenefit benefit) 530 : RewritePattern(opName, benefit, context, {}), filter(filter), 531 options(options) {} 532 533 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 534 Operation *op, PatternRewriter &rewriter) const { 535 if (failed(filter.checkAndNotify(rewriter, op))) 536 return failure(); 537 if (failed(promoteSubviewsPrecondition(op, options))) 538 return failure(); 539 540 // TODO: We cannot use root update here. This pattern is creating other ops, 541 // so if the promotion fails, those need to be cleaned up, which doesnt seem 542 // to be happening here. So to fail properly, we should be cloning the op and 543 // deleting the previous op. This needs more investigation. 544 rewriter.startRootUpdate(op); 545 Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options); 546 if (!promotedOp) { 547 rewriter.cancelRootUpdate(op); 548 return op->emitError("subview promotion failed"); 549 } 550 rewriter.finalizeRootUpdate(op); 551 filter.replaceLinalgTransformationFilter(rewriter, op); 552 return success(); 553 } 554 555 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 556 MLIRContext *context, LinalgTransformationFilter filter, 557 PatternBenefit benefit) 558 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {} 559 560 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 561 StringRef opName, MLIRContext *context, LinalgTransformationFilter filter, 562 PatternBenefit benefit) 563 : RewritePattern(opName, benefit, context, {}), filter(filter) {} 564 565 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( 566 Operation *op, PatternRewriter &rewriter) const { 567 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 568 if (!linalgOp) 569 return failure(); 570 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 571 return failure(); 572 SmallVector<Value> newResults; 573 if (failed(vectorizeLinalgOp(rewriter, op, newResults))) 574 return failure(); 575 if (!newResults.empty()) 576 rewriter.replaceOp(op, newResults); 577 else 578 rewriter.eraseOp(op); 579 return success(); 580 } 581 582 LogicalResult mlir::linalg::applyStagedPatterns( 583 Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns, 584 const FrozenRewritePatternSet &stage2Patterns, 585 function_ref<LogicalResult(Operation *)> stage3Lambda) { 586 unsigned iteration = 0; 587 (void)iteration; 588 for (const auto &patterns : stage1Patterns) { 589 LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" 590 << *op); 591 if (failed(applyPatternsAndFoldGreedily(op, patterns))) { 592 LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); 593 return failure(); 594 } 595 LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" 596 << *op); 597 if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { 598 LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); 599 return failure(); 600 } 601 LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" 602 << *op); 603 if (stage3Lambda) { 604 if (failed(stage3Lambda(op))) 605 return failure(); 606 LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" 607 << *op); 608 } 609 } 610 return success(); 611 } 612 613 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) { 614 return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName()); 615 } 616 617 /// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize 618 /// with pad_val) and GenericOp (to copy contents). 619 LogicalResult PadTensorOpTransformationPattern::matchAndRewrite( 620 linalg::PadTensorOp padOp, PatternRewriter &rewriter) const { 621 622 auto inputShapedType = padOp.source().getType().cast<ShapedType>(); 623 auto resultShapedType = padOp.result().getType().cast<ShapedType>(); 624 625 // Bail on non-static shapes. 626 if (!inputShapedType.hasStaticShape()) 627 return failure(); 628 if (!resultShapedType.hasStaticShape()) 629 return failure(); 630 631 // Only support padding with a constant for now, i.e. either: 632 // 1. A BBarg from a different block. 633 // 2. A value defined outside of the current block. 634 Block &block = padOp.region().front(); 635 auto yieldOp = cast<YieldOp>(block.getTerminator()); 636 assert(yieldOp.getNumOperands() == 1 && "expected single operand yield"); 637 Value padValue = yieldOp.values().front(); 638 Operation *definingOp = padValue.getDefiningOp(); 639 if (definingOp && definingOp->getBlock() == &block) 640 return failure(); 641 if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block) 642 return failure(); 643 644 // Create tensor with the padded shape 645 Location loc = padOp.getLoc(); 646 SmallVector<Value> indices(resultShapedType.getRank(), 647 rewriter.create<arith::ConstantIndexOp>(loc, 0)); 648 Value initTensor = rewriter.create<InitTensorOp>( 649 loc, resultShapedType.getShape(), resultShapedType.getElementType()); 650 651 // Initialize tensor with the pad value 652 Value tmpTensor = 653 rewriter.create<linalg::FillOp>(loc, padValue, initTensor).result(); 654 655 // Copy original contents into new tensor 656 // Uses linalg.generic, but could be done with tensor.insert_slice 657 SmallVector<AffineExpr, 4> outputExprs; 658 for (unsigned i = 0; i < resultShapedType.getRank(); ++i) { 659 outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) + 660 padOp.static_low()[i].cast<IntegerAttr>().getInt()); 661 } 662 663 SmallVector<AffineMap, 2> transferMaps = { 664 rewriter.getMultiDimIdentityMap(inputShapedType.getRank()), 665 AffineMap::get(resultShapedType.getRank(), 666 /*symbolCount=*/0, outputExprs, rewriter.getContext())}; 667 668 rewriter.replaceOpWithNewOp<linalg::GenericOp>( 669 padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps, 670 getNParallelLoopsAttrs(resultShapedType.getRank()), 671 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { 672 nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]); 673 }); 674 675 return success(); 676 } 677 678 /// Filling `dest` using FillOp constant padding value if possible. 679 /// Otherwise, generate a tensor::GenerateOp. 680 Value GeneralizePadTensorOpPattern::createFillOrGenerateOp( 681 PatternRewriter &rewriter, PadTensorOp padOp, Value dest, 682 const SmallVector<Value> &dynSizes) const { 683 auto padValue = padOp.getConstantPaddingValue(); 684 if (padValue) 685 return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result(); 686 687 // Fill could not be optimized: Lower to tensor::GenerateOp with region. 688 auto generateOp = rewriter.create<tensor::GenerateOp>( 689 padOp.getLoc(), padOp.getResultType(), dynSizes); 690 // Copy region to new op. 691 BlockAndValueMapping bvm; 692 padOp.region().cloneInto(&generateOp.getRegion(), bvm); 693 // Rewrite linalg::YieldOp to tensor::YieldOp. 694 OpBuilder::InsertionGuard guard(rewriter); 695 auto yieldOp = 696 dyn_cast<linalg::YieldOp>(generateOp.getRegion().front().getTerminator()); 697 assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator"); 698 assert(yieldOp.values().size() == 1); 699 rewriter.setInsertionPoint(yieldOp); 700 rewriter.replaceOpWithNewOp<tensor::YieldOp>(yieldOp, yieldOp.values()[0]); 701 return generateOp; 702 } 703 704 LogicalResult 705 GeneralizePadTensorOpPattern::matchAndRewrite(PadTensorOp padOp, 706 PatternRewriter &rewriter) const { 707 // Given an OpFoldResult, return an index-typed value. 708 auto getIdxValue = [&](OpFoldResult ofr) { 709 if (auto val = ofr.dyn_cast<Value>()) 710 return val; 711 return rewriter 712 .create<arith::ConstantIndexOp>( 713 padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt()) 714 .getResult(); 715 }; 716 717 auto resultType = padOp.getResultType(); 718 // Compute size of InitTensorOp. Any combination of static/dynamic is 719 // supported. 720 SmallVector<Value> dynSizes; 721 SmallVector<int64_t> staticSizes; 722 for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { 723 if (resultType.isDynamicDim(dim)) { 724 auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(), 725 padOp.source(), dim); 726 // Add low and high padding value. 727 auto plusLow = rewriter.createOrFold<arith::AddIOp>( 728 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); 729 auto plusHigh = rewriter.createOrFold<arith::AddIOp>( 730 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); 731 dynSizes.push_back(plusHigh); 732 } 733 staticSizes.push_back(resultType.getDimSize(dim)); 734 } 735 736 // Init tensor and fill it with padding. 737 Value init = rewriter.create<InitTensorOp>( 738 padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType()); 739 Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes); 740 741 // Try optimize the copy of source. 742 if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded()) 743 return success(); 744 745 // PadTensorOps cannot be optimized. Generate a InsertSliceOp instead 746 // for copying the PadOp source. 747 auto sourceType = padOp.getSourceType(); 748 // Compute size of source of PadTensorOp. 749 SmallVector<OpFoldResult> srcSizes; 750 for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) { 751 if (sourceType.isDynamicDim(dim)) { 752 srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>( 753 padOp.getLoc(), padOp.source(), dim)); 754 } else { 755 srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim))); 756 } 757 } 758 // Strides of InsertSliceOp are all 1. 759 SmallVector<OpFoldResult> strides(sourceType.getRank(), 760 rewriter.getIndexAttr(1)); 761 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( 762 padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides); 763 764 return success(); 765 } 766 767 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( 768 tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { 769 auto padOp = sliceOp.source().getDefiningOp<PadTensorOp>(); 770 if (!padOp) 771 return failure(); 772 // Only unit stride supported. 773 if (!sliceOp.hasUnitStride()) 774 return failure(); 775 776 Operation *tiledPadOp = padOp.getTiledImplementation( 777 rewriter, /*dest=*/ValueRange{}, sliceOp.getMixedOffsets(), 778 sliceOp.getMixedSizes()); 779 // All shapes are static and the data source is actually used. Rewrite into 780 // pad_tensor(subtensor(x)). 781 rewriter.replaceOp(sliceOp, tiledPadOp->getResults()); 782 return success(); 783 } 784