1 //===-- IterationSpace.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "flang/Lower/IterationSpace.h"
14 #include "flang/Evaluate/expression.h"
15 #include "flang/Lower/AbstractConverter.h"
16 #include "flang/Lower/Support/Utils.h"
17 #include "llvm/Support/Debug.h"
18 
19 #define DEBUG_TYPE "flang-lower-iteration-space"
20 
21 namespace {
22 // Fortran::evaluate::Expr are functional values organized like an AST. A
23 // Fortran::evaluate::Expr is meant to be moved and cloned. Using the front end
24 // tools can often cause copies and extra wrapper classes to be added to any
25 // Fortran::evalute::Expr. These values should not be assumed or relied upon to
26 // have an *object* identity. They are deeply recursive, irregular structures
27 // built from a large number of classes which do not use inheritance and
28 // necessitate a large volume of boilerplate code as a result.
29 //
30 // Contrastingly, LLVM data structures make ubiquitous assumptions about an
31 // object's identity via pointers to the object. An object's location in memory
32 // is thus very often an identifying relation.
33 
34 // This class defines a hash computation of a Fortran::evaluate::Expr tree value
35 // so it can be used with llvm::DenseMap. The Fortran::evaluate::Expr need not
36 // have the same address.
37 class HashEvaluateExpr {
38 public:
39   // A Se::Symbol is the only part of an Fortran::evaluate::Expr with an
40   // identity property.
41   static unsigned getHashValue(const Fortran::semantics::Symbol &x) {
42     return static_cast<unsigned>(reinterpret_cast<std::intptr_t>(&x));
43   }
44   template <typename A, bool COPY>
45   static unsigned getHashValue(const Fortran::common::Indirection<A, COPY> &x) {
46     return getHashValue(x.value());
47   }
48   template <typename A>
49   static unsigned getHashValue(const std::optional<A> &x) {
50     if (x.has_value())
51       return getHashValue(x.value());
52     return 0u;
53   }
54   static unsigned getHashValue(const Fortran::evaluate::Subscript &x) {
55     return std::visit([&](const auto &v) { return getHashValue(v); }, x.u);
56   }
57   static unsigned getHashValue(const Fortran::evaluate::Triplet &x) {
58     return getHashValue(x.lower()) - getHashValue(x.upper()) * 5u -
59            getHashValue(x.stride()) * 11u;
60   }
61   static unsigned getHashValue(const Fortran::evaluate::Component &x) {
62     return getHashValue(x.base()) * 83u - getHashValue(x.GetLastSymbol());
63   }
64   static unsigned getHashValue(const Fortran::evaluate::ArrayRef &x) {
65     unsigned subs = 1u;
66     for (const Fortran::evaluate::Subscript &v : x.subscript())
67       subs -= getHashValue(v);
68     return getHashValue(x.base()) * 89u - subs;
69   }
70   static unsigned getHashValue(const Fortran::evaluate::CoarrayRef &x) {
71     unsigned subs = 1u;
72     for (const Fortran::evaluate::Subscript &v : x.subscript())
73       subs -= getHashValue(v);
74     unsigned cosubs = 3u;
75     for (const Fortran::evaluate::Expr<Fortran::evaluate::SubscriptInteger> &v :
76          x.cosubscript())
77       cosubs -= getHashValue(v);
78     unsigned syms = 7u;
79     for (const Fortran::evaluate::SymbolRef &v : x.base())
80       syms += getHashValue(v);
81     return syms * 97u - subs - cosubs + getHashValue(x.stat()) + 257u +
82            getHashValue(x.team());
83   }
84   static unsigned getHashValue(const Fortran::evaluate::NamedEntity &x) {
85     if (x.IsSymbol())
86       return getHashValue(x.GetFirstSymbol()) * 11u;
87     return getHashValue(x.GetComponent()) * 13u;
88   }
89   static unsigned getHashValue(const Fortran::evaluate::DataRef &x) {
90     return std::visit([&](const auto &v) { return getHashValue(v); }, x.u);
91   }
92   static unsigned getHashValue(const Fortran::evaluate::ComplexPart &x) {
93     return getHashValue(x.complex()) - static_cast<unsigned>(x.part());
94   }
95   template <Fortran::common::TypeCategory TC1, int KIND,
96             Fortran::common::TypeCategory TC2>
97   static unsigned getHashValue(
98       const Fortran::evaluate::Convert<Fortran::evaluate::Type<TC1, KIND>, TC2>
99           &x) {
100     return getHashValue(x.left()) - (static_cast<unsigned>(TC1) + 2u) -
101            (static_cast<unsigned>(KIND) + 5u);
102   }
103   template <int KIND>
104   static unsigned
105   getHashValue(const Fortran::evaluate::ComplexComponent<KIND> &x) {
106     return getHashValue(x.left()) -
107            (static_cast<unsigned>(x.isImaginaryPart) + 1u) * 3u;
108   }
109   template <typename T>
110   static unsigned getHashValue(const Fortran::evaluate::Parentheses<T> &x) {
111     return getHashValue(x.left()) * 17u;
112   }
113   template <Fortran::common::TypeCategory TC, int KIND>
114   static unsigned getHashValue(
115       const Fortran::evaluate::Negate<Fortran::evaluate::Type<TC, KIND>> &x) {
116     return getHashValue(x.left()) - (static_cast<unsigned>(TC) + 5u) -
117            (static_cast<unsigned>(KIND) + 7u);
118   }
119   template <Fortran::common::TypeCategory TC, int KIND>
120   static unsigned getHashValue(
121       const Fortran::evaluate::Add<Fortran::evaluate::Type<TC, KIND>> &x) {
122     return (getHashValue(x.left()) + getHashValue(x.right())) * 23u +
123            static_cast<unsigned>(TC) + static_cast<unsigned>(KIND);
124   }
125   template <Fortran::common::TypeCategory TC, int KIND>
126   static unsigned getHashValue(
127       const Fortran::evaluate::Subtract<Fortran::evaluate::Type<TC, KIND>> &x) {
128     return (getHashValue(x.left()) - getHashValue(x.right())) * 19u +
129            static_cast<unsigned>(TC) + static_cast<unsigned>(KIND);
130   }
131   template <Fortran::common::TypeCategory TC, int KIND>
132   static unsigned getHashValue(
133       const Fortran::evaluate::Multiply<Fortran::evaluate::Type<TC, KIND>> &x) {
134     return (getHashValue(x.left()) + getHashValue(x.right())) * 29u +
135            static_cast<unsigned>(TC) + static_cast<unsigned>(KIND);
136   }
137   template <Fortran::common::TypeCategory TC, int KIND>
138   static unsigned getHashValue(
139       const Fortran::evaluate::Divide<Fortran::evaluate::Type<TC, KIND>> &x) {
140     return (getHashValue(x.left()) - getHashValue(x.right())) * 31u +
141            static_cast<unsigned>(TC) + static_cast<unsigned>(KIND);
142   }
143   template <Fortran::common::TypeCategory TC, int KIND>
144   static unsigned getHashValue(
145       const Fortran::evaluate::Power<Fortran::evaluate::Type<TC, KIND>> &x) {
146     return (getHashValue(x.left()) - getHashValue(x.right())) * 37u +
147            static_cast<unsigned>(TC) + static_cast<unsigned>(KIND);
148   }
149   template <Fortran::common::TypeCategory TC, int KIND>
150   static unsigned getHashValue(
151       const Fortran::evaluate::Extremum<Fortran::evaluate::Type<TC, KIND>> &x) {
152     return (getHashValue(x.left()) + getHashValue(x.right())) * 41u +
153            static_cast<unsigned>(TC) + static_cast<unsigned>(KIND) +
154            static_cast<unsigned>(x.ordering) * 7u;
155   }
156   template <Fortran::common::TypeCategory TC, int KIND>
157   static unsigned getHashValue(
158       const Fortran::evaluate::RealToIntPower<Fortran::evaluate::Type<TC, KIND>>
159           &x) {
160     return (getHashValue(x.left()) - getHashValue(x.right())) * 43u +
161            static_cast<unsigned>(TC) + static_cast<unsigned>(KIND);
162   }
163   template <int KIND>
164   static unsigned
165   getHashValue(const Fortran::evaluate::ComplexConstructor<KIND> &x) {
166     return (getHashValue(x.left()) - getHashValue(x.right())) * 47u +
167            static_cast<unsigned>(KIND);
168   }
169   template <int KIND>
170   static unsigned getHashValue(const Fortran::evaluate::Concat<KIND> &x) {
171     return (getHashValue(x.left()) - getHashValue(x.right())) * 53u +
172            static_cast<unsigned>(KIND);
173   }
174   template <int KIND>
175   static unsigned getHashValue(const Fortran::evaluate::SetLength<KIND> &x) {
176     return (getHashValue(x.left()) - getHashValue(x.right())) * 59u +
177            static_cast<unsigned>(KIND);
178   }
179   static unsigned getHashValue(const Fortran::semantics::SymbolRef &sym) {
180     return getHashValue(sym.get());
181   }
182   static unsigned getHashValue(const Fortran::evaluate::Substring &x) {
183     return 61u * std::visit([&](const auto &p) { return getHashValue(p); },
184                             x.parent()) -
185            getHashValue(x.lower()) - (getHashValue(x.lower()) + 1u);
186   }
187   static unsigned
188   getHashValue(const Fortran::evaluate::StaticDataObject::Pointer &x) {
189     return llvm::hash_value(x->name());
190   }
191   static unsigned getHashValue(const Fortran::evaluate::SpecificIntrinsic &x) {
192     return llvm::hash_value(x.name);
193   }
194   template <typename A>
195   static unsigned getHashValue(const Fortran::evaluate::Constant<A> &x) {
196     // FIXME: Should hash the content.
197     return 103u;
198   }
199   static unsigned getHashValue(const Fortran::evaluate::ActualArgument &x) {
200     if (const Fortran::evaluate::Symbol *sym = x.GetAssumedTypeDummy())
201       return getHashValue(*sym);
202     return getHashValue(*x.UnwrapExpr());
203   }
204   static unsigned
205   getHashValue(const Fortran::evaluate::ProcedureDesignator &x) {
206     return std::visit([&](const auto &v) { return getHashValue(v); }, x.u);
207   }
208   static unsigned getHashValue(const Fortran::evaluate::ProcedureRef &x) {
209     unsigned args = 13u;
210     for (const std::optional<Fortran::evaluate::ActualArgument> &v :
211          x.arguments())
212       args -= getHashValue(v);
213     return getHashValue(x.proc()) * 101u - args;
214   }
215   template <typename A>
216   static unsigned
217   getHashValue(const Fortran::evaluate::ArrayConstructor<A> &x) {
218     // FIXME: hash the contents.
219     return 127u;
220   }
221   static unsigned getHashValue(const Fortran::evaluate::ImpliedDoIndex &x) {
222     return llvm::hash_value(toStringRef(x.name).str()) * 131u;
223   }
224   static unsigned getHashValue(const Fortran::evaluate::TypeParamInquiry &x) {
225     return getHashValue(x.base()) * 137u - getHashValue(x.parameter()) * 3u;
226   }
227   static unsigned getHashValue(const Fortran::evaluate::DescriptorInquiry &x) {
228     return getHashValue(x.base()) * 139u -
229            static_cast<unsigned>(x.field()) * 13u +
230            static_cast<unsigned>(x.dimension());
231   }
232   static unsigned
233   getHashValue(const Fortran::evaluate::StructureConstructor &x) {
234     // FIXME: hash the contents.
235     return 149u;
236   }
237   template <int KIND>
238   static unsigned getHashValue(const Fortran::evaluate::Not<KIND> &x) {
239     return getHashValue(x.left()) * 61u + static_cast<unsigned>(KIND);
240   }
241   template <int KIND>
242   static unsigned
243   getHashValue(const Fortran::evaluate::LogicalOperation<KIND> &x) {
244     unsigned result = getHashValue(x.left()) + getHashValue(x.right());
245     return result * 67u + static_cast<unsigned>(x.logicalOperator) * 5u;
246   }
247   template <Fortran::common::TypeCategory TC, int KIND>
248   static unsigned getHashValue(
249       const Fortran::evaluate::Relational<Fortran::evaluate::Type<TC, KIND>>
250           &x) {
251     return (getHashValue(x.left()) + getHashValue(x.right())) * 71u +
252            static_cast<unsigned>(TC) + static_cast<unsigned>(KIND) +
253            static_cast<unsigned>(x.opr) * 11u;
254   }
255   template <typename A>
256   static unsigned getHashValue(const Fortran::evaluate::Expr<A> &x) {
257     return std::visit([&](const auto &v) { return getHashValue(v); }, x.u);
258   }
259   static unsigned getHashValue(
260       const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> &x) {
261     return std::visit([&](const auto &v) { return getHashValue(v); }, x.u);
262   }
263   template <typename A>
264   static unsigned getHashValue(const Fortran::evaluate::Designator<A> &x) {
265     return std::visit([&](const auto &v) { return getHashValue(v); }, x.u);
266   }
267   template <int BITS>
268   static unsigned
269   getHashValue(const Fortran::evaluate::value::Integer<BITS> &x) {
270     return static_cast<unsigned>(x.ToSInt());
271   }
272   static unsigned getHashValue(const Fortran::evaluate::NullPointer &x) {
273     return ~179u;
274   }
275 };
276 } // namespace
277 
278 unsigned Fortran::lower::getHashValue(
279     const Fortran::lower::ExplicitIterSpace::ArrayBases &x) {
280   return std::visit(
281       [&](const auto *p) { return HashEvaluateExpr::getHashValue(*p); }, x);
282 }
283 
284 unsigned Fortran::lower::getHashValue(Fortran::lower::FrontEndExpr x) {
285   return HashEvaluateExpr::getHashValue(*x);
286 }
287 
288 namespace {
289 // Define the is equals test for using Fortran::evaluate::Expr values with
290 // llvm::DenseMap.
291 class IsEqualEvaluateExpr {
292 public:
293   // A Se::Symbol is the only part of an Fortran::evaluate::Expr with an
294   // identity property.
295   static bool isEqual(const Fortran::semantics::Symbol &x,
296                       const Fortran::semantics::Symbol &y) {
297     return isEqual(&x, &y);
298   }
299   static bool isEqual(const Fortran::semantics::Symbol *x,
300                       const Fortran::semantics::Symbol *y) {
301     return x == y;
302   }
303   template <typename A, bool COPY>
304   static bool isEqual(const Fortran::common::Indirection<A, COPY> &x,
305                       const Fortran::common::Indirection<A, COPY> &y) {
306     return isEqual(x.value(), y.value());
307   }
308   template <typename A>
309   static bool isEqual(const std::optional<A> &x, const std::optional<A> &y) {
310     if (x.has_value() && y.has_value())
311       return isEqual(x.value(), y.value());
312     return !x.has_value() && !y.has_value();
313   }
314   template <typename A>
315   static bool isEqual(const std::vector<A> &x, const std::vector<A> &y) {
316     if (x.size() != y.size())
317       return false;
318     const std::size_t size = x.size();
319     for (std::remove_const_t<decltype(size)> i = 0; i < size; ++i)
320       if (!isEqual(x[i], y[i]))
321         return false;
322     return true;
323   }
324   static bool isEqual(const Fortran::evaluate::Subscript &x,
325                       const Fortran::evaluate::Subscript &y) {
326     return std::visit(
327         [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u);
328   }
329   static bool isEqual(const Fortran::evaluate::Triplet &x,
330                       const Fortran::evaluate::Triplet &y) {
331     return isEqual(x.lower(), y.lower()) && isEqual(x.upper(), y.upper()) &&
332            isEqual(x.stride(), y.stride());
333   }
334   static bool isEqual(const Fortran::evaluate::Component &x,
335                       const Fortran::evaluate::Component &y) {
336     return isEqual(x.base(), y.base()) &&
337            isEqual(x.GetLastSymbol(), y.GetLastSymbol());
338   }
339   static bool isEqual(const Fortran::evaluate::ArrayRef &x,
340                       const Fortran::evaluate::ArrayRef &y) {
341     return isEqual(x.base(), y.base()) && isEqual(x.subscript(), y.subscript());
342   }
343   static bool isEqual(const Fortran::evaluate::CoarrayRef &x,
344                       const Fortran::evaluate::CoarrayRef &y) {
345     return isEqual(x.base(), y.base()) &&
346            isEqual(x.subscript(), y.subscript()) &&
347            isEqual(x.cosubscript(), y.cosubscript()) &&
348            isEqual(x.stat(), y.stat()) && isEqual(x.team(), y.team());
349   }
350   static bool isEqual(const Fortran::evaluate::NamedEntity &x,
351                       const Fortran::evaluate::NamedEntity &y) {
352     if (x.IsSymbol() && y.IsSymbol())
353       return isEqual(x.GetFirstSymbol(), y.GetFirstSymbol());
354     return !x.IsSymbol() && !y.IsSymbol() &&
355            isEqual(x.GetComponent(), y.GetComponent());
356   }
357   static bool isEqual(const Fortran::evaluate::DataRef &x,
358                       const Fortran::evaluate::DataRef &y) {
359     return std::visit(
360         [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u);
361   }
362   static bool isEqual(const Fortran::evaluate::ComplexPart &x,
363                       const Fortran::evaluate::ComplexPart &y) {
364     return isEqual(x.complex(), y.complex()) && x.part() == y.part();
365   }
366   template <typename A, Fortran::common::TypeCategory TC2>
367   static bool isEqual(const Fortran::evaluate::Convert<A, TC2> &x,
368                       const Fortran::evaluate::Convert<A, TC2> &y) {
369     return isEqual(x.left(), y.left());
370   }
371   template <int KIND>
372   static bool isEqual(const Fortran::evaluate::ComplexComponent<KIND> &x,
373                       const Fortran::evaluate::ComplexComponent<KIND> &y) {
374     return isEqual(x.left(), y.left()) &&
375            x.isImaginaryPart == y.isImaginaryPart;
376   }
377   template <typename T>
378   static bool isEqual(const Fortran::evaluate::Parentheses<T> &x,
379                       const Fortran::evaluate::Parentheses<T> &y) {
380     return isEqual(x.left(), y.left());
381   }
382   template <typename A>
383   static bool isEqual(const Fortran::evaluate::Negate<A> &x,
384                       const Fortran::evaluate::Negate<A> &y) {
385     return isEqual(x.left(), y.left());
386   }
387   template <typename A>
388   static bool isBinaryEqual(const A &x, const A &y) {
389     return isEqual(x.left(), y.left()) && isEqual(x.right(), y.right());
390   }
391   template <typename A>
392   static bool isEqual(const Fortran::evaluate::Add<A> &x,
393                       const Fortran::evaluate::Add<A> &y) {
394     return isBinaryEqual(x, y);
395   }
396   template <typename A>
397   static bool isEqual(const Fortran::evaluate::Subtract<A> &x,
398                       const Fortran::evaluate::Subtract<A> &y) {
399     return isBinaryEqual(x, y);
400   }
401   template <typename A>
402   static bool isEqual(const Fortran::evaluate::Multiply<A> &x,
403                       const Fortran::evaluate::Multiply<A> &y) {
404     return isBinaryEqual(x, y);
405   }
406   template <typename A>
407   static bool isEqual(const Fortran::evaluate::Divide<A> &x,
408                       const Fortran::evaluate::Divide<A> &y) {
409     return isBinaryEqual(x, y);
410   }
411   template <typename A>
412   static bool isEqual(const Fortran::evaluate::Power<A> &x,
413                       const Fortran::evaluate::Power<A> &y) {
414     return isBinaryEqual(x, y);
415   }
416   template <typename A>
417   static bool isEqual(const Fortran::evaluate::Extremum<A> &x,
418                       const Fortran::evaluate::Extremum<A> &y) {
419     return isBinaryEqual(x, y);
420   }
421   template <typename A>
422   static bool isEqual(const Fortran::evaluate::RealToIntPower<A> &x,
423                       const Fortran::evaluate::RealToIntPower<A> &y) {
424     return isBinaryEqual(x, y);
425   }
426   template <int KIND>
427   static bool isEqual(const Fortran::evaluate::ComplexConstructor<KIND> &x,
428                       const Fortran::evaluate::ComplexConstructor<KIND> &y) {
429     return isBinaryEqual(x, y);
430   }
431   template <int KIND>
432   static bool isEqual(const Fortran::evaluate::Concat<KIND> &x,
433                       const Fortran::evaluate::Concat<KIND> &y) {
434     return isBinaryEqual(x, y);
435   }
436   template <int KIND>
437   static bool isEqual(const Fortran::evaluate::SetLength<KIND> &x,
438                       const Fortran::evaluate::SetLength<KIND> &y) {
439     return isBinaryEqual(x, y);
440   }
441   static bool isEqual(const Fortran::semantics::SymbolRef &x,
442                       const Fortran::semantics::SymbolRef &y) {
443     return isEqual(x.get(), y.get());
444   }
445   static bool isEqual(const Fortran::evaluate::Substring &x,
446                       const Fortran::evaluate::Substring &y) {
447     return std::visit(
448                [&](const auto &p, const auto &q) { return isEqual(p, q); },
449                x.parent(), y.parent()) &&
450            isEqual(x.lower(), y.lower()) && isEqual(x.lower(), y.lower());
451   }
452   static bool isEqual(const Fortran::evaluate::StaticDataObject::Pointer &x,
453                       const Fortran::evaluate::StaticDataObject::Pointer &y) {
454     return x->name() == y->name();
455   }
456   static bool isEqual(const Fortran::evaluate::SpecificIntrinsic &x,
457                       const Fortran::evaluate::SpecificIntrinsic &y) {
458     return x.name == y.name;
459   }
460   template <typename A>
461   static bool isEqual(const Fortran::evaluate::Constant<A> &x,
462                       const Fortran::evaluate::Constant<A> &y) {
463     return x == y;
464   }
465   static bool isEqual(const Fortran::evaluate::ActualArgument &x,
466                       const Fortran::evaluate::ActualArgument &y) {
467     if (const Fortran::evaluate::Symbol *xs = x.GetAssumedTypeDummy()) {
468       if (const Fortran::evaluate::Symbol *ys = y.GetAssumedTypeDummy())
469         return isEqual(*xs, *ys);
470       return false;
471     }
472     return !y.GetAssumedTypeDummy() &&
473            isEqual(*x.UnwrapExpr(), *y.UnwrapExpr());
474   }
475   static bool isEqual(const Fortran::evaluate::ProcedureDesignator &x,
476                       const Fortran::evaluate::ProcedureDesignator &y) {
477     return std::visit(
478         [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u);
479   }
480   static bool isEqual(const Fortran::evaluate::ProcedureRef &x,
481                       const Fortran::evaluate::ProcedureRef &y) {
482     return isEqual(x.proc(), y.proc()) && isEqual(x.arguments(), y.arguments());
483   }
484   template <typename A>
485   static bool isEqual(const Fortran::evaluate::ArrayConstructor<A> &x,
486                       const Fortran::evaluate::ArrayConstructor<A> &y) {
487     llvm::report_fatal_error("not implemented");
488   }
489   static bool isEqual(const Fortran::evaluate::ImpliedDoIndex &x,
490                       const Fortran::evaluate::ImpliedDoIndex &y) {
491     return toStringRef(x.name) == toStringRef(y.name);
492   }
493   static bool isEqual(const Fortran::evaluate::TypeParamInquiry &x,
494                       const Fortran::evaluate::TypeParamInquiry &y) {
495     return isEqual(x.base(), y.base()) && isEqual(x.parameter(), y.parameter());
496   }
497   static bool isEqual(const Fortran::evaluate::DescriptorInquiry &x,
498                       const Fortran::evaluate::DescriptorInquiry &y) {
499     return isEqual(x.base(), y.base()) && x.field() == y.field() &&
500            x.dimension() == y.dimension();
501   }
502   static bool isEqual(const Fortran::evaluate::StructureConstructor &x,
503                       const Fortran::evaluate::StructureConstructor &y) {
504     llvm::report_fatal_error("not implemented");
505   }
506   template <int KIND>
507   static bool isEqual(const Fortran::evaluate::Not<KIND> &x,
508                       const Fortran::evaluate::Not<KIND> &y) {
509     return isEqual(x.left(), y.left());
510   }
511   template <int KIND>
512   static bool isEqual(const Fortran::evaluate::LogicalOperation<KIND> &x,
513                       const Fortran::evaluate::LogicalOperation<KIND> &y) {
514     return isEqual(x.left(), y.left()) && isEqual(x.right(), x.right());
515   }
516   template <typename A>
517   static bool isEqual(const Fortran::evaluate::Relational<A> &x,
518                       const Fortran::evaluate::Relational<A> &y) {
519     return isEqual(x.left(), y.left()) && isEqual(x.right(), y.right());
520   }
521   template <typename A>
522   static bool isEqual(const Fortran::evaluate::Expr<A> &x,
523                       const Fortran::evaluate::Expr<A> &y) {
524     return std::visit(
525         [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u);
526   }
527   static bool
528   isEqual(const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> &x,
529           const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> &y) {
530     return std::visit(
531         [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u);
532   }
533   template <typename A>
534   static bool isEqual(const Fortran::evaluate::Designator<A> &x,
535                       const Fortran::evaluate::Designator<A> &y) {
536     return std::visit(
537         [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u);
538   }
539   template <int BITS>
540   static bool isEqual(const Fortran::evaluate::value::Integer<BITS> &x,
541                       const Fortran::evaluate::value::Integer<BITS> &y) {
542     return x == y;
543   }
544   static bool isEqual(const Fortran::evaluate::NullPointer &x,
545                       const Fortran::evaluate::NullPointer &y) {
546     return true;
547   }
548   template <typename A, typename B,
549             std::enable_if_t<!std::is_same_v<A, B>, bool> = true>
550   static bool isEqual(const A &, const B &) {
551     return false;
552   }
553 };
554 } // namespace
555 
556 bool Fortran::lower::isEqual(
557     const Fortran::lower::ExplicitIterSpace::ArrayBases &x,
558     const Fortran::lower::ExplicitIterSpace::ArrayBases &y) {
559   return std::visit(
560       Fortran::common::visitors{
561           // Fortran::semantics::Symbol * are the exception here. These pointers
562           // have identity; if two Symbol * values are the same (different) then
563           // they are the same (different) logical symbol.
564           [&](Fortran::lower::FrontEndSymbol p,
565               Fortran::lower::FrontEndSymbol q) { return p == q; },
566           [&](const auto *p, const auto *q) {
567             if constexpr (std::is_same_v<decltype(p), decltype(q)>) {
568               LLVM_DEBUG(llvm::dbgs()
569                          << "is equal: " << p << ' ' << q << ' '
570                          << IsEqualEvaluateExpr::isEqual(*p, *q) << '\n');
571               return IsEqualEvaluateExpr::isEqual(*p, *q);
572             } else {
573               // Different subtree types are never equal.
574               return false;
575             }
576           }},
577       x, y);
578 }
579 
580 bool Fortran::lower::isEqual(Fortran::lower::FrontEndExpr x,
581                              Fortran::lower::FrontEndExpr y) {
582   auto empty = llvm::DenseMapInfo<Fortran::lower::FrontEndExpr>::getEmptyKey();
583   auto tombstone =
584       llvm::DenseMapInfo<Fortran::lower::FrontEndExpr>::getTombstoneKey();
585   if (x == empty || y == empty || x == tombstone || y == tombstone)
586     return x == y;
587   return x == y || IsEqualEvaluateExpr::isEqual(*x, *y);
588 }
589 
590 namespace {
591 
592 /// This class can recover the base array in an expression that contains
593 /// explicit iteration space symbols. Most of the class can be ignored as it is
594 /// boilerplate Fortran::evaluate::Expr traversal.
595 class ArrayBaseFinder {
596 public:
597   using RT = bool;
598 
599   ArrayBaseFinder(llvm::ArrayRef<Fortran::lower::FrontEndSymbol> syms)
600       : controlVars(syms.begin(), syms.end()) {}
601 
602   template <typename T>
603   void operator()(const T &x) {
604     (void)find(x);
605   }
606 
607   /// Get the list of bases.
608   llvm::ArrayRef<Fortran::lower::ExplicitIterSpace::ArrayBases>
609   getBases() const {
610     LLVM_DEBUG(llvm::dbgs()
611                << "number of array bases found: " << bases.size() << '\n');
612     return bases;
613   }
614 
615 private:
616   // First, the cases that are of interest.
617   RT find(const Fortran::semantics::Symbol &symbol) {
618     if (symbol.Rank() > 0) {
619       bases.push_back(&symbol);
620       return true;
621     }
622     return {};
623   }
624   RT find(const Fortran::evaluate::Component &x) {
625     auto found = find(x.base());
626     if (!found && x.base().Rank() == 0 && x.Rank() > 0) {
627       bases.push_back(&x);
628       return true;
629     }
630     return found;
631   }
632   RT find(const Fortran::evaluate::ArrayRef &x) {
633     for (const auto &sub : x.subscript())
634       (void)find(sub);
635     if (x.base().IsSymbol()) {
636       if (x.Rank() > 0 || intersection(x.subscript())) {
637         bases.push_back(&x);
638         return true;
639       }
640       return {};
641     }
642     auto found = find(x.base());
643     if (!found && ((x.base().Rank() == 0 && x.Rank() > 0) ||
644                    intersection(x.subscript()))) {
645       bases.push_back(&x);
646       return true;
647     }
648     return found;
649   }
650   RT find(const Fortran::evaluate::Triplet &x) {
651     if (const auto *lower = x.GetLower())
652       (void)find(*lower);
653     if (const auto *upper = x.GetUpper())
654       (void)find(*upper);
655     return find(x.GetStride());
656   }
657   RT find(const Fortran::evaluate::IndirectSubscriptIntegerExpr &x) {
658     return find(x.value());
659   }
660   RT find(const Fortran::evaluate::Subscript &x) { return find(x.u); }
661   RT find(const Fortran::evaluate::DataRef &x) { return find(x.u); }
662   RT find(const Fortran::evaluate::CoarrayRef &x) {
663     assert(false && "coarray reference");
664     return {};
665   }
666 
667   template <typename A>
668   bool intersection(const A &subscripts) {
669     return Fortran::lower::symbolsIntersectSubscripts(controlVars, subscripts);
670   }
671 
672   // The rest is traversal boilerplate and can be ignored.
673   RT find(const Fortran::evaluate::Substring &x) { return find(x.parent()); }
674   template <typename A>
675   RT find(const Fortran::semantics::SymbolRef x) {
676     return find(*x);
677   }
678   RT find(const Fortran::evaluate::NamedEntity &x) {
679     if (x.IsSymbol())
680       return find(x.GetFirstSymbol());
681     return find(x.GetComponent());
682   }
683 
684   template <typename A, bool C>
685   RT find(const Fortran::common::Indirection<A, C> &x) {
686     return find(x.value());
687   }
688   template <typename A>
689   RT find(const std::unique_ptr<A> &x) {
690     return find(x.get());
691   }
692   template <typename A>
693   RT find(const std::shared_ptr<A> &x) {
694     return find(x.get());
695   }
696   template <typename A>
697   RT find(const A *x) {
698     if (x)
699       return find(*x);
700     return {};
701   }
702   template <typename A>
703   RT find(const std::optional<A> &x) {
704     if (x)
705       return find(*x);
706     return {};
707   }
708   template <typename... A>
709   RT find(const std::variant<A...> &u) {
710     return std::visit([&](const auto &v) { return find(v); }, u);
711   }
712   template <typename A>
713   RT find(const std::vector<A> &x) {
714     for (auto &v : x)
715       (void)find(v);
716     return {};
717   }
718   RT find(const Fortran::evaluate::BOZLiteralConstant &) { return {}; }
719   RT find(const Fortran::evaluate::NullPointer &) { return {}; }
720   template <typename T>
721   RT find(const Fortran::evaluate::Constant<T> &x) {
722     return {};
723   }
724   RT find(const Fortran::evaluate::StaticDataObject &) { return {}; }
725   RT find(const Fortran::evaluate::ImpliedDoIndex &) { return {}; }
726   RT find(const Fortran::evaluate::BaseObject &x) {
727     (void)find(x.u);
728     return {};
729   }
730   RT find(const Fortran::evaluate::TypeParamInquiry &) { return {}; }
731   RT find(const Fortran::evaluate::ComplexPart &x) { return {}; }
732   template <typename T>
733   RT find(const Fortran::evaluate::Designator<T> &x) {
734     return find(x.u);
735   }
736   template <typename T>
737   RT find(const Fortran::evaluate::Variable<T> &x) {
738     return find(x.u);
739   }
740   RT find(const Fortran::evaluate::DescriptorInquiry &) { return {}; }
741   RT find(const Fortran::evaluate::SpecificIntrinsic &) { return {}; }
742   RT find(const Fortran::evaluate::ProcedureDesignator &x) { return {}; }
743   RT find(const Fortran::evaluate::ProcedureRef &x) {
744     (void)find(x.proc());
745     if (x.IsElemental())
746       (void)find(x.arguments());
747     return {};
748   }
749   RT find(const Fortran::evaluate::ActualArgument &x) {
750     if (const auto *sym = x.GetAssumedTypeDummy())
751       (void)find(*sym);
752     else
753       (void)find(x.UnwrapExpr());
754     return {};
755   }
756   template <typename T>
757   RT find(const Fortran::evaluate::FunctionRef<T> &x) {
758     (void)find(static_cast<const Fortran::evaluate::ProcedureRef &>(x));
759     return {};
760   }
761   template <typename T>
762   RT find(const Fortran::evaluate::ArrayConstructorValue<T> &) {
763     return {};
764   }
765   template <typename T>
766   RT find(const Fortran::evaluate::ArrayConstructorValues<T> &) {
767     return {};
768   }
769   template <typename T>
770   RT find(const Fortran::evaluate::ImpliedDo<T> &) {
771     return {};
772   }
773   RT find(const Fortran::semantics::ParamValue &) { return {}; }
774   RT find(const Fortran::semantics::DerivedTypeSpec &) { return {}; }
775   RT find(const Fortran::evaluate::StructureConstructor &) { return {}; }
776   template <typename D, typename R, typename O>
777   RT find(const Fortran::evaluate::Operation<D, R, O> &op) {
778     (void)find(op.left());
779     return false;
780   }
781   template <typename D, typename R, typename LO, typename RO>
782   RT find(const Fortran::evaluate::Operation<D, R, LO, RO> &op) {
783     (void)find(op.left());
784     (void)find(op.right());
785     return false;
786   }
787   RT find(const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> &x) {
788     (void)find(x.u);
789     return {};
790   }
791   template <typename T>
792   RT find(const Fortran::evaluate::Expr<T> &x) {
793     (void)find(x.u);
794     return {};
795   }
796 
797   llvm::SmallVector<Fortran::lower::ExplicitIterSpace::ArrayBases> bases;
798   llvm::SmallVector<Fortran::lower::FrontEndSymbol> controlVars;
799 };
800 
801 } // namespace
802 
803 void Fortran::lower::ExplicitIterSpace::leave() {
804   ccLoopNest.pop_back();
805   --forallContextOpen;
806   conditionalCleanup();
807 }
808 
809 void Fortran::lower::ExplicitIterSpace::addSymbol(
810     Fortran::lower::FrontEndSymbol sym) {
811   assert(!symbolStack.empty());
812   symbolStack.back().push_back(sym);
813 }
814 
815 void Fortran::lower::ExplicitIterSpace::exprBase(Fortran::lower::FrontEndExpr x,
816                                                  bool lhs) {
817   ArrayBaseFinder finder(collectAllSymbols());
818   finder(*x);
819   llvm::ArrayRef<Fortran::lower::ExplicitIterSpace::ArrayBases> bases =
820       finder.getBases();
821   if (rhsBases.empty())
822     endAssign();
823   if (lhs) {
824     if (bases.empty()) {
825       lhsBases.push_back(llvm::None);
826       return;
827     }
828     assert(bases.size() >= 1 && "must detect an array reference on lhs");
829     if (bases.size() > 1)
830       rhsBases.back().append(bases.begin(), bases.end() - 1);
831     lhsBases.push_back(bases.back());
832     return;
833   }
834   rhsBases.back().append(bases.begin(), bases.end());
835 }
836 
837 void Fortran::lower::ExplicitIterSpace::endAssign() { rhsBases.emplace_back(); }
838 
839 void Fortran::lower::ExplicitIterSpace::pushLevel() {
840   symbolStack.push_back(llvm::SmallVector<Fortran::lower::FrontEndSymbol>{});
841 }
842 
843 void Fortran::lower::ExplicitIterSpace::popLevel() { symbolStack.pop_back(); }
844 
845 void Fortran::lower::ExplicitIterSpace::conditionalCleanup() {
846   if (forallContextOpen == 0) {
847     // Exiting the outermost FORALL context.
848     // Cleanup any residual mask buffers.
849     outermostContext().finalize();
850     // Clear and reset all the cached information.
851     symbolStack.clear();
852     lhsBases.clear();
853     rhsBases.clear();
854     loadBindings.clear();
855     ccLoopNest.clear();
856     innerArgs.clear();
857     outerLoop = llvm::None;
858     clearLoops();
859     counter = 0;
860   }
861 }
862 
863 llvm::Optional<size_t>
864 Fortran::lower::ExplicitIterSpace::findArgPosition(fir::ArrayLoadOp load) {
865   if (lhsBases[counter].hasValue()) {
866     auto ld = loadBindings.find(lhsBases[counter].getValue());
867     llvm::Optional<size_t> optPos;
868     if (ld != loadBindings.end() && ld->second == load)
869       optPos = static_cast<size_t>(0u);
870     assert(optPos.hasValue() && "load does not correspond to lhs");
871     return optPos;
872   }
873   return llvm::None;
874 }
875 
876 llvm::SmallVector<Fortran::lower::FrontEndSymbol>
877 Fortran::lower::ExplicitIterSpace::collectAllSymbols() {
878   llvm::SmallVector<Fortran::lower::FrontEndSymbol> result;
879   for (llvm::SmallVector<FrontEndSymbol> vec : symbolStack)
880     result.append(vec.begin(), vec.end());
881   return result;
882 }
883 
884 llvm::raw_ostream &
885 Fortran::lower::operator<<(llvm::raw_ostream &s,
886                            const Fortran::lower::ImplicitIterSpace &e) {
887   for (const llvm::SmallVector<
888            Fortran::lower::ImplicitIterSpace::FrontEndMaskExpr> &xs :
889        e.getMasks()) {
890     s << "{ ";
891     for (const Fortran::lower::ImplicitIterSpace::FrontEndMaskExpr &x : xs)
892       x->AsFortran(s << '(') << "), ";
893     s << "}\n";
894   }
895   return s;
896 }
897 
898 llvm::raw_ostream &
899 Fortran::lower::operator<<(llvm::raw_ostream &s,
900                            const Fortran::lower::ExplicitIterSpace &e) {
901   auto dump = [&](const auto &u) {
902     std::visit(Fortran::common::visitors{
903                    [&](const Fortran::semantics::Symbol *y) {
904                      s << "  " << *y << '\n';
905                    },
906                    [&](const Fortran::evaluate::ArrayRef *y) {
907                      s << "  ";
908                      if (y->base().IsSymbol())
909                        s << y->base().GetFirstSymbol();
910                      else
911                        s << y->base().GetComponent().GetLastSymbol();
912                      s << '\n';
913                    },
914                    [&](const Fortran::evaluate::Component *y) {
915                      s << "  " << y->GetLastSymbol() << '\n';
916                    }},
917                u);
918   };
919   s << "LHS bases:\n";
920   for (const llvm::Optional<Fortran::lower::ExplicitIterSpace::ArrayBases> &u :
921        e.lhsBases)
922     if (u.hasValue())
923       dump(u.getValue());
924   s << "RHS bases:\n";
925   for (const llvm::SmallVector<Fortran::lower::ExplicitIterSpace::ArrayBases>
926            &bases : e.rhsBases) {
927     for (const Fortran::lower::ExplicitIterSpace::ArrayBases &u : bases)
928       dump(u);
929     s << '\n';
930   }
931   return s;
932 }
933 
934 void Fortran::lower::ImplicitIterSpace::dump() const {
935   llvm::errs() << *this << '\n';
936 }
937 
938 void Fortran::lower::ExplicitIterSpace::dump() const {
939   llvm::errs() << *this << '\n';
940 }
941