19f3f6d7bSStella Laurenzo# RUN: %PYTHON %s | FileCheck %s
29f3f6d7bSStella Laurenzo
39f3f6d7bSStella Laurenzoimport gc
49f3f6d7bSStella Laurenzoimport io
59f3f6d7bSStella Laurenzoimport itertools
69f3f6d7bSStella Laurenzofrom mlir.ir import *
79f3f6d7bSStella Laurenzo
89f3f6d7bSStella Laurenzodef run(f):
99f3f6d7bSStella Laurenzo  print("\nTEST:", f.__name__)
109f3f6d7bSStella Laurenzo  f()
119f3f6d7bSStella Laurenzo  gc.collect()
129f3f6d7bSStella Laurenzo  assert Context._get_live_count() == 0
139f3f6d7bSStella Laurenzo
149f3f6d7bSStella Laurenzo
159f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: test_insert_at_block_end
169f3f6d7bSStella Laurenzodef test_insert_at_block_end():
179f3f6d7bSStella Laurenzo  ctx = Context()
189f3f6d7bSStella Laurenzo  ctx.allow_unregistered_dialects = True
199f3f6d7bSStella Laurenzo  with Location.unknown(ctx):
209f3f6d7bSStella Laurenzo    module = Module.parse(r"""
21*2310ced8SRiver Riddle      func.func @foo() -> () {
229f3f6d7bSStella Laurenzo        "custom.op1"() : () -> ()
239f3f6d7bSStella Laurenzo      }
249f3f6d7bSStella Laurenzo    """)
259f3f6d7bSStella Laurenzo    entry_block = module.body.operations[0].regions[0].blocks[0]
269f3f6d7bSStella Laurenzo    ip = InsertionPoint(entry_block)
279f3f6d7bSStella Laurenzo    ip.insert(Operation.create("custom.op2"))
289f3f6d7bSStella Laurenzo    # CHECK: "custom.op1"
299f3f6d7bSStella Laurenzo    # CHECK: "custom.op2"
309f3f6d7bSStella Laurenzo    module.operation.print()
319f3f6d7bSStella Laurenzo
329f3f6d7bSStella Laurenzorun(test_insert_at_block_end)
339f3f6d7bSStella Laurenzo
349f3f6d7bSStella Laurenzo
359f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: test_insert_before_operation
369f3f6d7bSStella Laurenzodef test_insert_before_operation():
379f3f6d7bSStella Laurenzo  ctx = Context()
389f3f6d7bSStella Laurenzo  ctx.allow_unregistered_dialects = True
399f3f6d7bSStella Laurenzo  with Location.unknown(ctx):
409f3f6d7bSStella Laurenzo    module = Module.parse(r"""
41*2310ced8SRiver Riddle      func.func @foo() -> () {
429f3f6d7bSStella Laurenzo        "custom.op1"() : () -> ()
439f3f6d7bSStella Laurenzo        "custom.op2"() : () -> ()
449f3f6d7bSStella Laurenzo      }
459f3f6d7bSStella Laurenzo    """)
469f3f6d7bSStella Laurenzo    entry_block = module.body.operations[0].regions[0].blocks[0]
479f3f6d7bSStella Laurenzo    ip = InsertionPoint(entry_block.operations[1])
489f3f6d7bSStella Laurenzo    ip.insert(Operation.create("custom.op3"))
499f3f6d7bSStella Laurenzo    # CHECK: "custom.op1"
509f3f6d7bSStella Laurenzo    # CHECK: "custom.op3"
519f3f6d7bSStella Laurenzo    # CHECK: "custom.op2"
529f3f6d7bSStella Laurenzo    module.operation.print()
539f3f6d7bSStella Laurenzo
549f3f6d7bSStella Laurenzorun(test_insert_before_operation)
559f3f6d7bSStella Laurenzo
569f3f6d7bSStella Laurenzo
579f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: test_insert_at_block_begin
589f3f6d7bSStella Laurenzodef test_insert_at_block_begin():
599f3f6d7bSStella Laurenzo  ctx = Context()
609f3f6d7bSStella Laurenzo  ctx.allow_unregistered_dialects = True
619f3f6d7bSStella Laurenzo  with Location.unknown(ctx):
629f3f6d7bSStella Laurenzo    module = Module.parse(r"""
63*2310ced8SRiver Riddle      func.func @foo() -> () {
649f3f6d7bSStella Laurenzo        "custom.op2"() : () -> ()
659f3f6d7bSStella Laurenzo      }
669f3f6d7bSStella Laurenzo    """)
679f3f6d7bSStella Laurenzo    entry_block = module.body.operations[0].regions[0].blocks[0]
689f3f6d7bSStella Laurenzo    ip = InsertionPoint.at_block_begin(entry_block)
699f3f6d7bSStella Laurenzo    ip.insert(Operation.create("custom.op1"))
709f3f6d7bSStella Laurenzo    # CHECK: "custom.op1"
719f3f6d7bSStella Laurenzo    # CHECK: "custom.op2"
729f3f6d7bSStella Laurenzo    module.operation.print()
739f3f6d7bSStella Laurenzo
749f3f6d7bSStella Laurenzorun(test_insert_at_block_begin)
759f3f6d7bSStella Laurenzo
769f3f6d7bSStella Laurenzo
779f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: test_insert_at_block_begin_empty
789f3f6d7bSStella Laurenzodef test_insert_at_block_begin_empty():
799f3f6d7bSStella Laurenzo  # TODO: Write this test case when we can create such a situation.
809f3f6d7bSStella Laurenzo  pass
819f3f6d7bSStella Laurenzo
829f3f6d7bSStella Laurenzorun(test_insert_at_block_begin_empty)
839f3f6d7bSStella Laurenzo
849f3f6d7bSStella Laurenzo
859f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: test_insert_at_terminator
869f3f6d7bSStella Laurenzodef test_insert_at_terminator():
879f3f6d7bSStella Laurenzo  ctx = Context()
889f3f6d7bSStella Laurenzo  ctx.allow_unregistered_dialects = True
899f3f6d7bSStella Laurenzo  with Location.unknown(ctx):
909f3f6d7bSStella Laurenzo    module = Module.parse(r"""
91*2310ced8SRiver Riddle      func.func @foo() -> () {
929f3f6d7bSStella Laurenzo        "custom.op1"() : () -> ()
939f3f6d7bSStella Laurenzo        return
949f3f6d7bSStella Laurenzo      }
959f3f6d7bSStella Laurenzo    """)
969f3f6d7bSStella Laurenzo    entry_block = module.body.operations[0].regions[0].blocks[0]
979f3f6d7bSStella Laurenzo    ip = InsertionPoint.at_block_terminator(entry_block)
989f3f6d7bSStella Laurenzo    ip.insert(Operation.create("custom.op2"))
999f3f6d7bSStella Laurenzo    # CHECK: "custom.op1"
1009f3f6d7bSStella Laurenzo    # CHECK: "custom.op2"
1019f3f6d7bSStella Laurenzo    module.operation.print()
1029f3f6d7bSStella Laurenzo
1039f3f6d7bSStella Laurenzorun(test_insert_at_terminator)
1049f3f6d7bSStella Laurenzo
1059f3f6d7bSStella Laurenzo
1069f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: test_insert_at_block_terminator_missing
1079f3f6d7bSStella Laurenzodef test_insert_at_block_terminator_missing():
1089f3f6d7bSStella Laurenzo  ctx = Context()
1099f3f6d7bSStella Laurenzo  ctx.allow_unregistered_dialects = True
1109f3f6d7bSStella Laurenzo  with ctx:
1119f3f6d7bSStella Laurenzo    module = Module.parse(r"""
112*2310ced8SRiver Riddle      func.func @foo() -> () {
1139f3f6d7bSStella Laurenzo        "custom.op1"() : () -> ()
1149f3f6d7bSStella Laurenzo      }
1159f3f6d7bSStella Laurenzo    """)
1169f3f6d7bSStella Laurenzo    entry_block = module.body.operations[0].regions[0].blocks[0]
1179f3f6d7bSStella Laurenzo    try:
1189f3f6d7bSStella Laurenzo      ip = InsertionPoint.at_block_terminator(entry_block)
1199f3f6d7bSStella Laurenzo    except ValueError as e:
1209f3f6d7bSStella Laurenzo      # CHECK: Block has no terminator
1219f3f6d7bSStella Laurenzo      print(e)
1229f3f6d7bSStella Laurenzo    else:
1239f3f6d7bSStella Laurenzo      assert False, "Expected exception"
1249f3f6d7bSStella Laurenzo
1259f3f6d7bSStella Laurenzorun(test_insert_at_block_terminator_missing)
1269f3f6d7bSStella Laurenzo
1279f3f6d7bSStella Laurenzo
1289f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: test_insert_at_end_with_terminator_errors
1299f3f6d7bSStella Laurenzodef test_insert_at_end_with_terminator_errors():
1309f3f6d7bSStella Laurenzo  with Context() as ctx, Location.unknown():
1319f3f6d7bSStella Laurenzo    ctx.allow_unregistered_dialects = True
1329f3f6d7bSStella Laurenzo    module = Module.parse(r"""
133*2310ced8SRiver Riddle      func.func @foo() -> () {
1349f3f6d7bSStella Laurenzo        return
1359f3f6d7bSStella Laurenzo      }
1369f3f6d7bSStella Laurenzo    """)
1379f3f6d7bSStella Laurenzo    entry_block = module.body.operations[0].regions[0].blocks[0]
1389f3f6d7bSStella Laurenzo    with InsertionPoint(entry_block):
1399f3f6d7bSStella Laurenzo      try:
1409f3f6d7bSStella Laurenzo        Operation.create("custom.op1", results=[], operands=[])
1419f3f6d7bSStella Laurenzo      except IndexError as e:
1429f3f6d7bSStella Laurenzo        # CHECK: ERROR: Cannot insert operation at the end of a block that already has a terminator.
1439f3f6d7bSStella Laurenzo        print(f"ERROR: {e}")
1449f3f6d7bSStella Laurenzo
1459f3f6d7bSStella Laurenzorun(test_insert_at_end_with_terminator_errors)
1469f3f6d7bSStella Laurenzo
1479f3f6d7bSStella Laurenzo
1489f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: test_insertion_point_context
1499f3f6d7bSStella Laurenzodef test_insertion_point_context():
1509f3f6d7bSStella Laurenzo  ctx = Context()
1519f3f6d7bSStella Laurenzo  ctx.allow_unregistered_dialects = True
1529f3f6d7bSStella Laurenzo  with Location.unknown(ctx):
1539f3f6d7bSStella Laurenzo    module = Module.parse(r"""
154*2310ced8SRiver Riddle      func.func @foo() -> () {
1559f3f6d7bSStella Laurenzo        "custom.op1"() : () -> ()
1569f3f6d7bSStella Laurenzo      }
1579f3f6d7bSStella Laurenzo    """)
1589f3f6d7bSStella Laurenzo    entry_block = module.body.operations[0].regions[0].blocks[0]
1599f3f6d7bSStella Laurenzo    with InsertionPoint(entry_block):
1609f3f6d7bSStella Laurenzo      Operation.create("custom.op2")
1619f3f6d7bSStella Laurenzo      with InsertionPoint.at_block_begin(entry_block):
1629f3f6d7bSStella Laurenzo        Operation.create("custom.opa")
1639f3f6d7bSStella Laurenzo        Operation.create("custom.opb")
1649f3f6d7bSStella Laurenzo      Operation.create("custom.op3")
1659f3f6d7bSStella Laurenzo    # CHECK: "custom.opa"
1669f3f6d7bSStella Laurenzo    # CHECK: "custom.opb"
1679f3f6d7bSStella Laurenzo    # CHECK: "custom.op1"
1689f3f6d7bSStella Laurenzo    # CHECK: "custom.op2"
1699f3f6d7bSStella Laurenzo    # CHECK: "custom.op3"
1709f3f6d7bSStella Laurenzo    module.operation.print()
1719f3f6d7bSStella Laurenzo
1729f3f6d7bSStella Laurenzorun(test_insertion_point_context)
173