1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Fusion pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 18 #include "mlir/Dialect/Linalg/Passes.h" 19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/MemRef/EDSC/Intrinsics.h" 22 #include "mlir/Dialect/MemRef/IR/MemRef.h" 23 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 24 #include "mlir/Dialect/Tensor/IR/Tensor.h" 25 #include "mlir/IR/AffineExpr.h" 26 #include "mlir/IR/AffineMap.h" 27 #include "mlir/IR/Dominance.h" 28 #include "mlir/Support/LLVM.h" 29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 30 #include "mlir/Transforms/RegionUtils.h" 31 #include "llvm/ADT/MapVector.h" 32 #include "llvm/Support/CommandLine.h" 33 #include "llvm/Support/Debug.h" 34 35 #include <set> 36 37 #define DEBUG_TYPE "linalg-fusion" 38 39 using namespace mlir; 40 using namespace mlir::edsc; 41 using namespace mlir::edsc::intrinsics; 42 using namespace mlir::linalg; 43 44 using llvm::dbgs; 45 46 /// Implements a simple high-level fusion pass on linalg structured operations. 47 /// 48 /// In each block, linalg ops are processed in reverse textual order. 49 /// Given a linalg op `O`, fusion occurs by: 50 /// 1. inspecting the linalg ops that write into the views read by `O`. There 51 /// are 2 cases: 52 /// a) buffer case: use the SSA value of the views and a simple alias 53 /// analysis on subview ops to determine producer-consumer dependences; 54 /// b) tensor case: use SSA use-def chains on subtensor ops; 55 /// 2. greedily fuse the linalg ops that produce the subview/subtensor. 56 /// 3. inspect the fused ops and determine whether they have other remaining 57 /// LinalgOp uses. If not, then erase the original producing linalg op. 58 /// 59 /// More advanced use cases, analyses as well as profitability heuristics are 60 /// left for future work. 61 62 struct ShapeDimension { 63 Value shape; 64 unsigned dimension; 65 }; 66 67 // Given an `op`, returns the first (`shape`, `dimension`) pair that identifies 68 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps 69 // guarantees at least one such dimension is found. If multiple candidates exist 70 // they must agree by construction (i.e. have the same size) and we just return 71 // the first one. 72 static ShapeDimension 73 getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth, 74 bool fromSubViewOpOnly = false) { 75 auto maps = op.indexing_maps(); 76 // Iterate over the inputs and outputs in order. 77 // Extract the subranges from the linearized ranges. 78 for (auto en : llvm::enumerate(op.getShapedOperands())) { 79 // The method `getRangeFromOperandShape` requires using SubViewOp or 80 // SubTensorOps. If the value isnt defined from there continue. 81 // todo: The method should be adapted to get the values from 82 // `ViewInterface`. The interface needs a `getOrCreateRanges` method which 83 // currently returns a `linalg.range`. The fix here is to move this op to 84 // `std` dialect and add the method to `ViewInterface`. 85 if (fromSubViewOpOnly && !isa_and_nonnull<memref::SubViewOp, SubTensorOp>( 86 en.value().getDefiningOp())) 87 continue; 88 89 unsigned idx = en.index(); 90 auto map = maps[idx].cast<AffineMapAttr>().getValue(); 91 LLVM_DEBUG(llvm::dbgs() 92 << "getShapeDefiningLoopRange I/O idx: " << idx << "\n"); 93 LLVM_DEBUG(llvm::dbgs() 94 << "getShapeDefiningLoopRange map: " << map << "\n"); 95 Value shape = en.value(); 96 SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr); 97 for (auto en2 : llvm::enumerate(map.getResults())) { 98 auto dimExpr = en2.value().dyn_cast<AffineDimExpr>(); 99 if (!dimExpr) 100 continue; 101 if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) { 102 LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: " 103 << loopDepth << "\n"); 104 LLVM_DEBUG(llvm::dbgs() 105 << "getShapeDefiningLoopRange shape: " << shape << "\n"); 106 return ShapeDimension{shape, static_cast<unsigned>(en2.index())}; 107 } 108 } 109 } 110 llvm_unreachable("Expect to be able to extract a shape defining loop range"); 111 } 112 113 /// Fuses the producer by cloning the `producer`. The `fusedLoopsAndRanges` 114 /// provides the loop range information for the fused loops. The rest are 115 /// obtained from the producer itself, since they are not tiled + fused. 116 static LinalgOp fuse(OpBuilder &builder, LinalgOp producer, 117 const DenseMap<unsigned, Range> &fusedLoopsAndRanges) { 118 SmallVector<Value, 8> ivs, tileSizes, sizeBounds; 119 SmallVector<Range, 8> loopRanges; 120 auto zero = std_constant_index(0); 121 auto one = std_constant_index(1); 122 Location loc = producer.getLoc(); 123 124 for (unsigned i = 0, e = producer.getNumLoops(); i < e; ++i) { 125 auto it = fusedLoopsAndRanges.find(i); 126 if (it != fusedLoopsAndRanges.end()) { 127 ivs.push_back(it->second.offset); 128 tileSizes.push_back(it->second.size); 129 sizeBounds.push_back(nullptr); 130 loopRanges.push_back(it->second); 131 LLVM_DEBUG(llvm::dbgs() << "tiled loop#" << i << " with LoopRange " 132 << loopRanges.back() << "\n"); 133 } else { 134 auto shapeDim = getShapeDefiningLoopRange(producer, i); 135 Value dim = memref_dim(shapeDim.shape, shapeDim.dimension); 136 tileSizes.push_back(zero); 137 sizeBounds.push_back(dim); 138 loopRanges.push_back(Range{zero, dim, one}); 139 LLVM_DEBUG(llvm::dbgs() << "full loop#" << i << " with LoopRange " 140 << loopRanges.back() << "\n"); 141 } 142 } 143 144 SmallVector<Value, 8> clonedShapes; 145 clonedShapes.reserve(producer.getNumShapedOperands()); 146 147 // Compute subranges for all tensor input/output operands. 148 auto tiledOperands = llvm::to_vector<4>(producer.getShapedOperands()); 149 clonedShapes.append(makeTiledShapes(builder, loc, producer, tiledOperands, 150 ivs, tileSizes, sizeBounds)); 151 152 // Append the other operands. 153 auto operands = producer.getAssumedNonShapedOperands(); 154 clonedShapes.append(operands.begin(), operands.end()); 155 156 // Iterate over the results in order. 157 // Extract the subtensor type from the linearized range. 158 // Since we do not enforce any canonicalizations on the fly, this is always 159 // fully dynamic at construction time. 160 SmallVector<Type, 4> resultTypes; 161 resultTypes.reserve(producer->getNumResults()); 162 for (RankedTensorType t : producer.getOutputTensorTypes()) { 163 unsigned rank = t.getRank(); 164 SmallVector<int64_t, 4> staticOffsetsVector( 165 rank, ShapedType::kDynamicStrideOrOffset); 166 SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize); 167 SmallVector<int64_t, 4> staticStridesVector( 168 rank, ShapedType::kDynamicStrideOrOffset); 169 resultTypes.push_back(SubTensorOp::inferResultType( 170 t.cast<RankedTensorType>(), staticOffsetsVector, staticSizesVector, 171 staticStridesVector)); 172 } 173 174 Operation *clonedOp = producer.clone(builder, loc, resultTypes, clonedShapes); 175 // When the producer is an IndexedGenericOp, we have to transform its block 176 // IV arguments according to the tiling of the consumer, i.e. offset them by 177 // the values computed in `loopRanges`. 178 if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) { 179 auto &block = indexedGenericOp.region().front(); 180 OpBuilder::InsertionGuard g(builder); 181 builder.setInsertionPointToStart(&block); 182 for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) { 183 Value oldIndex = block.getArgument(i); 184 // TODO: replace by an affine_apply. 185 AddIOp newIndex = builder.create<AddIOp>(indexedGenericOp.getLoc(), 186 oldIndex, loopRanges[i].offset); 187 oldIndex.replaceAllUsesExcept(newIndex, 188 SmallPtrSet<Operation *, 1>{newIndex}); 189 } 190 } 191 192 return clonedOp; 193 } 194 195 /// Get the loop range for a dimension `dim` based on the `shapedOperand`. It is 196 /// expected to be defined by a subview op or a subtensor op. 197 static Range getRangeFromOperandShape(OpBuilder &b, Location loc, 198 Value shapedOperand, unsigned dim) { 199 Operation *shapeProducingOp = shapedOperand.getDefiningOp(); 200 if (auto subViewOp = dyn_cast<memref::SubViewOp>(shapeProducingOp)) 201 return subViewOp.getOrCreateRanges(b, loc)[dim]; 202 if (auto subTensorOp = dyn_cast<SubTensorOp>(shapeProducingOp)) 203 return subTensorOp.getOrCreateRanges(b, loc)[dim]; 204 llvm_unreachable("SubviewOp or SubTensorOp expected"); 205 } 206 207 /// Fuses the producer of `producerIdx` into the loop immediately enclosing 208 /// `consumer`. This is achieved by "recomputing" the `producer` at the time it 209 /// is needed just before the `consumer. 210 /// 211 /// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are 212 /// 2 cases: 213 /// 1. Buffer case: `producerIdx` is the index of the buffer in 214 /// `producer.getOutputBuffers()`. 215 /// 2. Tensor case: `producerIdx` is the index of the tensor in 216 /// `producer.getResults()`. 217 static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp, AffineMap producerMap, 218 OpOperand &consumerOpOperand) { 219 LLVM_DEBUG(llvm::dbgs() << "Producer map: " << producerMap << "\n"); 220 DenseMap<unsigned, Range> fusedLoopsAndRanges; 221 Value shapedOperand = consumerOpOperand.get(); 222 for (auto en : llvm::enumerate(producerMap.getResults())) { 223 unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition(); 224 fusedLoopsAndRanges[posInProducerLoop] = getRangeFromOperandShape( 225 b, consumerOpOperand.getOwner()->getLoc(), shapedOperand, en.index()); 226 } 227 return fuse(b, producerOp, fusedLoopsAndRanges); 228 } 229 230 // Encode structural fusion safety preconditions. 231 // Some of these will be lifted in the future with better analysis. 232 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView, 233 LinalgOp consumer) { 234 assert(producer.hasBufferSemantics() && 235 "expected linalg op with buffer semantics"); 236 assert(consumer.hasBufferSemantics() && 237 "expected linalg op with buffer semantics"); 238 if (producer.getNumOutputs() != 1) { 239 LLVM_DEBUG(llvm::dbgs() << "\nNot structurally fusable (multi-output)"); 240 return false; 241 } 242 // Only fuse when the producer block dominates. 243 DominanceInfo dom(producer.getOperation()); 244 if (!dom.dominates(producer->getBlock(), consumer->getBlock())) { 245 LLVM_DEBUG( 246 llvm::dbgs() 247 << "\nNot structurally fusable (producer block does not dominate)"); 248 return false; 249 } 250 return true; 251 } 252 253 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph, 254 LinalgOp consumer, 255 Value consumedView, 256 LinalgOp producer) { 257 assert(producer.hasBufferSemantics() && 258 "expected linalg op with buffer semantics"); 259 assert(consumer.hasBufferSemantics() && 260 "expected linalg op with buffer semantics"); 261 // Make some simple structural checks that alleviate the need for more 262 // complex analyses. 263 if (!isStructurallyFusableProducer(producer, consumedView, consumer)) { 264 LLVM_DEBUG(llvm::dbgs() << "\n***Not static last write due to structure:\t" 265 << *producer.getOperation()); 266 return false; 267 } 268 // Check for any interleaved write to consumedView. 269 if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) { 270 LLVM_DEBUG(llvm::dbgs() << "\n***Not fusable due to interleaved write:\t" 271 << *producer.getOperation()); 272 return false; 273 } 274 return true; 275 } 276 277 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, 278 LinalgOp consumer, Value consumedView, 279 LinalgOp producer) { 280 assert(producer.hasBufferSemantics() && 281 "expected linalg op with buffer semantics"); 282 assert(consumer.hasBufferSemantics() && 283 "expected linalg op with buffer semantics"); 284 if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer)) 285 return false; 286 // Check for any fusion-preventing dependence to any shape read/written that 287 // would violate dependences. 288 if (!graph.findCoveringDependences(producer, consumer).empty()) { 289 LLVM_DEBUG(llvm::dbgs() 290 << "\n***Not fusable due to an interleaved dependence:\t" 291 << *producer.getOperation()); 292 return false; 293 } 294 if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) { 295 // TODO: add a level of indirection to linalg.generic. 296 if (convOp.padding()) 297 return false; 298 } 299 if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) { 300 // TODO: add a level of indirection to linalg.generic. 301 if (convOp.padding()) 302 return false; 303 } 304 return true; 305 } 306 307 /// For `consumer` with buffer semantics, find the Linalg operation on buffers 308 /// that is the last writer of `consumerOpOperand`. For now the fusable 309 /// dependence is returned as an instance of the `dependenceGraph`. 310 static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> 311 findFusableProducer(OpOperand &consumerOpOperand, 312 const LinalgDependenceGraph &dependenceGraph) { 313 LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.getOwner()); 314 if (!consumerOp) 315 return {}; 316 317 // Only consider RAW and WAW atm. 318 for (auto depType : { 319 LinalgDependenceGraph::DependenceType::RAW, 320 LinalgDependenceGraph::DependenceType::WAW, 321 }) { 322 for (auto dependence : llvm::make_filter_range( 323 dependenceGraph.getDependencesInto(consumerOp, depType), 324 [&](LinalgDependenceGraph::LinalgDependenceGraphElem elem) { 325 Value v = elem.getIndexingValue(); 326 Optional<unsigned> operandNum = 327 elem.getIndexingOpViewOperandNum(); 328 return isa<LinalgOp>(elem.getDependentOp()) && 329 v == consumerOpOperand.get() && operandNum && 330 operandNum.getValue() == 331 consumerOpOperand.getOperandNumber(); 332 })) { 333 // Consumer consumes this view, `isStructurallyFusableProducer` also 334 // checks whether it is a strict subview of the producer view. 335 auto producer = cast<LinalgOp>(dependence.getDependentOp()); 336 LLVM_DEBUG(llvm::dbgs() 337 << "\n" 338 << LinalgDependenceGraph::getDependenceTypeStr(depType) 339 << "producer: " << *dependence.getDependentOp() 340 << " view: " << dependence.getDependentValue() << "\n"); 341 342 // If the producer and consumer have tensor semantics, the only dependence 343 // between them is through a RAW dependence and they are fusable by 344 // construction. For buffer semantics need additional checks. 345 if (producer.hasBufferSemantics() && consumerOp.hasBufferSemantics() && 346 isFusableInto(dependenceGraph, consumerOp, consumerOpOperand.get(), 347 producer)) 348 return dependence; 349 if (producer.hasTensorSemantics() && consumerOp.hasTensorSemantics()) { 350 assert(dependence.dependenceType == 351 LinalgDependenceGraph::DependenceType::RAW); 352 return dependence; 353 } 354 } 355 } 356 return {}; 357 } 358 359 Optional<FusionInfo> 360 mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, OpOperand &consumerOpOperand, 361 const LinalgDependenceGraph &graph) { 362 Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence = 363 findFusableProducer(consumerOpOperand, graph); 364 if (!fusableDependence) 365 return llvm::None; 366 367 LinalgOp producerOp = dyn_cast<LinalgOp>(fusableDependence->getDependentOp()); 368 if (!producerOp) 369 return llvm::None; 370 371 // If producer is already in the same block as consumer, we are done. 372 if (consumerOpOperand.get().getParentBlock() == 373 fusableDependence->getDependentValue().getParentBlock()) 374 return llvm::None; 375 376 Optional<AffineMap> producerMap = 377 fusableDependence->getDependentOpViewIndexingMap(); 378 if (!producerMap) 379 return llvm::None; 380 381 // Must be a subview or a slice to guarantee there are loops we can fuse 382 // into. 383 auto subView = consumerOpOperand.get().getDefiningOp<memref::SubViewOp>(); 384 if (!subView) { 385 LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview)"); 386 return llvm::None; 387 } 388 389 // Fuse `producer` just before `consumer`. 390 OpBuilder::InsertionGuard g(b); 391 b.setInsertionPoint(consumerOpOperand.getOwner()); 392 ScopedContext scope(b, consumerOpOperand.getOwner()->getLoc()); 393 LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " 394 << *consumerOpOperand.getOwner() << "\n"); 395 396 auto fusedProducer = fuse(b, producerOp, *producerMap, consumerOpOperand); 397 return FusionInfo{producerOp, fusedProducer}; 398 } 399 400 /// Walk back use-def chain through scf::For yields. 401 /// Sets `producer` and `outputIndex` if it finds a producer LinalgOp 402 403 // TODO(ravishankarm, ntv): This can be moved into the dependence graphs 404 // dependence tracking since the dependence tracking is similar to what is done 405 // w.r.t to buffers. 406 static void getProducerOfTensor(Value tensor, OpResult &opResult) { 407 if (!tensor.getType().isa<RankedTensorType>()) 408 return; 409 410 while (true) { 411 LLVM_DEBUG(llvm::dbgs() << "\ngetProducerOfTensor: " << tensor); 412 if (auto linalgOp = tensor.getDefiningOp<LinalgOp>()) { 413 opResult = tensor.cast<OpResult>(); 414 return; 415 } 416 if (auto subTensorOp = tensor.getDefiningOp<SubTensorOp>()) { 417 tensor = subTensorOp.source(); 418 continue; 419 } 420 if (auto blockArg = tensor.dyn_cast<BlockArgument>()) { 421 if (auto forOp = blockArg.getDefiningOp<scf::ForOp>()) { 422 tensor = *(forOp.getIterOperands().begin() + blockArg.getArgNumber()); 423 continue; 424 } 425 } 426 return; 427 } 428 } 429 430 Optional<FusionInfo> 431 mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) { 432 Value inputTensor = consumerOpOperand.get(); 433 OpResult producerOpResult; 434 getProducerOfTensor(inputTensor, producerOpResult); 435 if (!producerOpResult) { 436 LLVM_DEBUG(llvm::dbgs() << "\nUnable to find producer"); 437 return {}; 438 } 439 return fuseProducerOfTensor(b, producerOpResult, consumerOpOperand); 440 } 441 442 Optional<FusionInfo> 443 mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult, 444 OpOperand &consumerOpOperand) { 445 auto producerOp = dyn_cast<LinalgOp>(producerOpResult.getOwner()); 446 if (!producerOp) 447 return llvm::None; 448 449 LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.getOwner()); 450 if (!consumerOp) 451 return llvm::None; 452 453 Value inputTensor = consumerOpOperand.get(); 454 455 // Must be a subtensor to guarantee there are loops we can fuse into. 456 auto subTensor = inputTensor.getDefiningOp<SubTensorOp>(); 457 if (!subTensor) { 458 LLVM_DEBUG(llvm::dbgs() 459 << "\nNot fusable, not a subtensor: " << inputTensor); 460 return {}; 461 } 462 463 // If producer is already in the same block as consumer, we are done. 464 if (consumerOpOperand.get().getParentBlock() == 465 producerOpResult.getParentBlock()) 466 return {}; 467 468 // Insert fused `producer` just before `consumer`. 469 OpBuilder::InsertionGuard g(b); 470 b.setInsertionPoint(consumerOp); 471 ScopedContext scope(b, consumerOp->getLoc()); 472 LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumerOp << "\n"); 473 LinalgOp fusedProducer = 474 fuse(b, producerOp, 475 producerOp.getOutputIndexingMap(producerOpResult.getResultNumber()), 476 consumerOpOperand); 477 478 // Replace use. 479 // Canonicalizations are not guaranteed to have happened before constructing 480 // `fusedProducer`. In the tensor case this can result in temporary type 481 // mismatches. Insert a `tensor.cast` op to propagate the transformation 482 // invariant that types are compatible. 483 Value def = fusedProducer->getResult(producerOpResult.getResultNumber()); 484 Type consumerType = consumerOpOperand.get().getType(); 485 if (consumerType != def.getType()) 486 def = b.create<tensor::CastOp>(fusedProducer.getLoc(), consumerType, def); 487 consumerOpOperand.set(def); 488 return FusionInfo{cast<LinalgOp>(producerOpResult.getOwner()), fusedProducer}; 489 } 490 491 /// Prune all dimensions that are of reduction iterator type from `map`. 492 static AffineMap pruneReductionDimsFromMap(ArrayRef<Attribute> iteratorTypes, 493 AffineMap map) { 494 llvm::SmallDenseSet<unsigned> projectedDims; 495 for (auto attr : llvm::enumerate(iteratorTypes)) { 496 if (!isParallelIterator(attr.value())) 497 projectedDims.insert(attr.index()); 498 } 499 return getProjectedMap(map, projectedDims); 500 } 501 502 /// Returns the mapping from iterations in the consumer that write to the same 503 /// location as the iterations in the producer. To do so use 504 /// - indexing map of the fused view in the consumer : consumerIndexMap 505 /// - indexing map of the fused view in the producer : producerIndexMap 506 /// consumerLoopToProducerLoop = 507 /// inverse(producerIndexMap).compose(consumerIndexMap) 508 static Optional<AffineMap> getConsumerLoopToProducerLoopMap( 509 LinalgDependenceGraph::LinalgDependenceGraphElem dependence) { 510 auto producer = dyn_cast<LinalgOp>(dependence.getDependentOp()); 511 if (!producer) 512 return None; 513 514 Optional<AffineMap> producerIndexingMap = 515 dependence.getDependentOpViewIndexingMap(); 516 Optional<AffineMap> consumerIndexingMap = 517 dependence.getIndexingOpViewIndexingMap(); 518 if (!producerIndexingMap || !consumerIndexingMap) 519 return None; 520 521 AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap( 522 producer.iterator_types().getValue(), *producerIndexingMap); 523 if (!prunedProducerIndexingMap.isPermutation()) 524 return None; 525 526 if (consumerIndexingMap->getNumResults() != 527 prunedProducerIndexingMap.getNumResults()) 528 return None; 529 530 LLVM_DEBUG({ 531 llvm::dbgs() << "\t producerMap : "; 532 producerIndexingMap->print(llvm::dbgs()); 533 llvm::dbgs() << " pruned : "; 534 prunedProducerIndexingMap.print(llvm::dbgs()); 535 llvm::dbgs() << "\n"; 536 llvm::dbgs() << "\t consumerMap : "; 537 consumerIndexingMap->print(llvm::dbgs()); 538 llvm::dbgs() << "\n"; 539 }); 540 541 AffineMap invProducerIndexMap = inversePermutation(prunedProducerIndexingMap); 542 if (!invProducerIndexMap) 543 return None; 544 545 return invProducerIndexMap.compose(*consumerIndexingMap); 546 } 547 548 /// Given a projected permutation `map`, returns true if the map changes the 549 /// order in which the fused loop dimension appear. 550 static bool doesTransposeAccess(AffineMap map, 551 const std::set<unsigned> &fusableLoops) { 552 Optional<unsigned> lastFusableLoop; 553 for (unsigned pos : llvm::map_range(map.getResults(), [](AffineExpr expr) { 554 return expr.cast<AffineDimExpr>().getPosition(); 555 })) { 556 if (!fusableLoops.count(pos)) 557 continue; 558 if (!lastFusableLoop) { 559 lastFusableLoop = pos; 560 continue; 561 } 562 if (pos <= lastFusableLoop.getValue()) 563 return true; 564 lastFusableLoop = pos; 565 } 566 return false; 567 } 568 569 /// Returns the positions of the loop in `op` that can be tiled based on the 570 /// operations that are to be fused with it. For example, in a 571 /// 572 /// linalg.matmul ins(%a, %b : ...) outs(%c : ...) 573 /// 574 /// if the producer of %a needs to be fused with this op, only the `i` loop of 575 /// the matmul can be tiled while fusing. If producer of %a, and %b are to be 576 /// fused, then no loops can be tiled while fusing. The conditions used are: 577 /// 1. Only parallel loops can be used for tile + fuse. Find the number of 578 /// common outer parallel loops between the op and its producers being fused. 579 /// 2. Of the parallel loops only some can be fused. Only those loops can be 580 /// fused such where the fusable loops iteration space only touches one tile 581 /// of the fused operation. This is because the producer (which is writing 582 /// the fused subview) has update semantics. 583 /// 584 /// Since an inverse computation is needed, we need to consider the projection 585 /// of the producerIndexMap w.r.t the parallel loops. The actual fusable loops 586 /// are the dimensions of the consumerLoopToProducerLoop map that correspond to 587 /// parallel loops and appear in the result of the map 588 /// 589 /// Example 1: 590 /// linalg.fill(%c, %cst) 591 /// linalg.matmul ins(%a, %b) outs(%c) 592 /// Number of parallel loops : 2 593 /// producerIndexMap = affine_map<(i, j) ->(i , j)> 594 /// consumerIndexMap = affine_map<(i, j, k) -> (i, j)> 595 /// consumerLoopToProducerLoop = affine_map<(i, j, k) -> (i, j)> 596 /// Fused dimensions : i, j 597 /// 598 /// Example 2: 599 /// linalg.matmul ins(%a, %b) outs(%c) 600 /// linalg.generic {indexing_maps = [affine_map<(i, j) -> (j, i)>, ... 601 /// iterator_types = ["parallel", "parallel"]} 602 /// ins(%c) ... 603 /// 604 /// Number of parallel loops = 2: 605 /// producerIndexMap (projected to parallel loops) = 606 /// affine_map<(i, j) -> (i, j)> 607 /// consumerLoopToProducerLoop2 = affine_map<(i, j) -> (j, i)> 608 /// Fused dimensions : i, j 609 /// 610 /// Example 3: 611 /// linalg.copy(%s, %b) 612 /// linalg.matmul ins(%a, %b) outs(%c) 613 /// 614 /// Number of parallel loops = 2 615 /// produceIndexMap : affine_map<(i, j) -> (i, j)> 616 /// consumerLoopToProduceLoops = affine_map<(i, j, k) -> (k, j)> 617 /// submap with only parallel loops = affine_map<(i, j) -> (j)> 618 /// Fused dimensions : j 619 static std::set<unsigned> 620 collectFusableLoops(ArrayRef<LinalgOp> ops, 621 const FusableOpDependencesTy &fusableDependences) { 622 assert(!ops.empty()); 623 auto getNumOuterParallelLoops = [](LinalgOp linalgOp) { 624 return linalgOp.iterator_types() 625 .getValue() 626 .take_while([](Attribute attr) -> bool { 627 return attr.cast<StringAttr>().getValue() == 628 getParallelIteratorTypeName(); 629 }) 630 .size(); 631 }; 632 633 size_t numOuterParallelLoops = getNumOuterParallelLoops(ops.back()); 634 for (auto op : ops.drop_back()) { 635 numOuterParallelLoops = 636 std::min(numOuterParallelLoops, getNumOuterParallelLoops(op)); 637 } 638 639 std::set<unsigned> fusableLoops; 640 auto range = llvm::seq<unsigned>(0, numOuterParallelLoops); 641 fusableLoops.insert(range.begin(), range.end()); 642 643 for (auto op : reverse(ops)) { 644 for (auto dependence : fusableDependences.lookup(op)) { 645 LLVM_DEBUG({ 646 llvm::dbgs() << "\t fusable :"; 647 for (unsigned i : fusableLoops) 648 llvm::dbgs() << " " << i; 649 llvm::dbgs() << "\n"; 650 }); 651 652 Optional<AffineMap> consumerLoopToProducerLoop = 653 getConsumerLoopToProducerLoopMap(dependence); 654 if (!consumerLoopToProducerLoop) { 655 op.emitRemark("failed to get map from consumer loop to producer loop"); 656 return {}; 657 } 658 // todo: This condition is only an implementation limitation. When fusing 659 // the operation, if the accesses in the producer/consumer are transposes 660 // of each other, the loop bounds for the tiled producer can be 661 // manipulated accordingly. This requires some additional bookkeeping in 662 // the implementation of tile+fuse that is deferred to later. 663 if (doesTransposeAccess(*consumerLoopToProducerLoop, fusableLoops)) { 664 op.emitRemark("unhandled fusion when fusion requires permutation"); 665 return {}; 666 } 667 668 std::set<unsigned> candidates; 669 for (AffineExpr expr : consumerLoopToProducerLoop->getResults()) { 670 unsigned position = expr.cast<AffineDimExpr>().getPosition(); 671 if (fusableLoops.count(position)) 672 candidates.insert(position); 673 } 674 LLVM_DEBUG({ 675 llvm::dbgs() << "\t candidates :"; 676 for (unsigned i : candidates) 677 llvm::dbgs() << " " << i; 678 llvm::dbgs() << "\n"; 679 }); 680 if (candidates.empty()) 681 return {}; 682 std::swap(candidates, fusableLoops); 683 } 684 } 685 686 return fusableLoops; 687 } 688 689 /// Find all dependences that are fusable. 690 FusableOpDependencesTy mlir::linalg::findAllFusableDependences( 691 ArrayRef<LinalgOp> ops, const LinalgDependenceGraph &dependenceGraph) { 692 FusableOpDependencesTy fusableDependences; 693 DenseMap<Operation *, SmallVector<AffineMap, 1>> fusedProducerIndexingMap; 694 for (LinalgOp op : reverse(ops)) { 695 for (OpOperand &opOperand : op.getShapedOpOperands()) { 696 Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> 697 fusableDependence = findFusableProducer(opOperand, dependenceGraph); 698 if (!fusableDependence) 699 continue; 700 LinalgOp producerOp = 701 dyn_cast<LinalgOp>(fusableDependence->getDependentOp()); 702 if (!producerOp) 703 continue; 704 // Do not fuse dependences that are to operations not in the same basic 705 // block. This avoid moving fused operations across loops that might 706 // themselves carry dependency making the fusion illegal. 707 if (producerOp->getBlock() != op->getBlock()) 708 continue; 709 710 // Make sure that the indexing map of the view used for fusion in the 711 // producer is a projected permutation. 712 Optional<AffineMap> producerMap = 713 fusableDependence->getDependentOpViewIndexingMap(); 714 Optional<AffineMap> consumerMap = 715 fusableDependence->getIndexingOpViewIndexingMap(); 716 assert( 717 consumerMap && 718 "unable to find indexing map of operand/result of indexing OpView"); 719 fusedProducerIndexingMap[producerOp.getOperation()].push_back( 720 *consumerMap); 721 if (!producerMap || !producerMap->isProjectedPermutation() || 722 !consumerMap->isProjectedPermutation()) 723 continue; 724 725 fusableDependences[producerOp.getOperation()].push_back( 726 *fusableDependence); 727 } 728 } 729 // TODO: Currently fusion would not be legal if the fusable dependence is to 730 // the same producer but different indexing map in the consumer. Fix this, but 731 // in the meanwhile disallow such a fusion. 732 for (auto useIndexingMapsList : fusedProducerIndexingMap) { 733 AffineMap map1 = useIndexingMapsList.second.front(); 734 for (AffineMap map2 : 735 ArrayRef<AffineMap>(useIndexingMapsList.second).drop_front()) { 736 if (map1 != map2) { 737 fusableDependences.erase(useIndexingMapsList.first); 738 break; 739 } 740 } 741 } 742 return fusableDependences; 743 } 744 745 /// Tile the fused loops in the root operation, by setting the tile sizes for 746 /// all other loops to zero (those will be tiled later). 747 static Optional<TiledLinalgOp> tileRootOperation( 748 OpBuilder &builder, LinalgOp op, ArrayRef<Value> tileSizeVector, 749 const LinalgTilingOptions &options, const std::set<unsigned> &fusedLoops) { 750 SmallVector<Value, 4> tileSizes(tileSizeVector.begin(), tileSizeVector.end()); 751 auto zero = std_constant_index(0); 752 for (unsigned i = 0, e = tileSizes.size(); i != e; ++i) 753 if (!fusedLoops.count(i)) 754 tileSizes[i] = zero; 755 LinalgTilingOptions tileFusedLoopsOptions = options; 756 tileFusedLoopsOptions.setTileSizes(tileSizes); 757 return tileLinalgOp(builder, op, tileFusedLoopsOptions); 758 } 759 760 /// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected 761 /// to be a tiled operation such that it is valid to fuse all operations in 762 /// `fusionCandidates`, i.e. move the operation within the inter-tile loops of 763 /// `tiledOp`. 764 static SmallVector<LinalgOp, 1> 765 fuseOperations(OpBuilder &builder, LinalgOp rootOp, LinalgOp tiledOp, 766 ArrayRef<LinalgOp> fusionCandidates, 767 const FusableOpDependencesTy &fusableDependences, 768 const std::set<unsigned> &fusedLoops) { 769 OpBuilder::InsertionGuard guard(builder); 770 builder.setInsertionPoint(tiledOp); 771 DenseMap<unsigned, Range> fusedLoopsAndRanges; 772 for (unsigned loop : fusedLoops) { 773 ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop, true); 774 fusedLoopsAndRanges[loop] = getRangeFromOperandShape( 775 builder, tiledOp.getLoc(), shapeDim.shape, shapeDim.dimension); 776 } 777 778 SmallVector<LinalgOp, 1> fusedOps(fusionCandidates.size()); 779 DenseMap<Operation *, LinalgOp> origOpToFusedOp; 780 origOpToFusedOp[rootOp.getOperation()] = tiledOp; 781 for (auto candidate : enumerate(llvm::reverse(fusionCandidates))) { 782 LinalgOp origOp = candidate.value(); 783 LinalgOp fusedOp = fuse(builder, origOp, fusedLoopsAndRanges); 784 origOpToFusedOp[origOp.getOperation()] = fusedOp; 785 fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp; 786 // If the producer consumer operations are linalg operations on tensors, the 787 // dependence is due to value produced (as a return tensor) by the producer 788 // and used in the consumer. The returned value of the fused op needs to be 789 // made the operand of the tiled/fused consumer operation. By construction 790 // the value returned by the producer is the value used by the consumer. 791 for (auto &dependence : fusableDependences.lookup(origOp.getOperation())) { 792 if (origOp.hasTensorSemantics() && 793 dependence.dependenceType == 794 LinalgDependenceGraph::DependenceType::RAW) { 795 unsigned resultIndex = 796 dependence.getDependentOpViewResultNum().getValue(); 797 LinalgOp consumer = origOpToFusedOp.lookup(dependence.getIndexingOp()); 798 if (!consumer) 799 continue; 800 Value replacementValue = fusedOp.getOperation()->getResult(resultIndex); 801 consumer.getOperation()->setOperand( 802 dependence.getIndexingOpViewOperandNum().getValue(), 803 replacementValue); 804 } 805 } 806 builder.setInsertionPoint(fusedOp); 807 } 808 return fusedOps; 809 } 810 811 template <typename LoopType> 812 static Optional<TiledAndFusedLinalgOps> 813 tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef<LinalgOp> ops, 814 const LinalgDependenceGraph &dependenceGraph, 815 const LinalgTilingOptions &tilingOptions) { 816 if (ops.size() < 2) 817 return llvm::None; 818 LinalgOp rootOp = ops.back(); 819 if (!llvm::all_of( 820 ops, 821 [](LinalgOp linalgOp) { return linalgOp.hasBufferSemantics(); }) && 822 !llvm::all_of(ops, [](LinalgOp linalgOp) { 823 return linalgOp.hasTensorSemantics(); 824 })) { 825 rootOp.emitError( 826 "unable to fuse operations that have tensor semantics with operations " 827 "that have buffer semantics and viceversa."); 828 return llvm::None; 829 } 830 // TODO: Support interchange with tile + fuse. This might actually help do 831 // better fusion. 832 if (!tilingOptions.interchangeVector.empty()) { 833 rootOp.emitRemark("unable to handle tile and fuse with interchange"); 834 return llvm::None; 835 } 836 837 OpBuilder::InsertionGuard guard(builder); 838 builder.setInsertionPoint(rootOp); 839 ScopedContext scope(builder, rootOp.getLoc()); 840 841 // Find all the producers. 842 FusableOpDependencesTy fusableDependences = 843 findAllFusableDependences(ops, dependenceGraph); 844 if (fusableDependences.empty()) 845 return llvm::None; 846 847 TiledAndFusedLinalgOps ret; 848 // Find the loops that can be tiled and fused. 849 ret.fusedLoopDims = collectFusableLoops(ops, fusableDependences); 850 851 // If there are no fusable dependences or there are no tile+fusable loops, 852 // just return. 853 if (ret.fusedLoopDims.empty()) { 854 return llvm::None; 855 } 856 857 // Tile the fused loops in the last operation in the list. 858 SmallVector<Value, 4> tileSizeVector = 859 tilingOptions.tileSizeComputationFunction(builder, rootOp); 860 Optional<TiledLinalgOp> tiledRootOp = tileRootOperation( 861 builder, rootOp, tileSizeVector, tilingOptions, ret.fusedLoopDims); 862 if (!tiledRootOp) { 863 rootOp.emitRemark("failed to tile the fused loops"); 864 return llvm::None; 865 } 866 ret.op = tiledRootOp->op; 867 ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end()); 868 869 // Fuse the other operations into the fused inter-tile loops produced above. 870 ret.fusedProducers = fuseOperations(builder, rootOp, ret.op, ops.drop_back(), 871 fusableDependences, ret.fusedLoopDims); 872 873 return ret; 874 } 875 876 Optional<TiledAndFusedLinalgOps> 877 mlir::linalg::tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops, 878 const LinalgDependenceGraph &dependenceGraph, 879 const LinalgTilingOptions &tilingOptions) { 880 switch (tilingOptions.loopType) { 881 case LinalgTilingLoopType::Loops: 882 return tileAndFuseLinalgOpsImpl<scf::ForOp>(builder, ops, dependenceGraph, 883 tilingOptions); 884 case LinalgTilingLoopType::ParallelLoops: 885 return tileAndFuseLinalgOpsImpl<scf::ParallelOp>( 886 builder, ops, dependenceGraph, tilingOptions); 887 default:; 888 } 889 return llvm::None; 890 } 891