1 //===- Transforms.cpp - Linalg transformations as patterns ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/Func/IR/FuncOps.h" 17 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 18 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/SCF/Transforms.h" 22 #include "mlir/Dialect/Tensor/IR/Tensor.h" 23 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" 24 #include "mlir/Dialect/Utils/StaticValueUtils.h" 25 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 26 #include "mlir/Dialect/Vector/IR/VectorOps.h" 27 #include "mlir/IR/AffineExpr.h" 28 #include "mlir/IR/Matchers.h" 29 #include "mlir/Pass/Pass.h" 30 #include "mlir/Support/LLVM.h" 31 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 32 #include "llvm/ADT/ScopeExit.h" 33 #include "llvm/ADT/TypeSwitch.h" 34 #include "llvm/Support/Debug.h" 35 #include "llvm/Support/raw_ostream.h" 36 #include <type_traits> 37 #include <utility> 38 39 #define DEBUG_TYPE "linalg-transforms" 40 41 using namespace mlir; 42 using namespace mlir::linalg; 43 44 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 45 46 //===----------------------------------------------------------------------===// 47 // Transformations exposed as rewrite patterns. 48 //===----------------------------------------------------------------------===// 49 // Marker used as attribute name in generated Linalg rewriting transformations. 50 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 51 "__internal_linalg_transform__"; 52 53 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 54 ArrayRef<StringAttr> matchDisjunction, Optional<StringAttr> replacement) 55 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 56 replacement(replacement), matchByDefault(false) {} 57 58 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 59 const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction, 60 Optional<StringAttr> replacement) 61 : filters(), 62 matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 63 replacement(replacement), matchByDefault(false) { 64 if (f) 65 filters.push_back(f); 66 } 67 68 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify( 69 PatternRewriter &rewriter, Operation *op) const { 70 if (llvm::any_of(filters, 71 [&](const FilterFunction &f) { return failed(f(op)); })) 72 return failure(); 73 74 auto attr = op->template getAttrOfType<StringAttr>( 75 LinalgTransforms::kLinalgTransformMarker); 76 77 if (!attr) { 78 // 1. Has no filter case and matchDisjunction is empty. 79 if (matchDisjunction.empty() || matchByDefault) 80 return success(); 81 82 // 2. Has no filter but was expecting a filter. 83 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 84 diag << " does not have any filter from list: "; 85 interleaveComma(matchDisjunction, diag); 86 }); 87 } 88 89 // 4. Match explicit filter. 90 for (auto filter : matchDisjunction) 91 if (attr.getValue() == filter) 92 return success(); 93 94 // 5. Fail to match. 95 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 96 diag << " does not have any filter from list: "; 97 interleaveComma(matchDisjunction, diag); 98 }); 99 } 100 101 void mlir::linalg::LinalgTransformationFilter:: 102 replaceLinalgTransformationFilter(PatternRewriter &rewriter, 103 Operation *op) const { 104 if (replacement.hasValue()) 105 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 106 replacement.getValue()); 107 else 108 op->removeAttr( 109 rewriter.getStringAttr(LinalgTransforms::kLinalgTransformMarker)); 110 } 111 112 bool mlir::linalg::LinalgTransformationFilter::hasReplacementFilter( 113 Operation *op) const { 114 if (!replacement) 115 return false; 116 auto attr = op->getAttr(LinalgTransforms::kLinalgTransformMarker) 117 .dyn_cast<StringAttr>(); 118 return attr && attr == replacement.getValue(); 119 } 120 121 LinalgTilingOptions & 122 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 123 assert(!tileSizeComputationFunction && "tile sizes already set"); 124 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 125 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 126 OpBuilder::InsertionGuard guard(b); 127 b.setInsertionPointToStart( 128 &op->getParentOfType<FuncOp>().getBody().front()); 129 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 130 Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); 131 return v; 132 })); 133 }; 134 return *this; 135 } 136 137 LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() { 138 assert(!tileSizeComputationFunction && "tile sizes already set"); 139 tileSizeComputationFunction = [](OpBuilder &b, Operation *op) { 140 SmallVector<Value, 4> tileSizes; 141 auto linalgOp = dyn_cast<LinalgOp>(op); 142 if (!linalgOp) 143 return tileSizes; 144 Location loc = linalgOp.getLoc(); 145 auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc); 146 AffineMap map = linalgOp.getShapesToLoopsMap(); 147 if (!map) 148 return tileSizes; 149 auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes); 150 // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile 151 // size 0). 152 for (Value shapeSize : shapeSizes) 153 tileSizes.push_back(getConstantIntValue(shapeSize).hasValue() 154 ? b.create<arith::ConstantIndexOp>(loc, 0) 155 : b.create<arith::ConstantIndexOp>(loc, 1)); 156 return tileSizes; 157 }; 158 return *this; 159 } 160 161 /// Pad `opOperand` using the provided `paddingValues`. Exit early for scalar 162 /// operands, if `paddingValues` contains no value for the `opOperand`, or if 163 /// `opOperand` is not defined by an ExtractSliceOp. Otherwise, try to pad the 164 /// operand even if it already has a static shape. Set `result` to the result of 165 /// the created tensor::PadOp or and return success if the operand either has 166 /// been padded to a static shape or already had a static shape and failure 167 /// otherwise. 168 static LogicalResult padOperandToSmallestStaticBoundingBox( 169 OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand, 170 ArrayRef<Attribute> paddingValues, ArrayRef<bool> packPaddings, 171 Value &result) { 172 // Get the shape of the operand and check if it has a dynamic shape. Only 173 // return failure if the operand is not a scalar and has a dynamic shape. 174 ArrayRef<int64_t> shape = opToPad.getShape(opOperand); 175 bool hasDynamicShape = llvm::is_contained(shape, ShapedType::kDynamicSize); 176 177 // Cannot pad scalar operands. 178 if (shape.empty()) 179 return success(); 180 181 // Cannot pad if the padding value is unknown. 182 if (opOperand->getOperandNumber() >= paddingValues.size()) 183 return failure(hasDynamicShape); 184 Attribute paddingAttr = paddingValues[opOperand->getOperandNumber()]; 185 Value paddingValue = b.create<arith::ConstantOp>( 186 opToPad.getLoc(), paddingAttr.getType(), paddingAttr); 187 188 // Follow the use-def chain if `currOpOperand` is defined by a LinalgOp. 189 OpOperand *currOpOperand = opOperand; 190 while (auto linalgOp = currOpOperand->get().getDefiningOp<LinalgOp>()) { 191 OpResult result = currOpOperand->get().cast<OpResult>(); 192 currOpOperand = linalgOp.getOutputOperand(result.getResultNumber()); 193 } 194 195 // Cannot construct a static bounding box if the `currOpOperand` is not 196 // defined by an ExtractSliceOp. 197 auto sliceOp = currOpOperand->get().getDefiningOp<tensor::ExtractSliceOp>(); 198 if (!sliceOp) 199 return failure(hasDynamicShape); 200 201 // Compute the dropped dimensions if `sliceOp` is ranke-reducing. 202 llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims(); 203 204 // Upper bound the `sliceOp` sizes to obtain a static bounding box. 205 SmallVector<int64_t> staticSizes; 206 staticSizes.reserve(shape.size()); 207 auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation()); 208 for (const auto &en : enumerate(shapedOp.getMixedSizes())) { 209 // Skip dropped dimensions. 210 if (droppedDims.test(en.index())) 211 continue; 212 // If the size is an attribute add it directly to `staticSizes`. 213 if (en.value().is<Attribute>()) { 214 staticSizes.push_back( 215 en.value().get<Attribute>().dyn_cast<IntegerAttr>().getInt()); 216 continue; 217 } 218 // Otherwise, try to compute a constant upper bound for the size value. 219 FailureOr<int64_t> upperBound = 220 getConstantUpperBoundForIndex(en.value().get<Value>()); 221 if (failed(upperBound)) { 222 LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding"); 223 return failure(); 224 } 225 staticSizes.push_back(upperBound.getValue()); 226 } 227 assert(staticSizes.size() == shape.size() && 228 "expect the dynamic and static ranks to match"); 229 230 // Pad the operand to the bounding box defined by `staticSizes`. 231 auto staticTensorType = RankedTensorType::get( 232 staticSizes, getElementTypeOrSelf(opOperand->get())); 233 bool nofold = opOperand->getOperandNumber() < packPaddings.size() 234 ? packPaddings[opOperand->getOperandNumber()] 235 : false; 236 result = makeComposedPadHighOp(b, opToPad->getLoc(), staticTensorType, 237 opOperand->get(), paddingValue, nofold); 238 return success(); 239 } 240 241 FailureOr<SmallVector<Value>> 242 linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad, 243 ArrayRef<Attribute> paddingValues, 244 ArrayRef<bool> packPaddings, LinalgOp &paddedOp) { 245 Location loc = opToPad->getLoc(); 246 247 // TODO: there are cases where we may still want to pad to larger sizes. 248 assert(opToPad.hasTensorSemantics() && 249 "expected operation to have tensor semantics"); 250 251 OpBuilder::InsertionGuard g(b); 252 // Set IP after op because we also take the dims of the original output. 253 b.setInsertionPointAfter(opToPad); 254 // Make a copy of the shaped operands and update it. 255 SmallVector<Value> newOperands; 256 newOperands.reserve(opToPad.getNumInputsAndOutputs()); 257 for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) { 258 Value paddedOperand; 259 // If padding was requested but the shape cannot be bounded statically then 260 // the pattern fails to apply. 261 if (failed(padOperandToSmallestStaticBoundingBox( 262 b, opToPad, opOperand, paddingValues, packPaddings, paddedOperand))) 263 return failure(); 264 newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get()); 265 } 266 267 SmallVector<SmallVector<Value>> reifiedResultShapes; 268 if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation()) 269 .reifyResultShapes(b, reifiedResultShapes))) 270 return failure(); 271 assert(reifiedResultShapes.size() == opToPad->getNumResults() && 272 "expected same number of results"); 273 274 // Clone `opToPad` to operate on the statically padded shapes. 275 auto resultTensorTypes = 276 ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes(); 277 paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands); 278 279 // Recover the slice out of the new static results. This keeps the original 280 // linalg op around because it uses the dims of the original results. 281 SmallVector<Value> paddedSubviewResults; 282 paddedSubviewResults.reserve(opToPad->getNumResults()); 283 for (const auto &en : llvm::enumerate(paddedOp->getResults())) { 284 Value paddedResult = en.value(); 285 int64_t resultNumber = en.index(); 286 int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank(); 287 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0)); 288 SmallVector<OpFoldResult> sizes; 289 for (Value v : reifiedResultShapes[resultNumber]) 290 sizes.push_back(getAsOpFoldResult(v)); 291 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1)); 292 paddedSubviewResults.push_back(b.create<tensor::ExtractSliceOp>( 293 loc, paddedResult, offsets, sizes, strides)); 294 } 295 return paddedSubviewResults; 296 } 297 298 /// Try to peel a loop `op` and return the new result. 299 // TODO: Add support for scf.parallel and affine.for loops. 300 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) { 301 return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op) 302 .Case<scf::ForOp>([&](scf::ForOp forOp) { 303 scf::ForOp partialIteration; 304 if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp, 305 partialIteration))) 306 return partialIteration->getResults(); 307 assert(!partialIteration && "expected that loop was not peeled"); 308 return forOp->getResults(); 309 }) 310 .Default([&](Operation *op) { return op->getResults(); }); 311 } 312 313 /// Peel loops after tiling. 314 void mlir::linalg::peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res, 315 ArrayRef<int64_t> peeledLoops, 316 LinalgTilingLoopType loopType) { 317 for (int64_t loop : peeledLoops) { 318 assert(loop < static_cast<int64_t>(res.loops.size()) && 319 "requested peeling of non-existing loop"); 320 SmallVector<Value, 4> loopResults; 321 Operation *loopOp = res.loops[loop]; 322 loopResults = peelLoop(rewriter, loopOp); 323 324 // The result of the loop nest may change with peeling. 325 if (res.tensorResults.size() == loopOp->getNumResults() && 326 std::equal(res.tensorResults.begin(), res.tensorResults.end(), 327 loopOp->getResults().begin())) 328 res.tensorResults = loopResults; 329 } 330 } 331 332 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) { 333 if (tiledOp.loops.empty()) 334 return tiledOp.op.getOperation()->getResults(); 335 return tiledOp.loops.front()->getResults(); 336 } 337 338 static ValueRange 339 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) { 340 if (tiledAndFusedOp.fusedLoops.empty()) 341 return tiledAndFusedOp.op.getOperation()->getResults(); 342 return tiledAndFusedOp.fusedLoops.front()->getResults(); 343 } 344 345 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( 346 StringRef opName, MLIRContext *context, 347 const LinalgDependenceGraph &dependenceGraph, 348 LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, 349 LinalgTransformationFilter f, LinalgTransformationFilter fusedOpMarker, 350 LinalgTransformationFilter originalOpMarker, PatternBenefit benefit) 351 : RewritePattern(opName, benefit, context, {}), 352 dependenceGraph(dependenceGraph), tilingOptions(std::move(tilingOptions)), 353 fusionOptions(std::move(fusionOptions)), filter(std::move(f)), 354 fusedOpMarker(std::move(fusedOpMarker)), 355 originalOpMarker(std::move(originalOpMarker)) {} 356 357 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( 358 Operation *op, PatternRewriter &rewriter) const { 359 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 360 // TODO: remove hasIndexSemantics check once index ops are supported. 361 if (!linalgOp || linalgOp.hasIndexSemantics()) 362 return failure(); 363 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 364 return failure(); 365 366 DenseSet<Operation *> producers; 367 producers.insert(linalgOp); 368 for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) { 369 Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum(); 370 // When looking at dependences into, indexingOp is always OpOperand. We 371 // could assert, but continue if this is not the case. 372 if (!operandNumber) 373 continue; 374 if (!fusionOptions.indicesToFuse.count(operandNumber.getValue())) 375 continue; 376 if (isa<LinalgOp>(dependence.getDependentOp())) 377 producers.insert(dependence.getDependentOp()); 378 } 379 380 SmallVector<LinalgOp, 1> fusionOps; 381 for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; 382 ++it) { 383 auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it)); 384 if (producerLinalgOp && producers.count(producerLinalgOp)) 385 fusionOps.push_back(producerLinalgOp); 386 } 387 fusionOps.push_back(linalgOp); 388 389 SmallVector<Value, 4> tileSizes = 390 tilingOptions.tileSizeComputationFunction(rewriter, op); 391 LinalgTilingOptions instanceTilingOptions = tilingOptions; 392 instanceTilingOptions.setTileSizes(tileSizes); 393 Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps( 394 rewriter, fusionOps, dependenceGraph, instanceTilingOptions); 395 if (!tiledAndFusedOps) 396 return failure(); 397 398 // Tile the unfused loops; 399 SmallVector<Value, 4> unfusedLoopTileSizes; 400 Value zero = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0); 401 for (const auto &tileSize : enumerate(tileSizes)) { 402 if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) 403 unfusedLoopTileSizes.push_back(zero); 404 else 405 unfusedLoopTileSizes.push_back(tileSize.value()); 406 } 407 // Tile the loop only if there is a non-zero tile size. 408 if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) 409 unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); 410 if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { 411 if (auto cst = val.getDefiningOp<arith::ConstantIndexOp>()) 412 return cst.value() != 0; 413 return true; 414 })) { 415 LinalgTilingOptions unfusedTilingOptions = tilingOptions; 416 unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); 417 FailureOr<TiledLinalgOp> unfusedTiledOp = 418 tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); 419 if (failed(unfusedTiledOp)) 420 return failure(); 421 rewriter.replaceOp(tiledAndFusedOps->op, 422 getTiledOpResult(unfusedTiledOp.getValue())); 423 tiledAndFusedOps->op = unfusedTiledOp->op; 424 } 425 op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue())); 426 427 filter.replaceLinalgTransformationFilter(rewriter, 428 tiledAndFusedOps->op.getOperation()); 429 for (auto fusedOp : tiledAndFusedOps->fusedProducers) { 430 fusedOpMarker.replaceLinalgTransformationFilter(rewriter, 431 fusedOp.getOperation()); 432 } 433 for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) { 434 originalOpMarker.replaceLinalgTransformationFilter( 435 rewriter, origProducerOp.getOperation()); 436 } 437 rewriter.updateRootInPlace(op, [&]() { 438 originalOpMarker.replaceLinalgTransformationFilter(rewriter, op); 439 }); 440 return success(); 441 } 442 443 /// Linalg tiling pattern. 444 mlir::linalg::LinalgTilingPattern::LinalgTilingPattern( 445 MLIRContext *context, LinalgTilingOptions options, 446 LinalgTransformationFilter f, PatternBenefit benefit) 447 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 448 filter(std::move(f)), options(std::move(options)) {} 449 450 mlir::linalg::LinalgTilingPattern::LinalgTilingPattern( 451 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 452 LinalgTransformationFilter f, PatternBenefit benefit) 453 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 454 filter(f.addOpNameFilter(opName)), options(std::move(options)) {} 455 456 FailureOr<TiledLinalgOp> 457 mlir::linalg::LinalgTilingPattern::returningMatchAndRewrite( 458 LinalgOp op, PatternRewriter &rewriter) const { 459 if (failed(filter.checkAndNotify(rewriter, op))) 460 return failure(); 461 462 FailureOr<TiledLinalgOp> res = tileLinalgOp(rewriter, op, options); 463 if (failed(res)) 464 return failure(); 465 466 // Clear filter to stop recursive pattern application. 467 // This must be done here to properly propagate to peeling branches. 468 filter.replaceLinalgTransformationFilter(rewriter, res->op); 469 470 // Peel the loops of the TiledLinalgOp. 471 peelTiledLinalgOp(rewriter, *res, options.peeledLoops, options.loopType); 472 473 if (res->tensorResults.empty()) 474 rewriter.eraseOp(op); 475 else 476 rewriter.replaceOp(op, res->tensorResults); 477 478 return res; 479 } 480 481 /// Linalg padding pattern. 482 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( 483 MLIRContext *context, LinalgPaddingOptions options, 484 LinalgTransformationFilter f, PatternBenefit benefit) 485 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 486 filter(std::move(f)), options(std::move(options)) {} 487 488 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( 489 StringRef opName, MLIRContext *context, LinalgPaddingOptions options, 490 LinalgTransformationFilter f, PatternBenefit benefit) 491 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 492 filter(f.addOpNameFilter(opName)), options(std::move(options)) {} 493 494 FailureOr<LinalgOp> 495 mlir::linalg::LinalgPaddingPattern::returningMatchAndRewrite( 496 LinalgOp linalgOp, PatternRewriter &rewriter) const { 497 if (!linalgOp.hasTensorSemantics()) 498 return failure(); 499 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 500 return failure(); 501 502 // Pad the operation. 503 LinalgOp paddedOp; 504 FailureOr<SmallVector<Value>> newResults = 505 rewriteAsPaddedOp(rewriter, linalgOp, options.paddingValues, 506 options.packPaddings, paddedOp); 507 if (failed(newResults)) 508 return failure(); 509 510 // Hoist the padding. 511 for (const auto &en : enumerate(options.hoistPaddings)) { 512 if (static_cast<int64_t>(en.index()) >= paddedOp.getNumInputsAndOutputs()) 513 break; 514 OpOperand &opOperand = paddedOp->getOpOperand(en.index()); 515 auto padOp = opOperand.get().getDefiningOp<tensor::PadOp>(); 516 if (!padOp || en.value() == 0) 517 continue; 518 tensor::PadOp hoistedOp; 519 SmallVector<GenericOp> transposeOps; 520 SmallVector<int64_t> transposeVector = 521 en.index() < options.transposePaddings.size() 522 ? options.transposePaddings[en.index()] 523 : SmallVector<int64_t>{}; 524 525 FailureOr<Value> newResult = hoistPaddingOnTensors( 526 padOp, en.value(), transposeVector, hoistedOp, transposeOps); 527 if (failed(newResult)) 528 continue; 529 rewriter.replaceOp(padOp, newResult.getValue()); 530 531 // Do not apply hoist padding to the newly introduced transpose operations. 532 for (GenericOp transposeOp : transposeOps) 533 filter.replaceLinalgTransformationFilter(rewriter, transposeOp); 534 } 535 536 // Replace the original operation to pad. 537 rewriter.replaceOp(linalgOp, newResults.getValue()); 538 filter.replaceLinalgTransformationFilter(rewriter, paddedOp); 539 540 return paddedOp; 541 } 542 543 /// Linalg tile and fuse tensor ops pattern. 544 mlir::linalg::LinalgTileAndFuseTensorOpsPattern:: 545 LinalgTileAndFuseTensorOpsPattern(MLIRContext *context, 546 LinalgTilingAndFusionOptions options, 547 LinalgTransformationFilter f, 548 PatternBenefit benefit) 549 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), 550 filter(std::move(f)), options(std::move(options)) {} 551 552 mlir::linalg::LinalgTileAndFuseTensorOpsPattern:: 553 LinalgTileAndFuseTensorOpsPattern(StringRef opName, MLIRContext *context, 554 LinalgTilingAndFusionOptions options, 555 LinalgTransformationFilter f, 556 PatternBenefit benefit) 557 : RewritePattern(opName, benefit, context), filter(std::move(f)), 558 options(std::move(options)) {} 559 560 FailureOr<mlir::linalg::TileLoopNest> 561 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::returningMatchAndRewrite( 562 Operation *op, PatternRewriter &rewriter) const { 563 LinalgOp rootOp = dyn_cast<LinalgOp>(op); 564 if (!rootOp) 565 return failure(); 566 if (failed(filter.checkAndNotify(rewriter, op))) 567 return failure(); 568 569 // Check `tileSizes` contains a tile size for every `rootOp` loop dimension. 570 if (options.tileSizes.size() < rootOp.getNumLoops()) 571 return rewriter.notifyMatchFailure(op, "expect #tile sizes >= #loops"); 572 573 // Check `tileInterchange` contains no entries or as many as `tileSizes`. 574 if (!options.tileInterchange.empty() && 575 options.tileInterchange.size() != options.tileSizes.size()) 576 return rewriter.notifyMatchFailure( 577 op, "expect the number of tile sizes and interchange dims to match"); 578 579 // Copy the `tileSizes` and `tileInterchange` prefixes needed for `rootOp`. 580 SmallVector<int64_t> rootTileSizes(options.tileSizes.begin(), 581 options.tileSizes.begin() + 582 rootOp.getNumLoops()); 583 SmallVector<int64_t> rootInterchange = 584 options.tileInterchange.empty() 585 ? llvm::to_vector<6>(llvm::seq<int64_t>(0, rootOp.getNumLoops())) 586 : SmallVector<int64_t>(options.tileInterchange.begin(), 587 options.tileInterchange.begin() + 588 rootOp.getNumLoops()); 589 590 // Check `rootTileSizes` contains non-zero tile sizes. 591 if (llvm::count(rootTileSizes, 0) == static_cast<long>(rootTileSizes.size())) 592 return rewriter.notifyMatchFailure( 593 op, "expect at least one non-zero tile size"); 594 595 // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions. 596 // It has to be a permutation since the tiling cannot tile the same loop 597 // dimension multiple times. 598 if (!isPermutation(rootInterchange)) 599 return rewriter.notifyMatchFailure( 600 op, "expect the tile interchange permutes the root loops"); 601 602 // Tile `rootOp` and fuse its producers. 603 FailureOr<TileLoopNest> tileLoopNest = 604 tileConsumerAndFuseProducers(rewriter, rootOp, rootTileSizes, 605 rootInterchange, options.tileDistribution); 606 if (failed(tileLoopNest)) 607 return rewriter.notifyMatchFailure( 608 op, "tileConsumerAndFuseProducers failed unexpectedly"); 609 610 // Replace all uses of the tiled loop operation. 611 rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults()); 612 613 // Apply the filter if specified. 614 for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps()) 615 filter.replaceLinalgTransformationFilter(rewriter, linalgOp); 616 return tileLoopNest; 617 } 618 619 /// Linalg generic interchange pattern. 620 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern( 621 MLIRContext *context, ArrayRef<unsigned> interchangeVector, 622 LinalgTransformationFilter f, PatternBenefit benefit) 623 : OpRewritePattern(context, benefit), filter(std::move(f)), 624 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 625 626 FailureOr<GenericOp> 627 mlir::linalg::GenericOpInterchangePattern::returningMatchAndRewrite( 628 GenericOp genericOp, PatternRewriter &rewriter) const { 629 if (failed(filter.checkAndNotify(rewriter, genericOp))) 630 return failure(); 631 632 FailureOr<GenericOp> transformedOp = 633 interchangeGenericOp(rewriter, genericOp, interchangeVector); 634 if (failed(transformedOp)) 635 return failure(); 636 637 // New filter if specified. 638 filter.replaceLinalgTransformationFilter(rewriter, genericOp); 639 return transformedOp; 640 } 641 642 /// Linalg generalization pattern. 643 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 644 MLIRContext *context, LinalgTransformationFilter f, PatternBenefit benefit) 645 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 646 filter(std::move(f)) {} 647 648 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 649 StringRef opName, MLIRContext *context, LinalgTransformationFilter f, 650 PatternBenefit benefit) 651 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 652 filter(f.addOpNameFilter(opName)) {} 653 654 FailureOr<GenericOp> 655 mlir::linalg::LinalgGeneralizationPattern::returningMatchAndRewrite( 656 LinalgOp linalgOp, PatternRewriter &rewriter) const { 657 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 658 return failure(); 659 FailureOr<GenericOp> genericOp = generalizeNamedOp(rewriter, linalgOp); 660 if (failed(genericOp)) 661 return failure(); 662 filter.replaceLinalgTransformationFilter(rewriter, *genericOp); 663 return genericOp; 664 } 665 666 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 667 MLIRContext *context, LinalgTransformationFilter f, 668 LinalgPromotionOptions options, PatternBenefit benefit) 669 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), 670 filter(std::move(f)), options(std::move(options)) {} 671 672 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 673 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 674 LinalgTransformationFilter f, PatternBenefit benefit) 675 : RewritePattern(opName, benefit, context, {}), filter(std::move(f)), 676 options(std::move(options)) {} 677 678 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 679 Operation *op, PatternRewriter &rewriter) const { 680 if (failed(filter.checkAndNotify(rewriter, op))) 681 return failure(); 682 if (failed(promoteSubviewsPrecondition(op, options))) 683 return failure(); 684 685 // TODO: We cannot use root update here. This pattern is creating other ops, 686 // so if the promotion fails, those need to be cleaned up, which doesnt seem 687 // to be happening here. So to fail properly, we should be cloning the op and 688 // deleting the previous op. This needs more investigation. 689 rewriter.startRootUpdate(op); 690 Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options); 691 if (!promotedOp) { 692 rewriter.cancelRootUpdate(op); 693 return op->emitError("subview promotion failed"); 694 } 695 rewriter.finalizeRootUpdate(op); 696 filter.replaceLinalgTransformationFilter(rewriter, op); 697 return success(); 698 } 699 700 mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern( 701 MLIRContext *context, LinalgTransformationFilter f, 702 LinalgVectorizationOptions options, PatternBenefit benefit) 703 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 704 filter(std::move(f)) {} 705 706 mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern( 707 StringRef opName, MLIRContext *context, LinalgVectorizationOptions options, 708 LinalgTransformationFilter f, PatternBenefit benefit) 709 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 710 filter(f.addOpNameFilter(opName)) {} 711 712 LogicalResult mlir::linalg::LinalgVectorizationPattern::matchAndRewrite( 713 LinalgOp linalgOp, PatternRewriter &rewriter) const { 714 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 715 return failure(); 716 return vectorize(rewriter, linalgOp); 717 } 718 719 LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite( 720 memref::CopyOp copyOp, PatternRewriter &rewriter) const { 721 return vectorizeCopy(rewriter, copyOp); 722 } 723 724 LogicalResult mlir::linalg::applyStagedPatterns( 725 Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns, 726 const FrozenRewritePatternSet &stage2Patterns, 727 function_ref<LogicalResult(Operation *)> stage3Lambda) { 728 unsigned iteration = 0; 729 (void)iteration; 730 for (const auto &patterns : stage1Patterns) { 731 LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" 732 << *op); 733 if (failed(applyPatternsAndFoldGreedily(op, patterns))) { 734 LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); 735 return failure(); 736 } 737 LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" 738 << *op); 739 if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { 740 LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); 741 return failure(); 742 } 743 LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" 744 << *op); 745 if (stage3Lambda) { 746 if (failed(stage3Lambda(op))) 747 return failure(); 748 LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" 749 << *op); 750 } 751 } 752 return success(); 753 } 754 755 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) { 756 return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName()); 757 } 758 759 /// Rewrite a tensor::PadOp into a sequence of InitTensorOp, FillOp (to 760 /// initialize with pad_val) and GenericOp (to copy contents). 761 LogicalResult 762 PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp, 763 PatternRewriter &rewriter) const { 764 765 auto inputShapedType = padOp.source().getType().cast<ShapedType>(); 766 auto resultShapedType = padOp.result().getType().cast<ShapedType>(); 767 768 // Bail on non-static shapes. 769 if (!inputShapedType.hasStaticShape()) 770 return failure(); 771 if (!resultShapedType.hasStaticShape()) 772 return failure(); 773 774 // Only support padding with a constant for now, i.e. either: 775 // 1. A BBarg from a different block. 776 // 2. A value defined outside of the current block. 777 Block &block = padOp.region().front(); 778 auto yieldOp = cast<tensor::YieldOp>(block.getTerminator()); 779 Value padValue = yieldOp.value(); 780 Operation *definingOp = padValue.getDefiningOp(); 781 if (definingOp && definingOp->getBlock() == &block) 782 return failure(); 783 if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block) 784 return failure(); 785 786 // Create tensor with the padded shape 787 Location loc = padOp.getLoc(); 788 SmallVector<Value> indices(resultShapedType.getRank(), 789 rewriter.create<arith::ConstantIndexOp>(loc, 0)); 790 Value initTensor = rewriter.create<InitTensorOp>( 791 loc, resultShapedType.getShape(), resultShapedType.getElementType()); 792 793 // Initialize tensor with the pad value 794 Value tmpTensor = rewriter 795 .create<linalg::FillOp>(loc, ValueRange{padValue}, 796 ValueRange{initTensor}) 797 .result(); 798 799 // Copy original contents into new tensor 800 // Uses linalg.generic, but could be done with tensor.insert_slice 801 SmallVector<AffineExpr, 4> outputExprs; 802 for (unsigned i = 0; i < resultShapedType.getRank(); ++i) { 803 outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) + 804 padOp.static_low()[i].cast<IntegerAttr>().getInt()); 805 } 806 807 SmallVector<AffineMap, 2> transferMaps = { 808 rewriter.getMultiDimIdentityMap(inputShapedType.getRank()), 809 AffineMap::get(resultShapedType.getRank(), 810 /*symbolCount=*/0, outputExprs, rewriter.getContext())}; 811 812 rewriter.replaceOpWithNewOp<linalg::GenericOp>( 813 padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps, 814 getNParallelLoopsAttrs(resultShapedType.getRank()), 815 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { 816 nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]); 817 }); 818 819 return success(); 820 } 821 822 /// Filling `dest` using FillOp constant padding value if possible. 823 /// Otherwise, generate a tensor::GenerateOp. 824 Value GeneralizePadOpPattern::createFillOrGenerateOp( 825 PatternRewriter &rewriter, tensor::PadOp padOp, Value dest, 826 const SmallVector<Value> &dynSizes) const { 827 auto padValue = padOp.getConstantPaddingValue(); 828 if (padValue) 829 return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result(); 830 831 // Fill could not be optimized: Lower to tensor::GenerateOp with region. 832 auto generateOp = rewriter.create<tensor::GenerateOp>( 833 padOp.getLoc(), padOp.getResultType(), dynSizes); 834 // Copy region to new op. 835 BlockAndValueMapping bvm; 836 padOp.region().cloneInto(&generateOp.getRegion(), bvm); 837 return generateOp; 838 } 839 840 LogicalResult 841 GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp, 842 PatternRewriter &rewriter) const { 843 // Given an OpFoldResult, return an index-typed value. 844 auto getIdxValue = [&](OpFoldResult ofr) { 845 if (auto val = ofr.dyn_cast<Value>()) 846 return val; 847 return rewriter 848 .create<arith::ConstantIndexOp>( 849 padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt()) 850 .getResult(); 851 }; 852 853 auto resultType = padOp.getResultType(); 854 // Compute size of InitTensorOp. Any combination of static/dynamic is 855 // supported. 856 SmallVector<Value> dynSizes; 857 SmallVector<int64_t> staticSizes; 858 for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { 859 if (resultType.isDynamicDim(dim)) { 860 auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(), 861 padOp.source(), dim); 862 // Add low and high padding value. 863 auto plusLow = rewriter.createOrFold<arith::AddIOp>( 864 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); 865 auto plusHigh = rewriter.createOrFold<arith::AddIOp>( 866 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); 867 dynSizes.push_back(plusHigh); 868 } 869 staticSizes.push_back(resultType.getDimSize(dim)); 870 } 871 872 // Init tensor and fill it with padding. 873 Value init = rewriter.create<InitTensorOp>( 874 padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType()); 875 Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes); 876 877 // Try optimize the copy of source. 878 if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded()) 879 return success(); 880 881 // tensor::PadOps cannot be optimized. Generate a InsertSliceOp instead 882 // for copying the PadOp source. 883 auto sourceType = padOp.getSourceType(); 884 // Compute size of source of tensor::PadOp. 885 SmallVector<OpFoldResult> srcSizes; 886 for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) { 887 if (sourceType.isDynamicDim(dim)) { 888 srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>( 889 padOp.getLoc(), padOp.source(), dim)); 890 } else { 891 srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim))); 892 } 893 } 894 // Strides of InsertSliceOp are all 1. 895 SmallVector<OpFoldResult> strides(sourceType.getRank(), 896 rewriter.getIndexAttr(1)); 897 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( 898 padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides); 899 900 return success(); 901 } 902 903 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( 904 tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { 905 if (!sliceOp.hasUnitStride()) 906 return failure(); 907 908 auto padOp = sliceOp.source().getDefiningOp<tensor::PadOp>(); 909 if (!padOp) 910 return failure(); 911 912 bool zeroSliceGuard = true; 913 if (controlFn) { 914 if (Optional<bool> control = controlFn(sliceOp)) 915 zeroSliceGuard = control.getValue(); 916 else 917 return failure(); 918 } 919 920 Operation *tiledPadOp = 921 tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(), 922 sliceOp.getMixedSizes(), zeroSliceGuard); 923 // All shapes are static and the data source is actually used. Rewrite into 924 // pad(extract_slice(x)). 925 rewriter.replaceOp(sliceOp, tiledPadOp->getResults()); 926 return success(); 927 } 928 929 namespace { 930 // The following are patterns for downscaling convolution ops with size-1 931 // window dimensions. 932 // 933 // Note that we'd eventually want to write such transformations in a generic 934 // way, e.g., converting to linalg.generic, removing the size-1 dimensions, 935 // and then turning back to named ops. But for now it's fine to have a few 936 // patterns matching special ops to get started. 937 938 /// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D 939 /// convolution ops. 940 struct DownscaleSizeOneWindowed2DConvolution final 941 : public OpRewritePattern<Conv2DNhwcHwcfOp> { 942 DownscaleSizeOneWindowed2DConvolution( 943 MLIRContext *context, 944 LinalgTransformationFilter f = LinalgTransformationFilter(), 945 PatternBenefit benefit = 1) 946 : OpRewritePattern<Conv2DNhwcHwcfOp>(context, benefit), 947 filter(std::move(f)) {} 948 949 LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, 950 PatternRewriter &rewriter) const override { 951 if (failed(filter.checkAndNotify(rewriter, convOp))) 952 return failure(); 953 if (convOp.hasBufferSemantics()) 954 return failure(); // To be implemented 955 956 Value input = convOp.inputs().front(); 957 Value kernel = convOp.inputs().back(); 958 Value output = convOp.outputs().front(); 959 960 auto inputType = input.getType().dyn_cast<RankedTensorType>(); 961 auto kernelType = kernel.getType().dyn_cast<RankedTensorType>(); 962 auto outputType = output.getType().dyn_cast<RankedTensorType>(); 963 964 auto kernelShape = kernelType.getShape(); 965 auto outputShape = outputType.getShape(); 966 967 // Only handle the case where at least one of the window dimensions is 968 // of size 1. Other cases can rely on tiling to reduce to such cases. 969 int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; 970 int64_t ohSize = outputShape[1], owSize = outputShape[2]; 971 bool removeH = (khSize == 1 && ohSize == 1); 972 bool removeW = (kwSize == 1 && owSize == 1); 973 if (!removeH && !removeW) 974 return failure(); 975 976 // Get new shapes and types for all operands by removing the size-1 977 // dimension. 978 using RTTBuilder = RankedTensorType::Builder; 979 RankedTensorType newInputType = 980 RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); 981 RankedTensorType newKernelType = 982 RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); 983 RankedTensorType newOutputType = 984 RTTBuilder(outputType).dropDim(removeH ? 1 : 2); 985 986 // Rank-reduce operands. 987 Location loc = convOp.getLoc(); 988 Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( 989 rewriter, loc, input, newInputType); 990 Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( 991 rewriter, loc, kernel, newKernelType); 992 Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( 993 rewriter, loc, output, newOutputType); 994 995 // Rank-reduce strides and dilations too. 996 // TODO: dropDim 1-liner helper. 997 auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>()); 998 strides.erase(strides.begin() + (removeH ? 0 : 1)); 999 auto stridesAttr = rewriter.getI64VectorAttr(strides); 1000 1001 auto dilations = 1002 llvm::to_vector<4>(convOp.dilations().getValues<int64_t>()); 1003 dilations.erase(dilations.begin() + (removeH ? 0 : 1)); 1004 auto dilationsAttr = rewriter.getI64VectorAttr(dilations); 1005 1006 auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>( 1007 loc, newOutputType, ValueRange{newInput, newKernel}, 1008 ValueRange{newOutput}, stridesAttr, dilationsAttr); 1009 1010 // Insert back. 1011 Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( 1012 rewriter, loc, conv1DOp.getResult(0), output); 1013 rewriter.replaceOp(convOp, inserted); 1014 1015 filter.replaceLinalgTransformationFilter(rewriter, conv1DOp); 1016 return success(); 1017 }; 1018 1019 private: 1020 /// LinalgTransformMarker handles special attribute manipulations. 1021 LinalgTransformationFilter filter; 1022 }; 1023 1024 /// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) 1025 /// dimensions into 1-D depthwise convolution ops. 1026 struct DownscaleDepthwiseConv2DNhwcHwcOp final 1027 : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> { 1028 DownscaleDepthwiseConv2DNhwcHwcOp( 1029 MLIRContext *context, 1030 LinalgTransformationFilter f = LinalgTransformationFilter(), 1031 PatternBenefit benefit = 1) 1032 : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit), 1033 filter(std::move(f)) {} 1034 1035 LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, 1036 PatternRewriter &rewriter) const override { 1037 if (failed(filter.checkAndNotify(rewriter, convOp))) 1038 return failure(); 1039 if (convOp.hasBufferSemantics()) 1040 return failure(); // To be implemented 1041 1042 Value input = convOp.inputs().front(); 1043 Value kernel = convOp.inputs().back(); 1044 Value output = convOp.outputs().front(); 1045 1046 auto inputType = input.getType().dyn_cast<RankedTensorType>(); 1047 auto kernelType = kernel.getType().dyn_cast<RankedTensorType>(); 1048 auto outputType = output.getType().dyn_cast<RankedTensorType>(); 1049 1050 auto kernelShape = kernelType.getShape(); 1051 auto outputShape = outputType.getShape(); 1052 1053 // Only handle the case where at least one of the window dimensions is 1054 // of size 1. Other cases can rely on tiling to reduce to such cases. 1055 int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; 1056 int64_t ohSize = outputShape[1], owSize = outputShape[2]; 1057 bool removeH = (khSize == 1 && ohSize == 1); 1058 bool removeW = (kwSize == 1 && owSize == 1); 1059 if (!removeH && !removeW) 1060 return failure(); 1061 1062 // Get new shapes and types for all operands by removing the size-1 1063 // dimension. 1064 using RTTBuilder = RankedTensorType::Builder; 1065 RankedTensorType newInputType = 1066 RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); 1067 RankedTensorType newKernelType = 1068 RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); 1069 RankedTensorType newOutputType = 1070 RTTBuilder(outputType).dropDim(removeH ? 1 : 2); 1071 1072 // Rank-reduce operands. 1073 Location loc = convOp.getLoc(); 1074 Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( 1075 rewriter, loc, input, newInputType); 1076 Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( 1077 rewriter, loc, kernel, newKernelType); 1078 Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( 1079 rewriter, loc, output, newOutputType); 1080 1081 // Rank-reduce strides and dilations too. 1082 // TODO: dropDim 1-liner helper. 1083 auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>()); 1084 strides.erase(strides.begin() + (removeH ? 0 : 1)); 1085 auto stridesAttr = rewriter.getI64VectorAttr(strides); 1086 1087 auto dilations = 1088 llvm::to_vector<4>(convOp.dilations().getValues<int64_t>()); 1089 dilations.erase(dilations.begin() + (removeH ? 0 : 1)); 1090 auto dilationsAttr = rewriter.getI64VectorAttr(dilations); 1091 1092 auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>( 1093 loc, newOutputType, ValueRange{newInput, newKernel}, 1094 ValueRange{newOutput}, stridesAttr, dilationsAttr); 1095 1096 // Insert back. 1097 Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( 1098 rewriter, loc, conv1DOp.getResult(0), output); 1099 rewriter.replaceOp(convOp, inserted); 1100 1101 filter.replaceLinalgTransformationFilter(rewriter, conv1DOp); 1102 return success(); 1103 }; 1104 1105 private: 1106 /// LinalgTransformMarker handles special attribute manipulations. 1107 LinalgTransformationFilter filter; 1108 }; 1109 1110 } // namespace 1111 1112 void linalg::populateDecomposeConvolutionPatterns( 1113 RewritePatternSet &patterns, const LinalgTransformationFilter &filter, 1114 PatternBenefit benefit) { 1115 patterns.add<DownscaleSizeOneWindowed2DConvolution, 1116 DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), filter, 1117 benefit); 1118 } 1119