1# RUN: %PYTHON %s | FileCheck %s
2
3import gc
4import io
5import itertools
6from mlir.ir import *
7
8
9def run(f):
10  print("\nTEST:", f.__name__)
11  f()
12  gc.collect()
13  assert Context._get_live_count() == 0
14  return f
15
16
17def expect_index_error(callback):
18  try:
19    _ = callback()
20    raise RuntimeError("Expected IndexError")
21  except IndexError:
22    pass
23
24
25# Verify iterator based traversal of the op/region/block hierarchy.
26# CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators
27@run
28def testTraverseOpRegionBlockIterators():
29  ctx = Context()
30  ctx.allow_unregistered_dialects = True
31  module = Module.parse(
32      r"""
33    func.func @f1(%arg0: i32) -> i32 {
34      %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
35      return %1 : i32
36    }
37  """, ctx)
38  op = module.operation
39  assert op.context is ctx
40  # Get the block using iterators off of the named collections.
41  regions = list(op.regions)
42  blocks = list(regions[0].blocks)
43  # CHECK: MODULE REGIONS=1 BLOCKS=1
44  print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}")
45
46  # Should verify.
47  # CHECK: .verify = True
48  print(f".verify = {module.operation.verify()}")
49
50  # Get the regions and blocks from the default collections.
51  default_regions = list(op.regions)
52  default_blocks = list(default_regions[0])
53  # They should compare equal regardless of how obtained.
54  assert default_regions == regions
55  assert default_blocks == blocks
56
57  # Should be able to get the operations from either the named collection
58  # or the block.
59  operations = list(blocks[0].operations)
60  default_operations = list(blocks[0])
61  assert default_operations == operations
62
63  def walk_operations(indent, op):
64    for i, region in enumerate(op.regions):
65      print(f"{indent}REGION {i}:")
66      for j, block in enumerate(region):
67        print(f"{indent}  BLOCK {j}:")
68        for k, child_op in enumerate(block):
69          print(f"{indent}    OP {k}: {child_op}")
70          walk_operations(indent + "      ", child_op)
71
72  # CHECK: REGION 0:
73  # CHECK:   BLOCK 0:
74  # CHECK:     OP 0: func
75  # CHECK:       REGION 0:
76  # CHECK:         BLOCK 0:
77  # CHECK:           OP 0: %0 = "custom.addi"
78  # CHECK:           OP 1: func.return
79  walk_operations("", op)
80
81
82# Verify index based traversal of the op/region/block hierarchy.
83# CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices
84@run
85def testTraverseOpRegionBlockIndices():
86  ctx = Context()
87  ctx.allow_unregistered_dialects = True
88  module = Module.parse(
89      r"""
90    func.func @f1(%arg0: i32) -> i32 {
91      %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
92      return %1 : i32
93    }
94  """, ctx)
95
96  def walk_operations(indent, op):
97    for i in range(len(op.regions)):
98      region = op.regions[i]
99      print(f"{indent}REGION {i}:")
100      for j in range(len(region.blocks)):
101        block = region.blocks[j]
102        print(f"{indent}  BLOCK {j}:")
103        for k in range(len(block.operations)):
104          child_op = block.operations[k]
105          print(f"{indent}    OP {k}: {child_op}")
106          print(f"{indent}    OP {k}: parent {child_op.operation.parent.name}")
107          walk_operations(indent + "      ", child_op)
108
109  # CHECK: REGION 0:
110  # CHECK:   BLOCK 0:
111  # CHECK:     OP 0: func
112  # CHECK:     OP 0: parent builtin.module
113  # CHECK:       REGION 0:
114  # CHECK:         BLOCK 0:
115  # CHECK:           OP 0: %0 = "custom.addi"
116  # CHECK:           OP 0: parent func.func
117  # CHECK:           OP 1: func.return
118  # CHECK:           OP 1: parent func.func
119  walk_operations("", module.operation)
120
121
122# CHECK-LABEL: TEST: testBlockAndRegionOwners
123@run
124def testBlockAndRegionOwners():
125  ctx = Context()
126  ctx.allow_unregistered_dialects = True
127  module = Module.parse(
128      r"""
129    builtin.module {
130      func.func @f() {
131        func.return
132      }
133    }
134  """, ctx)
135
136  assert module.operation.regions[0].owner == module.operation
137  assert module.operation.regions[0].blocks[0].owner == module.operation
138
139  func = module.body.operations[0]
140  assert func.operation.regions[0].owner == func
141  assert func.operation.regions[0].blocks[0].owner == func
142
143
144# CHECK-LABEL: TEST: testBlockArgumentList
145@run
146def testBlockArgumentList():
147  with Context() as ctx:
148    module = Module.parse(
149        r"""
150      func.func @f1(%arg0: i32, %arg1: f64, %arg2: index) {
151        return
152      }
153    """, ctx)
154    func = module.body.operations[0]
155    entry_block = func.regions[0].blocks[0]
156    assert len(entry_block.arguments) == 3
157    # CHECK: Argument 0, type i32
158    # CHECK: Argument 1, type f64
159    # CHECK: Argument 2, type index
160    for arg in entry_block.arguments:
161      print(f"Argument {arg.arg_number}, type {arg.type}")
162      new_type = IntegerType.get_signless(8 * (arg.arg_number + 1))
163      arg.set_type(new_type)
164
165    # CHECK: Argument 0, type i8
166    # CHECK: Argument 1, type i16
167    # CHECK: Argument 2, type i24
168    for arg in entry_block.arguments:
169      print(f"Argument {arg.arg_number}, type {arg.type}")
170
171    # Check that slicing works for block argument lists.
172    # CHECK: Argument 1, type i16
173    # CHECK: Argument 2, type i24
174    for arg in entry_block.arguments[1:]:
175      print(f"Argument {arg.arg_number}, type {arg.type}")
176
177    # Check that we can concatenate slices of argument lists.
178    # CHECK: Length: 4
179    print("Length: ",
180          len(entry_block.arguments[:2] + entry_block.arguments[1:]))
181
182    # CHECK: Type: i8
183    # CHECK: Type: i16
184    # CHECK: Type: i24
185    for t in entry_block.arguments.types:
186      print("Type: ", t)
187
188
189# CHECK-LABEL: TEST: testOperationOperands
190@run
191def testOperationOperands():
192  with Context() as ctx:
193    ctx.allow_unregistered_dialects = True
194    module = Module.parse(r"""
195      func.func @f1(%arg0: i32) {
196        %0 = "test.producer"() : () -> i64
197        "test.consumer"(%arg0, %0) : (i32, i64) -> ()
198        return
199      }""")
200    func = module.body.operations[0]
201    entry_block = func.regions[0].blocks[0]
202    consumer = entry_block.operations[1]
203    assert len(consumer.operands) == 2
204    # CHECK: Operand 0, type i32
205    # CHECK: Operand 1, type i64
206    for i, operand in enumerate(consumer.operands):
207      print(f"Operand {i}, type {operand.type}")
208
209
210
211
212# CHECK-LABEL: TEST: testOperationOperandsSlice
213@run
214def testOperationOperandsSlice():
215  with Context() as ctx:
216    ctx.allow_unregistered_dialects = True
217    module = Module.parse(r"""
218      func.func @f1() {
219        %0 = "test.producer0"() : () -> i64
220        %1 = "test.producer1"() : () -> i64
221        %2 = "test.producer2"() : () -> i64
222        %3 = "test.producer3"() : () -> i64
223        %4 = "test.producer4"() : () -> i64
224        "test.consumer"(%0, %1, %2, %3, %4) : (i64, i64, i64, i64, i64) -> ()
225        return
226      }""")
227    func = module.body.operations[0]
228    entry_block = func.regions[0].blocks[0]
229    consumer = entry_block.operations[5]
230    assert len(consumer.operands) == 5
231    for left, right in zip(consumer.operands, consumer.operands[::-1][::-1]):
232      assert left == right
233
234    # CHECK: test.producer0
235    # CHECK: test.producer1
236    # CHECK: test.producer2
237    # CHECK: test.producer3
238    # CHECK: test.producer4
239    full_slice = consumer.operands[:]
240    for operand in full_slice:
241      print(operand)
242
243    # CHECK: test.producer0
244    # CHECK: test.producer1
245    first_two = consumer.operands[0:2]
246    for operand in first_two:
247      print(operand)
248
249    # CHECK: test.producer3
250    # CHECK: test.producer4
251    last_two = consumer.operands[3:]
252    for operand in last_two:
253      print(operand)
254
255    # CHECK: test.producer0
256    # CHECK: test.producer2
257    # CHECK: test.producer4
258    even = consumer.operands[::2]
259    for operand in even:
260      print(operand)
261
262    # CHECK: test.producer2
263    fourth = consumer.operands[::2][1::2]
264    for operand in fourth:
265      print(operand)
266
267
268
269
270# CHECK-LABEL: TEST: testOperationOperandsSet
271@run
272def testOperationOperandsSet():
273  with Context() as ctx, Location.unknown(ctx):
274    ctx.allow_unregistered_dialects = True
275    module = Module.parse(r"""
276      func.func @f1() {
277        %0 = "test.producer0"() : () -> i64
278        %1 = "test.producer1"() : () -> i64
279        %2 = "test.producer2"() : () -> i64
280        "test.consumer"(%0) : (i64) -> ()
281        return
282      }""")
283    func = module.body.operations[0]
284    entry_block = func.regions[0].blocks[0]
285    producer1 = entry_block.operations[1]
286    producer2 = entry_block.operations[2]
287    consumer = entry_block.operations[3]
288    assert len(consumer.operands) == 1
289    type = consumer.operands[0].type
290
291    # CHECK: test.producer1
292    consumer.operands[0] = producer1.result
293    print(consumer.operands[0])
294
295    # CHECK: test.producer2
296    consumer.operands[-1] = producer2.result
297    print(consumer.operands[0])
298
299
300
301
302# CHECK-LABEL: TEST: testDetachedOperation
303@run
304def testDetachedOperation():
305  ctx = Context()
306  ctx.allow_unregistered_dialects = True
307  with Location.unknown(ctx):
308    i32 = IntegerType.get_signed(32)
309    op1 = Operation.create(
310        "custom.op1",
311        results=[i32, i32],
312        regions=1,
313        attributes={
314            "foo": StringAttr.get("foo_value"),
315            "bar": StringAttr.get("bar_value"),
316        })
317    # CHECK: %0:2 = "custom.op1"() ({
318    # CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32)
319    print(op1)
320
321  # TODO: Check successors once enough infra exists to do it properly.
322
323
324# CHECK-LABEL: TEST: testOperationInsertionPoint
325@run
326def testOperationInsertionPoint():
327  ctx = Context()
328  ctx.allow_unregistered_dialects = True
329  module = Module.parse(
330      r"""
331    func.func @f1(%arg0: i32) -> i32 {
332      %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
333      return %1 : i32
334    }
335  """, ctx)
336
337  # Create test op.
338  with Location.unknown(ctx):
339    op1 = Operation.create("custom.op1")
340    op2 = Operation.create("custom.op2")
341
342    func = module.body.operations[0]
343    entry_block = func.regions[0].blocks[0]
344    ip = InsertionPoint.at_block_begin(entry_block)
345    ip.insert(op1)
346    ip.insert(op2)
347    # CHECK: func @f1
348    # CHECK: "custom.op1"()
349    # CHECK: "custom.op2"()
350    # CHECK: %0 = "custom.addi"
351    print(module)
352
353  # Trying to add a previously added op should raise.
354  try:
355    ip.insert(op1)
356  except ValueError:
357    pass
358  else:
359    assert False, "expected insert of attached op to raise"
360
361
362# CHECK-LABEL: TEST: testOperationWithRegion
363@run
364def testOperationWithRegion():
365  ctx = Context()
366  ctx.allow_unregistered_dialects = True
367  with Location.unknown(ctx):
368    i32 = IntegerType.get_signed(32)
369    op1 = Operation.create("custom.op1", regions=1)
370    block = op1.regions[0].blocks.append(i32, i32)
371    # CHECK: "custom.op1"() ({
372    # CHECK: ^bb0(%arg0: si32, %arg1: si32):
373    # CHECK:   "custom.terminator"() : () -> ()
374    # CHECK: }) : () -> ()
375    terminator = Operation.create("custom.terminator")
376    ip = InsertionPoint(block)
377    ip.insert(terminator)
378    print(op1)
379
380    # Now add the whole operation to another op.
381    # TODO: Verify lifetime hazard by nulling out the new owning module and
382    # accessing op1.
383    # TODO: Also verify accessing the terminator once both parents are nulled
384    # out.
385    module = Module.parse(r"""
386      func.func @f1(%arg0: i32) -> i32 {
387        %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
388        return %1 : i32
389      }
390    """)
391    func = module.body.operations[0]
392    entry_block = func.regions[0].blocks[0]
393    ip = InsertionPoint.at_block_begin(entry_block)
394    ip.insert(op1)
395    # CHECK: func @f1
396    # CHECK: "custom.op1"()
397    # CHECK:   "custom.terminator"
398    # CHECK: %0 = "custom.addi"
399    print(module)
400
401
402# CHECK-LABEL: TEST: testOperationResultList
403@run
404def testOperationResultList():
405  ctx = Context()
406  module = Module.parse(
407      r"""
408    func.func @f1() {
409      %0:3 = call @f2() : () -> (i32, f64, index)
410      return
411    }
412    func.func private @f2() -> (i32, f64, index)
413  """, ctx)
414  caller = module.body.operations[0]
415  call = caller.regions[0].blocks[0].operations[0]
416  assert len(call.results) == 3
417  # CHECK: Result 0, type i32
418  # CHECK: Result 1, type f64
419  # CHECK: Result 2, type index
420  for res in call.results:
421    print(f"Result {res.result_number}, type {res.type}")
422
423  # CHECK: Result type i32
424  # CHECK: Result type f64
425  # CHECK: Result type index
426  for t in call.results.types:
427    print(f"Result type {t}")
428
429  # Out of range
430  expect_index_error(lambda: call.results[3])
431  expect_index_error(lambda: call.results[-4])
432
433
434# CHECK-LABEL: TEST: testOperationResultListSlice
435@run
436def testOperationResultListSlice():
437  with Context() as ctx:
438    ctx.allow_unregistered_dialects = True
439    module = Module.parse(r"""
440      func.func @f1() {
441        "some.op"() : () -> (i1, i2, i3, i4, i5)
442        return
443      }
444    """)
445    func = module.body.operations[0]
446    entry_block = func.regions[0].blocks[0]
447    producer = entry_block.operations[0]
448
449    assert len(producer.results) == 5
450    for left, right in zip(producer.results, producer.results[::-1][::-1]):
451      assert left == right
452      assert left.result_number == right.result_number
453
454    # CHECK: Result 0, type i1
455    # CHECK: Result 1, type i2
456    # CHECK: Result 2, type i3
457    # CHECK: Result 3, type i4
458    # CHECK: Result 4, type i5
459    full_slice = producer.results[:]
460    for res in full_slice:
461      print(f"Result {res.result_number}, type {res.type}")
462
463    # CHECK: Result 1, type i2
464    # CHECK: Result 2, type i3
465    # CHECK: Result 3, type i4
466    middle = producer.results[1:4]
467    for res in middle:
468      print(f"Result {res.result_number}, type {res.type}")
469
470    # CHECK: Result 1, type i2
471    # CHECK: Result 3, type i4
472    odd = producer.results[1::2]
473    for res in odd:
474      print(f"Result {res.result_number}, type {res.type}")
475
476    # CHECK: Result 3, type i4
477    # CHECK: Result 1, type i2
478    inverted_middle = producer.results[-2:0:-2]
479    for res in inverted_middle:
480      print(f"Result {res.result_number}, type {res.type}")
481
482
483# CHECK-LABEL: TEST: testOperationAttributes
484@run
485def testOperationAttributes():
486  ctx = Context()
487  ctx.allow_unregistered_dialects = True
488  module = Module.parse(
489      r"""
490    "some.op"() { some.attribute = 1 : i8,
491                  other.attribute = 3.0,
492                  dependent = "text" } : () -> ()
493  """, ctx)
494  op = module.body.operations[0]
495  assert len(op.attributes) == 3
496  iattr = IntegerAttr(op.attributes["some.attribute"])
497  fattr = FloatAttr(op.attributes["other.attribute"])
498  sattr = StringAttr(op.attributes["dependent"])
499  # CHECK: Attribute type i8, value 1
500  print(f"Attribute type {iattr.type}, value {iattr.value}")
501  # CHECK: Attribute type f64, value 3.0
502  print(f"Attribute type {fattr.type}, value {fattr.value}")
503  # CHECK: Attribute value text
504  print(f"Attribute value {sattr.value}")
505
506  # We don't know in which order the attributes are stored.
507  # CHECK-DAG: NamedAttribute(dependent="text")
508  # CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64)
509  # CHECK-DAG: NamedAttribute(some.attribute=1 : i8)
510  for attr in op.attributes:
511    print(str(attr))
512
513  # Check that exceptions are raised as expected.
514  try:
515    op.attributes["does_not_exist"]
516  except KeyError:
517    pass
518  else:
519    assert False, "expected KeyError on accessing a non-existent attribute"
520
521  try:
522    op.attributes[42]
523  except IndexError:
524    pass
525  else:
526    assert False, "expected IndexError on accessing an out-of-bounds attribute"
527
528
529
530
531# CHECK-LABEL: TEST: testOperationPrint
532@run
533def testOperationPrint():
534  ctx = Context()
535  module = Module.parse(
536      r"""
537    func.func @f1(%arg0: i32) -> i32 {
538      %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom")
539      return %arg0 : i32
540    }
541  """, ctx)
542
543  # Test print to stdout.
544  # CHECK: return %arg0 : i32
545  module.operation.print()
546
547  # Test print to text file.
548  f = io.StringIO()
549  # CHECK: <class 'str'>
550  # CHECK: return %arg0 : i32
551  module.operation.print(file=f)
552  str_value = f.getvalue()
553  print(str_value.__class__)
554  print(f.getvalue())
555
556  # Test print to binary file.
557  f = io.BytesIO()
558  # CHECK: <class 'bytes'>
559  # CHECK: return %arg0 : i32
560  module.operation.print(file=f, binary=True)
561  bytes_value = f.getvalue()
562  print(bytes_value.__class__)
563  print(bytes_value)
564
565  # Test get_asm local_scope.
566  # CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom")
567  module.operation.print(enable_debug_info=True, use_local_scope=True)
568
569  # Test get_asm with options.
570  # CHECK: value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<4xi32>
571  # CHECK: "func.return"(%arg0) : (i32) -> () -:4:7
572  module.operation.print(
573      large_elements_limit=2,
574      enable_debug_info=True,
575      pretty_debug_info=True,
576      print_generic_op_form=True,
577      use_local_scope=True)
578
579
580
581
582# CHECK-LABEL: TEST: testKnownOpView
583@run
584def testKnownOpView():
585  with Context(), Location.unknown():
586    Context.current.allow_unregistered_dialects = True
587    module = Module.parse(r"""
588      %1 = "custom.f32"() : () -> f32
589      %2 = "custom.f32"() : () -> f32
590      %3 = arith.addf %1, %2 : f32
591    """)
592    print(module)
593
594    # addf should map to a known OpView class in the arithmetic dialect.
595    # We know the OpView for it defines an 'lhs' attribute.
596    addf = module.body.operations[2]
597    # CHECK: <mlir.dialects._arith_ops_gen._AddFOp object
598    print(repr(addf))
599    # CHECK: "custom.f32"()
600    print(addf.lhs)
601
602    # One of the custom ops should resolve to the default OpView.
603    custom = module.body.operations[0]
604    # CHECK: OpView object
605    print(repr(custom))
606
607    # Check again to make sure negative caching works.
608    custom = module.body.operations[0]
609    # CHECK: OpView object
610    print(repr(custom))
611
612
613# CHECK-LABEL: TEST: testSingleResultProperty
614@run
615def testSingleResultProperty():
616  with Context(), Location.unknown():
617    Context.current.allow_unregistered_dialects = True
618    module = Module.parse(r"""
619      "custom.no_result"() : () -> ()
620      %0:2 = "custom.two_result"() : () -> (f32, f32)
621      %1 = "custom.one_result"() : () -> f32
622    """)
623    print(module)
624
625  try:
626    module.body.operations[0].result
627  except ValueError as e:
628    # CHECK: Cannot call .result on operation custom.no_result which has 0 results
629    print(e)
630  else:
631    assert False, "Expected exception"
632
633  try:
634    module.body.operations[1].result
635  except ValueError as e:
636    # CHECK: Cannot call .result on operation custom.two_result which has 2 results
637    print(e)
638  else:
639    assert False, "Expected exception"
640
641  # CHECK: %1 = "custom.one_result"() : () -> f32
642  print(module.body.operations[2])
643
644
645def create_invalid_operation():
646  # This module has two region and is invalid verify that we fallback
647  # to the generic printer for safety.
648  op = Operation.create("builtin.module", regions=2)
649  op.regions[0].blocks.append()
650  return op
651
652# CHECK-LABEL: TEST: testInvalidOperationStrSoftFails
653@run
654def testInvalidOperationStrSoftFails():
655  ctx = Context()
656  with Location.unknown(ctx):
657    invalid_op = create_invalid_operation()
658    # Verify that we fallback to the generic printer for safety.
659    # CHECK: // Verification failed, printing generic form
660    # CHECK: "builtin.module"() ({
661    # CHECK: }) : () -> ()
662    print(invalid_op)
663    # CHECK: .verify = False
664    print(f".verify = {invalid_op.operation.verify()}")
665
666
667# CHECK-LABEL: TEST: testInvalidModuleStrSoftFails
668@run
669def testInvalidModuleStrSoftFails():
670  ctx = Context()
671  with Location.unknown(ctx):
672    module = Module.create()
673    with InsertionPoint(module.body):
674      invalid_op = create_invalid_operation()
675    # Verify that we fallback to the generic printer for safety.
676    # CHECK: // Verification failed, printing generic form
677    print(module)
678
679
680# CHECK-LABEL: TEST: testInvalidOperationGetAsmBinarySoftFails
681@run
682def testInvalidOperationGetAsmBinarySoftFails():
683  ctx = Context()
684  with Location.unknown(ctx):
685    invalid_op = create_invalid_operation()
686    # Verify that we fallback to the generic printer for safety.
687    # CHECK: b'// Verification failed, printing generic form\n
688    print(invalid_op.get_asm(binary=True))
689
690
691# CHECK-LABEL: TEST: testCreateWithInvalidAttributes
692@run
693def testCreateWithInvalidAttributes():
694  ctx = Context()
695  with Location.unknown(ctx):
696    try:
697      Operation.create(
698          "builtin.module", attributes={None: StringAttr.get("name")})
699    except Exception as e:
700      # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
701      print(e)
702    try:
703      Operation.create(
704          "builtin.module", attributes={42: StringAttr.get("name")})
705    except Exception as e:
706      # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
707      print(e)
708    try:
709      Operation.create("builtin.module", attributes={"some_key": ctx})
710    except Exception as e:
711      # CHECK: Invalid attribute value for the key "some_key" when attempting to create the operation "builtin.module"
712      print(e)
713    try:
714      Operation.create("builtin.module", attributes={"some_key": None})
715    except Exception as e:
716      # CHECK: Found an invalid (`None`?) attribute value for the key "some_key" when attempting to create the operation "builtin.module"
717      print(e)
718
719
720# CHECK-LABEL: TEST: testOperationName
721@run
722def testOperationName():
723  ctx = Context()
724  ctx.allow_unregistered_dialects = True
725  module = Module.parse(
726      r"""
727    %0 = "custom.op1"() : () -> f32
728    %1 = "custom.op2"() : () -> i32
729    %2 = "custom.op1"() : () -> f32
730  """, ctx)
731
732  # CHECK: custom.op1
733  # CHECK: custom.op2
734  # CHECK: custom.op1
735  for op in module.body.operations:
736    print(op.operation.name)
737
738
739# CHECK-LABEL: TEST: testCapsuleConversions
740@run
741def testCapsuleConversions():
742  ctx = Context()
743  ctx.allow_unregistered_dialects = True
744  with Location.unknown(ctx):
745    m = Operation.create("custom.op1").operation
746    m_capsule = m._CAPIPtr
747    assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule)
748    m2 = Operation._CAPICreate(m_capsule)
749    assert m2 is m
750
751
752# CHECK-LABEL: TEST: testOperationErase
753@run
754def testOperationErase():
755  ctx = Context()
756  ctx.allow_unregistered_dialects = True
757  with Location.unknown(ctx):
758    m = Module.create()
759    with InsertionPoint(m.body):
760      op = Operation.create("custom.op1")
761
762      # CHECK: "custom.op1"
763      print(m)
764
765      op.operation.erase()
766
767      # CHECK-NOT: "custom.op1"
768      print(m)
769
770      # Ensure we can create another operation
771      Operation.create("custom.op2")
772
773
774# CHECK-LABEL: TEST: testOperationClone
775@run
776def testOperationClone():
777  ctx = Context()
778  ctx.allow_unregistered_dialects = True
779  with Location.unknown(ctx):
780    m = Module.create()
781    with InsertionPoint(m.body):
782      op = Operation.create("custom.op1")
783
784      # CHECK: "custom.op1"
785      print(m)
786
787      clone = op.operation.clone()
788      op.operation.erase()
789
790      # CHECK: "custom.op1"
791      print(m)
792
793
794# CHECK-LABEL: TEST: testOperationLoc
795@run
796def testOperationLoc():
797  ctx = Context()
798  ctx.allow_unregistered_dialects = True
799  with ctx:
800    loc = Location.name("loc")
801    op = Operation.create("custom.op", loc=loc)
802    assert op.location == loc
803    assert op.operation.location == loc
804
805
806# CHECK-LABEL: TEST: testModuleMerge
807@run
808def testModuleMerge():
809  with Context():
810    m1 = Module.parse("func.func private @foo()")
811    m2 = Module.parse("""
812      func.func private @bar()
813      func.func private @qux()
814    """)
815    foo = m1.body.operations[0]
816    bar = m2.body.operations[0]
817    qux = m2.body.operations[1]
818    bar.move_before(foo)
819    qux.move_after(foo)
820
821    # CHECK: module
822    # CHECK: func private @bar
823    # CHECK: func private @foo
824    # CHECK: func private @qux
825    print(m1)
826
827    # CHECK: module {
828    # CHECK-NEXT: }
829    print(m2)
830
831
832# CHECK-LABEL: TEST: testAppendMoveFromAnotherBlock
833@run
834def testAppendMoveFromAnotherBlock():
835  with Context():
836    m1 = Module.parse("func.func private @foo()")
837    m2 = Module.parse("func.func private @bar()")
838    func = m1.body.operations[0]
839    m2.body.append(func)
840
841    # CHECK: module
842    # CHECK: func private @bar
843    # CHECK: func private @foo
844
845    print(m2)
846    # CHECK: module {
847    # CHECK-NEXT: }
848    print(m1)
849
850
851# CHECK-LABEL: TEST: testDetachFromParent
852@run
853def testDetachFromParent():
854  with Context():
855    m1 = Module.parse("func.func private @foo()")
856    func = m1.body.operations[0].detach_from_parent()
857
858    try:
859      func.detach_from_parent()
860    except ValueError as e:
861      if "has no parent" not in str(e):
862        raise
863    else:
864      assert False, "expected ValueError when detaching a detached operation"
865
866    print(m1)
867    # CHECK-NOT: func private @foo
868
869
870# CHECK-LABEL: TEST: testOperationHash
871@run
872def testOperationHash():
873  ctx = Context()
874  ctx.allow_unregistered_dialects = True
875  with ctx, Location.unknown():
876    op = Operation.create("custom.op1")
877    assert hash(op) == hash(op.operation)
878