1# RUN: %PYTHON %s | FileCheck %s
2
3import gc
4import io
5import itertools
6from mlir.ir import *
7
8
9def run(f):
10  print("\nTEST:", f.__name__)
11  f()
12  gc.collect()
13  assert Context._get_live_count() == 0
14  return f
15
16
17def expect_index_error(callback):
18  try:
19    _ = callback()
20    raise RuntimeError("Expected IndexError")
21  except IndexError:
22    pass
23
24
25# Verify iterator based traversal of the op/region/block hierarchy.
26# CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators
27@run
28def testTraverseOpRegionBlockIterators():
29  ctx = Context()
30  ctx.allow_unregistered_dialects = True
31  module = Module.parse(
32      r"""
33    func.func @f1(%arg0: i32) -> i32 {
34      %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
35      return %1 : i32
36    }
37  """, ctx)
38  op = module.operation
39  assert op.context is ctx
40  # Get the block using iterators off of the named collections.
41  regions = list(op.regions)
42  blocks = list(regions[0].blocks)
43  # CHECK: MODULE REGIONS=1 BLOCKS=1
44  print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}")
45
46  # Should verify.
47  # CHECK: .verify = True
48  print(f".verify = {module.operation.verify()}")
49
50  # Get the regions and blocks from the default collections.
51  default_regions = list(op.regions)
52  default_blocks = list(default_regions[0])
53  # They should compare equal regardless of how obtained.
54  assert default_regions == regions
55  assert default_blocks == blocks
56
57  # Should be able to get the operations from either the named collection
58  # or the block.
59  operations = list(blocks[0].operations)
60  default_operations = list(blocks[0])
61  assert default_operations == operations
62
63  def walk_operations(indent, op):
64    for i, region in enumerate(op.regions):
65      print(f"{indent}REGION {i}:")
66      for j, block in enumerate(region):
67        print(f"{indent}  BLOCK {j}:")
68        for k, child_op in enumerate(block):
69          print(f"{indent}    OP {k}: {child_op}")
70          walk_operations(indent + "      ", child_op)
71
72  # CHECK: REGION 0:
73  # CHECK:   BLOCK 0:
74  # CHECK:     OP 0: func
75  # CHECK:       REGION 0:
76  # CHECK:         BLOCK 0:
77  # CHECK:           OP 0: %0 = "custom.addi"
78  # CHECK:           OP 1: func.return
79  walk_operations("", op)
80
81
82# Verify index based traversal of the op/region/block hierarchy.
83# CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices
84@run
85def testTraverseOpRegionBlockIndices():
86  ctx = Context()
87  ctx.allow_unregistered_dialects = True
88  module = Module.parse(
89      r"""
90    func.func @f1(%arg0: i32) -> i32 {
91      %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
92      return %1 : i32
93    }
94  """, ctx)
95
96  def walk_operations(indent, op):
97    for i in range(len(op.regions)):
98      region = op.regions[i]
99      print(f"{indent}REGION {i}:")
100      for j in range(len(region.blocks)):
101        block = region.blocks[j]
102        print(f"{indent}  BLOCK {j}:")
103        for k in range(len(block.operations)):
104          child_op = block.operations[k]
105          print(f"{indent}    OP {k}: {child_op}")
106          print(f"{indent}    OP {k}: parent {child_op.operation.parent.name}")
107          walk_operations(indent + "      ", child_op)
108
109  # CHECK: REGION 0:
110  # CHECK:   BLOCK 0:
111  # CHECK:     OP 0: func
112  # CHECK:     OP 0: parent builtin.module
113  # CHECK:       REGION 0:
114  # CHECK:         BLOCK 0:
115  # CHECK:           OP 0: %0 = "custom.addi"
116  # CHECK:           OP 0: parent func.func
117  # CHECK:           OP 1: func.return
118  # CHECK:           OP 1: parent func.func
119  walk_operations("", module.operation)
120
121
122# CHECK-LABEL: TEST: testBlockAndRegionOwners
123@run
124def testBlockAndRegionOwners():
125  ctx = Context()
126  ctx.allow_unregistered_dialects = True
127  module = Module.parse(
128      r"""
129    builtin.module {
130      func.func @f() {
131        func.return
132      }
133    }
134  """, ctx)
135
136  assert module.operation.regions[0].owner == module.operation
137  assert module.operation.regions[0].blocks[0].owner == module.operation
138
139  func = module.body.operations[0]
140  assert func.operation.regions[0].owner == func
141  assert func.operation.regions[0].blocks[0].owner == func
142
143
144# CHECK-LABEL: TEST: testBlockArgumentList
145@run
146def testBlockArgumentList():
147  with Context() as ctx:
148    module = Module.parse(
149        r"""
150      func.func @f1(%arg0: i32, %arg1: f64, %arg2: index) {
151        return
152      }
153    """, ctx)
154    func = module.body.operations[0]
155    entry_block = func.regions[0].blocks[0]
156    assert len(entry_block.arguments) == 3
157    # CHECK: Argument 0, type i32
158    # CHECK: Argument 1, type f64
159    # CHECK: Argument 2, type index
160    for arg in entry_block.arguments:
161      print(f"Argument {arg.arg_number}, type {arg.type}")
162      new_type = IntegerType.get_signless(8 * (arg.arg_number + 1))
163      arg.set_type(new_type)
164
165    # CHECK: Argument 0, type i8
166    # CHECK: Argument 1, type i16
167    # CHECK: Argument 2, type i24
168    for arg in entry_block.arguments:
169      print(f"Argument {arg.arg_number}, type {arg.type}")
170
171    # Check that slicing works for block argument lists.
172    # CHECK: Argument 1, type i16
173    # CHECK: Argument 2, type i24
174    for arg in entry_block.arguments[1:]:
175      print(f"Argument {arg.arg_number}, type {arg.type}")
176
177    # Check that we can concatenate slices of argument lists.
178    # CHECK: Length: 4
179    print("Length: ",
180          len(entry_block.arguments[:2] + entry_block.arguments[1:]))
181
182    # CHECK: Type: i8
183    # CHECK: Type: i16
184    # CHECK: Type: i24
185    for t in entry_block.arguments.types:
186      print("Type: ", t)
187
188
189# CHECK-LABEL: TEST: testOperationOperands
190@run
191def testOperationOperands():
192  with Context() as ctx:
193    ctx.allow_unregistered_dialects = True
194    module = Module.parse(r"""
195      func.func @f1(%arg0: i32) {
196        %0 = "test.producer"() : () -> i64
197        "test.consumer"(%arg0, %0) : (i32, i64) -> ()
198        return
199      }""")
200    func = module.body.operations[0]
201    entry_block = func.regions[0].blocks[0]
202    consumer = entry_block.operations[1]
203    assert len(consumer.operands) == 2
204    # CHECK: Operand 0, type i32
205    # CHECK: Operand 1, type i64
206    for i, operand in enumerate(consumer.operands):
207      print(f"Operand {i}, type {operand.type}")
208
209
210
211
212# CHECK-LABEL: TEST: testOperationOperandsSlice
213@run
214def testOperationOperandsSlice():
215  with Context() as ctx:
216    ctx.allow_unregistered_dialects = True
217    module = Module.parse(r"""
218      func.func @f1() {
219        %0 = "test.producer0"() : () -> i64
220        %1 = "test.producer1"() : () -> i64
221        %2 = "test.producer2"() : () -> i64
222        %3 = "test.producer3"() : () -> i64
223        %4 = "test.producer4"() : () -> i64
224        "test.consumer"(%0, %1, %2, %3, %4) : (i64, i64, i64, i64, i64) -> ()
225        return
226      }""")
227    func = module.body.operations[0]
228    entry_block = func.regions[0].blocks[0]
229    consumer = entry_block.operations[5]
230    assert len(consumer.operands) == 5
231    for left, right in zip(consumer.operands, consumer.operands[::-1][::-1]):
232      assert left == right
233
234    # CHECK: test.producer0
235    # CHECK: test.producer1
236    # CHECK: test.producer2
237    # CHECK: test.producer3
238    # CHECK: test.producer4
239    full_slice = consumer.operands[:]
240    for operand in full_slice:
241      print(operand)
242
243    # CHECK: test.producer0
244    # CHECK: test.producer1
245    first_two = consumer.operands[0:2]
246    for operand in first_two:
247      print(operand)
248
249    # CHECK: test.producer3
250    # CHECK: test.producer4
251    last_two = consumer.operands[3:]
252    for operand in last_two:
253      print(operand)
254
255    # CHECK: test.producer0
256    # CHECK: test.producer2
257    # CHECK: test.producer4
258    even = consumer.operands[::2]
259    for operand in even:
260      print(operand)
261
262    # CHECK: test.producer2
263    fourth = consumer.operands[::2][1::2]
264    for operand in fourth:
265      print(operand)
266
267
268
269
270# CHECK-LABEL: TEST: testOperationOperandsSet
271@run
272def testOperationOperandsSet():
273  with Context() as ctx, Location.unknown(ctx):
274    ctx.allow_unregistered_dialects = True
275    module = Module.parse(r"""
276      func.func @f1() {
277        %0 = "test.producer0"() : () -> i64
278        %1 = "test.producer1"() : () -> i64
279        %2 = "test.producer2"() : () -> i64
280        "test.consumer"(%0) : (i64) -> ()
281        return
282      }""")
283    func = module.body.operations[0]
284    entry_block = func.regions[0].blocks[0]
285    producer1 = entry_block.operations[1]
286    producer2 = entry_block.operations[2]
287    consumer = entry_block.operations[3]
288    assert len(consumer.operands) == 1
289    type = consumer.operands[0].type
290
291    # CHECK: test.producer1
292    consumer.operands[0] = producer1.result
293    print(consumer.operands[0])
294
295    # CHECK: test.producer2
296    consumer.operands[-1] = producer2.result
297    print(consumer.operands[0])
298
299
300
301
302# CHECK-LABEL: TEST: testDetachedOperation
303@run
304def testDetachedOperation():
305  ctx = Context()
306  ctx.allow_unregistered_dialects = True
307  with Location.unknown(ctx):
308    i32 = IntegerType.get_signed(32)
309    op1 = Operation.create(
310        "custom.op1",
311        results=[i32, i32],
312        regions=1,
313        attributes={
314            "foo": StringAttr.get("foo_value"),
315            "bar": StringAttr.get("bar_value"),
316        })
317    # CHECK: %0:2 = "custom.op1"() ({
318    # CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32)
319    print(op1)
320
321  # TODO: Check successors once enough infra exists to do it properly.
322
323
324# CHECK-LABEL: TEST: testOperationInsertionPoint
325@run
326def testOperationInsertionPoint():
327  ctx = Context()
328  ctx.allow_unregistered_dialects = True
329  module = Module.parse(
330      r"""
331    func.func @f1(%arg0: i32) -> i32 {
332      %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
333      return %1 : i32
334    }
335  """, ctx)
336
337  # Create test op.
338  with Location.unknown(ctx):
339    op1 = Operation.create("custom.op1")
340    op2 = Operation.create("custom.op2")
341
342    func = module.body.operations[0]
343    entry_block = func.regions[0].blocks[0]
344    ip = InsertionPoint.at_block_begin(entry_block)
345    ip.insert(op1)
346    ip.insert(op2)
347    # CHECK: func @f1
348    # CHECK: "custom.op1"()
349    # CHECK: "custom.op2"()
350    # CHECK: %0 = "custom.addi"
351    print(module)
352
353  # Trying to add a previously added op should raise.
354  try:
355    ip.insert(op1)
356  except ValueError:
357    pass
358  else:
359    assert False, "expected insert of attached op to raise"
360
361
362# CHECK-LABEL: TEST: testOperationWithRegion
363@run
364def testOperationWithRegion():
365  ctx = Context()
366  ctx.allow_unregistered_dialects = True
367  with Location.unknown(ctx):
368    i32 = IntegerType.get_signed(32)
369    op1 = Operation.create("custom.op1", regions=1)
370    block = op1.regions[0].blocks.append(i32, i32)
371    # CHECK: "custom.op1"() ({
372    # CHECK: ^bb0(%arg0: si32, %arg1: si32):
373    # CHECK:   "custom.terminator"() : () -> ()
374    # CHECK: }) : () -> ()
375    terminator = Operation.create("custom.terminator")
376    ip = InsertionPoint(block)
377    ip.insert(terminator)
378    print(op1)
379
380    # Now add the whole operation to another op.
381    # TODO: Verify lifetime hazard by nulling out the new owning module and
382    # accessing op1.
383    # TODO: Also verify accessing the terminator once both parents are nulled
384    # out.
385    module = Module.parse(r"""
386      func.func @f1(%arg0: i32) -> i32 {
387        %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
388        return %1 : i32
389      }
390    """)
391    func = module.body.operations[0]
392    entry_block = func.regions[0].blocks[0]
393    ip = InsertionPoint.at_block_begin(entry_block)
394    ip.insert(op1)
395    # CHECK: func @f1
396    # CHECK: "custom.op1"()
397    # CHECK:   "custom.terminator"
398    # CHECK: %0 = "custom.addi"
399    print(module)
400
401
402# CHECK-LABEL: TEST: testOperationResultList
403@run
404def testOperationResultList():
405  ctx = Context()
406  module = Module.parse(
407      r"""
408    func.func @f1() {
409      %0:3 = call @f2() : () -> (i32, f64, index)
410      return
411    }
412    func.func private @f2() -> (i32, f64, index)
413  """, ctx)
414  caller = module.body.operations[0]
415  call = caller.regions[0].blocks[0].operations[0]
416  assert len(call.results) == 3
417  # CHECK: Result 0, type i32
418  # CHECK: Result 1, type f64
419  # CHECK: Result 2, type index
420  for res in call.results:
421    print(f"Result {res.result_number}, type {res.type}")
422
423  # CHECK: Result type i32
424  # CHECK: Result type f64
425  # CHECK: Result type index
426  for t in call.results.types:
427    print(f"Result type {t}")
428
429  # Out of range
430  expect_index_error(lambda: call.results[3])
431  expect_index_error(lambda: call.results[-4])
432
433
434# CHECK-LABEL: TEST: testOperationResultListSlice
435@run
436def testOperationResultListSlice():
437  with Context() as ctx:
438    ctx.allow_unregistered_dialects = True
439    module = Module.parse(r"""
440      func.func @f1() {
441        "some.op"() : () -> (i1, i2, i3, i4, i5)
442        return
443      }
444    """)
445    func = module.body.operations[0]
446    entry_block = func.regions[0].blocks[0]
447    producer = entry_block.operations[0]
448
449    assert len(producer.results) == 5
450    for left, right in zip(producer.results, producer.results[::-1][::-1]):
451      assert left == right
452      assert left.result_number == right.result_number
453
454    # CHECK: Result 0, type i1
455    # CHECK: Result 1, type i2
456    # CHECK: Result 2, type i3
457    # CHECK: Result 3, type i4
458    # CHECK: Result 4, type i5
459    full_slice = producer.results[:]
460    for res in full_slice:
461      print(f"Result {res.result_number}, type {res.type}")
462
463    # CHECK: Result 1, type i2
464    # CHECK: Result 2, type i3
465    # CHECK: Result 3, type i4
466    middle = producer.results[1:4]
467    for res in middle:
468      print(f"Result {res.result_number}, type {res.type}")
469
470    # CHECK: Result 1, type i2
471    # CHECK: Result 3, type i4
472    odd = producer.results[1::2]
473    for res in odd:
474      print(f"Result {res.result_number}, type {res.type}")
475
476    # CHECK: Result 3, type i4
477    # CHECK: Result 1, type i2
478    inverted_middle = producer.results[-2:0:-2]
479    for res in inverted_middle:
480      print(f"Result {res.result_number}, type {res.type}")
481
482
483# CHECK-LABEL: TEST: testOperationAttributes
484@run
485def testOperationAttributes():
486  ctx = Context()
487  ctx.allow_unregistered_dialects = True
488  module = Module.parse(
489      r"""
490    "some.op"() { some.attribute = 1 : i8,
491                  other.attribute = 3.0,
492                  dependent = "text" } : () -> ()
493  """, ctx)
494  op = module.body.operations[0]
495  assert len(op.attributes) == 3
496  iattr = IntegerAttr(op.attributes["some.attribute"])
497  fattr = FloatAttr(op.attributes["other.attribute"])
498  sattr = StringAttr(op.attributes["dependent"])
499  # CHECK: Attribute type i8, value 1
500  print(f"Attribute type {iattr.type}, value {iattr.value}")
501  # CHECK: Attribute type f64, value 3.0
502  print(f"Attribute type {fattr.type}, value {fattr.value}")
503  # CHECK: Attribute value text
504  print(f"Attribute value {sattr.value}")
505
506  # We don't know in which order the attributes are stored.
507  # CHECK-DAG: NamedAttribute(dependent="text")
508  # CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64)
509  # CHECK-DAG: NamedAttribute(some.attribute=1 : i8)
510  for attr in op.attributes:
511    print(str(attr))
512
513  # Check that exceptions are raised as expected.
514  try:
515    op.attributes["does_not_exist"]
516  except KeyError:
517    pass
518  else:
519    assert False, "expected KeyError on accessing a non-existent attribute"
520
521  try:
522    op.attributes[42]
523  except IndexError:
524    pass
525  else:
526    assert False, "expected IndexError on accessing an out-of-bounds attribute"
527
528
529
530
531# CHECK-LABEL: TEST: testOperationPrint
532@run
533def testOperationPrint():
534  ctx = Context()
535  module = Module.parse(
536      r"""
537    func.func @f1(%arg0: i32) -> i32 {
538      %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
539      return %arg0 : i32
540    }
541  """, ctx)
542
543  # Test print to stdout.
544  # CHECK: return %arg0 : i32
545  module.operation.print()
546
547  # Test print to text file.
548  f = io.StringIO()
549  # CHECK: <class 'str'>
550  # CHECK: return %arg0 : i32
551  module.operation.print(file=f)
552  str_value = f.getvalue()
553  print(str_value.__class__)
554  print(f.getvalue())
555
556  # Test print to binary file.
557  f = io.BytesIO()
558  # CHECK: <class 'bytes'>
559  # CHECK: return %arg0 : i32
560  module.operation.print(file=f, binary=True)
561  bytes_value = f.getvalue()
562  print(bytes_value.__class__)
563  print(bytes_value)
564
565  # Test get_asm with options.
566  # CHECK: value = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<4xi32>
567  # CHECK: "func.return"(%arg0) : (i32) -> () -:4:7
568  module.operation.print(
569      large_elements_limit=2,
570      enable_debug_info=True,
571      pretty_debug_info=True,
572      print_generic_op_form=True,
573      use_local_scope=True)
574
575
576
577
578# CHECK-LABEL: TEST: testKnownOpView
579@run
580def testKnownOpView():
581  with Context(), Location.unknown():
582    Context.current.allow_unregistered_dialects = True
583    module = Module.parse(r"""
584      %1 = "custom.f32"() : () -> f32
585      %2 = "custom.f32"() : () -> f32
586      %3 = arith.addf %1, %2 : f32
587    """)
588    print(module)
589
590    # addf should map to a known OpView class in the arithmetic dialect.
591    # We know the OpView for it defines an 'lhs' attribute.
592    addf = module.body.operations[2]
593    # CHECK: <mlir.dialects._arith_ops_gen._AddFOp object
594    print(repr(addf))
595    # CHECK: "custom.f32"()
596    print(addf.lhs)
597
598    # One of the custom ops should resolve to the default OpView.
599    custom = module.body.operations[0]
600    # CHECK: OpView object
601    print(repr(custom))
602
603    # Check again to make sure negative caching works.
604    custom = module.body.operations[0]
605    # CHECK: OpView object
606    print(repr(custom))
607
608
609# CHECK-LABEL: TEST: testSingleResultProperty
610@run
611def testSingleResultProperty():
612  with Context(), Location.unknown():
613    Context.current.allow_unregistered_dialects = True
614    module = Module.parse(r"""
615      "custom.no_result"() : () -> ()
616      %0:2 = "custom.two_result"() : () -> (f32, f32)
617      %1 = "custom.one_result"() : () -> f32
618    """)
619    print(module)
620
621  try:
622    module.body.operations[0].result
623  except ValueError as e:
624    # CHECK: Cannot call .result on operation custom.no_result which has 0 results
625    print(e)
626  else:
627    assert False, "Expected exception"
628
629  try:
630    module.body.operations[1].result
631  except ValueError as e:
632    # CHECK: Cannot call .result on operation custom.two_result which has 2 results
633    print(e)
634  else:
635    assert False, "Expected exception"
636
637  # CHECK: %1 = "custom.one_result"() : () -> f32
638  print(module.body.operations[2])
639
640
641def create_invalid_operation():
642  # This module has two region and is invalid verify that we fallback
643  # to the generic printer for safety.
644  op = Operation.create("builtin.module", regions=2)
645  op.regions[0].blocks.append()
646  return op
647
648# CHECK-LABEL: TEST: testInvalidOperationStrSoftFails
649@run
650def testInvalidOperationStrSoftFails():
651  ctx = Context()
652  with Location.unknown(ctx):
653    invalid_op = create_invalid_operation()
654    # Verify that we fallback to the generic printer for safety.
655    # CHECK: // Verification failed, printing generic form
656    # CHECK: "builtin.module"() ({
657    # CHECK: }) : () -> ()
658    print(invalid_op)
659    # CHECK: .verify = False
660    print(f".verify = {invalid_op.operation.verify()}")
661
662
663# CHECK-LABEL: TEST: testInvalidModuleStrSoftFails
664@run
665def testInvalidModuleStrSoftFails():
666  ctx = Context()
667  with Location.unknown(ctx):
668    module = Module.create()
669    with InsertionPoint(module.body):
670      invalid_op = create_invalid_operation()
671    # Verify that we fallback to the generic printer for safety.
672    # CHECK: // Verification failed, printing generic form
673    print(module)
674
675
676# CHECK-LABEL: TEST: testInvalidOperationGetAsmBinarySoftFails
677@run
678def testInvalidOperationGetAsmBinarySoftFails():
679  ctx = Context()
680  with Location.unknown(ctx):
681    invalid_op = create_invalid_operation()
682    # Verify that we fallback to the generic printer for safety.
683    # CHECK: b'// Verification failed, printing generic form\n
684    print(invalid_op.get_asm(binary=True))
685
686
687# CHECK-LABEL: TEST: testCreateWithInvalidAttributes
688@run
689def testCreateWithInvalidAttributes():
690  ctx = Context()
691  with Location.unknown(ctx):
692    try:
693      Operation.create(
694          "builtin.module", attributes={None: StringAttr.get("name")})
695    except Exception as e:
696      # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
697      print(e)
698    try:
699      Operation.create(
700          "builtin.module", attributes={42: StringAttr.get("name")})
701    except Exception as e:
702      # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
703      print(e)
704    try:
705      Operation.create("builtin.module", attributes={"some_key": ctx})
706    except Exception as e:
707      # CHECK: Invalid attribute value for the key "some_key" when attempting to create the operation "builtin.module"
708      print(e)
709    try:
710      Operation.create("builtin.module", attributes={"some_key": None})
711    except Exception as e:
712      # CHECK: Found an invalid (`None`?) attribute value for the key "some_key" when attempting to create the operation "builtin.module"
713      print(e)
714
715
716# CHECK-LABEL: TEST: testOperationName
717@run
718def testOperationName():
719  ctx = Context()
720  ctx.allow_unregistered_dialects = True
721  module = Module.parse(
722      r"""
723    %0 = "custom.op1"() : () -> f32
724    %1 = "custom.op2"() : () -> i32
725    %2 = "custom.op1"() : () -> f32
726  """, ctx)
727
728  # CHECK: custom.op1
729  # CHECK: custom.op2
730  # CHECK: custom.op1
731  for op in module.body.operations:
732    print(op.operation.name)
733
734
735# CHECK-LABEL: TEST: testCapsuleConversions
736@run
737def testCapsuleConversions():
738  ctx = Context()
739  ctx.allow_unregistered_dialects = True
740  with Location.unknown(ctx):
741    m = Operation.create("custom.op1").operation
742    m_capsule = m._CAPIPtr
743    assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule)
744    m2 = Operation._CAPICreate(m_capsule)
745    assert m2 is m
746
747
748# CHECK-LABEL: TEST: testOperationErase
749@run
750def testOperationErase():
751  ctx = Context()
752  ctx.allow_unregistered_dialects = True
753  with Location.unknown(ctx):
754    m = Module.create()
755    with InsertionPoint(m.body):
756      op = Operation.create("custom.op1")
757
758      # CHECK: "custom.op1"
759      print(m)
760
761      op.operation.erase()
762
763      # CHECK-NOT: "custom.op1"
764      print(m)
765
766      # Ensure we can create another operation
767      Operation.create("custom.op2")
768
769
770# CHECK-LABEL: TEST: testOperationClone
771@run
772def testOperationClone():
773  ctx = Context()
774  ctx.allow_unregistered_dialects = True
775  with Location.unknown(ctx):
776    m = Module.create()
777    with InsertionPoint(m.body):
778      op = Operation.create("custom.op1")
779
780      # CHECK: "custom.op1"
781      print(m)
782
783      clone = op.operation.clone()
784      op.operation.erase()
785
786      # CHECK: "custom.op1"
787      print(m)
788
789
790# CHECK-LABEL: TEST: testOperationLoc
791@run
792def testOperationLoc():
793  ctx = Context()
794  ctx.allow_unregistered_dialects = True
795  with ctx:
796    loc = Location.name("loc")
797    op = Operation.create("custom.op", loc=loc)
798    assert op.location == loc
799    assert op.operation.location == loc
800
801
802# CHECK-LABEL: TEST: testModuleMerge
803@run
804def testModuleMerge():
805  with Context():
806    m1 = Module.parse("func.func private @foo()")
807    m2 = Module.parse("""
808      func.func private @bar()
809      func.func private @qux()
810    """)
811    foo = m1.body.operations[0]
812    bar = m2.body.operations[0]
813    qux = m2.body.operations[1]
814    bar.move_before(foo)
815    qux.move_after(foo)
816
817    # CHECK: module
818    # CHECK: func private @bar
819    # CHECK: func private @foo
820    # CHECK: func private @qux
821    print(m1)
822
823    # CHECK: module {
824    # CHECK-NEXT: }
825    print(m2)
826
827
828# CHECK-LABEL: TEST: testAppendMoveFromAnotherBlock
829@run
830def testAppendMoveFromAnotherBlock():
831  with Context():
832    m1 = Module.parse("func.func private @foo()")
833    m2 = Module.parse("func.func private @bar()")
834    func = m1.body.operations[0]
835    m2.body.append(func)
836
837    # CHECK: module
838    # CHECK: func private @bar
839    # CHECK: func private @foo
840
841    print(m2)
842    # CHECK: module {
843    # CHECK-NEXT: }
844    print(m1)
845
846
847# CHECK-LABEL: TEST: testDetachFromParent
848@run
849def testDetachFromParent():
850  with Context():
851    m1 = Module.parse("func.func private @foo()")
852    func = m1.body.operations[0].detach_from_parent()
853
854    try:
855      func.detach_from_parent()
856    except ValueError as e:
857      if "has no parent" not in str(e):
858        raise
859    else:
860      assert False, "expected ValueError when detaching a detached operation"
861
862    print(m1)
863    # CHECK-NOT: func private @foo
864
865
866# CHECK-LABEL: TEST: testOperationHash
867@run
868def testOperationHash():
869  ctx = Context()
870  ctx.allow_unregistered_dialects = True
871  with ctx, Location.unknown():
872    op = Operation.create("custom.op1")
873    assert hash(op) == hash(op.operation)
874