1 //===- Transforms.cpp - Linalg transformations as patterns ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 17 #include "mlir/Dialect/Linalg/IR/Linalg.h" 18 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/Dialect/SCF/Transforms.h" 21 #include "mlir/Dialect/Tensor/IR/Tensor.h" 22 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" 23 #include "mlir/Dialect/Utils/StaticValueUtils.h" 24 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 25 #include "mlir/Dialect/Vector/IR/VectorOps.h" 26 #include "mlir/IR/AffineExpr.h" 27 #include "mlir/IR/Matchers.h" 28 #include "mlir/Pass/Pass.h" 29 #include "mlir/Support/LLVM.h" 30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 31 #include "llvm/ADT/ScopeExit.h" 32 #include "llvm/ADT/TypeSwitch.h" 33 #include "llvm/Support/Debug.h" 34 #include "llvm/Support/raw_ostream.h" 35 #include <type_traits> 36 #include <utility> 37 38 #define DEBUG_TYPE "linalg-transforms" 39 40 using namespace mlir; 41 using namespace mlir::linalg; 42 43 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 44 45 //===----------------------------------------------------------------------===// 46 // Transformations exposed as rewrite patterns. 47 //===----------------------------------------------------------------------===// 48 // Marker used as attribute name in generated Linalg rewriting transformations. 49 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 50 "__internal_linalg_transform__"; 51 52 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 53 ArrayRef<StringAttr> matchDisjunction, Optional<StringAttr> replacement) 54 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 55 replacement(replacement), matchByDefault(false) {} 56 57 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 58 const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction, 59 Optional<StringAttr> replacement) 60 : filters(), 61 matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 62 replacement(replacement), matchByDefault(false) { 63 if (f) 64 filters.push_back(f); 65 } 66 67 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify( 68 PatternRewriter &rewriter, Operation *op) const { 69 if (llvm::any_of(filters, 70 [&](const FilterFunction &f) { return failed(f(op)); })) 71 return failure(); 72 73 auto attr = op->template getAttrOfType<StringAttr>( 74 LinalgTransforms::kLinalgTransformMarker); 75 76 if (!attr) { 77 // 1. Has no filter case and matchDisjunction is empty. 78 if (matchDisjunction.empty() || matchByDefault) 79 return success(); 80 81 // 2. Has no filter but was expecting a filter. 82 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 83 diag << " does not have any filter from list: "; 84 interleaveComma(matchDisjunction, diag); 85 }); 86 } 87 88 // 4. Match explicit filter. 89 for (auto filter : matchDisjunction) 90 if (attr.getValue() == filter) 91 return success(); 92 93 // 5. Fail to match. 94 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 95 diag << " does not have any filter from list: "; 96 interleaveComma(matchDisjunction, diag); 97 }); 98 } 99 100 void mlir::linalg::LinalgTransformationFilter:: 101 replaceLinalgTransformationFilter(PatternRewriter &rewriter, 102 Operation *op) const { 103 if (replacement.hasValue()) 104 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 105 replacement.getValue()); 106 else 107 op->removeAttr( 108 rewriter.getStringAttr(LinalgTransforms::kLinalgTransformMarker)); 109 } 110 111 bool mlir::linalg::LinalgTransformationFilter::hasReplacementFilter( 112 Operation *op) const { 113 if (!replacement) 114 return false; 115 auto attr = op->getAttr(LinalgTransforms::kLinalgTransformMarker) 116 .dyn_cast<StringAttr>(); 117 return attr && attr == replacement.getValue(); 118 } 119 120 LinalgTilingOptions & 121 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 122 assert(!tileSizeComputationFunction && "tile sizes already set"); 123 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 124 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 125 OpBuilder::InsertionGuard guard(b); 126 b.setInsertionPointToStart( 127 &op->getParentOfType<FuncOp>().getBody().front()); 128 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 129 Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); 130 return v; 131 })); 132 }; 133 return *this; 134 } 135 136 LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() { 137 assert(!tileSizeComputationFunction && "tile sizes already set"); 138 tileSizeComputationFunction = [](OpBuilder &b, Operation *op) { 139 SmallVector<Value, 4> tileSizes; 140 auto linalgOp = dyn_cast<LinalgOp>(op); 141 if (!linalgOp) 142 return tileSizes; 143 Location loc = linalgOp.getLoc(); 144 auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc); 145 AffineMap map = linalgOp.getShapesToLoopsMap(); 146 if (!map) 147 return tileSizes; 148 auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes); 149 // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile 150 // size 0). 151 for (Value shapeSize : shapeSizes) 152 tileSizes.push_back(getConstantIntValue(shapeSize).hasValue() 153 ? b.create<arith::ConstantIndexOp>(loc, 0) 154 : b.create<arith::ConstantIndexOp>(loc, 1)); 155 return tileSizes; 156 }; 157 return *this; 158 } 159 160 /// Helper function that tries to pad `opOperand`. Exit early for scalar 161 /// operands, if `paddingFunc` returns failure, or if `opOperand` is not defined 162 /// by an ExtractSliceOp. Otherwise, try to pad the operand even if it already 163 /// has a static shape. Set `result` to the result of the created tensor::PadOp 164 /// or and return success if the operand either has been padded to a static 165 /// shape or already had a static shape and failure otherwise. 166 static LogicalResult padOperandToSmallestStaticBoundingBox( 167 OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand, 168 const PaddingValueComputationFunction &paddingFunc, 169 const PaddingNoFoldComputationFunction &nofoldFunc, Value &result) { 170 // Get the shape of the operand and check if it has a dynamic shape. Only 171 // return failure if the operand is not a scalar and has a dynamic shape. 172 ArrayRef<int64_t> shape = opToPad.getShape(opOperand); 173 bool hasDynamicShape = llvm::is_contained(shape, ShapedType::kDynamicSize); 174 175 // Cannot pad scalar operands. 176 if (shape.empty()) 177 return success(); 178 179 // Cannot pad if the padding value is unknown. 180 FailureOr<Value> paddingValue = paddingFunc(b, *opOperand); 181 if (failed(paddingValue)) 182 return failure(hasDynamicShape); 183 184 // Cannot construct a static bounding box if the operand is not defined by an 185 // ExtractSliceOp. 186 auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>(); 187 if (!sliceOp) 188 return failure(hasDynamicShape); 189 190 // Compute the dropped dimensions if `sliceOp` is ranke-reducing. 191 llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims(); 192 193 // Upper bound the `sliceOp` sizes to obtain a static bounding box. 194 SmallVector<int64_t> staticSizes; 195 staticSizes.reserve(shape.size()); 196 auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation()); 197 for (const auto &en : enumerate(shapedOp.getMixedSizes())) { 198 // Skip dropped dimensions. 199 if (droppedDims.test(en.index())) 200 continue; 201 // If the size is an attribute add it directly to `staticSizes`. 202 if (en.value().is<Attribute>()) { 203 staticSizes.push_back( 204 en.value().get<Attribute>().dyn_cast<IntegerAttr>().getInt()); 205 continue; 206 } 207 // Otherwise, try to compute a constant upper bound for the size value. 208 FailureOr<int64_t> upperBound = 209 getConstantUpperBoundForIndex(en.value().get<Value>()); 210 if (failed(upperBound)) { 211 LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding"); 212 return failure(); 213 } 214 staticSizes.push_back(upperBound.getValue()); 215 } 216 assert(staticSizes.size() == shape.size() && 217 "expect the dynamic and static ranks to match"); 218 219 // Pad the operand to the bounding box defined by `staticSizes`. 220 auto staticTensorType = RankedTensorType::get( 221 staticSizes, getElementTypeOrSelf(opOperand->get())); 222 bool nofold = nofoldFunc ? nofoldFunc(*opOperand) : false; 223 result = 224 makeComposedPadHighOp(b, opToPad->getLoc(), staticTensorType, 225 opOperand->get(), paddingValue.getValue(), nofold); 226 return success(); 227 } 228 229 FailureOr<SmallVector<Value>> 230 linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad, 231 const PaddingValueComputationFunction &paddingFunc, 232 const PaddingNoFoldComputationFunction &nofoldFunc, 233 LinalgOp &paddedOp) { 234 Location loc = opToPad->getLoc(); 235 236 // TODO: there are cases where we may still want to pad to larger sizes. 237 assert(opToPad.hasTensorSemantics() && 238 "expected operation to have tensor semantics"); 239 240 OpBuilder::InsertionGuard g(b); 241 // Set IP after op because we also take the dims of the original output. 242 b.setInsertionPointAfter(opToPad); 243 // Make a copy of the shaped operands and update it. 244 SmallVector<Value> newOperands; 245 newOperands.reserve(opToPad.getNumInputsAndOutputs()); 246 for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) { 247 Value paddedOperand; 248 // If padding was requested but the shape cannot be bounded statically then 249 // the pattern fails to apply. 250 if (failed(padOperandToSmallestStaticBoundingBox( 251 b, opToPad, opOperand, paddingFunc, nofoldFunc, paddedOperand))) 252 return failure(); 253 newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get()); 254 } 255 256 SmallVector<SmallVector<Value>> reifiedResultShapes; 257 if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation()) 258 .reifyResultShapes(b, reifiedResultShapes))) 259 return failure(); 260 assert(reifiedResultShapes.size() == opToPad->getNumResults() && 261 "expected same number of results"); 262 263 // Clone `opToPad` to operate on the statically padded shapes. 264 auto resultTensorTypes = 265 ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes(); 266 paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands); 267 268 // Recover the slice out of the new static results. This keeps the original 269 // linalg op around because it uses the dims of the original results. 270 SmallVector<Value> paddedSubviewResults; 271 paddedSubviewResults.reserve(opToPad->getNumResults()); 272 for (const auto &en : llvm::enumerate(paddedOp->getResults())) { 273 Value paddedResult = en.value(); 274 int64_t resultNumber = en.index(); 275 int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank(); 276 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0)); 277 SmallVector<OpFoldResult> sizes; 278 for (Value v : reifiedResultShapes[resultNumber]) 279 sizes.push_back(getAsOpFoldResult(v)); 280 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1)); 281 paddedSubviewResults.push_back(b.create<tensor::ExtractSliceOp>( 282 loc, paddedResult, offsets, sizes, strides)); 283 } 284 return paddedSubviewResults; 285 } 286 287 /// Try to peel a loop `op` and return the new result. 288 // TODO: Add support for scf.parallel and affine.for loops. 289 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) { 290 return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op) 291 .Case<scf::ForOp>([&](scf::ForOp forOp) { 292 scf::ForOp partialIteration; 293 if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp, 294 partialIteration))) 295 return partialIteration->getResults(); 296 assert(!partialIteration && "expected that loop was not peeled"); 297 return forOp->getResults(); 298 }) 299 .Default([&](Operation *op) { return op->getResults(); }); 300 } 301 302 /// Try to peel a TiledLoopOp and return the new result. 303 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, 304 TiledLoopOp tiledLoop, int64_t idx) { 305 assert(idx < static_cast<int64_t>(tiledLoop.iterator_types().size()) && 306 "requested peeling of non-existing loop"); 307 TiledLoopOp result; 308 if (succeeded(peelAndCanonicalizeTiledLoop(rewriter, tiledLoop, idx, result))) 309 return result->getResults(); 310 assert(!result && "expected that loop was not peeled"); 311 return tiledLoop->getResults(); 312 } 313 314 /// Peel loops after tiling. 315 void mlir::linalg::peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res, 316 ArrayRef<int64_t> peeledLoops, 317 LinalgTilingLoopType loopType) { 318 for (int64_t loop : peeledLoops) { 319 assert(loop < static_cast<int64_t>(res.loops.size()) && 320 "requested peeling of non-existing loop"); 321 SmallVector<Value, 4> loopResults; 322 Operation *loopOp = res.loops[loop]; 323 if (loopType == LinalgTilingLoopType::TiledLoops) { 324 assert(llvm::all_of( 325 res.loops, 326 [&](Operation *op) { return op == res.loops.front(); }) && 327 "expected that all loop ops are the same TiledLoopOp"); 328 auto tiledLoopOp = dyn_cast<TiledLoopOp>(loopOp); 329 assert(tiledLoopOp && "expected TiledLoopOp"); 330 loopResults = peelLoop(rewriter, tiledLoopOp, loop); 331 } else { 332 loopResults = peelLoop(rewriter, loopOp); 333 } 334 335 // The result of the loop nest may change with peeling. 336 if (res.tensorResults.size() == loopOp->getNumResults() && 337 std::equal(res.tensorResults.begin(), res.tensorResults.end(), 338 loopOp->getResults().begin())) 339 res.tensorResults = loopResults; 340 } 341 } 342 343 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) { 344 if (tiledOp.loops.empty()) 345 return tiledOp.op.getOperation()->getResults(); 346 return tiledOp.loops.front()->getResults(); 347 } 348 349 static ValueRange 350 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) { 351 if (tiledAndFusedOp.fusedLoops.empty()) 352 return tiledAndFusedOp.op.getOperation()->getResults(); 353 return tiledAndFusedOp.fusedLoops.front()->getResults(); 354 } 355 356 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( 357 StringRef opName, MLIRContext *context, 358 const LinalgDependenceGraph &dependenceGraph, 359 LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, 360 LinalgTransformationFilter f, LinalgTransformationFilter fusedOpMarker, 361 LinalgTransformationFilter originalOpMarker, PatternBenefit benefit) 362 : RewritePattern(opName, benefit, context, {}), 363 dependenceGraph(dependenceGraph), tilingOptions(std::move(tilingOptions)), 364 fusionOptions(std::move(fusionOptions)), filter(std::move(f)), 365 fusedOpMarker(std::move(fusedOpMarker)), 366 originalOpMarker(std::move(originalOpMarker)) {} 367 368 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( 369 Operation *op, PatternRewriter &rewriter) const { 370 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 371 // TODO: remove hasIndexSemantics check once index ops are supported. 372 if (!linalgOp || linalgOp.hasIndexSemantics()) 373 return failure(); 374 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 375 return failure(); 376 377 DenseSet<Operation *> producers; 378 producers.insert(linalgOp); 379 for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) { 380 Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum(); 381 // When looking at dependences into, indexingOp is always OpOperand. We 382 // could assert, but continue if this is not the case. 383 if (!operandNumber) 384 continue; 385 if (!fusionOptions.indicesToFuse.count(operandNumber.getValue())) 386 continue; 387 if (isa<LinalgOp>(dependence.getDependentOp())) 388 producers.insert(dependence.getDependentOp()); 389 } 390 391 SmallVector<LinalgOp, 1> fusionOps; 392 for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; 393 ++it) { 394 auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it)); 395 if (producerLinalgOp && producers.count(producerLinalgOp)) 396 fusionOps.push_back(producerLinalgOp); 397 } 398 fusionOps.push_back(linalgOp); 399 400 SmallVector<Value, 4> tileSizes = 401 tilingOptions.tileSizeComputationFunction(rewriter, op); 402 LinalgTilingOptions instanceTilingOptions = tilingOptions; 403 instanceTilingOptions.setTileSizes(tileSizes); 404 Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps( 405 rewriter, fusionOps, dependenceGraph, instanceTilingOptions); 406 if (!tiledAndFusedOps) 407 return failure(); 408 409 // Tile the unfused loops; 410 SmallVector<Value, 4> unfusedLoopTileSizes; 411 Value zero = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0); 412 for (const auto &tileSize : enumerate(tileSizes)) { 413 if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) 414 unfusedLoopTileSizes.push_back(zero); 415 else 416 unfusedLoopTileSizes.push_back(tileSize.value()); 417 } 418 // Tile the loop only if there is a non-zero tile size. 419 if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) 420 unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); 421 if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { 422 if (auto cst = val.getDefiningOp<arith::ConstantIndexOp>()) 423 return cst.value() != 0; 424 return true; 425 })) { 426 LinalgTilingOptions unfusedTilingOptions = tilingOptions; 427 unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); 428 FailureOr<TiledLinalgOp> unfusedTiledOp = 429 tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); 430 if (failed(unfusedTiledOp)) 431 return failure(); 432 rewriter.replaceOp(tiledAndFusedOps->op, 433 getTiledOpResult(unfusedTiledOp.getValue())); 434 tiledAndFusedOps->op = unfusedTiledOp->op; 435 } 436 op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue())); 437 438 filter.replaceLinalgTransformationFilter(rewriter, 439 tiledAndFusedOps->op.getOperation()); 440 for (auto fusedOp : tiledAndFusedOps->fusedProducers) { 441 fusedOpMarker.replaceLinalgTransformationFilter(rewriter, 442 fusedOp.getOperation()); 443 } 444 for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) { 445 originalOpMarker.replaceLinalgTransformationFilter( 446 rewriter, origProducerOp.getOperation()); 447 } 448 rewriter.updateRootInPlace(op, [&]() { 449 originalOpMarker.replaceLinalgTransformationFilter(rewriter, op); 450 }); 451 return success(); 452 } 453 454 /// Linalg tiling pattern. 455 mlir::linalg::LinalgTilingPattern::LinalgTilingPattern( 456 MLIRContext *context, LinalgTilingOptions options, 457 LinalgTransformationFilter f, PatternBenefit benefit) 458 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 459 filter(std::move(f)), options(std::move(options)) {} 460 461 mlir::linalg::LinalgTilingPattern::LinalgTilingPattern( 462 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 463 LinalgTransformationFilter f, PatternBenefit benefit) 464 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 465 filter(f.addOpNameFilter(opName)), options(std::move(options)) {} 466 467 FailureOr<TiledLinalgOp> 468 mlir::linalg::LinalgTilingPattern::returningMatchAndRewrite( 469 LinalgOp op, PatternRewriter &rewriter) const { 470 if (failed(filter.checkAndNotify(rewriter, op))) 471 return failure(); 472 473 FailureOr<TiledLinalgOp> res = tileLinalgOp(rewriter, op, options); 474 if (failed(res)) 475 return failure(); 476 477 // Clear filter to stop recursive pattern application. 478 // This must be done here to properly propagate to peeling branches. 479 filter.replaceLinalgTransformationFilter(rewriter, res->op); 480 481 // Peel the loops of the TiledLinalgOp. 482 peelTiledLinalgOp(rewriter, *res, options.peeledLoops, options.loopType); 483 484 if (res->tensorResults.empty()) 485 rewriter.eraseOp(op); 486 else 487 rewriter.replaceOp(op, res->tensorResults); 488 489 return res; 490 } 491 492 /// Linalg padding pattern. 493 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( 494 MLIRContext *context, LinalgPaddingOptions options, 495 LinalgTransformationFilter f, PatternBenefit benefit) 496 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 497 filter(std::move(f)), options(std::move(options)) {} 498 499 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( 500 StringRef opName, MLIRContext *context, LinalgPaddingOptions options, 501 LinalgTransformationFilter f, PatternBenefit benefit) 502 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 503 filter(f.addOpNameFilter(opName)), options(std::move(options)) {} 504 505 FailureOr<LinalgOp> 506 mlir::linalg::LinalgPaddingPattern::returningMatchAndRewrite( 507 LinalgOp linalgOp, PatternRewriter &rewriter) const { 508 if (!linalgOp.hasTensorSemantics()) 509 return failure(); 510 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 511 return failure(); 512 513 // Pad the operation. 514 LinalgOp paddedOp; 515 FailureOr<SmallVector<Value>> newResults = rewriteAsPaddedOp( 516 rewriter, linalgOp, options.paddingValueComputationFunction, 517 options.paddingNoFoldComputationFunction, paddedOp); 518 if (failed(newResults)) 519 return failure(); 520 521 // Compute the desired hoisting depths. 522 SmallVector<int64_t> depths; 523 if (options.paddingHoistComputationFunction) { 524 for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) 525 depths.push_back(options.paddingHoistComputationFunction(*opOperand)); 526 } 527 528 // Hoist the padding. 529 for (const auto &en : enumerate(depths)) { 530 OpOperand &opOperand = paddedOp->getOpOperand(en.index()); 531 auto padOp = opOperand.get().getDefiningOp<tensor::PadOp>(); 532 if (!padOp || en.value() == 0) 533 continue; 534 tensor::PadOp hoistedOp; 535 SmallVector<GenericOp> transposeOps; 536 SmallVector<int64_t> transposeVector = 537 options.paddingTransposeComputationFunction(opOperand); 538 539 FailureOr<Value> newResult = hoistPaddingOnTensors( 540 padOp, en.value(), transposeVector, hoistedOp, transposeOps); 541 if (failed(newResult)) 542 continue; 543 rewriter.replaceOp(padOp, newResult.getValue()); 544 545 // Do not apply hoist padding to the newly introduced transpose operations. 546 for (GenericOp transposeOp : transposeOps) 547 filter.replaceLinalgTransformationFilter(rewriter, transposeOp); 548 } 549 550 // Replace the original operation to pad. 551 rewriter.replaceOp(linalgOp, newResults.getValue()); 552 filter.replaceLinalgTransformationFilter(rewriter, paddedOp); 553 554 return paddedOp; 555 } 556 557 /// Linalg tile and fuse tensor ops pattern. 558 mlir::linalg::LinalgTileAndFuseTensorOpsPattern:: 559 LinalgTileAndFuseTensorOpsPattern(MLIRContext *context, 560 LinalgTilingAndFusionOptions options, 561 LinalgTransformationFilter f, 562 PatternBenefit benefit) 563 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), 564 filter(std::move(f)), options(std::move(options)) {} 565 566 mlir::linalg::LinalgTileAndFuseTensorOpsPattern:: 567 LinalgTileAndFuseTensorOpsPattern(StringRef opName, MLIRContext *context, 568 LinalgTilingAndFusionOptions options, 569 LinalgTransformationFilter f, 570 PatternBenefit benefit) 571 : RewritePattern(opName, benefit, context), filter(std::move(f)), 572 options(std::move(options)) {} 573 574 LogicalResult mlir::linalg::LinalgTileAndFuseTensorOpsPattern::matchAndRewrite( 575 Operation *op, PatternRewriter &rewriter) const { 576 LinalgOp rootOp = dyn_cast<LinalgOp>(op); 577 if (!rootOp) 578 return failure(); 579 if (failed(filter.checkAndNotify(rewriter, op))) 580 return failure(); 581 582 // Check `tileSizes` contains a tile size for every `rootOp` loop dimension. 583 if (options.tileSizes.size() < rootOp.getNumLoops()) 584 return rewriter.notifyMatchFailure(op, "expect #tile sizes >= #loops"); 585 586 // Check `tileInterchange` contains no entries or as many as `tileSizes`. 587 if (!options.tileInterchange.empty() && 588 options.tileInterchange.size() != options.tileSizes.size()) 589 return rewriter.notifyMatchFailure( 590 op, "expect the number of tile sizes and interchange dims to match"); 591 592 // Copy the `tileSizes` and `tileInterchange` prefixes needed for `rootOp`. 593 SmallVector<int64_t> rootTileSizes(options.tileSizes.begin(), 594 options.tileSizes.begin() + 595 rootOp.getNumLoops()); 596 SmallVector<int64_t> rootInterchange = 597 options.tileInterchange.empty() 598 ? llvm::to_vector<6>(llvm::seq<int64_t>(0, rootOp.getNumLoops())) 599 : SmallVector<int64_t>(options.tileInterchange.begin(), 600 options.tileInterchange.begin() + 601 rootOp.getNumLoops()); 602 603 // Check `rootTileSizes` contains non-zero tile sizes. 604 if (llvm::count(rootTileSizes, 0) == static_cast<long>(rootTileSizes.size())) 605 return rewriter.notifyMatchFailure( 606 op, "expect at least one non-zero tile size"); 607 608 // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions. 609 // It has to be a permutation since the tiling cannot tile the same loop 610 // dimension multiple times. 611 if (!isPermutation(rootInterchange)) 612 return rewriter.notifyMatchFailure( 613 op, "expect the tile interchange permutes the root loops"); 614 615 // Tile `rootOp` and fuse its producers. 616 FailureOr<TileLoopNest> tileLoopNest = 617 tileConsumerAndFuseProducers(rewriter, rootOp, rootTileSizes, 618 rootInterchange, options.tileDistribution); 619 if (failed(tileLoopNest)) 620 return rewriter.notifyMatchFailure( 621 op, "tileConsumerAndFuseProducers failed unexpectedly"); 622 623 // Replace all uses of the tiled loop operation. 624 rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults()); 625 626 // Apply the filter if specified. 627 for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps()) 628 filter.replaceLinalgTransformationFilter(rewriter, linalgOp); 629 return success(); 630 } 631 632 /// Linalg generic interchange pattern. 633 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern( 634 MLIRContext *context, ArrayRef<unsigned> interchangeVector, 635 LinalgTransformationFilter f, PatternBenefit benefit) 636 : OpRewritePattern(context, benefit), filter(std::move(f)), 637 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 638 639 FailureOr<GenericOp> 640 mlir::linalg::GenericOpInterchangePattern::returningMatchAndRewrite( 641 GenericOp genericOp, PatternRewriter &rewriter) const { 642 if (failed(filter.checkAndNotify(rewriter, genericOp))) 643 return failure(); 644 645 FailureOr<GenericOp> transformedOp = 646 interchangeGenericOp(rewriter, genericOp, interchangeVector); 647 if (failed(transformedOp)) 648 return failure(); 649 650 // New filter if specified. 651 filter.replaceLinalgTransformationFilter(rewriter, genericOp); 652 return transformedOp; 653 } 654 655 /// Linalg generalization pattern. 656 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 657 MLIRContext *context, LinalgTransformationFilter f, PatternBenefit benefit) 658 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 659 filter(std::move(f)) {} 660 661 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 662 StringRef opName, MLIRContext *context, LinalgTransformationFilter f, 663 PatternBenefit benefit) 664 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 665 filter(f.addOpNameFilter(opName)) {} 666 667 FailureOr<GenericOp> 668 mlir::linalg::LinalgGeneralizationPattern::returningMatchAndRewrite( 669 LinalgOp linalgOp, PatternRewriter &rewriter) const { 670 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 671 return failure(); 672 FailureOr<GenericOp> genericOp = generalizeNamedOp(rewriter, linalgOp); 673 if (failed(genericOp)) 674 return failure(); 675 filter.replaceLinalgTransformationFilter(rewriter, *genericOp); 676 return genericOp; 677 } 678 679 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 680 MLIRContext *context, LinalgTransformationFilter f, 681 LinalgPromotionOptions options, PatternBenefit benefit) 682 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), 683 filter(std::move(f)), options(std::move(options)) {} 684 685 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 686 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 687 LinalgTransformationFilter f, PatternBenefit benefit) 688 : RewritePattern(opName, benefit, context, {}), filter(std::move(f)), 689 options(std::move(options)) {} 690 691 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 692 Operation *op, PatternRewriter &rewriter) const { 693 if (failed(filter.checkAndNotify(rewriter, op))) 694 return failure(); 695 if (failed(promoteSubviewsPrecondition(op, options))) 696 return failure(); 697 698 // TODO: We cannot use root update here. This pattern is creating other ops, 699 // so if the promotion fails, those need to be cleaned up, which doesnt seem 700 // to be happening here. So to fail properly, we should be cloning the op and 701 // deleting the previous op. This needs more investigation. 702 rewriter.startRootUpdate(op); 703 Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options); 704 if (!promotedOp) { 705 rewriter.cancelRootUpdate(op); 706 return op->emitError("subview promotion failed"); 707 } 708 rewriter.finalizeRootUpdate(op); 709 filter.replaceLinalgTransformationFilter(rewriter, op); 710 return success(); 711 } 712 713 mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern( 714 MLIRContext *context, LinalgTransformationFilter f, 715 LinalgVectorizationOptions options, PatternBenefit benefit) 716 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 717 filter(std::move(f)) {} 718 719 mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern( 720 StringRef opName, MLIRContext *context, LinalgVectorizationOptions options, 721 LinalgTransformationFilter f, PatternBenefit benefit) 722 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 723 filter(f.addOpNameFilter(opName)) {} 724 725 LogicalResult mlir::linalg::LinalgVectorizationPattern::matchAndRewrite( 726 LinalgOp linalgOp, PatternRewriter &rewriter) const { 727 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 728 return failure(); 729 return vectorize(rewriter, linalgOp); 730 } 731 732 LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite( 733 memref::CopyOp copyOp, PatternRewriter &rewriter) const { 734 return vectorizeCopy(rewriter, copyOp); 735 } 736 737 LogicalResult mlir::linalg::applyStagedPatterns( 738 Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns, 739 const FrozenRewritePatternSet &stage2Patterns, 740 function_ref<LogicalResult(Operation *)> stage3Lambda) { 741 unsigned iteration = 0; 742 (void)iteration; 743 for (const auto &patterns : stage1Patterns) { 744 LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" 745 << *op); 746 if (failed(applyPatternsAndFoldGreedily(op, patterns))) { 747 LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); 748 return failure(); 749 } 750 LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" 751 << *op); 752 if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { 753 LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); 754 return failure(); 755 } 756 LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" 757 << *op); 758 if (stage3Lambda) { 759 if (failed(stage3Lambda(op))) 760 return failure(); 761 LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" 762 << *op); 763 } 764 } 765 return success(); 766 } 767 768 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) { 769 return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName()); 770 } 771 772 /// Rewrite a tensor::PadOp into a sequence of InitTensorOp, FillOp (to 773 /// initialize with pad_val) and GenericOp (to copy contents). 774 LogicalResult 775 PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp, 776 PatternRewriter &rewriter) const { 777 778 auto inputShapedType = padOp.source().getType().cast<ShapedType>(); 779 auto resultShapedType = padOp.result().getType().cast<ShapedType>(); 780 781 // Bail on non-static shapes. 782 if (!inputShapedType.hasStaticShape()) 783 return failure(); 784 if (!resultShapedType.hasStaticShape()) 785 return failure(); 786 787 // Only support padding with a constant for now, i.e. either: 788 // 1. A BBarg from a different block. 789 // 2. A value defined outside of the current block. 790 Block &block = padOp.region().front(); 791 auto yieldOp = cast<tensor::YieldOp>(block.getTerminator()); 792 Value padValue = yieldOp.value(); 793 Operation *definingOp = padValue.getDefiningOp(); 794 if (definingOp && definingOp->getBlock() == &block) 795 return failure(); 796 if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block) 797 return failure(); 798 799 // Create tensor with the padded shape 800 Location loc = padOp.getLoc(); 801 SmallVector<Value> indices(resultShapedType.getRank(), 802 rewriter.create<arith::ConstantIndexOp>(loc, 0)); 803 Value initTensor = rewriter.create<InitTensorOp>( 804 loc, resultShapedType.getShape(), resultShapedType.getElementType()); 805 806 // Initialize tensor with the pad value 807 Value tmpTensor = 808 rewriter.create<linalg::FillOp>(loc, padValue, initTensor).result(); 809 810 // Copy original contents into new tensor 811 // Uses linalg.generic, but could be done with tensor.insert_slice 812 SmallVector<AffineExpr, 4> outputExprs; 813 for (unsigned i = 0; i < resultShapedType.getRank(); ++i) { 814 outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) + 815 padOp.static_low()[i].cast<IntegerAttr>().getInt()); 816 } 817 818 SmallVector<AffineMap, 2> transferMaps = { 819 rewriter.getMultiDimIdentityMap(inputShapedType.getRank()), 820 AffineMap::get(resultShapedType.getRank(), 821 /*symbolCount=*/0, outputExprs, rewriter.getContext())}; 822 823 rewriter.replaceOpWithNewOp<linalg::GenericOp>( 824 padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps, 825 getNParallelLoopsAttrs(resultShapedType.getRank()), 826 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { 827 nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]); 828 }); 829 830 return success(); 831 } 832 833 /// Filling `dest` using FillOp constant padding value if possible. 834 /// Otherwise, generate a tensor::GenerateOp. 835 Value GeneralizePadOpPattern::createFillOrGenerateOp( 836 PatternRewriter &rewriter, tensor::PadOp padOp, Value dest, 837 const SmallVector<Value> &dynSizes) const { 838 auto padValue = padOp.getConstantPaddingValue(); 839 if (padValue) 840 return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result(); 841 842 // Fill could not be optimized: Lower to tensor::GenerateOp with region. 843 auto generateOp = rewriter.create<tensor::GenerateOp>( 844 padOp.getLoc(), padOp.getResultType(), dynSizes); 845 // Copy region to new op. 846 BlockAndValueMapping bvm; 847 padOp.region().cloneInto(&generateOp.getRegion(), bvm); 848 return generateOp; 849 } 850 851 LogicalResult 852 GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp, 853 PatternRewriter &rewriter) const { 854 // Given an OpFoldResult, return an index-typed value. 855 auto getIdxValue = [&](OpFoldResult ofr) { 856 if (auto val = ofr.dyn_cast<Value>()) 857 return val; 858 return rewriter 859 .create<arith::ConstantIndexOp>( 860 padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt()) 861 .getResult(); 862 }; 863 864 auto resultType = padOp.getResultType(); 865 // Compute size of InitTensorOp. Any combination of static/dynamic is 866 // supported. 867 SmallVector<Value> dynSizes; 868 SmallVector<int64_t> staticSizes; 869 for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { 870 if (resultType.isDynamicDim(dim)) { 871 auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(), 872 padOp.source(), dim); 873 // Add low and high padding value. 874 auto plusLow = rewriter.createOrFold<arith::AddIOp>( 875 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); 876 auto plusHigh = rewriter.createOrFold<arith::AddIOp>( 877 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); 878 dynSizes.push_back(plusHigh); 879 } 880 staticSizes.push_back(resultType.getDimSize(dim)); 881 } 882 883 // Init tensor and fill it with padding. 884 Value init = rewriter.create<InitTensorOp>( 885 padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType()); 886 Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes); 887 888 // Try optimize the copy of source. 889 if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded()) 890 return success(); 891 892 // tensor::PadOps cannot be optimized. Generate a InsertSliceOp instead 893 // for copying the PadOp source. 894 auto sourceType = padOp.getSourceType(); 895 // Compute size of source of tensor::PadOp. 896 SmallVector<OpFoldResult> srcSizes; 897 for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) { 898 if (sourceType.isDynamicDim(dim)) { 899 srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>( 900 padOp.getLoc(), padOp.source(), dim)); 901 } else { 902 srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim))); 903 } 904 } 905 // Strides of InsertSliceOp are all 1. 906 SmallVector<OpFoldResult> strides(sourceType.getRank(), 907 rewriter.getIndexAttr(1)); 908 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( 909 padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides); 910 911 return success(); 912 } 913 914 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( 915 tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { 916 if (!sliceOp.hasUnitStride()) 917 return failure(); 918 919 auto padOp = sliceOp.source().getDefiningOp<tensor::PadOp>(); 920 if (!padOp) 921 return failure(); 922 923 bool zeroSliceGuard = true; 924 if (controlFn) { 925 if (Optional<bool> control = controlFn(sliceOp)) 926 zeroSliceGuard = control.getValue(); 927 else 928 return failure(); 929 } 930 931 Operation *tiledPadOp = 932 tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(), 933 sliceOp.getMixedSizes(), zeroSliceGuard); 934 // All shapes are static and the data source is actually used. Rewrite into 935 // pad(extract_slice(x)). 936 rewriter.replaceOp(sliceOp, tiledPadOp->getResults()); 937 return success(); 938 } 939 940 namespace { 941 // The following are patterns for downscaling convolution ops with size-1 942 // window dimensions. 943 // 944 // Note that we'd eventually want to write such transformations in a generic 945 // way, e.g., converting to linalg.generic, removing the size-1 dimensions, 946 // and then turning back to named ops. But for now it's fine to have a few 947 // patterns matching special ops to get started. 948 949 /// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D 950 /// convolution ops. 951 struct DownscaleSizeOneWindowed2DConvolution final 952 : public OpRewritePattern<Conv2DNhwcHwcfOp> { 953 DownscaleSizeOneWindowed2DConvolution( 954 MLIRContext *context, 955 LinalgTransformationFilter f = LinalgTransformationFilter(), 956 PatternBenefit benefit = 1) 957 : OpRewritePattern<Conv2DNhwcHwcfOp>(context, benefit), 958 filter(std::move(f)) {} 959 960 LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, 961 PatternRewriter &rewriter) const override { 962 if (failed(filter.checkAndNotify(rewriter, convOp))) 963 return failure(); 964 if (convOp.hasBufferSemantics()) 965 return failure(); // To be implemented 966 967 Value input = convOp.inputs().front(); 968 Value kernel = convOp.inputs().back(); 969 Value output = convOp.outputs().front(); 970 971 auto inputType = input.getType().dyn_cast<RankedTensorType>(); 972 auto kernelType = kernel.getType().dyn_cast<RankedTensorType>(); 973 auto outputType = output.getType().dyn_cast<RankedTensorType>(); 974 975 auto kernelShape = kernelType.getShape(); 976 auto outputShape = outputType.getShape(); 977 978 // Only handle the case where at least one of the window dimensions is 979 // of size 1. Other cases can rely on tiling to reduce to such cases. 980 int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; 981 int64_t ohSize = outputShape[1], owSize = outputShape[2]; 982 bool removeH = (khSize == 1 && ohSize == 1); 983 bool removeW = (kwSize == 1 && owSize == 1); 984 if (!removeH && !removeW) 985 return failure(); 986 987 // Get new shapes and types for all operands by removing the size-1 988 // dimension. 989 using RTTBuilder = RankedTensorType::Builder; 990 RankedTensorType newInputType = 991 RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); 992 RankedTensorType newKernelType = 993 RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); 994 RankedTensorType newOutputType = 995 RTTBuilder(outputType).dropDim(removeH ? 1 : 2); 996 997 // Rank-reduce operands. 998 Location loc = convOp.getLoc(); 999 Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( 1000 rewriter, loc, input, newInputType); 1001 Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( 1002 rewriter, loc, kernel, newKernelType); 1003 Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( 1004 rewriter, loc, output, newOutputType); 1005 1006 // Rank-reduce strides and dilations too. 1007 // TODO: dropDim 1-liner helper. 1008 auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>()); 1009 strides.erase(strides.begin() + (removeH ? 0 : 1)); 1010 auto stridesAttr = rewriter.getI64VectorAttr(strides); 1011 1012 auto dilations = 1013 llvm::to_vector<4>(convOp.dilations().getValues<int64_t>()); 1014 dilations.erase(dilations.begin() + (removeH ? 0 : 1)); 1015 auto dilationsAttr = rewriter.getI64VectorAttr(dilations); 1016 1017 auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>( 1018 loc, newOutputType, ValueRange{newInput, newKernel}, 1019 ValueRange{newOutput}, stridesAttr, dilationsAttr); 1020 1021 // Insert back. 1022 Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( 1023 rewriter, loc, conv1DOp.getResult(0), output); 1024 rewriter.replaceOp(convOp, inserted); 1025 1026 filter.replaceLinalgTransformationFilter(rewriter, conv1DOp); 1027 return success(); 1028 }; 1029 1030 private: 1031 /// LinalgTransformMarker handles special attribute manipulations. 1032 LinalgTransformationFilter filter; 1033 }; 1034 1035 /// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) 1036 /// dimensions into 1-D depthwise convolution ops. 1037 struct DownscaleDepthwiseConv2DNhwcHwcOp final 1038 : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> { 1039 DownscaleDepthwiseConv2DNhwcHwcOp( 1040 MLIRContext *context, 1041 LinalgTransformationFilter f = LinalgTransformationFilter(), 1042 PatternBenefit benefit = 1) 1043 : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit), 1044 filter(std::move(f)) {} 1045 1046 LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, 1047 PatternRewriter &rewriter) const override { 1048 if (failed(filter.checkAndNotify(rewriter, convOp))) 1049 return failure(); 1050 if (convOp.hasBufferSemantics()) 1051 return failure(); // To be implemented 1052 1053 Value input = convOp.inputs().front(); 1054 Value kernel = convOp.inputs().back(); 1055 Value output = convOp.outputs().front(); 1056 1057 auto inputType = input.getType().dyn_cast<RankedTensorType>(); 1058 auto kernelType = kernel.getType().dyn_cast<RankedTensorType>(); 1059 auto outputType = output.getType().dyn_cast<RankedTensorType>(); 1060 1061 auto kernelShape = kernelType.getShape(); 1062 auto outputShape = outputType.getShape(); 1063 1064 // Only handle the case where at least one of the window dimensions is 1065 // of size 1. Other cases can rely on tiling to reduce to such cases. 1066 int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; 1067 int64_t ohSize = outputShape[1], owSize = outputShape[2]; 1068 bool removeH = (khSize == 1 && ohSize == 1); 1069 bool removeW = (kwSize == 1 && owSize == 1); 1070 if (!removeH && !removeW) 1071 return failure(); 1072 1073 // Get new shapes and types for all operands by removing the size-1 1074 // dimension. 1075 using RTTBuilder = RankedTensorType::Builder; 1076 RankedTensorType newInputType = 1077 RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); 1078 RankedTensorType newKernelType = 1079 RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); 1080 RankedTensorType newOutputType = 1081 RTTBuilder(outputType).dropDim(removeH ? 1 : 2); 1082 1083 // Rank-reduce operands. 1084 Location loc = convOp.getLoc(); 1085 Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( 1086 rewriter, loc, input, newInputType); 1087 Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( 1088 rewriter, loc, kernel, newKernelType); 1089 Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( 1090 rewriter, loc, output, newOutputType); 1091 1092 // Rank-reduce strides and dilations too. 1093 // TODO: dropDim 1-liner helper. 1094 auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>()); 1095 strides.erase(strides.begin() + (removeH ? 0 : 1)); 1096 auto stridesAttr = rewriter.getI64VectorAttr(strides); 1097 1098 auto dilations = 1099 llvm::to_vector<4>(convOp.dilations().getValues<int64_t>()); 1100 dilations.erase(dilations.begin() + (removeH ? 0 : 1)); 1101 auto dilationsAttr = rewriter.getI64VectorAttr(dilations); 1102 1103 auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>( 1104 loc, newOutputType, ValueRange{newInput, newKernel}, 1105 ValueRange{newOutput}, stridesAttr, dilationsAttr); 1106 1107 // Insert back. 1108 Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( 1109 rewriter, loc, conv1DOp.getResult(0), output); 1110 rewriter.replaceOp(convOp, inserted); 1111 1112 filter.replaceLinalgTransformationFilter(rewriter, conv1DOp); 1113 return success(); 1114 }; 1115 1116 private: 1117 /// LinalgTransformMarker handles special attribute manipulations. 1118 LinalgTransformationFilter filter; 1119 }; 1120 1121 } // namespace 1122 1123 void linalg::populateDecomposeConvolutionPatterns( 1124 RewritePatternSet &patterns, const LinalgTransformationFilter &filter, 1125 PatternBenefit benefit) { 1126 patterns.add<DownscaleSizeOneWindowed2DConvolution, 1127 DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), filter, 1128 benefit); 1129 } 1130