1# RUN: %PYTHON %s | FileCheck %s
2
3import gc
4from mlir.ir import *
5
6def run(f):
7  print("\nTEST:", f.__name__)
8  f()
9  gc.collect()
10  assert Context._get_live_count() == 0
11
12
13def add_dummy_value():
14  return Operation.create(
15      "custom.value",
16      results=[IntegerType.get_signless(32)]).result
17
18
19def testOdsBuildDefaultImplicitRegions():
20
21  class TestFixedRegionsOp(OpView):
22    OPERATION_NAME = "custom.test_op"
23    _ODS_REGIONS = (2, True)
24
25  class TestVariadicRegionsOp(OpView):
26    OPERATION_NAME = "custom.test_any_regions_op"
27    _ODS_REGIONS = (2, False)
28
29  with Context() as ctx, Location.unknown():
30    ctx.allow_unregistered_dialects = True
31    m = Module.create()
32    with InsertionPoint(m.body):
33      op = TestFixedRegionsOp.build_generic(results=[], operands=[])
34      # CHECK: NUM_REGIONS: 2
35      print(f"NUM_REGIONS: {len(op.regions)}")
36      # Including a regions= that matches should be fine.
37      op = TestFixedRegionsOp.build_generic(results=[], operands=[], regions=2)
38      print(f"NUM_REGIONS: {len(op.regions)}")
39      # Reject greater than.
40      try:
41        op = TestFixedRegionsOp.build_generic(results=[], operands=[], regions=3)
42      except ValueError as e:
43        # CHECK: ERROR:Operation "custom.test_op" requires a maximum of 2 regions but was built with regions=3
44        print(f"ERROR:{e}")
45      # Reject less than.
46      try:
47        op = TestFixedRegionsOp.build_generic(results=[], operands=[], regions=1)
48      except ValueError as e:
49        # CHECK: ERROR:Operation "custom.test_op" requires a minimum of 2 regions but was built with regions=1
50        print(f"ERROR:{e}")
51
52      # If no regions specified for a variadic region op, build the minimum.
53      op = TestVariadicRegionsOp.build_generic(results=[], operands=[])
54      # CHECK: DEFAULT_NUM_REGIONS: 2
55      print(f"DEFAULT_NUM_REGIONS: {len(op.regions)}")
56      # Should also accept an explicit regions= that matches the minimum.
57      op = TestVariadicRegionsOp.build_generic(
58          results=[], operands=[], regions=2)
59      # CHECK: EQ_NUM_REGIONS: 2
60      print(f"EQ_NUM_REGIONS: {len(op.regions)}")
61      # And accept greater than minimum.
62      # Should also accept an explicit regions= that matches the minimum.
63      op = TestVariadicRegionsOp.build_generic(
64          results=[], operands=[], regions=3)
65      # CHECK: GT_NUM_REGIONS: 3
66      print(f"GT_NUM_REGIONS: {len(op.regions)}")
67      # Should reject less than minimum.
68      try:
69        op = TestVariadicRegionsOp.build_generic(results=[], operands=[], regions=1)
70      except ValueError as e:
71        # CHECK: ERROR:Operation "custom.test_any_regions_op" requires a minimum of 2 regions but was built with regions=1
72        print(f"ERROR:{e}")
73
74
75
76run(testOdsBuildDefaultImplicitRegions)
77
78
79def testOdsBuildDefaultNonVariadic():
80
81  class TestOp(OpView):
82    OPERATION_NAME = "custom.test_op"
83
84  with Context() as ctx, Location.unknown():
85    ctx.allow_unregistered_dialects = True
86    m = Module.create()
87    with InsertionPoint(m.body):
88      v0 = add_dummy_value()
89      v1 = add_dummy_value()
90      t0 = IntegerType.get_signless(8)
91      t1 = IntegerType.get_signless(16)
92      op = TestOp.build_generic(results=[t0, t1], operands=[v0, v1])
93      # CHECK: %[[V0:.+]] = "custom.value"
94      # CHECK: %[[V1:.+]] = "custom.value"
95      # CHECK: "custom.test_op"(%[[V0]], %[[V1]])
96      # CHECK-NOT: operand_segment_sizes
97      # CHECK-NOT: result_segment_sizes
98      # CHECK-SAME: : (i32, i32) -> (i8, i16)
99      print(m)
100
101run(testOdsBuildDefaultNonVariadic)
102
103
104def testOdsBuildDefaultSizedVariadic():
105
106  class TestOp(OpView):
107    OPERATION_NAME = "custom.test_op"
108    _ODS_OPERAND_SEGMENTS = [1, -1, 0]
109    _ODS_RESULT_SEGMENTS = [-1, 0, 1]
110
111  with Context() as ctx, Location.unknown():
112    ctx.allow_unregistered_dialects = True
113    m = Module.create()
114    with InsertionPoint(m.body):
115      v0 = add_dummy_value()
116      v1 = add_dummy_value()
117      v2 = add_dummy_value()
118      v3 = add_dummy_value()
119      t0 = IntegerType.get_signless(8)
120      t1 = IntegerType.get_signless(16)
121      t2 = IntegerType.get_signless(32)
122      t3 = IntegerType.get_signless(64)
123      # CHECK: %[[V0:.+]] = "custom.value"
124      # CHECK: %[[V1:.+]] = "custom.value"
125      # CHECK: %[[V2:.+]] = "custom.value"
126      # CHECK: %[[V3:.+]] = "custom.value"
127      # CHECK: "custom.test_op"(%[[V0]], %[[V1]], %[[V2]], %[[V3]])
128      # CHECK-SAME: operand_segment_sizes = dense<[1, 2, 1]> : vector<3xi32>
129      # CHECK-SAME: result_segment_sizes = dense<[2, 1, 1]> : vector<3xi32>
130      # CHECK-SAME: : (i32, i32, i32, i32) -> (i8, i16, i32, i64)
131      op = TestOp.build_generic(
132          results=[[t0, t1], t2, t3],
133          operands=[v0, [v1, v2], v3])
134
135      # Now test with optional omitted.
136      # CHECK: "custom.test_op"(%[[V0]])
137      # CHECK-SAME: operand_segment_sizes = dense<[1, 0, 0]>
138      # CHECK-SAME: result_segment_sizes = dense<[0, 0, 1]>
139      # CHECK-SAME: (i32) -> i64
140      op = TestOp.build_generic(
141          results=[None, None, t3],
142          operands=[v0, None, None])
143      print(m)
144
145      # And verify that errors are raised for None in a required operand.
146      try:
147        op = TestOp.build_generic(
148            results=[None, None, t3],
149            operands=[None, None, None])
150      except ValueError as e:
151        # CHECK: OPERAND_CAST_ERROR:Operand 0 of operation "custom.test_op" must be a Value (was None and operand is not optional)
152        print(f"OPERAND_CAST_ERROR:{e}")
153
154      # And verify that errors are raised for None in a required result.
155      try:
156        op = TestOp.build_generic(
157            results=[None, None, None],
158            operands=[v0, None, None])
159      except ValueError as e:
160        # CHECK: RESULT_CAST_ERROR:Result 2 of operation "custom.test_op" must be a Type (was None and result is not optional)
161        print(f"RESULT_CAST_ERROR:{e}")
162
163      # Variadic lists with None elements should reject.
164      try:
165        op = TestOp.build_generic(
166            results=[None, None, t3],
167            operands=[v0, [None], None])
168      except ValueError as e:
169        # CHECK: OPERAND_LIST_CAST_ERROR:Operand 1 of operation "custom.test_op" must be a Sequence of Values (contained a None item)
170        print(f"OPERAND_LIST_CAST_ERROR:{e}")
171      try:
172        op = TestOp.build_generic(
173            results=[[None], None, t3],
174            operands=[v0, None, None])
175      except ValueError as e:
176        # CHECK: RESULT_LIST_CAST_ERROR:Result 0 of operation "custom.test_op" must be a Sequence of Types (contained a None item)
177        print(f"RESULT_LIST_CAST_ERROR:{e}")
178
179run(testOdsBuildDefaultSizedVariadic)
180
181
182def testOdsBuildDefaultCastError():
183
184  class TestOp(OpView):
185    OPERATION_NAME = "custom.test_op"
186
187  with Context() as ctx, Location.unknown():
188    ctx.allow_unregistered_dialects = True
189    m = Module.create()
190    with InsertionPoint(m.body):
191      v0 = add_dummy_value()
192      v1 = add_dummy_value()
193      t0 = IntegerType.get_signless(8)
194      t1 = IntegerType.get_signless(16)
195      try:
196        op = TestOp.build_generic(
197            results=[t0, t1],
198            operands=[None, v1])
199      except ValueError as e:
200        # CHECK: ERROR: Operand 0 of operation "custom.test_op" must be a Value
201        print(f"ERROR: {e}")
202      try:
203        op = TestOp.build_generic(
204            results=[t0, None],
205            operands=[v0, v1])
206      except ValueError as e:
207        # CHECK: Result 1 of operation "custom.test_op" must be a Type
208        print(f"ERROR: {e}")
209
210run(testOdsBuildDefaultCastError)
211