1 //===- FuncToLLVM.cpp - Func to LLVM dialect conversion -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert MLIR Func and builtin dialects 10 // into the LLVM IR dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "../PassDetail.h" 15 #include "mlir/Analysis/DataLayoutAnalysis.h" 16 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" 17 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" 18 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" 19 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" 20 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 21 #include "mlir/Conversion/LLVMCommon/Pattern.h" 22 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 23 #include "mlir/Dialect/Func/IR/FuncOps.h" 24 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 25 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 26 #include "mlir/Dialect/Utils/StaticValueUtils.h" 27 #include "mlir/IR/Attributes.h" 28 #include "mlir/IR/BlockAndValueMapping.h" 29 #include "mlir/IR/Builders.h" 30 #include "mlir/IR/BuiltinOps.h" 31 #include "mlir/IR/PatternMatch.h" 32 #include "mlir/IR/TypeUtilities.h" 33 #include "mlir/Support/LogicalResult.h" 34 #include "mlir/Support/MathExtras.h" 35 #include "mlir/Transforms/DialectConversion.h" 36 #include "mlir/Transforms/Passes.h" 37 #include "llvm/ADT/TypeSwitch.h" 38 #include "llvm/IR/DerivedTypes.h" 39 #include "llvm/IR/IRBuilder.h" 40 #include "llvm/IR/Type.h" 41 #include "llvm/Support/CommandLine.h" 42 #include "llvm/Support/FormatVariadic.h" 43 #include <functional> 44 45 using namespace mlir; 46 47 #define PASS_NAME "convert-func-to-llvm" 48 49 /// Only retain those attributes that are not constructed by 50 /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument 51 /// attributes. 52 static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs, 53 bool filterArgAttrs, 54 SmallVectorImpl<NamedAttribute> &result) { 55 for (const auto &attr : attrs) { 56 if (attr.getName() == SymbolTable::getSymbolAttrName() || 57 attr.getName() == FunctionOpInterface::getTypeAttrName() || 58 attr.getName() == "func.varargs" || 59 (filterArgAttrs && 60 attr.getName() == FunctionOpInterface::getArgDictAttrName())) 61 continue; 62 result.push_back(attr); 63 } 64 } 65 66 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct 67 /// arguments instead of unpacked arguments. This function can be called from C 68 /// by passing a pointer to a C struct corresponding to a memref descriptor. 69 /// Similarly, returned memrefs are passed via pointers to a C struct that is 70 /// passed as additional argument. 71 /// Internally, the auxiliary function unpacks the descriptor into individual 72 /// components and forwards them to `newFuncOp` and forwards the results to 73 /// the extra arguments. 74 static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, 75 LLVMTypeConverter &typeConverter, 76 FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) { 77 auto type = funcOp.getType(); 78 SmallVector<NamedAttribute, 4> attributes; 79 filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/false, 80 attributes); 81 Type wrapperFuncType; 82 bool resultIsNowArg; 83 std::tie(wrapperFuncType, resultIsNowArg) = 84 typeConverter.convertFunctionTypeCWrapper(type); 85 auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>( 86 loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), 87 wrapperFuncType, LLVM::Linkage::External, /*dsoLocal*/ false, attributes); 88 89 OpBuilder::InsertionGuard guard(rewriter); 90 rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock()); 91 92 SmallVector<Value, 8> args; 93 size_t argOffset = resultIsNowArg ? 1 : 0; 94 for (auto &en : llvm::enumerate(type.getInputs())) { 95 Value arg = wrapperFuncOp.getArgument(en.index() + argOffset); 96 if (auto memrefType = en.value().dyn_cast<MemRefType>()) { 97 Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg); 98 MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args); 99 continue; 100 } 101 if (en.value().isa<UnrankedMemRefType>()) { 102 Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg); 103 UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args); 104 continue; 105 } 106 107 args.push_back(arg); 108 } 109 110 auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args); 111 112 if (resultIsNowArg) { 113 rewriter.create<LLVM::StoreOp>(loc, call.getResult(0), 114 wrapperFuncOp.getArgument(0)); 115 rewriter.create<LLVM::ReturnOp>(loc, ValueRange{}); 116 } else { 117 rewriter.create<LLVM::ReturnOp>(loc, call.getResults()); 118 } 119 } 120 121 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct 122 /// arguments instead of unpacked arguments. Creates a body for the (external) 123 /// `newFuncOp` that allocates a memref descriptor on stack, packs the 124 /// individual arguments into this descriptor and passes a pointer to it into 125 /// the auxiliary function. If the result of the function cannot be directly 126 /// returned, we write it to a special first argument that provides a pointer 127 /// to a corresponding struct. This auxiliary external function is now 128 /// compatible with functions defined in C using pointers to C structs 129 /// corresponding to a memref descriptor. 130 static void wrapExternalFunction(OpBuilder &builder, Location loc, 131 LLVMTypeConverter &typeConverter, 132 FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) { 133 OpBuilder::InsertionGuard guard(builder); 134 135 Type wrapperType; 136 bool resultIsNowArg; 137 std::tie(wrapperType, resultIsNowArg) = 138 typeConverter.convertFunctionTypeCWrapper(funcOp.getType()); 139 // This conversion can only fail if it could not convert one of the argument 140 // types. But since it has been applied to a non-wrapper function before, it 141 // should have failed earlier and not reach this point at all. 142 assert(wrapperType && "unexpected type conversion failure"); 143 144 SmallVector<NamedAttribute, 4> attributes; 145 filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/false, 146 attributes); 147 148 // Create the auxiliary function. 149 auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>( 150 loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), 151 wrapperType, LLVM::Linkage::External, /*dsoLocal*/ false, attributes); 152 153 builder.setInsertionPointToStart(newFuncOp.addEntryBlock()); 154 155 // Get a ValueRange containing arguments. 156 FunctionType type = funcOp.getType(); 157 SmallVector<Value, 8> args; 158 args.reserve(type.getNumInputs()); 159 ValueRange wrapperArgsRange(newFuncOp.getArguments()); 160 161 if (resultIsNowArg) { 162 // Allocate the struct on the stack and pass the pointer. 163 Type resultType = 164 wrapperType.cast<LLVM::LLVMFunctionType>().getParamType(0); 165 Value one = builder.create<LLVM::ConstantOp>( 166 loc, typeConverter.convertType(builder.getIndexType()), 167 builder.getIntegerAttr(builder.getIndexType(), 1)); 168 Value result = builder.create<LLVM::AllocaOp>(loc, resultType, one); 169 args.push_back(result); 170 } 171 172 // Iterate over the inputs of the original function and pack values into 173 // memref descriptors if the original type is a memref. 174 for (auto &en : llvm::enumerate(type.getInputs())) { 175 Value arg; 176 int numToDrop = 1; 177 auto memRefType = en.value().dyn_cast<MemRefType>(); 178 auto unrankedMemRefType = en.value().dyn_cast<UnrankedMemRefType>(); 179 if (memRefType || unrankedMemRefType) { 180 numToDrop = memRefType 181 ? MemRefDescriptor::getNumUnpackedValues(memRefType) 182 : UnrankedMemRefDescriptor::getNumUnpackedValues(); 183 Value packed = 184 memRefType 185 ? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType, 186 wrapperArgsRange.take_front(numToDrop)) 187 : UnrankedMemRefDescriptor::pack( 188 builder, loc, typeConverter, unrankedMemRefType, 189 wrapperArgsRange.take_front(numToDrop)); 190 191 auto ptrTy = LLVM::LLVMPointerType::get(packed.getType()); 192 Value one = builder.create<LLVM::ConstantOp>( 193 loc, typeConverter.convertType(builder.getIndexType()), 194 builder.getIntegerAttr(builder.getIndexType(), 1)); 195 Value allocated = 196 builder.create<LLVM::AllocaOp>(loc, ptrTy, one, /*alignment=*/0); 197 builder.create<LLVM::StoreOp>(loc, packed, allocated); 198 arg = allocated; 199 } else { 200 arg = wrapperArgsRange[0]; 201 } 202 203 args.push_back(arg); 204 wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop); 205 } 206 assert(wrapperArgsRange.empty() && "did not map some of the arguments"); 207 208 auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args); 209 210 if (resultIsNowArg) { 211 Value result = builder.create<LLVM::LoadOp>(loc, args.front()); 212 builder.create<LLVM::ReturnOp>(loc, ValueRange{result}); 213 } else { 214 builder.create<LLVM::ReturnOp>(loc, call.getResults()); 215 } 216 } 217 218 namespace { 219 220 struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> { 221 protected: 222 using ConvertOpToLLVMPattern<FuncOp>::ConvertOpToLLVMPattern; 223 224 // Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided 225 // to this legalization pattern. 226 LLVM::LLVMFuncOp 227 convertFuncOpToLLVMFuncOp(FuncOp funcOp, 228 ConversionPatternRewriter &rewriter) const { 229 // Convert the original function arguments. They are converted using the 230 // LLVMTypeConverter provided to this legalization pattern. 231 auto varargsAttr = funcOp->getAttrOfType<BoolAttr>("func.varargs"); 232 TypeConverter::SignatureConversion result(funcOp.getNumArguments()); 233 auto llvmType = getTypeConverter()->convertFunctionSignature( 234 funcOp.getType(), varargsAttr && varargsAttr.getValue(), result); 235 if (!llvmType) 236 return nullptr; 237 238 // Propagate argument attributes to all converted arguments obtained after 239 // converting a given original argument. 240 SmallVector<NamedAttribute, 4> attributes; 241 filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/true, 242 attributes); 243 if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) { 244 SmallVector<Attribute, 4> newArgAttrs( 245 llvmType.cast<LLVM::LLVMFunctionType>().getNumParams()); 246 for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) { 247 auto mapping = result.getInputMapping(i); 248 assert(mapping.hasValue() && 249 "unexpected deletion of function argument"); 250 for (size_t j = 0; j < mapping->size; ++j) 251 newArgAttrs[mapping->inputNo + j] = argAttrDicts[i]; 252 } 253 attributes.push_back( 254 rewriter.getNamedAttr(FunctionOpInterface::getArgDictAttrName(), 255 rewriter.getArrayAttr(newArgAttrs))); 256 } 257 for (const auto &pair : llvm::enumerate(attributes)) { 258 if (pair.value().getName() == "llvm.linkage") { 259 attributes.erase(attributes.begin() + pair.index()); 260 break; 261 } 262 } 263 264 // Create an LLVM function, use external linkage by default until MLIR 265 // functions have linkage. 266 LLVM::Linkage linkage = LLVM::Linkage::External; 267 if (funcOp->hasAttr("llvm.linkage")) { 268 auto attr = 269 funcOp->getAttr("llvm.linkage").dyn_cast<mlir::LLVM::LinkageAttr>(); 270 if (!attr) { 271 funcOp->emitError() 272 << "Contains llvm.linkage attribute not of type LLVM::LinkageAttr"; 273 return nullptr; 274 } 275 linkage = attr.getLinkage(); 276 } 277 auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>( 278 funcOp.getLoc(), funcOp.getName(), llvmType, linkage, 279 /*dsoLocal*/ false, attributes); 280 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), 281 newFuncOp.end()); 282 if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter, 283 &result))) 284 return nullptr; 285 286 return newFuncOp; 287 } 288 }; 289 290 /// FuncOp legalization pattern that converts MemRef arguments to pointers to 291 /// MemRef descriptors (LLVM struct data types) containing all the MemRef type 292 /// information. 293 static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface"; 294 struct FuncOpConversion : public FuncOpConversionBase { 295 FuncOpConversion(LLVMTypeConverter &converter) 296 : FuncOpConversionBase(converter) {} 297 298 LogicalResult 299 matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor, 300 ConversionPatternRewriter &rewriter) const override { 301 auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); 302 if (!newFuncOp) 303 return failure(); 304 305 if (getTypeConverter()->getOptions().emitCWrappers || 306 funcOp->getAttrOfType<UnitAttr>(kEmitIfaceAttrName)) { 307 if (newFuncOp.isExternal()) 308 wrapExternalFunction(rewriter, funcOp.getLoc(), *getTypeConverter(), 309 funcOp, newFuncOp); 310 else 311 wrapForExternalCallers(rewriter, funcOp.getLoc(), *getTypeConverter(), 312 funcOp, newFuncOp); 313 } 314 315 rewriter.eraseOp(funcOp); 316 return success(); 317 } 318 }; 319 320 /// FuncOp legalization pattern that converts MemRef arguments to bare pointers 321 /// to the MemRef element type. This will impact the calling convention and ABI. 322 struct BarePtrFuncOpConversion : public FuncOpConversionBase { 323 using FuncOpConversionBase::FuncOpConversionBase; 324 325 LogicalResult 326 matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor, 327 ConversionPatternRewriter &rewriter) const override { 328 329 // TODO: bare ptr conversion could be handled by argument materialization 330 // and most of the code below would go away. But to do this, we would need a 331 // way to distinguish between FuncOp and other regions in the 332 // addArgumentMaterialization hook. 333 334 // Store the type of memref-typed arguments before the conversion so that we 335 // can promote them to MemRef descriptor at the beginning of the function. 336 SmallVector<Type, 8> oldArgTypes = 337 llvm::to_vector<8>(funcOp.getType().getInputs()); 338 339 auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); 340 if (!newFuncOp) 341 return failure(); 342 if (newFuncOp.getBody().empty()) { 343 rewriter.eraseOp(funcOp); 344 return success(); 345 } 346 347 // Promote bare pointers from memref arguments to memref descriptors at the 348 // beginning of the function so that all the memrefs in the function have a 349 // uniform representation. 350 Block *entryBlock = &newFuncOp.getBody().front(); 351 auto blockArgs = entryBlock->getArguments(); 352 assert(blockArgs.size() == oldArgTypes.size() && 353 "The number of arguments and types doesn't match"); 354 355 OpBuilder::InsertionGuard guard(rewriter); 356 rewriter.setInsertionPointToStart(entryBlock); 357 for (auto it : llvm::zip(blockArgs, oldArgTypes)) { 358 BlockArgument arg = std::get<0>(it); 359 Type argTy = std::get<1>(it); 360 361 // Unranked memrefs are not supported in the bare pointer calling 362 // convention. We should have bailed out before in the presence of 363 // unranked memrefs. 364 assert(!argTy.isa<UnrankedMemRefType>() && 365 "Unranked memref is not supported"); 366 auto memrefTy = argTy.dyn_cast<MemRefType>(); 367 if (!memrefTy) 368 continue; 369 370 // Replace barePtr with a placeholder (undef), promote barePtr to a ranked 371 // or unranked memref descriptor and replace placeholder with the last 372 // instruction of the memref descriptor. 373 // TODO: The placeholder is needed to avoid replacing barePtr uses in the 374 // MemRef descriptor instructions. We may want to have a utility in the 375 // rewriter to properly handle this use case. 376 Location loc = funcOp.getLoc(); 377 auto placeholder = rewriter.create<LLVM::UndefOp>( 378 loc, getTypeConverter()->convertType(memrefTy)); 379 rewriter.replaceUsesOfBlockArgument(arg, placeholder); 380 381 Value desc = MemRefDescriptor::fromStaticShape( 382 rewriter, loc, *getTypeConverter(), memrefTy, arg); 383 rewriter.replaceOp(placeholder, {desc}); 384 } 385 386 rewriter.eraseOp(funcOp); 387 return success(); 388 } 389 }; 390 391 struct ConstantOpLowering : public ConvertOpToLLVMPattern<func::ConstantOp> { 392 using ConvertOpToLLVMPattern<func::ConstantOp>::ConvertOpToLLVMPattern; 393 394 LogicalResult 395 matchAndRewrite(func::ConstantOp op, OpAdaptor adaptor, 396 ConversionPatternRewriter &rewriter) const override { 397 auto type = typeConverter->convertType(op.getResult().getType()); 398 if (!type || !LLVM::isCompatibleType(type)) 399 return rewriter.notifyMatchFailure(op, "failed to convert result type"); 400 401 auto newOp = 402 rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type, op.getValue()); 403 for (const NamedAttribute &attr : op->getAttrs()) { 404 if (attr.getName().strref() == "value") 405 continue; 406 newOp->setAttr(attr.getName(), attr.getValue()); 407 } 408 rewriter.replaceOp(op, newOp->getResults()); 409 return success(); 410 } 411 }; 412 413 // A CallOp automatically promotes MemRefType to a sequence of alloca/store and 414 // passes the pointer to the MemRef across function boundaries. 415 template <typename CallOpType> 416 struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> { 417 using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern; 418 using Super = CallOpInterfaceLowering<CallOpType>; 419 using Base = ConvertOpToLLVMPattern<CallOpType>; 420 421 LogicalResult 422 matchAndRewrite(CallOpType callOp, typename CallOpType::Adaptor adaptor, 423 ConversionPatternRewriter &rewriter) const override { 424 // Pack the result types into a struct. 425 Type packedResult = nullptr; 426 unsigned numResults = callOp.getNumResults(); 427 auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes()); 428 429 if (numResults != 0) { 430 if (!(packedResult = 431 this->getTypeConverter()->packFunctionResults(resultTypes))) 432 return failure(); 433 } 434 435 auto promoted = this->getTypeConverter()->promoteOperands( 436 callOp.getLoc(), /*opOperands=*/callOp->getOperands(), 437 adaptor.getOperands(), rewriter); 438 auto newOp = rewriter.create<LLVM::CallOp>( 439 callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(), 440 promoted, callOp->getAttrs()); 441 442 SmallVector<Value, 4> results; 443 if (numResults < 2) { 444 // If < 2 results, packing did not do anything and we can just return. 445 results.append(newOp.result_begin(), newOp.result_end()); 446 } else { 447 // Otherwise, it had been converted to an operation producing a structure. 448 // Extract individual results from the structure and return them as list. 449 results.reserve(numResults); 450 for (unsigned i = 0; i < numResults; ++i) { 451 auto type = 452 this->typeConverter->convertType(callOp.getResult(i).getType()); 453 results.push_back(rewriter.create<LLVM::ExtractValueOp>( 454 callOp.getLoc(), type, newOp->getResult(0), 455 rewriter.getI64ArrayAttr(i))); 456 } 457 } 458 459 if (this->getTypeConverter()->getOptions().useBarePtrCallConv) { 460 // For the bare-ptr calling convention, promote memref results to 461 // descriptors. 462 assert(results.size() == resultTypes.size() && 463 "The number of arguments and types doesn't match"); 464 this->getTypeConverter()->promoteBarePtrsToDescriptors( 465 rewriter, callOp.getLoc(), resultTypes, results); 466 } else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(), 467 resultTypes, results, 468 /*toDynamic=*/false))) { 469 return failure(); 470 } 471 472 rewriter.replaceOp(callOp, results); 473 return success(); 474 } 475 }; 476 477 struct CallOpLowering : public CallOpInterfaceLowering<func::CallOp> { 478 using Super::Super; 479 }; 480 481 struct CallIndirectOpLowering 482 : public CallOpInterfaceLowering<func::CallIndirectOp> { 483 using Super::Super; 484 }; 485 486 struct UnrealizedConversionCastOpLowering 487 : public ConvertOpToLLVMPattern<UnrealizedConversionCastOp> { 488 using ConvertOpToLLVMPattern< 489 UnrealizedConversionCastOp>::ConvertOpToLLVMPattern; 490 491 LogicalResult 492 matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor, 493 ConversionPatternRewriter &rewriter) const override { 494 SmallVector<Type> convertedTypes; 495 if (succeeded(typeConverter->convertTypes(op.getOutputs().getTypes(), 496 convertedTypes)) && 497 convertedTypes == adaptor.getInputs().getTypes()) { 498 rewriter.replaceOp(op, adaptor.getInputs()); 499 return success(); 500 } 501 502 convertedTypes.clear(); 503 if (succeeded(typeConverter->convertTypes(adaptor.getInputs().getTypes(), 504 convertedTypes)) && 505 convertedTypes == op.getOutputs().getType()) { 506 rewriter.replaceOp(op, adaptor.getInputs()); 507 return success(); 508 } 509 return failure(); 510 } 511 }; 512 513 // Special lowering pattern for `ReturnOps`. Unlike all other operations, 514 // `ReturnOp` interacts with the function signature and must have as many 515 // operands as the function has return values. Because in LLVM IR, functions 516 // can only return 0 or 1 value, we pack multiple values into a structure type. 517 // Emit `UndefOp` followed by `InsertValueOp`s to create such structure if 518 // necessary before returning it 519 struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> { 520 using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern; 521 522 LogicalResult 523 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, 524 ConversionPatternRewriter &rewriter) const override { 525 Location loc = op.getLoc(); 526 unsigned numArguments = op.getNumOperands(); 527 SmallVector<Value, 4> updatedOperands; 528 529 if (getTypeConverter()->getOptions().useBarePtrCallConv) { 530 // For the bare-ptr calling convention, extract the aligned pointer to 531 // be returned from the memref descriptor. 532 for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) { 533 Type oldTy = std::get<0>(it).getType(); 534 Value newOperand = std::get<1>(it); 535 if (oldTy.isa<MemRefType>()) { 536 MemRefDescriptor memrefDesc(newOperand); 537 newOperand = memrefDesc.alignedPtr(rewriter, loc); 538 } else if (oldTy.isa<UnrankedMemRefType>()) { 539 // Unranked memref is not supported in the bare pointer calling 540 // convention. 541 return failure(); 542 } 543 updatedOperands.push_back(newOperand); 544 } 545 } else { 546 updatedOperands = llvm::to_vector<4>(adaptor.getOperands()); 547 (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(), 548 updatedOperands, 549 /*toDynamic=*/true); 550 } 551 552 // If ReturnOp has 0 or 1 operand, create it and return immediately. 553 if (numArguments == 0) { 554 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(), 555 op->getAttrs()); 556 return success(); 557 } 558 if (numArguments == 1) { 559 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>( 560 op, TypeRange(), updatedOperands, op->getAttrs()); 561 return success(); 562 } 563 564 // Otherwise, we need to pack the arguments into an LLVM struct type before 565 // returning. 566 auto packedType = getTypeConverter()->packFunctionResults( 567 llvm::to_vector<4>(op.getOperandTypes())); 568 569 Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType); 570 for (unsigned i = 0; i < numArguments; ++i) { 571 packed = rewriter.create<LLVM::InsertValueOp>( 572 loc, packedType, packed, updatedOperands[i], 573 rewriter.getI64ArrayAttr(i)); 574 } 575 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed, 576 op->getAttrs()); 577 return success(); 578 } 579 }; 580 } // namespace 581 582 void mlir::populateFuncToLLVMFuncOpConversionPattern( 583 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 584 if (converter.getOptions().useBarePtrCallConv) 585 patterns.add<BarePtrFuncOpConversion>(converter); 586 else 587 patterns.add<FuncOpConversion>(converter); 588 } 589 590 void mlir::populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter, 591 RewritePatternSet &patterns) { 592 populateFuncToLLVMFuncOpConversionPattern(converter, patterns); 593 // clang-format off 594 patterns.add< 595 CallIndirectOpLowering, 596 CallOpLowering, 597 ConstantOpLowering, 598 ReturnOpLowering>(converter); 599 // clang-format on 600 } 601 602 namespace { 603 /// A pass converting Func operations into the LLVM IR dialect. 604 struct ConvertFuncToLLVMPass 605 : public ConvertFuncToLLVMBase<ConvertFuncToLLVMPass> { 606 ConvertFuncToLLVMPass() = default; 607 ConvertFuncToLLVMPass(bool useBarePtrCallConv, bool emitCWrappers, 608 unsigned indexBitwidth, bool useAlignedAlloc, 609 const llvm::DataLayout &dataLayout) { 610 this->useBarePtrCallConv = useBarePtrCallConv; 611 this->emitCWrappers = emitCWrappers; 612 this->indexBitwidth = indexBitwidth; 613 this->dataLayout = dataLayout.getStringRepresentation(); 614 } 615 616 /// Run the dialect converter on the module. 617 void runOnOperation() override { 618 if (useBarePtrCallConv && emitCWrappers) { 619 getOperation().emitError() 620 << "incompatible conversion options: bare-pointer calling convention " 621 "and C wrapper emission"; 622 signalPassFailure(); 623 return; 624 } 625 if (failed(LLVM::LLVMDialect::verifyDataLayoutString( 626 this->dataLayout, [this](const Twine &message) { 627 getOperation().emitError() << message.str(); 628 }))) { 629 signalPassFailure(); 630 return; 631 } 632 633 ModuleOp m = getOperation(); 634 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>(); 635 636 LowerToLLVMOptions options(&getContext(), 637 dataLayoutAnalysis.getAtOrAbove(m)); 638 options.useBarePtrCallConv = useBarePtrCallConv; 639 options.emitCWrappers = emitCWrappers; 640 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 641 options.overrideIndexBitwidth(indexBitwidth); 642 options.dataLayout = llvm::DataLayout(this->dataLayout); 643 644 LLVMTypeConverter typeConverter(&getContext(), options, 645 &dataLayoutAnalysis); 646 647 RewritePatternSet patterns(&getContext()); 648 populateFuncToLLVMConversionPatterns(typeConverter, patterns); 649 650 // TODO: Remove these in favor of their dedicated conversion passes. 651 arith::populateArithmeticToLLVMConversionPatterns(typeConverter, patterns); 652 cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns); 653 654 LLVMConversionTarget target(getContext()); 655 if (failed(applyPartialConversion(m, target, std::move(patterns)))) 656 signalPassFailure(); 657 658 m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(), 659 StringAttr::get(m.getContext(), this->dataLayout)); 660 } 661 }; 662 } // namespace 663 664 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertFuncToLLVMPass() { 665 return std::make_unique<ConvertFuncToLLVMPass>(); 666 } 667 668 std::unique_ptr<OperationPass<ModuleOp>> 669 mlir::createConvertFuncToLLVMPass(const LowerToLLVMOptions &options) { 670 auto allocLowering = options.allocLowering; 671 // There is no way to provide additional patterns for pass, so 672 // AllocLowering::None will always fail. 673 assert(allocLowering != LowerToLLVMOptions::AllocLowering::None && 674 "ConvertFuncToLLVMPass doesn't support AllocLowering::None"); 675 bool useAlignedAlloc = 676 (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc); 677 return std::make_unique<ConvertFuncToLLVMPass>( 678 options.useBarePtrCallConv, options.emitCWrappers, 679 options.getIndexBitwidth(), useAlignedAlloc, options.dataLayout); 680 } 681