1// RUN: mlir-opt %s \
2// RUN:   -func-bufferize -tensor-bufferize -arith-bufferize --canonicalize \
3// RUN:   -convert-scf-to-cf --convert-complex-to-standard \
4// RUN:   -convert-memref-to-llvm -convert-math-to-llvm -convert-math-to-libm \
5// RUN:   -convert-vector-to-llvm -convert-complex-to-llvm \
6// RUN:   -convert-func-to-llvm -reconcile-unrealized-casts |\
7// RUN: mlir-cpu-runner \
8// RUN:  -e entry -entry-point-result=void  \
9// RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext |\
10// RUN: FileCheck %s
11
12func.func @test_unary(%input: tensor<?xcomplex<f32>>,
13                      %func: (complex<f32>) -> complex<f32>) {
14  %c0 = arith.constant 0 : index
15  %c1 = arith.constant 1 : index
16  %size = tensor.dim %input, %c0: tensor<?xcomplex<f32>>
17
18  scf.for %i = %c0 to %size step %c1 {
19    %elem = tensor.extract %input[%i]: tensor<?xcomplex<f32>>
20
21    %val = func.call_indirect %func(%elem) : (complex<f32>) -> complex<f32>
22    %real = complex.re %val : complex<f32>
23    %imag = complex.im %val: complex<f32>
24    vector.print %real : f32
25    vector.print %imag : f32
26    scf.yield
27  }
28  func.return
29}
30
31func.func @sqrt(%arg: complex<f32>) -> complex<f32> {
32  %sqrt = complex.sqrt %arg : complex<f32>
33  func.return %sqrt : complex<f32>
34}
35
36func.func @tanh(%arg: complex<f32>) -> complex<f32> {
37  %tanh = complex.tanh %arg : complex<f32>
38  func.return %tanh : complex<f32>
39}
40
41func.func @rsqrt(%arg: complex<f32>) -> complex<f32> {
42  %sqrt = complex.rsqrt %arg : complex<f32>
43  func.return %sqrt : complex<f32>
44}
45
46func.func @conj(%arg: complex<f32>) -> complex<f32> {
47  %conj = complex.conj %arg : complex<f32>
48  func.return %conj : complex<f32>
49}
50
51// %input contains pairs of lhs, rhs, i.e. [lhs_0, rhs_0, lhs_1, rhs_1,...]
52func.func @test_binary(%input: tensor<?xcomplex<f32>>,
53                       %func: (complex<f32>, complex<f32>) -> complex<f32>) {
54  %c0 = arith.constant 0 : index
55  %c1 = arith.constant 1 : index
56  %c2 = arith.constant 2 : index
57  %size = tensor.dim %input, %c0: tensor<?xcomplex<f32>>
58
59  scf.for %i = %c0 to %size step %c2 {
60    %lhs = tensor.extract %input[%i]: tensor<?xcomplex<f32>>
61    %i_next = arith.addi %i, %c1 : index
62    %rhs = tensor.extract %input[%i_next]: tensor<?xcomplex<f32>>
63
64    %val = func.call_indirect %func(%lhs, %rhs)
65      : (complex<f32>, complex<f32>) -> complex<f32>
66    %real = complex.re %val : complex<f32>
67    %imag = complex.im %val: complex<f32>
68    vector.print %real : f32
69    vector.print %imag : f32
70    scf.yield
71  }
72  func.return
73}
74
75func.func @atan2(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
76  %atan2 = complex.atan2 %lhs, %rhs : complex<f32>
77  func.return %atan2 : complex<f32>
78}
79
80func.func @pow(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
81  %pow = complex.pow %lhs, %rhs : complex<f32>
82  func.return %pow : complex<f32>
83}
84
85func.func @test_element(%input: tensor<?xcomplex<f32>>,
86                      %func: (complex<f32>) -> f32) {
87  %c0 = arith.constant 0 : index
88  %c1 = arith.constant 1 : index
89  %size = tensor.dim %input, %c0: tensor<?xcomplex<f32>>
90
91  scf.for %i = %c0 to %size step %c1 {
92    %elem = tensor.extract %input[%i]: tensor<?xcomplex<f32>>
93
94    %val = func.call_indirect %func(%elem) : (complex<f32>) -> f32
95    vector.print %val : f32
96    scf.yield
97  }
98  func.return
99}
100
101func.func @angle(%arg: complex<f32>) -> f32 {
102  %angle = complex.angle %arg : complex<f32>
103  func.return %angle : f32
104}
105
106func.func @entry() {
107  // complex.sqrt test
108  %sqrt_test = arith.constant dense<[
109    (-1.0, -1.0),
110    // CHECK:       0.455
111    // CHECK-NEXT: -1.098
112    (-1.0, 1.0),
113    // CHECK-NEXT:  0.455
114    // CHECK-NEXT:  1.098
115    (0.0, 0.0),
116    // CHECK-NEXT:  0
117    // CHECK-NEXT:  0
118    (0.0, 1.0),
119    // CHECK-NEXT:  0.707
120    // CHECK-NEXT:  0.707
121    (1.0, -1.0),
122    // CHECK-NEXT:  1.098
123    // CHECK-NEXT:  -0.455
124    (1.0, 0.0),
125    // CHECK-NEXT:  1
126    // CHECK-NEXT:  0
127    (1.0, 1.0)
128    // CHECK-NEXT:  1.098
129    // CHECK-NEXT:  0.455
130  ]> : tensor<7xcomplex<f32>>
131  %sqrt_test_cast = tensor.cast %sqrt_test
132    :  tensor<7xcomplex<f32>> to tensor<?xcomplex<f32>>
133
134  %sqrt_func = func.constant @sqrt : (complex<f32>) -> complex<f32>
135  call @test_unary(%sqrt_test_cast, %sqrt_func)
136    : (tensor<?xcomplex<f32>>, (complex<f32>) -> complex<f32>) -> ()
137
138  // complex.atan2 test
139  %atan2_test = arith.constant dense<[
140    (1.0, 2.0), (2.0, 1.0),
141    // CHECK:       0.785
142    // CHECK-NEXT:  0.346
143    (1.0, 1.0), (1.0, 0.0),
144    // CHECK-NEXT:  1.017
145    // CHECK-NEXT:  0.402
146    (1.0, 1.0), (1.0, 1.0)
147    // CHECK-NEXT:  0.785
148    // CHECK-NEXT:  0
149  ]> : tensor<6xcomplex<f32>>
150  %atan2_test_cast = tensor.cast %atan2_test
151    :  tensor<6xcomplex<f32>> to tensor<?xcomplex<f32>>
152
153  %atan2_func = func.constant @atan2 : (complex<f32>, complex<f32>)
154    -> complex<f32>
155  call @test_binary(%atan2_test_cast, %atan2_func)
156    : (tensor<?xcomplex<f32>>, (complex<f32>, complex<f32>)
157    -> complex<f32>) -> ()
158
159  // complex.pow test
160  %pow_test = arith.constant dense<[
161    (0.0, 0.0), (0.0, 0.0),
162    // CHECK:       1
163    // CHECK-NEXT:  0
164    (0.0, 0.0), (1.0, 0.0),
165    // CHECK-NEXT:  0
166    // CHECK-NEXT:  0
167    (0.0, 0.0), (-1.0, 0.0),
168    // CHECK-NEXT:  -nan
169    // CHECK-NEXT:  -nan
170    (1.0, 1.0), (1.0, 1.0)
171    // CHECK-NEXT:  0.273
172    // CHECK-NEXT:  0.583
173  ]> : tensor<8xcomplex<f32>>
174  %pow_test_cast = tensor.cast %pow_test
175    :  tensor<8xcomplex<f32>> to tensor<?xcomplex<f32>>
176
177  %pow_func = func.constant @pow : (complex<f32>, complex<f32>)
178    -> complex<f32>
179  call @test_binary(%pow_test_cast, %pow_func)
180    : (tensor<?xcomplex<f32>>, (complex<f32>, complex<f32>)
181    -> complex<f32>) -> ()
182
183  // complex.tanh test
184  %tanh_test = arith.constant dense<[
185    (-1.0, -1.0),
186    // CHECK:      -1.08392
187    // CHECK-NEXT: -0.271753
188    (-1.0, 1.0),
189    // CHECK-NEXT:  -1.08392
190    // CHECK-NEXT:  0.271753
191    (0.0, 0.0),
192    // CHECK-NEXT:  0
193    // CHECK-NEXT:  0
194    (0.0, 1.0),
195    // CHECK-NEXT:  0
196    // CHECK-NEXT:  1.5574
197    (1.0, -1.0),
198    // CHECK-NEXT:  1.08392
199    // CHECK-NEXT:  -0.271753
200    (1.0, 0.0),
201    // CHECK-NEXT:  0.761594
202    // CHECK-NEXT:  0
203    (1.0, 1.0)
204    // CHECK-NEXT:  1.08392
205    // CHECK-NEXT:  0.271753
206  ]> : tensor<7xcomplex<f32>>
207  %tanh_test_cast = tensor.cast %tanh_test
208    :  tensor<7xcomplex<f32>> to tensor<?xcomplex<f32>>
209
210  %tanh_func = func.constant @tanh : (complex<f32>) -> complex<f32>
211  call @test_unary(%tanh_test_cast, %tanh_func)
212    : (tensor<?xcomplex<f32>>, (complex<f32>) -> complex<f32>) -> ()
213
214  // complex.rsqrt test
215  %rsqrt_test = arith.constant dense<[
216    (-1.0, -1.0),
217    // CHECK:       0.321
218    // CHECK-NEXT:  0.776
219    (-1.0, 1.0),
220    // CHECK-NEXT:  0.321
221    // CHECK-NEXT:  -0.776
222    (0.0, 0.0),
223    // CHECK-NEXT:  nan
224    // CHECK-NEXT:  nan
225    (0.0, 1.0),
226    // CHECK-NEXT:  0.707
227    // CHECK-NEXT:  -0.707
228    (1.0, -1.0),
229    // CHECK-NEXT:  0.776
230    // CHECK-NEXT:  0.321
231    (1.0, 0.0),
232    // CHECK-NEXT:  1
233    // CHECK-NEXT:  0
234    (1.0, 1.0)
235    // CHECK-NEXT:  0.776
236    // CHECK-NEXT:  -0.321
237  ]> : tensor<7xcomplex<f32>>
238  %rsqrt_test_cast = tensor.cast %rsqrt_test
239    :  tensor<7xcomplex<f32>> to tensor<?xcomplex<f32>>
240
241  %rsqrt_func = func.constant @rsqrt : (complex<f32>) -> complex<f32>
242  call @test_unary(%rsqrt_test_cast, %rsqrt_func)
243    : (tensor<?xcomplex<f32>>, (complex<f32>) -> complex<f32>) -> ()
244
245  // complex.conj test
246  %conj_test = arith.constant dense<[
247    (-1.0, -1.0),
248    // CHECK:      -1
249    // CHECK-NEXT: 1
250    (-1.0, 1.0),
251    // CHECK-NEXT:  -1
252    // CHECK-NEXT:  -1
253    (0.0, 0.0),
254    // CHECK-NEXT:  0
255    // CHECK-NEXT:  0
256    (0.0, 1.0),
257    // CHECK-NEXT:  0
258    // CHECK-NEXT:  -1
259    (1.0, -1.0),
260    // CHECK-NEXT:  1
261    // CHECK-NEXT:  1
262    (1.0, 0.0),
263    // CHECK-NEXT:  1
264    // CHECK-NEXT:  0
265    (1.0, 1.0)
266    // CHECK-NEXT:  1
267    // CHECK-NEXT:  -1
268  ]> : tensor<7xcomplex<f32>>
269  %conj_test_cast = tensor.cast %conj_test
270    :  tensor<7xcomplex<f32>> to tensor<?xcomplex<f32>>
271
272  %conj_func = func.constant @conj : (complex<f32>) -> complex<f32>
273  call @test_unary(%conj_test_cast, %conj_func)
274    : (tensor<?xcomplex<f32>>, (complex<f32>) -> complex<f32>) -> ()
275
276  // complex.angle test
277  %angle_test = arith.constant dense<[
278    (-1.0, -1.0),
279    // CHECK:      -2.356
280    (-1.0, 1.0),
281    // CHECK-NEXT:  2.356
282    (0.0, 0.0),
283    // CHECK-NEXT:  0
284    (0.0, 1.0),
285    // CHECK-NEXT:  1.570
286    (1.0, -1.0),
287    // CHECK-NEXT:  -0.785
288    (1.0, 0.0),
289    // CHECK-NEXT:  0
290    (1.0, 1.0)
291    // CHECK-NEXT:  0.785
292  ]> : tensor<7xcomplex<f32>>
293  %angle_test_cast = tensor.cast %angle_test
294    :  tensor<7xcomplex<f32>> to tensor<?xcomplex<f32>>
295
296  %angle_func = func.constant @angle : (complex<f32>) -> f32
297  call @test_element(%angle_test_cast, %angle_func)
298    : (tensor<?xcomplex<f32>>, (complex<f32>) -> f32) -> ()
299
300  func.return
301}
302