1# RUN: %PYTHON %s | FileCheck %s
2
3import gc
4import io
5import itertools
6from mlir.ir import *
7
8def run(f):
9  print("\nTEST:", f.__name__)
10  f()
11  gc.collect()
12  assert Context._get_live_count() == 0
13
14
15# Verify iterator based traversal of the op/region/block hierarchy.
16# CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators
17def testTraverseOpRegionBlockIterators():
18  ctx = Context()
19  ctx.allow_unregistered_dialects = True
20  module = Module.parse(r"""
21    func @f1(%arg0: i32) -> i32 {
22      %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
23      return %1 : i32
24    }
25  """, ctx)
26  op = module.operation
27  assert op.context is ctx
28  # Get the block using iterators off of the named collections.
29  regions = list(op.regions)
30  blocks = list(regions[0].blocks)
31  # CHECK: MODULE REGIONS=1 BLOCKS=1
32  print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}")
33
34  # Should verify.
35  # CHECK: .verify = True
36  print(f".verify = {module.operation.verify()}")
37
38  # Get the regions and blocks from the default collections.
39  default_regions = list(op)
40  default_blocks = list(default_regions[0])
41  # They should compare equal regardless of how obtained.
42  assert default_regions == regions
43  assert default_blocks == blocks
44
45  # Should be able to get the operations from either the named collection
46  # or the block.
47  operations = list(blocks[0].operations)
48  default_operations = list(blocks[0])
49  assert default_operations == operations
50
51  def walk_operations(indent, op):
52    for i, region in enumerate(op):
53      print(f"{indent}REGION {i}:")
54      for j, block in enumerate(region):
55        print(f"{indent}  BLOCK {j}:")
56        for k, child_op in enumerate(block):
57          print(f"{indent}    OP {k}: {child_op}")
58          walk_operations(indent + "      ", child_op)
59
60  # CHECK: REGION 0:
61  # CHECK:   BLOCK 0:
62  # CHECK:     OP 0: func
63  # CHECK:       REGION 0:
64  # CHECK:         BLOCK 0:
65  # CHECK:           OP 0: %0 = "custom.addi"
66  # CHECK:           OP 1: return
67  walk_operations("", op)
68
69run(testTraverseOpRegionBlockIterators)
70
71
72# Verify index based traversal of the op/region/block hierarchy.
73# CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices
74def testTraverseOpRegionBlockIndices():
75  ctx = Context()
76  ctx.allow_unregistered_dialects = True
77  module = Module.parse(r"""
78    func @f1(%arg0: i32) -> i32 {
79      %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
80      return %1 : i32
81    }
82  """, ctx)
83
84  def walk_operations(indent, op):
85    for i in range(len(op.regions)):
86      region = op.regions[i]
87      print(f"{indent}REGION {i}:")
88      for j in range(len(region.blocks)):
89        block = region.blocks[j]
90        print(f"{indent}  BLOCK {j}:")
91        for k in range(len(block.operations)):
92          child_op = block.operations[k]
93          print(f"{indent}    OP {k}: {child_op}")
94          print(f"{indent}    OP {k}: parent {child_op.operation.parent.name}")
95          walk_operations(indent + "      ", child_op)
96
97  # CHECK: REGION 0:
98  # CHECK:   BLOCK 0:
99  # CHECK:     OP 0: func
100  # CHECK:     OP 0: parent module
101  # CHECK:       REGION 0:
102  # CHECK:         BLOCK 0:
103  # CHECK:           OP 0: %0 = "custom.addi"
104  # CHECK:           OP 0: parent func
105  # CHECK:           OP 1: return
106  # CHECK:           OP 1: parent func
107  walk_operations("", module.operation)
108
109run(testTraverseOpRegionBlockIndices)
110
111
112# CHECK-LABEL: TEST: testBlockArgumentList
113def testBlockArgumentList():
114  with Context() as ctx:
115    module = Module.parse(r"""
116      func @f1(%arg0: i32, %arg1: f64, %arg2: index) {
117        return
118      }
119    """, ctx)
120    func = module.body.operations[0]
121    entry_block = func.regions[0].blocks[0]
122    assert len(entry_block.arguments) == 3
123    # CHECK: Argument 0, type i32
124    # CHECK: Argument 1, type f64
125    # CHECK: Argument 2, type index
126    for arg in entry_block.arguments:
127      print(f"Argument {arg.arg_number}, type {arg.type}")
128      new_type = IntegerType.get_signless(8 * (arg.arg_number + 1))
129      arg.set_type(new_type)
130
131    # CHECK: Argument 0, type i8
132    # CHECK: Argument 1, type i16
133    # CHECK: Argument 2, type i24
134    for arg in entry_block.arguments:
135      print(f"Argument {arg.arg_number}, type {arg.type}")
136
137
138run(testBlockArgumentList)
139
140
141# CHECK-LABEL: TEST: testOperationOperands
142def testOperationOperands():
143  with Context() as ctx:
144    ctx.allow_unregistered_dialects = True
145    module = Module.parse(r"""
146      func @f1(%arg0: i32) {
147        %0 = "test.producer"() : () -> i64
148        "test.consumer"(%arg0, %0) : (i32, i64) -> ()
149        return
150      }""")
151    func = module.body.operations[0]
152    entry_block = func.regions[0].blocks[0]
153    consumer = entry_block.operations[1]
154    assert len(consumer.operands) == 2
155    # CHECK: Operand 0, type i32
156    # CHECK: Operand 1, type i64
157    for i, operand in enumerate(consumer.operands):
158      print(f"Operand {i}, type {operand.type}")
159
160
161run(testOperationOperands)
162
163
164# CHECK-LABEL: TEST: testOperationOperandsSlice
165def testOperationOperandsSlice():
166  with Context() as ctx:
167    ctx.allow_unregistered_dialects = True
168    module = Module.parse(r"""
169      func @f1() {
170        %0 = "test.producer0"() : () -> i64
171        %1 = "test.producer1"() : () -> i64
172        %2 = "test.producer2"() : () -> i64
173        %3 = "test.producer3"() : () -> i64
174        %4 = "test.producer4"() : () -> i64
175        "test.consumer"(%0, %1, %2, %3, %4) : (i64, i64, i64, i64, i64) -> ()
176        return
177      }""")
178    func = module.body.operations[0]
179    entry_block = func.regions[0].blocks[0]
180    consumer = entry_block.operations[5]
181    assert len(consumer.operands) == 5
182    for left, right in zip(consumer.operands, consumer.operands[::-1][::-1]):
183      assert left == right
184
185    # CHECK: test.producer0
186    # CHECK: test.producer1
187    # CHECK: test.producer2
188    # CHECK: test.producer3
189    # CHECK: test.producer4
190    full_slice = consumer.operands[:]
191    for operand in full_slice:
192      print(operand)
193
194    # CHECK: test.producer0
195    # CHECK: test.producer1
196    first_two = consumer.operands[0:2]
197    for operand in first_two:
198      print(operand)
199
200    # CHECK: test.producer3
201    # CHECK: test.producer4
202    last_two = consumer.operands[3:]
203    for operand in last_two:
204      print(operand)
205
206    # CHECK: test.producer0
207    # CHECK: test.producer2
208    # CHECK: test.producer4
209    even = consumer.operands[::2]
210    for operand in even:
211      print(operand)
212
213    # CHECK: test.producer2
214    fourth = consumer.operands[::2][1::2]
215    for operand in fourth:
216      print(operand)
217
218
219run(testOperationOperandsSlice)
220
221
222# CHECK-LABEL: TEST: testOperationOperandsSet
223def testOperationOperandsSet():
224  with Context() as ctx, Location.unknown(ctx):
225    ctx.allow_unregistered_dialects = True
226    module = Module.parse(r"""
227      func @f1() {
228        %0 = "test.producer0"() : () -> i64
229        %1 = "test.producer1"() : () -> i64
230        %2 = "test.producer2"() : () -> i64
231        "test.consumer"(%0) : (i64) -> ()
232        return
233      }""")
234    func = module.body.operations[0]
235    entry_block = func.regions[0].blocks[0]
236    producer1 = entry_block.operations[1]
237    producer2 = entry_block.operations[2]
238    consumer = entry_block.operations[3]
239    assert len(consumer.operands) == 1
240    type = consumer.operands[0].type
241
242    # CHECK: test.producer1
243    consumer.operands[0] = producer1.result
244    print(consumer.operands[0])
245
246    # CHECK: test.producer2
247    consumer.operands[-1] = producer2.result
248    print(consumer.operands[0])
249
250
251run(testOperationOperandsSet)
252
253
254# CHECK-LABEL: TEST: testDetachedOperation
255def testDetachedOperation():
256  ctx = Context()
257  ctx.allow_unregistered_dialects = True
258  with Location.unknown(ctx):
259    i32 = IntegerType.get_signed(32)
260    op1 = Operation.create(
261        "custom.op1", results=[i32, i32], regions=1, attributes={
262            "foo": StringAttr.get("foo_value"),
263            "bar": StringAttr.get("bar_value"),
264        })
265    # CHECK: %0:2 = "custom.op1"() ( {
266    # CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32)
267    print(op1)
268
269  # TODO: Check successors once enough infra exists to do it properly.
270
271run(testDetachedOperation)
272
273
274# CHECK-LABEL: TEST: testOperationInsertionPoint
275def testOperationInsertionPoint():
276  ctx = Context()
277  ctx.allow_unregistered_dialects = True
278  module = Module.parse(r"""
279    func @f1(%arg0: i32) -> i32 {
280      %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
281      return %1 : i32
282    }
283  """, ctx)
284
285  # Create test op.
286  with Location.unknown(ctx):
287    op1 = Operation.create("custom.op1")
288    op2 = Operation.create("custom.op2")
289
290    func = module.body.operations[0]
291    entry_block = func.regions[0].blocks[0]
292    ip = InsertionPoint.at_block_begin(entry_block)
293    ip.insert(op1)
294    ip.insert(op2)
295    # CHECK: func @f1
296    # CHECK: "custom.op1"()
297    # CHECK: "custom.op2"()
298    # CHECK: %0 = "custom.addi"
299    print(module)
300
301  # Trying to add a previously added op should raise.
302  try:
303    ip.insert(op1)
304  except ValueError:
305    pass
306  else:
307    assert False, "expected insert of attached op to raise"
308
309run(testOperationInsertionPoint)
310
311
312# CHECK-LABEL: TEST: testOperationWithRegion
313def testOperationWithRegion():
314  ctx = Context()
315  ctx.allow_unregistered_dialects = True
316  with Location.unknown(ctx):
317    i32 = IntegerType.get_signed(32)
318    op1 = Operation.create("custom.op1", regions=1)
319    block = op1.regions[0].blocks.append(i32, i32)
320    # CHECK: "custom.op1"() ( {
321    # CHECK: ^bb0(%arg0: si32, %arg1: si32):  // no predecessors
322    # CHECK:   "custom.terminator"() : () -> ()
323    # CHECK: }) : () -> ()
324    terminator = Operation.create("custom.terminator")
325    ip = InsertionPoint(block)
326    ip.insert(terminator)
327    print(op1)
328
329    # Now add the whole operation to another op.
330    # TODO: Verify lifetime hazard by nulling out the new owning module and
331    # accessing op1.
332    # TODO: Also verify accessing the terminator once both parents are nulled
333    # out.
334    module = Module.parse(r"""
335      func @f1(%arg0: i32) -> i32 {
336        %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
337        return %1 : i32
338      }
339    """)
340    func = module.body.operations[0]
341    entry_block = func.regions[0].blocks[0]
342    ip = InsertionPoint.at_block_begin(entry_block)
343    ip.insert(op1)
344    # CHECK: func @f1
345    # CHECK: "custom.op1"()
346    # CHECK:   "custom.terminator"
347    # CHECK: %0 = "custom.addi"
348    print(module)
349
350run(testOperationWithRegion)
351
352
353# CHECK-LABEL: TEST: testOperationResultList
354def testOperationResultList():
355  ctx = Context()
356  module = Module.parse(r"""
357    func @f1() {
358      %0:3 = call @f2() : () -> (i32, f64, index)
359      return
360    }
361    func private @f2() -> (i32, f64, index)
362  """, ctx)
363  caller = module.body.operations[0]
364  call = caller.regions[0].blocks[0].operations[0]
365  assert len(call.results) == 3
366  # CHECK: Result 0, type i32
367  # CHECK: Result 1, type f64
368  # CHECK: Result 2, type index
369  for res in call.results:
370    print(f"Result {res.result_number}, type {res.type}")
371
372
373run(testOperationResultList)
374
375
376# CHECK-LABEL: TEST: testOperationResultListSlice
377def testOperationResultListSlice():
378  with Context() as ctx:
379    ctx.allow_unregistered_dialects = True
380    module = Module.parse(r"""
381      func @f1() {
382        "some.op"() : () -> (i1, i2, i3, i4, i5)
383        return
384      }
385    """)
386    func = module.body.operations[0]
387    entry_block = func.regions[0].blocks[0]
388    producer = entry_block.operations[0]
389
390    assert len(producer.results) == 5
391    for left, right in zip(producer.results, producer.results[::-1][::-1]):
392      assert left == right
393      assert left.result_number == right.result_number
394
395    # CHECK: Result 0, type i1
396    # CHECK: Result 1, type i2
397    # CHECK: Result 2, type i3
398    # CHECK: Result 3, type i4
399    # CHECK: Result 4, type i5
400    full_slice = producer.results[:]
401    for res in full_slice:
402      print(f"Result {res.result_number}, type {res.type}")
403
404    # CHECK: Result 1, type i2
405    # CHECK: Result 2, type i3
406    # CHECK: Result 3, type i4
407    middle = producer.results[1:4]
408    for res in middle:
409      print(f"Result {res.result_number}, type {res.type}")
410
411    # CHECK: Result 1, type i2
412    # CHECK: Result 3, type i4
413    odd = producer.results[1::2]
414    for res in odd:
415      print(f"Result {res.result_number}, type {res.type}")
416
417    # CHECK: Result 3, type i4
418    # CHECK: Result 1, type i2
419    inverted_middle = producer.results[-2:0:-2]
420    for res in inverted_middle:
421      print(f"Result {res.result_number}, type {res.type}")
422
423
424run(testOperationResultListSlice)
425
426
427# CHECK-LABEL: TEST: testOperationAttributes
428def testOperationAttributes():
429  ctx = Context()
430  ctx.allow_unregistered_dialects = True
431  module = Module.parse(r"""
432    "some.op"() { some.attribute = 1 : i8,
433                  other.attribute = 3.0,
434                  dependent = "text" } : () -> ()
435  """, ctx)
436  op = module.body.operations[0]
437  assert len(op.attributes) == 3
438  iattr = IntegerAttr(op.attributes["some.attribute"])
439  fattr = FloatAttr(op.attributes["other.attribute"])
440  sattr = StringAttr(op.attributes["dependent"])
441  # CHECK: Attribute type i8, value 1
442  print(f"Attribute type {iattr.type}, value {iattr.value}")
443  # CHECK: Attribute type f64, value 3.0
444  print(f"Attribute type {fattr.type}, value {fattr.value}")
445  # CHECK: Attribute value text
446  print(f"Attribute value {sattr.value}")
447
448  # We don't know in which order the attributes are stored.
449  # CHECK-DAG: NamedAttribute(dependent="text")
450  # CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64)
451  # CHECK-DAG: NamedAttribute(some.attribute=1 : i8)
452  for attr in op.attributes:
453    print(str(attr))
454
455  # Check that exceptions are raised as expected.
456  try:
457    op.attributes["does_not_exist"]
458  except KeyError:
459    pass
460  else:
461    assert False, "expected KeyError on accessing a non-existent attribute"
462
463  try:
464    op.attributes[42]
465  except IndexError:
466    pass
467  else:
468    assert False, "expected IndexError on accessing an out-of-bounds attribute"
469
470
471run(testOperationAttributes)
472
473
474# CHECK-LABEL: TEST: testOperationPrint
475def testOperationPrint():
476  ctx = Context()
477  module = Module.parse(r"""
478    func @f1(%arg0: i32) -> i32 {
479      %0 = constant dense<[1, 2, 3, 4]> : tensor<4xi32>
480      return %arg0 : i32
481    }
482  """, ctx)
483
484  # Test print to stdout.
485  # CHECK: return %arg0 : i32
486  module.operation.print()
487
488  # Test print to text file.
489  f = io.StringIO()
490  # CHECK: <class 'str'>
491  # CHECK: return %arg0 : i32
492  module.operation.print(file=f)
493  str_value = f.getvalue()
494  print(str_value.__class__)
495  print(f.getvalue())
496
497  # Test print to binary file.
498  f = io.BytesIO()
499  # CHECK: <class 'bytes'>
500  # CHECK: return %arg0 : i32
501  module.operation.print(file=f, binary=True)
502  bytes_value = f.getvalue()
503  print(bytes_value.__class__)
504  print(bytes_value)
505
506  # Test get_asm with options.
507  # CHECK: value = opaque<"_", "0xDEADBEEF"> : tensor<4xi32>
508  # CHECK: "std.return"(%arg0) : (i32) -> () -:4:7
509  module.operation.print(large_elements_limit=2, enable_debug_info=True,
510      pretty_debug_info=True, print_generic_op_form=True, use_local_scope=True)
511
512run(testOperationPrint)
513
514
515# CHECK-LABEL: TEST: testKnownOpView
516def testKnownOpView():
517  with Context(), Location.unknown():
518    Context.current.allow_unregistered_dialects = True
519    module = Module.parse(r"""
520      %1 = "custom.f32"() : () -> f32
521      %2 = "custom.f32"() : () -> f32
522      %3 = addf %1, %2 : f32
523    """)
524    print(module)
525
526    # addf should map to a known OpView class in the std dialect.
527    # We know the OpView for it defines an 'lhs' attribute.
528    addf = module.body.operations[2]
529    # CHECK: <mlir.dialects._std_ops_gen._AddFOp object
530    print(repr(addf))
531    # CHECK: "custom.f32"()
532    print(addf.lhs)
533
534    # One of the custom ops should resolve to the default OpView.
535    custom = module.body.operations[0]
536    # CHECK: <_mlir.ir.OpView object
537    print(repr(custom))
538
539    # Check again to make sure negative caching works.
540    custom = module.body.operations[0]
541    # CHECK: <_mlir.ir.OpView object
542    print(repr(custom))
543
544run(testKnownOpView)
545
546
547# CHECK-LABEL: TEST: testSingleResultProperty
548def testSingleResultProperty():
549  with Context(), Location.unknown():
550    Context.current.allow_unregistered_dialects = True
551    module = Module.parse(r"""
552      "custom.no_result"() : () -> ()
553      %0:2 = "custom.two_result"() : () -> (f32, f32)
554      %1 = "custom.one_result"() : () -> f32
555    """)
556    print(module)
557
558  try:
559    module.body.operations[0].result
560  except ValueError as e:
561    # CHECK: Cannot call .result on operation custom.no_result which has 0 results
562    print(e)
563  else:
564    assert False, "Expected exception"
565
566  try:
567    module.body.operations[1].result
568  except ValueError as e:
569    # CHECK: Cannot call .result on operation custom.two_result which has 2 results
570    print(e)
571  else:
572    assert False, "Expected exception"
573
574  # CHECK: %1 = "custom.one_result"() : () -> f32
575  print(module.body.operations[2])
576
577run(testSingleResultProperty)
578
579# CHECK-LABEL: TEST: testPrintInvalidOperation
580def testPrintInvalidOperation():
581  ctx = Context()
582  with Location.unknown(ctx):
583    module = Operation.create("module", regions=2)
584    # This module has two region and is invalid verify that we fallback
585    # to the generic printer for safety.
586    block = module.regions[0].blocks.append()
587    # CHECK: // Verification failed, printing generic form
588    # CHECK: "module"() ( {
589    # CHECK: }) : () -> ()
590    print(module)
591    # CHECK: .verify = False
592    print(f".verify = {module.operation.verify()}")
593run(testPrintInvalidOperation)
594
595
596# CHECK-LABEL: TEST: testCreateWithInvalidAttributes
597def testCreateWithInvalidAttributes():
598  ctx = Context()
599  with Location.unknown(ctx):
600    try:
601      Operation.create("module", attributes={None:StringAttr.get("name")})
602    except Exception as e:
603      # CHECK: Invalid attribute key (not a string) when attempting to create the operation "module"
604      print(e)
605    try:
606      Operation.create("module", attributes={42:StringAttr.get("name")})
607    except Exception as e:
608      # CHECK: Invalid attribute key (not a string) when attempting to create the operation "module"
609      print(e)
610    try:
611      Operation.create("module", attributes={"some_key":ctx})
612    except Exception as e:
613      # CHECK: Invalid attribute value for the key "some_key" when attempting to create the operation "module"
614      print(e)
615    try:
616      Operation.create("module", attributes={"some_key":None})
617    except Exception as e:
618      # CHECK: Found an invalid (`None`?) attribute value for the key "some_key" when attempting to create the operation "module"
619      print(e)
620run(testCreateWithInvalidAttributes)
621
622
623# CHECK-LABEL: TEST: testOperationName
624def testOperationName():
625  ctx = Context()
626  ctx.allow_unregistered_dialects = True
627  module = Module.parse(r"""
628    %0 = "custom.op1"() : () -> f32
629    %1 = "custom.op2"() : () -> i32
630    %2 = "custom.op1"() : () -> f32
631  """, ctx)
632
633  # CHECK: custom.op1
634  # CHECK: custom.op2
635  # CHECK: custom.op1
636  for op in module.body.operations:
637    print(op.operation.name)
638
639run(testOperationName)
640
641# CHECK-LABEL: TEST: testCapsuleConversions
642def testCapsuleConversions():
643  ctx = Context()
644  ctx.allow_unregistered_dialects = True
645  with Location.unknown(ctx):
646    m = Operation.create("custom.op1").operation
647    m_capsule = m._CAPIPtr
648    assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule)
649    m2 = Operation._CAPICreate(m_capsule)
650    assert m2 is m
651
652run(testCapsuleConversions)
653
654# CHECK-LABEL: TEST: testOperationErase
655def testOperationErase():
656  ctx = Context()
657  ctx.allow_unregistered_dialects = True
658  with Location.unknown(ctx):
659    m = Module.create()
660    with InsertionPoint(m.body):
661      op = Operation.create("custom.op1")
662
663      # CHECK: "custom.op1"
664      print(m)
665
666      op.operation.erase()
667
668      # CHECK-NOT: "custom.op1"
669      print(m)
670
671      # Ensure we can create another operation
672      Operation.create("custom.op2")
673
674run(testOperationErase)
675