1# RUN: %PYTHON %s | FileCheck %s
2
3import gc
4from mlir.ir import *
5
6def run(f):
7  print("\nTEST:", f.__name__)
8  f()
9  gc.collect()
10  assert Context._get_live_count() == 0
11  return f
12
13
14@run
15def testLifecycleContextDestroy():
16  ctx = Context()
17  def callback(foo): ...
18  handler = ctx.attach_diagnostic_handler(callback)
19  assert handler.attached
20  # If context is destroyed before the handler, it should auto-detach.
21  ctx = None
22  gc.collect()
23  assert not handler.attached
24
25  # And finally collecting the handler should be fine.
26  handler = None
27  gc.collect()
28
29
30@run
31def testLifecycleExplicitDetach():
32  ctx = Context()
33  def callback(foo): ...
34  handler = ctx.attach_diagnostic_handler(callback)
35  assert handler.attached
36  handler.detach()
37  assert not handler.attached
38
39
40@run
41def testLifecycleWith():
42  ctx = Context()
43  def callback(foo): ...
44  with ctx.attach_diagnostic_handler(callback) as handler:
45    assert handler.attached
46  assert not handler.attached
47
48
49@run
50def testLifecycleWithAndExplicitDetach():
51  ctx = Context()
52  def callback(foo): ...
53  with ctx.attach_diagnostic_handler(callback) as handler:
54    assert handler.attached
55    handler.detach()
56  assert not handler.attached
57
58
59# CHECK-LABEL: TEST: testDiagnosticCallback
60@run
61def testDiagnosticCallback():
62  ctx = Context()
63  def callback(d):
64    # CHECK: DIAGNOSTIC: message='foobar', severity=DiagnosticSeverity.ERROR, loc=loc(unknown)
65    print(f"DIAGNOSTIC: message='{d.message}', severity={d.severity}, loc={d.location}")
66    return True
67  handler = ctx.attach_diagnostic_handler(callback)
68  loc = Location.unknown(ctx)
69  loc.emit_error("foobar")
70  assert not handler.had_error
71
72
73# CHECK-LABEL: TEST: testDiagnosticEmptyNotes
74# TODO: Come up with a way to inject a diagnostic with notes from this API.
75@run
76def testDiagnosticEmptyNotes():
77  ctx = Context()
78  def callback(d):
79    # CHECK: DIAGNOSTIC: notes=()
80    print(f"DIAGNOSTIC: notes={d.notes}")
81    return True
82  handler = ctx.attach_diagnostic_handler(callback)
83  loc = Location.unknown(ctx)
84  loc.emit_error("foobar")
85  assert not handler.had_error
86
87
88# CHECK-LABEL: TEST: testDiagnosticCallbackException
89@run
90def testDiagnosticCallbackException():
91  ctx = Context()
92  def callback(d):
93    raise ValueError("Error in handler")
94  handler = ctx.attach_diagnostic_handler(callback)
95  loc = Location.unknown(ctx)
96  loc.emit_error("foobar")
97  assert handler.had_error
98
99
100# CHECK-LABEL: TEST: testEscapingDiagnostic
101@run
102def testEscapingDiagnostic():
103  ctx = Context()
104  diags = []
105  def callback(d):
106    diags.append(d)
107    return True
108  handler = ctx.attach_diagnostic_handler(callback)
109  loc = Location.unknown(ctx)
110  loc.emit_error("foobar")
111  assert not handler.had_error
112
113  # CHECK: DIAGNOSTIC: <Invalid Diagnostic>
114  print(f"DIAGNOSTIC: {str(diags[0])}")
115  try:
116    diags[0].severity
117    raise RuntimeError("expected exception")
118  except ValueError:
119    pass
120  try:
121    diags[0].location
122    raise RuntimeError("expected exception")
123  except ValueError:
124    pass
125  try:
126    diags[0].message
127    raise RuntimeError("expected exception")
128  except ValueError:
129    pass
130  try:
131    diags[0].notes
132    raise RuntimeError("expected exception")
133  except ValueError:
134    pass
135
136
137
138# CHECK-LABEL: TEST: testDiagnosticReturnTrueHandles
139@run
140def testDiagnosticReturnTrueHandles():
141  ctx = Context()
142  def callback1(d):
143    print(f"CALLBACK1: {d}")
144    return True
145  def callback2(d):
146    print(f"CALLBACK2: {d}")
147    return True
148  ctx.attach_diagnostic_handler(callback1)
149  ctx.attach_diagnostic_handler(callback2)
150  loc = Location.unknown(ctx)
151  # CHECK-NOT: CALLBACK1
152  # CHECK: CALLBACK2: foobar
153  # CHECK-NOT: CALLBACK1
154  loc.emit_error("foobar")
155
156
157# CHECK-LABEL: TEST: testDiagnosticReturnFalseDoesNotHandle
158@run
159def testDiagnosticReturnFalseDoesNotHandle():
160  ctx = Context()
161  def callback1(d):
162    print(f"CALLBACK1: {d}")
163    return True
164  def callback2(d):
165    print(f"CALLBACK2: {d}")
166    return False
167  ctx.attach_diagnostic_handler(callback1)
168  ctx.attach_diagnostic_handler(callback2)
169  loc = Location.unknown(ctx)
170  # CHECK: CALLBACK2: foobar
171  # CHECK: CALLBACK1: foobar
172  loc.emit_error("foobar")
173