1 //===- ConvertFromLLVMIR.cpp - MLIR to LLVM IR conversion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between LLVM IR and the MLIR LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/MLIRContext.h"
18 #include "mlir/Target/LLVMIR/Import.h"
19 #include "mlir/Translation.h"
20 
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/IR/Attributes.h"
23 #include "llvm/IR/Constants.h"
24 #include "llvm/IR/DerivedTypes.h"
25 #include "llvm/IR/Function.h"
26 #include "llvm/IR/InlineAsm.h"
27 #include "llvm/IR/Instructions.h"
28 #include "llvm/IR/Type.h"
29 #include "llvm/IRReader/IRReader.h"
30 #include "llvm/Support/Error.h"
31 #include "llvm/Support/SourceMgr.h"
32 
33 using namespace mlir;
34 using namespace mlir::LLVM;
35 
36 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsFromLLVM.inc"
37 
38 // Utility to print an LLVM value as a string for passing to emitError().
39 // FIXME: Diagnostic should be able to natively handle types that have
40 // operator << (raw_ostream&) defined.
41 static std::string diag(llvm::Value &v) {
42   std::string s;
43   llvm::raw_string_ostream os(s);
44   os << v;
45   return os.str();
46 }
47 
48 namespace mlir {
49 namespace LLVM {
50 namespace detail {
51 /// Support for translating LLVM IR types to MLIR LLVM dialect types.
52 class TypeFromLLVMIRTranslatorImpl {
53 public:
54   /// Constructs a class creating types in the given MLIR context.
55   TypeFromLLVMIRTranslatorImpl(MLIRContext &context) : context(context) {}
56 
57   /// Translates the given type.
58   Type translateType(llvm::Type *type) {
59     if (knownTranslations.count(type))
60       return knownTranslations.lookup(type);
61 
62     Type translated =
63         llvm::TypeSwitch<llvm::Type *, Type>(type)
64             .Case<llvm::ArrayType, llvm::FunctionType, llvm::IntegerType,
65                   llvm::PointerType, llvm::StructType, llvm::FixedVectorType,
66                   llvm::ScalableVectorType>(
67                 [this](auto *type) { return this->translate(type); })
68             .Default([this](llvm::Type *type) {
69               return translatePrimitiveType(type);
70             });
71     knownTranslations.try_emplace(type, translated);
72     return translated;
73   }
74 
75 private:
76   /// Translates the given primitive, i.e. non-parametric in MLIR nomenclature,
77   /// type.
78   Type translatePrimitiveType(llvm::Type *type) {
79     if (type->isVoidTy())
80       return LLVM::LLVMVoidType::get(&context);
81     if (type->isHalfTy())
82       return Float16Type::get(&context);
83     if (type->isBFloatTy())
84       return BFloat16Type::get(&context);
85     if (type->isFloatTy())
86       return Float32Type::get(&context);
87     if (type->isDoubleTy())
88       return Float64Type::get(&context);
89     if (type->isFP128Ty())
90       return Float128Type::get(&context);
91     if (type->isX86_FP80Ty())
92       return Float80Type::get(&context);
93     if (type->isPPC_FP128Ty())
94       return LLVM::LLVMPPCFP128Type::get(&context);
95     if (type->isX86_MMXTy())
96       return LLVM::LLVMX86MMXType::get(&context);
97     if (type->isLabelTy())
98       return LLVM::LLVMLabelType::get(&context);
99     if (type->isMetadataTy())
100       return LLVM::LLVMMetadataType::get(&context);
101     llvm_unreachable("not a primitive type");
102   }
103 
104   /// Translates the given array type.
105   Type translate(llvm::ArrayType *type) {
106     return LLVM::LLVMArrayType::get(translateType(type->getElementType()),
107                                     type->getNumElements());
108   }
109 
110   /// Translates the given function type.
111   Type translate(llvm::FunctionType *type) {
112     SmallVector<Type, 8> paramTypes;
113     translateTypes(type->params(), paramTypes);
114     return LLVM::LLVMFunctionType::get(translateType(type->getReturnType()),
115                                        paramTypes, type->isVarArg());
116   }
117 
118   /// Translates the given integer type.
119   Type translate(llvm::IntegerType *type) {
120     return IntegerType::get(&context, type->getBitWidth());
121   }
122 
123   /// Translates the given pointer type.
124   Type translate(llvm::PointerType *type) {
125     return LLVM::LLVMPointerType::get(translateType(type->getElementType()),
126                                       type->getAddressSpace());
127   }
128 
129   /// Translates the given structure type.
130   Type translate(llvm::StructType *type) {
131     SmallVector<Type, 8> subtypes;
132     if (type->isLiteral()) {
133       translateTypes(type->subtypes(), subtypes);
134       return LLVM::LLVMStructType::getLiteral(&context, subtypes,
135                                               type->isPacked());
136     }
137 
138     if (type->isOpaque())
139       return LLVM::LLVMStructType::getOpaque(type->getName(), &context);
140 
141     LLVM::LLVMStructType translated =
142         LLVM::LLVMStructType::getIdentified(&context, type->getName());
143     knownTranslations.try_emplace(type, translated);
144     translateTypes(type->subtypes(), subtypes);
145     LogicalResult bodySet = translated.setBody(subtypes, type->isPacked());
146     assert(succeeded(bodySet) &&
147            "could not set the body of an identified struct");
148     (void)bodySet;
149     return translated;
150   }
151 
152   /// Translates the given fixed-vector type.
153   Type translate(llvm::FixedVectorType *type) {
154     return LLVM::getFixedVectorType(translateType(type->getElementType()),
155                                     type->getNumElements());
156   }
157 
158   /// Translates the given scalable-vector type.
159   Type translate(llvm::ScalableVectorType *type) {
160     return LLVM::LLVMScalableVectorType::get(
161         translateType(type->getElementType()), type->getMinNumElements());
162   }
163 
164   /// Translates a list of types.
165   void translateTypes(ArrayRef<llvm::Type *> types,
166                       SmallVectorImpl<Type> &result) {
167     result.reserve(result.size() + types.size());
168     for (llvm::Type *type : types)
169       result.push_back(translateType(type));
170   }
171 
172   /// Map of known translations. Serves as a cache and as recursion stopper for
173   /// translating recursive structs.
174   llvm::DenseMap<llvm::Type *, Type> knownTranslations;
175 
176   /// The context in which MLIR types are created.
177   MLIRContext &context;
178 };
179 } // end namespace detail
180 
181 /// Utility class to translate LLVM IR types to the MLIR LLVM dialect. Stores
182 /// the translation state, in particular any identified structure types that are
183 /// reused across translations.
184 class TypeFromLLVMIRTranslator {
185 public:
186   TypeFromLLVMIRTranslator(MLIRContext &context);
187   ~TypeFromLLVMIRTranslator();
188 
189   /// Translates the given LLVM IR type to the MLIR LLVM dialect.
190   Type translateType(llvm::Type *type);
191 
192 private:
193   /// Private implementation.
194   std::unique_ptr<detail::TypeFromLLVMIRTranslatorImpl> impl;
195 };
196 
197 } // end namespace LLVM
198 } // end namespace mlir
199 
200 LLVM::TypeFromLLVMIRTranslator::TypeFromLLVMIRTranslator(MLIRContext &context)
201     : impl(new detail::TypeFromLLVMIRTranslatorImpl(context)) {}
202 
203 LLVM::TypeFromLLVMIRTranslator::~TypeFromLLVMIRTranslator() {}
204 
205 Type LLVM::TypeFromLLVMIRTranslator::translateType(llvm::Type *type) {
206   return impl->translateType(type);
207 }
208 
209 // Handles importing globals and functions from an LLVM module.
210 namespace {
211 class Importer {
212 public:
213   Importer(MLIRContext *context, ModuleOp module)
214       : b(context), context(context), module(module),
215         unknownLoc(FileLineColLoc::get(context, "imported-bitcode", 0, 0)),
216         typeTranslator(*context) {
217     b.setInsertionPointToStart(module.getBody());
218   }
219 
220   /// Imports `f` into the current module.
221   LogicalResult processFunction(llvm::Function *f);
222 
223   /// Imports GV as a GlobalOp, creating it if it doesn't exist.
224   GlobalOp processGlobal(llvm::GlobalVariable *GV);
225 
226 private:
227   /// Returns personality of `f` as a FlatSymbolRefAttr.
228   FlatSymbolRefAttr getPersonalityAsAttr(llvm::Function *f);
229   /// Imports `bb` into `block`, which must be initially empty.
230   LogicalResult processBasicBlock(llvm::BasicBlock *bb, Block *block);
231   /// Imports `inst` and populates instMap[inst] with the imported Value.
232   LogicalResult processInstruction(llvm::Instruction *inst);
233   /// Creates an LLVM-compatible MLIR type for `type`.
234   Type processType(llvm::Type *type);
235   /// `value` is an SSA-use. Return the remapped version of `value` or a
236   /// placeholder that will be remapped later if this is an instruction that
237   /// has not yet been visited.
238   Value processValue(llvm::Value *value);
239   /// Create the most accurate Location possible using a llvm::DebugLoc and
240   /// possibly an llvm::Instruction to narrow the Location if debug information
241   /// is unavailable.
242   Location processDebugLoc(const llvm::DebugLoc &loc,
243                            llvm::Instruction *inst = nullptr);
244   /// `br` branches to `target`. Append the block arguments to attach to the
245   /// generated branch op to `blockArguments`. These should be in the same order
246   /// as the PHIs in `target`.
247   LogicalResult processBranchArgs(llvm::Instruction *br,
248                                   llvm::BasicBlock *target,
249                                   SmallVectorImpl<Value> &blockArguments);
250   /// Returns the builtin type equivalent to be used in attributes for the given
251   /// LLVM IR dialect type.
252   Type getStdTypeForAttr(Type type);
253   /// Return `value` as an attribute to attach to a GlobalOp.
254   Attribute getConstantAsAttr(llvm::Constant *value);
255   /// Return `c` as an MLIR Value. This could either be a ConstantOp, or
256   /// an expanded sequence of ops in the current function's entry block (for
257   /// ConstantExprs or ConstantGEPs).
258   Value processConstant(llvm::Constant *c);
259 
260   /// The current builder, pointing at where the next Instruction should be
261   /// generated.
262   OpBuilder b;
263   /// The current context.
264   MLIRContext *context;
265   /// The current module being created.
266   ModuleOp module;
267   /// The entry block of the current function being processed.
268   Block *currentEntryBlock;
269 
270   /// Globals are inserted before the first function, if any.
271   Block::iterator getGlobalInsertPt() {
272     auto it = module.getBody()->begin();
273     auto endIt = module.getBody()->end();
274     while (it != endIt && !isa<LLVMFuncOp>(it))
275       ++it;
276     return it;
277   }
278 
279   /// Functions are always inserted before the module terminator.
280   Block::iterator getFuncInsertPt() {
281     return std::prev(module.getBody()->end());
282   }
283 
284   /// Remapped blocks, for the current function.
285   DenseMap<llvm::BasicBlock *, Block *> blocks;
286   /// Remapped values. These are function-local.
287   DenseMap<llvm::Value *, Value> instMap;
288   /// Instructions that had not been defined when first encountered as a use.
289   /// Maps to the dummy Operation that was created in processValue().
290   DenseMap<llvm::Value *, Operation *> unknownInstMap;
291   /// Uniquing map of GlobalVariables.
292   DenseMap<llvm::GlobalVariable *, GlobalOp> globals;
293   /// Cached FileLineColLoc::get("imported-bitcode", 0, 0).
294   Location unknownLoc;
295   /// The stateful type translator (contains named structs).
296   LLVM::TypeFromLLVMIRTranslator typeTranslator;
297 };
298 } // namespace
299 
300 Location Importer::processDebugLoc(const llvm::DebugLoc &loc,
301                                    llvm::Instruction *inst) {
302   if (!loc && inst) {
303     std::string s;
304     llvm::raw_string_ostream os(s);
305     os << "llvm-imported-inst-%";
306     inst->printAsOperand(os, /*PrintType=*/false);
307     return FileLineColLoc::get(context, os.str(), 0, 0);
308   } else if (!loc) {
309     return unknownLoc;
310   }
311   // FIXME: Obtain the filename from DILocationInfo.
312   return FileLineColLoc::get(context, "imported-bitcode", loc.getLine(),
313                              loc.getCol());
314 }
315 
316 Type Importer::processType(llvm::Type *type) {
317   if (Type result = typeTranslator.translateType(type))
318     return result;
319 
320   // FIXME: Diagnostic should be able to natively handle types that have
321   // operator<<(raw_ostream&) defined.
322   std::string s;
323   llvm::raw_string_ostream os(s);
324   os << *type;
325   emitError(unknownLoc) << "unhandled type: " << os.str();
326   return nullptr;
327 }
328 
329 // We only need integers, floats, doubles, and vectors and tensors thereof for
330 // attributes. Scalar and vector types are converted to the standard
331 // equivalents. Array types are converted to ranked tensors; nested array types
332 // are converted to multi-dimensional tensors or vectors, depending on the
333 // innermost type being a scalar or a vector.
334 Type Importer::getStdTypeForAttr(Type type) {
335   if (!type)
336     return nullptr;
337 
338   if (type.isa<IntegerType, FloatType>())
339     return type;
340 
341   // LLVM vectors can only contain scalars.
342   if (LLVM::isCompatibleVectorType(type)) {
343     auto numElements = LLVM::getVectorNumElements(type);
344     if (numElements.isScalable()) {
345       emitError(unknownLoc) << "scalable vectors not supported";
346       return nullptr;
347     }
348     Type elementType = getStdTypeForAttr(LLVM::getVectorElementType(type));
349     if (!elementType)
350       return nullptr;
351     return VectorType::get(numElements.getKnownMinValue(), elementType);
352   }
353 
354   // LLVM arrays can contain other arrays or vectors.
355   if (auto arrayType = type.dyn_cast<LLVMArrayType>()) {
356     // Recover the nested array shape.
357     SmallVector<int64_t, 4> shape;
358     shape.push_back(arrayType.getNumElements());
359     while (arrayType.getElementType().isa<LLVMArrayType>()) {
360       arrayType = arrayType.getElementType().cast<LLVMArrayType>();
361       shape.push_back(arrayType.getNumElements());
362     }
363 
364     // If the innermost type is a vector, use the multi-dimensional vector as
365     // attribute type.
366     if (LLVM::isCompatibleVectorType(arrayType.getElementType())) {
367       auto numElements = LLVM::getVectorNumElements(arrayType.getElementType());
368       if (numElements.isScalable()) {
369         emitError(unknownLoc) << "scalable vectors not supported";
370         return nullptr;
371       }
372       shape.push_back(numElements.getKnownMinValue());
373 
374       Type elementType = getStdTypeForAttr(
375           LLVM::getVectorElementType(arrayType.getElementType()));
376       if (!elementType)
377         return nullptr;
378       return VectorType::get(shape, elementType);
379     }
380 
381     // Otherwise use a tensor.
382     Type elementType = getStdTypeForAttr(arrayType.getElementType());
383     if (!elementType)
384       return nullptr;
385     return RankedTensorType::get(shape, elementType);
386   }
387 
388   return nullptr;
389 }
390 
391 // Get the given constant as an attribute. Not all constants can be represented
392 // as attributes.
393 Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
394   if (auto *ci = dyn_cast<llvm::ConstantInt>(value))
395     return b.getIntegerAttr(
396         IntegerType::get(context, ci->getType()->getBitWidth()),
397         ci->getValue());
398   if (auto *c = dyn_cast<llvm::ConstantDataArray>(value))
399     if (c->isString())
400       return b.getStringAttr(c->getAsString());
401   if (auto *c = dyn_cast<llvm::ConstantFP>(value)) {
402     if (c->getType()->isDoubleTy())
403       return b.getFloatAttr(FloatType::getF64(context), c->getValueAPF());
404     if (c->getType()->isFloatingPointTy())
405       return b.getFloatAttr(FloatType::getF32(context), c->getValueAPF());
406   }
407   if (auto *f = dyn_cast<llvm::Function>(value))
408     return b.getSymbolRefAttr(f->getName());
409 
410   // Convert constant data to a dense elements attribute.
411   if (auto *cd = dyn_cast<llvm::ConstantDataSequential>(value)) {
412     Type type = processType(cd->getElementType());
413     if (!type)
414       return nullptr;
415 
416     auto attrType = getStdTypeForAttr(processType(cd->getType()))
417                         .dyn_cast_or_null<ShapedType>();
418     if (!attrType)
419       return nullptr;
420 
421     if (type.isa<IntegerType>()) {
422       SmallVector<APInt, 8> values;
423       values.reserve(cd->getNumElements());
424       for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i)
425         values.push_back(cd->getElementAsAPInt(i));
426       return DenseElementsAttr::get(attrType, values);
427     }
428 
429     if (type.isa<Float32Type, Float64Type>()) {
430       SmallVector<APFloat, 8> values;
431       values.reserve(cd->getNumElements());
432       for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i)
433         values.push_back(cd->getElementAsAPFloat(i));
434       return DenseElementsAttr::get(attrType, values);
435     }
436 
437     return nullptr;
438   }
439 
440   // Unpack constant aggregates to create dense elements attribute whenever
441   // possible. Return nullptr (failure) otherwise.
442   if (isa<llvm::ConstantAggregate>(value)) {
443     auto outerType = getStdTypeForAttr(processType(value->getType()))
444                          .dyn_cast_or_null<ShapedType>();
445     if (!outerType)
446       return nullptr;
447 
448     SmallVector<Attribute, 8> values;
449     SmallVector<int64_t, 8> shape;
450 
451     for (unsigned i = 0, e = value->getNumOperands(); i < e; ++i) {
452       auto nested = getConstantAsAttr(value->getAggregateElement(i))
453                         .dyn_cast_or_null<DenseElementsAttr>();
454       if (!nested)
455         return nullptr;
456 
457       values.append(nested.attr_value_begin(), nested.attr_value_end());
458     }
459 
460     return DenseElementsAttr::get(outerType, values);
461   }
462 
463   return nullptr;
464 }
465 
466 GlobalOp Importer::processGlobal(llvm::GlobalVariable *GV) {
467   auto it = globals.find(GV);
468   if (it != globals.end())
469     return it->second;
470 
471   OpBuilder b(module.getBody(), getGlobalInsertPt());
472   Attribute valueAttr;
473   if (GV->hasInitializer())
474     valueAttr = getConstantAsAttr(GV->getInitializer());
475   Type type = processType(GV->getValueType());
476   if (!type)
477     return nullptr;
478   GlobalOp op = b.create<GlobalOp>(
479       UnknownLoc::get(context), type, GV->isConstant(),
480       convertLinkageFromLLVM(GV->getLinkage()), GV->getName(), valueAttr);
481   if (GV->hasInitializer() && !valueAttr) {
482     Region &r = op.getInitializerRegion();
483     currentEntryBlock = b.createBlock(&r);
484     b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin());
485     Value v = processConstant(GV->getInitializer());
486     if (!v)
487       return nullptr;
488     b.create<ReturnOp>(op.getLoc(), ArrayRef<Value>({v}));
489   }
490   if (GV->hasAtLeastLocalUnnamedAddr())
491     op.unnamed_addrAttr(UnnamedAddrAttr::get(
492         context, convertUnnamedAddrFromLLVM(GV->getUnnamedAddr())));
493   if (GV->hasSection())
494     op.sectionAttr(b.getStringAttr(GV->getSection()));
495   return globals[GV] = op;
496 }
497 
498 Value Importer::processConstant(llvm::Constant *c) {
499   OpBuilder bEntry(currentEntryBlock, currentEntryBlock->begin());
500   if (Attribute attr = getConstantAsAttr(c)) {
501     // These constants can be represented as attributes.
502     OpBuilder b(currentEntryBlock, currentEntryBlock->begin());
503     Type type = processType(c->getType());
504     if (!type)
505       return nullptr;
506     if (auto symbolRef = attr.dyn_cast<FlatSymbolRefAttr>())
507       return instMap[c] = bEntry.create<AddressOfOp>(unknownLoc, type,
508                                                      symbolRef.getValue());
509     return instMap[c] = bEntry.create<ConstantOp>(unknownLoc, type, attr);
510   }
511   if (auto *cn = dyn_cast<llvm::ConstantPointerNull>(c)) {
512     Type type = processType(cn->getType());
513     if (!type)
514       return nullptr;
515     return instMap[c] = bEntry.create<NullOp>(unknownLoc, type);
516   }
517   if (auto *GV = dyn_cast<llvm::GlobalVariable>(c))
518     return bEntry.create<AddressOfOp>(UnknownLoc::get(context),
519                                       processGlobal(GV));
520 
521   if (auto *ce = dyn_cast<llvm::ConstantExpr>(c)) {
522     llvm::Instruction *i = ce->getAsInstruction();
523     OpBuilder::InsertionGuard guard(b);
524     b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin());
525     if (failed(processInstruction(i)))
526       return nullptr;
527     assert(instMap.count(i));
528 
529     // Remove this zombie LLVM instruction now, leaving us only with the MLIR
530     // op.
531     i->deleteValue();
532     return instMap[c] = instMap[i];
533   }
534   if (auto *ue = dyn_cast<llvm::UndefValue>(c)) {
535     Type type = processType(ue->getType());
536     if (!type)
537       return nullptr;
538     return instMap[c] = bEntry.create<UndefOp>(UnknownLoc::get(context), type);
539   }
540   emitError(unknownLoc) << "unhandled constant: " << diag(*c);
541   return nullptr;
542 }
543 
544 Value Importer::processValue(llvm::Value *value) {
545   auto it = instMap.find(value);
546   if (it != instMap.end())
547     return it->second;
548 
549   // We don't expect to see instructions in dominator order. If we haven't seen
550   // this instruction yet, create an unknown op and remap it later.
551   if (isa<llvm::Instruction>(value)) {
552     OperationState state(UnknownLoc::get(context), "llvm.unknown");
553     Type type = processType(value->getType());
554     if (!type)
555       return nullptr;
556     state.addTypes(type);
557     unknownInstMap[value] = b.createOperation(state);
558     return unknownInstMap[value]->getResult(0);
559   }
560 
561   if (auto *c = dyn_cast<llvm::Constant>(value))
562     return processConstant(c);
563 
564   emitError(unknownLoc) << "unhandled value: " << diag(*value);
565   return nullptr;
566 }
567 
568 /// Return the MLIR OperationName for the given LLVM opcode.
569 static StringRef lookupOperationNameFromOpcode(unsigned opcode) {
570 // Maps from LLVM opcode to MLIR OperationName. This is deliberately ordered
571 // as in llvm/IR/Instructions.def to aid comprehension and spot missing
572 // instructions.
573 #define INST(llvm_n, mlir_n)                                                   \
574   { llvm::Instruction::llvm_n, LLVM::mlir_n##Op::getOperationName() }
575   static const DenseMap<unsigned, StringRef> opcMap = {
576       // Ret is handled specially.
577       // Br is handled specially.
578       // FIXME: switch
579       // FIXME: indirectbr
580       // FIXME: invoke
581       INST(Resume, Resume),
582       // FIXME: unreachable
583       // FIXME: cleanupret
584       // FIXME: catchret
585       // FIXME: catchswitch
586       // FIXME: callbr
587       // FIXME: fneg
588       INST(Add, Add), INST(FAdd, FAdd), INST(Sub, Sub), INST(FSub, FSub),
589       INST(Mul, Mul), INST(FMul, FMul), INST(UDiv, UDiv), INST(SDiv, SDiv),
590       INST(FDiv, FDiv), INST(URem, URem), INST(SRem, SRem), INST(FRem, FRem),
591       INST(Shl, Shl), INST(LShr, LShr), INST(AShr, AShr), INST(And, And),
592       INST(Or, Or), INST(Xor, XOr), INST(Alloca, Alloca), INST(Load, Load),
593       INST(Store, Store),
594       // Getelementptr is handled specially.
595       INST(Ret, Return), INST(Fence, Fence),
596       // FIXME: atomiccmpxchg
597       // FIXME: atomicrmw
598       INST(Trunc, Trunc), INST(ZExt, ZExt), INST(SExt, SExt),
599       INST(FPToUI, FPToUI), INST(FPToSI, FPToSI), INST(UIToFP, UIToFP),
600       INST(SIToFP, SIToFP), INST(FPTrunc, FPTrunc), INST(FPExt, FPExt),
601       INST(PtrToInt, PtrToInt), INST(IntToPtr, IntToPtr),
602       INST(BitCast, Bitcast), INST(AddrSpaceCast, AddrSpaceCast),
603       // FIXME: cleanuppad
604       // FIXME: catchpad
605       // ICmp is handled specially.
606       // FIXME: fcmp
607       // PHI is handled specially.
608       INST(Freeze, Freeze), INST(Call, Call),
609       // FIXME: select
610       // FIXME: vaarg
611       // FIXME: extractelement
612       // FIXME: insertelement
613       // FIXME: shufflevector
614       // FIXME: extractvalue
615       // FIXME: insertvalue
616       // FIXME: landingpad
617   };
618 #undef INST
619 
620   return opcMap.lookup(opcode);
621 }
622 
623 static ICmpPredicate getICmpPredicate(llvm::CmpInst::Predicate p) {
624   switch (p) {
625   default:
626     llvm_unreachable("incorrect comparison predicate");
627   case llvm::CmpInst::Predicate::ICMP_EQ:
628     return LLVM::ICmpPredicate::eq;
629   case llvm::CmpInst::Predicate::ICMP_NE:
630     return LLVM::ICmpPredicate::ne;
631   case llvm::CmpInst::Predicate::ICMP_SLT:
632     return LLVM::ICmpPredicate::slt;
633   case llvm::CmpInst::Predicate::ICMP_SLE:
634     return LLVM::ICmpPredicate::sle;
635   case llvm::CmpInst::Predicate::ICMP_SGT:
636     return LLVM::ICmpPredicate::sgt;
637   case llvm::CmpInst::Predicate::ICMP_SGE:
638     return LLVM::ICmpPredicate::sge;
639   case llvm::CmpInst::Predicate::ICMP_ULT:
640     return LLVM::ICmpPredicate::ult;
641   case llvm::CmpInst::Predicate::ICMP_ULE:
642     return LLVM::ICmpPredicate::ule;
643   case llvm::CmpInst::Predicate::ICMP_UGT:
644     return LLVM::ICmpPredicate::ugt;
645   case llvm::CmpInst::Predicate::ICMP_UGE:
646     return LLVM::ICmpPredicate::uge;
647   }
648   llvm_unreachable("incorrect comparison predicate");
649 }
650 
651 static AtomicOrdering getLLVMAtomicOrdering(llvm::AtomicOrdering ordering) {
652   switch (ordering) {
653   case llvm::AtomicOrdering::NotAtomic:
654     return LLVM::AtomicOrdering::not_atomic;
655   case llvm::AtomicOrdering::Unordered:
656     return LLVM::AtomicOrdering::unordered;
657   case llvm::AtomicOrdering::Monotonic:
658     return LLVM::AtomicOrdering::monotonic;
659   case llvm::AtomicOrdering::Acquire:
660     return LLVM::AtomicOrdering::acquire;
661   case llvm::AtomicOrdering::Release:
662     return LLVM::AtomicOrdering::release;
663   case llvm::AtomicOrdering::AcquireRelease:
664     return LLVM::AtomicOrdering::acq_rel;
665   case llvm::AtomicOrdering::SequentiallyConsistent:
666     return LLVM::AtomicOrdering::seq_cst;
667   }
668   llvm_unreachable("incorrect atomic ordering");
669 }
670 
671 // `br` branches to `target`. Return the branch arguments to `br`, in the
672 // same order of the PHIs in `target`.
673 LogicalResult
674 Importer::processBranchArgs(llvm::Instruction *br, llvm::BasicBlock *target,
675                             SmallVectorImpl<Value> &blockArguments) {
676   for (auto inst = target->begin(); isa<llvm::PHINode>(inst); ++inst) {
677     auto *PN = cast<llvm::PHINode>(&*inst);
678     Value value = processValue(PN->getIncomingValueForBlock(br->getParent()));
679     if (!value)
680       return failure();
681     blockArguments.push_back(value);
682   }
683   return success();
684 }
685 
686 LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
687   // FIXME: Support uses of SubtargetData. Currently inbounds GEPs, fast-math
688   // flags and call / operand attributes are not supported.
689   Location loc = processDebugLoc(inst->getDebugLoc(), inst);
690   Value &v = instMap[inst];
691   assert(!v && "processInstruction must be called only once per instruction!");
692   switch (inst->getOpcode()) {
693   default:
694     return emitError(loc) << "unknown instruction: " << diag(*inst);
695   case llvm::Instruction::Add:
696   case llvm::Instruction::FAdd:
697   case llvm::Instruction::Sub:
698   case llvm::Instruction::FSub:
699   case llvm::Instruction::Mul:
700   case llvm::Instruction::FMul:
701   case llvm::Instruction::UDiv:
702   case llvm::Instruction::SDiv:
703   case llvm::Instruction::FDiv:
704   case llvm::Instruction::URem:
705   case llvm::Instruction::SRem:
706   case llvm::Instruction::FRem:
707   case llvm::Instruction::Shl:
708   case llvm::Instruction::LShr:
709   case llvm::Instruction::AShr:
710   case llvm::Instruction::And:
711   case llvm::Instruction::Or:
712   case llvm::Instruction::Xor:
713   case llvm::Instruction::Alloca:
714   case llvm::Instruction::Load:
715   case llvm::Instruction::Store:
716   case llvm::Instruction::Ret:
717   case llvm::Instruction::Resume:
718   case llvm::Instruction::Trunc:
719   case llvm::Instruction::ZExt:
720   case llvm::Instruction::SExt:
721   case llvm::Instruction::FPToUI:
722   case llvm::Instruction::FPToSI:
723   case llvm::Instruction::UIToFP:
724   case llvm::Instruction::SIToFP:
725   case llvm::Instruction::FPTrunc:
726   case llvm::Instruction::FPExt:
727   case llvm::Instruction::PtrToInt:
728   case llvm::Instruction::IntToPtr:
729   case llvm::Instruction::AddrSpaceCast:
730   case llvm::Instruction::Freeze:
731   case llvm::Instruction::BitCast: {
732     OperationState state(loc, lookupOperationNameFromOpcode(inst->getOpcode()));
733     SmallVector<Value, 4> ops;
734     ops.reserve(inst->getNumOperands());
735     for (auto *op : inst->operand_values()) {
736       Value value = processValue(op);
737       if (!value)
738         return failure();
739       ops.push_back(value);
740     }
741     state.addOperands(ops);
742     if (!inst->getType()->isVoidTy()) {
743       Type type = processType(inst->getType());
744       if (!type)
745         return failure();
746       state.addTypes(type);
747     }
748     Operation *op = b.createOperation(state);
749     if (!inst->getType()->isVoidTy())
750       v = op->getResult(0);
751     return success();
752   }
753   case llvm::Instruction::ICmp: {
754     Value lhs = processValue(inst->getOperand(0));
755     Value rhs = processValue(inst->getOperand(1));
756     if (!lhs || !rhs)
757       return failure();
758     v = b.create<ICmpOp>(
759         loc, getICmpPredicate(cast<llvm::ICmpInst>(inst)->getPredicate()), lhs,
760         rhs);
761     return success();
762   }
763   case llvm::Instruction::Br: {
764     auto *brInst = cast<llvm::BranchInst>(inst);
765     OperationState state(loc,
766                          brInst->isConditional() ? "llvm.cond_br" : "llvm.br");
767     if (brInst->isConditional()) {
768       Value condition = processValue(brInst->getCondition());
769       if (!condition)
770         return failure();
771       state.addOperands(condition);
772     }
773 
774     std::array<int32_t, 3> operandSegmentSizes = {1, 0, 0};
775     for (int i : llvm::seq<int>(0, brInst->getNumSuccessors())) {
776       auto *succ = brInst->getSuccessor(i);
777       SmallVector<Value, 4> blockArguments;
778       if (failed(processBranchArgs(brInst, succ, blockArguments)))
779         return failure();
780       state.addSuccessors(blocks[succ]);
781       state.addOperands(blockArguments);
782       operandSegmentSizes[i + 1] = blockArguments.size();
783     }
784 
785     if (brInst->isConditional()) {
786       state.addAttribute(LLVM::CondBrOp::getOperandSegmentSizeAttr(),
787                          b.getI32VectorAttr(operandSegmentSizes));
788     }
789 
790     b.createOperation(state);
791     return success();
792   }
793   case llvm::Instruction::PHI: {
794     Type type = processType(inst->getType());
795     if (!type)
796       return failure();
797     v = b.getInsertionBlock()->addArgument(type);
798     return success();
799   }
800   case llvm::Instruction::Call: {
801     llvm::CallInst *ci = cast<llvm::CallInst>(inst);
802     SmallVector<Value, 4> ops;
803     ops.reserve(inst->getNumOperands());
804     for (auto &op : ci->arg_operands()) {
805       Value arg = processValue(op.get());
806       if (!arg)
807         return failure();
808       ops.push_back(arg);
809     }
810 
811     SmallVector<Type, 2> tys;
812     if (!ci->getType()->isVoidTy()) {
813       Type type = processType(inst->getType());
814       if (!type)
815         return failure();
816       tys.push_back(type);
817     }
818     Operation *op;
819     if (llvm::Function *callee = ci->getCalledFunction()) {
820       op = b.create<CallOp>(loc, tys, b.getSymbolRefAttr(callee->getName()),
821                             ops);
822     } else {
823       Value calledValue = processValue(ci->getCalledOperand());
824       if (!calledValue)
825         return failure();
826       ops.insert(ops.begin(), calledValue);
827       op = b.create<CallOp>(loc, tys, ops);
828     }
829     if (!ci->getType()->isVoidTy())
830       v = op->getResult(0);
831     return success();
832   }
833   case llvm::Instruction::LandingPad: {
834     llvm::LandingPadInst *lpi = cast<llvm::LandingPadInst>(inst);
835     SmallVector<Value, 4> ops;
836 
837     for (unsigned i = 0, ie = lpi->getNumClauses(); i < ie; i++)
838       ops.push_back(processConstant(lpi->getClause(i)));
839 
840     Type ty = processType(lpi->getType());
841     if (!ty)
842       return failure();
843 
844     v = b.create<LandingpadOp>(loc, ty, lpi->isCleanup(), ops);
845     return success();
846   }
847   case llvm::Instruction::Invoke: {
848     llvm::InvokeInst *ii = cast<llvm::InvokeInst>(inst);
849 
850     SmallVector<Type, 2> tys;
851     if (!ii->getType()->isVoidTy())
852       tys.push_back(processType(inst->getType()));
853 
854     SmallVector<Value, 4> ops;
855     ops.reserve(inst->getNumOperands() + 1);
856     for (auto &op : ii->arg_operands())
857       ops.push_back(processValue(op.get()));
858 
859     SmallVector<Value, 4> normalArgs, unwindArgs;
860     (void)processBranchArgs(ii, ii->getNormalDest(), normalArgs);
861     (void)processBranchArgs(ii, ii->getUnwindDest(), unwindArgs);
862 
863     Operation *op;
864     if (llvm::Function *callee = ii->getCalledFunction()) {
865       op = b.create<InvokeOp>(loc, tys, b.getSymbolRefAttr(callee->getName()),
866                               ops, blocks[ii->getNormalDest()], normalArgs,
867                               blocks[ii->getUnwindDest()], unwindArgs);
868     } else {
869       ops.insert(ops.begin(), processValue(ii->getCalledOperand()));
870       op = b.create<InvokeOp>(loc, tys, ops, blocks[ii->getNormalDest()],
871                               normalArgs, blocks[ii->getUnwindDest()],
872                               unwindArgs);
873     }
874 
875     if (!ii->getType()->isVoidTy())
876       v = op->getResult(0);
877     return success();
878   }
879   case llvm::Instruction::Fence: {
880     StringRef syncscope;
881     SmallVector<StringRef, 4> ssNs;
882     llvm::LLVMContext &llvmContext = inst->getContext();
883     llvm::FenceInst *fence = cast<llvm::FenceInst>(inst);
884     llvmContext.getSyncScopeNames(ssNs);
885     int fenceSyncScopeID = fence->getSyncScopeID();
886     for (unsigned i = 0, e = ssNs.size(); i != e; i++) {
887       if (fenceSyncScopeID == llvmContext.getOrInsertSyncScopeID(ssNs[i])) {
888         syncscope = ssNs[i];
889         break;
890       }
891     }
892     b.create<FenceOp>(loc, getLLVMAtomicOrdering(fence->getOrdering()),
893                       syncscope);
894     return success();
895   }
896   case llvm::Instruction::GetElementPtr: {
897     // FIXME: Support inbounds GEPs.
898     llvm::GetElementPtrInst *gep = cast<llvm::GetElementPtrInst>(inst);
899     SmallVector<Value, 4> ops;
900     for (auto *op : gep->operand_values()) {
901       Value value = processValue(op);
902       if (!value)
903         return failure();
904       ops.push_back(value);
905     }
906     Type type = processType(inst->getType());
907     if (!type)
908       return failure();
909     v = b.create<GEPOp>(loc, type, ops);
910     return success();
911   }
912   }
913 }
914 
915 FlatSymbolRefAttr Importer::getPersonalityAsAttr(llvm::Function *f) {
916   if (!f->hasPersonalityFn())
917     return nullptr;
918 
919   llvm::Constant *pf = f->getPersonalityFn();
920 
921   // If it directly has a name, we can use it.
922   if (pf->hasName())
923     return b.getSymbolRefAttr(pf->getName());
924 
925   // If it doesn't have a name, currently, only function pointers that are
926   // bitcast to i8* are parsed.
927   if (auto ce = dyn_cast<llvm::ConstantExpr>(pf)) {
928     if (ce->getOpcode() == llvm::Instruction::BitCast &&
929         ce->getType() == llvm::Type::getInt8PtrTy(f->getContext())) {
930       if (auto func = dyn_cast<llvm::Function>(ce->getOperand(0)))
931         return b.getSymbolRefAttr(func->getName());
932     }
933   }
934   return FlatSymbolRefAttr();
935 }
936 
937 LogicalResult Importer::processFunction(llvm::Function *f) {
938   blocks.clear();
939   instMap.clear();
940   unknownInstMap.clear();
941 
942   auto functionType =
943       processType(f->getFunctionType()).dyn_cast<LLVMFunctionType>();
944   if (!functionType)
945     return failure();
946 
947   b.setInsertionPoint(module.getBody(), getFuncInsertPt());
948   LLVMFuncOp fop =
949       b.create<LLVMFuncOp>(UnknownLoc::get(context), f->getName(), functionType,
950                            convertLinkageFromLLVM(f->getLinkage()));
951 
952   if (FlatSymbolRefAttr personality = getPersonalityAsAttr(f))
953     fop->setAttr(b.getIdentifier("personality"), personality);
954   else if (f->hasPersonalityFn())
955     emitWarning(UnknownLoc::get(context),
956                 "could not deduce personality, skipping it");
957 
958   if (f->isDeclaration())
959     return success();
960 
961   // Eagerly create all blocks.
962   SmallVector<Block *, 4> blockList;
963   for (llvm::BasicBlock &bb : *f) {
964     blockList.push_back(b.createBlock(&fop.body(), fop.body().end()));
965     blocks[&bb] = blockList.back();
966   }
967   currentEntryBlock = blockList[0];
968 
969   // Add function arguments to the entry block.
970   for (auto kv : llvm::enumerate(f->args()))
971     instMap[&kv.value()] =
972         blockList[0]->addArgument(functionType.getParamType(kv.index()));
973 
974   for (auto bbs : llvm::zip(*f, blockList)) {
975     if (failed(processBasicBlock(&std::get<0>(bbs), std::get<1>(bbs))))
976       return failure();
977   }
978 
979   // Now that all instructions are guaranteed to have been visited, ensure
980   // any unknown uses we encountered are remapped.
981   for (auto &llvmAndUnknown : unknownInstMap) {
982     assert(instMap.count(llvmAndUnknown.first));
983     Value newValue = instMap[llvmAndUnknown.first];
984     Value oldValue = llvmAndUnknown.second->getResult(0);
985     oldValue.replaceAllUsesWith(newValue);
986     llvmAndUnknown.second->erase();
987   }
988   return success();
989 }
990 
991 LogicalResult Importer::processBasicBlock(llvm::BasicBlock *bb, Block *block) {
992   b.setInsertionPointToStart(block);
993   for (llvm::Instruction &inst : *bb) {
994     if (failed(processInstruction(&inst)))
995       return failure();
996   }
997   return success();
998 }
999 
1000 OwningModuleRef
1001 mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
1002                               MLIRContext *context) {
1003   context->loadDialect<LLVMDialect>();
1004   OwningModuleRef module(ModuleOp::create(
1005       FileLineColLoc::get(context, "", /*line=*/0, /*column=*/0)));
1006 
1007   Importer deserializer(context, module.get());
1008   for (llvm::GlobalVariable &gv : llvmModule->globals()) {
1009     if (!deserializer.processGlobal(&gv))
1010       return {};
1011   }
1012   for (llvm::Function &f : llvmModule->functions()) {
1013     if (failed(deserializer.processFunction(&f)))
1014       return {};
1015   }
1016 
1017   return module;
1018 }
1019 
1020 // Deserializes the LLVM bitcode stored in `input` into an MLIR module in the
1021 // LLVM dialect.
1022 OwningModuleRef translateLLVMIRToModule(llvm::SourceMgr &sourceMgr,
1023                                         MLIRContext *context) {
1024   llvm::SMDiagnostic err;
1025   llvm::LLVMContext llvmContext;
1026   std::unique_ptr<llvm::Module> llvmModule = llvm::parseIR(
1027       *sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), err, llvmContext);
1028   if (!llvmModule) {
1029     std::string errStr;
1030     llvm::raw_string_ostream errStream(errStr);
1031     err.print(/*ProgName=*/"", errStream);
1032     emitError(UnknownLoc::get(context)) << errStream.str();
1033     return {};
1034   }
1035   return translateLLVMIRToModule(std::move(llvmModule), context);
1036 }
1037 
1038 namespace mlir {
1039 void registerFromLLVMIRTranslation() {
1040   TranslateToMLIRRegistration fromLLVM(
1041       "import-llvm", [](llvm::SourceMgr &sourceMgr, MLIRContext *context) {
1042         return ::translateLLVMIRToModule(sourceMgr, context);
1043       });
1044 }
1045 } // namespace mlir
1046