1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Affine/Utils.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 18 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 19 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/SCF/Transforms.h" 22 #include "mlir/Dialect/Tensor/IR/Tensor.h" 23 #include "mlir/Dialect/Utils/StaticValueUtils.h" 24 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 25 #include "mlir/Dialect/Vector/VectorOps.h" 26 #include "mlir/IR/AffineExpr.h" 27 #include "mlir/IR/Matchers.h" 28 #include "mlir/Pass/Pass.h" 29 #include "mlir/Support/LLVM.h" 30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 31 #include "llvm/ADT/ScopeExit.h" 32 #include "llvm/ADT/TypeSwitch.h" 33 #include "llvm/Support/Debug.h" 34 #include "llvm/Support/raw_ostream.h" 35 #include <type_traits> 36 37 #define DEBUG_TYPE "linalg-transforms" 38 39 using namespace mlir; 40 using namespace mlir::linalg; 41 42 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 43 44 //===----------------------------------------------------------------------===// 45 // Transformations exposed as rewrite patterns. 46 //===----------------------------------------------------------------------===// 47 // Marker used as attribute name in generated Linalg rewriting transformations. 48 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 49 "__internal_linalg_transform__"; 50 51 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 52 ArrayRef<StringAttr> matchDisjunction, Optional<StringAttr> replacement) 53 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 54 replacement(replacement), matchByDefault(false) {} 55 56 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 57 FilterFunction f, ArrayRef<StringAttr> matchDisjunction, 58 Optional<StringAttr> replacement) 59 : filters(), 60 matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 61 replacement(replacement), matchByDefault(false) { 62 if (f) 63 filters.push_back(f); 64 } 65 66 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify( 67 PatternRewriter &rewriter, Operation *op) const { 68 if (llvm::any_of(filters, 69 [&](const FilterFunction &f) { return failed(f(op)); })) 70 return failure(); 71 72 auto attr = op->template getAttrOfType<StringAttr>( 73 LinalgTransforms::kLinalgTransformMarker); 74 75 if (!attr) { 76 // 1. Has no filter case and matchDisjunction is empty. 77 if (matchDisjunction.empty() || matchByDefault) 78 return success(); 79 80 // 2. Has no filter but was expecting a filter. 81 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 82 diag << " does not have any filter from list: "; 83 interleaveComma(matchDisjunction, diag); 84 }); 85 } 86 87 // 4. Match explicit filter. 88 for (auto filter : matchDisjunction) 89 if (attr.getValue() == filter) 90 return success(); 91 92 // 5. Fail to match. 93 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 94 diag << " does not have any filter from list: "; 95 interleaveComma(matchDisjunction, diag); 96 }); 97 } 98 99 void mlir::linalg::LinalgTransformationFilter:: 100 replaceLinalgTransformationFilter(PatternRewriter &rewriter, 101 Operation *op) const { 102 if (replacement.hasValue()) 103 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 104 replacement.getValue()); 105 else 106 op->removeAttr( 107 rewriter.getStringAttr(LinalgTransforms::kLinalgTransformMarker)); 108 } 109 110 bool mlir::linalg::LinalgTransformationFilter::hasReplacementFilter( 111 Operation *op) const { 112 if (!replacement) 113 return false; 114 auto attr = op->getAttr(LinalgTransforms::kLinalgTransformMarker) 115 .dyn_cast<StringAttr>(); 116 return attr && attr == replacement.getValue(); 117 } 118 119 LinalgTilingOptions & 120 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 121 assert(!tileSizeComputationFunction && "tile sizes already set"); 122 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 123 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 124 OpBuilder::InsertionGuard guard(b); 125 b.setInsertionPointToStart( 126 &op->getParentOfType<FuncOp>().getBody().front()); 127 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 128 Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); 129 return v; 130 })); 131 }; 132 return *this; 133 } 134 135 LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() { 136 assert(!tileSizeComputationFunction && "tile sizes already set"); 137 tileSizeComputationFunction = [](OpBuilder &b, Operation *op) { 138 SmallVector<Value, 4> tileSizes; 139 auto linalgOp = dyn_cast<LinalgOp>(op); 140 if (!linalgOp) 141 return tileSizes; 142 Location loc = linalgOp.getLoc(); 143 auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc); 144 AffineMap map = linalgOp.getShapesToLoopsMap(); 145 if (!map) 146 return tileSizes; 147 auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes); 148 // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile 149 // size 0). 150 for (Value shapeSize : shapeSizes) 151 tileSizes.push_back(getConstantIntValue(shapeSize).hasValue() 152 ? b.create<arith::ConstantIndexOp>(loc, 0) 153 : b.create<arith::ConstantIndexOp>(loc, 1)); 154 return tileSizes; 155 }; 156 return *this; 157 } 158 159 /// Helper function that tries to pad `opOperand`. Exit early for scalar 160 /// operands, if `paddingFunc` returns failure, or if `opOperand` is not defined 161 /// by an ExtractSliceOp. Otherwise, try to pad the operand even if it already 162 /// has a static shape. Set `result` to the result of the created PadTensorOp or 163 /// and return success if the operand either has been padded to a static shape 164 /// or already had a static shape and failure otherwise. 165 static LogicalResult padOperandToSmallestStaticBoundingBox( 166 OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand, 167 const PaddingValueComputationFunction &paddingFunc, 168 const PaddingNoFoldComputationFunction &nofoldFunc, Value &result) { 169 // Get the shape of the operand and check if it has a dynamic shape. Only 170 // return failure if the operand is not a scalar and has a dynamic shape. 171 ArrayRef<int64_t> shape = opToPad.getShape(opOperand); 172 bool hasDynamicShape = llvm::is_contained(shape, ShapedType::kDynamicSize); 173 174 // Cannot pad scalar operands. 175 if (shape.empty()) 176 return success(); 177 178 // Cannot pad if the padding value is unknown. 179 FailureOr<Value> paddingValue = paddingFunc(b, *opOperand); 180 if (failed(paddingValue)) 181 return failure(hasDynamicShape); 182 183 // Cannot construct a static bounding box if the operand is not defined by an 184 // ExtractSliceOp. 185 auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>(); 186 if (!sliceOp) 187 return failure(hasDynamicShape); 188 189 // Upper bound the `sliceOp` sizes to obtain a static bounding box. 190 SmallVector<int64_t> staticSizes; 191 staticSizes.reserve(opToPad.getRank(opOperand)); 192 auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation()); 193 for (auto size : shapedOp.getMixedSizes()) { 194 // If the size is an attribute add it directly to `staticSizes`. 195 if (size.is<Attribute>()) { 196 staticSizes.push_back( 197 size.get<Attribute>().dyn_cast<IntegerAttr>().getInt()); 198 continue; 199 } 200 // Otherwise, try to compute a constant upper bound for the size value. 201 FailureOr<int64_t> upperBound = 202 getConstantUpperBoundForIndex(size.get<Value>()); 203 if (failed(upperBound)) { 204 LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding"); 205 return failure(); 206 } 207 staticSizes.push_back(upperBound.getValue()); 208 } 209 210 // Pad the operand to the bounding box defined by `staticSizes`. 211 auto staticTensorType = RankedTensorType::get( 212 staticSizes, getElementTypeOrSelf(opOperand->get())); 213 bool nofold = nofoldFunc ? nofoldFunc(*opOperand) : false; 214 result = 215 makeComposedPadHighOp(b, opToPad->getLoc(), staticTensorType, 216 opOperand->get(), paddingValue.getValue(), nofold); 217 return success(); 218 } 219 220 FailureOr<SmallVector<Value>> 221 linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad, 222 const PaddingValueComputationFunction &paddingFunc, 223 const PaddingNoFoldComputationFunction &nofoldFunc, 224 LinalgOp &paddedOp) { 225 Location loc = opToPad->getLoc(); 226 227 // TODO: there are cases where we may still want to pad to larger sizes. 228 assert(opToPad.hasTensorSemantics() && 229 "expected operation to have tensor semantics"); 230 231 OpBuilder::InsertionGuard g(b); 232 // Set IP after op because we also take the dims of the original output. 233 b.setInsertionPointAfter(opToPad); 234 // Make a copy of the shaped operands and update it. 235 SmallVector<Value> newOperands; 236 newOperands.reserve(opToPad.getNumInputsAndOutputs()); 237 for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) { 238 Value paddedOperand; 239 // If padding was requested but the shape cannot be bounded statically then 240 // the pattern fails to apply. 241 if (failed(padOperandToSmallestStaticBoundingBox( 242 b, opToPad, opOperand, paddingFunc, nofoldFunc, paddedOperand))) 243 return failure(); 244 newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get()); 245 } 246 247 SmallVector<SmallVector<Value>> reifiedResultShapes; 248 if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation()) 249 .reifyResultShapes(b, reifiedResultShapes))) 250 return failure(); 251 assert(reifiedResultShapes.size() == opToPad->getNumResults() && 252 "expected same number of results"); 253 254 // Clone `opToPad` to operate on the statically padded shapes. 255 auto resultTensorTypes = 256 ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes(); 257 paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands); 258 259 // Recover the slice out of the new static results. This keeps the original 260 // linalg op around because it uses the dims of the original results. 261 SmallVector<Value> paddedSubviewResults; 262 paddedSubviewResults.reserve(opToPad->getNumResults()); 263 for (auto en : llvm::enumerate(paddedOp->getResults())) { 264 Value paddedResult = en.value(); 265 int64_t resultNumber = en.index(); 266 int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank(); 267 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0)); 268 SmallVector<OpFoldResult> sizes; 269 for (Value v : reifiedResultShapes[resultNumber]) 270 sizes.push_back(getAsOpFoldResult(v)); 271 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1)); 272 paddedSubviewResults.push_back(b.create<tensor::ExtractSliceOp>( 273 loc, paddedResult, offsets, sizes, strides)); 274 } 275 return paddedSubviewResults; 276 } 277 278 /// Linalg base tiling pattern. 279 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 280 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 281 LinalgTransformationFilter filter, PatternBenefit benefit) 282 : RewritePattern(opName, benefit, context), filter(filter), 283 options(options) {} 284 285 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 286 MLIRContext *context, LinalgTilingOptions options, 287 LinalgTransformationFilter filter, PatternBenefit benefit) 288 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 289 options(options) {} 290 291 /// Try to peel a loop `op` and return the new result. 292 // TODO: Add support for scf.parallel and affine.for loops. 293 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) { 294 return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op) 295 .Case<scf::ForOp>([&](scf::ForOp forOp) { 296 scf::ForOp partialIteration; 297 if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp, 298 partialIteration))) 299 return partialIteration->getResults(); 300 assert(!partialIteration && "expected that loop was not peeled"); 301 return forOp->getResults(); 302 }) 303 .Default([&](Operation *op) { return op->getResults(); }); 304 } 305 306 /// Try to peel a TiledLoopOp and return the new result. 307 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, 308 TiledLoopOp tiledLoop, int64_t idx) { 309 assert(idx < static_cast<int64_t>(tiledLoop.iterator_types().size()) && 310 "requested peeling of non-existing loop"); 311 TiledLoopOp result; 312 if (succeeded(peelAndCanonicalizeTiledLoop(rewriter, tiledLoop, idx, result))) 313 return result->getResults(); 314 assert(!result && "expected that loop was not peeled"); 315 return tiledLoop->getResults(); 316 } 317 318 /// Peel loops after tiling. 319 static void peelLoops(RewriterBase &rewriter, TiledLinalgOp &res, 320 const LinalgTilingOptions &options) { 321 for (int64_t loop : options.peeledLoops) { 322 assert(loop < static_cast<int64_t>(res.loops.size()) && 323 "requested peeling of non-existing loop"); 324 SmallVector<Value, 4> loopResults; 325 Operation *loopOp = res.loops[loop]; 326 if (options.loopType == LinalgTilingLoopType::TiledLoops) { 327 assert(llvm::all_of( 328 res.loops, 329 [&](Operation *op) { return op == res.loops.front(); }) && 330 "expected that all loop ops are the same TiledLoopOp"); 331 auto tiledLoopOp = dyn_cast<TiledLoopOp>(loopOp); 332 assert(tiledLoopOp && "expected TiledLoopOp"); 333 loopResults = peelLoop(rewriter, tiledLoopOp, loop); 334 } else { 335 loopResults = peelLoop(rewriter, loopOp); 336 } 337 338 // The result of the loop nest may change with peeling. 339 if (res.tensorResults.size() == loopOp->getNumResults() && 340 std::equal(res.tensorResults.begin(), res.tensorResults.end(), 341 loopOp->getResults().begin())) 342 res.tensorResults = loopResults; 343 } 344 } 345 346 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase( 347 Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const { 348 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 349 if (!linalgOp) 350 return failure(); 351 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 352 return failure(); 353 354 Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options); 355 356 if (!res) 357 return failure(); 358 // Clear filter to stop recursive pattern application. 359 filter.replaceLinalgTransformationFilter(rewriter, res->op); 360 361 // Peel loops. 362 peelLoops(rewriter, *res, options); 363 364 result = *res; 365 return success(); 366 } 367 368 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) { 369 if (tiledOp.loops.empty()) 370 return tiledOp.op.getOperation()->getResults(); 371 return tiledOp.loops.front()->getResults(); 372 } 373 374 static ValueRange 375 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) { 376 if (tiledAndFusedOp.fusedLoops.empty()) 377 return tiledAndFusedOp.op.getOperation()->getResults(); 378 return tiledAndFusedOp.fusedLoops.front()->getResults(); 379 } 380 381 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( 382 StringRef opName, MLIRContext *context, 383 const LinalgDependenceGraph &dependenceGraph, 384 LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, 385 LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker, 386 LinalgTransformationFilter originalOpMarker, PatternBenefit benefit) 387 : RewritePattern(opName, benefit, context, {}), 388 dependenceGraph(dependenceGraph), tilingOptions(tilingOptions), 389 fusionOptions(fusionOptions), filter(filter), 390 fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {} 391 392 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( 393 Operation *op, PatternRewriter &rewriter) const { 394 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 395 // TODO: remove hasIndexSemantics check once index ops are supported. 396 if (!linalgOp || linalgOp.hasIndexSemantics()) 397 return failure(); 398 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 399 return failure(); 400 401 DenseSet<Operation *> producers; 402 producers.insert(linalgOp); 403 for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) { 404 Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum(); 405 // When looking at dependences into, indexingOp is always OpOperand. We 406 // could assert, but continue if this is not the case. 407 if (!operandNumber) 408 continue; 409 if (!fusionOptions.indicesToFuse.count(operandNumber.getValue())) 410 continue; 411 if (isa<LinalgOp>(dependence.getDependentOp())) 412 producers.insert(dependence.getDependentOp()); 413 } 414 415 SmallVector<LinalgOp, 1> fusionOps; 416 for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; 417 ++it) { 418 auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it)); 419 if (producerLinalgOp && producers.count(producerLinalgOp)) 420 fusionOps.push_back(producerLinalgOp); 421 } 422 fusionOps.push_back(linalgOp); 423 424 SmallVector<Value, 4> tileSizes = 425 tilingOptions.tileSizeComputationFunction(rewriter, op); 426 LinalgTilingOptions instanceTilingOptions = tilingOptions; 427 instanceTilingOptions.setTileSizes(tileSizes); 428 Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps( 429 rewriter, fusionOps, dependenceGraph, instanceTilingOptions); 430 if (!tiledAndFusedOps) 431 return failure(); 432 433 // Tile the unfused loops; 434 SmallVector<Value, 4> unfusedLoopTileSizes; 435 Value zero = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0); 436 for (auto tileSize : enumerate(tileSizes)) { 437 if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) 438 unfusedLoopTileSizes.push_back(zero); 439 else 440 unfusedLoopTileSizes.push_back(tileSize.value()); 441 } 442 // Tile the loop only if there is a non-zero tile size. 443 if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) 444 unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); 445 if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { 446 if (auto cst = val.getDefiningOp<arith::ConstantIndexOp>()) 447 return cst.value() != 0; 448 return true; 449 })) { 450 LinalgTilingOptions unfusedTilingOptions = tilingOptions; 451 unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); 452 Optional<TiledLinalgOp> unfusedTiledOp = 453 tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); 454 if (!unfusedTiledOp) 455 return failure(); 456 rewriter.replaceOp(tiledAndFusedOps->op, 457 getTiledOpResult(unfusedTiledOp.getValue())); 458 tiledAndFusedOps->op = unfusedTiledOp->op; 459 } 460 op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue())); 461 462 filter.replaceLinalgTransformationFilter(rewriter, 463 tiledAndFusedOps->op.getOperation()); 464 for (auto fusedOp : tiledAndFusedOps->fusedProducers) { 465 fusedOpMarker.replaceLinalgTransformationFilter(rewriter, 466 fusedOp.getOperation()); 467 } 468 for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) { 469 originalOpMarker.replaceLinalgTransformationFilter( 470 rewriter, origProducerOp.getOperation()); 471 } 472 rewriter.updateRootInPlace(op, [&]() { 473 originalOpMarker.replaceLinalgTransformationFilter(rewriter, op); 474 }); 475 return success(); 476 } 477 478 /// Linalg padding pattern. 479 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( 480 MLIRContext *context, LinalgPaddingOptions options, 481 LinalgTransformationFilter filter, PatternBenefit benefit) 482 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 483 options(options) {} 484 485 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( 486 StringRef opName, MLIRContext *context, LinalgPaddingOptions options, 487 LinalgTransformationFilter filter, PatternBenefit benefit) 488 : RewritePattern(opName, benefit, context, {}), filter(filter), 489 options(options) {} 490 491 LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite( 492 Operation *op, PatternRewriter &rewriter) const { 493 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 494 if (!linalgOp) 495 return failure(); 496 if (!linalgOp.hasTensorSemantics()) 497 return failure(); 498 if (failed(filter.checkAndNotify(rewriter, op))) 499 return failure(); 500 501 // Pad the operation. 502 LinalgOp paddedOp; 503 FailureOr<SmallVector<Value>> newResults = rewriteAsPaddedOp( 504 rewriter, linalgOp, options.paddingValueComputationFunction, 505 options.paddingNoFoldComputationFunction, paddedOp); 506 if (failed(newResults)) 507 return failure(); 508 509 // Compute the desired hoisting depths. 510 SmallVector<int64_t> depths; 511 if (options.paddingHoistComputationFunction) { 512 for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) 513 depths.push_back(options.paddingHoistComputationFunction(*opOperand)); 514 } 515 516 // Hoist the padding. 517 for (auto en : enumerate(depths)) { 518 OpOperand &opOperand = paddedOp->getOpOperand(en.index()); 519 auto padTensorOp = opOperand.get().getDefiningOp<PadTensorOp>(); 520 if (!padTensorOp || en.value() == 0) 521 continue; 522 PadTensorOp hoistedOp; 523 FailureOr<Value> newResult = 524 hoistPaddingOnTensors(padTensorOp, en.value(), hoistedOp); 525 if (failed(newResult)) 526 continue; 527 rewriter.replaceOp(padTensorOp, newResult.getValue()); 528 } 529 530 // Replace the original operation to pad. 531 rewriter.replaceOp(op, newResults.getValue()); 532 filter.replaceLinalgTransformationFilter(rewriter, paddedOp); 533 return success(); 534 } 535 536 /// Linalg tile and fuse tensor ops pattern. 537 mlir::linalg::LinalgTileAndFuseTensorOpsPattern:: 538 LinalgTileAndFuseTensorOpsPattern(MLIRContext *context, 539 LinalgTilingAndFusionOptions options, 540 LinalgTransformationFilter filter, 541 PatternBenefit benefit) 542 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 543 options(options) {} 544 545 mlir::linalg::LinalgTileAndFuseTensorOpsPattern:: 546 LinalgTileAndFuseTensorOpsPattern(StringRef opName, MLIRContext *context, 547 LinalgTilingAndFusionOptions options, 548 LinalgTransformationFilter filter, 549 PatternBenefit benefit) 550 : RewritePattern(opName, benefit, context), filter(filter), 551 options(options) {} 552 553 LogicalResult mlir::linalg::LinalgTileAndFuseTensorOpsPattern::matchAndRewrite( 554 Operation *op, PatternRewriter &rewriter) const { 555 LinalgOp rootOp = dyn_cast<LinalgOp>(op); 556 if (!rootOp) 557 return failure(); 558 if (failed(filter.checkAndNotify(rewriter, op))) 559 return failure(); 560 561 // Check `tileSizes` contains a tile size for every `rootOp` loop dimension. 562 if (options.tileSizes.size() < rootOp.getNumLoops()) 563 return rewriter.notifyMatchFailure(op, "expect #tile sizes >= #loops"); 564 565 // Check `tileInterchange` contains no entries or as many as `tileSizes`. 566 if (!options.tileInterchange.empty() && 567 options.tileInterchange.size() != options.tileSizes.size()) 568 return rewriter.notifyMatchFailure( 569 op, "expect the number of tile sizes and interchange dims to match"); 570 571 // Copy the `tileSizes` and `tileInterchange` prefixes needed for `rootOp`. 572 SmallVector<int64_t> rootTileSizes(options.tileSizes.begin(), 573 options.tileSizes.begin() + 574 rootOp.getNumLoops()); 575 SmallVector<int64_t> rootInterchange = 576 options.tileInterchange.empty() 577 ? llvm::to_vector<6>(llvm::seq<int64_t>(0, rootOp.getNumLoops())) 578 : SmallVector<int64_t>(options.tileInterchange.begin(), 579 options.tileInterchange.begin() + 580 rootOp.getNumLoops()); 581 582 // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions. 583 // It has to be a permutation since the tiling cannot tile the same loop 584 // dimension multiple times. 585 if (!isPermutation(rootInterchange)) 586 return rewriter.notifyMatchFailure( 587 op, "expect the tile interchange permutes the root loops"); 588 589 // Tile `rootOp` and fuse its producers. 590 FailureOr<TileLoopNest> tileLoopNest = tileConsumerAndFuseProducers( 591 rewriter, rootOp, rootTileSizes, rootInterchange); 592 if (failed(tileLoopNest)) 593 return rewriter.notifyMatchFailure( 594 op, "tileConsumerAndFuseProducers failed unexpectedly"); 595 596 // Replace all uses of the tiled loop operation. 597 rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults()); 598 599 // Apply the filter if specified. 600 for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps()) 601 filter.replaceLinalgTransformationFilter(rewriter, linalgOp); 602 return failure(); 603 } 604 605 /// Linalg generic interchange pattern. 606 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern( 607 MLIRContext *context, ArrayRef<unsigned> interchangeVector, 608 LinalgTransformationFilter filter, PatternBenefit benefit) 609 : OpRewritePattern(context, benefit), filter(filter), 610 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 611 612 LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite( 613 GenericOp genericOp, PatternRewriter &rewriter) const { 614 if (failed(filter.checkAndNotify(rewriter, genericOp))) 615 return failure(); 616 if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector))) 617 return failure(); 618 619 // TODO: figure out how this interplays with named ops. In particular this 620 // should break the named op property. 621 rewriter.updateRootInPlace(genericOp, [&]() { 622 interchangeGenericOp(rewriter, genericOp, interchangeVector); 623 // New filter if specified. 624 filter.replaceLinalgTransformationFilter(rewriter, genericOp); 625 }); 626 return success(); 627 } 628 629 /// Linalg generalization pattern. 630 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 631 MLIRContext *context, LinalgTransformationFilter filter, 632 PatternBenefit benefit) 633 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {} 634 635 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 636 StringRef opName, MLIRContext *context, LinalgTransformationFilter filter, 637 PatternBenefit benefit) 638 : RewritePattern(opName, benefit, context, {}), filter(filter) {} 639 640 LogicalResult mlir::linalg::LinalgGeneralizationPattern::matchAndRewrite( 641 Operation *op, PatternRewriter &rewriter) const { 642 if (failed(filter.checkAndNotify(rewriter, op))) 643 return failure(); 644 if (failed(generalizeNamedOpPrecondition(op))) 645 return failure(); 646 647 GenericOp genericOp = generalizeNamedOp(rewriter, op); 648 rewriter.replaceOp(op, genericOp.getResults()); 649 filter.replaceLinalgTransformationFilter(rewriter, genericOp); 650 return success(); 651 } 652 653 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 654 MLIRContext *context, LinalgTransformationFilter filter, 655 LinalgPromotionOptions options, PatternBenefit benefit) 656 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 657 options(options) {} 658 659 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 660 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 661 LinalgTransformationFilter filter, PatternBenefit benefit) 662 : RewritePattern(opName, benefit, context, {}), filter(filter), 663 options(options) {} 664 665 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 666 Operation *op, PatternRewriter &rewriter) const { 667 if (failed(filter.checkAndNotify(rewriter, op))) 668 return failure(); 669 if (failed(promoteSubviewsPrecondition(op, options))) 670 return failure(); 671 672 // TODO: We cannot use root update here. This pattern is creating other ops, 673 // so if the promotion fails, those need to be cleaned up, which doesnt seem 674 // to be happening here. So to fail properly, we should be cloning the op and 675 // deleting the previous op. This needs more investigation. 676 rewriter.startRootUpdate(op); 677 Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options); 678 if (!promotedOp) { 679 rewriter.cancelRootUpdate(op); 680 return op->emitError("subview promotion failed"); 681 } 682 rewriter.finalizeRootUpdate(op); 683 filter.replaceLinalgTransformationFilter(rewriter, op); 684 return success(); 685 } 686 687 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 688 MLIRContext *context, LinalgTransformationFilter filter, 689 PatternBenefit benefit) 690 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {} 691 692 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 693 StringRef opName, MLIRContext *context, LinalgTransformationFilter filter, 694 PatternBenefit benefit) 695 : RewritePattern(opName, benefit, context, {}), filter(filter) {} 696 697 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( 698 Operation *op, PatternRewriter &rewriter) const { 699 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 700 if (!linalgOp) 701 return failure(); 702 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 703 return failure(); 704 SmallVector<Value> newResults; 705 if (failed(vectorizeLinalgOp(rewriter, op, newResults))) 706 return failure(); 707 if (!newResults.empty()) 708 rewriter.replaceOp(op, newResults); 709 else 710 rewriter.eraseOp(op); 711 return success(); 712 } 713 714 LogicalResult mlir::linalg::applyStagedPatterns( 715 Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns, 716 const FrozenRewritePatternSet &stage2Patterns, 717 function_ref<LogicalResult(Operation *)> stage3Lambda) { 718 unsigned iteration = 0; 719 (void)iteration; 720 for (const auto &patterns : stage1Patterns) { 721 LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" 722 << *op); 723 if (failed(applyPatternsAndFoldGreedily(op, patterns))) { 724 LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); 725 return failure(); 726 } 727 LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" 728 << *op); 729 if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { 730 LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); 731 return failure(); 732 } 733 LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" 734 << *op); 735 if (stage3Lambda) { 736 if (failed(stage3Lambda(op))) 737 return failure(); 738 LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" 739 << *op); 740 } 741 } 742 return success(); 743 } 744 745 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) { 746 return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName()); 747 } 748 749 /// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize 750 /// with pad_val) and GenericOp (to copy contents). 751 LogicalResult PadTensorOpTransformationPattern::matchAndRewrite( 752 linalg::PadTensorOp padOp, PatternRewriter &rewriter) const { 753 754 auto inputShapedType = padOp.source().getType().cast<ShapedType>(); 755 auto resultShapedType = padOp.result().getType().cast<ShapedType>(); 756 757 // Bail on non-static shapes. 758 if (!inputShapedType.hasStaticShape()) 759 return failure(); 760 if (!resultShapedType.hasStaticShape()) 761 return failure(); 762 763 // Only support padding with a constant for now, i.e. either: 764 // 1. A BBarg from a different block. 765 // 2. A value defined outside of the current block. 766 Block &block = padOp.region().front(); 767 auto yieldOp = cast<YieldOp>(block.getTerminator()); 768 assert(yieldOp.getNumOperands() == 1 && "expected single operand yield"); 769 Value padValue = yieldOp.values().front(); 770 Operation *definingOp = padValue.getDefiningOp(); 771 if (definingOp && definingOp->getBlock() == &block) 772 return failure(); 773 if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block) 774 return failure(); 775 776 // Create tensor with the padded shape 777 Location loc = padOp.getLoc(); 778 SmallVector<Value> indices(resultShapedType.getRank(), 779 rewriter.create<arith::ConstantIndexOp>(loc, 0)); 780 Value initTensor = rewriter.create<InitTensorOp>( 781 loc, resultShapedType.getShape(), resultShapedType.getElementType()); 782 783 // Initialize tensor with the pad value 784 Value tmpTensor = 785 rewriter.create<linalg::FillOp>(loc, padValue, initTensor).result(); 786 787 // Copy original contents into new tensor 788 // Uses linalg.generic, but could be done with tensor.insert_slice 789 SmallVector<AffineExpr, 4> outputExprs; 790 for (unsigned i = 0; i < resultShapedType.getRank(); ++i) { 791 outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) + 792 padOp.static_low()[i].cast<IntegerAttr>().getInt()); 793 } 794 795 SmallVector<AffineMap, 2> transferMaps = { 796 rewriter.getMultiDimIdentityMap(inputShapedType.getRank()), 797 AffineMap::get(resultShapedType.getRank(), 798 /*symbolCount=*/0, outputExprs, rewriter.getContext())}; 799 800 rewriter.replaceOpWithNewOp<linalg::GenericOp>( 801 padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps, 802 getNParallelLoopsAttrs(resultShapedType.getRank()), 803 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { 804 nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]); 805 }); 806 807 return success(); 808 } 809 810 /// Filling `dest` using FillOp constant padding value if possible. 811 /// Otherwise, generate a tensor::GenerateOp. 812 Value GeneralizePadTensorOpPattern::createFillOrGenerateOp( 813 PatternRewriter &rewriter, PadTensorOp padOp, Value dest, 814 const SmallVector<Value> &dynSizes) const { 815 auto padValue = padOp.getConstantPaddingValue(); 816 if (padValue) 817 return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result(); 818 819 // Fill could not be optimized: Lower to tensor::GenerateOp with region. 820 auto generateOp = rewriter.create<tensor::GenerateOp>( 821 padOp.getLoc(), padOp.getResultType(), dynSizes); 822 // Copy region to new op. 823 BlockAndValueMapping bvm; 824 padOp.region().cloneInto(&generateOp.getRegion(), bvm); 825 // Rewrite linalg::YieldOp to tensor::YieldOp. 826 OpBuilder::InsertionGuard guard(rewriter); 827 auto yieldOp = 828 dyn_cast<linalg::YieldOp>(generateOp.getRegion().front().getTerminator()); 829 assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator"); 830 assert(yieldOp.values().size() == 1); 831 rewriter.setInsertionPoint(yieldOp); 832 rewriter.replaceOpWithNewOp<tensor::YieldOp>(yieldOp, yieldOp.values()[0]); 833 return generateOp; 834 } 835 836 LogicalResult 837 GeneralizePadTensorOpPattern::matchAndRewrite(PadTensorOp padOp, 838 PatternRewriter &rewriter) const { 839 // Given an OpFoldResult, return an index-typed value. 840 auto getIdxValue = [&](OpFoldResult ofr) { 841 if (auto val = ofr.dyn_cast<Value>()) 842 return val; 843 return rewriter 844 .create<arith::ConstantIndexOp>( 845 padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt()) 846 .getResult(); 847 }; 848 849 auto resultType = padOp.getResultType(); 850 // Compute size of InitTensorOp. Any combination of static/dynamic is 851 // supported. 852 SmallVector<Value> dynSizes; 853 SmallVector<int64_t> staticSizes; 854 for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { 855 if (resultType.isDynamicDim(dim)) { 856 auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(), 857 padOp.source(), dim); 858 // Add low and high padding value. 859 auto plusLow = rewriter.createOrFold<arith::AddIOp>( 860 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); 861 auto plusHigh = rewriter.createOrFold<arith::AddIOp>( 862 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); 863 dynSizes.push_back(plusHigh); 864 } 865 staticSizes.push_back(resultType.getDimSize(dim)); 866 } 867 868 // Init tensor and fill it with padding. 869 Value init = rewriter.create<InitTensorOp>( 870 padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType()); 871 Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes); 872 873 // Try optimize the copy of source. 874 if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded()) 875 return success(); 876 877 // PadTensorOps cannot be optimized. Generate a InsertSliceOp instead 878 // for copying the PadOp source. 879 auto sourceType = padOp.getSourceType(); 880 // Compute size of source of PadTensorOp. 881 SmallVector<OpFoldResult> srcSizes; 882 for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) { 883 if (sourceType.isDynamicDim(dim)) { 884 srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>( 885 padOp.getLoc(), padOp.source(), dim)); 886 } else { 887 srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim))); 888 } 889 } 890 // Strides of InsertSliceOp are all 1. 891 SmallVector<OpFoldResult> strides(sourceType.getRank(), 892 rewriter.getIndexAttr(1)); 893 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( 894 padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides); 895 896 return success(); 897 } 898 899 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( 900 tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { 901 auto padOp = sliceOp.source().getDefiningOp<PadTensorOp>(); 902 if (!padOp) 903 return failure(); 904 // Only unit stride supported. 905 if (!sliceOp.hasUnitStride()) 906 return failure(); 907 908 Operation *tiledPadOp = 909 padOp 910 .getTiledImplementation( 911 rewriter, /*dest=*/ValueRange{}, sliceOp.getMixedOffsets(), 912 sliceOp.getMixedSizes(), /*tileDestOperands=*/false) 913 .front(); 914 // All shapes are static and the data source is actually used. Rewrite into 915 // pad_tensor(subtensor(x)). 916 rewriter.replaceOp(sliceOp, tiledPadOp->getResults()); 917 return success(); 918 } 919 920 namespace { 921 // The following are patterns for downscaling convolution ops with size-1 922 // window dimensions. 923 // 924 // Note that we'd eventually want to write such transformations in a generic 925 // way, e.g., converting to linalg.generic, removing the size-1 dimensions, 926 // and then turning back to named ops. But for now it's fine to have a few 927 // patterns matching special ops to get started. 928 929 /// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D 930 /// convolution ops. 931 struct DownscaleSizeOneWindowed2DConvolution final 932 : public OpRewritePattern<Conv2DNhwcHwcfOp> { 933 DownscaleSizeOneWindowed2DConvolution( 934 MLIRContext *context, 935 LinalgTransformationFilter filter = LinalgTransformationFilter(), 936 PatternBenefit benefit = 1) 937 : OpRewritePattern<Conv2DNhwcHwcfOp>(context, benefit), filter(filter) {} 938 939 LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, 940 PatternRewriter &rewriter) const override { 941 if (failed(filter.checkAndNotify(rewriter, convOp))) 942 return failure(); 943 if (convOp.hasBufferSemantics()) 944 return failure(); // To be implemented 945 946 Value input = convOp.inputs().front(); 947 Value kernel = convOp.inputs().back(); 948 Value output = convOp.outputs().front(); 949 950 auto inputType = input.getType().dyn_cast<RankedTensorType>(); 951 auto kernelType = kernel.getType().dyn_cast<RankedTensorType>(); 952 auto outputType = output.getType().dyn_cast<RankedTensorType>(); 953 954 auto kernelShape = kernelType.getShape(); 955 auto outputShape = outputType.getShape(); 956 957 // Only handle the case where at least one of the window dimensions is 958 // of size 1. Other cases can rely on tiling to reduce to such cases. 959 int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; 960 int64_t ohSize = outputShape[1], owSize = outputShape[2]; 961 bool removeH = (khSize == 1 && ohSize == 1); 962 bool removeW = (kwSize == 1 && owSize == 1); 963 if (!removeH && !removeW) 964 return failure(); 965 966 // Get new shapes and types for all operands by removing the size-1 967 // dimension. 968 using RTTBuilder = RankedTensorType::Builder; 969 RankedTensorType newInputType = 970 RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); 971 RankedTensorType newKernelType = 972 RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); 973 RankedTensorType newOutputType = 974 RTTBuilder(outputType).dropDim(removeH ? 1 : 2); 975 976 // Rank-reduce operands. 977 Location loc = convOp.getLoc(); 978 Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( 979 rewriter, loc, input, newInputType); 980 Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( 981 rewriter, loc, kernel, newKernelType); 982 Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( 983 rewriter, loc, output, newOutputType); 984 985 // Rank-reduce strides and dilations too. 986 // TODO: dropDim 1-liner helper. 987 auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>()); 988 strides.erase(strides.begin() + (removeH ? 0 : 1)); 989 auto stridesAttr = rewriter.getI64VectorAttr(strides); 990 991 auto dilations = 992 llvm::to_vector<4>(convOp.dilations().getValues<int64_t>()); 993 dilations.erase(dilations.begin() + (removeH ? 0 : 1)); 994 auto dilationsAttr = rewriter.getI64VectorAttr(dilations); 995 996 auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>( 997 loc, newOutputType, ValueRange{newInput, newKernel}, 998 ValueRange{newOutput}, stridesAttr, dilationsAttr); 999 1000 // Insert back. 1001 Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( 1002 rewriter, loc, conv1DOp.getResult(0), output); 1003 rewriter.replaceOp(convOp, inserted); 1004 1005 filter.replaceLinalgTransformationFilter(rewriter, conv1DOp); 1006 return success(); 1007 }; 1008 1009 private: 1010 /// LinalgTransformMarker handles special attribute manipulations. 1011 LinalgTransformationFilter filter; 1012 }; 1013 1014 /// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) 1015 /// dimensions into 1-D depthwise convolution ops. 1016 struct DownscaleDepthwiseConv2DNhwcHwcOp final 1017 : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> { 1018 DownscaleDepthwiseConv2DNhwcHwcOp( 1019 MLIRContext *context, 1020 LinalgTransformationFilter filter = LinalgTransformationFilter(), 1021 PatternBenefit benefit = 1) 1022 : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit), 1023 filter(filter) {} 1024 1025 LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, 1026 PatternRewriter &rewriter) const override { 1027 if (failed(filter.checkAndNotify(rewriter, convOp))) 1028 return failure(); 1029 if (convOp.hasBufferSemantics()) 1030 return failure(); // To be implemented 1031 1032 Value input = convOp.inputs().front(); 1033 Value kernel = convOp.inputs().back(); 1034 Value output = convOp.outputs().front(); 1035 1036 auto inputType = input.getType().dyn_cast<RankedTensorType>(); 1037 auto kernelType = kernel.getType().dyn_cast<RankedTensorType>(); 1038 auto outputType = output.getType().dyn_cast<RankedTensorType>(); 1039 1040 auto kernelShape = kernelType.getShape(); 1041 auto outputShape = outputType.getShape(); 1042 1043 // Only handle the case where at least one of the window dimensions is 1044 // of size 1. Other cases can rely on tiling to reduce to such cases. 1045 int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; 1046 int64_t ohSize = outputShape[1], owSize = outputShape[2]; 1047 bool removeH = (khSize == 1 && ohSize == 1); 1048 bool removeW = (kwSize == 1 && owSize == 1); 1049 if (!removeH && !removeW) 1050 return failure(); 1051 1052 // Get new shapes and types for all operands by removing the size-1 1053 // dimension. 1054 using RTTBuilder = RankedTensorType::Builder; 1055 RankedTensorType newInputType = 1056 RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); 1057 RankedTensorType newKernelType = 1058 RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); 1059 RankedTensorType newOutputType = 1060 RTTBuilder(outputType).dropDim(removeH ? 1 : 2); 1061 1062 // Rank-reduce operands. 1063 Location loc = convOp.getLoc(); 1064 Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( 1065 rewriter, loc, input, newInputType); 1066 Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( 1067 rewriter, loc, kernel, newKernelType); 1068 Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( 1069 rewriter, loc, output, newOutputType); 1070 1071 // Rank-reduce strides and dilations too. 1072 // TODO: dropDim 1-liner helper. 1073 auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>()); 1074 strides.erase(strides.begin() + (removeH ? 0 : 1)); 1075 auto stridesAttr = rewriter.getI64VectorAttr(strides); 1076 1077 auto dilations = 1078 llvm::to_vector<4>(convOp.dilations().getValues<int64_t>()); 1079 dilations.erase(dilations.begin() + (removeH ? 0 : 1)); 1080 auto dilationsAttr = rewriter.getI64VectorAttr(dilations); 1081 1082 auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>( 1083 loc, newOutputType, ValueRange{newInput, newKernel}, 1084 ValueRange{newOutput}, stridesAttr, dilationsAttr); 1085 1086 // Insert back. 1087 Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( 1088 rewriter, loc, conv1DOp.getResult(0), output); 1089 rewriter.replaceOp(convOp, inserted); 1090 1091 filter.replaceLinalgTransformationFilter(rewriter, conv1DOp); 1092 return success(); 1093 }; 1094 1095 private: 1096 /// LinalgTransformMarker handles special attribute manipulations. 1097 LinalgTransformationFilter filter; 1098 }; 1099 1100 } // namespace 1101 1102 void linalg::populateDecomposeConvolutionPatterns( 1103 RewritePatternSet &patterns, LinalgTransformationFilter filter, 1104 PatternBenefit benefit) { 1105 patterns.add<DownscaleSizeOneWindowed2DConvolution, 1106 DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), filter, 1107 benefit); 1108 } 1109