1# RUN: %PYTHON %s | FileCheck %s
2
3import gc
4from mlir.ir import *
5
6def run(f):
7  print("\nTEST:", f.__name__)
8  f()
9  gc.collect()
10  assert Context._get_live_count() == 0
11  return f
12
13
14# Verify successful parse.
15# CHECK-LABEL: TEST: testParseSuccess
16# CHECK: module @successfulParse
17@run
18def testParseSuccess():
19  ctx = Context()
20  module = Module.parse(r"""module @successfulParse {}""", ctx)
21  assert module.context is ctx
22  print("CLEAR CONTEXT")
23  ctx = None  # Ensure that module captures the context.
24  gc.collect()
25  module.dump()  # Just outputs to stderr. Verifies that it functions.
26  print(str(module))
27
28
29# Verify parse error.
30# CHECK-LABEL: TEST: testParseError
31# CHECK: testParseError: Unable to parse module assembly (see diagnostics)
32@run
33def testParseError():
34  ctx = Context()
35  try:
36    module = Module.parse(r"""}SYNTAX ERROR{""", ctx)
37  except ValueError as e:
38    print("testParseError:", e)
39  else:
40    print("Exception not produced")
41
42
43# Verify successful parse.
44# CHECK-LABEL: TEST: testCreateEmpty
45# CHECK: module {
46@run
47def testCreateEmpty():
48  ctx = Context()
49  loc = Location.unknown(ctx)
50  module = Module.create(loc)
51  print("CLEAR CONTEXT")
52  ctx = None  # Ensure that module captures the context.
53  gc.collect()
54  print(str(module))
55
56
57# Verify round-trip of ASM that contains unicode.
58# Note that this does not test that the print path converts unicode properly
59# because MLIR asm always normalizes it to the hex encoding.
60# CHECK-LABEL: TEST: testRoundtripUnicode
61# CHECK: func private @roundtripUnicode()
62# CHECK: foo = "\F0\9F\98\8A"
63@run
64def testRoundtripUnicode():
65  ctx = Context()
66  module = Module.parse(r"""
67    func.func private @roundtripUnicode() attributes { foo = "��" }
68  """, ctx)
69  print(str(module))
70
71
72# Verify round-trip of ASM that contains unicode.
73# Note that this does not test that the print path converts unicode properly
74# because MLIR asm always normalizes it to the hex encoding.
75# CHECK-LABEL: TEST: testRoundtripBinary
76# CHECK: func private @roundtripUnicode()
77# CHECK: foo = "\F0\9F\98\8A"
78@run
79def testRoundtripBinary():
80  with Context():
81    module = Module.parse(r"""
82      func.func private @roundtripUnicode() attributes { foo = "��" }
83    """)
84    binary_asm = module.operation.get_asm(binary=True)
85    assert isinstance(binary_asm, bytes)
86    module = Module.parse(binary_asm)
87    print(module)
88
89
90# Tests that module.operation works and correctly interns instances.
91# CHECK-LABEL: TEST: testModuleOperation
92@run
93def testModuleOperation():
94  ctx = Context()
95  module = Module.parse(r"""module @successfulParse {}""", ctx)
96  assert ctx._get_live_module_count() == 1
97  op1 = module.operation
98  assert ctx._get_live_operation_count() == 1
99  # CHECK: module @successfulParse
100  print(op1)
101
102  # Ensure that operations are the same on multiple calls.
103  op2 = module.operation
104  assert ctx._get_live_operation_count() == 1
105  assert op1 is op2
106
107  # Test live operation clearing.
108  op1 = module.operation
109  assert ctx._get_live_operation_count() == 1
110  num_invalidated = ctx._clear_live_operations()
111  assert num_invalidated == 1
112  assert ctx._get_live_operation_count() == 0
113  op1 = None
114  gc.collect()
115  op1 = module.operation
116
117  # Ensure that if module is de-referenced, the operations are still valid.
118  module = None
119  gc.collect()
120  print(op1)
121
122  # Collect and verify lifetime.
123  op1 = None
124  op2 = None
125  gc.collect()
126  print("LIVE OPERATIONS:", ctx._get_live_operation_count())
127  assert ctx._get_live_operation_count() == 0
128  assert ctx._get_live_module_count() == 0
129
130
131# CHECK-LABEL: TEST: testModuleCapsule
132@run
133def testModuleCapsule():
134  ctx = Context()
135  module = Module.parse(r"""module @successfulParse {}""", ctx)
136  assert ctx._get_live_module_count() == 1
137  # CHECK: "mlir.ir.Module._CAPIPtr"
138  module_capsule = module._CAPIPtr
139  print(module_capsule)
140  module_dup = Module._CAPICreate(module_capsule)
141  assert module is module_dup
142  assert module_dup.context is ctx
143  # Gc and verify destructed.
144  module = None
145  module_capsule = None
146  module_dup = None
147  gc.collect()
148  assert ctx._get_live_module_count() == 0
149
150