1 //===- TypeFromLLVM.cpp - type translation from LLVM to MLIR IR -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Target/LLVMIR/TypeFromLLVM.h"
10 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
11 #include "mlir/IR/BuiltinTypes.h"
12 #include "mlir/IR/MLIRContext.h"
13 
14 #include "llvm/ADT/TypeSwitch.h"
15 #include "llvm/IR/DataLayout.h"
16 #include "llvm/IR/DerivedTypes.h"
17 #include "llvm/IR/Type.h"
18 
19 using namespace mlir;
20 
21 namespace mlir {
22 namespace LLVM {
23 namespace detail {
24 /// Support for translating LLVM IR types to MLIR LLVM dialect types.
25 class TypeFromLLVMIRTranslatorImpl {
26 public:
27   /// Constructs a class creating types in the given MLIR context.
TypeFromLLVMIRTranslatorImpl(MLIRContext & context)28   TypeFromLLVMIRTranslatorImpl(MLIRContext &context) : context(context) {}
29 
30   /// Translates the given type.
translateType(llvm::Type * type)31   Type translateType(llvm::Type *type) {
32     if (knownTranslations.count(type))
33       return knownTranslations.lookup(type);
34 
35     Type translated =
36         llvm::TypeSwitch<llvm::Type *, Type>(type)
37             .Case<llvm::ArrayType, llvm::FunctionType, llvm::IntegerType,
38                   llvm::PointerType, llvm::StructType, llvm::FixedVectorType,
39                   llvm::ScalableVectorType>(
40                 [this](auto *type) { return this->translate(type); })
41             .Default([this](llvm::Type *type) {
42               return translatePrimitiveType(type);
43             });
44     knownTranslations.try_emplace(type, translated);
45     return translated;
46   }
47 
48 private:
49   /// Translates the given primitive, i.e. non-parametric in MLIR nomenclature,
50   /// type.
translatePrimitiveType(llvm::Type * type)51   Type translatePrimitiveType(llvm::Type *type) {
52     if (type->isVoidTy())
53       return LLVM::LLVMVoidType::get(&context);
54     if (type->isHalfTy())
55       return Float16Type::get(&context);
56     if (type->isBFloatTy())
57       return BFloat16Type::get(&context);
58     if (type->isFloatTy())
59       return Float32Type::get(&context);
60     if (type->isDoubleTy())
61       return Float64Type::get(&context);
62     if (type->isFP128Ty())
63       return Float128Type::get(&context);
64     if (type->isX86_FP80Ty())
65       return Float80Type::get(&context);
66     if (type->isPPC_FP128Ty())
67       return LLVM::LLVMPPCFP128Type::get(&context);
68     if (type->isX86_MMXTy())
69       return LLVM::LLVMX86MMXType::get(&context);
70     if (type->isLabelTy())
71       return LLVM::LLVMLabelType::get(&context);
72     if (type->isMetadataTy())
73       return LLVM::LLVMMetadataType::get(&context);
74     if (type->isTokenTy())
75       return LLVM::LLVMTokenType::get(&context);
76     llvm_unreachable("not a primitive type");
77   }
78 
79   /// Translates the given array type.
translate(llvm::ArrayType * type)80   Type translate(llvm::ArrayType *type) {
81     return LLVM::LLVMArrayType::get(translateType(type->getElementType()),
82                                     type->getNumElements());
83   }
84 
85   /// Translates the given function type.
translate(llvm::FunctionType * type)86   Type translate(llvm::FunctionType *type) {
87     SmallVector<Type, 8> paramTypes;
88     translateTypes(type->params(), paramTypes);
89     return LLVM::LLVMFunctionType::get(translateType(type->getReturnType()),
90                                        paramTypes, type->isVarArg());
91   }
92 
93   /// Translates the given integer type.
translate(llvm::IntegerType * type)94   Type translate(llvm::IntegerType *type) {
95     return IntegerType::get(&context, type->getBitWidth());
96   }
97 
98   /// Translates the given pointer type.
translate(llvm::PointerType * type)99   Type translate(llvm::PointerType *type) {
100     if (type->isOpaque())
101       return LLVM::LLVMPointerType::get(&context, type->getAddressSpace());
102 
103     return LLVM::LLVMPointerType::get(
104         translateType(type->getNonOpaquePointerElementType()),
105         type->getAddressSpace());
106   }
107 
108   /// Translates the given structure type.
translate(llvm::StructType * type)109   Type translate(llvm::StructType *type) {
110     SmallVector<Type, 8> subtypes;
111     if (type->isLiteral()) {
112       translateTypes(type->subtypes(), subtypes);
113       return LLVM::LLVMStructType::getLiteral(&context, subtypes,
114                                               type->isPacked());
115     }
116 
117     if (type->isOpaque())
118       return LLVM::LLVMStructType::getOpaque(type->getName(), &context);
119 
120     LLVM::LLVMStructType translated =
121         LLVM::LLVMStructType::getIdentified(&context, type->getName());
122     knownTranslations.try_emplace(type, translated);
123     translateTypes(type->subtypes(), subtypes);
124     LogicalResult bodySet = translated.setBody(subtypes, type->isPacked());
125     assert(succeeded(bodySet) &&
126            "could not set the body of an identified struct");
127     (void)bodySet;
128     return translated;
129   }
130 
131   /// Translates the given fixed-vector type.
translate(llvm::FixedVectorType * type)132   Type translate(llvm::FixedVectorType *type) {
133     return LLVM::getFixedVectorType(translateType(type->getElementType()),
134                                     type->getNumElements());
135   }
136 
137   /// Translates the given scalable-vector type.
translate(llvm::ScalableVectorType * type)138   Type translate(llvm::ScalableVectorType *type) {
139     return LLVM::LLVMScalableVectorType::get(
140         translateType(type->getElementType()), type->getMinNumElements());
141   }
142 
143   /// Translates a list of types.
translateTypes(ArrayRef<llvm::Type * > types,SmallVectorImpl<Type> & result)144   void translateTypes(ArrayRef<llvm::Type *> types,
145                       SmallVectorImpl<Type> &result) {
146     result.reserve(result.size() + types.size());
147     for (llvm::Type *type : types)
148       result.push_back(translateType(type));
149   }
150 
151   /// Map of known translations. Serves as a cache and as recursion stopper for
152   /// translating recursive structs.
153   llvm::DenseMap<llvm::Type *, Type> knownTranslations;
154 
155   /// The context in which MLIR types are created.
156   MLIRContext &context;
157 };
158 
159 } // namespace detail
160 } // namespace LLVM
161 } // namespace mlir
162 
TypeFromLLVMIRTranslator(MLIRContext & context)163 LLVM::TypeFromLLVMIRTranslator::TypeFromLLVMIRTranslator(MLIRContext &context)
164     : impl(new detail::TypeFromLLVMIRTranslatorImpl(context)) {}
165 
166 LLVM::TypeFromLLVMIRTranslator::~TypeFromLLVMIRTranslator() = default;
167 
translateType(llvm::Type * type)168 Type LLVM::TypeFromLLVMIRTranslator::translateType(llvm::Type *type) {
169   return impl->translateType(type);
170 }
171