1 //===- FusionOnTensors.cpp - Implementation of linalg Fusion --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements linalg fusion on tensors 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Analysis/SliceAnalysis.h" 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Linalg/IR/Linalg.h" 17 #include "mlir/Dialect/Linalg/Passes.h" 18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/IR/AffineExpr.h" 22 #include "mlir/IR/AffineMap.h" 23 #include "mlir/Support/LLVM.h" 24 25 using namespace mlir; 26 using namespace linalg; 27 28 //===----------------------------------------------------------------------===// 29 // StructuredOp specific helpers. 30 //===----------------------------------------------------------------------===// 31 32 /// Returns the tiled slice dimensions given the tiled consumer loop dimensions. 33 /// The slice defines a hyper rectangular iteration space and fusing the 34 /// producer is always possible. However, depending on the consumer indexing 35 /// map, not all slice elements may be consumed and the tiles may overlap. In 36 /// these cases, fusion introduces redundant computation. 37 static SmallVector<int64_t> getTiledSliceDims(OpOperand *consumerOperand, 38 ArrayRef<int64_t> tiledLoopDims) { 39 // Get the consumer operand indexing map. 40 LinalgOp consumerOp = consumerOperand->getOwner(); 41 AffineMap indexingMap = consumerOp.getTiedIndexingMap(consumerOperand); 42 43 // Search the slice dimensions tiled by a tile loop dimension. 44 DenseSet<int64_t> tiledSliceDimIndices; 45 for (const auto &en : enumerate(indexingMap.getResults())) { 46 for (auto tiledLoopDim : tiledLoopDims) { 47 if (en.value().isFunctionOfDim(tiledLoopDim)) 48 tiledSliceDimIndices.insert(en.index()); 49 } 50 } 51 return {tiledSliceDimIndices.begin(), tiledSliceDimIndices.end()}; 52 } 53 54 /// Given a vector of `tiledSliceDimIndices` that represent the tiled dimensions 55 /// of the producer result slice returns the tiled producer loop dimensions. 56 /// Example: 57 /// ``` 58 /// %res = linalg.fill(%cst, %input) 59 /// scf.for %i 60 /// scf.for %j 61 /// %slice = tensor.extract_slice %res[%i, %j] 62 /// ``` 63 /// getTiledProducerLoops(%res, [0, 1]) returns the loop indices [0, 1]. 64 static SmallVector<int64_t> 65 getTiledProducerLoops(OpResult producerResult, 66 ArrayRef<int64_t> tiledSliceDimIndices) { 67 LinalgOp producerOp = producerResult.getOwner(); 68 69 // Get the indexing map of the `producerOp` output operand that matches 70 // ´producerResult´. 71 AffineMap producerIndexingMap = producerOp.getTiedIndexingMap( 72 producerOp.getOutputOperand(producerResult.getResultNumber())); 73 74 // Keep only the tiled result slice dimensions of `producerIndexingMap`. 75 AffineMap tiledProducerIndexingSubMap = 76 producerIndexingMap.getSubMap(SmallVector<unsigned>( 77 tiledSliceDimIndices.begin(), tiledSliceDimIndices.end())); 78 79 // Compute the producer loop indices mapped to the tiled result slice 80 // dimensions. As the output indexing map of structured operations are 81 // projected permutations, `tiledProducerIndexingSubMap` has to be a 82 // projected permutation as well. We can thus obtain the producer loop indices 83 // by getting the positions of the result dimensions. 84 // Example: 85 // (d0, d1, d2) -> (d0, d2) has the result positions [0, 2]. 86 assert(tiledProducerIndexingSubMap.isProjectedPermutation() && 87 "expect slice and producer loop dimensions map one-to-one"); 88 SmallVector<int64_t> tiledProducerLoopIndices; 89 transform(llvm::seq<unsigned>(0, tiledProducerIndexingSubMap.getNumResults()), 90 std::back_inserter(tiledProducerLoopIndices), [&](unsigned idx) { 91 return tiledProducerIndexingSubMap.getDimPosition(idx); 92 }); 93 94 return tiledProducerLoopIndices; 95 } 96 97 /// Returns the producer fused in place of `sliceOp`. Tile the producer operands 98 /// along the `tiledSliceDimIndices` and clone the producer. Consider the case 99 /// of fusion of an output tensor: 100 /// ``` 101 /// %1 = producer ins(...) outs(%0) 102 /// %2 = consumer ins(...) outs(%1) 103 /// ``` 104 /// When consumer is tiled, %1 appears in the loop iter_args: 105 /// ``` 106 /// %1 = producer ins(...) outs(%0) 107 /// %2 = scf.for ... iter_args(%1) .. (%bbarg) { 108 /// %t1 = tensor.extract_slice %bbarg[..] 109 /// %t2 = consumer ins(...) outs(%t1) 110 /// %r = tensor.insert_slice %t2, %bbarg[...] 111 /// } 112 /// ``` 113 /// Fusing %1 into the loop requires updating iter_args(%1) to iter_args(%0): 114 /// ``` 115 /// %2 = scf.for ... iter_args(%0) .. (%bbarg) { 116 /// %t0 = tensor.extract_slice %bbarg[..] 117 /// %t1 = producer ins(...) outs(%t0) 118 /// %t2 = consumer ins(...) outs(%t1) 119 /// %r = tensor.insert_slice %t2, %bbarg[...] 120 /// } 121 /// ``` 122 /// This transformation is only valid if %bbarg is exclusively used by the 123 /// output ExtractSliceOp / InsertSliceOp pair, which is checked by the 124 /// `fuseProducer` method. 125 /// TODO: instead of check and failure, insert new iter_args each time a 126 /// producer is fused into a consumer and fold away unused iter_args. 127 static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult, 128 tensor::ExtractSliceOp sliceOp, 129 ArrayRef<int64_t> tiledSliceDimIndices, 130 ArrayRef<int64_t> tiledProducerLoopIndices, 131 OpOperand *iterArg) { 132 // Clone the producer after `sliceOp` since the slice may be reused to pass in 133 // the producer result. 134 OpBuilder::InsertionGuard guard(b); 135 b.setInsertionPointAfter(sliceOp); 136 137 // Get the producer. 138 LinalgOp producerOp = producerResult.getOwner(); 139 Location loc = producerOp.getLoc(); 140 141 // Obtain the `producerOp` loop bounds and the `sliceOp` ranges. 142 SmallVector<Value> producerLoopBounds; 143 transform(producerOp.createLoopRanges(b, loc), 144 std::back_inserter(producerLoopBounds), 145 [](Range range) { return range.size; }); 146 SmallVector<Range> sliceOpRanges = sliceOp.getOrCreateRanges(b, loc); 147 148 // Tile the producer operands given the `sliceOp` ranges. Iterate the 149 // `tiledSliceDimIndices` and store the tile offset and size for the tiled 150 // slice dimension. 151 auto zero = b.create<arith::ConstantIndexOp>(loc, 0); 152 SmallVector<Value> tileIvs(producerOp.getNumLoops(), nullptr); 153 SmallVector<Value> tileSizes(producerOp.getNumLoops(), zero); 154 SmallVector<Value> allIvs(producerOp.getNumLoops(), nullptr); 155 for (auto it : zip(tiledSliceDimIndices, tiledProducerLoopIndices)) { 156 int64_t tiledSliceDim = std::get<0>(it); 157 int64_t tiledProducerLoop = std::get<1>(it); 158 tileIvs[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].offset; 159 tileSizes[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].size; 160 allIvs[tiledProducerLoop] = tileIvs[tiledProducerLoop]; 161 } 162 erase_value(tileIvs, nullptr); 163 SmallVector<Value> tiledOperands = producerOp.getInputAndOutputOperands(); 164 tiledOperands = makeTiledShapes(b, loc, producerOp, tiledOperands, tileIvs, 165 tileSizes, producerLoopBounds); 166 167 // Output fusion has to update the iteration arguments of the tile loop nest. 168 // In particular, the iteration argument of the outermost tile loop needs to 169 // be set to the producer output instead of the producer result and `clonedOp` 170 // shall use the existing `sliceOp` result instead of the tiled producer 171 // output operand. 172 if (iterArg) { 173 OpOperand *outputOperand = 174 producerOp.getOutputOperand(producerResult.getResultNumber()); 175 iterArg->set(outputOperand->get()); 176 tiledOperands[outputOperand->getOperandNumber()] = sliceOp.getResult(); 177 } 178 179 // Clone the producer using the tiled producer operands. 180 TypeRange resultTypes = ValueRange(tiledOperands) 181 .take_back(producerOp.getNumOutputs()) 182 .getTypes(); 183 LinalgOp clonedOp = producerOp.clone(b, loc, resultTypes, tiledOperands); 184 185 // Shift all IndexOp results by the tile offset. 186 addTileLoopIvsToIndexOpResults(b, clonedOp, allIvs); 187 188 return clonedOp; 189 } 190 191 //===----------------------------------------------------------------------===// 192 // TileLoopNest specific helpers. 193 //===----------------------------------------------------------------------===// 194 195 bool TileLoopNest::isEmpty() { return tileLoopOps.empty(); } 196 197 bool TileLoopNest::isValid() { 198 // Check if `rootOp` has been tiled at least once. 199 if (isEmpty() || tiledRootAndFusedOpsLoops.count(rootOp) == 0) 200 return false; 201 202 // Check if the number of loop operations and dimensions match. 203 if (tileLoopOps.size() != tiledRootAndFusedOpsLoops[rootOp].size()) 204 return false; 205 206 // Check if the innermost tile loop is the parent of `tiledOp`. 207 if (rootOp->getParentOp() != tileLoopOps.back()) 208 return false; 209 210 // Check if the tile loops are directly nested. 211 return std::adjacent_find(tileLoopOps.begin(), tileLoopOps.end(), 212 [](Operation *op1, Operation *op2) { 213 return op1 != op2->getParentOp(); 214 }) == tileLoopOps.end(); 215 } 216 217 SmallVector<BlockArgument> TileLoopNest::getTiedBBArgs(BlockArgument bbArg) { 218 assert(bbArg && "expect the block argument to be non-zero"); 219 SmallVector<BlockArgument> bbArgs; 220 221 // Search all tile loop block arguments from inner to outer. 222 for (auto tileLoop : reverse(tileLoopOps)) { 223 if (bbArg.getOwner()->getParentOp() != tileLoop) 224 return {}; 225 bbArgs.push_back(bbArg); 226 OpOperand *iterArg = &tileLoop.getOpOperandForRegionIterArg(bbArg); 227 bbArg = iterArg->get().dyn_cast<BlockArgument>(); 228 } 229 230 // Reverse the block arguments to order them from outer to inner. 231 return {bbArgs.rbegin(), bbArgs.rend()}; 232 } 233 234 OpOperand *TileLoopNest::getTiedIterArg(BlockArgument bbArg) { 235 // Search all block arguments and return the matching iteration argument. 236 SmallVector<BlockArgument> bbArgs = getTiedBBArgs(bbArg); 237 if (bbArgs.size() != tileLoopOps.size()) 238 return nullptr; 239 return &tileLoopOps.front().getOpOperandForRegionIterArg(bbArgs.front()); 240 } 241 242 bool TileLoopNest::hasOtherUses(BlockArgument bbArg, 243 tensor::ExtractSliceOp sliceOp) { 244 // Check the innermost block argument is either used by the ExtractSliceOp 245 // `sliceOp`, the matching InsertSliceOp, or by a DimOp. Handle other uses 246 // conservatively. 247 for (Operation *op : bbArg.getUsers()) { 248 if (!isa<tensor::DimOp, tensor::InsertSliceOp, tensor::ExtractSliceOp>(op)) 249 return false; 250 if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) { 251 if (extractSliceOp != sliceOp) 252 return false; 253 } 254 if (auto insertSliceOp = dyn_cast<tensor::InsertSliceOp>(op)) { 255 SetVector<Operation *> backwardSlice; 256 getBackwardSlice(insertSliceOp.source(), &backwardSlice, 257 [](Operation *op) { 258 return isa<LinalgOp, tensor::InsertSliceOp>(op); 259 }); 260 if (backwardSlice.empty() || backwardSlice.front() != sliceOp) 261 return false; 262 } 263 } 264 265 // Check the block arguments, except for the innermost one, have one use. 266 SmallVector<BlockArgument> bbArgs = getTiedBBArgs(bbArg); 267 return !all_of(bbArgs, [&](BlockArgument bbArg) { 268 return bbArg.hasOneUse() || bbArg == bbArgs.back(); 269 }); 270 } 271 272 LogicalResult TileLoopNest::tileRootOp(OpBuilder &b, 273 ArrayRef<int64_t> tileSizes, 274 ArrayRef<int64_t> tileInterchange) { 275 // Exit if all tile sizes are zero. 276 if (tileSizes.size() == static_cast<size_t>(count(tileSizes, 0))) 277 return success(); 278 279 // Tile the root operation. 280 LinalgTilingOptions tilingOptions; 281 tilingOptions = tilingOptions 282 .setInterchange(SmallVector<unsigned>( 283 tileInterchange.begin(), tileInterchange.end())) 284 .setTileSizes(tileSizes) 285 .setLoopType(LinalgTilingLoopType::Loops); 286 287 // TODO: Propagate RewriterBase everywhere. 288 IRRewriter rewriter(b); 289 FailureOr<TiledLinalgOp> tiledRootOp = 290 tileLinalgOp(rewriter, rootOp, tilingOptions); 291 292 // Exit if tiling the root operation fails. 293 if (failed(tiledRootOp)) 294 return failure(); 295 296 // Replace all uses of the root operation if it has been tiled before. All 297 // uses of the original untiled root operation are updated by the calling pass 298 // or pattern. 299 if (!isEmpty()) 300 rootOp->replaceAllUsesWith(tiledRootOp->tensorResults); 301 302 // Transfer the stored `rootOp` loop dimensions if it has been tiled before. 303 if (tiledRootAndFusedOpsLoops.count(rootOp) != 0) { 304 tiledRootAndFusedOpsLoops[tiledRootOp->op] = 305 tiledRootAndFusedOpsLoops[rootOp]; 306 } 307 308 // Update the root operation and append the loops and tile loop dimensions. 309 rootOp = tiledRootOp->op; 310 tileLoopOps.append(tiledRootOp->loops.begin(), tiledRootOp->loops.end()); 311 for (const auto &en : enumerate(tileSizes)) { 312 // Copy only the tiled loop dimensions with non-zero tile size. 313 if (en.value() == 0) 314 continue; 315 tiledRootAndFusedOpsLoops[rootOp].push_back(tileInterchange[en.index()]); 316 } 317 assert(isValid() && "expect tile loop nest to be valid after tiling"); 318 return success(); 319 } 320 321 FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b, 322 OpOperand *consumerOpOperand) { 323 // Check if the consumer has been tiled before. For example, it may not have 324 // been tiled if the outermost tile loop is a reduction loop. 325 if (tiledRootAndFusedOpsLoops.count(consumerOpOperand->getOwner()) == 0) 326 return failure(); 327 328 assert(this->isValid() && 329 "expect the tile loop nest to satisfy all invariants"); 330 331 // Check the tile loop nest is non-empty. 332 if (isEmpty()) 333 return failure(); 334 335 // Check `consumerOpOperand` is defined by an ExtractSliceOp. 336 auto sliceOp = 337 consumerOpOperand->get().getDefiningOp<tensor::ExtractSliceOp>(); 338 if (!sliceOp) 339 return failure(); 340 341 // Check `sliceOp` and `consumerOp` are in the same block. 342 LinalgOp consumerOp = consumerOpOperand->getOwner(); 343 if (sliceOp->getBlock() != rootOp->getBlock() || 344 consumerOp->getBlock() != rootOp->getBlock()) 345 return failure(); 346 347 // Check if the producer is a LinalgOp possibly passed by iteration argument. 348 OpOperand *iterArg = nullptr; 349 auto producerResult = sliceOp.source().dyn_cast<OpResult>(); 350 if (auto bbArg = sliceOp.source().dyn_cast<BlockArgument>()) { 351 iterArg = getTiedIterArg(bbArg); 352 // Check the iteration argument may be used to pass in the producer output. 353 if (!iterArg || hasOtherUses(bbArg, sliceOp)) 354 return failure(); 355 producerResult = iterArg->get().dyn_cast<OpResult>(); 356 } 357 if (!producerResult || !isa<LinalgOp>(producerResult.getOwner())) 358 return failure(); 359 360 // Compute the tiled producer slice dimensions given the tiled consumer loops. 361 SmallVector<int64_t> tiledSliceDimIndices = getTiledSliceDims( 362 consumerOpOperand, tiledRootAndFusedOpsLoops[consumerOp]); 363 if (tiledSliceDimIndices.empty()) 364 return failure(); 365 366 // Compute the tiled producer loop indices. 367 SmallVector<int64_t> tiledProducerLoopIndices = 368 getTiledProducerLoops(producerResult, tiledSliceDimIndices); 369 370 // Tile the producer operands and clone the producer in place of `sliceOp`. 371 LinalgOp clonedOp = 372 getTiledProducer(b, producerResult, sliceOp, tiledSliceDimIndices, 373 tiledProducerLoopIndices, iterArg); 374 tiledRootAndFusedOpsLoops[clonedOp] = tiledProducerLoopIndices; 375 376 // Cast the `clonedOp` result to gap type mismatches before canonicalization. 377 Type consumerOperandType = consumerOpOperand->get().getType(); 378 Value newResult = clonedOp->getResult(producerResult.getResultNumber()); 379 if (newResult.getType() != consumerOperandType) { 380 OpBuilder::InsertionGuard guard(b); 381 b.setInsertionPointAfter(clonedOp); 382 newResult = b.create<tensor::CastOp>(producerResult.getLoc(), 383 consumerOperandType, newResult); 384 } 385 386 // Replace the `sliceOp` uses except for the `clonedOp` output uses. 387 sliceOp.getResult().replaceAllUsesExcept(newResult, clonedOp); 388 return clonedOp; 389 } 390 391 ValueRange TileLoopNest::getRootOpReplacementResults() { 392 assert(!isEmpty() && "expect tile loop nest to be non-empty"); 393 return tileLoopOps.front()->getOpResults(); 394 } 395 396 SmallVector<LinalgOp> TileLoopNest::getAllTiledAndFusedOps() { 397 SmallVector<LinalgOp> result; 398 for (const auto &kvp : tiledRootAndFusedOpsLoops) { 399 auto linalgOp = dyn_cast<LinalgOp>(kvp.getFirst()); 400 assert(linalgOp && 401 "expect all tiled and fused operations are linalg operations"); 402 result.push_back(linalgOp); 403 } 404 return result; 405 } 406 407 //===----------------------------------------------------------------------===// 408 // Tile and fuse entry-points. 409 //===----------------------------------------------------------------------===// 410 411 FailureOr<TileLoopNest> 412 mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp, 413 ArrayRef<int64_t> tileSizes, 414 ArrayRef<int64_t> tileInterchange) { 415 assert(tileSizes.size() == tileInterchange.size() && 416 "expect the number of tile sizes and interchange dims to match"); 417 assert(isPermutation(tileInterchange) && 418 "expect tile interchange is a permutation"); 419 420 // Create an empty tile loop nest. 421 TileLoopNest tileLoopNest(consumerOp); 422 423 // Search the number of outer parallel loops to separate them from possible 424 // inner reduction dimensions. 425 SmallVector<StringAttr> iterTypes = 426 llvm::to_vector<6>(consumerOp.iterator_types().getAsRange<StringAttr>()); 427 applyPermutationToVector(iterTypes, tileInterchange); 428 auto *it = find_if(iterTypes, [&](StringAttr iterType) { 429 return !isParallelIterator(iterType); 430 }); 431 int64_t split = std::distance(iterTypes.begin(), it); 432 433 // Helper to fuse the producers greedily using a queue of fusion candidates. 434 auto fuseProducersGreedily = [&](ArrayRef<OpOperand *> operands) { 435 SmallVector<OpOperand *> candidates(operands.begin(), operands.end()); 436 while (!candidates.empty()) { 437 FailureOr<LinalgOp> fusedProducer = 438 tileLoopNest.fuseProducer(b, candidates.pop_back_val()); 439 if (failed(fusedProducer)) 440 continue; 441 candidates.append(fusedProducer->getInputAndOutputOperands()); 442 } 443 }; 444 445 // Tile the outer parallel loops and fuse the output operands. 446 SmallVector<int64_t> outerTileSizes; 447 outerTileSizes.append(tileSizes.begin(), tileSizes.begin() + split); 448 outerTileSizes.append(tileSizes.size() - split, 0); 449 if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange))) 450 return failure(); 451 fuseProducersGreedily(tileLoopNest.getRootOp().getOutputOperands()); 452 453 // Tile the remaining loops and fuse the input operands. 454 SmallVector<int64_t> innerTileSizes; 455 innerTileSizes.append(split, 0); 456 innerTileSizes.append(tileSizes.begin() + split, tileSizes.end()); 457 if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange))) 458 return failure(); 459 fuseProducersGreedily(tileLoopNest.getRootOp().getInputOperands()); 460 461 return tileLoopNest; 462 } 463