1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4import mlir.dialects.builtin as builtin
5import mlir.dialects.std as std
6
7
8def run(f):
9  print("\nTEST:", f.__name__)
10  f()
11  return f
12
13
14# CHECK-LABEL: TEST: testFromPyFunc
15@run
16def testFromPyFunc():
17  with Context() as ctx, Location.unknown() as loc:
18    m = builtin.ModuleOp()
19    f32 = F32Type.get()
20    f64 = F64Type.get()
21    with InsertionPoint(m.body):
22      # CHECK-LABEL: func @unary_return(%arg0: f64) -> f64
23      # CHECK: return %arg0 : f64
24      @builtin.FuncOp.from_py_func(f64)
25      def unary_return(a):
26        return a
27
28      # CHECK-LABEL: func @binary_return(%arg0: f32, %arg1: f64) -> (f32, f64)
29      # CHECK: return %arg0, %arg1 : f32, f64
30      @builtin.FuncOp.from_py_func(f32, f64)
31      def binary_return(a, b):
32        return a, b
33
34      # CHECK-LABEL: func @none_return(%arg0: f32, %arg1: f64)
35      # CHECK: return
36      @builtin.FuncOp.from_py_func(f32, f64)
37      def none_return(a, b):
38        pass
39
40      # CHECK-LABEL: func @call_unary
41      # CHECK: %0 = call @unary_return(%arg0) : (f64) -> f64
42      # CHECK: return %0 : f64
43      @builtin.FuncOp.from_py_func(f64)
44      def call_unary(a):
45        return unary_return(a)
46
47      # CHECK-LABEL: func @call_binary
48      # CHECK: %0:2 = call @binary_return(%arg0, %arg1) : (f32, f64) -> (f32, f64)
49      # CHECK: return %0#0, %0#1 : f32, f64
50      @builtin.FuncOp.from_py_func(f32, f64)
51      def call_binary(a, b):
52        return binary_return(a, b)
53
54      # CHECK-LABEL: func @call_none
55      # CHECK: call @none_return(%arg0, %arg1) : (f32, f64) -> ()
56      # CHECK: return
57      @builtin.FuncOp.from_py_func(f32, f64)
58      def call_none(a, b):
59        return none_return(a, b)
60
61      ## Variants and optional feature tests.
62      # CHECK-LABEL: func @from_name_arg
63      @builtin.FuncOp.from_py_func(f32, f64, name="from_name_arg")
64      def explicit_name(a, b):
65        return b
66
67      @builtin.FuncOp.from_py_func(f32, f64)
68      def positional_func_op(a, b, func_op):
69        assert isinstance(func_op, builtin.FuncOp)
70        return b
71
72      @builtin.FuncOp.from_py_func(f32, f64)
73      def kw_func_op(a, b=None, func_op=None):
74        assert isinstance(func_op, builtin.FuncOp)
75        return b
76
77      @builtin.FuncOp.from_py_func(f32, f64)
78      def kwargs_func_op(a, b=None, **kwargs):
79        assert isinstance(kwargs["func_op"], builtin.FuncOp)
80        return b
81
82      # CHECK-LABEL: func @explicit_results(%arg0: f32, %arg1: f64) -> f64
83      # CHECK: return %arg1 : f64
84      @builtin.FuncOp.from_py_func(f32, f64, results=[f64])
85      def explicit_results(a, b):
86        std.ReturnOp([b])
87
88  print(m)
89
90
91# CHECK-LABEL: TEST: testFromPyFuncErrors
92@run
93def testFromPyFuncErrors():
94  with Context() as ctx, Location.unknown() as loc:
95    m = builtin.ModuleOp()
96    f32 = F32Type.get()
97    f64 = F64Type.get()
98    with InsertionPoint(m.body):
99      try:
100
101        @builtin.FuncOp.from_py_func(f64, results=[f64])
102        def unary_return(a):
103          return a
104      except AssertionError as e:
105        # CHECK: Capturing a python function with explicit `results=` requires that the wrapped function returns None.
106        print(e)
107
108
109# CHECK-LABEL: TEST: testBuildFuncOp
110@run
111def testBuildFuncOp():
112  ctx = Context()
113  with Location.unknown(ctx) as loc:
114    m = builtin.ModuleOp()
115
116    f32 = F32Type.get()
117    tensor_type = RankedTensorType.get((2, 3, 4), f32)
118    with InsertionPoint.at_block_begin(m.body):
119      func = builtin.FuncOp(name="some_func",
120                            type=FunctionType.get(
121                                inputs=[tensor_type, tensor_type],
122                                results=[tensor_type]),
123                            visibility="nested")
124      # CHECK: Name is: "some_func"
125      print("Name is: ", func.name)
126
127      # CHECK: Type is: (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
128      print("Type is: ", func.type)
129
130      # CHECK: Visibility is: "nested"
131      print("Visibility is: ", func.visibility)
132
133      try:
134        entry_block = func.entry_block
135      except IndexError as e:
136        # CHECK: External function does not have a body
137        print(e)
138
139      with InsertionPoint(func.add_entry_block()):
140        std.ReturnOp([func.entry_block.arguments[0]])
141        pass
142
143      try:
144        func.add_entry_block()
145      except IndexError as e:
146        # CHECK: The function already has an entry block!
147        print(e)
148
149      # Try the callback builder and passing type as tuple.
150      func = builtin.FuncOp(name="some_other_func",
151                            type=([tensor_type, tensor_type], [tensor_type]),
152                            visibility="nested",
153                            body_builder=lambda func: std.ReturnOp(
154                                [func.entry_block.arguments[0]]))
155
156  # CHECK: module  {
157  # CHECK:  func nested @some_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
158  # CHECK:   return %arg0 : tensor<2x3x4xf32>
159  # CHECK:  }
160  # CHECK:  func nested @some_other_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
161  # CHECK:   return %arg0 : tensor<2x3x4xf32>
162  # CHECK:  }
163  print(m)
164