1 //===-- lib/Evaluate/fold-integer.cpp -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "fold-implementation.h"
10 #include "fold-reduction.h"
11 #include "flang/Evaluate/check-expression.h"
12
13 namespace Fortran::evaluate {
14
15 // Given a collection of ConstantSubscripts values, package them as a Constant.
16 // Return scalar value if asScalar == true and shape-dim array otherwise.
17 template <typename T>
PackageConstantBounds(const ConstantSubscripts && bounds,bool asScalar=false)18 Expr<T> PackageConstantBounds(
19 const ConstantSubscripts &&bounds, bool asScalar = false) {
20 if (asScalar) {
21 return Expr<T>{Constant<T>{bounds.at(0)}};
22 } else {
23 // As rank-dim array
24 const int rank{GetRank(bounds)};
25 std::vector<Scalar<T>> packed(rank);
26 std::transform(bounds.begin(), bounds.end(), packed.begin(),
27 [](ConstantSubscript x) { return Scalar<T>(x); });
28 return Expr<T>{Constant<T>{std::move(packed), ConstantSubscripts{rank}}};
29 }
30 }
31
32 // Class to retrieve the constant bound of an expression which is an
33 // array that devolves to a type of Constant<T>
34 class GetConstantArrayBoundHelper {
35 public:
36 template <typename T>
GetLbound(const Expr<SomeType> & array,std::optional<int> dim)37 static Expr<T> GetLbound(
38 const Expr<SomeType> &array, std::optional<int> dim) {
39 return PackageConstantBounds<T>(
40 GetConstantArrayBoundHelper(dim, /*getLbound=*/true).Get(array),
41 dim.has_value());
42 }
43
44 template <typename T>
GetUbound(const Expr<SomeType> & array,std::optional<int> dim)45 static Expr<T> GetUbound(
46 const Expr<SomeType> &array, std::optional<int> dim) {
47 return PackageConstantBounds<T>(
48 GetConstantArrayBoundHelper(dim, /*getLbound=*/false).Get(array),
49 dim.has_value());
50 }
51
52 private:
GetConstantArrayBoundHelper(std::optional<ConstantSubscript> dim,bool getLbound)53 GetConstantArrayBoundHelper(
54 std::optional<ConstantSubscript> dim, bool getLbound)
55 : dim_{dim}, getLbound_{getLbound} {}
56
Get(const T &)57 template <typename T> ConstantSubscripts Get(const T &) {
58 // The method is needed for template expansion, but we should never get
59 // here in practice.
60 CHECK(false);
61 return {0};
62 }
63
Get(const Constant<T> & x)64 template <typename T> ConstantSubscripts Get(const Constant<T> &x) {
65 if (getLbound_) {
66 // Return the lower bound
67 if (dim_) {
68 return {x.lbounds().at(*dim_)};
69 } else {
70 return x.lbounds();
71 }
72 } else {
73 // Return the upper bound
74 if (arrayFromParenthesesExpr) {
75 // Underlying array comes from (x) expression - return shapes
76 if (dim_) {
77 return {x.shape().at(*dim_)};
78 } else {
79 return x.shape();
80 }
81 } else {
82 return x.ComputeUbounds(dim_);
83 }
84 }
85 }
86
Get(const Parentheses<T> & x)87 template <typename T> ConstantSubscripts Get(const Parentheses<T> &x) {
88 // Cause of temp variable inside parentheses - return [1, ... 1] for lower
89 // bounds and shape for upper bounds
90 if (getLbound_) {
91 return ConstantSubscripts(x.Rank(), ConstantSubscript{1});
92 } else {
93 // Indicate that underlying array comes from parentheses expression.
94 // Continue to unwrap expression until we hit a constant
95 arrayFromParenthesesExpr = true;
96 return Get(x.left());
97 }
98 }
99
Get(const Expr<T> & x)100 template <typename T> ConstantSubscripts Get(const Expr<T> &x) {
101 // recurse through Expr<T>'a until we hit a constant
102 return common::visit([&](const auto &inner) { return Get(inner); },
103 // [&](const auto &) { return 0; },
104 x.u);
105 }
106
107 const std::optional<ConstantSubscript> dim_;
108 const bool getLbound_;
109 bool arrayFromParenthesesExpr{false};
110 };
111
112 template <int KIND>
LBOUND(FoldingContext & context,FunctionRef<Type<TypeCategory::Integer,KIND>> && funcRef)113 Expr<Type<TypeCategory::Integer, KIND>> LBOUND(FoldingContext &context,
114 FunctionRef<Type<TypeCategory::Integer, KIND>> &&funcRef) {
115 using T = Type<TypeCategory::Integer, KIND>;
116 ActualArguments &args{funcRef.arguments()};
117 if (const auto *array{UnwrapExpr<Expr<SomeType>>(args[0])}) {
118 if (int rank{array->Rank()}; rank > 0) {
119 std::optional<int> dim;
120 if (funcRef.Rank() == 0) {
121 // Optional DIM= argument is present: result is scalar.
122 if (auto dim64{GetInt64Arg(args[1])}) {
123 if (*dim64 < 1 || *dim64 > rank) {
124 context.messages().Say("DIM=%jd dimension is out of range for "
125 "rank-%d array"_err_en_US,
126 *dim64, rank);
127 return MakeInvalidIntrinsic<T>(std::move(funcRef));
128 } else {
129 dim = *dim64 - 1; // 1-based to 0-based
130 }
131 } else {
132 // DIM= is present but not constant
133 return Expr<T>{std::move(funcRef)};
134 }
135 }
136 bool lowerBoundsAreOne{true};
137 if (auto named{ExtractNamedEntity(*array)}) {
138 const Symbol &symbol{named->GetLastSymbol()};
139 if (symbol.Rank() == rank) {
140 lowerBoundsAreOne = false;
141 if (dim) {
142 if (auto lb{GetLBOUND(context, *named, *dim)}) {
143 return Fold(context, ConvertToType<T>(std::move(*lb)));
144 }
145 } else if (auto extents{
146 AsExtentArrayExpr(GetLBOUNDs(context, *named))}) {
147 return Fold(context,
148 ConvertToType<T>(Expr<ExtentType>{std::move(*extents)}));
149 }
150 } else {
151 lowerBoundsAreOne = symbol.Rank() == 0; // LBOUND(array%component)
152 }
153 }
154 if (IsActuallyConstant(*array)) {
155 return GetConstantArrayBoundHelper::GetLbound<T>(*array, dim);
156 }
157 if (lowerBoundsAreOne) {
158 ConstantSubscripts ones(rank, ConstantSubscript{1});
159 return PackageConstantBounds<T>(std::move(ones), dim.has_value());
160 }
161 }
162 }
163 return Expr<T>{std::move(funcRef)};
164 }
165
166 template <int KIND>
UBOUND(FoldingContext & context,FunctionRef<Type<TypeCategory::Integer,KIND>> && funcRef)167 Expr<Type<TypeCategory::Integer, KIND>> UBOUND(FoldingContext &context,
168 FunctionRef<Type<TypeCategory::Integer, KIND>> &&funcRef) {
169 using T = Type<TypeCategory::Integer, KIND>;
170 ActualArguments &args{funcRef.arguments()};
171 if (auto *array{UnwrapExpr<Expr<SomeType>>(args[0])}) {
172 if (int rank{array->Rank()}; rank > 0) {
173 std::optional<int> dim;
174 if (funcRef.Rank() == 0) {
175 // Optional DIM= argument is present: result is scalar.
176 if (auto dim64{GetInt64Arg(args[1])}) {
177 if (*dim64 < 1 || *dim64 > rank) {
178 context.messages().Say("DIM=%jd dimension is out of range for "
179 "rank-%d array"_err_en_US,
180 *dim64, rank);
181 return MakeInvalidIntrinsic<T>(std::move(funcRef));
182 } else {
183 dim = *dim64 - 1; // 1-based to 0-based
184 }
185 } else {
186 // DIM= is present but not constant
187 return Expr<T>{std::move(funcRef)};
188 }
189 }
190 bool takeBoundsFromShape{true};
191 if (auto named{ExtractNamedEntity(*array)}) {
192 const Symbol &symbol{named->GetLastSymbol()};
193 if (symbol.Rank() == rank) {
194 takeBoundsFromShape = false;
195 if (dim) {
196 if (semantics::IsAssumedSizeArray(symbol) && *dim == rank - 1) {
197 context.messages().Say("DIM=%jd dimension is out of range for "
198 "rank-%d assumed-size array"_err_en_US,
199 rank, rank);
200 return MakeInvalidIntrinsic<T>(std::move(funcRef));
201 } else if (auto ub{GetUBOUND(context, *named, *dim)}) {
202 return Fold(context, ConvertToType<T>(std::move(*ub)));
203 }
204 } else {
205 Shape ubounds{GetUBOUNDs(context, *named)};
206 if (semantics::IsAssumedSizeArray(symbol)) {
207 CHECK(!ubounds.back());
208 ubounds.back() = ExtentExpr{-1};
209 }
210 if (auto extents{AsExtentArrayExpr(ubounds)}) {
211 return Fold(context,
212 ConvertToType<T>(Expr<ExtentType>{std::move(*extents)}));
213 }
214 }
215 } else {
216 takeBoundsFromShape = symbol.Rank() == 0; // UBOUND(array%component)
217 }
218 }
219 if (IsActuallyConstant(*array)) {
220 return GetConstantArrayBoundHelper::GetUbound<T>(*array, dim);
221 }
222 if (takeBoundsFromShape) {
223 if (auto shape{GetContextFreeShape(context, *array)}) {
224 if (dim) {
225 if (auto &dimSize{shape->at(*dim)}) {
226 return Fold(context,
227 ConvertToType<T>(Expr<ExtentType>{std::move(*dimSize)}));
228 }
229 } else if (auto shapeExpr{AsExtentArrayExpr(*shape)}) {
230 return Fold(context, ConvertToType<T>(std::move(*shapeExpr)));
231 }
232 }
233 }
234 }
235 }
236 return Expr<T>{std::move(funcRef)};
237 }
238
239 // COUNT()
240 template <typename T>
FoldCount(FoldingContext & context,FunctionRef<T> && ref)241 static Expr<T> FoldCount(FoldingContext &context, FunctionRef<T> &&ref) {
242 static_assert(T::category == TypeCategory::Integer);
243 ActualArguments &arg{ref.arguments()};
244 if (const Constant<LogicalResult> *mask{arg.empty()
245 ? nullptr
246 : Folder<LogicalResult>{context}.Folding(arg[0])}) {
247 std::optional<int> dim;
248 if (CheckReductionDIM(dim, context, arg, 1, mask->Rank())) {
249 auto accumulator{[&](Scalar<T> &element, const ConstantSubscripts &at) {
250 if (mask->At(at).IsTrue()) {
251 element = element.AddSigned(Scalar<T>{1}).value;
252 }
253 }};
254 return Expr<T>{DoReduction<T>(*mask, dim, Scalar<T>{}, accumulator)};
255 }
256 }
257 return Expr<T>{std::move(ref)};
258 }
259
260 // FINDLOC(), MAXLOC(), & MINLOC()
261 enum class WhichLocation { Findloc, Maxloc, Minloc };
262 template <WhichLocation WHICH> class LocationHelper {
263 public:
LocationHelper(DynamicType && type,ActualArguments & arg,FoldingContext & context)264 LocationHelper(
265 DynamicType &&type, ActualArguments &arg, FoldingContext &context)
266 : type_{type}, arg_{arg}, context_{context} {}
267 using Result = std::optional<Constant<SubscriptInteger>>;
268 using Types = std::conditional_t<WHICH == WhichLocation::Findloc,
269 AllIntrinsicTypes, RelationalTypes>;
270
Test() const271 template <typename T> Result Test() const {
272 if (T::category != type_.category() || T::kind != type_.kind()) {
273 return std::nullopt;
274 }
275 CHECK(arg_.size() == (WHICH == WhichLocation::Findloc ? 6 : 5));
276 Folder<T> folder{context_};
277 Constant<T> *array{folder.Folding(arg_[0])};
278 if (!array) {
279 return std::nullopt;
280 }
281 std::optional<Constant<T>> value;
282 if constexpr (WHICH == WhichLocation::Findloc) {
283 if (const Constant<T> *p{folder.Folding(arg_[1])}) {
284 value.emplace(*p);
285 } else {
286 return std::nullopt;
287 }
288 }
289 std::optional<int> dim;
290 Constant<LogicalResult> *mask{
291 GetReductionMASK(arg_[maskArg], array->shape(), context_)};
292 if ((!mask && arg_[maskArg]) ||
293 !CheckReductionDIM(dim, context_, arg_, dimArg, array->Rank())) {
294 return std::nullopt;
295 }
296 bool back{false};
297 if (arg_[backArg]) {
298 const auto *backConst{
299 Folder<LogicalResult>{context_}.Folding(arg_[backArg])};
300 if (backConst) {
301 back = backConst->GetScalarValue().value().IsTrue();
302 } else {
303 return std::nullopt;
304 }
305 }
306 const RelationalOperator relation{WHICH == WhichLocation::Findloc
307 ? RelationalOperator::EQ
308 : WHICH == WhichLocation::Maxloc
309 ? (back ? RelationalOperator::GE : RelationalOperator::GT)
310 : back ? RelationalOperator::LE
311 : RelationalOperator::LT};
312 // Use lower bounds of 1 exclusively.
313 array->SetLowerBoundsToOne();
314 ConstantSubscripts at{array->lbounds()}, maskAt, resultIndices, resultShape;
315 if (mask) {
316 if (auto scalarMask{mask->GetScalarValue()}) {
317 // Convert into array in case of scalar MASK= (for
318 // MAXLOC/MINLOC/FINDLOC mask should be be conformable)
319 ConstantSubscript n{GetSize(array->shape())};
320 std::vector<Scalar<LogicalResult>> mask_elements(
321 n, Scalar<LogicalResult>{scalarMask.value()});
322 *mask = Constant<LogicalResult>{
323 std::move(mask_elements), ConstantSubscripts{n}};
324 }
325 mask->SetLowerBoundsToOne();
326 maskAt = mask->lbounds();
327 }
328 if (dim) { // DIM=
329 if (*dim < 1 || *dim > array->Rank()) {
330 context_.messages().Say("DIM=%d is out of range"_err_en_US, *dim);
331 return std::nullopt;
332 }
333 int zbDim{*dim - 1};
334 resultShape = array->shape();
335 resultShape.erase(
336 resultShape.begin() + zbDim); // scalar if array is vector
337 ConstantSubscript dimLength{array->shape()[zbDim]};
338 ConstantSubscript n{GetSize(resultShape)};
339 for (ConstantSubscript j{0}; j < n; ++j) {
340 ConstantSubscript hit{0};
341 if constexpr (WHICH == WhichLocation::Maxloc ||
342 WHICH == WhichLocation::Minloc) {
343 value.reset();
344 }
345 for (ConstantSubscript k{0}; k < dimLength;
346 ++k, ++at[zbDim], mask && ++maskAt[zbDim]) {
347 if ((!mask || mask->At(maskAt).IsTrue()) &&
348 IsHit(array->At(at), value, relation)) {
349 hit = at[zbDim];
350 if constexpr (WHICH == WhichLocation::Findloc) {
351 if (!back) {
352 break;
353 }
354 }
355 }
356 }
357 resultIndices.emplace_back(hit);
358 at[zbDim] = std::max<ConstantSubscript>(dimLength, 1);
359 array->IncrementSubscripts(at);
360 at[zbDim] = 1;
361 if (mask) {
362 maskAt[zbDim] = mask->lbounds()[zbDim] +
363 std::max<ConstantSubscript>(dimLength, 1) - 1;
364 mask->IncrementSubscripts(maskAt);
365 maskAt[zbDim] = mask->lbounds()[zbDim];
366 }
367 }
368 } else { // no DIM=
369 resultShape = ConstantSubscripts{array->Rank()}; // always a vector
370 ConstantSubscript n{GetSize(array->shape())};
371 resultIndices = ConstantSubscripts(array->Rank(), 0);
372 for (ConstantSubscript j{0}; j < n; ++j, array->IncrementSubscripts(at),
373 mask && mask->IncrementSubscripts(maskAt)) {
374 if ((!mask || mask->At(maskAt).IsTrue()) &&
375 IsHit(array->At(at), value, relation)) {
376 resultIndices = at;
377 if constexpr (WHICH == WhichLocation::Findloc) {
378 if (!back) {
379 break;
380 }
381 }
382 }
383 }
384 }
385 std::vector<Scalar<SubscriptInteger>> resultElements;
386 for (ConstantSubscript j : resultIndices) {
387 resultElements.emplace_back(j);
388 }
389 return Constant<SubscriptInteger>{
390 std::move(resultElements), std::move(resultShape)};
391 }
392
393 private:
394 template <typename T>
IsHit(typename Constant<T>::Element element,std::optional<Constant<T>> & value,RelationalOperator relation) const395 bool IsHit(typename Constant<T>::Element element,
396 std::optional<Constant<T>> &value,
397 [[maybe_unused]] RelationalOperator relation) const {
398 std::optional<Expr<LogicalResult>> cmp;
399 bool result{true};
400 if (value) {
401 if constexpr (T::category == TypeCategory::Logical) {
402 // array(at) .EQV. value?
403 static_assert(WHICH == WhichLocation::Findloc);
404 cmp.emplace(ConvertToType<LogicalResult>(
405 Expr<T>{LogicalOperation<T::kind>{LogicalOperator::Eqv,
406 Expr<T>{Constant<T>{element}}, Expr<T>{Constant<T>{*value}}}}));
407 } else { // compare array(at) to value
408 cmp.emplace(PackageRelation(relation, Expr<T>{Constant<T>{element}},
409 Expr<T>{Constant<T>{*value}}));
410 }
411 Expr<LogicalResult> folded{Fold(context_, std::move(*cmp))};
412 result = GetScalarConstantValue<LogicalResult>(folded).value().IsTrue();
413 } else {
414 // first unmasked element for MAXLOC/MINLOC - always take it
415 }
416 if constexpr (WHICH == WhichLocation::Maxloc ||
417 WHICH == WhichLocation::Minloc) {
418 if (result) {
419 value.emplace(std::move(element));
420 }
421 }
422 return result;
423 }
424
425 static constexpr int dimArg{WHICH == WhichLocation::Findloc ? 2 : 1};
426 static constexpr int maskArg{dimArg + 1};
427 static constexpr int backArg{maskArg + 2};
428
429 DynamicType type_;
430 ActualArguments &arg_;
431 FoldingContext &context_;
432 };
433
434 template <WhichLocation which>
FoldLocationCall(ActualArguments & arg,FoldingContext & context)435 static std::optional<Constant<SubscriptInteger>> FoldLocationCall(
436 ActualArguments &arg, FoldingContext &context) {
437 if (arg[0]) {
438 if (auto type{arg[0]->GetType()}) {
439 if constexpr (which == WhichLocation::Findloc) {
440 // Both ARRAY and VALUE are susceptible to conversion to a common
441 // comparison type.
442 if (arg[1]) {
443 if (auto valType{arg[1]->GetType()}) {
444 if (auto compareType{ComparisonType(*type, *valType)}) {
445 type = compareType;
446 }
447 }
448 }
449 }
450 return common::SearchTypes(
451 LocationHelper<which>{std::move(*type), arg, context});
452 }
453 }
454 return std::nullopt;
455 }
456
457 template <WhichLocation which, typename T>
FoldLocation(FoldingContext & context,FunctionRef<T> && ref)458 static Expr<T> FoldLocation(FoldingContext &context, FunctionRef<T> &&ref) {
459 static_assert(T::category == TypeCategory::Integer);
460 if (std::optional<Constant<SubscriptInteger>> found{
461 FoldLocationCall<which>(ref.arguments(), context)}) {
462 return Expr<T>{Fold(
463 context, ConvertToType<T>(Expr<SubscriptInteger>{std::move(*found)}))};
464 } else {
465 return Expr<T>{std::move(ref)};
466 }
467 }
468
469 // for IALL, IANY, & IPARITY
470 template <typename T>
FoldBitReduction(FoldingContext & context,FunctionRef<T> && ref,Scalar<T> (Scalar<T>::* operation)(const Scalar<T> &)const,Scalar<T> identity)471 static Expr<T> FoldBitReduction(FoldingContext &context, FunctionRef<T> &&ref,
472 Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const,
473 Scalar<T> identity) {
474 static_assert(T::category == TypeCategory::Integer);
475 std::optional<int> dim;
476 if (std::optional<Constant<T>> array{
477 ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
478 /*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
479 auto accumulator{[&](Scalar<T> &element, const ConstantSubscripts &at) {
480 element = (element.*operation)(array->At(at));
481 }};
482 return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
483 }
484 return Expr<T>{std::move(ref)};
485 }
486
487 template <int KIND>
FoldIntrinsicFunction(FoldingContext & context,FunctionRef<Type<TypeCategory::Integer,KIND>> && funcRef)488 Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
489 FoldingContext &context,
490 FunctionRef<Type<TypeCategory::Integer, KIND>> &&funcRef) {
491 using T = Type<TypeCategory::Integer, KIND>;
492 using Int4 = Type<TypeCategory::Integer, 4>;
493 ActualArguments &args{funcRef.arguments()};
494 auto *intrinsic{std::get_if<SpecificIntrinsic>(&funcRef.proc().u)};
495 CHECK(intrinsic);
496 std::string name{intrinsic->name};
497 if (name == "abs") { // incl. babs, iiabs, jiaabs, & kiabs
498 return FoldElementalIntrinsic<T, T>(context, std::move(funcRef),
499 ScalarFunc<T, T>([&context](const Scalar<T> &i) -> Scalar<T> {
500 typename Scalar<T>::ValueWithOverflow j{i.ABS()};
501 if (j.overflow) {
502 context.messages().Say(
503 "abs(integer(kind=%d)) folding overflowed"_warn_en_US, KIND);
504 }
505 return j.value;
506 }));
507 } else if (name == "bit_size") {
508 return Expr<T>{Scalar<T>::bits};
509 } else if (name == "ceiling" || name == "floor" || name == "nint") {
510 if (const auto *cx{UnwrapExpr<Expr<SomeReal>>(args[0])}) {
511 // NINT rounds ties away from zero, not to even
512 common::RoundingMode mode{name == "ceiling" ? common::RoundingMode::Up
513 : name == "floor" ? common::RoundingMode::Down
514 : common::RoundingMode::TiesAwayFromZero};
515 return common::visit(
516 [&](const auto &kx) {
517 using TR = ResultType<decltype(kx)>;
518 return FoldElementalIntrinsic<T, TR>(context, std::move(funcRef),
519 ScalarFunc<T, TR>([&](const Scalar<TR> &x) {
520 auto y{x.template ToInteger<Scalar<T>>(mode)};
521 if (y.flags.test(RealFlag::Overflow)) {
522 context.messages().Say(
523 "%s intrinsic folding overflow"_warn_en_US, name);
524 }
525 return y.value;
526 }));
527 },
528 cx->u);
529 }
530 } else if (name == "count") {
531 return FoldCount<T>(context, std::move(funcRef));
532 } else if (name == "digits") {
533 if (const auto *cx{UnwrapExpr<Expr<SomeInteger>>(args[0])}) {
534 return Expr<T>{common::visit(
535 [](const auto &kx) {
536 return Scalar<ResultType<decltype(kx)>>::DIGITS;
537 },
538 cx->u)};
539 } else if (const auto *cx{UnwrapExpr<Expr<SomeReal>>(args[0])}) {
540 return Expr<T>{common::visit(
541 [](const auto &kx) {
542 return Scalar<ResultType<decltype(kx)>>::DIGITS;
543 },
544 cx->u)};
545 } else if (const auto *cx{UnwrapExpr<Expr<SomeComplex>>(args[0])}) {
546 return Expr<T>{common::visit(
547 [](const auto &kx) {
548 return Scalar<typename ResultType<decltype(kx)>::Part>::DIGITS;
549 },
550 cx->u)};
551 }
552 } else if (name == "dim") {
553 return FoldElementalIntrinsic<T, T, T>(
554 context, std::move(funcRef), &Scalar<T>::DIM);
555 } else if (name == "dshiftl" || name == "dshiftr") {
556 const auto fptr{
557 name == "dshiftl" ? &Scalar<T>::DSHIFTL : &Scalar<T>::DSHIFTR};
558 // Third argument can be of any kind. However, it must be smaller or equal
559 // than BIT_SIZE. It can be converted to Int4 to simplify.
560 return FoldElementalIntrinsic<T, T, T, Int4>(context, std::move(funcRef),
561 ScalarFunc<T, T, T, Int4>(
562 [&fptr](const Scalar<T> &i, const Scalar<T> &j,
563 const Scalar<Int4> &shift) -> Scalar<T> {
564 return std::invoke(fptr, i, j, static_cast<int>(shift.ToInt64()));
565 }));
566 } else if (name == "exponent") {
567 if (auto *sx{UnwrapExpr<Expr<SomeReal>>(args[0])}) {
568 return common::visit(
569 [&funcRef, &context](const auto &x) -> Expr<T> {
570 using TR = typename std::decay_t<decltype(x)>::Result;
571 return FoldElementalIntrinsic<T, TR>(context, std::move(funcRef),
572 &Scalar<TR>::template EXPONENT<Scalar<T>>);
573 },
574 sx->u);
575 } else {
576 DIE("exponent argument must be real");
577 }
578 } else if (name == "findloc") {
579 return FoldLocation<WhichLocation::Findloc, T>(context, std::move(funcRef));
580 } else if (name == "huge") {
581 return Expr<T>{Scalar<T>::HUGE()};
582 } else if (name == "iachar" || name == "ichar") {
583 auto *someChar{UnwrapExpr<Expr<SomeCharacter>>(args[0])};
584 CHECK(someChar);
585 if (auto len{ToInt64(someChar->LEN())}) {
586 if (len.value() != 1) {
587 // Do not die, this was not checked before
588 context.messages().Say(
589 "Character in intrinsic function %s must have length one"_warn_en_US,
590 name);
591 } else {
592 return common::visit(
593 [&funcRef, &context](const auto &str) -> Expr<T> {
594 using Char = typename std::decay_t<decltype(str)>::Result;
595 return FoldElementalIntrinsic<T, Char>(context,
596 std::move(funcRef),
597 ScalarFunc<T, Char>([](const Scalar<Char> &c) {
598 return Scalar<T>{CharacterUtils<Char::kind>::ICHAR(c)};
599 }));
600 },
601 someChar->u);
602 }
603 }
604 } else if (name == "iand" || name == "ior" || name == "ieor") {
605 auto fptr{&Scalar<T>::IAND};
606 if (name == "iand") { // done in fptr declaration
607 } else if (name == "ior") {
608 fptr = &Scalar<T>::IOR;
609 } else if (name == "ieor") {
610 fptr = &Scalar<T>::IEOR;
611 } else {
612 common::die("missing case to fold intrinsic function %s", name.c_str());
613 }
614 return FoldElementalIntrinsic<T, T, T>(
615 context, std::move(funcRef), ScalarFunc<T, T, T>(fptr));
616 } else if (name == "iall") {
617 return FoldBitReduction(
618 context, std::move(funcRef), &Scalar<T>::IAND, Scalar<T>{}.NOT());
619 } else if (name == "iany") {
620 return FoldBitReduction(
621 context, std::move(funcRef), &Scalar<T>::IOR, Scalar<T>{});
622 } else if (name == "ibclr" || name == "ibset") {
623 // Second argument can be of any kind. However, it must be smaller
624 // than BIT_SIZE. It can be converted to Int4 to simplify.
625 auto fptr{&Scalar<T>::IBCLR};
626 if (name == "ibclr") { // done in fptr definition
627 } else if (name == "ibset") {
628 fptr = &Scalar<T>::IBSET;
629 } else {
630 common::die("missing case to fold intrinsic function %s", name.c_str());
631 }
632 return FoldElementalIntrinsic<T, T, Int4>(context, std::move(funcRef),
633 ScalarFunc<T, T, Int4>([&](const Scalar<T> &i,
634 const Scalar<Int4> &pos) -> Scalar<T> {
635 auto posVal{static_cast<int>(pos.ToInt64())};
636 if (posVal < 0) {
637 context.messages().Say(
638 "bit position for %s (%d) is negative"_err_en_US, name, posVal);
639 } else if (posVal >= i.bits) {
640 context.messages().Say(
641 "bit position for %s (%d) is not less than %d"_err_en_US, name,
642 posVal, i.bits);
643 }
644 return std::invoke(fptr, i, posVal);
645 }));
646 } else if (name == "ibits") {
647 return FoldElementalIntrinsic<T, T, Int4, Int4>(context, std::move(funcRef),
648 ScalarFunc<T, T, Int4, Int4>([&](const Scalar<T> &i,
649 const Scalar<Int4> &pos,
650 const Scalar<Int4> &len) -> Scalar<T> {
651 auto posVal{static_cast<int>(pos.ToInt64())};
652 auto lenVal{static_cast<int>(len.ToInt64())};
653 if (posVal < 0) {
654 context.messages().Say(
655 "bit position for IBITS(POS=%d,LEN=%d) is negative"_err_en_US,
656 posVal, lenVal);
657 } else if (lenVal < 0) {
658 context.messages().Say(
659 "bit length for IBITS(POS=%d,LEN=%d) is negative"_err_en_US,
660 posVal, lenVal);
661 } else if (posVal + lenVal > i.bits) {
662 context.messages().Say(
663 "IBITS(POS=%d,LEN=%d) must have POS+LEN no greater than %d"_err_en_US,
664 posVal + lenVal, i.bits);
665 }
666 return i.IBITS(posVal, lenVal);
667 }));
668 } else if (name == "index" || name == "scan" || name == "verify") {
669 if (auto *charExpr{UnwrapExpr<Expr<SomeCharacter>>(args[0])}) {
670 return common::visit(
671 [&](const auto &kch) -> Expr<T> {
672 using TC = typename std::decay_t<decltype(kch)>::Result;
673 if (UnwrapExpr<Expr<SomeLogical>>(args[2])) { // BACK=
674 return FoldElementalIntrinsic<T, TC, TC, LogicalResult>(context,
675 std::move(funcRef),
676 ScalarFunc<T, TC, TC, LogicalResult>{
677 [&name](const Scalar<TC> &str, const Scalar<TC> &other,
678 const Scalar<LogicalResult> &back) -> Scalar<T> {
679 return name == "index"
680 ? CharacterUtils<TC::kind>::INDEX(
681 str, other, back.IsTrue())
682 : name == "scan" ? CharacterUtils<TC::kind>::SCAN(
683 str, other, back.IsTrue())
684 : CharacterUtils<TC::kind>::VERIFY(
685 str, other, back.IsTrue());
686 }});
687 } else {
688 return FoldElementalIntrinsic<T, TC, TC>(context,
689 std::move(funcRef),
690 ScalarFunc<T, TC, TC>{
691 [&name](const Scalar<TC> &str,
692 const Scalar<TC> &other) -> Scalar<T> {
693 return name == "index"
694 ? CharacterUtils<TC::kind>::INDEX(str, other)
695 : name == "scan"
696 ? CharacterUtils<TC::kind>::SCAN(str, other)
697 : CharacterUtils<TC::kind>::VERIFY(str, other);
698 }});
699 }
700 },
701 charExpr->u);
702 } else {
703 DIE("first argument must be CHARACTER");
704 }
705 } else if (name == "int") {
706 if (auto *expr{UnwrapExpr<Expr<SomeType>>(args[0])}) {
707 return common::visit(
708 [&](auto &&x) -> Expr<T> {
709 using From = std::decay_t<decltype(x)>;
710 if constexpr (std::is_same_v<From, BOZLiteralConstant> ||
711 IsNumericCategoryExpr<From>()) {
712 return Fold(context, ConvertToType<T>(std::move(x)));
713 }
714 DIE("int() argument type not valid");
715 },
716 std::move(expr->u));
717 }
718 } else if (name == "int_ptr_kind") {
719 return Expr<T>{8};
720 } else if (name == "kind") {
721 if constexpr (common::HasMember<T, IntegerTypes>) {
722 return Expr<T>{args[0].value().GetType()->kind()};
723 } else {
724 DIE("kind() result not integral");
725 }
726 } else if (name == "iparity") {
727 return FoldBitReduction(
728 context, std::move(funcRef), &Scalar<T>::IEOR, Scalar<T>{});
729 } else if (name == "ishft") {
730 return FoldElementalIntrinsic<T, T, Int4>(context, std::move(funcRef),
731 ScalarFunc<T, T, Int4>([&](const Scalar<T> &i,
732 const Scalar<Int4> &pos) -> Scalar<T> {
733 auto posVal{static_cast<int>(pos.ToInt64())};
734 if (posVal < -i.bits) {
735 context.messages().Say(
736 "SHIFT=%d count for ishft is less than %d"_err_en_US, posVal,
737 -i.bits);
738 } else if (posVal > i.bits) {
739 context.messages().Say(
740 "SHIFT=%d count for ishft is greater than %d"_err_en_US, posVal,
741 i.bits);
742 }
743 return i.ISHFT(posVal);
744 }));
745 } else if (name == "ishftc") {
746 if (args.at(2)) { // SIZE= is present
747 return FoldElementalIntrinsic<T, T, Int4, Int4>(context,
748 std::move(funcRef),
749 ScalarFunc<T, T, Int4, Int4>(
750 [&](const Scalar<T> &i, const Scalar<Int4> &shift,
751 const Scalar<Int4> &size) -> Scalar<T> {
752 // Errors are caught in intrinsics.cpp
753 auto shiftVal{static_cast<int>(shift.ToInt64())};
754 auto sizeVal{static_cast<int>(size.ToInt64())};
755 return i.ISHFTC(shiftVal, sizeVal);
756 }));
757 } else { // no SIZE=
758 return FoldElementalIntrinsic<T, T, Int4>(context, std::move(funcRef),
759 ScalarFunc<T, T, Int4>(
760 [&](const Scalar<T> &i, const Scalar<Int4> &count) -> Scalar<T> {
761 auto countVal{static_cast<int>(count.ToInt64())};
762 return i.ISHFTC(countVal);
763 }));
764 }
765 } else if (name == "lbound") {
766 return LBOUND(context, std::move(funcRef));
767 } else if (name == "leadz" || name == "trailz" || name == "poppar" ||
768 name == "popcnt") {
769 if (auto *sn{UnwrapExpr<Expr<SomeInteger>>(args[0])}) {
770 return common::visit(
771 [&funcRef, &context, &name](const auto &n) -> Expr<T> {
772 using TI = typename std::decay_t<decltype(n)>::Result;
773 if (name == "poppar") {
774 return FoldElementalIntrinsic<T, TI>(context, std::move(funcRef),
775 ScalarFunc<T, TI>([](const Scalar<TI> &i) -> Scalar<T> {
776 return Scalar<T>{i.POPPAR() ? 1 : 0};
777 }));
778 }
779 auto fptr{&Scalar<TI>::LEADZ};
780 if (name == "leadz") { // done in fptr definition
781 } else if (name == "trailz") {
782 fptr = &Scalar<TI>::TRAILZ;
783 } else if (name == "popcnt") {
784 fptr = &Scalar<TI>::POPCNT;
785 } else {
786 common::die(
787 "missing case to fold intrinsic function %s", name.c_str());
788 }
789 return FoldElementalIntrinsic<T, TI>(context, std::move(funcRef),
790 ScalarFunc<T, TI>([&fptr](const Scalar<TI> &i) -> Scalar<T> {
791 return Scalar<T>{std::invoke(fptr, i)};
792 }));
793 },
794 sn->u);
795 } else {
796 DIE("leadz argument must be integer");
797 }
798 } else if (name == "len") {
799 if (auto *charExpr{UnwrapExpr<Expr<SomeCharacter>>(args[0])}) {
800 return common::visit(
801 [&](auto &kx) {
802 if (auto len{kx.LEN()}) {
803 if (IsScopeInvariantExpr(*len)) {
804 return Fold(context, ConvertToType<T>(*std::move(len)));
805 } else {
806 return Expr<T>{std::move(funcRef)};
807 }
808 } else {
809 return Expr<T>{std::move(funcRef)};
810 }
811 },
812 charExpr->u);
813 } else {
814 DIE("len() argument must be of character type");
815 }
816 } else if (name == "len_trim") {
817 if (auto *charExpr{UnwrapExpr<Expr<SomeCharacter>>(args[0])}) {
818 return common::visit(
819 [&](const auto &kch) -> Expr<T> {
820 using TC = typename std::decay_t<decltype(kch)>::Result;
821 return FoldElementalIntrinsic<T, TC>(context, std::move(funcRef),
822 ScalarFunc<T, TC>{[](const Scalar<TC> &str) -> Scalar<T> {
823 return CharacterUtils<TC::kind>::LEN_TRIM(str);
824 }});
825 },
826 charExpr->u);
827 } else {
828 DIE("len_trim() argument must be of character type");
829 }
830 } else if (name == "maskl" || name == "maskr") {
831 // Argument can be of any kind but value has to be smaller than BIT_SIZE.
832 // It can be safely converted to Int4 to simplify.
833 const auto fptr{name == "maskl" ? &Scalar<T>::MASKL : &Scalar<T>::MASKR};
834 return FoldElementalIntrinsic<T, Int4>(context, std::move(funcRef),
835 ScalarFunc<T, Int4>([&fptr](const Scalar<Int4> &places) -> Scalar<T> {
836 return fptr(static_cast<int>(places.ToInt64()));
837 }));
838 } else if (name == "max") {
839 return FoldMINorMAX(context, std::move(funcRef), Ordering::Greater);
840 } else if (name == "max0" || name == "max1") {
841 return RewriteSpecificMINorMAX(context, std::move(funcRef));
842 } else if (name == "maxexponent") {
843 if (auto *sx{UnwrapExpr<Expr<SomeReal>>(args[0])}) {
844 return common::visit(
845 [](const auto &x) {
846 using TR = typename std::decay_t<decltype(x)>::Result;
847 return Expr<T>{Scalar<TR>::MAXEXPONENT};
848 },
849 sx->u);
850 }
851 } else if (name == "maxloc") {
852 return FoldLocation<WhichLocation::Maxloc, T>(context, std::move(funcRef));
853 } else if (name == "maxval") {
854 return FoldMaxvalMinval<T>(context, std::move(funcRef),
855 RelationalOperator::GT, T::Scalar::Least());
856 } else if (name == "merge") {
857 return FoldMerge<T>(context, std::move(funcRef));
858 } else if (name == "merge_bits") {
859 return FoldElementalIntrinsic<T, T, T, T>(
860 context, std::move(funcRef), &Scalar<T>::MERGE_BITS);
861 } else if (name == "min") {
862 return FoldMINorMAX(context, std::move(funcRef), Ordering::Less);
863 } else if (name == "min0" || name == "min1") {
864 return RewriteSpecificMINorMAX(context, std::move(funcRef));
865 } else if (name == "minexponent") {
866 if (auto *sx{UnwrapExpr<Expr<SomeReal>>(args[0])}) {
867 return common::visit(
868 [](const auto &x) {
869 using TR = typename std::decay_t<decltype(x)>::Result;
870 return Expr<T>{Scalar<TR>::MINEXPONENT};
871 },
872 sx->u);
873 }
874 } else if (name == "minloc") {
875 return FoldLocation<WhichLocation::Minloc, T>(context, std::move(funcRef));
876 } else if (name == "minval") {
877 return FoldMaxvalMinval<T>(
878 context, std::move(funcRef), RelationalOperator::LT, T::Scalar::HUGE());
879 } else if (name == "mod") {
880 return FoldElementalIntrinsic<T, T, T>(context, std::move(funcRef),
881 ScalarFuncWithContext<T, T, T>(
882 [](FoldingContext &context, const Scalar<T> &x,
883 const Scalar<T> &y) -> Scalar<T> {
884 auto quotRem{x.DivideSigned(y)};
885 if (quotRem.divisionByZero) {
886 context.messages().Say("mod() by zero"_warn_en_US);
887 } else if (quotRem.overflow) {
888 context.messages().Say("mod() folding overflowed"_warn_en_US);
889 }
890 return quotRem.remainder;
891 }));
892 } else if (name == "modulo") {
893 return FoldElementalIntrinsic<T, T, T>(context, std::move(funcRef),
894 ScalarFuncWithContext<T, T, T>(
895 [](FoldingContext &context, const Scalar<T> &x,
896 const Scalar<T> &y) -> Scalar<T> {
897 auto result{x.MODULO(y)};
898 if (result.overflow) {
899 context.messages().Say(
900 "modulo() folding overflowed"_warn_en_US);
901 }
902 return result.value;
903 }));
904 } else if (name == "not") {
905 return FoldElementalIntrinsic<T, T>(
906 context, std::move(funcRef), &Scalar<T>::NOT);
907 } else if (name == "precision") {
908 if (const auto *cx{UnwrapExpr<Expr<SomeReal>>(args[0])}) {
909 return Expr<T>{common::visit(
910 [](const auto &kx) {
911 return Scalar<ResultType<decltype(kx)>>::PRECISION;
912 },
913 cx->u)};
914 } else if (const auto *cx{UnwrapExpr<Expr<SomeComplex>>(args[0])}) {
915 return Expr<T>{common::visit(
916 [](const auto &kx) {
917 return Scalar<typename ResultType<decltype(kx)>::Part>::PRECISION;
918 },
919 cx->u)};
920 }
921 } else if (name == "product") {
922 return FoldProduct<T>(context, std::move(funcRef), Scalar<T>{1});
923 } else if (name == "radix") {
924 return Expr<T>{2};
925 } else if (name == "range") {
926 if (const auto *cx{UnwrapExpr<Expr<SomeInteger>>(args[0])}) {
927 return Expr<T>{common::visit(
928 [](const auto &kx) {
929 return Scalar<ResultType<decltype(kx)>>::RANGE;
930 },
931 cx->u)};
932 } else if (const auto *cx{UnwrapExpr<Expr<SomeReal>>(args[0])}) {
933 return Expr<T>{common::visit(
934 [](const auto &kx) {
935 return Scalar<ResultType<decltype(kx)>>::RANGE;
936 },
937 cx->u)};
938 } else if (const auto *cx{UnwrapExpr<Expr<SomeComplex>>(args[0])}) {
939 return Expr<T>{common::visit(
940 [](const auto &kx) {
941 return Scalar<typename ResultType<decltype(kx)>::Part>::RANGE;
942 },
943 cx->u)};
944 }
945 } else if (name == "rank") {
946 if (const auto *array{UnwrapExpr<Expr<SomeType>>(args[0])}) {
947 if (auto named{ExtractNamedEntity(*array)}) {
948 const Symbol &symbol{named->GetLastSymbol()};
949 if (IsAssumedRank(symbol)) {
950 // DescriptorInquiry can only be placed in expression of kind
951 // DescriptorInquiry::Result::kind.
952 return ConvertToType<T>(Expr<
953 Type<TypeCategory::Integer, DescriptorInquiry::Result::kind>>{
954 DescriptorInquiry{*named, DescriptorInquiry::Field::Rank}});
955 }
956 }
957 return Expr<T>{args[0].value().Rank()};
958 }
959 return Expr<T>{args[0].value().Rank()};
960 } else if (name == "selected_char_kind") {
961 if (const auto *chCon{UnwrapExpr<Constant<TypeOf<std::string>>>(args[0])}) {
962 if (std::optional<std::string> value{chCon->GetScalarValue()}) {
963 int defaultKind{
964 context.defaults().GetDefaultKind(TypeCategory::Character)};
965 return Expr<T>{SelectedCharKind(*value, defaultKind)};
966 }
967 }
968 } else if (name == "selected_int_kind") {
969 if (auto p{GetInt64Arg(args[0])}) {
970 return Expr<T>{context.targetCharacteristics().SelectedIntKind(*p)};
971 }
972 } else if (name == "selected_real_kind" ||
973 name == "__builtin_ieee_selected_real_kind") {
974 if (auto p{GetInt64ArgOr(args[0], 0)}) {
975 if (auto r{GetInt64ArgOr(args[1], 0)}) {
976 if (auto radix{GetInt64ArgOr(args[2], 2)}) {
977 return Expr<T>{
978 context.targetCharacteristics().SelectedRealKind(*p, *r, *radix)};
979 }
980 }
981 }
982 } else if (name == "shape") {
983 if (auto shape{GetContextFreeShape(context, args[0])}) {
984 if (auto shapeExpr{AsExtentArrayExpr(*shape)}) {
985 return Fold(context, ConvertToType<T>(std::move(*shapeExpr)));
986 }
987 }
988 } else if (name == "shifta" || name == "shiftr" || name == "shiftl") {
989 // Second argument can be of any kind. However, it must be smaller or
990 // equal than BIT_SIZE. It can be converted to Int4 to simplify.
991 auto fptr{&Scalar<T>::SHIFTA};
992 if (name == "shifta") { // done in fptr definition
993 } else if (name == "shiftr") {
994 fptr = &Scalar<T>::SHIFTR;
995 } else if (name == "shiftl") {
996 fptr = &Scalar<T>::SHIFTL;
997 } else {
998 common::die("missing case to fold intrinsic function %s", name.c_str());
999 }
1000 return FoldElementalIntrinsic<T, T, Int4>(context, std::move(funcRef),
1001 ScalarFunc<T, T, Int4>([&](const Scalar<T> &i,
1002 const Scalar<Int4> &pos) -> Scalar<T> {
1003 auto posVal{static_cast<int>(pos.ToInt64())};
1004 if (posVal < 0) {
1005 context.messages().Say(
1006 "SHIFT=%d count for %s is negative"_err_en_US, posVal, name);
1007 } else if (posVal > i.bits) {
1008 context.messages().Say(
1009 "SHIFT=%d count for %s is greater than %d"_err_en_US, posVal,
1010 name, i.bits);
1011 }
1012 return std::invoke(fptr, i, posVal);
1013 }));
1014 } else if (name == "sign") {
1015 return FoldElementalIntrinsic<T, T, T>(context, std::move(funcRef),
1016 ScalarFunc<T, T, T>(
1017 [&context](const Scalar<T> &j, const Scalar<T> &k) -> Scalar<T> {
1018 typename Scalar<T>::ValueWithOverflow result{j.SIGN(k)};
1019 if (result.overflow) {
1020 context.messages().Say(
1021 "sign(integer(kind=%d)) folding overflowed"_warn_en_US,
1022 KIND);
1023 }
1024 return result.value;
1025 }));
1026 } else if (name == "size") {
1027 if (auto shape{GetContextFreeShape(context, args[0])}) {
1028 if (auto &dimArg{args[1]}) { // DIM= is present, get one extent
1029 if (auto dim{GetInt64Arg(args[1])}) {
1030 int rank{GetRank(*shape)};
1031 if (*dim >= 1 && *dim <= rank) {
1032 const Symbol *symbol{UnwrapWholeSymbolDataRef(args[0])};
1033 if (symbol && IsAssumedSizeArray(*symbol) && *dim == rank) {
1034 context.messages().Say(
1035 "size(array,dim=%jd) of last dimension is not available for rank-%d assumed-size array dummy argument"_err_en_US,
1036 *dim, rank);
1037 return MakeInvalidIntrinsic<T>(std::move(funcRef));
1038 } else if (auto &extent{shape->at(*dim - 1)}) {
1039 return Fold(context, ConvertToType<T>(std::move(*extent)));
1040 }
1041 } else {
1042 context.messages().Say(
1043 "size(array,dim=%jd) dimension is out of range for rank-%d array"_warn_en_US,
1044 *dim, rank);
1045 }
1046 }
1047 } else if (auto extents{common::AllElementsPresent(std::move(*shape))}) {
1048 // DIM= is absent; compute PRODUCT(SHAPE())
1049 ExtentExpr product{1};
1050 for (auto &&extent : std::move(*extents)) {
1051 product = std::move(product) * std::move(extent);
1052 }
1053 return Expr<T>{ConvertToType<T>(Fold(context, std::move(product)))};
1054 }
1055 }
1056 } else if (name == "sizeof") { // in bytes; extension
1057 if (auto info{
1058 characteristics::TypeAndShape::Characterize(args[0], context)}) {
1059 if (auto bytes{info->MeasureSizeInBytes(context)}) {
1060 return Expr<T>{Fold(context, ConvertToType<T>(std::move(*bytes)))};
1061 }
1062 }
1063 } else if (name == "storage_size") { // in bits
1064 if (auto info{
1065 characteristics::TypeAndShape::Characterize(args[0], context)}) {
1066 if (auto bytes{info->MeasureElementSizeInBytes(context, true)}) {
1067 return Expr<T>{
1068 Fold(context, Expr<T>{8} * ConvertToType<T>(std::move(*bytes)))};
1069 }
1070 }
1071 } else if (name == "sum") {
1072 return FoldSum<T>(context, std::move(funcRef));
1073 } else if (name == "ubound") {
1074 return UBOUND(context, std::move(funcRef));
1075 }
1076 // TODO: dot_product, matmul, sign
1077 return Expr<T>{std::move(funcRef)};
1078 }
1079
1080 // Substitutes a bare type parameter reference with its value if it has one now
1081 // in an instantiation. Bare LEN type parameters are substituted only when
1082 // the known value is constant.
FoldOperation(FoldingContext & context,TypeParamInquiry && inquiry)1083 Expr<TypeParamInquiry::Result> FoldOperation(
1084 FoldingContext &context, TypeParamInquiry &&inquiry) {
1085 std::optional<NamedEntity> base{inquiry.base()};
1086 parser::CharBlock parameterName{inquiry.parameter().name()};
1087 if (base) {
1088 // Handling "designator%typeParam". Get the value of the type parameter
1089 // from the instantiation of the base
1090 if (const semantics::DeclTypeSpec *
1091 declType{base->GetLastSymbol().GetType()}) {
1092 if (const semantics::ParamValue *
1093 paramValue{
1094 declType->derivedTypeSpec().FindParameter(parameterName)}) {
1095 const semantics::MaybeIntExpr ¶mExpr{paramValue->GetExplicit()};
1096 if (paramExpr && IsConstantExpr(*paramExpr)) {
1097 Expr<SomeInteger> intExpr{*paramExpr};
1098 return Fold(context,
1099 ConvertToType<TypeParamInquiry::Result>(std::move(intExpr)));
1100 }
1101 }
1102 }
1103 } else {
1104 // A "bare" type parameter: replace with its value, if that's now known
1105 // in a current derived type instantiation, for KIND type parameters.
1106 if (const auto *pdt{context.pdtInstance()}) {
1107 bool isLen{false};
1108 if (const semantics::Scope * scope{context.pdtInstance()->scope()}) {
1109 auto iter{scope->find(parameterName)};
1110 if (iter != scope->end()) {
1111 const Symbol &symbol{*iter->second};
1112 const auto *details{symbol.detailsIf<semantics::TypeParamDetails>()};
1113 if (details) {
1114 isLen = details->attr() == common::TypeParamAttr::Len;
1115 const semantics::MaybeIntExpr &initExpr{details->init()};
1116 if (initExpr && IsConstantExpr(*initExpr) &&
1117 (!isLen || ToInt64(*initExpr))) {
1118 Expr<SomeInteger> expr{*initExpr};
1119 return Fold(context,
1120 ConvertToType<TypeParamInquiry::Result>(std::move(expr)));
1121 }
1122 }
1123 }
1124 }
1125 if (const auto *value{pdt->FindParameter(parameterName)}) {
1126 if (value->isExplicit()) {
1127 auto folded{Fold(context,
1128 AsExpr(ConvertToType<TypeParamInquiry::Result>(
1129 Expr<SomeInteger>{value->GetExplicit().value()})))};
1130 if (!isLen || ToInt64(folded)) {
1131 return folded;
1132 }
1133 }
1134 }
1135 }
1136 }
1137 return AsExpr(std::move(inquiry));
1138 }
1139
ToInt64(const Expr<SomeInteger> & expr)1140 std::optional<std::int64_t> ToInt64(const Expr<SomeInteger> &expr) {
1141 return common::visit(
1142 [](const auto &kindExpr) { return ToInt64(kindExpr); }, expr.u);
1143 }
1144
ToInt64(const Expr<SomeType> & expr)1145 std::optional<std::int64_t> ToInt64(const Expr<SomeType> &expr) {
1146 if (const auto *intExpr{UnwrapExpr<Expr<SomeInteger>>(expr)}) {
1147 return ToInt64(*intExpr);
1148 } else {
1149 return std::nullopt;
1150 }
1151 }
1152
1153 #ifdef _MSC_VER // disable bogus warning about missing definitions
1154 #pragma warning(disable : 4661)
1155 #endif
1156 FOR_EACH_INTEGER_KIND(template class ExpressionBase, )
1157 template class ExpressionBase<SomeInteger>;
1158 } // namespace Fortran::evaluate
1159