1 //===- TestLoopFusion.cpp - Test loop fusion ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to test various loop fusion utility functions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/Analysis/Utils.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Affine/LoopFusionUtils.h"
16 #include "mlir/Dialect/Affine/LoopUtils.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Pass/Pass.h"
19 
20 #define DEBUG_TYPE "test-loop-fusion"
21 
22 using namespace mlir;
23 
24 static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
25 
26 static llvm::cl::opt<bool> clTestDependenceCheck(
27     "test-loop-fusion-dependence-check",
28     llvm::cl::desc("Enable testing of loop fusion dependence check"),
29     llvm::cl::cat(clOptionsCategory));
30 
31 static llvm::cl::opt<bool> clTestSliceComputation(
32     "test-loop-fusion-slice-computation",
33     llvm::cl::desc("Enable testing of loop fusion slice computation"),
34     llvm::cl::cat(clOptionsCategory));
35 
36 static llvm::cl::opt<bool> clTestLoopFusionTransformation(
37     "test-loop-fusion-transformation",
38     llvm::cl::desc("Enable testing of loop fusion transformation"),
39     llvm::cl::cat(clOptionsCategory));
40 
41 namespace {
42 
43 struct TestLoopFusion
44     : public PassWrapper<TestLoopFusion, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon6d9b19930111::TestLoopFusion45   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopFusion)
46 
47   StringRef getArgument() const final { return "test-loop-fusion"; }
getDescription__anon6d9b19930111::TestLoopFusion48   StringRef getDescription() const final {
49     return "Tests loop fusion utility functions.";
50   }
51   void runOnOperation() override;
52 };
53 
54 } // namespace
55 
56 // Run fusion dependence check on 'loops[i]' and 'loops[j]' at loop depths
57 // in range ['loopDepth' + 1, 'maxLoopDepth'].
58 // Emits a remark on 'loops[i]' if a fusion-preventing dependence exists.
59 // Returns false as IR is not transformed.
testDependenceCheck(AffineForOp srcForOp,AffineForOp dstForOp,unsigned i,unsigned j,unsigned loopDepth,unsigned maxLoopDepth)60 static bool testDependenceCheck(AffineForOp srcForOp, AffineForOp dstForOp,
61                                 unsigned i, unsigned j, unsigned loopDepth,
62                                 unsigned maxLoopDepth) {
63   mlir::ComputationSliceState sliceUnion;
64   for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
65     FusionResult result =
66         mlir::canFuseLoops(srcForOp, dstForOp, d, &sliceUnion);
67     if (result.value == FusionResult::FailBlockDependence) {
68       srcForOp->emitRemark("block-level dependence preventing"
69                            " fusion of loop nest ")
70           << i << " into loop nest " << j << " at depth " << loopDepth;
71     }
72   }
73   return false;
74 }
75 
76 // Returns the index of 'op' in its block.
getBlockIndex(Operation & op)77 static unsigned getBlockIndex(Operation &op) {
78   unsigned index = 0;
79   for (auto &opX : *op.getBlock()) {
80     if (&op == &opX)
81       break;
82     ++index;
83   }
84   return index;
85 }
86 
87 // Returns a string representation of 'sliceUnion'.
getSliceStr(const mlir::ComputationSliceState & sliceUnion)88 static std::string getSliceStr(const mlir::ComputationSliceState &sliceUnion) {
89   std::string result;
90   llvm::raw_string_ostream os(result);
91   // Slice insertion point format [loop-depth, operation-block-index]
92   unsigned ipd = getNestingDepth(&*sliceUnion.insertPoint);
93   unsigned ipb = getBlockIndex(*sliceUnion.insertPoint);
94   os << "insert point: (" << std::to_string(ipd) << ", " << std::to_string(ipb)
95      << ")";
96   assert(sliceUnion.lbs.size() == sliceUnion.ubs.size());
97   os << " loop bounds: ";
98   for (unsigned k = 0, e = sliceUnion.lbs.size(); k < e; ++k) {
99     os << '[';
100     sliceUnion.lbs[k].print(os);
101     os << ", ";
102     sliceUnion.ubs[k].print(os);
103     os << "] ";
104   }
105   return os.str();
106 }
107 
108 /// Computes fusion slice union on 'loops[i]' and 'loops[j]' at loop depths
109 /// in range ['loopDepth' + 1, 'maxLoopDepth'].
110 /// Emits a string representation of the slice union as a remark on 'loops[j]'
111 /// and marks this as incorrect slice if the slice is invalid. Returns false as
112 /// IR is not transformed.
testSliceComputation(AffineForOp forOpA,AffineForOp forOpB,unsigned i,unsigned j,unsigned loopDepth,unsigned maxLoopDepth)113 static bool testSliceComputation(AffineForOp forOpA, AffineForOp forOpB,
114                                  unsigned i, unsigned j, unsigned loopDepth,
115                                  unsigned maxLoopDepth) {
116   for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
117     mlir::ComputationSliceState sliceUnion;
118     FusionResult result = mlir::canFuseLoops(forOpA, forOpB, d, &sliceUnion);
119     if (result.value == FusionResult::Success) {
120       forOpB->emitRemark("slice (")
121           << " src loop: " << i << ", dst loop: " << j << ", depth: " << d
122           << " : " << getSliceStr(sliceUnion) << ")";
123     } else if (result.value == FusionResult::FailIncorrectSlice) {
124       forOpB->emitRemark("Incorrect slice (")
125           << " src loop: " << i << ", dst loop: " << j << ", depth: " << d
126           << " : " << getSliceStr(sliceUnion) << ")";
127     }
128   }
129   return false;
130 }
131 
132 // Attempts to fuse 'forOpA' into 'forOpB' at loop depths in range
133 // ['loopDepth' + 1, 'maxLoopDepth'].
134 // Returns true if loops were successfully fused, false otherwise.
testLoopFusionTransformation(AffineForOp forOpA,AffineForOp forOpB,unsigned i,unsigned j,unsigned loopDepth,unsigned maxLoopDepth)135 static bool testLoopFusionTransformation(AffineForOp forOpA, AffineForOp forOpB,
136                                          unsigned i, unsigned j,
137                                          unsigned loopDepth,
138                                          unsigned maxLoopDepth) {
139   for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
140     mlir::ComputationSliceState sliceUnion;
141     FusionResult result = mlir::canFuseLoops(forOpA, forOpB, d, &sliceUnion);
142     if (result.value == FusionResult::Success) {
143       mlir::fuseLoops(forOpA, forOpB, sliceUnion);
144       // Note: 'forOpA' is removed to simplify test output. A proper loop
145       // fusion pass should check the data dependence graph and run memref
146       // region analysis to ensure removing 'forOpA' is safe.
147       forOpA.erase();
148       return true;
149     }
150   }
151   return false;
152 }
153 
154 using LoopFunc = function_ref<bool(AffineForOp, AffineForOp, unsigned, unsigned,
155                                    unsigned, unsigned)>;
156 
157 // Run tests on all combinations of src/dst loop nests in 'depthToLoops'.
158 // If 'return_on_change' is true, returns on first invocation of 'fn' which
159 // returns true.
iterateLoops(ArrayRef<SmallVector<AffineForOp,2>> depthToLoops,LoopFunc fn,bool returnOnChange=false)160 static bool iterateLoops(ArrayRef<SmallVector<AffineForOp, 2>> depthToLoops,
161                          LoopFunc fn, bool returnOnChange = false) {
162   bool changed = false;
163   for (unsigned loopDepth = 0, end = depthToLoops.size(); loopDepth < end;
164        ++loopDepth) {
165     auto &loops = depthToLoops[loopDepth];
166     unsigned numLoops = loops.size();
167     for (unsigned j = 0; j < numLoops; ++j) {
168       for (unsigned k = 0; k < numLoops; ++k) {
169         if (j != k)
170           changed |=
171               fn(loops[j], loops[k], j, k, loopDepth, depthToLoops.size());
172         if (changed && returnOnChange)
173           return true;
174       }
175     }
176   }
177   return changed;
178 }
179 
runOnOperation()180 void TestLoopFusion::runOnOperation() {
181   std::vector<SmallVector<AffineForOp, 2>> depthToLoops;
182   if (clTestLoopFusionTransformation) {
183     // Run loop fusion until a fixed point is reached.
184     do {
185       depthToLoops.clear();
186       // Gather all AffineForOps by loop depth.
187       gatherLoops(getOperation(), depthToLoops);
188 
189       // Try to fuse all combinations of src/dst loop nests in 'depthToLoops'.
190     } while (iterateLoops(depthToLoops, testLoopFusionTransformation,
191                           /*returnOnChange=*/true));
192     return;
193   }
194 
195   // Gather all AffineForOps by loop depth.
196   gatherLoops(getOperation(), depthToLoops);
197 
198   // Run tests on all combinations of src/dst loop nests in 'depthToLoops'.
199   if (clTestDependenceCheck)
200     iterateLoops(depthToLoops, testDependenceCheck);
201   if (clTestSliceComputation)
202     iterateLoops(depthToLoops, testSliceComputation);
203 }
204 
205 namespace mlir {
206 namespace test {
registerTestLoopFusion()207 void registerTestLoopFusion() { PassRegistration<TestLoopFusion>(); }
208 } // namespace test
209 } // namespace mlir
210