1// RUN: mlir-opt %s -test-math-algebraic-simplification | FileCheck %s --dump-input=always
2
3// CHECK-LABEL: @pow_noop
4func.func @pow_noop(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
5  // CHECK: return %arg0, %arg1
6  %c = arith.constant 1.0 : f32
7  %v = arith.constant dense <1.0> : vector<4xf32>
8  %0 = math.powf %arg0, %c : f32
9  %1 = math.powf %arg1, %v : vector<4xf32>
10  return %0, %1 : f32, vector<4xf32>
11}
12
13// CHECK-LABEL: @pow_square
14func.func @pow_square(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
15  // CHECK: %[[SCALAR:.*]] = arith.mulf %arg0, %arg0
16  // CHECK: %[[VECTOR:.*]] = arith.mulf %arg1, %arg1
17  // CHECK: return %[[SCALAR]], %[[VECTOR]]
18  %c = arith.constant 2.0 : f32
19  %v = arith.constant dense <2.0> : vector<4xf32>
20  %0 = math.powf %arg0, %c : f32
21  %1 = math.powf %arg1, %v : vector<4xf32>
22  return %0, %1 : f32, vector<4xf32>
23}
24
25// CHECK-LABEL: @pow_cube
26func.func @pow_cube(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
27  // CHECK: %[[TMP_S:.*]] = arith.mulf %arg0, %arg0
28  // CHECK: %[[SCALAR:.*]] = arith.mulf %arg0, %[[TMP_S]]
29  // CHECK: %[[TMP_V:.*]] = arith.mulf %arg1, %arg1
30  // CHECK: %[[VECTOR:.*]] = arith.mulf %arg1, %[[TMP_V]]
31  // CHECK: return %[[SCALAR]], %[[VECTOR]]
32  %c = arith.constant 3.0 : f32
33  %v = arith.constant dense <3.0> : vector<4xf32>
34  %0 = math.powf %arg0, %c : f32
35  %1 = math.powf %arg1, %v : vector<4xf32>
36  return %0, %1 : f32, vector<4xf32>
37}
38
39// CHECK-LABEL: @pow_recip
40func.func @pow_recip(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
41  // CHECK: %[[CST_S:.*]] = arith.constant 1.0{{.*}} : f32
42  // CHECK: %[[CST_V:.*]] = arith.constant dense<1.0{{.*}}> : vector<4xf32>
43  // CHECK: %[[SCALAR:.*]] = arith.divf %[[CST_S]], %arg0
44  // CHECK: %[[VECTOR:.*]] = arith.divf %[[CST_V]], %arg1
45  // CHECK: return %[[SCALAR]], %[[VECTOR]]
46  %c = arith.constant -1.0 : f32
47  %v = arith.constant dense <-1.0> : vector<4xf32>
48  %0 = math.powf %arg0, %c : f32
49  %1 = math.powf %arg1, %v : vector<4xf32>
50  return %0, %1 : f32, vector<4xf32>
51}
52
53// CHECK-LABEL: @pow_sqrt
54func.func @pow_sqrt(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
55  // CHECK: %[[SCALAR:.*]] = math.sqrt %arg0
56  // CHECK: %[[VECTOR:.*]] = math.sqrt %arg1
57  // CHECK: return %[[SCALAR]], %[[VECTOR]]
58  %c = arith.constant 0.5 : f32
59  %v = arith.constant dense <0.5> : vector<4xf32>
60  %0 = math.powf %arg0, %c : f32
61  %1 = math.powf %arg1, %v : vector<4xf32>
62  return %0, %1 : f32, vector<4xf32>
63}
64
65// CHECK-LABEL: @pow_rsqrt
66func.func @pow_rsqrt(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
67  // CHECK: %[[SCALAR:.*]] = math.rsqrt %arg0
68  // CHECK: %[[VECTOR:.*]] = math.rsqrt %arg1
69  // CHECK: return %[[SCALAR]], %[[VECTOR]]
70  %c = arith.constant -0.5 : f32
71  %v = arith.constant dense <-0.5> : vector<4xf32>
72  %0 = math.powf %arg0, %c : f32
73  %1 = math.powf %arg1, %v : vector<4xf32>
74  return %0, %1 : f32, vector<4xf32>
75}
76