1 //===- ConvertFromLLVMIR.cpp - MLIR to LLVM IR conversion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between LLVM IR and the MLIR LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/MLIRContext.h"
18 #include "mlir/Target/LLVMIR/Import.h"
19 #include "mlir/Translation.h"
20 
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/IR/Attributes.h"
23 #include "llvm/IR/Constants.h"
24 #include "llvm/IR/DerivedTypes.h"
25 #include "llvm/IR/Function.h"
26 #include "llvm/IR/InlineAsm.h"
27 #include "llvm/IR/Instructions.h"
28 #include "llvm/IR/Type.h"
29 #include "llvm/IRReader/IRReader.h"
30 #include "llvm/Support/Error.h"
31 #include "llvm/Support/SourceMgr.h"
32 
33 using namespace mlir;
34 using namespace mlir::LLVM;
35 
36 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsFromLLVM.inc"
37 
38 // Utility to print an LLVM value as a string for passing to emitError().
39 // FIXME: Diagnostic should be able to natively handle types that have
40 // operator << (raw_ostream&) defined.
41 static std::string diag(llvm::Value &v) {
42   std::string s;
43   llvm::raw_string_ostream os(s);
44   os << v;
45   return os.str();
46 }
47 
48 namespace mlir {
49 namespace LLVM {
50 namespace detail {
51 /// Support for translating LLVM IR types to MLIR LLVM dialect types.
52 class TypeFromLLVMIRTranslatorImpl {
53 public:
54   /// Constructs a class creating types in the given MLIR context.
55   TypeFromLLVMIRTranslatorImpl(MLIRContext &context) : context(context) {}
56 
57   /// Translates the given type.
58   Type translateType(llvm::Type *type) {
59     if (knownTranslations.count(type))
60       return knownTranslations.lookup(type);
61 
62     Type translated =
63         llvm::TypeSwitch<llvm::Type *, Type>(type)
64             .Case<llvm::ArrayType, llvm::FunctionType, llvm::IntegerType,
65                   llvm::PointerType, llvm::StructType, llvm::FixedVectorType,
66                   llvm::ScalableVectorType>(
67                 [this](auto *type) { return this->translate(type); })
68             .Default([this](llvm::Type *type) {
69               return translatePrimitiveType(type);
70             });
71     knownTranslations.try_emplace(type, translated);
72     return translated;
73   }
74 
75 private:
76   /// Translates the given primitive, i.e. non-parametric in MLIR nomenclature,
77   /// type.
78   Type translatePrimitiveType(llvm::Type *type) {
79     if (type->isVoidTy())
80       return LLVM::LLVMVoidType::get(&context);
81     if (type->isHalfTy())
82       return Float16Type::get(&context);
83     if (type->isBFloatTy())
84       return BFloat16Type::get(&context);
85     if (type->isFloatTy())
86       return Float32Type::get(&context);
87     if (type->isDoubleTy())
88       return Float64Type::get(&context);
89     if (type->isFP128Ty())
90       return Float128Type::get(&context);
91     if (type->isX86_FP80Ty())
92       return Float80Type::get(&context);
93     if (type->isPPC_FP128Ty())
94       return LLVM::LLVMPPCFP128Type::get(&context);
95     if (type->isX86_MMXTy())
96       return LLVM::LLVMX86MMXType::get(&context);
97     if (type->isLabelTy())
98       return LLVM::LLVMLabelType::get(&context);
99     if (type->isMetadataTy())
100       return LLVM::LLVMMetadataType::get(&context);
101     llvm_unreachable("not a primitive type");
102   }
103 
104   /// Translates the given array type.
105   Type translate(llvm::ArrayType *type) {
106     return LLVM::LLVMArrayType::get(translateType(type->getElementType()),
107                                     type->getNumElements());
108   }
109 
110   /// Translates the given function type.
111   Type translate(llvm::FunctionType *type) {
112     SmallVector<Type, 8> paramTypes;
113     translateTypes(type->params(), paramTypes);
114     return LLVM::LLVMFunctionType::get(translateType(type->getReturnType()),
115                                        paramTypes, type->isVarArg());
116   }
117 
118   /// Translates the given integer type.
119   Type translate(llvm::IntegerType *type) {
120     return IntegerType::get(&context, type->getBitWidth());
121   }
122 
123   /// Translates the given pointer type.
124   Type translate(llvm::PointerType *type) {
125     return LLVM::LLVMPointerType::get(translateType(type->getElementType()),
126                                       type->getAddressSpace());
127   }
128 
129   /// Translates the given structure type.
130   Type translate(llvm::StructType *type) {
131     SmallVector<Type, 8> subtypes;
132     if (type->isLiteral()) {
133       translateTypes(type->subtypes(), subtypes);
134       return LLVM::LLVMStructType::getLiteral(&context, subtypes,
135                                               type->isPacked());
136     }
137 
138     if (type->isOpaque())
139       return LLVM::LLVMStructType::getOpaque(type->getName(), &context);
140 
141     LLVM::LLVMStructType translated =
142         LLVM::LLVMStructType::getIdentified(&context, type->getName());
143     knownTranslations.try_emplace(type, translated);
144     translateTypes(type->subtypes(), subtypes);
145     LogicalResult bodySet = translated.setBody(subtypes, type->isPacked());
146     assert(succeeded(bodySet) &&
147            "could not set the body of an identified struct");
148     (void)bodySet;
149     return translated;
150   }
151 
152   /// Translates the given fixed-vector type.
153   Type translate(llvm::FixedVectorType *type) {
154     return LLVM::getFixedVectorType(translateType(type->getElementType()),
155                                     type->getNumElements());
156   }
157 
158   /// Translates the given scalable-vector type.
159   Type translate(llvm::ScalableVectorType *type) {
160     return LLVM::LLVMScalableVectorType::get(
161         translateType(type->getElementType()), type->getMinNumElements());
162   }
163 
164   /// Translates a list of types.
165   void translateTypes(ArrayRef<llvm::Type *> types,
166                       SmallVectorImpl<Type> &result) {
167     result.reserve(result.size() + types.size());
168     for (llvm::Type *type : types)
169       result.push_back(translateType(type));
170   }
171 
172   /// Map of known translations. Serves as a cache and as recursion stopper for
173   /// translating recursive structs.
174   llvm::DenseMap<llvm::Type *, Type> knownTranslations;
175 
176   /// The context in which MLIR types are created.
177   MLIRContext &context;
178 };
179 } // end namespace detail
180 
181 /// Utility class to translate LLVM IR types to the MLIR LLVM dialect. Stores
182 /// the translation state, in particular any identified structure types that are
183 /// reused across translations.
184 class TypeFromLLVMIRTranslator {
185 public:
186   TypeFromLLVMIRTranslator(MLIRContext &context);
187   ~TypeFromLLVMIRTranslator();
188 
189   /// Translates the given LLVM IR type to the MLIR LLVM dialect.
190   Type translateType(llvm::Type *type);
191 
192 private:
193   /// Private implementation.
194   std::unique_ptr<detail::TypeFromLLVMIRTranslatorImpl> impl;
195 };
196 
197 } // end namespace LLVM
198 } // end namespace mlir
199 
200 LLVM::TypeFromLLVMIRTranslator::TypeFromLLVMIRTranslator(MLIRContext &context)
201     : impl(new detail::TypeFromLLVMIRTranslatorImpl(context)) {}
202 
203 LLVM::TypeFromLLVMIRTranslator::~TypeFromLLVMIRTranslator() {}
204 
205 Type LLVM::TypeFromLLVMIRTranslator::translateType(llvm::Type *type) {
206   return impl->translateType(type);
207 }
208 
209 // Handles importing globals and functions from an LLVM module.
210 namespace {
211 class Importer {
212 public:
213   Importer(MLIRContext *context, ModuleOp module)
214       : b(context), context(context), module(module),
215         unknownLoc(FileLineColLoc::get(context, "imported-bitcode", 0, 0)),
216         typeTranslator(*context) {
217     b.setInsertionPointToStart(module.getBody());
218   }
219 
220   /// Imports `f` into the current module.
221   LogicalResult processFunction(llvm::Function *f);
222 
223   /// Imports GV as a GlobalOp, creating it if it doesn't exist.
224   GlobalOp processGlobal(llvm::GlobalVariable *GV);
225 
226 private:
227   /// Returns personality of `f` as a FlatSymbolRefAttr.
228   FlatSymbolRefAttr getPersonalityAsAttr(llvm::Function *f);
229   /// Imports `bb` into `block`, which must be initially empty.
230   LogicalResult processBasicBlock(llvm::BasicBlock *bb, Block *block);
231   /// Imports `inst` and populates instMap[inst] with the imported Value.
232   LogicalResult processInstruction(llvm::Instruction *inst);
233   /// Creates an LLVM-compatible MLIR type for `type`.
234   Type processType(llvm::Type *type);
235   /// `value` is an SSA-use. Return the remapped version of `value` or a
236   /// placeholder that will be remapped later if this is an instruction that
237   /// has not yet been visited.
238   Value processValue(llvm::Value *value);
239   /// Create the most accurate Location possible using a llvm::DebugLoc and
240   /// possibly an llvm::Instruction to narrow the Location if debug information
241   /// is unavailable.
242   Location processDebugLoc(const llvm::DebugLoc &loc,
243                            llvm::Instruction *inst = nullptr);
244   /// `br` branches to `target`. Append the block arguments to attach to the
245   /// generated branch op to `blockArguments`. These should be in the same order
246   /// as the PHIs in `target`.
247   LogicalResult processBranchArgs(llvm::Instruction *br,
248                                   llvm::BasicBlock *target,
249                                   SmallVectorImpl<Value> &blockArguments);
250   /// Returns the builtin type equivalent to be used in attributes for the given
251   /// LLVM IR dialect type.
252   Type getStdTypeForAttr(Type type);
253   /// Return `value` as an attribute to attach to a GlobalOp.
254   Attribute getConstantAsAttr(llvm::Constant *value);
255   /// Return `c` as an MLIR Value. This could either be a ConstantOp, or
256   /// an expanded sequence of ops in the current function's entry block (for
257   /// ConstantExprs or ConstantGEPs).
258   Value processConstant(llvm::Constant *c);
259 
260   /// The current builder, pointing at where the next Instruction should be
261   /// generated.
262   OpBuilder b;
263   /// The current context.
264   MLIRContext *context;
265   /// The current module being created.
266   ModuleOp module;
267   /// The entry block of the current function being processed.
268   Block *currentEntryBlock;
269 
270   /// Globals are inserted before the first function, if any.
271   Block::iterator getGlobalInsertPt() {
272     auto it = module.getBody()->begin();
273     auto endIt = module.getBody()->end();
274     while (it != endIt && !isa<LLVMFuncOp>(it))
275       ++it;
276     return it;
277   }
278 
279   /// Functions are always inserted before the module terminator.
280   Block::iterator getFuncInsertPt() {
281     return std::prev(module.getBody()->end());
282   }
283 
284   /// Remapped blocks, for the current function.
285   DenseMap<llvm::BasicBlock *, Block *> blocks;
286   /// Remapped values. These are function-local.
287   DenseMap<llvm::Value *, Value> instMap;
288   /// Instructions that had not been defined when first encountered as a use.
289   /// Maps to the dummy Operation that was created in processValue().
290   DenseMap<llvm::Value *, Operation *> unknownInstMap;
291   /// Uniquing map of GlobalVariables.
292   DenseMap<llvm::GlobalVariable *, GlobalOp> globals;
293   /// Cached FileLineColLoc::get("imported-bitcode", 0, 0).
294   Location unknownLoc;
295   /// The stateful type translator (contains named structs).
296   LLVM::TypeFromLLVMIRTranslator typeTranslator;
297 };
298 } // namespace
299 
300 Location Importer::processDebugLoc(const llvm::DebugLoc &loc,
301                                    llvm::Instruction *inst) {
302   if (!loc && inst) {
303     std::string s;
304     llvm::raw_string_ostream os(s);
305     os << "llvm-imported-inst-%";
306     inst->printAsOperand(os, /*PrintType=*/false);
307     return FileLineColLoc::get(context, os.str(), 0, 0);
308   } else if (!loc) {
309     return unknownLoc;
310   }
311   // FIXME: Obtain the filename from DILocationInfo.
312   return FileLineColLoc::get(context, "imported-bitcode", loc.getLine(),
313                              loc.getCol());
314 }
315 
316 Type Importer::processType(llvm::Type *type) {
317   if (Type result = typeTranslator.translateType(type))
318     return result;
319 
320   // FIXME: Diagnostic should be able to natively handle types that have
321   // operator<<(raw_ostream&) defined.
322   std::string s;
323   llvm::raw_string_ostream os(s);
324   os << *type;
325   emitError(unknownLoc) << "unhandled type: " << os.str();
326   return nullptr;
327 }
328 
329 // We only need integers, floats, doubles, and vectors and tensors thereof for
330 // attributes. Scalar and vector types are converted to the standard
331 // equivalents. Array types are converted to ranked tensors; nested array types
332 // are converted to multi-dimensional tensors or vectors, depending on the
333 // innermost type being a scalar or a vector.
334 Type Importer::getStdTypeForAttr(Type type) {
335   if (!type)
336     return nullptr;
337 
338   if (type.isa<IntegerType, FloatType>())
339     return type;
340 
341   // LLVM vectors can only contain scalars.
342   if (LLVM::isCompatibleVectorType(type)) {
343     auto numElements = LLVM::getVectorNumElements(type);
344     if (numElements.isScalable()) {
345       emitError(unknownLoc) << "scalable vectors not supported";
346       return nullptr;
347     }
348     Type elementType = getStdTypeForAttr(LLVM::getVectorElementType(type));
349     if (!elementType)
350       return nullptr;
351     return VectorType::get(numElements.getKnownMinValue(), elementType);
352   }
353 
354   // LLVM arrays can contain other arrays or vectors.
355   if (auto arrayType = type.dyn_cast<LLVMArrayType>()) {
356     // Recover the nested array shape.
357     SmallVector<int64_t, 4> shape;
358     shape.push_back(arrayType.getNumElements());
359     while (arrayType.getElementType().isa<LLVMArrayType>()) {
360       arrayType = arrayType.getElementType().cast<LLVMArrayType>();
361       shape.push_back(arrayType.getNumElements());
362     }
363 
364     // If the innermost type is a vector, use the multi-dimensional vector as
365     // attribute type.
366     if (LLVM::isCompatibleVectorType(arrayType.getElementType())) {
367       auto numElements = LLVM::getVectorNumElements(arrayType.getElementType());
368       if (numElements.isScalable()) {
369         emitError(unknownLoc) << "scalable vectors not supported";
370         return nullptr;
371       }
372       shape.push_back(numElements.getKnownMinValue());
373 
374       Type elementType = getStdTypeForAttr(
375           LLVM::getVectorElementType(arrayType.getElementType()));
376       if (!elementType)
377         return nullptr;
378       return VectorType::get(shape, elementType);
379     }
380 
381     // Otherwise use a tensor.
382     Type elementType = getStdTypeForAttr(arrayType.getElementType());
383     if (!elementType)
384       return nullptr;
385     return RankedTensorType::get(shape, elementType);
386   }
387 
388   return nullptr;
389 }
390 
391 // Get the given constant as an attribute. Not all constants can be represented
392 // as attributes.
393 Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
394   if (auto *ci = dyn_cast<llvm::ConstantInt>(value))
395     return b.getIntegerAttr(
396         IntegerType::get(context, ci->getType()->getBitWidth()),
397         ci->getValue());
398   if (auto *c = dyn_cast<llvm::ConstantDataArray>(value))
399     if (c->isString())
400       return b.getStringAttr(c->getAsString());
401   if (auto *c = dyn_cast<llvm::ConstantFP>(value)) {
402     if (c->getType()->isDoubleTy())
403       return b.getFloatAttr(FloatType::getF64(context), c->getValueAPF());
404     if (c->getType()->isFloatingPointTy())
405       return b.getFloatAttr(FloatType::getF32(context), c->getValueAPF());
406   }
407   if (auto *f = dyn_cast<llvm::Function>(value))
408     return b.getSymbolRefAttr(f->getName());
409 
410   // Convert constant data to a dense elements attribute.
411   if (auto *cd = dyn_cast<llvm::ConstantDataSequential>(value)) {
412     Type type = processType(cd->getElementType());
413     if (!type)
414       return nullptr;
415 
416     auto attrType = getStdTypeForAttr(processType(cd->getType()))
417                         .dyn_cast_or_null<ShapedType>();
418     if (!attrType)
419       return nullptr;
420 
421     if (type.isa<IntegerType>()) {
422       SmallVector<APInt, 8> values;
423       values.reserve(cd->getNumElements());
424       for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i)
425         values.push_back(cd->getElementAsAPInt(i));
426       return DenseElementsAttr::get(attrType, values);
427     }
428 
429     if (type.isa<Float32Type, Float64Type>()) {
430       SmallVector<APFloat, 8> values;
431       values.reserve(cd->getNumElements());
432       for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i)
433         values.push_back(cd->getElementAsAPFloat(i));
434       return DenseElementsAttr::get(attrType, values);
435     }
436 
437     return nullptr;
438   }
439 
440   // Unpack constant aggregates to create dense elements attribute whenever
441   // possible. Return nullptr (failure) otherwise.
442   if (isa<llvm::ConstantAggregate>(value)) {
443     auto outerType = getStdTypeForAttr(processType(value->getType()))
444                          .dyn_cast_or_null<ShapedType>();
445     if (!outerType)
446       return nullptr;
447 
448     SmallVector<Attribute, 8> values;
449     SmallVector<int64_t, 8> shape;
450 
451     for (unsigned i = 0, e = value->getNumOperands(); i < e; ++i) {
452       auto nested = getConstantAsAttr(value->getAggregateElement(i))
453                         .dyn_cast_or_null<DenseElementsAttr>();
454       if (!nested)
455         return nullptr;
456 
457       values.append(nested.attr_value_begin(), nested.attr_value_end());
458     }
459 
460     return DenseElementsAttr::get(outerType, values);
461   }
462 
463   return nullptr;
464 }
465 
466 GlobalOp Importer::processGlobal(llvm::GlobalVariable *GV) {
467   auto it = globals.find(GV);
468   if (it != globals.end())
469     return it->second;
470 
471   OpBuilder b(module.getBody(), getGlobalInsertPt());
472   Attribute valueAttr;
473   if (GV->hasInitializer())
474     valueAttr = getConstantAsAttr(GV->getInitializer());
475   Type type = processType(GV->getValueType());
476   if (!type)
477     return nullptr;
478   GlobalOp op = b.create<GlobalOp>(
479       UnknownLoc::get(context), type, GV->isConstant(),
480       convertLinkageFromLLVM(GV->getLinkage()), GV->getName(), valueAttr);
481   if (GV->hasInitializer() && !valueAttr) {
482     Region &r = op.getInitializerRegion();
483     currentEntryBlock = b.createBlock(&r);
484     b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin());
485     Value v = processConstant(GV->getInitializer());
486     if (!v)
487       return nullptr;
488     b.create<ReturnOp>(op.getLoc(), ArrayRef<Value>({v}));
489   }
490   if (GV->hasAtLeastLocalUnnamedAddr())
491     op.unnamed_addrAttr(UnnamedAddrAttr::get(
492         context, convertUnnamedAddrFromLLVM(GV->getUnnamedAddr())));
493   return globals[GV] = op;
494 }
495 
496 Value Importer::processConstant(llvm::Constant *c) {
497   OpBuilder bEntry(currentEntryBlock, currentEntryBlock->begin());
498   if (Attribute attr = getConstantAsAttr(c)) {
499     // These constants can be represented as attributes.
500     OpBuilder b(currentEntryBlock, currentEntryBlock->begin());
501     Type type = processType(c->getType());
502     if (!type)
503       return nullptr;
504     if (auto symbolRef = attr.dyn_cast<FlatSymbolRefAttr>())
505       return instMap[c] = bEntry.create<AddressOfOp>(unknownLoc, type,
506                                                      symbolRef.getValue());
507     return instMap[c] = bEntry.create<ConstantOp>(unknownLoc, type, attr);
508   }
509   if (auto *cn = dyn_cast<llvm::ConstantPointerNull>(c)) {
510     Type type = processType(cn->getType());
511     if (!type)
512       return nullptr;
513     return instMap[c] = bEntry.create<NullOp>(unknownLoc, type);
514   }
515   if (auto *GV = dyn_cast<llvm::GlobalVariable>(c))
516     return bEntry.create<AddressOfOp>(UnknownLoc::get(context),
517                                       processGlobal(GV));
518 
519   if (auto *ce = dyn_cast<llvm::ConstantExpr>(c)) {
520     llvm::Instruction *i = ce->getAsInstruction();
521     OpBuilder::InsertionGuard guard(b);
522     b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin());
523     if (failed(processInstruction(i)))
524       return nullptr;
525     assert(instMap.count(i));
526 
527     // Remove this zombie LLVM instruction now, leaving us only with the MLIR
528     // op.
529     i->deleteValue();
530     return instMap[c] = instMap[i];
531   }
532   if (auto *ue = dyn_cast<llvm::UndefValue>(c)) {
533     Type type = processType(ue->getType());
534     if (!type)
535       return nullptr;
536     return instMap[c] = bEntry.create<UndefOp>(UnknownLoc::get(context), type);
537   }
538   emitError(unknownLoc) << "unhandled constant: " << diag(*c);
539   return nullptr;
540 }
541 
542 Value Importer::processValue(llvm::Value *value) {
543   auto it = instMap.find(value);
544   if (it != instMap.end())
545     return it->second;
546 
547   // We don't expect to see instructions in dominator order. If we haven't seen
548   // this instruction yet, create an unknown op and remap it later.
549   if (isa<llvm::Instruction>(value)) {
550     OperationState state(UnknownLoc::get(context), "llvm.unknown");
551     Type type = processType(value->getType());
552     if (!type)
553       return nullptr;
554     state.addTypes(type);
555     unknownInstMap[value] = b.createOperation(state);
556     return unknownInstMap[value]->getResult(0);
557   }
558 
559   if (auto *c = dyn_cast<llvm::Constant>(value))
560     return processConstant(c);
561 
562   emitError(unknownLoc) << "unhandled value: " << diag(*value);
563   return nullptr;
564 }
565 
566 /// Return the MLIR OperationName for the given LLVM opcode.
567 static StringRef lookupOperationNameFromOpcode(unsigned opcode) {
568 // Maps from LLVM opcode to MLIR OperationName. This is deliberately ordered
569 // as in llvm/IR/Instructions.def to aid comprehension and spot missing
570 // instructions.
571 #define INST(llvm_n, mlir_n)                                                   \
572   { llvm::Instruction::llvm_n, LLVM::mlir_n##Op::getOperationName() }
573   static const DenseMap<unsigned, StringRef> opcMap = {
574       // Ret is handled specially.
575       // Br is handled specially.
576       // FIXME: switch
577       // FIXME: indirectbr
578       // FIXME: invoke
579       INST(Resume, Resume),
580       // FIXME: unreachable
581       // FIXME: cleanupret
582       // FIXME: catchret
583       // FIXME: catchswitch
584       // FIXME: callbr
585       // FIXME: fneg
586       INST(Add, Add), INST(FAdd, FAdd), INST(Sub, Sub), INST(FSub, FSub),
587       INST(Mul, Mul), INST(FMul, FMul), INST(UDiv, UDiv), INST(SDiv, SDiv),
588       INST(FDiv, FDiv), INST(URem, URem), INST(SRem, SRem), INST(FRem, FRem),
589       INST(Shl, Shl), INST(LShr, LShr), INST(AShr, AShr), INST(And, And),
590       INST(Or, Or), INST(Xor, XOr), INST(Alloca, Alloca), INST(Load, Load),
591       INST(Store, Store),
592       // Getelementptr is handled specially.
593       INST(Ret, Return), INST(Fence, Fence),
594       // FIXME: atomiccmpxchg
595       // FIXME: atomicrmw
596       INST(Trunc, Trunc), INST(ZExt, ZExt), INST(SExt, SExt),
597       INST(FPToUI, FPToUI), INST(FPToSI, FPToSI), INST(UIToFP, UIToFP),
598       INST(SIToFP, SIToFP), INST(FPTrunc, FPTrunc), INST(FPExt, FPExt),
599       INST(PtrToInt, PtrToInt), INST(IntToPtr, IntToPtr),
600       INST(BitCast, Bitcast), INST(AddrSpaceCast, AddrSpaceCast),
601       // FIXME: cleanuppad
602       // FIXME: catchpad
603       // ICmp is handled specially.
604       // FIXME: fcmp
605       // PHI is handled specially.
606       INST(Freeze, Freeze), INST(Call, Call),
607       // FIXME: select
608       // FIXME: vaarg
609       // FIXME: extractelement
610       // FIXME: insertelement
611       // FIXME: shufflevector
612       // FIXME: extractvalue
613       // FIXME: insertvalue
614       // FIXME: landingpad
615   };
616 #undef INST
617 
618   return opcMap.lookup(opcode);
619 }
620 
621 static ICmpPredicate getICmpPredicate(llvm::CmpInst::Predicate p) {
622   switch (p) {
623   default:
624     llvm_unreachable("incorrect comparison predicate");
625   case llvm::CmpInst::Predicate::ICMP_EQ:
626     return LLVM::ICmpPredicate::eq;
627   case llvm::CmpInst::Predicate::ICMP_NE:
628     return LLVM::ICmpPredicate::ne;
629   case llvm::CmpInst::Predicate::ICMP_SLT:
630     return LLVM::ICmpPredicate::slt;
631   case llvm::CmpInst::Predicate::ICMP_SLE:
632     return LLVM::ICmpPredicate::sle;
633   case llvm::CmpInst::Predicate::ICMP_SGT:
634     return LLVM::ICmpPredicate::sgt;
635   case llvm::CmpInst::Predicate::ICMP_SGE:
636     return LLVM::ICmpPredicate::sge;
637   case llvm::CmpInst::Predicate::ICMP_ULT:
638     return LLVM::ICmpPredicate::ult;
639   case llvm::CmpInst::Predicate::ICMP_ULE:
640     return LLVM::ICmpPredicate::ule;
641   case llvm::CmpInst::Predicate::ICMP_UGT:
642     return LLVM::ICmpPredicate::ugt;
643   case llvm::CmpInst::Predicate::ICMP_UGE:
644     return LLVM::ICmpPredicate::uge;
645   }
646   llvm_unreachable("incorrect comparison predicate");
647 }
648 
649 static AtomicOrdering getLLVMAtomicOrdering(llvm::AtomicOrdering ordering) {
650   switch (ordering) {
651   case llvm::AtomicOrdering::NotAtomic:
652     return LLVM::AtomicOrdering::not_atomic;
653   case llvm::AtomicOrdering::Unordered:
654     return LLVM::AtomicOrdering::unordered;
655   case llvm::AtomicOrdering::Monotonic:
656     return LLVM::AtomicOrdering::monotonic;
657   case llvm::AtomicOrdering::Acquire:
658     return LLVM::AtomicOrdering::acquire;
659   case llvm::AtomicOrdering::Release:
660     return LLVM::AtomicOrdering::release;
661   case llvm::AtomicOrdering::AcquireRelease:
662     return LLVM::AtomicOrdering::acq_rel;
663   case llvm::AtomicOrdering::SequentiallyConsistent:
664     return LLVM::AtomicOrdering::seq_cst;
665   }
666   llvm_unreachable("incorrect atomic ordering");
667 }
668 
669 // `br` branches to `target`. Return the branch arguments to `br`, in the
670 // same order of the PHIs in `target`.
671 LogicalResult
672 Importer::processBranchArgs(llvm::Instruction *br, llvm::BasicBlock *target,
673                             SmallVectorImpl<Value> &blockArguments) {
674   for (auto inst = target->begin(); isa<llvm::PHINode>(inst); ++inst) {
675     auto *PN = cast<llvm::PHINode>(&*inst);
676     Value value = processValue(PN->getIncomingValueForBlock(br->getParent()));
677     if (!value)
678       return failure();
679     blockArguments.push_back(value);
680   }
681   return success();
682 }
683 
684 LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
685   // FIXME: Support uses of SubtargetData. Currently inbounds GEPs, fast-math
686   // flags and call / operand attributes are not supported.
687   Location loc = processDebugLoc(inst->getDebugLoc(), inst);
688   Value &v = instMap[inst];
689   assert(!v && "processInstruction must be called only once per instruction!");
690   switch (inst->getOpcode()) {
691   default:
692     return emitError(loc) << "unknown instruction: " << diag(*inst);
693   case llvm::Instruction::Add:
694   case llvm::Instruction::FAdd:
695   case llvm::Instruction::Sub:
696   case llvm::Instruction::FSub:
697   case llvm::Instruction::Mul:
698   case llvm::Instruction::FMul:
699   case llvm::Instruction::UDiv:
700   case llvm::Instruction::SDiv:
701   case llvm::Instruction::FDiv:
702   case llvm::Instruction::URem:
703   case llvm::Instruction::SRem:
704   case llvm::Instruction::FRem:
705   case llvm::Instruction::Shl:
706   case llvm::Instruction::LShr:
707   case llvm::Instruction::AShr:
708   case llvm::Instruction::And:
709   case llvm::Instruction::Or:
710   case llvm::Instruction::Xor:
711   case llvm::Instruction::Alloca:
712   case llvm::Instruction::Load:
713   case llvm::Instruction::Store:
714   case llvm::Instruction::Ret:
715   case llvm::Instruction::Resume:
716   case llvm::Instruction::Trunc:
717   case llvm::Instruction::ZExt:
718   case llvm::Instruction::SExt:
719   case llvm::Instruction::FPToUI:
720   case llvm::Instruction::FPToSI:
721   case llvm::Instruction::UIToFP:
722   case llvm::Instruction::SIToFP:
723   case llvm::Instruction::FPTrunc:
724   case llvm::Instruction::FPExt:
725   case llvm::Instruction::PtrToInt:
726   case llvm::Instruction::IntToPtr:
727   case llvm::Instruction::AddrSpaceCast:
728   case llvm::Instruction::Freeze:
729   case llvm::Instruction::BitCast: {
730     OperationState state(loc, lookupOperationNameFromOpcode(inst->getOpcode()));
731     SmallVector<Value, 4> ops;
732     ops.reserve(inst->getNumOperands());
733     for (auto *op : inst->operand_values()) {
734       Value value = processValue(op);
735       if (!value)
736         return failure();
737       ops.push_back(value);
738     }
739     state.addOperands(ops);
740     if (!inst->getType()->isVoidTy()) {
741       Type type = processType(inst->getType());
742       if (!type)
743         return failure();
744       state.addTypes(type);
745     }
746     Operation *op = b.createOperation(state);
747     if (!inst->getType()->isVoidTy())
748       v = op->getResult(0);
749     return success();
750   }
751   case llvm::Instruction::ICmp: {
752     Value lhs = processValue(inst->getOperand(0));
753     Value rhs = processValue(inst->getOperand(1));
754     if (!lhs || !rhs)
755       return failure();
756     v = b.create<ICmpOp>(
757         loc, getICmpPredicate(cast<llvm::ICmpInst>(inst)->getPredicate()), lhs,
758         rhs);
759     return success();
760   }
761   case llvm::Instruction::Br: {
762     auto *brInst = cast<llvm::BranchInst>(inst);
763     OperationState state(loc,
764                          brInst->isConditional() ? "llvm.cond_br" : "llvm.br");
765     if (brInst->isConditional()) {
766       Value condition = processValue(brInst->getCondition());
767       if (!condition)
768         return failure();
769       state.addOperands(condition);
770     }
771 
772     std::array<int32_t, 3> operandSegmentSizes = {1, 0, 0};
773     for (int i : llvm::seq<int>(0, brInst->getNumSuccessors())) {
774       auto *succ = brInst->getSuccessor(i);
775       SmallVector<Value, 4> blockArguments;
776       if (failed(processBranchArgs(brInst, succ, blockArguments)))
777         return failure();
778       state.addSuccessors(blocks[succ]);
779       state.addOperands(blockArguments);
780       operandSegmentSizes[i + 1] = blockArguments.size();
781     }
782 
783     if (brInst->isConditional()) {
784       state.addAttribute(LLVM::CondBrOp::getOperandSegmentSizeAttr(),
785                          b.getI32VectorAttr(operandSegmentSizes));
786     }
787 
788     b.createOperation(state);
789     return success();
790   }
791   case llvm::Instruction::PHI: {
792     Type type = processType(inst->getType());
793     if (!type)
794       return failure();
795     v = b.getInsertionBlock()->addArgument(type);
796     return success();
797   }
798   case llvm::Instruction::Call: {
799     llvm::CallInst *ci = cast<llvm::CallInst>(inst);
800     SmallVector<Value, 4> ops;
801     ops.reserve(inst->getNumOperands());
802     for (auto &op : ci->arg_operands()) {
803       Value arg = processValue(op.get());
804       if (!arg)
805         return failure();
806       ops.push_back(arg);
807     }
808 
809     SmallVector<Type, 2> tys;
810     if (!ci->getType()->isVoidTy()) {
811       Type type = processType(inst->getType());
812       if (!type)
813         return failure();
814       tys.push_back(type);
815     }
816     Operation *op;
817     if (llvm::Function *callee = ci->getCalledFunction()) {
818       op = b.create<CallOp>(loc, tys, b.getSymbolRefAttr(callee->getName()),
819                             ops);
820     } else {
821       Value calledValue = processValue(ci->getCalledOperand());
822       if (!calledValue)
823         return failure();
824       ops.insert(ops.begin(), calledValue);
825       op = b.create<CallOp>(loc, tys, ops);
826     }
827     if (!ci->getType()->isVoidTy())
828       v = op->getResult(0);
829     return success();
830   }
831   case llvm::Instruction::LandingPad: {
832     llvm::LandingPadInst *lpi = cast<llvm::LandingPadInst>(inst);
833     SmallVector<Value, 4> ops;
834 
835     for (unsigned i = 0, ie = lpi->getNumClauses(); i < ie; i++)
836       ops.push_back(processConstant(lpi->getClause(i)));
837 
838     Type ty = processType(lpi->getType());
839     if (!ty)
840       return failure();
841 
842     v = b.create<LandingpadOp>(loc, ty, lpi->isCleanup(), ops);
843     return success();
844   }
845   case llvm::Instruction::Invoke: {
846     llvm::InvokeInst *ii = cast<llvm::InvokeInst>(inst);
847 
848     SmallVector<Type, 2> tys;
849     if (!ii->getType()->isVoidTy())
850       tys.push_back(processType(inst->getType()));
851 
852     SmallVector<Value, 4> ops;
853     ops.reserve(inst->getNumOperands() + 1);
854     for (auto &op : ii->arg_operands())
855       ops.push_back(processValue(op.get()));
856 
857     SmallVector<Value, 4> normalArgs, unwindArgs;
858     (void)processBranchArgs(ii, ii->getNormalDest(), normalArgs);
859     (void)processBranchArgs(ii, ii->getUnwindDest(), unwindArgs);
860 
861     Operation *op;
862     if (llvm::Function *callee = ii->getCalledFunction()) {
863       op = b.create<InvokeOp>(loc, tys, b.getSymbolRefAttr(callee->getName()),
864                               ops, blocks[ii->getNormalDest()], normalArgs,
865                               blocks[ii->getUnwindDest()], unwindArgs);
866     } else {
867       ops.insert(ops.begin(), processValue(ii->getCalledOperand()));
868       op = b.create<InvokeOp>(loc, tys, ops, blocks[ii->getNormalDest()],
869                               normalArgs, blocks[ii->getUnwindDest()],
870                               unwindArgs);
871     }
872 
873     if (!ii->getType()->isVoidTy())
874       v = op->getResult(0);
875     return success();
876   }
877   case llvm::Instruction::Fence: {
878     StringRef syncscope;
879     SmallVector<StringRef, 4> ssNs;
880     llvm::LLVMContext &llvmContext = inst->getContext();
881     llvm::FenceInst *fence = cast<llvm::FenceInst>(inst);
882     llvmContext.getSyncScopeNames(ssNs);
883     int fenceSyncScopeID = fence->getSyncScopeID();
884     for (unsigned i = 0, e = ssNs.size(); i != e; i++) {
885       if (fenceSyncScopeID == llvmContext.getOrInsertSyncScopeID(ssNs[i])) {
886         syncscope = ssNs[i];
887         break;
888       }
889     }
890     b.create<FenceOp>(loc, getLLVMAtomicOrdering(fence->getOrdering()),
891                       syncscope);
892     return success();
893   }
894   case llvm::Instruction::GetElementPtr: {
895     // FIXME: Support inbounds GEPs.
896     llvm::GetElementPtrInst *gep = cast<llvm::GetElementPtrInst>(inst);
897     SmallVector<Value, 4> ops;
898     for (auto *op : gep->operand_values()) {
899       Value value = processValue(op);
900       if (!value)
901         return failure();
902       ops.push_back(value);
903     }
904     Type type = processType(inst->getType());
905     if (!type)
906       return failure();
907     v = b.create<GEPOp>(loc, type, ops);
908     return success();
909   }
910   }
911 }
912 
913 FlatSymbolRefAttr Importer::getPersonalityAsAttr(llvm::Function *f) {
914   if (!f->hasPersonalityFn())
915     return nullptr;
916 
917   llvm::Constant *pf = f->getPersonalityFn();
918 
919   // If it directly has a name, we can use it.
920   if (pf->hasName())
921     return b.getSymbolRefAttr(pf->getName());
922 
923   // If it doesn't have a name, currently, only function pointers that are
924   // bitcast to i8* are parsed.
925   if (auto ce = dyn_cast<llvm::ConstantExpr>(pf)) {
926     if (ce->getOpcode() == llvm::Instruction::BitCast &&
927         ce->getType() == llvm::Type::getInt8PtrTy(f->getContext())) {
928       if (auto func = dyn_cast<llvm::Function>(ce->getOperand(0)))
929         return b.getSymbolRefAttr(func->getName());
930     }
931   }
932   return FlatSymbolRefAttr();
933 }
934 
935 LogicalResult Importer::processFunction(llvm::Function *f) {
936   blocks.clear();
937   instMap.clear();
938   unknownInstMap.clear();
939 
940   auto functionType =
941       processType(f->getFunctionType()).dyn_cast<LLVMFunctionType>();
942   if (!functionType)
943     return failure();
944 
945   b.setInsertionPoint(module.getBody(), getFuncInsertPt());
946   LLVMFuncOp fop =
947       b.create<LLVMFuncOp>(UnknownLoc::get(context), f->getName(), functionType,
948                            convertLinkageFromLLVM(f->getLinkage()));
949 
950   if (FlatSymbolRefAttr personality = getPersonalityAsAttr(f))
951     fop->setAttr(b.getIdentifier("personality"), personality);
952   else if (f->hasPersonalityFn())
953     emitWarning(UnknownLoc::get(context),
954                 "could not deduce personality, skipping it");
955 
956   if (f->isDeclaration())
957     return success();
958 
959   // Eagerly create all blocks.
960   SmallVector<Block *, 4> blockList;
961   for (llvm::BasicBlock &bb : *f) {
962     blockList.push_back(b.createBlock(&fop.body(), fop.body().end()));
963     blocks[&bb] = blockList.back();
964   }
965   currentEntryBlock = blockList[0];
966 
967   // Add function arguments to the entry block.
968   for (auto kv : llvm::enumerate(f->args()))
969     instMap[&kv.value()] =
970         blockList[0]->addArgument(functionType.getParamType(kv.index()));
971 
972   for (auto bbs : llvm::zip(*f, blockList)) {
973     if (failed(processBasicBlock(&std::get<0>(bbs), std::get<1>(bbs))))
974       return failure();
975   }
976 
977   // Now that all instructions are guaranteed to have been visited, ensure
978   // any unknown uses we encountered are remapped.
979   for (auto &llvmAndUnknown : unknownInstMap) {
980     assert(instMap.count(llvmAndUnknown.first));
981     Value newValue = instMap[llvmAndUnknown.first];
982     Value oldValue = llvmAndUnknown.second->getResult(0);
983     oldValue.replaceAllUsesWith(newValue);
984     llvmAndUnknown.second->erase();
985   }
986   return success();
987 }
988 
989 LogicalResult Importer::processBasicBlock(llvm::BasicBlock *bb, Block *block) {
990   b.setInsertionPointToStart(block);
991   for (llvm::Instruction &inst : *bb) {
992     if (failed(processInstruction(&inst)))
993       return failure();
994   }
995   return success();
996 }
997 
998 OwningModuleRef
999 mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
1000                               MLIRContext *context) {
1001   context->loadDialect<LLVMDialect>();
1002   OwningModuleRef module(ModuleOp::create(
1003       FileLineColLoc::get(context, "", /*line=*/0, /*column=*/0)));
1004 
1005   Importer deserializer(context, module.get());
1006   for (llvm::GlobalVariable &gv : llvmModule->globals()) {
1007     if (!deserializer.processGlobal(&gv))
1008       return {};
1009   }
1010   for (llvm::Function &f : llvmModule->functions()) {
1011     if (failed(deserializer.processFunction(&f)))
1012       return {};
1013   }
1014 
1015   return module;
1016 }
1017 
1018 // Deserializes the LLVM bitcode stored in `input` into an MLIR module in the
1019 // LLVM dialect.
1020 OwningModuleRef translateLLVMIRToModule(llvm::SourceMgr &sourceMgr,
1021                                         MLIRContext *context) {
1022   llvm::SMDiagnostic err;
1023   llvm::LLVMContext llvmContext;
1024   std::unique_ptr<llvm::Module> llvmModule = llvm::parseIR(
1025       *sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), err, llvmContext);
1026   if (!llvmModule) {
1027     std::string errStr;
1028     llvm::raw_string_ostream errStream(errStr);
1029     err.print(/*ProgName=*/"", errStream);
1030     emitError(UnknownLoc::get(context)) << errStream.str();
1031     return {};
1032   }
1033   return translateLLVMIRToModule(std::move(llvmModule), context);
1034 }
1035 
1036 namespace mlir {
1037 void registerFromLLVMIRTranslation() {
1038   TranslateToMLIRRegistration fromLLVM(
1039       "import-llvm", [](llvm::SourceMgr &sourceMgr, MLIRContext *context) {
1040         return ::translateLLVMIRToModule(sourceMgr, context);
1041       });
1042 }
1043 } // namespace mlir
1044