1 //===- LoopFusionUtils.cpp ---- Utilities for loop fusion ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop fusion transformation utility functions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/LoopFusionUtils.h"
14 #include "mlir/Analysis/SliceAnalysis.h"
15 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
16 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
17 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
18 #include "mlir/Dialect/Affine/Analysis/Utils.h"
19 #include "mlir/Dialect/Affine/IR/AffineOps.h"
20 #include "mlir/Dialect/Affine/LoopUtils.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/IR/BlockAndValueMapping.h"
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/BuiltinOps.h"
26 #include "mlir/IR/Operation.h"
27 #include "llvm/ADT/DenseMap.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/raw_ostream.h"
31 
32 #define DEBUG_TYPE "loop-fusion-utils"
33 
34 using namespace mlir;
35 
36 // Gathers all load and store memref accesses in 'opA' into 'values', where
37 // 'values[memref] == true' for each store operation.
38 static void getLoadAndStoreMemRefAccesses(Operation *opA,
39                                           DenseMap<Value, bool> &values) {
40   opA->walk([&](Operation *op) {
41     if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) {
42       if (values.count(loadOp.getMemRef()) == 0)
43         values[loadOp.getMemRef()] = false;
44     } else if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) {
45       values[storeOp.getMemRef()] = true;
46     }
47   });
48 }
49 
50 /// Returns true if 'op' is a load or store operation which access a memref
51 /// accessed 'values' and at least one of the access is a store operation.
52 /// Returns false otherwise.
53 static bool isDependentLoadOrStoreOp(Operation *op,
54                                      DenseMap<Value, bool> &values) {
55   if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) {
56     return values.count(loadOp.getMemRef()) > 0 && values[loadOp.getMemRef()];
57   }
58   if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) {
59     return values.count(storeOp.getMemRef()) > 0;
60   }
61   return false;
62 }
63 
64 // Returns the first operation in range ('opA', 'opB') which has a data
65 // dependence on 'opA'. Returns 'nullptr' of no dependence exists.
66 static Operation *getFirstDependentOpInRange(Operation *opA, Operation *opB) {
67   // Record memref values from all loads/store in loop nest rooted at 'opA'.
68   // Map from memref value to bool which is true if store, false otherwise.
69   DenseMap<Value, bool> values;
70   getLoadAndStoreMemRefAccesses(opA, values);
71 
72   // For each 'opX' in block in range ('opA', 'opB'), check if there is a data
73   // dependence from 'opA' to 'opX' ('opA' and 'opX' access the same memref
74   // and at least one of the accesses is a store).
75   Operation *firstDepOp = nullptr;
76   for (Block::iterator it = std::next(Block::iterator(opA));
77        it != Block::iterator(opB); ++it) {
78     Operation *opX = &(*it);
79     opX->walk([&](Operation *op) {
80       if (!firstDepOp && isDependentLoadOrStoreOp(op, values))
81         firstDepOp = opX;
82     });
83     if (firstDepOp)
84       break;
85   }
86   return firstDepOp;
87 }
88 
89 // Returns the last operation 'opX' in range ('opA', 'opB'), for which there
90 // exists a data dependence from 'opX' to 'opB'.
91 // Returns 'nullptr' of no dependence exists.
92 static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) {
93   // Record memref values from all loads/store in loop nest rooted at 'opB'.
94   // Map from memref value to bool which is true if store, false otherwise.
95   DenseMap<Value, bool> values;
96   getLoadAndStoreMemRefAccesses(opB, values);
97 
98   // For each 'opX' in block in range ('opA', 'opB') in reverse order,
99   // check if there is a data dependence from 'opX' to 'opB':
100   // *) 'opX' and 'opB' access the same memref and at least one of the accesses
101   //    is a store.
102   // *) 'opX' produces an SSA Value which is used by 'opB'.
103   Operation *lastDepOp = nullptr;
104   for (Block::reverse_iterator it = std::next(Block::reverse_iterator(opB));
105        it != Block::reverse_iterator(opA); ++it) {
106     Operation *opX = &(*it);
107     opX->walk([&](Operation *op) {
108       if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
109         if (isDependentLoadOrStoreOp(op, values)) {
110           lastDepOp = opX;
111           return WalkResult::interrupt();
112         }
113         return WalkResult::advance();
114       }
115       for (auto value : op->getResults()) {
116         for (Operation *user : value.getUsers()) {
117           SmallVector<AffineForOp, 4> loops;
118           // Check if any loop in loop nest surrounding 'user' is 'opB'.
119           getLoopIVs(*user, &loops);
120           if (llvm::is_contained(loops, cast<AffineForOp>(opB))) {
121             lastDepOp = opX;
122             return WalkResult::interrupt();
123           }
124         }
125       }
126       return WalkResult::advance();
127     });
128     if (lastDepOp)
129       break;
130   }
131   return lastDepOp;
132 }
133 
134 // Computes and returns an insertion point operation, before which the
135 // the fused <srcForOp, dstForOp> loop nest can be inserted while preserving
136 // dependences. Returns nullptr if no such insertion point is found.
137 static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp,
138                                                  AffineForOp dstForOp) {
139   bool isSrcForOpBeforeDstForOp =
140       srcForOp->isBeforeInBlock(dstForOp.getOperation());
141   auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
142   auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
143 
144   auto *firstDepOpA =
145       getFirstDependentOpInRange(forOpA.getOperation(), forOpB.getOperation());
146   auto *lastDepOpB =
147       getLastDependentOpInRange(forOpA.getOperation(), forOpB.getOperation());
148   // Block:
149   //      ...
150   //  |-- opA
151   //  |   ...
152   //  |   lastDepOpB --|
153   //  |   ...          |
154   //  |-> firstDepOpA  |
155   //      ...          |
156   //      opB <---------
157   //
158   // Valid insertion point range: (lastDepOpB, firstDepOpA)
159   //
160   if (firstDepOpA != nullptr) {
161     if (lastDepOpB != nullptr) {
162       if (firstDepOpA->isBeforeInBlock(lastDepOpB) || firstDepOpA == lastDepOpB)
163         // No valid insertion point exists which preserves dependences.
164         return nullptr;
165     }
166     // Return insertion point in valid range closest to 'opB'.
167     // TODO: Consider other insertion points in valid range.
168     return firstDepOpA;
169   }
170   // No dependences from 'opA' to operation in range ('opA', 'opB'), return
171   // 'opB' insertion point.
172   return forOpB.getOperation();
173 }
174 
175 // Gathers all load and store ops in loop nest rooted at 'forOp' into
176 // 'loadAndStoreOps'.
177 static bool
178 gatherLoadsAndStores(AffineForOp forOp,
179                      SmallVectorImpl<Operation *> &loadAndStoreOps) {
180   bool hasIfOp = false;
181   forOp.walk([&](Operation *op) {
182     if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op))
183       loadAndStoreOps.push_back(op);
184     else if (isa<AffineIfOp>(op))
185       hasIfOp = true;
186   });
187   return !hasIfOp;
188 }
189 
190 /// Returns the maximum loop depth at which we could fuse producer loop
191 /// 'srcForOp' into consumer loop 'dstForOp' without violating data dependences.
192 // TODO: Generalize this check for sibling and more generic fusion scenarios.
193 // TODO: Support forward slice fusion.
194 static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps,
195                                 ArrayRef<Operation *> dstOps) {
196   if (dstOps.empty())
197     // Expected at least one memory operation.
198     // TODO: Revisit this case with a specific example.
199     return 0;
200 
201   // Filter out ops in 'dstOps' that do not use the producer-consumer memref so
202   // that they are not considered for analysis.
203   DenseSet<Value> producerConsumerMemrefs;
204   gatherProducerConsumerMemrefs(srcOps, dstOps, producerConsumerMemrefs);
205   SmallVector<Operation *, 4> targetDstOps;
206   for (Operation *dstOp : dstOps) {
207     auto loadOp = dyn_cast<AffineReadOpInterface>(dstOp);
208     Value memref = loadOp ? loadOp.getMemRef()
209                           : cast<AffineWriteOpInterface>(dstOp).getMemRef();
210     if (producerConsumerMemrefs.count(memref) > 0)
211       targetDstOps.push_back(dstOp);
212   }
213 
214   assert(!targetDstOps.empty() &&
215          "No dependences between 'srcForOp' and 'dstForOp'?");
216 
217   // Compute the innermost common loop depth for loads and stores.
218   unsigned loopDepth = getInnermostCommonLoopDepth(targetDstOps);
219 
220   // Return common loop depth for loads if there are no store ops.
221   if (all_of(targetDstOps,
222              [&](Operation *op) { return isa<AffineReadOpInterface>(op); }))
223     return loopDepth;
224 
225   // Check dependences on all pairs of ops in 'targetDstOps' and store the
226   // minimum loop depth at which a dependence is satisfied.
227   for (unsigned i = 0, e = targetDstOps.size(); i < e; ++i) {
228     auto *srcOpInst = targetDstOps[i];
229     MemRefAccess srcAccess(srcOpInst);
230     for (unsigned j = 0; j < e; ++j) {
231       auto *dstOpInst = targetDstOps[j];
232       MemRefAccess dstAccess(dstOpInst);
233 
234       unsigned numCommonLoops =
235           getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst);
236       for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
237         FlatAffineValueConstraints dependenceConstraints;
238         // TODO: Cache dependence analysis results, check cache here.
239         DependenceResult result = checkMemrefAccessDependence(
240             srcAccess, dstAccess, d, &dependenceConstraints,
241             /*dependenceComponents=*/nullptr);
242         if (hasDependence(result)) {
243           // Store minimum loop depth and break because we want the min 'd' at
244           // which there is a dependence.
245           loopDepth = std::min(loopDepth, d - 1);
246           break;
247         }
248       }
249     }
250   }
251 
252   return loopDepth;
253 }
254 
255 // TODO: Prevent fusion of loop nests with side-effecting operations.
256 // TODO: This pass performs some computation that is the same for all the depths
257 // (e.g., getMaxLoopDepth). Implement a version of this utility that processes
258 // all the depths at once or only the legal maximal depth for maximal fusion.
259 FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
260                                 unsigned dstLoopDepth,
261                                 ComputationSliceState *srcSlice,
262                                 FusionStrategy fusionStrategy) {
263   // Return 'failure' if 'dstLoopDepth == 0'.
264   if (dstLoopDepth == 0) {
265     LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n");
266     return FusionResult::FailPrecondition;
267   }
268   // Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block.
269   auto *block = srcForOp->getBlock();
270   if (block != dstForOp->getBlock()) {
271     LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n");
272     return FusionResult::FailPrecondition;
273   }
274 
275   // Return 'failure' if no valid insertion point for fused loop nest in 'block'
276   // exists which would preserve dependences.
277   if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) {
278     LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n");
279     return FusionResult::FailBlockDependence;
280   }
281 
282   // Check if 'srcForOp' precedes 'dstForOp' in 'block'.
283   bool isSrcForOpBeforeDstForOp =
284       srcForOp->isBeforeInBlock(dstForOp.getOperation());
285   // 'forOpA' executes before 'forOpB' in 'block'.
286   auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
287   auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
288 
289   // Gather all load and store from 'forOpA' which precedes 'forOpB' in 'block'.
290   SmallVector<Operation *, 4> opsA;
291   if (!gatherLoadsAndStores(forOpA, opsA)) {
292     LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n");
293     return FusionResult::FailPrecondition;
294   }
295 
296   // Gather all load and store from 'forOpB' which succeeds 'forOpA' in 'block'.
297   SmallVector<Operation *, 4> opsB;
298   if (!gatherLoadsAndStores(forOpB, opsB)) {
299     LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n");
300     return FusionResult::FailPrecondition;
301   }
302 
303   // Return 'failure' if fusing loops at depth 'dstLoopDepth' wouldn't preserve
304   // loop dependences.
305   // TODO: Enable this check for sibling and more generic loop fusion
306   // strategies.
307   if (fusionStrategy.getStrategy() == FusionStrategy::ProducerConsumer) {
308     // TODO: 'getMaxLoopDepth' does not support forward slice fusion.
309     assert(isSrcForOpBeforeDstForOp && "Unexpected forward slice fusion");
310     if (getMaxLoopDepth(opsA, opsB) < dstLoopDepth) {
311       LLVM_DEBUG(llvm::dbgs() << "Fusion would violate loop dependences\n");
312       return FusionResult::FailFusionDependence;
313     }
314   }
315 
316   // Calculate the number of common loops surrounding 'srcForOp' and 'dstForOp'.
317   unsigned numCommonLoops = mlir::getNumCommonSurroundingLoops(
318       *srcForOp.getOperation(), *dstForOp.getOperation());
319 
320   // Filter out ops in 'opsA' to compute the slice union based on the
321   // assumptions made by the fusion strategy.
322   SmallVector<Operation *, 4> strategyOpsA;
323   switch (fusionStrategy.getStrategy()) {
324   case FusionStrategy::Generic:
325     // Generic fusion. Take into account all the memory operations to compute
326     // the slice union.
327     strategyOpsA.append(opsA.begin(), opsA.end());
328     break;
329   case FusionStrategy::ProducerConsumer:
330     // Producer-consumer fusion (AffineLoopFusion pass) only takes into
331     // account stores in 'srcForOp' to compute the slice union.
332     for (Operation *op : opsA) {
333       if (isa<AffineWriteOpInterface>(op))
334         strategyOpsA.push_back(op);
335     }
336     break;
337   case FusionStrategy::Sibling:
338     // Sibling fusion (AffineLoopFusion pass) only takes into account the loads
339     // to 'memref' in 'srcForOp' to compute the slice union.
340     for (Operation *op : opsA) {
341       auto load = dyn_cast<AffineReadOpInterface>(op);
342       if (load && load.getMemRef() == fusionStrategy.getSiblingFusionMemRef())
343         strategyOpsA.push_back(op);
344     }
345     break;
346   }
347 
348   // Compute union of computation slices computed between all pairs of ops
349   // from 'forOpA' and 'forOpB'.
350   SliceComputationResult sliceComputationResult =
351       mlir::computeSliceUnion(strategyOpsA, opsB, dstLoopDepth, numCommonLoops,
352                               isSrcForOpBeforeDstForOp, srcSlice);
353   if (sliceComputationResult.value == SliceComputationResult::GenericFailure) {
354     LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n");
355     return FusionResult::FailPrecondition;
356   }
357   if (sliceComputationResult.value ==
358       SliceComputationResult::IncorrectSliceFailure) {
359     LLVM_DEBUG(llvm::dbgs() << "Incorrect slice computation\n");
360     return FusionResult::FailIncorrectSlice;
361   }
362 
363   return FusionResult::Success;
364 }
365 
366 /// Patch the loop body of a forOp that is a single iteration reduction loop
367 /// into its containing block.
368 LogicalResult promoteSingleIterReductionLoop(AffineForOp forOp,
369                                              bool siblingFusionUser) {
370   // Check if the reduction loop is a single iteration loop.
371   Optional<uint64_t> tripCount = getConstantTripCount(forOp);
372   if (!tripCount || tripCount.getValue() != 1)
373     return failure();
374   auto iterOperands = forOp.getIterOperands();
375   auto *parentOp = forOp->getParentOp();
376   if (!isa<AffineForOp>(parentOp))
377     return failure();
378   auto newOperands = forOp.getBody()->getTerminator()->getOperands();
379   OpBuilder b(parentOp);
380   // Replace the parent loop and add iteroperands and results from the `forOp`.
381   AffineForOp parentForOp = forOp->getParentOfType<AffineForOp>();
382   AffineForOp newLoop = replaceForOpWithNewYields(
383       b, parentForOp, iterOperands, newOperands, forOp.getRegionIterArgs());
384 
385   // For sibling-fusion users, collect operations that use the results of the
386   // `forOp` outside the new parent loop that has absorbed all its iter args
387   // and operands. These operations will be moved later after the results
388   // have been replaced.
389   SetVector<Operation *> forwardSlice;
390   if (siblingFusionUser) {
391     for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) {
392       SetVector<Operation *> tmpForwardSlice;
393       getForwardSlice(forOp.getResult(i), &tmpForwardSlice);
394       forwardSlice.set_union(tmpForwardSlice);
395     }
396   }
397   // Update the results of the `forOp` in the new loop.
398   for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) {
399     forOp.getResult(i).replaceAllUsesWith(
400         newLoop.getResult(i + parentOp->getNumResults()));
401   }
402   // For sibling-fusion users, move operations that use the results of the
403   // `forOp` outside the new parent loop
404   if (siblingFusionUser) {
405     topologicalSort(forwardSlice);
406     for (Operation *op : llvm::reverse(forwardSlice))
407       op->moveAfter(newLoop);
408   }
409   // Replace the induction variable.
410   auto iv = forOp.getInductionVar();
411   iv.replaceAllUsesWith(newLoop.getInductionVar());
412   // Replace the iter args.
413   auto forOpIterArgs = forOp.getRegionIterArgs();
414   for (auto it : llvm::zip(forOpIterArgs, newLoop.getRegionIterArgs().take_back(
415                                               forOpIterArgs.size()))) {
416     std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
417   }
418   // Move the loop body operations, except for its terminator, to the loop's
419   // containing block.
420   forOp.getBody()->back().erase();
421   auto *parentBlock = forOp->getBlock();
422   parentBlock->getOperations().splice(Block::iterator(forOp),
423                                       forOp.getBody()->getOperations());
424   forOp.erase();
425   parentForOp.erase();
426   return success();
427 }
428 
429 /// Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point
430 /// and source slice loop bounds specified in 'srcSlice'.
431 void mlir::fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
432                      const ComputationSliceState &srcSlice,
433                      bool isInnermostSiblingInsertion) {
434   // Clone 'srcForOp' into 'dstForOp' at 'srcSlice->insertPoint'.
435   OpBuilder b(srcSlice.insertPoint->getBlock(), srcSlice.insertPoint);
436   BlockAndValueMapping mapper;
437   b.clone(*srcForOp, mapper);
438 
439   // Update 'sliceLoopNest' upper and lower bounds from computed 'srcSlice'.
440   SmallVector<AffineForOp, 4> sliceLoops;
441   for (unsigned i = 0, e = srcSlice.ivs.size(); i < e; ++i) {
442     auto loopIV = mapper.lookupOrNull(srcSlice.ivs[i]);
443     if (!loopIV)
444       continue;
445     auto forOp = getForInductionVarOwner(loopIV);
446     sliceLoops.push_back(forOp);
447     if (AffineMap lbMap = srcSlice.lbs[i]) {
448       auto lbOperands = srcSlice.lbOperands[i];
449       canonicalizeMapAndOperands(&lbMap, &lbOperands);
450       forOp.setLowerBound(lbOperands, lbMap);
451     }
452     if (AffineMap ubMap = srcSlice.ubs[i]) {
453       auto ubOperands = srcSlice.ubOperands[i];
454       canonicalizeMapAndOperands(&ubMap, &ubOperands);
455       forOp.setUpperBound(ubOperands, ubMap);
456     }
457   }
458 
459   llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
460   auto srcIsUnitSlice = [&]() {
461     return (buildSliceTripCountMap(srcSlice, &sliceTripCountMap) &&
462             (getSliceIterationCount(sliceTripCountMap) == 1));
463   };
464   // Fix up and if possible, eliminate single iteration loops.
465   for (AffineForOp forOp : sliceLoops) {
466     if (isLoopParallelAndContainsReduction(forOp) &&
467         isInnermostSiblingInsertion && srcIsUnitSlice())
468       // Patch reduction loop - only ones that are sibling-fused with the
469       // destination loop - into the parent loop.
470       (void)promoteSingleIterReductionLoop(forOp, true);
471     else
472       // Promote any single iteration slice loops.
473       (void)promoteIfSingleIteration(forOp);
474   }
475 }
476 
477 /// Collect loop nest statistics (eg. loop trip count and operation count)
478 /// in 'stats' for loop nest rooted at 'forOp'. Returns true on success,
479 /// returns false otherwise.
480 bool mlir::getLoopNestStats(AffineForOp forOpRoot, LoopNestStats *stats) {
481   auto walkResult = forOpRoot.walk([&](AffineForOp forOp) {
482     auto *childForOp = forOp.getOperation();
483     auto *parentForOp = forOp->getParentOp();
484     if (!llvm::isa<FuncOp>(parentForOp)) {
485       if (!isa<AffineForOp>(parentForOp)) {
486         LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp\n");
487         return WalkResult::interrupt();
488       }
489       // Add mapping to 'forOp' from its parent AffineForOp.
490       stats->loopMap[parentForOp].push_back(forOp);
491     }
492 
493     // Record the number of op operations in the body of 'forOp'.
494     unsigned count = 0;
495     stats->opCountMap[childForOp] = 0;
496     for (auto &op : *forOp.getBody()) {
497       if (!isa<AffineForOp, AffineIfOp>(op))
498         ++count;
499     }
500     stats->opCountMap[childForOp] = count;
501 
502     // Record trip count for 'forOp'. Set flag if trip count is not
503     // constant.
504     Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
505     if (!maybeConstTripCount.hasValue()) {
506       // Currently only constant trip count loop nests are supported.
507       LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported\n");
508       return WalkResult::interrupt();
509     }
510 
511     stats->tripCountMap[childForOp] = maybeConstTripCount.getValue();
512     return WalkResult::advance();
513   });
514   return !walkResult.wasInterrupted();
515 }
516 
517 // Computes the total cost of the loop nest rooted at 'forOp'.
518 // Currently, the total cost is computed by counting the total operation
519 // instance count (i.e. total number of operations in the loop bodyloop
520 // operation count * loop trip count) for the entire loop nest.
521 // If 'tripCountOverrideMap' is non-null, overrides the trip count for loops
522 // specified in the map when computing the total op instance count.
523 // NOTEs: 1) This is used to compute the cost of computation slices, which are
524 // sliced along the iteration dimension, and thus reduce the trip count.
525 // If 'computeCostMap' is non-null, the total op count for forOps specified
526 // in the map is increased (not overridden) by adding the op count from the
527 // map to the existing op count for the for loop. This is done before
528 // multiplying by the loop's trip count, and is used to model the cost of
529 // inserting a sliced loop nest of known cost into the loop's body.
530 // 2) This is also used to compute the cost of fusing a slice of some loop nest
531 // within another loop.
532 static int64_t getComputeCostHelper(
533     Operation *forOp, LoopNestStats &stats,
534     llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountOverrideMap,
535     DenseMap<Operation *, int64_t> *computeCostMap) {
536   // 'opCount' is the total number operations in one iteration of 'forOp' body,
537   // minus terminator op which is a no-op.
538   int64_t opCount = stats.opCountMap[forOp] - 1;
539   if (stats.loopMap.count(forOp) > 0) {
540     for (auto childForOp : stats.loopMap[forOp]) {
541       opCount += getComputeCostHelper(childForOp.getOperation(), stats,
542                                       tripCountOverrideMap, computeCostMap);
543     }
544   }
545   // Add in additional op instances from slice (if specified in map).
546   if (computeCostMap != nullptr) {
547     auto it = computeCostMap->find(forOp);
548     if (it != computeCostMap->end()) {
549       opCount += it->second;
550     }
551   }
552   // Override trip count (if specified in map).
553   int64_t tripCount = stats.tripCountMap[forOp];
554   if (tripCountOverrideMap != nullptr) {
555     auto it = tripCountOverrideMap->find(forOp);
556     if (it != tripCountOverrideMap->end()) {
557       tripCount = it->second;
558     }
559   }
560   // Returns the total number of dynamic instances of operations in loop body.
561   return tripCount * opCount;
562 }
563 
564 /// Computes the total cost of the loop nest rooted at 'forOp' using 'stats'.
565 /// Currently, the total cost is computed by counting the total operation
566 /// instance count (i.e. total number of operations in the loop body * loop
567 /// trip count) for the entire loop nest.
568 int64_t mlir::getComputeCost(AffineForOp forOp, LoopNestStats &stats) {
569   return getComputeCostHelper(forOp.getOperation(), stats,
570                               /*tripCountOverrideMap=*/nullptr,
571                               /*computeCostMap=*/nullptr);
572 }
573 
574 /// Computes and returns in 'computeCost', the total compute cost of fusing the
575 /// 'slice' of the loop nest rooted at 'srcForOp' into 'dstForOp'. Currently,
576 /// the total cost is computed by counting the total operation instance count
577 /// (i.e. total number of operations in the loop body * loop trip count) for
578 /// the entire loop nest.
579 bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats,
580                                 AffineForOp dstForOp, LoopNestStats &dstStats,
581                                 const ComputationSliceState &slice,
582                                 int64_t *computeCost) {
583   llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
584   DenseMap<Operation *, int64_t> computeCostMap;
585 
586   // Build trip count map for computation slice.
587   if (!buildSliceTripCountMap(slice, &sliceTripCountMap))
588     return false;
589   // Checks whether a store to load forwarding will happen.
590   int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
591   assert(sliceIterationCount > 0);
592   bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
593   auto *insertPointParent = slice.insertPoint->getParentOp();
594 
595   // The store and loads to this memref will disappear.
596   // TODO: Add load coalescing to memref data flow opt pass.
597   if (storeLoadFwdGuaranteed) {
598     // Subtract from operation count the loads/store we expect load/store
599     // forwarding to remove.
600     unsigned storeCount = 0;
601     llvm::SmallDenseSet<Value, 4> storeMemrefs;
602     srcForOp.walk([&](Operation *op) {
603       if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) {
604         storeMemrefs.insert(storeOp.getMemRef());
605         ++storeCount;
606       }
607     });
608     // Subtract out any store ops in single-iteration src slice loop nest.
609     if (storeCount > 0)
610       computeCostMap[insertPointParent] = -storeCount;
611     // Subtract out any load users of 'storeMemrefs' nested below
612     // 'insertPointParent'.
613     for (auto value : storeMemrefs) {
614       for (auto *user : value.getUsers()) {
615         if (auto loadOp = dyn_cast<AffineReadOpInterface>(user)) {
616           SmallVector<AffineForOp, 4> loops;
617           // Check if any loop in loop nest surrounding 'user' is
618           // 'insertPointParent'.
619           getLoopIVs(*user, &loops);
620           if (llvm::is_contained(loops, cast<AffineForOp>(insertPointParent))) {
621             if (auto forOp =
622                     dyn_cast_or_null<AffineForOp>(user->getParentOp())) {
623               if (computeCostMap.count(forOp) == 0)
624                 computeCostMap[forOp] = 0;
625               computeCostMap[forOp] -= 1;
626             }
627           }
628         }
629       }
630     }
631   }
632 
633   // Compute op instance count for the src loop nest with iteration slicing.
634   int64_t sliceComputeCost = getComputeCostHelper(
635       srcForOp.getOperation(), srcStats, &sliceTripCountMap, &computeCostMap);
636 
637   // Compute cost of fusion for this depth.
638   computeCostMap[insertPointParent] = sliceComputeCost;
639 
640   *computeCost =
641       getComputeCostHelper(dstForOp.getOperation(), dstStats,
642                            /*tripCountOverrideMap=*/nullptr, &computeCostMap);
643   return true;
644 }
645 
646 /// Returns in 'producerConsumerMemrefs' the memrefs involved in a
647 /// producer-consumer dependence between write ops in 'srcOps' and read ops in
648 /// 'dstOps'.
649 void mlir::gatherProducerConsumerMemrefs(
650     ArrayRef<Operation *> srcOps, ArrayRef<Operation *> dstOps,
651     DenseSet<Value> &producerConsumerMemrefs) {
652   // Gather memrefs from stores in 'srcOps'.
653   DenseSet<Value> srcStoreMemRefs;
654   for (Operation *op : srcOps)
655     if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op))
656       srcStoreMemRefs.insert(storeOp.getMemRef());
657 
658   // Compute the intersection between memrefs from stores in 'srcOps' and
659   // memrefs from loads in 'dstOps'.
660   for (Operation *op : dstOps)
661     if (auto loadOp = dyn_cast<AffineReadOpInterface>(op))
662       if (srcStoreMemRefs.count(loadOp.getMemRef()) > 0)
663         producerConsumerMemrefs.insert(loadOp.getMemRef());
664 }
665