1 //===- Transforms.cpp - Linalg transformations as patterns ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 17 #include "mlir/Dialect/Linalg/IR/Linalg.h" 18 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/Dialect/SCF/Transforms.h" 21 #include "mlir/Dialect/Tensor/IR/Tensor.h" 22 #include "mlir/Dialect/Utils/StaticValueUtils.h" 23 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 24 #include "mlir/Dialect/Vector/IR/VectorOps.h" 25 #include "mlir/IR/AffineExpr.h" 26 #include "mlir/IR/Matchers.h" 27 #include "mlir/Pass/Pass.h" 28 #include "mlir/Support/LLVM.h" 29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 30 #include "llvm/ADT/ScopeExit.h" 31 #include "llvm/ADT/TypeSwitch.h" 32 #include "llvm/Support/Debug.h" 33 #include "llvm/Support/raw_ostream.h" 34 #include <type_traits> 35 #include <utility> 36 37 #define DEBUG_TYPE "linalg-transforms" 38 39 using namespace mlir; 40 using namespace mlir::linalg; 41 42 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 43 44 //===----------------------------------------------------------------------===// 45 // Transformations exposed as rewrite patterns. 46 //===----------------------------------------------------------------------===// 47 // Marker used as attribute name in generated Linalg rewriting transformations. 48 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 49 "__internal_linalg_transform__"; 50 51 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 52 ArrayRef<StringAttr> matchDisjunction, Optional<StringAttr> replacement) 53 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 54 replacement(replacement), matchByDefault(false) {} 55 56 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 57 const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction, 58 Optional<StringAttr> replacement) 59 : filters(), 60 matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 61 replacement(replacement), matchByDefault(false) { 62 if (f) 63 filters.push_back(f); 64 } 65 66 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify( 67 PatternRewriter &rewriter, Operation *op) const { 68 if (llvm::any_of(filters, 69 [&](const FilterFunction &f) { return failed(f(op)); })) 70 return failure(); 71 72 auto attr = op->template getAttrOfType<StringAttr>( 73 LinalgTransforms::kLinalgTransformMarker); 74 75 if (!attr) { 76 // 1. Has no filter case and matchDisjunction is empty. 77 if (matchDisjunction.empty() || matchByDefault) 78 return success(); 79 80 // 2. Has no filter but was expecting a filter. 81 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 82 diag << " does not have any filter from list: "; 83 interleaveComma(matchDisjunction, diag); 84 }); 85 } 86 87 // 4. Match explicit filter. 88 for (auto filter : matchDisjunction) 89 if (attr.getValue() == filter) 90 return success(); 91 92 // 5. Fail to match. 93 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 94 diag << " does not have any filter from list: "; 95 interleaveComma(matchDisjunction, diag); 96 }); 97 } 98 99 void mlir::linalg::LinalgTransformationFilter:: 100 replaceLinalgTransformationFilter(PatternRewriter &rewriter, 101 Operation *op) const { 102 if (replacement.hasValue()) 103 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 104 replacement.getValue()); 105 else 106 op->removeAttr( 107 rewriter.getStringAttr(LinalgTransforms::kLinalgTransformMarker)); 108 } 109 110 bool mlir::linalg::LinalgTransformationFilter::hasReplacementFilter( 111 Operation *op) const { 112 if (!replacement) 113 return false; 114 auto attr = op->getAttr(LinalgTransforms::kLinalgTransformMarker) 115 .dyn_cast<StringAttr>(); 116 return attr && attr == replacement.getValue(); 117 } 118 119 LinalgTilingOptions & 120 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 121 assert(!tileSizeComputationFunction && "tile sizes already set"); 122 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 123 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 124 OpBuilder::InsertionGuard guard(b); 125 b.setInsertionPointToStart( 126 &op->getParentOfType<FuncOp>().getBody().front()); 127 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 128 Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); 129 return v; 130 })); 131 }; 132 return *this; 133 } 134 135 LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() { 136 assert(!tileSizeComputationFunction && "tile sizes already set"); 137 tileSizeComputationFunction = [](OpBuilder &b, Operation *op) { 138 SmallVector<Value, 4> tileSizes; 139 auto linalgOp = dyn_cast<LinalgOp>(op); 140 if (!linalgOp) 141 return tileSizes; 142 Location loc = linalgOp.getLoc(); 143 auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc); 144 AffineMap map = linalgOp.getShapesToLoopsMap(); 145 if (!map) 146 return tileSizes; 147 auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes); 148 // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile 149 // size 0). 150 for (Value shapeSize : shapeSizes) 151 tileSizes.push_back(getConstantIntValue(shapeSize).hasValue() 152 ? b.create<arith::ConstantIndexOp>(loc, 0) 153 : b.create<arith::ConstantIndexOp>(loc, 1)); 154 return tileSizes; 155 }; 156 return *this; 157 } 158 159 /// Helper function that tries to pad `opOperand`. Exit early for scalar 160 /// operands, if `paddingFunc` returns failure, or if `opOperand` is not defined 161 /// by an ExtractSliceOp. Otherwise, try to pad the operand even if it already 162 /// has a static shape. Set `result` to the result of the created tensor::PadOp 163 /// or and return success if the operand either has been padded to a static 164 /// shape or already had a static shape and failure otherwise. 165 static LogicalResult padOperandToSmallestStaticBoundingBox( 166 OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand, 167 const PaddingValueComputationFunction &paddingFunc, 168 const PaddingNoFoldComputationFunction &nofoldFunc, Value &result) { 169 // Get the shape of the operand and check if it has a dynamic shape. Only 170 // return failure if the operand is not a scalar and has a dynamic shape. 171 ArrayRef<int64_t> shape = opToPad.getShape(opOperand); 172 bool hasDynamicShape = llvm::is_contained(shape, ShapedType::kDynamicSize); 173 174 // Cannot pad scalar operands. 175 if (shape.empty()) 176 return success(); 177 178 // Cannot pad if the padding value is unknown. 179 FailureOr<Value> paddingValue = paddingFunc(b, *opOperand); 180 if (failed(paddingValue)) 181 return failure(hasDynamicShape); 182 183 // Cannot construct a static bounding box if the operand is not defined by an 184 // ExtractSliceOp. 185 auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>(); 186 if (!sliceOp) 187 return failure(hasDynamicShape); 188 189 // Compute the dropped dimensions if `sliceOp` is ranke-reducing. 190 llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims(); 191 192 // Upper bound the `sliceOp` sizes to obtain a static bounding box. 193 SmallVector<int64_t> staticSizes; 194 staticSizes.reserve(shape.size()); 195 auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation()); 196 for (const auto &en : enumerate(shapedOp.getMixedSizes())) { 197 // Skip dropped dimensions. 198 if (droppedDims.test(en.index())) 199 continue; 200 // If the size is an attribute add it directly to `staticSizes`. 201 if (en.value().is<Attribute>()) { 202 staticSizes.push_back( 203 en.value().get<Attribute>().dyn_cast<IntegerAttr>().getInt()); 204 continue; 205 } 206 // Otherwise, try to compute a constant upper bound for the size value. 207 FailureOr<int64_t> upperBound = 208 getConstantUpperBoundForIndex(en.value().get<Value>()); 209 if (failed(upperBound)) { 210 LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding"); 211 return failure(); 212 } 213 staticSizes.push_back(upperBound.getValue()); 214 } 215 assert(staticSizes.size() == shape.size() && 216 "expect the dynamic and static ranks to match"); 217 218 // Pad the operand to the bounding box defined by `staticSizes`. 219 auto staticTensorType = RankedTensorType::get( 220 staticSizes, getElementTypeOrSelf(opOperand->get())); 221 bool nofold = nofoldFunc ? nofoldFunc(*opOperand) : false; 222 result = 223 makeComposedPadHighOp(b, opToPad->getLoc(), staticTensorType, 224 opOperand->get(), paddingValue.getValue(), nofold); 225 return success(); 226 } 227 228 FailureOr<SmallVector<Value>> 229 linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad, 230 const PaddingValueComputationFunction &paddingFunc, 231 const PaddingNoFoldComputationFunction &nofoldFunc, 232 LinalgOp &paddedOp) { 233 Location loc = opToPad->getLoc(); 234 235 // TODO: there are cases where we may still want to pad to larger sizes. 236 assert(opToPad.hasTensorSemantics() && 237 "expected operation to have tensor semantics"); 238 239 OpBuilder::InsertionGuard g(b); 240 // Set IP after op because we also take the dims of the original output. 241 b.setInsertionPointAfter(opToPad); 242 // Make a copy of the shaped operands and update it. 243 SmallVector<Value> newOperands; 244 newOperands.reserve(opToPad.getNumInputsAndOutputs()); 245 for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) { 246 Value paddedOperand; 247 // If padding was requested but the shape cannot be bounded statically then 248 // the pattern fails to apply. 249 if (failed(padOperandToSmallestStaticBoundingBox( 250 b, opToPad, opOperand, paddingFunc, nofoldFunc, paddedOperand))) 251 return failure(); 252 newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get()); 253 } 254 255 SmallVector<SmallVector<Value>> reifiedResultShapes; 256 if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation()) 257 .reifyResultShapes(b, reifiedResultShapes))) 258 return failure(); 259 assert(reifiedResultShapes.size() == opToPad->getNumResults() && 260 "expected same number of results"); 261 262 // Clone `opToPad` to operate on the statically padded shapes. 263 auto resultTensorTypes = 264 ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes(); 265 paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands); 266 267 // Recover the slice out of the new static results. This keeps the original 268 // linalg op around because it uses the dims of the original results. 269 SmallVector<Value> paddedSubviewResults; 270 paddedSubviewResults.reserve(opToPad->getNumResults()); 271 for (const auto &en : llvm::enumerate(paddedOp->getResults())) { 272 Value paddedResult = en.value(); 273 int64_t resultNumber = en.index(); 274 int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank(); 275 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0)); 276 SmallVector<OpFoldResult> sizes; 277 for (Value v : reifiedResultShapes[resultNumber]) 278 sizes.push_back(getAsOpFoldResult(v)); 279 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1)); 280 paddedSubviewResults.push_back(b.create<tensor::ExtractSliceOp>( 281 loc, paddedResult, offsets, sizes, strides)); 282 } 283 return paddedSubviewResults; 284 } 285 286 /// Try to peel a loop `op` and return the new result. 287 // TODO: Add support for scf.parallel and affine.for loops. 288 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) { 289 return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op) 290 .Case<scf::ForOp>([&](scf::ForOp forOp) { 291 scf::ForOp partialIteration; 292 if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp, 293 partialIteration))) 294 return partialIteration->getResults(); 295 assert(!partialIteration && "expected that loop was not peeled"); 296 return forOp->getResults(); 297 }) 298 .Default([&](Operation *op) { return op->getResults(); }); 299 } 300 301 /// Try to peel a TiledLoopOp and return the new result. 302 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, 303 TiledLoopOp tiledLoop, int64_t idx) { 304 assert(idx < static_cast<int64_t>(tiledLoop.iterator_types().size()) && 305 "requested peeling of non-existing loop"); 306 TiledLoopOp result; 307 if (succeeded(peelAndCanonicalizeTiledLoop(rewriter, tiledLoop, idx, result))) 308 return result->getResults(); 309 assert(!result && "expected that loop was not peeled"); 310 return tiledLoop->getResults(); 311 } 312 313 /// Peel loops after tiling. 314 void mlir::linalg::peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res, 315 ArrayRef<int64_t> peeledLoops, 316 LinalgTilingLoopType loopType) { 317 for (int64_t loop : peeledLoops) { 318 assert(loop < static_cast<int64_t>(res.loops.size()) && 319 "requested peeling of non-existing loop"); 320 SmallVector<Value, 4> loopResults; 321 Operation *loopOp = res.loops[loop]; 322 if (loopType == LinalgTilingLoopType::TiledLoops) { 323 assert(llvm::all_of( 324 res.loops, 325 [&](Operation *op) { return op == res.loops.front(); }) && 326 "expected that all loop ops are the same TiledLoopOp"); 327 auto tiledLoopOp = dyn_cast<TiledLoopOp>(loopOp); 328 assert(tiledLoopOp && "expected TiledLoopOp"); 329 loopResults = peelLoop(rewriter, tiledLoopOp, loop); 330 } else { 331 loopResults = peelLoop(rewriter, loopOp); 332 } 333 334 // The result of the loop nest may change with peeling. 335 if (res.tensorResults.size() == loopOp->getNumResults() && 336 std::equal(res.tensorResults.begin(), res.tensorResults.end(), 337 loopOp->getResults().begin())) 338 res.tensorResults = loopResults; 339 } 340 } 341 342 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) { 343 if (tiledOp.loops.empty()) 344 return tiledOp.op.getOperation()->getResults(); 345 return tiledOp.loops.front()->getResults(); 346 } 347 348 static ValueRange 349 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) { 350 if (tiledAndFusedOp.fusedLoops.empty()) 351 return tiledAndFusedOp.op.getOperation()->getResults(); 352 return tiledAndFusedOp.fusedLoops.front()->getResults(); 353 } 354 355 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( 356 StringRef opName, MLIRContext *context, 357 const LinalgDependenceGraph &dependenceGraph, 358 LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, 359 LinalgTransformationFilter f, LinalgTransformationFilter fusedOpMarker, 360 LinalgTransformationFilter originalOpMarker, PatternBenefit benefit) 361 : RewritePattern(opName, benefit, context, {}), 362 dependenceGraph(dependenceGraph), tilingOptions(std::move(tilingOptions)), 363 fusionOptions(std::move(fusionOptions)), filter(std::move(f)), 364 fusedOpMarker(std::move(fusedOpMarker)), 365 originalOpMarker(std::move(originalOpMarker)) {} 366 367 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( 368 Operation *op, PatternRewriter &rewriter) const { 369 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 370 // TODO: remove hasIndexSemantics check once index ops are supported. 371 if (!linalgOp || linalgOp.hasIndexSemantics()) 372 return failure(); 373 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 374 return failure(); 375 376 DenseSet<Operation *> producers; 377 producers.insert(linalgOp); 378 for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) { 379 Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum(); 380 // When looking at dependences into, indexingOp is always OpOperand. We 381 // could assert, but continue if this is not the case. 382 if (!operandNumber) 383 continue; 384 if (!fusionOptions.indicesToFuse.count(operandNumber.getValue())) 385 continue; 386 if (isa<LinalgOp>(dependence.getDependentOp())) 387 producers.insert(dependence.getDependentOp()); 388 } 389 390 SmallVector<LinalgOp, 1> fusionOps; 391 for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; 392 ++it) { 393 auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it)); 394 if (producerLinalgOp && producers.count(producerLinalgOp)) 395 fusionOps.push_back(producerLinalgOp); 396 } 397 fusionOps.push_back(linalgOp); 398 399 SmallVector<Value, 4> tileSizes = 400 tilingOptions.tileSizeComputationFunction(rewriter, op); 401 LinalgTilingOptions instanceTilingOptions = tilingOptions; 402 instanceTilingOptions.setTileSizes(tileSizes); 403 Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps( 404 rewriter, fusionOps, dependenceGraph, instanceTilingOptions); 405 if (!tiledAndFusedOps) 406 return failure(); 407 408 // Tile the unfused loops; 409 SmallVector<Value, 4> unfusedLoopTileSizes; 410 Value zero = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0); 411 for (const auto &tileSize : enumerate(tileSizes)) { 412 if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) 413 unfusedLoopTileSizes.push_back(zero); 414 else 415 unfusedLoopTileSizes.push_back(tileSize.value()); 416 } 417 // Tile the loop only if there is a non-zero tile size. 418 if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) 419 unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); 420 if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { 421 if (auto cst = val.getDefiningOp<arith::ConstantIndexOp>()) 422 return cst.value() != 0; 423 return true; 424 })) { 425 LinalgTilingOptions unfusedTilingOptions = tilingOptions; 426 unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); 427 FailureOr<TiledLinalgOp> unfusedTiledOp = 428 tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); 429 if (failed(unfusedTiledOp)) 430 return failure(); 431 rewriter.replaceOp(tiledAndFusedOps->op, 432 getTiledOpResult(unfusedTiledOp.getValue())); 433 tiledAndFusedOps->op = unfusedTiledOp->op; 434 } 435 op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue())); 436 437 filter.replaceLinalgTransformationFilter(rewriter, 438 tiledAndFusedOps->op.getOperation()); 439 for (auto fusedOp : tiledAndFusedOps->fusedProducers) { 440 fusedOpMarker.replaceLinalgTransformationFilter(rewriter, 441 fusedOp.getOperation()); 442 } 443 for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) { 444 originalOpMarker.replaceLinalgTransformationFilter( 445 rewriter, origProducerOp.getOperation()); 446 } 447 rewriter.updateRootInPlace(op, [&]() { 448 originalOpMarker.replaceLinalgTransformationFilter(rewriter, op); 449 }); 450 return success(); 451 } 452 453 /// Linalg tiling pattern. 454 mlir::linalg::LinalgTilingPattern::LinalgTilingPattern( 455 MLIRContext *context, LinalgTilingOptions options, 456 LinalgTransformationFilter f, PatternBenefit benefit) 457 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 458 filter(std::move(f)), options(std::move(options)) {} 459 460 mlir::linalg::LinalgTilingPattern::LinalgTilingPattern( 461 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 462 LinalgTransformationFilter f, PatternBenefit benefit) 463 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 464 filter(f.addOpNameFilter(opName)), options(std::move(options)) {} 465 466 FailureOr<TiledLinalgOp> 467 mlir::linalg::LinalgTilingPattern::returningMatchAndRewrite( 468 LinalgOp op, PatternRewriter &rewriter) const { 469 if (failed(filter.checkAndNotify(rewriter, op))) 470 return failure(); 471 472 FailureOr<TiledLinalgOp> res = tileLinalgOp(rewriter, op, options); 473 if (failed(res)) 474 return failure(); 475 476 // Clear filter to stop recursive pattern application. 477 // This must be done here to properly propagate to peeling branches. 478 filter.replaceLinalgTransformationFilter(rewriter, res->op); 479 480 // Peel the loops of the TiledLinalgOp. 481 peelTiledLinalgOp(rewriter, *res, options.peeledLoops, options.loopType); 482 483 if (res->tensorResults.empty()) 484 rewriter.eraseOp(op); 485 else 486 rewriter.replaceOp(op, res->tensorResults); 487 488 return res; 489 } 490 491 /// Linalg padding pattern. 492 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( 493 MLIRContext *context, LinalgPaddingOptions options, 494 LinalgTransformationFilter f, PatternBenefit benefit) 495 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 496 filter(std::move(f)), options(std::move(options)) {} 497 498 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( 499 StringRef opName, MLIRContext *context, LinalgPaddingOptions options, 500 LinalgTransformationFilter f, PatternBenefit benefit) 501 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 502 filter(f.addOpNameFilter(opName)), options(std::move(options)) {} 503 504 FailureOr<LinalgOp> 505 mlir::linalg::LinalgPaddingPattern::returningMatchAndRewrite( 506 LinalgOp linalgOp, PatternRewriter &rewriter) const { 507 if (!linalgOp.hasTensorSemantics()) 508 return failure(); 509 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 510 return failure(); 511 512 // Pad the operation. 513 LinalgOp paddedOp; 514 FailureOr<SmallVector<Value>> newResults = rewriteAsPaddedOp( 515 rewriter, linalgOp, options.paddingValueComputationFunction, 516 options.paddingNoFoldComputationFunction, paddedOp); 517 if (failed(newResults)) 518 return failure(); 519 520 // Compute the desired hoisting depths. 521 SmallVector<int64_t> depths; 522 if (options.paddingHoistComputationFunction) { 523 for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) 524 depths.push_back(options.paddingHoistComputationFunction(*opOperand)); 525 } 526 527 // Hoist the padding. 528 for (const auto &en : enumerate(depths)) { 529 OpOperand &opOperand = paddedOp->getOpOperand(en.index()); 530 auto padOp = opOperand.get().getDefiningOp<tensor::PadOp>(); 531 if (!padOp || en.value() == 0) 532 continue; 533 tensor::PadOp hoistedOp; 534 SmallVector<GenericOp> transposeOps; 535 SmallVector<int64_t> transposeVector = 536 options.paddingTransposeComputationFunction(opOperand); 537 538 FailureOr<Value> newResult = hoistPaddingOnTensors( 539 padOp, en.value(), transposeVector, hoistedOp, transposeOps); 540 if (failed(newResult)) 541 continue; 542 rewriter.replaceOp(padOp, newResult.getValue()); 543 544 // Do not apply hoist padding to the newly introduced transpose operations. 545 for (GenericOp transposeOp : transposeOps) 546 filter.replaceLinalgTransformationFilter(rewriter, transposeOp); 547 } 548 549 // Replace the original operation to pad. 550 rewriter.replaceOp(linalgOp, newResults.getValue()); 551 filter.replaceLinalgTransformationFilter(rewriter, paddedOp); 552 553 return paddedOp; 554 } 555 556 /// Linalg tile and fuse tensor ops pattern. 557 mlir::linalg::LinalgTileAndFuseTensorOpsPattern:: 558 LinalgTileAndFuseTensorOpsPattern(MLIRContext *context, 559 LinalgTilingAndFusionOptions options, 560 LinalgTransformationFilter f, 561 PatternBenefit benefit) 562 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), 563 filter(std::move(f)), options(std::move(options)) {} 564 565 mlir::linalg::LinalgTileAndFuseTensorOpsPattern:: 566 LinalgTileAndFuseTensorOpsPattern(StringRef opName, MLIRContext *context, 567 LinalgTilingAndFusionOptions options, 568 LinalgTransformationFilter f, 569 PatternBenefit benefit) 570 : RewritePattern(opName, benefit, context), filter(std::move(f)), 571 options(std::move(options)) {} 572 573 LogicalResult mlir::linalg::LinalgTileAndFuseTensorOpsPattern::matchAndRewrite( 574 Operation *op, PatternRewriter &rewriter) const { 575 LinalgOp rootOp = dyn_cast<LinalgOp>(op); 576 if (!rootOp) 577 return failure(); 578 if (failed(filter.checkAndNotify(rewriter, op))) 579 return failure(); 580 581 // Check `tileSizes` contains a tile size for every `rootOp` loop dimension. 582 if (options.tileSizes.size() < rootOp.getNumLoops()) 583 return rewriter.notifyMatchFailure(op, "expect #tile sizes >= #loops"); 584 585 // Check `tileInterchange` contains no entries or as many as `tileSizes`. 586 if (!options.tileInterchange.empty() && 587 options.tileInterchange.size() != options.tileSizes.size()) 588 return rewriter.notifyMatchFailure( 589 op, "expect the number of tile sizes and interchange dims to match"); 590 591 // Copy the `tileSizes` and `tileInterchange` prefixes needed for `rootOp`. 592 SmallVector<int64_t> rootTileSizes(options.tileSizes.begin(), 593 options.tileSizes.begin() + 594 rootOp.getNumLoops()); 595 SmallVector<int64_t> rootInterchange = 596 options.tileInterchange.empty() 597 ? llvm::to_vector<6>(llvm::seq<int64_t>(0, rootOp.getNumLoops())) 598 : SmallVector<int64_t>(options.tileInterchange.begin(), 599 options.tileInterchange.begin() + 600 rootOp.getNumLoops()); 601 602 // Check `rootTileSizes` contains non-zero tile sizes. 603 if (llvm::count(rootTileSizes, 0) == static_cast<long>(rootTileSizes.size())) 604 return rewriter.notifyMatchFailure( 605 op, "expect at least one non-zero tile size"); 606 607 // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions. 608 // It has to be a permutation since the tiling cannot tile the same loop 609 // dimension multiple times. 610 if (!isPermutation(rootInterchange)) 611 return rewriter.notifyMatchFailure( 612 op, "expect the tile interchange permutes the root loops"); 613 614 // Tile `rootOp` and fuse its producers. 615 FailureOr<TileLoopNest> tileLoopNest = tileConsumerAndFuseProducers( 616 rewriter, rootOp, rootTileSizes, rootInterchange); 617 if (failed(tileLoopNest)) 618 return rewriter.notifyMatchFailure( 619 op, "tileConsumerAndFuseProducers failed unexpectedly"); 620 621 // Replace all uses of the tiled loop operation. 622 rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults()); 623 624 // Apply the filter if specified. 625 for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps()) 626 filter.replaceLinalgTransformationFilter(rewriter, linalgOp); 627 return success(); 628 } 629 630 /// Linalg generic interchange pattern. 631 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern( 632 MLIRContext *context, ArrayRef<unsigned> interchangeVector, 633 LinalgTransformationFilter f, PatternBenefit benefit) 634 : OpRewritePattern(context, benefit), filter(std::move(f)), 635 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 636 637 FailureOr<GenericOp> 638 mlir::linalg::GenericOpInterchangePattern::returningMatchAndRewrite( 639 GenericOp genericOp, PatternRewriter &rewriter) const { 640 if (failed(filter.checkAndNotify(rewriter, genericOp))) 641 return failure(); 642 643 FailureOr<GenericOp> transformedOp = 644 interchangeGenericOp(rewriter, genericOp, interchangeVector); 645 if (failed(transformedOp)) 646 return failure(); 647 648 // New filter if specified. 649 filter.replaceLinalgTransformationFilter(rewriter, genericOp); 650 return transformedOp; 651 } 652 653 /// Linalg generalization pattern. 654 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 655 MLIRContext *context, LinalgTransformationFilter f, PatternBenefit benefit) 656 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 657 filter(std::move(f)) {} 658 659 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 660 StringRef opName, MLIRContext *context, LinalgTransformationFilter f, 661 PatternBenefit benefit) 662 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 663 filter(f.addOpNameFilter(opName)) {} 664 665 FailureOr<GenericOp> 666 mlir::linalg::LinalgGeneralizationPattern::returningMatchAndRewrite( 667 LinalgOp linalgOp, PatternRewriter &rewriter) const { 668 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 669 return failure(); 670 FailureOr<GenericOp> genericOp = generalizeNamedOp(rewriter, linalgOp); 671 if (failed(genericOp)) 672 return failure(); 673 filter.replaceLinalgTransformationFilter(rewriter, *genericOp); 674 return genericOp; 675 } 676 677 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 678 MLIRContext *context, LinalgTransformationFilter f, 679 LinalgPromotionOptions options, PatternBenefit benefit) 680 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), 681 filter(std::move(f)), options(std::move(options)) {} 682 683 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 684 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 685 LinalgTransformationFilter f, PatternBenefit benefit) 686 : RewritePattern(opName, benefit, context, {}), filter(std::move(f)), 687 options(std::move(options)) {} 688 689 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 690 Operation *op, PatternRewriter &rewriter) const { 691 if (failed(filter.checkAndNotify(rewriter, op))) 692 return failure(); 693 if (failed(promoteSubviewsPrecondition(op, options))) 694 return failure(); 695 696 // TODO: We cannot use root update here. This pattern is creating other ops, 697 // so if the promotion fails, those need to be cleaned up, which doesnt seem 698 // to be happening here. So to fail properly, we should be cloning the op and 699 // deleting the previous op. This needs more investigation. 700 rewriter.startRootUpdate(op); 701 Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options); 702 if (!promotedOp) { 703 rewriter.cancelRootUpdate(op); 704 return op->emitError("subview promotion failed"); 705 } 706 rewriter.finalizeRootUpdate(op); 707 filter.replaceLinalgTransformationFilter(rewriter, op); 708 return success(); 709 } 710 711 mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern( 712 MLIRContext *context, LinalgTransformationFilter f, 713 LinalgVectorizationOptions options, PatternBenefit benefit) 714 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 715 filter(std::move(f)) {} 716 717 mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern( 718 StringRef opName, MLIRContext *context, LinalgVectorizationOptions options, 719 LinalgTransformationFilter f, PatternBenefit benefit) 720 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 721 filter(f.addOpNameFilter(opName)) {} 722 723 LogicalResult mlir::linalg::LinalgVectorizationPattern::matchAndRewrite( 724 LinalgOp linalgOp, PatternRewriter &rewriter) const { 725 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 726 return failure(); 727 return vectorize(rewriter, linalgOp); 728 } 729 730 LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite( 731 memref::CopyOp copyOp, PatternRewriter &rewriter) const { 732 return vectorizeCopy(rewriter, copyOp); 733 } 734 735 LogicalResult mlir::linalg::applyStagedPatterns( 736 Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns, 737 const FrozenRewritePatternSet &stage2Patterns, 738 function_ref<LogicalResult(Operation *)> stage3Lambda) { 739 unsigned iteration = 0; 740 (void)iteration; 741 for (const auto &patterns : stage1Patterns) { 742 LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" 743 << *op); 744 if (failed(applyPatternsAndFoldGreedily(op, patterns))) { 745 LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); 746 return failure(); 747 } 748 LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" 749 << *op); 750 if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { 751 LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); 752 return failure(); 753 } 754 LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" 755 << *op); 756 if (stage3Lambda) { 757 if (failed(stage3Lambda(op))) 758 return failure(); 759 LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" 760 << *op); 761 } 762 } 763 return success(); 764 } 765 766 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) { 767 return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName()); 768 } 769 770 /// Rewrite a tensor::PadOp into a sequence of InitTensorOp, FillOp (to 771 /// initialize with pad_val) and GenericOp (to copy contents). 772 LogicalResult 773 PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp, 774 PatternRewriter &rewriter) const { 775 776 auto inputShapedType = padOp.source().getType().cast<ShapedType>(); 777 auto resultShapedType = padOp.result().getType().cast<ShapedType>(); 778 779 // Bail on non-static shapes. 780 if (!inputShapedType.hasStaticShape()) 781 return failure(); 782 if (!resultShapedType.hasStaticShape()) 783 return failure(); 784 785 // Only support padding with a constant for now, i.e. either: 786 // 1. A BBarg from a different block. 787 // 2. A value defined outside of the current block. 788 Block &block = padOp.region().front(); 789 auto yieldOp = cast<tensor::YieldOp>(block.getTerminator()); 790 Value padValue = yieldOp.value(); 791 Operation *definingOp = padValue.getDefiningOp(); 792 if (definingOp && definingOp->getBlock() == &block) 793 return failure(); 794 if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block) 795 return failure(); 796 797 // Create tensor with the padded shape 798 Location loc = padOp.getLoc(); 799 SmallVector<Value> indices(resultShapedType.getRank(), 800 rewriter.create<arith::ConstantIndexOp>(loc, 0)); 801 Value initTensor = rewriter.create<InitTensorOp>( 802 loc, resultShapedType.getShape(), resultShapedType.getElementType()); 803 804 // Initialize tensor with the pad value 805 Value tmpTensor = 806 rewriter.create<linalg::FillOp>(loc, padValue, initTensor).result(); 807 808 // Copy original contents into new tensor 809 // Uses linalg.generic, but could be done with tensor.insert_slice 810 SmallVector<AffineExpr, 4> outputExprs; 811 for (unsigned i = 0; i < resultShapedType.getRank(); ++i) { 812 outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) + 813 padOp.static_low()[i].cast<IntegerAttr>().getInt()); 814 } 815 816 SmallVector<AffineMap, 2> transferMaps = { 817 rewriter.getMultiDimIdentityMap(inputShapedType.getRank()), 818 AffineMap::get(resultShapedType.getRank(), 819 /*symbolCount=*/0, outputExprs, rewriter.getContext())}; 820 821 rewriter.replaceOpWithNewOp<linalg::GenericOp>( 822 padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps, 823 getNParallelLoopsAttrs(resultShapedType.getRank()), 824 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { 825 nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]); 826 }); 827 828 return success(); 829 } 830 831 /// Filling `dest` using FillOp constant padding value if possible. 832 /// Otherwise, generate a tensor::GenerateOp. 833 Value GeneralizePadOpPattern::createFillOrGenerateOp( 834 PatternRewriter &rewriter, tensor::PadOp padOp, Value dest, 835 const SmallVector<Value> &dynSizes) const { 836 auto padValue = padOp.getConstantPaddingValue(); 837 if (padValue) 838 return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result(); 839 840 // Fill could not be optimized: Lower to tensor::GenerateOp with region. 841 auto generateOp = rewriter.create<tensor::GenerateOp>( 842 padOp.getLoc(), padOp.getResultType(), dynSizes); 843 // Copy region to new op. 844 BlockAndValueMapping bvm; 845 padOp.region().cloneInto(&generateOp.getRegion(), bvm); 846 return generateOp; 847 } 848 849 LogicalResult 850 GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp, 851 PatternRewriter &rewriter) const { 852 // Given an OpFoldResult, return an index-typed value. 853 auto getIdxValue = [&](OpFoldResult ofr) { 854 if (auto val = ofr.dyn_cast<Value>()) 855 return val; 856 return rewriter 857 .create<arith::ConstantIndexOp>( 858 padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt()) 859 .getResult(); 860 }; 861 862 auto resultType = padOp.getResultType(); 863 // Compute size of InitTensorOp. Any combination of static/dynamic is 864 // supported. 865 SmallVector<Value> dynSizes; 866 SmallVector<int64_t> staticSizes; 867 for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { 868 if (resultType.isDynamicDim(dim)) { 869 auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(), 870 padOp.source(), dim); 871 // Add low and high padding value. 872 auto plusLow = rewriter.createOrFold<arith::AddIOp>( 873 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); 874 auto plusHigh = rewriter.createOrFold<arith::AddIOp>( 875 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); 876 dynSizes.push_back(plusHigh); 877 } 878 staticSizes.push_back(resultType.getDimSize(dim)); 879 } 880 881 // Init tensor and fill it with padding. 882 Value init = rewriter.create<InitTensorOp>( 883 padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType()); 884 Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes); 885 886 // Try optimize the copy of source. 887 if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded()) 888 return success(); 889 890 // tensor::PadOps cannot be optimized. Generate a InsertSliceOp instead 891 // for copying the PadOp source. 892 auto sourceType = padOp.getSourceType(); 893 // Compute size of source of tensor::PadOp. 894 SmallVector<OpFoldResult> srcSizes; 895 for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) { 896 if (sourceType.isDynamicDim(dim)) { 897 srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>( 898 padOp.getLoc(), padOp.source(), dim)); 899 } else { 900 srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim))); 901 } 902 } 903 // Strides of InsertSliceOp are all 1. 904 SmallVector<OpFoldResult> strides(sourceType.getRank(), 905 rewriter.getIndexAttr(1)); 906 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( 907 padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides); 908 909 return success(); 910 } 911 912 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( 913 tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { 914 auto padOp = sliceOp.source().getDefiningOp<tensor::PadOp>(); 915 if (!padOp) 916 return failure(); 917 // Only unit stride supported. 918 if (!sliceOp.hasUnitStride()) 919 return failure(); 920 921 TilingInterface tilingInterface = 922 dyn_cast<TilingInterface>(padOp.getOperation()); 923 Operation *tiledPadOp = 924 tilingInterface 925 .getTiledImplementation( 926 rewriter, /*dest=*/ValueRange{}, sliceOp.getMixedOffsets(), 927 sliceOp.getMixedSizes(), /*tileDestOperands=*/false) 928 .front(); 929 // All shapes are static and the data source is actually used. Rewrite into 930 // pad_tensor(subtensor(x)). 931 rewriter.replaceOp(sliceOp, tiledPadOp->getResults()); 932 return success(); 933 } 934 935 namespace { 936 // The following are patterns for downscaling convolution ops with size-1 937 // window dimensions. 938 // 939 // Note that we'd eventually want to write such transformations in a generic 940 // way, e.g., converting to linalg.generic, removing the size-1 dimensions, 941 // and then turning back to named ops. But for now it's fine to have a few 942 // patterns matching special ops to get started. 943 944 /// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D 945 /// convolution ops. 946 struct DownscaleSizeOneWindowed2DConvolution final 947 : public OpRewritePattern<Conv2DNhwcHwcfOp> { 948 DownscaleSizeOneWindowed2DConvolution( 949 MLIRContext *context, 950 LinalgTransformationFilter f = LinalgTransformationFilter(), 951 PatternBenefit benefit = 1) 952 : OpRewritePattern<Conv2DNhwcHwcfOp>(context, benefit), 953 filter(std::move(f)) {} 954 955 LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, 956 PatternRewriter &rewriter) const override { 957 if (failed(filter.checkAndNotify(rewriter, convOp))) 958 return failure(); 959 if (convOp.hasBufferSemantics()) 960 return failure(); // To be implemented 961 962 Value input = convOp.inputs().front(); 963 Value kernel = convOp.inputs().back(); 964 Value output = convOp.outputs().front(); 965 966 auto inputType = input.getType().dyn_cast<RankedTensorType>(); 967 auto kernelType = kernel.getType().dyn_cast<RankedTensorType>(); 968 auto outputType = output.getType().dyn_cast<RankedTensorType>(); 969 970 auto kernelShape = kernelType.getShape(); 971 auto outputShape = outputType.getShape(); 972 973 // Only handle the case where at least one of the window dimensions is 974 // of size 1. Other cases can rely on tiling to reduce to such cases. 975 int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; 976 int64_t ohSize = outputShape[1], owSize = outputShape[2]; 977 bool removeH = (khSize == 1 && ohSize == 1); 978 bool removeW = (kwSize == 1 && owSize == 1); 979 if (!removeH && !removeW) 980 return failure(); 981 982 // Get new shapes and types for all operands by removing the size-1 983 // dimension. 984 using RTTBuilder = RankedTensorType::Builder; 985 RankedTensorType newInputType = 986 RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); 987 RankedTensorType newKernelType = 988 RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); 989 RankedTensorType newOutputType = 990 RTTBuilder(outputType).dropDim(removeH ? 1 : 2); 991 992 // Rank-reduce operands. 993 Location loc = convOp.getLoc(); 994 Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( 995 rewriter, loc, input, newInputType); 996 Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( 997 rewriter, loc, kernel, newKernelType); 998 Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( 999 rewriter, loc, output, newOutputType); 1000 1001 // Rank-reduce strides and dilations too. 1002 // TODO: dropDim 1-liner helper. 1003 auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>()); 1004 strides.erase(strides.begin() + (removeH ? 0 : 1)); 1005 auto stridesAttr = rewriter.getI64VectorAttr(strides); 1006 1007 auto dilations = 1008 llvm::to_vector<4>(convOp.dilations().getValues<int64_t>()); 1009 dilations.erase(dilations.begin() + (removeH ? 0 : 1)); 1010 auto dilationsAttr = rewriter.getI64VectorAttr(dilations); 1011 1012 auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>( 1013 loc, newOutputType, ValueRange{newInput, newKernel}, 1014 ValueRange{newOutput}, stridesAttr, dilationsAttr); 1015 1016 // Insert back. 1017 Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( 1018 rewriter, loc, conv1DOp.getResult(0), output); 1019 rewriter.replaceOp(convOp, inserted); 1020 1021 filter.replaceLinalgTransformationFilter(rewriter, conv1DOp); 1022 return success(); 1023 }; 1024 1025 private: 1026 /// LinalgTransformMarker handles special attribute manipulations. 1027 LinalgTransformationFilter filter; 1028 }; 1029 1030 /// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) 1031 /// dimensions into 1-D depthwise convolution ops. 1032 struct DownscaleDepthwiseConv2DNhwcHwcOp final 1033 : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> { 1034 DownscaleDepthwiseConv2DNhwcHwcOp( 1035 MLIRContext *context, 1036 LinalgTransformationFilter f = LinalgTransformationFilter(), 1037 PatternBenefit benefit = 1) 1038 : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit), 1039 filter(std::move(f)) {} 1040 1041 LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, 1042 PatternRewriter &rewriter) const override { 1043 if (failed(filter.checkAndNotify(rewriter, convOp))) 1044 return failure(); 1045 if (convOp.hasBufferSemantics()) 1046 return failure(); // To be implemented 1047 1048 Value input = convOp.inputs().front(); 1049 Value kernel = convOp.inputs().back(); 1050 Value output = convOp.outputs().front(); 1051 1052 auto inputType = input.getType().dyn_cast<RankedTensorType>(); 1053 auto kernelType = kernel.getType().dyn_cast<RankedTensorType>(); 1054 auto outputType = output.getType().dyn_cast<RankedTensorType>(); 1055 1056 auto kernelShape = kernelType.getShape(); 1057 auto outputShape = outputType.getShape(); 1058 1059 // Only handle the case where at least one of the window dimensions is 1060 // of size 1. Other cases can rely on tiling to reduce to such cases. 1061 int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; 1062 int64_t ohSize = outputShape[1], owSize = outputShape[2]; 1063 bool removeH = (khSize == 1 && ohSize == 1); 1064 bool removeW = (kwSize == 1 && owSize == 1); 1065 if (!removeH && !removeW) 1066 return failure(); 1067 1068 // Get new shapes and types for all operands by removing the size-1 1069 // dimension. 1070 using RTTBuilder = RankedTensorType::Builder; 1071 RankedTensorType newInputType = 1072 RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); 1073 RankedTensorType newKernelType = 1074 RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); 1075 RankedTensorType newOutputType = 1076 RTTBuilder(outputType).dropDim(removeH ? 1 : 2); 1077 1078 // Rank-reduce operands. 1079 Location loc = convOp.getLoc(); 1080 Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( 1081 rewriter, loc, input, newInputType); 1082 Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( 1083 rewriter, loc, kernel, newKernelType); 1084 Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( 1085 rewriter, loc, output, newOutputType); 1086 1087 // Rank-reduce strides and dilations too. 1088 // TODO: dropDim 1-liner helper. 1089 auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>()); 1090 strides.erase(strides.begin() + (removeH ? 0 : 1)); 1091 auto stridesAttr = rewriter.getI64VectorAttr(strides); 1092 1093 auto dilations = 1094 llvm::to_vector<4>(convOp.dilations().getValues<int64_t>()); 1095 dilations.erase(dilations.begin() + (removeH ? 0 : 1)); 1096 auto dilationsAttr = rewriter.getI64VectorAttr(dilations); 1097 1098 auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>( 1099 loc, newOutputType, ValueRange{newInput, newKernel}, 1100 ValueRange{newOutput}, stridesAttr, dilationsAttr); 1101 1102 // Insert back. 1103 Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( 1104 rewriter, loc, conv1DOp.getResult(0), output); 1105 rewriter.replaceOp(convOp, inserted); 1106 1107 filter.replaceLinalgTransformationFilter(rewriter, conv1DOp); 1108 return success(); 1109 }; 1110 1111 private: 1112 /// LinalgTransformMarker handles special attribute manipulations. 1113 LinalgTransformationFilter filter; 1114 }; 1115 1116 } // namespace 1117 1118 void linalg::populateDecomposeConvolutionPatterns( 1119 RewritePatternSet &patterns, const LinalgTransformationFilter &filter, 1120 PatternBenefit benefit) { 1121 patterns.add<DownscaleSizeOneWindowed2DConvolution, 1122 DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), filter, 1123 benefit); 1124 } 1125