1 //===-- IterationSpace.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "flang/Lower/IterationSpace.h"
14 #include "flang/Evaluate/expression.h"
15 #include "flang/Lower/AbstractConverter.h"
16 #include "flang/Lower/Support/Utils.h"
17 #include "llvm/Support/Debug.h"
18 
19 #define DEBUG_TYPE "flang-lower-iteration-space"
20 
21 namespace {
22 // Fortran::evaluate::Expr are functional values organized like an AST. A
23 // Fortran::evaluate::Expr is meant to be moved and cloned. Using the front end
24 // tools can often cause copies and extra wrapper classes to be added to any
25 // Fortran::evalute::Expr. These values should not be assumed or relied upon to
26 // have an *object* identity. They are deeply recursive, irregular structures
27 // built from a large number of classes which do not use inheritance and
28 // necessitate a large volume of boilerplate code as a result.
29 //
30 // Contrastingly, LLVM data structures make ubiquitous assumptions about an
31 // object's identity via pointers to the object. An object's location in memory
32 // is thus very often an identifying relation.
33 
34 // This class defines a hash computation of a Fortran::evaluate::Expr tree value
35 // so it can be used with llvm::DenseMap. The Fortran::evaluate::Expr need not
36 // have the same address.
37 class HashEvaluateExpr {
38 public:
39   // A Se::Symbol is the only part of an Fortran::evaluate::Expr with an
40   // identity property.
getHashValue(const Fortran::semantics::Symbol & x)41   static unsigned getHashValue(const Fortran::semantics::Symbol &x) {
42     return static_cast<unsigned>(reinterpret_cast<std::intptr_t>(&x));
43   }
44   template <typename A, bool COPY>
getHashValue(const Fortran::common::Indirection<A,COPY> & x)45   static unsigned getHashValue(const Fortran::common::Indirection<A, COPY> &x) {
46     return getHashValue(x.value());
47   }
48   template <typename A>
getHashValue(const std::optional<A> & x)49   static unsigned getHashValue(const std::optional<A> &x) {
50     if (x.has_value())
51       return getHashValue(x.value());
52     return 0u;
53   }
getHashValue(const Fortran::evaluate::Subscript & x)54   static unsigned getHashValue(const Fortran::evaluate::Subscript &x) {
55     return std::visit([&](const auto &v) { return getHashValue(v); }, x.u);
56   }
getHashValue(const Fortran::evaluate::Triplet & x)57   static unsigned getHashValue(const Fortran::evaluate::Triplet &x) {
58     return getHashValue(x.lower()) - getHashValue(x.upper()) * 5u -
59            getHashValue(x.stride()) * 11u;
60   }
getHashValue(const Fortran::evaluate::Component & x)61   static unsigned getHashValue(const Fortran::evaluate::Component &x) {
62     return getHashValue(x.base()) * 83u - getHashValue(x.GetLastSymbol());
63   }
getHashValue(const Fortran::evaluate::ArrayRef & x)64   static unsigned getHashValue(const Fortran::evaluate::ArrayRef &x) {
65     unsigned subs = 1u;
66     for (const Fortran::evaluate::Subscript &v : x.subscript())
67       subs -= getHashValue(v);
68     return getHashValue(x.base()) * 89u - subs;
69   }
getHashValue(const Fortran::evaluate::CoarrayRef & x)70   static unsigned getHashValue(const Fortran::evaluate::CoarrayRef &x) {
71     unsigned subs = 1u;
72     for (const Fortran::evaluate::Subscript &v : x.subscript())
73       subs -= getHashValue(v);
74     unsigned cosubs = 3u;
75     for (const Fortran::evaluate::Expr<Fortran::evaluate::SubscriptInteger> &v :
76          x.cosubscript())
77       cosubs -= getHashValue(v);
78     unsigned syms = 7u;
79     for (const Fortran::evaluate::SymbolRef &v : x.base())
80       syms += getHashValue(v);
81     return syms * 97u - subs - cosubs + getHashValue(x.stat()) + 257u +
82            getHashValue(x.team());
83   }
getHashValue(const Fortran::evaluate::NamedEntity & x)84   static unsigned getHashValue(const Fortran::evaluate::NamedEntity &x) {
85     if (x.IsSymbol())
86       return getHashValue(x.GetFirstSymbol()) * 11u;
87     return getHashValue(x.GetComponent()) * 13u;
88   }
getHashValue(const Fortran::evaluate::DataRef & x)89   static unsigned getHashValue(const Fortran::evaluate::DataRef &x) {
90     return std::visit([&](const auto &v) { return getHashValue(v); }, x.u);
91   }
getHashValue(const Fortran::evaluate::ComplexPart & x)92   static unsigned getHashValue(const Fortran::evaluate::ComplexPart &x) {
93     return getHashValue(x.complex()) - static_cast<unsigned>(x.part());
94   }
95   template <Fortran::common::TypeCategory TC1, int KIND,
96             Fortran::common::TypeCategory TC2>
getHashValue(const Fortran::evaluate::Convert<Fortran::evaluate::Type<TC1,KIND>,TC2> & x)97   static unsigned getHashValue(
98       const Fortran::evaluate::Convert<Fortran::evaluate::Type<TC1, KIND>, TC2>
99           &x) {
100     return getHashValue(x.left()) - (static_cast<unsigned>(TC1) + 2u) -
101            (static_cast<unsigned>(KIND) + 5u);
102   }
103   template <int KIND>
104   static unsigned
getHashValue(const Fortran::evaluate::ComplexComponent<KIND> & x)105   getHashValue(const Fortran::evaluate::ComplexComponent<KIND> &x) {
106     return getHashValue(x.left()) -
107            (static_cast<unsigned>(x.isImaginaryPart) + 1u) * 3u;
108   }
109   template <typename T>
getHashValue(const Fortran::evaluate::Parentheses<T> & x)110   static unsigned getHashValue(const Fortran::evaluate::Parentheses<T> &x) {
111     return getHashValue(x.left()) * 17u;
112   }
113   template <Fortran::common::TypeCategory TC, int KIND>
getHashValue(const Fortran::evaluate::Negate<Fortran::evaluate::Type<TC,KIND>> & x)114   static unsigned getHashValue(
115       const Fortran::evaluate::Negate<Fortran::evaluate::Type<TC, KIND>> &x) {
116     return getHashValue(x.left()) - (static_cast<unsigned>(TC) + 5u) -
117            (static_cast<unsigned>(KIND) + 7u);
118   }
119   template <Fortran::common::TypeCategory TC, int KIND>
getHashValue(const Fortran::evaluate::Add<Fortran::evaluate::Type<TC,KIND>> & x)120   static unsigned getHashValue(
121       const Fortran::evaluate::Add<Fortran::evaluate::Type<TC, KIND>> &x) {
122     return (getHashValue(x.left()) + getHashValue(x.right())) * 23u +
123            static_cast<unsigned>(TC) + static_cast<unsigned>(KIND);
124   }
125   template <Fortran::common::TypeCategory TC, int KIND>
getHashValue(const Fortran::evaluate::Subtract<Fortran::evaluate::Type<TC,KIND>> & x)126   static unsigned getHashValue(
127       const Fortran::evaluate::Subtract<Fortran::evaluate::Type<TC, KIND>> &x) {
128     return (getHashValue(x.left()) - getHashValue(x.right())) * 19u +
129            static_cast<unsigned>(TC) + static_cast<unsigned>(KIND);
130   }
131   template <Fortran::common::TypeCategory TC, int KIND>
getHashValue(const Fortran::evaluate::Multiply<Fortran::evaluate::Type<TC,KIND>> & x)132   static unsigned getHashValue(
133       const Fortran::evaluate::Multiply<Fortran::evaluate::Type<TC, KIND>> &x) {
134     return (getHashValue(x.left()) + getHashValue(x.right())) * 29u +
135            static_cast<unsigned>(TC) + static_cast<unsigned>(KIND);
136   }
137   template <Fortran::common::TypeCategory TC, int KIND>
getHashValue(const Fortran::evaluate::Divide<Fortran::evaluate::Type<TC,KIND>> & x)138   static unsigned getHashValue(
139       const Fortran::evaluate::Divide<Fortran::evaluate::Type<TC, KIND>> &x) {
140     return (getHashValue(x.left()) - getHashValue(x.right())) * 31u +
141            static_cast<unsigned>(TC) + static_cast<unsigned>(KIND);
142   }
143   template <Fortran::common::TypeCategory TC, int KIND>
getHashValue(const Fortran::evaluate::Power<Fortran::evaluate::Type<TC,KIND>> & x)144   static unsigned getHashValue(
145       const Fortran::evaluate::Power<Fortran::evaluate::Type<TC, KIND>> &x) {
146     return (getHashValue(x.left()) - getHashValue(x.right())) * 37u +
147            static_cast<unsigned>(TC) + static_cast<unsigned>(KIND);
148   }
149   template <Fortran::common::TypeCategory TC, int KIND>
getHashValue(const Fortran::evaluate::Extremum<Fortran::evaluate::Type<TC,KIND>> & x)150   static unsigned getHashValue(
151       const Fortran::evaluate::Extremum<Fortran::evaluate::Type<TC, KIND>> &x) {
152     return (getHashValue(x.left()) + getHashValue(x.right())) * 41u +
153            static_cast<unsigned>(TC) + static_cast<unsigned>(KIND) +
154            static_cast<unsigned>(x.ordering) * 7u;
155   }
156   template <Fortran::common::TypeCategory TC, int KIND>
getHashValue(const Fortran::evaluate::RealToIntPower<Fortran::evaluate::Type<TC,KIND>> & x)157   static unsigned getHashValue(
158       const Fortran::evaluate::RealToIntPower<Fortran::evaluate::Type<TC, KIND>>
159           &x) {
160     return (getHashValue(x.left()) - getHashValue(x.right())) * 43u +
161            static_cast<unsigned>(TC) + static_cast<unsigned>(KIND);
162   }
163   template <int KIND>
164   static unsigned
getHashValue(const Fortran::evaluate::ComplexConstructor<KIND> & x)165   getHashValue(const Fortran::evaluate::ComplexConstructor<KIND> &x) {
166     return (getHashValue(x.left()) - getHashValue(x.right())) * 47u +
167            static_cast<unsigned>(KIND);
168   }
169   template <int KIND>
getHashValue(const Fortran::evaluate::Concat<KIND> & x)170   static unsigned getHashValue(const Fortran::evaluate::Concat<KIND> &x) {
171     return (getHashValue(x.left()) - getHashValue(x.right())) * 53u +
172            static_cast<unsigned>(KIND);
173   }
174   template <int KIND>
getHashValue(const Fortran::evaluate::SetLength<KIND> & x)175   static unsigned getHashValue(const Fortran::evaluate::SetLength<KIND> &x) {
176     return (getHashValue(x.left()) - getHashValue(x.right())) * 59u +
177            static_cast<unsigned>(KIND);
178   }
getHashValue(const Fortran::semantics::SymbolRef & sym)179   static unsigned getHashValue(const Fortran::semantics::SymbolRef &sym) {
180     return getHashValue(sym.get());
181   }
getHashValue(const Fortran::evaluate::Substring & x)182   static unsigned getHashValue(const Fortran::evaluate::Substring &x) {
183     return 61u * std::visit([&](const auto &p) { return getHashValue(p); },
184                             x.parent()) -
185            getHashValue(x.lower()) - (getHashValue(x.lower()) + 1u);
186   }
187   static unsigned
getHashValue(const Fortran::evaluate::StaticDataObject::Pointer & x)188   getHashValue(const Fortran::evaluate::StaticDataObject::Pointer &x) {
189     return llvm::hash_value(x->name());
190   }
getHashValue(const Fortran::evaluate::SpecificIntrinsic & x)191   static unsigned getHashValue(const Fortran::evaluate::SpecificIntrinsic &x) {
192     return llvm::hash_value(x.name);
193   }
194   template <typename A>
getHashValue(const Fortran::evaluate::Constant<A> & x)195   static unsigned getHashValue(const Fortran::evaluate::Constant<A> &x) {
196     // FIXME: Should hash the content.
197     return 103u;
198   }
getHashValue(const Fortran::evaluate::ActualArgument & x)199   static unsigned getHashValue(const Fortran::evaluate::ActualArgument &x) {
200     if (const Fortran::evaluate::Symbol *sym = x.GetAssumedTypeDummy())
201       return getHashValue(*sym);
202     return getHashValue(*x.UnwrapExpr());
203   }
204   static unsigned
getHashValue(const Fortran::evaluate::ProcedureDesignator & x)205   getHashValue(const Fortran::evaluate::ProcedureDesignator &x) {
206     return std::visit([&](const auto &v) { return getHashValue(v); }, x.u);
207   }
getHashValue(const Fortran::evaluate::ProcedureRef & x)208   static unsigned getHashValue(const Fortran::evaluate::ProcedureRef &x) {
209     unsigned args = 13u;
210     for (const std::optional<Fortran::evaluate::ActualArgument> &v :
211          x.arguments())
212       args -= getHashValue(v);
213     return getHashValue(x.proc()) * 101u - args;
214   }
215   template <typename A>
216   static unsigned
getHashValue(const Fortran::evaluate::ArrayConstructor<A> & x)217   getHashValue(const Fortran::evaluate::ArrayConstructor<A> &x) {
218     // FIXME: hash the contents.
219     return 127u;
220   }
getHashValue(const Fortran::evaluate::ImpliedDoIndex & x)221   static unsigned getHashValue(const Fortran::evaluate::ImpliedDoIndex &x) {
222     return llvm::hash_value(toStringRef(x.name).str()) * 131u;
223   }
getHashValue(const Fortran::evaluate::TypeParamInquiry & x)224   static unsigned getHashValue(const Fortran::evaluate::TypeParamInquiry &x) {
225     return getHashValue(x.base()) * 137u - getHashValue(x.parameter()) * 3u;
226   }
getHashValue(const Fortran::evaluate::DescriptorInquiry & x)227   static unsigned getHashValue(const Fortran::evaluate::DescriptorInquiry &x) {
228     return getHashValue(x.base()) * 139u -
229            static_cast<unsigned>(x.field()) * 13u +
230            static_cast<unsigned>(x.dimension());
231   }
232   static unsigned
getHashValue(const Fortran::evaluate::StructureConstructor & x)233   getHashValue(const Fortran::evaluate::StructureConstructor &x) {
234     // FIXME: hash the contents.
235     return 149u;
236   }
237   template <int KIND>
getHashValue(const Fortran::evaluate::Not<KIND> & x)238   static unsigned getHashValue(const Fortran::evaluate::Not<KIND> &x) {
239     return getHashValue(x.left()) * 61u + static_cast<unsigned>(KIND);
240   }
241   template <int KIND>
242   static unsigned
getHashValue(const Fortran::evaluate::LogicalOperation<KIND> & x)243   getHashValue(const Fortran::evaluate::LogicalOperation<KIND> &x) {
244     unsigned result = getHashValue(x.left()) + getHashValue(x.right());
245     return result * 67u + static_cast<unsigned>(x.logicalOperator) * 5u;
246   }
247   template <Fortran::common::TypeCategory TC, int KIND>
getHashValue(const Fortran::evaluate::Relational<Fortran::evaluate::Type<TC,KIND>> & x)248   static unsigned getHashValue(
249       const Fortran::evaluate::Relational<Fortran::evaluate::Type<TC, KIND>>
250           &x) {
251     return (getHashValue(x.left()) + getHashValue(x.right())) * 71u +
252            static_cast<unsigned>(TC) + static_cast<unsigned>(KIND) +
253            static_cast<unsigned>(x.opr) * 11u;
254   }
255   template <typename A>
getHashValue(const Fortran::evaluate::Expr<A> & x)256   static unsigned getHashValue(const Fortran::evaluate::Expr<A> &x) {
257     return std::visit([&](const auto &v) { return getHashValue(v); }, x.u);
258   }
getHashValue(const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> & x)259   static unsigned getHashValue(
260       const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> &x) {
261     return std::visit([&](const auto &v) { return getHashValue(v); }, x.u);
262   }
263   template <typename A>
getHashValue(const Fortran::evaluate::Designator<A> & x)264   static unsigned getHashValue(const Fortran::evaluate::Designator<A> &x) {
265     return std::visit([&](const auto &v) { return getHashValue(v); }, x.u);
266   }
267   template <int BITS>
268   static unsigned
getHashValue(const Fortran::evaluate::value::Integer<BITS> & x)269   getHashValue(const Fortran::evaluate::value::Integer<BITS> &x) {
270     return static_cast<unsigned>(x.ToSInt());
271   }
getHashValue(const Fortran::evaluate::NullPointer & x)272   static unsigned getHashValue(const Fortran::evaluate::NullPointer &x) {
273     return ~179u;
274   }
275 };
276 } // namespace
277 
getHashValue(const Fortran::lower::ExplicitIterSpace::ArrayBases & x)278 unsigned Fortran::lower::getHashValue(
279     const Fortran::lower::ExplicitIterSpace::ArrayBases &x) {
280   return std::visit(
281       [&](const auto *p) { return HashEvaluateExpr::getHashValue(*p); }, x);
282 }
283 
getHashValue(Fortran::lower::FrontEndExpr x)284 unsigned Fortran::lower::getHashValue(Fortran::lower::FrontEndExpr x) {
285   return HashEvaluateExpr::getHashValue(*x);
286 }
287 
288 namespace {
289 // Define the is equals test for using Fortran::evaluate::Expr values with
290 // llvm::DenseMap.
291 class IsEqualEvaluateExpr {
292 public:
293   // A Se::Symbol is the only part of an Fortran::evaluate::Expr with an
294   // identity property.
isEqual(const Fortran::semantics::Symbol & x,const Fortran::semantics::Symbol & y)295   static bool isEqual(const Fortran::semantics::Symbol &x,
296                       const Fortran::semantics::Symbol &y) {
297     return isEqual(&x, &y);
298   }
isEqual(const Fortran::semantics::Symbol * x,const Fortran::semantics::Symbol * y)299   static bool isEqual(const Fortran::semantics::Symbol *x,
300                       const Fortran::semantics::Symbol *y) {
301     return x == y;
302   }
303   template <typename A, bool COPY>
isEqual(const Fortran::common::Indirection<A,COPY> & x,const Fortran::common::Indirection<A,COPY> & y)304   static bool isEqual(const Fortran::common::Indirection<A, COPY> &x,
305                       const Fortran::common::Indirection<A, COPY> &y) {
306     return isEqual(x.value(), y.value());
307   }
308   template <typename A>
isEqual(const std::optional<A> & x,const std::optional<A> & y)309   static bool isEqual(const std::optional<A> &x, const std::optional<A> &y) {
310     if (x.has_value() && y.has_value())
311       return isEqual(x.value(), y.value());
312     return !x.has_value() && !y.has_value();
313   }
314   template <typename A>
isEqual(const std::vector<A> & x,const std::vector<A> & y)315   static bool isEqual(const std::vector<A> &x, const std::vector<A> &y) {
316     if (x.size() != y.size())
317       return false;
318     const std::size_t size = x.size();
319     for (std::remove_const_t<decltype(size)> i = 0; i < size; ++i)
320       if (!isEqual(x[i], y[i]))
321         return false;
322     return true;
323   }
isEqual(const Fortran::evaluate::Subscript & x,const Fortran::evaluate::Subscript & y)324   static bool isEqual(const Fortran::evaluate::Subscript &x,
325                       const Fortran::evaluate::Subscript &y) {
326     return std::visit(
327         [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u);
328   }
isEqual(const Fortran::evaluate::Triplet & x,const Fortran::evaluate::Triplet & y)329   static bool isEqual(const Fortran::evaluate::Triplet &x,
330                       const Fortran::evaluate::Triplet &y) {
331     return isEqual(x.lower(), y.lower()) && isEqual(x.upper(), y.upper()) &&
332            isEqual(x.stride(), y.stride());
333   }
isEqual(const Fortran::evaluate::Component & x,const Fortran::evaluate::Component & y)334   static bool isEqual(const Fortran::evaluate::Component &x,
335                       const Fortran::evaluate::Component &y) {
336     return isEqual(x.base(), y.base()) &&
337            isEqual(x.GetLastSymbol(), y.GetLastSymbol());
338   }
isEqual(const Fortran::evaluate::ArrayRef & x,const Fortran::evaluate::ArrayRef & y)339   static bool isEqual(const Fortran::evaluate::ArrayRef &x,
340                       const Fortran::evaluate::ArrayRef &y) {
341     return isEqual(x.base(), y.base()) && isEqual(x.subscript(), y.subscript());
342   }
isEqual(const Fortran::evaluate::CoarrayRef & x,const Fortran::evaluate::CoarrayRef & y)343   static bool isEqual(const Fortran::evaluate::CoarrayRef &x,
344                       const Fortran::evaluate::CoarrayRef &y) {
345     return isEqual(x.base(), y.base()) &&
346            isEqual(x.subscript(), y.subscript()) &&
347            isEqual(x.cosubscript(), y.cosubscript()) &&
348            isEqual(x.stat(), y.stat()) && isEqual(x.team(), y.team());
349   }
isEqual(const Fortran::evaluate::NamedEntity & x,const Fortran::evaluate::NamedEntity & y)350   static bool isEqual(const Fortran::evaluate::NamedEntity &x,
351                       const Fortran::evaluate::NamedEntity &y) {
352     if (x.IsSymbol() && y.IsSymbol())
353       return isEqual(x.GetFirstSymbol(), y.GetFirstSymbol());
354     return !x.IsSymbol() && !y.IsSymbol() &&
355            isEqual(x.GetComponent(), y.GetComponent());
356   }
isEqual(const Fortran::evaluate::DataRef & x,const Fortran::evaluate::DataRef & y)357   static bool isEqual(const Fortran::evaluate::DataRef &x,
358                       const Fortran::evaluate::DataRef &y) {
359     return std::visit(
360         [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u);
361   }
isEqual(const Fortran::evaluate::ComplexPart & x,const Fortran::evaluate::ComplexPart & y)362   static bool isEqual(const Fortran::evaluate::ComplexPart &x,
363                       const Fortran::evaluate::ComplexPart &y) {
364     return isEqual(x.complex(), y.complex()) && x.part() == y.part();
365   }
366   template <typename A, Fortran::common::TypeCategory TC2>
isEqual(const Fortran::evaluate::Convert<A,TC2> & x,const Fortran::evaluate::Convert<A,TC2> & y)367   static bool isEqual(const Fortran::evaluate::Convert<A, TC2> &x,
368                       const Fortran::evaluate::Convert<A, TC2> &y) {
369     return isEqual(x.left(), y.left());
370   }
371   template <int KIND>
isEqual(const Fortran::evaluate::ComplexComponent<KIND> & x,const Fortran::evaluate::ComplexComponent<KIND> & y)372   static bool isEqual(const Fortran::evaluate::ComplexComponent<KIND> &x,
373                       const Fortran::evaluate::ComplexComponent<KIND> &y) {
374     return isEqual(x.left(), y.left()) &&
375            x.isImaginaryPart == y.isImaginaryPart;
376   }
377   template <typename T>
isEqual(const Fortran::evaluate::Parentheses<T> & x,const Fortran::evaluate::Parentheses<T> & y)378   static bool isEqual(const Fortran::evaluate::Parentheses<T> &x,
379                       const Fortran::evaluate::Parentheses<T> &y) {
380     return isEqual(x.left(), y.left());
381   }
382   template <typename A>
isEqual(const Fortran::evaluate::Negate<A> & x,const Fortran::evaluate::Negate<A> & y)383   static bool isEqual(const Fortran::evaluate::Negate<A> &x,
384                       const Fortran::evaluate::Negate<A> &y) {
385     return isEqual(x.left(), y.left());
386   }
387   template <typename A>
isBinaryEqual(const A & x,const A & y)388   static bool isBinaryEqual(const A &x, const A &y) {
389     return isEqual(x.left(), y.left()) && isEqual(x.right(), y.right());
390   }
391   template <typename A>
isEqual(const Fortran::evaluate::Add<A> & x,const Fortran::evaluate::Add<A> & y)392   static bool isEqual(const Fortran::evaluate::Add<A> &x,
393                       const Fortran::evaluate::Add<A> &y) {
394     return isBinaryEqual(x, y);
395   }
396   template <typename A>
isEqual(const Fortran::evaluate::Subtract<A> & x,const Fortran::evaluate::Subtract<A> & y)397   static bool isEqual(const Fortran::evaluate::Subtract<A> &x,
398                       const Fortran::evaluate::Subtract<A> &y) {
399     return isBinaryEqual(x, y);
400   }
401   template <typename A>
isEqual(const Fortran::evaluate::Multiply<A> & x,const Fortran::evaluate::Multiply<A> & y)402   static bool isEqual(const Fortran::evaluate::Multiply<A> &x,
403                       const Fortran::evaluate::Multiply<A> &y) {
404     return isBinaryEqual(x, y);
405   }
406   template <typename A>
isEqual(const Fortran::evaluate::Divide<A> & x,const Fortran::evaluate::Divide<A> & y)407   static bool isEqual(const Fortran::evaluate::Divide<A> &x,
408                       const Fortran::evaluate::Divide<A> &y) {
409     return isBinaryEqual(x, y);
410   }
411   template <typename A>
isEqual(const Fortran::evaluate::Power<A> & x,const Fortran::evaluate::Power<A> & y)412   static bool isEqual(const Fortran::evaluate::Power<A> &x,
413                       const Fortran::evaluate::Power<A> &y) {
414     return isBinaryEqual(x, y);
415   }
416   template <typename A>
isEqual(const Fortran::evaluate::Extremum<A> & x,const Fortran::evaluate::Extremum<A> & y)417   static bool isEqual(const Fortran::evaluate::Extremum<A> &x,
418                       const Fortran::evaluate::Extremum<A> &y) {
419     return isBinaryEqual(x, y);
420   }
421   template <typename A>
isEqual(const Fortran::evaluate::RealToIntPower<A> & x,const Fortran::evaluate::RealToIntPower<A> & y)422   static bool isEqual(const Fortran::evaluate::RealToIntPower<A> &x,
423                       const Fortran::evaluate::RealToIntPower<A> &y) {
424     return isBinaryEqual(x, y);
425   }
426   template <int KIND>
isEqual(const Fortran::evaluate::ComplexConstructor<KIND> & x,const Fortran::evaluate::ComplexConstructor<KIND> & y)427   static bool isEqual(const Fortran::evaluate::ComplexConstructor<KIND> &x,
428                       const Fortran::evaluate::ComplexConstructor<KIND> &y) {
429     return isBinaryEqual(x, y);
430   }
431   template <int KIND>
isEqual(const Fortran::evaluate::Concat<KIND> & x,const Fortran::evaluate::Concat<KIND> & y)432   static bool isEqual(const Fortran::evaluate::Concat<KIND> &x,
433                       const Fortran::evaluate::Concat<KIND> &y) {
434     return isBinaryEqual(x, y);
435   }
436   template <int KIND>
isEqual(const Fortran::evaluate::SetLength<KIND> & x,const Fortran::evaluate::SetLength<KIND> & y)437   static bool isEqual(const Fortran::evaluate::SetLength<KIND> &x,
438                       const Fortran::evaluate::SetLength<KIND> &y) {
439     return isBinaryEqual(x, y);
440   }
isEqual(const Fortran::semantics::SymbolRef & x,const Fortran::semantics::SymbolRef & y)441   static bool isEqual(const Fortran::semantics::SymbolRef &x,
442                       const Fortran::semantics::SymbolRef &y) {
443     return isEqual(x.get(), y.get());
444   }
isEqual(const Fortran::evaluate::Substring & x,const Fortran::evaluate::Substring & y)445   static bool isEqual(const Fortran::evaluate::Substring &x,
446                       const Fortran::evaluate::Substring &y) {
447     return std::visit(
448                [&](const auto &p, const auto &q) { return isEqual(p, q); },
449                x.parent(), y.parent()) &&
450            isEqual(x.lower(), y.lower()) && isEqual(x.lower(), y.lower());
451   }
isEqual(const Fortran::evaluate::StaticDataObject::Pointer & x,const Fortran::evaluate::StaticDataObject::Pointer & y)452   static bool isEqual(const Fortran::evaluate::StaticDataObject::Pointer &x,
453                       const Fortran::evaluate::StaticDataObject::Pointer &y) {
454     return x->name() == y->name();
455   }
isEqual(const Fortran::evaluate::SpecificIntrinsic & x,const Fortran::evaluate::SpecificIntrinsic & y)456   static bool isEqual(const Fortran::evaluate::SpecificIntrinsic &x,
457                       const Fortran::evaluate::SpecificIntrinsic &y) {
458     return x.name == y.name;
459   }
460   template <typename A>
isEqual(const Fortran::evaluate::Constant<A> & x,const Fortran::evaluate::Constant<A> & y)461   static bool isEqual(const Fortran::evaluate::Constant<A> &x,
462                       const Fortran::evaluate::Constant<A> &y) {
463     return x == y;
464   }
isEqual(const Fortran::evaluate::ActualArgument & x,const Fortran::evaluate::ActualArgument & y)465   static bool isEqual(const Fortran::evaluate::ActualArgument &x,
466                       const Fortran::evaluate::ActualArgument &y) {
467     if (const Fortran::evaluate::Symbol *xs = x.GetAssumedTypeDummy()) {
468       if (const Fortran::evaluate::Symbol *ys = y.GetAssumedTypeDummy())
469         return isEqual(*xs, *ys);
470       return false;
471     }
472     return !y.GetAssumedTypeDummy() &&
473            isEqual(*x.UnwrapExpr(), *y.UnwrapExpr());
474   }
isEqual(const Fortran::evaluate::ProcedureDesignator & x,const Fortran::evaluate::ProcedureDesignator & y)475   static bool isEqual(const Fortran::evaluate::ProcedureDesignator &x,
476                       const Fortran::evaluate::ProcedureDesignator &y) {
477     return std::visit(
478         [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u);
479   }
isEqual(const Fortran::evaluate::ProcedureRef & x,const Fortran::evaluate::ProcedureRef & y)480   static bool isEqual(const Fortran::evaluate::ProcedureRef &x,
481                       const Fortran::evaluate::ProcedureRef &y) {
482     return isEqual(x.proc(), y.proc()) && isEqual(x.arguments(), y.arguments());
483   }
484   template <typename A>
isEqual(const Fortran::evaluate::ArrayConstructor<A> & x,const Fortran::evaluate::ArrayConstructor<A> & y)485   static bool isEqual(const Fortran::evaluate::ArrayConstructor<A> &x,
486                       const Fortran::evaluate::ArrayConstructor<A> &y) {
487     llvm::report_fatal_error("not implemented");
488   }
isEqual(const Fortran::evaluate::ImpliedDoIndex & x,const Fortran::evaluate::ImpliedDoIndex & y)489   static bool isEqual(const Fortran::evaluate::ImpliedDoIndex &x,
490                       const Fortran::evaluate::ImpliedDoIndex &y) {
491     return toStringRef(x.name) == toStringRef(y.name);
492   }
isEqual(const Fortran::evaluate::TypeParamInquiry & x,const Fortran::evaluate::TypeParamInquiry & y)493   static bool isEqual(const Fortran::evaluate::TypeParamInquiry &x,
494                       const Fortran::evaluate::TypeParamInquiry &y) {
495     return isEqual(x.base(), y.base()) && isEqual(x.parameter(), y.parameter());
496   }
isEqual(const Fortran::evaluate::DescriptorInquiry & x,const Fortran::evaluate::DescriptorInquiry & y)497   static bool isEqual(const Fortran::evaluate::DescriptorInquiry &x,
498                       const Fortran::evaluate::DescriptorInquiry &y) {
499     return isEqual(x.base(), y.base()) && x.field() == y.field() &&
500            x.dimension() == y.dimension();
501   }
isEqual(const Fortran::evaluate::StructureConstructor & x,const Fortran::evaluate::StructureConstructor & y)502   static bool isEqual(const Fortran::evaluate::StructureConstructor &x,
503                       const Fortran::evaluate::StructureConstructor &y) {
504     llvm::report_fatal_error("not implemented");
505   }
506   template <int KIND>
isEqual(const Fortran::evaluate::Not<KIND> & x,const Fortran::evaluate::Not<KIND> & y)507   static bool isEqual(const Fortran::evaluate::Not<KIND> &x,
508                       const Fortran::evaluate::Not<KIND> &y) {
509     return isEqual(x.left(), y.left());
510   }
511   template <int KIND>
isEqual(const Fortran::evaluate::LogicalOperation<KIND> & x,const Fortran::evaluate::LogicalOperation<KIND> & y)512   static bool isEqual(const Fortran::evaluate::LogicalOperation<KIND> &x,
513                       const Fortran::evaluate::LogicalOperation<KIND> &y) {
514     return isEqual(x.left(), y.left()) && isEqual(x.right(), x.right());
515   }
516   template <typename A>
isEqual(const Fortran::evaluate::Relational<A> & x,const Fortran::evaluate::Relational<A> & y)517   static bool isEqual(const Fortran::evaluate::Relational<A> &x,
518                       const Fortran::evaluate::Relational<A> &y) {
519     return isEqual(x.left(), y.left()) && isEqual(x.right(), y.right());
520   }
521   template <typename A>
isEqual(const Fortran::evaluate::Expr<A> & x,const Fortran::evaluate::Expr<A> & y)522   static bool isEqual(const Fortran::evaluate::Expr<A> &x,
523                       const Fortran::evaluate::Expr<A> &y) {
524     return std::visit(
525         [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u);
526   }
527   static bool
isEqual(const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> & x,const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> & y)528   isEqual(const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> &x,
529           const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> &y) {
530     return std::visit(
531         [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u);
532   }
533   template <typename A>
isEqual(const Fortran::evaluate::Designator<A> & x,const Fortran::evaluate::Designator<A> & y)534   static bool isEqual(const Fortran::evaluate::Designator<A> &x,
535                       const Fortran::evaluate::Designator<A> &y) {
536     return std::visit(
537         [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u);
538   }
539   template <int BITS>
isEqual(const Fortran::evaluate::value::Integer<BITS> & x,const Fortran::evaluate::value::Integer<BITS> & y)540   static bool isEqual(const Fortran::evaluate::value::Integer<BITS> &x,
541                       const Fortran::evaluate::value::Integer<BITS> &y) {
542     return x == y;
543   }
isEqual(const Fortran::evaluate::NullPointer & x,const Fortran::evaluate::NullPointer & y)544   static bool isEqual(const Fortran::evaluate::NullPointer &x,
545                       const Fortran::evaluate::NullPointer &y) {
546     return true;
547   }
548   template <typename A, typename B,
549             std::enable_if_t<!std::is_same_v<A, B>, bool> = true>
isEqual(const A &,const B &)550   static bool isEqual(const A &, const B &) {
551     return false;
552   }
553 };
554 } // namespace
555 
isEqual(const Fortran::lower::ExplicitIterSpace::ArrayBases & x,const Fortran::lower::ExplicitIterSpace::ArrayBases & y)556 bool Fortran::lower::isEqual(
557     const Fortran::lower::ExplicitIterSpace::ArrayBases &x,
558     const Fortran::lower::ExplicitIterSpace::ArrayBases &y) {
559   return std::visit(
560       Fortran::common::visitors{
561           // Fortran::semantics::Symbol * are the exception here. These pointers
562           // have identity; if two Symbol * values are the same (different) then
563           // they are the same (different) logical symbol.
564           [&](Fortran::lower::FrontEndSymbol p,
565               Fortran::lower::FrontEndSymbol q) { return p == q; },
566           [&](const auto *p, const auto *q) {
567             if constexpr (std::is_same_v<decltype(p), decltype(q)>) {
568               LLVM_DEBUG(llvm::dbgs()
569                          << "is equal: " << p << ' ' << q << ' '
570                          << IsEqualEvaluateExpr::isEqual(*p, *q) << '\n');
571               return IsEqualEvaluateExpr::isEqual(*p, *q);
572             } else {
573               // Different subtree types are never equal.
574               return false;
575             }
576           }},
577       x, y);
578 }
579 
isEqual(Fortran::lower::FrontEndExpr x,Fortran::lower::FrontEndExpr y)580 bool Fortran::lower::isEqual(Fortran::lower::FrontEndExpr x,
581                              Fortran::lower::FrontEndExpr y) {
582   auto empty = llvm::DenseMapInfo<Fortran::lower::FrontEndExpr>::getEmptyKey();
583   auto tombstone =
584       llvm::DenseMapInfo<Fortran::lower::FrontEndExpr>::getTombstoneKey();
585   if (x == empty || y == empty || x == tombstone || y == tombstone)
586     return x == y;
587   return x == y || IsEqualEvaluateExpr::isEqual(*x, *y);
588 }
589 
590 namespace {
591 
592 /// This class can recover the base array in an expression that contains
593 /// explicit iteration space symbols. Most of the class can be ignored as it is
594 /// boilerplate Fortran::evaluate::Expr traversal.
595 class ArrayBaseFinder {
596 public:
597   using RT = bool;
598 
ArrayBaseFinder(llvm::ArrayRef<Fortran::lower::FrontEndSymbol> syms)599   ArrayBaseFinder(llvm::ArrayRef<Fortran::lower::FrontEndSymbol> syms)
600       : controlVars(syms.begin(), syms.end()) {}
601 
602   template <typename T>
operator ()(const T & x)603   void operator()(const T &x) {
604     (void)find(x);
605   }
606 
607   /// Get the list of bases.
608   llvm::ArrayRef<Fortran::lower::ExplicitIterSpace::ArrayBases>
getBases() const609   getBases() const {
610     LLVM_DEBUG(llvm::dbgs()
611                << "number of array bases found: " << bases.size() << '\n');
612     return bases;
613   }
614 
615 private:
616   // First, the cases that are of interest.
find(const Fortran::semantics::Symbol & symbol)617   RT find(const Fortran::semantics::Symbol &symbol) {
618     if (symbol.Rank() > 0) {
619       bases.push_back(&symbol);
620       return true;
621     }
622     return {};
623   }
find(const Fortran::evaluate::Component & x)624   RT find(const Fortran::evaluate::Component &x) {
625     auto found = find(x.base());
626     if (!found && x.base().Rank() == 0 && x.Rank() > 0) {
627       bases.push_back(&x);
628       return true;
629     }
630     return found;
631   }
find(const Fortran::evaluate::ArrayRef & x)632   RT find(const Fortran::evaluate::ArrayRef &x) {
633     for (const auto &sub : x.subscript())
634       (void)find(sub);
635     if (x.base().IsSymbol()) {
636       if (x.Rank() > 0 || intersection(x.subscript())) {
637         bases.push_back(&x);
638         return true;
639       }
640       return {};
641     }
642     auto found = find(x.base());
643     if (!found && ((x.base().Rank() == 0 && x.Rank() > 0) ||
644                    intersection(x.subscript()))) {
645       bases.push_back(&x);
646       return true;
647     }
648     return found;
649   }
find(const Fortran::evaluate::Triplet & x)650   RT find(const Fortran::evaluate::Triplet &x) {
651     if (const auto *lower = x.GetLower())
652       (void)find(*lower);
653     if (const auto *upper = x.GetUpper())
654       (void)find(*upper);
655     return find(x.GetStride());
656   }
find(const Fortran::evaluate::IndirectSubscriptIntegerExpr & x)657   RT find(const Fortran::evaluate::IndirectSubscriptIntegerExpr &x) {
658     return find(x.value());
659   }
find(const Fortran::evaluate::Subscript & x)660   RT find(const Fortran::evaluate::Subscript &x) { return find(x.u); }
find(const Fortran::evaluate::DataRef & x)661   RT find(const Fortran::evaluate::DataRef &x) { return find(x.u); }
find(const Fortran::evaluate::CoarrayRef & x)662   RT find(const Fortran::evaluate::CoarrayRef &x) {
663     assert(false && "coarray reference");
664     return {};
665   }
666 
667   template <typename A>
intersection(const A & subscripts)668   bool intersection(const A &subscripts) {
669     return Fortran::lower::symbolsIntersectSubscripts(controlVars, subscripts);
670   }
671 
672   // The rest is traversal boilerplate and can be ignored.
find(const Fortran::evaluate::Substring & x)673   RT find(const Fortran::evaluate::Substring &x) { return find(x.parent()); }
674   template <typename A>
find(const Fortran::semantics::SymbolRef x)675   RT find(const Fortran::semantics::SymbolRef x) {
676     return find(*x);
677   }
find(const Fortran::evaluate::NamedEntity & x)678   RT find(const Fortran::evaluate::NamedEntity &x) {
679     if (x.IsSymbol())
680       return find(x.GetFirstSymbol());
681     return find(x.GetComponent());
682   }
683 
684   template <typename A, bool C>
find(const Fortran::common::Indirection<A,C> & x)685   RT find(const Fortran::common::Indirection<A, C> &x) {
686     return find(x.value());
687   }
688   template <typename A>
find(const std::unique_ptr<A> & x)689   RT find(const std::unique_ptr<A> &x) {
690     return find(x.get());
691   }
692   template <typename A>
find(const std::shared_ptr<A> & x)693   RT find(const std::shared_ptr<A> &x) {
694     return find(x.get());
695   }
696   template <typename A>
find(const A * x)697   RT find(const A *x) {
698     if (x)
699       return find(*x);
700     return {};
701   }
702   template <typename A>
find(const std::optional<A> & x)703   RT find(const std::optional<A> &x) {
704     if (x)
705       return find(*x);
706     return {};
707   }
708   template <typename... A>
find(const std::variant<A...> & u)709   RT find(const std::variant<A...> &u) {
710     return std::visit([&](const auto &v) { return find(v); }, u);
711   }
712   template <typename A>
find(const std::vector<A> & x)713   RT find(const std::vector<A> &x) {
714     for (auto &v : x)
715       (void)find(v);
716     return {};
717   }
find(const Fortran::evaluate::BOZLiteralConstant &)718   RT find(const Fortran::evaluate::BOZLiteralConstant &) { return {}; }
find(const Fortran::evaluate::NullPointer &)719   RT find(const Fortran::evaluate::NullPointer &) { return {}; }
720   template <typename T>
find(const Fortran::evaluate::Constant<T> & x)721   RT find(const Fortran::evaluate::Constant<T> &x) {
722     return {};
723   }
find(const Fortran::evaluate::StaticDataObject &)724   RT find(const Fortran::evaluate::StaticDataObject &) { return {}; }
find(const Fortran::evaluate::ImpliedDoIndex &)725   RT find(const Fortran::evaluate::ImpliedDoIndex &) { return {}; }
find(const Fortran::evaluate::BaseObject & x)726   RT find(const Fortran::evaluate::BaseObject &x) {
727     (void)find(x.u);
728     return {};
729   }
find(const Fortran::evaluate::TypeParamInquiry &)730   RT find(const Fortran::evaluate::TypeParamInquiry &) { return {}; }
find(const Fortran::evaluate::ComplexPart & x)731   RT find(const Fortran::evaluate::ComplexPart &x) { return {}; }
732   template <typename T>
find(const Fortran::evaluate::Designator<T> & x)733   RT find(const Fortran::evaluate::Designator<T> &x) {
734     return find(x.u);
735   }
736   template <typename T>
find(const Fortran::evaluate::Variable<T> & x)737   RT find(const Fortran::evaluate::Variable<T> &x) {
738     return find(x.u);
739   }
find(const Fortran::evaluate::DescriptorInquiry &)740   RT find(const Fortran::evaluate::DescriptorInquiry &) { return {}; }
find(const Fortran::evaluate::SpecificIntrinsic &)741   RT find(const Fortran::evaluate::SpecificIntrinsic &) { return {}; }
find(const Fortran::evaluate::ProcedureDesignator & x)742   RT find(const Fortran::evaluate::ProcedureDesignator &x) { return {}; }
find(const Fortran::evaluate::ProcedureRef & x)743   RT find(const Fortran::evaluate::ProcedureRef &x) {
744     (void)find(x.proc());
745     if (x.IsElemental())
746       (void)find(x.arguments());
747     return {};
748   }
find(const Fortran::evaluate::ActualArgument & x)749   RT find(const Fortran::evaluate::ActualArgument &x) {
750     if (const auto *sym = x.GetAssumedTypeDummy())
751       (void)find(*sym);
752     else
753       (void)find(x.UnwrapExpr());
754     return {};
755   }
756   template <typename T>
find(const Fortran::evaluate::FunctionRef<T> & x)757   RT find(const Fortran::evaluate::FunctionRef<T> &x) {
758     (void)find(static_cast<const Fortran::evaluate::ProcedureRef &>(x));
759     return {};
760   }
761   template <typename T>
find(const Fortran::evaluate::ArrayConstructorValue<T> &)762   RT find(const Fortran::evaluate::ArrayConstructorValue<T> &) {
763     return {};
764   }
765   template <typename T>
find(const Fortran::evaluate::ArrayConstructorValues<T> &)766   RT find(const Fortran::evaluate::ArrayConstructorValues<T> &) {
767     return {};
768   }
769   template <typename T>
find(const Fortran::evaluate::ImpliedDo<T> &)770   RT find(const Fortran::evaluate::ImpliedDo<T> &) {
771     return {};
772   }
find(const Fortran::semantics::ParamValue &)773   RT find(const Fortran::semantics::ParamValue &) { return {}; }
find(const Fortran::semantics::DerivedTypeSpec &)774   RT find(const Fortran::semantics::DerivedTypeSpec &) { return {}; }
find(const Fortran::evaluate::StructureConstructor &)775   RT find(const Fortran::evaluate::StructureConstructor &) { return {}; }
776   template <typename D, typename R, typename O>
find(const Fortran::evaluate::Operation<D,R,O> & op)777   RT find(const Fortran::evaluate::Operation<D, R, O> &op) {
778     (void)find(op.left());
779     return false;
780   }
781   template <typename D, typename R, typename LO, typename RO>
find(const Fortran::evaluate::Operation<D,R,LO,RO> & op)782   RT find(const Fortran::evaluate::Operation<D, R, LO, RO> &op) {
783     (void)find(op.left());
784     (void)find(op.right());
785     return false;
786   }
find(const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> & x)787   RT find(const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> &x) {
788     (void)find(x.u);
789     return {};
790   }
791   template <typename T>
find(const Fortran::evaluate::Expr<T> & x)792   RT find(const Fortran::evaluate::Expr<T> &x) {
793     (void)find(x.u);
794     return {};
795   }
796 
797   llvm::SmallVector<Fortran::lower::ExplicitIterSpace::ArrayBases> bases;
798   llvm::SmallVector<Fortran::lower::FrontEndSymbol> controlVars;
799 };
800 
801 } // namespace
802 
leave()803 void Fortran::lower::ExplicitIterSpace::leave() {
804   ccLoopNest.pop_back();
805   --forallContextOpen;
806   conditionalCleanup();
807 }
808 
addSymbol(Fortran::lower::FrontEndSymbol sym)809 void Fortran::lower::ExplicitIterSpace::addSymbol(
810     Fortran::lower::FrontEndSymbol sym) {
811   assert(!symbolStack.empty());
812   symbolStack.back().push_back(sym);
813 }
814 
exprBase(Fortran::lower::FrontEndExpr x,bool lhs)815 void Fortran::lower::ExplicitIterSpace::exprBase(Fortran::lower::FrontEndExpr x,
816                                                  bool lhs) {
817   ArrayBaseFinder finder(collectAllSymbols());
818   finder(*x);
819   llvm::ArrayRef<Fortran::lower::ExplicitIterSpace::ArrayBases> bases =
820       finder.getBases();
821   if (rhsBases.empty())
822     endAssign();
823   if (lhs) {
824     if (bases.empty()) {
825       lhsBases.push_back(llvm::None);
826       return;
827     }
828     assert(bases.size() >= 1 && "must detect an array reference on lhs");
829     if (bases.size() > 1)
830       rhsBases.back().append(bases.begin(), bases.end() - 1);
831     lhsBases.push_back(bases.back());
832     return;
833   }
834   rhsBases.back().append(bases.begin(), bases.end());
835 }
836 
endAssign()837 void Fortran::lower::ExplicitIterSpace::endAssign() { rhsBases.emplace_back(); }
838 
pushLevel()839 void Fortran::lower::ExplicitIterSpace::pushLevel() {
840   symbolStack.push_back(llvm::SmallVector<Fortran::lower::FrontEndSymbol>{});
841 }
842 
popLevel()843 void Fortran::lower::ExplicitIterSpace::popLevel() { symbolStack.pop_back(); }
844 
conditionalCleanup()845 void Fortran::lower::ExplicitIterSpace::conditionalCleanup() {
846   if (forallContextOpen == 0) {
847     // Exiting the outermost FORALL context.
848     // Cleanup any residual mask buffers.
849     outermostContext().finalize();
850     // Clear and reset all the cached information.
851     symbolStack.clear();
852     lhsBases.clear();
853     rhsBases.clear();
854     loadBindings.clear();
855     ccLoopNest.clear();
856     innerArgs.clear();
857     outerLoop = llvm::None;
858     clearLoops();
859     counter = 0;
860   }
861 }
862 
863 llvm::Optional<size_t>
findArgPosition(fir::ArrayLoadOp load)864 Fortran::lower::ExplicitIterSpace::findArgPosition(fir::ArrayLoadOp load) {
865   if (lhsBases[counter]) {
866     auto ld = loadBindings.find(*lhsBases[counter]);
867     llvm::Optional<size_t> optPos;
868     if (ld != loadBindings.end() && ld->second == load)
869       optPos = static_cast<size_t>(0u);
870     assert(optPos.has_value() && "load does not correspond to lhs");
871     return optPos;
872   }
873   return llvm::None;
874 }
875 
876 llvm::SmallVector<Fortran::lower::FrontEndSymbol>
collectAllSymbols()877 Fortran::lower::ExplicitIterSpace::collectAllSymbols() {
878   llvm::SmallVector<Fortran::lower::FrontEndSymbol> result;
879   for (llvm::SmallVector<FrontEndSymbol> vec : symbolStack)
880     result.append(vec.begin(), vec.end());
881   return result;
882 }
883 
884 llvm::raw_ostream &
operator <<(llvm::raw_ostream & s,const Fortran::lower::ImplicitIterSpace & e)885 Fortran::lower::operator<<(llvm::raw_ostream &s,
886                            const Fortran::lower::ImplicitIterSpace &e) {
887   for (const llvm::SmallVector<
888            Fortran::lower::ImplicitIterSpace::FrontEndMaskExpr> &xs :
889        e.getMasks()) {
890     s << "{ ";
891     for (const Fortran::lower::ImplicitIterSpace::FrontEndMaskExpr &x : xs)
892       x->AsFortran(s << '(') << "), ";
893     s << "}\n";
894   }
895   return s;
896 }
897 
898 llvm::raw_ostream &
operator <<(llvm::raw_ostream & s,const Fortran::lower::ExplicitIterSpace & e)899 Fortran::lower::operator<<(llvm::raw_ostream &s,
900                            const Fortran::lower::ExplicitIterSpace &e) {
901   auto dump = [&](const auto &u) {
902     std::visit(Fortran::common::visitors{
903                    [&](const Fortran::semantics::Symbol *y) {
904                      s << "  " << *y << '\n';
905                    },
906                    [&](const Fortran::evaluate::ArrayRef *y) {
907                      s << "  ";
908                      if (y->base().IsSymbol())
909                        s << y->base().GetFirstSymbol();
910                      else
911                        s << y->base().GetComponent().GetLastSymbol();
912                      s << '\n';
913                    },
914                    [&](const Fortran::evaluate::Component *y) {
915                      s << "  " << y->GetLastSymbol() << '\n';
916                    }},
917                u);
918   };
919   s << "LHS bases:\n";
920   for (const llvm::Optional<Fortran::lower::ExplicitIterSpace::ArrayBases> &u :
921        e.lhsBases)
922     if (u)
923       dump(*u);
924   s << "RHS bases:\n";
925   for (const llvm::SmallVector<Fortran::lower::ExplicitIterSpace::ArrayBases>
926            &bases : e.rhsBases) {
927     for (const Fortran::lower::ExplicitIterSpace::ArrayBases &u : bases)
928       dump(u);
929     s << '\n';
930   }
931   return s;
932 }
933 
dump() const934 void Fortran::lower::ImplicitIterSpace::dump() const {
935   llvm::errs() << *this << '\n';
936 }
937 
dump() const938 void Fortran::lower::ExplicitIterSpace::dump() const {
939   llvm::errs() << *this << '\n';
940 }
941