1# RUN: %PYTHON %s | FileCheck %s
2
3import gc
4import io
5import itertools
6from mlir.ir import *
7
8
9def run(f):
10  print("\nTEST:", f.__name__)
11  f()
12  gc.collect()
13  assert Context._get_live_count() == 0
14  return f
15
16
17# CHECK-LABEL: TEST: testSymbolTableInsert
18@run
19def testSymbolTableInsert():
20  with Context() as ctx:
21    ctx.allow_unregistered_dialects = True
22    m1 = Module.parse("""
23      func.func private @foo()
24      func.func private @bar()""")
25    m2 = Module.parse("""
26      func.func private @qux()
27      func.func private @foo()
28      "foo.bar"() : () -> ()""")
29
30    symbol_table = SymbolTable(m1.operation)
31
32    # CHECK: func private @foo
33    # CHECK: func private @bar
34    assert "foo" in symbol_table
35    print(symbol_table["foo"])
36    assert "bar" in symbol_table
37    bar = symbol_table["bar"]
38    print(symbol_table["bar"])
39
40    assert "qux" not in symbol_table
41
42    del symbol_table["bar"]
43    try:
44      symbol_table.erase(symbol_table["bar"])
45    except KeyError:
46      pass
47    else:
48      assert False, "expected KeyError"
49
50    # CHECK: module
51    # CHECK:   func private @foo()
52    print(m1)
53    assert "bar" not in symbol_table
54
55    try:
56      print(bar)
57    except RuntimeError as e:
58      if "the operation has been invalidated" not in str(e):
59        raise
60    else:
61      assert False, "expected RuntimeError due to invalidated operation"
62
63    qux = m2.body.operations[0]
64    m1.body.append(qux)
65    symbol_table.insert(qux)
66    assert "qux" in symbol_table
67
68    # Check that insertion actually renames this symbol in the symbol table.
69    foo2 = m2.body.operations[0]
70    m1.body.append(foo2)
71    updated_name = symbol_table.insert(foo2)
72    assert foo2.name.value != "foo"
73    assert foo2.name == updated_name
74
75    # CHECK: module
76    # CHECK:   func private @foo()
77    # CHECK:   func private @qux()
78    # CHECK:   func private @foo{{.*}}
79    print(m1)
80
81    try:
82      symbol_table.insert(m2.body.operations[0])
83    except ValueError as e:
84      if "Expected operation to have a symbol name" not in str(e):
85        raise
86    else:
87      assert False, "exepcted ValueError when adding a non-symbol"
88
89
90# CHECK-LABEL: testSymbolTableRAUW
91@run
92def testSymbolTableRAUW():
93  with Context() as ctx:
94    m = Module.parse("""
95      func.func private @foo() {
96        call @bar() : () -> ()
97        return
98      }
99      func.func private @bar()
100      """)
101    foo, bar = list(m.operation.regions[0].blocks[0].operations)[0:2]
102    SymbolTable.set_symbol_name(bar, "bam")
103    # Note that module.operation counts as a "nested symbol table" which won't
104    # be traversed into, so it is necessary to traverse its children.
105    SymbolTable.replace_all_symbol_uses("bar", "bam", foo)
106    # CHECK: call @bam()
107    # CHECK: func private @bam
108    print(m)
109    # CHECK: Foo symbol: "foo"
110    # CHECK: Bar symbol: "bam"
111    print(f"Foo symbol: {SymbolTable.get_symbol_name(foo)}")
112    print(f"Bar symbol: {SymbolTable.get_symbol_name(bar)}")
113
114
115# CHECK-LABEL: testSymbolTableVisibility
116@run
117def testSymbolTableVisibility():
118  with Context() as ctx:
119    m = Module.parse("""
120      func.func private @foo() {
121        return
122      }
123      """)
124    foo = m.operation.regions[0].blocks[0].operations[0]
125    # CHECK: Existing visibility: "private"
126    print(f"Existing visibility: {SymbolTable.get_visibility(foo)}")
127    SymbolTable.set_visibility(foo, "public")
128    # CHECK: func public @foo
129    print(m)
130
131
132# CHECK: testWalkSymbolTables
133@run
134def testWalkSymbolTables():
135  with Context() as ctx:
136    m = Module.parse("""
137      module @outer {
138        module @inner{
139        }
140      }
141      """)
142    def callback(symbol_table_op, uses_visible):
143      print(f"SYMBOL TABLE: {uses_visible}: {symbol_table_op}")
144    # CHECK: SYMBOL TABLE: True: module @inner
145    # CHECK: SYMBOL TABLE: True: module @outer
146    SymbolTable.walk_symbol_tables(m.operation, True, callback)
147
148    # Make sure exceptions in the callback are handled.
149    def error_callback(symbol_table_op, uses_visible):
150      assert False, "Raised from python"
151    try:
152      SymbolTable.walk_symbol_tables(m.operation, True, error_callback)
153    except RuntimeError as e:
154      # CHECK: GOT EXCEPTION: Exception raised in callback: AssertionError: Raised from python
155      print(f"GOT EXCEPTION: {e}")
156
157