1 //===- ReductionTest.cpp -- Reduction runtime builder unit tests ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/Builder/Runtime/Reduction.h" 10 #include "RuntimeCallTestBase.h" 11 #include "gtest/gtest.h" 12 13 TEST_F(RuntimeCallTest, genAllTest) { 14 mlir::Location loc = firBuilder->getUnknownLoc(); 15 mlir::Value undef = firBuilder->create<fir::UndefOp>(loc, seqTy10); 16 mlir::Value dim = firBuilder->createIntegerConstant(loc, i32Ty, 1); 17 mlir::Value all = fir::runtime::genAll(*firBuilder, loc, undef, dim); 18 checkCallOp(all.getDefiningOp(), "_FortranAAll", 2); 19 } 20 21 TEST_F(RuntimeCallTest, genAllDescriptorTest) { 22 mlir::Location loc = firBuilder->getUnknownLoc(); 23 mlir::Value result = firBuilder->create<fir::UndefOp>(loc, seqTy10); 24 mlir::Value mask = firBuilder->create<fir::UndefOp>(loc, seqTy10); 25 mlir::Value dim = firBuilder->createIntegerConstant(loc, i32Ty, 1); 26 fir::runtime::genAllDescriptor(*firBuilder, loc, result, mask, dim); 27 checkCallOpFromResultBox(result, "_FortranAAllDim", 3); 28 } 29 30 TEST_F(RuntimeCallTest, genAnyTest) { 31 mlir::Location loc = firBuilder->getUnknownLoc(); 32 mlir::Value undef = firBuilder->create<fir::UndefOp>(loc, seqTy10); 33 mlir::Value dim = firBuilder->createIntegerConstant(loc, i32Ty, 1); 34 mlir::Value any = fir::runtime::genAny(*firBuilder, loc, undef, dim); 35 checkCallOp(any.getDefiningOp(), "_FortranAAny", 2); 36 } 37 38 TEST_F(RuntimeCallTest, genAnyDescriptorTest) { 39 mlir::Location loc = firBuilder->getUnknownLoc(); 40 mlir::Value result = firBuilder->create<fir::UndefOp>(loc, seqTy10); 41 mlir::Value mask = firBuilder->create<fir::UndefOp>(loc, seqTy10); 42 mlir::Value dim = firBuilder->createIntegerConstant(loc, i32Ty, 1); 43 fir::runtime::genAnyDescriptor(*firBuilder, loc, result, mask, dim); 44 checkCallOpFromResultBox(result, "_FortranAAnyDim", 3); 45 } 46 47 TEST_F(RuntimeCallTest, genCountTest) { 48 mlir::Location loc = firBuilder->getUnknownLoc(); 49 mlir::Value undef = firBuilder->create<fir::UndefOp>(loc, seqTy10); 50 mlir::Value dim = firBuilder->createIntegerConstant(loc, i32Ty, 1); 51 mlir::Value count = fir::runtime::genCount(*firBuilder, loc, undef, dim); 52 checkCallOp(count.getDefiningOp(), "_FortranACount", 2); 53 } 54 55 TEST_F(RuntimeCallTest, genCountDimTest) { 56 mlir::Location loc = firBuilder->getUnknownLoc(); 57 mlir::Value result = firBuilder->create<fir::UndefOp>(loc, seqTy10); 58 mlir::Value mask = firBuilder->create<fir::UndefOp>(loc, seqTy10); 59 mlir::Value dim = firBuilder->createIntegerConstant(loc, i32Ty, 1); 60 mlir::Value kind = firBuilder->createIntegerConstant(loc, i32Ty, 1); 61 fir::runtime::genCountDim(*firBuilder, loc, result, mask, dim, kind); 62 checkCallOpFromResultBox(result, "_FortranACountDim", 4); 63 } 64 65 void testGenMaxVal( 66 fir::FirOpBuilder &builder, mlir::Type eleTy, llvm::StringRef fctName) { 67 mlir::Location loc = builder.getUnknownLoc(); 68 mlir::Type seqTy = 69 fir::SequenceType::get(fir::SequenceType::Shape(1, 10), eleTy); 70 mlir::Type refSeqTy = fir::ReferenceType::get(seqTy); 71 mlir::Value undef = builder.create<fir::UndefOp>(loc, refSeqTy); 72 mlir::Value mask = builder.create<fir::UndefOp>(loc, seqTy); 73 mlir::Value max = fir::runtime::genMaxval(builder, loc, undef, mask); 74 checkCallOp(max.getDefiningOp(), fctName, 3); 75 } 76 77 TEST_F(RuntimeCallTest, genMaxValTest) { 78 testGenMaxVal(*firBuilder, f32Ty, "_FortranAMaxvalReal4"); 79 testGenMaxVal(*firBuilder, f64Ty, "_FortranAMaxvalReal8"); 80 testGenMaxVal(*firBuilder, f80Ty, "_FortranAMaxvalReal10"); 81 testGenMaxVal(*firBuilder, f128Ty, "_FortranAMaxvalReal16"); 82 83 testGenMaxVal(*firBuilder, i8Ty, "_FortranAMaxvalInteger1"); 84 testGenMaxVal(*firBuilder, i16Ty, "_FortranAMaxvalInteger2"); 85 testGenMaxVal(*firBuilder, i32Ty, "_FortranAMaxvalInteger4"); 86 testGenMaxVal(*firBuilder, i64Ty, "_FortranAMaxvalInteger8"); 87 testGenMaxVal(*firBuilder, i128Ty, "_FortranAMaxvalInteger16"); 88 } 89 90 void testGenMinVal( 91 fir::FirOpBuilder &builder, mlir::Type eleTy, llvm::StringRef fctName) { 92 mlir::Location loc = builder.getUnknownLoc(); 93 mlir::Type seqTy = 94 fir::SequenceType::get(fir::SequenceType::Shape(1, 10), eleTy); 95 mlir::Type refSeqTy = fir::ReferenceType::get(seqTy); 96 mlir::Value undef = builder.create<fir::UndefOp>(loc, refSeqTy); 97 mlir::Value mask = builder.create<fir::UndefOp>(loc, seqTy); 98 mlir::Value min = fir::runtime::genMinval(builder, loc, undef, mask); 99 checkCallOp(min.getDefiningOp(), fctName, 3); 100 } 101 102 TEST_F(RuntimeCallTest, genMinValTest) { 103 testGenMinVal(*firBuilder, f32Ty, "_FortranAMinvalReal4"); 104 testGenMinVal(*firBuilder, f64Ty, "_FortranAMinvalReal8"); 105 testGenMinVal(*firBuilder, f80Ty, "_FortranAMinvalReal10"); 106 testGenMinVal(*firBuilder, f128Ty, "_FortranAMinvalReal16"); 107 108 testGenMinVal(*firBuilder, i8Ty, "_FortranAMinvalInteger1"); 109 testGenMinVal(*firBuilder, i16Ty, "_FortranAMinvalInteger2"); 110 testGenMinVal(*firBuilder, i32Ty, "_FortranAMinvalInteger4"); 111 testGenMinVal(*firBuilder, i64Ty, "_FortranAMinvalInteger8"); 112 testGenMinVal(*firBuilder, i128Ty, "_FortranAMinvalInteger16"); 113 } 114 115 void testGenSum( 116 fir::FirOpBuilder &builder, mlir::Type eleTy, llvm::StringRef fctName) { 117 mlir::Location loc = builder.getUnknownLoc(); 118 mlir::Type seqTy = 119 fir::SequenceType::get(fir::SequenceType::Shape(1, 10), eleTy); 120 mlir::Type refSeqTy = fir::ReferenceType::get(seqTy); 121 mlir::Value undef = builder.create<fir::UndefOp>(loc, refSeqTy); 122 mlir::Value mask = builder.create<fir::UndefOp>(loc, seqTy); 123 mlir::Value result = builder.create<fir::UndefOp>(loc, seqTy); 124 mlir::Value sum = fir::runtime::genSum(builder, loc, undef, mask, result); 125 if (fir::isa_complex(eleTy)) 126 checkCallOpFromResultBox(result, fctName, 4); 127 else 128 checkCallOp(sum.getDefiningOp(), fctName, 3); 129 } 130 131 TEST_F(RuntimeCallTest, genSumTest) { 132 testGenSum(*firBuilder, f32Ty, "_FortranASumReal4"); 133 testGenSum(*firBuilder, f64Ty, "_FortranASumReal8"); 134 testGenSum(*firBuilder, f80Ty, "_FortranASumReal10"); 135 testGenSum(*firBuilder, f128Ty, "_FortranASumReal16"); 136 testGenSum(*firBuilder, i8Ty, "_FortranASumInteger1"); 137 testGenSum(*firBuilder, i16Ty, "_FortranASumInteger2"); 138 testGenSum(*firBuilder, i32Ty, "_FortranASumInteger4"); 139 testGenSum(*firBuilder, i64Ty, "_FortranASumInteger8"); 140 testGenSum(*firBuilder, i128Ty, "_FortranASumInteger16"); 141 testGenSum(*firBuilder, c4Ty, "_FortranACppSumComplex4"); 142 testGenSum(*firBuilder, c8Ty, "_FortranACppSumComplex8"); 143 testGenSum(*firBuilder, c10Ty, "_FortranACppSumComplex10"); 144 testGenSum(*firBuilder, c16Ty, "_FortranACppSumComplex16"); 145 } 146 147 void testGenProduct( 148 fir::FirOpBuilder &builder, mlir::Type eleTy, llvm::StringRef fctName) { 149 mlir::Location loc = builder.getUnknownLoc(); 150 mlir::Type seqTy = 151 fir::SequenceType::get(fir::SequenceType::Shape(1, 10), eleTy); 152 mlir::Type refSeqTy = fir::ReferenceType::get(seqTy); 153 mlir::Value undef = builder.create<fir::UndefOp>(loc, refSeqTy); 154 mlir::Value mask = builder.create<fir::UndefOp>(loc, seqTy); 155 mlir::Value result = builder.create<fir::UndefOp>(loc, seqTy); 156 mlir::Value prod = 157 fir::runtime::genProduct(builder, loc, undef, mask, result); 158 if (fir::isa_complex(eleTy)) 159 checkCallOpFromResultBox(result, fctName, 4); 160 else 161 checkCallOp(prod.getDefiningOp(), fctName, 3); 162 } 163 164 TEST_F(RuntimeCallTest, genProduct) { 165 testGenProduct(*firBuilder, f32Ty, "_FortranAProductReal4"); 166 testGenProduct(*firBuilder, f64Ty, "_FortranAProductReal8"); 167 testGenProduct(*firBuilder, f80Ty, "_FortranAProductReal10"); 168 testGenProduct(*firBuilder, f128Ty, "_FortranAProductReal16"); 169 testGenProduct(*firBuilder, i8Ty, "_FortranAProductInteger1"); 170 testGenProduct(*firBuilder, i16Ty, "_FortranAProductInteger2"); 171 testGenProduct(*firBuilder, i32Ty, "_FortranAProductInteger4"); 172 testGenProduct(*firBuilder, i64Ty, "_FortranAProductInteger8"); 173 testGenProduct(*firBuilder, i128Ty, "_FortranAProductInteger16"); 174 testGenProduct(*firBuilder, c4Ty, "_FortranACppProductComplex4"); 175 testGenProduct(*firBuilder, c8Ty, "_FortranACppProductComplex8"); 176 testGenProduct(*firBuilder, c10Ty, "_FortranACppProductComplex10"); 177 testGenProduct(*firBuilder, c16Ty, "_FortranACppProductComplex16"); 178 } 179 180 void testGenDotProduct( 181 fir::FirOpBuilder &builder, mlir::Type eleTy, llvm::StringRef fctName) { 182 mlir::Location loc = builder.getUnknownLoc(); 183 mlir::Type seqTy = 184 fir::SequenceType::get(fir::SequenceType::Shape(1, 10), eleTy); 185 mlir::Type refSeqTy = fir::ReferenceType::get(seqTy); 186 mlir::Value a = builder.create<fir::UndefOp>(loc, refSeqTy); 187 mlir::Value b = builder.create<fir::UndefOp>(loc, refSeqTy); 188 mlir::Value result = builder.create<fir::UndefOp>(loc, seqTy); 189 mlir::Value prod = fir::runtime::genDotProduct(builder, loc, a, b, result); 190 if (fir::isa_complex(eleTy)) 191 checkCallOpFromResultBox(result, fctName, 3); 192 else 193 checkCallOp(prod.getDefiningOp(), fctName, 2); 194 } 195 196 TEST_F(RuntimeCallTest, genDotProduct) { 197 testGenDotProduct(*firBuilder, f32Ty, "_FortranADotProductReal4"); 198 testGenDotProduct(*firBuilder, f64Ty, "_FortranADotProductReal8"); 199 testGenDotProduct(*firBuilder, f80Ty, "_FortranADotProductReal10"); 200 testGenDotProduct(*firBuilder, f128Ty, "_FortranADotProductReal16"); 201 testGenDotProduct(*firBuilder, i8Ty, "_FortranADotProductInteger1"); 202 testGenDotProduct(*firBuilder, i16Ty, "_FortranADotProductInteger2"); 203 testGenDotProduct(*firBuilder, i32Ty, "_FortranADotProductInteger4"); 204 testGenDotProduct(*firBuilder, i64Ty, "_FortranADotProductInteger8"); 205 testGenDotProduct(*firBuilder, i128Ty, "_FortranADotProductInteger16"); 206 testGenDotProduct(*firBuilder, c4Ty, "_FortranACppDotProductComplex4"); 207 testGenDotProduct(*firBuilder, c8Ty, "_FortranACppDotProductComplex8"); 208 testGenDotProduct(*firBuilder, c10Ty, "_FortranACppDotProductComplex10"); 209 testGenDotProduct(*firBuilder, c16Ty, "_FortranACppDotProductComplex16"); 210 } 211 212 void checkGenMxxloc(fir::FirOpBuilder &builder, 213 void (*genFct)(fir::FirOpBuilder &, Location, mlir::Value, mlir::Value, 214 mlir::Value, mlir::Value, mlir::Value), 215 llvm::StringRef fctName, unsigned nbArgs) { 216 mlir::Location loc = builder.getUnknownLoc(); 217 mlir::Type i32Ty = builder.getI32Type(); 218 mlir::Type seqTy = 219 fir::SequenceType::get(fir::SequenceType::Shape(1, 10), i32Ty); 220 mlir::Type refSeqTy = fir::ReferenceType::get(seqTy); 221 mlir::Value a = builder.create<fir::UndefOp>(loc, refSeqTy); 222 mlir::Value result = builder.create<fir::UndefOp>(loc, seqTy); 223 mlir::Value mask = builder.create<fir::UndefOp>(loc, seqTy); 224 mlir::Value kind = builder.createIntegerConstant(loc, i32Ty, 1); 225 mlir::Value back = builder.createIntegerConstant(loc, i32Ty, 1); 226 genFct(builder, loc, result, a, mask, kind, back); 227 checkCallOpFromResultBox(result, fctName, nbArgs); 228 } 229 230 TEST_F(RuntimeCallTest, genMaxlocTest) { 231 checkGenMxxloc(*firBuilder, fir::runtime::genMaxloc, "_FortranAMaxloc", 5); 232 } 233 234 TEST_F(RuntimeCallTest, genMinlocTest) { 235 checkGenMxxloc(*firBuilder, fir::runtime::genMinloc, "_FortranAMinloc", 5); 236 } 237 238 void checkGenMxxlocDim(fir::FirOpBuilder &builder, 239 void (*genFct)(fir::FirOpBuilder &, Location, mlir::Value, mlir::Value, 240 mlir::Value, mlir::Value, mlir::Value, mlir::Value), 241 llvm::StringRef fctName, unsigned nbArgs) { 242 mlir::Location loc = builder.getUnknownLoc(); 243 auto i32Ty = builder.getI32Type(); 244 mlir::Type seqTy = 245 fir::SequenceType::get(fir::SequenceType::Shape(1, 10), i32Ty); 246 mlir::Type refSeqTy = fir::ReferenceType::get(seqTy); 247 mlir::Value a = builder.create<fir::UndefOp>(loc, refSeqTy); 248 mlir::Value result = builder.create<fir::UndefOp>(loc, seqTy); 249 mlir::Value mask = builder.create<fir::UndefOp>(loc, seqTy); 250 mlir::Value kind = builder.createIntegerConstant(loc, i32Ty, 1); 251 mlir::Value dim = builder.createIntegerConstant(loc, i32Ty, 1); 252 mlir::Value back = builder.createIntegerConstant(loc, i32Ty, 1); 253 genFct(builder, loc, result, a, dim, mask, kind, back); 254 checkCallOpFromResultBox(result, fctName, nbArgs); 255 } 256 257 TEST_F(RuntimeCallTest, genMaxlocDimTest) { 258 checkGenMxxlocDim( 259 *firBuilder, fir::runtime::genMaxlocDim, "_FortranAMaxlocDim", 6); 260 } 261 262 TEST_F(RuntimeCallTest, genMinlocDimTest) { 263 checkGenMxxlocDim( 264 *firBuilder, fir::runtime::genMinlocDim, "_FortranAMinlocDim", 6); 265 } 266 267 void checkGenMxxvalChar(fir::FirOpBuilder &builder, 268 void (*genFct)( 269 fir::FirOpBuilder &, Location, mlir::Value, mlir::Value, mlir::Value), 270 llvm::StringRef fctName, unsigned nbArgs) { 271 mlir::Location loc = builder.getUnknownLoc(); 272 auto i32Ty = builder.getI32Type(); 273 mlir::Type seqTy = 274 fir::SequenceType::get(fir::SequenceType::Shape(1, 10), i32Ty); 275 mlir::Type refSeqTy = fir::ReferenceType::get(seqTy); 276 mlir::Value a = builder.create<fir::UndefOp>(loc, refSeqTy); 277 mlir::Value result = builder.create<fir::UndefOp>(loc, seqTy); 278 mlir::Value mask = builder.create<fir::UndefOp>(loc, seqTy); 279 genFct(builder, loc, result, a, mask); 280 checkCallOpFromResultBox(result, fctName, nbArgs); 281 } 282 283 TEST_F(RuntimeCallTest, genMaxvalCharTest) { 284 checkGenMxxvalChar( 285 *firBuilder, fir::runtime::genMaxvalChar, "_FortranAMaxvalCharacter", 3); 286 } 287 288 TEST_F(RuntimeCallTest, genMinvalCharTest) { 289 checkGenMxxvalChar( 290 *firBuilder, fir::runtime::genMinvalChar, "_FortranAMinvalCharacter", 3); 291 } 292 293 void checkGen4argsDim(fir::FirOpBuilder &builder, 294 void (*genFct)(fir::FirOpBuilder &, Location, mlir::Value, mlir::Value, 295 mlir::Value, mlir::Value), 296 llvm::StringRef fctName, unsigned nbArgs) { 297 mlir::Location loc = builder.getUnknownLoc(); 298 auto i32Ty = builder.getI32Type(); 299 mlir::Type seqTy = 300 fir::SequenceType::get(fir::SequenceType::Shape(1, 10), i32Ty); 301 mlir::Type refSeqTy = fir::ReferenceType::get(seqTy); 302 mlir::Value a = builder.create<fir::UndefOp>(loc, refSeqTy); 303 mlir::Value result = builder.create<fir::UndefOp>(loc, seqTy); 304 mlir::Value mask = builder.create<fir::UndefOp>(loc, seqTy); 305 mlir::Value dim = builder.createIntegerConstant(loc, i32Ty, 1); 306 genFct(builder, loc, result, a, dim, mask); 307 checkCallOpFromResultBox(result, fctName, nbArgs); 308 } 309 310 TEST_F(RuntimeCallTest, genMaxvalDimTest) { 311 checkGen4argsDim( 312 *firBuilder, fir::runtime::genMaxvalDim, "_FortranAMaxvalDim", 4); 313 } 314 315 TEST_F(RuntimeCallTest, genMinvalDimTest) { 316 checkGen4argsDim( 317 *firBuilder, fir::runtime::genMinvalDim, "_FortranAMinvalDim", 4); 318 } 319 320 TEST_F(RuntimeCallTest, genProductDimTest) { 321 checkGen4argsDim( 322 *firBuilder, fir::runtime::genProductDim, "_FortranAProductDim", 4); 323 } 324 325 TEST_F(RuntimeCallTest, genSumDimTest) { 326 checkGen4argsDim(*firBuilder, fir::runtime::genSumDim, "_FortranASumDim", 4); 327 } 328