1# RUN: %PYTHON %s | FileCheck %s
2
3import gc
4import io
5import itertools
6from mlir.ir import *
7
8def run(f):
9  print("\nTEST:", f.__name__)
10  f()
11  gc.collect()
12  assert Context._get_live_count() == 0
13
14
15# Verify iterator based traversal of the op/region/block hierarchy.
16# CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators
17def testTraverseOpRegionBlockIterators():
18  ctx = Context()
19  ctx.allow_unregistered_dialects = True
20  module = Module.parse(r"""
21    func @f1(%arg0: i32) -> i32 {
22      %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
23      return %1 : i32
24    }
25  """, ctx)
26  op = module.operation
27  assert op.context is ctx
28  # Get the block using iterators off of the named collections.
29  regions = list(op.regions)
30  blocks = list(regions[0].blocks)
31  # CHECK: MODULE REGIONS=1 BLOCKS=1
32  print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}")
33
34  # Should verify.
35  # CHECK: .verify = True
36  print(f".verify = {module.operation.verify()}")
37
38  # Get the regions and blocks from the default collections.
39  default_regions = list(op)
40  default_blocks = list(default_regions[0])
41  # They should compare equal regardless of how obtained.
42  assert default_regions == regions
43  assert default_blocks == blocks
44
45  # Should be able to get the operations from either the named collection
46  # or the block.
47  operations = list(blocks[0].operations)
48  default_operations = list(blocks[0])
49  assert default_operations == operations
50
51  def walk_operations(indent, op):
52    for i, region in enumerate(op):
53      print(f"{indent}REGION {i}:")
54      for j, block in enumerate(region):
55        print(f"{indent}  BLOCK {j}:")
56        for k, child_op in enumerate(block):
57          print(f"{indent}    OP {k}: {child_op}")
58          walk_operations(indent + "      ", child_op)
59
60  # CHECK: REGION 0:
61  # CHECK:   BLOCK 0:
62  # CHECK:     OP 0: func
63  # CHECK:       REGION 0:
64  # CHECK:         BLOCK 0:
65  # CHECK:           OP 0: %0 = "custom.addi"
66  # CHECK:           OP 1: return
67  walk_operations("", op)
68
69run(testTraverseOpRegionBlockIterators)
70
71
72# Verify index based traversal of the op/region/block hierarchy.
73# CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices
74def testTraverseOpRegionBlockIndices():
75  ctx = Context()
76  ctx.allow_unregistered_dialects = True
77  module = Module.parse(r"""
78    func @f1(%arg0: i32) -> i32 {
79      %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
80      return %1 : i32
81    }
82  """, ctx)
83
84  def walk_operations(indent, op):
85    for i in range(len(op.regions)):
86      region = op.regions[i]
87      print(f"{indent}REGION {i}:")
88      for j in range(len(region.blocks)):
89        block = region.blocks[j]
90        print(f"{indent}  BLOCK {j}:")
91        for k in range(len(block.operations)):
92          child_op = block.operations[k]
93          print(f"{indent}    OP {k}: {child_op}")
94          print(f"{indent}    OP {k}: parent {child_op.operation.parent.name}")
95          walk_operations(indent + "      ", child_op)
96
97  # CHECK: REGION 0:
98  # CHECK:   BLOCK 0:
99  # CHECK:     OP 0: func
100  # CHECK:     OP 0: parent builtin.module
101  # CHECK:       REGION 0:
102  # CHECK:         BLOCK 0:
103  # CHECK:           OP 0: %0 = "custom.addi"
104  # CHECK:           OP 0: parent builtin.func
105  # CHECK:           OP 1: return
106  # CHECK:           OP 1: parent builtin.func
107  walk_operations("", module.operation)
108
109run(testTraverseOpRegionBlockIndices)
110
111
112# CHECK-LABEL: TEST: testBlockArgumentList
113def testBlockArgumentList():
114  with Context() as ctx:
115    module = Module.parse(r"""
116      func @f1(%arg0: i32, %arg1: f64, %arg2: index) {
117        return
118      }
119    """, ctx)
120    func = module.body.operations[0]
121    entry_block = func.regions[0].blocks[0]
122    assert len(entry_block.arguments) == 3
123    # CHECK: Argument 0, type i32
124    # CHECK: Argument 1, type f64
125    # CHECK: Argument 2, type index
126    for arg in entry_block.arguments:
127      print(f"Argument {arg.arg_number}, type {arg.type}")
128      new_type = IntegerType.get_signless(8 * (arg.arg_number + 1))
129      arg.set_type(new_type)
130
131    # CHECK: Argument 0, type i8
132    # CHECK: Argument 1, type i16
133    # CHECK: Argument 2, type i24
134    for arg in entry_block.arguments:
135      print(f"Argument {arg.arg_number}, type {arg.type}")
136
137    # Check that slicing works for block argument lists.
138    # CHECK: Argument 1, type i16
139    # CHECK: Argument 2, type i24
140    for arg in entry_block.arguments[1:]:
141      print(f"Argument {arg.arg_number}, type {arg.type}")
142
143    # Check that we can concatenate slices of argument lists.
144    # CHECK: Length: 4
145    print("Length: ",
146          len(entry_block.arguments[:2] + entry_block.arguments[1:]))
147
148
149run(testBlockArgumentList)
150
151
152# CHECK-LABEL: TEST: testOperationOperands
153def testOperationOperands():
154  with Context() as ctx:
155    ctx.allow_unregistered_dialects = True
156    module = Module.parse(r"""
157      func @f1(%arg0: i32) {
158        %0 = "test.producer"() : () -> i64
159        "test.consumer"(%arg0, %0) : (i32, i64) -> ()
160        return
161      }""")
162    func = module.body.operations[0]
163    entry_block = func.regions[0].blocks[0]
164    consumer = entry_block.operations[1]
165    assert len(consumer.operands) == 2
166    # CHECK: Operand 0, type i32
167    # CHECK: Operand 1, type i64
168    for i, operand in enumerate(consumer.operands):
169      print(f"Operand {i}, type {operand.type}")
170
171
172run(testOperationOperands)
173
174
175# CHECK-LABEL: TEST: testOperationOperandsSlice
176def testOperationOperandsSlice():
177  with Context() as ctx:
178    ctx.allow_unregistered_dialects = True
179    module = Module.parse(r"""
180      func @f1() {
181        %0 = "test.producer0"() : () -> i64
182        %1 = "test.producer1"() : () -> i64
183        %2 = "test.producer2"() : () -> i64
184        %3 = "test.producer3"() : () -> i64
185        %4 = "test.producer4"() : () -> i64
186        "test.consumer"(%0, %1, %2, %3, %4) : (i64, i64, i64, i64, i64) -> ()
187        return
188      }""")
189    func = module.body.operations[0]
190    entry_block = func.regions[0].blocks[0]
191    consumer = entry_block.operations[5]
192    assert len(consumer.operands) == 5
193    for left, right in zip(consumer.operands, consumer.operands[::-1][::-1]):
194      assert left == right
195
196    # CHECK: test.producer0
197    # CHECK: test.producer1
198    # CHECK: test.producer2
199    # CHECK: test.producer3
200    # CHECK: test.producer4
201    full_slice = consumer.operands[:]
202    for operand in full_slice:
203      print(operand)
204
205    # CHECK: test.producer0
206    # CHECK: test.producer1
207    first_two = consumer.operands[0:2]
208    for operand in first_two:
209      print(operand)
210
211    # CHECK: test.producer3
212    # CHECK: test.producer4
213    last_two = consumer.operands[3:]
214    for operand in last_two:
215      print(operand)
216
217    # CHECK: test.producer0
218    # CHECK: test.producer2
219    # CHECK: test.producer4
220    even = consumer.operands[::2]
221    for operand in even:
222      print(operand)
223
224    # CHECK: test.producer2
225    fourth = consumer.operands[::2][1::2]
226    for operand in fourth:
227      print(operand)
228
229
230run(testOperationOperandsSlice)
231
232
233# CHECK-LABEL: TEST: testOperationOperandsSet
234def testOperationOperandsSet():
235  with Context() as ctx, Location.unknown(ctx):
236    ctx.allow_unregistered_dialects = True
237    module = Module.parse(r"""
238      func @f1() {
239        %0 = "test.producer0"() : () -> i64
240        %1 = "test.producer1"() : () -> i64
241        %2 = "test.producer2"() : () -> i64
242        "test.consumer"(%0) : (i64) -> ()
243        return
244      }""")
245    func = module.body.operations[0]
246    entry_block = func.regions[0].blocks[0]
247    producer1 = entry_block.operations[1]
248    producer2 = entry_block.operations[2]
249    consumer = entry_block.operations[3]
250    assert len(consumer.operands) == 1
251    type = consumer.operands[0].type
252
253    # CHECK: test.producer1
254    consumer.operands[0] = producer1.result
255    print(consumer.operands[0])
256
257    # CHECK: test.producer2
258    consumer.operands[-1] = producer2.result
259    print(consumer.operands[0])
260
261
262run(testOperationOperandsSet)
263
264
265# CHECK-LABEL: TEST: testDetachedOperation
266def testDetachedOperation():
267  ctx = Context()
268  ctx.allow_unregistered_dialects = True
269  with Location.unknown(ctx):
270    i32 = IntegerType.get_signed(32)
271    op1 = Operation.create(
272        "custom.op1", results=[i32, i32], regions=1, attributes={
273            "foo": StringAttr.get("foo_value"),
274            "bar": StringAttr.get("bar_value"),
275        })
276    # CHECK: %0:2 = "custom.op1"() ( {
277    # CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32)
278    print(op1)
279
280  # TODO: Check successors once enough infra exists to do it properly.
281
282run(testDetachedOperation)
283
284
285# CHECK-LABEL: TEST: testOperationInsertionPoint
286def testOperationInsertionPoint():
287  ctx = Context()
288  ctx.allow_unregistered_dialects = True
289  module = Module.parse(r"""
290    func @f1(%arg0: i32) -> i32 {
291      %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
292      return %1 : i32
293    }
294  """, ctx)
295
296  # Create test op.
297  with Location.unknown(ctx):
298    op1 = Operation.create("custom.op1")
299    op2 = Operation.create("custom.op2")
300
301    func = module.body.operations[0]
302    entry_block = func.regions[0].blocks[0]
303    ip = InsertionPoint.at_block_begin(entry_block)
304    ip.insert(op1)
305    ip.insert(op2)
306    # CHECK: func @f1
307    # CHECK: "custom.op1"()
308    # CHECK: "custom.op2"()
309    # CHECK: %0 = "custom.addi"
310    print(module)
311
312  # Trying to add a previously added op should raise.
313  try:
314    ip.insert(op1)
315  except ValueError:
316    pass
317  else:
318    assert False, "expected insert of attached op to raise"
319
320run(testOperationInsertionPoint)
321
322
323# CHECK-LABEL: TEST: testOperationWithRegion
324def testOperationWithRegion():
325  ctx = Context()
326  ctx.allow_unregistered_dialects = True
327  with Location.unknown(ctx):
328    i32 = IntegerType.get_signed(32)
329    op1 = Operation.create("custom.op1", regions=1)
330    block = op1.regions[0].blocks.append(i32, i32)
331    # CHECK: "custom.op1"() ( {
332    # CHECK: ^bb0(%arg0: si32, %arg1: si32):  // no predecessors
333    # CHECK:   "custom.terminator"() : () -> ()
334    # CHECK: }) : () -> ()
335    terminator = Operation.create("custom.terminator")
336    ip = InsertionPoint(block)
337    ip.insert(terminator)
338    print(op1)
339
340    # Now add the whole operation to another op.
341    # TODO: Verify lifetime hazard by nulling out the new owning module and
342    # accessing op1.
343    # TODO: Also verify accessing the terminator once both parents are nulled
344    # out.
345    module = Module.parse(r"""
346      func @f1(%arg0: i32) -> i32 {
347        %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
348        return %1 : i32
349      }
350    """)
351    func = module.body.operations[0]
352    entry_block = func.regions[0].blocks[0]
353    ip = InsertionPoint.at_block_begin(entry_block)
354    ip.insert(op1)
355    # CHECK: func @f1
356    # CHECK: "custom.op1"()
357    # CHECK:   "custom.terminator"
358    # CHECK: %0 = "custom.addi"
359    print(module)
360
361run(testOperationWithRegion)
362
363
364# CHECK-LABEL: TEST: testOperationResultList
365def testOperationResultList():
366  ctx = Context()
367  module = Module.parse(r"""
368    func @f1() {
369      %0:3 = call @f2() : () -> (i32, f64, index)
370      return
371    }
372    func private @f2() -> (i32, f64, index)
373  """, ctx)
374  caller = module.body.operations[0]
375  call = caller.regions[0].blocks[0].operations[0]
376  assert len(call.results) == 3
377  # CHECK: Result 0, type i32
378  # CHECK: Result 1, type f64
379  # CHECK: Result 2, type index
380  for res in call.results:
381    print(f"Result {res.result_number}, type {res.type}")
382
383
384run(testOperationResultList)
385
386
387# CHECK-LABEL: TEST: testOperationResultListSlice
388def testOperationResultListSlice():
389  with Context() as ctx:
390    ctx.allow_unregistered_dialects = True
391    module = Module.parse(r"""
392      func @f1() {
393        "some.op"() : () -> (i1, i2, i3, i4, i5)
394        return
395      }
396    """)
397    func = module.body.operations[0]
398    entry_block = func.regions[0].blocks[0]
399    producer = entry_block.operations[0]
400
401    assert len(producer.results) == 5
402    for left, right in zip(producer.results, producer.results[::-1][::-1]):
403      assert left == right
404      assert left.result_number == right.result_number
405
406    # CHECK: Result 0, type i1
407    # CHECK: Result 1, type i2
408    # CHECK: Result 2, type i3
409    # CHECK: Result 3, type i4
410    # CHECK: Result 4, type i5
411    full_slice = producer.results[:]
412    for res in full_slice:
413      print(f"Result {res.result_number}, type {res.type}")
414
415    # CHECK: Result 1, type i2
416    # CHECK: Result 2, type i3
417    # CHECK: Result 3, type i4
418    middle = producer.results[1:4]
419    for res in middle:
420      print(f"Result {res.result_number}, type {res.type}")
421
422    # CHECK: Result 1, type i2
423    # CHECK: Result 3, type i4
424    odd = producer.results[1::2]
425    for res in odd:
426      print(f"Result {res.result_number}, type {res.type}")
427
428    # CHECK: Result 3, type i4
429    # CHECK: Result 1, type i2
430    inverted_middle = producer.results[-2:0:-2]
431    for res in inverted_middle:
432      print(f"Result {res.result_number}, type {res.type}")
433
434
435run(testOperationResultListSlice)
436
437
438# CHECK-LABEL: TEST: testOperationAttributes
439def testOperationAttributes():
440  ctx = Context()
441  ctx.allow_unregistered_dialects = True
442  module = Module.parse(r"""
443    "some.op"() { some.attribute = 1 : i8,
444                  other.attribute = 3.0,
445                  dependent = "text" } : () -> ()
446  """, ctx)
447  op = module.body.operations[0]
448  assert len(op.attributes) == 3
449  iattr = IntegerAttr(op.attributes["some.attribute"])
450  fattr = FloatAttr(op.attributes["other.attribute"])
451  sattr = StringAttr(op.attributes["dependent"])
452  # CHECK: Attribute type i8, value 1
453  print(f"Attribute type {iattr.type}, value {iattr.value}")
454  # CHECK: Attribute type f64, value 3.0
455  print(f"Attribute type {fattr.type}, value {fattr.value}")
456  # CHECK: Attribute value text
457  print(f"Attribute value {sattr.value}")
458
459  # We don't know in which order the attributes are stored.
460  # CHECK-DAG: NamedAttribute(dependent="text")
461  # CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64)
462  # CHECK-DAG: NamedAttribute(some.attribute=1 : i8)
463  for attr in op.attributes:
464    print(str(attr))
465
466  # Check that exceptions are raised as expected.
467  try:
468    op.attributes["does_not_exist"]
469  except KeyError:
470    pass
471  else:
472    assert False, "expected KeyError on accessing a non-existent attribute"
473
474  try:
475    op.attributes[42]
476  except IndexError:
477    pass
478  else:
479    assert False, "expected IndexError on accessing an out-of-bounds attribute"
480
481
482run(testOperationAttributes)
483
484
485# CHECK-LABEL: TEST: testOperationPrint
486def testOperationPrint():
487  ctx = Context()
488  module = Module.parse(r"""
489    func @f1(%arg0: i32) -> i32 {
490      %0 = constant dense<[1, 2, 3, 4]> : tensor<4xi32>
491      return %arg0 : i32
492    }
493  """, ctx)
494
495  # Test print to stdout.
496  # CHECK: return %arg0 : i32
497  module.operation.print()
498
499  # Test print to text file.
500  f = io.StringIO()
501  # CHECK: <class 'str'>
502  # CHECK: return %arg0 : i32
503  module.operation.print(file=f)
504  str_value = f.getvalue()
505  print(str_value.__class__)
506  print(f.getvalue())
507
508  # Test print to binary file.
509  f = io.BytesIO()
510  # CHECK: <class 'bytes'>
511  # CHECK: return %arg0 : i32
512  module.operation.print(file=f, binary=True)
513  bytes_value = f.getvalue()
514  print(bytes_value.__class__)
515  print(bytes_value)
516
517  # Test get_asm with options.
518  # CHECK: value = opaque<"_", "0xDEADBEEF"> : tensor<4xi32>
519  # CHECK: "std.return"(%arg0) : (i32) -> () -:4:7
520  module.operation.print(large_elements_limit=2, enable_debug_info=True,
521      pretty_debug_info=True, print_generic_op_form=True, use_local_scope=True)
522
523run(testOperationPrint)
524
525
526# CHECK-LABEL: TEST: testKnownOpView
527def testKnownOpView():
528  with Context(), Location.unknown():
529    Context.current.allow_unregistered_dialects = True
530    module = Module.parse(r"""
531      %1 = "custom.f32"() : () -> f32
532      %2 = "custom.f32"() : () -> f32
533      %3 = addf %1, %2 : f32
534    """)
535    print(module)
536
537    # addf should map to a known OpView class in the std dialect.
538    # We know the OpView for it defines an 'lhs' attribute.
539    addf = module.body.operations[2]
540    # CHECK: <mlir.dialects._std_ops_gen._AddFOp object
541    print(repr(addf))
542    # CHECK: "custom.f32"()
543    print(addf.lhs)
544
545    # One of the custom ops should resolve to the default OpView.
546    custom = module.body.operations[0]
547    # CHECK: OpView object
548    print(repr(custom))
549
550    # Check again to make sure negative caching works.
551    custom = module.body.operations[0]
552    # CHECK: OpView object
553    print(repr(custom))
554
555run(testKnownOpView)
556
557
558# CHECK-LABEL: TEST: testSingleResultProperty
559def testSingleResultProperty():
560  with Context(), Location.unknown():
561    Context.current.allow_unregistered_dialects = True
562    module = Module.parse(r"""
563      "custom.no_result"() : () -> ()
564      %0:2 = "custom.two_result"() : () -> (f32, f32)
565      %1 = "custom.one_result"() : () -> f32
566    """)
567    print(module)
568
569  try:
570    module.body.operations[0].result
571  except ValueError as e:
572    # CHECK: Cannot call .result on operation custom.no_result which has 0 results
573    print(e)
574  else:
575    assert False, "Expected exception"
576
577  try:
578    module.body.operations[1].result
579  except ValueError as e:
580    # CHECK: Cannot call .result on operation custom.two_result which has 2 results
581    print(e)
582  else:
583    assert False, "Expected exception"
584
585  # CHECK: %1 = "custom.one_result"() : () -> f32
586  print(module.body.operations[2])
587
588run(testSingleResultProperty)
589
590# CHECK-LABEL: TEST: testPrintInvalidOperation
591def testPrintInvalidOperation():
592  ctx = Context()
593  with Location.unknown(ctx):
594    module = Operation.create("builtin.module", regions=2)
595    # This module has two region and is invalid verify that we fallback
596    # to the generic printer for safety.
597    block = module.regions[0].blocks.append()
598    # CHECK: // Verification failed, printing generic form
599    # CHECK: "builtin.module"() ( {
600    # CHECK: }) : () -> ()
601    print(module)
602    # CHECK: .verify = False
603    print(f".verify = {module.operation.verify()}")
604run(testPrintInvalidOperation)
605
606
607# CHECK-LABEL: TEST: testCreateWithInvalidAttributes
608def testCreateWithInvalidAttributes():
609  ctx = Context()
610  with Location.unknown(ctx):
611    try:
612      Operation.create(
613          "builtin.module", attributes={None: StringAttr.get("name")})
614    except Exception as e:
615      # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
616      print(e)
617    try:
618      Operation.create(
619          "builtin.module", attributes={42: StringAttr.get("name")})
620    except Exception as e:
621      # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
622      print(e)
623    try:
624      Operation.create("builtin.module", attributes={"some_key": ctx})
625    except Exception as e:
626      # CHECK: Invalid attribute value for the key "some_key" when attempting to create the operation "builtin.module"
627      print(e)
628    try:
629      Operation.create("builtin.module", attributes={"some_key": None})
630    except Exception as e:
631      # CHECK: Found an invalid (`None`?) attribute value for the key "some_key" when attempting to create the operation "builtin.module"
632      print(e)
633run(testCreateWithInvalidAttributes)
634
635
636# CHECK-LABEL: TEST: testOperationName
637def testOperationName():
638  ctx = Context()
639  ctx.allow_unregistered_dialects = True
640  module = Module.parse(r"""
641    %0 = "custom.op1"() : () -> f32
642    %1 = "custom.op2"() : () -> i32
643    %2 = "custom.op1"() : () -> f32
644  """, ctx)
645
646  # CHECK: custom.op1
647  # CHECK: custom.op2
648  # CHECK: custom.op1
649  for op in module.body.operations:
650    print(op.operation.name)
651
652run(testOperationName)
653
654# CHECK-LABEL: TEST: testCapsuleConversions
655def testCapsuleConversions():
656  ctx = Context()
657  ctx.allow_unregistered_dialects = True
658  with Location.unknown(ctx):
659    m = Operation.create("custom.op1").operation
660    m_capsule = m._CAPIPtr
661    assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule)
662    m2 = Operation._CAPICreate(m_capsule)
663    assert m2 is m
664
665run(testCapsuleConversions)
666
667# CHECK-LABEL: TEST: testOperationErase
668def testOperationErase():
669  ctx = Context()
670  ctx.allow_unregistered_dialects = True
671  with Location.unknown(ctx):
672    m = Module.create()
673    with InsertionPoint(m.body):
674      op = Operation.create("custom.op1")
675
676      # CHECK: "custom.op1"
677      print(m)
678
679      op.operation.erase()
680
681      # CHECK-NOT: "custom.op1"
682      print(m)
683
684      # Ensure we can create another operation
685      Operation.create("custom.op2")
686
687run(testOperationErase)
688