1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.ir import * 4import mlir.dialects.builtin as builtin 5import mlir.dialects.func as func 6 7 8def run(f): 9 print("\nTEST:", f.__name__) 10 f() 11 return f 12 13 14# CHECK-LABEL: TEST: testFromPyFunc 15@run 16def testFromPyFunc(): 17 with Context() as ctx, Location.unknown() as loc: 18 ctx.allow_unregistered_dialects = True 19 m = builtin.ModuleOp() 20 f32 = F32Type.get() 21 f64 = F64Type.get() 22 with InsertionPoint(m.body): 23 # CHECK-LABEL: func @unary_return(%arg0: f64) -> f64 24 # CHECK: return %arg0 : f64 25 @func.FuncOp.from_py_func(f64) 26 def unary_return(a): 27 return a 28 29 # CHECK-LABEL: func @binary_return(%arg0: f32, %arg1: f64) -> (f32, f64) 30 # CHECK: return %arg0, %arg1 : f32, f64 31 @func.FuncOp.from_py_func(f32, f64) 32 def binary_return(a, b): 33 return a, b 34 35 # CHECK-LABEL: func @none_return(%arg0: f32, %arg1: f64) 36 # CHECK: return 37 @func.FuncOp.from_py_func(f32, f64) 38 def none_return(a, b): 39 pass 40 41 # CHECK-LABEL: func @call_unary 42 # CHECK: %0 = call @unary_return(%arg0) : (f64) -> f64 43 # CHECK: return %0 : f64 44 @func.FuncOp.from_py_func(f64) 45 def call_unary(a): 46 return unary_return(a) 47 48 # CHECK-LABEL: func @call_binary 49 # CHECK: %0:2 = call @binary_return(%arg0, %arg1) : (f32, f64) -> (f32, f64) 50 # CHECK: return %0#0, %0#1 : f32, f64 51 @func.FuncOp.from_py_func(f32, f64) 52 def call_binary(a, b): 53 return binary_return(a, b) 54 55 # We expect coercion of a single result operation to a returned value. 56 # CHECK-LABEL: func @single_result_op 57 # CHECK: %0 = "custom.op1"() : () -> f32 58 # CHECK: return %0 : f32 59 @func.FuncOp.from_py_func() 60 def single_result_op(): 61 return Operation.create("custom.op1", results=[f32]) 62 63 # CHECK-LABEL: func @call_none 64 # CHECK: call @none_return(%arg0, %arg1) : (f32, f64) -> () 65 # CHECK: return 66 @func.FuncOp.from_py_func(f32, f64) 67 def call_none(a, b): 68 return none_return(a, b) 69 70 ## Variants and optional feature tests. 71 # CHECK-LABEL: func @from_name_arg 72 @func.FuncOp.from_py_func(f32, f64, name="from_name_arg") 73 def explicit_name(a, b): 74 return b 75 76 @func.FuncOp.from_py_func(f32, f64) 77 def positional_func_op(a, b, func_op): 78 assert isinstance(func_op, func.FuncOp) 79 return b 80 81 @func.FuncOp.from_py_func(f32, f64) 82 def kw_func_op(a, b=None, func_op=None): 83 assert isinstance(func_op, func.FuncOp) 84 return b 85 86 @func.FuncOp.from_py_func(f32, f64) 87 def kwargs_func_op(a, b=None, **kwargs): 88 assert isinstance(kwargs["func_op"], func.FuncOp) 89 return b 90 91 # CHECK-LABEL: func @explicit_results(%arg0: f32, %arg1: f64) -> f64 92 # CHECK: return %arg1 : f64 93 @func.FuncOp.from_py_func(f32, f64, results=[f64]) 94 def explicit_results(a, b): 95 func.ReturnOp([b]) 96 97 print(m) 98 99 100# CHECK-LABEL: TEST: testFromPyFuncErrors 101@run 102def testFromPyFuncErrors(): 103 with Context() as ctx, Location.unknown() as loc: 104 m = builtin.ModuleOp() 105 f32 = F32Type.get() 106 f64 = F64Type.get() 107 with InsertionPoint(m.body): 108 try: 109 110 @func.FuncOp.from_py_func(f64, results=[f64]) 111 def unary_return(a): 112 return a 113 except AssertionError as e: 114 # CHECK: Capturing a python function with explicit `results=` requires that the wrapped function returns None. 115 print(e) 116 117 118# CHECK-LABEL: TEST: testBuildFuncOp 119@run 120def testBuildFuncOp(): 121 ctx = Context() 122 with Location.unknown(ctx) as loc: 123 m = builtin.ModuleOp() 124 125 f32 = F32Type.get() 126 tensor_type = RankedTensorType.get((2, 3, 4), f32) 127 with InsertionPoint.at_block_begin(m.body): 128 f = func.FuncOp(name="some_func", 129 type=FunctionType.get( 130 inputs=[tensor_type, tensor_type], 131 results=[tensor_type]), 132 visibility="nested") 133 # CHECK: Name is: "some_func" 134 print("Name is: ", f.name) 135 136 # CHECK: Type is: (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32> 137 print("Type is: ", f.type) 138 139 # CHECK: Visibility is: "nested" 140 print("Visibility is: ", f.visibility) 141 142 try: 143 entry_block = f.entry_block 144 except IndexError as e: 145 # CHECK: External function does not have a body 146 print(e) 147 148 with InsertionPoint(f.add_entry_block()): 149 func.ReturnOp([f.entry_block.arguments[0]]) 150 pass 151 152 try: 153 f.add_entry_block() 154 except IndexError as e: 155 # CHECK: The function already has an entry block! 156 print(e) 157 158 # Try the callback builder and passing type as tuple. 159 f = func.FuncOp(name="some_other_func", 160 type=([tensor_type, tensor_type], [tensor_type]), 161 visibility="nested", 162 body_builder=lambda f: func.ReturnOp( 163 [f.entry_block.arguments[0]])) 164 165 # CHECK: module { 166 # CHECK: func nested @some_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { 167 # CHECK: return %arg0 : tensor<2x3x4xf32> 168 # CHECK: } 169 # CHECK: func nested @some_other_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { 170 # CHECK: return %arg0 : tensor<2x3x4xf32> 171 # CHECK: } 172 print(m) 173 174 175# CHECK-LABEL: TEST: testFuncArgumentAccess 176@run 177def testFuncArgumentAccess(): 178 with Context() as ctx, Location.unknown(): 179 ctx.allow_unregistered_dialects = True 180 module = Module.create() 181 f32 = F32Type.get() 182 f64 = F64Type.get() 183 with InsertionPoint(module.body): 184 f = func.FuncOp("some_func", ([f32, f32], [f32, f32])) 185 with InsertionPoint(f.add_entry_block()): 186 func.ReturnOp(f.arguments) 187 f.arg_attrs = ArrayAttr.get([ 188 DictAttr.get({ 189 "custom_dialect.foo": StringAttr.get("bar"), 190 "custom_dialect.baz": UnitAttr.get() 191 }), 192 DictAttr.get({"custom_dialect.qux": ArrayAttr.get([])}) 193 ]) 194 f.result_attrs = ArrayAttr.get([ 195 DictAttr.get({"custom_dialect.res1": FloatAttr.get(f32, 42.0)}), 196 DictAttr.get({"custom_dialect.res2": FloatAttr.get(f64, 256.0)}) 197 ]) 198 199 other = func.FuncOp("other_func", ([f32, f32], [])) 200 with InsertionPoint(other.add_entry_block()): 201 func.ReturnOp([]) 202 other.arg_attrs = [ 203 DictAttr.get({"custom_dialect.foo": StringAttr.get("qux")}), 204 DictAttr.get() 205 ] 206 207 # CHECK: [{custom_dialect.baz, custom_dialect.foo = "bar"}, {custom_dialect.qux = []}] 208 print(f.arg_attrs) 209 210 # CHECK: [{custom_dialect.res1 = 4.200000e+01 : f32}, {custom_dialect.res2 = 2.560000e+02 : f64}] 211 print(f.result_attrs) 212 213 # CHECK: func @some_func( 214 # CHECK: %[[ARG0:.*]]: f32 {custom_dialect.baz, custom_dialect.foo = "bar"}, 215 # CHECK: %[[ARG1:.*]]: f32 {custom_dialect.qux = []}) -> 216 # CHECK: f32 {custom_dialect.res1 = 4.200000e+01 : f32}, 217 # CHECK: f32 {custom_dialect.res2 = 2.560000e+02 : f64}) 218 # CHECK: return %[[ARG0]], %[[ARG1]] : f32, f32 219 # 220 # CHECK: func @other_func( 221 # CHECK: %{{.*}}: f32 {custom_dialect.foo = "qux"}, 222 # CHECK: %{{.*}}: f32) 223 print(module) 224