1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Affine/Utils.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 18 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 19 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/SCF/Transforms.h" 22 #include "mlir/Dialect/Tensor/IR/Tensor.h" 23 #include "mlir/Dialect/Utils/StaticValueUtils.h" 24 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 25 #include "mlir/Dialect/Vector/VectorOps.h" 26 #include "mlir/IR/AffineExpr.h" 27 #include "mlir/IR/Matchers.h" 28 #include "mlir/Pass/Pass.h" 29 #include "mlir/Support/LLVM.h" 30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 31 #include "llvm/ADT/ScopeExit.h" 32 #include "llvm/ADT/TypeSwitch.h" 33 #include "llvm/Support/Debug.h" 34 #include "llvm/Support/raw_ostream.h" 35 #include <type_traits> 36 37 #define DEBUG_TYPE "linalg-transforms" 38 39 using namespace mlir; 40 using namespace mlir::linalg; 41 42 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 43 44 //===----------------------------------------------------------------------===// 45 // Transformations exposed as rewrite patterns. 46 //===----------------------------------------------------------------------===// 47 // Marker used as attribute name in generated Linalg rewriting transformations. 48 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 49 "__internal_linalg_transform__"; 50 51 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 52 ArrayRef<StringAttr> matchDisjunction, Optional<StringAttr> replacement) 53 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 54 replacement(replacement), matchByDefault(false) {} 55 56 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 57 FilterFunction f, ArrayRef<StringAttr> matchDisjunction, 58 Optional<StringAttr> replacement) 59 : filters(), 60 matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 61 replacement(replacement), matchByDefault(false) { 62 if (f) 63 filters.push_back(f); 64 } 65 66 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify( 67 PatternRewriter &rewriter, Operation *op) const { 68 if (llvm::any_of(filters, 69 [&](const FilterFunction &f) { return failed(f(op)); })) 70 return failure(); 71 72 auto attr = op->template getAttrOfType<StringAttr>( 73 LinalgTransforms::kLinalgTransformMarker); 74 75 if (!attr) { 76 // 1. Has no filter case and matchDisjunction is empty. 77 if (matchDisjunction.empty() || matchByDefault) 78 return success(); 79 80 // 2. Has no filter but was expecting a filter. 81 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 82 diag << " does not have any filter from list: "; 83 interleaveComma(matchDisjunction, diag); 84 }); 85 } 86 87 // 4. Match explicit filter. 88 for (auto filter : matchDisjunction) 89 if (attr.getValue() == filter) 90 return success(); 91 92 // 5. Fail to match. 93 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 94 diag << " does not have any filter from list: "; 95 interleaveComma(matchDisjunction, diag); 96 }); 97 } 98 99 void mlir::linalg::LinalgTransformationFilter:: 100 replaceLinalgTransformationFilter(PatternRewriter &rewriter, 101 Operation *op) const { 102 if (replacement.hasValue()) 103 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 104 replacement.getValue()); 105 else 106 op->removeAttr( 107 rewriter.getStringAttr(LinalgTransforms::kLinalgTransformMarker)); 108 } 109 110 bool mlir::linalg::LinalgTransformationFilter::hasReplacementFilter( 111 Operation *op) const { 112 if (!replacement) 113 return false; 114 auto attr = op->getAttr(LinalgTransforms::kLinalgTransformMarker) 115 .dyn_cast<StringAttr>(); 116 return attr && attr == replacement.getValue(); 117 } 118 119 LinalgTilingOptions & 120 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 121 assert(!tileSizeComputationFunction && "tile sizes already set"); 122 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 123 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 124 OpBuilder::InsertionGuard guard(b); 125 b.setInsertionPointToStart( 126 &op->getParentOfType<FuncOp>().getBody().front()); 127 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 128 Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); 129 return v; 130 })); 131 }; 132 return *this; 133 } 134 135 LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() { 136 assert(!tileSizeComputationFunction && "tile sizes already set"); 137 tileSizeComputationFunction = [](OpBuilder &b, Operation *op) { 138 SmallVector<Value, 4> tileSizes; 139 auto linalgOp = dyn_cast<LinalgOp>(op); 140 if (!linalgOp) 141 return tileSizes; 142 Location loc = linalgOp.getLoc(); 143 auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc); 144 AffineMap map = linalgOp.getShapesToLoopsMap(); 145 if (!map) 146 return tileSizes; 147 auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes); 148 // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile 149 // size 0). 150 for (Value shapeSize : shapeSizes) 151 tileSizes.push_back(getConstantIntValue(shapeSize).hasValue() 152 ? b.create<arith::ConstantIndexOp>(loc, 0) 153 : b.create<arith::ConstantIndexOp>(loc, 1)); 154 return tileSizes; 155 }; 156 return *this; 157 } 158 159 /// Helper function that tries to pad `opOperand`. Exit early for scalar 160 /// operands, if `paddingFunc` returns failure, or if `opOperand` is not defined 161 /// by an ExtractSliceOp. Otherwise, try to pad the operand even if it already 162 /// has a static shape. Set `result` to the result of the created PadTensorOp or 163 /// and return success if the operand either has been padded to a static shape 164 /// or already had a static shape and failure otherwise. 165 static LogicalResult padOperandToSmallestStaticBoundingBox( 166 OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand, 167 const PaddingValueComputationFunction &paddingFunc, 168 const PaddingNoFoldComputationFunction &nofoldFunc, Value &result) { 169 // Get the shape of the operand and check if it has a dynamic shape. Only 170 // return failure if the operand is not a scalar and has a dynamic shape. 171 ArrayRef<int64_t> shape = opToPad.getShape(opOperand); 172 bool hasDynamicShape = llvm::is_contained(shape, ShapedType::kDynamicSize); 173 174 // Cannot pad scalar operands. 175 if (shape.empty()) 176 return success(); 177 178 // Cannot pad if the padding value is unknown. 179 FailureOr<Value> paddingValue = paddingFunc(b, *opOperand); 180 if (failed(paddingValue)) 181 return failure(hasDynamicShape); 182 183 // Cannot construct a static bounding box if the operand is not defined by an 184 // ExtractSliceOp. 185 auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>(); 186 if (!sliceOp) 187 return failure(hasDynamicShape); 188 189 // Upper bound the `sliceOp` sizes to obtain a static bounding box. 190 SmallVector<int64_t> staticSizes; 191 staticSizes.reserve(opToPad.getRank(opOperand)); 192 auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation()); 193 for (auto size : shapedOp.getMixedSizes()) { 194 // If the size is an attribute add it directly to `staticSizes`. 195 if (size.is<Attribute>()) { 196 staticSizes.push_back( 197 size.get<Attribute>().dyn_cast<IntegerAttr>().getInt()); 198 continue; 199 } 200 // Otherwise, try to compute a constant upper bound for the size value. 201 FailureOr<int64_t> upperBound = 202 getConstantUpperBoundForIndex(size.get<Value>()); 203 if (failed(upperBound)) { 204 LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding"); 205 return failure(); 206 } 207 staticSizes.push_back(upperBound.getValue()); 208 } 209 210 // Pad the operand to the bounding box defined by `staticSizes`. 211 auto staticTensorType = RankedTensorType::get( 212 staticSizes, getElementTypeOrSelf(opOperand->get())); 213 bool nofold = nofoldFunc ? nofoldFunc(*opOperand) : false; 214 result = 215 makeComposedPadHighOp(b, opToPad->getLoc(), staticTensorType, 216 opOperand->get(), paddingValue.getValue(), nofold); 217 return success(); 218 } 219 220 FailureOr<SmallVector<Value>> 221 linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad, 222 const PaddingValueComputationFunction &paddingFunc, 223 const PaddingNoFoldComputationFunction &nofoldFunc, 224 LinalgOp &paddedOp) { 225 Location loc = opToPad->getLoc(); 226 227 // TODO: there are cases where we may still want to pad to larger sizes. 228 assert(opToPad.hasTensorSemantics() && 229 "expected operation to have tensor semantics"); 230 231 OpBuilder::InsertionGuard g(b); 232 // Set IP after op because we also take the dims of the original output. 233 b.setInsertionPointAfter(opToPad); 234 // Make a copy of the shaped operands and update it. 235 SmallVector<Value> newOperands; 236 newOperands.reserve(opToPad.getNumInputsAndOutputs()); 237 for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) { 238 Value paddedOperand; 239 // If padding was requested but the shape cannot be bounded statically then 240 // the pattern fails to apply. 241 if (failed(padOperandToSmallestStaticBoundingBox( 242 b, opToPad, opOperand, paddingFunc, nofoldFunc, paddedOperand))) 243 return failure(); 244 newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get()); 245 } 246 247 SmallVector<SmallVector<Value>> reifiedResultShapes; 248 if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation()) 249 .reifyResultShapes(b, reifiedResultShapes))) 250 return failure(); 251 assert(reifiedResultShapes.size() == opToPad->getNumResults() && 252 "expected same number of results"); 253 254 // Clone `opToPad` to operate on the statically padded shapes. 255 auto resultTensorTypes = 256 ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes(); 257 paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands); 258 259 // Recover the slice out of the new static results. This keeps the original 260 // linalg op around because it uses the dims of the original results. 261 SmallVector<Value> paddedSubviewResults; 262 paddedSubviewResults.reserve(opToPad->getNumResults()); 263 for (auto en : llvm::enumerate(paddedOp->getResults())) { 264 Value paddedResult = en.value(); 265 int64_t resultNumber = en.index(); 266 int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank(); 267 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0)); 268 SmallVector<OpFoldResult> sizes; 269 for (Value v : reifiedResultShapes[resultNumber]) 270 sizes.push_back(getAsOpFoldResult(v)); 271 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1)); 272 paddedSubviewResults.push_back(b.create<tensor::ExtractSliceOp>( 273 loc, paddedResult, offsets, sizes, strides)); 274 } 275 return paddedSubviewResults; 276 } 277 278 /// Linalg base tiling pattern. 279 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 280 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 281 LinalgTransformationFilter filter, PatternBenefit benefit) 282 : RewritePattern(opName, benefit, context), filter(filter), 283 options(options) {} 284 285 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 286 MLIRContext *context, LinalgTilingOptions options, 287 LinalgTransformationFilter filter, PatternBenefit benefit) 288 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 289 options(options) {} 290 291 /// Try to peel a loop `op` and return the new result. 292 // TODO: Add support for scf.parallel and affine.for loops. 293 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) { 294 return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op) 295 .Case<scf::ForOp>([&](scf::ForOp forOp) { 296 scf::ForOp partialIteration; 297 if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp, 298 partialIteration))) 299 return partialIteration->getResults(); 300 assert(!partialIteration && "expected that loop was not peeled"); 301 return forOp->getResults(); 302 }) 303 .Default([&](Operation *op) { return op->getResults(); }); 304 } 305 306 /// Try to peel a TiledLoopOp and return the new result. 307 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, 308 TiledLoopOp tiledLoop, int64_t idx) { 309 assert(idx < static_cast<int64_t>(tiledLoop.iterator_types().size()) && 310 "requested peeling of non-existing loop"); 311 TiledLoopOp result; 312 if (succeeded(peelAndCanonicalizeTiledLoop(rewriter, tiledLoop, idx, result))) 313 return result->getResults(); 314 assert(!result && "expected that loop was not peeled"); 315 return tiledLoop->getResults(); 316 } 317 318 /// Peel loops after tiling. 319 static void peelLoops(RewriterBase &rewriter, TiledLinalgOp &res, 320 const LinalgTilingOptions &options) { 321 for (int64_t loop : options.peeledLoops) { 322 assert(loop < static_cast<int64_t>(res.loops.size()) && 323 "requested peeling of non-existing loop"); 324 SmallVector<Value, 4> loopResults; 325 Operation *loopOp = res.loops[loop]; 326 if (options.loopType == LinalgTilingLoopType::TiledLoops) { 327 assert(llvm::all_of( 328 res.loops, 329 [&](Operation *op) { return op == res.loops.front(); }) && 330 "expected that all loop ops are the same TiledLoopOp"); 331 auto tiledLoopOp = dyn_cast<TiledLoopOp>(loopOp); 332 assert(tiledLoopOp && "expected TiledLoopOp"); 333 loopResults = peelLoop(rewriter, tiledLoopOp, loop); 334 } else { 335 loopResults = peelLoop(rewriter, loopOp); 336 } 337 338 // The result of the loop nest may change with peeling. 339 if (res.tensorResults.size() == loopOp->getNumResults() && 340 std::equal(res.tensorResults.begin(), res.tensorResults.end(), 341 loopOp->getResults().begin())) 342 res.tensorResults = loopResults; 343 } 344 } 345 346 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase( 347 Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const { 348 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 349 if (!linalgOp) 350 return failure(); 351 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 352 return failure(); 353 354 Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options); 355 356 if (!res) 357 return failure(); 358 // Clear filter to stop recursive pattern application. 359 filter.replaceLinalgTransformationFilter(rewriter, res->op); 360 361 // Peel loops. 362 peelLoops(rewriter, *res, options); 363 364 result = *res; 365 return success(); 366 } 367 368 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) { 369 if (tiledOp.loops.empty()) 370 return tiledOp.op.getOperation()->getResults(); 371 return tiledOp.loops.front()->getResults(); 372 } 373 374 static ValueRange 375 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) { 376 if (tiledAndFusedOp.fusedLoops.empty()) 377 return tiledAndFusedOp.op.getOperation()->getResults(); 378 return tiledAndFusedOp.fusedLoops.front()->getResults(); 379 } 380 381 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( 382 StringRef opName, MLIRContext *context, 383 const LinalgDependenceGraph &dependenceGraph, 384 LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, 385 LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker, 386 LinalgTransformationFilter originalOpMarker, PatternBenefit benefit) 387 : RewritePattern(opName, benefit, context, {}), 388 dependenceGraph(dependenceGraph), tilingOptions(tilingOptions), 389 fusionOptions(fusionOptions), filter(filter), 390 fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {} 391 392 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( 393 Operation *op, PatternRewriter &rewriter) const { 394 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 395 // TODO: remove hasIndexSemantics check once index ops are supported. 396 if (!linalgOp || linalgOp.hasIndexSemantics()) 397 return failure(); 398 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 399 return failure(); 400 401 DenseSet<Operation *> producers; 402 producers.insert(linalgOp); 403 for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) { 404 Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum(); 405 // When looking at dependences into, indexingOp is always OpOperand. We 406 // could assert, but continue if this is not the case. 407 if (!operandNumber) 408 continue; 409 if (!fusionOptions.indicesToFuse.count(operandNumber.getValue())) 410 continue; 411 if (isa<LinalgOp>(dependence.getDependentOp())) 412 producers.insert(dependence.getDependentOp()); 413 } 414 415 SmallVector<LinalgOp, 1> fusionOps; 416 for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; 417 ++it) { 418 auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it)); 419 if (producerLinalgOp && producers.count(producerLinalgOp)) 420 fusionOps.push_back(producerLinalgOp); 421 } 422 fusionOps.push_back(linalgOp); 423 424 SmallVector<Value, 4> tileSizes = 425 tilingOptions.tileSizeComputationFunction(rewriter, op); 426 LinalgTilingOptions instanceTilingOptions = tilingOptions; 427 instanceTilingOptions.setTileSizes(tileSizes); 428 Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps( 429 rewriter, fusionOps, dependenceGraph, instanceTilingOptions); 430 if (!tiledAndFusedOps) 431 return failure(); 432 433 // Tile the unfused loops; 434 SmallVector<Value, 4> unfusedLoopTileSizes; 435 Value zero = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0); 436 for (auto tileSize : enumerate(tileSizes)) { 437 if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) 438 unfusedLoopTileSizes.push_back(zero); 439 else 440 unfusedLoopTileSizes.push_back(tileSize.value()); 441 } 442 // Tile the loop only if there is a non-zero tile size. 443 if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) 444 unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); 445 if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { 446 if (auto cst = val.getDefiningOp<arith::ConstantIndexOp>()) 447 return cst.value() != 0; 448 return true; 449 })) { 450 LinalgTilingOptions unfusedTilingOptions = tilingOptions; 451 unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); 452 Optional<TiledLinalgOp> unfusedTiledOp = 453 tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); 454 if (!unfusedTiledOp) 455 return failure(); 456 rewriter.replaceOp(tiledAndFusedOps->op, 457 getTiledOpResult(unfusedTiledOp.getValue())); 458 tiledAndFusedOps->op = unfusedTiledOp->op; 459 } 460 op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue())); 461 462 filter.replaceLinalgTransformationFilter(rewriter, 463 tiledAndFusedOps->op.getOperation()); 464 for (auto fusedOp : tiledAndFusedOps->fusedProducers) { 465 fusedOpMarker.replaceLinalgTransformationFilter(rewriter, 466 fusedOp.getOperation()); 467 } 468 for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) { 469 originalOpMarker.replaceLinalgTransformationFilter( 470 rewriter, origProducerOp.getOperation()); 471 } 472 rewriter.updateRootInPlace(op, [&]() { 473 originalOpMarker.replaceLinalgTransformationFilter(rewriter, op); 474 }); 475 return success(); 476 } 477 478 /// Linalg padding pattern. 479 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( 480 MLIRContext *context, LinalgPaddingOptions options, 481 LinalgTransformationFilter filter, PatternBenefit benefit) 482 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 483 options(options) {} 484 485 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( 486 StringRef opName, MLIRContext *context, LinalgPaddingOptions options, 487 LinalgTransformationFilter filter, PatternBenefit benefit) 488 : RewritePattern(opName, benefit, context, {}), filter(filter), 489 options(options) {} 490 491 LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite( 492 Operation *op, PatternRewriter &rewriter) const { 493 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 494 if (!linalgOp) 495 return failure(); 496 if (!linalgOp.hasTensorSemantics()) 497 return failure(); 498 if (failed(filter.checkAndNotify(rewriter, op))) 499 return failure(); 500 501 // Pad the operation. 502 LinalgOp paddedOp; 503 FailureOr<SmallVector<Value>> newResults = rewriteAsPaddedOp( 504 rewriter, linalgOp, options.paddingValueComputationFunction, 505 options.paddingNoFoldComputationFunction, paddedOp); 506 if (failed(newResults)) { 507 filter.replaceLinalgTransformationFilter(rewriter, linalgOp); 508 return failure(); 509 } 510 511 // Compute the desired hoisting depths. 512 SmallVector<int64_t> depths; 513 if (options.paddingHoistComputationFunction) { 514 for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) 515 depths.push_back(options.paddingHoistComputationFunction(*opOperand)); 516 } 517 518 // Hoist the padding. 519 for (auto en : enumerate(depths)) { 520 OpOperand &opOperand = paddedOp->getOpOperand(en.index()); 521 auto padTensorOp = opOperand.get().getDefiningOp<PadTensorOp>(); 522 if (!padTensorOp || en.value() == 0) 523 continue; 524 PadTensorOp hoistedOp; 525 FailureOr<Value> newResult = 526 hoistPaddingOnTensors(padTensorOp, en.value(), hoistedOp); 527 if (failed(newResult)) 528 continue; 529 rewriter.replaceOp(padTensorOp, newResult.getValue()); 530 } 531 532 // Replace the original operation to pad. 533 rewriter.replaceOp(op, newResults.getValue()); 534 filter.replaceLinalgTransformationFilter(rewriter, paddedOp); 535 return success(); 536 } 537 538 /// Linalg tile and fuse tensor ops pattern. 539 mlir::linalg::LinalgTileAndFuseTensorOpsPattern:: 540 LinalgTileAndFuseTensorOpsPattern(MLIRContext *context, 541 LinalgTilingAndFusionOptions options, 542 LinalgTransformationFilter filter, 543 PatternBenefit benefit) 544 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 545 options(options) {} 546 547 mlir::linalg::LinalgTileAndFuseTensorOpsPattern:: 548 LinalgTileAndFuseTensorOpsPattern(StringRef opName, MLIRContext *context, 549 LinalgTilingAndFusionOptions options, 550 LinalgTransformationFilter filter, 551 PatternBenefit benefit) 552 : RewritePattern(opName, benefit, context), filter(filter), 553 options(options) {} 554 555 LogicalResult mlir::linalg::LinalgTileAndFuseTensorOpsPattern::matchAndRewrite( 556 Operation *op, PatternRewriter &rewriter) const { 557 LinalgOp rootOp = dyn_cast<LinalgOp>(op); 558 if (!rootOp) 559 return failure(); 560 if (failed(filter.checkAndNotify(rewriter, op))) 561 return failure(); 562 563 // Check `tileSizes` contains a tile size for every `rootOp` loop dimension. 564 if (options.tileSizes.size() < rootOp.getNumLoops()) 565 return rewriter.notifyMatchFailure(op, "expect #tile sizes >= #loops"); 566 567 // Check `tileInterchange` contains no entries or as many as `tileSizes`. 568 if (!options.tileInterchange.empty() && 569 options.tileInterchange.size() != options.tileSizes.size()) 570 return rewriter.notifyMatchFailure( 571 op, "expect the number of tile sizes and interchange dims to match"); 572 573 // Copy the `tileSizes` and `tileInterchange` prefixes needed for `rootOp`. 574 SmallVector<int64_t> rootTileSizes(options.tileSizes.begin(), 575 options.tileSizes.begin() + 576 rootOp.getNumLoops()); 577 SmallVector<int64_t> rootInterchange = 578 options.tileInterchange.empty() 579 ? llvm::to_vector<6>(llvm::seq<int64_t>(0, rootOp.getNumLoops())) 580 : SmallVector<int64_t>(options.tileInterchange.begin(), 581 options.tileInterchange.begin() + 582 rootOp.getNumLoops()); 583 584 // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions. 585 // It has to be a permutation since the tiling cannot tile the same loop 586 // dimension multiple times. 587 if (!isPermutation(rootInterchange)) 588 return rewriter.notifyMatchFailure( 589 op, "expect the tile interchange permutes the root loops"); 590 591 // Tile `rootOp` and fuse its producers. 592 FailureOr<TileLoopNest> tileLoopNest = tileConsumerAndFuseProducers( 593 rewriter, rootOp, rootTileSizes, rootInterchange); 594 if (failed(tileLoopNest)) 595 return rewriter.notifyMatchFailure( 596 op, "tileConsumerAndFuseProducers failed unexpectedly"); 597 598 // Replace all uses of the tiled loop operation. 599 rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults()); 600 601 // Apply the filter if specified. 602 for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps()) 603 filter.replaceLinalgTransformationFilter(rewriter, linalgOp); 604 return failure(); 605 } 606 607 /// Linalg generic interchange pattern. 608 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern( 609 MLIRContext *context, ArrayRef<unsigned> interchangeVector, 610 LinalgTransformationFilter filter, PatternBenefit benefit) 611 : OpRewritePattern(context, benefit), filter(filter), 612 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 613 614 LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite( 615 GenericOp genericOp, PatternRewriter &rewriter) const { 616 if (failed(filter.checkAndNotify(rewriter, genericOp))) 617 return failure(); 618 if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector))) 619 return failure(); 620 621 // TODO: figure out how this interplays with named ops. In particular this 622 // should break the named op property. 623 rewriter.updateRootInPlace(genericOp, [&]() { 624 interchangeGenericOp(rewriter, genericOp, interchangeVector); 625 // New filter if specified. 626 filter.replaceLinalgTransformationFilter(rewriter, genericOp); 627 }); 628 return success(); 629 } 630 631 /// Linalg generalization pattern. 632 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 633 MLIRContext *context, LinalgTransformationFilter filter, 634 PatternBenefit benefit) 635 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {} 636 637 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 638 StringRef opName, MLIRContext *context, LinalgTransformationFilter filter, 639 PatternBenefit benefit) 640 : RewritePattern(opName, benefit, context, {}), filter(filter) {} 641 642 LogicalResult mlir::linalg::LinalgGeneralizationPattern::matchAndRewrite( 643 Operation *op, PatternRewriter &rewriter) const { 644 if (failed(filter.checkAndNotify(rewriter, op))) 645 return failure(); 646 if (failed(generalizeNamedOpPrecondition(op))) 647 return failure(); 648 649 GenericOp genericOp = generalizeNamedOp(rewriter, op); 650 rewriter.replaceOp(op, genericOp.getResults()); 651 filter.replaceLinalgTransformationFilter(rewriter, genericOp); 652 return success(); 653 } 654 655 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 656 MLIRContext *context, LinalgTransformationFilter filter, 657 LinalgPromotionOptions options, PatternBenefit benefit) 658 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 659 options(options) {} 660 661 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 662 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 663 LinalgTransformationFilter filter, PatternBenefit benefit) 664 : RewritePattern(opName, benefit, context, {}), filter(filter), 665 options(options) {} 666 667 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 668 Operation *op, PatternRewriter &rewriter) const { 669 if (failed(filter.checkAndNotify(rewriter, op))) 670 return failure(); 671 if (failed(promoteSubviewsPrecondition(op, options))) 672 return failure(); 673 674 // TODO: We cannot use root update here. This pattern is creating other ops, 675 // so if the promotion fails, those need to be cleaned up, which doesnt seem 676 // to be happening here. So to fail properly, we should be cloning the op and 677 // deleting the previous op. This needs more investigation. 678 rewriter.startRootUpdate(op); 679 Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options); 680 if (!promotedOp) { 681 rewriter.cancelRootUpdate(op); 682 return op->emitError("subview promotion failed"); 683 } 684 rewriter.finalizeRootUpdate(op); 685 filter.replaceLinalgTransformationFilter(rewriter, op); 686 return success(); 687 } 688 689 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 690 MLIRContext *context, LinalgTransformationFilter filter, 691 PatternBenefit benefit) 692 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {} 693 694 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 695 StringRef opName, MLIRContext *context, LinalgTransformationFilter filter, 696 PatternBenefit benefit) 697 : RewritePattern(opName, benefit, context, {}), filter(filter) {} 698 699 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( 700 Operation *op, PatternRewriter &rewriter) const { 701 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 702 if (!linalgOp) 703 return failure(); 704 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 705 return failure(); 706 SmallVector<Value> newResults; 707 if (failed(vectorizeLinalgOp(rewriter, op, newResults))) 708 return failure(); 709 if (!newResults.empty()) 710 rewriter.replaceOp(op, newResults); 711 else 712 rewriter.eraseOp(op); 713 return success(); 714 } 715 716 LogicalResult mlir::linalg::applyStagedPatterns( 717 Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns, 718 const FrozenRewritePatternSet &stage2Patterns, 719 function_ref<LogicalResult(Operation *)> stage3Lambda) { 720 unsigned iteration = 0; 721 (void)iteration; 722 for (const auto &patterns : stage1Patterns) { 723 LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" 724 << *op); 725 if (failed(applyPatternsAndFoldGreedily(op, patterns))) { 726 LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); 727 return failure(); 728 } 729 LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" 730 << *op); 731 if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { 732 LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); 733 return failure(); 734 } 735 LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" 736 << *op); 737 if (stage3Lambda) { 738 if (failed(stage3Lambda(op))) 739 return failure(); 740 LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" 741 << *op); 742 } 743 } 744 return success(); 745 } 746 747 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) { 748 return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName()); 749 } 750 751 /// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize 752 /// with pad_val) and GenericOp (to copy contents). 753 LogicalResult PadTensorOpTransformationPattern::matchAndRewrite( 754 linalg::PadTensorOp padOp, PatternRewriter &rewriter) const { 755 756 auto inputShapedType = padOp.source().getType().cast<ShapedType>(); 757 auto resultShapedType = padOp.result().getType().cast<ShapedType>(); 758 759 // Bail on non-static shapes. 760 if (!inputShapedType.hasStaticShape()) 761 return failure(); 762 if (!resultShapedType.hasStaticShape()) 763 return failure(); 764 765 // Only support padding with a constant for now, i.e. either: 766 // 1. A BBarg from a different block. 767 // 2. A value defined outside of the current block. 768 Block &block = padOp.region().front(); 769 auto yieldOp = cast<YieldOp>(block.getTerminator()); 770 assert(yieldOp.getNumOperands() == 1 && "expected single operand yield"); 771 Value padValue = yieldOp.values().front(); 772 Operation *definingOp = padValue.getDefiningOp(); 773 if (definingOp && definingOp->getBlock() == &block) 774 return failure(); 775 if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block) 776 return failure(); 777 778 // Create tensor with the padded shape 779 Location loc = padOp.getLoc(); 780 SmallVector<Value> indices(resultShapedType.getRank(), 781 rewriter.create<arith::ConstantIndexOp>(loc, 0)); 782 Value initTensor = rewriter.create<InitTensorOp>( 783 loc, resultShapedType.getShape(), resultShapedType.getElementType()); 784 785 // Initialize tensor with the pad value 786 Value tmpTensor = 787 rewriter.create<linalg::FillOp>(loc, padValue, initTensor).result(); 788 789 // Copy original contents into new tensor 790 // Uses linalg.generic, but could be done with tensor.insert_slice 791 SmallVector<AffineExpr, 4> outputExprs; 792 for (unsigned i = 0; i < resultShapedType.getRank(); ++i) { 793 outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) + 794 padOp.static_low()[i].cast<IntegerAttr>().getInt()); 795 } 796 797 SmallVector<AffineMap, 2> transferMaps = { 798 rewriter.getMultiDimIdentityMap(inputShapedType.getRank()), 799 AffineMap::get(resultShapedType.getRank(), 800 /*symbolCount=*/0, outputExprs, rewriter.getContext())}; 801 802 rewriter.replaceOpWithNewOp<linalg::GenericOp>( 803 padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps, 804 getNParallelLoopsAttrs(resultShapedType.getRank()), 805 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { 806 nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]); 807 }); 808 809 return success(); 810 } 811 812 /// Filling `dest` using FillOp constant padding value if possible. 813 /// Otherwise, generate a tensor::GenerateOp. 814 Value GeneralizePadTensorOpPattern::createFillOrGenerateOp( 815 PatternRewriter &rewriter, PadTensorOp padOp, Value dest, 816 const SmallVector<Value> &dynSizes) const { 817 auto padValue = padOp.getConstantPaddingValue(); 818 if (padValue) 819 return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result(); 820 821 // Fill could not be optimized: Lower to tensor::GenerateOp with region. 822 auto generateOp = rewriter.create<tensor::GenerateOp>( 823 padOp.getLoc(), padOp.getResultType(), dynSizes); 824 // Copy region to new op. 825 BlockAndValueMapping bvm; 826 padOp.region().cloneInto(&generateOp.getRegion(), bvm); 827 // Rewrite linalg::YieldOp to tensor::YieldOp. 828 OpBuilder::InsertionGuard guard(rewriter); 829 auto yieldOp = 830 dyn_cast<linalg::YieldOp>(generateOp.getRegion().front().getTerminator()); 831 assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator"); 832 assert(yieldOp.values().size() == 1); 833 rewriter.setInsertionPoint(yieldOp); 834 rewriter.replaceOpWithNewOp<tensor::YieldOp>(yieldOp, yieldOp.values()[0]); 835 return generateOp; 836 } 837 838 LogicalResult 839 GeneralizePadTensorOpPattern::matchAndRewrite(PadTensorOp padOp, 840 PatternRewriter &rewriter) const { 841 // Given an OpFoldResult, return an index-typed value. 842 auto getIdxValue = [&](OpFoldResult ofr) { 843 if (auto val = ofr.dyn_cast<Value>()) 844 return val; 845 return rewriter 846 .create<arith::ConstantIndexOp>( 847 padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt()) 848 .getResult(); 849 }; 850 851 auto resultType = padOp.getResultType(); 852 // Compute size of InitTensorOp. Any combination of static/dynamic is 853 // supported. 854 SmallVector<Value> dynSizes; 855 SmallVector<int64_t> staticSizes; 856 for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { 857 if (resultType.isDynamicDim(dim)) { 858 auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(), 859 padOp.source(), dim); 860 // Add low and high padding value. 861 auto plusLow = rewriter.createOrFold<arith::AddIOp>( 862 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); 863 auto plusHigh = rewriter.createOrFold<arith::AddIOp>( 864 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); 865 dynSizes.push_back(plusHigh); 866 } 867 staticSizes.push_back(resultType.getDimSize(dim)); 868 } 869 870 // Init tensor and fill it with padding. 871 Value init = rewriter.create<InitTensorOp>( 872 padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType()); 873 Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes); 874 875 // Try optimize the copy of source. 876 if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded()) 877 return success(); 878 879 // PadTensorOps cannot be optimized. Generate a InsertSliceOp instead 880 // for copying the PadOp source. 881 auto sourceType = padOp.getSourceType(); 882 // Compute size of source of PadTensorOp. 883 SmallVector<OpFoldResult> srcSizes; 884 for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) { 885 if (sourceType.isDynamicDim(dim)) { 886 srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>( 887 padOp.getLoc(), padOp.source(), dim)); 888 } else { 889 srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim))); 890 } 891 } 892 // Strides of InsertSliceOp are all 1. 893 SmallVector<OpFoldResult> strides(sourceType.getRank(), 894 rewriter.getIndexAttr(1)); 895 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( 896 padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides); 897 898 return success(); 899 } 900 901 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( 902 tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { 903 auto padOp = sliceOp.source().getDefiningOp<PadTensorOp>(); 904 if (!padOp) 905 return failure(); 906 // Only unit stride supported. 907 if (!sliceOp.hasUnitStride()) 908 return failure(); 909 910 Operation *tiledPadOp = padOp.getTiledImplementation( 911 rewriter, /*dest=*/ValueRange{}, sliceOp.getMixedOffsets(), 912 sliceOp.getMixedSizes()); 913 // All shapes are static and the data source is actually used. Rewrite into 914 // pad_tensor(subtensor(x)). 915 rewriter.replaceOp(sliceOp, tiledPadOp->getResults()); 916 return success(); 917 } 918 919 namespace { 920 // The following are patterns for downscaling convolution ops with size-1 921 // window dimensions. 922 // 923 // Note that we'd eventually want to write such transformations in a generic 924 // way, e.g., converting to linalg.generic, removing the size-1 dimensions, 925 // and then turning back to named ops. But for now it's fine to have a few 926 // patterns matching special ops to get started. 927 928 /// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D 929 /// convolution ops. 930 struct DownscaleSizeOneWindowed2DConvolution final 931 : public OpRewritePattern<Conv2DNhwcHwcfOp> { 932 using OpRewritePattern::OpRewritePattern; 933 934 LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, 935 PatternRewriter &rewriter) const override { 936 auto linalgOp = cast<linalg::LinalgOp>(*convOp); 937 if (linalgOp.hasBufferSemantics()) 938 return failure(); // To be implemented 939 940 Value input = convOp.inputs().front(); 941 Value filter = convOp.inputs().back(); 942 Value output = convOp.outputs().front(); 943 944 auto inputType = input.getType().dyn_cast<RankedTensorType>(); 945 auto filterType = filter.getType().dyn_cast<RankedTensorType>(); 946 auto outputType = output.getType().dyn_cast<RankedTensorType>(); 947 948 auto filterShape = filterType.getShape(); 949 auto outputShape = outputType.getShape(); 950 951 // Only handle the case where at least one of the window dimensions is 952 // of size 1. Other cases can rely on tiling to reduce to such cases. 953 int64_t fhSize = filterShape[0], fwSize = filterShape[1]; 954 int64_t ohSize = outputShape[1], owSize = outputShape[2]; 955 bool removeH = (fhSize == 1 && ohSize == 1); 956 bool removeW = (fwSize == 1 && owSize == 1); 957 if (!removeH && !removeW) 958 return failure(); 959 960 // Get new shapes and types for all operands by removing the size-1 961 // dimension. 962 using RTTBuilder = RankedTensorType::Builder; 963 RankedTensorType newInputType = 964 RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); 965 RankedTensorType newFilterType = 966 RTTBuilder(filterType).dropDim((removeH ? 0 : 1)); 967 RankedTensorType newOutputType = 968 RTTBuilder(outputType).dropDim(removeH ? 1 : 2); 969 970 // Rank-reduce operands. 971 Location loc = convOp.getLoc(); 972 Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( 973 rewriter, loc, input, newInputType); 974 Value newFilter = tensor::createCanonicalRankReducingExtractSliceOp( 975 rewriter, loc, filter, newFilterType); 976 Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( 977 rewriter, loc, output, newOutputType); 978 979 // Rank-reduce strides and dilations too. 980 // TODO: dropDim 1-liner helper. 981 auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>()); 982 strides.erase(strides.begin() + (removeH ? 0 : 1)); 983 auto stridesAttr = rewriter.getI64VectorAttr(strides); 984 985 auto dilations = 986 llvm::to_vector<4>(convOp.dilations().getValues<int64_t>()); 987 dilations.erase(dilations.begin() + (removeH ? 0 : 1)); 988 auto dilationsAttr = rewriter.getI64VectorAttr(dilations); 989 990 auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>( 991 loc, newOutputType, ValueRange{newInput, newFilter}, 992 ValueRange{newOutput}, stridesAttr, dilationsAttr); 993 994 // Insert back. 995 Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( 996 rewriter, loc, conv1DOp.getResult(0), output); 997 rewriter.replaceOp(convOp, inserted); 998 999 return success(); 1000 }; 1001 }; 1002 1003 /// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) 1004 /// dimensions into 1-D depthwise convolution ops. 1005 struct DownscaleDepthwiseConv2DNhwcHwcOp final 1006 : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> { 1007 using OpRewritePattern::OpRewritePattern; 1008 1009 LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, 1010 PatternRewriter &rewriter) const override { 1011 auto linalgOp = cast<linalg::LinalgOp>(*convOp); 1012 if (linalgOp.hasBufferSemantics()) 1013 return failure(); // To be implemented 1014 1015 Value input = convOp.inputs().front(); 1016 Value kernel = convOp.inputs().back(); 1017 Value output = convOp.outputs().front(); 1018 1019 auto inputType = input.getType().dyn_cast<RankedTensorType>(); 1020 auto kernelType = kernel.getType().dyn_cast<RankedTensorType>(); 1021 auto outputType = output.getType().dyn_cast<RankedTensorType>(); 1022 1023 auto kernelShape = kernelType.getShape(); 1024 auto outputShape = outputType.getShape(); 1025 1026 // Only handle the case where at least one of the window dimensions is 1027 // of size 1. Other cases can rely on tiling to reduce to such cases. 1028 int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; 1029 int64_t ohSize = outputShape[1], owSize = outputShape[2]; 1030 bool removeH = (khSize == 1 && ohSize == 1); 1031 bool removeW = (kwSize == 1 && owSize == 1); 1032 if (!removeH && !removeW) 1033 return failure(); 1034 1035 // Get new shapes and types for all operands by removing the size-1 1036 // dimension. 1037 using RTTBuilder = RankedTensorType::Builder; 1038 RankedTensorType newInputType = 1039 RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); 1040 RankedTensorType newKernelType = 1041 RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); 1042 RankedTensorType newOutputType = 1043 RTTBuilder(outputType).dropDim(removeH ? 1 : 2); 1044 1045 // Rank-reduce operands. 1046 Location loc = convOp.getLoc(); 1047 Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( 1048 rewriter, loc, input, newInputType); 1049 Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( 1050 rewriter, loc, kernel, newKernelType); 1051 Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( 1052 rewriter, loc, output, newOutputType); 1053 1054 // Rank-reduce strides and dilations too. 1055 // TODO: dropDim 1-liner helper. 1056 auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>()); 1057 strides.erase(strides.begin() + (removeH ? 0 : 1)); 1058 auto stridesAttr = rewriter.getI64VectorAttr(strides); 1059 1060 auto dilations = 1061 llvm::to_vector<4>(convOp.dilations().getValues<int64_t>()); 1062 dilations.erase(dilations.begin() + (removeH ? 0 : 1)); 1063 auto dilationsAttr = rewriter.getI64VectorAttr(dilations); 1064 1065 auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>( 1066 loc, newOutputType, ValueRange{newInput, newKernel}, 1067 ValueRange{newOutput}, stridesAttr, dilationsAttr); 1068 1069 // Insert back. 1070 Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( 1071 rewriter, loc, conv1DOp.getResult(0), output); 1072 rewriter.replaceOp(convOp, inserted); 1073 1074 return success(); 1075 }; 1076 }; 1077 1078 } // namespace 1079 1080 void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, 1081 PatternBenefit benefit) { 1082 patterns.add<DownscaleSizeOneWindowed2DConvolution, 1083 DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), 1084 benefit); 1085 } 1086