1 //===- ConvertFromLLVMIR.cpp - MLIR to LLVM IR conversion -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a translation between LLVM IR and the MLIR LLVM dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "mlir/IR/Builders.h" 15 #include "mlir/IR/MLIRContext.h" 16 #include "mlir/IR/Module.h" 17 #include "mlir/IR/StandardTypes.h" 18 #include "mlir/Target/LLVMIR.h" 19 #include "mlir/Translation.h" 20 21 #include "llvm/IR/Attributes.h" 22 #include "llvm/IR/Constants.h" 23 #include "llvm/IR/Function.h" 24 #include "llvm/IR/Instructions.h" 25 #include "llvm/IR/Type.h" 26 #include "llvm/IRReader/IRReader.h" 27 #include "llvm/Support/Error.h" 28 #include "llvm/Support/SourceMgr.h" 29 30 using namespace mlir; 31 using namespace mlir::LLVM; 32 33 // Utility to print an LLVM value as a string for passing to emitError(). 34 // FIXME: Diagnostic should be able to natively handle types that have 35 // operator << (raw_ostream&) defined. 36 static std::string diag(llvm::Value &v) { 37 std::string s; 38 llvm::raw_string_ostream os(s); 39 os << v; 40 return os.str(); 41 } 42 43 // Handles importing globals and functions from an LLVM module. 44 namespace { 45 class Importer { 46 public: 47 Importer(MLIRContext *context, ModuleOp module) 48 : b(context), context(context), module(module), 49 unknownLoc(FileLineColLoc::get("imported-bitcode", 0, 0, context)) { 50 b.setInsertionPointToStart(module.getBody()); 51 dialect = context->getRegisteredDialect<LLVMDialect>(); 52 } 53 54 /// Imports `f` into the current module. 55 LogicalResult processFunction(llvm::Function *f); 56 57 /// Imports GV as a GlobalOp, creating it if it doesn't exist. 58 GlobalOp processGlobal(llvm::GlobalVariable *GV); 59 60 private: 61 /// Imports `bb` into `block`, which must be initially empty. 62 LogicalResult processBasicBlock(llvm::BasicBlock *bb, Block *block); 63 /// Imports `inst` and populates instMap[inst] with the imported Value. 64 LogicalResult processInstruction(llvm::Instruction *inst); 65 /// Creates an LLVMType for `type`. 66 LLVMType processType(llvm::Type *type); 67 /// `value` is an SSA-use. Return the remapped version of `value` or a 68 /// placeholder that will be remapped later if this is an instruction that 69 /// has not yet been visited. 70 Value processValue(llvm::Value *value); 71 /// Create the most accurate Location possible using a llvm::DebugLoc and 72 /// possibly an llvm::Instruction to narrow the Location if debug information 73 /// is unavailable. 74 Location processDebugLoc(const llvm::DebugLoc &loc, 75 llvm::Instruction *inst = nullptr); 76 /// `br` branches to `target`. Append the block arguments to attach to the 77 /// generated branch op to `blockArguments`. These should be in the same order 78 /// as the PHIs in `target`. 79 LogicalResult processBranchArgs(llvm::BranchInst *br, 80 llvm::BasicBlock *target, 81 SmallVectorImpl<Value> &blockArguments); 82 /// Returns the standard type equivalent to be used in attributes for the 83 /// given LLVM IR dialect type. 84 Type getStdTypeForAttr(LLVMType type); 85 /// Return `value` as an attribute to attach to a GlobalOp. 86 Attribute getConstantAsAttr(llvm::Constant *value); 87 /// Return `c` as an MLIR Value. This could either be a ConstantOp, or 88 /// an expanded sequence of ops in the current function's entry block (for 89 /// ConstantExprs or ConstantGEPs). 90 Value processConstant(llvm::Constant *c); 91 92 /// The current builder, pointing at where the next Instruction should be 93 /// generated. 94 OpBuilder b; 95 /// The current context. 96 MLIRContext *context; 97 /// The current module being created. 98 ModuleOp module; 99 /// The entry block of the current function being processed. 100 Block *currentEntryBlock; 101 102 /// Globals are inserted before the first function, if any. 103 Block::iterator getGlobalInsertPt() { 104 auto i = module.getBody()->begin(); 105 while (!isa<LLVMFuncOp>(i) && !isa<ModuleTerminatorOp>(i)) 106 ++i; 107 return i; 108 } 109 110 /// Functions are always inserted before the module terminator. 111 Block::iterator getFuncInsertPt() { 112 return std::prev(module.getBody()->end()); 113 } 114 115 /// Remapped blocks, for the current function. 116 DenseMap<llvm::BasicBlock *, Block *> blocks; 117 /// Remapped values. These are function-local. 118 DenseMap<llvm::Value *, Value> instMap; 119 /// Instructions that had not been defined when first encountered as a use. 120 /// Maps to the dummy Operation that was created in processValue(). 121 DenseMap<llvm::Value *, Operation *> unknownInstMap; 122 /// Uniquing map of GlobalVariables. 123 DenseMap<llvm::GlobalVariable *, GlobalOp> globals; 124 /// Cached FileLineColLoc::get("imported-bitcode", 0, 0). 125 Location unknownLoc; 126 /// Cached dialect. 127 LLVMDialect *dialect; 128 }; 129 } // namespace 130 131 Location Importer::processDebugLoc(const llvm::DebugLoc &loc, 132 llvm::Instruction *inst) { 133 if (!loc && inst) { 134 std::string s; 135 llvm::raw_string_ostream os(s); 136 os << "llvm-imported-inst-%"; 137 inst->printAsOperand(os, /*PrintType=*/false); 138 return FileLineColLoc::get(os.str(), 0, 0, context); 139 } else if (!loc) { 140 return unknownLoc; 141 } 142 // FIXME: Obtain the filename from DILocationInfo. 143 return FileLineColLoc::get("imported-bitcode", loc.getLine(), loc.getCol(), 144 context); 145 } 146 147 LLVMType Importer::processType(llvm::Type *type) { 148 switch (type->getTypeID()) { 149 case llvm::Type::FloatTyID: 150 return LLVMType::getFloatTy(dialect); 151 case llvm::Type::DoubleTyID: 152 return LLVMType::getDoubleTy(dialect); 153 case llvm::Type::IntegerTyID: 154 return LLVMType::getIntNTy(dialect, type->getIntegerBitWidth()); 155 case llvm::Type::PointerTyID: { 156 LLVMType elementType = processType(type->getPointerElementType()); 157 if (!elementType) 158 return nullptr; 159 return elementType.getPointerTo(type->getPointerAddressSpace()); 160 } 161 case llvm::Type::ArrayTyID: { 162 LLVMType elementType = processType(type->getArrayElementType()); 163 if (!elementType) 164 return nullptr; 165 return LLVMType::getArrayTy(elementType, type->getArrayNumElements()); 166 } 167 case llvm::Type::VectorTyID: { 168 if (type->getVectorIsScalable()) { 169 emitError(unknownLoc) << "scalable vector types not supported"; 170 return nullptr; 171 } 172 LLVMType elementType = processType(type->getVectorElementType()); 173 if (!elementType) 174 return nullptr; 175 return LLVMType::getVectorTy(elementType, type->getVectorNumElements()); 176 } 177 case llvm::Type::VoidTyID: 178 return LLVMType::getVoidTy(dialect); 179 case llvm::Type::FP128TyID: 180 return LLVMType::getFP128Ty(dialect); 181 case llvm::Type::X86_FP80TyID: 182 return LLVMType::getX86_FP80Ty(dialect); 183 case llvm::Type::StructTyID: { 184 SmallVector<LLVMType, 4> elementTypes; 185 elementTypes.reserve(type->getStructNumElements()); 186 for (unsigned i = 0, e = type->getStructNumElements(); i != e; ++i) { 187 LLVMType ty = processType(type->getStructElementType(i)); 188 if (!ty) 189 return nullptr; 190 elementTypes.push_back(ty); 191 } 192 return LLVMType::getStructTy(dialect, elementTypes, 193 cast<llvm::StructType>(type)->isPacked()); 194 } 195 case llvm::Type::FunctionTyID: { 196 llvm::FunctionType *fty = cast<llvm::FunctionType>(type); 197 SmallVector<LLVMType, 4> paramTypes; 198 for (unsigned i = 0, e = fty->getNumParams(); i != e; ++i) { 199 LLVMType ty = processType(fty->getParamType(i)); 200 if (!ty) 201 return nullptr; 202 paramTypes.push_back(ty); 203 } 204 LLVMType result = processType(fty->getReturnType()); 205 if (!result) 206 return nullptr; 207 208 return LLVMType::getFunctionTy(result, paramTypes, fty->isVarArg()); 209 } 210 default: { 211 // FIXME: Diagnostic should be able to natively handle types that have 212 // operator<<(raw_ostream&) defined. 213 std::string s; 214 llvm::raw_string_ostream os(s); 215 os << *type; 216 emitError(unknownLoc) << "unhandled type: " << os.str(); 217 return nullptr; 218 } 219 } 220 } 221 222 // We only need integers, floats, doubles, and vectors and tensors thereof for 223 // attributes. Scalar and vector types are converted to the standard 224 // equivalents. Array types are converted to ranked tensors; nested array types 225 // are converted to multi-dimensional tensors or vectors, depending on the 226 // innermost type being a scalar or a vector. 227 Type Importer::getStdTypeForAttr(LLVMType type) { 228 if (!type) 229 return nullptr; 230 231 if (type.isIntegerTy()) 232 return b.getIntegerType(type.getUnderlyingType()->getIntegerBitWidth()); 233 234 if (type.getUnderlyingType()->isFloatTy()) 235 return b.getF32Type(); 236 237 if (type.getUnderlyingType()->isDoubleTy()) 238 return b.getF64Type(); 239 240 // LLVM vectors can only contain scalars. 241 if (type.isVectorTy()) { 242 auto numElements = type.getUnderlyingType()->getVectorElementCount(); 243 if (numElements.Scalable) { 244 emitError(unknownLoc) << "scalable vectors not supported"; 245 return nullptr; 246 } 247 Type elementType = getStdTypeForAttr(type.getVectorElementType()); 248 if (!elementType) 249 return nullptr; 250 return VectorType::get(numElements.Min, elementType); 251 } 252 253 // LLVM arrays can contain other arrays or vectors. 254 if (type.isArrayTy()) { 255 // Recover the nested array shape. 256 SmallVector<int64_t, 4> shape; 257 shape.push_back(type.getArrayNumElements()); 258 while (type.getArrayElementType().isArrayTy()) { 259 type = type.getArrayElementType(); 260 shape.push_back(type.getArrayNumElements()); 261 } 262 263 // If the innermost type is a vector, use the multi-dimensional vector as 264 // attribute type. 265 if (type.getArrayElementType().isVectorTy()) { 266 LLVMType vectorType = type.getArrayElementType(); 267 auto numElements = 268 vectorType.getUnderlyingType()->getVectorElementCount(); 269 if (numElements.Scalable) { 270 emitError(unknownLoc) << "scalable vectors not supported"; 271 return nullptr; 272 } 273 shape.push_back(numElements.Min); 274 275 Type elementType = getStdTypeForAttr(vectorType.getVectorElementType()); 276 if (!elementType) 277 return nullptr; 278 return VectorType::get(shape, elementType); 279 } 280 281 // Otherwise use a tensor. 282 Type elementType = getStdTypeForAttr(type.getArrayElementType()); 283 if (!elementType) 284 return nullptr; 285 return RankedTensorType::get(shape, elementType); 286 } 287 288 return nullptr; 289 } 290 291 // Get the given constant as an attribute. Not all constants can be represented 292 // as attributes. 293 Attribute Importer::getConstantAsAttr(llvm::Constant *value) { 294 if (auto *ci = dyn_cast<llvm::ConstantInt>(value)) 295 return b.getIntegerAttr( 296 IntegerType::get(ci->getType()->getBitWidth(), context), 297 ci->getValue()); 298 if (auto *c = dyn_cast<llvm::ConstantDataArray>(value)) 299 if (c->isString()) 300 return b.getStringAttr(c->getAsString()); 301 if (auto *c = dyn_cast<llvm::ConstantFP>(value)) { 302 if (c->getType()->isDoubleTy()) 303 return b.getFloatAttr(FloatType::getF64(context), c->getValueAPF()); 304 else if (c->getType()->isFloatingPointTy()) 305 return b.getFloatAttr(FloatType::getF32(context), c->getValueAPF()); 306 } 307 if (auto *f = dyn_cast<llvm::Function>(value)) 308 return b.getSymbolRefAttr(f->getName()); 309 310 // Convert constant data to a dense elements attribute. 311 if (auto *cd = dyn_cast<llvm::ConstantDataSequential>(value)) { 312 LLVMType type = processType(cd->getElementType()); 313 if (!type) 314 return nullptr; 315 316 auto attrType = getStdTypeForAttr(processType(cd->getType())) 317 .dyn_cast_or_null<ShapedType>(); 318 if (!attrType) 319 return nullptr; 320 321 if (type.isIntegerTy()) { 322 SmallVector<APInt, 8> values; 323 values.reserve(cd->getNumElements()); 324 for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i) 325 values.push_back(cd->getElementAsAPInt(i)); 326 return DenseElementsAttr::get(attrType, values); 327 } 328 329 if (type.isFloatTy() || type.isDoubleTy()) { 330 SmallVector<APFloat, 8> values; 331 values.reserve(cd->getNumElements()); 332 for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i) 333 values.push_back(cd->getElementAsAPFloat(i)); 334 return DenseElementsAttr::get(attrType, values); 335 } 336 337 return nullptr; 338 } 339 340 // Unpack constant aggregates to create dense elements attribute whenever 341 // possible. Return nullptr (failure) otherwise. 342 if (auto *ca = dyn_cast<llvm::ConstantAggregate>(value)) { 343 auto outerType = getStdTypeForAttr(processType(value->getType())) 344 .dyn_cast_or_null<ShapedType>(); 345 if (!outerType) 346 return nullptr; 347 348 SmallVector<Attribute, 8> values; 349 SmallVector<int64_t, 8> shape; 350 351 for (unsigned i = 0, e = value->getNumOperands(); i < e; ++i) { 352 auto nested = getConstantAsAttr(value->getAggregateElement(i)) 353 .dyn_cast_or_null<DenseElementsAttr>(); 354 if (!nested) 355 return nullptr; 356 357 values.append(nested.attr_value_begin(), nested.attr_value_end()); 358 } 359 360 return DenseElementsAttr::get(outerType, values); 361 } 362 363 return nullptr; 364 } 365 366 /// Converts LLVM global variable linkage type into the LLVM dialect predicate. 367 static LLVM::Linkage 368 processLinkage(llvm::GlobalVariable::LinkageTypes linkage) { 369 switch (linkage) { 370 case llvm::GlobalValue::PrivateLinkage: 371 return LLVM::Linkage::Private; 372 case llvm::GlobalValue::InternalLinkage: 373 return LLVM::Linkage::Internal; 374 case llvm::GlobalValue::AvailableExternallyLinkage: 375 return LLVM::Linkage::AvailableExternally; 376 case llvm::GlobalValue::LinkOnceAnyLinkage: 377 return LLVM::Linkage::Linkonce; 378 case llvm::GlobalValue::WeakAnyLinkage: 379 return LLVM::Linkage::Weak; 380 case llvm::GlobalValue::CommonLinkage: 381 return LLVM::Linkage::Common; 382 case llvm::GlobalValue::AppendingLinkage: 383 return LLVM::Linkage::Appending; 384 case llvm::GlobalValue::ExternalWeakLinkage: 385 return LLVM::Linkage::ExternWeak; 386 case llvm::GlobalValue::LinkOnceODRLinkage: 387 return LLVM::Linkage::LinkonceODR; 388 case llvm::GlobalValue::WeakODRLinkage: 389 return LLVM::Linkage::WeakODR; 390 case llvm::GlobalValue::ExternalLinkage: 391 return LLVM::Linkage::External; 392 } 393 394 llvm_unreachable("unhandled linkage type"); 395 } 396 397 GlobalOp Importer::processGlobal(llvm::GlobalVariable *GV) { 398 auto it = globals.find(GV); 399 if (it != globals.end()) 400 return it->second; 401 402 OpBuilder b(module.getBody(), getGlobalInsertPt()); 403 Attribute valueAttr; 404 if (GV->hasInitializer()) 405 valueAttr = getConstantAsAttr(GV->getInitializer()); 406 LLVMType type = processType(GV->getValueType()); 407 if (!type) 408 return nullptr; 409 GlobalOp op = b.create<GlobalOp>( 410 UnknownLoc::get(context), type, GV->isConstant(), 411 processLinkage(GV->getLinkage()), GV->getName(), valueAttr); 412 if (GV->hasInitializer() && !valueAttr) { 413 Region &r = op.getInitializerRegion(); 414 currentEntryBlock = b.createBlock(&r); 415 b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin()); 416 Value v = processConstant(GV->getInitializer()); 417 if (!v) 418 return nullptr; 419 b.create<ReturnOp>(op.getLoc(), ArrayRef<Value>({v})); 420 } 421 return globals[GV] = op; 422 } 423 424 Value Importer::processConstant(llvm::Constant *c) { 425 if (Attribute attr = getConstantAsAttr(c)) { 426 // These constants can be represented as attributes. 427 OpBuilder b(currentEntryBlock, currentEntryBlock->begin()); 428 LLVMType type = processType(c->getType()); 429 if (!type) 430 return nullptr; 431 return instMap[c] = b.create<ConstantOp>(unknownLoc, type, attr); 432 } 433 if (auto *cn = dyn_cast<llvm::ConstantPointerNull>(c)) { 434 OpBuilder b(currentEntryBlock, currentEntryBlock->begin()); 435 LLVMType type = processType(cn->getType()); 436 if (!type) 437 return nullptr; 438 return instMap[c] = b.create<NullOp>(unknownLoc, type); 439 } 440 if (auto *ce = dyn_cast<llvm::ConstantExpr>(c)) { 441 llvm::Instruction *i = ce->getAsInstruction(); 442 OpBuilder::InsertionGuard guard(b); 443 b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin()); 444 if (failed(processInstruction(i))) 445 return nullptr; 446 assert(instMap.count(i)); 447 448 // Remove this zombie LLVM instruction now, leaving us only with the MLIR 449 // op. 450 i->deleteValue(); 451 return instMap[c] = instMap[i]; 452 } 453 emitError(unknownLoc) << "unhandled constant: " << diag(*c); 454 return nullptr; 455 } 456 457 Value Importer::processValue(llvm::Value *value) { 458 auto it = instMap.find(value); 459 if (it != instMap.end()) 460 return it->second; 461 462 // We don't expect to see instructions in dominator order. If we haven't seen 463 // this instruction yet, create an unknown op and remap it later. 464 if (isa<llvm::Instruction>(value)) { 465 OperationState state(UnknownLoc::get(context), "unknown"); 466 LLVMType type = processType(value->getType()); 467 if (!type) 468 return nullptr; 469 state.addTypes(type); 470 unknownInstMap[value] = b.createOperation(state); 471 return unknownInstMap[value]->getResult(0); 472 } 473 474 if (auto *GV = dyn_cast<llvm::GlobalVariable>(value)) { 475 auto global = processGlobal(GV); 476 if (!global) 477 return nullptr; 478 return b.create<AddressOfOp>(UnknownLoc::get(context), global, 479 ArrayRef<NamedAttribute>()); 480 } 481 482 // Note, constant global variables are both GlobalVariables and Constants, 483 // so we handle GlobalVariables first above. 484 if (auto *c = dyn_cast<llvm::Constant>(value)) 485 return processConstant(c); 486 487 emitError(unknownLoc) << "unhandled value: " << diag(*value); 488 return nullptr; 489 } 490 491 // Maps from LLVM opcode to MLIR OperationName. This is deliberately ordered 492 // as in llvm/IR/Instructions.def to aid comprehension and spot missing 493 // instructions. 494 #define INST(llvm_n, mlir_n) \ 495 { llvm::Instruction::llvm_n, LLVM::mlir_n##Op::getOperationName() } 496 static const DenseMap<unsigned, StringRef> opcMap = { 497 // Ret is handled specially. 498 // Br is handled specially. 499 // FIXME: switch 500 // FIXME: indirectbr 501 // FIXME: invoke 502 // FIXME: resume 503 // FIXME: unreachable 504 // FIXME: cleanupret 505 // FIXME: catchret 506 // FIXME: catchswitch 507 // FIXME: callbr 508 // FIXME: fneg 509 INST(Add, Add), INST(FAdd, FAdd), INST(Sub, Sub), INST(FSub, FSub), 510 INST(Mul, Mul), INST(FMul, FMul), INST(UDiv, UDiv), INST(SDiv, SDiv), 511 INST(FDiv, FDiv), INST(URem, URem), INST(SRem, SRem), INST(FRem, FRem), 512 INST(Shl, Shl), INST(LShr, LShr), INST(AShr, AShr), INST(And, And), 513 INST(Or, Or), INST(Xor, XOr), INST(Alloca, Alloca), INST(Load, Load), 514 INST(Store, Store), 515 // Getelementptr is handled specially. 516 INST(Ret, Return), 517 // FIXME: fence 518 // FIXME: atomiccmpxchg 519 // FIXME: atomicrmw 520 INST(Trunc, Trunc), INST(ZExt, ZExt), INST(SExt, SExt), 521 INST(FPToUI, FPToUI), INST(FPToSI, FPToSI), INST(UIToFP, UIToFP), 522 INST(SIToFP, SIToFP), INST(FPTrunc, FPTrunc), INST(FPExt, FPExt), 523 INST(PtrToInt, PtrToInt), INST(IntToPtr, IntToPtr), INST(BitCast, Bitcast), 524 INST(AddrSpaceCast, AddrSpaceCast), 525 // FIXME: cleanuppad 526 // FIXME: catchpad 527 // ICmp is handled specially. 528 // FIXME: fcmp 529 // PHI is handled specially. 530 INST(Call, Call), 531 // FIXME: select 532 // FIXME: vaarg 533 // FIXME: extractelement 534 // FIXME: insertelement 535 // FIXME: shufflevector 536 // FIXME: extractvalue 537 // FIXME: insertvalue 538 // FIXME: landingpad 539 }; 540 #undef INST 541 542 static ICmpPredicate getICmpPredicate(llvm::CmpInst::Predicate p) { 543 switch (p) { 544 default: 545 llvm_unreachable("incorrect comparison predicate"); 546 case llvm::CmpInst::Predicate::ICMP_EQ: 547 return LLVM::ICmpPredicate::eq; 548 case llvm::CmpInst::Predicate::ICMP_NE: 549 return LLVM::ICmpPredicate::ne; 550 case llvm::CmpInst::Predicate::ICMP_SLT: 551 return LLVM::ICmpPredicate::slt; 552 case llvm::CmpInst::Predicate::ICMP_SLE: 553 return LLVM::ICmpPredicate::sle; 554 case llvm::CmpInst::Predicate::ICMP_SGT: 555 return LLVM::ICmpPredicate::sgt; 556 case llvm::CmpInst::Predicate::ICMP_SGE: 557 return LLVM::ICmpPredicate::sge; 558 case llvm::CmpInst::Predicate::ICMP_ULT: 559 return LLVM::ICmpPredicate::ult; 560 case llvm::CmpInst::Predicate::ICMP_ULE: 561 return LLVM::ICmpPredicate::ule; 562 case llvm::CmpInst::Predicate::ICMP_UGT: 563 return LLVM::ICmpPredicate::ugt; 564 case llvm::CmpInst::Predicate::ICMP_UGE: 565 return LLVM::ICmpPredicate::uge; 566 } 567 llvm_unreachable("incorrect comparison predicate"); 568 } 569 570 // `br` branches to `target`. Return the branch arguments to `br`, in the 571 // same order of the PHIs in `target`. 572 LogicalResult 573 Importer::processBranchArgs(llvm::BranchInst *br, llvm::BasicBlock *target, 574 SmallVectorImpl<Value> &blockArguments) { 575 for (auto inst = target->begin(); isa<llvm::PHINode>(inst); ++inst) { 576 auto *PN = cast<llvm::PHINode>(&*inst); 577 Value value = processValue(PN->getIncomingValueForBlock(br->getParent())); 578 if (!value) 579 return failure(); 580 blockArguments.push_back(value); 581 } 582 return success(); 583 } 584 585 LogicalResult Importer::processInstruction(llvm::Instruction *inst) { 586 // FIXME: Support uses of SubtargetData. Currently inbounds GEPs, fast-math 587 // flags and call / operand attributes are not supported. 588 Location loc = processDebugLoc(inst->getDebugLoc(), inst); 589 Value &v = instMap[inst]; 590 assert(!v && "processInstruction must be called only once per instruction!"); 591 switch (inst->getOpcode()) { 592 default: 593 return emitError(loc) << "unknown instruction: " << diag(*inst); 594 case llvm::Instruction::Add: 595 case llvm::Instruction::FAdd: 596 case llvm::Instruction::Sub: 597 case llvm::Instruction::FSub: 598 case llvm::Instruction::Mul: 599 case llvm::Instruction::FMul: 600 case llvm::Instruction::UDiv: 601 case llvm::Instruction::SDiv: 602 case llvm::Instruction::FDiv: 603 case llvm::Instruction::URem: 604 case llvm::Instruction::SRem: 605 case llvm::Instruction::FRem: 606 case llvm::Instruction::Shl: 607 case llvm::Instruction::LShr: 608 case llvm::Instruction::AShr: 609 case llvm::Instruction::And: 610 case llvm::Instruction::Or: 611 case llvm::Instruction::Xor: 612 case llvm::Instruction::Alloca: 613 case llvm::Instruction::Load: 614 case llvm::Instruction::Store: 615 case llvm::Instruction::Ret: 616 case llvm::Instruction::Trunc: 617 case llvm::Instruction::ZExt: 618 case llvm::Instruction::SExt: 619 case llvm::Instruction::FPToUI: 620 case llvm::Instruction::FPToSI: 621 case llvm::Instruction::UIToFP: 622 case llvm::Instruction::SIToFP: 623 case llvm::Instruction::FPTrunc: 624 case llvm::Instruction::FPExt: 625 case llvm::Instruction::PtrToInt: 626 case llvm::Instruction::IntToPtr: 627 case llvm::Instruction::AddrSpaceCast: 628 case llvm::Instruction::BitCast: { 629 OperationState state(loc, opcMap.lookup(inst->getOpcode())); 630 SmallVector<Value, 4> ops; 631 ops.reserve(inst->getNumOperands()); 632 for (auto *op : inst->operand_values()) { 633 Value value = processValue(op); 634 if (!value) 635 return failure(); 636 ops.push_back(value); 637 } 638 state.addOperands(ops); 639 if (!inst->getType()->isVoidTy()) { 640 LLVMType type = processType(inst->getType()); 641 if (!type) 642 return failure(); 643 state.addTypes(type); 644 } 645 Operation *op = b.createOperation(state); 646 if (!inst->getType()->isVoidTy()) 647 v = op->getResult(0); 648 return success(); 649 } 650 case llvm::Instruction::ICmp: { 651 Value lhs = processValue(inst->getOperand(0)); 652 Value rhs = processValue(inst->getOperand(1)); 653 if (!lhs || !rhs) 654 return failure(); 655 v = b.create<ICmpOp>( 656 loc, getICmpPredicate(cast<llvm::ICmpInst>(inst)->getPredicate()), lhs, 657 rhs); 658 return success(); 659 } 660 case llvm::Instruction::Br: { 661 auto *brInst = cast<llvm::BranchInst>(inst); 662 OperationState state(loc, 663 brInst->isConditional() ? "llvm.cond_br" : "llvm.br"); 664 SmallVector<Value, 4> ops; 665 if (brInst->isConditional()) { 666 Value condition = processValue(brInst->getCondition()); 667 if (!condition) 668 return failure(); 669 ops.push_back(condition); 670 } 671 state.addOperands(ops); 672 SmallVector<Block *, 4> succs; 673 for (auto *succ : llvm::reverse(brInst->successors())) { 674 SmallVector<Value, 4> blockArguments; 675 if (failed(processBranchArgs(brInst, succ, blockArguments))) 676 return failure(); 677 state.addSuccessor(blocks[succ], blockArguments); 678 } 679 b.createOperation(state); 680 return success(); 681 } 682 case llvm::Instruction::PHI: { 683 LLVMType type = processType(inst->getType()); 684 if (!type) 685 return failure(); 686 v = b.getInsertionBlock()->addArgument(type); 687 return success(); 688 } 689 case llvm::Instruction::Call: { 690 llvm::CallInst *ci = cast<llvm::CallInst>(inst); 691 SmallVector<Value, 4> ops; 692 ops.reserve(inst->getNumOperands()); 693 for (auto &op : ci->arg_operands()) { 694 Value arg = processValue(op.get()); 695 if (!arg) 696 return failure(); 697 ops.push_back(arg); 698 } 699 700 SmallVector<Type, 2> tys; 701 if (!ci->getType()->isVoidTy()) { 702 LLVMType type = processType(inst->getType()); 703 if (!type) 704 return failure(); 705 tys.push_back(type); 706 } 707 Operation *op; 708 if (llvm::Function *callee = ci->getCalledFunction()) { 709 op = b.create<CallOp>(loc, tys, b.getSymbolRefAttr(callee->getName()), 710 ops); 711 } else { 712 Value calledValue = processValue(ci->getCalledValue()); 713 if (!calledValue) 714 return failure(); 715 ops.insert(ops.begin(), calledValue); 716 op = b.create<CallOp>(loc, tys, ops, ArrayRef<NamedAttribute>()); 717 } 718 if (!ci->getType()->isVoidTy()) 719 v = op->getResult(0); 720 return success(); 721 } 722 case llvm::Instruction::GetElementPtr: { 723 // FIXME: Support inbounds GEPs. 724 llvm::GetElementPtrInst *gep = cast<llvm::GetElementPtrInst>(inst); 725 SmallVector<Value, 4> ops; 726 for (auto *op : gep->operand_values()) { 727 Value value = processValue(op); 728 if (!value) 729 return failure(); 730 ops.push_back(value); 731 } 732 Type type = processType(inst->getType()); 733 if (!type) 734 return failure(); 735 v = b.create<GEPOp>(loc, type, ops, ArrayRef<NamedAttribute>()); 736 return success(); 737 } 738 } 739 } 740 741 LogicalResult Importer::processFunction(llvm::Function *f) { 742 blocks.clear(); 743 instMap.clear(); 744 unknownInstMap.clear(); 745 746 LLVMType functionType = processType(f->getFunctionType()); 747 if (!functionType) 748 return failure(); 749 750 b.setInsertionPoint(module.getBody(), getFuncInsertPt()); 751 LLVMFuncOp fop = b.create<LLVMFuncOp>(UnknownLoc::get(context), f->getName(), 752 functionType); 753 if (f->isDeclaration()) 754 return success(); 755 756 // Eagerly create all blocks. 757 SmallVector<Block *, 4> blockList; 758 for (llvm::BasicBlock &bb : *f) { 759 blockList.push_back(b.createBlock(&fop.body(), fop.body().end())); 760 blocks[&bb] = blockList.back(); 761 } 762 currentEntryBlock = blockList[0]; 763 764 // Add function arguments to the entry block. 765 for (auto kv : llvm::enumerate(f->args())) 766 instMap[&kv.value()] = blockList[0]->addArgument( 767 functionType.getFunctionParamType(kv.index())); 768 769 for (auto bbs : llvm::zip(*f, blockList)) { 770 if (failed(processBasicBlock(&std::get<0>(bbs), std::get<1>(bbs)))) 771 return failure(); 772 } 773 774 // Now that all instructions are guaranteed to have been visited, ensure 775 // any unknown uses we encountered are remapped. 776 for (auto &llvmAndUnknown : unknownInstMap) { 777 assert(instMap.count(llvmAndUnknown.first)); 778 Value newValue = instMap[llvmAndUnknown.first]; 779 Value oldValue = llvmAndUnknown.second->getResult(0); 780 oldValue.replaceAllUsesWith(newValue); 781 llvmAndUnknown.second->erase(); 782 } 783 return success(); 784 } 785 786 LogicalResult Importer::processBasicBlock(llvm::BasicBlock *bb, Block *block) { 787 b.setInsertionPointToStart(block); 788 for (llvm::Instruction &inst : *bb) { 789 if (failed(processInstruction(&inst))) 790 return failure(); 791 } 792 return success(); 793 } 794 795 OwningModuleRef 796 mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule, 797 MLIRContext *context) { 798 OwningModuleRef module(ModuleOp::create( 799 FileLineColLoc::get("", /*line=*/0, /*column=*/0, context))); 800 801 Importer deserializer(context, module.get()); 802 for (llvm::GlobalVariable &gv : llvmModule->globals()) { 803 if (!deserializer.processGlobal(&gv)) 804 return {}; 805 } 806 for (llvm::Function &f : llvmModule->functions()) { 807 if (failed(deserializer.processFunction(&f))) 808 return {}; 809 } 810 811 return module; 812 } 813 814 // Deserializes the LLVM bitcode stored in `input` into an MLIR module in the 815 // LLVM dialect. 816 OwningModuleRef translateLLVMIRToModule(llvm::SourceMgr &sourceMgr, 817 MLIRContext *context) { 818 LLVMDialect *dialect = context->getRegisteredDialect<LLVMDialect>(); 819 assert(dialect && "Could not find LLVMDialect?"); 820 821 llvm::SMDiagnostic err; 822 std::unique_ptr<llvm::Module> llvmModule = 823 llvm::parseIR(*sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), err, 824 dialect->getLLVMContext(), 825 /*UpgradeDebugInfo=*/true, 826 /*DataLayoutString=*/""); 827 if (!llvmModule) { 828 std::string errStr; 829 llvm::raw_string_ostream errStream(errStr); 830 err.print(/*ProgName=*/"", errStream); 831 emitError(UnknownLoc::get(context)) << errStream.str(); 832 return {}; 833 } 834 return translateLLVMIRToModule(std::move(llvmModule), context); 835 } 836 837 static TranslateToMLIRRegistration 838 fromLLVM("import-llvm", 839 [](llvm::SourceMgr &sourceMgr, MLIRContext *context) { 840 return translateLLVMIRToModule(sourceMgr, context); 841 }); 842