1 //===- Tiling.cpp - Implementation of linalg Tiling -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Tiling pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h" 15 #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 17 #include "mlir/Dialect/Linalg/Passes.h" 18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/Dialect/MemRef/EDSC/Intrinsics.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/Dialect/SCF/EDSC/Builders.h" 23 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 24 #include "mlir/Dialect/Tensor/IR/Tensor.h" 25 #include "mlir/IR/AffineExpr.h" 26 #include "mlir/IR/AffineMap.h" 27 #include "mlir/Transforms/FoldUtils.h" 28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 29 30 #include "llvm/Support/CommandLine.h" 31 32 using namespace mlir; 33 using namespace mlir::edsc; 34 using namespace mlir::edsc::intrinsics; 35 using namespace mlir::linalg; 36 using namespace mlir::scf; 37 38 #define DEBUG_TYPE "linalg-tiling" 39 40 static bool isZero(Value v) { 41 if (auto cst = v.getDefiningOp<ConstantIndexOp>()) 42 return cst.getValue() == 0; 43 return false; 44 } 45 46 using LoopIndexToRangeIndexMap = DenseMap<int, int>; 47 48 // Creates a number of ranges equal to the number of non-zero in `tileSizes`. 49 // One for each loop of the LinalgOp that is tiled. The `tileSizes` argument has 50 // one entry per surrounding loop. It uses zero as the convention that a 51 // particular loop is not tiled. This convention simplifies implementations by 52 // avoiding affine map manipulations. 53 // The returned ranges correspond to the loop ranges, in the proper order, that 54 // are tiled and for which new loops will be created. Also the function returns 55 // a map from loop indices of the LinalgOp to the corresponding non-empty range 56 // indices of newly created loops. 57 static std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap> 58 makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map, 59 ValueRange allShapeSizes, ValueRange allTileSizes) { 60 assert(allTileSizes.size() == map.getNumResults()); 61 // Apply `map` to get shape sizes in loop order. 62 auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes); 63 SmallVector<Value, 4> tileSizes(allTileSizes.begin(), allTileSizes.end()); 64 65 // Traverse the tile sizes, which are in loop order, erase zeros everywhere. 66 LoopIndexToRangeIndexMap loopIndexToRangeIndex; 67 for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) { 68 if (isZero(tileSizes[idx - zerosCount])) { 69 shapeSizes.erase(shapeSizes.begin() + idx - zerosCount); 70 tileSizes.erase(tileSizes.begin() + idx - zerosCount); 71 ++zerosCount; 72 continue; 73 } 74 loopIndexToRangeIndex[idx] = idx - zerosCount; 75 } 76 77 // Create a new range with the applied tile sizes. 78 SmallVector<Range, 4> res; 79 for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) 80 res.push_back( 81 Range{std_constant_index(0), shapeSizes[idx], tileSizes[idx]}); 82 return std::make_tuple(res, loopIndexToRangeIndex); 83 } 84 85 // IndexedGenericOp explicitly uses induction variables in the loop body. The 86 // values of the indices that are used in the loop body for any given access of 87 // input/output memref before `subview` op was applied should be invariant with 88 // respect to tiling. 89 // 90 // Therefore, if the operation is tiled, we have to transform the indices 91 // accordingly, i.e. offset them by the values of the corresponding induction 92 // variables that are captured implicitly in the body of the op. 93 // 94 // Example. `linalg.indexed_generic` before tiling: 95 // 96 // #id_2d = (i, j) -> (i, j) 97 // #pointwise_2d_trait = { 98 // indexing_maps = [#id_2d, #id_2d], 99 // iterator_types = ["parallel", "parallel"], 100 // n_views = [1, 1] 101 // } 102 // linalg.indexed_generic #pointwise_2d_trait %operand, %result { 103 // ^bb0(%i: index, %j: index, %operand_in: f32, %result_in: f32): 104 // <some operations that use %i, %j> 105 // }: memref<50x100xf32>, memref<50x100xf32> 106 // 107 // After tiling pass with tiles sizes 10 and 25: 108 // 109 // #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2) 110 // 111 // %c1 = constant 1 : index 112 // %c0 = constant 0 : index 113 // %c25 = constant 25 : index 114 // %c10 = constant 10 : index 115 // operand_dim_0 = dim %operand, 0 : memref<50x100xf32> 116 // operand_dim_1 = dim %operand, 1 : memref<50x100xf32> 117 // scf.for %k = %c0 to operand_dim_0 step %c10 { 118 // scf.for %l = %c0 to operand_dim_1 step %c25 { 119 // %4 = memref.subview %operand[%k, %l][%c10, %c25][%c1, %c1] 120 // : memref<50x100xf32> to memref<?x?xf32, #strided> 121 // %5 = memref.subview %result[%k, %l][%c10, %c25][%c1, %c1] 122 // : memref<50x100xf32> to memref<?x?xf32, #strided> 123 // linalg.indexed_generic pointwise_2d_trait %4, %5 { 124 // ^bb0(%i: index, %j: index, %operand_in: f32, %result_in: f32): 125 // // Indices `k` and `l` are implicitly captured in the body. 126 // %transformed_i = addi %i, %k : index // index `i` is offset by %k 127 // %transformed_j = addi %j, %l : index // index `j` is offset by %l 128 // // Every use of %i, %j is replaced with %transformed_i, %transformed_j 129 // <some operations that use %transformed_i, %transformed_j> 130 // }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided> 131 // } 132 // } 133 // 134 // TODO: Investigate whether mixing implicit and explicit indices 135 // does not lead to losing information. 136 static void transformIndexedGenericOpIndices( 137 OpBuilder &b, LinalgOp op, SmallVectorImpl<Value> &ivs, 138 const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) { 139 auto indexedGenericOp = dyn_cast<IndexedGenericOp>(op.getOperation()); 140 if (!indexedGenericOp) 141 return; 142 143 // `linalg.indexed_generic` comes in two flavours. One has a region with a 144 // single block that defines the loop body. The other has a `fun` attribute 145 // that refers to an existing function symbol. The `fun` function call will be 146 // inserted in the loop body in that case. 147 // 148 // TODO: Add support for `linalg.indexed_generic` with `fun` attribute. 149 auto ®ion = indexedGenericOp.region(); 150 if (region.empty()) { 151 indexedGenericOp.emitOpError("expected a region"); 152 return; 153 } 154 auto &block = region.front(); 155 156 OpBuilder::InsertionGuard g(b); 157 b.setInsertionPointToStart(&block); 158 for (unsigned i = 0; i < indexedGenericOp.getNumLoops(); ++i) { 159 auto rangeIndex = loopIndexToRangeIndex.find(i); 160 if (rangeIndex == loopIndexToRangeIndex.end()) 161 continue; 162 Value oldIndex = block.getArgument(i); 163 // Offset the index argument `i` by the value of the corresponding induction 164 // variable and replace all uses of the previous value. 165 Value newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex, 166 ivs[rangeIndex->second]); 167 for (auto &use : oldIndex.getUses()) { 168 if (use.getOwner() == newIndex.getDefiningOp()) 169 continue; 170 use.set(newIndex); 171 } 172 } 173 } 174 175 // All indices returned by IndexOp should be invariant with respect to tiling. 176 // Therefore, if an operation is tiled, we have to transform the indices 177 // accordingly, i.e. offset them by the values of the corresponding induction 178 // variables that are captured implicitly in the body of the op. 179 // 180 // Example. `linalg.generic` before tiling: 181 // 182 // #id_2d = (i, j) -> (i, j) 183 // #pointwise_2d_trait = { 184 // indexing_maps = [#id_2d, #id_2d], 185 // iterator_types = ["parallel", "parallel"] 186 // } 187 // linalg.generic #pointwise_2d_trait %operand, %result { 188 // ^bb0(%operand_in: f32, %result_in: f32): 189 // %i = linalg.index 0 : index 190 // %j = linalg.index 1 : index 191 // <some operations that use %i, %j> 192 // }: memref<50x100xf32>, memref<50x100xf32> 193 // 194 // After tiling pass with tiles sizes 10 and 25: 195 // 196 // #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2) 197 // 198 // %c1 = constant 1 : index 199 // %c0 = constant 0 : index 200 // %c25 = constant 25 : index 201 // %c10 = constant 10 : index 202 // operand_dim_0 = dim %operand, 0 : memref<50x100xf32> 203 // operand_dim_1 = dim %operand, 1 : memref<50x100xf32> 204 // scf.for %k = %c0 to operand_dim_0 step %c10 { 205 // scf.for %l = %c0 to operand_dim_1 step %c25 { 206 // %4 = std.subview %operand[%k, %l][%c10, %c25][%c1, %c1] 207 // : memref<50x100xf32> to memref<?x?xf32, #strided> 208 // %5 = std.subview %result[%k, %l][%c10, %c25][%c1, %c1] 209 // : memref<50x100xf32> to memref<?x?xf32, #strided> 210 // linalg.generic pointwise_2d_trait %4, %5 { 211 // ^bb0(%operand_in: f32, %result_in: f32): 212 // %i = linalg.index 0 : index 213 // %j = linalg.index 1 : index 214 // // Indices `k` and `l` are implicitly captured in the body. 215 // %transformed_i = addi %i, %k : index // index `i` is offset by %k 216 // %transformed_j = addi %j, %l : index // index `j` is offset by %l 217 // // Every use of %i, %j is replaced with %transformed_i, %transformed_j 218 // <some operations that use %transformed_i, %transformed_j> 219 // }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided> 220 // } 221 // } 222 // 223 // TODO: Investigate whether mixing implicit and explicit indices 224 // does not lead to losing information. 225 static void 226 transformIndexOps(OpBuilder &b, LinalgOp op, SmallVectorImpl<Value> &ivs, 227 const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) { 228 // Skip operations that have no region attached. 229 if (op->getNumRegions() == 0) 230 return; 231 assert(op->getNumRegions() == 1 && op->getRegion(0).getBlocks().size() == 1 && 232 "expected linalg operation to have one block."); 233 Block &block = op->getRegion(0).front(); 234 235 for (IndexOp indexOp : block.getOps<linalg::IndexOp>()) { 236 auto rangeIndex = loopIndexToRangeIndex.find(indexOp.dim()); 237 if (rangeIndex == loopIndexToRangeIndex.end()) 238 continue; 239 // Offset the index by the value of the corresponding induction variable and 240 // replace all uses of the previous value. 241 OpBuilder::InsertionGuard g(b); 242 b.setInsertionPointAfter(indexOp); 243 AffineExpr index, iv; 244 bindDims(b.getContext(), index, iv); 245 AffineApplyOp applyOp = b.create<AffineApplyOp>( 246 indexOp.getLoc(), index + iv, 247 ValueRange{indexOp.getResult(), ivs[rangeIndex->second]}); 248 indexOp.getResult().replaceAllUsesExcept( 249 applyOp.getResult(), SmallPtrSet<Operation *, 1>{applyOp}); 250 } 251 } 252 253 template <typename LoopTy> 254 static Optional<TiledLinalgOp> 255 tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes, 256 const LinalgTilingOptions &options) { 257 auto nLoops = op.getNumLoops(); 258 // Initial tile sizes may be too big, only take the first nLoops. 259 tileSizes = tileSizes.take_front(nLoops); 260 261 if (llvm::all_of(tileSizes, isZero)) 262 return llvm::None; 263 264 if (auto convOp = dyn_cast<linalg::ConvOp>(op.getOperation())) { 265 // For conv op only support tiling along batch dimension (which is the first 266 // loop). 267 if (convOp.padding() && !llvm::all_of(tileSizes.drop_front(), isZero)) 268 return llvm::None; 269 } 270 271 // 1. Build the tiled loop ranges. 272 auto allShapeSizes = op.createFlatListOfOperandDims(b, op.getLoc()); 273 AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap(); 274 if (!shapeSizesToLoopsMap) 275 return llvm::None; 276 277 SmallVector<Range, 4> loopRanges; 278 LoopIndexToRangeIndexMap loopIndexToRangeIndex; 279 std::tie(loopRanges, loopIndexToRangeIndex) = makeTiledLoopRanges( 280 b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes); 281 282 SmallVector<Attribute, 4> iteratorTypes; 283 for (auto attr : 284 enumerate(op.iterator_types().cast<ArrayAttr>().getValue())) { 285 if (loopIndexToRangeIndex.count(attr.index())) 286 iteratorTypes.push_back(attr.value()); 287 } 288 // If interchangeVector is empty, use the identity. Build the permutation map 289 // otherwise. 290 auto invPermutationMap = 291 AffineMap::getMultiDimIdentityMap(tileSizes.size(), b.getContext()); 292 if (!options.interchangeVector.empty()) { 293 // Based on the pruned iterations (due to zero tile size), recompute the 294 // interchange vector. 295 SmallVector<unsigned, 4> interchangeVector; 296 interchangeVector.reserve(options.interchangeVector.size()); 297 for (auto pos : options.interchangeVector) { 298 auto it = loopIndexToRangeIndex.find(pos); 299 if (it == loopIndexToRangeIndex.end()) 300 continue; 301 interchangeVector.push_back(it->second); 302 } 303 // Interchange vector is guaranteed to be a permutation, 304 // `inversePermutation` must succeed. 305 invPermutationMap = inversePermutation( 306 AffineMap::getPermutationMap(interchangeVector, b.getContext())); 307 assert(invPermutationMap); 308 applyPermutationToVector(loopRanges, interchangeVector); 309 applyPermutationToVector(iteratorTypes, interchangeVector); 310 } 311 312 // 2. Create the tiled loops. 313 LinalgOp res = op; 314 SmallVector<Value, 4> ivs, tensorResults; 315 GenerateLoopNest<LoopTy>::doit( 316 loopRanges, op, iteratorTypes, 317 [&](ValueRange localIvs, ValueRange iterArgs) -> scf::ValueVector { 318 auto &b = ScopedContext::getBuilderRef(); 319 auto loc = ScopedContext::getLocation(); 320 ivs.assign(localIvs.begin(), localIvs.end()); 321 322 // When an `interchangeVector` is present, it has been applied to the 323 // loop ranges and the iterator types. Apply its inverse to the 324 // resulting loop `ivs` to match the op definition. 325 SmallVector<Value, 4> interchangedIvs; 326 if (!options.interchangeVector.empty()) 327 interchangedIvs = applyMapToValues(b, loc, invPermutationMap, ivs); 328 else 329 interchangedIvs.assign(ivs.begin(), ivs.end()); 330 331 assert(op.getNumOutputTensors() == iterArgs.size() && 332 "num output tensors must match number of loop iter arguments"); 333 334 auto operands = llvm::to_vector<4>(op.getInputs()); 335 SmallVector<Value, 4> outputBuffers = op.getOutputBuffers(); 336 // TODO: thanks to simplifying assumption we do not need to worry about 337 // order of output buffers and tensors: there is only ever one kind. 338 assert(outputBuffers.empty() || iterArgs.empty()); 339 operands.append(outputBuffers.begin(), outputBuffers.end()); 340 operands.append(iterArgs.begin(), iterArgs.end()); 341 auto sizeBounds = 342 applyMapToValues(b, loc, shapeSizesToLoopsMap, allShapeSizes); 343 SmallVector<Value, 4> tiledOperands = makeTiledShapes( 344 b, loc, op, operands, interchangedIvs, tileSizes, sizeBounds); 345 auto nonShapedOperands = op.getAssumedNonShapedOperands(); 346 tiledOperands.append(nonShapedOperands.begin(), 347 nonShapedOperands.end()); 348 349 // TODO: use an interface/adaptor to avoid leaking position in 350 // `tiledOperands`. 351 SmallVector<Type, 4> resultTensorTypes; 352 for (OpOperand *opOperand : op.getOutputTensorsOpOperands()) 353 resultTensorTypes.push_back( 354 tiledOperands[opOperand->getOperandNumber()].getType()); 355 356 res = op.clone(b, loc, resultTensorTypes, tiledOperands); 357 358 // Insert a subtensor_insert for each output tensor. 359 unsigned resultIdx = 0; 360 for (OpOperand *opOperand : op.getOutputTensorsOpOperands()) { 361 // TODO: use an interface/adaptor to avoid leaking position in 362 // `tiledOperands`. 363 Value outputTensor = tiledOperands[opOperand->getOperandNumber()]; 364 if (auto subtensor = outputTensor.getDefiningOp<SubTensorOp>()) { 365 tensorResults.push_back(b.create<SubTensorInsertOp>( 366 loc, subtensor.source().getType(), res->getResult(resultIdx), 367 subtensor.source(), subtensor.offsets(), subtensor.sizes(), 368 subtensor.strides(), subtensor.static_offsets(), 369 subtensor.static_sizes(), subtensor.static_strides())); 370 } else { 371 tensorResults.push_back(res->getResult(resultIdx)); 372 } 373 ++resultIdx; 374 } 375 return scf::ValueVector(tensorResults.begin(), tensorResults.end()); 376 }, 377 options.distribution); 378 379 // 3a. Transforms index arguments of `linalg.generic` w.r.t. to the tiling. 380 transformIndexedGenericOpIndices(b, res, ivs, loopIndexToRangeIndex); 381 // 3b. Transform IndexOp results w.r.t. the tiling. 382 transformIndexOps(b, res, ivs, loopIndexToRangeIndex); 383 384 // 4. Gather the newly created loops and return them with the new op. 385 SmallVector<Operation *, 8> loops; 386 loops.reserve(ivs.size()); 387 for (auto iv : ivs) { 388 if (iv.isa<BlockArgument>()) { 389 loops.push_back(iv.cast<BlockArgument>().getOwner()->getParentOp()); 390 assert(loops.back() && "no owner found for induction variable!"); 391 } else { 392 // TODO: Instead of doing this, try to recover the ops used instead of the 393 // loop. 394 loops.push_back(nullptr); 395 } 396 } 397 398 // 5. Get the tensor results from the outermost loop if available. Otherwise 399 // use the previously captured `tensorResults`. 400 Operation *outermostLoop = nullptr; 401 for (Operation *loop : loops) 402 if ((outermostLoop = loop)) 403 break; 404 405 return TiledLinalgOp{ 406 res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults}; 407 } 408 409 template <typename LoopTy> 410 Optional<TiledLinalgOp> static tileLinalgOpImpl( 411 OpBuilder &b, LinalgOp op, const LinalgTilingOptions &options) { 412 OpBuilder::InsertionGuard g(b); 413 b.setInsertionPoint(op); 414 ScopedContext scope(b, op.getLoc()); 415 416 if (!options.tileSizeComputationFunction) 417 return llvm::None; 418 419 // Enforce the convention that "tiling by zero" skips tiling a particular 420 // dimension. This convention is significantly simpler to handle instead of 421 // adjusting affine maps to account for missing dimensions. 422 auto nLoops = op.getNumLoops(); 423 SmallVector<Value, 4> tileSizeVector = 424 options.tileSizeComputationFunction(b, op); 425 if (tileSizeVector.size() < nLoops) { 426 auto zero = std_constant_index(0); 427 tileSizeVector.append(nLoops - tileSizeVector.size(), zero); 428 } 429 430 return tileLinalgOpImpl<LoopTy>(b, op, tileSizeVector, options); 431 } 432 433 Optional<TiledLinalgOp> 434 mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, 435 const LinalgTilingOptions &options) { 436 switch (options.loopType) { 437 case LinalgTilingLoopType::Loops: 438 return tileLinalgOpImpl<scf::ForOp>(b, op, options); 439 case LinalgTilingLoopType::ParallelLoops: 440 return tileLinalgOpImpl<scf::ParallelOp>(b, op, options); 441 case LinalgTilingLoopType::TiledLoops: 442 return tileLinalgOpImpl<linalg::TiledLoopOp>(b, op, options); 443 default:; 444 } 445 return llvm::None; 446 } 447 448 namespace { 449 /// Helper classes for type list expansion. 450 template <typename... OpTypes> 451 class CanonicalizationPatternList; 452 453 template <> 454 class CanonicalizationPatternList<> { 455 public: 456 static void insert(RewritePatternSet &patterns) {} 457 }; 458 459 template <typename OpTy, typename... OpTypes> 460 class CanonicalizationPatternList<OpTy, OpTypes...> { 461 public: 462 static void insert(RewritePatternSet &patterns) { 463 OpTy::getCanonicalizationPatterns(patterns, patterns.getContext()); 464 CanonicalizationPatternList<OpTypes...>::insert(patterns); 465 } 466 }; 467 468 /// Helper classes for type list expansion. 469 template <typename... OpTypes> 470 class RewritePatternList; 471 472 template <> 473 class RewritePatternList<> { 474 public: 475 static void insert(RewritePatternSet &patterns, 476 const LinalgTilingOptions &options) {} 477 }; 478 479 template <typename OpTy, typename... OpTypes> 480 class RewritePatternList<OpTy, OpTypes...> { 481 public: 482 static void insert(RewritePatternSet &patterns, 483 const LinalgTilingOptions &options) { 484 auto *ctx = patterns.getContext(); 485 patterns.add<LinalgTilingPattern<OpTy>>( 486 ctx, options, 487 LinalgTransformationFilter(ArrayRef<Identifier>{}, 488 Identifier::get("tiled", ctx))); 489 RewritePatternList<OpTypes...>::insert(patterns, options); 490 } 491 }; 492 } // namespace 493 494 RewritePatternSet 495 mlir::linalg::getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx) { 496 RewritePatternSet patterns(ctx); 497 populateLinalgTilingCanonicalizationPatterns(patterns); 498 return patterns; 499 } 500 501 void mlir::linalg::populateLinalgTilingCanonicalizationPatterns( 502 RewritePatternSet &patterns) { 503 auto *ctx = patterns.getContext(); 504 AffineApplyOp::getCanonicalizationPatterns(patterns, ctx); 505 AffineForOp::getCanonicalizationPatterns(patterns, ctx); 506 AffineMinOp::getCanonicalizationPatterns(patterns, ctx); 507 AffineMaxOp::getCanonicalizationPatterns(patterns, ctx); 508 scf::ForOp::getCanonicalizationPatterns(patterns, ctx); 509 scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx); 510 ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx); 511 SubTensorOp::getCanonicalizationPatterns(patterns, ctx); 512 memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx); 513 tensor::CastOp::getCanonicalizationPatterns(patterns, ctx); 514 memref::ViewOp::getCanonicalizationPatterns(patterns, ctx); 515 CanonicalizationPatternList< 516 #define GET_OP_LIST 517 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 518 >::insert(patterns); 519 } 520 521 /// Populate the given list with patterns that apply Linalg tiling. 522 static void insertTilingPatterns(RewritePatternSet &patterns, 523 const LinalgTilingOptions &options) { 524 RewritePatternList<GenericOp, IndexedGenericOp, 525 #define GET_OP_LIST 526 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 527 >::insert(patterns, options); 528 } 529 530 static void applyTilingToLoopPatterns(LinalgTilingLoopType loopType, 531 FuncOp funcOp, 532 ArrayRef<int64_t> tileSizes) { 533 auto options = 534 LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType); 535 MLIRContext *ctx = funcOp.getContext(); 536 RewritePatternSet patterns(ctx); 537 insertTilingPatterns(patterns, options); 538 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 539 (void)applyPatternsAndFoldGreedily( 540 funcOp, getLinalgTilingCanonicalizationPatterns(ctx)); 541 // Drop the marker. 542 funcOp.walk([](LinalgOp op) { 543 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 544 }); 545 } 546 547 namespace { 548 struct LinalgTilingPass : public LinalgTilingBase<LinalgTilingPass> { 549 LinalgTilingPass() = default; 550 LinalgTilingPass(ArrayRef<int64_t> sizes) { tileSizes = sizes; } 551 552 void runOnFunction() override { 553 applyTilingToLoopPatterns(LinalgTilingLoopType::Loops, getFunction(), 554 tileSizes); 555 } 556 }; 557 558 struct LinalgTilingToParallelLoopsPass 559 : public LinalgTilingToParallelLoopsBase<LinalgTilingToParallelLoopsPass> { 560 LinalgTilingToParallelLoopsPass() = default; 561 LinalgTilingToParallelLoopsPass(ArrayRef<int64_t> sizes) { 562 tileSizes = sizes; 563 } 564 565 void runOnFunction() override { 566 applyTilingToLoopPatterns(LinalgTilingLoopType::ParallelLoops, 567 getFunction(), tileSizes); 568 } 569 }; 570 571 struct LinalgTilingToTiledLoopsPass 572 : public LinalgTilingToTiledLoopsBase<LinalgTilingToTiledLoopsPass> { 573 LinalgTilingToTiledLoopsPass() = default; 574 LinalgTilingToTiledLoopsPass(ArrayRef<int64_t> sizes) { tileSizes = sizes; } 575 576 void runOnFunction() override { 577 applyTilingToLoopPatterns(LinalgTilingLoopType::TiledLoops, getFunction(), 578 tileSizes); 579 } 580 }; 581 582 } // namespace 583 584 std::unique_ptr<OperationPass<FuncOp>> 585 mlir::createLinalgTilingPass(ArrayRef<int64_t> tileSizes) { 586 return std::make_unique<LinalgTilingPass>(tileSizes); 587 } 588 589 std::unique_ptr<OperationPass<FuncOp>> 590 mlir::createLinalgTilingToParallelLoopsPass(ArrayRef<int64_t> tileSizes) { 591 return std::make_unique<LinalgTilingToParallelLoopsPass>(tileSizes); 592 } 593 594 std::unique_ptr<OperationPass<FuncOp>> 595 mlir::createLinalgTilingToTiledLoopPass(ArrayRef<int64_t> tileSizes) { 596 return std::make_unique<LinalgTilingToTiledLoopsPass>(tileSizes); 597 } 598