1 //===- TestLoopFusion.cpp - Test loop fusion ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to test various loop fusion utility functions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/Analysis/Utils.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Affine/LoopFusionUtils.h"
16 #include "mlir/Dialect/Affine/LoopUtils.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/Pass/Pass.h"
19 
20 #define DEBUG_TYPE "test-loop-fusion"
21 
22 using namespace mlir;
23 
24 static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
25 
26 static llvm::cl::opt<bool> clTestDependenceCheck(
27     "test-loop-fusion-dependence-check",
28     llvm::cl::desc("Enable testing of loop fusion dependence check"),
29     llvm::cl::cat(clOptionsCategory));
30 
31 static llvm::cl::opt<bool> clTestSliceComputation(
32     "test-loop-fusion-slice-computation",
33     llvm::cl::desc("Enable testing of loop fusion slice computation"),
34     llvm::cl::cat(clOptionsCategory));
35 
36 static llvm::cl::opt<bool> clTestLoopFusionTransformation(
37     "test-loop-fusion-transformation",
38     llvm::cl::desc("Enable testing of loop fusion transformation"),
39     llvm::cl::cat(clOptionsCategory));
40 
41 namespace {
42 
43 struct TestLoopFusion
44     : public PassWrapper<TestLoopFusion, OperationPass<FuncOp>> {
45   StringRef getArgument() const final { return "test-loop-fusion"; }
46   StringRef getDescription() const final {
47     return "Tests loop fusion utility functions.";
48   }
49   void runOnOperation() override;
50 };
51 
52 } // namespace
53 
54 // Run fusion dependence check on 'loops[i]' and 'loops[j]' at loop depths
55 // in range ['loopDepth' + 1, 'maxLoopDepth'].
56 // Emits a remark on 'loops[i]' if a fusion-preventing dependence exists.
57 // Returns false as IR is not transformed.
58 static bool testDependenceCheck(AffineForOp srcForOp, AffineForOp dstForOp,
59                                 unsigned i, unsigned j, unsigned loopDepth,
60                                 unsigned maxLoopDepth) {
61   mlir::ComputationSliceState sliceUnion;
62   for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
63     FusionResult result =
64         mlir::canFuseLoops(srcForOp, dstForOp, d, &sliceUnion);
65     if (result.value == FusionResult::FailBlockDependence) {
66       srcForOp->emitRemark("block-level dependence preventing"
67                            " fusion of loop nest ")
68           << i << " into loop nest " << j << " at depth " << loopDepth;
69     }
70   }
71   return false;
72 }
73 
74 // Returns the index of 'op' in its block.
75 static unsigned getBlockIndex(Operation &op) {
76   unsigned index = 0;
77   for (auto &opX : *op.getBlock()) {
78     if (&op == &opX)
79       break;
80     ++index;
81   }
82   return index;
83 }
84 
85 // Returns a string representation of 'sliceUnion'.
86 static std::string getSliceStr(const mlir::ComputationSliceState &sliceUnion) {
87   std::string result;
88   llvm::raw_string_ostream os(result);
89   // Slice insertion point format [loop-depth, operation-block-index]
90   unsigned ipd = getNestingDepth(&*sliceUnion.insertPoint);
91   unsigned ipb = getBlockIndex(*sliceUnion.insertPoint);
92   os << "insert point: (" << std::to_string(ipd) << ", " << std::to_string(ipb)
93      << ")";
94   assert(sliceUnion.lbs.size() == sliceUnion.ubs.size());
95   os << " loop bounds: ";
96   for (unsigned k = 0, e = sliceUnion.lbs.size(); k < e; ++k) {
97     os << '[';
98     sliceUnion.lbs[k].print(os);
99     os << ", ";
100     sliceUnion.ubs[k].print(os);
101     os << "] ";
102   }
103   return os.str();
104 }
105 
106 /// Computes fusion slice union on 'loops[i]' and 'loops[j]' at loop depths
107 /// in range ['loopDepth' + 1, 'maxLoopDepth'].
108 /// Emits a string representation of the slice union as a remark on 'loops[j]'
109 /// and marks this as incorrect slice if the slice is invalid. Returns false as
110 /// IR is not transformed.
111 static bool testSliceComputation(AffineForOp forOpA, AffineForOp forOpB,
112                                  unsigned i, unsigned j, unsigned loopDepth,
113                                  unsigned maxLoopDepth) {
114   for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
115     mlir::ComputationSliceState sliceUnion;
116     FusionResult result = mlir::canFuseLoops(forOpA, forOpB, d, &sliceUnion);
117     if (result.value == FusionResult::Success) {
118       forOpB->emitRemark("slice (")
119           << " src loop: " << i << ", dst loop: " << j << ", depth: " << d
120           << " : " << getSliceStr(sliceUnion) << ")";
121     } else if (result.value == FusionResult::FailIncorrectSlice) {
122       forOpB->emitRemark("Incorrect slice (")
123           << " src loop: " << i << ", dst loop: " << j << ", depth: " << d
124           << " : " << getSliceStr(sliceUnion) << ")";
125     }
126   }
127   return false;
128 }
129 
130 // Attempts to fuse 'forOpA' into 'forOpB' at loop depths in range
131 // ['loopDepth' + 1, 'maxLoopDepth'].
132 // Returns true if loops were successfully fused, false otherwise.
133 static bool testLoopFusionTransformation(AffineForOp forOpA, AffineForOp forOpB,
134                                          unsigned i, unsigned j,
135                                          unsigned loopDepth,
136                                          unsigned maxLoopDepth) {
137   for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
138     mlir::ComputationSliceState sliceUnion;
139     FusionResult result = mlir::canFuseLoops(forOpA, forOpB, d, &sliceUnion);
140     if (result.value == FusionResult::Success) {
141       mlir::fuseLoops(forOpA, forOpB, sliceUnion);
142       // Note: 'forOpA' is removed to simplify test output. A proper loop
143       // fusion pass should check the data dependence graph and run memref
144       // region analysis to ensure removing 'forOpA' is safe.
145       forOpA.erase();
146       return true;
147     }
148   }
149   return false;
150 }
151 
152 using LoopFunc = function_ref<bool(AffineForOp, AffineForOp, unsigned, unsigned,
153                                    unsigned, unsigned)>;
154 
155 // Run tests on all combinations of src/dst loop nests in 'depthToLoops'.
156 // If 'return_on_change' is true, returns on first invocation of 'fn' which
157 // returns true.
158 static bool iterateLoops(ArrayRef<SmallVector<AffineForOp, 2>> depthToLoops,
159                          LoopFunc fn, bool returnOnChange = false) {
160   bool changed = false;
161   for (unsigned loopDepth = 0, end = depthToLoops.size(); loopDepth < end;
162        ++loopDepth) {
163     auto &loops = depthToLoops[loopDepth];
164     unsigned numLoops = loops.size();
165     for (unsigned j = 0; j < numLoops; ++j) {
166       for (unsigned k = 0; k < numLoops; ++k) {
167         if (j != k)
168           changed |=
169               fn(loops[j], loops[k], j, k, loopDepth, depthToLoops.size());
170         if (changed && returnOnChange)
171           return true;
172       }
173     }
174   }
175   return changed;
176 }
177 
178 void TestLoopFusion::runOnOperation() {
179   std::vector<SmallVector<AffineForOp, 2>> depthToLoops;
180   if (clTestLoopFusionTransformation) {
181     // Run loop fusion until a fixed point is reached.
182     do {
183       depthToLoops.clear();
184       // Gather all AffineForOps by loop depth.
185       gatherLoops(getOperation(), depthToLoops);
186 
187       // Try to fuse all combinations of src/dst loop nests in 'depthToLoops'.
188     } while (iterateLoops(depthToLoops, testLoopFusionTransformation,
189                           /*returnOnChange=*/true));
190     return;
191   }
192 
193   // Gather all AffineForOps by loop depth.
194   gatherLoops(getOperation(), depthToLoops);
195 
196   // Run tests on all combinations of src/dst loop nests in 'depthToLoops'.
197   if (clTestDependenceCheck)
198     iterateLoops(depthToLoops, testDependenceCheck);
199   if (clTestSliceComputation)
200     iterateLoops(depthToLoops, testSliceComputation);
201 }
202 
203 namespace mlir {
204 namespace test {
205 void registerTestLoopFusion() { PassRegistration<TestLoopFusion>(); }
206 } // namespace test
207 } // namespace mlir
208