1# RUN: %PYTHON %s | FileCheck %s
2
3import gc
4import io
5import itertools
6from mlir.ir import *
7
8def run(f):
9  print("\nTEST:", f.__name__)
10  f()
11  gc.collect()
12  assert Context._get_live_count() == 0
13
14
15# CHECK-LABEL: TEST: test_insert_at_block_end
16def test_insert_at_block_end():
17  ctx = Context()
18  ctx.allow_unregistered_dialects = True
19  with Location.unknown(ctx):
20    module = Module.parse(r"""
21      func.func @foo() -> () {
22        "custom.op1"() : () -> ()
23      }
24    """)
25    entry_block = module.body.operations[0].regions[0].blocks[0]
26    ip = InsertionPoint(entry_block)
27    ip.insert(Operation.create("custom.op2"))
28    # CHECK: "custom.op1"
29    # CHECK: "custom.op2"
30    module.operation.print()
31
32run(test_insert_at_block_end)
33
34
35# CHECK-LABEL: TEST: test_insert_before_operation
36def test_insert_before_operation():
37  ctx = Context()
38  ctx.allow_unregistered_dialects = True
39  with Location.unknown(ctx):
40    module = Module.parse(r"""
41      func.func @foo() -> () {
42        "custom.op1"() : () -> ()
43        "custom.op2"() : () -> ()
44      }
45    """)
46    entry_block = module.body.operations[0].regions[0].blocks[0]
47    ip = InsertionPoint(entry_block.operations[1])
48    ip.insert(Operation.create("custom.op3"))
49    # CHECK: "custom.op1"
50    # CHECK: "custom.op3"
51    # CHECK: "custom.op2"
52    module.operation.print()
53
54run(test_insert_before_operation)
55
56
57# CHECK-LABEL: TEST: test_insert_at_block_begin
58def test_insert_at_block_begin():
59  ctx = Context()
60  ctx.allow_unregistered_dialects = True
61  with Location.unknown(ctx):
62    module = Module.parse(r"""
63      func.func @foo() -> () {
64        "custom.op2"() : () -> ()
65      }
66    """)
67    entry_block = module.body.operations[0].regions[0].blocks[0]
68    ip = InsertionPoint.at_block_begin(entry_block)
69    ip.insert(Operation.create("custom.op1"))
70    # CHECK: "custom.op1"
71    # CHECK: "custom.op2"
72    module.operation.print()
73
74run(test_insert_at_block_begin)
75
76
77# CHECK-LABEL: TEST: test_insert_at_block_begin_empty
78def test_insert_at_block_begin_empty():
79  # TODO: Write this test case when we can create such a situation.
80  pass
81
82run(test_insert_at_block_begin_empty)
83
84
85# CHECK-LABEL: TEST: test_insert_at_terminator
86def test_insert_at_terminator():
87  ctx = Context()
88  ctx.allow_unregistered_dialects = True
89  with Location.unknown(ctx):
90    module = Module.parse(r"""
91      func.func @foo() -> () {
92        "custom.op1"() : () -> ()
93        return
94      }
95    """)
96    entry_block = module.body.operations[0].regions[0].blocks[0]
97    ip = InsertionPoint.at_block_terminator(entry_block)
98    ip.insert(Operation.create("custom.op2"))
99    # CHECK: "custom.op1"
100    # CHECK: "custom.op2"
101    module.operation.print()
102
103run(test_insert_at_terminator)
104
105
106# CHECK-LABEL: TEST: test_insert_at_block_terminator_missing
107def test_insert_at_block_terminator_missing():
108  ctx = Context()
109  ctx.allow_unregistered_dialects = True
110  with ctx:
111    module = Module.parse(r"""
112      func.func @foo() -> () {
113        "custom.op1"() : () -> ()
114      }
115    """)
116    entry_block = module.body.operations[0].regions[0].blocks[0]
117    try:
118      ip = InsertionPoint.at_block_terminator(entry_block)
119    except ValueError as e:
120      # CHECK: Block has no terminator
121      print(e)
122    else:
123      assert False, "Expected exception"
124
125run(test_insert_at_block_terminator_missing)
126
127
128# CHECK-LABEL: TEST: test_insert_at_end_with_terminator_errors
129def test_insert_at_end_with_terminator_errors():
130  with Context() as ctx, Location.unknown():
131    ctx.allow_unregistered_dialects = True
132    module = Module.parse(r"""
133      func.func @foo() -> () {
134        return
135      }
136    """)
137    entry_block = module.body.operations[0].regions[0].blocks[0]
138    with InsertionPoint(entry_block):
139      try:
140        Operation.create("custom.op1", results=[], operands=[])
141      except IndexError as e:
142        # CHECK: ERROR: Cannot insert operation at the end of a block that already has a terminator.
143        print(f"ERROR: {e}")
144
145run(test_insert_at_end_with_terminator_errors)
146
147
148# CHECK-LABEL: TEST: test_insertion_point_context
149def test_insertion_point_context():
150  ctx = Context()
151  ctx.allow_unregistered_dialects = True
152  with Location.unknown(ctx):
153    module = Module.parse(r"""
154      func.func @foo() -> () {
155        "custom.op1"() : () -> ()
156      }
157    """)
158    entry_block = module.body.operations[0].regions[0].blocks[0]
159    with InsertionPoint(entry_block):
160      Operation.create("custom.op2")
161      with InsertionPoint.at_block_begin(entry_block):
162        Operation.create("custom.opa")
163        Operation.create("custom.opb")
164      Operation.create("custom.op3")
165    # CHECK: "custom.opa"
166    # CHECK: "custom.opb"
167    # CHECK: "custom.op1"
168    # CHECK: "custom.op2"
169    # CHECK: "custom.op3"
170    module.operation.print()
171
172run(test_insertion_point_context)
173