1 //===- ConvertFromLLVMIR.cpp - MLIR to LLVM IR conversion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between LLVM IR and the MLIR LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/MLIRContext.h"
16 #include "mlir/IR/Module.h"
17 #include "mlir/IR/StandardTypes.h"
18 #include "mlir/Target/LLVMIR.h"
19 #include "mlir/Translation.h"
20 
21 #include "llvm/IR/Attributes.h"
22 #include "llvm/IR/Constants.h"
23 #include "llvm/IR/Function.h"
24 #include "llvm/IR/Instructions.h"
25 #include "llvm/IR/Type.h"
26 #include "llvm/IRReader/IRReader.h"
27 #include "llvm/Support/Error.h"
28 #include "llvm/Support/SourceMgr.h"
29 
30 using namespace mlir;
31 using namespace mlir::LLVM;
32 
33 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsFromLLVM.inc"
34 
35 // Utility to print an LLVM value as a string for passing to emitError().
36 // FIXME: Diagnostic should be able to natively handle types that have
37 // operator << (raw_ostream&) defined.
38 static std::string diag(llvm::Value &v) {
39   std::string s;
40   llvm::raw_string_ostream os(s);
41   os << v;
42   return os.str();
43 }
44 
45 // Handles importing globals and functions from an LLVM module.
46 namespace {
47 class Importer {
48 public:
49   Importer(MLIRContext *context, ModuleOp module)
50       : b(context), context(context), module(module),
51         unknownLoc(FileLineColLoc::get("imported-bitcode", 0, 0, context)) {
52     b.setInsertionPointToStart(module.getBody());
53     dialect = context->getRegisteredDialect<LLVMDialect>();
54   }
55 
56   /// Imports `f` into the current module.
57   LogicalResult processFunction(llvm::Function *f);
58 
59   /// Imports GV as a GlobalOp, creating it if it doesn't exist.
60   GlobalOp processGlobal(llvm::GlobalVariable *GV);
61 
62 private:
63   /// Returns personality of `f` as a FlatSymbolRefAttr.
64   FlatSymbolRefAttr getPersonalityAsAttr(llvm::Function *f);
65   /// Imports `bb` into `block`, which must be initially empty.
66   LogicalResult processBasicBlock(llvm::BasicBlock *bb, Block *block);
67   /// Imports `inst` and populates instMap[inst] with the imported Value.
68   LogicalResult processInstruction(llvm::Instruction *inst);
69   /// Creates an LLVMType for `type`.
70   LLVMType processType(llvm::Type *type);
71   /// `value` is an SSA-use. Return the remapped version of `value` or a
72   /// placeholder that will be remapped later if this is an instruction that
73   /// has not yet been visited.
74   Value processValue(llvm::Value *value);
75   /// Create the most accurate Location possible using a llvm::DebugLoc and
76   /// possibly an llvm::Instruction to narrow the Location if debug information
77   /// is unavailable.
78   Location processDebugLoc(const llvm::DebugLoc &loc,
79                            llvm::Instruction *inst = nullptr);
80   /// `br` branches to `target`. Append the block arguments to attach to the
81   /// generated branch op to `blockArguments`. These should be in the same order
82   /// as the PHIs in `target`.
83   LogicalResult processBranchArgs(llvm::Instruction *br,
84                                   llvm::BasicBlock *target,
85                                   SmallVectorImpl<Value> &blockArguments);
86   /// Returns the standard type equivalent to be used in attributes for the
87   /// given LLVM IR dialect type.
88   Type getStdTypeForAttr(LLVMType type);
89   /// Return `value` as an attribute to attach to a GlobalOp.
90   Attribute getConstantAsAttr(llvm::Constant *value);
91   /// Return `c` as an MLIR Value. This could either be a ConstantOp, or
92   /// an expanded sequence of ops in the current function's entry block (for
93   /// ConstantExprs or ConstantGEPs).
94   Value processConstant(llvm::Constant *c);
95 
96   /// The current builder, pointing at where the next Instruction should be
97   /// generated.
98   OpBuilder b;
99   /// The current context.
100   MLIRContext *context;
101   /// The current module being created.
102   ModuleOp module;
103   /// The entry block of the current function being processed.
104   Block *currentEntryBlock;
105 
106   /// Globals are inserted before the first function, if any.
107   Block::iterator getGlobalInsertPt() {
108     auto i = module.getBody()->begin();
109     while (!isa<LLVMFuncOp>(i) && !isa<ModuleTerminatorOp>(i))
110       ++i;
111     return i;
112   }
113 
114   /// Functions are always inserted before the module terminator.
115   Block::iterator getFuncInsertPt() {
116     return std::prev(module.getBody()->end());
117   }
118 
119   /// Remapped blocks, for the current function.
120   DenseMap<llvm::BasicBlock *, Block *> blocks;
121   /// Remapped values. These are function-local.
122   DenseMap<llvm::Value *, Value> instMap;
123   /// Instructions that had not been defined when first encountered as a use.
124   /// Maps to the dummy Operation that was created in processValue().
125   DenseMap<llvm::Value *, Operation *> unknownInstMap;
126   /// Uniquing map of GlobalVariables.
127   DenseMap<llvm::GlobalVariable *, GlobalOp> globals;
128   /// Cached FileLineColLoc::get("imported-bitcode", 0, 0).
129   Location unknownLoc;
130   /// Cached dialect.
131   LLVMDialect *dialect;
132 };
133 } // namespace
134 
135 Location Importer::processDebugLoc(const llvm::DebugLoc &loc,
136                                    llvm::Instruction *inst) {
137   if (!loc && inst) {
138     std::string s;
139     llvm::raw_string_ostream os(s);
140     os << "llvm-imported-inst-%";
141     inst->printAsOperand(os, /*PrintType=*/false);
142     return FileLineColLoc::get(os.str(), 0, 0, context);
143   } else if (!loc) {
144     return unknownLoc;
145   }
146   // FIXME: Obtain the filename from DILocationInfo.
147   return FileLineColLoc::get("imported-bitcode", loc.getLine(), loc.getCol(),
148                              context);
149 }
150 
151 LLVMType Importer::processType(llvm::Type *type) {
152   switch (type->getTypeID()) {
153   case llvm::Type::FloatTyID:
154     return LLVMType::getFloatTy(dialect);
155   case llvm::Type::DoubleTyID:
156     return LLVMType::getDoubleTy(dialect);
157   case llvm::Type::IntegerTyID:
158     return LLVMType::getIntNTy(dialect, type->getIntegerBitWidth());
159   case llvm::Type::PointerTyID: {
160     LLVMType elementType = processType(type->getPointerElementType());
161     if (!elementType)
162       return nullptr;
163     return elementType.getPointerTo(type->getPointerAddressSpace());
164   }
165   case llvm::Type::ArrayTyID: {
166     LLVMType elementType = processType(type->getArrayElementType());
167     if (!elementType)
168       return nullptr;
169     return LLVMType::getArrayTy(elementType, type->getArrayNumElements());
170   }
171   case llvm::Type::ScalableVectorTyID: {
172     emitError(unknownLoc) << "scalable vector types not supported";
173     return nullptr;
174   }
175   case llvm::Type::FixedVectorTyID: {
176     auto *typeVTy = llvm::cast<llvm::FixedVectorType>(type);
177     LLVMType elementType = processType(typeVTy->getElementType());
178     if (!elementType)
179       return nullptr;
180     return LLVMType::getVectorTy(elementType, typeVTy->getNumElements());
181   }
182   case llvm::Type::VoidTyID:
183     return LLVMType::getVoidTy(dialect);
184   case llvm::Type::FP128TyID:
185     return LLVMType::getFP128Ty(dialect);
186   case llvm::Type::X86_FP80TyID:
187     return LLVMType::getX86_FP80Ty(dialect);
188   case llvm::Type::StructTyID: {
189     SmallVector<LLVMType, 4> elementTypes;
190     elementTypes.reserve(type->getStructNumElements());
191     for (unsigned i = 0, e = type->getStructNumElements(); i != e; ++i) {
192       LLVMType ty = processType(type->getStructElementType(i));
193       if (!ty)
194         return nullptr;
195       elementTypes.push_back(ty);
196     }
197     return LLVMType::getStructTy(dialect, elementTypes,
198                                  cast<llvm::StructType>(type)->isPacked());
199   }
200   case llvm::Type::FunctionTyID: {
201     llvm::FunctionType *fty = cast<llvm::FunctionType>(type);
202     SmallVector<LLVMType, 4> paramTypes;
203     for (unsigned i = 0, e = fty->getNumParams(); i != e; ++i) {
204       LLVMType ty = processType(fty->getParamType(i));
205       if (!ty)
206         return nullptr;
207       paramTypes.push_back(ty);
208     }
209     LLVMType result = processType(fty->getReturnType());
210     if (!result)
211       return nullptr;
212 
213     return LLVMType::getFunctionTy(result, paramTypes, fty->isVarArg());
214   }
215   default: {
216     // FIXME: Diagnostic should be able to natively handle types that have
217     // operator<<(raw_ostream&) defined.
218     std::string s;
219     llvm::raw_string_ostream os(s);
220     os << *type;
221     emitError(unknownLoc) << "unhandled type: " << os.str();
222     return nullptr;
223   }
224   }
225 }
226 
227 // We only need integers, floats, doubles, and vectors and tensors thereof for
228 // attributes. Scalar and vector types are converted to the standard
229 // equivalents. Array types are converted to ranked tensors; nested array types
230 // are converted to multi-dimensional tensors or vectors, depending on the
231 // innermost type being a scalar or a vector.
232 Type Importer::getStdTypeForAttr(LLVMType type) {
233   if (!type)
234     return nullptr;
235 
236   if (type.isIntegerTy())
237     return b.getIntegerType(type.getUnderlyingType()->getIntegerBitWidth());
238 
239   if (type.getUnderlyingType()->isFloatTy())
240     return b.getF32Type();
241 
242   if (type.getUnderlyingType()->isDoubleTy())
243     return b.getF64Type();
244 
245   // LLVM vectors can only contain scalars.
246   if (type.isVectorTy()) {
247     auto numElements = llvm::cast<llvm::VectorType>(type.getUnderlyingType())
248                            ->getElementCount();
249     if (numElements.Scalable) {
250       emitError(unknownLoc) << "scalable vectors not supported";
251       return nullptr;
252     }
253     Type elementType = getStdTypeForAttr(type.getVectorElementType());
254     if (!elementType)
255       return nullptr;
256     return VectorType::get(numElements.Min, elementType);
257   }
258 
259   // LLVM arrays can contain other arrays or vectors.
260   if (type.isArrayTy()) {
261     // Recover the nested array shape.
262     SmallVector<int64_t, 4> shape;
263     shape.push_back(type.getArrayNumElements());
264     while (type.getArrayElementType().isArrayTy()) {
265       type = type.getArrayElementType();
266       shape.push_back(type.getArrayNumElements());
267     }
268 
269     // If the innermost type is a vector, use the multi-dimensional vector as
270     // attribute type.
271     if (type.getArrayElementType().isVectorTy()) {
272       LLVMType vectorType = type.getArrayElementType();
273       auto numElements =
274           llvm::cast<llvm::VectorType>(vectorType.getUnderlyingType())
275               ->getElementCount();
276       if (numElements.Scalable) {
277         emitError(unknownLoc) << "scalable vectors not supported";
278         return nullptr;
279       }
280       shape.push_back(numElements.Min);
281 
282       Type elementType = getStdTypeForAttr(vectorType.getVectorElementType());
283       if (!elementType)
284         return nullptr;
285       return VectorType::get(shape, elementType);
286     }
287 
288     // Otherwise use a tensor.
289     Type elementType = getStdTypeForAttr(type.getArrayElementType());
290     if (!elementType)
291       return nullptr;
292     return RankedTensorType::get(shape, elementType);
293   }
294 
295   return nullptr;
296 }
297 
298 // Get the given constant as an attribute. Not all constants can be represented
299 // as attributes.
300 Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
301   if (auto *ci = dyn_cast<llvm::ConstantInt>(value))
302     return b.getIntegerAttr(
303         IntegerType::get(ci->getType()->getBitWidth(), context),
304         ci->getValue());
305   if (auto *c = dyn_cast<llvm::ConstantDataArray>(value))
306     if (c->isString())
307       return b.getStringAttr(c->getAsString());
308   if (auto *c = dyn_cast<llvm::ConstantFP>(value)) {
309     if (c->getType()->isDoubleTy())
310       return b.getFloatAttr(FloatType::getF64(context), c->getValueAPF());
311     else if (c->getType()->isFloatingPointTy())
312       return b.getFloatAttr(FloatType::getF32(context), c->getValueAPF());
313   }
314   if (auto *f = dyn_cast<llvm::Function>(value))
315     return b.getSymbolRefAttr(f->getName());
316 
317   // Convert constant data to a dense elements attribute.
318   if (auto *cd = dyn_cast<llvm::ConstantDataSequential>(value)) {
319     LLVMType type = processType(cd->getElementType());
320     if (!type)
321       return nullptr;
322 
323     auto attrType = getStdTypeForAttr(processType(cd->getType()))
324                         .dyn_cast_or_null<ShapedType>();
325     if (!attrType)
326       return nullptr;
327 
328     if (type.isIntegerTy()) {
329       SmallVector<APInt, 8> values;
330       values.reserve(cd->getNumElements());
331       for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i)
332         values.push_back(cd->getElementAsAPInt(i));
333       return DenseElementsAttr::get(attrType, values);
334     }
335 
336     if (type.isFloatTy() || type.isDoubleTy()) {
337       SmallVector<APFloat, 8> values;
338       values.reserve(cd->getNumElements());
339       for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i)
340         values.push_back(cd->getElementAsAPFloat(i));
341       return DenseElementsAttr::get(attrType, values);
342     }
343 
344     return nullptr;
345   }
346 
347   // Unpack constant aggregates to create dense elements attribute whenever
348   // possible. Return nullptr (failure) otherwise.
349   if (isa<llvm::ConstantAggregate>(value)) {
350     auto outerType = getStdTypeForAttr(processType(value->getType()))
351                          .dyn_cast_or_null<ShapedType>();
352     if (!outerType)
353       return nullptr;
354 
355     SmallVector<Attribute, 8> values;
356     SmallVector<int64_t, 8> shape;
357 
358     for (unsigned i = 0, e = value->getNumOperands(); i < e; ++i) {
359       auto nested = getConstantAsAttr(value->getAggregateElement(i))
360                         .dyn_cast_or_null<DenseElementsAttr>();
361       if (!nested)
362         return nullptr;
363 
364       values.append(nested.attr_value_begin(), nested.attr_value_end());
365     }
366 
367     return DenseElementsAttr::get(outerType, values);
368   }
369 
370   return nullptr;
371 }
372 
373 GlobalOp Importer::processGlobal(llvm::GlobalVariable *GV) {
374   auto it = globals.find(GV);
375   if (it != globals.end())
376     return it->second;
377 
378   OpBuilder b(module.getBody(), getGlobalInsertPt());
379   Attribute valueAttr;
380   if (GV->hasInitializer())
381     valueAttr = getConstantAsAttr(GV->getInitializer());
382   LLVMType type = processType(GV->getValueType());
383   if (!type)
384     return nullptr;
385   GlobalOp op = b.create<GlobalOp>(
386       UnknownLoc::get(context), type, GV->isConstant(),
387       convertLinkageFromLLVM(GV->getLinkage()), GV->getName(), valueAttr);
388   if (GV->hasInitializer() && !valueAttr) {
389     Region &r = op.getInitializerRegion();
390     currentEntryBlock = b.createBlock(&r);
391     b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin());
392     Value v = processConstant(GV->getInitializer());
393     if (!v)
394       return nullptr;
395     b.create<ReturnOp>(op.getLoc(), ArrayRef<Value>({v}));
396   }
397   return globals[GV] = op;
398 }
399 
400 Value Importer::processConstant(llvm::Constant *c) {
401   OpBuilder bEntry(currentEntryBlock, currentEntryBlock->begin());
402   if (Attribute attr = getConstantAsAttr(c)) {
403     // These constants can be represented as attributes.
404     OpBuilder b(currentEntryBlock, currentEntryBlock->begin());
405     LLVMType type = processType(c->getType());
406     if (!type)
407       return nullptr;
408     return instMap[c] = bEntry.create<ConstantOp>(unknownLoc, type, attr);
409   }
410   if (auto *cn = dyn_cast<llvm::ConstantPointerNull>(c)) {
411     LLVMType type = processType(cn->getType());
412     if (!type)
413       return nullptr;
414     return instMap[c] = bEntry.create<NullOp>(unknownLoc, type);
415   }
416   if (auto *GV = dyn_cast<llvm::GlobalVariable>(c))
417     return bEntry.create<AddressOfOp>(UnknownLoc::get(context),
418                                       processGlobal(GV),
419                                       ArrayRef<NamedAttribute>());
420 
421   if (auto *ce = dyn_cast<llvm::ConstantExpr>(c)) {
422     llvm::Instruction *i = ce->getAsInstruction();
423     OpBuilder::InsertionGuard guard(b);
424     b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin());
425     if (failed(processInstruction(i)))
426       return nullptr;
427     assert(instMap.count(i));
428 
429     // Remove this zombie LLVM instruction now, leaving us only with the MLIR
430     // op.
431     i->deleteValue();
432     return instMap[c] = instMap[i];
433   }
434   if (auto *ue = dyn_cast<llvm::UndefValue>(c)) {
435     LLVMType type = processType(ue->getType());
436     if (!type)
437       return nullptr;
438     return instMap[c] = bEntry.create<UndefOp>(UnknownLoc::get(context), type);
439   }
440   emitError(unknownLoc) << "unhandled constant: " << diag(*c);
441   return nullptr;
442 }
443 
444 Value Importer::processValue(llvm::Value *value) {
445   auto it = instMap.find(value);
446   if (it != instMap.end())
447     return it->second;
448 
449   // We don't expect to see instructions in dominator order. If we haven't seen
450   // this instruction yet, create an unknown op and remap it later.
451   if (isa<llvm::Instruction>(value)) {
452     OperationState state(UnknownLoc::get(context), "llvm.unknown");
453     LLVMType type = processType(value->getType());
454     if (!type)
455       return nullptr;
456     state.addTypes(type);
457     unknownInstMap[value] = b.createOperation(state);
458     return unknownInstMap[value]->getResult(0);
459   }
460 
461   if (auto *c = dyn_cast<llvm::Constant>(value))
462     return processConstant(c);
463 
464   emitError(unknownLoc) << "unhandled value: " << diag(*value);
465   return nullptr;
466 }
467 
468 /// Return the MLIR OperationName for the given LLVM opcode.
469 static StringRef lookupOperationNameFromOpcode(unsigned opcode) {
470 // Maps from LLVM opcode to MLIR OperationName. This is deliberately ordered
471 // as in llvm/IR/Instructions.def to aid comprehension and spot missing
472 // instructions.
473 #define INST(llvm_n, mlir_n)                                                   \
474   { llvm::Instruction::llvm_n, LLVM::mlir_n##Op::getOperationName() }
475   static const DenseMap<unsigned, StringRef> opcMap = {
476       // Ret is handled specially.
477       // Br is handled specially.
478       // FIXME: switch
479       // FIXME: indirectbr
480       // FIXME: invoke
481       INST(Resume, Resume),
482       // FIXME: unreachable
483       // FIXME: cleanupret
484       // FIXME: catchret
485       // FIXME: catchswitch
486       // FIXME: callbr
487       // FIXME: fneg
488       INST(Add, Add), INST(FAdd, FAdd), INST(Sub, Sub), INST(FSub, FSub),
489       INST(Mul, Mul), INST(FMul, FMul), INST(UDiv, UDiv), INST(SDiv, SDiv),
490       INST(FDiv, FDiv), INST(URem, URem), INST(SRem, SRem), INST(FRem, FRem),
491       INST(Shl, Shl), INST(LShr, LShr), INST(AShr, AShr), INST(And, And),
492       INST(Or, Or), INST(Xor, XOr), INST(Alloca, Alloca), INST(Load, Load),
493       INST(Store, Store),
494       // Getelementptr is handled specially.
495       INST(Ret, Return), INST(Fence, Fence),
496       // FIXME: atomiccmpxchg
497       // FIXME: atomicrmw
498       INST(Trunc, Trunc), INST(ZExt, ZExt), INST(SExt, SExt),
499       INST(FPToUI, FPToUI), INST(FPToSI, FPToSI), INST(UIToFP, UIToFP),
500       INST(SIToFP, SIToFP), INST(FPTrunc, FPTrunc), INST(FPExt, FPExt),
501       INST(PtrToInt, PtrToInt), INST(IntToPtr, IntToPtr),
502       INST(BitCast, Bitcast), INST(AddrSpaceCast, AddrSpaceCast),
503       // FIXME: cleanuppad
504       // FIXME: catchpad
505       // ICmp is handled specially.
506       // FIXME: fcmp
507       // PHI is handled specially.
508       INST(Freeze, Freeze), INST(Call, Call),
509       // FIXME: select
510       // FIXME: vaarg
511       // FIXME: extractelement
512       // FIXME: insertelement
513       // FIXME: shufflevector
514       // FIXME: extractvalue
515       // FIXME: insertvalue
516       // FIXME: landingpad
517   };
518 #undef INST
519 
520   return opcMap.lookup(opcode);
521 }
522 
523 static ICmpPredicate getICmpPredicate(llvm::CmpInst::Predicate p) {
524   switch (p) {
525   default:
526     llvm_unreachable("incorrect comparison predicate");
527   case llvm::CmpInst::Predicate::ICMP_EQ:
528     return LLVM::ICmpPredicate::eq;
529   case llvm::CmpInst::Predicate::ICMP_NE:
530     return LLVM::ICmpPredicate::ne;
531   case llvm::CmpInst::Predicate::ICMP_SLT:
532     return LLVM::ICmpPredicate::slt;
533   case llvm::CmpInst::Predicate::ICMP_SLE:
534     return LLVM::ICmpPredicate::sle;
535   case llvm::CmpInst::Predicate::ICMP_SGT:
536     return LLVM::ICmpPredicate::sgt;
537   case llvm::CmpInst::Predicate::ICMP_SGE:
538     return LLVM::ICmpPredicate::sge;
539   case llvm::CmpInst::Predicate::ICMP_ULT:
540     return LLVM::ICmpPredicate::ult;
541   case llvm::CmpInst::Predicate::ICMP_ULE:
542     return LLVM::ICmpPredicate::ule;
543   case llvm::CmpInst::Predicate::ICMP_UGT:
544     return LLVM::ICmpPredicate::ugt;
545   case llvm::CmpInst::Predicate::ICMP_UGE:
546     return LLVM::ICmpPredicate::uge;
547   }
548   llvm_unreachable("incorrect comparison predicate");
549 }
550 
551 static AtomicOrdering getLLVMAtomicOrdering(llvm::AtomicOrdering ordering) {
552   switch (ordering) {
553   case llvm::AtomicOrdering::NotAtomic:
554     return LLVM::AtomicOrdering::not_atomic;
555   case llvm::AtomicOrdering::Unordered:
556     return LLVM::AtomicOrdering::unordered;
557   case llvm::AtomicOrdering::Monotonic:
558     return LLVM::AtomicOrdering::monotonic;
559   case llvm::AtomicOrdering::Acquire:
560     return LLVM::AtomicOrdering::acquire;
561   case llvm::AtomicOrdering::Release:
562     return LLVM::AtomicOrdering::release;
563   case llvm::AtomicOrdering::AcquireRelease:
564     return LLVM::AtomicOrdering::acq_rel;
565   case llvm::AtomicOrdering::SequentiallyConsistent:
566     return LLVM::AtomicOrdering::seq_cst;
567   }
568   llvm_unreachable("incorrect atomic ordering");
569 }
570 
571 // `br` branches to `target`. Return the branch arguments to `br`, in the
572 // same order of the PHIs in `target`.
573 LogicalResult
574 Importer::processBranchArgs(llvm::Instruction *br, llvm::BasicBlock *target,
575                             SmallVectorImpl<Value> &blockArguments) {
576   for (auto inst = target->begin(); isa<llvm::PHINode>(inst); ++inst) {
577     auto *PN = cast<llvm::PHINode>(&*inst);
578     Value value = processValue(PN->getIncomingValueForBlock(br->getParent()));
579     if (!value)
580       return failure();
581     blockArguments.push_back(value);
582   }
583   return success();
584 }
585 
586 LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
587   // FIXME: Support uses of SubtargetData. Currently inbounds GEPs, fast-math
588   // flags and call / operand attributes are not supported.
589   Location loc = processDebugLoc(inst->getDebugLoc(), inst);
590   Value &v = instMap[inst];
591   assert(!v && "processInstruction must be called only once per instruction!");
592   switch (inst->getOpcode()) {
593   default:
594     return emitError(loc) << "unknown instruction: " << diag(*inst);
595   case llvm::Instruction::Add:
596   case llvm::Instruction::FAdd:
597   case llvm::Instruction::Sub:
598   case llvm::Instruction::FSub:
599   case llvm::Instruction::Mul:
600   case llvm::Instruction::FMul:
601   case llvm::Instruction::UDiv:
602   case llvm::Instruction::SDiv:
603   case llvm::Instruction::FDiv:
604   case llvm::Instruction::URem:
605   case llvm::Instruction::SRem:
606   case llvm::Instruction::FRem:
607   case llvm::Instruction::Shl:
608   case llvm::Instruction::LShr:
609   case llvm::Instruction::AShr:
610   case llvm::Instruction::And:
611   case llvm::Instruction::Or:
612   case llvm::Instruction::Xor:
613   case llvm::Instruction::Alloca:
614   case llvm::Instruction::Load:
615   case llvm::Instruction::Store:
616   case llvm::Instruction::Ret:
617   case llvm::Instruction::Resume:
618   case llvm::Instruction::Trunc:
619   case llvm::Instruction::ZExt:
620   case llvm::Instruction::SExt:
621   case llvm::Instruction::FPToUI:
622   case llvm::Instruction::FPToSI:
623   case llvm::Instruction::UIToFP:
624   case llvm::Instruction::SIToFP:
625   case llvm::Instruction::FPTrunc:
626   case llvm::Instruction::FPExt:
627   case llvm::Instruction::PtrToInt:
628   case llvm::Instruction::IntToPtr:
629   case llvm::Instruction::AddrSpaceCast:
630   case llvm::Instruction::Freeze:
631   case llvm::Instruction::BitCast: {
632     OperationState state(loc, lookupOperationNameFromOpcode(inst->getOpcode()));
633     SmallVector<Value, 4> ops;
634     ops.reserve(inst->getNumOperands());
635     for (auto *op : inst->operand_values()) {
636       Value value = processValue(op);
637       if (!value)
638         return failure();
639       ops.push_back(value);
640     }
641     state.addOperands(ops);
642     if (!inst->getType()->isVoidTy()) {
643       LLVMType type = processType(inst->getType());
644       if (!type)
645         return failure();
646       state.addTypes(type);
647     }
648     Operation *op = b.createOperation(state);
649     if (!inst->getType()->isVoidTy())
650       v = op->getResult(0);
651     return success();
652   }
653   case llvm::Instruction::ICmp: {
654     Value lhs = processValue(inst->getOperand(0));
655     Value rhs = processValue(inst->getOperand(1));
656     if (!lhs || !rhs)
657       return failure();
658     v = b.create<ICmpOp>(
659         loc, getICmpPredicate(cast<llvm::ICmpInst>(inst)->getPredicate()), lhs,
660         rhs);
661     return success();
662   }
663   case llvm::Instruction::Br: {
664     auto *brInst = cast<llvm::BranchInst>(inst);
665     OperationState state(loc,
666                          brInst->isConditional() ? "llvm.cond_br" : "llvm.br");
667     if (brInst->isConditional()) {
668       Value condition = processValue(brInst->getCondition());
669       if (!condition)
670         return failure();
671       state.addOperands(condition);
672     }
673 
674     std::array<int32_t, 3> operandSegmentSizes = {1, 0, 0};
675     for (int i : llvm::seq<int>(0, brInst->getNumSuccessors())) {
676       auto *succ = brInst->getSuccessor(i);
677       SmallVector<Value, 4> blockArguments;
678       if (failed(processBranchArgs(brInst, succ, blockArguments)))
679         return failure();
680       state.addSuccessors(blocks[succ]);
681       state.addOperands(blockArguments);
682       operandSegmentSizes[i + 1] = blockArguments.size();
683     }
684 
685     if (brInst->isConditional()) {
686       state.addAttribute(LLVM::CondBrOp::getOperandSegmentSizeAttr(),
687                          b.getI32VectorAttr(operandSegmentSizes));
688     }
689 
690     b.createOperation(state);
691     return success();
692   }
693   case llvm::Instruction::PHI: {
694     LLVMType type = processType(inst->getType());
695     if (!type)
696       return failure();
697     v = b.getInsertionBlock()->addArgument(type);
698     return success();
699   }
700   case llvm::Instruction::Call: {
701     llvm::CallInst *ci = cast<llvm::CallInst>(inst);
702     SmallVector<Value, 4> ops;
703     ops.reserve(inst->getNumOperands());
704     for (auto &op : ci->arg_operands()) {
705       Value arg = processValue(op.get());
706       if (!arg)
707         return failure();
708       ops.push_back(arg);
709     }
710 
711     SmallVector<Type, 2> tys;
712     if (!ci->getType()->isVoidTy()) {
713       LLVMType type = processType(inst->getType());
714       if (!type)
715         return failure();
716       tys.push_back(type);
717     }
718     Operation *op;
719     if (llvm::Function *callee = ci->getCalledFunction()) {
720       op = b.create<CallOp>(loc, tys, b.getSymbolRefAttr(callee->getName()),
721                             ops);
722     } else {
723       Value calledValue = processValue(ci->getCalledOperand());
724       if (!calledValue)
725         return failure();
726       ops.insert(ops.begin(), calledValue);
727       op = b.create<CallOp>(loc, tys, ops, ArrayRef<NamedAttribute>());
728     }
729     if (!ci->getType()->isVoidTy())
730       v = op->getResult(0);
731     return success();
732   }
733   case llvm::Instruction::LandingPad: {
734     llvm::LandingPadInst *lpi = cast<llvm::LandingPadInst>(inst);
735     SmallVector<Value, 4> ops;
736 
737     for (unsigned i = 0, ie = lpi->getNumClauses(); i < ie; i++)
738       ops.push_back(processConstant(lpi->getClause(i)));
739 
740     Type ty = processType(lpi->getType());
741     if (!ty)
742       return failure();
743 
744     v = b.create<LandingpadOp>(loc, ty, lpi->isCleanup(), ops);
745     return success();
746   }
747   case llvm::Instruction::Invoke: {
748     llvm::InvokeInst *ii = cast<llvm::InvokeInst>(inst);
749 
750     SmallVector<Type, 2> tys;
751     if (!ii->getType()->isVoidTy())
752       tys.push_back(processType(inst->getType()));
753 
754     SmallVector<Value, 4> ops;
755     ops.reserve(inst->getNumOperands() + 1);
756     for (auto &op : ii->arg_operands())
757       ops.push_back(processValue(op.get()));
758 
759     SmallVector<Value, 4> normalArgs, unwindArgs;
760     processBranchArgs(ii, ii->getNormalDest(), normalArgs);
761     processBranchArgs(ii, ii->getUnwindDest(), unwindArgs);
762 
763     Operation *op;
764     if (llvm::Function *callee = ii->getCalledFunction()) {
765       op = b.create<InvokeOp>(loc, tys, b.getSymbolRefAttr(callee->getName()),
766                               ops, blocks[ii->getNormalDest()], normalArgs,
767                               blocks[ii->getUnwindDest()], unwindArgs);
768     } else {
769       ops.insert(ops.begin(), processValue(ii->getCalledOperand()));
770       op = b.create<InvokeOp>(loc, tys, ops, blocks[ii->getNormalDest()],
771                               normalArgs, blocks[ii->getUnwindDest()],
772                               unwindArgs);
773     }
774 
775     if (!ii->getType()->isVoidTy())
776       v = op->getResult(0);
777     return success();
778   }
779   case llvm::Instruction::Fence: {
780     StringRef syncscope;
781     SmallVector<StringRef, 4> ssNs;
782     llvm::LLVMContext &llvmContext = dialect->getLLVMContext();
783     llvm::FenceInst *fence = cast<llvm::FenceInst>(inst);
784     llvmContext.getSyncScopeNames(ssNs);
785     int fenceSyncScopeID = fence->getSyncScopeID();
786     for (unsigned i = 0, e = ssNs.size(); i != e; i++) {
787       if (fenceSyncScopeID == llvmContext.getOrInsertSyncScopeID(ssNs[i])) {
788         syncscope = ssNs[i];
789         break;
790       }
791     }
792     b.create<FenceOp>(loc, getLLVMAtomicOrdering(fence->getOrdering()),
793                       syncscope);
794     return success();
795   }
796   case llvm::Instruction::GetElementPtr: {
797     // FIXME: Support inbounds GEPs.
798     llvm::GetElementPtrInst *gep = cast<llvm::GetElementPtrInst>(inst);
799     SmallVector<Value, 4> ops;
800     for (auto *op : gep->operand_values()) {
801       Value value = processValue(op);
802       if (!value)
803         return failure();
804       ops.push_back(value);
805     }
806     Type type = processType(inst->getType());
807     if (!type)
808       return failure();
809     v = b.create<GEPOp>(loc, type, ops, ArrayRef<NamedAttribute>());
810     return success();
811   }
812   }
813 }
814 
815 FlatSymbolRefAttr Importer::getPersonalityAsAttr(llvm::Function *f) {
816   if (!f->hasPersonalityFn())
817     return nullptr;
818 
819   llvm::Constant *pf = f->getPersonalityFn();
820 
821   // If it directly has a name, we can use it.
822   if (pf->hasName())
823     return b.getSymbolRefAttr(pf->getName());
824 
825   // If it doesn't have a name, currently, only function pointers that are
826   // bitcast to i8* are parsed.
827   if (auto ce = dyn_cast<llvm::ConstantExpr>(pf)) {
828     if (ce->getOpcode() == llvm::Instruction::BitCast &&
829         ce->getType() == llvm::Type::getInt8PtrTy(dialect->getLLVMContext())) {
830       if (auto func = dyn_cast<llvm::Function>(ce->getOperand(0)))
831         return b.getSymbolRefAttr(func->getName());
832     }
833   }
834   return FlatSymbolRefAttr();
835 }
836 
837 LogicalResult Importer::processFunction(llvm::Function *f) {
838   blocks.clear();
839   instMap.clear();
840   unknownInstMap.clear();
841 
842   LLVMType functionType = processType(f->getFunctionType());
843   if (!functionType)
844     return failure();
845 
846   b.setInsertionPoint(module.getBody(), getFuncInsertPt());
847   LLVMFuncOp fop = b.create<LLVMFuncOp>(UnknownLoc::get(context), f->getName(),
848                                         functionType);
849 
850   if (FlatSymbolRefAttr personality = getPersonalityAsAttr(f))
851     fop.setAttr(b.getIdentifier("personality"), personality);
852   else if (f->hasPersonalityFn())
853     emitWarning(UnknownLoc::get(context),
854                 "could not deduce personality, skipping it");
855 
856   if (f->isDeclaration())
857     return success();
858 
859   // Eagerly create all blocks.
860   SmallVector<Block *, 4> blockList;
861   for (llvm::BasicBlock &bb : *f) {
862     blockList.push_back(b.createBlock(&fop.body(), fop.body().end()));
863     blocks[&bb] = blockList.back();
864   }
865   currentEntryBlock = blockList[0];
866 
867   // Add function arguments to the entry block.
868   for (auto kv : llvm::enumerate(f->args()))
869     instMap[&kv.value()] = blockList[0]->addArgument(
870         functionType.getFunctionParamType(kv.index()));
871 
872   for (auto bbs : llvm::zip(*f, blockList)) {
873     if (failed(processBasicBlock(&std::get<0>(bbs), std::get<1>(bbs))))
874       return failure();
875   }
876 
877   // Now that all instructions are guaranteed to have been visited, ensure
878   // any unknown uses we encountered are remapped.
879   for (auto &llvmAndUnknown : unknownInstMap) {
880     assert(instMap.count(llvmAndUnknown.first));
881     Value newValue = instMap[llvmAndUnknown.first];
882     Value oldValue = llvmAndUnknown.second->getResult(0);
883     oldValue.replaceAllUsesWith(newValue);
884     llvmAndUnknown.second->erase();
885   }
886   return success();
887 }
888 
889 LogicalResult Importer::processBasicBlock(llvm::BasicBlock *bb, Block *block) {
890   b.setInsertionPointToStart(block);
891   for (llvm::Instruction &inst : *bb) {
892     if (failed(processInstruction(&inst)))
893       return failure();
894   }
895   return success();
896 }
897 
898 OwningModuleRef
899 mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
900                               MLIRContext *context) {
901   OwningModuleRef module(ModuleOp::create(
902       FileLineColLoc::get("", /*line=*/0, /*column=*/0, context)));
903 
904   Importer deserializer(context, module.get());
905   for (llvm::GlobalVariable &gv : llvmModule->globals()) {
906     if (!deserializer.processGlobal(&gv))
907       return {};
908   }
909   for (llvm::Function &f : llvmModule->functions()) {
910     if (failed(deserializer.processFunction(&f)))
911       return {};
912   }
913 
914   return module;
915 }
916 
917 // Deserializes the LLVM bitcode stored in `input` into an MLIR module in the
918 // LLVM dialect.
919 OwningModuleRef translateLLVMIRToModule(llvm::SourceMgr &sourceMgr,
920                                         MLIRContext *context) {
921   LLVMDialect *dialect = context->getRegisteredDialect<LLVMDialect>();
922   assert(dialect && "Could not find LLVMDialect?");
923 
924   llvm::SMDiagnostic err;
925   std::unique_ptr<llvm::Module> llvmModule =
926       llvm::parseIR(*sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), err,
927                     dialect->getLLVMContext());
928   if (!llvmModule) {
929     std::string errStr;
930     llvm::raw_string_ostream errStream(errStr);
931     err.print(/*ProgName=*/"", errStream);
932     emitError(UnknownLoc::get(context)) << errStream.str();
933     return {};
934   }
935   return translateLLVMIRToModule(std::move(llvmModule), context);
936 }
937 
938 namespace mlir {
939 void registerFromLLVMIRTranslation() {
940   TranslateToMLIRRegistration fromLLVM(
941       "import-llvm", [](llvm::SourceMgr &sourceMgr, MLIRContext *context) {
942         return ::translateLLVMIRToModule(sourceMgr, context);
943       });
944 }
945 } // namespace mlir
946