1 //===-- lib/Evaluate/fold.cpp ---------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Evaluate/fold.h"
10 #include "fold-implementation.h"
11 #include "flang/Evaluate/characteristics.h"
12 #include "flang/Evaluate/initial-image.h"
13 
14 namespace Fortran::evaluate {
15 
Fold(FoldingContext & context,characteristics::TypeAndShape && x)16 characteristics::TypeAndShape Fold(
17     FoldingContext &context, characteristics::TypeAndShape &&x) {
18   x.Rewrite(context);
19   return std::move(x);
20 }
21 
GetConstantSubscript(FoldingContext & context,Subscript & ss,const NamedEntity & base,int dim)22 std::optional<Constant<SubscriptInteger>> GetConstantSubscript(
23     FoldingContext &context, Subscript &ss, const NamedEntity &base, int dim) {
24   ss = FoldOperation(context, std::move(ss));
25   return common::visit(
26       common::visitors{
27           [](IndirectSubscriptIntegerExpr &expr)
28               -> std::optional<Constant<SubscriptInteger>> {
29             if (const auto *constant{
30                     UnwrapConstantValue<SubscriptInteger>(expr.value())}) {
31               return *constant;
32             } else {
33               return std::nullopt;
34             }
35           },
36           [&](Triplet &triplet) -> std::optional<Constant<SubscriptInteger>> {
37             auto lower{triplet.lower()}, upper{triplet.upper()};
38             std::optional<ConstantSubscript> stride{ToInt64(triplet.stride())};
39             if (!lower) {
40               lower = GetLBOUND(context, base, dim);
41             }
42             if (!upper) {
43               if (auto lb{GetLBOUND(context, base, dim)}) {
44                 upper = ComputeUpperBound(
45                     context, std::move(*lb), GetExtent(context, base, dim));
46               }
47             }
48             auto lbi{ToInt64(lower)}, ubi{ToInt64(upper)};
49             if (lbi && ubi && stride && *stride != 0) {
50               std::vector<SubscriptInteger::Scalar> values;
51               while ((*stride > 0 && *lbi <= *ubi) ||
52                   (*stride < 0 && *lbi >= *ubi)) {
53                 values.emplace_back(*lbi);
54                 *lbi += *stride;
55               }
56               return Constant<SubscriptInteger>{std::move(values),
57                   ConstantSubscripts{
58                       static_cast<ConstantSubscript>(values.size())}};
59             } else {
60               return std::nullopt;
61             }
62           },
63       },
64       ss.u);
65 }
66 
FoldOperation(FoldingContext & context,StructureConstructor && structure)67 Expr<SomeDerived> FoldOperation(
68     FoldingContext &context, StructureConstructor &&structure) {
69   StructureConstructor ctor{structure.derivedTypeSpec()};
70   bool isConstant{true};
71   auto restorer{context.WithPDTInstance(structure.derivedTypeSpec())};
72   for (auto &&[symbol, value] : std::move(structure)) {
73     auto expr{Fold(context, std::move(value.value()))};
74     if (IsPointer(symbol)) {
75       if (IsProcedure(symbol)) {
76         isConstant &= IsInitialProcedureTarget(expr);
77       } else {
78         isConstant &= IsInitialDataTarget(expr);
79       }
80     } else {
81       isConstant &= IsActuallyConstant(expr);
82       if (auto valueShape{GetConstantExtents(context, expr)}) {
83         if (auto componentShape{GetConstantExtents(context, symbol)}) {
84           if (GetRank(*componentShape) > 0 && GetRank(*valueShape) == 0) {
85             expr = ScalarConstantExpander{std::move(*componentShape)}.Expand(
86                 std::move(expr));
87             isConstant &= expr.Rank() > 0;
88           } else {
89             isConstant &= *valueShape == *componentShape;
90           }
91         }
92       }
93     }
94     ctor.Add(symbol, std::move(expr));
95   }
96   if (isConstant) {
97     return Expr<SomeDerived>{Constant<SomeDerived>{std::move(ctor)}};
98   } else {
99     return Expr<SomeDerived>{std::move(ctor)};
100   }
101 }
102 
FoldOperation(FoldingContext & context,Component && component)103 Component FoldOperation(FoldingContext &context, Component &&component) {
104   return {FoldOperation(context, std::move(component.base())),
105       component.GetLastSymbol()};
106 }
107 
FoldOperation(FoldingContext & context,NamedEntity && x)108 NamedEntity FoldOperation(FoldingContext &context, NamedEntity &&x) {
109   if (Component * c{x.UnwrapComponent()}) {
110     return NamedEntity{FoldOperation(context, std::move(*c))};
111   } else {
112     return std::move(x);
113   }
114 }
115 
FoldOperation(FoldingContext & context,Triplet && triplet)116 Triplet FoldOperation(FoldingContext &context, Triplet &&triplet) {
117   MaybeExtentExpr lower{triplet.lower()};
118   MaybeExtentExpr upper{triplet.upper()};
119   return {Fold(context, std::move(lower)), Fold(context, std::move(upper)),
120       Fold(context, triplet.stride())};
121 }
122 
FoldOperation(FoldingContext & context,Subscript && subscript)123 Subscript FoldOperation(FoldingContext &context, Subscript &&subscript) {
124   return common::visit(
125       common::visitors{
126           [&](IndirectSubscriptIntegerExpr &&expr) {
127             expr.value() = Fold(context, std::move(expr.value()));
128             return Subscript(std::move(expr));
129           },
130           [&](Triplet &&triplet) {
131             return Subscript(FoldOperation(context, std::move(triplet)));
132           },
133       },
134       std::move(subscript.u));
135 }
136 
FoldOperation(FoldingContext & context,ArrayRef && arrayRef)137 ArrayRef FoldOperation(FoldingContext &context, ArrayRef &&arrayRef) {
138   NamedEntity base{FoldOperation(context, std::move(arrayRef.base()))};
139   for (Subscript &subscript : arrayRef.subscript()) {
140     subscript = FoldOperation(context, std::move(subscript));
141   }
142   return ArrayRef{std::move(base), std::move(arrayRef.subscript())};
143 }
144 
FoldOperation(FoldingContext & context,CoarrayRef && coarrayRef)145 CoarrayRef FoldOperation(FoldingContext &context, CoarrayRef &&coarrayRef) {
146   std::vector<Subscript> subscript;
147   for (Subscript x : coarrayRef.subscript()) {
148     subscript.emplace_back(FoldOperation(context, std::move(x)));
149   }
150   std::vector<Expr<SubscriptInteger>> cosubscript;
151   for (Expr<SubscriptInteger> x : coarrayRef.cosubscript()) {
152     cosubscript.emplace_back(Fold(context, std::move(x)));
153   }
154   CoarrayRef folded{std::move(coarrayRef.base()), std::move(subscript),
155       std::move(cosubscript)};
156   if (std::optional<Expr<SomeInteger>> stat{coarrayRef.stat()}) {
157     folded.set_stat(Fold(context, std::move(*stat)));
158   }
159   if (std::optional<Expr<SomeInteger>> team{coarrayRef.team()}) {
160     folded.set_team(
161         Fold(context, std::move(*team)), coarrayRef.teamIsTeamNumber());
162   }
163   return folded;
164 }
165 
FoldOperation(FoldingContext & context,DataRef && dataRef)166 DataRef FoldOperation(FoldingContext &context, DataRef &&dataRef) {
167   return common::visit(common::visitors{
168                            [&](SymbolRef symbol) { return DataRef{*symbol}; },
169                            [&](auto &&x) {
170                              return DataRef{
171                                  FoldOperation(context, std::move(x))};
172                            },
173                        },
174       std::move(dataRef.u));
175 }
176 
FoldOperation(FoldingContext & context,Substring && substring)177 Substring FoldOperation(FoldingContext &context, Substring &&substring) {
178   auto lower{Fold(context, substring.lower())};
179   auto upper{Fold(context, substring.upper())};
180   if (const DataRef * dataRef{substring.GetParentIf<DataRef>()}) {
181     return Substring{FoldOperation(context, DataRef{*dataRef}),
182         std::move(lower), std::move(upper)};
183   } else {
184     auto p{*substring.GetParentIf<StaticDataObject::Pointer>()};
185     return Substring{std::move(p), std::move(lower), std::move(upper)};
186   }
187 }
188 
FoldOperation(FoldingContext & context,ComplexPart && complexPart)189 ComplexPart FoldOperation(FoldingContext &context, ComplexPart &&complexPart) {
190   DataRef complex{complexPart.complex()};
191   return ComplexPart{
192       FoldOperation(context, std::move(complex)), complexPart.part()};
193 }
194 
GetInt64Arg(const std::optional<ActualArgument> & arg)195 std::optional<std::int64_t> GetInt64Arg(
196     const std::optional<ActualArgument> &arg) {
197   if (const auto *intExpr{UnwrapExpr<Expr<SomeInteger>>(arg)}) {
198     return ToInt64(*intExpr);
199   } else {
200     return std::nullopt;
201   }
202 }
203 
GetInt64ArgOr(const std::optional<ActualArgument> & arg,std::int64_t defaultValue)204 std::optional<std::int64_t> GetInt64ArgOr(
205     const std::optional<ActualArgument> &arg, std::int64_t defaultValue) {
206   if (!arg) {
207     return defaultValue;
208   } else if (const auto *intExpr{UnwrapExpr<Expr<SomeInteger>>(arg)}) {
209     return ToInt64(*intExpr);
210   } else {
211     return std::nullopt;
212   }
213 }
214 
FoldOperation(FoldingContext & context,ImpliedDoIndex && iDo)215 Expr<ImpliedDoIndex::Result> FoldOperation(
216     FoldingContext &context, ImpliedDoIndex &&iDo) {
217   if (std::optional<ConstantSubscript> value{context.GetImpliedDo(iDo.name)}) {
218     return Expr<ImpliedDoIndex::Result>{*value};
219   } else {
220     return Expr<ImpliedDoIndex::Result>{std::move(iDo)};
221   }
222 }
223 
224 // TRANSFER (F'2018 16.9.193)
FoldTransfer(FoldingContext & context,const ActualArguments & arguments)225 std::optional<Expr<SomeType>> FoldTransfer(
226     FoldingContext &context, const ActualArguments &arguments) {
227   CHECK(arguments.size() == 2 || arguments.size() == 3);
228   const auto *source{UnwrapExpr<Expr<SomeType>>(arguments[0])};
229   std::optional<std::size_t> sourceBytes;
230   if (source) {
231     if (auto sourceTypeAndShape{
232             characteristics::TypeAndShape::Characterize(*source, context)}) {
233       if (auto sourceBytesExpr{
234               sourceTypeAndShape->MeasureSizeInBytes(context)}) {
235         sourceBytes = ToInt64(*sourceBytesExpr);
236       }
237     }
238   }
239   std::optional<DynamicType> moldType;
240   if (arguments[1]) {
241     moldType = arguments[1]->GetType();
242   }
243   std::optional<ConstantSubscripts> extents;
244   if (arguments.size() == 2) { // no SIZE=
245     if (moldType && sourceBytes) {
246       if (arguments[1]->Rank() == 0) { // scalar MOLD=
247         extents = ConstantSubscripts{}; // empty extents (scalar result)
248       } else if (auto moldBytesExpr{
249                      moldType->MeasureSizeInBytes(context, true)}) {
250         if (auto moldBytes{ToInt64(Fold(context, std::move(*moldBytesExpr)))};
251             *moldBytes > 0) {
252           extents = ConstantSubscripts{
253               static_cast<ConstantSubscript>((*sourceBytes) + *moldBytes - 1) /
254               *moldBytes};
255         }
256       }
257     }
258   } else if (arguments[2]) { // SIZE= is present
259     if (const auto *sizeExpr{arguments[2]->UnwrapExpr()}) {
260       if (auto sizeValue{ToInt64(*sizeExpr)}) {
261         extents = ConstantSubscripts{*sizeValue};
262       }
263     }
264   }
265   if (sourceBytes && IsActuallyConstant(*source) && moldType && extents) {
266     InitialImage image{*sourceBytes};
267     InitialImage::Result imageResult{
268         image.Add(0, *sourceBytes, *source, context)};
269     CHECK(imageResult == InitialImage::Ok);
270     return image.AsConstant(context, *moldType, *extents, true /*pad with 0*/);
271   } else {
272     return std::nullopt;
273   }
274 }
275 
276 template class ExpressionBase<SomeDerived>;
277 template class ExpressionBase<SomeType>;
278 
279 } // namespace Fortran::evaluate
280