1 //===- FuncToLLVM.cpp - Func to LLVM dialect conversion -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert MLIR Func and builtin dialects
10 // into the LLVM IR dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "../PassDetail.h"
15 #include "mlir/Analysis/DataLayoutAnalysis.h"
16 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
17 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
18 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
19 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
20 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
21 #include "mlir/Conversion/LLVMCommon/Pattern.h"
22 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h"
24 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
25 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
26 #include "mlir/Dialect/Utils/StaticValueUtils.h"
27 #include "mlir/IR/Attributes.h"
28 #include "mlir/IR/BlockAndValueMapping.h"
29 #include "mlir/IR/Builders.h"
30 #include "mlir/IR/BuiltinOps.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeUtilities.h"
33 #include "mlir/Support/LogicalResult.h"
34 #include "mlir/Support/MathExtras.h"
35 #include "mlir/Transforms/DialectConversion.h"
36 #include "mlir/Transforms/Passes.h"
37 #include "llvm/ADT/TypeSwitch.h"
38 #include "llvm/IR/DerivedTypes.h"
39 #include "llvm/IR/IRBuilder.h"
40 #include "llvm/IR/Type.h"
41 #include "llvm/Support/CommandLine.h"
42 #include "llvm/Support/FormatVariadic.h"
43 #include <algorithm>
44 #include <functional>
45 
46 using namespace mlir;
47 
48 #define PASS_NAME "convert-func-to-llvm"
49 
50 /// Only retain those attributes that are not constructed by
51 /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
52 /// attributes.
53 static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
54                                  bool filterArgAndResAttrs,
55                                  SmallVectorImpl<NamedAttribute> &result) {
56   for (const auto &attr : attrs) {
57     if (attr.getName() == SymbolTable::getSymbolAttrName() ||
58         attr.getName() == FunctionOpInterface::getTypeAttrName() ||
59         attr.getName() == "func.varargs" ||
60         (filterArgAndResAttrs &&
61          (attr.getName() == FunctionOpInterface::getArgDictAttrName() ||
62           attr.getName() == FunctionOpInterface::getResultDictAttrName())))
63       continue;
64     result.push_back(attr);
65   }
66 }
67 
68 /// Helper function for wrapping all attributes into a single DictionaryAttr
69 static auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) {
70   return DictionaryAttr::get(
71       b.getContext(),
72       b.getNamedAttr(LLVM::LLVMDialect::getStructAttrsAttrName(), attrs));
73 }
74 
75 /// Combines all result attributes into a single DictionaryAttr
76 /// and prepends to argument attrs.
77 /// This is intended to be used to format the attributes for a C wrapper
78 /// function when the result(s) is converted to the first function argument
79 /// (in the multiple return case, all returns get wrapped into a single
80 /// argument). The total number of argument attributes should be equal to
81 /// (number of function arguments) + 1.
82 static void
83 prependResAttrsToArgAttrs(OpBuilder &builder,
84                           SmallVectorImpl<NamedAttribute> &attributes,
85                           size_t numArguments) {
86   auto allAttrs = SmallVector<Attribute>(
87       numArguments + 1, DictionaryAttr::get(builder.getContext()));
88   NamedAttribute *argAttrs = nullptr;
89   for (auto *it = attributes.begin(); it != attributes.end();) {
90     if (it->getName() == FunctionOpInterface::getArgDictAttrName()) {
91       auto arrayAttrs = it->getValue().cast<ArrayAttr>();
92       assert(arrayAttrs.size() == numArguments &&
93              "Number of arg attrs and args should match");
94       std::copy(arrayAttrs.begin(), arrayAttrs.end(), allAttrs.begin() + 1);
95       argAttrs = it;
96     } else if (it->getName() == FunctionOpInterface::getResultDictAttrName()) {
97       auto arrayAttrs = it->getValue().cast<ArrayAttr>();
98       assert(!arrayAttrs.empty() && "expected array to be non-empty");
99       allAttrs[0] = (arrayAttrs.size() == 1)
100                         ? arrayAttrs[0]
101                         : wrapAsStructAttrs(builder, arrayAttrs);
102       it = attributes.erase(it);
103       continue;
104     }
105     it++;
106   }
107 
108   auto newArgAttrs =
109       builder.getNamedAttr(FunctionOpInterface::getArgDictAttrName(),
110                            builder.getArrayAttr(allAttrs));
111   if (!argAttrs) {
112     attributes.emplace_back(newArgAttrs);
113     return;
114   }
115   *argAttrs = newArgAttrs;
116 }
117 
118 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
119 /// arguments instead of unpacked arguments. This function can be called from C
120 /// by passing a pointer to a C struct corresponding to a memref descriptor.
121 /// Similarly, returned memrefs are passed via pointers to a C struct that is
122 /// passed as additional argument.
123 /// Internally, the auxiliary function unpacks the descriptor into individual
124 /// components and forwards them to `newFuncOp` and forwards the results to
125 /// the extra arguments.
126 static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
127                                    LLVMTypeConverter &typeConverter,
128                                    func::FuncOp funcOp,
129                                    LLVM::LLVMFuncOp newFuncOp) {
130   auto type = funcOp.getFunctionType();
131   SmallVector<NamedAttribute, 4> attributes;
132   filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false,
133                        attributes);
134   Type wrapperFuncType;
135   bool resultIsNowArg;
136   std::tie(wrapperFuncType, resultIsNowArg) =
137       typeConverter.convertFunctionTypeCWrapper(type);
138   if (resultIsNowArg)
139     prependResAttrsToArgAttrs(rewriter, attributes, funcOp.getNumArguments());
140   auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
141       loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
142       wrapperFuncType, LLVM::Linkage::External, /*dsoLocal*/ false,
143       /*cconv*/ LLVM::CConv::C, attributes);
144 
145   OpBuilder::InsertionGuard guard(rewriter);
146   rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock());
147 
148   SmallVector<Value, 8> args;
149   size_t argOffset = resultIsNowArg ? 1 : 0;
150   for (auto &en : llvm::enumerate(type.getInputs())) {
151     Value arg = wrapperFuncOp.getArgument(en.index() + argOffset);
152     if (auto memrefType = en.value().dyn_cast<MemRefType>()) {
153       Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
154       MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
155       continue;
156     }
157     if (en.value().isa<UnrankedMemRefType>()) {
158       Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
159       UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args);
160       continue;
161     }
162 
163     args.push_back(arg);
164   }
165 
166   auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args);
167 
168   if (resultIsNowArg) {
169     rewriter.create<LLVM::StoreOp>(loc, call.getResult(0),
170                                    wrapperFuncOp.getArgument(0));
171     rewriter.create<LLVM::ReturnOp>(loc, ValueRange{});
172   } else {
173     rewriter.create<LLVM::ReturnOp>(loc, call.getResults());
174   }
175 }
176 
177 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
178 /// arguments instead of unpacked arguments. Creates a body for the (external)
179 /// `newFuncOp` that allocates a memref descriptor on stack, packs the
180 /// individual arguments into this descriptor and passes a pointer to it into
181 /// the auxiliary function. If the result of the function cannot be directly
182 /// returned, we write it to a special first argument that provides a pointer
183 /// to a corresponding struct. This auxiliary external function is now
184 /// compatible with functions defined in C using pointers to C structs
185 /// corresponding to a memref descriptor.
186 static void wrapExternalFunction(OpBuilder &builder, Location loc,
187                                  LLVMTypeConverter &typeConverter,
188                                  func::FuncOp funcOp,
189                                  LLVM::LLVMFuncOp newFuncOp) {
190   OpBuilder::InsertionGuard guard(builder);
191 
192   Type wrapperType;
193   bool resultIsNowArg;
194   std::tie(wrapperType, resultIsNowArg) =
195       typeConverter.convertFunctionTypeCWrapper(funcOp.getFunctionType());
196   // This conversion can only fail if it could not convert one of the argument
197   // types. But since it has been applied to a non-wrapper function before, it
198   // should have failed earlier and not reach this point at all.
199   assert(wrapperType && "unexpected type conversion failure");
200 
201   SmallVector<NamedAttribute, 4> attributes;
202   filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false,
203                        attributes);
204 
205   if (resultIsNowArg)
206     prependResAttrsToArgAttrs(builder, attributes, funcOp.getNumArguments());
207   // Create the auxiliary function.
208   auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
209       loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
210       wrapperType, LLVM::Linkage::External, /*dsoLocal*/ false,
211       /*cconv*/ LLVM::CConv::C, attributes);
212 
213   builder.setInsertionPointToStart(newFuncOp.addEntryBlock());
214 
215   // Get a ValueRange containing arguments.
216   FunctionType type = funcOp.getFunctionType();
217   SmallVector<Value, 8> args;
218   args.reserve(type.getNumInputs());
219   ValueRange wrapperArgsRange(newFuncOp.getArguments());
220 
221   if (resultIsNowArg) {
222     // Allocate the struct on the stack and pass the pointer.
223     Type resultType =
224         wrapperType.cast<LLVM::LLVMFunctionType>().getParamType(0);
225     Value one = builder.create<LLVM::ConstantOp>(
226         loc, typeConverter.convertType(builder.getIndexType()),
227         builder.getIntegerAttr(builder.getIndexType(), 1));
228     Value result = builder.create<LLVM::AllocaOp>(loc, resultType, one);
229     args.push_back(result);
230   }
231 
232   // Iterate over the inputs of the original function and pack values into
233   // memref descriptors if the original type is a memref.
234   for (auto &en : llvm::enumerate(type.getInputs())) {
235     Value arg;
236     int numToDrop = 1;
237     auto memRefType = en.value().dyn_cast<MemRefType>();
238     auto unrankedMemRefType = en.value().dyn_cast<UnrankedMemRefType>();
239     if (memRefType || unrankedMemRefType) {
240       numToDrop = memRefType
241                       ? MemRefDescriptor::getNumUnpackedValues(memRefType)
242                       : UnrankedMemRefDescriptor::getNumUnpackedValues();
243       Value packed =
244           memRefType
245               ? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType,
246                                        wrapperArgsRange.take_front(numToDrop))
247               : UnrankedMemRefDescriptor::pack(
248                     builder, loc, typeConverter, unrankedMemRefType,
249                     wrapperArgsRange.take_front(numToDrop));
250 
251       auto ptrTy = LLVM::LLVMPointerType::get(packed.getType());
252       Value one = builder.create<LLVM::ConstantOp>(
253           loc, typeConverter.convertType(builder.getIndexType()),
254           builder.getIntegerAttr(builder.getIndexType(), 1));
255       Value allocated =
256           builder.create<LLVM::AllocaOp>(loc, ptrTy, one, /*alignment=*/0);
257       builder.create<LLVM::StoreOp>(loc, packed, allocated);
258       arg = allocated;
259     } else {
260       arg = wrapperArgsRange[0];
261     }
262 
263     args.push_back(arg);
264     wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop);
265   }
266   assert(wrapperArgsRange.empty() && "did not map some of the arguments");
267 
268   auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args);
269 
270   if (resultIsNowArg) {
271     Value result = builder.create<LLVM::LoadOp>(loc, args.front());
272     builder.create<LLVM::ReturnOp>(loc, ValueRange{result});
273   } else {
274     builder.create<LLVM::ReturnOp>(loc, call.getResults());
275   }
276 }
277 
278 namespace {
279 
280 struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
281 protected:
282   using ConvertOpToLLVMPattern<func::FuncOp>::ConvertOpToLLVMPattern;
283 
284   // Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided
285   // to this legalization pattern.
286   LLVM::LLVMFuncOp
287   convertFuncOpToLLVMFuncOp(func::FuncOp funcOp,
288                             ConversionPatternRewriter &rewriter) const {
289     // Convert the original function arguments. They are converted using the
290     // LLVMTypeConverter provided to this legalization pattern.
291     auto varargsAttr = funcOp->getAttrOfType<BoolAttr>("func.varargs");
292     TypeConverter::SignatureConversion result(funcOp.getNumArguments());
293     auto llvmType = getTypeConverter()->convertFunctionSignature(
294         funcOp.getFunctionType(), varargsAttr && varargsAttr.getValue(),
295         result);
296     if (!llvmType)
297       return nullptr;
298 
299     // Propagate argument/result attributes to all converted arguments/result
300     // obtained after converting a given original argument/result.
301     SmallVector<NamedAttribute, 4> attributes;
302     filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/true,
303                          attributes);
304     if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
305       assert(!resAttrDicts.empty() && "expected array to be non-empty");
306       auto newResAttrDicts =
307           (funcOp.getNumResults() == 1)
308               ? resAttrDicts
309               : rewriter.getArrayAttr(
310                     {wrapAsStructAttrs(rewriter, resAttrDicts)});
311       attributes.push_back(rewriter.getNamedAttr(
312           FunctionOpInterface::getResultDictAttrName(), newResAttrDicts));
313     }
314     if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
315       SmallVector<Attribute, 4> newArgAttrs(
316           llvmType.cast<LLVM::LLVMFunctionType>().getNumParams());
317       for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
318         auto mapping = result.getInputMapping(i);
319         assert(mapping && "unexpected deletion of function argument");
320         for (size_t j = 0; j < mapping->size; ++j)
321           newArgAttrs[mapping->inputNo + j] = argAttrDicts[i];
322       }
323       attributes.push_back(
324           rewriter.getNamedAttr(FunctionOpInterface::getArgDictAttrName(),
325                                 rewriter.getArrayAttr(newArgAttrs)));
326     }
327     for (const auto &pair : llvm::enumerate(attributes)) {
328       if (pair.value().getName() == "llvm.linkage") {
329         attributes.erase(attributes.begin() + pair.index());
330         break;
331       }
332     }
333 
334     // Create an LLVM function, use external linkage by default until MLIR
335     // functions have linkage.
336     LLVM::Linkage linkage = LLVM::Linkage::External;
337     if (funcOp->hasAttr("llvm.linkage")) {
338       auto attr =
339           funcOp->getAttr("llvm.linkage").dyn_cast<mlir::LLVM::LinkageAttr>();
340       if (!attr) {
341         funcOp->emitError()
342             << "Contains llvm.linkage attribute not of type LLVM::LinkageAttr";
343         return nullptr;
344       }
345       linkage = attr.getLinkage();
346     }
347     auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
348         funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
349         /*dsoLocal*/ false, /*cconv*/ LLVM::CConv::C, attributes);
350     rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
351                                 newFuncOp.end());
352     if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
353                                            &result)))
354       return nullptr;
355 
356     return newFuncOp;
357   }
358 };
359 
360 /// FuncOp legalization pattern that converts MemRef arguments to pointers to
361 /// MemRef descriptors (LLVM struct data types) containing all the MemRef type
362 /// information.
363 struct FuncOpConversion : public FuncOpConversionBase {
364   FuncOpConversion(LLVMTypeConverter &converter)
365       : FuncOpConversionBase(converter) {}
366 
367   LogicalResult
368   matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
369                   ConversionPatternRewriter &rewriter) const override {
370     auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
371     if (!newFuncOp)
372       return failure();
373 
374     if (funcOp->getAttrOfType<UnitAttr>(
375             LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
376       if (newFuncOp.isVarArg())
377         return funcOp->emitError("C interface for variadic functions is not "
378                                  "supported yet.");
379 
380       if (newFuncOp.isExternal())
381         wrapExternalFunction(rewriter, funcOp.getLoc(), *getTypeConverter(),
382                              funcOp, newFuncOp);
383       else
384         wrapForExternalCallers(rewriter, funcOp.getLoc(), *getTypeConverter(),
385                                funcOp, newFuncOp);
386     }
387 
388     rewriter.eraseOp(funcOp);
389     return success();
390   }
391 };
392 
393 /// FuncOp legalization pattern that converts MemRef arguments to bare pointers
394 /// to the MemRef element type. This will impact the calling convention and ABI.
395 struct BarePtrFuncOpConversion : public FuncOpConversionBase {
396   using FuncOpConversionBase::FuncOpConversionBase;
397 
398   LogicalResult
399   matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
400                   ConversionPatternRewriter &rewriter) const override {
401 
402     // TODO: bare ptr conversion could be handled by argument materialization
403     // and most of the code below would go away. But to do this, we would need a
404     // way to distinguish between FuncOp and other regions in the
405     // addArgumentMaterialization hook.
406 
407     // Store the type of memref-typed arguments before the conversion so that we
408     // can promote them to MemRef descriptor at the beginning of the function.
409     SmallVector<Type, 8> oldArgTypes =
410         llvm::to_vector<8>(funcOp.getFunctionType().getInputs());
411 
412     auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
413     if (!newFuncOp)
414       return failure();
415     if (newFuncOp.getBody().empty()) {
416       rewriter.eraseOp(funcOp);
417       return success();
418     }
419 
420     // Promote bare pointers from memref arguments to memref descriptors at the
421     // beginning of the function so that all the memrefs in the function have a
422     // uniform representation.
423     Block *entryBlock = &newFuncOp.getBody().front();
424     auto blockArgs = entryBlock->getArguments();
425     assert(blockArgs.size() == oldArgTypes.size() &&
426            "The number of arguments and types doesn't match");
427 
428     OpBuilder::InsertionGuard guard(rewriter);
429     rewriter.setInsertionPointToStart(entryBlock);
430     for (auto it : llvm::zip(blockArgs, oldArgTypes)) {
431       BlockArgument arg = std::get<0>(it);
432       Type argTy = std::get<1>(it);
433 
434       // Unranked memrefs are not supported in the bare pointer calling
435       // convention. We should have bailed out before in the presence of
436       // unranked memrefs.
437       assert(!argTy.isa<UnrankedMemRefType>() &&
438              "Unranked memref is not supported");
439       auto memrefTy = argTy.dyn_cast<MemRefType>();
440       if (!memrefTy)
441         continue;
442 
443       // Replace barePtr with a placeholder (undef), promote barePtr to a ranked
444       // or unranked memref descriptor and replace placeholder with the last
445       // instruction of the memref descriptor.
446       // TODO: The placeholder is needed to avoid replacing barePtr uses in the
447       // MemRef descriptor instructions. We may want to have a utility in the
448       // rewriter to properly handle this use case.
449       Location loc = funcOp.getLoc();
450       auto placeholder = rewriter.create<LLVM::UndefOp>(
451           loc, getTypeConverter()->convertType(memrefTy));
452       rewriter.replaceUsesOfBlockArgument(arg, placeholder);
453 
454       Value desc = MemRefDescriptor::fromStaticShape(
455           rewriter, loc, *getTypeConverter(), memrefTy, arg);
456       rewriter.replaceOp(placeholder, {desc});
457     }
458 
459     rewriter.eraseOp(funcOp);
460     return success();
461   }
462 };
463 
464 struct ConstantOpLowering : public ConvertOpToLLVMPattern<func::ConstantOp> {
465   using ConvertOpToLLVMPattern<func::ConstantOp>::ConvertOpToLLVMPattern;
466 
467   LogicalResult
468   matchAndRewrite(func::ConstantOp op, OpAdaptor adaptor,
469                   ConversionPatternRewriter &rewriter) const override {
470     auto type = typeConverter->convertType(op.getResult().getType());
471     if (!type || !LLVM::isCompatibleType(type))
472       return rewriter.notifyMatchFailure(op, "failed to convert result type");
473 
474     auto newOp =
475         rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type, op.getValue());
476     for (const NamedAttribute &attr : op->getAttrs()) {
477       if (attr.getName().strref() == "value")
478         continue;
479       newOp->setAttr(attr.getName(), attr.getValue());
480     }
481     rewriter.replaceOp(op, newOp->getResults());
482     return success();
483   }
484 };
485 
486 // A CallOp automatically promotes MemRefType to a sequence of alloca/store and
487 // passes the pointer to the MemRef across function boundaries.
488 template <typename CallOpType>
489 struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
490   using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern;
491   using Super = CallOpInterfaceLowering<CallOpType>;
492   using Base = ConvertOpToLLVMPattern<CallOpType>;
493 
494   LogicalResult
495   matchAndRewrite(CallOpType callOp, typename CallOpType::Adaptor adaptor,
496                   ConversionPatternRewriter &rewriter) const override {
497     // Pack the result types into a struct.
498     Type packedResult = nullptr;
499     unsigned numResults = callOp.getNumResults();
500     auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
501 
502     if (numResults != 0) {
503       if (!(packedResult =
504                 this->getTypeConverter()->packFunctionResults(resultTypes)))
505         return failure();
506     }
507 
508     auto promoted = this->getTypeConverter()->promoteOperands(
509         callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
510         adaptor.getOperands(), rewriter);
511     auto newOp = rewriter.create<LLVM::CallOp>(
512         callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
513         promoted, callOp->getAttrs());
514 
515     SmallVector<Value, 4> results;
516     if (numResults < 2) {
517       // If < 2 results, packing did not do anything and we can just return.
518       results.append(newOp.result_begin(), newOp.result_end());
519     } else {
520       // Otherwise, it had been converted to an operation producing a structure.
521       // Extract individual results from the structure and return them as list.
522       results.reserve(numResults);
523       for (unsigned i = 0; i < numResults; ++i) {
524         auto type =
525             this->typeConverter->convertType(callOp.getResult(i).getType());
526         results.push_back(rewriter.create<LLVM::ExtractValueOp>(
527             callOp.getLoc(), type, newOp->getResult(0),
528             rewriter.getI64ArrayAttr(i)));
529       }
530     }
531 
532     if (this->getTypeConverter()->getOptions().useBarePtrCallConv) {
533       // For the bare-ptr calling convention, promote memref results to
534       // descriptors.
535       assert(results.size() == resultTypes.size() &&
536              "The number of arguments and types doesn't match");
537       this->getTypeConverter()->promoteBarePtrsToDescriptors(
538           rewriter, callOp.getLoc(), resultTypes, results);
539     } else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(),
540                                                     resultTypes, results,
541                                                     /*toDynamic=*/false))) {
542       return failure();
543     }
544 
545     rewriter.replaceOp(callOp, results);
546     return success();
547   }
548 };
549 
550 struct CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
551   using Super::Super;
552 };
553 
554 struct CallIndirectOpLowering
555     : public CallOpInterfaceLowering<func::CallIndirectOp> {
556   using Super::Super;
557 };
558 
559 struct UnrealizedConversionCastOpLowering
560     : public ConvertOpToLLVMPattern<UnrealizedConversionCastOp> {
561   using ConvertOpToLLVMPattern<
562       UnrealizedConversionCastOp>::ConvertOpToLLVMPattern;
563 
564   LogicalResult
565   matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor,
566                   ConversionPatternRewriter &rewriter) const override {
567     SmallVector<Type> convertedTypes;
568     if (succeeded(typeConverter->convertTypes(op.getOutputs().getTypes(),
569                                               convertedTypes)) &&
570         convertedTypes == adaptor.getInputs().getTypes()) {
571       rewriter.replaceOp(op, adaptor.getInputs());
572       return success();
573     }
574 
575     convertedTypes.clear();
576     if (succeeded(typeConverter->convertTypes(adaptor.getInputs().getTypes(),
577                                               convertedTypes)) &&
578         convertedTypes == op.getOutputs().getType()) {
579       rewriter.replaceOp(op, adaptor.getInputs());
580       return success();
581     }
582     return failure();
583   }
584 };
585 
586 // Special lowering pattern for `ReturnOps`.  Unlike all other operations,
587 // `ReturnOp` interacts with the function signature and must have as many
588 // operands as the function has return values.  Because in LLVM IR, functions
589 // can only return 0 or 1 value, we pack multiple values into a structure type.
590 // Emit `UndefOp` followed by `InsertValueOp`s to create such structure if
591 // necessary before returning it
592 struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
593   using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
594 
595   LogicalResult
596   matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
597                   ConversionPatternRewriter &rewriter) const override {
598     Location loc = op.getLoc();
599     unsigned numArguments = op.getNumOperands();
600     SmallVector<Value, 4> updatedOperands;
601 
602     if (getTypeConverter()->getOptions().useBarePtrCallConv) {
603       // For the bare-ptr calling convention, extract the aligned pointer to
604       // be returned from the memref descriptor.
605       for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
606         Type oldTy = std::get<0>(it).getType();
607         Value newOperand = std::get<1>(it);
608         if (oldTy.isa<MemRefType>() && getTypeConverter()->canConvertToBarePtr(
609                                            oldTy.cast<BaseMemRefType>())) {
610           MemRefDescriptor memrefDesc(newOperand);
611           newOperand = memrefDesc.alignedPtr(rewriter, loc);
612         } else if (oldTy.isa<UnrankedMemRefType>()) {
613           // Unranked memref is not supported in the bare pointer calling
614           // convention.
615           return failure();
616         }
617         updatedOperands.push_back(newOperand);
618       }
619     } else {
620       updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
621       (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
622                                     updatedOperands,
623                                     /*toDynamic=*/true);
624     }
625 
626     // If ReturnOp has 0 or 1 operand, create it and return immediately.
627     if (numArguments == 0) {
628       rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(),
629                                                   op->getAttrs());
630       return success();
631     }
632     if (numArguments == 1) {
633       rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
634           op, TypeRange(), updatedOperands, op->getAttrs());
635       return success();
636     }
637 
638     // Otherwise, we need to pack the arguments into an LLVM struct type before
639     // returning.
640     auto packedType = getTypeConverter()->packFunctionResults(
641         llvm::to_vector<4>(op.getOperandTypes()));
642 
643     Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
644     for (unsigned i = 0; i < numArguments; ++i) {
645       packed = rewriter.create<LLVM::InsertValueOp>(
646           loc, packedType, packed, updatedOperands[i],
647           rewriter.getI64ArrayAttr(i));
648     }
649     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
650                                                 op->getAttrs());
651     return success();
652   }
653 };
654 } // namespace
655 
656 void mlir::populateFuncToLLVMFuncOpConversionPattern(
657     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
658   if (converter.getOptions().useBarePtrCallConv)
659     patterns.add<BarePtrFuncOpConversion>(converter);
660   else
661     patterns.add<FuncOpConversion>(converter);
662 }
663 
664 void mlir::populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter,
665                                                 RewritePatternSet &patterns) {
666   populateFuncToLLVMFuncOpConversionPattern(converter, patterns);
667   // clang-format off
668   patterns.add<
669       CallIndirectOpLowering,
670       CallOpLowering,
671       ConstantOpLowering,
672       ReturnOpLowering>(converter);
673   // clang-format on
674 }
675 
676 namespace {
677 /// A pass converting Func operations into the LLVM IR dialect.
678 struct ConvertFuncToLLVMPass
679     : public ConvertFuncToLLVMBase<ConvertFuncToLLVMPass> {
680   ConvertFuncToLLVMPass() = default;
681   ConvertFuncToLLVMPass(bool useBarePtrCallConv, unsigned indexBitwidth,
682                         bool useAlignedAlloc,
683                         const llvm::DataLayout &dataLayout) {
684     this->useBarePtrCallConv = useBarePtrCallConv;
685     this->indexBitwidth = indexBitwidth;
686     this->dataLayout = dataLayout.getStringRepresentation();
687   }
688 
689   /// Run the dialect converter on the module.
690   void runOnOperation() override {
691     if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
692             this->dataLayout, [this](const Twine &message) {
693               getOperation().emitError() << message.str();
694             }))) {
695       signalPassFailure();
696       return;
697     }
698 
699     ModuleOp m = getOperation();
700     const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
701 
702     LowerToLLVMOptions options(&getContext(),
703                                dataLayoutAnalysis.getAtOrAbove(m));
704     options.useBarePtrCallConv = useBarePtrCallConv;
705     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
706       options.overrideIndexBitwidth(indexBitwidth);
707     options.dataLayout = llvm::DataLayout(this->dataLayout);
708 
709     LLVMTypeConverter typeConverter(&getContext(), options,
710                                     &dataLayoutAnalysis);
711 
712     RewritePatternSet patterns(&getContext());
713     populateFuncToLLVMConversionPatterns(typeConverter, patterns);
714 
715     // TODO: Remove these in favor of their dedicated conversion passes.
716     arith::populateArithmeticToLLVMConversionPatterns(typeConverter, patterns);
717     cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
718 
719     LLVMConversionTarget target(getContext());
720     if (failed(applyPartialConversion(m, target, std::move(patterns))))
721       signalPassFailure();
722 
723     m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(),
724                StringAttr::get(m.getContext(), this->dataLayout));
725   }
726 };
727 } // namespace
728 
729 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertFuncToLLVMPass() {
730   return std::make_unique<ConvertFuncToLLVMPass>();
731 }
732 
733 std::unique_ptr<OperationPass<ModuleOp>>
734 mlir::createConvertFuncToLLVMPass(const LowerToLLVMOptions &options) {
735   auto allocLowering = options.allocLowering;
736   // There is no way to provide additional patterns for pass, so
737   // AllocLowering::None will always fail.
738   assert(allocLowering != LowerToLLVMOptions::AllocLowering::None &&
739          "ConvertFuncToLLVMPass doesn't support AllocLowering::None");
740   bool useAlignedAlloc =
741       (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc);
742   return std::make_unique<ConvertFuncToLLVMPass>(
743       options.useBarePtrCallConv, options.getIndexBitwidth(), useAlignedAlloc,
744       options.dataLayout);
745 }
746