1 //===- ConvertFromLLVMIR.cpp - MLIR to LLVM IR conversion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between LLVM IR and the MLIR LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/MLIRContext.h"
16 #include "mlir/IR/Module.h"
17 #include "mlir/IR/StandardTypes.h"
18 #include "mlir/Target/LLVMIR.h"
19 #include "mlir/Translation.h"
20 
21 #include "llvm/IR/Attributes.h"
22 #include "llvm/IR/Constants.h"
23 #include "llvm/IR/Function.h"
24 #include "llvm/IR/Instructions.h"
25 #include "llvm/IR/Type.h"
26 #include "llvm/IRReader/IRReader.h"
27 #include "llvm/Support/Error.h"
28 #include "llvm/Support/SourceMgr.h"
29 
30 using namespace mlir;
31 using namespace mlir::LLVM;
32 
33 // Utility to print an LLVM value as a string for passing to emitError().
34 // FIXME: Diagnostic should be able to natively handle types that have
35 // operator << (raw_ostream&) defined.
36 static std::string diag(llvm::Value &v) {
37   std::string s;
38   llvm::raw_string_ostream os(s);
39   os << v;
40   return os.str();
41 }
42 
43 // Handles importing globals and functions from an LLVM module.
44 namespace {
45 class Importer {
46 public:
47   Importer(MLIRContext *context, ModuleOp module)
48       : b(context), context(context), module(module),
49         unknownLoc(FileLineColLoc::get("imported-bitcode", 0, 0, context)) {
50     b.setInsertionPointToStart(module.getBody());
51     dialect = context->getRegisteredDialect<LLVMDialect>();
52   }
53 
54   /// Imports `f` into the current module.
55   LogicalResult processFunction(llvm::Function *f);
56 
57   /// Imports GV as a GlobalOp, creating it if it doesn't exist.
58   GlobalOp processGlobal(llvm::GlobalVariable *GV);
59 
60 private:
61   /// Imports `bb` into `block`, which must be initially empty.
62   LogicalResult processBasicBlock(llvm::BasicBlock *bb, Block *block);
63   /// Imports `inst` and populates instMap[inst] with the imported Value.
64   LogicalResult processInstruction(llvm::Instruction *inst);
65   /// Creates an LLVMType for `type`.
66   LLVMType processType(llvm::Type *type);
67   /// `value` is an SSA-use. Return the remapped version of `value` or a
68   /// placeholder that will be remapped later if this is an instruction that
69   /// has not yet been visited.
70   Value processValue(llvm::Value *value);
71   /// Create the most accurate Location possible using a llvm::DebugLoc and
72   /// possibly an llvm::Instruction to narrow the Location if debug information
73   /// is unavailable.
74   Location processDebugLoc(const llvm::DebugLoc &loc,
75                            llvm::Instruction *inst = nullptr);
76   /// `br` branches to `target`. Append the block arguments to attach to the
77   /// generated branch op to `blockArguments`. These should be in the same order
78   /// as the PHIs in `target`.
79   LogicalResult processBranchArgs(llvm::BranchInst *br,
80                                   llvm::BasicBlock *target,
81                                   SmallVectorImpl<Value> &blockArguments);
82   /// Returns the standard type equivalent to be used in attributes for the
83   /// given LLVM IR dialect type.
84   Type getStdTypeForAttr(LLVMType type);
85   /// Return `value` as an attribute to attach to a GlobalOp.
86   Attribute getConstantAsAttr(llvm::Constant *value);
87   /// Return `c` as an MLIR Value. This could either be a ConstantOp, or
88   /// an expanded sequence of ops in the current function's entry block (for
89   /// ConstantExprs or ConstantGEPs).
90   Value processConstant(llvm::Constant *c);
91 
92   /// The current builder, pointing at where the next Instruction should be
93   /// generated.
94   OpBuilder b;
95   /// The current context.
96   MLIRContext *context;
97   /// The current module being created.
98   ModuleOp module;
99   /// The entry block of the current function being processed.
100   Block *currentEntryBlock;
101 
102   /// Globals are inserted before the first function, if any.
103   Block::iterator getGlobalInsertPt() {
104     auto i = module.getBody()->begin();
105     while (!isa<LLVMFuncOp>(i) && !isa<ModuleTerminatorOp>(i))
106       ++i;
107     return i;
108   }
109 
110   /// Functions are always inserted before the module terminator.
111   Block::iterator getFuncInsertPt() {
112     return std::prev(module.getBody()->end());
113   }
114 
115   /// Remapped blocks, for the current function.
116   DenseMap<llvm::BasicBlock *, Block *> blocks;
117   /// Remapped values. These are function-local.
118   DenseMap<llvm::Value *, Value> instMap;
119   /// Instructions that had not been defined when first encountered as a use.
120   /// Maps to the dummy Operation that was created in processValue().
121   DenseMap<llvm::Value *, Operation *> unknownInstMap;
122   /// Uniquing map of GlobalVariables.
123   DenseMap<llvm::GlobalVariable *, GlobalOp> globals;
124   /// Cached FileLineColLoc::get("imported-bitcode", 0, 0).
125   Location unknownLoc;
126   /// Cached dialect.
127   LLVMDialect *dialect;
128 };
129 } // namespace
130 
131 Location Importer::processDebugLoc(const llvm::DebugLoc &loc,
132                                    llvm::Instruction *inst) {
133   if (!loc && inst) {
134     std::string s;
135     llvm::raw_string_ostream os(s);
136     os << "llvm-imported-inst-%";
137     inst->printAsOperand(os, /*PrintType=*/false);
138     return FileLineColLoc::get(os.str(), 0, 0, context);
139   } else if (!loc) {
140     return unknownLoc;
141   }
142   // FIXME: Obtain the filename from DILocationInfo.
143   return FileLineColLoc::get("imported-bitcode", loc.getLine(), loc.getCol(),
144                              context);
145 }
146 
147 LLVMType Importer::processType(llvm::Type *type) {
148   switch (type->getTypeID()) {
149   case llvm::Type::FloatTyID:
150     return LLVMType::getFloatTy(dialect);
151   case llvm::Type::DoubleTyID:
152     return LLVMType::getDoubleTy(dialect);
153   case llvm::Type::IntegerTyID:
154     return LLVMType::getIntNTy(dialect, type->getIntegerBitWidth());
155   case llvm::Type::PointerTyID: {
156     LLVMType elementType = processType(type->getPointerElementType());
157     if (!elementType)
158       return nullptr;
159     return elementType.getPointerTo(type->getPointerAddressSpace());
160   }
161   case llvm::Type::ArrayTyID: {
162     LLVMType elementType = processType(type->getArrayElementType());
163     if (!elementType)
164       return nullptr;
165     return LLVMType::getArrayTy(elementType, type->getArrayNumElements());
166   }
167   case llvm::Type::VectorTyID: {
168     if (type->getVectorIsScalable()) {
169       emitError(unknownLoc) << "scalable vector types not supported";
170       return nullptr;
171     }
172     LLVMType elementType = processType(type->getVectorElementType());
173     if (!elementType)
174       return nullptr;
175     return LLVMType::getVectorTy(elementType, type->getVectorNumElements());
176   }
177   case llvm::Type::VoidTyID:
178     return LLVMType::getVoidTy(dialect);
179   case llvm::Type::FP128TyID:
180     return LLVMType::getFP128Ty(dialect);
181   case llvm::Type::X86_FP80TyID:
182     return LLVMType::getX86_FP80Ty(dialect);
183   case llvm::Type::StructTyID: {
184     SmallVector<LLVMType, 4> elementTypes;
185     elementTypes.reserve(type->getStructNumElements());
186     for (unsigned i = 0, e = type->getStructNumElements(); i != e; ++i) {
187       LLVMType ty = processType(type->getStructElementType(i));
188       if (!ty)
189         return nullptr;
190       elementTypes.push_back(ty);
191     }
192     return LLVMType::getStructTy(dialect, elementTypes,
193                                  cast<llvm::StructType>(type)->isPacked());
194   }
195   case llvm::Type::FunctionTyID: {
196     llvm::FunctionType *fty = cast<llvm::FunctionType>(type);
197     SmallVector<LLVMType, 4> paramTypes;
198     for (unsigned i = 0, e = fty->getNumParams(); i != e; ++i) {
199       LLVMType ty = processType(fty->getParamType(i));
200       if (!ty)
201         return nullptr;
202       paramTypes.push_back(ty);
203     }
204     LLVMType result = processType(fty->getReturnType());
205     if (!result)
206       return nullptr;
207 
208     return LLVMType::getFunctionTy(result, paramTypes, fty->isVarArg());
209   }
210   default: {
211     // FIXME: Diagnostic should be able to natively handle types that have
212     // operator<<(raw_ostream&) defined.
213     std::string s;
214     llvm::raw_string_ostream os(s);
215     os << *type;
216     emitError(unknownLoc) << "unhandled type: " << os.str();
217     return nullptr;
218   }
219   }
220 }
221 
222 // We only need integers, floats, doubles, and vectors and tensors thereof for
223 // attributes. Scalar and vector types are converted to the standard
224 // equivalents. Array types are converted to ranked tensors; nested array types
225 // are converted to multi-dimensional tensors or vectors, depending on the
226 // innermost type being a scalar or a vector.
227 Type Importer::getStdTypeForAttr(LLVMType type) {
228   if (!type)
229     return nullptr;
230 
231   if (type.isIntegerTy())
232     return b.getIntegerType(type.getUnderlyingType()->getIntegerBitWidth());
233 
234   if (type.getUnderlyingType()->isFloatTy())
235     return b.getF32Type();
236 
237   if (type.getUnderlyingType()->isDoubleTy())
238     return b.getF64Type();
239 
240   // LLVM vectors can only contain scalars.
241   if (type.isVectorTy()) {
242     auto numElements = type.getUnderlyingType()->getVectorElementCount();
243     if (numElements.Scalable) {
244       emitError(unknownLoc) << "scalable vectors not supported";
245       return nullptr;
246     }
247     Type elementType = getStdTypeForAttr(type.getVectorElementType());
248     if (!elementType)
249       return nullptr;
250     return VectorType::get(numElements.Min, elementType);
251   }
252 
253   // LLVM arrays can contain other arrays or vectors.
254   if (type.isArrayTy()) {
255     // Recover the nested array shape.
256     SmallVector<int64_t, 4> shape;
257     shape.push_back(type.getArrayNumElements());
258     while (type.getArrayElementType().isArrayTy()) {
259       type = type.getArrayElementType();
260       shape.push_back(type.getArrayNumElements());
261     }
262 
263     // If the innermost type is a vector, use the multi-dimensional vector as
264     // attribute type.
265     if (type.getArrayElementType().isVectorTy()) {
266       LLVMType vectorType = type.getArrayElementType();
267       auto numElements =
268           vectorType.getUnderlyingType()->getVectorElementCount();
269       if (numElements.Scalable) {
270         emitError(unknownLoc) << "scalable vectors not supported";
271         return nullptr;
272       }
273       shape.push_back(numElements.Min);
274 
275       Type elementType = getStdTypeForAttr(vectorType.getVectorElementType());
276       if (!elementType)
277         return nullptr;
278       return VectorType::get(shape, elementType);
279     }
280 
281     // Otherwise use a tensor.
282     Type elementType = getStdTypeForAttr(type.getArrayElementType());
283     if (!elementType)
284       return nullptr;
285     return RankedTensorType::get(shape, elementType);
286   }
287 
288   return nullptr;
289 }
290 
291 // Get the given constant as an attribute. Not all constants can be represented
292 // as attributes.
293 Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
294   if (auto *ci = dyn_cast<llvm::ConstantInt>(value))
295     return b.getIntegerAttr(
296         IntegerType::get(ci->getType()->getBitWidth(), context),
297         ci->getValue());
298   if (auto *c = dyn_cast<llvm::ConstantDataArray>(value))
299     if (c->isString())
300       return b.getStringAttr(c->getAsString());
301   if (auto *c = dyn_cast<llvm::ConstantFP>(value)) {
302     if (c->getType()->isDoubleTy())
303       return b.getFloatAttr(FloatType::getF64(context), c->getValueAPF());
304     else if (c->getType()->isFloatingPointTy())
305       return b.getFloatAttr(FloatType::getF32(context), c->getValueAPF());
306   }
307   if (auto *f = dyn_cast<llvm::Function>(value))
308     return b.getSymbolRefAttr(f->getName());
309 
310   // Convert constant data to a dense elements attribute.
311   if (auto *cd = dyn_cast<llvm::ConstantDataSequential>(value)) {
312     LLVMType type = processType(cd->getElementType());
313     if (!type)
314       return nullptr;
315 
316     auto attrType = getStdTypeForAttr(processType(cd->getType()))
317                         .dyn_cast_or_null<ShapedType>();
318     if (!attrType)
319       return nullptr;
320 
321     if (type.isIntegerTy()) {
322       SmallVector<APInt, 8> values;
323       values.reserve(cd->getNumElements());
324       for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i)
325         values.push_back(cd->getElementAsAPInt(i));
326       return DenseElementsAttr::get(attrType, values);
327     }
328 
329     if (type.isFloatTy() || type.isDoubleTy()) {
330       SmallVector<APFloat, 8> values;
331       values.reserve(cd->getNumElements());
332       for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i)
333         values.push_back(cd->getElementAsAPFloat(i));
334       return DenseElementsAttr::get(attrType, values);
335     }
336 
337     return nullptr;
338   }
339 
340   // Unpack constant aggregates to create dense elements attribute whenever
341   // possible. Return nullptr (failure) otherwise.
342   if (auto *ca = dyn_cast<llvm::ConstantAggregate>(value)) {
343     auto outerType = getStdTypeForAttr(processType(value->getType()))
344                          .dyn_cast_or_null<ShapedType>();
345     if (!outerType)
346       return nullptr;
347 
348     SmallVector<Attribute, 8> values;
349     SmallVector<int64_t, 8> shape;
350 
351     for (unsigned i = 0, e = value->getNumOperands(); i < e; ++i) {
352       auto nested = getConstantAsAttr(value->getAggregateElement(i))
353                         .dyn_cast_or_null<DenseElementsAttr>();
354       if (!nested)
355         return nullptr;
356 
357       values.append(nested.attr_value_begin(), nested.attr_value_end());
358     }
359 
360     return DenseElementsAttr::get(outerType, values);
361   }
362 
363   return nullptr;
364 }
365 
366 /// Converts LLVM global variable linkage type into the LLVM dialect predicate.
367 static LLVM::Linkage
368 processLinkage(llvm::GlobalVariable::LinkageTypes linkage) {
369   switch (linkage) {
370   case llvm::GlobalValue::PrivateLinkage:
371     return LLVM::Linkage::Private;
372   case llvm::GlobalValue::InternalLinkage:
373     return LLVM::Linkage::Internal;
374   case llvm::GlobalValue::AvailableExternallyLinkage:
375     return LLVM::Linkage::AvailableExternally;
376   case llvm::GlobalValue::LinkOnceAnyLinkage:
377     return LLVM::Linkage::Linkonce;
378   case llvm::GlobalValue::WeakAnyLinkage:
379     return LLVM::Linkage::Weak;
380   case llvm::GlobalValue::CommonLinkage:
381     return LLVM::Linkage::Common;
382   case llvm::GlobalValue::AppendingLinkage:
383     return LLVM::Linkage::Appending;
384   case llvm::GlobalValue::ExternalWeakLinkage:
385     return LLVM::Linkage::ExternWeak;
386   case llvm::GlobalValue::LinkOnceODRLinkage:
387     return LLVM::Linkage::LinkonceODR;
388   case llvm::GlobalValue::WeakODRLinkage:
389     return LLVM::Linkage::WeakODR;
390   case llvm::GlobalValue::ExternalLinkage:
391     return LLVM::Linkage::External;
392   }
393 
394   llvm_unreachable("unhandled linkage type");
395 }
396 
397 GlobalOp Importer::processGlobal(llvm::GlobalVariable *GV) {
398   auto it = globals.find(GV);
399   if (it != globals.end())
400     return it->second;
401 
402   OpBuilder b(module.getBody(), getGlobalInsertPt());
403   Attribute valueAttr;
404   if (GV->hasInitializer())
405     valueAttr = getConstantAsAttr(GV->getInitializer());
406   LLVMType type = processType(GV->getValueType());
407   if (!type)
408     return nullptr;
409   GlobalOp op = b.create<GlobalOp>(
410       UnknownLoc::get(context), type, GV->isConstant(),
411       processLinkage(GV->getLinkage()), GV->getName(), valueAttr);
412   if (GV->hasInitializer() && !valueAttr) {
413     Region &r = op.getInitializerRegion();
414     currentEntryBlock = b.createBlock(&r);
415     b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin());
416     Value v = processConstant(GV->getInitializer());
417     if (!v)
418       return nullptr;
419     b.create<ReturnOp>(op.getLoc(), ArrayRef<Value>({v}));
420   }
421   return globals[GV] = op;
422 }
423 
424 Value Importer::processConstant(llvm::Constant *c) {
425   if (Attribute attr = getConstantAsAttr(c)) {
426     // These constants can be represented as attributes.
427     OpBuilder b(currentEntryBlock, currentEntryBlock->begin());
428     LLVMType type = processType(c->getType());
429     if (!type)
430       return nullptr;
431     return instMap[c] = b.create<ConstantOp>(unknownLoc, type, attr);
432   }
433   if (auto *cn = dyn_cast<llvm::ConstantPointerNull>(c)) {
434     OpBuilder b(currentEntryBlock, currentEntryBlock->begin());
435     LLVMType type = processType(cn->getType());
436     if (!type)
437       return nullptr;
438     return instMap[c] = b.create<NullOp>(unknownLoc, type);
439   }
440   if (auto *ce = dyn_cast<llvm::ConstantExpr>(c)) {
441     llvm::Instruction *i = ce->getAsInstruction();
442     OpBuilder::InsertionGuard guard(b);
443     b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin());
444     if (failed(processInstruction(i)))
445       return nullptr;
446     assert(instMap.count(i));
447 
448     // Remove this zombie LLVM instruction now, leaving us only with the MLIR
449     // op.
450     i->deleteValue();
451     return instMap[c] = instMap[i];
452   }
453   emitError(unknownLoc) << "unhandled constant: " << diag(*c);
454   return nullptr;
455 }
456 
457 Value Importer::processValue(llvm::Value *value) {
458   auto it = instMap.find(value);
459   if (it != instMap.end())
460     return it->second;
461 
462   // We don't expect to see instructions in dominator order. If we haven't seen
463   // this instruction yet, create an unknown op and remap it later.
464   if (isa<llvm::Instruction>(value)) {
465     OperationState state(UnknownLoc::get(context), "unknown");
466     LLVMType type = processType(value->getType());
467     if (!type)
468       return nullptr;
469     state.addTypes(type);
470     unknownInstMap[value] = b.createOperation(state);
471     return unknownInstMap[value]->getResult(0);
472   }
473 
474   if (auto *GV = dyn_cast<llvm::GlobalVariable>(value)) {
475     auto global = processGlobal(GV);
476     if (!global)
477       return nullptr;
478     return b.create<AddressOfOp>(UnknownLoc::get(context), global,
479                                  ArrayRef<NamedAttribute>());
480   }
481 
482   // Note, constant global variables are both GlobalVariables and Constants,
483   // so we handle GlobalVariables first above.
484   if (auto *c = dyn_cast<llvm::Constant>(value))
485     return processConstant(c);
486 
487   emitError(unknownLoc) << "unhandled value: " << diag(*value);
488   return nullptr;
489 }
490 
491 // Maps from LLVM opcode to MLIR OperationName. This is deliberately ordered
492 // as in llvm/IR/Instructions.def to aid comprehension and spot missing
493 // instructions.
494 #define INST(llvm_n, mlir_n)                                                   \
495   { llvm::Instruction::llvm_n, LLVM::mlir_n##Op::getOperationName() }
496 static const DenseMap<unsigned, StringRef> opcMap = {
497     // Ret is handled specially.
498     // Br is handled specially.
499     // FIXME: switch
500     // FIXME: indirectbr
501     // FIXME: invoke
502     // FIXME: resume
503     // FIXME: unreachable
504     // FIXME: cleanupret
505     // FIXME: catchret
506     // FIXME: catchswitch
507     // FIXME: callbr
508     // FIXME: fneg
509     INST(Add, Add), INST(FAdd, FAdd), INST(Sub, Sub), INST(FSub, FSub),
510     INST(Mul, Mul), INST(FMul, FMul), INST(UDiv, UDiv), INST(SDiv, SDiv),
511     INST(FDiv, FDiv), INST(URem, URem), INST(SRem, SRem), INST(FRem, FRem),
512     INST(Shl, Shl), INST(LShr, LShr), INST(AShr, AShr), INST(And, And),
513     INST(Or, Or), INST(Xor, XOr), INST(Alloca, Alloca), INST(Load, Load),
514     INST(Store, Store),
515     // Getelementptr is handled specially.
516     INST(Ret, Return),
517     // FIXME: fence
518     // FIXME: atomiccmpxchg
519     // FIXME: atomicrmw
520     INST(Trunc, Trunc), INST(ZExt, ZExt), INST(SExt, SExt),
521     INST(FPToUI, FPToUI), INST(FPToSI, FPToSI), INST(UIToFP, UIToFP),
522     INST(SIToFP, SIToFP), INST(FPTrunc, FPTrunc), INST(FPExt, FPExt),
523     INST(PtrToInt, PtrToInt), INST(IntToPtr, IntToPtr), INST(BitCast, Bitcast),
524     INST(AddrSpaceCast, AddrSpaceCast),
525     // FIXME: cleanuppad
526     // FIXME: catchpad
527     // ICmp is handled specially.
528     // FIXME: fcmp
529     // PHI is handled specially.
530     INST(Call, Call),
531     // FIXME: select
532     // FIXME: vaarg
533     // FIXME: extractelement
534     // FIXME: insertelement
535     // FIXME: shufflevector
536     // FIXME: extractvalue
537     // FIXME: insertvalue
538     // FIXME: landingpad
539 };
540 #undef INST
541 
542 static ICmpPredicate getICmpPredicate(llvm::CmpInst::Predicate p) {
543   switch (p) {
544   default:
545     llvm_unreachable("incorrect comparison predicate");
546   case llvm::CmpInst::Predicate::ICMP_EQ:
547     return LLVM::ICmpPredicate::eq;
548   case llvm::CmpInst::Predicate::ICMP_NE:
549     return LLVM::ICmpPredicate::ne;
550   case llvm::CmpInst::Predicate::ICMP_SLT:
551     return LLVM::ICmpPredicate::slt;
552   case llvm::CmpInst::Predicate::ICMP_SLE:
553     return LLVM::ICmpPredicate::sle;
554   case llvm::CmpInst::Predicate::ICMP_SGT:
555     return LLVM::ICmpPredicate::sgt;
556   case llvm::CmpInst::Predicate::ICMP_SGE:
557     return LLVM::ICmpPredicate::sge;
558   case llvm::CmpInst::Predicate::ICMP_ULT:
559     return LLVM::ICmpPredicate::ult;
560   case llvm::CmpInst::Predicate::ICMP_ULE:
561     return LLVM::ICmpPredicate::ule;
562   case llvm::CmpInst::Predicate::ICMP_UGT:
563     return LLVM::ICmpPredicate::ugt;
564   case llvm::CmpInst::Predicate::ICMP_UGE:
565     return LLVM::ICmpPredicate::uge;
566   }
567   llvm_unreachable("incorrect comparison predicate");
568 }
569 
570 // `br` branches to `target`. Return the branch arguments to `br`, in the
571 // same order of the PHIs in `target`.
572 LogicalResult
573 Importer::processBranchArgs(llvm::BranchInst *br, llvm::BasicBlock *target,
574                             SmallVectorImpl<Value> &blockArguments) {
575   for (auto inst = target->begin(); isa<llvm::PHINode>(inst); ++inst) {
576     auto *PN = cast<llvm::PHINode>(&*inst);
577     Value value = processValue(PN->getIncomingValueForBlock(br->getParent()));
578     if (!value)
579       return failure();
580     blockArguments.push_back(value);
581   }
582   return success();
583 }
584 
585 LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
586   // FIXME: Support uses of SubtargetData. Currently inbounds GEPs, fast-math
587   // flags and call / operand attributes are not supported.
588   Location loc = processDebugLoc(inst->getDebugLoc(), inst);
589   Value &v = instMap[inst];
590   assert(!v && "processInstruction must be called only once per instruction!");
591   switch (inst->getOpcode()) {
592   default:
593     return emitError(loc) << "unknown instruction: " << diag(*inst);
594   case llvm::Instruction::Add:
595   case llvm::Instruction::FAdd:
596   case llvm::Instruction::Sub:
597   case llvm::Instruction::FSub:
598   case llvm::Instruction::Mul:
599   case llvm::Instruction::FMul:
600   case llvm::Instruction::UDiv:
601   case llvm::Instruction::SDiv:
602   case llvm::Instruction::FDiv:
603   case llvm::Instruction::URem:
604   case llvm::Instruction::SRem:
605   case llvm::Instruction::FRem:
606   case llvm::Instruction::Shl:
607   case llvm::Instruction::LShr:
608   case llvm::Instruction::AShr:
609   case llvm::Instruction::And:
610   case llvm::Instruction::Or:
611   case llvm::Instruction::Xor:
612   case llvm::Instruction::Alloca:
613   case llvm::Instruction::Load:
614   case llvm::Instruction::Store:
615   case llvm::Instruction::Ret:
616   case llvm::Instruction::Trunc:
617   case llvm::Instruction::ZExt:
618   case llvm::Instruction::SExt:
619   case llvm::Instruction::FPToUI:
620   case llvm::Instruction::FPToSI:
621   case llvm::Instruction::UIToFP:
622   case llvm::Instruction::SIToFP:
623   case llvm::Instruction::FPTrunc:
624   case llvm::Instruction::FPExt:
625   case llvm::Instruction::PtrToInt:
626   case llvm::Instruction::IntToPtr:
627   case llvm::Instruction::AddrSpaceCast:
628   case llvm::Instruction::BitCast: {
629     OperationState state(loc, opcMap.lookup(inst->getOpcode()));
630     SmallVector<Value, 4> ops;
631     ops.reserve(inst->getNumOperands());
632     for (auto *op : inst->operand_values()) {
633       Value value = processValue(op);
634       if (!value)
635         return failure();
636       ops.push_back(value);
637     }
638     state.addOperands(ops);
639     if (!inst->getType()->isVoidTy()) {
640       LLVMType type = processType(inst->getType());
641       if (!type)
642         return failure();
643       state.addTypes(type);
644     }
645     Operation *op = b.createOperation(state);
646     if (!inst->getType()->isVoidTy())
647       v = op->getResult(0);
648     return success();
649   }
650   case llvm::Instruction::ICmp: {
651     Value lhs = processValue(inst->getOperand(0));
652     Value rhs = processValue(inst->getOperand(1));
653     if (!lhs || !rhs)
654       return failure();
655     v = b.create<ICmpOp>(
656         loc, getICmpPredicate(cast<llvm::ICmpInst>(inst)->getPredicate()), lhs,
657         rhs);
658     return success();
659   }
660   case llvm::Instruction::Br: {
661     auto *brInst = cast<llvm::BranchInst>(inst);
662     OperationState state(loc,
663                          brInst->isConditional() ? "llvm.cond_br" : "llvm.br");
664     SmallVector<Value, 4> ops;
665     if (brInst->isConditional()) {
666       Value condition = processValue(brInst->getCondition());
667       if (!condition)
668         return failure();
669       ops.push_back(condition);
670     }
671     state.addOperands(ops);
672     SmallVector<Block *, 4> succs;
673     for (auto *succ : llvm::reverse(brInst->successors())) {
674       SmallVector<Value, 4> blockArguments;
675       if (failed(processBranchArgs(brInst, succ, blockArguments)))
676         return failure();
677       state.addSuccessor(blocks[succ], blockArguments);
678     }
679     b.createOperation(state);
680     return success();
681   }
682   case llvm::Instruction::PHI: {
683     LLVMType type = processType(inst->getType());
684     if (!type)
685       return failure();
686     v = b.getInsertionBlock()->addArgument(type);
687     return success();
688   }
689   case llvm::Instruction::Call: {
690     llvm::CallInst *ci = cast<llvm::CallInst>(inst);
691     SmallVector<Value, 4> ops;
692     ops.reserve(inst->getNumOperands());
693     for (auto &op : ci->arg_operands()) {
694       Value arg = processValue(op.get());
695       if (!arg)
696         return failure();
697       ops.push_back(arg);
698     }
699 
700     SmallVector<Type, 2> tys;
701     if (!ci->getType()->isVoidTy()) {
702       LLVMType type = processType(inst->getType());
703       if (!type)
704         return failure();
705       tys.push_back(type);
706     }
707     Operation *op;
708     if (llvm::Function *callee = ci->getCalledFunction()) {
709       op = b.create<CallOp>(loc, tys, b.getSymbolRefAttr(callee->getName()),
710                             ops);
711     } else {
712       Value calledValue = processValue(ci->getCalledValue());
713       if (!calledValue)
714         return failure();
715       ops.insert(ops.begin(), calledValue);
716       op = b.create<CallOp>(loc, tys, ops, ArrayRef<NamedAttribute>());
717     }
718     if (!ci->getType()->isVoidTy())
719       v = op->getResult(0);
720     return success();
721   }
722   case llvm::Instruction::GetElementPtr: {
723     // FIXME: Support inbounds GEPs.
724     llvm::GetElementPtrInst *gep = cast<llvm::GetElementPtrInst>(inst);
725     SmallVector<Value, 4> ops;
726     for (auto *op : gep->operand_values()) {
727       Value value = processValue(op);
728       if (!value)
729         return failure();
730       ops.push_back(value);
731     }
732     Type type = processType(inst->getType());
733     if (!type)
734       return failure();
735     v = b.create<GEPOp>(loc, type, ops, ArrayRef<NamedAttribute>());
736     return success();
737   }
738   }
739 }
740 
741 LogicalResult Importer::processFunction(llvm::Function *f) {
742   blocks.clear();
743   instMap.clear();
744   unknownInstMap.clear();
745 
746   LLVMType functionType = processType(f->getFunctionType());
747   if (!functionType)
748     return failure();
749 
750   b.setInsertionPoint(module.getBody(), getFuncInsertPt());
751   LLVMFuncOp fop = b.create<LLVMFuncOp>(UnknownLoc::get(context), f->getName(),
752                                         functionType);
753   if (f->isDeclaration())
754     return success();
755 
756   // Eagerly create all blocks.
757   SmallVector<Block *, 4> blockList;
758   for (llvm::BasicBlock &bb : *f) {
759     blockList.push_back(b.createBlock(&fop.body(), fop.body().end()));
760     blocks[&bb] = blockList.back();
761   }
762   currentEntryBlock = blockList[0];
763 
764   // Add function arguments to the entry block.
765   for (auto kv : llvm::enumerate(f->args()))
766     instMap[&kv.value()] = blockList[0]->addArgument(
767         functionType.getFunctionParamType(kv.index()));
768 
769   for (auto bbs : llvm::zip(*f, blockList)) {
770     if (failed(processBasicBlock(&std::get<0>(bbs), std::get<1>(bbs))))
771       return failure();
772   }
773 
774   // Now that all instructions are guaranteed to have been visited, ensure
775   // any unknown uses we encountered are remapped.
776   for (auto &llvmAndUnknown : unknownInstMap) {
777     assert(instMap.count(llvmAndUnknown.first));
778     Value newValue = instMap[llvmAndUnknown.first];
779     Value oldValue = llvmAndUnknown.second->getResult(0);
780     oldValue.replaceAllUsesWith(newValue);
781     llvmAndUnknown.second->erase();
782   }
783   return success();
784 }
785 
786 LogicalResult Importer::processBasicBlock(llvm::BasicBlock *bb, Block *block) {
787   b.setInsertionPointToStart(block);
788   for (llvm::Instruction &inst : *bb) {
789     if (failed(processInstruction(&inst)))
790       return failure();
791   }
792   return success();
793 }
794 
795 OwningModuleRef
796 mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
797                               MLIRContext *context) {
798   OwningModuleRef module(ModuleOp::create(
799       FileLineColLoc::get("", /*line=*/0, /*column=*/0, context)));
800 
801   Importer deserializer(context, module.get());
802   for (llvm::GlobalVariable &gv : llvmModule->globals()) {
803     if (!deserializer.processGlobal(&gv))
804       return {};
805   }
806   for (llvm::Function &f : llvmModule->functions()) {
807     if (failed(deserializer.processFunction(&f)))
808       return {};
809   }
810 
811   return module;
812 }
813 
814 // Deserializes the LLVM bitcode stored in `input` into an MLIR module in the
815 // LLVM dialect.
816 OwningModuleRef translateLLVMIRToModule(llvm::SourceMgr &sourceMgr,
817                                         MLIRContext *context) {
818   LLVMDialect *dialect = context->getRegisteredDialect<LLVMDialect>();
819   assert(dialect && "Could not find LLVMDialect?");
820 
821   llvm::SMDiagnostic err;
822   std::unique_ptr<llvm::Module> llvmModule =
823       llvm::parseIR(*sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), err,
824                     dialect->getLLVMContext(),
825                     /*UpgradeDebugInfo=*/true,
826                     /*DataLayoutString=*/"");
827   if (!llvmModule) {
828     std::string errStr;
829     llvm::raw_string_ostream errStream(errStr);
830     err.print(/*ProgName=*/"", errStream);
831     emitError(UnknownLoc::get(context)) << errStream.str();
832     return {};
833   }
834   return translateLLVMIRToModule(std::move(llvmModule), context);
835 }
836 
837 static TranslateToMLIRRegistration
838     fromLLVM("import-llvm",
839              [](llvm::SourceMgr &sourceMgr, MLIRContext *context) {
840                return translateLLVMIRToModule(sourceMgr, context);
841              });
842