1 //===-- lib/Semantics/check-do-forall.cpp ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "check-do-forall.h"
10 #include "flang/Common/template.h"
11 #include "flang/Evaluate/call.h"
12 #include "flang/Evaluate/expression.h"
13 #include "flang/Evaluate/tools.h"
14 #include "flang/Parser/message.h"
15 #include "flang/Parser/parse-tree-visitor.h"
16 #include "flang/Parser/tools.h"
17 #include "flang/Semantics/attr.h"
18 #include "flang/Semantics/scope.h"
19 #include "flang/Semantics/semantics.h"
20 #include "flang/Semantics/symbol.h"
21 #include "flang/Semantics/tools.h"
22 #include "flang/Semantics/type.h"
23 
24 namespace Fortran::evaluate {
25 using ActualArgumentRef = common::Reference<const ActualArgument>;
26 
27 inline bool operator<(ActualArgumentRef x, ActualArgumentRef y) {
28   return &*x < &*y;
29 }
30 } // namespace Fortran::evaluate
31 
32 namespace Fortran::semantics {
33 
34 using namespace parser::literals;
35 
36 using Bounds = parser::LoopControl::Bounds;
37 using IndexVarKind = SemanticsContext::IndexVarKind;
38 
39 static const parser::ConcurrentHeader &GetConcurrentHeader(
40     const parser::LoopControl &loopControl) {
41   const auto &concurrent{
42       std::get<parser::LoopControl::Concurrent>(loopControl.u)};
43   return std::get<parser::ConcurrentHeader>(concurrent.t);
44 }
45 static const parser::ConcurrentHeader &GetConcurrentHeader(
46     const parser::ForallConstruct &construct) {
47   const auto &stmt{
48       std::get<parser::Statement<parser::ForallConstructStmt>>(construct.t)};
49   return std::get<common::Indirection<parser::ConcurrentHeader>>(
50       stmt.statement.t)
51       .value();
52 }
53 static const parser::ConcurrentHeader &GetConcurrentHeader(
54     const parser::ForallStmt &stmt) {
55   return std::get<common::Indirection<parser::ConcurrentHeader>>(stmt.t)
56       .value();
57 }
58 template <typename T>
59 static const std::list<parser::ConcurrentControl> &GetControls(const T &x) {
60   return std::get<std::list<parser::ConcurrentControl>>(
61       GetConcurrentHeader(x).t);
62 }
63 
64 static const Bounds &GetBounds(const parser::DoConstruct &doConstruct) {
65   auto &loopControl{doConstruct.GetLoopControl().value()};
66   return std::get<Bounds>(loopControl.u);
67 }
68 
69 static const parser::Name &GetDoVariable(
70     const parser::DoConstruct &doConstruct) {
71   const Bounds &bounds{GetBounds(doConstruct)};
72   return bounds.name.thing;
73 }
74 
75 static parser::MessageFixedText GetEnclosingDoMsg() {
76   return "Enclosing DO CONCURRENT statement"_en_US;
77 }
78 
79 static void SayWithDo(SemanticsContext &context, parser::CharBlock stmtLocation,
80     parser::MessageFixedText &&message, parser::CharBlock doLocation) {
81   context.Say(stmtLocation, message).Attach(doLocation, GetEnclosingDoMsg());
82 }
83 
84 // 11.1.7.5 - enforce semantics constraints on a DO CONCURRENT loop body
85 class DoConcurrentBodyEnforce {
86 public:
87   DoConcurrentBodyEnforce(
88       SemanticsContext &context, parser::CharBlock doConcurrentSourcePosition)
89       : context_{context}, doConcurrentSourcePosition_{
90                                doConcurrentSourcePosition} {}
91   std::set<parser::Label> labels() { return labels_; }
92   template <typename T> bool Pre(const T &) { return true; }
93   template <typename T> void Post(const T &) {}
94 
95   template <typename T> bool Pre(const parser::Statement<T> &statement) {
96     currentStatementSourcePosition_ = statement.source;
97     if (statement.label.has_value()) {
98       labels_.insert(*statement.label);
99     }
100     return true;
101   }
102 
103   template <typename T> bool Pre(const parser::UnlabeledStatement<T> &stmt) {
104     currentStatementSourcePosition_ = stmt.source;
105     return true;
106   }
107 
108   // C1140 -- Can't deallocate a polymorphic entity in a DO CONCURRENT.
109   // Deallocation can be caused by exiting a block that declares an allocatable
110   // entity, assignment to an allocatable variable, or an actual DEALLOCATE
111   // statement
112   //
113   // Note also that the deallocation of a derived type entity might cause the
114   // invocation of an IMPURE final subroutine. (C1139)
115   //
116 
117   // Only to be called for symbols with ObjectEntityDetails
118   static bool HasImpureFinal(const Symbol &original) {
119     const Symbol &symbol{ResolveAssociations(original)};
120     if (symbol.has<ObjectEntityDetails>()) {
121       if (const DeclTypeSpec * symType{symbol.GetType()}) {
122         if (const DerivedTypeSpec * derived{symType->AsDerived()}) {
123           return semantics::HasImpureFinal(*derived);
124         }
125       }
126     }
127     return false;
128   }
129 
130   // Predicate for deallocations caused by block exit and direct deallocation
131   static bool DeallocateAll(const Symbol &) { return true; }
132 
133   // Predicate for deallocations caused by intrinsic assignment
134   static bool DeallocateNonCoarray(const Symbol &component) {
135     return !IsCoarray(component);
136   }
137 
138   static bool WillDeallocatePolymorphic(const Symbol &entity,
139       const std::function<bool(const Symbol &)> &WillDeallocate) {
140     return WillDeallocate(entity) && IsPolymorphicAllocatable(entity);
141   }
142 
143   // Is it possible that we will we deallocate a polymorphic entity or one
144   // of its components?
145   static bool MightDeallocatePolymorphic(const Symbol &original,
146       const std::function<bool(const Symbol &)> &WillDeallocate) {
147     const Symbol &symbol{ResolveAssociations(original)};
148     // Check the entity itself, no coarray exception here
149     if (IsPolymorphicAllocatable(symbol)) {
150       return true;
151     }
152     // Check the components
153     if (const auto *details{symbol.detailsIf<ObjectEntityDetails>()}) {
154       if (const DeclTypeSpec * entityType{details->type()}) {
155         if (const DerivedTypeSpec * derivedType{entityType->AsDerived()}) {
156           UltimateComponentIterator ultimates{*derivedType};
157           for (const auto &ultimate : ultimates) {
158             if (WillDeallocatePolymorphic(ultimate, WillDeallocate)) {
159               return true;
160             }
161           }
162         }
163       }
164     }
165     return false;
166   }
167 
168   void SayDeallocateWithImpureFinal(const Symbol &entity, const char *reason) {
169     context_.SayWithDecl(entity, currentStatementSourcePosition_,
170         "Deallocation of an entity with an IMPURE FINAL procedure"
171         " caused by %s not allowed in DO CONCURRENT"_err_en_US,
172         reason);
173   }
174 
175   void SayDeallocateOfPolymorph(
176       parser::CharBlock location, const Symbol &entity, const char *reason) {
177     context_.SayWithDecl(entity, location,
178         "Deallocation of a polymorphic entity caused by %s"
179         " not allowed in DO CONCURRENT"_err_en_US,
180         reason);
181   }
182 
183   // Deallocation caused by block exit
184   // Allocatable entities and all of their allocatable subcomponents will be
185   // deallocated.  This test is different from the other two because it does
186   // not deallocate in cases where the entity itself is not allocatable but
187   // has allocatable polymorphic components
188   void Post(const parser::BlockConstruct &blockConstruct) {
189     const auto &endBlockStmt{
190         std::get<parser::Statement<parser::EndBlockStmt>>(blockConstruct.t)};
191     const Scope &blockScope{context_.FindScope(endBlockStmt.source)};
192     const Scope &doScope{context_.FindScope(doConcurrentSourcePosition_)};
193     if (DoesScopeContain(&doScope, blockScope)) {
194       const char *reason{"block exit"};
195       for (auto &pair : blockScope) {
196         const Symbol &entity{*pair.second};
197         if (IsAllocatable(entity) && !IsSaved(entity) &&
198             MightDeallocatePolymorphic(entity, DeallocateAll)) {
199           SayDeallocateOfPolymorph(endBlockStmt.source, entity, reason);
200         }
201         if (HasImpureFinal(entity)) {
202           SayDeallocateWithImpureFinal(entity, reason);
203         }
204       }
205     }
206   }
207 
208   // Deallocation caused by assignment
209   // Note that this case does not cause deallocation of coarray components
210   void Post(const parser::AssignmentStmt &stmt) {
211     const auto &variable{std::get<parser::Variable>(stmt.t)};
212     if (const Symbol * entity{GetLastName(variable).symbol}) {
213       const char *reason{"assignment"};
214       if (MightDeallocatePolymorphic(*entity, DeallocateNonCoarray)) {
215         SayDeallocateOfPolymorph(variable.GetSource(), *entity, reason);
216       }
217       if (HasImpureFinal(*entity)) {
218         SayDeallocateWithImpureFinal(*entity, reason);
219       }
220     }
221   }
222 
223   // Deallocation from a DEALLOCATE statement
224   // This case is different because DEALLOCATE statements deallocate both
225   // ALLOCATABLE and POINTER entities
226   void Post(const parser::DeallocateStmt &stmt) {
227     const auto &allocateObjectList{
228         std::get<std::list<parser::AllocateObject>>(stmt.t)};
229     for (const auto &allocateObject : allocateObjectList) {
230       const parser::Name &name{GetLastName(allocateObject)};
231       const char *reason{"a DEALLOCATE statement"};
232       if (name.symbol) {
233         const Symbol &entity{*name.symbol};
234         const DeclTypeSpec *entityType{entity.GetType()};
235         if ((entityType && entityType->IsPolymorphic()) || // POINTER case
236             MightDeallocatePolymorphic(entity, DeallocateAll)) {
237           SayDeallocateOfPolymorph(
238               currentStatementSourcePosition_, entity, reason);
239         }
240         if (HasImpureFinal(entity)) {
241           SayDeallocateWithImpureFinal(entity, reason);
242         }
243       }
244     }
245   }
246 
247   // C1137 -- No image control statements in a DO CONCURRENT
248   void Post(const parser::ExecutableConstruct &construct) {
249     if (IsImageControlStmt(construct)) {
250       const parser::CharBlock statementLocation{
251           GetImageControlStmtLocation(construct)};
252       auto &msg{context_.Say(statementLocation,
253           "An image control statement is not allowed in DO"
254           " CONCURRENT"_err_en_US)};
255       if (auto coarrayMsg{GetImageControlStmtCoarrayMsg(construct)}) {
256         msg.Attach(statementLocation, *coarrayMsg);
257       }
258       msg.Attach(doConcurrentSourcePosition_, GetEnclosingDoMsg());
259     }
260   }
261 
262   // C1136 -- No RETURN statements in a DO CONCURRENT
263   void Post(const parser::ReturnStmt &) {
264     context_
265         .Say(currentStatementSourcePosition_,
266             "RETURN is not allowed in DO CONCURRENT"_err_en_US)
267         .Attach(doConcurrentSourcePosition_, GetEnclosingDoMsg());
268   }
269 
270   // C1139: call to impure procedure and ...
271   // C1141: cannot call ieee_get_flag, ieee_[gs]et_halting_mode
272   // It's not necessary to check the ieee_get* procedures because they're
273   // not pure, and impure procedures are caught by checks for constraint C1139
274   void Post(const parser::ProcedureDesignator &procedureDesignator) {
275     if (auto *name{std::get_if<parser::Name>(&procedureDesignator.u)}) {
276       if (name->symbol && !IsPureProcedure(*name->symbol)) {
277         SayWithDo(context_, currentStatementSourcePosition_,
278             "Call to an impure procedure is not allowed in DO"
279             " CONCURRENT"_err_en_US,
280             doConcurrentSourcePosition_);
281       }
282       if (name->symbol && fromScope(*name->symbol, "ieee_exceptions"s)) {
283         if (name->source == "ieee_set_halting_mode") {
284           SayWithDo(context_, currentStatementSourcePosition_,
285               "IEEE_SET_HALTING_MODE is not allowed in DO "
286               "CONCURRENT"_err_en_US,
287               doConcurrentSourcePosition_);
288         }
289       }
290     } else {
291       // C1139: this a procedure component
292       auto &component{std::get<parser::ProcComponentRef>(procedureDesignator.u)
293                           .v.thing.component};
294       if (component.symbol && !IsPureProcedure(*component.symbol)) {
295         SayWithDo(context_, currentStatementSourcePosition_,
296             "Call to an impure procedure component is not allowed"
297             " in DO CONCURRENT"_err_en_US,
298             doConcurrentSourcePosition_);
299       }
300     }
301   }
302 
303   // 11.1.7.5, paragraph 5, no ADVANCE specifier in a DO CONCURRENT
304   void Post(const parser::IoControlSpec &ioControlSpec) {
305     if (auto *charExpr{
306             std::get_if<parser::IoControlSpec::CharExpr>(&ioControlSpec.u)}) {
307       if (std::get<parser::IoControlSpec::CharExpr::Kind>(charExpr->t) ==
308           parser::IoControlSpec::CharExpr::Kind::Advance) {
309         SayWithDo(context_, currentStatementSourcePosition_,
310             "ADVANCE specifier is not allowed in DO"
311             " CONCURRENT"_err_en_US,
312             doConcurrentSourcePosition_);
313       }
314     }
315   }
316 
317 private:
318   bool fromScope(const Symbol &symbol, const std::string &moduleName) {
319     if (symbol.GetUltimate().owner().IsModule() &&
320         symbol.GetUltimate().owner().GetName().value().ToString() ==
321             moduleName) {
322       return true;
323     }
324     return false;
325   }
326 
327   std::set<parser::Label> labels_;
328   parser::CharBlock currentStatementSourcePosition_;
329   SemanticsContext &context_;
330   parser::CharBlock doConcurrentSourcePosition_;
331 }; // class DoConcurrentBodyEnforce
332 
333 // Class for enforcing C1130 -- in a DO CONCURRENT with DEFAULT(NONE),
334 // variables from enclosing scopes must have their locality specified
335 class DoConcurrentVariableEnforce {
336 public:
337   DoConcurrentVariableEnforce(
338       SemanticsContext &context, parser::CharBlock doConcurrentSourcePosition)
339       : context_{context},
340         doConcurrentSourcePosition_{doConcurrentSourcePosition},
341         blockScope_{context.FindScope(doConcurrentSourcePosition_)} {}
342 
343   template <typename T> bool Pre(const T &) { return true; }
344   template <typename T> void Post(const T &) {}
345 
346   // Check to see if the name is a variable from an enclosing scope
347   void Post(const parser::Name &name) {
348     if (const Symbol * symbol{name.symbol}) {
349       if (IsVariableName(*symbol)) {
350         const Scope &variableScope{symbol->owner()};
351         if (DoesScopeContain(&variableScope, blockScope_)) {
352           context_.SayWithDecl(*symbol, name.source,
353               "Variable '%s' from an enclosing scope referenced in DO "
354               "CONCURRENT with DEFAULT(NONE) must appear in a "
355               "locality-spec"_err_en_US,
356               symbol->name());
357         }
358       }
359     }
360   }
361 
362 private:
363   SemanticsContext &context_;
364   parser::CharBlock doConcurrentSourcePosition_;
365   const Scope &blockScope_;
366 }; // class DoConcurrentVariableEnforce
367 
368 // Find a DO or FORALL and enforce semantics checks on its body
369 class DoContext {
370 public:
371   DoContext(SemanticsContext &context, IndexVarKind kind)
372       : context_{context}, kind_{kind} {}
373 
374   // Mark this DO construct as a point of definition for the DO variables
375   // or index-names it contains.  If they're already defined, emit an error
376   // message.  We need to remember both the variable and the source location of
377   // the variable in the DO construct so that we can remove it when we leave
378   // the DO construct and use its location in error messages.
379   void DefineDoVariables(const parser::DoConstruct &doConstruct) {
380     if (doConstruct.IsDoNormal()) {
381       context_.ActivateIndexVar(GetDoVariable(doConstruct), IndexVarKind::DO);
382     } else if (doConstruct.IsDoConcurrent()) {
383       if (const auto &loopControl{doConstruct.GetLoopControl()}) {
384         ActivateIndexVars(GetControls(*loopControl));
385       }
386     }
387   }
388 
389   // Called at the end of a DO construct to deactivate the DO construct
390   void ResetDoVariables(const parser::DoConstruct &doConstruct) {
391     if (doConstruct.IsDoNormal()) {
392       context_.DeactivateIndexVar(GetDoVariable(doConstruct));
393     } else if (doConstruct.IsDoConcurrent()) {
394       if (const auto &loopControl{doConstruct.GetLoopControl()}) {
395         DeactivateIndexVars(GetControls(*loopControl));
396       }
397     }
398   }
399 
400   void ActivateIndexVars(const std::list<parser::ConcurrentControl> &controls) {
401     for (const auto &control : controls) {
402       context_.ActivateIndexVar(std::get<parser::Name>(control.t), kind_);
403     }
404   }
405   void DeactivateIndexVars(
406       const std::list<parser::ConcurrentControl> &controls) {
407     for (const auto &control : controls) {
408       context_.DeactivateIndexVar(std::get<parser::Name>(control.t));
409     }
410   }
411 
412   void Check(const parser::DoConstruct &doConstruct) {
413     if (doConstruct.IsDoConcurrent()) {
414       CheckDoConcurrent(doConstruct);
415       return;
416     }
417     if (doConstruct.IsDoNormal()) {
418       CheckDoNormal(doConstruct);
419       return;
420     }
421     // TODO: handle the other cases
422   }
423 
424   void Check(const parser::ForallStmt &stmt) {
425     CheckConcurrentHeader(GetConcurrentHeader(stmt));
426   }
427   void Check(const parser::ForallConstruct &construct) {
428     CheckConcurrentHeader(GetConcurrentHeader(construct));
429   }
430 
431   void Check(const parser::ForallAssignmentStmt &stmt) {
432     const evaluate::Assignment *assignment{std::visit(
433         common::visitors{[&](const auto &x) { return GetAssignment(x); }},
434         stmt.u)};
435     if (assignment) {
436       CheckForallIndexesUsed(*assignment);
437       CheckForImpureCall(assignment->lhs);
438       CheckForImpureCall(assignment->rhs);
439       if (const auto *proc{
440               std::get_if<evaluate::ProcedureRef>(&assignment->u)}) {
441         CheckForImpureCall(*proc);
442       }
443       std::visit(common::visitors{
444                      [](const evaluate::Assignment::Intrinsic &) {},
445                      [&](const evaluate::ProcedureRef &proc) {
446                        CheckForImpureCall(proc);
447                      },
448                      [&](const evaluate::Assignment::BoundsSpec &bounds) {
449                        for (const auto &bound : bounds) {
450                          CheckForImpureCall(SomeExpr{bound});
451                        }
452                      },
453                      [&](const evaluate::Assignment::BoundsRemapping &bounds) {
454                        for (const auto &bound : bounds) {
455                          CheckForImpureCall(SomeExpr{bound.first});
456                          CheckForImpureCall(SomeExpr{bound.second});
457                        }
458                      },
459                  },
460           assignment->u);
461     }
462   }
463 
464 private:
465   void SayBadDoControl(parser::CharBlock sourceLocation) {
466     context_.Say(sourceLocation, "DO controls should be INTEGER"_err_en_US);
467   }
468 
469   void CheckDoControl(const parser::CharBlock &sourceLocation, bool isReal) {
470     const bool warn{context_.warnOnNonstandardUsage() ||
471         context_.ShouldWarn(common::LanguageFeature::RealDoControls)};
472     if (isReal && !warn) {
473       // No messages for the default case
474     } else if (isReal && warn) {
475       context_.Say(sourceLocation, "DO controls should be INTEGER"_en_US);
476     } else {
477       SayBadDoControl(sourceLocation);
478     }
479   }
480 
481   void CheckDoVariable(const parser::ScalarName &scalarName) {
482     const parser::CharBlock &sourceLocation{scalarName.thing.source};
483     if (const Symbol * symbol{scalarName.thing.symbol}) {
484       if (!IsVariableName(*symbol)) {
485         context_.Say(
486             sourceLocation, "DO control must be an INTEGER variable"_err_en_US);
487       } else {
488         const DeclTypeSpec *symType{symbol->GetType()};
489         if (!symType) {
490           SayBadDoControl(sourceLocation);
491         } else {
492           if (!symType->IsNumeric(TypeCategory::Integer)) {
493             CheckDoControl(
494                 sourceLocation, symType->IsNumeric(TypeCategory::Real));
495           }
496         }
497       } // No messages for INTEGER
498     }
499   }
500 
501   // Semantic checks for the limit and step expressions
502   void CheckDoExpression(const parser::ScalarExpr &scalarExpression) {
503     if (const SomeExpr * expr{GetExpr(scalarExpression)}) {
504       if (!ExprHasTypeCategory(*expr, TypeCategory::Integer)) {
505         // No warnings or errors for type INTEGER
506         const parser::CharBlock &loc{scalarExpression.thing.value().source};
507         CheckDoControl(loc, ExprHasTypeCategory(*expr, TypeCategory::Real));
508       }
509     }
510   }
511 
512   void CheckDoNormal(const parser::DoConstruct &doConstruct) {
513     // C1120 -- types of DO variables must be INTEGER, extended by allowing
514     // REAL and DOUBLE PRECISION
515     const Bounds &bounds{GetBounds(doConstruct)};
516     CheckDoVariable(bounds.name);
517     CheckDoExpression(bounds.lower);
518     CheckDoExpression(bounds.upper);
519     if (bounds.step) {
520       CheckDoExpression(*bounds.step);
521       if (IsZero(*bounds.step)) {
522         context_.Say(bounds.step->thing.value().source,
523             "DO step expression should not be zero"_en_US);
524       }
525     }
526   }
527 
528   void CheckDoConcurrent(const parser::DoConstruct &doConstruct) {
529     auto &doStmt{
530         std::get<parser::Statement<parser::NonLabelDoStmt>>(doConstruct.t)};
531     currentStatementSourcePosition_ = doStmt.source;
532 
533     const parser::Block &block{std::get<parser::Block>(doConstruct.t)};
534     DoConcurrentBodyEnforce doConcurrentBodyEnforce{context_, doStmt.source};
535     parser::Walk(block, doConcurrentBodyEnforce);
536 
537     LabelEnforce doConcurrentLabelEnforce{context_,
538         doConcurrentBodyEnforce.labels(), currentStatementSourcePosition_,
539         "DO CONCURRENT"};
540     parser::Walk(block, doConcurrentLabelEnforce);
541 
542     const auto &loopControl{doConstruct.GetLoopControl()};
543     CheckConcurrentLoopControl(*loopControl);
544     CheckLocalitySpecs(*loopControl, block);
545   }
546 
547   // Return a set of symbols whose names are in a Local locality-spec.  Look
548   // the names up in the scope that encloses the DO construct to avoid getting
549   // the local versions of them.  Then follow the host-, use-, and
550   // construct-associations to get the root symbols
551   SymbolSet GatherLocals(
552       const std::list<parser::LocalitySpec> &localitySpecs) const {
553     SymbolSet symbols;
554     const Scope &parentScope{
555         context_.FindScope(currentStatementSourcePosition_).parent()};
556     // Loop through the LocalitySpec::Local locality-specs
557     for (const auto &ls : localitySpecs) {
558       if (const auto *names{std::get_if<parser::LocalitySpec::Local>(&ls.u)}) {
559         // Loop through the names in the Local locality-spec getting their
560         // symbols
561         for (const parser::Name &name : names->v) {
562           if (const Symbol * symbol{parentScope.FindSymbol(name.source)}) {
563             symbols.insert(ResolveAssociations(*symbol));
564           }
565         }
566       }
567     }
568     return symbols;
569   }
570 
571   static SymbolSet GatherSymbolsFromExpression(const parser::Expr &expression) {
572     SymbolSet result;
573     if (const auto *expr{GetExpr(expression)}) {
574       for (const Symbol &symbol : evaluate::CollectSymbols(*expr)) {
575         result.insert(ResolveAssociations(symbol));
576       }
577     }
578     return result;
579   }
580 
581   // C1121 - procedures in mask must be pure
582   void CheckMaskIsPure(const parser::ScalarLogicalExpr &mask) const {
583     SymbolSet references{GatherSymbolsFromExpression(mask.thing.thing.value())};
584     for (const Symbol &ref : references) {
585       if (IsProcedure(ref) && !IsPureProcedure(ref)) {
586         context_.SayWithDecl(ref, parser::Unwrap<parser::Expr>(mask)->source,
587             "%s mask expression may not reference impure procedure '%s'"_err_en_US,
588             LoopKindName(), ref.name());
589         return;
590       }
591     }
592   }
593 
594   void CheckNoCollisions(const SymbolSet &refs, const SymbolSet &uses,
595       parser::MessageFixedText &&errorMessage,
596       const parser::CharBlock &refPosition) const {
597     for (const Symbol &ref : refs) {
598       if (uses.find(ref) != uses.end()) {
599         context_.SayWithDecl(ref, refPosition, std::move(errorMessage),
600             LoopKindName(), ref.name());
601         return;
602       }
603     }
604   }
605 
606   void HasNoReferences(
607       const SymbolSet &indexNames, const parser::ScalarIntExpr &expr) const {
608     CheckNoCollisions(GatherSymbolsFromExpression(expr.thing.thing.value()),
609         indexNames,
610         "%s limit expression may not reference index variable '%s'"_err_en_US,
611         expr.thing.thing.value().source);
612   }
613 
614   // C1129, names in local locality-specs can't be in mask expressions
615   void CheckMaskDoesNotReferenceLocal(
616       const parser::ScalarLogicalExpr &mask, const SymbolSet &localVars) const {
617     CheckNoCollisions(GatherSymbolsFromExpression(mask.thing.thing.value()),
618         localVars,
619         "%s mask expression references variable '%s'"
620         " in LOCAL locality-spec"_err_en_US,
621         mask.thing.thing.value().source);
622   }
623 
624   // C1129, names in local locality-specs can't be in limit or step
625   // expressions
626   void CheckExprDoesNotReferenceLocal(
627       const parser::ScalarIntExpr &expr, const SymbolSet &localVars) const {
628     CheckNoCollisions(GatherSymbolsFromExpression(expr.thing.thing.value()),
629         localVars,
630         "%s expression references variable '%s'"
631         " in LOCAL locality-spec"_err_en_US,
632         expr.thing.thing.value().source);
633   }
634 
635   // C1130, DEFAULT(NONE) locality requires names to be in locality-specs to
636   // be used in the body of the DO loop
637   void CheckDefaultNoneImpliesExplicitLocality(
638       const std::list<parser::LocalitySpec> &localitySpecs,
639       const parser::Block &block) const {
640     bool hasDefaultNone{false};
641     for (auto &ls : localitySpecs) {
642       if (std::holds_alternative<parser::LocalitySpec::DefaultNone>(ls.u)) {
643         if (hasDefaultNone) {
644           // C1127, you can only have one DEFAULT(NONE)
645           context_.Say(currentStatementSourcePosition_,
646               "Only one DEFAULT(NONE) may appear"_en_US);
647           break;
648         }
649         hasDefaultNone = true;
650       }
651     }
652     if (hasDefaultNone) {
653       DoConcurrentVariableEnforce doConcurrentVariableEnforce{
654           context_, currentStatementSourcePosition_};
655       parser::Walk(block, doConcurrentVariableEnforce);
656     }
657   }
658 
659   // C1123, concurrent limit or step expressions can't reference index-names
660   void CheckConcurrentHeader(const parser::ConcurrentHeader &header) const {
661     if (const auto &mask{
662             std::get<std::optional<parser::ScalarLogicalExpr>>(header.t)}) {
663       CheckMaskIsPure(*mask);
664     }
665     auto &controls{std::get<std::list<parser::ConcurrentControl>>(header.t)};
666     SymbolSet indexNames;
667     for (const parser::ConcurrentControl &control : controls) {
668       const auto &indexName{std::get<parser::Name>(control.t)};
669       if (indexName.symbol) {
670         indexNames.insert(*indexName.symbol);
671       }
672     }
673     if (!indexNames.empty()) {
674       for (const parser::ConcurrentControl &control : controls) {
675         HasNoReferences(indexNames, std::get<1>(control.t));
676         HasNoReferences(indexNames, std::get<2>(control.t));
677         if (const auto &intExpr{
678                 std::get<std::optional<parser::ScalarIntExpr>>(control.t)}) {
679           const parser::Expr &expr{intExpr->thing.thing.value()};
680           CheckNoCollisions(GatherSymbolsFromExpression(expr), indexNames,
681               "%s step expression may not reference index variable '%s'"_err_en_US,
682               expr.source);
683           if (IsZero(expr)) {
684             context_.Say(expr.source,
685                 "%s step expression may not be zero"_err_en_US, LoopKindName());
686           }
687         }
688       }
689     }
690   }
691 
692   void CheckLocalitySpecs(
693       const parser::LoopControl &control, const parser::Block &block) const {
694     const auto &concurrent{
695         std::get<parser::LoopControl::Concurrent>(control.u)};
696     const auto &header{std::get<parser::ConcurrentHeader>(concurrent.t)};
697     const auto &localitySpecs{
698         std::get<std::list<parser::LocalitySpec>>(concurrent.t)};
699     if (!localitySpecs.empty()) {
700       const SymbolSet &localVars{GatherLocals(localitySpecs)};
701       for (const auto &c : GetControls(control)) {
702         CheckExprDoesNotReferenceLocal(std::get<1>(c.t), localVars);
703         CheckExprDoesNotReferenceLocal(std::get<2>(c.t), localVars);
704         if (const auto &expr{
705                 std::get<std::optional<parser::ScalarIntExpr>>(c.t)}) {
706           CheckExprDoesNotReferenceLocal(*expr, localVars);
707         }
708       }
709       if (const auto &mask{
710               std::get<std::optional<parser::ScalarLogicalExpr>>(header.t)}) {
711         CheckMaskDoesNotReferenceLocal(*mask, localVars);
712       }
713       CheckDefaultNoneImpliesExplicitLocality(localitySpecs, block);
714     }
715   }
716 
717   // check constraints [C1121 .. C1130]
718   void CheckConcurrentLoopControl(const parser::LoopControl &control) const {
719     const auto &concurrent{
720         std::get<parser::LoopControl::Concurrent>(control.u)};
721     CheckConcurrentHeader(std::get<parser::ConcurrentHeader>(concurrent.t));
722   }
723 
724   template <typename T> void CheckForImpureCall(const T &x) {
725     if (auto bad{FindImpureCall(context_.foldingContext(), x)}) {
726       context_.Say(
727           "Impure procedure '%s' may not be referenced in a %s"_err_en_US, *bad,
728           LoopKindName());
729     }
730   }
731 
732   // Each index should be used on the LHS of each assignment in a FORALL
733   void CheckForallIndexesUsed(const evaluate::Assignment &assignment) {
734     SymbolVector indexVars{context_.GetIndexVars(IndexVarKind::FORALL)};
735     if (!indexVars.empty()) {
736       SymbolSet symbols{evaluate::CollectSymbols(assignment.lhs)};
737       std::visit(
738           common::visitors{
739               [&](const evaluate::Assignment::BoundsSpec &spec) {
740                 for (const auto &bound : spec) {
741 // TODO: this is working around missing std::set::merge in some versions of
742 // clang that we are building with
743 #ifdef __clang__
744                   auto boundSymbols{evaluate::CollectSymbols(bound)};
745                   symbols.insert(boundSymbols.begin(), boundSymbols.end());
746 #else
747                   symbols.merge(evaluate::CollectSymbols(bound));
748 #endif
749                 }
750               },
751               [&](const evaluate::Assignment::BoundsRemapping &remapping) {
752                 for (const auto &bounds : remapping) {
753 #ifdef __clang__
754                   auto lbSymbols{evaluate::CollectSymbols(bounds.first)};
755                   symbols.insert(lbSymbols.begin(), lbSymbols.end());
756                   auto ubSymbols{evaluate::CollectSymbols(bounds.second)};
757                   symbols.insert(ubSymbols.begin(), ubSymbols.end());
758 #else
759                   symbols.merge(evaluate::CollectSymbols(bounds.first));
760                   symbols.merge(evaluate::CollectSymbols(bounds.second));
761 #endif
762                 }
763               },
764               [](const auto &) {},
765           },
766           assignment.u);
767       for (const Symbol &index : indexVars) {
768         if (symbols.count(index) == 0) {
769           context_.Say(
770               "Warning: FORALL index variable '%s' not used on left-hand side"
771               " of assignment"_en_US,
772               index.name());
773         }
774       }
775     }
776   }
777 
778   // For messages where the DO loop must be DO CONCURRENT, make that explicit.
779   const char *LoopKindName() const {
780     return kind_ == IndexVarKind::DO ? "DO CONCURRENT" : "FORALL";
781   }
782 
783   SemanticsContext &context_;
784   const IndexVarKind kind_;
785   parser::CharBlock currentStatementSourcePosition_;
786 }; // class DoContext
787 
788 void DoForallChecker::Enter(const parser::DoConstruct &doConstruct) {
789   DoContext doContext{context_, IndexVarKind::DO};
790   doContext.DefineDoVariables(doConstruct);
791 }
792 
793 void DoForallChecker::Leave(const parser::DoConstruct &doConstruct) {
794   DoContext doContext{context_, IndexVarKind::DO};
795   doContext.Check(doConstruct);
796   doContext.ResetDoVariables(doConstruct);
797 }
798 
799 void DoForallChecker::Enter(const parser::ForallConstruct &construct) {
800   DoContext doContext{context_, IndexVarKind::FORALL};
801   doContext.ActivateIndexVars(GetControls(construct));
802 }
803 void DoForallChecker::Leave(const parser::ForallConstruct &construct) {
804   DoContext doContext{context_, IndexVarKind::FORALL};
805   doContext.Check(construct);
806   doContext.DeactivateIndexVars(GetControls(construct));
807 }
808 
809 void DoForallChecker::Enter(const parser::ForallStmt &stmt) {
810   DoContext doContext{context_, IndexVarKind::FORALL};
811   doContext.ActivateIndexVars(GetControls(stmt));
812 }
813 void DoForallChecker::Leave(const parser::ForallStmt &stmt) {
814   DoContext doContext{context_, IndexVarKind::FORALL};
815   doContext.Check(stmt);
816   doContext.DeactivateIndexVars(GetControls(stmt));
817 }
818 void DoForallChecker::Leave(const parser::ForallAssignmentStmt &stmt) {
819   DoContext doContext{context_, IndexVarKind::FORALL};
820   doContext.Check(stmt);
821 }
822 
823 template <typename A>
824 static parser::CharBlock GetConstructPosition(const A &a) {
825   return std::get<0>(a.t).source;
826 }
827 
828 static parser::CharBlock GetNodePosition(const ConstructNode &construct) {
829   return std::visit(
830       [&](const auto &x) { return GetConstructPosition(*x); }, construct);
831 }
832 
833 void DoForallChecker::SayBadLeave(StmtType stmtType,
834     const char *enclosingStmtName, const ConstructNode &construct) const {
835   context_
836       .Say("%s must not leave a %s statement"_err_en_US, EnumToString(stmtType),
837           enclosingStmtName)
838       .Attach(GetNodePosition(construct), "The construct that was left"_en_US);
839 }
840 
841 static const parser::DoConstruct *MaybeGetDoConstruct(
842     const ConstructNode &construct) {
843   if (const auto *doNode{
844           std::get_if<const parser::DoConstruct *>(&construct)}) {
845     return *doNode;
846   } else {
847     return nullptr;
848   }
849 }
850 
851 static bool ConstructIsDoConcurrent(const ConstructNode &construct) {
852   const parser::DoConstruct *doConstruct{MaybeGetDoConstruct(construct)};
853   return doConstruct && doConstruct->IsDoConcurrent();
854 }
855 
856 // Check that CYCLE and EXIT statements do not cause flow of control to
857 // leave DO CONCURRENT, CRITICAL, or CHANGE TEAM constructs.
858 void DoForallChecker::CheckForBadLeave(
859     StmtType stmtType, const ConstructNode &construct) const {
860   std::visit(common::visitors{
861                  [&](const parser::DoConstruct *doConstructPtr) {
862                    if (doConstructPtr->IsDoConcurrent()) {
863                      // C1135 and C1167 -- CYCLE and EXIT statements can't leave
864                      // a DO CONCURRENT
865                      SayBadLeave(stmtType, "DO CONCURRENT", construct);
866                    }
867                  },
868                  [&](const parser::CriticalConstruct *) {
869                    // C1135 and C1168 -- similarly, for CRITICAL
870                    SayBadLeave(stmtType, "CRITICAL", construct);
871                  },
872                  [&](const parser::ChangeTeamConstruct *) {
873                    // C1135 and C1168 -- similarly, for CHANGE TEAM
874                    SayBadLeave(stmtType, "CHANGE TEAM", construct);
875                  },
876                  [](const auto *) {},
877              },
878       construct);
879 }
880 
881 static bool StmtMatchesConstruct(const parser::Name *stmtName,
882     StmtType stmtType, const std::optional<parser::Name> &constructName,
883     const ConstructNode &construct) {
884   bool inDoConstruct{MaybeGetDoConstruct(construct) != nullptr};
885   if (!stmtName) {
886     return inDoConstruct; // Unlabeled statements match all DO constructs
887   } else if (constructName && constructName->source == stmtName->source) {
888     return stmtType == StmtType::EXIT || inDoConstruct;
889   } else {
890     return false;
891   }
892 }
893 
894 // C1167 Can't EXIT from a DO CONCURRENT
895 void DoForallChecker::CheckDoConcurrentExit(
896     StmtType stmtType, const ConstructNode &construct) const {
897   if (stmtType == StmtType::EXIT && ConstructIsDoConcurrent(construct)) {
898     SayBadLeave(StmtType::EXIT, "DO CONCURRENT", construct);
899   }
900 }
901 
902 // Check nesting violations for a CYCLE or EXIT statement.  Loop up the
903 // nesting levels looking for a construct that matches the CYCLE or EXIT
904 // statment.  At every construct, check for a violation.  If we find a match
905 // without finding a violation, the check is complete.
906 void DoForallChecker::CheckNesting(
907     StmtType stmtType, const parser::Name *stmtName) const {
908   const ConstructStack &stack{context_.constructStack()};
909   for (auto iter{stack.cend()}; iter-- != stack.cbegin();) {
910     const ConstructNode &construct{*iter};
911     const std::optional<parser::Name> &constructName{
912         MaybeGetNodeName(construct)};
913     if (StmtMatchesConstruct(stmtName, stmtType, constructName, construct)) {
914       CheckDoConcurrentExit(stmtType, construct);
915       return; // We got a match, so we're finished checking
916     }
917     CheckForBadLeave(stmtType, construct);
918   }
919 
920   // We haven't found a match in the enclosing constructs
921   if (stmtType == StmtType::EXIT) {
922     context_.Say("No matching construct for EXIT statement"_err_en_US);
923   } else {
924     context_.Say("No matching DO construct for CYCLE statement"_err_en_US);
925   }
926 }
927 
928 // C1135 -- Nesting for CYCLE statements
929 void DoForallChecker::Enter(const parser::CycleStmt &cycleStmt) {
930   CheckNesting(StmtType::CYCLE, common::GetPtrFromOptional(cycleStmt.v));
931 }
932 
933 // C1167 and C1168 -- Nesting for EXIT statements
934 void DoForallChecker::Enter(const parser::ExitStmt &exitStmt) {
935   CheckNesting(StmtType::EXIT, common::GetPtrFromOptional(exitStmt.v));
936 }
937 
938 void DoForallChecker::Leave(const parser::AssignmentStmt &stmt) {
939   const auto &variable{std::get<parser::Variable>(stmt.t)};
940   context_.CheckIndexVarRedefine(variable);
941 }
942 
943 static void CheckIfArgIsDoVar(const evaluate::ActualArgument &arg,
944     const parser::CharBlock location, SemanticsContext &context) {
945   common::Intent intent{arg.dummyIntent()};
946   if (intent == common::Intent::Out || intent == common::Intent::InOut) {
947     if (const SomeExpr * argExpr{arg.UnwrapExpr()}) {
948       if (const Symbol * var{evaluate::UnwrapWholeSymbolDataRef(*argExpr)}) {
949         if (intent == common::Intent::Out) {
950           context.CheckIndexVarRedefine(location, *var);
951         } else {
952           context.WarnIndexVarRedefine(location, *var); // INTENT(INOUT)
953         }
954       }
955     }
956   }
957 }
958 
959 // Check to see if a DO variable is being passed as an actual argument to a
960 // dummy argument whose intent is OUT or INOUT.  To do this, we need to find
961 // the expressions for actual arguments which contain DO variables.  We get the
962 // intents of the dummy arguments from the ProcedureRef in the "typedCall"
963 // field of the CallStmt which was filled in during expression checking.  At
964 // the same time, we need to iterate over the parser::Expr versions of the
965 // actual arguments to get their source locations of the arguments for the
966 // messages.
967 void DoForallChecker::Leave(const parser::CallStmt &callStmt) {
968   if (const auto &typedCall{callStmt.typedCall}) {
969     const auto &parsedArgs{
970         std::get<std::list<parser::ActualArgSpec>>(callStmt.v.t)};
971     auto parsedArgIter{parsedArgs.begin()};
972     const evaluate::ActualArguments &checkedArgs{typedCall->arguments()};
973     for (const auto &checkedOptionalArg : checkedArgs) {
974       if (parsedArgIter == parsedArgs.end()) {
975         break; // No more parsed arguments, we're done.
976       }
977       const auto &parsedArg{std::get<parser::ActualArg>(parsedArgIter->t)};
978       ++parsedArgIter;
979       if (checkedOptionalArg) {
980         const evaluate::ActualArgument &checkedArg{*checkedOptionalArg};
981         if (const auto *parsedExpr{
982                 std::get_if<common::Indirection<parser::Expr>>(&parsedArg.u)}) {
983           CheckIfArgIsDoVar(checkedArg, parsedExpr->value().source, context_);
984         }
985       }
986     }
987   }
988 }
989 
990 void DoForallChecker::Leave(const parser::ConnectSpec &connectSpec) {
991   const auto *newunit{
992       std::get_if<parser::ConnectSpec::Newunit>(&connectSpec.u)};
993   if (newunit) {
994     context_.CheckIndexVarRedefine(newunit->v.thing.thing);
995   }
996 }
997 
998 using ActualArgumentSet = std::set<evaluate::ActualArgumentRef>;
999 
1000 struct CollectActualArgumentsHelper
1001     : public evaluate::SetTraverse<CollectActualArgumentsHelper,
1002           ActualArgumentSet> {
1003   using Base = SetTraverse<CollectActualArgumentsHelper, ActualArgumentSet>;
1004   CollectActualArgumentsHelper() : Base{*this} {}
1005   using Base::operator();
1006   ActualArgumentSet operator()(const evaluate::ActualArgument &arg) const {
1007     return Combine(ActualArgumentSet{arg},
1008         CollectActualArgumentsHelper{}(arg.UnwrapExpr()));
1009   }
1010 };
1011 
1012 template <typename A> ActualArgumentSet CollectActualArguments(const A &x) {
1013   return CollectActualArgumentsHelper{}(x);
1014 }
1015 
1016 template ActualArgumentSet CollectActualArguments(const SomeExpr &);
1017 
1018 void DoForallChecker::Enter(const parser::Expr &parsedExpr) { ++exprDepth_; }
1019 
1020 void DoForallChecker::Leave(const parser::Expr &parsedExpr) {
1021   CHECK(exprDepth_ > 0);
1022   if (--exprDepth_ == 0) { // Only check top level expressions
1023     if (const SomeExpr * expr{GetExpr(parsedExpr)}) {
1024       ActualArgumentSet argSet{CollectActualArguments(*expr)};
1025       for (const evaluate::ActualArgumentRef &argRef : argSet) {
1026         CheckIfArgIsDoVar(*argRef, parsedExpr.source, context_);
1027       }
1028     }
1029   }
1030 }
1031 
1032 void DoForallChecker::Leave(const parser::InquireSpec &inquireSpec) {
1033   const auto *intVar{std::get_if<parser::InquireSpec::IntVar>(&inquireSpec.u)};
1034   if (intVar) {
1035     const auto &scalar{std::get<parser::ScalarIntVariable>(intVar->t)};
1036     context_.CheckIndexVarRedefine(scalar.thing.thing);
1037   }
1038 }
1039 
1040 void DoForallChecker::Leave(const parser::IoControlSpec &ioControlSpec) {
1041   const auto *size{std::get_if<parser::IoControlSpec::Size>(&ioControlSpec.u)};
1042   if (size) {
1043     context_.CheckIndexVarRedefine(size->v.thing.thing);
1044   }
1045 }
1046 
1047 void DoForallChecker::Leave(const parser::OutputImpliedDo &outputImpliedDo) {
1048   const auto &control{std::get<parser::IoImpliedDoControl>(outputImpliedDo.t)};
1049   const parser::Name &name{control.name.thing.thing};
1050   context_.CheckIndexVarRedefine(name.source, *name.symbol);
1051 }
1052 
1053 void DoForallChecker::Leave(const parser::StatVariable &statVariable) {
1054   context_.CheckIndexVarRedefine(statVariable.v.thing.thing);
1055 }
1056 
1057 } // namespace Fortran::semantics
1058