1 //===- TestLoopFusion.cpp - Test loop fusion ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to test various loop fusion utility functions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/Analysis/Utils.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Affine/LoopFusionUtils.h"
16 #include "mlir/Dialect/Affine/LoopUtils.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Pass/Pass.h"
19 
20 #define DEBUG_TYPE "test-loop-fusion"
21 
22 using namespace mlir;
23 
24 static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
25 
26 static llvm::cl::opt<bool> clTestDependenceCheck(
27     "test-loop-fusion-dependence-check",
28     llvm::cl::desc("Enable testing of loop fusion dependence check"),
29     llvm::cl::cat(clOptionsCategory));
30 
31 static llvm::cl::opt<bool> clTestSliceComputation(
32     "test-loop-fusion-slice-computation",
33     llvm::cl::desc("Enable testing of loop fusion slice computation"),
34     llvm::cl::cat(clOptionsCategory));
35 
36 static llvm::cl::opt<bool> clTestLoopFusionTransformation(
37     "test-loop-fusion-transformation",
38     llvm::cl::desc("Enable testing of loop fusion transformation"),
39     llvm::cl::cat(clOptionsCategory));
40 
41 namespace {
42 
43 struct TestLoopFusion
44     : public PassWrapper<TestLoopFusion, OperationPass<func::FuncOp>> {
45   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopFusion)
46 
47   StringRef getArgument() const final { return "test-loop-fusion"; }
48   StringRef getDescription() const final {
49     return "Tests loop fusion utility functions.";
50   }
51   void runOnOperation() override;
52 };
53 
54 } // namespace
55 
56 // Run fusion dependence check on 'loops[i]' and 'loops[j]' at loop depths
57 // in range ['loopDepth' + 1, 'maxLoopDepth'].
58 // Emits a remark on 'loops[i]' if a fusion-preventing dependence exists.
59 // Returns false as IR is not transformed.
60 static bool testDependenceCheck(AffineForOp srcForOp, AffineForOp dstForOp,
61                                 unsigned i, unsigned j, unsigned loopDepth,
62                                 unsigned maxLoopDepth) {
63   mlir::ComputationSliceState sliceUnion;
64   for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
65     FusionResult result =
66         mlir::canFuseLoops(srcForOp, dstForOp, d, &sliceUnion);
67     if (result.value == FusionResult::FailBlockDependence) {
68       srcForOp->emitRemark("block-level dependence preventing"
69                            " fusion of loop nest ")
70           << i << " into loop nest " << j << " at depth " << loopDepth;
71     }
72   }
73   return false;
74 }
75 
76 // Returns the index of 'op' in its block.
77 static unsigned getBlockIndex(Operation &op) {
78   unsigned index = 0;
79   for (auto &opX : *op.getBlock()) {
80     if (&op == &opX)
81       break;
82     ++index;
83   }
84   return index;
85 }
86 
87 // Returns a string representation of 'sliceUnion'.
88 static std::string getSliceStr(const mlir::ComputationSliceState &sliceUnion) {
89   std::string result;
90   llvm::raw_string_ostream os(result);
91   // Slice insertion point format [loop-depth, operation-block-index]
92   unsigned ipd = getNestingDepth(&*sliceUnion.insertPoint);
93   unsigned ipb = getBlockIndex(*sliceUnion.insertPoint);
94   os << "insert point: (" << std::to_string(ipd) << ", " << std::to_string(ipb)
95      << ")";
96   assert(sliceUnion.lbs.size() == sliceUnion.ubs.size());
97   os << " loop bounds: ";
98   for (unsigned k = 0, e = sliceUnion.lbs.size(); k < e; ++k) {
99     os << '[';
100     sliceUnion.lbs[k].print(os);
101     os << ", ";
102     sliceUnion.ubs[k].print(os);
103     os << "] ";
104   }
105   return os.str();
106 }
107 
108 /// Computes fusion slice union on 'loops[i]' and 'loops[j]' at loop depths
109 /// in range ['loopDepth' + 1, 'maxLoopDepth'].
110 /// Emits a string representation of the slice union as a remark on 'loops[j]'
111 /// and marks this as incorrect slice if the slice is invalid. Returns false as
112 /// IR is not transformed.
113 static bool testSliceComputation(AffineForOp forOpA, AffineForOp forOpB,
114                                  unsigned i, unsigned j, unsigned loopDepth,
115                                  unsigned maxLoopDepth) {
116   for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
117     mlir::ComputationSliceState sliceUnion;
118     FusionResult result = mlir::canFuseLoops(forOpA, forOpB, d, &sliceUnion);
119     if (result.value == FusionResult::Success) {
120       forOpB->emitRemark("slice (")
121           << " src loop: " << i << ", dst loop: " << j << ", depth: " << d
122           << " : " << getSliceStr(sliceUnion) << ")";
123     } else if (result.value == FusionResult::FailIncorrectSlice) {
124       forOpB->emitRemark("Incorrect slice (")
125           << " src loop: " << i << ", dst loop: " << j << ", depth: " << d
126           << " : " << getSliceStr(sliceUnion) << ")";
127     }
128   }
129   return false;
130 }
131 
132 // Attempts to fuse 'forOpA' into 'forOpB' at loop depths in range
133 // ['loopDepth' + 1, 'maxLoopDepth'].
134 // Returns true if loops were successfully fused, false otherwise.
135 static bool testLoopFusionTransformation(AffineForOp forOpA, AffineForOp forOpB,
136                                          unsigned i, unsigned j,
137                                          unsigned loopDepth,
138                                          unsigned maxLoopDepth) {
139   for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
140     mlir::ComputationSliceState sliceUnion;
141     FusionResult result = mlir::canFuseLoops(forOpA, forOpB, d, &sliceUnion);
142     if (result.value == FusionResult::Success) {
143       mlir::fuseLoops(forOpA, forOpB, sliceUnion);
144       // Note: 'forOpA' is removed to simplify test output. A proper loop
145       // fusion pass should check the data dependence graph and run memref
146       // region analysis to ensure removing 'forOpA' is safe.
147       forOpA.erase();
148       return true;
149     }
150   }
151   return false;
152 }
153 
154 using LoopFunc = function_ref<bool(AffineForOp, AffineForOp, unsigned, unsigned,
155                                    unsigned, unsigned)>;
156 
157 // Run tests on all combinations of src/dst loop nests in 'depthToLoops'.
158 // If 'return_on_change' is true, returns on first invocation of 'fn' which
159 // returns true.
160 static bool iterateLoops(ArrayRef<SmallVector<AffineForOp, 2>> depthToLoops,
161                          LoopFunc fn, bool returnOnChange = false) {
162   bool changed = false;
163   for (unsigned loopDepth = 0, end = depthToLoops.size(); loopDepth < end;
164        ++loopDepth) {
165     auto &loops = depthToLoops[loopDepth];
166     unsigned numLoops = loops.size();
167     for (unsigned j = 0; j < numLoops; ++j) {
168       for (unsigned k = 0; k < numLoops; ++k) {
169         if (j != k)
170           changed |=
171               fn(loops[j], loops[k], j, k, loopDepth, depthToLoops.size());
172         if (changed && returnOnChange)
173           return true;
174       }
175     }
176   }
177   return changed;
178 }
179 
180 void TestLoopFusion::runOnOperation() {
181   std::vector<SmallVector<AffineForOp, 2>> depthToLoops;
182   if (clTestLoopFusionTransformation) {
183     // Run loop fusion until a fixed point is reached.
184     do {
185       depthToLoops.clear();
186       // Gather all AffineForOps by loop depth.
187       gatherLoops(getOperation(), depthToLoops);
188 
189       // Try to fuse all combinations of src/dst loop nests in 'depthToLoops'.
190     } while (iterateLoops(depthToLoops, testLoopFusionTransformation,
191                           /*returnOnChange=*/true));
192     return;
193   }
194 
195   // Gather all AffineForOps by loop depth.
196   gatherLoops(getOperation(), depthToLoops);
197 
198   // Run tests on all combinations of src/dst loop nests in 'depthToLoops'.
199   if (clTestDependenceCheck)
200     iterateLoops(depthToLoops, testDependenceCheck);
201   if (clTestSliceComputation)
202     iterateLoops(depthToLoops, testSliceComputation);
203 }
204 
205 namespace mlir {
206 namespace test {
207 void registerTestLoopFusion() { PassRegistration<TestLoopFusion>(); }
208 } // namespace test
209 } // namespace mlir
210