1 //===- ControlFlowInterfacesTest.cpp - Unit Tests for Control Flow Interf. ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/ControlFlowInterfaces.h"
10 #include "mlir/IR/BuiltinOps.h"
11 #include "mlir/IR/Dialect.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/OpDefinition.h"
14 #include "mlir/IR/OpImplementation.h"
15 #include "mlir/Parser.h"
16 
17 #include <gtest/gtest.h>
18 
19 using namespace mlir;
20 
21 /// A dummy op that is also a terminator.
22 struct DummyOp : public Op<DummyOp, OpTrait::IsTerminator> {
23   using Op::Op;
24   static ArrayRef<StringRef> getAttributeNames() { return {}; }
25 
26   static StringRef getOperationName() { return "cftest.dummy_op"; }
27 };
28 
29 /// All regions of this op are mutually exclusive.
30 struct MutuallyExclusiveRegionsOp
31     : public Op<MutuallyExclusiveRegionsOp, RegionBranchOpInterface::Trait> {
32   using Op::Op;
33   static ArrayRef<StringRef> getAttributeNames() { return {}; }
34 
35   static StringRef getOperationName() {
36     return "cftest.mutually_exclusive_regions_op";
37   }
38 
39   // Regions have no successors.
40   void getSuccessorRegions(Optional<unsigned> index,
41                            ArrayRef<Attribute> operands,
42                            SmallVectorImpl<RegionSuccessor> &regions) {}
43 };
44 
45 /// Regions are executed sequentially.
46 struct SequentialRegionsOp
47     : public Op<SequentialRegionsOp, RegionBranchOpInterface::Trait> {
48   using Op::Op;
49   static ArrayRef<StringRef> getAttributeNames() { return {}; }
50 
51   static StringRef getOperationName() { return "cftest.sequential_regions_op"; }
52 
53   // Region 0 has Region 1 as a successor.
54   void getSuccessorRegions(Optional<unsigned> index,
55                            ArrayRef<Attribute> operands,
56                            SmallVectorImpl<RegionSuccessor> &regions) {
57     assert(index.hasValue() && "expected index");
58     if (*index == 0) {
59       Operation *thisOp = this->getOperation();
60       regions.push_back(RegionSuccessor(&thisOp->getRegion(1)));
61     }
62   }
63 };
64 
65 /// A dialect putting all the above together.
66 struct CFTestDialect : Dialect {
67   explicit CFTestDialect(MLIRContext *ctx)
68       : Dialect(getDialectNamespace(), ctx, TypeID::get<CFTestDialect>()) {
69     addOperations<DummyOp, MutuallyExclusiveRegionsOp, SequentialRegionsOp>();
70   }
71   static StringRef getDialectNamespace() { return "cftest"; }
72 };
73 
74 TEST(RegionBranchOpInterface, MutuallyExclusiveOps) {
75   const char *ir = R"MLIR(
76 "cftest.mutually_exclusive_regions_op"() (
77       {"cftest.dummy_op"() : () -> ()},  // op1
78       {"cftest.dummy_op"() : () -> ()}   // op2
79   ) : () -> ()
80   )MLIR";
81 
82   DialectRegistry registry;
83   registry.insert<CFTestDialect>();
84   MLIRContext ctx(registry);
85 
86   OwningOpRef<ModuleOp> module = parseSourceString(ir, &ctx);
87   Operation *testOp = &module->getBody()->getOperations().front();
88   Operation *op1 = &testOp->getRegion(0).front().front();
89   Operation *op2 = &testOp->getRegion(1).front().front();
90 
91   EXPECT_TRUE(insideMutuallyExclusiveRegions(op1, op2));
92   EXPECT_TRUE(insideMutuallyExclusiveRegions(op2, op1));
93 }
94 
95 TEST(RegionBranchOpInterface, NotMutuallyExclusiveOps) {
96   const char *ir = R"MLIR(
97 "cftest.sequential_regions_op"() (
98       {"cftest.dummy_op"() : () -> ()},  // op1
99       {"cftest.dummy_op"() : () -> ()}   // op2
100   ) : () -> ()
101   )MLIR";
102 
103   DialectRegistry registry;
104   registry.insert<CFTestDialect>();
105   MLIRContext ctx(registry);
106 
107   OwningOpRef<ModuleOp> module = parseSourceString(ir, &ctx);
108   Operation *testOp = &module->getBody()->getOperations().front();
109   Operation *op1 = &testOp->getRegion(0).front().front();
110   Operation *op2 = &testOp->getRegion(1).front().front();
111 
112   EXPECT_FALSE(insideMutuallyExclusiveRegions(op1, op2));
113   EXPECT_FALSE(insideMutuallyExclusiveRegions(op2, op1));
114 }
115 
116 TEST(RegionBranchOpInterface, NestedMutuallyExclusiveOps) {
117   const char *ir = R"MLIR(
118 "cftest.mutually_exclusive_regions_op"() (
119       {
120         "cftest.sequential_regions_op"() (
121               {"cftest.dummy_op"() : () -> ()},  // op1
122               {"cftest.dummy_op"() : () -> ()}   // op3
123           ) : () -> ()
124         "cftest.dummy_op"() : () -> ()
125       },
126       {"cftest.dummy_op"() : () -> ()}           // op2
127   ) : () -> ()
128   )MLIR";
129 
130   DialectRegistry registry;
131   registry.insert<CFTestDialect>();
132   MLIRContext ctx(registry);
133 
134   OwningOpRef<ModuleOp> module = parseSourceString(ir, &ctx);
135   Operation *testOp = &module->getBody()->getOperations().front();
136   Operation *op1 =
137       &testOp->getRegion(0).front().front().getRegion(0).front().front();
138   Operation *op2 = &testOp->getRegion(1).front().front();
139   Operation *op3 =
140       &testOp->getRegion(0).front().front().getRegion(1).front().front();
141 
142   EXPECT_TRUE(insideMutuallyExclusiveRegions(op1, op2));
143   EXPECT_TRUE(insideMutuallyExclusiveRegions(op3, op2));
144   EXPECT_FALSE(insideMutuallyExclusiveRegions(op1, op3));
145 }
146