1# RUN: %PYTHON %s | FileCheck %s 2 3import gc 4import io 5import itertools 6from mlir.ir import * 7 8 9def run(f): 10 print("\nTEST:", f.__name__) 11 f() 12 gc.collect() 13 assert Context._get_live_count() == 0 14 return f 15 16 17def expect_index_error(callback): 18 try: 19 _ = callback() 20 raise RuntimeError("Expected IndexError") 21 except IndexError: 22 pass 23 24 25# Verify iterator based traversal of the op/region/block hierarchy. 26# CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators 27@run 28def testTraverseOpRegionBlockIterators(): 29 ctx = Context() 30 ctx.allow_unregistered_dialects = True 31 module = Module.parse( 32 r""" 33 func.func @f1(%arg0: i32) -> i32 { 34 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 35 return %1 : i32 36 } 37 """, ctx) 38 op = module.operation 39 assert op.context is ctx 40 # Get the block using iterators off of the named collections. 41 regions = list(op.regions) 42 blocks = list(regions[0].blocks) 43 # CHECK: MODULE REGIONS=1 BLOCKS=1 44 print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}") 45 46 # Should verify. 47 # CHECK: .verify = True 48 print(f".verify = {module.operation.verify()}") 49 50 # Get the regions and blocks from the default collections. 51 default_regions = list(op.regions) 52 default_blocks = list(default_regions[0]) 53 # They should compare equal regardless of how obtained. 54 assert default_regions == regions 55 assert default_blocks == blocks 56 57 # Should be able to get the operations from either the named collection 58 # or the block. 59 operations = list(blocks[0].operations) 60 default_operations = list(blocks[0]) 61 assert default_operations == operations 62 63 def walk_operations(indent, op): 64 for i, region in enumerate(op.regions): 65 print(f"{indent}REGION {i}:") 66 for j, block in enumerate(region): 67 print(f"{indent} BLOCK {j}:") 68 for k, child_op in enumerate(block): 69 print(f"{indent} OP {k}: {child_op}") 70 walk_operations(indent + " ", child_op) 71 72 # CHECK: REGION 0: 73 # CHECK: BLOCK 0: 74 # CHECK: OP 0: func 75 # CHECK: REGION 0: 76 # CHECK: BLOCK 0: 77 # CHECK: OP 0: %0 = "custom.addi" 78 # CHECK: OP 1: func.return 79 walk_operations("", op) 80 81 82# Verify index based traversal of the op/region/block hierarchy. 83# CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices 84@run 85def testTraverseOpRegionBlockIndices(): 86 ctx = Context() 87 ctx.allow_unregistered_dialects = True 88 module = Module.parse( 89 r""" 90 func.func @f1(%arg0: i32) -> i32 { 91 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 92 return %1 : i32 93 } 94 """, ctx) 95 96 def walk_operations(indent, op): 97 for i in range(len(op.regions)): 98 region = op.regions[i] 99 print(f"{indent}REGION {i}:") 100 for j in range(len(region.blocks)): 101 block = region.blocks[j] 102 print(f"{indent} BLOCK {j}:") 103 for k in range(len(block.operations)): 104 child_op = block.operations[k] 105 print(f"{indent} OP {k}: {child_op}") 106 print(f"{indent} OP {k}: parent {child_op.operation.parent.name}") 107 walk_operations(indent + " ", child_op) 108 109 # CHECK: REGION 0: 110 # CHECK: BLOCK 0: 111 # CHECK: OP 0: func 112 # CHECK: OP 0: parent builtin.module 113 # CHECK: REGION 0: 114 # CHECK: BLOCK 0: 115 # CHECK: OP 0: %0 = "custom.addi" 116 # CHECK: OP 0: parent func.func 117 # CHECK: OP 1: func.return 118 # CHECK: OP 1: parent func.func 119 walk_operations("", module.operation) 120 121 122# CHECK-LABEL: TEST: testBlockAndRegionOwners 123@run 124def testBlockAndRegionOwners(): 125 ctx = Context() 126 ctx.allow_unregistered_dialects = True 127 module = Module.parse( 128 r""" 129 builtin.module { 130 func.func @f() { 131 func.return 132 } 133 } 134 """, ctx) 135 136 assert module.operation.regions[0].owner == module.operation 137 assert module.operation.regions[0].blocks[0].owner == module.operation 138 139 func = module.body.operations[0] 140 assert func.operation.regions[0].owner == func 141 assert func.operation.regions[0].blocks[0].owner == func 142 143 144# CHECK-LABEL: TEST: testBlockArgumentList 145@run 146def testBlockArgumentList(): 147 with Context() as ctx: 148 module = Module.parse( 149 r""" 150 func.func @f1(%arg0: i32, %arg1: f64, %arg2: index) { 151 return 152 } 153 """, ctx) 154 func = module.body.operations[0] 155 entry_block = func.regions[0].blocks[0] 156 assert len(entry_block.arguments) == 3 157 # CHECK: Argument 0, type i32 158 # CHECK: Argument 1, type f64 159 # CHECK: Argument 2, type index 160 for arg in entry_block.arguments: 161 print(f"Argument {arg.arg_number}, type {arg.type}") 162 new_type = IntegerType.get_signless(8 * (arg.arg_number + 1)) 163 arg.set_type(new_type) 164 165 # CHECK: Argument 0, type i8 166 # CHECK: Argument 1, type i16 167 # CHECK: Argument 2, type i24 168 for arg in entry_block.arguments: 169 print(f"Argument {arg.arg_number}, type {arg.type}") 170 171 # Check that slicing works for block argument lists. 172 # CHECK: Argument 1, type i16 173 # CHECK: Argument 2, type i24 174 for arg in entry_block.arguments[1:]: 175 print(f"Argument {arg.arg_number}, type {arg.type}") 176 177 # Check that we can concatenate slices of argument lists. 178 # CHECK: Length: 4 179 print("Length: ", 180 len(entry_block.arguments[:2] + entry_block.arguments[1:])) 181 182 # CHECK: Type: i8 183 # CHECK: Type: i16 184 # CHECK: Type: i24 185 for t in entry_block.arguments.types: 186 print("Type: ", t) 187 188 # Check that slicing and type access compose. 189 # CHECK: Sliced type: i16 190 # CHECK: Sliced type: i24 191 for t in entry_block.arguments[1:].types: 192 print("Sliced type: ", t) 193 194 # Check that slice addition works as expected. 195 # CHECK: Argument 2, type i24 196 # CHECK: Argument 0, type i8 197 restructured = entry_block.arguments[-1:] + entry_block.arguments[:1] 198 for arg in restructured: 199 print(f"Argument {arg.arg_number}, type {arg.type}") 200 201 202# CHECK-LABEL: TEST: testOperationOperands 203@run 204def testOperationOperands(): 205 with Context() as ctx: 206 ctx.allow_unregistered_dialects = True 207 module = Module.parse(r""" 208 func.func @f1(%arg0: i32) { 209 %0 = "test.producer"() : () -> i64 210 "test.consumer"(%arg0, %0) : (i32, i64) -> () 211 return 212 }""") 213 func = module.body.operations[0] 214 entry_block = func.regions[0].blocks[0] 215 consumer = entry_block.operations[1] 216 assert len(consumer.operands) == 2 217 # CHECK: Operand 0, type i32 218 # CHECK: Operand 1, type i64 219 for i, operand in enumerate(consumer.operands): 220 print(f"Operand {i}, type {operand.type}") 221 222 223 224 225# CHECK-LABEL: TEST: testOperationOperandsSlice 226@run 227def testOperationOperandsSlice(): 228 with Context() as ctx: 229 ctx.allow_unregistered_dialects = True 230 module = Module.parse(r""" 231 func.func @f1() { 232 %0 = "test.producer0"() : () -> i64 233 %1 = "test.producer1"() : () -> i64 234 %2 = "test.producer2"() : () -> i64 235 %3 = "test.producer3"() : () -> i64 236 %4 = "test.producer4"() : () -> i64 237 "test.consumer"(%0, %1, %2, %3, %4) : (i64, i64, i64, i64, i64) -> () 238 return 239 }""") 240 func = module.body.operations[0] 241 entry_block = func.regions[0].blocks[0] 242 consumer = entry_block.operations[5] 243 assert len(consumer.operands) == 5 244 for left, right in zip(consumer.operands, consumer.operands[::-1][::-1]): 245 assert left == right 246 247 # CHECK: test.producer0 248 # CHECK: test.producer1 249 # CHECK: test.producer2 250 # CHECK: test.producer3 251 # CHECK: test.producer4 252 full_slice = consumer.operands[:] 253 for operand in full_slice: 254 print(operand) 255 256 # CHECK: test.producer0 257 # CHECK: test.producer1 258 first_two = consumer.operands[0:2] 259 for operand in first_two: 260 print(operand) 261 262 # CHECK: test.producer3 263 # CHECK: test.producer4 264 last_two = consumer.operands[3:] 265 for operand in last_two: 266 print(operand) 267 268 # CHECK: test.producer0 269 # CHECK: test.producer2 270 # CHECK: test.producer4 271 even = consumer.operands[::2] 272 for operand in even: 273 print(operand) 274 275 # CHECK: test.producer2 276 fourth = consumer.operands[::2][1::2] 277 for operand in fourth: 278 print(operand) 279 280 281 282 283# CHECK-LABEL: TEST: testOperationOperandsSet 284@run 285def testOperationOperandsSet(): 286 with Context() as ctx, Location.unknown(ctx): 287 ctx.allow_unregistered_dialects = True 288 module = Module.parse(r""" 289 func.func @f1() { 290 %0 = "test.producer0"() : () -> i64 291 %1 = "test.producer1"() : () -> i64 292 %2 = "test.producer2"() : () -> i64 293 "test.consumer"(%0) : (i64) -> () 294 return 295 }""") 296 func = module.body.operations[0] 297 entry_block = func.regions[0].blocks[0] 298 producer1 = entry_block.operations[1] 299 producer2 = entry_block.operations[2] 300 consumer = entry_block.operations[3] 301 assert len(consumer.operands) == 1 302 type = consumer.operands[0].type 303 304 # CHECK: test.producer1 305 consumer.operands[0] = producer1.result 306 print(consumer.operands[0]) 307 308 # CHECK: test.producer2 309 consumer.operands[-1] = producer2.result 310 print(consumer.operands[0]) 311 312 313 314 315# CHECK-LABEL: TEST: testDetachedOperation 316@run 317def testDetachedOperation(): 318 ctx = Context() 319 ctx.allow_unregistered_dialects = True 320 with Location.unknown(ctx): 321 i32 = IntegerType.get_signed(32) 322 op1 = Operation.create( 323 "custom.op1", 324 results=[i32, i32], 325 regions=1, 326 attributes={ 327 "foo": StringAttr.get("foo_value"), 328 "bar": StringAttr.get("bar_value"), 329 }) 330 # CHECK: %0:2 = "custom.op1"() ({ 331 # CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32) 332 print(op1) 333 334 # TODO: Check successors once enough infra exists to do it properly. 335 336 337# CHECK-LABEL: TEST: testOperationInsertionPoint 338@run 339def testOperationInsertionPoint(): 340 ctx = Context() 341 ctx.allow_unregistered_dialects = True 342 module = Module.parse( 343 r""" 344 func.func @f1(%arg0: i32) -> i32 { 345 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 346 return %1 : i32 347 } 348 """, ctx) 349 350 # Create test op. 351 with Location.unknown(ctx): 352 op1 = Operation.create("custom.op1") 353 op2 = Operation.create("custom.op2") 354 355 func = module.body.operations[0] 356 entry_block = func.regions[0].blocks[0] 357 ip = InsertionPoint.at_block_begin(entry_block) 358 ip.insert(op1) 359 ip.insert(op2) 360 # CHECK: func @f1 361 # CHECK: "custom.op1"() 362 # CHECK: "custom.op2"() 363 # CHECK: %0 = "custom.addi" 364 print(module) 365 366 # Trying to add a previously added op should raise. 367 try: 368 ip.insert(op1) 369 except ValueError: 370 pass 371 else: 372 assert False, "expected insert of attached op to raise" 373 374 375# CHECK-LABEL: TEST: testOperationWithRegion 376@run 377def testOperationWithRegion(): 378 ctx = Context() 379 ctx.allow_unregistered_dialects = True 380 with Location.unknown(ctx): 381 i32 = IntegerType.get_signed(32) 382 op1 = Operation.create("custom.op1", regions=1) 383 block = op1.regions[0].blocks.append(i32, i32) 384 # CHECK: "custom.op1"() ({ 385 # CHECK: ^bb0(%arg0: si32, %arg1: si32): 386 # CHECK: "custom.terminator"() : () -> () 387 # CHECK: }) : () -> () 388 terminator = Operation.create("custom.terminator") 389 ip = InsertionPoint(block) 390 ip.insert(terminator) 391 print(op1) 392 393 # Now add the whole operation to another op. 394 # TODO: Verify lifetime hazard by nulling out the new owning module and 395 # accessing op1. 396 # TODO: Also verify accessing the terminator once both parents are nulled 397 # out. 398 module = Module.parse(r""" 399 func.func @f1(%arg0: i32) -> i32 { 400 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 401 return %1 : i32 402 } 403 """) 404 func = module.body.operations[0] 405 entry_block = func.regions[0].blocks[0] 406 ip = InsertionPoint.at_block_begin(entry_block) 407 ip.insert(op1) 408 # CHECK: func @f1 409 # CHECK: "custom.op1"() 410 # CHECK: "custom.terminator" 411 # CHECK: %0 = "custom.addi" 412 print(module) 413 414 415# CHECK-LABEL: TEST: testOperationResultList 416@run 417def testOperationResultList(): 418 ctx = Context() 419 module = Module.parse( 420 r""" 421 func.func @f1() { 422 %0:3 = call @f2() : () -> (i32, f64, index) 423 return 424 } 425 func.func private @f2() -> (i32, f64, index) 426 """, ctx) 427 caller = module.body.operations[0] 428 call = caller.regions[0].blocks[0].operations[0] 429 assert len(call.results) == 3 430 # CHECK: Result 0, type i32 431 # CHECK: Result 1, type f64 432 # CHECK: Result 2, type index 433 for res in call.results: 434 print(f"Result {res.result_number}, type {res.type}") 435 436 # CHECK: Result type i32 437 # CHECK: Result type f64 438 # CHECK: Result type index 439 for t in call.results.types: 440 print(f"Result type {t}") 441 442 # Out of range 443 expect_index_error(lambda: call.results[3]) 444 expect_index_error(lambda: call.results[-4]) 445 446 447# CHECK-LABEL: TEST: testOperationResultListSlice 448@run 449def testOperationResultListSlice(): 450 with Context() as ctx: 451 ctx.allow_unregistered_dialects = True 452 module = Module.parse(r""" 453 func.func @f1() { 454 "some.op"() : () -> (i1, i2, i3, i4, i5) 455 return 456 } 457 """) 458 func = module.body.operations[0] 459 entry_block = func.regions[0].blocks[0] 460 producer = entry_block.operations[0] 461 462 assert len(producer.results) == 5 463 for left, right in zip(producer.results, producer.results[::-1][::-1]): 464 assert left == right 465 assert left.result_number == right.result_number 466 467 # CHECK: Result 0, type i1 468 # CHECK: Result 1, type i2 469 # CHECK: Result 2, type i3 470 # CHECK: Result 3, type i4 471 # CHECK: Result 4, type i5 472 full_slice = producer.results[:] 473 for res in full_slice: 474 print(f"Result {res.result_number}, type {res.type}") 475 476 # CHECK: Result 1, type i2 477 # CHECK: Result 2, type i3 478 # CHECK: Result 3, type i4 479 middle = producer.results[1:4] 480 for res in middle: 481 print(f"Result {res.result_number}, type {res.type}") 482 483 # CHECK: Result 1, type i2 484 # CHECK: Result 3, type i4 485 odd = producer.results[1::2] 486 for res in odd: 487 print(f"Result {res.result_number}, type {res.type}") 488 489 # CHECK: Result 3, type i4 490 # CHECK: Result 1, type i2 491 inverted_middle = producer.results[-2:0:-2] 492 for res in inverted_middle: 493 print(f"Result {res.result_number}, type {res.type}") 494 495 496# CHECK-LABEL: TEST: testOperationAttributes 497@run 498def testOperationAttributes(): 499 ctx = Context() 500 ctx.allow_unregistered_dialects = True 501 module = Module.parse( 502 r""" 503 "some.op"() { some.attribute = 1 : i8, 504 other.attribute = 3.0, 505 dependent = "text" } : () -> () 506 """, ctx) 507 op = module.body.operations[0] 508 assert len(op.attributes) == 3 509 iattr = IntegerAttr(op.attributes["some.attribute"]) 510 fattr = FloatAttr(op.attributes["other.attribute"]) 511 sattr = StringAttr(op.attributes["dependent"]) 512 # CHECK: Attribute type i8, value 1 513 print(f"Attribute type {iattr.type}, value {iattr.value}") 514 # CHECK: Attribute type f64, value 3.0 515 print(f"Attribute type {fattr.type}, value {fattr.value}") 516 # CHECK: Attribute value text 517 print(f"Attribute value {sattr.value}") 518 519 # We don't know in which order the attributes are stored. 520 # CHECK-DAG: NamedAttribute(dependent="text") 521 # CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64) 522 # CHECK-DAG: NamedAttribute(some.attribute=1 : i8) 523 for attr in op.attributes: 524 print(str(attr)) 525 526 # Check that exceptions are raised as expected. 527 try: 528 op.attributes["does_not_exist"] 529 except KeyError: 530 pass 531 else: 532 assert False, "expected KeyError on accessing a non-existent attribute" 533 534 try: 535 op.attributes[42] 536 except IndexError: 537 pass 538 else: 539 assert False, "expected IndexError on accessing an out-of-bounds attribute" 540 541 542 543 544# CHECK-LABEL: TEST: testOperationPrint 545@run 546def testOperationPrint(): 547 ctx = Context() 548 module = Module.parse( 549 r""" 550 func.func @f1(%arg0: i32) -> i32 { 551 %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom") 552 return %arg0 : i32 553 } 554 """, ctx) 555 556 # Test print to stdout. 557 # CHECK: return %arg0 : i32 558 module.operation.print() 559 560 # Test print to text file. 561 f = io.StringIO() 562 # CHECK: <class 'str'> 563 # CHECK: return %arg0 : i32 564 module.operation.print(file=f) 565 str_value = f.getvalue() 566 print(str_value.__class__) 567 print(f.getvalue()) 568 569 # Test print to binary file. 570 f = io.BytesIO() 571 # CHECK: <class 'bytes'> 572 # CHECK: return %arg0 : i32 573 module.operation.print(file=f, binary=True) 574 bytes_value = f.getvalue() 575 print(bytes_value.__class__) 576 print(bytes_value) 577 578 # Test get_asm local_scope. 579 # CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom") 580 module.operation.print(enable_debug_info=True, use_local_scope=True) 581 582 # Test get_asm with options. 583 # CHECK: value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<4xi32> 584 # CHECK: "func.return"(%arg0) : (i32) -> () -:4:7 585 module.operation.print( 586 large_elements_limit=2, 587 enable_debug_info=True, 588 pretty_debug_info=True, 589 print_generic_op_form=True, 590 use_local_scope=True) 591 592 593 594 595# CHECK-LABEL: TEST: testKnownOpView 596@run 597def testKnownOpView(): 598 with Context(), Location.unknown(): 599 Context.current.allow_unregistered_dialects = True 600 module = Module.parse(r""" 601 %1 = "custom.f32"() : () -> f32 602 %2 = "custom.f32"() : () -> f32 603 %3 = arith.addf %1, %2 : f32 604 """) 605 print(module) 606 607 # addf should map to a known OpView class in the arithmetic dialect. 608 # We know the OpView for it defines an 'lhs' attribute. 609 addf = module.body.operations[2] 610 # CHECK: <mlir.dialects._arith_ops_gen._AddFOp object 611 print(repr(addf)) 612 # CHECK: "custom.f32"() 613 print(addf.lhs) 614 615 # One of the custom ops should resolve to the default OpView. 616 custom = module.body.operations[0] 617 # CHECK: OpView object 618 print(repr(custom)) 619 620 # Check again to make sure negative caching works. 621 custom = module.body.operations[0] 622 # CHECK: OpView object 623 print(repr(custom)) 624 625 626# CHECK-LABEL: TEST: testSingleResultProperty 627@run 628def testSingleResultProperty(): 629 with Context(), Location.unknown(): 630 Context.current.allow_unregistered_dialects = True 631 module = Module.parse(r""" 632 "custom.no_result"() : () -> () 633 %0:2 = "custom.two_result"() : () -> (f32, f32) 634 %1 = "custom.one_result"() : () -> f32 635 """) 636 print(module) 637 638 try: 639 module.body.operations[0].result 640 except ValueError as e: 641 # CHECK: Cannot call .result on operation custom.no_result which has 0 results 642 print(e) 643 else: 644 assert False, "Expected exception" 645 646 try: 647 module.body.operations[1].result 648 except ValueError as e: 649 # CHECK: Cannot call .result on operation custom.two_result which has 2 results 650 print(e) 651 else: 652 assert False, "Expected exception" 653 654 # CHECK: %1 = "custom.one_result"() : () -> f32 655 print(module.body.operations[2]) 656 657 658def create_invalid_operation(): 659 # This module has two region and is invalid verify that we fallback 660 # to the generic printer for safety. 661 op = Operation.create("builtin.module", regions=2) 662 op.regions[0].blocks.append() 663 return op 664 665# CHECK-LABEL: TEST: testInvalidOperationStrSoftFails 666@run 667def testInvalidOperationStrSoftFails(): 668 ctx = Context() 669 with Location.unknown(ctx): 670 invalid_op = create_invalid_operation() 671 # Verify that we fallback to the generic printer for safety. 672 # CHECK: // Verification failed, printing generic form 673 # CHECK: "builtin.module"() ({ 674 # CHECK: }) : () -> () 675 print(invalid_op) 676 # CHECK: .verify = False 677 print(f".verify = {invalid_op.operation.verify()}") 678 679 680# CHECK-LABEL: TEST: testInvalidModuleStrSoftFails 681@run 682def testInvalidModuleStrSoftFails(): 683 ctx = Context() 684 with Location.unknown(ctx): 685 module = Module.create() 686 with InsertionPoint(module.body): 687 invalid_op = create_invalid_operation() 688 # Verify that we fallback to the generic printer for safety. 689 # CHECK: // Verification failed, printing generic form 690 print(module) 691 692 693# CHECK-LABEL: TEST: testInvalidOperationGetAsmBinarySoftFails 694@run 695def testInvalidOperationGetAsmBinarySoftFails(): 696 ctx = Context() 697 with Location.unknown(ctx): 698 invalid_op = create_invalid_operation() 699 # Verify that we fallback to the generic printer for safety. 700 # CHECK: b'// Verification failed, printing generic form\n 701 print(invalid_op.get_asm(binary=True)) 702 703 704# CHECK-LABEL: TEST: testCreateWithInvalidAttributes 705@run 706def testCreateWithInvalidAttributes(): 707 ctx = Context() 708 with Location.unknown(ctx): 709 try: 710 Operation.create( 711 "builtin.module", attributes={None: StringAttr.get("name")}) 712 except Exception as e: 713 # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module" 714 print(e) 715 try: 716 Operation.create( 717 "builtin.module", attributes={42: StringAttr.get("name")}) 718 except Exception as e: 719 # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module" 720 print(e) 721 try: 722 Operation.create("builtin.module", attributes={"some_key": ctx}) 723 except Exception as e: 724 # CHECK: Invalid attribute value for the key "some_key" when attempting to create the operation "builtin.module" 725 print(e) 726 try: 727 Operation.create("builtin.module", attributes={"some_key": None}) 728 except Exception as e: 729 # CHECK: Found an invalid (`None`?) attribute value for the key "some_key" when attempting to create the operation "builtin.module" 730 print(e) 731 732 733# CHECK-LABEL: TEST: testOperationName 734@run 735def testOperationName(): 736 ctx = Context() 737 ctx.allow_unregistered_dialects = True 738 module = Module.parse( 739 r""" 740 %0 = "custom.op1"() : () -> f32 741 %1 = "custom.op2"() : () -> i32 742 %2 = "custom.op1"() : () -> f32 743 """, ctx) 744 745 # CHECK: custom.op1 746 # CHECK: custom.op2 747 # CHECK: custom.op1 748 for op in module.body.operations: 749 print(op.operation.name) 750 751 752# CHECK-LABEL: TEST: testCapsuleConversions 753@run 754def testCapsuleConversions(): 755 ctx = Context() 756 ctx.allow_unregistered_dialects = True 757 with Location.unknown(ctx): 758 m = Operation.create("custom.op1").operation 759 m_capsule = m._CAPIPtr 760 assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule) 761 m2 = Operation._CAPICreate(m_capsule) 762 assert m2 is m 763 764 765# CHECK-LABEL: TEST: testOperationErase 766@run 767def testOperationErase(): 768 ctx = Context() 769 ctx.allow_unregistered_dialects = True 770 with Location.unknown(ctx): 771 m = Module.create() 772 with InsertionPoint(m.body): 773 op = Operation.create("custom.op1") 774 775 # CHECK: "custom.op1" 776 print(m) 777 778 op.operation.erase() 779 780 # CHECK-NOT: "custom.op1" 781 print(m) 782 783 # Ensure we can create another operation 784 Operation.create("custom.op2") 785 786 787# CHECK-LABEL: TEST: testOperationClone 788@run 789def testOperationClone(): 790 ctx = Context() 791 ctx.allow_unregistered_dialects = True 792 with Location.unknown(ctx): 793 m = Module.create() 794 with InsertionPoint(m.body): 795 op = Operation.create("custom.op1") 796 797 # CHECK: "custom.op1" 798 print(m) 799 800 clone = op.operation.clone() 801 op.operation.erase() 802 803 # CHECK: "custom.op1" 804 print(m) 805 806 807# CHECK-LABEL: TEST: testOperationLoc 808@run 809def testOperationLoc(): 810 ctx = Context() 811 ctx.allow_unregistered_dialects = True 812 with ctx: 813 loc = Location.name("loc") 814 op = Operation.create("custom.op", loc=loc) 815 assert op.location == loc 816 assert op.operation.location == loc 817 818 819# CHECK-LABEL: TEST: testModuleMerge 820@run 821def testModuleMerge(): 822 with Context(): 823 m1 = Module.parse("func.func private @foo()") 824 m2 = Module.parse(""" 825 func.func private @bar() 826 func.func private @qux() 827 """) 828 foo = m1.body.operations[0] 829 bar = m2.body.operations[0] 830 qux = m2.body.operations[1] 831 bar.move_before(foo) 832 qux.move_after(foo) 833 834 # CHECK: module 835 # CHECK: func private @bar 836 # CHECK: func private @foo 837 # CHECK: func private @qux 838 print(m1) 839 840 # CHECK: module { 841 # CHECK-NEXT: } 842 print(m2) 843 844 845# CHECK-LABEL: TEST: testAppendMoveFromAnotherBlock 846@run 847def testAppendMoveFromAnotherBlock(): 848 with Context(): 849 m1 = Module.parse("func.func private @foo()") 850 m2 = Module.parse("func.func private @bar()") 851 func = m1.body.operations[0] 852 m2.body.append(func) 853 854 # CHECK: module 855 # CHECK: func private @bar 856 # CHECK: func private @foo 857 858 print(m2) 859 # CHECK: module { 860 # CHECK-NEXT: } 861 print(m1) 862 863 864# CHECK-LABEL: TEST: testDetachFromParent 865@run 866def testDetachFromParent(): 867 with Context(): 868 m1 = Module.parse("func.func private @foo()") 869 func = m1.body.operations[0].detach_from_parent() 870 871 try: 872 func.detach_from_parent() 873 except ValueError as e: 874 if "has no parent" not in str(e): 875 raise 876 else: 877 assert False, "expected ValueError when detaching a detached operation" 878 879 print(m1) 880 # CHECK-NOT: func private @foo 881 882 883# CHECK-LABEL: TEST: testOperationHash 884@run 885def testOperationHash(): 886 ctx = Context() 887 ctx.allow_unregistered_dialects = True 888 with ctx, Location.unknown(): 889 op = Operation.create("custom.op1") 890 assert hash(op) == hash(op.operation) 891