1 //===- ConvertFromLLVMIR.cpp - MLIR to LLVM IR conversion -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a translation between LLVM IR and the MLIR LLVM dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "mlir/IR/Builders.h" 15 #include "mlir/IR/BuiltinOps.h" 16 #include "mlir/IR/BuiltinTypes.h" 17 #include "mlir/IR/MLIRContext.h" 18 #include "mlir/Target/LLVMIR/Import.h" 19 #include "mlir/Translation.h" 20 21 #include "llvm/ADT/TypeSwitch.h" 22 #include "llvm/IR/Attributes.h" 23 #include "llvm/IR/Constants.h" 24 #include "llvm/IR/DerivedTypes.h" 25 #include "llvm/IR/Function.h" 26 #include "llvm/IR/InlineAsm.h" 27 #include "llvm/IR/Instructions.h" 28 #include "llvm/IR/Type.h" 29 #include "llvm/IRReader/IRReader.h" 30 #include "llvm/Support/Error.h" 31 #include "llvm/Support/SourceMgr.h" 32 33 using namespace mlir; 34 using namespace mlir::LLVM; 35 36 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsFromLLVM.inc" 37 38 // Utility to print an LLVM value as a string for passing to emitError(). 39 // FIXME: Diagnostic should be able to natively handle types that have 40 // operator << (raw_ostream&) defined. 41 static std::string diag(llvm::Value &v) { 42 std::string s; 43 llvm::raw_string_ostream os(s); 44 os << v; 45 return os.str(); 46 } 47 48 namespace mlir { 49 namespace LLVM { 50 namespace detail { 51 /// Support for translating LLVM IR types to MLIR LLVM dialect types. 52 class TypeFromLLVMIRTranslatorImpl { 53 public: 54 /// Constructs a class creating types in the given MLIR context. 55 TypeFromLLVMIRTranslatorImpl(MLIRContext &context) : context(context) {} 56 57 /// Translates the given type. 58 Type translateType(llvm::Type *type) { 59 if (knownTranslations.count(type)) 60 return knownTranslations.lookup(type); 61 62 Type translated = 63 llvm::TypeSwitch<llvm::Type *, Type>(type) 64 .Case<llvm::ArrayType, llvm::FunctionType, llvm::IntegerType, 65 llvm::PointerType, llvm::StructType, llvm::FixedVectorType, 66 llvm::ScalableVectorType>( 67 [this](auto *type) { return this->translate(type); }) 68 .Default([this](llvm::Type *type) { 69 return translatePrimitiveType(type); 70 }); 71 knownTranslations.try_emplace(type, translated); 72 return translated; 73 } 74 75 private: 76 /// Translates the given primitive, i.e. non-parametric in MLIR nomenclature, 77 /// type. 78 Type translatePrimitiveType(llvm::Type *type) { 79 if (type->isVoidTy()) 80 return LLVM::LLVMVoidType::get(&context); 81 if (type->isHalfTy()) 82 return Float16Type::get(&context); 83 if (type->isBFloatTy()) 84 return BFloat16Type::get(&context); 85 if (type->isFloatTy()) 86 return Float32Type::get(&context); 87 if (type->isDoubleTy()) 88 return Float64Type::get(&context); 89 if (type->isFP128Ty()) 90 return Float128Type::get(&context); 91 if (type->isX86_FP80Ty()) 92 return Float80Type::get(&context); 93 if (type->isPPC_FP128Ty()) 94 return LLVM::LLVMPPCFP128Type::get(&context); 95 if (type->isX86_MMXTy()) 96 return LLVM::LLVMX86MMXType::get(&context); 97 if (type->isLabelTy()) 98 return LLVM::LLVMLabelType::get(&context); 99 if (type->isMetadataTy()) 100 return LLVM::LLVMMetadataType::get(&context); 101 llvm_unreachable("not a primitive type"); 102 } 103 104 /// Translates the given array type. 105 Type translate(llvm::ArrayType *type) { 106 return LLVM::LLVMArrayType::get(translateType(type->getElementType()), 107 type->getNumElements()); 108 } 109 110 /// Translates the given function type. 111 Type translate(llvm::FunctionType *type) { 112 SmallVector<Type, 8> paramTypes; 113 translateTypes(type->params(), paramTypes); 114 return LLVM::LLVMFunctionType::get(translateType(type->getReturnType()), 115 paramTypes, type->isVarArg()); 116 } 117 118 /// Translates the given integer type. 119 Type translate(llvm::IntegerType *type) { 120 return IntegerType::get(&context, type->getBitWidth()); 121 } 122 123 /// Translates the given pointer type. 124 Type translate(llvm::PointerType *type) { 125 return LLVM::LLVMPointerType::get(translateType(type->getElementType()), 126 type->getAddressSpace()); 127 } 128 129 /// Translates the given structure type. 130 Type translate(llvm::StructType *type) { 131 SmallVector<Type, 8> subtypes; 132 if (type->isLiteral()) { 133 translateTypes(type->subtypes(), subtypes); 134 return LLVM::LLVMStructType::getLiteral(&context, subtypes, 135 type->isPacked()); 136 } 137 138 if (type->isOpaque()) 139 return LLVM::LLVMStructType::getOpaque(type->getName(), &context); 140 141 LLVM::LLVMStructType translated = 142 LLVM::LLVMStructType::getIdentified(&context, type->getName()); 143 knownTranslations.try_emplace(type, translated); 144 translateTypes(type->subtypes(), subtypes); 145 LogicalResult bodySet = translated.setBody(subtypes, type->isPacked()); 146 assert(succeeded(bodySet) && 147 "could not set the body of an identified struct"); 148 (void)bodySet; 149 return translated; 150 } 151 152 /// Translates the given fixed-vector type. 153 Type translate(llvm::FixedVectorType *type) { 154 return LLVM::getFixedVectorType(translateType(type->getElementType()), 155 type->getNumElements()); 156 } 157 158 /// Translates the given scalable-vector type. 159 Type translate(llvm::ScalableVectorType *type) { 160 return LLVM::LLVMScalableVectorType::get( 161 translateType(type->getElementType()), type->getMinNumElements()); 162 } 163 164 /// Translates a list of types. 165 void translateTypes(ArrayRef<llvm::Type *> types, 166 SmallVectorImpl<Type> &result) { 167 result.reserve(result.size() + types.size()); 168 for (llvm::Type *type : types) 169 result.push_back(translateType(type)); 170 } 171 172 /// Map of known translations. Serves as a cache and as recursion stopper for 173 /// translating recursive structs. 174 llvm::DenseMap<llvm::Type *, Type> knownTranslations; 175 176 /// The context in which MLIR types are created. 177 MLIRContext &context; 178 }; 179 } // end namespace detail 180 181 /// Utility class to translate LLVM IR types to the MLIR LLVM dialect. Stores 182 /// the translation state, in particular any identified structure types that are 183 /// reused across translations. 184 class TypeFromLLVMIRTranslator { 185 public: 186 TypeFromLLVMIRTranslator(MLIRContext &context); 187 ~TypeFromLLVMIRTranslator(); 188 189 /// Translates the given LLVM IR type to the MLIR LLVM dialect. 190 Type translateType(llvm::Type *type); 191 192 private: 193 /// Private implementation. 194 std::unique_ptr<detail::TypeFromLLVMIRTranslatorImpl> impl; 195 }; 196 197 } // end namespace LLVM 198 } // end namespace mlir 199 200 LLVM::TypeFromLLVMIRTranslator::TypeFromLLVMIRTranslator(MLIRContext &context) 201 : impl(new detail::TypeFromLLVMIRTranslatorImpl(context)) {} 202 203 LLVM::TypeFromLLVMIRTranslator::~TypeFromLLVMIRTranslator() {} 204 205 Type LLVM::TypeFromLLVMIRTranslator::translateType(llvm::Type *type) { 206 return impl->translateType(type); 207 } 208 209 // Handles importing globals and functions from an LLVM module. 210 namespace { 211 class Importer { 212 public: 213 Importer(MLIRContext *context, ModuleOp module) 214 : b(context), context(context), module(module), 215 unknownLoc(FileLineColLoc::get(context, "imported-bitcode", 0, 0)), 216 typeTranslator(*context) { 217 b.setInsertionPointToStart(module.getBody()); 218 } 219 220 /// Imports `f` into the current module. 221 LogicalResult processFunction(llvm::Function *f); 222 223 /// Imports GV as a GlobalOp, creating it if it doesn't exist. 224 GlobalOp processGlobal(llvm::GlobalVariable *GV); 225 226 private: 227 /// Returns personality of `f` as a FlatSymbolRefAttr. 228 FlatSymbolRefAttr getPersonalityAsAttr(llvm::Function *f); 229 /// Imports `bb` into `block`, which must be initially empty. 230 LogicalResult processBasicBlock(llvm::BasicBlock *bb, Block *block); 231 /// Imports `inst` and populates instMap[inst] with the imported Value. 232 LogicalResult processInstruction(llvm::Instruction *inst); 233 /// Creates an LLVM-compatible MLIR type for `type`. 234 Type processType(llvm::Type *type); 235 /// `value` is an SSA-use. Return the remapped version of `value` or a 236 /// placeholder that will be remapped later if this is an instruction that 237 /// has not yet been visited. 238 Value processValue(llvm::Value *value); 239 /// Create the most accurate Location possible using a llvm::DebugLoc and 240 /// possibly an llvm::Instruction to narrow the Location if debug information 241 /// is unavailable. 242 Location processDebugLoc(const llvm::DebugLoc &loc, 243 llvm::Instruction *inst = nullptr); 244 /// `br` branches to `target`. Append the block arguments to attach to the 245 /// generated branch op to `blockArguments`. These should be in the same order 246 /// as the PHIs in `target`. 247 LogicalResult processBranchArgs(llvm::Instruction *br, 248 llvm::BasicBlock *target, 249 SmallVectorImpl<Value> &blockArguments); 250 /// Returns the builtin type equivalent to be used in attributes for the given 251 /// LLVM IR dialect type. 252 Type getStdTypeForAttr(Type type); 253 /// Return `value` as an attribute to attach to a GlobalOp. 254 Attribute getConstantAsAttr(llvm::Constant *value); 255 /// Return `c` as an MLIR Value. This could either be a ConstantOp, or 256 /// an expanded sequence of ops in the current function's entry block (for 257 /// ConstantExprs or ConstantGEPs). 258 Value processConstant(llvm::Constant *c); 259 260 /// The current builder, pointing at where the next Instruction should be 261 /// generated. 262 OpBuilder b; 263 /// The current context. 264 MLIRContext *context; 265 /// The current module being created. 266 ModuleOp module; 267 /// The entry block of the current function being processed. 268 Block *currentEntryBlock; 269 270 /// Globals are inserted before the first function, if any. 271 Block::iterator getGlobalInsertPt() { 272 auto it = module.getBody()->begin(); 273 auto endIt = module.getBody()->end(); 274 while (it != endIt && !isa<LLVMFuncOp>(it)) 275 ++it; 276 return it; 277 } 278 279 /// Functions are always inserted before the module terminator. 280 Block::iterator getFuncInsertPt() { 281 return std::prev(module.getBody()->end()); 282 } 283 284 /// Remapped blocks, for the current function. 285 DenseMap<llvm::BasicBlock *, Block *> blocks; 286 /// Remapped values. These are function-local. 287 DenseMap<llvm::Value *, Value> instMap; 288 /// Instructions that had not been defined when first encountered as a use. 289 /// Maps to the dummy Operation that was created in processValue(). 290 DenseMap<llvm::Value *, Operation *> unknownInstMap; 291 /// Uniquing map of GlobalVariables. 292 DenseMap<llvm::GlobalVariable *, GlobalOp> globals; 293 /// Cached FileLineColLoc::get("imported-bitcode", 0, 0). 294 Location unknownLoc; 295 /// The stateful type translator (contains named structs). 296 LLVM::TypeFromLLVMIRTranslator typeTranslator; 297 }; 298 } // namespace 299 300 Location Importer::processDebugLoc(const llvm::DebugLoc &loc, 301 llvm::Instruction *inst) { 302 if (!loc && inst) { 303 std::string s; 304 llvm::raw_string_ostream os(s); 305 os << "llvm-imported-inst-%"; 306 inst->printAsOperand(os, /*PrintType=*/false); 307 return FileLineColLoc::get(context, os.str(), 0, 0); 308 } else if (!loc) { 309 return unknownLoc; 310 } 311 // FIXME: Obtain the filename from DILocationInfo. 312 return FileLineColLoc::get(context, "imported-bitcode", loc.getLine(), 313 loc.getCol()); 314 } 315 316 Type Importer::processType(llvm::Type *type) { 317 if (Type result = typeTranslator.translateType(type)) 318 return result; 319 320 // FIXME: Diagnostic should be able to natively handle types that have 321 // operator<<(raw_ostream&) defined. 322 std::string s; 323 llvm::raw_string_ostream os(s); 324 os << *type; 325 emitError(unknownLoc) << "unhandled type: " << os.str(); 326 return nullptr; 327 } 328 329 // We only need integers, floats, doubles, and vectors and tensors thereof for 330 // attributes. Scalar and vector types are converted to the standard 331 // equivalents. Array types are converted to ranked tensors; nested array types 332 // are converted to multi-dimensional tensors or vectors, depending on the 333 // innermost type being a scalar or a vector. 334 Type Importer::getStdTypeForAttr(Type type) { 335 if (!type) 336 return nullptr; 337 338 if (type.isa<IntegerType, FloatType>()) 339 return type; 340 341 // LLVM vectors can only contain scalars. 342 if (LLVM::isCompatibleVectorType(type)) { 343 auto numElements = LLVM::getVectorNumElements(type); 344 if (numElements.isScalable()) { 345 emitError(unknownLoc) << "scalable vectors not supported"; 346 return nullptr; 347 } 348 Type elementType = getStdTypeForAttr(LLVM::getVectorElementType(type)); 349 if (!elementType) 350 return nullptr; 351 return VectorType::get(numElements.getKnownMinValue(), elementType); 352 } 353 354 // LLVM arrays can contain other arrays or vectors. 355 if (auto arrayType = type.dyn_cast<LLVMArrayType>()) { 356 // Recover the nested array shape. 357 SmallVector<int64_t, 4> shape; 358 shape.push_back(arrayType.getNumElements()); 359 while (arrayType.getElementType().isa<LLVMArrayType>()) { 360 arrayType = arrayType.getElementType().cast<LLVMArrayType>(); 361 shape.push_back(arrayType.getNumElements()); 362 } 363 364 // If the innermost type is a vector, use the multi-dimensional vector as 365 // attribute type. 366 if (LLVM::isCompatibleVectorType(arrayType.getElementType())) { 367 auto numElements = LLVM::getVectorNumElements(arrayType.getElementType()); 368 if (numElements.isScalable()) { 369 emitError(unknownLoc) << "scalable vectors not supported"; 370 return nullptr; 371 } 372 shape.push_back(numElements.getKnownMinValue()); 373 374 Type elementType = getStdTypeForAttr( 375 LLVM::getVectorElementType(arrayType.getElementType())); 376 if (!elementType) 377 return nullptr; 378 return VectorType::get(shape, elementType); 379 } 380 381 // Otherwise use a tensor. 382 Type elementType = getStdTypeForAttr(arrayType.getElementType()); 383 if (!elementType) 384 return nullptr; 385 return RankedTensorType::get(shape, elementType); 386 } 387 388 return nullptr; 389 } 390 391 // Get the given constant as an attribute. Not all constants can be represented 392 // as attributes. 393 Attribute Importer::getConstantAsAttr(llvm::Constant *value) { 394 if (auto *ci = dyn_cast<llvm::ConstantInt>(value)) 395 return b.getIntegerAttr( 396 IntegerType::get(context, ci->getType()->getBitWidth()), 397 ci->getValue()); 398 if (auto *c = dyn_cast<llvm::ConstantDataArray>(value)) 399 if (c->isString()) 400 return b.getStringAttr(c->getAsString()); 401 if (auto *c = dyn_cast<llvm::ConstantFP>(value)) { 402 if (c->getType()->isDoubleTy()) 403 return b.getFloatAttr(FloatType::getF64(context), c->getValueAPF()); 404 if (c->getType()->isFloatingPointTy()) 405 return b.getFloatAttr(FloatType::getF32(context), c->getValueAPF()); 406 } 407 if (auto *f = dyn_cast<llvm::Function>(value)) 408 return b.getSymbolRefAttr(f->getName()); 409 410 // Convert constant data to a dense elements attribute. 411 if (auto *cd = dyn_cast<llvm::ConstantDataSequential>(value)) { 412 Type type = processType(cd->getElementType()); 413 if (!type) 414 return nullptr; 415 416 auto attrType = getStdTypeForAttr(processType(cd->getType())) 417 .dyn_cast_or_null<ShapedType>(); 418 if (!attrType) 419 return nullptr; 420 421 if (type.isa<IntegerType>()) { 422 SmallVector<APInt, 8> values; 423 values.reserve(cd->getNumElements()); 424 for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i) 425 values.push_back(cd->getElementAsAPInt(i)); 426 return DenseElementsAttr::get(attrType, values); 427 } 428 429 if (type.isa<Float32Type, Float64Type>()) { 430 SmallVector<APFloat, 8> values; 431 values.reserve(cd->getNumElements()); 432 for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i) 433 values.push_back(cd->getElementAsAPFloat(i)); 434 return DenseElementsAttr::get(attrType, values); 435 } 436 437 return nullptr; 438 } 439 440 // Unpack constant aggregates to create dense elements attribute whenever 441 // possible. Return nullptr (failure) otherwise. 442 if (isa<llvm::ConstantAggregate>(value)) { 443 auto outerType = getStdTypeForAttr(processType(value->getType())) 444 .dyn_cast_or_null<ShapedType>(); 445 if (!outerType) 446 return nullptr; 447 448 SmallVector<Attribute, 8> values; 449 SmallVector<int64_t, 8> shape; 450 451 for (unsigned i = 0, e = value->getNumOperands(); i < e; ++i) { 452 auto nested = getConstantAsAttr(value->getAggregateElement(i)) 453 .dyn_cast_or_null<DenseElementsAttr>(); 454 if (!nested) 455 return nullptr; 456 457 values.append(nested.attr_value_begin(), nested.attr_value_end()); 458 } 459 460 return DenseElementsAttr::get(outerType, values); 461 } 462 463 return nullptr; 464 } 465 466 GlobalOp Importer::processGlobal(llvm::GlobalVariable *GV) { 467 auto it = globals.find(GV); 468 if (it != globals.end()) 469 return it->second; 470 471 OpBuilder b(module.getBody(), getGlobalInsertPt()); 472 Attribute valueAttr; 473 if (GV->hasInitializer()) 474 valueAttr = getConstantAsAttr(GV->getInitializer()); 475 Type type = processType(GV->getValueType()); 476 if (!type) 477 return nullptr; 478 GlobalOp op = b.create<GlobalOp>( 479 UnknownLoc::get(context), type, GV->isConstant(), 480 convertLinkageFromLLVM(GV->getLinkage()), GV->getName(), valueAttr); 481 if (GV->hasInitializer() && !valueAttr) { 482 Region &r = op.getInitializerRegion(); 483 currentEntryBlock = b.createBlock(&r); 484 b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin()); 485 Value v = processConstant(GV->getInitializer()); 486 if (!v) 487 return nullptr; 488 b.create<ReturnOp>(op.getLoc(), ArrayRef<Value>({v})); 489 } 490 if (GV->hasAtLeastLocalUnnamedAddr()) 491 op.unnamed_addrAttr(UnnamedAddrAttr::get( 492 context, convertUnnamedAddrFromLLVM(GV->getUnnamedAddr()))); 493 if (GV->hasSection()) 494 op.sectionAttr(b.getStringAttr(GV->getSection())); 495 return globals[GV] = op; 496 } 497 498 Value Importer::processConstant(llvm::Constant *c) { 499 OpBuilder bEntry(currentEntryBlock, currentEntryBlock->begin()); 500 if (Attribute attr = getConstantAsAttr(c)) { 501 // These constants can be represented as attributes. 502 OpBuilder b(currentEntryBlock, currentEntryBlock->begin()); 503 Type type = processType(c->getType()); 504 if (!type) 505 return nullptr; 506 if (auto symbolRef = attr.dyn_cast<FlatSymbolRefAttr>()) 507 return instMap[c] = bEntry.create<AddressOfOp>(unknownLoc, type, 508 symbolRef.getValue()); 509 return instMap[c] = bEntry.create<ConstantOp>(unknownLoc, type, attr); 510 } 511 if (auto *cn = dyn_cast<llvm::ConstantPointerNull>(c)) { 512 Type type = processType(cn->getType()); 513 if (!type) 514 return nullptr; 515 return instMap[c] = bEntry.create<NullOp>(unknownLoc, type); 516 } 517 if (auto *GV = dyn_cast<llvm::GlobalVariable>(c)) 518 return bEntry.create<AddressOfOp>(UnknownLoc::get(context), 519 processGlobal(GV)); 520 521 if (auto *ce = dyn_cast<llvm::ConstantExpr>(c)) { 522 llvm::Instruction *i = ce->getAsInstruction(); 523 OpBuilder::InsertionGuard guard(b); 524 b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin()); 525 if (failed(processInstruction(i))) 526 return nullptr; 527 assert(instMap.count(i)); 528 529 // Remove this zombie LLVM instruction now, leaving us only with the MLIR 530 // op. 531 i->deleteValue(); 532 return instMap[c] = instMap[i]; 533 } 534 if (auto *ue = dyn_cast<llvm::UndefValue>(c)) { 535 Type type = processType(ue->getType()); 536 if (!type) 537 return nullptr; 538 return instMap[c] = bEntry.create<UndefOp>(UnknownLoc::get(context), type); 539 } 540 emitError(unknownLoc) << "unhandled constant: " << diag(*c); 541 return nullptr; 542 } 543 544 Value Importer::processValue(llvm::Value *value) { 545 auto it = instMap.find(value); 546 if (it != instMap.end()) 547 return it->second; 548 549 // We don't expect to see instructions in dominator order. If we haven't seen 550 // this instruction yet, create an unknown op and remap it later. 551 if (isa<llvm::Instruction>(value)) { 552 OperationState state(UnknownLoc::get(context), "llvm.unknown"); 553 Type type = processType(value->getType()); 554 if (!type) 555 return nullptr; 556 state.addTypes(type); 557 unknownInstMap[value] = b.createOperation(state); 558 return unknownInstMap[value]->getResult(0); 559 } 560 561 if (auto *c = dyn_cast<llvm::Constant>(value)) 562 return processConstant(c); 563 564 emitError(unknownLoc) << "unhandled value: " << diag(*value); 565 return nullptr; 566 } 567 568 /// Return the MLIR OperationName for the given LLVM opcode. 569 static StringRef lookupOperationNameFromOpcode(unsigned opcode) { 570 // Maps from LLVM opcode to MLIR OperationName. This is deliberately ordered 571 // as in llvm/IR/Instructions.def to aid comprehension and spot missing 572 // instructions. 573 #define INST(llvm_n, mlir_n) \ 574 { llvm::Instruction::llvm_n, LLVM::mlir_n##Op::getOperationName() } 575 static const DenseMap<unsigned, StringRef> opcMap = { 576 // Ret is handled specially. 577 // Br is handled specially. 578 // FIXME: switch 579 // FIXME: indirectbr 580 // FIXME: invoke 581 INST(Resume, Resume), 582 // FIXME: unreachable 583 // FIXME: cleanupret 584 // FIXME: catchret 585 // FIXME: catchswitch 586 // FIXME: callbr 587 // FIXME: fneg 588 INST(Add, Add), INST(FAdd, FAdd), INST(Sub, Sub), INST(FSub, FSub), 589 INST(Mul, Mul), INST(FMul, FMul), INST(UDiv, UDiv), INST(SDiv, SDiv), 590 INST(FDiv, FDiv), INST(URem, URem), INST(SRem, SRem), INST(FRem, FRem), 591 INST(Shl, Shl), INST(LShr, LShr), INST(AShr, AShr), INST(And, And), 592 INST(Or, Or), INST(Xor, XOr), INST(Alloca, Alloca), INST(Load, Load), 593 INST(Store, Store), 594 // Getelementptr is handled specially. 595 INST(Ret, Return), INST(Fence, Fence), 596 // FIXME: atomiccmpxchg 597 // FIXME: atomicrmw 598 INST(Trunc, Trunc), INST(ZExt, ZExt), INST(SExt, SExt), 599 INST(FPToUI, FPToUI), INST(FPToSI, FPToSI), INST(UIToFP, UIToFP), 600 INST(SIToFP, SIToFP), INST(FPTrunc, FPTrunc), INST(FPExt, FPExt), 601 INST(PtrToInt, PtrToInt), INST(IntToPtr, IntToPtr), 602 INST(BitCast, Bitcast), INST(AddrSpaceCast, AddrSpaceCast), 603 // FIXME: cleanuppad 604 // FIXME: catchpad 605 // ICmp is handled specially. 606 // FIXME: fcmp 607 // PHI is handled specially. 608 INST(Freeze, Freeze), INST(Call, Call), 609 // FIXME: select 610 // FIXME: vaarg 611 // FIXME: extractelement 612 // FIXME: insertelement 613 // FIXME: shufflevector 614 // FIXME: extractvalue 615 // FIXME: insertvalue 616 // FIXME: landingpad 617 }; 618 #undef INST 619 620 return opcMap.lookup(opcode); 621 } 622 623 static ICmpPredicate getICmpPredicate(llvm::CmpInst::Predicate p) { 624 switch (p) { 625 default: 626 llvm_unreachable("incorrect comparison predicate"); 627 case llvm::CmpInst::Predicate::ICMP_EQ: 628 return LLVM::ICmpPredicate::eq; 629 case llvm::CmpInst::Predicate::ICMP_NE: 630 return LLVM::ICmpPredicate::ne; 631 case llvm::CmpInst::Predicate::ICMP_SLT: 632 return LLVM::ICmpPredicate::slt; 633 case llvm::CmpInst::Predicate::ICMP_SLE: 634 return LLVM::ICmpPredicate::sle; 635 case llvm::CmpInst::Predicate::ICMP_SGT: 636 return LLVM::ICmpPredicate::sgt; 637 case llvm::CmpInst::Predicate::ICMP_SGE: 638 return LLVM::ICmpPredicate::sge; 639 case llvm::CmpInst::Predicate::ICMP_ULT: 640 return LLVM::ICmpPredicate::ult; 641 case llvm::CmpInst::Predicate::ICMP_ULE: 642 return LLVM::ICmpPredicate::ule; 643 case llvm::CmpInst::Predicate::ICMP_UGT: 644 return LLVM::ICmpPredicate::ugt; 645 case llvm::CmpInst::Predicate::ICMP_UGE: 646 return LLVM::ICmpPredicate::uge; 647 } 648 llvm_unreachable("incorrect comparison predicate"); 649 } 650 651 static AtomicOrdering getLLVMAtomicOrdering(llvm::AtomicOrdering ordering) { 652 switch (ordering) { 653 case llvm::AtomicOrdering::NotAtomic: 654 return LLVM::AtomicOrdering::not_atomic; 655 case llvm::AtomicOrdering::Unordered: 656 return LLVM::AtomicOrdering::unordered; 657 case llvm::AtomicOrdering::Monotonic: 658 return LLVM::AtomicOrdering::monotonic; 659 case llvm::AtomicOrdering::Acquire: 660 return LLVM::AtomicOrdering::acquire; 661 case llvm::AtomicOrdering::Release: 662 return LLVM::AtomicOrdering::release; 663 case llvm::AtomicOrdering::AcquireRelease: 664 return LLVM::AtomicOrdering::acq_rel; 665 case llvm::AtomicOrdering::SequentiallyConsistent: 666 return LLVM::AtomicOrdering::seq_cst; 667 } 668 llvm_unreachable("incorrect atomic ordering"); 669 } 670 671 // `br` branches to `target`. Return the branch arguments to `br`, in the 672 // same order of the PHIs in `target`. 673 LogicalResult 674 Importer::processBranchArgs(llvm::Instruction *br, llvm::BasicBlock *target, 675 SmallVectorImpl<Value> &blockArguments) { 676 for (auto inst = target->begin(); isa<llvm::PHINode>(inst); ++inst) { 677 auto *PN = cast<llvm::PHINode>(&*inst); 678 Value value = processValue(PN->getIncomingValueForBlock(br->getParent())); 679 if (!value) 680 return failure(); 681 blockArguments.push_back(value); 682 } 683 return success(); 684 } 685 686 LogicalResult Importer::processInstruction(llvm::Instruction *inst) { 687 // FIXME: Support uses of SubtargetData. Currently inbounds GEPs, fast-math 688 // flags and call / operand attributes are not supported. 689 Location loc = processDebugLoc(inst->getDebugLoc(), inst); 690 Value &v = instMap[inst]; 691 assert(!v && "processInstruction must be called only once per instruction!"); 692 switch (inst->getOpcode()) { 693 default: 694 return emitError(loc) << "unknown instruction: " << diag(*inst); 695 case llvm::Instruction::Add: 696 case llvm::Instruction::FAdd: 697 case llvm::Instruction::Sub: 698 case llvm::Instruction::FSub: 699 case llvm::Instruction::Mul: 700 case llvm::Instruction::FMul: 701 case llvm::Instruction::UDiv: 702 case llvm::Instruction::SDiv: 703 case llvm::Instruction::FDiv: 704 case llvm::Instruction::URem: 705 case llvm::Instruction::SRem: 706 case llvm::Instruction::FRem: 707 case llvm::Instruction::Shl: 708 case llvm::Instruction::LShr: 709 case llvm::Instruction::AShr: 710 case llvm::Instruction::And: 711 case llvm::Instruction::Or: 712 case llvm::Instruction::Xor: 713 case llvm::Instruction::Alloca: 714 case llvm::Instruction::Load: 715 case llvm::Instruction::Store: 716 case llvm::Instruction::Ret: 717 case llvm::Instruction::Resume: 718 case llvm::Instruction::Trunc: 719 case llvm::Instruction::ZExt: 720 case llvm::Instruction::SExt: 721 case llvm::Instruction::FPToUI: 722 case llvm::Instruction::FPToSI: 723 case llvm::Instruction::UIToFP: 724 case llvm::Instruction::SIToFP: 725 case llvm::Instruction::FPTrunc: 726 case llvm::Instruction::FPExt: 727 case llvm::Instruction::PtrToInt: 728 case llvm::Instruction::IntToPtr: 729 case llvm::Instruction::AddrSpaceCast: 730 case llvm::Instruction::Freeze: 731 case llvm::Instruction::BitCast: { 732 OperationState state(loc, lookupOperationNameFromOpcode(inst->getOpcode())); 733 SmallVector<Value, 4> ops; 734 ops.reserve(inst->getNumOperands()); 735 for (auto *op : inst->operand_values()) { 736 Value value = processValue(op); 737 if (!value) 738 return failure(); 739 ops.push_back(value); 740 } 741 state.addOperands(ops); 742 if (!inst->getType()->isVoidTy()) { 743 Type type = processType(inst->getType()); 744 if (!type) 745 return failure(); 746 state.addTypes(type); 747 } 748 Operation *op = b.createOperation(state); 749 if (!inst->getType()->isVoidTy()) 750 v = op->getResult(0); 751 return success(); 752 } 753 case llvm::Instruction::ICmp: { 754 Value lhs = processValue(inst->getOperand(0)); 755 Value rhs = processValue(inst->getOperand(1)); 756 if (!lhs || !rhs) 757 return failure(); 758 v = b.create<ICmpOp>( 759 loc, getICmpPredicate(cast<llvm::ICmpInst>(inst)->getPredicate()), lhs, 760 rhs); 761 return success(); 762 } 763 case llvm::Instruction::Br: { 764 auto *brInst = cast<llvm::BranchInst>(inst); 765 OperationState state(loc, 766 brInst->isConditional() ? "llvm.cond_br" : "llvm.br"); 767 if (brInst->isConditional()) { 768 Value condition = processValue(brInst->getCondition()); 769 if (!condition) 770 return failure(); 771 state.addOperands(condition); 772 } 773 774 std::array<int32_t, 3> operandSegmentSizes = {1, 0, 0}; 775 for (int i : llvm::seq<int>(0, brInst->getNumSuccessors())) { 776 auto *succ = brInst->getSuccessor(i); 777 SmallVector<Value, 4> blockArguments; 778 if (failed(processBranchArgs(brInst, succ, blockArguments))) 779 return failure(); 780 state.addSuccessors(blocks[succ]); 781 state.addOperands(blockArguments); 782 operandSegmentSizes[i + 1] = blockArguments.size(); 783 } 784 785 if (brInst->isConditional()) { 786 state.addAttribute(LLVM::CondBrOp::getOperandSegmentSizeAttr(), 787 b.getI32VectorAttr(operandSegmentSizes)); 788 } 789 790 b.createOperation(state); 791 return success(); 792 } 793 case llvm::Instruction::PHI: { 794 Type type = processType(inst->getType()); 795 if (!type) 796 return failure(); 797 v = b.getInsertionBlock()->addArgument(type); 798 return success(); 799 } 800 case llvm::Instruction::Call: { 801 llvm::CallInst *ci = cast<llvm::CallInst>(inst); 802 SmallVector<Value, 4> ops; 803 ops.reserve(inst->getNumOperands()); 804 for (auto &op : ci->arg_operands()) { 805 Value arg = processValue(op.get()); 806 if (!arg) 807 return failure(); 808 ops.push_back(arg); 809 } 810 811 SmallVector<Type, 2> tys; 812 if (!ci->getType()->isVoidTy()) { 813 Type type = processType(inst->getType()); 814 if (!type) 815 return failure(); 816 tys.push_back(type); 817 } 818 Operation *op; 819 if (llvm::Function *callee = ci->getCalledFunction()) { 820 op = b.create<CallOp>(loc, tys, b.getSymbolRefAttr(callee->getName()), 821 ops); 822 } else { 823 Value calledValue = processValue(ci->getCalledOperand()); 824 if (!calledValue) 825 return failure(); 826 ops.insert(ops.begin(), calledValue); 827 op = b.create<CallOp>(loc, tys, ops); 828 } 829 if (!ci->getType()->isVoidTy()) 830 v = op->getResult(0); 831 return success(); 832 } 833 case llvm::Instruction::LandingPad: { 834 llvm::LandingPadInst *lpi = cast<llvm::LandingPadInst>(inst); 835 SmallVector<Value, 4> ops; 836 837 for (unsigned i = 0, ie = lpi->getNumClauses(); i < ie; i++) 838 ops.push_back(processConstant(lpi->getClause(i))); 839 840 Type ty = processType(lpi->getType()); 841 if (!ty) 842 return failure(); 843 844 v = b.create<LandingpadOp>(loc, ty, lpi->isCleanup(), ops); 845 return success(); 846 } 847 case llvm::Instruction::Invoke: { 848 llvm::InvokeInst *ii = cast<llvm::InvokeInst>(inst); 849 850 SmallVector<Type, 2> tys; 851 if (!ii->getType()->isVoidTy()) 852 tys.push_back(processType(inst->getType())); 853 854 SmallVector<Value, 4> ops; 855 ops.reserve(inst->getNumOperands() + 1); 856 for (auto &op : ii->arg_operands()) 857 ops.push_back(processValue(op.get())); 858 859 SmallVector<Value, 4> normalArgs, unwindArgs; 860 (void)processBranchArgs(ii, ii->getNormalDest(), normalArgs); 861 (void)processBranchArgs(ii, ii->getUnwindDest(), unwindArgs); 862 863 Operation *op; 864 if (llvm::Function *callee = ii->getCalledFunction()) { 865 op = b.create<InvokeOp>(loc, tys, b.getSymbolRefAttr(callee->getName()), 866 ops, blocks[ii->getNormalDest()], normalArgs, 867 blocks[ii->getUnwindDest()], unwindArgs); 868 } else { 869 ops.insert(ops.begin(), processValue(ii->getCalledOperand())); 870 op = b.create<InvokeOp>(loc, tys, ops, blocks[ii->getNormalDest()], 871 normalArgs, blocks[ii->getUnwindDest()], 872 unwindArgs); 873 } 874 875 if (!ii->getType()->isVoidTy()) 876 v = op->getResult(0); 877 return success(); 878 } 879 case llvm::Instruction::Fence: { 880 StringRef syncscope; 881 SmallVector<StringRef, 4> ssNs; 882 llvm::LLVMContext &llvmContext = inst->getContext(); 883 llvm::FenceInst *fence = cast<llvm::FenceInst>(inst); 884 llvmContext.getSyncScopeNames(ssNs); 885 int fenceSyncScopeID = fence->getSyncScopeID(); 886 for (unsigned i = 0, e = ssNs.size(); i != e; i++) { 887 if (fenceSyncScopeID == llvmContext.getOrInsertSyncScopeID(ssNs[i])) { 888 syncscope = ssNs[i]; 889 break; 890 } 891 } 892 b.create<FenceOp>(loc, getLLVMAtomicOrdering(fence->getOrdering()), 893 syncscope); 894 return success(); 895 } 896 case llvm::Instruction::GetElementPtr: { 897 // FIXME: Support inbounds GEPs. 898 llvm::GetElementPtrInst *gep = cast<llvm::GetElementPtrInst>(inst); 899 SmallVector<Value, 4> ops; 900 for (auto *op : gep->operand_values()) { 901 Value value = processValue(op); 902 if (!value) 903 return failure(); 904 ops.push_back(value); 905 } 906 Type type = processType(inst->getType()); 907 if (!type) 908 return failure(); 909 v = b.create<GEPOp>(loc, type, ops); 910 return success(); 911 } 912 } 913 } 914 915 FlatSymbolRefAttr Importer::getPersonalityAsAttr(llvm::Function *f) { 916 if (!f->hasPersonalityFn()) 917 return nullptr; 918 919 llvm::Constant *pf = f->getPersonalityFn(); 920 921 // If it directly has a name, we can use it. 922 if (pf->hasName()) 923 return b.getSymbolRefAttr(pf->getName()); 924 925 // If it doesn't have a name, currently, only function pointers that are 926 // bitcast to i8* are parsed. 927 if (auto ce = dyn_cast<llvm::ConstantExpr>(pf)) { 928 if (ce->getOpcode() == llvm::Instruction::BitCast && 929 ce->getType() == llvm::Type::getInt8PtrTy(f->getContext())) { 930 if (auto func = dyn_cast<llvm::Function>(ce->getOperand(0))) 931 return b.getSymbolRefAttr(func->getName()); 932 } 933 } 934 return FlatSymbolRefAttr(); 935 } 936 937 LogicalResult Importer::processFunction(llvm::Function *f) { 938 blocks.clear(); 939 instMap.clear(); 940 unknownInstMap.clear(); 941 942 auto functionType = 943 processType(f->getFunctionType()).dyn_cast<LLVMFunctionType>(); 944 if (!functionType) 945 return failure(); 946 947 b.setInsertionPoint(module.getBody(), getFuncInsertPt()); 948 LLVMFuncOp fop = 949 b.create<LLVMFuncOp>(UnknownLoc::get(context), f->getName(), functionType, 950 convertLinkageFromLLVM(f->getLinkage())); 951 952 if (FlatSymbolRefAttr personality = getPersonalityAsAttr(f)) 953 fop->setAttr(b.getIdentifier("personality"), personality); 954 else if (f->hasPersonalityFn()) 955 emitWarning(UnknownLoc::get(context), 956 "could not deduce personality, skipping it"); 957 958 if (f->isDeclaration()) 959 return success(); 960 961 // Eagerly create all blocks. 962 SmallVector<Block *, 4> blockList; 963 for (llvm::BasicBlock &bb : *f) { 964 blockList.push_back(b.createBlock(&fop.body(), fop.body().end())); 965 blocks[&bb] = blockList.back(); 966 } 967 currentEntryBlock = blockList[0]; 968 969 // Add function arguments to the entry block. 970 for (auto kv : llvm::enumerate(f->args())) 971 instMap[&kv.value()] = 972 blockList[0]->addArgument(functionType.getParamType(kv.index())); 973 974 for (auto bbs : llvm::zip(*f, blockList)) { 975 if (failed(processBasicBlock(&std::get<0>(bbs), std::get<1>(bbs)))) 976 return failure(); 977 } 978 979 // Now that all instructions are guaranteed to have been visited, ensure 980 // any unknown uses we encountered are remapped. 981 for (auto &llvmAndUnknown : unknownInstMap) { 982 assert(instMap.count(llvmAndUnknown.first)); 983 Value newValue = instMap[llvmAndUnknown.first]; 984 Value oldValue = llvmAndUnknown.second->getResult(0); 985 oldValue.replaceAllUsesWith(newValue); 986 llvmAndUnknown.second->erase(); 987 } 988 return success(); 989 } 990 991 LogicalResult Importer::processBasicBlock(llvm::BasicBlock *bb, Block *block) { 992 b.setInsertionPointToStart(block); 993 for (llvm::Instruction &inst : *bb) { 994 if (failed(processInstruction(&inst))) 995 return failure(); 996 } 997 return success(); 998 } 999 1000 OwningModuleRef 1001 mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule, 1002 MLIRContext *context) { 1003 context->loadDialect<LLVMDialect>(); 1004 OwningModuleRef module(ModuleOp::create( 1005 FileLineColLoc::get(context, "", /*line=*/0, /*column=*/0))); 1006 1007 Importer deserializer(context, module.get()); 1008 for (llvm::GlobalVariable &gv : llvmModule->globals()) { 1009 if (!deserializer.processGlobal(&gv)) 1010 return {}; 1011 } 1012 for (llvm::Function &f : llvmModule->functions()) { 1013 if (failed(deserializer.processFunction(&f))) 1014 return {}; 1015 } 1016 1017 return module; 1018 } 1019 1020 // Deserializes the LLVM bitcode stored in `input` into an MLIR module in the 1021 // LLVM dialect. 1022 OwningModuleRef translateLLVMIRToModule(llvm::SourceMgr &sourceMgr, 1023 MLIRContext *context) { 1024 llvm::SMDiagnostic err; 1025 llvm::LLVMContext llvmContext; 1026 std::unique_ptr<llvm::Module> llvmModule = llvm::parseIR( 1027 *sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), err, llvmContext); 1028 if (!llvmModule) { 1029 std::string errStr; 1030 llvm::raw_string_ostream errStream(errStr); 1031 err.print(/*ProgName=*/"", errStream); 1032 emitError(UnknownLoc::get(context)) << errStream.str(); 1033 return {}; 1034 } 1035 return translateLLVMIRToModule(std::move(llvmModule), context); 1036 } 1037 1038 namespace mlir { 1039 void registerFromLLVMIRTranslation() { 1040 TranslateToMLIRRegistration fromLLVM( 1041 "import-llvm", [](llvm::SourceMgr &sourceMgr, MLIRContext *context) { 1042 return ::translateLLVMIRToModule(sourceMgr, context); 1043 }); 1044 } 1045 } // namespace mlir 1046