1 //===- ReductionTest.cpp -- Reduction runtime builder unit tests ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "flang/Optimizer/Builder/Runtime/Reduction.h"
10 #include "RuntimeCallTestBase.h"
11 #include "gtest/gtest.h"
12
TEST_F(RuntimeCallTest,genAllTest)13 TEST_F(RuntimeCallTest, genAllTest) {
14 mlir::Location loc = firBuilder->getUnknownLoc();
15 mlir::Value undef = firBuilder->create<fir::UndefOp>(loc, seqTy10);
16 mlir::Value dim = firBuilder->createIntegerConstant(loc, i32Ty, 1);
17 mlir::Value all = fir::runtime::genAll(*firBuilder, loc, undef, dim);
18 checkCallOp(all.getDefiningOp(), "_FortranAAll", 2);
19 }
20
TEST_F(RuntimeCallTest,genAllDescriptorTest)21 TEST_F(RuntimeCallTest, genAllDescriptorTest) {
22 mlir::Location loc = firBuilder->getUnknownLoc();
23 mlir::Value result = firBuilder->create<fir::UndefOp>(loc, seqTy10);
24 mlir::Value mask = firBuilder->create<fir::UndefOp>(loc, seqTy10);
25 mlir::Value dim = firBuilder->createIntegerConstant(loc, i32Ty, 1);
26 fir::runtime::genAllDescriptor(*firBuilder, loc, result, mask, dim);
27 checkCallOpFromResultBox(result, "_FortranAAllDim", 3);
28 }
29
TEST_F(RuntimeCallTest,genAnyTest)30 TEST_F(RuntimeCallTest, genAnyTest) {
31 mlir::Location loc = firBuilder->getUnknownLoc();
32 mlir::Value undef = firBuilder->create<fir::UndefOp>(loc, seqTy10);
33 mlir::Value dim = firBuilder->createIntegerConstant(loc, i32Ty, 1);
34 mlir::Value any = fir::runtime::genAny(*firBuilder, loc, undef, dim);
35 checkCallOp(any.getDefiningOp(), "_FortranAAny", 2);
36 }
37
TEST_F(RuntimeCallTest,genAnyDescriptorTest)38 TEST_F(RuntimeCallTest, genAnyDescriptorTest) {
39 mlir::Location loc = firBuilder->getUnknownLoc();
40 mlir::Value result = firBuilder->create<fir::UndefOp>(loc, seqTy10);
41 mlir::Value mask = firBuilder->create<fir::UndefOp>(loc, seqTy10);
42 mlir::Value dim = firBuilder->createIntegerConstant(loc, i32Ty, 1);
43 fir::runtime::genAnyDescriptor(*firBuilder, loc, result, mask, dim);
44 checkCallOpFromResultBox(result, "_FortranAAnyDim", 3);
45 }
46
TEST_F(RuntimeCallTest,genCountTest)47 TEST_F(RuntimeCallTest, genCountTest) {
48 mlir::Location loc = firBuilder->getUnknownLoc();
49 mlir::Value undef = firBuilder->create<fir::UndefOp>(loc, seqTy10);
50 mlir::Value dim = firBuilder->createIntegerConstant(loc, i32Ty, 1);
51 mlir::Value count = fir::runtime::genCount(*firBuilder, loc, undef, dim);
52 checkCallOp(count.getDefiningOp(), "_FortranACount", 2);
53 }
54
TEST_F(RuntimeCallTest,genCountDimTest)55 TEST_F(RuntimeCallTest, genCountDimTest) {
56 mlir::Location loc = firBuilder->getUnknownLoc();
57 mlir::Value result = firBuilder->create<fir::UndefOp>(loc, seqTy10);
58 mlir::Value mask = firBuilder->create<fir::UndefOp>(loc, seqTy10);
59 mlir::Value dim = firBuilder->createIntegerConstant(loc, i32Ty, 1);
60 mlir::Value kind = firBuilder->createIntegerConstant(loc, i32Ty, 1);
61 fir::runtime::genCountDim(*firBuilder, loc, result, mask, dim, kind);
62 checkCallOpFromResultBox(result, "_FortranACountDim", 4);
63 }
64
testGenMaxVal(fir::FirOpBuilder & builder,mlir::Type eleTy,llvm::StringRef fctName)65 void testGenMaxVal(
66 fir::FirOpBuilder &builder, mlir::Type eleTy, llvm::StringRef fctName) {
67 mlir::Location loc = builder.getUnknownLoc();
68 mlir::Type seqTy =
69 fir::SequenceType::get(fir::SequenceType::Shape(1, 10), eleTy);
70 mlir::Type refSeqTy = fir::ReferenceType::get(seqTy);
71 mlir::Value undef = builder.create<fir::UndefOp>(loc, refSeqTy);
72 mlir::Value mask = builder.create<fir::UndefOp>(loc, seqTy);
73 mlir::Value max = fir::runtime::genMaxval(builder, loc, undef, mask);
74 checkCallOp(max.getDefiningOp(), fctName, 3);
75 }
76
TEST_F(RuntimeCallTest,genMaxValTest)77 TEST_F(RuntimeCallTest, genMaxValTest) {
78 testGenMaxVal(*firBuilder, f32Ty, "_FortranAMaxvalReal4");
79 testGenMaxVal(*firBuilder, f64Ty, "_FortranAMaxvalReal8");
80 testGenMaxVal(*firBuilder, f80Ty, "_FortranAMaxvalReal10");
81 testGenMaxVal(*firBuilder, f128Ty, "_FortranAMaxvalReal16");
82
83 testGenMaxVal(*firBuilder, i8Ty, "_FortranAMaxvalInteger1");
84 testGenMaxVal(*firBuilder, i16Ty, "_FortranAMaxvalInteger2");
85 testGenMaxVal(*firBuilder, i32Ty, "_FortranAMaxvalInteger4");
86 testGenMaxVal(*firBuilder, i64Ty, "_FortranAMaxvalInteger8");
87 testGenMaxVal(*firBuilder, i128Ty, "_FortranAMaxvalInteger16");
88 }
89
testGenMinVal(fir::FirOpBuilder & builder,mlir::Type eleTy,llvm::StringRef fctName)90 void testGenMinVal(
91 fir::FirOpBuilder &builder, mlir::Type eleTy, llvm::StringRef fctName) {
92 mlir::Location loc = builder.getUnknownLoc();
93 mlir::Type seqTy =
94 fir::SequenceType::get(fir::SequenceType::Shape(1, 10), eleTy);
95 mlir::Type refSeqTy = fir::ReferenceType::get(seqTy);
96 mlir::Value undef = builder.create<fir::UndefOp>(loc, refSeqTy);
97 mlir::Value mask = builder.create<fir::UndefOp>(loc, seqTy);
98 mlir::Value min = fir::runtime::genMinval(builder, loc, undef, mask);
99 checkCallOp(min.getDefiningOp(), fctName, 3);
100 }
101
TEST_F(RuntimeCallTest,genMinValTest)102 TEST_F(RuntimeCallTest, genMinValTest) {
103 testGenMinVal(*firBuilder, f32Ty, "_FortranAMinvalReal4");
104 testGenMinVal(*firBuilder, f64Ty, "_FortranAMinvalReal8");
105 testGenMinVal(*firBuilder, f80Ty, "_FortranAMinvalReal10");
106 testGenMinVal(*firBuilder, f128Ty, "_FortranAMinvalReal16");
107
108 testGenMinVal(*firBuilder, i8Ty, "_FortranAMinvalInteger1");
109 testGenMinVal(*firBuilder, i16Ty, "_FortranAMinvalInteger2");
110 testGenMinVal(*firBuilder, i32Ty, "_FortranAMinvalInteger4");
111 testGenMinVal(*firBuilder, i64Ty, "_FortranAMinvalInteger8");
112 testGenMinVal(*firBuilder, i128Ty, "_FortranAMinvalInteger16");
113 }
114
testGenSum(fir::FirOpBuilder & builder,mlir::Type eleTy,llvm::StringRef fctName)115 void testGenSum(
116 fir::FirOpBuilder &builder, mlir::Type eleTy, llvm::StringRef fctName) {
117 mlir::Location loc = builder.getUnknownLoc();
118 mlir::Type seqTy =
119 fir::SequenceType::get(fir::SequenceType::Shape(1, 10), eleTy);
120 mlir::Type refSeqTy = fir::ReferenceType::get(seqTy);
121 mlir::Value undef = builder.create<fir::UndefOp>(loc, refSeqTy);
122 mlir::Value mask = builder.create<fir::UndefOp>(loc, seqTy);
123 mlir::Value result = builder.create<fir::UndefOp>(loc, seqTy);
124 mlir::Value sum = fir::runtime::genSum(builder, loc, undef, mask, result);
125 if (fir::isa_complex(eleTy))
126 checkCallOpFromResultBox(result, fctName, 4);
127 else
128 checkCallOp(sum.getDefiningOp(), fctName, 3);
129 }
130
TEST_F(RuntimeCallTest,genSumTest)131 TEST_F(RuntimeCallTest, genSumTest) {
132 testGenSum(*firBuilder, f32Ty, "_FortranASumReal4");
133 testGenSum(*firBuilder, f64Ty, "_FortranASumReal8");
134 testGenSum(*firBuilder, f80Ty, "_FortranASumReal10");
135 testGenSum(*firBuilder, f128Ty, "_FortranASumReal16");
136 testGenSum(*firBuilder, i8Ty, "_FortranASumInteger1");
137 testGenSum(*firBuilder, i16Ty, "_FortranASumInteger2");
138 testGenSum(*firBuilder, i32Ty, "_FortranASumInteger4");
139 testGenSum(*firBuilder, i64Ty, "_FortranASumInteger8");
140 testGenSum(*firBuilder, i128Ty, "_FortranASumInteger16");
141 testGenSum(*firBuilder, c4Ty, "_FortranACppSumComplex4");
142 testGenSum(*firBuilder, c8Ty, "_FortranACppSumComplex8");
143 testGenSum(*firBuilder, c10Ty, "_FortranACppSumComplex10");
144 testGenSum(*firBuilder, c16Ty, "_FortranACppSumComplex16");
145 }
146
testGenProduct(fir::FirOpBuilder & builder,mlir::Type eleTy,llvm::StringRef fctName)147 void testGenProduct(
148 fir::FirOpBuilder &builder, mlir::Type eleTy, llvm::StringRef fctName) {
149 mlir::Location loc = builder.getUnknownLoc();
150 mlir::Type seqTy =
151 fir::SequenceType::get(fir::SequenceType::Shape(1, 10), eleTy);
152 mlir::Type refSeqTy = fir::ReferenceType::get(seqTy);
153 mlir::Value undef = builder.create<fir::UndefOp>(loc, refSeqTy);
154 mlir::Value mask = builder.create<fir::UndefOp>(loc, seqTy);
155 mlir::Value result = builder.create<fir::UndefOp>(loc, seqTy);
156 mlir::Value prod =
157 fir::runtime::genProduct(builder, loc, undef, mask, result);
158 if (fir::isa_complex(eleTy))
159 checkCallOpFromResultBox(result, fctName, 4);
160 else
161 checkCallOp(prod.getDefiningOp(), fctName, 3);
162 }
163
TEST_F(RuntimeCallTest,genProduct)164 TEST_F(RuntimeCallTest, genProduct) {
165 testGenProduct(*firBuilder, f32Ty, "_FortranAProductReal4");
166 testGenProduct(*firBuilder, f64Ty, "_FortranAProductReal8");
167 testGenProduct(*firBuilder, f80Ty, "_FortranAProductReal10");
168 testGenProduct(*firBuilder, f128Ty, "_FortranAProductReal16");
169 testGenProduct(*firBuilder, i8Ty, "_FortranAProductInteger1");
170 testGenProduct(*firBuilder, i16Ty, "_FortranAProductInteger2");
171 testGenProduct(*firBuilder, i32Ty, "_FortranAProductInteger4");
172 testGenProduct(*firBuilder, i64Ty, "_FortranAProductInteger8");
173 testGenProduct(*firBuilder, i128Ty, "_FortranAProductInteger16");
174 testGenProduct(*firBuilder, c4Ty, "_FortranACppProductComplex4");
175 testGenProduct(*firBuilder, c8Ty, "_FortranACppProductComplex8");
176 testGenProduct(*firBuilder, c10Ty, "_FortranACppProductComplex10");
177 testGenProduct(*firBuilder, c16Ty, "_FortranACppProductComplex16");
178 }
179
testGenDotProduct(fir::FirOpBuilder & builder,mlir::Type eleTy,llvm::StringRef fctName)180 void testGenDotProduct(
181 fir::FirOpBuilder &builder, mlir::Type eleTy, llvm::StringRef fctName) {
182 mlir::Location loc = builder.getUnknownLoc();
183 mlir::Type seqTy =
184 fir::SequenceType::get(fir::SequenceType::Shape(1, 10), eleTy);
185 mlir::Type refSeqTy = fir::ReferenceType::get(seqTy);
186 mlir::Value a = builder.create<fir::UndefOp>(loc, refSeqTy);
187 mlir::Value b = builder.create<fir::UndefOp>(loc, refSeqTy);
188 mlir::Value result = builder.create<fir::UndefOp>(loc, seqTy);
189 mlir::Value prod = fir::runtime::genDotProduct(builder, loc, a, b, result);
190 if (fir::isa_complex(eleTy))
191 checkCallOpFromResultBox(result, fctName, 3);
192 else
193 checkCallOp(prod.getDefiningOp(), fctName, 2);
194 }
195
TEST_F(RuntimeCallTest,genDotProduct)196 TEST_F(RuntimeCallTest, genDotProduct) {
197 testGenDotProduct(*firBuilder, f32Ty, "_FortranADotProductReal4");
198 testGenDotProduct(*firBuilder, f64Ty, "_FortranADotProductReal8");
199 testGenDotProduct(*firBuilder, f80Ty, "_FortranADotProductReal10");
200 testGenDotProduct(*firBuilder, f128Ty, "_FortranADotProductReal16");
201 testGenDotProduct(*firBuilder, i8Ty, "_FortranADotProductInteger1");
202 testGenDotProduct(*firBuilder, i16Ty, "_FortranADotProductInteger2");
203 testGenDotProduct(*firBuilder, i32Ty, "_FortranADotProductInteger4");
204 testGenDotProduct(*firBuilder, i64Ty, "_FortranADotProductInteger8");
205 testGenDotProduct(*firBuilder, i128Ty, "_FortranADotProductInteger16");
206 testGenDotProduct(*firBuilder, c4Ty, "_FortranACppDotProductComplex4");
207 testGenDotProduct(*firBuilder, c8Ty, "_FortranACppDotProductComplex8");
208 testGenDotProduct(*firBuilder, c10Ty, "_FortranACppDotProductComplex10");
209 testGenDotProduct(*firBuilder, c16Ty, "_FortranACppDotProductComplex16");
210 }
211
checkGenMxxloc(fir::FirOpBuilder & builder,void (* genFct)(fir::FirOpBuilder &,mlir::Location,mlir::Value,mlir::Value,mlir::Value,mlir::Value,mlir::Value),llvm::StringRef fctName,unsigned nbArgs)212 void checkGenMxxloc(fir::FirOpBuilder &builder,
213 void (*genFct)(fir::FirOpBuilder &, mlir::Location, mlir::Value,
214 mlir::Value, mlir::Value, mlir::Value, mlir::Value),
215 llvm::StringRef fctName, unsigned nbArgs) {
216 mlir::Location loc = builder.getUnknownLoc();
217 mlir::Type i32Ty = builder.getI32Type();
218 mlir::Type seqTy =
219 fir::SequenceType::get(fir::SequenceType::Shape(1, 10), i32Ty);
220 mlir::Type refSeqTy = fir::ReferenceType::get(seqTy);
221 mlir::Value a = builder.create<fir::UndefOp>(loc, refSeqTy);
222 mlir::Value result = builder.create<fir::UndefOp>(loc, seqTy);
223 mlir::Value mask = builder.create<fir::UndefOp>(loc, seqTy);
224 mlir::Value kind = builder.createIntegerConstant(loc, i32Ty, 1);
225 mlir::Value back = builder.createIntegerConstant(loc, i32Ty, 1);
226 genFct(builder, loc, result, a, mask, kind, back);
227 checkCallOpFromResultBox(result, fctName, nbArgs);
228 }
229
TEST_F(RuntimeCallTest,genMaxlocTest)230 TEST_F(RuntimeCallTest, genMaxlocTest) {
231 checkGenMxxloc(*firBuilder, fir::runtime::genMaxloc, "_FortranAMaxloc", 5);
232 }
233
TEST_F(RuntimeCallTest,genMinlocTest)234 TEST_F(RuntimeCallTest, genMinlocTest) {
235 checkGenMxxloc(*firBuilder, fir::runtime::genMinloc, "_FortranAMinloc", 5);
236 }
237
checkGenMxxlocDim(fir::FirOpBuilder & builder,void (* genFct)(fir::FirOpBuilder &,mlir::Location,mlir::Value,mlir::Value,mlir::Value,mlir::Value,mlir::Value,mlir::Value),llvm::StringRef fctName,unsigned nbArgs)238 void checkGenMxxlocDim(fir::FirOpBuilder &builder,
239 void (*genFct)(fir::FirOpBuilder &, mlir::Location, mlir::Value,
240 mlir::Value, mlir::Value, mlir::Value, mlir::Value, mlir::Value),
241 llvm::StringRef fctName, unsigned nbArgs) {
242 mlir::Location loc = builder.getUnknownLoc();
243 auto i32Ty = builder.getI32Type();
244 mlir::Type seqTy =
245 fir::SequenceType::get(fir::SequenceType::Shape(1, 10), i32Ty);
246 mlir::Type refSeqTy = fir::ReferenceType::get(seqTy);
247 mlir::Value a = builder.create<fir::UndefOp>(loc, refSeqTy);
248 mlir::Value result = builder.create<fir::UndefOp>(loc, seqTy);
249 mlir::Value mask = builder.create<fir::UndefOp>(loc, seqTy);
250 mlir::Value kind = builder.createIntegerConstant(loc, i32Ty, 1);
251 mlir::Value dim = builder.createIntegerConstant(loc, i32Ty, 1);
252 mlir::Value back = builder.createIntegerConstant(loc, i32Ty, 1);
253 genFct(builder, loc, result, a, dim, mask, kind, back);
254 checkCallOpFromResultBox(result, fctName, nbArgs);
255 }
256
TEST_F(RuntimeCallTest,genMaxlocDimTest)257 TEST_F(RuntimeCallTest, genMaxlocDimTest) {
258 checkGenMxxlocDim(
259 *firBuilder, fir::runtime::genMaxlocDim, "_FortranAMaxlocDim", 6);
260 }
261
TEST_F(RuntimeCallTest,genMinlocDimTest)262 TEST_F(RuntimeCallTest, genMinlocDimTest) {
263 checkGenMxxlocDim(
264 *firBuilder, fir::runtime::genMinlocDim, "_FortranAMinlocDim", 6);
265 }
266
checkGenMxxvalChar(fir::FirOpBuilder & builder,void (* genFct)(fir::FirOpBuilder &,mlir::Location,mlir::Value,mlir::Value,mlir::Value),llvm::StringRef fctName,unsigned nbArgs)267 void checkGenMxxvalChar(fir::FirOpBuilder &builder,
268 void (*genFct)(fir::FirOpBuilder &, mlir::Location, mlir::Value,
269 mlir::Value, mlir::Value),
270 llvm::StringRef fctName, unsigned nbArgs) {
271 mlir::Location loc = builder.getUnknownLoc();
272 auto i32Ty = builder.getI32Type();
273 mlir::Type seqTy =
274 fir::SequenceType::get(fir::SequenceType::Shape(1, 10), i32Ty);
275 mlir::Type refSeqTy = fir::ReferenceType::get(seqTy);
276 mlir::Value a = builder.create<fir::UndefOp>(loc, refSeqTy);
277 mlir::Value result = builder.create<fir::UndefOp>(loc, seqTy);
278 mlir::Value mask = builder.create<fir::UndefOp>(loc, seqTy);
279 genFct(builder, loc, result, a, mask);
280 checkCallOpFromResultBox(result, fctName, nbArgs);
281 }
282
TEST_F(RuntimeCallTest,genMaxvalCharTest)283 TEST_F(RuntimeCallTest, genMaxvalCharTest) {
284 checkGenMxxvalChar(
285 *firBuilder, fir::runtime::genMaxvalChar, "_FortranAMaxvalCharacter", 3);
286 }
287
TEST_F(RuntimeCallTest,genMinvalCharTest)288 TEST_F(RuntimeCallTest, genMinvalCharTest) {
289 checkGenMxxvalChar(
290 *firBuilder, fir::runtime::genMinvalChar, "_FortranAMinvalCharacter", 3);
291 }
292
checkGen4argsDim(fir::FirOpBuilder & builder,void (* genFct)(fir::FirOpBuilder &,mlir::Location,mlir::Value,mlir::Value,mlir::Value,mlir::Value),llvm::StringRef fctName,unsigned nbArgs)293 void checkGen4argsDim(fir::FirOpBuilder &builder,
294 void (*genFct)(fir::FirOpBuilder &, mlir::Location, mlir::Value,
295 mlir::Value, mlir::Value, mlir::Value),
296 llvm::StringRef fctName, unsigned nbArgs) {
297 mlir::Location loc = builder.getUnknownLoc();
298 auto i32Ty = builder.getI32Type();
299 mlir::Type seqTy =
300 fir::SequenceType::get(fir::SequenceType::Shape(1, 10), i32Ty);
301 mlir::Type refSeqTy = fir::ReferenceType::get(seqTy);
302 mlir::Value a = builder.create<fir::UndefOp>(loc, refSeqTy);
303 mlir::Value result = builder.create<fir::UndefOp>(loc, seqTy);
304 mlir::Value mask = builder.create<fir::UndefOp>(loc, seqTy);
305 mlir::Value dim = builder.createIntegerConstant(loc, i32Ty, 1);
306 genFct(builder, loc, result, a, dim, mask);
307 checkCallOpFromResultBox(result, fctName, nbArgs);
308 }
309
TEST_F(RuntimeCallTest,genMaxvalDimTest)310 TEST_F(RuntimeCallTest, genMaxvalDimTest) {
311 checkGen4argsDim(
312 *firBuilder, fir::runtime::genMaxvalDim, "_FortranAMaxvalDim", 4);
313 }
314
TEST_F(RuntimeCallTest,genMinvalDimTest)315 TEST_F(RuntimeCallTest, genMinvalDimTest) {
316 checkGen4argsDim(
317 *firBuilder, fir::runtime::genMinvalDim, "_FortranAMinvalDim", 4);
318 }
319
TEST_F(RuntimeCallTest,genProductDimTest)320 TEST_F(RuntimeCallTest, genProductDimTest) {
321 checkGen4argsDim(
322 *firBuilder, fir::runtime::genProductDim, "_FortranAProductDim", 4);
323 }
324
TEST_F(RuntimeCallTest,genSumDimTest)325 TEST_F(RuntimeCallTest, genSumDimTest) {
326 checkGen4argsDim(*firBuilder, fir::runtime::genSumDim, "_FortranASumDim", 4);
327 }
328