1 //===-- lib/Evaluate/shape.cpp --------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "flang/Evaluate/shape.h"
10 #include "flang/Common/idioms.h"
11 #include "flang/Common/template.h"
12 #include "flang/Evaluate/characteristics.h"
13 #include "flang/Evaluate/check-expression.h"
14 #include "flang/Evaluate/fold.h"
15 #include "flang/Evaluate/intrinsics.h"
16 #include "flang/Evaluate/tools.h"
17 #include "flang/Evaluate/type.h"
18 #include "flang/Parser/message.h"
19 #include "flang/Semantics/symbol.h"
20 #include <functional>
21
22 using namespace std::placeholders; // _1, _2, &c. for std::bind()
23
24 namespace Fortran::evaluate {
25
IsImpliedShape(const Symbol & original)26 bool IsImpliedShape(const Symbol &original) {
27 const Symbol &symbol{ResolveAssociations(original)};
28 const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()};
29 return details && symbol.attrs().test(semantics::Attr::PARAMETER) &&
30 details->shape().CanBeImpliedShape();
31 }
32
IsExplicitShape(const Symbol & original)33 bool IsExplicitShape(const Symbol &original) {
34 const Symbol &symbol{ResolveAssociations(original)};
35 if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
36 const auto &shape{details->shape()};
37 return shape.Rank() == 0 ||
38 shape.IsExplicitShape(); // true when scalar, too
39 } else {
40 return symbol
41 .has<semantics::AssocEntityDetails>(); // exprs have explicit shape
42 }
43 }
44
ConstantShape(const Constant<ExtentType> & arrayConstant)45 Shape GetShapeHelper::ConstantShape(const Constant<ExtentType> &arrayConstant) {
46 CHECK(arrayConstant.Rank() == 1);
47 Shape result;
48 std::size_t dimensions{arrayConstant.size()};
49 for (std::size_t j{0}; j < dimensions; ++j) {
50 Scalar<ExtentType> extent{arrayConstant.values().at(j)};
51 result.emplace_back(MaybeExtentExpr{ExtentExpr{std::move(extent)}});
52 }
53 return result;
54 }
55
AsShapeResult(ExtentExpr && arrayExpr) const56 auto GetShapeHelper::AsShapeResult(ExtentExpr &&arrayExpr) const -> Result {
57 if (context_) {
58 arrayExpr = Fold(*context_, std::move(arrayExpr));
59 }
60 if (const auto *constArray{UnwrapConstantValue<ExtentType>(arrayExpr)}) {
61 return ConstantShape(*constArray);
62 }
63 if (auto *constructor{UnwrapExpr<ArrayConstructor<ExtentType>>(arrayExpr)}) {
64 Shape result;
65 for (auto &value : *constructor) {
66 auto *expr{std::get_if<ExtentExpr>(&value.u)};
67 if (expr && expr->Rank() == 0) {
68 result.emplace_back(std::move(*expr));
69 } else {
70 return std::nullopt;
71 }
72 }
73 return result;
74 } else {
75 return std::nullopt;
76 }
77 }
78
CreateShape(int rank,NamedEntity & base)79 Shape GetShapeHelper::CreateShape(int rank, NamedEntity &base) {
80 Shape shape;
81 for (int dimension{0}; dimension < rank; ++dimension) {
82 shape.emplace_back(GetExtent(base, dimension));
83 }
84 return shape;
85 }
86
AsExtentArrayExpr(const Shape & shape)87 std::optional<ExtentExpr> AsExtentArrayExpr(const Shape &shape) {
88 ArrayConstructorValues<ExtentType> values;
89 for (const auto &dim : shape) {
90 if (dim) {
91 values.Push(common::Clone(*dim));
92 } else {
93 return std::nullopt;
94 }
95 }
96 return ExtentExpr{ArrayConstructor<ExtentType>{std::move(values)}};
97 }
98
AsConstantShape(FoldingContext & context,const Shape & shape)99 std::optional<Constant<ExtentType>> AsConstantShape(
100 FoldingContext &context, const Shape &shape) {
101 if (auto shapeArray{AsExtentArrayExpr(shape)}) {
102 auto folded{Fold(context, std::move(*shapeArray))};
103 if (auto *p{UnwrapConstantValue<ExtentType>(folded)}) {
104 return std::move(*p);
105 }
106 }
107 return std::nullopt;
108 }
109
AsConstantShape(const ConstantSubscripts & shape)110 Constant<SubscriptInteger> AsConstantShape(const ConstantSubscripts &shape) {
111 using IntType = Scalar<SubscriptInteger>;
112 std::vector<IntType> result;
113 for (auto dim : shape) {
114 result.emplace_back(dim);
115 }
116 return {std::move(result), ConstantSubscripts{GetRank(shape)}};
117 }
118
AsConstantExtents(const Constant<ExtentType> & shape)119 ConstantSubscripts AsConstantExtents(const Constant<ExtentType> &shape) {
120 ConstantSubscripts result;
121 for (const auto &extent : shape.values()) {
122 result.push_back(extent.ToInt64());
123 }
124 return result;
125 }
126
AsConstantExtents(FoldingContext & context,const Shape & shape)127 std::optional<ConstantSubscripts> AsConstantExtents(
128 FoldingContext &context, const Shape &shape) {
129 if (auto shapeConstant{AsConstantShape(context, shape)}) {
130 return AsConstantExtents(*shapeConstant);
131 } else {
132 return std::nullopt;
133 }
134 }
135
AsShape(const ConstantSubscripts & shape)136 Shape AsShape(const ConstantSubscripts &shape) {
137 Shape result;
138 for (const auto &extent : shape) {
139 result.emplace_back(ExtentExpr{extent});
140 }
141 return result;
142 }
143
AsShape(const std::optional<ConstantSubscripts> & shape)144 std::optional<Shape> AsShape(const std::optional<ConstantSubscripts> &shape) {
145 if (shape) {
146 return AsShape(*shape);
147 } else {
148 return std::nullopt;
149 }
150 }
151
Fold(FoldingContext & context,Shape && shape)152 Shape Fold(FoldingContext &context, Shape &&shape) {
153 for (auto &dim : shape) {
154 dim = Fold(context, std::move(dim));
155 }
156 return std::move(shape);
157 }
158
Fold(FoldingContext & context,std::optional<Shape> && shape)159 std::optional<Shape> Fold(
160 FoldingContext &context, std::optional<Shape> &&shape) {
161 if (shape) {
162 return Fold(context, std::move(*shape));
163 } else {
164 return std::nullopt;
165 }
166 }
167
ComputeTripCount(ExtentExpr && lower,ExtentExpr && upper,ExtentExpr && stride)168 static ExtentExpr ComputeTripCount(
169 ExtentExpr &&lower, ExtentExpr &&upper, ExtentExpr &&stride) {
170 ExtentExpr strideCopy{common::Clone(stride)};
171 ExtentExpr span{
172 (std::move(upper) - std::move(lower) + std::move(strideCopy)) /
173 std::move(stride)};
174 return ExtentExpr{
175 Extremum<ExtentType>{Ordering::Greater, std::move(span), ExtentExpr{0}}};
176 }
177
CountTrips(ExtentExpr && lower,ExtentExpr && upper,ExtentExpr && stride)178 ExtentExpr CountTrips(
179 ExtentExpr &&lower, ExtentExpr &&upper, ExtentExpr &&stride) {
180 return ComputeTripCount(
181 std::move(lower), std::move(upper), std::move(stride));
182 }
183
CountTrips(const ExtentExpr & lower,const ExtentExpr & upper,const ExtentExpr & stride)184 ExtentExpr CountTrips(const ExtentExpr &lower, const ExtentExpr &upper,
185 const ExtentExpr &stride) {
186 return ComputeTripCount(
187 common::Clone(lower), common::Clone(upper), common::Clone(stride));
188 }
189
CountTrips(MaybeExtentExpr && lower,MaybeExtentExpr && upper,MaybeExtentExpr && stride)190 MaybeExtentExpr CountTrips(MaybeExtentExpr &&lower, MaybeExtentExpr &&upper,
191 MaybeExtentExpr &&stride) {
192 std::function<ExtentExpr(ExtentExpr &&, ExtentExpr &&, ExtentExpr &&)> bound{
193 std::bind(ComputeTripCount, _1, _2, _3)};
194 return common::MapOptional(
195 std::move(bound), std::move(lower), std::move(upper), std::move(stride));
196 }
197
GetSize(Shape && shape)198 MaybeExtentExpr GetSize(Shape &&shape) {
199 ExtentExpr extent{1};
200 for (auto &&dim : std::move(shape)) {
201 if (dim) {
202 extent = std::move(extent) * std::move(*dim);
203 } else {
204 return std::nullopt;
205 }
206 }
207 return extent;
208 }
209
GetSize(const ConstantSubscripts & shape)210 ConstantSubscript GetSize(const ConstantSubscripts &shape) {
211 ConstantSubscript size{1};
212 for (auto dim : shape) {
213 CHECK(dim >= 0);
214 size *= dim;
215 }
216 return size;
217 }
218
ContainsAnyImpliedDoIndex(const ExtentExpr & expr)219 bool ContainsAnyImpliedDoIndex(const ExtentExpr &expr) {
220 struct MyVisitor : public AnyTraverse<MyVisitor> {
221 using Base = AnyTraverse<MyVisitor>;
222 MyVisitor() : Base{*this} {}
223 using Base::operator();
224 bool operator()(const ImpliedDoIndex &) { return true; }
225 };
226 return MyVisitor{}(expr);
227 }
228
229 // Determines lower bound on a dimension. This can be other than 1 only
230 // for a reference to a whole array object or component. (See LBOUND, 16.9.109).
231 // ASSOCIATE construct entities may require traversal of their referents.
232 template <typename RESULT, bool LBOUND_SEMANTICS>
233 class GetLowerBoundHelper
234 : public Traverse<GetLowerBoundHelper<RESULT, LBOUND_SEMANTICS>, RESULT> {
235 public:
236 using Result = RESULT;
237 using Base = Traverse<GetLowerBoundHelper, RESULT>;
238 using Base::operator();
GetLowerBoundHelper(int d,FoldingContext * context)239 explicit GetLowerBoundHelper(int d, FoldingContext *context)
240 : Base{*this}, dimension_{d}, context_{context} {}
Default()241 static Result Default() { return Result{1}; }
Combine(Result &&,Result &&)242 static Result Combine(Result &&, Result &&) {
243 // Operator results and array references always have lower bounds == 1
244 return Result{1};
245 }
246
GetLowerBound(const Symbol & symbol0,NamedEntity && base) const247 Result GetLowerBound(const Symbol &symbol0, NamedEntity &&base) const {
248 const Symbol &symbol{symbol0.GetUltimate()};
249 if (const auto *details{
250 symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
251 int rank{details->shape().Rank()};
252 if (dimension_ < rank) {
253 const semantics::ShapeSpec &shapeSpec{details->shape()[dimension_]};
254 if (shapeSpec.lbound().isExplicit()) {
255 if (const auto &lbound{shapeSpec.lbound().GetExplicit()}) {
256 if constexpr (LBOUND_SEMANTICS) {
257 bool ok{false};
258 auto lbValue{ToInt64(*lbound)};
259 if (dimension_ == rank - 1 && details->IsAssumedSize()) {
260 // last dimension of assumed-size dummy array: don't worry
261 // about handling an empty dimension
262 ok = IsScopeInvariantExpr(*lbound);
263 } else if (lbValue.value_or(0) == 1) {
264 // Lower bound is 1, regardless of extent
265 ok = true;
266 } else if (const auto &ubound{shapeSpec.ubound().GetExplicit()}) {
267 // If we can't prove that the dimension is nonempty,
268 // we must be conservative.
269 // TODO: simple symbolic math in expression rewriting to
270 // cope with cases like A(J:J)
271 if (context_) {
272 auto extent{ToInt64(Fold(*context_,
273 ExtentExpr{*ubound} - ExtentExpr{*lbound} +
274 ExtentExpr{1}))};
275 if (extent) {
276 if (extent <= 0) {
277 return Result{1};
278 }
279 ok = true;
280 } else {
281 ok = false;
282 }
283 } else {
284 auto ubValue{ToInt64(*ubound)};
285 if (lbValue && ubValue) {
286 if (*lbValue > *ubValue) {
287 return Result{1};
288 }
289 ok = true;
290 } else {
291 ok = false;
292 }
293 }
294 }
295 return ok ? *lbound : Result{};
296 } else {
297 return *lbound;
298 }
299 } else {
300 return Result{1};
301 }
302 }
303 if (IsDescriptor(symbol)) {
304 return ExtentExpr{DescriptorInquiry{std::move(base),
305 DescriptorInquiry::Field::LowerBound, dimension_}};
306 }
307 }
308 } else if (const auto *assoc{
309 symbol.detailsIf<semantics::AssocEntityDetails>()}) {
310 if (assoc->rank()) { // SELECT RANK case
311 const Symbol &resolved{ResolveAssociations(symbol)};
312 if (IsDescriptor(resolved) && dimension_ < *assoc->rank()) {
313 return ExtentExpr{DescriptorInquiry{std::move(base),
314 DescriptorInquiry::Field::LowerBound, dimension_}};
315 }
316 } else {
317 return (*this)(assoc->expr());
318 }
319 }
320 if constexpr (LBOUND_SEMANTICS) {
321 return Result{};
322 } else {
323 return Result{1};
324 }
325 }
326
operator ()(const Symbol & symbol0) const327 Result operator()(const Symbol &symbol0) const {
328 return GetLowerBound(symbol0, NamedEntity{symbol0});
329 }
330
operator ()(const Component & component) const331 Result operator()(const Component &component) const {
332 if (component.base().Rank() == 0) {
333 return GetLowerBound(
334 component.GetLastSymbol(), NamedEntity{common::Clone(component)});
335 }
336 return Result{1};
337 }
338
339 private:
340 int dimension_;
341 FoldingContext *context_{nullptr};
342 };
343
GetRawLowerBound(const NamedEntity & base,int dimension)344 ExtentExpr GetRawLowerBound(const NamedEntity &base, int dimension) {
345 return GetLowerBoundHelper<ExtentExpr, false>{dimension, nullptr}(base);
346 }
347
GetRawLowerBound(FoldingContext & context,const NamedEntity & base,int dimension)348 ExtentExpr GetRawLowerBound(
349 FoldingContext &context, const NamedEntity &base, int dimension) {
350 return Fold(context,
351 GetLowerBoundHelper<ExtentExpr, false>{dimension, &context}(base));
352 }
353
GetLBOUND(const NamedEntity & base,int dimension)354 MaybeExtentExpr GetLBOUND(const NamedEntity &base, int dimension) {
355 return GetLowerBoundHelper<MaybeExtentExpr, true>{dimension, nullptr}(base);
356 }
357
GetLBOUND(FoldingContext & context,const NamedEntity & base,int dimension)358 MaybeExtentExpr GetLBOUND(
359 FoldingContext &context, const NamedEntity &base, int dimension) {
360 return Fold(context,
361 GetLowerBoundHelper<MaybeExtentExpr, true>{dimension, &context}(base));
362 }
363
GetRawLowerBounds(const NamedEntity & base)364 Shape GetRawLowerBounds(const NamedEntity &base) {
365 Shape result;
366 int rank{base.Rank()};
367 for (int dim{0}; dim < rank; ++dim) {
368 result.emplace_back(GetRawLowerBound(base, dim));
369 }
370 return result;
371 }
372
GetRawLowerBounds(FoldingContext & context,const NamedEntity & base)373 Shape GetRawLowerBounds(FoldingContext &context, const NamedEntity &base) {
374 Shape result;
375 int rank{base.Rank()};
376 for (int dim{0}; dim < rank; ++dim) {
377 result.emplace_back(GetRawLowerBound(context, base, dim));
378 }
379 return result;
380 }
381
GetLBOUNDs(const NamedEntity & base)382 Shape GetLBOUNDs(const NamedEntity &base) {
383 Shape result;
384 int rank{base.Rank()};
385 for (int dim{0}; dim < rank; ++dim) {
386 result.emplace_back(GetLBOUND(base, dim));
387 }
388 return result;
389 }
390
GetLBOUNDs(FoldingContext & context,const NamedEntity & base)391 Shape GetLBOUNDs(FoldingContext &context, const NamedEntity &base) {
392 Shape result;
393 int rank{base.Rank()};
394 for (int dim{0}; dim < rank; ++dim) {
395 result.emplace_back(GetLBOUND(context, base, dim));
396 }
397 return result;
398 }
399
400 // If the upper and lower bounds are constant, return a constant expression for
401 // the extent. In particular, if the upper bound is less than the lower bound,
402 // return zero.
GetNonNegativeExtent(const semantics::ShapeSpec & shapeSpec)403 static MaybeExtentExpr GetNonNegativeExtent(
404 const semantics::ShapeSpec &shapeSpec) {
405 const auto &ubound{shapeSpec.ubound().GetExplicit()};
406 const auto &lbound{shapeSpec.lbound().GetExplicit()};
407 std::optional<ConstantSubscript> uval{ToInt64(ubound)};
408 std::optional<ConstantSubscript> lval{ToInt64(lbound)};
409 if (uval && lval) {
410 if (*uval < *lval) {
411 return ExtentExpr{0};
412 } else {
413 return ExtentExpr{*uval - *lval + 1};
414 }
415 } else if (lbound && ubound && IsScopeInvariantExpr(*lbound) &&
416 IsScopeInvariantExpr(*ubound)) {
417 // Apply effective IDIM (MAX calculation with 0) so thet the
418 // result is never negative
419 if (lval.value_or(0) == 1) {
420 return ExtentExpr{Extremum<SubscriptInteger>{
421 Ordering::Greater, ExtentExpr{0}, common::Clone(*ubound)}};
422 } else {
423 return ExtentExpr{
424 Extremum<SubscriptInteger>{Ordering::Greater, ExtentExpr{0},
425 common::Clone(*ubound) - common::Clone(*lbound) + ExtentExpr{1}}};
426 }
427 } else {
428 return std::nullopt;
429 }
430 }
431
GetExtent(const NamedEntity & base,int dimension)432 MaybeExtentExpr GetExtent(const NamedEntity &base, int dimension) {
433 CHECK(dimension >= 0);
434 const Symbol &last{base.GetLastSymbol()};
435 const Symbol &symbol{ResolveAssociations(last)};
436 if (const auto *assoc{last.detailsIf<semantics::AssocEntityDetails>()}) {
437 if (assoc->rank()) { // SELECT RANK case
438 if (semantics::IsDescriptor(symbol) && dimension < *assoc->rank()) {
439 return ExtentExpr{DescriptorInquiry{
440 NamedEntity{base}, DescriptorInquiry::Field::Extent, dimension}};
441 }
442 } else if (auto shape{GetShape(assoc->expr())}) {
443 if (dimension < static_cast<int>(shape->size())) {
444 return std::move(shape->at(dimension));
445 }
446 }
447 }
448 if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
449 if (IsImpliedShape(symbol) && details->init()) {
450 if (auto shape{GetShape(symbol)}) {
451 if (dimension < static_cast<int>(shape->size())) {
452 return std::move(shape->at(dimension));
453 }
454 }
455 } else {
456 int j{0};
457 for (const auto &shapeSpec : details->shape()) {
458 if (j++ == dimension) {
459 if (auto extent{GetNonNegativeExtent(shapeSpec)}) {
460 return extent;
461 } else if (details->IsAssumedSize() && j == symbol.Rank()) {
462 return std::nullopt;
463 } else if (semantics::IsDescriptor(symbol)) {
464 return ExtentExpr{DescriptorInquiry{NamedEntity{base},
465 DescriptorInquiry::Field::Extent, dimension}};
466 } else {
467 break;
468 }
469 }
470 }
471 }
472 }
473 return std::nullopt;
474 }
475
GetExtent(FoldingContext & context,const NamedEntity & base,int dimension)476 MaybeExtentExpr GetExtent(
477 FoldingContext &context, const NamedEntity &base, int dimension) {
478 return Fold(context, GetExtent(base, dimension));
479 }
480
GetExtent(const Subscript & subscript,const NamedEntity & base,int dimension)481 MaybeExtentExpr GetExtent(
482 const Subscript &subscript, const NamedEntity &base, int dimension) {
483 return common::visit(
484 common::visitors{
485 [&](const Triplet &triplet) -> MaybeExtentExpr {
486 MaybeExtentExpr upper{triplet.upper()};
487 if (!upper) {
488 upper = GetUBOUND(base, dimension);
489 }
490 MaybeExtentExpr lower{triplet.lower()};
491 if (!lower) {
492 lower = GetLBOUND(base, dimension);
493 }
494 return CountTrips(std::move(lower), std::move(upper),
495 MaybeExtentExpr{triplet.stride()});
496 },
497 [&](const IndirectSubscriptIntegerExpr &subs) -> MaybeExtentExpr {
498 if (auto shape{GetShape(subs.value())}) {
499 if (GetRank(*shape) > 0) {
500 CHECK(GetRank(*shape) == 1); // vector-valued subscript
501 return std::move(shape->at(0));
502 }
503 }
504 return std::nullopt;
505 },
506 },
507 subscript.u);
508 }
509
GetExtent(FoldingContext & context,const Subscript & subscript,const NamedEntity & base,int dimension)510 MaybeExtentExpr GetExtent(FoldingContext &context, const Subscript &subscript,
511 const NamedEntity &base, int dimension) {
512 return Fold(context, GetExtent(subscript, base, dimension));
513 }
514
ComputeUpperBound(ExtentExpr && lower,MaybeExtentExpr && extent)515 MaybeExtentExpr ComputeUpperBound(
516 ExtentExpr &&lower, MaybeExtentExpr &&extent) {
517 if (extent) {
518 if (ToInt64(lower).value_or(0) == 1) {
519 return std::move(*extent);
520 } else {
521 return std::move(*extent) + std::move(lower) - ExtentExpr{1};
522 }
523 } else {
524 return std::nullopt;
525 }
526 }
527
ComputeUpperBound(FoldingContext & context,ExtentExpr && lower,MaybeExtentExpr && extent)528 MaybeExtentExpr ComputeUpperBound(
529 FoldingContext &context, ExtentExpr &&lower, MaybeExtentExpr &&extent) {
530 return Fold(context, ComputeUpperBound(std::move(lower), std::move(extent)));
531 }
532
GetRawUpperBound(const NamedEntity & base,int dimension)533 MaybeExtentExpr GetRawUpperBound(const NamedEntity &base, int dimension) {
534 const Symbol &symbol{ResolveAssociations(base.GetLastSymbol())};
535 if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
536 int rank{details->shape().Rank()};
537 if (dimension < rank) {
538 const auto &bound{details->shape()[dimension].ubound().GetExplicit()};
539 if (bound && IsScopeInvariantExpr(*bound)) {
540 return *bound;
541 } else if (details->IsAssumedSize() && dimension + 1 == symbol.Rank()) {
542 return std::nullopt;
543 } else {
544 return ComputeUpperBound(
545 GetRawLowerBound(base, dimension), GetExtent(base, dimension));
546 }
547 }
548 } else if (const auto *assoc{
549 symbol.detailsIf<semantics::AssocEntityDetails>()}) {
550 if (auto shape{GetShape(assoc->expr())}) {
551 if (dimension < static_cast<int>(shape->size())) {
552 return ComputeUpperBound(
553 GetRawLowerBound(base, dimension), std::move(shape->at(dimension)));
554 }
555 }
556 }
557 return std::nullopt;
558 }
559
GetRawUpperBound(FoldingContext & context,const NamedEntity & base,int dimension)560 MaybeExtentExpr GetRawUpperBound(
561 FoldingContext &context, const NamedEntity &base, int dimension) {
562 return Fold(context, GetRawUpperBound(base, dimension));
563 }
564
GetExplicitUBOUND(FoldingContext * context,const semantics::ShapeSpec & shapeSpec)565 static MaybeExtentExpr GetExplicitUBOUND(
566 FoldingContext *context, const semantics::ShapeSpec &shapeSpec) {
567 const auto &ubound{shapeSpec.ubound().GetExplicit()};
568 if (ubound && IsScopeInvariantExpr(*ubound)) {
569 if (auto extent{GetNonNegativeExtent(shapeSpec)}) {
570 if (auto cstExtent{ToInt64(
571 context ? Fold(*context, std::move(*extent)) : *extent)}) {
572 if (cstExtent > 0) {
573 return *ubound;
574 } else if (cstExtent == 0) {
575 return ExtentExpr{0};
576 }
577 }
578 }
579 }
580 return std::nullopt;
581 }
582
GetUBOUND(FoldingContext * context,const NamedEntity & base,int dimension)583 static MaybeExtentExpr GetUBOUND(
584 FoldingContext *context, const NamedEntity &base, int dimension) {
585 const Symbol &symbol{ResolveAssociations(base.GetLastSymbol())};
586 if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
587 int rank{details->shape().Rank()};
588 if (dimension < rank) {
589 const semantics::ShapeSpec &shapeSpec{details->shape()[dimension]};
590 if (auto ubound{GetExplicitUBOUND(context, shapeSpec)}) {
591 return *ubound;
592 } else if (details->IsAssumedSize() && dimension + 1 == symbol.Rank()) {
593 return std::nullopt;
594 } else if (auto lb{GetLBOUND(base, dimension)}) {
595 return ComputeUpperBound(std::move(*lb), GetExtent(base, dimension));
596 }
597 }
598 } else if (const auto *assoc{
599 symbol.detailsIf<semantics::AssocEntityDetails>()}) {
600 if (auto shape{GetShape(assoc->expr())}) {
601 if (dimension < static_cast<int>(shape->size())) {
602 if (auto lb{GetLBOUND(base, dimension)}) {
603 return ComputeUpperBound(
604 std::move(*lb), std::move(shape->at(dimension)));
605 }
606 }
607 }
608 }
609 return std::nullopt;
610 }
611
GetUBOUND(const NamedEntity & base,int dimension)612 MaybeExtentExpr GetUBOUND(const NamedEntity &base, int dimension) {
613 return GetUBOUND(nullptr, base, dimension);
614 }
615
GetUBOUND(FoldingContext & context,const NamedEntity & base,int dimension)616 MaybeExtentExpr GetUBOUND(
617 FoldingContext &context, const NamedEntity &base, int dimension) {
618 return Fold(context, GetUBOUND(&context, base, dimension));
619 }
620
GetUBOUNDs(FoldingContext * context,const NamedEntity & base)621 static Shape GetUBOUNDs(FoldingContext *context, const NamedEntity &base) {
622 const Symbol &symbol{ResolveAssociations(base.GetLastSymbol())};
623 if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
624 Shape result;
625 int dim{0};
626 for (const auto &shapeSpec : details->shape()) {
627 if (auto ubound{GetExplicitUBOUND(context, shapeSpec)}) {
628 result.emplace_back(*ubound);
629 } else if (details->IsAssumedSize() && dim + 1 == base.Rank()) {
630 result.emplace_back(std::nullopt); // UBOUND folding replaces with -1
631 } else if (auto lb{GetLBOUND(base, dim)}) {
632 result.emplace_back(
633 ComputeUpperBound(std::move(*lb), GetExtent(base, dim)));
634 } else {
635 result.emplace_back(); // unknown
636 }
637 ++dim;
638 }
639 CHECK(GetRank(result) == symbol.Rank());
640 return result;
641 } else {
642 return std::move(GetShape(symbol).value());
643 }
644 }
645
GetUBOUNDs(FoldingContext & context,const NamedEntity & base)646 Shape GetUBOUNDs(FoldingContext &context, const NamedEntity &base) {
647 return Fold(context, GetUBOUNDs(&context, base));
648 }
649
GetUBOUNDs(const NamedEntity & base)650 Shape GetUBOUNDs(const NamedEntity &base) { return GetUBOUNDs(nullptr, base); }
651
operator ()(const Symbol & symbol) const652 auto GetShapeHelper::operator()(const Symbol &symbol) const -> Result {
653 return common::visit(
654 common::visitors{
655 [&](const semantics::ObjectEntityDetails &object) {
656 if (IsImpliedShape(symbol) && object.init()) {
657 return (*this)(object.init());
658 } else if (IsAssumedRank(symbol)) {
659 return Result{};
660 } else {
661 int n{object.shape().Rank()};
662 NamedEntity base{symbol};
663 return Result{CreateShape(n, base)};
664 }
665 },
666 [](const semantics::EntityDetails &) {
667 return ScalarShape(); // no dimensions seen
668 },
669 [&](const semantics::ProcEntityDetails &proc) {
670 if (const Symbol * interface{proc.interface().symbol()}) {
671 return (*this)(*interface);
672 } else {
673 return ScalarShape();
674 }
675 },
676 [&](const semantics::AssocEntityDetails &assoc) {
677 if (assoc.rank()) { // SELECT RANK case
678 int n{assoc.rank().value()};
679 NamedEntity base{symbol};
680 return Result{CreateShape(n, base)};
681 } else {
682 return (*this)(assoc.expr());
683 }
684 },
685 [&](const semantics::SubprogramDetails &subp) -> Result {
686 if (subp.isFunction()) {
687 auto resultShape{(*this)(subp.result())};
688 if (resultShape && !useResultSymbolShape_) {
689 // Ensure the shape is constant. Otherwise, it may be referring
690 // to symbols that belong to the subroutine scope and are
691 // meaningless on the caller side without the related call
692 // expression.
693 for (auto &extent : *resultShape) {
694 if (extent && !IsActuallyConstant(*extent)) {
695 extent.reset();
696 }
697 }
698 }
699 return resultShape;
700 } else {
701 return Result{};
702 }
703 },
704 [&](const semantics::ProcBindingDetails &binding) {
705 return (*this)(binding.symbol());
706 },
707 [](const semantics::TypeParamDetails &) { return ScalarShape(); },
708 [](const auto &) { return Result{}; },
709 },
710 symbol.GetUltimate().details());
711 }
712
operator ()(const Component & component) const713 auto GetShapeHelper::operator()(const Component &component) const -> Result {
714 const Symbol &symbol{component.GetLastSymbol()};
715 int rank{symbol.Rank()};
716 if (rank == 0) {
717 return (*this)(component.base());
718 } else if (symbol.has<semantics::ObjectEntityDetails>()) {
719 NamedEntity base{Component{component}};
720 return CreateShape(rank, base);
721 } else if (symbol.has<semantics::AssocEntityDetails>()) {
722 NamedEntity base{Component{component}};
723 return Result{CreateShape(rank, base)};
724 } else {
725 return (*this)(symbol);
726 }
727 }
728
operator ()(const ArrayRef & arrayRef) const729 auto GetShapeHelper::operator()(const ArrayRef &arrayRef) const -> Result {
730 Shape shape;
731 int dimension{0};
732 const NamedEntity &base{arrayRef.base()};
733 for (const Subscript &ss : arrayRef.subscript()) {
734 if (ss.Rank() > 0) {
735 shape.emplace_back(GetExtent(ss, base, dimension));
736 }
737 ++dimension;
738 }
739 if (shape.empty()) {
740 if (const Component * component{base.UnwrapComponent()}) {
741 return (*this)(component->base());
742 }
743 }
744 return shape;
745 }
746
operator ()(const CoarrayRef & coarrayRef) const747 auto GetShapeHelper::operator()(const CoarrayRef &coarrayRef) const -> Result {
748 NamedEntity base{coarrayRef.GetBase()};
749 if (coarrayRef.subscript().empty()) {
750 return (*this)(base);
751 } else {
752 Shape shape;
753 int dimension{0};
754 for (const Subscript &ss : coarrayRef.subscript()) {
755 if (ss.Rank() > 0) {
756 shape.emplace_back(GetExtent(ss, base, dimension));
757 }
758 ++dimension;
759 }
760 return shape;
761 }
762 }
763
operator ()(const Substring & substring) const764 auto GetShapeHelper::operator()(const Substring &substring) const -> Result {
765 return (*this)(substring.parent());
766 }
767
operator ()(const ProcedureRef & call) const768 auto GetShapeHelper::operator()(const ProcedureRef &call) const -> Result {
769 if (call.Rank() == 0) {
770 return ScalarShape();
771 } else if (call.IsElemental()) {
772 for (const auto &arg : call.arguments()) {
773 if (arg && arg->Rank() > 0) {
774 return (*this)(*arg);
775 }
776 }
777 return ScalarShape();
778 } else if (const Symbol * symbol{call.proc().GetSymbol()}) {
779 return (*this)(*symbol);
780 } else if (const auto *intrinsic{call.proc().GetSpecificIntrinsic()}) {
781 if (intrinsic->name == "shape" || intrinsic->name == "lbound" ||
782 intrinsic->name == "ubound") {
783 // For LBOUND/UBOUND, these are the array-valued cases (no DIM=)
784 if (!call.arguments().empty() && call.arguments().front()) {
785 return Shape{
786 MaybeExtentExpr{ExtentExpr{call.arguments().front()->Rank()}}};
787 }
788 } else if (intrinsic->name == "all" || intrinsic->name == "any" ||
789 intrinsic->name == "count" || intrinsic->name == "iall" ||
790 intrinsic->name == "iany" || intrinsic->name == "iparity" ||
791 intrinsic->name == "maxval" || intrinsic->name == "minval" ||
792 intrinsic->name == "norm2" || intrinsic->name == "parity" ||
793 intrinsic->name == "product" || intrinsic->name == "sum") {
794 // Reduction with DIM=
795 if (call.arguments().size() >= 2) {
796 auto arrayShape{
797 (*this)(UnwrapExpr<Expr<SomeType>>(call.arguments().at(0)))};
798 const auto *dimArg{UnwrapExpr<Expr<SomeType>>(call.arguments().at(1))};
799 if (arrayShape && dimArg) {
800 if (auto dim{ToInt64(*dimArg)}) {
801 if (*dim >= 1 &&
802 static_cast<std::size_t>(*dim) <= arrayShape->size()) {
803 arrayShape->erase(arrayShape->begin() + (*dim - 1));
804 return std::move(*arrayShape);
805 }
806 }
807 }
808 }
809 } else if (intrinsic->name == "findloc" || intrinsic->name == "maxloc" ||
810 intrinsic->name == "minloc") {
811 std::size_t dimIndex{intrinsic->name == "findloc" ? 2u : 1u};
812 if (call.arguments().size() > dimIndex) {
813 if (auto arrayShape{
814 (*this)(UnwrapExpr<Expr<SomeType>>(call.arguments().at(0)))}) {
815 auto rank{static_cast<int>(arrayShape->size())};
816 if (const auto *dimArg{
817 UnwrapExpr<Expr<SomeType>>(call.arguments()[dimIndex])}) {
818 auto dim{ToInt64(*dimArg)};
819 if (dim && *dim >= 1 && *dim <= rank) {
820 arrayShape->erase(arrayShape->begin() + (*dim - 1));
821 return std::move(*arrayShape);
822 }
823 } else {
824 // xxxLOC(no DIM=) result is vector(1:RANK(ARRAY=))
825 return Shape{ExtentExpr{rank}};
826 }
827 }
828 }
829 } else if (intrinsic->name == "cshift" || intrinsic->name == "eoshift") {
830 if (!call.arguments().empty()) {
831 return (*this)(call.arguments()[0]);
832 }
833 } else if (intrinsic->name == "matmul") {
834 if (call.arguments().size() == 2) {
835 if (auto ashape{(*this)(call.arguments()[0])}) {
836 if (auto bshape{(*this)(call.arguments()[1])}) {
837 if (ashape->size() == 1 && bshape->size() == 2) {
838 bshape->erase(bshape->begin());
839 return std::move(*bshape); // matmul(vector, matrix)
840 } else if (ashape->size() == 2 && bshape->size() == 1) {
841 ashape->pop_back();
842 return std::move(*ashape); // matmul(matrix, vector)
843 } else if (ashape->size() == 2 && bshape->size() == 2) {
844 (*ashape)[1] = std::move((*bshape)[1]);
845 return std::move(*ashape); // matmul(matrix, matrix)
846 }
847 }
848 }
849 }
850 } else if (intrinsic->name == "pack") {
851 if (call.arguments().size() >= 3 && call.arguments().at(2)) {
852 // SHAPE(PACK(,,VECTOR=v)) -> SHAPE(v)
853 return (*this)(call.arguments().at(2));
854 } else if (call.arguments().size() >= 2 && context_) {
855 if (auto maskShape{(*this)(call.arguments().at(1))}) {
856 if (maskShape->size() == 0) {
857 // Scalar MASK= -> [MERGE(SIZE(ARRAY=), 0, mask)]
858 if (auto arrayShape{(*this)(call.arguments().at(0))}) {
859 auto arraySize{GetSize(std::move(*arrayShape))};
860 CHECK(arraySize);
861 ActualArguments toMerge{
862 ActualArgument{AsGenericExpr(std::move(*arraySize))},
863 ActualArgument{AsGenericExpr(ExtentExpr{0})},
864 common::Clone(call.arguments().at(1))};
865 auto specific{context_->intrinsics().Probe(
866 CallCharacteristics{"merge"}, toMerge, *context_)};
867 CHECK(specific);
868 return Shape{ExtentExpr{FunctionRef<ExtentType>{
869 ProcedureDesignator{std::move(specific->specificIntrinsic)},
870 std::move(specific->arguments)}}};
871 }
872 } else {
873 // Non-scalar MASK= -> [COUNT(mask)]
874 ActualArguments toCount{ActualArgument{common::Clone(
875 DEREF(call.arguments().at(1).value().UnwrapExpr()))}};
876 auto specific{context_->intrinsics().Probe(
877 CallCharacteristics{"count"}, toCount, *context_)};
878 CHECK(specific);
879 return Shape{ExtentExpr{FunctionRef<ExtentType>{
880 ProcedureDesignator{std::move(specific->specificIntrinsic)},
881 std::move(specific->arguments)}}};
882 }
883 }
884 }
885 } else if (intrinsic->name == "reshape") {
886 if (call.arguments().size() >= 2 && call.arguments().at(1)) {
887 // SHAPE(RESHAPE(array,shape)) -> shape
888 if (const auto *shapeExpr{
889 call.arguments().at(1).value().UnwrapExpr()}) {
890 auto shapeArg{std::get<Expr<SomeInteger>>(shapeExpr->u)};
891 if (auto result{AsShapeResult(
892 ConvertToType<ExtentType>(std::move(shapeArg)))}) {
893 return result;
894 }
895 }
896 }
897 } else if (intrinsic->name == "spread") {
898 // SHAPE(SPREAD(ARRAY,DIM,NCOPIES)) = SHAPE(ARRAY) with NCOPIES inserted
899 // at position DIM.
900 if (call.arguments().size() == 3) {
901 auto arrayShape{
902 (*this)(UnwrapExpr<Expr<SomeType>>(call.arguments().at(0)))};
903 const auto *dimArg{UnwrapExpr<Expr<SomeType>>(call.arguments().at(1))};
904 const auto *nCopies{
905 UnwrapExpr<Expr<SomeInteger>>(call.arguments().at(2))};
906 if (arrayShape && dimArg && nCopies) {
907 if (auto dim{ToInt64(*dimArg)}) {
908 if (*dim >= 1 &&
909 static_cast<std::size_t>(*dim) <= arrayShape->size() + 1) {
910 arrayShape->emplace(arrayShape->begin() + *dim - 1,
911 ConvertToType<ExtentType>(common::Clone(*nCopies)));
912 return std::move(*arrayShape);
913 }
914 }
915 }
916 }
917 } else if (intrinsic->name == "transfer") {
918 if (call.arguments().size() == 3 && call.arguments().at(2)) {
919 // SIZE= is present; shape is vector [SIZE=]
920 if (const auto *size{
921 UnwrapExpr<Expr<SomeInteger>>(call.arguments().at(2))}) {
922 return Shape{
923 MaybeExtentExpr{ConvertToType<ExtentType>(common::Clone(*size))}};
924 }
925 } else if (context_) {
926 if (auto moldTypeAndShape{characteristics::TypeAndShape::Characterize(
927 call.arguments().at(1), *context_)}) {
928 if (GetRank(moldTypeAndShape->shape()) == 0) {
929 // SIZE= is absent and MOLD= is scalar: result is scalar
930 return ScalarShape();
931 } else {
932 // SIZE= is absent and MOLD= is array: result is vector whose
933 // length is determined by sizes of types. See 16.9.193p4 case(ii).
934 // Note that if sourceBytes is not known to be empty, we
935 // can fold only when moldElementBytes is known to not be zero;
936 // the most general case risks a division by zero otherwise.
937 if (auto sourceTypeAndShape{
938 characteristics::TypeAndShape::Characterize(
939 call.arguments().at(0), *context_)}) {
940 if (auto sourceBytes{
941 sourceTypeAndShape->MeasureSizeInBytes(*context_)}) {
942 *sourceBytes = Fold(*context_, std::move(*sourceBytes));
943 if (auto sourceBytesConst{ToInt64(*sourceBytes)}) {
944 if (*sourceBytesConst == 0) {
945 return Shape{ExtentExpr{0}};
946 }
947 }
948 if (auto moldElementBytes{
949 moldTypeAndShape->MeasureElementSizeInBytes(
950 *context_, true)}) {
951 *moldElementBytes =
952 Fold(*context_, std::move(*moldElementBytes));
953 auto moldElementBytesConst{ToInt64(*moldElementBytes)};
954 if (moldElementBytesConst && *moldElementBytesConst != 0) {
955 ExtentExpr extent{Fold(*context_,
956 (std::move(*sourceBytes) +
957 common::Clone(*moldElementBytes) - ExtentExpr{1}) /
958 common::Clone(*moldElementBytes))};
959 return Shape{MaybeExtentExpr{std::move(extent)}};
960 }
961 }
962 }
963 }
964 }
965 }
966 }
967 } else if (intrinsic->name == "transpose") {
968 if (call.arguments().size() >= 1) {
969 if (auto shape{(*this)(call.arguments().at(0))}) {
970 if (shape->size() == 2) {
971 std::swap((*shape)[0], (*shape)[1]);
972 return shape;
973 }
974 }
975 }
976 } else if (intrinsic->name == "unpack") {
977 if (call.arguments().size() >= 2) {
978 return (*this)(call.arguments()[1]); // MASK=
979 }
980 } else if (intrinsic->characteristics.value().attrs.test(characteristics::
981 Procedure::Attr::NullPointer)) { // NULL(MOLD=)
982 return (*this)(call.arguments());
983 } else {
984 // TODO: shapes of other non-elemental intrinsic results
985 }
986 }
987 // The rank is always known even if the extents are not.
988 return Shape(static_cast<std::size_t>(call.Rank()), MaybeExtentExpr{});
989 }
990
991 // Check conformance of the passed shapes.
CheckConformance(parser::ContextualMessages & messages,const Shape & left,const Shape & right,CheckConformanceFlags::Flags flags,const char * leftIs,const char * rightIs)992 std::optional<bool> CheckConformance(parser::ContextualMessages &messages,
993 const Shape &left, const Shape &right, CheckConformanceFlags::Flags flags,
994 const char *leftIs, const char *rightIs) {
995 int n{GetRank(left)};
996 if (n == 0 && (flags & CheckConformanceFlags::LeftScalarExpandable)) {
997 return true;
998 }
999 int rn{GetRank(right)};
1000 if (rn == 0 && (flags & CheckConformanceFlags::RightScalarExpandable)) {
1001 return true;
1002 }
1003 if (n != rn) {
1004 messages.Say("Rank of %1$s is %2$d, but %3$s has rank %4$d"_err_en_US,
1005 leftIs, n, rightIs, rn);
1006 return false;
1007 }
1008 for (int j{0}; j < n; ++j) {
1009 if (auto leftDim{ToInt64(left[j])}) {
1010 if (auto rightDim{ToInt64(right[j])}) {
1011 if (*leftDim != *rightDim) {
1012 messages.Say("Dimension %1$d of %2$s has extent %3$jd, "
1013 "but %4$s has extent %5$jd"_err_en_US,
1014 j + 1, leftIs, *leftDim, rightIs, *rightDim);
1015 return false;
1016 }
1017 } else if (!(flags & CheckConformanceFlags::RightIsDeferredShape)) {
1018 return std::nullopt;
1019 }
1020 } else if (!(flags & CheckConformanceFlags::LeftIsDeferredShape)) {
1021 return std::nullopt;
1022 }
1023 }
1024 return true;
1025 }
1026
IncrementSubscripts(ConstantSubscripts & indices,const ConstantSubscripts & extents)1027 bool IncrementSubscripts(
1028 ConstantSubscripts &indices, const ConstantSubscripts &extents) {
1029 std::size_t rank(indices.size());
1030 CHECK(rank <= extents.size());
1031 for (std::size_t j{0}; j < rank; ++j) {
1032 if (extents[j] < 1) {
1033 return false;
1034 }
1035 }
1036 for (std::size_t j{0}; j < rank; ++j) {
1037 if (indices[j]++ < extents[j]) {
1038 return true;
1039 }
1040 indices[j] = 1;
1041 }
1042 return false;
1043 }
1044
1045 } // namespace Fortran::evaluate
1046