1// RUN: mlir-opt %s \ 2// RUN: -func-bufferize -tensor-bufferize -arith-bufferize --canonicalize \ 3// RUN: -convert-scf-to-cf --convert-complex-to-standard \ 4// RUN: -convert-memref-to-llvm -convert-math-to-llvm -convert-math-to-libm \ 5// RUN: -convert-vector-to-llvm -convert-complex-to-llvm \ 6// RUN: -convert-func-to-llvm -reconcile-unrealized-casts |\ 7// RUN: mlir-cpu-runner \ 8// RUN: -e entry -entry-point-result=void \ 9// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext |\ 10// RUN: FileCheck %s 11 12func.func @test_unary(%input: tensor<?xcomplex<f32>>, 13 %func: (complex<f32>) -> complex<f32>) { 14 %c0 = arith.constant 0 : index 15 %c1 = arith.constant 1 : index 16 %size = tensor.dim %input, %c0: tensor<?xcomplex<f32>> 17 18 scf.for %i = %c0 to %size step %c1 { 19 %elem = tensor.extract %input[%i]: tensor<?xcomplex<f32>> 20 21 %val = func.call_indirect %func(%elem) : (complex<f32>) -> complex<f32> 22 %real = complex.re %val : complex<f32> 23 %imag = complex.im %val: complex<f32> 24 vector.print %real : f32 25 vector.print %imag : f32 26 scf.yield 27 } 28 func.return 29} 30 31func.func @sqrt(%arg: complex<f32>) -> complex<f32> { 32 %sqrt = complex.sqrt %arg : complex<f32> 33 func.return %sqrt : complex<f32> 34} 35 36func.func @tanh(%arg: complex<f32>) -> complex<f32> { 37 %tanh = complex.tanh %arg : complex<f32> 38 func.return %tanh : complex<f32> 39} 40 41func.func @rsqrt(%arg: complex<f32>) -> complex<f32> { 42 %sqrt = complex.rsqrt %arg : complex<f32> 43 func.return %sqrt : complex<f32> 44} 45 46func.func @conj(%arg: complex<f32>) -> complex<f32> { 47 %conj = complex.conj %arg : complex<f32> 48 func.return %conj : complex<f32> 49} 50 51// %input contains pairs of lhs, rhs, i.e. [lhs_0, rhs_0, lhs_1, rhs_1,...] 52func.func @test_binary(%input: tensor<?xcomplex<f32>>, 53 %func: (complex<f32>, complex<f32>) -> complex<f32>) { 54 %c0 = arith.constant 0 : index 55 %c1 = arith.constant 1 : index 56 %c2 = arith.constant 2 : index 57 %size = tensor.dim %input, %c0: tensor<?xcomplex<f32>> 58 59 scf.for %i = %c0 to %size step %c2 { 60 %lhs = tensor.extract %input[%i]: tensor<?xcomplex<f32>> 61 %i_next = arith.addi %i, %c1 : index 62 %rhs = tensor.extract %input[%i_next]: tensor<?xcomplex<f32>> 63 64 %val = func.call_indirect %func(%lhs, %rhs) 65 : (complex<f32>, complex<f32>) -> complex<f32> 66 %real = complex.re %val : complex<f32> 67 %imag = complex.im %val: complex<f32> 68 vector.print %real : f32 69 vector.print %imag : f32 70 scf.yield 71 } 72 func.return 73} 74 75func.func @atan2(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> { 76 %atan2 = complex.atan2 %lhs, %rhs : complex<f32> 77 func.return %atan2 : complex<f32> 78} 79 80func.func @pow(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> { 81 %pow = complex.pow %lhs, %rhs : complex<f32> 82 func.return %pow : complex<f32> 83} 84 85func.func @test_element(%input: tensor<?xcomplex<f32>>, 86 %func: (complex<f32>) -> f32) { 87 %c0 = arith.constant 0 : index 88 %c1 = arith.constant 1 : index 89 %size = tensor.dim %input, %c0: tensor<?xcomplex<f32>> 90 91 scf.for %i = %c0 to %size step %c1 { 92 %elem = tensor.extract %input[%i]: tensor<?xcomplex<f32>> 93 94 %val = func.call_indirect %func(%elem) : (complex<f32>) -> f32 95 vector.print %val : f32 96 scf.yield 97 } 98 func.return 99} 100 101func.func @angle(%arg: complex<f32>) -> f32 { 102 %angle = complex.angle %arg : complex<f32> 103 func.return %angle : f32 104} 105 106func.func @entry() { 107 // complex.sqrt test 108 %sqrt_test = arith.constant dense<[ 109 (-1.0, -1.0), 110 // CHECK: 0.455 111 // CHECK-NEXT: -1.098 112 (-1.0, 1.0), 113 // CHECK-NEXT: 0.455 114 // CHECK-NEXT: 1.098 115 (0.0, 0.0), 116 // CHECK-NEXT: 0 117 // CHECK-NEXT: 0 118 (0.0, 1.0), 119 // CHECK-NEXT: 0.707 120 // CHECK-NEXT: 0.707 121 (1.0, -1.0), 122 // CHECK-NEXT: 1.098 123 // CHECK-NEXT: -0.455 124 (1.0, 0.0), 125 // CHECK-NEXT: 1 126 // CHECK-NEXT: 0 127 (1.0, 1.0) 128 // CHECK-NEXT: 1.098 129 // CHECK-NEXT: 0.455 130 ]> : tensor<7xcomplex<f32>> 131 %sqrt_test_cast = tensor.cast %sqrt_test 132 : tensor<7xcomplex<f32>> to tensor<?xcomplex<f32>> 133 134 %sqrt_func = func.constant @sqrt : (complex<f32>) -> complex<f32> 135 call @test_unary(%sqrt_test_cast, %sqrt_func) 136 : (tensor<?xcomplex<f32>>, (complex<f32>) -> complex<f32>) -> () 137 138 // complex.atan2 test 139 %atan2_test = arith.constant dense<[ 140 (1.0, 2.0), (2.0, 1.0), 141 // CHECK: 0.785 142 // CHECK-NEXT: 0.346 143 (1.0, 1.0), (1.0, 0.0), 144 // CHECK-NEXT: 1.017 145 // CHECK-NEXT: 0.402 146 (1.0, 1.0), (1.0, 1.0) 147 // CHECK-NEXT: 0.785 148 // CHECK-NEXT: 0 149 ]> : tensor<6xcomplex<f32>> 150 %atan2_test_cast = tensor.cast %atan2_test 151 : tensor<6xcomplex<f32>> to tensor<?xcomplex<f32>> 152 153 %atan2_func = func.constant @atan2 : (complex<f32>, complex<f32>) 154 -> complex<f32> 155 call @test_binary(%atan2_test_cast, %atan2_func) 156 : (tensor<?xcomplex<f32>>, (complex<f32>, complex<f32>) 157 -> complex<f32>) -> () 158 159 // complex.pow test 160 %pow_test = arith.constant dense<[ 161 (0.0, 0.0), (0.0, 0.0), 162 // CHECK: 1 163 // CHECK-NEXT: 0 164 (0.0, 0.0), (1.0, 0.0), 165 // CHECK-NEXT: 0 166 // CHECK-NEXT: 0 167 (0.0, 0.0), (-1.0, 0.0), 168 // CHECK-NEXT: -nan 169 // CHECK-NEXT: -nan 170 (1.0, 1.0), (1.0, 1.0) 171 // CHECK-NEXT: 0.273 172 // CHECK-NEXT: 0.583 173 ]> : tensor<8xcomplex<f32>> 174 %pow_test_cast = tensor.cast %pow_test 175 : tensor<8xcomplex<f32>> to tensor<?xcomplex<f32>> 176 177 %pow_func = func.constant @pow : (complex<f32>, complex<f32>) 178 -> complex<f32> 179 call @test_binary(%pow_test_cast, %pow_func) 180 : (tensor<?xcomplex<f32>>, (complex<f32>, complex<f32>) 181 -> complex<f32>) -> () 182 183 // complex.tanh test 184 %tanh_test = arith.constant dense<[ 185 (-1.0, -1.0), 186 // CHECK: -1.08392 187 // CHECK-NEXT: -0.271753 188 (-1.0, 1.0), 189 // CHECK-NEXT: -1.08392 190 // CHECK-NEXT: 0.271753 191 (0.0, 0.0), 192 // CHECK-NEXT: 0 193 // CHECK-NEXT: 0 194 (0.0, 1.0), 195 // CHECK-NEXT: 0 196 // CHECK-NEXT: 1.5574 197 (1.0, -1.0), 198 // CHECK-NEXT: 1.08392 199 // CHECK-NEXT: -0.271753 200 (1.0, 0.0), 201 // CHECK-NEXT: 0.761594 202 // CHECK-NEXT: 0 203 (1.0, 1.0) 204 // CHECK-NEXT: 1.08392 205 // CHECK-NEXT: 0.271753 206 ]> : tensor<7xcomplex<f32>> 207 %tanh_test_cast = tensor.cast %tanh_test 208 : tensor<7xcomplex<f32>> to tensor<?xcomplex<f32>> 209 210 %tanh_func = func.constant @tanh : (complex<f32>) -> complex<f32> 211 call @test_unary(%tanh_test_cast, %tanh_func) 212 : (tensor<?xcomplex<f32>>, (complex<f32>) -> complex<f32>) -> () 213 214 // complex.rsqrt test 215 %rsqrt_test = arith.constant dense<[ 216 (-1.0, -1.0), 217 // CHECK: 0.321 218 // CHECK-NEXT: 0.776 219 (-1.0, 1.0), 220 // CHECK-NEXT: 0.321 221 // CHECK-NEXT: -0.776 222 (0.0, 0.0), 223 // CHECK-NEXT: nan 224 // CHECK-NEXT: nan 225 (0.0, 1.0), 226 // CHECK-NEXT: 0.707 227 // CHECK-NEXT: -0.707 228 (1.0, -1.0), 229 // CHECK-NEXT: 0.776 230 // CHECK-NEXT: 0.321 231 (1.0, 0.0), 232 // CHECK-NEXT: 1 233 // CHECK-NEXT: 0 234 (1.0, 1.0) 235 // CHECK-NEXT: 0.776 236 // CHECK-NEXT: -0.321 237 ]> : tensor<7xcomplex<f32>> 238 %rsqrt_test_cast = tensor.cast %rsqrt_test 239 : tensor<7xcomplex<f32>> to tensor<?xcomplex<f32>> 240 241 %rsqrt_func = func.constant @rsqrt : (complex<f32>) -> complex<f32> 242 call @test_unary(%rsqrt_test_cast, %rsqrt_func) 243 : (tensor<?xcomplex<f32>>, (complex<f32>) -> complex<f32>) -> () 244 245 // complex.conj test 246 %conj_test = arith.constant dense<[ 247 (-1.0, -1.0), 248 // CHECK: -1 249 // CHECK-NEXT: 1 250 (-1.0, 1.0), 251 // CHECK-NEXT: -1 252 // CHECK-NEXT: -1 253 (0.0, 0.0), 254 // CHECK-NEXT: 0 255 // CHECK-NEXT: 0 256 (0.0, 1.0), 257 // CHECK-NEXT: 0 258 // CHECK-NEXT: -1 259 (1.0, -1.0), 260 // CHECK-NEXT: 1 261 // CHECK-NEXT: 1 262 (1.0, 0.0), 263 // CHECK-NEXT: 1 264 // CHECK-NEXT: 0 265 (1.0, 1.0) 266 // CHECK-NEXT: 1 267 // CHECK-NEXT: -1 268 ]> : tensor<7xcomplex<f32>> 269 %conj_test_cast = tensor.cast %conj_test 270 : tensor<7xcomplex<f32>> to tensor<?xcomplex<f32>> 271 272 %conj_func = func.constant @conj : (complex<f32>) -> complex<f32> 273 call @test_unary(%conj_test_cast, %conj_func) 274 : (tensor<?xcomplex<f32>>, (complex<f32>) -> complex<f32>) -> () 275 276 // complex.angle test 277 %angle_test = arith.constant dense<[ 278 (-1.0, -1.0), 279 // CHECK: -2.356 280 (-1.0, 1.0), 281 // CHECK-NEXT: 2.356 282 (0.0, 0.0), 283 // CHECK-NEXT: 0 284 (0.0, 1.0), 285 // CHECK-NEXT: 1.570 286 (1.0, -1.0), 287 // CHECK-NEXT: -0.785 288 (1.0, 0.0), 289 // CHECK-NEXT: 0 290 (1.0, 1.0) 291 // CHECK-NEXT: 0.785 292 ]> : tensor<7xcomplex<f32>> 293 %angle_test_cast = tensor.cast %angle_test 294 : tensor<7xcomplex<f32>> to tensor<?xcomplex<f32>> 295 296 %angle_func = func.constant @angle : (complex<f32>) -> f32 297 call @test_element(%angle_test_cast, %angle_func) 298 : (tensor<?xcomplex<f32>>, (complex<f32>) -> f32) -> () 299 300 func.return 301} 302