1 //===- FuncToLLVM.cpp - Func to LLVM dialect conversion -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert MLIR Func and builtin dialects 10 // into the LLVM IR dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "../PassDetail.h" 15 #include "mlir/Analysis/DataLayoutAnalysis.h" 16 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" 17 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" 18 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" 19 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" 20 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 21 #include "mlir/Conversion/LLVMCommon/Pattern.h" 22 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 23 #include "mlir/Dialect/Func/IR/FuncOps.h" 24 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 25 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 26 #include "mlir/Dialect/Utils/StaticValueUtils.h" 27 #include "mlir/IR/Attributes.h" 28 #include "mlir/IR/BlockAndValueMapping.h" 29 #include "mlir/IR/Builders.h" 30 #include "mlir/IR/BuiltinOps.h" 31 #include "mlir/IR/PatternMatch.h" 32 #include "mlir/IR/TypeUtilities.h" 33 #include "mlir/Support/LogicalResult.h" 34 #include "mlir/Support/MathExtras.h" 35 #include "mlir/Transforms/DialectConversion.h" 36 #include "mlir/Transforms/Passes.h" 37 #include "llvm/ADT/TypeSwitch.h" 38 #include "llvm/IR/DerivedTypes.h" 39 #include "llvm/IR/IRBuilder.h" 40 #include "llvm/IR/Type.h" 41 #include "llvm/Support/CommandLine.h" 42 #include "llvm/Support/FormatVariadic.h" 43 #include <algorithm> 44 #include <functional> 45 46 using namespace mlir; 47 48 #define PASS_NAME "convert-func-to-llvm" 49 50 /// Only retain those attributes that are not constructed by 51 /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument 52 /// attributes. 53 static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs, 54 bool filterArgAndResAttrs, 55 SmallVectorImpl<NamedAttribute> &result) { 56 for (const auto &attr : attrs) { 57 if (attr.getName() == SymbolTable::getSymbolAttrName() || 58 attr.getName() == FunctionOpInterface::getTypeAttrName() || 59 attr.getName() == "func.varargs" || 60 (filterArgAndResAttrs && 61 (attr.getName() == FunctionOpInterface::getArgDictAttrName() || 62 attr.getName() == FunctionOpInterface::getResultDictAttrName()))) 63 continue; 64 result.push_back(attr); 65 } 66 } 67 68 /// Helper function for wrapping all attributes into a single DictionaryAttr 69 static auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) { 70 return DictionaryAttr::get( 71 b.getContext(), 72 b.getNamedAttr(LLVM::LLVMDialect::getStructAttrsAttrName(), attrs)); 73 } 74 75 /// Combines all result attributes into a single DictionaryAttr 76 /// and prepends to argument attrs. 77 /// This is intended to be used to format the attributes for a C wrapper 78 /// function when the result(s) is converted to the first function argument 79 /// (in the multiple return case, all returns get wrapped into a single 80 /// argument). The total number of argument attributes should be equal to 81 /// (number of function arguments) + 1. 82 static void 83 prependResAttrsToArgAttrs(OpBuilder &builder, 84 SmallVectorImpl<NamedAttribute> &attributes, 85 size_t numArguments) { 86 auto allAttrs = SmallVector<Attribute>( 87 numArguments + 1, DictionaryAttr::get(builder.getContext())); 88 NamedAttribute *argAttrs = nullptr; 89 for (auto it = attributes.begin(); it != attributes.end();) { 90 if (it->getName() == FunctionOpInterface::getArgDictAttrName()) { 91 auto arrayAttrs = it->getValue().cast<ArrayAttr>(); 92 assert(arrayAttrs.size() == numArguments && 93 "Number of arg attrs and args should match"); 94 std::copy(arrayAttrs.begin(), arrayAttrs.end(), allAttrs.begin() + 1); 95 argAttrs = it; 96 } else if (it->getName() == FunctionOpInterface::getResultDictAttrName()) { 97 auto arrayAttrs = it->getValue().cast<ArrayAttr>(); 98 assert(!arrayAttrs.empty() && "expected array to be non-empty"); 99 allAttrs[0] = (arrayAttrs.size() == 1) 100 ? arrayAttrs[0] 101 : wrapAsStructAttrs(builder, arrayAttrs); 102 it = attributes.erase(it); 103 continue; 104 } 105 it++; 106 } 107 108 auto newArgAttrs = 109 builder.getNamedAttr(FunctionOpInterface::getArgDictAttrName(), 110 builder.getArrayAttr(allAttrs)); 111 if (!argAttrs) { 112 attributes.emplace_back(newArgAttrs); 113 return; 114 } 115 *argAttrs = newArgAttrs; 116 return; 117 } 118 119 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct 120 /// arguments instead of unpacked arguments. This function can be called from C 121 /// by passing a pointer to a C struct corresponding to a memref descriptor. 122 /// Similarly, returned memrefs are passed via pointers to a C struct that is 123 /// passed as additional argument. 124 /// Internally, the auxiliary function unpacks the descriptor into individual 125 /// components and forwards them to `newFuncOp` and forwards the results to 126 /// the extra arguments. 127 static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, 128 LLVMTypeConverter &typeConverter, 129 FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) { 130 auto type = funcOp.getType(); 131 SmallVector<NamedAttribute, 4> attributes; 132 filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false, 133 attributes); 134 Type wrapperFuncType; 135 bool resultIsNowArg; 136 std::tie(wrapperFuncType, resultIsNowArg) = 137 typeConverter.convertFunctionTypeCWrapper(type); 138 if (resultIsNowArg) 139 prependResAttrsToArgAttrs(rewriter, attributes, funcOp.getNumArguments()); 140 auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>( 141 loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), 142 wrapperFuncType, LLVM::Linkage::External, /*dsoLocal*/ false, attributes); 143 144 OpBuilder::InsertionGuard guard(rewriter); 145 rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock()); 146 147 SmallVector<Value, 8> args; 148 size_t argOffset = resultIsNowArg ? 1 : 0; 149 for (auto &en : llvm::enumerate(type.getInputs())) { 150 Value arg = wrapperFuncOp.getArgument(en.index() + argOffset); 151 if (auto memrefType = en.value().dyn_cast<MemRefType>()) { 152 Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg); 153 MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args); 154 continue; 155 } 156 if (en.value().isa<UnrankedMemRefType>()) { 157 Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg); 158 UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args); 159 continue; 160 } 161 162 args.push_back(arg); 163 } 164 165 auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args); 166 167 if (resultIsNowArg) { 168 rewriter.create<LLVM::StoreOp>(loc, call.getResult(0), 169 wrapperFuncOp.getArgument(0)); 170 rewriter.create<LLVM::ReturnOp>(loc, ValueRange{}); 171 } else { 172 rewriter.create<LLVM::ReturnOp>(loc, call.getResults()); 173 } 174 } 175 176 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct 177 /// arguments instead of unpacked arguments. Creates a body for the (external) 178 /// `newFuncOp` that allocates a memref descriptor on stack, packs the 179 /// individual arguments into this descriptor and passes a pointer to it into 180 /// the auxiliary function. If the result of the function cannot be directly 181 /// returned, we write it to a special first argument that provides a pointer 182 /// to a corresponding struct. This auxiliary external function is now 183 /// compatible with functions defined in C using pointers to C structs 184 /// corresponding to a memref descriptor. 185 static void wrapExternalFunction(OpBuilder &builder, Location loc, 186 LLVMTypeConverter &typeConverter, 187 FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) { 188 OpBuilder::InsertionGuard guard(builder); 189 190 Type wrapperType; 191 bool resultIsNowArg; 192 std::tie(wrapperType, resultIsNowArg) = 193 typeConverter.convertFunctionTypeCWrapper(funcOp.getType()); 194 // This conversion can only fail if it could not convert one of the argument 195 // types. But since it has been applied to a non-wrapper function before, it 196 // should have failed earlier and not reach this point at all. 197 assert(wrapperType && "unexpected type conversion failure"); 198 199 SmallVector<NamedAttribute, 4> attributes; 200 filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false, 201 attributes); 202 203 if (resultIsNowArg) 204 prependResAttrsToArgAttrs(builder, attributes, funcOp.getNumArguments()); 205 // Create the auxiliary function. 206 auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>( 207 loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), 208 wrapperType, LLVM::Linkage::External, /*dsoLocal*/ false, attributes); 209 210 builder.setInsertionPointToStart(newFuncOp.addEntryBlock()); 211 212 // Get a ValueRange containing arguments. 213 FunctionType type = funcOp.getType(); 214 SmallVector<Value, 8> args; 215 args.reserve(type.getNumInputs()); 216 ValueRange wrapperArgsRange(newFuncOp.getArguments()); 217 218 if (resultIsNowArg) { 219 // Allocate the struct on the stack and pass the pointer. 220 Type resultType = 221 wrapperType.cast<LLVM::LLVMFunctionType>().getParamType(0); 222 Value one = builder.create<LLVM::ConstantOp>( 223 loc, typeConverter.convertType(builder.getIndexType()), 224 builder.getIntegerAttr(builder.getIndexType(), 1)); 225 Value result = builder.create<LLVM::AllocaOp>(loc, resultType, one); 226 args.push_back(result); 227 } 228 229 // Iterate over the inputs of the original function and pack values into 230 // memref descriptors if the original type is a memref. 231 for (auto &en : llvm::enumerate(type.getInputs())) { 232 Value arg; 233 int numToDrop = 1; 234 auto memRefType = en.value().dyn_cast<MemRefType>(); 235 auto unrankedMemRefType = en.value().dyn_cast<UnrankedMemRefType>(); 236 if (memRefType || unrankedMemRefType) { 237 numToDrop = memRefType 238 ? MemRefDescriptor::getNumUnpackedValues(memRefType) 239 : UnrankedMemRefDescriptor::getNumUnpackedValues(); 240 Value packed = 241 memRefType 242 ? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType, 243 wrapperArgsRange.take_front(numToDrop)) 244 : UnrankedMemRefDescriptor::pack( 245 builder, loc, typeConverter, unrankedMemRefType, 246 wrapperArgsRange.take_front(numToDrop)); 247 248 auto ptrTy = LLVM::LLVMPointerType::get(packed.getType()); 249 Value one = builder.create<LLVM::ConstantOp>( 250 loc, typeConverter.convertType(builder.getIndexType()), 251 builder.getIntegerAttr(builder.getIndexType(), 1)); 252 Value allocated = 253 builder.create<LLVM::AllocaOp>(loc, ptrTy, one, /*alignment=*/0); 254 builder.create<LLVM::StoreOp>(loc, packed, allocated); 255 arg = allocated; 256 } else { 257 arg = wrapperArgsRange[0]; 258 } 259 260 args.push_back(arg); 261 wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop); 262 } 263 assert(wrapperArgsRange.empty() && "did not map some of the arguments"); 264 265 auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args); 266 267 if (resultIsNowArg) { 268 Value result = builder.create<LLVM::LoadOp>(loc, args.front()); 269 builder.create<LLVM::ReturnOp>(loc, ValueRange{result}); 270 } else { 271 builder.create<LLVM::ReturnOp>(loc, call.getResults()); 272 } 273 } 274 275 namespace { 276 277 struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> { 278 protected: 279 using ConvertOpToLLVMPattern<FuncOp>::ConvertOpToLLVMPattern; 280 281 // Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided 282 // to this legalization pattern. 283 LLVM::LLVMFuncOp 284 convertFuncOpToLLVMFuncOp(FuncOp funcOp, 285 ConversionPatternRewriter &rewriter) const { 286 // Convert the original function arguments. They are converted using the 287 // LLVMTypeConverter provided to this legalization pattern. 288 auto varargsAttr = funcOp->getAttrOfType<BoolAttr>("func.varargs"); 289 TypeConverter::SignatureConversion result(funcOp.getNumArguments()); 290 auto llvmType = getTypeConverter()->convertFunctionSignature( 291 funcOp.getType(), varargsAttr && varargsAttr.getValue(), result); 292 if (!llvmType) 293 return nullptr; 294 295 // Propagate argument/result attributes to all converted arguments/result 296 // obtained after converting a given original argument/result. 297 SmallVector<NamedAttribute, 4> attributes; 298 filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/true, 299 attributes); 300 if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) { 301 assert(!resAttrDicts.empty() && "expected array to be non-empty"); 302 auto newResAttrDicts = 303 (funcOp.getNumResults() == 1) 304 ? resAttrDicts 305 : rewriter.getArrayAttr( 306 {wrapAsStructAttrs(rewriter, resAttrDicts)}); 307 attributes.push_back(rewriter.getNamedAttr( 308 FunctionOpInterface::getResultDictAttrName(), newResAttrDicts)); 309 } 310 if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) { 311 SmallVector<Attribute, 4> newArgAttrs( 312 llvmType.cast<LLVM::LLVMFunctionType>().getNumParams()); 313 for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) { 314 auto mapping = result.getInputMapping(i); 315 assert(mapping.hasValue() && 316 "unexpected deletion of function argument"); 317 for (size_t j = 0; j < mapping->size; ++j) 318 newArgAttrs[mapping->inputNo + j] = argAttrDicts[i]; 319 } 320 attributes.push_back( 321 rewriter.getNamedAttr(FunctionOpInterface::getArgDictAttrName(), 322 rewriter.getArrayAttr(newArgAttrs))); 323 } 324 for (const auto &pair : llvm::enumerate(attributes)) { 325 if (pair.value().getName() == "llvm.linkage") { 326 attributes.erase(attributes.begin() + pair.index()); 327 break; 328 } 329 } 330 331 // Create an LLVM function, use external linkage by default until MLIR 332 // functions have linkage. 333 LLVM::Linkage linkage = LLVM::Linkage::External; 334 if (funcOp->hasAttr("llvm.linkage")) { 335 auto attr = 336 funcOp->getAttr("llvm.linkage").dyn_cast<mlir::LLVM::LinkageAttr>(); 337 if (!attr) { 338 funcOp->emitError() 339 << "Contains llvm.linkage attribute not of type LLVM::LinkageAttr"; 340 return nullptr; 341 } 342 linkage = attr.getLinkage(); 343 } 344 auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>( 345 funcOp.getLoc(), funcOp.getName(), llvmType, linkage, 346 /*dsoLocal*/ false, attributes); 347 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), 348 newFuncOp.end()); 349 if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter, 350 &result))) 351 return nullptr; 352 353 return newFuncOp; 354 } 355 }; 356 357 /// FuncOp legalization pattern that converts MemRef arguments to pointers to 358 /// MemRef descriptors (LLVM struct data types) containing all the MemRef type 359 /// information. 360 static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface"; 361 struct FuncOpConversion : public FuncOpConversionBase { 362 FuncOpConversion(LLVMTypeConverter &converter) 363 : FuncOpConversionBase(converter) {} 364 365 LogicalResult 366 matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor, 367 ConversionPatternRewriter &rewriter) const override { 368 auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); 369 if (!newFuncOp) 370 return failure(); 371 372 if (getTypeConverter()->getOptions().emitCWrappers || 373 funcOp->getAttrOfType<UnitAttr>(kEmitIfaceAttrName)) { 374 if (newFuncOp.isExternal()) 375 wrapExternalFunction(rewriter, funcOp.getLoc(), *getTypeConverter(), 376 funcOp, newFuncOp); 377 else 378 wrapForExternalCallers(rewriter, funcOp.getLoc(), *getTypeConverter(), 379 funcOp, newFuncOp); 380 } 381 382 rewriter.eraseOp(funcOp); 383 return success(); 384 } 385 }; 386 387 /// FuncOp legalization pattern that converts MemRef arguments to bare pointers 388 /// to the MemRef element type. This will impact the calling convention and ABI. 389 struct BarePtrFuncOpConversion : public FuncOpConversionBase { 390 using FuncOpConversionBase::FuncOpConversionBase; 391 392 LogicalResult 393 matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor, 394 ConversionPatternRewriter &rewriter) const override { 395 396 // TODO: bare ptr conversion could be handled by argument materialization 397 // and most of the code below would go away. But to do this, we would need a 398 // way to distinguish between FuncOp and other regions in the 399 // addArgumentMaterialization hook. 400 401 // Store the type of memref-typed arguments before the conversion so that we 402 // can promote them to MemRef descriptor at the beginning of the function. 403 SmallVector<Type, 8> oldArgTypes = 404 llvm::to_vector<8>(funcOp.getType().getInputs()); 405 406 auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); 407 if (!newFuncOp) 408 return failure(); 409 if (newFuncOp.getBody().empty()) { 410 rewriter.eraseOp(funcOp); 411 return success(); 412 } 413 414 // Promote bare pointers from memref arguments to memref descriptors at the 415 // beginning of the function so that all the memrefs in the function have a 416 // uniform representation. 417 Block *entryBlock = &newFuncOp.getBody().front(); 418 auto blockArgs = entryBlock->getArguments(); 419 assert(blockArgs.size() == oldArgTypes.size() && 420 "The number of arguments and types doesn't match"); 421 422 OpBuilder::InsertionGuard guard(rewriter); 423 rewriter.setInsertionPointToStart(entryBlock); 424 for (auto it : llvm::zip(blockArgs, oldArgTypes)) { 425 BlockArgument arg = std::get<0>(it); 426 Type argTy = std::get<1>(it); 427 428 // Unranked memrefs are not supported in the bare pointer calling 429 // convention. We should have bailed out before in the presence of 430 // unranked memrefs. 431 assert(!argTy.isa<UnrankedMemRefType>() && 432 "Unranked memref is not supported"); 433 auto memrefTy = argTy.dyn_cast<MemRefType>(); 434 if (!memrefTy) 435 continue; 436 437 // Replace barePtr with a placeholder (undef), promote barePtr to a ranked 438 // or unranked memref descriptor and replace placeholder with the last 439 // instruction of the memref descriptor. 440 // TODO: The placeholder is needed to avoid replacing barePtr uses in the 441 // MemRef descriptor instructions. We may want to have a utility in the 442 // rewriter to properly handle this use case. 443 Location loc = funcOp.getLoc(); 444 auto placeholder = rewriter.create<LLVM::UndefOp>( 445 loc, getTypeConverter()->convertType(memrefTy)); 446 rewriter.replaceUsesOfBlockArgument(arg, placeholder); 447 448 Value desc = MemRefDescriptor::fromStaticShape( 449 rewriter, loc, *getTypeConverter(), memrefTy, arg); 450 rewriter.replaceOp(placeholder, {desc}); 451 } 452 453 rewriter.eraseOp(funcOp); 454 return success(); 455 } 456 }; 457 458 struct ConstantOpLowering : public ConvertOpToLLVMPattern<func::ConstantOp> { 459 using ConvertOpToLLVMPattern<func::ConstantOp>::ConvertOpToLLVMPattern; 460 461 LogicalResult 462 matchAndRewrite(func::ConstantOp op, OpAdaptor adaptor, 463 ConversionPatternRewriter &rewriter) const override { 464 auto type = typeConverter->convertType(op.getResult().getType()); 465 if (!type || !LLVM::isCompatibleType(type)) 466 return rewriter.notifyMatchFailure(op, "failed to convert result type"); 467 468 auto newOp = 469 rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type, op.getValue()); 470 for (const NamedAttribute &attr : op->getAttrs()) { 471 if (attr.getName().strref() == "value") 472 continue; 473 newOp->setAttr(attr.getName(), attr.getValue()); 474 } 475 rewriter.replaceOp(op, newOp->getResults()); 476 return success(); 477 } 478 }; 479 480 // A CallOp automatically promotes MemRefType to a sequence of alloca/store and 481 // passes the pointer to the MemRef across function boundaries. 482 template <typename CallOpType> 483 struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> { 484 using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern; 485 using Super = CallOpInterfaceLowering<CallOpType>; 486 using Base = ConvertOpToLLVMPattern<CallOpType>; 487 488 LogicalResult 489 matchAndRewrite(CallOpType callOp, typename CallOpType::Adaptor adaptor, 490 ConversionPatternRewriter &rewriter) const override { 491 // Pack the result types into a struct. 492 Type packedResult = nullptr; 493 unsigned numResults = callOp.getNumResults(); 494 auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes()); 495 496 if (numResults != 0) { 497 if (!(packedResult = 498 this->getTypeConverter()->packFunctionResults(resultTypes))) 499 return failure(); 500 } 501 502 auto promoted = this->getTypeConverter()->promoteOperands( 503 callOp.getLoc(), /*opOperands=*/callOp->getOperands(), 504 adaptor.getOperands(), rewriter); 505 auto newOp = rewriter.create<LLVM::CallOp>( 506 callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(), 507 promoted, callOp->getAttrs()); 508 509 SmallVector<Value, 4> results; 510 if (numResults < 2) { 511 // If < 2 results, packing did not do anything and we can just return. 512 results.append(newOp.result_begin(), newOp.result_end()); 513 } else { 514 // Otherwise, it had been converted to an operation producing a structure. 515 // Extract individual results from the structure and return them as list. 516 results.reserve(numResults); 517 for (unsigned i = 0; i < numResults; ++i) { 518 auto type = 519 this->typeConverter->convertType(callOp.getResult(i).getType()); 520 results.push_back(rewriter.create<LLVM::ExtractValueOp>( 521 callOp.getLoc(), type, newOp->getResult(0), 522 rewriter.getI64ArrayAttr(i))); 523 } 524 } 525 526 if (this->getTypeConverter()->getOptions().useBarePtrCallConv) { 527 // For the bare-ptr calling convention, promote memref results to 528 // descriptors. 529 assert(results.size() == resultTypes.size() && 530 "The number of arguments and types doesn't match"); 531 this->getTypeConverter()->promoteBarePtrsToDescriptors( 532 rewriter, callOp.getLoc(), resultTypes, results); 533 } else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(), 534 resultTypes, results, 535 /*toDynamic=*/false))) { 536 return failure(); 537 } 538 539 rewriter.replaceOp(callOp, results); 540 return success(); 541 } 542 }; 543 544 struct CallOpLowering : public CallOpInterfaceLowering<func::CallOp> { 545 using Super::Super; 546 }; 547 548 struct CallIndirectOpLowering 549 : public CallOpInterfaceLowering<func::CallIndirectOp> { 550 using Super::Super; 551 }; 552 553 struct UnrealizedConversionCastOpLowering 554 : public ConvertOpToLLVMPattern<UnrealizedConversionCastOp> { 555 using ConvertOpToLLVMPattern< 556 UnrealizedConversionCastOp>::ConvertOpToLLVMPattern; 557 558 LogicalResult 559 matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor, 560 ConversionPatternRewriter &rewriter) const override { 561 SmallVector<Type> convertedTypes; 562 if (succeeded(typeConverter->convertTypes(op.getOutputs().getTypes(), 563 convertedTypes)) && 564 convertedTypes == adaptor.getInputs().getTypes()) { 565 rewriter.replaceOp(op, adaptor.getInputs()); 566 return success(); 567 } 568 569 convertedTypes.clear(); 570 if (succeeded(typeConverter->convertTypes(adaptor.getInputs().getTypes(), 571 convertedTypes)) && 572 convertedTypes == op.getOutputs().getType()) { 573 rewriter.replaceOp(op, adaptor.getInputs()); 574 return success(); 575 } 576 return failure(); 577 } 578 }; 579 580 // Special lowering pattern for `ReturnOps`. Unlike all other operations, 581 // `ReturnOp` interacts with the function signature and must have as many 582 // operands as the function has return values. Because in LLVM IR, functions 583 // can only return 0 or 1 value, we pack multiple values into a structure type. 584 // Emit `UndefOp` followed by `InsertValueOp`s to create such structure if 585 // necessary before returning it 586 struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> { 587 using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern; 588 589 LogicalResult 590 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, 591 ConversionPatternRewriter &rewriter) const override { 592 Location loc = op.getLoc(); 593 unsigned numArguments = op.getNumOperands(); 594 SmallVector<Value, 4> updatedOperands; 595 596 if (getTypeConverter()->getOptions().useBarePtrCallConv) { 597 // For the bare-ptr calling convention, extract the aligned pointer to 598 // be returned from the memref descriptor. 599 for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) { 600 Type oldTy = std::get<0>(it).getType(); 601 Value newOperand = std::get<1>(it); 602 if (oldTy.isa<MemRefType>()) { 603 MemRefDescriptor memrefDesc(newOperand); 604 newOperand = memrefDesc.alignedPtr(rewriter, loc); 605 } else if (oldTy.isa<UnrankedMemRefType>()) { 606 // Unranked memref is not supported in the bare pointer calling 607 // convention. 608 return failure(); 609 } 610 updatedOperands.push_back(newOperand); 611 } 612 } else { 613 updatedOperands = llvm::to_vector<4>(adaptor.getOperands()); 614 (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(), 615 updatedOperands, 616 /*toDynamic=*/true); 617 } 618 619 // If ReturnOp has 0 or 1 operand, create it and return immediately. 620 if (numArguments == 0) { 621 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(), 622 op->getAttrs()); 623 return success(); 624 } 625 if (numArguments == 1) { 626 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>( 627 op, TypeRange(), updatedOperands, op->getAttrs()); 628 return success(); 629 } 630 631 // Otherwise, we need to pack the arguments into an LLVM struct type before 632 // returning. 633 auto packedType = getTypeConverter()->packFunctionResults( 634 llvm::to_vector<4>(op.getOperandTypes())); 635 636 Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType); 637 for (unsigned i = 0; i < numArguments; ++i) { 638 packed = rewriter.create<LLVM::InsertValueOp>( 639 loc, packedType, packed, updatedOperands[i], 640 rewriter.getI64ArrayAttr(i)); 641 } 642 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed, 643 op->getAttrs()); 644 return success(); 645 } 646 }; 647 } // namespace 648 649 void mlir::populateFuncToLLVMFuncOpConversionPattern( 650 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 651 if (converter.getOptions().useBarePtrCallConv) 652 patterns.add<BarePtrFuncOpConversion>(converter); 653 else 654 patterns.add<FuncOpConversion>(converter); 655 } 656 657 void mlir::populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter, 658 RewritePatternSet &patterns) { 659 populateFuncToLLVMFuncOpConversionPattern(converter, patterns); 660 // clang-format off 661 patterns.add< 662 CallIndirectOpLowering, 663 CallOpLowering, 664 ConstantOpLowering, 665 ReturnOpLowering>(converter); 666 // clang-format on 667 } 668 669 namespace { 670 /// A pass converting Func operations into the LLVM IR dialect. 671 struct ConvertFuncToLLVMPass 672 : public ConvertFuncToLLVMBase<ConvertFuncToLLVMPass> { 673 ConvertFuncToLLVMPass() = default; 674 ConvertFuncToLLVMPass(bool useBarePtrCallConv, bool emitCWrappers, 675 unsigned indexBitwidth, bool useAlignedAlloc, 676 const llvm::DataLayout &dataLayout) { 677 this->useBarePtrCallConv = useBarePtrCallConv; 678 this->emitCWrappers = emitCWrappers; 679 this->indexBitwidth = indexBitwidth; 680 this->dataLayout = dataLayout.getStringRepresentation(); 681 } 682 683 /// Run the dialect converter on the module. 684 void runOnOperation() override { 685 if (useBarePtrCallConv && emitCWrappers) { 686 getOperation().emitError() 687 << "incompatible conversion options: bare-pointer calling convention " 688 "and C wrapper emission"; 689 signalPassFailure(); 690 return; 691 } 692 if (failed(LLVM::LLVMDialect::verifyDataLayoutString( 693 this->dataLayout, [this](const Twine &message) { 694 getOperation().emitError() << message.str(); 695 }))) { 696 signalPassFailure(); 697 return; 698 } 699 700 ModuleOp m = getOperation(); 701 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>(); 702 703 LowerToLLVMOptions options(&getContext(), 704 dataLayoutAnalysis.getAtOrAbove(m)); 705 options.useBarePtrCallConv = useBarePtrCallConv; 706 options.emitCWrappers = emitCWrappers; 707 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 708 options.overrideIndexBitwidth(indexBitwidth); 709 options.dataLayout = llvm::DataLayout(this->dataLayout); 710 711 LLVMTypeConverter typeConverter(&getContext(), options, 712 &dataLayoutAnalysis); 713 714 RewritePatternSet patterns(&getContext()); 715 populateFuncToLLVMConversionPatterns(typeConverter, patterns); 716 717 // TODO: Remove these in favor of their dedicated conversion passes. 718 arith::populateArithmeticToLLVMConversionPatterns(typeConverter, patterns); 719 cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns); 720 721 LLVMConversionTarget target(getContext()); 722 if (failed(applyPartialConversion(m, target, std::move(patterns)))) 723 signalPassFailure(); 724 725 m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(), 726 StringAttr::get(m.getContext(), this->dataLayout)); 727 } 728 }; 729 } // namespace 730 731 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertFuncToLLVMPass() { 732 return std::make_unique<ConvertFuncToLLVMPass>(); 733 } 734 735 std::unique_ptr<OperationPass<ModuleOp>> 736 mlir::createConvertFuncToLLVMPass(const LowerToLLVMOptions &options) { 737 auto allocLowering = options.allocLowering; 738 // There is no way to provide additional patterns for pass, so 739 // AllocLowering::None will always fail. 740 assert(allocLowering != LowerToLLVMOptions::AllocLowering::None && 741 "ConvertFuncToLLVMPass doesn't support AllocLowering::None"); 742 bool useAlignedAlloc = 743 (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc); 744 return std::make_unique<ConvertFuncToLLVMPass>( 745 options.useBarePtrCallConv, options.emitCWrappers, 746 options.getIndexBitwidth(), useAlignedAlloc, options.dataLayout); 747 } 748