1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Affine/Utils.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 18 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/SCF/Transforms.h" 22 #include "mlir/Dialect/Tensor/IR/Tensor.h" 23 #include "mlir/Dialect/Utils/StaticValueUtils.h" 24 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 25 #include "mlir/Dialect/Vector/VectorOps.h" 26 #include "mlir/IR/AffineExpr.h" 27 #include "mlir/IR/Matchers.h" 28 #include "mlir/Pass/Pass.h" 29 #include "mlir/Support/LLVM.h" 30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 31 #include "llvm/ADT/ScopeExit.h" 32 #include "llvm/ADT/TypeSwitch.h" 33 #include "llvm/Support/Debug.h" 34 #include "llvm/Support/raw_ostream.h" 35 #include <type_traits> 36 37 #define DEBUG_TYPE "linalg-transforms" 38 39 using namespace mlir; 40 using namespace mlir::linalg; 41 42 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 43 44 //===----------------------------------------------------------------------===// 45 // Transformations exposed as rewrite patterns. 46 //===----------------------------------------------------------------------===// 47 // Marker used as attribute name in generated Linalg rewriting transformations. 48 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 49 "__internal_linalg_transform__"; 50 51 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 52 ArrayRef<StringAttr> matchDisjunction, Optional<StringAttr> replacement) 53 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 54 replacement(replacement), matchByDefault(false) {} 55 56 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 57 FilterFunction f, ArrayRef<StringAttr> matchDisjunction, 58 Optional<StringAttr> replacement) 59 : filters(), 60 matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 61 replacement(replacement), matchByDefault(false) { 62 if (f) 63 filters.push_back(f); 64 } 65 66 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify( 67 PatternRewriter &rewriter, Operation *op) const { 68 if (llvm::any_of(filters, 69 [&](const FilterFunction &f) { return failed(f(op)); })) 70 return failure(); 71 72 auto attr = op->template getAttrOfType<StringAttr>( 73 LinalgTransforms::kLinalgTransformMarker); 74 75 if (!attr) { 76 // 1. Has no filter case and matchDisjunction is empty. 77 if (matchDisjunction.empty() || matchByDefault) 78 return success(); 79 80 // 2. Has no filter but was expecting a filter. 81 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 82 diag << " does not have any filter from list: "; 83 interleaveComma(matchDisjunction, diag); 84 }); 85 } 86 87 // 4. Match explicit filter. 88 for (auto filter : matchDisjunction) 89 if (attr.getValue() == filter) 90 return success(); 91 92 // 5. Fail to match. 93 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 94 diag << " does not have any filter from list: "; 95 interleaveComma(matchDisjunction, diag); 96 }); 97 } 98 99 void mlir::linalg::LinalgTransformationFilter:: 100 replaceLinalgTransformationFilter(PatternRewriter &rewriter, 101 Operation *op) const { 102 if (replacement.hasValue()) 103 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 104 replacement.getValue()); 105 else 106 op->removeAttr( 107 rewriter.getStringAttr(LinalgTransforms::kLinalgTransformMarker)); 108 } 109 110 bool mlir::linalg::LinalgTransformationFilter::hasReplacementFilter( 111 Operation *op) const { 112 if (!replacement) 113 return false; 114 auto attr = op->getAttr(LinalgTransforms::kLinalgTransformMarker) 115 .dyn_cast<StringAttr>(); 116 return attr && attr == replacement.getValue(); 117 } 118 119 LinalgTilingOptions & 120 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 121 assert(!tileSizeComputationFunction && "tile sizes already set"); 122 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 123 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 124 OpBuilder::InsertionGuard guard(b); 125 b.setInsertionPointToStart( 126 &op->getParentOfType<FuncOp>().getBody().front()); 127 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 128 Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); 129 return v; 130 })); 131 }; 132 return *this; 133 } 134 135 LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() { 136 assert(!tileSizeComputationFunction && "tile sizes already set"); 137 tileSizeComputationFunction = [](OpBuilder &b, Operation *op) { 138 SmallVector<Value, 4> tileSizes; 139 auto linalgOp = dyn_cast<LinalgOp>(op); 140 if (!linalgOp) 141 return tileSizes; 142 Location loc = linalgOp.getLoc(); 143 auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc); 144 AffineMap map = linalgOp.getShapesToLoopsMap(); 145 if (!map) 146 return tileSizes; 147 auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes); 148 // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile 149 // size 0). 150 for (Value shapeSize : shapeSizes) 151 tileSizes.push_back(getConstantIntValue(shapeSize).hasValue() 152 ? b.create<arith::ConstantIndexOp>(loc, 0) 153 : b.create<arith::ConstantIndexOp>(loc, 1)); 154 return tileSizes; 155 }; 156 return *this; 157 } 158 159 /// Helper function that tries to pad `opOperand`. Exit early for scalar 160 /// operands, if `paddingFunc` returns failure, or if `opOperand` is not defined 161 /// by an ExtractSliceOp. Otherwise, try to pad the operand even if it already 162 /// has a static shape. Set `result` to the result of the created PadTensorOp or 163 /// and return success if the operand either has been padded to a static shape 164 /// or already had a static shape and failure otherwise. 165 static LogicalResult padOperandToSmallestStaticBoundingBox( 166 OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand, 167 const PaddingValueComputationFunction &paddingFunc, 168 const PaddingNoFoldComputationFunction &nofoldFunc, Value &result) { 169 // Get the shape of the operand and check if it has a dynamic shape. Only 170 // return failure if the operand is not a scalar and has a dynamic shape. 171 ArrayRef<int64_t> shape = opToPad.getShape(opOperand); 172 bool hasDynamicShape = llvm::is_contained(shape, ShapedType::kDynamicSize); 173 174 // Cannot pad scalar operands. 175 if (shape.empty()) 176 return success(); 177 178 // Cannot pad if the padding value is unknown. 179 FailureOr<Value> paddingValue = paddingFunc(b, *opOperand); 180 if (failed(paddingValue)) 181 return failure(hasDynamicShape); 182 183 // Cannot construct a static bounding box if the operand is not defined by an 184 // ExtractSliceOp. 185 auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>(); 186 if (!sliceOp) 187 return failure(hasDynamicShape); 188 189 // Compute the dropped dimensions if `sliceOp` is ranke-reducing. 190 llvm::SmallDenseSet<unsigned> droppedDims = sliceOp.getDroppedDims(); 191 192 // Upper bound the `sliceOp` sizes to obtain a static bounding box. 193 SmallVector<int64_t> staticSizes; 194 staticSizes.reserve(shape.size()); 195 auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation()); 196 for (auto en : enumerate(shapedOp.getMixedSizes())) { 197 // Skip dropped dimensions. 198 if (droppedDims.contains(en.index())) 199 continue; 200 // If the size is an attribute add it directly to `staticSizes`. 201 if (en.value().is<Attribute>()) { 202 staticSizes.push_back( 203 en.value().get<Attribute>().dyn_cast<IntegerAttr>().getInt()); 204 continue; 205 } 206 // Otherwise, try to compute a constant upper bound for the size value. 207 FailureOr<int64_t> upperBound = 208 getConstantUpperBoundForIndex(en.value().get<Value>()); 209 if (failed(upperBound)) { 210 LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding"); 211 return failure(); 212 } 213 staticSizes.push_back(upperBound.getValue()); 214 } 215 assert(staticSizes.size() == shape.size() && 216 "expect the dynamic and static ranks to match"); 217 218 // Pad the operand to the bounding box defined by `staticSizes`. 219 auto staticTensorType = RankedTensorType::get( 220 staticSizes, getElementTypeOrSelf(opOperand->get())); 221 bool nofold = nofoldFunc ? nofoldFunc(*opOperand) : false; 222 result = 223 makeComposedPadHighOp(b, opToPad->getLoc(), staticTensorType, 224 opOperand->get(), paddingValue.getValue(), nofold); 225 return success(); 226 } 227 228 FailureOr<SmallVector<Value>> 229 linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad, 230 const PaddingValueComputationFunction &paddingFunc, 231 const PaddingNoFoldComputationFunction &nofoldFunc, 232 LinalgOp &paddedOp) { 233 Location loc = opToPad->getLoc(); 234 235 // TODO: there are cases where we may still want to pad to larger sizes. 236 assert(opToPad.hasTensorSemantics() && 237 "expected operation to have tensor semantics"); 238 239 OpBuilder::InsertionGuard g(b); 240 // Set IP after op because we also take the dims of the original output. 241 b.setInsertionPointAfter(opToPad); 242 // Make a copy of the shaped operands and update it. 243 SmallVector<Value> newOperands; 244 newOperands.reserve(opToPad.getNumInputsAndOutputs()); 245 for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) { 246 Value paddedOperand; 247 // If padding was requested but the shape cannot be bounded statically then 248 // the pattern fails to apply. 249 if (failed(padOperandToSmallestStaticBoundingBox( 250 b, opToPad, opOperand, paddingFunc, nofoldFunc, paddedOperand))) 251 return failure(); 252 newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get()); 253 } 254 255 SmallVector<SmallVector<Value>> reifiedResultShapes; 256 if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation()) 257 .reifyResultShapes(b, reifiedResultShapes))) 258 return failure(); 259 assert(reifiedResultShapes.size() == opToPad->getNumResults() && 260 "expected same number of results"); 261 262 // Clone `opToPad` to operate on the statically padded shapes. 263 auto resultTensorTypes = 264 ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes(); 265 paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands); 266 267 // Recover the slice out of the new static results. This keeps the original 268 // linalg op around because it uses the dims of the original results. 269 SmallVector<Value> paddedSubviewResults; 270 paddedSubviewResults.reserve(opToPad->getNumResults()); 271 for (auto en : llvm::enumerate(paddedOp->getResults())) { 272 Value paddedResult = en.value(); 273 int64_t resultNumber = en.index(); 274 int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank(); 275 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0)); 276 SmallVector<OpFoldResult> sizes; 277 for (Value v : reifiedResultShapes[resultNumber]) 278 sizes.push_back(getAsOpFoldResult(v)); 279 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1)); 280 paddedSubviewResults.push_back(b.create<tensor::ExtractSliceOp>( 281 loc, paddedResult, offsets, sizes, strides)); 282 } 283 return paddedSubviewResults; 284 } 285 286 /// Linalg base tiling pattern. 287 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 288 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 289 LinalgTransformationFilter filter, PatternBenefit benefit) 290 : RewritePattern(opName, benefit, context), filter(filter), 291 options(options) {} 292 293 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 294 MLIRContext *context, LinalgTilingOptions options, 295 LinalgTransformationFilter filter, PatternBenefit benefit) 296 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 297 options(options) {} 298 299 /// Try to peel a loop `op` and return the new result. 300 // TODO: Add support for scf.parallel and affine.for loops. 301 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) { 302 return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op) 303 .Case<scf::ForOp>([&](scf::ForOp forOp) { 304 scf::ForOp partialIteration; 305 if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp, 306 partialIteration))) 307 return partialIteration->getResults(); 308 assert(!partialIteration && "expected that loop was not peeled"); 309 return forOp->getResults(); 310 }) 311 .Default([&](Operation *op) { return op->getResults(); }); 312 } 313 314 /// Try to peel a TiledLoopOp and return the new result. 315 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, 316 TiledLoopOp tiledLoop, int64_t idx) { 317 assert(idx < static_cast<int64_t>(tiledLoop.iterator_types().size()) && 318 "requested peeling of non-existing loop"); 319 TiledLoopOp result; 320 if (succeeded(peelAndCanonicalizeTiledLoop(rewriter, tiledLoop, idx, result))) 321 return result->getResults(); 322 assert(!result && "expected that loop was not peeled"); 323 return tiledLoop->getResults(); 324 } 325 326 /// Peel loops after tiling. 327 static void peelLoops(RewriterBase &rewriter, TiledLinalgOp &res, 328 const LinalgTilingOptions &options) { 329 for (int64_t loop : options.peeledLoops) { 330 assert(loop < static_cast<int64_t>(res.loops.size()) && 331 "requested peeling of non-existing loop"); 332 SmallVector<Value, 4> loopResults; 333 Operation *loopOp = res.loops[loop]; 334 if (options.loopType == LinalgTilingLoopType::TiledLoops) { 335 assert(llvm::all_of( 336 res.loops, 337 [&](Operation *op) { return op == res.loops.front(); }) && 338 "expected that all loop ops are the same TiledLoopOp"); 339 auto tiledLoopOp = dyn_cast<TiledLoopOp>(loopOp); 340 assert(tiledLoopOp && "expected TiledLoopOp"); 341 loopResults = peelLoop(rewriter, tiledLoopOp, loop); 342 } else { 343 loopResults = peelLoop(rewriter, loopOp); 344 } 345 346 // The result of the loop nest may change with peeling. 347 if (res.tensorResults.size() == loopOp->getNumResults() && 348 std::equal(res.tensorResults.begin(), res.tensorResults.end(), 349 loopOp->getResults().begin())) 350 res.tensorResults = loopResults; 351 } 352 } 353 354 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase( 355 Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const { 356 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 357 if (!linalgOp) 358 return failure(); 359 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 360 return failure(); 361 362 Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options); 363 364 if (!res) 365 return failure(); 366 // Clear filter to stop recursive pattern application. 367 filter.replaceLinalgTransformationFilter(rewriter, res->op); 368 369 // Peel loops. 370 peelLoops(rewriter, *res, options); 371 372 result = *res; 373 return success(); 374 } 375 376 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) { 377 if (tiledOp.loops.empty()) 378 return tiledOp.op.getOperation()->getResults(); 379 return tiledOp.loops.front()->getResults(); 380 } 381 382 static ValueRange 383 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) { 384 if (tiledAndFusedOp.fusedLoops.empty()) 385 return tiledAndFusedOp.op.getOperation()->getResults(); 386 return tiledAndFusedOp.fusedLoops.front()->getResults(); 387 } 388 389 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( 390 StringRef opName, MLIRContext *context, 391 const LinalgDependenceGraph &dependenceGraph, 392 LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, 393 LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker, 394 LinalgTransformationFilter originalOpMarker, PatternBenefit benefit) 395 : RewritePattern(opName, benefit, context, {}), 396 dependenceGraph(dependenceGraph), tilingOptions(tilingOptions), 397 fusionOptions(fusionOptions), filter(filter), 398 fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {} 399 400 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( 401 Operation *op, PatternRewriter &rewriter) const { 402 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 403 // TODO: remove hasIndexSemantics check once index ops are supported. 404 if (!linalgOp || linalgOp.hasIndexSemantics()) 405 return failure(); 406 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 407 return failure(); 408 409 DenseSet<Operation *> producers; 410 producers.insert(linalgOp); 411 for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) { 412 Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum(); 413 // When looking at dependences into, indexingOp is always OpOperand. We 414 // could assert, but continue if this is not the case. 415 if (!operandNumber) 416 continue; 417 if (!fusionOptions.indicesToFuse.count(operandNumber.getValue())) 418 continue; 419 if (isa<LinalgOp>(dependence.getDependentOp())) 420 producers.insert(dependence.getDependentOp()); 421 } 422 423 SmallVector<LinalgOp, 1> fusionOps; 424 for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; 425 ++it) { 426 auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it)); 427 if (producerLinalgOp && producers.count(producerLinalgOp)) 428 fusionOps.push_back(producerLinalgOp); 429 } 430 fusionOps.push_back(linalgOp); 431 432 SmallVector<Value, 4> tileSizes = 433 tilingOptions.tileSizeComputationFunction(rewriter, op); 434 LinalgTilingOptions instanceTilingOptions = tilingOptions; 435 instanceTilingOptions.setTileSizes(tileSizes); 436 Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps( 437 rewriter, fusionOps, dependenceGraph, instanceTilingOptions); 438 if (!tiledAndFusedOps) 439 return failure(); 440 441 // Tile the unfused loops; 442 SmallVector<Value, 4> unfusedLoopTileSizes; 443 Value zero = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0); 444 for (auto tileSize : enumerate(tileSizes)) { 445 if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) 446 unfusedLoopTileSizes.push_back(zero); 447 else 448 unfusedLoopTileSizes.push_back(tileSize.value()); 449 } 450 // Tile the loop only if there is a non-zero tile size. 451 if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) 452 unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); 453 if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { 454 if (auto cst = val.getDefiningOp<arith::ConstantIndexOp>()) 455 return cst.value() != 0; 456 return true; 457 })) { 458 LinalgTilingOptions unfusedTilingOptions = tilingOptions; 459 unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); 460 Optional<TiledLinalgOp> unfusedTiledOp = 461 tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); 462 if (!unfusedTiledOp) 463 return failure(); 464 rewriter.replaceOp(tiledAndFusedOps->op, 465 getTiledOpResult(unfusedTiledOp.getValue())); 466 tiledAndFusedOps->op = unfusedTiledOp->op; 467 } 468 op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue())); 469 470 filter.replaceLinalgTransformationFilter(rewriter, 471 tiledAndFusedOps->op.getOperation()); 472 for (auto fusedOp : tiledAndFusedOps->fusedProducers) { 473 fusedOpMarker.replaceLinalgTransformationFilter(rewriter, 474 fusedOp.getOperation()); 475 } 476 for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) { 477 originalOpMarker.replaceLinalgTransformationFilter( 478 rewriter, origProducerOp.getOperation()); 479 } 480 rewriter.updateRootInPlace(op, [&]() { 481 originalOpMarker.replaceLinalgTransformationFilter(rewriter, op); 482 }); 483 return success(); 484 } 485 486 /// Linalg padding pattern. 487 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( 488 MLIRContext *context, LinalgPaddingOptions options, 489 LinalgTransformationFilter filter, PatternBenefit benefit) 490 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 491 options(options) {} 492 493 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( 494 StringRef opName, MLIRContext *context, LinalgPaddingOptions options, 495 LinalgTransformationFilter filter, PatternBenefit benefit) 496 : RewritePattern(opName, benefit, context, {}), filter(filter), 497 options(options) {} 498 499 LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite( 500 Operation *op, PatternRewriter &rewriter) const { 501 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 502 if (!linalgOp) 503 return failure(); 504 if (!linalgOp.hasTensorSemantics()) 505 return failure(); 506 if (failed(filter.checkAndNotify(rewriter, op))) 507 return failure(); 508 509 // Pad the operation. 510 LinalgOp paddedOp; 511 FailureOr<SmallVector<Value>> newResults = rewriteAsPaddedOp( 512 rewriter, linalgOp, options.paddingValueComputationFunction, 513 options.paddingNoFoldComputationFunction, paddedOp); 514 if (failed(newResults)) 515 return failure(); 516 517 // Compute the desired hoisting depths. 518 SmallVector<int64_t> depths; 519 if (options.paddingHoistComputationFunction) { 520 for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) 521 depths.push_back(options.paddingHoistComputationFunction(*opOperand)); 522 } 523 524 // Hoist the padding. 525 for (auto en : enumerate(depths)) { 526 OpOperand &opOperand = paddedOp->getOpOperand(en.index()); 527 auto padTensorOp = opOperand.get().getDefiningOp<PadTensorOp>(); 528 if (!padTensorOp || en.value() == 0) 529 continue; 530 PadTensorOp hoistedOp; 531 FailureOr<Value> newResult = 532 hoistPaddingOnTensors(padTensorOp, en.value(), hoistedOp); 533 if (failed(newResult)) 534 continue; 535 rewriter.replaceOp(padTensorOp, newResult.getValue()); 536 } 537 538 // Replace the original operation to pad. 539 rewriter.replaceOp(op, newResults.getValue()); 540 filter.replaceLinalgTransformationFilter(rewriter, paddedOp); 541 return success(); 542 } 543 544 /// Linalg tile and fuse tensor ops pattern. 545 mlir::linalg::LinalgTileAndFuseTensorOpsPattern:: 546 LinalgTileAndFuseTensorOpsPattern(MLIRContext *context, 547 LinalgTilingAndFusionOptions options, 548 LinalgTransformationFilter filter, 549 PatternBenefit benefit) 550 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 551 options(options) {} 552 553 mlir::linalg::LinalgTileAndFuseTensorOpsPattern:: 554 LinalgTileAndFuseTensorOpsPattern(StringRef opName, MLIRContext *context, 555 LinalgTilingAndFusionOptions options, 556 LinalgTransformationFilter filter, 557 PatternBenefit benefit) 558 : RewritePattern(opName, benefit, context), filter(filter), 559 options(options) {} 560 561 LogicalResult mlir::linalg::LinalgTileAndFuseTensorOpsPattern::matchAndRewrite( 562 Operation *op, PatternRewriter &rewriter) const { 563 LinalgOp rootOp = dyn_cast<LinalgOp>(op); 564 if (!rootOp) 565 return failure(); 566 if (failed(filter.checkAndNotify(rewriter, op))) 567 return failure(); 568 569 // Check `tileSizes` contains a tile size for every `rootOp` loop dimension. 570 if (options.tileSizes.size() < rootOp.getNumLoops()) 571 return rewriter.notifyMatchFailure(op, "expect #tile sizes >= #loops"); 572 573 // Check `tileInterchange` contains no entries or as many as `tileSizes`. 574 if (!options.tileInterchange.empty() && 575 options.tileInterchange.size() != options.tileSizes.size()) 576 return rewriter.notifyMatchFailure( 577 op, "expect the number of tile sizes and interchange dims to match"); 578 579 // Copy the `tileSizes` and `tileInterchange` prefixes needed for `rootOp`. 580 SmallVector<int64_t> rootTileSizes(options.tileSizes.begin(), 581 options.tileSizes.begin() + 582 rootOp.getNumLoops()); 583 SmallVector<int64_t> rootInterchange = 584 options.tileInterchange.empty() 585 ? llvm::to_vector<6>(llvm::seq<int64_t>(0, rootOp.getNumLoops())) 586 : SmallVector<int64_t>(options.tileInterchange.begin(), 587 options.tileInterchange.begin() + 588 rootOp.getNumLoops()); 589 590 // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions. 591 // It has to be a permutation since the tiling cannot tile the same loop 592 // dimension multiple times. 593 if (!isPermutation(rootInterchange)) 594 return rewriter.notifyMatchFailure( 595 op, "expect the tile interchange permutes the root loops"); 596 597 // Tile `rootOp` and fuse its producers. 598 FailureOr<TileLoopNest> tileLoopNest = tileConsumerAndFuseProducers( 599 rewriter, rootOp, rootTileSizes, rootInterchange); 600 if (failed(tileLoopNest)) 601 return rewriter.notifyMatchFailure( 602 op, "tileConsumerAndFuseProducers failed unexpectedly"); 603 604 // Replace all uses of the tiled loop operation. 605 rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults()); 606 607 // Apply the filter if specified. 608 for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps()) 609 filter.replaceLinalgTransformationFilter(rewriter, linalgOp); 610 return failure(); 611 } 612 613 /// Linalg generic interchange pattern. 614 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern( 615 MLIRContext *context, ArrayRef<unsigned> interchangeVector, 616 LinalgTransformationFilter filter, PatternBenefit benefit) 617 : OpRewritePattern(context, benefit), filter(filter), 618 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 619 620 LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite( 621 GenericOp genericOp, PatternRewriter &rewriter) const { 622 if (failed(filter.checkAndNotify(rewriter, genericOp))) 623 return failure(); 624 if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector))) 625 return failure(); 626 627 // TODO: figure out how this interplays with named ops. In particular this 628 // should break the named op property. 629 rewriter.updateRootInPlace(genericOp, [&]() { 630 interchangeGenericOp(rewriter, genericOp, interchangeVector); 631 // New filter if specified. 632 filter.replaceLinalgTransformationFilter(rewriter, genericOp); 633 }); 634 return success(); 635 } 636 637 /// Linalg generalization pattern. 638 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 639 MLIRContext *context, LinalgTransformationFilter filter, 640 PatternBenefit benefit) 641 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {} 642 643 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 644 StringRef opName, MLIRContext *context, LinalgTransformationFilter filter, 645 PatternBenefit benefit) 646 : RewritePattern(opName, benefit, context, {}), filter(filter) {} 647 648 LogicalResult mlir::linalg::LinalgGeneralizationPattern::matchAndRewrite( 649 Operation *op, PatternRewriter &rewriter) const { 650 if (failed(filter.checkAndNotify(rewriter, op))) 651 return failure(); 652 if (failed(generalizeNamedOpPrecondition(op))) 653 return failure(); 654 655 GenericOp genericOp = generalizeNamedOp(rewriter, op); 656 rewriter.replaceOp(op, genericOp.getResults()); 657 filter.replaceLinalgTransformationFilter(rewriter, genericOp); 658 return success(); 659 } 660 661 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 662 MLIRContext *context, LinalgTransformationFilter filter, 663 LinalgPromotionOptions options, PatternBenefit benefit) 664 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 665 options(options) {} 666 667 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 668 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 669 LinalgTransformationFilter filter, PatternBenefit benefit) 670 : RewritePattern(opName, benefit, context, {}), filter(filter), 671 options(options) {} 672 673 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 674 Operation *op, PatternRewriter &rewriter) const { 675 if (failed(filter.checkAndNotify(rewriter, op))) 676 return failure(); 677 if (failed(promoteSubviewsPrecondition(op, options))) 678 return failure(); 679 680 // TODO: We cannot use root update here. This pattern is creating other ops, 681 // so if the promotion fails, those need to be cleaned up, which doesnt seem 682 // to be happening here. So to fail properly, we should be cloning the op and 683 // deleting the previous op. This needs more investigation. 684 rewriter.startRootUpdate(op); 685 Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options); 686 if (!promotedOp) { 687 rewriter.cancelRootUpdate(op); 688 return op->emitError("subview promotion failed"); 689 } 690 rewriter.finalizeRootUpdate(op); 691 filter.replaceLinalgTransformationFilter(rewriter, op); 692 return success(); 693 } 694 695 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 696 MLIRContext *context, LinalgTransformationFilter filter, 697 PatternBenefit benefit) 698 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {} 699 700 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 701 StringRef opName, MLIRContext *context, LinalgTransformationFilter filter, 702 PatternBenefit benefit) 703 : RewritePattern(opName, benefit, context, {}), filter(filter) {} 704 705 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( 706 Operation *op, PatternRewriter &rewriter) const { 707 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 708 if (!linalgOp) 709 return failure(); 710 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 711 return failure(); 712 SmallVector<Value> newResults; 713 if (failed(vectorizeLinalgOp(rewriter, op, newResults))) 714 return failure(); 715 if (!newResults.empty()) 716 rewriter.replaceOp(op, newResults); 717 else 718 rewriter.eraseOp(op); 719 return success(); 720 } 721 722 LogicalResult mlir::linalg::applyStagedPatterns( 723 Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns, 724 const FrozenRewritePatternSet &stage2Patterns, 725 function_ref<LogicalResult(Operation *)> stage3Lambda) { 726 unsigned iteration = 0; 727 (void)iteration; 728 for (const auto &patterns : stage1Patterns) { 729 LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" 730 << *op); 731 if (failed(applyPatternsAndFoldGreedily(op, patterns))) { 732 LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); 733 return failure(); 734 } 735 LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" 736 << *op); 737 if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { 738 LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); 739 return failure(); 740 } 741 LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" 742 << *op); 743 if (stage3Lambda) { 744 if (failed(stage3Lambda(op))) 745 return failure(); 746 LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" 747 << *op); 748 } 749 } 750 return success(); 751 } 752 753 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) { 754 return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName()); 755 } 756 757 /// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize 758 /// with pad_val) and GenericOp (to copy contents). 759 LogicalResult PadTensorOpTransformationPattern::matchAndRewrite( 760 linalg::PadTensorOp padOp, PatternRewriter &rewriter) const { 761 762 auto inputShapedType = padOp.source().getType().cast<ShapedType>(); 763 auto resultShapedType = padOp.result().getType().cast<ShapedType>(); 764 765 // Bail on non-static shapes. 766 if (!inputShapedType.hasStaticShape()) 767 return failure(); 768 if (!resultShapedType.hasStaticShape()) 769 return failure(); 770 771 // Only support padding with a constant for now, i.e. either: 772 // 1. A BBarg from a different block. 773 // 2. A value defined outside of the current block. 774 Block &block = padOp.region().front(); 775 auto yieldOp = cast<YieldOp>(block.getTerminator()); 776 assert(yieldOp.getNumOperands() == 1 && "expected single operand yield"); 777 Value padValue = yieldOp.values().front(); 778 Operation *definingOp = padValue.getDefiningOp(); 779 if (definingOp && definingOp->getBlock() == &block) 780 return failure(); 781 if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block) 782 return failure(); 783 784 // Create tensor with the padded shape 785 Location loc = padOp.getLoc(); 786 SmallVector<Value> indices(resultShapedType.getRank(), 787 rewriter.create<arith::ConstantIndexOp>(loc, 0)); 788 Value initTensor = rewriter.create<InitTensorOp>( 789 loc, resultShapedType.getShape(), resultShapedType.getElementType()); 790 791 // Initialize tensor with the pad value 792 Value tmpTensor = 793 rewriter.create<linalg::FillOp>(loc, padValue, initTensor).result(); 794 795 // Copy original contents into new tensor 796 // Uses linalg.generic, but could be done with tensor.insert_slice 797 SmallVector<AffineExpr, 4> outputExprs; 798 for (unsigned i = 0; i < resultShapedType.getRank(); ++i) { 799 outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) + 800 padOp.static_low()[i].cast<IntegerAttr>().getInt()); 801 } 802 803 SmallVector<AffineMap, 2> transferMaps = { 804 rewriter.getMultiDimIdentityMap(inputShapedType.getRank()), 805 AffineMap::get(resultShapedType.getRank(), 806 /*symbolCount=*/0, outputExprs, rewriter.getContext())}; 807 808 rewriter.replaceOpWithNewOp<linalg::GenericOp>( 809 padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps, 810 getNParallelLoopsAttrs(resultShapedType.getRank()), 811 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { 812 nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]); 813 }); 814 815 return success(); 816 } 817 818 /// Filling `dest` using FillOp constant padding value if possible. 819 /// Otherwise, generate a tensor::GenerateOp. 820 Value GeneralizePadTensorOpPattern::createFillOrGenerateOp( 821 PatternRewriter &rewriter, PadTensorOp padOp, Value dest, 822 const SmallVector<Value> &dynSizes) const { 823 auto padValue = padOp.getConstantPaddingValue(); 824 if (padValue) 825 return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result(); 826 827 // Fill could not be optimized: Lower to tensor::GenerateOp with region. 828 auto generateOp = rewriter.create<tensor::GenerateOp>( 829 padOp.getLoc(), padOp.getResultType(), dynSizes); 830 // Copy region to new op. 831 BlockAndValueMapping bvm; 832 padOp.region().cloneInto(&generateOp.getRegion(), bvm); 833 // Rewrite linalg::YieldOp to tensor::YieldOp. 834 OpBuilder::InsertionGuard guard(rewriter); 835 auto yieldOp = 836 dyn_cast<linalg::YieldOp>(generateOp.getRegion().front().getTerminator()); 837 assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator"); 838 assert(yieldOp.values().size() == 1); 839 rewriter.setInsertionPoint(yieldOp); 840 rewriter.replaceOpWithNewOp<tensor::YieldOp>(yieldOp, yieldOp.values()[0]); 841 return generateOp; 842 } 843 844 LogicalResult 845 GeneralizePadTensorOpPattern::matchAndRewrite(PadTensorOp padOp, 846 PatternRewriter &rewriter) const { 847 // Given an OpFoldResult, return an index-typed value. 848 auto getIdxValue = [&](OpFoldResult ofr) { 849 if (auto val = ofr.dyn_cast<Value>()) 850 return val; 851 return rewriter 852 .create<arith::ConstantIndexOp>( 853 padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt()) 854 .getResult(); 855 }; 856 857 auto resultType = padOp.getResultType(); 858 // Compute size of InitTensorOp. Any combination of static/dynamic is 859 // supported. 860 SmallVector<Value> dynSizes; 861 SmallVector<int64_t> staticSizes; 862 for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { 863 if (resultType.isDynamicDim(dim)) { 864 auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(), 865 padOp.source(), dim); 866 // Add low and high padding value. 867 auto plusLow = rewriter.createOrFold<arith::AddIOp>( 868 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); 869 auto plusHigh = rewriter.createOrFold<arith::AddIOp>( 870 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); 871 dynSizes.push_back(plusHigh); 872 } 873 staticSizes.push_back(resultType.getDimSize(dim)); 874 } 875 876 // Init tensor and fill it with padding. 877 Value init = rewriter.create<InitTensorOp>( 878 padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType()); 879 Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes); 880 881 // Try optimize the copy of source. 882 if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded()) 883 return success(); 884 885 // PadTensorOps cannot be optimized. Generate a InsertSliceOp instead 886 // for copying the PadOp source. 887 auto sourceType = padOp.getSourceType(); 888 // Compute size of source of PadTensorOp. 889 SmallVector<OpFoldResult> srcSizes; 890 for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) { 891 if (sourceType.isDynamicDim(dim)) { 892 srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>( 893 padOp.getLoc(), padOp.source(), dim)); 894 } else { 895 srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim))); 896 } 897 } 898 // Strides of InsertSliceOp are all 1. 899 SmallVector<OpFoldResult> strides(sourceType.getRank(), 900 rewriter.getIndexAttr(1)); 901 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( 902 padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides); 903 904 return success(); 905 } 906 907 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( 908 tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { 909 auto padOp = sliceOp.source().getDefiningOp<PadTensorOp>(); 910 if (!padOp) 911 return failure(); 912 // Only unit stride supported. 913 if (!sliceOp.hasUnitStride()) 914 return failure(); 915 916 Operation *tiledPadOp = 917 padOp 918 .getTiledImplementation( 919 rewriter, /*dest=*/ValueRange{}, sliceOp.getMixedOffsets(), 920 sliceOp.getMixedSizes(), /*tileDestOperands=*/false) 921 .front(); 922 // All shapes are static and the data source is actually used. Rewrite into 923 // pad_tensor(subtensor(x)). 924 rewriter.replaceOp(sliceOp, tiledPadOp->getResults()); 925 return success(); 926 } 927 928 namespace { 929 // The following are patterns for downscaling convolution ops with size-1 930 // window dimensions. 931 // 932 // Note that we'd eventually want to write such transformations in a generic 933 // way, e.g., converting to linalg.generic, removing the size-1 dimensions, 934 // and then turning back to named ops. But for now it's fine to have a few 935 // patterns matching special ops to get started. 936 937 /// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D 938 /// convolution ops. 939 struct DownscaleSizeOneWindowed2DConvolution final 940 : public OpRewritePattern<Conv2DNhwcHwcfOp> { 941 DownscaleSizeOneWindowed2DConvolution( 942 MLIRContext *context, 943 LinalgTransformationFilter filter = LinalgTransformationFilter(), 944 PatternBenefit benefit = 1) 945 : OpRewritePattern<Conv2DNhwcHwcfOp>(context, benefit), filter(filter) {} 946 947 LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, 948 PatternRewriter &rewriter) const override { 949 if (failed(filter.checkAndNotify(rewriter, convOp))) 950 return failure(); 951 if (convOp.hasBufferSemantics()) 952 return failure(); // To be implemented 953 954 Value input = convOp.inputs().front(); 955 Value kernel = convOp.inputs().back(); 956 Value output = convOp.outputs().front(); 957 958 auto inputType = input.getType().dyn_cast<RankedTensorType>(); 959 auto kernelType = kernel.getType().dyn_cast<RankedTensorType>(); 960 auto outputType = output.getType().dyn_cast<RankedTensorType>(); 961 962 auto kernelShape = kernelType.getShape(); 963 auto outputShape = outputType.getShape(); 964 965 // Only handle the case where at least one of the window dimensions is 966 // of size 1. Other cases can rely on tiling to reduce to such cases. 967 int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; 968 int64_t ohSize = outputShape[1], owSize = outputShape[2]; 969 bool removeH = (khSize == 1 && ohSize == 1); 970 bool removeW = (kwSize == 1 && owSize == 1); 971 if (!removeH && !removeW) 972 return failure(); 973 974 // Get new shapes and types for all operands by removing the size-1 975 // dimension. 976 using RTTBuilder = RankedTensorType::Builder; 977 RankedTensorType newInputType = 978 RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); 979 RankedTensorType newKernelType = 980 RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); 981 RankedTensorType newOutputType = 982 RTTBuilder(outputType).dropDim(removeH ? 1 : 2); 983 984 // Rank-reduce operands. 985 Location loc = convOp.getLoc(); 986 Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( 987 rewriter, loc, input, newInputType); 988 Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( 989 rewriter, loc, kernel, newKernelType); 990 Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( 991 rewriter, loc, output, newOutputType); 992 993 // Rank-reduce strides and dilations too. 994 // TODO: dropDim 1-liner helper. 995 auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>()); 996 strides.erase(strides.begin() + (removeH ? 0 : 1)); 997 auto stridesAttr = rewriter.getI64VectorAttr(strides); 998 999 auto dilations = 1000 llvm::to_vector<4>(convOp.dilations().getValues<int64_t>()); 1001 dilations.erase(dilations.begin() + (removeH ? 0 : 1)); 1002 auto dilationsAttr = rewriter.getI64VectorAttr(dilations); 1003 1004 auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>( 1005 loc, newOutputType, ValueRange{newInput, newKernel}, 1006 ValueRange{newOutput}, stridesAttr, dilationsAttr); 1007 1008 // Insert back. 1009 Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( 1010 rewriter, loc, conv1DOp.getResult(0), output); 1011 rewriter.replaceOp(convOp, inserted); 1012 1013 filter.replaceLinalgTransformationFilter(rewriter, conv1DOp); 1014 return success(); 1015 }; 1016 1017 private: 1018 /// LinalgTransformMarker handles special attribute manipulations. 1019 LinalgTransformationFilter filter; 1020 }; 1021 1022 /// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) 1023 /// dimensions into 1-D depthwise convolution ops. 1024 struct DownscaleDepthwiseConv2DNhwcHwcOp final 1025 : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> { 1026 DownscaleDepthwiseConv2DNhwcHwcOp( 1027 MLIRContext *context, 1028 LinalgTransformationFilter filter = LinalgTransformationFilter(), 1029 PatternBenefit benefit = 1) 1030 : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit), 1031 filter(filter) {} 1032 1033 LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, 1034 PatternRewriter &rewriter) const override { 1035 if (failed(filter.checkAndNotify(rewriter, convOp))) 1036 return failure(); 1037 if (convOp.hasBufferSemantics()) 1038 return failure(); // To be implemented 1039 1040 Value input = convOp.inputs().front(); 1041 Value kernel = convOp.inputs().back(); 1042 Value output = convOp.outputs().front(); 1043 1044 auto inputType = input.getType().dyn_cast<RankedTensorType>(); 1045 auto kernelType = kernel.getType().dyn_cast<RankedTensorType>(); 1046 auto outputType = output.getType().dyn_cast<RankedTensorType>(); 1047 1048 auto kernelShape = kernelType.getShape(); 1049 auto outputShape = outputType.getShape(); 1050 1051 // Only handle the case where at least one of the window dimensions is 1052 // of size 1. Other cases can rely on tiling to reduce to such cases. 1053 int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; 1054 int64_t ohSize = outputShape[1], owSize = outputShape[2]; 1055 bool removeH = (khSize == 1 && ohSize == 1); 1056 bool removeW = (kwSize == 1 && owSize == 1); 1057 if (!removeH && !removeW) 1058 return failure(); 1059 1060 // Get new shapes and types for all operands by removing the size-1 1061 // dimension. 1062 using RTTBuilder = RankedTensorType::Builder; 1063 RankedTensorType newInputType = 1064 RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); 1065 RankedTensorType newKernelType = 1066 RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); 1067 RankedTensorType newOutputType = 1068 RTTBuilder(outputType).dropDim(removeH ? 1 : 2); 1069 1070 // Rank-reduce operands. 1071 Location loc = convOp.getLoc(); 1072 Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( 1073 rewriter, loc, input, newInputType); 1074 Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( 1075 rewriter, loc, kernel, newKernelType); 1076 Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( 1077 rewriter, loc, output, newOutputType); 1078 1079 // Rank-reduce strides and dilations too. 1080 // TODO: dropDim 1-liner helper. 1081 auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>()); 1082 strides.erase(strides.begin() + (removeH ? 0 : 1)); 1083 auto stridesAttr = rewriter.getI64VectorAttr(strides); 1084 1085 auto dilations = 1086 llvm::to_vector<4>(convOp.dilations().getValues<int64_t>()); 1087 dilations.erase(dilations.begin() + (removeH ? 0 : 1)); 1088 auto dilationsAttr = rewriter.getI64VectorAttr(dilations); 1089 1090 auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>( 1091 loc, newOutputType, ValueRange{newInput, newKernel}, 1092 ValueRange{newOutput}, stridesAttr, dilationsAttr); 1093 1094 // Insert back. 1095 Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( 1096 rewriter, loc, conv1DOp.getResult(0), output); 1097 rewriter.replaceOp(convOp, inserted); 1098 1099 filter.replaceLinalgTransformationFilter(rewriter, conv1DOp); 1100 return success(); 1101 }; 1102 1103 private: 1104 /// LinalgTransformMarker handles special attribute manipulations. 1105 LinalgTransformationFilter filter; 1106 }; 1107 1108 } // namespace 1109 1110 void linalg::populateDecomposeConvolutionPatterns( 1111 RewritePatternSet &patterns, LinalgTransformationFilter filter, 1112 PatternBenefit benefit) { 1113 patterns.add<DownscaleSizeOneWindowed2DConvolution, 1114 DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), filter, 1115 benefit); 1116 } 1117