1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Affine/Utils.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 18 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 19 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/SCF/Transforms.h" 22 #include "mlir/Dialect/Tensor/IR/Tensor.h" 23 #include "mlir/Dialect/Utils/StaticValueUtils.h" 24 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 25 #include "mlir/Dialect/Vector/VectorOps.h" 26 #include "mlir/IR/AffineExpr.h" 27 #include "mlir/IR/Matchers.h" 28 #include "mlir/Pass/Pass.h" 29 #include "mlir/Support/LLVM.h" 30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 31 #include "llvm/ADT/ScopeExit.h" 32 #include "llvm/ADT/TypeSwitch.h" 33 #include "llvm/Support/Debug.h" 34 #include "llvm/Support/raw_ostream.h" 35 #include <type_traits> 36 37 #define DEBUG_TYPE "linalg-transforms" 38 39 using namespace mlir; 40 using namespace mlir::linalg; 41 42 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 43 44 //===----------------------------------------------------------------------===// 45 // Transformations exposed as rewrite patterns. 46 //===----------------------------------------------------------------------===// 47 // Marker used as attribute name in generated Linalg rewriting transformations. 48 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 49 "__internal_linalg_transform__"; 50 51 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 52 ArrayRef<StringAttr> matchDisjunction, Optional<StringAttr> replacement) 53 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 54 replacement(replacement), matchByDefault(false) {} 55 56 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 57 FilterFunction f, ArrayRef<StringAttr> matchDisjunction, 58 Optional<StringAttr> replacement) 59 : filters(), 60 matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 61 replacement(replacement), matchByDefault(false) { 62 if (f) 63 filters.push_back(f); 64 } 65 66 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify( 67 PatternRewriter &rewriter, Operation *op) const { 68 if (llvm::any_of(filters, 69 [&](const FilterFunction &f) { return failed(f(op)); })) 70 return failure(); 71 72 auto attr = op->template getAttrOfType<StringAttr>( 73 LinalgTransforms::kLinalgTransformMarker); 74 75 if (!attr) { 76 // 1. Has no filter case and matchDisjunction is empty. 77 if (matchDisjunction.empty() || matchByDefault) 78 return success(); 79 80 // 2. Has no filter but was expecting a filter. 81 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 82 diag << " does not have any filter from list: "; 83 interleaveComma(matchDisjunction, diag); 84 }); 85 } 86 87 // 4. Match explicit filter. 88 for (auto filter : matchDisjunction) 89 if (attr.getValue() == filter) 90 return success(); 91 92 // 5. Fail to match. 93 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 94 diag << " does not have any filter from list: "; 95 interleaveComma(matchDisjunction, diag); 96 }); 97 } 98 99 void mlir::linalg::LinalgTransformationFilter:: 100 replaceLinalgTransformationFilter(PatternRewriter &rewriter, 101 Operation *op) const { 102 if (replacement.hasValue()) 103 op->setAttr(LinalgTransforms::kLinalgTransformMarker, 104 replacement.getValue()); 105 else 106 op->removeAttr( 107 rewriter.getStringAttr(LinalgTransforms::kLinalgTransformMarker)); 108 } 109 110 bool mlir::linalg::LinalgTransformationFilter::hasReplacementFilter( 111 Operation *op) const { 112 if (!replacement) 113 return false; 114 auto attr = op->getAttr(LinalgTransforms::kLinalgTransformMarker) 115 .dyn_cast<StringAttr>(); 116 return attr && attr == replacement.getValue(); 117 } 118 119 LinalgTilingOptions & 120 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 121 assert(!tileSizeComputationFunction && "tile sizes already set"); 122 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 123 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 124 OpBuilder::InsertionGuard guard(b); 125 b.setInsertionPointToStart( 126 &op->getParentOfType<FuncOp>().getBody().front()); 127 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 128 Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); 129 return v; 130 })); 131 }; 132 return *this; 133 } 134 135 LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() { 136 assert(!tileSizeComputationFunction && "tile sizes already set"); 137 tileSizeComputationFunction = [](OpBuilder &b, Operation *op) { 138 SmallVector<Value, 4> tileSizes; 139 auto linalgOp = dyn_cast<LinalgOp>(op); 140 if (!linalgOp) 141 return tileSizes; 142 Location loc = linalgOp.getLoc(); 143 auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc); 144 AffineMap map = linalgOp.getShapesToLoopsMap(); 145 if (!map) 146 return tileSizes; 147 auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes); 148 // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile 149 // size 0). 150 for (Value shapeSize : shapeSizes) 151 tileSizes.push_back(getConstantIntValue(shapeSize).hasValue() 152 ? b.create<arith::ConstantIndexOp>(loc, 0) 153 : b.create<arith::ConstantIndexOp>(loc, 1)); 154 return tileSizes; 155 }; 156 return *this; 157 } 158 159 /// Helper function that tries to pad `opOperand`. Exit early and return success 160 /// for scalar operands or if `paddingFunc` returns failure. Otherwise, try to 161 /// pad the operand even if it already has a static shape. Set `result` to the 162 /// result of the created PadTensorOp or return failure if the operand cannot be 163 /// padded to a static shape. 164 static LogicalResult padOperandToSmallestStaticBoundingBox( 165 OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand, 166 const PaddingValueComputationFunction &paddingFunc, 167 const PaddingNoFoldComputationFunction &nofoldFunc, Value &result) { 168 // Can't pad scalars. 169 if (opToPad.getShape(opOperand).empty()) 170 return success(); 171 // Can't pad if no padding value is known. 172 FailureOr<Value> paddingValue = paddingFunc(b, *opOperand); 173 if (failed(paddingValue)) 174 return success(); 175 auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>(); 176 // Not a slice op, cannot construct a static bounding box. 177 if (!sliceOp) 178 return failure(); 179 SmallVector<int64_t> staticSizes; 180 staticSizes.reserve(opToPad.getRank(opOperand)); 181 auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation()); 182 for (auto size : shapedOp.getMixedSizes()) { 183 // If the size is an attribute add it directly to `staticSizes`. 184 if (size.is<Attribute>()) { 185 staticSizes.push_back( 186 size.get<Attribute>().dyn_cast<IntegerAttr>().getInt()); 187 continue; 188 } 189 // Otherwise, try to compute a constant upper bound for the size value. 190 FailureOr<int64_t> upperBound = 191 getConstantUpperBoundForIndex(size.get<Value>()); 192 if (failed(upperBound)) { 193 LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding"); 194 return failure(); 195 } 196 staticSizes.push_back(upperBound.getValue()); 197 } 198 auto staticTensorType = RankedTensorType::get( 199 staticSizes, getElementTypeOrSelf(opOperand->get())); 200 bool nofold = nofoldFunc ? nofoldFunc(*opOperand) : false; 201 result = linalg::PadTensorOp::createPadHighOp( 202 staticTensorType, opOperand->get(), paddingValue.getValue(), 203 /*nofold=*/nofold, opToPad->getLoc(), b); 204 return success(); 205 } 206 207 FailureOr<SmallVector<Value>> 208 linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad, 209 const PaddingValueComputationFunction &paddingFunc, 210 const PaddingNoFoldComputationFunction &nofoldFunc, 211 LinalgOp &paddedOp) { 212 Location loc = opToPad->getLoc(); 213 214 // TODO: there are cases where we may still want to pad to larger sizes. 215 assert(opToPad.hasTensorSemantics() && 216 "expected operation to have tensor semantics"); 217 218 OpBuilder::InsertionGuard g(b); 219 // Set IP after op because we also take the dims of the original output. 220 b.setInsertionPointAfter(opToPad); 221 // Make a copy of the shaped operands and update it. 222 SmallVector<Value> newOperands; 223 newOperands.reserve(opToPad.getNumInputsAndOutputs()); 224 for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) { 225 Value paddedOperand; 226 // If padding was requested but the shape cannot be bounded statically then 227 // the pattern fails to apply. 228 if (failed(padOperandToSmallestStaticBoundingBox( 229 b, opToPad, opOperand, paddingFunc, nofoldFunc, paddedOperand))) 230 return failure(); 231 newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get()); 232 } 233 234 SmallVector<SmallVector<Value>> reifiedResultShapes; 235 if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation()) 236 .reifyResultShapes(b, reifiedResultShapes))) 237 return failure(); 238 assert(reifiedResultShapes.size() == opToPad->getNumResults() && 239 "expected same number of results"); 240 241 // Clone `opToPad` to operate on the statically padded shapes. 242 auto resultTensorTypes = 243 ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes(); 244 paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands); 245 246 // Recover the slice out of the new static results. This keeps the original 247 // linalg op around because it uses the dims of the original results. 248 SmallVector<Value> paddedSubviewResults; 249 paddedSubviewResults.reserve(opToPad->getNumResults()); 250 for (auto en : llvm::enumerate(paddedOp->getResults())) { 251 Value paddedResult = en.value(); 252 int64_t resultNumber = en.index(); 253 int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank(); 254 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0)); 255 SmallVector<OpFoldResult> sizes; 256 for (Value v : reifiedResultShapes[resultNumber]) 257 sizes.push_back(v); 258 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1)); 259 paddedSubviewResults.push_back(b.create<tensor::ExtractSliceOp>( 260 loc, paddedResult, offsets, sizes, strides)); 261 } 262 return paddedSubviewResults; 263 } 264 265 /// Linalg base tiling pattern. 266 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 267 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 268 LinalgTransformationFilter filter, PatternBenefit benefit) 269 : RewritePattern(opName, benefit, context), filter(filter), 270 options(options) {} 271 272 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( 273 MLIRContext *context, LinalgTilingOptions options, 274 LinalgTransformationFilter filter, PatternBenefit benefit) 275 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 276 options(options) {} 277 278 /// Try to peel a loop `op` and return the new result. 279 // TODO: Add support for scf.parallel and affine.for loops. 280 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) { 281 return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op) 282 .Case<scf::ForOp>([&](scf::ForOp forOp) { 283 scf::ForOp partialIteration; 284 if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp, 285 partialIteration))) 286 return partialIteration->getResults(); 287 assert(!partialIteration && "expected that loop was not peeled"); 288 return forOp->getResults(); 289 }) 290 .Default([&](Operation *op) { return op->getResults(); }); 291 } 292 293 /// Try to peel a TiledLoopOp and return the new result. 294 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, 295 TiledLoopOp tiledLoop, int64_t idx) { 296 assert(idx < static_cast<int64_t>(tiledLoop.iterator_types().size()) && 297 "requested peeling of non-existing loop"); 298 TiledLoopOp result; 299 if (succeeded(peelAndCanonicalizeTiledLoop(rewriter, tiledLoop, idx, result))) 300 return result->getResults(); 301 assert(!result && "expected that loop was not peeled"); 302 return tiledLoop->getResults(); 303 } 304 305 /// Peel loops after tiling. 306 static void peelLoops(RewriterBase &rewriter, TiledLinalgOp &res, 307 const LinalgTilingOptions &options) { 308 for (int64_t loop : options.peeledLoops) { 309 assert(loop < static_cast<int64_t>(res.loops.size()) && 310 "requested peeling of non-existing loop"); 311 SmallVector<Value, 4> loopResults; 312 Operation *loopOp = res.loops[loop]; 313 if (options.loopType == LinalgTilingLoopType::TiledLoops) { 314 assert(llvm::all_of( 315 res.loops, 316 [&](Operation *op) { return op == res.loops.front(); }) && 317 "expected that all loop ops are the same TiledLoopOp"); 318 auto tiledLoopOp = dyn_cast<TiledLoopOp>(loopOp); 319 assert(tiledLoopOp && "expected TiledLoopOp"); 320 loopResults = peelLoop(rewriter, tiledLoopOp, loop); 321 } else { 322 loopResults = peelLoop(rewriter, loopOp); 323 } 324 325 // The result of the loop nest may change with peeling. 326 if (res.tensorResults.size() == loopOp->getNumResults() && 327 std::equal(res.tensorResults.begin(), res.tensorResults.end(), 328 loopOp->getResults().begin())) 329 res.tensorResults = loopResults; 330 } 331 } 332 333 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase( 334 Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const { 335 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 336 if (!linalgOp) 337 return failure(); 338 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 339 return failure(); 340 341 Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options); 342 343 if (!res) 344 return failure(); 345 // Clear filter to stop recursive pattern application. 346 filter.replaceLinalgTransformationFilter(rewriter, res->op); 347 348 // Peel loops. 349 peelLoops(rewriter, *res, options); 350 351 result = *res; 352 return success(); 353 } 354 355 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) { 356 if (tiledOp.loops.empty()) 357 return tiledOp.op.getOperation()->getResults(); 358 return tiledOp.loops.front()->getResults(); 359 } 360 361 static ValueRange 362 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) { 363 if (tiledAndFusedOp.fusedLoops.empty()) 364 return tiledAndFusedOp.op.getOperation()->getResults(); 365 return tiledAndFusedOp.fusedLoops.front()->getResults(); 366 } 367 368 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( 369 StringRef opName, MLIRContext *context, 370 const LinalgDependenceGraph &dependenceGraph, 371 LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, 372 LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker, 373 LinalgTransformationFilter originalOpMarker, PatternBenefit benefit) 374 : RewritePattern(opName, benefit, context, {}), 375 dependenceGraph(dependenceGraph), tilingOptions(tilingOptions), 376 fusionOptions(fusionOptions), filter(filter), 377 fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {} 378 379 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( 380 Operation *op, PatternRewriter &rewriter) const { 381 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 382 // TODO: remove hasIndexSemantics check once index ops are supported. 383 if (!linalgOp || linalgOp.hasIndexSemantics()) 384 return failure(); 385 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 386 return failure(); 387 388 DenseSet<Operation *> producers; 389 producers.insert(linalgOp); 390 for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) { 391 Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum(); 392 // When looking at dependences into, indexingOp is always OpOperand. We 393 // could assert, but continue if this is not the case. 394 if (!operandNumber) 395 continue; 396 if (!fusionOptions.indicesToFuse.count(operandNumber.getValue())) 397 continue; 398 if (isa<LinalgOp>(dependence.getDependentOp())) 399 producers.insert(dependence.getDependentOp()); 400 } 401 402 SmallVector<LinalgOp, 1> fusionOps; 403 for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; 404 ++it) { 405 auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it)); 406 if (producerLinalgOp && producers.count(producerLinalgOp)) 407 fusionOps.push_back(producerLinalgOp); 408 } 409 fusionOps.push_back(linalgOp); 410 411 SmallVector<Value, 4> tileSizes = 412 tilingOptions.tileSizeComputationFunction(rewriter, op); 413 LinalgTilingOptions instanceTilingOptions = tilingOptions; 414 instanceTilingOptions.setTileSizes(tileSizes); 415 Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps( 416 rewriter, fusionOps, dependenceGraph, instanceTilingOptions); 417 if (!tiledAndFusedOps) 418 return failure(); 419 420 // Tile the unfused loops; 421 SmallVector<Value, 4> unfusedLoopTileSizes; 422 Value zero = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0); 423 for (auto tileSize : enumerate(tileSizes)) { 424 if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) 425 unfusedLoopTileSizes.push_back(zero); 426 else 427 unfusedLoopTileSizes.push_back(tileSize.value()); 428 } 429 // Tile the loop only if there is a non-zero tile size. 430 if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) 431 unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); 432 if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { 433 if (auto cst = val.getDefiningOp<arith::ConstantIndexOp>()) 434 return cst.value() != 0; 435 return true; 436 })) { 437 LinalgTilingOptions unfusedTilingOptions = tilingOptions; 438 unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); 439 Optional<TiledLinalgOp> unfusedTiledOp = 440 tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); 441 if (!unfusedTiledOp) 442 return failure(); 443 rewriter.replaceOp(tiledAndFusedOps->op, 444 getTiledOpResult(unfusedTiledOp.getValue())); 445 tiledAndFusedOps->op = unfusedTiledOp->op; 446 } 447 op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue())); 448 449 filter.replaceLinalgTransformationFilter(rewriter, 450 tiledAndFusedOps->op.getOperation()); 451 for (auto fusedOp : tiledAndFusedOps->fusedProducers) { 452 fusedOpMarker.replaceLinalgTransformationFilter(rewriter, 453 fusedOp.getOperation()); 454 } 455 for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) { 456 originalOpMarker.replaceLinalgTransformationFilter( 457 rewriter, origProducerOp.getOperation()); 458 } 459 rewriter.updateRootInPlace(op, [&]() { 460 originalOpMarker.replaceLinalgTransformationFilter(rewriter, op); 461 }); 462 return success(); 463 } 464 465 /// Linalg padding pattern. 466 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( 467 MLIRContext *context, LinalgPaddingOptions options, 468 LinalgTransformationFilter filter, PatternBenefit benefit) 469 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 470 options(options) {} 471 472 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( 473 StringRef opName, MLIRContext *context, LinalgPaddingOptions options, 474 LinalgTransformationFilter filter, PatternBenefit benefit) 475 : RewritePattern(opName, benefit, context, {}), filter(filter), 476 options(options) {} 477 478 LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite( 479 Operation *op, PatternRewriter &rewriter) const { 480 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 481 if (!linalgOp) 482 return failure(); 483 if (!linalgOp.hasTensorSemantics()) 484 return failure(); 485 if (failed(filter.checkAndNotify(rewriter, op))) 486 return failure(); 487 488 // Pad the operation. 489 LinalgOp paddedOp; 490 FailureOr<SmallVector<Value>> newResults = rewriteAsPaddedOp( 491 rewriter, linalgOp, options.paddingValueComputationFunction, 492 options.paddingNoFoldComputationFunction, paddedOp); 493 if (failed(newResults)) 494 return failure(); 495 496 // Compute the desired hoisting depths. 497 SmallVector<int64_t> depths; 498 if (options.paddingHoistComputationFunction) { 499 for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) 500 depths.push_back(options.paddingHoistComputationFunction(*opOperand)); 501 } 502 503 // Hoist the padding. 504 for (auto en : enumerate(depths)) { 505 OpOperand &opOperand = paddedOp->getOpOperand(en.index()); 506 auto padTensorOp = opOperand.get().getDefiningOp<PadTensorOp>(); 507 if (!padTensorOp || en.value() == 0) 508 continue; 509 PadTensorOp hoistedOp; 510 FailureOr<Value> newResult = 511 hoistPaddingOnTensors(padTensorOp, en.value(), hoistedOp); 512 if (failed(newResult)) 513 continue; 514 rewriter.replaceOp(padTensorOp, newResult.getValue()); 515 } 516 517 // Replace the original operation to pad. 518 rewriter.replaceOp(op, newResults.getValue()); 519 filter.replaceLinalgTransformationFilter(rewriter, paddedOp); 520 return success(); 521 } 522 523 /// Linalg tile and fuse tensor ops pattern. 524 mlir::linalg::LinalgTileAndFuseTensorOpsPattern:: 525 LinalgTileAndFuseTensorOpsPattern(MLIRContext *context, 526 LinalgTilingAndFusionOptions options, 527 LinalgTransformationFilter filter, 528 PatternBenefit benefit) 529 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 530 options(options) {} 531 532 mlir::linalg::LinalgTileAndFuseTensorOpsPattern:: 533 LinalgTileAndFuseTensorOpsPattern(StringRef opName, MLIRContext *context, 534 LinalgTilingAndFusionOptions options, 535 LinalgTransformationFilter filter, 536 PatternBenefit benefit) 537 : RewritePattern(opName, benefit, context), filter(filter), 538 options(options) {} 539 540 LogicalResult mlir::linalg::LinalgTileAndFuseTensorOpsPattern::matchAndRewrite( 541 Operation *op, PatternRewriter &rewriter) const { 542 LinalgOp rootOp = dyn_cast<LinalgOp>(op); 543 if (!rootOp) 544 return failure(); 545 if (failed(filter.checkAndNotify(rewriter, op))) 546 return failure(); 547 548 // Check `tileSizes` contains a tile size for every `rootOp` loop dimension. 549 if (options.tileSizes.size() < rootOp.getNumLoops()) 550 return rewriter.notifyMatchFailure(op, "expect #tile sizes >= #loops"); 551 552 // Check `tileInterchange` contains no entries or as many as `tileSizes`. 553 if (!options.tileInterchange.empty() && 554 options.tileInterchange.size() != options.tileSizes.size()) 555 return rewriter.notifyMatchFailure( 556 op, "expect the number of tile sizes and interchange dims to match"); 557 558 // Copy the `tileSizes` and `tileInterchange` prefixes needed for `rootOp`. 559 SmallVector<int64_t> rootTileSizes(options.tileSizes.begin(), 560 options.tileSizes.begin() + 561 rootOp.getNumLoops()); 562 SmallVector<int64_t> rootInterchange = 563 options.tileInterchange.empty() 564 ? llvm::to_vector<6>(llvm::seq<int64_t>(0, rootOp.getNumLoops())) 565 : SmallVector<int64_t>(options.tileInterchange.begin(), 566 options.tileInterchange.begin() + 567 rootOp.getNumLoops()); 568 569 // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions. 570 // It has to be a permutation since the tiling cannot tile the same loop 571 // dimension multiple times. 572 if (!isPermutation(rootInterchange)) 573 return rewriter.notifyMatchFailure( 574 op, "expect the tile interchange permutes the root loops"); 575 576 // Tile `rootOp` and fuse its producers. 577 FailureOr<TileLoopNest> tileLoopNest = tileConsumerAndFuseProducers( 578 rewriter, rootOp, rootTileSizes, rootInterchange); 579 if (failed(tileLoopNest)) 580 return rewriter.notifyMatchFailure( 581 op, "tileConsumerAndFuseProducers failed unexpectedly"); 582 583 // Replace all uses of the tiled loop operation. 584 rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults()); 585 586 // Apply the filter if specified. 587 for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps()) 588 filter.replaceLinalgTransformationFilter(rewriter, linalgOp); 589 return failure(); 590 } 591 592 /// Linalg generic interchange pattern. 593 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern( 594 MLIRContext *context, ArrayRef<unsigned> interchangeVector, 595 LinalgTransformationFilter filter, PatternBenefit benefit) 596 : OpRewritePattern(context, benefit), filter(filter), 597 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 598 599 LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite( 600 GenericOp genericOp, PatternRewriter &rewriter) const { 601 if (failed(filter.checkAndNotify(rewriter, genericOp))) 602 return failure(); 603 if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector))) 604 return failure(); 605 606 // TODO: figure out how this interplays with named ops. In particular this 607 // should break the named op property. 608 rewriter.updateRootInPlace(genericOp, [&]() { 609 interchangeGenericOp(rewriter, genericOp, interchangeVector); 610 // New filter if specified. 611 filter.replaceLinalgTransformationFilter(rewriter, genericOp); 612 }); 613 return success(); 614 } 615 616 /// Linalg generalization pattern. 617 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 618 MLIRContext *context, LinalgTransformationFilter filter, 619 PatternBenefit benefit) 620 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {} 621 622 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 623 StringRef opName, MLIRContext *context, LinalgTransformationFilter filter, 624 PatternBenefit benefit) 625 : RewritePattern(opName, benefit, context, {}), filter(filter) {} 626 627 LogicalResult mlir::linalg::LinalgGeneralizationPattern::matchAndRewrite( 628 Operation *op, PatternRewriter &rewriter) const { 629 if (failed(filter.checkAndNotify(rewriter, op))) 630 return failure(); 631 if (failed(generalizeNamedOpPrecondition(op))) 632 return failure(); 633 634 GenericOp genericOp = generalizeNamedOp(rewriter, op); 635 rewriter.replaceOp(op, genericOp.getResults()); 636 filter.replaceLinalgTransformationFilter(rewriter, genericOp); 637 return success(); 638 } 639 640 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 641 MLIRContext *context, LinalgTransformationFilter filter, 642 LinalgPromotionOptions options, PatternBenefit benefit) 643 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), 644 options(options) {} 645 646 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( 647 StringRef opName, MLIRContext *context, LinalgPromotionOptions options, 648 LinalgTransformationFilter filter, PatternBenefit benefit) 649 : RewritePattern(opName, benefit, context, {}), filter(filter), 650 options(options) {} 651 652 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( 653 Operation *op, PatternRewriter &rewriter) const { 654 if (failed(filter.checkAndNotify(rewriter, op))) 655 return failure(); 656 if (failed(promoteSubviewsPrecondition(op, options))) 657 return failure(); 658 659 // TODO: We cannot use root update here. This pattern is creating other ops, 660 // so if the promotion fails, those need to be cleaned up, which doesnt seem 661 // to be happening here. So to fail properly, we should be cloning the op and 662 // deleting the previous op. This needs more investigation. 663 rewriter.startRootUpdate(op); 664 Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options); 665 if (!promotedOp) { 666 rewriter.cancelRootUpdate(op); 667 return op->emitError("subview promotion failed"); 668 } 669 rewriter.finalizeRootUpdate(op); 670 filter.replaceLinalgTransformationFilter(rewriter, op); 671 return success(); 672 } 673 674 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 675 MLIRContext *context, LinalgTransformationFilter filter, 676 PatternBenefit benefit) 677 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {} 678 679 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( 680 StringRef opName, MLIRContext *context, LinalgTransformationFilter filter, 681 PatternBenefit benefit) 682 : RewritePattern(opName, benefit, context, {}), filter(filter) {} 683 684 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( 685 Operation *op, PatternRewriter &rewriter) const { 686 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 687 if (!linalgOp) 688 return failure(); 689 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 690 return failure(); 691 SmallVector<Value> newResults; 692 if (failed(vectorizeLinalgOp(rewriter, op, newResults))) 693 return failure(); 694 if (!newResults.empty()) 695 rewriter.replaceOp(op, newResults); 696 else 697 rewriter.eraseOp(op); 698 return success(); 699 } 700 701 LogicalResult mlir::linalg::applyStagedPatterns( 702 Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns, 703 const FrozenRewritePatternSet &stage2Patterns, 704 function_ref<LogicalResult(Operation *)> stage3Lambda) { 705 unsigned iteration = 0; 706 (void)iteration; 707 for (const auto &patterns : stage1Patterns) { 708 LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" 709 << *op); 710 if (failed(applyPatternsAndFoldGreedily(op, patterns))) { 711 LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); 712 return failure(); 713 } 714 LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" 715 << *op); 716 if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { 717 LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); 718 return failure(); 719 } 720 LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" 721 << *op); 722 if (stage3Lambda) { 723 if (failed(stage3Lambda(op))) 724 return failure(); 725 LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" 726 << *op); 727 } 728 } 729 return success(); 730 } 731 732 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) { 733 return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName()); 734 } 735 736 /// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize 737 /// with pad_val) and GenericOp (to copy contents). 738 LogicalResult PadTensorOpTransformationPattern::matchAndRewrite( 739 linalg::PadTensorOp padOp, PatternRewriter &rewriter) const { 740 741 auto inputShapedType = padOp.source().getType().cast<ShapedType>(); 742 auto resultShapedType = padOp.result().getType().cast<ShapedType>(); 743 744 // Bail on non-static shapes. 745 if (!inputShapedType.hasStaticShape()) 746 return failure(); 747 if (!resultShapedType.hasStaticShape()) 748 return failure(); 749 750 // Only support padding with a constant for now, i.e. either: 751 // 1. A BBarg from a different block. 752 // 2. A value defined outside of the current block. 753 Block &block = padOp.region().front(); 754 auto yieldOp = cast<YieldOp>(block.getTerminator()); 755 assert(yieldOp.getNumOperands() == 1 && "expected single operand yield"); 756 Value padValue = yieldOp.values().front(); 757 Operation *definingOp = padValue.getDefiningOp(); 758 if (definingOp && definingOp->getBlock() == &block) 759 return failure(); 760 if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block) 761 return failure(); 762 763 // Create tensor with the padded shape 764 Location loc = padOp.getLoc(); 765 SmallVector<Value> indices(resultShapedType.getRank(), 766 rewriter.create<arith::ConstantIndexOp>(loc, 0)); 767 Value initTensor = rewriter.create<InitTensorOp>( 768 loc, resultShapedType.getShape(), resultShapedType.getElementType()); 769 770 // Initialize tensor with the pad value 771 Value tmpTensor = 772 rewriter.create<linalg::FillOp>(loc, padValue, initTensor).result(); 773 774 // Copy original contents into new tensor 775 // Uses linalg.generic, but could be done with tensor.insert_slice 776 SmallVector<AffineExpr, 4> outputExprs; 777 for (unsigned i = 0; i < resultShapedType.getRank(); ++i) { 778 outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) + 779 padOp.static_low()[i].cast<IntegerAttr>().getInt()); 780 } 781 782 SmallVector<AffineMap, 2> transferMaps = { 783 rewriter.getMultiDimIdentityMap(inputShapedType.getRank()), 784 AffineMap::get(resultShapedType.getRank(), 785 /*symbolCount=*/0, outputExprs, rewriter.getContext())}; 786 787 rewriter.replaceOpWithNewOp<linalg::GenericOp>( 788 padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps, 789 getNParallelLoopsAttrs(resultShapedType.getRank()), 790 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { 791 nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]); 792 }); 793 794 return success(); 795 } 796 797 /// Filling `dest` using FillOp constant padding value if possible. 798 /// Otherwise, generate a tensor::GenerateOp. 799 Value GeneralizePadTensorOpPattern::createFillOrGenerateOp( 800 PatternRewriter &rewriter, PadTensorOp padOp, Value dest, 801 const SmallVector<Value> &dynSizes) const { 802 auto padValue = padOp.getConstantPaddingValue(); 803 if (padValue) 804 return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result(); 805 806 // Fill could not be optimized: Lower to tensor::GenerateOp with region. 807 auto generateOp = rewriter.create<tensor::GenerateOp>( 808 padOp.getLoc(), padOp.getResultType(), dynSizes); 809 // Copy region to new op. 810 BlockAndValueMapping bvm; 811 padOp.region().cloneInto(&generateOp.getRegion(), bvm); 812 // Rewrite linalg::YieldOp to tensor::YieldOp. 813 OpBuilder::InsertionGuard guard(rewriter); 814 auto yieldOp = 815 dyn_cast<linalg::YieldOp>(generateOp.getRegion().front().getTerminator()); 816 assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator"); 817 assert(yieldOp.values().size() == 1); 818 rewriter.setInsertionPoint(yieldOp); 819 rewriter.replaceOpWithNewOp<tensor::YieldOp>(yieldOp, yieldOp.values()[0]); 820 return generateOp; 821 } 822 823 LogicalResult 824 GeneralizePadTensorOpPattern::matchAndRewrite(PadTensorOp padOp, 825 PatternRewriter &rewriter) const { 826 // Given an OpFoldResult, return an index-typed value. 827 auto getIdxValue = [&](OpFoldResult ofr) { 828 if (auto val = ofr.dyn_cast<Value>()) 829 return val; 830 return rewriter 831 .create<arith::ConstantIndexOp>( 832 padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt()) 833 .getResult(); 834 }; 835 836 auto resultType = padOp.getResultType(); 837 // Compute size of InitTensorOp. Any combination of static/dynamic is 838 // supported. 839 SmallVector<Value> dynSizes; 840 SmallVector<int64_t> staticSizes; 841 for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { 842 if (resultType.isDynamicDim(dim)) { 843 auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(), 844 padOp.source(), dim); 845 // Add low and high padding value. 846 auto plusLow = rewriter.createOrFold<arith::AddIOp>( 847 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); 848 auto plusHigh = rewriter.createOrFold<arith::AddIOp>( 849 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); 850 dynSizes.push_back(plusHigh); 851 } 852 staticSizes.push_back(resultType.getDimSize(dim)); 853 } 854 855 // Init tensor and fill it with padding. 856 Value init = rewriter.create<InitTensorOp>( 857 padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType()); 858 Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes); 859 860 // Try optimize the copy of source. 861 if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded()) 862 return success(); 863 864 // PadTensorOps cannot be optimized. Generate a InsertSliceOp instead 865 // for copying the PadOp source. 866 auto sourceType = padOp.getSourceType(); 867 // Compute size of source of PadTensorOp. 868 SmallVector<OpFoldResult> srcSizes; 869 for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) { 870 if (sourceType.isDynamicDim(dim)) { 871 srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>( 872 padOp.getLoc(), padOp.source(), dim)); 873 } else { 874 srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim))); 875 } 876 } 877 // Strides of InsertSliceOp are all 1. 878 SmallVector<OpFoldResult> strides(sourceType.getRank(), 879 rewriter.getIndexAttr(1)); 880 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( 881 padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides); 882 883 return success(); 884 } 885 886 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( 887 tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { 888 auto padOp = sliceOp.source().getDefiningOp<PadTensorOp>(); 889 if (!padOp) 890 return failure(); 891 // Only unit stride supported. 892 if (!sliceOp.hasUnitStride()) 893 return failure(); 894 895 Operation *tiledPadOp = padOp.getTiledImplementation( 896 rewriter, /*dest=*/ValueRange{}, sliceOp.getMixedOffsets(), 897 sliceOp.getMixedSizes()); 898 // All shapes are static and the data source is actually used. Rewrite into 899 // pad_tensor(subtensor(x)). 900 rewriter.replaceOp(sliceOp, tiledPadOp->getResults()); 901 return success(); 902 } 903 904 namespace { 905 // The following are patterns for downscaling convolution ops with size-1 906 // window dimensions. 907 // 908 // Note that we'd eventually want to write such transformations in a generic 909 // way, e.g., converting to linalg.generic, removing the size-1 dimensions, 910 // and then turning back to named ops. But for now it's fine to have a few 911 // patterns matching special ops to get started. 912 913 /// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D 914 /// convolution ops. 915 struct DownscaleSizeOneWindowed2DConvolution final 916 : public OpRewritePattern<Conv2DNhwcHwcfOp> { 917 using OpRewritePattern::OpRewritePattern; 918 919 LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, 920 PatternRewriter &rewriter) const override { 921 auto linalgOp = cast<linalg::LinalgOp>(*convOp); 922 if (linalgOp.hasBufferSemantics()) 923 return failure(); // To be implemented 924 925 Value input = convOp.inputs().front(); 926 Value filter = convOp.inputs().back(); 927 Value output = convOp.outputs().front(); 928 929 auto inputType = input.getType().dyn_cast<RankedTensorType>(); 930 auto filterType = filter.getType().dyn_cast<RankedTensorType>(); 931 auto outputType = output.getType().dyn_cast<RankedTensorType>(); 932 933 auto filterShape = filterType.getShape(); 934 auto outputShape = outputType.getShape(); 935 936 // Only handle the case where at least one of the window dimensions is 937 // of size 1. Other cases can rely on tiling to reduce to such cases. 938 int64_t fhSize = filterShape[0], fwSize = filterShape[1]; 939 int64_t ohSize = outputShape[1], owSize = outputShape[2]; 940 bool removeH = (fhSize == 1 && ohSize == 1); 941 bool removeW = (fwSize == 1 && owSize == 1); 942 if (!removeH && !removeW) 943 return failure(); 944 945 // Get new shapes and types for all operands by removing the size-1 946 // dimension. 947 using RTTBuilder = RankedTensorType::Builder; 948 RankedTensorType newInputType = 949 RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); 950 RankedTensorType newFilterType = 951 RTTBuilder(filterType).dropDim((removeH ? 0 : 1)); 952 RankedTensorType newOutputType = 953 RTTBuilder(outputType).dropDim(removeH ? 1 : 2); 954 955 // Rank-reduce operands. 956 Location loc = convOp.getLoc(); 957 Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( 958 rewriter, loc, input, newInputType); 959 Value newFilter = tensor::createCanonicalRankReducingExtractSliceOp( 960 rewriter, loc, filter, newFilterType); 961 Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( 962 rewriter, loc, output, newOutputType); 963 964 // Rank-reduce strides and dilations too. 965 // TODO: dropDim 1-liner helper. 966 auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>()); 967 strides.erase(strides.begin() + (removeH ? 0 : 1)); 968 auto stridesAttr = rewriter.getI64VectorAttr(strides); 969 970 auto dilations = 971 llvm::to_vector<4>(convOp.dilations().getValues<int64_t>()); 972 dilations.erase(dilations.begin() + (removeH ? 0 : 1)); 973 auto dilationsAttr = rewriter.getI64VectorAttr(dilations); 974 975 auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>( 976 loc, newOutputType, ValueRange{newInput, newFilter}, 977 ValueRange{newOutput}, stridesAttr, dilationsAttr); 978 979 // Insert back. 980 Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( 981 rewriter, loc, conv1DOp.getResult(0), output); 982 rewriter.replaceOp(convOp, inserted); 983 984 return success(); 985 }; 986 }; 987 988 /// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) 989 /// dimensions into 1-D depthwise convolution ops. 990 struct DownscaleDepthwiseConv2DNhwcHwcOp final 991 : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> { 992 using OpRewritePattern::OpRewritePattern; 993 994 LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, 995 PatternRewriter &rewriter) const override { 996 auto linalgOp = cast<linalg::LinalgOp>(*convOp); 997 if (linalgOp.hasBufferSemantics()) 998 return failure(); // To be implemented 999 1000 Value input = convOp.inputs().front(); 1001 Value kernel = convOp.inputs().back(); 1002 Value output = convOp.outputs().front(); 1003 1004 auto inputType = input.getType().dyn_cast<RankedTensorType>(); 1005 auto kernelType = kernel.getType().dyn_cast<RankedTensorType>(); 1006 auto outputType = output.getType().dyn_cast<RankedTensorType>(); 1007 1008 auto kernelShape = kernelType.getShape(); 1009 auto outputShape = outputType.getShape(); 1010 1011 // Only handle the case where at least one of the window dimensions is 1012 // of size 1. Other cases can rely on tiling to reduce to such cases. 1013 int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; 1014 int64_t ohSize = outputShape[1], owSize = outputShape[2]; 1015 bool removeH = (khSize == 1 && ohSize == 1); 1016 bool removeW = (kwSize == 1 && owSize == 1); 1017 if (!removeH && !removeW) 1018 return failure(); 1019 1020 // Get new shapes and types for all operands by removing the size-1 1021 // dimension. 1022 using RTTBuilder = RankedTensorType::Builder; 1023 RankedTensorType newInputType = 1024 RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); 1025 RankedTensorType newKernelType = 1026 RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); 1027 RankedTensorType newOutputType = 1028 RTTBuilder(outputType).dropDim(removeH ? 1 : 2); 1029 1030 // Rank-reduce operands. 1031 Location loc = convOp.getLoc(); 1032 Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( 1033 rewriter, loc, input, newInputType); 1034 Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( 1035 rewriter, loc, kernel, newKernelType); 1036 Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( 1037 rewriter, loc, output, newOutputType); 1038 1039 // Rank-reduce strides and dilations too. 1040 // TODO: dropDim 1-liner helper. 1041 auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>()); 1042 strides.erase(strides.begin() + (removeH ? 0 : 1)); 1043 auto stridesAttr = rewriter.getI64VectorAttr(strides); 1044 1045 auto dilations = 1046 llvm::to_vector<4>(convOp.dilations().getValues<int64_t>()); 1047 dilations.erase(dilations.begin() + (removeH ? 0 : 1)); 1048 auto dilationsAttr = rewriter.getI64VectorAttr(dilations); 1049 1050 auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>( 1051 loc, newOutputType, ValueRange{newInput, newKernel}, 1052 ValueRange{newOutput}, stridesAttr, dilationsAttr); 1053 1054 // Insert back. 1055 Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( 1056 rewriter, loc, conv1DOp.getResult(0), output); 1057 rewriter.replaceOp(convOp, inserted); 1058 1059 return success(); 1060 }; 1061 }; 1062 1063 } // namespace 1064 1065 void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, 1066 PatternBenefit benefit) { 1067 patterns.add<DownscaleSizeOneWindowed2DConvolution, 1068 DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), 1069 benefit); 1070 } 1071