1 //===- Transforms.cpp - Linalg transformations as patterns ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/Func/IR/FuncOps.h" 17 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 18 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/SCF/Transforms.h" 22 #include "mlir/Dialect/Tensor/IR/Tensor.h" 23 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" 24 #include "mlir/Dialect/Utils/StaticValueUtils.h" 25 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 26 #include "mlir/Dialect/Vector/IR/VectorOps.h" 27 #include "mlir/IR/AffineExpr.h" 28 #include "mlir/IR/Matchers.h" 29 #include "mlir/Pass/Pass.h" 30 #include "mlir/Support/LLVM.h" 31 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 32 #include "llvm/ADT/ScopeExit.h" 33 #include "llvm/ADT/TypeSwitch.h" 34 #include "llvm/Support/Debug.h" 35 #include "llvm/Support/raw_ostream.h" 36 #include <type_traits> 37 #include <utility> 38 39 #define DEBUG_TYPE "linalg-transforms" 40 41 using namespace mlir; 42 using namespace mlir::linalg; 43 44 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 45 46 //===----------------------------------------------------------------------===// 47 // Transformations exposed as rewrite patterns. 48 //===----------------------------------------------------------------------===// 49 // Marker used as attribute name in generated Linalg rewriting transformations. 50 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 51 "__internal_linalg_transform__"; 52 53 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 54 ArrayRef<StringAttr> matchDisjunction, Optional<StringAttr> replacement) 55 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 56 replacement(replacement), matchByDefault(false) {} 57 58 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 59 const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction, 60 Optional<StringAttr> replacement) 61 : filters(), 62 matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 63 replacement(replacement), matchByDefault(false) { 64 if (f) 65 filters.push_back(f); 66 } 67 68 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify( 69 PatternRewriter &rewriter, Operation *op) const { 70 if (llvm::any_of(filters, 71 [&](const FilterFunction &f) { return failed(f(op)); })) 72 return failure(); 73 74 auto attr = op->template getAttrOfType<StringAttr>( 75 LinalgTransforms::kLinalgTransformMarker); 76 77 if (!attr) { 78 // 1. Has no filter case and matchDisjunction is empty. 79 if (matchDisjunction.empty() || matchByDefault) 80 return success(); 81 82 // 2. Has no filter but was expecting a filter. 83 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 84 diag << " does not have any filter from list: "; 85 interleaveComma(matchDisjunction, diag); 86 }); 87 } 88 89 // 4. Match explicit filter. 90 for (auto filter : matchDisjunction) 91 if (attr.getValue() == filter) 92 return success(); 93 94 // 5. Fail to match. 95 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 96 diag << " does not have any filter from list: "; 97 interleaveComma(matchDisjunction, diag); 98 }); 99 } 100 101 void mlir::linalg::LinalgTransformationFilter:: 102 replaceLinalgTransformationFilter(PatternRewriter &rewriter, 103 Operation *op) const { 104 if (replacement.hasValue()) 105 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 106 replacement.getValue()); 107 else 108 op->removeAttr( 109 rewriter.getStringAttr(LinalgTransforms::kLinalgTransformMarker)); 110 } 111 112 bool mlir::linalg::LinalgTransformationFilter::hasReplacementFilter( 113 Operation *op) const { 114 if (!replacement) 115 return false; 116 auto attr = op->getAttr(LinalgTransforms::kLinalgTransformMarker) 117 .dyn_cast<StringAttr>(); 118 return attr && attr == replacement.getValue(); 119 } 120 121 LinalgTilingOptions & 122 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 123 assert(!tileSizeComputationFunction && "tile sizes already set"); 124 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 125 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 126 OpBuilder::InsertionGuard guard(b); 127 b.setInsertionPointToStart( 128 &op->getParentOfType<func::FuncOp>().getBody().front()); 129 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 130 Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); 131 return v; 132 })); 133 }; 134 return *this; 135 } 136 137 LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() { 138 assert(!tileSizeComputationFunction && "tile sizes already set"); 139 tileSizeComputationFunction = [](OpBuilder &b, Operation *op) { 140 SmallVector<Value, 4> tileSizes; 141 auto linalgOp = dyn_cast<LinalgOp>(op); 142 if (!linalgOp) 143 return tileSizes; 144 Location loc = linalgOp.getLoc(); 145 auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc); 146 AffineMap map = linalgOp.getShapesToLoopsMap(); 147 if (!map) 148 return tileSizes; 149 auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes); 150 // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile 151 // size 0). 152 for (Value shapeSize : shapeSizes) 153 tileSizes.push_back(getConstantIntValue(shapeSize).hasValue() 154 ? b.create<arith::ConstantIndexOp>(loc, 0) 155 : b.create<arith::ConstantIndexOp>(loc, 1)); 156 return tileSizes; 157 }; 158 return *this; 159 } 160 161 /// Pad the `opOperand` in the `paddingDimensions` using the padding value and 162 /// the nofold flag found in `paddingValues` and `packPaddings`, respectively. 163 /// Exit early and return the `opOperand` value if the shape dimensions that 164 /// match `paddingDimensions` have a static size and the nofold flag is not set. 165 /// Otherwise, try to pad the shape dimensions that match the iterator 166 /// dimensions `paddingDimensions` and return the tensor::PadOp result if 167 /// padding succeeds or failure otherwise. 168 static FailureOr<Value> padOperandToSmallestStaticBoundingBox( 169 OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand, 170 ArrayRef<int64_t> paddingDimensions, ArrayRef<Attribute> paddingValues, 171 ArrayRef<bool> packPaddings) { 172 AffineMap indexingMap = opToPad.getTiedIndexingMap(opOperand); 173 ArrayRef<int64_t> shape = opToPad.getShape(opOperand); 174 175 // Collect the shape dimension that are a function of the `paddingDimensions`. 176 llvm::SmallDenseSet<int64_t> shapeDimsToPad; 177 for (int64_t dim : paddingDimensions) 178 for (const auto &en : enumerate(indexingMap.getResults())) 179 if (en.value().isFunctionOfDim(dim)) 180 shapeDimsToPad.insert(en.index()); 181 182 // Return the unpadded operand if padding to a static shape is not needed and 183 // if the nofold flag is not set. 184 bool nofold = opOperand->getOperandNumber() < packPaddings.size() 185 ? packPaddings[opOperand->getOperandNumber()] 186 : false; 187 bool hasStaticShape = llvm::none_of(shapeDimsToPad, [&](int64_t dim) { 188 return ShapedType::isDynamic(shape[dim]); 189 }); 190 if (!nofold && hasStaticShape) 191 return opOperand->get(); 192 193 // Fail if `paddingValues` specifies no padding value. 194 if (opOperand->getOperandNumber() >= paddingValues.size()) 195 return failure(); 196 Attribute paddingAttr = paddingValues[opOperand->getOperandNumber()]; 197 Value paddingValue = b.create<arith::ConstantOp>( 198 opToPad.getLoc(), paddingAttr.getType(), paddingAttr); 199 200 // Follow the use-def chain if `currOpOperand` is defined by a LinalgOp. 201 OpOperand *currOpOperand = opOperand; 202 while (auto linalgOp = currOpOperand->get().getDefiningOp<LinalgOp>()) { 203 OpResult result = currOpOperand->get().cast<OpResult>(); 204 currOpOperand = linalgOp.getOutputOperand(result.getResultNumber()); 205 } 206 207 // Fail if `currOpOperand` is not defined by an ExtractSliceOp. 208 auto sliceOp = currOpOperand->get().getDefiningOp<tensor::ExtractSliceOp>(); 209 if (!sliceOp) 210 return failure(); 211 212 // Compute the dropped dimensions if `sliceOp` is ranke-reducing. 213 llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims(); 214 OffsetSizeAndStrideOpInterface shapedOp = sliceOp; 215 216 // Upper bound the `sliceOp` sizes to obtain a static bounding box. 217 SmallVector<int64_t> paddedShape(shape.begin(), shape.end()); 218 int64_t shapeIdx = 0; 219 for (const auto &en : enumerate(shapedOp.getMixedSizes())) { 220 // Skip dropped dimensions. 221 if (droppedDims.test(en.index())) 222 continue; 223 // Skip dimensions that do not require padding. 224 if (!shapeDimsToPad.contains(shapeIdx)) { 225 shapeIdx++; 226 continue; 227 } 228 // If the size is an attribute add it directly to `paddedShape`. 229 if (en.value().is<Attribute>()) { 230 paddedShape[shapeIdx++] = 231 en.value().get<Attribute>().dyn_cast<IntegerAttr>().getInt(); 232 continue; 233 } 234 // Otherwise, try to compute a constant upper bound for the size value. 235 FailureOr<int64_t> upperBound = 236 getConstantUpperBoundForIndex(en.value().get<Value>()); 237 if (failed(upperBound)) { 238 LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding"); 239 return failure(); 240 } 241 paddedShape[shapeIdx++] = upperBound.getValue(); 242 } 243 assert(shapeIdx == static_cast<int64_t>(shape.size()) && 244 "expect the dynamic and static ranks to match"); 245 246 // Pad the operand to the bounding box defined by `paddedShape`. 247 auto paddedTensorType = RankedTensorType::get( 248 paddedShape, getElementTypeOrSelf(opOperand->get())); 249 return makeComposedPadHighOp(b, opToPad->getLoc(), paddedTensorType, 250 opOperand->get(), paddingValue, nofold); 251 } 252 253 FailureOr<SmallVector<Value>> 254 linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad, 255 ArrayRef<int64_t> paddingDimensions, 256 ArrayRef<Attribute> paddingValues, 257 ArrayRef<bool> packPaddings, LinalgOp &paddedOp) { 258 Location loc = opToPad->getLoc(); 259 260 // TODO: there are cases where we may still want to pad to larger sizes. 261 assert(opToPad.hasTensorSemantics() && 262 "expected operation to have tensor semantics"); 263 264 OpBuilder::InsertionGuard g(b); 265 // Set IP after op because we also take the dims of the original output. 266 b.setInsertionPointAfter(opToPad); 267 // Make a copy of the shaped operands and update it. 268 SmallVector<Value> newOperands; 269 newOperands.reserve(opToPad.getNumInputsAndOutputs()); 270 for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) { 271 FailureOr<Value> paddedOperand = padOperandToSmallestStaticBoundingBox( 272 b, opToPad, opOperand, paddingDimensions, paddingValues, packPaddings); 273 // Exit if `paddingDimensions` cannot be bounded statically. 274 if (failed(paddedOperand)) 275 return failure(); 276 newOperands.push_back(*paddedOperand); 277 } 278 279 SmallVector<SmallVector<Value>> reifiedResultShapes; 280 if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation()) 281 .reifyResultShapes(b, reifiedResultShapes))) 282 return failure(); 283 assert(reifiedResultShapes.size() == opToPad->getNumResults() && 284 "expected same number of results"); 285 286 // Clone `opToPad` to operate on the statically padded shapes. 287 auto resultTensorTypes = 288 ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes(); 289 paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands); 290 291 // Recover the slice out of the new static results. This keeps the original 292 // linalg op around because it uses the dims of the original results. 293 SmallVector<Value> paddedSubviewResults; 294 paddedSubviewResults.reserve(opToPad->getNumResults()); 295 for (const auto &en : llvm::enumerate(paddedOp->getResults())) { 296 Value paddedResult = en.value(); 297 int64_t resultNumber = en.index(); 298 int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank(); 299 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0)); 300 SmallVector<OpFoldResult> sizes; 301 for (Value v : reifiedResultShapes[resultNumber]) 302 sizes.push_back(getAsOpFoldResult(v)); 303 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1)); 304 paddedSubviewResults.push_back(b.create<tensor::ExtractSliceOp>( 305 loc, paddedResult, offsets, sizes, strides)); 306 } 307 return paddedSubviewResults; 308 } 309 310 /// Try to peel a loop `op` and return the new result. 311 // TODO: Add support for scf.parallel and affine.for loops. 312 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) { 313 return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op) 314 .Case<scf::ForOp>([&](scf::ForOp forOp) { 315 scf::ForOp partialIteration; 316 if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp, 317 partialIteration))) 318 return partialIteration->getResults(); 319 assert(!partialIteration && "expected that loop was not peeled"); 320 return forOp->getResults(); 321 }) 322 .Default([&](Operation *op) { return op->getResults(); }); 323 } 324 325 /// Peel loops after tiling. 326 void mlir::linalg::peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res, 327 ArrayRef<int64_t> peeledLoops, 328 LinalgTilingLoopType loopType) { 329 for (int64_t loop : peeledLoops) { 330 assert(loop < static_cast<int64_t>(res.loops.size()) && 331 "requested peeling of non-existing loop"); 332 SmallVector<Value, 4> loopResults; 333 Operation *loopOp = res.loops[loop]; 334 loopResults = peelLoop(rewriter, loopOp); 335 336 // The result of the loop nest may change with peeling. 337 if (res.tensorResults.size() == loopOp->getNumResults() && 338 std::equal(res.tensorResults.begin(), res.tensorResults.end(), 339 loopOp->getResults().begin())) 340 res.tensorResults = loopResults; 341 } 342 } 343 344 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) { 345 if (tiledOp.loops.empty()) 346 return tiledOp.op.getOperation()->getResults(); 347 return tiledOp.loops.front()->getResults(); 348 } 349 350 static ValueRange 351 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) { 352 if (tiledAndFusedOp.fusedLoops.empty()) 353 return tiledAndFusedOp.op.getOperation()->getResults(); 354 return tiledAndFusedOp.fusedLoops.front()->getResults(); 355 } 356 357 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( 358 StringRef opName, MLIRContext *context, 359 const LinalgDependenceGraph &dependenceGraph, 360 LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, 361 LinalgTransformationFilter f, LinalgTransformationFilter fusedOpMarker, 362 LinalgTransformationFilter originalOpMarker, PatternBenefit benefit) 363 : RewritePattern(opName, benefit, context, {}), 364 dependenceGraph(dependenceGraph), tilingOptions(std::move(tilingOptions)), 365 fusionOptions(std::move(fusionOptions)), filter(std::move(f)), 366 fusedOpMarker(std::move(fusedOpMarker)), 367 originalOpMarker(std::move(originalOpMarker)) {} 368 369 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( 370 Operation *op, PatternRewriter &rewriter) const { 371 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 372 // TODO: remove hasIndexSemantics check once index ops are supported. 373 if (!linalgOp || linalgOp.hasIndexSemantics()) 374 return failure(); 375 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 376 return failure(); 377 378 DenseSet<Operation *> producers; 379 producers.insert(linalgOp); 380 for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) { 381 Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum(); 382 // When looking at dependences into, indexingOp is always OpOperand. We 383 // could assert, but continue if this is not the case. 384 if (!operandNumber) 385 continue; 386 if (!fusionOptions.indicesToFuse.count(operandNumber.getValue())) 387 continue; 388 if (isa<LinalgOp>(dependence.getDependentOp())) 389 producers.insert(dependence.getDependentOp()); 390 } 391 392 SmallVector<LinalgOp, 1> fusionOps; 393 for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; 394 ++it) { 395 auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it)); 396 if (producerLinalgOp && producers.count(producerLinalgOp)) 397 fusionOps.push_back(producerLinalgOp); 398 } 399 fusionOps.push_back(linalgOp); 400 401 SmallVector<Value, 4> tileSizes = 402 tilingOptions.tileSizeComputationFunction(rewriter, op); 403 LinalgTilingOptions instanceTilingOptions = tilingOptions; 404 instanceTilingOptions.setTileSizes(tileSizes); 405 Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps( 406 rewriter, fusionOps, dependenceGraph, instanceTilingOptions); 407 if (!tiledAndFusedOps) 408 return failure(); 409 410 // Tile the unfused loops; 411 SmallVector<Value, 4> unfusedLoopTileSizes; 412 Value zero = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0); 413 for (const auto &tileSize : enumerate(tileSizes)) { 414 if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) 415 unfusedLoopTileSizes.push_back(zero); 416 else 417 unfusedLoopTileSizes.push_back(tileSize.value()); 418 } 419 // Tile the loop only if there is a non-zero tile size. 420 if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) 421 unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); 422 if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { 423 if (auto cst = val.getDefiningOp<arith::ConstantIndexOp>()) 424 return cst.value() != 0; 425 return true; 426 })) { 427 LinalgTilingOptions unfusedTilingOptions = tilingOptions; 428 unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); 429 FailureOr<TiledLinalgOp> unfusedTiledOp = 430 tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); 431 if (failed(unfusedTiledOp)) 432 return failure(); 433 rewriter.replaceOp(tiledAndFusedOps->op, 434 getTiledOpResult(unfusedTiledOp.getValue())); 435 tiledAndFusedOps->op = unfusedTiledOp->op; 436 } 437 op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue())); 438 439 filter.replaceLinalgTransformationFilter(rewriter, 440 tiledAndFusedOps->op.getOperation()); 441 for (auto fusedOp : tiledAndFusedOps->fusedProducers) { 442 fusedOpMarker.replaceLinalgTransformationFilter(rewriter, 443 fusedOp.getOperation()); 444 } 445 for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) { 446 originalOpMarker.replaceLinalgTransformationFilter( 447 rewriter, origProducerOp.getOperation()); 448 } 449 rewriter.updateRootInPlace(op, [&]() { 450 originalOpMarker.replaceLinalgTransformationFilter(rewriter, op); 451 }); 452 return success(); 453 } 454 455 /// Linalg tiling pattern. 456 mlir::linalg::LinalgTilingPattern::LinalgTilingPattern( 457 MLIRContext *context, LinalgTilingOptions options, 458 LinalgTransformationFilter f, PatternBenefit benefit) 459 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 460 filter(std::move(f)), options(std::move(options)) {} 461 462 mlir::linalg::LinalgTilingPattern::LinalgTilingPattern( 463 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 464 LinalgTransformationFilter f, PatternBenefit benefit) 465 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 466 filter(f.addOpNameFilter(opName)), options(std::move(options)) {} 467 468 FailureOr<TiledLinalgOp> 469 mlir::linalg::LinalgTilingPattern::returningMatchAndRewrite( 470 LinalgOp op, PatternRewriter &rewriter) const { 471 if (failed(filter.checkAndNotify(rewriter, op))) 472 return failure(); 473 474 FailureOr<TiledLinalgOp> res = tileLinalgOp(rewriter, op, options); 475 if (failed(res)) 476 return failure(); 477 478 // Clear filter to stop recursive pattern application. 479 // This must be done here to properly propagate to peeling branches. 480 filter.replaceLinalgTransformationFilter(rewriter, res->op); 481 482 // Peel the loops of the TiledLinalgOp. 483 peelTiledLinalgOp(rewriter, *res, options.peeledLoops, options.loopType); 484 485 if (res->tensorResults.empty()) 486 rewriter.eraseOp(op); 487 else 488 rewriter.replaceOp(op, res->tensorResults); 489 490 return res; 491 } 492 493 /// Linalg padding pattern. 494 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( 495 MLIRContext *context, LinalgPaddingOptions options, 496 LinalgTransformationFilter f, PatternBenefit benefit) 497 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 498 filter(std::move(f)), options(std::move(options)) {} 499 500 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( 501 StringRef opName, MLIRContext *context, LinalgPaddingOptions options, 502 LinalgTransformationFilter f, PatternBenefit benefit) 503 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 504 filter(f.addOpNameFilter(opName)), options(std::move(options)) {} 505 506 FailureOr<LinalgOp> 507 mlir::linalg::LinalgPaddingPattern::returningMatchAndRewrite( 508 LinalgOp linalgOp, PatternRewriter &rewriter) const { 509 if (!linalgOp.hasTensorSemantics()) 510 return failure(); 511 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 512 return failure(); 513 514 // Pad the operation. 515 LinalgOp paddedOp; 516 FailureOr<SmallVector<Value>> newResults = 517 rewriteAsPaddedOp(rewriter, linalgOp, options.paddingDimensions, 518 options.paddingValues, options.packPaddings, paddedOp); 519 if (failed(newResults)) 520 return failure(); 521 522 // Hoist the padding. 523 for (const auto &en : enumerate(options.hoistPaddings)) { 524 if (static_cast<int64_t>(en.index()) >= paddedOp.getNumInputsAndOutputs()) 525 break; 526 OpOperand *opOperand = &paddedOp->getOpOperand(en.index()); 527 auto padOp = opOperand->get().getDefiningOp<tensor::PadOp>(); 528 if (!padOp || en.value() == 0) 529 continue; 530 531 // Fail hoisting if the operand shape is not fully static. 532 if (llvm::any_of(paddedOp.getShape(opOperand), 533 [](int64_t size) { return ShapedType::isDynamic(size); })) 534 return failure(); 535 536 tensor::PadOp hoistedOp; 537 SmallVector<GenericOp> transposeOps; 538 SmallVector<int64_t> transposeVector = 539 en.index() < options.transposePaddings.size() 540 ? options.transposePaddings[en.index()] 541 : SmallVector<int64_t>{}; 542 543 FailureOr<Value> newResult = hoistPaddingOnTensors( 544 padOp, en.value(), transposeVector, hoistedOp, transposeOps); 545 if (failed(newResult)) 546 continue; 547 rewriter.replaceOp(padOp, newResult.getValue()); 548 549 // Do not apply hoist padding to the newly introduced transpose operations. 550 for (GenericOp transposeOp : transposeOps) 551 filter.replaceLinalgTransformationFilter(rewriter, transposeOp); 552 } 553 554 // Replace the original operation to pad. 555 rewriter.replaceOp(linalgOp, newResults.getValue()); 556 filter.replaceLinalgTransformationFilter(rewriter, paddedOp); 557 558 return paddedOp; 559 } 560 561 /// Linalg tile and fuse tensor ops pattern. 562 mlir::linalg::LinalgTileAndFuseTensorOpsPattern:: 563 LinalgTileAndFuseTensorOpsPattern(MLIRContext *context, 564 LinalgTilingAndFusionOptions options, 565 LinalgTransformationFilter f, 566 PatternBenefit benefit) 567 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), 568 filter(std::move(f)), options(std::move(options)) {} 569 570 mlir::linalg::LinalgTileAndFuseTensorOpsPattern:: 571 LinalgTileAndFuseTensorOpsPattern(StringRef opName, MLIRContext *context, 572 LinalgTilingAndFusionOptions options, 573 LinalgTransformationFilter f, 574 PatternBenefit benefit) 575 : RewritePattern(opName, benefit, context), filter(std::move(f)), 576 options(std::move(options)) {} 577 578 FailureOr<mlir::linalg::TileLoopNest> 579 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::returningMatchAndRewrite( 580 Operation *op, PatternRewriter &rewriter) const { 581 LinalgOp rootOp = dyn_cast<LinalgOp>(op); 582 if (!rootOp) 583 return failure(); 584 if (failed(filter.checkAndNotify(rewriter, op))) 585 return failure(); 586 587 // Check `tileSizes` contains a tile size for every `rootOp` loop dimension. 588 if (options.tileSizes.size() < rootOp.getNumLoops()) 589 return rewriter.notifyMatchFailure(op, "expect #tile sizes >= #loops"); 590 591 // Check `tileInterchange` contains no entries or as many as `tileSizes`. 592 if (!options.tileInterchange.empty() && 593 options.tileInterchange.size() != options.tileSizes.size()) 594 return rewriter.notifyMatchFailure( 595 op, "expect the number of tile sizes and interchange dims to match"); 596 597 // Copy the `tileSizes` and `tileInterchange` prefixes needed for `rootOp`. 598 SmallVector<int64_t> rootTileSizes(options.tileSizes.begin(), 599 options.tileSizes.begin() + 600 rootOp.getNumLoops()); 601 SmallVector<int64_t> rootInterchange = 602 options.tileInterchange.empty() 603 ? llvm::to_vector<6>(llvm::seq<int64_t>(0, rootOp.getNumLoops())) 604 : SmallVector<int64_t>(options.tileInterchange.begin(), 605 options.tileInterchange.begin() + 606 rootOp.getNumLoops()); 607 608 // Check `rootTileSizes` contains non-zero tile sizes. 609 if (llvm::count(rootTileSizes, 0) == static_cast<long>(rootTileSizes.size())) 610 return rewriter.notifyMatchFailure( 611 op, "expect at least one non-zero tile size"); 612 613 // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions. 614 // It has to be a permutation since the tiling cannot tile the same loop 615 // dimension multiple times. 616 if (!isPermutation(rootInterchange)) 617 return rewriter.notifyMatchFailure( 618 op, "expect the tile interchange permutes the root loops"); 619 620 // Tile `rootOp` and fuse its producers. 621 FailureOr<TileLoopNest> tileLoopNest = 622 tileConsumerAndFuseProducers(rewriter, rootOp, rootTileSizes, 623 rootInterchange, options.tileDistribution); 624 if (failed(tileLoopNest)) 625 return rewriter.notifyMatchFailure( 626 op, "tileConsumerAndFuseProducers failed unexpectedly"); 627 628 // Replace all uses of the tiled loop operation. 629 rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults()); 630 631 // Apply the filter if specified. 632 for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps()) 633 filter.replaceLinalgTransformationFilter(rewriter, linalgOp); 634 return tileLoopNest; 635 } 636 637 /// Linalg generic interchange pattern. 638 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern( 639 MLIRContext *context, ArrayRef<unsigned> interchangeVector, 640 LinalgTransformationFilter f, PatternBenefit benefit) 641 : OpRewritePattern(context, benefit), filter(std::move(f)), 642 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 643 644 FailureOr<GenericOp> 645 mlir::linalg::GenericOpInterchangePattern::returningMatchAndRewrite( 646 GenericOp genericOp, PatternRewriter &rewriter) const { 647 if (failed(filter.checkAndNotify(rewriter, genericOp))) 648 return failure(); 649 650 FailureOr<GenericOp> transformedOp = 651 interchangeGenericOp(rewriter, genericOp, interchangeVector); 652 if (failed(transformedOp)) 653 return failure(); 654 655 // New filter if specified. 656 filter.replaceLinalgTransformationFilter(rewriter, genericOp); 657 return transformedOp; 658 } 659 660 /// Linalg generalization pattern. 661 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 662 MLIRContext *context, LinalgTransformationFilter f, PatternBenefit benefit) 663 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 664 filter(std::move(f)) {} 665 666 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 667 StringRef opName, MLIRContext *context, LinalgTransformationFilter f, 668 PatternBenefit benefit) 669 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 670 filter(f.addOpNameFilter(opName)) {} 671 672 FailureOr<GenericOp> 673 mlir::linalg::LinalgGeneralizationPattern::returningMatchAndRewrite( 674 LinalgOp linalgOp, PatternRewriter &rewriter) const { 675 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 676 return failure(); 677 FailureOr<GenericOp> genericOp = generalizeNamedOp(rewriter, linalgOp); 678 if (failed(genericOp)) 679 return failure(); 680 filter.replaceLinalgTransformationFilter(rewriter, *genericOp); 681 return genericOp; 682 } 683 684 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 685 MLIRContext *context, LinalgTransformationFilter f, 686 LinalgPromotionOptions options, PatternBenefit benefit) 687 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), 688 filter(std::move(f)), options(std::move(options)) {} 689 690 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 691 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 692 LinalgTransformationFilter f, PatternBenefit benefit) 693 : RewritePattern(opName, benefit, context, {}), filter(std::move(f)), 694 options(std::move(options)) {} 695 696 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 697 Operation *op, PatternRewriter &rewriter) const { 698 if (failed(filter.checkAndNotify(rewriter, op))) 699 return failure(); 700 if (failed(promoteSubviewsPrecondition(op, options))) 701 return failure(); 702 703 // TODO: We cannot use root update here. This pattern is creating other ops, 704 // so if the promotion fails, those need to be cleaned up, which doesnt seem 705 // to be happening here. So to fail properly, we should be cloning the op and 706 // deleting the previous op. This needs more investigation. 707 rewriter.startRootUpdate(op); 708 Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options); 709 if (!promotedOp) { 710 rewriter.cancelRootUpdate(op); 711 return op->emitError("subview promotion failed"); 712 } 713 rewriter.finalizeRootUpdate(op); 714 filter.replaceLinalgTransformationFilter(rewriter, op); 715 return success(); 716 } 717 718 mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern( 719 MLIRContext *context, LinalgTransformationFilter f, 720 LinalgVectorizationOptions options, PatternBenefit benefit) 721 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 722 filter(std::move(f)) {} 723 724 mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern( 725 StringRef opName, MLIRContext *context, LinalgVectorizationOptions options, 726 LinalgTransformationFilter f, PatternBenefit benefit) 727 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 728 filter(f.addOpNameFilter(opName)) {} 729 730 LogicalResult mlir::linalg::LinalgVectorizationPattern::matchAndRewrite( 731 LinalgOp linalgOp, PatternRewriter &rewriter) const { 732 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 733 return failure(); 734 return vectorize(rewriter, linalgOp); 735 } 736 737 LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite( 738 memref::CopyOp copyOp, PatternRewriter &rewriter) const { 739 return vectorizeCopy(rewriter, copyOp); 740 } 741 742 LogicalResult mlir::linalg::applyStagedPatterns( 743 Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns, 744 const FrozenRewritePatternSet &stage2Patterns, 745 function_ref<LogicalResult(Operation *)> stage3Lambda) { 746 unsigned iteration = 0; 747 (void)iteration; 748 for (const auto &patterns : stage1Patterns) { 749 LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" 750 << *op); 751 if (failed(applyPatternsAndFoldGreedily(op, patterns))) { 752 LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); 753 return failure(); 754 } 755 LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" 756 << *op); 757 if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { 758 LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); 759 return failure(); 760 } 761 LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" 762 << *op); 763 if (stage3Lambda) { 764 if (failed(stage3Lambda(op))) 765 return failure(); 766 LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" 767 << *op); 768 } 769 } 770 return success(); 771 } 772 773 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) { 774 return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName()); 775 } 776 777 /// Rewrite a tensor::PadOp into a sequence of InitTensorOp, FillOp (to 778 /// initialize with pad_val) and GenericOp (to copy contents). 779 LogicalResult 780 PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp, 781 PatternRewriter &rewriter) const { 782 783 auto inputShapedType = padOp.source().getType().cast<ShapedType>(); 784 auto resultShapedType = padOp.result().getType().cast<ShapedType>(); 785 786 // Bail on non-static shapes. 787 if (!inputShapedType.hasStaticShape()) 788 return failure(); 789 if (!resultShapedType.hasStaticShape()) 790 return failure(); 791 792 // Only support padding with a constant for now, i.e. either: 793 // 1. A BBarg from a different block. 794 // 2. A value defined outside of the current block. 795 Block &block = padOp.region().front(); 796 auto yieldOp = cast<tensor::YieldOp>(block.getTerminator()); 797 Value padValue = yieldOp.value(); 798 Operation *definingOp = padValue.getDefiningOp(); 799 if (definingOp && definingOp->getBlock() == &block) 800 return failure(); 801 if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block) 802 return failure(); 803 804 // Create tensor with the padded shape 805 Location loc = padOp.getLoc(); 806 SmallVector<Value> indices(resultShapedType.getRank(), 807 rewriter.create<arith::ConstantIndexOp>(loc, 0)); 808 Value initTensor = rewriter.create<InitTensorOp>( 809 loc, resultShapedType.getShape(), resultShapedType.getElementType()); 810 811 // Initialize tensor with the pad value 812 Value tmpTensor = rewriter 813 .create<linalg::FillOp>(loc, ValueRange{padValue}, 814 ValueRange{initTensor}) 815 .result(); 816 817 // Copy original contents into new tensor 818 // Uses linalg.generic, but could be done with tensor.insert_slice 819 SmallVector<AffineExpr, 4> outputExprs; 820 for (unsigned i = 0; i < resultShapedType.getRank(); ++i) { 821 outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) + 822 padOp.static_low()[i].cast<IntegerAttr>().getInt()); 823 } 824 825 SmallVector<AffineMap, 2> transferMaps = { 826 rewriter.getMultiDimIdentityMap(inputShapedType.getRank()), 827 AffineMap::get(resultShapedType.getRank(), 828 /*symbolCount=*/0, outputExprs, rewriter.getContext())}; 829 830 rewriter.replaceOpWithNewOp<linalg::GenericOp>( 831 padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps, 832 getNParallelLoopsAttrs(resultShapedType.getRank()), 833 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { 834 nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]); 835 }); 836 837 return success(); 838 } 839 840 /// Filling `dest` using FillOp constant padding value if possible. 841 /// Otherwise, generate a tensor::GenerateOp. 842 Value GeneralizePadOpPattern::createFillOrGenerateOp( 843 PatternRewriter &rewriter, tensor::PadOp padOp, Value dest, 844 const SmallVector<Value> &dynSizes) const { 845 auto padValue = padOp.getConstantPaddingValue(); 846 if (padValue) 847 return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result(); 848 849 // Fill could not be optimized: Lower to tensor::GenerateOp with region. 850 auto generateOp = rewriter.create<tensor::GenerateOp>( 851 padOp.getLoc(), padOp.getResultType(), dynSizes); 852 // Copy region to new op. 853 BlockAndValueMapping bvm; 854 padOp.region().cloneInto(&generateOp.getRegion(), bvm); 855 return generateOp; 856 } 857 858 LogicalResult 859 GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp, 860 PatternRewriter &rewriter) const { 861 // Given an OpFoldResult, return an index-typed value. 862 auto getIdxValue = [&](OpFoldResult ofr) { 863 if (auto val = ofr.dyn_cast<Value>()) 864 return val; 865 return rewriter 866 .create<arith::ConstantIndexOp>( 867 padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt()) 868 .getResult(); 869 }; 870 871 auto resultType = padOp.getResultType(); 872 // Compute size of InitTensorOp. Any combination of static/dynamic is 873 // supported. 874 SmallVector<Value> dynSizes; 875 SmallVector<int64_t> staticSizes; 876 for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { 877 if (resultType.isDynamicDim(dim)) { 878 auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(), 879 padOp.source(), dim); 880 // Add low and high padding value. 881 auto plusLow = rewriter.createOrFold<arith::AddIOp>( 882 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); 883 auto plusHigh = rewriter.createOrFold<arith::AddIOp>( 884 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); 885 dynSizes.push_back(plusHigh); 886 } 887 staticSizes.push_back(resultType.getDimSize(dim)); 888 } 889 890 // Init tensor and fill it with padding. 891 Value init = rewriter.create<InitTensorOp>( 892 padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType()); 893 Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes); 894 895 // Try optimize the copy of source. 896 if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded()) 897 return success(); 898 899 // tensor::PadOps cannot be optimized. Generate a InsertSliceOp instead 900 // for copying the PadOp source. 901 auto sourceType = padOp.getSourceType(); 902 // Compute size of source of tensor::PadOp. 903 SmallVector<OpFoldResult> srcSizes; 904 for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) { 905 if (sourceType.isDynamicDim(dim)) { 906 srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>( 907 padOp.getLoc(), padOp.source(), dim)); 908 } else { 909 srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim))); 910 } 911 } 912 // Strides of InsertSliceOp are all 1. 913 SmallVector<OpFoldResult> strides(sourceType.getRank(), 914 rewriter.getIndexAttr(1)); 915 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( 916 padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides); 917 918 return success(); 919 } 920 921 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( 922 tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { 923 if (!sliceOp.hasUnitStride()) 924 return failure(); 925 926 auto padOp = sliceOp.source().getDefiningOp<tensor::PadOp>(); 927 if (!padOp) 928 return failure(); 929 930 bool zeroSliceGuard = true; 931 if (controlFn) { 932 if (Optional<bool> control = controlFn(sliceOp)) 933 zeroSliceGuard = control.getValue(); 934 else 935 return failure(); 936 } 937 938 Operation *tiledPadOp = 939 tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(), 940 sliceOp.getMixedSizes(), zeroSliceGuard); 941 // All shapes are static and the data source is actually used. Rewrite into 942 // pad(extract_slice(x)). 943 rewriter.replaceOp(sliceOp, tiledPadOp->getResults()); 944 return success(); 945 } 946 947 namespace { 948 // The following are patterns for downscaling convolution ops with size-1 949 // window dimensions. 950 // 951 // Note that we'd eventually want to write such transformations in a generic 952 // way, e.g., converting to linalg.generic, removing the size-1 dimensions, 953 // and then turning back to named ops. But for now it's fine to have a few 954 // patterns matching special ops to get started. 955 956 /// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D 957 /// convolution ops. 958 struct DownscaleSizeOneWindowed2DConvolution final 959 : public OpRewritePattern<Conv2DNhwcHwcfOp> { 960 DownscaleSizeOneWindowed2DConvolution( 961 MLIRContext *context, 962 LinalgTransformationFilter f = LinalgTransformationFilter(), 963 PatternBenefit benefit = 1) 964 : OpRewritePattern<Conv2DNhwcHwcfOp>(context, benefit), 965 filter(std::move(f)) {} 966 967 LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, 968 PatternRewriter &rewriter) const override { 969 if (failed(filter.checkAndNotify(rewriter, convOp))) 970 return failure(); 971 if (convOp.hasBufferSemantics()) 972 return failure(); // To be implemented 973 974 Value input = convOp.inputs().front(); 975 Value kernel = convOp.inputs().back(); 976 Value output = convOp.outputs().front(); 977 978 auto inputType = input.getType().dyn_cast<RankedTensorType>(); 979 auto kernelType = kernel.getType().dyn_cast<RankedTensorType>(); 980 auto outputType = output.getType().dyn_cast<RankedTensorType>(); 981 982 auto kernelShape = kernelType.getShape(); 983 auto outputShape = outputType.getShape(); 984 985 // Only handle the case where at least one of the window dimensions is 986 // of size 1. Other cases can rely on tiling to reduce to such cases. 987 int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; 988 int64_t ohSize = outputShape[1], owSize = outputShape[2]; 989 bool removeH = (khSize == 1 && ohSize == 1); 990 bool removeW = (kwSize == 1 && owSize == 1); 991 if (!removeH && !removeW) 992 return failure(); 993 994 // Get new shapes and types for all operands by removing the size-1 995 // dimension. 996 using RTTBuilder = RankedTensorType::Builder; 997 RankedTensorType newInputType = 998 RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); 999 RankedTensorType newKernelType = 1000 RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); 1001 RankedTensorType newOutputType = 1002 RTTBuilder(outputType).dropDim(removeH ? 1 : 2); 1003 1004 // Rank-reduce operands. 1005 Location loc = convOp.getLoc(); 1006 Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( 1007 rewriter, loc, input, newInputType); 1008 Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( 1009 rewriter, loc, kernel, newKernelType); 1010 Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( 1011 rewriter, loc, output, newOutputType); 1012 1013 // Rank-reduce strides and dilations too. 1014 // TODO: dropDim 1-liner helper. 1015 auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>()); 1016 strides.erase(strides.begin() + (removeH ? 0 : 1)); 1017 auto stridesAttr = rewriter.getI64VectorAttr(strides); 1018 1019 auto dilations = 1020 llvm::to_vector<4>(convOp.dilations().getValues<int64_t>()); 1021 dilations.erase(dilations.begin() + (removeH ? 0 : 1)); 1022 auto dilationsAttr = rewriter.getI64VectorAttr(dilations); 1023 1024 auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>( 1025 loc, newOutputType, ValueRange{newInput, newKernel}, 1026 ValueRange{newOutput}, stridesAttr, dilationsAttr); 1027 1028 // Insert back. 1029 Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( 1030 rewriter, loc, conv1DOp.getResult(0), output); 1031 rewriter.replaceOp(convOp, inserted); 1032 1033 filter.replaceLinalgTransformationFilter(rewriter, conv1DOp); 1034 return success(); 1035 }; 1036 1037 private: 1038 /// LinalgTransformMarker handles special attribute manipulations. 1039 LinalgTransformationFilter filter; 1040 }; 1041 1042 /// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) 1043 /// dimensions into 1-D depthwise convolution ops. 1044 struct DownscaleDepthwiseConv2DNhwcHwcOp final 1045 : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> { 1046 DownscaleDepthwiseConv2DNhwcHwcOp( 1047 MLIRContext *context, 1048 LinalgTransformationFilter f = LinalgTransformationFilter(), 1049 PatternBenefit benefit = 1) 1050 : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit), 1051 filter(std::move(f)) {} 1052 1053 LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, 1054 PatternRewriter &rewriter) const override { 1055 if (failed(filter.checkAndNotify(rewriter, convOp))) 1056 return failure(); 1057 if (convOp.hasBufferSemantics()) 1058 return failure(); // To be implemented 1059 1060 Value input = convOp.inputs().front(); 1061 Value kernel = convOp.inputs().back(); 1062 Value output = convOp.outputs().front(); 1063 1064 auto inputType = input.getType().dyn_cast<RankedTensorType>(); 1065 auto kernelType = kernel.getType().dyn_cast<RankedTensorType>(); 1066 auto outputType = output.getType().dyn_cast<RankedTensorType>(); 1067 1068 auto kernelShape = kernelType.getShape(); 1069 auto outputShape = outputType.getShape(); 1070 1071 // Only handle the case where at least one of the window dimensions is 1072 // of size 1. Other cases can rely on tiling to reduce to such cases. 1073 int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; 1074 int64_t ohSize = outputShape[1], owSize = outputShape[2]; 1075 bool removeH = (khSize == 1 && ohSize == 1); 1076 bool removeW = (kwSize == 1 && owSize == 1); 1077 if (!removeH && !removeW) 1078 return failure(); 1079 1080 // Get new shapes and types for all operands by removing the size-1 1081 // dimension. 1082 using RTTBuilder = RankedTensorType::Builder; 1083 RankedTensorType newInputType = 1084 RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); 1085 RankedTensorType newKernelType = 1086 RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); 1087 RankedTensorType newOutputType = 1088 RTTBuilder(outputType).dropDim(removeH ? 1 : 2); 1089 1090 // Rank-reduce operands. 1091 Location loc = convOp.getLoc(); 1092 Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( 1093 rewriter, loc, input, newInputType); 1094 Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( 1095 rewriter, loc, kernel, newKernelType); 1096 Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( 1097 rewriter, loc, output, newOutputType); 1098 1099 // Rank-reduce strides and dilations too. 1100 // TODO: dropDim 1-liner helper. 1101 auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>()); 1102 strides.erase(strides.begin() + (removeH ? 0 : 1)); 1103 auto stridesAttr = rewriter.getI64VectorAttr(strides); 1104 1105 auto dilations = 1106 llvm::to_vector<4>(convOp.dilations().getValues<int64_t>()); 1107 dilations.erase(dilations.begin() + (removeH ? 0 : 1)); 1108 auto dilationsAttr = rewriter.getI64VectorAttr(dilations); 1109 1110 auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>( 1111 loc, newOutputType, ValueRange{newInput, newKernel}, 1112 ValueRange{newOutput}, stridesAttr, dilationsAttr); 1113 1114 // Insert back. 1115 Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( 1116 rewriter, loc, conv1DOp.getResult(0), output); 1117 rewriter.replaceOp(convOp, inserted); 1118 1119 filter.replaceLinalgTransformationFilter(rewriter, conv1DOp); 1120 return success(); 1121 }; 1122 1123 private: 1124 /// LinalgTransformMarker handles special attribute manipulations. 1125 LinalgTransformationFilter filter; 1126 }; 1127 1128 } // namespace 1129 1130 void linalg::populateDecomposeConvolutionPatterns( 1131 RewritePatternSet &patterns, const LinalgTransformationFilter &filter, 1132 PatternBenefit benefit) { 1133 patterns.add<DownscaleSizeOneWindowed2DConvolution, 1134 DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), filter, 1135 benefit); 1136 } 1137