1# RUN: %PYTHON %s | FileCheck %s 2 3import gc 4from mlir.ir import * 5 6def run(f): 7 print("\nTEST:", f.__name__) 8 f() 9 gc.collect() 10 assert Context._get_live_count() == 0 11 12 13def add_dummy_value(): 14 return Operation.create( 15 "custom.value", 16 results=[IntegerType.get_signless(32)]).result 17 18 19def testOdsBuildDefaultImplicitRegions(): 20 21 class TestFixedRegionsOp(OpView): 22 OPERATION_NAME = "custom.test_op" 23 _ODS_REGIONS = (2, True) 24 25 class TestVariadicRegionsOp(OpView): 26 OPERATION_NAME = "custom.test_any_regions_op" 27 _ODS_REGIONS = (2, False) 28 29 with Context() as ctx, Location.unknown(): 30 ctx.allow_unregistered_dialects = True 31 m = Module.create() 32 with InsertionPoint(m.body): 33 op = TestFixedRegionsOp.build_generic(results=[], operands=[]) 34 # CHECK: NUM_REGIONS: 2 35 print(f"NUM_REGIONS: {len(op.regions)}") 36 # Including a regions= that matches should be fine. 37 op = TestFixedRegionsOp.build_generic(results=[], operands=[], regions=2) 38 print(f"NUM_REGIONS: {len(op.regions)}") 39 # Reject greater than. 40 try: 41 op = TestFixedRegionsOp.build_generic(results=[], operands=[], regions=3) 42 except ValueError as e: 43 # CHECK: ERROR:Operation "custom.test_op" requires a maximum of 2 regions but was built with regions=3 44 print(f"ERROR:{e}") 45 # Reject less than. 46 try: 47 op = TestFixedRegionsOp.build_generic(results=[], operands=[], regions=1) 48 except ValueError as e: 49 # CHECK: ERROR:Operation "custom.test_op" requires a minimum of 2 regions but was built with regions=1 50 print(f"ERROR:{e}") 51 52 # If no regions specified for a variadic region op, build the minimum. 53 op = TestVariadicRegionsOp.build_generic(results=[], operands=[]) 54 # CHECK: DEFAULT_NUM_REGIONS: 2 55 print(f"DEFAULT_NUM_REGIONS: {len(op.regions)}") 56 # Should also accept an explicit regions= that matches the minimum. 57 op = TestVariadicRegionsOp.build_generic( 58 results=[], operands=[], regions=2) 59 # CHECK: EQ_NUM_REGIONS: 2 60 print(f"EQ_NUM_REGIONS: {len(op.regions)}") 61 # And accept greater than minimum. 62 # Should also accept an explicit regions= that matches the minimum. 63 op = TestVariadicRegionsOp.build_generic( 64 results=[], operands=[], regions=3) 65 # CHECK: GT_NUM_REGIONS: 3 66 print(f"GT_NUM_REGIONS: {len(op.regions)}") 67 # Should reject less than minimum. 68 try: 69 op = TestVariadicRegionsOp.build_generic(results=[], operands=[], regions=1) 70 except ValueError as e: 71 # CHECK: ERROR:Operation "custom.test_any_regions_op" requires a minimum of 2 regions but was built with regions=1 72 print(f"ERROR:{e}") 73 74 75 76run(testOdsBuildDefaultImplicitRegions) 77 78 79def testOdsBuildDefaultNonVariadic(): 80 81 class TestOp(OpView): 82 OPERATION_NAME = "custom.test_op" 83 84 with Context() as ctx, Location.unknown(): 85 ctx.allow_unregistered_dialects = True 86 m = Module.create() 87 with InsertionPoint(m.body): 88 v0 = add_dummy_value() 89 v1 = add_dummy_value() 90 t0 = IntegerType.get_signless(8) 91 t1 = IntegerType.get_signless(16) 92 op = TestOp.build_generic(results=[t0, t1], operands=[v0, v1]) 93 # CHECK: %[[V0:.+]] = "custom.value" 94 # CHECK: %[[V1:.+]] = "custom.value" 95 # CHECK: "custom.test_op"(%[[V0]], %[[V1]]) 96 # CHECK-NOT: operand_segment_sizes 97 # CHECK-NOT: result_segment_sizes 98 # CHECK-SAME: : (i32, i32) -> (i8, i16) 99 print(m) 100 101run(testOdsBuildDefaultNonVariadic) 102 103 104def testOdsBuildDefaultSizedVariadic(): 105 106 class TestOp(OpView): 107 OPERATION_NAME = "custom.test_op" 108 _ODS_OPERAND_SEGMENTS = [1, -1, 0] 109 _ODS_RESULT_SEGMENTS = [-1, 0, 1] 110 111 with Context() as ctx, Location.unknown(): 112 ctx.allow_unregistered_dialects = True 113 m = Module.create() 114 with InsertionPoint(m.body): 115 v0 = add_dummy_value() 116 v1 = add_dummy_value() 117 v2 = add_dummy_value() 118 v3 = add_dummy_value() 119 t0 = IntegerType.get_signless(8) 120 t1 = IntegerType.get_signless(16) 121 t2 = IntegerType.get_signless(32) 122 t3 = IntegerType.get_signless(64) 123 # CHECK: %[[V0:.+]] = "custom.value" 124 # CHECK: %[[V1:.+]] = "custom.value" 125 # CHECK: %[[V2:.+]] = "custom.value" 126 # CHECK: %[[V3:.+]] = "custom.value" 127 # CHECK: "custom.test_op"(%[[V0]], %[[V1]], %[[V2]], %[[V3]]) 128 # CHECK-SAME: operand_segment_sizes = dense<[1, 2, 1]> : vector<3xi32> 129 # CHECK-SAME: result_segment_sizes = dense<[2, 1, 1]> : vector<3xi32> 130 # CHECK-SAME: : (i32, i32, i32, i32) -> (i8, i16, i32, i64) 131 op = TestOp.build_generic( 132 results=[[t0, t1], t2, t3], 133 operands=[v0, [v1, v2], v3]) 134 135 # Now test with optional omitted. 136 # CHECK: "custom.test_op"(%[[V0]]) 137 # CHECK-SAME: operand_segment_sizes = dense<[1, 0, 0]> 138 # CHECK-SAME: result_segment_sizes = dense<[0, 0, 1]> 139 # CHECK-SAME: (i32) -> i64 140 op = TestOp.build_generic( 141 results=[None, None, t3], 142 operands=[v0, None, None]) 143 print(m) 144 145 # And verify that errors are raised for None in a required operand. 146 try: 147 op = TestOp.build_generic( 148 results=[None, None, t3], 149 operands=[None, None, None]) 150 except ValueError as e: 151 # CHECK: OPERAND_CAST_ERROR:Operand 0 of operation "custom.test_op" must be a Value (was None and operand is not optional) 152 print(f"OPERAND_CAST_ERROR:{e}") 153 154 # And verify that errors are raised for None in a required result. 155 try: 156 op = TestOp.build_generic( 157 results=[None, None, None], 158 operands=[v0, None, None]) 159 except ValueError as e: 160 # CHECK: RESULT_CAST_ERROR:Result 2 of operation "custom.test_op" must be a Type (was None and result is not optional) 161 print(f"RESULT_CAST_ERROR:{e}") 162 163 # Variadic lists with None elements should reject. 164 try: 165 op = TestOp.build_generic( 166 results=[None, None, t3], 167 operands=[v0, [None], None]) 168 except ValueError as e: 169 # CHECK: OPERAND_LIST_CAST_ERROR:Operand 1 of operation "custom.test_op" must be a Sequence of Values (contained a None item) 170 print(f"OPERAND_LIST_CAST_ERROR:{e}") 171 try: 172 op = TestOp.build_generic( 173 results=[[None], None, t3], 174 operands=[v0, None, None]) 175 except ValueError as e: 176 # CHECK: RESULT_LIST_CAST_ERROR:Result 0 of operation "custom.test_op" must be a Sequence of Types (contained a None item) 177 print(f"RESULT_LIST_CAST_ERROR:{e}") 178 179run(testOdsBuildDefaultSizedVariadic) 180 181 182def testOdsBuildDefaultCastError(): 183 184 class TestOp(OpView): 185 OPERATION_NAME = "custom.test_op" 186 187 with Context() as ctx, Location.unknown(): 188 ctx.allow_unregistered_dialects = True 189 m = Module.create() 190 with InsertionPoint(m.body): 191 v0 = add_dummy_value() 192 v1 = add_dummy_value() 193 t0 = IntegerType.get_signless(8) 194 t1 = IntegerType.get_signless(16) 195 try: 196 op = TestOp.build_generic( 197 results=[t0, t1], 198 operands=[None, v1]) 199 except ValueError as e: 200 # CHECK: ERROR: Operand 0 of operation "custom.test_op" must be a Value 201 print(f"ERROR: {e}") 202 try: 203 op = TestOp.build_generic( 204 results=[t0, None], 205 operands=[v0, v1]) 206 except ValueError as e: 207 # CHECK: Result 1 of operation "custom.test_op" must be a Type 208 print(f"ERROR: {e}") 209 210run(testOdsBuildDefaultCastError) 211