1 //===- ControlFlowInterfacesTest.cpp - Unit Tests for Control Flow Interf. ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Interfaces/ControlFlowInterfaces.h"
10 #include "mlir/IR/BuiltinOps.h"
11 #include "mlir/IR/Dialect.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/OpDefinition.h"
14 #include "mlir/IR/OpImplementation.h"
15 #include "mlir/Parser/Parser.h"
16
17 #include <gtest/gtest.h>
18
19 using namespace mlir;
20
21 /// A dummy op that is also a terminator.
22 struct DummyOp : public Op<DummyOp, OpTrait::IsTerminator> {
23 using Op::Op;
getAttributeNamesDummyOp24 static ArrayRef<StringRef> getAttributeNames() { return {}; }
25
getOperationNameDummyOp26 static StringRef getOperationName() { return "cftest.dummy_op"; }
27 };
28
29 /// All regions of this op are mutually exclusive.
30 struct MutuallyExclusiveRegionsOp
31 : public Op<MutuallyExclusiveRegionsOp, RegionBranchOpInterface::Trait> {
32 using Op::Op;
getAttributeNamesMutuallyExclusiveRegionsOp33 static ArrayRef<StringRef> getAttributeNames() { return {}; }
34
getOperationNameMutuallyExclusiveRegionsOp35 static StringRef getOperationName() {
36 return "cftest.mutually_exclusive_regions_op";
37 }
38
39 // Regions have no successors.
getSuccessorRegionsMutuallyExclusiveRegionsOp40 void getSuccessorRegions(Optional<unsigned> index,
41 ArrayRef<Attribute> operands,
42 SmallVectorImpl<RegionSuccessor> ®ions) {}
43 };
44
45 /// All regions of this op call each other in a large circle.
46 struct LoopRegionsOp
47 : public Op<LoopRegionsOp, RegionBranchOpInterface::Trait> {
48 using Op::Op;
49 static const unsigned kNumRegions = 3;
50
getAttributeNamesLoopRegionsOp51 static ArrayRef<StringRef> getAttributeNames() { return {}; }
52
getOperationNameLoopRegionsOp53 static StringRef getOperationName() { return "cftest.loop_regions_op"; }
54
getSuccessorRegionsLoopRegionsOp55 void getSuccessorRegions(Optional<unsigned> index,
56 ArrayRef<Attribute> operands,
57 SmallVectorImpl<RegionSuccessor> ®ions) {
58 if (index) {
59 if (*index == 1)
60 // This region also branches back to the parent.
61 regions.push_back(RegionSuccessor());
62 regions.push_back(
63 RegionSuccessor(&getOperation()->getRegion(*index % kNumRegions)));
64 }
65 }
66 };
67
68 /// Each region branches back it itself or the parent.
69 struct DoubleLoopRegionsOp
70 : public Op<DoubleLoopRegionsOp, RegionBranchOpInterface::Trait> {
71 using Op::Op;
72
getAttributeNamesDoubleLoopRegionsOp73 static ArrayRef<StringRef> getAttributeNames() { return {}; }
74
getOperationNameDoubleLoopRegionsOp75 static StringRef getOperationName() {
76 return "cftest.double_loop_regions_op";
77 }
78
getSuccessorRegionsDoubleLoopRegionsOp79 void getSuccessorRegions(Optional<unsigned> index,
80 ArrayRef<Attribute> operands,
81 SmallVectorImpl<RegionSuccessor> ®ions) {
82 if (index.has_value()) {
83 regions.push_back(RegionSuccessor());
84 regions.push_back(RegionSuccessor(&getOperation()->getRegion(*index)));
85 }
86 }
87 };
88
89 /// Regions are executed sequentially.
90 struct SequentialRegionsOp
91 : public Op<SequentialRegionsOp, RegionBranchOpInterface::Trait> {
92 using Op::Op;
getAttributeNamesSequentialRegionsOp93 static ArrayRef<StringRef> getAttributeNames() { return {}; }
94
getOperationNameSequentialRegionsOp95 static StringRef getOperationName() { return "cftest.sequential_regions_op"; }
96
97 // Region 0 has Region 1 as a successor.
getSuccessorRegionsSequentialRegionsOp98 void getSuccessorRegions(Optional<unsigned> index,
99 ArrayRef<Attribute> operands,
100 SmallVectorImpl<RegionSuccessor> ®ions) {
101 if (index == 0u) {
102 Operation *thisOp = this->getOperation();
103 regions.push_back(RegionSuccessor(&thisOp->getRegion(1)));
104 }
105 }
106 };
107
108 /// A dialect putting all the above together.
109 struct CFTestDialect : Dialect {
CFTestDialectCFTestDialect110 explicit CFTestDialect(MLIRContext *ctx)
111 : Dialect(getDialectNamespace(), ctx, TypeID::get<CFTestDialect>()) {
112 addOperations<DummyOp, MutuallyExclusiveRegionsOp, LoopRegionsOp,
113 DoubleLoopRegionsOp, SequentialRegionsOp>();
114 }
getDialectNamespaceCFTestDialect115 static StringRef getDialectNamespace() { return "cftest"; }
116 };
117
TEST(RegionBranchOpInterface,MutuallyExclusiveOps)118 TEST(RegionBranchOpInterface, MutuallyExclusiveOps) {
119 const char *ir = R"MLIR(
120 "cftest.mutually_exclusive_regions_op"() (
121 {"cftest.dummy_op"() : () -> ()}, // op1
122 {"cftest.dummy_op"() : () -> ()} // op2
123 ) : () -> ()
124 )MLIR";
125
126 DialectRegistry registry;
127 registry.insert<CFTestDialect>();
128 MLIRContext ctx(registry);
129
130 OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
131 Operation *testOp = &module->getBody()->getOperations().front();
132 Operation *op1 = &testOp->getRegion(0).front().front();
133 Operation *op2 = &testOp->getRegion(1).front().front();
134
135 EXPECT_TRUE(insideMutuallyExclusiveRegions(op1, op2));
136 EXPECT_TRUE(insideMutuallyExclusiveRegions(op2, op1));
137 }
138
TEST(RegionBranchOpInterface,MutuallyExclusiveOps2)139 TEST(RegionBranchOpInterface, MutuallyExclusiveOps2) {
140 const char *ir = R"MLIR(
141 "cftest.double_loop_regions_op"() (
142 {"cftest.dummy_op"() : () -> ()}, // op1
143 {"cftest.dummy_op"() : () -> ()} // op2
144 ) : () -> ()
145 )MLIR";
146
147 DialectRegistry registry;
148 registry.insert<CFTestDialect>();
149 MLIRContext ctx(registry);
150
151 OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
152 Operation *testOp = &module->getBody()->getOperations().front();
153 Operation *op1 = &testOp->getRegion(0).front().front();
154 Operation *op2 = &testOp->getRegion(1).front().front();
155
156 EXPECT_TRUE(insideMutuallyExclusiveRegions(op1, op2));
157 EXPECT_TRUE(insideMutuallyExclusiveRegions(op2, op1));
158 }
159
TEST(RegionBranchOpInterface,NotMutuallyExclusiveOps)160 TEST(RegionBranchOpInterface, NotMutuallyExclusiveOps) {
161 const char *ir = R"MLIR(
162 "cftest.sequential_regions_op"() (
163 {"cftest.dummy_op"() : () -> ()}, // op1
164 {"cftest.dummy_op"() : () -> ()} // op2
165 ) : () -> ()
166 )MLIR";
167
168 DialectRegistry registry;
169 registry.insert<CFTestDialect>();
170 MLIRContext ctx(registry);
171
172 OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
173 Operation *testOp = &module->getBody()->getOperations().front();
174 Operation *op1 = &testOp->getRegion(0).front().front();
175 Operation *op2 = &testOp->getRegion(1).front().front();
176
177 EXPECT_FALSE(insideMutuallyExclusiveRegions(op1, op2));
178 EXPECT_FALSE(insideMutuallyExclusiveRegions(op2, op1));
179 }
180
TEST(RegionBranchOpInterface,NestedMutuallyExclusiveOps)181 TEST(RegionBranchOpInterface, NestedMutuallyExclusiveOps) {
182 const char *ir = R"MLIR(
183 "cftest.mutually_exclusive_regions_op"() (
184 {
185 "cftest.sequential_regions_op"() (
186 {"cftest.dummy_op"() : () -> ()}, // op1
187 {"cftest.dummy_op"() : () -> ()} // op3
188 ) : () -> ()
189 "cftest.dummy_op"() : () -> ()
190 },
191 {"cftest.dummy_op"() : () -> ()} // op2
192 ) : () -> ()
193 )MLIR";
194
195 DialectRegistry registry;
196 registry.insert<CFTestDialect>();
197 MLIRContext ctx(registry);
198
199 OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
200 Operation *testOp = &module->getBody()->getOperations().front();
201 Operation *op1 =
202 &testOp->getRegion(0).front().front().getRegion(0).front().front();
203 Operation *op2 = &testOp->getRegion(1).front().front();
204 Operation *op3 =
205 &testOp->getRegion(0).front().front().getRegion(1).front().front();
206
207 EXPECT_TRUE(insideMutuallyExclusiveRegions(op1, op2));
208 EXPECT_TRUE(insideMutuallyExclusiveRegions(op3, op2));
209 EXPECT_FALSE(insideMutuallyExclusiveRegions(op1, op3));
210 }
211
TEST(RegionBranchOpInterface,RecursiveRegions)212 TEST(RegionBranchOpInterface, RecursiveRegions) {
213 const char *ir = R"MLIR(
214 "cftest.loop_regions_op"() (
215 {"cftest.dummy_op"() : () -> ()}, // op1
216 {"cftest.dummy_op"() : () -> ()}, // op2
217 {"cftest.dummy_op"() : () -> ()} // op3
218 ) : () -> ()
219 )MLIR";
220
221 DialectRegistry registry;
222 registry.insert<CFTestDialect>();
223 MLIRContext ctx(registry);
224
225 OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
226 Operation *testOp = &module->getBody()->getOperations().front();
227 auto regionOp = cast<RegionBranchOpInterface>(testOp);
228 Operation *op1 = &testOp->getRegion(0).front().front();
229 Operation *op2 = &testOp->getRegion(1).front().front();
230 Operation *op3 = &testOp->getRegion(2).front().front();
231
232 EXPECT_TRUE(regionOp.isRepetitiveRegion(0));
233 EXPECT_TRUE(regionOp.isRepetitiveRegion(1));
234 EXPECT_TRUE(regionOp.isRepetitiveRegion(2));
235 EXPECT_NE(getEnclosingRepetitiveRegion(op1), nullptr);
236 EXPECT_NE(getEnclosingRepetitiveRegion(op2), nullptr);
237 EXPECT_NE(getEnclosingRepetitiveRegion(op3), nullptr);
238 }
239
TEST(RegionBranchOpInterface,NotRecursiveRegions)240 TEST(RegionBranchOpInterface, NotRecursiveRegions) {
241 const char *ir = R"MLIR(
242 "cftest.sequential_regions_op"() (
243 {"cftest.dummy_op"() : () -> ()}, // op1
244 {"cftest.dummy_op"() : () -> ()} // op2
245 ) : () -> ()
246 )MLIR";
247
248 DialectRegistry registry;
249 registry.insert<CFTestDialect>();
250 MLIRContext ctx(registry);
251
252 OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
253 Operation *testOp = &module->getBody()->getOperations().front();
254 Operation *op1 = &testOp->getRegion(0).front().front();
255 Operation *op2 = &testOp->getRegion(1).front().front();
256
257 EXPECT_EQ(getEnclosingRepetitiveRegion(op1), nullptr);
258 EXPECT_EQ(getEnclosingRepetitiveRegion(op2), nullptr);
259 }
260