1 //===-- lib/Evaluate/shape.cpp --------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Evaluate/shape.h" 10 #include "flang/Common/idioms.h" 11 #include "flang/Common/template.h" 12 #include "flang/Evaluate/characteristics.h" 13 #include "flang/Evaluate/fold.h" 14 #include "flang/Evaluate/intrinsics.h" 15 #include "flang/Evaluate/tools.h" 16 #include "flang/Evaluate/type.h" 17 #include "flang/Parser/message.h" 18 #include "flang/Semantics/symbol.h" 19 #include <functional> 20 21 using namespace std::placeholders; // _1, _2, &c. for std::bind() 22 23 namespace Fortran::evaluate { 24 25 bool IsImpliedShape(const Symbol &original) { 26 const Symbol &symbol{ResolveAssociations(original)}; 27 const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}; 28 return details && symbol.attrs().test(semantics::Attr::PARAMETER) && 29 details->shape().IsImpliedShape(); 30 } 31 32 bool IsExplicitShape(const Symbol &original) { 33 const Symbol &symbol{ResolveAssociations(original)}; 34 if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) { 35 const auto &shape{details->shape()}; 36 return shape.Rank() == 0 || 37 shape.IsExplicitShape(); // true when scalar, too 38 } else { 39 return symbol 40 .has<semantics::AssocEntityDetails>(); // exprs have explicit shape 41 } 42 } 43 44 Shape GetShapeHelper::ConstantShape(const Constant<ExtentType> &arrayConstant) { 45 CHECK(arrayConstant.Rank() == 1); 46 Shape result; 47 std::size_t dimensions{arrayConstant.size()}; 48 for (std::size_t j{0}; j < dimensions; ++j) { 49 Scalar<ExtentType> extent{arrayConstant.values().at(j)}; 50 result.emplace_back(MaybeExtentExpr{ExtentExpr{std::move(extent)}}); 51 } 52 return result; 53 } 54 55 auto GetShapeHelper::AsShape(ExtentExpr &&arrayExpr) const -> Result { 56 if (context_) { 57 arrayExpr = Fold(*context_, std::move(arrayExpr)); 58 } 59 if (const auto *constArray{UnwrapConstantValue<ExtentType>(arrayExpr)}) { 60 return ConstantShape(*constArray); 61 } 62 if (auto *constructor{UnwrapExpr<ArrayConstructor<ExtentType>>(arrayExpr)}) { 63 Shape result; 64 for (auto &value : *constructor) { 65 if (auto *expr{std::get_if<ExtentExpr>(&value.u)}) { 66 if (expr->Rank() == 0) { 67 result.emplace_back(std::move(*expr)); 68 continue; 69 } 70 } 71 return std::nullopt; 72 } 73 return result; 74 } 75 return std::nullopt; 76 } 77 78 Shape GetShapeHelper::CreateShape(int rank, NamedEntity &base) { 79 Shape shape; 80 for (int dimension{0}; dimension < rank; ++dimension) { 81 shape.emplace_back(GetExtent(base, dimension)); 82 } 83 return shape; 84 } 85 86 std::optional<ExtentExpr> AsExtentArrayExpr(const Shape &shape) { 87 ArrayConstructorValues<ExtentType> values; 88 for (const auto &dim : shape) { 89 if (dim) { 90 values.Push(common::Clone(*dim)); 91 } else { 92 return std::nullopt; 93 } 94 } 95 return ExtentExpr{ArrayConstructor<ExtentType>{std::move(values)}}; 96 } 97 98 std::optional<Constant<ExtentType>> AsConstantShape( 99 FoldingContext &context, const Shape &shape) { 100 if (auto shapeArray{AsExtentArrayExpr(shape)}) { 101 auto folded{Fold(context, std::move(*shapeArray))}; 102 if (auto *p{UnwrapConstantValue<ExtentType>(folded)}) { 103 return std::move(*p); 104 } 105 } 106 return std::nullopt; 107 } 108 109 Constant<SubscriptInteger> AsConstantShape(const ConstantSubscripts &shape) { 110 using IntType = Scalar<SubscriptInteger>; 111 std::vector<IntType> result; 112 for (auto dim : shape) { 113 result.emplace_back(dim); 114 } 115 return {std::move(result), ConstantSubscripts{GetRank(shape)}}; 116 } 117 118 ConstantSubscripts AsConstantExtents(const Constant<ExtentType> &shape) { 119 ConstantSubscripts result; 120 for (const auto &extent : shape.values()) { 121 result.push_back(extent.ToInt64()); 122 } 123 return result; 124 } 125 126 std::optional<ConstantSubscripts> AsConstantExtents( 127 FoldingContext &context, const Shape &shape) { 128 if (auto shapeConstant{AsConstantShape(context, shape)}) { 129 return AsConstantExtents(*shapeConstant); 130 } else { 131 return std::nullopt; 132 } 133 } 134 135 Shape Fold(FoldingContext &context, Shape &&shape) { 136 for (auto &dim : shape) { 137 dim = Fold(context, std::move(dim)); 138 } 139 return std::move(shape); 140 } 141 142 std::optional<Shape> Fold( 143 FoldingContext &context, std::optional<Shape> &&shape) { 144 if (shape) { 145 return Fold(context, std::move(*shape)); 146 } else { 147 return std::nullopt; 148 } 149 } 150 151 static ExtentExpr ComputeTripCount( 152 ExtentExpr &&lower, ExtentExpr &&upper, ExtentExpr &&stride) { 153 ExtentExpr strideCopy{common::Clone(stride)}; 154 ExtentExpr span{ 155 (std::move(upper) - std::move(lower) + std::move(strideCopy)) / 156 std::move(stride)}; 157 return ExtentExpr{ 158 Extremum<ExtentType>{Ordering::Greater, std::move(span), ExtentExpr{0}}}; 159 } 160 161 ExtentExpr CountTrips( 162 ExtentExpr &&lower, ExtentExpr &&upper, ExtentExpr &&stride) { 163 return ComputeTripCount( 164 std::move(lower), std::move(upper), std::move(stride)); 165 } 166 167 ExtentExpr CountTrips(const ExtentExpr &lower, const ExtentExpr &upper, 168 const ExtentExpr &stride) { 169 return ComputeTripCount( 170 common::Clone(lower), common::Clone(upper), common::Clone(stride)); 171 } 172 173 MaybeExtentExpr CountTrips(MaybeExtentExpr &&lower, MaybeExtentExpr &&upper, 174 MaybeExtentExpr &&stride) { 175 std::function<ExtentExpr(ExtentExpr &&, ExtentExpr &&, ExtentExpr &&)> bound{ 176 std::bind(ComputeTripCount, _1, _2, _3)}; 177 return common::MapOptional( 178 std::move(bound), std::move(lower), std::move(upper), std::move(stride)); 179 } 180 181 MaybeExtentExpr GetSize(Shape &&shape) { 182 ExtentExpr extent{1}; 183 for (auto &&dim : std::move(shape)) { 184 if (dim) { 185 extent = std::move(extent) * std::move(*dim); 186 } else { 187 return std::nullopt; 188 } 189 } 190 return extent; 191 } 192 193 bool ContainsAnyImpliedDoIndex(const ExtentExpr &expr) { 194 struct MyVisitor : public AnyTraverse<MyVisitor> { 195 using Base = AnyTraverse<MyVisitor>; 196 MyVisitor() : Base{*this} {} 197 using Base::operator(); 198 bool operator()(const ImpliedDoIndex &) { return true; } 199 }; 200 return MyVisitor{}(expr); 201 } 202 203 // Determines lower bound on a dimension. This can be other than 1 only 204 // for a reference to a whole array object or component. (See LBOUND, 16.9.109). 205 // ASSOCIATE construct entities may require tranversal of their referents. 206 class GetLowerBoundHelper : public Traverse<GetLowerBoundHelper, ExtentExpr> { 207 public: 208 using Result = ExtentExpr; 209 using Base = Traverse<GetLowerBoundHelper, ExtentExpr>; 210 using Base::operator(); 211 explicit GetLowerBoundHelper(int d) : Base{*this}, dimension_{d} {} 212 static ExtentExpr Default() { return ExtentExpr{1}; } 213 static ExtentExpr Combine(Result &&, Result &&) { return Default(); } 214 ExtentExpr operator()(const Symbol &); 215 ExtentExpr operator()(const Component &); 216 217 private: 218 int dimension_; 219 }; 220 221 auto GetLowerBoundHelper::operator()(const Symbol &symbol0) -> Result { 222 const Symbol &symbol{symbol0.GetUltimate()}; 223 if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) { 224 int j{0}; 225 for (const auto &shapeSpec : details->shape()) { 226 if (j++ == dimension_) { 227 if (const auto &bound{shapeSpec.lbound().GetExplicit()}) { 228 return *bound; 229 } else if (IsDescriptor(symbol)) { 230 return ExtentExpr{DescriptorInquiry{NamedEntity{symbol0}, 231 DescriptorInquiry::Field::LowerBound, dimension_}}; 232 } else { 233 break; 234 } 235 } 236 } 237 } else if (const auto *assoc{ 238 symbol.detailsIf<semantics::AssocEntityDetails>()}) { 239 return (*this)(assoc->expr()); 240 } 241 return Default(); 242 } 243 244 auto GetLowerBoundHelper::operator()(const Component &component) -> Result { 245 if (component.base().Rank() == 0) { 246 const Symbol &symbol{component.GetLastSymbol().GetUltimate()}; 247 if (const auto *details{ 248 symbol.detailsIf<semantics::ObjectEntityDetails>()}) { 249 int j{0}; 250 for (const auto &shapeSpec : details->shape()) { 251 if (j++ == dimension_) { 252 if (const auto &bound{shapeSpec.lbound().GetExplicit()}) { 253 return *bound; 254 } else if (IsDescriptor(symbol)) { 255 return ExtentExpr{ 256 DescriptorInquiry{NamedEntity{common::Clone(component)}, 257 DescriptorInquiry::Field::LowerBound, dimension_}}; 258 } else { 259 break; 260 } 261 } 262 } 263 } 264 } 265 return Default(); 266 } 267 268 ExtentExpr GetLowerBound(const NamedEntity &base, int dimension) { 269 return GetLowerBoundHelper{dimension}(base); 270 } 271 272 ExtentExpr GetLowerBound( 273 FoldingContext &context, const NamedEntity &base, int dimension) { 274 return Fold(context, GetLowerBound(base, dimension)); 275 } 276 277 Shape GetLowerBounds(const NamedEntity &base) { 278 Shape result; 279 int rank{base.Rank()}; 280 for (int dim{0}; dim < rank; ++dim) { 281 result.emplace_back(GetLowerBound(base, dim)); 282 } 283 return result; 284 } 285 286 Shape GetLowerBounds(FoldingContext &context, const NamedEntity &base) { 287 Shape result; 288 int rank{base.Rank()}; 289 for (int dim{0}; dim < rank; ++dim) { 290 result.emplace_back(GetLowerBound(context, base, dim)); 291 } 292 return result; 293 } 294 295 MaybeExtentExpr GetExtent(const NamedEntity &base, int dimension) { 296 CHECK(dimension >= 0); 297 const Symbol &symbol{ResolveAssociations(base.GetLastSymbol())}; 298 if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) { 299 if (IsImpliedShape(symbol) && details->init()) { 300 if (auto shape{GetShape(symbol)}) { 301 if (dimension < static_cast<int>(shape->size())) { 302 return std::move(shape->at(dimension)); 303 } 304 } 305 } else { 306 int j{0}; 307 for (const auto &shapeSpec : details->shape()) { 308 if (j++ == dimension) { 309 if (shapeSpec.ubound().isExplicit()) { 310 if (const auto &ubound{shapeSpec.ubound().GetExplicit()}) { 311 if (const auto &lbound{shapeSpec.lbound().GetExplicit()}) { 312 return common::Clone(ubound.value()) - 313 common::Clone(lbound.value()) + ExtentExpr{1}; 314 } else { 315 return ubound.value(); 316 } 317 } 318 } else if (details->IsAssumedSize() && j == symbol.Rank()) { 319 return std::nullopt; 320 } else if (semantics::IsDescriptor(symbol)) { 321 return ExtentExpr{DescriptorInquiry{NamedEntity{base}, 322 DescriptorInquiry::Field::Extent, dimension}}; 323 } 324 } 325 } 326 } 327 } else if (const auto *assoc{ 328 symbol.detailsIf<semantics::AssocEntityDetails>()}) { 329 if (auto shape{GetShape(assoc->expr())}) { 330 if (dimension < static_cast<int>(shape->size())) { 331 return std::move(shape->at(dimension)); 332 } 333 } 334 } 335 return std::nullopt; 336 } 337 338 MaybeExtentExpr GetExtent( 339 FoldingContext &context, const NamedEntity &base, int dimension) { 340 return Fold(context, GetExtent(base, dimension)); 341 } 342 343 MaybeExtentExpr GetExtent( 344 const Subscript &subscript, const NamedEntity &base, int dimension) { 345 return std::visit( 346 common::visitors{ 347 [&](const Triplet &triplet) -> MaybeExtentExpr { 348 MaybeExtentExpr upper{triplet.upper()}; 349 if (!upper) { 350 upper = GetUpperBound(base, dimension); 351 } 352 MaybeExtentExpr lower{triplet.lower()}; 353 if (!lower) { 354 lower = GetLowerBound(base, dimension); 355 } 356 return CountTrips(std::move(lower), std::move(upper), 357 MaybeExtentExpr{triplet.stride()}); 358 }, 359 [&](const IndirectSubscriptIntegerExpr &subs) -> MaybeExtentExpr { 360 if (auto shape{GetShape(subs.value())}) { 361 if (GetRank(*shape) > 0) { 362 CHECK(GetRank(*shape) == 1); // vector-valued subscript 363 return std::move(shape->at(0)); 364 } 365 } 366 return std::nullopt; 367 }, 368 }, 369 subscript.u); 370 } 371 372 MaybeExtentExpr GetExtent(FoldingContext &context, const Subscript &subscript, 373 const NamedEntity &base, int dimension) { 374 return Fold(context, GetExtent(subscript, base, dimension)); 375 } 376 377 MaybeExtentExpr ComputeUpperBound( 378 ExtentExpr &&lower, MaybeExtentExpr &&extent) { 379 if (extent) { 380 return std::move(*extent) + std::move(lower) - ExtentExpr{1}; 381 } else { 382 return std::nullopt; 383 } 384 } 385 386 MaybeExtentExpr ComputeUpperBound( 387 FoldingContext &context, ExtentExpr &&lower, MaybeExtentExpr &&extent) { 388 return Fold(context, ComputeUpperBound(std::move(lower), std::move(extent))); 389 } 390 391 MaybeExtentExpr GetUpperBound(const NamedEntity &base, int dimension) { 392 const Symbol &symbol{ResolveAssociations(base.GetLastSymbol())}; 393 if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) { 394 int j{0}; 395 for (const auto &shapeSpec : details->shape()) { 396 if (j++ == dimension) { 397 if (const auto &bound{shapeSpec.ubound().GetExplicit()}) { 398 return *bound; 399 } else if (details->IsAssumedSize() && dimension + 1 == symbol.Rank()) { 400 break; 401 } else { 402 return ComputeUpperBound( 403 GetLowerBound(base, dimension), GetExtent(base, dimension)); 404 } 405 } 406 } 407 } else if (const auto *assoc{ 408 symbol.detailsIf<semantics::AssocEntityDetails>()}) { 409 if (auto shape{GetShape(assoc->expr())}) { 410 if (dimension < static_cast<int>(shape->size())) { 411 return ComputeUpperBound( 412 GetLowerBound(base, dimension), std::move(shape->at(dimension))); 413 } 414 } 415 } 416 return std::nullopt; 417 } 418 419 MaybeExtentExpr GetUpperBound( 420 FoldingContext &context, const NamedEntity &base, int dimension) { 421 return Fold(context, GetUpperBound(base, dimension)); 422 } 423 424 Shape GetUpperBounds(const NamedEntity &base) { 425 const Symbol &symbol{ResolveAssociations(base.GetLastSymbol())}; 426 if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) { 427 Shape result; 428 int dim{0}; 429 for (const auto &shapeSpec : details->shape()) { 430 if (const auto &bound{shapeSpec.ubound().GetExplicit()}) { 431 result.push_back(*bound); 432 } else if (details->IsAssumedSize()) { 433 CHECK(dim + 1 == base.Rank()); 434 result.emplace_back(std::nullopt); // UBOUND folding replaces with -1 435 } else { 436 result.emplace_back( 437 ComputeUpperBound(GetLowerBound(base, dim), GetExtent(base, dim))); 438 } 439 ++dim; 440 } 441 CHECK(GetRank(result) == symbol.Rank()); 442 return result; 443 } else { 444 return std::move(GetShape(symbol).value()); 445 } 446 } 447 448 Shape GetUpperBounds(FoldingContext &context, const NamedEntity &base) { 449 return Fold(context, GetUpperBounds(base)); 450 } 451 452 auto GetShapeHelper::operator()(const Symbol &symbol) const -> Result { 453 return std::visit( 454 common::visitors{ 455 [&](const semantics::ObjectEntityDetails &object) { 456 if (IsImpliedShape(symbol) && object.init()) { 457 return (*this)(object.init()); 458 } else { 459 int n{object.shape().Rank()}; 460 NamedEntity base{symbol}; 461 return Result{CreateShape(n, base)}; 462 } 463 }, 464 [](const semantics::EntityDetails &) { 465 return ScalarShape(); // no dimensions seen 466 }, 467 [&](const semantics::ProcEntityDetails &proc) { 468 if (const Symbol * interface{proc.interface().symbol()}) { 469 return (*this)(*interface); 470 } else { 471 return ScalarShape(); 472 } 473 }, 474 [&](const semantics::AssocEntityDetails &assoc) { 475 if (!assoc.rank()) { 476 return (*this)(assoc.expr()); 477 } else { 478 int n{assoc.rank().value()}; 479 NamedEntity base{symbol}; 480 return Result{CreateShape(n, base)}; 481 } 482 }, 483 [&](const semantics::SubprogramDetails &subp) { 484 if (subp.isFunction()) { 485 return (*this)(subp.result()); 486 } else { 487 return Result{}; 488 } 489 }, 490 [&](const semantics::ProcBindingDetails &binding) { 491 return (*this)(binding.symbol()); 492 }, 493 [&](const semantics::UseDetails &use) { 494 return (*this)(use.symbol()); 495 }, 496 [&](const semantics::HostAssocDetails &assoc) { 497 return (*this)(assoc.symbol()); 498 }, 499 [](const semantics::TypeParamDetails &) { return ScalarShape(); }, 500 [](const auto &) { return Result{}; }, 501 }, 502 symbol.details()); 503 } 504 505 auto GetShapeHelper::operator()(const Component &component) const -> Result { 506 const Symbol &symbol{component.GetLastSymbol()}; 507 int rank{symbol.Rank()}; 508 if (rank == 0) { 509 return (*this)(component.base()); 510 } else if (symbol.has<semantics::ObjectEntityDetails>()) { 511 NamedEntity base{Component{component}}; 512 return CreateShape(rank, base); 513 } else if (symbol.has<semantics::AssocEntityDetails>()) { 514 NamedEntity base{Component{component}}; 515 return Result{CreateShape(rank, base)}; 516 } else { 517 return (*this)(symbol); 518 } 519 } 520 521 auto GetShapeHelper::operator()(const ArrayRef &arrayRef) const -> Result { 522 Shape shape; 523 int dimension{0}; 524 const NamedEntity &base{arrayRef.base()}; 525 for (const Subscript &ss : arrayRef.subscript()) { 526 if (ss.Rank() > 0) { 527 shape.emplace_back(GetExtent(ss, base, dimension)); 528 } 529 ++dimension; 530 } 531 if (shape.empty()) { 532 if (const Component * component{base.UnwrapComponent()}) { 533 return (*this)(component->base()); 534 } 535 } 536 return shape; 537 } 538 539 auto GetShapeHelper::operator()(const CoarrayRef &coarrayRef) const -> Result { 540 NamedEntity base{coarrayRef.GetBase()}; 541 if (coarrayRef.subscript().empty()) { 542 return (*this)(base); 543 } else { 544 Shape shape; 545 int dimension{0}; 546 for (const Subscript &ss : coarrayRef.subscript()) { 547 if (ss.Rank() > 0) { 548 shape.emplace_back(GetExtent(ss, base, dimension)); 549 } 550 ++dimension; 551 } 552 return shape; 553 } 554 } 555 556 auto GetShapeHelper::operator()(const Substring &substring) const -> Result { 557 return (*this)(substring.parent()); 558 } 559 560 auto GetShapeHelper::operator()(const ProcedureRef &call) const -> Result { 561 if (call.Rank() == 0) { 562 return ScalarShape(); 563 } else if (call.IsElemental()) { 564 for (const auto &arg : call.arguments()) { 565 if (arg && arg->Rank() > 0) { 566 return (*this)(*arg); 567 } 568 } 569 return ScalarShape(); 570 } else if (const Symbol * symbol{call.proc().GetSymbol()}) { 571 return (*this)(*symbol); 572 } else if (const auto *intrinsic{call.proc().GetSpecificIntrinsic()}) { 573 if (intrinsic->name == "shape" || intrinsic->name == "lbound" || 574 intrinsic->name == "ubound") { 575 // These are the array-valued cases for LBOUND and UBOUND (no DIM=). 576 const auto *expr{call.arguments().front().value().UnwrapExpr()}; 577 CHECK(expr); 578 return Shape{MaybeExtentExpr{ExtentExpr{expr->Rank()}}}; 579 } else if (intrinsic->name == "all" || intrinsic->name == "any" || 580 intrinsic->name == "count" || intrinsic->name == "iall" || 581 intrinsic->name == "iany" || intrinsic->name == "iparity" || 582 intrinsic->name == "maxval" || intrinsic->name == "minval" || 583 intrinsic->name == "norm2" || intrinsic->name == "parity" || 584 intrinsic->name == "product" || intrinsic->name == "sum") { 585 // Reduction with DIM= 586 if (call.arguments().size() >= 2) { 587 auto arrayShape{ 588 (*this)(UnwrapExpr<Expr<SomeType>>(call.arguments().at(0)))}; 589 const auto *dimArg{UnwrapExpr<Expr<SomeType>>(call.arguments().at(1))}; 590 if (arrayShape && dimArg) { 591 if (auto dim{ToInt64(*dimArg)}) { 592 if (*dim >= 1 && 593 static_cast<std::size_t>(*dim) <= arrayShape->size()) { 594 arrayShape->erase(arrayShape->begin() + (*dim - 1)); 595 return std::move(*arrayShape); 596 } 597 } 598 } 599 } 600 } else if (intrinsic->name == "maxloc" || intrinsic->name == "minloc") { 601 // TODO: FINDLOC 602 if (call.arguments().size() >= 2) { 603 if (auto arrayShape{ 604 (*this)(UnwrapExpr<Expr<SomeType>>(call.arguments().at(0)))}) { 605 auto rank{static_cast<int>(arrayShape->size())}; 606 if (const auto *dimArg{ 607 UnwrapExpr<Expr<SomeType>>(call.arguments()[1])}) { 608 auto dim{ToInt64(*dimArg)}; 609 if (dim && *dim >= 1 && *dim <= rank) { 610 arrayShape->erase(arrayShape->begin() + (*dim - 1)); 611 return std::move(*arrayShape); 612 } 613 } else { 614 // xxxLOC(no DIM=) result is vector(1:RANK(ARRAY=)) 615 return Shape{ExtentExpr{rank}}; 616 } 617 } 618 } 619 } else if (intrinsic->name == "cshift" || intrinsic->name == "eoshift") { 620 if (!call.arguments().empty()) { 621 return (*this)(call.arguments()[0]); 622 } 623 } else if (intrinsic->name == "matmul") { 624 if (call.arguments().size() == 2) { 625 if (auto ashape{(*this)(call.arguments()[0])}) { 626 if (auto bshape{(*this)(call.arguments()[1])}) { 627 if (ashape->size() == 1 && bshape->size() == 2) { 628 bshape->erase(bshape->begin()); 629 return std::move(*bshape); // matmul(vector, matrix) 630 } else if (ashape->size() == 2 && bshape->size() == 1) { 631 ashape->pop_back(); 632 return std::move(*ashape); // matmul(matrix, vector) 633 } else if (ashape->size() == 2 && bshape->size() == 2) { 634 (*ashape)[1] = std::move((*bshape)[1]); 635 return std::move(*ashape); // matmul(matrix, matrix) 636 } 637 } 638 } 639 } 640 } else if (intrinsic->name == "reshape") { 641 if (call.arguments().size() >= 2 && call.arguments().at(1)) { 642 // SHAPE(RESHAPE(array,shape)) -> shape 643 if (const auto *shapeExpr{ 644 call.arguments().at(1).value().UnwrapExpr()}) { 645 auto shape{std::get<Expr<SomeInteger>>(shapeExpr->u)}; 646 return AsShape(ConvertToType<ExtentType>(std::move(shape))); 647 } 648 } 649 } else if (intrinsic->name == "pack") { 650 if (call.arguments().size() >= 3 && call.arguments().at(2)) { 651 // SHAPE(PACK(,,VECTOR=v)) -> SHAPE(v) 652 return (*this)(call.arguments().at(2)); 653 } else if (call.arguments().size() >= 2 && context_) { 654 if (auto maskShape{(*this)(call.arguments().at(1))}) { 655 if (maskShape->size() == 0) { 656 // Scalar MASK= -> [MERGE(SIZE(ARRAY=), 0, mask)] 657 if (auto arrayShape{(*this)(call.arguments().at(0))}) { 658 auto arraySize{GetSize(std::move(*arrayShape))}; 659 CHECK(arraySize); 660 ActualArguments toMerge{ 661 ActualArgument{AsGenericExpr(std::move(*arraySize))}, 662 ActualArgument{AsGenericExpr(ExtentExpr{0})}, 663 common::Clone(call.arguments().at(1))}; 664 auto specific{context_->intrinsics().Probe( 665 CallCharacteristics{"merge"}, toMerge, *context_)}; 666 CHECK(specific); 667 return Shape{ExtentExpr{FunctionRef<ExtentType>{ 668 ProcedureDesignator{std::move(specific->specificIntrinsic)}, 669 std::move(specific->arguments)}}}; 670 } 671 } else { 672 // Non-scalar MASK= -> [COUNT(mask)] 673 ActualArguments toCount{ActualArgument{common::Clone( 674 DEREF(call.arguments().at(1).value().UnwrapExpr()))}}; 675 auto specific{context_->intrinsics().Probe( 676 CallCharacteristics{"count"}, toCount, *context_)}; 677 CHECK(specific); 678 return Shape{ExtentExpr{FunctionRef<ExtentType>{ 679 ProcedureDesignator{std::move(specific->specificIntrinsic)}, 680 std::move(specific->arguments)}}}; 681 } 682 } 683 } 684 } else if (intrinsic->name == "spread") { 685 // SHAPE(SPREAD(ARRAY,DIM,NCOPIES)) = SHAPE(ARRAY) with NCOPIES inserted 686 // at position DIM. 687 if (call.arguments().size() == 3) { 688 auto arrayShape{ 689 (*this)(UnwrapExpr<Expr<SomeType>>(call.arguments().at(0)))}; 690 const auto *dimArg{UnwrapExpr<Expr<SomeType>>(call.arguments().at(1))}; 691 const auto *nCopies{ 692 UnwrapExpr<Expr<SomeInteger>>(call.arguments().at(2))}; 693 if (arrayShape && dimArg && nCopies) { 694 if (auto dim{ToInt64(*dimArg)}) { 695 if (*dim >= 1 && 696 static_cast<std::size_t>(*dim) <= arrayShape->size() + 1) { 697 arrayShape->emplace(arrayShape->begin() + *dim - 1, 698 ConvertToType<ExtentType>(common::Clone(*nCopies))); 699 return std::move(*arrayShape); 700 } 701 } 702 } 703 } 704 } else if (intrinsic->name == "transfer") { 705 if (call.arguments().size() == 3 && call.arguments().at(2)) { 706 // SIZE= is present; shape is vector [SIZE=] 707 if (const auto *size{ 708 UnwrapExpr<Expr<SomeInteger>>(call.arguments().at(2))}) { 709 return Shape{ 710 MaybeExtentExpr{ConvertToType<ExtentType>(common::Clone(*size))}}; 711 } 712 } else if (context_) { 713 if (auto moldTypeAndShape{characteristics::TypeAndShape::Characterize( 714 call.arguments().at(1), *context_)}) { 715 if (GetRank(moldTypeAndShape->shape()) == 0) { 716 // SIZE= is absent and MOLD= is scalar: result is scalar 717 return ScalarShape(); 718 } else { 719 // SIZE= is absent and MOLD= is array: result is vector whose 720 // length is determined by sizes of types. See 16.9.193p4 case(ii). 721 if (auto sourceTypeAndShape{ 722 characteristics::TypeAndShape::Characterize( 723 call.arguments().at(0), *context_)}) { 724 auto sourceBytes{ 725 sourceTypeAndShape->MeasureSizeInBytes(*context_)}; 726 auto moldElementBytes{ 727 moldTypeAndShape->MeasureElementSizeInBytes(*context_, true)}; 728 if (sourceBytes && moldElementBytes) { 729 ExtentExpr extent{Fold(*context_, 730 (std::move(*sourceBytes) + 731 common::Clone(*moldElementBytes) - ExtentExpr{1}) / 732 common::Clone(*moldElementBytes))}; 733 return Shape{MaybeExtentExpr{std::move(extent)}}; 734 } 735 } 736 } 737 } 738 } 739 } else if (intrinsic->name == "transpose") { 740 if (call.arguments().size() >= 1) { 741 if (auto shape{(*this)(call.arguments().at(0))}) { 742 if (shape->size() == 2) { 743 std::swap((*shape)[0], (*shape)[1]); 744 return shape; 745 } 746 } 747 } 748 } else if (intrinsic->name == "unpack") { 749 if (call.arguments().size() >= 2) { 750 return (*this)(call.arguments()[1]); // MASK= 751 } 752 } else if (intrinsic->characteristics.value().attrs.test(characteristics:: 753 Procedure::Attr::NullPointer)) { // NULL(MOLD=) 754 return (*this)(call.arguments()); 755 } else { 756 // TODO: shapes of other non-elemental intrinsic results 757 } 758 } 759 return std::nullopt; 760 } 761 762 // Check conformance of the passed shapes. Only return true if we can verify 763 // that they conform 764 bool CheckConformance(parser::ContextualMessages &messages, const Shape &left, 765 const Shape &right, const char *leftIs, const char *rightIs, 766 bool leftScalarExpandable, bool rightScalarExpandable, 767 bool leftIsDeferredShape, bool rightIsDeferredShape) { 768 int n{GetRank(left)}; 769 if (n == 0 && leftScalarExpandable) { 770 return true; 771 } 772 int rn{GetRank(right)}; 773 if (rn == 0 && rightScalarExpandable) { 774 return true; 775 } 776 if (n != rn) { 777 messages.Say("Rank of %1$s is %2$d, but %3$s has rank %4$d"_err_en_US, 778 leftIs, n, rightIs, rn); 779 return false; 780 } 781 for (int j{0}; j < n; ++j) { 782 if (auto leftDim{ToInt64(left[j])}) { 783 if (auto rightDim{ToInt64(right[j])}) { 784 if (*leftDim != *rightDim) { 785 messages.Say("Dimension %1$d of %2$s has extent %3$jd, " 786 "but %4$s has extent %5$jd"_err_en_US, 787 j + 1, leftIs, *leftDim, rightIs, *rightDim); 788 return false; 789 } 790 } else if (!rightIsDeferredShape) { 791 return false; 792 } 793 } else if (!leftIsDeferredShape) { 794 return false; 795 } 796 } 797 return true; 798 } 799 800 bool IncrementSubscripts( 801 ConstantSubscripts &indices, const ConstantSubscripts &extents) { 802 std::size_t rank(indices.size()); 803 CHECK(rank <= extents.size()); 804 for (std::size_t j{0}; j < rank; ++j) { 805 if (extents[j] < 1) { 806 return false; 807 } 808 } 809 for (std::size_t j{0}; j < rank; ++j) { 810 if (indices[j]++ < extents[j]) { 811 return true; 812 } 813 indices[j] = 1; 814 } 815 return false; 816 } 817 818 } // namespace Fortran::evaluate 819