1 //===- ConvertFromLLVMIR.cpp - MLIR to LLVM IR conversion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between LLVM IR and the MLIR LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/MLIRContext.h"
18 #include "mlir/Target/LLVMIR/Import.h"
19 #include "mlir/Translation.h"
20 
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/IR/Attributes.h"
23 #include "llvm/IR/Constants.h"
24 #include "llvm/IR/DerivedTypes.h"
25 #include "llvm/IR/Function.h"
26 #include "llvm/IR/InlineAsm.h"
27 #include "llvm/IR/Instructions.h"
28 #include "llvm/IR/Type.h"
29 #include "llvm/IRReader/IRReader.h"
30 #include "llvm/Support/Error.h"
31 #include "llvm/Support/SourceMgr.h"
32 
33 using namespace mlir;
34 using namespace mlir::LLVM;
35 
36 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsFromLLVM.inc"
37 
38 // Utility to print an LLVM value as a string for passing to emitError().
39 // FIXME: Diagnostic should be able to natively handle types that have
40 // operator << (raw_ostream&) defined.
41 static std::string diag(llvm::Value &v) {
42   std::string s;
43   llvm::raw_string_ostream os(s);
44   os << v;
45   return os.str();
46 }
47 
48 namespace mlir {
49 namespace LLVM {
50 namespace detail {
51 /// Support for translating LLVM IR types to MLIR LLVM dialect types.
52 class TypeFromLLVMIRTranslatorImpl {
53 public:
54   /// Constructs a class creating types in the given MLIR context.
55   TypeFromLLVMIRTranslatorImpl(MLIRContext &context) : context(context) {}
56 
57   /// Translates the given type.
58   Type translateType(llvm::Type *type) {
59     if (knownTranslations.count(type))
60       return knownTranslations.lookup(type);
61 
62     Type translated =
63         llvm::TypeSwitch<llvm::Type *, Type>(type)
64             .Case<llvm::ArrayType, llvm::FunctionType, llvm::IntegerType,
65                   llvm::PointerType, llvm::StructType, llvm::FixedVectorType,
66                   llvm::ScalableVectorType>(
67                 [this](auto *type) { return this->translate(type); })
68             .Default([this](llvm::Type *type) {
69               return translatePrimitiveType(type);
70             });
71     knownTranslations.try_emplace(type, translated);
72     return translated;
73   }
74 
75 private:
76   /// Translates the given primitive, i.e. non-parametric in MLIR nomenclature,
77   /// type.
78   Type translatePrimitiveType(llvm::Type *type) {
79     if (type->isVoidTy())
80       return LLVM::LLVMVoidType::get(&context);
81     if (type->isHalfTy())
82       return Float16Type::get(&context);
83     if (type->isBFloatTy())
84       return BFloat16Type::get(&context);
85     if (type->isFloatTy())
86       return Float32Type::get(&context);
87     if (type->isDoubleTy())
88       return Float64Type::get(&context);
89     if (type->isFP128Ty())
90       return Float128Type::get(&context);
91     if (type->isX86_FP80Ty())
92       return Float80Type::get(&context);
93     if (type->isPPC_FP128Ty())
94       return LLVM::LLVMPPCFP128Type::get(&context);
95     if (type->isX86_MMXTy())
96       return LLVM::LLVMX86MMXType::get(&context);
97     if (type->isLabelTy())
98       return LLVM::LLVMLabelType::get(&context);
99     if (type->isMetadataTy())
100       return LLVM::LLVMMetadataType::get(&context);
101     llvm_unreachable("not a primitive type");
102   }
103 
104   /// Translates the given array type.
105   Type translate(llvm::ArrayType *type) {
106     return LLVM::LLVMArrayType::get(translateType(type->getElementType()),
107                                     type->getNumElements());
108   }
109 
110   /// Translates the given function type.
111   Type translate(llvm::FunctionType *type) {
112     SmallVector<Type, 8> paramTypes;
113     translateTypes(type->params(), paramTypes);
114     return LLVM::LLVMFunctionType::get(translateType(type->getReturnType()),
115                                        paramTypes, type->isVarArg());
116   }
117 
118   /// Translates the given integer type.
119   Type translate(llvm::IntegerType *type) {
120     return IntegerType::get(&context, type->getBitWidth());
121   }
122 
123   /// Translates the given pointer type.
124   Type translate(llvm::PointerType *type) {
125     return LLVM::LLVMPointerType::get(translateType(type->getElementType()),
126                                       type->getAddressSpace());
127   }
128 
129   /// Translates the given structure type.
130   Type translate(llvm::StructType *type) {
131     SmallVector<Type, 8> subtypes;
132     if (type->isLiteral()) {
133       translateTypes(type->subtypes(), subtypes);
134       return LLVM::LLVMStructType::getLiteral(&context, subtypes,
135                                               type->isPacked());
136     }
137 
138     if (type->isOpaque())
139       return LLVM::LLVMStructType::getOpaque(type->getName(), &context);
140 
141     LLVM::LLVMStructType translated =
142         LLVM::LLVMStructType::getIdentified(&context, type->getName());
143     knownTranslations.try_emplace(type, translated);
144     translateTypes(type->subtypes(), subtypes);
145     LogicalResult bodySet = translated.setBody(subtypes, type->isPacked());
146     assert(succeeded(bodySet) &&
147            "could not set the body of an identified struct");
148     (void)bodySet;
149     return translated;
150   }
151 
152   /// Translates the given fixed-vector type.
153   Type translate(llvm::FixedVectorType *type) {
154     return LLVM::getFixedVectorType(translateType(type->getElementType()),
155                                     type->getNumElements());
156   }
157 
158   /// Translates the given scalable-vector type.
159   Type translate(llvm::ScalableVectorType *type) {
160     return LLVM::LLVMScalableVectorType::get(
161         translateType(type->getElementType()), type->getMinNumElements());
162   }
163 
164   /// Translates a list of types.
165   void translateTypes(ArrayRef<llvm::Type *> types,
166                       SmallVectorImpl<Type> &result) {
167     result.reserve(result.size() + types.size());
168     for (llvm::Type *type : types)
169       result.push_back(translateType(type));
170   }
171 
172   /// Map of known translations. Serves as a cache and as recursion stopper for
173   /// translating recursive structs.
174   llvm::DenseMap<llvm::Type *, Type> knownTranslations;
175 
176   /// The context in which MLIR types are created.
177   MLIRContext &context;
178 };
179 } // end namespace detail
180 
181 /// Utility class to translate LLVM IR types to the MLIR LLVM dialect. Stores
182 /// the translation state, in particular any identified structure types that are
183 /// reused across translations.
184 class TypeFromLLVMIRTranslator {
185 public:
186   TypeFromLLVMIRTranslator(MLIRContext &context);
187   ~TypeFromLLVMIRTranslator();
188 
189   /// Translates the given LLVM IR type to the MLIR LLVM dialect.
190   Type translateType(llvm::Type *type);
191 
192 private:
193   /// Private implementation.
194   std::unique_ptr<detail::TypeFromLLVMIRTranslatorImpl> impl;
195 };
196 
197 } // end namespace LLVM
198 } // end namespace mlir
199 
200 LLVM::TypeFromLLVMIRTranslator::TypeFromLLVMIRTranslator(MLIRContext &context)
201     : impl(new detail::TypeFromLLVMIRTranslatorImpl(context)) {}
202 
203 LLVM::TypeFromLLVMIRTranslator::~TypeFromLLVMIRTranslator() {}
204 
205 Type LLVM::TypeFromLLVMIRTranslator::translateType(llvm::Type *type) {
206   return impl->translateType(type);
207 }
208 
209 // Handles importing globals and functions from an LLVM module.
210 namespace {
211 class Importer {
212 public:
213   Importer(MLIRContext *context, ModuleOp module)
214       : b(context), context(context), module(module),
215         unknownLoc(FileLineColLoc::get(context, "imported-bitcode", 0, 0)),
216         typeTranslator(*context) {
217     b.setInsertionPointToStart(module.getBody());
218   }
219 
220   /// Imports `f` into the current module.
221   LogicalResult processFunction(llvm::Function *f);
222 
223   /// Imports GV as a GlobalOp, creating it if it doesn't exist.
224   GlobalOp processGlobal(llvm::GlobalVariable *GV);
225 
226 private:
227   /// Returns personality of `f` as a FlatSymbolRefAttr.
228   FlatSymbolRefAttr getPersonalityAsAttr(llvm::Function *f);
229   /// Imports `bb` into `block`, which must be initially empty.
230   LogicalResult processBasicBlock(llvm::BasicBlock *bb, Block *block);
231   /// Imports `inst` and populates instMap[inst] with the imported Value.
232   LogicalResult processInstruction(llvm::Instruction *inst);
233   /// Creates an LLVM-compatible MLIR type for `type`.
234   Type processType(llvm::Type *type);
235   /// `value` is an SSA-use. Return the remapped version of `value` or a
236   /// placeholder that will be remapped later if this is an instruction that
237   /// has not yet been visited.
238   Value processValue(llvm::Value *value);
239   /// Create the most accurate Location possible using a llvm::DebugLoc and
240   /// possibly an llvm::Instruction to narrow the Location if debug information
241   /// is unavailable.
242   Location processDebugLoc(const llvm::DebugLoc &loc,
243                            llvm::Instruction *inst = nullptr);
244   /// `br` branches to `target`. Append the block arguments to attach to the
245   /// generated branch op to `blockArguments`. These should be in the same order
246   /// as the PHIs in `target`.
247   LogicalResult processBranchArgs(llvm::Instruction *br,
248                                   llvm::BasicBlock *target,
249                                   SmallVectorImpl<Value> &blockArguments);
250   /// Returns the builtin type equivalent to be used in attributes for the given
251   /// LLVM IR dialect type.
252   Type getStdTypeForAttr(Type type);
253   /// Return `value` as an attribute to attach to a GlobalOp.
254   Attribute getConstantAsAttr(llvm::Constant *value);
255   /// Return `c` as an MLIR Value. This could either be a ConstantOp, or
256   /// an expanded sequence of ops in the current function's entry block (for
257   /// ConstantExprs or ConstantGEPs).
258   Value processConstant(llvm::Constant *c);
259 
260   /// The current builder, pointing at where the next Instruction should be
261   /// generated.
262   OpBuilder b;
263   /// The current context.
264   MLIRContext *context;
265   /// The current module being created.
266   ModuleOp module;
267   /// The entry block of the current function being processed.
268   Block *currentEntryBlock;
269 
270   /// Globals are inserted before the first function, if any.
271   Block::iterator getGlobalInsertPt() {
272     auto it = module.getBody()->begin();
273     auto endIt = module.getBody()->end();
274     while (it != endIt && !isa<LLVMFuncOp>(it))
275       ++it;
276     return it;
277   }
278 
279   /// Functions are always inserted before the module terminator.
280   Block::iterator getFuncInsertPt() {
281     return std::prev(module.getBody()->end());
282   }
283 
284   /// Remapped blocks, for the current function.
285   DenseMap<llvm::BasicBlock *, Block *> blocks;
286   /// Remapped values. These are function-local.
287   DenseMap<llvm::Value *, Value> instMap;
288   /// Instructions that had not been defined when first encountered as a use.
289   /// Maps to the dummy Operation that was created in processValue().
290   DenseMap<llvm::Value *, Operation *> unknownInstMap;
291   /// Uniquing map of GlobalVariables.
292   DenseMap<llvm::GlobalVariable *, GlobalOp> globals;
293   /// Cached FileLineColLoc::get("imported-bitcode", 0, 0).
294   Location unknownLoc;
295   /// The stateful type translator (contains named structs).
296   LLVM::TypeFromLLVMIRTranslator typeTranslator;
297 };
298 } // namespace
299 
300 Location Importer::processDebugLoc(const llvm::DebugLoc &loc,
301                                    llvm::Instruction *inst) {
302   if (!loc && inst) {
303     std::string s;
304     llvm::raw_string_ostream os(s);
305     os << "llvm-imported-inst-%";
306     inst->printAsOperand(os, /*PrintType=*/false);
307     return FileLineColLoc::get(context, os.str(), 0, 0);
308   } else if (!loc) {
309     return unknownLoc;
310   }
311   // FIXME: Obtain the filename from DILocationInfo.
312   return FileLineColLoc::get(context, "imported-bitcode", loc.getLine(),
313                              loc.getCol());
314 }
315 
316 Type Importer::processType(llvm::Type *type) {
317   if (Type result = typeTranslator.translateType(type))
318     return result;
319 
320   // FIXME: Diagnostic should be able to natively handle types that have
321   // operator<<(raw_ostream&) defined.
322   std::string s;
323   llvm::raw_string_ostream os(s);
324   os << *type;
325   emitError(unknownLoc) << "unhandled type: " << os.str();
326   return nullptr;
327 }
328 
329 // We only need integers, floats, doubles, and vectors and tensors thereof for
330 // attributes. Scalar and vector types are converted to the standard
331 // equivalents. Array types are converted to ranked tensors; nested array types
332 // are converted to multi-dimensional tensors or vectors, depending on the
333 // innermost type being a scalar or a vector.
334 Type Importer::getStdTypeForAttr(Type type) {
335   if (!type)
336     return nullptr;
337 
338   if (type.isa<IntegerType, FloatType>())
339     return type;
340 
341   // LLVM vectors can only contain scalars.
342   if (LLVM::isCompatibleVectorType(type)) {
343     auto numElements = LLVM::getVectorNumElements(type);
344     if (numElements.isScalable()) {
345       emitError(unknownLoc) << "scalable vectors not supported";
346       return nullptr;
347     }
348     Type elementType = getStdTypeForAttr(LLVM::getVectorElementType(type));
349     if (!elementType)
350       return nullptr;
351     return VectorType::get(numElements.getKnownMinValue(), elementType);
352   }
353 
354   // LLVM arrays can contain other arrays or vectors.
355   if (auto arrayType = type.dyn_cast<LLVMArrayType>()) {
356     // Recover the nested array shape.
357     SmallVector<int64_t, 4> shape;
358     shape.push_back(arrayType.getNumElements());
359     while (arrayType.getElementType().isa<LLVMArrayType>()) {
360       arrayType = arrayType.getElementType().cast<LLVMArrayType>();
361       shape.push_back(arrayType.getNumElements());
362     }
363 
364     // If the innermost type is a vector, use the multi-dimensional vector as
365     // attribute type.
366     if (LLVM::isCompatibleVectorType(arrayType.getElementType())) {
367       auto numElements = LLVM::getVectorNumElements(arrayType.getElementType());
368       if (numElements.isScalable()) {
369         emitError(unknownLoc) << "scalable vectors not supported";
370         return nullptr;
371       }
372       shape.push_back(numElements.getKnownMinValue());
373 
374       Type elementType = getStdTypeForAttr(
375           LLVM::getVectorElementType(arrayType.getElementType()));
376       if (!elementType)
377         return nullptr;
378       return VectorType::get(shape, elementType);
379     }
380 
381     // Otherwise use a tensor.
382     Type elementType = getStdTypeForAttr(arrayType.getElementType());
383     if (!elementType)
384       return nullptr;
385     return RankedTensorType::get(shape, elementType);
386   }
387 
388   return nullptr;
389 }
390 
391 // Get the given constant as an attribute. Not all constants can be represented
392 // as attributes.
393 Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
394   if (auto *ci = dyn_cast<llvm::ConstantInt>(value))
395     return b.getIntegerAttr(
396         IntegerType::get(context, ci->getType()->getBitWidth()),
397         ci->getValue());
398   if (auto *c = dyn_cast<llvm::ConstantDataArray>(value))
399     if (c->isString())
400       return b.getStringAttr(c->getAsString());
401   if (auto *c = dyn_cast<llvm::ConstantFP>(value)) {
402     if (c->getType()->isDoubleTy())
403       return b.getFloatAttr(FloatType::getF64(context), c->getValueAPF());
404     if (c->getType()->isFloatingPointTy())
405       return b.getFloatAttr(FloatType::getF32(context), c->getValueAPF());
406   }
407   if (auto *f = dyn_cast<llvm::Function>(value))
408     return b.getSymbolRefAttr(f->getName());
409 
410   // Convert constant data to a dense elements attribute.
411   if (auto *cd = dyn_cast<llvm::ConstantDataSequential>(value)) {
412     Type type = processType(cd->getElementType());
413     if (!type)
414       return nullptr;
415 
416     auto attrType = getStdTypeForAttr(processType(cd->getType()))
417                         .dyn_cast_or_null<ShapedType>();
418     if (!attrType)
419       return nullptr;
420 
421     if (type.isa<IntegerType>()) {
422       SmallVector<APInt, 8> values;
423       values.reserve(cd->getNumElements());
424       for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i)
425         values.push_back(cd->getElementAsAPInt(i));
426       return DenseElementsAttr::get(attrType, values);
427     }
428 
429     if (type.isa<Float32Type, Float64Type>()) {
430       SmallVector<APFloat, 8> values;
431       values.reserve(cd->getNumElements());
432       for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i)
433         values.push_back(cd->getElementAsAPFloat(i));
434       return DenseElementsAttr::get(attrType, values);
435     }
436 
437     return nullptr;
438   }
439 
440   // Unpack constant aggregates to create dense elements attribute whenever
441   // possible. Return nullptr (failure) otherwise.
442   if (isa<llvm::ConstantAggregate>(value)) {
443     auto outerType = getStdTypeForAttr(processType(value->getType()))
444                          .dyn_cast_or_null<ShapedType>();
445     if (!outerType)
446       return nullptr;
447 
448     SmallVector<Attribute, 8> values;
449     SmallVector<int64_t, 8> shape;
450 
451     for (unsigned i = 0, e = value->getNumOperands(); i < e; ++i) {
452       auto nested = getConstantAsAttr(value->getAggregateElement(i))
453                         .dyn_cast_or_null<DenseElementsAttr>();
454       if (!nested)
455         return nullptr;
456 
457       values.append(nested.attr_value_begin(), nested.attr_value_end());
458     }
459 
460     return DenseElementsAttr::get(outerType, values);
461   }
462 
463   return nullptr;
464 }
465 
466 GlobalOp Importer::processGlobal(llvm::GlobalVariable *GV) {
467   auto it = globals.find(GV);
468   if (it != globals.end())
469     return it->second;
470 
471   OpBuilder b(module.getBody(), getGlobalInsertPt());
472   Attribute valueAttr;
473   if (GV->hasInitializer())
474     valueAttr = getConstantAsAttr(GV->getInitializer());
475   Type type = processType(GV->getValueType());
476   if (!type)
477     return nullptr;
478 
479   uint64_t alignment = 0;
480   llvm::MaybeAlign maybeAlign = GV->getAlign();
481   if (maybeAlign.hasValue()) {
482     llvm::Align align = maybeAlign.getValue();
483     alignment = align.value();
484   }
485 
486   GlobalOp op =
487       b.create<GlobalOp>(UnknownLoc::get(context), type, GV->isConstant(),
488                          convertLinkageFromLLVM(GV->getLinkage()),
489                          GV->getName(), valueAttr, alignment);
490 
491   if (GV->hasInitializer() && !valueAttr) {
492     Region &r = op.getInitializerRegion();
493     currentEntryBlock = b.createBlock(&r);
494     b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin());
495     Value v = processConstant(GV->getInitializer());
496     if (!v)
497       return nullptr;
498     b.create<ReturnOp>(op.getLoc(), ArrayRef<Value>({v}));
499   }
500   if (GV->hasAtLeastLocalUnnamedAddr())
501     op.unnamed_addrAttr(UnnamedAddrAttr::get(
502         context, convertUnnamedAddrFromLLVM(GV->getUnnamedAddr())));
503   if (GV->hasSection())
504     op.sectionAttr(b.getStringAttr(GV->getSection()));
505 
506   return globals[GV] = op;
507 }
508 
509 Value Importer::processConstant(llvm::Constant *c) {
510   OpBuilder bEntry(currentEntryBlock, currentEntryBlock->begin());
511   if (Attribute attr = getConstantAsAttr(c)) {
512     // These constants can be represented as attributes.
513     OpBuilder b(currentEntryBlock, currentEntryBlock->begin());
514     Type type = processType(c->getType());
515     if (!type)
516       return nullptr;
517     if (auto symbolRef = attr.dyn_cast<FlatSymbolRefAttr>())
518       return instMap[c] = bEntry.create<AddressOfOp>(unknownLoc, type,
519                                                      symbolRef.getValue());
520     return instMap[c] = bEntry.create<ConstantOp>(unknownLoc, type, attr);
521   }
522   if (auto *cn = dyn_cast<llvm::ConstantPointerNull>(c)) {
523     Type type = processType(cn->getType());
524     if (!type)
525       return nullptr;
526     return instMap[c] = bEntry.create<NullOp>(unknownLoc, type);
527   }
528   if (auto *GV = dyn_cast<llvm::GlobalVariable>(c))
529     return bEntry.create<AddressOfOp>(UnknownLoc::get(context),
530                                       processGlobal(GV));
531 
532   if (auto *ce = dyn_cast<llvm::ConstantExpr>(c)) {
533     llvm::Instruction *i = ce->getAsInstruction();
534     OpBuilder::InsertionGuard guard(b);
535     b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin());
536     if (failed(processInstruction(i)))
537       return nullptr;
538     assert(instMap.count(i));
539 
540     // Remove this zombie LLVM instruction now, leaving us only with the MLIR
541     // op.
542     i->deleteValue();
543     return instMap[c] = instMap[i];
544   }
545   if (auto *ue = dyn_cast<llvm::UndefValue>(c)) {
546     Type type = processType(ue->getType());
547     if (!type)
548       return nullptr;
549     return instMap[c] = bEntry.create<UndefOp>(UnknownLoc::get(context), type);
550   }
551   emitError(unknownLoc) << "unhandled constant: " << diag(*c);
552   return nullptr;
553 }
554 
555 Value Importer::processValue(llvm::Value *value) {
556   auto it = instMap.find(value);
557   if (it != instMap.end())
558     return it->second;
559 
560   // We don't expect to see instructions in dominator order. If we haven't seen
561   // this instruction yet, create an unknown op and remap it later.
562   if (isa<llvm::Instruction>(value)) {
563     OperationState state(UnknownLoc::get(context), "llvm.unknown");
564     Type type = processType(value->getType());
565     if (!type)
566       return nullptr;
567     state.addTypes(type);
568     unknownInstMap[value] = b.createOperation(state);
569     return unknownInstMap[value]->getResult(0);
570   }
571 
572   if (auto *c = dyn_cast<llvm::Constant>(value))
573     return processConstant(c);
574 
575   emitError(unknownLoc) << "unhandled value: " << diag(*value);
576   return nullptr;
577 }
578 
579 /// Return the MLIR OperationName for the given LLVM opcode.
580 static StringRef lookupOperationNameFromOpcode(unsigned opcode) {
581 // Maps from LLVM opcode to MLIR OperationName. This is deliberately ordered
582 // as in llvm/IR/Instructions.def to aid comprehension and spot missing
583 // instructions.
584 #define INST(llvm_n, mlir_n)                                                   \
585   { llvm::Instruction::llvm_n, LLVM::mlir_n##Op::getOperationName() }
586   static const DenseMap<unsigned, StringRef> opcMap = {
587       // Ret is handled specially.
588       // Br is handled specially.
589       // FIXME: switch
590       // FIXME: indirectbr
591       // FIXME: invoke
592       INST(Resume, Resume),
593       // FIXME: unreachable
594       // FIXME: cleanupret
595       // FIXME: catchret
596       // FIXME: catchswitch
597       // FIXME: callbr
598       // FIXME: fneg
599       INST(Add, Add), INST(FAdd, FAdd), INST(Sub, Sub), INST(FSub, FSub),
600       INST(Mul, Mul), INST(FMul, FMul), INST(UDiv, UDiv), INST(SDiv, SDiv),
601       INST(FDiv, FDiv), INST(URem, URem), INST(SRem, SRem), INST(FRem, FRem),
602       INST(Shl, Shl), INST(LShr, LShr), INST(AShr, AShr), INST(And, And),
603       INST(Or, Or), INST(Xor, XOr), INST(Alloca, Alloca), INST(Load, Load),
604       INST(Store, Store),
605       // Getelementptr is handled specially.
606       INST(Ret, Return), INST(Fence, Fence),
607       // FIXME: atomiccmpxchg
608       // FIXME: atomicrmw
609       INST(Trunc, Trunc), INST(ZExt, ZExt), INST(SExt, SExt),
610       INST(FPToUI, FPToUI), INST(FPToSI, FPToSI), INST(UIToFP, UIToFP),
611       INST(SIToFP, SIToFP), INST(FPTrunc, FPTrunc), INST(FPExt, FPExt),
612       INST(PtrToInt, PtrToInt), INST(IntToPtr, IntToPtr),
613       INST(BitCast, Bitcast), INST(AddrSpaceCast, AddrSpaceCast),
614       // FIXME: cleanuppad
615       // FIXME: catchpad
616       // ICmp is handled specially.
617       // FIXME: fcmp
618       // PHI is handled specially.
619       INST(Freeze, Freeze), INST(Call, Call),
620       // FIXME: select
621       // FIXME: vaarg
622       // FIXME: extractelement
623       // FIXME: insertelement
624       // FIXME: shufflevector
625       // FIXME: extractvalue
626       // FIXME: insertvalue
627       // FIXME: landingpad
628   };
629 #undef INST
630 
631   return opcMap.lookup(opcode);
632 }
633 
634 static ICmpPredicate getICmpPredicate(llvm::CmpInst::Predicate p) {
635   switch (p) {
636   default:
637     llvm_unreachable("incorrect comparison predicate");
638   case llvm::CmpInst::Predicate::ICMP_EQ:
639     return LLVM::ICmpPredicate::eq;
640   case llvm::CmpInst::Predicate::ICMP_NE:
641     return LLVM::ICmpPredicate::ne;
642   case llvm::CmpInst::Predicate::ICMP_SLT:
643     return LLVM::ICmpPredicate::slt;
644   case llvm::CmpInst::Predicate::ICMP_SLE:
645     return LLVM::ICmpPredicate::sle;
646   case llvm::CmpInst::Predicate::ICMP_SGT:
647     return LLVM::ICmpPredicate::sgt;
648   case llvm::CmpInst::Predicate::ICMP_SGE:
649     return LLVM::ICmpPredicate::sge;
650   case llvm::CmpInst::Predicate::ICMP_ULT:
651     return LLVM::ICmpPredicate::ult;
652   case llvm::CmpInst::Predicate::ICMP_ULE:
653     return LLVM::ICmpPredicate::ule;
654   case llvm::CmpInst::Predicate::ICMP_UGT:
655     return LLVM::ICmpPredicate::ugt;
656   case llvm::CmpInst::Predicate::ICMP_UGE:
657     return LLVM::ICmpPredicate::uge;
658   }
659   llvm_unreachable("incorrect comparison predicate");
660 }
661 
662 static AtomicOrdering getLLVMAtomicOrdering(llvm::AtomicOrdering ordering) {
663   switch (ordering) {
664   case llvm::AtomicOrdering::NotAtomic:
665     return LLVM::AtomicOrdering::not_atomic;
666   case llvm::AtomicOrdering::Unordered:
667     return LLVM::AtomicOrdering::unordered;
668   case llvm::AtomicOrdering::Monotonic:
669     return LLVM::AtomicOrdering::monotonic;
670   case llvm::AtomicOrdering::Acquire:
671     return LLVM::AtomicOrdering::acquire;
672   case llvm::AtomicOrdering::Release:
673     return LLVM::AtomicOrdering::release;
674   case llvm::AtomicOrdering::AcquireRelease:
675     return LLVM::AtomicOrdering::acq_rel;
676   case llvm::AtomicOrdering::SequentiallyConsistent:
677     return LLVM::AtomicOrdering::seq_cst;
678   }
679   llvm_unreachable("incorrect atomic ordering");
680 }
681 
682 // `br` branches to `target`. Return the branch arguments to `br`, in the
683 // same order of the PHIs in `target`.
684 LogicalResult
685 Importer::processBranchArgs(llvm::Instruction *br, llvm::BasicBlock *target,
686                             SmallVectorImpl<Value> &blockArguments) {
687   for (auto inst = target->begin(); isa<llvm::PHINode>(inst); ++inst) {
688     auto *PN = cast<llvm::PHINode>(&*inst);
689     Value value = processValue(PN->getIncomingValueForBlock(br->getParent()));
690     if (!value)
691       return failure();
692     blockArguments.push_back(value);
693   }
694   return success();
695 }
696 
697 LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
698   // FIXME: Support uses of SubtargetData. Currently inbounds GEPs, fast-math
699   // flags and call / operand attributes are not supported.
700   Location loc = processDebugLoc(inst->getDebugLoc(), inst);
701   Value &v = instMap[inst];
702   assert(!v && "processInstruction must be called only once per instruction!");
703   switch (inst->getOpcode()) {
704   default:
705     return emitError(loc) << "unknown instruction: " << diag(*inst);
706   case llvm::Instruction::Add:
707   case llvm::Instruction::FAdd:
708   case llvm::Instruction::Sub:
709   case llvm::Instruction::FSub:
710   case llvm::Instruction::Mul:
711   case llvm::Instruction::FMul:
712   case llvm::Instruction::UDiv:
713   case llvm::Instruction::SDiv:
714   case llvm::Instruction::FDiv:
715   case llvm::Instruction::URem:
716   case llvm::Instruction::SRem:
717   case llvm::Instruction::FRem:
718   case llvm::Instruction::Shl:
719   case llvm::Instruction::LShr:
720   case llvm::Instruction::AShr:
721   case llvm::Instruction::And:
722   case llvm::Instruction::Or:
723   case llvm::Instruction::Xor:
724   case llvm::Instruction::Alloca:
725   case llvm::Instruction::Load:
726   case llvm::Instruction::Store:
727   case llvm::Instruction::Ret:
728   case llvm::Instruction::Resume:
729   case llvm::Instruction::Trunc:
730   case llvm::Instruction::ZExt:
731   case llvm::Instruction::SExt:
732   case llvm::Instruction::FPToUI:
733   case llvm::Instruction::FPToSI:
734   case llvm::Instruction::UIToFP:
735   case llvm::Instruction::SIToFP:
736   case llvm::Instruction::FPTrunc:
737   case llvm::Instruction::FPExt:
738   case llvm::Instruction::PtrToInt:
739   case llvm::Instruction::IntToPtr:
740   case llvm::Instruction::AddrSpaceCast:
741   case llvm::Instruction::Freeze:
742   case llvm::Instruction::BitCast: {
743     OperationState state(loc, lookupOperationNameFromOpcode(inst->getOpcode()));
744     SmallVector<Value, 4> ops;
745     ops.reserve(inst->getNumOperands());
746     for (auto *op : inst->operand_values()) {
747       Value value = processValue(op);
748       if (!value)
749         return failure();
750       ops.push_back(value);
751     }
752     state.addOperands(ops);
753     if (!inst->getType()->isVoidTy()) {
754       Type type = processType(inst->getType());
755       if (!type)
756         return failure();
757       state.addTypes(type);
758     }
759     Operation *op = b.createOperation(state);
760     if (!inst->getType()->isVoidTy())
761       v = op->getResult(0);
762     return success();
763   }
764   case llvm::Instruction::ICmp: {
765     Value lhs = processValue(inst->getOperand(0));
766     Value rhs = processValue(inst->getOperand(1));
767     if (!lhs || !rhs)
768       return failure();
769     v = b.create<ICmpOp>(
770         loc, getICmpPredicate(cast<llvm::ICmpInst>(inst)->getPredicate()), lhs,
771         rhs);
772     return success();
773   }
774   case llvm::Instruction::Br: {
775     auto *brInst = cast<llvm::BranchInst>(inst);
776     OperationState state(loc,
777                          brInst->isConditional() ? "llvm.cond_br" : "llvm.br");
778     if (brInst->isConditional()) {
779       Value condition = processValue(brInst->getCondition());
780       if (!condition)
781         return failure();
782       state.addOperands(condition);
783     }
784 
785     std::array<int32_t, 3> operandSegmentSizes = {1, 0, 0};
786     for (int i : llvm::seq<int>(0, brInst->getNumSuccessors())) {
787       auto *succ = brInst->getSuccessor(i);
788       SmallVector<Value, 4> blockArguments;
789       if (failed(processBranchArgs(brInst, succ, blockArguments)))
790         return failure();
791       state.addSuccessors(blocks[succ]);
792       state.addOperands(blockArguments);
793       operandSegmentSizes[i + 1] = blockArguments.size();
794     }
795 
796     if (brInst->isConditional()) {
797       state.addAttribute(LLVM::CondBrOp::getOperandSegmentSizeAttr(),
798                          b.getI32VectorAttr(operandSegmentSizes));
799     }
800 
801     b.createOperation(state);
802     return success();
803   }
804   case llvm::Instruction::PHI: {
805     Type type = processType(inst->getType());
806     if (!type)
807       return failure();
808     v = b.getInsertionBlock()->addArgument(type);
809     return success();
810   }
811   case llvm::Instruction::Call: {
812     llvm::CallInst *ci = cast<llvm::CallInst>(inst);
813     SmallVector<Value, 4> ops;
814     ops.reserve(inst->getNumOperands());
815     for (auto &op : ci->arg_operands()) {
816       Value arg = processValue(op.get());
817       if (!arg)
818         return failure();
819       ops.push_back(arg);
820     }
821 
822     SmallVector<Type, 2> tys;
823     if (!ci->getType()->isVoidTy()) {
824       Type type = processType(inst->getType());
825       if (!type)
826         return failure();
827       tys.push_back(type);
828     }
829     Operation *op;
830     if (llvm::Function *callee = ci->getCalledFunction()) {
831       op = b.create<CallOp>(loc, tys, b.getSymbolRefAttr(callee->getName()),
832                             ops);
833     } else {
834       Value calledValue = processValue(ci->getCalledOperand());
835       if (!calledValue)
836         return failure();
837       ops.insert(ops.begin(), calledValue);
838       op = b.create<CallOp>(loc, tys, ops);
839     }
840     if (!ci->getType()->isVoidTy())
841       v = op->getResult(0);
842     return success();
843   }
844   case llvm::Instruction::LandingPad: {
845     llvm::LandingPadInst *lpi = cast<llvm::LandingPadInst>(inst);
846     SmallVector<Value, 4> ops;
847 
848     for (unsigned i = 0, ie = lpi->getNumClauses(); i < ie; i++)
849       ops.push_back(processConstant(lpi->getClause(i)));
850 
851     Type ty = processType(lpi->getType());
852     if (!ty)
853       return failure();
854 
855     v = b.create<LandingpadOp>(loc, ty, lpi->isCleanup(), ops);
856     return success();
857   }
858   case llvm::Instruction::Invoke: {
859     llvm::InvokeInst *ii = cast<llvm::InvokeInst>(inst);
860 
861     SmallVector<Type, 2> tys;
862     if (!ii->getType()->isVoidTy())
863       tys.push_back(processType(inst->getType()));
864 
865     SmallVector<Value, 4> ops;
866     ops.reserve(inst->getNumOperands() + 1);
867     for (auto &op : ii->arg_operands())
868       ops.push_back(processValue(op.get()));
869 
870     SmallVector<Value, 4> normalArgs, unwindArgs;
871     (void)processBranchArgs(ii, ii->getNormalDest(), normalArgs);
872     (void)processBranchArgs(ii, ii->getUnwindDest(), unwindArgs);
873 
874     Operation *op;
875     if (llvm::Function *callee = ii->getCalledFunction()) {
876       op = b.create<InvokeOp>(loc, tys, b.getSymbolRefAttr(callee->getName()),
877                               ops, blocks[ii->getNormalDest()], normalArgs,
878                               blocks[ii->getUnwindDest()], unwindArgs);
879     } else {
880       ops.insert(ops.begin(), processValue(ii->getCalledOperand()));
881       op = b.create<InvokeOp>(loc, tys, ops, blocks[ii->getNormalDest()],
882                               normalArgs, blocks[ii->getUnwindDest()],
883                               unwindArgs);
884     }
885 
886     if (!ii->getType()->isVoidTy())
887       v = op->getResult(0);
888     return success();
889   }
890   case llvm::Instruction::Fence: {
891     StringRef syncscope;
892     SmallVector<StringRef, 4> ssNs;
893     llvm::LLVMContext &llvmContext = inst->getContext();
894     llvm::FenceInst *fence = cast<llvm::FenceInst>(inst);
895     llvmContext.getSyncScopeNames(ssNs);
896     int fenceSyncScopeID = fence->getSyncScopeID();
897     for (unsigned i = 0, e = ssNs.size(); i != e; i++) {
898       if (fenceSyncScopeID == llvmContext.getOrInsertSyncScopeID(ssNs[i])) {
899         syncscope = ssNs[i];
900         break;
901       }
902     }
903     b.create<FenceOp>(loc, getLLVMAtomicOrdering(fence->getOrdering()),
904                       syncscope);
905     return success();
906   }
907   case llvm::Instruction::GetElementPtr: {
908     // FIXME: Support inbounds GEPs.
909     llvm::GetElementPtrInst *gep = cast<llvm::GetElementPtrInst>(inst);
910     SmallVector<Value, 4> ops;
911     for (auto *op : gep->operand_values()) {
912       Value value = processValue(op);
913       if (!value)
914         return failure();
915       ops.push_back(value);
916     }
917     Type type = processType(inst->getType());
918     if (!type)
919       return failure();
920     v = b.create<GEPOp>(loc, type, ops);
921     return success();
922   }
923   }
924 }
925 
926 FlatSymbolRefAttr Importer::getPersonalityAsAttr(llvm::Function *f) {
927   if (!f->hasPersonalityFn())
928     return nullptr;
929 
930   llvm::Constant *pf = f->getPersonalityFn();
931 
932   // If it directly has a name, we can use it.
933   if (pf->hasName())
934     return b.getSymbolRefAttr(pf->getName());
935 
936   // If it doesn't have a name, currently, only function pointers that are
937   // bitcast to i8* are parsed.
938   if (auto ce = dyn_cast<llvm::ConstantExpr>(pf)) {
939     if (ce->getOpcode() == llvm::Instruction::BitCast &&
940         ce->getType() == llvm::Type::getInt8PtrTy(f->getContext())) {
941       if (auto func = dyn_cast<llvm::Function>(ce->getOperand(0)))
942         return b.getSymbolRefAttr(func->getName());
943     }
944   }
945   return FlatSymbolRefAttr();
946 }
947 
948 LogicalResult Importer::processFunction(llvm::Function *f) {
949   blocks.clear();
950   instMap.clear();
951   unknownInstMap.clear();
952 
953   auto functionType =
954       processType(f->getFunctionType()).dyn_cast<LLVMFunctionType>();
955   if (!functionType)
956     return failure();
957 
958   b.setInsertionPoint(module.getBody(), getFuncInsertPt());
959   LLVMFuncOp fop =
960       b.create<LLVMFuncOp>(UnknownLoc::get(context), f->getName(), functionType,
961                            convertLinkageFromLLVM(f->getLinkage()));
962 
963   if (FlatSymbolRefAttr personality = getPersonalityAsAttr(f))
964     fop->setAttr(b.getIdentifier("personality"), personality);
965   else if (f->hasPersonalityFn())
966     emitWarning(UnknownLoc::get(context),
967                 "could not deduce personality, skipping it");
968 
969   if (f->isDeclaration())
970     return success();
971 
972   // Eagerly create all blocks.
973   SmallVector<Block *, 4> blockList;
974   for (llvm::BasicBlock &bb : *f) {
975     blockList.push_back(b.createBlock(&fop.body(), fop.body().end()));
976     blocks[&bb] = blockList.back();
977   }
978   currentEntryBlock = blockList[0];
979 
980   // Add function arguments to the entry block.
981   for (auto kv : llvm::enumerate(f->args()))
982     instMap[&kv.value()] =
983         blockList[0]->addArgument(functionType.getParamType(kv.index()));
984 
985   for (auto bbs : llvm::zip(*f, blockList)) {
986     if (failed(processBasicBlock(&std::get<0>(bbs), std::get<1>(bbs))))
987       return failure();
988   }
989 
990   // Now that all instructions are guaranteed to have been visited, ensure
991   // any unknown uses we encountered are remapped.
992   for (auto &llvmAndUnknown : unknownInstMap) {
993     assert(instMap.count(llvmAndUnknown.first));
994     Value newValue = instMap[llvmAndUnknown.first];
995     Value oldValue = llvmAndUnknown.second->getResult(0);
996     oldValue.replaceAllUsesWith(newValue);
997     llvmAndUnknown.second->erase();
998   }
999   return success();
1000 }
1001 
1002 LogicalResult Importer::processBasicBlock(llvm::BasicBlock *bb, Block *block) {
1003   b.setInsertionPointToStart(block);
1004   for (llvm::Instruction &inst : *bb) {
1005     if (failed(processInstruction(&inst)))
1006       return failure();
1007   }
1008   return success();
1009 }
1010 
1011 OwningModuleRef
1012 mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
1013                               MLIRContext *context) {
1014   context->loadDialect<LLVMDialect>();
1015   OwningModuleRef module(ModuleOp::create(
1016       FileLineColLoc::get(context, "", /*line=*/0, /*column=*/0)));
1017 
1018   Importer deserializer(context, module.get());
1019   for (llvm::GlobalVariable &gv : llvmModule->globals()) {
1020     if (!deserializer.processGlobal(&gv))
1021       return {};
1022   }
1023   for (llvm::Function &f : llvmModule->functions()) {
1024     if (failed(deserializer.processFunction(&f)))
1025       return {};
1026   }
1027 
1028   return module;
1029 }
1030 
1031 // Deserializes the LLVM bitcode stored in `input` into an MLIR module in the
1032 // LLVM dialect.
1033 OwningModuleRef translateLLVMIRToModule(llvm::SourceMgr &sourceMgr,
1034                                         MLIRContext *context) {
1035   llvm::SMDiagnostic err;
1036   llvm::LLVMContext llvmContext;
1037   std::unique_ptr<llvm::Module> llvmModule = llvm::parseIR(
1038       *sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), err, llvmContext);
1039   if (!llvmModule) {
1040     std::string errStr;
1041     llvm::raw_string_ostream errStream(errStr);
1042     err.print(/*ProgName=*/"", errStream);
1043     emitError(UnknownLoc::get(context)) << errStream.str();
1044     return {};
1045   }
1046   return translateLLVMIRToModule(std::move(llvmModule), context);
1047 }
1048 
1049 namespace mlir {
1050 void registerFromLLVMIRTranslation() {
1051   TranslateToMLIRRegistration fromLLVM(
1052       "import-llvm", [](llvm::SourceMgr &sourceMgr, MLIRContext *context) {
1053         return ::translateLLVMIRToModule(sourceMgr, context);
1054       });
1055 }
1056 } // namespace mlir
1057