1# RUN: %PYTHON %s | FileCheck %s 2 3import gc 4import io 5import itertools 6from mlir.ir import * 7 8 9def run(f): 10 print("\nTEST:", f.__name__) 11 f() 12 gc.collect() 13 assert Context._get_live_count() == 0 14 return f 15 16 17# CHECK-LABEL: TEST: testSymbolTableInsert 18@run 19def testSymbolTableInsert(): 20 with Context() as ctx: 21 ctx.allow_unregistered_dialects = True 22 m1 = Module.parse(""" 23 func.func private @foo() 24 func.func private @bar()""") 25 m2 = Module.parse(""" 26 func.func private @qux() 27 func.func private @foo() 28 "foo.bar"() : () -> ()""") 29 30 symbol_table = SymbolTable(m1.operation) 31 32 # CHECK: func private @foo 33 # CHECK: func private @bar 34 assert "foo" in symbol_table 35 print(symbol_table["foo"]) 36 assert "bar" in symbol_table 37 bar = symbol_table["bar"] 38 print(symbol_table["bar"]) 39 40 assert "qux" not in symbol_table 41 42 del symbol_table["bar"] 43 try: 44 symbol_table.erase(symbol_table["bar"]) 45 except KeyError: 46 pass 47 else: 48 assert False, "expected KeyError" 49 50 # CHECK: module 51 # CHECK: func private @foo() 52 print(m1) 53 assert "bar" not in symbol_table 54 55 try: 56 print(bar) 57 except RuntimeError as e: 58 if "the operation has been invalidated" not in str(e): 59 raise 60 else: 61 assert False, "expected RuntimeError due to invalidated operation" 62 63 qux = m2.body.operations[0] 64 m1.body.append(qux) 65 symbol_table.insert(qux) 66 assert "qux" in symbol_table 67 68 # Check that insertion actually renames this symbol in the symbol table. 69 foo2 = m2.body.operations[0] 70 m1.body.append(foo2) 71 updated_name = symbol_table.insert(foo2) 72 assert foo2.name.value != "foo" 73 assert foo2.name == updated_name 74 75 # CHECK: module 76 # CHECK: func private @foo() 77 # CHECK: func private @qux() 78 # CHECK: func private @foo{{.*}} 79 print(m1) 80 81 try: 82 symbol_table.insert(m2.body.operations[0]) 83 except ValueError as e: 84 if "Expected operation to have a symbol name" not in str(e): 85 raise 86 else: 87 assert False, "exepcted ValueError when adding a non-symbol" 88 89 90# CHECK-LABEL: testSymbolTableRAUW 91@run 92def testSymbolTableRAUW(): 93 with Context() as ctx: 94 m = Module.parse(""" 95 func.func private @foo() { 96 call @bar() : () -> () 97 return 98 } 99 func.func private @bar() 100 """) 101 foo, bar = list(m.operation.regions[0].blocks[0].operations)[0:2] 102 SymbolTable.set_symbol_name(bar, "bam") 103 # Note that module.operation counts as a "nested symbol table" which won't 104 # be traversed into, so it is necessary to traverse its children. 105 SymbolTable.replace_all_symbol_uses("bar", "bam", foo) 106 # CHECK: call @bam() 107 # CHECK: func private @bam 108 print(m) 109 # CHECK: Foo symbol: "foo" 110 # CHECK: Bar symbol: "bam" 111 print(f"Foo symbol: {SymbolTable.get_symbol_name(foo)}") 112 print(f"Bar symbol: {SymbolTable.get_symbol_name(bar)}") 113 114 115# CHECK-LABEL: testSymbolTableVisibility 116@run 117def testSymbolTableVisibility(): 118 with Context() as ctx: 119 m = Module.parse(""" 120 func.func private @foo() { 121 return 122 } 123 """) 124 foo = m.operation.regions[0].blocks[0].operations[0] 125 # CHECK: Existing visibility: "private" 126 print(f"Existing visibility: {SymbolTable.get_visibility(foo)}") 127 SymbolTable.set_visibility(foo, "public") 128 # CHECK: func public @foo 129 print(m) 130 131 132# CHECK: testWalkSymbolTables 133@run 134def testWalkSymbolTables(): 135 with Context() as ctx: 136 m = Module.parse(""" 137 module @outer { 138 module @inner{ 139 } 140 } 141 """) 142 def callback(symbol_table_op, uses_visible): 143 print(f"SYMBOL TABLE: {uses_visible}: {symbol_table_op}") 144 # CHECK: SYMBOL TABLE: True: module @inner 145 # CHECK: SYMBOL TABLE: True: module @outer 146 SymbolTable.walk_symbol_tables(m.operation, True, callback) 147 148 # Make sure exceptions in the callback are handled. 149 def error_callback(symbol_table_op, uses_visible): 150 assert False, "Raised from python" 151 try: 152 SymbolTable.walk_symbol_tables(m.operation, True, error_callback) 153 except RuntimeError as e: 154 # CHECK: GOT EXCEPTION: Exception raised in callback: AssertionError: Raised from python 155 print(f"GOT EXCEPTION: {e}") 156 157