1 //===- Tiling.cpp - Implementation of linalg Tiling -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Tiling pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 15 #include "mlir/Dialect/Linalg/Passes.h" 16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 17 #include "mlir/Dialect/Linalg/Utils/Utils.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/SCF/Transforms.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/IR/AffineExpr.h" 22 #include "mlir/IR/AffineMap.h" 23 #include "mlir/Transforms/FoldUtils.h" 24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 25 26 #include "llvm/Support/CommandLine.h" 27 28 using namespace mlir; 29 using namespace mlir::linalg; 30 using namespace mlir::scf; 31 32 #define DEBUG_TYPE "linalg-tiling" 33 34 static bool isZero(Value v) { 35 if (auto cst = v.getDefiningOp<arith::ConstantIndexOp>()) 36 return cst.value() == 0; 37 return false; 38 } 39 40 using LoopIndexToRangeIndexMap = DenseMap<int, int>; 41 42 // Creates a number of ranges equal to the number of non-zero in `tileSizes`. 43 // One for each loop of the LinalgOp that is tiled. The `tileSizes` argument has 44 // one entry per surrounding loop. It uses zero as the convention that a 45 // particular loop is not tiled. This convention simplifies implementations by 46 // avoiding affine map manipulations. 47 // The returned ranges correspond to the loop ranges, in the proper order, that 48 // are tiled and for which new loops will be created. Also the function returns 49 // a map from loop indices of the LinalgOp to the corresponding non-empty range 50 // indices of newly created loops. 51 static std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap> 52 makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map, 53 ValueRange allShapeSizes, ValueRange allTileSizes) { 54 assert(allTileSizes.size() == map.getNumResults()); 55 // Apply `map` to get shape sizes in loop order. 56 auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes); 57 SmallVector<Value, 4> tileSizes(allTileSizes.begin(), allTileSizes.end()); 58 59 // Traverse the tile sizes, which are in loop order, erase zeros everywhere. 60 LoopIndexToRangeIndexMap loopIndexToRangeIndex; 61 for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) { 62 if (isZero(tileSizes[idx - zerosCount])) { 63 shapeSizes.erase(shapeSizes.begin() + idx - zerosCount); 64 tileSizes.erase(tileSizes.begin() + idx - zerosCount); 65 ++zerosCount; 66 continue; 67 } 68 loopIndexToRangeIndex[idx] = idx - zerosCount; 69 } 70 71 // Create a new range with the applied tile sizes. 72 SmallVector<Range, 4> res; 73 for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) 74 res.push_back(Range{b.create<arith::ConstantIndexOp>(loc, 0), 75 shapeSizes[idx], tileSizes[idx]}); 76 return std::make_tuple(res, loopIndexToRangeIndex); 77 } 78 79 // All indices returned by IndexOp should be invariant with respect to tiling. 80 // Therefore, if an operation is tiled, we have to transform the indices 81 // accordingly, i.e. offset them by the values of the corresponding induction 82 // variables that are captured implicitly in the body of the op. 83 // 84 // Example. `linalg.generic` before tiling: 85 // 86 // #id_2d = (i, j) -> (i, j) 87 // #pointwise_2d_trait = { 88 // indexing_maps = [#id_2d, #id_2d], 89 // iterator_types = ["parallel", "parallel"] 90 // } 91 // linalg.generic #pointwise_2d_trait %operand, %result { 92 // ^bb0(%operand_in: f32, %result_in: f32): 93 // %i = linalg.index 0 : index 94 // %j = linalg.index 1 : index 95 // <some operations that use %i, %j> 96 // }: memref<50x100xf32>, memref<50x100xf32> 97 // 98 // After tiling pass with tiles sizes 10 and 25: 99 // 100 // #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2) 101 // 102 // %c1 = arith.constant 1 : index 103 // %c0 = arith.constant 0 : index 104 // %c25 = arith.constant 25 : index 105 // %c10 = arith.constant 10 : index 106 // operand_dim_0 = dim %operand, 0 : memref<50x100xf32> 107 // operand_dim_1 = dim %operand, 1 : memref<50x100xf32> 108 // scf.for %k = %c0 to operand_dim_0 step %c10 { 109 // scf.for %l = %c0 to operand_dim_1 step %c25 { 110 // %4 = std.subview %operand[%k, %l][%c10, %c25][%c1, %c1] 111 // : memref<50x100xf32> to memref<?x?xf32, #strided> 112 // %5 = std.subview %result[%k, %l][%c10, %c25][%c1, %c1] 113 // : memref<50x100xf32> to memref<?x?xf32, #strided> 114 // linalg.generic pointwise_2d_trait %4, %5 { 115 // ^bb0(%operand_in: f32, %result_in: f32): 116 // %i = linalg.index 0 : index 117 // %j = linalg.index 1 : index 118 // // Indices `k` and `l` are implicitly captured in the body. 119 // %transformed_i = arith.addi %i, %k : index // index `i` is offset by %k 120 // %transformed_j = arith.addi %j, %l : index // index `j` is offset by %l 121 // // Every use of %i, %j is replaced with %transformed_i, %transformed_j 122 // <some operations that use %transformed_i, %transformed_j> 123 // }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided> 124 // } 125 // } 126 // 127 // TODO: Investigate whether mixing implicit and explicit indices 128 // does not lead to losing information. 129 static void 130 transformIndexOps(OpBuilder &b, LinalgOp op, SmallVectorImpl<Value> &ivs, 131 const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) { 132 SmallVector<Value> allIvs(op.getNumLoops(), nullptr); 133 for (auto &en : enumerate(allIvs)) { 134 auto rangeIndex = loopIndexToRangeIndex.find(en.index()); 135 if (rangeIndex == loopIndexToRangeIndex.end()) 136 continue; 137 en.value() = ivs[rangeIndex->second]; 138 } 139 addTileLoopIvsToIndexOpResults(b, op, allIvs); 140 } 141 142 // Insert a tile `source` into the destination tensor `dest`. The position at 143 // which the tile is inserted (as well as size of tile) is taken from a given 144 // ExtractSliceOp `sliceOp`. 145 static Value insertSliceIntoTensor(OpBuilder &b, Location loc, 146 tensor::ExtractSliceOp sliceOp, Value source, 147 Value dest) { 148 return b.create<tensor::InsertSliceOp>( 149 loc, sliceOp.source().getType(), source, dest, sliceOp.offsets(), 150 sliceOp.sizes(), sliceOp.strides(), sliceOp.static_offsets(), 151 sliceOp.static_sizes(), sliceOp.static_strides()); 152 } 153 154 template <typename LoopTy> 155 static FailureOr<TiledLinalgOp> 156 tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes, 157 const LinalgTilingOptions &options) { 158 auto nLoops = op.getNumLoops(); 159 // Initial tile sizes may be too big, only take the first nLoops. 160 tileSizes = tileSizes.take_front(nLoops); 161 162 if (llvm::all_of(tileSizes, isZero)) { 163 TiledLinalgOp tiledOp; 164 tiledOp.op = cast<LinalgOp>(b.clone(*op.getOperation())); 165 tiledOp.tensorResults.assign(tiledOp.op->result_begin(), 166 tiledOp.op->result_end()); 167 return tiledOp; 168 } 169 170 // 1. Build the tiled loop ranges. 171 auto allShapeSizes = op.createFlatListOfOperandDims(b, op.getLoc()); 172 AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap(); 173 if (!shapeSizesToLoopsMap) 174 return failure(); 175 176 SmallVector<Range, 4> loopRanges; 177 LoopIndexToRangeIndexMap loopIndexToRangeIndex; 178 std::tie(loopRanges, loopIndexToRangeIndex) = makeTiledLoopRanges( 179 b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes); 180 181 SmallVector<Attribute, 4> iteratorTypes; 182 for (auto attr : 183 enumerate(op.iterator_types().cast<ArrayAttr>().getValue())) { 184 if (loopIndexToRangeIndex.count(attr.index())) 185 iteratorTypes.push_back(attr.value()); 186 } 187 // If interchangeVector is empty, use the identity. Build the permutation map 188 // otherwise. 189 auto invPermutationMap = 190 AffineMap::getMultiDimIdentityMap(tileSizes.size(), b.getContext()); 191 if (!options.interchangeVector.empty()) { 192 // Based on the pruned iterations (due to zero tile size), recompute the 193 // interchange vector. 194 SmallVector<unsigned, 4> interchangeVector; 195 interchangeVector.reserve(options.interchangeVector.size()); 196 for (auto pos : options.interchangeVector) { 197 auto it = loopIndexToRangeIndex.find(pos); 198 if (it == loopIndexToRangeIndex.end()) 199 continue; 200 interchangeVector.push_back(it->second); 201 } 202 // Interchange vector is guaranteed to be a permutation, 203 // `inversePermutation` must succeed. 204 invPermutationMap = inversePermutation( 205 AffineMap::getPermutationMap(interchangeVector, b.getContext())); 206 assert(invPermutationMap); 207 SmallVector<int64_t> permutation(interchangeVector.begin(), 208 interchangeVector.end()); 209 applyPermutationToVector(loopRanges, permutation); 210 applyPermutationToVector(iteratorTypes, permutation); 211 } 212 213 // 2. Create the tiled loops. 214 LinalgOp res = op; 215 SmallVector<Value, 4> ivs, tensorResults; 216 auto tiledLoopBodyBuilder = 217 [&](OpBuilder &b, Location loc, ValueRange localIvs, 218 ValueRange operandValuesToUse) -> scf::ValueVector { 219 ivs.assign(localIvs.begin(), localIvs.end()); 220 221 // When an `interchangeVector` is present, it has been applied to the 222 // loop ranges and the iterator types. Apply its inverse to the 223 // resulting loop `ivs` to match the op definition. 224 SmallVector<Value, 4> interchangedIvs; 225 if (!options.interchangeVector.empty()) 226 interchangedIvs = applyMapToValues(b, loc, invPermutationMap, ivs); 227 else 228 interchangedIvs.assign(ivs.begin(), ivs.end()); 229 230 // Tile the `operandValuesToUse` that either match the `op` operands 231 // themselves or the tile loop arguments forwarding them. 232 assert(operandValuesToUse.size() == 233 static_cast<size_t>(op.getNumInputsAndOutputs()) && 234 "expect the number of operands and inputs and outputs to match"); 235 SmallVector<Value> valuesToTile = operandValuesToUse; 236 auto sizeBounds = 237 applyMapToValues(b, loc, shapeSizesToLoopsMap, allShapeSizes); 238 SmallVector<Value, 4> tiledOperands = makeTiledShapes( 239 b, loc, op, valuesToTile, interchangedIvs, tileSizes, sizeBounds); 240 241 // TODO: use an interface/adaptor to avoid leaking position in 242 // `tiledOperands`. 243 SmallVector<Type, 4> resultTensorTypes; 244 for (OpOperand *opOperand : op.getOutputTensorOperands()) 245 resultTensorTypes.push_back( 246 tiledOperands[opOperand->getOperandNumber()].getType()); 247 248 res = op.clone(b, loc, resultTensorTypes, tiledOperands); 249 250 // Insert a insert_slice for each output tensor. 251 unsigned resultIdx = 0; 252 for (OpOperand *opOperand : op.getOutputTensorOperands()) { 253 // TODO: use an interface/adaptor to avoid leaking position in 254 // `tiledOperands`. 255 Value outputTensor = tiledOperands[opOperand->getOperandNumber()]; 256 if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) { 257 tensorResults.push_back(insertSliceIntoTensor( 258 b, loc, sliceOp, res->getResult(resultIdx), sliceOp.source())); 259 } else { 260 tensorResults.push_back(res->getResult(resultIdx)); 261 } 262 ++resultIdx; 263 } 264 return scf::ValueVector(tensorResults.begin(), tensorResults.end()); 265 }; 266 GenerateLoopNest<LoopTy>::doit(b, op.getLoc(), loopRanges, op, iteratorTypes, 267 tiledLoopBodyBuilder, options.distribution, 268 options.distributionTypes); 269 270 // 3. Transform IndexOp results w.r.t. the tiling. 271 transformIndexOps(b, res, ivs, loopIndexToRangeIndex); 272 273 // 4. Gather the newly created loops and return them with the new op. 274 SmallVector<Operation *, 8> loops; 275 loops.reserve(ivs.size()); 276 for (auto iv : ivs) { 277 if (iv.isa<BlockArgument>()) { 278 loops.push_back(iv.cast<BlockArgument>().getOwner()->getParentOp()); 279 assert(loops.back() && "no owner found for induction variable!"); 280 } else { 281 // TODO: Instead of doing this, try to recover the ops used instead of the 282 // loop. 283 loops.push_back(nullptr); 284 } 285 } 286 287 // 5. Get the tensor results from the outermost loop if available. Otherwise 288 // use the previously captured `tensorResults`. 289 Operation *outermostLoop = nullptr; 290 for (Operation *loop : loops) 291 if ((outermostLoop = loop)) 292 break; 293 294 return TiledLinalgOp{ 295 res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults}; 296 } 297 298 template <typename LoopTy> 299 FailureOr<TiledLinalgOp> static tileLinalgOpImpl( 300 OpBuilder &b, LinalgOp op, const LinalgTilingOptions &options) { 301 OpBuilder::InsertionGuard g(b); 302 b.setInsertionPoint(op); 303 304 if (!options.tileSizeComputationFunction) 305 return failure(); 306 307 // Enforce the convention that "tiling by zero" skips tiling a particular 308 // dimension. This convention is significantly simpler to handle instead of 309 // adjusting affine maps to account for missing dimensions. 310 auto nLoops = op.getNumLoops(); 311 SmallVector<Value, 4> tileSizeVector = 312 options.tileSizeComputationFunction(b, op); 313 if (tileSizeVector.size() < nLoops) { 314 auto zero = b.create<arith::ConstantIndexOp>(op.getLoc(), 0); 315 tileSizeVector.append(nLoops - tileSizeVector.size(), zero); 316 } 317 318 return tileLinalgOpImpl<LoopTy>(b, op, tileSizeVector, options); 319 } 320 321 FailureOr<TiledLinalgOp> 322 mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, 323 const LinalgTilingOptions &options) { 324 switch (options.loopType) { 325 case LinalgTilingLoopType::Loops: 326 return tileLinalgOpImpl<scf::ForOp>(b, op, options); 327 case LinalgTilingLoopType::ParallelLoops: 328 return tileLinalgOpImpl<scf::ParallelOp>(b, op, options); 329 case LinalgTilingLoopType::TiledLoops: 330 return tileLinalgOpImpl<linalg::TiledLoopOp>(b, op, options); 331 default:; 332 } 333 return failure(); 334 } 335 336 /// Generate a loop nest around a given PadTensorOp (for tiling). `newPadOp` 337 /// and `loopNest` are output parameters that return the new (tiled) PadTensorOp 338 /// and the loop nest. 339 static LogicalResult tilePadTensorOp(OpBuilder &builder, PadTensorOp op, 340 PadTensorOp &newPadOp, LoopNest &loopNest, 341 const LinalgTilingOptions &options) { 342 Location loc = op.getLoc(); 343 OpBuilder::InsertionGuard g(builder); 344 builder.setInsertionPoint(op); 345 346 // Clone PadTensorOp so that the existing op can be replaced more easily. 347 newPadOp = cast<PadTensorOp>(builder.clone(*op.getOperation())); 348 // Get rank and tile sizes. 349 int64_t rank = op.getResultType().getRank(); 350 SmallVector<Value> tileSizes = 351 options.tileSizeComputationFunction(builder, op); 352 assert(static_cast<int64_t>(tileSizes.size()) == rank); 353 // Compute lower and upper bounds of the loop nest. 354 SmallVector<Range> ranges = op.getIterationDomain(builder); 355 SmallVector<Value> lbs, dims, allDims, steps; 356 for (int64_t i = 0; i < rank; ++i) { 357 allDims.push_back(ranges[i].size); 358 if (!isZero(tileSizes[i])) { 359 lbs.push_back(ranges[i].offset); 360 dims.push_back(ranges[i].size); 361 steps.push_back(tileSizes[i]); 362 } 363 } 364 // Generate loop nest: One loop per dimension. 365 SmallVector<Value> destOperand = op.getDestinationOperands(builder); 366 loopNest = mlir::scf::buildLoopNest( 367 builder, loc, lbs, /*ubs=*/dims, steps, ValueRange(destOperand), 368 [&](OpBuilder &b, Location loc, ValueRange localIvs, 369 ValueRange iterArgs) -> scf::ValueVector { 370 // Compute offsets and sizes of ExtractSliceOp. 371 SmallVector<Value> offsets = 372 computeTileOffsets(b, loc, localIvs, tileSizes); 373 SmallVector<Value> sizes = 374 computeTileSizes(b, loc, localIvs, tileSizes, allDims); 375 // Create ExtractSliceOp: Extract a tile from the PadTensorOp. 376 // Note: The PadTensorOp is located outside of the loop nest. It is 377 // later moved inside by ExtractSliceOfPadTensorSwapPattern. 378 auto map = AffineMap::getMultiDimIdentityMap(rank, b.getContext()); 379 Value tiledOutput = 380 makeTiledShape(b, loc, newPadOp->getResult(0), tileSizes, map, 381 offsets, allDims, sizes); 382 auto sliceOp = tiledOutput.getDefiningOp<tensor::ExtractSliceOp>(); 383 assert(sliceOp && "expected ExtractSliceOp"); 384 // Insert the tile into the output tensor. 385 Value yieldValue = 386 insertSliceIntoTensor(b, loc, sliceOp, sliceOp, iterArgs[0]); 387 return scf::ValueVector({yieldValue}); 388 }); 389 return success(); 390 } 391 392 namespace { 393 struct PadTensorOpTilingPattern : public OpRewritePattern<PadTensorOp> { 394 PadTensorOpTilingPattern(MLIRContext *ctx, LinalgTilingOptions opt) 395 : OpRewritePattern<PadTensorOp>(ctx), options(opt) {} 396 397 LogicalResult matchAndRewrite(PadTensorOp op, 398 PatternRewriter &rewriter) const override { 399 if (op->hasAttr(LinalgTransforms::kLinalgTransformMarker)) 400 return failure(); 401 PadTensorOp newPadOp; 402 LoopNest loopNest; 403 if (failed(tilePadTensorOp(rewriter, op, newPadOp, loopNest, options))) 404 return failure(); 405 newPadOp->setAttr(LinalgTransforms::kLinalgTransformMarker, 406 rewriter.getUnitAttr()); 407 // Replace all uses of the original PadTensorOp. 408 rewriter.replaceOp(op, loopNest.getResults()[0]); 409 return success(); 410 } 411 412 LinalgTilingOptions options; 413 }; 414 } // namespace 415 416 namespace { 417 /// Helper classes for type list expansion. 418 template <typename... OpTypes> 419 class CanonicalizationPatternList; 420 421 template <> 422 class CanonicalizationPatternList<> { 423 public: 424 static void insert(RewritePatternSet &patterns) {} 425 }; 426 427 template <typename OpTy, typename... OpTypes> 428 class CanonicalizationPatternList<OpTy, OpTypes...> { 429 public: 430 static void insert(RewritePatternSet &patterns) { 431 OpTy::getCanonicalizationPatterns(patterns, patterns.getContext()); 432 CanonicalizationPatternList<OpTypes...>::insert(patterns); 433 } 434 }; 435 436 /// Helper classes for type list expansion. 437 template <typename... OpTypes> 438 class RewritePatternList; 439 440 template <> 441 class RewritePatternList<> { 442 public: 443 static void insert(RewritePatternSet &patterns, 444 const LinalgTilingOptions &options) {} 445 }; 446 447 template <typename OpTy, typename... OpTypes> 448 class RewritePatternList<OpTy, OpTypes...> { 449 public: 450 static void insert(RewritePatternSet &patterns, 451 const LinalgTilingOptions &options) { 452 auto *ctx = patterns.getContext(); 453 patterns.add<LinalgTilingPattern<OpTy>>( 454 ctx, options, 455 LinalgTransformationFilter(ArrayRef<StringAttr>{}, 456 StringAttr::get(ctx, "tiled"))); 457 RewritePatternList<OpTypes...>::insert(patterns, options); 458 } 459 }; 460 } // namespace 461 462 RewritePatternSet 463 mlir::linalg::getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx) { 464 RewritePatternSet patterns(ctx); 465 populateLinalgTilingCanonicalizationPatterns(patterns); 466 return patterns; 467 } 468 469 void mlir::linalg::populateLinalgTilingCanonicalizationPatterns( 470 RewritePatternSet &patterns) { 471 auto *ctx = patterns.getContext(); 472 AffineApplyOp::getCanonicalizationPatterns(patterns, ctx); 473 AffineForOp::getCanonicalizationPatterns(patterns, ctx); 474 AffineMinOp::getCanonicalizationPatterns(patterns, ctx); 475 AffineMaxOp::getCanonicalizationPatterns(patterns, ctx); 476 arith::ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx); 477 478 memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx); 479 memref::ViewOp::getCanonicalizationPatterns(patterns, ctx); 480 481 scf::ForOp::getCanonicalizationPatterns(patterns, ctx); 482 scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx); 483 484 tensor::CastOp::getCanonicalizationPatterns(patterns, ctx); 485 tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx); 486 tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx); 487 488 InitTensorOp::getCanonicalizationPatterns(patterns, ctx); 489 PadTensorOp::getCanonicalizationPatterns(patterns, ctx); 490 ctx->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(patterns); 491 492 CanonicalizationPatternList< 493 #define GET_OP_LIST 494 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 495 >::insert(patterns); 496 } 497 498 /// Populate the given list with patterns that apply Linalg tiling. 499 static void insertTilingPatterns(RewritePatternSet &patterns, 500 const LinalgTilingOptions &options) { 501 RewritePatternList<GenericOp, 502 #define GET_OP_LIST 503 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 504 >::insert(patterns, options); 505 patterns.add<PadTensorOpTilingPattern>(patterns.getContext(), options); 506 } 507 508 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) { 509 MLIRContext *ctx = funcOp.getContext(); 510 RewritePatternSet patterns(ctx); 511 patterns.add<ExtractSliceOfPadTensorSwapPattern>(patterns.getContext()); 512 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 513 (void)applyPatternsAndFoldGreedily( 514 funcOp, getLinalgTilingCanonicalizationPatterns(ctx)); 515 } 516 517 namespace { 518 struct LinalgTilingPass : public LinalgTilingBase<LinalgTilingPass> { 519 LinalgTilingPass() = default; 520 LinalgTilingPass(ArrayRef<int64_t> tileSizes, LinalgTilingLoopType loopType, 521 ArrayRef<StringRef> distributionTypes) { 522 this->tileSizes = tileSizes; 523 this->loopType = ""; 524 this->loopTypeEnum = loopType; 525 this->distributionTypes = llvm::to_vector<2>(llvm::map_range( 526 distributionTypes, [](StringRef ref) { return ref.str(); })); 527 } 528 529 void runOnFunction() override { 530 FuncOp funcOp = getFunction(); 531 LinalgTilingLoopType type = 532 llvm::StringSwitch<LinalgTilingLoopType>(loopType) 533 .Case("for", LinalgTilingLoopType::Loops) 534 .Case("affine", LinalgTilingLoopType::AffineLoops) 535 .Case("parallel", LinalgTilingLoopType::ParallelLoops) 536 .Case("tiled_loop", LinalgTilingLoopType::TiledLoops) 537 .Default(loopTypeEnum); 538 auto distTypes = llvm::to_vector<2>(llvm::map_range( 539 distributionTypes, [](std::string &str) { return StringRef(str); })); 540 auto options = LinalgTilingOptions() 541 .setTileSizes(tileSizes) 542 .setLoopType(type) 543 .setDistributionTypes(distTypes); 544 MLIRContext *ctx = funcOp.getContext(); 545 RewritePatternSet patterns(ctx); 546 insertTilingPatterns(patterns, options); 547 scf::populateSCFForLoopCanonicalizationPatterns(patterns); 548 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 549 (void)applyPatternsAndFoldGreedily( 550 funcOp, getLinalgTilingCanonicalizationPatterns(ctx)); 551 // Drop the marker. 552 funcOp.walk([](LinalgOp op) { 553 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 554 }); 555 556 // Apply swap pattern after generating loop nest and running 557 // canonicalizations. 558 applyExtractSliceOfPadTensorSwapPattern(funcOp); 559 } 560 561 LinalgTilingLoopType loopTypeEnum; 562 }; 563 564 } // namespace 565 566 std::unique_ptr<OperationPass<FuncOp>> 567 mlir::createLinalgTilingPass(ArrayRef<int64_t> tileSizes, 568 linalg::LinalgTilingLoopType loopType, 569 ArrayRef<StringRef> distributionTypes) { 570 return std::make_unique<LinalgTilingPass>(tileSizes, loopType, 571 distributionTypes); 572 } 573