1# RUN: %PYTHON %s | FileCheck %s 2 3import gc 4import io 5import itertools 6from mlir.ir import * 7 8 9def run(f): 10 print("\nTEST:", f.__name__) 11 f() 12 gc.collect() 13 assert Context._get_live_count() == 0 14 return f 15 16 17def expect_index_error(callback): 18 try: 19 _ = callback() 20 raise RuntimeError("Expected IndexError") 21 except IndexError: 22 pass 23 24 25# Verify iterator based traversal of the op/region/block hierarchy. 26# CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators 27@run 28def testTraverseOpRegionBlockIterators(): 29 ctx = Context() 30 ctx.allow_unregistered_dialects = True 31 module = Module.parse( 32 r""" 33 func.func @f1(%arg0: i32) -> i32 { 34 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 35 return %1 : i32 36 } 37 """, ctx) 38 op = module.operation 39 assert op.context is ctx 40 # Get the block using iterators off of the named collections. 41 regions = list(op.regions) 42 blocks = list(regions[0].blocks) 43 # CHECK: MODULE REGIONS=1 BLOCKS=1 44 print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}") 45 46 # Should verify. 47 # CHECK: .verify = True 48 print(f".verify = {module.operation.verify()}") 49 50 # Get the regions and blocks from the default collections. 51 default_regions = list(op.regions) 52 default_blocks = list(default_regions[0]) 53 # They should compare equal regardless of how obtained. 54 assert default_regions == regions 55 assert default_blocks == blocks 56 57 # Should be able to get the operations from either the named collection 58 # or the block. 59 operations = list(blocks[0].operations) 60 default_operations = list(blocks[0]) 61 assert default_operations == operations 62 63 def walk_operations(indent, op): 64 for i, region in enumerate(op.regions): 65 print(f"{indent}REGION {i}:") 66 for j, block in enumerate(region): 67 print(f"{indent} BLOCK {j}:") 68 for k, child_op in enumerate(block): 69 print(f"{indent} OP {k}: {child_op}") 70 walk_operations(indent + " ", child_op) 71 72 # CHECK: REGION 0: 73 # CHECK: BLOCK 0: 74 # CHECK: OP 0: func 75 # CHECK: REGION 0: 76 # CHECK: BLOCK 0: 77 # CHECK: OP 0: %0 = "custom.addi" 78 # CHECK: OP 1: func.return 79 walk_operations("", op) 80 81 82# Verify index based traversal of the op/region/block hierarchy. 83# CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices 84@run 85def testTraverseOpRegionBlockIndices(): 86 ctx = Context() 87 ctx.allow_unregistered_dialects = True 88 module = Module.parse( 89 r""" 90 func.func @f1(%arg0: i32) -> i32 { 91 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 92 return %1 : i32 93 } 94 """, ctx) 95 96 def walk_operations(indent, op): 97 for i in range(len(op.regions)): 98 region = op.regions[i] 99 print(f"{indent}REGION {i}:") 100 for j in range(len(region.blocks)): 101 block = region.blocks[j] 102 print(f"{indent} BLOCK {j}:") 103 for k in range(len(block.operations)): 104 child_op = block.operations[k] 105 print(f"{indent} OP {k}: {child_op}") 106 print(f"{indent} OP {k}: parent {child_op.operation.parent.name}") 107 walk_operations(indent + " ", child_op) 108 109 # CHECK: REGION 0: 110 # CHECK: BLOCK 0: 111 # CHECK: OP 0: func 112 # CHECK: OP 0: parent builtin.module 113 # CHECK: REGION 0: 114 # CHECK: BLOCK 0: 115 # CHECK: OP 0: %0 = "custom.addi" 116 # CHECK: OP 0: parent func.func 117 # CHECK: OP 1: func.return 118 # CHECK: OP 1: parent func.func 119 walk_operations("", module.operation) 120 121 122# CHECK-LABEL: TEST: testBlockAndRegionOwners 123@run 124def testBlockAndRegionOwners(): 125 ctx = Context() 126 ctx.allow_unregistered_dialects = True 127 module = Module.parse( 128 r""" 129 builtin.module { 130 func.func @f() { 131 func.return 132 } 133 } 134 """, ctx) 135 136 assert module.operation.regions[0].owner == module.operation 137 assert module.operation.regions[0].blocks[0].owner == module.operation 138 139 func = module.body.operations[0] 140 assert func.operation.regions[0].owner == func 141 assert func.operation.regions[0].blocks[0].owner == func 142 143 144# CHECK-LABEL: TEST: testBlockArgumentList 145@run 146def testBlockArgumentList(): 147 with Context() as ctx: 148 module = Module.parse( 149 r""" 150 func.func @f1(%arg0: i32, %arg1: f64, %arg2: index) { 151 return 152 } 153 """, ctx) 154 func = module.body.operations[0] 155 entry_block = func.regions[0].blocks[0] 156 assert len(entry_block.arguments) == 3 157 # CHECK: Argument 0, type i32 158 # CHECK: Argument 1, type f64 159 # CHECK: Argument 2, type index 160 for arg in entry_block.arguments: 161 print(f"Argument {arg.arg_number}, type {arg.type}") 162 new_type = IntegerType.get_signless(8 * (arg.arg_number + 1)) 163 arg.set_type(new_type) 164 165 # CHECK: Argument 0, type i8 166 # CHECK: Argument 1, type i16 167 # CHECK: Argument 2, type i24 168 for arg in entry_block.arguments: 169 print(f"Argument {arg.arg_number}, type {arg.type}") 170 171 # Check that slicing works for block argument lists. 172 # CHECK: Argument 1, type i16 173 # CHECK: Argument 2, type i24 174 for arg in entry_block.arguments[1:]: 175 print(f"Argument {arg.arg_number}, type {arg.type}") 176 177 # Check that we can concatenate slices of argument lists. 178 # CHECK: Length: 4 179 print("Length: ", 180 len(entry_block.arguments[:2] + entry_block.arguments[1:])) 181 182 # CHECK: Type: i8 183 # CHECK: Type: i16 184 # CHECK: Type: i24 185 for t in entry_block.arguments.types: 186 print("Type: ", t) 187 188 189# CHECK-LABEL: TEST: testOperationOperands 190@run 191def testOperationOperands(): 192 with Context() as ctx: 193 ctx.allow_unregistered_dialects = True 194 module = Module.parse(r""" 195 func.func @f1(%arg0: i32) { 196 %0 = "test.producer"() : () -> i64 197 "test.consumer"(%arg0, %0) : (i32, i64) -> () 198 return 199 }""") 200 func = module.body.operations[0] 201 entry_block = func.regions[0].blocks[0] 202 consumer = entry_block.operations[1] 203 assert len(consumer.operands) == 2 204 # CHECK: Operand 0, type i32 205 # CHECK: Operand 1, type i64 206 for i, operand in enumerate(consumer.operands): 207 print(f"Operand {i}, type {operand.type}") 208 209 210 211 212# CHECK-LABEL: TEST: testOperationOperandsSlice 213@run 214def testOperationOperandsSlice(): 215 with Context() as ctx: 216 ctx.allow_unregistered_dialects = True 217 module = Module.parse(r""" 218 func.func @f1() { 219 %0 = "test.producer0"() : () -> i64 220 %1 = "test.producer1"() : () -> i64 221 %2 = "test.producer2"() : () -> i64 222 %3 = "test.producer3"() : () -> i64 223 %4 = "test.producer4"() : () -> i64 224 "test.consumer"(%0, %1, %2, %3, %4) : (i64, i64, i64, i64, i64) -> () 225 return 226 }""") 227 func = module.body.operations[0] 228 entry_block = func.regions[0].blocks[0] 229 consumer = entry_block.operations[5] 230 assert len(consumer.operands) == 5 231 for left, right in zip(consumer.operands, consumer.operands[::-1][::-1]): 232 assert left == right 233 234 # CHECK: test.producer0 235 # CHECK: test.producer1 236 # CHECK: test.producer2 237 # CHECK: test.producer3 238 # CHECK: test.producer4 239 full_slice = consumer.operands[:] 240 for operand in full_slice: 241 print(operand) 242 243 # CHECK: test.producer0 244 # CHECK: test.producer1 245 first_two = consumer.operands[0:2] 246 for operand in first_two: 247 print(operand) 248 249 # CHECK: test.producer3 250 # CHECK: test.producer4 251 last_two = consumer.operands[3:] 252 for operand in last_two: 253 print(operand) 254 255 # CHECK: test.producer0 256 # CHECK: test.producer2 257 # CHECK: test.producer4 258 even = consumer.operands[::2] 259 for operand in even: 260 print(operand) 261 262 # CHECK: test.producer2 263 fourth = consumer.operands[::2][1::2] 264 for operand in fourth: 265 print(operand) 266 267 268 269 270# CHECK-LABEL: TEST: testOperationOperandsSet 271@run 272def testOperationOperandsSet(): 273 with Context() as ctx, Location.unknown(ctx): 274 ctx.allow_unregistered_dialects = True 275 module = Module.parse(r""" 276 func.func @f1() { 277 %0 = "test.producer0"() : () -> i64 278 %1 = "test.producer1"() : () -> i64 279 %2 = "test.producer2"() : () -> i64 280 "test.consumer"(%0) : (i64) -> () 281 return 282 }""") 283 func = module.body.operations[0] 284 entry_block = func.regions[0].blocks[0] 285 producer1 = entry_block.operations[1] 286 producer2 = entry_block.operations[2] 287 consumer = entry_block.operations[3] 288 assert len(consumer.operands) == 1 289 type = consumer.operands[0].type 290 291 # CHECK: test.producer1 292 consumer.operands[0] = producer1.result 293 print(consumer.operands[0]) 294 295 # CHECK: test.producer2 296 consumer.operands[-1] = producer2.result 297 print(consumer.operands[0]) 298 299 300 301 302# CHECK-LABEL: TEST: testDetachedOperation 303@run 304def testDetachedOperation(): 305 ctx = Context() 306 ctx.allow_unregistered_dialects = True 307 with Location.unknown(ctx): 308 i32 = IntegerType.get_signed(32) 309 op1 = Operation.create( 310 "custom.op1", 311 results=[i32, i32], 312 regions=1, 313 attributes={ 314 "foo": StringAttr.get("foo_value"), 315 "bar": StringAttr.get("bar_value"), 316 }) 317 # CHECK: %0:2 = "custom.op1"() ({ 318 # CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32) 319 print(op1) 320 321 # TODO: Check successors once enough infra exists to do it properly. 322 323 324# CHECK-LABEL: TEST: testOperationInsertionPoint 325@run 326def testOperationInsertionPoint(): 327 ctx = Context() 328 ctx.allow_unregistered_dialects = True 329 module = Module.parse( 330 r""" 331 func.func @f1(%arg0: i32) -> i32 { 332 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 333 return %1 : i32 334 } 335 """, ctx) 336 337 # Create test op. 338 with Location.unknown(ctx): 339 op1 = Operation.create("custom.op1") 340 op2 = Operation.create("custom.op2") 341 342 func = module.body.operations[0] 343 entry_block = func.regions[0].blocks[0] 344 ip = InsertionPoint.at_block_begin(entry_block) 345 ip.insert(op1) 346 ip.insert(op2) 347 # CHECK: func @f1 348 # CHECK: "custom.op1"() 349 # CHECK: "custom.op2"() 350 # CHECK: %0 = "custom.addi" 351 print(module) 352 353 # Trying to add a previously added op should raise. 354 try: 355 ip.insert(op1) 356 except ValueError: 357 pass 358 else: 359 assert False, "expected insert of attached op to raise" 360 361 362# CHECK-LABEL: TEST: testOperationWithRegion 363@run 364def testOperationWithRegion(): 365 ctx = Context() 366 ctx.allow_unregistered_dialects = True 367 with Location.unknown(ctx): 368 i32 = IntegerType.get_signed(32) 369 op1 = Operation.create("custom.op1", regions=1) 370 block = op1.regions[0].blocks.append(i32, i32) 371 # CHECK: "custom.op1"() ({ 372 # CHECK: ^bb0(%arg0: si32, %arg1: si32): 373 # CHECK: "custom.terminator"() : () -> () 374 # CHECK: }) : () -> () 375 terminator = Operation.create("custom.terminator") 376 ip = InsertionPoint(block) 377 ip.insert(terminator) 378 print(op1) 379 380 # Now add the whole operation to another op. 381 # TODO: Verify lifetime hazard by nulling out the new owning module and 382 # accessing op1. 383 # TODO: Also verify accessing the terminator once both parents are nulled 384 # out. 385 module = Module.parse(r""" 386 func.func @f1(%arg0: i32) -> i32 { 387 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 388 return %1 : i32 389 } 390 """) 391 func = module.body.operations[0] 392 entry_block = func.regions[0].blocks[0] 393 ip = InsertionPoint.at_block_begin(entry_block) 394 ip.insert(op1) 395 # CHECK: func @f1 396 # CHECK: "custom.op1"() 397 # CHECK: "custom.terminator" 398 # CHECK: %0 = "custom.addi" 399 print(module) 400 401 402# CHECK-LABEL: TEST: testOperationResultList 403@run 404def testOperationResultList(): 405 ctx = Context() 406 module = Module.parse( 407 r""" 408 func.func @f1() { 409 %0:3 = call @f2() : () -> (i32, f64, index) 410 return 411 } 412 func.func private @f2() -> (i32, f64, index) 413 """, ctx) 414 caller = module.body.operations[0] 415 call = caller.regions[0].blocks[0].operations[0] 416 assert len(call.results) == 3 417 # CHECK: Result 0, type i32 418 # CHECK: Result 1, type f64 419 # CHECK: Result 2, type index 420 for res in call.results: 421 print(f"Result {res.result_number}, type {res.type}") 422 423 # CHECK: Result type i32 424 # CHECK: Result type f64 425 # CHECK: Result type index 426 for t in call.results.types: 427 print(f"Result type {t}") 428 429 # Out of range 430 expect_index_error(lambda: call.results[3]) 431 expect_index_error(lambda: call.results[-4]) 432 433 434# CHECK-LABEL: TEST: testOperationResultListSlice 435@run 436def testOperationResultListSlice(): 437 with Context() as ctx: 438 ctx.allow_unregistered_dialects = True 439 module = Module.parse(r""" 440 func.func @f1() { 441 "some.op"() : () -> (i1, i2, i3, i4, i5) 442 return 443 } 444 """) 445 func = module.body.operations[0] 446 entry_block = func.regions[0].blocks[0] 447 producer = entry_block.operations[0] 448 449 assert len(producer.results) == 5 450 for left, right in zip(producer.results, producer.results[::-1][::-1]): 451 assert left == right 452 assert left.result_number == right.result_number 453 454 # CHECK: Result 0, type i1 455 # CHECK: Result 1, type i2 456 # CHECK: Result 2, type i3 457 # CHECK: Result 3, type i4 458 # CHECK: Result 4, type i5 459 full_slice = producer.results[:] 460 for res in full_slice: 461 print(f"Result {res.result_number}, type {res.type}") 462 463 # CHECK: Result 1, type i2 464 # CHECK: Result 2, type i3 465 # CHECK: Result 3, type i4 466 middle = producer.results[1:4] 467 for res in middle: 468 print(f"Result {res.result_number}, type {res.type}") 469 470 # CHECK: Result 1, type i2 471 # CHECK: Result 3, type i4 472 odd = producer.results[1::2] 473 for res in odd: 474 print(f"Result {res.result_number}, type {res.type}") 475 476 # CHECK: Result 3, type i4 477 # CHECK: Result 1, type i2 478 inverted_middle = producer.results[-2:0:-2] 479 for res in inverted_middle: 480 print(f"Result {res.result_number}, type {res.type}") 481 482 483# CHECK-LABEL: TEST: testOperationAttributes 484@run 485def testOperationAttributes(): 486 ctx = Context() 487 ctx.allow_unregistered_dialects = True 488 module = Module.parse( 489 r""" 490 "some.op"() { some.attribute = 1 : i8, 491 other.attribute = 3.0, 492 dependent = "text" } : () -> () 493 """, ctx) 494 op = module.body.operations[0] 495 assert len(op.attributes) == 3 496 iattr = IntegerAttr(op.attributes["some.attribute"]) 497 fattr = FloatAttr(op.attributes["other.attribute"]) 498 sattr = StringAttr(op.attributes["dependent"]) 499 # CHECK: Attribute type i8, value 1 500 print(f"Attribute type {iattr.type}, value {iattr.value}") 501 # CHECK: Attribute type f64, value 3.0 502 print(f"Attribute type {fattr.type}, value {fattr.value}") 503 # CHECK: Attribute value text 504 print(f"Attribute value {sattr.value}") 505 506 # We don't know in which order the attributes are stored. 507 # CHECK-DAG: NamedAttribute(dependent="text") 508 # CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64) 509 # CHECK-DAG: NamedAttribute(some.attribute=1 : i8) 510 for attr in op.attributes: 511 print(str(attr)) 512 513 # Check that exceptions are raised as expected. 514 try: 515 op.attributes["does_not_exist"] 516 except KeyError: 517 pass 518 else: 519 assert False, "expected KeyError on accessing a non-existent attribute" 520 521 try: 522 op.attributes[42] 523 except IndexError: 524 pass 525 else: 526 assert False, "expected IndexError on accessing an out-of-bounds attribute" 527 528 529 530 531# CHECK-LABEL: TEST: testOperationPrint 532@run 533def testOperationPrint(): 534 ctx = Context() 535 module = Module.parse( 536 r""" 537 func.func @f1(%arg0: i32) -> i32 { 538 %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom") 539 return %arg0 : i32 540 } 541 """, ctx) 542 543 # Test print to stdout. 544 # CHECK: return %arg0 : i32 545 module.operation.print() 546 547 # Test print to text file. 548 f = io.StringIO() 549 # CHECK: <class 'str'> 550 # CHECK: return %arg0 : i32 551 module.operation.print(file=f) 552 str_value = f.getvalue() 553 print(str_value.__class__) 554 print(f.getvalue()) 555 556 # Test print to binary file. 557 f = io.BytesIO() 558 # CHECK: <class 'bytes'> 559 # CHECK: return %arg0 : i32 560 module.operation.print(file=f, binary=True) 561 bytes_value = f.getvalue() 562 print(bytes_value.__class__) 563 print(bytes_value) 564 565 # Test get_asm local_scope. 566 # CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom") 567 module.operation.print(enable_debug_info=True, use_local_scope=True) 568 569 # Test get_asm with options. 570 # CHECK: value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<4xi32> 571 # CHECK: "func.return"(%arg0) : (i32) -> () -:4:7 572 module.operation.print( 573 large_elements_limit=2, 574 enable_debug_info=True, 575 pretty_debug_info=True, 576 print_generic_op_form=True, 577 use_local_scope=True) 578 579 580 581 582# CHECK-LABEL: TEST: testKnownOpView 583@run 584def testKnownOpView(): 585 with Context(), Location.unknown(): 586 Context.current.allow_unregistered_dialects = True 587 module = Module.parse(r""" 588 %1 = "custom.f32"() : () -> f32 589 %2 = "custom.f32"() : () -> f32 590 %3 = arith.addf %1, %2 : f32 591 """) 592 print(module) 593 594 # addf should map to a known OpView class in the arithmetic dialect. 595 # We know the OpView for it defines an 'lhs' attribute. 596 addf = module.body.operations[2] 597 # CHECK: <mlir.dialects._arith_ops_gen._AddFOp object 598 print(repr(addf)) 599 # CHECK: "custom.f32"() 600 print(addf.lhs) 601 602 # One of the custom ops should resolve to the default OpView. 603 custom = module.body.operations[0] 604 # CHECK: OpView object 605 print(repr(custom)) 606 607 # Check again to make sure negative caching works. 608 custom = module.body.operations[0] 609 # CHECK: OpView object 610 print(repr(custom)) 611 612 613# CHECK-LABEL: TEST: testSingleResultProperty 614@run 615def testSingleResultProperty(): 616 with Context(), Location.unknown(): 617 Context.current.allow_unregistered_dialects = True 618 module = Module.parse(r""" 619 "custom.no_result"() : () -> () 620 %0:2 = "custom.two_result"() : () -> (f32, f32) 621 %1 = "custom.one_result"() : () -> f32 622 """) 623 print(module) 624 625 try: 626 module.body.operations[0].result 627 except ValueError as e: 628 # CHECK: Cannot call .result on operation custom.no_result which has 0 results 629 print(e) 630 else: 631 assert False, "Expected exception" 632 633 try: 634 module.body.operations[1].result 635 except ValueError as e: 636 # CHECK: Cannot call .result on operation custom.two_result which has 2 results 637 print(e) 638 else: 639 assert False, "Expected exception" 640 641 # CHECK: %1 = "custom.one_result"() : () -> f32 642 print(module.body.operations[2]) 643 644 645def create_invalid_operation(): 646 # This module has two region and is invalid verify that we fallback 647 # to the generic printer for safety. 648 op = Operation.create("builtin.module", regions=2) 649 op.regions[0].blocks.append() 650 return op 651 652# CHECK-LABEL: TEST: testInvalidOperationStrSoftFails 653@run 654def testInvalidOperationStrSoftFails(): 655 ctx = Context() 656 with Location.unknown(ctx): 657 invalid_op = create_invalid_operation() 658 # Verify that we fallback to the generic printer for safety. 659 # CHECK: // Verification failed, printing generic form 660 # CHECK: "builtin.module"() ({ 661 # CHECK: }) : () -> () 662 print(invalid_op) 663 # CHECK: .verify = False 664 print(f".verify = {invalid_op.operation.verify()}") 665 666 667# CHECK-LABEL: TEST: testInvalidModuleStrSoftFails 668@run 669def testInvalidModuleStrSoftFails(): 670 ctx = Context() 671 with Location.unknown(ctx): 672 module = Module.create() 673 with InsertionPoint(module.body): 674 invalid_op = create_invalid_operation() 675 # Verify that we fallback to the generic printer for safety. 676 # CHECK: // Verification failed, printing generic form 677 print(module) 678 679 680# CHECK-LABEL: TEST: testInvalidOperationGetAsmBinarySoftFails 681@run 682def testInvalidOperationGetAsmBinarySoftFails(): 683 ctx = Context() 684 with Location.unknown(ctx): 685 invalid_op = create_invalid_operation() 686 # Verify that we fallback to the generic printer for safety. 687 # CHECK: b'// Verification failed, printing generic form\n 688 print(invalid_op.get_asm(binary=True)) 689 690 691# CHECK-LABEL: TEST: testCreateWithInvalidAttributes 692@run 693def testCreateWithInvalidAttributes(): 694 ctx = Context() 695 with Location.unknown(ctx): 696 try: 697 Operation.create( 698 "builtin.module", attributes={None: StringAttr.get("name")}) 699 except Exception as e: 700 # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module" 701 print(e) 702 try: 703 Operation.create( 704 "builtin.module", attributes={42: StringAttr.get("name")}) 705 except Exception as e: 706 # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module" 707 print(e) 708 try: 709 Operation.create("builtin.module", attributes={"some_key": ctx}) 710 except Exception as e: 711 # CHECK: Invalid attribute value for the key "some_key" when attempting to create the operation "builtin.module" 712 print(e) 713 try: 714 Operation.create("builtin.module", attributes={"some_key": None}) 715 except Exception as e: 716 # CHECK: Found an invalid (`None`?) attribute value for the key "some_key" when attempting to create the operation "builtin.module" 717 print(e) 718 719 720# CHECK-LABEL: TEST: testOperationName 721@run 722def testOperationName(): 723 ctx = Context() 724 ctx.allow_unregistered_dialects = True 725 module = Module.parse( 726 r""" 727 %0 = "custom.op1"() : () -> f32 728 %1 = "custom.op2"() : () -> i32 729 %2 = "custom.op1"() : () -> f32 730 """, ctx) 731 732 # CHECK: custom.op1 733 # CHECK: custom.op2 734 # CHECK: custom.op1 735 for op in module.body.operations: 736 print(op.operation.name) 737 738 739# CHECK-LABEL: TEST: testCapsuleConversions 740@run 741def testCapsuleConversions(): 742 ctx = Context() 743 ctx.allow_unregistered_dialects = True 744 with Location.unknown(ctx): 745 m = Operation.create("custom.op1").operation 746 m_capsule = m._CAPIPtr 747 assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule) 748 m2 = Operation._CAPICreate(m_capsule) 749 assert m2 is m 750 751 752# CHECK-LABEL: TEST: testOperationErase 753@run 754def testOperationErase(): 755 ctx = Context() 756 ctx.allow_unregistered_dialects = True 757 with Location.unknown(ctx): 758 m = Module.create() 759 with InsertionPoint(m.body): 760 op = Operation.create("custom.op1") 761 762 # CHECK: "custom.op1" 763 print(m) 764 765 op.operation.erase() 766 767 # CHECK-NOT: "custom.op1" 768 print(m) 769 770 # Ensure we can create another operation 771 Operation.create("custom.op2") 772 773 774# CHECK-LABEL: TEST: testOperationClone 775@run 776def testOperationClone(): 777 ctx = Context() 778 ctx.allow_unregistered_dialects = True 779 with Location.unknown(ctx): 780 m = Module.create() 781 with InsertionPoint(m.body): 782 op = Operation.create("custom.op1") 783 784 # CHECK: "custom.op1" 785 print(m) 786 787 clone = op.operation.clone() 788 op.operation.erase() 789 790 # CHECK: "custom.op1" 791 print(m) 792 793 794# CHECK-LABEL: TEST: testOperationLoc 795@run 796def testOperationLoc(): 797 ctx = Context() 798 ctx.allow_unregistered_dialects = True 799 with ctx: 800 loc = Location.name("loc") 801 op = Operation.create("custom.op", loc=loc) 802 assert op.location == loc 803 assert op.operation.location == loc 804 805 806# CHECK-LABEL: TEST: testModuleMerge 807@run 808def testModuleMerge(): 809 with Context(): 810 m1 = Module.parse("func.func private @foo()") 811 m2 = Module.parse(""" 812 func.func private @bar() 813 func.func private @qux() 814 """) 815 foo = m1.body.operations[0] 816 bar = m2.body.operations[0] 817 qux = m2.body.operations[1] 818 bar.move_before(foo) 819 qux.move_after(foo) 820 821 # CHECK: module 822 # CHECK: func private @bar 823 # CHECK: func private @foo 824 # CHECK: func private @qux 825 print(m1) 826 827 # CHECK: module { 828 # CHECK-NEXT: } 829 print(m2) 830 831 832# CHECK-LABEL: TEST: testAppendMoveFromAnotherBlock 833@run 834def testAppendMoveFromAnotherBlock(): 835 with Context(): 836 m1 = Module.parse("func.func private @foo()") 837 m2 = Module.parse("func.func private @bar()") 838 func = m1.body.operations[0] 839 m2.body.append(func) 840 841 # CHECK: module 842 # CHECK: func private @bar 843 # CHECK: func private @foo 844 845 print(m2) 846 # CHECK: module { 847 # CHECK-NEXT: } 848 print(m1) 849 850 851# CHECK-LABEL: TEST: testDetachFromParent 852@run 853def testDetachFromParent(): 854 with Context(): 855 m1 = Module.parse("func.func private @foo()") 856 func = m1.body.operations[0].detach_from_parent() 857 858 try: 859 func.detach_from_parent() 860 except ValueError as e: 861 if "has no parent" not in str(e): 862 raise 863 else: 864 assert False, "expected ValueError when detaching a detached operation" 865 866 print(m1) 867 # CHECK-NOT: func private @foo 868 869 870# CHECK-LABEL: TEST: testOperationHash 871@run 872def testOperationHash(): 873 ctx = Context() 874 ctx.allow_unregistered_dialects = True 875 with ctx, Location.unknown(): 876 op = Operation.create("custom.op1") 877 assert hash(op) == hash(op.operation) 878