1 //===-- IterationSpace.cpp ------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "flang/Lower/IterationSpace.h" 14 #include "flang/Evaluate/expression.h" 15 #include "flang/Lower/AbstractConverter.h" 16 #include "flang/Lower/Support/Utils.h" 17 #include "llvm/Support/Debug.h" 18 19 #define DEBUG_TYPE "flang-lower-iteration-space" 20 21 namespace { 22 // Fortran::evaluate::Expr are functional values organized like an AST. A 23 // Fortran::evaluate::Expr is meant to be moved and cloned. Using the front end 24 // tools can often cause copies and extra wrapper classes to be added to any 25 // Fortran::evalute::Expr. These values should not be assumed or relied upon to 26 // have an *object* identity. They are deeply recursive, irregular structures 27 // built from a large number of classes which do not use inheritance and 28 // necessitate a large volume of boilerplate code as a result. 29 // 30 // Contrastingly, LLVM data structures make ubiquitous assumptions about an 31 // object's identity via pointers to the object. An object's location in memory 32 // is thus very often an identifying relation. 33 34 // This class defines a hash computation of a Fortran::evaluate::Expr tree value 35 // so it can be used with llvm::DenseMap. The Fortran::evaluate::Expr need not 36 // have the same address. 37 class HashEvaluateExpr { 38 public: 39 // A Se::Symbol is the only part of an Fortran::evaluate::Expr with an 40 // identity property. 41 static unsigned getHashValue(const Fortran::semantics::Symbol &x) { 42 return static_cast<unsigned>(reinterpret_cast<std::intptr_t>(&x)); 43 } 44 template <typename A, bool COPY> 45 static unsigned getHashValue(const Fortran::common::Indirection<A, COPY> &x) { 46 return getHashValue(x.value()); 47 } 48 template <typename A> 49 static unsigned getHashValue(const std::optional<A> &x) { 50 if (x.has_value()) 51 return getHashValue(x.value()); 52 return 0u; 53 } 54 static unsigned getHashValue(const Fortran::evaluate::Subscript &x) { 55 return std::visit([&](const auto &v) { return getHashValue(v); }, x.u); 56 } 57 static unsigned getHashValue(const Fortran::evaluate::Triplet &x) { 58 return getHashValue(x.lower()) - getHashValue(x.upper()) * 5u - 59 getHashValue(x.stride()) * 11u; 60 } 61 static unsigned getHashValue(const Fortran::evaluate::Component &x) { 62 return getHashValue(x.base()) * 83u - getHashValue(x.GetLastSymbol()); 63 } 64 static unsigned getHashValue(const Fortran::evaluate::ArrayRef &x) { 65 unsigned subs = 1u; 66 for (const Fortran::evaluate::Subscript &v : x.subscript()) 67 subs -= getHashValue(v); 68 return getHashValue(x.base()) * 89u - subs; 69 } 70 static unsigned getHashValue(const Fortran::evaluate::CoarrayRef &x) { 71 unsigned subs = 1u; 72 for (const Fortran::evaluate::Subscript &v : x.subscript()) 73 subs -= getHashValue(v); 74 unsigned cosubs = 3u; 75 for (const Fortran::evaluate::Expr<Fortran::evaluate::SubscriptInteger> &v : 76 x.cosubscript()) 77 cosubs -= getHashValue(v); 78 unsigned syms = 7u; 79 for (const Fortran::evaluate::SymbolRef &v : x.base()) 80 syms += getHashValue(v); 81 return syms * 97u - subs - cosubs + getHashValue(x.stat()) + 257u + 82 getHashValue(x.team()); 83 } 84 static unsigned getHashValue(const Fortran::evaluate::NamedEntity &x) { 85 if (x.IsSymbol()) 86 return getHashValue(x.GetFirstSymbol()) * 11u; 87 return getHashValue(x.GetComponent()) * 13u; 88 } 89 static unsigned getHashValue(const Fortran::evaluate::DataRef &x) { 90 return std::visit([&](const auto &v) { return getHashValue(v); }, x.u); 91 } 92 static unsigned getHashValue(const Fortran::evaluate::ComplexPart &x) { 93 return getHashValue(x.complex()) - static_cast<unsigned>(x.part()); 94 } 95 template <Fortran::common::TypeCategory TC1, int KIND, 96 Fortran::common::TypeCategory TC2> 97 static unsigned getHashValue( 98 const Fortran::evaluate::Convert<Fortran::evaluate::Type<TC1, KIND>, TC2> 99 &x) { 100 return getHashValue(x.left()) - (static_cast<unsigned>(TC1) + 2u) - 101 (static_cast<unsigned>(KIND) + 5u); 102 } 103 template <int KIND> 104 static unsigned 105 getHashValue(const Fortran::evaluate::ComplexComponent<KIND> &x) { 106 return getHashValue(x.left()) - 107 (static_cast<unsigned>(x.isImaginaryPart) + 1u) * 3u; 108 } 109 template <typename T> 110 static unsigned getHashValue(const Fortran::evaluate::Parentheses<T> &x) { 111 return getHashValue(x.left()) * 17u; 112 } 113 template <Fortran::common::TypeCategory TC, int KIND> 114 static unsigned getHashValue( 115 const Fortran::evaluate::Negate<Fortran::evaluate::Type<TC, KIND>> &x) { 116 return getHashValue(x.left()) - (static_cast<unsigned>(TC) + 5u) - 117 (static_cast<unsigned>(KIND) + 7u); 118 } 119 template <Fortran::common::TypeCategory TC, int KIND> 120 static unsigned getHashValue( 121 const Fortran::evaluate::Add<Fortran::evaluate::Type<TC, KIND>> &x) { 122 return (getHashValue(x.left()) + getHashValue(x.right())) * 23u + 123 static_cast<unsigned>(TC) + static_cast<unsigned>(KIND); 124 } 125 template <Fortran::common::TypeCategory TC, int KIND> 126 static unsigned getHashValue( 127 const Fortran::evaluate::Subtract<Fortran::evaluate::Type<TC, KIND>> &x) { 128 return (getHashValue(x.left()) - getHashValue(x.right())) * 19u + 129 static_cast<unsigned>(TC) + static_cast<unsigned>(KIND); 130 } 131 template <Fortran::common::TypeCategory TC, int KIND> 132 static unsigned getHashValue( 133 const Fortran::evaluate::Multiply<Fortran::evaluate::Type<TC, KIND>> &x) { 134 return (getHashValue(x.left()) + getHashValue(x.right())) * 29u + 135 static_cast<unsigned>(TC) + static_cast<unsigned>(KIND); 136 } 137 template <Fortran::common::TypeCategory TC, int KIND> 138 static unsigned getHashValue( 139 const Fortran::evaluate::Divide<Fortran::evaluate::Type<TC, KIND>> &x) { 140 return (getHashValue(x.left()) - getHashValue(x.right())) * 31u + 141 static_cast<unsigned>(TC) + static_cast<unsigned>(KIND); 142 } 143 template <Fortran::common::TypeCategory TC, int KIND> 144 static unsigned getHashValue( 145 const Fortran::evaluate::Power<Fortran::evaluate::Type<TC, KIND>> &x) { 146 return (getHashValue(x.left()) - getHashValue(x.right())) * 37u + 147 static_cast<unsigned>(TC) + static_cast<unsigned>(KIND); 148 } 149 template <Fortran::common::TypeCategory TC, int KIND> 150 static unsigned getHashValue( 151 const Fortran::evaluate::Extremum<Fortran::evaluate::Type<TC, KIND>> &x) { 152 return (getHashValue(x.left()) + getHashValue(x.right())) * 41u + 153 static_cast<unsigned>(TC) + static_cast<unsigned>(KIND) + 154 static_cast<unsigned>(x.ordering) * 7u; 155 } 156 template <Fortran::common::TypeCategory TC, int KIND> 157 static unsigned getHashValue( 158 const Fortran::evaluate::RealToIntPower<Fortran::evaluate::Type<TC, KIND>> 159 &x) { 160 return (getHashValue(x.left()) - getHashValue(x.right())) * 43u + 161 static_cast<unsigned>(TC) + static_cast<unsigned>(KIND); 162 } 163 template <int KIND> 164 static unsigned 165 getHashValue(const Fortran::evaluate::ComplexConstructor<KIND> &x) { 166 return (getHashValue(x.left()) - getHashValue(x.right())) * 47u + 167 static_cast<unsigned>(KIND); 168 } 169 template <int KIND> 170 static unsigned getHashValue(const Fortran::evaluate::Concat<KIND> &x) { 171 return (getHashValue(x.left()) - getHashValue(x.right())) * 53u + 172 static_cast<unsigned>(KIND); 173 } 174 template <int KIND> 175 static unsigned getHashValue(const Fortran::evaluate::SetLength<KIND> &x) { 176 return (getHashValue(x.left()) - getHashValue(x.right())) * 59u + 177 static_cast<unsigned>(KIND); 178 } 179 static unsigned getHashValue(const Fortran::semantics::SymbolRef &sym) { 180 return getHashValue(sym.get()); 181 } 182 static unsigned getHashValue(const Fortran::evaluate::Substring &x) { 183 return 61u * std::visit([&](const auto &p) { return getHashValue(p); }, 184 x.parent()) - 185 getHashValue(x.lower()) - (getHashValue(x.lower()) + 1u); 186 } 187 static unsigned 188 getHashValue(const Fortran::evaluate::StaticDataObject::Pointer &x) { 189 return llvm::hash_value(x->name()); 190 } 191 static unsigned getHashValue(const Fortran::evaluate::SpecificIntrinsic &x) { 192 return llvm::hash_value(x.name); 193 } 194 template <typename A> 195 static unsigned getHashValue(const Fortran::evaluate::Constant<A> &x) { 196 // FIXME: Should hash the content. 197 return 103u; 198 } 199 static unsigned getHashValue(const Fortran::evaluate::ActualArgument &x) { 200 if (const Fortran::evaluate::Symbol *sym = x.GetAssumedTypeDummy()) 201 return getHashValue(*sym); 202 return getHashValue(*x.UnwrapExpr()); 203 } 204 static unsigned 205 getHashValue(const Fortran::evaluate::ProcedureDesignator &x) { 206 return std::visit([&](const auto &v) { return getHashValue(v); }, x.u); 207 } 208 static unsigned getHashValue(const Fortran::evaluate::ProcedureRef &x) { 209 unsigned args = 13u; 210 for (const std::optional<Fortran::evaluate::ActualArgument> &v : 211 x.arguments()) 212 args -= getHashValue(v); 213 return getHashValue(x.proc()) * 101u - args; 214 } 215 template <typename A> 216 static unsigned 217 getHashValue(const Fortran::evaluate::ArrayConstructor<A> &x) { 218 // FIXME: hash the contents. 219 return 127u; 220 } 221 static unsigned getHashValue(const Fortran::evaluate::ImpliedDoIndex &x) { 222 return llvm::hash_value(toStringRef(x.name).str()) * 131u; 223 } 224 static unsigned getHashValue(const Fortran::evaluate::TypeParamInquiry &x) { 225 return getHashValue(x.base()) * 137u - getHashValue(x.parameter()) * 3u; 226 } 227 static unsigned getHashValue(const Fortran::evaluate::DescriptorInquiry &x) { 228 return getHashValue(x.base()) * 139u - 229 static_cast<unsigned>(x.field()) * 13u + 230 static_cast<unsigned>(x.dimension()); 231 } 232 static unsigned 233 getHashValue(const Fortran::evaluate::StructureConstructor &x) { 234 // FIXME: hash the contents. 235 return 149u; 236 } 237 template <int KIND> 238 static unsigned getHashValue(const Fortran::evaluate::Not<KIND> &x) { 239 return getHashValue(x.left()) * 61u + static_cast<unsigned>(KIND); 240 } 241 template <int KIND> 242 static unsigned 243 getHashValue(const Fortran::evaluate::LogicalOperation<KIND> &x) { 244 unsigned result = getHashValue(x.left()) + getHashValue(x.right()); 245 return result * 67u + static_cast<unsigned>(x.logicalOperator) * 5u; 246 } 247 template <Fortran::common::TypeCategory TC, int KIND> 248 static unsigned getHashValue( 249 const Fortran::evaluate::Relational<Fortran::evaluate::Type<TC, KIND>> 250 &x) { 251 return (getHashValue(x.left()) + getHashValue(x.right())) * 71u + 252 static_cast<unsigned>(TC) + static_cast<unsigned>(KIND) + 253 static_cast<unsigned>(x.opr) * 11u; 254 } 255 template <typename A> 256 static unsigned getHashValue(const Fortran::evaluate::Expr<A> &x) { 257 return std::visit([&](const auto &v) { return getHashValue(v); }, x.u); 258 } 259 static unsigned getHashValue( 260 const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> &x) { 261 return std::visit([&](const auto &v) { return getHashValue(v); }, x.u); 262 } 263 template <typename A> 264 static unsigned getHashValue(const Fortran::evaluate::Designator<A> &x) { 265 return std::visit([&](const auto &v) { return getHashValue(v); }, x.u); 266 } 267 template <int BITS> 268 static unsigned 269 getHashValue(const Fortran::evaluate::value::Integer<BITS> &x) { 270 return static_cast<unsigned>(x.ToSInt()); 271 } 272 static unsigned getHashValue(const Fortran::evaluate::NullPointer &x) { 273 return ~179u; 274 } 275 }; 276 } // namespace 277 278 unsigned Fortran::lower::getHashValue( 279 const Fortran::lower::ExplicitIterSpace::ArrayBases &x) { 280 return std::visit( 281 [&](const auto *p) { return HashEvaluateExpr::getHashValue(*p); }, x); 282 } 283 284 unsigned Fortran::lower::getHashValue(Fortran::lower::FrontEndExpr x) { 285 return HashEvaluateExpr::getHashValue(*x); 286 } 287 288 namespace { 289 // Define the is equals test for using Fortran::evaluate::Expr values with 290 // llvm::DenseMap. 291 class IsEqualEvaluateExpr { 292 public: 293 // A Se::Symbol is the only part of an Fortran::evaluate::Expr with an 294 // identity property. 295 static bool isEqual(const Fortran::semantics::Symbol &x, 296 const Fortran::semantics::Symbol &y) { 297 return isEqual(&x, &y); 298 } 299 static bool isEqual(const Fortran::semantics::Symbol *x, 300 const Fortran::semantics::Symbol *y) { 301 return x == y; 302 } 303 template <typename A, bool COPY> 304 static bool isEqual(const Fortran::common::Indirection<A, COPY> &x, 305 const Fortran::common::Indirection<A, COPY> &y) { 306 return isEqual(x.value(), y.value()); 307 } 308 template <typename A> 309 static bool isEqual(const std::optional<A> &x, const std::optional<A> &y) { 310 if (x.has_value() && y.has_value()) 311 return isEqual(x.value(), y.value()); 312 return !x.has_value() && !y.has_value(); 313 } 314 template <typename A> 315 static bool isEqual(const std::vector<A> &x, const std::vector<A> &y) { 316 if (x.size() != y.size()) 317 return false; 318 const std::size_t size = x.size(); 319 for (std::remove_const_t<decltype(size)> i = 0; i < size; ++i) 320 if (!isEqual(x[i], y[i])) 321 return false; 322 return true; 323 } 324 static bool isEqual(const Fortran::evaluate::Subscript &x, 325 const Fortran::evaluate::Subscript &y) { 326 return std::visit( 327 [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); 328 } 329 static bool isEqual(const Fortran::evaluate::Triplet &x, 330 const Fortran::evaluate::Triplet &y) { 331 return isEqual(x.lower(), y.lower()) && isEqual(x.upper(), y.upper()) && 332 isEqual(x.stride(), y.stride()); 333 } 334 static bool isEqual(const Fortran::evaluate::Component &x, 335 const Fortran::evaluate::Component &y) { 336 return isEqual(x.base(), y.base()) && 337 isEqual(x.GetLastSymbol(), y.GetLastSymbol()); 338 } 339 static bool isEqual(const Fortran::evaluate::ArrayRef &x, 340 const Fortran::evaluate::ArrayRef &y) { 341 return isEqual(x.base(), y.base()) && isEqual(x.subscript(), y.subscript()); 342 } 343 static bool isEqual(const Fortran::evaluate::CoarrayRef &x, 344 const Fortran::evaluate::CoarrayRef &y) { 345 return isEqual(x.base(), y.base()) && 346 isEqual(x.subscript(), y.subscript()) && 347 isEqual(x.cosubscript(), y.cosubscript()) && 348 isEqual(x.stat(), y.stat()) && isEqual(x.team(), y.team()); 349 } 350 static bool isEqual(const Fortran::evaluate::NamedEntity &x, 351 const Fortran::evaluate::NamedEntity &y) { 352 if (x.IsSymbol() && y.IsSymbol()) 353 return isEqual(x.GetFirstSymbol(), y.GetFirstSymbol()); 354 return !x.IsSymbol() && !y.IsSymbol() && 355 isEqual(x.GetComponent(), y.GetComponent()); 356 } 357 static bool isEqual(const Fortran::evaluate::DataRef &x, 358 const Fortran::evaluate::DataRef &y) { 359 return std::visit( 360 [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); 361 } 362 static bool isEqual(const Fortran::evaluate::ComplexPart &x, 363 const Fortran::evaluate::ComplexPart &y) { 364 return isEqual(x.complex(), y.complex()) && x.part() == y.part(); 365 } 366 template <typename A, Fortran::common::TypeCategory TC2> 367 static bool isEqual(const Fortran::evaluate::Convert<A, TC2> &x, 368 const Fortran::evaluate::Convert<A, TC2> &y) { 369 return isEqual(x.left(), y.left()); 370 } 371 template <int KIND> 372 static bool isEqual(const Fortran::evaluate::ComplexComponent<KIND> &x, 373 const Fortran::evaluate::ComplexComponent<KIND> &y) { 374 return isEqual(x.left(), y.left()) && 375 x.isImaginaryPart == y.isImaginaryPart; 376 } 377 template <typename T> 378 static bool isEqual(const Fortran::evaluate::Parentheses<T> &x, 379 const Fortran::evaluate::Parentheses<T> &y) { 380 return isEqual(x.left(), y.left()); 381 } 382 template <typename A> 383 static bool isEqual(const Fortran::evaluate::Negate<A> &x, 384 const Fortran::evaluate::Negate<A> &y) { 385 return isEqual(x.left(), y.left()); 386 } 387 template <typename A> 388 static bool isBinaryEqual(const A &x, const A &y) { 389 return isEqual(x.left(), y.left()) && isEqual(x.right(), y.right()); 390 } 391 template <typename A> 392 static bool isEqual(const Fortran::evaluate::Add<A> &x, 393 const Fortran::evaluate::Add<A> &y) { 394 return isBinaryEqual(x, y); 395 } 396 template <typename A> 397 static bool isEqual(const Fortran::evaluate::Subtract<A> &x, 398 const Fortran::evaluate::Subtract<A> &y) { 399 return isBinaryEqual(x, y); 400 } 401 template <typename A> 402 static bool isEqual(const Fortran::evaluate::Multiply<A> &x, 403 const Fortran::evaluate::Multiply<A> &y) { 404 return isBinaryEqual(x, y); 405 } 406 template <typename A> 407 static bool isEqual(const Fortran::evaluate::Divide<A> &x, 408 const Fortran::evaluate::Divide<A> &y) { 409 return isBinaryEqual(x, y); 410 } 411 template <typename A> 412 static bool isEqual(const Fortran::evaluate::Power<A> &x, 413 const Fortran::evaluate::Power<A> &y) { 414 return isBinaryEqual(x, y); 415 } 416 template <typename A> 417 static bool isEqual(const Fortran::evaluate::Extremum<A> &x, 418 const Fortran::evaluate::Extremum<A> &y) { 419 return isBinaryEqual(x, y); 420 } 421 template <typename A> 422 static bool isEqual(const Fortran::evaluate::RealToIntPower<A> &x, 423 const Fortran::evaluate::RealToIntPower<A> &y) { 424 return isBinaryEqual(x, y); 425 } 426 template <int KIND> 427 static bool isEqual(const Fortran::evaluate::ComplexConstructor<KIND> &x, 428 const Fortran::evaluate::ComplexConstructor<KIND> &y) { 429 return isBinaryEqual(x, y); 430 } 431 template <int KIND> 432 static bool isEqual(const Fortran::evaluate::Concat<KIND> &x, 433 const Fortran::evaluate::Concat<KIND> &y) { 434 return isBinaryEqual(x, y); 435 } 436 template <int KIND> 437 static bool isEqual(const Fortran::evaluate::SetLength<KIND> &x, 438 const Fortran::evaluate::SetLength<KIND> &y) { 439 return isBinaryEqual(x, y); 440 } 441 static bool isEqual(const Fortran::semantics::SymbolRef &x, 442 const Fortran::semantics::SymbolRef &y) { 443 return isEqual(x.get(), y.get()); 444 } 445 static bool isEqual(const Fortran::evaluate::Substring &x, 446 const Fortran::evaluate::Substring &y) { 447 return std::visit( 448 [&](const auto &p, const auto &q) { return isEqual(p, q); }, 449 x.parent(), y.parent()) && 450 isEqual(x.lower(), y.lower()) && isEqual(x.lower(), y.lower()); 451 } 452 static bool isEqual(const Fortran::evaluate::StaticDataObject::Pointer &x, 453 const Fortran::evaluate::StaticDataObject::Pointer &y) { 454 return x->name() == y->name(); 455 } 456 static bool isEqual(const Fortran::evaluate::SpecificIntrinsic &x, 457 const Fortran::evaluate::SpecificIntrinsic &y) { 458 return x.name == y.name; 459 } 460 template <typename A> 461 static bool isEqual(const Fortran::evaluate::Constant<A> &x, 462 const Fortran::evaluate::Constant<A> &y) { 463 return x == y; 464 } 465 static bool isEqual(const Fortran::evaluate::ActualArgument &x, 466 const Fortran::evaluate::ActualArgument &y) { 467 if (const Fortran::evaluate::Symbol *xs = x.GetAssumedTypeDummy()) { 468 if (const Fortran::evaluate::Symbol *ys = y.GetAssumedTypeDummy()) 469 return isEqual(*xs, *ys); 470 return false; 471 } 472 return !y.GetAssumedTypeDummy() && 473 isEqual(*x.UnwrapExpr(), *y.UnwrapExpr()); 474 } 475 static bool isEqual(const Fortran::evaluate::ProcedureDesignator &x, 476 const Fortran::evaluate::ProcedureDesignator &y) { 477 return std::visit( 478 [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); 479 } 480 static bool isEqual(const Fortran::evaluate::ProcedureRef &x, 481 const Fortran::evaluate::ProcedureRef &y) { 482 return isEqual(x.proc(), y.proc()) && isEqual(x.arguments(), y.arguments()); 483 } 484 template <typename A> 485 static bool isEqual(const Fortran::evaluate::ArrayConstructor<A> &x, 486 const Fortran::evaluate::ArrayConstructor<A> &y) { 487 llvm::report_fatal_error("not implemented"); 488 } 489 static bool isEqual(const Fortran::evaluate::ImpliedDoIndex &x, 490 const Fortran::evaluate::ImpliedDoIndex &y) { 491 return toStringRef(x.name) == toStringRef(y.name); 492 } 493 static bool isEqual(const Fortran::evaluate::TypeParamInquiry &x, 494 const Fortran::evaluate::TypeParamInquiry &y) { 495 return isEqual(x.base(), y.base()) && isEqual(x.parameter(), y.parameter()); 496 } 497 static bool isEqual(const Fortran::evaluate::DescriptorInquiry &x, 498 const Fortran::evaluate::DescriptorInquiry &y) { 499 return isEqual(x.base(), y.base()) && x.field() == y.field() && 500 x.dimension() == y.dimension(); 501 } 502 static bool isEqual(const Fortran::evaluate::StructureConstructor &x, 503 const Fortran::evaluate::StructureConstructor &y) { 504 llvm::report_fatal_error("not implemented"); 505 } 506 template <int KIND> 507 static bool isEqual(const Fortran::evaluate::Not<KIND> &x, 508 const Fortran::evaluate::Not<KIND> &y) { 509 return isEqual(x.left(), y.left()); 510 } 511 template <int KIND> 512 static bool isEqual(const Fortran::evaluate::LogicalOperation<KIND> &x, 513 const Fortran::evaluate::LogicalOperation<KIND> &y) { 514 return isEqual(x.left(), y.left()) && isEqual(x.right(), x.right()); 515 } 516 template <typename A> 517 static bool isEqual(const Fortran::evaluate::Relational<A> &x, 518 const Fortran::evaluate::Relational<A> &y) { 519 return isEqual(x.left(), y.left()) && isEqual(x.right(), y.right()); 520 } 521 template <typename A> 522 static bool isEqual(const Fortran::evaluate::Expr<A> &x, 523 const Fortran::evaluate::Expr<A> &y) { 524 return std::visit( 525 [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); 526 } 527 static bool 528 isEqual(const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> &x, 529 const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> &y) { 530 return std::visit( 531 [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); 532 } 533 template <typename A> 534 static bool isEqual(const Fortran::evaluate::Designator<A> &x, 535 const Fortran::evaluate::Designator<A> &y) { 536 return std::visit( 537 [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); 538 } 539 template <int BITS> 540 static bool isEqual(const Fortran::evaluate::value::Integer<BITS> &x, 541 const Fortran::evaluate::value::Integer<BITS> &y) { 542 return x == y; 543 } 544 static bool isEqual(const Fortran::evaluate::NullPointer &x, 545 const Fortran::evaluate::NullPointer &y) { 546 return true; 547 } 548 template <typename A, typename B, 549 std::enable_if_t<!std::is_same_v<A, B>, bool> = true> 550 static bool isEqual(const A &, const B &) { 551 return false; 552 } 553 }; 554 } // namespace 555 556 bool Fortran::lower::isEqual( 557 const Fortran::lower::ExplicitIterSpace::ArrayBases &x, 558 const Fortran::lower::ExplicitIterSpace::ArrayBases &y) { 559 return std::visit( 560 Fortran::common::visitors{ 561 // Fortran::semantics::Symbol * are the exception here. These pointers 562 // have identity; if two Symbol * values are the same (different) then 563 // they are the same (different) logical symbol. 564 [&](Fortran::lower::FrontEndSymbol p, 565 Fortran::lower::FrontEndSymbol q) { return p == q; }, 566 [&](const auto *p, const auto *q) { 567 if constexpr (std::is_same_v<decltype(p), decltype(q)>) { 568 LLVM_DEBUG(llvm::dbgs() 569 << "is equal: " << p << ' ' << q << ' ' 570 << IsEqualEvaluateExpr::isEqual(*p, *q) << '\n'); 571 return IsEqualEvaluateExpr::isEqual(*p, *q); 572 } else { 573 // Different subtree types are never equal. 574 return false; 575 } 576 }}, 577 x, y); 578 } 579 580 bool Fortran::lower::isEqual(Fortran::lower::FrontEndExpr x, 581 Fortran::lower::FrontEndExpr y) { 582 auto empty = llvm::DenseMapInfo<Fortran::lower::FrontEndExpr>::getEmptyKey(); 583 auto tombstone = 584 llvm::DenseMapInfo<Fortran::lower::FrontEndExpr>::getTombstoneKey(); 585 if (x == empty || y == empty || x == tombstone || y == tombstone) 586 return x == y; 587 return x == y || IsEqualEvaluateExpr::isEqual(*x, *y); 588 } 589 590 namespace { 591 592 /// This class can recover the base array in an expression that contains 593 /// explicit iteration space symbols. Most of the class can be ignored as it is 594 /// boilerplate Fortran::evaluate::Expr traversal. 595 class ArrayBaseFinder { 596 public: 597 using RT = bool; 598 599 ArrayBaseFinder(llvm::ArrayRef<Fortran::lower::FrontEndSymbol> syms) 600 : controlVars(syms.begin(), syms.end()) {} 601 602 template <typename T> 603 void operator()(const T &x) { 604 (void)find(x); 605 } 606 607 /// Get the list of bases. 608 llvm::ArrayRef<Fortran::lower::ExplicitIterSpace::ArrayBases> 609 getBases() const { 610 LLVM_DEBUG(llvm::dbgs() 611 << "number of array bases found: " << bases.size() << '\n'); 612 return bases; 613 } 614 615 private: 616 // First, the cases that are of interest. 617 RT find(const Fortran::semantics::Symbol &symbol) { 618 if (symbol.Rank() > 0) { 619 bases.push_back(&symbol); 620 return true; 621 } 622 return {}; 623 } 624 RT find(const Fortran::evaluate::Component &x) { 625 auto found = find(x.base()); 626 if (!found && x.base().Rank() == 0 && x.Rank() > 0) { 627 bases.push_back(&x); 628 return true; 629 } 630 return found; 631 } 632 RT find(const Fortran::evaluate::ArrayRef &x) { 633 for (const auto &sub : x.subscript()) 634 (void)find(sub); 635 if (x.base().IsSymbol()) { 636 if (x.Rank() > 0 || intersection(x.subscript())) { 637 bases.push_back(&x); 638 return true; 639 } 640 return {}; 641 } 642 auto found = find(x.base()); 643 if (!found && ((x.base().Rank() == 0 && x.Rank() > 0) || 644 intersection(x.subscript()))) { 645 bases.push_back(&x); 646 return true; 647 } 648 return found; 649 } 650 RT find(const Fortran::evaluate::Triplet &x) { 651 if (const auto *lower = x.GetLower()) 652 (void)find(*lower); 653 if (const auto *upper = x.GetUpper()) 654 (void)find(*upper); 655 return find(x.GetStride()); 656 } 657 RT find(const Fortran::evaluate::IndirectSubscriptIntegerExpr &x) { 658 return find(x.value()); 659 } 660 RT find(const Fortran::evaluate::Subscript &x) { return find(x.u); } 661 RT find(const Fortran::evaluate::DataRef &x) { return find(x.u); } 662 RT find(const Fortran::evaluate::CoarrayRef &x) { 663 assert(false && "coarray reference"); 664 return {}; 665 } 666 667 template <typename A> 668 bool intersection(const A &subscripts) { 669 return Fortran::lower::symbolsIntersectSubscripts(controlVars, subscripts); 670 } 671 672 // The rest is traversal boilerplate and can be ignored. 673 RT find(const Fortran::evaluate::Substring &x) { return find(x.parent()); } 674 template <typename A> 675 RT find(const Fortran::semantics::SymbolRef x) { 676 return find(*x); 677 } 678 RT find(const Fortran::evaluate::NamedEntity &x) { 679 if (x.IsSymbol()) 680 return find(x.GetFirstSymbol()); 681 return find(x.GetComponent()); 682 } 683 684 template <typename A, bool C> 685 RT find(const Fortran::common::Indirection<A, C> &x) { 686 return find(x.value()); 687 } 688 template <typename A> 689 RT find(const std::unique_ptr<A> &x) { 690 return find(x.get()); 691 } 692 template <typename A> 693 RT find(const std::shared_ptr<A> &x) { 694 return find(x.get()); 695 } 696 template <typename A> 697 RT find(const A *x) { 698 if (x) 699 return find(*x); 700 return {}; 701 } 702 template <typename A> 703 RT find(const std::optional<A> &x) { 704 if (x) 705 return find(*x); 706 return {}; 707 } 708 template <typename... A> 709 RT find(const std::variant<A...> &u) { 710 return std::visit([&](const auto &v) { return find(v); }, u); 711 } 712 template <typename A> 713 RT find(const std::vector<A> &x) { 714 for (auto &v : x) 715 (void)find(v); 716 return {}; 717 } 718 RT find(const Fortran::evaluate::BOZLiteralConstant &) { return {}; } 719 RT find(const Fortran::evaluate::NullPointer &) { return {}; } 720 template <typename T> 721 RT find(const Fortran::evaluate::Constant<T> &x) { 722 return {}; 723 } 724 RT find(const Fortran::evaluate::StaticDataObject &) { return {}; } 725 RT find(const Fortran::evaluate::ImpliedDoIndex &) { return {}; } 726 RT find(const Fortran::evaluate::BaseObject &x) { 727 (void)find(x.u); 728 return {}; 729 } 730 RT find(const Fortran::evaluate::TypeParamInquiry &) { return {}; } 731 RT find(const Fortran::evaluate::ComplexPart &x) { return {}; } 732 template <typename T> 733 RT find(const Fortran::evaluate::Designator<T> &x) { 734 return find(x.u); 735 } 736 template <typename T> 737 RT find(const Fortran::evaluate::Variable<T> &x) { 738 return find(x.u); 739 } 740 RT find(const Fortran::evaluate::DescriptorInquiry &) { return {}; } 741 RT find(const Fortran::evaluate::SpecificIntrinsic &) { return {}; } 742 RT find(const Fortran::evaluate::ProcedureDesignator &x) { return {}; } 743 RT find(const Fortran::evaluate::ProcedureRef &x) { 744 (void)find(x.proc()); 745 if (x.IsElemental()) 746 (void)find(x.arguments()); 747 return {}; 748 } 749 RT find(const Fortran::evaluate::ActualArgument &x) { 750 if (const auto *sym = x.GetAssumedTypeDummy()) 751 (void)find(*sym); 752 else 753 (void)find(x.UnwrapExpr()); 754 return {}; 755 } 756 template <typename T> 757 RT find(const Fortran::evaluate::FunctionRef<T> &x) { 758 (void)find(static_cast<const Fortran::evaluate::ProcedureRef &>(x)); 759 return {}; 760 } 761 template <typename T> 762 RT find(const Fortran::evaluate::ArrayConstructorValue<T> &) { 763 return {}; 764 } 765 template <typename T> 766 RT find(const Fortran::evaluate::ArrayConstructorValues<T> &) { 767 return {}; 768 } 769 template <typename T> 770 RT find(const Fortran::evaluate::ImpliedDo<T> &) { 771 return {}; 772 } 773 RT find(const Fortran::semantics::ParamValue &) { return {}; } 774 RT find(const Fortran::semantics::DerivedTypeSpec &) { return {}; } 775 RT find(const Fortran::evaluate::StructureConstructor &) { return {}; } 776 template <typename D, typename R, typename O> 777 RT find(const Fortran::evaluate::Operation<D, R, O> &op) { 778 (void)find(op.left()); 779 return false; 780 } 781 template <typename D, typename R, typename LO, typename RO> 782 RT find(const Fortran::evaluate::Operation<D, R, LO, RO> &op) { 783 (void)find(op.left()); 784 (void)find(op.right()); 785 return false; 786 } 787 RT find(const Fortran::evaluate::Relational<Fortran::evaluate::SomeType> &x) { 788 (void)find(x.u); 789 return {}; 790 } 791 template <typename T> 792 RT find(const Fortran::evaluate::Expr<T> &x) { 793 (void)find(x.u); 794 return {}; 795 } 796 797 llvm::SmallVector<Fortran::lower::ExplicitIterSpace::ArrayBases> bases; 798 llvm::SmallVector<Fortran::lower::FrontEndSymbol> controlVars; 799 }; 800 801 } // namespace 802 803 void Fortran::lower::ExplicitIterSpace::leave() { 804 ccLoopNest.pop_back(); 805 --forallContextOpen; 806 conditionalCleanup(); 807 } 808 809 void Fortran::lower::ExplicitIterSpace::addSymbol( 810 Fortran::lower::FrontEndSymbol sym) { 811 assert(!symbolStack.empty()); 812 symbolStack.back().push_back(sym); 813 } 814 815 void Fortran::lower::ExplicitIterSpace::exprBase(Fortran::lower::FrontEndExpr x, 816 bool lhs) { 817 ArrayBaseFinder finder(collectAllSymbols()); 818 finder(*x); 819 llvm::ArrayRef<Fortran::lower::ExplicitIterSpace::ArrayBases> bases = 820 finder.getBases(); 821 if (rhsBases.empty()) 822 endAssign(); 823 if (lhs) { 824 if (bases.empty()) { 825 lhsBases.push_back(llvm::None); 826 return; 827 } 828 assert(bases.size() >= 1 && "must detect an array reference on lhs"); 829 if (bases.size() > 1) 830 rhsBases.back().append(bases.begin(), bases.end() - 1); 831 lhsBases.push_back(bases.back()); 832 return; 833 } 834 rhsBases.back().append(bases.begin(), bases.end()); 835 } 836 837 void Fortran::lower::ExplicitIterSpace::endAssign() { rhsBases.emplace_back(); } 838 839 void Fortran::lower::ExplicitIterSpace::pushLevel() { 840 symbolStack.push_back(llvm::SmallVector<Fortran::lower::FrontEndSymbol>{}); 841 } 842 843 void Fortran::lower::ExplicitIterSpace::popLevel() { symbolStack.pop_back(); } 844 845 void Fortran::lower::ExplicitIterSpace::conditionalCleanup() { 846 if (forallContextOpen == 0) { 847 // Exiting the outermost FORALL context. 848 // Cleanup any residual mask buffers. 849 outermostContext().finalize(); 850 // Clear and reset all the cached information. 851 symbolStack.clear(); 852 lhsBases.clear(); 853 rhsBases.clear(); 854 loadBindings.clear(); 855 ccLoopNest.clear(); 856 innerArgs.clear(); 857 outerLoop = llvm::None; 858 clearLoops(); 859 counter = 0; 860 } 861 } 862 863 llvm::Optional<size_t> 864 Fortran::lower::ExplicitIterSpace::findArgPosition(fir::ArrayLoadOp load) { 865 if (lhsBases[counter].hasValue()) { 866 auto ld = loadBindings.find(lhsBases[counter].getValue()); 867 llvm::Optional<size_t> optPos; 868 if (ld != loadBindings.end() && ld->second == load) 869 optPos = static_cast<size_t>(0u); 870 assert(optPos.hasValue() && "load does not correspond to lhs"); 871 return optPos; 872 } 873 return llvm::None; 874 } 875 876 llvm::SmallVector<Fortran::lower::FrontEndSymbol> 877 Fortran::lower::ExplicitIterSpace::collectAllSymbols() { 878 llvm::SmallVector<Fortran::lower::FrontEndSymbol> result; 879 for (llvm::SmallVector<FrontEndSymbol> vec : symbolStack) 880 result.append(vec.begin(), vec.end()); 881 return result; 882 } 883 884 llvm::raw_ostream & 885 Fortran::lower::operator<<(llvm::raw_ostream &s, 886 const Fortran::lower::ImplicitIterSpace &e) { 887 for (const llvm::SmallVector< 888 Fortran::lower::ImplicitIterSpace::FrontEndMaskExpr> &xs : 889 e.getMasks()) { 890 s << "{ "; 891 for (const Fortran::lower::ImplicitIterSpace::FrontEndMaskExpr &x : xs) 892 x->AsFortran(s << '(') << "), "; 893 s << "}\n"; 894 } 895 return s; 896 } 897 898 llvm::raw_ostream & 899 Fortran::lower::operator<<(llvm::raw_ostream &s, 900 const Fortran::lower::ExplicitIterSpace &e) { 901 auto dump = [&](const auto &u) { 902 std::visit(Fortran::common::visitors{ 903 [&](const Fortran::semantics::Symbol *y) { 904 s << " " << *y << '\n'; 905 }, 906 [&](const Fortran::evaluate::ArrayRef *y) { 907 s << " "; 908 if (y->base().IsSymbol()) 909 s << y->base().GetFirstSymbol(); 910 else 911 s << y->base().GetComponent().GetLastSymbol(); 912 s << '\n'; 913 }, 914 [&](const Fortran::evaluate::Component *y) { 915 s << " " << y->GetLastSymbol() << '\n'; 916 }}, 917 u); 918 }; 919 s << "LHS bases:\n"; 920 for (const llvm::Optional<Fortran::lower::ExplicitIterSpace::ArrayBases> &u : 921 e.lhsBases) 922 if (u.hasValue()) 923 dump(u.getValue()); 924 s << "RHS bases:\n"; 925 for (const llvm::SmallVector<Fortran::lower::ExplicitIterSpace::ArrayBases> 926 &bases : e.rhsBases) { 927 for (const Fortran::lower::ExplicitIterSpace::ArrayBases &u : bases) 928 dump(u); 929 s << '\n'; 930 } 931 return s; 932 } 933 934 void Fortran::lower::ImplicitIterSpace::dump() const { 935 llvm::errs() << *this << '\n'; 936 } 937 938 void Fortran::lower::ExplicitIterSpace::dump() const { 939 llvm::errs() << *this << '\n'; 940 } 941