1a5c2f782SMatthias Springer //===- ControlFlowInterfacesTest.cpp - Unit Tests for Control Flow Interf. ===//
2a5c2f782SMatthias Springer //
3a5c2f782SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a5c2f782SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
5a5c2f782SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a5c2f782SMatthias Springer //
7a5c2f782SMatthias Springer //===----------------------------------------------------------------------===//
8a5c2f782SMatthias Springer
9a5c2f782SMatthias Springer #include "mlir/Interfaces/ControlFlowInterfaces.h"
10a5c2f782SMatthias Springer #include "mlir/IR/BuiltinOps.h"
11a5c2f782SMatthias Springer #include "mlir/IR/Dialect.h"
12a5c2f782SMatthias Springer #include "mlir/IR/DialectImplementation.h"
13a5c2f782SMatthias Springer #include "mlir/IR/OpDefinition.h"
14a5c2f782SMatthias Springer #include "mlir/IR/OpImplementation.h"
159eaff423SRiver Riddle #include "mlir/Parser/Parser.h"
16a5c2f782SMatthias Springer
17a5c2f782SMatthias Springer #include <gtest/gtest.h>
18a5c2f782SMatthias Springer
19a5c2f782SMatthias Springer using namespace mlir;
20a5c2f782SMatthias Springer
21a5c2f782SMatthias Springer /// A dummy op that is also a terminator.
22a5c2f782SMatthias Springer struct DummyOp : public Op<DummyOp, OpTrait::IsTerminator> {
23a5c2f782SMatthias Springer using Op::Op;
getAttributeNamesDummyOp24a5c2f782SMatthias Springer static ArrayRef<StringRef> getAttributeNames() { return {}; }
25a5c2f782SMatthias Springer
getOperationNameDummyOp26a5c2f782SMatthias Springer static StringRef getOperationName() { return "cftest.dummy_op"; }
27a5c2f782SMatthias Springer };
28a5c2f782SMatthias Springer
29a5c2f782SMatthias Springer /// All regions of this op are mutually exclusive.
30a5c2f782SMatthias Springer struct MutuallyExclusiveRegionsOp
31a5c2f782SMatthias Springer : public Op<MutuallyExclusiveRegionsOp, RegionBranchOpInterface::Trait> {
32a5c2f782SMatthias Springer using Op::Op;
getAttributeNamesMutuallyExclusiveRegionsOp33a5c2f782SMatthias Springer static ArrayRef<StringRef> getAttributeNames() { return {}; }
34a5c2f782SMatthias Springer
getOperationNameMutuallyExclusiveRegionsOp35a5c2f782SMatthias Springer static StringRef getOperationName() {
36a5c2f782SMatthias Springer return "cftest.mutually_exclusive_regions_op";
37a5c2f782SMatthias Springer }
38a5c2f782SMatthias Springer
39a5c2f782SMatthias Springer // Regions have no successors.
getSuccessorRegionsMutuallyExclusiveRegionsOp40a5c2f782SMatthias Springer void getSuccessorRegions(Optional<unsigned> index,
41a5c2f782SMatthias Springer ArrayRef<Attribute> operands,
42a5c2f782SMatthias Springer SmallVectorImpl<RegionSuccessor> ®ions) {}
43a5c2f782SMatthias Springer };
44a5c2f782SMatthias Springer
450f4ba02dSMatthias Springer /// All regions of this op call each other in a large circle.
460f4ba02dSMatthias Springer struct LoopRegionsOp
470f4ba02dSMatthias Springer : public Op<LoopRegionsOp, RegionBranchOpInterface::Trait> {
480f4ba02dSMatthias Springer using Op::Op;
490f4ba02dSMatthias Springer static const unsigned kNumRegions = 3;
500f4ba02dSMatthias Springer
getAttributeNamesLoopRegionsOp510f4ba02dSMatthias Springer static ArrayRef<StringRef> getAttributeNames() { return {}; }
520f4ba02dSMatthias Springer
getOperationNameLoopRegionsOp530f4ba02dSMatthias Springer static StringRef getOperationName() { return "cftest.loop_regions_op"; }
540f4ba02dSMatthias Springer
getSuccessorRegionsLoopRegionsOp550f4ba02dSMatthias Springer void getSuccessorRegions(Optional<unsigned> index,
560f4ba02dSMatthias Springer ArrayRef<Attribute> operands,
570f4ba02dSMatthias Springer SmallVectorImpl<RegionSuccessor> ®ions) {
580f4ba02dSMatthias Springer if (index) {
590f4ba02dSMatthias Springer if (*index == 1)
600f4ba02dSMatthias Springer // This region also branches back to the parent.
610f4ba02dSMatthias Springer regions.push_back(RegionSuccessor());
620f4ba02dSMatthias Springer regions.push_back(
630f4ba02dSMatthias Springer RegionSuccessor(&getOperation()->getRegion(*index % kNumRegions)));
640f4ba02dSMatthias Springer }
650f4ba02dSMatthias Springer }
660f4ba02dSMatthias Springer };
670f4ba02dSMatthias Springer
68a3005a40SMatthias Springer /// Each region branches back it itself or the parent.
69a3005a40SMatthias Springer struct DoubleLoopRegionsOp
70a3005a40SMatthias Springer : public Op<DoubleLoopRegionsOp, RegionBranchOpInterface::Trait> {
71a3005a40SMatthias Springer using Op::Op;
72a3005a40SMatthias Springer
getAttributeNamesDoubleLoopRegionsOp73a3005a40SMatthias Springer static ArrayRef<StringRef> getAttributeNames() { return {}; }
74a3005a40SMatthias Springer
getOperationNameDoubleLoopRegionsOp75a3005a40SMatthias Springer static StringRef getOperationName() {
76a3005a40SMatthias Springer return "cftest.double_loop_regions_op";
77a3005a40SMatthias Springer }
78a3005a40SMatthias Springer
getSuccessorRegionsDoubleLoopRegionsOp79a3005a40SMatthias Springer void getSuccessorRegions(Optional<unsigned> index,
80a3005a40SMatthias Springer ArrayRef<Attribute> operands,
81a3005a40SMatthias Springer SmallVectorImpl<RegionSuccessor> ®ions) {
82*491d2701SKazu Hirata if (index.has_value()) {
83a3005a40SMatthias Springer regions.push_back(RegionSuccessor());
84a3005a40SMatthias Springer regions.push_back(RegionSuccessor(&getOperation()->getRegion(*index)));
85a3005a40SMatthias Springer }
86a3005a40SMatthias Springer }
87a3005a40SMatthias Springer };
88a3005a40SMatthias Springer
89a5c2f782SMatthias Springer /// Regions are executed sequentially.
90a5c2f782SMatthias Springer struct SequentialRegionsOp
91a5c2f782SMatthias Springer : public Op<SequentialRegionsOp, RegionBranchOpInterface::Trait> {
92a5c2f782SMatthias Springer using Op::Op;
getAttributeNamesSequentialRegionsOp93a5c2f782SMatthias Springer static ArrayRef<StringRef> getAttributeNames() { return {}; }
94a5c2f782SMatthias Springer
getOperationNameSequentialRegionsOp95a5c2f782SMatthias Springer static StringRef getOperationName() { return "cftest.sequential_regions_op"; }
96a5c2f782SMatthias Springer
97a5c2f782SMatthias Springer // Region 0 has Region 1 as a successor.
getSuccessorRegionsSequentialRegionsOp98a5c2f782SMatthias Springer void getSuccessorRegions(Optional<unsigned> index,
99a5c2f782SMatthias Springer ArrayRef<Attribute> operands,
100a5c2f782SMatthias Springer SmallVectorImpl<RegionSuccessor> ®ions) {
101e7c7b16aSMogball if (index == 0u) {
102a5c2f782SMatthias Springer Operation *thisOp = this->getOperation();
103a5c2f782SMatthias Springer regions.push_back(RegionSuccessor(&thisOp->getRegion(1)));
104a5c2f782SMatthias Springer }
105a5c2f782SMatthias Springer }
106a5c2f782SMatthias Springer };
107a5c2f782SMatthias Springer
108a5c2f782SMatthias Springer /// A dialect putting all the above together.
109a5c2f782SMatthias Springer struct CFTestDialect : Dialect {
CFTestDialectCFTestDialect110a5c2f782SMatthias Springer explicit CFTestDialect(MLIRContext *ctx)
111a5c2f782SMatthias Springer : Dialect(getDialectNamespace(), ctx, TypeID::get<CFTestDialect>()) {
1120f4ba02dSMatthias Springer addOperations<DummyOp, MutuallyExclusiveRegionsOp, LoopRegionsOp,
113a3005a40SMatthias Springer DoubleLoopRegionsOp, SequentialRegionsOp>();
114a5c2f782SMatthias Springer }
getDialectNamespaceCFTestDialect115a5c2f782SMatthias Springer static StringRef getDialectNamespace() { return "cftest"; }
116a5c2f782SMatthias Springer };
117a5c2f782SMatthias Springer
TEST(RegionBranchOpInterface,MutuallyExclusiveOps)118a5c2f782SMatthias Springer TEST(RegionBranchOpInterface, MutuallyExclusiveOps) {
119a5c2f782SMatthias Springer const char *ir = R"MLIR(
120a5c2f782SMatthias Springer "cftest.mutually_exclusive_regions_op"() (
121a5c2f782SMatthias Springer {"cftest.dummy_op"() : () -> ()}, // op1
122a5c2f782SMatthias Springer {"cftest.dummy_op"() : () -> ()} // op2
123a5c2f782SMatthias Springer ) : () -> ()
124a5c2f782SMatthias Springer )MLIR";
125a5c2f782SMatthias Springer
126a5c2f782SMatthias Springer DialectRegistry registry;
127a5c2f782SMatthias Springer registry.insert<CFTestDialect>();
128a5c2f782SMatthias Springer MLIRContext ctx(registry);
129a5c2f782SMatthias Springer
130dfaadf6bSChristian Sigg OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
131a5c2f782SMatthias Springer Operation *testOp = &module->getBody()->getOperations().front();
132a5c2f782SMatthias Springer Operation *op1 = &testOp->getRegion(0).front().front();
133a5c2f782SMatthias Springer Operation *op2 = &testOp->getRegion(1).front().front();
134a5c2f782SMatthias Springer
135a5c2f782SMatthias Springer EXPECT_TRUE(insideMutuallyExclusiveRegions(op1, op2));
136a5c2f782SMatthias Springer EXPECT_TRUE(insideMutuallyExclusiveRegions(op2, op1));
137a5c2f782SMatthias Springer }
138a5c2f782SMatthias Springer
TEST(RegionBranchOpInterface,MutuallyExclusiveOps2)139a3005a40SMatthias Springer TEST(RegionBranchOpInterface, MutuallyExclusiveOps2) {
140a3005a40SMatthias Springer const char *ir = R"MLIR(
141a3005a40SMatthias Springer "cftest.double_loop_regions_op"() (
142a3005a40SMatthias Springer {"cftest.dummy_op"() : () -> ()}, // op1
143a3005a40SMatthias Springer {"cftest.dummy_op"() : () -> ()} // op2
144a3005a40SMatthias Springer ) : () -> ()
145a3005a40SMatthias Springer )MLIR";
146a3005a40SMatthias Springer
147a3005a40SMatthias Springer DialectRegistry registry;
148a3005a40SMatthias Springer registry.insert<CFTestDialect>();
149a3005a40SMatthias Springer MLIRContext ctx(registry);
150a3005a40SMatthias Springer
151a3005a40SMatthias Springer OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
152a3005a40SMatthias Springer Operation *testOp = &module->getBody()->getOperations().front();
153a3005a40SMatthias Springer Operation *op1 = &testOp->getRegion(0).front().front();
154a3005a40SMatthias Springer Operation *op2 = &testOp->getRegion(1).front().front();
155a3005a40SMatthias Springer
156a3005a40SMatthias Springer EXPECT_TRUE(insideMutuallyExclusiveRegions(op1, op2));
157a3005a40SMatthias Springer EXPECT_TRUE(insideMutuallyExclusiveRegions(op2, op1));
158a3005a40SMatthias Springer }
159a3005a40SMatthias Springer
TEST(RegionBranchOpInterface,NotMutuallyExclusiveOps)160a5c2f782SMatthias Springer TEST(RegionBranchOpInterface, NotMutuallyExclusiveOps) {
161a5c2f782SMatthias Springer const char *ir = R"MLIR(
162a5c2f782SMatthias Springer "cftest.sequential_regions_op"() (
163a5c2f782SMatthias Springer {"cftest.dummy_op"() : () -> ()}, // op1
164a5c2f782SMatthias Springer {"cftest.dummy_op"() : () -> ()} // op2
165a5c2f782SMatthias Springer ) : () -> ()
166a5c2f782SMatthias Springer )MLIR";
167a5c2f782SMatthias Springer
168a5c2f782SMatthias Springer DialectRegistry registry;
169a5c2f782SMatthias Springer registry.insert<CFTestDialect>();
170a5c2f782SMatthias Springer MLIRContext ctx(registry);
171a5c2f782SMatthias Springer
172dfaadf6bSChristian Sigg OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
173a5c2f782SMatthias Springer Operation *testOp = &module->getBody()->getOperations().front();
174a5c2f782SMatthias Springer Operation *op1 = &testOp->getRegion(0).front().front();
175a5c2f782SMatthias Springer Operation *op2 = &testOp->getRegion(1).front().front();
176a5c2f782SMatthias Springer
177a5c2f782SMatthias Springer EXPECT_FALSE(insideMutuallyExclusiveRegions(op1, op2));
178a5c2f782SMatthias Springer EXPECT_FALSE(insideMutuallyExclusiveRegions(op2, op1));
179a5c2f782SMatthias Springer }
180a5c2f782SMatthias Springer
TEST(RegionBranchOpInterface,NestedMutuallyExclusiveOps)181a5c2f782SMatthias Springer TEST(RegionBranchOpInterface, NestedMutuallyExclusiveOps) {
182a5c2f782SMatthias Springer const char *ir = R"MLIR(
183a5c2f782SMatthias Springer "cftest.mutually_exclusive_regions_op"() (
184a5c2f782SMatthias Springer {
185a5c2f782SMatthias Springer "cftest.sequential_regions_op"() (
186a5c2f782SMatthias Springer {"cftest.dummy_op"() : () -> ()}, // op1
187a5c2f782SMatthias Springer {"cftest.dummy_op"() : () -> ()} // op3
188a5c2f782SMatthias Springer ) : () -> ()
189a5c2f782SMatthias Springer "cftest.dummy_op"() : () -> ()
190a5c2f782SMatthias Springer },
191a5c2f782SMatthias Springer {"cftest.dummy_op"() : () -> ()} // op2
192a5c2f782SMatthias Springer ) : () -> ()
193a5c2f782SMatthias Springer )MLIR";
194a5c2f782SMatthias Springer
195a5c2f782SMatthias Springer DialectRegistry registry;
196a5c2f782SMatthias Springer registry.insert<CFTestDialect>();
197a5c2f782SMatthias Springer MLIRContext ctx(registry);
198a5c2f782SMatthias Springer
199dfaadf6bSChristian Sigg OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
200a5c2f782SMatthias Springer Operation *testOp = &module->getBody()->getOperations().front();
201a5c2f782SMatthias Springer Operation *op1 =
202a5c2f782SMatthias Springer &testOp->getRegion(0).front().front().getRegion(0).front().front();
203a5c2f782SMatthias Springer Operation *op2 = &testOp->getRegion(1).front().front();
204a5c2f782SMatthias Springer Operation *op3 =
205a5c2f782SMatthias Springer &testOp->getRegion(0).front().front().getRegion(1).front().front();
206a5c2f782SMatthias Springer
207a5c2f782SMatthias Springer EXPECT_TRUE(insideMutuallyExclusiveRegions(op1, op2));
208a5c2f782SMatthias Springer EXPECT_TRUE(insideMutuallyExclusiveRegions(op3, op2));
209a5c2f782SMatthias Springer EXPECT_FALSE(insideMutuallyExclusiveRegions(op1, op3));
210a5c2f782SMatthias Springer }
2110f4ba02dSMatthias Springer
TEST(RegionBranchOpInterface,RecursiveRegions)2120f4ba02dSMatthias Springer TEST(RegionBranchOpInterface, RecursiveRegions) {
2130f4ba02dSMatthias Springer const char *ir = R"MLIR(
2140f4ba02dSMatthias Springer "cftest.loop_regions_op"() (
2150f4ba02dSMatthias Springer {"cftest.dummy_op"() : () -> ()}, // op1
2160f4ba02dSMatthias Springer {"cftest.dummy_op"() : () -> ()}, // op2
2170f4ba02dSMatthias Springer {"cftest.dummy_op"() : () -> ()} // op3
2180f4ba02dSMatthias Springer ) : () -> ()
2190f4ba02dSMatthias Springer )MLIR";
2200f4ba02dSMatthias Springer
2210f4ba02dSMatthias Springer DialectRegistry registry;
2220f4ba02dSMatthias Springer registry.insert<CFTestDialect>();
2230f4ba02dSMatthias Springer MLIRContext ctx(registry);
2240f4ba02dSMatthias Springer
2250f4ba02dSMatthias Springer OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
2260f4ba02dSMatthias Springer Operation *testOp = &module->getBody()->getOperations().front();
2270f4ba02dSMatthias Springer auto regionOp = cast<RegionBranchOpInterface>(testOp);
2280f4ba02dSMatthias Springer Operation *op1 = &testOp->getRegion(0).front().front();
2290f4ba02dSMatthias Springer Operation *op2 = &testOp->getRegion(1).front().front();
2300f4ba02dSMatthias Springer Operation *op3 = &testOp->getRegion(2).front().front();
2310f4ba02dSMatthias Springer
2320f4ba02dSMatthias Springer EXPECT_TRUE(regionOp.isRepetitiveRegion(0));
2330f4ba02dSMatthias Springer EXPECT_TRUE(regionOp.isRepetitiveRegion(1));
2340f4ba02dSMatthias Springer EXPECT_TRUE(regionOp.isRepetitiveRegion(2));
2350f4ba02dSMatthias Springer EXPECT_NE(getEnclosingRepetitiveRegion(op1), nullptr);
2360f4ba02dSMatthias Springer EXPECT_NE(getEnclosingRepetitiveRegion(op2), nullptr);
2370f4ba02dSMatthias Springer EXPECT_NE(getEnclosingRepetitiveRegion(op3), nullptr);
2380f4ba02dSMatthias Springer }
2390f4ba02dSMatthias Springer
TEST(RegionBranchOpInterface,NotRecursiveRegions)2400f4ba02dSMatthias Springer TEST(RegionBranchOpInterface, NotRecursiveRegions) {
2410f4ba02dSMatthias Springer const char *ir = R"MLIR(
2420f4ba02dSMatthias Springer "cftest.sequential_regions_op"() (
2430f4ba02dSMatthias Springer {"cftest.dummy_op"() : () -> ()}, // op1
2440f4ba02dSMatthias Springer {"cftest.dummy_op"() : () -> ()} // op2
2450f4ba02dSMatthias Springer ) : () -> ()
2460f4ba02dSMatthias Springer )MLIR";
2470f4ba02dSMatthias Springer
2480f4ba02dSMatthias Springer DialectRegistry registry;
2490f4ba02dSMatthias Springer registry.insert<CFTestDialect>();
2500f4ba02dSMatthias Springer MLIRContext ctx(registry);
2510f4ba02dSMatthias Springer
2520f4ba02dSMatthias Springer OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
2530f4ba02dSMatthias Springer Operation *testOp = &module->getBody()->getOperations().front();
2540f4ba02dSMatthias Springer Operation *op1 = &testOp->getRegion(0).front().front();
2550f4ba02dSMatthias Springer Operation *op2 = &testOp->getRegion(1).front().front();
2560f4ba02dSMatthias Springer
2570f4ba02dSMatthias Springer EXPECT_EQ(getEnclosingRepetitiveRegion(op1), nullptr);
2580f4ba02dSMatthias Springer EXPECT_EQ(getEnclosingRepetitiveRegion(op2), nullptr);
2590f4ba02dSMatthias Springer }
260