1# RUN: %PYTHON %s | FileCheck %s 2 3import gc 4import io 5import itertools 6from mlir.ir import * 7 8def run(f): 9 print("\nTEST:", f.__name__) 10 f() 11 gc.collect() 12 assert Context._get_live_count() == 0 13 14 15# Verify iterator based traversal of the op/region/block hierarchy. 16# CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators 17def testTraverseOpRegionBlockIterators(): 18 ctx = Context() 19 ctx.allow_unregistered_dialects = True 20 module = Module.parse(r""" 21 func @f1(%arg0: i32) -> i32 { 22 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 23 return %1 : i32 24 } 25 """, ctx) 26 op = module.operation 27 assert op.context is ctx 28 # Get the block using iterators off of the named collections. 29 regions = list(op.regions) 30 blocks = list(regions[0].blocks) 31 # CHECK: MODULE REGIONS=1 BLOCKS=1 32 print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}") 33 34 # Should verify. 35 # CHECK: .verify = True 36 print(f".verify = {module.operation.verify()}") 37 38 # Get the regions and blocks from the default collections. 39 default_regions = list(op) 40 default_blocks = list(default_regions[0]) 41 # They should compare equal regardless of how obtained. 42 assert default_regions == regions 43 assert default_blocks == blocks 44 45 # Should be able to get the operations from either the named collection 46 # or the block. 47 operations = list(blocks[0].operations) 48 default_operations = list(blocks[0]) 49 assert default_operations == operations 50 51 def walk_operations(indent, op): 52 for i, region in enumerate(op): 53 print(f"{indent}REGION {i}:") 54 for j, block in enumerate(region): 55 print(f"{indent} BLOCK {j}:") 56 for k, child_op in enumerate(block): 57 print(f"{indent} OP {k}: {child_op}") 58 walk_operations(indent + " ", child_op) 59 60 # CHECK: REGION 0: 61 # CHECK: BLOCK 0: 62 # CHECK: OP 0: func 63 # CHECK: REGION 0: 64 # CHECK: BLOCK 0: 65 # CHECK: OP 0: %0 = "custom.addi" 66 # CHECK: OP 1: return 67 walk_operations("", op) 68 69run(testTraverseOpRegionBlockIterators) 70 71 72# Verify index based traversal of the op/region/block hierarchy. 73# CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices 74def testTraverseOpRegionBlockIndices(): 75 ctx = Context() 76 ctx.allow_unregistered_dialects = True 77 module = Module.parse(r""" 78 func @f1(%arg0: i32) -> i32 { 79 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 80 return %1 : i32 81 } 82 """, ctx) 83 84 def walk_operations(indent, op): 85 for i in range(len(op.regions)): 86 region = op.regions[i] 87 print(f"{indent}REGION {i}:") 88 for j in range(len(region.blocks)): 89 block = region.blocks[j] 90 print(f"{indent} BLOCK {j}:") 91 for k in range(len(block.operations)): 92 child_op = block.operations[k] 93 print(f"{indent} OP {k}: {child_op}") 94 print(f"{indent} OP {k}: parent {child_op.operation.parent.name}") 95 walk_operations(indent + " ", child_op) 96 97 # CHECK: REGION 0: 98 # CHECK: BLOCK 0: 99 # CHECK: OP 0: func 100 # CHECK: OP 0: parent builtin.module 101 # CHECK: REGION 0: 102 # CHECK: BLOCK 0: 103 # CHECK: OP 0: %0 = "custom.addi" 104 # CHECK: OP 0: parent builtin.func 105 # CHECK: OP 1: return 106 # CHECK: OP 1: parent builtin.func 107 walk_operations("", module.operation) 108 109run(testTraverseOpRegionBlockIndices) 110 111 112# CHECK-LABEL: TEST: testBlockArgumentList 113def testBlockArgumentList(): 114 with Context() as ctx: 115 module = Module.parse(r""" 116 func @f1(%arg0: i32, %arg1: f64, %arg2: index) { 117 return 118 } 119 """, ctx) 120 func = module.body.operations[0] 121 entry_block = func.regions[0].blocks[0] 122 assert len(entry_block.arguments) == 3 123 # CHECK: Argument 0, type i32 124 # CHECK: Argument 1, type f64 125 # CHECK: Argument 2, type index 126 for arg in entry_block.arguments: 127 print(f"Argument {arg.arg_number}, type {arg.type}") 128 new_type = IntegerType.get_signless(8 * (arg.arg_number + 1)) 129 arg.set_type(new_type) 130 131 # CHECK: Argument 0, type i8 132 # CHECK: Argument 1, type i16 133 # CHECK: Argument 2, type i24 134 for arg in entry_block.arguments: 135 print(f"Argument {arg.arg_number}, type {arg.type}") 136 137 # Check that slicing works for block argument lists. 138 # CHECK: Argument 1, type i16 139 # CHECK: Argument 2, type i24 140 for arg in entry_block.arguments[1:]: 141 print(f"Argument {arg.arg_number}, type {arg.type}") 142 143 # Check that we can concatenate slices of argument lists. 144 # CHECK: Length: 4 145 print("Length: ", 146 len(entry_block.arguments[:2] + entry_block.arguments[1:])) 147 148 149run(testBlockArgumentList) 150 151 152# CHECK-LABEL: TEST: testOperationOperands 153def testOperationOperands(): 154 with Context() as ctx: 155 ctx.allow_unregistered_dialects = True 156 module = Module.parse(r""" 157 func @f1(%arg0: i32) { 158 %0 = "test.producer"() : () -> i64 159 "test.consumer"(%arg0, %0) : (i32, i64) -> () 160 return 161 }""") 162 func = module.body.operations[0] 163 entry_block = func.regions[0].blocks[0] 164 consumer = entry_block.operations[1] 165 assert len(consumer.operands) == 2 166 # CHECK: Operand 0, type i32 167 # CHECK: Operand 1, type i64 168 for i, operand in enumerate(consumer.operands): 169 print(f"Operand {i}, type {operand.type}") 170 171 172run(testOperationOperands) 173 174 175# CHECK-LABEL: TEST: testOperationOperandsSlice 176def testOperationOperandsSlice(): 177 with Context() as ctx: 178 ctx.allow_unregistered_dialects = True 179 module = Module.parse(r""" 180 func @f1() { 181 %0 = "test.producer0"() : () -> i64 182 %1 = "test.producer1"() : () -> i64 183 %2 = "test.producer2"() : () -> i64 184 %3 = "test.producer3"() : () -> i64 185 %4 = "test.producer4"() : () -> i64 186 "test.consumer"(%0, %1, %2, %3, %4) : (i64, i64, i64, i64, i64) -> () 187 return 188 }""") 189 func = module.body.operations[0] 190 entry_block = func.regions[0].blocks[0] 191 consumer = entry_block.operations[5] 192 assert len(consumer.operands) == 5 193 for left, right in zip(consumer.operands, consumer.operands[::-1][::-1]): 194 assert left == right 195 196 # CHECK: test.producer0 197 # CHECK: test.producer1 198 # CHECK: test.producer2 199 # CHECK: test.producer3 200 # CHECK: test.producer4 201 full_slice = consumer.operands[:] 202 for operand in full_slice: 203 print(operand) 204 205 # CHECK: test.producer0 206 # CHECK: test.producer1 207 first_two = consumer.operands[0:2] 208 for operand in first_two: 209 print(operand) 210 211 # CHECK: test.producer3 212 # CHECK: test.producer4 213 last_two = consumer.operands[3:] 214 for operand in last_two: 215 print(operand) 216 217 # CHECK: test.producer0 218 # CHECK: test.producer2 219 # CHECK: test.producer4 220 even = consumer.operands[::2] 221 for operand in even: 222 print(operand) 223 224 # CHECK: test.producer2 225 fourth = consumer.operands[::2][1::2] 226 for operand in fourth: 227 print(operand) 228 229 230run(testOperationOperandsSlice) 231 232 233# CHECK-LABEL: TEST: testOperationOperandsSet 234def testOperationOperandsSet(): 235 with Context() as ctx, Location.unknown(ctx): 236 ctx.allow_unregistered_dialects = True 237 module = Module.parse(r""" 238 func @f1() { 239 %0 = "test.producer0"() : () -> i64 240 %1 = "test.producer1"() : () -> i64 241 %2 = "test.producer2"() : () -> i64 242 "test.consumer"(%0) : (i64) -> () 243 return 244 }""") 245 func = module.body.operations[0] 246 entry_block = func.regions[0].blocks[0] 247 producer1 = entry_block.operations[1] 248 producer2 = entry_block.operations[2] 249 consumer = entry_block.operations[3] 250 assert len(consumer.operands) == 1 251 type = consumer.operands[0].type 252 253 # CHECK: test.producer1 254 consumer.operands[0] = producer1.result 255 print(consumer.operands[0]) 256 257 # CHECK: test.producer2 258 consumer.operands[-1] = producer2.result 259 print(consumer.operands[0]) 260 261 262run(testOperationOperandsSet) 263 264 265# CHECK-LABEL: TEST: testDetachedOperation 266def testDetachedOperation(): 267 ctx = Context() 268 ctx.allow_unregistered_dialects = True 269 with Location.unknown(ctx): 270 i32 = IntegerType.get_signed(32) 271 op1 = Operation.create( 272 "custom.op1", results=[i32, i32], regions=1, attributes={ 273 "foo": StringAttr.get("foo_value"), 274 "bar": StringAttr.get("bar_value"), 275 }) 276 # CHECK: %0:2 = "custom.op1"() ( { 277 # CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32) 278 print(op1) 279 280 # TODO: Check successors once enough infra exists to do it properly. 281 282run(testDetachedOperation) 283 284 285# CHECK-LABEL: TEST: testOperationInsertionPoint 286def testOperationInsertionPoint(): 287 ctx = Context() 288 ctx.allow_unregistered_dialects = True 289 module = Module.parse(r""" 290 func @f1(%arg0: i32) -> i32 { 291 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 292 return %1 : i32 293 } 294 """, ctx) 295 296 # Create test op. 297 with Location.unknown(ctx): 298 op1 = Operation.create("custom.op1") 299 op2 = Operation.create("custom.op2") 300 301 func = module.body.operations[0] 302 entry_block = func.regions[0].blocks[0] 303 ip = InsertionPoint.at_block_begin(entry_block) 304 ip.insert(op1) 305 ip.insert(op2) 306 # CHECK: func @f1 307 # CHECK: "custom.op1"() 308 # CHECK: "custom.op2"() 309 # CHECK: %0 = "custom.addi" 310 print(module) 311 312 # Trying to add a previously added op should raise. 313 try: 314 ip.insert(op1) 315 except ValueError: 316 pass 317 else: 318 assert False, "expected insert of attached op to raise" 319 320run(testOperationInsertionPoint) 321 322 323# CHECK-LABEL: TEST: testOperationWithRegion 324def testOperationWithRegion(): 325 ctx = Context() 326 ctx.allow_unregistered_dialects = True 327 with Location.unknown(ctx): 328 i32 = IntegerType.get_signed(32) 329 op1 = Operation.create("custom.op1", regions=1) 330 block = op1.regions[0].blocks.append(i32, i32) 331 # CHECK: "custom.op1"() ( { 332 # CHECK: ^bb0(%arg0: si32, %arg1: si32): // no predecessors 333 # CHECK: "custom.terminator"() : () -> () 334 # CHECK: }) : () -> () 335 terminator = Operation.create("custom.terminator") 336 ip = InsertionPoint(block) 337 ip.insert(terminator) 338 print(op1) 339 340 # Now add the whole operation to another op. 341 # TODO: Verify lifetime hazard by nulling out the new owning module and 342 # accessing op1. 343 # TODO: Also verify accessing the terminator once both parents are nulled 344 # out. 345 module = Module.parse(r""" 346 func @f1(%arg0: i32) -> i32 { 347 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 348 return %1 : i32 349 } 350 """) 351 func = module.body.operations[0] 352 entry_block = func.regions[0].blocks[0] 353 ip = InsertionPoint.at_block_begin(entry_block) 354 ip.insert(op1) 355 # CHECK: func @f1 356 # CHECK: "custom.op1"() 357 # CHECK: "custom.terminator" 358 # CHECK: %0 = "custom.addi" 359 print(module) 360 361run(testOperationWithRegion) 362 363 364# CHECK-LABEL: TEST: testOperationResultList 365def testOperationResultList(): 366 ctx = Context() 367 module = Module.parse(r""" 368 func @f1() { 369 %0:3 = call @f2() : () -> (i32, f64, index) 370 return 371 } 372 func private @f2() -> (i32, f64, index) 373 """, ctx) 374 caller = module.body.operations[0] 375 call = caller.regions[0].blocks[0].operations[0] 376 assert len(call.results) == 3 377 # CHECK: Result 0, type i32 378 # CHECK: Result 1, type f64 379 # CHECK: Result 2, type index 380 for res in call.results: 381 print(f"Result {res.result_number}, type {res.type}") 382 383 384run(testOperationResultList) 385 386 387# CHECK-LABEL: TEST: testOperationResultListSlice 388def testOperationResultListSlice(): 389 with Context() as ctx: 390 ctx.allow_unregistered_dialects = True 391 module = Module.parse(r""" 392 func @f1() { 393 "some.op"() : () -> (i1, i2, i3, i4, i5) 394 return 395 } 396 """) 397 func = module.body.operations[0] 398 entry_block = func.regions[0].blocks[0] 399 producer = entry_block.operations[0] 400 401 assert len(producer.results) == 5 402 for left, right in zip(producer.results, producer.results[::-1][::-1]): 403 assert left == right 404 assert left.result_number == right.result_number 405 406 # CHECK: Result 0, type i1 407 # CHECK: Result 1, type i2 408 # CHECK: Result 2, type i3 409 # CHECK: Result 3, type i4 410 # CHECK: Result 4, type i5 411 full_slice = producer.results[:] 412 for res in full_slice: 413 print(f"Result {res.result_number}, type {res.type}") 414 415 # CHECK: Result 1, type i2 416 # CHECK: Result 2, type i3 417 # CHECK: Result 3, type i4 418 middle = producer.results[1:4] 419 for res in middle: 420 print(f"Result {res.result_number}, type {res.type}") 421 422 # CHECK: Result 1, type i2 423 # CHECK: Result 3, type i4 424 odd = producer.results[1::2] 425 for res in odd: 426 print(f"Result {res.result_number}, type {res.type}") 427 428 # CHECK: Result 3, type i4 429 # CHECK: Result 1, type i2 430 inverted_middle = producer.results[-2:0:-2] 431 for res in inverted_middle: 432 print(f"Result {res.result_number}, type {res.type}") 433 434 435run(testOperationResultListSlice) 436 437 438# CHECK-LABEL: TEST: testOperationAttributes 439def testOperationAttributes(): 440 ctx = Context() 441 ctx.allow_unregistered_dialects = True 442 module = Module.parse(r""" 443 "some.op"() { some.attribute = 1 : i8, 444 other.attribute = 3.0, 445 dependent = "text" } : () -> () 446 """, ctx) 447 op = module.body.operations[0] 448 assert len(op.attributes) == 3 449 iattr = IntegerAttr(op.attributes["some.attribute"]) 450 fattr = FloatAttr(op.attributes["other.attribute"]) 451 sattr = StringAttr(op.attributes["dependent"]) 452 # CHECK: Attribute type i8, value 1 453 print(f"Attribute type {iattr.type}, value {iattr.value}") 454 # CHECK: Attribute type f64, value 3.0 455 print(f"Attribute type {fattr.type}, value {fattr.value}") 456 # CHECK: Attribute value text 457 print(f"Attribute value {sattr.value}") 458 459 # We don't know in which order the attributes are stored. 460 # CHECK-DAG: NamedAttribute(dependent="text") 461 # CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64) 462 # CHECK-DAG: NamedAttribute(some.attribute=1 : i8) 463 for attr in op.attributes: 464 print(str(attr)) 465 466 # Check that exceptions are raised as expected. 467 try: 468 op.attributes["does_not_exist"] 469 except KeyError: 470 pass 471 else: 472 assert False, "expected KeyError on accessing a non-existent attribute" 473 474 try: 475 op.attributes[42] 476 except IndexError: 477 pass 478 else: 479 assert False, "expected IndexError on accessing an out-of-bounds attribute" 480 481 482run(testOperationAttributes) 483 484 485# CHECK-LABEL: TEST: testOperationPrint 486def testOperationPrint(): 487 ctx = Context() 488 module = Module.parse(r""" 489 func @f1(%arg0: i32) -> i32 { 490 %0 = constant dense<[1, 2, 3, 4]> : tensor<4xi32> 491 return %arg0 : i32 492 } 493 """, ctx) 494 495 # Test print to stdout. 496 # CHECK: return %arg0 : i32 497 module.operation.print() 498 499 # Test print to text file. 500 f = io.StringIO() 501 # CHECK: <class 'str'> 502 # CHECK: return %arg0 : i32 503 module.operation.print(file=f) 504 str_value = f.getvalue() 505 print(str_value.__class__) 506 print(f.getvalue()) 507 508 # Test print to binary file. 509 f = io.BytesIO() 510 # CHECK: <class 'bytes'> 511 # CHECK: return %arg0 : i32 512 module.operation.print(file=f, binary=True) 513 bytes_value = f.getvalue() 514 print(bytes_value.__class__) 515 print(bytes_value) 516 517 # Test get_asm with options. 518 # CHECK: value = opaque<"_", "0xDEADBEEF"> : tensor<4xi32> 519 # CHECK: "std.return"(%arg0) : (i32) -> () -:4:7 520 module.operation.print(large_elements_limit=2, enable_debug_info=True, 521 pretty_debug_info=True, print_generic_op_form=True, use_local_scope=True) 522 523run(testOperationPrint) 524 525 526# CHECK-LABEL: TEST: testKnownOpView 527def testKnownOpView(): 528 with Context(), Location.unknown(): 529 Context.current.allow_unregistered_dialects = True 530 module = Module.parse(r""" 531 %1 = "custom.f32"() : () -> f32 532 %2 = "custom.f32"() : () -> f32 533 %3 = addf %1, %2 : f32 534 """) 535 print(module) 536 537 # addf should map to a known OpView class in the std dialect. 538 # We know the OpView for it defines an 'lhs' attribute. 539 addf = module.body.operations[2] 540 # CHECK: <mlir.dialects._std_ops_gen._AddFOp object 541 print(repr(addf)) 542 # CHECK: "custom.f32"() 543 print(addf.lhs) 544 545 # One of the custom ops should resolve to the default OpView. 546 custom = module.body.operations[0] 547 # CHECK: OpView object 548 print(repr(custom)) 549 550 # Check again to make sure negative caching works. 551 custom = module.body.operations[0] 552 # CHECK: OpView object 553 print(repr(custom)) 554 555run(testKnownOpView) 556 557 558# CHECK-LABEL: TEST: testSingleResultProperty 559def testSingleResultProperty(): 560 with Context(), Location.unknown(): 561 Context.current.allow_unregistered_dialects = True 562 module = Module.parse(r""" 563 "custom.no_result"() : () -> () 564 %0:2 = "custom.two_result"() : () -> (f32, f32) 565 %1 = "custom.one_result"() : () -> f32 566 """) 567 print(module) 568 569 try: 570 module.body.operations[0].result 571 except ValueError as e: 572 # CHECK: Cannot call .result on operation custom.no_result which has 0 results 573 print(e) 574 else: 575 assert False, "Expected exception" 576 577 try: 578 module.body.operations[1].result 579 except ValueError as e: 580 # CHECK: Cannot call .result on operation custom.two_result which has 2 results 581 print(e) 582 else: 583 assert False, "Expected exception" 584 585 # CHECK: %1 = "custom.one_result"() : () -> f32 586 print(module.body.operations[2]) 587 588run(testSingleResultProperty) 589 590# CHECK-LABEL: TEST: testPrintInvalidOperation 591def testPrintInvalidOperation(): 592 ctx = Context() 593 with Location.unknown(ctx): 594 module = Operation.create("builtin.module", regions=2) 595 # This module has two region and is invalid verify that we fallback 596 # to the generic printer for safety. 597 block = module.regions[0].blocks.append() 598 # CHECK: // Verification failed, printing generic form 599 # CHECK: "builtin.module"() ( { 600 # CHECK: }) : () -> () 601 print(module) 602 # CHECK: .verify = False 603 print(f".verify = {module.operation.verify()}") 604run(testPrintInvalidOperation) 605 606 607# CHECK-LABEL: TEST: testCreateWithInvalidAttributes 608def testCreateWithInvalidAttributes(): 609 ctx = Context() 610 with Location.unknown(ctx): 611 try: 612 Operation.create( 613 "builtin.module", attributes={None: StringAttr.get("name")}) 614 except Exception as e: 615 # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module" 616 print(e) 617 try: 618 Operation.create( 619 "builtin.module", attributes={42: StringAttr.get("name")}) 620 except Exception as e: 621 # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module" 622 print(e) 623 try: 624 Operation.create("builtin.module", attributes={"some_key": ctx}) 625 except Exception as e: 626 # CHECK: Invalid attribute value for the key "some_key" when attempting to create the operation "builtin.module" 627 print(e) 628 try: 629 Operation.create("builtin.module", attributes={"some_key": None}) 630 except Exception as e: 631 # CHECK: Found an invalid (`None`?) attribute value for the key "some_key" when attempting to create the operation "builtin.module" 632 print(e) 633run(testCreateWithInvalidAttributes) 634 635 636# CHECK-LABEL: TEST: testOperationName 637def testOperationName(): 638 ctx = Context() 639 ctx.allow_unregistered_dialects = True 640 module = Module.parse(r""" 641 %0 = "custom.op1"() : () -> f32 642 %1 = "custom.op2"() : () -> i32 643 %2 = "custom.op1"() : () -> f32 644 """, ctx) 645 646 # CHECK: custom.op1 647 # CHECK: custom.op2 648 # CHECK: custom.op1 649 for op in module.body.operations: 650 print(op.operation.name) 651 652run(testOperationName) 653 654# CHECK-LABEL: TEST: testCapsuleConversions 655def testCapsuleConversions(): 656 ctx = Context() 657 ctx.allow_unregistered_dialects = True 658 with Location.unknown(ctx): 659 m = Operation.create("custom.op1").operation 660 m_capsule = m._CAPIPtr 661 assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule) 662 m2 = Operation._CAPICreate(m_capsule) 663 assert m2 is m 664 665run(testCapsuleConversions) 666 667# CHECK-LABEL: TEST: testOperationErase 668def testOperationErase(): 669 ctx = Context() 670 ctx.allow_unregistered_dialects = True 671 with Location.unknown(ctx): 672 m = Module.create() 673 with InsertionPoint(m.body): 674 op = Operation.create("custom.op1") 675 676 # CHECK: "custom.op1" 677 print(m) 678 679 op.operation.erase() 680 681 # CHECK-NOT: "custom.op1" 682 print(m) 683 684 # Ensure we can create another operation 685 Operation.create("custom.op2") 686 687run(testOperationErase) 688