1 //===-- lib/Semantics/check-case.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "check-case.h"
10 #include "flang/Common/idioms.h"
11 #include "flang/Common/reference.h"
12 #include "flang/Evaluate/fold.h"
13 #include "flang/Evaluate/type.h"
14 #include "flang/Parser/parse-tree.h"
15 #include "flang/Semantics/semantics.h"
16 #include "flang/Semantics/tools.h"
17 #include <tuple>
18 
19 namespace Fortran::semantics {
20 
21 template <typename T> class CaseValues {
22 public:
23   CaseValues(SemanticsContext &c, const evaluate::DynamicType &t)
24       : context_{c}, caseExprType_{t} {}
25 
26   void Check(const std::list<parser::CaseConstruct::Case> &cases) {
27     for (const parser::CaseConstruct::Case &c : cases) {
28       AddCase(c);
29     }
30     if (!hasErrors_) {
31       cases_.sort(Comparator{});
32       if (!AreCasesDisjoint()) { // C1149
33         ReportConflictingCases();
34       }
35     }
36   }
37 
38 private:
39   using Value = evaluate::Scalar<T>;
40 
41   void AddCase(const parser::CaseConstruct::Case &c) {
42     const auto &stmt{std::get<parser::Statement<parser::CaseStmt>>(c.t)};
43     const parser::CaseStmt &caseStmt{stmt.statement};
44     const auto &selector{std::get<parser::CaseSelector>(caseStmt.t)};
45     std::visit(
46         common::visitors{
47             [&](const std::list<parser::CaseValueRange> &ranges) {
48               for (const auto &range : ranges) {
49                 auto pair{ComputeBounds(range)};
50                 if (pair.first && pair.second && *pair.first > *pair.second) {
51                   context_.Say(stmt.source,
52                       "CASE has lower bound greater than upper bound"_en_US);
53                 } else {
54                   if constexpr (T::category == TypeCategory::Logical) { // C1148
55                     if ((pair.first || pair.second) &&
56                         (!pair.first || !pair.second ||
57                             *pair.first != *pair.second)) {
58                       context_.Say(stmt.source,
59                           "CASE range is not allowed for LOGICAL"_err_en_US);
60                     }
61                   }
62                   cases_.emplace_back(stmt);
63                   cases_.back().lower = std::move(pair.first);
64                   cases_.back().upper = std::move(pair.second);
65                 }
66               }
67             },
68             [&](const parser::Default &) { cases_.emplace_front(stmt); },
69         },
70         selector.u);
71   }
72 
73   std::optional<Value> GetValue(const parser::CaseValue &caseValue) {
74     const parser::Expr &expr{caseValue.thing.thing.value()};
75     auto *x{expr.typedExpr.get()};
76     if (x && x->v) { // C1147
77       auto type{x->v->GetType()};
78       if (type && type->category() == caseExprType_.category() &&
79           (type->category() != TypeCategory::Character ||
80               type->kind() == caseExprType_.kind())) {
81         x->v = evaluate::Fold(context_.foldingContext(),
82             evaluate::ConvertToType(T::GetType(), std::move(*x->v)));
83         if (x->v) {
84           if (auto value{evaluate::GetScalarConstantValue<T>(*x->v)}) {
85             return *value;
86           }
87         }
88         context_.Say(
89             expr.source, "CASE value must be a constant scalar"_err_en_US);
90       } else {
91         std::string typeStr{type ? type->AsFortran() : "typeless"s};
92         context_.Say(expr.source,
93             "CASE value has type '%s' which is not compatible with the SELECT CASE expression's type '%s'"_err_en_US,
94             typeStr, caseExprType_.AsFortran());
95       }
96       hasErrors_ = true;
97     }
98     return std::nullopt;
99   }
100 
101   using PairOfValues = std::pair<std::optional<Value>, std::optional<Value>>;
102   PairOfValues ComputeBounds(const parser::CaseValueRange &range) {
103     return std::visit(common::visitors{
104                           [&](const parser::CaseValue &x) {
105                             auto value{GetValue(x)};
106                             return PairOfValues{value, value};
107                           },
108                           [&](const parser::CaseValueRange::Range &x) {
109                             std::optional<Value> lo, hi;
110                             if (x.lower) {
111                               lo = GetValue(*x.lower);
112                             }
113                             if (x.upper) {
114                               hi = GetValue(*x.upper);
115                             }
116                             if ((x.lower && !lo) || (x.upper && !hi)) {
117                               return PairOfValues{}; // error case
118                             }
119                             return PairOfValues{std::move(lo), std::move(hi)};
120                           },
121                       },
122         range.u);
123   }
124 
125   struct Case {
126     explicit Case(const parser::Statement<parser::CaseStmt> &s) : stmt{s} {}
127     bool IsDefault() const { return !lower && !upper; }
128     std::string AsFortran() const {
129       std::string result;
130       {
131         llvm::raw_string_ostream bs{result};
132         if (lower) {
133           evaluate::Constant<T>{*lower}.AsFortran(bs << '(');
134           if (!upper) {
135             bs << ':';
136           } else if (*lower != *upper) {
137             evaluate::Constant<T>{*upper}.AsFortran(bs << ':');
138           }
139           bs << ')';
140         } else if (upper) {
141           evaluate::Constant<T>{*upper}.AsFortran(bs << "(:") << ')';
142         } else {
143           bs << "DEFAULT";
144         }
145       }
146       return result;
147     }
148 
149     const parser::Statement<parser::CaseStmt> &stmt;
150     std::optional<Value> lower, upper;
151   };
152 
153   // Defines a comparator for use with std::list<>::sort().
154   // Returns true if and only if the highest value in range x is less
155   // than the least value in range y.  The DEFAULT case is arbitrarily
156   // defined to be less than all others.  When two ranges overlap,
157   // neither is less than the other.
158   struct Comparator {
159     bool operator()(const Case &x, const Case &y) const {
160       if (x.IsDefault()) {
161         return !y.IsDefault();
162       } else {
163         return x.upper && y.lower && *x.upper < *y.lower;
164       }
165     }
166   };
167 
168   bool AreCasesDisjoint() const {
169     auto endIter{cases_.end()};
170     for (auto iter{cases_.begin()}; iter != endIter; ++iter) {
171       auto next{iter};
172       if (++next != endIter && !Comparator{}(*iter, *next)) {
173         return false;
174       }
175     }
176     return true;
177   }
178 
179   // This has quadratic time, but only runs in error cases
180   void ReportConflictingCases() {
181     for (auto iter{cases_.begin()}; iter != cases_.end(); ++iter) {
182       parser::Message *msg{nullptr};
183       for (auto p{cases_.begin()}; p != cases_.end(); ++p) {
184         if (p->stmt.source.begin() < iter->stmt.source.begin() &&
185             !Comparator{}(*p, *iter) && !Comparator{}(*iter, *p)) {
186           if (!msg) {
187             msg = &context_.Say(iter->stmt.source,
188                 "CASE %s conflicts with previous cases"_err_en_US,
189                 iter->AsFortran());
190           }
191           msg->Attach(
192               p->stmt.source, "Conflicting CASE %s"_en_US, p->AsFortran());
193         }
194       }
195     }
196   }
197 
198   SemanticsContext &context_;
199   const evaluate::DynamicType &caseExprType_;
200   std::list<Case> cases_;
201   bool hasErrors_{false};
202 };
203 
204 void CaseChecker::Enter(const parser::CaseConstruct &construct) {
205   const auto &selectCaseStmt{
206       std::get<parser::Statement<parser::SelectCaseStmt>>(construct.t)};
207   const auto &selectCase{selectCaseStmt.statement};
208   const auto &selectExpr{
209       std::get<parser::Scalar<parser::Expr>>(selectCase.t).thing};
210   const auto *x{GetExpr(selectExpr)};
211   if (!x) {
212     return; // expression semantics failed
213   }
214   if (auto exprType{x->GetType()}) {
215     const auto &caseList{
216         std::get<std::list<parser::CaseConstruct::Case>>(construct.t)};
217     switch (exprType->category()) {
218     case TypeCategory::Integer:
219       CaseValues<evaluate::Type<TypeCategory::Integer, 16>>{context_, *exprType}
220           .Check(caseList);
221       return;
222     case TypeCategory::Logical:
223       CaseValues<evaluate::Type<TypeCategory::Logical, 1>>{context_, *exprType}
224           .Check(caseList);
225       return;
226     case TypeCategory::Character:
227       switch (exprType->kind()) {
228         SWITCH_COVERS_ALL_CASES
229       case 1:
230         CaseValues<evaluate::Type<TypeCategory::Character, 1>>{
231             context_, *exprType}
232             .Check(caseList);
233         return;
234       case 2:
235         CaseValues<evaluate::Type<TypeCategory::Character, 2>>{
236             context_, *exprType}
237             .Check(caseList);
238         return;
239       case 4:
240         CaseValues<evaluate::Type<TypeCategory::Character, 4>>{
241             context_, *exprType}
242             .Check(caseList);
243         return;
244       }
245     default:
246       break;
247     }
248   }
249   context_.Say(selectExpr.source,
250       "SELECT CASE expression must be integer, logical, or character"_err_en_US);
251 }
252 } // namespace Fortran::semantics
253