1# RUN: %PYTHON %s | FileCheck %s 2 3import gc 4import io 5import itertools 6from mlir.ir import * 7 8def run(f): 9 print("\nTEST:", f.__name__) 10 f() 11 gc.collect() 12 assert Context._get_live_count() == 0 13 14 15# CHECK-LABEL: TEST: test_insert_at_block_end 16def test_insert_at_block_end(): 17 ctx = Context() 18 ctx.allow_unregistered_dialects = True 19 with Location.unknown(ctx): 20 module = Module.parse(r""" 21 func.func @foo() -> () { 22 "custom.op1"() : () -> () 23 } 24 """) 25 entry_block = module.body.operations[0].regions[0].blocks[0] 26 ip = InsertionPoint(entry_block) 27 ip.insert(Operation.create("custom.op2")) 28 # CHECK: "custom.op1" 29 # CHECK: "custom.op2" 30 module.operation.print() 31 32run(test_insert_at_block_end) 33 34 35# CHECK-LABEL: TEST: test_insert_before_operation 36def test_insert_before_operation(): 37 ctx = Context() 38 ctx.allow_unregistered_dialects = True 39 with Location.unknown(ctx): 40 module = Module.parse(r""" 41 func.func @foo() -> () { 42 "custom.op1"() : () -> () 43 "custom.op2"() : () -> () 44 } 45 """) 46 entry_block = module.body.operations[0].regions[0].blocks[0] 47 ip = InsertionPoint(entry_block.operations[1]) 48 ip.insert(Operation.create("custom.op3")) 49 # CHECK: "custom.op1" 50 # CHECK: "custom.op3" 51 # CHECK: "custom.op2" 52 module.operation.print() 53 54run(test_insert_before_operation) 55 56 57# CHECK-LABEL: TEST: test_insert_at_block_begin 58def test_insert_at_block_begin(): 59 ctx = Context() 60 ctx.allow_unregistered_dialects = True 61 with Location.unknown(ctx): 62 module = Module.parse(r""" 63 func.func @foo() -> () { 64 "custom.op2"() : () -> () 65 } 66 """) 67 entry_block = module.body.operations[0].regions[0].blocks[0] 68 ip = InsertionPoint.at_block_begin(entry_block) 69 ip.insert(Operation.create("custom.op1")) 70 # CHECK: "custom.op1" 71 # CHECK: "custom.op2" 72 module.operation.print() 73 74run(test_insert_at_block_begin) 75 76 77# CHECK-LABEL: TEST: test_insert_at_block_begin_empty 78def test_insert_at_block_begin_empty(): 79 # TODO: Write this test case when we can create such a situation. 80 pass 81 82run(test_insert_at_block_begin_empty) 83 84 85# CHECK-LABEL: TEST: test_insert_at_terminator 86def test_insert_at_terminator(): 87 ctx = Context() 88 ctx.allow_unregistered_dialects = True 89 with Location.unknown(ctx): 90 module = Module.parse(r""" 91 func.func @foo() -> () { 92 "custom.op1"() : () -> () 93 return 94 } 95 """) 96 entry_block = module.body.operations[0].regions[0].blocks[0] 97 ip = InsertionPoint.at_block_terminator(entry_block) 98 ip.insert(Operation.create("custom.op2")) 99 # CHECK: "custom.op1" 100 # CHECK: "custom.op2" 101 module.operation.print() 102 103run(test_insert_at_terminator) 104 105 106# CHECK-LABEL: TEST: test_insert_at_block_terminator_missing 107def test_insert_at_block_terminator_missing(): 108 ctx = Context() 109 ctx.allow_unregistered_dialects = True 110 with ctx: 111 module = Module.parse(r""" 112 func.func @foo() -> () { 113 "custom.op1"() : () -> () 114 } 115 """) 116 entry_block = module.body.operations[0].regions[0].blocks[0] 117 try: 118 ip = InsertionPoint.at_block_terminator(entry_block) 119 except ValueError as e: 120 # CHECK: Block has no terminator 121 print(e) 122 else: 123 assert False, "Expected exception" 124 125run(test_insert_at_block_terminator_missing) 126 127 128# CHECK-LABEL: TEST: test_insert_at_end_with_terminator_errors 129def test_insert_at_end_with_terminator_errors(): 130 with Context() as ctx, Location.unknown(): 131 ctx.allow_unregistered_dialects = True 132 module = Module.parse(r""" 133 func.func @foo() -> () { 134 return 135 } 136 """) 137 entry_block = module.body.operations[0].regions[0].blocks[0] 138 with InsertionPoint(entry_block): 139 try: 140 Operation.create("custom.op1", results=[], operands=[]) 141 except IndexError as e: 142 # CHECK: ERROR: Cannot insert operation at the end of a block that already has a terminator. 143 print(f"ERROR: {e}") 144 145run(test_insert_at_end_with_terminator_errors) 146 147 148# CHECK-LABEL: TEST: test_insertion_point_context 149def test_insertion_point_context(): 150 ctx = Context() 151 ctx.allow_unregistered_dialects = True 152 with Location.unknown(ctx): 153 module = Module.parse(r""" 154 func.func @foo() -> () { 155 "custom.op1"() : () -> () 156 } 157 """) 158 entry_block = module.body.operations[0].regions[0].blocks[0] 159 with InsertionPoint(entry_block): 160 Operation.create("custom.op2") 161 with InsertionPoint.at_block_begin(entry_block): 162 Operation.create("custom.opa") 163 Operation.create("custom.opb") 164 Operation.create("custom.op3") 165 # CHECK: "custom.opa" 166 # CHECK: "custom.opb" 167 # CHECK: "custom.op1" 168 # CHECK: "custom.op2" 169 # CHECK: "custom.op3" 170 module.operation.print() 171 172run(test_insertion_point_context) 173