1 //===- ConvertFromLLVMIR.cpp - MLIR to LLVM IR conversion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between LLVM IR and the MLIR LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/MLIRContext.h"
16 #include "mlir/IR/Module.h"
17 #include "mlir/IR/StandardTypes.h"
18 #include "mlir/Target/LLVMIR.h"
19 #include "mlir/Translation.h"
20 
21 #include "llvm/IR/Attributes.h"
22 #include "llvm/IR/Constants.h"
23 #include "llvm/IR/Function.h"
24 #include "llvm/IR/Instructions.h"
25 #include "llvm/IR/Type.h"
26 #include "llvm/IRReader/IRReader.h"
27 #include "llvm/Support/Error.h"
28 #include "llvm/Support/SourceMgr.h"
29 
30 using namespace mlir;
31 using namespace mlir::LLVM;
32 
33 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsFromLLVM.inc"
34 
35 // Utility to print an LLVM value as a string for passing to emitError().
36 // FIXME: Diagnostic should be able to natively handle types that have
37 // operator << (raw_ostream&) defined.
38 static std::string diag(llvm::Value &v) {
39   std::string s;
40   llvm::raw_string_ostream os(s);
41   os << v;
42   return os.str();
43 }
44 
45 // Handles importing globals and functions from an LLVM module.
46 namespace {
47 class Importer {
48 public:
49   Importer(MLIRContext *context, ModuleOp module)
50       : b(context), context(context), module(module),
51         unknownLoc(FileLineColLoc::get("imported-bitcode", 0, 0, context)) {
52     b.setInsertionPointToStart(module.getBody());
53     dialect = context->getRegisteredDialect<LLVMDialect>();
54   }
55 
56   /// Imports `f` into the current module.
57   LogicalResult processFunction(llvm::Function *f);
58 
59   /// Imports GV as a GlobalOp, creating it if it doesn't exist.
60   GlobalOp processGlobal(llvm::GlobalVariable *GV);
61 
62 private:
63   /// Returns personality of `f` as a FlatSymbolRefAttr.
64   FlatSymbolRefAttr getPersonalityAsAttr(llvm::Function *f);
65   /// Imports `bb` into `block`, which must be initially empty.
66   LogicalResult processBasicBlock(llvm::BasicBlock *bb, Block *block);
67   /// Imports `inst` and populates instMap[inst] with the imported Value.
68   LogicalResult processInstruction(llvm::Instruction *inst);
69   /// Creates an LLVMType for `type`.
70   LLVMType processType(llvm::Type *type);
71   /// `value` is an SSA-use. Return the remapped version of `value` or a
72   /// placeholder that will be remapped later if this is an instruction that
73   /// has not yet been visited.
74   Value processValue(llvm::Value *value);
75   /// Create the most accurate Location possible using a llvm::DebugLoc and
76   /// possibly an llvm::Instruction to narrow the Location if debug information
77   /// is unavailable.
78   Location processDebugLoc(const llvm::DebugLoc &loc,
79                            llvm::Instruction *inst = nullptr);
80   /// `br` branches to `target`. Append the block arguments to attach to the
81   /// generated branch op to `blockArguments`. These should be in the same order
82   /// as the PHIs in `target`.
83   LogicalResult processBranchArgs(llvm::Instruction *br,
84                                   llvm::BasicBlock *target,
85                                   SmallVectorImpl<Value> &blockArguments);
86   /// Returns the standard type equivalent to be used in attributes for the
87   /// given LLVM IR dialect type.
88   Type getStdTypeForAttr(LLVMType type);
89   /// Return `value` as an attribute to attach to a GlobalOp.
90   Attribute getConstantAsAttr(llvm::Constant *value);
91   /// Return `c` as an MLIR Value. This could either be a ConstantOp, or
92   /// an expanded sequence of ops in the current function's entry block (for
93   /// ConstantExprs or ConstantGEPs).
94   Value processConstant(llvm::Constant *c);
95 
96   /// The current builder, pointing at where the next Instruction should be
97   /// generated.
98   OpBuilder b;
99   /// The current context.
100   MLIRContext *context;
101   /// The current module being created.
102   ModuleOp module;
103   /// The entry block of the current function being processed.
104   Block *currentEntryBlock;
105 
106   /// Globals are inserted before the first function, if any.
107   Block::iterator getGlobalInsertPt() {
108     auto i = module.getBody()->begin();
109     while (!isa<LLVMFuncOp>(i) && !isa<ModuleTerminatorOp>(i))
110       ++i;
111     return i;
112   }
113 
114   /// Functions are always inserted before the module terminator.
115   Block::iterator getFuncInsertPt() {
116     return std::prev(module.getBody()->end());
117   }
118 
119   /// Remapped blocks, for the current function.
120   DenseMap<llvm::BasicBlock *, Block *> blocks;
121   /// Remapped values. These are function-local.
122   DenseMap<llvm::Value *, Value> instMap;
123   /// Instructions that had not been defined when first encountered as a use.
124   /// Maps to the dummy Operation that was created in processValue().
125   DenseMap<llvm::Value *, Operation *> unknownInstMap;
126   /// Uniquing map of GlobalVariables.
127   DenseMap<llvm::GlobalVariable *, GlobalOp> globals;
128   /// Cached FileLineColLoc::get("imported-bitcode", 0, 0).
129   Location unknownLoc;
130   /// Cached dialect.
131   LLVMDialect *dialect;
132 };
133 } // namespace
134 
135 Location Importer::processDebugLoc(const llvm::DebugLoc &loc,
136                                    llvm::Instruction *inst) {
137   if (!loc && inst) {
138     std::string s;
139     llvm::raw_string_ostream os(s);
140     os << "llvm-imported-inst-%";
141     inst->printAsOperand(os, /*PrintType=*/false);
142     return FileLineColLoc::get(os.str(), 0, 0, context);
143   } else if (!loc) {
144     return unknownLoc;
145   }
146   // FIXME: Obtain the filename from DILocationInfo.
147   return FileLineColLoc::get("imported-bitcode", loc.getLine(), loc.getCol(),
148                              context);
149 }
150 
151 LLVMType Importer::processType(llvm::Type *type) {
152   switch (type->getTypeID()) {
153   case llvm::Type::FloatTyID:
154     return LLVMType::getFloatTy(dialect);
155   case llvm::Type::DoubleTyID:
156     return LLVMType::getDoubleTy(dialect);
157   case llvm::Type::IntegerTyID:
158     return LLVMType::getIntNTy(dialect, type->getIntegerBitWidth());
159   case llvm::Type::PointerTyID: {
160     LLVMType elementType = processType(type->getPointerElementType());
161     if (!elementType)
162       return nullptr;
163     return elementType.getPointerTo(type->getPointerAddressSpace());
164   }
165   case llvm::Type::ArrayTyID: {
166     LLVMType elementType = processType(type->getArrayElementType());
167     if (!elementType)
168       return nullptr;
169     return LLVMType::getArrayTy(elementType, type->getArrayNumElements());
170   }
171   case llvm::Type::VectorTyID: {
172     if (type->getVectorIsScalable()) {
173       emitError(unknownLoc) << "scalable vector types not supported";
174       return nullptr;
175     }
176     LLVMType elementType = processType(type->getVectorElementType());
177     if (!elementType)
178       return nullptr;
179     return LLVMType::getVectorTy(elementType, type->getVectorNumElements());
180   }
181   case llvm::Type::VoidTyID:
182     return LLVMType::getVoidTy(dialect);
183   case llvm::Type::FP128TyID:
184     return LLVMType::getFP128Ty(dialect);
185   case llvm::Type::X86_FP80TyID:
186     return LLVMType::getX86_FP80Ty(dialect);
187   case llvm::Type::StructTyID: {
188     SmallVector<LLVMType, 4> elementTypes;
189     elementTypes.reserve(type->getStructNumElements());
190     for (unsigned i = 0, e = type->getStructNumElements(); i != e; ++i) {
191       LLVMType ty = processType(type->getStructElementType(i));
192       if (!ty)
193         return nullptr;
194       elementTypes.push_back(ty);
195     }
196     return LLVMType::getStructTy(dialect, elementTypes,
197                                  cast<llvm::StructType>(type)->isPacked());
198   }
199   case llvm::Type::FunctionTyID: {
200     llvm::FunctionType *fty = cast<llvm::FunctionType>(type);
201     SmallVector<LLVMType, 4> paramTypes;
202     for (unsigned i = 0, e = fty->getNumParams(); i != e; ++i) {
203       LLVMType ty = processType(fty->getParamType(i));
204       if (!ty)
205         return nullptr;
206       paramTypes.push_back(ty);
207     }
208     LLVMType result = processType(fty->getReturnType());
209     if (!result)
210       return nullptr;
211 
212     return LLVMType::getFunctionTy(result, paramTypes, fty->isVarArg());
213   }
214   default: {
215     // FIXME: Diagnostic should be able to natively handle types that have
216     // operator<<(raw_ostream&) defined.
217     std::string s;
218     llvm::raw_string_ostream os(s);
219     os << *type;
220     emitError(unknownLoc) << "unhandled type: " << os.str();
221     return nullptr;
222   }
223   }
224 }
225 
226 // We only need integers, floats, doubles, and vectors and tensors thereof for
227 // attributes. Scalar and vector types are converted to the standard
228 // equivalents. Array types are converted to ranked tensors; nested array types
229 // are converted to multi-dimensional tensors or vectors, depending on the
230 // innermost type being a scalar or a vector.
231 Type Importer::getStdTypeForAttr(LLVMType type) {
232   if (!type)
233     return nullptr;
234 
235   if (type.isIntegerTy())
236     return b.getIntegerType(type.getUnderlyingType()->getIntegerBitWidth());
237 
238   if (type.getUnderlyingType()->isFloatTy())
239     return b.getF32Type();
240 
241   if (type.getUnderlyingType()->isDoubleTy())
242     return b.getF64Type();
243 
244   // LLVM vectors can only contain scalars.
245   if (type.isVectorTy()) {
246     auto numElements = type.getUnderlyingType()->getVectorElementCount();
247     if (numElements.Scalable) {
248       emitError(unknownLoc) << "scalable vectors not supported";
249       return nullptr;
250     }
251     Type elementType = getStdTypeForAttr(type.getVectorElementType());
252     if (!elementType)
253       return nullptr;
254     return VectorType::get(numElements.Min, elementType);
255   }
256 
257   // LLVM arrays can contain other arrays or vectors.
258   if (type.isArrayTy()) {
259     // Recover the nested array shape.
260     SmallVector<int64_t, 4> shape;
261     shape.push_back(type.getArrayNumElements());
262     while (type.getArrayElementType().isArrayTy()) {
263       type = type.getArrayElementType();
264       shape.push_back(type.getArrayNumElements());
265     }
266 
267     // If the innermost type is a vector, use the multi-dimensional vector as
268     // attribute type.
269     if (type.getArrayElementType().isVectorTy()) {
270       LLVMType vectorType = type.getArrayElementType();
271       auto numElements =
272           vectorType.getUnderlyingType()->getVectorElementCount();
273       if (numElements.Scalable) {
274         emitError(unknownLoc) << "scalable vectors not supported";
275         return nullptr;
276       }
277       shape.push_back(numElements.Min);
278 
279       Type elementType = getStdTypeForAttr(vectorType.getVectorElementType());
280       if (!elementType)
281         return nullptr;
282       return VectorType::get(shape, elementType);
283     }
284 
285     // Otherwise use a tensor.
286     Type elementType = getStdTypeForAttr(type.getArrayElementType());
287     if (!elementType)
288       return nullptr;
289     return RankedTensorType::get(shape, elementType);
290   }
291 
292   return nullptr;
293 }
294 
295 // Get the given constant as an attribute. Not all constants can be represented
296 // as attributes.
297 Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
298   if (auto *ci = dyn_cast<llvm::ConstantInt>(value))
299     return b.getIntegerAttr(
300         IntegerType::get(ci->getType()->getBitWidth(), context),
301         ci->getValue());
302   if (auto *c = dyn_cast<llvm::ConstantDataArray>(value))
303     if (c->isString())
304       return b.getStringAttr(c->getAsString());
305   if (auto *c = dyn_cast<llvm::ConstantFP>(value)) {
306     if (c->getType()->isDoubleTy())
307       return b.getFloatAttr(FloatType::getF64(context), c->getValueAPF());
308     else if (c->getType()->isFloatingPointTy())
309       return b.getFloatAttr(FloatType::getF32(context), c->getValueAPF());
310   }
311   if (auto *f = dyn_cast<llvm::Function>(value))
312     return b.getSymbolRefAttr(f->getName());
313 
314   // Convert constant data to a dense elements attribute.
315   if (auto *cd = dyn_cast<llvm::ConstantDataSequential>(value)) {
316     LLVMType type = processType(cd->getElementType());
317     if (!type)
318       return nullptr;
319 
320     auto attrType = getStdTypeForAttr(processType(cd->getType()))
321                         .dyn_cast_or_null<ShapedType>();
322     if (!attrType)
323       return nullptr;
324 
325     if (type.isIntegerTy()) {
326       SmallVector<APInt, 8> values;
327       values.reserve(cd->getNumElements());
328       for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i)
329         values.push_back(cd->getElementAsAPInt(i));
330       return DenseElementsAttr::get(attrType, values);
331     }
332 
333     if (type.isFloatTy() || type.isDoubleTy()) {
334       SmallVector<APFloat, 8> values;
335       values.reserve(cd->getNumElements());
336       for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i)
337         values.push_back(cd->getElementAsAPFloat(i));
338       return DenseElementsAttr::get(attrType, values);
339     }
340 
341     return nullptr;
342   }
343 
344   // Unpack constant aggregates to create dense elements attribute whenever
345   // possible. Return nullptr (failure) otherwise.
346   if (isa<llvm::ConstantAggregate>(value)) {
347     auto outerType = getStdTypeForAttr(processType(value->getType()))
348                          .dyn_cast_or_null<ShapedType>();
349     if (!outerType)
350       return nullptr;
351 
352     SmallVector<Attribute, 8> values;
353     SmallVector<int64_t, 8> shape;
354 
355     for (unsigned i = 0, e = value->getNumOperands(); i < e; ++i) {
356       auto nested = getConstantAsAttr(value->getAggregateElement(i))
357                         .dyn_cast_or_null<DenseElementsAttr>();
358       if (!nested)
359         return nullptr;
360 
361       values.append(nested.attr_value_begin(), nested.attr_value_end());
362     }
363 
364     return DenseElementsAttr::get(outerType, values);
365   }
366 
367   return nullptr;
368 }
369 
370 GlobalOp Importer::processGlobal(llvm::GlobalVariable *GV) {
371   auto it = globals.find(GV);
372   if (it != globals.end())
373     return it->second;
374 
375   OpBuilder b(module.getBody(), getGlobalInsertPt());
376   Attribute valueAttr;
377   if (GV->hasInitializer())
378     valueAttr = getConstantAsAttr(GV->getInitializer());
379   LLVMType type = processType(GV->getValueType());
380   if (!type)
381     return nullptr;
382   GlobalOp op = b.create<GlobalOp>(
383       UnknownLoc::get(context), type, GV->isConstant(),
384       convertLinkageFromLLVM(GV->getLinkage()), GV->getName(), valueAttr);
385   if (GV->hasInitializer() && !valueAttr) {
386     Region &r = op.getInitializerRegion();
387     currentEntryBlock = b.createBlock(&r);
388     b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin());
389     Value v = processConstant(GV->getInitializer());
390     if (!v)
391       return nullptr;
392     b.create<ReturnOp>(op.getLoc(), ArrayRef<Value>({v}));
393   }
394   return globals[GV] = op;
395 }
396 
397 Value Importer::processConstant(llvm::Constant *c) {
398   OpBuilder bEntry(currentEntryBlock, currentEntryBlock->begin());
399   if (Attribute attr = getConstantAsAttr(c)) {
400     // These constants can be represented as attributes.
401     OpBuilder b(currentEntryBlock, currentEntryBlock->begin());
402     LLVMType type = processType(c->getType());
403     if (!type)
404       return nullptr;
405     return instMap[c] = bEntry.create<ConstantOp>(unknownLoc, type, attr);
406   }
407   if (auto *cn = dyn_cast<llvm::ConstantPointerNull>(c)) {
408     LLVMType type = processType(cn->getType());
409     if (!type)
410       return nullptr;
411     return instMap[c] = bEntry.create<NullOp>(unknownLoc, type);
412   }
413   if (auto *GV = dyn_cast<llvm::GlobalVariable>(c))
414     return bEntry.create<AddressOfOp>(UnknownLoc::get(context),
415                                       processGlobal(GV),
416                                       ArrayRef<NamedAttribute>());
417 
418   if (auto *ce = dyn_cast<llvm::ConstantExpr>(c)) {
419     llvm::Instruction *i = ce->getAsInstruction();
420     OpBuilder::InsertionGuard guard(b);
421     b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin());
422     if (failed(processInstruction(i)))
423       return nullptr;
424     assert(instMap.count(i));
425 
426     // Remove this zombie LLVM instruction now, leaving us only with the MLIR
427     // op.
428     i->deleteValue();
429     return instMap[c] = instMap[i];
430   }
431   if (auto *ue = dyn_cast<llvm::UndefValue>(c)) {
432     LLVMType type = processType(ue->getType());
433     if (!type)
434       return nullptr;
435     return instMap[c] = bEntry.create<UndefOp>(UnknownLoc::get(context), type);
436   }
437   emitError(unknownLoc) << "unhandled constant: " << diag(*c);
438   return nullptr;
439 }
440 
441 Value Importer::processValue(llvm::Value *value) {
442   auto it = instMap.find(value);
443   if (it != instMap.end())
444     return it->second;
445 
446   // We don't expect to see instructions in dominator order. If we haven't seen
447   // this instruction yet, create an unknown op and remap it later.
448   if (isa<llvm::Instruction>(value)) {
449     OperationState state(UnknownLoc::get(context), "unknown");
450     LLVMType type = processType(value->getType());
451     if (!type)
452       return nullptr;
453     state.addTypes(type);
454     unknownInstMap[value] = b.createOperation(state);
455     return unknownInstMap[value]->getResult(0);
456   }
457 
458   if (auto *c = dyn_cast<llvm::Constant>(value))
459     return processConstant(c);
460 
461   emitError(unknownLoc) << "unhandled value: " << diag(*value);
462   return nullptr;
463 }
464 
465 // Maps from LLVM opcode to MLIR OperationName. This is deliberately ordered
466 // as in llvm/IR/Instructions.def to aid comprehension and spot missing
467 // instructions.
468 #define INST(llvm_n, mlir_n)                                                   \
469   { llvm::Instruction::llvm_n, LLVM::mlir_n##Op::getOperationName() }
470 static const DenseMap<unsigned, StringRef> opcMap = {
471     // Ret is handled specially.
472     // Br is handled specially.
473     // FIXME: switch
474     // FIXME: indirectbr
475     // FIXME: invoke
476     INST(Resume, Resume),
477     // FIXME: unreachable
478     // FIXME: cleanupret
479     // FIXME: catchret
480     // FIXME: catchswitch
481     // FIXME: callbr
482     // FIXME: fneg
483     INST(Add, Add), INST(FAdd, FAdd), INST(Sub, Sub), INST(FSub, FSub),
484     INST(Mul, Mul), INST(FMul, FMul), INST(UDiv, UDiv), INST(SDiv, SDiv),
485     INST(FDiv, FDiv), INST(URem, URem), INST(SRem, SRem), INST(FRem, FRem),
486     INST(Shl, Shl), INST(LShr, LShr), INST(AShr, AShr), INST(And, And),
487     INST(Or, Or), INST(Xor, XOr), INST(Alloca, Alloca), INST(Load, Load),
488     INST(Store, Store),
489     // Getelementptr is handled specially.
490     INST(Ret, Return), INST(Fence, Fence),
491     // FIXME: atomiccmpxchg
492     // FIXME: atomicrmw
493     INST(Trunc, Trunc), INST(ZExt, ZExt), INST(SExt, SExt),
494     INST(FPToUI, FPToUI), INST(FPToSI, FPToSI), INST(UIToFP, UIToFP),
495     INST(SIToFP, SIToFP), INST(FPTrunc, FPTrunc), INST(FPExt, FPExt),
496     INST(PtrToInt, PtrToInt), INST(IntToPtr, IntToPtr), INST(BitCast, Bitcast),
497     INST(AddrSpaceCast, AddrSpaceCast),
498     // FIXME: cleanuppad
499     // FIXME: catchpad
500     // ICmp is handled specially.
501     // FIXME: fcmp
502     // PHI is handled specially.
503     INST(Freeze, Freeze), INST(Call, Call),
504     // FIXME: select
505     // FIXME: vaarg
506     // FIXME: extractelement
507     // FIXME: insertelement
508     // FIXME: shufflevector
509     // FIXME: extractvalue
510     // FIXME: insertvalue
511     // FIXME: landingpad
512 };
513 #undef INST
514 
515 static ICmpPredicate getICmpPredicate(llvm::CmpInst::Predicate p) {
516   switch (p) {
517   default:
518     llvm_unreachable("incorrect comparison predicate");
519   case llvm::CmpInst::Predicate::ICMP_EQ:
520     return LLVM::ICmpPredicate::eq;
521   case llvm::CmpInst::Predicate::ICMP_NE:
522     return LLVM::ICmpPredicate::ne;
523   case llvm::CmpInst::Predicate::ICMP_SLT:
524     return LLVM::ICmpPredicate::slt;
525   case llvm::CmpInst::Predicate::ICMP_SLE:
526     return LLVM::ICmpPredicate::sle;
527   case llvm::CmpInst::Predicate::ICMP_SGT:
528     return LLVM::ICmpPredicate::sgt;
529   case llvm::CmpInst::Predicate::ICMP_SGE:
530     return LLVM::ICmpPredicate::sge;
531   case llvm::CmpInst::Predicate::ICMP_ULT:
532     return LLVM::ICmpPredicate::ult;
533   case llvm::CmpInst::Predicate::ICMP_ULE:
534     return LLVM::ICmpPredicate::ule;
535   case llvm::CmpInst::Predicate::ICMP_UGT:
536     return LLVM::ICmpPredicate::ugt;
537   case llvm::CmpInst::Predicate::ICMP_UGE:
538     return LLVM::ICmpPredicate::uge;
539   }
540   llvm_unreachable("incorrect comparison predicate");
541 }
542 
543 static AtomicOrdering getLLVMAtomicOrdering(llvm::AtomicOrdering ordering) {
544   switch (ordering) {
545   case llvm::AtomicOrdering::NotAtomic:
546     return LLVM::AtomicOrdering::not_atomic;
547   case llvm::AtomicOrdering::Unordered:
548     return LLVM::AtomicOrdering::unordered;
549   case llvm::AtomicOrdering::Monotonic:
550     return LLVM::AtomicOrdering::monotonic;
551   case llvm::AtomicOrdering::Acquire:
552     return LLVM::AtomicOrdering::acquire;
553   case llvm::AtomicOrdering::Release:
554     return LLVM::AtomicOrdering::release;
555   case llvm::AtomicOrdering::AcquireRelease:
556     return LLVM::AtomicOrdering::acq_rel;
557   case llvm::AtomicOrdering::SequentiallyConsistent:
558     return LLVM::AtomicOrdering::seq_cst;
559   }
560   llvm_unreachable("incorrect atomic ordering");
561 }
562 
563 // `br` branches to `target`. Return the branch arguments to `br`, in the
564 // same order of the PHIs in `target`.
565 LogicalResult
566 Importer::processBranchArgs(llvm::Instruction *br, llvm::BasicBlock *target,
567                             SmallVectorImpl<Value> &blockArguments) {
568   for (auto inst = target->begin(); isa<llvm::PHINode>(inst); ++inst) {
569     auto *PN = cast<llvm::PHINode>(&*inst);
570     Value value = processValue(PN->getIncomingValueForBlock(br->getParent()));
571     if (!value)
572       return failure();
573     blockArguments.push_back(value);
574   }
575   return success();
576 }
577 
578 LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
579   // FIXME: Support uses of SubtargetData. Currently inbounds GEPs, fast-math
580   // flags and call / operand attributes are not supported.
581   Location loc = processDebugLoc(inst->getDebugLoc(), inst);
582   Value &v = instMap[inst];
583   assert(!v && "processInstruction must be called only once per instruction!");
584   switch (inst->getOpcode()) {
585   default:
586     return emitError(loc) << "unknown instruction: " << diag(*inst);
587   case llvm::Instruction::Add:
588   case llvm::Instruction::FAdd:
589   case llvm::Instruction::Sub:
590   case llvm::Instruction::FSub:
591   case llvm::Instruction::Mul:
592   case llvm::Instruction::FMul:
593   case llvm::Instruction::UDiv:
594   case llvm::Instruction::SDiv:
595   case llvm::Instruction::FDiv:
596   case llvm::Instruction::URem:
597   case llvm::Instruction::SRem:
598   case llvm::Instruction::FRem:
599   case llvm::Instruction::Shl:
600   case llvm::Instruction::LShr:
601   case llvm::Instruction::AShr:
602   case llvm::Instruction::And:
603   case llvm::Instruction::Or:
604   case llvm::Instruction::Xor:
605   case llvm::Instruction::Alloca:
606   case llvm::Instruction::Load:
607   case llvm::Instruction::Store:
608   case llvm::Instruction::Ret:
609   case llvm::Instruction::Resume:
610   case llvm::Instruction::Trunc:
611   case llvm::Instruction::ZExt:
612   case llvm::Instruction::SExt:
613   case llvm::Instruction::FPToUI:
614   case llvm::Instruction::FPToSI:
615   case llvm::Instruction::UIToFP:
616   case llvm::Instruction::SIToFP:
617   case llvm::Instruction::FPTrunc:
618   case llvm::Instruction::FPExt:
619   case llvm::Instruction::PtrToInt:
620   case llvm::Instruction::IntToPtr:
621   case llvm::Instruction::AddrSpaceCast:
622   case llvm::Instruction::Freeze:
623   case llvm::Instruction::BitCast: {
624     OperationState state(loc, opcMap.lookup(inst->getOpcode()));
625     SmallVector<Value, 4> ops;
626     ops.reserve(inst->getNumOperands());
627     for (auto *op : inst->operand_values()) {
628       Value value = processValue(op);
629       if (!value)
630         return failure();
631       ops.push_back(value);
632     }
633     state.addOperands(ops);
634     if (!inst->getType()->isVoidTy()) {
635       LLVMType type = processType(inst->getType());
636       if (!type)
637         return failure();
638       state.addTypes(type);
639     }
640     Operation *op = b.createOperation(state);
641     if (!inst->getType()->isVoidTy())
642       v = op->getResult(0);
643     return success();
644   }
645   case llvm::Instruction::ICmp: {
646     Value lhs = processValue(inst->getOperand(0));
647     Value rhs = processValue(inst->getOperand(1));
648     if (!lhs || !rhs)
649       return failure();
650     v = b.create<ICmpOp>(
651         loc, getICmpPredicate(cast<llvm::ICmpInst>(inst)->getPredicate()), lhs,
652         rhs);
653     return success();
654   }
655   case llvm::Instruction::Br: {
656     auto *brInst = cast<llvm::BranchInst>(inst);
657     OperationState state(loc,
658                          brInst->isConditional() ? "llvm.cond_br" : "llvm.br");
659     if (brInst->isConditional()) {
660       Value condition = processValue(brInst->getCondition());
661       if (!condition)
662         return failure();
663       state.addOperands(condition);
664     }
665 
666     std::array<int32_t, 3> operandSegmentSizes = {1, 0, 0};
667     for (int i : llvm::seq<int>(0, brInst->getNumSuccessors())) {
668       auto *succ = brInst->getSuccessor(i);
669       SmallVector<Value, 4> blockArguments;
670       if (failed(processBranchArgs(brInst, succ, blockArguments)))
671         return failure();
672       state.addSuccessors(blocks[succ]);
673       state.addOperands(blockArguments);
674       operandSegmentSizes[i + 1] = blockArguments.size();
675     }
676 
677     if (brInst->isConditional()) {
678       state.addAttribute(LLVM::CondBrOp::getOperandSegmentSizeAttr(),
679                          b.getI32VectorAttr(operandSegmentSizes));
680     }
681 
682     b.createOperation(state);
683     return success();
684   }
685   case llvm::Instruction::PHI: {
686     LLVMType type = processType(inst->getType());
687     if (!type)
688       return failure();
689     v = b.getInsertionBlock()->addArgument(type);
690     return success();
691   }
692   case llvm::Instruction::Call: {
693     llvm::CallInst *ci = cast<llvm::CallInst>(inst);
694     SmallVector<Value, 4> ops;
695     ops.reserve(inst->getNumOperands());
696     for (auto &op : ci->arg_operands()) {
697       Value arg = processValue(op.get());
698       if (!arg)
699         return failure();
700       ops.push_back(arg);
701     }
702 
703     SmallVector<Type, 2> tys;
704     if (!ci->getType()->isVoidTy()) {
705       LLVMType type = processType(inst->getType());
706       if (!type)
707         return failure();
708       tys.push_back(type);
709     }
710     Operation *op;
711     if (llvm::Function *callee = ci->getCalledFunction()) {
712       op = b.create<CallOp>(loc, tys, b.getSymbolRefAttr(callee->getName()),
713                             ops);
714     } else {
715       Value calledValue = processValue(ci->getCalledValue());
716       if (!calledValue)
717         return failure();
718       ops.insert(ops.begin(), calledValue);
719       op = b.create<CallOp>(loc, tys, ops, ArrayRef<NamedAttribute>());
720     }
721     if (!ci->getType()->isVoidTy())
722       v = op->getResult(0);
723     return success();
724   }
725   case llvm::Instruction::LandingPad: {
726     llvm::LandingPadInst *lpi = cast<llvm::LandingPadInst>(inst);
727     SmallVector<Value, 4> ops;
728 
729     for (unsigned i = 0, ie = lpi->getNumClauses(); i < ie; i++)
730       ops.push_back(processConstant(lpi->getClause(i)));
731 
732     Type ty = processType(lpi->getType());
733     if (!ty)
734       return failure();
735 
736     v = b.create<LandingpadOp>(loc, ty, lpi->isCleanup(), ops);
737     return success();
738   }
739   case llvm::Instruction::Invoke: {
740     llvm::InvokeInst *ii = cast<llvm::InvokeInst>(inst);
741 
742     SmallVector<Type, 2> tys;
743     if (!ii->getType()->isVoidTy())
744       tys.push_back(processType(inst->getType()));
745 
746     SmallVector<Value, 4> ops;
747     ops.reserve(inst->getNumOperands() + 1);
748     for (auto &op : ii->arg_operands())
749       ops.push_back(processValue(op.get()));
750 
751     SmallVector<Value, 4> normalArgs, unwindArgs;
752     processBranchArgs(ii, ii->getNormalDest(), normalArgs);
753     processBranchArgs(ii, ii->getUnwindDest(), unwindArgs);
754 
755     Operation *op;
756     if (llvm::Function *callee = ii->getCalledFunction()) {
757       op = b.create<InvokeOp>(loc, tys, b.getSymbolRefAttr(callee->getName()),
758                               ops, blocks[ii->getNormalDest()], normalArgs,
759                               blocks[ii->getUnwindDest()], unwindArgs);
760     } else {
761       ops.insert(ops.begin(), processValue(ii->getCalledValue()));
762       op = b.create<InvokeOp>(loc, tys, ops, blocks[ii->getNormalDest()],
763                               normalArgs, blocks[ii->getUnwindDest()],
764                               unwindArgs);
765     }
766 
767     if (!ii->getType()->isVoidTy())
768       v = op->getResult(0);
769     return success();
770   }
771   case llvm::Instruction::Fence: {
772     StringRef syncscope;
773     SmallVector<StringRef, 4> ssNs;
774     llvm::LLVMContext &llvmContext = dialect->getLLVMContext();
775     llvm::FenceInst *fence = cast<llvm::FenceInst>(inst);
776     llvmContext.getSyncScopeNames(ssNs);
777     int fenceSyncScopeID = fence->getSyncScopeID();
778     for (unsigned i = 0, e = ssNs.size(); i != e; i++) {
779       if (fenceSyncScopeID == llvmContext.getOrInsertSyncScopeID(ssNs[i])) {
780         syncscope = ssNs[i];
781         break;
782       }
783     }
784     b.create<FenceOp>(loc, getLLVMAtomicOrdering(fence->getOrdering()),
785                       syncscope);
786     return success();
787   }
788   case llvm::Instruction::GetElementPtr: {
789     // FIXME: Support inbounds GEPs.
790     llvm::GetElementPtrInst *gep = cast<llvm::GetElementPtrInst>(inst);
791     SmallVector<Value, 4> ops;
792     for (auto *op : gep->operand_values()) {
793       Value value = processValue(op);
794       if (!value)
795         return failure();
796       ops.push_back(value);
797     }
798     Type type = processType(inst->getType());
799     if (!type)
800       return failure();
801     v = b.create<GEPOp>(loc, type, ops, ArrayRef<NamedAttribute>());
802     return success();
803   }
804   }
805 }
806 
807 FlatSymbolRefAttr Importer::getPersonalityAsAttr(llvm::Function *f) {
808   if (!f->hasPersonalityFn())
809     return nullptr;
810 
811   llvm::Constant *pf = f->getPersonalityFn();
812 
813   // If it directly has a name, we can use it.
814   if (pf->hasName())
815     return b.getSymbolRefAttr(pf->getName());
816 
817   // If it doesn't have a name, currently, only function pointers that are
818   // bitcast to i8* are parsed.
819   if (auto ce = dyn_cast<llvm::ConstantExpr>(pf)) {
820     if (ce->getOpcode() == llvm::Instruction::BitCast &&
821         ce->getType() == llvm::Type::getInt8PtrTy(dialect->getLLVMContext())) {
822       if (auto func = dyn_cast<llvm::Function>(ce->getOperand(0)))
823         return b.getSymbolRefAttr(func->getName());
824     }
825   }
826   return FlatSymbolRefAttr();
827 }
828 
829 LogicalResult Importer::processFunction(llvm::Function *f) {
830   blocks.clear();
831   instMap.clear();
832   unknownInstMap.clear();
833 
834   LLVMType functionType = processType(f->getFunctionType());
835   if (!functionType)
836     return failure();
837 
838   b.setInsertionPoint(module.getBody(), getFuncInsertPt());
839   LLVMFuncOp fop = b.create<LLVMFuncOp>(UnknownLoc::get(context), f->getName(),
840                                         functionType);
841 
842   if (FlatSymbolRefAttr personality = getPersonalityAsAttr(f))
843     fop.setAttr(b.getIdentifier("personality"), personality);
844   else if (f->hasPersonalityFn())
845     emitWarning(UnknownLoc::get(context),
846                 "could not deduce personality, skipping it");
847 
848   if (f->isDeclaration())
849     return success();
850 
851   // Eagerly create all blocks.
852   SmallVector<Block *, 4> blockList;
853   for (llvm::BasicBlock &bb : *f) {
854     blockList.push_back(b.createBlock(&fop.body(), fop.body().end()));
855     blocks[&bb] = blockList.back();
856   }
857   currentEntryBlock = blockList[0];
858 
859   // Add function arguments to the entry block.
860   for (auto kv : llvm::enumerate(f->args()))
861     instMap[&kv.value()] = blockList[0]->addArgument(
862         functionType.getFunctionParamType(kv.index()));
863 
864   for (auto bbs : llvm::zip(*f, blockList)) {
865     if (failed(processBasicBlock(&std::get<0>(bbs), std::get<1>(bbs))))
866       return failure();
867   }
868 
869   // Now that all instructions are guaranteed to have been visited, ensure
870   // any unknown uses we encountered are remapped.
871   for (auto &llvmAndUnknown : unknownInstMap) {
872     assert(instMap.count(llvmAndUnknown.first));
873     Value newValue = instMap[llvmAndUnknown.first];
874     Value oldValue = llvmAndUnknown.second->getResult(0);
875     oldValue.replaceAllUsesWith(newValue);
876     llvmAndUnknown.second->erase();
877   }
878   return success();
879 }
880 
881 LogicalResult Importer::processBasicBlock(llvm::BasicBlock *bb, Block *block) {
882   b.setInsertionPointToStart(block);
883   for (llvm::Instruction &inst : *bb) {
884     if (failed(processInstruction(&inst)))
885       return failure();
886   }
887   return success();
888 }
889 
890 OwningModuleRef
891 mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
892                               MLIRContext *context) {
893   OwningModuleRef module(ModuleOp::create(
894       FileLineColLoc::get("", /*line=*/0, /*column=*/0, context)));
895 
896   Importer deserializer(context, module.get());
897   for (llvm::GlobalVariable &gv : llvmModule->globals()) {
898     if (!deserializer.processGlobal(&gv))
899       return {};
900   }
901   for (llvm::Function &f : llvmModule->functions()) {
902     if (failed(deserializer.processFunction(&f)))
903       return {};
904   }
905 
906   return module;
907 }
908 
909 // Deserializes the LLVM bitcode stored in `input` into an MLIR module in the
910 // LLVM dialect.
911 OwningModuleRef translateLLVMIRToModule(llvm::SourceMgr &sourceMgr,
912                                         MLIRContext *context) {
913   LLVMDialect *dialect = context->getRegisteredDialect<LLVMDialect>();
914   assert(dialect && "Could not find LLVMDialect?");
915 
916   llvm::SMDiagnostic err;
917   std::unique_ptr<llvm::Module> llvmModule =
918       llvm::parseIR(*sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), err,
919                     dialect->getLLVMContext(),
920                     /*UpgradeDebugInfo=*/true,
921                     /*DataLayoutString=*/"");
922   if (!llvmModule) {
923     std::string errStr;
924     llvm::raw_string_ostream errStream(errStr);
925     err.print(/*ProgName=*/"", errStream);
926     emitError(UnknownLoc::get(context)) << errStream.str();
927     return {};
928   }
929   return translateLLVMIRToModule(std::move(llvmModule), context);
930 }
931 
932 static TranslateToMLIRRegistration
933     fromLLVM("import-llvm",
934              [](llvm::SourceMgr &sourceMgr, MLIRContext *context) {
935                return translateLLVMIRToModule(sourceMgr, context);
936              });
937