1 //===- ConvertFromLLVMIR.cpp - MLIR to LLVM IR conversion -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a translation between LLVM IR and the MLIR LLVM dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "mlir/IR/Builders.h" 15 #include "mlir/IR/MLIRContext.h" 16 #include "mlir/IR/Module.h" 17 #include "mlir/IR/StandardTypes.h" 18 #include "mlir/Target/LLVMIR.h" 19 #include "mlir/Translation.h" 20 21 #include "llvm/IR/Attributes.h" 22 #include "llvm/IR/Constants.h" 23 #include "llvm/IR/Function.h" 24 #include "llvm/IR/Instructions.h" 25 #include "llvm/IR/Type.h" 26 #include "llvm/IRReader/IRReader.h" 27 #include "llvm/Support/Error.h" 28 #include "llvm/Support/SourceMgr.h" 29 30 using namespace mlir; 31 using namespace mlir::LLVM; 32 33 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsFromLLVM.inc" 34 35 // Utility to print an LLVM value as a string for passing to emitError(). 36 // FIXME: Diagnostic should be able to natively handle types that have 37 // operator << (raw_ostream&) defined. 38 static std::string diag(llvm::Value &v) { 39 std::string s; 40 llvm::raw_string_ostream os(s); 41 os << v; 42 return os.str(); 43 } 44 45 // Handles importing globals and functions from an LLVM module. 46 namespace { 47 class Importer { 48 public: 49 Importer(MLIRContext *context, ModuleOp module) 50 : b(context), context(context), module(module), 51 unknownLoc(FileLineColLoc::get("imported-bitcode", 0, 0, context)) { 52 b.setInsertionPointToStart(module.getBody()); 53 dialect = context->getRegisteredDialect<LLVMDialect>(); 54 } 55 56 /// Imports `f` into the current module. 57 LogicalResult processFunction(llvm::Function *f); 58 59 /// Imports GV as a GlobalOp, creating it if it doesn't exist. 60 GlobalOp processGlobal(llvm::GlobalVariable *GV); 61 62 private: 63 /// Returns personality of `f` as a FlatSymbolRefAttr. 64 FlatSymbolRefAttr getPersonalityAsAttr(llvm::Function *f); 65 /// Imports `bb` into `block`, which must be initially empty. 66 LogicalResult processBasicBlock(llvm::BasicBlock *bb, Block *block); 67 /// Imports `inst` and populates instMap[inst] with the imported Value. 68 LogicalResult processInstruction(llvm::Instruction *inst); 69 /// Creates an LLVMType for `type`. 70 LLVMType processType(llvm::Type *type); 71 /// `value` is an SSA-use. Return the remapped version of `value` or a 72 /// placeholder that will be remapped later if this is an instruction that 73 /// has not yet been visited. 74 Value processValue(llvm::Value *value); 75 /// Create the most accurate Location possible using a llvm::DebugLoc and 76 /// possibly an llvm::Instruction to narrow the Location if debug information 77 /// is unavailable. 78 Location processDebugLoc(const llvm::DebugLoc &loc, 79 llvm::Instruction *inst = nullptr); 80 /// `br` branches to `target`. Append the block arguments to attach to the 81 /// generated branch op to `blockArguments`. These should be in the same order 82 /// as the PHIs in `target`. 83 LogicalResult processBranchArgs(llvm::Instruction *br, 84 llvm::BasicBlock *target, 85 SmallVectorImpl<Value> &blockArguments); 86 /// Returns the standard type equivalent to be used in attributes for the 87 /// given LLVM IR dialect type. 88 Type getStdTypeForAttr(LLVMType type); 89 /// Return `value` as an attribute to attach to a GlobalOp. 90 Attribute getConstantAsAttr(llvm::Constant *value); 91 /// Return `c` as an MLIR Value. This could either be a ConstantOp, or 92 /// an expanded sequence of ops in the current function's entry block (for 93 /// ConstantExprs or ConstantGEPs). 94 Value processConstant(llvm::Constant *c); 95 96 /// The current builder, pointing at where the next Instruction should be 97 /// generated. 98 OpBuilder b; 99 /// The current context. 100 MLIRContext *context; 101 /// The current module being created. 102 ModuleOp module; 103 /// The entry block of the current function being processed. 104 Block *currentEntryBlock; 105 106 /// Globals are inserted before the first function, if any. 107 Block::iterator getGlobalInsertPt() { 108 auto i = module.getBody()->begin(); 109 while (!isa<LLVMFuncOp, ModuleTerminatorOp>(i)) 110 ++i; 111 return i; 112 } 113 114 /// Functions are always inserted before the module terminator. 115 Block::iterator getFuncInsertPt() { 116 return std::prev(module.getBody()->end()); 117 } 118 119 /// Remapped blocks, for the current function. 120 DenseMap<llvm::BasicBlock *, Block *> blocks; 121 /// Remapped values. These are function-local. 122 DenseMap<llvm::Value *, Value> instMap; 123 /// Instructions that had not been defined when first encountered as a use. 124 /// Maps to the dummy Operation that was created in processValue(). 125 DenseMap<llvm::Value *, Operation *> unknownInstMap; 126 /// Uniquing map of GlobalVariables. 127 DenseMap<llvm::GlobalVariable *, GlobalOp> globals; 128 /// Cached FileLineColLoc::get("imported-bitcode", 0, 0). 129 Location unknownLoc; 130 /// Cached dialect. 131 LLVMDialect *dialect; 132 }; 133 } // namespace 134 135 Location Importer::processDebugLoc(const llvm::DebugLoc &loc, 136 llvm::Instruction *inst) { 137 if (!loc && inst) { 138 std::string s; 139 llvm::raw_string_ostream os(s); 140 os << "llvm-imported-inst-%"; 141 inst->printAsOperand(os, /*PrintType=*/false); 142 return FileLineColLoc::get(os.str(), 0, 0, context); 143 } else if (!loc) { 144 return unknownLoc; 145 } 146 // FIXME: Obtain the filename from DILocationInfo. 147 return FileLineColLoc::get("imported-bitcode", loc.getLine(), loc.getCol(), 148 context); 149 } 150 151 LLVMType Importer::processType(llvm::Type *type) { 152 switch (type->getTypeID()) { 153 case llvm::Type::FloatTyID: 154 return LLVMType::getFloatTy(dialect); 155 case llvm::Type::DoubleTyID: 156 return LLVMType::getDoubleTy(dialect); 157 case llvm::Type::IntegerTyID: 158 return LLVMType::getIntNTy(dialect, type->getIntegerBitWidth()); 159 case llvm::Type::PointerTyID: { 160 LLVMType elementType = processType(type->getPointerElementType()); 161 if (!elementType) 162 return nullptr; 163 return elementType.getPointerTo(type->getPointerAddressSpace()); 164 } 165 case llvm::Type::ArrayTyID: { 166 LLVMType elementType = processType(type->getArrayElementType()); 167 if (!elementType) 168 return nullptr; 169 return LLVMType::getArrayTy(elementType, type->getArrayNumElements()); 170 } 171 case llvm::Type::ScalableVectorTyID: { 172 emitError(unknownLoc) << "scalable vector types not supported"; 173 return nullptr; 174 } 175 case llvm::Type::FixedVectorTyID: { 176 auto *typeVTy = llvm::cast<llvm::FixedVectorType>(type); 177 LLVMType elementType = processType(typeVTy->getElementType()); 178 if (!elementType) 179 return nullptr; 180 return LLVMType::getVectorTy(elementType, typeVTy->getNumElements()); 181 } 182 case llvm::Type::VoidTyID: 183 return LLVMType::getVoidTy(dialect); 184 case llvm::Type::FP128TyID: 185 return LLVMType::getFP128Ty(dialect); 186 case llvm::Type::X86_FP80TyID: 187 return LLVMType::getX86_FP80Ty(dialect); 188 case llvm::Type::StructTyID: { 189 SmallVector<LLVMType, 4> elementTypes; 190 elementTypes.reserve(type->getStructNumElements()); 191 for (unsigned i = 0, e = type->getStructNumElements(); i != e; ++i) { 192 LLVMType ty = processType(type->getStructElementType(i)); 193 if (!ty) 194 return nullptr; 195 elementTypes.push_back(ty); 196 } 197 return LLVMType::getStructTy(dialect, elementTypes, 198 cast<llvm::StructType>(type)->isPacked()); 199 } 200 case llvm::Type::FunctionTyID: { 201 llvm::FunctionType *fty = cast<llvm::FunctionType>(type); 202 SmallVector<LLVMType, 4> paramTypes; 203 for (unsigned i = 0, e = fty->getNumParams(); i != e; ++i) { 204 LLVMType ty = processType(fty->getParamType(i)); 205 if (!ty) 206 return nullptr; 207 paramTypes.push_back(ty); 208 } 209 LLVMType result = processType(fty->getReturnType()); 210 if (!result) 211 return nullptr; 212 213 return LLVMType::getFunctionTy(result, paramTypes, fty->isVarArg()); 214 } 215 default: { 216 // FIXME: Diagnostic should be able to natively handle types that have 217 // operator<<(raw_ostream&) defined. 218 std::string s; 219 llvm::raw_string_ostream os(s); 220 os << *type; 221 emitError(unknownLoc) << "unhandled type: " << os.str(); 222 return nullptr; 223 } 224 } 225 } 226 227 // We only need integers, floats, doubles, and vectors and tensors thereof for 228 // attributes. Scalar and vector types are converted to the standard 229 // equivalents. Array types are converted to ranked tensors; nested array types 230 // are converted to multi-dimensional tensors or vectors, depending on the 231 // innermost type being a scalar or a vector. 232 Type Importer::getStdTypeForAttr(LLVMType type) { 233 if (!type) 234 return nullptr; 235 236 if (type.isIntegerTy()) 237 return b.getIntegerType(type.getUnderlyingType()->getIntegerBitWidth()); 238 239 if (type.getUnderlyingType()->isFloatTy()) 240 return b.getF32Type(); 241 242 if (type.getUnderlyingType()->isDoubleTy()) 243 return b.getF64Type(); 244 245 // LLVM vectors can only contain scalars. 246 if (type.isVectorTy()) { 247 auto numElements = llvm::cast<llvm::VectorType>(type.getUnderlyingType()) 248 ->getElementCount(); 249 if (numElements.Scalable) { 250 emitError(unknownLoc) << "scalable vectors not supported"; 251 return nullptr; 252 } 253 Type elementType = getStdTypeForAttr(type.getVectorElementType()); 254 if (!elementType) 255 return nullptr; 256 return VectorType::get(numElements.Min, elementType); 257 } 258 259 // LLVM arrays can contain other arrays or vectors. 260 if (type.isArrayTy()) { 261 // Recover the nested array shape. 262 SmallVector<int64_t, 4> shape; 263 shape.push_back(type.getArrayNumElements()); 264 while (type.getArrayElementType().isArrayTy()) { 265 type = type.getArrayElementType(); 266 shape.push_back(type.getArrayNumElements()); 267 } 268 269 // If the innermost type is a vector, use the multi-dimensional vector as 270 // attribute type. 271 if (type.getArrayElementType().isVectorTy()) { 272 LLVMType vectorType = type.getArrayElementType(); 273 auto numElements = 274 llvm::cast<llvm::VectorType>(vectorType.getUnderlyingType()) 275 ->getElementCount(); 276 if (numElements.Scalable) { 277 emitError(unknownLoc) << "scalable vectors not supported"; 278 return nullptr; 279 } 280 shape.push_back(numElements.Min); 281 282 Type elementType = getStdTypeForAttr(vectorType.getVectorElementType()); 283 if (!elementType) 284 return nullptr; 285 return VectorType::get(shape, elementType); 286 } 287 288 // Otherwise use a tensor. 289 Type elementType = getStdTypeForAttr(type.getArrayElementType()); 290 if (!elementType) 291 return nullptr; 292 return RankedTensorType::get(shape, elementType); 293 } 294 295 return nullptr; 296 } 297 298 // Get the given constant as an attribute. Not all constants can be represented 299 // as attributes. 300 Attribute Importer::getConstantAsAttr(llvm::Constant *value) { 301 if (auto *ci = dyn_cast<llvm::ConstantInt>(value)) 302 return b.getIntegerAttr( 303 IntegerType::get(ci->getType()->getBitWidth(), context), 304 ci->getValue()); 305 if (auto *c = dyn_cast<llvm::ConstantDataArray>(value)) 306 if (c->isString()) 307 return b.getStringAttr(c->getAsString()); 308 if (auto *c = dyn_cast<llvm::ConstantFP>(value)) { 309 if (c->getType()->isDoubleTy()) 310 return b.getFloatAttr(FloatType::getF64(context), c->getValueAPF()); 311 else if (c->getType()->isFloatingPointTy()) 312 return b.getFloatAttr(FloatType::getF32(context), c->getValueAPF()); 313 } 314 if (auto *f = dyn_cast<llvm::Function>(value)) 315 return b.getSymbolRefAttr(f->getName()); 316 317 // Convert constant data to a dense elements attribute. 318 if (auto *cd = dyn_cast<llvm::ConstantDataSequential>(value)) { 319 LLVMType type = processType(cd->getElementType()); 320 if (!type) 321 return nullptr; 322 323 auto attrType = getStdTypeForAttr(processType(cd->getType())) 324 .dyn_cast_or_null<ShapedType>(); 325 if (!attrType) 326 return nullptr; 327 328 if (type.isIntegerTy()) { 329 SmallVector<APInt, 8> values; 330 values.reserve(cd->getNumElements()); 331 for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i) 332 values.push_back(cd->getElementAsAPInt(i)); 333 return DenseElementsAttr::get(attrType, values); 334 } 335 336 if (type.isFloatTy() || type.isDoubleTy()) { 337 SmallVector<APFloat, 8> values; 338 values.reserve(cd->getNumElements()); 339 for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i) 340 values.push_back(cd->getElementAsAPFloat(i)); 341 return DenseElementsAttr::get(attrType, values); 342 } 343 344 return nullptr; 345 } 346 347 // Unpack constant aggregates to create dense elements attribute whenever 348 // possible. Return nullptr (failure) otherwise. 349 if (isa<llvm::ConstantAggregate>(value)) { 350 auto outerType = getStdTypeForAttr(processType(value->getType())) 351 .dyn_cast_or_null<ShapedType>(); 352 if (!outerType) 353 return nullptr; 354 355 SmallVector<Attribute, 8> values; 356 SmallVector<int64_t, 8> shape; 357 358 for (unsigned i = 0, e = value->getNumOperands(); i < e; ++i) { 359 auto nested = getConstantAsAttr(value->getAggregateElement(i)) 360 .dyn_cast_or_null<DenseElementsAttr>(); 361 if (!nested) 362 return nullptr; 363 364 values.append(nested.attr_value_begin(), nested.attr_value_end()); 365 } 366 367 return DenseElementsAttr::get(outerType, values); 368 } 369 370 return nullptr; 371 } 372 373 GlobalOp Importer::processGlobal(llvm::GlobalVariable *GV) { 374 auto it = globals.find(GV); 375 if (it != globals.end()) 376 return it->second; 377 378 OpBuilder b(module.getBody(), getGlobalInsertPt()); 379 Attribute valueAttr; 380 if (GV->hasInitializer()) 381 valueAttr = getConstantAsAttr(GV->getInitializer()); 382 LLVMType type = processType(GV->getValueType()); 383 if (!type) 384 return nullptr; 385 GlobalOp op = b.create<GlobalOp>( 386 UnknownLoc::get(context), type, GV->isConstant(), 387 convertLinkageFromLLVM(GV->getLinkage()), GV->getName(), valueAttr); 388 if (GV->hasInitializer() && !valueAttr) { 389 Region &r = op.getInitializerRegion(); 390 currentEntryBlock = b.createBlock(&r); 391 b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin()); 392 Value v = processConstant(GV->getInitializer()); 393 if (!v) 394 return nullptr; 395 b.create<ReturnOp>(op.getLoc(), ArrayRef<Value>({v})); 396 } 397 return globals[GV] = op; 398 } 399 400 Value Importer::processConstant(llvm::Constant *c) { 401 OpBuilder bEntry(currentEntryBlock, currentEntryBlock->begin()); 402 if (Attribute attr = getConstantAsAttr(c)) { 403 // These constants can be represented as attributes. 404 OpBuilder b(currentEntryBlock, currentEntryBlock->begin()); 405 LLVMType type = processType(c->getType()); 406 if (!type) 407 return nullptr; 408 if (auto symbolRef = attr.dyn_cast<FlatSymbolRefAttr>()) 409 return instMap[c] = bEntry.create<AddressOfOp>(unknownLoc, type, 410 symbolRef.getValue()); 411 return instMap[c] = bEntry.create<ConstantOp>(unknownLoc, type, attr); 412 } 413 if (auto *cn = dyn_cast<llvm::ConstantPointerNull>(c)) { 414 LLVMType type = processType(cn->getType()); 415 if (!type) 416 return nullptr; 417 return instMap[c] = bEntry.create<NullOp>(unknownLoc, type); 418 } 419 if (auto *GV = dyn_cast<llvm::GlobalVariable>(c)) 420 return bEntry.create<AddressOfOp>(UnknownLoc::get(context), 421 processGlobal(GV)); 422 423 if (auto *ce = dyn_cast<llvm::ConstantExpr>(c)) { 424 llvm::Instruction *i = ce->getAsInstruction(); 425 OpBuilder::InsertionGuard guard(b); 426 b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin()); 427 if (failed(processInstruction(i))) 428 return nullptr; 429 assert(instMap.count(i)); 430 431 // Remove this zombie LLVM instruction now, leaving us only with the MLIR 432 // op. 433 i->deleteValue(); 434 return instMap[c] = instMap[i]; 435 } 436 if (auto *ue = dyn_cast<llvm::UndefValue>(c)) { 437 LLVMType type = processType(ue->getType()); 438 if (!type) 439 return nullptr; 440 return instMap[c] = bEntry.create<UndefOp>(UnknownLoc::get(context), type); 441 } 442 emitError(unknownLoc) << "unhandled constant: " << diag(*c); 443 return nullptr; 444 } 445 446 Value Importer::processValue(llvm::Value *value) { 447 auto it = instMap.find(value); 448 if (it != instMap.end()) 449 return it->second; 450 451 // We don't expect to see instructions in dominator order. If we haven't seen 452 // this instruction yet, create an unknown op and remap it later. 453 if (isa<llvm::Instruction>(value)) { 454 OperationState state(UnknownLoc::get(context), "llvm.unknown"); 455 LLVMType type = processType(value->getType()); 456 if (!type) 457 return nullptr; 458 state.addTypes(type); 459 unknownInstMap[value] = b.createOperation(state); 460 return unknownInstMap[value]->getResult(0); 461 } 462 463 if (auto *c = dyn_cast<llvm::Constant>(value)) 464 return processConstant(c); 465 466 emitError(unknownLoc) << "unhandled value: " << diag(*value); 467 return nullptr; 468 } 469 470 /// Return the MLIR OperationName for the given LLVM opcode. 471 static StringRef lookupOperationNameFromOpcode(unsigned opcode) { 472 // Maps from LLVM opcode to MLIR OperationName. This is deliberately ordered 473 // as in llvm/IR/Instructions.def to aid comprehension and spot missing 474 // instructions. 475 #define INST(llvm_n, mlir_n) \ 476 { llvm::Instruction::llvm_n, LLVM::mlir_n##Op::getOperationName() } 477 static const DenseMap<unsigned, StringRef> opcMap = { 478 // Ret is handled specially. 479 // Br is handled specially. 480 // FIXME: switch 481 // FIXME: indirectbr 482 // FIXME: invoke 483 INST(Resume, Resume), 484 // FIXME: unreachable 485 // FIXME: cleanupret 486 // FIXME: catchret 487 // FIXME: catchswitch 488 // FIXME: callbr 489 // FIXME: fneg 490 INST(Add, Add), INST(FAdd, FAdd), INST(Sub, Sub), INST(FSub, FSub), 491 INST(Mul, Mul), INST(FMul, FMul), INST(UDiv, UDiv), INST(SDiv, SDiv), 492 INST(FDiv, FDiv), INST(URem, URem), INST(SRem, SRem), INST(FRem, FRem), 493 INST(Shl, Shl), INST(LShr, LShr), INST(AShr, AShr), INST(And, And), 494 INST(Or, Or), INST(Xor, XOr), INST(Alloca, Alloca), INST(Load, Load), 495 INST(Store, Store), 496 // Getelementptr is handled specially. 497 INST(Ret, Return), INST(Fence, Fence), 498 // FIXME: atomiccmpxchg 499 // FIXME: atomicrmw 500 INST(Trunc, Trunc), INST(ZExt, ZExt), INST(SExt, SExt), 501 INST(FPToUI, FPToUI), INST(FPToSI, FPToSI), INST(UIToFP, UIToFP), 502 INST(SIToFP, SIToFP), INST(FPTrunc, FPTrunc), INST(FPExt, FPExt), 503 INST(PtrToInt, PtrToInt), INST(IntToPtr, IntToPtr), 504 INST(BitCast, Bitcast), INST(AddrSpaceCast, AddrSpaceCast), 505 // FIXME: cleanuppad 506 // FIXME: catchpad 507 // ICmp is handled specially. 508 // FIXME: fcmp 509 // PHI is handled specially. 510 INST(Freeze, Freeze), INST(Call, Call), 511 // FIXME: select 512 // FIXME: vaarg 513 // FIXME: extractelement 514 // FIXME: insertelement 515 // FIXME: shufflevector 516 // FIXME: extractvalue 517 // FIXME: insertvalue 518 // FIXME: landingpad 519 }; 520 #undef INST 521 522 return opcMap.lookup(opcode); 523 } 524 525 static ICmpPredicate getICmpPredicate(llvm::CmpInst::Predicate p) { 526 switch (p) { 527 default: 528 llvm_unreachable("incorrect comparison predicate"); 529 case llvm::CmpInst::Predicate::ICMP_EQ: 530 return LLVM::ICmpPredicate::eq; 531 case llvm::CmpInst::Predicate::ICMP_NE: 532 return LLVM::ICmpPredicate::ne; 533 case llvm::CmpInst::Predicate::ICMP_SLT: 534 return LLVM::ICmpPredicate::slt; 535 case llvm::CmpInst::Predicate::ICMP_SLE: 536 return LLVM::ICmpPredicate::sle; 537 case llvm::CmpInst::Predicate::ICMP_SGT: 538 return LLVM::ICmpPredicate::sgt; 539 case llvm::CmpInst::Predicate::ICMP_SGE: 540 return LLVM::ICmpPredicate::sge; 541 case llvm::CmpInst::Predicate::ICMP_ULT: 542 return LLVM::ICmpPredicate::ult; 543 case llvm::CmpInst::Predicate::ICMP_ULE: 544 return LLVM::ICmpPredicate::ule; 545 case llvm::CmpInst::Predicate::ICMP_UGT: 546 return LLVM::ICmpPredicate::ugt; 547 case llvm::CmpInst::Predicate::ICMP_UGE: 548 return LLVM::ICmpPredicate::uge; 549 } 550 llvm_unreachable("incorrect comparison predicate"); 551 } 552 553 static AtomicOrdering getLLVMAtomicOrdering(llvm::AtomicOrdering ordering) { 554 switch (ordering) { 555 case llvm::AtomicOrdering::NotAtomic: 556 return LLVM::AtomicOrdering::not_atomic; 557 case llvm::AtomicOrdering::Unordered: 558 return LLVM::AtomicOrdering::unordered; 559 case llvm::AtomicOrdering::Monotonic: 560 return LLVM::AtomicOrdering::monotonic; 561 case llvm::AtomicOrdering::Acquire: 562 return LLVM::AtomicOrdering::acquire; 563 case llvm::AtomicOrdering::Release: 564 return LLVM::AtomicOrdering::release; 565 case llvm::AtomicOrdering::AcquireRelease: 566 return LLVM::AtomicOrdering::acq_rel; 567 case llvm::AtomicOrdering::SequentiallyConsistent: 568 return LLVM::AtomicOrdering::seq_cst; 569 } 570 llvm_unreachable("incorrect atomic ordering"); 571 } 572 573 // `br` branches to `target`. Return the branch arguments to `br`, in the 574 // same order of the PHIs in `target`. 575 LogicalResult 576 Importer::processBranchArgs(llvm::Instruction *br, llvm::BasicBlock *target, 577 SmallVectorImpl<Value> &blockArguments) { 578 for (auto inst = target->begin(); isa<llvm::PHINode>(inst); ++inst) { 579 auto *PN = cast<llvm::PHINode>(&*inst); 580 Value value = processValue(PN->getIncomingValueForBlock(br->getParent())); 581 if (!value) 582 return failure(); 583 blockArguments.push_back(value); 584 } 585 return success(); 586 } 587 588 LogicalResult Importer::processInstruction(llvm::Instruction *inst) { 589 // FIXME: Support uses of SubtargetData. Currently inbounds GEPs, fast-math 590 // flags and call / operand attributes are not supported. 591 Location loc = processDebugLoc(inst->getDebugLoc(), inst); 592 Value &v = instMap[inst]; 593 assert(!v && "processInstruction must be called only once per instruction!"); 594 switch (inst->getOpcode()) { 595 default: 596 return emitError(loc) << "unknown instruction: " << diag(*inst); 597 case llvm::Instruction::Add: 598 case llvm::Instruction::FAdd: 599 case llvm::Instruction::Sub: 600 case llvm::Instruction::FSub: 601 case llvm::Instruction::Mul: 602 case llvm::Instruction::FMul: 603 case llvm::Instruction::UDiv: 604 case llvm::Instruction::SDiv: 605 case llvm::Instruction::FDiv: 606 case llvm::Instruction::URem: 607 case llvm::Instruction::SRem: 608 case llvm::Instruction::FRem: 609 case llvm::Instruction::Shl: 610 case llvm::Instruction::LShr: 611 case llvm::Instruction::AShr: 612 case llvm::Instruction::And: 613 case llvm::Instruction::Or: 614 case llvm::Instruction::Xor: 615 case llvm::Instruction::Alloca: 616 case llvm::Instruction::Load: 617 case llvm::Instruction::Store: 618 case llvm::Instruction::Ret: 619 case llvm::Instruction::Resume: 620 case llvm::Instruction::Trunc: 621 case llvm::Instruction::ZExt: 622 case llvm::Instruction::SExt: 623 case llvm::Instruction::FPToUI: 624 case llvm::Instruction::FPToSI: 625 case llvm::Instruction::UIToFP: 626 case llvm::Instruction::SIToFP: 627 case llvm::Instruction::FPTrunc: 628 case llvm::Instruction::FPExt: 629 case llvm::Instruction::PtrToInt: 630 case llvm::Instruction::IntToPtr: 631 case llvm::Instruction::AddrSpaceCast: 632 case llvm::Instruction::Freeze: 633 case llvm::Instruction::BitCast: { 634 OperationState state(loc, lookupOperationNameFromOpcode(inst->getOpcode())); 635 SmallVector<Value, 4> ops; 636 ops.reserve(inst->getNumOperands()); 637 for (auto *op : inst->operand_values()) { 638 Value value = processValue(op); 639 if (!value) 640 return failure(); 641 ops.push_back(value); 642 } 643 state.addOperands(ops); 644 if (!inst->getType()->isVoidTy()) { 645 LLVMType type = processType(inst->getType()); 646 if (!type) 647 return failure(); 648 state.addTypes(type); 649 } 650 Operation *op = b.createOperation(state); 651 if (!inst->getType()->isVoidTy()) 652 v = op->getResult(0); 653 return success(); 654 } 655 case llvm::Instruction::ICmp: { 656 Value lhs = processValue(inst->getOperand(0)); 657 Value rhs = processValue(inst->getOperand(1)); 658 if (!lhs || !rhs) 659 return failure(); 660 v = b.create<ICmpOp>( 661 loc, getICmpPredicate(cast<llvm::ICmpInst>(inst)->getPredicate()), lhs, 662 rhs); 663 return success(); 664 } 665 case llvm::Instruction::Br: { 666 auto *brInst = cast<llvm::BranchInst>(inst); 667 OperationState state(loc, 668 brInst->isConditional() ? "llvm.cond_br" : "llvm.br"); 669 if (brInst->isConditional()) { 670 Value condition = processValue(brInst->getCondition()); 671 if (!condition) 672 return failure(); 673 state.addOperands(condition); 674 } 675 676 std::array<int32_t, 3> operandSegmentSizes = {1, 0, 0}; 677 for (int i : llvm::seq<int>(0, brInst->getNumSuccessors())) { 678 auto *succ = brInst->getSuccessor(i); 679 SmallVector<Value, 4> blockArguments; 680 if (failed(processBranchArgs(brInst, succ, blockArguments))) 681 return failure(); 682 state.addSuccessors(blocks[succ]); 683 state.addOperands(blockArguments); 684 operandSegmentSizes[i + 1] = blockArguments.size(); 685 } 686 687 if (brInst->isConditional()) { 688 state.addAttribute(LLVM::CondBrOp::getOperandSegmentSizeAttr(), 689 b.getI32VectorAttr(operandSegmentSizes)); 690 } 691 692 b.createOperation(state); 693 return success(); 694 } 695 case llvm::Instruction::PHI: { 696 LLVMType type = processType(inst->getType()); 697 if (!type) 698 return failure(); 699 v = b.getInsertionBlock()->addArgument(type); 700 return success(); 701 } 702 case llvm::Instruction::Call: { 703 llvm::CallInst *ci = cast<llvm::CallInst>(inst); 704 SmallVector<Value, 4> ops; 705 ops.reserve(inst->getNumOperands()); 706 for (auto &op : ci->arg_operands()) { 707 Value arg = processValue(op.get()); 708 if (!arg) 709 return failure(); 710 ops.push_back(arg); 711 } 712 713 SmallVector<Type, 2> tys; 714 if (!ci->getType()->isVoidTy()) { 715 LLVMType type = processType(inst->getType()); 716 if (!type) 717 return failure(); 718 tys.push_back(type); 719 } 720 Operation *op; 721 if (llvm::Function *callee = ci->getCalledFunction()) { 722 op = b.create<CallOp>(loc, tys, b.getSymbolRefAttr(callee->getName()), 723 ops); 724 } else { 725 Value calledValue = processValue(ci->getCalledOperand()); 726 if (!calledValue) 727 return failure(); 728 ops.insert(ops.begin(), calledValue); 729 op = b.create<CallOp>(loc, tys, ops); 730 } 731 if (!ci->getType()->isVoidTy()) 732 v = op->getResult(0); 733 return success(); 734 } 735 case llvm::Instruction::LandingPad: { 736 llvm::LandingPadInst *lpi = cast<llvm::LandingPadInst>(inst); 737 SmallVector<Value, 4> ops; 738 739 for (unsigned i = 0, ie = lpi->getNumClauses(); i < ie; i++) 740 ops.push_back(processConstant(lpi->getClause(i))); 741 742 Type ty = processType(lpi->getType()); 743 if (!ty) 744 return failure(); 745 746 v = b.create<LandingpadOp>(loc, ty, lpi->isCleanup(), ops); 747 return success(); 748 } 749 case llvm::Instruction::Invoke: { 750 llvm::InvokeInst *ii = cast<llvm::InvokeInst>(inst); 751 752 SmallVector<Type, 2> tys; 753 if (!ii->getType()->isVoidTy()) 754 tys.push_back(processType(inst->getType())); 755 756 SmallVector<Value, 4> ops; 757 ops.reserve(inst->getNumOperands() + 1); 758 for (auto &op : ii->arg_operands()) 759 ops.push_back(processValue(op.get())); 760 761 SmallVector<Value, 4> normalArgs, unwindArgs; 762 processBranchArgs(ii, ii->getNormalDest(), normalArgs); 763 processBranchArgs(ii, ii->getUnwindDest(), unwindArgs); 764 765 Operation *op; 766 if (llvm::Function *callee = ii->getCalledFunction()) { 767 op = b.create<InvokeOp>(loc, tys, b.getSymbolRefAttr(callee->getName()), 768 ops, blocks[ii->getNormalDest()], normalArgs, 769 blocks[ii->getUnwindDest()], unwindArgs); 770 } else { 771 ops.insert(ops.begin(), processValue(ii->getCalledOperand())); 772 op = b.create<InvokeOp>(loc, tys, ops, blocks[ii->getNormalDest()], 773 normalArgs, blocks[ii->getUnwindDest()], 774 unwindArgs); 775 } 776 777 if (!ii->getType()->isVoidTy()) 778 v = op->getResult(0); 779 return success(); 780 } 781 case llvm::Instruction::Fence: { 782 StringRef syncscope; 783 SmallVector<StringRef, 4> ssNs; 784 llvm::LLVMContext &llvmContext = dialect->getLLVMContext(); 785 llvm::FenceInst *fence = cast<llvm::FenceInst>(inst); 786 llvmContext.getSyncScopeNames(ssNs); 787 int fenceSyncScopeID = fence->getSyncScopeID(); 788 for (unsigned i = 0, e = ssNs.size(); i != e; i++) { 789 if (fenceSyncScopeID == llvmContext.getOrInsertSyncScopeID(ssNs[i])) { 790 syncscope = ssNs[i]; 791 break; 792 } 793 } 794 b.create<FenceOp>(loc, getLLVMAtomicOrdering(fence->getOrdering()), 795 syncscope); 796 return success(); 797 } 798 case llvm::Instruction::GetElementPtr: { 799 // FIXME: Support inbounds GEPs. 800 llvm::GetElementPtrInst *gep = cast<llvm::GetElementPtrInst>(inst); 801 SmallVector<Value, 4> ops; 802 for (auto *op : gep->operand_values()) { 803 Value value = processValue(op); 804 if (!value) 805 return failure(); 806 ops.push_back(value); 807 } 808 Type type = processType(inst->getType()); 809 if (!type) 810 return failure(); 811 v = b.create<GEPOp>(loc, type, ops); 812 return success(); 813 } 814 } 815 } 816 817 FlatSymbolRefAttr Importer::getPersonalityAsAttr(llvm::Function *f) { 818 if (!f->hasPersonalityFn()) 819 return nullptr; 820 821 llvm::Constant *pf = f->getPersonalityFn(); 822 823 // If it directly has a name, we can use it. 824 if (pf->hasName()) 825 return b.getSymbolRefAttr(pf->getName()); 826 827 // If it doesn't have a name, currently, only function pointers that are 828 // bitcast to i8* are parsed. 829 if (auto ce = dyn_cast<llvm::ConstantExpr>(pf)) { 830 if (ce->getOpcode() == llvm::Instruction::BitCast && 831 ce->getType() == llvm::Type::getInt8PtrTy(dialect->getLLVMContext())) { 832 if (auto func = dyn_cast<llvm::Function>(ce->getOperand(0))) 833 return b.getSymbolRefAttr(func->getName()); 834 } 835 } 836 return FlatSymbolRefAttr(); 837 } 838 839 LogicalResult Importer::processFunction(llvm::Function *f) { 840 blocks.clear(); 841 instMap.clear(); 842 unknownInstMap.clear(); 843 844 LLVMType functionType = processType(f->getFunctionType()); 845 if (!functionType) 846 return failure(); 847 848 b.setInsertionPoint(module.getBody(), getFuncInsertPt()); 849 LLVMFuncOp fop = b.create<LLVMFuncOp>(UnknownLoc::get(context), f->getName(), 850 functionType); 851 852 if (FlatSymbolRefAttr personality = getPersonalityAsAttr(f)) 853 fop.setAttr(b.getIdentifier("personality"), personality); 854 else if (f->hasPersonalityFn()) 855 emitWarning(UnknownLoc::get(context), 856 "could not deduce personality, skipping it"); 857 858 if (f->isDeclaration()) 859 return success(); 860 861 // Eagerly create all blocks. 862 SmallVector<Block *, 4> blockList; 863 for (llvm::BasicBlock &bb : *f) { 864 blockList.push_back(b.createBlock(&fop.body(), fop.body().end())); 865 blocks[&bb] = blockList.back(); 866 } 867 currentEntryBlock = blockList[0]; 868 869 // Add function arguments to the entry block. 870 for (auto kv : llvm::enumerate(f->args())) 871 instMap[&kv.value()] = blockList[0]->addArgument( 872 functionType.getFunctionParamType(kv.index())); 873 874 for (auto bbs : llvm::zip(*f, blockList)) { 875 if (failed(processBasicBlock(&std::get<0>(bbs), std::get<1>(bbs)))) 876 return failure(); 877 } 878 879 // Now that all instructions are guaranteed to have been visited, ensure 880 // any unknown uses we encountered are remapped. 881 for (auto &llvmAndUnknown : unknownInstMap) { 882 assert(instMap.count(llvmAndUnknown.first)); 883 Value newValue = instMap[llvmAndUnknown.first]; 884 Value oldValue = llvmAndUnknown.second->getResult(0); 885 oldValue.replaceAllUsesWith(newValue); 886 llvmAndUnknown.second->erase(); 887 } 888 return success(); 889 } 890 891 LogicalResult Importer::processBasicBlock(llvm::BasicBlock *bb, Block *block) { 892 b.setInsertionPointToStart(block); 893 for (llvm::Instruction &inst : *bb) { 894 if (failed(processInstruction(&inst))) 895 return failure(); 896 } 897 return success(); 898 } 899 900 OwningModuleRef 901 mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule, 902 MLIRContext *context) { 903 OwningModuleRef module(ModuleOp::create( 904 FileLineColLoc::get("", /*line=*/0, /*column=*/0, context))); 905 906 Importer deserializer(context, module.get()); 907 for (llvm::GlobalVariable &gv : llvmModule->globals()) { 908 if (!deserializer.processGlobal(&gv)) 909 return {}; 910 } 911 for (llvm::Function &f : llvmModule->functions()) { 912 if (failed(deserializer.processFunction(&f))) 913 return {}; 914 } 915 916 return module; 917 } 918 919 // Deserializes the LLVM bitcode stored in `input` into an MLIR module in the 920 // LLVM dialect. 921 OwningModuleRef translateLLVMIRToModule(llvm::SourceMgr &sourceMgr, 922 MLIRContext *context) { 923 LLVMDialect *dialect = context->getRegisteredDialect<LLVMDialect>(); 924 assert(dialect && "Could not find LLVMDialect?"); 925 926 llvm::SMDiagnostic err; 927 std::unique_ptr<llvm::Module> llvmModule = 928 llvm::parseIR(*sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), err, 929 dialect->getLLVMContext()); 930 if (!llvmModule) { 931 std::string errStr; 932 llvm::raw_string_ostream errStream(errStr); 933 err.print(/*ProgName=*/"", errStream); 934 emitError(UnknownLoc::get(context)) << errStream.str(); 935 return {}; 936 } 937 return translateLLVMIRToModule(std::move(llvmModule), context); 938 } 939 940 namespace mlir { 941 void registerFromLLVMIRTranslation() { 942 TranslateToMLIRRegistration fromLLVM( 943 "import-llvm", [](llvm::SourceMgr &sourceMgr, MLIRContext *context) { 944 return ::translateLLVMIRToModule(sourceMgr, context); 945 }); 946 } 947 } // namespace mlir 948