1 //===- TypeToLLVM.cpp - type translation from MLIR to LLVM IR -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
10 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
11 #include "mlir/IR/BuiltinTypes.h"
12 #include "mlir/IR/MLIRContext.h"
13 
14 #include "llvm/ADT/TypeSwitch.h"
15 #include "llvm/IR/DataLayout.h"
16 #include "llvm/IR/DerivedTypes.h"
17 #include "llvm/IR/Type.h"
18 
19 using namespace mlir;
20 
21 namespace mlir {
22 namespace LLVM {
23 namespace detail {
24 /// Support for translating MLIR LLVM dialect types to LLVM IR.
25 class TypeToLLVMIRTranslatorImpl {
26 public:
27   /// Constructs a class creating types in the given LLVM context.
TypeToLLVMIRTranslatorImpl(llvm::LLVMContext & context)28   TypeToLLVMIRTranslatorImpl(llvm::LLVMContext &context) : context(context) {}
29 
30   /// Translates a single type.
translateType(Type type)31   llvm::Type *translateType(Type type) {
32     // If the conversion is already known, just return it.
33     if (knownTranslations.count(type))
34       return knownTranslations.lookup(type);
35 
36     // Dispatch to an appropriate function.
37     llvm::Type *translated =
38         llvm::TypeSwitch<Type, llvm::Type *>(type)
39             .Case([this](LLVM::LLVMVoidType) {
40               return llvm::Type::getVoidTy(context);
41             })
42             .Case(
43                 [this](Float16Type) { return llvm::Type::getHalfTy(context); })
44             .Case([this](BFloat16Type) {
45               return llvm::Type::getBFloatTy(context);
46             })
47             .Case(
48                 [this](Float32Type) { return llvm::Type::getFloatTy(context); })
49             .Case([this](Float64Type) {
50               return llvm::Type::getDoubleTy(context);
51             })
52             .Case([this](Float80Type) {
53               return llvm::Type::getX86_FP80Ty(context);
54             })
55             .Case([this](Float128Type) {
56               return llvm::Type::getFP128Ty(context);
57             })
58             .Case([this](LLVM::LLVMPPCFP128Type) {
59               return llvm::Type::getPPC_FP128Ty(context);
60             })
61             .Case([this](LLVM::LLVMX86MMXType) {
62               return llvm::Type::getX86_MMXTy(context);
63             })
64             .Case([this](LLVM::LLVMTokenType) {
65               return llvm::Type::getTokenTy(context);
66             })
67             .Case([this](LLVM::LLVMLabelType) {
68               return llvm::Type::getLabelTy(context);
69             })
70             .Case([this](LLVM::LLVMMetadataType) {
71               return llvm::Type::getMetadataTy(context);
72             })
73             .Case<LLVM::LLVMArrayType, IntegerType, LLVM::LLVMFunctionType,
74                   LLVM::LLVMPointerType, LLVM::LLVMStructType,
75                   LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType,
76                   VectorType>(
77                 [this](auto type) { return this->translate(type); })
78             .Default([](Type t) -> llvm::Type * {
79               llvm_unreachable("unknown LLVM dialect type");
80             });
81 
82     // Cache the result of the conversion and return.
83     knownTranslations.try_emplace(type, translated);
84     return translated;
85   }
86 
87 private:
88   /// Translates the given array type.
translate(LLVM::LLVMArrayType type)89   llvm::Type *translate(LLVM::LLVMArrayType type) {
90     return llvm::ArrayType::get(translateType(type.getElementType()),
91                                 type.getNumElements());
92   }
93 
94   /// Translates the given function type.
translate(LLVM::LLVMFunctionType type)95   llvm::Type *translate(LLVM::LLVMFunctionType type) {
96     SmallVector<llvm::Type *, 8> paramTypes;
97     translateTypes(type.getParams(), paramTypes);
98     return llvm::FunctionType::get(translateType(type.getReturnType()),
99                                    paramTypes, type.isVarArg());
100   }
101 
102   /// Translates the given integer type.
translate(IntegerType type)103   llvm::Type *translate(IntegerType type) {
104     return llvm::IntegerType::get(context, type.getWidth());
105   }
106 
107   /// Translates the given pointer type.
translate(LLVM::LLVMPointerType type)108   llvm::Type *translate(LLVM::LLVMPointerType type) {
109     if (type.isOpaque())
110       return llvm::PointerType::get(context, type.getAddressSpace());
111     return llvm::PointerType::get(translateType(type.getElementType()),
112                                   type.getAddressSpace());
113   }
114 
115   /// Translates the given structure type, supports both identified and literal
116   /// structs. This will _create_ a new identified structure every time, use
117   /// `convertType` if a structure with the same name must be looked up instead.
translate(LLVM::LLVMStructType type)118   llvm::Type *translate(LLVM::LLVMStructType type) {
119     SmallVector<llvm::Type *, 8> subtypes;
120     if (!type.isIdentified()) {
121       translateTypes(type.getBody(), subtypes);
122       return llvm::StructType::get(context, subtypes, type.isPacked());
123     }
124 
125     llvm::StructType *structType =
126         llvm::StructType::create(context, type.getName());
127     // Mark the type we just created as known so that recursive calls can pick
128     // it up and use directly.
129     knownTranslations.try_emplace(type, structType);
130     if (type.isOpaque())
131       return structType;
132 
133     translateTypes(type.getBody(), subtypes);
134     structType->setBody(subtypes, type.isPacked());
135     return structType;
136   }
137 
138   /// Translates the given built-in vector type compatible with LLVM.
translate(VectorType type)139   llvm::Type *translate(VectorType type) {
140     assert(LLVM::isCompatibleVectorType(type) &&
141            "expected compatible with LLVM vector type");
142     if (type.isScalable())
143       return llvm::ScalableVectorType::get(translateType(type.getElementType()),
144                                            type.getNumElements());
145     return llvm::FixedVectorType::get(translateType(type.getElementType()),
146                                       type.getNumElements());
147   }
148 
149   /// Translates the given fixed-vector type.
translate(LLVM::LLVMFixedVectorType type)150   llvm::Type *translate(LLVM::LLVMFixedVectorType type) {
151     return llvm::FixedVectorType::get(translateType(type.getElementType()),
152                                       type.getNumElements());
153   }
154 
155   /// Translates the given scalable-vector type.
translate(LLVM::LLVMScalableVectorType type)156   llvm::Type *translate(LLVM::LLVMScalableVectorType type) {
157     return llvm::ScalableVectorType::get(translateType(type.getElementType()),
158                                          type.getMinNumElements());
159   }
160 
161   /// Translates a list of types.
translateTypes(ArrayRef<Type> types,SmallVectorImpl<llvm::Type * > & result)162   void translateTypes(ArrayRef<Type> types,
163                       SmallVectorImpl<llvm::Type *> &result) {
164     result.reserve(result.size() + types.size());
165     for (auto type : types)
166       result.push_back(translateType(type));
167   }
168 
169   /// Reference to the context in which the LLVM IR types are created.
170   llvm::LLVMContext &context;
171 
172   /// Map of known translation. This serves a double purpose: caches translation
173   /// results to avoid repeated recursive calls and makes sure identified
174   /// structs with the same name (that is, equal) are resolved to an existing
175   /// type instead of creating a new type.
176   llvm::DenseMap<Type, llvm::Type *> knownTranslations;
177 };
178 } // namespace detail
179 } // namespace LLVM
180 } // namespace mlir
181 
TypeToLLVMIRTranslator(llvm::LLVMContext & context)182 LLVM::TypeToLLVMIRTranslator::TypeToLLVMIRTranslator(llvm::LLVMContext &context)
183     : impl(new detail::TypeToLLVMIRTranslatorImpl(context)) {}
184 
185 LLVM::TypeToLLVMIRTranslator::~TypeToLLVMIRTranslator() = default;
186 
translateType(Type type)187 llvm::Type *LLVM::TypeToLLVMIRTranslator::translateType(Type type) {
188   return impl->translateType(type);
189 }
190 
getPreferredAlignment(Type type,const llvm::DataLayout & layout)191 unsigned LLVM::TypeToLLVMIRTranslator::getPreferredAlignment(
192     Type type, const llvm::DataLayout &layout) {
193   return layout.getPrefTypeAlignment(translateType(type));
194 }
195