1 //===- Tiling.cpp - Implementation of linalg Tiling -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Tiling pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include <utility> 14 15 #include "PassDetail.h" 16 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 17 #include "mlir/Dialect/Linalg/IR/Linalg.h" 18 #include "mlir/Dialect/Linalg/Passes.h" 19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 23 #include "mlir/Dialect/Tensor/IR/Tensor.h" 24 #include "mlir/Dialect/Utils/IndexingUtils.h" 25 #include "mlir/IR/AffineExpr.h" 26 #include "mlir/IR/AffineMap.h" 27 #include "mlir/Transforms/FoldUtils.h" 28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 29 30 #include "llvm/Support/CommandLine.h" 31 32 using namespace mlir; 33 using namespace mlir::linalg; 34 using namespace mlir::scf; 35 36 #define DEBUG_TYPE "linalg-tiling" 37 38 static bool isZero(Value v) { 39 if (auto cst = v.getDefiningOp<arith::ConstantIndexOp>()) 40 return cst.value() == 0; 41 return false; 42 } 43 44 std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap> 45 mlir::linalg::makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map, 46 ValueRange allShapeSizes, 47 ValueRange allTileSizes) { 48 assert(allTileSizes.size() == map.getNumResults()); 49 // Apply `map` to get shape sizes in loop order. 50 auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes); 51 SmallVector<Value, 4> tileSizes(allTileSizes.begin(), allTileSizes.end()); 52 53 // Traverse the tile sizes, which are in loop order, erase zeros everywhere. 54 LoopIndexToRangeIndexMap loopIndexToRangeIndex; 55 for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) { 56 if (isZero(tileSizes[idx - zerosCount])) { 57 shapeSizes.erase(shapeSizes.begin() + idx - zerosCount); 58 tileSizes.erase(tileSizes.begin() + idx - zerosCount); 59 ++zerosCount; 60 continue; 61 } 62 loopIndexToRangeIndex[idx] = idx - zerosCount; 63 } 64 65 // Create a new range with the applied tile sizes. 66 SmallVector<Range, 4> res; 67 for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) 68 res.push_back(Range{b.create<arith::ConstantIndexOp>(loc, 0), 69 shapeSizes[idx], tileSizes[idx]}); 70 return std::make_tuple(res, loopIndexToRangeIndex); 71 } 72 73 void mlir::linalg::transformIndexOps( 74 RewriterBase &b, LinalgOp op, SmallVectorImpl<Value> &ivs, 75 const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) { 76 SmallVector<Value> allIvs(op.getNumLoops(), nullptr); 77 for (auto &en : enumerate(allIvs)) { 78 auto rangeIndex = loopIndexToRangeIndex.find(en.index()); 79 if (rangeIndex == loopIndexToRangeIndex.end()) 80 continue; 81 en.value() = ivs[rangeIndex->second]; 82 } 83 offsetIndices(b, op, allIvs); 84 } 85 86 /// Asserts that the given index-typed value is strictly positive. If the value 87 /// is an attribute, asserts at compile time, otherwise emits an assertion 88 /// checked at runtime. 89 static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b, 90 OpFoldResult value) { 91 if (auto attr = value.dyn_cast<Attribute>()) { 92 assert(attr.cast<IntegerAttr>().getValue().isStrictlyPositive() && 93 "expected strictly positive tile size and divisor"); 94 return; 95 } 96 97 Value zero = b.create<arith::ConstantIndexOp>(0); 98 Value condition = b.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, 99 value.get<Value>(), zero); 100 b.create<cf::AssertOp>( 101 condition, 102 b.getStringAttr("expected strictly positive tile size and divisor")); 103 } 104 105 FailureOr<MultiSizeSpecification> 106 mlir::linalg::computeMultiTileSizes(OpBuilder &builder, LinalgOp op, 107 unsigned dimension, OpFoldResult targetSize, 108 OpFoldResult divisor, bool emitAssertions) { 109 // Bail out on dimension overflow. 110 if (dimension >= op.getNumLoops()) 111 return failure(); 112 113 // The code below works only on values. 114 ImplicitLocOpBuilder b(op.getLoc(), builder); 115 if (emitAssertions) { 116 emitIsPositiveIndexAssertion(b, targetSize); 117 emitIsPositiveIndexAssertion(b, divisor); 118 } 119 Value targetSizeValue = materializeOpFoldResult(b, targetSize); 120 Value divisorValue = materializeOpFoldResult(b, divisor); 121 122 // Find the trip count of the iteration space dimension for which the tile 123 // sizes are computed. 124 // TODO: update createFlatListOfOperandDims to return OpFoldResults and avoid 125 // littering by useless constant materialization. 126 SmallVector<Value, 4> allShapes = 127 op.createFlatListOfOperandDims(b, b.getLoc()); 128 AffineMap shapesToLoops = op.getShapesToLoopsMap(); 129 SmallVector<Value, 4> loopRanges = 130 applyMapToValues(b, op.getLoc(), shapesToLoops, allShapes); 131 Value tripCount = loopRanges[dimension]; 132 133 // Compute the tile sizes and the respective numbers of tiles. 134 AffineExpr s0 = b.getAffineSymbolExpr(0); 135 AffineExpr s1 = b.getAffineSymbolExpr(1); 136 AffineExpr s2 = b.getAffineSymbolExpr(2); 137 auto apply = [&](AffineExpr expr, ValueRange values) -> Value { 138 return makeComposedAffineApply(b, b.getLoc(), expr, values); 139 }; 140 Value a = apply(s0.floorDiv(s1), {tripCount, divisorValue}); 141 Value t = apply((s0 + s1 - 1).floorDiv(s1), {targetSizeValue, divisorValue}); 142 Value d = apply((s0 + s1 - 1).floorDiv(s1), {a, t}); 143 Value s = apply(s0.floorDiv(s1) * s2, {a, d, divisorValue}); 144 Value v = apply(s0 % s1, {a, d}); 145 Value u = apply(s0 - s1, {d, v}); 146 147 MultiSizeSpecification spec; 148 spec.lowTileSize = s; 149 spec.highTileSize = apply(s0 + s1, {s, divisorValue}); 150 spec.lowTripCount = u; 151 spec.highTripCount = v; 152 153 // If requested, emit the check that the tile sizes are computed correctly. 154 // For example, for iteration dimension size of 15 and the target size 8 it is 155 // impossible to find two tile sizes both divisible by 8 that fully cover the 156 // original space dimension. 157 if (emitAssertions) { 158 AffineExpr s3 = builder.getAffineSymbolExpr(3); 159 Value coveredSize = 160 apply(s0 * s1 + s2 * s3, {spec.lowTileSize, spec.lowTripCount, 161 spec.highTileSize, spec.highTripCount}); 162 Value equals = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, 163 coveredSize, tripCount); 164 b.create<cf::AssertOp>( 165 equals, builder.getStringAttr( 166 "could not compute dynamic multi-size tile shapes")); 167 } 168 169 return spec; 170 } 171 172 // Insert a tile `source` into the destination tensor `dest`. The position at 173 // which the tile is inserted (as well as size of tile) is taken from a given 174 // ExtractSliceOp `sliceOp`. 175 static Value insertSliceIntoTensor(RewriterBase &b, Location loc, 176 tensor::ExtractSliceOp sliceOp, Value source, 177 Value dest) { 178 return b.create<tensor::InsertSliceOp>( 179 loc, sliceOp.getSource().getType(), source, dest, sliceOp.getOffsets(), 180 sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(), 181 sliceOp.getStaticSizes(), sliceOp.getStaticStrides()); 182 } 183 184 template <typename LoopTy> 185 static FailureOr<TiledLinalgOp> 186 tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ValueRange tileSizes, 187 const LinalgTilingOptions &options) { 188 auto nLoops = op.getNumLoops(); 189 // Initial tile sizes may be too big, only take the first nLoops. 190 tileSizes = tileSizes.take_front(nLoops); 191 192 if (llvm::all_of(tileSizes, isZero)) { 193 TiledLinalgOp tiledOp; 194 tiledOp.op = cast<LinalgOp>(b.clone(*op.getOperation())); 195 tiledOp.tensorResults.assign(tiledOp.op->result_begin(), 196 tiledOp.op->result_end()); 197 return tiledOp; 198 } 199 200 // 1. Build the tiled loop ranges. 201 auto allShapeSizes = op.createFlatListOfOperandDims(b, op.getLoc()); 202 AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap(); 203 if (!shapeSizesToLoopsMap) 204 return failure(); 205 206 SmallVector<Range, 4> loopRanges; 207 LoopIndexToRangeIndexMap loopIndexToRangeIndex; 208 std::tie(loopRanges, loopIndexToRangeIndex) = makeTiledLoopRanges( 209 b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes); 210 211 SmallVector<Attribute, 4> iteratorTypes; 212 for (const auto &attr : 213 enumerate(op.iterator_types().cast<ArrayAttr>().getValue())) { 214 if (loopIndexToRangeIndex.count(attr.index())) 215 iteratorTypes.push_back(attr.value()); 216 } 217 // If interchangeVector is empty, use the identity. Build the permutation map 218 // otherwise. 219 auto invPermutationMap = 220 AffineMap::getMultiDimIdentityMap(tileSizes.size(), b.getContext()); 221 if (!options.interchangeVector.empty()) { 222 // Based on the pruned iterations (due to zero tile size), recompute the 223 // interchange vector. 224 SmallVector<unsigned, 4> interchangeVector; 225 interchangeVector.reserve(options.interchangeVector.size()); 226 for (auto pos : options.interchangeVector) { 227 auto it = loopIndexToRangeIndex.find(pos); 228 if (it == loopIndexToRangeIndex.end()) 229 continue; 230 interchangeVector.push_back(it->second); 231 } 232 // Interchange vector is guaranteed to be a permutation, 233 // `inversePermutation` must succeed. 234 invPermutationMap = inversePermutation( 235 AffineMap::getPermutationMap(interchangeVector, b.getContext())); 236 assert(invPermutationMap); 237 SmallVector<int64_t> permutation(interchangeVector.begin(), 238 interchangeVector.end()); 239 applyPermutationToVector(loopRanges, permutation); 240 applyPermutationToVector(iteratorTypes, permutation); 241 } 242 243 // 2. Create the tiled loops. 244 LinalgOp res = op; 245 SmallVector<Value, 4> ivs, tensorResults; 246 auto tiledLoopBodyBuilder = 247 [&](OpBuilder &builder, Location loc, ValueRange localIvs, 248 ValueRange operandValuesToUse) -> scf::ValueVector { 249 ivs.assign(localIvs.begin(), localIvs.end()); 250 251 // When an `interchangeVector` is present, it has been applied to the 252 // loop ranges and the iterator types. Apply its inverse to the 253 // resulting loop `ivs` to match the op definition. 254 SmallVector<Value, 4> interchangedIvs; 255 if (!options.interchangeVector.empty()) 256 interchangedIvs = applyMapToValues(b, loc, invPermutationMap, ivs); 257 else 258 interchangedIvs.assign(ivs.begin(), ivs.end()); 259 260 // Tile the `operandValuesToUse` that either match the `op` operands 261 // themselves or the tile loop arguments forwarding them. 262 assert(operandValuesToUse.size() == 263 static_cast<size_t>(op.getNumInputsAndOutputs()) && 264 "expect the number of operands and inputs and outputs to match"); 265 SmallVector<Value> valuesToTile = operandValuesToUse; 266 auto sizeBounds = 267 applyMapToValues(b, loc, shapeSizesToLoopsMap, allShapeSizes); 268 SmallVector<Value, 4> tiledOperands = 269 makeTiledShapes(b, loc, op, valuesToTile, interchangedIvs, tileSizes, 270 sizeBounds, /*omitPartialTileCheck=*/false); 271 272 SmallVector<Type> resultTensorTypes = 273 getTensorOutputTypes(op, tiledOperands); 274 res = op.clone(b, loc, resultTensorTypes, tiledOperands); 275 tensorResults = 276 insertSlicesBack(builder, loc, op, tiledOperands, res->getResults()); 277 return scf::ValueVector(tensorResults.begin(), tensorResults.end()); 278 }; 279 GenerateLoopNest<LoopTy>::doit(b, op.getLoc(), loopRanges, op, iteratorTypes, 280 tiledLoopBodyBuilder, options.distribution, 281 options.distributionTypes); 282 283 // 3. Transform IndexOp results w.r.t. the tiling. 284 transformIndexOps(b, res, ivs, loopIndexToRangeIndex); 285 286 // 4. Gather the newly created loops and return them with the new op. 287 SmallVector<Operation *, 8> loops; 288 loops.reserve(ivs.size()); 289 for (auto iv : ivs) { 290 if (iv.isa<BlockArgument>()) { 291 loops.push_back(iv.cast<BlockArgument>().getOwner()->getParentOp()); 292 assert(loops.back() && "no owner found for induction variable!"); 293 } else { 294 // TODO: Instead of doing this, try to recover the ops used instead of the 295 // loop. 296 loops.push_back(nullptr); 297 } 298 } 299 300 // 5. Get the tensor results from the outermost loop if available. Otherwise 301 // use the previously captured `tensorResults`. 302 Operation *outermostLoop = nullptr; 303 for (Operation *loop : loops) 304 if ((outermostLoop = loop)) 305 break; 306 307 return TiledLinalgOp{ 308 res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults}; 309 } 310 311 template <typename LoopTy> 312 FailureOr<TiledLinalgOp> static tileLinalgOpImpl( 313 RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options) { 314 OpBuilder::InsertionGuard g(b); 315 b.setInsertionPoint(op); 316 317 if (!options.tileSizeComputationFunction) 318 return failure(); 319 320 // Enforce the convention that "tiling by zero" skips tiling a particular 321 // dimension. This convention is significantly simpler to handle instead of 322 // adjusting affine maps to account for missing dimensions. 323 auto nLoops = op.getNumLoops(); 324 SmallVector<Value, 4> tileSizeVector = 325 options.tileSizeComputationFunction(b, op); 326 if (tileSizeVector.size() < nLoops) { 327 auto zero = b.create<arith::ConstantIndexOp>(op.getLoc(), 0); 328 tileSizeVector.append(nLoops - tileSizeVector.size(), zero); 329 } 330 331 return tileLinalgOpImpl<LoopTy>(b, op, tileSizeVector, options); 332 } 333 334 FailureOr<TiledLinalgOp> 335 mlir::linalg::tileLinalgOp(RewriterBase &b, LinalgOp op, 336 const LinalgTilingOptions &options) { 337 switch (options.loopType) { 338 case LinalgTilingLoopType::Loops: 339 return tileLinalgOpImpl<scf::ForOp>(b, op, options); 340 case LinalgTilingLoopType::ParallelLoops: 341 return tileLinalgOpImpl<scf::ParallelOp>(b, op, options); 342 default:; 343 } 344 return failure(); 345 } 346 347 /// Generate a loop nest around a given tensor::PadOp (for tiling). `newPadOp` 348 /// and `loopNest` are output parameters that return the new (tiled) 349 /// tensor::PadOp and the loop nest. 350 static LogicalResult tilePadOp(RewriterBase &builder, tensor::PadOp op, 351 tensor::PadOp &newPadOp, LoopNest &loopNest, 352 const LinalgTilingOptions &options) { 353 Location loc = op.getLoc(); 354 OpBuilder::InsertionGuard g(builder); 355 builder.setInsertionPoint(op); 356 357 // Clone tensor::PadOp so that the existing op can be replaced more easily. 358 newPadOp = cast<tensor::PadOp>(builder.clone(*op.getOperation())); 359 // Get rank and tile sizes. 360 int64_t rank = op.getResultType().getRank(); 361 SmallVector<Value> tileSizes = 362 options.tileSizeComputationFunction(builder, op); 363 // Normalize untiled padding dimensions to 0. 364 Value zero = builder.create<arith::ConstantIndexOp>(loc, 0); 365 tileSizes.append(rank - tileSizes.size(), zero); 366 // Compute lower and upper bounds of the loop nest. 367 TilingInterface tilingInterface = 368 dyn_cast<TilingInterface>(op.getOperation()); 369 SmallVector<Range> ranges = tilingInterface.getIterationDomain(builder); 370 SmallVector<Value> lbs, dims, allDims, steps; 371 for (int64_t i = 0; i < rank; ++i) { 372 allDims.push_back(ranges[i].size); 373 if (!isZero(tileSizes[i])) { 374 lbs.push_back(ranges[i].offset); 375 dims.push_back(ranges[i].size); 376 steps.push_back(tileSizes[i]); 377 } 378 } 379 // Generate loop nest: One loop per dimension. 380 SmallVector<Value> destOperand = 381 tilingInterface.getDestinationOperands(builder); 382 loopNest = mlir::scf::buildLoopNest( 383 builder, loc, lbs, /*ubs=*/dims, steps, ValueRange(destOperand), 384 [&](OpBuilder &b, Location loc, ValueRange localIvs, 385 ValueRange iterArgs) -> scf::ValueVector { 386 // Compute offsets and sizes of ExtractSliceOp. 387 SmallVector<Value> offsets = 388 computeTileOffsets(b, loc, localIvs, tileSizes); 389 SmallVector<Value> sizes = computeTileSizes(b, loc, tileSizes, allDims); 390 // Create ExtractSliceOp: Extract a tile from the tensor::PadOp. 391 // Note: The tensor::PadOp is located outside of the loop nest. It is 392 // later moved inside by ExtractSliceOfPadTensorSwapPattern. 393 auto map = AffineMap::getMultiDimIdentityMap(rank, b.getContext()); 394 Value tiledOutput = makeTiledShape( 395 b, loc, newPadOp->getResult(0), tileSizes, map, offsets, allDims, 396 sizes, /*omitPartialTileCheck=*/false); 397 auto sliceOp = tiledOutput.getDefiningOp<tensor::ExtractSliceOp>(); 398 assert(sliceOp && "expected ExtractSliceOp"); 399 // Insert the tile into the output tensor. 400 // TODO: Propagate RewriterBase everywhere. 401 IRRewriter rewriter(b); 402 Value yieldValue = 403 insertSliceIntoTensor(rewriter, loc, sliceOp, sliceOp, iterArgs[0]); 404 return scf::ValueVector({yieldValue}); 405 }); 406 return success(); 407 } 408 409 namespace { 410 struct PadOpTilingPattern : public OpRewritePattern<tensor::PadOp> { 411 PadOpTilingPattern(MLIRContext *ctx, LinalgTilingOptions opt) 412 : OpRewritePattern<tensor::PadOp>(ctx), options(std::move(opt)) {} 413 414 LogicalResult matchAndRewrite(tensor::PadOp op, 415 PatternRewriter &rewriter) const override { 416 if (op->hasAttr(LinalgTransforms::kLinalgTransformMarker)) 417 return failure(); 418 tensor::PadOp newPadOp; 419 LoopNest loopNest; 420 if (failed(tilePadOp(rewriter, op, newPadOp, loopNest, options))) 421 return failure(); 422 newPadOp->setAttr(LinalgTransforms::kLinalgTransformMarker, 423 rewriter.getUnitAttr()); 424 // Replace all uses of the original tensor::PadOp. 425 rewriter.replaceOp(op, loopNest.getResults()[0]); 426 return success(); 427 } 428 429 LinalgTilingOptions options; 430 }; 431 } // namespace 432 433 namespace { 434 /// Helper classes for type list expansion. 435 template <typename... OpTypes> 436 class CanonicalizationPatternList; 437 438 template <> 439 class CanonicalizationPatternList<> { 440 public: 441 static void insert(RewritePatternSet &patterns) {} 442 }; 443 444 template <typename OpTy, typename... OpTypes> 445 class CanonicalizationPatternList<OpTy, OpTypes...> { 446 public: 447 static void insert(RewritePatternSet &patterns) { 448 OpTy::getCanonicalizationPatterns(patterns, patterns.getContext()); 449 CanonicalizationPatternList<OpTypes...>::insert(patterns); 450 } 451 }; 452 } // namespace 453 454 RewritePatternSet 455 mlir::linalg::getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx) { 456 RewritePatternSet patterns(ctx); 457 populateLinalgTilingCanonicalizationPatterns(patterns); 458 return patterns; 459 } 460 461 void mlir::linalg::populateLinalgTilingCanonicalizationPatterns( 462 RewritePatternSet &patterns) { 463 auto *ctx = patterns.getContext(); 464 AffineApplyOp::getCanonicalizationPatterns(patterns, ctx); 465 AffineForOp::getCanonicalizationPatterns(patterns, ctx); 466 AffineMinOp::getCanonicalizationPatterns(patterns, ctx); 467 AffineMaxOp::getCanonicalizationPatterns(patterns, ctx); 468 arith::ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx); 469 470 memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx); 471 memref::ViewOp::getCanonicalizationPatterns(patterns, ctx); 472 473 scf::ForOp::getCanonicalizationPatterns(patterns, ctx); 474 scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx); 475 476 tensor::CastOp::getCanonicalizationPatterns(patterns, ctx); 477 tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx); 478 tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx); 479 480 InitTensorOp::getCanonicalizationPatterns(patterns, ctx); 481 tensor::PadOp::getCanonicalizationPatterns(patterns, ctx); 482 ctx->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(patterns); 483 484 CanonicalizationPatternList< 485 #define GET_OP_LIST 486 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 487 >::insert(patterns); 488 } 489 490 /// Populate the given list with patterns that apply Linalg tiling. 491 static void insertTilingPatterns(RewritePatternSet &patterns, 492 const LinalgTilingOptions &options) { 493 auto *ctx = patterns.getContext(); 494 LinalgTransformationFilter f(ArrayRef<StringAttr>{}, 495 StringAttr::get(ctx, "tiled")); 496 TilingPatterns<GenericOp, 497 #define GET_OP_LIST 498 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 499 >::insert(patterns, options, f); 500 patterns.add<PadOpTilingPattern>(ctx, options); 501 } 502 503 void mlir::linalg::populatePadTensorTilingPatterns( 504 RewritePatternSet &patterns, const LinalgTilingOptions &options) { 505 auto *ctx = patterns.getContext(); 506 patterns.add<PadOpTilingPattern>(ctx, options); 507 } 508 509 static void applyExtractSliceOfPadTensorSwapPattern(func::FuncOp funcOp) { 510 MLIRContext *ctx = funcOp.getContext(); 511 RewritePatternSet patterns(ctx); 512 patterns.add<ExtractSliceOfPadTensorSwapPattern>(patterns.getContext()); 513 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 514 (void)applyPatternsAndFoldGreedily( 515 funcOp, getLinalgTilingCanonicalizationPatterns(ctx)); 516 } 517 518 namespace { 519 struct LinalgTilingPass : public LinalgTilingBase<LinalgTilingPass> { 520 LinalgTilingPass() = default; 521 LinalgTilingPass(ArrayRef<int64_t> tileSizes, LinalgTilingLoopType loopType) { 522 this->tileSizes = tileSizes; 523 this->loopType = ""; 524 this->loopTypeEnum = loopType; 525 } 526 527 void runOnOperation() override { 528 func::FuncOp funcOp = getOperation(); 529 LinalgTilingLoopType type = 530 llvm::StringSwitch<LinalgTilingLoopType>(loopType) 531 .Case("for", LinalgTilingLoopType::Loops) 532 .Case("affine", LinalgTilingLoopType::AffineLoops) 533 .Case("parallel", LinalgTilingLoopType::ParallelLoops) 534 .Default(loopTypeEnum); 535 auto options = 536 LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(type); 537 MLIRContext *ctx = funcOp.getContext(); 538 RewritePatternSet patterns(ctx); 539 insertTilingPatterns(patterns, options); 540 scf::populateSCFForLoopCanonicalizationPatterns(patterns); 541 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 542 (void)applyPatternsAndFoldGreedily( 543 funcOp, getLinalgTilingCanonicalizationPatterns(ctx)); 544 // Drop the marker. 545 funcOp.walk([](LinalgOp op) { 546 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 547 }); 548 549 // Apply swap pattern after generating loop nest and running 550 // canonicalizations. 551 applyExtractSliceOfPadTensorSwapPattern(funcOp); 552 } 553 554 LinalgTilingLoopType loopTypeEnum; 555 }; 556 557 } // namespace 558 559 std::unique_ptr<OperationPass<func::FuncOp>> 560 mlir::createLinalgTilingPass(ArrayRef<int64_t> tileSizes, 561 linalg::LinalgTilingLoopType loopType) { 562 return std::make_unique<LinalgTilingPass>(tileSizes, loopType); 563 } 564