1 //===- TypeFromLLVM.cpp - type translation from LLVM to MLIR IR -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Target/LLVMIR/TypeFromLLVM.h"
10 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
11 #include "mlir/IR/BuiltinTypes.h"
12 #include "mlir/IR/MLIRContext.h"
13
14 #include "llvm/ADT/TypeSwitch.h"
15 #include "llvm/IR/DataLayout.h"
16 #include "llvm/IR/DerivedTypes.h"
17 #include "llvm/IR/Type.h"
18
19 using namespace mlir;
20
21 namespace mlir {
22 namespace LLVM {
23 namespace detail {
24 /// Support for translating LLVM IR types to MLIR LLVM dialect types.
25 class TypeFromLLVMIRTranslatorImpl {
26 public:
27 /// Constructs a class creating types in the given MLIR context.
TypeFromLLVMIRTranslatorImpl(MLIRContext & context)28 TypeFromLLVMIRTranslatorImpl(MLIRContext &context) : context(context) {}
29
30 /// Translates the given type.
translateType(llvm::Type * type)31 Type translateType(llvm::Type *type) {
32 if (knownTranslations.count(type))
33 return knownTranslations.lookup(type);
34
35 Type translated =
36 llvm::TypeSwitch<llvm::Type *, Type>(type)
37 .Case<llvm::ArrayType, llvm::FunctionType, llvm::IntegerType,
38 llvm::PointerType, llvm::StructType, llvm::FixedVectorType,
39 llvm::ScalableVectorType>(
40 [this](auto *type) { return this->translate(type); })
41 .Default([this](llvm::Type *type) {
42 return translatePrimitiveType(type);
43 });
44 knownTranslations.try_emplace(type, translated);
45 return translated;
46 }
47
48 private:
49 /// Translates the given primitive, i.e. non-parametric in MLIR nomenclature,
50 /// type.
translatePrimitiveType(llvm::Type * type)51 Type translatePrimitiveType(llvm::Type *type) {
52 if (type->isVoidTy())
53 return LLVM::LLVMVoidType::get(&context);
54 if (type->isHalfTy())
55 return Float16Type::get(&context);
56 if (type->isBFloatTy())
57 return BFloat16Type::get(&context);
58 if (type->isFloatTy())
59 return Float32Type::get(&context);
60 if (type->isDoubleTy())
61 return Float64Type::get(&context);
62 if (type->isFP128Ty())
63 return Float128Type::get(&context);
64 if (type->isX86_FP80Ty())
65 return Float80Type::get(&context);
66 if (type->isPPC_FP128Ty())
67 return LLVM::LLVMPPCFP128Type::get(&context);
68 if (type->isX86_MMXTy())
69 return LLVM::LLVMX86MMXType::get(&context);
70 if (type->isLabelTy())
71 return LLVM::LLVMLabelType::get(&context);
72 if (type->isMetadataTy())
73 return LLVM::LLVMMetadataType::get(&context);
74 if (type->isTokenTy())
75 return LLVM::LLVMTokenType::get(&context);
76 llvm_unreachable("not a primitive type");
77 }
78
79 /// Translates the given array type.
translate(llvm::ArrayType * type)80 Type translate(llvm::ArrayType *type) {
81 return LLVM::LLVMArrayType::get(translateType(type->getElementType()),
82 type->getNumElements());
83 }
84
85 /// Translates the given function type.
translate(llvm::FunctionType * type)86 Type translate(llvm::FunctionType *type) {
87 SmallVector<Type, 8> paramTypes;
88 translateTypes(type->params(), paramTypes);
89 return LLVM::LLVMFunctionType::get(translateType(type->getReturnType()),
90 paramTypes, type->isVarArg());
91 }
92
93 /// Translates the given integer type.
translate(llvm::IntegerType * type)94 Type translate(llvm::IntegerType *type) {
95 return IntegerType::get(&context, type->getBitWidth());
96 }
97
98 /// Translates the given pointer type.
translate(llvm::PointerType * type)99 Type translate(llvm::PointerType *type) {
100 if (type->isOpaque())
101 return LLVM::LLVMPointerType::get(&context, type->getAddressSpace());
102
103 return LLVM::LLVMPointerType::get(
104 translateType(type->getNonOpaquePointerElementType()),
105 type->getAddressSpace());
106 }
107
108 /// Translates the given structure type.
translate(llvm::StructType * type)109 Type translate(llvm::StructType *type) {
110 SmallVector<Type, 8> subtypes;
111 if (type->isLiteral()) {
112 translateTypes(type->subtypes(), subtypes);
113 return LLVM::LLVMStructType::getLiteral(&context, subtypes,
114 type->isPacked());
115 }
116
117 if (type->isOpaque())
118 return LLVM::LLVMStructType::getOpaque(type->getName(), &context);
119
120 LLVM::LLVMStructType translated =
121 LLVM::LLVMStructType::getIdentified(&context, type->getName());
122 knownTranslations.try_emplace(type, translated);
123 translateTypes(type->subtypes(), subtypes);
124 LogicalResult bodySet = translated.setBody(subtypes, type->isPacked());
125 assert(succeeded(bodySet) &&
126 "could not set the body of an identified struct");
127 (void)bodySet;
128 return translated;
129 }
130
131 /// Translates the given fixed-vector type.
translate(llvm::FixedVectorType * type)132 Type translate(llvm::FixedVectorType *type) {
133 return LLVM::getFixedVectorType(translateType(type->getElementType()),
134 type->getNumElements());
135 }
136
137 /// Translates the given scalable-vector type.
translate(llvm::ScalableVectorType * type)138 Type translate(llvm::ScalableVectorType *type) {
139 return LLVM::LLVMScalableVectorType::get(
140 translateType(type->getElementType()), type->getMinNumElements());
141 }
142
143 /// Translates a list of types.
translateTypes(ArrayRef<llvm::Type * > types,SmallVectorImpl<Type> & result)144 void translateTypes(ArrayRef<llvm::Type *> types,
145 SmallVectorImpl<Type> &result) {
146 result.reserve(result.size() + types.size());
147 for (llvm::Type *type : types)
148 result.push_back(translateType(type));
149 }
150
151 /// Map of known translations. Serves as a cache and as recursion stopper for
152 /// translating recursive structs.
153 llvm::DenseMap<llvm::Type *, Type> knownTranslations;
154
155 /// The context in which MLIR types are created.
156 MLIRContext &context;
157 };
158
159 } // namespace detail
160 } // namespace LLVM
161 } // namespace mlir
162
TypeFromLLVMIRTranslator(MLIRContext & context)163 LLVM::TypeFromLLVMIRTranslator::TypeFromLLVMIRTranslator(MLIRContext &context)
164 : impl(new detail::TypeFromLLVMIRTranslatorImpl(context)) {}
165
166 LLVM::TypeFromLLVMIRTranslator::~TypeFromLLVMIRTranslator() = default;
167
translateType(llvm::Type * type)168 Type LLVM::TypeFromLLVMIRTranslator::translateType(llvm::Type *type) {
169 return impl->translateType(type);
170 }
171