1 //===- FuncToLLVM.cpp - Func to LLVM dialect conversion -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert MLIR Func and builtin dialects
10 // into the LLVM IR dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "../PassDetail.h"
15 #include "mlir/Analysis/DataLayoutAnalysis.h"
16 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
17 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
18 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
19 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
20 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
21 #include "mlir/Conversion/LLVMCommon/Pattern.h"
22 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h"
24 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
25 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
26 #include "mlir/Dialect/Utils/StaticValueUtils.h"
27 #include "mlir/IR/Attributes.h"
28 #include "mlir/IR/BlockAndValueMapping.h"
29 #include "mlir/IR/Builders.h"
30 #include "mlir/IR/BuiltinOps.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeUtilities.h"
33 #include "mlir/Support/LogicalResult.h"
34 #include "mlir/Support/MathExtras.h"
35 #include "mlir/Transforms/DialectConversion.h"
36 #include "mlir/Transforms/Passes.h"
37 #include "llvm/ADT/TypeSwitch.h"
38 #include "llvm/IR/DerivedTypes.h"
39 #include "llvm/IR/IRBuilder.h"
40 #include "llvm/IR/Type.h"
41 #include "llvm/Support/CommandLine.h"
42 #include "llvm/Support/FormatVariadic.h"
43 #include <functional>
44 
45 using namespace mlir;
46 
47 #define PASS_NAME "convert-func-to-llvm"
48 
49 /// Only retain those attributes that are not constructed by
50 /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
51 /// attributes.
52 static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
53                                  bool filterArgAttrs,
54                                  SmallVectorImpl<NamedAttribute> &result) {
55   for (const auto &attr : attrs) {
56     if (attr.getName() == SymbolTable::getSymbolAttrName() ||
57         attr.getName() == FunctionOpInterface::getTypeAttrName() ||
58         attr.getName() == "func.varargs" ||
59         (filterArgAttrs &&
60          attr.getName() == FunctionOpInterface::getArgDictAttrName()))
61       continue;
62     result.push_back(attr);
63   }
64 }
65 
66 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
67 /// arguments instead of unpacked arguments. This function can be called from C
68 /// by passing a pointer to a C struct corresponding to a memref descriptor.
69 /// Similarly, returned memrefs are passed via pointers to a C struct that is
70 /// passed as additional argument.
71 /// Internally, the auxiliary function unpacks the descriptor into individual
72 /// components and forwards them to `newFuncOp` and forwards the results to
73 /// the extra arguments.
74 static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
75                                    LLVMTypeConverter &typeConverter,
76                                    FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) {
77   auto type = funcOp.getType();
78   SmallVector<NamedAttribute, 4> attributes;
79   filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/false,
80                        attributes);
81   Type wrapperFuncType;
82   bool resultIsNowArg;
83   std::tie(wrapperFuncType, resultIsNowArg) =
84       typeConverter.convertFunctionTypeCWrapper(type);
85   auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
86       loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
87       wrapperFuncType, LLVM::Linkage::External, /*dsoLocal*/ false, attributes);
88 
89   OpBuilder::InsertionGuard guard(rewriter);
90   rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock());
91 
92   SmallVector<Value, 8> args;
93   size_t argOffset = resultIsNowArg ? 1 : 0;
94   for (auto &en : llvm::enumerate(type.getInputs())) {
95     Value arg = wrapperFuncOp.getArgument(en.index() + argOffset);
96     if (auto memrefType = en.value().dyn_cast<MemRefType>()) {
97       Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
98       MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
99       continue;
100     }
101     if (en.value().isa<UnrankedMemRefType>()) {
102       Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
103       UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args);
104       continue;
105     }
106 
107     args.push_back(arg);
108   }
109 
110   auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args);
111 
112   if (resultIsNowArg) {
113     rewriter.create<LLVM::StoreOp>(loc, call.getResult(0),
114                                    wrapperFuncOp.getArgument(0));
115     rewriter.create<LLVM::ReturnOp>(loc, ValueRange{});
116   } else {
117     rewriter.create<LLVM::ReturnOp>(loc, call.getResults());
118   }
119 }
120 
121 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
122 /// arguments instead of unpacked arguments. Creates a body for the (external)
123 /// `newFuncOp` that allocates a memref descriptor on stack, packs the
124 /// individual arguments into this descriptor and passes a pointer to it into
125 /// the auxiliary function. If the result of the function cannot be directly
126 /// returned, we write it to a special first argument that provides a pointer
127 /// to a corresponding struct. This auxiliary external function is now
128 /// compatible with functions defined in C using pointers to C structs
129 /// corresponding to a memref descriptor.
130 static void wrapExternalFunction(OpBuilder &builder, Location loc,
131                                  LLVMTypeConverter &typeConverter,
132                                  FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) {
133   OpBuilder::InsertionGuard guard(builder);
134 
135   Type wrapperType;
136   bool resultIsNowArg;
137   std::tie(wrapperType, resultIsNowArg) =
138       typeConverter.convertFunctionTypeCWrapper(funcOp.getType());
139   // This conversion can only fail if it could not convert one of the argument
140   // types. But since it has been applied to a non-wrapper function before, it
141   // should have failed earlier and not reach this point at all.
142   assert(wrapperType && "unexpected type conversion failure");
143 
144   SmallVector<NamedAttribute, 4> attributes;
145   filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/false,
146                        attributes);
147 
148   // Create the auxiliary function.
149   auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
150       loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
151       wrapperType, LLVM::Linkage::External, /*dsoLocal*/ false, attributes);
152 
153   builder.setInsertionPointToStart(newFuncOp.addEntryBlock());
154 
155   // Get a ValueRange containing arguments.
156   FunctionType type = funcOp.getType();
157   SmallVector<Value, 8> args;
158   args.reserve(type.getNumInputs());
159   ValueRange wrapperArgsRange(newFuncOp.getArguments());
160 
161   if (resultIsNowArg) {
162     // Allocate the struct on the stack and pass the pointer.
163     Type resultType =
164         wrapperType.cast<LLVM::LLVMFunctionType>().getParamType(0);
165     Value one = builder.create<LLVM::ConstantOp>(
166         loc, typeConverter.convertType(builder.getIndexType()),
167         builder.getIntegerAttr(builder.getIndexType(), 1));
168     Value result = builder.create<LLVM::AllocaOp>(loc, resultType, one);
169     args.push_back(result);
170   }
171 
172   // Iterate over the inputs of the original function and pack values into
173   // memref descriptors if the original type is a memref.
174   for (auto &en : llvm::enumerate(type.getInputs())) {
175     Value arg;
176     int numToDrop = 1;
177     auto memRefType = en.value().dyn_cast<MemRefType>();
178     auto unrankedMemRefType = en.value().dyn_cast<UnrankedMemRefType>();
179     if (memRefType || unrankedMemRefType) {
180       numToDrop = memRefType
181                       ? MemRefDescriptor::getNumUnpackedValues(memRefType)
182                       : UnrankedMemRefDescriptor::getNumUnpackedValues();
183       Value packed =
184           memRefType
185               ? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType,
186                                        wrapperArgsRange.take_front(numToDrop))
187               : UnrankedMemRefDescriptor::pack(
188                     builder, loc, typeConverter, unrankedMemRefType,
189                     wrapperArgsRange.take_front(numToDrop));
190 
191       auto ptrTy = LLVM::LLVMPointerType::get(packed.getType());
192       Value one = builder.create<LLVM::ConstantOp>(
193           loc, typeConverter.convertType(builder.getIndexType()),
194           builder.getIntegerAttr(builder.getIndexType(), 1));
195       Value allocated =
196           builder.create<LLVM::AllocaOp>(loc, ptrTy, one, /*alignment=*/0);
197       builder.create<LLVM::StoreOp>(loc, packed, allocated);
198       arg = allocated;
199     } else {
200       arg = wrapperArgsRange[0];
201     }
202 
203     args.push_back(arg);
204     wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop);
205   }
206   assert(wrapperArgsRange.empty() && "did not map some of the arguments");
207 
208   auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args);
209 
210   if (resultIsNowArg) {
211     Value result = builder.create<LLVM::LoadOp>(loc, args.front());
212     builder.create<LLVM::ReturnOp>(loc, ValueRange{result});
213   } else {
214     builder.create<LLVM::ReturnOp>(loc, call.getResults());
215   }
216 }
217 
218 namespace {
219 
220 struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
221 protected:
222   using ConvertOpToLLVMPattern<FuncOp>::ConvertOpToLLVMPattern;
223 
224   // Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided
225   // to this legalization pattern.
226   LLVM::LLVMFuncOp
227   convertFuncOpToLLVMFuncOp(FuncOp funcOp,
228                             ConversionPatternRewriter &rewriter) const {
229     // Convert the original function arguments. They are converted using the
230     // LLVMTypeConverter provided to this legalization pattern.
231     auto varargsAttr = funcOp->getAttrOfType<BoolAttr>("func.varargs");
232     TypeConverter::SignatureConversion result(funcOp.getNumArguments());
233     auto llvmType = getTypeConverter()->convertFunctionSignature(
234         funcOp.getType(), varargsAttr && varargsAttr.getValue(), result);
235     if (!llvmType)
236       return nullptr;
237 
238     // Propagate argument attributes to all converted arguments obtained after
239     // converting a given original argument.
240     SmallVector<NamedAttribute, 4> attributes;
241     filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/true,
242                          attributes);
243     if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
244       SmallVector<Attribute, 4> newArgAttrs(
245           llvmType.cast<LLVM::LLVMFunctionType>().getNumParams());
246       for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
247         auto mapping = result.getInputMapping(i);
248         assert(mapping.hasValue() &&
249                "unexpected deletion of function argument");
250         for (size_t j = 0; j < mapping->size; ++j)
251           newArgAttrs[mapping->inputNo + j] = argAttrDicts[i];
252       }
253       attributes.push_back(
254           rewriter.getNamedAttr(FunctionOpInterface::getArgDictAttrName(),
255                                 rewriter.getArrayAttr(newArgAttrs)));
256     }
257     for (const auto &pair : llvm::enumerate(attributes)) {
258       if (pair.value().getName() == "llvm.linkage") {
259         attributes.erase(attributes.begin() + pair.index());
260         break;
261       }
262     }
263 
264     // Create an LLVM function, use external linkage by default until MLIR
265     // functions have linkage.
266     LLVM::Linkage linkage = LLVM::Linkage::External;
267     if (funcOp->hasAttr("llvm.linkage")) {
268       auto attr =
269           funcOp->getAttr("llvm.linkage").dyn_cast<mlir::LLVM::LinkageAttr>();
270       if (!attr) {
271         funcOp->emitError()
272             << "Contains llvm.linkage attribute not of type LLVM::LinkageAttr";
273         return nullptr;
274       }
275       linkage = attr.getLinkage();
276     }
277     auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
278         funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
279         /*dsoLocal*/ false, attributes);
280     rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
281                                 newFuncOp.end());
282     if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
283                                            &result)))
284       return nullptr;
285 
286     return newFuncOp;
287   }
288 };
289 
290 /// FuncOp legalization pattern that converts MemRef arguments to pointers to
291 /// MemRef descriptors (LLVM struct data types) containing all the MemRef type
292 /// information.
293 static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface";
294 struct FuncOpConversion : public FuncOpConversionBase {
295   FuncOpConversion(LLVMTypeConverter &converter)
296       : FuncOpConversionBase(converter) {}
297 
298   LogicalResult
299   matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor,
300                   ConversionPatternRewriter &rewriter) const override {
301     auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
302     if (!newFuncOp)
303       return failure();
304 
305     if (getTypeConverter()->getOptions().emitCWrappers ||
306         funcOp->getAttrOfType<UnitAttr>(kEmitIfaceAttrName)) {
307       if (newFuncOp.isExternal())
308         wrapExternalFunction(rewriter, funcOp.getLoc(), *getTypeConverter(),
309                              funcOp, newFuncOp);
310       else
311         wrapForExternalCallers(rewriter, funcOp.getLoc(), *getTypeConverter(),
312                                funcOp, newFuncOp);
313     }
314 
315     rewriter.eraseOp(funcOp);
316     return success();
317   }
318 };
319 
320 /// FuncOp legalization pattern that converts MemRef arguments to bare pointers
321 /// to the MemRef element type. This will impact the calling convention and ABI.
322 struct BarePtrFuncOpConversion : public FuncOpConversionBase {
323   using FuncOpConversionBase::FuncOpConversionBase;
324 
325   LogicalResult
326   matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor,
327                   ConversionPatternRewriter &rewriter) const override {
328 
329     // TODO: bare ptr conversion could be handled by argument materialization
330     // and most of the code below would go away. But to do this, we would need a
331     // way to distinguish between FuncOp and other regions in the
332     // addArgumentMaterialization hook.
333 
334     // Store the type of memref-typed arguments before the conversion so that we
335     // can promote them to MemRef descriptor at the beginning of the function.
336     SmallVector<Type, 8> oldArgTypes =
337         llvm::to_vector<8>(funcOp.getType().getInputs());
338 
339     auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
340     if (!newFuncOp)
341       return failure();
342     if (newFuncOp.getBody().empty()) {
343       rewriter.eraseOp(funcOp);
344       return success();
345     }
346 
347     // Promote bare pointers from memref arguments to memref descriptors at the
348     // beginning of the function so that all the memrefs in the function have a
349     // uniform representation.
350     Block *entryBlock = &newFuncOp.getBody().front();
351     auto blockArgs = entryBlock->getArguments();
352     assert(blockArgs.size() == oldArgTypes.size() &&
353            "The number of arguments and types doesn't match");
354 
355     OpBuilder::InsertionGuard guard(rewriter);
356     rewriter.setInsertionPointToStart(entryBlock);
357     for (auto it : llvm::zip(blockArgs, oldArgTypes)) {
358       BlockArgument arg = std::get<0>(it);
359       Type argTy = std::get<1>(it);
360 
361       // Unranked memrefs are not supported in the bare pointer calling
362       // convention. We should have bailed out before in the presence of
363       // unranked memrefs.
364       assert(!argTy.isa<UnrankedMemRefType>() &&
365              "Unranked memref is not supported");
366       auto memrefTy = argTy.dyn_cast<MemRefType>();
367       if (!memrefTy)
368         continue;
369 
370       // Replace barePtr with a placeholder (undef), promote barePtr to a ranked
371       // or unranked memref descriptor and replace placeholder with the last
372       // instruction of the memref descriptor.
373       // TODO: The placeholder is needed to avoid replacing barePtr uses in the
374       // MemRef descriptor instructions. We may want to have a utility in the
375       // rewriter to properly handle this use case.
376       Location loc = funcOp.getLoc();
377       auto placeholder = rewriter.create<LLVM::UndefOp>(
378           loc, getTypeConverter()->convertType(memrefTy));
379       rewriter.replaceUsesOfBlockArgument(arg, placeholder);
380 
381       Value desc = MemRefDescriptor::fromStaticShape(
382           rewriter, loc, *getTypeConverter(), memrefTy, arg);
383       rewriter.replaceOp(placeholder, {desc});
384     }
385 
386     rewriter.eraseOp(funcOp);
387     return success();
388   }
389 };
390 
391 struct ConstantOpLowering : public ConvertOpToLLVMPattern<func::ConstantOp> {
392   using ConvertOpToLLVMPattern<func::ConstantOp>::ConvertOpToLLVMPattern;
393 
394   LogicalResult
395   matchAndRewrite(func::ConstantOp op, OpAdaptor adaptor,
396                   ConversionPatternRewriter &rewriter) const override {
397     auto type = typeConverter->convertType(op.getResult().getType());
398     if (!type || !LLVM::isCompatibleType(type))
399       return rewriter.notifyMatchFailure(op, "failed to convert result type");
400 
401     auto newOp =
402         rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type, op.getValue());
403     for (const NamedAttribute &attr : op->getAttrs()) {
404       if (attr.getName().strref() == "value")
405         continue;
406       newOp->setAttr(attr.getName(), attr.getValue());
407     }
408     rewriter.replaceOp(op, newOp->getResults());
409     return success();
410   }
411 };
412 
413 // A CallOp automatically promotes MemRefType to a sequence of alloca/store and
414 // passes the pointer to the MemRef across function boundaries.
415 template <typename CallOpType>
416 struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
417   using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern;
418   using Super = CallOpInterfaceLowering<CallOpType>;
419   using Base = ConvertOpToLLVMPattern<CallOpType>;
420 
421   LogicalResult
422   matchAndRewrite(CallOpType callOp, typename CallOpType::Adaptor adaptor,
423                   ConversionPatternRewriter &rewriter) const override {
424     // Pack the result types into a struct.
425     Type packedResult = nullptr;
426     unsigned numResults = callOp.getNumResults();
427     auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
428 
429     if (numResults != 0) {
430       if (!(packedResult =
431                 this->getTypeConverter()->packFunctionResults(resultTypes)))
432         return failure();
433     }
434 
435     auto promoted = this->getTypeConverter()->promoteOperands(
436         callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
437         adaptor.getOperands(), rewriter);
438     auto newOp = rewriter.create<LLVM::CallOp>(
439         callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
440         promoted, callOp->getAttrs());
441 
442     SmallVector<Value, 4> results;
443     if (numResults < 2) {
444       // If < 2 results, packing did not do anything and we can just return.
445       results.append(newOp.result_begin(), newOp.result_end());
446     } else {
447       // Otherwise, it had been converted to an operation producing a structure.
448       // Extract individual results from the structure and return them as list.
449       results.reserve(numResults);
450       for (unsigned i = 0; i < numResults; ++i) {
451         auto type =
452             this->typeConverter->convertType(callOp.getResult(i).getType());
453         results.push_back(rewriter.create<LLVM::ExtractValueOp>(
454             callOp.getLoc(), type, newOp->getResult(0),
455             rewriter.getI64ArrayAttr(i)));
456       }
457     }
458 
459     if (this->getTypeConverter()->getOptions().useBarePtrCallConv) {
460       // For the bare-ptr calling convention, promote memref results to
461       // descriptors.
462       assert(results.size() == resultTypes.size() &&
463              "The number of arguments and types doesn't match");
464       this->getTypeConverter()->promoteBarePtrsToDescriptors(
465           rewriter, callOp.getLoc(), resultTypes, results);
466     } else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(),
467                                                     resultTypes, results,
468                                                     /*toDynamic=*/false))) {
469       return failure();
470     }
471 
472     rewriter.replaceOp(callOp, results);
473     return success();
474   }
475 };
476 
477 struct CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
478   using Super::Super;
479 };
480 
481 struct CallIndirectOpLowering
482     : public CallOpInterfaceLowering<func::CallIndirectOp> {
483   using Super::Super;
484 };
485 
486 struct UnrealizedConversionCastOpLowering
487     : public ConvertOpToLLVMPattern<UnrealizedConversionCastOp> {
488   using ConvertOpToLLVMPattern<
489       UnrealizedConversionCastOp>::ConvertOpToLLVMPattern;
490 
491   LogicalResult
492   matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor,
493                   ConversionPatternRewriter &rewriter) const override {
494     SmallVector<Type> convertedTypes;
495     if (succeeded(typeConverter->convertTypes(op.getOutputs().getTypes(),
496                                               convertedTypes)) &&
497         convertedTypes == adaptor.getInputs().getTypes()) {
498       rewriter.replaceOp(op, adaptor.getInputs());
499       return success();
500     }
501 
502     convertedTypes.clear();
503     if (succeeded(typeConverter->convertTypes(adaptor.getInputs().getTypes(),
504                                               convertedTypes)) &&
505         convertedTypes == op.getOutputs().getType()) {
506       rewriter.replaceOp(op, adaptor.getInputs());
507       return success();
508     }
509     return failure();
510   }
511 };
512 
513 // Special lowering pattern for `ReturnOps`.  Unlike all other operations,
514 // `ReturnOp` interacts with the function signature and must have as many
515 // operands as the function has return values.  Because in LLVM IR, functions
516 // can only return 0 or 1 value, we pack multiple values into a structure type.
517 // Emit `UndefOp` followed by `InsertValueOp`s to create such structure if
518 // necessary before returning it
519 struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
520   using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
521 
522   LogicalResult
523   matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
524                   ConversionPatternRewriter &rewriter) const override {
525     Location loc = op.getLoc();
526     unsigned numArguments = op.getNumOperands();
527     SmallVector<Value, 4> updatedOperands;
528 
529     if (getTypeConverter()->getOptions().useBarePtrCallConv) {
530       // For the bare-ptr calling convention, extract the aligned pointer to
531       // be returned from the memref descriptor.
532       for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
533         Type oldTy = std::get<0>(it).getType();
534         Value newOperand = std::get<1>(it);
535         if (oldTy.isa<MemRefType>()) {
536           MemRefDescriptor memrefDesc(newOperand);
537           newOperand = memrefDesc.alignedPtr(rewriter, loc);
538         } else if (oldTy.isa<UnrankedMemRefType>()) {
539           // Unranked memref is not supported in the bare pointer calling
540           // convention.
541           return failure();
542         }
543         updatedOperands.push_back(newOperand);
544       }
545     } else {
546       updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
547       (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
548                                     updatedOperands,
549                                     /*toDynamic=*/true);
550     }
551 
552     // If ReturnOp has 0 or 1 operand, create it and return immediately.
553     if (numArguments == 0) {
554       rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(),
555                                                   op->getAttrs());
556       return success();
557     }
558     if (numArguments == 1) {
559       rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
560           op, TypeRange(), updatedOperands, op->getAttrs());
561       return success();
562     }
563 
564     // Otherwise, we need to pack the arguments into an LLVM struct type before
565     // returning.
566     auto packedType = getTypeConverter()->packFunctionResults(
567         llvm::to_vector<4>(op.getOperandTypes()));
568 
569     Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
570     for (unsigned i = 0; i < numArguments; ++i) {
571       packed = rewriter.create<LLVM::InsertValueOp>(
572           loc, packedType, packed, updatedOperands[i],
573           rewriter.getI64ArrayAttr(i));
574     }
575     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
576                                                 op->getAttrs());
577     return success();
578   }
579 };
580 } // namespace
581 
582 void mlir::populateFuncToLLVMFuncOpConversionPattern(
583     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
584   if (converter.getOptions().useBarePtrCallConv)
585     patterns.add<BarePtrFuncOpConversion>(converter);
586   else
587     patterns.add<FuncOpConversion>(converter);
588 }
589 
590 void mlir::populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter,
591                                                 RewritePatternSet &patterns) {
592   populateFuncToLLVMFuncOpConversionPattern(converter, patterns);
593   // clang-format off
594   patterns.add<
595       CallIndirectOpLowering,
596       CallOpLowering,
597       ConstantOpLowering,
598       ReturnOpLowering>(converter);
599   // clang-format on
600 }
601 
602 namespace {
603 /// A pass converting Func operations into the LLVM IR dialect.
604 struct ConvertFuncToLLVMPass
605     : public ConvertFuncToLLVMBase<ConvertFuncToLLVMPass> {
606   ConvertFuncToLLVMPass() = default;
607   ConvertFuncToLLVMPass(bool useBarePtrCallConv, bool emitCWrappers,
608                         unsigned indexBitwidth, bool useAlignedAlloc,
609                         const llvm::DataLayout &dataLayout) {
610     this->useBarePtrCallConv = useBarePtrCallConv;
611     this->emitCWrappers = emitCWrappers;
612     this->indexBitwidth = indexBitwidth;
613     this->dataLayout = dataLayout.getStringRepresentation();
614   }
615 
616   /// Run the dialect converter on the module.
617   void runOnOperation() override {
618     if (useBarePtrCallConv && emitCWrappers) {
619       getOperation().emitError()
620           << "incompatible conversion options: bare-pointer calling convention "
621              "and C wrapper emission";
622       signalPassFailure();
623       return;
624     }
625     if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
626             this->dataLayout, [this](const Twine &message) {
627               getOperation().emitError() << message.str();
628             }))) {
629       signalPassFailure();
630       return;
631     }
632 
633     ModuleOp m = getOperation();
634     const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
635 
636     LowerToLLVMOptions options(&getContext(),
637                                dataLayoutAnalysis.getAtOrAbove(m));
638     options.useBarePtrCallConv = useBarePtrCallConv;
639     options.emitCWrappers = emitCWrappers;
640     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
641       options.overrideIndexBitwidth(indexBitwidth);
642     options.dataLayout = llvm::DataLayout(this->dataLayout);
643 
644     LLVMTypeConverter typeConverter(&getContext(), options,
645                                     &dataLayoutAnalysis);
646 
647     RewritePatternSet patterns(&getContext());
648     populateFuncToLLVMConversionPatterns(typeConverter, patterns);
649 
650     // TODO: Remove these in favor of their dedicated conversion passes.
651     arith::populateArithmeticToLLVMConversionPatterns(typeConverter, patterns);
652     cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
653 
654     LLVMConversionTarget target(getContext());
655     if (failed(applyPartialConversion(m, target, std::move(patterns))))
656       signalPassFailure();
657 
658     m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(),
659                StringAttr::get(m.getContext(), this->dataLayout));
660   }
661 };
662 } // namespace
663 
664 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertFuncToLLVMPass() {
665   return std::make_unique<ConvertFuncToLLVMPass>();
666 }
667 
668 std::unique_ptr<OperationPass<ModuleOp>>
669 mlir::createConvertFuncToLLVMPass(const LowerToLLVMOptions &options) {
670   auto allocLowering = options.allocLowering;
671   // There is no way to provide additional patterns for pass, so
672   // AllocLowering::None will always fail.
673   assert(allocLowering != LowerToLLVMOptions::AllocLowering::None &&
674          "ConvertFuncToLLVMPass doesn't support AllocLowering::None");
675   bool useAlignedAlloc =
676       (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc);
677   return std::make_unique<ConvertFuncToLLVMPass>(
678       options.useBarePtrCallConv, options.emitCWrappers,
679       options.getIndexBitwidth(), useAlignedAlloc, options.dataLayout);
680 }
681