1 //===- FuncToLLVM.cpp - Func to LLVM dialect conversion -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert MLIR Func and builtin dialects
10 // into the LLVM IR dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "../PassDetail.h"
15 #include "mlir/Analysis/DataLayoutAnalysis.h"
16 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
17 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
18 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
19 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
20 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
21 #include "mlir/Conversion/LLVMCommon/Pattern.h"
22 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h"
24 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
25 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
26 #include "mlir/Dialect/Utils/StaticValueUtils.h"
27 #include "mlir/IR/Attributes.h"
28 #include "mlir/IR/BlockAndValueMapping.h"
29 #include "mlir/IR/Builders.h"
30 #include "mlir/IR/BuiltinOps.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeUtilities.h"
33 #include "mlir/Support/LogicalResult.h"
34 #include "mlir/Support/MathExtras.h"
35 #include "mlir/Transforms/DialectConversion.h"
36 #include "mlir/Transforms/Passes.h"
37 #include "llvm/ADT/TypeSwitch.h"
38 #include "llvm/IR/DerivedTypes.h"
39 #include "llvm/IR/IRBuilder.h"
40 #include "llvm/IR/Type.h"
41 #include "llvm/Support/CommandLine.h"
42 #include "llvm/Support/FormatVariadic.h"
43 #include <algorithm>
44 #include <functional>
45 
46 using namespace mlir;
47 
48 #define PASS_NAME "convert-func-to-llvm"
49 
50 /// Only retain those attributes that are not constructed by
51 /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
52 /// attributes.
53 static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
54                                  bool filterArgAndResAttrs,
55                                  SmallVectorImpl<NamedAttribute> &result) {
56   for (const auto &attr : attrs) {
57     if (attr.getName() == SymbolTable::getSymbolAttrName() ||
58         attr.getName() == FunctionOpInterface::getTypeAttrName() ||
59         attr.getName() == "func.varargs" ||
60         (filterArgAndResAttrs &&
61          (attr.getName() == FunctionOpInterface::getArgDictAttrName() ||
62           attr.getName() == FunctionOpInterface::getResultDictAttrName())))
63       continue;
64     result.push_back(attr);
65   }
66 }
67 
68 /// Helper function for wrapping all attributes into a single DictionaryAttr
69 static auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) {
70   return DictionaryAttr::get(
71       b.getContext(),
72       b.getNamedAttr(LLVM::LLVMDialect::getStructAttrsAttrName(), attrs));
73 }
74 
75 /// Combines all result attributes into a single DictionaryAttr
76 /// and prepends to argument attrs.
77 /// This is intended to be used to format the attributes for a C wrapper
78 /// function when the result(s) is converted to the first function argument
79 /// (in the multiple return case, all returns get wrapped into a single
80 /// argument). The total number of argument attributes should be equal to
81 /// (number of function arguments) + 1.
82 static void
83 prependResAttrsToArgAttrs(OpBuilder &builder,
84                           SmallVectorImpl<NamedAttribute> &attributes,
85                           size_t numArguments) {
86   auto allAttrs = SmallVector<Attribute>(
87       numArguments + 1, DictionaryAttr::get(builder.getContext()));
88   NamedAttribute *argAttrs = nullptr;
89   for (auto it = attributes.begin(); it != attributes.end();) {
90     if (it->getName() == FunctionOpInterface::getArgDictAttrName()) {
91       auto arrayAttrs = it->getValue().cast<ArrayAttr>();
92       assert(arrayAttrs.size() == numArguments &&
93              "Number of arg attrs and args should match");
94       std::copy(arrayAttrs.begin(), arrayAttrs.end(), allAttrs.begin() + 1);
95       argAttrs = it;
96     } else if (it->getName() == FunctionOpInterface::getResultDictAttrName()) {
97       auto arrayAttrs = it->getValue().cast<ArrayAttr>();
98       assert(!arrayAttrs.empty() && "expected array to be non-empty");
99       allAttrs[0] = (arrayAttrs.size() == 1)
100                         ? arrayAttrs[0]
101                         : wrapAsStructAttrs(builder, arrayAttrs);
102       it = attributes.erase(it);
103       continue;
104     }
105     it++;
106   }
107 
108   auto newArgAttrs =
109       builder.getNamedAttr(FunctionOpInterface::getArgDictAttrName(),
110                            builder.getArrayAttr(allAttrs));
111   if (!argAttrs) {
112     attributes.emplace_back(newArgAttrs);
113     return;
114   }
115   *argAttrs = newArgAttrs;
116   return;
117 }
118 
119 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
120 /// arguments instead of unpacked arguments. This function can be called from C
121 /// by passing a pointer to a C struct corresponding to a memref descriptor.
122 /// Similarly, returned memrefs are passed via pointers to a C struct that is
123 /// passed as additional argument.
124 /// Internally, the auxiliary function unpacks the descriptor into individual
125 /// components and forwards them to `newFuncOp` and forwards the results to
126 /// the extra arguments.
127 static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
128                                    LLVMTypeConverter &typeConverter,
129                                    FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) {
130   auto type = funcOp.getType();
131   SmallVector<NamedAttribute, 4> attributes;
132   filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false,
133                        attributes);
134   Type wrapperFuncType;
135   bool resultIsNowArg;
136   std::tie(wrapperFuncType, resultIsNowArg) =
137       typeConverter.convertFunctionTypeCWrapper(type);
138   if (resultIsNowArg)
139     prependResAttrsToArgAttrs(rewriter, attributes, funcOp.getNumArguments());
140   auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
141       loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
142       wrapperFuncType, LLVM::Linkage::External, /*dsoLocal*/ false, attributes);
143 
144   OpBuilder::InsertionGuard guard(rewriter);
145   rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock());
146 
147   SmallVector<Value, 8> args;
148   size_t argOffset = resultIsNowArg ? 1 : 0;
149   for (auto &en : llvm::enumerate(type.getInputs())) {
150     Value arg = wrapperFuncOp.getArgument(en.index() + argOffset);
151     if (auto memrefType = en.value().dyn_cast<MemRefType>()) {
152       Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
153       MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
154       continue;
155     }
156     if (en.value().isa<UnrankedMemRefType>()) {
157       Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
158       UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args);
159       continue;
160     }
161 
162     args.push_back(arg);
163   }
164 
165   auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args);
166 
167   if (resultIsNowArg) {
168     rewriter.create<LLVM::StoreOp>(loc, call.getResult(0),
169                                    wrapperFuncOp.getArgument(0));
170     rewriter.create<LLVM::ReturnOp>(loc, ValueRange{});
171   } else {
172     rewriter.create<LLVM::ReturnOp>(loc, call.getResults());
173   }
174 }
175 
176 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
177 /// arguments instead of unpacked arguments. Creates a body for the (external)
178 /// `newFuncOp` that allocates a memref descriptor on stack, packs the
179 /// individual arguments into this descriptor and passes a pointer to it into
180 /// the auxiliary function. If the result of the function cannot be directly
181 /// returned, we write it to a special first argument that provides a pointer
182 /// to a corresponding struct. This auxiliary external function is now
183 /// compatible with functions defined in C using pointers to C structs
184 /// corresponding to a memref descriptor.
185 static void wrapExternalFunction(OpBuilder &builder, Location loc,
186                                  LLVMTypeConverter &typeConverter,
187                                  FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) {
188   OpBuilder::InsertionGuard guard(builder);
189 
190   Type wrapperType;
191   bool resultIsNowArg;
192   std::tie(wrapperType, resultIsNowArg) =
193       typeConverter.convertFunctionTypeCWrapper(funcOp.getType());
194   // This conversion can only fail if it could not convert one of the argument
195   // types. But since it has been applied to a non-wrapper function before, it
196   // should have failed earlier and not reach this point at all.
197   assert(wrapperType && "unexpected type conversion failure");
198 
199   SmallVector<NamedAttribute, 4> attributes;
200   filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false,
201                        attributes);
202 
203   if (resultIsNowArg)
204     prependResAttrsToArgAttrs(builder, attributes, funcOp.getNumArguments());
205   // Create the auxiliary function.
206   auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
207       loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
208       wrapperType, LLVM::Linkage::External, /*dsoLocal*/ false, attributes);
209 
210   builder.setInsertionPointToStart(newFuncOp.addEntryBlock());
211 
212   // Get a ValueRange containing arguments.
213   FunctionType type = funcOp.getType();
214   SmallVector<Value, 8> args;
215   args.reserve(type.getNumInputs());
216   ValueRange wrapperArgsRange(newFuncOp.getArguments());
217 
218   if (resultIsNowArg) {
219     // Allocate the struct on the stack and pass the pointer.
220     Type resultType =
221         wrapperType.cast<LLVM::LLVMFunctionType>().getParamType(0);
222     Value one = builder.create<LLVM::ConstantOp>(
223         loc, typeConverter.convertType(builder.getIndexType()),
224         builder.getIntegerAttr(builder.getIndexType(), 1));
225     Value result = builder.create<LLVM::AllocaOp>(loc, resultType, one);
226     args.push_back(result);
227   }
228 
229   // Iterate over the inputs of the original function and pack values into
230   // memref descriptors if the original type is a memref.
231   for (auto &en : llvm::enumerate(type.getInputs())) {
232     Value arg;
233     int numToDrop = 1;
234     auto memRefType = en.value().dyn_cast<MemRefType>();
235     auto unrankedMemRefType = en.value().dyn_cast<UnrankedMemRefType>();
236     if (memRefType || unrankedMemRefType) {
237       numToDrop = memRefType
238                       ? MemRefDescriptor::getNumUnpackedValues(memRefType)
239                       : UnrankedMemRefDescriptor::getNumUnpackedValues();
240       Value packed =
241           memRefType
242               ? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType,
243                                        wrapperArgsRange.take_front(numToDrop))
244               : UnrankedMemRefDescriptor::pack(
245                     builder, loc, typeConverter, unrankedMemRefType,
246                     wrapperArgsRange.take_front(numToDrop));
247 
248       auto ptrTy = LLVM::LLVMPointerType::get(packed.getType());
249       Value one = builder.create<LLVM::ConstantOp>(
250           loc, typeConverter.convertType(builder.getIndexType()),
251           builder.getIntegerAttr(builder.getIndexType(), 1));
252       Value allocated =
253           builder.create<LLVM::AllocaOp>(loc, ptrTy, one, /*alignment=*/0);
254       builder.create<LLVM::StoreOp>(loc, packed, allocated);
255       arg = allocated;
256     } else {
257       arg = wrapperArgsRange[0];
258     }
259 
260     args.push_back(arg);
261     wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop);
262   }
263   assert(wrapperArgsRange.empty() && "did not map some of the arguments");
264 
265   auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args);
266 
267   if (resultIsNowArg) {
268     Value result = builder.create<LLVM::LoadOp>(loc, args.front());
269     builder.create<LLVM::ReturnOp>(loc, ValueRange{result});
270   } else {
271     builder.create<LLVM::ReturnOp>(loc, call.getResults());
272   }
273 }
274 
275 namespace {
276 
277 struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
278 protected:
279   using ConvertOpToLLVMPattern<FuncOp>::ConvertOpToLLVMPattern;
280 
281   // Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided
282   // to this legalization pattern.
283   LLVM::LLVMFuncOp
284   convertFuncOpToLLVMFuncOp(FuncOp funcOp,
285                             ConversionPatternRewriter &rewriter) const {
286     // Convert the original function arguments. They are converted using the
287     // LLVMTypeConverter provided to this legalization pattern.
288     auto varargsAttr = funcOp->getAttrOfType<BoolAttr>("func.varargs");
289     TypeConverter::SignatureConversion result(funcOp.getNumArguments());
290     auto llvmType = getTypeConverter()->convertFunctionSignature(
291         funcOp.getType(), varargsAttr && varargsAttr.getValue(), result);
292     if (!llvmType)
293       return nullptr;
294 
295     // Propagate argument/result attributes to all converted arguments/result
296     // obtained after converting a given original argument/result.
297     SmallVector<NamedAttribute, 4> attributes;
298     filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/true,
299                          attributes);
300     if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
301       assert(!resAttrDicts.empty() && "expected array to be non-empty");
302       auto newResAttrDicts =
303           (funcOp.getNumResults() == 1)
304               ? resAttrDicts
305               : rewriter.getArrayAttr(
306                     {wrapAsStructAttrs(rewriter, resAttrDicts)});
307       attributes.push_back(rewriter.getNamedAttr(
308           FunctionOpInterface::getResultDictAttrName(), newResAttrDicts));
309     }
310     if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
311       SmallVector<Attribute, 4> newArgAttrs(
312           llvmType.cast<LLVM::LLVMFunctionType>().getNumParams());
313       for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
314         auto mapping = result.getInputMapping(i);
315         assert(mapping.hasValue() &&
316                "unexpected deletion of function argument");
317         for (size_t j = 0; j < mapping->size; ++j)
318           newArgAttrs[mapping->inputNo + j] = argAttrDicts[i];
319       }
320       attributes.push_back(
321           rewriter.getNamedAttr(FunctionOpInterface::getArgDictAttrName(),
322                                 rewriter.getArrayAttr(newArgAttrs)));
323     }
324     for (const auto &pair : llvm::enumerate(attributes)) {
325       if (pair.value().getName() == "llvm.linkage") {
326         attributes.erase(attributes.begin() + pair.index());
327         break;
328       }
329     }
330 
331     // Create an LLVM function, use external linkage by default until MLIR
332     // functions have linkage.
333     LLVM::Linkage linkage = LLVM::Linkage::External;
334     if (funcOp->hasAttr("llvm.linkage")) {
335       auto attr =
336           funcOp->getAttr("llvm.linkage").dyn_cast<mlir::LLVM::LinkageAttr>();
337       if (!attr) {
338         funcOp->emitError()
339             << "Contains llvm.linkage attribute not of type LLVM::LinkageAttr";
340         return nullptr;
341       }
342       linkage = attr.getLinkage();
343     }
344     auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
345         funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
346         /*dsoLocal*/ false, attributes);
347     rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
348                                 newFuncOp.end());
349     if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
350                                            &result)))
351       return nullptr;
352 
353     return newFuncOp;
354   }
355 };
356 
357 /// FuncOp legalization pattern that converts MemRef arguments to pointers to
358 /// MemRef descriptors (LLVM struct data types) containing all the MemRef type
359 /// information.
360 static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface";
361 struct FuncOpConversion : public FuncOpConversionBase {
362   FuncOpConversion(LLVMTypeConverter &converter)
363       : FuncOpConversionBase(converter) {}
364 
365   LogicalResult
366   matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor,
367                   ConversionPatternRewriter &rewriter) const override {
368     auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
369     if (!newFuncOp)
370       return failure();
371 
372     if (getTypeConverter()->getOptions().emitCWrappers ||
373         funcOp->getAttrOfType<UnitAttr>(kEmitIfaceAttrName)) {
374       if (newFuncOp.isExternal())
375         wrapExternalFunction(rewriter, funcOp.getLoc(), *getTypeConverter(),
376                              funcOp, newFuncOp);
377       else
378         wrapForExternalCallers(rewriter, funcOp.getLoc(), *getTypeConverter(),
379                                funcOp, newFuncOp);
380     }
381 
382     rewriter.eraseOp(funcOp);
383     return success();
384   }
385 };
386 
387 /// FuncOp legalization pattern that converts MemRef arguments to bare pointers
388 /// to the MemRef element type. This will impact the calling convention and ABI.
389 struct BarePtrFuncOpConversion : public FuncOpConversionBase {
390   using FuncOpConversionBase::FuncOpConversionBase;
391 
392   LogicalResult
393   matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor,
394                   ConversionPatternRewriter &rewriter) const override {
395 
396     // TODO: bare ptr conversion could be handled by argument materialization
397     // and most of the code below would go away. But to do this, we would need a
398     // way to distinguish between FuncOp and other regions in the
399     // addArgumentMaterialization hook.
400 
401     // Store the type of memref-typed arguments before the conversion so that we
402     // can promote them to MemRef descriptor at the beginning of the function.
403     SmallVector<Type, 8> oldArgTypes =
404         llvm::to_vector<8>(funcOp.getType().getInputs());
405 
406     auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
407     if (!newFuncOp)
408       return failure();
409     if (newFuncOp.getBody().empty()) {
410       rewriter.eraseOp(funcOp);
411       return success();
412     }
413 
414     // Promote bare pointers from memref arguments to memref descriptors at the
415     // beginning of the function so that all the memrefs in the function have a
416     // uniform representation.
417     Block *entryBlock = &newFuncOp.getBody().front();
418     auto blockArgs = entryBlock->getArguments();
419     assert(blockArgs.size() == oldArgTypes.size() &&
420            "The number of arguments and types doesn't match");
421 
422     OpBuilder::InsertionGuard guard(rewriter);
423     rewriter.setInsertionPointToStart(entryBlock);
424     for (auto it : llvm::zip(blockArgs, oldArgTypes)) {
425       BlockArgument arg = std::get<0>(it);
426       Type argTy = std::get<1>(it);
427 
428       // Unranked memrefs are not supported in the bare pointer calling
429       // convention. We should have bailed out before in the presence of
430       // unranked memrefs.
431       assert(!argTy.isa<UnrankedMemRefType>() &&
432              "Unranked memref is not supported");
433       auto memrefTy = argTy.dyn_cast<MemRefType>();
434       if (!memrefTy)
435         continue;
436 
437       // Replace barePtr with a placeholder (undef), promote barePtr to a ranked
438       // or unranked memref descriptor and replace placeholder with the last
439       // instruction of the memref descriptor.
440       // TODO: The placeholder is needed to avoid replacing barePtr uses in the
441       // MemRef descriptor instructions. We may want to have a utility in the
442       // rewriter to properly handle this use case.
443       Location loc = funcOp.getLoc();
444       auto placeholder = rewriter.create<LLVM::UndefOp>(
445           loc, getTypeConverter()->convertType(memrefTy));
446       rewriter.replaceUsesOfBlockArgument(arg, placeholder);
447 
448       Value desc = MemRefDescriptor::fromStaticShape(
449           rewriter, loc, *getTypeConverter(), memrefTy, arg);
450       rewriter.replaceOp(placeholder, {desc});
451     }
452 
453     rewriter.eraseOp(funcOp);
454     return success();
455   }
456 };
457 
458 struct ConstantOpLowering : public ConvertOpToLLVMPattern<func::ConstantOp> {
459   using ConvertOpToLLVMPattern<func::ConstantOp>::ConvertOpToLLVMPattern;
460 
461   LogicalResult
462   matchAndRewrite(func::ConstantOp op, OpAdaptor adaptor,
463                   ConversionPatternRewriter &rewriter) const override {
464     auto type = typeConverter->convertType(op.getResult().getType());
465     if (!type || !LLVM::isCompatibleType(type))
466       return rewriter.notifyMatchFailure(op, "failed to convert result type");
467 
468     auto newOp =
469         rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type, op.getValue());
470     for (const NamedAttribute &attr : op->getAttrs()) {
471       if (attr.getName().strref() == "value")
472         continue;
473       newOp->setAttr(attr.getName(), attr.getValue());
474     }
475     rewriter.replaceOp(op, newOp->getResults());
476     return success();
477   }
478 };
479 
480 // A CallOp automatically promotes MemRefType to a sequence of alloca/store and
481 // passes the pointer to the MemRef across function boundaries.
482 template <typename CallOpType>
483 struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
484   using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern;
485   using Super = CallOpInterfaceLowering<CallOpType>;
486   using Base = ConvertOpToLLVMPattern<CallOpType>;
487 
488   LogicalResult
489   matchAndRewrite(CallOpType callOp, typename CallOpType::Adaptor adaptor,
490                   ConversionPatternRewriter &rewriter) const override {
491     // Pack the result types into a struct.
492     Type packedResult = nullptr;
493     unsigned numResults = callOp.getNumResults();
494     auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
495 
496     if (numResults != 0) {
497       if (!(packedResult =
498                 this->getTypeConverter()->packFunctionResults(resultTypes)))
499         return failure();
500     }
501 
502     auto promoted = this->getTypeConverter()->promoteOperands(
503         callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
504         adaptor.getOperands(), rewriter);
505     auto newOp = rewriter.create<LLVM::CallOp>(
506         callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
507         promoted, callOp->getAttrs());
508 
509     SmallVector<Value, 4> results;
510     if (numResults < 2) {
511       // If < 2 results, packing did not do anything and we can just return.
512       results.append(newOp.result_begin(), newOp.result_end());
513     } else {
514       // Otherwise, it had been converted to an operation producing a structure.
515       // Extract individual results from the structure and return them as list.
516       results.reserve(numResults);
517       for (unsigned i = 0; i < numResults; ++i) {
518         auto type =
519             this->typeConverter->convertType(callOp.getResult(i).getType());
520         results.push_back(rewriter.create<LLVM::ExtractValueOp>(
521             callOp.getLoc(), type, newOp->getResult(0),
522             rewriter.getI64ArrayAttr(i)));
523       }
524     }
525 
526     if (this->getTypeConverter()->getOptions().useBarePtrCallConv) {
527       // For the bare-ptr calling convention, promote memref results to
528       // descriptors.
529       assert(results.size() == resultTypes.size() &&
530              "The number of arguments and types doesn't match");
531       this->getTypeConverter()->promoteBarePtrsToDescriptors(
532           rewriter, callOp.getLoc(), resultTypes, results);
533     } else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(),
534                                                     resultTypes, results,
535                                                     /*toDynamic=*/false))) {
536       return failure();
537     }
538 
539     rewriter.replaceOp(callOp, results);
540     return success();
541   }
542 };
543 
544 struct CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
545   using Super::Super;
546 };
547 
548 struct CallIndirectOpLowering
549     : public CallOpInterfaceLowering<func::CallIndirectOp> {
550   using Super::Super;
551 };
552 
553 struct UnrealizedConversionCastOpLowering
554     : public ConvertOpToLLVMPattern<UnrealizedConversionCastOp> {
555   using ConvertOpToLLVMPattern<
556       UnrealizedConversionCastOp>::ConvertOpToLLVMPattern;
557 
558   LogicalResult
559   matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor,
560                   ConversionPatternRewriter &rewriter) const override {
561     SmallVector<Type> convertedTypes;
562     if (succeeded(typeConverter->convertTypes(op.getOutputs().getTypes(),
563                                               convertedTypes)) &&
564         convertedTypes == adaptor.getInputs().getTypes()) {
565       rewriter.replaceOp(op, adaptor.getInputs());
566       return success();
567     }
568 
569     convertedTypes.clear();
570     if (succeeded(typeConverter->convertTypes(adaptor.getInputs().getTypes(),
571                                               convertedTypes)) &&
572         convertedTypes == op.getOutputs().getType()) {
573       rewriter.replaceOp(op, adaptor.getInputs());
574       return success();
575     }
576     return failure();
577   }
578 };
579 
580 // Special lowering pattern for `ReturnOps`.  Unlike all other operations,
581 // `ReturnOp` interacts with the function signature and must have as many
582 // operands as the function has return values.  Because in LLVM IR, functions
583 // can only return 0 or 1 value, we pack multiple values into a structure type.
584 // Emit `UndefOp` followed by `InsertValueOp`s to create such structure if
585 // necessary before returning it
586 struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
587   using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
588 
589   LogicalResult
590   matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
591                   ConversionPatternRewriter &rewriter) const override {
592     Location loc = op.getLoc();
593     unsigned numArguments = op.getNumOperands();
594     SmallVector<Value, 4> updatedOperands;
595 
596     if (getTypeConverter()->getOptions().useBarePtrCallConv) {
597       // For the bare-ptr calling convention, extract the aligned pointer to
598       // be returned from the memref descriptor.
599       for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
600         Type oldTy = std::get<0>(it).getType();
601         Value newOperand = std::get<1>(it);
602         if (oldTy.isa<MemRefType>()) {
603           MemRefDescriptor memrefDesc(newOperand);
604           newOperand = memrefDesc.alignedPtr(rewriter, loc);
605         } else if (oldTy.isa<UnrankedMemRefType>()) {
606           // Unranked memref is not supported in the bare pointer calling
607           // convention.
608           return failure();
609         }
610         updatedOperands.push_back(newOperand);
611       }
612     } else {
613       updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
614       (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
615                                     updatedOperands,
616                                     /*toDynamic=*/true);
617     }
618 
619     // If ReturnOp has 0 or 1 operand, create it and return immediately.
620     if (numArguments == 0) {
621       rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(),
622                                                   op->getAttrs());
623       return success();
624     }
625     if (numArguments == 1) {
626       rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
627           op, TypeRange(), updatedOperands, op->getAttrs());
628       return success();
629     }
630 
631     // Otherwise, we need to pack the arguments into an LLVM struct type before
632     // returning.
633     auto packedType = getTypeConverter()->packFunctionResults(
634         llvm::to_vector<4>(op.getOperandTypes()));
635 
636     Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
637     for (unsigned i = 0; i < numArguments; ++i) {
638       packed = rewriter.create<LLVM::InsertValueOp>(
639           loc, packedType, packed, updatedOperands[i],
640           rewriter.getI64ArrayAttr(i));
641     }
642     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
643                                                 op->getAttrs());
644     return success();
645   }
646 };
647 } // namespace
648 
649 void mlir::populateFuncToLLVMFuncOpConversionPattern(
650     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
651   if (converter.getOptions().useBarePtrCallConv)
652     patterns.add<BarePtrFuncOpConversion>(converter);
653   else
654     patterns.add<FuncOpConversion>(converter);
655 }
656 
657 void mlir::populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter,
658                                                 RewritePatternSet &patterns) {
659   populateFuncToLLVMFuncOpConversionPattern(converter, patterns);
660   // clang-format off
661   patterns.add<
662       CallIndirectOpLowering,
663       CallOpLowering,
664       ConstantOpLowering,
665       ReturnOpLowering>(converter);
666   // clang-format on
667 }
668 
669 namespace {
670 /// A pass converting Func operations into the LLVM IR dialect.
671 struct ConvertFuncToLLVMPass
672     : public ConvertFuncToLLVMBase<ConvertFuncToLLVMPass> {
673   ConvertFuncToLLVMPass() = default;
674   ConvertFuncToLLVMPass(bool useBarePtrCallConv, bool emitCWrappers,
675                         unsigned indexBitwidth, bool useAlignedAlloc,
676                         const llvm::DataLayout &dataLayout) {
677     this->useBarePtrCallConv = useBarePtrCallConv;
678     this->emitCWrappers = emitCWrappers;
679     this->indexBitwidth = indexBitwidth;
680     this->dataLayout = dataLayout.getStringRepresentation();
681   }
682 
683   /// Run the dialect converter on the module.
684   void runOnOperation() override {
685     if (useBarePtrCallConv && emitCWrappers) {
686       getOperation().emitError()
687           << "incompatible conversion options: bare-pointer calling convention "
688              "and C wrapper emission";
689       signalPassFailure();
690       return;
691     }
692     if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
693             this->dataLayout, [this](const Twine &message) {
694               getOperation().emitError() << message.str();
695             }))) {
696       signalPassFailure();
697       return;
698     }
699 
700     ModuleOp m = getOperation();
701     const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
702 
703     LowerToLLVMOptions options(&getContext(),
704                                dataLayoutAnalysis.getAtOrAbove(m));
705     options.useBarePtrCallConv = useBarePtrCallConv;
706     options.emitCWrappers = emitCWrappers;
707     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
708       options.overrideIndexBitwidth(indexBitwidth);
709     options.dataLayout = llvm::DataLayout(this->dataLayout);
710 
711     LLVMTypeConverter typeConverter(&getContext(), options,
712                                     &dataLayoutAnalysis);
713 
714     RewritePatternSet patterns(&getContext());
715     populateFuncToLLVMConversionPatterns(typeConverter, patterns);
716 
717     // TODO: Remove these in favor of their dedicated conversion passes.
718     arith::populateArithmeticToLLVMConversionPatterns(typeConverter, patterns);
719     cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
720 
721     LLVMConversionTarget target(getContext());
722     if (failed(applyPartialConversion(m, target, std::move(patterns))))
723       signalPassFailure();
724 
725     m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(),
726                StringAttr::get(m.getContext(), this->dataLayout));
727   }
728 };
729 } // namespace
730 
731 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertFuncToLLVMPass() {
732   return std::make_unique<ConvertFuncToLLVMPass>();
733 }
734 
735 std::unique_ptr<OperationPass<ModuleOp>>
736 mlir::createConvertFuncToLLVMPass(const LowerToLLVMOptions &options) {
737   auto allocLowering = options.allocLowering;
738   // There is no way to provide additional patterns for pass, so
739   // AllocLowering::None will always fail.
740   assert(allocLowering != LowerToLLVMOptions::AllocLowering::None &&
741          "ConvertFuncToLLVMPass doesn't support AllocLowering::None");
742   bool useAlignedAlloc =
743       (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc);
744   return std::make_unique<ConvertFuncToLLVMPass>(
745       options.useBarePtrCallConv, options.emitCWrappers,
746       options.getIndexBitwidth(), useAlignedAlloc, options.dataLayout);
747 }
748