1 //===- Tiling.cpp - Implementation of linalg Tiling -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Tiling pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h" 15 #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 17 #include "mlir/Dialect/Linalg/Passes.h" 18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/Dialect/MemRef/EDSC/Intrinsics.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/Dialect/SCF/EDSC/Builders.h" 23 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 24 #include "mlir/Dialect/Tensor/IR/Tensor.h" 25 #include "mlir/IR/AffineExpr.h" 26 #include "mlir/IR/AffineMap.h" 27 #include "mlir/Transforms/FoldUtils.h" 28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 29 30 #include "llvm/Support/CommandLine.h" 31 32 using namespace mlir; 33 using namespace mlir::edsc; 34 using namespace mlir::edsc::intrinsics; 35 using namespace mlir::linalg; 36 using namespace mlir::scf; 37 38 #define DEBUG_TYPE "linalg-tiling" 39 40 static bool isZero(Value v) { 41 if (auto cst = v.getDefiningOp<ConstantIndexOp>()) 42 return cst.getValue() == 0; 43 return false; 44 } 45 46 using LoopIndexToRangeIndexMap = DenseMap<int, int>; 47 48 // Creates a number of ranges equal to the number of non-zero in `tileSizes`. 49 // One for each loop of the LinalgOp that is tiled. The `tileSizes` argument has 50 // one entry per surrounding loop. It uses zero as the convention that a 51 // particular loop is not tiled. This convention simplifies implementations by 52 // avoiding affine map manipulations. 53 // The returned ranges correspond to the loop ranges, in the proper order, that 54 // are tiled and for which new loops will be created. Also the function returns 55 // a map from loop indices of the LinalgOp to the corresponding non-empty range 56 // indices of newly created loops. 57 static std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap> 58 makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map, 59 ValueRange allShapeSizes, ValueRange allTileSizes) { 60 assert(allTileSizes.size() == map.getNumResults()); 61 // Apply `map` to get shape sizes in loop order. 62 auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes); 63 SmallVector<Value, 4> tileSizes(allTileSizes.begin(), allTileSizes.end()); 64 65 // Traverse the tile sizes, which are in loop order, erase zeros everywhere. 66 LoopIndexToRangeIndexMap loopIndexToRangeIndex; 67 for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) { 68 if (isZero(tileSizes[idx - zerosCount])) { 69 shapeSizes.erase(shapeSizes.begin() + idx - zerosCount); 70 tileSizes.erase(tileSizes.begin() + idx - zerosCount); 71 ++zerosCount; 72 continue; 73 } 74 loopIndexToRangeIndex[idx] = idx - zerosCount; 75 } 76 77 // Create a new range with the applied tile sizes. 78 SmallVector<Range, 4> res; 79 for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) 80 res.push_back( 81 Range{std_constant_index(0), shapeSizes[idx], tileSizes[idx]}); 82 return std::make_tuple(res, loopIndexToRangeIndex); 83 } 84 85 // IndexedGenericOp explicitly uses induction variables in the loop body. The 86 // values of the indices that are used in the loop body for any given access of 87 // input/output memref before `subview` op was applied should be invariant with 88 // respect to tiling. 89 // 90 // Therefore, if the operation is tiled, we have to transform the indices 91 // accordingly, i.e. offset them by the values of the corresponding induction 92 // variables that are captured implicitly in the body of the op. 93 // 94 // Example. `linalg.indexed_generic` before tiling: 95 // 96 // #id_2d = (i, j) -> (i, j) 97 // #pointwise_2d_trait = { 98 // indexing_maps = [#id_2d, #id_2d], 99 // iterator_types = ["parallel", "parallel"], 100 // n_views = [1, 1] 101 // } 102 // linalg.indexed_generic #pointwise_2d_trait %operand, %result { 103 // ^bb0(%i: index, %j: index, %operand_in: f32, %result_in: f32): 104 // <some operations that use %i, %j> 105 // }: memref<50x100xf32>, memref<50x100xf32> 106 // 107 // After tiling pass with tiles sizes 10 and 25: 108 // 109 // #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2) 110 // 111 // %c1 = constant 1 : index 112 // %c0 = constant 0 : index 113 // %c25 = constant 25 : index 114 // %c10 = constant 10 : index 115 // operand_dim_0 = dim %operand, 0 : memref<50x100xf32> 116 // operand_dim_1 = dim %operand, 1 : memref<50x100xf32> 117 // scf.for %k = %c0 to operand_dim_0 step %c10 { 118 // scf.for %l = %c0 to operand_dim_1 step %c25 { 119 // %4 = memref.subview %operand[%k, %l][%c10, %c25][%c1, %c1] 120 // : memref<50x100xf32> to memref<?x?xf32, #strided> 121 // %5 = memref.subview %result[%k, %l][%c10, %c25][%c1, %c1] 122 // : memref<50x100xf32> to memref<?x?xf32, #strided> 123 // linalg.indexed_generic pointwise_2d_trait %4, %5 { 124 // ^bb0(%i: index, %j: index, %operand_in: f32, %result_in: f32): 125 // // Indices `k` and `l` are implicitly captured in the body. 126 // %transformed_i = addi %i, %k : index // index `i` is offset by %k 127 // %transformed_j = addi %j, %l : index // index `j` is offset by %l 128 // // Every use of %i, %j is replaced with %transformed_i, %transformed_j 129 // <some operations that use %transformed_i, %transformed_j> 130 // }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided> 131 // } 132 // } 133 // 134 // TODO: Investigate whether mixing implicit and explicit indices 135 // does not lead to losing information. 136 static void transformIndexedGenericOpIndices( 137 OpBuilder &b, LinalgOp op, SmallVectorImpl<Value> &ivs, 138 const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) { 139 auto indexedGenericOp = dyn_cast<IndexedGenericOp>(op.getOperation()); 140 if (!indexedGenericOp) 141 return; 142 143 // `linalg.indexed_generic` comes in two flavours. One has a region with a 144 // single block that defines the loop body. The other has a `fun` attribute 145 // that refers to an existing function symbol. The `fun` function call will be 146 // inserted in the loop body in that case. 147 // 148 // TODO: Add support for `linalg.indexed_generic` with `fun` attribute. 149 auto ®ion = indexedGenericOp.region(); 150 if (region.empty()) { 151 indexedGenericOp.emitOpError("expected a region"); 152 return; 153 } 154 auto &block = region.front(); 155 156 OpBuilder::InsertionGuard g(b); 157 b.setInsertionPointToStart(&block); 158 for (unsigned i = 0; i < indexedGenericOp.getNumLoops(); ++i) { 159 auto rangeIndex = loopIndexToRangeIndex.find(i); 160 if (rangeIndex == loopIndexToRangeIndex.end()) 161 continue; 162 Value oldIndex = block.getArgument(i); 163 // Offset the index argument `i` by the value of the corresponding induction 164 // variable and replace all uses of the previous value. 165 Value newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex, 166 ivs[rangeIndex->second]); 167 for (auto &use : oldIndex.getUses()) { 168 if (use.getOwner() == newIndex.getDefiningOp()) 169 continue; 170 use.set(newIndex); 171 } 172 } 173 } 174 175 template <typename LoopTy> 176 static Optional<TiledLinalgOp> 177 tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes, 178 const LinalgTilingOptions &options) { 179 auto nLoops = op.getNumLoops(); 180 // Initial tile sizes may be too big, only take the first nLoops. 181 tileSizes = tileSizes.take_front(nLoops); 182 183 if (llvm::all_of(tileSizes, isZero)) 184 return llvm::None; 185 186 if (auto convOp = dyn_cast<linalg::ConvOp>(op.getOperation())) { 187 // For conv op only support tiling along batch dimension (which is the first 188 // loop). 189 if (convOp.padding() && !llvm::all_of(tileSizes.drop_front(), isZero)) 190 return llvm::None; 191 } 192 193 // 1. Build the tiled loop ranges. 194 auto allShapeSizes = op.createFlatListOfOperandDims(b, op.getLoc()); 195 AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap(); 196 if (!shapeSizesToLoopsMap) 197 return llvm::None; 198 199 SmallVector<Range, 4> loopRanges; 200 LoopIndexToRangeIndexMap loopIndexToRangeIndex; 201 std::tie(loopRanges, loopIndexToRangeIndex) = makeTiledLoopRanges( 202 b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes); 203 204 SmallVector<Attribute, 4> iteratorTypes; 205 for (auto attr : 206 enumerate(op.iterator_types().cast<ArrayAttr>().getValue())) { 207 if (loopIndexToRangeIndex.count(attr.index())) 208 iteratorTypes.push_back(attr.value()); 209 } 210 // If interchangeVector is empty, use the identity. Build the permutation map 211 // otherwise. 212 auto invPermutationMap = 213 AffineMap::getMultiDimIdentityMap(tileSizes.size(), b.getContext()); 214 if (!options.interchangeVector.empty()) { 215 // Based on the pruned iterations (due to zero tile size), recompute the 216 // interchange vector. 217 SmallVector<unsigned, 4> interchangeVector; 218 interchangeVector.reserve(options.interchangeVector.size()); 219 for (auto pos : options.interchangeVector) { 220 auto it = loopIndexToRangeIndex.find(pos); 221 if (it == loopIndexToRangeIndex.end()) 222 continue; 223 interchangeVector.push_back(it->second); 224 } 225 // Interchange vector is guaranteed to be a permutation, 226 // `inversePermutation` must succeed. 227 invPermutationMap = inversePermutation( 228 AffineMap::getPermutationMap(interchangeVector, b.getContext())); 229 assert(invPermutationMap); 230 applyPermutationToVector(loopRanges, interchangeVector); 231 applyPermutationToVector(iteratorTypes, interchangeVector); 232 } 233 234 // 2. Create the tiled loops. 235 LinalgOp res = op; 236 SmallVector<Value, 4> ivs, tensorResults; 237 auto outputTensors = op.getOutputTensors(); 238 GenerateLoopNest<LoopTy>::doit( 239 loopRanges, /*iterArgInitValues*/ outputTensors, iteratorTypes, 240 [&](ValueRange localIvs, ValueRange iterArgs) -> scf::ValueVector { 241 auto &b = ScopedContext::getBuilderRef(); 242 auto loc = ScopedContext::getLocation(); 243 ivs.assign(localIvs.begin(), localIvs.end()); 244 245 // When an `interchangeVector` is present, it has been applied to the 246 // loop ranges and the iterator types. Apply its inverse to the 247 // resulting loop `ivs` to match the op definition. 248 SmallVector<Value, 4> interchangedIvs; 249 if (!options.interchangeVector.empty()) 250 interchangedIvs = applyMapToValues(b, loc, invPermutationMap, ivs); 251 else 252 interchangedIvs.assign(ivs.begin(), ivs.end()); 253 254 assert(op.getNumOutputTensors() == iterArgs.size() && 255 "num output tensors must match number of loop iter arguments"); 256 257 auto operands = llvm::to_vector<4>(op.getInputs()); 258 SmallVector<Value, 4> outputBuffers = op.getOutputBuffers(); 259 // TODO: thanks to simplifying assumption we do not need to worry about 260 // order of output buffers and tensors: there is only ever one kind. 261 assert(outputBuffers.empty() || iterArgs.empty()); 262 operands.append(outputBuffers.begin(), outputBuffers.end()); 263 operands.append(iterArgs.begin(), iterArgs.end()); 264 auto sizeBounds = 265 applyMapToValues(b, loc, shapeSizesToLoopsMap, allShapeSizes); 266 SmallVector<Value, 4> tiledOperands = makeTiledShapes( 267 b, loc, op, operands, interchangedIvs, tileSizes, sizeBounds); 268 auto nonShapedOperands = op.getAssumedNonShapedOperands(); 269 tiledOperands.append(nonShapedOperands.begin(), 270 nonShapedOperands.end()); 271 272 // TODO: use an interface/adaptor to avoid leaking position in 273 // `tiledOperands`. 274 SmallVector<Type, 4> resultTensorTypes; 275 for (OpOperand *opOperand : op.getOutputTensorsOpOperands()) 276 resultTensorTypes.push_back( 277 tiledOperands[opOperand->getOperandNumber()].getType()); 278 279 res = op.clone(b, loc, resultTensorTypes, tiledOperands); 280 281 // Insert a subtensor_insert for each output tensor. 282 unsigned resultIdx = 0; 283 for (OpOperand *opOperand : op.getOutputTensorsOpOperands()) { 284 // TODO: use an interface/adaptor to avoid leaking position in 285 // `tiledOperands`. 286 Value outputTensor = tiledOperands[opOperand->getOperandNumber()]; 287 if (auto subtensor = outputTensor.getDefiningOp<SubTensorOp>()) { 288 tensorResults.push_back(b.create<SubTensorInsertOp>( 289 loc, subtensor.source().getType(), res->getResult(resultIdx), 290 subtensor.source(), subtensor.offsets(), subtensor.sizes(), 291 subtensor.strides(), subtensor.static_offsets(), 292 subtensor.static_sizes(), subtensor.static_strides())); 293 } else { 294 tensorResults.push_back(res->getResult(resultIdx)); 295 } 296 ++resultIdx; 297 } 298 return scf::ValueVector(tensorResults.begin(), tensorResults.end()); 299 }, 300 options.distribution); 301 302 // 3. Transforms index arguments of `linalg.generic` w.r.t. to the tiling. 303 transformIndexedGenericOpIndices(b, res, ivs, loopIndexToRangeIndex); 304 305 // 4. Gather the newly created loops and return them with the new op. 306 SmallVector<Operation *, 8> loops; 307 loops.reserve(ivs.size()); 308 for (auto iv : ivs) { 309 if (iv.isa<BlockArgument>()) { 310 loops.push_back(iv.cast<BlockArgument>().getOwner()->getParentOp()); 311 assert(loops.back() && "no owner found for induction variable!"); 312 } else { 313 // TODO: Instead of doing this, try to recover the ops used instead of the 314 // loop. 315 loops.push_back(nullptr); 316 } 317 } 318 319 // 5. Get the tensor results from the outermost loop if available. Otherwise 320 // use the previously captured `tensorResults`. 321 Operation *outermostLoop = nullptr; 322 for (Operation *loop : loops) 323 if ((outermostLoop = loop)) 324 break; 325 326 return TiledLinalgOp{ 327 res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults}; 328 } 329 330 template <typename LoopTy> 331 Optional<TiledLinalgOp> static tileLinalgOpImpl( 332 OpBuilder &b, LinalgOp op, const LinalgTilingOptions &options) { 333 OpBuilder::InsertionGuard g(b); 334 b.setInsertionPoint(op); 335 ScopedContext scope(b, op.getLoc()); 336 337 if (!options.tileSizeComputationFunction) 338 return llvm::None; 339 340 // Enforce the convention that "tiling by zero" skips tiling a particular 341 // dimension. This convention is significantly simpler to handle instead of 342 // adjusting affine maps to account for missing dimensions. 343 auto nLoops = op.getNumLoops(); 344 SmallVector<Value, 4> tileSizeVector = 345 options.tileSizeComputationFunction(b, op); 346 if (tileSizeVector.size() < nLoops) { 347 auto zero = std_constant_index(0); 348 tileSizeVector.append(nLoops - tileSizeVector.size(), zero); 349 } 350 351 return tileLinalgOpImpl<LoopTy>(b, op, tileSizeVector, options); 352 } 353 354 Optional<TiledLinalgOp> 355 mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, 356 const LinalgTilingOptions &options) { 357 switch (options.loopType) { 358 case LinalgTilingLoopType::Loops: 359 return tileLinalgOpImpl<scf::ForOp>(b, op, options); 360 case LinalgTilingLoopType::ParallelLoops: 361 return tileLinalgOpImpl<scf::ParallelOp>(b, op, options); 362 default:; 363 } 364 return llvm::None; 365 } 366 367 namespace { 368 /// Helper classes for type list expansion. 369 template <typename... OpTypes> 370 class CanonicalizationPatternList; 371 372 template <> 373 class CanonicalizationPatternList<> { 374 public: 375 static void insert(RewritePatternSet &patterns) {} 376 }; 377 378 template <typename OpTy, typename... OpTypes> 379 class CanonicalizationPatternList<OpTy, OpTypes...> { 380 public: 381 static void insert(RewritePatternSet &patterns) { 382 OpTy::getCanonicalizationPatterns(patterns, patterns.getContext()); 383 CanonicalizationPatternList<OpTypes...>::insert(patterns); 384 } 385 }; 386 387 /// Helper classes for type list expansion. 388 template <typename... OpTypes> 389 class RewritePatternList; 390 391 template <> 392 class RewritePatternList<> { 393 public: 394 static void insert(RewritePatternSet &patterns, 395 const LinalgTilingOptions &options) {} 396 }; 397 398 template <typename OpTy, typename... OpTypes> 399 class RewritePatternList<OpTy, OpTypes...> { 400 public: 401 static void insert(RewritePatternSet &patterns, 402 const LinalgTilingOptions &options) { 403 auto *ctx = patterns.getContext(); 404 patterns.add<LinalgTilingPattern<OpTy>>( 405 ctx, options, 406 LinalgTransformationFilter(ArrayRef<Identifier>{}, 407 Identifier::get("tiled", ctx))); 408 RewritePatternList<OpTypes...>::insert(patterns, options); 409 } 410 }; 411 } // namespace 412 413 RewritePatternSet 414 mlir::linalg::getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx) { 415 RewritePatternSet patterns(ctx); 416 populateLinalgTilingCanonicalizationPatterns(patterns); 417 return patterns; 418 } 419 420 void mlir::linalg::populateLinalgTilingCanonicalizationPatterns( 421 RewritePatternSet &patterns) { 422 auto *ctx = patterns.getContext(); 423 AffineApplyOp::getCanonicalizationPatterns(patterns, ctx); 424 AffineForOp::getCanonicalizationPatterns(patterns, ctx); 425 AffineMinOp::getCanonicalizationPatterns(patterns, ctx); 426 AffineMaxOp::getCanonicalizationPatterns(patterns, ctx); 427 scf::ForOp::getCanonicalizationPatterns(patterns, ctx); 428 scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx); 429 ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx); 430 SubTensorOp::getCanonicalizationPatterns(patterns, ctx); 431 memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx); 432 tensor::CastOp::getCanonicalizationPatterns(patterns, ctx); 433 memref::ViewOp::getCanonicalizationPatterns(patterns, ctx); 434 CanonicalizationPatternList< 435 #define GET_OP_LIST 436 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 437 >::insert(patterns); 438 } 439 440 /// Populate the given list with patterns that apply Linalg tiling. 441 static void insertTilingPatterns(RewritePatternSet &patterns, 442 const LinalgTilingOptions &options) { 443 RewritePatternList<GenericOp, IndexedGenericOp, 444 #define GET_OP_LIST 445 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 446 >::insert(patterns, options); 447 } 448 449 static void applyTilingToLoopPatterns(LinalgTilingLoopType loopType, 450 FuncOp funcOp, 451 ArrayRef<int64_t> tileSizes) { 452 auto options = 453 LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType); 454 MLIRContext *ctx = funcOp.getContext(); 455 RewritePatternSet patterns(ctx); 456 insertTilingPatterns(patterns, options); 457 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 458 (void)applyPatternsAndFoldGreedily( 459 funcOp, getLinalgTilingCanonicalizationPatterns(ctx)); 460 // Drop the marker. 461 funcOp.walk([](LinalgOp op) { 462 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 463 }); 464 } 465 466 namespace { 467 struct LinalgTilingPass : public LinalgTilingBase<LinalgTilingPass> { 468 LinalgTilingPass() = default; 469 LinalgTilingPass(ArrayRef<int64_t> sizes) { tileSizes = sizes; } 470 471 void runOnFunction() override { 472 applyTilingToLoopPatterns(LinalgTilingLoopType::Loops, getFunction(), 473 tileSizes); 474 } 475 }; 476 477 struct LinalgTilingToParallelLoopsPass 478 : public LinalgTilingToParallelLoopsBase<LinalgTilingToParallelLoopsPass> { 479 LinalgTilingToParallelLoopsPass() = default; 480 LinalgTilingToParallelLoopsPass(ArrayRef<int64_t> sizes) { 481 tileSizes = sizes; 482 } 483 484 void runOnFunction() override { 485 applyTilingToLoopPatterns(LinalgTilingLoopType::ParallelLoops, 486 getFunction(), tileSizes); 487 } 488 }; 489 490 } // namespace 491 492 std::unique_ptr<OperationPass<FuncOp>> 493 mlir::createLinalgTilingPass(ArrayRef<int64_t> tileSizes) { 494 return std::make_unique<LinalgTilingPass>(tileSizes); 495 } 496 497 std::unique_ptr<OperationPass<FuncOp>> 498 mlir::createLinalgTilingToParallelLoopsPass(ArrayRef<int64_t> tileSizes) { 499 return std::make_unique<LinalgTilingToParallelLoopsPass>(tileSizes); 500 } 501