1# RUN: %PYTHON %s | FileCheck %s 2 3import gc 4import io 5import itertools 6from mlir.ir import * 7 8 9def run(f): 10 print("\nTEST:", f.__name__) 11 f() 12 gc.collect() 13 assert Context._get_live_count() == 0 14 return f 15 16 17def expect_index_error(callback): 18 try: 19 _ = callback() 20 raise RuntimeError("Expected IndexError") 21 except IndexError: 22 pass 23 24 25# Verify iterator based traversal of the op/region/block hierarchy. 26# CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators 27@run 28def testTraverseOpRegionBlockIterators(): 29 ctx = Context() 30 ctx.allow_unregistered_dialects = True 31 module = Module.parse( 32 r""" 33 func.func @f1(%arg0: i32) -> i32 { 34 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 35 return %1 : i32 36 } 37 """, ctx) 38 op = module.operation 39 assert op.context is ctx 40 # Get the block using iterators off of the named collections. 41 regions = list(op.regions) 42 blocks = list(regions[0].blocks) 43 # CHECK: MODULE REGIONS=1 BLOCKS=1 44 print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}") 45 46 # Should verify. 47 # CHECK: .verify = True 48 print(f".verify = {module.operation.verify()}") 49 50 # Get the regions and blocks from the default collections. 51 default_regions = list(op.regions) 52 default_blocks = list(default_regions[0]) 53 # They should compare equal regardless of how obtained. 54 assert default_regions == regions 55 assert default_blocks == blocks 56 57 # Should be able to get the operations from either the named collection 58 # or the block. 59 operations = list(blocks[0].operations) 60 default_operations = list(blocks[0]) 61 assert default_operations == operations 62 63 def walk_operations(indent, op): 64 for i, region in enumerate(op.regions): 65 print(f"{indent}REGION {i}:") 66 for j, block in enumerate(region): 67 print(f"{indent} BLOCK {j}:") 68 for k, child_op in enumerate(block): 69 print(f"{indent} OP {k}: {child_op}") 70 walk_operations(indent + " ", child_op) 71 72 # CHECK: REGION 0: 73 # CHECK: BLOCK 0: 74 # CHECK: OP 0: func 75 # CHECK: REGION 0: 76 # CHECK: BLOCK 0: 77 # CHECK: OP 0: %0 = "custom.addi" 78 # CHECK: OP 1: func.return 79 walk_operations("", op) 80 81 82# Verify index based traversal of the op/region/block hierarchy. 83# CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices 84@run 85def testTraverseOpRegionBlockIndices(): 86 ctx = Context() 87 ctx.allow_unregistered_dialects = True 88 module = Module.parse( 89 r""" 90 func.func @f1(%arg0: i32) -> i32 { 91 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 92 return %1 : i32 93 } 94 """, ctx) 95 96 def walk_operations(indent, op): 97 for i in range(len(op.regions)): 98 region = op.regions[i] 99 print(f"{indent}REGION {i}:") 100 for j in range(len(region.blocks)): 101 block = region.blocks[j] 102 print(f"{indent} BLOCK {j}:") 103 for k in range(len(block.operations)): 104 child_op = block.operations[k] 105 print(f"{indent} OP {k}: {child_op}") 106 print(f"{indent} OP {k}: parent {child_op.operation.parent.name}") 107 walk_operations(indent + " ", child_op) 108 109 # CHECK: REGION 0: 110 # CHECK: BLOCK 0: 111 # CHECK: OP 0: func 112 # CHECK: OP 0: parent builtin.module 113 # CHECK: REGION 0: 114 # CHECK: BLOCK 0: 115 # CHECK: OP 0: %0 = "custom.addi" 116 # CHECK: OP 0: parent func.func 117 # CHECK: OP 1: func.return 118 # CHECK: OP 1: parent func.func 119 walk_operations("", module.operation) 120 121 122# CHECK-LABEL: TEST: testBlockAndRegionOwners 123@run 124def testBlockAndRegionOwners(): 125 ctx = Context() 126 ctx.allow_unregistered_dialects = True 127 module = Module.parse( 128 r""" 129 builtin.module { 130 func.func @f() { 131 func.return 132 } 133 } 134 """, ctx) 135 136 assert module.operation.regions[0].owner == module.operation 137 assert module.operation.regions[0].blocks[0].owner == module.operation 138 139 func = module.body.operations[0] 140 assert func.operation.regions[0].owner == func 141 assert func.operation.regions[0].blocks[0].owner == func 142 143 144# CHECK-LABEL: TEST: testBlockArgumentList 145@run 146def testBlockArgumentList(): 147 with Context() as ctx: 148 module = Module.parse( 149 r""" 150 func.func @f1(%arg0: i32, %arg1: f64, %arg2: index) { 151 return 152 } 153 """, ctx) 154 func = module.body.operations[0] 155 entry_block = func.regions[0].blocks[0] 156 assert len(entry_block.arguments) == 3 157 # CHECK: Argument 0, type i32 158 # CHECK: Argument 1, type f64 159 # CHECK: Argument 2, type index 160 for arg in entry_block.arguments: 161 print(f"Argument {arg.arg_number}, type {arg.type}") 162 new_type = IntegerType.get_signless(8 * (arg.arg_number + 1)) 163 arg.set_type(new_type) 164 165 # CHECK: Argument 0, type i8 166 # CHECK: Argument 1, type i16 167 # CHECK: Argument 2, type i24 168 for arg in entry_block.arguments: 169 print(f"Argument {arg.arg_number}, type {arg.type}") 170 171 # Check that slicing works for block argument lists. 172 # CHECK: Argument 1, type i16 173 # CHECK: Argument 2, type i24 174 for arg in entry_block.arguments[1:]: 175 print(f"Argument {arg.arg_number}, type {arg.type}") 176 177 # Check that we can concatenate slices of argument lists. 178 # CHECK: Length: 4 179 print("Length: ", 180 len(entry_block.arguments[:2] + entry_block.arguments[1:])) 181 182 # CHECK: Type: i8 183 # CHECK: Type: i16 184 # CHECK: Type: i24 185 for t in entry_block.arguments.types: 186 print("Type: ", t) 187 188 189# CHECK-LABEL: TEST: testOperationOperands 190@run 191def testOperationOperands(): 192 with Context() as ctx: 193 ctx.allow_unregistered_dialects = True 194 module = Module.parse(r""" 195 func.func @f1(%arg0: i32) { 196 %0 = "test.producer"() : () -> i64 197 "test.consumer"(%arg0, %0) : (i32, i64) -> () 198 return 199 }""") 200 func = module.body.operations[0] 201 entry_block = func.regions[0].blocks[0] 202 consumer = entry_block.operations[1] 203 assert len(consumer.operands) == 2 204 # CHECK: Operand 0, type i32 205 # CHECK: Operand 1, type i64 206 for i, operand in enumerate(consumer.operands): 207 print(f"Operand {i}, type {operand.type}") 208 209 210 211 212# CHECK-LABEL: TEST: testOperationOperandsSlice 213@run 214def testOperationOperandsSlice(): 215 with Context() as ctx: 216 ctx.allow_unregistered_dialects = True 217 module = Module.parse(r""" 218 func.func @f1() { 219 %0 = "test.producer0"() : () -> i64 220 %1 = "test.producer1"() : () -> i64 221 %2 = "test.producer2"() : () -> i64 222 %3 = "test.producer3"() : () -> i64 223 %4 = "test.producer4"() : () -> i64 224 "test.consumer"(%0, %1, %2, %3, %4) : (i64, i64, i64, i64, i64) -> () 225 return 226 }""") 227 func = module.body.operations[0] 228 entry_block = func.regions[0].blocks[0] 229 consumer = entry_block.operations[5] 230 assert len(consumer.operands) == 5 231 for left, right in zip(consumer.operands, consumer.operands[::-1][::-1]): 232 assert left == right 233 234 # CHECK: test.producer0 235 # CHECK: test.producer1 236 # CHECK: test.producer2 237 # CHECK: test.producer3 238 # CHECK: test.producer4 239 full_slice = consumer.operands[:] 240 for operand in full_slice: 241 print(operand) 242 243 # CHECK: test.producer0 244 # CHECK: test.producer1 245 first_two = consumer.operands[0:2] 246 for operand in first_two: 247 print(operand) 248 249 # CHECK: test.producer3 250 # CHECK: test.producer4 251 last_two = consumer.operands[3:] 252 for operand in last_two: 253 print(operand) 254 255 # CHECK: test.producer0 256 # CHECK: test.producer2 257 # CHECK: test.producer4 258 even = consumer.operands[::2] 259 for operand in even: 260 print(operand) 261 262 # CHECK: test.producer2 263 fourth = consumer.operands[::2][1::2] 264 for operand in fourth: 265 print(operand) 266 267 268 269 270# CHECK-LABEL: TEST: testOperationOperandsSet 271@run 272def testOperationOperandsSet(): 273 with Context() as ctx, Location.unknown(ctx): 274 ctx.allow_unregistered_dialects = True 275 module = Module.parse(r""" 276 func.func @f1() { 277 %0 = "test.producer0"() : () -> i64 278 %1 = "test.producer1"() : () -> i64 279 %2 = "test.producer2"() : () -> i64 280 "test.consumer"(%0) : (i64) -> () 281 return 282 }""") 283 func = module.body.operations[0] 284 entry_block = func.regions[0].blocks[0] 285 producer1 = entry_block.operations[1] 286 producer2 = entry_block.operations[2] 287 consumer = entry_block.operations[3] 288 assert len(consumer.operands) == 1 289 type = consumer.operands[0].type 290 291 # CHECK: test.producer1 292 consumer.operands[0] = producer1.result 293 print(consumer.operands[0]) 294 295 # CHECK: test.producer2 296 consumer.operands[-1] = producer2.result 297 print(consumer.operands[0]) 298 299 300 301 302# CHECK-LABEL: TEST: testDetachedOperation 303@run 304def testDetachedOperation(): 305 ctx = Context() 306 ctx.allow_unregistered_dialects = True 307 with Location.unknown(ctx): 308 i32 = IntegerType.get_signed(32) 309 op1 = Operation.create( 310 "custom.op1", 311 results=[i32, i32], 312 regions=1, 313 attributes={ 314 "foo": StringAttr.get("foo_value"), 315 "bar": StringAttr.get("bar_value"), 316 }) 317 # CHECK: %0:2 = "custom.op1"() ({ 318 # CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32) 319 print(op1) 320 321 # TODO: Check successors once enough infra exists to do it properly. 322 323 324# CHECK-LABEL: TEST: testOperationInsertionPoint 325@run 326def testOperationInsertionPoint(): 327 ctx = Context() 328 ctx.allow_unregistered_dialects = True 329 module = Module.parse( 330 r""" 331 func.func @f1(%arg0: i32) -> i32 { 332 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 333 return %1 : i32 334 } 335 """, ctx) 336 337 # Create test op. 338 with Location.unknown(ctx): 339 op1 = Operation.create("custom.op1") 340 op2 = Operation.create("custom.op2") 341 342 func = module.body.operations[0] 343 entry_block = func.regions[0].blocks[0] 344 ip = InsertionPoint.at_block_begin(entry_block) 345 ip.insert(op1) 346 ip.insert(op2) 347 # CHECK: func @f1 348 # CHECK: "custom.op1"() 349 # CHECK: "custom.op2"() 350 # CHECK: %0 = "custom.addi" 351 print(module) 352 353 # Trying to add a previously added op should raise. 354 try: 355 ip.insert(op1) 356 except ValueError: 357 pass 358 else: 359 assert False, "expected insert of attached op to raise" 360 361 362# CHECK-LABEL: TEST: testOperationWithRegion 363@run 364def testOperationWithRegion(): 365 ctx = Context() 366 ctx.allow_unregistered_dialects = True 367 with Location.unknown(ctx): 368 i32 = IntegerType.get_signed(32) 369 op1 = Operation.create("custom.op1", regions=1) 370 block = op1.regions[0].blocks.append(i32, i32) 371 # CHECK: "custom.op1"() ({ 372 # CHECK: ^bb0(%arg0: si32, %arg1: si32): 373 # CHECK: "custom.terminator"() : () -> () 374 # CHECK: }) : () -> () 375 terminator = Operation.create("custom.terminator") 376 ip = InsertionPoint(block) 377 ip.insert(terminator) 378 print(op1) 379 380 # Now add the whole operation to another op. 381 # TODO: Verify lifetime hazard by nulling out the new owning module and 382 # accessing op1. 383 # TODO: Also verify accessing the terminator once both parents are nulled 384 # out. 385 module = Module.parse(r""" 386 func.func @f1(%arg0: i32) -> i32 { 387 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 388 return %1 : i32 389 } 390 """) 391 func = module.body.operations[0] 392 entry_block = func.regions[0].blocks[0] 393 ip = InsertionPoint.at_block_begin(entry_block) 394 ip.insert(op1) 395 # CHECK: func @f1 396 # CHECK: "custom.op1"() 397 # CHECK: "custom.terminator" 398 # CHECK: %0 = "custom.addi" 399 print(module) 400 401 402# CHECK-LABEL: TEST: testOperationResultList 403@run 404def testOperationResultList(): 405 ctx = Context() 406 module = Module.parse( 407 r""" 408 func.func @f1() { 409 %0:3 = call @f2() : () -> (i32, f64, index) 410 return 411 } 412 func.func private @f2() -> (i32, f64, index) 413 """, ctx) 414 caller = module.body.operations[0] 415 call = caller.regions[0].blocks[0].operations[0] 416 assert len(call.results) == 3 417 # CHECK: Result 0, type i32 418 # CHECK: Result 1, type f64 419 # CHECK: Result 2, type index 420 for res in call.results: 421 print(f"Result {res.result_number}, type {res.type}") 422 423 # CHECK: Result type i32 424 # CHECK: Result type f64 425 # CHECK: Result type index 426 for t in call.results.types: 427 print(f"Result type {t}") 428 429 # Out of range 430 expect_index_error(lambda: call.results[3]) 431 expect_index_error(lambda: call.results[-4]) 432 433 434# CHECK-LABEL: TEST: testOperationResultListSlice 435@run 436def testOperationResultListSlice(): 437 with Context() as ctx: 438 ctx.allow_unregistered_dialects = True 439 module = Module.parse(r""" 440 func.func @f1() { 441 "some.op"() : () -> (i1, i2, i3, i4, i5) 442 return 443 } 444 """) 445 func = module.body.operations[0] 446 entry_block = func.regions[0].blocks[0] 447 producer = entry_block.operations[0] 448 449 assert len(producer.results) == 5 450 for left, right in zip(producer.results, producer.results[::-1][::-1]): 451 assert left == right 452 assert left.result_number == right.result_number 453 454 # CHECK: Result 0, type i1 455 # CHECK: Result 1, type i2 456 # CHECK: Result 2, type i3 457 # CHECK: Result 3, type i4 458 # CHECK: Result 4, type i5 459 full_slice = producer.results[:] 460 for res in full_slice: 461 print(f"Result {res.result_number}, type {res.type}") 462 463 # CHECK: Result 1, type i2 464 # CHECK: Result 2, type i3 465 # CHECK: Result 3, type i4 466 middle = producer.results[1:4] 467 for res in middle: 468 print(f"Result {res.result_number}, type {res.type}") 469 470 # CHECK: Result 1, type i2 471 # CHECK: Result 3, type i4 472 odd = producer.results[1::2] 473 for res in odd: 474 print(f"Result {res.result_number}, type {res.type}") 475 476 # CHECK: Result 3, type i4 477 # CHECK: Result 1, type i2 478 inverted_middle = producer.results[-2:0:-2] 479 for res in inverted_middle: 480 print(f"Result {res.result_number}, type {res.type}") 481 482 483# CHECK-LABEL: TEST: testOperationAttributes 484@run 485def testOperationAttributes(): 486 ctx = Context() 487 ctx.allow_unregistered_dialects = True 488 module = Module.parse( 489 r""" 490 "some.op"() { some.attribute = 1 : i8, 491 other.attribute = 3.0, 492 dependent = "text" } : () -> () 493 """, ctx) 494 op = module.body.operations[0] 495 assert len(op.attributes) == 3 496 iattr = IntegerAttr(op.attributes["some.attribute"]) 497 fattr = FloatAttr(op.attributes["other.attribute"]) 498 sattr = StringAttr(op.attributes["dependent"]) 499 # CHECK: Attribute type i8, value 1 500 print(f"Attribute type {iattr.type}, value {iattr.value}") 501 # CHECK: Attribute type f64, value 3.0 502 print(f"Attribute type {fattr.type}, value {fattr.value}") 503 # CHECK: Attribute value text 504 print(f"Attribute value {sattr.value}") 505 506 # We don't know in which order the attributes are stored. 507 # CHECK-DAG: NamedAttribute(dependent="text") 508 # CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64) 509 # CHECK-DAG: NamedAttribute(some.attribute=1 : i8) 510 for attr in op.attributes: 511 print(str(attr)) 512 513 # Check that exceptions are raised as expected. 514 try: 515 op.attributes["does_not_exist"] 516 except KeyError: 517 pass 518 else: 519 assert False, "expected KeyError on accessing a non-existent attribute" 520 521 try: 522 op.attributes[42] 523 except IndexError: 524 pass 525 else: 526 assert False, "expected IndexError on accessing an out-of-bounds attribute" 527 528 529 530 531# CHECK-LABEL: TEST: testOperationPrint 532@run 533def testOperationPrint(): 534 ctx = Context() 535 module = Module.parse( 536 r""" 537 func.func @f1(%arg0: i32) -> i32 { 538 %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> 539 return %arg0 : i32 540 } 541 """, ctx) 542 543 # Test print to stdout. 544 # CHECK: return %arg0 : i32 545 module.operation.print() 546 547 # Test print to text file. 548 f = io.StringIO() 549 # CHECK: <class 'str'> 550 # CHECK: return %arg0 : i32 551 module.operation.print(file=f) 552 str_value = f.getvalue() 553 print(str_value.__class__) 554 print(f.getvalue()) 555 556 # Test print to binary file. 557 f = io.BytesIO() 558 # CHECK: <class 'bytes'> 559 # CHECK: return %arg0 : i32 560 module.operation.print(file=f, binary=True) 561 bytes_value = f.getvalue() 562 print(bytes_value.__class__) 563 print(bytes_value) 564 565 # Test get_asm with options. 566 # CHECK: value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<4xi32> 567 # CHECK: "func.return"(%arg0) : (i32) -> () -:4:7 568 module.operation.print( 569 large_elements_limit=2, 570 enable_debug_info=True, 571 pretty_debug_info=True, 572 print_generic_op_form=True, 573 use_local_scope=True) 574 575 576 577 578# CHECK-LABEL: TEST: testKnownOpView 579@run 580def testKnownOpView(): 581 with Context(), Location.unknown(): 582 Context.current.allow_unregistered_dialects = True 583 module = Module.parse(r""" 584 %1 = "custom.f32"() : () -> f32 585 %2 = "custom.f32"() : () -> f32 586 %3 = arith.addf %1, %2 : f32 587 """) 588 print(module) 589 590 # addf should map to a known OpView class in the arithmetic dialect. 591 # We know the OpView for it defines an 'lhs' attribute. 592 addf = module.body.operations[2] 593 # CHECK: <mlir.dialects._arith_ops_gen._AddFOp object 594 print(repr(addf)) 595 # CHECK: "custom.f32"() 596 print(addf.lhs) 597 598 # One of the custom ops should resolve to the default OpView. 599 custom = module.body.operations[0] 600 # CHECK: OpView object 601 print(repr(custom)) 602 603 # Check again to make sure negative caching works. 604 custom = module.body.operations[0] 605 # CHECK: OpView object 606 print(repr(custom)) 607 608 609# CHECK-LABEL: TEST: testSingleResultProperty 610@run 611def testSingleResultProperty(): 612 with Context(), Location.unknown(): 613 Context.current.allow_unregistered_dialects = True 614 module = Module.parse(r""" 615 "custom.no_result"() : () -> () 616 %0:2 = "custom.two_result"() : () -> (f32, f32) 617 %1 = "custom.one_result"() : () -> f32 618 """) 619 print(module) 620 621 try: 622 module.body.operations[0].result 623 except ValueError as e: 624 # CHECK: Cannot call .result on operation custom.no_result which has 0 results 625 print(e) 626 else: 627 assert False, "Expected exception" 628 629 try: 630 module.body.operations[1].result 631 except ValueError as e: 632 # CHECK: Cannot call .result on operation custom.two_result which has 2 results 633 print(e) 634 else: 635 assert False, "Expected exception" 636 637 # CHECK: %1 = "custom.one_result"() : () -> f32 638 print(module.body.operations[2]) 639 640 641def create_invalid_operation(): 642 # This module has two region and is invalid verify that we fallback 643 # to the generic printer for safety. 644 op = Operation.create("builtin.module", regions=2) 645 op.regions[0].blocks.append() 646 return op 647 648# CHECK-LABEL: TEST: testInvalidOperationStrSoftFails 649@run 650def testInvalidOperationStrSoftFails(): 651 ctx = Context() 652 with Location.unknown(ctx): 653 invalid_op = create_invalid_operation() 654 # Verify that we fallback to the generic printer for safety. 655 # CHECK: // Verification failed, printing generic form 656 # CHECK: "builtin.module"() ({ 657 # CHECK: }) : () -> () 658 print(invalid_op) 659 # CHECK: .verify = False 660 print(f".verify = {invalid_op.operation.verify()}") 661 662 663# CHECK-LABEL: TEST: testInvalidModuleStrSoftFails 664@run 665def testInvalidModuleStrSoftFails(): 666 ctx = Context() 667 with Location.unknown(ctx): 668 module = Module.create() 669 with InsertionPoint(module.body): 670 invalid_op = create_invalid_operation() 671 # Verify that we fallback to the generic printer for safety. 672 # CHECK: // Verification failed, printing generic form 673 print(module) 674 675 676# CHECK-LABEL: TEST: testInvalidOperationGetAsmBinarySoftFails 677@run 678def testInvalidOperationGetAsmBinarySoftFails(): 679 ctx = Context() 680 with Location.unknown(ctx): 681 invalid_op = create_invalid_operation() 682 # Verify that we fallback to the generic printer for safety. 683 # CHECK: b'// Verification failed, printing generic form\n 684 print(invalid_op.get_asm(binary=True)) 685 686 687# CHECK-LABEL: TEST: testCreateWithInvalidAttributes 688@run 689def testCreateWithInvalidAttributes(): 690 ctx = Context() 691 with Location.unknown(ctx): 692 try: 693 Operation.create( 694 "builtin.module", attributes={None: StringAttr.get("name")}) 695 except Exception as e: 696 # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module" 697 print(e) 698 try: 699 Operation.create( 700 "builtin.module", attributes={42: StringAttr.get("name")}) 701 except Exception as e: 702 # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module" 703 print(e) 704 try: 705 Operation.create("builtin.module", attributes={"some_key": ctx}) 706 except Exception as e: 707 # CHECK: Invalid attribute value for the key "some_key" when attempting to create the operation "builtin.module" 708 print(e) 709 try: 710 Operation.create("builtin.module", attributes={"some_key": None}) 711 except Exception as e: 712 # CHECK: Found an invalid (`None`?) attribute value for the key "some_key" when attempting to create the operation "builtin.module" 713 print(e) 714 715 716# CHECK-LABEL: TEST: testOperationName 717@run 718def testOperationName(): 719 ctx = Context() 720 ctx.allow_unregistered_dialects = True 721 module = Module.parse( 722 r""" 723 %0 = "custom.op1"() : () -> f32 724 %1 = "custom.op2"() : () -> i32 725 %2 = "custom.op1"() : () -> f32 726 """, ctx) 727 728 # CHECK: custom.op1 729 # CHECK: custom.op2 730 # CHECK: custom.op1 731 for op in module.body.operations: 732 print(op.operation.name) 733 734 735# CHECK-LABEL: TEST: testCapsuleConversions 736@run 737def testCapsuleConversions(): 738 ctx = Context() 739 ctx.allow_unregistered_dialects = True 740 with Location.unknown(ctx): 741 m = Operation.create("custom.op1").operation 742 m_capsule = m._CAPIPtr 743 assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule) 744 m2 = Operation._CAPICreate(m_capsule) 745 assert m2 is m 746 747 748# CHECK-LABEL: TEST: testOperationErase 749@run 750def testOperationErase(): 751 ctx = Context() 752 ctx.allow_unregistered_dialects = True 753 with Location.unknown(ctx): 754 m = Module.create() 755 with InsertionPoint(m.body): 756 op = Operation.create("custom.op1") 757 758 # CHECK: "custom.op1" 759 print(m) 760 761 op.operation.erase() 762 763 # CHECK-NOT: "custom.op1" 764 print(m) 765 766 # Ensure we can create another operation 767 Operation.create("custom.op2") 768 769 770# CHECK-LABEL: TEST: testOperationClone 771@run 772def testOperationClone(): 773 ctx = Context() 774 ctx.allow_unregistered_dialects = True 775 with Location.unknown(ctx): 776 m = Module.create() 777 with InsertionPoint(m.body): 778 op = Operation.create("custom.op1") 779 780 # CHECK: "custom.op1" 781 print(m) 782 783 clone = op.operation.clone() 784 op.operation.erase() 785 786 # CHECK: "custom.op1" 787 print(m) 788 789 790# CHECK-LABEL: TEST: testOperationLoc 791@run 792def testOperationLoc(): 793 ctx = Context() 794 ctx.allow_unregistered_dialects = True 795 with ctx: 796 loc = Location.name("loc") 797 op = Operation.create("custom.op", loc=loc) 798 assert op.location == loc 799 assert op.operation.location == loc 800 801 802# CHECK-LABEL: TEST: testModuleMerge 803@run 804def testModuleMerge(): 805 with Context(): 806 m1 = Module.parse("func.func private @foo()") 807 m2 = Module.parse(""" 808 func.func private @bar() 809 func.func private @qux() 810 """) 811 foo = m1.body.operations[0] 812 bar = m2.body.operations[0] 813 qux = m2.body.operations[1] 814 bar.move_before(foo) 815 qux.move_after(foo) 816 817 # CHECK: module 818 # CHECK: func private @bar 819 # CHECK: func private @foo 820 # CHECK: func private @qux 821 print(m1) 822 823 # CHECK: module { 824 # CHECK-NEXT: } 825 print(m2) 826 827 828# CHECK-LABEL: TEST: testAppendMoveFromAnotherBlock 829@run 830def testAppendMoveFromAnotherBlock(): 831 with Context(): 832 m1 = Module.parse("func.func private @foo()") 833 m2 = Module.parse("func.func private @bar()") 834 func = m1.body.operations[0] 835 m2.body.append(func) 836 837 # CHECK: module 838 # CHECK: func private @bar 839 # CHECK: func private @foo 840 841 print(m2) 842 # CHECK: module { 843 # CHECK-NEXT: } 844 print(m1) 845 846 847# CHECK-LABEL: TEST: testDetachFromParent 848@run 849def testDetachFromParent(): 850 with Context(): 851 m1 = Module.parse("func.func private @foo()") 852 func = m1.body.operations[0].detach_from_parent() 853 854 try: 855 func.detach_from_parent() 856 except ValueError as e: 857 if "has no parent" not in str(e): 858 raise 859 else: 860 assert False, "expected ValueError when detaching a detached operation" 861 862 print(m1) 863 # CHECK-NOT: func private @foo 864 865 866# CHECK-LABEL: TEST: testOperationHash 867@run 868def testOperationHash(): 869 ctx = Context() 870 ctx.allow_unregistered_dialects = True 871 with ctx, Location.unknown(): 872 op = Operation.create("custom.op1") 873 assert hash(op) == hash(op.operation) 874