1 //===- ReductionTest.cpp -- Reduction runtime builder unit tests ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Optimizer/Builder/Runtime/Reduction.h"
10 #include "RuntimeCallTestBase.h"
11 #include "gtest/gtest.h"
12 
TEST_F(RuntimeCallTest,genAllTest)13 TEST_F(RuntimeCallTest, genAllTest) {
14   mlir::Location loc = firBuilder->getUnknownLoc();
15   mlir::Value undef = firBuilder->create<fir::UndefOp>(loc, seqTy10);
16   mlir::Value dim = firBuilder->createIntegerConstant(loc, i32Ty, 1);
17   mlir::Value all = fir::runtime::genAll(*firBuilder, loc, undef, dim);
18   checkCallOp(all.getDefiningOp(), "_FortranAAll", 2);
19 }
20 
TEST_F(RuntimeCallTest,genAllDescriptorTest)21 TEST_F(RuntimeCallTest, genAllDescriptorTest) {
22   mlir::Location loc = firBuilder->getUnknownLoc();
23   mlir::Value result = firBuilder->create<fir::UndefOp>(loc, seqTy10);
24   mlir::Value mask = firBuilder->create<fir::UndefOp>(loc, seqTy10);
25   mlir::Value dim = firBuilder->createIntegerConstant(loc, i32Ty, 1);
26   fir::runtime::genAllDescriptor(*firBuilder, loc, result, mask, dim);
27   checkCallOpFromResultBox(result, "_FortranAAllDim", 3);
28 }
29 
TEST_F(RuntimeCallTest,genAnyTest)30 TEST_F(RuntimeCallTest, genAnyTest) {
31   mlir::Location loc = firBuilder->getUnknownLoc();
32   mlir::Value undef = firBuilder->create<fir::UndefOp>(loc, seqTy10);
33   mlir::Value dim = firBuilder->createIntegerConstant(loc, i32Ty, 1);
34   mlir::Value any = fir::runtime::genAny(*firBuilder, loc, undef, dim);
35   checkCallOp(any.getDefiningOp(), "_FortranAAny", 2);
36 }
37 
TEST_F(RuntimeCallTest,genAnyDescriptorTest)38 TEST_F(RuntimeCallTest, genAnyDescriptorTest) {
39   mlir::Location loc = firBuilder->getUnknownLoc();
40   mlir::Value result = firBuilder->create<fir::UndefOp>(loc, seqTy10);
41   mlir::Value mask = firBuilder->create<fir::UndefOp>(loc, seqTy10);
42   mlir::Value dim = firBuilder->createIntegerConstant(loc, i32Ty, 1);
43   fir::runtime::genAnyDescriptor(*firBuilder, loc, result, mask, dim);
44   checkCallOpFromResultBox(result, "_FortranAAnyDim", 3);
45 }
46 
TEST_F(RuntimeCallTest,genCountTest)47 TEST_F(RuntimeCallTest, genCountTest) {
48   mlir::Location loc = firBuilder->getUnknownLoc();
49   mlir::Value undef = firBuilder->create<fir::UndefOp>(loc, seqTy10);
50   mlir::Value dim = firBuilder->createIntegerConstant(loc, i32Ty, 1);
51   mlir::Value count = fir::runtime::genCount(*firBuilder, loc, undef, dim);
52   checkCallOp(count.getDefiningOp(), "_FortranACount", 2);
53 }
54 
TEST_F(RuntimeCallTest,genCountDimTest)55 TEST_F(RuntimeCallTest, genCountDimTest) {
56   mlir::Location loc = firBuilder->getUnknownLoc();
57   mlir::Value result = firBuilder->create<fir::UndefOp>(loc, seqTy10);
58   mlir::Value mask = firBuilder->create<fir::UndefOp>(loc, seqTy10);
59   mlir::Value dim = firBuilder->createIntegerConstant(loc, i32Ty, 1);
60   mlir::Value kind = firBuilder->createIntegerConstant(loc, i32Ty, 1);
61   fir::runtime::genCountDim(*firBuilder, loc, result, mask, dim, kind);
62   checkCallOpFromResultBox(result, "_FortranACountDim", 4);
63 }
64 
testGenMaxVal(fir::FirOpBuilder & builder,mlir::Type eleTy,llvm::StringRef fctName)65 void testGenMaxVal(
66     fir::FirOpBuilder &builder, mlir::Type eleTy, llvm::StringRef fctName) {
67   mlir::Location loc = builder.getUnknownLoc();
68   mlir::Type seqTy =
69       fir::SequenceType::get(fir::SequenceType::Shape(1, 10), eleTy);
70   mlir::Type refSeqTy = fir::ReferenceType::get(seqTy);
71   mlir::Value undef = builder.create<fir::UndefOp>(loc, refSeqTy);
72   mlir::Value mask = builder.create<fir::UndefOp>(loc, seqTy);
73   mlir::Value max = fir::runtime::genMaxval(builder, loc, undef, mask);
74   checkCallOp(max.getDefiningOp(), fctName, 3);
75 }
76 
TEST_F(RuntimeCallTest,genMaxValTest)77 TEST_F(RuntimeCallTest, genMaxValTest) {
78   testGenMaxVal(*firBuilder, f32Ty, "_FortranAMaxvalReal4");
79   testGenMaxVal(*firBuilder, f64Ty, "_FortranAMaxvalReal8");
80   testGenMaxVal(*firBuilder, f80Ty, "_FortranAMaxvalReal10");
81   testGenMaxVal(*firBuilder, f128Ty, "_FortranAMaxvalReal16");
82 
83   testGenMaxVal(*firBuilder, i8Ty, "_FortranAMaxvalInteger1");
84   testGenMaxVal(*firBuilder, i16Ty, "_FortranAMaxvalInteger2");
85   testGenMaxVal(*firBuilder, i32Ty, "_FortranAMaxvalInteger4");
86   testGenMaxVal(*firBuilder, i64Ty, "_FortranAMaxvalInteger8");
87   testGenMaxVal(*firBuilder, i128Ty, "_FortranAMaxvalInteger16");
88 }
89 
testGenMinVal(fir::FirOpBuilder & builder,mlir::Type eleTy,llvm::StringRef fctName)90 void testGenMinVal(
91     fir::FirOpBuilder &builder, mlir::Type eleTy, llvm::StringRef fctName) {
92   mlir::Location loc = builder.getUnknownLoc();
93   mlir::Type seqTy =
94       fir::SequenceType::get(fir::SequenceType::Shape(1, 10), eleTy);
95   mlir::Type refSeqTy = fir::ReferenceType::get(seqTy);
96   mlir::Value undef = builder.create<fir::UndefOp>(loc, refSeqTy);
97   mlir::Value mask = builder.create<fir::UndefOp>(loc, seqTy);
98   mlir::Value min = fir::runtime::genMinval(builder, loc, undef, mask);
99   checkCallOp(min.getDefiningOp(), fctName, 3);
100 }
101 
TEST_F(RuntimeCallTest,genMinValTest)102 TEST_F(RuntimeCallTest, genMinValTest) {
103   testGenMinVal(*firBuilder, f32Ty, "_FortranAMinvalReal4");
104   testGenMinVal(*firBuilder, f64Ty, "_FortranAMinvalReal8");
105   testGenMinVal(*firBuilder, f80Ty, "_FortranAMinvalReal10");
106   testGenMinVal(*firBuilder, f128Ty, "_FortranAMinvalReal16");
107 
108   testGenMinVal(*firBuilder, i8Ty, "_FortranAMinvalInteger1");
109   testGenMinVal(*firBuilder, i16Ty, "_FortranAMinvalInteger2");
110   testGenMinVal(*firBuilder, i32Ty, "_FortranAMinvalInteger4");
111   testGenMinVal(*firBuilder, i64Ty, "_FortranAMinvalInteger8");
112   testGenMinVal(*firBuilder, i128Ty, "_FortranAMinvalInteger16");
113 }
114 
testGenSum(fir::FirOpBuilder & builder,mlir::Type eleTy,llvm::StringRef fctName)115 void testGenSum(
116     fir::FirOpBuilder &builder, mlir::Type eleTy, llvm::StringRef fctName) {
117   mlir::Location loc = builder.getUnknownLoc();
118   mlir::Type seqTy =
119       fir::SequenceType::get(fir::SequenceType::Shape(1, 10), eleTy);
120   mlir::Type refSeqTy = fir::ReferenceType::get(seqTy);
121   mlir::Value undef = builder.create<fir::UndefOp>(loc, refSeqTy);
122   mlir::Value mask = builder.create<fir::UndefOp>(loc, seqTy);
123   mlir::Value result = builder.create<fir::UndefOp>(loc, seqTy);
124   mlir::Value sum = fir::runtime::genSum(builder, loc, undef, mask, result);
125   if (fir::isa_complex(eleTy))
126     checkCallOpFromResultBox(result, fctName, 4);
127   else
128     checkCallOp(sum.getDefiningOp(), fctName, 3);
129 }
130 
TEST_F(RuntimeCallTest,genSumTest)131 TEST_F(RuntimeCallTest, genSumTest) {
132   testGenSum(*firBuilder, f32Ty, "_FortranASumReal4");
133   testGenSum(*firBuilder, f64Ty, "_FortranASumReal8");
134   testGenSum(*firBuilder, f80Ty, "_FortranASumReal10");
135   testGenSum(*firBuilder, f128Ty, "_FortranASumReal16");
136   testGenSum(*firBuilder, i8Ty, "_FortranASumInteger1");
137   testGenSum(*firBuilder, i16Ty, "_FortranASumInteger2");
138   testGenSum(*firBuilder, i32Ty, "_FortranASumInteger4");
139   testGenSum(*firBuilder, i64Ty, "_FortranASumInteger8");
140   testGenSum(*firBuilder, i128Ty, "_FortranASumInteger16");
141   testGenSum(*firBuilder, c4Ty, "_FortranACppSumComplex4");
142   testGenSum(*firBuilder, c8Ty, "_FortranACppSumComplex8");
143   testGenSum(*firBuilder, c10Ty, "_FortranACppSumComplex10");
144   testGenSum(*firBuilder, c16Ty, "_FortranACppSumComplex16");
145 }
146 
testGenProduct(fir::FirOpBuilder & builder,mlir::Type eleTy,llvm::StringRef fctName)147 void testGenProduct(
148     fir::FirOpBuilder &builder, mlir::Type eleTy, llvm::StringRef fctName) {
149   mlir::Location loc = builder.getUnknownLoc();
150   mlir::Type seqTy =
151       fir::SequenceType::get(fir::SequenceType::Shape(1, 10), eleTy);
152   mlir::Type refSeqTy = fir::ReferenceType::get(seqTy);
153   mlir::Value undef = builder.create<fir::UndefOp>(loc, refSeqTy);
154   mlir::Value mask = builder.create<fir::UndefOp>(loc, seqTy);
155   mlir::Value result = builder.create<fir::UndefOp>(loc, seqTy);
156   mlir::Value prod =
157       fir::runtime::genProduct(builder, loc, undef, mask, result);
158   if (fir::isa_complex(eleTy))
159     checkCallOpFromResultBox(result, fctName, 4);
160   else
161     checkCallOp(prod.getDefiningOp(), fctName, 3);
162 }
163 
TEST_F(RuntimeCallTest,genProduct)164 TEST_F(RuntimeCallTest, genProduct) {
165   testGenProduct(*firBuilder, f32Ty, "_FortranAProductReal4");
166   testGenProduct(*firBuilder, f64Ty, "_FortranAProductReal8");
167   testGenProduct(*firBuilder, f80Ty, "_FortranAProductReal10");
168   testGenProduct(*firBuilder, f128Ty, "_FortranAProductReal16");
169   testGenProduct(*firBuilder, i8Ty, "_FortranAProductInteger1");
170   testGenProduct(*firBuilder, i16Ty, "_FortranAProductInteger2");
171   testGenProduct(*firBuilder, i32Ty, "_FortranAProductInteger4");
172   testGenProduct(*firBuilder, i64Ty, "_FortranAProductInteger8");
173   testGenProduct(*firBuilder, i128Ty, "_FortranAProductInteger16");
174   testGenProduct(*firBuilder, c4Ty, "_FortranACppProductComplex4");
175   testGenProduct(*firBuilder, c8Ty, "_FortranACppProductComplex8");
176   testGenProduct(*firBuilder, c10Ty, "_FortranACppProductComplex10");
177   testGenProduct(*firBuilder, c16Ty, "_FortranACppProductComplex16");
178 }
179 
testGenDotProduct(fir::FirOpBuilder & builder,mlir::Type eleTy,llvm::StringRef fctName)180 void testGenDotProduct(
181     fir::FirOpBuilder &builder, mlir::Type eleTy, llvm::StringRef fctName) {
182   mlir::Location loc = builder.getUnknownLoc();
183   mlir::Type seqTy =
184       fir::SequenceType::get(fir::SequenceType::Shape(1, 10), eleTy);
185   mlir::Type refSeqTy = fir::ReferenceType::get(seqTy);
186   mlir::Value a = builder.create<fir::UndefOp>(loc, refSeqTy);
187   mlir::Value b = builder.create<fir::UndefOp>(loc, refSeqTy);
188   mlir::Value result = builder.create<fir::UndefOp>(loc, seqTy);
189   mlir::Value prod = fir::runtime::genDotProduct(builder, loc, a, b, result);
190   if (fir::isa_complex(eleTy))
191     checkCallOpFromResultBox(result, fctName, 3);
192   else
193     checkCallOp(prod.getDefiningOp(), fctName, 2);
194 }
195 
TEST_F(RuntimeCallTest,genDotProduct)196 TEST_F(RuntimeCallTest, genDotProduct) {
197   testGenDotProduct(*firBuilder, f32Ty, "_FortranADotProductReal4");
198   testGenDotProduct(*firBuilder, f64Ty, "_FortranADotProductReal8");
199   testGenDotProduct(*firBuilder, f80Ty, "_FortranADotProductReal10");
200   testGenDotProduct(*firBuilder, f128Ty, "_FortranADotProductReal16");
201   testGenDotProduct(*firBuilder, i8Ty, "_FortranADotProductInteger1");
202   testGenDotProduct(*firBuilder, i16Ty, "_FortranADotProductInteger2");
203   testGenDotProduct(*firBuilder, i32Ty, "_FortranADotProductInteger4");
204   testGenDotProduct(*firBuilder, i64Ty, "_FortranADotProductInteger8");
205   testGenDotProduct(*firBuilder, i128Ty, "_FortranADotProductInteger16");
206   testGenDotProduct(*firBuilder, c4Ty, "_FortranACppDotProductComplex4");
207   testGenDotProduct(*firBuilder, c8Ty, "_FortranACppDotProductComplex8");
208   testGenDotProduct(*firBuilder, c10Ty, "_FortranACppDotProductComplex10");
209   testGenDotProduct(*firBuilder, c16Ty, "_FortranACppDotProductComplex16");
210 }
211 
checkGenMxxloc(fir::FirOpBuilder & builder,void (* genFct)(fir::FirOpBuilder &,mlir::Location,mlir::Value,mlir::Value,mlir::Value,mlir::Value,mlir::Value),llvm::StringRef fctName,unsigned nbArgs)212 void checkGenMxxloc(fir::FirOpBuilder &builder,
213     void (*genFct)(fir::FirOpBuilder &, mlir::Location, mlir::Value,
214         mlir::Value, mlir::Value, mlir::Value, mlir::Value),
215     llvm::StringRef fctName, unsigned nbArgs) {
216   mlir::Location loc = builder.getUnknownLoc();
217   mlir::Type i32Ty = builder.getI32Type();
218   mlir::Type seqTy =
219       fir::SequenceType::get(fir::SequenceType::Shape(1, 10), i32Ty);
220   mlir::Type refSeqTy = fir::ReferenceType::get(seqTy);
221   mlir::Value a = builder.create<fir::UndefOp>(loc, refSeqTy);
222   mlir::Value result = builder.create<fir::UndefOp>(loc, seqTy);
223   mlir::Value mask = builder.create<fir::UndefOp>(loc, seqTy);
224   mlir::Value kind = builder.createIntegerConstant(loc, i32Ty, 1);
225   mlir::Value back = builder.createIntegerConstant(loc, i32Ty, 1);
226   genFct(builder, loc, result, a, mask, kind, back);
227   checkCallOpFromResultBox(result, fctName, nbArgs);
228 }
229 
TEST_F(RuntimeCallTest,genMaxlocTest)230 TEST_F(RuntimeCallTest, genMaxlocTest) {
231   checkGenMxxloc(*firBuilder, fir::runtime::genMaxloc, "_FortranAMaxloc", 5);
232 }
233 
TEST_F(RuntimeCallTest,genMinlocTest)234 TEST_F(RuntimeCallTest, genMinlocTest) {
235   checkGenMxxloc(*firBuilder, fir::runtime::genMinloc, "_FortranAMinloc", 5);
236 }
237 
checkGenMxxlocDim(fir::FirOpBuilder & builder,void (* genFct)(fir::FirOpBuilder &,mlir::Location,mlir::Value,mlir::Value,mlir::Value,mlir::Value,mlir::Value,mlir::Value),llvm::StringRef fctName,unsigned nbArgs)238 void checkGenMxxlocDim(fir::FirOpBuilder &builder,
239     void (*genFct)(fir::FirOpBuilder &, mlir::Location, mlir::Value,
240         mlir::Value, mlir::Value, mlir::Value, mlir::Value, mlir::Value),
241     llvm::StringRef fctName, unsigned nbArgs) {
242   mlir::Location loc = builder.getUnknownLoc();
243   auto i32Ty = builder.getI32Type();
244   mlir::Type seqTy =
245       fir::SequenceType::get(fir::SequenceType::Shape(1, 10), i32Ty);
246   mlir::Type refSeqTy = fir::ReferenceType::get(seqTy);
247   mlir::Value a = builder.create<fir::UndefOp>(loc, refSeqTy);
248   mlir::Value result = builder.create<fir::UndefOp>(loc, seqTy);
249   mlir::Value mask = builder.create<fir::UndefOp>(loc, seqTy);
250   mlir::Value kind = builder.createIntegerConstant(loc, i32Ty, 1);
251   mlir::Value dim = builder.createIntegerConstant(loc, i32Ty, 1);
252   mlir::Value back = builder.createIntegerConstant(loc, i32Ty, 1);
253   genFct(builder, loc, result, a, dim, mask, kind, back);
254   checkCallOpFromResultBox(result, fctName, nbArgs);
255 }
256 
TEST_F(RuntimeCallTest,genMaxlocDimTest)257 TEST_F(RuntimeCallTest, genMaxlocDimTest) {
258   checkGenMxxlocDim(
259       *firBuilder, fir::runtime::genMaxlocDim, "_FortranAMaxlocDim", 6);
260 }
261 
TEST_F(RuntimeCallTest,genMinlocDimTest)262 TEST_F(RuntimeCallTest, genMinlocDimTest) {
263   checkGenMxxlocDim(
264       *firBuilder, fir::runtime::genMinlocDim, "_FortranAMinlocDim", 6);
265 }
266 
checkGenMxxvalChar(fir::FirOpBuilder & builder,void (* genFct)(fir::FirOpBuilder &,mlir::Location,mlir::Value,mlir::Value,mlir::Value),llvm::StringRef fctName,unsigned nbArgs)267 void checkGenMxxvalChar(fir::FirOpBuilder &builder,
268     void (*genFct)(fir::FirOpBuilder &, mlir::Location, mlir::Value,
269         mlir::Value, mlir::Value),
270     llvm::StringRef fctName, unsigned nbArgs) {
271   mlir::Location loc = builder.getUnknownLoc();
272   auto i32Ty = builder.getI32Type();
273   mlir::Type seqTy =
274       fir::SequenceType::get(fir::SequenceType::Shape(1, 10), i32Ty);
275   mlir::Type refSeqTy = fir::ReferenceType::get(seqTy);
276   mlir::Value a = builder.create<fir::UndefOp>(loc, refSeqTy);
277   mlir::Value result = builder.create<fir::UndefOp>(loc, seqTy);
278   mlir::Value mask = builder.create<fir::UndefOp>(loc, seqTy);
279   genFct(builder, loc, result, a, mask);
280   checkCallOpFromResultBox(result, fctName, nbArgs);
281 }
282 
TEST_F(RuntimeCallTest,genMaxvalCharTest)283 TEST_F(RuntimeCallTest, genMaxvalCharTest) {
284   checkGenMxxvalChar(
285       *firBuilder, fir::runtime::genMaxvalChar, "_FortranAMaxvalCharacter", 3);
286 }
287 
TEST_F(RuntimeCallTest,genMinvalCharTest)288 TEST_F(RuntimeCallTest, genMinvalCharTest) {
289   checkGenMxxvalChar(
290       *firBuilder, fir::runtime::genMinvalChar, "_FortranAMinvalCharacter", 3);
291 }
292 
checkGen4argsDim(fir::FirOpBuilder & builder,void (* genFct)(fir::FirOpBuilder &,mlir::Location,mlir::Value,mlir::Value,mlir::Value,mlir::Value),llvm::StringRef fctName,unsigned nbArgs)293 void checkGen4argsDim(fir::FirOpBuilder &builder,
294     void (*genFct)(fir::FirOpBuilder &, mlir::Location, mlir::Value,
295         mlir::Value, mlir::Value, mlir::Value),
296     llvm::StringRef fctName, unsigned nbArgs) {
297   mlir::Location loc = builder.getUnknownLoc();
298   auto i32Ty = builder.getI32Type();
299   mlir::Type seqTy =
300       fir::SequenceType::get(fir::SequenceType::Shape(1, 10), i32Ty);
301   mlir::Type refSeqTy = fir::ReferenceType::get(seqTy);
302   mlir::Value a = builder.create<fir::UndefOp>(loc, refSeqTy);
303   mlir::Value result = builder.create<fir::UndefOp>(loc, seqTy);
304   mlir::Value mask = builder.create<fir::UndefOp>(loc, seqTy);
305   mlir::Value dim = builder.createIntegerConstant(loc, i32Ty, 1);
306   genFct(builder, loc, result, a, dim, mask);
307   checkCallOpFromResultBox(result, fctName, nbArgs);
308 }
309 
TEST_F(RuntimeCallTest,genMaxvalDimTest)310 TEST_F(RuntimeCallTest, genMaxvalDimTest) {
311   checkGen4argsDim(
312       *firBuilder, fir::runtime::genMaxvalDim, "_FortranAMaxvalDim", 4);
313 }
314 
TEST_F(RuntimeCallTest,genMinvalDimTest)315 TEST_F(RuntimeCallTest, genMinvalDimTest) {
316   checkGen4argsDim(
317       *firBuilder, fir::runtime::genMinvalDim, "_FortranAMinvalDim", 4);
318 }
319 
TEST_F(RuntimeCallTest,genProductDimTest)320 TEST_F(RuntimeCallTest, genProductDimTest) {
321   checkGen4argsDim(
322       *firBuilder, fir::runtime::genProductDim, "_FortranAProductDim", 4);
323 }
324 
TEST_F(RuntimeCallTest,genSumDimTest)325 TEST_F(RuntimeCallTest, genSumDimTest) {
326   checkGen4argsDim(*firBuilder, fir::runtime::genSumDim, "_FortranASumDim", 4);
327 }
328