1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Affine/Utils.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 18 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 19 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/SCF/Transforms.h" 22 #include "mlir/Dialect/Tensor/IR/Tensor.h" 23 #include "mlir/Dialect/Utils/StaticValueUtils.h" 24 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 25 #include "mlir/Dialect/Vector/VectorOps.h" 26 #include "mlir/IR/AffineExpr.h" 27 #include "mlir/IR/Matchers.h" 28 #include "mlir/Pass/Pass.h" 29 #include "mlir/Support/LLVM.h" 30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 31 #include "llvm/ADT/ScopeExit.h" 32 #include "llvm/ADT/TypeSwitch.h" 33 #include "llvm/Support/Debug.h" 34 #include "llvm/Support/raw_ostream.h" 35 #include <type_traits> 36 37 #define DEBUG_TYPE "linalg-transforms" 38 39 using namespace mlir; 40 using namespace mlir::linalg; 41 42 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 43 44 //===----------------------------------------------------------------------===// 45 // Transformations exposed as rewrite patterns. 46 //===----------------------------------------------------------------------===// 47 // Marker used as attribute name in generated Linalg rewriting transformations. 48 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 49 "__internal_linalg_transform__"; 50 51 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 52 ArrayRef<Identifier> matchDisjunction, Optional<Identifier> replacement) 53 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 54 replacement(replacement), matchByDefault(false) {} 55 56 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 57 FilterFunction f, ArrayRef<Identifier> matchDisjunction, 58 Optional<Identifier> replacement) 59 : filters(), 60 matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 61 replacement(replacement), matchByDefault(false) { 62 if (f) 63 filters.push_back(f); 64 } 65 66 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify( 67 PatternRewriter &rewriter, Operation *op) const { 68 if (llvm::any_of(filters, 69 [&](const FilterFunction &f) { return failed(f(op)); })) 70 return failure(); 71 72 auto attr = op->template getAttrOfType<StringAttr>( 73 LinalgTransforms::kLinalgTransformMarker); 74 75 if (!attr) { 76 // 1. Has no filter case and matchDisjunction is empty. 77 if (matchDisjunction.empty() || matchByDefault) 78 return success(); 79 80 // 2. Has no filter but was expecting a filter. 81 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 82 diag << " does not have any filter from list: "; 83 interleaveComma(matchDisjunction, diag); 84 }); 85 } 86 87 // 4. Match explicit filter. 88 for (auto filter : matchDisjunction) 89 if (attr.getValue() == filter) 90 return success(); 91 92 // 5. Fail to match. 93 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 94 diag << " does not have any filter from list: "; 95 interleaveComma(matchDisjunction, diag); 96 }); 97 } 98 99 void mlir::linalg::LinalgTransformationFilter:: 100 replaceLinalgTransformationFilter(PatternRewriter &rewriter, 101 Operation *op) const { 102 if (replacement.hasValue()) 103 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 104 rewriter.getStringAttr(replacement.getValue().strref())); 105 else 106 op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker, 107 rewriter.getContext())); 108 } 109 110 LinalgTilingOptions & 111 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 112 assert(!tileSizeComputationFunction && "tile sizes already set"); 113 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 114 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 115 OpBuilder::InsertionGuard guard(b); 116 b.setInsertionPointToStart( 117 &op->getParentOfType<FuncOp>().getBody().front()); 118 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 119 Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); 120 return v; 121 })); 122 }; 123 return *this; 124 } 125 126 LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() { 127 assert(!tileSizeComputationFunction && "tile sizes already set"); 128 tileSizeComputationFunction = [](OpBuilder &b, Operation *op) { 129 SmallVector<Value, 4> tileSizes; 130 auto linalgOp = dyn_cast<LinalgOp>(op); 131 if (!linalgOp) 132 return tileSizes; 133 Location loc = linalgOp.getLoc(); 134 auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc); 135 AffineMap map = linalgOp.getShapesToLoopsMap(); 136 if (!map) 137 return tileSizes; 138 auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes); 139 // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile 140 // size 0). 141 for (Value shapeSize : shapeSizes) 142 tileSizes.push_back(getConstantIntValue(shapeSize).hasValue() 143 ? b.create<arith::ConstantIndexOp>(loc, 0) 144 : b.create<arith::ConstantIndexOp>(loc, 1)); 145 return tileSizes; 146 }; 147 return *this; 148 } 149 150 /// Helper function that tries to pad `opOperand`. Exit early and return success 151 /// for scalar operands or if `paddingFunc` returns failure. Otherwise, try to 152 /// pad the operand even if it already has a static shape. Set `result` to the 153 /// result of the created PadTensorOp or return failure if the operand cannot be 154 /// padded to a static shape. 155 static LogicalResult padOperandToSmallestStaticBoundingBox( 156 OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand, 157 const PaddingValueComputationFunction &paddingFunc, 158 const PaddingNoFoldComputationFunction &nofoldFunc, Value &result) { 159 // Can't pad scalars. 160 if (opToPad.getShape(opOperand).empty()) 161 return success(); 162 // Can't pad if no padding value is known. 163 FailureOr<Value> paddingValue = paddingFunc(b, *opOperand); 164 if (failed(paddingValue)) 165 return success(); 166 auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>(); 167 // Not a slice op, cannot construct a static bounding box. 168 if (!sliceOp) 169 return failure(); 170 SmallVector<int64_t> staticSizes; 171 staticSizes.reserve(opToPad.getRank(opOperand)); 172 auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation()); 173 for (auto size : shapedOp.getMixedSizes()) { 174 // If the size is an attribute add it directly to `staticSizes`. 175 if (size.is<Attribute>()) { 176 staticSizes.push_back( 177 size.get<Attribute>().dyn_cast<IntegerAttr>().getInt()); 178 continue; 179 } 180 // Otherwise, try to compute a constant upper bound for the size value. 181 FailureOr<int64_t> upperBound = 182 getConstantUpperBoundForIndex(size.get<Value>()); 183 if (failed(upperBound)) { 184 LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding"); 185 return failure(); 186 } 187 staticSizes.push_back(upperBound.getValue()); 188 } 189 auto staticTensorType = RankedTensorType::get( 190 staticSizes, getElementTypeOrSelf(opOperand->get())); 191 bool nofold = nofoldFunc ? nofoldFunc(*opOperand) : false; 192 result = linalg::PadTensorOp::createPadHighOp( 193 staticTensorType, opOperand->get(), paddingValue.getValue(), 194 /*nofold=*/nofold, opToPad->getLoc(), b); 195 return success(); 196 } 197 198 FailureOr<SmallVector<Value>> 199 linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad, 200 const PaddingValueComputationFunction &paddingFunc, 201 const PaddingNoFoldComputationFunction &nofoldFunc, 202 LinalgOp &paddedOp) { 203 Location loc = opToPad->getLoc(); 204 205 // TODO: there are cases where we may still want to pad to larger sizes. 206 assert(opToPad.hasTensorSemantics() && 207 "expected operation to have tensor semantics"); 208 209 OpBuilder::InsertionGuard g(b); 210 // Set IP after op because we also take the dims of the original output. 211 b.setInsertionPointAfter(opToPad); 212 // Make a copy of the shaped operands and update it. 213 SmallVector<Value> newOperands; 214 newOperands.reserve(opToPad.getNumInputsAndOutputs()); 215 for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) { 216 Value paddedOperand; 217 // If padding was requested but the shape cannot be bounded statically then 218 // the pattern fails to apply. 219 if (failed(padOperandToSmallestStaticBoundingBox( 220 b, opToPad, opOperand, paddingFunc, nofoldFunc, paddedOperand))) 221 return failure(); 222 newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get()); 223 } 224 225 SmallVector<SmallVector<Value>> reifiedResultShapes; 226 if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation()) 227 .reifyResultShapes(b, reifiedResultShapes))) 228 return failure(); 229 assert(reifiedResultShapes.size() == opToPad->getNumResults() && 230 "expected same number of results"); 231 232 // Clone `opToPad` to operate on the statically padded shapes. 233 auto resultTensorTypes = 234 ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes(); 235 paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands); 236 237 // Recover the slice out of the new static results. This keeps the original 238 // linalg op around because it uses the dims of the original results. 239 SmallVector<Value> paddedSubviewResults; 240 paddedSubviewResults.reserve(opToPad->getNumResults()); 241 for (auto en : llvm::enumerate(paddedOp->getResults())) { 242 Value paddedResult = en.value(); 243 int64_t resultNumber = en.index(); 244 int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank(); 245 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0)); 246 SmallVector<OpFoldResult> sizes; 247 for (Value v : reifiedResultShapes[resultNumber]) 248 sizes.push_back(v); 249 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1)); 250 paddedSubviewResults.push_back(b.create<tensor::ExtractSliceOp>( 251 loc, paddedResult, offsets, sizes, strides)); 252 } 253 return paddedSubviewResults; 254 } 255 256 /// Linalg base tiling pattern. 257 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 258 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 259 LinalgTransformationFilter filter, PatternBenefit benefit) 260 : RewritePattern(opName, benefit, context), filter(filter), 261 options(options) {} 262 263 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 264 MLIRContext *context, LinalgTilingOptions options, 265 LinalgTransformationFilter filter, PatternBenefit benefit) 266 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 267 options(options) {} 268 269 /// Try to peel a loop `op` and return the new result. 270 // TODO: Add support for scf.parallel and affine.for loops. 271 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) { 272 return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op) 273 .Case<scf::ForOp>([&](scf::ForOp forOp) { 274 scf::ForOp partialIteration; 275 if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp, 276 partialIteration))) 277 return partialIteration->getResults(); 278 assert(!partialIteration && "expected that loop was not peeled"); 279 return forOp->getResults(); 280 }) 281 .Default([&](Operation *op) { return op->getResults(); }); 282 } 283 284 /// Try to peel a TiledLoopOp and return the new result. 285 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, 286 TiledLoopOp tiledLoop, int64_t idx) { 287 assert(idx < static_cast<int64_t>(tiledLoop.iterator_types().size()) && 288 "requested peeling of non-existing loop"); 289 TiledLoopOp result; 290 if (succeeded(peelAndCanonicalizeTiledLoop(rewriter, tiledLoop, idx, result))) 291 return result->getResults(); 292 assert(!result && "expected that loop was not peeled"); 293 return tiledLoop->getResults(); 294 } 295 296 /// Peel loops after tiling. 297 static void peelLoops(RewriterBase &rewriter, TiledLinalgOp &res, 298 const LinalgTilingOptions &options) { 299 for (int64_t loop : options.peeledLoops) { 300 assert(loop < static_cast<int64_t>(res.loops.size()) && 301 "requested peeling of non-existing loop"); 302 SmallVector<Value, 4> loopResults; 303 Operation *loopOp = res.loops[loop]; 304 if (options.loopType == LinalgTilingLoopType::TiledLoops) { 305 assert(llvm::all_of( 306 res.loops, 307 [&](Operation *op) { return op == res.loops.front(); }) && 308 "expected that all loop ops are the same TiledLoopOp"); 309 auto tiledLoopOp = dyn_cast<TiledLoopOp>(loopOp); 310 assert(tiledLoopOp && "expected TiledLoopOp"); 311 loopResults = peelLoop(rewriter, tiledLoopOp, loop); 312 } else { 313 loopResults = peelLoop(rewriter, loopOp); 314 } 315 316 // The result of the loop nest may change with peeling. 317 if (res.tensorResults.size() == loopOp->getNumResults() && 318 std::equal(res.tensorResults.begin(), res.tensorResults.end(), 319 loopOp->getResults().begin())) 320 res.tensorResults = loopResults; 321 } 322 } 323 324 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase( 325 Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const { 326 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 327 if (!linalgOp) 328 return failure(); 329 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 330 return failure(); 331 332 Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options); 333 334 if (!res) 335 return failure(); 336 // Clear filter to stop recursive pattern application. 337 filter.replaceLinalgTransformationFilter(rewriter, res->op); 338 339 // Peel loops. 340 peelLoops(rewriter, *res, options); 341 342 result = *res; 343 return success(); 344 } 345 346 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) { 347 if (tiledOp.loops.empty()) 348 return tiledOp.op.getOperation()->getResults(); 349 return tiledOp.loops.front()->getResults(); 350 } 351 352 static ValueRange 353 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) { 354 if (tiledAndFusedOp.fusedLoops.empty()) 355 return tiledAndFusedOp.op.getOperation()->getResults(); 356 return tiledAndFusedOp.fusedLoops.front()->getResults(); 357 } 358 359 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( 360 StringRef opName, MLIRContext *context, 361 const LinalgDependenceGraph &dependenceGraph, 362 LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, 363 LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker, 364 LinalgTransformationFilter originalOpMarker, PatternBenefit benefit) 365 : RewritePattern(opName, benefit, context, {}), 366 dependenceGraph(dependenceGraph), tilingOptions(tilingOptions), 367 fusionOptions(fusionOptions), filter(filter), 368 fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {} 369 370 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( 371 Operation *op, PatternRewriter &rewriter) const { 372 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 373 // TODO: remove hasIndexSemantics check once index ops are supported. 374 if (!linalgOp || linalgOp.hasIndexSemantics()) 375 return failure(); 376 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 377 return failure(); 378 379 DenseSet<Operation *> producers; 380 producers.insert(linalgOp); 381 for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) { 382 Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum(); 383 // When looking at dependences into, indexingOp is always OpOperand. We 384 // could assert, but continue if this is not the case. 385 if (!operandNumber) 386 continue; 387 if (!fusionOptions.indicesToFuse.count(operandNumber.getValue())) 388 continue; 389 if (isa<LinalgOp>(dependence.getDependentOp())) 390 producers.insert(dependence.getDependentOp()); 391 } 392 393 SmallVector<LinalgOp, 1> fusionOps; 394 for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; 395 ++it) { 396 auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it)); 397 if (producerLinalgOp && producers.count(producerLinalgOp)) 398 fusionOps.push_back(producerLinalgOp); 399 } 400 fusionOps.push_back(linalgOp); 401 402 SmallVector<Value, 4> tileSizes = 403 tilingOptions.tileSizeComputationFunction(rewriter, op); 404 LinalgTilingOptions instanceTilingOptions = tilingOptions; 405 instanceTilingOptions.setTileSizes(tileSizes); 406 Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps( 407 rewriter, fusionOps, dependenceGraph, instanceTilingOptions); 408 if (!tiledAndFusedOps) 409 return failure(); 410 411 // Tile the unfused loops; 412 SmallVector<Value, 4> unfusedLoopTileSizes; 413 Value zero = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0); 414 for (auto tileSize : enumerate(tileSizes)) { 415 if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) 416 unfusedLoopTileSizes.push_back(zero); 417 else 418 unfusedLoopTileSizes.push_back(tileSize.value()); 419 } 420 // Tile the loop only if there is a non-zero tile size. 421 if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) 422 unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); 423 if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { 424 if (auto cst = val.getDefiningOp<arith::ConstantIndexOp>()) 425 return cst.value() != 0; 426 return true; 427 })) { 428 LinalgTilingOptions unfusedTilingOptions = tilingOptions; 429 unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); 430 Optional<TiledLinalgOp> unfusedTiledOp = 431 tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); 432 if (!unfusedTiledOp) 433 return failure(); 434 rewriter.replaceOp(tiledAndFusedOps->op, 435 getTiledOpResult(unfusedTiledOp.getValue())); 436 tiledAndFusedOps->op = unfusedTiledOp->op; 437 } 438 op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue())); 439 440 filter.replaceLinalgTransformationFilter(rewriter, 441 tiledAndFusedOps->op.getOperation()); 442 for (auto fusedOp : tiledAndFusedOps->fusedProducers) { 443 fusedOpMarker.replaceLinalgTransformationFilter(rewriter, 444 fusedOp.getOperation()); 445 } 446 for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) { 447 originalOpMarker.replaceLinalgTransformationFilter( 448 rewriter, origProducerOp.getOperation()); 449 } 450 rewriter.updateRootInPlace(op, [&]() { 451 originalOpMarker.replaceLinalgTransformationFilter(rewriter, op); 452 }); 453 return success(); 454 } 455 456 /// Linalg padding pattern. 457 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( 458 MLIRContext *context, LinalgPaddingOptions options, 459 LinalgTransformationFilter filter, PatternBenefit benefit) 460 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 461 options(options) {} 462 463 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( 464 StringRef opName, MLIRContext *context, LinalgPaddingOptions options, 465 LinalgTransformationFilter filter, PatternBenefit benefit) 466 : RewritePattern(opName, benefit, context, {}), filter(filter), 467 options(options) {} 468 469 LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite( 470 Operation *op, PatternRewriter &rewriter) const { 471 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 472 if (!linalgOp) 473 return failure(); 474 if (!linalgOp.hasTensorSemantics()) 475 return failure(); 476 if (failed(filter.checkAndNotify(rewriter, op))) 477 return failure(); 478 479 // Pad the operation. 480 LinalgOp paddedOp; 481 FailureOr<SmallVector<Value>> newResults = rewriteAsPaddedOp( 482 rewriter, linalgOp, options.paddingValueComputationFunction, 483 options.paddingNoFoldComputationFunction, paddedOp); 484 if (failed(newResults)) 485 return failure(); 486 487 // Compute the desired hoisting depths. 488 SmallVector<int64_t> depths; 489 if (options.paddingHoistComputationFunction) { 490 for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) 491 depths.push_back(options.paddingHoistComputationFunction(*opOperand)); 492 } 493 494 // Hoist the padding. 495 for (auto en : enumerate(depths)) { 496 OpOperand &opOperand = paddedOp->getOpOperand(en.index()); 497 auto padTensorOp = opOperand.get().getDefiningOp<PadTensorOp>(); 498 if (!padTensorOp || en.value() == 0) 499 continue; 500 PadTensorOp hoistedOp; 501 FailureOr<Value> newResult = 502 hoistPaddingOnTensors(padTensorOp, en.value(), hoistedOp); 503 if (failed(newResult)) 504 continue; 505 rewriter.replaceOp(padTensorOp, newResult.getValue()); 506 } 507 508 // Replace the original operation to pad. 509 rewriter.replaceOp(op, newResults.getValue()); 510 filter.replaceLinalgTransformationFilter(rewriter, paddedOp); 511 return success(); 512 } 513 514 /// Linalg generic interchange pattern. 515 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern( 516 MLIRContext *context, ArrayRef<unsigned> interchangeVector, 517 LinalgTransformationFilter filter, PatternBenefit benefit) 518 : OpRewritePattern(context, benefit), filter(filter), 519 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 520 521 LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite( 522 GenericOp genericOp, PatternRewriter &rewriter) const { 523 if (failed(filter.checkAndNotify(rewriter, genericOp))) 524 return failure(); 525 if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector))) 526 return failure(); 527 528 // TODO: figure out how this interplays with named ops. In particular this 529 // should break the named op property. 530 rewriter.updateRootInPlace(genericOp, [&]() { 531 interchangeGenericOp(rewriter, genericOp, interchangeVector); 532 // New filter if specified. 533 filter.replaceLinalgTransformationFilter(rewriter, genericOp); 534 }); 535 return success(); 536 } 537 538 /// Linalg generalization pattern. 539 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 540 MLIRContext *context, LinalgTransformationFilter filter, 541 PatternBenefit benefit) 542 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {} 543 544 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 545 StringRef opName, MLIRContext *context, LinalgTransformationFilter filter, 546 PatternBenefit benefit) 547 : RewritePattern(opName, benefit, context, {}), filter(filter) {} 548 549 LogicalResult mlir::linalg::LinalgGeneralizationPattern::matchAndRewrite( 550 Operation *op, PatternRewriter &rewriter) const { 551 if (failed(filter.checkAndNotify(rewriter, op))) 552 return failure(); 553 if (failed(generalizeNamedOpPrecondition(op))) 554 return failure(); 555 556 GenericOp genericOp = generalizeNamedOp(rewriter, op); 557 rewriter.replaceOp(op, genericOp.getResults()); 558 filter.replaceLinalgTransformationFilter(rewriter, genericOp); 559 return success(); 560 } 561 562 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 563 MLIRContext *context, LinalgTransformationFilter filter, 564 LinalgPromotionOptions options, PatternBenefit benefit) 565 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 566 options(options) {} 567 568 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 569 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 570 LinalgTransformationFilter filter, PatternBenefit benefit) 571 : RewritePattern(opName, benefit, context, {}), filter(filter), 572 options(options) {} 573 574 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 575 Operation *op, PatternRewriter &rewriter) const { 576 if (failed(filter.checkAndNotify(rewriter, op))) 577 return failure(); 578 if (failed(promoteSubviewsPrecondition(op, options))) 579 return failure(); 580 581 // TODO: We cannot use root update here. This pattern is creating other ops, 582 // so if the promotion fails, those need to be cleaned up, which doesnt seem 583 // to be happening here. So to fail properly, we should be cloning the op and 584 // deleting the previous op. This needs more investigation. 585 rewriter.startRootUpdate(op); 586 Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options); 587 if (!promotedOp) { 588 rewriter.cancelRootUpdate(op); 589 return op->emitError("subview promotion failed"); 590 } 591 rewriter.finalizeRootUpdate(op); 592 filter.replaceLinalgTransformationFilter(rewriter, op); 593 return success(); 594 } 595 596 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 597 MLIRContext *context, LinalgTransformationFilter filter, 598 PatternBenefit benefit) 599 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {} 600 601 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 602 StringRef opName, MLIRContext *context, LinalgTransformationFilter filter, 603 PatternBenefit benefit) 604 : RewritePattern(opName, benefit, context, {}), filter(filter) {} 605 606 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( 607 Operation *op, PatternRewriter &rewriter) const { 608 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 609 if (!linalgOp) 610 return failure(); 611 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 612 return failure(); 613 SmallVector<Value> newResults; 614 if (failed(vectorizeLinalgOp(rewriter, op, newResults))) 615 return failure(); 616 if (!newResults.empty()) 617 rewriter.replaceOp(op, newResults); 618 else 619 rewriter.eraseOp(op); 620 return success(); 621 } 622 623 LogicalResult mlir::linalg::applyStagedPatterns( 624 Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns, 625 const FrozenRewritePatternSet &stage2Patterns, 626 function_ref<LogicalResult(Operation *)> stage3Lambda) { 627 unsigned iteration = 0; 628 (void)iteration; 629 for (const auto &patterns : stage1Patterns) { 630 LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" 631 << *op); 632 if (failed(applyPatternsAndFoldGreedily(op, patterns))) { 633 LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); 634 return failure(); 635 } 636 LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" 637 << *op); 638 if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { 639 LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); 640 return failure(); 641 } 642 LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" 643 << *op); 644 if (stage3Lambda) { 645 if (failed(stage3Lambda(op))) 646 return failure(); 647 LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" 648 << *op); 649 } 650 } 651 return success(); 652 } 653 654 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) { 655 return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName()); 656 } 657 658 /// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize 659 /// with pad_val) and GenericOp (to copy contents). 660 LogicalResult PadTensorOpTransformationPattern::matchAndRewrite( 661 linalg::PadTensorOp padOp, PatternRewriter &rewriter) const { 662 663 auto inputShapedType = padOp.source().getType().cast<ShapedType>(); 664 auto resultShapedType = padOp.result().getType().cast<ShapedType>(); 665 666 // Bail on non-static shapes. 667 if (!inputShapedType.hasStaticShape()) 668 return failure(); 669 if (!resultShapedType.hasStaticShape()) 670 return failure(); 671 672 // Only support padding with a constant for now, i.e. either: 673 // 1. A BBarg from a different block. 674 // 2. A value defined outside of the current block. 675 Block &block = padOp.region().front(); 676 auto yieldOp = cast<YieldOp>(block.getTerminator()); 677 assert(yieldOp.getNumOperands() == 1 && "expected single operand yield"); 678 Value padValue = yieldOp.values().front(); 679 Operation *definingOp = padValue.getDefiningOp(); 680 if (definingOp && definingOp->getBlock() == &block) 681 return failure(); 682 if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block) 683 return failure(); 684 685 // Create tensor with the padded shape 686 Location loc = padOp.getLoc(); 687 SmallVector<Value> indices(resultShapedType.getRank(), 688 rewriter.create<arith::ConstantIndexOp>(loc, 0)); 689 Value initTensor = rewriter.create<InitTensorOp>( 690 loc, resultShapedType.getShape(), resultShapedType.getElementType()); 691 692 // Initialize tensor with the pad value 693 Value tmpTensor = 694 rewriter.create<linalg::FillOp>(loc, padValue, initTensor).result(); 695 696 // Copy original contents into new tensor 697 // Uses linalg.generic, but could be done with tensor.insert_slice 698 SmallVector<AffineExpr, 4> outputExprs; 699 for (unsigned i = 0; i < resultShapedType.getRank(); ++i) { 700 outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) + 701 padOp.static_low()[i].cast<IntegerAttr>().getInt()); 702 } 703 704 SmallVector<AffineMap, 2> transferMaps = { 705 rewriter.getMultiDimIdentityMap(inputShapedType.getRank()), 706 AffineMap::get(resultShapedType.getRank(), 707 /*symbolCount=*/0, outputExprs, rewriter.getContext())}; 708 709 rewriter.replaceOpWithNewOp<linalg::GenericOp>( 710 padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps, 711 getNParallelLoopsAttrs(resultShapedType.getRank()), 712 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { 713 nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]); 714 }); 715 716 return success(); 717 } 718 719 /// Filling `dest` using FillOp constant padding value if possible. 720 /// Otherwise, generate a tensor::GenerateOp. 721 Value GeneralizePadTensorOpPattern::createFillOrGenerateOp( 722 PatternRewriter &rewriter, PadTensorOp padOp, Value dest, 723 const SmallVector<Value> &dynSizes) const { 724 auto padValue = padOp.getConstantPaddingValue(); 725 if (padValue) 726 return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result(); 727 728 // Fill could not be optimized: Lower to tensor::GenerateOp with region. 729 auto generateOp = rewriter.create<tensor::GenerateOp>( 730 padOp.getLoc(), padOp.getResultType(), dynSizes); 731 // Copy region to new op. 732 BlockAndValueMapping bvm; 733 padOp.region().cloneInto(&generateOp.getRegion(), bvm); 734 // Rewrite linalg::YieldOp to tensor::YieldOp. 735 OpBuilder::InsertionGuard guard(rewriter); 736 auto yieldOp = 737 dyn_cast<linalg::YieldOp>(generateOp.getRegion().front().getTerminator()); 738 assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator"); 739 assert(yieldOp.values().size() == 1); 740 rewriter.setInsertionPoint(yieldOp); 741 rewriter.replaceOpWithNewOp<tensor::YieldOp>(yieldOp, yieldOp.values()[0]); 742 return generateOp; 743 } 744 745 LogicalResult 746 GeneralizePadTensorOpPattern::matchAndRewrite(PadTensorOp padOp, 747 PatternRewriter &rewriter) const { 748 // Given an OpFoldResult, return an index-typed value. 749 auto getIdxValue = [&](OpFoldResult ofr) { 750 if (auto val = ofr.dyn_cast<Value>()) 751 return val; 752 return rewriter 753 .create<arith::ConstantIndexOp>( 754 padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt()) 755 .getResult(); 756 }; 757 758 auto resultType = padOp.getResultType(); 759 // Compute size of InitTensorOp. Any combination of static/dynamic is 760 // supported. 761 SmallVector<Value> dynSizes; 762 SmallVector<int64_t> staticSizes; 763 for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { 764 if (resultType.isDynamicDim(dim)) { 765 auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(), 766 padOp.source(), dim); 767 // Add low and high padding value. 768 auto plusLow = rewriter.createOrFold<arith::AddIOp>( 769 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); 770 auto plusHigh = rewriter.createOrFold<arith::AddIOp>( 771 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); 772 dynSizes.push_back(plusHigh); 773 } 774 staticSizes.push_back(resultType.getDimSize(dim)); 775 } 776 777 // Init tensor and fill it with padding. 778 Value init = rewriter.create<InitTensorOp>( 779 padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType()); 780 Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes); 781 782 // Try optimize the copy of source. 783 if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded()) 784 return success(); 785 786 // PadTensorOps cannot be optimized. Generate a InsertSliceOp instead 787 // for copying the PadOp source. 788 auto sourceType = padOp.getSourceType(); 789 // Compute size of source of PadTensorOp. 790 SmallVector<OpFoldResult> srcSizes; 791 for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) { 792 if (sourceType.isDynamicDim(dim)) { 793 srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>( 794 padOp.getLoc(), padOp.source(), dim)); 795 } else { 796 srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim))); 797 } 798 } 799 // Strides of InsertSliceOp are all 1. 800 SmallVector<OpFoldResult> strides(sourceType.getRank(), 801 rewriter.getIndexAttr(1)); 802 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( 803 padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides); 804 805 return success(); 806 } 807 808 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( 809 tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { 810 auto padOp = sliceOp.source().getDefiningOp<PadTensorOp>(); 811 if (!padOp) 812 return failure(); 813 // Only unit stride supported. 814 if (!sliceOp.hasUnitStride()) 815 return failure(); 816 817 Operation *tiledPadOp = padOp.getTiledImplementation( 818 rewriter, /*dest=*/ValueRange{}, sliceOp.getMixedOffsets(), 819 sliceOp.getMixedSizes()); 820 // All shapes are static and the data source is actually used. Rewrite into 821 // pad_tensor(subtensor(x)). 822 rewriter.replaceOp(sliceOp, tiledPadOp->getResults()); 823 return success(); 824 } 825 826 namespace { 827 // The following are patterns for downscaling convolution ops with size-1 828 // window dimensions. 829 // 830 // Note that we'd eventually want to write such transformations in a generic 831 // way, e.g., converting to linalg.generic, removing the size-1 dimensions, 832 // and then turning back to named ops. But for now it's fine to have a few 833 // patterns matching special ops to get started. 834 835 /// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D 836 /// convolution ops. 837 struct DownscaleSizeOneWindowed2DConvolution final 838 : public OpRewritePattern<Conv2DNhwcHwcfOp> { 839 using OpRewritePattern::OpRewritePattern; 840 841 LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, 842 PatternRewriter &rewriter) const override { 843 auto linalgOp = cast<linalg::LinalgOp>(*convOp); 844 if (linalgOp.hasBufferSemantics()) 845 return failure(); // To be implemented 846 847 Value input = convOp.inputs().front(); 848 Value filter = convOp.inputs().back(); 849 Value output = convOp.outputs().front(); 850 851 auto inputType = input.getType().dyn_cast<RankedTensorType>(); 852 auto filterType = filter.getType().dyn_cast<RankedTensorType>(); 853 auto outputType = output.getType().dyn_cast<RankedTensorType>(); 854 855 auto inputShape = inputType.getShape(); 856 auto filterShape = filterType.getShape(); 857 auto outputShape = outputType.getShape(); 858 859 // Only handle the case where at least one of the window dimensions is 860 // of size 1. Other cases can rely on tiling to reduce to such cases. 861 int64_t fhSize = filterShape[0], fwSize = filterShape[1]; 862 int64_t ohSize = outputShape[1], owSize = outputShape[2]; 863 if (!(fhSize == 1 && ohSize == 1) && !(fwSize == 1 && owSize == 1)) 864 return failure(); 865 bool removeH = ohSize == 1; 866 867 // Get new shapes and types for all operands by removing the size-1 868 // dimension. 869 870 SmallVector<int64_t, 3> newInputShape{ 871 inputShape[0], inputShape[removeH ? 2 : 1], inputShape[3]}; 872 auto newInputType = RankedTensorType::get( 873 newInputShape, inputType.getElementType(), inputType.getEncoding()); 874 875 SmallVector<int64_t, 3> newFilterShape{filterShape[removeH ? 1 : 0], 876 filterShape[2], filterShape[3]}; 877 auto newFilterType = RankedTensorType::get( 878 newFilterShape, filterType.getElementType(), filterType.getEncoding()); 879 880 SmallVector<int64_t, 3> newOutputShape{ 881 outputShape[0], outputShape[removeH ? 2 : 1], outputShape[3]}; 882 auto newOutputType = RankedTensorType::get( 883 newOutputShape, outputType.getElementType(), outputType.getEncoding()); 884 885 SmallVector<ReassociationIndices, 3> ioReshapeIndices = {{0}, {1, 2}, {3}}; 886 SmallVector<ReassociationIndices, 3> fReshapeIndices = {{0, 1}, {2}, {3}}; 887 888 // Reshape all operands for 1-D convolution. 889 Location loc = convOp.getLoc(); 890 Value newInput = rewriter.create<linalg::TensorCollapseShapeOp>( 891 loc, newInputType, input, ioReshapeIndices); 892 Value newFilter = rewriter.create<linalg::TensorCollapseShapeOp>( 893 loc, newFilterType, filter, fReshapeIndices); 894 Value newOutput = rewriter.create<linalg::TensorCollapseShapeOp>( 895 loc, newOutputType, output, ioReshapeIndices); 896 897 // We need to shrink the strides and dilations too. 898 auto stride = convOp.strides().getValues<int64_t>()[removeH ? 1 : 0]; 899 auto stridesAttr = rewriter.getI64VectorAttr(stride); 900 auto dilation = convOp.dilations().getValues<int64_t>()[removeH ? 1 : 0]; 901 auto dilationsAttr = rewriter.getI64VectorAttr(dilation); 902 903 auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>( 904 loc, newOutputType, ValueRange{newInput, newFilter}, 905 ValueRange{newOutput}, stridesAttr, dilationsAttr); 906 907 rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>( 908 convOp, outputType, conv1DOp.getResult(0), ioReshapeIndices); 909 return success(); 910 }; 911 }; 912 913 } // namespace 914 915 void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, 916 PatternBenefit benefit) { 917 patterns.add<DownscaleSizeOneWindowed2DConvolution>(patterns.getContext(), 918 benefit); 919 } 920