1 //===- FusionOnTensors.cpp - Implementation of linalg Fusion --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements linalg fusion on tensors 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Analysis/SliceAnalysis.h" 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Linalg/IR/Linalg.h" 17 #include "mlir/Dialect/Linalg/Passes.h" 18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/Dialect/Utils/IndexingUtils.h" 22 #include "mlir/IR/AffineExpr.h" 23 #include "mlir/IR/AffineMap.h" 24 #include "mlir/Support/LLVM.h" 25 26 using namespace mlir; 27 using namespace linalg; 28 29 //===----------------------------------------------------------------------===// 30 // StructuredOp specific helpers. 31 //===----------------------------------------------------------------------===// 32 33 /// Returns the tiled slice dimensions given the tiled consumer loop dimensions. 34 /// The slice defines a hyper rectangular iteration space and fusing the 35 /// producer is always possible. However, depending on the consumer indexing 36 /// map, not all slice elements may be consumed and the tiles may overlap. In 37 /// these cases, fusion introduces redundant computation. 38 static SmallVector<int64_t> getTiledSliceDims(OpOperand *consumerOperand, 39 ArrayRef<int64_t> tiledLoopDims) { 40 // Get the consumer operand indexing map. 41 LinalgOp consumerOp = consumerOperand->getOwner(); 42 AffineMap indexingMap = consumerOp.getTiedIndexingMap(consumerOperand); 43 44 // Search the slice dimensions tiled by a tile loop dimension. 45 DenseSet<int64_t> tiledSliceDimIndices; 46 for (const auto &en : enumerate(indexingMap.getResults())) { 47 for (auto tiledLoopDim : tiledLoopDims) { 48 if (en.value().isFunctionOfDim(tiledLoopDim)) 49 tiledSliceDimIndices.insert(en.index()); 50 } 51 } 52 return {tiledSliceDimIndices.begin(), tiledSliceDimIndices.end()}; 53 } 54 55 /// Given a vector of `tiledSliceDimIndices` that represent the tiled dimensions 56 /// of the producer result slice returns the tiled producer loop dimensions. 57 /// Example: 58 /// ``` 59 /// %res = linalg.fill(%cst, %input) 60 /// scf.for %i 61 /// scf.for %j 62 /// %slice = tensor.extract_slice %res[%i, %j] 63 /// ``` 64 /// getTiledProducerLoops(%res, [0, 1]) returns the loop indices [0, 1]. 65 static SmallVector<int64_t> 66 getTiledProducerLoops(OpResult producerResult, 67 ArrayRef<int64_t> tiledSliceDimIndices) { 68 LinalgOp producerOp = producerResult.getOwner(); 69 70 // Get the indexing map of the `producerOp` output operand that matches 71 // ´producerResult´. 72 AffineMap producerIndexingMap = producerOp.getTiedIndexingMap( 73 producerOp.getOutputOperand(producerResult.getResultNumber())); 74 75 // Keep only the tiled result slice dimensions of `producerIndexingMap`. 76 AffineMap tiledProducerIndexingSubMap = 77 producerIndexingMap.getSubMap(SmallVector<unsigned>( 78 tiledSliceDimIndices.begin(), tiledSliceDimIndices.end())); 79 80 // Compute the producer loop indices mapped to the tiled result slice 81 // dimensions. As the output indexing map of structured operations are 82 // projected permutations, `tiledProducerIndexingSubMap` has to be a 83 // projected permutation as well. We can thus obtain the producer loop indices 84 // by getting the positions of the result dimensions. 85 // Example: 86 // (d0, d1, d2) -> (d0, d2) has the result positions [0, 2]. 87 assert(tiledProducerIndexingSubMap.isProjectedPermutation() && 88 "expect slice and producer loop dimensions map one-to-one"); 89 SmallVector<int64_t> tiledProducerLoopIndices; 90 transform(llvm::seq<unsigned>(0, tiledProducerIndexingSubMap.getNumResults()), 91 std::back_inserter(tiledProducerLoopIndices), [&](unsigned idx) { 92 return tiledProducerIndexingSubMap.getDimPosition(idx); 93 }); 94 95 return tiledProducerLoopIndices; 96 } 97 98 /// Returns the producer fused in place of `sliceOp`. Tile the producer operands 99 /// along the `tiledSliceDimIndices` and clone the producer. Consider the case 100 /// of fusion of an output tensor: 101 /// ``` 102 /// %1 = producer ins(...) outs(%0) 103 /// %2 = consumer ins(...) outs(%1) 104 /// ``` 105 /// When consumer is tiled, %1 appears in the loop iter_args: 106 /// ``` 107 /// %1 = producer ins(...) outs(%0) 108 /// %2 = scf.for ... iter_args(%1) .. (%bbarg) { 109 /// %t1 = tensor.extract_slice %bbarg[..] 110 /// %t2 = consumer ins(...) outs(%t1) 111 /// %r = tensor.insert_slice %t2, %bbarg[...] 112 /// } 113 /// ``` 114 /// Fusing %1 into the loop requires updating iter_args(%1) to iter_args(%0): 115 /// ``` 116 /// %2 = scf.for ... iter_args(%0) .. (%bbarg) { 117 /// %t0 = tensor.extract_slice %bbarg[..] 118 /// %t1 = producer ins(...) outs(%t0) 119 /// %t2 = consumer ins(...) outs(%t1) 120 /// %r = tensor.insert_slice %t2, %bbarg[...] 121 /// } 122 /// ``` 123 /// This transformation is only valid if %bbarg is exclusively used by the 124 /// output ExtractSliceOp / InsertSliceOp pair, which is checked by the 125 /// `fuseProducer` method. 126 /// TODO: instead of check and failure, insert new iter_args each time a 127 /// producer is fused into a consumer and fold away unused iter_args. 128 static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult, 129 tensor::ExtractSliceOp sliceOp, 130 ArrayRef<int64_t> tiledSliceDimIndices, 131 ArrayRef<int64_t> tiledProducerLoopIndices, 132 OpOperand *iterArg) { 133 // Clone the producer after `sliceOp` since the slice may be reused to pass in 134 // the producer result. 135 OpBuilder::InsertionGuard guard(b); 136 b.setInsertionPointAfter(sliceOp); 137 138 // Get the producer. 139 LinalgOp producerOp = producerResult.getOwner(); 140 Location loc = producerOp.getLoc(); 141 142 // Obtain the `producerOp` loop bounds and the `sliceOp` ranges. 143 SmallVector<Value> producerLoopBounds; 144 transform(producerOp.createLoopRanges(b, loc), 145 std::back_inserter(producerLoopBounds), 146 [](Range range) { return range.size; }); 147 SmallVector<Range> sliceOpRanges = sliceOp.getOrCreateRanges(b, loc); 148 149 // Tile the producer operands given the `sliceOp` ranges. Iterate the 150 // `tiledSliceDimIndices` and store the tile offset and size for the tiled 151 // slice dimension. 152 auto zero = b.create<arith::ConstantIndexOp>(loc, 0); 153 SmallVector<Value> tileIvs(producerOp.getNumLoops(), nullptr); 154 SmallVector<Value> tileSizes(producerOp.getNumLoops(), zero); 155 SmallVector<Value> allIvs(producerOp.getNumLoops(), nullptr); 156 for (auto it : zip(tiledSliceDimIndices, tiledProducerLoopIndices)) { 157 int64_t tiledSliceDim = std::get<0>(it); 158 int64_t tiledProducerLoop = std::get<1>(it); 159 tileIvs[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].offset; 160 tileSizes[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].size; 161 allIvs[tiledProducerLoop] = tileIvs[tiledProducerLoop]; 162 } 163 erase_value(tileIvs, nullptr); 164 SmallVector<Value> tiledOperands = producerOp.getInputAndOutputOperands(); 165 tiledOperands = makeTiledShapes(b, loc, producerOp, tiledOperands, tileIvs, 166 tileSizes, producerLoopBounds); 167 168 // Output fusion has to update the iteration arguments of the tile loop nest. 169 // In particular, the iteration argument of the outermost tile loop needs to 170 // be set to the producer output instead of the producer result and `clonedOp` 171 // shall use the existing `sliceOp` result instead of the tiled producer 172 // output operand. 173 if (iterArg) { 174 OpOperand *outputOperand = 175 producerOp.getOutputOperand(producerResult.getResultNumber()); 176 iterArg->set(outputOperand->get()); 177 tiledOperands[outputOperand->getOperandNumber()] = sliceOp.getResult(); 178 } 179 180 // Clone the producer using the tiled producer operands. 181 TypeRange resultTypes = ValueRange(tiledOperands) 182 .take_back(producerOp.getNumOutputs()) 183 .getTypes(); 184 LinalgOp clonedOp = producerOp.clone(b, loc, resultTypes, tiledOperands); 185 186 // Shift all IndexOp results by the tile offset. 187 addTileLoopIvsToIndexOpResults(b, clonedOp, allIvs); 188 189 return clonedOp; 190 } 191 192 //===----------------------------------------------------------------------===// 193 // TileLoopNest specific helpers. 194 //===----------------------------------------------------------------------===// 195 196 bool TileLoopNest::isEmpty() { return tileLoopOps.empty(); } 197 198 bool TileLoopNest::isValid() { 199 // Check if `rootOp` has been tiled at least once. 200 if (isEmpty() || tiledRootAndFusedOpsLoops.count(rootOp) == 0) 201 return false; 202 203 // Check if the number of loop operations and dimensions match. 204 if (tileLoopOps.size() != tiledRootAndFusedOpsLoops[rootOp].size()) 205 return false; 206 207 // Check if the innermost tile loop is the parent of `tiledOp`. 208 if (rootOp->getParentOp() != tileLoopOps.back()) 209 return false; 210 211 // Check if the tile loops are directly nested. 212 return std::adjacent_find(tileLoopOps.begin(), tileLoopOps.end(), 213 [](Operation *op1, Operation *op2) { 214 return op1 != op2->getParentOp(); 215 }) == tileLoopOps.end(); 216 } 217 218 SmallVector<BlockArgument> TileLoopNest::getTiedBBArgs(BlockArgument bbArg) { 219 assert(bbArg && "expect the block argument to be non-zero"); 220 SmallVector<BlockArgument> bbArgs; 221 222 // Search all tile loop block arguments from inner to outer. 223 for (auto tileLoop : reverse(tileLoopOps)) { 224 if (bbArg.getOwner()->getParentOp() != tileLoop) 225 return {}; 226 bbArgs.push_back(bbArg); 227 OpOperand *iterArg = &tileLoop.getOpOperandForRegionIterArg(bbArg); 228 bbArg = iterArg->get().dyn_cast<BlockArgument>(); 229 } 230 231 // Reverse the block arguments to order them from outer to inner. 232 return {bbArgs.rbegin(), bbArgs.rend()}; 233 } 234 235 OpOperand *TileLoopNest::getTiedIterArg(BlockArgument bbArg) { 236 // Search all block arguments and return the matching iteration argument. 237 SmallVector<BlockArgument> bbArgs = getTiedBBArgs(bbArg); 238 if (bbArgs.size() != tileLoopOps.size()) 239 return nullptr; 240 return &tileLoopOps.front().getOpOperandForRegionIterArg(bbArgs.front()); 241 } 242 243 bool TileLoopNest::hasOtherUses(BlockArgument bbArg, 244 tensor::ExtractSliceOp sliceOp) { 245 // Check the innermost block argument is either used by the ExtractSliceOp 246 // `sliceOp`, the matching InsertSliceOp, or by a DimOp. Handle other uses 247 // conservatively. 248 for (Operation *op : bbArg.getUsers()) { 249 if (!isa<tensor::DimOp, tensor::InsertSliceOp, tensor::ExtractSliceOp>(op)) 250 return false; 251 if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) { 252 if (extractSliceOp != sliceOp) 253 return false; 254 } 255 if (auto insertSliceOp = dyn_cast<tensor::InsertSliceOp>(op)) { 256 SetVector<Operation *> backwardSlice; 257 getBackwardSlice(insertSliceOp.source(), &backwardSlice, 258 [](Operation *op) { 259 return isa<LinalgOp, tensor::InsertSliceOp>(op); 260 }); 261 if (backwardSlice.empty() || backwardSlice.front() != sliceOp) 262 return false; 263 } 264 } 265 266 // Check the block arguments, except for the innermost one, have one use. 267 SmallVector<BlockArgument> bbArgs = getTiedBBArgs(bbArg); 268 return !all_of(bbArgs, [&](BlockArgument bbArg) { 269 return bbArg.hasOneUse() || bbArg == bbArgs.back(); 270 }); 271 } 272 273 LogicalResult TileLoopNest::tileRootOp( 274 OpBuilder &b, ArrayRef<int64_t> tileSizes, 275 ArrayRef<int64_t> tileInterchange, 276 Optional<LinalgLoopDistributionOptions> tileDistribution) { 277 // Exit if all tile sizes are zero. 278 if (tileSizes.size() == static_cast<size_t>(count(tileSizes, 0))) 279 return success(); 280 281 // Tile the root operation. 282 LinalgTilingOptions tilingOptions; 283 tilingOptions = tilingOptions 284 .setInterchange(SmallVector<unsigned>( 285 tileInterchange.begin(), tileInterchange.end())) 286 .setTileSizes(tileSizes) 287 .setLoopType(LinalgTilingLoopType::Loops); 288 if (tileDistribution) 289 tilingOptions = 290 tilingOptions.setDistributionOptions(tileDistribution.getValue()); 291 292 // TODO: Propagate RewriterBase everywhere. 293 IRRewriter rewriter(b); 294 FailureOr<TiledLinalgOp> tiledRootOp = 295 tileLinalgOp(rewriter, rootOp, tilingOptions); 296 297 // Exit if tiling the root operation fails. 298 if (failed(tiledRootOp)) 299 return failure(); 300 301 // Replace all uses of the root operation if it has been tiled before. All 302 // uses of the original untiled root operation are updated by the calling pass 303 // or pattern. 304 if (!isEmpty()) 305 rootOp->replaceAllUsesWith(tiledRootOp->tensorResults); 306 307 // Transfer the stored `rootOp` loop dimensions if it has been tiled before. 308 if (tiledRootAndFusedOpsLoops.count(rootOp) != 0) { 309 tiledRootAndFusedOpsLoops[tiledRootOp->op] = 310 tiledRootAndFusedOpsLoops[rootOp]; 311 } 312 313 // Update the root operation and append the loops and tile loop dimensions. 314 rootOp = tiledRootOp->op; 315 tileLoopOps.append(tiledRootOp->loops.begin(), tiledRootOp->loops.end()); 316 for (const auto &en : enumerate(tileSizes)) { 317 // Copy only the tiled loop dimensions with non-zero tile size. 318 if (en.value() == 0) 319 continue; 320 tiledRootAndFusedOpsLoops[rootOp].push_back(tileInterchange[en.index()]); 321 } 322 assert(isValid() && "expect tile loop nest to be valid after tiling"); 323 return success(); 324 } 325 326 FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b, 327 OpOperand *consumerOpOperand) { 328 // Check if the consumer has been tiled before. For example, it may not have 329 // been tiled if the outermost tile loop is a reduction loop. 330 if (tiledRootAndFusedOpsLoops.count(consumerOpOperand->getOwner()) == 0) 331 return failure(); 332 333 assert(this->isValid() && 334 "expect the tile loop nest to satisfy all invariants"); 335 336 // Check the tile loop nest is non-empty. 337 if (isEmpty()) 338 return failure(); 339 340 // Check `consumerOpOperand` is defined by an ExtractSliceOp. 341 auto sliceOp = 342 consumerOpOperand->get().getDefiningOp<tensor::ExtractSliceOp>(); 343 if (!sliceOp) 344 return failure(); 345 346 // Check `sliceOp` and `consumerOp` are in the same block. 347 LinalgOp consumerOp = consumerOpOperand->getOwner(); 348 if (sliceOp->getBlock() != rootOp->getBlock() || 349 consumerOp->getBlock() != rootOp->getBlock()) 350 return failure(); 351 352 // Check if the producer is a LinalgOp possibly passed by iteration argument. 353 OpOperand *iterArg = nullptr; 354 auto producerResult = sliceOp.source().dyn_cast<OpResult>(); 355 if (auto bbArg = sliceOp.source().dyn_cast<BlockArgument>()) { 356 iterArg = getTiedIterArg(bbArg); 357 // Check the iteration argument may be used to pass in the producer output. 358 if (!iterArg || hasOtherUses(bbArg, sliceOp)) 359 return failure(); 360 producerResult = iterArg->get().dyn_cast<OpResult>(); 361 } 362 if (!producerResult || !isa<LinalgOp>(producerResult.getOwner())) 363 return failure(); 364 365 // Compute the tiled producer slice dimensions given the tiled consumer loops. 366 SmallVector<int64_t> tiledSliceDimIndices = getTiledSliceDims( 367 consumerOpOperand, tiledRootAndFusedOpsLoops[consumerOp]); 368 if (tiledSliceDimIndices.empty()) 369 return failure(); 370 371 // Compute the tiled producer loop indices. 372 SmallVector<int64_t> tiledProducerLoopIndices = 373 getTiledProducerLoops(producerResult, tiledSliceDimIndices); 374 375 // Tile the producer operands and clone the producer in place of `sliceOp`. 376 LinalgOp clonedOp = 377 getTiledProducer(b, producerResult, sliceOp, tiledSliceDimIndices, 378 tiledProducerLoopIndices, iterArg); 379 tiledRootAndFusedOpsLoops[clonedOp] = tiledProducerLoopIndices; 380 381 // Cast the `clonedOp` result to gap type mismatches before canonicalization. 382 Type consumerOperandType = consumerOpOperand->get().getType(); 383 Value newResult = clonedOp->getResult(producerResult.getResultNumber()); 384 if (newResult.getType() != consumerOperandType) { 385 OpBuilder::InsertionGuard guard(b); 386 b.setInsertionPointAfter(clonedOp); 387 newResult = b.create<tensor::CastOp>(producerResult.getLoc(), 388 consumerOperandType, newResult); 389 } 390 391 // Replace the `sliceOp` uses except for the `clonedOp` output uses. 392 sliceOp.getResult().replaceAllUsesExcept(newResult, clonedOp); 393 return clonedOp; 394 } 395 396 ValueRange TileLoopNest::getRootOpReplacementResults() { 397 assert(!isEmpty() && "expect tile loop nest to be non-empty"); 398 return tileLoopOps.front()->getOpResults(); 399 } 400 401 SmallVector<LinalgOp> TileLoopNest::getAllTiledAndFusedOps() { 402 SmallVector<LinalgOp> result; 403 for (const auto &kvp : tiledRootAndFusedOpsLoops) { 404 auto linalgOp = dyn_cast<LinalgOp>(kvp.getFirst()); 405 assert(linalgOp && 406 "expect all tiled and fused operations are linalg operations"); 407 result.push_back(linalgOp); 408 } 409 return result; 410 } 411 412 //===----------------------------------------------------------------------===// 413 // Tile and fuse entry-points. 414 //===----------------------------------------------------------------------===// 415 416 FailureOr<TileLoopNest> mlir::linalg::tileConsumerAndFuseProducers( 417 OpBuilder &b, LinalgOp consumerOp, ArrayRef<int64_t> tileSizes, 418 ArrayRef<int64_t> tileInterchange, 419 const Optional<LinalgLoopDistributionOptions> &tileDistribution) { 420 assert(tileSizes.size() == tileInterchange.size() && 421 "expect the number of tile sizes and interchange dims to match"); 422 assert(isPermutation(tileInterchange) && 423 "expect tile interchange is a permutation"); 424 425 // Create an empty tile loop nest. 426 TileLoopNest tileLoopNest(consumerOp); 427 428 // Search the number of outer parallel loops to separate them from possible 429 // inner reduction dimensions. 430 SmallVector<StringAttr> iterTypes = 431 llvm::to_vector<6>(consumerOp.iterator_types().getAsRange<StringAttr>()); 432 applyPermutationToVector(iterTypes, tileInterchange); 433 auto *it = find_if(iterTypes, [&](StringAttr iterType) { 434 return !isParallelIterator(iterType); 435 }); 436 int64_t split = std::distance(iterTypes.begin(), it); 437 438 // Helper to fuse the producers greedily using a queue of fusion candidates. 439 auto fuseProducersGreedily = [&](ArrayRef<OpOperand *> operands) { 440 SmallVector<OpOperand *> candidates(operands.begin(), operands.end()); 441 while (!candidates.empty()) { 442 FailureOr<LinalgOp> fusedProducer = 443 tileLoopNest.fuseProducer(b, candidates.pop_back_val()); 444 if (failed(fusedProducer)) 445 continue; 446 candidates.append(fusedProducer->getInputAndOutputOperands()); 447 } 448 }; 449 450 // Tile the outer parallel loops and fuse the output operands. 451 SmallVector<int64_t> outerTileSizes; 452 outerTileSizes.append(tileSizes.begin(), tileSizes.begin() + split); 453 outerTileSizes.append(tileSizes.size() - split, 0); 454 if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange, 455 tileDistribution))) 456 return failure(); 457 fuseProducersGreedily(tileLoopNest.getRootOp().getOutputOperands()); 458 459 // Tile the remaining loops and fuse the input operands. 460 SmallVector<int64_t> innerTileSizes; 461 innerTileSizes.append(split, 0); 462 innerTileSizes.append(tileSizes.begin() + split, tileSizes.end()); 463 if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange, 464 tileDistribution))) 465 return failure(); 466 fuseProducersGreedily(tileLoopNest.getRootOp().getInputOperands()); 467 468 // Exit if the tile loop nest is empty since all tile sizes are zero. 469 if (tileLoopNest.isEmpty()) 470 return failure(); 471 472 return tileLoopNest; 473 } 474