1 //===- ConvertFromLLVMIR.cpp - MLIR to LLVM IR conversion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between LLVM IR and the MLIR LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/MLIRContext.h"
18 #include "mlir/Target/LLVMIR/Import.h"
19 #include "mlir/Translation.h"
20 
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/IR/Attributes.h"
23 #include "llvm/IR/Constants.h"
24 #include "llvm/IR/DerivedTypes.h"
25 #include "llvm/IR/Function.h"
26 #include "llvm/IR/InlineAsm.h"
27 #include "llvm/IR/Instructions.h"
28 #include "llvm/IR/Type.h"
29 #include "llvm/IRReader/IRReader.h"
30 #include "llvm/Support/Error.h"
31 #include "llvm/Support/SourceMgr.h"
32 
33 using namespace mlir;
34 using namespace mlir::LLVM;
35 
36 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsFromLLVM.inc"
37 
38 // Utility to print an LLVM value as a string for passing to emitError().
39 // FIXME: Diagnostic should be able to natively handle types that have
40 // operator << (raw_ostream&) defined.
41 static std::string diag(llvm::Value &v) {
42   std::string s;
43   llvm::raw_string_ostream os(s);
44   os << v;
45   return os.str();
46 }
47 
48 namespace mlir {
49 namespace LLVM {
50 namespace detail {
51 /// Support for translating LLVM IR types to MLIR LLVM dialect types.
52 class TypeFromLLVMIRTranslatorImpl {
53 public:
54   /// Constructs a class creating types in the given MLIR context.
55   TypeFromLLVMIRTranslatorImpl(MLIRContext &context) : context(context) {}
56 
57   /// Translates the given type.
58   Type translateType(llvm::Type *type) {
59     if (knownTranslations.count(type))
60       return knownTranslations.lookup(type);
61 
62     Type translated =
63         llvm::TypeSwitch<llvm::Type *, Type>(type)
64             .Case<llvm::ArrayType, llvm::FunctionType, llvm::IntegerType,
65                   llvm::PointerType, llvm::StructType, llvm::FixedVectorType,
66                   llvm::ScalableVectorType>(
67                 [this](auto *type) { return this->translate(type); })
68             .Default([this](llvm::Type *type) {
69               return translatePrimitiveType(type);
70             });
71     knownTranslations.try_emplace(type, translated);
72     return translated;
73   }
74 
75 private:
76   /// Translates the given primitive, i.e. non-parametric in MLIR nomenclature,
77   /// type.
78   Type translatePrimitiveType(llvm::Type *type) {
79     if (type->isVoidTy())
80       return LLVM::LLVMVoidType::get(&context);
81     if (type->isHalfTy())
82       return Float16Type::get(&context);
83     if (type->isBFloatTy())
84       return BFloat16Type::get(&context);
85     if (type->isFloatTy())
86       return Float32Type::get(&context);
87     if (type->isDoubleTy())
88       return Float64Type::get(&context);
89     if (type->isFP128Ty())
90       return Float128Type::get(&context);
91     if (type->isX86_FP80Ty())
92       return Float80Type::get(&context);
93     if (type->isPPC_FP128Ty())
94       return LLVM::LLVMPPCFP128Type::get(&context);
95     if (type->isX86_MMXTy())
96       return LLVM::LLVMX86MMXType::get(&context);
97     if (type->isLabelTy())
98       return LLVM::LLVMLabelType::get(&context);
99     if (type->isMetadataTy())
100       return LLVM::LLVMMetadataType::get(&context);
101     llvm_unreachable("not a primitive type");
102   }
103 
104   /// Translates the given array type.
105   Type translate(llvm::ArrayType *type) {
106     return LLVM::LLVMArrayType::get(translateType(type->getElementType()),
107                                     type->getNumElements());
108   }
109 
110   /// Translates the given function type.
111   Type translate(llvm::FunctionType *type) {
112     SmallVector<Type, 8> paramTypes;
113     translateTypes(type->params(), paramTypes);
114     return LLVM::LLVMFunctionType::get(translateType(type->getReturnType()),
115                                        paramTypes, type->isVarArg());
116   }
117 
118   /// Translates the given integer type.
119   Type translate(llvm::IntegerType *type) {
120     return IntegerType::get(&context, type->getBitWidth());
121   }
122 
123   /// Translates the given pointer type.
124   Type translate(llvm::PointerType *type) {
125     return LLVM::LLVMPointerType::get(translateType(type->getElementType()),
126                                       type->getAddressSpace());
127   }
128 
129   /// Translates the given structure type.
130   Type translate(llvm::StructType *type) {
131     SmallVector<Type, 8> subtypes;
132     if (type->isLiteral()) {
133       translateTypes(type->subtypes(), subtypes);
134       return LLVM::LLVMStructType::getLiteral(&context, subtypes,
135                                               type->isPacked());
136     }
137 
138     if (type->isOpaque())
139       return LLVM::LLVMStructType::getOpaque(type->getName(), &context);
140 
141     LLVM::LLVMStructType translated =
142         LLVM::LLVMStructType::getIdentified(&context, type->getName());
143     knownTranslations.try_emplace(type, translated);
144     translateTypes(type->subtypes(), subtypes);
145     LogicalResult bodySet = translated.setBody(subtypes, type->isPacked());
146     assert(succeeded(bodySet) &&
147            "could not set the body of an identified struct");
148     (void)bodySet;
149     return translated;
150   }
151 
152   /// Translates the given fixed-vector type.
153   Type translate(llvm::FixedVectorType *type) {
154     return LLVM::getFixedVectorType(translateType(type->getElementType()),
155                                     type->getNumElements());
156   }
157 
158   /// Translates the given scalable-vector type.
159   Type translate(llvm::ScalableVectorType *type) {
160     return LLVM::LLVMScalableVectorType::get(
161         translateType(type->getElementType()), type->getMinNumElements());
162   }
163 
164   /// Translates a list of types.
165   void translateTypes(ArrayRef<llvm::Type *> types,
166                       SmallVectorImpl<Type> &result) {
167     result.reserve(result.size() + types.size());
168     for (llvm::Type *type : types)
169       result.push_back(translateType(type));
170   }
171 
172   /// Map of known translations. Serves as a cache and as recursion stopper for
173   /// translating recursive structs.
174   llvm::DenseMap<llvm::Type *, Type> knownTranslations;
175 
176   /// The context in which MLIR types are created.
177   MLIRContext &context;
178 };
179 } // end namespace detail
180 
181 /// Utility class to translate LLVM IR types to the MLIR LLVM dialect. Stores
182 /// the translation state, in particular any identified structure types that are
183 /// reused across translations.
184 class TypeFromLLVMIRTranslator {
185 public:
186   TypeFromLLVMIRTranslator(MLIRContext &context);
187   ~TypeFromLLVMIRTranslator();
188 
189   /// Translates the given LLVM IR type to the MLIR LLVM dialect.
190   Type translateType(llvm::Type *type);
191 
192 private:
193   /// Private implementation.
194   std::unique_ptr<detail::TypeFromLLVMIRTranslatorImpl> impl;
195 };
196 
197 } // end namespace LLVM
198 } // end namespace mlir
199 
200 LLVM::TypeFromLLVMIRTranslator::TypeFromLLVMIRTranslator(MLIRContext &context)
201     : impl(new detail::TypeFromLLVMIRTranslatorImpl(context)) {}
202 
203 LLVM::TypeFromLLVMIRTranslator::~TypeFromLLVMIRTranslator() {}
204 
205 Type LLVM::TypeFromLLVMIRTranslator::translateType(llvm::Type *type) {
206   return impl->translateType(type);
207 }
208 
209 // Handles importing globals and functions from an LLVM module.
210 namespace {
211 class Importer {
212 public:
213   Importer(MLIRContext *context, ModuleOp module)
214       : b(context), context(context), module(module),
215         unknownLoc(FileLineColLoc::get(context, "imported-bitcode", 0, 0)),
216         typeTranslator(*context) {
217     b.setInsertionPointToStart(module.getBody());
218   }
219 
220   /// Imports `f` into the current module.
221   LogicalResult processFunction(llvm::Function *f);
222 
223   /// Imports GV as a GlobalOp, creating it if it doesn't exist.
224   GlobalOp processGlobal(llvm::GlobalVariable *GV);
225 
226 private:
227   /// Returns personality of `f` as a FlatSymbolRefAttr.
228   FlatSymbolRefAttr getPersonalityAsAttr(llvm::Function *f);
229   /// Imports `bb` into `block`, which must be initially empty.
230   LogicalResult processBasicBlock(llvm::BasicBlock *bb, Block *block);
231   /// Imports `inst` and populates instMap[inst] with the imported Value.
232   LogicalResult processInstruction(llvm::Instruction *inst);
233   /// Creates an LLVM-compatible MLIR type for `type`.
234   Type processType(llvm::Type *type);
235   /// `value` is an SSA-use. Return the remapped version of `value` or a
236   /// placeholder that will be remapped later if this is an instruction that
237   /// has not yet been visited.
238   Value processValue(llvm::Value *value);
239   /// Create the most accurate Location possible using a llvm::DebugLoc and
240   /// possibly an llvm::Instruction to narrow the Location if debug information
241   /// is unavailable.
242   Location processDebugLoc(const llvm::DebugLoc &loc,
243                            llvm::Instruction *inst = nullptr);
244   /// `br` branches to `target`. Append the block arguments to attach to the
245   /// generated branch op to `blockArguments`. These should be in the same order
246   /// as the PHIs in `target`.
247   LogicalResult processBranchArgs(llvm::Instruction *br,
248                                   llvm::BasicBlock *target,
249                                   SmallVectorImpl<Value> &blockArguments);
250   /// Returns the builtin type equivalent to be used in attributes for the given
251   /// LLVM IR dialect type.
252   Type getStdTypeForAttr(Type type);
253   /// Return `value` as an attribute to attach to a GlobalOp.
254   Attribute getConstantAsAttr(llvm::Constant *value);
255   /// Return `c` as an MLIR Value. This could either be a ConstantOp, or
256   /// an expanded sequence of ops in the current function's entry block (for
257   /// ConstantExprs or ConstantGEPs).
258   Value processConstant(llvm::Constant *c);
259 
260   /// The current builder, pointing at where the next Instruction should be
261   /// generated.
262   OpBuilder b;
263   /// The current context.
264   MLIRContext *context;
265   /// The current module being created.
266   ModuleOp module;
267   /// The entry block of the current function being processed.
268   Block *currentEntryBlock;
269 
270   /// Globals are inserted before the first function, if any.
271   Block::iterator getGlobalInsertPt() {
272     auto it = module.getBody()->begin();
273     auto endIt = module.getBody()->end();
274     while (it != endIt && !isa<LLVMFuncOp>(it))
275       ++it;
276     return it;
277   }
278 
279   /// Functions are always inserted before the module terminator.
280   Block::iterator getFuncInsertPt() {
281     return std::prev(module.getBody()->end());
282   }
283 
284   /// Remapped blocks, for the current function.
285   DenseMap<llvm::BasicBlock *, Block *> blocks;
286   /// Remapped values. These are function-local.
287   DenseMap<llvm::Value *, Value> instMap;
288   /// Instructions that had not been defined when first encountered as a use.
289   /// Maps to the dummy Operation that was created in processValue().
290   DenseMap<llvm::Value *, Operation *> unknownInstMap;
291   /// Uniquing map of GlobalVariables.
292   DenseMap<llvm::GlobalVariable *, GlobalOp> globals;
293   /// Cached FileLineColLoc::get("imported-bitcode", 0, 0).
294   Location unknownLoc;
295   /// The stateful type translator (contains named structs).
296   LLVM::TypeFromLLVMIRTranslator typeTranslator;
297 };
298 } // namespace
299 
300 Location Importer::processDebugLoc(const llvm::DebugLoc &loc,
301                                    llvm::Instruction *inst) {
302   if (!loc && inst) {
303     std::string s;
304     llvm::raw_string_ostream os(s);
305     os << "llvm-imported-inst-%";
306     inst->printAsOperand(os, /*PrintType=*/false);
307     return FileLineColLoc::get(context, os.str(), 0, 0);
308   } else if (!loc) {
309     return unknownLoc;
310   }
311   // FIXME: Obtain the filename from DILocationInfo.
312   return FileLineColLoc::get(context, "imported-bitcode", loc.getLine(),
313                              loc.getCol());
314 }
315 
316 Type Importer::processType(llvm::Type *type) {
317   if (Type result = typeTranslator.translateType(type))
318     return result;
319 
320   // FIXME: Diagnostic should be able to natively handle types that have
321   // operator<<(raw_ostream&) defined.
322   std::string s;
323   llvm::raw_string_ostream os(s);
324   os << *type;
325   emitError(unknownLoc) << "unhandled type: " << os.str();
326   return nullptr;
327 }
328 
329 // We only need integers, floats, doubles, and vectors and tensors thereof for
330 // attributes. Scalar and vector types are converted to the standard
331 // equivalents. Array types are converted to ranked tensors; nested array types
332 // are converted to multi-dimensional tensors or vectors, depending on the
333 // innermost type being a scalar or a vector.
334 Type Importer::getStdTypeForAttr(Type type) {
335   if (!type)
336     return nullptr;
337 
338   if (type.isa<IntegerType, FloatType>())
339     return type;
340 
341   // LLVM vectors can only contain scalars.
342   if (LLVM::isCompatibleVectorType(type)) {
343     auto numElements = LLVM::getVectorNumElements(type);
344     if (numElements.isScalable()) {
345       emitError(unknownLoc) << "scalable vectors not supported";
346       return nullptr;
347     }
348     Type elementType = getStdTypeForAttr(LLVM::getVectorElementType(type));
349     if (!elementType)
350       return nullptr;
351     return VectorType::get(numElements.getKnownMinValue(), elementType);
352   }
353 
354   // LLVM arrays can contain other arrays or vectors.
355   if (auto arrayType = type.dyn_cast<LLVMArrayType>()) {
356     // Recover the nested array shape.
357     SmallVector<int64_t, 4> shape;
358     shape.push_back(arrayType.getNumElements());
359     while (arrayType.getElementType().isa<LLVMArrayType>()) {
360       arrayType = arrayType.getElementType().cast<LLVMArrayType>();
361       shape.push_back(arrayType.getNumElements());
362     }
363 
364     // If the innermost type is a vector, use the multi-dimensional vector as
365     // attribute type.
366     if (LLVM::isCompatibleVectorType(arrayType.getElementType())) {
367       auto numElements = LLVM::getVectorNumElements(arrayType.getElementType());
368       if (numElements.isScalable()) {
369         emitError(unknownLoc) << "scalable vectors not supported";
370         return nullptr;
371       }
372       shape.push_back(numElements.getKnownMinValue());
373 
374       Type elementType = getStdTypeForAttr(
375           LLVM::getVectorElementType(arrayType.getElementType()));
376       if (!elementType)
377         return nullptr;
378       return VectorType::get(shape, elementType);
379     }
380 
381     // Otherwise use a tensor.
382     Type elementType = getStdTypeForAttr(arrayType.getElementType());
383     if (!elementType)
384       return nullptr;
385     return RankedTensorType::get(shape, elementType);
386   }
387 
388   return nullptr;
389 }
390 
391 // Get the given constant as an attribute. Not all constants can be represented
392 // as attributes.
393 Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
394   if (auto *ci = dyn_cast<llvm::ConstantInt>(value))
395     return b.getIntegerAttr(
396         IntegerType::get(context, ci->getType()->getBitWidth()),
397         ci->getValue());
398   if (auto *c = dyn_cast<llvm::ConstantDataArray>(value))
399     if (c->isString())
400       return b.getStringAttr(c->getAsString());
401   if (auto *c = dyn_cast<llvm::ConstantFP>(value)) {
402     if (c->getType()->isDoubleTy())
403       return b.getFloatAttr(FloatType::getF64(context), c->getValueAPF());
404     if (c->getType()->isFloatingPointTy())
405       return b.getFloatAttr(FloatType::getF32(context), c->getValueAPF());
406   }
407   if (auto *f = dyn_cast<llvm::Function>(value))
408     return b.getSymbolRefAttr(f->getName());
409 
410   // Convert constant data to a dense elements attribute.
411   if (auto *cd = dyn_cast<llvm::ConstantDataSequential>(value)) {
412     Type type = processType(cd->getElementType());
413     if (!type)
414       return nullptr;
415 
416     auto attrType = getStdTypeForAttr(processType(cd->getType()))
417                         .dyn_cast_or_null<ShapedType>();
418     if (!attrType)
419       return nullptr;
420 
421     if (type.isa<IntegerType>()) {
422       SmallVector<APInt, 8> values;
423       values.reserve(cd->getNumElements());
424       for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i)
425         values.push_back(cd->getElementAsAPInt(i));
426       return DenseElementsAttr::get(attrType, values);
427     }
428 
429     if (type.isa<Float32Type, Float64Type>()) {
430       SmallVector<APFloat, 8> values;
431       values.reserve(cd->getNumElements());
432       for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i)
433         values.push_back(cd->getElementAsAPFloat(i));
434       return DenseElementsAttr::get(attrType, values);
435     }
436 
437     return nullptr;
438   }
439 
440   // Unpack constant aggregates to create dense elements attribute whenever
441   // possible. Return nullptr (failure) otherwise.
442   if (isa<llvm::ConstantAggregate>(value)) {
443     auto outerType = getStdTypeForAttr(processType(value->getType()))
444                          .dyn_cast_or_null<ShapedType>();
445     if (!outerType)
446       return nullptr;
447 
448     SmallVector<Attribute, 8> values;
449     SmallVector<int64_t, 8> shape;
450 
451     for (unsigned i = 0, e = value->getNumOperands(); i < e; ++i) {
452       auto nested = getConstantAsAttr(value->getAggregateElement(i))
453                         .dyn_cast_or_null<DenseElementsAttr>();
454       if (!nested)
455         return nullptr;
456 
457       values.append(nested.attr_value_begin(), nested.attr_value_end());
458     }
459 
460     return DenseElementsAttr::get(outerType, values);
461   }
462 
463   return nullptr;
464 }
465 
466 GlobalOp Importer::processGlobal(llvm::GlobalVariable *GV) {
467   auto it = globals.find(GV);
468   if (it != globals.end())
469     return it->second;
470 
471   OpBuilder b(module.getBody(), getGlobalInsertPt());
472   Attribute valueAttr;
473   if (GV->hasInitializer())
474     valueAttr = getConstantAsAttr(GV->getInitializer());
475   Type type = processType(GV->getValueType());
476   if (!type)
477     return nullptr;
478   GlobalOp op = b.create<GlobalOp>(
479       UnknownLoc::get(context), type, GV->isConstant(),
480       convertLinkageFromLLVM(GV->getLinkage()), GV->getName(), valueAttr);
481   if (GV->hasInitializer() && !valueAttr) {
482     Region &r = op.getInitializerRegion();
483     currentEntryBlock = b.createBlock(&r);
484     b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin());
485     Value v = processConstant(GV->getInitializer());
486     if (!v)
487       return nullptr;
488     b.create<ReturnOp>(op.getLoc(), ArrayRef<Value>({v}));
489   }
490   return globals[GV] = op;
491 }
492 
493 Value Importer::processConstant(llvm::Constant *c) {
494   OpBuilder bEntry(currentEntryBlock, currentEntryBlock->begin());
495   if (Attribute attr = getConstantAsAttr(c)) {
496     // These constants can be represented as attributes.
497     OpBuilder b(currentEntryBlock, currentEntryBlock->begin());
498     Type type = processType(c->getType());
499     if (!type)
500       return nullptr;
501     if (auto symbolRef = attr.dyn_cast<FlatSymbolRefAttr>())
502       return instMap[c] = bEntry.create<AddressOfOp>(unknownLoc, type,
503                                                      symbolRef.getValue());
504     return instMap[c] = bEntry.create<ConstantOp>(unknownLoc, type, attr);
505   }
506   if (auto *cn = dyn_cast<llvm::ConstantPointerNull>(c)) {
507     Type type = processType(cn->getType());
508     if (!type)
509       return nullptr;
510     return instMap[c] = bEntry.create<NullOp>(unknownLoc, type);
511   }
512   if (auto *GV = dyn_cast<llvm::GlobalVariable>(c))
513     return bEntry.create<AddressOfOp>(UnknownLoc::get(context),
514                                       processGlobal(GV));
515 
516   if (auto *ce = dyn_cast<llvm::ConstantExpr>(c)) {
517     llvm::Instruction *i = ce->getAsInstruction();
518     OpBuilder::InsertionGuard guard(b);
519     b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin());
520     if (failed(processInstruction(i)))
521       return nullptr;
522     assert(instMap.count(i));
523 
524     // Remove this zombie LLVM instruction now, leaving us only with the MLIR
525     // op.
526     i->deleteValue();
527     return instMap[c] = instMap[i];
528   }
529   if (auto *ue = dyn_cast<llvm::UndefValue>(c)) {
530     Type type = processType(ue->getType());
531     if (!type)
532       return nullptr;
533     return instMap[c] = bEntry.create<UndefOp>(UnknownLoc::get(context), type);
534   }
535   emitError(unknownLoc) << "unhandled constant: " << diag(*c);
536   return nullptr;
537 }
538 
539 Value Importer::processValue(llvm::Value *value) {
540   auto it = instMap.find(value);
541   if (it != instMap.end())
542     return it->second;
543 
544   // We don't expect to see instructions in dominator order. If we haven't seen
545   // this instruction yet, create an unknown op and remap it later.
546   if (isa<llvm::Instruction>(value)) {
547     OperationState state(UnknownLoc::get(context), "llvm.unknown");
548     Type type = processType(value->getType());
549     if (!type)
550       return nullptr;
551     state.addTypes(type);
552     unknownInstMap[value] = b.createOperation(state);
553     return unknownInstMap[value]->getResult(0);
554   }
555 
556   if (auto *c = dyn_cast<llvm::Constant>(value))
557     return processConstant(c);
558 
559   emitError(unknownLoc) << "unhandled value: " << diag(*value);
560   return nullptr;
561 }
562 
563 /// Return the MLIR OperationName for the given LLVM opcode.
564 static StringRef lookupOperationNameFromOpcode(unsigned opcode) {
565 // Maps from LLVM opcode to MLIR OperationName. This is deliberately ordered
566 // as in llvm/IR/Instructions.def to aid comprehension and spot missing
567 // instructions.
568 #define INST(llvm_n, mlir_n)                                                   \
569   { llvm::Instruction::llvm_n, LLVM::mlir_n##Op::getOperationName() }
570   static const DenseMap<unsigned, StringRef> opcMap = {
571       // Ret is handled specially.
572       // Br is handled specially.
573       // FIXME: switch
574       // FIXME: indirectbr
575       // FIXME: invoke
576       INST(Resume, Resume),
577       // FIXME: unreachable
578       // FIXME: cleanupret
579       // FIXME: catchret
580       // FIXME: catchswitch
581       // FIXME: callbr
582       // FIXME: fneg
583       INST(Add, Add), INST(FAdd, FAdd), INST(Sub, Sub), INST(FSub, FSub),
584       INST(Mul, Mul), INST(FMul, FMul), INST(UDiv, UDiv), INST(SDiv, SDiv),
585       INST(FDiv, FDiv), INST(URem, URem), INST(SRem, SRem), INST(FRem, FRem),
586       INST(Shl, Shl), INST(LShr, LShr), INST(AShr, AShr), INST(And, And),
587       INST(Or, Or), INST(Xor, XOr), INST(Alloca, Alloca), INST(Load, Load),
588       INST(Store, Store),
589       // Getelementptr is handled specially.
590       INST(Ret, Return), INST(Fence, Fence),
591       // FIXME: atomiccmpxchg
592       // FIXME: atomicrmw
593       INST(Trunc, Trunc), INST(ZExt, ZExt), INST(SExt, SExt),
594       INST(FPToUI, FPToUI), INST(FPToSI, FPToSI), INST(UIToFP, UIToFP),
595       INST(SIToFP, SIToFP), INST(FPTrunc, FPTrunc), INST(FPExt, FPExt),
596       INST(PtrToInt, PtrToInt), INST(IntToPtr, IntToPtr),
597       INST(BitCast, Bitcast), INST(AddrSpaceCast, AddrSpaceCast),
598       // FIXME: cleanuppad
599       // FIXME: catchpad
600       // ICmp is handled specially.
601       // FIXME: fcmp
602       // PHI is handled specially.
603       INST(Freeze, Freeze), INST(Call, Call),
604       // FIXME: select
605       // FIXME: vaarg
606       // FIXME: extractelement
607       // FIXME: insertelement
608       // FIXME: shufflevector
609       // FIXME: extractvalue
610       // FIXME: insertvalue
611       // FIXME: landingpad
612   };
613 #undef INST
614 
615   return opcMap.lookup(opcode);
616 }
617 
618 static ICmpPredicate getICmpPredicate(llvm::CmpInst::Predicate p) {
619   switch (p) {
620   default:
621     llvm_unreachable("incorrect comparison predicate");
622   case llvm::CmpInst::Predicate::ICMP_EQ:
623     return LLVM::ICmpPredicate::eq;
624   case llvm::CmpInst::Predicate::ICMP_NE:
625     return LLVM::ICmpPredicate::ne;
626   case llvm::CmpInst::Predicate::ICMP_SLT:
627     return LLVM::ICmpPredicate::slt;
628   case llvm::CmpInst::Predicate::ICMP_SLE:
629     return LLVM::ICmpPredicate::sle;
630   case llvm::CmpInst::Predicate::ICMP_SGT:
631     return LLVM::ICmpPredicate::sgt;
632   case llvm::CmpInst::Predicate::ICMP_SGE:
633     return LLVM::ICmpPredicate::sge;
634   case llvm::CmpInst::Predicate::ICMP_ULT:
635     return LLVM::ICmpPredicate::ult;
636   case llvm::CmpInst::Predicate::ICMP_ULE:
637     return LLVM::ICmpPredicate::ule;
638   case llvm::CmpInst::Predicate::ICMP_UGT:
639     return LLVM::ICmpPredicate::ugt;
640   case llvm::CmpInst::Predicate::ICMP_UGE:
641     return LLVM::ICmpPredicate::uge;
642   }
643   llvm_unreachable("incorrect comparison predicate");
644 }
645 
646 static AtomicOrdering getLLVMAtomicOrdering(llvm::AtomicOrdering ordering) {
647   switch (ordering) {
648   case llvm::AtomicOrdering::NotAtomic:
649     return LLVM::AtomicOrdering::not_atomic;
650   case llvm::AtomicOrdering::Unordered:
651     return LLVM::AtomicOrdering::unordered;
652   case llvm::AtomicOrdering::Monotonic:
653     return LLVM::AtomicOrdering::monotonic;
654   case llvm::AtomicOrdering::Acquire:
655     return LLVM::AtomicOrdering::acquire;
656   case llvm::AtomicOrdering::Release:
657     return LLVM::AtomicOrdering::release;
658   case llvm::AtomicOrdering::AcquireRelease:
659     return LLVM::AtomicOrdering::acq_rel;
660   case llvm::AtomicOrdering::SequentiallyConsistent:
661     return LLVM::AtomicOrdering::seq_cst;
662   }
663   llvm_unreachable("incorrect atomic ordering");
664 }
665 
666 // `br` branches to `target`. Return the branch arguments to `br`, in the
667 // same order of the PHIs in `target`.
668 LogicalResult
669 Importer::processBranchArgs(llvm::Instruction *br, llvm::BasicBlock *target,
670                             SmallVectorImpl<Value> &blockArguments) {
671   for (auto inst = target->begin(); isa<llvm::PHINode>(inst); ++inst) {
672     auto *PN = cast<llvm::PHINode>(&*inst);
673     Value value = processValue(PN->getIncomingValueForBlock(br->getParent()));
674     if (!value)
675       return failure();
676     blockArguments.push_back(value);
677   }
678   return success();
679 }
680 
681 LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
682   // FIXME: Support uses of SubtargetData. Currently inbounds GEPs, fast-math
683   // flags and call / operand attributes are not supported.
684   Location loc = processDebugLoc(inst->getDebugLoc(), inst);
685   Value &v = instMap[inst];
686   assert(!v && "processInstruction must be called only once per instruction!");
687   switch (inst->getOpcode()) {
688   default:
689     return emitError(loc) << "unknown instruction: " << diag(*inst);
690   case llvm::Instruction::Add:
691   case llvm::Instruction::FAdd:
692   case llvm::Instruction::Sub:
693   case llvm::Instruction::FSub:
694   case llvm::Instruction::Mul:
695   case llvm::Instruction::FMul:
696   case llvm::Instruction::UDiv:
697   case llvm::Instruction::SDiv:
698   case llvm::Instruction::FDiv:
699   case llvm::Instruction::URem:
700   case llvm::Instruction::SRem:
701   case llvm::Instruction::FRem:
702   case llvm::Instruction::Shl:
703   case llvm::Instruction::LShr:
704   case llvm::Instruction::AShr:
705   case llvm::Instruction::And:
706   case llvm::Instruction::Or:
707   case llvm::Instruction::Xor:
708   case llvm::Instruction::Alloca:
709   case llvm::Instruction::Load:
710   case llvm::Instruction::Store:
711   case llvm::Instruction::Ret:
712   case llvm::Instruction::Resume:
713   case llvm::Instruction::Trunc:
714   case llvm::Instruction::ZExt:
715   case llvm::Instruction::SExt:
716   case llvm::Instruction::FPToUI:
717   case llvm::Instruction::FPToSI:
718   case llvm::Instruction::UIToFP:
719   case llvm::Instruction::SIToFP:
720   case llvm::Instruction::FPTrunc:
721   case llvm::Instruction::FPExt:
722   case llvm::Instruction::PtrToInt:
723   case llvm::Instruction::IntToPtr:
724   case llvm::Instruction::AddrSpaceCast:
725   case llvm::Instruction::Freeze:
726   case llvm::Instruction::BitCast: {
727     OperationState state(loc, lookupOperationNameFromOpcode(inst->getOpcode()));
728     SmallVector<Value, 4> ops;
729     ops.reserve(inst->getNumOperands());
730     for (auto *op : inst->operand_values()) {
731       Value value = processValue(op);
732       if (!value)
733         return failure();
734       ops.push_back(value);
735     }
736     state.addOperands(ops);
737     if (!inst->getType()->isVoidTy()) {
738       Type type = processType(inst->getType());
739       if (!type)
740         return failure();
741       state.addTypes(type);
742     }
743     Operation *op = b.createOperation(state);
744     if (!inst->getType()->isVoidTy())
745       v = op->getResult(0);
746     return success();
747   }
748   case llvm::Instruction::ICmp: {
749     Value lhs = processValue(inst->getOperand(0));
750     Value rhs = processValue(inst->getOperand(1));
751     if (!lhs || !rhs)
752       return failure();
753     v = b.create<ICmpOp>(
754         loc, getICmpPredicate(cast<llvm::ICmpInst>(inst)->getPredicate()), lhs,
755         rhs);
756     return success();
757   }
758   case llvm::Instruction::Br: {
759     auto *brInst = cast<llvm::BranchInst>(inst);
760     OperationState state(loc,
761                          brInst->isConditional() ? "llvm.cond_br" : "llvm.br");
762     if (brInst->isConditional()) {
763       Value condition = processValue(brInst->getCondition());
764       if (!condition)
765         return failure();
766       state.addOperands(condition);
767     }
768 
769     std::array<int32_t, 3> operandSegmentSizes = {1, 0, 0};
770     for (int i : llvm::seq<int>(0, brInst->getNumSuccessors())) {
771       auto *succ = brInst->getSuccessor(i);
772       SmallVector<Value, 4> blockArguments;
773       if (failed(processBranchArgs(brInst, succ, blockArguments)))
774         return failure();
775       state.addSuccessors(blocks[succ]);
776       state.addOperands(blockArguments);
777       operandSegmentSizes[i + 1] = blockArguments.size();
778     }
779 
780     if (brInst->isConditional()) {
781       state.addAttribute(LLVM::CondBrOp::getOperandSegmentSizeAttr(),
782                          b.getI32VectorAttr(operandSegmentSizes));
783     }
784 
785     b.createOperation(state);
786     return success();
787   }
788   case llvm::Instruction::PHI: {
789     Type type = processType(inst->getType());
790     if (!type)
791       return failure();
792     v = b.getInsertionBlock()->addArgument(type);
793     return success();
794   }
795   case llvm::Instruction::Call: {
796     llvm::CallInst *ci = cast<llvm::CallInst>(inst);
797     SmallVector<Value, 4> ops;
798     ops.reserve(inst->getNumOperands());
799     for (auto &op : ci->arg_operands()) {
800       Value arg = processValue(op.get());
801       if (!arg)
802         return failure();
803       ops.push_back(arg);
804     }
805 
806     SmallVector<Type, 2> tys;
807     if (!ci->getType()->isVoidTy()) {
808       Type type = processType(inst->getType());
809       if (!type)
810         return failure();
811       tys.push_back(type);
812     }
813     Operation *op;
814     if (llvm::Function *callee = ci->getCalledFunction()) {
815       op = b.create<CallOp>(loc, tys, b.getSymbolRefAttr(callee->getName()),
816                             ops);
817     } else {
818       Value calledValue = processValue(ci->getCalledOperand());
819       if (!calledValue)
820         return failure();
821       ops.insert(ops.begin(), calledValue);
822       op = b.create<CallOp>(loc, tys, ops);
823     }
824     if (!ci->getType()->isVoidTy())
825       v = op->getResult(0);
826     return success();
827   }
828   case llvm::Instruction::LandingPad: {
829     llvm::LandingPadInst *lpi = cast<llvm::LandingPadInst>(inst);
830     SmallVector<Value, 4> ops;
831 
832     for (unsigned i = 0, ie = lpi->getNumClauses(); i < ie; i++)
833       ops.push_back(processConstant(lpi->getClause(i)));
834 
835     Type ty = processType(lpi->getType());
836     if (!ty)
837       return failure();
838 
839     v = b.create<LandingpadOp>(loc, ty, lpi->isCleanup(), ops);
840     return success();
841   }
842   case llvm::Instruction::Invoke: {
843     llvm::InvokeInst *ii = cast<llvm::InvokeInst>(inst);
844 
845     SmallVector<Type, 2> tys;
846     if (!ii->getType()->isVoidTy())
847       tys.push_back(processType(inst->getType()));
848 
849     SmallVector<Value, 4> ops;
850     ops.reserve(inst->getNumOperands() + 1);
851     for (auto &op : ii->arg_operands())
852       ops.push_back(processValue(op.get()));
853 
854     SmallVector<Value, 4> normalArgs, unwindArgs;
855     (void)processBranchArgs(ii, ii->getNormalDest(), normalArgs);
856     (void)processBranchArgs(ii, ii->getUnwindDest(), unwindArgs);
857 
858     Operation *op;
859     if (llvm::Function *callee = ii->getCalledFunction()) {
860       op = b.create<InvokeOp>(loc, tys, b.getSymbolRefAttr(callee->getName()),
861                               ops, blocks[ii->getNormalDest()], normalArgs,
862                               blocks[ii->getUnwindDest()], unwindArgs);
863     } else {
864       ops.insert(ops.begin(), processValue(ii->getCalledOperand()));
865       op = b.create<InvokeOp>(loc, tys, ops, blocks[ii->getNormalDest()],
866                               normalArgs, blocks[ii->getUnwindDest()],
867                               unwindArgs);
868     }
869 
870     if (!ii->getType()->isVoidTy())
871       v = op->getResult(0);
872     return success();
873   }
874   case llvm::Instruction::Fence: {
875     StringRef syncscope;
876     SmallVector<StringRef, 4> ssNs;
877     llvm::LLVMContext &llvmContext = inst->getContext();
878     llvm::FenceInst *fence = cast<llvm::FenceInst>(inst);
879     llvmContext.getSyncScopeNames(ssNs);
880     int fenceSyncScopeID = fence->getSyncScopeID();
881     for (unsigned i = 0, e = ssNs.size(); i != e; i++) {
882       if (fenceSyncScopeID == llvmContext.getOrInsertSyncScopeID(ssNs[i])) {
883         syncscope = ssNs[i];
884         break;
885       }
886     }
887     b.create<FenceOp>(loc, getLLVMAtomicOrdering(fence->getOrdering()),
888                       syncscope);
889     return success();
890   }
891   case llvm::Instruction::GetElementPtr: {
892     // FIXME: Support inbounds GEPs.
893     llvm::GetElementPtrInst *gep = cast<llvm::GetElementPtrInst>(inst);
894     SmallVector<Value, 4> ops;
895     for (auto *op : gep->operand_values()) {
896       Value value = processValue(op);
897       if (!value)
898         return failure();
899       ops.push_back(value);
900     }
901     Type type = processType(inst->getType());
902     if (!type)
903       return failure();
904     v = b.create<GEPOp>(loc, type, ops);
905     return success();
906   }
907   }
908 }
909 
910 FlatSymbolRefAttr Importer::getPersonalityAsAttr(llvm::Function *f) {
911   if (!f->hasPersonalityFn())
912     return nullptr;
913 
914   llvm::Constant *pf = f->getPersonalityFn();
915 
916   // If it directly has a name, we can use it.
917   if (pf->hasName())
918     return b.getSymbolRefAttr(pf->getName());
919 
920   // If it doesn't have a name, currently, only function pointers that are
921   // bitcast to i8* are parsed.
922   if (auto ce = dyn_cast<llvm::ConstantExpr>(pf)) {
923     if (ce->getOpcode() == llvm::Instruction::BitCast &&
924         ce->getType() == llvm::Type::getInt8PtrTy(f->getContext())) {
925       if (auto func = dyn_cast<llvm::Function>(ce->getOperand(0)))
926         return b.getSymbolRefAttr(func->getName());
927     }
928   }
929   return FlatSymbolRefAttr();
930 }
931 
932 LogicalResult Importer::processFunction(llvm::Function *f) {
933   blocks.clear();
934   instMap.clear();
935   unknownInstMap.clear();
936 
937   auto functionType =
938       processType(f->getFunctionType()).dyn_cast<LLVMFunctionType>();
939   if (!functionType)
940     return failure();
941 
942   b.setInsertionPoint(module.getBody(), getFuncInsertPt());
943   LLVMFuncOp fop =
944       b.create<LLVMFuncOp>(UnknownLoc::get(context), f->getName(), functionType,
945                            convertLinkageFromLLVM(f->getLinkage()));
946 
947   if (FlatSymbolRefAttr personality = getPersonalityAsAttr(f))
948     fop->setAttr(b.getIdentifier("personality"), personality);
949   else if (f->hasPersonalityFn())
950     emitWarning(UnknownLoc::get(context),
951                 "could not deduce personality, skipping it");
952 
953   if (f->isDeclaration())
954     return success();
955 
956   // Eagerly create all blocks.
957   SmallVector<Block *, 4> blockList;
958   for (llvm::BasicBlock &bb : *f) {
959     blockList.push_back(b.createBlock(&fop.body(), fop.body().end()));
960     blocks[&bb] = blockList.back();
961   }
962   currentEntryBlock = blockList[0];
963 
964   // Add function arguments to the entry block.
965   for (auto kv : llvm::enumerate(f->args()))
966     instMap[&kv.value()] =
967         blockList[0]->addArgument(functionType.getParamType(kv.index()));
968 
969   for (auto bbs : llvm::zip(*f, blockList)) {
970     if (failed(processBasicBlock(&std::get<0>(bbs), std::get<1>(bbs))))
971       return failure();
972   }
973 
974   // Now that all instructions are guaranteed to have been visited, ensure
975   // any unknown uses we encountered are remapped.
976   for (auto &llvmAndUnknown : unknownInstMap) {
977     assert(instMap.count(llvmAndUnknown.first));
978     Value newValue = instMap[llvmAndUnknown.first];
979     Value oldValue = llvmAndUnknown.second->getResult(0);
980     oldValue.replaceAllUsesWith(newValue);
981     llvmAndUnknown.second->erase();
982   }
983   return success();
984 }
985 
986 LogicalResult Importer::processBasicBlock(llvm::BasicBlock *bb, Block *block) {
987   b.setInsertionPointToStart(block);
988   for (llvm::Instruction &inst : *bb) {
989     if (failed(processInstruction(&inst)))
990       return failure();
991   }
992   return success();
993 }
994 
995 OwningModuleRef
996 mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
997                               MLIRContext *context) {
998   context->loadDialect<LLVMDialect>();
999   OwningModuleRef module(ModuleOp::create(
1000       FileLineColLoc::get(context, "", /*line=*/0, /*column=*/0)));
1001 
1002   Importer deserializer(context, module.get());
1003   for (llvm::GlobalVariable &gv : llvmModule->globals()) {
1004     if (!deserializer.processGlobal(&gv))
1005       return {};
1006   }
1007   for (llvm::Function &f : llvmModule->functions()) {
1008     if (failed(deserializer.processFunction(&f)))
1009       return {};
1010   }
1011 
1012   return module;
1013 }
1014 
1015 // Deserializes the LLVM bitcode stored in `input` into an MLIR module in the
1016 // LLVM dialect.
1017 OwningModuleRef translateLLVMIRToModule(llvm::SourceMgr &sourceMgr,
1018                                         MLIRContext *context) {
1019   llvm::SMDiagnostic err;
1020   llvm::LLVMContext llvmContext;
1021   std::unique_ptr<llvm::Module> llvmModule = llvm::parseIR(
1022       *sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), err, llvmContext);
1023   if (!llvmModule) {
1024     std::string errStr;
1025     llvm::raw_string_ostream errStream(errStr);
1026     err.print(/*ProgName=*/"", errStream);
1027     emitError(UnknownLoc::get(context)) << errStream.str();
1028     return {};
1029   }
1030   return translateLLVMIRToModule(std::move(llvmModule), context);
1031 }
1032 
1033 namespace mlir {
1034 void registerFromLLVMIRTranslation() {
1035   TranslateToMLIRRegistration fromLLVM(
1036       "import-llvm", [](llvm::SourceMgr &sourceMgr, MLIRContext *context) {
1037         return ::translateLLVMIRToModule(sourceMgr, context);
1038       });
1039 }
1040 } // namespace mlir
1041