1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.ir import * 4import mlir.dialects.builtin as builtin 5import mlir.dialects.std as std 6 7 8def run(f): 9 print("\nTEST:", f.__name__) 10 f() 11 return f 12 13 14# CHECK-LABEL: TEST: testFromPyFunc 15@run 16def testFromPyFunc(): 17 with Context() as ctx, Location.unknown() as loc: 18 m = builtin.ModuleOp() 19 f32 = F32Type.get() 20 f64 = F64Type.get() 21 with InsertionPoint(m.body): 22 # CHECK-LABEL: func @unary_return(%arg0: f64) -> f64 23 # CHECK: return %arg0 : f64 24 @builtin.FuncOp.from_py_func(f64) 25 def unary_return(a): 26 return a 27 28 # CHECK-LABEL: func @binary_return(%arg0: f32, %arg1: f64) -> (f32, f64) 29 # CHECK: return %arg0, %arg1 : f32, f64 30 @builtin.FuncOp.from_py_func(f32, f64) 31 def binary_return(a, b): 32 return a, b 33 34 # CHECK-LABEL: func @none_return(%arg0: f32, %arg1: f64) 35 # CHECK: return 36 @builtin.FuncOp.from_py_func(f32, f64) 37 def none_return(a, b): 38 pass 39 40 # CHECK-LABEL: func @call_unary 41 # CHECK: %0 = call @unary_return(%arg0) : (f64) -> f64 42 # CHECK: return %0 : f64 43 @builtin.FuncOp.from_py_func(f64) 44 def call_unary(a): 45 return unary_return(a) 46 47 # CHECK-LABEL: func @call_binary 48 # CHECK: %0:2 = call @binary_return(%arg0, %arg1) : (f32, f64) -> (f32, f64) 49 # CHECK: return %0#0, %0#1 : f32, f64 50 @builtin.FuncOp.from_py_func(f32, f64) 51 def call_binary(a, b): 52 return binary_return(a, b) 53 54 # CHECK-LABEL: func @call_none 55 # CHECK: call @none_return(%arg0, %arg1) : (f32, f64) -> () 56 # CHECK: return 57 @builtin.FuncOp.from_py_func(f32, f64) 58 def call_none(a, b): 59 return none_return(a, b) 60 61 ## Variants and optional feature tests. 62 # CHECK-LABEL: func @from_name_arg 63 @builtin.FuncOp.from_py_func(f32, f64, name="from_name_arg") 64 def explicit_name(a, b): 65 return b 66 67 @builtin.FuncOp.from_py_func(f32, f64) 68 def positional_func_op(a, b, func_op): 69 assert isinstance(func_op, builtin.FuncOp) 70 return b 71 72 @builtin.FuncOp.from_py_func(f32, f64) 73 def kw_func_op(a, b=None, func_op=None): 74 assert isinstance(func_op, builtin.FuncOp) 75 return b 76 77 @builtin.FuncOp.from_py_func(f32, f64) 78 def kwargs_func_op(a, b=None, **kwargs): 79 assert isinstance(kwargs["func_op"], builtin.FuncOp) 80 return b 81 82 # CHECK-LABEL: func @explicit_results(%arg0: f32, %arg1: f64) -> f64 83 # CHECK: return %arg1 : f64 84 @builtin.FuncOp.from_py_func(f32, f64, results=[f64]) 85 def explicit_results(a, b): 86 std.ReturnOp([b]) 87 88 print(m) 89 90 91# CHECK-LABEL: TEST: testFromPyFuncErrors 92@run 93def testFromPyFuncErrors(): 94 with Context() as ctx, Location.unknown() as loc: 95 m = builtin.ModuleOp() 96 f32 = F32Type.get() 97 f64 = F64Type.get() 98 with InsertionPoint(m.body): 99 try: 100 101 @builtin.FuncOp.from_py_func(f64, results=[f64]) 102 def unary_return(a): 103 return a 104 except AssertionError as e: 105 # CHECK: Capturing a python function with explicit `results=` requires that the wrapped function returns None. 106 print(e) 107 108 109# CHECK-LABEL: TEST: testBuildFuncOp 110@run 111def testBuildFuncOp(): 112 ctx = Context() 113 with Location.unknown(ctx) as loc: 114 m = builtin.ModuleOp() 115 116 f32 = F32Type.get() 117 tensor_type = RankedTensorType.get((2, 3, 4), f32) 118 with InsertionPoint.at_block_begin(m.body): 119 func = builtin.FuncOp(name="some_func", 120 type=FunctionType.get( 121 inputs=[tensor_type, tensor_type], 122 results=[tensor_type]), 123 visibility="nested") 124 # CHECK: Name is: "some_func" 125 print("Name is: ", func.name) 126 127 # CHECK: Type is: (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32> 128 print("Type is: ", func.type) 129 130 # CHECK: Visibility is: "nested" 131 print("Visibility is: ", func.visibility) 132 133 try: 134 entry_block = func.entry_block 135 except IndexError as e: 136 # CHECK: External function does not have a body 137 print(e) 138 139 with InsertionPoint(func.add_entry_block()): 140 std.ReturnOp([func.entry_block.arguments[0]]) 141 pass 142 143 try: 144 func.add_entry_block() 145 except IndexError as e: 146 # CHECK: The function already has an entry block! 147 print(e) 148 149 # Try the callback builder and passing type as tuple. 150 func = builtin.FuncOp(name="some_other_func", 151 type=([tensor_type, tensor_type], [tensor_type]), 152 visibility="nested", 153 body_builder=lambda func: std.ReturnOp( 154 [func.entry_block.arguments[0]])) 155 156 # CHECK: module { 157 # CHECK: func nested @some_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { 158 # CHECK: return %arg0 : tensor<2x3x4xf32> 159 # CHECK: } 160 # CHECK: func nested @some_other_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { 161 # CHECK: return %arg0 : tensor<2x3x4xf32> 162 # CHECK: } 163 print(m) 164 165 166# CHECK-LABEL: TEST: testFuncArgumentAccess 167@run 168def testFuncArgumentAccess(): 169 with Context(), Location.unknown(): 170 module = Module.create() 171 f32 = F32Type.get() 172 f64 = F64Type.get() 173 with InsertionPoint(module.body): 174 func = builtin.FuncOp("some_func", ([f32, f32], [f32, f32])) 175 with InsertionPoint(func.add_entry_block()): 176 std.ReturnOp(func.arguments) 177 func.arg_attrs = ArrayAttr.get([ 178 DictAttr.get({ 179 "foo": StringAttr.get("bar"), 180 "baz": UnitAttr.get() 181 }), 182 DictAttr.get({"qux": ArrayAttr.get([])}) 183 ]) 184 func.result_attrs = ArrayAttr.get([ 185 DictAttr.get({"res1": FloatAttr.get(f32, 42.0)}), 186 DictAttr.get({"res2": FloatAttr.get(f64, 256.0)}) 187 ]) 188 189 other = builtin.FuncOp("other_func", ([f32, f32], [])) 190 with InsertionPoint(other.add_entry_block()): 191 std.ReturnOp([]) 192 other.arg_attrs = [ 193 DictAttr.get({"foo": StringAttr.get("qux")}), 194 DictAttr.get() 195 ] 196 197 # CHECK: [{baz, foo = "bar"}, {qux = []}] 198 print(func.arg_attrs) 199 200 # CHECK: [{res1 = 4.200000e+01 : f32}, {res2 = 2.560000e+02 : f64}] 201 print(func.result_attrs) 202 203 # CHECK: func @some_func( 204 # CHECK: %[[ARG0:.*]]: f32 {baz, foo = "bar"}, 205 # CHECK: %[[ARG1:.*]]: f32 {qux = []}) -> 206 # CHECK: f32 {res1 = 4.200000e+01 : f32}, 207 # CHECK: f32 {res2 = 2.560000e+02 : f64}) 208 # CHECK: return %[[ARG0]], %[[ARG1]] : f32, f32 209 # 210 # CHECK: func @other_func( 211 # CHECK: %{{.*}}: f32 {foo = "qux"}, 212 # CHECK: %{{.*}}: f32) 213 print(module) 214