1 //===-- VectorSubscripts.cpp -- Vector subscripts tools -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "flang/Lower/VectorSubscripts.h"
14 #include "flang/Lower/AbstractConverter.h"
15 #include "flang/Lower/Support/Utils.h"
16 #include "flang/Optimizer/Builder/Character.h"
17 #include "flang/Optimizer/Builder/Complex.h"
18 #include "flang/Optimizer/Builder/FIRBuilder.h"
19 #include "flang/Optimizer/Builder/Todo.h"
20 #include "flang/Semantics/expression.h"
21 
22 namespace {
23 /// Helper class to lower a designator containing vector subscripts into a
24 /// lowered representation that can be worked with.
25 class VectorSubscriptBoxBuilder {
26 public:
VectorSubscriptBoxBuilder(mlir::Location loc,Fortran::lower::AbstractConverter & converter,Fortran::lower::StatementContext & stmtCtx)27   VectorSubscriptBoxBuilder(mlir::Location loc,
28                             Fortran::lower::AbstractConverter &converter,
29                             Fortran::lower::StatementContext &stmtCtx)
30       : converter{converter}, stmtCtx{stmtCtx}, loc{loc} {}
31 
gen(const Fortran::lower::SomeExpr & expr)32   Fortran::lower::VectorSubscriptBox gen(const Fortran::lower::SomeExpr &expr) {
33     elementType = genDesignator(expr);
34     return Fortran::lower::VectorSubscriptBox(
35         std::move(loweredBase), std::move(loweredSubscripts),
36         std::move(componentPath), substringBounds, elementType);
37   }
38 
39 private:
40   using LoweredVectorSubscript =
41       Fortran::lower::VectorSubscriptBox::LoweredVectorSubscript;
42   using LoweredTriplet = Fortran::lower::VectorSubscriptBox::LoweredTriplet;
43   using LoweredSubscript = Fortran::lower::VectorSubscriptBox::LoweredSubscript;
44   using MaybeSubstring = Fortran::lower::VectorSubscriptBox::MaybeSubstring;
45 
46   /// genDesignator unwraps a Designator<T> and calls `gen` on what the
47   /// designator actually contains.
48   template <typename A>
genDesignator(const A &)49   mlir::Type genDesignator(const A &) {
50     fir::emitFatalError(loc, "expr must contain a designator");
51   }
52   template <typename T>
genDesignator(const Fortran::evaluate::Expr<T> & expr)53   mlir::Type genDesignator(const Fortran::evaluate::Expr<T> &expr) {
54     using ExprVariant = decltype(Fortran::evaluate::Expr<T>::u);
55     using Designator = Fortran::evaluate::Designator<T>;
56     if constexpr (Fortran::common::HasMember<Designator, ExprVariant>) {
57       const auto &designator = std::get<Designator>(expr.u);
58       return std::visit([&](const auto &x) { return gen(x); }, designator.u);
59     } else {
60       return std::visit([&](const auto &x) { return genDesignator(x); },
61                         expr.u);
62     }
63   }
64 
65   // The gen(X) methods visit X to lower its base and subscripts and return the
66   // type of X elements.
67 
gen(const Fortran::evaluate::DataRef & dataRef)68   mlir::Type gen(const Fortran::evaluate::DataRef &dataRef) {
69     return std::visit([&](const auto &ref) -> mlir::Type { return gen(ref); },
70                       dataRef.u);
71   }
72 
gen(const Fortran::evaluate::SymbolRef & symRef)73   mlir::Type gen(const Fortran::evaluate::SymbolRef &symRef) {
74     // Never visited because expr lowering is used to lowered the ranked
75     // ArrayRef.
76     fir::emitFatalError(
77         loc, "expected at least one ArrayRef with vector susbcripts");
78   }
79 
gen(const Fortran::evaluate::Substring & substring)80   mlir::Type gen(const Fortran::evaluate::Substring &substring) {
81     // StaticDataObject::Pointer bases are constants and cannot be
82     // subscripted, so the base must be a DataRef here.
83     mlir::Type baseElementType =
84         gen(std::get<Fortran::evaluate::DataRef>(substring.parent()));
85     fir::FirOpBuilder &builder = converter.getFirOpBuilder();
86     mlir::Type idxTy = builder.getIndexType();
87     mlir::Value lb = genScalarValue(substring.lower());
88     substringBounds.emplace_back(builder.createConvert(loc, idxTy, lb));
89     if (const auto &ubExpr = substring.upper()) {
90       mlir::Value ub = genScalarValue(*ubExpr);
91       substringBounds.emplace_back(builder.createConvert(loc, idxTy, ub));
92     }
93     return baseElementType;
94   }
95 
gen(const Fortran::evaluate::ComplexPart & complexPart)96   mlir::Type gen(const Fortran::evaluate::ComplexPart &complexPart) {
97     auto complexType = gen(complexPart.complex());
98     fir::FirOpBuilder &builder = converter.getFirOpBuilder();
99     mlir::Type i32Ty = builder.getI32Type(); // llvm's GEP requires i32
100     mlir::Value offset = builder.createIntegerConstant(
101         loc, i32Ty,
102         complexPart.part() == Fortran::evaluate::ComplexPart::Part::RE ? 0 : 1);
103     componentPath.emplace_back(offset);
104     return fir::factory::Complex{builder, loc}.getComplexPartType(complexType);
105   }
106 
gen(const Fortran::evaluate::Component & component)107   mlir::Type gen(const Fortran::evaluate::Component &component) {
108     auto recTy = gen(component.base()).cast<fir::RecordType>();
109     const Fortran::semantics::Symbol &componentSymbol =
110         component.GetLastSymbol();
111     // Parent components will not be found here, they are not part
112     // of the FIR type and cannot be used in the path yet.
113     if (componentSymbol.test(Fortran::semantics::Symbol::Flag::ParentComp))
114       TODO(loc, "reference to parent component");
115     mlir::Type fldTy = fir::FieldType::get(&converter.getMLIRContext());
116     llvm::StringRef componentName = toStringRef(componentSymbol.name());
117     // Parameters threading in field_index is not yet very clear. We only
118     // have the ones of the ranked array ref at hand, but it looks like
119     // the fir.field_index expects the one of the direct base.
120     if (recTy.getNumLenParams() != 0)
121       TODO(loc, "threading length parameters in field index op");
122     fir::FirOpBuilder &builder = converter.getFirOpBuilder();
123     componentPath.emplace_back(builder.create<fir::FieldIndexOp>(
124         loc, fldTy, componentName, recTy, /*typeParams*/ llvm::None));
125     return fir::unwrapSequenceType(recTy.getType(componentName));
126   }
127 
gen(const Fortran::evaluate::ArrayRef & arrayRef)128   mlir::Type gen(const Fortran::evaluate::ArrayRef &arrayRef) {
129     auto isTripletOrVector =
130         [](const Fortran::evaluate::Subscript &subscript) -> bool {
131       return std::visit(
132           Fortran::common::visitors{
133               [](const Fortran::evaluate::IndirectSubscriptIntegerExpr &expr) {
134                 return expr.value().Rank() != 0;
135               },
136               [&](const Fortran::evaluate::Triplet &) { return true; }},
137           subscript.u);
138     };
139     if (llvm::any_of(arrayRef.subscript(), isTripletOrVector))
140       return genRankedArrayRefSubscriptAndBase(arrayRef);
141 
142     // This is a scalar ArrayRef (only scalar indexes), collect the indexes and
143     // visit the base that must contain another arrayRef with the vector
144     // subscript.
145     mlir::Type elementType = gen(namedEntityToDataRef(arrayRef.base()));
146     for (const Fortran::evaluate::Subscript &subscript : arrayRef.subscript()) {
147       const auto &expr =
148           std::get<Fortran::evaluate::IndirectSubscriptIntegerExpr>(
149               subscript.u);
150       componentPath.emplace_back(genScalarValue(expr.value()));
151     }
152     return elementType;
153   }
154 
155   /// Lower the subscripts and base of the ArrayRef that is an array (there must
156   /// be one since there is a vector subscript, and there can only be one
157   /// according to C925).
genRankedArrayRefSubscriptAndBase(const Fortran::evaluate::ArrayRef & arrayRef)158   mlir::Type genRankedArrayRefSubscriptAndBase(
159       const Fortran::evaluate::ArrayRef &arrayRef) {
160     // Lower the save the base
161     Fortran::lower::SomeExpr baseExpr = namedEntityToExpr(arrayRef.base());
162     loweredBase = converter.genExprAddr(baseExpr, stmtCtx);
163     // Lower and save the subscripts
164     fir::FirOpBuilder &builder = converter.getFirOpBuilder();
165     mlir::Type idxTy = builder.getIndexType();
166     mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
167     for (const auto &subscript : llvm::enumerate(arrayRef.subscript())) {
168       std::visit(
169           Fortran::common::visitors{
170               [&](const Fortran::evaluate::IndirectSubscriptIntegerExpr &expr) {
171                 if (expr.value().Rank() == 0) {
172                   // Simple scalar subscript
173                   loweredSubscripts.emplace_back(genScalarValue(expr.value()));
174                 } else {
175                   // Vector subscript.
176                   // Remove conversion if any to avoid temp creation that may
177                   // have been added by the front-end to avoid the creation of a
178                   // temp array value.
179                   auto vector = converter.genExprAddr(
180                       ignoreEvConvert(expr.value()), stmtCtx);
181                   mlir::Value size =
182                       fir::factory::readExtent(builder, loc, vector, /*dim=*/0);
183                   size = builder.createConvert(loc, idxTy, size);
184                   loweredSubscripts.emplace_back(
185                       LoweredVectorSubscript{std::move(vector), size});
186                 }
187               },
188               [&](const Fortran::evaluate::Triplet &triplet) {
189                 mlir::Value lb, ub;
190                 if (const auto &lbExpr = triplet.lower())
191                   lb = genScalarValue(*lbExpr);
192                 else
193                   lb = fir::factory::readLowerBound(builder, loc, loweredBase,
194                                                     subscript.index(), one);
195                 if (const auto &ubExpr = triplet.upper())
196                   ub = genScalarValue(*ubExpr);
197                 else
198                   ub = fir::factory::readExtent(builder, loc, loweredBase,
199                                                 subscript.index());
200                 lb = builder.createConvert(loc, idxTy, lb);
201                 ub = builder.createConvert(loc, idxTy, ub);
202                 mlir::Value stride = genScalarValue(triplet.stride());
203                 stride = builder.createConvert(loc, idxTy, stride);
204                 loweredSubscripts.emplace_back(LoweredTriplet{lb, ub, stride});
205               },
206           },
207           subscript.value().u);
208     }
209     return fir::unwrapSequenceType(
210         fir::unwrapPassByRefType(fir::getBase(loweredBase).getType()));
211   }
212 
gen(const Fortran::evaluate::CoarrayRef &)213   mlir::Type gen(const Fortran::evaluate::CoarrayRef &) {
214     // Is this possible/legal ?
215     TODO(loc, "coarray ref with vector subscript in IO input");
216   }
217 
218   template <typename A>
genScalarValue(const A & expr)219   mlir::Value genScalarValue(const A &expr) {
220     return fir::getBase(converter.genExprValue(toEvExpr(expr), stmtCtx));
221   }
222 
223   Fortran::evaluate::DataRef
namedEntityToDataRef(const Fortran::evaluate::NamedEntity & namedEntity)224   namedEntityToDataRef(const Fortran::evaluate::NamedEntity &namedEntity) {
225     if (namedEntity.IsSymbol())
226       return Fortran::evaluate::DataRef{namedEntity.GetFirstSymbol()};
227     return Fortran::evaluate::DataRef{namedEntity.GetComponent()};
228   }
229 
230   Fortran::lower::SomeExpr
namedEntityToExpr(const Fortran::evaluate::NamedEntity & namedEntity)231   namedEntityToExpr(const Fortran::evaluate::NamedEntity &namedEntity) {
232     return Fortran::evaluate::AsGenericExpr(namedEntityToDataRef(namedEntity))
233         .value();
234   }
235 
236   Fortran::lower::AbstractConverter &converter;
237   Fortran::lower::StatementContext &stmtCtx;
238   mlir::Location loc;
239   /// Elements of VectorSubscriptBox being built.
240   fir::ExtendedValue loweredBase;
241   llvm::SmallVector<LoweredSubscript, 16> loweredSubscripts;
242   llvm::SmallVector<mlir::Value> componentPath;
243   MaybeSubstring substringBounds;
244   mlir::Type elementType;
245 };
246 } // namespace
247 
genVectorSubscriptBox(mlir::Location loc,Fortran::lower::AbstractConverter & converter,Fortran::lower::StatementContext & stmtCtx,const Fortran::lower::SomeExpr & expr)248 Fortran::lower::VectorSubscriptBox Fortran::lower::genVectorSubscriptBox(
249     mlir::Location loc, Fortran::lower::AbstractConverter &converter,
250     Fortran::lower::StatementContext &stmtCtx,
251     const Fortran::lower::SomeExpr &expr) {
252   return VectorSubscriptBoxBuilder(loc, converter, stmtCtx).gen(expr);
253 }
254 
255 template <typename LoopType, typename Generator>
loopOverElementsBase(fir::FirOpBuilder & builder,mlir::Location loc,const Generator & elementalGenerator,mlir::Value initialCondition)256 mlir::Value Fortran::lower::VectorSubscriptBox::loopOverElementsBase(
257     fir::FirOpBuilder &builder, mlir::Location loc,
258     const Generator &elementalGenerator,
259     [[maybe_unused]] mlir::Value initialCondition) {
260   mlir::Value shape = builder.createShape(loc, loweredBase);
261   mlir::Value slice = createSlice(builder, loc);
262 
263   // Create loop nest for triplets and vector subscripts in column
264   // major order.
265   llvm::SmallVector<mlir::Value> inductionVariables;
266   LoopType outerLoop;
267   for (auto [lb, ub, step] : genLoopBounds(builder, loc)) {
268     LoopType loop;
269     if constexpr (std::is_same_v<LoopType, fir::IterWhileOp>) {
270       loop =
271           builder.create<fir::IterWhileOp>(loc, lb, ub, step, initialCondition);
272       initialCondition = loop.getIterateVar();
273       if (!outerLoop)
274         outerLoop = loop;
275       else
276         builder.create<fir::ResultOp>(loc, loop.getResult(0));
277     } else {
278       loop =
279           builder.create<fir::DoLoopOp>(loc, lb, ub, step, /*unordered=*/false);
280       if (!outerLoop)
281         outerLoop = loop;
282     }
283     builder.setInsertionPointToStart(loop.getBody());
284     inductionVariables.push_back(loop.getInductionVar());
285   }
286   assert(outerLoop && !inductionVariables.empty() &&
287          "at least one loop should be created");
288 
289   fir::ExtendedValue elem =
290       getElementAt(builder, loc, shape, slice, inductionVariables);
291 
292   if constexpr (std::is_same_v<LoopType, fir::IterWhileOp>) {
293     auto res = elementalGenerator(elem);
294     builder.create<fir::ResultOp>(loc, res);
295     builder.setInsertionPointAfter(outerLoop);
296     return outerLoop.getResult(0);
297   } else {
298     elementalGenerator(elem);
299     builder.setInsertionPointAfter(outerLoop);
300     return {};
301   }
302 }
303 
loopOverElements(fir::FirOpBuilder & builder,mlir::Location loc,const ElementalGenerator & elementalGenerator)304 void Fortran::lower::VectorSubscriptBox::loopOverElements(
305     fir::FirOpBuilder &builder, mlir::Location loc,
306     const ElementalGenerator &elementalGenerator) {
307   mlir::Value initialCondition;
308   loopOverElementsBase<fir::DoLoopOp, ElementalGenerator>(
309       builder, loc, elementalGenerator, initialCondition);
310 }
311 
loopOverElementsWhile(fir::FirOpBuilder & builder,mlir::Location loc,const ElementalGeneratorWithBoolReturn & elementalGenerator,mlir::Value initialCondition)312 mlir::Value Fortran::lower::VectorSubscriptBox::loopOverElementsWhile(
313     fir::FirOpBuilder &builder, mlir::Location loc,
314     const ElementalGeneratorWithBoolReturn &elementalGenerator,
315     mlir::Value initialCondition) {
316   return loopOverElementsBase<fir::IterWhileOp,
317                               ElementalGeneratorWithBoolReturn>(
318       builder, loc, elementalGenerator, initialCondition);
319 }
320 
321 mlir::Value
createSlice(fir::FirOpBuilder & builder,mlir::Location loc)322 Fortran::lower::VectorSubscriptBox::createSlice(fir::FirOpBuilder &builder,
323                                                 mlir::Location loc) {
324   mlir::Type idxTy = builder.getIndexType();
325   llvm::SmallVector<mlir::Value> triples;
326   mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
327   auto undef = builder.create<fir::UndefOp>(loc, idxTy);
328   for (const LoweredSubscript &subscript : loweredSubscripts)
329     std::visit(Fortran::common::visitors{
330                    [&](const LoweredTriplet &triplet) {
331                      triples.emplace_back(triplet.lb);
332                      triples.emplace_back(triplet.ub);
333                      triples.emplace_back(triplet.stride);
334                    },
335                    [&](const LoweredVectorSubscript &vector) {
336                      triples.emplace_back(one);
337                      triples.emplace_back(vector.size);
338                      triples.emplace_back(one);
339                    },
340                    [&](const mlir::Value &i) {
341                      triples.emplace_back(i);
342                      triples.emplace_back(undef);
343                      triples.emplace_back(undef);
344                    },
345                },
346                subscript);
347   return builder.create<fir::SliceOp>(loc, triples, componentPath);
348 }
349 
350 llvm::SmallVector<std::tuple<mlir::Value, mlir::Value, mlir::Value>>
genLoopBounds(fir::FirOpBuilder & builder,mlir::Location loc)351 Fortran::lower::VectorSubscriptBox::genLoopBounds(fir::FirOpBuilder &builder,
352                                                   mlir::Location loc) {
353   mlir::Type idxTy = builder.getIndexType();
354   mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
355   mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0);
356   llvm::SmallVector<std::tuple<mlir::Value, mlir::Value, mlir::Value>> bounds;
357   size_t dimension = loweredSubscripts.size();
358   for (const LoweredSubscript &subscript : llvm::reverse(loweredSubscripts)) {
359     --dimension;
360     if (std::holds_alternative<mlir::Value>(subscript))
361       continue;
362     mlir::Value lb, ub, step;
363     if (const auto *triplet = std::get_if<LoweredTriplet>(&subscript)) {
364       mlir::Value extent = builder.genExtentFromTriplet(
365           loc, triplet->lb, triplet->ub, triplet->stride, idxTy);
366       mlir::Value baseLb = fir::factory::readLowerBound(
367           builder, loc, loweredBase, dimension, one);
368       baseLb = builder.createConvert(loc, idxTy, baseLb);
369       lb = baseLb;
370       ub = builder.create<mlir::arith::SubIOp>(loc, idxTy, extent, one);
371       ub = builder.create<mlir::arith::AddIOp>(loc, idxTy, ub, baseLb);
372       step = one;
373     } else {
374       const auto &vector = std::get<LoweredVectorSubscript>(subscript);
375       lb = zero;
376       ub = builder.create<mlir::arith::SubIOp>(loc, idxTy, vector.size, one);
377       step = one;
378     }
379     bounds.emplace_back(lb, ub, step);
380   }
381   return bounds;
382 }
383 
getElementAt(fir::FirOpBuilder & builder,mlir::Location loc,mlir::Value shape,mlir::Value slice,mlir::ValueRange inductionVariables)384 fir::ExtendedValue Fortran::lower::VectorSubscriptBox::getElementAt(
385     fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value shape,
386     mlir::Value slice, mlir::ValueRange inductionVariables) {
387   /// Generate the indexes for the array_coor inside the loops.
388   mlir::Type idxTy = builder.getIndexType();
389   llvm::SmallVector<mlir::Value> indexes;
390   size_t inductionIdx = inductionVariables.size() - 1;
391   for (const LoweredSubscript &subscript : loweredSubscripts)
392     std::visit(Fortran::common::visitors{
393                    [&](const LoweredTriplet &triplet) {
394                      indexes.emplace_back(inductionVariables[inductionIdx--]);
395                    },
396                    [&](const LoweredVectorSubscript &vector) {
397                      mlir::Value vecIndex = inductionVariables[inductionIdx--];
398                      mlir::Value vecBase = fir::getBase(vector.vector);
399                      mlir::Type vecEleTy = fir::unwrapSequenceType(
400                          fir::unwrapPassByRefType(vecBase.getType()));
401                      mlir::Type refTy = builder.getRefType(vecEleTy);
402                      auto vecEltRef = builder.create<fir::CoordinateOp>(
403                          loc, refTy, vecBase, vecIndex);
404                      auto vecElt =
405                          builder.create<fir::LoadOp>(loc, vecEleTy, vecEltRef);
406                      indexes.emplace_back(
407                          builder.createConvert(loc, idxTy, vecElt));
408                    },
409                    [&](const mlir::Value &i) {
410                      indexes.emplace_back(builder.createConvert(loc, idxTy, i));
411                    },
412                },
413                subscript);
414   mlir::Type refTy = builder.getRefType(getElementType());
415   auto elementAddr = builder.create<fir::ArrayCoorOp>(
416       loc, refTy, fir::getBase(loweredBase), shape, slice, indexes,
417       fir::getTypeParams(loweredBase));
418   fir::ExtendedValue element = fir::factory::arraySectionElementToExtendedValue(
419       builder, loc, loweredBase, elementAddr, slice);
420   if (!substringBounds.empty()) {
421     const fir::CharBoxValue *charBox = element.getCharBox();
422     assert(charBox && "substring requires CharBox base");
423     fir::factory::CharacterExprHelper helper{builder, loc};
424     return helper.createSubstring(*charBox, substringBounds);
425   }
426   return element;
427 }
428