1 //===-- lib/Evaluate/fold.cpp ---------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "flang/Evaluate/fold.h"
10 #include "fold-implementation.h"
11 #include "flang/Evaluate/characteristics.h"
12 #include "flang/Evaluate/initial-image.h"
13
14 namespace Fortran::evaluate {
15
Fold(FoldingContext & context,characteristics::TypeAndShape && x)16 characteristics::TypeAndShape Fold(
17 FoldingContext &context, characteristics::TypeAndShape &&x) {
18 x.Rewrite(context);
19 return std::move(x);
20 }
21
GetConstantSubscript(FoldingContext & context,Subscript & ss,const NamedEntity & base,int dim)22 std::optional<Constant<SubscriptInteger>> GetConstantSubscript(
23 FoldingContext &context, Subscript &ss, const NamedEntity &base, int dim) {
24 ss = FoldOperation(context, std::move(ss));
25 return common::visit(
26 common::visitors{
27 [](IndirectSubscriptIntegerExpr &expr)
28 -> std::optional<Constant<SubscriptInteger>> {
29 if (const auto *constant{
30 UnwrapConstantValue<SubscriptInteger>(expr.value())}) {
31 return *constant;
32 } else {
33 return std::nullopt;
34 }
35 },
36 [&](Triplet &triplet) -> std::optional<Constant<SubscriptInteger>> {
37 auto lower{triplet.lower()}, upper{triplet.upper()};
38 std::optional<ConstantSubscript> stride{ToInt64(triplet.stride())};
39 if (!lower) {
40 lower = GetLBOUND(context, base, dim);
41 }
42 if (!upper) {
43 if (auto lb{GetLBOUND(context, base, dim)}) {
44 upper = ComputeUpperBound(
45 context, std::move(*lb), GetExtent(context, base, dim));
46 }
47 }
48 auto lbi{ToInt64(lower)}, ubi{ToInt64(upper)};
49 if (lbi && ubi && stride && *stride != 0) {
50 std::vector<SubscriptInteger::Scalar> values;
51 while ((*stride > 0 && *lbi <= *ubi) ||
52 (*stride < 0 && *lbi >= *ubi)) {
53 values.emplace_back(*lbi);
54 *lbi += *stride;
55 }
56 return Constant<SubscriptInteger>{std::move(values),
57 ConstantSubscripts{
58 static_cast<ConstantSubscript>(values.size())}};
59 } else {
60 return std::nullopt;
61 }
62 },
63 },
64 ss.u);
65 }
66
FoldOperation(FoldingContext & context,StructureConstructor && structure)67 Expr<SomeDerived> FoldOperation(
68 FoldingContext &context, StructureConstructor &&structure) {
69 StructureConstructor ctor{structure.derivedTypeSpec()};
70 bool isConstant{true};
71 auto restorer{context.WithPDTInstance(structure.derivedTypeSpec())};
72 for (auto &&[symbol, value] : std::move(structure)) {
73 auto expr{Fold(context, std::move(value.value()))};
74 if (IsPointer(symbol)) {
75 if (IsProcedure(symbol)) {
76 isConstant &= IsInitialProcedureTarget(expr);
77 } else {
78 isConstant &= IsInitialDataTarget(expr);
79 }
80 } else {
81 isConstant &= IsActuallyConstant(expr);
82 if (auto valueShape{GetConstantExtents(context, expr)}) {
83 if (auto componentShape{GetConstantExtents(context, symbol)}) {
84 if (GetRank(*componentShape) > 0 && GetRank(*valueShape) == 0) {
85 expr = ScalarConstantExpander{std::move(*componentShape)}.Expand(
86 std::move(expr));
87 isConstant &= expr.Rank() > 0;
88 } else {
89 isConstant &= *valueShape == *componentShape;
90 }
91 }
92 }
93 }
94 ctor.Add(symbol, std::move(expr));
95 }
96 if (isConstant) {
97 return Expr<SomeDerived>{Constant<SomeDerived>{std::move(ctor)}};
98 } else {
99 return Expr<SomeDerived>{std::move(ctor)};
100 }
101 }
102
FoldOperation(FoldingContext & context,Component && component)103 Component FoldOperation(FoldingContext &context, Component &&component) {
104 return {FoldOperation(context, std::move(component.base())),
105 component.GetLastSymbol()};
106 }
107
FoldOperation(FoldingContext & context,NamedEntity && x)108 NamedEntity FoldOperation(FoldingContext &context, NamedEntity &&x) {
109 if (Component * c{x.UnwrapComponent()}) {
110 return NamedEntity{FoldOperation(context, std::move(*c))};
111 } else {
112 return std::move(x);
113 }
114 }
115
FoldOperation(FoldingContext & context,Triplet && triplet)116 Triplet FoldOperation(FoldingContext &context, Triplet &&triplet) {
117 MaybeExtentExpr lower{triplet.lower()};
118 MaybeExtentExpr upper{triplet.upper()};
119 return {Fold(context, std::move(lower)), Fold(context, std::move(upper)),
120 Fold(context, triplet.stride())};
121 }
122
FoldOperation(FoldingContext & context,Subscript && subscript)123 Subscript FoldOperation(FoldingContext &context, Subscript &&subscript) {
124 return common::visit(
125 common::visitors{
126 [&](IndirectSubscriptIntegerExpr &&expr) {
127 expr.value() = Fold(context, std::move(expr.value()));
128 return Subscript(std::move(expr));
129 },
130 [&](Triplet &&triplet) {
131 return Subscript(FoldOperation(context, std::move(triplet)));
132 },
133 },
134 std::move(subscript.u));
135 }
136
FoldOperation(FoldingContext & context,ArrayRef && arrayRef)137 ArrayRef FoldOperation(FoldingContext &context, ArrayRef &&arrayRef) {
138 NamedEntity base{FoldOperation(context, std::move(arrayRef.base()))};
139 for (Subscript &subscript : arrayRef.subscript()) {
140 subscript = FoldOperation(context, std::move(subscript));
141 }
142 return ArrayRef{std::move(base), std::move(arrayRef.subscript())};
143 }
144
FoldOperation(FoldingContext & context,CoarrayRef && coarrayRef)145 CoarrayRef FoldOperation(FoldingContext &context, CoarrayRef &&coarrayRef) {
146 std::vector<Subscript> subscript;
147 for (Subscript x : coarrayRef.subscript()) {
148 subscript.emplace_back(FoldOperation(context, std::move(x)));
149 }
150 std::vector<Expr<SubscriptInteger>> cosubscript;
151 for (Expr<SubscriptInteger> x : coarrayRef.cosubscript()) {
152 cosubscript.emplace_back(Fold(context, std::move(x)));
153 }
154 CoarrayRef folded{std::move(coarrayRef.base()), std::move(subscript),
155 std::move(cosubscript)};
156 if (std::optional<Expr<SomeInteger>> stat{coarrayRef.stat()}) {
157 folded.set_stat(Fold(context, std::move(*stat)));
158 }
159 if (std::optional<Expr<SomeInteger>> team{coarrayRef.team()}) {
160 folded.set_team(
161 Fold(context, std::move(*team)), coarrayRef.teamIsTeamNumber());
162 }
163 return folded;
164 }
165
FoldOperation(FoldingContext & context,DataRef && dataRef)166 DataRef FoldOperation(FoldingContext &context, DataRef &&dataRef) {
167 return common::visit(common::visitors{
168 [&](SymbolRef symbol) { return DataRef{*symbol}; },
169 [&](auto &&x) {
170 return DataRef{
171 FoldOperation(context, std::move(x))};
172 },
173 },
174 std::move(dataRef.u));
175 }
176
FoldOperation(FoldingContext & context,Substring && substring)177 Substring FoldOperation(FoldingContext &context, Substring &&substring) {
178 auto lower{Fold(context, substring.lower())};
179 auto upper{Fold(context, substring.upper())};
180 if (const DataRef * dataRef{substring.GetParentIf<DataRef>()}) {
181 return Substring{FoldOperation(context, DataRef{*dataRef}),
182 std::move(lower), std::move(upper)};
183 } else {
184 auto p{*substring.GetParentIf<StaticDataObject::Pointer>()};
185 return Substring{std::move(p), std::move(lower), std::move(upper)};
186 }
187 }
188
FoldOperation(FoldingContext & context,ComplexPart && complexPart)189 ComplexPart FoldOperation(FoldingContext &context, ComplexPart &&complexPart) {
190 DataRef complex{complexPart.complex()};
191 return ComplexPart{
192 FoldOperation(context, std::move(complex)), complexPart.part()};
193 }
194
GetInt64Arg(const std::optional<ActualArgument> & arg)195 std::optional<std::int64_t> GetInt64Arg(
196 const std::optional<ActualArgument> &arg) {
197 if (const auto *intExpr{UnwrapExpr<Expr<SomeInteger>>(arg)}) {
198 return ToInt64(*intExpr);
199 } else {
200 return std::nullopt;
201 }
202 }
203
GetInt64ArgOr(const std::optional<ActualArgument> & arg,std::int64_t defaultValue)204 std::optional<std::int64_t> GetInt64ArgOr(
205 const std::optional<ActualArgument> &arg, std::int64_t defaultValue) {
206 if (!arg) {
207 return defaultValue;
208 } else if (const auto *intExpr{UnwrapExpr<Expr<SomeInteger>>(arg)}) {
209 return ToInt64(*intExpr);
210 } else {
211 return std::nullopt;
212 }
213 }
214
FoldOperation(FoldingContext & context,ImpliedDoIndex && iDo)215 Expr<ImpliedDoIndex::Result> FoldOperation(
216 FoldingContext &context, ImpliedDoIndex &&iDo) {
217 if (std::optional<ConstantSubscript> value{context.GetImpliedDo(iDo.name)}) {
218 return Expr<ImpliedDoIndex::Result>{*value};
219 } else {
220 return Expr<ImpliedDoIndex::Result>{std::move(iDo)};
221 }
222 }
223
224 // TRANSFER (F'2018 16.9.193)
FoldTransfer(FoldingContext & context,const ActualArguments & arguments)225 std::optional<Expr<SomeType>> FoldTransfer(
226 FoldingContext &context, const ActualArguments &arguments) {
227 CHECK(arguments.size() == 2 || arguments.size() == 3);
228 const auto *source{UnwrapExpr<Expr<SomeType>>(arguments[0])};
229 std::optional<std::size_t> sourceBytes;
230 if (source) {
231 if (auto sourceTypeAndShape{
232 characteristics::TypeAndShape::Characterize(*source, context)}) {
233 if (auto sourceBytesExpr{
234 sourceTypeAndShape->MeasureSizeInBytes(context)}) {
235 sourceBytes = ToInt64(*sourceBytesExpr);
236 }
237 }
238 }
239 std::optional<DynamicType> moldType;
240 if (arguments[1]) {
241 moldType = arguments[1]->GetType();
242 }
243 std::optional<ConstantSubscripts> extents;
244 if (arguments.size() == 2) { // no SIZE=
245 if (moldType && sourceBytes) {
246 if (arguments[1]->Rank() == 0) { // scalar MOLD=
247 extents = ConstantSubscripts{}; // empty extents (scalar result)
248 } else if (auto moldBytesExpr{
249 moldType->MeasureSizeInBytes(context, true)}) {
250 if (auto moldBytes{ToInt64(Fold(context, std::move(*moldBytesExpr)))};
251 *moldBytes > 0) {
252 extents = ConstantSubscripts{
253 static_cast<ConstantSubscript>((*sourceBytes) + *moldBytes - 1) /
254 *moldBytes};
255 }
256 }
257 }
258 } else if (arguments[2]) { // SIZE= is present
259 if (const auto *sizeExpr{arguments[2]->UnwrapExpr()}) {
260 if (auto sizeValue{ToInt64(*sizeExpr)}) {
261 extents = ConstantSubscripts{*sizeValue};
262 }
263 }
264 }
265 if (sourceBytes && IsActuallyConstant(*source) && moldType && extents) {
266 InitialImage image{*sourceBytes};
267 InitialImage::Result imageResult{
268 image.Add(0, *sourceBytes, *source, context)};
269 CHECK(imageResult == InitialImage::Ok);
270 return image.AsConstant(context, *moldType, *extents, true /*pad with 0*/);
271 } else {
272 return std::nullopt;
273 }
274 }
275
276 template class ExpressionBase<SomeDerived>;
277 template class ExpressionBase<SomeType>;
278
279 } // namespace Fortran::evaluate
280