1 //===- FusionOnTensors.cpp - Implementation of linalg Fusion --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements linalg fusion on tensors 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Analysis/SliceAnalysis.h" 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Linalg/IR/Linalg.h" 17 #include "mlir/Dialect/Linalg/Passes.h" 18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/Dialect/Utils/IndexingUtils.h" 22 #include "mlir/IR/AffineExpr.h" 23 #include "mlir/IR/AffineMap.h" 24 #include "mlir/Support/LLVM.h" 25 26 using namespace mlir; 27 using namespace linalg; 28 29 //===----------------------------------------------------------------------===// 30 // StructuredOp specific helpers. 31 //===----------------------------------------------------------------------===// 32 33 /// Returns the tiled slice dimensions given the tiled consumer loop dimensions. 34 /// The slice defines a hyper rectangular iteration space and fusing the 35 /// producer is always possible. However, depending on the consumer indexing 36 /// map, not all slice elements may be consumed and the tiles may overlap. In 37 /// these cases, fusion introduces redundant computation. 38 static SmallVector<int64_t> getTiledSliceDims(OpOperand *consumerOperand, 39 ArrayRef<int64_t> tiledLoopDims) { 40 // Get the consumer operand indexing map. 41 LinalgOp consumerOp = consumerOperand->getOwner(); 42 AffineMap indexingMap = consumerOp.getTiedIndexingMap(consumerOperand); 43 44 // Search the slice dimensions tiled by a tile loop dimension. 45 DenseSet<int64_t> tiledSliceDimIndices; 46 for (const auto &en : enumerate(indexingMap.getResults())) { 47 for (auto tiledLoopDim : tiledLoopDims) { 48 if (en.value().isFunctionOfDim(tiledLoopDim)) 49 tiledSliceDimIndices.insert(en.index()); 50 } 51 } 52 return {tiledSliceDimIndices.begin(), tiledSliceDimIndices.end()}; 53 } 54 55 /// Given a vector of `tiledSliceDimIndices` that represent the tiled dimensions 56 /// of the producer result slice returns the tiled producer loop dimensions. 57 /// Example: 58 /// ``` 59 /// %res = linalg.fill(%cst, %input) 60 /// scf.for %i 61 /// scf.for %j 62 /// %slice = tensor.extract_slice %res[%i, %j] 63 /// ``` 64 /// getTiledProducerLoops(%res, [0, 1]) returns the loop indices [0, 1]. 65 static SmallVector<int64_t> 66 getTiledProducerLoops(OpResult producerResult, 67 ArrayRef<int64_t> tiledSliceDimIndices) { 68 LinalgOp producerOp = producerResult.getOwner(); 69 70 // Get the indexing map of the `producerOp` output operand that matches 71 // ´producerResult´. 72 AffineMap producerIndexingMap = producerOp.getTiedIndexingMap( 73 producerOp.getOutputOperand(producerResult.getResultNumber())); 74 75 // Keep only the tiled result slice dimensions of `producerIndexingMap`. 76 AffineMap tiledProducerIndexingSubMap = 77 producerIndexingMap.getSubMap(SmallVector<unsigned>( 78 tiledSliceDimIndices.begin(), tiledSliceDimIndices.end())); 79 80 // Compute the producer loop indices mapped to the tiled result slice 81 // dimensions. As the output indexing map of structured operations are 82 // projected permutations, `tiledProducerIndexingSubMap` has to be a 83 // projected permutation as well. We can thus obtain the producer loop indices 84 // by getting the positions of the result dimensions. 85 // Example: 86 // (d0, d1, d2) -> (d0, d2) has the result positions [0, 2]. 87 assert(tiledProducerIndexingSubMap.isProjectedPermutation() && 88 "expect slice and producer loop dimensions map one-to-one"); 89 SmallVector<int64_t> tiledProducerLoopIndices; 90 llvm::transform( 91 llvm::seq<unsigned>(0, tiledProducerIndexingSubMap.getNumResults()), 92 std::back_inserter(tiledProducerLoopIndices), [&](unsigned idx) { 93 return tiledProducerIndexingSubMap.getDimPosition(idx); 94 }); 95 96 return tiledProducerLoopIndices; 97 } 98 99 /// Returns the producer fused in place of `sliceOp`. Tile the producer operands 100 /// along the `tiledSliceDimIndices` and clone the producer. Consider the case 101 /// of fusion of an output tensor: 102 /// ``` 103 /// %1 = producer ins(...) outs(%0) 104 /// %2 = consumer ins(...) outs(%1) 105 /// ``` 106 /// When consumer is tiled, %1 appears in the loop iter_args: 107 /// ``` 108 /// %1 = producer ins(...) outs(%0) 109 /// %2 = scf.for ... iter_args(%1) .. (%bbarg) { 110 /// %t1 = tensor.extract_slice %bbarg[..] 111 /// %t2 = consumer ins(...) outs(%t1) 112 /// %r = tensor.insert_slice %t2, %bbarg[...] 113 /// } 114 /// ``` 115 /// Fusing %1 into the loop requires updating iter_args(%1) to iter_args(%0): 116 /// ``` 117 /// %2 = scf.for ... iter_args(%0) .. (%bbarg) { 118 /// %t0 = tensor.extract_slice %bbarg[..] 119 /// %t1 = producer ins(...) outs(%t0) 120 /// %t2 = consumer ins(...) outs(%t1) 121 /// %r = tensor.insert_slice %t2, %bbarg[...] 122 /// } 123 /// ``` 124 /// This transformation is only valid if %bbarg is exclusively used by the 125 /// output ExtractSliceOp / InsertSliceOp pair, which is checked by the 126 /// `fuseProducer` method. 127 /// TODO: instead of check and failure, insert new iter_args each time a 128 /// producer is fused into a consumer and fold away unused iter_args. 129 static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult, 130 tensor::ExtractSliceOp sliceOp, 131 ArrayRef<int64_t> tiledSliceDimIndices, 132 ArrayRef<int64_t> tiledProducerLoopIndices, 133 OpOperand *iterArg) { 134 // Clone the producer after `sliceOp` since the slice may be reused to pass in 135 // the producer result. 136 OpBuilder::InsertionGuard guard(b); 137 b.setInsertionPointAfter(sliceOp); 138 139 // Get the producer. 140 LinalgOp producerOp = producerResult.getOwner(); 141 Location loc = producerOp.getLoc(); 142 143 // Obtain the `producerOp` loop bounds and the `sliceOp` ranges. 144 SmallVector<Value> producerLoopBounds; 145 llvm::transform(producerOp.createLoopRanges(b, loc), 146 std::back_inserter(producerLoopBounds), 147 [](Range range) { return range.size; }); 148 SmallVector<Range> sliceOpRanges = sliceOp.getOrCreateRanges(b, loc); 149 150 // Tile the producer operands given the `sliceOp` ranges. Iterate the 151 // `tiledSliceDimIndices` and store the tile offset and size for the tiled 152 // slice dimension. 153 auto zero = b.create<arith::ConstantIndexOp>(loc, 0); 154 SmallVector<Value> tileIvs(producerOp.getNumLoops(), nullptr); 155 SmallVector<Value> tileSizes(producerOp.getNumLoops(), zero); 156 SmallVector<Value> allIvs(producerOp.getNumLoops(), nullptr); 157 for (auto it : zip(tiledSliceDimIndices, tiledProducerLoopIndices)) { 158 int64_t tiledSliceDim = std::get<0>(it); 159 int64_t tiledProducerLoop = std::get<1>(it); 160 tileIvs[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].offset; 161 tileSizes[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].size; 162 allIvs[tiledProducerLoop] = tileIvs[tiledProducerLoop]; 163 } 164 erase_value(tileIvs, nullptr); 165 SmallVector<Value> tiledOperands = producerOp.getInputAndOutputOperands(); 166 tiledOperands = makeTiledShapes(b, loc, producerOp, tiledOperands, tileIvs, 167 tileSizes, producerLoopBounds, 168 /**omitPartialTileCheck=*/false); 169 170 // Output fusion has to update the iteration arguments of the tile loop nest. 171 // In particular, the iteration argument of the outermost tile loop needs to 172 // be set to the producer output instead of the producer result and `clonedOp` 173 // shall use the existing `sliceOp` result instead of the tiled producer 174 // output operand. 175 if (iterArg) { 176 OpOperand *outputOperand = 177 producerOp.getOutputOperand(producerResult.getResultNumber()); 178 iterArg->set(outputOperand->get()); 179 tiledOperands[outputOperand->getOperandNumber()] = sliceOp.getResult(); 180 } 181 182 // Clone the producer using the tiled producer operands. 183 TypeRange resultTypes = ValueRange(tiledOperands) 184 .take_back(producerOp.getNumOutputs()) 185 .getTypes(); 186 LinalgOp clonedOp = producerOp.clone(b, loc, resultTypes, tiledOperands); 187 188 // Shift all IndexOp results by the tile offset. 189 offsetIndices(b, clonedOp, allIvs); 190 191 return clonedOp; 192 } 193 194 //===----------------------------------------------------------------------===// 195 // TileLoopNest specific helpers. 196 //===----------------------------------------------------------------------===// 197 198 bool TileLoopNest::isEmpty() { return tileLoopOps.empty(); } 199 200 bool TileLoopNest::isValid() { 201 // Check if `rootOp` has been tiled at least once. 202 if (isEmpty() || tiledRootAndFusedOpsLoops.count(rootOp) == 0) 203 return false; 204 205 // Check if the number of loop operations and dimensions match. 206 if (tileLoopOps.size() != tiledRootAndFusedOpsLoops[rootOp].size()) 207 return false; 208 209 // Check if the innermost tile loop is the parent of `tiledOp`. 210 if (rootOp->getParentOp() != tileLoopOps.back()) 211 return false; 212 213 // Check if the tile loops are directly nested. 214 return std::adjacent_find(tileLoopOps.begin(), tileLoopOps.end(), 215 [](Operation *op1, Operation *op2) { 216 return op1 != op2->getParentOp(); 217 }) == tileLoopOps.end(); 218 } 219 220 SmallVector<BlockArgument> TileLoopNest::getTiedBBArgs(BlockArgument bbArg) { 221 assert(bbArg && "expect the block argument to be non-zero"); 222 SmallVector<BlockArgument> bbArgs; 223 224 // Search all tile loop block arguments from inner to outer. 225 for (auto tileLoop : reverse(tileLoopOps)) { 226 if (bbArg.getOwner()->getParentOp() != tileLoop) 227 return {}; 228 bbArgs.push_back(bbArg); 229 OpOperand *iterArg = &tileLoop.getOpOperandForRegionIterArg(bbArg); 230 bbArg = iterArg->get().dyn_cast<BlockArgument>(); 231 } 232 233 // Reverse the block arguments to order them from outer to inner. 234 return {bbArgs.rbegin(), bbArgs.rend()}; 235 } 236 237 OpOperand *TileLoopNest::getTiedIterArg(BlockArgument bbArg) { 238 // Search all block arguments and return the matching iteration argument. 239 SmallVector<BlockArgument> bbArgs = getTiedBBArgs(bbArg); 240 if (bbArgs.size() != tileLoopOps.size()) 241 return nullptr; 242 return &tileLoopOps.front().getOpOperandForRegionIterArg(bbArgs.front()); 243 } 244 245 bool TileLoopNest::hasOtherUses(BlockArgument bbArg, 246 tensor::ExtractSliceOp sliceOp) { 247 // Check the innermost block argument is either used by the ExtractSliceOp 248 // `sliceOp`, the matching InsertSliceOp, or by a DimOp. Handle other uses 249 // conservatively. 250 for (Operation *op : bbArg.getUsers()) { 251 if (!isa<tensor::DimOp, tensor::InsertSliceOp, tensor::ExtractSliceOp>(op)) 252 return false; 253 if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) { 254 if (extractSliceOp != sliceOp) 255 return false; 256 } 257 if (auto insertSliceOp = dyn_cast<tensor::InsertSliceOp>(op)) { 258 SetVector<Operation *> backwardSlice; 259 getBackwardSlice(insertSliceOp.getSource(), &backwardSlice, 260 [](Operation *op) { 261 return isa<LinalgOp, tensor::InsertSliceOp>(op); 262 }); 263 if (backwardSlice.empty() || backwardSlice.front() != sliceOp) 264 return false; 265 } 266 } 267 268 // Check the block arguments, except for the innermost one, have one use. 269 SmallVector<BlockArgument> bbArgs = getTiedBBArgs(bbArg); 270 return !all_of(bbArgs, [&](BlockArgument bbArg) { 271 return bbArg.hasOneUse() || bbArg == bbArgs.back(); 272 }); 273 } 274 275 LogicalResult TileLoopNest::tileRootOp( 276 OpBuilder &b, ArrayRef<int64_t> tileSizes, 277 ArrayRef<int64_t> tileInterchange, 278 Optional<LinalgLoopDistributionOptions> tileDistribution) { 279 // Exit if all tile sizes are zero. 280 if (tileSizes.size() == static_cast<size_t>(count(tileSizes, 0))) 281 return success(); 282 283 // Tile the root operation. 284 LinalgTilingOptions tilingOptions; 285 tilingOptions = tilingOptions 286 .setInterchange(SmallVector<unsigned>( 287 tileInterchange.begin(), tileInterchange.end())) 288 .setTileSizes(tileSizes) 289 .setLoopType(LinalgTilingLoopType::Loops); 290 if (tileDistribution) 291 tilingOptions = tilingOptions.setDistributionOptions(*tileDistribution); 292 293 // TODO: Propagate RewriterBase everywhere. 294 IRRewriter rewriter(b); 295 FailureOr<TiledLinalgOp> tiledRootOp = 296 tileLinalgOp(rewriter, rootOp, tilingOptions); 297 298 // Exit if tiling the root operation fails. 299 if (failed(tiledRootOp)) 300 return failure(); 301 302 // Replace all uses of the root operation if it has been tiled before. All 303 // uses of the original untiled root operation are updated by the calling pass 304 // or pattern. 305 if (!isEmpty()) 306 rootOp->replaceAllUsesWith(tiledRootOp->tensorResults); 307 308 // Transfer the stored `rootOp` loop dimensions if it has been tiled before. 309 if (tiledRootAndFusedOpsLoops.count(rootOp) != 0) { 310 tiledRootAndFusedOpsLoops[tiledRootOp->op] = 311 tiledRootAndFusedOpsLoops[rootOp]; 312 } 313 314 // Update the root operation and append the loops and tile loop dimensions. 315 rootOp = tiledRootOp->op; 316 tileLoopOps.append(tiledRootOp->loops.begin(), tiledRootOp->loops.end()); 317 for (const auto &en : enumerate(tileSizes)) { 318 // Copy only the tiled loop dimensions with non-zero tile size. 319 if (en.value() == 0) 320 continue; 321 tiledRootAndFusedOpsLoops[rootOp].push_back(tileInterchange[en.index()]); 322 } 323 assert(isValid() && "expect tile loop nest to be valid after tiling"); 324 return success(); 325 } 326 327 FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b, 328 OpOperand *consumerOpOperand) { 329 // Check if the consumer has been tiled before. For example, it may not have 330 // been tiled if the outermost tile loop is a reduction loop. 331 if (tiledRootAndFusedOpsLoops.count(consumerOpOperand->getOwner()) == 0) 332 return failure(); 333 334 assert(this->isValid() && 335 "expect the tile loop nest to satisfy all invariants"); 336 337 // Check the tile loop nest is non-empty. 338 if (isEmpty()) 339 return failure(); 340 341 // Check `consumerOpOperand` is defined by an ExtractSliceOp. 342 auto sliceOp = 343 consumerOpOperand->get().getDefiningOp<tensor::ExtractSliceOp>(); 344 if (!sliceOp) 345 return failure(); 346 347 // Check `sliceOp` and `consumerOp` are in the same block. 348 LinalgOp consumerOp = consumerOpOperand->getOwner(); 349 if (sliceOp->getBlock() != rootOp->getBlock() || 350 consumerOp->getBlock() != rootOp->getBlock()) 351 return failure(); 352 353 // Check `consumerOpOperand` is not shape-only to avoid fusion if the data is 354 // not used by the `consumerOp` computation. 355 BlockArgument bbArg = consumerOp.getTiedBlockArgument(consumerOpOperand); 356 if (bbArg.getUses().empty()) 357 return failure(); 358 359 // Check if the producer is a LinalgOp possibly passed by iteration argument. 360 OpOperand *iterArg = nullptr; 361 auto producerResult = sliceOp.getSource().dyn_cast<OpResult>(); 362 if (auto bbArg = sliceOp.getSource().dyn_cast<BlockArgument>()) { 363 iterArg = getTiedIterArg(bbArg); 364 // Check the iteration argument may be used to pass in the producer output. 365 if (!iterArg || hasOtherUses(bbArg, sliceOp)) 366 return failure(); 367 producerResult = iterArg->get().dyn_cast<OpResult>(); 368 } 369 if (!producerResult || !isa<LinalgOp>(producerResult.getOwner())) 370 return failure(); 371 372 // Compute the tiled producer slice dimensions given the tiled consumer loops. 373 SmallVector<int64_t> tiledSliceDimIndices = getTiledSliceDims( 374 consumerOpOperand, tiledRootAndFusedOpsLoops[consumerOp]); 375 if (tiledSliceDimIndices.empty()) 376 return failure(); 377 378 // Compute the tiled producer loop indices. 379 SmallVector<int64_t> tiledProducerLoopIndices = 380 getTiledProducerLoops(producerResult, tiledSliceDimIndices); 381 382 // Tile the producer operands and clone the producer in place of `sliceOp`. 383 LinalgOp clonedOp = 384 getTiledProducer(b, producerResult, sliceOp, tiledSliceDimIndices, 385 tiledProducerLoopIndices, iterArg); 386 tiledRootAndFusedOpsLoops[clonedOp] = tiledProducerLoopIndices; 387 388 // Cast the `clonedOp` result to gap type mismatches before canonicalization. 389 Type consumerOperandType = consumerOpOperand->get().getType(); 390 Value newResult = clonedOp->getResult(producerResult.getResultNumber()); 391 if (newResult.getType() != consumerOperandType) { 392 OpBuilder::InsertionGuard guard(b); 393 b.setInsertionPointAfter(clonedOp); 394 newResult = b.create<tensor::CastOp>(producerResult.getLoc(), 395 consumerOperandType, newResult); 396 } 397 398 // Replace the `sliceOp` uses except for the `clonedOp` output uses. 399 sliceOp.getResult().replaceAllUsesExcept(newResult, clonedOp); 400 return clonedOp; 401 } 402 403 ValueRange TileLoopNest::getRootOpReplacementResults() { 404 assert(!isEmpty() && "expect tile loop nest to be non-empty"); 405 return tileLoopOps.front()->getOpResults(); 406 } 407 408 SmallVector<LinalgOp> TileLoopNest::getAllTiledAndFusedOps() { 409 SmallVector<LinalgOp> result; 410 for (const auto &kvp : tiledRootAndFusedOpsLoops) { 411 auto linalgOp = dyn_cast<LinalgOp>(kvp.getFirst()); 412 assert(linalgOp && 413 "expect all tiled and fused operations are linalg operations"); 414 result.push_back(linalgOp); 415 } 416 return result; 417 } 418 419 //===----------------------------------------------------------------------===// 420 // Tile and fuse entry-points. 421 //===----------------------------------------------------------------------===// 422 423 FailureOr<TileLoopNest> mlir::linalg::tileConsumerAndFuseProducers( 424 OpBuilder &b, LinalgOp consumerOp, ArrayRef<int64_t> tileSizes, 425 ArrayRef<int64_t> tileInterchange, 426 const Optional<LinalgLoopDistributionOptions> &tileDistribution) { 427 assert(tileSizes.size() == tileInterchange.size() && 428 "expect the number of tile sizes and interchange dims to match"); 429 assert(isPermutation(tileInterchange) && 430 "expect tile interchange is a permutation"); 431 432 // Create an empty tile loop nest. 433 TileLoopNest tileLoopNest(consumerOp); 434 435 // Search the number of outer parallel loops to separate them from possible 436 // inner reduction dimensions. 437 SmallVector<StringAttr> iterTypes = 438 llvm::to_vector<6>(consumerOp.iterator_types().getAsRange<StringAttr>()); 439 applyPermutationToVector(iterTypes, tileInterchange); 440 auto *it = find_if(iterTypes, [&](StringAttr iterType) { 441 return !isParallelIterator(iterType); 442 }); 443 int64_t split = std::distance(iterTypes.begin(), it); 444 445 // Helper to fuse the producers greedily using a queue of fusion candidates. 446 auto fuseProducersGreedily = [&](ArrayRef<OpOperand *> operands) { 447 SmallVector<OpOperand *> candidates(operands.begin(), operands.end()); 448 while (!candidates.empty()) { 449 FailureOr<LinalgOp> fusedProducer = 450 tileLoopNest.fuseProducer(b, candidates.pop_back_val()); 451 if (failed(fusedProducer)) 452 continue; 453 candidates.append(fusedProducer->getInputAndOutputOperands()); 454 } 455 }; 456 457 // Perform tiling and fusion in two steps. We need to respect the loop 458 // interchange here; filter parellel dimensions based on their order *after* 459 // permutation but pass in the original configuration *before* permuation, 460 // given the tiling and interchange happen together. 461 SmallVector<int64_t> outerTileSizes(tileSizes.size(), 0); 462 SmallVector<int64_t> innerTileSizes(tileSizes.size(), 0); 463 for (int64_t i : tileInterchange.take_front(split)) 464 outerTileSizes[i] = tileSizes[i]; 465 for (int64_t i : tileInterchange.drop_front(split)) 466 innerTileSizes[i] = tileSizes[i]; 467 468 // Tile the outer parallel loops and fuse the output operands. 469 if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange, 470 tileDistribution))) 471 return failure(); 472 fuseProducersGreedily(tileLoopNest.getRootOp().getOutputOperands()); 473 474 // Tile the remaining loops and fuse the input operands. 475 if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange, 476 tileDistribution))) 477 return failure(); 478 fuseProducersGreedily(tileLoopNest.getRootOp().getInputOperands()); 479 480 // Exit if the tile loop nest is empty since all tile sizes are zero. 481 if (tileLoopNest.isEmpty()) 482 return failure(); 483 484 return tileLoopNest; 485 } 486