1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Fusion pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 18 #include "mlir/Dialect/Linalg/Passes.h" 19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 22 #include "mlir/Dialect/Tensor/IR/Tensor.h" 23 #include "mlir/IR/AffineExpr.h" 24 #include "mlir/IR/AffineMap.h" 25 #include "mlir/IR/Dominance.h" 26 #include "mlir/Support/LLVM.h" 27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 28 #include "mlir/Transforms/RegionUtils.h" 29 #include "llvm/ADT/MapVector.h" 30 #include "llvm/Support/CommandLine.h" 31 #include "llvm/Support/Debug.h" 32 33 #include <set> 34 35 #define DEBUG_TYPE "linalg-fusion" 36 37 using namespace mlir; 38 using namespace mlir::edsc; 39 using namespace mlir::edsc::intrinsics; 40 using namespace mlir::linalg; 41 42 using llvm::dbgs; 43 44 /// Implements a simple high-level fusion pass on linalg structured operations. 45 /// 46 /// In each block, linalg ops are processed in reverse textual order. 47 /// Given a linalg op `O`, fusion occurs by: 48 /// 1. inspecting the linalg ops that write into the views read by `O`. There 49 /// are 2 cases: 50 /// a) buffer case: use the SSA value of the views and a simple alias 51 /// analysis on subview ops to determine producer-consumer dependences; 52 /// b) tensor case: use SSA use-def chains on subtensor ops; 53 /// 2. greedily fuse the linalg ops that produce the subview/subtensor. 54 /// 3. inspect the fused ops and determine whether they have other remaining 55 /// LinalgOp uses. If not, then erase the original producing linalg op. 56 /// 57 /// More advanced use cases, analyses as well as profitability heuristics are 58 /// left for future work. 59 60 // Fill `offset`, `sizes` and `strides` used to iterate over the shape indexed 61 // by `permutationMap`. 62 static void inferShapeComponents(AffineMap permutationMap, 63 ArrayRef<Range> loopRanges, 64 SmallVectorImpl<OpFoldResult> &offsets, 65 SmallVectorImpl<OpFoldResult> &sizes, 66 SmallVectorImpl<OpFoldResult> &strides) { 67 assert(permutationMap.isProjectedPermutation() && 68 "expected some subset of a permutation map"); 69 SmallVector<Range, 4> shapeRanges(permutationMap.getNumResults()); 70 unsigned idx = 0; 71 for (AffineExpr e : permutationMap.getResults()) { 72 // loopToOperandRangesMaps are permutations-only, just swap indices. 73 unsigned loopPos = e.cast<AffineDimExpr>().getPosition(); 74 shapeRanges[idx++] = loopRanges[loopPos]; 75 } 76 // Construct a new subshape for the tile. 77 unsigned rank = shapeRanges.size(); 78 offsets.reserve(rank); 79 sizes.reserve(rank); 80 strides.reserve(rank); 81 for (auto r : shapeRanges) { 82 offsets.push_back(r.offset); 83 sizes.push_back(r.size); 84 strides.push_back(r.stride); 85 } 86 } 87 88 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be 89 // a subset of the original loop ranges of `op`. 90 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps 91 // to the `loopRanges` in order to obtain view ranges. 92 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, 93 ArrayRef<Range> loopRanges) { 94 SmallVector<Value, 8> clonedShapes; 95 clonedShapes.reserve(op.getNumShapedOperands()); 96 97 // Iterate over the shape operands in order. 98 // Extract the subranges from the linearized ranges. 99 for (auto en : llvm::enumerate(op.getShapedOperands())) { 100 unsigned shapedOperandIdx = en.index(); 101 AffineMap map = op.getIndexingMap(shapedOperandIdx); 102 LLVM_DEBUG(llvm::dbgs() << "shapedOperandIdx: " << shapedOperandIdx 103 << " with indexingMap: " << map << "\n"); 104 SmallVector<OpFoldResult, 4> offsets, sizes, strides; 105 inferShapeComponents(map, loopRanges, offsets, sizes, strides); 106 Value shape = en.value(); 107 Value sub = shape.getType().isa<MemRefType>() 108 ? b.create<SubViewOp>(loc, shape, offsets, sizes, strides) 109 .getResult() 110 : b.create<SubTensorOp>(loc, shape, offsets, sizes, strides) 111 .getResult(); 112 clonedShapes.push_back(sub); 113 } 114 // Append the other operands. 115 auto operands = op.getAssumedNonShapedOperands(); 116 clonedShapes.append(operands.begin(), operands.end()); 117 118 // Iterate over the results in order. 119 // Extract the subtensor type from the linearized range. 120 // Since we do not enforce any canonicalizations on the fly, this is always 121 // fully dynamic at construction time. 122 SmallVector<Type, 4> resultTypes; 123 resultTypes.reserve(op->getNumResults()); 124 for (RankedTensorType t : op.getOutputTensorTypes()) { 125 unsigned rank = t.getRank(); 126 SmallVector<int64_t, 4> staticOffsetsVector( 127 rank, ShapedType::kDynamicStrideOrOffset); 128 SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize); 129 SmallVector<int64_t, 4> staticStridesVector( 130 rank, ShapedType::kDynamicStrideOrOffset); 131 resultTypes.push_back(SubTensorOp::inferResultType( 132 t.cast<RankedTensorType>(), staticOffsetsVector, staticSizesVector, 133 staticStridesVector)); 134 } 135 136 Operation *clonedOp = op.clone(b, loc, resultTypes, clonedShapes); 137 // When the producer is an IndexedGenericOp, we have to transform its block 138 // IV arguments according to the tiling of the consumer, i.e. offset them by 139 // the values computed in `loopRanges`. 140 if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) { 141 auto &block = indexedGenericOp.region().front(); 142 OpBuilder::InsertionGuard g(b); 143 b.setInsertionPointToStart(&block); 144 for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) { 145 Value oldIndex = block.getArgument(i); 146 // TODO: replace by an affine_apply. 147 AddIOp newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex, 148 loopRanges[i].offset); 149 oldIndex.replaceAllUsesExcept(newIndex, 150 SmallPtrSet<Operation *, 1>{newIndex}); 151 } 152 } 153 154 return clonedOp; 155 } 156 157 struct ShapeDimension { 158 Value shape; 159 unsigned dimension; 160 }; 161 162 // Given an `op`, returns the first (`shape`, `dimension`) pair that identifies 163 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps 164 // guarantees at least one such dimension is found. If multiple candidates exist 165 // they must agree by construction (i.e. have the same size) and we just return 166 // the first one. 167 static ShapeDimension 168 getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth, 169 bool fromSubViewOpOnly = false) { 170 auto maps = op.indexing_maps(); 171 // Iterate over the inputs and outputs in order. 172 // Extract the subranges from the linearized ranges. 173 for (auto en : llvm::enumerate(op.getShapedOperands())) { 174 // The method `getRangeFromOperandShape` requires using SubViewOp or 175 // SubTensorOps. If the value isnt defined from there continue. 176 // todo: The method should be adapted to get the values from 177 // `ViewInterface`. The interface needs a `getOrCreateRanges` method which 178 // currently returns a `linalg.range`. The fix here is to move this op to 179 // `std` dialect and add the method to `ViewInterface`. 180 if (fromSubViewOpOnly && 181 !isa_and_nonnull<SubViewOp, SubTensorOp>(en.value().getDefiningOp())) 182 continue; 183 184 unsigned idx = en.index(); 185 auto map = maps[idx].cast<AffineMapAttr>().getValue(); 186 LLVM_DEBUG(llvm::dbgs() 187 << "getShapeDefiningLoopRange I/O idx: " << idx << "\n"); 188 LLVM_DEBUG(llvm::dbgs() 189 << "getShapeDefiningLoopRange map: " << map << "\n"); 190 Value shape = en.value(); 191 SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr); 192 for (auto en2 : llvm::enumerate(map.getResults())) { 193 auto dimExpr = en2.value().dyn_cast<AffineDimExpr>(); 194 if (!dimExpr) 195 continue; 196 if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) { 197 LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: " 198 << loopDepth << "\n"); 199 LLVM_DEBUG(llvm::dbgs() 200 << "getShapeDefiningLoopRange shape: " << shape << "\n"); 201 return ShapeDimension{shape, static_cast<unsigned>(en2.index())}; 202 } 203 } 204 } 205 llvm_unreachable("Expect to be able to extract a shape defining loop range"); 206 } 207 208 /// Fuse the producer by cloning the `producer`. The `fusedLoopsAndRanges` 209 /// provides the loop range information for the fused loops. The rest are 210 /// obtained from the producer itself, since they are not tiled + fused. 211 static LinalgOp fuse(OpBuilder &b, LinalgOp producer, 212 const DenseMap<unsigned, Range> &fusedLoopsAndRanges) { 213 214 unsigned nPar = producer.getNumParallelLoops(); 215 unsigned nRed = producer.getNumReductionLoops(); 216 unsigned nWin = producer.getNumWindowLoops(); 217 SmallVector<Range, 8> loopRanges(nPar + nRed + nWin); 218 for (auto fusedLoops : fusedLoopsAndRanges) 219 loopRanges[fusedLoops.first] = fusedLoops.second; 220 221 // Iterate over all dimensions. For the dimensions not identified by the 222 // producer map for `producerIdx`, we need to explicitly compute the shape 223 // that defines the loop ranges using the `producer`. 224 for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) { 225 if (loopRanges[i].offset) 226 LLVM_DEBUG(llvm::dbgs() 227 << "existing LoopRange: " << loopRanges[i] << "\n"); 228 else { 229 auto shapeDim = getShapeDefiningLoopRange(producer, i); 230 loopRanges[i] = Range{std_constant_index(0), 231 std_dim(shapeDim.shape, shapeDim.dimension), 232 std_constant_index(1)}; 233 LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n"); 234 } 235 } 236 237 return cloneWithLoopRanges(b, producer.getLoc(), producer, loopRanges); 238 } 239 240 /// Get the loop range for a dimension `dim` based on the `shapedOperand`. It is 241 /// expected to be defined by a subview op or a subtensor op. 242 static Range getRangeFromOperandShape(OpBuilder &b, Location loc, 243 Value shapedOperand, unsigned dim) { 244 Operation *shapeProducingOp = shapedOperand.getDefiningOp(); 245 if (auto subViewOp = dyn_cast<SubViewOp>(shapeProducingOp)) 246 return subViewOp.getOrCreateRanges(b, loc)[dim]; 247 if (auto subTensorOp = dyn_cast<SubTensorOp>(shapeProducingOp)) 248 return subTensorOp.getOrCreateRanges(b, loc)[dim]; 249 llvm_unreachable("SubviewOp or SubTensorOp expected"); 250 } 251 252 /// Fuses the producer of `producerIdx` into the loop immediately enclosing 253 /// `consumer`. This is achieved by "recomputing" the `producer` at the time it 254 /// is needed just before the `consumer. 255 /// 256 /// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are 257 /// 2 cases: 258 /// 1. Buffer case: `producerIdx` is the index of the buffer in 259 /// `producer.getOutputBuffers()`. 260 /// 2. Tensor case: `producerIdx` is the index of the tensor in 261 /// `producer.getResults()`. 262 static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp, AffineMap producerMap, 263 OpOperand &consumerOpOperand) { 264 LLVM_DEBUG(llvm::dbgs() << "Producer map: " << producerMap << "\n"); 265 DenseMap<unsigned, Range> fusedLoopsAndRanges; 266 Value shapedOperand = consumerOpOperand.get(); 267 for (auto en : llvm::enumerate(producerMap.getResults())) { 268 unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition(); 269 fusedLoopsAndRanges[posInProducerLoop] = getRangeFromOperandShape( 270 b, consumerOpOperand.getOwner()->getLoc(), shapedOperand, en.index()); 271 } 272 return fuse(b, producerOp, fusedLoopsAndRanges); 273 } 274 275 // Encode structural fusion safety preconditions. 276 // Some of these will be lifted in the future with better analysis. 277 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView, 278 LinalgOp consumer) { 279 assert(producer.hasBufferSemantics() && 280 "expected linalg op with buffer semantics"); 281 assert(consumer.hasBufferSemantics() && 282 "expected linalg op with buffer semantics"); 283 if (producer.getNumOutputs() != 1) { 284 LLVM_DEBUG(llvm::dbgs() << "\nNot structurally fusable (multi-output)"); 285 return false; 286 } 287 // Only fuse when the producer block dominates. 288 DominanceInfo dom(producer.getOperation()); 289 if (!dom.dominates(producer->getBlock(), consumer->getBlock())) { 290 LLVM_DEBUG( 291 llvm::dbgs() 292 << "\nNot structurally fusable (producer block does not dominate)"); 293 return false; 294 } 295 return true; 296 } 297 298 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph, 299 LinalgOp consumer, 300 Value consumedView, 301 LinalgOp producer) { 302 assert(producer.hasBufferSemantics() && 303 "expected linalg op with buffer semantics"); 304 assert(consumer.hasBufferSemantics() && 305 "expected linalg op with buffer semantics"); 306 // Make some simple structural checks that alleviate the need for more 307 // complex analyses. 308 if (!isStructurallyFusableProducer(producer, consumedView, consumer)) { 309 LLVM_DEBUG(llvm::dbgs() << "\n***Not static last write due to structure:\t" 310 << *producer.getOperation()); 311 return false; 312 } 313 // Check for any interleaved write to consumedView. 314 if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) { 315 LLVM_DEBUG(llvm::dbgs() << "\n***Not fusable due to interleaved write:\t" 316 << *producer.getOperation()); 317 return false; 318 } 319 return true; 320 } 321 322 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, 323 LinalgOp consumer, Value consumedView, 324 LinalgOp producer) { 325 assert(producer.hasBufferSemantics() && 326 "expected linalg op with buffer semantics"); 327 assert(consumer.hasBufferSemantics() && 328 "expected linalg op with buffer semantics"); 329 if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer)) 330 return false; 331 // Check for any fusion-preventing dependence to any shape read/written that 332 // would violate dependences. 333 if (!graph.findCoveringDependences(producer, consumer).empty()) { 334 LLVM_DEBUG(llvm::dbgs() 335 << "\n***Not fusable due to an interleaved dependence:\t" 336 << *producer.getOperation()); 337 return false; 338 } 339 if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) { 340 // TODO: add a level of indirection to linalg.generic. 341 if (convOp.padding()) 342 return false; 343 } 344 if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) { 345 // TODO: add a level of indirection to linalg.generic. 346 if (convOp.padding()) 347 return false; 348 } 349 return true; 350 } 351 352 /// For `consumer` with buffer semantics, find the Linalg operation on buffers 353 /// that is the last writer of `consumerOpOperand`. For now the fusable 354 /// dependence is returned as an instance of the `dependenceGraph`. 355 static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> 356 findFusableProducer(OpOperand &consumerOpOperand, 357 const LinalgDependenceGraph &dependenceGraph) { 358 LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.getOwner()); 359 if (!consumerOp) 360 return {}; 361 362 // Only consider RAW and WAW atm. 363 for (auto depType : { 364 LinalgDependenceGraph::DependenceType::RAW, 365 LinalgDependenceGraph::DependenceType::WAW, 366 }) { 367 for (auto dependence : llvm::make_filter_range( 368 dependenceGraph.getDependencesInto(consumerOp, depType), 369 [&](LinalgDependenceGraph::LinalgDependenceGraphElem elem) { 370 Value v = elem.getIndexingValue(); 371 Optional<unsigned> operandNum = 372 elem.getIndexingOpViewOperandNum(); 373 return isa<LinalgOp>(elem.getDependentOp()) && 374 v == consumerOpOperand.get() && operandNum && 375 operandNum.getValue() == 376 consumerOpOperand.getOperandNumber(); 377 })) { 378 // Consumer consumes this view, `isStructurallyFusableProducer` also 379 // checks whether it is a strict subview of the producer view. 380 auto producer = cast<LinalgOp>(dependence.getDependentOp()); 381 LLVM_DEBUG(llvm::dbgs() 382 << "\n" 383 << LinalgDependenceGraph::getDependenceTypeStr(depType) 384 << "producer: " << *dependence.getDependentOp() 385 << " view: " << dependence.getDependentValue() << "\n"); 386 387 // If the producer and consumer have tensor semantics, the only dependence 388 // between them is through a RAW dependence and they are fusable by 389 // construction. For buffer semantics need additional checks. 390 if (producer.hasBufferSemantics() && consumerOp.hasBufferSemantics() && 391 isFusableInto(dependenceGraph, consumerOp, consumerOpOperand.get(), 392 producer)) 393 return dependence; 394 if (producer.hasTensorSemantics() && consumerOp.hasTensorSemantics()) { 395 assert(dependence.dependenceType == 396 LinalgDependenceGraph::DependenceType::RAW); 397 return dependence; 398 } 399 } 400 } 401 return {}; 402 } 403 404 Optional<FusionInfo> 405 mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, OpOperand &consumerOpOperand, 406 const LinalgDependenceGraph &graph) { 407 Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence = 408 findFusableProducer(consumerOpOperand, graph); 409 if (!fusableDependence) 410 return llvm::None; 411 412 LinalgOp producerOp = dyn_cast<LinalgOp>(fusableDependence->getDependentOp()); 413 if (!producerOp) 414 return llvm::None; 415 416 // If producer is already in the same block as consumer, we are done. 417 if (consumerOpOperand.get().getParentBlock() == 418 fusableDependence->getDependentValue().getParentBlock()) 419 return llvm::None; 420 421 Optional<AffineMap> producerMap = 422 fusableDependence->getDependentOpViewIndexingMap(); 423 if (!producerMap) 424 return llvm::None; 425 426 // Must be a subview or a slice to guarantee there are loops we can fuse 427 // into. 428 auto subView = consumerOpOperand.get().getDefiningOp<SubViewOp>(); 429 if (!subView) { 430 LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview)"); 431 return llvm::None; 432 } 433 434 // Fuse `producer` just before `consumer`. 435 OpBuilder::InsertionGuard g(b); 436 b.setInsertionPoint(consumerOpOperand.getOwner()); 437 ScopedContext scope(b, consumerOpOperand.getOwner()->getLoc()); 438 LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " 439 << *consumerOpOperand.getOwner() << "\n"); 440 441 auto fusedProducer = fuse(b, producerOp, *producerMap, consumerOpOperand); 442 return FusionInfo{producerOp, fusedProducer}; 443 } 444 445 /// Walk back use-def chain through scf::For yields. 446 /// Sets `producer` and `outputIndex` if it finds a producer LinalgOp 447 448 // TODO(ravishankarm, ntv): This can be moved into the dependence graphs 449 // dependence tracking since the dependence tracking is similar to what is done 450 // w.r.t to buffers. 451 static void getProducerOfTensor(Value tensor, OpResult &opResult) { 452 if (!tensor.getType().isa<RankedTensorType>()) 453 return; 454 455 while (true) { 456 LLVM_DEBUG(llvm::dbgs() << "\ngetProducerOfTensor: " << tensor); 457 if (auto linalgOp = tensor.getDefiningOp<LinalgOp>()) { 458 opResult = tensor.cast<OpResult>(); 459 return; 460 } 461 if (auto subTensorOp = tensor.getDefiningOp<SubTensorOp>()) { 462 tensor = subTensorOp.source(); 463 continue; 464 } 465 if (auto blockArg = tensor.dyn_cast<BlockArgument>()) { 466 if (auto forOp = blockArg.getDefiningOp<scf::ForOp>()) { 467 tensor = *(forOp.getIterOperands().begin() + blockArg.getArgNumber()); 468 continue; 469 } 470 } 471 return; 472 } 473 } 474 475 Optional<FusionInfo> 476 mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) { 477 Value inputTensor = consumerOpOperand.get(); 478 OpResult producerOpResult; 479 getProducerOfTensor(inputTensor, producerOpResult); 480 if (!producerOpResult) { 481 LLVM_DEBUG(llvm::dbgs() << "\nUnable to find producer"); 482 return {}; 483 } 484 return fuseProducerOfTensor(b, producerOpResult, consumerOpOperand); 485 } 486 487 Optional<FusionInfo> 488 mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult, 489 OpOperand &consumerOpOperand) { 490 auto producerOp = dyn_cast<LinalgOp>(producerOpResult.getOwner()); 491 if (!producerOp) 492 return llvm::None; 493 494 LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.getOwner()); 495 if (!consumerOp) 496 return llvm::None; 497 498 Value inputTensor = consumerOpOperand.get(); 499 500 // Must be a subtensor to guarantee there are loops we can fuse into. 501 auto subTensor = inputTensor.getDefiningOp<SubTensorOp>(); 502 if (!subTensor) { 503 LLVM_DEBUG(llvm::dbgs() 504 << "\nNot fusable, not a subtensor: " << inputTensor); 505 return {}; 506 } 507 508 // If producer is already in the same block as consumer, we are done. 509 if (consumerOpOperand.get().getParentBlock() == 510 producerOpResult.getParentBlock()) 511 return {}; 512 513 // Insert fused `producer` just before `consumer`. 514 OpBuilder::InsertionGuard g(b); 515 b.setInsertionPoint(consumerOp); 516 ScopedContext scope(b, consumerOp->getLoc()); 517 LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumerOp << "\n"); 518 LinalgOp fusedProducer = 519 fuse(b, producerOp, 520 producerOp.getOutputIndexingMap(producerOpResult.getResultNumber()), 521 consumerOpOperand); 522 523 // Replace use. 524 // Canonicalizations are not guaranteed to have happened before constructing 525 // `fusedProducer`. In the tensor case this can result in temporary type 526 // mismatches. Insert a `tensor.cast` op to propagate the transformation 527 // invariant that types are compatible. 528 Value def = fusedProducer->getResult(producerOpResult.getResultNumber()); 529 Type consumerType = consumerOpOperand.get().getType(); 530 if (consumerType != def.getType()) 531 def = b.create<tensor::CastOp>(fusedProducer.getLoc(), consumerType, def); 532 consumerOpOperand.set(def); 533 return FusionInfo{cast<LinalgOp>(producerOpResult.getOwner()), fusedProducer}; 534 } 535 536 /// Prune all dimensions that are of reduction iterator type from `map`. 537 static AffineMap pruneReductionDimsFromMap(ArrayRef<Attribute> iteratorTypes, 538 AffineMap map) { 539 llvm::SmallDenseSet<unsigned> projectedDims; 540 for (auto attr : llvm::enumerate(iteratorTypes)) { 541 if (!isParallelIterator(attr.value())) 542 projectedDims.insert(attr.index()); 543 } 544 return getProjectedMap(map, projectedDims); 545 } 546 547 /// Returns the mapping from iterations in the consumer that write to the same 548 /// location as the iterations in the producer. To do so use 549 /// - indexing map of the fused view in the consumer : consumerIndexMap 550 /// - indexing map of the fused view in the producer : producerIndexMap 551 /// consumerLoopToProducerLoop = 552 /// inverse(producerIndexMap).compose(consumerIndexMap) 553 static Optional<AffineMap> getConsumerLoopToProducerLoopMap( 554 LinalgDependenceGraph::LinalgDependenceGraphElem dependence) { 555 auto producer = dyn_cast<LinalgOp>(dependence.getDependentOp()); 556 if (!producer) 557 return None; 558 559 Optional<AffineMap> producerIndexingMap = 560 dependence.getDependentOpViewIndexingMap(); 561 Optional<AffineMap> consumerIndexingMap = 562 dependence.getIndexingOpViewIndexingMap(); 563 if (!producerIndexingMap || !consumerIndexingMap) 564 return None; 565 566 AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap( 567 producer.iterator_types().getValue(), *producerIndexingMap); 568 if (!prunedProducerIndexingMap.isPermutation()) 569 return None; 570 571 if (consumerIndexingMap->getNumResults() != 572 prunedProducerIndexingMap.getNumResults()) 573 return None; 574 575 LLVM_DEBUG({ 576 llvm::dbgs() << "\t producerMap : "; 577 producerIndexingMap->print(llvm::dbgs()); 578 llvm::dbgs() << " pruned : "; 579 prunedProducerIndexingMap.print(llvm::dbgs()); 580 llvm::dbgs() << "\n"; 581 llvm::dbgs() << "\t consumerMap : "; 582 consumerIndexingMap->print(llvm::dbgs()); 583 llvm::dbgs() << "\n"; 584 }); 585 586 AffineMap invProducerIndexMap = inversePermutation(prunedProducerIndexingMap); 587 if (!invProducerIndexMap) 588 return None; 589 590 return invProducerIndexMap.compose(*consumerIndexingMap); 591 } 592 593 /// Given a projected permutation `map`, returns true if the map changes the 594 /// order in which the fused loop dimension appear. 595 static bool doesTransposeAccess(AffineMap map, 596 const std::set<unsigned> &fusableLoops) { 597 Optional<unsigned> lastFusableLoop; 598 for (unsigned pos : llvm::map_range(map.getResults(), [](AffineExpr expr) { 599 return expr.cast<AffineDimExpr>().getPosition(); 600 })) { 601 if (!fusableLoops.count(pos)) 602 continue; 603 if (!lastFusableLoop) { 604 lastFusableLoop = pos; 605 continue; 606 } 607 if (pos <= lastFusableLoop.getValue()) 608 return true; 609 lastFusableLoop = pos; 610 } 611 return false; 612 } 613 614 /// Returns the positions of the loop in `op` that can be tiled based on the 615 /// operations that are to be fused with it. For example, in a 616 /// 617 /// linalg.matmul ins(%a, %b : ...) outs(%c : ...) 618 /// 619 /// if the producer of %a needs to be fused with this op, only the `i` loop of 620 /// the matmul can be tiled while fusing. If producer of %a, and %b are to be 621 /// fused, then no loops can be tiled while fusing. The conditions used are: 622 /// 1. Only parallel loops can be used for tile + fuse. Find the number of 623 /// common outer parallel loops between the op and its producers being fused. 624 /// 2. Of the parallel loops only some can be fused. Only those loops can be 625 /// fused such where the fusable loops iteration space only touches one tile 626 /// of the fused operation. This is because the producer (which is writing 627 /// the fused subview) has update semantics. 628 /// 629 /// Since an inverse computation is needed, we need to consider the projection 630 /// of the producerIndexMap w.r.t the parallel loops. The actual fusable loops 631 /// are the dimensions of the consumerLoopToProducerLoop map that correspond to 632 /// parallel loops and appear in the result of the map 633 /// 634 /// Example 1: 635 /// linalg.fill(%c, %cst) 636 /// linalg.matmul ins(%a, %b) outs(%c) 637 /// Number of parallel loops : 2 638 /// producerIndexMap = affine_map<(i, j) ->(i , j)> 639 /// consumerIndexMap = affine_map<(i, j, k) -> (i, j)> 640 /// consumerLoopToProducerLoop = affine_map<(i, j, k) -> (i, j)> 641 /// Fused dimensions : i, j 642 /// 643 /// Example 2: 644 /// linalg.matmul ins(%a, %b) outs(%c) 645 /// linalg.generic {indexing_maps = [affine_map<(i, j) -> (j, i)>, ... 646 /// iterator_types = ["parallel", "parallel"]} 647 /// ins(%c) ... 648 /// 649 /// Number of parallel loops = 2: 650 /// producerIndexMap (projected to parallel loops) = 651 /// affine_map<(i, j) -> (i, j)> 652 /// consumerLoopToProducerLoop2 = affine_map<(i, j) -> (j, i)> 653 /// Fused dimensions : i, j 654 /// 655 /// Example 3: 656 /// linalg.copy(%s, %b) 657 /// linalg.matmul ins(%a, %b) outs(%c) 658 /// 659 /// Number of parallel loops = 2 660 /// produceIndexMap : affine_map<(i, j) -> (i, j)> 661 /// consumerLoopToProduceLoops = affine_map<(i, j, k) -> (k, j)> 662 /// submap with only parallel loops = affine_map<(i, j) -> (j)> 663 /// Fused dimensions : j 664 static std::set<unsigned> 665 collectFusableLoops(ArrayRef<LinalgOp> ops, 666 const FusableOpDependencesTy &fusableDependences) { 667 assert(!ops.empty()); 668 auto getNumOuterParallelLoops = [](LinalgOp linalgOp) { 669 return linalgOp.iterator_types() 670 .getValue() 671 .take_while([](Attribute attr) -> bool { 672 return attr.cast<StringAttr>().getValue() == 673 getParallelIteratorTypeName(); 674 }) 675 .size(); 676 }; 677 678 size_t numOuterParallelLoops = getNumOuterParallelLoops(ops.back()); 679 for (auto op : ops.drop_back()) { 680 numOuterParallelLoops = 681 std::min(numOuterParallelLoops, getNumOuterParallelLoops(op)); 682 } 683 684 std::set<unsigned> fusableLoops; 685 auto range = llvm::seq<unsigned>(0, numOuterParallelLoops); 686 fusableLoops.insert(range.begin(), range.end()); 687 688 for (auto op : reverse(ops)) { 689 for (auto dependence : fusableDependences.lookup(op)) { 690 LLVM_DEBUG({ 691 llvm::dbgs() << "\t fusable :"; 692 for (unsigned i : fusableLoops) 693 llvm::dbgs() << " " << i; 694 llvm::dbgs() << "\n"; 695 }); 696 697 Optional<AffineMap> consumerLoopToProducerLoop = 698 getConsumerLoopToProducerLoopMap(dependence); 699 if (!consumerLoopToProducerLoop) { 700 op.emitRemark("failed to get map from consumer loop to producer loop"); 701 return {}; 702 } 703 // todo: This condition is only an implementation limitation. When fusing 704 // the operation, if the accesses in the producer/consumer are transposes 705 // of each other, the loop bounds for the tiled producer can be 706 // manipulated accordingly. This requires some additional bookkeeping in 707 // the implementation of tile+fuse that is deferred to later. 708 if (doesTransposeAccess(*consumerLoopToProducerLoop, fusableLoops)) { 709 op.emitRemark("unhandled fusion when fusion requires permutation"); 710 return {}; 711 } 712 713 std::set<unsigned> candidates; 714 for (AffineExpr expr : consumerLoopToProducerLoop->getResults()) { 715 unsigned position = expr.cast<AffineDimExpr>().getPosition(); 716 if (fusableLoops.count(position)) 717 candidates.insert(position); 718 } 719 LLVM_DEBUG({ 720 llvm::dbgs() << "\t candidates :"; 721 for (unsigned i : candidates) 722 llvm::dbgs() << " " << i; 723 llvm::dbgs() << "\n"; 724 }); 725 if (candidates.empty()) 726 return {}; 727 std::swap(candidates, fusableLoops); 728 } 729 } 730 731 return fusableLoops; 732 } 733 734 /// Find all dependences that are fusable. 735 FusableOpDependencesTy mlir::linalg::findAllFusableDependences( 736 ArrayRef<LinalgOp> ops, const LinalgDependenceGraph &dependenceGraph) { 737 FusableOpDependencesTy fusableDependences; 738 DenseMap<Operation *, SmallVector<AffineMap, 1>> fusedProducerIndexingMap; 739 for (LinalgOp op : reverse(ops)) { 740 for (OpOperand &opOperand : op.getShapedOpOperands()) { 741 Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> 742 fusableDependence = findFusableProducer(opOperand, dependenceGraph); 743 if (!fusableDependence) 744 continue; 745 LinalgOp producerOp = 746 dyn_cast<LinalgOp>(fusableDependence->getDependentOp()); 747 if (!producerOp) 748 continue; 749 // Do not fuse dependences that are to operations not in the same basic 750 // block. This avoid moving fused operations across loops that might 751 // themselves carry dependency making the fusion illegal. 752 if (producerOp->getBlock() != op->getBlock()) 753 continue; 754 755 // Make sure that the indexing map of the view used for fusion in the 756 // producer is a projected permutation. 757 Optional<AffineMap> producerMap = 758 fusableDependence->getDependentOpViewIndexingMap(); 759 Optional<AffineMap> consumerMap = 760 fusableDependence->getIndexingOpViewIndexingMap(); 761 assert( 762 consumerMap && 763 "unable to find indexing map of operand/result of indexing OpView"); 764 fusedProducerIndexingMap[producerOp.getOperation()].push_back( 765 *consumerMap); 766 if (!producerMap || !producerMap->isProjectedPermutation() || 767 !consumerMap->isProjectedPermutation()) 768 continue; 769 770 fusableDependences[producerOp.getOperation()].push_back( 771 *fusableDependence); 772 } 773 } 774 // TODO: Currently fusion would not be legal if the fusable dependence is to 775 // the same producer but different indexing map in the consumer. Fix this, but 776 // in the meanwhile disallow such a fusion. 777 for (auto useIndexingMapsList : fusedProducerIndexingMap) { 778 AffineMap map1 = useIndexingMapsList.second.front(); 779 for (AffineMap map2 : 780 ArrayRef<AffineMap>(useIndexingMapsList.second).drop_front()) { 781 if (map1 != map2) { 782 fusableDependences.erase(useIndexingMapsList.first); 783 break; 784 } 785 } 786 } 787 return fusableDependences; 788 } 789 790 /// Tile the fused loops in the root operation, by setting the tile sizes for 791 /// all other loops to zero (those will be tiled later). 792 static Optional<TiledLinalgOp> tileRootOperation( 793 OpBuilder &builder, LinalgOp op, ArrayRef<Value> tileSizeVector, 794 const LinalgTilingOptions &options, const std::set<unsigned> &fusedLoops) { 795 SmallVector<Value, 4> tileSizes(tileSizeVector.begin(), tileSizeVector.end()); 796 auto zero = std_constant_index(0); 797 for (unsigned i = 0, e = tileSizes.size(); i != e; ++i) 798 if (!fusedLoops.count(i)) 799 tileSizes[i] = zero; 800 LinalgTilingOptions tileFusedLoopsOptions = options; 801 tileFusedLoopsOptions.setTileSizes(tileSizes); 802 return tileLinalgOp(builder, op, tileFusedLoopsOptions); 803 } 804 805 /// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected 806 /// to be a tiled operation such that it is valid to fuse all operations in 807 /// `fusionCandidates`, i.e. move the operation within the inter-tile loops of 808 /// `tiledOp`. 809 static SmallVector<LinalgOp, 1> 810 fuseOperations(OpBuilder &builder, LinalgOp rootOp, LinalgOp tiledOp, 811 ArrayRef<LinalgOp> fusionCandidates, 812 const FusableOpDependencesTy &fusableDependences, 813 const std::set<unsigned> &fusedLoops) { 814 OpBuilder::InsertionGuard guard(builder); 815 builder.setInsertionPoint(tiledOp); 816 DenseMap<unsigned, Range> fusedLoopsAndRanges; 817 for (unsigned loop : fusedLoops) { 818 ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop, true); 819 fusedLoopsAndRanges[loop] = getRangeFromOperandShape( 820 builder, tiledOp.getLoc(), shapeDim.shape, shapeDim.dimension); 821 } 822 823 SmallVector<LinalgOp, 1> fusedOps(fusionCandidates.size()); 824 DenseMap<Operation *, LinalgOp> origOpToFusedOp; 825 origOpToFusedOp[rootOp.getOperation()] = tiledOp; 826 for (auto candidate : enumerate(llvm::reverse(fusionCandidates))) { 827 LinalgOp origOp = candidate.value(); 828 LinalgOp fusedOp = fuse(builder, origOp, fusedLoopsAndRanges); 829 origOpToFusedOp[origOp.getOperation()] = fusedOp; 830 fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp; 831 // If the producer consumer operations are linalg operations on tensors, the 832 // dependence is due to value produced (as a return tensor) by the producer 833 // and used in the consumer. The returned value of the fused op needs to be 834 // made the operand of the tiled/fused consumer operation. By construction 835 // the value returned by the producer is the value used by the consumer. 836 for (auto &dependence : fusableDependences.lookup(origOp.getOperation())) { 837 if (origOp.hasTensorSemantics() && 838 dependence.dependenceType == 839 LinalgDependenceGraph::DependenceType::RAW) { 840 unsigned resultIndex = 841 dependence.getDependentOpViewResultNum().getValue(); 842 LinalgOp consumer = origOpToFusedOp.lookup(dependence.getIndexingOp()); 843 if (!consumer) 844 continue; 845 Value replacementValue = fusedOp.getOperation()->getResult(resultIndex); 846 consumer.getOperation()->setOperand( 847 dependence.getIndexingOpViewOperandNum().getValue(), 848 replacementValue); 849 } 850 } 851 builder.setInsertionPoint(fusedOp); 852 } 853 return fusedOps; 854 } 855 856 template <typename LoopType> 857 static Optional<TiledAndFusedLinalgOps> 858 tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef<LinalgOp> ops, 859 const LinalgDependenceGraph &dependenceGraph, 860 const LinalgTilingOptions &tilingOptions) { 861 if (ops.size() < 2) 862 return llvm::None; 863 LinalgOp rootOp = ops.back(); 864 if (!llvm::all_of( 865 ops, 866 [](LinalgOp linalgOp) { return linalgOp.hasBufferSemantics(); }) && 867 !llvm::all_of(ops, [](LinalgOp linalgOp) { 868 return linalgOp.hasTensorSemantics(); 869 })) { 870 rootOp.emitError( 871 "unable to fuse operations that have tensor semantics with operations " 872 "that have buffer semantics and viceversa."); 873 return llvm::None; 874 } 875 // TODO: Support interchange with tile + fuse. This might actually help do 876 // better fusion. 877 if (!tilingOptions.interchangeVector.empty()) { 878 rootOp.emitRemark("unable to handle tile and fuse with interchange"); 879 return llvm::None; 880 } 881 882 OpBuilder::InsertionGuard guard(builder); 883 builder.setInsertionPoint(rootOp); 884 ScopedContext scope(builder, rootOp.getLoc()); 885 886 // Find all the producers. 887 FusableOpDependencesTy fusableDependences = 888 findAllFusableDependences(ops, dependenceGraph); 889 if (fusableDependences.empty()) 890 return llvm::None; 891 892 TiledAndFusedLinalgOps ret; 893 // Find the loops that can be tiled and fused. 894 ret.fusedLoopDims = collectFusableLoops(ops, fusableDependences); 895 896 // If there are no fusable dependences or there are no tile+fusable loops, 897 // just return. 898 if (ret.fusedLoopDims.empty()) { 899 return llvm::None; 900 } 901 902 // Tile the fused loops in the last operation in the list. 903 SmallVector<Value, 4> tileSizeVector = 904 tilingOptions.tileSizeComputationFunction(builder, rootOp); 905 Optional<TiledLinalgOp> tiledRootOp = tileRootOperation( 906 builder, rootOp, tileSizeVector, tilingOptions, ret.fusedLoopDims); 907 if (!tiledRootOp) { 908 rootOp.emitRemark("failed to tile the fused loops"); 909 return llvm::None; 910 } 911 ret.op = tiledRootOp->op; 912 ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end()); 913 914 // Fuse the other operations into the fused inter-tile loops produced above. 915 ret.fusedProducers = fuseOperations(builder, rootOp, ret.op, ops.drop_back(), 916 fusableDependences, ret.fusedLoopDims); 917 918 return ret; 919 } 920 921 Optional<TiledAndFusedLinalgOps> 922 mlir::linalg::tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops, 923 const LinalgDependenceGraph &dependenceGraph, 924 const LinalgTilingOptions &tilingOptions) { 925 switch (tilingOptions.loopType) { 926 case LinalgTilingLoopType::Loops: 927 return tileAndFuseLinalgOpsImpl<scf::ForOp>(builder, ops, dependenceGraph, 928 tilingOptions); 929 case LinalgTilingLoopType::ParallelLoops: 930 return tileAndFuseLinalgOpsImpl<scf::ParallelOp>( 931 builder, ops, dependenceGraph, tilingOptions); 932 default:; 933 } 934 return llvm::None; 935 } 936