1 //===- TypeConverter.cpp - Convert builtin to LLVM dialect types ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
10 #include "MemRefDescriptor.h"
11 #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
14 
15 using namespace mlir;
16 
17 /// Create an LLVMTypeConverter using default LowerToLLVMOptions.
LLVMTypeConverter(MLIRContext * ctx,const DataLayoutAnalysis * analysis)18 LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
19                                      const DataLayoutAnalysis *analysis)
20     : LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), analysis) {}
21 
22 /// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
LLVMTypeConverter(MLIRContext * ctx,const LowerToLLVMOptions & options,const DataLayoutAnalysis * analysis)23 LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
24                                      const LowerToLLVMOptions &options,
25                                      const DataLayoutAnalysis *analysis)
26     : llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()), options(options),
27       dataLayoutAnalysis(analysis) {
28   assert(llvmDialect && "LLVM IR dialect is not registered");
29 
30   // Register conversions for the builtin types.
31   addConversion([&](ComplexType type) { return convertComplexType(type); });
32   addConversion([&](FloatType type) { return convertFloatType(type); });
33   addConversion([&](FunctionType type) { return convertFunctionType(type); });
34   addConversion([&](IndexType type) { return convertIndexType(type); });
35   addConversion([&](IntegerType type) { return convertIntegerType(type); });
36   addConversion([&](MemRefType type) { return convertMemRefType(type); });
37   addConversion(
38       [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
39   addConversion([&](VectorType type) { return convertVectorType(type); });
40 
41   // LLVM-compatible types are legal, so add a pass-through conversion. Do this
42   // before the conversions below since conversions are attempted in reverse
43   // order and those should take priority.
44   addConversion([](Type type) {
45     return LLVM::isCompatibleType(type) ? llvm::Optional<Type>(type)
46                                         : llvm::None;
47   });
48 
49   // LLVM container types may (recursively) contain other types that must be
50   // converted even when the outer type is compatible.
51   addConversion([&](LLVM::LLVMPointerType type) -> llvm::Optional<Type> {
52     if (type.isOpaque())
53       return type;
54     if (auto pointee = convertType(type.getElementType()))
55       return LLVM::LLVMPointerType::get(pointee, type.getAddressSpace());
56     return llvm::None;
57   });
58   addConversion([&](LLVM::LLVMStructType type, SmallVectorImpl<Type> &results,
59                     ArrayRef<Type> callStack) -> llvm::Optional<LogicalResult> {
60     // Fastpath for types that won't be converted by this callback anyway.
61     if (LLVM::isCompatibleType(type)) {
62       results.push_back(type);
63       return success();
64     }
65 
66     if (type.isIdentified()) {
67       auto convertedType = LLVM::LLVMStructType::getIdentified(
68           type.getContext(), ("_Converted_" + type.getName()).str());
69       unsigned counter = 1;
70       while (convertedType.isInitialized()) {
71         assert(counter != UINT_MAX &&
72                "about to overflow struct renaming counter in conversion");
73         convertedType = LLVM::LLVMStructType::getIdentified(
74             type.getContext(),
75             ("_Converted_" + std::to_string(counter) + type.getName()).str());
76       }
77       if (llvm::count(callStack, type) > 1) {
78         results.push_back(convertedType);
79         return success();
80       }
81 
82       SmallVector<Type> convertedElemTypes;
83       convertedElemTypes.reserve(type.getBody().size());
84       if (failed(convertTypes(type.getBody(), convertedElemTypes)))
85         return llvm::None;
86 
87       if (failed(convertedType.setBody(convertedElemTypes, type.isPacked())))
88         return failure();
89       results.push_back(convertedType);
90       return success();
91     }
92 
93     SmallVector<Type> convertedSubtypes;
94     convertedSubtypes.reserve(type.getBody().size());
95     if (failed(convertTypes(type.getBody(), convertedSubtypes)))
96       return llvm::None;
97 
98     results.push_back(LLVM::LLVMStructType::getLiteral(
99         type.getContext(), convertedSubtypes, type.isPacked()));
100     return success();
101   });
102   addConversion([&](LLVM::LLVMArrayType type) -> llvm::Optional<Type> {
103     if (auto element = convertType(type.getElementType()))
104       return LLVM::LLVMArrayType::get(element, type.getNumElements());
105     return llvm::None;
106   });
107   addConversion([&](LLVM::LLVMFunctionType type) -> llvm::Optional<Type> {
108     Type convertedResType = convertType(type.getReturnType());
109     if (!convertedResType)
110       return llvm::None;
111 
112     SmallVector<Type> convertedArgTypes;
113     convertedArgTypes.reserve(type.getNumParams());
114     if (failed(convertTypes(type.getParams(), convertedArgTypes)))
115       return llvm::None;
116 
117     return LLVM::LLVMFunctionType::get(convertedResType, convertedArgTypes,
118                                        type.isVarArg());
119   });
120 
121   // Materialization for memrefs creates descriptor structs from individual
122   // values constituting them, when descriptors are used, i.e. more than one
123   // value represents a memref.
124   addArgumentMaterialization(
125       [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
126           Location loc) -> Optional<Value> {
127         if (inputs.size() == 1)
128           return llvm::None;
129         return UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType,
130                                               inputs);
131       });
132   addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
133                                  ValueRange inputs,
134                                  Location loc) -> Optional<Value> {
135     // TODO: bare ptr conversion could be handled here but we would need a way
136     // to distinguish between FuncOp and other regions.
137     if (inputs.size() == 1)
138       return llvm::None;
139     return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
140   });
141   // Add generic source and target materializations to handle cases where
142   // non-LLVM types persist after an LLVM conversion.
143   addSourceMaterialization([&](OpBuilder &builder, Type resultType,
144                                ValueRange inputs,
145                                Location loc) -> Optional<Value> {
146     if (inputs.size() != 1)
147       return llvm::None;
148 
149     return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
150         .getResult(0);
151   });
152   addTargetMaterialization([&](OpBuilder &builder, Type resultType,
153                                ValueRange inputs,
154                                Location loc) -> Optional<Value> {
155     if (inputs.size() != 1)
156       return llvm::None;
157 
158     return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
159         .getResult(0);
160   });
161 }
162 
163 /// Returns the MLIR context.
getContext()164 MLIRContext &LLVMTypeConverter::getContext() {
165   return *getDialect()->getContext();
166 }
167 
getIndexType()168 Type LLVMTypeConverter::getIndexType() {
169   return IntegerType::get(&getContext(), getIndexTypeBitwidth());
170 }
171 
getPointerBitwidth(unsigned addressSpace)172 unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) {
173   return options.dataLayout.getPointerSizeInBits(addressSpace);
174 }
175 
convertIndexType(IndexType type)176 Type LLVMTypeConverter::convertIndexType(IndexType type) {
177   return getIndexType();
178 }
179 
convertIntegerType(IntegerType type)180 Type LLVMTypeConverter::convertIntegerType(IntegerType type) {
181   return IntegerType::get(&getContext(), type.getWidth());
182 }
183 
convertFloatType(FloatType type)184 Type LLVMTypeConverter::convertFloatType(FloatType type) { return type; }
185 
186 // Convert a `ComplexType` to an LLVM type. The result is a complex number
187 // struct with entries for the
188 //   1. real part and for the
189 //   2. imaginary part.
convertComplexType(ComplexType type)190 Type LLVMTypeConverter::convertComplexType(ComplexType type) {
191   auto elementType = convertType(type.getElementType());
192   return LLVM::LLVMStructType::getLiteral(&getContext(),
193                                           {elementType, elementType});
194 }
195 
196 // Except for signatures, MLIR function types are converted into LLVM
197 // pointer-to-function types.
convertFunctionType(FunctionType type)198 Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
199   SignatureConversion conversion(type.getNumInputs());
200   Type converted =
201       convertFunctionSignature(type, /*isVariadic=*/false, conversion);
202   return LLVM::LLVMPointerType::get(converted);
203 }
204 
205 // Function types are converted to LLVM Function types by recursively converting
206 // argument and result types.  If MLIR Function has zero results, the LLVM
207 // Function has one VoidType result.  If MLIR Function has more than one result,
208 // they are into an LLVM StructType in their order of appearance.
convertFunctionSignature(FunctionType funcTy,bool isVariadic,LLVMTypeConverter::SignatureConversion & result)209 Type LLVMTypeConverter::convertFunctionSignature(
210     FunctionType funcTy, bool isVariadic,
211     LLVMTypeConverter::SignatureConversion &result) {
212   // Select the argument converter depending on the calling convention.
213   auto funcArgConverter = options.useBarePtrCallConv
214                               ? barePtrFuncArgTypeConverter
215                               : structFuncArgTypeConverter;
216   // Convert argument types one by one and check for errors.
217   for (auto &en : llvm::enumerate(funcTy.getInputs())) {
218     Type type = en.value();
219     SmallVector<Type, 8> converted;
220     if (failed(funcArgConverter(*this, type, converted)))
221       return {};
222     result.addInputs(en.index(), converted);
223   }
224 
225   // If function does not return anything, create the void result type,
226   // if it returns on element, convert it, otherwise pack the result types into
227   // a struct.
228   Type resultType = funcTy.getNumResults() == 0
229                         ? LLVM::LLVMVoidType::get(&getContext())
230                         : packFunctionResults(funcTy.getResults());
231   if (!resultType)
232     return {};
233   return LLVM::LLVMFunctionType::get(resultType, result.getConvertedTypes(),
234                                      isVariadic);
235 }
236 
237 /// Converts the function type to a C-compatible format, in particular using
238 /// pointers to memref descriptors for arguments.
239 std::pair<Type, bool>
convertFunctionTypeCWrapper(FunctionType type)240 LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
241   SmallVector<Type, 4> inputs;
242   bool resultIsNowArg = false;
243 
244   Type resultType = type.getNumResults() == 0
245                         ? LLVM::LLVMVoidType::get(&getContext())
246                         : packFunctionResults(type.getResults());
247   if (!resultType)
248     return {};
249 
250   if (auto structType = resultType.dyn_cast<LLVM::LLVMStructType>()) {
251     // Struct types cannot be safely returned via C interface. Make this a
252     // pointer argument, instead.
253     inputs.push_back(LLVM::LLVMPointerType::get(structType));
254     resultType = LLVM::LLVMVoidType::get(&getContext());
255     resultIsNowArg = true;
256   }
257 
258   for (Type t : type.getInputs()) {
259     auto converted = convertType(t);
260     if (!converted || !LLVM::isCompatibleType(converted))
261       return {};
262     if (t.isa<MemRefType, UnrankedMemRefType>())
263       converted = LLVM::LLVMPointerType::get(converted);
264     inputs.push_back(converted);
265   }
266 
267   return {LLVM::LLVMFunctionType::get(resultType, inputs), resultIsNowArg};
268 }
269 
270 /// Convert a memref type into a list of LLVM IR types that will form the
271 /// memref descriptor. The result contains the following types:
272 ///  1. The pointer to the allocated data buffer, followed by
273 ///  2. The pointer to the aligned data buffer, followed by
274 ///  3. A lowered `index`-type integer containing the distance between the
275 ///  beginning of the buffer and the first element to be accessed through the
276 ///  view, followed by
277 ///  4. An array containing as many `index`-type integers as the rank of the
278 ///  MemRef: the array represents the size, in number of elements, of the memref
279 ///  along the given dimension. For constant MemRef dimensions, the
280 ///  corresponding size entry is a constant whose runtime value must match the
281 ///  static value, followed by
282 ///  5. A second array containing as many `index`-type integers as the rank of
283 ///  the MemRef: the second array represents the "stride" (in tensor abstraction
284 ///  sense), i.e. the number of consecutive elements of the underlying buffer.
285 ///  TODO: add assertions for the static cases.
286 ///
287 ///  If `unpackAggregates` is set to true, the arrays described in (4) and (5)
288 ///  are expanded into individual index-type elements.
289 ///
290 ///  template <typename Elem, typename Index, size_t Rank>
291 ///  struct {
292 ///    Elem *allocatedPtr;
293 ///    Elem *alignedPtr;
294 ///    Index offset;
295 ///    Index sizes[Rank]; // omitted when rank == 0
296 ///    Index strides[Rank]; // omitted when rank == 0
297 ///  };
298 SmallVector<Type, 5>
getMemRefDescriptorFields(MemRefType type,bool unpackAggregates)299 LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
300                                              bool unpackAggregates) {
301   assert(isStrided(type) &&
302          "Non-strided layout maps must have been normalized away");
303 
304   Type elementType = convertType(type.getElementType());
305   if (!elementType)
306     return {};
307   auto ptrTy =
308       LLVM::LLVMPointerType::get(elementType, type.getMemorySpaceAsInt());
309   auto indexTy = getIndexType();
310 
311   SmallVector<Type, 5> results = {ptrTy, ptrTy, indexTy};
312   auto rank = type.getRank();
313   if (rank == 0)
314     return results;
315 
316   if (unpackAggregates)
317     results.insert(results.end(), 2 * rank, indexTy);
318   else
319     results.insert(results.end(), 2, LLVM::LLVMArrayType::get(indexTy, rank));
320   return results;
321 }
322 
getMemRefDescriptorSize(MemRefType type,const DataLayout & layout)323 unsigned LLVMTypeConverter::getMemRefDescriptorSize(MemRefType type,
324                                                     const DataLayout &layout) {
325   // Compute the descriptor size given that of its components indicated above.
326   unsigned space = type.getMemorySpaceAsInt();
327   return 2 * llvm::divideCeil(getPointerBitwidth(space), 8) +
328          (1 + 2 * type.getRank()) * layout.getTypeSize(getIndexType());
329 }
330 
331 /// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that
332 /// packs the descriptor fields as defined by `getMemRefDescriptorFields`.
convertMemRefType(MemRefType type)333 Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
334   // When converting a MemRefType to a struct with descriptor fields, do not
335   // unpack the `sizes` and `strides` arrays.
336   SmallVector<Type, 5> types =
337       getMemRefDescriptorFields(type, /*unpackAggregates=*/false);
338   if (types.empty())
339     return {};
340   return LLVM::LLVMStructType::getLiteral(&getContext(), types);
341 }
342 
343 /// Convert an unranked memref type into a list of non-aggregate LLVM IR types
344 /// that will form the unranked memref descriptor. In particular, the fields
345 /// for an unranked memref descriptor are:
346 /// 1. index-typed rank, the dynamic rank of this MemRef
347 /// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be
348 ///    stack allocated (alloca) copy of a MemRef descriptor that got casted to
349 ///    be unranked.
getUnrankedMemRefDescriptorFields()350 SmallVector<Type, 2> LLVMTypeConverter::getUnrankedMemRefDescriptorFields() {
351   return {getIndexType(),
352           LLVM::LLVMPointerType::get(IntegerType::get(&getContext(), 8))};
353 }
354 
355 unsigned
getUnrankedMemRefDescriptorSize(UnrankedMemRefType type,const DataLayout & layout)356 LLVMTypeConverter::getUnrankedMemRefDescriptorSize(UnrankedMemRefType type,
357                                                    const DataLayout &layout) {
358   // Compute the descriptor size given that of its components indicated above.
359   unsigned space = type.getMemorySpaceAsInt();
360   return layout.getTypeSize(getIndexType()) +
361          llvm::divideCeil(getPointerBitwidth(space), 8);
362 }
363 
convertUnrankedMemRefType(UnrankedMemRefType type)364 Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) {
365   if (!convertType(type.getElementType()))
366     return {};
367   return LLVM::LLVMStructType::getLiteral(&getContext(),
368                                           getUnrankedMemRefDescriptorFields());
369 }
370 
371 // Check if a memref type can be converted to a bare pointer.
canConvertToBarePtr(BaseMemRefType type)372 bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type) {
373   if (type.isa<UnrankedMemRefType>())
374     // Unranked memref is not supported in the bare pointer calling convention.
375     return false;
376 
377   // Check that the memref has static shape, strides and offset. Otherwise, it
378   // cannot be lowered to a bare pointer.
379   auto memrefTy = type.cast<MemRefType>();
380   if (!memrefTy.hasStaticShape())
381     return false;
382 
383   int64_t offset = 0;
384   SmallVector<int64_t, 4> strides;
385   if (failed(getStridesAndOffset(memrefTy, strides, offset)))
386     return false;
387 
388   for (int64_t stride : strides)
389     if (ShapedType::isDynamicStrideOrOffset(stride))
390       return false;
391 
392   return !ShapedType::isDynamicStrideOrOffset(offset);
393 }
394 
395 /// Convert a memref type to a bare pointer to the memref element type.
convertMemRefToBarePtr(BaseMemRefType type)396 Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) {
397   if (!canConvertToBarePtr(type))
398     return {};
399   Type elementType = convertType(type.getElementType());
400   if (!elementType)
401     return {};
402   return LLVM::LLVMPointerType::get(elementType, type.getMemorySpaceAsInt());
403 }
404 
405 /// Convert an n-D vector type to an LLVM vector type:
406 ///  * 0-D `vector<T>` are converted to vector<1xT>
407 ///  * 1-D `vector<axT>` remains as is while,
408 ///  * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to
409 ///    `!llvm.array<ax...array<jxvector<kxT>>>`.
convertVectorType(VectorType type)410 Type LLVMTypeConverter::convertVectorType(VectorType type) {
411   auto elementType = convertType(type.getElementType());
412   if (!elementType)
413     return {};
414   if (type.getShape().empty())
415     return VectorType::get({1}, elementType);
416   Type vectorType = VectorType::get(type.getShape().back(), elementType,
417                                     type.getNumScalableDims());
418   assert(LLVM::isCompatibleVectorType(vectorType) &&
419          "expected vector type compatible with the LLVM dialect");
420   auto shape = type.getShape();
421   for (int i = shape.size() - 2; i >= 0; --i)
422     vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);
423   return vectorType;
424 }
425 
426 /// Convert a type in the context of the default or bare pointer calling
427 /// convention. Calling convention sensitive types, such as MemRefType and
428 /// UnrankedMemRefType, are converted following the specific rules for the
429 /// calling convention. Calling convention independent types are converted
430 /// following the default LLVM type conversions.
convertCallingConventionType(Type type)431 Type LLVMTypeConverter::convertCallingConventionType(Type type) {
432   if (options.useBarePtrCallConv)
433     if (auto memrefTy = type.dyn_cast<BaseMemRefType>())
434       return convertMemRefToBarePtr(memrefTy);
435 
436   return convertType(type);
437 }
438 
439 /// Promote the bare pointers in 'values' that resulted from memrefs to
440 /// descriptors. 'stdTypes' holds they types of 'values' before the conversion
441 /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
promoteBarePtrsToDescriptors(ConversionPatternRewriter & rewriter,Location loc,ArrayRef<Type> stdTypes,SmallVectorImpl<Value> & values)442 void LLVMTypeConverter::promoteBarePtrsToDescriptors(
443     ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes,
444     SmallVectorImpl<Value> &values) {
445   assert(stdTypes.size() == values.size() &&
446          "The number of types and values doesn't match");
447   for (unsigned i = 0, end = values.size(); i < end; ++i)
448     if (auto memrefTy = stdTypes[i].dyn_cast<MemRefType>())
449       values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this,
450                                                     memrefTy, values[i]);
451 }
452 
453 /// Convert a non-empty list of types to be returned from a function into a
454 /// supported LLVM IR type.  In particular, if more than one value is returned,
455 /// create an LLVM IR structure type with elements that correspond to each of
456 /// the MLIR types converted with `convertType`.
packFunctionResults(TypeRange types)457 Type LLVMTypeConverter::packFunctionResults(TypeRange types) {
458   assert(!types.empty() && "expected non-empty list of type");
459 
460   if (types.size() == 1)
461     return convertCallingConventionType(types.front());
462 
463   SmallVector<Type, 8> resultTypes;
464   resultTypes.reserve(types.size());
465   for (auto t : types) {
466     auto converted = convertCallingConventionType(t);
467     if (!converted || !LLVM::isCompatibleType(converted))
468       return {};
469     resultTypes.push_back(converted);
470   }
471 
472   return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
473 }
474 
promoteOneMemRefDescriptor(Location loc,Value operand,OpBuilder & builder)475 Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
476                                                     OpBuilder &builder) {
477   auto *context = builder.getContext();
478   auto int64Ty = IntegerType::get(builder.getContext(), 64);
479   auto indexType = IndexType::get(context);
480   // Alloca with proper alignment. We do not expect optimizations of this
481   // alloca op and so we omit allocating at the entry block.
482   auto ptrType = LLVM::LLVMPointerType::get(operand.getType());
483   Value one = builder.create<LLVM::ConstantOp>(loc, int64Ty,
484                                                IntegerAttr::get(indexType, 1));
485   Value allocated =
486       builder.create<LLVM::AllocaOp>(loc, ptrType, one, /*alignment=*/0);
487   // Store into the alloca'ed descriptor.
488   builder.create<LLVM::StoreOp>(loc, operand, allocated);
489   return allocated;
490 }
491 
promoteOperands(Location loc,ValueRange opOperands,ValueRange operands,OpBuilder & builder)492 SmallVector<Value, 4> LLVMTypeConverter::promoteOperands(Location loc,
493                                                          ValueRange opOperands,
494                                                          ValueRange operands,
495                                                          OpBuilder &builder) {
496   SmallVector<Value, 4> promotedOperands;
497   promotedOperands.reserve(operands.size());
498   for (auto it : llvm::zip(opOperands, operands)) {
499     auto operand = std::get<0>(it);
500     auto llvmOperand = std::get<1>(it);
501 
502     if (options.useBarePtrCallConv) {
503       // For the bare-ptr calling convention, we only have to extract the
504       // aligned pointer of a memref.
505       if (auto memrefType = operand.getType().dyn_cast<MemRefType>()) {
506         MemRefDescriptor desc(llvmOperand);
507         llvmOperand = desc.alignedPtr(builder, loc);
508       } else if (operand.getType().isa<UnrankedMemRefType>()) {
509         llvm_unreachable("Unranked memrefs are not supported");
510       }
511     } else {
512       if (operand.getType().isa<UnrankedMemRefType>()) {
513         UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand,
514                                          promotedOperands);
515         continue;
516       }
517       if (auto memrefType = operand.getType().dyn_cast<MemRefType>()) {
518         MemRefDescriptor::unpack(builder, loc, llvmOperand, memrefType,
519                                  promotedOperands);
520         continue;
521       }
522     }
523 
524     promotedOperands.push_back(llvmOperand);
525   }
526   return promotedOperands;
527 }
528 
529 /// Callback to convert function argument types. It converts a MemRef function
530 /// argument to a list of non-aggregate types containing descriptor
531 /// information, and an UnrankedmemRef function argument to a list containing
532 /// the rank and a pointer to a descriptor struct.
structFuncArgTypeConverter(LLVMTypeConverter & converter,Type type,SmallVectorImpl<Type> & result)533 LogicalResult mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter,
534                                                Type type,
535                                                SmallVectorImpl<Type> &result) {
536   if (auto memref = type.dyn_cast<MemRefType>()) {
537     // In signatures, Memref descriptors are expanded into lists of
538     // non-aggregate values.
539     auto converted =
540         converter.getMemRefDescriptorFields(memref, /*unpackAggregates=*/true);
541     if (converted.empty())
542       return failure();
543     result.append(converted.begin(), converted.end());
544     return success();
545   }
546   if (type.isa<UnrankedMemRefType>()) {
547     auto converted = converter.getUnrankedMemRefDescriptorFields();
548     if (converted.empty())
549       return failure();
550     result.append(converted.begin(), converted.end());
551     return success();
552   }
553   auto converted = converter.convertType(type);
554   if (!converted)
555     return failure();
556   result.push_back(converted);
557   return success();
558 }
559 
560 /// Callback to convert function argument types. It converts MemRef function
561 /// arguments to bare pointers to the MemRef element type.
barePtrFuncArgTypeConverter(LLVMTypeConverter & converter,Type type,SmallVectorImpl<Type> & result)562 LogicalResult mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter,
563                                                 Type type,
564                                                 SmallVectorImpl<Type> &result) {
565   auto llvmTy = converter.convertCallingConventionType(type);
566   if (!llvmTy)
567     return failure();
568 
569   result.push_back(llvmTy);
570   return success();
571 }
572