1 //===- ConvertFromLLVMIR.cpp - MLIR to LLVM IR conversion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between LLVM IR and the MLIR LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/MLIRContext.h"
16 #include "mlir/IR/Module.h"
17 #include "mlir/IR/StandardTypes.h"
18 #include "mlir/Target/LLVMIR.h"
19 #include "mlir/Translation.h"
20 
21 #include "llvm/IR/Attributes.h"
22 #include "llvm/IR/Constants.h"
23 #include "llvm/IR/Function.h"
24 #include "llvm/IR/Instructions.h"
25 #include "llvm/IR/Type.h"
26 #include "llvm/IRReader/IRReader.h"
27 #include "llvm/Support/Error.h"
28 #include "llvm/Support/SourceMgr.h"
29 
30 using namespace mlir;
31 using namespace mlir::LLVM;
32 
33 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsFromLLVM.inc"
34 
35 // Utility to print an LLVM value as a string for passing to emitError().
36 // FIXME: Diagnostic should be able to natively handle types that have
37 // operator << (raw_ostream&) defined.
38 static std::string diag(llvm::Value &v) {
39   std::string s;
40   llvm::raw_string_ostream os(s);
41   os << v;
42   return os.str();
43 }
44 
45 // Handles importing globals and functions from an LLVM module.
46 namespace {
47 class Importer {
48 public:
49   Importer(MLIRContext *context, ModuleOp module)
50       : b(context), context(context), module(module),
51         unknownLoc(FileLineColLoc::get("imported-bitcode", 0, 0, context)) {
52     b.setInsertionPointToStart(module.getBody());
53     dialect = context->getRegisteredDialect<LLVMDialect>();
54   }
55 
56   /// Imports `f` into the current module.
57   LogicalResult processFunction(llvm::Function *f);
58 
59   /// Imports GV as a GlobalOp, creating it if it doesn't exist.
60   GlobalOp processGlobal(llvm::GlobalVariable *GV);
61 
62 private:
63   /// Returns personality of `f` as a FlatSymbolRefAttr.
64   FlatSymbolRefAttr getPersonalityAsAttr(llvm::Function *f);
65   /// Imports `bb` into `block`, which must be initially empty.
66   LogicalResult processBasicBlock(llvm::BasicBlock *bb, Block *block);
67   /// Imports `inst` and populates instMap[inst] with the imported Value.
68   LogicalResult processInstruction(llvm::Instruction *inst);
69   /// Creates an LLVMType for `type`.
70   LLVMType processType(llvm::Type *type);
71   /// `value` is an SSA-use. Return the remapped version of `value` or a
72   /// placeholder that will be remapped later if this is an instruction that
73   /// has not yet been visited.
74   Value processValue(llvm::Value *value);
75   /// Create the most accurate Location possible using a llvm::DebugLoc and
76   /// possibly an llvm::Instruction to narrow the Location if debug information
77   /// is unavailable.
78   Location processDebugLoc(const llvm::DebugLoc &loc,
79                            llvm::Instruction *inst = nullptr);
80   /// `br` branches to `target`. Append the block arguments to attach to the
81   /// generated branch op to `blockArguments`. These should be in the same order
82   /// as the PHIs in `target`.
83   LogicalResult processBranchArgs(llvm::Instruction *br,
84                                   llvm::BasicBlock *target,
85                                   SmallVectorImpl<Value> &blockArguments);
86   /// Returns the standard type equivalent to be used in attributes for the
87   /// given LLVM IR dialect type.
88   Type getStdTypeForAttr(LLVMType type);
89   /// Return `value` as an attribute to attach to a GlobalOp.
90   Attribute getConstantAsAttr(llvm::Constant *value);
91   /// Return `c` as an MLIR Value. This could either be a ConstantOp, or
92   /// an expanded sequence of ops in the current function's entry block (for
93   /// ConstantExprs or ConstantGEPs).
94   Value processConstant(llvm::Constant *c);
95 
96   /// The current builder, pointing at where the next Instruction should be
97   /// generated.
98   OpBuilder b;
99   /// The current context.
100   MLIRContext *context;
101   /// The current module being created.
102   ModuleOp module;
103   /// The entry block of the current function being processed.
104   Block *currentEntryBlock;
105 
106   /// Globals are inserted before the first function, if any.
107   Block::iterator getGlobalInsertPt() {
108     auto i = module.getBody()->begin();
109     while (!isa<LLVMFuncOp, ModuleTerminatorOp>(i))
110       ++i;
111     return i;
112   }
113 
114   /// Functions are always inserted before the module terminator.
115   Block::iterator getFuncInsertPt() {
116     return std::prev(module.getBody()->end());
117   }
118 
119   /// Remapped blocks, for the current function.
120   DenseMap<llvm::BasicBlock *, Block *> blocks;
121   /// Remapped values. These are function-local.
122   DenseMap<llvm::Value *, Value> instMap;
123   /// Instructions that had not been defined when first encountered as a use.
124   /// Maps to the dummy Operation that was created in processValue().
125   DenseMap<llvm::Value *, Operation *> unknownInstMap;
126   /// Uniquing map of GlobalVariables.
127   DenseMap<llvm::GlobalVariable *, GlobalOp> globals;
128   /// Cached FileLineColLoc::get("imported-bitcode", 0, 0).
129   Location unknownLoc;
130   /// Cached dialect.
131   LLVMDialect *dialect;
132 };
133 } // namespace
134 
135 Location Importer::processDebugLoc(const llvm::DebugLoc &loc,
136                                    llvm::Instruction *inst) {
137   if (!loc && inst) {
138     std::string s;
139     llvm::raw_string_ostream os(s);
140     os << "llvm-imported-inst-%";
141     inst->printAsOperand(os, /*PrintType=*/false);
142     return FileLineColLoc::get(os.str(), 0, 0, context);
143   } else if (!loc) {
144     return unknownLoc;
145   }
146   // FIXME: Obtain the filename from DILocationInfo.
147   return FileLineColLoc::get("imported-bitcode", loc.getLine(), loc.getCol(),
148                              context);
149 }
150 
151 LLVMType Importer::processType(llvm::Type *type) {
152   switch (type->getTypeID()) {
153   case llvm::Type::FloatTyID:
154     return LLVMType::getFloatTy(dialect);
155   case llvm::Type::DoubleTyID:
156     return LLVMType::getDoubleTy(dialect);
157   case llvm::Type::IntegerTyID:
158     return LLVMType::getIntNTy(dialect, type->getIntegerBitWidth());
159   case llvm::Type::PointerTyID: {
160     LLVMType elementType = processType(type->getPointerElementType());
161     if (!elementType)
162       return nullptr;
163     return elementType.getPointerTo(type->getPointerAddressSpace());
164   }
165   case llvm::Type::ArrayTyID: {
166     LLVMType elementType = processType(type->getArrayElementType());
167     if (!elementType)
168       return nullptr;
169     return LLVMType::getArrayTy(elementType, type->getArrayNumElements());
170   }
171   case llvm::Type::ScalableVectorTyID: {
172     emitError(unknownLoc) << "scalable vector types not supported";
173     return nullptr;
174   }
175   case llvm::Type::FixedVectorTyID: {
176     auto *typeVTy = llvm::cast<llvm::FixedVectorType>(type);
177     LLVMType elementType = processType(typeVTy->getElementType());
178     if (!elementType)
179       return nullptr;
180     return LLVMType::getVectorTy(elementType, typeVTy->getNumElements());
181   }
182   case llvm::Type::VoidTyID:
183     return LLVMType::getVoidTy(dialect);
184   case llvm::Type::FP128TyID:
185     return LLVMType::getFP128Ty(dialect);
186   case llvm::Type::X86_FP80TyID:
187     return LLVMType::getX86_FP80Ty(dialect);
188   case llvm::Type::StructTyID: {
189     SmallVector<LLVMType, 4> elementTypes;
190     elementTypes.reserve(type->getStructNumElements());
191     for (unsigned i = 0, e = type->getStructNumElements(); i != e; ++i) {
192       LLVMType ty = processType(type->getStructElementType(i));
193       if (!ty)
194         return nullptr;
195       elementTypes.push_back(ty);
196     }
197     return LLVMType::getStructTy(dialect, elementTypes,
198                                  cast<llvm::StructType>(type)->isPacked());
199   }
200   case llvm::Type::FunctionTyID: {
201     llvm::FunctionType *fty = cast<llvm::FunctionType>(type);
202     SmallVector<LLVMType, 4> paramTypes;
203     for (unsigned i = 0, e = fty->getNumParams(); i != e; ++i) {
204       LLVMType ty = processType(fty->getParamType(i));
205       if (!ty)
206         return nullptr;
207       paramTypes.push_back(ty);
208     }
209     LLVMType result = processType(fty->getReturnType());
210     if (!result)
211       return nullptr;
212 
213     return LLVMType::getFunctionTy(result, paramTypes, fty->isVarArg());
214   }
215   default: {
216     // FIXME: Diagnostic should be able to natively handle types that have
217     // operator<<(raw_ostream&) defined.
218     std::string s;
219     llvm::raw_string_ostream os(s);
220     os << *type;
221     emitError(unknownLoc) << "unhandled type: " << os.str();
222     return nullptr;
223   }
224   }
225 }
226 
227 // We only need integers, floats, doubles, and vectors and tensors thereof for
228 // attributes. Scalar and vector types are converted to the standard
229 // equivalents. Array types are converted to ranked tensors; nested array types
230 // are converted to multi-dimensional tensors or vectors, depending on the
231 // innermost type being a scalar or a vector.
232 Type Importer::getStdTypeForAttr(LLVMType type) {
233   if (!type)
234     return nullptr;
235 
236   if (type.isIntegerTy())
237     return b.getIntegerType(type.getUnderlyingType()->getIntegerBitWidth());
238 
239   if (type.getUnderlyingType()->isFloatTy())
240     return b.getF32Type();
241 
242   if (type.getUnderlyingType()->isDoubleTy())
243     return b.getF64Type();
244 
245   // LLVM vectors can only contain scalars.
246   if (type.isVectorTy()) {
247     auto numElements = llvm::cast<llvm::VectorType>(type.getUnderlyingType())
248                            ->getElementCount();
249     if (numElements.Scalable) {
250       emitError(unknownLoc) << "scalable vectors not supported";
251       return nullptr;
252     }
253     Type elementType = getStdTypeForAttr(type.getVectorElementType());
254     if (!elementType)
255       return nullptr;
256     return VectorType::get(numElements.Min, elementType);
257   }
258 
259   // LLVM arrays can contain other arrays or vectors.
260   if (type.isArrayTy()) {
261     // Recover the nested array shape.
262     SmallVector<int64_t, 4> shape;
263     shape.push_back(type.getArrayNumElements());
264     while (type.getArrayElementType().isArrayTy()) {
265       type = type.getArrayElementType();
266       shape.push_back(type.getArrayNumElements());
267     }
268 
269     // If the innermost type is a vector, use the multi-dimensional vector as
270     // attribute type.
271     if (type.getArrayElementType().isVectorTy()) {
272       LLVMType vectorType = type.getArrayElementType();
273       auto numElements =
274           llvm::cast<llvm::VectorType>(vectorType.getUnderlyingType())
275               ->getElementCount();
276       if (numElements.Scalable) {
277         emitError(unknownLoc) << "scalable vectors not supported";
278         return nullptr;
279       }
280       shape.push_back(numElements.Min);
281 
282       Type elementType = getStdTypeForAttr(vectorType.getVectorElementType());
283       if (!elementType)
284         return nullptr;
285       return VectorType::get(shape, elementType);
286     }
287 
288     // Otherwise use a tensor.
289     Type elementType = getStdTypeForAttr(type.getArrayElementType());
290     if (!elementType)
291       return nullptr;
292     return RankedTensorType::get(shape, elementType);
293   }
294 
295   return nullptr;
296 }
297 
298 // Get the given constant as an attribute. Not all constants can be represented
299 // as attributes.
300 Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
301   if (auto *ci = dyn_cast<llvm::ConstantInt>(value))
302     return b.getIntegerAttr(
303         IntegerType::get(ci->getType()->getBitWidth(), context),
304         ci->getValue());
305   if (auto *c = dyn_cast<llvm::ConstantDataArray>(value))
306     if (c->isString())
307       return b.getStringAttr(c->getAsString());
308   if (auto *c = dyn_cast<llvm::ConstantFP>(value)) {
309     if (c->getType()->isDoubleTy())
310       return b.getFloatAttr(FloatType::getF64(context), c->getValueAPF());
311     else if (c->getType()->isFloatingPointTy())
312       return b.getFloatAttr(FloatType::getF32(context), c->getValueAPF());
313   }
314   if (auto *f = dyn_cast<llvm::Function>(value))
315     return b.getSymbolRefAttr(f->getName());
316 
317   // Convert constant data to a dense elements attribute.
318   if (auto *cd = dyn_cast<llvm::ConstantDataSequential>(value)) {
319     LLVMType type = processType(cd->getElementType());
320     if (!type)
321       return nullptr;
322 
323     auto attrType = getStdTypeForAttr(processType(cd->getType()))
324                         .dyn_cast_or_null<ShapedType>();
325     if (!attrType)
326       return nullptr;
327 
328     if (type.isIntegerTy()) {
329       SmallVector<APInt, 8> values;
330       values.reserve(cd->getNumElements());
331       for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i)
332         values.push_back(cd->getElementAsAPInt(i));
333       return DenseElementsAttr::get(attrType, values);
334     }
335 
336     if (type.isFloatTy() || type.isDoubleTy()) {
337       SmallVector<APFloat, 8> values;
338       values.reserve(cd->getNumElements());
339       for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i)
340         values.push_back(cd->getElementAsAPFloat(i));
341       return DenseElementsAttr::get(attrType, values);
342     }
343 
344     return nullptr;
345   }
346 
347   // Unpack constant aggregates to create dense elements attribute whenever
348   // possible. Return nullptr (failure) otherwise.
349   if (isa<llvm::ConstantAggregate>(value)) {
350     auto outerType = getStdTypeForAttr(processType(value->getType()))
351                          .dyn_cast_or_null<ShapedType>();
352     if (!outerType)
353       return nullptr;
354 
355     SmallVector<Attribute, 8> values;
356     SmallVector<int64_t, 8> shape;
357 
358     for (unsigned i = 0, e = value->getNumOperands(); i < e; ++i) {
359       auto nested = getConstantAsAttr(value->getAggregateElement(i))
360                         .dyn_cast_or_null<DenseElementsAttr>();
361       if (!nested)
362         return nullptr;
363 
364       values.append(nested.attr_value_begin(), nested.attr_value_end());
365     }
366 
367     return DenseElementsAttr::get(outerType, values);
368   }
369 
370   return nullptr;
371 }
372 
373 GlobalOp Importer::processGlobal(llvm::GlobalVariable *GV) {
374   auto it = globals.find(GV);
375   if (it != globals.end())
376     return it->second;
377 
378   OpBuilder b(module.getBody(), getGlobalInsertPt());
379   Attribute valueAttr;
380   if (GV->hasInitializer())
381     valueAttr = getConstantAsAttr(GV->getInitializer());
382   LLVMType type = processType(GV->getValueType());
383   if (!type)
384     return nullptr;
385   GlobalOp op = b.create<GlobalOp>(
386       UnknownLoc::get(context), type, GV->isConstant(),
387       convertLinkageFromLLVM(GV->getLinkage()), GV->getName(), valueAttr);
388   if (GV->hasInitializer() && !valueAttr) {
389     Region &r = op.getInitializerRegion();
390     currentEntryBlock = b.createBlock(&r);
391     b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin());
392     Value v = processConstant(GV->getInitializer());
393     if (!v)
394       return nullptr;
395     b.create<ReturnOp>(op.getLoc(), ArrayRef<Value>({v}));
396   }
397   return globals[GV] = op;
398 }
399 
400 Value Importer::processConstant(llvm::Constant *c) {
401   OpBuilder bEntry(currentEntryBlock, currentEntryBlock->begin());
402   if (Attribute attr = getConstantAsAttr(c)) {
403     // These constants can be represented as attributes.
404     OpBuilder b(currentEntryBlock, currentEntryBlock->begin());
405     LLVMType type = processType(c->getType());
406     if (!type)
407       return nullptr;
408     if (auto symbolRef = attr.dyn_cast<FlatSymbolRefAttr>())
409       return instMap[c] = bEntry.create<AddressOfOp>(unknownLoc, type,
410                                                      symbolRef.getValue());
411     return instMap[c] = bEntry.create<ConstantOp>(unknownLoc, type, attr);
412   }
413   if (auto *cn = dyn_cast<llvm::ConstantPointerNull>(c)) {
414     LLVMType type = processType(cn->getType());
415     if (!type)
416       return nullptr;
417     return instMap[c] = bEntry.create<NullOp>(unknownLoc, type);
418   }
419   if (auto *GV = dyn_cast<llvm::GlobalVariable>(c))
420     return bEntry.create<AddressOfOp>(UnknownLoc::get(context),
421                                       processGlobal(GV),
422                                       ArrayRef<NamedAttribute>());
423 
424   if (auto *ce = dyn_cast<llvm::ConstantExpr>(c)) {
425     llvm::Instruction *i = ce->getAsInstruction();
426     OpBuilder::InsertionGuard guard(b);
427     b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin());
428     if (failed(processInstruction(i)))
429       return nullptr;
430     assert(instMap.count(i));
431 
432     // Remove this zombie LLVM instruction now, leaving us only with the MLIR
433     // op.
434     i->deleteValue();
435     return instMap[c] = instMap[i];
436   }
437   if (auto *ue = dyn_cast<llvm::UndefValue>(c)) {
438     LLVMType type = processType(ue->getType());
439     if (!type)
440       return nullptr;
441     return instMap[c] = bEntry.create<UndefOp>(UnknownLoc::get(context), type);
442   }
443   emitError(unknownLoc) << "unhandled constant: " << diag(*c);
444   return nullptr;
445 }
446 
447 Value Importer::processValue(llvm::Value *value) {
448   auto it = instMap.find(value);
449   if (it != instMap.end())
450     return it->second;
451 
452   // We don't expect to see instructions in dominator order. If we haven't seen
453   // this instruction yet, create an unknown op and remap it later.
454   if (isa<llvm::Instruction>(value)) {
455     OperationState state(UnknownLoc::get(context), "llvm.unknown");
456     LLVMType type = processType(value->getType());
457     if (!type)
458       return nullptr;
459     state.addTypes(type);
460     unknownInstMap[value] = b.createOperation(state);
461     return unknownInstMap[value]->getResult(0);
462   }
463 
464   if (auto *c = dyn_cast<llvm::Constant>(value))
465     return processConstant(c);
466 
467   emitError(unknownLoc) << "unhandled value: " << diag(*value);
468   return nullptr;
469 }
470 
471 /// Return the MLIR OperationName for the given LLVM opcode.
472 static StringRef lookupOperationNameFromOpcode(unsigned opcode) {
473 // Maps from LLVM opcode to MLIR OperationName. This is deliberately ordered
474 // as in llvm/IR/Instructions.def to aid comprehension and spot missing
475 // instructions.
476 #define INST(llvm_n, mlir_n)                                                   \
477   { llvm::Instruction::llvm_n, LLVM::mlir_n##Op::getOperationName() }
478   static const DenseMap<unsigned, StringRef> opcMap = {
479       // Ret is handled specially.
480       // Br is handled specially.
481       // FIXME: switch
482       // FIXME: indirectbr
483       // FIXME: invoke
484       INST(Resume, Resume),
485       // FIXME: unreachable
486       // FIXME: cleanupret
487       // FIXME: catchret
488       // FIXME: catchswitch
489       // FIXME: callbr
490       // FIXME: fneg
491       INST(Add, Add), INST(FAdd, FAdd), INST(Sub, Sub), INST(FSub, FSub),
492       INST(Mul, Mul), INST(FMul, FMul), INST(UDiv, UDiv), INST(SDiv, SDiv),
493       INST(FDiv, FDiv), INST(URem, URem), INST(SRem, SRem), INST(FRem, FRem),
494       INST(Shl, Shl), INST(LShr, LShr), INST(AShr, AShr), INST(And, And),
495       INST(Or, Or), INST(Xor, XOr), INST(Alloca, Alloca), INST(Load, Load),
496       INST(Store, Store),
497       // Getelementptr is handled specially.
498       INST(Ret, Return), INST(Fence, Fence),
499       // FIXME: atomiccmpxchg
500       // FIXME: atomicrmw
501       INST(Trunc, Trunc), INST(ZExt, ZExt), INST(SExt, SExt),
502       INST(FPToUI, FPToUI), INST(FPToSI, FPToSI), INST(UIToFP, UIToFP),
503       INST(SIToFP, SIToFP), INST(FPTrunc, FPTrunc), INST(FPExt, FPExt),
504       INST(PtrToInt, PtrToInt), INST(IntToPtr, IntToPtr),
505       INST(BitCast, Bitcast), INST(AddrSpaceCast, AddrSpaceCast),
506       // FIXME: cleanuppad
507       // FIXME: catchpad
508       // ICmp is handled specially.
509       // FIXME: fcmp
510       // PHI is handled specially.
511       INST(Freeze, Freeze), INST(Call, Call),
512       // FIXME: select
513       // FIXME: vaarg
514       // FIXME: extractelement
515       // FIXME: insertelement
516       // FIXME: shufflevector
517       // FIXME: extractvalue
518       // FIXME: insertvalue
519       // FIXME: landingpad
520   };
521 #undef INST
522 
523   return opcMap.lookup(opcode);
524 }
525 
526 static ICmpPredicate getICmpPredicate(llvm::CmpInst::Predicate p) {
527   switch (p) {
528   default:
529     llvm_unreachable("incorrect comparison predicate");
530   case llvm::CmpInst::Predicate::ICMP_EQ:
531     return LLVM::ICmpPredicate::eq;
532   case llvm::CmpInst::Predicate::ICMP_NE:
533     return LLVM::ICmpPredicate::ne;
534   case llvm::CmpInst::Predicate::ICMP_SLT:
535     return LLVM::ICmpPredicate::slt;
536   case llvm::CmpInst::Predicate::ICMP_SLE:
537     return LLVM::ICmpPredicate::sle;
538   case llvm::CmpInst::Predicate::ICMP_SGT:
539     return LLVM::ICmpPredicate::sgt;
540   case llvm::CmpInst::Predicate::ICMP_SGE:
541     return LLVM::ICmpPredicate::sge;
542   case llvm::CmpInst::Predicate::ICMP_ULT:
543     return LLVM::ICmpPredicate::ult;
544   case llvm::CmpInst::Predicate::ICMP_ULE:
545     return LLVM::ICmpPredicate::ule;
546   case llvm::CmpInst::Predicate::ICMP_UGT:
547     return LLVM::ICmpPredicate::ugt;
548   case llvm::CmpInst::Predicate::ICMP_UGE:
549     return LLVM::ICmpPredicate::uge;
550   }
551   llvm_unreachable("incorrect comparison predicate");
552 }
553 
554 static AtomicOrdering getLLVMAtomicOrdering(llvm::AtomicOrdering ordering) {
555   switch (ordering) {
556   case llvm::AtomicOrdering::NotAtomic:
557     return LLVM::AtomicOrdering::not_atomic;
558   case llvm::AtomicOrdering::Unordered:
559     return LLVM::AtomicOrdering::unordered;
560   case llvm::AtomicOrdering::Monotonic:
561     return LLVM::AtomicOrdering::monotonic;
562   case llvm::AtomicOrdering::Acquire:
563     return LLVM::AtomicOrdering::acquire;
564   case llvm::AtomicOrdering::Release:
565     return LLVM::AtomicOrdering::release;
566   case llvm::AtomicOrdering::AcquireRelease:
567     return LLVM::AtomicOrdering::acq_rel;
568   case llvm::AtomicOrdering::SequentiallyConsistent:
569     return LLVM::AtomicOrdering::seq_cst;
570   }
571   llvm_unreachable("incorrect atomic ordering");
572 }
573 
574 // `br` branches to `target`. Return the branch arguments to `br`, in the
575 // same order of the PHIs in `target`.
576 LogicalResult
577 Importer::processBranchArgs(llvm::Instruction *br, llvm::BasicBlock *target,
578                             SmallVectorImpl<Value> &blockArguments) {
579   for (auto inst = target->begin(); isa<llvm::PHINode>(inst); ++inst) {
580     auto *PN = cast<llvm::PHINode>(&*inst);
581     Value value = processValue(PN->getIncomingValueForBlock(br->getParent()));
582     if (!value)
583       return failure();
584     blockArguments.push_back(value);
585   }
586   return success();
587 }
588 
589 LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
590   // FIXME: Support uses of SubtargetData. Currently inbounds GEPs, fast-math
591   // flags and call / operand attributes are not supported.
592   Location loc = processDebugLoc(inst->getDebugLoc(), inst);
593   Value &v = instMap[inst];
594   assert(!v && "processInstruction must be called only once per instruction!");
595   switch (inst->getOpcode()) {
596   default:
597     return emitError(loc) << "unknown instruction: " << diag(*inst);
598   case llvm::Instruction::Add:
599   case llvm::Instruction::FAdd:
600   case llvm::Instruction::Sub:
601   case llvm::Instruction::FSub:
602   case llvm::Instruction::Mul:
603   case llvm::Instruction::FMul:
604   case llvm::Instruction::UDiv:
605   case llvm::Instruction::SDiv:
606   case llvm::Instruction::FDiv:
607   case llvm::Instruction::URem:
608   case llvm::Instruction::SRem:
609   case llvm::Instruction::FRem:
610   case llvm::Instruction::Shl:
611   case llvm::Instruction::LShr:
612   case llvm::Instruction::AShr:
613   case llvm::Instruction::And:
614   case llvm::Instruction::Or:
615   case llvm::Instruction::Xor:
616   case llvm::Instruction::Alloca:
617   case llvm::Instruction::Load:
618   case llvm::Instruction::Store:
619   case llvm::Instruction::Ret:
620   case llvm::Instruction::Resume:
621   case llvm::Instruction::Trunc:
622   case llvm::Instruction::ZExt:
623   case llvm::Instruction::SExt:
624   case llvm::Instruction::FPToUI:
625   case llvm::Instruction::FPToSI:
626   case llvm::Instruction::UIToFP:
627   case llvm::Instruction::SIToFP:
628   case llvm::Instruction::FPTrunc:
629   case llvm::Instruction::FPExt:
630   case llvm::Instruction::PtrToInt:
631   case llvm::Instruction::IntToPtr:
632   case llvm::Instruction::AddrSpaceCast:
633   case llvm::Instruction::Freeze:
634   case llvm::Instruction::BitCast: {
635     OperationState state(loc, lookupOperationNameFromOpcode(inst->getOpcode()));
636     SmallVector<Value, 4> ops;
637     ops.reserve(inst->getNumOperands());
638     for (auto *op : inst->operand_values()) {
639       Value value = processValue(op);
640       if (!value)
641         return failure();
642       ops.push_back(value);
643     }
644     state.addOperands(ops);
645     if (!inst->getType()->isVoidTy()) {
646       LLVMType type = processType(inst->getType());
647       if (!type)
648         return failure();
649       state.addTypes(type);
650     }
651     Operation *op = b.createOperation(state);
652     if (!inst->getType()->isVoidTy())
653       v = op->getResult(0);
654     return success();
655   }
656   case llvm::Instruction::ICmp: {
657     Value lhs = processValue(inst->getOperand(0));
658     Value rhs = processValue(inst->getOperand(1));
659     if (!lhs || !rhs)
660       return failure();
661     v = b.create<ICmpOp>(
662         loc, getICmpPredicate(cast<llvm::ICmpInst>(inst)->getPredicate()), lhs,
663         rhs);
664     return success();
665   }
666   case llvm::Instruction::Br: {
667     auto *brInst = cast<llvm::BranchInst>(inst);
668     OperationState state(loc,
669                          brInst->isConditional() ? "llvm.cond_br" : "llvm.br");
670     if (brInst->isConditional()) {
671       Value condition = processValue(brInst->getCondition());
672       if (!condition)
673         return failure();
674       state.addOperands(condition);
675     }
676 
677     std::array<int32_t, 3> operandSegmentSizes = {1, 0, 0};
678     for (int i : llvm::seq<int>(0, brInst->getNumSuccessors())) {
679       auto *succ = brInst->getSuccessor(i);
680       SmallVector<Value, 4> blockArguments;
681       if (failed(processBranchArgs(brInst, succ, blockArguments)))
682         return failure();
683       state.addSuccessors(blocks[succ]);
684       state.addOperands(blockArguments);
685       operandSegmentSizes[i + 1] = blockArguments.size();
686     }
687 
688     if (brInst->isConditional()) {
689       state.addAttribute(LLVM::CondBrOp::getOperandSegmentSizeAttr(),
690                          b.getI32VectorAttr(operandSegmentSizes));
691     }
692 
693     b.createOperation(state);
694     return success();
695   }
696   case llvm::Instruction::PHI: {
697     LLVMType type = processType(inst->getType());
698     if (!type)
699       return failure();
700     v = b.getInsertionBlock()->addArgument(type);
701     return success();
702   }
703   case llvm::Instruction::Call: {
704     llvm::CallInst *ci = cast<llvm::CallInst>(inst);
705     SmallVector<Value, 4> ops;
706     ops.reserve(inst->getNumOperands());
707     for (auto &op : ci->arg_operands()) {
708       Value arg = processValue(op.get());
709       if (!arg)
710         return failure();
711       ops.push_back(arg);
712     }
713 
714     SmallVector<Type, 2> tys;
715     if (!ci->getType()->isVoidTy()) {
716       LLVMType type = processType(inst->getType());
717       if (!type)
718         return failure();
719       tys.push_back(type);
720     }
721     Operation *op;
722     if (llvm::Function *callee = ci->getCalledFunction()) {
723       op = b.create<CallOp>(loc, tys, b.getSymbolRefAttr(callee->getName()),
724                             ops);
725     } else {
726       Value calledValue = processValue(ci->getCalledOperand());
727       if (!calledValue)
728         return failure();
729       ops.insert(ops.begin(), calledValue);
730       op = b.create<CallOp>(loc, tys, ops, ArrayRef<NamedAttribute>());
731     }
732     if (!ci->getType()->isVoidTy())
733       v = op->getResult(0);
734     return success();
735   }
736   case llvm::Instruction::LandingPad: {
737     llvm::LandingPadInst *lpi = cast<llvm::LandingPadInst>(inst);
738     SmallVector<Value, 4> ops;
739 
740     for (unsigned i = 0, ie = lpi->getNumClauses(); i < ie; i++)
741       ops.push_back(processConstant(lpi->getClause(i)));
742 
743     Type ty = processType(lpi->getType());
744     if (!ty)
745       return failure();
746 
747     v = b.create<LandingpadOp>(loc, ty, lpi->isCleanup(), ops);
748     return success();
749   }
750   case llvm::Instruction::Invoke: {
751     llvm::InvokeInst *ii = cast<llvm::InvokeInst>(inst);
752 
753     SmallVector<Type, 2> tys;
754     if (!ii->getType()->isVoidTy())
755       tys.push_back(processType(inst->getType()));
756 
757     SmallVector<Value, 4> ops;
758     ops.reserve(inst->getNumOperands() + 1);
759     for (auto &op : ii->arg_operands())
760       ops.push_back(processValue(op.get()));
761 
762     SmallVector<Value, 4> normalArgs, unwindArgs;
763     processBranchArgs(ii, ii->getNormalDest(), normalArgs);
764     processBranchArgs(ii, ii->getUnwindDest(), unwindArgs);
765 
766     Operation *op;
767     if (llvm::Function *callee = ii->getCalledFunction()) {
768       op = b.create<InvokeOp>(loc, tys, b.getSymbolRefAttr(callee->getName()),
769                               ops, blocks[ii->getNormalDest()], normalArgs,
770                               blocks[ii->getUnwindDest()], unwindArgs);
771     } else {
772       ops.insert(ops.begin(), processValue(ii->getCalledOperand()));
773       op = b.create<InvokeOp>(loc, tys, ops, blocks[ii->getNormalDest()],
774                               normalArgs, blocks[ii->getUnwindDest()],
775                               unwindArgs);
776     }
777 
778     if (!ii->getType()->isVoidTy())
779       v = op->getResult(0);
780     return success();
781   }
782   case llvm::Instruction::Fence: {
783     StringRef syncscope;
784     SmallVector<StringRef, 4> ssNs;
785     llvm::LLVMContext &llvmContext = dialect->getLLVMContext();
786     llvm::FenceInst *fence = cast<llvm::FenceInst>(inst);
787     llvmContext.getSyncScopeNames(ssNs);
788     int fenceSyncScopeID = fence->getSyncScopeID();
789     for (unsigned i = 0, e = ssNs.size(); i != e; i++) {
790       if (fenceSyncScopeID == llvmContext.getOrInsertSyncScopeID(ssNs[i])) {
791         syncscope = ssNs[i];
792         break;
793       }
794     }
795     b.create<FenceOp>(loc, getLLVMAtomicOrdering(fence->getOrdering()),
796                       syncscope);
797     return success();
798   }
799   case llvm::Instruction::GetElementPtr: {
800     // FIXME: Support inbounds GEPs.
801     llvm::GetElementPtrInst *gep = cast<llvm::GetElementPtrInst>(inst);
802     SmallVector<Value, 4> ops;
803     for (auto *op : gep->operand_values()) {
804       Value value = processValue(op);
805       if (!value)
806         return failure();
807       ops.push_back(value);
808     }
809     Type type = processType(inst->getType());
810     if (!type)
811       return failure();
812     v = b.create<GEPOp>(loc, type, ops, ArrayRef<NamedAttribute>());
813     return success();
814   }
815   }
816 }
817 
818 FlatSymbolRefAttr Importer::getPersonalityAsAttr(llvm::Function *f) {
819   if (!f->hasPersonalityFn())
820     return nullptr;
821 
822   llvm::Constant *pf = f->getPersonalityFn();
823 
824   // If it directly has a name, we can use it.
825   if (pf->hasName())
826     return b.getSymbolRefAttr(pf->getName());
827 
828   // If it doesn't have a name, currently, only function pointers that are
829   // bitcast to i8* are parsed.
830   if (auto ce = dyn_cast<llvm::ConstantExpr>(pf)) {
831     if (ce->getOpcode() == llvm::Instruction::BitCast &&
832         ce->getType() == llvm::Type::getInt8PtrTy(dialect->getLLVMContext())) {
833       if (auto func = dyn_cast<llvm::Function>(ce->getOperand(0)))
834         return b.getSymbolRefAttr(func->getName());
835     }
836   }
837   return FlatSymbolRefAttr();
838 }
839 
840 LogicalResult Importer::processFunction(llvm::Function *f) {
841   blocks.clear();
842   instMap.clear();
843   unknownInstMap.clear();
844 
845   LLVMType functionType = processType(f->getFunctionType());
846   if (!functionType)
847     return failure();
848 
849   b.setInsertionPoint(module.getBody(), getFuncInsertPt());
850   LLVMFuncOp fop = b.create<LLVMFuncOp>(UnknownLoc::get(context), f->getName(),
851                                         functionType);
852 
853   if (FlatSymbolRefAttr personality = getPersonalityAsAttr(f))
854     fop.setAttr(b.getIdentifier("personality"), personality);
855   else if (f->hasPersonalityFn())
856     emitWarning(UnknownLoc::get(context),
857                 "could not deduce personality, skipping it");
858 
859   if (f->isDeclaration())
860     return success();
861 
862   // Eagerly create all blocks.
863   SmallVector<Block *, 4> blockList;
864   for (llvm::BasicBlock &bb : *f) {
865     blockList.push_back(b.createBlock(&fop.body(), fop.body().end()));
866     blocks[&bb] = blockList.back();
867   }
868   currentEntryBlock = blockList[0];
869 
870   // Add function arguments to the entry block.
871   for (auto kv : llvm::enumerate(f->args()))
872     instMap[&kv.value()] = blockList[0]->addArgument(
873         functionType.getFunctionParamType(kv.index()));
874 
875   for (auto bbs : llvm::zip(*f, blockList)) {
876     if (failed(processBasicBlock(&std::get<0>(bbs), std::get<1>(bbs))))
877       return failure();
878   }
879 
880   // Now that all instructions are guaranteed to have been visited, ensure
881   // any unknown uses we encountered are remapped.
882   for (auto &llvmAndUnknown : unknownInstMap) {
883     assert(instMap.count(llvmAndUnknown.first));
884     Value newValue = instMap[llvmAndUnknown.first];
885     Value oldValue = llvmAndUnknown.second->getResult(0);
886     oldValue.replaceAllUsesWith(newValue);
887     llvmAndUnknown.second->erase();
888   }
889   return success();
890 }
891 
892 LogicalResult Importer::processBasicBlock(llvm::BasicBlock *bb, Block *block) {
893   b.setInsertionPointToStart(block);
894   for (llvm::Instruction &inst : *bb) {
895     if (failed(processInstruction(&inst)))
896       return failure();
897   }
898   return success();
899 }
900 
901 OwningModuleRef
902 mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
903                               MLIRContext *context) {
904   OwningModuleRef module(ModuleOp::create(
905       FileLineColLoc::get("", /*line=*/0, /*column=*/0, context)));
906 
907   Importer deserializer(context, module.get());
908   for (llvm::GlobalVariable &gv : llvmModule->globals()) {
909     if (!deserializer.processGlobal(&gv))
910       return {};
911   }
912   for (llvm::Function &f : llvmModule->functions()) {
913     if (failed(deserializer.processFunction(&f)))
914       return {};
915   }
916 
917   return module;
918 }
919 
920 // Deserializes the LLVM bitcode stored in `input` into an MLIR module in the
921 // LLVM dialect.
922 OwningModuleRef translateLLVMIRToModule(llvm::SourceMgr &sourceMgr,
923                                         MLIRContext *context) {
924   LLVMDialect *dialect = context->getRegisteredDialect<LLVMDialect>();
925   assert(dialect && "Could not find LLVMDialect?");
926 
927   llvm::SMDiagnostic err;
928   std::unique_ptr<llvm::Module> llvmModule =
929       llvm::parseIR(*sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), err,
930                     dialect->getLLVMContext());
931   if (!llvmModule) {
932     std::string errStr;
933     llvm::raw_string_ostream errStream(errStr);
934     err.print(/*ProgName=*/"", errStream);
935     emitError(UnknownLoc::get(context)) << errStream.str();
936     return {};
937   }
938   return translateLLVMIRToModule(std::move(llvmModule), context);
939 }
940 
941 namespace mlir {
942 void registerFromLLVMIRTranslation() {
943   TranslateToMLIRRegistration fromLLVM(
944       "import-llvm", [](llvm::SourceMgr &sourceMgr, MLIRContext *context) {
945         return ::translateLLVMIRToModule(sourceMgr, context);
946       });
947 }
948 } // namespace mlir
949