1 //===- Transforms.cpp - Linalg transformations as patterns ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 19 #include "mlir/Dialect/Linalg/IR/Linalg.h" 20 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" 21 #include "mlir/Dialect/Linalg/Utils/Utils.h" 22 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 23 #include "mlir/Dialect/Tensor/IR/Tensor.h" 24 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" 25 #include "mlir/Dialect/Utils/StaticValueUtils.h" 26 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 27 #include "mlir/Dialect/Vector/IR/VectorOps.h" 28 #include "mlir/IR/AffineExpr.h" 29 #include "mlir/IR/Matchers.h" 30 #include "mlir/Pass/Pass.h" 31 #include "mlir/Support/LLVM.h" 32 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 33 #include "llvm/ADT/ScopeExit.h" 34 #include "llvm/ADT/TypeSwitch.h" 35 #include "llvm/Support/Debug.h" 36 #include "llvm/Support/raw_ostream.h" 37 #include <type_traits> 38 #include <utility> 39 40 #define DEBUG_TYPE "linalg-transforms" 41 42 using namespace mlir; 43 using namespace mlir::linalg; 44 45 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 46 47 //===----------------------------------------------------------------------===// 48 // Transformations exposed as rewrite patterns. 49 //===----------------------------------------------------------------------===// 50 // Marker used as attribute name in generated Linalg rewriting transformations. 51 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = 52 "__internal_linalg_transform__"; 53 54 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 55 ArrayRef<StringAttr> matchDisjunction, Optional<StringAttr> replacement) 56 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 57 replacement(replacement), matchByDefault(false) {} 58 59 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( 60 const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction, 61 Optional<StringAttr> replacement) 62 : filters(), 63 matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), 64 replacement(replacement), matchByDefault(false) { 65 if (f) 66 filters.push_back(f); 67 } 68 69 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify( 70 PatternRewriter &rewriter, Operation *op) const { 71 if (llvm::any_of(filters, 72 [&](const FilterFunction &f) { return failed(f(op)); })) 73 return failure(); 74 75 auto attr = op->template getAttrOfType<StringAttr>( 76 LinalgTransforms::kLinalgTransformMarker); 77 78 if (!attr) { 79 // 1. Has no filter case and matchDisjunction is empty. 80 if (matchDisjunction.empty() || matchByDefault) 81 return success(); 82 83 // 2. Has no filter but was expecting a filter. 84 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 85 diag << " does not have any filter from list: "; 86 interleaveComma(matchDisjunction, diag); 87 }); 88 } 89 90 // 4. Match explicit filter. 91 for (auto filter : matchDisjunction) 92 if (attr.getValue() == filter) 93 return success(); 94 95 // 5. Fail to match. 96 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 97 diag << " does not have any filter from list: "; 98 interleaveComma(matchDisjunction, diag); 99 }); 100 } 101 102 void mlir::linalg::LinalgTransformationFilter:: 103 replaceLinalgTransformationFilter(PatternRewriter &rewriter, 104 Operation *op) const { 105 if (replacement.has_value()) 106 op->setAttr(LinalgTransforms::kLinalgTransformMarker, replacement.value()); 107 else 108 op->removeAttr( 109 rewriter.getStringAttr(LinalgTransforms::kLinalgTransformMarker)); 110 } 111 112 bool mlir::linalg::LinalgTransformationFilter::hasReplacementFilter( 113 Operation *op) const { 114 if (!replacement) 115 return false; 116 auto attr = op->getAttr(LinalgTransforms::kLinalgTransformMarker) 117 .dyn_cast<StringAttr>(); 118 return attr && attr == *replacement; 119 } 120 121 LinalgTilingOptions & 122 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 123 assert(!tileSizeComputationFunction && "tile sizes already set"); 124 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 125 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 126 OpBuilder::InsertionGuard guard(b); 127 b.setInsertionPointToStart( 128 &op->getParentOfType<func::FuncOp>().getBody().front()); 129 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 130 Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); 131 return v; 132 })); 133 }; 134 return *this; 135 } 136 137 LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() { 138 assert(!tileSizeComputationFunction && "tile sizes already set"); 139 tileSizeComputationFunction = [](OpBuilder &b, Operation *op) { 140 SmallVector<Value, 4> tileSizes; 141 auto linalgOp = dyn_cast<LinalgOp>(op); 142 if (!linalgOp) 143 return tileSizes; 144 Location loc = linalgOp.getLoc(); 145 auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc); 146 AffineMap map = linalgOp.getShapesToLoopsMap(); 147 if (!map) 148 return tileSizes; 149 auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes); 150 // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile 151 // size 0). 152 for (Value shapeSize : shapeSizes) 153 tileSizes.push_back(getConstantIntValue(shapeSize) 154 ? b.create<arith::ConstantIndexOp>(loc, 0) 155 : b.create<arith::ConstantIndexOp>(loc, 1)); 156 return tileSizes; 157 }; 158 return *this; 159 } 160 161 /// Pad the `opOperand` in the `paddingDimensions` using the padding value and 162 /// the nofold flag found in `paddingValues` and `packPaddings`, respectively. 163 /// Exit early and return the `opOperand` value if the shape dimensions that 164 /// match `paddingDimensions` have a static size and the nofold flag is not set. 165 /// Otherwise, try to pad the shape dimensions that match the iterator 166 /// dimensions `paddingDimensions` and return the tensor::PadOp result if 167 /// padding succeeds or failure otherwise. 168 static FailureOr<Value> padOperandToSmallestStaticBoundingBox( 169 OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand, 170 ArrayRef<int64_t> paddingDimensions, ArrayRef<Attribute> paddingValues, 171 ArrayRef<bool> packPaddings) { 172 AffineMap indexingMap = opToPad.getTiedIndexingMap(opOperand); 173 ArrayRef<int64_t> shape = opToPad.getShape(opOperand); 174 175 // Collect the shape dimension that are a function of the `paddingDimensions`. 176 llvm::SmallDenseSet<int64_t> shapeDimsToPad; 177 for (int64_t dim : paddingDimensions) 178 for (const auto &en : enumerate(indexingMap.getResults())) 179 if (en.value().isFunctionOfDim(dim)) 180 shapeDimsToPad.insert(en.index()); 181 182 // Return the unpadded operand if padding to a static shape is not needed and 183 // if the nofold flag is not set. 184 bool nofold = opOperand->getOperandNumber() < packPaddings.size() 185 ? packPaddings[opOperand->getOperandNumber()] 186 : false; 187 bool hasStaticShape = llvm::none_of(shapeDimsToPad, [&](int64_t dim) { 188 return ShapedType::isDynamic(shape[dim]); 189 }); 190 if (!nofold && hasStaticShape) 191 return opOperand->get(); 192 193 // Fail if `paddingValues` specifies no padding value. 194 if (opOperand->getOperandNumber() >= paddingValues.size()) 195 return failure(); 196 Attribute paddingAttr = paddingValues[opOperand->getOperandNumber()]; 197 Value paddingValue = b.create<arith::ConstantOp>( 198 opToPad.getLoc(), paddingAttr.getType(), paddingAttr); 199 200 // Follow the use-def chain if `currOpOperand` is defined by a LinalgOp. 201 OpOperand *currOpOperand = opOperand; 202 while (auto linalgOp = currOpOperand->get().getDefiningOp<LinalgOp>()) { 203 OpResult result = currOpOperand->get().cast<OpResult>(); 204 currOpOperand = linalgOp.getOutputOperand(result.getResultNumber()); 205 } 206 207 // Fail if `currOpOperand` is not defined by an ExtractSliceOp. 208 auto sliceOp = currOpOperand->get().getDefiningOp<tensor::ExtractSliceOp>(); 209 if (!sliceOp) 210 return failure(); 211 212 // Compute the dropped dimensions if `sliceOp` is ranke-reducing. 213 llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims(); 214 OffsetSizeAndStrideOpInterface shapedOp = sliceOp; 215 216 // Upper bound the `sliceOp` sizes to obtain a static bounding box. 217 SmallVector<int64_t> paddedShape(shape.begin(), shape.end()); 218 int64_t shapeIdx = 0; 219 for (const auto &en : enumerate(shapedOp.getMixedSizes())) { 220 // Skip dropped dimensions. 221 if (droppedDims.test(en.index())) 222 continue; 223 // Skip dimensions that do not require padding. 224 if (!shapeDimsToPad.contains(shapeIdx)) { 225 shapeIdx++; 226 continue; 227 } 228 // If the size is an attribute add it directly to `paddedShape`. 229 if (en.value().is<Attribute>()) { 230 paddedShape[shapeIdx++] = 231 en.value().get<Attribute>().dyn_cast<IntegerAttr>().getInt(); 232 continue; 233 } 234 // Otherwise, try to compute a constant upper bound for the size value. 235 FailureOr<int64_t> upperBound = 236 getConstantUpperBoundForIndex(en.value().get<Value>()); 237 if (failed(upperBound)) { 238 LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding"); 239 return failure(); 240 } 241 paddedShape[shapeIdx++] = *upperBound; 242 } 243 assert(shapeIdx == static_cast<int64_t>(shape.size()) && 244 "expect the dynamic and static ranks to match"); 245 246 // Pad the operand to the bounding box defined by `paddedShape`. 247 auto paddedTensorType = RankedTensorType::get( 248 paddedShape, getElementTypeOrSelf(opOperand->get())); 249 return makeComposedPadHighOp(b, opToPad->getLoc(), paddedTensorType, 250 opOperand->get(), paddingValue, nofold); 251 } 252 253 FailureOr<SmallVector<Value>> 254 linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad, 255 ArrayRef<int64_t> paddingDimensions, 256 ArrayRef<Attribute> paddingValues, 257 ArrayRef<bool> packPaddings, LinalgOp &paddedOp) { 258 Location loc = opToPad->getLoc(); 259 260 // TODO: there are cases where we may still want to pad to larger sizes. 261 assert(opToPad.hasTensorSemantics() && 262 "expected operation to have tensor semantics"); 263 264 OpBuilder::InsertionGuard g(b); 265 // Set IP after op because we also take the dims of the original output. 266 b.setInsertionPointAfter(opToPad); 267 // Make a copy of the shaped operands and update it. 268 SmallVector<Value> newOperands; 269 newOperands.reserve(opToPad.getNumInputsAndOutputs()); 270 for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) { 271 FailureOr<Value> paddedOperand = padOperandToSmallestStaticBoundingBox( 272 b, opToPad, opOperand, paddingDimensions, paddingValues, packPaddings); 273 // Exit if `paddingDimensions` cannot be bounded statically. 274 if (failed(paddedOperand)) 275 return failure(); 276 newOperands.push_back(*paddedOperand); 277 } 278 279 SmallVector<SmallVector<Value>> reifiedResultShapes; 280 if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation()) 281 .reifyResultShapes(b, reifiedResultShapes))) 282 return failure(); 283 assert(reifiedResultShapes.size() == opToPad->getNumResults() && 284 "expected same number of results"); 285 286 // Clone `opToPad` to operate on the statically padded shapes. 287 auto resultTensorTypes = 288 ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes(); 289 paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands); 290 291 // Recover the slice out of the new static results. This keeps the original 292 // linalg op around because it uses the dims of the original results. 293 SmallVector<Value> paddedSubviewResults; 294 paddedSubviewResults.reserve(opToPad->getNumResults()); 295 for (const auto &en : llvm::enumerate(paddedOp->getResults())) { 296 Value paddedResult = en.value(); 297 int64_t resultNumber = en.index(); 298 int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank(); 299 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0)); 300 SmallVector<OpFoldResult> sizes; 301 for (Value v : reifiedResultShapes[resultNumber]) 302 sizes.push_back(getAsOpFoldResult(v)); 303 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1)); 304 paddedSubviewResults.push_back(b.create<tensor::ExtractSliceOp>( 305 loc, paddedResult, offsets, sizes, strides)); 306 } 307 return paddedSubviewResults; 308 } 309 310 /// Try to peel a loop `op` and return the new result. 311 // TODO: Add support for scf.parallel and affine.for loops. 312 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) { 313 return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op) 314 .Case<scf::ForOp>([&](scf::ForOp forOp) { 315 scf::ForOp partialIteration; 316 if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp, 317 partialIteration))) 318 return partialIteration->getResults(); 319 assert(!partialIteration && "expected that loop was not peeled"); 320 return forOp->getResults(); 321 }) 322 .Default([&](Operation *op) { return op->getResults(); }); 323 } 324 325 /// Peel and canonicalize 'loops'. 326 void mlir::linalg::peelLoops(RewriterBase &rewriter, 327 ArrayRef<scf::ForOp> loops) { 328 for (auto loopOp : loops) { 329 SmallVector<Value, 4> loopResults; 330 loopResults = peelLoop(rewriter, loopOp); 331 } 332 } 333 334 /// Peel loops after tiling. 335 void mlir::linalg::peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res, 336 ArrayRef<int64_t> peeledLoops, 337 LinalgTilingLoopType loopType) { 338 for (int64_t loop : peeledLoops) { 339 assert(loop < static_cast<int64_t>(res.loops.size()) && 340 "requested peeling of non-existing loop"); 341 SmallVector<Value, 4> loopResults; 342 Operation *loopOp = res.loops[loop]; 343 loopResults = peelLoop(rewriter, loopOp); 344 345 // The result of the loop nest may change with peeling. 346 if (res.tensorResults.size() == loopOp->getNumResults() && 347 std::equal(res.tensorResults.begin(), res.tensorResults.end(), 348 loopOp->getResults().begin())) 349 res.tensorResults = loopResults; 350 } 351 } 352 353 /// Linalg tiling pattern. 354 mlir::linalg::LinalgTilingPattern::LinalgTilingPattern( 355 MLIRContext *context, LinalgTilingOptions options, 356 LinalgTransformationFilter f, PatternBenefit benefit) 357 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 358 filter(std::move(f)), options(std::move(options)) {} 359 360 mlir::linalg::LinalgTilingPattern::LinalgTilingPattern( 361 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 362 LinalgTransformationFilter f, PatternBenefit benefit) 363 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 364 filter(f.addOpNameFilter(opName)), options(std::move(options)) {} 365 366 FailureOr<TiledLinalgOp> 367 mlir::linalg::LinalgTilingPattern::returningMatchAndRewrite( 368 LinalgOp op, PatternRewriter &rewriter) const { 369 if (failed(filter.checkAndNotify(rewriter, op))) 370 return failure(); 371 372 FailureOr<TiledLinalgOp> res = tileLinalgOp(rewriter, op, options); 373 if (failed(res)) 374 return failure(); 375 376 // Clear filter to stop recursive pattern application. 377 // This must be done here to properly propagate to peeling branches. 378 filter.replaceLinalgTransformationFilter(rewriter, res->op); 379 380 // Peel the loops of the TiledLinalgOp. 381 peelTiledLinalgOp(rewriter, *res, options.peeledLoops, options.loopType); 382 383 if (res->tensorResults.empty()) 384 rewriter.eraseOp(op); 385 else 386 rewriter.replaceOp(op, res->tensorResults); 387 388 return res; 389 } 390 391 /// Linalg padding pattern. 392 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( 393 MLIRContext *context, LinalgPaddingOptions options, 394 LinalgTransformationFilter f, PatternBenefit benefit) 395 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 396 filter(std::move(f)), options(std::move(options)) {} 397 398 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( 399 StringRef opName, MLIRContext *context, LinalgPaddingOptions options, 400 LinalgTransformationFilter f, PatternBenefit benefit) 401 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 402 filter(f.addOpNameFilter(opName)), options(std::move(options)) {} 403 404 FailureOr<LinalgOp> 405 mlir::linalg::LinalgPaddingPattern::returningMatchAndRewrite( 406 LinalgOp linalgOp, PatternRewriter &rewriter) const { 407 if (!linalgOp.hasTensorSemantics()) 408 return failure(); 409 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 410 return failure(); 411 412 // Pad the operation. 413 LinalgOp paddedOp; 414 FailureOr<SmallVector<Value>> newResults = 415 rewriteAsPaddedOp(rewriter, linalgOp, options.paddingDimensions, 416 options.paddingValues, options.packPaddings, paddedOp); 417 if (failed(newResults)) 418 return failure(); 419 420 // Hoist the padding. 421 for (const auto &en : enumerate(options.hoistPaddings)) { 422 if (static_cast<int64_t>(en.index()) >= paddedOp.getNumInputsAndOutputs()) 423 break; 424 OpOperand *opOperand = &paddedOp->getOpOperand(en.index()); 425 auto padOp = opOperand->get().getDefiningOp<tensor::PadOp>(); 426 if (!padOp || en.value() == 0) 427 continue; 428 429 // Fail hoisting if the operand shape is not fully static. 430 if (llvm::any_of(paddedOp.getShape(opOperand), 431 [](int64_t size) { return ShapedType::isDynamic(size); })) 432 return failure(); 433 434 tensor::PadOp hoistedOp; 435 SmallVector<GenericOp> transposeOps; 436 SmallVector<int64_t> transposeVector = 437 en.index() < options.transposePaddings.size() 438 ? options.transposePaddings[en.index()] 439 : SmallVector<int64_t>{}; 440 441 FailureOr<Value> newResult = hoistPaddingOnTensors( 442 padOp, en.value(), transposeVector, hoistedOp, transposeOps); 443 if (failed(newResult)) 444 continue; 445 rewriter.replaceOp(padOp, *newResult); 446 447 // Do not apply hoist padding to the newly introduced transpose operations. 448 for (GenericOp transposeOp : transposeOps) 449 filter.replaceLinalgTransformationFilter(rewriter, transposeOp); 450 } 451 452 // Replace the original operation to pad. 453 rewriter.replaceOp(linalgOp, *newResults); 454 filter.replaceLinalgTransformationFilter(rewriter, paddedOp); 455 456 return paddedOp; 457 } 458 459 /// Linalg tile and fuse tensor ops pattern. 460 mlir::linalg::LinalgTileAndFuseTensorOpsPattern:: 461 LinalgTileAndFuseTensorOpsPattern(MLIRContext *context, 462 LinalgTilingAndFusionOptions options, 463 LinalgTransformationFilter f, 464 PatternBenefit benefit) 465 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), 466 filter(std::move(f)), options(std::move(options)) {} 467 468 mlir::linalg::LinalgTileAndFuseTensorOpsPattern:: 469 LinalgTileAndFuseTensorOpsPattern(StringRef opName, MLIRContext *context, 470 LinalgTilingAndFusionOptions options, 471 LinalgTransformationFilter f, 472 PatternBenefit benefit) 473 : RewritePattern(opName, benefit, context), filter(std::move(f)), 474 options(std::move(options)) {} 475 476 FailureOr<mlir::linalg::TileLoopNest> 477 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::returningMatchAndRewrite( 478 Operation *op, PatternRewriter &rewriter) const { 479 LinalgOp rootOp = dyn_cast<LinalgOp>(op); 480 if (!rootOp) 481 return failure(); 482 if (failed(filter.checkAndNotify(rewriter, op))) 483 return failure(); 484 485 // Check `tileSizes` contains a tile size for every `rootOp` loop dimension. 486 if (options.tileSizes.size() < rootOp.getNumLoops()) 487 return rewriter.notifyMatchFailure(op, "expect #tile sizes >= #loops"); 488 489 // Check `tileInterchange` contains no entries or as many as `tileSizes`. 490 if (!options.tileInterchange.empty() && 491 options.tileInterchange.size() != options.tileSizes.size()) 492 return rewriter.notifyMatchFailure( 493 op, "expect the number of tile sizes and interchange dims to match"); 494 495 // Copy the `tileSizes` and `tileInterchange` prefixes needed for `rootOp`. 496 SmallVector<int64_t> rootTileSizes(options.tileSizes.begin(), 497 options.tileSizes.begin() + 498 rootOp.getNumLoops()); 499 SmallVector<int64_t> rootInterchange = 500 options.tileInterchange.empty() 501 ? llvm::to_vector<6>(llvm::seq<int64_t>(0, rootOp.getNumLoops())) 502 : SmallVector<int64_t>(options.tileInterchange.begin(), 503 options.tileInterchange.begin() + 504 rootOp.getNumLoops()); 505 506 // Check `rootTileSizes` contains non-zero tile sizes. 507 if (llvm::count(rootTileSizes, 0) == static_cast<long>(rootTileSizes.size())) 508 return rewriter.notifyMatchFailure( 509 op, "expect at least one non-zero tile size"); 510 511 // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions. 512 // It has to be a permutation since the tiling cannot tile the same loop 513 // dimension multiple times. 514 if (!isPermutation(rootInterchange)) 515 return rewriter.notifyMatchFailure( 516 op, "expect the tile interchange permutes the root loops"); 517 518 // Tile `rootOp` and fuse its producers. 519 FailureOr<TileLoopNest> tileLoopNest = 520 tileConsumerAndFuseProducers(rewriter, rootOp, rootTileSizes, 521 rootInterchange, options.tileDistribution); 522 if (failed(tileLoopNest)) 523 return rewriter.notifyMatchFailure( 524 op, "tileConsumerAndFuseProducers failed unexpectedly"); 525 526 // Replace all uses of the tiled loop operation. 527 rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults()); 528 529 // Apply the filter if specified. 530 for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps()) 531 filter.replaceLinalgTransformationFilter(rewriter, linalgOp); 532 return tileLoopNest; 533 } 534 535 /// Linalg generic interchange pattern. 536 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern( 537 MLIRContext *context, ArrayRef<unsigned> interchangeVector, 538 LinalgTransformationFilter f, PatternBenefit benefit) 539 : OpRewritePattern(context, benefit), filter(std::move(f)), 540 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} 541 542 FailureOr<GenericOp> 543 mlir::linalg::GenericOpInterchangePattern::returningMatchAndRewrite( 544 GenericOp genericOp, PatternRewriter &rewriter) const { 545 if (failed(filter.checkAndNotify(rewriter, genericOp))) 546 return failure(); 547 548 FailureOr<GenericOp> transformedOp = 549 interchangeGenericOp(rewriter, genericOp, interchangeVector); 550 if (failed(transformedOp)) 551 return failure(); 552 553 // New filter if specified. 554 filter.replaceLinalgTransformationFilter(rewriter, genericOp); 555 return transformedOp; 556 } 557 558 /// Linalg generalization pattern. 559 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 560 MLIRContext *context, LinalgTransformationFilter f, PatternBenefit benefit) 561 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 562 filter(std::move(f)) {} 563 564 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( 565 StringRef opName, MLIRContext *context, LinalgTransformationFilter f, 566 PatternBenefit benefit) 567 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 568 filter(f.addOpNameFilter(opName)) {} 569 570 FailureOr<GenericOp> 571 mlir::linalg::LinalgGeneralizationPattern::returningMatchAndRewrite( 572 LinalgOp linalgOp, PatternRewriter &rewriter) const { 573 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 574 return failure(); 575 FailureOr<GenericOp> genericOp = generalizeNamedOp(rewriter, linalgOp); 576 if (failed(genericOp)) 577 return failure(); 578 filter.replaceLinalgTransformationFilter(rewriter, *genericOp); 579 return genericOp; 580 } 581 582 mlir::linalg::LinalgPeelingPattern::LinalgPeelingPattern( 583 MLIRContext *context, LinalgTransformationFilter f, 584 LinalgPeelOptions options, PatternBenefit benefit) 585 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 586 filter(std::move(f)), options(std::move(options)) {} 587 588 mlir::linalg::LinalgPeelingPattern::LinalgPeelingPattern( 589 StringRef opName, MLIRContext *context, LinalgPeelOptions options, 590 LinalgTransformationFilter f, PatternBenefit benefit) 591 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 592 filter(f.addOpNameFilter(opName)), options(std::move(options)) {} 593 594 LogicalResult mlir::linalg::LinalgPeelingPattern::matchAndRewrite( 595 LinalgOp linalgOp, PatternRewriter &rewriter) const { 596 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 597 return failure(); 598 599 // Increase marker counter even if peeling doesn't happen for this op. 600 filter.replaceLinalgTransformationFilter(rewriter, linalgOp); 601 602 if (!options.loopsToPeelComputationFunction) 603 return failure(); 604 605 SmallVector<scf::ForOp, 4> loopsToPeel; 606 options.loopsToPeelComputationFunction(rewriter, linalgOp, loopsToPeel); 607 peelLoops(rewriter, loopsToPeel); 608 return success(); 609 } 610 611 mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern( 612 MLIRContext *context, LinalgTransformationFilter f, 613 LinalgVectorizationOptions options, PatternBenefit benefit) 614 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 615 filter(std::move(f)) {} 616 617 mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern( 618 StringRef opName, MLIRContext *context, LinalgVectorizationOptions options, 619 LinalgTransformationFilter f, PatternBenefit benefit) 620 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 621 filter(f.addOpNameFilter(opName)) {} 622 623 LogicalResult mlir::linalg::LinalgVectorizationPattern::matchAndRewrite( 624 LinalgOp linalgOp, PatternRewriter &rewriter) const { 625 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 626 return failure(); 627 return vectorize(rewriter, linalgOp); 628 } 629 630 LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite( 631 memref::CopyOp copyOp, PatternRewriter &rewriter) const { 632 return vectorizeCopy(rewriter, copyOp); 633 } 634 635 LogicalResult mlir::linalg::applyStagedPatterns( 636 Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns, 637 const FrozenRewritePatternSet &stage2Patterns, 638 function_ref<LogicalResult(Operation *)> stage3Lambda) { 639 unsigned iteration = 0; 640 (void)iteration; 641 for (const auto &patterns : stage1Patterns) { 642 LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" 643 << *op); 644 if (failed(applyPatternsAndFoldGreedily(op, patterns))) { 645 LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); 646 return failure(); 647 } 648 LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" 649 << *op); 650 if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { 651 LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); 652 return failure(); 653 } 654 LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" 655 << *op); 656 if (stage3Lambda) { 657 if (failed(stage3Lambda(op))) 658 return failure(); 659 LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" 660 << *op); 661 } 662 } 663 return success(); 664 } 665 666 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) { 667 return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName()); 668 } 669 670 /// Rewrite a tensor::PadOp into a sequence of InitTensorOp, FillOp (to 671 /// initialize with pad_val) and GenericOp (to copy contents). 672 LogicalResult 673 PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp, 674 PatternRewriter &rewriter) const { 675 676 auto inputShapedType = padOp.getSource().getType().cast<ShapedType>(); 677 auto resultShapedType = padOp.getResult().getType().cast<ShapedType>(); 678 679 // Bail on non-static shapes. 680 if (!inputShapedType.hasStaticShape()) 681 return failure(); 682 if (!resultShapedType.hasStaticShape()) 683 return failure(); 684 685 // Only support padding with a constant for now, i.e. either: 686 // 1. A BBarg from a different block. 687 // 2. A value defined outside of the current block. 688 Block &block = padOp.getRegion().front(); 689 auto yieldOp = cast<tensor::YieldOp>(block.getTerminator()); 690 Value padValue = yieldOp.getValue(); 691 Operation *definingOp = padValue.getDefiningOp(); 692 if (definingOp && definingOp->getBlock() == &block) 693 return failure(); 694 if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block) 695 return failure(); 696 697 // Create tensor with the padded shape 698 Location loc = padOp.getLoc(); 699 SmallVector<Value> indices(resultShapedType.getRank(), 700 rewriter.create<arith::ConstantIndexOp>(loc, 0)); 701 Value initTensor = rewriter.create<InitTensorOp>( 702 loc, resultShapedType.getShape(), resultShapedType.getElementType()); 703 704 // Initialize tensor with the pad value 705 Value tmpTensor = rewriter 706 .create<linalg::FillOp>(loc, ValueRange{padValue}, 707 ValueRange{initTensor}) 708 .result(); 709 710 // Copy original contents into new tensor 711 // Uses linalg.generic, but could be done with tensor.insert_slice 712 SmallVector<AffineExpr, 4> outputExprs; 713 for (unsigned i = 0; i < resultShapedType.getRank(); ++i) { 714 outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) + 715 padOp.getStaticLow()[i].cast<IntegerAttr>().getInt()); 716 } 717 718 SmallVector<AffineMap, 2> transferMaps = { 719 rewriter.getMultiDimIdentityMap(inputShapedType.getRank()), 720 AffineMap::get(resultShapedType.getRank(), 721 /*symbolCount=*/0, outputExprs, rewriter.getContext())}; 722 723 rewriter.replaceOpWithNewOp<linalg::GenericOp>( 724 padOp, resultShapedType, padOp.getSource(), tmpTensor, transferMaps, 725 getNParallelLoopsAttrs(resultShapedType.getRank()), 726 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { 727 nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]); 728 }); 729 730 return success(); 731 } 732 733 /// Filling `dest` using FillOp constant padding value if possible. 734 /// Otherwise, generate a tensor::GenerateOp. 735 Value GeneralizePadOpPattern::createFillOrGenerateOp( 736 PatternRewriter &rewriter, tensor::PadOp padOp, Value dest, 737 const SmallVector<Value> &dynSizes) const { 738 auto padValue = padOp.getConstantPaddingValue(); 739 if (padValue) 740 return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result(); 741 742 // Fill could not be optimized: Lower to tensor::GenerateOp with region. 743 auto generateOp = rewriter.create<tensor::GenerateOp>( 744 padOp.getLoc(), padOp.getResultType(), dynSizes); 745 // Copy region to new op. 746 BlockAndValueMapping bvm; 747 padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm); 748 return generateOp; 749 } 750 751 LogicalResult 752 GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp, 753 PatternRewriter &rewriter) const { 754 // Given an OpFoldResult, return an index-typed value. 755 auto getIdxValue = [&](OpFoldResult ofr) { 756 if (auto val = ofr.dyn_cast<Value>()) 757 return val; 758 return rewriter 759 .create<arith::ConstantIndexOp>( 760 padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt()) 761 .getResult(); 762 }; 763 764 auto resultType = padOp.getResultType(); 765 // Compute size of InitTensorOp. Any combination of static/dynamic is 766 // supported. 767 SmallVector<Value> dynSizes; 768 SmallVector<int64_t> staticSizes; 769 for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { 770 if (resultType.isDynamicDim(dim)) { 771 auto srcSize = rewriter.createOrFold<tensor::DimOp>( 772 padOp.getLoc(), padOp.getSource(), dim); 773 // Add low and high padding value. 774 auto plusLow = rewriter.createOrFold<arith::AddIOp>( 775 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); 776 auto plusHigh = rewriter.createOrFold<arith::AddIOp>( 777 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); 778 dynSizes.push_back(plusHigh); 779 } 780 staticSizes.push_back(resultType.getDimSize(dim)); 781 } 782 783 // Init tensor and fill it with padding. 784 Value init = rewriter.create<InitTensorOp>( 785 padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType()); 786 Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes); 787 788 // Try optimize the copy of source. 789 if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded()) 790 return success(); 791 792 // tensor::PadOps cannot be optimized. Generate a InsertSliceOp instead 793 // for copying the PadOp source. 794 auto sourceType = padOp.getSourceType(); 795 // Compute size of source of tensor::PadOp. 796 SmallVector<OpFoldResult> srcSizes; 797 for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) { 798 if (sourceType.isDynamicDim(dim)) { 799 srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>( 800 padOp.getLoc(), padOp.getSource(), dim)); 801 } else { 802 srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim))); 803 } 804 } 805 // Strides of InsertSliceOp are all 1. 806 SmallVector<OpFoldResult> strides(sourceType.getRank(), 807 rewriter.getIndexAttr(1)); 808 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( 809 padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes, 810 strides); 811 812 return success(); 813 } 814 815 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( 816 tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { 817 if (!sliceOp.hasUnitStride()) 818 return failure(); 819 820 auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>(); 821 if (!padOp) 822 return failure(); 823 824 bool zeroSliceGuard = true; 825 if (controlFn) { 826 if (Optional<bool> control = controlFn(sliceOp)) 827 zeroSliceGuard = *control; 828 else 829 return failure(); 830 } 831 832 Operation *tiledPadOp = 833 tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(), 834 sliceOp.getMixedSizes(), zeroSliceGuard); 835 // All shapes are static and the data source is actually used. Rewrite into 836 // pad(extract_slice(x)). 837 rewriter.replaceOp(sliceOp, tiledPadOp->getResults()); 838 return success(); 839 } 840 841 // The following are patterns for downscaling convolution ops with size-1 842 // window dimensions. 843 // 844 // Note that we'd eventually want to write such transformations in a generic 845 // way, e.g., converting to linalg.generic, removing the size-1 dimensions, 846 // and then turning back to named ops. But for now it's fine to have a few 847 // patterns matching special ops to get started. 848 849 FailureOr<Conv1DNwcWcfOp> 850 DownscaleSizeOneWindowed2DConvolution::returningMatchAndRewrite( 851 linalg::Conv2DNhwcHwcfOp convOp, PatternRewriter &rewriter) const { 852 if (failed(filter.checkAndNotify(rewriter, convOp))) 853 return failure(); 854 if (convOp.hasBufferSemantics()) 855 return failure(); // To be implemented. 856 857 Value input = convOp.inputs().front(); 858 Value kernel = convOp.inputs().back(); 859 Value output = convOp.outputs().front(); 860 861 auto inputType = input.getType().dyn_cast<RankedTensorType>(); 862 auto kernelType = kernel.getType().dyn_cast<RankedTensorType>(); 863 auto outputType = output.getType().dyn_cast<RankedTensorType>(); 864 865 auto kernelShape = kernelType.getShape(); 866 auto outputShape = outputType.getShape(); 867 868 // Only handle the case where at least one of the window dimensions is 869 // of size 1. Other cases can rely on tiling to reduce to such cases. 870 int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; 871 int64_t ohSize = outputShape[1], owSize = outputShape[2]; 872 bool removeH = (khSize == 1 && ohSize == 1); 873 bool removeW = (kwSize == 1 && owSize == 1); 874 if (!removeH && !removeW) 875 return failure(); 876 877 // Get new shapes and types for all operands by removing the size-1 878 // dimension. 879 using RTTBuilder = RankedTensorType::Builder; 880 RankedTensorType newInputType = 881 RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); 882 RankedTensorType newKernelType = 883 RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); 884 RankedTensorType newOutputType = 885 RTTBuilder(outputType).dropDim(removeH ? 1 : 2); 886 887 // Rank-reduce operands. 888 Location loc = convOp.getLoc(); 889 Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( 890 rewriter, loc, input, newInputType); 891 Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( 892 rewriter, loc, kernel, newKernelType); 893 Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( 894 rewriter, loc, output, newOutputType); 895 896 // Rank-reduce strides and dilations too. 897 // TODO: dropDim 1-liner helper. 898 auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>()); 899 strides.erase(strides.begin() + (removeH ? 0 : 1)); 900 auto stridesAttr = rewriter.getI64VectorAttr(strides); 901 902 auto dilations = llvm::to_vector<4>(convOp.dilations().getValues<int64_t>()); 903 dilations.erase(dilations.begin() + (removeH ? 0 : 1)); 904 auto dilationsAttr = rewriter.getI64VectorAttr(dilations); 905 906 auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>( 907 loc, newOutputType, ValueRange{newInput, newKernel}, 908 ValueRange{newOutput}, stridesAttr, dilationsAttr); 909 910 // Insert back. 911 Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( 912 rewriter, loc, conv1DOp.getResult(0), output); 913 rewriter.replaceOp(convOp, inserted); 914 915 filter.replaceLinalgTransformationFilter(rewriter, conv1DOp); 916 return conv1DOp; 917 } 918 919 FailureOr<DepthwiseConv1DNwcWcOp> 920 DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite( 921 DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const { 922 if (failed(filter.checkAndNotify(rewriter, convOp))) 923 return failure(); 924 if (convOp.hasBufferSemantics()) 925 return failure(); // To be implemented. 926 927 Value input = convOp.inputs().front(); 928 Value kernel = convOp.inputs().back(); 929 Value output = convOp.outputs().front(); 930 931 auto inputType = input.getType().dyn_cast<RankedTensorType>(); 932 auto kernelType = kernel.getType().dyn_cast<RankedTensorType>(); 933 auto outputType = output.getType().dyn_cast<RankedTensorType>(); 934 935 auto kernelShape = kernelType.getShape(); 936 auto outputShape = outputType.getShape(); 937 938 // Only handle the case where at least one of the window dimensions is 939 // of size 1. Other cases can rely on tiling to reduce to such cases. 940 int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; 941 int64_t ohSize = outputShape[1], owSize = outputShape[2]; 942 bool removeH = (khSize == 1 && ohSize == 1); 943 bool removeW = (kwSize == 1 && owSize == 1); 944 if (!removeH && !removeW) 945 return failure(); 946 947 // Get new shapes and types for all operands by removing the size-1 948 // dimension. 949 using RTTBuilder = RankedTensorType::Builder; 950 RankedTensorType newInputType = 951 RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); 952 RankedTensorType newKernelType = 953 RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); 954 RankedTensorType newOutputType = 955 RTTBuilder(outputType).dropDim(removeH ? 1 : 2); 956 957 // Rank-reduce operands. 958 Location loc = convOp.getLoc(); 959 Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( 960 rewriter, loc, input, newInputType); 961 Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( 962 rewriter, loc, kernel, newKernelType); 963 Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( 964 rewriter, loc, output, newOutputType); 965 966 // Rank-reduce strides and dilations too. 967 // TODO: dropDim 1-liner helper. 968 auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>()); 969 strides.erase(strides.begin() + (removeH ? 0 : 1)); 970 auto stridesAttr = rewriter.getI64VectorAttr(strides); 971 972 auto dilations = llvm::to_vector<4>(convOp.dilations().getValues<int64_t>()); 973 dilations.erase(dilations.begin() + (removeH ? 0 : 1)); 974 auto dilationsAttr = rewriter.getI64VectorAttr(dilations); 975 976 auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>( 977 loc, newOutputType, ValueRange{newInput, newKernel}, 978 ValueRange{newOutput}, stridesAttr, dilationsAttr); 979 980 // Insert back. 981 Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( 982 rewriter, loc, conv1DOp.getResult(0), output); 983 rewriter.replaceOp(convOp, inserted); 984 985 filter.replaceLinalgTransformationFilter(rewriter, conv1DOp); 986 return conv1DOp; 987 } 988 989 void linalg::populateDecomposeConvolutionPatterns( 990 RewritePatternSet &patterns, const LinalgTransformationFilter &filter, 991 PatternBenefit benefit) { 992 patterns.add<DownscaleSizeOneWindowed2DConvolution, 993 DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), filter, 994 benefit); 995 } 996