1 //===- ConvertFromLLVMIR.cpp - MLIR to LLVM IR conversion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between LLVM IR and the MLIR LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/MLIRContext.h"
18 #include "mlir/Target/LLVMIR/Import.h"
19 #include "mlir/Translation.h"
20 
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/IR/Attributes.h"
23 #include "llvm/IR/Constants.h"
24 #include "llvm/IR/DerivedTypes.h"
25 #include "llvm/IR/Function.h"
26 #include "llvm/IR/InlineAsm.h"
27 #include "llvm/IR/Instructions.h"
28 #include "llvm/IR/Type.h"
29 #include "llvm/IRReader/IRReader.h"
30 #include "llvm/Support/Error.h"
31 #include "llvm/Support/SourceMgr.h"
32 
33 using namespace mlir;
34 using namespace mlir::LLVM;
35 
36 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsFromLLVM.inc"
37 
38 // Utility to print an LLVM value as a string for passing to emitError().
39 // FIXME: Diagnostic should be able to natively handle types that have
40 // operator << (raw_ostream&) defined.
41 static std::string diag(llvm::Value &v) {
42   std::string s;
43   llvm::raw_string_ostream os(s);
44   os << v;
45   return os.str();
46 }
47 
48 namespace mlir {
49 namespace LLVM {
50 namespace detail {
51 /// Support for translating LLVM IR types to MLIR LLVM dialect types.
52 class TypeFromLLVMIRTranslatorImpl {
53 public:
54   /// Constructs a class creating types in the given MLIR context.
55   TypeFromLLVMIRTranslatorImpl(MLIRContext &context) : context(context) {}
56 
57   /// Translates the given type.
58   Type translateType(llvm::Type *type) {
59     if (knownTranslations.count(type))
60       return knownTranslations.lookup(type);
61 
62     Type translated =
63         llvm::TypeSwitch<llvm::Type *, Type>(type)
64             .Case<llvm::ArrayType, llvm::FunctionType, llvm::IntegerType,
65                   llvm::PointerType, llvm::StructType, llvm::FixedVectorType,
66                   llvm::ScalableVectorType>(
67                 [this](auto *type) { return this->translate(type); })
68             .Default([this](llvm::Type *type) {
69               return translatePrimitiveType(type);
70             });
71     knownTranslations.try_emplace(type, translated);
72     return translated;
73   }
74 
75 private:
76   /// Translates the given primitive, i.e. non-parametric in MLIR nomenclature,
77   /// type.
78   Type translatePrimitiveType(llvm::Type *type) {
79     if (type->isVoidTy())
80       return LLVM::LLVMVoidType::get(&context);
81     if (type->isHalfTy())
82       return Float16Type::get(&context);
83     if (type->isBFloatTy())
84       return BFloat16Type::get(&context);
85     if (type->isFloatTy())
86       return Float32Type::get(&context);
87     if (type->isDoubleTy())
88       return Float64Type::get(&context);
89     if (type->isFP128Ty())
90       return Float128Type::get(&context);
91     if (type->isX86_FP80Ty())
92       return Float80Type::get(&context);
93     if (type->isPPC_FP128Ty())
94       return LLVM::LLVMPPCFP128Type::get(&context);
95     if (type->isX86_MMXTy())
96       return LLVM::LLVMX86MMXType::get(&context);
97     if (type->isLabelTy())
98       return LLVM::LLVMLabelType::get(&context);
99     if (type->isMetadataTy())
100       return LLVM::LLVMMetadataType::get(&context);
101     llvm_unreachable("not a primitive type");
102   }
103 
104   /// Translates the given array type.
105   Type translate(llvm::ArrayType *type) {
106     return LLVM::LLVMArrayType::get(translateType(type->getElementType()),
107                                     type->getNumElements());
108   }
109 
110   /// Translates the given function type.
111   Type translate(llvm::FunctionType *type) {
112     SmallVector<Type, 8> paramTypes;
113     translateTypes(type->params(), paramTypes);
114     return LLVM::LLVMFunctionType::get(translateType(type->getReturnType()),
115                                        paramTypes, type->isVarArg());
116   }
117 
118   /// Translates the given integer type.
119   Type translate(llvm::IntegerType *type) {
120     return IntegerType::get(&context, type->getBitWidth());
121   }
122 
123   /// Translates the given pointer type.
124   Type translate(llvm::PointerType *type) {
125     return LLVM::LLVMPointerType::get(translateType(type->getElementType()),
126                                       type->getAddressSpace());
127   }
128 
129   /// Translates the given structure type.
130   Type translate(llvm::StructType *type) {
131     SmallVector<Type, 8> subtypes;
132     if (type->isLiteral()) {
133       translateTypes(type->subtypes(), subtypes);
134       return LLVM::LLVMStructType::getLiteral(&context, subtypes,
135                                               type->isPacked());
136     }
137 
138     if (type->isOpaque())
139       return LLVM::LLVMStructType::getOpaque(type->getName(), &context);
140 
141     LLVM::LLVMStructType translated =
142         LLVM::LLVMStructType::getIdentified(&context, type->getName());
143     knownTranslations.try_emplace(type, translated);
144     translateTypes(type->subtypes(), subtypes);
145     LogicalResult bodySet = translated.setBody(subtypes, type->isPacked());
146     assert(succeeded(bodySet) &&
147            "could not set the body of an identified struct");
148     (void)bodySet;
149     return translated;
150   }
151 
152   /// Translates the given fixed-vector type.
153   Type translate(llvm::FixedVectorType *type) {
154     return LLVM::getFixedVectorType(translateType(type->getElementType()),
155                                     type->getNumElements());
156   }
157 
158   /// Translates the given scalable-vector type.
159   Type translate(llvm::ScalableVectorType *type) {
160     return LLVM::LLVMScalableVectorType::get(
161         translateType(type->getElementType()), type->getMinNumElements());
162   }
163 
164   /// Translates a list of types.
165   void translateTypes(ArrayRef<llvm::Type *> types,
166                       SmallVectorImpl<Type> &result) {
167     result.reserve(result.size() + types.size());
168     for (llvm::Type *type : types)
169       result.push_back(translateType(type));
170   }
171 
172   /// Map of known translations. Serves as a cache and as recursion stopper for
173   /// translating recursive structs.
174   llvm::DenseMap<llvm::Type *, Type> knownTranslations;
175 
176   /// The context in which MLIR types are created.
177   MLIRContext &context;
178 };
179 } // end namespace detail
180 
181 /// Utility class to translate LLVM IR types to the MLIR LLVM dialect. Stores
182 /// the translation state, in particular any identified structure types that are
183 /// reused across translations.
184 class TypeFromLLVMIRTranslator {
185 public:
186   TypeFromLLVMIRTranslator(MLIRContext &context);
187   ~TypeFromLLVMIRTranslator();
188 
189   /// Translates the given LLVM IR type to the MLIR LLVM dialect.
190   Type translateType(llvm::Type *type);
191 
192 private:
193   /// Private implementation.
194   std::unique_ptr<detail::TypeFromLLVMIRTranslatorImpl> impl;
195 };
196 
197 } // end namespace LLVM
198 } // end namespace mlir
199 
200 LLVM::TypeFromLLVMIRTranslator::TypeFromLLVMIRTranslator(MLIRContext &context)
201     : impl(new detail::TypeFromLLVMIRTranslatorImpl(context)) {}
202 
203 LLVM::TypeFromLLVMIRTranslator::~TypeFromLLVMIRTranslator() {}
204 
205 Type LLVM::TypeFromLLVMIRTranslator::translateType(llvm::Type *type) {
206   return impl->translateType(type);
207 }
208 
209 // Handles importing globals and functions from an LLVM module.
210 namespace {
211 class Importer {
212 public:
213   Importer(MLIRContext *context, ModuleOp module)
214       : b(context), context(context), module(module),
215         unknownLoc(FileLineColLoc::get(context, "imported-bitcode", 0, 0)),
216         typeTranslator(*context) {
217     b.setInsertionPointToStart(module.getBody());
218   }
219 
220   /// Imports `f` into the current module.
221   LogicalResult processFunction(llvm::Function *f);
222 
223   /// Imports GV as a GlobalOp, creating it if it doesn't exist.
224   GlobalOp processGlobal(llvm::GlobalVariable *GV);
225 
226 private:
227   /// Returns personality of `f` as a FlatSymbolRefAttr.
228   FlatSymbolRefAttr getPersonalityAsAttr(llvm::Function *f);
229   /// Imports `bb` into `block`, which must be initially empty.
230   LogicalResult processBasicBlock(llvm::BasicBlock *bb, Block *block);
231   /// Imports `inst` and populates instMap[inst] with the imported Value.
232   LogicalResult processInstruction(llvm::Instruction *inst);
233   /// Creates an LLVM-compatible MLIR type for `type`.
234   Type processType(llvm::Type *type);
235   /// `value` is an SSA-use. Return the remapped version of `value` or a
236   /// placeholder that will be remapped later if this is an instruction that
237   /// has not yet been visited.
238   Value processValue(llvm::Value *value);
239   /// Create the most accurate Location possible using a llvm::DebugLoc and
240   /// possibly an llvm::Instruction to narrow the Location if debug information
241   /// is unavailable.
242   Location processDebugLoc(const llvm::DebugLoc &loc,
243                            llvm::Instruction *inst = nullptr);
244   /// `br` branches to `target`. Append the block arguments to attach to the
245   /// generated branch op to `blockArguments`. These should be in the same order
246   /// as the PHIs in `target`.
247   LogicalResult processBranchArgs(llvm::Instruction *br,
248                                   llvm::BasicBlock *target,
249                                   SmallVectorImpl<Value> &blockArguments);
250   /// Returns the builtin type equivalent to be used in attributes for the given
251   /// LLVM IR dialect type.
252   Type getStdTypeForAttr(Type type);
253   /// Return `value` as an attribute to attach to a GlobalOp.
254   Attribute getConstantAsAttr(llvm::Constant *value);
255   /// Return `c` as an MLIR Value. This could either be a ConstantOp, or
256   /// an expanded sequence of ops in the current function's entry block (for
257   /// ConstantExprs or ConstantGEPs).
258   Value processConstant(llvm::Constant *c);
259 
260   /// The current builder, pointing at where the next Instruction should be
261   /// generated.
262   OpBuilder b;
263   /// The current context.
264   MLIRContext *context;
265   /// The current module being created.
266   ModuleOp module;
267   /// The entry block of the current function being processed.
268   Block *currentEntryBlock;
269 
270   /// Globals are inserted before the first function, if any.
271   Block::iterator getGlobalInsertPt() {
272     auto i = module.getBody()->begin();
273     while (!isa<LLVMFuncOp, ModuleTerminatorOp>(i))
274       ++i;
275     return i;
276   }
277 
278   /// Functions are always inserted before the module terminator.
279   Block::iterator getFuncInsertPt() {
280     return std::prev(module.getBody()->end());
281   }
282 
283   /// Remapped blocks, for the current function.
284   DenseMap<llvm::BasicBlock *, Block *> blocks;
285   /// Remapped values. These are function-local.
286   DenseMap<llvm::Value *, Value> instMap;
287   /// Instructions that had not been defined when first encountered as a use.
288   /// Maps to the dummy Operation that was created in processValue().
289   DenseMap<llvm::Value *, Operation *> unknownInstMap;
290   /// Uniquing map of GlobalVariables.
291   DenseMap<llvm::GlobalVariable *, GlobalOp> globals;
292   /// Cached FileLineColLoc::get("imported-bitcode", 0, 0).
293   Location unknownLoc;
294   /// The stateful type translator (contains named structs).
295   LLVM::TypeFromLLVMIRTranslator typeTranslator;
296 };
297 } // namespace
298 
299 Location Importer::processDebugLoc(const llvm::DebugLoc &loc,
300                                    llvm::Instruction *inst) {
301   if (!loc && inst) {
302     std::string s;
303     llvm::raw_string_ostream os(s);
304     os << "llvm-imported-inst-%";
305     inst->printAsOperand(os, /*PrintType=*/false);
306     return FileLineColLoc::get(context, os.str(), 0, 0);
307   } else if (!loc) {
308     return unknownLoc;
309   }
310   // FIXME: Obtain the filename from DILocationInfo.
311   return FileLineColLoc::get(context, "imported-bitcode", loc.getLine(),
312                              loc.getCol());
313 }
314 
315 Type Importer::processType(llvm::Type *type) {
316   if (Type result = typeTranslator.translateType(type))
317     return result;
318 
319   // FIXME: Diagnostic should be able to natively handle types that have
320   // operator<<(raw_ostream&) defined.
321   std::string s;
322   llvm::raw_string_ostream os(s);
323   os << *type;
324   emitError(unknownLoc) << "unhandled type: " << os.str();
325   return nullptr;
326 }
327 
328 // We only need integers, floats, doubles, and vectors and tensors thereof for
329 // attributes. Scalar and vector types are converted to the standard
330 // equivalents. Array types are converted to ranked tensors; nested array types
331 // are converted to multi-dimensional tensors or vectors, depending on the
332 // innermost type being a scalar or a vector.
333 Type Importer::getStdTypeForAttr(Type type) {
334   if (!type)
335     return nullptr;
336 
337   if (type.isa<IntegerType, FloatType>())
338     return type;
339 
340   // LLVM vectors can only contain scalars.
341   if (LLVM::isCompatibleVectorType(type)) {
342     auto numElements = LLVM::getVectorNumElements(type);
343     if (numElements.isScalable()) {
344       emitError(unknownLoc) << "scalable vectors not supported";
345       return nullptr;
346     }
347     Type elementType = getStdTypeForAttr(LLVM::getVectorElementType(type));
348     if (!elementType)
349       return nullptr;
350     return VectorType::get(numElements.getKnownMinValue(), elementType);
351   }
352 
353   // LLVM arrays can contain other arrays or vectors.
354   if (auto arrayType = type.dyn_cast<LLVMArrayType>()) {
355     // Recover the nested array shape.
356     SmallVector<int64_t, 4> shape;
357     shape.push_back(arrayType.getNumElements());
358     while (arrayType.getElementType().isa<LLVMArrayType>()) {
359       arrayType = arrayType.getElementType().cast<LLVMArrayType>();
360       shape.push_back(arrayType.getNumElements());
361     }
362 
363     // If the innermost type is a vector, use the multi-dimensional vector as
364     // attribute type.
365     if (LLVM::isCompatibleVectorType(arrayType.getElementType())) {
366       auto numElements = LLVM::getVectorNumElements(arrayType.getElementType());
367       if (numElements.isScalable()) {
368         emitError(unknownLoc) << "scalable vectors not supported";
369         return nullptr;
370       }
371       shape.push_back(numElements.getKnownMinValue());
372 
373       Type elementType = getStdTypeForAttr(
374           LLVM::getVectorElementType(arrayType.getElementType()));
375       if (!elementType)
376         return nullptr;
377       return VectorType::get(shape, elementType);
378     }
379 
380     // Otherwise use a tensor.
381     Type elementType = getStdTypeForAttr(arrayType.getElementType());
382     if (!elementType)
383       return nullptr;
384     return RankedTensorType::get(shape, elementType);
385   }
386 
387   return nullptr;
388 }
389 
390 // Get the given constant as an attribute. Not all constants can be represented
391 // as attributes.
392 Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
393   if (auto *ci = dyn_cast<llvm::ConstantInt>(value))
394     return b.getIntegerAttr(
395         IntegerType::get(context, ci->getType()->getBitWidth()),
396         ci->getValue());
397   if (auto *c = dyn_cast<llvm::ConstantDataArray>(value))
398     if (c->isString())
399       return b.getStringAttr(c->getAsString());
400   if (auto *c = dyn_cast<llvm::ConstantFP>(value)) {
401     if (c->getType()->isDoubleTy())
402       return b.getFloatAttr(FloatType::getF64(context), c->getValueAPF());
403     if (c->getType()->isFloatingPointTy())
404       return b.getFloatAttr(FloatType::getF32(context), c->getValueAPF());
405   }
406   if (auto *f = dyn_cast<llvm::Function>(value))
407     return b.getSymbolRefAttr(f->getName());
408 
409   // Convert constant data to a dense elements attribute.
410   if (auto *cd = dyn_cast<llvm::ConstantDataSequential>(value)) {
411     Type type = processType(cd->getElementType());
412     if (!type)
413       return nullptr;
414 
415     auto attrType = getStdTypeForAttr(processType(cd->getType()))
416                         .dyn_cast_or_null<ShapedType>();
417     if (!attrType)
418       return nullptr;
419 
420     if (type.isa<IntegerType>()) {
421       SmallVector<APInt, 8> values;
422       values.reserve(cd->getNumElements());
423       for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i)
424         values.push_back(cd->getElementAsAPInt(i));
425       return DenseElementsAttr::get(attrType, values);
426     }
427 
428     if (type.isa<Float32Type, Float64Type>()) {
429       SmallVector<APFloat, 8> values;
430       values.reserve(cd->getNumElements());
431       for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i)
432         values.push_back(cd->getElementAsAPFloat(i));
433       return DenseElementsAttr::get(attrType, values);
434     }
435 
436     return nullptr;
437   }
438 
439   // Unpack constant aggregates to create dense elements attribute whenever
440   // possible. Return nullptr (failure) otherwise.
441   if (isa<llvm::ConstantAggregate>(value)) {
442     auto outerType = getStdTypeForAttr(processType(value->getType()))
443                          .dyn_cast_or_null<ShapedType>();
444     if (!outerType)
445       return nullptr;
446 
447     SmallVector<Attribute, 8> values;
448     SmallVector<int64_t, 8> shape;
449 
450     for (unsigned i = 0, e = value->getNumOperands(); i < e; ++i) {
451       auto nested = getConstantAsAttr(value->getAggregateElement(i))
452                         .dyn_cast_or_null<DenseElementsAttr>();
453       if (!nested)
454         return nullptr;
455 
456       values.append(nested.attr_value_begin(), nested.attr_value_end());
457     }
458 
459     return DenseElementsAttr::get(outerType, values);
460   }
461 
462   return nullptr;
463 }
464 
465 GlobalOp Importer::processGlobal(llvm::GlobalVariable *GV) {
466   auto it = globals.find(GV);
467   if (it != globals.end())
468     return it->second;
469 
470   OpBuilder b(module.getBody(), getGlobalInsertPt());
471   Attribute valueAttr;
472   if (GV->hasInitializer())
473     valueAttr = getConstantAsAttr(GV->getInitializer());
474   Type type = processType(GV->getValueType());
475   if (!type)
476     return nullptr;
477   GlobalOp op = b.create<GlobalOp>(
478       UnknownLoc::get(context), type, GV->isConstant(),
479       convertLinkageFromLLVM(GV->getLinkage()), GV->getName(), valueAttr);
480   if (GV->hasInitializer() && !valueAttr) {
481     Region &r = op.getInitializerRegion();
482     currentEntryBlock = b.createBlock(&r);
483     b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin());
484     Value v = processConstant(GV->getInitializer());
485     if (!v)
486       return nullptr;
487     b.create<ReturnOp>(op.getLoc(), ArrayRef<Value>({v}));
488   }
489   return globals[GV] = op;
490 }
491 
492 Value Importer::processConstant(llvm::Constant *c) {
493   OpBuilder bEntry(currentEntryBlock, currentEntryBlock->begin());
494   if (Attribute attr = getConstantAsAttr(c)) {
495     // These constants can be represented as attributes.
496     OpBuilder b(currentEntryBlock, currentEntryBlock->begin());
497     Type type = processType(c->getType());
498     if (!type)
499       return nullptr;
500     if (auto symbolRef = attr.dyn_cast<FlatSymbolRefAttr>())
501       return instMap[c] = bEntry.create<AddressOfOp>(unknownLoc, type,
502                                                      symbolRef.getValue());
503     return instMap[c] = bEntry.create<ConstantOp>(unknownLoc, type, attr);
504   }
505   if (auto *cn = dyn_cast<llvm::ConstantPointerNull>(c)) {
506     Type type = processType(cn->getType());
507     if (!type)
508       return nullptr;
509     return instMap[c] = bEntry.create<NullOp>(unknownLoc, type);
510   }
511   if (auto *GV = dyn_cast<llvm::GlobalVariable>(c))
512     return bEntry.create<AddressOfOp>(UnknownLoc::get(context),
513                                       processGlobal(GV));
514 
515   if (auto *ce = dyn_cast<llvm::ConstantExpr>(c)) {
516     llvm::Instruction *i = ce->getAsInstruction();
517     OpBuilder::InsertionGuard guard(b);
518     b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin());
519     if (failed(processInstruction(i)))
520       return nullptr;
521     assert(instMap.count(i));
522 
523     // Remove this zombie LLVM instruction now, leaving us only with the MLIR
524     // op.
525     i->deleteValue();
526     return instMap[c] = instMap[i];
527   }
528   if (auto *ue = dyn_cast<llvm::UndefValue>(c)) {
529     Type type = processType(ue->getType());
530     if (!type)
531       return nullptr;
532     return instMap[c] = bEntry.create<UndefOp>(UnknownLoc::get(context), type);
533   }
534   emitError(unknownLoc) << "unhandled constant: " << diag(*c);
535   return nullptr;
536 }
537 
538 Value Importer::processValue(llvm::Value *value) {
539   auto it = instMap.find(value);
540   if (it != instMap.end())
541     return it->second;
542 
543   // We don't expect to see instructions in dominator order. If we haven't seen
544   // this instruction yet, create an unknown op and remap it later.
545   if (isa<llvm::Instruction>(value)) {
546     OperationState state(UnknownLoc::get(context), "llvm.unknown");
547     Type type = processType(value->getType());
548     if (!type)
549       return nullptr;
550     state.addTypes(type);
551     unknownInstMap[value] = b.createOperation(state);
552     return unknownInstMap[value]->getResult(0);
553   }
554 
555   if (auto *c = dyn_cast<llvm::Constant>(value))
556     return processConstant(c);
557 
558   emitError(unknownLoc) << "unhandled value: " << diag(*value);
559   return nullptr;
560 }
561 
562 /// Return the MLIR OperationName for the given LLVM opcode.
563 static StringRef lookupOperationNameFromOpcode(unsigned opcode) {
564 // Maps from LLVM opcode to MLIR OperationName. This is deliberately ordered
565 // as in llvm/IR/Instructions.def to aid comprehension and spot missing
566 // instructions.
567 #define INST(llvm_n, mlir_n)                                                   \
568   { llvm::Instruction::llvm_n, LLVM::mlir_n##Op::getOperationName() }
569   static const DenseMap<unsigned, StringRef> opcMap = {
570       // Ret is handled specially.
571       // Br is handled specially.
572       // FIXME: switch
573       // FIXME: indirectbr
574       // FIXME: invoke
575       INST(Resume, Resume),
576       // FIXME: unreachable
577       // FIXME: cleanupret
578       // FIXME: catchret
579       // FIXME: catchswitch
580       // FIXME: callbr
581       // FIXME: fneg
582       INST(Add, Add), INST(FAdd, FAdd), INST(Sub, Sub), INST(FSub, FSub),
583       INST(Mul, Mul), INST(FMul, FMul), INST(UDiv, UDiv), INST(SDiv, SDiv),
584       INST(FDiv, FDiv), INST(URem, URem), INST(SRem, SRem), INST(FRem, FRem),
585       INST(Shl, Shl), INST(LShr, LShr), INST(AShr, AShr), INST(And, And),
586       INST(Or, Or), INST(Xor, XOr), INST(Alloca, Alloca), INST(Load, Load),
587       INST(Store, Store),
588       // Getelementptr is handled specially.
589       INST(Ret, Return), INST(Fence, Fence),
590       // FIXME: atomiccmpxchg
591       // FIXME: atomicrmw
592       INST(Trunc, Trunc), INST(ZExt, ZExt), INST(SExt, SExt),
593       INST(FPToUI, FPToUI), INST(FPToSI, FPToSI), INST(UIToFP, UIToFP),
594       INST(SIToFP, SIToFP), INST(FPTrunc, FPTrunc), INST(FPExt, FPExt),
595       INST(PtrToInt, PtrToInt), INST(IntToPtr, IntToPtr),
596       INST(BitCast, Bitcast), INST(AddrSpaceCast, AddrSpaceCast),
597       // FIXME: cleanuppad
598       // FIXME: catchpad
599       // ICmp is handled specially.
600       // FIXME: fcmp
601       // PHI is handled specially.
602       INST(Freeze, Freeze), INST(Call, Call),
603       // FIXME: select
604       // FIXME: vaarg
605       // FIXME: extractelement
606       // FIXME: insertelement
607       // FIXME: shufflevector
608       // FIXME: extractvalue
609       // FIXME: insertvalue
610       // FIXME: landingpad
611   };
612 #undef INST
613 
614   return opcMap.lookup(opcode);
615 }
616 
617 static ICmpPredicate getICmpPredicate(llvm::CmpInst::Predicate p) {
618   switch (p) {
619   default:
620     llvm_unreachable("incorrect comparison predicate");
621   case llvm::CmpInst::Predicate::ICMP_EQ:
622     return LLVM::ICmpPredicate::eq;
623   case llvm::CmpInst::Predicate::ICMP_NE:
624     return LLVM::ICmpPredicate::ne;
625   case llvm::CmpInst::Predicate::ICMP_SLT:
626     return LLVM::ICmpPredicate::slt;
627   case llvm::CmpInst::Predicate::ICMP_SLE:
628     return LLVM::ICmpPredicate::sle;
629   case llvm::CmpInst::Predicate::ICMP_SGT:
630     return LLVM::ICmpPredicate::sgt;
631   case llvm::CmpInst::Predicate::ICMP_SGE:
632     return LLVM::ICmpPredicate::sge;
633   case llvm::CmpInst::Predicate::ICMP_ULT:
634     return LLVM::ICmpPredicate::ult;
635   case llvm::CmpInst::Predicate::ICMP_ULE:
636     return LLVM::ICmpPredicate::ule;
637   case llvm::CmpInst::Predicate::ICMP_UGT:
638     return LLVM::ICmpPredicate::ugt;
639   case llvm::CmpInst::Predicate::ICMP_UGE:
640     return LLVM::ICmpPredicate::uge;
641   }
642   llvm_unreachable("incorrect comparison predicate");
643 }
644 
645 static AtomicOrdering getLLVMAtomicOrdering(llvm::AtomicOrdering ordering) {
646   switch (ordering) {
647   case llvm::AtomicOrdering::NotAtomic:
648     return LLVM::AtomicOrdering::not_atomic;
649   case llvm::AtomicOrdering::Unordered:
650     return LLVM::AtomicOrdering::unordered;
651   case llvm::AtomicOrdering::Monotonic:
652     return LLVM::AtomicOrdering::monotonic;
653   case llvm::AtomicOrdering::Acquire:
654     return LLVM::AtomicOrdering::acquire;
655   case llvm::AtomicOrdering::Release:
656     return LLVM::AtomicOrdering::release;
657   case llvm::AtomicOrdering::AcquireRelease:
658     return LLVM::AtomicOrdering::acq_rel;
659   case llvm::AtomicOrdering::SequentiallyConsistent:
660     return LLVM::AtomicOrdering::seq_cst;
661   }
662   llvm_unreachable("incorrect atomic ordering");
663 }
664 
665 // `br` branches to `target`. Return the branch arguments to `br`, in the
666 // same order of the PHIs in `target`.
667 LogicalResult
668 Importer::processBranchArgs(llvm::Instruction *br, llvm::BasicBlock *target,
669                             SmallVectorImpl<Value> &blockArguments) {
670   for (auto inst = target->begin(); isa<llvm::PHINode>(inst); ++inst) {
671     auto *PN = cast<llvm::PHINode>(&*inst);
672     Value value = processValue(PN->getIncomingValueForBlock(br->getParent()));
673     if (!value)
674       return failure();
675     blockArguments.push_back(value);
676   }
677   return success();
678 }
679 
680 LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
681   // FIXME: Support uses of SubtargetData. Currently inbounds GEPs, fast-math
682   // flags and call / operand attributes are not supported.
683   Location loc = processDebugLoc(inst->getDebugLoc(), inst);
684   Value &v = instMap[inst];
685   assert(!v && "processInstruction must be called only once per instruction!");
686   switch (inst->getOpcode()) {
687   default:
688     return emitError(loc) << "unknown instruction: " << diag(*inst);
689   case llvm::Instruction::Add:
690   case llvm::Instruction::FAdd:
691   case llvm::Instruction::Sub:
692   case llvm::Instruction::FSub:
693   case llvm::Instruction::Mul:
694   case llvm::Instruction::FMul:
695   case llvm::Instruction::UDiv:
696   case llvm::Instruction::SDiv:
697   case llvm::Instruction::FDiv:
698   case llvm::Instruction::URem:
699   case llvm::Instruction::SRem:
700   case llvm::Instruction::FRem:
701   case llvm::Instruction::Shl:
702   case llvm::Instruction::LShr:
703   case llvm::Instruction::AShr:
704   case llvm::Instruction::And:
705   case llvm::Instruction::Or:
706   case llvm::Instruction::Xor:
707   case llvm::Instruction::Alloca:
708   case llvm::Instruction::Load:
709   case llvm::Instruction::Store:
710   case llvm::Instruction::Ret:
711   case llvm::Instruction::Resume:
712   case llvm::Instruction::Trunc:
713   case llvm::Instruction::ZExt:
714   case llvm::Instruction::SExt:
715   case llvm::Instruction::FPToUI:
716   case llvm::Instruction::FPToSI:
717   case llvm::Instruction::UIToFP:
718   case llvm::Instruction::SIToFP:
719   case llvm::Instruction::FPTrunc:
720   case llvm::Instruction::FPExt:
721   case llvm::Instruction::PtrToInt:
722   case llvm::Instruction::IntToPtr:
723   case llvm::Instruction::AddrSpaceCast:
724   case llvm::Instruction::Freeze:
725   case llvm::Instruction::BitCast: {
726     OperationState state(loc, lookupOperationNameFromOpcode(inst->getOpcode()));
727     SmallVector<Value, 4> ops;
728     ops.reserve(inst->getNumOperands());
729     for (auto *op : inst->operand_values()) {
730       Value value = processValue(op);
731       if (!value)
732         return failure();
733       ops.push_back(value);
734     }
735     state.addOperands(ops);
736     if (!inst->getType()->isVoidTy()) {
737       Type type = processType(inst->getType());
738       if (!type)
739         return failure();
740       state.addTypes(type);
741     }
742     Operation *op = b.createOperation(state);
743     if (!inst->getType()->isVoidTy())
744       v = op->getResult(0);
745     return success();
746   }
747   case llvm::Instruction::ICmp: {
748     Value lhs = processValue(inst->getOperand(0));
749     Value rhs = processValue(inst->getOperand(1));
750     if (!lhs || !rhs)
751       return failure();
752     v = b.create<ICmpOp>(
753         loc, getICmpPredicate(cast<llvm::ICmpInst>(inst)->getPredicate()), lhs,
754         rhs);
755     return success();
756   }
757   case llvm::Instruction::Br: {
758     auto *brInst = cast<llvm::BranchInst>(inst);
759     OperationState state(loc,
760                          brInst->isConditional() ? "llvm.cond_br" : "llvm.br");
761     if (brInst->isConditional()) {
762       Value condition = processValue(brInst->getCondition());
763       if (!condition)
764         return failure();
765       state.addOperands(condition);
766     }
767 
768     std::array<int32_t, 3> operandSegmentSizes = {1, 0, 0};
769     for (int i : llvm::seq<int>(0, brInst->getNumSuccessors())) {
770       auto *succ = brInst->getSuccessor(i);
771       SmallVector<Value, 4> blockArguments;
772       if (failed(processBranchArgs(brInst, succ, blockArguments)))
773         return failure();
774       state.addSuccessors(blocks[succ]);
775       state.addOperands(blockArguments);
776       operandSegmentSizes[i + 1] = blockArguments.size();
777     }
778 
779     if (brInst->isConditional()) {
780       state.addAttribute(LLVM::CondBrOp::getOperandSegmentSizeAttr(),
781                          b.getI32VectorAttr(operandSegmentSizes));
782     }
783 
784     b.createOperation(state);
785     return success();
786   }
787   case llvm::Instruction::PHI: {
788     Type type = processType(inst->getType());
789     if (!type)
790       return failure();
791     v = b.getInsertionBlock()->addArgument(type);
792     return success();
793   }
794   case llvm::Instruction::Call: {
795     llvm::CallInst *ci = cast<llvm::CallInst>(inst);
796     SmallVector<Value, 4> ops;
797     ops.reserve(inst->getNumOperands());
798     for (auto &op : ci->arg_operands()) {
799       Value arg = processValue(op.get());
800       if (!arg)
801         return failure();
802       ops.push_back(arg);
803     }
804 
805     SmallVector<Type, 2> tys;
806     if (!ci->getType()->isVoidTy()) {
807       Type type = processType(inst->getType());
808       if (!type)
809         return failure();
810       tys.push_back(type);
811     }
812     Operation *op;
813     if (llvm::Function *callee = ci->getCalledFunction()) {
814       op = b.create<CallOp>(loc, tys, b.getSymbolRefAttr(callee->getName()),
815                             ops);
816     } else {
817       Value calledValue = processValue(ci->getCalledOperand());
818       if (!calledValue)
819         return failure();
820       ops.insert(ops.begin(), calledValue);
821       op = b.create<CallOp>(loc, tys, ops);
822     }
823     if (!ci->getType()->isVoidTy())
824       v = op->getResult(0);
825     return success();
826   }
827   case llvm::Instruction::LandingPad: {
828     llvm::LandingPadInst *lpi = cast<llvm::LandingPadInst>(inst);
829     SmallVector<Value, 4> ops;
830 
831     for (unsigned i = 0, ie = lpi->getNumClauses(); i < ie; i++)
832       ops.push_back(processConstant(lpi->getClause(i)));
833 
834     Type ty = processType(lpi->getType());
835     if (!ty)
836       return failure();
837 
838     v = b.create<LandingpadOp>(loc, ty, lpi->isCleanup(), ops);
839     return success();
840   }
841   case llvm::Instruction::Invoke: {
842     llvm::InvokeInst *ii = cast<llvm::InvokeInst>(inst);
843 
844     SmallVector<Type, 2> tys;
845     if (!ii->getType()->isVoidTy())
846       tys.push_back(processType(inst->getType()));
847 
848     SmallVector<Value, 4> ops;
849     ops.reserve(inst->getNumOperands() + 1);
850     for (auto &op : ii->arg_operands())
851       ops.push_back(processValue(op.get()));
852 
853     SmallVector<Value, 4> normalArgs, unwindArgs;
854     (void)processBranchArgs(ii, ii->getNormalDest(), normalArgs);
855     (void)processBranchArgs(ii, ii->getUnwindDest(), unwindArgs);
856 
857     Operation *op;
858     if (llvm::Function *callee = ii->getCalledFunction()) {
859       op = b.create<InvokeOp>(loc, tys, b.getSymbolRefAttr(callee->getName()),
860                               ops, blocks[ii->getNormalDest()], normalArgs,
861                               blocks[ii->getUnwindDest()], unwindArgs);
862     } else {
863       ops.insert(ops.begin(), processValue(ii->getCalledOperand()));
864       op = b.create<InvokeOp>(loc, tys, ops, blocks[ii->getNormalDest()],
865                               normalArgs, blocks[ii->getUnwindDest()],
866                               unwindArgs);
867     }
868 
869     if (!ii->getType()->isVoidTy())
870       v = op->getResult(0);
871     return success();
872   }
873   case llvm::Instruction::Fence: {
874     StringRef syncscope;
875     SmallVector<StringRef, 4> ssNs;
876     llvm::LLVMContext &llvmContext = inst->getContext();
877     llvm::FenceInst *fence = cast<llvm::FenceInst>(inst);
878     llvmContext.getSyncScopeNames(ssNs);
879     int fenceSyncScopeID = fence->getSyncScopeID();
880     for (unsigned i = 0, e = ssNs.size(); i != e; i++) {
881       if (fenceSyncScopeID == llvmContext.getOrInsertSyncScopeID(ssNs[i])) {
882         syncscope = ssNs[i];
883         break;
884       }
885     }
886     b.create<FenceOp>(loc, getLLVMAtomicOrdering(fence->getOrdering()),
887                       syncscope);
888     return success();
889   }
890   case llvm::Instruction::GetElementPtr: {
891     // FIXME: Support inbounds GEPs.
892     llvm::GetElementPtrInst *gep = cast<llvm::GetElementPtrInst>(inst);
893     SmallVector<Value, 4> ops;
894     for (auto *op : gep->operand_values()) {
895       Value value = processValue(op);
896       if (!value)
897         return failure();
898       ops.push_back(value);
899     }
900     Type type = processType(inst->getType());
901     if (!type)
902       return failure();
903     v = b.create<GEPOp>(loc, type, ops);
904     return success();
905   }
906   }
907 }
908 
909 FlatSymbolRefAttr Importer::getPersonalityAsAttr(llvm::Function *f) {
910   if (!f->hasPersonalityFn())
911     return nullptr;
912 
913   llvm::Constant *pf = f->getPersonalityFn();
914 
915   // If it directly has a name, we can use it.
916   if (pf->hasName())
917     return b.getSymbolRefAttr(pf->getName());
918 
919   // If it doesn't have a name, currently, only function pointers that are
920   // bitcast to i8* are parsed.
921   if (auto ce = dyn_cast<llvm::ConstantExpr>(pf)) {
922     if (ce->getOpcode() == llvm::Instruction::BitCast &&
923         ce->getType() == llvm::Type::getInt8PtrTy(f->getContext())) {
924       if (auto func = dyn_cast<llvm::Function>(ce->getOperand(0)))
925         return b.getSymbolRefAttr(func->getName());
926     }
927   }
928   return FlatSymbolRefAttr();
929 }
930 
931 LogicalResult Importer::processFunction(llvm::Function *f) {
932   blocks.clear();
933   instMap.clear();
934   unknownInstMap.clear();
935 
936   auto functionType =
937       processType(f->getFunctionType()).dyn_cast<LLVMFunctionType>();
938   if (!functionType)
939     return failure();
940 
941   b.setInsertionPoint(module.getBody(), getFuncInsertPt());
942   LLVMFuncOp fop =
943       b.create<LLVMFuncOp>(UnknownLoc::get(context), f->getName(), functionType,
944                            convertLinkageFromLLVM(f->getLinkage()));
945 
946   if (FlatSymbolRefAttr personality = getPersonalityAsAttr(f))
947     fop->setAttr(b.getIdentifier("personality"), personality);
948   else if (f->hasPersonalityFn())
949     emitWarning(UnknownLoc::get(context),
950                 "could not deduce personality, skipping it");
951 
952   if (f->isDeclaration())
953     return success();
954 
955   // Eagerly create all blocks.
956   SmallVector<Block *, 4> blockList;
957   for (llvm::BasicBlock &bb : *f) {
958     blockList.push_back(b.createBlock(&fop.body(), fop.body().end()));
959     blocks[&bb] = blockList.back();
960   }
961   currentEntryBlock = blockList[0];
962 
963   // Add function arguments to the entry block.
964   for (auto kv : llvm::enumerate(f->args()))
965     instMap[&kv.value()] =
966         blockList[0]->addArgument(functionType.getParamType(kv.index()));
967 
968   for (auto bbs : llvm::zip(*f, blockList)) {
969     if (failed(processBasicBlock(&std::get<0>(bbs), std::get<1>(bbs))))
970       return failure();
971   }
972 
973   // Now that all instructions are guaranteed to have been visited, ensure
974   // any unknown uses we encountered are remapped.
975   for (auto &llvmAndUnknown : unknownInstMap) {
976     assert(instMap.count(llvmAndUnknown.first));
977     Value newValue = instMap[llvmAndUnknown.first];
978     Value oldValue = llvmAndUnknown.second->getResult(0);
979     oldValue.replaceAllUsesWith(newValue);
980     llvmAndUnknown.second->erase();
981   }
982   return success();
983 }
984 
985 LogicalResult Importer::processBasicBlock(llvm::BasicBlock *bb, Block *block) {
986   b.setInsertionPointToStart(block);
987   for (llvm::Instruction &inst : *bb) {
988     if (failed(processInstruction(&inst)))
989       return failure();
990   }
991   return success();
992 }
993 
994 OwningModuleRef
995 mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
996                               MLIRContext *context) {
997   context->loadDialect<LLVMDialect>();
998   OwningModuleRef module(ModuleOp::create(
999       FileLineColLoc::get(context, "", /*line=*/0, /*column=*/0)));
1000 
1001   Importer deserializer(context, module.get());
1002   for (llvm::GlobalVariable &gv : llvmModule->globals()) {
1003     if (!deserializer.processGlobal(&gv))
1004       return {};
1005   }
1006   for (llvm::Function &f : llvmModule->functions()) {
1007     if (failed(deserializer.processFunction(&f)))
1008       return {};
1009   }
1010 
1011   return module;
1012 }
1013 
1014 // Deserializes the LLVM bitcode stored in `input` into an MLIR module in the
1015 // LLVM dialect.
1016 OwningModuleRef translateLLVMIRToModule(llvm::SourceMgr &sourceMgr,
1017                                         MLIRContext *context) {
1018   llvm::SMDiagnostic err;
1019   llvm::LLVMContext llvmContext;
1020   std::unique_ptr<llvm::Module> llvmModule = llvm::parseIR(
1021       *sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), err, llvmContext);
1022   if (!llvmModule) {
1023     std::string errStr;
1024     llvm::raw_string_ostream errStream(errStr);
1025     err.print(/*ProgName=*/"", errStream);
1026     emitError(UnknownLoc::get(context)) << errStream.str();
1027     return {};
1028   }
1029   return translateLLVMIRToModule(std::move(llvmModule), context);
1030 }
1031 
1032 namespace mlir {
1033 void registerFromLLVMIRTranslation() {
1034   TranslateToMLIRRegistration fromLLVM(
1035       "import-llvm", [](llvm::SourceMgr &sourceMgr, MLIRContext *context) {
1036         return ::translateLLVMIRToModule(sourceMgr, context);
1037       });
1038 }
1039 } // namespace mlir
1040