1 //===- Transforms.cpp - Linalg transformations as patterns ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/Func/IR/FuncOps.h" 17 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 18 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/SCF/Transforms.h" 22 #include "mlir/Dialect/Tensor/IR/Tensor.h" 23 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" 24 #include "mlir/Dialect/Utils/StaticValueUtils.h" 25 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 26 #include "mlir/Dialect/Vector/IR/VectorOps.h" 27 #include "mlir/IR/AffineExpr.h" 28 #include "mlir/IR/Matchers.h" 29 #include "mlir/Pass/Pass.h" 30 #include "mlir/Support/LLVM.h" 31 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 32 #include "llvm/ADT/ScopeExit.h" 33 #include "llvm/ADT/TypeSwitch.h" 34 #include "llvm/Support/Debug.h" 35 #include "llvm/Support/raw_ostream.h" 36 #include <type_traits> 37 #include <utility> 38 39 #define DEBUG_TYPE "linalg-transforms" 40 41 using namespace mlir; 42 using namespace mlir::linalg; 43 44 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 45 46 //===----------------------------------------------------------------------===// 47 // Transformations exposed as rewrite patterns. 48 //===----------------------------------------------------------------------===// 49 // Marker used as attribute name in generated Linalg rewriting transformations. 50 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 51 "__internal_linalg_transform__"; 52 53 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 54 ArrayRef<StringAttr> matchDisjunction, Optional<StringAttr> replacement) 55 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 56 replacement(replacement), matchByDefault(false) {} 57 58 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 59 const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction, 60 Optional<StringAttr> replacement) 61 : filters(), 62 matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 63 replacement(replacement), matchByDefault(false) { 64 if (f) 65 filters.push_back(f); 66 } 67 68 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify( 69 PatternRewriter &rewriter, Operation *op) const { 70 if (llvm::any_of(filters, 71 [&](const FilterFunction &f) { return failed(f(op)); })) 72 return failure(); 73 74 auto attr = op->template getAttrOfType<StringAttr>( 75 LinalgTransforms::kLinalgTransformMarker); 76 77 if (!attr) { 78 // 1. Has no filter case and matchDisjunction is empty. 79 if (matchDisjunction.empty() || matchByDefault) 80 return success(); 81 82 // 2. Has no filter but was expecting a filter. 83 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 84 diag << " does not have any filter from list: "; 85 interleaveComma(matchDisjunction, diag); 86 }); 87 } 88 89 // 4. Match explicit filter. 90 for (auto filter : matchDisjunction) 91 if (attr.getValue() == filter) 92 return success(); 93 94 // 5. Fail to match. 95 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 96 diag << " does not have any filter from list: "; 97 interleaveComma(matchDisjunction, diag); 98 }); 99 } 100 101 void mlir::linalg::LinalgTransformationFilter:: 102 replaceLinalgTransformationFilter(PatternRewriter &rewriter, 103 Operation *op) const { 104 if (replacement.hasValue()) 105 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 106 replacement.getValue()); 107 else 108 op->removeAttr( 109 rewriter.getStringAttr(LinalgTransforms::kLinalgTransformMarker)); 110 } 111 112 bool mlir::linalg::LinalgTransformationFilter::hasReplacementFilter( 113 Operation *op) const { 114 if (!replacement) 115 return false; 116 auto attr = op->getAttr(LinalgTransforms::kLinalgTransformMarker) 117 .dyn_cast<StringAttr>(); 118 return attr && attr == replacement.getValue(); 119 } 120 121 LinalgTilingOptions & 122 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 123 assert(!tileSizeComputationFunction && "tile sizes already set"); 124 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 125 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 126 OpBuilder::InsertionGuard guard(b); 127 b.setInsertionPointToStart( 128 &op->getParentOfType<FuncOp>().getBody().front()); 129 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 130 Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); 131 return v; 132 })); 133 }; 134 return *this; 135 } 136 137 LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() { 138 assert(!tileSizeComputationFunction && "tile sizes already set"); 139 tileSizeComputationFunction = [](OpBuilder &b, Operation *op) { 140 SmallVector<Value, 4> tileSizes; 141 auto linalgOp = dyn_cast<LinalgOp>(op); 142 if (!linalgOp) 143 return tileSizes; 144 Location loc = linalgOp.getLoc(); 145 auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc); 146 AffineMap map = linalgOp.getShapesToLoopsMap(); 147 if (!map) 148 return tileSizes; 149 auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes); 150 // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile 151 // size 0). 152 for (Value shapeSize : shapeSizes) 153 tileSizes.push_back(getConstantIntValue(shapeSize).hasValue() 154 ? b.create<arith::ConstantIndexOp>(loc, 0) 155 : b.create<arith::ConstantIndexOp>(loc, 1)); 156 return tileSizes; 157 }; 158 return *this; 159 } 160 161 /// Helper function that tries to pad `opOperand`. Exit early for scalar 162 /// operands, if `paddingFunc` returns failure, or if `opOperand` is not defined 163 /// by an ExtractSliceOp. Otherwise, try to pad the operand even if it already 164 /// has a static shape. Set `result` to the result of the created tensor::PadOp 165 /// or and return success if the operand either has been padded to a static 166 /// shape or already had a static shape and failure otherwise. 167 static LogicalResult padOperandToSmallestStaticBoundingBox( 168 OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand, 169 const PaddingValueComputationFunction &paddingFunc, 170 const PaddingNoFoldComputationFunction &nofoldFunc, Value &result) { 171 // Get the shape of the operand and check if it has a dynamic shape. Only 172 // return failure if the operand is not a scalar and has a dynamic shape. 173 ArrayRef<int64_t> shape = opToPad.getShape(opOperand); 174 bool hasDynamicShape = llvm::is_contained(shape, ShapedType::kDynamicSize); 175 176 // Cannot pad scalar operands. 177 if (shape.empty()) 178 return success(); 179 180 // Cannot pad if the padding value is unknown. 181 FailureOr<Value> paddingValue = paddingFunc(b, *opOperand); 182 if (failed(paddingValue)) 183 return failure(hasDynamicShape); 184 185 // Cannot construct a static bounding box if the operand is not defined by an 186 // ExtractSliceOp. 187 auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>(); 188 if (!sliceOp) 189 return failure(hasDynamicShape); 190 191 // Compute the dropped dimensions if `sliceOp` is ranke-reducing. 192 llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims(); 193 194 // Upper bound the `sliceOp` sizes to obtain a static bounding box. 195 SmallVector<int64_t> staticSizes; 196 staticSizes.reserve(shape.size()); 197 auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation()); 198 for (const auto &en : enumerate(shapedOp.getMixedSizes())) { 199 // Skip dropped dimensions. 200 if (droppedDims.test(en.index())) 201 continue; 202 // If the size is an attribute add it directly to `staticSizes`. 203 if (en.value().is<Attribute>()) { 204 staticSizes.push_back( 205 en.value().get<Attribute>().dyn_cast<IntegerAttr>().getInt()); 206 continue; 207 } 208 // Otherwise, try to compute a constant upper bound for the size value. 209 FailureOr<int64_t> upperBound = 210 getConstantUpperBoundForIndex(en.value().get<Value>()); 211 if (failed(upperBound)) { 212 LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding"); 213 return failure(); 214 } 215 staticSizes.push_back(upperBound.getValue()); 216 } 217 assert(staticSizes.size() == shape.size() && 218 "expect the dynamic and static ranks to match"); 219 220 // Pad the operand to the bounding box defined by `staticSizes`. 221 auto staticTensorType = RankedTensorType::get( 222 staticSizes, getElementTypeOrSelf(opOperand->get())); 223 bool nofold = nofoldFunc ? nofoldFunc(*opOperand) : false; 224 result = 225 makeComposedPadHighOp(b, opToPad->getLoc(), staticTensorType, 226 opOperand->get(), paddingValue.getValue(), nofold); 227 return success(); 228 } 229 230 FailureOr<SmallVector<Value>> 231 linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad, 232 const PaddingValueComputationFunction &paddingFunc, 233 const PaddingNoFoldComputationFunction &nofoldFunc, 234 LinalgOp &paddedOp) { 235 Location loc = opToPad->getLoc(); 236 237 // TODO: there are cases where we may still want to pad to larger sizes. 238 assert(opToPad.hasTensorSemantics() && 239 "expected operation to have tensor semantics"); 240 241 OpBuilder::InsertionGuard g(b); 242 // Set IP after op because we also take the dims of the original output. 243 b.setInsertionPointAfter(opToPad); 244 // Make a copy of the shaped operands and update it. 245 SmallVector<Value> newOperands; 246 newOperands.reserve(opToPad.getNumInputsAndOutputs()); 247 for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) { 248 Value paddedOperand; 249 // If padding was requested but the shape cannot be bounded statically then 250 // the pattern fails to apply. 251 if (failed(padOperandToSmallestStaticBoundingBox( 252 b, opToPad, opOperand, paddingFunc, nofoldFunc, paddedOperand))) 253 return failure(); 254 newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get()); 255 } 256 257 SmallVector<SmallVector<Value>> reifiedResultShapes; 258 if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation()) 259 .reifyResultShapes(b, reifiedResultShapes))) 260 return failure(); 261 assert(reifiedResultShapes.size() == opToPad->getNumResults() && 262 "expected same number of results"); 263 264 // Clone `opToPad` to operate on the statically padded shapes. 265 auto resultTensorTypes = 266 ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes(); 267 paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands); 268 269 // Recover the slice out of the new static results. This keeps the original 270 // linalg op around because it uses the dims of the original results. 271 SmallVector<Value> paddedSubviewResults; 272 paddedSubviewResults.reserve(opToPad->getNumResults()); 273 for (const auto &en : llvm::enumerate(paddedOp->getResults())) { 274 Value paddedResult = en.value(); 275 int64_t resultNumber = en.index(); 276 int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank(); 277 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0)); 278 SmallVector<OpFoldResult> sizes; 279 for (Value v : reifiedResultShapes[resultNumber]) 280 sizes.push_back(getAsOpFoldResult(v)); 281 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1)); 282 paddedSubviewResults.push_back(b.create<tensor::ExtractSliceOp>( 283 loc, paddedResult, offsets, sizes, strides)); 284 } 285 return paddedSubviewResults; 286 } 287 288 /// Try to peel a loop `op` and return the new result. 289 // TODO: Add support for scf.parallel and affine.for loops. 290 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) { 291 return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op) 292 .Case<scf::ForOp>([&](scf::ForOp forOp) { 293 scf::ForOp partialIteration; 294 if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp, 295 partialIteration))) 296 return partialIteration->getResults(); 297 assert(!partialIteration && "expected that loop was not peeled"); 298 return forOp->getResults(); 299 }) 300 .Default([&](Operation *op) { return op->getResults(); }); 301 } 302 303 /// Peel loops after tiling. 304 void mlir::linalg::peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res, 305 ArrayRef<int64_t> peeledLoops, 306 LinalgTilingLoopType loopType) { 307 for (int64_t loop : peeledLoops) { 308 assert(loop < static_cast<int64_t>(res.loops.size()) && 309 "requested peeling of non-existing loop"); 310 SmallVector<Value, 4> loopResults; 311 Operation *loopOp = res.loops[loop]; 312 loopResults = peelLoop(rewriter, loopOp); 313 314 // The result of the loop nest may change with peeling. 315 if (res.tensorResults.size() == loopOp->getNumResults() && 316 std::equal(res.tensorResults.begin(), res.tensorResults.end(), 317 loopOp->getResults().begin())) 318 res.tensorResults = loopResults; 319 } 320 } 321 322 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) { 323 if (tiledOp.loops.empty()) 324 return tiledOp.op.getOperation()->getResults(); 325 return tiledOp.loops.front()->getResults(); 326 } 327 328 static ValueRange 329 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) { 330 if (tiledAndFusedOp.fusedLoops.empty()) 331 return tiledAndFusedOp.op.getOperation()->getResults(); 332 return tiledAndFusedOp.fusedLoops.front()->getResults(); 333 } 334 335 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( 336 StringRef opName, MLIRContext *context, 337 const LinalgDependenceGraph &dependenceGraph, 338 LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, 339 LinalgTransformationFilter f, LinalgTransformationFilter fusedOpMarker, 340 LinalgTransformationFilter originalOpMarker, PatternBenefit benefit) 341 : RewritePattern(opName, benefit, context, {}), 342 dependenceGraph(dependenceGraph), tilingOptions(std::move(tilingOptions)), 343 fusionOptions(std::move(fusionOptions)), filter(std::move(f)), 344 fusedOpMarker(std::move(fusedOpMarker)), 345 originalOpMarker(std::move(originalOpMarker)) {} 346 347 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( 348 Operation *op, PatternRewriter &rewriter) const { 349 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 350 // TODO: remove hasIndexSemantics check once index ops are supported. 351 if (!linalgOp || linalgOp.hasIndexSemantics()) 352 return failure(); 353 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 354 return failure(); 355 356 DenseSet<Operation *> producers; 357 producers.insert(linalgOp); 358 for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) { 359 Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum(); 360 // When looking at dependences into, indexingOp is always OpOperand. We 361 // could assert, but continue if this is not the case. 362 if (!operandNumber) 363 continue; 364 if (!fusionOptions.indicesToFuse.count(operandNumber.getValue())) 365 continue; 366 if (isa<LinalgOp>(dependence.getDependentOp())) 367 producers.insert(dependence.getDependentOp()); 368 } 369 370 SmallVector<LinalgOp, 1> fusionOps; 371 for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; 372 ++it) { 373 auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it)); 374 if (producerLinalgOp && producers.count(producerLinalgOp)) 375 fusionOps.push_back(producerLinalgOp); 376 } 377 fusionOps.push_back(linalgOp); 378 379 SmallVector<Value, 4> tileSizes = 380 tilingOptions.tileSizeComputationFunction(rewriter, op); 381 LinalgTilingOptions instanceTilingOptions = tilingOptions; 382 instanceTilingOptions.setTileSizes(tileSizes); 383 Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps( 384 rewriter, fusionOps, dependenceGraph, instanceTilingOptions); 385 if (!tiledAndFusedOps) 386 return failure(); 387 388 // Tile the unfused loops; 389 SmallVector<Value, 4> unfusedLoopTileSizes; 390 Value zero = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0); 391 for (const auto &tileSize : enumerate(tileSizes)) { 392 if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) 393 unfusedLoopTileSizes.push_back(zero); 394 else 395 unfusedLoopTileSizes.push_back(tileSize.value()); 396 } 397 // Tile the loop only if there is a non-zero tile size. 398 if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) 399 unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); 400 if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { 401 if (auto cst = val.getDefiningOp<arith::ConstantIndexOp>()) 402 return cst.value() != 0; 403 return true; 404 })) { 405 LinalgTilingOptions unfusedTilingOptions = tilingOptions; 406 unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); 407 FailureOr<TiledLinalgOp> unfusedTiledOp = 408 tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); 409 if (failed(unfusedTiledOp)) 410 return failure(); 411 rewriter.replaceOp(tiledAndFusedOps->op, 412 getTiledOpResult(unfusedTiledOp.getValue())); 413 tiledAndFusedOps->op = unfusedTiledOp->op; 414 } 415 op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue())); 416 417 filter.replaceLinalgTransformationFilter(rewriter, 418 tiledAndFusedOps->op.getOperation()); 419 for (auto fusedOp : tiledAndFusedOps->fusedProducers) { 420 fusedOpMarker.replaceLinalgTransformationFilter(rewriter, 421 fusedOp.getOperation()); 422 } 423 for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) { 424 originalOpMarker.replaceLinalgTransformationFilter( 425 rewriter, origProducerOp.getOperation()); 426 } 427 rewriter.updateRootInPlace(op, [&]() { 428 originalOpMarker.replaceLinalgTransformationFilter(rewriter, op); 429 }); 430 return success(); 431 } 432 433 /// Linalg tiling pattern. 434 mlir::linalg::LinalgTilingPattern::LinalgTilingPattern( 435 MLIRContext *context, LinalgTilingOptions options, 436 LinalgTransformationFilter f, PatternBenefit benefit) 437 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 438 filter(std::move(f)), options(std::move(options)) {} 439 440 mlir::linalg::LinalgTilingPattern::LinalgTilingPattern( 441 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 442 LinalgTransformationFilter f, PatternBenefit benefit) 443 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 444 filter(f.addOpNameFilter(opName)), options(std::move(options)) {} 445 446 FailureOr<TiledLinalgOp> 447 mlir::linalg::LinalgTilingPattern::returningMatchAndRewrite( 448 LinalgOp op, PatternRewriter &rewriter) const { 449 if (failed(filter.checkAndNotify(rewriter, op))) 450 return failure(); 451 452 FailureOr<TiledLinalgOp> res = tileLinalgOp(rewriter, op, options); 453 if (failed(res)) 454 return failure(); 455 456 // Clear filter to stop recursive pattern application. 457 // This must be done here to properly propagate to peeling branches. 458 filter.replaceLinalgTransformationFilter(rewriter, res->op); 459 460 // Peel the loops of the TiledLinalgOp. 461 peelTiledLinalgOp(rewriter, *res, options.peeledLoops, options.loopType); 462 463 if (res->tensorResults.empty()) 464 rewriter.eraseOp(op); 465 else 466 rewriter.replaceOp(op, res->tensorResults); 467 468 return res; 469 } 470 471 /// Linalg padding pattern. 472 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( 473 MLIRContext *context, LinalgPaddingOptions options, 474 LinalgTransformationFilter f, PatternBenefit benefit) 475 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 476 filter(std::move(f)), options(std::move(options)) {} 477 478 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( 479 StringRef opName, MLIRContext *context, LinalgPaddingOptions options, 480 LinalgTransformationFilter f, PatternBenefit benefit) 481 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 482 filter(f.addOpNameFilter(opName)), options(std::move(options)) {} 483 484 FailureOr<LinalgOp> 485 mlir::linalg::LinalgPaddingPattern::returningMatchAndRewrite( 486 LinalgOp linalgOp, PatternRewriter &rewriter) const { 487 if (!linalgOp.hasTensorSemantics()) 488 return failure(); 489 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 490 return failure(); 491 492 // Pad the operation. 493 LinalgOp paddedOp; 494 FailureOr<SmallVector<Value>> newResults = rewriteAsPaddedOp( 495 rewriter, linalgOp, options.paddingValueComputationFunction, 496 options.paddingNoFoldComputationFunction, paddedOp); 497 if (failed(newResults)) 498 return failure(); 499 500 // Compute the desired hoisting depths. 501 SmallVector<int64_t> depths; 502 if (options.paddingHoistComputationFunction) { 503 for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) 504 depths.push_back(options.paddingHoistComputationFunction(*opOperand)); 505 } 506 507 // Hoist the padding. 508 for (const auto &en : enumerate(depths)) { 509 OpOperand &opOperand = paddedOp->getOpOperand(en.index()); 510 auto padOp = opOperand.get().getDefiningOp<tensor::PadOp>(); 511 if (!padOp || en.value() == 0) 512 continue; 513 tensor::PadOp hoistedOp; 514 SmallVector<GenericOp> transposeOps; 515 SmallVector<int64_t> transposeVector = 516 options.paddingTransposeComputationFunction(opOperand); 517 518 FailureOr<Value> newResult = hoistPaddingOnTensors( 519 padOp, en.value(), transposeVector, hoistedOp, transposeOps); 520 if (failed(newResult)) 521 continue; 522 rewriter.replaceOp(padOp, newResult.getValue()); 523 524 // Do not apply hoist padding to the newly introduced transpose operations. 525 for (GenericOp transposeOp : transposeOps) 526 filter.replaceLinalgTransformationFilter(rewriter, transposeOp); 527 } 528 529 // Replace the original operation to pad. 530 rewriter.replaceOp(linalgOp, newResults.getValue()); 531 filter.replaceLinalgTransformationFilter(rewriter, paddedOp); 532 533 return paddedOp; 534 } 535 536 /// Linalg tile and fuse tensor ops pattern. 537 mlir::linalg::LinalgTileAndFuseTensorOpsPattern:: 538 LinalgTileAndFuseTensorOpsPattern(MLIRContext *context, 539 LinalgTilingAndFusionOptions options, 540 LinalgTransformationFilter f, 541 PatternBenefit benefit) 542 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), 543 filter(std::move(f)), options(std::move(options)) {} 544 545 mlir::linalg::LinalgTileAndFuseTensorOpsPattern:: 546 LinalgTileAndFuseTensorOpsPattern(StringRef opName, MLIRContext *context, 547 LinalgTilingAndFusionOptions options, 548 LinalgTransformationFilter f, 549 PatternBenefit benefit) 550 : RewritePattern(opName, benefit, context), filter(std::move(f)), 551 options(std::move(options)) {} 552 553 FailureOr<mlir::linalg::TileLoopNest> 554 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::returningMatchAndRewrite( 555 Operation *op, PatternRewriter &rewriter) const { 556 LinalgOp rootOp = dyn_cast<LinalgOp>(op); 557 if (!rootOp) 558 return failure(); 559 if (failed(filter.checkAndNotify(rewriter, op))) 560 return failure(); 561 562 // Check `tileSizes` contains a tile size for every `rootOp` loop dimension. 563 if (options.tileSizes.size() < rootOp.getNumLoops()) 564 return rewriter.notifyMatchFailure(op, "expect #tile sizes >= #loops"); 565 566 // Check `tileInterchange` contains no entries or as many as `tileSizes`. 567 if (!options.tileInterchange.empty() && 568 options.tileInterchange.size() != options.tileSizes.size()) 569 return rewriter.notifyMatchFailure( 570 op, "expect the number of tile sizes and interchange dims to match"); 571 572 // Copy the `tileSizes` and `tileInterchange` prefixes needed for `rootOp`. 573 SmallVector<int64_t> rootTileSizes(options.tileSizes.begin(), 574 options.tileSizes.begin() + 575 rootOp.getNumLoops()); 576 SmallVector<int64_t> rootInterchange = 577 options.tileInterchange.empty() 578 ? llvm::to_vector<6>(llvm::seq<int64_t>(0, rootOp.getNumLoops())) 579 : SmallVector<int64_t>(options.tileInterchange.begin(), 580 options.tileInterchange.begin() + 581 rootOp.getNumLoops()); 582 583 // Check `rootTileSizes` contains non-zero tile sizes. 584 if (llvm::count(rootTileSizes, 0) == static_cast<long>(rootTileSizes.size())) 585 return rewriter.notifyMatchFailure( 586 op, "expect at least one non-zero tile size"); 587 588 // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions. 589 // It has to be a permutation since the tiling cannot tile the same loop 590 // dimension multiple times. 591 if (!isPermutation(rootInterchange)) 592 return rewriter.notifyMatchFailure( 593 op, "expect the tile interchange permutes the root loops"); 594 595 // Tile `rootOp` and fuse its producers. 596 FailureOr<TileLoopNest> tileLoopNest = 597 tileConsumerAndFuseProducers(rewriter, rootOp, rootTileSizes, 598 rootInterchange, options.tileDistribution); 599 if (failed(tileLoopNest)) 600 return rewriter.notifyMatchFailure( 601 op, "tileConsumerAndFuseProducers failed unexpectedly"); 602 603 // Replace all uses of the tiled loop operation. 604 rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults()); 605 606 // Apply the filter if specified. 607 for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps()) 608 filter.replaceLinalgTransformationFilter(rewriter, linalgOp); 609 return tileLoopNest; 610 } 611 612 /// Linalg generic interchange pattern. 613 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern( 614 MLIRContext *context, ArrayRef<unsigned> interchangeVector, 615 LinalgTransformationFilter f, PatternBenefit benefit) 616 : OpRewritePattern(context, benefit), filter(std::move(f)), 617 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 618 619 FailureOr<GenericOp> 620 mlir::linalg::GenericOpInterchangePattern::returningMatchAndRewrite( 621 GenericOp genericOp, PatternRewriter &rewriter) const { 622 if (failed(filter.checkAndNotify(rewriter, genericOp))) 623 return failure(); 624 625 FailureOr<GenericOp> transformedOp = 626 interchangeGenericOp(rewriter, genericOp, interchangeVector); 627 if (failed(transformedOp)) 628 return failure(); 629 630 // New filter if specified. 631 filter.replaceLinalgTransformationFilter(rewriter, genericOp); 632 return transformedOp; 633 } 634 635 /// Linalg generalization pattern. 636 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 637 MLIRContext *context, LinalgTransformationFilter f, PatternBenefit benefit) 638 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 639 filter(std::move(f)) {} 640 641 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 642 StringRef opName, MLIRContext *context, LinalgTransformationFilter f, 643 PatternBenefit benefit) 644 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 645 filter(f.addOpNameFilter(opName)) {} 646 647 FailureOr<GenericOp> 648 mlir::linalg::LinalgGeneralizationPattern::returningMatchAndRewrite( 649 LinalgOp linalgOp, PatternRewriter &rewriter) const { 650 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 651 return failure(); 652 FailureOr<GenericOp> genericOp = generalizeNamedOp(rewriter, linalgOp); 653 if (failed(genericOp)) 654 return failure(); 655 filter.replaceLinalgTransformationFilter(rewriter, *genericOp); 656 return genericOp; 657 } 658 659 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 660 MLIRContext *context, LinalgTransformationFilter f, 661 LinalgPromotionOptions options, PatternBenefit benefit) 662 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), 663 filter(std::move(f)), options(std::move(options)) {} 664 665 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 666 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 667 LinalgTransformationFilter f, PatternBenefit benefit) 668 : RewritePattern(opName, benefit, context, {}), filter(std::move(f)), 669 options(std::move(options)) {} 670 671 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 672 Operation *op, PatternRewriter &rewriter) const { 673 if (failed(filter.checkAndNotify(rewriter, op))) 674 return failure(); 675 if (failed(promoteSubviewsPrecondition(op, options))) 676 return failure(); 677 678 // TODO: We cannot use root update here. This pattern is creating other ops, 679 // so if the promotion fails, those need to be cleaned up, which doesnt seem 680 // to be happening here. So to fail properly, we should be cloning the op and 681 // deleting the previous op. This needs more investigation. 682 rewriter.startRootUpdate(op); 683 Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options); 684 if (!promotedOp) { 685 rewriter.cancelRootUpdate(op); 686 return op->emitError("subview promotion failed"); 687 } 688 rewriter.finalizeRootUpdate(op); 689 filter.replaceLinalgTransformationFilter(rewriter, op); 690 return success(); 691 } 692 693 mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern( 694 MLIRContext *context, LinalgTransformationFilter f, 695 LinalgVectorizationOptions options, PatternBenefit benefit) 696 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 697 filter(std::move(f)) {} 698 699 mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern( 700 StringRef opName, MLIRContext *context, LinalgVectorizationOptions options, 701 LinalgTransformationFilter f, PatternBenefit benefit) 702 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 703 filter(f.addOpNameFilter(opName)) {} 704 705 LogicalResult mlir::linalg::LinalgVectorizationPattern::matchAndRewrite( 706 LinalgOp linalgOp, PatternRewriter &rewriter) const { 707 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 708 return failure(); 709 return vectorize(rewriter, linalgOp); 710 } 711 712 LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite( 713 memref::CopyOp copyOp, PatternRewriter &rewriter) const { 714 return vectorizeCopy(rewriter, copyOp); 715 } 716 717 LogicalResult mlir::linalg::applyStagedPatterns( 718 Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns, 719 const FrozenRewritePatternSet &stage2Patterns, 720 function_ref<LogicalResult(Operation *)> stage3Lambda) { 721 unsigned iteration = 0; 722 (void)iteration; 723 for (const auto &patterns : stage1Patterns) { 724 LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" 725 << *op); 726 if (failed(applyPatternsAndFoldGreedily(op, patterns))) { 727 LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); 728 return failure(); 729 } 730 LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" 731 << *op); 732 if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { 733 LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); 734 return failure(); 735 } 736 LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" 737 << *op); 738 if (stage3Lambda) { 739 if (failed(stage3Lambda(op))) 740 return failure(); 741 LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" 742 << *op); 743 } 744 } 745 return success(); 746 } 747 748 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) { 749 return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName()); 750 } 751 752 /// Rewrite a tensor::PadOp into a sequence of InitTensorOp, FillOp (to 753 /// initialize with pad_val) and GenericOp (to copy contents). 754 LogicalResult 755 PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp, 756 PatternRewriter &rewriter) const { 757 758 auto inputShapedType = padOp.source().getType().cast<ShapedType>(); 759 auto resultShapedType = padOp.result().getType().cast<ShapedType>(); 760 761 // Bail on non-static shapes. 762 if (!inputShapedType.hasStaticShape()) 763 return failure(); 764 if (!resultShapedType.hasStaticShape()) 765 return failure(); 766 767 // Only support padding with a constant for now, i.e. either: 768 // 1. A BBarg from a different block. 769 // 2. A value defined outside of the current block. 770 Block &block = padOp.region().front(); 771 auto yieldOp = cast<tensor::YieldOp>(block.getTerminator()); 772 Value padValue = yieldOp.value(); 773 Operation *definingOp = padValue.getDefiningOp(); 774 if (definingOp && definingOp->getBlock() == &block) 775 return failure(); 776 if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block) 777 return failure(); 778 779 // Create tensor with the padded shape 780 Location loc = padOp.getLoc(); 781 SmallVector<Value> indices(resultShapedType.getRank(), 782 rewriter.create<arith::ConstantIndexOp>(loc, 0)); 783 Value initTensor = rewriter.create<InitTensorOp>( 784 loc, resultShapedType.getShape(), resultShapedType.getElementType()); 785 786 // Initialize tensor with the pad value 787 Value tmpTensor = rewriter 788 .create<linalg::FillOp>(loc, ValueRange{padValue}, 789 ValueRange{initTensor}) 790 .result(); 791 792 // Copy original contents into new tensor 793 // Uses linalg.generic, but could be done with tensor.insert_slice 794 SmallVector<AffineExpr, 4> outputExprs; 795 for (unsigned i = 0; i < resultShapedType.getRank(); ++i) { 796 outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) + 797 padOp.static_low()[i].cast<IntegerAttr>().getInt()); 798 } 799 800 SmallVector<AffineMap, 2> transferMaps = { 801 rewriter.getMultiDimIdentityMap(inputShapedType.getRank()), 802 AffineMap::get(resultShapedType.getRank(), 803 /*symbolCount=*/0, outputExprs, rewriter.getContext())}; 804 805 rewriter.replaceOpWithNewOp<linalg::GenericOp>( 806 padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps, 807 getNParallelLoopsAttrs(resultShapedType.getRank()), 808 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { 809 nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]); 810 }); 811 812 return success(); 813 } 814 815 /// Filling `dest` using FillOp constant padding value if possible. 816 /// Otherwise, generate a tensor::GenerateOp. 817 Value GeneralizePadOpPattern::createFillOrGenerateOp( 818 PatternRewriter &rewriter, tensor::PadOp padOp, Value dest, 819 const SmallVector<Value> &dynSizes) const { 820 auto padValue = padOp.getConstantPaddingValue(); 821 if (padValue) 822 return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result(); 823 824 // Fill could not be optimized: Lower to tensor::GenerateOp with region. 825 auto generateOp = rewriter.create<tensor::GenerateOp>( 826 padOp.getLoc(), padOp.getResultType(), dynSizes); 827 // Copy region to new op. 828 BlockAndValueMapping bvm; 829 padOp.region().cloneInto(&generateOp.getRegion(), bvm); 830 return generateOp; 831 } 832 833 LogicalResult 834 GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp, 835 PatternRewriter &rewriter) const { 836 // Given an OpFoldResult, return an index-typed value. 837 auto getIdxValue = [&](OpFoldResult ofr) { 838 if (auto val = ofr.dyn_cast<Value>()) 839 return val; 840 return rewriter 841 .create<arith::ConstantIndexOp>( 842 padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt()) 843 .getResult(); 844 }; 845 846 auto resultType = padOp.getResultType(); 847 // Compute size of InitTensorOp. Any combination of static/dynamic is 848 // supported. 849 SmallVector<Value> dynSizes; 850 SmallVector<int64_t> staticSizes; 851 for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { 852 if (resultType.isDynamicDim(dim)) { 853 auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(), 854 padOp.source(), dim); 855 // Add low and high padding value. 856 auto plusLow = rewriter.createOrFold<arith::AddIOp>( 857 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); 858 auto plusHigh = rewriter.createOrFold<arith::AddIOp>( 859 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); 860 dynSizes.push_back(plusHigh); 861 } 862 staticSizes.push_back(resultType.getDimSize(dim)); 863 } 864 865 // Init tensor and fill it with padding. 866 Value init = rewriter.create<InitTensorOp>( 867 padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType()); 868 Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes); 869 870 // Try optimize the copy of source. 871 if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded()) 872 return success(); 873 874 // tensor::PadOps cannot be optimized. Generate a InsertSliceOp instead 875 // for copying the PadOp source. 876 auto sourceType = padOp.getSourceType(); 877 // Compute size of source of tensor::PadOp. 878 SmallVector<OpFoldResult> srcSizes; 879 for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) { 880 if (sourceType.isDynamicDim(dim)) { 881 srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>( 882 padOp.getLoc(), padOp.source(), dim)); 883 } else { 884 srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim))); 885 } 886 } 887 // Strides of InsertSliceOp are all 1. 888 SmallVector<OpFoldResult> strides(sourceType.getRank(), 889 rewriter.getIndexAttr(1)); 890 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( 891 padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides); 892 893 return success(); 894 } 895 896 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( 897 tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { 898 if (!sliceOp.hasUnitStride()) 899 return failure(); 900 901 auto padOp = sliceOp.source().getDefiningOp<tensor::PadOp>(); 902 if (!padOp) 903 return failure(); 904 905 bool zeroSliceGuard = true; 906 if (controlFn) { 907 if (Optional<bool> control = controlFn(sliceOp)) 908 zeroSliceGuard = control.getValue(); 909 else 910 return failure(); 911 } 912 913 Operation *tiledPadOp = 914 tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(), 915 sliceOp.getMixedSizes(), zeroSliceGuard); 916 // All shapes are static and the data source is actually used. Rewrite into 917 // pad(extract_slice(x)). 918 rewriter.replaceOp(sliceOp, tiledPadOp->getResults()); 919 return success(); 920 } 921 922 namespace { 923 // The following are patterns for downscaling convolution ops with size-1 924 // window dimensions. 925 // 926 // Note that we'd eventually want to write such transformations in a generic 927 // way, e.g., converting to linalg.generic, removing the size-1 dimensions, 928 // and then turning back to named ops. But for now it's fine to have a few 929 // patterns matching special ops to get started. 930 931 /// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D 932 /// convolution ops. 933 struct DownscaleSizeOneWindowed2DConvolution final 934 : public OpRewritePattern<Conv2DNhwcHwcfOp> { 935 DownscaleSizeOneWindowed2DConvolution( 936 MLIRContext *context, 937 LinalgTransformationFilter f = LinalgTransformationFilter(), 938 PatternBenefit benefit = 1) 939 : OpRewritePattern<Conv2DNhwcHwcfOp>(context, benefit), 940 filter(std::move(f)) {} 941 942 LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, 943 PatternRewriter &rewriter) const override { 944 if (failed(filter.checkAndNotify(rewriter, convOp))) 945 return failure(); 946 if (convOp.hasBufferSemantics()) 947 return failure(); // To be implemented 948 949 Value input = convOp.inputs().front(); 950 Value kernel = convOp.inputs().back(); 951 Value output = convOp.outputs().front(); 952 953 auto inputType = input.getType().dyn_cast<RankedTensorType>(); 954 auto kernelType = kernel.getType().dyn_cast<RankedTensorType>(); 955 auto outputType = output.getType().dyn_cast<RankedTensorType>(); 956 957 auto kernelShape = kernelType.getShape(); 958 auto outputShape = outputType.getShape(); 959 960 // Only handle the case where at least one of the window dimensions is 961 // of size 1. Other cases can rely on tiling to reduce to such cases. 962 int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; 963 int64_t ohSize = outputShape[1], owSize = outputShape[2]; 964 bool removeH = (khSize == 1 && ohSize == 1); 965 bool removeW = (kwSize == 1 && owSize == 1); 966 if (!removeH && !removeW) 967 return failure(); 968 969 // Get new shapes and types for all operands by removing the size-1 970 // dimension. 971 using RTTBuilder = RankedTensorType::Builder; 972 RankedTensorType newInputType = 973 RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); 974 RankedTensorType newKernelType = 975 RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); 976 RankedTensorType newOutputType = 977 RTTBuilder(outputType).dropDim(removeH ? 1 : 2); 978 979 // Rank-reduce operands. 980 Location loc = convOp.getLoc(); 981 Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( 982 rewriter, loc, input, newInputType); 983 Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( 984 rewriter, loc, kernel, newKernelType); 985 Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( 986 rewriter, loc, output, newOutputType); 987 988 // Rank-reduce strides and dilations too. 989 // TODO: dropDim 1-liner helper. 990 auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>()); 991 strides.erase(strides.begin() + (removeH ? 0 : 1)); 992 auto stridesAttr = rewriter.getI64VectorAttr(strides); 993 994 auto dilations = 995 llvm::to_vector<4>(convOp.dilations().getValues<int64_t>()); 996 dilations.erase(dilations.begin() + (removeH ? 0 : 1)); 997 auto dilationsAttr = rewriter.getI64VectorAttr(dilations); 998 999 auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>( 1000 loc, newOutputType, ValueRange{newInput, newKernel}, 1001 ValueRange{newOutput}, stridesAttr, dilationsAttr); 1002 1003 // Insert back. 1004 Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( 1005 rewriter, loc, conv1DOp.getResult(0), output); 1006 rewriter.replaceOp(convOp, inserted); 1007 1008 filter.replaceLinalgTransformationFilter(rewriter, conv1DOp); 1009 return success(); 1010 }; 1011 1012 private: 1013 /// LinalgTransformMarker handles special attribute manipulations. 1014 LinalgTransformationFilter filter; 1015 }; 1016 1017 /// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) 1018 /// dimensions into 1-D depthwise convolution ops. 1019 struct DownscaleDepthwiseConv2DNhwcHwcOp final 1020 : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> { 1021 DownscaleDepthwiseConv2DNhwcHwcOp( 1022 MLIRContext *context, 1023 LinalgTransformationFilter f = LinalgTransformationFilter(), 1024 PatternBenefit benefit = 1) 1025 : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit), 1026 filter(std::move(f)) {} 1027 1028 LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, 1029 PatternRewriter &rewriter) const override { 1030 if (failed(filter.checkAndNotify(rewriter, convOp))) 1031 return failure(); 1032 if (convOp.hasBufferSemantics()) 1033 return failure(); // To be implemented 1034 1035 Value input = convOp.inputs().front(); 1036 Value kernel = convOp.inputs().back(); 1037 Value output = convOp.outputs().front(); 1038 1039 auto inputType = input.getType().dyn_cast<RankedTensorType>(); 1040 auto kernelType = kernel.getType().dyn_cast<RankedTensorType>(); 1041 auto outputType = output.getType().dyn_cast<RankedTensorType>(); 1042 1043 auto kernelShape = kernelType.getShape(); 1044 auto outputShape = outputType.getShape(); 1045 1046 // Only handle the case where at least one of the window dimensions is 1047 // of size 1. Other cases can rely on tiling to reduce to such cases. 1048 int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; 1049 int64_t ohSize = outputShape[1], owSize = outputShape[2]; 1050 bool removeH = (khSize == 1 && ohSize == 1); 1051 bool removeW = (kwSize == 1 && owSize == 1); 1052 if (!removeH && !removeW) 1053 return failure(); 1054 1055 // Get new shapes and types for all operands by removing the size-1 1056 // dimension. 1057 using RTTBuilder = RankedTensorType::Builder; 1058 RankedTensorType newInputType = 1059 RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); 1060 RankedTensorType newKernelType = 1061 RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); 1062 RankedTensorType newOutputType = 1063 RTTBuilder(outputType).dropDim(removeH ? 1 : 2); 1064 1065 // Rank-reduce operands. 1066 Location loc = convOp.getLoc(); 1067 Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( 1068 rewriter, loc, input, newInputType); 1069 Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( 1070 rewriter, loc, kernel, newKernelType); 1071 Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( 1072 rewriter, loc, output, newOutputType); 1073 1074 // Rank-reduce strides and dilations too. 1075 // TODO: dropDim 1-liner helper. 1076 auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>()); 1077 strides.erase(strides.begin() + (removeH ? 0 : 1)); 1078 auto stridesAttr = rewriter.getI64VectorAttr(strides); 1079 1080 auto dilations = 1081 llvm::to_vector<4>(convOp.dilations().getValues<int64_t>()); 1082 dilations.erase(dilations.begin() + (removeH ? 0 : 1)); 1083 auto dilationsAttr = rewriter.getI64VectorAttr(dilations); 1084 1085 auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>( 1086 loc, newOutputType, ValueRange{newInput, newKernel}, 1087 ValueRange{newOutput}, stridesAttr, dilationsAttr); 1088 1089 // Insert back. 1090 Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( 1091 rewriter, loc, conv1DOp.getResult(0), output); 1092 rewriter.replaceOp(convOp, inserted); 1093 1094 filter.replaceLinalgTransformationFilter(rewriter, conv1DOp); 1095 return success(); 1096 }; 1097 1098 private: 1099 /// LinalgTransformMarker handles special attribute manipulations. 1100 LinalgTransformationFilter filter; 1101 }; 1102 1103 } // namespace 1104 1105 void linalg::populateDecomposeConvolutionPatterns( 1106 RewritePatternSet &patterns, const LinalgTransformationFilter &filter, 1107 PatternBenefit benefit) { 1108 patterns.add<DownscaleSizeOneWindowed2DConvolution, 1109 DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), filter, 1110 benefit); 1111 } 1112