1 //===- TestLoopFusion.cpp - Test loop fusion ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to test various loop fusion utility functions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/Analysis/Utils.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Affine/LoopFusionUtils.h"
16 #include "mlir/Dialect/Affine/LoopUtils.h"
17 #include "mlir/Pass/Pass.h"
18 
19 #define DEBUG_TYPE "test-loop-fusion"
20 
21 using namespace mlir;
22 
23 static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
24 
25 static llvm::cl::opt<bool> clTestDependenceCheck(
26     "test-loop-fusion-dependence-check",
27     llvm::cl::desc("Enable testing of loop fusion dependence check"),
28     llvm::cl::cat(clOptionsCategory));
29 
30 static llvm::cl::opt<bool> clTestSliceComputation(
31     "test-loop-fusion-slice-computation",
32     llvm::cl::desc("Enable testing of loop fusion slice computation"),
33     llvm::cl::cat(clOptionsCategory));
34 
35 static llvm::cl::opt<bool> clTestLoopFusionTransformation(
36     "test-loop-fusion-transformation",
37     llvm::cl::desc("Enable testing of loop fusion transformation"),
38     llvm::cl::cat(clOptionsCategory));
39 
40 namespace {
41 
42 struct TestLoopFusion
43     : public PassWrapper<TestLoopFusion, OperationPass<FuncOp>> {
44   StringRef getArgument() const final { return "test-loop-fusion"; }
45   StringRef getDescription() const final {
46     return "Tests loop fusion utility functions.";
47   }
48   void runOnOperation() override;
49 };
50 
51 } // namespace
52 
53 // Run fusion dependence check on 'loops[i]' and 'loops[j]' at loop depths
54 // in range ['loopDepth' + 1, 'maxLoopDepth'].
55 // Emits a remark on 'loops[i]' if a fusion-preventing dependence exists.
56 // Returns false as IR is not transformed.
57 static bool testDependenceCheck(AffineForOp srcForOp, AffineForOp dstForOp,
58                                 unsigned i, unsigned j, unsigned loopDepth,
59                                 unsigned maxLoopDepth) {
60   mlir::ComputationSliceState sliceUnion;
61   for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
62     FusionResult result =
63         mlir::canFuseLoops(srcForOp, dstForOp, d, &sliceUnion);
64     if (result.value == FusionResult::FailBlockDependence) {
65       srcForOp->emitRemark("block-level dependence preventing"
66                            " fusion of loop nest ")
67           << i << " into loop nest " << j << " at depth " << loopDepth;
68     }
69   }
70   return false;
71 }
72 
73 // Returns the index of 'op' in its block.
74 static unsigned getBlockIndex(Operation &op) {
75   unsigned index = 0;
76   for (auto &opX : *op.getBlock()) {
77     if (&op == &opX)
78       break;
79     ++index;
80   }
81   return index;
82 }
83 
84 // Returns a string representation of 'sliceUnion'.
85 static std::string getSliceStr(const mlir::ComputationSliceState &sliceUnion) {
86   std::string result;
87   llvm::raw_string_ostream os(result);
88   // Slice insertion point format [loop-depth, operation-block-index]
89   unsigned ipd = getNestingDepth(&*sliceUnion.insertPoint);
90   unsigned ipb = getBlockIndex(*sliceUnion.insertPoint);
91   os << "insert point: (" << std::to_string(ipd) << ", " << std::to_string(ipb)
92      << ")";
93   assert(sliceUnion.lbs.size() == sliceUnion.ubs.size());
94   os << " loop bounds: ";
95   for (unsigned k = 0, e = sliceUnion.lbs.size(); k < e; ++k) {
96     os << '[';
97     sliceUnion.lbs[k].print(os);
98     os << ", ";
99     sliceUnion.ubs[k].print(os);
100     os << "] ";
101   }
102   return os.str();
103 }
104 
105 /// Computes fusion slice union on 'loops[i]' and 'loops[j]' at loop depths
106 /// in range ['loopDepth' + 1, 'maxLoopDepth'].
107 /// Emits a string representation of the slice union as a remark on 'loops[j]'
108 /// and marks this as incorrect slice if the slice is invalid. Returns false as
109 /// IR is not transformed.
110 static bool testSliceComputation(AffineForOp forOpA, AffineForOp forOpB,
111                                  unsigned i, unsigned j, unsigned loopDepth,
112                                  unsigned maxLoopDepth) {
113   for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
114     mlir::ComputationSliceState sliceUnion;
115     FusionResult result = mlir::canFuseLoops(forOpA, forOpB, d, &sliceUnion);
116     if (result.value == FusionResult::Success) {
117       forOpB->emitRemark("slice (")
118           << " src loop: " << i << ", dst loop: " << j << ", depth: " << d
119           << " : " << getSliceStr(sliceUnion) << ")";
120     } else if (result.value == FusionResult::FailIncorrectSlice) {
121       forOpB->emitRemark("Incorrect slice (")
122           << " src loop: " << i << ", dst loop: " << j << ", depth: " << d
123           << " : " << getSliceStr(sliceUnion) << ")";
124     }
125   }
126   return false;
127 }
128 
129 // Attempts to fuse 'forOpA' into 'forOpB' at loop depths in range
130 // ['loopDepth' + 1, 'maxLoopDepth'].
131 // Returns true if loops were successfully fused, false otherwise.
132 static bool testLoopFusionTransformation(AffineForOp forOpA, AffineForOp forOpB,
133                                          unsigned i, unsigned j,
134                                          unsigned loopDepth,
135                                          unsigned maxLoopDepth) {
136   for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
137     mlir::ComputationSliceState sliceUnion;
138     FusionResult result = mlir::canFuseLoops(forOpA, forOpB, d, &sliceUnion);
139     if (result.value == FusionResult::Success) {
140       mlir::fuseLoops(forOpA, forOpB, sliceUnion);
141       // Note: 'forOpA' is removed to simplify test output. A proper loop
142       // fusion pass should check the data dependence graph and run memref
143       // region analysis to ensure removing 'forOpA' is safe.
144       forOpA.erase();
145       return true;
146     }
147   }
148   return false;
149 }
150 
151 using LoopFunc = function_ref<bool(AffineForOp, AffineForOp, unsigned, unsigned,
152                                    unsigned, unsigned)>;
153 
154 // Run tests on all combinations of src/dst loop nests in 'depthToLoops'.
155 // If 'return_on_change' is true, returns on first invocation of 'fn' which
156 // returns true.
157 static bool iterateLoops(ArrayRef<SmallVector<AffineForOp, 2>> depthToLoops,
158                          LoopFunc fn, bool returnOnChange = false) {
159   bool changed = false;
160   for (unsigned loopDepth = 0, end = depthToLoops.size(); loopDepth < end;
161        ++loopDepth) {
162     auto &loops = depthToLoops[loopDepth];
163     unsigned numLoops = loops.size();
164     for (unsigned j = 0; j < numLoops; ++j) {
165       for (unsigned k = 0; k < numLoops; ++k) {
166         if (j != k)
167           changed |=
168               fn(loops[j], loops[k], j, k, loopDepth, depthToLoops.size());
169         if (changed && returnOnChange)
170           return true;
171       }
172     }
173   }
174   return changed;
175 }
176 
177 void TestLoopFusion::runOnOperation() {
178   std::vector<SmallVector<AffineForOp, 2>> depthToLoops;
179   if (clTestLoopFusionTransformation) {
180     // Run loop fusion until a fixed point is reached.
181     do {
182       depthToLoops.clear();
183       // Gather all AffineForOps by loop depth.
184       gatherLoops(getOperation(), depthToLoops);
185 
186       // Try to fuse all combinations of src/dst loop nests in 'depthToLoops'.
187     } while (iterateLoops(depthToLoops, testLoopFusionTransformation,
188                           /*returnOnChange=*/true));
189     return;
190   }
191 
192   // Gather all AffineForOps by loop depth.
193   gatherLoops(getOperation(), depthToLoops);
194 
195   // Run tests on all combinations of src/dst loop nests in 'depthToLoops'.
196   if (clTestDependenceCheck)
197     iterateLoops(depthToLoops, testDependenceCheck);
198   if (clTestSliceComputation)
199     iterateLoops(depthToLoops, testSliceComputation);
200 }
201 
202 namespace mlir {
203 namespace test {
204 void registerTestLoopFusion() { PassRegistration<TestLoopFusion>(); }
205 } // namespace test
206 } // namespace mlir
207