1 //===- Tiling.cpp - Implementation of linalg Tiling -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Tiling pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h" 15 #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 17 #include "mlir/Dialect/Linalg/Passes.h" 18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/Dialect/MemRef/EDSC/Intrinsics.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/Dialect/SCF/EDSC/Builders.h" 23 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 24 #include "mlir/Dialect/Tensor/IR/Tensor.h" 25 #include "mlir/IR/AffineExpr.h" 26 #include "mlir/IR/AffineMap.h" 27 #include "mlir/Transforms/FoldUtils.h" 28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 29 30 #include "llvm/Support/CommandLine.h" 31 32 using namespace mlir; 33 using namespace mlir::edsc; 34 using namespace mlir::edsc::intrinsics; 35 using namespace mlir::linalg; 36 using namespace mlir::scf; 37 38 #define DEBUG_TYPE "linalg-tiling" 39 40 static bool isZero(Value v) { 41 if (auto cst = v.getDefiningOp<ConstantIndexOp>()) 42 return cst.getValue() == 0; 43 return false; 44 } 45 46 using LoopIndexToRangeIndexMap = DenseMap<int, int>; 47 48 // Creates a number of ranges equal to the number of non-zero in `tileSizes`. 49 // One for each loop of the LinalgOp that is tiled. The `tileSizes` argument has 50 // one entry per surrounding loop. It uses zero as the convention that a 51 // particular loop is not tiled. This convention simplifies implementations by 52 // avoiding affine map manipulations. 53 // The returned ranges correspond to the loop ranges, in the proper order, that 54 // are tiled and for which new loops will be created. Also the function returns 55 // a map from loop indices of the LinalgOp to the corresponding non-empty range 56 // indices of newly created loops. 57 static std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap> 58 makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map, 59 ValueRange allShapeSizes, ValueRange allTileSizes) { 60 assert(allTileSizes.size() == map.getNumResults()); 61 // Apply `map` to get shape sizes in loop order. 62 auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes); 63 SmallVector<Value, 4> tileSizes(allTileSizes.begin(), allTileSizes.end()); 64 65 // Traverse the tile sizes, which are in loop order, erase zeros everywhere. 66 LoopIndexToRangeIndexMap loopIndexToRangeIndex; 67 for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) { 68 if (isZero(tileSizes[idx - zerosCount])) { 69 shapeSizes.erase(shapeSizes.begin() + idx - zerosCount); 70 tileSizes.erase(tileSizes.begin() + idx - zerosCount); 71 ++zerosCount; 72 continue; 73 } 74 loopIndexToRangeIndex[idx] = idx - zerosCount; 75 } 76 77 // Create a new range with the applied tile sizes. 78 SmallVector<Range, 4> res; 79 for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) 80 res.push_back( 81 Range{std_constant_index(0), shapeSizes[idx], tileSizes[idx]}); 82 return std::make_tuple(res, loopIndexToRangeIndex); 83 } 84 85 // IndexedGenericOp explicitly uses induction variables in the loop body. The 86 // values of the indices that are used in the loop body for any given access of 87 // input/output memref before `subview` op was applied should be invariant with 88 // respect to tiling. 89 // 90 // Therefore, if the operation is tiled, we have to transform the indices 91 // accordingly, i.e. offset them by the values of the corresponding induction 92 // variables that are captured implicitly in the body of the op. 93 // 94 // Example. `linalg.indexed_generic` before tiling: 95 // 96 // #id_2d = (i, j) -> (i, j) 97 // #pointwise_2d_trait = { 98 // indexing_maps = [#id_2d, #id_2d], 99 // iterator_types = ["parallel", "parallel"], 100 // n_views = [1, 1] 101 // } 102 // linalg.indexed_generic #pointwise_2d_trait %operand, %result { 103 // ^bb0(%i: index, %j: index, %operand_in: f32, %result_in: f32): 104 // <some operations that use %i, %j> 105 // }: memref<50x100xf32>, memref<50x100xf32> 106 // 107 // After tiling pass with tiles sizes 10 and 25: 108 // 109 // #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2) 110 // 111 // %c1 = constant 1 : index 112 // %c0 = constant 0 : index 113 // %c25 = constant 25 : index 114 // %c10 = constant 10 : index 115 // operand_dim_0 = dim %operand, 0 : memref<50x100xf32> 116 // operand_dim_1 = dim %operand, 1 : memref<50x100xf32> 117 // scf.for %k = %c0 to operand_dim_0 step %c10 { 118 // scf.for %l = %c0 to operand_dim_1 step %c25 { 119 // %4 = memref.subview %operand[%k, %l][%c10, %c25][%c1, %c1] 120 // : memref<50x100xf32> to memref<?x?xf32, #strided> 121 // %5 = memref.subview %result[%k, %l][%c10, %c25][%c1, %c1] 122 // : memref<50x100xf32> to memref<?x?xf32, #strided> 123 // linalg.indexed_generic pointwise_2d_trait %4, %5 { 124 // ^bb0(%i: index, %j: index, %operand_in: f32, %result_in: f32): 125 // // Indices `k` and `l` are implicitly captured in the body. 126 // %transformed_i = addi %i, %k : index // index `i` is offset by %k 127 // %transformed_j = addi %j, %l : index // index `j` is offset by %l 128 // // Every use of %i, %j is replaced with %transformed_i, %transformed_j 129 // <some operations that use %transformed_i, %transformed_j> 130 // }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided> 131 // } 132 // } 133 // 134 // TODO: Investigate whether mixing implicit and explicit indices 135 // does not lead to losing information. 136 static void transformIndexedGenericOpIndices( 137 OpBuilder &b, LinalgOp op, SmallVectorImpl<Value> &ivs, 138 const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) { 139 auto indexedGenericOp = dyn_cast<IndexedGenericOp>(op.getOperation()); 140 if (!indexedGenericOp) 141 return; 142 143 // `linalg.indexed_generic` comes in two flavours. One has a region with a 144 // single block that defines the loop body. The other has a `fun` attribute 145 // that refers to an existing function symbol. The `fun` function call will be 146 // inserted in the loop body in that case. 147 // 148 // TODO: Add support for `linalg.indexed_generic` with `fun` attribute. 149 auto ®ion = indexedGenericOp.region(); 150 if (region.empty()) { 151 indexedGenericOp.emitOpError("expected a region"); 152 return; 153 } 154 auto &block = region.front(); 155 156 OpBuilder::InsertionGuard g(b); 157 b.setInsertionPointToStart(&block); 158 for (unsigned i = 0; i < indexedGenericOp.getNumLoops(); ++i) { 159 auto rangeIndex = loopIndexToRangeIndex.find(i); 160 if (rangeIndex == loopIndexToRangeIndex.end()) 161 continue; 162 Value oldIndex = block.getArgument(i); 163 // Offset the index argument `i` by the value of the corresponding induction 164 // variable and replace all uses of the previous value. 165 Value newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex, 166 ivs[rangeIndex->second]); 167 for (auto &use : oldIndex.getUses()) { 168 if (use.getOwner() == newIndex.getDefiningOp()) 169 continue; 170 use.set(newIndex); 171 } 172 } 173 } 174 175 // All indices returned by IndexOp should be invariant with respect to tiling. 176 // Therefore, if an operation is tiled, we have to transform the indices 177 // accordingly, i.e. offset them by the values of the corresponding induction 178 // variables that are captured implicitly in the body of the op. 179 // 180 // Example. `linalg.generic` before tiling: 181 // 182 // #id_2d = (i, j) -> (i, j) 183 // #pointwise_2d_trait = { 184 // indexing_maps = [#id_2d, #id_2d], 185 // iterator_types = ["parallel", "parallel"] 186 // } 187 // linalg.generic #pointwise_2d_trait %operand, %result { 188 // ^bb0(%operand_in: f32, %result_in: f32): 189 // %i = linalg.index 0 : index 190 // %j = linalg.index 1 : index 191 // <some operations that use %i, %j> 192 // }: memref<50x100xf32>, memref<50x100xf32> 193 // 194 // After tiling pass with tiles sizes 10 and 25: 195 // 196 // #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2) 197 // 198 // %c1 = constant 1 : index 199 // %c0 = constant 0 : index 200 // %c25 = constant 25 : index 201 // %c10 = constant 10 : index 202 // operand_dim_0 = dim %operand, 0 : memref<50x100xf32> 203 // operand_dim_1 = dim %operand, 1 : memref<50x100xf32> 204 // scf.for %k = %c0 to operand_dim_0 step %c10 { 205 // scf.for %l = %c0 to operand_dim_1 step %c25 { 206 // %4 = std.subview %operand[%k, %l][%c10, %c25][%c1, %c1] 207 // : memref<50x100xf32> to memref<?x?xf32, #strided> 208 // %5 = std.subview %result[%k, %l][%c10, %c25][%c1, %c1] 209 // : memref<50x100xf32> to memref<?x?xf32, #strided> 210 // linalg.generic pointwise_2d_trait %4, %5 { 211 // ^bb0(%operand_in: f32, %result_in: f32): 212 // %i = linalg.index 0 : index 213 // %j = linalg.index 1 : index 214 // // Indices `k` and `l` are implicitly captured in the body. 215 // %transformed_i = addi %i, %k : index // index `i` is offset by %k 216 // %transformed_j = addi %j, %l : index // index `j` is offset by %l 217 // // Every use of %i, %j is replaced with %transformed_i, %transformed_j 218 // <some operations that use %transformed_i, %transformed_j> 219 // }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided> 220 // } 221 // } 222 // 223 // TODO: Investigate whether mixing implicit and explicit indices 224 // does not lead to losing information. 225 static void 226 transformIndexOps(OpBuilder &b, LinalgOp op, SmallVectorImpl<Value> &ivs, 227 const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) { 228 // Skip operations that have no region attached. 229 if (op->getNumRegions() == 0) 230 return; 231 assert(op->getNumRegions() == 1 && op->getRegion(0).getBlocks().size() == 1 && 232 "expected linalg operation to have one block."); 233 Block &block = op->getRegion(0).front(); 234 235 for (IndexOp indexOp : block.getOps<linalg::IndexOp>()) { 236 auto rangeIndex = loopIndexToRangeIndex.find(indexOp.dim()); 237 if (rangeIndex == loopIndexToRangeIndex.end()) 238 continue; 239 // Offset the index by the value of the corresponding induction variable and 240 // replace all uses of the previous value. 241 OpBuilder::InsertionGuard g(b); 242 b.setInsertionPointAfter(indexOp); 243 AffineExpr index, iv; 244 bindDims(b.getContext(), index, iv); 245 AffineApplyOp applyOp = b.create<AffineApplyOp>( 246 indexOp.getLoc(), index + iv, 247 ValueRange{indexOp.getResult(), ivs[rangeIndex->second]}); 248 indexOp.getResult().replaceAllUsesExcept( 249 applyOp.getResult(), SmallPtrSet<Operation *, 1>{applyOp}); 250 } 251 } 252 253 template <typename LoopTy> 254 static Optional<TiledLinalgOp> 255 tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes, 256 const LinalgTilingOptions &options) { 257 auto nLoops = op.getNumLoops(); 258 // Initial tile sizes may be too big, only take the first nLoops. 259 tileSizes = tileSizes.take_front(nLoops); 260 261 if (llvm::all_of(tileSizes, isZero)) 262 return llvm::None; 263 264 if (auto convOp = dyn_cast<linalg::ConvOp>(op.getOperation())) { 265 // For conv op only support tiling along batch dimension (which is the first 266 // loop). 267 if (convOp.padding() && !llvm::all_of(tileSizes.drop_front(), isZero)) 268 return llvm::None; 269 } 270 271 // 1. Build the tiled loop ranges. 272 auto allShapeSizes = op.createFlatListOfOperandDims(b, op.getLoc()); 273 AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap(); 274 if (!shapeSizesToLoopsMap) 275 return llvm::None; 276 277 SmallVector<Range, 4> loopRanges; 278 LoopIndexToRangeIndexMap loopIndexToRangeIndex; 279 std::tie(loopRanges, loopIndexToRangeIndex) = makeTiledLoopRanges( 280 b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes); 281 282 SmallVector<Attribute, 4> iteratorTypes; 283 for (auto attr : 284 enumerate(op.iterator_types().cast<ArrayAttr>().getValue())) { 285 if (loopIndexToRangeIndex.count(attr.index())) 286 iteratorTypes.push_back(attr.value()); 287 } 288 // If interchangeVector is empty, use the identity. Build the permutation map 289 // otherwise. 290 auto invPermutationMap = 291 AffineMap::getMultiDimIdentityMap(tileSizes.size(), b.getContext()); 292 if (!options.interchangeVector.empty()) { 293 // Based on the pruned iterations (due to zero tile size), recompute the 294 // interchange vector. 295 SmallVector<unsigned, 4> interchangeVector; 296 interchangeVector.reserve(options.interchangeVector.size()); 297 for (auto pos : options.interchangeVector) { 298 auto it = loopIndexToRangeIndex.find(pos); 299 if (it == loopIndexToRangeIndex.end()) 300 continue; 301 interchangeVector.push_back(it->second); 302 } 303 // Interchange vector is guaranteed to be a permutation, 304 // `inversePermutation` must succeed. 305 invPermutationMap = inversePermutation( 306 AffineMap::getPermutationMap(interchangeVector, b.getContext())); 307 assert(invPermutationMap); 308 applyPermutationToVector(loopRanges, interchangeVector); 309 applyPermutationToVector(iteratorTypes, interchangeVector); 310 } 311 312 // 2. Create the tiled loops. 313 LinalgOp res = op; 314 SmallVector<Value, 4> ivs, tensorResults; 315 auto outputTensors = op.getOutputTensors(); 316 GenerateLoopNest<LoopTy>::doit( 317 loopRanges, /*iterArgInitValues*/ outputTensors, iteratorTypes, 318 [&](ValueRange localIvs, ValueRange iterArgs) -> scf::ValueVector { 319 auto &b = ScopedContext::getBuilderRef(); 320 auto loc = ScopedContext::getLocation(); 321 ivs.assign(localIvs.begin(), localIvs.end()); 322 323 // When an `interchangeVector` is present, it has been applied to the 324 // loop ranges and the iterator types. Apply its inverse to the 325 // resulting loop `ivs` to match the op definition. 326 SmallVector<Value, 4> interchangedIvs; 327 if (!options.interchangeVector.empty()) 328 interchangedIvs = applyMapToValues(b, loc, invPermutationMap, ivs); 329 else 330 interchangedIvs.assign(ivs.begin(), ivs.end()); 331 332 assert(op.getNumOutputTensors() == iterArgs.size() && 333 "num output tensors must match number of loop iter arguments"); 334 335 auto operands = llvm::to_vector<4>(op.getInputs()); 336 SmallVector<Value, 4> outputBuffers = op.getOutputBuffers(); 337 // TODO: thanks to simplifying assumption we do not need to worry about 338 // order of output buffers and tensors: there is only ever one kind. 339 assert(outputBuffers.empty() || iterArgs.empty()); 340 operands.append(outputBuffers.begin(), outputBuffers.end()); 341 operands.append(iterArgs.begin(), iterArgs.end()); 342 auto sizeBounds = 343 applyMapToValues(b, loc, shapeSizesToLoopsMap, allShapeSizes); 344 SmallVector<Value, 4> tiledOperands = makeTiledShapes( 345 b, loc, op, operands, interchangedIvs, tileSizes, sizeBounds); 346 auto nonShapedOperands = op.getAssumedNonShapedOperands(); 347 tiledOperands.append(nonShapedOperands.begin(), 348 nonShapedOperands.end()); 349 350 // TODO: use an interface/adaptor to avoid leaking position in 351 // `tiledOperands`. 352 SmallVector<Type, 4> resultTensorTypes; 353 for (OpOperand *opOperand : op.getOutputTensorsOpOperands()) 354 resultTensorTypes.push_back( 355 tiledOperands[opOperand->getOperandNumber()].getType()); 356 357 res = op.clone(b, loc, resultTensorTypes, tiledOperands); 358 359 // Insert a subtensor_insert for each output tensor. 360 unsigned resultIdx = 0; 361 for (OpOperand *opOperand : op.getOutputTensorsOpOperands()) { 362 // TODO: use an interface/adaptor to avoid leaking position in 363 // `tiledOperands`. 364 Value outputTensor = tiledOperands[opOperand->getOperandNumber()]; 365 if (auto subtensor = outputTensor.getDefiningOp<SubTensorOp>()) { 366 tensorResults.push_back(b.create<SubTensorInsertOp>( 367 loc, subtensor.source().getType(), res->getResult(resultIdx), 368 subtensor.source(), subtensor.offsets(), subtensor.sizes(), 369 subtensor.strides(), subtensor.static_offsets(), 370 subtensor.static_sizes(), subtensor.static_strides())); 371 } else { 372 tensorResults.push_back(res->getResult(resultIdx)); 373 } 374 ++resultIdx; 375 } 376 return scf::ValueVector(tensorResults.begin(), tensorResults.end()); 377 }, 378 options.distribution); 379 380 // 3a. Transforms index arguments of `linalg.generic` w.r.t. to the tiling. 381 transformIndexedGenericOpIndices(b, res, ivs, loopIndexToRangeIndex); 382 // 3b. Transform IndexOp results w.r.t. the tiling. 383 transformIndexOps(b, res, ivs, loopIndexToRangeIndex); 384 385 // 4. Gather the newly created loops and return them with the new op. 386 SmallVector<Operation *, 8> loops; 387 loops.reserve(ivs.size()); 388 for (auto iv : ivs) { 389 if (iv.isa<BlockArgument>()) { 390 loops.push_back(iv.cast<BlockArgument>().getOwner()->getParentOp()); 391 assert(loops.back() && "no owner found for induction variable!"); 392 } else { 393 // TODO: Instead of doing this, try to recover the ops used instead of the 394 // loop. 395 loops.push_back(nullptr); 396 } 397 } 398 399 // 5. Get the tensor results from the outermost loop if available. Otherwise 400 // use the previously captured `tensorResults`. 401 Operation *outermostLoop = nullptr; 402 for (Operation *loop : loops) 403 if ((outermostLoop = loop)) 404 break; 405 406 return TiledLinalgOp{ 407 res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults}; 408 } 409 410 template <typename LoopTy> 411 Optional<TiledLinalgOp> static tileLinalgOpImpl( 412 OpBuilder &b, LinalgOp op, const LinalgTilingOptions &options) { 413 OpBuilder::InsertionGuard g(b); 414 b.setInsertionPoint(op); 415 ScopedContext scope(b, op.getLoc()); 416 417 if (!options.tileSizeComputationFunction) 418 return llvm::None; 419 420 // Enforce the convention that "tiling by zero" skips tiling a particular 421 // dimension. This convention is significantly simpler to handle instead of 422 // adjusting affine maps to account for missing dimensions. 423 auto nLoops = op.getNumLoops(); 424 SmallVector<Value, 4> tileSizeVector = 425 options.tileSizeComputationFunction(b, op); 426 if (tileSizeVector.size() < nLoops) { 427 auto zero = std_constant_index(0); 428 tileSizeVector.append(nLoops - tileSizeVector.size(), zero); 429 } 430 431 return tileLinalgOpImpl<LoopTy>(b, op, tileSizeVector, options); 432 } 433 434 Optional<TiledLinalgOp> 435 mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, 436 const LinalgTilingOptions &options) { 437 switch (options.loopType) { 438 case LinalgTilingLoopType::Loops: 439 return tileLinalgOpImpl<scf::ForOp>(b, op, options); 440 case LinalgTilingLoopType::ParallelLoops: 441 return tileLinalgOpImpl<scf::ParallelOp>(b, op, options); 442 default:; 443 } 444 return llvm::None; 445 } 446 447 namespace { 448 /// Helper classes for type list expansion. 449 template <typename... OpTypes> 450 class CanonicalizationPatternList; 451 452 template <> 453 class CanonicalizationPatternList<> { 454 public: 455 static void insert(RewritePatternSet &patterns) {} 456 }; 457 458 template <typename OpTy, typename... OpTypes> 459 class CanonicalizationPatternList<OpTy, OpTypes...> { 460 public: 461 static void insert(RewritePatternSet &patterns) { 462 OpTy::getCanonicalizationPatterns(patterns, patterns.getContext()); 463 CanonicalizationPatternList<OpTypes...>::insert(patterns); 464 } 465 }; 466 467 /// Helper classes for type list expansion. 468 template <typename... OpTypes> 469 class RewritePatternList; 470 471 template <> 472 class RewritePatternList<> { 473 public: 474 static void insert(RewritePatternSet &patterns, 475 const LinalgTilingOptions &options) {} 476 }; 477 478 template <typename OpTy, typename... OpTypes> 479 class RewritePatternList<OpTy, OpTypes...> { 480 public: 481 static void insert(RewritePatternSet &patterns, 482 const LinalgTilingOptions &options) { 483 auto *ctx = patterns.getContext(); 484 patterns.add<LinalgTilingPattern<OpTy>>( 485 ctx, options, 486 LinalgTransformationFilter(ArrayRef<Identifier>{}, 487 Identifier::get("tiled", ctx))); 488 RewritePatternList<OpTypes...>::insert(patterns, options); 489 } 490 }; 491 } // namespace 492 493 RewritePatternSet 494 mlir::linalg::getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx) { 495 RewritePatternSet patterns(ctx); 496 populateLinalgTilingCanonicalizationPatterns(patterns); 497 return patterns; 498 } 499 500 void mlir::linalg::populateLinalgTilingCanonicalizationPatterns( 501 RewritePatternSet &patterns) { 502 auto *ctx = patterns.getContext(); 503 AffineApplyOp::getCanonicalizationPatterns(patterns, ctx); 504 AffineForOp::getCanonicalizationPatterns(patterns, ctx); 505 AffineMinOp::getCanonicalizationPatterns(patterns, ctx); 506 AffineMaxOp::getCanonicalizationPatterns(patterns, ctx); 507 scf::ForOp::getCanonicalizationPatterns(patterns, ctx); 508 scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx); 509 ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx); 510 SubTensorOp::getCanonicalizationPatterns(patterns, ctx); 511 memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx); 512 tensor::CastOp::getCanonicalizationPatterns(patterns, ctx); 513 memref::ViewOp::getCanonicalizationPatterns(patterns, ctx); 514 CanonicalizationPatternList< 515 #define GET_OP_LIST 516 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 517 >::insert(patterns); 518 } 519 520 /// Populate the given list with patterns that apply Linalg tiling. 521 static void insertTilingPatterns(RewritePatternSet &patterns, 522 const LinalgTilingOptions &options) { 523 RewritePatternList<GenericOp, IndexedGenericOp, 524 #define GET_OP_LIST 525 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 526 >::insert(patterns, options); 527 } 528 529 static void applyTilingToLoopPatterns(LinalgTilingLoopType loopType, 530 FuncOp funcOp, 531 ArrayRef<int64_t> tileSizes) { 532 auto options = 533 LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType); 534 MLIRContext *ctx = funcOp.getContext(); 535 RewritePatternSet patterns(ctx); 536 insertTilingPatterns(patterns, options); 537 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 538 (void)applyPatternsAndFoldGreedily( 539 funcOp, getLinalgTilingCanonicalizationPatterns(ctx)); 540 // Drop the marker. 541 funcOp.walk([](LinalgOp op) { 542 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 543 }); 544 } 545 546 namespace { 547 struct LinalgTilingPass : public LinalgTilingBase<LinalgTilingPass> { 548 LinalgTilingPass() = default; 549 LinalgTilingPass(ArrayRef<int64_t> sizes) { tileSizes = sizes; } 550 551 void runOnFunction() override { 552 applyTilingToLoopPatterns(LinalgTilingLoopType::Loops, getFunction(), 553 tileSizes); 554 } 555 }; 556 557 struct LinalgTilingToParallelLoopsPass 558 : public LinalgTilingToParallelLoopsBase<LinalgTilingToParallelLoopsPass> { 559 LinalgTilingToParallelLoopsPass() = default; 560 LinalgTilingToParallelLoopsPass(ArrayRef<int64_t> sizes) { 561 tileSizes = sizes; 562 } 563 564 void runOnFunction() override { 565 applyTilingToLoopPatterns(LinalgTilingLoopType::ParallelLoops, 566 getFunction(), tileSizes); 567 } 568 }; 569 570 } // namespace 571 572 std::unique_ptr<OperationPass<FuncOp>> 573 mlir::createLinalgTilingPass(ArrayRef<int64_t> tileSizes) { 574 return std::make_unique<LinalgTilingPass>(tileSizes); 575 } 576 577 std::unique_ptr<OperationPass<FuncOp>> 578 mlir::createLinalgTilingToParallelLoopsPass(ArrayRef<int64_t> tileSizes) { 579 return std::make_unique<LinalgTilingToParallelLoopsPass>(tileSizes); 580 } 581