1 //===-- lib/Evaluate/fold.cpp ---------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Evaluate/fold.h"
10 #include "fold-implementation.h"
11 #include "flang/Evaluate/characteristics.h"
12 
13 namespace Fortran::evaluate {
14 
15 characteristics::TypeAndShape Fold(
16     FoldingContext &context, characteristics::TypeAndShape &&x) {
17   x.Rewrite(context);
18   return std::move(x);
19 }
20 
21 std::optional<Constant<SubscriptInteger>> GetConstantSubscript(
22     FoldingContext &context, Subscript &ss, const NamedEntity &base, int dim) {
23   ss = FoldOperation(context, std::move(ss));
24   return std::visit(
25       common::visitors{
26           [](IndirectSubscriptIntegerExpr &expr)
27               -> std::optional<Constant<SubscriptInteger>> {
28             if (const auto *constant{
29                     UnwrapConstantValue<SubscriptInteger>(expr.value())}) {
30               return *constant;
31             } else {
32               return std::nullopt;
33             }
34           },
35           [&](Triplet &triplet) -> std::optional<Constant<SubscriptInteger>> {
36             auto lower{triplet.lower()}, upper{triplet.upper()};
37             std::optional<ConstantSubscript> stride{ToInt64(triplet.stride())};
38             if (!lower) {
39               lower = GetLowerBound(context, base, dim);
40             }
41             if (!upper) {
42               upper =
43                   ComputeUpperBound(context, GetLowerBound(context, base, dim),
44                       GetExtent(context, base, dim));
45             }
46             auto lbi{ToInt64(lower)}, ubi{ToInt64(upper)};
47             if (lbi && ubi && stride && *stride != 0) {
48               std::vector<SubscriptInteger::Scalar> values;
49               while ((*stride > 0 && *lbi <= *ubi) ||
50                   (*stride < 0 && *lbi >= *ubi)) {
51                 values.emplace_back(*lbi);
52                 *lbi += *stride;
53               }
54               return Constant<SubscriptInteger>{std::move(values),
55                   ConstantSubscripts{
56                       static_cast<ConstantSubscript>(values.size())}};
57             } else {
58               return std::nullopt;
59             }
60           },
61       },
62       ss.u);
63 }
64 
65 Expr<SomeDerived> FoldOperation(
66     FoldingContext &context, StructureConstructor &&structure) {
67   StructureConstructor ctor{structure.derivedTypeSpec()};
68   bool isConstant{true};
69   for (auto &&[symbol, value] : std::move(structure)) {
70     auto expr{Fold(context, std::move(value.value()))};
71     if (IsPointer(symbol)) {
72       if (IsProcedure(symbol)) {
73         isConstant &= IsInitialProcedureTarget(expr);
74       } else {
75         isConstant &= IsInitialDataTarget(expr);
76       }
77     } else {
78       isConstant &= IsActuallyConstant(expr);
79       if (auto valueShape{GetConstantExtents(context, expr)}) {
80         if (auto componentShape{GetConstantExtents(context, symbol)}) {
81           if (GetRank(*componentShape) > 0 && GetRank(*valueShape) == 0) {
82             expr = ScalarConstantExpander{std::move(*componentShape)}.Expand(
83                 std::move(expr));
84             isConstant &= expr.Rank() > 0;
85           } else {
86             isConstant &= *valueShape == *componentShape;
87           }
88         }
89       }
90     }
91     ctor.Add(symbol, std::move(expr));
92   }
93   if (isConstant) {
94     return Expr<SomeDerived>{Constant<SomeDerived>{std::move(ctor)}};
95   } else {
96     return Expr<SomeDerived>{std::move(ctor)};
97   }
98 }
99 
100 Component FoldOperation(FoldingContext &context, Component &&component) {
101   return {FoldOperation(context, std::move(component.base())),
102       component.GetLastSymbol()};
103 }
104 
105 NamedEntity FoldOperation(FoldingContext &context, NamedEntity &&x) {
106   if (Component * c{x.UnwrapComponent()}) {
107     return NamedEntity{FoldOperation(context, std::move(*c))};
108   } else {
109     return std::move(x);
110   }
111 }
112 
113 Triplet FoldOperation(FoldingContext &context, Triplet &&triplet) {
114   MaybeExtentExpr lower{triplet.lower()};
115   MaybeExtentExpr upper{triplet.upper()};
116   return {Fold(context, std::move(lower)), Fold(context, std::move(upper)),
117       Fold(context, triplet.stride())};
118 }
119 
120 Subscript FoldOperation(FoldingContext &context, Subscript &&subscript) {
121   return std::visit(common::visitors{
122                         [&](IndirectSubscriptIntegerExpr &&expr) {
123                           expr.value() = Fold(context, std::move(expr.value()));
124                           return Subscript(std::move(expr));
125                         },
126                         [&](Triplet &&triplet) {
127                           return Subscript(
128                               FoldOperation(context, std::move(triplet)));
129                         },
130                     },
131       std::move(subscript.u));
132 }
133 
134 ArrayRef FoldOperation(FoldingContext &context, ArrayRef &&arrayRef) {
135   NamedEntity base{FoldOperation(context, std::move(arrayRef.base()))};
136   for (Subscript &subscript : arrayRef.subscript()) {
137     subscript = FoldOperation(context, std::move(subscript));
138   }
139   return ArrayRef{std::move(base), std::move(arrayRef.subscript())};
140 }
141 
142 CoarrayRef FoldOperation(FoldingContext &context, CoarrayRef &&coarrayRef) {
143   std::vector<Subscript> subscript;
144   for (Subscript x : coarrayRef.subscript()) {
145     subscript.emplace_back(FoldOperation(context, std::move(x)));
146   }
147   std::vector<Expr<SubscriptInteger>> cosubscript;
148   for (Expr<SubscriptInteger> x : coarrayRef.cosubscript()) {
149     cosubscript.emplace_back(Fold(context, std::move(x)));
150   }
151   CoarrayRef folded{std::move(coarrayRef.base()), std::move(subscript),
152       std::move(cosubscript)};
153   if (std::optional<Expr<SomeInteger>> stat{coarrayRef.stat()}) {
154     folded.set_stat(Fold(context, std::move(*stat)));
155   }
156   if (std::optional<Expr<SomeInteger>> team{coarrayRef.team()}) {
157     folded.set_team(
158         Fold(context, std::move(*team)), coarrayRef.teamIsTeamNumber());
159   }
160   return folded;
161 }
162 
163 DataRef FoldOperation(FoldingContext &context, DataRef &&dataRef) {
164   return std::visit(common::visitors{
165                         [&](SymbolRef symbol) { return DataRef{*symbol}; },
166                         [&](auto &&x) {
167                           return DataRef{FoldOperation(context, std::move(x))};
168                         },
169                     },
170       std::move(dataRef.u));
171 }
172 
173 Substring FoldOperation(FoldingContext &context, Substring &&substring) {
174   auto lower{Fold(context, substring.lower())};
175   auto upper{Fold(context, substring.upper())};
176   if (const DataRef * dataRef{substring.GetParentIf<DataRef>()}) {
177     return Substring{FoldOperation(context, DataRef{*dataRef}),
178         std::move(lower), std::move(upper)};
179   } else {
180     auto p{*substring.GetParentIf<StaticDataObject::Pointer>()};
181     return Substring{std::move(p), std::move(lower), std::move(upper)};
182   }
183 }
184 
185 ComplexPart FoldOperation(FoldingContext &context, ComplexPart &&complexPart) {
186   DataRef complex{complexPart.complex()};
187   return ComplexPart{
188       FoldOperation(context, std::move(complex)), complexPart.part()};
189 }
190 
191 std::optional<std::int64_t> GetInt64Arg(
192     const std::optional<ActualArgument> &arg) {
193   if (const auto *intExpr{UnwrapExpr<Expr<SomeInteger>>(arg)}) {
194     return ToInt64(*intExpr);
195   } else {
196     return std::nullopt;
197   }
198 }
199 
200 std::optional<std::int64_t> GetInt64ArgOr(
201     const std::optional<ActualArgument> &arg, std::int64_t defaultValue) {
202   if (!arg) {
203     return defaultValue;
204   } else if (const auto *intExpr{UnwrapExpr<Expr<SomeInteger>>(arg)}) {
205     return ToInt64(*intExpr);
206   } else {
207     return std::nullopt;
208   }
209 }
210 
211 Expr<ImpliedDoIndex::Result> FoldOperation(
212     FoldingContext &context, ImpliedDoIndex &&iDo) {
213   if (std::optional<ConstantSubscript> value{context.GetImpliedDo(iDo.name)}) {
214     return Expr<ImpliedDoIndex::Result>{*value};
215   } else {
216     return Expr<ImpliedDoIndex::Result>{std::move(iDo)};
217   }
218 }
219 
220 template class ExpressionBase<SomeDerived>;
221 template class ExpressionBase<SomeType>;
222 
223 } // namespace Fortran::evaluate
224