1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.ir import * 4import mlir.dialects.builtin as builtin 5import mlir.dialects.std as std 6 7 8def run(f): 9 print("\nTEST:", f.__name__) 10 f() 11 return f 12 13 14# CHECK-LABEL: TEST: testFromPyFunc 15@run 16def testFromPyFunc(): 17 with Context() as ctx, Location.unknown() as loc: 18 m = builtin.ModuleOp() 19 f32 = F32Type.get() 20 f64 = F64Type.get() 21 with InsertionPoint(m.body): 22 # CHECK-LABEL: func @unary_return(%arg0: f64) -> f64 23 # CHECK: return %arg0 : f64 24 @builtin.FuncOp.from_py_func(f64) 25 def unary_return(a): 26 return a 27 28 # CHECK-LABEL: func @binary_return(%arg0: f32, %arg1: f64) -> (f32, f64) 29 # CHECK: return %arg0, %arg1 : f32, f64 30 @builtin.FuncOp.from_py_func(f32, f64) 31 def binary_return(a, b): 32 return a, b 33 34 # CHECK-LABEL: func @none_return(%arg0: f32, %arg1: f64) 35 # CHECK: return 36 @builtin.FuncOp.from_py_func(f32, f64) 37 def none_return(a, b): 38 pass 39 40 # CHECK-LABEL: func @call_unary 41 # CHECK: %0 = call @unary_return(%arg0) : (f64) -> f64 42 # CHECK: return %0 : f64 43 @builtin.FuncOp.from_py_func(f64) 44 def call_unary(a): 45 return unary_return(a) 46 47 # CHECK-LABEL: func @call_binary 48 # CHECK: %0:2 = call @binary_return(%arg0, %arg1) : (f32, f64) -> (f32, f64) 49 # CHECK: return %0#0, %0#1 : f32, f64 50 @builtin.FuncOp.from_py_func(f32, f64) 51 def call_binary(a, b): 52 return binary_return(a, b) 53 54 # CHECK-LABEL: func @call_none 55 # CHECK: call @none_return(%arg0, %arg1) : (f32, f64) -> () 56 # CHECK: return 57 @builtin.FuncOp.from_py_func(f32, f64) 58 def call_none(a, b): 59 return none_return(a, b) 60 61 ## Variants and optional feature tests. 62 # CHECK-LABEL: func @from_name_arg 63 @builtin.FuncOp.from_py_func(f32, f64, name="from_name_arg") 64 def explicit_name(a, b): 65 return b 66 67 @builtin.FuncOp.from_py_func(f32, f64) 68 def positional_func_op(a, b, func_op): 69 assert isinstance(func_op, builtin.FuncOp) 70 return b 71 72 @builtin.FuncOp.from_py_func(f32, f64) 73 def kw_func_op(a, b=None, func_op=None): 74 assert isinstance(func_op, builtin.FuncOp) 75 return b 76 77 @builtin.FuncOp.from_py_func(f32, f64) 78 def kwargs_func_op(a, b=None, **kwargs): 79 assert isinstance(kwargs["func_op"], builtin.FuncOp) 80 return b 81 82 # CHECK-LABEL: func @explicit_results(%arg0: f32, %arg1: f64) -> f64 83 # CHECK: return %arg1 : f64 84 @builtin.FuncOp.from_py_func(f32, f64, results=[f64]) 85 def explicit_results(a, b): 86 std.ReturnOp([b]) 87 88 print(m) 89 90 91# CHECK-LABEL: TEST: testFromPyFuncErrors 92@run 93def testFromPyFuncErrors(): 94 with Context() as ctx, Location.unknown() as loc: 95 m = builtin.ModuleOp() 96 f32 = F32Type.get() 97 f64 = F64Type.get() 98 with InsertionPoint(m.body): 99 try: 100 101 @builtin.FuncOp.from_py_func(f64, results=[f64]) 102 def unary_return(a): 103 return a 104 except AssertionError as e: 105 # CHECK: Capturing a python function with explicit `results=` requires that the wrapped function returns None. 106 print(e) 107 108 109# CHECK-LABEL: TEST: testBuildFuncOp 110@run 111def testBuildFuncOp(): 112 ctx = Context() 113 with Location.unknown(ctx) as loc: 114 m = builtin.ModuleOp() 115 116 f32 = F32Type.get() 117 tensor_type = RankedTensorType.get((2, 3, 4), f32) 118 with InsertionPoint.at_block_begin(m.body): 119 func = builtin.FuncOp(name="some_func", 120 type=FunctionType.get( 121 inputs=[tensor_type, tensor_type], 122 results=[tensor_type]), 123 visibility="nested") 124 # CHECK: Name is: "some_func" 125 print("Name is: ", func.name) 126 127 # CHECK: Type is: (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32> 128 print("Type is: ", func.type) 129 130 # CHECK: Visibility is: "nested" 131 print("Visibility is: ", func.visibility) 132 133 try: 134 entry_block = func.entry_block 135 except IndexError as e: 136 # CHECK: External function does not have a body 137 print(e) 138 139 with InsertionPoint(func.add_entry_block()): 140 std.ReturnOp([func.entry_block.arguments[0]]) 141 pass 142 143 try: 144 func.add_entry_block() 145 except IndexError as e: 146 # CHECK: The function already has an entry block! 147 print(e) 148 149 # Try the callback builder and passing type as tuple. 150 func = builtin.FuncOp(name="some_other_func", 151 type=([tensor_type, tensor_type], [tensor_type]), 152 visibility="nested", 153 body_builder=lambda func: std.ReturnOp( 154 [func.entry_block.arguments[0]])) 155 156 # CHECK: module { 157 # CHECK: func nested @some_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { 158 # CHECK: return %arg0 : tensor<2x3x4xf32> 159 # CHECK: } 160 # CHECK: func nested @some_other_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { 161 # CHECK: return %arg0 : tensor<2x3x4xf32> 162 # CHECK: } 163 print(m) 164