1 //===-- IterationSpace.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "flang/Lower/IterationSpace.h"
14 #include "flang/Evaluate/expression.h"
15 #include "flang/Lower/AbstractConverter.h"
16 #include "flang/Lower/Support/Utils.h"
17 #include "llvm/Support/Debug.h"
18
19 #define DEBUG_TYPE "flang-lower-iteration-space"
20
21 namespace {
22 // Fortran::evaluate::Expr are functional values organized like an AST. A
23 // Fortran::evaluate::Expr is meant to be moved and cloned. Using the front end
24 // tools can often cause copies and extra wrapper classes to be added to any
25 // Fortran::evalute::Expr. These values should not be assumed or relied upon to
26 // have an *object* identity. They are deeply recursive, irregular structures
27 // built from a large number of classes which do not use inheritance and
28 // necessitate a large volume of boilerplate code as a result.
29 //
30 // Contrastingly, LLVM data structures make ubiquitous assumptions about an
31 // object's identity via pointers to the object. An object's location in memory
32 // is thus very often an identifying relation.
33
34 // This class defines a hash computation of a Fortran::evaluate::Expr tree value
35 // so it can be used with llvm::DenseMap. The Fortran::evaluate::Expr need not
36 // have the same address.
37 class HashEvaluateExpr {
38 public:
39 // A Se::Symbol is the only part of an Fortran::evaluate::Expr with an
40 // identity property.
getHashValue(const Fortran::semantics::Symbol & x)41 static unsigned getHashValue(const Fortran::semantics::Symbol &x) {
42 return static_cast<unsigned>(reinterpret_cast<std::intptr_t>(&x));
43 }
44 template <typename A, bool COPY>
getHashValue(const Fortran::common::Indirection<A,COPY> & x)45 static unsigned getHashValue(const Fortran::common::Indirection<A, COPY> &x) {
46 return getHashValue(x.value());
47 }
48 template <typename A>
getHashValue(const std::optional<A> & x)49 static unsigned getHashValue(const std::optional<A> &x) {
50 if (x.has_value())
51 return getHashValue(x.value());
52 return 0u;
53 }
getHashValue(const Fortran::evaluate::Subscript & x)54 static unsigned getHashValue(const Fortran::evaluate::Subscript &x) {
55 return std::visit([&](const auto &v) { return getHashValue(v); }, x.u);
56 }
getHashValue(const Fortran::evaluate::Triplet & x)57 static unsigned getHashValue(const Fortran::evaluate::Triplet &x) {
58 return getHashValue(x.lower()) - getHashValue(x.upper()) * 5u -
59 getHashValue(x.stride()) * 11u;
60 }
getHashValue(const Fortran::evaluate::Component & x)61 static unsigned getHashValue(const Fortran::evaluate::Component &x) {
62 return getHashValue(x.base()) * 83u - getHashValue(x.GetLastSymbol());
63 }
getHashValue(const Fortran::evaluate::ArrayRef & x)64 static unsigned getHashValue(const Fortran::evaluate::ArrayRef &x) {
65 unsigned subs = 1u;
66 for (const Fortran::evaluate::Subscript &v : x.subscript())
67 subs -= getHashValue(v);
68 return getHashValue(x.base()) * 89u - subs;
69 }
getHashValue(const Fortran::evaluate::CoarrayRef & x)70 static unsigned getHashValue(const Fortran::evaluate::CoarrayRef &x) {
71 unsigned subs = 1u;
72 for (const Fortran::evaluate::Subscript &v : x.subscript())
73 subs -= getHashValue(v);
74 unsigned cosubs = 3u;
75 for (const Fortran::evaluate::Expr<Fortran::evaluate::SubscriptInteger> &v :
76 x.cosubscript())
77 cosubs -= getHashValue(v);
78 unsigned syms = 7u;
79 for (const Fortran::evaluate::SymbolRef &v : x.base())
80 syms += getHashValue(v);
81 return syms * 97u - subs - cosubs + getHashValue(x.stat()) + 257u +
82 getHashValue(x.team());
83 }
getHashValue(const Fortran::evaluate::NamedEntity & x)84 static unsigned getHashValue(const Fortran::evaluate::NamedEntity &x) {
85 if (x.IsSymbol())
86 return getHashValue(x.GetFirstSymbol()) * 11u;
87 return getHashValue(x.GetComponent()) * 13u;
88 }
getHashValue(const Fortran::evaluate::DataRef & x)89 static unsigned getHashValue(const Fortran::evaluate::DataRef &x) {
90 return std::visit([&](const auto &v) { return getHashValue(v); }, x.u);
91 }
getHashValue(const Fortran::evaluate::ComplexPart & x)92 static unsigned getHashValue(const Fortran::evaluate::ComplexPart &x) {
93 return getHashValue(x.complex()) - static_cast<unsigned>(x.part());
94 }
95 template <Fortran::common::TypeCategory TC1, int KIND,
96 Fortran::common::TypeCategory TC2>
getHashValue(const Fortran::evaluate::Convert<Fortran::evaluate::Type<TC1,KIND>,TC2> & x)97 static unsigned getHashValue(
98 const Fortran::evaluate::Convert<Fortran::evaluate::Type<TC1, KIND>, TC2>
99 &x) {
100 return getHashValue(x.left()) - (static_cast<unsigned>(TC1) + 2u) -
101 (static_cast<unsigned>(KIND) + 5u);
102 }
103 template <int KIND>
104 static unsigned
getHashValue(const Fortran::evaluate::ComplexComponent<KIND> & x)105 getHashValue(const Fortran::evaluate::ComplexComponent<KIND> &x) {
106 return getHashValue(x.left()) -
107 (static_cast<unsigned>(x.isImaginaryPart) + 1u) * 3u;
108 }
109 template <typename T>
getHashValue(const Fortran::evaluate::Parentheses<T> & x)110 static unsigned getHashValue(const Fortran::evaluate::Parentheses<T> &x) {
111 return getHashValue(x.left()) * 17u;
112 }
113 template <Fortran::common::TypeCategory TC, int KIND>
getHashValue(const Fortran::evaluate::Negate<Fortran::evaluate::Type<TC,KIND>> & x)114 static unsigned getHashValue(
115 const Fortran::evaluate::Negate<Fortran::evaluate::Type<TC, KIND>> &x) {
116 return getHashValue(x.left()) - (static_cast<unsigned>(TC) + 5u) -
117 (static_cast<unsigned>(KIND) + 7u);
118 }
119 template <Fortran::common::TypeCategory TC, int KIND>
getHashValue(const Fortran::evaluate::Add<Fortran::evaluate::Type<TC,KIND>> & x)120 static unsigned getHashValue(
121 const Fortran::evaluate::Add<Fortran::evaluate::Type<TC, KIND>> &x) {
122 return (getHashValue(x.left()) + getHashValue(x.right())) * 23u +
123 static_cast<unsigned>(TC) + static_cast<unsigned>(KIND);
124 }
125 template <Fortran::common::TypeCategory TC, int KIND>
getHashValue(const Fortran::evaluate::Subtract<Fortran::evaluate::Type<TC,KIND>> & x)126 static unsigned getHashValue(
127 const Fortran::evaluate::Subtract<Fortran::evaluate::Type<TC, KIND>> &x) {
128 return (getHashValue(x.left()) - getHashValue(x.right())) * 19u +
129 static_cast<unsigned>(TC) + static_cast<unsigned>(KIND);
130 }
131 template <Fortran::common::TypeCategory TC, int KIND>
getHashValue(const Fortran::evaluate::Multiply<Fortran::evaluate::Type<TC,KIND>> & x)132 static unsigned getHashValue(
133 const Fortran::evaluate::Multiply<Fortran::evaluate::Type<TC, KIND>> &x) {
134 return (getHashValue(x.left()) + getHashValue(x.right())) * 29u +
135 static_cast<unsigned>(TC) + static_cast<unsigned>(KIND);
136 }
137 template <Fortran::common::TypeCategory TC, int KIND>
getHashValue(const Fortran::evaluate::Divide<Fortran::evaluate::Type<TC,KIND>> & x)138 static unsigned getHashValue(
139 const Fortran::evaluate::Divide<Fortran::evaluate::Type<TC, KIND>> &x) {
140 return (getHashValue(x.left()) - getHashValue(x.right())) * 31u +
141 static_cast<unsigned>(TC) + static_cast<unsigned>(KIND);
142 }
143 template <Fortran::common::TypeCategory TC, int KIND>
getHashValue(const Fortran::evaluate::Power<Fortran::evaluate::Type<TC,KIND>> & x)144 static unsigned getHashValue(
145 const Fortran::evaluate::Power<Fortran::evaluate::Type<TC, KIND>> &x) {
146 return (getHashValue(x.left()) - getHashValue(x.right())) * 37u +
147 static_cast<unsigned>(TC) + static_cast<unsigned>(KIND);
148 }
149 template <Fortran::common::TypeCategory TC, int KIND>
getHashValue(const Fortran::evaluate::Extremum<Fortran::evaluate::Type<TC,KIND>> & x)150 static unsigned getHashValue(
151 const Fortran::evaluate::Extremum<Fortran::evaluate::Type<TC, KIND>> &x) {
152 return (getHashValue(x.left()) + getHashValue(x.right())) * 41u +
153 static_cast<unsigned>(TC) + static_cast<unsigned>(KIND) +
154 static_cast<unsigned>(x.ordering) * 7u;
155 }
156 template <Fortran::common::TypeCategory TC, int KIND>
getHashValue(const Fortran::evaluate::RealToIntPower<Fortran::evaluate::Type<TC,KIND>> & x)157 static unsigned getHashValue(
158 const Fortran::evaluate::RealToIntPower<Fortran::evaluate::Type<TC, KIND>>
159 &x) {
160 return (getHashValue(x.left()) - getHashValue(x.right())) * 43u +
161 static_cast<unsigned>(TC) + static_cast<unsigned>(KIND);
162 }
163 template <int KIND>
164 static unsigned
getHashValue(const Fortran::evaluate::ComplexConstructor<KIND> & x)165 getHashValue(const Fortran::evaluate::ComplexConstructor<KIND> &x) {
166 return (getHashValue(x.left()) - getHashValue(x.right())) * 47u +
167 static_cast<unsigned>(KIND);
168 }
169 template <int KIND>
getHashValue(const Fortran::evaluate::Concat<KIND> & x)170 static unsigned getHashValue(const Fortran::evaluate::Concat<KIND> &x) {
171 return (getHashValue(x.left()) - getHashValue(x.right())) * 53u +
172 static_cast<unsigned>(KIND);
173 }
174 template <int KIND>
getHashValue(const Fortran::evaluate::SetLength<KIND> & x)175 static unsigned getHashValue(const Fortran::evaluate::SetLength<KIND> &x) {
176 return (getHashValue(x.left()) - getHashValue(x.right())) * 59u +
177 static_cast<unsigned>(KIND);
178 }
getHashValue(const Fortran::semantics::SymbolRef & sym)179 static unsigned getHashValue(const Fortran::semantics::SymbolRef &sym) {
180 return getHashValue(sym.get());
181 }
getHashValue(const Fortran::evaluate::Substring & x)182 static unsigned getHashValue(const Fortran::evaluate::Substring &x) {
183 return 61u * std::visit([&](const auto &p) { return getHashValue(p); },
184 x.parent()) -
185 getHashValue(x.lower()) - (getHashValue(x.lower()) + 1u);
186 }
187 static unsigned
getHashValue(const Fortran::evaluate::StaticDataObject::Pointer & x)188 getHashValue(const Fortran::evaluate::StaticDataObject::Pointer &x) {
189 return llvm::hash_value(x->name());
190 }
getHashValue(const Fortran::evaluate::SpecificIntrinsic & x)191 static unsigned getHashValue(const Fortran::evaluate::SpecificIntrinsic &x) {
192 return llvm::hash_value(x.name);
193 }
194 template <typename A>
getHashValue(const Fortran::evaluate::Constant<A> & x)195 static unsigned getHashValue(const Fortran::evaluate::Constant<A> &x) {
196 // FIXME: Should hash the content.
197 return 103u;
198 }
getHashValue(const Fortran::evaluate::ActualArgument & x)199 static unsigned getHashValue(const Fortran::evaluate::ActualArgument &x) {
200 if (const Fortran::evaluate::Symbol *sym = x.GetAssumedTypeDummy())
201 return getHashValue(*sym);
202 return getHashValue(*x.UnwrapExpr());
203 }
204 static unsigned
getHashValue(const Fortran::evaluate::ProcedureDesignator & x)205 getHashValue(const Fortran::evaluate::ProcedureDesignator &x) {
206 return std::visit([&](const auto &v) { return getHashValue(v); }, x.u);
207 }
getHashValue(const Fortran::evaluate::ProcedureRef & x)208 static unsigned getHashValue(const Fortran::evaluate::ProcedureRef &x) {
209 unsigned args = 13u;
210 for (const std::optional<Fortran::evaluate::ActualArgument> &v :
211 x.arguments())
212 args -= getHashValue(v);
213 return getHashValue(x.proc()) * 101u - args;
214 }
215 template <typename A>
216 static unsigned
getHashValue(const Fortran::evaluate::ArrayConstructor<A> & x)217 getHashValue(const Fortran::evaluate::ArrayConstructor<A> &x) {
218 // FIXME: hash the contents.
219 return 127u;
220 }
getHashValue(const Fortran::evaluate::ImpliedDoIndex & x)221 static unsigned getHashValue(const Fortran::evaluate::ImpliedDoIndex &x) {
222 return llvm::hash_value(toStringRef(x.name).str()) * 131u;
223 }
getHashValue(const Fortran::evaluate::TypeParamInquiry & x)224 static unsigned getHashValue(const Fortran::evaluate::TypeParamInquiry &x) {
225 return getHashValue(x.base()) * 137u - getHashValue(x.parameter()) * 3u;
226 }
getHashValue(const Fortran::evaluate::DescriptorInquiry & x)227 static unsigned getHashValue(const Fortran::evaluate::DescriptorInquiry &x) {
228 return getHashValue(x.base()) * 139u -
229 static_cast<unsigned>(x.field()) * 13u +
230 static_cast<unsigned>(x.dimension());
231 }
232 static unsigned
getHashValue(const Fortran::evaluate::StructureConstructor & x)233 getHashValue(const Fortran::evaluate::StructureConstructor &x) {
234 // FIXME: hash the contents.
235 return 149u;
236 }
237 template <int KIND>
getHashValue(const Fortran::evaluate::Not<KIND> & x)238 static unsigned getHashValue(const Fortran::evaluate::Not<KIND> &x) {
239 return getHashValue(x.left()) * 61u + static_cast<unsigned>(KIND);
240 }
241 template <int KIND>
242 static unsigned
getHashValue(const Fortran::evaluate::LogicalOperation<KIND> & x)243 getHashValue(const Fortran::evaluate::LogicalOperation<KIND> &x) {
244 unsigned result = getHashValue(x.left()) + getHashValue(x.right());
245 return result * 67u + static_cast<unsigned>(x.logicalOperator) * 5u;
246 }
247 template <Fortran::common::TypeCategory TC, int KIND>
getHashValue(const Fortran::evaluate::Relational<Fortran::evaluate::Type<TC,KIND>> & x)248 static unsigned getHashValue(
249 const Fortran::evaluate::Relational<Fortran::evaluate::Type<TC, KIND>>
250 &x) {
251 return (getHashValue(x.left()) + getHashValue(x.right())) * 71u +
252 static_cast<unsigned>(TC) + static_cast<unsigned>(KIND) +
253 static_cast<unsigned>(x.opr) * 11u;
254 }
255 template <typename A>
getHashValue(const Fortran::evaluate::Expr<A> & x)256 static unsigned getHashValue(const Fortran::evaluate::Expr<A> &x) {
257 return std::visit([&](const auto &v) { return getHashValue(v); }, x.u);
258 }
getHashValue(const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> & x)259 static unsigned getHashValue(
260 const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> &x) {
261 return std::visit([&](const auto &v) { return getHashValue(v); }, x.u);
262 }
263 template <typename A>
getHashValue(const Fortran::evaluate::Designator<A> & x)264 static unsigned getHashValue(const Fortran::evaluate::Designator<A> &x) {
265 return std::visit([&](const auto &v) { return getHashValue(v); }, x.u);
266 }
267 template <int BITS>
268 static unsigned
getHashValue(const Fortran::evaluate::value::Integer<BITS> & x)269 getHashValue(const Fortran::evaluate::value::Integer<BITS> &x) {
270 return static_cast<unsigned>(x.ToSInt());
271 }
getHashValue(const Fortran::evaluate::NullPointer & x)272 static unsigned getHashValue(const Fortran::evaluate::NullPointer &x) {
273 return ~179u;
274 }
275 };
276 } // namespace
277
getHashValue(const Fortran::lower::ExplicitIterSpace::ArrayBases & x)278 unsigned Fortran::lower::getHashValue(
279 const Fortran::lower::ExplicitIterSpace::ArrayBases &x) {
280 return std::visit(
281 [&](const auto *p) { return HashEvaluateExpr::getHashValue(*p); }, x);
282 }
283
getHashValue(Fortran::lower::FrontEndExpr x)284 unsigned Fortran::lower::getHashValue(Fortran::lower::FrontEndExpr x) {
285 return HashEvaluateExpr::getHashValue(*x);
286 }
287
288 namespace {
289 // Define the is equals test for using Fortran::evaluate::Expr values with
290 // llvm::DenseMap.
291 class IsEqualEvaluateExpr {
292 public:
293 // A Se::Symbol is the only part of an Fortran::evaluate::Expr with an
294 // identity property.
isEqual(const Fortran::semantics::Symbol & x,const Fortran::semantics::Symbol & y)295 static bool isEqual(const Fortran::semantics::Symbol &x,
296 const Fortran::semantics::Symbol &y) {
297 return isEqual(&x, &y);
298 }
isEqual(const Fortran::semantics::Symbol * x,const Fortran::semantics::Symbol * y)299 static bool isEqual(const Fortran::semantics::Symbol *x,
300 const Fortran::semantics::Symbol *y) {
301 return x == y;
302 }
303 template <typename A, bool COPY>
isEqual(const Fortran::common::Indirection<A,COPY> & x,const Fortran::common::Indirection<A,COPY> & y)304 static bool isEqual(const Fortran::common::Indirection<A, COPY> &x,
305 const Fortran::common::Indirection<A, COPY> &y) {
306 return isEqual(x.value(), y.value());
307 }
308 template <typename A>
isEqual(const std::optional<A> & x,const std::optional<A> & y)309 static bool isEqual(const std::optional<A> &x, const std::optional<A> &y) {
310 if (x.has_value() && y.has_value())
311 return isEqual(x.value(), y.value());
312 return !x.has_value() && !y.has_value();
313 }
314 template <typename A>
isEqual(const std::vector<A> & x,const std::vector<A> & y)315 static bool isEqual(const std::vector<A> &x, const std::vector<A> &y) {
316 if (x.size() != y.size())
317 return false;
318 const std::size_t size = x.size();
319 for (std::remove_const_t<decltype(size)> i = 0; i < size; ++i)
320 if (!isEqual(x[i], y[i]))
321 return false;
322 return true;
323 }
isEqual(const Fortran::evaluate::Subscript & x,const Fortran::evaluate::Subscript & y)324 static bool isEqual(const Fortran::evaluate::Subscript &x,
325 const Fortran::evaluate::Subscript &y) {
326 return std::visit(
327 [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u);
328 }
isEqual(const Fortran::evaluate::Triplet & x,const Fortran::evaluate::Triplet & y)329 static bool isEqual(const Fortran::evaluate::Triplet &x,
330 const Fortran::evaluate::Triplet &y) {
331 return isEqual(x.lower(), y.lower()) && isEqual(x.upper(), y.upper()) &&
332 isEqual(x.stride(), y.stride());
333 }
isEqual(const Fortran::evaluate::Component & x,const Fortran::evaluate::Component & y)334 static bool isEqual(const Fortran::evaluate::Component &x,
335 const Fortran::evaluate::Component &y) {
336 return isEqual(x.base(), y.base()) &&
337 isEqual(x.GetLastSymbol(), y.GetLastSymbol());
338 }
isEqual(const Fortran::evaluate::ArrayRef & x,const Fortran::evaluate::ArrayRef & y)339 static bool isEqual(const Fortran::evaluate::ArrayRef &x,
340 const Fortran::evaluate::ArrayRef &y) {
341 return isEqual(x.base(), y.base()) && isEqual(x.subscript(), y.subscript());
342 }
isEqual(const Fortran::evaluate::CoarrayRef & x,const Fortran::evaluate::CoarrayRef & y)343 static bool isEqual(const Fortran::evaluate::CoarrayRef &x,
344 const Fortran::evaluate::CoarrayRef &y) {
345 return isEqual(x.base(), y.base()) &&
346 isEqual(x.subscript(), y.subscript()) &&
347 isEqual(x.cosubscript(), y.cosubscript()) &&
348 isEqual(x.stat(), y.stat()) && isEqual(x.team(), y.team());
349 }
isEqual(const Fortran::evaluate::NamedEntity & x,const Fortran::evaluate::NamedEntity & y)350 static bool isEqual(const Fortran::evaluate::NamedEntity &x,
351 const Fortran::evaluate::NamedEntity &y) {
352 if (x.IsSymbol() && y.IsSymbol())
353 return isEqual(x.GetFirstSymbol(), y.GetFirstSymbol());
354 return !x.IsSymbol() && !y.IsSymbol() &&
355 isEqual(x.GetComponent(), y.GetComponent());
356 }
isEqual(const Fortran::evaluate::DataRef & x,const Fortran::evaluate::DataRef & y)357 static bool isEqual(const Fortran::evaluate::DataRef &x,
358 const Fortran::evaluate::DataRef &y) {
359 return std::visit(
360 [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u);
361 }
isEqual(const Fortran::evaluate::ComplexPart & x,const Fortran::evaluate::ComplexPart & y)362 static bool isEqual(const Fortran::evaluate::ComplexPart &x,
363 const Fortran::evaluate::ComplexPart &y) {
364 return isEqual(x.complex(), y.complex()) && x.part() == y.part();
365 }
366 template <typename A, Fortran::common::TypeCategory TC2>
isEqual(const Fortran::evaluate::Convert<A,TC2> & x,const Fortran::evaluate::Convert<A,TC2> & y)367 static bool isEqual(const Fortran::evaluate::Convert<A, TC2> &x,
368 const Fortran::evaluate::Convert<A, TC2> &y) {
369 return isEqual(x.left(), y.left());
370 }
371 template <int KIND>
isEqual(const Fortran::evaluate::ComplexComponent<KIND> & x,const Fortran::evaluate::ComplexComponent<KIND> & y)372 static bool isEqual(const Fortran::evaluate::ComplexComponent<KIND> &x,
373 const Fortran::evaluate::ComplexComponent<KIND> &y) {
374 return isEqual(x.left(), y.left()) &&
375 x.isImaginaryPart == y.isImaginaryPart;
376 }
377 template <typename T>
isEqual(const Fortran::evaluate::Parentheses<T> & x,const Fortran::evaluate::Parentheses<T> & y)378 static bool isEqual(const Fortran::evaluate::Parentheses<T> &x,
379 const Fortran::evaluate::Parentheses<T> &y) {
380 return isEqual(x.left(), y.left());
381 }
382 template <typename A>
isEqual(const Fortran::evaluate::Negate<A> & x,const Fortran::evaluate::Negate<A> & y)383 static bool isEqual(const Fortran::evaluate::Negate<A> &x,
384 const Fortran::evaluate::Negate<A> &y) {
385 return isEqual(x.left(), y.left());
386 }
387 template <typename A>
isBinaryEqual(const A & x,const A & y)388 static bool isBinaryEqual(const A &x, const A &y) {
389 return isEqual(x.left(), y.left()) && isEqual(x.right(), y.right());
390 }
391 template <typename A>
isEqual(const Fortran::evaluate::Add<A> & x,const Fortran::evaluate::Add<A> & y)392 static bool isEqual(const Fortran::evaluate::Add<A> &x,
393 const Fortran::evaluate::Add<A> &y) {
394 return isBinaryEqual(x, y);
395 }
396 template <typename A>
isEqual(const Fortran::evaluate::Subtract<A> & x,const Fortran::evaluate::Subtract<A> & y)397 static bool isEqual(const Fortran::evaluate::Subtract<A> &x,
398 const Fortran::evaluate::Subtract<A> &y) {
399 return isBinaryEqual(x, y);
400 }
401 template <typename A>
isEqual(const Fortran::evaluate::Multiply<A> & x,const Fortran::evaluate::Multiply<A> & y)402 static bool isEqual(const Fortran::evaluate::Multiply<A> &x,
403 const Fortran::evaluate::Multiply<A> &y) {
404 return isBinaryEqual(x, y);
405 }
406 template <typename A>
isEqual(const Fortran::evaluate::Divide<A> & x,const Fortran::evaluate::Divide<A> & y)407 static bool isEqual(const Fortran::evaluate::Divide<A> &x,
408 const Fortran::evaluate::Divide<A> &y) {
409 return isBinaryEqual(x, y);
410 }
411 template <typename A>
isEqual(const Fortran::evaluate::Power<A> & x,const Fortran::evaluate::Power<A> & y)412 static bool isEqual(const Fortran::evaluate::Power<A> &x,
413 const Fortran::evaluate::Power<A> &y) {
414 return isBinaryEqual(x, y);
415 }
416 template <typename A>
isEqual(const Fortran::evaluate::Extremum<A> & x,const Fortran::evaluate::Extremum<A> & y)417 static bool isEqual(const Fortran::evaluate::Extremum<A> &x,
418 const Fortran::evaluate::Extremum<A> &y) {
419 return isBinaryEqual(x, y);
420 }
421 template <typename A>
isEqual(const Fortran::evaluate::RealToIntPower<A> & x,const Fortran::evaluate::RealToIntPower<A> & y)422 static bool isEqual(const Fortran::evaluate::RealToIntPower<A> &x,
423 const Fortran::evaluate::RealToIntPower<A> &y) {
424 return isBinaryEqual(x, y);
425 }
426 template <int KIND>
isEqual(const Fortran::evaluate::ComplexConstructor<KIND> & x,const Fortran::evaluate::ComplexConstructor<KIND> & y)427 static bool isEqual(const Fortran::evaluate::ComplexConstructor<KIND> &x,
428 const Fortran::evaluate::ComplexConstructor<KIND> &y) {
429 return isBinaryEqual(x, y);
430 }
431 template <int KIND>
isEqual(const Fortran::evaluate::Concat<KIND> & x,const Fortran::evaluate::Concat<KIND> & y)432 static bool isEqual(const Fortran::evaluate::Concat<KIND> &x,
433 const Fortran::evaluate::Concat<KIND> &y) {
434 return isBinaryEqual(x, y);
435 }
436 template <int KIND>
isEqual(const Fortran::evaluate::SetLength<KIND> & x,const Fortran::evaluate::SetLength<KIND> & y)437 static bool isEqual(const Fortran::evaluate::SetLength<KIND> &x,
438 const Fortran::evaluate::SetLength<KIND> &y) {
439 return isBinaryEqual(x, y);
440 }
isEqual(const Fortran::semantics::SymbolRef & x,const Fortran::semantics::SymbolRef & y)441 static bool isEqual(const Fortran::semantics::SymbolRef &x,
442 const Fortran::semantics::SymbolRef &y) {
443 return isEqual(x.get(), y.get());
444 }
isEqual(const Fortran::evaluate::Substring & x,const Fortran::evaluate::Substring & y)445 static bool isEqual(const Fortran::evaluate::Substring &x,
446 const Fortran::evaluate::Substring &y) {
447 return std::visit(
448 [&](const auto &p, const auto &q) { return isEqual(p, q); },
449 x.parent(), y.parent()) &&
450 isEqual(x.lower(), y.lower()) && isEqual(x.lower(), y.lower());
451 }
isEqual(const Fortran::evaluate::StaticDataObject::Pointer & x,const Fortran::evaluate::StaticDataObject::Pointer & y)452 static bool isEqual(const Fortran::evaluate::StaticDataObject::Pointer &x,
453 const Fortran::evaluate::StaticDataObject::Pointer &y) {
454 return x->name() == y->name();
455 }
isEqual(const Fortran::evaluate::SpecificIntrinsic & x,const Fortran::evaluate::SpecificIntrinsic & y)456 static bool isEqual(const Fortran::evaluate::SpecificIntrinsic &x,
457 const Fortran::evaluate::SpecificIntrinsic &y) {
458 return x.name == y.name;
459 }
460 template <typename A>
isEqual(const Fortran::evaluate::Constant<A> & x,const Fortran::evaluate::Constant<A> & y)461 static bool isEqual(const Fortran::evaluate::Constant<A> &x,
462 const Fortran::evaluate::Constant<A> &y) {
463 return x == y;
464 }
isEqual(const Fortran::evaluate::ActualArgument & x,const Fortran::evaluate::ActualArgument & y)465 static bool isEqual(const Fortran::evaluate::ActualArgument &x,
466 const Fortran::evaluate::ActualArgument &y) {
467 if (const Fortran::evaluate::Symbol *xs = x.GetAssumedTypeDummy()) {
468 if (const Fortran::evaluate::Symbol *ys = y.GetAssumedTypeDummy())
469 return isEqual(*xs, *ys);
470 return false;
471 }
472 return !y.GetAssumedTypeDummy() &&
473 isEqual(*x.UnwrapExpr(), *y.UnwrapExpr());
474 }
isEqual(const Fortran::evaluate::ProcedureDesignator & x,const Fortran::evaluate::ProcedureDesignator & y)475 static bool isEqual(const Fortran::evaluate::ProcedureDesignator &x,
476 const Fortran::evaluate::ProcedureDesignator &y) {
477 return std::visit(
478 [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u);
479 }
isEqual(const Fortran::evaluate::ProcedureRef & x,const Fortran::evaluate::ProcedureRef & y)480 static bool isEqual(const Fortran::evaluate::ProcedureRef &x,
481 const Fortran::evaluate::ProcedureRef &y) {
482 return isEqual(x.proc(), y.proc()) && isEqual(x.arguments(), y.arguments());
483 }
484 template <typename A>
isEqual(const Fortran::evaluate::ArrayConstructor<A> & x,const Fortran::evaluate::ArrayConstructor<A> & y)485 static bool isEqual(const Fortran::evaluate::ArrayConstructor<A> &x,
486 const Fortran::evaluate::ArrayConstructor<A> &y) {
487 llvm::report_fatal_error("not implemented");
488 }
isEqual(const Fortran::evaluate::ImpliedDoIndex & x,const Fortran::evaluate::ImpliedDoIndex & y)489 static bool isEqual(const Fortran::evaluate::ImpliedDoIndex &x,
490 const Fortran::evaluate::ImpliedDoIndex &y) {
491 return toStringRef(x.name) == toStringRef(y.name);
492 }
isEqual(const Fortran::evaluate::TypeParamInquiry & x,const Fortran::evaluate::TypeParamInquiry & y)493 static bool isEqual(const Fortran::evaluate::TypeParamInquiry &x,
494 const Fortran::evaluate::TypeParamInquiry &y) {
495 return isEqual(x.base(), y.base()) && isEqual(x.parameter(), y.parameter());
496 }
isEqual(const Fortran::evaluate::DescriptorInquiry & x,const Fortran::evaluate::DescriptorInquiry & y)497 static bool isEqual(const Fortran::evaluate::DescriptorInquiry &x,
498 const Fortran::evaluate::DescriptorInquiry &y) {
499 return isEqual(x.base(), y.base()) && x.field() == y.field() &&
500 x.dimension() == y.dimension();
501 }
isEqual(const Fortran::evaluate::StructureConstructor & x,const Fortran::evaluate::StructureConstructor & y)502 static bool isEqual(const Fortran::evaluate::StructureConstructor &x,
503 const Fortran::evaluate::StructureConstructor &y) {
504 llvm::report_fatal_error("not implemented");
505 }
506 template <int KIND>
isEqual(const Fortran::evaluate::Not<KIND> & x,const Fortran::evaluate::Not<KIND> & y)507 static bool isEqual(const Fortran::evaluate::Not<KIND> &x,
508 const Fortran::evaluate::Not<KIND> &y) {
509 return isEqual(x.left(), y.left());
510 }
511 template <int KIND>
isEqual(const Fortran::evaluate::LogicalOperation<KIND> & x,const Fortran::evaluate::LogicalOperation<KIND> & y)512 static bool isEqual(const Fortran::evaluate::LogicalOperation<KIND> &x,
513 const Fortran::evaluate::LogicalOperation<KIND> &y) {
514 return isEqual(x.left(), y.left()) && isEqual(x.right(), x.right());
515 }
516 template <typename A>
isEqual(const Fortran::evaluate::Relational<A> & x,const Fortran::evaluate::Relational<A> & y)517 static bool isEqual(const Fortran::evaluate::Relational<A> &x,
518 const Fortran::evaluate::Relational<A> &y) {
519 return isEqual(x.left(), y.left()) && isEqual(x.right(), y.right());
520 }
521 template <typename A>
isEqual(const Fortran::evaluate::Expr<A> & x,const Fortran::evaluate::Expr<A> & y)522 static bool isEqual(const Fortran::evaluate::Expr<A> &x,
523 const Fortran::evaluate::Expr<A> &y) {
524 return std::visit(
525 [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u);
526 }
527 static bool
isEqual(const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> & x,const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> & y)528 isEqual(const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> &x,
529 const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> &y) {
530 return std::visit(
531 [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u);
532 }
533 template <typename A>
isEqual(const Fortran::evaluate::Designator<A> & x,const Fortran::evaluate::Designator<A> & y)534 static bool isEqual(const Fortran::evaluate::Designator<A> &x,
535 const Fortran::evaluate::Designator<A> &y) {
536 return std::visit(
537 [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u);
538 }
539 template <int BITS>
isEqual(const Fortran::evaluate::value::Integer<BITS> & x,const Fortran::evaluate::value::Integer<BITS> & y)540 static bool isEqual(const Fortran::evaluate::value::Integer<BITS> &x,
541 const Fortran::evaluate::value::Integer<BITS> &y) {
542 return x == y;
543 }
isEqual(const Fortran::evaluate::NullPointer & x,const Fortran::evaluate::NullPointer & y)544 static bool isEqual(const Fortran::evaluate::NullPointer &x,
545 const Fortran::evaluate::NullPointer &y) {
546 return true;
547 }
548 template <typename A, typename B,
549 std::enable_if_t<!std::is_same_v<A, B>, bool> = true>
isEqual(const A &,const B &)550 static bool isEqual(const A &, const B &) {
551 return false;
552 }
553 };
554 } // namespace
555
isEqual(const Fortran::lower::ExplicitIterSpace::ArrayBases & x,const Fortran::lower::ExplicitIterSpace::ArrayBases & y)556 bool Fortran::lower::isEqual(
557 const Fortran::lower::ExplicitIterSpace::ArrayBases &x,
558 const Fortran::lower::ExplicitIterSpace::ArrayBases &y) {
559 return std::visit(
560 Fortran::common::visitors{
561 // Fortran::semantics::Symbol * are the exception here. These pointers
562 // have identity; if two Symbol * values are the same (different) then
563 // they are the same (different) logical symbol.
564 [&](Fortran::lower::FrontEndSymbol p,
565 Fortran::lower::FrontEndSymbol q) { return p == q; },
566 [&](const auto *p, const auto *q) {
567 if constexpr (std::is_same_v<decltype(p), decltype(q)>) {
568 LLVM_DEBUG(llvm::dbgs()
569 << "is equal: " << p << ' ' << q << ' '
570 << IsEqualEvaluateExpr::isEqual(*p, *q) << '\n');
571 return IsEqualEvaluateExpr::isEqual(*p, *q);
572 } else {
573 // Different subtree types are never equal.
574 return false;
575 }
576 }},
577 x, y);
578 }
579
isEqual(Fortran::lower::FrontEndExpr x,Fortran::lower::FrontEndExpr y)580 bool Fortran::lower::isEqual(Fortran::lower::FrontEndExpr x,
581 Fortran::lower::FrontEndExpr y) {
582 auto empty = llvm::DenseMapInfo<Fortran::lower::FrontEndExpr>::getEmptyKey();
583 auto tombstone =
584 llvm::DenseMapInfo<Fortran::lower::FrontEndExpr>::getTombstoneKey();
585 if (x == empty || y == empty || x == tombstone || y == tombstone)
586 return x == y;
587 return x == y || IsEqualEvaluateExpr::isEqual(*x, *y);
588 }
589
590 namespace {
591
592 /// This class can recover the base array in an expression that contains
593 /// explicit iteration space symbols. Most of the class can be ignored as it is
594 /// boilerplate Fortran::evaluate::Expr traversal.
595 class ArrayBaseFinder {
596 public:
597 using RT = bool;
598
ArrayBaseFinder(llvm::ArrayRef<Fortran::lower::FrontEndSymbol> syms)599 ArrayBaseFinder(llvm::ArrayRef<Fortran::lower::FrontEndSymbol> syms)
600 : controlVars(syms.begin(), syms.end()) {}
601
602 template <typename T>
operator ()(const T & x)603 void operator()(const T &x) {
604 (void)find(x);
605 }
606
607 /// Get the list of bases.
608 llvm::ArrayRef<Fortran::lower::ExplicitIterSpace::ArrayBases>
getBases() const609 getBases() const {
610 LLVM_DEBUG(llvm::dbgs()
611 << "number of array bases found: " << bases.size() << '\n');
612 return bases;
613 }
614
615 private:
616 // First, the cases that are of interest.
find(const Fortran::semantics::Symbol & symbol)617 RT find(const Fortran::semantics::Symbol &symbol) {
618 if (symbol.Rank() > 0) {
619 bases.push_back(&symbol);
620 return true;
621 }
622 return {};
623 }
find(const Fortran::evaluate::Component & x)624 RT find(const Fortran::evaluate::Component &x) {
625 auto found = find(x.base());
626 if (!found && x.base().Rank() == 0 && x.Rank() > 0) {
627 bases.push_back(&x);
628 return true;
629 }
630 return found;
631 }
find(const Fortran::evaluate::ArrayRef & x)632 RT find(const Fortran::evaluate::ArrayRef &x) {
633 for (const auto &sub : x.subscript())
634 (void)find(sub);
635 if (x.base().IsSymbol()) {
636 if (x.Rank() > 0 || intersection(x.subscript())) {
637 bases.push_back(&x);
638 return true;
639 }
640 return {};
641 }
642 auto found = find(x.base());
643 if (!found && ((x.base().Rank() == 0 && x.Rank() > 0) ||
644 intersection(x.subscript()))) {
645 bases.push_back(&x);
646 return true;
647 }
648 return found;
649 }
find(const Fortran::evaluate::Triplet & x)650 RT find(const Fortran::evaluate::Triplet &x) {
651 if (const auto *lower = x.GetLower())
652 (void)find(*lower);
653 if (const auto *upper = x.GetUpper())
654 (void)find(*upper);
655 return find(x.GetStride());
656 }
find(const Fortran::evaluate::IndirectSubscriptIntegerExpr & x)657 RT find(const Fortran::evaluate::IndirectSubscriptIntegerExpr &x) {
658 return find(x.value());
659 }
find(const Fortran::evaluate::Subscript & x)660 RT find(const Fortran::evaluate::Subscript &x) { return find(x.u); }
find(const Fortran::evaluate::DataRef & x)661 RT find(const Fortran::evaluate::DataRef &x) { return find(x.u); }
find(const Fortran::evaluate::CoarrayRef & x)662 RT find(const Fortran::evaluate::CoarrayRef &x) {
663 assert(false && "coarray reference");
664 return {};
665 }
666
667 template <typename A>
intersection(const A & subscripts)668 bool intersection(const A &subscripts) {
669 return Fortran::lower::symbolsIntersectSubscripts(controlVars, subscripts);
670 }
671
672 // The rest is traversal boilerplate and can be ignored.
find(const Fortran::evaluate::Substring & x)673 RT find(const Fortran::evaluate::Substring &x) { return find(x.parent()); }
674 template <typename A>
find(const Fortran::semantics::SymbolRef x)675 RT find(const Fortran::semantics::SymbolRef x) {
676 return find(*x);
677 }
find(const Fortran::evaluate::NamedEntity & x)678 RT find(const Fortran::evaluate::NamedEntity &x) {
679 if (x.IsSymbol())
680 return find(x.GetFirstSymbol());
681 return find(x.GetComponent());
682 }
683
684 template <typename A, bool C>
find(const Fortran::common::Indirection<A,C> & x)685 RT find(const Fortran::common::Indirection<A, C> &x) {
686 return find(x.value());
687 }
688 template <typename A>
find(const std::unique_ptr<A> & x)689 RT find(const std::unique_ptr<A> &x) {
690 return find(x.get());
691 }
692 template <typename A>
find(const std::shared_ptr<A> & x)693 RT find(const std::shared_ptr<A> &x) {
694 return find(x.get());
695 }
696 template <typename A>
find(const A * x)697 RT find(const A *x) {
698 if (x)
699 return find(*x);
700 return {};
701 }
702 template <typename A>
find(const std::optional<A> & x)703 RT find(const std::optional<A> &x) {
704 if (x)
705 return find(*x);
706 return {};
707 }
708 template <typename... A>
find(const std::variant<A...> & u)709 RT find(const std::variant<A...> &u) {
710 return std::visit([&](const auto &v) { return find(v); }, u);
711 }
712 template <typename A>
find(const std::vector<A> & x)713 RT find(const std::vector<A> &x) {
714 for (auto &v : x)
715 (void)find(v);
716 return {};
717 }
find(const Fortran::evaluate::BOZLiteralConstant &)718 RT find(const Fortran::evaluate::BOZLiteralConstant &) { return {}; }
find(const Fortran::evaluate::NullPointer &)719 RT find(const Fortran::evaluate::NullPointer &) { return {}; }
720 template <typename T>
find(const Fortran::evaluate::Constant<T> & x)721 RT find(const Fortran::evaluate::Constant<T> &x) {
722 return {};
723 }
find(const Fortran::evaluate::StaticDataObject &)724 RT find(const Fortran::evaluate::StaticDataObject &) { return {}; }
find(const Fortran::evaluate::ImpliedDoIndex &)725 RT find(const Fortran::evaluate::ImpliedDoIndex &) { return {}; }
find(const Fortran::evaluate::BaseObject & x)726 RT find(const Fortran::evaluate::BaseObject &x) {
727 (void)find(x.u);
728 return {};
729 }
find(const Fortran::evaluate::TypeParamInquiry &)730 RT find(const Fortran::evaluate::TypeParamInquiry &) { return {}; }
find(const Fortran::evaluate::ComplexPart & x)731 RT find(const Fortran::evaluate::ComplexPart &x) { return {}; }
732 template <typename T>
find(const Fortran::evaluate::Designator<T> & x)733 RT find(const Fortran::evaluate::Designator<T> &x) {
734 return find(x.u);
735 }
736 template <typename T>
find(const Fortran::evaluate::Variable<T> & x)737 RT find(const Fortran::evaluate::Variable<T> &x) {
738 return find(x.u);
739 }
find(const Fortran::evaluate::DescriptorInquiry &)740 RT find(const Fortran::evaluate::DescriptorInquiry &) { return {}; }
find(const Fortran::evaluate::SpecificIntrinsic &)741 RT find(const Fortran::evaluate::SpecificIntrinsic &) { return {}; }
find(const Fortran::evaluate::ProcedureDesignator & x)742 RT find(const Fortran::evaluate::ProcedureDesignator &x) { return {}; }
find(const Fortran::evaluate::ProcedureRef & x)743 RT find(const Fortran::evaluate::ProcedureRef &x) {
744 (void)find(x.proc());
745 if (x.IsElemental())
746 (void)find(x.arguments());
747 return {};
748 }
find(const Fortran::evaluate::ActualArgument & x)749 RT find(const Fortran::evaluate::ActualArgument &x) {
750 if (const auto *sym = x.GetAssumedTypeDummy())
751 (void)find(*sym);
752 else
753 (void)find(x.UnwrapExpr());
754 return {};
755 }
756 template <typename T>
find(const Fortran::evaluate::FunctionRef<T> & x)757 RT find(const Fortran::evaluate::FunctionRef<T> &x) {
758 (void)find(static_cast<const Fortran::evaluate::ProcedureRef &>(x));
759 return {};
760 }
761 template <typename T>
find(const Fortran::evaluate::ArrayConstructorValue<T> &)762 RT find(const Fortran::evaluate::ArrayConstructorValue<T> &) {
763 return {};
764 }
765 template <typename T>
find(const Fortran::evaluate::ArrayConstructorValues<T> &)766 RT find(const Fortran::evaluate::ArrayConstructorValues<T> &) {
767 return {};
768 }
769 template <typename T>
find(const Fortran::evaluate::ImpliedDo<T> &)770 RT find(const Fortran::evaluate::ImpliedDo<T> &) {
771 return {};
772 }
find(const Fortran::semantics::ParamValue &)773 RT find(const Fortran::semantics::ParamValue &) { return {}; }
find(const Fortran::semantics::DerivedTypeSpec &)774 RT find(const Fortran::semantics::DerivedTypeSpec &) { return {}; }
find(const Fortran::evaluate::StructureConstructor &)775 RT find(const Fortran::evaluate::StructureConstructor &) { return {}; }
776 template <typename D, typename R, typename O>
find(const Fortran::evaluate::Operation<D,R,O> & op)777 RT find(const Fortran::evaluate::Operation<D, R, O> &op) {
778 (void)find(op.left());
779 return false;
780 }
781 template <typename D, typename R, typename LO, typename RO>
find(const Fortran::evaluate::Operation<D,R,LO,RO> & op)782 RT find(const Fortran::evaluate::Operation<D, R, LO, RO> &op) {
783 (void)find(op.left());
784 (void)find(op.right());
785 return false;
786 }
find(const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> & x)787 RT find(const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> &x) {
788 (void)find(x.u);
789 return {};
790 }
791 template <typename T>
find(const Fortran::evaluate::Expr<T> & x)792 RT find(const Fortran::evaluate::Expr<T> &x) {
793 (void)find(x.u);
794 return {};
795 }
796
797 llvm::SmallVector<Fortran::lower::ExplicitIterSpace::ArrayBases> bases;
798 llvm::SmallVector<Fortran::lower::FrontEndSymbol> controlVars;
799 };
800
801 } // namespace
802
leave()803 void Fortran::lower::ExplicitIterSpace::leave() {
804 ccLoopNest.pop_back();
805 --forallContextOpen;
806 conditionalCleanup();
807 }
808
addSymbol(Fortran::lower::FrontEndSymbol sym)809 void Fortran::lower::ExplicitIterSpace::addSymbol(
810 Fortran::lower::FrontEndSymbol sym) {
811 assert(!symbolStack.empty());
812 symbolStack.back().push_back(sym);
813 }
814
exprBase(Fortran::lower::FrontEndExpr x,bool lhs)815 void Fortran::lower::ExplicitIterSpace::exprBase(Fortran::lower::FrontEndExpr x,
816 bool lhs) {
817 ArrayBaseFinder finder(collectAllSymbols());
818 finder(*x);
819 llvm::ArrayRef<Fortran::lower::ExplicitIterSpace::ArrayBases> bases =
820 finder.getBases();
821 if (rhsBases.empty())
822 endAssign();
823 if (lhs) {
824 if (bases.empty()) {
825 lhsBases.push_back(llvm::None);
826 return;
827 }
828 assert(bases.size() >= 1 && "must detect an array reference on lhs");
829 if (bases.size() > 1)
830 rhsBases.back().append(bases.begin(), bases.end() - 1);
831 lhsBases.push_back(bases.back());
832 return;
833 }
834 rhsBases.back().append(bases.begin(), bases.end());
835 }
836
endAssign()837 void Fortran::lower::ExplicitIterSpace::endAssign() { rhsBases.emplace_back(); }
838
pushLevel()839 void Fortran::lower::ExplicitIterSpace::pushLevel() {
840 symbolStack.push_back(llvm::SmallVector<Fortran::lower::FrontEndSymbol>{});
841 }
842
popLevel()843 void Fortran::lower::ExplicitIterSpace::popLevel() { symbolStack.pop_back(); }
844
conditionalCleanup()845 void Fortran::lower::ExplicitIterSpace::conditionalCleanup() {
846 if (forallContextOpen == 0) {
847 // Exiting the outermost FORALL context.
848 // Cleanup any residual mask buffers.
849 outermostContext().finalize();
850 // Clear and reset all the cached information.
851 symbolStack.clear();
852 lhsBases.clear();
853 rhsBases.clear();
854 loadBindings.clear();
855 ccLoopNest.clear();
856 innerArgs.clear();
857 outerLoop = llvm::None;
858 clearLoops();
859 counter = 0;
860 }
861 }
862
863 llvm::Optional<size_t>
findArgPosition(fir::ArrayLoadOp load)864 Fortran::lower::ExplicitIterSpace::findArgPosition(fir::ArrayLoadOp load) {
865 if (lhsBases[counter]) {
866 auto ld = loadBindings.find(*lhsBases[counter]);
867 llvm::Optional<size_t> optPos;
868 if (ld != loadBindings.end() && ld->second == load)
869 optPos = static_cast<size_t>(0u);
870 assert(optPos.has_value() && "load does not correspond to lhs");
871 return optPos;
872 }
873 return llvm::None;
874 }
875
876 llvm::SmallVector<Fortran::lower::FrontEndSymbol>
collectAllSymbols()877 Fortran::lower::ExplicitIterSpace::collectAllSymbols() {
878 llvm::SmallVector<Fortran::lower::FrontEndSymbol> result;
879 for (llvm::SmallVector<FrontEndSymbol> vec : symbolStack)
880 result.append(vec.begin(), vec.end());
881 return result;
882 }
883
884 llvm::raw_ostream &
operator <<(llvm::raw_ostream & s,const Fortran::lower::ImplicitIterSpace & e)885 Fortran::lower::operator<<(llvm::raw_ostream &s,
886 const Fortran::lower::ImplicitIterSpace &e) {
887 for (const llvm::SmallVector<
888 Fortran::lower::ImplicitIterSpace::FrontEndMaskExpr> &xs :
889 e.getMasks()) {
890 s << "{ ";
891 for (const Fortran::lower::ImplicitIterSpace::FrontEndMaskExpr &x : xs)
892 x->AsFortran(s << '(') << "), ";
893 s << "}\n";
894 }
895 return s;
896 }
897
898 llvm::raw_ostream &
operator <<(llvm::raw_ostream & s,const Fortran::lower::ExplicitIterSpace & e)899 Fortran::lower::operator<<(llvm::raw_ostream &s,
900 const Fortran::lower::ExplicitIterSpace &e) {
901 auto dump = [&](const auto &u) {
902 std::visit(Fortran::common::visitors{
903 [&](const Fortran::semantics::Symbol *y) {
904 s << " " << *y << '\n';
905 },
906 [&](const Fortran::evaluate::ArrayRef *y) {
907 s << " ";
908 if (y->base().IsSymbol())
909 s << y->base().GetFirstSymbol();
910 else
911 s << y->base().GetComponent().GetLastSymbol();
912 s << '\n';
913 },
914 [&](const Fortran::evaluate::Component *y) {
915 s << " " << y->GetLastSymbol() << '\n';
916 }},
917 u);
918 };
919 s << "LHS bases:\n";
920 for (const llvm::Optional<Fortran::lower::ExplicitIterSpace::ArrayBases> &u :
921 e.lhsBases)
922 if (u)
923 dump(*u);
924 s << "RHS bases:\n";
925 for (const llvm::SmallVector<Fortran::lower::ExplicitIterSpace::ArrayBases>
926 &bases : e.rhsBases) {
927 for (const Fortran::lower::ExplicitIterSpace::ArrayBases &u : bases)
928 dump(u);
929 s << '\n';
930 }
931 return s;
932 }
933
dump() const934 void Fortran::lower::ImplicitIterSpace::dump() const {
935 llvm::errs() << *this << '\n';
936 }
937
dump() const938 void Fortran::lower::ExplicitIterSpace::dump() const {
939 llvm::errs() << *this << '\n';
940 }
941