1a70aa7bbSRiver Riddle //===- TestLoopFusion.cpp - Test loop fusion ------------------------------===//
2a70aa7bbSRiver Riddle //
3a70aa7bbSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a70aa7bbSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5a70aa7bbSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a70aa7bbSRiver Riddle //
7a70aa7bbSRiver Riddle //===----------------------------------------------------------------------===//
8a70aa7bbSRiver Riddle //
9a70aa7bbSRiver Riddle // This file implements a pass to test various loop fusion utility functions.
10a70aa7bbSRiver Riddle //
11a70aa7bbSRiver Riddle //===----------------------------------------------------------------------===//
12a70aa7bbSRiver Riddle 
13a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/Analysis/Utils.h"
14a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h"
15a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/LoopFusionUtils.h"
16a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/LoopUtils.h"
1736550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
18a70aa7bbSRiver Riddle #include "mlir/Pass/Pass.h"
19a70aa7bbSRiver Riddle 
20a70aa7bbSRiver Riddle #define DEBUG_TYPE "test-loop-fusion"
21a70aa7bbSRiver Riddle 
22a70aa7bbSRiver Riddle using namespace mlir;
23a70aa7bbSRiver Riddle 
24a70aa7bbSRiver Riddle static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
25a70aa7bbSRiver Riddle 
26a70aa7bbSRiver Riddle static llvm::cl::opt<bool> clTestDependenceCheck(
27a70aa7bbSRiver Riddle     "test-loop-fusion-dependence-check",
28a70aa7bbSRiver Riddle     llvm::cl::desc("Enable testing of loop fusion dependence check"),
29a70aa7bbSRiver Riddle     llvm::cl::cat(clOptionsCategory));
30a70aa7bbSRiver Riddle 
31a70aa7bbSRiver Riddle static llvm::cl::opt<bool> clTestSliceComputation(
32a70aa7bbSRiver Riddle     "test-loop-fusion-slice-computation",
33a70aa7bbSRiver Riddle     llvm::cl::desc("Enable testing of loop fusion slice computation"),
34a70aa7bbSRiver Riddle     llvm::cl::cat(clOptionsCategory));
35a70aa7bbSRiver Riddle 
36a70aa7bbSRiver Riddle static llvm::cl::opt<bool> clTestLoopFusionTransformation(
37a70aa7bbSRiver Riddle     "test-loop-fusion-transformation",
38a70aa7bbSRiver Riddle     llvm::cl::desc("Enable testing of loop fusion transformation"),
39a70aa7bbSRiver Riddle     llvm::cl::cat(clOptionsCategory));
40a70aa7bbSRiver Riddle 
41a70aa7bbSRiver Riddle namespace {
42a70aa7bbSRiver Riddle 
43a70aa7bbSRiver Riddle struct TestLoopFusion
44*58ceae95SRiver Riddle     : public PassWrapper<TestLoopFusion, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon6d9b19930111::TestLoopFusion455e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopFusion)
465e50dd04SRiver Riddle 
47a70aa7bbSRiver Riddle   StringRef getArgument() const final { return "test-loop-fusion"; }
getDescription__anon6d9b19930111::TestLoopFusion48a70aa7bbSRiver Riddle   StringRef getDescription() const final {
49a70aa7bbSRiver Riddle     return "Tests loop fusion utility functions.";
50a70aa7bbSRiver Riddle   }
51a70aa7bbSRiver Riddle   void runOnOperation() override;
52a70aa7bbSRiver Riddle };
53a70aa7bbSRiver Riddle 
54a70aa7bbSRiver Riddle } // namespace
55a70aa7bbSRiver Riddle 
56a70aa7bbSRiver Riddle // Run fusion dependence check on 'loops[i]' and 'loops[j]' at loop depths
57a70aa7bbSRiver Riddle // in range ['loopDepth' + 1, 'maxLoopDepth'].
58a70aa7bbSRiver Riddle // Emits a remark on 'loops[i]' if a fusion-preventing dependence exists.
59a70aa7bbSRiver Riddle // Returns false as IR is not transformed.
testDependenceCheck(AffineForOp srcForOp,AffineForOp dstForOp,unsigned i,unsigned j,unsigned loopDepth,unsigned maxLoopDepth)60a70aa7bbSRiver Riddle static bool testDependenceCheck(AffineForOp srcForOp, AffineForOp dstForOp,
61a70aa7bbSRiver Riddle                                 unsigned i, unsigned j, unsigned loopDepth,
62a70aa7bbSRiver Riddle                                 unsigned maxLoopDepth) {
63a70aa7bbSRiver Riddle   mlir::ComputationSliceState sliceUnion;
64a70aa7bbSRiver Riddle   for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
65a70aa7bbSRiver Riddle     FusionResult result =
66a70aa7bbSRiver Riddle         mlir::canFuseLoops(srcForOp, dstForOp, d, &sliceUnion);
67a70aa7bbSRiver Riddle     if (result.value == FusionResult::FailBlockDependence) {
68a70aa7bbSRiver Riddle       srcForOp->emitRemark("block-level dependence preventing"
69a70aa7bbSRiver Riddle                            " fusion of loop nest ")
70a70aa7bbSRiver Riddle           << i << " into loop nest " << j << " at depth " << loopDepth;
71a70aa7bbSRiver Riddle     }
72a70aa7bbSRiver Riddle   }
73a70aa7bbSRiver Riddle   return false;
74a70aa7bbSRiver Riddle }
75a70aa7bbSRiver Riddle 
76a70aa7bbSRiver Riddle // Returns the index of 'op' in its block.
getBlockIndex(Operation & op)77a70aa7bbSRiver Riddle static unsigned getBlockIndex(Operation &op) {
78a70aa7bbSRiver Riddle   unsigned index = 0;
79a70aa7bbSRiver Riddle   for (auto &opX : *op.getBlock()) {
80a70aa7bbSRiver Riddle     if (&op == &opX)
81a70aa7bbSRiver Riddle       break;
82a70aa7bbSRiver Riddle     ++index;
83a70aa7bbSRiver Riddle   }
84a70aa7bbSRiver Riddle   return index;
85a70aa7bbSRiver Riddle }
86a70aa7bbSRiver Riddle 
87a70aa7bbSRiver Riddle // Returns a string representation of 'sliceUnion'.
getSliceStr(const mlir::ComputationSliceState & sliceUnion)88a70aa7bbSRiver Riddle static std::string getSliceStr(const mlir::ComputationSliceState &sliceUnion) {
89a70aa7bbSRiver Riddle   std::string result;
90a70aa7bbSRiver Riddle   llvm::raw_string_ostream os(result);
91a70aa7bbSRiver Riddle   // Slice insertion point format [loop-depth, operation-block-index]
92a70aa7bbSRiver Riddle   unsigned ipd = getNestingDepth(&*sliceUnion.insertPoint);
93a70aa7bbSRiver Riddle   unsigned ipb = getBlockIndex(*sliceUnion.insertPoint);
94a70aa7bbSRiver Riddle   os << "insert point: (" << std::to_string(ipd) << ", " << std::to_string(ipb)
95a70aa7bbSRiver Riddle      << ")";
96a70aa7bbSRiver Riddle   assert(sliceUnion.lbs.size() == sliceUnion.ubs.size());
97a70aa7bbSRiver Riddle   os << " loop bounds: ";
98a70aa7bbSRiver Riddle   for (unsigned k = 0, e = sliceUnion.lbs.size(); k < e; ++k) {
99a70aa7bbSRiver Riddle     os << '[';
100a70aa7bbSRiver Riddle     sliceUnion.lbs[k].print(os);
101a70aa7bbSRiver Riddle     os << ", ";
102a70aa7bbSRiver Riddle     sliceUnion.ubs[k].print(os);
103a70aa7bbSRiver Riddle     os << "] ";
104a70aa7bbSRiver Riddle   }
105a70aa7bbSRiver Riddle   return os.str();
106a70aa7bbSRiver Riddle }
107a70aa7bbSRiver Riddle 
108a70aa7bbSRiver Riddle /// Computes fusion slice union on 'loops[i]' and 'loops[j]' at loop depths
109a70aa7bbSRiver Riddle /// in range ['loopDepth' + 1, 'maxLoopDepth'].
110a70aa7bbSRiver Riddle /// Emits a string representation of the slice union as a remark on 'loops[j]'
111a70aa7bbSRiver Riddle /// and marks this as incorrect slice if the slice is invalid. Returns false as
112a70aa7bbSRiver Riddle /// IR is not transformed.
testSliceComputation(AffineForOp forOpA,AffineForOp forOpB,unsigned i,unsigned j,unsigned loopDepth,unsigned maxLoopDepth)113a70aa7bbSRiver Riddle static bool testSliceComputation(AffineForOp forOpA, AffineForOp forOpB,
114a70aa7bbSRiver Riddle                                  unsigned i, unsigned j, unsigned loopDepth,
115a70aa7bbSRiver Riddle                                  unsigned maxLoopDepth) {
116a70aa7bbSRiver Riddle   for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
117a70aa7bbSRiver Riddle     mlir::ComputationSliceState sliceUnion;
118a70aa7bbSRiver Riddle     FusionResult result = mlir::canFuseLoops(forOpA, forOpB, d, &sliceUnion);
119a70aa7bbSRiver Riddle     if (result.value == FusionResult::Success) {
120a70aa7bbSRiver Riddle       forOpB->emitRemark("slice (")
121a70aa7bbSRiver Riddle           << " src loop: " << i << ", dst loop: " << j << ", depth: " << d
122a70aa7bbSRiver Riddle           << " : " << getSliceStr(sliceUnion) << ")";
123a70aa7bbSRiver Riddle     } else if (result.value == FusionResult::FailIncorrectSlice) {
124a70aa7bbSRiver Riddle       forOpB->emitRemark("Incorrect slice (")
125a70aa7bbSRiver Riddle           << " src loop: " << i << ", dst loop: " << j << ", depth: " << d
126a70aa7bbSRiver Riddle           << " : " << getSliceStr(sliceUnion) << ")";
127a70aa7bbSRiver Riddle     }
128a70aa7bbSRiver Riddle   }
129a70aa7bbSRiver Riddle   return false;
130a70aa7bbSRiver Riddle }
131a70aa7bbSRiver Riddle 
132a70aa7bbSRiver Riddle // Attempts to fuse 'forOpA' into 'forOpB' at loop depths in range
133a70aa7bbSRiver Riddle // ['loopDepth' + 1, 'maxLoopDepth'].
134a70aa7bbSRiver Riddle // Returns true if loops were successfully fused, false otherwise.
testLoopFusionTransformation(AffineForOp forOpA,AffineForOp forOpB,unsigned i,unsigned j,unsigned loopDepth,unsigned maxLoopDepth)135a70aa7bbSRiver Riddle static bool testLoopFusionTransformation(AffineForOp forOpA, AffineForOp forOpB,
136a70aa7bbSRiver Riddle                                          unsigned i, unsigned j,
137a70aa7bbSRiver Riddle                                          unsigned loopDepth,
138a70aa7bbSRiver Riddle                                          unsigned maxLoopDepth) {
139a70aa7bbSRiver Riddle   for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
140a70aa7bbSRiver Riddle     mlir::ComputationSliceState sliceUnion;
141a70aa7bbSRiver Riddle     FusionResult result = mlir::canFuseLoops(forOpA, forOpB, d, &sliceUnion);
142a70aa7bbSRiver Riddle     if (result.value == FusionResult::Success) {
143a70aa7bbSRiver Riddle       mlir::fuseLoops(forOpA, forOpB, sliceUnion);
144a70aa7bbSRiver Riddle       // Note: 'forOpA' is removed to simplify test output. A proper loop
145a70aa7bbSRiver Riddle       // fusion pass should check the data dependence graph and run memref
146a70aa7bbSRiver Riddle       // region analysis to ensure removing 'forOpA' is safe.
147a70aa7bbSRiver Riddle       forOpA.erase();
148a70aa7bbSRiver Riddle       return true;
149a70aa7bbSRiver Riddle     }
150a70aa7bbSRiver Riddle   }
151a70aa7bbSRiver Riddle   return false;
152a70aa7bbSRiver Riddle }
153a70aa7bbSRiver Riddle 
154a70aa7bbSRiver Riddle using LoopFunc = function_ref<bool(AffineForOp, AffineForOp, unsigned, unsigned,
155a70aa7bbSRiver Riddle                                    unsigned, unsigned)>;
156a70aa7bbSRiver Riddle 
157a70aa7bbSRiver Riddle // Run tests on all combinations of src/dst loop nests in 'depthToLoops'.
158a70aa7bbSRiver Riddle // If 'return_on_change' is true, returns on first invocation of 'fn' which
159a70aa7bbSRiver Riddle // returns true.
iterateLoops(ArrayRef<SmallVector<AffineForOp,2>> depthToLoops,LoopFunc fn,bool returnOnChange=false)160a70aa7bbSRiver Riddle static bool iterateLoops(ArrayRef<SmallVector<AffineForOp, 2>> depthToLoops,
161a70aa7bbSRiver Riddle                          LoopFunc fn, bool returnOnChange = false) {
162a70aa7bbSRiver Riddle   bool changed = false;
163a70aa7bbSRiver Riddle   for (unsigned loopDepth = 0, end = depthToLoops.size(); loopDepth < end;
164a70aa7bbSRiver Riddle        ++loopDepth) {
165a70aa7bbSRiver Riddle     auto &loops = depthToLoops[loopDepth];
166a70aa7bbSRiver Riddle     unsigned numLoops = loops.size();
167a70aa7bbSRiver Riddle     for (unsigned j = 0; j < numLoops; ++j) {
168a70aa7bbSRiver Riddle       for (unsigned k = 0; k < numLoops; ++k) {
169a70aa7bbSRiver Riddle         if (j != k)
170a70aa7bbSRiver Riddle           changed |=
171a70aa7bbSRiver Riddle               fn(loops[j], loops[k], j, k, loopDepth, depthToLoops.size());
172a70aa7bbSRiver Riddle         if (changed && returnOnChange)
173a70aa7bbSRiver Riddle           return true;
174a70aa7bbSRiver Riddle       }
175a70aa7bbSRiver Riddle     }
176a70aa7bbSRiver Riddle   }
177a70aa7bbSRiver Riddle   return changed;
178a70aa7bbSRiver Riddle }
179a70aa7bbSRiver Riddle 
runOnOperation()180a70aa7bbSRiver Riddle void TestLoopFusion::runOnOperation() {
181a70aa7bbSRiver Riddle   std::vector<SmallVector<AffineForOp, 2>> depthToLoops;
182a70aa7bbSRiver Riddle   if (clTestLoopFusionTransformation) {
183a70aa7bbSRiver Riddle     // Run loop fusion until a fixed point is reached.
184a70aa7bbSRiver Riddle     do {
185a70aa7bbSRiver Riddle       depthToLoops.clear();
186a70aa7bbSRiver Riddle       // Gather all AffineForOps by loop depth.
187a70aa7bbSRiver Riddle       gatherLoops(getOperation(), depthToLoops);
188a70aa7bbSRiver Riddle 
189a70aa7bbSRiver Riddle       // Try to fuse all combinations of src/dst loop nests in 'depthToLoops'.
190a70aa7bbSRiver Riddle     } while (iterateLoops(depthToLoops, testLoopFusionTransformation,
191a70aa7bbSRiver Riddle                           /*returnOnChange=*/true));
192a70aa7bbSRiver Riddle     return;
193a70aa7bbSRiver Riddle   }
194a70aa7bbSRiver Riddle 
195a70aa7bbSRiver Riddle   // Gather all AffineForOps by loop depth.
196a70aa7bbSRiver Riddle   gatherLoops(getOperation(), depthToLoops);
197a70aa7bbSRiver Riddle 
198a70aa7bbSRiver Riddle   // Run tests on all combinations of src/dst loop nests in 'depthToLoops'.
199a70aa7bbSRiver Riddle   if (clTestDependenceCheck)
200a70aa7bbSRiver Riddle     iterateLoops(depthToLoops, testDependenceCheck);
201a70aa7bbSRiver Riddle   if (clTestSliceComputation)
202a70aa7bbSRiver Riddle     iterateLoops(depthToLoops, testSliceComputation);
203a70aa7bbSRiver Riddle }
204a70aa7bbSRiver Riddle 
205a70aa7bbSRiver Riddle namespace mlir {
206a70aa7bbSRiver Riddle namespace test {
registerTestLoopFusion()207a70aa7bbSRiver Riddle void registerTestLoopFusion() { PassRegistration<TestLoopFusion>(); }
208a70aa7bbSRiver Riddle } // namespace test
209a70aa7bbSRiver Riddle } // namespace mlir
210