1 //===- FuncToLLVM.cpp - Func to LLVM dialect conversion -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert MLIR Func and builtin dialects
10 // into the LLVM IR dialect.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "../PassDetail.h"
15 #include "mlir/Analysis/DataLayoutAnalysis.h"
16 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
17 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
18 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
19 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
20 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
21 #include "mlir/Conversion/LLVMCommon/Pattern.h"
22 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h"
24 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
25 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
26 #include "mlir/Dialect/Utils/StaticValueUtils.h"
27 #include "mlir/IR/Attributes.h"
28 #include "mlir/IR/BlockAndValueMapping.h"
29 #include "mlir/IR/Builders.h"
30 #include "mlir/IR/BuiltinOps.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeUtilities.h"
33 #include "mlir/Support/LogicalResult.h"
34 #include "mlir/Support/MathExtras.h"
35 #include "mlir/Transforms/DialectConversion.h"
36 #include "mlir/Transforms/Passes.h"
37 #include "llvm/ADT/TypeSwitch.h"
38 #include "llvm/IR/DerivedTypes.h"
39 #include "llvm/IR/IRBuilder.h"
40 #include "llvm/IR/Type.h"
41 #include "llvm/Support/CommandLine.h"
42 #include "llvm/Support/FormatVariadic.h"
43 #include <algorithm>
44 #include <functional>
45
46 using namespace mlir;
47
48 #define PASS_NAME "convert-func-to-llvm"
49
50 /// Only retain those attributes that are not constructed by
51 /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
52 /// attributes.
filterFuncAttributes(ArrayRef<NamedAttribute> attrs,bool filterArgAndResAttrs,SmallVectorImpl<NamedAttribute> & result)53 static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
54 bool filterArgAndResAttrs,
55 SmallVectorImpl<NamedAttribute> &result) {
56 for (const auto &attr : attrs) {
57 if (attr.getName() == SymbolTable::getSymbolAttrName() ||
58 attr.getName() == FunctionOpInterface::getTypeAttrName() ||
59 attr.getName() == "func.varargs" ||
60 (filterArgAndResAttrs &&
61 (attr.getName() == FunctionOpInterface::getArgDictAttrName() ||
62 attr.getName() == FunctionOpInterface::getResultDictAttrName())))
63 continue;
64 result.push_back(attr);
65 }
66 }
67
68 /// Helper function for wrapping all attributes into a single DictionaryAttr
wrapAsStructAttrs(OpBuilder & b,ArrayAttr attrs)69 static auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) {
70 return DictionaryAttr::get(
71 b.getContext(),
72 b.getNamedAttr(LLVM::LLVMDialect::getStructAttrsAttrName(), attrs));
73 }
74
75 /// Combines all result attributes into a single DictionaryAttr
76 /// and prepends to argument attrs.
77 /// This is intended to be used to format the attributes for a C wrapper
78 /// function when the result(s) is converted to the first function argument
79 /// (in the multiple return case, all returns get wrapped into a single
80 /// argument). The total number of argument attributes should be equal to
81 /// (number of function arguments) + 1.
82 static void
prependResAttrsToArgAttrs(OpBuilder & builder,SmallVectorImpl<NamedAttribute> & attributes,size_t numArguments)83 prependResAttrsToArgAttrs(OpBuilder &builder,
84 SmallVectorImpl<NamedAttribute> &attributes,
85 size_t numArguments) {
86 auto allAttrs = SmallVector<Attribute>(
87 numArguments + 1, DictionaryAttr::get(builder.getContext()));
88 NamedAttribute *argAttrs = nullptr;
89 for (auto *it = attributes.begin(); it != attributes.end();) {
90 if (it->getName() == FunctionOpInterface::getArgDictAttrName()) {
91 auto arrayAttrs = it->getValue().cast<ArrayAttr>();
92 assert(arrayAttrs.size() == numArguments &&
93 "Number of arg attrs and args should match");
94 std::copy(arrayAttrs.begin(), arrayAttrs.end(), allAttrs.begin() + 1);
95 argAttrs = it;
96 } else if (it->getName() == FunctionOpInterface::getResultDictAttrName()) {
97 auto arrayAttrs = it->getValue().cast<ArrayAttr>();
98 assert(!arrayAttrs.empty() && "expected array to be non-empty");
99 allAttrs[0] = (arrayAttrs.size() == 1)
100 ? arrayAttrs[0]
101 : wrapAsStructAttrs(builder, arrayAttrs);
102 it = attributes.erase(it);
103 continue;
104 }
105 it++;
106 }
107
108 auto newArgAttrs =
109 builder.getNamedAttr(FunctionOpInterface::getArgDictAttrName(),
110 builder.getArrayAttr(allAttrs));
111 if (!argAttrs) {
112 attributes.emplace_back(newArgAttrs);
113 return;
114 }
115 *argAttrs = newArgAttrs;
116 }
117
118 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
119 /// arguments instead of unpacked arguments. This function can be called from C
120 /// by passing a pointer to a C struct corresponding to a memref descriptor.
121 /// Similarly, returned memrefs are passed via pointers to a C struct that is
122 /// passed as additional argument.
123 /// Internally, the auxiliary function unpacks the descriptor into individual
124 /// components and forwards them to `newFuncOp` and forwards the results to
125 /// the extra arguments.
wrapForExternalCallers(OpBuilder & rewriter,Location loc,LLVMTypeConverter & typeConverter,func::FuncOp funcOp,LLVM::LLVMFuncOp newFuncOp)126 static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
127 LLVMTypeConverter &typeConverter,
128 func::FuncOp funcOp,
129 LLVM::LLVMFuncOp newFuncOp) {
130 auto type = funcOp.getFunctionType();
131 SmallVector<NamedAttribute, 4> attributes;
132 filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false,
133 attributes);
134 Type wrapperFuncType;
135 bool resultIsNowArg;
136 std::tie(wrapperFuncType, resultIsNowArg) =
137 typeConverter.convertFunctionTypeCWrapper(type);
138 if (resultIsNowArg)
139 prependResAttrsToArgAttrs(rewriter, attributes, funcOp.getNumArguments());
140 auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
141 loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
142 wrapperFuncType, LLVM::Linkage::External, /*dsoLocal*/ false,
143 /*cconv*/ LLVM::CConv::C, attributes);
144
145 OpBuilder::InsertionGuard guard(rewriter);
146 rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock());
147
148 SmallVector<Value, 8> args;
149 size_t argOffset = resultIsNowArg ? 1 : 0;
150 for (auto &en : llvm::enumerate(type.getInputs())) {
151 Value arg = wrapperFuncOp.getArgument(en.index() + argOffset);
152 if (auto memrefType = en.value().dyn_cast<MemRefType>()) {
153 Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
154 MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
155 continue;
156 }
157 if (en.value().isa<UnrankedMemRefType>()) {
158 Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
159 UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args);
160 continue;
161 }
162
163 args.push_back(arg);
164 }
165
166 auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args);
167
168 if (resultIsNowArg) {
169 rewriter.create<LLVM::StoreOp>(loc, call.getResult(0),
170 wrapperFuncOp.getArgument(0));
171 rewriter.create<LLVM::ReturnOp>(loc, ValueRange{});
172 } else {
173 rewriter.create<LLVM::ReturnOp>(loc, call.getResults());
174 }
175 }
176
177 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
178 /// arguments instead of unpacked arguments. Creates a body for the (external)
179 /// `newFuncOp` that allocates a memref descriptor on stack, packs the
180 /// individual arguments into this descriptor and passes a pointer to it into
181 /// the auxiliary function. If the result of the function cannot be directly
182 /// returned, we write it to a special first argument that provides a pointer
183 /// to a corresponding struct. This auxiliary external function is now
184 /// compatible with functions defined in C using pointers to C structs
185 /// corresponding to a memref descriptor.
wrapExternalFunction(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,func::FuncOp funcOp,LLVM::LLVMFuncOp newFuncOp)186 static void wrapExternalFunction(OpBuilder &builder, Location loc,
187 LLVMTypeConverter &typeConverter,
188 func::FuncOp funcOp,
189 LLVM::LLVMFuncOp newFuncOp) {
190 OpBuilder::InsertionGuard guard(builder);
191
192 Type wrapperType;
193 bool resultIsNowArg;
194 std::tie(wrapperType, resultIsNowArg) =
195 typeConverter.convertFunctionTypeCWrapper(funcOp.getFunctionType());
196 // This conversion can only fail if it could not convert one of the argument
197 // types. But since it has been applied to a non-wrapper function before, it
198 // should have failed earlier and not reach this point at all.
199 assert(wrapperType && "unexpected type conversion failure");
200
201 SmallVector<NamedAttribute, 4> attributes;
202 filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false,
203 attributes);
204
205 if (resultIsNowArg)
206 prependResAttrsToArgAttrs(builder, attributes, funcOp.getNumArguments());
207 // Create the auxiliary function.
208 auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
209 loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
210 wrapperType, LLVM::Linkage::External, /*dsoLocal*/ false,
211 /*cconv*/ LLVM::CConv::C, attributes);
212
213 builder.setInsertionPointToStart(newFuncOp.addEntryBlock());
214
215 // Get a ValueRange containing arguments.
216 FunctionType type = funcOp.getFunctionType();
217 SmallVector<Value, 8> args;
218 args.reserve(type.getNumInputs());
219 ValueRange wrapperArgsRange(newFuncOp.getArguments());
220
221 if (resultIsNowArg) {
222 // Allocate the struct on the stack and pass the pointer.
223 Type resultType =
224 wrapperType.cast<LLVM::LLVMFunctionType>().getParamType(0);
225 Value one = builder.create<LLVM::ConstantOp>(
226 loc, typeConverter.convertType(builder.getIndexType()),
227 builder.getIntegerAttr(builder.getIndexType(), 1));
228 Value result = builder.create<LLVM::AllocaOp>(loc, resultType, one);
229 args.push_back(result);
230 }
231
232 // Iterate over the inputs of the original function and pack values into
233 // memref descriptors if the original type is a memref.
234 for (auto &en : llvm::enumerate(type.getInputs())) {
235 Value arg;
236 int numToDrop = 1;
237 auto memRefType = en.value().dyn_cast<MemRefType>();
238 auto unrankedMemRefType = en.value().dyn_cast<UnrankedMemRefType>();
239 if (memRefType || unrankedMemRefType) {
240 numToDrop = memRefType
241 ? MemRefDescriptor::getNumUnpackedValues(memRefType)
242 : UnrankedMemRefDescriptor::getNumUnpackedValues();
243 Value packed =
244 memRefType
245 ? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType,
246 wrapperArgsRange.take_front(numToDrop))
247 : UnrankedMemRefDescriptor::pack(
248 builder, loc, typeConverter, unrankedMemRefType,
249 wrapperArgsRange.take_front(numToDrop));
250
251 auto ptrTy = LLVM::LLVMPointerType::get(packed.getType());
252 Value one = builder.create<LLVM::ConstantOp>(
253 loc, typeConverter.convertType(builder.getIndexType()),
254 builder.getIntegerAttr(builder.getIndexType(), 1));
255 Value allocated =
256 builder.create<LLVM::AllocaOp>(loc, ptrTy, one, /*alignment=*/0);
257 builder.create<LLVM::StoreOp>(loc, packed, allocated);
258 arg = allocated;
259 } else {
260 arg = wrapperArgsRange[0];
261 }
262
263 args.push_back(arg);
264 wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop);
265 }
266 assert(wrapperArgsRange.empty() && "did not map some of the arguments");
267
268 auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args);
269
270 if (resultIsNowArg) {
271 Value result = builder.create<LLVM::LoadOp>(loc, args.front());
272 builder.create<LLVM::ReturnOp>(loc, ValueRange{result});
273 } else {
274 builder.create<LLVM::ReturnOp>(loc, call.getResults());
275 }
276 }
277
278 namespace {
279
280 struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
281 protected:
282 using ConvertOpToLLVMPattern<func::FuncOp>::ConvertOpToLLVMPattern;
283
284 // Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided
285 // to this legalization pattern.
286 LLVM::LLVMFuncOp
convertFuncOpToLLVMFuncOp__anon290525710111::FuncOpConversionBase287 convertFuncOpToLLVMFuncOp(func::FuncOp funcOp,
288 ConversionPatternRewriter &rewriter) const {
289 // Convert the original function arguments. They are converted using the
290 // LLVMTypeConverter provided to this legalization pattern.
291 auto varargsAttr = funcOp->getAttrOfType<BoolAttr>("func.varargs");
292 TypeConverter::SignatureConversion result(funcOp.getNumArguments());
293 auto llvmType = getTypeConverter()->convertFunctionSignature(
294 funcOp.getFunctionType(), varargsAttr && varargsAttr.getValue(),
295 result);
296 if (!llvmType)
297 return nullptr;
298
299 // Propagate argument/result attributes to all converted arguments/result
300 // obtained after converting a given original argument/result.
301 SmallVector<NamedAttribute, 4> attributes;
302 filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/true,
303 attributes);
304 if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
305 assert(!resAttrDicts.empty() && "expected array to be non-empty");
306 auto newResAttrDicts =
307 (funcOp.getNumResults() == 1)
308 ? resAttrDicts
309 : rewriter.getArrayAttr(
310 {wrapAsStructAttrs(rewriter, resAttrDicts)});
311 attributes.push_back(rewriter.getNamedAttr(
312 FunctionOpInterface::getResultDictAttrName(), newResAttrDicts));
313 }
314 if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
315 SmallVector<Attribute, 4> newArgAttrs(
316 llvmType.cast<LLVM::LLVMFunctionType>().getNumParams());
317 for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
318 auto mapping = result.getInputMapping(i);
319 assert(mapping && "unexpected deletion of function argument");
320 for (size_t j = 0; j < mapping->size; ++j)
321 newArgAttrs[mapping->inputNo + j] = argAttrDicts[i];
322 }
323 attributes.push_back(
324 rewriter.getNamedAttr(FunctionOpInterface::getArgDictAttrName(),
325 rewriter.getArrayAttr(newArgAttrs)));
326 }
327 for (const auto &pair : llvm::enumerate(attributes)) {
328 if (pair.value().getName() == "llvm.linkage") {
329 attributes.erase(attributes.begin() + pair.index());
330 break;
331 }
332 }
333
334 // Create an LLVM function, use external linkage by default until MLIR
335 // functions have linkage.
336 LLVM::Linkage linkage = LLVM::Linkage::External;
337 if (funcOp->hasAttr("llvm.linkage")) {
338 auto attr =
339 funcOp->getAttr("llvm.linkage").dyn_cast<mlir::LLVM::LinkageAttr>();
340 if (!attr) {
341 funcOp->emitError()
342 << "Contains llvm.linkage attribute not of type LLVM::LinkageAttr";
343 return nullptr;
344 }
345 linkage = attr.getLinkage();
346 }
347 auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
348 funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
349 /*dsoLocal*/ false, /*cconv*/ LLVM::CConv::C, attributes);
350 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
351 newFuncOp.end());
352 if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
353 &result)))
354 return nullptr;
355
356 return newFuncOp;
357 }
358 };
359
360 /// FuncOp legalization pattern that converts MemRef arguments to pointers to
361 /// MemRef descriptors (LLVM struct data types) containing all the MemRef type
362 /// information.
363 struct FuncOpConversion : public FuncOpConversionBase {
FuncOpConversion__anon290525710111::FuncOpConversion364 FuncOpConversion(LLVMTypeConverter &converter)
365 : FuncOpConversionBase(converter) {}
366
367 LogicalResult
matchAndRewrite__anon290525710111::FuncOpConversion368 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
369 ConversionPatternRewriter &rewriter) const override {
370 auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
371 if (!newFuncOp)
372 return failure();
373
374 if (funcOp->getAttrOfType<UnitAttr>(
375 LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
376 if (newFuncOp.isVarArg())
377 return funcOp->emitError("C interface for variadic functions is not "
378 "supported yet.");
379
380 if (newFuncOp.isExternal())
381 wrapExternalFunction(rewriter, funcOp.getLoc(), *getTypeConverter(),
382 funcOp, newFuncOp);
383 else
384 wrapForExternalCallers(rewriter, funcOp.getLoc(), *getTypeConverter(),
385 funcOp, newFuncOp);
386 }
387
388 rewriter.eraseOp(funcOp);
389 return success();
390 }
391 };
392
393 /// FuncOp legalization pattern that converts MemRef arguments to bare pointers
394 /// to the MemRef element type. This will impact the calling convention and ABI.
395 struct BarePtrFuncOpConversion : public FuncOpConversionBase {
396 using FuncOpConversionBase::FuncOpConversionBase;
397
398 LogicalResult
matchAndRewrite__anon290525710111::BarePtrFuncOpConversion399 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
400 ConversionPatternRewriter &rewriter) const override {
401
402 // TODO: bare ptr conversion could be handled by argument materialization
403 // and most of the code below would go away. But to do this, we would need a
404 // way to distinguish between FuncOp and other regions in the
405 // addArgumentMaterialization hook.
406
407 // Store the type of memref-typed arguments before the conversion so that we
408 // can promote them to MemRef descriptor at the beginning of the function.
409 SmallVector<Type, 8> oldArgTypes =
410 llvm::to_vector<8>(funcOp.getFunctionType().getInputs());
411
412 auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
413 if (!newFuncOp)
414 return failure();
415 if (newFuncOp.getBody().empty()) {
416 rewriter.eraseOp(funcOp);
417 return success();
418 }
419
420 // Promote bare pointers from memref arguments to memref descriptors at the
421 // beginning of the function so that all the memrefs in the function have a
422 // uniform representation.
423 Block *entryBlock = &newFuncOp.getBody().front();
424 auto blockArgs = entryBlock->getArguments();
425 assert(blockArgs.size() == oldArgTypes.size() &&
426 "The number of arguments and types doesn't match");
427
428 OpBuilder::InsertionGuard guard(rewriter);
429 rewriter.setInsertionPointToStart(entryBlock);
430 for (auto it : llvm::zip(blockArgs, oldArgTypes)) {
431 BlockArgument arg = std::get<0>(it);
432 Type argTy = std::get<1>(it);
433
434 // Unranked memrefs are not supported in the bare pointer calling
435 // convention. We should have bailed out before in the presence of
436 // unranked memrefs.
437 assert(!argTy.isa<UnrankedMemRefType>() &&
438 "Unranked memref is not supported");
439 auto memrefTy = argTy.dyn_cast<MemRefType>();
440 if (!memrefTy)
441 continue;
442
443 // Replace barePtr with a placeholder (undef), promote barePtr to a ranked
444 // or unranked memref descriptor and replace placeholder with the last
445 // instruction of the memref descriptor.
446 // TODO: The placeholder is needed to avoid replacing barePtr uses in the
447 // MemRef descriptor instructions. We may want to have a utility in the
448 // rewriter to properly handle this use case.
449 Location loc = funcOp.getLoc();
450 auto placeholder = rewriter.create<LLVM::UndefOp>(
451 loc, getTypeConverter()->convertType(memrefTy));
452 rewriter.replaceUsesOfBlockArgument(arg, placeholder);
453
454 Value desc = MemRefDescriptor::fromStaticShape(
455 rewriter, loc, *getTypeConverter(), memrefTy, arg);
456 rewriter.replaceOp(placeholder, {desc});
457 }
458
459 rewriter.eraseOp(funcOp);
460 return success();
461 }
462 };
463
464 struct ConstantOpLowering : public ConvertOpToLLVMPattern<func::ConstantOp> {
465 using ConvertOpToLLVMPattern<func::ConstantOp>::ConvertOpToLLVMPattern;
466
467 LogicalResult
matchAndRewrite__anon290525710111::ConstantOpLowering468 matchAndRewrite(func::ConstantOp op, OpAdaptor adaptor,
469 ConversionPatternRewriter &rewriter) const override {
470 auto type = typeConverter->convertType(op.getResult().getType());
471 if (!type || !LLVM::isCompatibleType(type))
472 return rewriter.notifyMatchFailure(op, "failed to convert result type");
473
474 auto newOp =
475 rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type, op.getValue());
476 for (const NamedAttribute &attr : op->getAttrs()) {
477 if (attr.getName().strref() == "value")
478 continue;
479 newOp->setAttr(attr.getName(), attr.getValue());
480 }
481 rewriter.replaceOp(op, newOp->getResults());
482 return success();
483 }
484 };
485
486 // A CallOp automatically promotes MemRefType to a sequence of alloca/store and
487 // passes the pointer to the MemRef across function boundaries.
488 template <typename CallOpType>
489 struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
490 using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern;
491 using Super = CallOpInterfaceLowering<CallOpType>;
492 using Base = ConvertOpToLLVMPattern<CallOpType>;
493
494 LogicalResult
matchAndRewrite__anon290525710111::CallOpInterfaceLowering495 matchAndRewrite(CallOpType callOp, typename CallOpType::Adaptor adaptor,
496 ConversionPatternRewriter &rewriter) const override {
497 // Pack the result types into a struct.
498 Type packedResult = nullptr;
499 unsigned numResults = callOp.getNumResults();
500 auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
501
502 if (numResults != 0) {
503 if (!(packedResult =
504 this->getTypeConverter()->packFunctionResults(resultTypes)))
505 return failure();
506 }
507
508 auto promoted = this->getTypeConverter()->promoteOperands(
509 callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
510 adaptor.getOperands(), rewriter);
511 auto newOp = rewriter.create<LLVM::CallOp>(
512 callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
513 promoted, callOp->getAttrs());
514
515 SmallVector<Value, 4> results;
516 if (numResults < 2) {
517 // If < 2 results, packing did not do anything and we can just return.
518 results.append(newOp.result_begin(), newOp.result_end());
519 } else {
520 // Otherwise, it had been converted to an operation producing a structure.
521 // Extract individual results from the structure and return them as list.
522 results.reserve(numResults);
523 for (unsigned i = 0; i < numResults; ++i) {
524 auto type =
525 this->typeConverter->convertType(callOp.getResult(i).getType());
526 results.push_back(rewriter.create<LLVM::ExtractValueOp>(
527 callOp.getLoc(), type, newOp->getResult(0),
528 rewriter.getI64ArrayAttr(i)));
529 }
530 }
531
532 if (this->getTypeConverter()->getOptions().useBarePtrCallConv) {
533 // For the bare-ptr calling convention, promote memref results to
534 // descriptors.
535 assert(results.size() == resultTypes.size() &&
536 "The number of arguments and types doesn't match");
537 this->getTypeConverter()->promoteBarePtrsToDescriptors(
538 rewriter, callOp.getLoc(), resultTypes, results);
539 } else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(),
540 resultTypes, results,
541 /*toDynamic=*/false))) {
542 return failure();
543 }
544
545 rewriter.replaceOp(callOp, results);
546 return success();
547 }
548 };
549
550 struct CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
551 using Super::Super;
552 };
553
554 struct CallIndirectOpLowering
555 : public CallOpInterfaceLowering<func::CallIndirectOp> {
556 using Super::Super;
557 };
558
559 struct UnrealizedConversionCastOpLowering
560 : public ConvertOpToLLVMPattern<UnrealizedConversionCastOp> {
561 using ConvertOpToLLVMPattern<
562 UnrealizedConversionCastOp>::ConvertOpToLLVMPattern;
563
564 LogicalResult
matchAndRewrite__anon290525710111::UnrealizedConversionCastOpLowering565 matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor,
566 ConversionPatternRewriter &rewriter) const override {
567 SmallVector<Type> convertedTypes;
568 if (succeeded(typeConverter->convertTypes(op.getOutputs().getTypes(),
569 convertedTypes)) &&
570 convertedTypes == adaptor.getInputs().getTypes()) {
571 rewriter.replaceOp(op, adaptor.getInputs());
572 return success();
573 }
574
575 convertedTypes.clear();
576 if (succeeded(typeConverter->convertTypes(adaptor.getInputs().getTypes(),
577 convertedTypes)) &&
578 convertedTypes == op.getOutputs().getType()) {
579 rewriter.replaceOp(op, adaptor.getInputs());
580 return success();
581 }
582 return failure();
583 }
584 };
585
586 // Special lowering pattern for `ReturnOps`. Unlike all other operations,
587 // `ReturnOp` interacts with the function signature and must have as many
588 // operands as the function has return values. Because in LLVM IR, functions
589 // can only return 0 or 1 value, we pack multiple values into a structure type.
590 // Emit `UndefOp` followed by `InsertValueOp`s to create such structure if
591 // necessary before returning it
592 struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
593 using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
594
595 LogicalResult
matchAndRewrite__anon290525710111::ReturnOpLowering596 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
597 ConversionPatternRewriter &rewriter) const override {
598 Location loc = op.getLoc();
599 unsigned numArguments = op.getNumOperands();
600 SmallVector<Value, 4> updatedOperands;
601
602 if (getTypeConverter()->getOptions().useBarePtrCallConv) {
603 // For the bare-ptr calling convention, extract the aligned pointer to
604 // be returned from the memref descriptor.
605 for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
606 Type oldTy = std::get<0>(it).getType();
607 Value newOperand = std::get<1>(it);
608 if (oldTy.isa<MemRefType>() && getTypeConverter()->canConvertToBarePtr(
609 oldTy.cast<BaseMemRefType>())) {
610 MemRefDescriptor memrefDesc(newOperand);
611 newOperand = memrefDesc.alignedPtr(rewriter, loc);
612 } else if (oldTy.isa<UnrankedMemRefType>()) {
613 // Unranked memref is not supported in the bare pointer calling
614 // convention.
615 return failure();
616 }
617 updatedOperands.push_back(newOperand);
618 }
619 } else {
620 updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
621 (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
622 updatedOperands,
623 /*toDynamic=*/true);
624 }
625
626 // If ReturnOp has 0 or 1 operand, create it and return immediately.
627 if (numArguments == 0) {
628 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(),
629 op->getAttrs());
630 return success();
631 }
632 if (numArguments == 1) {
633 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
634 op, TypeRange(), updatedOperands, op->getAttrs());
635 return success();
636 }
637
638 // Otherwise, we need to pack the arguments into an LLVM struct type before
639 // returning.
640 auto packedType = getTypeConverter()->packFunctionResults(
641 llvm::to_vector<4>(op.getOperandTypes()));
642
643 Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
644 for (unsigned i = 0; i < numArguments; ++i) {
645 packed = rewriter.create<LLVM::InsertValueOp>(
646 loc, packedType, packed, updatedOperands[i],
647 rewriter.getI64ArrayAttr(i));
648 }
649 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
650 op->getAttrs());
651 return success();
652 }
653 };
654 } // namespace
655
populateFuncToLLVMFuncOpConversionPattern(LLVMTypeConverter & converter,RewritePatternSet & patterns)656 void mlir::populateFuncToLLVMFuncOpConversionPattern(
657 LLVMTypeConverter &converter, RewritePatternSet &patterns) {
658 if (converter.getOptions().useBarePtrCallConv)
659 patterns.add<BarePtrFuncOpConversion>(converter);
660 else
661 patterns.add<FuncOpConversion>(converter);
662 }
663
populateFuncToLLVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)664 void mlir::populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter,
665 RewritePatternSet &patterns) {
666 populateFuncToLLVMFuncOpConversionPattern(converter, patterns);
667 // clang-format off
668 patterns.add<
669 CallIndirectOpLowering,
670 CallOpLowering,
671 ConstantOpLowering,
672 ReturnOpLowering>(converter);
673 // clang-format on
674 }
675
676 namespace {
677 /// A pass converting Func operations into the LLVM IR dialect.
678 struct ConvertFuncToLLVMPass
679 : public ConvertFuncToLLVMBase<ConvertFuncToLLVMPass> {
680 ConvertFuncToLLVMPass() = default;
ConvertFuncToLLVMPass__anon290525710211::ConvertFuncToLLVMPass681 ConvertFuncToLLVMPass(bool useBarePtrCallConv, unsigned indexBitwidth,
682 bool useAlignedAlloc,
683 const llvm::DataLayout &dataLayout) {
684 this->useBarePtrCallConv = useBarePtrCallConv;
685 this->indexBitwidth = indexBitwidth;
686 this->dataLayout = dataLayout.getStringRepresentation();
687 }
688
689 /// Run the dialect converter on the module.
runOnOperation__anon290525710211::ConvertFuncToLLVMPass690 void runOnOperation() override {
691 if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
692 this->dataLayout, [this](const Twine &message) {
693 getOperation().emitError() << message.str();
694 }))) {
695 signalPassFailure();
696 return;
697 }
698
699 ModuleOp m = getOperation();
700 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
701
702 LowerToLLVMOptions options(&getContext(),
703 dataLayoutAnalysis.getAtOrAbove(m));
704 options.useBarePtrCallConv = useBarePtrCallConv;
705 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
706 options.overrideIndexBitwidth(indexBitwidth);
707 options.dataLayout = llvm::DataLayout(this->dataLayout);
708
709 LLVMTypeConverter typeConverter(&getContext(), options,
710 &dataLayoutAnalysis);
711
712 RewritePatternSet patterns(&getContext());
713 populateFuncToLLVMConversionPatterns(typeConverter, patterns);
714
715 // TODO: Remove these in favor of their dedicated conversion passes.
716 arith::populateArithmeticToLLVMConversionPatterns(typeConverter, patterns);
717 cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
718
719 LLVMConversionTarget target(getContext());
720 if (failed(applyPartialConversion(m, target, std::move(patterns))))
721 signalPassFailure();
722
723 m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(),
724 StringAttr::get(m.getContext(), this->dataLayout));
725 }
726 };
727 } // namespace
728
createConvertFuncToLLVMPass()729 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertFuncToLLVMPass() {
730 return std::make_unique<ConvertFuncToLLVMPass>();
731 }
732
733 std::unique_ptr<OperationPass<ModuleOp>>
createConvertFuncToLLVMPass(const LowerToLLVMOptions & options)734 mlir::createConvertFuncToLLVMPass(const LowerToLLVMOptions &options) {
735 auto allocLowering = options.allocLowering;
736 // There is no way to provide additional patterns for pass, so
737 // AllocLowering::None will always fail.
738 assert(allocLowering != LowerToLLVMOptions::AllocLowering::None &&
739 "ConvertFuncToLLVMPass doesn't support AllocLowering::None");
740 bool useAlignedAlloc =
741 (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc);
742 return std::make_unique<ConvertFuncToLLVMPass>(
743 options.useBarePtrCallConv, options.getIndexBitwidth(), useAlignedAlloc,
744 options.dataLayout);
745 }
746