1 //===-- lib/Semantics/check-case.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "check-case.h"
10 #include "flang/Common/idioms.h"
11 #include "flang/Common/reference.h"
12 #include "flang/Common/template.h"
13 #include "flang/Evaluate/fold.h"
14 #include "flang/Evaluate/type.h"
15 #include "flang/Parser/parse-tree.h"
16 #include "flang/Semantics/semantics.h"
17 #include "flang/Semantics/tools.h"
18 #include <tuple>
19 
20 namespace Fortran::semantics {
21 
22 template <typename T> class CaseValues {
23 public:
24   CaseValues(SemanticsContext &c, const evaluate::DynamicType &t)
25       : context_{c}, caseExprType_{t} {}
26 
27   void Check(const std::list<parser::CaseConstruct::Case> &cases) {
28     for (const parser::CaseConstruct::Case &c : cases) {
29       AddCase(c);
30     }
31     if (!hasErrors_) {
32       cases_.sort(Comparator{});
33       if (!AreCasesDisjoint()) { // C1149
34         ReportConflictingCases();
35       }
36     }
37   }
38 
39 private:
40   using Value = evaluate::Scalar<T>;
41 
42   void AddCase(const parser::CaseConstruct::Case &c) {
43     const auto &stmt{std::get<parser::Statement<parser::CaseStmt>>(c.t)};
44     const parser::CaseStmt &caseStmt{stmt.statement};
45     const auto &selector{std::get<parser::CaseSelector>(caseStmt.t)};
46     std::visit(
47         common::visitors{
48             [&](const std::list<parser::CaseValueRange> &ranges) {
49               for (const auto &range : ranges) {
50                 auto pair{ComputeBounds(range)};
51                 if (pair.first && pair.second && *pair.first > *pair.second) {
52                   context_.Say(stmt.source,
53                       "CASE has lower bound greater than upper bound"_warn_en_US);
54                 } else {
55                   if constexpr (T::category == TypeCategory::Logical) { // C1148
56                     if ((pair.first || pair.second) &&
57                         (!pair.first || !pair.second ||
58                             *pair.first != *pair.second)) {
59                       context_.Say(stmt.source,
60                           "CASE range is not allowed for LOGICAL"_err_en_US);
61                     }
62                   }
63                   cases_.emplace_back(stmt);
64                   cases_.back().lower = std::move(pair.first);
65                   cases_.back().upper = std::move(pair.second);
66                 }
67               }
68             },
69             [&](const parser::Default &) { cases_.emplace_front(stmt); },
70         },
71         selector.u);
72   }
73 
74   std::optional<Value> GetValue(const parser::CaseValue &caseValue) {
75     const parser::Expr &expr{caseValue.thing.thing.value()};
76     auto *x{expr.typedExpr.get()};
77     if (x && x->v) { // C1147
78       auto type{x->v->GetType()};
79       if (type && type->category() == caseExprType_.category() &&
80           (type->category() != TypeCategory::Character ||
81               type->kind() == caseExprType_.kind())) {
82         parser::Messages buffer; // discarded folding messages
83         parser::ContextualMessages foldingMessages{expr.source, &buffer};
84         evaluate::FoldingContext foldingContext{
85             context_.foldingContext(), foldingMessages};
86         auto folded{evaluate::Fold(foldingContext, SomeExpr{*x->v})};
87         if (auto converted{evaluate::Fold(foldingContext,
88                 evaluate::ConvertToType(T::GetType(), SomeExpr{folded}))}) {
89           if (auto value{evaluate::GetScalarConstantValue<T>(*converted)}) {
90             auto back{evaluate::Fold(foldingContext,
91                 evaluate::ConvertToType(*type, SomeExpr{*converted}))};
92             if (back == folded) {
93               x->v = converted;
94               return value;
95             } else {
96               context_.Say(expr.source,
97                   "CASE value (%s) overflows type (%s) of SELECT CASE expression"_err_en_US,
98                   folded.AsFortran(), caseExprType_.AsFortran());
99               hasErrors_ = true;
100               return std::nullopt;
101             }
102           }
103         }
104         context_.Say(expr.source,
105             "CASE value (%s) must be a constant scalar"_err_en_US,
106             x->v->AsFortran());
107       } else {
108         std::string typeStr{type ? type->AsFortran() : "typeless"s};
109         context_.Say(expr.source,
110             "CASE value has type '%s' which is not compatible with the SELECT CASE expression's type '%s'"_err_en_US,
111             typeStr, caseExprType_.AsFortran());
112       }
113       hasErrors_ = true;
114     }
115     return std::nullopt;
116   }
117 
118   using PairOfValues = std::pair<std::optional<Value>, std::optional<Value>>;
119   PairOfValues ComputeBounds(const parser::CaseValueRange &range) {
120     return std::visit(common::visitors{
121                           [&](const parser::CaseValue &x) {
122                             auto value{GetValue(x)};
123                             return PairOfValues{value, value};
124                           },
125                           [&](const parser::CaseValueRange::Range &x) {
126                             std::optional<Value> lo, hi;
127                             if (x.lower) {
128                               lo = GetValue(*x.lower);
129                             }
130                             if (x.upper) {
131                               hi = GetValue(*x.upper);
132                             }
133                             if ((x.lower && !lo) || (x.upper && !hi)) {
134                               return PairOfValues{}; // error case
135                             }
136                             return PairOfValues{std::move(lo), std::move(hi)};
137                           },
138                       },
139         range.u);
140   }
141 
142   struct Case {
143     explicit Case(const parser::Statement<parser::CaseStmt> &s) : stmt{s} {}
144     bool IsDefault() const { return !lower && !upper; }
145     std::string AsFortran() const {
146       std::string result;
147       {
148         llvm::raw_string_ostream bs{result};
149         if (lower) {
150           evaluate::Constant<T>{*lower}.AsFortran(bs << '(');
151           if (!upper) {
152             bs << ':';
153           } else if (*lower != *upper) {
154             evaluate::Constant<T>{*upper}.AsFortran(bs << ':');
155           }
156           bs << ')';
157         } else if (upper) {
158           evaluate::Constant<T>{*upper}.AsFortran(bs << "(:") << ')';
159         } else {
160           bs << "DEFAULT";
161         }
162       }
163       return result;
164     }
165 
166     const parser::Statement<parser::CaseStmt> &stmt;
167     std::optional<Value> lower, upper;
168   };
169 
170   // Defines a comparator for use with std::list<>::sort().
171   // Returns true if and only if the highest value in range x is less
172   // than the least value in range y.  The DEFAULT case is arbitrarily
173   // defined to be less than all others.  When two ranges overlap,
174   // neither is less than the other.
175   struct Comparator {
176     bool operator()(const Case &x, const Case &y) const {
177       if (x.IsDefault()) {
178         return !y.IsDefault();
179       } else {
180         return x.upper && y.lower && *x.upper < *y.lower;
181       }
182     }
183   };
184 
185   bool AreCasesDisjoint() const {
186     auto endIter{cases_.end()};
187     for (auto iter{cases_.begin()}; iter != endIter; ++iter) {
188       auto next{iter};
189       if (++next != endIter && !Comparator{}(*iter, *next)) {
190         return false;
191       }
192     }
193     return true;
194   }
195 
196   // This has quadratic time, but only runs in error cases
197   void ReportConflictingCases() {
198     for (auto iter{cases_.begin()}; iter != cases_.end(); ++iter) {
199       parser::Message *msg{nullptr};
200       for (auto p{cases_.begin()}; p != cases_.end(); ++p) {
201         if (p->stmt.source.begin() < iter->stmt.source.begin() &&
202             !Comparator{}(*p, *iter) && !Comparator{}(*iter, *p)) {
203           if (!msg) {
204             msg = &context_.Say(iter->stmt.source,
205                 "CASE %s conflicts with previous cases"_err_en_US,
206                 iter->AsFortran());
207           }
208           msg->Attach(
209               p->stmt.source, "Conflicting CASE %s"_en_US, p->AsFortran());
210         }
211       }
212     }
213   }
214 
215   SemanticsContext &context_;
216   const evaluate::DynamicType &caseExprType_;
217   std::list<Case> cases_;
218   bool hasErrors_{false};
219 };
220 
221 template <TypeCategory CAT> struct TypeVisitor {
222   using Result = bool;
223   using Types = evaluate::CategoryTypes<CAT>;
224   template <typename T> Result Test() {
225     if (T::kind == exprType.kind()) {
226       CaseValues<T>(context, exprType).Check(caseList);
227       return true;
228     } else {
229       return false;
230     }
231   }
232   SemanticsContext &context;
233   const evaluate::DynamicType &exprType;
234   const std::list<parser::CaseConstruct::Case> &caseList;
235 };
236 
237 void CaseChecker::Enter(const parser::CaseConstruct &construct) {
238   const auto &selectCaseStmt{
239       std::get<parser::Statement<parser::SelectCaseStmt>>(construct.t)};
240   const auto &selectCase{selectCaseStmt.statement};
241   const auto &selectExpr{
242       std::get<parser::Scalar<parser::Expr>>(selectCase.t).thing};
243   const auto *x{GetExpr(selectExpr)};
244   if (!x) {
245     return; // expression semantics failed
246   }
247   if (auto exprType{x->GetType()}) {
248     const auto &caseList{
249         std::get<std::list<parser::CaseConstruct::Case>>(construct.t)};
250     switch (exprType->category()) {
251     case TypeCategory::Integer:
252       common::SearchTypes(
253           TypeVisitor<TypeCategory::Integer>{context_, *exprType, caseList});
254       return;
255     case TypeCategory::Logical:
256       CaseValues<evaluate::Type<TypeCategory::Logical, 1>>{context_, *exprType}
257           .Check(caseList);
258       return;
259     case TypeCategory::Character:
260       common::SearchTypes(
261           TypeVisitor<TypeCategory::Character>{context_, *exprType, caseList});
262       return;
263     default:
264       break;
265     }
266   }
267   context_.Say(selectExpr.source,
268       "SELECT CASE expression must be integer, logical, or character"_err_en_US);
269 }
270 } // namespace Fortran::semantics
271