1 //===-- lib/Semantics/check-do-forall.cpp ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "check-do-forall.h"
10 #include "flang/Common/template.h"
11 #include "flang/Evaluate/call.h"
12 #include "flang/Evaluate/expression.h"
13 #include "flang/Evaluate/tools.h"
14 #include "flang/Parser/message.h"
15 #include "flang/Parser/parse-tree-visitor.h"
16 #include "flang/Parser/tools.h"
17 #include "flang/Semantics/attr.h"
18 #include "flang/Semantics/scope.h"
19 #include "flang/Semantics/semantics.h"
20 #include "flang/Semantics/symbol.h"
21 #include "flang/Semantics/tools.h"
22 #include "flang/Semantics/type.h"
23 
24 namespace Fortran::evaluate {
25 using ActualArgumentRef = common::Reference<const ActualArgument>;
26 
27 inline bool operator<(ActualArgumentRef x, ActualArgumentRef y) {
28   return &*x < &*y;
29 }
30 } // namespace Fortran::evaluate
31 
32 namespace Fortran::semantics {
33 
34 using namespace parser::literals;
35 
36 using Bounds = parser::LoopControl::Bounds;
37 using IndexVarKind = SemanticsContext::IndexVarKind;
38 
39 static const parser::ConcurrentHeader &GetConcurrentHeader(
40     const parser::LoopControl &loopControl) {
41   const auto &concurrent{
42       std::get<parser::LoopControl::Concurrent>(loopControl.u)};
43   return std::get<parser::ConcurrentHeader>(concurrent.t);
44 }
45 static const parser::ConcurrentHeader &GetConcurrentHeader(
46     const parser::ForallConstruct &construct) {
47   const auto &stmt{
48       std::get<parser::Statement<parser::ForallConstructStmt>>(construct.t)};
49   return std::get<common::Indirection<parser::ConcurrentHeader>>(
50       stmt.statement.t)
51       .value();
52 }
53 static const parser::ConcurrentHeader &GetConcurrentHeader(
54     const parser::ForallStmt &stmt) {
55   return std::get<common::Indirection<parser::ConcurrentHeader>>(stmt.t)
56       .value();
57 }
58 template <typename T>
59 static const std::list<parser::ConcurrentControl> &GetControls(const T &x) {
60   return std::get<std::list<parser::ConcurrentControl>>(
61       GetConcurrentHeader(x).t);
62 }
63 
64 static const Bounds &GetBounds(const parser::DoConstruct &doConstruct) {
65   auto &loopControl{doConstruct.GetLoopControl().value()};
66   return std::get<Bounds>(loopControl.u);
67 }
68 
69 static const parser::Name &GetDoVariable(
70     const parser::DoConstruct &doConstruct) {
71   const Bounds &bounds{GetBounds(doConstruct)};
72   return bounds.name.thing;
73 }
74 
75 // Return the (possibly null)  name of the construct
76 template <typename A>
77 static const parser::Name *MaybeGetConstructName(const A &a) {
78   return common::GetPtrFromOptional(std::get<0>(std::get<0>(a.t).statement.t));
79 }
80 
81 static parser::MessageFixedText GetEnclosingDoMsg() {
82   return "Enclosing DO CONCURRENT statement"_en_US;
83 }
84 
85 static const parser::Name *MaybeGetConstructName(
86     const parser::BlockConstruct &blockConstruct) {
87   return common::GetPtrFromOptional(
88       std::get<parser::Statement<parser::BlockStmt>>(blockConstruct.t)
89           .statement.v);
90 }
91 
92 static void SayWithDo(SemanticsContext &context, parser::CharBlock stmtLocation,
93     parser::MessageFixedText &&message, parser::CharBlock doLocation) {
94   context.Say(stmtLocation, message).Attach(doLocation, GetEnclosingDoMsg());
95 }
96 
97 // 11.1.7.5 - enforce semantics constraints on a DO CONCURRENT loop body
98 class DoConcurrentBodyEnforce {
99 public:
100   DoConcurrentBodyEnforce(
101       SemanticsContext &context, parser::CharBlock doConcurrentSourcePosition)
102       : context_{context}, doConcurrentSourcePosition_{
103                                doConcurrentSourcePosition} {}
104   std::set<parser::Label> labels() { return labels_; }
105   template <typename T> bool Pre(const T &) { return true; }
106   template <typename T> void Post(const T &) {}
107 
108   template <typename T> bool Pre(const parser::Statement<T> &statement) {
109     currentStatementSourcePosition_ = statement.source;
110     if (statement.label.has_value()) {
111       labels_.insert(*statement.label);
112     }
113     return true;
114   }
115 
116   template <typename T> bool Pre(const parser::UnlabeledStatement<T> &stmt) {
117     currentStatementSourcePosition_ = stmt.source;
118     return true;
119   }
120 
121   // C1140 -- Can't deallocate a polymorphic entity in a DO CONCURRENT.
122   // Deallocation can be caused by exiting a block that declares an allocatable
123   // entity, assignment to an allocatable variable, or an actual DEALLOCATE
124   // statement
125   //
126   // Note also that the deallocation of a derived type entity might cause the
127   // invocation of an IMPURE final subroutine. (C1139)
128   //
129 
130   // Only to be called for symbols with ObjectEntityDetails
131   static bool HasImpureFinal(const Symbol &symbol) {
132     if (const Symbol * root{GetAssociationRoot(symbol)}) {
133       CHECK(root->has<ObjectEntityDetails>());
134       if (const DeclTypeSpec * symType{root->GetType()}) {
135         if (const DerivedTypeSpec * derived{symType->AsDerived()}) {
136           return semantics::HasImpureFinal(*derived);
137         }
138       }
139     }
140     return false;
141   }
142 
143   // Predicate for deallocations caused by block exit and direct deallocation
144   static bool DeallocateAll(const Symbol &) { return true; }
145 
146   // Predicate for deallocations caused by intrinsic assignment
147   static bool DeallocateNonCoarray(const Symbol &component) {
148     return !IsCoarray(component);
149   }
150 
151   static bool WillDeallocatePolymorphic(const Symbol &entity,
152       const std::function<bool(const Symbol &)> &WillDeallocate) {
153     return WillDeallocate(entity) && IsPolymorphicAllocatable(entity);
154   }
155 
156   // Is it possible that we will we deallocate a polymorphic entity or one
157   // of its components?
158   static bool MightDeallocatePolymorphic(const Symbol &entity,
159       const std::function<bool(const Symbol &)> &WillDeallocate) {
160     if (const Symbol * root{GetAssociationRoot(entity)}) {
161       // Check the entity itself, no coarray exception here
162       if (IsPolymorphicAllocatable(*root)) {
163         return true;
164       }
165       // Check the components
166       if (const auto *details{root->detailsIf<ObjectEntityDetails>()}) {
167         if (const DeclTypeSpec * entityType{details->type()}) {
168           if (const DerivedTypeSpec * derivedType{entityType->AsDerived()}) {
169             UltimateComponentIterator ultimates{*derivedType};
170             for (const auto &ultimate : ultimates) {
171               if (WillDeallocatePolymorphic(ultimate, WillDeallocate)) {
172                 return true;
173               }
174             }
175           }
176         }
177       }
178     }
179     return false;
180   }
181 
182   void SayDeallocateWithImpureFinal(const Symbol &entity, const char *reason) {
183     context_.SayWithDecl(entity, currentStatementSourcePosition_,
184         "Deallocation of an entity with an IMPURE FINAL procedure"
185         " caused by %s not allowed in DO CONCURRENT"_err_en_US,
186         reason);
187   }
188 
189   void SayDeallocateOfPolymorph(
190       parser::CharBlock location, const Symbol &entity, const char *reason) {
191     context_.SayWithDecl(entity, location,
192         "Deallocation of a polymorphic entity caused by %s"
193         " not allowed in DO CONCURRENT"_err_en_US,
194         reason);
195   }
196 
197   // Deallocation caused by block exit
198   // Allocatable entities and all of their allocatable subcomponents will be
199   // deallocated.  This test is different from the other two because it does
200   // not deallocate in cases where the entity itself is not allocatable but
201   // has allocatable polymorphic components
202   void Post(const parser::BlockConstruct &blockConstruct) {
203     const auto &endBlockStmt{
204         std::get<parser::Statement<parser::EndBlockStmt>>(blockConstruct.t)};
205     const Scope &blockScope{context_.FindScope(endBlockStmt.source)};
206     const Scope &doScope{context_.FindScope(doConcurrentSourcePosition_)};
207     if (DoesScopeContain(&doScope, blockScope)) {
208       const char *reason{"block exit"};
209       for (auto &pair : blockScope) {
210         const Symbol &entity{*pair.second};
211         if (IsAllocatable(entity) && !IsSaved(entity) &&
212             MightDeallocatePolymorphic(entity, DeallocateAll)) {
213           SayDeallocateOfPolymorph(endBlockStmt.source, entity, reason);
214         }
215         if (HasImpureFinal(entity)) {
216           SayDeallocateWithImpureFinal(entity, reason);
217         }
218       }
219     }
220   }
221 
222   // Deallocation caused by assignment
223   // Note that this case does not cause deallocation of coarray components
224   void Post(const parser::AssignmentStmt &stmt) {
225     const auto &variable{std::get<parser::Variable>(stmt.t)};
226     if (const Symbol * entity{GetLastName(variable).symbol}) {
227       const char *reason{"assignment"};
228       if (MightDeallocatePolymorphic(*entity, DeallocateNonCoarray)) {
229         SayDeallocateOfPolymorph(variable.GetSource(), *entity, reason);
230       }
231       if (HasImpureFinal(*entity)) {
232         SayDeallocateWithImpureFinal(*entity, reason);
233       }
234     }
235   }
236 
237   // Deallocation from a DEALLOCATE statement
238   // This case is different because DEALLOCATE statements deallocate both
239   // ALLOCATABLE and POINTER entities
240   void Post(const parser::DeallocateStmt &stmt) {
241     const auto &allocateObjectList{
242         std::get<std::list<parser::AllocateObject>>(stmt.t)};
243     for (const auto &allocateObject : allocateObjectList) {
244       const parser::Name &name{GetLastName(allocateObject)};
245       const char *reason{"a DEALLOCATE statement"};
246       if (name.symbol) {
247         const Symbol &entity{*name.symbol};
248         const DeclTypeSpec *entityType{entity.GetType()};
249         if ((entityType && entityType->IsPolymorphic()) || // POINTER case
250             MightDeallocatePolymorphic(entity, DeallocateAll)) {
251           SayDeallocateOfPolymorph(
252               currentStatementSourcePosition_, entity, reason);
253         }
254         if (HasImpureFinal(entity)) {
255           SayDeallocateWithImpureFinal(entity, reason);
256         }
257       }
258     }
259   }
260 
261   // C1137 -- No image control statements in a DO CONCURRENT
262   void Post(const parser::ExecutableConstruct &construct) {
263     if (IsImageControlStmt(construct)) {
264       const parser::CharBlock statementLocation{
265           GetImageControlStmtLocation(construct)};
266       auto &msg{context_.Say(statementLocation,
267           "An image control statement is not allowed in DO"
268           " CONCURRENT"_err_en_US)};
269       if (auto coarrayMsg{GetImageControlStmtCoarrayMsg(construct)}) {
270         msg.Attach(statementLocation, *coarrayMsg);
271       }
272       msg.Attach(doConcurrentSourcePosition_, GetEnclosingDoMsg());
273     }
274   }
275 
276   // C1136 -- No RETURN statements in a DO CONCURRENT
277   void Post(const parser::ReturnStmt &) {
278     context_
279         .Say(currentStatementSourcePosition_,
280             "RETURN is not allowed in DO CONCURRENT"_err_en_US)
281         .Attach(doConcurrentSourcePosition_, GetEnclosingDoMsg());
282   }
283 
284   // C1139: call to impure procedure and ...
285   // C1141: cannot call ieee_get_flag, ieee_[gs]et_halting_mode
286   // It's not necessary to check the ieee_get* procedures because they're
287   // not pure, and impure procedures are caught by checks for constraint C1139
288   void Post(const parser::ProcedureDesignator &procedureDesignator) {
289     if (auto *name{std::get_if<parser::Name>(&procedureDesignator.u)}) {
290       if (name->symbol && !IsPureProcedure(*name->symbol)) {
291         SayWithDo(context_, currentStatementSourcePosition_,
292             "Call to an impure procedure is not allowed in DO"
293             " CONCURRENT"_err_en_US,
294             doConcurrentSourcePosition_);
295       }
296       if (name->symbol && fromScope(*name->symbol, "ieee_exceptions"s)) {
297         if (name->source == "ieee_set_halting_mode") {
298           SayWithDo(context_, currentStatementSourcePosition_,
299               "IEEE_SET_HALTING_MODE is not allowed in DO "
300               "CONCURRENT"_err_en_US,
301               doConcurrentSourcePosition_);
302         }
303       }
304     } else {
305       // C1139: this a procedure component
306       auto &component{std::get<parser::ProcComponentRef>(procedureDesignator.u)
307                           .v.thing.component};
308       if (component.symbol && !IsPureProcedure(*component.symbol)) {
309         SayWithDo(context_, currentStatementSourcePosition_,
310             "Call to an impure procedure component is not allowed"
311             " in DO CONCURRENT"_err_en_US,
312             doConcurrentSourcePosition_);
313       }
314     }
315   }
316 
317   // 11.1.7.5, paragraph 5, no ADVANCE specifier in a DO CONCURRENT
318   void Post(const parser::IoControlSpec &ioControlSpec) {
319     if (auto *charExpr{
320             std::get_if<parser::IoControlSpec::CharExpr>(&ioControlSpec.u)}) {
321       if (std::get<parser::IoControlSpec::CharExpr::Kind>(charExpr->t) ==
322           parser::IoControlSpec::CharExpr::Kind::Advance) {
323         SayWithDo(context_, currentStatementSourcePosition_,
324             "ADVANCE specifier is not allowed in DO"
325             " CONCURRENT"_err_en_US,
326             doConcurrentSourcePosition_);
327       }
328     }
329   }
330 
331 private:
332   // Return the (possibly null) name of the statement
333   template <typename A>
334   static const parser::Name *MaybeGetStmtName(const A &a) {
335     return common::GetPtrFromOptional(std::get<0>(a.t));
336   }
337 
338   bool fromScope(const Symbol &symbol, const std::string &moduleName) {
339     if (symbol.GetUltimate().owner().IsModule() &&
340         symbol.GetUltimate().owner().GetName().value().ToString() ==
341             moduleName) {
342       return true;
343     }
344     return false;
345   }
346 
347   std::set<parser::Label> labels_;
348   parser::CharBlock currentStatementSourcePosition_;
349   SemanticsContext &context_;
350   parser::CharBlock doConcurrentSourcePosition_;
351 }; // class DoConcurrentBodyEnforce
352 
353 // Class for enforcing C1130 -- in a DO CONCURRENT with DEFAULT(NONE),
354 // variables from enclosing scopes must have their locality specified
355 class DoConcurrentVariableEnforce {
356 public:
357   DoConcurrentVariableEnforce(
358       SemanticsContext &context, parser::CharBlock doConcurrentSourcePosition)
359       : context_{context},
360         doConcurrentSourcePosition_{doConcurrentSourcePosition},
361         blockScope_{context.FindScope(doConcurrentSourcePosition_)} {}
362 
363   template <typename T> bool Pre(const T &) { return true; }
364   template <typename T> void Post(const T &) {}
365 
366   // Check to see if the name is a variable from an enclosing scope
367   void Post(const parser::Name &name) {
368     if (const Symbol * symbol{name.symbol}) {
369       if (IsVariableName(*symbol)) {
370         const Scope &variableScope{symbol->owner()};
371         if (DoesScopeContain(&variableScope, blockScope_)) {
372           context_.SayWithDecl(*symbol, name.source,
373               "Variable '%s' from an enclosing scope referenced in DO "
374               "CONCURRENT with DEFAULT(NONE) must appear in a "
375               "locality-spec"_err_en_US,
376               symbol->name());
377         }
378       }
379     }
380   }
381 
382 private:
383   SemanticsContext &context_;
384   parser::CharBlock doConcurrentSourcePosition_;
385   const Scope &blockScope_;
386 }; // class DoConcurrentVariableEnforce
387 
388 // Find a DO or FORALL and enforce semantics checks on its body
389 class DoContext {
390 public:
391   DoContext(SemanticsContext &context, IndexVarKind kind)
392       : context_{context}, kind_{kind} {}
393 
394   // Mark this DO construct as a point of definition for the DO variables
395   // or index-names it contains.  If they're already defined, emit an error
396   // message.  We need to remember both the variable and the source location of
397   // the variable in the DO construct so that we can remove it when we leave
398   // the DO construct and use its location in error messages.
399   void DefineDoVariables(const parser::DoConstruct &doConstruct) {
400     if (doConstruct.IsDoNormal()) {
401       context_.ActivateIndexVar(GetDoVariable(doConstruct), IndexVarKind::DO);
402     } else if (doConstruct.IsDoConcurrent()) {
403       if (const auto &loopControl{doConstruct.GetLoopControl()}) {
404         ActivateIndexVars(GetControls(*loopControl));
405       }
406     }
407   }
408 
409   // Called at the end of a DO construct to deactivate the DO construct
410   void ResetDoVariables(const parser::DoConstruct &doConstruct) {
411     if (doConstruct.IsDoNormal()) {
412       context_.DeactivateIndexVar(GetDoVariable(doConstruct));
413     } else if (doConstruct.IsDoConcurrent()) {
414       if (const auto &loopControl{doConstruct.GetLoopControl()}) {
415         DeactivateIndexVars(GetControls(*loopControl));
416       }
417     }
418   }
419 
420   void ActivateIndexVars(const std::list<parser::ConcurrentControl> &controls) {
421     for (const auto &control : controls) {
422       context_.ActivateIndexVar(std::get<parser::Name>(control.t), kind_);
423     }
424   }
425   void DeactivateIndexVars(
426       const std::list<parser::ConcurrentControl> &controls) {
427     for (const auto &control : controls) {
428       context_.DeactivateIndexVar(std::get<parser::Name>(control.t));
429     }
430   }
431 
432   void Check(const parser::DoConstruct &doConstruct) {
433     if (doConstruct.IsDoConcurrent()) {
434       CheckDoConcurrent(doConstruct);
435       return;
436     }
437     if (doConstruct.IsDoNormal()) {
438       CheckDoNormal(doConstruct);
439       return;
440     }
441     // TODO: handle the other cases
442   }
443 
444   void Check(const parser::ForallStmt &stmt) {
445     CheckConcurrentHeader(GetConcurrentHeader(stmt));
446   }
447   void Check(const parser::ForallConstruct &construct) {
448     CheckConcurrentHeader(GetConcurrentHeader(construct));
449   }
450 
451   void Check(const parser::ForallAssignmentStmt &stmt) {
452     const evaluate::Assignment *assignment{std::visit(
453         common::visitors{[&](const auto &x) { return GetAssignment(x); }},
454         stmt.u)};
455     if (assignment) {
456       CheckForallIndexesUsed(*assignment);
457       CheckForImpureCall(assignment->lhs);
458       CheckForImpureCall(assignment->rhs);
459       if (const auto *proc{
460               std::get_if<evaluate::ProcedureRef>(&assignment->u)}) {
461         CheckForImpureCall(*proc);
462       }
463       std::visit(common::visitors{
464                      [](const evaluate::Assignment::Intrinsic &) {},
465                      [&](const evaluate::ProcedureRef &proc) {
466                        CheckForImpureCall(proc);
467                      },
468                      [&](const evaluate::Assignment::BoundsSpec &bounds) {
469                        for (const auto &bound : bounds) {
470                          CheckForImpureCall(SomeExpr{bound});
471                        }
472                      },
473                      [&](const evaluate::Assignment::BoundsRemapping &bounds) {
474                        for (const auto &bound : bounds) {
475                          CheckForImpureCall(SomeExpr{bound.first});
476                          CheckForImpureCall(SomeExpr{bound.second});
477                        }
478                      },
479                  },
480           assignment->u);
481     }
482   }
483 
484 private:
485   void SayBadDoControl(parser::CharBlock sourceLocation) {
486     context_.Say(sourceLocation, "DO controls should be INTEGER"_err_en_US);
487   }
488 
489   void CheckDoControl(const parser::CharBlock &sourceLocation, bool isReal) {
490     const bool warn{context_.warnOnNonstandardUsage() ||
491         context_.ShouldWarn(common::LanguageFeature::RealDoControls)};
492     if (isReal && !warn) {
493       // No messages for the default case
494     } else if (isReal && warn) {
495       context_.Say(sourceLocation, "DO controls should be INTEGER"_en_US);
496     } else {
497       SayBadDoControl(sourceLocation);
498     }
499   }
500 
501   void CheckDoVariable(const parser::ScalarName &scalarName) {
502     const parser::CharBlock &sourceLocation{scalarName.thing.source};
503     if (const Symbol * symbol{scalarName.thing.symbol}) {
504       if (!IsVariableName(*symbol)) {
505         context_.Say(
506             sourceLocation, "DO control must be an INTEGER variable"_err_en_US);
507       } else {
508         const DeclTypeSpec *symType{symbol->GetType()};
509         if (!symType) {
510           SayBadDoControl(sourceLocation);
511         } else {
512           if (!symType->IsNumeric(TypeCategory::Integer)) {
513             CheckDoControl(
514                 sourceLocation, symType->IsNumeric(TypeCategory::Real));
515           }
516         }
517       } // No messages for INTEGER
518     }
519   }
520 
521   // Semantic checks for the limit and step expressions
522   void CheckDoExpression(const parser::ScalarExpr &scalarExpression) {
523     if (const SomeExpr * expr{GetExpr(scalarExpression)}) {
524       if (!ExprHasTypeCategory(*expr, TypeCategory::Integer)) {
525         // No warnings or errors for type INTEGER
526         const parser::CharBlock &loc{scalarExpression.thing.value().source};
527         CheckDoControl(loc, ExprHasTypeCategory(*expr, TypeCategory::Real));
528       }
529     }
530   }
531 
532   void CheckDoNormal(const parser::DoConstruct &doConstruct) {
533     // C1120 -- types of DO variables must be INTEGER, extended by allowing
534     // REAL and DOUBLE PRECISION
535     const Bounds &bounds{GetBounds(doConstruct)};
536     CheckDoVariable(bounds.name);
537     CheckDoExpression(bounds.lower);
538     CheckDoExpression(bounds.upper);
539     if (bounds.step) {
540       CheckDoExpression(*bounds.step);
541       if (IsZero(*bounds.step)) {
542         context_.Say(bounds.step->thing.value().source,
543             "DO step expression should not be zero"_en_US);
544       }
545     }
546   }
547 
548   void CheckDoConcurrent(const parser::DoConstruct &doConstruct) {
549     auto &doStmt{
550         std::get<parser::Statement<parser::NonLabelDoStmt>>(doConstruct.t)};
551     currentStatementSourcePosition_ = doStmt.source;
552 
553     const parser::Block &block{std::get<parser::Block>(doConstruct.t)};
554     DoConcurrentBodyEnforce doConcurrentBodyEnforce{context_, doStmt.source};
555     parser::Walk(block, doConcurrentBodyEnforce);
556 
557     LabelEnforce doConcurrentLabelEnforce{context_,
558         doConcurrentBodyEnforce.labels(), currentStatementSourcePosition_,
559         "DO CONCURRENT"};
560     parser::Walk(block, doConcurrentLabelEnforce);
561 
562     const auto &loopControl{doConstruct.GetLoopControl()};
563     CheckConcurrentLoopControl(*loopControl);
564     CheckLocalitySpecs(*loopControl, block);
565   }
566 
567   // Return a set of symbols whose names are in a Local locality-spec.  Look
568   // the names up in the scope that encloses the DO construct to avoid getting
569   // the local versions of them.  Then follow the host-, use-, and
570   // construct-associations to get the root symbols
571   SymbolSet GatherLocals(
572       const std::list<parser::LocalitySpec> &localitySpecs) const {
573     SymbolSet symbols;
574     const Scope &parentScope{
575         context_.FindScope(currentStatementSourcePosition_).parent()};
576     // Loop through the LocalitySpec::Local locality-specs
577     for (const auto &ls : localitySpecs) {
578       if (const auto *names{std::get_if<parser::LocalitySpec::Local>(&ls.u)}) {
579         // Loop through the names in the Local locality-spec getting their
580         // symbols
581         for (const parser::Name &name : names->v) {
582           if (const Symbol * symbol{parentScope.FindSymbol(name.source)}) {
583             if (const Symbol * root{GetAssociationRoot(*symbol)}) {
584               symbols.insert(*root);
585             }
586           }
587         }
588       }
589     }
590     return symbols;
591   }
592 
593   static SymbolSet GatherSymbolsFromExpression(const parser::Expr &expression) {
594     SymbolSet result;
595     if (const auto *expr{GetExpr(expression)}) {
596       for (const Symbol &symbol : evaluate::CollectSymbols(*expr)) {
597         if (const Symbol * root{GetAssociationRoot(symbol)}) {
598           result.insert(*root);
599         }
600       }
601     }
602     return result;
603   }
604 
605   // C1121 - procedures in mask must be pure
606   void CheckMaskIsPure(const parser::ScalarLogicalExpr &mask) const {
607     SymbolSet references{GatherSymbolsFromExpression(mask.thing.thing.value())};
608     for (const Symbol &ref : references) {
609       if (IsProcedure(ref) && !IsPureProcedure(ref)) {
610         context_.SayWithDecl(ref, parser::Unwrap<parser::Expr>(mask)->source,
611             "%s mask expression may not reference impure procedure '%s'"_err_en_US,
612             LoopKindName(), ref.name());
613         return;
614       }
615     }
616   }
617 
618   void CheckNoCollisions(const SymbolSet &refs, const SymbolSet &uses,
619       parser::MessageFixedText &&errorMessage,
620       const parser::CharBlock &refPosition) const {
621     for (const Symbol &ref : refs) {
622       if (uses.find(ref) != uses.end()) {
623         context_.SayWithDecl(ref, refPosition, std::move(errorMessage),
624             LoopKindName(), ref.name());
625         return;
626       }
627     }
628   }
629 
630   void HasNoReferences(
631       const SymbolSet &indexNames, const parser::ScalarIntExpr &expr) const {
632     CheckNoCollisions(GatherSymbolsFromExpression(expr.thing.thing.value()),
633         indexNames,
634         "%s limit expression may not reference index variable '%s'"_err_en_US,
635         expr.thing.thing.value().source);
636   }
637 
638   // C1129, names in local locality-specs can't be in mask expressions
639   void CheckMaskDoesNotReferenceLocal(
640       const parser::ScalarLogicalExpr &mask, const SymbolSet &localVars) const {
641     CheckNoCollisions(GatherSymbolsFromExpression(mask.thing.thing.value()),
642         localVars,
643         "%s mask expression references variable '%s'"
644         " in LOCAL locality-spec"_err_en_US,
645         mask.thing.thing.value().source);
646   }
647 
648   // C1129, names in local locality-specs can't be in limit or step
649   // expressions
650   void CheckExprDoesNotReferenceLocal(
651       const parser::ScalarIntExpr &expr, const SymbolSet &localVars) const {
652     CheckNoCollisions(GatherSymbolsFromExpression(expr.thing.thing.value()),
653         localVars,
654         "%s expression references variable '%s'"
655         " in LOCAL locality-spec"_err_en_US,
656         expr.thing.thing.value().source);
657   }
658 
659   // C1130, DEFAULT(NONE) locality requires names to be in locality-specs to
660   // be used in the body of the DO loop
661   void CheckDefaultNoneImpliesExplicitLocality(
662       const std::list<parser::LocalitySpec> &localitySpecs,
663       const parser::Block &block) const {
664     bool hasDefaultNone{false};
665     for (auto &ls : localitySpecs) {
666       if (std::holds_alternative<parser::LocalitySpec::DefaultNone>(ls.u)) {
667         if (hasDefaultNone) {
668           // C1127, you can only have one DEFAULT(NONE)
669           context_.Say(currentStatementSourcePosition_,
670               "Only one DEFAULT(NONE) may appear"_en_US);
671           break;
672         }
673         hasDefaultNone = true;
674       }
675     }
676     if (hasDefaultNone) {
677       DoConcurrentVariableEnforce doConcurrentVariableEnforce{
678           context_, currentStatementSourcePosition_};
679       parser::Walk(block, doConcurrentVariableEnforce);
680     }
681   }
682 
683   // C1123, concurrent limit or step expressions can't reference index-names
684   void CheckConcurrentHeader(const parser::ConcurrentHeader &header) const {
685     if (const auto &mask{
686             std::get<std::optional<parser::ScalarLogicalExpr>>(header.t)}) {
687       CheckMaskIsPure(*mask);
688     }
689     auto &controls{std::get<std::list<parser::ConcurrentControl>>(header.t)};
690     SymbolSet indexNames;
691     for (const parser::ConcurrentControl &control : controls) {
692       const auto &indexName{std::get<parser::Name>(control.t)};
693       if (indexName.symbol) {
694         indexNames.insert(*indexName.symbol);
695       }
696     }
697     if (!indexNames.empty()) {
698       for (const parser::ConcurrentControl &control : controls) {
699         HasNoReferences(indexNames, std::get<1>(control.t));
700         HasNoReferences(indexNames, std::get<2>(control.t));
701         if (const auto &intExpr{
702                 std::get<std::optional<parser::ScalarIntExpr>>(control.t)}) {
703           const parser::Expr &expr{intExpr->thing.thing.value()};
704           CheckNoCollisions(GatherSymbolsFromExpression(expr), indexNames,
705               "%s step expression may not reference index variable '%s'"_err_en_US,
706               expr.source);
707           if (IsZero(expr)) {
708             context_.Say(expr.source,
709                 "%s step expression may not be zero"_err_en_US, LoopKindName());
710           }
711         }
712       }
713     }
714   }
715 
716   void CheckLocalitySpecs(
717       const parser::LoopControl &control, const parser::Block &block) const {
718     const auto &concurrent{
719         std::get<parser::LoopControl::Concurrent>(control.u)};
720     const auto &header{std::get<parser::ConcurrentHeader>(concurrent.t)};
721     const auto &localitySpecs{
722         std::get<std::list<parser::LocalitySpec>>(concurrent.t)};
723     if (!localitySpecs.empty()) {
724       const SymbolSet &localVars{GatherLocals(localitySpecs)};
725       for (const auto &c : GetControls(control)) {
726         CheckExprDoesNotReferenceLocal(std::get<1>(c.t), localVars);
727         CheckExprDoesNotReferenceLocal(std::get<2>(c.t), localVars);
728         if (const auto &expr{
729                 std::get<std::optional<parser::ScalarIntExpr>>(c.t)}) {
730           CheckExprDoesNotReferenceLocal(*expr, localVars);
731         }
732       }
733       if (const auto &mask{
734               std::get<std::optional<parser::ScalarLogicalExpr>>(header.t)}) {
735         CheckMaskDoesNotReferenceLocal(*mask, localVars);
736       }
737       CheckDefaultNoneImpliesExplicitLocality(localitySpecs, block);
738     }
739   }
740 
741   // check constraints [C1121 .. C1130]
742   void CheckConcurrentLoopControl(const parser::LoopControl &control) const {
743     const auto &concurrent{
744         std::get<parser::LoopControl::Concurrent>(control.u)};
745     CheckConcurrentHeader(std::get<parser::ConcurrentHeader>(concurrent.t));
746   }
747 
748   template <typename T> void CheckForImpureCall(const T &x) {
749     const auto &intrinsics{context_.foldingContext().intrinsics()};
750     if (auto bad{FindImpureCall(intrinsics, x)}) {
751       context_.Say(
752           "Impure procedure '%s' may not be referenced in a %s"_err_en_US, *bad,
753           LoopKindName());
754     }
755   }
756 
757   // Each index should be used on the LHS of each assignment in a FORALL
758   void CheckForallIndexesUsed(const evaluate::Assignment &assignment) {
759     SymbolVector indexVars{context_.GetIndexVars(IndexVarKind::FORALL)};
760     if (!indexVars.empty()) {
761       SymbolSet symbols{evaluate::CollectSymbols(assignment.lhs)};
762       std::visit(
763           common::visitors{
764               [&](const evaluate::Assignment::BoundsSpec &spec) {
765                 for (const auto &bound : spec) {
766 // TODO: this is working around missing std::set::merge in some versions of
767 // clang that we are building with
768 #ifdef __clang__
769                   auto boundSymbols{evaluate::CollectSymbols(bound)};
770                   symbols.insert(boundSymbols.begin(), boundSymbols.end());
771 #else
772                   symbols.merge(evaluate::CollectSymbols(bound));
773 #endif
774                 }
775               },
776               [&](const evaluate::Assignment::BoundsRemapping &remapping) {
777                 for (const auto &bounds : remapping) {
778 #ifdef __clang__
779                   auto lbSymbols{evaluate::CollectSymbols(bounds.first)};
780                   symbols.insert(lbSymbols.begin(), lbSymbols.end());
781                   auto ubSymbols{evaluate::CollectSymbols(bounds.second)};
782                   symbols.insert(ubSymbols.begin(), ubSymbols.end());
783 #else
784                   symbols.merge(evaluate::CollectSymbols(bounds.first));
785                   symbols.merge(evaluate::CollectSymbols(bounds.second));
786 #endif
787                 }
788               },
789               [](const auto &) {},
790           },
791           assignment.u);
792       for (const Symbol &index : indexVars) {
793         if (symbols.count(index) == 0) {
794           context_.Say(
795               "Warning: FORALL index variable '%s' not used on left-hand side"
796               " of assignment"_en_US,
797               index.name());
798         }
799       }
800     }
801   }
802 
803   // For messages where the DO loop must be DO CONCURRENT, make that explicit.
804   const char *LoopKindName() const {
805     return kind_ == IndexVarKind::DO ? "DO CONCURRENT" : "FORALL";
806   }
807 
808   SemanticsContext &context_;
809   const IndexVarKind kind_;
810   parser::CharBlock currentStatementSourcePosition_;
811 }; // class DoContext
812 
813 void DoForallChecker::Enter(const parser::DoConstruct &doConstruct) {
814   DoContext doContext{context_, IndexVarKind::DO};
815   doContext.DefineDoVariables(doConstruct);
816 }
817 
818 void DoForallChecker::Leave(const parser::DoConstruct &doConstruct) {
819   DoContext doContext{context_, IndexVarKind::DO};
820   doContext.Check(doConstruct);
821   doContext.ResetDoVariables(doConstruct);
822 }
823 
824 void DoForallChecker::Enter(const parser::ForallConstruct &construct) {
825   DoContext doContext{context_, IndexVarKind::FORALL};
826   doContext.ActivateIndexVars(GetControls(construct));
827 }
828 void DoForallChecker::Leave(const parser::ForallConstruct &construct) {
829   DoContext doContext{context_, IndexVarKind::FORALL};
830   doContext.Check(construct);
831   doContext.DeactivateIndexVars(GetControls(construct));
832 }
833 
834 void DoForallChecker::Enter(const parser::ForallStmt &stmt) {
835   DoContext doContext{context_, IndexVarKind::FORALL};
836   doContext.ActivateIndexVars(GetControls(stmt));
837 }
838 void DoForallChecker::Leave(const parser::ForallStmt &stmt) {
839   DoContext doContext{context_, IndexVarKind::FORALL};
840   doContext.Check(stmt);
841   doContext.DeactivateIndexVars(GetControls(stmt));
842 }
843 void DoForallChecker::Leave(const parser::ForallAssignmentStmt &stmt) {
844   DoContext doContext{context_, IndexVarKind::FORALL};
845   doContext.Check(stmt);
846 }
847 
848 // Return the (possibly null) name of the ConstructNode
849 static const parser::Name *MaybeGetNodeName(const ConstructNode &construct) {
850   return std::visit(
851       [&](const auto &x) { return MaybeGetConstructName(*x); }, construct);
852 }
853 
854 template <typename A>
855 static parser::CharBlock GetConstructPosition(const A &a) {
856   return std::get<0>(a.t).source;
857 }
858 
859 static parser::CharBlock GetNodePosition(const ConstructNode &construct) {
860   return std::visit(
861       [&](const auto &x) { return GetConstructPosition(*x); }, construct);
862 }
863 
864 void DoForallChecker::SayBadLeave(StmtType stmtType,
865     const char *enclosingStmtName, const ConstructNode &construct) const {
866   context_
867       .Say("%s must not leave a %s statement"_err_en_US, EnumToString(stmtType),
868           enclosingStmtName)
869       .Attach(GetNodePosition(construct), "The construct that was left"_en_US);
870 }
871 
872 static const parser::DoConstruct *MaybeGetDoConstruct(
873     const ConstructNode &construct) {
874   if (const auto *doNode{
875           std::get_if<const parser::DoConstruct *>(&construct)}) {
876     return *doNode;
877   } else {
878     return nullptr;
879   }
880 }
881 
882 static bool ConstructIsDoConcurrent(const ConstructNode &construct) {
883   const parser::DoConstruct *doConstruct{MaybeGetDoConstruct(construct)};
884   return doConstruct && doConstruct->IsDoConcurrent();
885 }
886 
887 // Check that CYCLE and EXIT statements do not cause flow of control to
888 // leave DO CONCURRENT, CRITICAL, or CHANGE TEAM constructs.
889 void DoForallChecker::CheckForBadLeave(
890     StmtType stmtType, const ConstructNode &construct) const {
891   std::visit(common::visitors{
892                  [&](const parser::DoConstruct *doConstructPtr) {
893                    if (doConstructPtr->IsDoConcurrent()) {
894                      // C1135 and C1167 -- CYCLE and EXIT statements can't leave
895                      // a DO CONCURRENT
896                      SayBadLeave(stmtType, "DO CONCURRENT", construct);
897                    }
898                  },
899                  [&](const parser::CriticalConstruct *) {
900                    // C1135 and C1168 -- similarly, for CRITICAL
901                    SayBadLeave(stmtType, "CRITICAL", construct);
902                  },
903                  [&](const parser::ChangeTeamConstruct *) {
904                    // C1135 and C1168 -- similarly, for CHANGE TEAM
905                    SayBadLeave(stmtType, "CHANGE TEAM", construct);
906                  },
907                  [](const auto *) {},
908              },
909       construct);
910 }
911 
912 static bool StmtMatchesConstruct(const parser::Name *stmtName,
913     StmtType stmtType, const parser::Name *constructName,
914     const ConstructNode &construct) {
915   bool inDoConstruct{MaybeGetDoConstruct(construct) != nullptr};
916   if (!stmtName) {
917     return inDoConstruct; // Unlabeled statements match all DO constructs
918   } else if (constructName && constructName->source == stmtName->source) {
919     return stmtType == StmtType::EXIT || inDoConstruct;
920   } else {
921     return false;
922   }
923 }
924 
925 // C1167 Can't EXIT from a DO CONCURRENT
926 void DoForallChecker::CheckDoConcurrentExit(
927     StmtType stmtType, const ConstructNode &construct) const {
928   if (stmtType == StmtType::EXIT && ConstructIsDoConcurrent(construct)) {
929     SayBadLeave(StmtType::EXIT, "DO CONCURRENT", construct);
930   }
931 }
932 
933 // Check nesting violations for a CYCLE or EXIT statement.  Loop up the
934 // nesting levels looking for a construct that matches the CYCLE or EXIT
935 // statment.  At every construct, check for a violation.  If we find a match
936 // without finding a violation, the check is complete.
937 void DoForallChecker::CheckNesting(
938     StmtType stmtType, const parser::Name *stmtName) const {
939   const ConstructStack &stack{context_.constructStack()};
940   for (auto iter{stack.cend()}; iter-- != stack.cbegin();) {
941     const ConstructNode &construct{*iter};
942     const parser::Name *constructName{MaybeGetNodeName(construct)};
943     if (StmtMatchesConstruct(stmtName, stmtType, constructName, construct)) {
944       CheckDoConcurrentExit(stmtType, construct);
945       return; // We got a match, so we're finished checking
946     }
947     CheckForBadLeave(stmtType, construct);
948   }
949 
950   // We haven't found a match in the enclosing constructs
951   if (stmtType == StmtType::EXIT) {
952     context_.Say("No matching construct for EXIT statement"_err_en_US);
953   } else {
954     context_.Say("No matching DO construct for CYCLE statement"_err_en_US);
955   }
956 }
957 
958 // C1135 -- Nesting for CYCLE statements
959 void DoForallChecker::Enter(const parser::CycleStmt &cycleStmt) {
960   CheckNesting(StmtType::CYCLE, common::GetPtrFromOptional(cycleStmt.v));
961 }
962 
963 // C1167 and C1168 -- Nesting for EXIT statements
964 void DoForallChecker::Enter(const parser::ExitStmt &exitStmt) {
965   CheckNesting(StmtType::EXIT, common::GetPtrFromOptional(exitStmt.v));
966 }
967 
968 void DoForallChecker::Leave(const parser::AssignmentStmt &stmt) {
969   const auto &variable{std::get<parser::Variable>(stmt.t)};
970   context_.CheckIndexVarRedefine(variable);
971 }
972 
973 static void CheckIfArgIsDoVar(const evaluate::ActualArgument &arg,
974     const parser::CharBlock location, SemanticsContext &context) {
975   common::Intent intent{arg.dummyIntent()};
976   if (intent == common::Intent::Out || intent == common::Intent::InOut) {
977     if (const SomeExpr * argExpr{arg.UnwrapExpr()}) {
978       if (const Symbol * var{evaluate::UnwrapWholeSymbolDataRef(*argExpr)}) {
979         if (intent == common::Intent::Out) {
980           context.CheckIndexVarRedefine(location, *var);
981         } else {
982           context.WarnIndexVarRedefine(location, *var); // INTENT(INOUT)
983         }
984       }
985     }
986   }
987 }
988 
989 // Check to see if a DO variable is being passed as an actual argument to a
990 // dummy argument whose intent is OUT or INOUT.  To do this, we need to find
991 // the expressions for actual arguments which contain DO variables.  We get the
992 // intents of the dummy arguments from the ProcedureRef in the "typedCall"
993 // field of the CallStmt which was filled in during expression checking.  At
994 // the same time, we need to iterate over the parser::Expr versions of the
995 // actual arguments to get their source locations of the arguments for the
996 // messages.
997 void DoForallChecker::Leave(const parser::CallStmt &callStmt) {
998   if (const auto &typedCall{callStmt.typedCall}) {
999     const auto &parsedArgs{
1000         std::get<std::list<parser::ActualArgSpec>>(callStmt.v.t)};
1001     auto parsedArgIter{parsedArgs.begin()};
1002     const evaluate::ActualArguments &checkedArgs{typedCall->arguments()};
1003     for (const auto &checkedOptionalArg : checkedArgs) {
1004       if (parsedArgIter == parsedArgs.end()) {
1005         break; // No more parsed arguments, we're done.
1006       }
1007       const auto &parsedArg{std::get<parser::ActualArg>(parsedArgIter->t)};
1008       ++parsedArgIter;
1009       if (checkedOptionalArg) {
1010         const evaluate::ActualArgument &checkedArg{*checkedOptionalArg};
1011         if (const auto *parsedExpr{
1012                 std::get_if<common::Indirection<parser::Expr>>(&parsedArg.u)}) {
1013           CheckIfArgIsDoVar(checkedArg, parsedExpr->value().source, context_);
1014         }
1015       }
1016     }
1017   }
1018 }
1019 
1020 void DoForallChecker::Leave(const parser::ConnectSpec &connectSpec) {
1021   const auto *newunit{
1022       std::get_if<parser::ConnectSpec::Newunit>(&connectSpec.u)};
1023   if (newunit) {
1024     context_.CheckIndexVarRedefine(newunit->v.thing.thing);
1025   }
1026 }
1027 
1028 using ActualArgumentSet = std::set<evaluate::ActualArgumentRef>;
1029 
1030 struct CollectActualArgumentsHelper
1031     : public evaluate::SetTraverse<CollectActualArgumentsHelper,
1032           ActualArgumentSet> {
1033   using Base = SetTraverse<CollectActualArgumentsHelper, ActualArgumentSet>;
1034   CollectActualArgumentsHelper() : Base{*this} {}
1035   using Base::operator();
1036   ActualArgumentSet operator()(const evaluate::ActualArgument &arg) const {
1037     return Combine(ActualArgumentSet{arg},
1038         CollectActualArgumentsHelper{}(arg.UnwrapExpr()));
1039   }
1040 };
1041 
1042 template <typename A> ActualArgumentSet CollectActualArguments(const A &x) {
1043   return CollectActualArgumentsHelper{}(x);
1044 }
1045 
1046 template ActualArgumentSet CollectActualArguments(const SomeExpr &);
1047 
1048 void DoForallChecker::Enter(const parser::Expr &parsedExpr) { ++exprDepth_; }
1049 
1050 void DoForallChecker::Leave(const parser::Expr &parsedExpr) {
1051   CHECK(exprDepth_ > 0);
1052   if (--exprDepth_ == 0) { // Only check top level expressions
1053     if (const SomeExpr * expr{GetExpr(parsedExpr)}) {
1054       ActualArgumentSet argSet{CollectActualArguments(*expr)};
1055       for (const evaluate::ActualArgumentRef &argRef : argSet) {
1056         CheckIfArgIsDoVar(*argRef, parsedExpr.source, context_);
1057       }
1058     }
1059   }
1060 }
1061 
1062 void DoForallChecker::Leave(const parser::InquireSpec &inquireSpec) {
1063   const auto *intVar{std::get_if<parser::InquireSpec::IntVar>(&inquireSpec.u)};
1064   if (intVar) {
1065     const auto &scalar{std::get<parser::ScalarIntVariable>(intVar->t)};
1066     context_.CheckIndexVarRedefine(scalar.thing.thing);
1067   }
1068 }
1069 
1070 void DoForallChecker::Leave(const parser::IoControlSpec &ioControlSpec) {
1071   const auto *size{std::get_if<parser::IoControlSpec::Size>(&ioControlSpec.u)};
1072   if (size) {
1073     context_.CheckIndexVarRedefine(size->v.thing.thing);
1074   }
1075 }
1076 
1077 void DoForallChecker::Leave(const parser::OutputImpliedDo &outputImpliedDo) {
1078   const auto &control{std::get<parser::IoImpliedDoControl>(outputImpliedDo.t)};
1079   const parser::Name &name{control.name.thing.thing};
1080   context_.CheckIndexVarRedefine(name.source, *name.symbol);
1081 }
1082 
1083 void DoForallChecker::Leave(const parser::StatVariable &statVariable) {
1084   context_.CheckIndexVarRedefine(statVariable.v.thing.thing);
1085 }
1086 
1087 } // namespace Fortran::semantics
1088