19f3f6d7bSStella Laurenzo# RUN: %PYTHON %s | FileCheck %s 29f3f6d7bSStella Laurenzo 39f3f6d7bSStella Laurenzoimport gc 49f3f6d7bSStella Laurenzoimport io 59f3f6d7bSStella Laurenzoimport itertools 69f3f6d7bSStella Laurenzofrom mlir.ir import * 79f3f6d7bSStella Laurenzo 89f3f6d7bSStella Laurenzodef run(f): 99f3f6d7bSStella Laurenzo print("\nTEST:", f.__name__) 109f3f6d7bSStella Laurenzo f() 119f3f6d7bSStella Laurenzo gc.collect() 129f3f6d7bSStella Laurenzo assert Context._get_live_count() == 0 139f3f6d7bSStella Laurenzo 149f3f6d7bSStella Laurenzo 159f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: test_insert_at_block_end 169f3f6d7bSStella Laurenzodef test_insert_at_block_end(): 179f3f6d7bSStella Laurenzo ctx = Context() 189f3f6d7bSStella Laurenzo ctx.allow_unregistered_dialects = True 199f3f6d7bSStella Laurenzo with Location.unknown(ctx): 209f3f6d7bSStella Laurenzo module = Module.parse(r""" 21*2310ced8SRiver Riddle func.func @foo() -> () { 229f3f6d7bSStella Laurenzo "custom.op1"() : () -> () 239f3f6d7bSStella Laurenzo } 249f3f6d7bSStella Laurenzo """) 259f3f6d7bSStella Laurenzo entry_block = module.body.operations[0].regions[0].blocks[0] 269f3f6d7bSStella Laurenzo ip = InsertionPoint(entry_block) 279f3f6d7bSStella Laurenzo ip.insert(Operation.create("custom.op2")) 289f3f6d7bSStella Laurenzo # CHECK: "custom.op1" 299f3f6d7bSStella Laurenzo # CHECK: "custom.op2" 309f3f6d7bSStella Laurenzo module.operation.print() 319f3f6d7bSStella Laurenzo 329f3f6d7bSStella Laurenzorun(test_insert_at_block_end) 339f3f6d7bSStella Laurenzo 349f3f6d7bSStella Laurenzo 359f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: test_insert_before_operation 369f3f6d7bSStella Laurenzodef test_insert_before_operation(): 379f3f6d7bSStella Laurenzo ctx = Context() 389f3f6d7bSStella Laurenzo ctx.allow_unregistered_dialects = True 399f3f6d7bSStella Laurenzo with Location.unknown(ctx): 409f3f6d7bSStella Laurenzo module = Module.parse(r""" 41*2310ced8SRiver Riddle func.func @foo() -> () { 429f3f6d7bSStella Laurenzo "custom.op1"() : () -> () 439f3f6d7bSStella Laurenzo "custom.op2"() : () -> () 449f3f6d7bSStella Laurenzo } 459f3f6d7bSStella Laurenzo """) 469f3f6d7bSStella Laurenzo entry_block = module.body.operations[0].regions[0].blocks[0] 479f3f6d7bSStella Laurenzo ip = InsertionPoint(entry_block.operations[1]) 489f3f6d7bSStella Laurenzo ip.insert(Operation.create("custom.op3")) 499f3f6d7bSStella Laurenzo # CHECK: "custom.op1" 509f3f6d7bSStella Laurenzo # CHECK: "custom.op3" 519f3f6d7bSStella Laurenzo # CHECK: "custom.op2" 529f3f6d7bSStella Laurenzo module.operation.print() 539f3f6d7bSStella Laurenzo 549f3f6d7bSStella Laurenzorun(test_insert_before_operation) 559f3f6d7bSStella Laurenzo 569f3f6d7bSStella Laurenzo 579f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: test_insert_at_block_begin 589f3f6d7bSStella Laurenzodef test_insert_at_block_begin(): 599f3f6d7bSStella Laurenzo ctx = Context() 609f3f6d7bSStella Laurenzo ctx.allow_unregistered_dialects = True 619f3f6d7bSStella Laurenzo with Location.unknown(ctx): 629f3f6d7bSStella Laurenzo module = Module.parse(r""" 63*2310ced8SRiver Riddle func.func @foo() -> () { 649f3f6d7bSStella Laurenzo "custom.op2"() : () -> () 659f3f6d7bSStella Laurenzo } 669f3f6d7bSStella Laurenzo """) 679f3f6d7bSStella Laurenzo entry_block = module.body.operations[0].regions[0].blocks[0] 689f3f6d7bSStella Laurenzo ip = InsertionPoint.at_block_begin(entry_block) 699f3f6d7bSStella Laurenzo ip.insert(Operation.create("custom.op1")) 709f3f6d7bSStella Laurenzo # CHECK: "custom.op1" 719f3f6d7bSStella Laurenzo # CHECK: "custom.op2" 729f3f6d7bSStella Laurenzo module.operation.print() 739f3f6d7bSStella Laurenzo 749f3f6d7bSStella Laurenzorun(test_insert_at_block_begin) 759f3f6d7bSStella Laurenzo 769f3f6d7bSStella Laurenzo 779f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: test_insert_at_block_begin_empty 789f3f6d7bSStella Laurenzodef test_insert_at_block_begin_empty(): 799f3f6d7bSStella Laurenzo # TODO: Write this test case when we can create such a situation. 809f3f6d7bSStella Laurenzo pass 819f3f6d7bSStella Laurenzo 829f3f6d7bSStella Laurenzorun(test_insert_at_block_begin_empty) 839f3f6d7bSStella Laurenzo 849f3f6d7bSStella Laurenzo 859f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: test_insert_at_terminator 869f3f6d7bSStella Laurenzodef test_insert_at_terminator(): 879f3f6d7bSStella Laurenzo ctx = Context() 889f3f6d7bSStella Laurenzo ctx.allow_unregistered_dialects = True 899f3f6d7bSStella Laurenzo with Location.unknown(ctx): 909f3f6d7bSStella Laurenzo module = Module.parse(r""" 91*2310ced8SRiver Riddle func.func @foo() -> () { 929f3f6d7bSStella Laurenzo "custom.op1"() : () -> () 939f3f6d7bSStella Laurenzo return 949f3f6d7bSStella Laurenzo } 959f3f6d7bSStella Laurenzo """) 969f3f6d7bSStella Laurenzo entry_block = module.body.operations[0].regions[0].blocks[0] 979f3f6d7bSStella Laurenzo ip = InsertionPoint.at_block_terminator(entry_block) 989f3f6d7bSStella Laurenzo ip.insert(Operation.create("custom.op2")) 999f3f6d7bSStella Laurenzo # CHECK: "custom.op1" 1009f3f6d7bSStella Laurenzo # CHECK: "custom.op2" 1019f3f6d7bSStella Laurenzo module.operation.print() 1029f3f6d7bSStella Laurenzo 1039f3f6d7bSStella Laurenzorun(test_insert_at_terminator) 1049f3f6d7bSStella Laurenzo 1059f3f6d7bSStella Laurenzo 1069f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: test_insert_at_block_terminator_missing 1079f3f6d7bSStella Laurenzodef test_insert_at_block_terminator_missing(): 1089f3f6d7bSStella Laurenzo ctx = Context() 1099f3f6d7bSStella Laurenzo ctx.allow_unregistered_dialects = True 1109f3f6d7bSStella Laurenzo with ctx: 1119f3f6d7bSStella Laurenzo module = Module.parse(r""" 112*2310ced8SRiver Riddle func.func @foo() -> () { 1139f3f6d7bSStella Laurenzo "custom.op1"() : () -> () 1149f3f6d7bSStella Laurenzo } 1159f3f6d7bSStella Laurenzo """) 1169f3f6d7bSStella Laurenzo entry_block = module.body.operations[0].regions[0].blocks[0] 1179f3f6d7bSStella Laurenzo try: 1189f3f6d7bSStella Laurenzo ip = InsertionPoint.at_block_terminator(entry_block) 1199f3f6d7bSStella Laurenzo except ValueError as e: 1209f3f6d7bSStella Laurenzo # CHECK: Block has no terminator 1219f3f6d7bSStella Laurenzo print(e) 1229f3f6d7bSStella Laurenzo else: 1239f3f6d7bSStella Laurenzo assert False, "Expected exception" 1249f3f6d7bSStella Laurenzo 1259f3f6d7bSStella Laurenzorun(test_insert_at_block_terminator_missing) 1269f3f6d7bSStella Laurenzo 1279f3f6d7bSStella Laurenzo 1289f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: test_insert_at_end_with_terminator_errors 1299f3f6d7bSStella Laurenzodef test_insert_at_end_with_terminator_errors(): 1309f3f6d7bSStella Laurenzo with Context() as ctx, Location.unknown(): 1319f3f6d7bSStella Laurenzo ctx.allow_unregistered_dialects = True 1329f3f6d7bSStella Laurenzo module = Module.parse(r""" 133*2310ced8SRiver Riddle func.func @foo() -> () { 1349f3f6d7bSStella Laurenzo return 1359f3f6d7bSStella Laurenzo } 1369f3f6d7bSStella Laurenzo """) 1379f3f6d7bSStella Laurenzo entry_block = module.body.operations[0].regions[0].blocks[0] 1389f3f6d7bSStella Laurenzo with InsertionPoint(entry_block): 1399f3f6d7bSStella Laurenzo try: 1409f3f6d7bSStella Laurenzo Operation.create("custom.op1", results=[], operands=[]) 1419f3f6d7bSStella Laurenzo except IndexError as e: 1429f3f6d7bSStella Laurenzo # CHECK: ERROR: Cannot insert operation at the end of a block that already has a terminator. 1439f3f6d7bSStella Laurenzo print(f"ERROR: {e}") 1449f3f6d7bSStella Laurenzo 1459f3f6d7bSStella Laurenzorun(test_insert_at_end_with_terminator_errors) 1469f3f6d7bSStella Laurenzo 1479f3f6d7bSStella Laurenzo 1489f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: test_insertion_point_context 1499f3f6d7bSStella Laurenzodef test_insertion_point_context(): 1509f3f6d7bSStella Laurenzo ctx = Context() 1519f3f6d7bSStella Laurenzo ctx.allow_unregistered_dialects = True 1529f3f6d7bSStella Laurenzo with Location.unknown(ctx): 1539f3f6d7bSStella Laurenzo module = Module.parse(r""" 154*2310ced8SRiver Riddle func.func @foo() -> () { 1559f3f6d7bSStella Laurenzo "custom.op1"() : () -> () 1569f3f6d7bSStella Laurenzo } 1579f3f6d7bSStella Laurenzo """) 1589f3f6d7bSStella Laurenzo entry_block = module.body.operations[0].regions[0].blocks[0] 1599f3f6d7bSStella Laurenzo with InsertionPoint(entry_block): 1609f3f6d7bSStella Laurenzo Operation.create("custom.op2") 1619f3f6d7bSStella Laurenzo with InsertionPoint.at_block_begin(entry_block): 1629f3f6d7bSStella Laurenzo Operation.create("custom.opa") 1639f3f6d7bSStella Laurenzo Operation.create("custom.opb") 1649f3f6d7bSStella Laurenzo Operation.create("custom.op3") 1659f3f6d7bSStella Laurenzo # CHECK: "custom.opa" 1669f3f6d7bSStella Laurenzo # CHECK: "custom.opb" 1679f3f6d7bSStella Laurenzo # CHECK: "custom.op1" 1689f3f6d7bSStella Laurenzo # CHECK: "custom.op2" 1699f3f6d7bSStella Laurenzo # CHECK: "custom.op3" 1709f3f6d7bSStella Laurenzo module.operation.print() 1719f3f6d7bSStella Laurenzo 1729f3f6d7bSStella Laurenzorun(test_insertion_point_context) 173