1# RUN: %PYTHON %s | FileCheck %s
2
3import gc
4import io
5import itertools
6from mlir.ir import *
7
8
9def run(f):
10  print("\nTEST:", f.__name__)
11  f()
12  gc.collect()
13  assert Context._get_live_count() == 0
14  return f
15
16
17# Verify iterator based traversal of the op/region/block hierarchy.
18# CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators
19@run
20def testTraverseOpRegionBlockIterators():
21  ctx = Context()
22  ctx.allow_unregistered_dialects = True
23  module = Module.parse(
24      r"""
25    func @f1(%arg0: i32) -> i32 {
26      %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
27      return %1 : i32
28    }
29  """, ctx)
30  op = module.operation
31  assert op.context is ctx
32  # Get the block using iterators off of the named collections.
33  regions = list(op.regions)
34  blocks = list(regions[0].blocks)
35  # CHECK: MODULE REGIONS=1 BLOCKS=1
36  print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}")
37
38  # Should verify.
39  # CHECK: .verify = True
40  print(f".verify = {module.operation.verify()}")
41
42  # Get the regions and blocks from the default collections.
43  default_regions = list(op.regions)
44  default_blocks = list(default_regions[0])
45  # They should compare equal regardless of how obtained.
46  assert default_regions == regions
47  assert default_blocks == blocks
48
49  # Should be able to get the operations from either the named collection
50  # or the block.
51  operations = list(blocks[0].operations)
52  default_operations = list(blocks[0])
53  assert default_operations == operations
54
55  def walk_operations(indent, op):
56    for i, region in enumerate(op.regions):
57      print(f"{indent}REGION {i}:")
58      for j, block in enumerate(region):
59        print(f"{indent}  BLOCK {j}:")
60        for k, child_op in enumerate(block):
61          print(f"{indent}    OP {k}: {child_op}")
62          walk_operations(indent + "      ", child_op)
63
64  # CHECK: REGION 0:
65  # CHECK:   BLOCK 0:
66  # CHECK:     OP 0: func
67  # CHECK:       REGION 0:
68  # CHECK:         BLOCK 0:
69  # CHECK:           OP 0: %0 = "custom.addi"
70  # CHECK:           OP 1: return
71  walk_operations("", op)
72
73
74# Verify index based traversal of the op/region/block hierarchy.
75# CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices
76@run
77def testTraverseOpRegionBlockIndices():
78  ctx = Context()
79  ctx.allow_unregistered_dialects = True
80  module = Module.parse(
81      r"""
82    func @f1(%arg0: i32) -> i32 {
83      %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
84      return %1 : i32
85    }
86  """, ctx)
87
88  def walk_operations(indent, op):
89    for i in range(len(op.regions)):
90      region = op.regions[i]
91      print(f"{indent}REGION {i}:")
92      for j in range(len(region.blocks)):
93        block = region.blocks[j]
94        print(f"{indent}  BLOCK {j}:")
95        for k in range(len(block.operations)):
96          child_op = block.operations[k]
97          print(f"{indent}    OP {k}: {child_op}")
98          print(f"{indent}    OP {k}: parent {child_op.operation.parent.name}")
99          walk_operations(indent + "      ", child_op)
100
101  # CHECK: REGION 0:
102  # CHECK:   BLOCK 0:
103  # CHECK:     OP 0: func
104  # CHECK:     OP 0: parent builtin.module
105  # CHECK:       REGION 0:
106  # CHECK:         BLOCK 0:
107  # CHECK:           OP 0: %0 = "custom.addi"
108  # CHECK:           OP 0: parent builtin.func
109  # CHECK:           OP 1: return
110  # CHECK:           OP 1: parent builtin.func
111  walk_operations("", module.operation)
112
113
114# CHECK-LABEL: TEST: testBlockAndRegionOwners
115@run
116def testBlockAndRegionOwners():
117  ctx = Context()
118  ctx.allow_unregistered_dialects = True
119  module = Module.parse(
120      r"""
121    builtin.module {
122      builtin.func @f() {
123        std.return
124      }
125    }
126  """, ctx)
127
128  assert module.operation.regions[0].owner == module.operation
129  assert module.operation.regions[0].blocks[0].owner == module.operation
130
131  func = module.body.operations[0]
132  assert func.operation.regions[0].owner == func
133  assert func.operation.regions[0].blocks[0].owner == func
134
135
136# CHECK-LABEL: TEST: testBlockArgumentList
137@run
138def testBlockArgumentList():
139  with Context() as ctx:
140    module = Module.parse(
141        r"""
142      func @f1(%arg0: i32, %arg1: f64, %arg2: index) {
143        return
144      }
145    """, ctx)
146    func = module.body.operations[0]
147    entry_block = func.regions[0].blocks[0]
148    assert len(entry_block.arguments) == 3
149    # CHECK: Argument 0, type i32
150    # CHECK: Argument 1, type f64
151    # CHECK: Argument 2, type index
152    for arg in entry_block.arguments:
153      print(f"Argument {arg.arg_number}, type {arg.type}")
154      new_type = IntegerType.get_signless(8 * (arg.arg_number + 1))
155      arg.set_type(new_type)
156
157    # CHECK: Argument 0, type i8
158    # CHECK: Argument 1, type i16
159    # CHECK: Argument 2, type i24
160    for arg in entry_block.arguments:
161      print(f"Argument {arg.arg_number}, type {arg.type}")
162
163    # Check that slicing works for block argument lists.
164    # CHECK: Argument 1, type i16
165    # CHECK: Argument 2, type i24
166    for arg in entry_block.arguments[1:]:
167      print(f"Argument {arg.arg_number}, type {arg.type}")
168
169    # Check that we can concatenate slices of argument lists.
170    # CHECK: Length: 4
171    print("Length: ",
172          len(entry_block.arguments[:2] + entry_block.arguments[1:]))
173
174    # CHECK: Type: i8
175    # CHECK: Type: i16
176    # CHECK: Type: i24
177    for t in entry_block.arguments.types:
178      print("Type: ", t)
179
180
181# CHECK-LABEL: TEST: testOperationOperands
182@run
183def testOperationOperands():
184  with Context() as ctx:
185    ctx.allow_unregistered_dialects = True
186    module = Module.parse(r"""
187      func @f1(%arg0: i32) {
188        %0 = "test.producer"() : () -> i64
189        "test.consumer"(%arg0, %0) : (i32, i64) -> ()
190        return
191      }""")
192    func = module.body.operations[0]
193    entry_block = func.regions[0].blocks[0]
194    consumer = entry_block.operations[1]
195    assert len(consumer.operands) == 2
196    # CHECK: Operand 0, type i32
197    # CHECK: Operand 1, type i64
198    for i, operand in enumerate(consumer.operands):
199      print(f"Operand {i}, type {operand.type}")
200
201
202
203
204# CHECK-LABEL: TEST: testOperationOperandsSlice
205@run
206def testOperationOperandsSlice():
207  with Context() as ctx:
208    ctx.allow_unregistered_dialects = True
209    module = Module.parse(r"""
210      func @f1() {
211        %0 = "test.producer0"() : () -> i64
212        %1 = "test.producer1"() : () -> i64
213        %2 = "test.producer2"() : () -> i64
214        %3 = "test.producer3"() : () -> i64
215        %4 = "test.producer4"() : () -> i64
216        "test.consumer"(%0, %1, %2, %3, %4) : (i64, i64, i64, i64, i64) -> ()
217        return
218      }""")
219    func = module.body.operations[0]
220    entry_block = func.regions[0].blocks[0]
221    consumer = entry_block.operations[5]
222    assert len(consumer.operands) == 5
223    for left, right in zip(consumer.operands, consumer.operands[::-1][::-1]):
224      assert left == right
225
226    # CHECK: test.producer0
227    # CHECK: test.producer1
228    # CHECK: test.producer2
229    # CHECK: test.producer3
230    # CHECK: test.producer4
231    full_slice = consumer.operands[:]
232    for operand in full_slice:
233      print(operand)
234
235    # CHECK: test.producer0
236    # CHECK: test.producer1
237    first_two = consumer.operands[0:2]
238    for operand in first_two:
239      print(operand)
240
241    # CHECK: test.producer3
242    # CHECK: test.producer4
243    last_two = consumer.operands[3:]
244    for operand in last_two:
245      print(operand)
246
247    # CHECK: test.producer0
248    # CHECK: test.producer2
249    # CHECK: test.producer4
250    even = consumer.operands[::2]
251    for operand in even:
252      print(operand)
253
254    # CHECK: test.producer2
255    fourth = consumer.operands[::2][1::2]
256    for operand in fourth:
257      print(operand)
258
259
260
261
262# CHECK-LABEL: TEST: testOperationOperandsSet
263@run
264def testOperationOperandsSet():
265  with Context() as ctx, Location.unknown(ctx):
266    ctx.allow_unregistered_dialects = True
267    module = Module.parse(r"""
268      func @f1() {
269        %0 = "test.producer0"() : () -> i64
270        %1 = "test.producer1"() : () -> i64
271        %2 = "test.producer2"() : () -> i64
272        "test.consumer"(%0) : (i64) -> ()
273        return
274      }""")
275    func = module.body.operations[0]
276    entry_block = func.regions[0].blocks[0]
277    producer1 = entry_block.operations[1]
278    producer2 = entry_block.operations[2]
279    consumer = entry_block.operations[3]
280    assert len(consumer.operands) == 1
281    type = consumer.operands[0].type
282
283    # CHECK: test.producer1
284    consumer.operands[0] = producer1.result
285    print(consumer.operands[0])
286
287    # CHECK: test.producer2
288    consumer.operands[-1] = producer2.result
289    print(consumer.operands[0])
290
291
292
293
294# CHECK-LABEL: TEST: testDetachedOperation
295@run
296def testDetachedOperation():
297  ctx = Context()
298  ctx.allow_unregistered_dialects = True
299  with Location.unknown(ctx):
300    i32 = IntegerType.get_signed(32)
301    op1 = Operation.create(
302        "custom.op1",
303        results=[i32, i32],
304        regions=1,
305        attributes={
306            "foo": StringAttr.get("foo_value"),
307            "bar": StringAttr.get("bar_value"),
308        })
309    # CHECK: %0:2 = "custom.op1"() ( {
310    # CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32)
311    print(op1)
312
313  # TODO: Check successors once enough infra exists to do it properly.
314
315
316# CHECK-LABEL: TEST: testOperationInsertionPoint
317@run
318def testOperationInsertionPoint():
319  ctx = Context()
320  ctx.allow_unregistered_dialects = True
321  module = Module.parse(
322      r"""
323    func @f1(%arg0: i32) -> i32 {
324      %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
325      return %1 : i32
326    }
327  """, ctx)
328
329  # Create test op.
330  with Location.unknown(ctx):
331    op1 = Operation.create("custom.op1")
332    op2 = Operation.create("custom.op2")
333
334    func = module.body.operations[0]
335    entry_block = func.regions[0].blocks[0]
336    ip = InsertionPoint.at_block_begin(entry_block)
337    ip.insert(op1)
338    ip.insert(op2)
339    # CHECK: func @f1
340    # CHECK: "custom.op1"()
341    # CHECK: "custom.op2"()
342    # CHECK: %0 = "custom.addi"
343    print(module)
344
345  # Trying to add a previously added op should raise.
346  try:
347    ip.insert(op1)
348  except ValueError:
349    pass
350  else:
351    assert False, "expected insert of attached op to raise"
352
353
354# CHECK-LABEL: TEST: testOperationWithRegion
355@run
356def testOperationWithRegion():
357  ctx = Context()
358  ctx.allow_unregistered_dialects = True
359  with Location.unknown(ctx):
360    i32 = IntegerType.get_signed(32)
361    op1 = Operation.create("custom.op1", regions=1)
362    block = op1.regions[0].blocks.append(i32, i32)
363    # CHECK: "custom.op1"() ( {
364    # CHECK: ^bb0(%arg0: si32, %arg1: si32):  // no predecessors
365    # CHECK:   "custom.terminator"() : () -> ()
366    # CHECK: }) : () -> ()
367    terminator = Operation.create("custom.terminator")
368    ip = InsertionPoint(block)
369    ip.insert(terminator)
370    print(op1)
371
372    # Now add the whole operation to another op.
373    # TODO: Verify lifetime hazard by nulling out the new owning module and
374    # accessing op1.
375    # TODO: Also verify accessing the terminator once both parents are nulled
376    # out.
377    module = Module.parse(r"""
378      func @f1(%arg0: i32) -> i32 {
379        %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
380        return %1 : i32
381      }
382    """)
383    func = module.body.operations[0]
384    entry_block = func.regions[0].blocks[0]
385    ip = InsertionPoint.at_block_begin(entry_block)
386    ip.insert(op1)
387    # CHECK: func @f1
388    # CHECK: "custom.op1"()
389    # CHECK:   "custom.terminator"
390    # CHECK: %0 = "custom.addi"
391    print(module)
392
393
394# CHECK-LABEL: TEST: testOperationResultList
395@run
396def testOperationResultList():
397  ctx = Context()
398  module = Module.parse(
399      r"""
400    func @f1() {
401      %0:3 = call @f2() : () -> (i32, f64, index)
402      return
403    }
404    func private @f2() -> (i32, f64, index)
405  """, ctx)
406  caller = module.body.operations[0]
407  call = caller.regions[0].blocks[0].operations[0]
408  assert len(call.results) == 3
409  # CHECK: Result 0, type i32
410  # CHECK: Result 1, type f64
411  # CHECK: Result 2, type index
412  for res in call.results:
413    print(f"Result {res.result_number}, type {res.type}")
414
415  # CHECK: Result type i32
416  # CHECK: Result type f64
417  # CHECK: Result type index
418  for t in call.results.types:
419    print(f"Result type {t}")
420
421
422
423
424# CHECK-LABEL: TEST: testOperationResultListSlice
425@run
426def testOperationResultListSlice():
427  with Context() as ctx:
428    ctx.allow_unregistered_dialects = True
429    module = Module.parse(r"""
430      func @f1() {
431        "some.op"() : () -> (i1, i2, i3, i4, i5)
432        return
433      }
434    """)
435    func = module.body.operations[0]
436    entry_block = func.regions[0].blocks[0]
437    producer = entry_block.operations[0]
438
439    assert len(producer.results) == 5
440    for left, right in zip(producer.results, producer.results[::-1][::-1]):
441      assert left == right
442      assert left.result_number == right.result_number
443
444    # CHECK: Result 0, type i1
445    # CHECK: Result 1, type i2
446    # CHECK: Result 2, type i3
447    # CHECK: Result 3, type i4
448    # CHECK: Result 4, type i5
449    full_slice = producer.results[:]
450    for res in full_slice:
451      print(f"Result {res.result_number}, type {res.type}")
452
453    # CHECK: Result 1, type i2
454    # CHECK: Result 2, type i3
455    # CHECK: Result 3, type i4
456    middle = producer.results[1:4]
457    for res in middle:
458      print(f"Result {res.result_number}, type {res.type}")
459
460    # CHECK: Result 1, type i2
461    # CHECK: Result 3, type i4
462    odd = producer.results[1::2]
463    for res in odd:
464      print(f"Result {res.result_number}, type {res.type}")
465
466    # CHECK: Result 3, type i4
467    # CHECK: Result 1, type i2
468    inverted_middle = producer.results[-2:0:-2]
469    for res in inverted_middle:
470      print(f"Result {res.result_number}, type {res.type}")
471
472
473
474
475# CHECK-LABEL: TEST: testOperationAttributes
476@run
477def testOperationAttributes():
478  ctx = Context()
479  ctx.allow_unregistered_dialects = True
480  module = Module.parse(
481      r"""
482    "some.op"() { some.attribute = 1 : i8,
483                  other.attribute = 3.0,
484                  dependent = "text" } : () -> ()
485  """, ctx)
486  op = module.body.operations[0]
487  assert len(op.attributes) == 3
488  iattr = IntegerAttr(op.attributes["some.attribute"])
489  fattr = FloatAttr(op.attributes["other.attribute"])
490  sattr = StringAttr(op.attributes["dependent"])
491  # CHECK: Attribute type i8, value 1
492  print(f"Attribute type {iattr.type}, value {iattr.value}")
493  # CHECK: Attribute type f64, value 3.0
494  print(f"Attribute type {fattr.type}, value {fattr.value}")
495  # CHECK: Attribute value text
496  print(f"Attribute value {sattr.value}")
497
498  # We don't know in which order the attributes are stored.
499  # CHECK-DAG: NamedAttribute(dependent="text")
500  # CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64)
501  # CHECK-DAG: NamedAttribute(some.attribute=1 : i8)
502  for attr in op.attributes:
503    print(str(attr))
504
505  # Check that exceptions are raised as expected.
506  try:
507    op.attributes["does_not_exist"]
508  except KeyError:
509    pass
510  else:
511    assert False, "expected KeyError on accessing a non-existent attribute"
512
513  try:
514    op.attributes[42]
515  except IndexError:
516    pass
517  else:
518    assert False, "expected IndexError on accessing an out-of-bounds attribute"
519
520
521
522
523# CHECK-LABEL: TEST: testOperationPrint
524@run
525def testOperationPrint():
526  ctx = Context()
527  module = Module.parse(
528      r"""
529    func @f1(%arg0: i32) -> i32 {
530      %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
531      return %arg0 : i32
532    }
533  """, ctx)
534
535  # Test print to stdout.
536  # CHECK: return %arg0 : i32
537  module.operation.print()
538
539  # Test print to text file.
540  f = io.StringIO()
541  # CHECK: <class 'str'>
542  # CHECK: return %arg0 : i32
543  module.operation.print(file=f)
544  str_value = f.getvalue()
545  print(str_value.__class__)
546  print(f.getvalue())
547
548  # Test print to binary file.
549  f = io.BytesIO()
550  # CHECK: <class 'bytes'>
551  # CHECK: return %arg0 : i32
552  module.operation.print(file=f, binary=True)
553  bytes_value = f.getvalue()
554  print(bytes_value.__class__)
555  print(bytes_value)
556
557  # Test get_asm with options.
558  # CHECK: value = opaque<"_", "0xDEADBEEF"> : tensor<4xi32>
559  # CHECK: "std.return"(%arg0) : (i32) -> () -:4:7
560  module.operation.print(
561      large_elements_limit=2,
562      enable_debug_info=True,
563      pretty_debug_info=True,
564      print_generic_op_form=True,
565      use_local_scope=True)
566
567
568
569
570# CHECK-LABEL: TEST: testKnownOpView
571@run
572def testKnownOpView():
573  with Context(), Location.unknown():
574    Context.current.allow_unregistered_dialects = True
575    module = Module.parse(r"""
576      %1 = "custom.f32"() : () -> f32
577      %2 = "custom.f32"() : () -> f32
578      %3 = arith.addf %1, %2 : f32
579    """)
580    print(module)
581
582    # addf should map to a known OpView class in the std dialect.
583    # We know the OpView for it defines an 'lhs' attribute.
584    addf = module.body.operations[2]
585    # CHECK: <mlir.dialects._arith_ops_gen._AddFOp object
586    print(repr(addf))
587    # CHECK: "custom.f32"()
588    print(addf.lhs)
589
590    # One of the custom ops should resolve to the default OpView.
591    custom = module.body.operations[0]
592    # CHECK: OpView object
593    print(repr(custom))
594
595    # Check again to make sure negative caching works.
596    custom = module.body.operations[0]
597    # CHECK: OpView object
598    print(repr(custom))
599
600
601# CHECK-LABEL: TEST: testSingleResultProperty
602@run
603def testSingleResultProperty():
604  with Context(), Location.unknown():
605    Context.current.allow_unregistered_dialects = True
606    module = Module.parse(r"""
607      "custom.no_result"() : () -> ()
608      %0:2 = "custom.two_result"() : () -> (f32, f32)
609      %1 = "custom.one_result"() : () -> f32
610    """)
611    print(module)
612
613  try:
614    module.body.operations[0].result
615  except ValueError as e:
616    # CHECK: Cannot call .result on operation custom.no_result which has 0 results
617    print(e)
618  else:
619    assert False, "Expected exception"
620
621  try:
622    module.body.operations[1].result
623  except ValueError as e:
624    # CHECK: Cannot call .result on operation custom.two_result which has 2 results
625    print(e)
626  else:
627    assert False, "Expected exception"
628
629  # CHECK: %1 = "custom.one_result"() : () -> f32
630  print(module.body.operations[2])
631
632
633def create_invalid_operation():
634  # This module has two region and is invalid verify that we fallback
635  # to the generic printer for safety.
636  op = Operation.create("builtin.module", regions=2)
637  op.regions[0].blocks.append()
638  return op
639
640# CHECK-LABEL: TEST: testInvalidOperationStrSoftFails
641@run
642def testInvalidOperationStrSoftFails():
643  ctx = Context()
644  with Location.unknown(ctx):
645    invalid_op = create_invalid_operation()
646    # Verify that we fallback to the generic printer for safety.
647    # CHECK: // Verification failed, printing generic form
648    # CHECK: "builtin.module"() ( {
649    # CHECK: }) : () -> ()
650    print(invalid_op)
651    # CHECK: .verify = False
652    print(f".verify = {invalid_op.operation.verify()}")
653
654
655# CHECK-LABEL: TEST: testInvalidModuleStrSoftFails
656@run
657def testInvalidModuleStrSoftFails():
658  ctx = Context()
659  with Location.unknown(ctx):
660    module = Module.create()
661    with InsertionPoint(module.body):
662      invalid_op = create_invalid_operation()
663    # Verify that we fallback to the generic printer for safety.
664    # CHECK: // Verification failed, printing generic form
665    print(module)
666
667
668# CHECK-LABEL: TEST: testInvalidOperationGetAsmBinarySoftFails
669@run
670def testInvalidOperationGetAsmBinarySoftFails():
671  ctx = Context()
672  with Location.unknown(ctx):
673    invalid_op = create_invalid_operation()
674    # Verify that we fallback to the generic printer for safety.
675    # CHECK: b'// Verification failed, printing generic form\n
676    print(invalid_op.get_asm(binary=True))
677
678
679# CHECK-LABEL: TEST: testCreateWithInvalidAttributes
680@run
681def testCreateWithInvalidAttributes():
682  ctx = Context()
683  with Location.unknown(ctx):
684    try:
685      Operation.create(
686          "builtin.module", attributes={None: StringAttr.get("name")})
687    except Exception as e:
688      # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
689      print(e)
690    try:
691      Operation.create(
692          "builtin.module", attributes={42: StringAttr.get("name")})
693    except Exception as e:
694      # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
695      print(e)
696    try:
697      Operation.create("builtin.module", attributes={"some_key": ctx})
698    except Exception as e:
699      # CHECK: Invalid attribute value for the key "some_key" when attempting to create the operation "builtin.module"
700      print(e)
701    try:
702      Operation.create("builtin.module", attributes={"some_key": None})
703    except Exception as e:
704      # CHECK: Found an invalid (`None`?) attribute value for the key "some_key" when attempting to create the operation "builtin.module"
705      print(e)
706
707
708# CHECK-LABEL: TEST: testOperationName
709@run
710def testOperationName():
711  ctx = Context()
712  ctx.allow_unregistered_dialects = True
713  module = Module.parse(
714      r"""
715    %0 = "custom.op1"() : () -> f32
716    %1 = "custom.op2"() : () -> i32
717    %2 = "custom.op1"() : () -> f32
718  """, ctx)
719
720  # CHECK: custom.op1
721  # CHECK: custom.op2
722  # CHECK: custom.op1
723  for op in module.body.operations:
724    print(op.operation.name)
725
726
727# CHECK-LABEL: TEST: testCapsuleConversions
728@run
729def testCapsuleConversions():
730  ctx = Context()
731  ctx.allow_unregistered_dialects = True
732  with Location.unknown(ctx):
733    m = Operation.create("custom.op1").operation
734    m_capsule = m._CAPIPtr
735    assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule)
736    m2 = Operation._CAPICreate(m_capsule)
737    assert m2 is m
738
739
740# CHECK-LABEL: TEST: testOperationErase
741@run
742def testOperationErase():
743  ctx = Context()
744  ctx.allow_unregistered_dialects = True
745  with Location.unknown(ctx):
746    m = Module.create()
747    with InsertionPoint(m.body):
748      op = Operation.create("custom.op1")
749
750      # CHECK: "custom.op1"
751      print(m)
752
753      op.operation.erase()
754
755      # CHECK-NOT: "custom.op1"
756      print(m)
757
758      # Ensure we can create another operation
759      Operation.create("custom.op2")
760
761
762# CHECK-LABEL: TEST: testOperationLoc
763@run
764def testOperationLoc():
765  ctx = Context()
766  ctx.allow_unregistered_dialects = True
767  with ctx:
768    loc = Location.name("loc")
769    op = Operation.create("custom.op", loc=loc)
770    assert op.location == loc
771    assert op.operation.location == loc
772
773
774# CHECK-LABEL: TEST: testModuleMerge
775@run
776def testModuleMerge():
777  with Context():
778    m1 = Module.parse("func private @foo()")
779    m2 = Module.parse("""
780      func private @bar()
781      func private @qux()
782    """)
783    foo = m1.body.operations[0]
784    bar = m2.body.operations[0]
785    qux = m2.body.operations[1]
786    bar.move_before(foo)
787    qux.move_after(foo)
788
789    # CHECK: module
790    # CHECK: func private @bar
791    # CHECK: func private @foo
792    # CHECK: func private @qux
793    print(m1)
794
795    # CHECK: module {
796    # CHECK-NEXT: }
797    print(m2)
798
799
800# CHECK-LABEL: TEST: testAppendMoveFromAnotherBlock
801@run
802def testAppendMoveFromAnotherBlock():
803  with Context():
804    m1 = Module.parse("func private @foo()")
805    m2 = Module.parse("func private @bar()")
806    func = m1.body.operations[0]
807    m2.body.append(func)
808
809    # CHECK: module
810    # CHECK: func private @bar
811    # CHECK: func private @foo
812
813    print(m2)
814    # CHECK: module {
815    # CHECK-NEXT: }
816    print(m1)
817
818
819# CHECK-LABEL: TEST: testDetachFromParent
820@run
821def testDetachFromParent():
822  with Context():
823    m1 = Module.parse("func private @foo()")
824    func = m1.body.operations[0].detach_from_parent()
825
826    try:
827      func.detach_from_parent()
828    except ValueError as e:
829      if "has no parent" not in str(e):
830        raise
831    else:
832      assert False, "expected ValueError when detaching a detached operation"
833
834    print(m1)
835    # CHECK-NOT: func private @foo
836
837
838# CHECK-LABEL: TEST: testOperationHash
839@run
840def testOperationHash():
841  ctx = Context()
842  ctx.allow_unregistered_dialects = True
843  with ctx, Location.unknown():
844    op = Operation.create("custom.op1")
845    assert hash(op) == hash(op.operation)
846