1 //===- Transforms.cpp - Linalg transformations as patterns ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 19 #include "mlir/Dialect/Linalg/IR/Linalg.h" 20 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" 21 #include "mlir/Dialect/Linalg/Utils/Utils.h" 22 #include "mlir/Dialect/SCF/Transforms.h" 23 #include "mlir/Dialect/Tensor/IR/Tensor.h" 24 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" 25 #include "mlir/Dialect/Utils/StaticValueUtils.h" 26 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 27 #include "mlir/Dialect/Vector/IR/VectorOps.h" 28 #include "mlir/IR/AffineExpr.h" 29 #include "mlir/IR/Matchers.h" 30 #include "mlir/Pass/Pass.h" 31 #include "mlir/Support/LLVM.h" 32 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 33 #include "llvm/ADT/ScopeExit.h" 34 #include "llvm/ADT/TypeSwitch.h" 35 #include "llvm/Support/Debug.h" 36 #include "llvm/Support/raw_ostream.h" 37 #include <type_traits> 38 #include <utility> 39 40 #define DEBUG_TYPE "linalg-transforms" 41 42 using namespace mlir; 43 using namespace mlir::linalg; 44 45 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 46 47 //===----------------------------------------------------------------------===// 48 // Transformations exposed as rewrite patterns. 49 //===----------------------------------------------------------------------===// 50 // Marker used as attribute name in generated Linalg rewriting transformations. 51 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 52 "__internal_linalg_transform__"; 53 54 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 55 ArrayRef<StringAttr> matchDisjunction, Optional<StringAttr> replacement) 56 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 57 replacement(replacement), matchByDefault(false) {} 58 59 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 60 const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction, 61 Optional<StringAttr> replacement) 62 : filters(), 63 matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 64 replacement(replacement), matchByDefault(false) { 65 if (f) 66 filters.push_back(f); 67 } 68 69 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify( 70 PatternRewriter &rewriter, Operation *op) const { 71 if (llvm::any_of(filters, 72 [&](const FilterFunction &f) { return failed(f(op)); })) 73 return failure(); 74 75 auto attr = op->template getAttrOfType<StringAttr>( 76 LinalgTransforms::kLinalgTransformMarker); 77 78 if (!attr) { 79 // 1. Has no filter case and matchDisjunction is empty. 80 if (matchDisjunction.empty() || matchByDefault) 81 return success(); 82 83 // 2. Has no filter but was expecting a filter. 84 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 85 diag << " does not have any filter from list: "; 86 interleaveComma(matchDisjunction, diag); 87 }); 88 } 89 90 // 4. Match explicit filter. 91 for (auto filter : matchDisjunction) 92 if (attr.getValue() == filter) 93 return success(); 94 95 // 5. Fail to match. 96 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 97 diag << " does not have any filter from list: "; 98 interleaveComma(matchDisjunction, diag); 99 }); 100 } 101 102 void mlir::linalg::LinalgTransformationFilter:: 103 replaceLinalgTransformationFilter(PatternRewriter &rewriter, 104 Operation *op) const { 105 if (replacement.hasValue()) 106 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 107 replacement.getValue()); 108 else 109 op->removeAttr( 110 rewriter.getStringAttr(LinalgTransforms::kLinalgTransformMarker)); 111 } 112 113 bool mlir::linalg::LinalgTransformationFilter::hasReplacementFilter( 114 Operation *op) const { 115 if (!replacement) 116 return false; 117 auto attr = op->getAttr(LinalgTransforms::kLinalgTransformMarker) 118 .dyn_cast<StringAttr>(); 119 return attr && attr == replacement.getValue(); 120 } 121 122 LinalgTilingOptions & 123 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 124 assert(!tileSizeComputationFunction && "tile sizes already set"); 125 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 126 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 127 OpBuilder::InsertionGuard guard(b); 128 b.setInsertionPointToStart( 129 &op->getParentOfType<func::FuncOp>().getBody().front()); 130 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 131 Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); 132 return v; 133 })); 134 }; 135 return *this; 136 } 137 138 LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() { 139 assert(!tileSizeComputationFunction && "tile sizes already set"); 140 tileSizeComputationFunction = [](OpBuilder &b, Operation *op) { 141 SmallVector<Value, 4> tileSizes; 142 auto linalgOp = dyn_cast<LinalgOp>(op); 143 if (!linalgOp) 144 return tileSizes; 145 Location loc = linalgOp.getLoc(); 146 auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc); 147 AffineMap map = linalgOp.getShapesToLoopsMap(); 148 if (!map) 149 return tileSizes; 150 auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes); 151 // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile 152 // size 0). 153 for (Value shapeSize : shapeSizes) 154 tileSizes.push_back(getConstantIntValue(shapeSize).hasValue() 155 ? b.create<arith::ConstantIndexOp>(loc, 0) 156 : b.create<arith::ConstantIndexOp>(loc, 1)); 157 return tileSizes; 158 }; 159 return *this; 160 } 161 162 /// Pad the `opOperand` in the `paddingDimensions` using the padding value and 163 /// the nofold flag found in `paddingValues` and `packPaddings`, respectively. 164 /// Exit early and return the `opOperand` value if the shape dimensions that 165 /// match `paddingDimensions` have a static size and the nofold flag is not set. 166 /// Otherwise, try to pad the shape dimensions that match the iterator 167 /// dimensions `paddingDimensions` and return the tensor::PadOp result if 168 /// padding succeeds or failure otherwise. 169 static FailureOr<Value> padOperandToSmallestStaticBoundingBox( 170 OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand, 171 ArrayRef<int64_t> paddingDimensions, ArrayRef<Attribute> paddingValues, 172 ArrayRef<bool> packPaddings) { 173 AffineMap indexingMap = opToPad.getTiedIndexingMap(opOperand); 174 ArrayRef<int64_t> shape = opToPad.getShape(opOperand); 175 176 // Collect the shape dimension that are a function of the `paddingDimensions`. 177 llvm::SmallDenseSet<int64_t> shapeDimsToPad; 178 for (int64_t dim : paddingDimensions) 179 for (const auto &en : enumerate(indexingMap.getResults())) 180 if (en.value().isFunctionOfDim(dim)) 181 shapeDimsToPad.insert(en.index()); 182 183 // Return the unpadded operand if padding to a static shape is not needed and 184 // if the nofold flag is not set. 185 bool nofold = opOperand->getOperandNumber() < packPaddings.size() 186 ? packPaddings[opOperand->getOperandNumber()] 187 : false; 188 bool hasStaticShape = llvm::none_of(shapeDimsToPad, [&](int64_t dim) { 189 return ShapedType::isDynamic(shape[dim]); 190 }); 191 if (!nofold && hasStaticShape) 192 return opOperand->get(); 193 194 // Fail if `paddingValues` specifies no padding value. 195 if (opOperand->getOperandNumber() >= paddingValues.size()) 196 return failure(); 197 Attribute paddingAttr = paddingValues[opOperand->getOperandNumber()]; 198 Value paddingValue = b.create<arith::ConstantOp>( 199 opToPad.getLoc(), paddingAttr.getType(), paddingAttr); 200 201 // Follow the use-def chain if `currOpOperand` is defined by a LinalgOp. 202 OpOperand *currOpOperand = opOperand; 203 while (auto linalgOp = currOpOperand->get().getDefiningOp<LinalgOp>()) { 204 OpResult result = currOpOperand->get().cast<OpResult>(); 205 currOpOperand = linalgOp.getOutputOperand(result.getResultNumber()); 206 } 207 208 // Fail if `currOpOperand` is not defined by an ExtractSliceOp. 209 auto sliceOp = currOpOperand->get().getDefiningOp<tensor::ExtractSliceOp>(); 210 if (!sliceOp) 211 return failure(); 212 213 // Compute the dropped dimensions if `sliceOp` is ranke-reducing. 214 llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims(); 215 OffsetSizeAndStrideOpInterface shapedOp = sliceOp; 216 217 // Upper bound the `sliceOp` sizes to obtain a static bounding box. 218 SmallVector<int64_t> paddedShape(shape.begin(), shape.end()); 219 int64_t shapeIdx = 0; 220 for (const auto &en : enumerate(shapedOp.getMixedSizes())) { 221 // Skip dropped dimensions. 222 if (droppedDims.test(en.index())) 223 continue; 224 // Skip dimensions that do not require padding. 225 if (!shapeDimsToPad.contains(shapeIdx)) { 226 shapeIdx++; 227 continue; 228 } 229 // If the size is an attribute add it directly to `paddedShape`. 230 if (en.value().is<Attribute>()) { 231 paddedShape[shapeIdx++] = 232 en.value().get<Attribute>().dyn_cast<IntegerAttr>().getInt(); 233 continue; 234 } 235 // Otherwise, try to compute a constant upper bound for the size value. 236 FailureOr<int64_t> upperBound = 237 getConstantUpperBoundForIndex(en.value().get<Value>()); 238 if (failed(upperBound)) { 239 LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding"); 240 return failure(); 241 } 242 paddedShape[shapeIdx++] = upperBound.getValue(); 243 } 244 assert(shapeIdx == static_cast<int64_t>(shape.size()) && 245 "expect the dynamic and static ranks to match"); 246 247 // Pad the operand to the bounding box defined by `paddedShape`. 248 auto paddedTensorType = RankedTensorType::get( 249 paddedShape, getElementTypeOrSelf(opOperand->get())); 250 return makeComposedPadHighOp(b, opToPad->getLoc(), paddedTensorType, 251 opOperand->get(), paddingValue, nofold); 252 } 253 254 FailureOr<SmallVector<Value>> 255 linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad, 256 ArrayRef<int64_t> paddingDimensions, 257 ArrayRef<Attribute> paddingValues, 258 ArrayRef<bool> packPaddings, LinalgOp &paddedOp) { 259 Location loc = opToPad->getLoc(); 260 261 // TODO: there are cases where we may still want to pad to larger sizes. 262 assert(opToPad.hasTensorSemantics() && 263 "expected operation to have tensor semantics"); 264 265 OpBuilder::InsertionGuard g(b); 266 // Set IP after op because we also take the dims of the original output. 267 b.setInsertionPointAfter(opToPad); 268 // Make a copy of the shaped operands and update it. 269 SmallVector<Value> newOperands; 270 newOperands.reserve(opToPad.getNumInputsAndOutputs()); 271 for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) { 272 FailureOr<Value> paddedOperand = padOperandToSmallestStaticBoundingBox( 273 b, opToPad, opOperand, paddingDimensions, paddingValues, packPaddings); 274 // Exit if `paddingDimensions` cannot be bounded statically. 275 if (failed(paddedOperand)) 276 return failure(); 277 newOperands.push_back(*paddedOperand); 278 } 279 280 SmallVector<SmallVector<Value>> reifiedResultShapes; 281 if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation()) 282 .reifyResultShapes(b, reifiedResultShapes))) 283 return failure(); 284 assert(reifiedResultShapes.size() == opToPad->getNumResults() && 285 "expected same number of results"); 286 287 // Clone `opToPad` to operate on the statically padded shapes. 288 auto resultTensorTypes = 289 ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes(); 290 paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands); 291 292 // Recover the slice out of the new static results. This keeps the original 293 // linalg op around because it uses the dims of the original results. 294 SmallVector<Value> paddedSubviewResults; 295 paddedSubviewResults.reserve(opToPad->getNumResults()); 296 for (const auto &en : llvm::enumerate(paddedOp->getResults())) { 297 Value paddedResult = en.value(); 298 int64_t resultNumber = en.index(); 299 int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank(); 300 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0)); 301 SmallVector<OpFoldResult> sizes; 302 for (Value v : reifiedResultShapes[resultNumber]) 303 sizes.push_back(getAsOpFoldResult(v)); 304 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1)); 305 paddedSubviewResults.push_back(b.create<tensor::ExtractSliceOp>( 306 loc, paddedResult, offsets, sizes, strides)); 307 } 308 return paddedSubviewResults; 309 } 310 311 /// Try to peel a loop `op` and return the new result. 312 // TODO: Add support for scf.parallel and affine.for loops. 313 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) { 314 return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op) 315 .Case<scf::ForOp>([&](scf::ForOp forOp) { 316 scf::ForOp partialIteration; 317 if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp, 318 partialIteration))) 319 return partialIteration->getResults(); 320 assert(!partialIteration && "expected that loop was not peeled"); 321 return forOp->getResults(); 322 }) 323 .Default([&](Operation *op) { return op->getResults(); }); 324 } 325 326 /// Peel and canonicalize 'loops'. 327 void mlir::linalg::peelLoops(RewriterBase &rewriter, 328 ArrayRef<scf::ForOp> loops) { 329 for (auto loopOp : loops) { 330 SmallVector<Value, 4> loopResults; 331 loopResults = peelLoop(rewriter, loopOp); 332 } 333 } 334 335 /// Peel loops after tiling. 336 void mlir::linalg::peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res, 337 ArrayRef<int64_t> peeledLoops, 338 LinalgTilingLoopType loopType) { 339 for (int64_t loop : peeledLoops) { 340 assert(loop < static_cast<int64_t>(res.loops.size()) && 341 "requested peeling of non-existing loop"); 342 SmallVector<Value, 4> loopResults; 343 Operation *loopOp = res.loops[loop]; 344 loopResults = peelLoop(rewriter, loopOp); 345 346 // The result of the loop nest may change with peeling. 347 if (res.tensorResults.size() == loopOp->getNumResults() && 348 std::equal(res.tensorResults.begin(), res.tensorResults.end(), 349 loopOp->getResults().begin())) 350 res.tensorResults = loopResults; 351 } 352 } 353 354 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) { 355 if (tiledOp.loops.empty()) 356 return tiledOp.op.getOperation()->getResults(); 357 return tiledOp.loops.front()->getResults(); 358 } 359 360 static ValueRange 361 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) { 362 if (tiledAndFusedOp.fusedLoops.empty()) 363 return tiledAndFusedOp.op.getOperation()->getResults(); 364 return tiledAndFusedOp.fusedLoops.front()->getResults(); 365 } 366 367 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( 368 StringRef opName, MLIRContext *context, 369 const LinalgDependenceGraph &dependenceGraph, 370 LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, 371 LinalgTransformationFilter f, LinalgTransformationFilter fusedOpMarker, 372 LinalgTransformationFilter originalOpMarker, PatternBenefit benefit) 373 : RewritePattern(opName, benefit, context, {}), 374 dependenceGraph(dependenceGraph), tilingOptions(std::move(tilingOptions)), 375 fusionOptions(std::move(fusionOptions)), filter(std::move(f)), 376 fusedOpMarker(std::move(fusedOpMarker)), 377 originalOpMarker(std::move(originalOpMarker)) {} 378 379 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( 380 Operation *op, PatternRewriter &rewriter) const { 381 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 382 // TODO: remove hasIndexSemantics check once index ops are supported. 383 if (!linalgOp || linalgOp.hasIndexSemantics()) 384 return failure(); 385 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 386 return failure(); 387 388 DenseSet<Operation *> producers; 389 producers.insert(linalgOp); 390 for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) { 391 Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum(); 392 // When looking at dependences into, indexingOp is always OpOperand. We 393 // could assert, but continue if this is not the case. 394 if (!operandNumber) 395 continue; 396 if (!fusionOptions.indicesToFuse.count(operandNumber.getValue())) 397 continue; 398 if (isa<LinalgOp>(dependence.getDependentOp())) 399 producers.insert(dependence.getDependentOp()); 400 } 401 402 SmallVector<LinalgOp, 1> fusionOps; 403 for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; 404 ++it) { 405 auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it)); 406 if (producerLinalgOp && producers.count(producerLinalgOp)) 407 fusionOps.push_back(producerLinalgOp); 408 } 409 fusionOps.push_back(linalgOp); 410 411 SmallVector<Value, 4> tileSizes = 412 tilingOptions.tileSizeComputationFunction(rewriter, op); 413 LinalgTilingOptions instanceTilingOptions = tilingOptions; 414 instanceTilingOptions.setTileSizes(tileSizes); 415 Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps( 416 rewriter, fusionOps, dependenceGraph, instanceTilingOptions); 417 if (!tiledAndFusedOps) 418 return failure(); 419 420 // Tile the unfused loops; 421 SmallVector<Value, 4> unfusedLoopTileSizes; 422 Value zero = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0); 423 for (const auto &tileSize : enumerate(tileSizes)) { 424 if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) 425 unfusedLoopTileSizes.push_back(zero); 426 else 427 unfusedLoopTileSizes.push_back(tileSize.value()); 428 } 429 // Tile the loop only if there is a non-zero tile size. 430 if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) 431 unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); 432 if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { 433 if (auto cst = val.getDefiningOp<arith::ConstantIndexOp>()) 434 return cst.value() != 0; 435 return true; 436 })) { 437 LinalgTilingOptions unfusedTilingOptions = tilingOptions; 438 unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); 439 FailureOr<TiledLinalgOp> unfusedTiledOp = 440 tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); 441 if (failed(unfusedTiledOp)) 442 return failure(); 443 rewriter.replaceOp(tiledAndFusedOps->op, 444 getTiledOpResult(unfusedTiledOp.getValue())); 445 tiledAndFusedOps->op = unfusedTiledOp->op; 446 } 447 op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue())); 448 449 filter.replaceLinalgTransformationFilter(rewriter, 450 tiledAndFusedOps->op.getOperation()); 451 for (auto fusedOp : tiledAndFusedOps->fusedProducers) { 452 fusedOpMarker.replaceLinalgTransformationFilter(rewriter, 453 fusedOp.getOperation()); 454 } 455 for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) { 456 originalOpMarker.replaceLinalgTransformationFilter( 457 rewriter, origProducerOp.getOperation()); 458 } 459 rewriter.updateRootInPlace(op, [&]() { 460 originalOpMarker.replaceLinalgTransformationFilter(rewriter, op); 461 }); 462 return success(); 463 } 464 465 /// Linalg tiling pattern. 466 mlir::linalg::LinalgTilingPattern::LinalgTilingPattern( 467 MLIRContext *context, LinalgTilingOptions options, 468 LinalgTransformationFilter f, PatternBenefit benefit) 469 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 470 filter(std::move(f)), options(std::move(options)) {} 471 472 mlir::linalg::LinalgTilingPattern::LinalgTilingPattern( 473 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 474 LinalgTransformationFilter f, PatternBenefit benefit) 475 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 476 filter(f.addOpNameFilter(opName)), options(std::move(options)) {} 477 478 FailureOr<TiledLinalgOp> 479 mlir::linalg::LinalgTilingPattern::returningMatchAndRewrite( 480 LinalgOp op, PatternRewriter &rewriter) const { 481 if (failed(filter.checkAndNotify(rewriter, op))) 482 return failure(); 483 484 FailureOr<TiledLinalgOp> res = tileLinalgOp(rewriter, op, options); 485 if (failed(res)) 486 return failure(); 487 488 // Clear filter to stop recursive pattern application. 489 // This must be done here to properly propagate to peeling branches. 490 filter.replaceLinalgTransformationFilter(rewriter, res->op); 491 492 // Peel the loops of the TiledLinalgOp. 493 peelTiledLinalgOp(rewriter, *res, options.peeledLoops, options.loopType); 494 495 if (res->tensorResults.empty()) 496 rewriter.eraseOp(op); 497 else 498 rewriter.replaceOp(op, res->tensorResults); 499 500 return res; 501 } 502 503 /// Linalg padding pattern. 504 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( 505 MLIRContext *context, LinalgPaddingOptions options, 506 LinalgTransformationFilter f, PatternBenefit benefit) 507 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 508 filter(std::move(f)), options(std::move(options)) {} 509 510 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( 511 StringRef opName, MLIRContext *context, LinalgPaddingOptions options, 512 LinalgTransformationFilter f, PatternBenefit benefit) 513 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 514 filter(f.addOpNameFilter(opName)), options(std::move(options)) {} 515 516 FailureOr<LinalgOp> 517 mlir::linalg::LinalgPaddingPattern::returningMatchAndRewrite( 518 LinalgOp linalgOp, PatternRewriter &rewriter) const { 519 if (!linalgOp.hasTensorSemantics()) 520 return failure(); 521 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 522 return failure(); 523 524 // Pad the operation. 525 LinalgOp paddedOp; 526 FailureOr<SmallVector<Value>> newResults = 527 rewriteAsPaddedOp(rewriter, linalgOp, options.paddingDimensions, 528 options.paddingValues, options.packPaddings, paddedOp); 529 if (failed(newResults)) 530 return failure(); 531 532 // Hoist the padding. 533 for (const auto &en : enumerate(options.hoistPaddings)) { 534 if (static_cast<int64_t>(en.index()) >= paddedOp.getNumInputsAndOutputs()) 535 break; 536 OpOperand *opOperand = &paddedOp->getOpOperand(en.index()); 537 auto padOp = opOperand->get().getDefiningOp<tensor::PadOp>(); 538 if (!padOp || en.value() == 0) 539 continue; 540 541 // Fail hoisting if the operand shape is not fully static. 542 if (llvm::any_of(paddedOp.getShape(opOperand), 543 [](int64_t size) { return ShapedType::isDynamic(size); })) 544 return failure(); 545 546 tensor::PadOp hoistedOp; 547 SmallVector<GenericOp> transposeOps; 548 SmallVector<int64_t> transposeVector = 549 en.index() < options.transposePaddings.size() 550 ? options.transposePaddings[en.index()] 551 : SmallVector<int64_t>{}; 552 553 FailureOr<Value> newResult = hoistPaddingOnTensors( 554 padOp, en.value(), transposeVector, hoistedOp, transposeOps); 555 if (failed(newResult)) 556 continue; 557 rewriter.replaceOp(padOp, newResult.getValue()); 558 559 // Do not apply hoist padding to the newly introduced transpose operations. 560 for (GenericOp transposeOp : transposeOps) 561 filter.replaceLinalgTransformationFilter(rewriter, transposeOp); 562 } 563 564 // Replace the original operation to pad. 565 rewriter.replaceOp(linalgOp, newResults.getValue()); 566 filter.replaceLinalgTransformationFilter(rewriter, paddedOp); 567 568 return paddedOp; 569 } 570 571 /// Linalg tile and fuse tensor ops pattern. 572 mlir::linalg::LinalgTileAndFuseTensorOpsPattern:: 573 LinalgTileAndFuseTensorOpsPattern(MLIRContext *context, 574 LinalgTilingAndFusionOptions options, 575 LinalgTransformationFilter f, 576 PatternBenefit benefit) 577 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), 578 filter(std::move(f)), options(std::move(options)) {} 579 580 mlir::linalg::LinalgTileAndFuseTensorOpsPattern:: 581 LinalgTileAndFuseTensorOpsPattern(StringRef opName, MLIRContext *context, 582 LinalgTilingAndFusionOptions options, 583 LinalgTransformationFilter f, 584 PatternBenefit benefit) 585 : RewritePattern(opName, benefit, context), filter(std::move(f)), 586 options(std::move(options)) {} 587 588 FailureOr<mlir::linalg::TileLoopNest> 589 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::returningMatchAndRewrite( 590 Operation *op, PatternRewriter &rewriter) const { 591 LinalgOp rootOp = dyn_cast<LinalgOp>(op); 592 if (!rootOp) 593 return failure(); 594 if (failed(filter.checkAndNotify(rewriter, op))) 595 return failure(); 596 597 // Check `tileSizes` contains a tile size for every `rootOp` loop dimension. 598 if (options.tileSizes.size() < rootOp.getNumLoops()) 599 return rewriter.notifyMatchFailure(op, "expect #tile sizes >= #loops"); 600 601 // Check `tileInterchange` contains no entries or as many as `tileSizes`. 602 if (!options.tileInterchange.empty() && 603 options.tileInterchange.size() != options.tileSizes.size()) 604 return rewriter.notifyMatchFailure( 605 op, "expect the number of tile sizes and interchange dims to match"); 606 607 // Copy the `tileSizes` and `tileInterchange` prefixes needed for `rootOp`. 608 SmallVector<int64_t> rootTileSizes(options.tileSizes.begin(), 609 options.tileSizes.begin() + 610 rootOp.getNumLoops()); 611 SmallVector<int64_t> rootInterchange = 612 options.tileInterchange.empty() 613 ? llvm::to_vector<6>(llvm::seq<int64_t>(0, rootOp.getNumLoops())) 614 : SmallVector<int64_t>(options.tileInterchange.begin(), 615 options.tileInterchange.begin() + 616 rootOp.getNumLoops()); 617 618 // Check `rootTileSizes` contains non-zero tile sizes. 619 if (llvm::count(rootTileSizes, 0) == static_cast<long>(rootTileSizes.size())) 620 return rewriter.notifyMatchFailure( 621 op, "expect at least one non-zero tile size"); 622 623 // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions. 624 // It has to be a permutation since the tiling cannot tile the same loop 625 // dimension multiple times. 626 if (!isPermutation(rootInterchange)) 627 return rewriter.notifyMatchFailure( 628 op, "expect the tile interchange permutes the root loops"); 629 630 // Tile `rootOp` and fuse its producers. 631 FailureOr<TileLoopNest> tileLoopNest = 632 tileConsumerAndFuseProducers(rewriter, rootOp, rootTileSizes, 633 rootInterchange, options.tileDistribution); 634 if (failed(tileLoopNest)) 635 return rewriter.notifyMatchFailure( 636 op, "tileConsumerAndFuseProducers failed unexpectedly"); 637 638 // Replace all uses of the tiled loop operation. 639 rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults()); 640 641 // Apply the filter if specified. 642 for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps()) 643 filter.replaceLinalgTransformationFilter(rewriter, linalgOp); 644 return tileLoopNest; 645 } 646 647 /// Linalg generic interchange pattern. 648 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern( 649 MLIRContext *context, ArrayRef<unsigned> interchangeVector, 650 LinalgTransformationFilter f, PatternBenefit benefit) 651 : OpRewritePattern(context, benefit), filter(std::move(f)), 652 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 653 654 FailureOr<GenericOp> 655 mlir::linalg::GenericOpInterchangePattern::returningMatchAndRewrite( 656 GenericOp genericOp, PatternRewriter &rewriter) const { 657 if (failed(filter.checkAndNotify(rewriter, genericOp))) 658 return failure(); 659 660 FailureOr<GenericOp> transformedOp = 661 interchangeGenericOp(rewriter, genericOp, interchangeVector); 662 if (failed(transformedOp)) 663 return failure(); 664 665 // New filter if specified. 666 filter.replaceLinalgTransformationFilter(rewriter, genericOp); 667 return transformedOp; 668 } 669 670 /// Linalg generalization pattern. 671 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 672 MLIRContext *context, LinalgTransformationFilter f, PatternBenefit benefit) 673 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 674 filter(std::move(f)) {} 675 676 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 677 StringRef opName, MLIRContext *context, LinalgTransformationFilter f, 678 PatternBenefit benefit) 679 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 680 filter(f.addOpNameFilter(opName)) {} 681 682 FailureOr<GenericOp> 683 mlir::linalg::LinalgGeneralizationPattern::returningMatchAndRewrite( 684 LinalgOp linalgOp, PatternRewriter &rewriter) const { 685 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 686 return failure(); 687 FailureOr<GenericOp> genericOp = generalizeNamedOp(rewriter, linalgOp); 688 if (failed(genericOp)) 689 return failure(); 690 filter.replaceLinalgTransformationFilter(rewriter, *genericOp); 691 return genericOp; 692 } 693 694 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 695 MLIRContext *context, LinalgTransformationFilter f, 696 LinalgPromotionOptions options, PatternBenefit benefit) 697 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), 698 filter(std::move(f)), options(std::move(options)) {} 699 700 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 701 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 702 LinalgTransformationFilter f, PatternBenefit benefit) 703 : RewritePattern(opName, benefit, context, {}), filter(std::move(f)), 704 options(std::move(options)) {} 705 706 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 707 Operation *op, PatternRewriter &rewriter) const { 708 if (failed(filter.checkAndNotify(rewriter, op))) 709 return failure(); 710 if (failed(promoteSubviewsPrecondition(op, options))) 711 return failure(); 712 713 // TODO: We cannot use root update here. This pattern is creating other ops, 714 // so if the promotion fails, those need to be cleaned up, which doesnt seem 715 // to be happening here. So to fail properly, we should be cloning the op and 716 // deleting the previous op. This needs more investigation. 717 rewriter.startRootUpdate(op); 718 Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options); 719 if (!promotedOp) { 720 rewriter.cancelRootUpdate(op); 721 return op->emitError("subview promotion failed"); 722 } 723 rewriter.finalizeRootUpdate(op); 724 filter.replaceLinalgTransformationFilter(rewriter, op); 725 return success(); 726 } 727 728 mlir::linalg::LinalgPeelingPattern::LinalgPeelingPattern( 729 MLIRContext *context, LinalgTransformationFilter f, 730 LinalgPeelOptions options, PatternBenefit benefit) 731 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 732 filter(std::move(f)), options(std::move(options)) {} 733 734 mlir::linalg::LinalgPeelingPattern::LinalgPeelingPattern( 735 StringRef opName, MLIRContext *context, LinalgPeelOptions options, 736 LinalgTransformationFilter f, PatternBenefit benefit) 737 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 738 filter(f.addOpNameFilter(opName)), options(std::move(options)) {} 739 740 LogicalResult mlir::linalg::LinalgPeelingPattern::matchAndRewrite( 741 LinalgOp linalgOp, PatternRewriter &rewriter) const { 742 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 743 return failure(); 744 745 // Increase marker counter even if peeling doesn't happen for this op. 746 filter.replaceLinalgTransformationFilter(rewriter, linalgOp); 747 748 if (!options.loopsToPeelComputationFunction) 749 return failure(); 750 751 SmallVector<scf::ForOp, 4> loopsToPeel; 752 options.loopsToPeelComputationFunction(rewriter, linalgOp, loopsToPeel); 753 peelLoops(rewriter, loopsToPeel); 754 return success(); 755 } 756 757 mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern( 758 MLIRContext *context, LinalgTransformationFilter f, 759 LinalgVectorizationOptions options, PatternBenefit benefit) 760 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 761 filter(std::move(f)) {} 762 763 mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern( 764 StringRef opName, MLIRContext *context, LinalgVectorizationOptions options, 765 LinalgTransformationFilter f, PatternBenefit benefit) 766 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 767 filter(f.addOpNameFilter(opName)) {} 768 769 LogicalResult mlir::linalg::LinalgVectorizationPattern::matchAndRewrite( 770 LinalgOp linalgOp, PatternRewriter &rewriter) const { 771 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 772 return failure(); 773 return vectorize(rewriter, linalgOp); 774 } 775 776 LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite( 777 memref::CopyOp copyOp, PatternRewriter &rewriter) const { 778 return vectorizeCopy(rewriter, copyOp); 779 } 780 781 LogicalResult mlir::linalg::applyStagedPatterns( 782 Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns, 783 const FrozenRewritePatternSet &stage2Patterns, 784 function_ref<LogicalResult(Operation *)> stage3Lambda) { 785 unsigned iteration = 0; 786 (void)iteration; 787 for (const auto &patterns : stage1Patterns) { 788 LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" 789 << *op); 790 if (failed(applyPatternsAndFoldGreedily(op, patterns))) { 791 LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); 792 return failure(); 793 } 794 LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" 795 << *op); 796 if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { 797 LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); 798 return failure(); 799 } 800 LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" 801 << *op); 802 if (stage3Lambda) { 803 if (failed(stage3Lambda(op))) 804 return failure(); 805 LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" 806 << *op); 807 } 808 } 809 return success(); 810 } 811 812 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) { 813 return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName()); 814 } 815 816 /// Rewrite a tensor::PadOp into a sequence of InitTensorOp, FillOp (to 817 /// initialize with pad_val) and GenericOp (to copy contents). 818 LogicalResult 819 PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp, 820 PatternRewriter &rewriter) const { 821 822 auto inputShapedType = padOp.source().getType().cast<ShapedType>(); 823 auto resultShapedType = padOp.result().getType().cast<ShapedType>(); 824 825 // Bail on non-static shapes. 826 if (!inputShapedType.hasStaticShape()) 827 return failure(); 828 if (!resultShapedType.hasStaticShape()) 829 return failure(); 830 831 // Only support padding with a constant for now, i.e. either: 832 // 1. A BBarg from a different block. 833 // 2. A value defined outside of the current block. 834 Block &block = padOp.region().front(); 835 auto yieldOp = cast<tensor::YieldOp>(block.getTerminator()); 836 Value padValue = yieldOp.value(); 837 Operation *definingOp = padValue.getDefiningOp(); 838 if (definingOp && definingOp->getBlock() == &block) 839 return failure(); 840 if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block) 841 return failure(); 842 843 // Create tensor with the padded shape 844 Location loc = padOp.getLoc(); 845 SmallVector<Value> indices(resultShapedType.getRank(), 846 rewriter.create<arith::ConstantIndexOp>(loc, 0)); 847 Value initTensor = rewriter.create<InitTensorOp>( 848 loc, resultShapedType.getShape(), resultShapedType.getElementType()); 849 850 // Initialize tensor with the pad value 851 Value tmpTensor = rewriter 852 .create<linalg::FillOp>(loc, ValueRange{padValue}, 853 ValueRange{initTensor}) 854 .result(); 855 856 // Copy original contents into new tensor 857 // Uses linalg.generic, but could be done with tensor.insert_slice 858 SmallVector<AffineExpr, 4> outputExprs; 859 for (unsigned i = 0; i < resultShapedType.getRank(); ++i) { 860 outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) + 861 padOp.static_low()[i].cast<IntegerAttr>().getInt()); 862 } 863 864 SmallVector<AffineMap, 2> transferMaps = { 865 rewriter.getMultiDimIdentityMap(inputShapedType.getRank()), 866 AffineMap::get(resultShapedType.getRank(), 867 /*symbolCount=*/0, outputExprs, rewriter.getContext())}; 868 869 rewriter.replaceOpWithNewOp<linalg::GenericOp>( 870 padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps, 871 getNParallelLoopsAttrs(resultShapedType.getRank()), 872 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { 873 nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]); 874 }); 875 876 return success(); 877 } 878 879 /// Filling `dest` using FillOp constant padding value if possible. 880 /// Otherwise, generate a tensor::GenerateOp. 881 Value GeneralizePadOpPattern::createFillOrGenerateOp( 882 PatternRewriter &rewriter, tensor::PadOp padOp, Value dest, 883 const SmallVector<Value> &dynSizes) const { 884 auto padValue = padOp.getConstantPaddingValue(); 885 if (padValue) 886 return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result(); 887 888 // Fill could not be optimized: Lower to tensor::GenerateOp with region. 889 auto generateOp = rewriter.create<tensor::GenerateOp>( 890 padOp.getLoc(), padOp.getResultType(), dynSizes); 891 // Copy region to new op. 892 BlockAndValueMapping bvm; 893 padOp.region().cloneInto(&generateOp.getRegion(), bvm); 894 return generateOp; 895 } 896 897 LogicalResult 898 GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp, 899 PatternRewriter &rewriter) const { 900 // Given an OpFoldResult, return an index-typed value. 901 auto getIdxValue = [&](OpFoldResult ofr) { 902 if (auto val = ofr.dyn_cast<Value>()) 903 return val; 904 return rewriter 905 .create<arith::ConstantIndexOp>( 906 padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt()) 907 .getResult(); 908 }; 909 910 auto resultType = padOp.getResultType(); 911 // Compute size of InitTensorOp. Any combination of static/dynamic is 912 // supported. 913 SmallVector<Value> dynSizes; 914 SmallVector<int64_t> staticSizes; 915 for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { 916 if (resultType.isDynamicDim(dim)) { 917 auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(), 918 padOp.source(), dim); 919 // Add low and high padding value. 920 auto plusLow = rewriter.createOrFold<arith::AddIOp>( 921 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); 922 auto plusHigh = rewriter.createOrFold<arith::AddIOp>( 923 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); 924 dynSizes.push_back(plusHigh); 925 } 926 staticSizes.push_back(resultType.getDimSize(dim)); 927 } 928 929 // Init tensor and fill it with padding. 930 Value init = rewriter.create<InitTensorOp>( 931 padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType()); 932 Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes); 933 934 // Try optimize the copy of source. 935 if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded()) 936 return success(); 937 938 // tensor::PadOps cannot be optimized. Generate a InsertSliceOp instead 939 // for copying the PadOp source. 940 auto sourceType = padOp.getSourceType(); 941 // Compute size of source of tensor::PadOp. 942 SmallVector<OpFoldResult> srcSizes; 943 for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) { 944 if (sourceType.isDynamicDim(dim)) { 945 srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>( 946 padOp.getLoc(), padOp.source(), dim)); 947 } else { 948 srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim))); 949 } 950 } 951 // Strides of InsertSliceOp are all 1. 952 SmallVector<OpFoldResult> strides(sourceType.getRank(), 953 rewriter.getIndexAttr(1)); 954 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( 955 padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides); 956 957 return success(); 958 } 959 960 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( 961 tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { 962 if (!sliceOp.hasUnitStride()) 963 return failure(); 964 965 auto padOp = sliceOp.source().getDefiningOp<tensor::PadOp>(); 966 if (!padOp) 967 return failure(); 968 969 bool zeroSliceGuard = true; 970 if (controlFn) { 971 if (Optional<bool> control = controlFn(sliceOp)) 972 zeroSliceGuard = control.getValue(); 973 else 974 return failure(); 975 } 976 977 Operation *tiledPadOp = 978 tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(), 979 sliceOp.getMixedSizes(), zeroSliceGuard); 980 // All shapes are static and the data source is actually used. Rewrite into 981 // pad(extract_slice(x)). 982 rewriter.replaceOp(sliceOp, tiledPadOp->getResults()); 983 return success(); 984 } 985 986 // The following are patterns for downscaling convolution ops with size-1 987 // window dimensions. 988 // 989 // Note that we'd eventually want to write such transformations in a generic 990 // way, e.g., converting to linalg.generic, removing the size-1 dimensions, 991 // and then turning back to named ops. But for now it's fine to have a few 992 // patterns matching special ops to get started. 993 994 FailureOr<Conv1DNwcWcfOp> 995 DownscaleSizeOneWindowed2DConvolution::returningMatchAndRewrite( 996 linalg::Conv2DNhwcHwcfOp convOp, PatternRewriter &rewriter) const { 997 if (failed(filter.checkAndNotify(rewriter, convOp))) 998 return failure(); 999 if (convOp.hasBufferSemantics()) 1000 return failure(); // To be implemented. 1001 1002 Value input = convOp.inputs().front(); 1003 Value kernel = convOp.inputs().back(); 1004 Value output = convOp.outputs().front(); 1005 1006 auto inputType = input.getType().dyn_cast<RankedTensorType>(); 1007 auto kernelType = kernel.getType().dyn_cast<RankedTensorType>(); 1008 auto outputType = output.getType().dyn_cast<RankedTensorType>(); 1009 1010 auto kernelShape = kernelType.getShape(); 1011 auto outputShape = outputType.getShape(); 1012 1013 // Only handle the case where at least one of the window dimensions is 1014 // of size 1. Other cases can rely on tiling to reduce to such cases. 1015 int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; 1016 int64_t ohSize = outputShape[1], owSize = outputShape[2]; 1017 bool removeH = (khSize == 1 && ohSize == 1); 1018 bool removeW = (kwSize == 1 && owSize == 1); 1019 if (!removeH && !removeW) 1020 return failure(); 1021 1022 // Get new shapes and types for all operands by removing the size-1 1023 // dimension. 1024 using RTTBuilder = RankedTensorType::Builder; 1025 RankedTensorType newInputType = 1026 RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); 1027 RankedTensorType newKernelType = 1028 RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); 1029 RankedTensorType newOutputType = 1030 RTTBuilder(outputType).dropDim(removeH ? 1 : 2); 1031 1032 // Rank-reduce operands. 1033 Location loc = convOp.getLoc(); 1034 Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( 1035 rewriter, loc, input, newInputType); 1036 Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( 1037 rewriter, loc, kernel, newKernelType); 1038 Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( 1039 rewriter, loc, output, newOutputType); 1040 1041 // Rank-reduce strides and dilations too. 1042 // TODO: dropDim 1-liner helper. 1043 auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>()); 1044 strides.erase(strides.begin() + (removeH ? 0 : 1)); 1045 auto stridesAttr = rewriter.getI64VectorAttr(strides); 1046 1047 auto dilations = llvm::to_vector<4>(convOp.dilations().getValues<int64_t>()); 1048 dilations.erase(dilations.begin() + (removeH ? 0 : 1)); 1049 auto dilationsAttr = rewriter.getI64VectorAttr(dilations); 1050 1051 auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>( 1052 loc, newOutputType, ValueRange{newInput, newKernel}, 1053 ValueRange{newOutput}, stridesAttr, dilationsAttr); 1054 1055 // Insert back. 1056 Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( 1057 rewriter, loc, conv1DOp.getResult(0), output); 1058 rewriter.replaceOp(convOp, inserted); 1059 1060 filter.replaceLinalgTransformationFilter(rewriter, conv1DOp); 1061 return conv1DOp; 1062 } 1063 1064 FailureOr<DepthwiseConv1DNwcWcOp> 1065 DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite( 1066 DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const { 1067 if (failed(filter.checkAndNotify(rewriter, convOp))) 1068 return failure(); 1069 if (convOp.hasBufferSemantics()) 1070 return failure(); // To be implemented. 1071 1072 Value input = convOp.inputs().front(); 1073 Value kernel = convOp.inputs().back(); 1074 Value output = convOp.outputs().front(); 1075 1076 auto inputType = input.getType().dyn_cast<RankedTensorType>(); 1077 auto kernelType = kernel.getType().dyn_cast<RankedTensorType>(); 1078 auto outputType = output.getType().dyn_cast<RankedTensorType>(); 1079 1080 auto kernelShape = kernelType.getShape(); 1081 auto outputShape = outputType.getShape(); 1082 1083 // Only handle the case where at least one of the window dimensions is 1084 // of size 1. Other cases can rely on tiling to reduce to such cases. 1085 int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; 1086 int64_t ohSize = outputShape[1], owSize = outputShape[2]; 1087 bool removeH = (khSize == 1 && ohSize == 1); 1088 bool removeW = (kwSize == 1 && owSize == 1); 1089 if (!removeH && !removeW) 1090 return failure(); 1091 1092 // Get new shapes and types for all operands by removing the size-1 1093 // dimension. 1094 using RTTBuilder = RankedTensorType::Builder; 1095 RankedTensorType newInputType = 1096 RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); 1097 RankedTensorType newKernelType = 1098 RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); 1099 RankedTensorType newOutputType = 1100 RTTBuilder(outputType).dropDim(removeH ? 1 : 2); 1101 1102 // Rank-reduce operands. 1103 Location loc = convOp.getLoc(); 1104 Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( 1105 rewriter, loc, input, newInputType); 1106 Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( 1107 rewriter, loc, kernel, newKernelType); 1108 Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( 1109 rewriter, loc, output, newOutputType); 1110 1111 // Rank-reduce strides and dilations too. 1112 // TODO: dropDim 1-liner helper. 1113 auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>()); 1114 strides.erase(strides.begin() + (removeH ? 0 : 1)); 1115 auto stridesAttr = rewriter.getI64VectorAttr(strides); 1116 1117 auto dilations = llvm::to_vector<4>(convOp.dilations().getValues<int64_t>()); 1118 dilations.erase(dilations.begin() + (removeH ? 0 : 1)); 1119 auto dilationsAttr = rewriter.getI64VectorAttr(dilations); 1120 1121 auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>( 1122 loc, newOutputType, ValueRange{newInput, newKernel}, 1123 ValueRange{newOutput}, stridesAttr, dilationsAttr); 1124 1125 // Insert back. 1126 Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( 1127 rewriter, loc, conv1DOp.getResult(0), output); 1128 rewriter.replaceOp(convOp, inserted); 1129 1130 filter.replaceLinalgTransformationFilter(rewriter, conv1DOp); 1131 return conv1DOp; 1132 } 1133 1134 void linalg::populateDecomposeConvolutionPatterns( 1135 RewritePatternSet &patterns, const LinalgTransformationFilter &filter, 1136 PatternBenefit benefit) { 1137 patterns.add<DownscaleSizeOneWindowed2DConvolution, 1138 DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), filter, 1139 benefit); 1140 } 1141