1 //===-- VectorSubscripts.cpp -- Vector subscripts tools -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "flang/Lower/VectorSubscripts.h" 14 #include "flang/Lower/AbstractConverter.h" 15 #include "flang/Lower/Support/Utils.h" 16 #include "flang/Lower/Todo.h" 17 #include "flang/Optimizer/Builder/Character.h" 18 #include "flang/Optimizer/Builder/Complex.h" 19 #include "flang/Optimizer/Builder/FIRBuilder.h" 20 #include "flang/Semantics/expression.h" 21 22 namespace { 23 /// Helper class to lower a designator containing vector subscripts into a 24 /// lowered representation that can be worked with. 25 class VectorSubscriptBoxBuilder { 26 public: 27 VectorSubscriptBoxBuilder(mlir::Location loc, 28 Fortran::lower::AbstractConverter &converter, 29 Fortran::lower::StatementContext &stmtCtx) 30 : converter{converter}, stmtCtx{stmtCtx}, loc{loc} {} 31 32 Fortran::lower::VectorSubscriptBox gen(const Fortran::lower::SomeExpr &expr) { 33 elementType = genDesignator(expr); 34 return Fortran::lower::VectorSubscriptBox( 35 std::move(loweredBase), std::move(loweredSubscripts), 36 std::move(componentPath), substringBounds, elementType); 37 } 38 39 private: 40 using LoweredVectorSubscript = 41 Fortran::lower::VectorSubscriptBox::LoweredVectorSubscript; 42 using LoweredTriplet = Fortran::lower::VectorSubscriptBox::LoweredTriplet; 43 using LoweredSubscript = Fortran::lower::VectorSubscriptBox::LoweredSubscript; 44 using MaybeSubstring = Fortran::lower::VectorSubscriptBox::MaybeSubstring; 45 46 /// genDesignator unwraps a Designator<T> and calls `gen` on what the 47 /// designator actually contains. 48 template <typename A> 49 mlir::Type genDesignator(const A &) { 50 fir::emitFatalError(loc, "expr must contain a designator"); 51 } 52 template <typename T> 53 mlir::Type genDesignator(const Fortran::evaluate::Expr<T> &expr) { 54 using ExprVariant = decltype(Fortran::evaluate::Expr<T>::u); 55 using Designator = Fortran::evaluate::Designator<T>; 56 if constexpr (Fortran::common::HasMember<Designator, ExprVariant>) { 57 const auto &designator = std::get<Designator>(expr.u); 58 return std::visit([&](const auto &x) { return gen(x); }, designator.u); 59 } else { 60 return std::visit([&](const auto &x) { return genDesignator(x); }, 61 expr.u); 62 } 63 } 64 65 // The gen(X) methods visit X to lower its base and subscripts and return the 66 // type of X elements. 67 68 mlir::Type gen(const Fortran::evaluate::DataRef &dataRef) { 69 return std::visit([&](const auto &ref) -> mlir::Type { return gen(ref); }, 70 dataRef.u); 71 } 72 73 mlir::Type gen(const Fortran::evaluate::SymbolRef &symRef) { 74 // Never visited because expr lowering is used to lowered the ranked 75 // ArrayRef. 76 fir::emitFatalError( 77 loc, "expected at least one ArrayRef with vector susbcripts"); 78 } 79 80 mlir::Type gen(const Fortran::evaluate::Substring &substring) { 81 // StaticDataObject::Pointer bases are constants and cannot be 82 // subscripted, so the base must be a DataRef here. 83 mlir::Type baseElementType = 84 gen(std::get<Fortran::evaluate::DataRef>(substring.parent())); 85 fir::FirOpBuilder &builder = converter.getFirOpBuilder(); 86 mlir::Type idxTy = builder.getIndexType(); 87 mlir::Value lb = genScalarValue(substring.lower()); 88 substringBounds.emplace_back(builder.createConvert(loc, idxTy, lb)); 89 if (const auto &ubExpr = substring.upper()) { 90 mlir::Value ub = genScalarValue(*ubExpr); 91 substringBounds.emplace_back(builder.createConvert(loc, idxTy, ub)); 92 } 93 return baseElementType; 94 } 95 96 mlir::Type gen(const Fortran::evaluate::ComplexPart &complexPart) { 97 auto complexType = gen(complexPart.complex()); 98 fir::FirOpBuilder &builder = converter.getFirOpBuilder(); 99 mlir::Type i32Ty = builder.getI32Type(); // llvm's GEP requires i32 100 mlir::Value offset = builder.createIntegerConstant( 101 loc, i32Ty, 102 complexPart.part() == Fortran::evaluate::ComplexPart::Part::RE ? 0 : 1); 103 componentPath.emplace_back(offset); 104 return fir::factory::Complex{builder, loc}.getComplexPartType(complexType); 105 } 106 107 mlir::Type gen(const Fortran::evaluate::Component &component) { 108 auto recTy = gen(component.base()).cast<fir::RecordType>(); 109 const Fortran::semantics::Symbol &componentSymbol = 110 component.GetLastSymbol(); 111 // Parent components will not be found here, they are not part 112 // of the FIR type and cannot be used in the path yet. 113 if (componentSymbol.test(Fortran::semantics::Symbol::Flag::ParentComp)) 114 TODO(loc, "Reference to parent component"); 115 mlir::Type fldTy = fir::FieldType::get(&converter.getMLIRContext()); 116 llvm::StringRef componentName = toStringRef(componentSymbol.name()); 117 // Parameters threading in field_index is not yet very clear. We only 118 // have the ones of the ranked array ref at hand, but it looks like 119 // the fir.field_index expects the one of the direct base. 120 if (recTy.getNumLenParams() != 0) 121 TODO(loc, "threading length parameters in field index op"); 122 fir::FirOpBuilder &builder = converter.getFirOpBuilder(); 123 componentPath.emplace_back(builder.create<fir::FieldIndexOp>( 124 loc, fldTy, componentName, recTy, /*typeParams*/ llvm::None)); 125 return fir::unwrapSequenceType(recTy.getType(componentName)); 126 } 127 128 mlir::Type gen(const Fortran::evaluate::ArrayRef &arrayRef) { 129 auto isTripletOrVector = 130 [](const Fortran::evaluate::Subscript &subscript) -> bool { 131 return std::visit( 132 Fortran::common::visitors{ 133 [](const Fortran::evaluate::IndirectSubscriptIntegerExpr &expr) { 134 return expr.value().Rank() != 0; 135 }, 136 [&](const Fortran::evaluate::Triplet &) { return true; }}, 137 subscript.u); 138 }; 139 if (llvm::any_of(arrayRef.subscript(), isTripletOrVector)) 140 return genRankedArrayRefSubscriptAndBase(arrayRef); 141 142 // This is a scalar ArrayRef (only scalar indexes), collect the indexes and 143 // visit the base that must contain another arrayRef with the vector 144 // subscript. 145 mlir::Type elementType = gen(namedEntityToDataRef(arrayRef.base())); 146 for (const Fortran::evaluate::Subscript &subscript : arrayRef.subscript()) { 147 const auto &expr = 148 std::get<Fortran::evaluate::IndirectSubscriptIntegerExpr>( 149 subscript.u); 150 componentPath.emplace_back(genScalarValue(expr.value())); 151 } 152 return elementType; 153 } 154 155 /// Lower the subscripts and base of the ArrayRef that is an array (there must 156 /// be one since there is a vector subscript, and there can only be one 157 /// according to C925). 158 mlir::Type genRankedArrayRefSubscriptAndBase( 159 const Fortran::evaluate::ArrayRef &arrayRef) { 160 // Lower the save the base 161 Fortran::lower::SomeExpr baseExpr = namedEntityToExpr(arrayRef.base()); 162 loweredBase = converter.genExprAddr(baseExpr, stmtCtx); 163 // Lower and save the subscripts 164 fir::FirOpBuilder &builder = converter.getFirOpBuilder(); 165 mlir::Type idxTy = builder.getIndexType(); 166 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); 167 for (const auto &subscript : llvm::enumerate(arrayRef.subscript())) { 168 std::visit( 169 Fortran::common::visitors{ 170 [&](const Fortran::evaluate::IndirectSubscriptIntegerExpr &expr) { 171 if (expr.value().Rank() == 0) { 172 // Simple scalar subscript 173 loweredSubscripts.emplace_back(genScalarValue(expr.value())); 174 } else { 175 // Vector subscript. 176 // Remove conversion if any to avoid temp creation that may 177 // have been added by the front-end to avoid the creation of a 178 // temp array value. 179 auto vector = converter.genExprAddr( 180 ignoreEvConvert(expr.value()), stmtCtx); 181 mlir::Value size = 182 fir::factory::readExtent(builder, loc, vector, /*dim=*/0); 183 size = builder.createConvert(loc, idxTy, size); 184 loweredSubscripts.emplace_back( 185 LoweredVectorSubscript{std::move(vector), size}); 186 } 187 }, 188 [&](const Fortran::evaluate::Triplet &triplet) { 189 mlir::Value lb, ub; 190 if (const auto &lbExpr = triplet.lower()) 191 lb = genScalarValue(*lbExpr); 192 else 193 lb = fir::factory::readLowerBound(builder, loc, loweredBase, 194 subscript.index(), one); 195 if (const auto &ubExpr = triplet.upper()) 196 ub = genScalarValue(*ubExpr); 197 else 198 ub = fir::factory::readExtent(builder, loc, loweredBase, 199 subscript.index()); 200 lb = builder.createConvert(loc, idxTy, lb); 201 ub = builder.createConvert(loc, idxTy, ub); 202 mlir::Value stride = genScalarValue(triplet.stride()); 203 stride = builder.createConvert(loc, idxTy, stride); 204 loweredSubscripts.emplace_back(LoweredTriplet{lb, ub, stride}); 205 }, 206 }, 207 subscript.value().u); 208 } 209 return fir::unwrapSequenceType( 210 fir::unwrapPassByRefType(fir::getBase(loweredBase).getType())); 211 } 212 213 mlir::Type gen(const Fortran::evaluate::CoarrayRef &) { 214 // Is this possible/legal ? 215 TODO(loc, "Coarray ref with vector subscript in IO input"); 216 } 217 218 template <typename A> 219 mlir::Value genScalarValue(const A &expr) { 220 return fir::getBase(converter.genExprValue(toEvExpr(expr), stmtCtx)); 221 } 222 223 Fortran::evaluate::DataRef 224 namedEntityToDataRef(const Fortran::evaluate::NamedEntity &namedEntity) { 225 if (namedEntity.IsSymbol()) 226 return Fortran::evaluate::DataRef{namedEntity.GetFirstSymbol()}; 227 return Fortran::evaluate::DataRef{namedEntity.GetComponent()}; 228 } 229 230 Fortran::lower::SomeExpr 231 namedEntityToExpr(const Fortran::evaluate::NamedEntity &namedEntity) { 232 return Fortran::evaluate::AsGenericExpr(namedEntityToDataRef(namedEntity)) 233 .value(); 234 } 235 236 Fortran::lower::AbstractConverter &converter; 237 Fortran::lower::StatementContext &stmtCtx; 238 mlir::Location loc; 239 /// Elements of VectorSubscriptBox being built. 240 fir::ExtendedValue loweredBase; 241 llvm::SmallVector<LoweredSubscript, 16> loweredSubscripts; 242 llvm::SmallVector<mlir::Value> componentPath; 243 MaybeSubstring substringBounds; 244 mlir::Type elementType; 245 }; 246 } // namespace 247 248 Fortran::lower::VectorSubscriptBox Fortran::lower::genVectorSubscriptBox( 249 mlir::Location loc, Fortran::lower::AbstractConverter &converter, 250 Fortran::lower::StatementContext &stmtCtx, 251 const Fortran::lower::SomeExpr &expr) { 252 return VectorSubscriptBoxBuilder(loc, converter, stmtCtx).gen(expr); 253 } 254 255 template <typename LoopType, typename Generator> 256 mlir::Value Fortran::lower::VectorSubscriptBox::loopOverElementsBase( 257 fir::FirOpBuilder &builder, mlir::Location loc, 258 const Generator &elementalGenerator, 259 [[maybe_unused]] mlir::Value initialCondition) { 260 mlir::Value shape = builder.createShape(loc, loweredBase); 261 mlir::Value slice = createSlice(builder, loc); 262 263 // Create loop nest for triplets and vector subscripts in column 264 // major order. 265 llvm::SmallVector<mlir::Value> inductionVariables; 266 LoopType outerLoop; 267 for (auto [lb, ub, step] : genLoopBounds(builder, loc)) { 268 LoopType loop; 269 if constexpr (std::is_same_v<LoopType, fir::IterWhileOp>) { 270 loop = 271 builder.create<fir::IterWhileOp>(loc, lb, ub, step, initialCondition); 272 initialCondition = loop.getIterateVar(); 273 if (!outerLoop) 274 outerLoop = loop; 275 else 276 builder.create<fir::ResultOp>(loc, loop.getResult(0)); 277 } else { 278 loop = 279 builder.create<fir::DoLoopOp>(loc, lb, ub, step, /*unordered=*/false); 280 if (!outerLoop) 281 outerLoop = loop; 282 } 283 builder.setInsertionPointToStart(loop.getBody()); 284 inductionVariables.push_back(loop.getInductionVar()); 285 } 286 assert(outerLoop && !inductionVariables.empty() && 287 "at least one loop should be created"); 288 289 fir::ExtendedValue elem = 290 getElementAt(builder, loc, shape, slice, inductionVariables); 291 292 if constexpr (std::is_same_v<LoopType, fir::IterWhileOp>) { 293 auto res = elementalGenerator(elem); 294 builder.create<fir::ResultOp>(loc, res); 295 builder.setInsertionPointAfter(outerLoop); 296 return outerLoop.getResult(0); 297 } else { 298 elementalGenerator(elem); 299 builder.setInsertionPointAfter(outerLoop); 300 return {}; 301 } 302 } 303 304 void Fortran::lower::VectorSubscriptBox::loopOverElements( 305 fir::FirOpBuilder &builder, mlir::Location loc, 306 const ElementalGenerator &elementalGenerator) { 307 mlir::Value initialCondition; 308 loopOverElementsBase<fir::DoLoopOp, ElementalGenerator>( 309 builder, loc, elementalGenerator, initialCondition); 310 } 311 312 mlir::Value Fortran::lower::VectorSubscriptBox::loopOverElementsWhile( 313 fir::FirOpBuilder &builder, mlir::Location loc, 314 const ElementalGeneratorWithBoolReturn &elementalGenerator, 315 mlir::Value initialCondition) { 316 return loopOverElementsBase<fir::IterWhileOp, 317 ElementalGeneratorWithBoolReturn>( 318 builder, loc, elementalGenerator, initialCondition); 319 } 320 321 mlir::Value 322 Fortran::lower::VectorSubscriptBox::createSlice(fir::FirOpBuilder &builder, 323 mlir::Location loc) { 324 mlir::Type idxTy = builder.getIndexType(); 325 llvm::SmallVector<mlir::Value> triples; 326 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); 327 auto undef = builder.create<fir::UndefOp>(loc, idxTy); 328 for (const LoweredSubscript &subscript : loweredSubscripts) 329 std::visit(Fortran::common::visitors{ 330 [&](const LoweredTriplet &triplet) { 331 triples.emplace_back(triplet.lb); 332 triples.emplace_back(triplet.ub); 333 triples.emplace_back(triplet.stride); 334 }, 335 [&](const LoweredVectorSubscript &vector) { 336 triples.emplace_back(one); 337 triples.emplace_back(vector.size); 338 triples.emplace_back(one); 339 }, 340 [&](const mlir::Value &i) { 341 triples.emplace_back(i); 342 triples.emplace_back(undef); 343 triples.emplace_back(undef); 344 }, 345 }, 346 subscript); 347 return builder.create<fir::SliceOp>(loc, triples, componentPath); 348 } 349 350 llvm::SmallVector<std::tuple<mlir::Value, mlir::Value, mlir::Value>> 351 Fortran::lower::VectorSubscriptBox::genLoopBounds(fir::FirOpBuilder &builder, 352 mlir::Location loc) { 353 mlir::Type idxTy = builder.getIndexType(); 354 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); 355 mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0); 356 llvm::SmallVector<std::tuple<mlir::Value, mlir::Value, mlir::Value>> bounds; 357 size_t dimension = loweredSubscripts.size(); 358 for (const LoweredSubscript &subscript : llvm::reverse(loweredSubscripts)) { 359 --dimension; 360 if (std::holds_alternative<mlir::Value>(subscript)) 361 continue; 362 mlir::Value lb, ub, step; 363 if (const auto *triplet = std::get_if<LoweredTriplet>(&subscript)) { 364 mlir::Value extent = builder.genExtentFromTriplet( 365 loc, triplet->lb, triplet->ub, triplet->stride, idxTy); 366 mlir::Value baseLb = fir::factory::readLowerBound( 367 builder, loc, loweredBase, dimension, one); 368 baseLb = builder.createConvert(loc, idxTy, baseLb); 369 lb = baseLb; 370 ub = builder.create<mlir::arith::SubIOp>(loc, idxTy, extent, one); 371 ub = builder.create<mlir::arith::AddIOp>(loc, idxTy, ub, baseLb); 372 step = one; 373 } else { 374 const auto &vector = std::get<LoweredVectorSubscript>(subscript); 375 lb = zero; 376 ub = builder.create<mlir::arith::SubIOp>(loc, idxTy, vector.size, one); 377 step = one; 378 } 379 bounds.emplace_back(lb, ub, step); 380 } 381 return bounds; 382 } 383 384 fir::ExtendedValue Fortran::lower::VectorSubscriptBox::getElementAt( 385 fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value shape, 386 mlir::Value slice, mlir::ValueRange inductionVariables) { 387 /// Generate the indexes for the array_coor inside the loops. 388 mlir::Type idxTy = builder.getIndexType(); 389 llvm::SmallVector<mlir::Value> indexes; 390 size_t inductionIdx = inductionVariables.size() - 1; 391 for (const LoweredSubscript &subscript : loweredSubscripts) 392 std::visit(Fortran::common::visitors{ 393 [&](const LoweredTriplet &triplet) { 394 indexes.emplace_back(inductionVariables[inductionIdx--]); 395 }, 396 [&](const LoweredVectorSubscript &vector) { 397 mlir::Value vecIndex = inductionVariables[inductionIdx--]; 398 mlir::Value vecBase = fir::getBase(vector.vector); 399 mlir::Type vecEleTy = fir::unwrapSequenceType( 400 fir::unwrapPassByRefType(vecBase.getType())); 401 mlir::Type refTy = builder.getRefType(vecEleTy); 402 auto vecEltRef = builder.create<fir::CoordinateOp>( 403 loc, refTy, vecBase, vecIndex); 404 auto vecElt = 405 builder.create<fir::LoadOp>(loc, vecEleTy, vecEltRef); 406 indexes.emplace_back( 407 builder.createConvert(loc, idxTy, vecElt)); 408 }, 409 [&](const mlir::Value &i) { 410 indexes.emplace_back(builder.createConvert(loc, idxTy, i)); 411 }, 412 }, 413 subscript); 414 mlir::Type refTy = builder.getRefType(getElementType()); 415 auto elementAddr = builder.create<fir::ArrayCoorOp>( 416 loc, refTy, fir::getBase(loweredBase), shape, slice, indexes, 417 fir::getTypeParams(loweredBase)); 418 fir::ExtendedValue element = fir::factory::arraySectionElementToExtendedValue( 419 builder, loc, loweredBase, elementAddr, slice); 420 if (!substringBounds.empty()) { 421 const fir::CharBoxValue *charBox = element.getCharBox(); 422 assert(charBox && "substring requires CharBox base"); 423 fir::factory::CharacterExprHelper helper{builder, loc}; 424 return helper.createSubstring(*charBox, substringBounds); 425 } 426 return element; 427 } 428