1 //===- TestLoopFusion.cpp - Test loop fusion ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to test various loop fusion utility functions.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/Affine/Analysis/Utils.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Affine/LoopFusionUtils.h"
16 #include "mlir/Dialect/Affine/LoopUtils.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Pass/Pass.h"
19
20 #define DEBUG_TYPE "test-loop-fusion"
21
22 using namespace mlir;
23
24 static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
25
26 static llvm::cl::opt<bool> clTestDependenceCheck(
27 "test-loop-fusion-dependence-check",
28 llvm::cl::desc("Enable testing of loop fusion dependence check"),
29 llvm::cl::cat(clOptionsCategory));
30
31 static llvm::cl::opt<bool> clTestSliceComputation(
32 "test-loop-fusion-slice-computation",
33 llvm::cl::desc("Enable testing of loop fusion slice computation"),
34 llvm::cl::cat(clOptionsCategory));
35
36 static llvm::cl::opt<bool> clTestLoopFusionTransformation(
37 "test-loop-fusion-transformation",
38 llvm::cl::desc("Enable testing of loop fusion transformation"),
39 llvm::cl::cat(clOptionsCategory));
40
41 namespace {
42
43 struct TestLoopFusion
44 : public PassWrapper<TestLoopFusion, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon6d9b19930111::TestLoopFusion45 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopFusion)
46
47 StringRef getArgument() const final { return "test-loop-fusion"; }
getDescription__anon6d9b19930111::TestLoopFusion48 StringRef getDescription() const final {
49 return "Tests loop fusion utility functions.";
50 }
51 void runOnOperation() override;
52 };
53
54 } // namespace
55
56 // Run fusion dependence check on 'loops[i]' and 'loops[j]' at loop depths
57 // in range ['loopDepth' + 1, 'maxLoopDepth'].
58 // Emits a remark on 'loops[i]' if a fusion-preventing dependence exists.
59 // Returns false as IR is not transformed.
testDependenceCheck(AffineForOp srcForOp,AffineForOp dstForOp,unsigned i,unsigned j,unsigned loopDepth,unsigned maxLoopDepth)60 static bool testDependenceCheck(AffineForOp srcForOp, AffineForOp dstForOp,
61 unsigned i, unsigned j, unsigned loopDepth,
62 unsigned maxLoopDepth) {
63 mlir::ComputationSliceState sliceUnion;
64 for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
65 FusionResult result =
66 mlir::canFuseLoops(srcForOp, dstForOp, d, &sliceUnion);
67 if (result.value == FusionResult::FailBlockDependence) {
68 srcForOp->emitRemark("block-level dependence preventing"
69 " fusion of loop nest ")
70 << i << " into loop nest " << j << " at depth " << loopDepth;
71 }
72 }
73 return false;
74 }
75
76 // Returns the index of 'op' in its block.
getBlockIndex(Operation & op)77 static unsigned getBlockIndex(Operation &op) {
78 unsigned index = 0;
79 for (auto &opX : *op.getBlock()) {
80 if (&op == &opX)
81 break;
82 ++index;
83 }
84 return index;
85 }
86
87 // Returns a string representation of 'sliceUnion'.
getSliceStr(const mlir::ComputationSliceState & sliceUnion)88 static std::string getSliceStr(const mlir::ComputationSliceState &sliceUnion) {
89 std::string result;
90 llvm::raw_string_ostream os(result);
91 // Slice insertion point format [loop-depth, operation-block-index]
92 unsigned ipd = getNestingDepth(&*sliceUnion.insertPoint);
93 unsigned ipb = getBlockIndex(*sliceUnion.insertPoint);
94 os << "insert point: (" << std::to_string(ipd) << ", " << std::to_string(ipb)
95 << ")";
96 assert(sliceUnion.lbs.size() == sliceUnion.ubs.size());
97 os << " loop bounds: ";
98 for (unsigned k = 0, e = sliceUnion.lbs.size(); k < e; ++k) {
99 os << '[';
100 sliceUnion.lbs[k].print(os);
101 os << ", ";
102 sliceUnion.ubs[k].print(os);
103 os << "] ";
104 }
105 return os.str();
106 }
107
108 /// Computes fusion slice union on 'loops[i]' and 'loops[j]' at loop depths
109 /// in range ['loopDepth' + 1, 'maxLoopDepth'].
110 /// Emits a string representation of the slice union as a remark on 'loops[j]'
111 /// and marks this as incorrect slice if the slice is invalid. Returns false as
112 /// IR is not transformed.
testSliceComputation(AffineForOp forOpA,AffineForOp forOpB,unsigned i,unsigned j,unsigned loopDepth,unsigned maxLoopDepth)113 static bool testSliceComputation(AffineForOp forOpA, AffineForOp forOpB,
114 unsigned i, unsigned j, unsigned loopDepth,
115 unsigned maxLoopDepth) {
116 for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
117 mlir::ComputationSliceState sliceUnion;
118 FusionResult result = mlir::canFuseLoops(forOpA, forOpB, d, &sliceUnion);
119 if (result.value == FusionResult::Success) {
120 forOpB->emitRemark("slice (")
121 << " src loop: " << i << ", dst loop: " << j << ", depth: " << d
122 << " : " << getSliceStr(sliceUnion) << ")";
123 } else if (result.value == FusionResult::FailIncorrectSlice) {
124 forOpB->emitRemark("Incorrect slice (")
125 << " src loop: " << i << ", dst loop: " << j << ", depth: " << d
126 << " : " << getSliceStr(sliceUnion) << ")";
127 }
128 }
129 return false;
130 }
131
132 // Attempts to fuse 'forOpA' into 'forOpB' at loop depths in range
133 // ['loopDepth' + 1, 'maxLoopDepth'].
134 // Returns true if loops were successfully fused, false otherwise.
testLoopFusionTransformation(AffineForOp forOpA,AffineForOp forOpB,unsigned i,unsigned j,unsigned loopDepth,unsigned maxLoopDepth)135 static bool testLoopFusionTransformation(AffineForOp forOpA, AffineForOp forOpB,
136 unsigned i, unsigned j,
137 unsigned loopDepth,
138 unsigned maxLoopDepth) {
139 for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
140 mlir::ComputationSliceState sliceUnion;
141 FusionResult result = mlir::canFuseLoops(forOpA, forOpB, d, &sliceUnion);
142 if (result.value == FusionResult::Success) {
143 mlir::fuseLoops(forOpA, forOpB, sliceUnion);
144 // Note: 'forOpA' is removed to simplify test output. A proper loop
145 // fusion pass should check the data dependence graph and run memref
146 // region analysis to ensure removing 'forOpA' is safe.
147 forOpA.erase();
148 return true;
149 }
150 }
151 return false;
152 }
153
154 using LoopFunc = function_ref<bool(AffineForOp, AffineForOp, unsigned, unsigned,
155 unsigned, unsigned)>;
156
157 // Run tests on all combinations of src/dst loop nests in 'depthToLoops'.
158 // If 'return_on_change' is true, returns on first invocation of 'fn' which
159 // returns true.
iterateLoops(ArrayRef<SmallVector<AffineForOp,2>> depthToLoops,LoopFunc fn,bool returnOnChange=false)160 static bool iterateLoops(ArrayRef<SmallVector<AffineForOp, 2>> depthToLoops,
161 LoopFunc fn, bool returnOnChange = false) {
162 bool changed = false;
163 for (unsigned loopDepth = 0, end = depthToLoops.size(); loopDepth < end;
164 ++loopDepth) {
165 auto &loops = depthToLoops[loopDepth];
166 unsigned numLoops = loops.size();
167 for (unsigned j = 0; j < numLoops; ++j) {
168 for (unsigned k = 0; k < numLoops; ++k) {
169 if (j != k)
170 changed |=
171 fn(loops[j], loops[k], j, k, loopDepth, depthToLoops.size());
172 if (changed && returnOnChange)
173 return true;
174 }
175 }
176 }
177 return changed;
178 }
179
runOnOperation()180 void TestLoopFusion::runOnOperation() {
181 std::vector<SmallVector<AffineForOp, 2>> depthToLoops;
182 if (clTestLoopFusionTransformation) {
183 // Run loop fusion until a fixed point is reached.
184 do {
185 depthToLoops.clear();
186 // Gather all AffineForOps by loop depth.
187 gatherLoops(getOperation(), depthToLoops);
188
189 // Try to fuse all combinations of src/dst loop nests in 'depthToLoops'.
190 } while (iterateLoops(depthToLoops, testLoopFusionTransformation,
191 /*returnOnChange=*/true));
192 return;
193 }
194
195 // Gather all AffineForOps by loop depth.
196 gatherLoops(getOperation(), depthToLoops);
197
198 // Run tests on all combinations of src/dst loop nests in 'depthToLoops'.
199 if (clTestDependenceCheck)
200 iterateLoops(depthToLoops, testDependenceCheck);
201 if (clTestSliceComputation)
202 iterateLoops(depthToLoops, testSliceComputation);
203 }
204
205 namespace mlir {
206 namespace test {
registerTestLoopFusion()207 void registerTestLoopFusion() { PassRegistration<TestLoopFusion>(); }
208 } // namespace test
209 } // namespace mlir
210