1 //===- ConvertFromLLVMIR.cpp - MLIR to LLVM IR conversion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between LLVM IR and the MLIR LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/MLIRContext.h"
16 #include "mlir/IR/Module.h"
17 #include "mlir/IR/StandardTypes.h"
18 #include "mlir/Target/LLVMIR.h"
19 #include "mlir/Translation.h"
20 
21 #include "llvm/IR/Attributes.h"
22 #include "llvm/IR/Constants.h"
23 #include "llvm/IR/Function.h"
24 #include "llvm/IR/Instructions.h"
25 #include "llvm/IR/Type.h"
26 #include "llvm/IRReader/IRReader.h"
27 #include "llvm/Support/Error.h"
28 #include "llvm/Support/SourceMgr.h"
29 
30 using namespace mlir;
31 using namespace mlir::LLVM;
32 
33 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsFromLLVM.inc"
34 
35 // Utility to print an LLVM value as a string for passing to emitError().
36 // FIXME: Diagnostic should be able to natively handle types that have
37 // operator << (raw_ostream&) defined.
38 static std::string diag(llvm::Value &v) {
39   std::string s;
40   llvm::raw_string_ostream os(s);
41   os << v;
42   return os.str();
43 }
44 
45 // Handles importing globals and functions from an LLVM module.
46 namespace {
47 class Importer {
48 public:
49   Importer(MLIRContext *context, ModuleOp module)
50       : b(context), context(context), module(module),
51         unknownLoc(FileLineColLoc::get("imported-bitcode", 0, 0, context)) {
52     b.setInsertionPointToStart(module.getBody());
53     dialect = context->getRegisteredDialect<LLVMDialect>();
54   }
55 
56   /// Imports `f` into the current module.
57   LogicalResult processFunction(llvm::Function *f);
58 
59   /// Imports GV as a GlobalOp, creating it if it doesn't exist.
60   GlobalOp processGlobal(llvm::GlobalVariable *GV);
61 
62 private:
63   /// Returns personality of `f` as a FlatSymbolRefAttr.
64   FlatSymbolRefAttr getPersonalityAsAttr(llvm::Function *f);
65   /// Imports `bb` into `block`, which must be initially empty.
66   LogicalResult processBasicBlock(llvm::BasicBlock *bb, Block *block);
67   /// Imports `inst` and populates instMap[inst] with the imported Value.
68   LogicalResult processInstruction(llvm::Instruction *inst);
69   /// Creates an LLVMType for `type`.
70   LLVMType processType(llvm::Type *type);
71   /// `value` is an SSA-use. Return the remapped version of `value` or a
72   /// placeholder that will be remapped later if this is an instruction that
73   /// has not yet been visited.
74   Value processValue(llvm::Value *value);
75   /// Create the most accurate Location possible using a llvm::DebugLoc and
76   /// possibly an llvm::Instruction to narrow the Location if debug information
77   /// is unavailable.
78   Location processDebugLoc(const llvm::DebugLoc &loc,
79                            llvm::Instruction *inst = nullptr);
80   /// `br` branches to `target`. Append the block arguments to attach to the
81   /// generated branch op to `blockArguments`. These should be in the same order
82   /// as the PHIs in `target`.
83   LogicalResult processBranchArgs(llvm::Instruction *br,
84                                   llvm::BasicBlock *target,
85                                   SmallVectorImpl<Value> &blockArguments);
86   /// Returns the standard type equivalent to be used in attributes for the
87   /// given LLVM IR dialect type.
88   Type getStdTypeForAttr(LLVMType type);
89   /// Return `value` as an attribute to attach to a GlobalOp.
90   Attribute getConstantAsAttr(llvm::Constant *value);
91   /// Return `c` as an MLIR Value. This could either be a ConstantOp, or
92   /// an expanded sequence of ops in the current function's entry block (for
93   /// ConstantExprs or ConstantGEPs).
94   Value processConstant(llvm::Constant *c);
95 
96   /// The current builder, pointing at where the next Instruction should be
97   /// generated.
98   OpBuilder b;
99   /// The current context.
100   MLIRContext *context;
101   /// The current module being created.
102   ModuleOp module;
103   /// The entry block of the current function being processed.
104   Block *currentEntryBlock;
105 
106   /// Globals are inserted before the first function, if any.
107   Block::iterator getGlobalInsertPt() {
108     auto i = module.getBody()->begin();
109     while (!isa<LLVMFuncOp, ModuleTerminatorOp>(i))
110       ++i;
111     return i;
112   }
113 
114   /// Functions are always inserted before the module terminator.
115   Block::iterator getFuncInsertPt() {
116     return std::prev(module.getBody()->end());
117   }
118 
119   /// Remapped blocks, for the current function.
120   DenseMap<llvm::BasicBlock *, Block *> blocks;
121   /// Remapped values. These are function-local.
122   DenseMap<llvm::Value *, Value> instMap;
123   /// Instructions that had not been defined when first encountered as a use.
124   /// Maps to the dummy Operation that was created in processValue().
125   DenseMap<llvm::Value *, Operation *> unknownInstMap;
126   /// Uniquing map of GlobalVariables.
127   DenseMap<llvm::GlobalVariable *, GlobalOp> globals;
128   /// Cached FileLineColLoc::get("imported-bitcode", 0, 0).
129   Location unknownLoc;
130   /// Cached dialect.
131   LLVMDialect *dialect;
132 };
133 } // namespace
134 
135 Location Importer::processDebugLoc(const llvm::DebugLoc &loc,
136                                    llvm::Instruction *inst) {
137   if (!loc && inst) {
138     std::string s;
139     llvm::raw_string_ostream os(s);
140     os << "llvm-imported-inst-%";
141     inst->printAsOperand(os, /*PrintType=*/false);
142     return FileLineColLoc::get(os.str(), 0, 0, context);
143   } else if (!loc) {
144     return unknownLoc;
145   }
146   // FIXME: Obtain the filename from DILocationInfo.
147   return FileLineColLoc::get("imported-bitcode", loc.getLine(), loc.getCol(),
148                              context);
149 }
150 
151 LLVMType Importer::processType(llvm::Type *type) {
152   switch (type->getTypeID()) {
153   case llvm::Type::FloatTyID:
154     return LLVMType::getFloatTy(dialect);
155   case llvm::Type::DoubleTyID:
156     return LLVMType::getDoubleTy(dialect);
157   case llvm::Type::IntegerTyID:
158     return LLVMType::getIntNTy(dialect, type->getIntegerBitWidth());
159   case llvm::Type::PointerTyID: {
160     LLVMType elementType = processType(type->getPointerElementType());
161     if (!elementType)
162       return nullptr;
163     return elementType.getPointerTo(type->getPointerAddressSpace());
164   }
165   case llvm::Type::ArrayTyID: {
166     LLVMType elementType = processType(type->getArrayElementType());
167     if (!elementType)
168       return nullptr;
169     return LLVMType::getArrayTy(elementType, type->getArrayNumElements());
170   }
171   case llvm::Type::ScalableVectorTyID: {
172     emitError(unknownLoc) << "scalable vector types not supported";
173     return nullptr;
174   }
175   case llvm::Type::FixedVectorTyID: {
176     auto *typeVTy = llvm::cast<llvm::FixedVectorType>(type);
177     LLVMType elementType = processType(typeVTy->getElementType());
178     if (!elementType)
179       return nullptr;
180     return LLVMType::getVectorTy(elementType, typeVTy->getNumElements());
181   }
182   case llvm::Type::VoidTyID:
183     return LLVMType::getVoidTy(dialect);
184   case llvm::Type::FP128TyID:
185     return LLVMType::getFP128Ty(dialect);
186   case llvm::Type::X86_FP80TyID:
187     return LLVMType::getX86_FP80Ty(dialect);
188   case llvm::Type::StructTyID: {
189     SmallVector<LLVMType, 4> elementTypes;
190     elementTypes.reserve(type->getStructNumElements());
191     for (unsigned i = 0, e = type->getStructNumElements(); i != e; ++i) {
192       LLVMType ty = processType(type->getStructElementType(i));
193       if (!ty)
194         return nullptr;
195       elementTypes.push_back(ty);
196     }
197     return LLVMType::getStructTy(dialect, elementTypes,
198                                  cast<llvm::StructType>(type)->isPacked());
199   }
200   case llvm::Type::FunctionTyID: {
201     llvm::FunctionType *fty = cast<llvm::FunctionType>(type);
202     SmallVector<LLVMType, 4> paramTypes;
203     for (unsigned i = 0, e = fty->getNumParams(); i != e; ++i) {
204       LLVMType ty = processType(fty->getParamType(i));
205       if (!ty)
206         return nullptr;
207       paramTypes.push_back(ty);
208     }
209     LLVMType result = processType(fty->getReturnType());
210     if (!result)
211       return nullptr;
212 
213     return LLVMType::getFunctionTy(result, paramTypes, fty->isVarArg());
214   }
215   default: {
216     // FIXME: Diagnostic should be able to natively handle types that have
217     // operator<<(raw_ostream&) defined.
218     std::string s;
219     llvm::raw_string_ostream os(s);
220     os << *type;
221     emitError(unknownLoc) << "unhandled type: " << os.str();
222     return nullptr;
223   }
224   }
225 }
226 
227 // We only need integers, floats, doubles, and vectors and tensors thereof for
228 // attributes. Scalar and vector types are converted to the standard
229 // equivalents. Array types are converted to ranked tensors; nested array types
230 // are converted to multi-dimensional tensors or vectors, depending on the
231 // innermost type being a scalar or a vector.
232 Type Importer::getStdTypeForAttr(LLVMType type) {
233   if (!type)
234     return nullptr;
235 
236   if (type.isIntegerTy())
237     return b.getIntegerType(type.getUnderlyingType()->getIntegerBitWidth());
238 
239   if (type.getUnderlyingType()->isFloatTy())
240     return b.getF32Type();
241 
242   if (type.getUnderlyingType()->isDoubleTy())
243     return b.getF64Type();
244 
245   // LLVM vectors can only contain scalars.
246   if (type.isVectorTy()) {
247     auto numElements = llvm::cast<llvm::VectorType>(type.getUnderlyingType())
248                            ->getElementCount();
249     if (numElements.Scalable) {
250       emitError(unknownLoc) << "scalable vectors not supported";
251       return nullptr;
252     }
253     Type elementType = getStdTypeForAttr(type.getVectorElementType());
254     if (!elementType)
255       return nullptr;
256     return VectorType::get(numElements.Min, elementType);
257   }
258 
259   // LLVM arrays can contain other arrays or vectors.
260   if (type.isArrayTy()) {
261     // Recover the nested array shape.
262     SmallVector<int64_t, 4> shape;
263     shape.push_back(type.getArrayNumElements());
264     while (type.getArrayElementType().isArrayTy()) {
265       type = type.getArrayElementType();
266       shape.push_back(type.getArrayNumElements());
267     }
268 
269     // If the innermost type is a vector, use the multi-dimensional vector as
270     // attribute type.
271     if (type.getArrayElementType().isVectorTy()) {
272       LLVMType vectorType = type.getArrayElementType();
273       auto numElements =
274           llvm::cast<llvm::VectorType>(vectorType.getUnderlyingType())
275               ->getElementCount();
276       if (numElements.Scalable) {
277         emitError(unknownLoc) << "scalable vectors not supported";
278         return nullptr;
279       }
280       shape.push_back(numElements.Min);
281 
282       Type elementType = getStdTypeForAttr(vectorType.getVectorElementType());
283       if (!elementType)
284         return nullptr;
285       return VectorType::get(shape, elementType);
286     }
287 
288     // Otherwise use a tensor.
289     Type elementType = getStdTypeForAttr(type.getArrayElementType());
290     if (!elementType)
291       return nullptr;
292     return RankedTensorType::get(shape, elementType);
293   }
294 
295   return nullptr;
296 }
297 
298 // Get the given constant as an attribute. Not all constants can be represented
299 // as attributes.
300 Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
301   if (auto *ci = dyn_cast<llvm::ConstantInt>(value))
302     return b.getIntegerAttr(
303         IntegerType::get(ci->getType()->getBitWidth(), context),
304         ci->getValue());
305   if (auto *c = dyn_cast<llvm::ConstantDataArray>(value))
306     if (c->isString())
307       return b.getStringAttr(c->getAsString());
308   if (auto *c = dyn_cast<llvm::ConstantFP>(value)) {
309     if (c->getType()->isDoubleTy())
310       return b.getFloatAttr(FloatType::getF64(context), c->getValueAPF());
311     else if (c->getType()->isFloatingPointTy())
312       return b.getFloatAttr(FloatType::getF32(context), c->getValueAPF());
313   }
314   if (auto *f = dyn_cast<llvm::Function>(value))
315     return b.getSymbolRefAttr(f->getName());
316 
317   // Convert constant data to a dense elements attribute.
318   if (auto *cd = dyn_cast<llvm::ConstantDataSequential>(value)) {
319     LLVMType type = processType(cd->getElementType());
320     if (!type)
321       return nullptr;
322 
323     auto attrType = getStdTypeForAttr(processType(cd->getType()))
324                         .dyn_cast_or_null<ShapedType>();
325     if (!attrType)
326       return nullptr;
327 
328     if (type.isIntegerTy()) {
329       SmallVector<APInt, 8> values;
330       values.reserve(cd->getNumElements());
331       for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i)
332         values.push_back(cd->getElementAsAPInt(i));
333       return DenseElementsAttr::get(attrType, values);
334     }
335 
336     if (type.isFloatTy() || type.isDoubleTy()) {
337       SmallVector<APFloat, 8> values;
338       values.reserve(cd->getNumElements());
339       for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i)
340         values.push_back(cd->getElementAsAPFloat(i));
341       return DenseElementsAttr::get(attrType, values);
342     }
343 
344     return nullptr;
345   }
346 
347   // Unpack constant aggregates to create dense elements attribute whenever
348   // possible. Return nullptr (failure) otherwise.
349   if (isa<llvm::ConstantAggregate>(value)) {
350     auto outerType = getStdTypeForAttr(processType(value->getType()))
351                          .dyn_cast_or_null<ShapedType>();
352     if (!outerType)
353       return nullptr;
354 
355     SmallVector<Attribute, 8> values;
356     SmallVector<int64_t, 8> shape;
357 
358     for (unsigned i = 0, e = value->getNumOperands(); i < e; ++i) {
359       auto nested = getConstantAsAttr(value->getAggregateElement(i))
360                         .dyn_cast_or_null<DenseElementsAttr>();
361       if (!nested)
362         return nullptr;
363 
364       values.append(nested.attr_value_begin(), nested.attr_value_end());
365     }
366 
367     return DenseElementsAttr::get(outerType, values);
368   }
369 
370   return nullptr;
371 }
372 
373 GlobalOp Importer::processGlobal(llvm::GlobalVariable *GV) {
374   auto it = globals.find(GV);
375   if (it != globals.end())
376     return it->second;
377 
378   OpBuilder b(module.getBody(), getGlobalInsertPt());
379   Attribute valueAttr;
380   if (GV->hasInitializer())
381     valueAttr = getConstantAsAttr(GV->getInitializer());
382   LLVMType type = processType(GV->getValueType());
383   if (!type)
384     return nullptr;
385   GlobalOp op = b.create<GlobalOp>(
386       UnknownLoc::get(context), type, GV->isConstant(),
387       convertLinkageFromLLVM(GV->getLinkage()), GV->getName(), valueAttr);
388   if (GV->hasInitializer() && !valueAttr) {
389     Region &r = op.getInitializerRegion();
390     currentEntryBlock = b.createBlock(&r);
391     b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin());
392     Value v = processConstant(GV->getInitializer());
393     if (!v)
394       return nullptr;
395     b.create<ReturnOp>(op.getLoc(), ArrayRef<Value>({v}));
396   }
397   return globals[GV] = op;
398 }
399 
400 Value Importer::processConstant(llvm::Constant *c) {
401   OpBuilder bEntry(currentEntryBlock, currentEntryBlock->begin());
402   if (Attribute attr = getConstantAsAttr(c)) {
403     // These constants can be represented as attributes.
404     OpBuilder b(currentEntryBlock, currentEntryBlock->begin());
405     LLVMType type = processType(c->getType());
406     if (!type)
407       return nullptr;
408     if (auto symbolRef = attr.dyn_cast<FlatSymbolRefAttr>())
409       return instMap[c] = bEntry.create<AddressOfOp>(unknownLoc, type,
410                                                      symbolRef.getValue());
411     return instMap[c] = bEntry.create<ConstantOp>(unknownLoc, type, attr);
412   }
413   if (auto *cn = dyn_cast<llvm::ConstantPointerNull>(c)) {
414     LLVMType type = processType(cn->getType());
415     if (!type)
416       return nullptr;
417     return instMap[c] = bEntry.create<NullOp>(unknownLoc, type);
418   }
419   if (auto *GV = dyn_cast<llvm::GlobalVariable>(c))
420     return bEntry.create<AddressOfOp>(UnknownLoc::get(context),
421                                       processGlobal(GV));
422 
423   if (auto *ce = dyn_cast<llvm::ConstantExpr>(c)) {
424     llvm::Instruction *i = ce->getAsInstruction();
425     OpBuilder::InsertionGuard guard(b);
426     b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin());
427     if (failed(processInstruction(i)))
428       return nullptr;
429     assert(instMap.count(i));
430 
431     // Remove this zombie LLVM instruction now, leaving us only with the MLIR
432     // op.
433     i->deleteValue();
434     return instMap[c] = instMap[i];
435   }
436   if (auto *ue = dyn_cast<llvm::UndefValue>(c)) {
437     LLVMType type = processType(ue->getType());
438     if (!type)
439       return nullptr;
440     return instMap[c] = bEntry.create<UndefOp>(UnknownLoc::get(context), type);
441   }
442   emitError(unknownLoc) << "unhandled constant: " << diag(*c);
443   return nullptr;
444 }
445 
446 Value Importer::processValue(llvm::Value *value) {
447   auto it = instMap.find(value);
448   if (it != instMap.end())
449     return it->second;
450 
451   // We don't expect to see instructions in dominator order. If we haven't seen
452   // this instruction yet, create an unknown op and remap it later.
453   if (isa<llvm::Instruction>(value)) {
454     OperationState state(UnknownLoc::get(context), "llvm.unknown");
455     LLVMType type = processType(value->getType());
456     if (!type)
457       return nullptr;
458     state.addTypes(type);
459     unknownInstMap[value] = b.createOperation(state);
460     return unknownInstMap[value]->getResult(0);
461   }
462 
463   if (auto *c = dyn_cast<llvm::Constant>(value))
464     return processConstant(c);
465 
466   emitError(unknownLoc) << "unhandled value: " << diag(*value);
467   return nullptr;
468 }
469 
470 /// Return the MLIR OperationName for the given LLVM opcode.
471 static StringRef lookupOperationNameFromOpcode(unsigned opcode) {
472 // Maps from LLVM opcode to MLIR OperationName. This is deliberately ordered
473 // as in llvm/IR/Instructions.def to aid comprehension and spot missing
474 // instructions.
475 #define INST(llvm_n, mlir_n)                                                   \
476   { llvm::Instruction::llvm_n, LLVM::mlir_n##Op::getOperationName() }
477   static const DenseMap<unsigned, StringRef> opcMap = {
478       // Ret is handled specially.
479       // Br is handled specially.
480       // FIXME: switch
481       // FIXME: indirectbr
482       // FIXME: invoke
483       INST(Resume, Resume),
484       // FIXME: unreachable
485       // FIXME: cleanupret
486       // FIXME: catchret
487       // FIXME: catchswitch
488       // FIXME: callbr
489       // FIXME: fneg
490       INST(Add, Add), INST(FAdd, FAdd), INST(Sub, Sub), INST(FSub, FSub),
491       INST(Mul, Mul), INST(FMul, FMul), INST(UDiv, UDiv), INST(SDiv, SDiv),
492       INST(FDiv, FDiv), INST(URem, URem), INST(SRem, SRem), INST(FRem, FRem),
493       INST(Shl, Shl), INST(LShr, LShr), INST(AShr, AShr), INST(And, And),
494       INST(Or, Or), INST(Xor, XOr), INST(Alloca, Alloca), INST(Load, Load),
495       INST(Store, Store),
496       // Getelementptr is handled specially.
497       INST(Ret, Return), INST(Fence, Fence),
498       // FIXME: atomiccmpxchg
499       // FIXME: atomicrmw
500       INST(Trunc, Trunc), INST(ZExt, ZExt), INST(SExt, SExt),
501       INST(FPToUI, FPToUI), INST(FPToSI, FPToSI), INST(UIToFP, UIToFP),
502       INST(SIToFP, SIToFP), INST(FPTrunc, FPTrunc), INST(FPExt, FPExt),
503       INST(PtrToInt, PtrToInt), INST(IntToPtr, IntToPtr),
504       INST(BitCast, Bitcast), INST(AddrSpaceCast, AddrSpaceCast),
505       // FIXME: cleanuppad
506       // FIXME: catchpad
507       // ICmp is handled specially.
508       // FIXME: fcmp
509       // PHI is handled specially.
510       INST(Freeze, Freeze), INST(Call, Call),
511       // FIXME: select
512       // FIXME: vaarg
513       // FIXME: extractelement
514       // FIXME: insertelement
515       // FIXME: shufflevector
516       // FIXME: extractvalue
517       // FIXME: insertvalue
518       // FIXME: landingpad
519   };
520 #undef INST
521 
522   return opcMap.lookup(opcode);
523 }
524 
525 static ICmpPredicate getICmpPredicate(llvm::CmpInst::Predicate p) {
526   switch (p) {
527   default:
528     llvm_unreachable("incorrect comparison predicate");
529   case llvm::CmpInst::Predicate::ICMP_EQ:
530     return LLVM::ICmpPredicate::eq;
531   case llvm::CmpInst::Predicate::ICMP_NE:
532     return LLVM::ICmpPredicate::ne;
533   case llvm::CmpInst::Predicate::ICMP_SLT:
534     return LLVM::ICmpPredicate::slt;
535   case llvm::CmpInst::Predicate::ICMP_SLE:
536     return LLVM::ICmpPredicate::sle;
537   case llvm::CmpInst::Predicate::ICMP_SGT:
538     return LLVM::ICmpPredicate::sgt;
539   case llvm::CmpInst::Predicate::ICMP_SGE:
540     return LLVM::ICmpPredicate::sge;
541   case llvm::CmpInst::Predicate::ICMP_ULT:
542     return LLVM::ICmpPredicate::ult;
543   case llvm::CmpInst::Predicate::ICMP_ULE:
544     return LLVM::ICmpPredicate::ule;
545   case llvm::CmpInst::Predicate::ICMP_UGT:
546     return LLVM::ICmpPredicate::ugt;
547   case llvm::CmpInst::Predicate::ICMP_UGE:
548     return LLVM::ICmpPredicate::uge;
549   }
550   llvm_unreachable("incorrect comparison predicate");
551 }
552 
553 static AtomicOrdering getLLVMAtomicOrdering(llvm::AtomicOrdering ordering) {
554   switch (ordering) {
555   case llvm::AtomicOrdering::NotAtomic:
556     return LLVM::AtomicOrdering::not_atomic;
557   case llvm::AtomicOrdering::Unordered:
558     return LLVM::AtomicOrdering::unordered;
559   case llvm::AtomicOrdering::Monotonic:
560     return LLVM::AtomicOrdering::monotonic;
561   case llvm::AtomicOrdering::Acquire:
562     return LLVM::AtomicOrdering::acquire;
563   case llvm::AtomicOrdering::Release:
564     return LLVM::AtomicOrdering::release;
565   case llvm::AtomicOrdering::AcquireRelease:
566     return LLVM::AtomicOrdering::acq_rel;
567   case llvm::AtomicOrdering::SequentiallyConsistent:
568     return LLVM::AtomicOrdering::seq_cst;
569   }
570   llvm_unreachable("incorrect atomic ordering");
571 }
572 
573 // `br` branches to `target`. Return the branch arguments to `br`, in the
574 // same order of the PHIs in `target`.
575 LogicalResult
576 Importer::processBranchArgs(llvm::Instruction *br, llvm::BasicBlock *target,
577                             SmallVectorImpl<Value> &blockArguments) {
578   for (auto inst = target->begin(); isa<llvm::PHINode>(inst); ++inst) {
579     auto *PN = cast<llvm::PHINode>(&*inst);
580     Value value = processValue(PN->getIncomingValueForBlock(br->getParent()));
581     if (!value)
582       return failure();
583     blockArguments.push_back(value);
584   }
585   return success();
586 }
587 
588 LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
589   // FIXME: Support uses of SubtargetData. Currently inbounds GEPs, fast-math
590   // flags and call / operand attributes are not supported.
591   Location loc = processDebugLoc(inst->getDebugLoc(), inst);
592   Value &v = instMap[inst];
593   assert(!v && "processInstruction must be called only once per instruction!");
594   switch (inst->getOpcode()) {
595   default:
596     return emitError(loc) << "unknown instruction: " << diag(*inst);
597   case llvm::Instruction::Add:
598   case llvm::Instruction::FAdd:
599   case llvm::Instruction::Sub:
600   case llvm::Instruction::FSub:
601   case llvm::Instruction::Mul:
602   case llvm::Instruction::FMul:
603   case llvm::Instruction::UDiv:
604   case llvm::Instruction::SDiv:
605   case llvm::Instruction::FDiv:
606   case llvm::Instruction::URem:
607   case llvm::Instruction::SRem:
608   case llvm::Instruction::FRem:
609   case llvm::Instruction::Shl:
610   case llvm::Instruction::LShr:
611   case llvm::Instruction::AShr:
612   case llvm::Instruction::And:
613   case llvm::Instruction::Or:
614   case llvm::Instruction::Xor:
615   case llvm::Instruction::Alloca:
616   case llvm::Instruction::Load:
617   case llvm::Instruction::Store:
618   case llvm::Instruction::Ret:
619   case llvm::Instruction::Resume:
620   case llvm::Instruction::Trunc:
621   case llvm::Instruction::ZExt:
622   case llvm::Instruction::SExt:
623   case llvm::Instruction::FPToUI:
624   case llvm::Instruction::FPToSI:
625   case llvm::Instruction::UIToFP:
626   case llvm::Instruction::SIToFP:
627   case llvm::Instruction::FPTrunc:
628   case llvm::Instruction::FPExt:
629   case llvm::Instruction::PtrToInt:
630   case llvm::Instruction::IntToPtr:
631   case llvm::Instruction::AddrSpaceCast:
632   case llvm::Instruction::Freeze:
633   case llvm::Instruction::BitCast: {
634     OperationState state(loc, lookupOperationNameFromOpcode(inst->getOpcode()));
635     SmallVector<Value, 4> ops;
636     ops.reserve(inst->getNumOperands());
637     for (auto *op : inst->operand_values()) {
638       Value value = processValue(op);
639       if (!value)
640         return failure();
641       ops.push_back(value);
642     }
643     state.addOperands(ops);
644     if (!inst->getType()->isVoidTy()) {
645       LLVMType type = processType(inst->getType());
646       if (!type)
647         return failure();
648       state.addTypes(type);
649     }
650     Operation *op = b.createOperation(state);
651     if (!inst->getType()->isVoidTy())
652       v = op->getResult(0);
653     return success();
654   }
655   case llvm::Instruction::ICmp: {
656     Value lhs = processValue(inst->getOperand(0));
657     Value rhs = processValue(inst->getOperand(1));
658     if (!lhs || !rhs)
659       return failure();
660     v = b.create<ICmpOp>(
661         loc, getICmpPredicate(cast<llvm::ICmpInst>(inst)->getPredicate()), lhs,
662         rhs);
663     return success();
664   }
665   case llvm::Instruction::Br: {
666     auto *brInst = cast<llvm::BranchInst>(inst);
667     OperationState state(loc,
668                          brInst->isConditional() ? "llvm.cond_br" : "llvm.br");
669     if (brInst->isConditional()) {
670       Value condition = processValue(brInst->getCondition());
671       if (!condition)
672         return failure();
673       state.addOperands(condition);
674     }
675 
676     std::array<int32_t, 3> operandSegmentSizes = {1, 0, 0};
677     for (int i : llvm::seq<int>(0, brInst->getNumSuccessors())) {
678       auto *succ = brInst->getSuccessor(i);
679       SmallVector<Value, 4> blockArguments;
680       if (failed(processBranchArgs(brInst, succ, blockArguments)))
681         return failure();
682       state.addSuccessors(blocks[succ]);
683       state.addOperands(blockArguments);
684       operandSegmentSizes[i + 1] = blockArguments.size();
685     }
686 
687     if (brInst->isConditional()) {
688       state.addAttribute(LLVM::CondBrOp::getOperandSegmentSizeAttr(),
689                          b.getI32VectorAttr(operandSegmentSizes));
690     }
691 
692     b.createOperation(state);
693     return success();
694   }
695   case llvm::Instruction::PHI: {
696     LLVMType type = processType(inst->getType());
697     if (!type)
698       return failure();
699     v = b.getInsertionBlock()->addArgument(type);
700     return success();
701   }
702   case llvm::Instruction::Call: {
703     llvm::CallInst *ci = cast<llvm::CallInst>(inst);
704     SmallVector<Value, 4> ops;
705     ops.reserve(inst->getNumOperands());
706     for (auto &op : ci->arg_operands()) {
707       Value arg = processValue(op.get());
708       if (!arg)
709         return failure();
710       ops.push_back(arg);
711     }
712 
713     SmallVector<Type, 2> tys;
714     if (!ci->getType()->isVoidTy()) {
715       LLVMType type = processType(inst->getType());
716       if (!type)
717         return failure();
718       tys.push_back(type);
719     }
720     Operation *op;
721     if (llvm::Function *callee = ci->getCalledFunction()) {
722       op = b.create<CallOp>(loc, tys, b.getSymbolRefAttr(callee->getName()),
723                             ops);
724     } else {
725       Value calledValue = processValue(ci->getCalledOperand());
726       if (!calledValue)
727         return failure();
728       ops.insert(ops.begin(), calledValue);
729       op = b.create<CallOp>(loc, tys, ops);
730     }
731     if (!ci->getType()->isVoidTy())
732       v = op->getResult(0);
733     return success();
734   }
735   case llvm::Instruction::LandingPad: {
736     llvm::LandingPadInst *lpi = cast<llvm::LandingPadInst>(inst);
737     SmallVector<Value, 4> ops;
738 
739     for (unsigned i = 0, ie = lpi->getNumClauses(); i < ie; i++)
740       ops.push_back(processConstant(lpi->getClause(i)));
741 
742     Type ty = processType(lpi->getType());
743     if (!ty)
744       return failure();
745 
746     v = b.create<LandingpadOp>(loc, ty, lpi->isCleanup(), ops);
747     return success();
748   }
749   case llvm::Instruction::Invoke: {
750     llvm::InvokeInst *ii = cast<llvm::InvokeInst>(inst);
751 
752     SmallVector<Type, 2> tys;
753     if (!ii->getType()->isVoidTy())
754       tys.push_back(processType(inst->getType()));
755 
756     SmallVector<Value, 4> ops;
757     ops.reserve(inst->getNumOperands() + 1);
758     for (auto &op : ii->arg_operands())
759       ops.push_back(processValue(op.get()));
760 
761     SmallVector<Value, 4> normalArgs, unwindArgs;
762     processBranchArgs(ii, ii->getNormalDest(), normalArgs);
763     processBranchArgs(ii, ii->getUnwindDest(), unwindArgs);
764 
765     Operation *op;
766     if (llvm::Function *callee = ii->getCalledFunction()) {
767       op = b.create<InvokeOp>(loc, tys, b.getSymbolRefAttr(callee->getName()),
768                               ops, blocks[ii->getNormalDest()], normalArgs,
769                               blocks[ii->getUnwindDest()], unwindArgs);
770     } else {
771       ops.insert(ops.begin(), processValue(ii->getCalledOperand()));
772       op = b.create<InvokeOp>(loc, tys, ops, blocks[ii->getNormalDest()],
773                               normalArgs, blocks[ii->getUnwindDest()],
774                               unwindArgs);
775     }
776 
777     if (!ii->getType()->isVoidTy())
778       v = op->getResult(0);
779     return success();
780   }
781   case llvm::Instruction::Fence: {
782     StringRef syncscope;
783     SmallVector<StringRef, 4> ssNs;
784     llvm::LLVMContext &llvmContext = dialect->getLLVMContext();
785     llvm::FenceInst *fence = cast<llvm::FenceInst>(inst);
786     llvmContext.getSyncScopeNames(ssNs);
787     int fenceSyncScopeID = fence->getSyncScopeID();
788     for (unsigned i = 0, e = ssNs.size(); i != e; i++) {
789       if (fenceSyncScopeID == llvmContext.getOrInsertSyncScopeID(ssNs[i])) {
790         syncscope = ssNs[i];
791         break;
792       }
793     }
794     b.create<FenceOp>(loc, getLLVMAtomicOrdering(fence->getOrdering()),
795                       syncscope);
796     return success();
797   }
798   case llvm::Instruction::GetElementPtr: {
799     // FIXME: Support inbounds GEPs.
800     llvm::GetElementPtrInst *gep = cast<llvm::GetElementPtrInst>(inst);
801     SmallVector<Value, 4> ops;
802     for (auto *op : gep->operand_values()) {
803       Value value = processValue(op);
804       if (!value)
805         return failure();
806       ops.push_back(value);
807     }
808     Type type = processType(inst->getType());
809     if (!type)
810       return failure();
811     v = b.create<GEPOp>(loc, type, ops);
812     return success();
813   }
814   }
815 }
816 
817 FlatSymbolRefAttr Importer::getPersonalityAsAttr(llvm::Function *f) {
818   if (!f->hasPersonalityFn())
819     return nullptr;
820 
821   llvm::Constant *pf = f->getPersonalityFn();
822 
823   // If it directly has a name, we can use it.
824   if (pf->hasName())
825     return b.getSymbolRefAttr(pf->getName());
826 
827   // If it doesn't have a name, currently, only function pointers that are
828   // bitcast to i8* are parsed.
829   if (auto ce = dyn_cast<llvm::ConstantExpr>(pf)) {
830     if (ce->getOpcode() == llvm::Instruction::BitCast &&
831         ce->getType() == llvm::Type::getInt8PtrTy(dialect->getLLVMContext())) {
832       if (auto func = dyn_cast<llvm::Function>(ce->getOperand(0)))
833         return b.getSymbolRefAttr(func->getName());
834     }
835   }
836   return FlatSymbolRefAttr();
837 }
838 
839 LogicalResult Importer::processFunction(llvm::Function *f) {
840   blocks.clear();
841   instMap.clear();
842   unknownInstMap.clear();
843 
844   LLVMType functionType = processType(f->getFunctionType());
845   if (!functionType)
846     return failure();
847 
848   b.setInsertionPoint(module.getBody(), getFuncInsertPt());
849   LLVMFuncOp fop = b.create<LLVMFuncOp>(UnknownLoc::get(context), f->getName(),
850                                         functionType);
851 
852   if (FlatSymbolRefAttr personality = getPersonalityAsAttr(f))
853     fop.setAttr(b.getIdentifier("personality"), personality);
854   else if (f->hasPersonalityFn())
855     emitWarning(UnknownLoc::get(context),
856                 "could not deduce personality, skipping it");
857 
858   if (f->isDeclaration())
859     return success();
860 
861   // Eagerly create all blocks.
862   SmallVector<Block *, 4> blockList;
863   for (llvm::BasicBlock &bb : *f) {
864     blockList.push_back(b.createBlock(&fop.body(), fop.body().end()));
865     blocks[&bb] = blockList.back();
866   }
867   currentEntryBlock = blockList[0];
868 
869   // Add function arguments to the entry block.
870   for (auto kv : llvm::enumerate(f->args()))
871     instMap[&kv.value()] = blockList[0]->addArgument(
872         functionType.getFunctionParamType(kv.index()));
873 
874   for (auto bbs : llvm::zip(*f, blockList)) {
875     if (failed(processBasicBlock(&std::get<0>(bbs), std::get<1>(bbs))))
876       return failure();
877   }
878 
879   // Now that all instructions are guaranteed to have been visited, ensure
880   // any unknown uses we encountered are remapped.
881   for (auto &llvmAndUnknown : unknownInstMap) {
882     assert(instMap.count(llvmAndUnknown.first));
883     Value newValue = instMap[llvmAndUnknown.first];
884     Value oldValue = llvmAndUnknown.second->getResult(0);
885     oldValue.replaceAllUsesWith(newValue);
886     llvmAndUnknown.second->erase();
887   }
888   return success();
889 }
890 
891 LogicalResult Importer::processBasicBlock(llvm::BasicBlock *bb, Block *block) {
892   b.setInsertionPointToStart(block);
893   for (llvm::Instruction &inst : *bb) {
894     if (failed(processInstruction(&inst)))
895       return failure();
896   }
897   return success();
898 }
899 
900 OwningModuleRef
901 mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
902                               MLIRContext *context) {
903   OwningModuleRef module(ModuleOp::create(
904       FileLineColLoc::get("", /*line=*/0, /*column=*/0, context)));
905 
906   Importer deserializer(context, module.get());
907   for (llvm::GlobalVariable &gv : llvmModule->globals()) {
908     if (!deserializer.processGlobal(&gv))
909       return {};
910   }
911   for (llvm::Function &f : llvmModule->functions()) {
912     if (failed(deserializer.processFunction(&f)))
913       return {};
914   }
915 
916   return module;
917 }
918 
919 // Deserializes the LLVM bitcode stored in `input` into an MLIR module in the
920 // LLVM dialect.
921 OwningModuleRef translateLLVMIRToModule(llvm::SourceMgr &sourceMgr,
922                                         MLIRContext *context) {
923   LLVMDialect *dialect = context->getRegisteredDialect<LLVMDialect>();
924   assert(dialect && "Could not find LLVMDialect?");
925 
926   llvm::SMDiagnostic err;
927   std::unique_ptr<llvm::Module> llvmModule =
928       llvm::parseIR(*sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), err,
929                     dialect->getLLVMContext());
930   if (!llvmModule) {
931     std::string errStr;
932     llvm::raw_string_ostream errStream(errStr);
933     err.print(/*ProgName=*/"", errStream);
934     emitError(UnknownLoc::get(context)) << errStream.str();
935     return {};
936   }
937   return translateLLVMIRToModule(std::move(llvmModule), context);
938 }
939 
940 namespace mlir {
941 void registerFromLLVMIRTranslation() {
942   TranslateToMLIRRegistration fromLLVM(
943       "import-llvm", [](llvm::SourceMgr &sourceMgr, MLIRContext *context) {
944         return ::translateLLVMIRToModule(sourceMgr, context);
945       });
946 }
947 } // namespace mlir
948