1# RUN: %PYTHON %s | FileCheck %s 2 3import gc 4import io 5import itertools 6from mlir.ir import * 7 8 9def run(f): 10 print("\nTEST:", f.__name__) 11 f() 12 gc.collect() 13 assert Context._get_live_count() == 0 14 return f 15 16 17# Verify iterator based traversal of the op/region/block hierarchy. 18# CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators 19@run 20def testTraverseOpRegionBlockIterators(): 21 ctx = Context() 22 ctx.allow_unregistered_dialects = True 23 module = Module.parse( 24 r""" 25 func @f1(%arg0: i32) -> i32 { 26 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 27 return %1 : i32 28 } 29 """, ctx) 30 op = module.operation 31 assert op.context is ctx 32 # Get the block using iterators off of the named collections. 33 regions = list(op.regions) 34 blocks = list(regions[0].blocks) 35 # CHECK: MODULE REGIONS=1 BLOCKS=1 36 print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}") 37 38 # Should verify. 39 # CHECK: .verify = True 40 print(f".verify = {module.operation.verify()}") 41 42 # Get the regions and blocks from the default collections. 43 default_regions = list(op.regions) 44 default_blocks = list(default_regions[0]) 45 # They should compare equal regardless of how obtained. 46 assert default_regions == regions 47 assert default_blocks == blocks 48 49 # Should be able to get the operations from either the named collection 50 # or the block. 51 operations = list(blocks[0].operations) 52 default_operations = list(blocks[0]) 53 assert default_operations == operations 54 55 def walk_operations(indent, op): 56 for i, region in enumerate(op.regions): 57 print(f"{indent}REGION {i}:") 58 for j, block in enumerate(region): 59 print(f"{indent} BLOCK {j}:") 60 for k, child_op in enumerate(block): 61 print(f"{indent} OP {k}: {child_op}") 62 walk_operations(indent + " ", child_op) 63 64 # CHECK: REGION 0: 65 # CHECK: BLOCK 0: 66 # CHECK: OP 0: func 67 # CHECK: REGION 0: 68 # CHECK: BLOCK 0: 69 # CHECK: OP 0: %0 = "custom.addi" 70 # CHECK: OP 1: return 71 walk_operations("", op) 72 73 74# Verify index based traversal of the op/region/block hierarchy. 75# CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices 76@run 77def testTraverseOpRegionBlockIndices(): 78 ctx = Context() 79 ctx.allow_unregistered_dialects = True 80 module = Module.parse( 81 r""" 82 func @f1(%arg0: i32) -> i32 { 83 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 84 return %1 : i32 85 } 86 """, ctx) 87 88 def walk_operations(indent, op): 89 for i in range(len(op.regions)): 90 region = op.regions[i] 91 print(f"{indent}REGION {i}:") 92 for j in range(len(region.blocks)): 93 block = region.blocks[j] 94 print(f"{indent} BLOCK {j}:") 95 for k in range(len(block.operations)): 96 child_op = block.operations[k] 97 print(f"{indent} OP {k}: {child_op}") 98 print(f"{indent} OP {k}: parent {child_op.operation.parent.name}") 99 walk_operations(indent + " ", child_op) 100 101 # CHECK: REGION 0: 102 # CHECK: BLOCK 0: 103 # CHECK: OP 0: func 104 # CHECK: OP 0: parent builtin.module 105 # CHECK: REGION 0: 106 # CHECK: BLOCK 0: 107 # CHECK: OP 0: %0 = "custom.addi" 108 # CHECK: OP 0: parent builtin.func 109 # CHECK: OP 1: return 110 # CHECK: OP 1: parent builtin.func 111 walk_operations("", module.operation) 112 113 114# CHECK-LABEL: TEST: testBlockAndRegionOwners 115@run 116def testBlockAndRegionOwners(): 117 ctx = Context() 118 ctx.allow_unregistered_dialects = True 119 module = Module.parse( 120 r""" 121 builtin.module { 122 builtin.func @f() { 123 std.return 124 } 125 } 126 """, ctx) 127 128 assert module.operation.regions[0].owner == module.operation 129 assert module.operation.regions[0].blocks[0].owner == module.operation 130 131 func = module.body.operations[0] 132 assert func.operation.regions[0].owner == func 133 assert func.operation.regions[0].blocks[0].owner == func 134 135 136# CHECK-LABEL: TEST: testBlockArgumentList 137@run 138def testBlockArgumentList(): 139 with Context() as ctx: 140 module = Module.parse( 141 r""" 142 func @f1(%arg0: i32, %arg1: f64, %arg2: index) { 143 return 144 } 145 """, ctx) 146 func = module.body.operations[0] 147 entry_block = func.regions[0].blocks[0] 148 assert len(entry_block.arguments) == 3 149 # CHECK: Argument 0, type i32 150 # CHECK: Argument 1, type f64 151 # CHECK: Argument 2, type index 152 for arg in entry_block.arguments: 153 print(f"Argument {arg.arg_number}, type {arg.type}") 154 new_type = IntegerType.get_signless(8 * (arg.arg_number + 1)) 155 arg.set_type(new_type) 156 157 # CHECK: Argument 0, type i8 158 # CHECK: Argument 1, type i16 159 # CHECK: Argument 2, type i24 160 for arg in entry_block.arguments: 161 print(f"Argument {arg.arg_number}, type {arg.type}") 162 163 # Check that slicing works for block argument lists. 164 # CHECK: Argument 1, type i16 165 # CHECK: Argument 2, type i24 166 for arg in entry_block.arguments[1:]: 167 print(f"Argument {arg.arg_number}, type {arg.type}") 168 169 # Check that we can concatenate slices of argument lists. 170 # CHECK: Length: 4 171 print("Length: ", 172 len(entry_block.arguments[:2] + entry_block.arguments[1:])) 173 174 # CHECK: Type: i8 175 # CHECK: Type: i16 176 # CHECK: Type: i24 177 for t in entry_block.arguments.types: 178 print("Type: ", t) 179 180 181# CHECK-LABEL: TEST: testOperationOperands 182@run 183def testOperationOperands(): 184 with Context() as ctx: 185 ctx.allow_unregistered_dialects = True 186 module = Module.parse(r""" 187 func @f1(%arg0: i32) { 188 %0 = "test.producer"() : () -> i64 189 "test.consumer"(%arg0, %0) : (i32, i64) -> () 190 return 191 }""") 192 func = module.body.operations[0] 193 entry_block = func.regions[0].blocks[0] 194 consumer = entry_block.operations[1] 195 assert len(consumer.operands) == 2 196 # CHECK: Operand 0, type i32 197 # CHECK: Operand 1, type i64 198 for i, operand in enumerate(consumer.operands): 199 print(f"Operand {i}, type {operand.type}") 200 201 202 203 204# CHECK-LABEL: TEST: testOperationOperandsSlice 205@run 206def testOperationOperandsSlice(): 207 with Context() as ctx: 208 ctx.allow_unregistered_dialects = True 209 module = Module.parse(r""" 210 func @f1() { 211 %0 = "test.producer0"() : () -> i64 212 %1 = "test.producer1"() : () -> i64 213 %2 = "test.producer2"() : () -> i64 214 %3 = "test.producer3"() : () -> i64 215 %4 = "test.producer4"() : () -> i64 216 "test.consumer"(%0, %1, %2, %3, %4) : (i64, i64, i64, i64, i64) -> () 217 return 218 }""") 219 func = module.body.operations[0] 220 entry_block = func.regions[0].blocks[0] 221 consumer = entry_block.operations[5] 222 assert len(consumer.operands) == 5 223 for left, right in zip(consumer.operands, consumer.operands[::-1][::-1]): 224 assert left == right 225 226 # CHECK: test.producer0 227 # CHECK: test.producer1 228 # CHECK: test.producer2 229 # CHECK: test.producer3 230 # CHECK: test.producer4 231 full_slice = consumer.operands[:] 232 for operand in full_slice: 233 print(operand) 234 235 # CHECK: test.producer0 236 # CHECK: test.producer1 237 first_two = consumer.operands[0:2] 238 for operand in first_two: 239 print(operand) 240 241 # CHECK: test.producer3 242 # CHECK: test.producer4 243 last_two = consumer.operands[3:] 244 for operand in last_two: 245 print(operand) 246 247 # CHECK: test.producer0 248 # CHECK: test.producer2 249 # CHECK: test.producer4 250 even = consumer.operands[::2] 251 for operand in even: 252 print(operand) 253 254 # CHECK: test.producer2 255 fourth = consumer.operands[::2][1::2] 256 for operand in fourth: 257 print(operand) 258 259 260 261 262# CHECK-LABEL: TEST: testOperationOperandsSet 263@run 264def testOperationOperandsSet(): 265 with Context() as ctx, Location.unknown(ctx): 266 ctx.allow_unregistered_dialects = True 267 module = Module.parse(r""" 268 func @f1() { 269 %0 = "test.producer0"() : () -> i64 270 %1 = "test.producer1"() : () -> i64 271 %2 = "test.producer2"() : () -> i64 272 "test.consumer"(%0) : (i64) -> () 273 return 274 }""") 275 func = module.body.operations[0] 276 entry_block = func.regions[0].blocks[0] 277 producer1 = entry_block.operations[1] 278 producer2 = entry_block.operations[2] 279 consumer = entry_block.operations[3] 280 assert len(consumer.operands) == 1 281 type = consumer.operands[0].type 282 283 # CHECK: test.producer1 284 consumer.operands[0] = producer1.result 285 print(consumer.operands[0]) 286 287 # CHECK: test.producer2 288 consumer.operands[-1] = producer2.result 289 print(consumer.operands[0]) 290 291 292 293 294# CHECK-LABEL: TEST: testDetachedOperation 295@run 296def testDetachedOperation(): 297 ctx = Context() 298 ctx.allow_unregistered_dialects = True 299 with Location.unknown(ctx): 300 i32 = IntegerType.get_signed(32) 301 op1 = Operation.create( 302 "custom.op1", 303 results=[i32, i32], 304 regions=1, 305 attributes={ 306 "foo": StringAttr.get("foo_value"), 307 "bar": StringAttr.get("bar_value"), 308 }) 309 # CHECK: %0:2 = "custom.op1"() ( { 310 # CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32) 311 print(op1) 312 313 # TODO: Check successors once enough infra exists to do it properly. 314 315 316# CHECK-LABEL: TEST: testOperationInsertionPoint 317@run 318def testOperationInsertionPoint(): 319 ctx = Context() 320 ctx.allow_unregistered_dialects = True 321 module = Module.parse( 322 r""" 323 func @f1(%arg0: i32) -> i32 { 324 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 325 return %1 : i32 326 } 327 """, ctx) 328 329 # Create test op. 330 with Location.unknown(ctx): 331 op1 = Operation.create("custom.op1") 332 op2 = Operation.create("custom.op2") 333 334 func = module.body.operations[0] 335 entry_block = func.regions[0].blocks[0] 336 ip = InsertionPoint.at_block_begin(entry_block) 337 ip.insert(op1) 338 ip.insert(op2) 339 # CHECK: func @f1 340 # CHECK: "custom.op1"() 341 # CHECK: "custom.op2"() 342 # CHECK: %0 = "custom.addi" 343 print(module) 344 345 # Trying to add a previously added op should raise. 346 try: 347 ip.insert(op1) 348 except ValueError: 349 pass 350 else: 351 assert False, "expected insert of attached op to raise" 352 353 354# CHECK-LABEL: TEST: testOperationWithRegion 355@run 356def testOperationWithRegion(): 357 ctx = Context() 358 ctx.allow_unregistered_dialects = True 359 with Location.unknown(ctx): 360 i32 = IntegerType.get_signed(32) 361 op1 = Operation.create("custom.op1", regions=1) 362 block = op1.regions[0].blocks.append(i32, i32) 363 # CHECK: "custom.op1"() ( { 364 # CHECK: ^bb0(%arg0: si32, %arg1: si32): // no predecessors 365 # CHECK: "custom.terminator"() : () -> () 366 # CHECK: }) : () -> () 367 terminator = Operation.create("custom.terminator") 368 ip = InsertionPoint(block) 369 ip.insert(terminator) 370 print(op1) 371 372 # Now add the whole operation to another op. 373 # TODO: Verify lifetime hazard by nulling out the new owning module and 374 # accessing op1. 375 # TODO: Also verify accessing the terminator once both parents are nulled 376 # out. 377 module = Module.parse(r""" 378 func @f1(%arg0: i32) -> i32 { 379 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 380 return %1 : i32 381 } 382 """) 383 func = module.body.operations[0] 384 entry_block = func.regions[0].blocks[0] 385 ip = InsertionPoint.at_block_begin(entry_block) 386 ip.insert(op1) 387 # CHECK: func @f1 388 # CHECK: "custom.op1"() 389 # CHECK: "custom.terminator" 390 # CHECK: %0 = "custom.addi" 391 print(module) 392 393 394# CHECK-LABEL: TEST: testOperationResultList 395@run 396def testOperationResultList(): 397 ctx = Context() 398 module = Module.parse( 399 r""" 400 func @f1() { 401 %0:3 = call @f2() : () -> (i32, f64, index) 402 return 403 } 404 func private @f2() -> (i32, f64, index) 405 """, ctx) 406 caller = module.body.operations[0] 407 call = caller.regions[0].blocks[0].operations[0] 408 assert len(call.results) == 3 409 # CHECK: Result 0, type i32 410 # CHECK: Result 1, type f64 411 # CHECK: Result 2, type index 412 for res in call.results: 413 print(f"Result {res.result_number}, type {res.type}") 414 415 # CHECK: Result type i32 416 # CHECK: Result type f64 417 # CHECK: Result type index 418 for t in call.results.types: 419 print(f"Result type {t}") 420 421 422 423 424# CHECK-LABEL: TEST: testOperationResultListSlice 425@run 426def testOperationResultListSlice(): 427 with Context() as ctx: 428 ctx.allow_unregistered_dialects = True 429 module = Module.parse(r""" 430 func @f1() { 431 "some.op"() : () -> (i1, i2, i3, i4, i5) 432 return 433 } 434 """) 435 func = module.body.operations[0] 436 entry_block = func.regions[0].blocks[0] 437 producer = entry_block.operations[0] 438 439 assert len(producer.results) == 5 440 for left, right in zip(producer.results, producer.results[::-1][::-1]): 441 assert left == right 442 assert left.result_number == right.result_number 443 444 # CHECK: Result 0, type i1 445 # CHECK: Result 1, type i2 446 # CHECK: Result 2, type i3 447 # CHECK: Result 3, type i4 448 # CHECK: Result 4, type i5 449 full_slice = producer.results[:] 450 for res in full_slice: 451 print(f"Result {res.result_number}, type {res.type}") 452 453 # CHECK: Result 1, type i2 454 # CHECK: Result 2, type i3 455 # CHECK: Result 3, type i4 456 middle = producer.results[1:4] 457 for res in middle: 458 print(f"Result {res.result_number}, type {res.type}") 459 460 # CHECK: Result 1, type i2 461 # CHECK: Result 3, type i4 462 odd = producer.results[1::2] 463 for res in odd: 464 print(f"Result {res.result_number}, type {res.type}") 465 466 # CHECK: Result 3, type i4 467 # CHECK: Result 1, type i2 468 inverted_middle = producer.results[-2:0:-2] 469 for res in inverted_middle: 470 print(f"Result {res.result_number}, type {res.type}") 471 472 473 474 475# CHECK-LABEL: TEST: testOperationAttributes 476@run 477def testOperationAttributes(): 478 ctx = Context() 479 ctx.allow_unregistered_dialects = True 480 module = Module.parse( 481 r""" 482 "some.op"() { some.attribute = 1 : i8, 483 other.attribute = 3.0, 484 dependent = "text" } : () -> () 485 """, ctx) 486 op = module.body.operations[0] 487 assert len(op.attributes) == 3 488 iattr = IntegerAttr(op.attributes["some.attribute"]) 489 fattr = FloatAttr(op.attributes["other.attribute"]) 490 sattr = StringAttr(op.attributes["dependent"]) 491 # CHECK: Attribute type i8, value 1 492 print(f"Attribute type {iattr.type}, value {iattr.value}") 493 # CHECK: Attribute type f64, value 3.0 494 print(f"Attribute type {fattr.type}, value {fattr.value}") 495 # CHECK: Attribute value text 496 print(f"Attribute value {sattr.value}") 497 498 # We don't know in which order the attributes are stored. 499 # CHECK-DAG: NamedAttribute(dependent="text") 500 # CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64) 501 # CHECK-DAG: NamedAttribute(some.attribute=1 : i8) 502 for attr in op.attributes: 503 print(str(attr)) 504 505 # Check that exceptions are raised as expected. 506 try: 507 op.attributes["does_not_exist"] 508 except KeyError: 509 pass 510 else: 511 assert False, "expected KeyError on accessing a non-existent attribute" 512 513 try: 514 op.attributes[42] 515 except IndexError: 516 pass 517 else: 518 assert False, "expected IndexError on accessing an out-of-bounds attribute" 519 520 521 522 523# CHECK-LABEL: TEST: testOperationPrint 524@run 525def testOperationPrint(): 526 ctx = Context() 527 module = Module.parse( 528 r""" 529 func @f1(%arg0: i32) -> i32 { 530 %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> 531 return %arg0 : i32 532 } 533 """, ctx) 534 535 # Test print to stdout. 536 # CHECK: return %arg0 : i32 537 module.operation.print() 538 539 # Test print to text file. 540 f = io.StringIO() 541 # CHECK: <class 'str'> 542 # CHECK: return %arg0 : i32 543 module.operation.print(file=f) 544 str_value = f.getvalue() 545 print(str_value.__class__) 546 print(f.getvalue()) 547 548 # Test print to binary file. 549 f = io.BytesIO() 550 # CHECK: <class 'bytes'> 551 # CHECK: return %arg0 : i32 552 module.operation.print(file=f, binary=True) 553 bytes_value = f.getvalue() 554 print(bytes_value.__class__) 555 print(bytes_value) 556 557 # Test get_asm with options. 558 # CHECK: value = opaque<"_", "0xDEADBEEF"> : tensor<4xi32> 559 # CHECK: "std.return"(%arg0) : (i32) -> () -:4:7 560 module.operation.print( 561 large_elements_limit=2, 562 enable_debug_info=True, 563 pretty_debug_info=True, 564 print_generic_op_form=True, 565 use_local_scope=True) 566 567 568 569 570# CHECK-LABEL: TEST: testKnownOpView 571@run 572def testKnownOpView(): 573 with Context(), Location.unknown(): 574 Context.current.allow_unregistered_dialects = True 575 module = Module.parse(r""" 576 %1 = "custom.f32"() : () -> f32 577 %2 = "custom.f32"() : () -> f32 578 %3 = arith.addf %1, %2 : f32 579 """) 580 print(module) 581 582 # addf should map to a known OpView class in the std dialect. 583 # We know the OpView for it defines an 'lhs' attribute. 584 addf = module.body.operations[2] 585 # CHECK: <mlir.dialects._arith_ops_gen._AddFOp object 586 print(repr(addf)) 587 # CHECK: "custom.f32"() 588 print(addf.lhs) 589 590 # One of the custom ops should resolve to the default OpView. 591 custom = module.body.operations[0] 592 # CHECK: OpView object 593 print(repr(custom)) 594 595 # Check again to make sure negative caching works. 596 custom = module.body.operations[0] 597 # CHECK: OpView object 598 print(repr(custom)) 599 600 601# CHECK-LABEL: TEST: testSingleResultProperty 602@run 603def testSingleResultProperty(): 604 with Context(), Location.unknown(): 605 Context.current.allow_unregistered_dialects = True 606 module = Module.parse(r""" 607 "custom.no_result"() : () -> () 608 %0:2 = "custom.two_result"() : () -> (f32, f32) 609 %1 = "custom.one_result"() : () -> f32 610 """) 611 print(module) 612 613 try: 614 module.body.operations[0].result 615 except ValueError as e: 616 # CHECK: Cannot call .result on operation custom.no_result which has 0 results 617 print(e) 618 else: 619 assert False, "Expected exception" 620 621 try: 622 module.body.operations[1].result 623 except ValueError as e: 624 # CHECK: Cannot call .result on operation custom.two_result which has 2 results 625 print(e) 626 else: 627 assert False, "Expected exception" 628 629 # CHECK: %1 = "custom.one_result"() : () -> f32 630 print(module.body.operations[2]) 631 632 633def create_invalid_operation(): 634 # This module has two region and is invalid verify that we fallback 635 # to the generic printer for safety. 636 op = Operation.create("builtin.module", regions=2) 637 op.regions[0].blocks.append() 638 return op 639 640# CHECK-LABEL: TEST: testInvalidOperationStrSoftFails 641@run 642def testInvalidOperationStrSoftFails(): 643 ctx = Context() 644 with Location.unknown(ctx): 645 invalid_op = create_invalid_operation() 646 # Verify that we fallback to the generic printer for safety. 647 # CHECK: // Verification failed, printing generic form 648 # CHECK: "builtin.module"() ( { 649 # CHECK: }) : () -> () 650 print(invalid_op) 651 # CHECK: .verify = False 652 print(f".verify = {invalid_op.operation.verify()}") 653 654 655# CHECK-LABEL: TEST: testInvalidModuleStrSoftFails 656@run 657def testInvalidModuleStrSoftFails(): 658 ctx = Context() 659 with Location.unknown(ctx): 660 module = Module.create() 661 with InsertionPoint(module.body): 662 invalid_op = create_invalid_operation() 663 # Verify that we fallback to the generic printer for safety. 664 # CHECK: // Verification failed, printing generic form 665 print(module) 666 667 668# CHECK-LABEL: TEST: testInvalidOperationGetAsmBinarySoftFails 669@run 670def testInvalidOperationGetAsmBinarySoftFails(): 671 ctx = Context() 672 with Location.unknown(ctx): 673 invalid_op = create_invalid_operation() 674 # Verify that we fallback to the generic printer for safety. 675 # CHECK: b'// Verification failed, printing generic form\n 676 print(invalid_op.get_asm(binary=True)) 677 678 679# CHECK-LABEL: TEST: testCreateWithInvalidAttributes 680@run 681def testCreateWithInvalidAttributes(): 682 ctx = Context() 683 with Location.unknown(ctx): 684 try: 685 Operation.create( 686 "builtin.module", attributes={None: StringAttr.get("name")}) 687 except Exception as e: 688 # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module" 689 print(e) 690 try: 691 Operation.create( 692 "builtin.module", attributes={42: StringAttr.get("name")}) 693 except Exception as e: 694 # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module" 695 print(e) 696 try: 697 Operation.create("builtin.module", attributes={"some_key": ctx}) 698 except Exception as e: 699 # CHECK: Invalid attribute value for the key "some_key" when attempting to create the operation "builtin.module" 700 print(e) 701 try: 702 Operation.create("builtin.module", attributes={"some_key": None}) 703 except Exception as e: 704 # CHECK: Found an invalid (`None`?) attribute value for the key "some_key" when attempting to create the operation "builtin.module" 705 print(e) 706 707 708# CHECK-LABEL: TEST: testOperationName 709@run 710def testOperationName(): 711 ctx = Context() 712 ctx.allow_unregistered_dialects = True 713 module = Module.parse( 714 r""" 715 %0 = "custom.op1"() : () -> f32 716 %1 = "custom.op2"() : () -> i32 717 %2 = "custom.op1"() : () -> f32 718 """, ctx) 719 720 # CHECK: custom.op1 721 # CHECK: custom.op2 722 # CHECK: custom.op1 723 for op in module.body.operations: 724 print(op.operation.name) 725 726 727# CHECK-LABEL: TEST: testCapsuleConversions 728@run 729def testCapsuleConversions(): 730 ctx = Context() 731 ctx.allow_unregistered_dialects = True 732 with Location.unknown(ctx): 733 m = Operation.create("custom.op1").operation 734 m_capsule = m._CAPIPtr 735 assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule) 736 m2 = Operation._CAPICreate(m_capsule) 737 assert m2 is m 738 739 740# CHECK-LABEL: TEST: testOperationErase 741@run 742def testOperationErase(): 743 ctx = Context() 744 ctx.allow_unregistered_dialects = True 745 with Location.unknown(ctx): 746 m = Module.create() 747 with InsertionPoint(m.body): 748 op = Operation.create("custom.op1") 749 750 # CHECK: "custom.op1" 751 print(m) 752 753 op.operation.erase() 754 755 # CHECK-NOT: "custom.op1" 756 print(m) 757 758 # Ensure we can create another operation 759 Operation.create("custom.op2") 760 761 762# CHECK-LABEL: TEST: testOperationLoc 763@run 764def testOperationLoc(): 765 ctx = Context() 766 ctx.allow_unregistered_dialects = True 767 with ctx: 768 loc = Location.name("loc") 769 op = Operation.create("custom.op", loc=loc) 770 assert op.location == loc 771 assert op.operation.location == loc 772 773 774# CHECK-LABEL: TEST: testModuleMerge 775@run 776def testModuleMerge(): 777 with Context(): 778 m1 = Module.parse("func private @foo()") 779 m2 = Module.parse(""" 780 func private @bar() 781 func private @qux() 782 """) 783 foo = m1.body.operations[0] 784 bar = m2.body.operations[0] 785 qux = m2.body.operations[1] 786 bar.move_before(foo) 787 qux.move_after(foo) 788 789 # CHECK: module 790 # CHECK: func private @bar 791 # CHECK: func private @foo 792 # CHECK: func private @qux 793 print(m1) 794 795 # CHECK: module { 796 # CHECK-NEXT: } 797 print(m2) 798 799 800# CHECK-LABEL: TEST: testAppendMoveFromAnotherBlock 801@run 802def testAppendMoveFromAnotherBlock(): 803 with Context(): 804 m1 = Module.parse("func private @foo()") 805 m2 = Module.parse("func private @bar()") 806 func = m1.body.operations[0] 807 m2.body.append(func) 808 809 # CHECK: module 810 # CHECK: func private @bar 811 # CHECK: func private @foo 812 813 print(m2) 814 # CHECK: module { 815 # CHECK-NEXT: } 816 print(m1) 817 818 819# CHECK-LABEL: TEST: testDetachFromParent 820@run 821def testDetachFromParent(): 822 with Context(): 823 m1 = Module.parse("func private @foo()") 824 func = m1.body.operations[0].detach_from_parent() 825 826 try: 827 func.detach_from_parent() 828 except ValueError as e: 829 if "has no parent" not in str(e): 830 raise 831 else: 832 assert False, "expected ValueError when detaching a detached operation" 833 834 print(m1) 835 # CHECK-NOT: func private @foo 836 837 838# CHECK-LABEL: TEST: testOperationHash 839@run 840def testOperationHash(): 841 ctx = Context() 842 ctx.allow_unregistered_dialects = True 843 with ctx, Location.unknown(): 844 op = Operation.create("custom.op1") 845 assert hash(op) == hash(op.operation) 846