1# RUN: %PYTHON %s | FileCheck %s 2 3import gc 4import io 5import itertools 6from mlir.ir import * 7 8def run(f): 9 print("\nTEST:", f.__name__) 10 f() 11 gc.collect() 12 assert Context._get_live_count() == 0 13 14 15# Verify iterator based traversal of the op/region/block hierarchy. 16# CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators 17def testTraverseOpRegionBlockIterators(): 18 ctx = Context() 19 ctx.allow_unregistered_dialects = True 20 module = Module.parse(r""" 21 func @f1(%arg0: i32) -> i32 { 22 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 23 return %1 : i32 24 } 25 """, ctx) 26 op = module.operation 27 assert op.context is ctx 28 # Get the block using iterators off of the named collections. 29 regions = list(op.regions) 30 blocks = list(regions[0].blocks) 31 # CHECK: MODULE REGIONS=1 BLOCKS=1 32 print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}") 33 34 # Should verify. 35 # CHECK: .verify = True 36 print(f".verify = {module.operation.verify()}") 37 38 # Get the regions and blocks from the default collections. 39 default_regions = list(op) 40 default_blocks = list(default_regions[0]) 41 # They should compare equal regardless of how obtained. 42 assert default_regions == regions 43 assert default_blocks == blocks 44 45 # Should be able to get the operations from either the named collection 46 # or the block. 47 operations = list(blocks[0].operations) 48 default_operations = list(blocks[0]) 49 assert default_operations == operations 50 51 def walk_operations(indent, op): 52 for i, region in enumerate(op): 53 print(f"{indent}REGION {i}:") 54 for j, block in enumerate(region): 55 print(f"{indent} BLOCK {j}:") 56 for k, child_op in enumerate(block): 57 print(f"{indent} OP {k}: {child_op}") 58 walk_operations(indent + " ", child_op) 59 60 # CHECK: REGION 0: 61 # CHECK: BLOCK 0: 62 # CHECK: OP 0: func 63 # CHECK: REGION 0: 64 # CHECK: BLOCK 0: 65 # CHECK: OP 0: %0 = "custom.addi" 66 # CHECK: OP 1: return 67 walk_operations("", op) 68 69run(testTraverseOpRegionBlockIterators) 70 71 72# Verify index based traversal of the op/region/block hierarchy. 73# CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices 74def testTraverseOpRegionBlockIndices(): 75 ctx = Context() 76 ctx.allow_unregistered_dialects = True 77 module = Module.parse(r""" 78 func @f1(%arg0: i32) -> i32 { 79 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 80 return %1 : i32 81 } 82 """, ctx) 83 84 def walk_operations(indent, op): 85 for i in range(len(op.regions)): 86 region = op.regions[i] 87 print(f"{indent}REGION {i}:") 88 for j in range(len(region.blocks)): 89 block = region.blocks[j] 90 print(f"{indent} BLOCK {j}:") 91 for k in range(len(block.operations)): 92 child_op = block.operations[k] 93 print(f"{indent} OP {k}: {child_op}") 94 print(f"{indent} OP {k}: parent {child_op.operation.parent.name}") 95 walk_operations(indent + " ", child_op) 96 97 # CHECK: REGION 0: 98 # CHECK: BLOCK 0: 99 # CHECK: OP 0: func 100 # CHECK: OP 0: parent builtin.module 101 # CHECK: REGION 0: 102 # CHECK: BLOCK 0: 103 # CHECK: OP 0: %0 = "custom.addi" 104 # CHECK: OP 0: parent builtin.func 105 # CHECK: OP 1: return 106 # CHECK: OP 1: parent builtin.func 107 walk_operations("", module.operation) 108 109run(testTraverseOpRegionBlockIndices) 110 111 112# CHECK-LABEL: TEST: testBlockArgumentList 113def testBlockArgumentList(): 114 with Context() as ctx: 115 module = Module.parse(r""" 116 func @f1(%arg0: i32, %arg1: f64, %arg2: index) { 117 return 118 } 119 """, ctx) 120 func = module.body.operations[0] 121 entry_block = func.regions[0].blocks[0] 122 assert len(entry_block.arguments) == 3 123 # CHECK: Argument 0, type i32 124 # CHECK: Argument 1, type f64 125 # CHECK: Argument 2, type index 126 for arg in entry_block.arguments: 127 print(f"Argument {arg.arg_number}, type {arg.type}") 128 new_type = IntegerType.get_signless(8 * (arg.arg_number + 1)) 129 arg.set_type(new_type) 130 131 # CHECK: Argument 0, type i8 132 # CHECK: Argument 1, type i16 133 # CHECK: Argument 2, type i24 134 for arg in entry_block.arguments: 135 print(f"Argument {arg.arg_number}, type {arg.type}") 136 137 138run(testBlockArgumentList) 139 140 141# CHECK-LABEL: TEST: testOperationOperands 142def testOperationOperands(): 143 with Context() as ctx: 144 ctx.allow_unregistered_dialects = True 145 module = Module.parse(r""" 146 func @f1(%arg0: i32) { 147 %0 = "test.producer"() : () -> i64 148 "test.consumer"(%arg0, %0) : (i32, i64) -> () 149 return 150 }""") 151 func = module.body.operations[0] 152 entry_block = func.regions[0].blocks[0] 153 consumer = entry_block.operations[1] 154 assert len(consumer.operands) == 2 155 # CHECK: Operand 0, type i32 156 # CHECK: Operand 1, type i64 157 for i, operand in enumerate(consumer.operands): 158 print(f"Operand {i}, type {operand.type}") 159 160 161run(testOperationOperands) 162 163 164# CHECK-LABEL: TEST: testOperationOperandsSlice 165def testOperationOperandsSlice(): 166 with Context() as ctx: 167 ctx.allow_unregistered_dialects = True 168 module = Module.parse(r""" 169 func @f1() { 170 %0 = "test.producer0"() : () -> i64 171 %1 = "test.producer1"() : () -> i64 172 %2 = "test.producer2"() : () -> i64 173 %3 = "test.producer3"() : () -> i64 174 %4 = "test.producer4"() : () -> i64 175 "test.consumer"(%0, %1, %2, %3, %4) : (i64, i64, i64, i64, i64) -> () 176 return 177 }""") 178 func = module.body.operations[0] 179 entry_block = func.regions[0].blocks[0] 180 consumer = entry_block.operations[5] 181 assert len(consumer.operands) == 5 182 for left, right in zip(consumer.operands, consumer.operands[::-1][::-1]): 183 assert left == right 184 185 # CHECK: test.producer0 186 # CHECK: test.producer1 187 # CHECK: test.producer2 188 # CHECK: test.producer3 189 # CHECK: test.producer4 190 full_slice = consumer.operands[:] 191 for operand in full_slice: 192 print(operand) 193 194 # CHECK: test.producer0 195 # CHECK: test.producer1 196 first_two = consumer.operands[0:2] 197 for operand in first_two: 198 print(operand) 199 200 # CHECK: test.producer3 201 # CHECK: test.producer4 202 last_two = consumer.operands[3:] 203 for operand in last_two: 204 print(operand) 205 206 # CHECK: test.producer0 207 # CHECK: test.producer2 208 # CHECK: test.producer4 209 even = consumer.operands[::2] 210 for operand in even: 211 print(operand) 212 213 # CHECK: test.producer2 214 fourth = consumer.operands[::2][1::2] 215 for operand in fourth: 216 print(operand) 217 218 219run(testOperationOperandsSlice) 220 221 222# CHECK-LABEL: TEST: testOperationOperandsSet 223def testOperationOperandsSet(): 224 with Context() as ctx, Location.unknown(ctx): 225 ctx.allow_unregistered_dialects = True 226 module = Module.parse(r""" 227 func @f1() { 228 %0 = "test.producer0"() : () -> i64 229 %1 = "test.producer1"() : () -> i64 230 %2 = "test.producer2"() : () -> i64 231 "test.consumer"(%0) : (i64) -> () 232 return 233 }""") 234 func = module.body.operations[0] 235 entry_block = func.regions[0].blocks[0] 236 producer1 = entry_block.operations[1] 237 producer2 = entry_block.operations[2] 238 consumer = entry_block.operations[3] 239 assert len(consumer.operands) == 1 240 type = consumer.operands[0].type 241 242 # CHECK: test.producer1 243 consumer.operands[0] = producer1.result 244 print(consumer.operands[0]) 245 246 # CHECK: test.producer2 247 consumer.operands[-1] = producer2.result 248 print(consumer.operands[0]) 249 250 251run(testOperationOperandsSet) 252 253 254# CHECK-LABEL: TEST: testDetachedOperation 255def testDetachedOperation(): 256 ctx = Context() 257 ctx.allow_unregistered_dialects = True 258 with Location.unknown(ctx): 259 i32 = IntegerType.get_signed(32) 260 op1 = Operation.create( 261 "custom.op1", results=[i32, i32], regions=1, attributes={ 262 "foo": StringAttr.get("foo_value"), 263 "bar": StringAttr.get("bar_value"), 264 }) 265 # CHECK: %0:2 = "custom.op1"() ( { 266 # CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32) 267 print(op1) 268 269 # TODO: Check successors once enough infra exists to do it properly. 270 271run(testDetachedOperation) 272 273 274# CHECK-LABEL: TEST: testOperationInsertionPoint 275def testOperationInsertionPoint(): 276 ctx = Context() 277 ctx.allow_unregistered_dialects = True 278 module = Module.parse(r""" 279 func @f1(%arg0: i32) -> i32 { 280 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 281 return %1 : i32 282 } 283 """, ctx) 284 285 # Create test op. 286 with Location.unknown(ctx): 287 op1 = Operation.create("custom.op1") 288 op2 = Operation.create("custom.op2") 289 290 func = module.body.operations[0] 291 entry_block = func.regions[0].blocks[0] 292 ip = InsertionPoint.at_block_begin(entry_block) 293 ip.insert(op1) 294 ip.insert(op2) 295 # CHECK: func @f1 296 # CHECK: "custom.op1"() 297 # CHECK: "custom.op2"() 298 # CHECK: %0 = "custom.addi" 299 print(module) 300 301 # Trying to add a previously added op should raise. 302 try: 303 ip.insert(op1) 304 except ValueError: 305 pass 306 else: 307 assert False, "expected insert of attached op to raise" 308 309run(testOperationInsertionPoint) 310 311 312# CHECK-LABEL: TEST: testOperationWithRegion 313def testOperationWithRegion(): 314 ctx = Context() 315 ctx.allow_unregistered_dialects = True 316 with Location.unknown(ctx): 317 i32 = IntegerType.get_signed(32) 318 op1 = Operation.create("custom.op1", regions=1) 319 block = op1.regions[0].blocks.append(i32, i32) 320 # CHECK: "custom.op1"() ( { 321 # CHECK: ^bb0(%arg0: si32, %arg1: si32): // no predecessors 322 # CHECK: "custom.terminator"() : () -> () 323 # CHECK: }) : () -> () 324 terminator = Operation.create("custom.terminator") 325 ip = InsertionPoint(block) 326 ip.insert(terminator) 327 print(op1) 328 329 # Now add the whole operation to another op. 330 # TODO: Verify lifetime hazard by nulling out the new owning module and 331 # accessing op1. 332 # TODO: Also verify accessing the terminator once both parents are nulled 333 # out. 334 module = Module.parse(r""" 335 func @f1(%arg0: i32) -> i32 { 336 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 337 return %1 : i32 338 } 339 """) 340 func = module.body.operations[0] 341 entry_block = func.regions[0].blocks[0] 342 ip = InsertionPoint.at_block_begin(entry_block) 343 ip.insert(op1) 344 # CHECK: func @f1 345 # CHECK: "custom.op1"() 346 # CHECK: "custom.terminator" 347 # CHECK: %0 = "custom.addi" 348 print(module) 349 350run(testOperationWithRegion) 351 352 353# CHECK-LABEL: TEST: testOperationResultList 354def testOperationResultList(): 355 ctx = Context() 356 module = Module.parse(r""" 357 func @f1() { 358 %0:3 = call @f2() : () -> (i32, f64, index) 359 return 360 } 361 func private @f2() -> (i32, f64, index) 362 """, ctx) 363 caller = module.body.operations[0] 364 call = caller.regions[0].blocks[0].operations[0] 365 assert len(call.results) == 3 366 # CHECK: Result 0, type i32 367 # CHECK: Result 1, type f64 368 # CHECK: Result 2, type index 369 for res in call.results: 370 print(f"Result {res.result_number}, type {res.type}") 371 372 373run(testOperationResultList) 374 375 376# CHECK-LABEL: TEST: testOperationResultListSlice 377def testOperationResultListSlice(): 378 with Context() as ctx: 379 ctx.allow_unregistered_dialects = True 380 module = Module.parse(r""" 381 func @f1() { 382 "some.op"() : () -> (i1, i2, i3, i4, i5) 383 return 384 } 385 """) 386 func = module.body.operations[0] 387 entry_block = func.regions[0].blocks[0] 388 producer = entry_block.operations[0] 389 390 assert len(producer.results) == 5 391 for left, right in zip(producer.results, producer.results[::-1][::-1]): 392 assert left == right 393 assert left.result_number == right.result_number 394 395 # CHECK: Result 0, type i1 396 # CHECK: Result 1, type i2 397 # CHECK: Result 2, type i3 398 # CHECK: Result 3, type i4 399 # CHECK: Result 4, type i5 400 full_slice = producer.results[:] 401 for res in full_slice: 402 print(f"Result {res.result_number}, type {res.type}") 403 404 # CHECK: Result 1, type i2 405 # CHECK: Result 2, type i3 406 # CHECK: Result 3, type i4 407 middle = producer.results[1:4] 408 for res in middle: 409 print(f"Result {res.result_number}, type {res.type}") 410 411 # CHECK: Result 1, type i2 412 # CHECK: Result 3, type i4 413 odd = producer.results[1::2] 414 for res in odd: 415 print(f"Result {res.result_number}, type {res.type}") 416 417 # CHECK: Result 3, type i4 418 # CHECK: Result 1, type i2 419 inverted_middle = producer.results[-2:0:-2] 420 for res in inverted_middle: 421 print(f"Result {res.result_number}, type {res.type}") 422 423 424run(testOperationResultListSlice) 425 426 427# CHECK-LABEL: TEST: testOperationAttributes 428def testOperationAttributes(): 429 ctx = Context() 430 ctx.allow_unregistered_dialects = True 431 module = Module.parse(r""" 432 "some.op"() { some.attribute = 1 : i8, 433 other.attribute = 3.0, 434 dependent = "text" } : () -> () 435 """, ctx) 436 op = module.body.operations[0] 437 assert len(op.attributes) == 3 438 iattr = IntegerAttr(op.attributes["some.attribute"]) 439 fattr = FloatAttr(op.attributes["other.attribute"]) 440 sattr = StringAttr(op.attributes["dependent"]) 441 # CHECK: Attribute type i8, value 1 442 print(f"Attribute type {iattr.type}, value {iattr.value}") 443 # CHECK: Attribute type f64, value 3.0 444 print(f"Attribute type {fattr.type}, value {fattr.value}") 445 # CHECK: Attribute value text 446 print(f"Attribute value {sattr.value}") 447 448 # We don't know in which order the attributes are stored. 449 # CHECK-DAG: NamedAttribute(dependent="text") 450 # CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64) 451 # CHECK-DAG: NamedAttribute(some.attribute=1 : i8) 452 for attr in op.attributes: 453 print(str(attr)) 454 455 # Check that exceptions are raised as expected. 456 try: 457 op.attributes["does_not_exist"] 458 except KeyError: 459 pass 460 else: 461 assert False, "expected KeyError on accessing a non-existent attribute" 462 463 try: 464 op.attributes[42] 465 except IndexError: 466 pass 467 else: 468 assert False, "expected IndexError on accessing an out-of-bounds attribute" 469 470 471run(testOperationAttributes) 472 473 474# CHECK-LABEL: TEST: testOperationPrint 475def testOperationPrint(): 476 ctx = Context() 477 module = Module.parse(r""" 478 func @f1(%arg0: i32) -> i32 { 479 %0 = constant dense<[1, 2, 3, 4]> : tensor<4xi32> 480 return %arg0 : i32 481 } 482 """, ctx) 483 484 # Test print to stdout. 485 # CHECK: return %arg0 : i32 486 module.operation.print() 487 488 # Test print to text file. 489 f = io.StringIO() 490 # CHECK: <class 'str'> 491 # CHECK: return %arg0 : i32 492 module.operation.print(file=f) 493 str_value = f.getvalue() 494 print(str_value.__class__) 495 print(f.getvalue()) 496 497 # Test print to binary file. 498 f = io.BytesIO() 499 # CHECK: <class 'bytes'> 500 # CHECK: return %arg0 : i32 501 module.operation.print(file=f, binary=True) 502 bytes_value = f.getvalue() 503 print(bytes_value.__class__) 504 print(bytes_value) 505 506 # Test get_asm with options. 507 # CHECK: value = opaque<"_", "0xDEADBEEF"> : tensor<4xi32> 508 # CHECK: "std.return"(%arg0) : (i32) -> () -:4:7 509 module.operation.print(large_elements_limit=2, enable_debug_info=True, 510 pretty_debug_info=True, print_generic_op_form=True, use_local_scope=True) 511 512run(testOperationPrint) 513 514 515# CHECK-LABEL: TEST: testKnownOpView 516def testKnownOpView(): 517 with Context(), Location.unknown(): 518 Context.current.allow_unregistered_dialects = True 519 module = Module.parse(r""" 520 %1 = "custom.f32"() : () -> f32 521 %2 = "custom.f32"() : () -> f32 522 %3 = addf %1, %2 : f32 523 """) 524 print(module) 525 526 # addf should map to a known OpView class in the std dialect. 527 # We know the OpView for it defines an 'lhs' attribute. 528 addf = module.body.operations[2] 529 # CHECK: <mlir.dialects._std_ops_gen._AddFOp object 530 print(repr(addf)) 531 # CHECK: "custom.f32"() 532 print(addf.lhs) 533 534 # One of the custom ops should resolve to the default OpView. 535 custom = module.body.operations[0] 536 # CHECK: OpView object 537 print(repr(custom)) 538 539 # Check again to make sure negative caching works. 540 custom = module.body.operations[0] 541 # CHECK: OpView object 542 print(repr(custom)) 543 544run(testKnownOpView) 545 546 547# CHECK-LABEL: TEST: testSingleResultProperty 548def testSingleResultProperty(): 549 with Context(), Location.unknown(): 550 Context.current.allow_unregistered_dialects = True 551 module = Module.parse(r""" 552 "custom.no_result"() : () -> () 553 %0:2 = "custom.two_result"() : () -> (f32, f32) 554 %1 = "custom.one_result"() : () -> f32 555 """) 556 print(module) 557 558 try: 559 module.body.operations[0].result 560 except ValueError as e: 561 # CHECK: Cannot call .result on operation custom.no_result which has 0 results 562 print(e) 563 else: 564 assert False, "Expected exception" 565 566 try: 567 module.body.operations[1].result 568 except ValueError as e: 569 # CHECK: Cannot call .result on operation custom.two_result which has 2 results 570 print(e) 571 else: 572 assert False, "Expected exception" 573 574 # CHECK: %1 = "custom.one_result"() : () -> f32 575 print(module.body.operations[2]) 576 577run(testSingleResultProperty) 578 579# CHECK-LABEL: TEST: testPrintInvalidOperation 580def testPrintInvalidOperation(): 581 ctx = Context() 582 with Location.unknown(ctx): 583 module = Operation.create("builtin.module", regions=2) 584 # This module has two region and is invalid verify that we fallback 585 # to the generic printer for safety. 586 block = module.regions[0].blocks.append() 587 # CHECK: // Verification failed, printing generic form 588 # CHECK: "builtin.module"() ( { 589 # CHECK: }) : () -> () 590 print(module) 591 # CHECK: .verify = False 592 print(f".verify = {module.operation.verify()}") 593run(testPrintInvalidOperation) 594 595 596# CHECK-LABEL: TEST: testCreateWithInvalidAttributes 597def testCreateWithInvalidAttributes(): 598 ctx = Context() 599 with Location.unknown(ctx): 600 try: 601 Operation.create("builtin.module", attributes={None:StringAttr.get("name")}) 602 except Exception as e: 603 # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module" 604 print(e) 605 try: 606 Operation.create("builtin.module", attributes={42:StringAttr.get("name")}) 607 except Exception as e: 608 # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module" 609 print(e) 610 try: 611 Operation.create("builtin.module", attributes={"some_key":ctx}) 612 except Exception as e: 613 # CHECK: Invalid attribute value for the key "some_key" when attempting to create the operation "builtin.module" 614 print(e) 615 try: 616 Operation.create("builtin.module", attributes={"some_key":None}) 617 except Exception as e: 618 # CHECK: Found an invalid (`None`?) attribute value for the key "some_key" when attempting to create the operation "builtin.module" 619 print(e) 620run(testCreateWithInvalidAttributes) 621 622 623# CHECK-LABEL: TEST: testOperationName 624def testOperationName(): 625 ctx = Context() 626 ctx.allow_unregistered_dialects = True 627 module = Module.parse(r""" 628 %0 = "custom.op1"() : () -> f32 629 %1 = "custom.op2"() : () -> i32 630 %2 = "custom.op1"() : () -> f32 631 """, ctx) 632 633 # CHECK: custom.op1 634 # CHECK: custom.op2 635 # CHECK: custom.op1 636 for op in module.body.operations: 637 print(op.operation.name) 638 639run(testOperationName) 640 641# CHECK-LABEL: TEST: testCapsuleConversions 642def testCapsuleConversions(): 643 ctx = Context() 644 ctx.allow_unregistered_dialects = True 645 with Location.unknown(ctx): 646 m = Operation.create("custom.op1").operation 647 m_capsule = m._CAPIPtr 648 assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule) 649 m2 = Operation._CAPICreate(m_capsule) 650 assert m2 is m 651 652run(testCapsuleConversions) 653 654# CHECK-LABEL: TEST: testOperationErase 655def testOperationErase(): 656 ctx = Context() 657 ctx.allow_unregistered_dialects = True 658 with Location.unknown(ctx): 659 m = Module.create() 660 with InsertionPoint(m.body): 661 op = Operation.create("custom.op1") 662 663 # CHECK: "custom.op1" 664 print(m) 665 666 op.operation.erase() 667 668 # CHECK-NOT: "custom.op1" 669 print(m) 670 671 # Ensure we can create another operation 672 Operation.create("custom.op2") 673 674run(testOperationErase) 675