1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4import mlir.dialects.builtin as builtin
5import mlir.dialects.func as func
6
7
8def run(f):
9  print("\nTEST:", f.__name__)
10  f()
11  return f
12
13
14# CHECK-LABEL: TEST: testFromPyFunc
15@run
16def testFromPyFunc():
17  with Context() as ctx, Location.unknown() as loc:
18    ctx.allow_unregistered_dialects = True
19    m = builtin.ModuleOp()
20    f32 = F32Type.get()
21    f64 = F64Type.get()
22    with InsertionPoint(m.body):
23      # CHECK-LABEL: func @unary_return(%arg0: f64) -> f64
24      # CHECK: return %arg0 : f64
25      @func.FuncOp.from_py_func(f64)
26      def unary_return(a):
27        return a
28
29      # CHECK-LABEL: func @binary_return(%arg0: f32, %arg1: f64) -> (f32, f64)
30      # CHECK: return %arg0, %arg1 : f32, f64
31      @func.FuncOp.from_py_func(f32, f64)
32      def binary_return(a, b):
33        return a, b
34
35      # CHECK-LABEL: func @none_return(%arg0: f32, %arg1: f64)
36      # CHECK: return
37      @func.FuncOp.from_py_func(f32, f64)
38      def none_return(a, b):
39        pass
40
41      # CHECK-LABEL: func @call_unary
42      # CHECK: %0 = call @unary_return(%arg0) : (f64) -> f64
43      # CHECK: return %0 : f64
44      @func.FuncOp.from_py_func(f64)
45      def call_unary(a):
46        return unary_return(a)
47
48      # CHECK-LABEL: func @call_binary
49      # CHECK: %0:2 = call @binary_return(%arg0, %arg1) : (f32, f64) -> (f32, f64)
50      # CHECK: return %0#0, %0#1 : f32, f64
51      @func.FuncOp.from_py_func(f32, f64)
52      def call_binary(a, b):
53        return binary_return(a, b)
54
55      # We expect coercion of a single result operation to a returned value.
56      # CHECK-LABEL: func @single_result_op
57      # CHECK: %0 = "custom.op1"() : () -> f32
58      # CHECK: return %0 : f32
59      @func.FuncOp.from_py_func()
60      def single_result_op():
61        return Operation.create("custom.op1", results=[f32])
62
63      # CHECK-LABEL: func @call_none
64      # CHECK: call @none_return(%arg0, %arg1) : (f32, f64) -> ()
65      # CHECK: return
66      @func.FuncOp.from_py_func(f32, f64)
67      def call_none(a, b):
68        return none_return(a, b)
69
70      ## Variants and optional feature tests.
71      # CHECK-LABEL: func @from_name_arg
72      @func.FuncOp.from_py_func(f32, f64, name="from_name_arg")
73      def explicit_name(a, b):
74        return b
75
76      @func.FuncOp.from_py_func(f32, f64)
77      def positional_func_op(a, b, func_op):
78        assert isinstance(func_op, func.FuncOp)
79        return b
80
81      @func.FuncOp.from_py_func(f32, f64)
82      def kw_func_op(a, b=None, func_op=None):
83        assert isinstance(func_op, func.FuncOp)
84        return b
85
86      @func.FuncOp.from_py_func(f32, f64)
87      def kwargs_func_op(a, b=None, **kwargs):
88        assert isinstance(kwargs["func_op"], func.FuncOp)
89        return b
90
91      # CHECK-LABEL: func @explicit_results(%arg0: f32, %arg1: f64) -> f64
92      # CHECK: return %arg1 : f64
93      @func.FuncOp.from_py_func(f32, f64, results=[f64])
94      def explicit_results(a, b):
95        func.ReturnOp([b])
96
97  print(m)
98
99
100# CHECK-LABEL: TEST: testFromPyFuncErrors
101@run
102def testFromPyFuncErrors():
103  with Context() as ctx, Location.unknown() as loc:
104    m = builtin.ModuleOp()
105    f32 = F32Type.get()
106    f64 = F64Type.get()
107    with InsertionPoint(m.body):
108      try:
109
110        @func.FuncOp.from_py_func(f64, results=[f64])
111        def unary_return(a):
112          return a
113      except AssertionError as e:
114        # CHECK: Capturing a python function with explicit `results=` requires that the wrapped function returns None.
115        print(e)
116
117
118# CHECK-LABEL: TEST: testBuildFuncOp
119@run
120def testBuildFuncOp():
121  ctx = Context()
122  with Location.unknown(ctx) as loc:
123    m = builtin.ModuleOp()
124
125    f32 = F32Type.get()
126    tensor_type = RankedTensorType.get((2, 3, 4), f32)
127    with InsertionPoint.at_block_begin(m.body):
128      f = func.FuncOp(name="some_func",
129                            type=FunctionType.get(
130                                inputs=[tensor_type, tensor_type],
131                                results=[tensor_type]),
132                            visibility="nested")
133      # CHECK: Name is: "some_func"
134      print("Name is: ", f.name)
135
136      # CHECK: Type is: (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
137      print("Type is: ", f.type)
138
139      # CHECK: Visibility is: "nested"
140      print("Visibility is: ", f.visibility)
141
142      try:
143        entry_block = f.entry_block
144      except IndexError as e:
145        # CHECK: External function does not have a body
146        print(e)
147
148      with InsertionPoint(f.add_entry_block()):
149        func.ReturnOp([f.entry_block.arguments[0]])
150        pass
151
152      try:
153        f.add_entry_block()
154      except IndexError as e:
155        # CHECK: The function already has an entry block!
156        print(e)
157
158      # Try the callback builder and passing type as tuple.
159      f = func.FuncOp(name="some_other_func",
160                            type=([tensor_type, tensor_type], [tensor_type]),
161                            visibility="nested",
162                            body_builder=lambda f: func.ReturnOp(
163                                [f.entry_block.arguments[0]]))
164
165  # CHECK: module  {
166  # CHECK:  func nested @some_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
167  # CHECK:   return %arg0 : tensor<2x3x4xf32>
168  # CHECK:  }
169  # CHECK:  func nested @some_other_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
170  # CHECK:   return %arg0 : tensor<2x3x4xf32>
171  # CHECK:  }
172  print(m)
173
174
175# CHECK-LABEL: TEST: testFuncArgumentAccess
176@run
177def testFuncArgumentAccess():
178  with Context() as ctx, Location.unknown():
179    ctx.allow_unregistered_dialects = True
180    module = Module.create()
181    f32 = F32Type.get()
182    f64 = F64Type.get()
183    with InsertionPoint(module.body):
184      f = func.FuncOp("some_func", ([f32, f32], [f32, f32]))
185      with InsertionPoint(f.add_entry_block()):
186        func.ReturnOp(f.arguments)
187      f.arg_attrs = ArrayAttr.get([
188          DictAttr.get({
189              "custom_dialect.foo": StringAttr.get("bar"),
190              "custom_dialect.baz": UnitAttr.get()
191          }),
192          DictAttr.get({"custom_dialect.qux": ArrayAttr.get([])})
193      ])
194      f.result_attrs = ArrayAttr.get([
195          DictAttr.get({"custom_dialect.res1": FloatAttr.get(f32, 42.0)}),
196          DictAttr.get({"custom_dialect.res2": FloatAttr.get(f64, 256.0)})
197      ])
198
199      other = func.FuncOp("other_func", ([f32, f32], []))
200      with InsertionPoint(other.add_entry_block()):
201        func.ReturnOp([])
202      other.arg_attrs = [
203          DictAttr.get({"custom_dialect.foo": StringAttr.get("qux")}),
204          DictAttr.get()
205      ]
206
207  # CHECK: [{custom_dialect.baz, custom_dialect.foo = "bar"}, {custom_dialect.qux = []}]
208  print(f.arg_attrs)
209
210  # CHECK: [{custom_dialect.res1 = 4.200000e+01 : f32}, {custom_dialect.res2 = 2.560000e+02 : f64}]
211  print(f.result_attrs)
212
213  # CHECK: func @some_func(
214  # CHECK: %[[ARG0:.*]]: f32 {custom_dialect.baz, custom_dialect.foo = "bar"},
215  # CHECK: %[[ARG1:.*]]: f32 {custom_dialect.qux = []}) ->
216  # CHECK: f32 {custom_dialect.res1 = 4.200000e+01 : f32},
217  # CHECK: f32 {custom_dialect.res2 = 2.560000e+02 : f64})
218  # CHECK: return %[[ARG0]], %[[ARG1]] : f32, f32
219  #
220  # CHECK: func @other_func(
221  # CHECK: %{{.*}}: f32 {custom_dialect.foo = "qux"},
222  # CHECK: %{{.*}}: f32)
223  print(module)
224