1# RUN: %PYTHON %s | FileCheck %s
2
3import gc
4import io
5import itertools
6from mlir.ir import *
7
8
9def run(f):
10  print("\nTEST:", f.__name__)
11  f()
12  gc.collect()
13  assert Context._get_live_count() == 0
14  return f
15
16
17def expect_index_error(callback):
18  try:
19    _ = callback()
20    raise RuntimeError("Expected IndexError")
21  except IndexError:
22    pass
23
24
25# Verify iterator based traversal of the op/region/block hierarchy.
26# CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators
27@run
28def testTraverseOpRegionBlockIterators():
29  ctx = Context()
30  ctx.allow_unregistered_dialects = True
31  module = Module.parse(
32      r"""
33    func.func @f1(%arg0: i32) -> i32 {
34      %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
35      return %1 : i32
36    }
37  """, ctx)
38  op = module.operation
39  assert op.context is ctx
40  # Get the block using iterators off of the named collections.
41  regions = list(op.regions)
42  blocks = list(regions[0].blocks)
43  # CHECK: MODULE REGIONS=1 BLOCKS=1
44  print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}")
45
46  # Should verify.
47  # CHECK: .verify = True
48  print(f".verify = {module.operation.verify()}")
49
50  # Get the regions and blocks from the default collections.
51  default_regions = list(op.regions)
52  default_blocks = list(default_regions[0])
53  # They should compare equal regardless of how obtained.
54  assert default_regions == regions
55  assert default_blocks == blocks
56
57  # Should be able to get the operations from either the named collection
58  # or the block.
59  operations = list(blocks[0].operations)
60  default_operations = list(blocks[0])
61  assert default_operations == operations
62
63  def walk_operations(indent, op):
64    for i, region in enumerate(op.regions):
65      print(f"{indent}REGION {i}:")
66      for j, block in enumerate(region):
67        print(f"{indent}  BLOCK {j}:")
68        for k, child_op in enumerate(block):
69          print(f"{indent}    OP {k}: {child_op}")
70          walk_operations(indent + "      ", child_op)
71
72  # CHECK: REGION 0:
73  # CHECK:   BLOCK 0:
74  # CHECK:     OP 0: func
75  # CHECK:       REGION 0:
76  # CHECK:         BLOCK 0:
77  # CHECK:           OP 0: %0 = "custom.addi"
78  # CHECK:           OP 1: func.return
79  walk_operations("", op)
80
81
82# Verify index based traversal of the op/region/block hierarchy.
83# CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices
84@run
85def testTraverseOpRegionBlockIndices():
86  ctx = Context()
87  ctx.allow_unregistered_dialects = True
88  module = Module.parse(
89      r"""
90    func.func @f1(%arg0: i32) -> i32 {
91      %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
92      return %1 : i32
93    }
94  """, ctx)
95
96  def walk_operations(indent, op):
97    for i in range(len(op.regions)):
98      region = op.regions[i]
99      print(f"{indent}REGION {i}:")
100      for j in range(len(region.blocks)):
101        block = region.blocks[j]
102        print(f"{indent}  BLOCK {j}:")
103        for k in range(len(block.operations)):
104          child_op = block.operations[k]
105          print(f"{indent}    OP {k}: {child_op}")
106          print(f"{indent}    OP {k}: parent {child_op.operation.parent.name}")
107          walk_operations(indent + "      ", child_op)
108
109  # CHECK: REGION 0:
110  # CHECK:   BLOCK 0:
111  # CHECK:     OP 0: func
112  # CHECK:     OP 0: parent builtin.module
113  # CHECK:       REGION 0:
114  # CHECK:         BLOCK 0:
115  # CHECK:           OP 0: %0 = "custom.addi"
116  # CHECK:           OP 0: parent func.func
117  # CHECK:           OP 1: func.return
118  # CHECK:           OP 1: parent func.func
119  walk_operations("", module.operation)
120
121
122# CHECK-LABEL: TEST: testBlockAndRegionOwners
123@run
124def testBlockAndRegionOwners():
125  ctx = Context()
126  ctx.allow_unregistered_dialects = True
127  module = Module.parse(
128      r"""
129    builtin.module {
130      func.func @f() {
131        func.return
132      }
133    }
134  """, ctx)
135
136  assert module.operation.regions[0].owner == module.operation
137  assert module.operation.regions[0].blocks[0].owner == module.operation
138
139  func = module.body.operations[0]
140  assert func.operation.regions[0].owner == func
141  assert func.operation.regions[0].blocks[0].owner == func
142
143
144# CHECK-LABEL: TEST: testBlockArgumentList
145@run
146def testBlockArgumentList():
147  with Context() as ctx:
148    module = Module.parse(
149        r"""
150      func.func @f1(%arg0: i32, %arg1: f64, %arg2: index) {
151        return
152      }
153    """, ctx)
154    func = module.body.operations[0]
155    entry_block = func.regions[0].blocks[0]
156    assert len(entry_block.arguments) == 3
157    # CHECK: Argument 0, type i32
158    # CHECK: Argument 1, type f64
159    # CHECK: Argument 2, type index
160    for arg in entry_block.arguments:
161      print(f"Argument {arg.arg_number}, type {arg.type}")
162      new_type = IntegerType.get_signless(8 * (arg.arg_number + 1))
163      arg.set_type(new_type)
164
165    # CHECK: Argument 0, type i8
166    # CHECK: Argument 1, type i16
167    # CHECK: Argument 2, type i24
168    for arg in entry_block.arguments:
169      print(f"Argument {arg.arg_number}, type {arg.type}")
170
171    # Check that slicing works for block argument lists.
172    # CHECK: Argument 1, type i16
173    # CHECK: Argument 2, type i24
174    for arg in entry_block.arguments[1:]:
175      print(f"Argument {arg.arg_number}, type {arg.type}")
176
177    # Check that we can concatenate slices of argument lists.
178    # CHECK: Length: 4
179    print("Length: ",
180          len(entry_block.arguments[:2] + entry_block.arguments[1:]))
181
182    # CHECK: Type: i8
183    # CHECK: Type: i16
184    # CHECK: Type: i24
185    for t in entry_block.arguments.types:
186      print("Type: ", t)
187
188    # Check that slicing and type access compose.
189    # CHECK: Sliced type: i16
190    # CHECK: Sliced type: i24
191    for t in entry_block.arguments[1:].types:
192      print("Sliced type: ", t)
193
194    # Check that slice addition works as expected.
195    # CHECK: Argument 2, type i24
196    # CHECK: Argument 0, type i8
197    restructured = entry_block.arguments[-1:] + entry_block.arguments[:1]
198    for arg in restructured:
199      print(f"Argument {arg.arg_number}, type {arg.type}")
200
201
202# CHECK-LABEL: TEST: testOperationOperands
203@run
204def testOperationOperands():
205  with Context() as ctx:
206    ctx.allow_unregistered_dialects = True
207    module = Module.parse(r"""
208      func.func @f1(%arg0: i32) {
209        %0 = "test.producer"() : () -> i64
210        "test.consumer"(%arg0, %0) : (i32, i64) -> ()
211        return
212      }""")
213    func = module.body.operations[0]
214    entry_block = func.regions[0].blocks[0]
215    consumer = entry_block.operations[1]
216    assert len(consumer.operands) == 2
217    # CHECK: Operand 0, type i32
218    # CHECK: Operand 1, type i64
219    for i, operand in enumerate(consumer.operands):
220      print(f"Operand {i}, type {operand.type}")
221
222
223
224
225# CHECK-LABEL: TEST: testOperationOperandsSlice
226@run
227def testOperationOperandsSlice():
228  with Context() as ctx:
229    ctx.allow_unregistered_dialects = True
230    module = Module.parse(r"""
231      func.func @f1() {
232        %0 = "test.producer0"() : () -> i64
233        %1 = "test.producer1"() : () -> i64
234        %2 = "test.producer2"() : () -> i64
235        %3 = "test.producer3"() : () -> i64
236        %4 = "test.producer4"() : () -> i64
237        "test.consumer"(%0, %1, %2, %3, %4) : (i64, i64, i64, i64, i64) -> ()
238        return
239      }""")
240    func = module.body.operations[0]
241    entry_block = func.regions[0].blocks[0]
242    consumer = entry_block.operations[5]
243    assert len(consumer.operands) == 5
244    for left, right in zip(consumer.operands, consumer.operands[::-1][::-1]):
245      assert left == right
246
247    # CHECK: test.producer0
248    # CHECK: test.producer1
249    # CHECK: test.producer2
250    # CHECK: test.producer3
251    # CHECK: test.producer4
252    full_slice = consumer.operands[:]
253    for operand in full_slice:
254      print(operand)
255
256    # CHECK: test.producer0
257    # CHECK: test.producer1
258    first_two = consumer.operands[0:2]
259    for operand in first_two:
260      print(operand)
261
262    # CHECK: test.producer3
263    # CHECK: test.producer4
264    last_two = consumer.operands[3:]
265    for operand in last_two:
266      print(operand)
267
268    # CHECK: test.producer0
269    # CHECK: test.producer2
270    # CHECK: test.producer4
271    even = consumer.operands[::2]
272    for operand in even:
273      print(operand)
274
275    # CHECK: test.producer2
276    fourth = consumer.operands[::2][1::2]
277    for operand in fourth:
278      print(operand)
279
280
281
282
283# CHECK-LABEL: TEST: testOperationOperandsSet
284@run
285def testOperationOperandsSet():
286  with Context() as ctx, Location.unknown(ctx):
287    ctx.allow_unregistered_dialects = True
288    module = Module.parse(r"""
289      func.func @f1() {
290        %0 = "test.producer0"() : () -> i64
291        %1 = "test.producer1"() : () -> i64
292        %2 = "test.producer2"() : () -> i64
293        "test.consumer"(%0) : (i64) -> ()
294        return
295      }""")
296    func = module.body.operations[0]
297    entry_block = func.regions[0].blocks[0]
298    producer1 = entry_block.operations[1]
299    producer2 = entry_block.operations[2]
300    consumer = entry_block.operations[3]
301    assert len(consumer.operands) == 1
302    type = consumer.operands[0].type
303
304    # CHECK: test.producer1
305    consumer.operands[0] = producer1.result
306    print(consumer.operands[0])
307
308    # CHECK: test.producer2
309    consumer.operands[-1] = producer2.result
310    print(consumer.operands[0])
311
312
313
314
315# CHECK-LABEL: TEST: testDetachedOperation
316@run
317def testDetachedOperation():
318  ctx = Context()
319  ctx.allow_unregistered_dialects = True
320  with Location.unknown(ctx):
321    i32 = IntegerType.get_signed(32)
322    op1 = Operation.create(
323        "custom.op1",
324        results=[i32, i32],
325        regions=1,
326        attributes={
327            "foo": StringAttr.get("foo_value"),
328            "bar": StringAttr.get("bar_value"),
329        })
330    # CHECK: %0:2 = "custom.op1"() ({
331    # CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32)
332    print(op1)
333
334  # TODO: Check successors once enough infra exists to do it properly.
335
336
337# CHECK-LABEL: TEST: testOperationInsertionPoint
338@run
339def testOperationInsertionPoint():
340  ctx = Context()
341  ctx.allow_unregistered_dialects = True
342  module = Module.parse(
343      r"""
344    func.func @f1(%arg0: i32) -> i32 {
345      %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
346      return %1 : i32
347    }
348  """, ctx)
349
350  # Create test op.
351  with Location.unknown(ctx):
352    op1 = Operation.create("custom.op1")
353    op2 = Operation.create("custom.op2")
354
355    func = module.body.operations[0]
356    entry_block = func.regions[0].blocks[0]
357    ip = InsertionPoint.at_block_begin(entry_block)
358    ip.insert(op1)
359    ip.insert(op2)
360    # CHECK: func @f1
361    # CHECK: "custom.op1"()
362    # CHECK: "custom.op2"()
363    # CHECK: %0 = "custom.addi"
364    print(module)
365
366  # Trying to add a previously added op should raise.
367  try:
368    ip.insert(op1)
369  except ValueError:
370    pass
371  else:
372    assert False, "expected insert of attached op to raise"
373
374
375# CHECK-LABEL: TEST: testOperationWithRegion
376@run
377def testOperationWithRegion():
378  ctx = Context()
379  ctx.allow_unregistered_dialects = True
380  with Location.unknown(ctx):
381    i32 = IntegerType.get_signed(32)
382    op1 = Operation.create("custom.op1", regions=1)
383    block = op1.regions[0].blocks.append(i32, i32)
384    # CHECK: "custom.op1"() ({
385    # CHECK: ^bb0(%arg0: si32, %arg1: si32):
386    # CHECK:   "custom.terminator"() : () -> ()
387    # CHECK: }) : () -> ()
388    terminator = Operation.create("custom.terminator")
389    ip = InsertionPoint(block)
390    ip.insert(terminator)
391    print(op1)
392
393    # Now add the whole operation to another op.
394    # TODO: Verify lifetime hazard by nulling out the new owning module and
395    # accessing op1.
396    # TODO: Also verify accessing the terminator once both parents are nulled
397    # out.
398    module = Module.parse(r"""
399      func.func @f1(%arg0: i32) -> i32 {
400        %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
401        return %1 : i32
402      }
403    """)
404    func = module.body.operations[0]
405    entry_block = func.regions[0].blocks[0]
406    ip = InsertionPoint.at_block_begin(entry_block)
407    ip.insert(op1)
408    # CHECK: func @f1
409    # CHECK: "custom.op1"()
410    # CHECK:   "custom.terminator"
411    # CHECK: %0 = "custom.addi"
412    print(module)
413
414
415# CHECK-LABEL: TEST: testOperationResultList
416@run
417def testOperationResultList():
418  ctx = Context()
419  module = Module.parse(
420      r"""
421    func.func @f1() {
422      %0:3 = call @f2() : () -> (i32, f64, index)
423      return
424    }
425    func.func private @f2() -> (i32, f64, index)
426  """, ctx)
427  caller = module.body.operations[0]
428  call = caller.regions[0].blocks[0].operations[0]
429  assert len(call.results) == 3
430  # CHECK: Result 0, type i32
431  # CHECK: Result 1, type f64
432  # CHECK: Result 2, type index
433  for res in call.results:
434    print(f"Result {res.result_number}, type {res.type}")
435
436  # CHECK: Result type i32
437  # CHECK: Result type f64
438  # CHECK: Result type index
439  for t in call.results.types:
440    print(f"Result type {t}")
441
442  # Out of range
443  expect_index_error(lambda: call.results[3])
444  expect_index_error(lambda: call.results[-4])
445
446
447# CHECK-LABEL: TEST: testOperationResultListSlice
448@run
449def testOperationResultListSlice():
450  with Context() as ctx:
451    ctx.allow_unregistered_dialects = True
452    module = Module.parse(r"""
453      func.func @f1() {
454        "some.op"() : () -> (i1, i2, i3, i4, i5)
455        return
456      }
457    """)
458    func = module.body.operations[0]
459    entry_block = func.regions[0].blocks[0]
460    producer = entry_block.operations[0]
461
462    assert len(producer.results) == 5
463    for left, right in zip(producer.results, producer.results[::-1][::-1]):
464      assert left == right
465      assert left.result_number == right.result_number
466
467    # CHECK: Result 0, type i1
468    # CHECK: Result 1, type i2
469    # CHECK: Result 2, type i3
470    # CHECK: Result 3, type i4
471    # CHECK: Result 4, type i5
472    full_slice = producer.results[:]
473    for res in full_slice:
474      print(f"Result {res.result_number}, type {res.type}")
475
476    # CHECK: Result 1, type i2
477    # CHECK: Result 2, type i3
478    # CHECK: Result 3, type i4
479    middle = producer.results[1:4]
480    for res in middle:
481      print(f"Result {res.result_number}, type {res.type}")
482
483    # CHECK: Result 1, type i2
484    # CHECK: Result 3, type i4
485    odd = producer.results[1::2]
486    for res in odd:
487      print(f"Result {res.result_number}, type {res.type}")
488
489    # CHECK: Result 3, type i4
490    # CHECK: Result 1, type i2
491    inverted_middle = producer.results[-2:0:-2]
492    for res in inverted_middle:
493      print(f"Result {res.result_number}, type {res.type}")
494
495
496# CHECK-LABEL: TEST: testOperationAttributes
497@run
498def testOperationAttributes():
499  ctx = Context()
500  ctx.allow_unregistered_dialects = True
501  module = Module.parse(
502      r"""
503    "some.op"() { some.attribute = 1 : i8,
504                  other.attribute = 3.0,
505                  dependent = "text" } : () -> ()
506  """, ctx)
507  op = module.body.operations[0]
508  assert len(op.attributes) == 3
509  iattr = IntegerAttr(op.attributes["some.attribute"])
510  fattr = FloatAttr(op.attributes["other.attribute"])
511  sattr = StringAttr(op.attributes["dependent"])
512  # CHECK: Attribute type i8, value 1
513  print(f"Attribute type {iattr.type}, value {iattr.value}")
514  # CHECK: Attribute type f64, value 3.0
515  print(f"Attribute type {fattr.type}, value {fattr.value}")
516  # CHECK: Attribute value text
517  print(f"Attribute value {sattr.value}")
518
519  # We don't know in which order the attributes are stored.
520  # CHECK-DAG: NamedAttribute(dependent="text")
521  # CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64)
522  # CHECK-DAG: NamedAttribute(some.attribute=1 : i8)
523  for attr in op.attributes:
524    print(str(attr))
525
526  # Check that exceptions are raised as expected.
527  try:
528    op.attributes["does_not_exist"]
529  except KeyError:
530    pass
531  else:
532    assert False, "expected KeyError on accessing a non-existent attribute"
533
534  try:
535    op.attributes[42]
536  except IndexError:
537    pass
538  else:
539    assert False, "expected IndexError on accessing an out-of-bounds attribute"
540
541
542
543
544# CHECK-LABEL: TEST: testOperationPrint
545@run
546def testOperationPrint():
547  ctx = Context()
548  module = Module.parse(
549      r"""
550    func.func @f1(%arg0: i32) -> i32 {
551      %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom")
552      return %arg0 : i32
553    }
554  """, ctx)
555
556  # Test print to stdout.
557  # CHECK: return %arg0 : i32
558  module.operation.print()
559
560  # Test print to text file.
561  f = io.StringIO()
562  # CHECK: <class 'str'>
563  # CHECK: return %arg0 : i32
564  module.operation.print(file=f)
565  str_value = f.getvalue()
566  print(str_value.__class__)
567  print(f.getvalue())
568
569  # Test print to binary file.
570  f = io.BytesIO()
571  # CHECK: <class 'bytes'>
572  # CHECK: return %arg0 : i32
573  module.operation.print(file=f, binary=True)
574  bytes_value = f.getvalue()
575  print(bytes_value.__class__)
576  print(bytes_value)
577
578  # Test get_asm local_scope.
579  # CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom")
580  module.operation.print(enable_debug_info=True, use_local_scope=True)
581
582  # Test get_asm with options.
583  # CHECK: value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<4xi32>
584  # CHECK: "func.return"(%arg0) : (i32) -> () -:4:7
585  module.operation.print(
586      large_elements_limit=2,
587      enable_debug_info=True,
588      pretty_debug_info=True,
589      print_generic_op_form=True,
590      use_local_scope=True)
591
592
593
594
595# CHECK-LABEL: TEST: testKnownOpView
596@run
597def testKnownOpView():
598  with Context(), Location.unknown():
599    Context.current.allow_unregistered_dialects = True
600    module = Module.parse(r"""
601      %1 = "custom.f32"() : () -> f32
602      %2 = "custom.f32"() : () -> f32
603      %3 = arith.addf %1, %2 : f32
604    """)
605    print(module)
606
607    # addf should map to a known OpView class in the arithmetic dialect.
608    # We know the OpView for it defines an 'lhs' attribute.
609    addf = module.body.operations[2]
610    # CHECK: <mlir.dialects._arith_ops_gen._AddFOp object
611    print(repr(addf))
612    # CHECK: "custom.f32"()
613    print(addf.lhs)
614
615    # One of the custom ops should resolve to the default OpView.
616    custom = module.body.operations[0]
617    # CHECK: OpView object
618    print(repr(custom))
619
620    # Check again to make sure negative caching works.
621    custom = module.body.operations[0]
622    # CHECK: OpView object
623    print(repr(custom))
624
625
626# CHECK-LABEL: TEST: testSingleResultProperty
627@run
628def testSingleResultProperty():
629  with Context(), Location.unknown():
630    Context.current.allow_unregistered_dialects = True
631    module = Module.parse(r"""
632      "custom.no_result"() : () -> ()
633      %0:2 = "custom.two_result"() : () -> (f32, f32)
634      %1 = "custom.one_result"() : () -> f32
635    """)
636    print(module)
637
638  try:
639    module.body.operations[0].result
640  except ValueError as e:
641    # CHECK: Cannot call .result on operation custom.no_result which has 0 results
642    print(e)
643  else:
644    assert False, "Expected exception"
645
646  try:
647    module.body.operations[1].result
648  except ValueError as e:
649    # CHECK: Cannot call .result on operation custom.two_result which has 2 results
650    print(e)
651  else:
652    assert False, "Expected exception"
653
654  # CHECK: %1 = "custom.one_result"() : () -> f32
655  print(module.body.operations[2])
656
657
658def create_invalid_operation():
659  # This module has two region and is invalid verify that we fallback
660  # to the generic printer for safety.
661  op = Operation.create("builtin.module", regions=2)
662  op.regions[0].blocks.append()
663  return op
664
665# CHECK-LABEL: TEST: testInvalidOperationStrSoftFails
666@run
667def testInvalidOperationStrSoftFails():
668  ctx = Context()
669  with Location.unknown(ctx):
670    invalid_op = create_invalid_operation()
671    # Verify that we fallback to the generic printer for safety.
672    # CHECK: // Verification failed, printing generic form
673    # CHECK: "builtin.module"() ({
674    # CHECK: }) : () -> ()
675    print(invalid_op)
676    # CHECK: .verify = False
677    print(f".verify = {invalid_op.operation.verify()}")
678
679
680# CHECK-LABEL: TEST: testInvalidModuleStrSoftFails
681@run
682def testInvalidModuleStrSoftFails():
683  ctx = Context()
684  with Location.unknown(ctx):
685    module = Module.create()
686    with InsertionPoint(module.body):
687      invalid_op = create_invalid_operation()
688    # Verify that we fallback to the generic printer for safety.
689    # CHECK: // Verification failed, printing generic form
690    print(module)
691
692
693# CHECK-LABEL: TEST: testInvalidOperationGetAsmBinarySoftFails
694@run
695def testInvalidOperationGetAsmBinarySoftFails():
696  ctx = Context()
697  with Location.unknown(ctx):
698    invalid_op = create_invalid_operation()
699    # Verify that we fallback to the generic printer for safety.
700    # CHECK: b'// Verification failed, printing generic form\n
701    print(invalid_op.get_asm(binary=True))
702
703
704# CHECK-LABEL: TEST: testCreateWithInvalidAttributes
705@run
706def testCreateWithInvalidAttributes():
707  ctx = Context()
708  with Location.unknown(ctx):
709    try:
710      Operation.create(
711          "builtin.module", attributes={None: StringAttr.get("name")})
712    except Exception as e:
713      # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
714      print(e)
715    try:
716      Operation.create(
717          "builtin.module", attributes={42: StringAttr.get("name")})
718    except Exception as e:
719      # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
720      print(e)
721    try:
722      Operation.create("builtin.module", attributes={"some_key": ctx})
723    except Exception as e:
724      # CHECK: Invalid attribute value for the key "some_key" when attempting to create the operation "builtin.module"
725      print(e)
726    try:
727      Operation.create("builtin.module", attributes={"some_key": None})
728    except Exception as e:
729      # CHECK: Found an invalid (`None`?) attribute value for the key "some_key" when attempting to create the operation "builtin.module"
730      print(e)
731
732
733# CHECK-LABEL: TEST: testOperationName
734@run
735def testOperationName():
736  ctx = Context()
737  ctx.allow_unregistered_dialects = True
738  module = Module.parse(
739      r"""
740    %0 = "custom.op1"() : () -> f32
741    %1 = "custom.op2"() : () -> i32
742    %2 = "custom.op1"() : () -> f32
743  """, ctx)
744
745  # CHECK: custom.op1
746  # CHECK: custom.op2
747  # CHECK: custom.op1
748  for op in module.body.operations:
749    print(op.operation.name)
750
751
752# CHECK-LABEL: TEST: testCapsuleConversions
753@run
754def testCapsuleConversions():
755  ctx = Context()
756  ctx.allow_unregistered_dialects = True
757  with Location.unknown(ctx):
758    m = Operation.create("custom.op1").operation
759    m_capsule = m._CAPIPtr
760    assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule)
761    m2 = Operation._CAPICreate(m_capsule)
762    assert m2 is m
763
764
765# CHECK-LABEL: TEST: testOperationErase
766@run
767def testOperationErase():
768  ctx = Context()
769  ctx.allow_unregistered_dialects = True
770  with Location.unknown(ctx):
771    m = Module.create()
772    with InsertionPoint(m.body):
773      op = Operation.create("custom.op1")
774
775      # CHECK: "custom.op1"
776      print(m)
777
778      op.operation.erase()
779
780      # CHECK-NOT: "custom.op1"
781      print(m)
782
783      # Ensure we can create another operation
784      Operation.create("custom.op2")
785
786
787# CHECK-LABEL: TEST: testOperationClone
788@run
789def testOperationClone():
790  ctx = Context()
791  ctx.allow_unregistered_dialects = True
792  with Location.unknown(ctx):
793    m = Module.create()
794    with InsertionPoint(m.body):
795      op = Operation.create("custom.op1")
796
797      # CHECK: "custom.op1"
798      print(m)
799
800      clone = op.operation.clone()
801      op.operation.erase()
802
803      # CHECK: "custom.op1"
804      print(m)
805
806
807# CHECK-LABEL: TEST: testOperationLoc
808@run
809def testOperationLoc():
810  ctx = Context()
811  ctx.allow_unregistered_dialects = True
812  with ctx:
813    loc = Location.name("loc")
814    op = Operation.create("custom.op", loc=loc)
815    assert op.location == loc
816    assert op.operation.location == loc
817
818
819# CHECK-LABEL: TEST: testModuleMerge
820@run
821def testModuleMerge():
822  with Context():
823    m1 = Module.parse("func.func private @foo()")
824    m2 = Module.parse("""
825      func.func private @bar()
826      func.func private @qux()
827    """)
828    foo = m1.body.operations[0]
829    bar = m2.body.operations[0]
830    qux = m2.body.operations[1]
831    bar.move_before(foo)
832    qux.move_after(foo)
833
834    # CHECK: module
835    # CHECK: func private @bar
836    # CHECK: func private @foo
837    # CHECK: func private @qux
838    print(m1)
839
840    # CHECK: module {
841    # CHECK-NEXT: }
842    print(m2)
843
844
845# CHECK-LABEL: TEST: testAppendMoveFromAnotherBlock
846@run
847def testAppendMoveFromAnotherBlock():
848  with Context():
849    m1 = Module.parse("func.func private @foo()")
850    m2 = Module.parse("func.func private @bar()")
851    func = m1.body.operations[0]
852    m2.body.append(func)
853
854    # CHECK: module
855    # CHECK: func private @bar
856    # CHECK: func private @foo
857
858    print(m2)
859    # CHECK: module {
860    # CHECK-NEXT: }
861    print(m1)
862
863
864# CHECK-LABEL: TEST: testDetachFromParent
865@run
866def testDetachFromParent():
867  with Context():
868    m1 = Module.parse("func.func private @foo()")
869    func = m1.body.operations[0].detach_from_parent()
870
871    try:
872      func.detach_from_parent()
873    except ValueError as e:
874      if "has no parent" not in str(e):
875        raise
876    else:
877      assert False, "expected ValueError when detaching a detached operation"
878
879    print(m1)
880    # CHECK-NOT: func private @foo
881
882
883# CHECK-LABEL: TEST: testOperationHash
884@run
885def testOperationHash():
886  ctx = Context()
887  ctx.allow_unregistered_dialects = True
888  with ctx, Location.unknown():
889    op = Operation.create("custom.op1")
890    assert hash(op) == hash(op.operation)
891