1 //===-- lib/Semantics/check-case.cpp --------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "check-case.h" 10 #include "flang/Common/idioms.h" 11 #include "flang/Common/reference.h" 12 #include "flang/Evaluate/fold.h" 13 #include "flang/Evaluate/type.h" 14 #include "flang/Parser/parse-tree.h" 15 #include "flang/Semantics/semantics.h" 16 #include "flang/Semantics/tools.h" 17 #include <tuple> 18 19 namespace Fortran::semantics { 20 21 template <typename T> class CaseValues { 22 public: 23 CaseValues(SemanticsContext &c, const evaluate::DynamicType &t) 24 : context_{c}, caseExprType_{t} {} 25 26 void Check(const std::list<parser::CaseConstruct::Case> &cases) { 27 for (const parser::CaseConstruct::Case &c : cases) { 28 AddCase(c); 29 } 30 if (!hasErrors_) { 31 cases_.sort(Comparator{}); 32 if (!AreCasesDisjoint()) { // C1149 33 ReportConflictingCases(); 34 } 35 } 36 } 37 38 private: 39 using Value = evaluate::Scalar<T>; 40 41 void AddCase(const parser::CaseConstruct::Case &c) { 42 const auto &stmt{std::get<parser::Statement<parser::CaseStmt>>(c.t)}; 43 const parser::CaseStmt &caseStmt{stmt.statement}; 44 const auto &selector{std::get<parser::CaseSelector>(caseStmt.t)}; 45 std::visit( 46 common::visitors{ 47 [&](const std::list<parser::CaseValueRange> &ranges) { 48 for (const auto &range : ranges) { 49 auto pair{ComputeBounds(range)}; 50 if (pair.first && pair.second && *pair.first > *pair.second) { 51 context_.Say(stmt.source, 52 "CASE has lower bound greater than upper bound"_en_US); 53 } else { 54 if constexpr (T::category == TypeCategory::Logical) { // C1148 55 if ((pair.first || pair.second) && 56 (!pair.first || !pair.second || 57 *pair.first != *pair.second)) { 58 context_.Say(stmt.source, 59 "CASE range is not allowed for LOGICAL"_err_en_US); 60 } 61 } 62 cases_.emplace_back(stmt); 63 cases_.back().lower = std::move(pair.first); 64 cases_.back().upper = std::move(pair.second); 65 } 66 } 67 }, 68 [&](const parser::Default &) { cases_.emplace_front(stmt); }, 69 }, 70 selector.u); 71 } 72 73 std::optional<Value> GetValue(const parser::CaseValue &caseValue) { 74 const parser::Expr &expr{caseValue.thing.thing.value()}; 75 auto *x{expr.typedExpr.get()}; 76 if (x && x->v) { // C1147 77 auto type{x->v->GetType()}; 78 if (type && type->category() == caseExprType_.category() && 79 (type->category() != TypeCategory::Character || 80 type->kind() == caseExprType_.kind())) { 81 x->v = evaluate::Fold(context_.foldingContext(), 82 evaluate::ConvertToType(T::GetType(), std::move(*x->v))); 83 if (x->v) { 84 if (auto value{evaluate::GetScalarConstantValue<T>(*x->v)}) { 85 return *value; 86 } 87 } 88 context_.Say( 89 expr.source, "CASE value must be a constant scalar"_err_en_US); 90 } else { 91 std::string typeStr{type ? type->AsFortran() : "typeless"s}; 92 context_.Say(expr.source, 93 "CASE value has type '%s' which is not compatible with the SELECT CASE expression's type '%s'"_err_en_US, 94 typeStr, caseExprType_.AsFortran()); 95 } 96 hasErrors_ = true; 97 } 98 return std::nullopt; 99 } 100 101 using PairOfValues = std::pair<std::optional<Value>, std::optional<Value>>; 102 PairOfValues ComputeBounds(const parser::CaseValueRange &range) { 103 return std::visit(common::visitors{ 104 [&](const parser::CaseValue &x) { 105 auto value{GetValue(x)}; 106 return PairOfValues{value, value}; 107 }, 108 [&](const parser::CaseValueRange::Range &x) { 109 std::optional<Value> lo, hi; 110 if (x.lower) { 111 lo = GetValue(*x.lower); 112 } 113 if (x.upper) { 114 hi = GetValue(*x.upper); 115 } 116 if ((x.lower && !lo) || (x.upper && !hi)) { 117 return PairOfValues{}; // error case 118 } 119 return PairOfValues{std::move(lo), std::move(hi)}; 120 }, 121 }, 122 range.u); 123 } 124 125 struct Case { 126 explicit Case(const parser::Statement<parser::CaseStmt> &s) : stmt{s} {} 127 bool IsDefault() const { return !lower && !upper; } 128 std::string AsFortran() const { 129 std::string result; 130 { 131 llvm::raw_string_ostream bs{result}; 132 if (lower) { 133 evaluate::Constant<T>{*lower}.AsFortran(bs << '('); 134 if (!upper) { 135 bs << ':'; 136 } else if (*lower != *upper) { 137 evaluate::Constant<T>{*upper}.AsFortran(bs << ':'); 138 } 139 bs << ')'; 140 } else if (upper) { 141 evaluate::Constant<T>{*upper}.AsFortran(bs << "(:") << ')'; 142 } else { 143 bs << "DEFAULT"; 144 } 145 } 146 return result; 147 } 148 149 const parser::Statement<parser::CaseStmt> &stmt; 150 std::optional<Value> lower, upper; 151 }; 152 153 // Defines a comparator for use with std::list<>::sort(). 154 // Returns true if and only if the highest value in range x is less 155 // than the least value in range y. The DEFAULT case is arbitrarily 156 // defined to be less than all others. When two ranges overlap, 157 // neither is less than the other. 158 struct Comparator { 159 bool operator()(const Case &x, const Case &y) const { 160 if (x.IsDefault()) { 161 return !y.IsDefault(); 162 } else { 163 return x.upper && y.lower && *x.upper < *y.lower; 164 } 165 } 166 }; 167 168 bool AreCasesDisjoint() const { 169 auto endIter{cases_.end()}; 170 for (auto iter{cases_.begin()}; iter != endIter; ++iter) { 171 auto next{iter}; 172 if (++next != endIter && !Comparator{}(*iter, *next)) { 173 return false; 174 } 175 } 176 return true; 177 } 178 179 // This has quadratic time, but only runs in error cases 180 void ReportConflictingCases() { 181 for (auto iter{cases_.begin()}; iter != cases_.end(); ++iter) { 182 parser::Message *msg{nullptr}; 183 for (auto p{cases_.begin()}; p != cases_.end(); ++p) { 184 if (p->stmt.source.begin() < iter->stmt.source.begin() && 185 !Comparator{}(*p, *iter) && !Comparator{}(*iter, *p)) { 186 if (!msg) { 187 msg = &context_.Say(iter->stmt.source, 188 "CASE %s conflicts with previous cases"_err_en_US, 189 iter->AsFortran()); 190 } 191 msg->Attach( 192 p->stmt.source, "Conflicting CASE %s"_en_US, p->AsFortran()); 193 } 194 } 195 } 196 } 197 198 SemanticsContext &context_; 199 const evaluate::DynamicType &caseExprType_; 200 std::list<Case> cases_; 201 bool hasErrors_{false}; 202 }; 203 204 void CaseChecker::Enter(const parser::CaseConstruct &construct) { 205 const auto &selectCaseStmt{ 206 std::get<parser::Statement<parser::SelectCaseStmt>>(construct.t)}; 207 const auto &selectCase{selectCaseStmt.statement}; 208 const auto &selectExpr{ 209 std::get<parser::Scalar<parser::Expr>>(selectCase.t).thing}; 210 const auto *x{GetExpr(selectExpr)}; 211 if (!x) { 212 return; // expression semantics failed 213 } 214 if (auto exprType{x->GetType()}) { 215 const auto &caseList{ 216 std::get<std::list<parser::CaseConstruct::Case>>(construct.t)}; 217 switch (exprType->category()) { 218 case TypeCategory::Integer: 219 CaseValues<evaluate::Type<TypeCategory::Integer, 16>>{context_, *exprType} 220 .Check(caseList); 221 return; 222 case TypeCategory::Logical: 223 CaseValues<evaluate::Type<TypeCategory::Logical, 1>>{context_, *exprType} 224 .Check(caseList); 225 return; 226 case TypeCategory::Character: 227 switch (exprType->kind()) { 228 SWITCH_COVERS_ALL_CASES 229 case 1: 230 CaseValues<evaluate::Type<TypeCategory::Character, 1>>{ 231 context_, *exprType} 232 .Check(caseList); 233 return; 234 case 2: 235 CaseValues<evaluate::Type<TypeCategory::Character, 2>>{ 236 context_, *exprType} 237 .Check(caseList); 238 return; 239 case 4: 240 CaseValues<evaluate::Type<TypeCategory::Character, 4>>{ 241 context_, *exprType} 242 .Check(caseList); 243 return; 244 } 245 default: 246 break; 247 } 248 } 249 context_.Say(selectExpr.source, 250 "SELECT CASE expression must be integer, logical, or character"_err_en_US); 251 } 252 } // namespace Fortran::semantics 253