1 //===- ConvertFromLLVMIR.cpp - MLIR to LLVM IR conversion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between LLVM IR and the MLIR LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/MLIRContext.h"
18 #include "mlir/Target/LLVMIR/Import.h"
19 #include "mlir/Target/LLVMIR/TypeFromLLVM.h"
20 #include "mlir/Translation.h"
21 
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/IR/Attributes.h"
24 #include "llvm/IR/Constants.h"
25 #include "llvm/IR/DerivedTypes.h"
26 #include "llvm/IR/Function.h"
27 #include "llvm/IR/InlineAsm.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/IR/Type.h"
30 #include "llvm/IRReader/IRReader.h"
31 #include "llvm/Support/Error.h"
32 #include "llvm/Support/SourceMgr.h"
33 
34 using namespace mlir;
35 using namespace mlir::LLVM;
36 
37 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsFromLLVM.inc"
38 
39 // Utility to print an LLVM value as a string for passing to emitError().
40 // FIXME: Diagnostic should be able to natively handle types that have
41 // operator << (raw_ostream&) defined.
42 static std::string diag(llvm::Value &v) {
43   std::string s;
44   llvm::raw_string_ostream os(s);
45   os << v;
46   return os.str();
47 }
48 
49 // Handles importing globals and functions from an LLVM module.
50 namespace {
51 class Importer {
52 public:
53   Importer(MLIRContext *context, ModuleOp module)
54       : b(context), context(context), module(module),
55         unknownLoc(FileLineColLoc::get(context, "imported-bitcode", 0, 0)),
56         typeTranslator(*context) {
57     b.setInsertionPointToStart(module.getBody());
58   }
59 
60   /// Imports `f` into the current module.
61   LogicalResult processFunction(llvm::Function *f);
62 
63   /// Imports GV as a GlobalOp, creating it if it doesn't exist.
64   GlobalOp processGlobal(llvm::GlobalVariable *GV);
65 
66 private:
67   /// Returns personality of `f` as a FlatSymbolRefAttr.
68   FlatSymbolRefAttr getPersonalityAsAttr(llvm::Function *f);
69   /// Imports `bb` into `block`, which must be initially empty.
70   LogicalResult processBasicBlock(llvm::BasicBlock *bb, Block *block);
71   /// Imports `inst` and populates instMap[inst] with the imported Value.
72   LogicalResult processInstruction(llvm::Instruction *inst);
73   /// Creates an LLVM-compatible MLIR type for `type`.
74   Type processType(llvm::Type *type);
75   /// `value` is an SSA-use. Return the remapped version of `value` or a
76   /// placeholder that will be remapped later if this is an instruction that
77   /// has not yet been visited.
78   Value processValue(llvm::Value *value);
79   /// Create the most accurate Location possible using a llvm::DebugLoc and
80   /// possibly an llvm::Instruction to narrow the Location if debug information
81   /// is unavailable.
82   Location processDebugLoc(const llvm::DebugLoc &loc,
83                            llvm::Instruction *inst = nullptr);
84   /// `br` branches to `target`. Append the block arguments to attach to the
85   /// generated branch op to `blockArguments`. These should be in the same order
86   /// as the PHIs in `target`.
87   LogicalResult processBranchArgs(llvm::Instruction *br,
88                                   llvm::BasicBlock *target,
89                                   SmallVectorImpl<Value> &blockArguments);
90   /// Returns the builtin type equivalent to be used in attributes for the given
91   /// LLVM IR dialect type.
92   Type getStdTypeForAttr(Type type);
93   /// Return `value` as an attribute to attach to a GlobalOp.
94   Attribute getConstantAsAttr(llvm::Constant *value);
95   /// Return `c` as an MLIR Value. This could either be a ConstantOp, or
96   /// an expanded sequence of ops in the current function's entry block (for
97   /// ConstantExprs or ConstantGEPs).
98   Value processConstant(llvm::Constant *c);
99 
100   /// The current builder, pointing at where the next Instruction should be
101   /// generated.
102   OpBuilder b;
103   /// The current context.
104   MLIRContext *context;
105   /// The current module being created.
106   ModuleOp module;
107   /// The entry block of the current function being processed.
108   Block *currentEntryBlock;
109 
110   /// Globals are inserted before the first function, if any.
111   Block::iterator getGlobalInsertPt() {
112     auto it = module.getBody()->begin();
113     auto endIt = module.getBody()->end();
114     while (it != endIt && !isa<LLVMFuncOp>(it))
115       ++it;
116     return it;
117   }
118 
119   /// Functions are always inserted before the module terminator.
120   Block::iterator getFuncInsertPt() {
121     return std::prev(module.getBody()->end());
122   }
123 
124   /// Remapped blocks, for the current function.
125   DenseMap<llvm::BasicBlock *, Block *> blocks;
126   /// Remapped values. These are function-local.
127   DenseMap<llvm::Value *, Value> instMap;
128   /// Instructions that had not been defined when first encountered as a use.
129   /// Maps to the dummy Operation that was created in processValue().
130   DenseMap<llvm::Value *, Operation *> unknownInstMap;
131   /// Uniquing map of GlobalVariables.
132   DenseMap<llvm::GlobalVariable *, GlobalOp> globals;
133   /// Cached FileLineColLoc::get("imported-bitcode", 0, 0).
134   Location unknownLoc;
135   /// The stateful type translator (contains named structs).
136   LLVM::TypeFromLLVMIRTranslator typeTranslator;
137 };
138 } // namespace
139 
140 Location Importer::processDebugLoc(const llvm::DebugLoc &loc,
141                                    llvm::Instruction *inst) {
142   if (!loc && inst) {
143     std::string s;
144     llvm::raw_string_ostream os(s);
145     os << "llvm-imported-inst-%";
146     inst->printAsOperand(os, /*PrintType=*/false);
147     return FileLineColLoc::get(context, os.str(), 0, 0);
148   } else if (!loc) {
149     return unknownLoc;
150   }
151   // FIXME: Obtain the filename from DILocationInfo.
152   return FileLineColLoc::get(context, "imported-bitcode", loc.getLine(),
153                              loc.getCol());
154 }
155 
156 Type Importer::processType(llvm::Type *type) {
157   if (Type result = typeTranslator.translateType(type))
158     return result;
159 
160   // FIXME: Diagnostic should be able to natively handle types that have
161   // operator<<(raw_ostream&) defined.
162   std::string s;
163   llvm::raw_string_ostream os(s);
164   os << *type;
165   emitError(unknownLoc) << "unhandled type: " << os.str();
166   return nullptr;
167 }
168 
169 // We only need integers, floats, doubles, and vectors and tensors thereof for
170 // attributes. Scalar and vector types are converted to the standard
171 // equivalents. Array types are converted to ranked tensors; nested array types
172 // are converted to multi-dimensional tensors or vectors, depending on the
173 // innermost type being a scalar or a vector.
174 Type Importer::getStdTypeForAttr(Type type) {
175   if (!type)
176     return nullptr;
177 
178   if (type.isa<IntegerType, FloatType>())
179     return type;
180 
181   // LLVM vectors can only contain scalars.
182   if (LLVM::isCompatibleVectorType(type)) {
183     auto numElements = LLVM::getVectorNumElements(type);
184     if (numElements.isScalable()) {
185       emitError(unknownLoc) << "scalable vectors not supported";
186       return nullptr;
187     }
188     Type elementType = getStdTypeForAttr(LLVM::getVectorElementType(type));
189     if (!elementType)
190       return nullptr;
191     return VectorType::get(numElements.getKnownMinValue(), elementType);
192   }
193 
194   // LLVM arrays can contain other arrays or vectors.
195   if (auto arrayType = type.dyn_cast<LLVMArrayType>()) {
196     // Recover the nested array shape.
197     SmallVector<int64_t, 4> shape;
198     shape.push_back(arrayType.getNumElements());
199     while (arrayType.getElementType().isa<LLVMArrayType>()) {
200       arrayType = arrayType.getElementType().cast<LLVMArrayType>();
201       shape.push_back(arrayType.getNumElements());
202     }
203 
204     // If the innermost type is a vector, use the multi-dimensional vector as
205     // attribute type.
206     if (LLVM::isCompatibleVectorType(arrayType.getElementType())) {
207       auto numElements = LLVM::getVectorNumElements(arrayType.getElementType());
208       if (numElements.isScalable()) {
209         emitError(unknownLoc) << "scalable vectors not supported";
210         return nullptr;
211       }
212       shape.push_back(numElements.getKnownMinValue());
213 
214       Type elementType = getStdTypeForAttr(
215           LLVM::getVectorElementType(arrayType.getElementType()));
216       if (!elementType)
217         return nullptr;
218       return VectorType::get(shape, elementType);
219     }
220 
221     // Otherwise use a tensor.
222     Type elementType = getStdTypeForAttr(arrayType.getElementType());
223     if (!elementType)
224       return nullptr;
225     return RankedTensorType::get(shape, elementType);
226   }
227 
228   return nullptr;
229 }
230 
231 // Get the given constant as an attribute. Not all constants can be represented
232 // as attributes.
233 Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
234   if (auto *ci = dyn_cast<llvm::ConstantInt>(value))
235     return b.getIntegerAttr(
236         IntegerType::get(context, ci->getType()->getBitWidth()),
237         ci->getValue());
238   if (auto *c = dyn_cast<llvm::ConstantDataArray>(value))
239     if (c->isString())
240       return b.getStringAttr(c->getAsString());
241   if (auto *c = dyn_cast<llvm::ConstantFP>(value)) {
242     if (c->getType()->isDoubleTy())
243       return b.getFloatAttr(FloatType::getF64(context), c->getValueAPF());
244     if (c->getType()->isFloatingPointTy())
245       return b.getFloatAttr(FloatType::getF32(context), c->getValueAPF());
246   }
247   if (auto *f = dyn_cast<llvm::Function>(value))
248     return SymbolRefAttr::get(b.getContext(), f->getName());
249 
250   // Convert constant data to a dense elements attribute.
251   if (auto *cd = dyn_cast<llvm::ConstantDataSequential>(value)) {
252     Type type = processType(cd->getElementType());
253     if (!type)
254       return nullptr;
255 
256     auto attrType = getStdTypeForAttr(processType(cd->getType()))
257                         .dyn_cast_or_null<ShapedType>();
258     if (!attrType)
259       return nullptr;
260 
261     if (type.isa<IntegerType>()) {
262       SmallVector<APInt, 8> values;
263       values.reserve(cd->getNumElements());
264       for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i)
265         values.push_back(cd->getElementAsAPInt(i));
266       return DenseElementsAttr::get(attrType, values);
267     }
268 
269     if (type.isa<Float32Type, Float64Type>()) {
270       SmallVector<APFloat, 8> values;
271       values.reserve(cd->getNumElements());
272       for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i)
273         values.push_back(cd->getElementAsAPFloat(i));
274       return DenseElementsAttr::get(attrType, values);
275     }
276 
277     return nullptr;
278   }
279 
280   // Unpack constant aggregates to create dense elements attribute whenever
281   // possible. Return nullptr (failure) otherwise.
282   if (isa<llvm::ConstantAggregate>(value)) {
283     auto outerType = getStdTypeForAttr(processType(value->getType()))
284                          .dyn_cast_or_null<ShapedType>();
285     if (!outerType)
286       return nullptr;
287 
288     SmallVector<Attribute, 8> values;
289     SmallVector<int64_t, 8> shape;
290 
291     for (unsigned i = 0, e = value->getNumOperands(); i < e; ++i) {
292       auto nested = getConstantAsAttr(value->getAggregateElement(i))
293                         .dyn_cast_or_null<DenseElementsAttr>();
294       if (!nested)
295         return nullptr;
296 
297       values.append(nested.attr_value_begin(), nested.attr_value_end());
298     }
299 
300     return DenseElementsAttr::get(outerType, values);
301   }
302 
303   return nullptr;
304 }
305 
306 GlobalOp Importer::processGlobal(llvm::GlobalVariable *GV) {
307   auto it = globals.find(GV);
308   if (it != globals.end())
309     return it->second;
310 
311   OpBuilder b(module.getBody(), getGlobalInsertPt());
312   Attribute valueAttr;
313   if (GV->hasInitializer())
314     valueAttr = getConstantAsAttr(GV->getInitializer());
315   Type type = processType(GV->getValueType());
316   if (!type)
317     return nullptr;
318 
319   uint64_t alignment = 0;
320   llvm::MaybeAlign maybeAlign = GV->getAlign();
321   if (maybeAlign.hasValue()) {
322     llvm::Align align = maybeAlign.getValue();
323     alignment = align.value();
324   }
325 
326   GlobalOp op =
327       b.create<GlobalOp>(UnknownLoc::get(context), type, GV->isConstant(),
328                          convertLinkageFromLLVM(GV->getLinkage()),
329                          GV->getName(), valueAttr, alignment);
330 
331   if (GV->hasInitializer() && !valueAttr) {
332     Region &r = op.getInitializerRegion();
333     currentEntryBlock = b.createBlock(&r);
334     b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin());
335     Value v = processConstant(GV->getInitializer());
336     if (!v)
337       return nullptr;
338     b.create<ReturnOp>(op.getLoc(), ArrayRef<Value>({v}));
339   }
340   if (GV->hasAtLeastLocalUnnamedAddr())
341     op.unnamed_addrAttr(UnnamedAddrAttr::get(
342         context, convertUnnamedAddrFromLLVM(GV->getUnnamedAddr())));
343   if (GV->hasSection())
344     op.sectionAttr(b.getStringAttr(GV->getSection()));
345 
346   return globals[GV] = op;
347 }
348 
349 Value Importer::processConstant(llvm::Constant *c) {
350   OpBuilder bEntry(currentEntryBlock, currentEntryBlock->begin());
351   if (Attribute attr = getConstantAsAttr(c)) {
352     // These constants can be represented as attributes.
353     OpBuilder b(currentEntryBlock, currentEntryBlock->begin());
354     Type type = processType(c->getType());
355     if (!type)
356       return nullptr;
357     if (auto symbolRef = attr.dyn_cast<FlatSymbolRefAttr>())
358       return instMap[c] = bEntry.create<AddressOfOp>(unknownLoc, type,
359                                                      symbolRef.getValue());
360     return instMap[c] = bEntry.create<ConstantOp>(unknownLoc, type, attr);
361   }
362   if (auto *cn = dyn_cast<llvm::ConstantPointerNull>(c)) {
363     Type type = processType(cn->getType());
364     if (!type)
365       return nullptr;
366     return instMap[c] = bEntry.create<NullOp>(unknownLoc, type);
367   }
368   if (auto *GV = dyn_cast<llvm::GlobalVariable>(c))
369     return bEntry.create<AddressOfOp>(UnknownLoc::get(context),
370                                       processGlobal(GV));
371 
372   if (auto *ce = dyn_cast<llvm::ConstantExpr>(c)) {
373     llvm::Instruction *i = ce->getAsInstruction();
374     OpBuilder::InsertionGuard guard(b);
375     b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin());
376     if (failed(processInstruction(i)))
377       return nullptr;
378     assert(instMap.count(i));
379 
380     // Remove this zombie LLVM instruction now, leaving us only with the MLIR
381     // op.
382     i->deleteValue();
383     return instMap[c] = instMap[i];
384   }
385   if (auto *ue = dyn_cast<llvm::UndefValue>(c)) {
386     Type type = processType(ue->getType());
387     if (!type)
388       return nullptr;
389     return instMap[c] = bEntry.create<UndefOp>(UnknownLoc::get(context), type);
390   }
391   emitError(unknownLoc) << "unhandled constant: " << diag(*c);
392   return nullptr;
393 }
394 
395 Value Importer::processValue(llvm::Value *value) {
396   auto it = instMap.find(value);
397   if (it != instMap.end())
398     return it->second;
399 
400   // We don't expect to see instructions in dominator order. If we haven't seen
401   // this instruction yet, create an unknown op and remap it later.
402   if (isa<llvm::Instruction>(value)) {
403     OperationState state(UnknownLoc::get(context), "llvm.unknown");
404     Type type = processType(value->getType());
405     if (!type)
406       return nullptr;
407     state.addTypes(type);
408     unknownInstMap[value] = b.createOperation(state);
409     return unknownInstMap[value]->getResult(0);
410   }
411 
412   if (auto *c = dyn_cast<llvm::Constant>(value))
413     return processConstant(c);
414 
415   emitError(unknownLoc) << "unhandled value: " << diag(*value);
416   return nullptr;
417 }
418 
419 /// Return the MLIR OperationName for the given LLVM opcode.
420 static StringRef lookupOperationNameFromOpcode(unsigned opcode) {
421 // Maps from LLVM opcode to MLIR OperationName. This is deliberately ordered
422 // as in llvm/IR/Instructions.def to aid comprehension and spot missing
423 // instructions.
424 #define INST(llvm_n, mlir_n)                                                   \
425   { llvm::Instruction::llvm_n, LLVM::mlir_n##Op::getOperationName() }
426   static const DenseMap<unsigned, StringRef> opcMap = {
427       // Ret is handled specially.
428       // Br is handled specially.
429       // FIXME: switch
430       // FIXME: indirectbr
431       // FIXME: invoke
432       INST(Resume, Resume),
433       // FIXME: unreachable
434       // FIXME: cleanupret
435       // FIXME: catchret
436       // FIXME: catchswitch
437       // FIXME: callbr
438       // FIXME: fneg
439       INST(Add, Add), INST(FAdd, FAdd), INST(Sub, Sub), INST(FSub, FSub),
440       INST(Mul, Mul), INST(FMul, FMul), INST(UDiv, UDiv), INST(SDiv, SDiv),
441       INST(FDiv, FDiv), INST(URem, URem), INST(SRem, SRem), INST(FRem, FRem),
442       INST(Shl, Shl), INST(LShr, LShr), INST(AShr, AShr), INST(And, And),
443       INST(Or, Or), INST(Xor, XOr), INST(Alloca, Alloca), INST(Load, Load),
444       INST(Store, Store),
445       // Getelementptr is handled specially.
446       INST(Ret, Return), INST(Fence, Fence),
447       // FIXME: atomiccmpxchg
448       // FIXME: atomicrmw
449       INST(Trunc, Trunc), INST(ZExt, ZExt), INST(SExt, SExt),
450       INST(FPToUI, FPToUI), INST(FPToSI, FPToSI), INST(UIToFP, UIToFP),
451       INST(SIToFP, SIToFP), INST(FPTrunc, FPTrunc), INST(FPExt, FPExt),
452       INST(PtrToInt, PtrToInt), INST(IntToPtr, IntToPtr),
453       INST(BitCast, Bitcast), INST(AddrSpaceCast, AddrSpaceCast),
454       // FIXME: cleanuppad
455       // FIXME: catchpad
456       // ICmp is handled specially.
457       // FIXME: fcmp
458       // PHI is handled specially.
459       INST(Freeze, Freeze), INST(Call, Call),
460       // FIXME: select
461       // FIXME: vaarg
462       // FIXME: extractelement
463       // FIXME: insertelement
464       // FIXME: shufflevector
465       // FIXME: extractvalue
466       // FIXME: insertvalue
467       // FIXME: landingpad
468   };
469 #undef INST
470 
471   return opcMap.lookup(opcode);
472 }
473 
474 static ICmpPredicate getICmpPredicate(llvm::CmpInst::Predicate p) {
475   switch (p) {
476   default:
477     llvm_unreachable("incorrect comparison predicate");
478   case llvm::CmpInst::Predicate::ICMP_EQ:
479     return LLVM::ICmpPredicate::eq;
480   case llvm::CmpInst::Predicate::ICMP_NE:
481     return LLVM::ICmpPredicate::ne;
482   case llvm::CmpInst::Predicate::ICMP_SLT:
483     return LLVM::ICmpPredicate::slt;
484   case llvm::CmpInst::Predicate::ICMP_SLE:
485     return LLVM::ICmpPredicate::sle;
486   case llvm::CmpInst::Predicate::ICMP_SGT:
487     return LLVM::ICmpPredicate::sgt;
488   case llvm::CmpInst::Predicate::ICMP_SGE:
489     return LLVM::ICmpPredicate::sge;
490   case llvm::CmpInst::Predicate::ICMP_ULT:
491     return LLVM::ICmpPredicate::ult;
492   case llvm::CmpInst::Predicate::ICMP_ULE:
493     return LLVM::ICmpPredicate::ule;
494   case llvm::CmpInst::Predicate::ICMP_UGT:
495     return LLVM::ICmpPredicate::ugt;
496   case llvm::CmpInst::Predicate::ICMP_UGE:
497     return LLVM::ICmpPredicate::uge;
498   }
499   llvm_unreachable("incorrect comparison predicate");
500 }
501 
502 static AtomicOrdering getLLVMAtomicOrdering(llvm::AtomicOrdering ordering) {
503   switch (ordering) {
504   case llvm::AtomicOrdering::NotAtomic:
505     return LLVM::AtomicOrdering::not_atomic;
506   case llvm::AtomicOrdering::Unordered:
507     return LLVM::AtomicOrdering::unordered;
508   case llvm::AtomicOrdering::Monotonic:
509     return LLVM::AtomicOrdering::monotonic;
510   case llvm::AtomicOrdering::Acquire:
511     return LLVM::AtomicOrdering::acquire;
512   case llvm::AtomicOrdering::Release:
513     return LLVM::AtomicOrdering::release;
514   case llvm::AtomicOrdering::AcquireRelease:
515     return LLVM::AtomicOrdering::acq_rel;
516   case llvm::AtomicOrdering::SequentiallyConsistent:
517     return LLVM::AtomicOrdering::seq_cst;
518   }
519   llvm_unreachable("incorrect atomic ordering");
520 }
521 
522 // `br` branches to `target`. Return the branch arguments to `br`, in the
523 // same order of the PHIs in `target`.
524 LogicalResult
525 Importer::processBranchArgs(llvm::Instruction *br, llvm::BasicBlock *target,
526                             SmallVectorImpl<Value> &blockArguments) {
527   for (auto inst = target->begin(); isa<llvm::PHINode>(inst); ++inst) {
528     auto *PN = cast<llvm::PHINode>(&*inst);
529     Value value = processValue(PN->getIncomingValueForBlock(br->getParent()));
530     if (!value)
531       return failure();
532     blockArguments.push_back(value);
533   }
534   return success();
535 }
536 
537 LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
538   // FIXME: Support uses of SubtargetData. Currently inbounds GEPs, fast-math
539   // flags and call / operand attributes are not supported.
540   Location loc = processDebugLoc(inst->getDebugLoc(), inst);
541   Value &v = instMap[inst];
542   assert(!v && "processInstruction must be called only once per instruction!");
543   switch (inst->getOpcode()) {
544   default:
545     return emitError(loc) << "unknown instruction: " << diag(*inst);
546   case llvm::Instruction::Add:
547   case llvm::Instruction::FAdd:
548   case llvm::Instruction::Sub:
549   case llvm::Instruction::FSub:
550   case llvm::Instruction::Mul:
551   case llvm::Instruction::FMul:
552   case llvm::Instruction::UDiv:
553   case llvm::Instruction::SDiv:
554   case llvm::Instruction::FDiv:
555   case llvm::Instruction::URem:
556   case llvm::Instruction::SRem:
557   case llvm::Instruction::FRem:
558   case llvm::Instruction::Shl:
559   case llvm::Instruction::LShr:
560   case llvm::Instruction::AShr:
561   case llvm::Instruction::And:
562   case llvm::Instruction::Or:
563   case llvm::Instruction::Xor:
564   case llvm::Instruction::Alloca:
565   case llvm::Instruction::Load:
566   case llvm::Instruction::Store:
567   case llvm::Instruction::Ret:
568   case llvm::Instruction::Resume:
569   case llvm::Instruction::Trunc:
570   case llvm::Instruction::ZExt:
571   case llvm::Instruction::SExt:
572   case llvm::Instruction::FPToUI:
573   case llvm::Instruction::FPToSI:
574   case llvm::Instruction::UIToFP:
575   case llvm::Instruction::SIToFP:
576   case llvm::Instruction::FPTrunc:
577   case llvm::Instruction::FPExt:
578   case llvm::Instruction::PtrToInt:
579   case llvm::Instruction::IntToPtr:
580   case llvm::Instruction::AddrSpaceCast:
581   case llvm::Instruction::Freeze:
582   case llvm::Instruction::BitCast: {
583     OperationState state(loc, lookupOperationNameFromOpcode(inst->getOpcode()));
584     SmallVector<Value, 4> ops;
585     ops.reserve(inst->getNumOperands());
586     for (auto *op : inst->operand_values()) {
587       Value value = processValue(op);
588       if (!value)
589         return failure();
590       ops.push_back(value);
591     }
592     state.addOperands(ops);
593     if (!inst->getType()->isVoidTy()) {
594       Type type = processType(inst->getType());
595       if (!type)
596         return failure();
597       state.addTypes(type);
598     }
599     Operation *op = b.createOperation(state);
600     if (!inst->getType()->isVoidTy())
601       v = op->getResult(0);
602     return success();
603   }
604   case llvm::Instruction::ICmp: {
605     Value lhs = processValue(inst->getOperand(0));
606     Value rhs = processValue(inst->getOperand(1));
607     if (!lhs || !rhs)
608       return failure();
609     v = b.create<ICmpOp>(
610         loc, getICmpPredicate(cast<llvm::ICmpInst>(inst)->getPredicate()), lhs,
611         rhs);
612     return success();
613   }
614   case llvm::Instruction::Br: {
615     auto *brInst = cast<llvm::BranchInst>(inst);
616     OperationState state(loc,
617                          brInst->isConditional() ? "llvm.cond_br" : "llvm.br");
618     if (brInst->isConditional()) {
619       Value condition = processValue(brInst->getCondition());
620       if (!condition)
621         return failure();
622       state.addOperands(condition);
623     }
624 
625     std::array<int32_t, 3> operandSegmentSizes = {1, 0, 0};
626     for (int i : llvm::seq<int>(0, brInst->getNumSuccessors())) {
627       auto *succ = brInst->getSuccessor(i);
628       SmallVector<Value, 4> blockArguments;
629       if (failed(processBranchArgs(brInst, succ, blockArguments)))
630         return failure();
631       state.addSuccessors(blocks[succ]);
632       state.addOperands(blockArguments);
633       operandSegmentSizes[i + 1] = blockArguments.size();
634     }
635 
636     if (brInst->isConditional()) {
637       state.addAttribute(LLVM::CondBrOp::getOperandSegmentSizeAttr(),
638                          b.getI32VectorAttr(operandSegmentSizes));
639     }
640 
641     b.createOperation(state);
642     return success();
643   }
644   case llvm::Instruction::PHI: {
645     Type type = processType(inst->getType());
646     if (!type)
647       return failure();
648     v = b.getInsertionBlock()->addArgument(type);
649     return success();
650   }
651   case llvm::Instruction::Call: {
652     llvm::CallInst *ci = cast<llvm::CallInst>(inst);
653     SmallVector<Value, 4> ops;
654     ops.reserve(inst->getNumOperands());
655     for (auto &op : ci->arg_operands()) {
656       Value arg = processValue(op.get());
657       if (!arg)
658         return failure();
659       ops.push_back(arg);
660     }
661 
662     SmallVector<Type, 2> tys;
663     if (!ci->getType()->isVoidTy()) {
664       Type type = processType(inst->getType());
665       if (!type)
666         return failure();
667       tys.push_back(type);
668     }
669     Operation *op;
670     if (llvm::Function *callee = ci->getCalledFunction()) {
671       op = b.create<CallOp>(
672           loc, tys, SymbolRefAttr::get(b.getContext(), callee->getName()), ops);
673     } else {
674       Value calledValue = processValue(ci->getCalledOperand());
675       if (!calledValue)
676         return failure();
677       ops.insert(ops.begin(), calledValue);
678       op = b.create<CallOp>(loc, tys, ops);
679     }
680     if (!ci->getType()->isVoidTy())
681       v = op->getResult(0);
682     return success();
683   }
684   case llvm::Instruction::LandingPad: {
685     llvm::LandingPadInst *lpi = cast<llvm::LandingPadInst>(inst);
686     SmallVector<Value, 4> ops;
687 
688     for (unsigned i = 0, ie = lpi->getNumClauses(); i < ie; i++)
689       ops.push_back(processConstant(lpi->getClause(i)));
690 
691     Type ty = processType(lpi->getType());
692     if (!ty)
693       return failure();
694 
695     v = b.create<LandingpadOp>(loc, ty, lpi->isCleanup(), ops);
696     return success();
697   }
698   case llvm::Instruction::Invoke: {
699     llvm::InvokeInst *ii = cast<llvm::InvokeInst>(inst);
700 
701     SmallVector<Type, 2> tys;
702     if (!ii->getType()->isVoidTy())
703       tys.push_back(processType(inst->getType()));
704 
705     SmallVector<Value, 4> ops;
706     ops.reserve(inst->getNumOperands() + 1);
707     for (auto &op : ii->arg_operands())
708       ops.push_back(processValue(op.get()));
709 
710     SmallVector<Value, 4> normalArgs, unwindArgs;
711     (void)processBranchArgs(ii, ii->getNormalDest(), normalArgs);
712     (void)processBranchArgs(ii, ii->getUnwindDest(), unwindArgs);
713 
714     Operation *op;
715     if (llvm::Function *callee = ii->getCalledFunction()) {
716       op = b.create<InvokeOp>(
717           loc, tys, SymbolRefAttr::get(b.getContext(), callee->getName()), ops,
718           blocks[ii->getNormalDest()], normalArgs, blocks[ii->getUnwindDest()],
719           unwindArgs);
720     } else {
721       ops.insert(ops.begin(), processValue(ii->getCalledOperand()));
722       op = b.create<InvokeOp>(loc, tys, ops, blocks[ii->getNormalDest()],
723                               normalArgs, blocks[ii->getUnwindDest()],
724                               unwindArgs);
725     }
726 
727     if (!ii->getType()->isVoidTy())
728       v = op->getResult(0);
729     return success();
730   }
731   case llvm::Instruction::Fence: {
732     StringRef syncscope;
733     SmallVector<StringRef, 4> ssNs;
734     llvm::LLVMContext &llvmContext = inst->getContext();
735     llvm::FenceInst *fence = cast<llvm::FenceInst>(inst);
736     llvmContext.getSyncScopeNames(ssNs);
737     int fenceSyncScopeID = fence->getSyncScopeID();
738     for (unsigned i = 0, e = ssNs.size(); i != e; i++) {
739       if (fenceSyncScopeID == llvmContext.getOrInsertSyncScopeID(ssNs[i])) {
740         syncscope = ssNs[i];
741         break;
742       }
743     }
744     b.create<FenceOp>(loc, getLLVMAtomicOrdering(fence->getOrdering()),
745                       syncscope);
746     return success();
747   }
748   case llvm::Instruction::GetElementPtr: {
749     // FIXME: Support inbounds GEPs.
750     llvm::GetElementPtrInst *gep = cast<llvm::GetElementPtrInst>(inst);
751     SmallVector<Value, 4> ops;
752     for (auto *op : gep->operand_values()) {
753       Value value = processValue(op);
754       if (!value)
755         return failure();
756       ops.push_back(value);
757     }
758     Type type = processType(inst->getType());
759     if (!type)
760       return failure();
761     v = b.create<GEPOp>(loc, type, ops);
762     return success();
763   }
764   }
765 }
766 
767 FlatSymbolRefAttr Importer::getPersonalityAsAttr(llvm::Function *f) {
768   if (!f->hasPersonalityFn())
769     return nullptr;
770 
771   llvm::Constant *pf = f->getPersonalityFn();
772 
773   // If it directly has a name, we can use it.
774   if (pf->hasName())
775     return SymbolRefAttr::get(b.getContext(), pf->getName());
776 
777   // If it doesn't have a name, currently, only function pointers that are
778   // bitcast to i8* are parsed.
779   if (auto ce = dyn_cast<llvm::ConstantExpr>(pf)) {
780     if (ce->getOpcode() == llvm::Instruction::BitCast &&
781         ce->getType() == llvm::Type::getInt8PtrTy(f->getContext())) {
782       if (auto func = dyn_cast<llvm::Function>(ce->getOperand(0)))
783         return SymbolRefAttr::get(b.getContext(), func->getName());
784     }
785   }
786   return FlatSymbolRefAttr();
787 }
788 
789 LogicalResult Importer::processFunction(llvm::Function *f) {
790   blocks.clear();
791   instMap.clear();
792   unknownInstMap.clear();
793 
794   auto functionType =
795       processType(f->getFunctionType()).dyn_cast<LLVMFunctionType>();
796   if (!functionType)
797     return failure();
798 
799   b.setInsertionPoint(module.getBody(), getFuncInsertPt());
800   LLVMFuncOp fop =
801       b.create<LLVMFuncOp>(UnknownLoc::get(context), f->getName(), functionType,
802                            convertLinkageFromLLVM(f->getLinkage()));
803 
804   if (FlatSymbolRefAttr personality = getPersonalityAsAttr(f))
805     fop->setAttr(b.getIdentifier("personality"), personality);
806   else if (f->hasPersonalityFn())
807     emitWarning(UnknownLoc::get(context),
808                 "could not deduce personality, skipping it");
809 
810   if (f->isDeclaration())
811     return success();
812 
813   // Eagerly create all blocks.
814   SmallVector<Block *, 4> blockList;
815   for (llvm::BasicBlock &bb : *f) {
816     blockList.push_back(b.createBlock(&fop.body(), fop.body().end()));
817     blocks[&bb] = blockList.back();
818   }
819   currentEntryBlock = blockList[0];
820 
821   // Add function arguments to the entry block.
822   for (auto kv : llvm::enumerate(f->args()))
823     instMap[&kv.value()] =
824         blockList[0]->addArgument(functionType.getParamType(kv.index()));
825 
826   for (auto bbs : llvm::zip(*f, blockList)) {
827     if (failed(processBasicBlock(&std::get<0>(bbs), std::get<1>(bbs))))
828       return failure();
829   }
830 
831   // Now that all instructions are guaranteed to have been visited, ensure
832   // any unknown uses we encountered are remapped.
833   for (auto &llvmAndUnknown : unknownInstMap) {
834     assert(instMap.count(llvmAndUnknown.first));
835     Value newValue = instMap[llvmAndUnknown.first];
836     Value oldValue = llvmAndUnknown.second->getResult(0);
837     oldValue.replaceAllUsesWith(newValue);
838     llvmAndUnknown.second->erase();
839   }
840   return success();
841 }
842 
843 LogicalResult Importer::processBasicBlock(llvm::BasicBlock *bb, Block *block) {
844   b.setInsertionPointToStart(block);
845   for (llvm::Instruction &inst : *bb) {
846     if (failed(processInstruction(&inst)))
847       return failure();
848   }
849   return success();
850 }
851 
852 OwningModuleRef
853 mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
854                               MLIRContext *context) {
855   context->loadDialect<LLVMDialect>();
856   OwningModuleRef module(ModuleOp::create(
857       FileLineColLoc::get(context, "", /*line=*/0, /*column=*/0)));
858 
859   Importer deserializer(context, module.get());
860   for (llvm::GlobalVariable &gv : llvmModule->globals()) {
861     if (!deserializer.processGlobal(&gv))
862       return {};
863   }
864   for (llvm::Function &f : llvmModule->functions()) {
865     if (failed(deserializer.processFunction(&f)))
866       return {};
867   }
868 
869   return module;
870 }
871 
872 // Deserializes the LLVM bitcode stored in `input` into an MLIR module in the
873 // LLVM dialect.
874 OwningModuleRef translateLLVMIRToModule(llvm::SourceMgr &sourceMgr,
875                                         MLIRContext *context) {
876   llvm::SMDiagnostic err;
877   llvm::LLVMContext llvmContext;
878   std::unique_ptr<llvm::Module> llvmModule = llvm::parseIR(
879       *sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), err, llvmContext);
880   if (!llvmModule) {
881     std::string errStr;
882     llvm::raw_string_ostream errStream(errStr);
883     err.print(/*ProgName=*/"", errStream);
884     emitError(UnknownLoc::get(context)) << errStream.str();
885     return {};
886   }
887   return translateLLVMIRToModule(std::move(llvmModule), context);
888 }
889 
890 namespace mlir {
891 void registerFromLLVMIRTranslation() {
892   TranslateToMLIRRegistration fromLLVM(
893       "import-llvm", [](llvm::SourceMgr &sourceMgr, MLIRContext *context) {
894         return ::translateLLVMIRToModule(sourceMgr, context);
895       });
896 }
897 } // namespace mlir
898