1# RUN: %PYTHON %s | FileCheck %s
2
3import gc
4from mlir.ir import *
5
6def run(f):
7  print("\nTEST:", f.__name__)
8  f()
9  gc.collect()
10  assert Context._get_live_count() == 0
11  return f
12
13
14@run
15def testLifecycleContextDestroy():
16  ctx = Context()
17  def callback(foo): ...
18  handler = ctx.attach_diagnostic_handler(callback)
19  assert handler.attached
20  # If context is destroyed before the handler, it should auto-detach.
21  ctx = None
22  gc.collect()
23  assert not handler.attached
24
25  # And finally collecting the handler should be fine.
26  handler = None
27  gc.collect()
28
29
30@run
31def testLifecycleExplicitDetach():
32  ctx = Context()
33  def callback(foo): ...
34  handler = ctx.attach_diagnostic_handler(callback)
35  assert handler.attached
36  handler.detach()
37  assert not handler.attached
38
39
40@run
41def testLifecycleWith():
42  ctx = Context()
43  def callback(foo): ...
44  with ctx.attach_diagnostic_handler(callback) as handler:
45    assert handler.attached
46  assert not handler.attached
47
48
49@run
50def testLifecycleWithAndExplicitDetach():
51  ctx = Context()
52  def callback(foo): ...
53  with ctx.attach_diagnostic_handler(callback) as handler:
54    assert handler.attached
55    handler.detach()
56  assert not handler.attached
57
58
59# CHECK-LABEL: TEST: testDiagnosticCallback
60@run
61def testDiagnosticCallback():
62  ctx = Context()
63  def callback(d):
64    # CHECK: DIAGNOSTIC: message='foobar', severity=DiagnosticSeverity.ERROR, loc=loc(unknown)
65    print(f"DIAGNOSTIC: message='{d.message}', severity={d.severity}, loc={d.location}")
66    return True
67  handler = ctx.attach_diagnostic_handler(callback)
68  loc = Location.unknown(ctx)
69  loc.emit_error("foobar")
70  assert not handler.had_error
71
72
73# CHECK-LABEL: TEST: testDiagnosticEmptyNotes
74# TODO: Come up with a way to inject a diagnostic with notes from this API.
75@run
76def testDiagnosticEmptyNotes():
77  ctx = Context()
78  def callback(d):
79    # CHECK: DIAGNOSTIC: notes=()
80    print(f"DIAGNOSTIC: notes={d.notes}")
81    return True
82  handler = ctx.attach_diagnostic_handler(callback)
83  loc = Location.unknown(ctx)
84  loc.emit_error("foobar")
85  assert not handler.had_error
86
87
88# CHECK-LABEL: TEST: testDiagnosticNonEmptyNotes
89@run
90def testDiagnosticNonEmptyNotes():
91  ctx = Context()
92  def callback(d):
93    # CHECK: DIAGNOSTIC:
94    # CHECK:   message='arith.addi' op requires one result
95    # CHECK:   notes=['see current operation: "arith.addi"() : () -> ()']
96    print(f"DIAGNOSTIC:")
97    print(f"  message={d.message}")
98    print(f"  notes={list(map(str, d.notes))}")
99    return True
100  handler = ctx.attach_diagnostic_handler(callback)
101  loc = Location.unknown(ctx)
102  Operation.create('arith.addi', loc=loc).verify()
103  assert not handler.had_error
104
105# CHECK-LABEL: TEST: testDiagnosticCallbackException
106@run
107def testDiagnosticCallbackException():
108  ctx = Context()
109  def callback(d):
110    raise ValueError("Error in handler")
111  handler = ctx.attach_diagnostic_handler(callback)
112  loc = Location.unknown(ctx)
113  loc.emit_error("foobar")
114  assert handler.had_error
115
116
117# CHECK-LABEL: TEST: testEscapingDiagnostic
118@run
119def testEscapingDiagnostic():
120  ctx = Context()
121  diags = []
122  def callback(d):
123    diags.append(d)
124    return True
125  handler = ctx.attach_diagnostic_handler(callback)
126  loc = Location.unknown(ctx)
127  loc.emit_error("foobar")
128  assert not handler.had_error
129
130  # CHECK: DIAGNOSTIC: <Invalid Diagnostic>
131  print(f"DIAGNOSTIC: {str(diags[0])}")
132  try:
133    diags[0].severity
134    raise RuntimeError("expected exception")
135  except ValueError:
136    pass
137  try:
138    diags[0].location
139    raise RuntimeError("expected exception")
140  except ValueError:
141    pass
142  try:
143    diags[0].message
144    raise RuntimeError("expected exception")
145  except ValueError:
146    pass
147  try:
148    diags[0].notes
149    raise RuntimeError("expected exception")
150  except ValueError:
151    pass
152
153
154
155# CHECK-LABEL: TEST: testDiagnosticReturnTrueHandles
156@run
157def testDiagnosticReturnTrueHandles():
158  ctx = Context()
159  def callback1(d):
160    print(f"CALLBACK1: {d}")
161    return True
162  def callback2(d):
163    print(f"CALLBACK2: {d}")
164    return True
165  ctx.attach_diagnostic_handler(callback1)
166  ctx.attach_diagnostic_handler(callback2)
167  loc = Location.unknown(ctx)
168  # CHECK-NOT: CALLBACK1
169  # CHECK: CALLBACK2: foobar
170  # CHECK-NOT: CALLBACK1
171  loc.emit_error("foobar")
172
173
174# CHECK-LABEL: TEST: testDiagnosticReturnFalseDoesNotHandle
175@run
176def testDiagnosticReturnFalseDoesNotHandle():
177  ctx = Context()
178  def callback1(d):
179    print(f"CALLBACK1: {d}")
180    return True
181  def callback2(d):
182    print(f"CALLBACK2: {d}")
183    return False
184  ctx.attach_diagnostic_handler(callback1)
185  ctx.attach_diagnostic_handler(callback2)
186  loc = Location.unknown(ctx)
187  # CHECK: CALLBACK2: foobar
188  # CHECK: CALLBACK1: foobar
189  loc.emit_error("foobar")
190