1 //===- TypeConverter.cpp - Convert builtin to LLVM dialect types ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
10 #include "MemRefDescriptor.h"
11 #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
14
15 using namespace mlir;
16
17 /// Create an LLVMTypeConverter using default LowerToLLVMOptions.
LLVMTypeConverter(MLIRContext * ctx,const DataLayoutAnalysis * analysis)18 LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
19 const DataLayoutAnalysis *analysis)
20 : LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), analysis) {}
21
22 /// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
LLVMTypeConverter(MLIRContext * ctx,const LowerToLLVMOptions & options,const DataLayoutAnalysis * analysis)23 LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
24 const LowerToLLVMOptions &options,
25 const DataLayoutAnalysis *analysis)
26 : llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()), options(options),
27 dataLayoutAnalysis(analysis) {
28 assert(llvmDialect && "LLVM IR dialect is not registered");
29
30 // Register conversions for the builtin types.
31 addConversion([&](ComplexType type) { return convertComplexType(type); });
32 addConversion([&](FloatType type) { return convertFloatType(type); });
33 addConversion([&](FunctionType type) { return convertFunctionType(type); });
34 addConversion([&](IndexType type) { return convertIndexType(type); });
35 addConversion([&](IntegerType type) { return convertIntegerType(type); });
36 addConversion([&](MemRefType type) { return convertMemRefType(type); });
37 addConversion(
38 [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
39 addConversion([&](VectorType type) { return convertVectorType(type); });
40
41 // LLVM-compatible types are legal, so add a pass-through conversion. Do this
42 // before the conversions below since conversions are attempted in reverse
43 // order and those should take priority.
44 addConversion([](Type type) {
45 return LLVM::isCompatibleType(type) ? llvm::Optional<Type>(type)
46 : llvm::None;
47 });
48
49 // LLVM container types may (recursively) contain other types that must be
50 // converted even when the outer type is compatible.
51 addConversion([&](LLVM::LLVMPointerType type) -> llvm::Optional<Type> {
52 if (type.isOpaque())
53 return type;
54 if (auto pointee = convertType(type.getElementType()))
55 return LLVM::LLVMPointerType::get(pointee, type.getAddressSpace());
56 return llvm::None;
57 });
58 addConversion([&](LLVM::LLVMStructType type, SmallVectorImpl<Type> &results,
59 ArrayRef<Type> callStack) -> llvm::Optional<LogicalResult> {
60 // Fastpath for types that won't be converted by this callback anyway.
61 if (LLVM::isCompatibleType(type)) {
62 results.push_back(type);
63 return success();
64 }
65
66 if (type.isIdentified()) {
67 auto convertedType = LLVM::LLVMStructType::getIdentified(
68 type.getContext(), ("_Converted_" + type.getName()).str());
69 unsigned counter = 1;
70 while (convertedType.isInitialized()) {
71 assert(counter != UINT_MAX &&
72 "about to overflow struct renaming counter in conversion");
73 convertedType = LLVM::LLVMStructType::getIdentified(
74 type.getContext(),
75 ("_Converted_" + std::to_string(counter) + type.getName()).str());
76 }
77 if (llvm::count(callStack, type) > 1) {
78 results.push_back(convertedType);
79 return success();
80 }
81
82 SmallVector<Type> convertedElemTypes;
83 convertedElemTypes.reserve(type.getBody().size());
84 if (failed(convertTypes(type.getBody(), convertedElemTypes)))
85 return llvm::None;
86
87 if (failed(convertedType.setBody(convertedElemTypes, type.isPacked())))
88 return failure();
89 results.push_back(convertedType);
90 return success();
91 }
92
93 SmallVector<Type> convertedSubtypes;
94 convertedSubtypes.reserve(type.getBody().size());
95 if (failed(convertTypes(type.getBody(), convertedSubtypes)))
96 return llvm::None;
97
98 results.push_back(LLVM::LLVMStructType::getLiteral(
99 type.getContext(), convertedSubtypes, type.isPacked()));
100 return success();
101 });
102 addConversion([&](LLVM::LLVMArrayType type) -> llvm::Optional<Type> {
103 if (auto element = convertType(type.getElementType()))
104 return LLVM::LLVMArrayType::get(element, type.getNumElements());
105 return llvm::None;
106 });
107 addConversion([&](LLVM::LLVMFunctionType type) -> llvm::Optional<Type> {
108 Type convertedResType = convertType(type.getReturnType());
109 if (!convertedResType)
110 return llvm::None;
111
112 SmallVector<Type> convertedArgTypes;
113 convertedArgTypes.reserve(type.getNumParams());
114 if (failed(convertTypes(type.getParams(), convertedArgTypes)))
115 return llvm::None;
116
117 return LLVM::LLVMFunctionType::get(convertedResType, convertedArgTypes,
118 type.isVarArg());
119 });
120
121 // Materialization for memrefs creates descriptor structs from individual
122 // values constituting them, when descriptors are used, i.e. more than one
123 // value represents a memref.
124 addArgumentMaterialization(
125 [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
126 Location loc) -> Optional<Value> {
127 if (inputs.size() == 1)
128 return llvm::None;
129 return UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType,
130 inputs);
131 });
132 addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
133 ValueRange inputs,
134 Location loc) -> Optional<Value> {
135 // TODO: bare ptr conversion could be handled here but we would need a way
136 // to distinguish between FuncOp and other regions.
137 if (inputs.size() == 1)
138 return llvm::None;
139 return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
140 });
141 // Add generic source and target materializations to handle cases where
142 // non-LLVM types persist after an LLVM conversion.
143 addSourceMaterialization([&](OpBuilder &builder, Type resultType,
144 ValueRange inputs,
145 Location loc) -> Optional<Value> {
146 if (inputs.size() != 1)
147 return llvm::None;
148
149 return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
150 .getResult(0);
151 });
152 addTargetMaterialization([&](OpBuilder &builder, Type resultType,
153 ValueRange inputs,
154 Location loc) -> Optional<Value> {
155 if (inputs.size() != 1)
156 return llvm::None;
157
158 return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
159 .getResult(0);
160 });
161 }
162
163 /// Returns the MLIR context.
getContext()164 MLIRContext &LLVMTypeConverter::getContext() {
165 return *getDialect()->getContext();
166 }
167
getIndexType()168 Type LLVMTypeConverter::getIndexType() {
169 return IntegerType::get(&getContext(), getIndexTypeBitwidth());
170 }
171
getPointerBitwidth(unsigned addressSpace)172 unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) {
173 return options.dataLayout.getPointerSizeInBits(addressSpace);
174 }
175
convertIndexType(IndexType type)176 Type LLVMTypeConverter::convertIndexType(IndexType type) {
177 return getIndexType();
178 }
179
convertIntegerType(IntegerType type)180 Type LLVMTypeConverter::convertIntegerType(IntegerType type) {
181 return IntegerType::get(&getContext(), type.getWidth());
182 }
183
convertFloatType(FloatType type)184 Type LLVMTypeConverter::convertFloatType(FloatType type) { return type; }
185
186 // Convert a `ComplexType` to an LLVM type. The result is a complex number
187 // struct with entries for the
188 // 1. real part and for the
189 // 2. imaginary part.
convertComplexType(ComplexType type)190 Type LLVMTypeConverter::convertComplexType(ComplexType type) {
191 auto elementType = convertType(type.getElementType());
192 return LLVM::LLVMStructType::getLiteral(&getContext(),
193 {elementType, elementType});
194 }
195
196 // Except for signatures, MLIR function types are converted into LLVM
197 // pointer-to-function types.
convertFunctionType(FunctionType type)198 Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
199 SignatureConversion conversion(type.getNumInputs());
200 Type converted =
201 convertFunctionSignature(type, /*isVariadic=*/false, conversion);
202 return LLVM::LLVMPointerType::get(converted);
203 }
204
205 // Function types are converted to LLVM Function types by recursively converting
206 // argument and result types. If MLIR Function has zero results, the LLVM
207 // Function has one VoidType result. If MLIR Function has more than one result,
208 // they are into an LLVM StructType in their order of appearance.
convertFunctionSignature(FunctionType funcTy,bool isVariadic,LLVMTypeConverter::SignatureConversion & result)209 Type LLVMTypeConverter::convertFunctionSignature(
210 FunctionType funcTy, bool isVariadic,
211 LLVMTypeConverter::SignatureConversion &result) {
212 // Select the argument converter depending on the calling convention.
213 auto funcArgConverter = options.useBarePtrCallConv
214 ? barePtrFuncArgTypeConverter
215 : structFuncArgTypeConverter;
216 // Convert argument types one by one and check for errors.
217 for (auto &en : llvm::enumerate(funcTy.getInputs())) {
218 Type type = en.value();
219 SmallVector<Type, 8> converted;
220 if (failed(funcArgConverter(*this, type, converted)))
221 return {};
222 result.addInputs(en.index(), converted);
223 }
224
225 // If function does not return anything, create the void result type,
226 // if it returns on element, convert it, otherwise pack the result types into
227 // a struct.
228 Type resultType = funcTy.getNumResults() == 0
229 ? LLVM::LLVMVoidType::get(&getContext())
230 : packFunctionResults(funcTy.getResults());
231 if (!resultType)
232 return {};
233 return LLVM::LLVMFunctionType::get(resultType, result.getConvertedTypes(),
234 isVariadic);
235 }
236
237 /// Converts the function type to a C-compatible format, in particular using
238 /// pointers to memref descriptors for arguments.
239 std::pair<Type, bool>
convertFunctionTypeCWrapper(FunctionType type)240 LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
241 SmallVector<Type, 4> inputs;
242 bool resultIsNowArg = false;
243
244 Type resultType = type.getNumResults() == 0
245 ? LLVM::LLVMVoidType::get(&getContext())
246 : packFunctionResults(type.getResults());
247 if (!resultType)
248 return {};
249
250 if (auto structType = resultType.dyn_cast<LLVM::LLVMStructType>()) {
251 // Struct types cannot be safely returned via C interface. Make this a
252 // pointer argument, instead.
253 inputs.push_back(LLVM::LLVMPointerType::get(structType));
254 resultType = LLVM::LLVMVoidType::get(&getContext());
255 resultIsNowArg = true;
256 }
257
258 for (Type t : type.getInputs()) {
259 auto converted = convertType(t);
260 if (!converted || !LLVM::isCompatibleType(converted))
261 return {};
262 if (t.isa<MemRefType, UnrankedMemRefType>())
263 converted = LLVM::LLVMPointerType::get(converted);
264 inputs.push_back(converted);
265 }
266
267 return {LLVM::LLVMFunctionType::get(resultType, inputs), resultIsNowArg};
268 }
269
270 /// Convert a memref type into a list of LLVM IR types that will form the
271 /// memref descriptor. The result contains the following types:
272 /// 1. The pointer to the allocated data buffer, followed by
273 /// 2. The pointer to the aligned data buffer, followed by
274 /// 3. A lowered `index`-type integer containing the distance between the
275 /// beginning of the buffer and the first element to be accessed through the
276 /// view, followed by
277 /// 4. An array containing as many `index`-type integers as the rank of the
278 /// MemRef: the array represents the size, in number of elements, of the memref
279 /// along the given dimension. For constant MemRef dimensions, the
280 /// corresponding size entry is a constant whose runtime value must match the
281 /// static value, followed by
282 /// 5. A second array containing as many `index`-type integers as the rank of
283 /// the MemRef: the second array represents the "stride" (in tensor abstraction
284 /// sense), i.e. the number of consecutive elements of the underlying buffer.
285 /// TODO: add assertions for the static cases.
286 ///
287 /// If `unpackAggregates` is set to true, the arrays described in (4) and (5)
288 /// are expanded into individual index-type elements.
289 ///
290 /// template <typename Elem, typename Index, size_t Rank>
291 /// struct {
292 /// Elem *allocatedPtr;
293 /// Elem *alignedPtr;
294 /// Index offset;
295 /// Index sizes[Rank]; // omitted when rank == 0
296 /// Index strides[Rank]; // omitted when rank == 0
297 /// };
298 SmallVector<Type, 5>
getMemRefDescriptorFields(MemRefType type,bool unpackAggregates)299 LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
300 bool unpackAggregates) {
301 assert(isStrided(type) &&
302 "Non-strided layout maps must have been normalized away");
303
304 Type elementType = convertType(type.getElementType());
305 if (!elementType)
306 return {};
307 auto ptrTy =
308 LLVM::LLVMPointerType::get(elementType, type.getMemorySpaceAsInt());
309 auto indexTy = getIndexType();
310
311 SmallVector<Type, 5> results = {ptrTy, ptrTy, indexTy};
312 auto rank = type.getRank();
313 if (rank == 0)
314 return results;
315
316 if (unpackAggregates)
317 results.insert(results.end(), 2 * rank, indexTy);
318 else
319 results.insert(results.end(), 2, LLVM::LLVMArrayType::get(indexTy, rank));
320 return results;
321 }
322
getMemRefDescriptorSize(MemRefType type,const DataLayout & layout)323 unsigned LLVMTypeConverter::getMemRefDescriptorSize(MemRefType type,
324 const DataLayout &layout) {
325 // Compute the descriptor size given that of its components indicated above.
326 unsigned space = type.getMemorySpaceAsInt();
327 return 2 * llvm::divideCeil(getPointerBitwidth(space), 8) +
328 (1 + 2 * type.getRank()) * layout.getTypeSize(getIndexType());
329 }
330
331 /// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that
332 /// packs the descriptor fields as defined by `getMemRefDescriptorFields`.
convertMemRefType(MemRefType type)333 Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
334 // When converting a MemRefType to a struct with descriptor fields, do not
335 // unpack the `sizes` and `strides` arrays.
336 SmallVector<Type, 5> types =
337 getMemRefDescriptorFields(type, /*unpackAggregates=*/false);
338 if (types.empty())
339 return {};
340 return LLVM::LLVMStructType::getLiteral(&getContext(), types);
341 }
342
343 /// Convert an unranked memref type into a list of non-aggregate LLVM IR types
344 /// that will form the unranked memref descriptor. In particular, the fields
345 /// for an unranked memref descriptor are:
346 /// 1. index-typed rank, the dynamic rank of this MemRef
347 /// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be
348 /// stack allocated (alloca) copy of a MemRef descriptor that got casted to
349 /// be unranked.
getUnrankedMemRefDescriptorFields()350 SmallVector<Type, 2> LLVMTypeConverter::getUnrankedMemRefDescriptorFields() {
351 return {getIndexType(),
352 LLVM::LLVMPointerType::get(IntegerType::get(&getContext(), 8))};
353 }
354
355 unsigned
getUnrankedMemRefDescriptorSize(UnrankedMemRefType type,const DataLayout & layout)356 LLVMTypeConverter::getUnrankedMemRefDescriptorSize(UnrankedMemRefType type,
357 const DataLayout &layout) {
358 // Compute the descriptor size given that of its components indicated above.
359 unsigned space = type.getMemorySpaceAsInt();
360 return layout.getTypeSize(getIndexType()) +
361 llvm::divideCeil(getPointerBitwidth(space), 8);
362 }
363
convertUnrankedMemRefType(UnrankedMemRefType type)364 Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) {
365 if (!convertType(type.getElementType()))
366 return {};
367 return LLVM::LLVMStructType::getLiteral(&getContext(),
368 getUnrankedMemRefDescriptorFields());
369 }
370
371 // Check if a memref type can be converted to a bare pointer.
canConvertToBarePtr(BaseMemRefType type)372 bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type) {
373 if (type.isa<UnrankedMemRefType>())
374 // Unranked memref is not supported in the bare pointer calling convention.
375 return false;
376
377 // Check that the memref has static shape, strides and offset. Otherwise, it
378 // cannot be lowered to a bare pointer.
379 auto memrefTy = type.cast<MemRefType>();
380 if (!memrefTy.hasStaticShape())
381 return false;
382
383 int64_t offset = 0;
384 SmallVector<int64_t, 4> strides;
385 if (failed(getStridesAndOffset(memrefTy, strides, offset)))
386 return false;
387
388 for (int64_t stride : strides)
389 if (ShapedType::isDynamicStrideOrOffset(stride))
390 return false;
391
392 return !ShapedType::isDynamicStrideOrOffset(offset);
393 }
394
395 /// Convert a memref type to a bare pointer to the memref element type.
convertMemRefToBarePtr(BaseMemRefType type)396 Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) {
397 if (!canConvertToBarePtr(type))
398 return {};
399 Type elementType = convertType(type.getElementType());
400 if (!elementType)
401 return {};
402 return LLVM::LLVMPointerType::get(elementType, type.getMemorySpaceAsInt());
403 }
404
405 /// Convert an n-D vector type to an LLVM vector type:
406 /// * 0-D `vector<T>` are converted to vector<1xT>
407 /// * 1-D `vector<axT>` remains as is while,
408 /// * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to
409 /// `!llvm.array<ax...array<jxvector<kxT>>>`.
convertVectorType(VectorType type)410 Type LLVMTypeConverter::convertVectorType(VectorType type) {
411 auto elementType = convertType(type.getElementType());
412 if (!elementType)
413 return {};
414 if (type.getShape().empty())
415 return VectorType::get({1}, elementType);
416 Type vectorType = VectorType::get(type.getShape().back(), elementType,
417 type.getNumScalableDims());
418 assert(LLVM::isCompatibleVectorType(vectorType) &&
419 "expected vector type compatible with the LLVM dialect");
420 auto shape = type.getShape();
421 for (int i = shape.size() - 2; i >= 0; --i)
422 vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);
423 return vectorType;
424 }
425
426 /// Convert a type in the context of the default or bare pointer calling
427 /// convention. Calling convention sensitive types, such as MemRefType and
428 /// UnrankedMemRefType, are converted following the specific rules for the
429 /// calling convention. Calling convention independent types are converted
430 /// following the default LLVM type conversions.
convertCallingConventionType(Type type)431 Type LLVMTypeConverter::convertCallingConventionType(Type type) {
432 if (options.useBarePtrCallConv)
433 if (auto memrefTy = type.dyn_cast<BaseMemRefType>())
434 return convertMemRefToBarePtr(memrefTy);
435
436 return convertType(type);
437 }
438
439 /// Promote the bare pointers in 'values' that resulted from memrefs to
440 /// descriptors. 'stdTypes' holds they types of 'values' before the conversion
441 /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
promoteBarePtrsToDescriptors(ConversionPatternRewriter & rewriter,Location loc,ArrayRef<Type> stdTypes,SmallVectorImpl<Value> & values)442 void LLVMTypeConverter::promoteBarePtrsToDescriptors(
443 ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes,
444 SmallVectorImpl<Value> &values) {
445 assert(stdTypes.size() == values.size() &&
446 "The number of types and values doesn't match");
447 for (unsigned i = 0, end = values.size(); i < end; ++i)
448 if (auto memrefTy = stdTypes[i].dyn_cast<MemRefType>())
449 values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this,
450 memrefTy, values[i]);
451 }
452
453 /// Convert a non-empty list of types to be returned from a function into a
454 /// supported LLVM IR type. In particular, if more than one value is returned,
455 /// create an LLVM IR structure type with elements that correspond to each of
456 /// the MLIR types converted with `convertType`.
packFunctionResults(TypeRange types)457 Type LLVMTypeConverter::packFunctionResults(TypeRange types) {
458 assert(!types.empty() && "expected non-empty list of type");
459
460 if (types.size() == 1)
461 return convertCallingConventionType(types.front());
462
463 SmallVector<Type, 8> resultTypes;
464 resultTypes.reserve(types.size());
465 for (auto t : types) {
466 auto converted = convertCallingConventionType(t);
467 if (!converted || !LLVM::isCompatibleType(converted))
468 return {};
469 resultTypes.push_back(converted);
470 }
471
472 return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
473 }
474
promoteOneMemRefDescriptor(Location loc,Value operand,OpBuilder & builder)475 Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
476 OpBuilder &builder) {
477 auto *context = builder.getContext();
478 auto int64Ty = IntegerType::get(builder.getContext(), 64);
479 auto indexType = IndexType::get(context);
480 // Alloca with proper alignment. We do not expect optimizations of this
481 // alloca op and so we omit allocating at the entry block.
482 auto ptrType = LLVM::LLVMPointerType::get(operand.getType());
483 Value one = builder.create<LLVM::ConstantOp>(loc, int64Ty,
484 IntegerAttr::get(indexType, 1));
485 Value allocated =
486 builder.create<LLVM::AllocaOp>(loc, ptrType, one, /*alignment=*/0);
487 // Store into the alloca'ed descriptor.
488 builder.create<LLVM::StoreOp>(loc, operand, allocated);
489 return allocated;
490 }
491
promoteOperands(Location loc,ValueRange opOperands,ValueRange operands,OpBuilder & builder)492 SmallVector<Value, 4> LLVMTypeConverter::promoteOperands(Location loc,
493 ValueRange opOperands,
494 ValueRange operands,
495 OpBuilder &builder) {
496 SmallVector<Value, 4> promotedOperands;
497 promotedOperands.reserve(operands.size());
498 for (auto it : llvm::zip(opOperands, operands)) {
499 auto operand = std::get<0>(it);
500 auto llvmOperand = std::get<1>(it);
501
502 if (options.useBarePtrCallConv) {
503 // For the bare-ptr calling convention, we only have to extract the
504 // aligned pointer of a memref.
505 if (auto memrefType = operand.getType().dyn_cast<MemRefType>()) {
506 MemRefDescriptor desc(llvmOperand);
507 llvmOperand = desc.alignedPtr(builder, loc);
508 } else if (operand.getType().isa<UnrankedMemRefType>()) {
509 llvm_unreachable("Unranked memrefs are not supported");
510 }
511 } else {
512 if (operand.getType().isa<UnrankedMemRefType>()) {
513 UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand,
514 promotedOperands);
515 continue;
516 }
517 if (auto memrefType = operand.getType().dyn_cast<MemRefType>()) {
518 MemRefDescriptor::unpack(builder, loc, llvmOperand, memrefType,
519 promotedOperands);
520 continue;
521 }
522 }
523
524 promotedOperands.push_back(llvmOperand);
525 }
526 return promotedOperands;
527 }
528
529 /// Callback to convert function argument types. It converts a MemRef function
530 /// argument to a list of non-aggregate types containing descriptor
531 /// information, and an UnrankedmemRef function argument to a list containing
532 /// the rank and a pointer to a descriptor struct.
structFuncArgTypeConverter(LLVMTypeConverter & converter,Type type,SmallVectorImpl<Type> & result)533 LogicalResult mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter,
534 Type type,
535 SmallVectorImpl<Type> &result) {
536 if (auto memref = type.dyn_cast<MemRefType>()) {
537 // In signatures, Memref descriptors are expanded into lists of
538 // non-aggregate values.
539 auto converted =
540 converter.getMemRefDescriptorFields(memref, /*unpackAggregates=*/true);
541 if (converted.empty())
542 return failure();
543 result.append(converted.begin(), converted.end());
544 return success();
545 }
546 if (type.isa<UnrankedMemRefType>()) {
547 auto converted = converter.getUnrankedMemRefDescriptorFields();
548 if (converted.empty())
549 return failure();
550 result.append(converted.begin(), converted.end());
551 return success();
552 }
553 auto converted = converter.convertType(type);
554 if (!converted)
555 return failure();
556 result.push_back(converted);
557 return success();
558 }
559
560 /// Callback to convert function argument types. It converts MemRef function
561 /// arguments to bare pointers to the MemRef element type.
barePtrFuncArgTypeConverter(LLVMTypeConverter & converter,Type type,SmallVectorImpl<Type> & result)562 LogicalResult mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter,
563 Type type,
564 SmallVectorImpl<Type> &result) {
565 auto llvmTy = converter.convertCallingConventionType(type);
566 if (!llvmTy)
567 return failure();
568
569 result.push_back(llvmTy);
570 return success();
571 }
572