1 //===- FuncToLLVM.cpp - Func to LLVM dialect conversion -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert MLIR Func and builtin dialects 10 // into the LLVM IR dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "../PassDetail.h" 15 #include "mlir/Analysis/DataLayoutAnalysis.h" 16 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" 17 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" 18 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" 19 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" 20 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 21 #include "mlir/Conversion/LLVMCommon/Pattern.h" 22 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 23 #include "mlir/Dialect/Func/IR/FuncOps.h" 24 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 25 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 26 #include "mlir/Dialect/Utils/StaticValueUtils.h" 27 #include "mlir/IR/Attributes.h" 28 #include "mlir/IR/BlockAndValueMapping.h" 29 #include "mlir/IR/Builders.h" 30 #include "mlir/IR/BuiltinOps.h" 31 #include "mlir/IR/PatternMatch.h" 32 #include "mlir/IR/TypeUtilities.h" 33 #include "mlir/Support/LogicalResult.h" 34 #include "mlir/Support/MathExtras.h" 35 #include "mlir/Transforms/DialectConversion.h" 36 #include "mlir/Transforms/Passes.h" 37 #include "llvm/ADT/TypeSwitch.h" 38 #include "llvm/IR/DerivedTypes.h" 39 #include "llvm/IR/IRBuilder.h" 40 #include "llvm/IR/Type.h" 41 #include "llvm/Support/CommandLine.h" 42 #include "llvm/Support/FormatVariadic.h" 43 #include <algorithm> 44 #include <functional> 45 46 using namespace mlir; 47 48 #define PASS_NAME "convert-func-to-llvm" 49 50 /// Only retain those attributes that are not constructed by 51 /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument 52 /// attributes. 53 static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs, 54 bool filterArgAndResAttrs, 55 SmallVectorImpl<NamedAttribute> &result) { 56 for (const auto &attr : attrs) { 57 if (attr.getName() == SymbolTable::getSymbolAttrName() || 58 attr.getName() == FunctionOpInterface::getTypeAttrName() || 59 attr.getName() == "func.varargs" || 60 (filterArgAndResAttrs && 61 (attr.getName() == FunctionOpInterface::getArgDictAttrName() || 62 attr.getName() == FunctionOpInterface::getResultDictAttrName()))) 63 continue; 64 result.push_back(attr); 65 } 66 } 67 68 /// Helper function for wrapping all attributes into a single DictionaryAttr 69 static auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) { 70 return DictionaryAttr::get( 71 b.getContext(), 72 b.getNamedAttr(LLVM::LLVMDialect::getStructAttrsAttrName(), attrs)); 73 } 74 75 /// Combines all result attributes into a single DictionaryAttr 76 /// and prepends to argument attrs. 77 /// This is intended to be used to format the attributes for a C wrapper 78 /// function when the result(s) is converted to the first function argument 79 /// (in the multiple return case, all returns get wrapped into a single 80 /// argument). The total number of argument attributes should be equal to 81 /// (number of function arguments) + 1. 82 static void 83 prependResAttrsToArgAttrs(OpBuilder &builder, 84 SmallVectorImpl<NamedAttribute> &attributes, 85 size_t numArguments) { 86 auto allAttrs = SmallVector<Attribute>( 87 numArguments + 1, DictionaryAttr::get(builder.getContext())); 88 NamedAttribute *argAttrs = nullptr; 89 for (auto *it = attributes.begin(); it != attributes.end();) { 90 if (it->getName() == FunctionOpInterface::getArgDictAttrName()) { 91 auto arrayAttrs = it->getValue().cast<ArrayAttr>(); 92 assert(arrayAttrs.size() == numArguments && 93 "Number of arg attrs and args should match"); 94 std::copy(arrayAttrs.begin(), arrayAttrs.end(), allAttrs.begin() + 1); 95 argAttrs = it; 96 } else if (it->getName() == FunctionOpInterface::getResultDictAttrName()) { 97 auto arrayAttrs = it->getValue().cast<ArrayAttr>(); 98 assert(!arrayAttrs.empty() && "expected array to be non-empty"); 99 allAttrs[0] = (arrayAttrs.size() == 1) 100 ? arrayAttrs[0] 101 : wrapAsStructAttrs(builder, arrayAttrs); 102 it = attributes.erase(it); 103 continue; 104 } 105 it++; 106 } 107 108 auto newArgAttrs = 109 builder.getNamedAttr(FunctionOpInterface::getArgDictAttrName(), 110 builder.getArrayAttr(allAttrs)); 111 if (!argAttrs) { 112 attributes.emplace_back(newArgAttrs); 113 return; 114 } 115 *argAttrs = newArgAttrs; 116 } 117 118 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct 119 /// arguments instead of unpacked arguments. This function can be called from C 120 /// by passing a pointer to a C struct corresponding to a memref descriptor. 121 /// Similarly, returned memrefs are passed via pointers to a C struct that is 122 /// passed as additional argument. 123 /// Internally, the auxiliary function unpacks the descriptor into individual 124 /// components and forwards them to `newFuncOp` and forwards the results to 125 /// the extra arguments. 126 static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, 127 LLVMTypeConverter &typeConverter, 128 func::FuncOp funcOp, 129 LLVM::LLVMFuncOp newFuncOp) { 130 auto type = funcOp.getFunctionType(); 131 SmallVector<NamedAttribute, 4> attributes; 132 filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false, 133 attributes); 134 Type wrapperFuncType; 135 bool resultIsNowArg; 136 std::tie(wrapperFuncType, resultIsNowArg) = 137 typeConverter.convertFunctionTypeCWrapper(type); 138 if (resultIsNowArg) 139 prependResAttrsToArgAttrs(rewriter, attributes, funcOp.getNumArguments()); 140 auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>( 141 loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), 142 wrapperFuncType, LLVM::Linkage::External, /*dsoLocal*/ false, 143 /*cconv*/ LLVM::CConv::C, attributes); 144 145 OpBuilder::InsertionGuard guard(rewriter); 146 rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock()); 147 148 SmallVector<Value, 8> args; 149 size_t argOffset = resultIsNowArg ? 1 : 0; 150 for (auto &en : llvm::enumerate(type.getInputs())) { 151 Value arg = wrapperFuncOp.getArgument(en.index() + argOffset); 152 if (auto memrefType = en.value().dyn_cast<MemRefType>()) { 153 Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg); 154 MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args); 155 continue; 156 } 157 if (en.value().isa<UnrankedMemRefType>()) { 158 Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg); 159 UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args); 160 continue; 161 } 162 163 args.push_back(arg); 164 } 165 166 auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args); 167 168 if (resultIsNowArg) { 169 rewriter.create<LLVM::StoreOp>(loc, call.getResult(0), 170 wrapperFuncOp.getArgument(0)); 171 rewriter.create<LLVM::ReturnOp>(loc, ValueRange{}); 172 } else { 173 rewriter.create<LLVM::ReturnOp>(loc, call.getResults()); 174 } 175 } 176 177 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct 178 /// arguments instead of unpacked arguments. Creates a body for the (external) 179 /// `newFuncOp` that allocates a memref descriptor on stack, packs the 180 /// individual arguments into this descriptor and passes a pointer to it into 181 /// the auxiliary function. If the result of the function cannot be directly 182 /// returned, we write it to a special first argument that provides a pointer 183 /// to a corresponding struct. This auxiliary external function is now 184 /// compatible with functions defined in C using pointers to C structs 185 /// corresponding to a memref descriptor. 186 static void wrapExternalFunction(OpBuilder &builder, Location loc, 187 LLVMTypeConverter &typeConverter, 188 func::FuncOp funcOp, 189 LLVM::LLVMFuncOp newFuncOp) { 190 OpBuilder::InsertionGuard guard(builder); 191 192 Type wrapperType; 193 bool resultIsNowArg; 194 std::tie(wrapperType, resultIsNowArg) = 195 typeConverter.convertFunctionTypeCWrapper(funcOp.getFunctionType()); 196 // This conversion can only fail if it could not convert one of the argument 197 // types. But since it has been applied to a non-wrapper function before, it 198 // should have failed earlier and not reach this point at all. 199 assert(wrapperType && "unexpected type conversion failure"); 200 201 SmallVector<NamedAttribute, 4> attributes; 202 filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false, 203 attributes); 204 205 if (resultIsNowArg) 206 prependResAttrsToArgAttrs(builder, attributes, funcOp.getNumArguments()); 207 // Create the auxiliary function. 208 auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>( 209 loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), 210 wrapperType, LLVM::Linkage::External, /*dsoLocal*/ false, 211 /*cconv*/ LLVM::CConv::C, attributes); 212 213 builder.setInsertionPointToStart(newFuncOp.addEntryBlock()); 214 215 // Get a ValueRange containing arguments. 216 FunctionType type = funcOp.getFunctionType(); 217 SmallVector<Value, 8> args; 218 args.reserve(type.getNumInputs()); 219 ValueRange wrapperArgsRange(newFuncOp.getArguments()); 220 221 if (resultIsNowArg) { 222 // Allocate the struct on the stack and pass the pointer. 223 Type resultType = 224 wrapperType.cast<LLVM::LLVMFunctionType>().getParamType(0); 225 Value one = builder.create<LLVM::ConstantOp>( 226 loc, typeConverter.convertType(builder.getIndexType()), 227 builder.getIntegerAttr(builder.getIndexType(), 1)); 228 Value result = builder.create<LLVM::AllocaOp>(loc, resultType, one); 229 args.push_back(result); 230 } 231 232 // Iterate over the inputs of the original function and pack values into 233 // memref descriptors if the original type is a memref. 234 for (auto &en : llvm::enumerate(type.getInputs())) { 235 Value arg; 236 int numToDrop = 1; 237 auto memRefType = en.value().dyn_cast<MemRefType>(); 238 auto unrankedMemRefType = en.value().dyn_cast<UnrankedMemRefType>(); 239 if (memRefType || unrankedMemRefType) { 240 numToDrop = memRefType 241 ? MemRefDescriptor::getNumUnpackedValues(memRefType) 242 : UnrankedMemRefDescriptor::getNumUnpackedValues(); 243 Value packed = 244 memRefType 245 ? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType, 246 wrapperArgsRange.take_front(numToDrop)) 247 : UnrankedMemRefDescriptor::pack( 248 builder, loc, typeConverter, unrankedMemRefType, 249 wrapperArgsRange.take_front(numToDrop)); 250 251 auto ptrTy = LLVM::LLVMPointerType::get(packed.getType()); 252 Value one = builder.create<LLVM::ConstantOp>( 253 loc, typeConverter.convertType(builder.getIndexType()), 254 builder.getIntegerAttr(builder.getIndexType(), 1)); 255 Value allocated = 256 builder.create<LLVM::AllocaOp>(loc, ptrTy, one, /*alignment=*/0); 257 builder.create<LLVM::StoreOp>(loc, packed, allocated); 258 arg = allocated; 259 } else { 260 arg = wrapperArgsRange[0]; 261 } 262 263 args.push_back(arg); 264 wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop); 265 } 266 assert(wrapperArgsRange.empty() && "did not map some of the arguments"); 267 268 auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args); 269 270 if (resultIsNowArg) { 271 Value result = builder.create<LLVM::LoadOp>(loc, args.front()); 272 builder.create<LLVM::ReturnOp>(loc, ValueRange{result}); 273 } else { 274 builder.create<LLVM::ReturnOp>(loc, call.getResults()); 275 } 276 } 277 278 namespace { 279 280 struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> { 281 protected: 282 using ConvertOpToLLVMPattern<func::FuncOp>::ConvertOpToLLVMPattern; 283 284 // Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided 285 // to this legalization pattern. 286 LLVM::LLVMFuncOp 287 convertFuncOpToLLVMFuncOp(func::FuncOp funcOp, 288 ConversionPatternRewriter &rewriter) const { 289 // Convert the original function arguments. They are converted using the 290 // LLVMTypeConverter provided to this legalization pattern. 291 auto varargsAttr = funcOp->getAttrOfType<BoolAttr>("func.varargs"); 292 TypeConverter::SignatureConversion result(funcOp.getNumArguments()); 293 auto llvmType = getTypeConverter()->convertFunctionSignature( 294 funcOp.getFunctionType(), varargsAttr && varargsAttr.getValue(), 295 result); 296 if (!llvmType) 297 return nullptr; 298 299 // Propagate argument/result attributes to all converted arguments/result 300 // obtained after converting a given original argument/result. 301 SmallVector<NamedAttribute, 4> attributes; 302 filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/true, 303 attributes); 304 if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) { 305 assert(!resAttrDicts.empty() && "expected array to be non-empty"); 306 auto newResAttrDicts = 307 (funcOp.getNumResults() == 1) 308 ? resAttrDicts 309 : rewriter.getArrayAttr( 310 {wrapAsStructAttrs(rewriter, resAttrDicts)}); 311 attributes.push_back(rewriter.getNamedAttr( 312 FunctionOpInterface::getResultDictAttrName(), newResAttrDicts)); 313 } 314 if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) { 315 SmallVector<Attribute, 4> newArgAttrs( 316 llvmType.cast<LLVM::LLVMFunctionType>().getNumParams()); 317 for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) { 318 auto mapping = result.getInputMapping(i); 319 assert(mapping && "unexpected deletion of function argument"); 320 for (size_t j = 0; j < mapping->size; ++j) 321 newArgAttrs[mapping->inputNo + j] = argAttrDicts[i]; 322 } 323 attributes.push_back( 324 rewriter.getNamedAttr(FunctionOpInterface::getArgDictAttrName(), 325 rewriter.getArrayAttr(newArgAttrs))); 326 } 327 for (const auto &pair : llvm::enumerate(attributes)) { 328 if (pair.value().getName() == "llvm.linkage") { 329 attributes.erase(attributes.begin() + pair.index()); 330 break; 331 } 332 } 333 334 // Create an LLVM function, use external linkage by default until MLIR 335 // functions have linkage. 336 LLVM::Linkage linkage = LLVM::Linkage::External; 337 if (funcOp->hasAttr("llvm.linkage")) { 338 auto attr = 339 funcOp->getAttr("llvm.linkage").dyn_cast<mlir::LLVM::LinkageAttr>(); 340 if (!attr) { 341 funcOp->emitError() 342 << "Contains llvm.linkage attribute not of type LLVM::LinkageAttr"; 343 return nullptr; 344 } 345 linkage = attr.getLinkage(); 346 } 347 auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>( 348 funcOp.getLoc(), funcOp.getName(), llvmType, linkage, 349 /*dsoLocal*/ false, /*cconv*/ LLVM::CConv::C, attributes); 350 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), 351 newFuncOp.end()); 352 if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter, 353 &result))) 354 return nullptr; 355 356 return newFuncOp; 357 } 358 }; 359 360 /// FuncOp legalization pattern that converts MemRef arguments to pointers to 361 /// MemRef descriptors (LLVM struct data types) containing all the MemRef type 362 /// information. 363 struct FuncOpConversion : public FuncOpConversionBase { 364 FuncOpConversion(LLVMTypeConverter &converter) 365 : FuncOpConversionBase(converter) {} 366 367 LogicalResult 368 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor, 369 ConversionPatternRewriter &rewriter) const override { 370 auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); 371 if (!newFuncOp) 372 return failure(); 373 374 if (funcOp->getAttrOfType<UnitAttr>( 375 LLVM::LLVMDialect::getEmitCWrapperAttrName())) { 376 if (newFuncOp.isVarArg()) 377 return funcOp->emitError("C interface for variadic functions is not " 378 "supported yet."); 379 380 if (newFuncOp.isExternal()) 381 wrapExternalFunction(rewriter, funcOp.getLoc(), *getTypeConverter(), 382 funcOp, newFuncOp); 383 else 384 wrapForExternalCallers(rewriter, funcOp.getLoc(), *getTypeConverter(), 385 funcOp, newFuncOp); 386 } 387 388 rewriter.eraseOp(funcOp); 389 return success(); 390 } 391 }; 392 393 /// FuncOp legalization pattern that converts MemRef arguments to bare pointers 394 /// to the MemRef element type. This will impact the calling convention and ABI. 395 struct BarePtrFuncOpConversion : public FuncOpConversionBase { 396 using FuncOpConversionBase::FuncOpConversionBase; 397 398 LogicalResult 399 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor, 400 ConversionPatternRewriter &rewriter) const override { 401 402 // TODO: bare ptr conversion could be handled by argument materialization 403 // and most of the code below would go away. But to do this, we would need a 404 // way to distinguish between FuncOp and other regions in the 405 // addArgumentMaterialization hook. 406 407 // Store the type of memref-typed arguments before the conversion so that we 408 // can promote them to MemRef descriptor at the beginning of the function. 409 SmallVector<Type, 8> oldArgTypes = 410 llvm::to_vector<8>(funcOp.getFunctionType().getInputs()); 411 412 auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); 413 if (!newFuncOp) 414 return failure(); 415 if (newFuncOp.getBody().empty()) { 416 rewriter.eraseOp(funcOp); 417 return success(); 418 } 419 420 // Promote bare pointers from memref arguments to memref descriptors at the 421 // beginning of the function so that all the memrefs in the function have a 422 // uniform representation. 423 Block *entryBlock = &newFuncOp.getBody().front(); 424 auto blockArgs = entryBlock->getArguments(); 425 assert(blockArgs.size() == oldArgTypes.size() && 426 "The number of arguments and types doesn't match"); 427 428 OpBuilder::InsertionGuard guard(rewriter); 429 rewriter.setInsertionPointToStart(entryBlock); 430 for (auto it : llvm::zip(blockArgs, oldArgTypes)) { 431 BlockArgument arg = std::get<0>(it); 432 Type argTy = std::get<1>(it); 433 434 // Unranked memrefs are not supported in the bare pointer calling 435 // convention. We should have bailed out before in the presence of 436 // unranked memrefs. 437 assert(!argTy.isa<UnrankedMemRefType>() && 438 "Unranked memref is not supported"); 439 auto memrefTy = argTy.dyn_cast<MemRefType>(); 440 if (!memrefTy) 441 continue; 442 443 // Replace barePtr with a placeholder (undef), promote barePtr to a ranked 444 // or unranked memref descriptor and replace placeholder with the last 445 // instruction of the memref descriptor. 446 // TODO: The placeholder is needed to avoid replacing barePtr uses in the 447 // MemRef descriptor instructions. We may want to have a utility in the 448 // rewriter to properly handle this use case. 449 Location loc = funcOp.getLoc(); 450 auto placeholder = rewriter.create<LLVM::UndefOp>( 451 loc, getTypeConverter()->convertType(memrefTy)); 452 rewriter.replaceUsesOfBlockArgument(arg, placeholder); 453 454 Value desc = MemRefDescriptor::fromStaticShape( 455 rewriter, loc, *getTypeConverter(), memrefTy, arg); 456 rewriter.replaceOp(placeholder, {desc}); 457 } 458 459 rewriter.eraseOp(funcOp); 460 return success(); 461 } 462 }; 463 464 struct ConstantOpLowering : public ConvertOpToLLVMPattern<func::ConstantOp> { 465 using ConvertOpToLLVMPattern<func::ConstantOp>::ConvertOpToLLVMPattern; 466 467 LogicalResult 468 matchAndRewrite(func::ConstantOp op, OpAdaptor adaptor, 469 ConversionPatternRewriter &rewriter) const override { 470 auto type = typeConverter->convertType(op.getResult().getType()); 471 if (!type || !LLVM::isCompatibleType(type)) 472 return rewriter.notifyMatchFailure(op, "failed to convert result type"); 473 474 auto newOp = 475 rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type, op.getValue()); 476 for (const NamedAttribute &attr : op->getAttrs()) { 477 if (attr.getName().strref() == "value") 478 continue; 479 newOp->setAttr(attr.getName(), attr.getValue()); 480 } 481 rewriter.replaceOp(op, newOp->getResults()); 482 return success(); 483 } 484 }; 485 486 // A CallOp automatically promotes MemRefType to a sequence of alloca/store and 487 // passes the pointer to the MemRef across function boundaries. 488 template <typename CallOpType> 489 struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> { 490 using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern; 491 using Super = CallOpInterfaceLowering<CallOpType>; 492 using Base = ConvertOpToLLVMPattern<CallOpType>; 493 494 LogicalResult 495 matchAndRewrite(CallOpType callOp, typename CallOpType::Adaptor adaptor, 496 ConversionPatternRewriter &rewriter) const override { 497 // Pack the result types into a struct. 498 Type packedResult = nullptr; 499 unsigned numResults = callOp.getNumResults(); 500 auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes()); 501 502 if (numResults != 0) { 503 if (!(packedResult = 504 this->getTypeConverter()->packFunctionResults(resultTypes))) 505 return failure(); 506 } 507 508 auto promoted = this->getTypeConverter()->promoteOperands( 509 callOp.getLoc(), /*opOperands=*/callOp->getOperands(), 510 adaptor.getOperands(), rewriter); 511 auto newOp = rewriter.create<LLVM::CallOp>( 512 callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(), 513 promoted, callOp->getAttrs()); 514 515 SmallVector<Value, 4> results; 516 if (numResults < 2) { 517 // If < 2 results, packing did not do anything and we can just return. 518 results.append(newOp.result_begin(), newOp.result_end()); 519 } else { 520 // Otherwise, it had been converted to an operation producing a structure. 521 // Extract individual results from the structure and return them as list. 522 results.reserve(numResults); 523 for (unsigned i = 0; i < numResults; ++i) { 524 auto type = 525 this->typeConverter->convertType(callOp.getResult(i).getType()); 526 results.push_back(rewriter.create<LLVM::ExtractValueOp>( 527 callOp.getLoc(), type, newOp->getResult(0), 528 rewriter.getI64ArrayAttr(i))); 529 } 530 } 531 532 if (this->getTypeConverter()->getOptions().useBarePtrCallConv) { 533 // For the bare-ptr calling convention, promote memref results to 534 // descriptors. 535 assert(results.size() == resultTypes.size() && 536 "The number of arguments and types doesn't match"); 537 this->getTypeConverter()->promoteBarePtrsToDescriptors( 538 rewriter, callOp.getLoc(), resultTypes, results); 539 } else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(), 540 resultTypes, results, 541 /*toDynamic=*/false))) { 542 return failure(); 543 } 544 545 rewriter.replaceOp(callOp, results); 546 return success(); 547 } 548 }; 549 550 struct CallOpLowering : public CallOpInterfaceLowering<func::CallOp> { 551 using Super::Super; 552 }; 553 554 struct CallIndirectOpLowering 555 : public CallOpInterfaceLowering<func::CallIndirectOp> { 556 using Super::Super; 557 }; 558 559 struct UnrealizedConversionCastOpLowering 560 : public ConvertOpToLLVMPattern<UnrealizedConversionCastOp> { 561 using ConvertOpToLLVMPattern< 562 UnrealizedConversionCastOp>::ConvertOpToLLVMPattern; 563 564 LogicalResult 565 matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor, 566 ConversionPatternRewriter &rewriter) const override { 567 SmallVector<Type> convertedTypes; 568 if (succeeded(typeConverter->convertTypes(op.getOutputs().getTypes(), 569 convertedTypes)) && 570 convertedTypes == adaptor.getInputs().getTypes()) { 571 rewriter.replaceOp(op, adaptor.getInputs()); 572 return success(); 573 } 574 575 convertedTypes.clear(); 576 if (succeeded(typeConverter->convertTypes(adaptor.getInputs().getTypes(), 577 convertedTypes)) && 578 convertedTypes == op.getOutputs().getType()) { 579 rewriter.replaceOp(op, adaptor.getInputs()); 580 return success(); 581 } 582 return failure(); 583 } 584 }; 585 586 // Special lowering pattern for `ReturnOps`. Unlike all other operations, 587 // `ReturnOp` interacts with the function signature and must have as many 588 // operands as the function has return values. Because in LLVM IR, functions 589 // can only return 0 or 1 value, we pack multiple values into a structure type. 590 // Emit `UndefOp` followed by `InsertValueOp`s to create such structure if 591 // necessary before returning it 592 struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> { 593 using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern; 594 595 LogicalResult 596 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, 597 ConversionPatternRewriter &rewriter) const override { 598 Location loc = op.getLoc(); 599 unsigned numArguments = op.getNumOperands(); 600 SmallVector<Value, 4> updatedOperands; 601 602 if (getTypeConverter()->getOptions().useBarePtrCallConv) { 603 // For the bare-ptr calling convention, extract the aligned pointer to 604 // be returned from the memref descriptor. 605 for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) { 606 Type oldTy = std::get<0>(it).getType(); 607 Value newOperand = std::get<1>(it); 608 if (oldTy.isa<MemRefType>() && getTypeConverter()->canConvertToBarePtr( 609 oldTy.cast<BaseMemRefType>())) { 610 MemRefDescriptor memrefDesc(newOperand); 611 newOperand = memrefDesc.alignedPtr(rewriter, loc); 612 } else if (oldTy.isa<UnrankedMemRefType>()) { 613 // Unranked memref is not supported in the bare pointer calling 614 // convention. 615 return failure(); 616 } 617 updatedOperands.push_back(newOperand); 618 } 619 } else { 620 updatedOperands = llvm::to_vector<4>(adaptor.getOperands()); 621 (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(), 622 updatedOperands, 623 /*toDynamic=*/true); 624 } 625 626 // If ReturnOp has 0 or 1 operand, create it and return immediately. 627 if (numArguments == 0) { 628 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(), 629 op->getAttrs()); 630 return success(); 631 } 632 if (numArguments == 1) { 633 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>( 634 op, TypeRange(), updatedOperands, op->getAttrs()); 635 return success(); 636 } 637 638 // Otherwise, we need to pack the arguments into an LLVM struct type before 639 // returning. 640 auto packedType = getTypeConverter()->packFunctionResults( 641 llvm::to_vector<4>(op.getOperandTypes())); 642 643 Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType); 644 for (unsigned i = 0; i < numArguments; ++i) { 645 packed = rewriter.create<LLVM::InsertValueOp>( 646 loc, packedType, packed, updatedOperands[i], 647 rewriter.getI64ArrayAttr(i)); 648 } 649 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed, 650 op->getAttrs()); 651 return success(); 652 } 653 }; 654 } // namespace 655 656 void mlir::populateFuncToLLVMFuncOpConversionPattern( 657 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 658 if (converter.getOptions().useBarePtrCallConv) 659 patterns.add<BarePtrFuncOpConversion>(converter); 660 else 661 patterns.add<FuncOpConversion>(converter); 662 } 663 664 void mlir::populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter, 665 RewritePatternSet &patterns) { 666 populateFuncToLLVMFuncOpConversionPattern(converter, patterns); 667 // clang-format off 668 patterns.add< 669 CallIndirectOpLowering, 670 CallOpLowering, 671 ConstantOpLowering, 672 ReturnOpLowering>(converter); 673 // clang-format on 674 } 675 676 namespace { 677 /// A pass converting Func operations into the LLVM IR dialect. 678 struct ConvertFuncToLLVMPass 679 : public ConvertFuncToLLVMBase<ConvertFuncToLLVMPass> { 680 ConvertFuncToLLVMPass() = default; 681 ConvertFuncToLLVMPass(bool useBarePtrCallConv, unsigned indexBitwidth, 682 bool useAlignedAlloc, 683 const llvm::DataLayout &dataLayout) { 684 this->useBarePtrCallConv = useBarePtrCallConv; 685 this->indexBitwidth = indexBitwidth; 686 this->dataLayout = dataLayout.getStringRepresentation(); 687 } 688 689 /// Run the dialect converter on the module. 690 void runOnOperation() override { 691 if (failed(LLVM::LLVMDialect::verifyDataLayoutString( 692 this->dataLayout, [this](const Twine &message) { 693 getOperation().emitError() << message.str(); 694 }))) { 695 signalPassFailure(); 696 return; 697 } 698 699 ModuleOp m = getOperation(); 700 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>(); 701 702 LowerToLLVMOptions options(&getContext(), 703 dataLayoutAnalysis.getAtOrAbove(m)); 704 options.useBarePtrCallConv = useBarePtrCallConv; 705 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 706 options.overrideIndexBitwidth(indexBitwidth); 707 options.dataLayout = llvm::DataLayout(this->dataLayout); 708 709 LLVMTypeConverter typeConverter(&getContext(), options, 710 &dataLayoutAnalysis); 711 712 RewritePatternSet patterns(&getContext()); 713 populateFuncToLLVMConversionPatterns(typeConverter, patterns); 714 715 // TODO: Remove these in favor of their dedicated conversion passes. 716 arith::populateArithmeticToLLVMConversionPatterns(typeConverter, patterns); 717 cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns); 718 719 LLVMConversionTarget target(getContext()); 720 if (failed(applyPartialConversion(m, target, std::move(patterns)))) 721 signalPassFailure(); 722 723 m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(), 724 StringAttr::get(m.getContext(), this->dataLayout)); 725 } 726 }; 727 } // namespace 728 729 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertFuncToLLVMPass() { 730 return std::make_unique<ConvertFuncToLLVMPass>(); 731 } 732 733 std::unique_ptr<OperationPass<ModuleOp>> 734 mlir::createConvertFuncToLLVMPass(const LowerToLLVMOptions &options) { 735 auto allocLowering = options.allocLowering; 736 // There is no way to provide additional patterns for pass, so 737 // AllocLowering::None will always fail. 738 assert(allocLowering != LowerToLLVMOptions::AllocLowering::None && 739 "ConvertFuncToLLVMPass doesn't support AllocLowering::None"); 740 bool useAlignedAlloc = 741 (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc); 742 return std::make_unique<ConvertFuncToLLVMPass>( 743 options.useBarePtrCallConv, options.getIndexBitwidth(), useAlignedAlloc, 744 options.dataLayout); 745 } 746