1 //
2 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
3 // See https://llvm.org/LICENSE.txt for license information.
4 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5 //
6 //===----------------------------------------------------------------------===//
7 //
8 // This file implements the types for the LLVM dialect in MLIR. These MLIR types
9 // correspond to the LLVM IR type system.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "TypeDetail.h"
14 
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/DialectImplementation.h"
19 #include "mlir/IR/TypeSupport.h"
20 
21 #include "llvm/ADT/ScopeExit.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/TypeSize.h"
24 
25 using namespace mlir;
26 using namespace mlir::LLVM;
27 
28 constexpr const static unsigned kBitsInByte = 8;
29 
30 //===----------------------------------------------------------------------===//
31 // Array type.
32 //===----------------------------------------------------------------------===//
33 
isValidElementType(Type type)34 bool LLVMArrayType::isValidElementType(Type type) {
35   return !type.isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
36                    LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>();
37 }
38 
get(Type elementType,unsigned numElements)39 LLVMArrayType LLVMArrayType::get(Type elementType, unsigned numElements) {
40   assert(elementType && "expected non-null subtype");
41   return Base::get(elementType.getContext(), elementType, numElements);
42 }
43 
44 LLVMArrayType
getChecked(function_ref<InFlightDiagnostic ()> emitError,Type elementType,unsigned numElements)45 LLVMArrayType::getChecked(function_ref<InFlightDiagnostic()> emitError,
46                           Type elementType, unsigned numElements) {
47   assert(elementType && "expected non-null subtype");
48   return Base::getChecked(emitError, elementType.getContext(), elementType,
49                           numElements);
50 }
51 
getElementType() const52 Type LLVMArrayType::getElementType() const { return getImpl()->elementType; }
53 
getNumElements() const54 unsigned LLVMArrayType::getNumElements() const {
55   return getImpl()->numElements;
56 }
57 
58 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Type elementType,unsigned numElements)59 LLVMArrayType::verify(function_ref<InFlightDiagnostic()> emitError,
60                       Type elementType, unsigned numElements) {
61   if (!isValidElementType(elementType))
62     return emitError() << "invalid array element type: " << elementType;
63   return success();
64 }
65 
getTypeSizeInBits(const DataLayout & dataLayout,DataLayoutEntryListRef params) const66 unsigned LLVMArrayType::getTypeSizeInBits(const DataLayout &dataLayout,
67                                           DataLayoutEntryListRef params) const {
68   return kBitsInByte * getTypeSize(dataLayout, params);
69 }
70 
getTypeSize(const DataLayout & dataLayout,DataLayoutEntryListRef params) const71 unsigned LLVMArrayType::getTypeSize(const DataLayout &dataLayout,
72                                     DataLayoutEntryListRef params) const {
73   return llvm::alignTo(dataLayout.getTypeSize(getElementType()),
74                        dataLayout.getTypeABIAlignment(getElementType())) *
75          getNumElements();
76 }
77 
getABIAlignment(const DataLayout & dataLayout,DataLayoutEntryListRef params) const78 unsigned LLVMArrayType::getABIAlignment(const DataLayout &dataLayout,
79                                         DataLayoutEntryListRef params) const {
80   return dataLayout.getTypeABIAlignment(getElementType());
81 }
82 
83 unsigned
getPreferredAlignment(const DataLayout & dataLayout,DataLayoutEntryListRef params) const84 LLVMArrayType::getPreferredAlignment(const DataLayout &dataLayout,
85                                      DataLayoutEntryListRef params) const {
86   return dataLayout.getTypePreferredAlignment(getElementType());
87 }
88 
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const89 void LLVMArrayType::walkImmediateSubElements(
90     function_ref<void(Attribute)> walkAttrsFn,
91     function_ref<void(Type)> walkTypesFn) const {
92   walkTypesFn(getElementType());
93 }
94 
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const95 Type LLVMArrayType::replaceImmediateSubElements(
96     ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
97   return get(replTypes.front(), getNumElements());
98 }
99 
100 //===----------------------------------------------------------------------===//
101 // Function type.
102 //===----------------------------------------------------------------------===//
103 
isValidArgumentType(Type type)104 bool LLVMFunctionType::isValidArgumentType(Type type) {
105   return !type.isa<LLVMVoidType, LLVMFunctionType>();
106 }
107 
isValidResultType(Type type)108 bool LLVMFunctionType::isValidResultType(Type type) {
109   return !type.isa<LLVMFunctionType, LLVMMetadataType, LLVMLabelType>();
110 }
111 
get(Type result,ArrayRef<Type> arguments,bool isVarArg)112 LLVMFunctionType LLVMFunctionType::get(Type result, ArrayRef<Type> arguments,
113                                        bool isVarArg) {
114   assert(result && "expected non-null result");
115   return Base::get(result.getContext(), result, arguments, isVarArg);
116 }
117 
118 LLVMFunctionType
getChecked(function_ref<InFlightDiagnostic ()> emitError,Type result,ArrayRef<Type> arguments,bool isVarArg)119 LLVMFunctionType::getChecked(function_ref<InFlightDiagnostic()> emitError,
120                              Type result, ArrayRef<Type> arguments,
121                              bool isVarArg) {
122   assert(result && "expected non-null result");
123   return Base::getChecked(emitError, result.getContext(), result, arguments,
124                           isVarArg);
125 }
126 
clone(TypeRange inputs,TypeRange results) const127 LLVMFunctionType LLVMFunctionType::clone(TypeRange inputs,
128                                          TypeRange results) const {
129   assert(results.size() == 1 && "expected a single result type");
130   return get(results[0], llvm::to_vector(inputs), isVarArg());
131 }
132 
getReturnType() const133 Type LLVMFunctionType::getReturnType() const {
134   return getImpl()->getReturnType();
135 }
getReturnTypes() const136 ArrayRef<Type> LLVMFunctionType::getReturnTypes() const {
137   return getImpl()->getReturnType();
138 }
139 
getNumParams()140 unsigned LLVMFunctionType::getNumParams() {
141   return getImpl()->getArgumentTypes().size();
142 }
143 
getParamType(unsigned i)144 Type LLVMFunctionType::getParamType(unsigned i) {
145   return getImpl()->getArgumentTypes()[i];
146 }
147 
isVarArg() const148 bool LLVMFunctionType::isVarArg() const { return getImpl()->isVariadic(); }
149 
getParams() const150 ArrayRef<Type> LLVMFunctionType::getParams() const {
151   return getImpl()->getArgumentTypes();
152 }
153 
154 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Type result,ArrayRef<Type> arguments,bool)155 LLVMFunctionType::verify(function_ref<InFlightDiagnostic()> emitError,
156                          Type result, ArrayRef<Type> arguments, bool) {
157   if (!isValidResultType(result))
158     return emitError() << "invalid function result type: " << result;
159 
160   for (Type arg : arguments)
161     if (!isValidArgumentType(arg))
162       return emitError() << "invalid function argument type: " << arg;
163 
164   return success();
165 }
166 
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const167 void LLVMFunctionType::walkImmediateSubElements(
168     function_ref<void(Attribute)> walkAttrsFn,
169     function_ref<void(Type)> walkTypesFn) const {
170   for (Type type : llvm::concat<const Type>(getReturnTypes(), getParams()))
171     walkTypesFn(type);
172 }
173 
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const174 Type LLVMFunctionType::replaceImmediateSubElements(
175     ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
176   return get(replTypes.front(), replTypes.drop_front(), isVarArg());
177 }
178 
179 //===----------------------------------------------------------------------===//
180 // Pointer type.
181 //===----------------------------------------------------------------------===//
182 
isValidElementType(Type type)183 bool LLVMPointerType::isValidElementType(Type type) {
184   if (!type)
185     return true;
186   return isCompatibleOuterType(type)
187              ? !type.isa<LLVMVoidType, LLVMTokenType, LLVMMetadataType,
188                          LLVMLabelType>()
189              : type.isa<PointerElementTypeInterface>();
190 }
191 
get(Type pointee,unsigned addressSpace)192 LLVMPointerType LLVMPointerType::get(Type pointee, unsigned addressSpace) {
193   assert(pointee && "expected non-null subtype, pass the context instead if "
194                     "the opaque pointer type is desired");
195   return Base::get(pointee.getContext(), pointee, addressSpace);
196 }
197 
get(MLIRContext * context,unsigned addressSpace)198 LLVMPointerType LLVMPointerType::get(MLIRContext *context,
199                                      unsigned addressSpace) {
200   return Base::get(context, Type(), addressSpace);
201 }
202 
203 LLVMPointerType
getChecked(function_ref<InFlightDiagnostic ()> emitError,Type pointee,unsigned addressSpace)204 LLVMPointerType::getChecked(function_ref<InFlightDiagnostic()> emitError,
205                             Type pointee, unsigned addressSpace) {
206   return Base::getChecked(emitError, pointee.getContext(), pointee,
207                           addressSpace);
208 }
209 
210 LLVMPointerType
getChecked(function_ref<InFlightDiagnostic ()> emitError,MLIRContext * context,unsigned addressSpace)211 LLVMPointerType::getChecked(function_ref<InFlightDiagnostic()> emitError,
212                             MLIRContext *context, unsigned addressSpace) {
213   return Base::getChecked(emitError, context, Type(), addressSpace);
214 }
215 
getElementType() const216 Type LLVMPointerType::getElementType() const { return getImpl()->pointeeType; }
217 
isOpaque() const218 bool LLVMPointerType::isOpaque() const { return !getImpl()->pointeeType; }
219 
getAddressSpace() const220 unsigned LLVMPointerType::getAddressSpace() const {
221   return getImpl()->addressSpace;
222 }
223 
224 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Type pointee,unsigned)225 LLVMPointerType::verify(function_ref<InFlightDiagnostic()> emitError,
226                         Type pointee, unsigned) {
227   if (!isValidElementType(pointee))
228     return emitError() << "invalid pointer element type: " << pointee;
229   return success();
230 }
231 
232 namespace {
233 /// The positions of different values in the data layout entry.
234 enum class DLEntryPos { Size = 0, Abi = 1, Preferred = 2, Address = 3 };
235 } // namespace
236 
237 constexpr const static unsigned kDefaultPointerSizeBits = 64;
238 constexpr const static unsigned kDefaultPointerAlignment = 8;
239 
240 /// Returns the value that corresponds to named position `pos` from the
241 /// attribute `attr` assuming it's a dense integer elements attribute.
extractPointerSpecValue(Attribute attr,DLEntryPos pos)242 static unsigned extractPointerSpecValue(Attribute attr, DLEntryPos pos) {
243   return attr.cast<DenseIntElementsAttr>()
244       .getValues<unsigned>()[static_cast<unsigned>(pos)];
245 }
246 
247 /// Returns the part of the data layout entry that corresponds to `pos` for the
248 /// given `type` by interpreting the list of entries `params`. For the pointer
249 /// type in the default address space, returns the default value if the entries
250 /// do not provide a custom one, for other address spaces returns None.
251 static Optional<unsigned>
getPointerDataLayoutEntry(DataLayoutEntryListRef params,LLVMPointerType type,DLEntryPos pos)252 getPointerDataLayoutEntry(DataLayoutEntryListRef params, LLVMPointerType type,
253                           DLEntryPos pos) {
254   // First, look for the entry for the pointer in the current address space.
255   Attribute currentEntry;
256   for (DataLayoutEntryInterface entry : params) {
257     if (!entry.isTypeEntry())
258       continue;
259     if (entry.getKey().get<Type>().cast<LLVMPointerType>().getAddressSpace() ==
260         type.getAddressSpace()) {
261       currentEntry = entry.getValue();
262       break;
263     }
264   }
265   if (currentEntry) {
266     return extractPointerSpecValue(currentEntry, pos) /
267            (pos == DLEntryPos::Size ? 1 : kBitsInByte);
268   }
269 
270   // If not found, and this is the pointer to the default memory space, assume
271   // 64-bit pointers.
272   if (type.getAddressSpace() == 0) {
273     return pos == DLEntryPos::Size ? kDefaultPointerSizeBits
274                                    : kDefaultPointerAlignment;
275   }
276 
277   return llvm::None;
278 }
279 
280 unsigned
getTypeSizeInBits(const DataLayout & dataLayout,DataLayoutEntryListRef params) const281 LLVMPointerType::getTypeSizeInBits(const DataLayout &dataLayout,
282                                    DataLayoutEntryListRef params) const {
283   if (Optional<unsigned> size =
284           getPointerDataLayoutEntry(params, *this, DLEntryPos::Size))
285     return *size;
286 
287   // For other memory spaces, use the size of the pointer to the default memory
288   // space.
289   if (isOpaque())
290     return dataLayout.getTypeSizeInBits(get(getContext()));
291   return dataLayout.getTypeSizeInBits(get(getElementType()));
292 }
293 
getABIAlignment(const DataLayout & dataLayout,DataLayoutEntryListRef params) const294 unsigned LLVMPointerType::getABIAlignment(const DataLayout &dataLayout,
295                                           DataLayoutEntryListRef params) const {
296   if (Optional<unsigned> alignment =
297           getPointerDataLayoutEntry(params, *this, DLEntryPos::Abi))
298     return *alignment;
299 
300   if (isOpaque())
301     return dataLayout.getTypeABIAlignment(get(getContext()));
302   return dataLayout.getTypeABIAlignment(get(getElementType()));
303 }
304 
305 unsigned
getPreferredAlignment(const DataLayout & dataLayout,DataLayoutEntryListRef params) const306 LLVMPointerType::getPreferredAlignment(const DataLayout &dataLayout,
307                                        DataLayoutEntryListRef params) const {
308   if (Optional<unsigned> alignment =
309           getPointerDataLayoutEntry(params, *this, DLEntryPos::Preferred))
310     return *alignment;
311 
312   if (isOpaque())
313     return dataLayout.getTypePreferredAlignment(get(getContext()));
314   return dataLayout.getTypePreferredAlignment(get(getElementType()));
315 }
316 
areCompatible(DataLayoutEntryListRef oldLayout,DataLayoutEntryListRef newLayout) const317 bool LLVMPointerType::areCompatible(DataLayoutEntryListRef oldLayout,
318                                     DataLayoutEntryListRef newLayout) const {
319   for (DataLayoutEntryInterface newEntry : newLayout) {
320     if (!newEntry.isTypeEntry())
321       continue;
322     unsigned size = kDefaultPointerSizeBits;
323     unsigned abi = kDefaultPointerAlignment;
324     auto newType = newEntry.getKey().get<Type>().cast<LLVMPointerType>();
325     const auto *it =
326         llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
327           if (auto type = entry.getKey().dyn_cast<Type>()) {
328             return type.cast<LLVMPointerType>().getAddressSpace() ==
329                    newType.getAddressSpace();
330           }
331           return false;
332         });
333     if (it == oldLayout.end()) {
334       llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
335         if (auto type = entry.getKey().dyn_cast<Type>()) {
336           return type.cast<LLVMPointerType>().getAddressSpace() == 0;
337         }
338         return false;
339       });
340     }
341     if (it != oldLayout.end()) {
342       size = extractPointerSpecValue(*it, DLEntryPos::Size);
343       abi = extractPointerSpecValue(*it, DLEntryPos::Abi);
344     }
345 
346     Attribute newSpec = newEntry.getValue().cast<DenseIntElementsAttr>();
347     unsigned newSize = extractPointerSpecValue(newSpec, DLEntryPos::Size);
348     unsigned newAbi = extractPointerSpecValue(newSpec, DLEntryPos::Abi);
349     if (size != newSize || abi < newAbi || abi % newAbi != 0)
350       return false;
351   }
352   return true;
353 }
354 
verifyEntries(DataLayoutEntryListRef entries,Location loc) const355 LogicalResult LLVMPointerType::verifyEntries(DataLayoutEntryListRef entries,
356                                              Location loc) const {
357   for (DataLayoutEntryInterface entry : entries) {
358     if (!entry.isTypeEntry())
359       continue;
360     auto key = entry.getKey().get<Type>().cast<LLVMPointerType>();
361     auto values = entry.getValue().dyn_cast<DenseIntElementsAttr>();
362     if (!values || (values.size() != 3 && values.size() != 4)) {
363       return emitError(loc)
364              << "expected layout attribute for " << entry.getKey().get<Type>()
365              << " to be a dense integer elements attribute with 3 or 4 "
366                 "elements";
367     }
368     if (key.getElementType() && !key.getElementType().isInteger(8)) {
369       return emitError(loc) << "unexpected layout attribute for pointer to "
370                             << key.getElementType();
371     }
372     if (extractPointerSpecValue(values, DLEntryPos::Abi) >
373         extractPointerSpecValue(values, DLEntryPos::Preferred)) {
374       return emitError(loc) << "preferred alignment is expected to be at least "
375                                "as large as ABI alignment";
376     }
377   }
378   return success();
379 }
380 
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const381 void LLVMPointerType::walkImmediateSubElements(
382     function_ref<void(Attribute)> walkAttrsFn,
383     function_ref<void(Type)> walkTypesFn) const {
384   walkTypesFn(getElementType());
385 }
386 
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const387 Type LLVMPointerType::replaceImmediateSubElements(
388     ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
389   return get(replTypes.front(), getAddressSpace());
390 }
391 
392 //===----------------------------------------------------------------------===//
393 // Struct type.
394 //===----------------------------------------------------------------------===//
395 
isValidElementType(Type type)396 bool LLVMStructType::isValidElementType(Type type) {
397   return !type.isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
398                    LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>();
399 }
400 
getIdentified(MLIRContext * context,StringRef name)401 LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
402                                              StringRef name) {
403   return Base::get(context, name, /*opaque=*/false);
404 }
405 
getIdentifiedChecked(function_ref<InFlightDiagnostic ()> emitError,MLIRContext * context,StringRef name)406 LLVMStructType LLVMStructType::getIdentifiedChecked(
407     function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
408     StringRef name) {
409   return Base::getChecked(emitError, context, name, /*opaque=*/false);
410 }
411 
getNewIdentified(MLIRContext * context,StringRef name,ArrayRef<Type> elements,bool isPacked)412 LLVMStructType LLVMStructType::getNewIdentified(MLIRContext *context,
413                                                 StringRef name,
414                                                 ArrayRef<Type> elements,
415                                                 bool isPacked) {
416   std::string stringName = name.str();
417   unsigned counter = 0;
418   do {
419     auto type = LLVMStructType::getIdentified(context, stringName);
420     if (type.isInitialized() || failed(type.setBody(elements, isPacked))) {
421       counter += 1;
422       stringName = (Twine(name) + "." + std::to_string(counter)).str();
423       continue;
424     }
425     return type;
426   } while (true);
427 }
428 
getLiteral(MLIRContext * context,ArrayRef<Type> types,bool isPacked)429 LLVMStructType LLVMStructType::getLiteral(MLIRContext *context,
430                                           ArrayRef<Type> types, bool isPacked) {
431   return Base::get(context, types, isPacked);
432 }
433 
434 LLVMStructType
getLiteralChecked(function_ref<InFlightDiagnostic ()> emitError,MLIRContext * context,ArrayRef<Type> types,bool isPacked)435 LLVMStructType::getLiteralChecked(function_ref<InFlightDiagnostic()> emitError,
436                                   MLIRContext *context, ArrayRef<Type> types,
437                                   bool isPacked) {
438   return Base::getChecked(emitError, context, types, isPacked);
439 }
440 
getOpaque(StringRef name,MLIRContext * context)441 LLVMStructType LLVMStructType::getOpaque(StringRef name, MLIRContext *context) {
442   return Base::get(context, name, /*opaque=*/true);
443 }
444 
445 LLVMStructType
getOpaqueChecked(function_ref<InFlightDiagnostic ()> emitError,MLIRContext * context,StringRef name)446 LLVMStructType::getOpaqueChecked(function_ref<InFlightDiagnostic()> emitError,
447                                  MLIRContext *context, StringRef name) {
448   return Base::getChecked(emitError, context, name, /*opaque=*/true);
449 }
450 
setBody(ArrayRef<Type> types,bool isPacked)451 LogicalResult LLVMStructType::setBody(ArrayRef<Type> types, bool isPacked) {
452   assert(isIdentified() && "can only set bodies of identified structs");
453   assert(llvm::all_of(types, LLVMStructType::isValidElementType) &&
454          "expected valid body types");
455   return Base::mutate(types, isPacked);
456 }
457 
isPacked() const458 bool LLVMStructType::isPacked() const { return getImpl()->isPacked(); }
isIdentified() const459 bool LLVMStructType::isIdentified() const { return getImpl()->isIdentified(); }
isOpaque()460 bool LLVMStructType::isOpaque() {
461   return getImpl()->isIdentified() &&
462          (getImpl()->isOpaque() || !getImpl()->isInitialized());
463 }
isInitialized()464 bool LLVMStructType::isInitialized() { return getImpl()->isInitialized(); }
getName()465 StringRef LLVMStructType::getName() { return getImpl()->getIdentifier(); }
getBody() const466 ArrayRef<Type> LLVMStructType::getBody() const {
467   return isIdentified() ? getImpl()->getIdentifiedStructBody()
468                         : getImpl()->getTypeList();
469 }
470 
verify(function_ref<InFlightDiagnostic ()>,StringRef,bool)471 LogicalResult LLVMStructType::verify(function_ref<InFlightDiagnostic()>,
472                                      StringRef, bool) {
473   return success();
474 }
475 
476 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,ArrayRef<Type> types,bool)477 LLVMStructType::verify(function_ref<InFlightDiagnostic()> emitError,
478                        ArrayRef<Type> types, bool) {
479   for (Type t : types)
480     if (!isValidElementType(t))
481       return emitError() << "invalid LLVM structure element type: " << t;
482 
483   return success();
484 }
485 
486 unsigned
getTypeSizeInBits(const DataLayout & dataLayout,DataLayoutEntryListRef params) const487 LLVMStructType::getTypeSizeInBits(const DataLayout &dataLayout,
488                                   DataLayoutEntryListRef params) const {
489   unsigned structSize = 0;
490   unsigned structAlignment = 1;
491   for (Type element : getBody()) {
492     unsigned elementAlignment =
493         isPacked() ? 1 : dataLayout.getTypeABIAlignment(element);
494     // Add padding to the struct size to align it to the abi alignment of the
495     // element type before than adding the size of the element
496     structSize = llvm::alignTo(structSize, elementAlignment);
497     structSize += dataLayout.getTypeSize(element);
498 
499     // The alignment requirement of a struct is equal to the strictest alignment
500     // requirement of its elements.
501     structAlignment = std::max(elementAlignment, structAlignment);
502   }
503   // At the end, add padding to the struct to satisfy its own alignment
504   // requirement. Otherwise structs inside of arrays would be misaligned.
505   structSize = llvm::alignTo(structSize, structAlignment);
506   return structSize * kBitsInByte;
507 }
508 
509 namespace {
510 enum class StructDLEntryPos { Abi = 0, Preferred = 1 };
511 } // namespace
512 
513 static Optional<unsigned>
getStructDataLayoutEntry(DataLayoutEntryListRef params,LLVMStructType type,StructDLEntryPos pos)514 getStructDataLayoutEntry(DataLayoutEntryListRef params, LLVMStructType type,
515                          StructDLEntryPos pos) {
516   const auto *currentEntry =
517       llvm::find_if(params, [](DataLayoutEntryInterface entry) {
518         return entry.isTypeEntry();
519       });
520   if (currentEntry == params.end())
521     return llvm::None;
522 
523   auto attr = currentEntry->getValue().cast<DenseIntElementsAttr>();
524   if (pos == StructDLEntryPos::Preferred &&
525       attr.size() <= static_cast<unsigned>(StructDLEntryPos::Preferred))
526     // If no preferred was specified, fall back to abi alignment
527     pos = StructDLEntryPos::Abi;
528 
529   return attr.getValues<unsigned>()[static_cast<unsigned>(pos)];
530 }
531 
calculateStructAlignment(const DataLayout & dataLayout,DataLayoutEntryListRef params,LLVMStructType type,StructDLEntryPos pos)532 static unsigned calculateStructAlignment(const DataLayout &dataLayout,
533                                          DataLayoutEntryListRef params,
534                                          LLVMStructType type,
535                                          StructDLEntryPos pos) {
536   // Packed structs always have an abi alignment of 1
537   if (pos == StructDLEntryPos::Abi && type.isPacked()) {
538     return 1;
539   }
540 
541   // The alignment requirement of a struct is equal to the strictest alignment
542   // requirement of its elements.
543   unsigned structAlignment = 1;
544   for (Type iter : type.getBody()) {
545     structAlignment =
546         std::max(dataLayout.getTypeABIAlignment(iter), structAlignment);
547   }
548 
549   // Entries are only allowed to be stricter than the required alignment
550   if (Optional<unsigned> entryResult =
551           getStructDataLayoutEntry(params, type, pos))
552     return std::max(*entryResult / kBitsInByte, structAlignment);
553 
554   return structAlignment;
555 }
556 
getABIAlignment(const DataLayout & dataLayout,DataLayoutEntryListRef params) const557 unsigned LLVMStructType::getABIAlignment(const DataLayout &dataLayout,
558                                          DataLayoutEntryListRef params) const {
559   return calculateStructAlignment(dataLayout, params, *this,
560                                   StructDLEntryPos::Abi);
561 }
562 
563 unsigned
getPreferredAlignment(const DataLayout & dataLayout,DataLayoutEntryListRef params) const564 LLVMStructType::getPreferredAlignment(const DataLayout &dataLayout,
565                                       DataLayoutEntryListRef params) const {
566   return calculateStructAlignment(dataLayout, params, *this,
567                                   StructDLEntryPos::Preferred);
568 }
569 
extractStructSpecValue(Attribute attr,StructDLEntryPos pos)570 static unsigned extractStructSpecValue(Attribute attr, StructDLEntryPos pos) {
571   return attr.cast<DenseIntElementsAttr>()
572       .getValues<unsigned>()[static_cast<unsigned>(pos)];
573 }
574 
areCompatible(DataLayoutEntryListRef oldLayout,DataLayoutEntryListRef newLayout) const575 bool LLVMStructType::areCompatible(DataLayoutEntryListRef oldLayout,
576                                    DataLayoutEntryListRef newLayout) const {
577   for (DataLayoutEntryInterface newEntry : newLayout) {
578     if (!newEntry.isTypeEntry())
579       continue;
580 
581     const auto *previousEntry =
582         llvm::find_if(oldLayout, [](DataLayoutEntryInterface entry) {
583           return entry.isTypeEntry();
584         });
585     if (previousEntry == oldLayout.end())
586       continue;
587 
588     unsigned abi = extractStructSpecValue(previousEntry->getValue(),
589                                           StructDLEntryPos::Abi);
590     unsigned newAbi =
591         extractStructSpecValue(newEntry.getValue(), StructDLEntryPos::Abi);
592     if (abi < newAbi || abi % newAbi != 0)
593       return false;
594   }
595   return true;
596 }
597 
verifyEntries(DataLayoutEntryListRef entries,Location loc) const598 LogicalResult LLVMStructType::verifyEntries(DataLayoutEntryListRef entries,
599                                             Location loc) const {
600   for (DataLayoutEntryInterface entry : entries) {
601     if (!entry.isTypeEntry())
602       continue;
603 
604     auto key = entry.getKey().get<Type>().cast<LLVMStructType>();
605     auto values = entry.getValue().dyn_cast<DenseIntElementsAttr>();
606     if (!values || (values.size() != 2 && values.size() != 1)) {
607       return emitError(loc)
608              << "expected layout attribute for " << entry.getKey().get<Type>()
609              << " to be a dense integer elements attribute of 1 or 2 elements";
610     }
611 
612     if (key.isIdentified() || !key.getBody().empty()) {
613       return emitError(loc) << "unexpected layout attribute for struct " << key;
614     }
615 
616     if (values.size() == 1)
617       continue;
618 
619     if (extractStructSpecValue(values, StructDLEntryPos::Abi) >
620         extractStructSpecValue(values, StructDLEntryPos::Preferred)) {
621       return emitError(loc) << "preferred alignment is expected to be at least "
622                                "as large as ABI alignment";
623     }
624   }
625   return mlir::success();
626 }
627 
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const628 void LLVMStructType::walkImmediateSubElements(
629     function_ref<void(Attribute)> walkAttrsFn,
630     function_ref<void(Type)> walkTypesFn) const {
631   for (Type type : getBody())
632     walkTypesFn(type);
633 }
634 
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const635 Type LLVMStructType::replaceImmediateSubElements(
636     ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
637   // TODO: It's not clear how we support replacing sub-elements of mutable
638   // types.
639   return nullptr;
640 }
641 
642 //===----------------------------------------------------------------------===//
643 // Vector types.
644 //===----------------------------------------------------------------------===//
645 
646 /// Verifies that the type about to be constructed is well-formed.
647 template <typename VecTy>
648 static LogicalResult
verifyVectorConstructionInvariants(function_ref<InFlightDiagnostic ()> emitError,Type elementType,unsigned numElements)649 verifyVectorConstructionInvariants(function_ref<InFlightDiagnostic()> emitError,
650                                    Type elementType, unsigned numElements) {
651   if (numElements == 0)
652     return emitError() << "the number of vector elements must be positive";
653 
654   if (!VecTy::isValidElementType(elementType))
655     return emitError() << "invalid vector element type";
656 
657   return success();
658 }
659 
get(Type elementType,unsigned numElements)660 LLVMFixedVectorType LLVMFixedVectorType::get(Type elementType,
661                                              unsigned numElements) {
662   assert(elementType && "expected non-null subtype");
663   return Base::get(elementType.getContext(), elementType, numElements);
664 }
665 
666 LLVMFixedVectorType
getChecked(function_ref<InFlightDiagnostic ()> emitError,Type elementType,unsigned numElements)667 LLVMFixedVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
668                                 Type elementType, unsigned numElements) {
669   assert(elementType && "expected non-null subtype");
670   return Base::getChecked(emitError, elementType.getContext(), elementType,
671                           numElements);
672 }
673 
getElementType() const674 Type LLVMFixedVectorType::getElementType() const {
675   return static_cast<detail::LLVMTypeAndSizeStorage *>(impl)->elementType;
676 }
677 
getNumElements() const678 unsigned LLVMFixedVectorType::getNumElements() const {
679   return getImpl()->numElements;
680 }
681 
isValidElementType(Type type)682 bool LLVMFixedVectorType::isValidElementType(Type type) {
683   return type.isa<LLVMPointerType, LLVMPPCFP128Type>();
684 }
685 
686 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Type elementType,unsigned numElements)687 LLVMFixedVectorType::verify(function_ref<InFlightDiagnostic()> emitError,
688                             Type elementType, unsigned numElements) {
689   return verifyVectorConstructionInvariants<LLVMFixedVectorType>(
690       emitError, elementType, numElements);
691 }
692 
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const693 void LLVMFixedVectorType::walkImmediateSubElements(
694     function_ref<void(Attribute)> walkAttrsFn,
695     function_ref<void(Type)> walkTypesFn) const {
696   walkTypesFn(getElementType());
697 }
698 
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const699 Type LLVMFixedVectorType::replaceImmediateSubElements(
700     ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
701   return get(replTypes[0], getNumElements());
702 }
703 
704 //===----------------------------------------------------------------------===//
705 // LLVMScalableVectorType.
706 //===----------------------------------------------------------------------===//
707 
get(Type elementType,unsigned minNumElements)708 LLVMScalableVectorType LLVMScalableVectorType::get(Type elementType,
709                                                    unsigned minNumElements) {
710   assert(elementType && "expected non-null subtype");
711   return Base::get(elementType.getContext(), elementType, minNumElements);
712 }
713 
714 LLVMScalableVectorType
getChecked(function_ref<InFlightDiagnostic ()> emitError,Type elementType,unsigned minNumElements)715 LLVMScalableVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
716                                    Type elementType, unsigned minNumElements) {
717   assert(elementType && "expected non-null subtype");
718   return Base::getChecked(emitError, elementType.getContext(), elementType,
719                           minNumElements);
720 }
721 
getElementType() const722 Type LLVMScalableVectorType::getElementType() const {
723   return static_cast<detail::LLVMTypeAndSizeStorage *>(impl)->elementType;
724 }
725 
getMinNumElements() const726 unsigned LLVMScalableVectorType::getMinNumElements() const {
727   return getImpl()->numElements;
728 }
729 
isValidElementType(Type type)730 bool LLVMScalableVectorType::isValidElementType(Type type) {
731   if (auto intType = type.dyn_cast<IntegerType>())
732     return intType.isSignless();
733 
734   return isCompatibleFloatingPointType(type) || type.isa<LLVMPointerType>();
735 }
736 
737 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Type elementType,unsigned numElements)738 LLVMScalableVectorType::verify(function_ref<InFlightDiagnostic()> emitError,
739                                Type elementType, unsigned numElements) {
740   return verifyVectorConstructionInvariants<LLVMScalableVectorType>(
741       emitError, elementType, numElements);
742 }
743 
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const744 void LLVMScalableVectorType::walkImmediateSubElements(
745     function_ref<void(Attribute)> walkAttrsFn,
746     function_ref<void(Type)> walkTypesFn) const {
747   walkTypesFn(getElementType());
748 }
749 
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const750 Type LLVMScalableVectorType::replaceImmediateSubElements(
751     ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
752   return get(replTypes[0], getMinNumElements());
753 }
754 
755 //===----------------------------------------------------------------------===//
756 // Utility functions.
757 //===----------------------------------------------------------------------===//
758 
isCompatibleOuterType(Type type)759 bool mlir::LLVM::isCompatibleOuterType(Type type) {
760   // clang-format off
761   if (type.isa<
762       BFloat16Type,
763       Float16Type,
764       Float32Type,
765       Float64Type,
766       Float80Type,
767       Float128Type,
768       LLVMArrayType,
769       LLVMFunctionType,
770       LLVMLabelType,
771       LLVMMetadataType,
772       LLVMPPCFP128Type,
773       LLVMPointerType,
774       LLVMStructType,
775       LLVMTokenType,
776       LLVMFixedVectorType,
777       LLVMScalableVectorType,
778       LLVMVoidType,
779       LLVMX86MMXType
780     >()) {
781     // clang-format on
782     return true;
783   }
784 
785   // Only signless integers are compatible.
786   if (auto intType = type.dyn_cast<IntegerType>())
787     return intType.isSignless();
788 
789   // 1D vector types are compatible.
790   if (auto vecType = type.dyn_cast<VectorType>())
791     return vecType.getRank() == 1;
792 
793   return false;
794 }
795 
isCompatibleImpl(Type type,DenseSet<Type> & compatibleTypes)796 static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
797   if (!compatibleTypes.insert(type).second)
798     return true;
799 
800   auto isCompatible = [&](Type type) {
801     return isCompatibleImpl(type, compatibleTypes);
802   };
803 
804   bool result =
805       llvm::TypeSwitch<Type, bool>(type)
806           .Case<LLVMStructType>([&](auto structType) {
807             return llvm::all_of(structType.getBody(), isCompatible);
808           })
809           .Case<LLVMFunctionType>([&](auto funcType) {
810             return isCompatible(funcType.getReturnType()) &&
811                    llvm::all_of(funcType.getParams(), isCompatible);
812           })
813           .Case<IntegerType>([](auto intType) { return intType.isSignless(); })
814           .Case<VectorType>([&](auto vecType) {
815             return vecType.getRank() == 1 &&
816                    isCompatible(vecType.getElementType());
817           })
818           .Case<LLVMPointerType>([&](auto pointerType) {
819             if (pointerType.isOpaque())
820               return true;
821             return isCompatible(pointerType.getElementType());
822           })
823           // clang-format off
824           .Case<
825               LLVMFixedVectorType,
826               LLVMScalableVectorType,
827               LLVMArrayType
828           >([&](auto containerType) {
829             return isCompatible(containerType.getElementType());
830           })
831           .Case<
832             BFloat16Type,
833             Float16Type,
834             Float32Type,
835             Float64Type,
836             Float80Type,
837             Float128Type,
838             LLVMLabelType,
839             LLVMMetadataType,
840             LLVMPPCFP128Type,
841             LLVMTokenType,
842             LLVMVoidType,
843             LLVMX86MMXType
844           >([](Type) { return true; })
845           // clang-format on
846           .Default([](Type) { return false; });
847 
848   if (!result)
849     compatibleTypes.erase(type);
850 
851   return result;
852 }
853 
isCompatibleType(Type type)854 bool LLVMDialect::isCompatibleType(Type type) {
855   if (auto *llvmDialect =
856           type.getContext()->getLoadedDialect<LLVM::LLVMDialect>())
857     return isCompatibleImpl(type, llvmDialect->compatibleTypes.get());
858 
859   DenseSet<Type> localCompatibleTypes;
860   return isCompatibleImpl(type, localCompatibleTypes);
861 }
862 
isCompatibleType(Type type)863 bool mlir::LLVM::isCompatibleType(Type type) {
864   return LLVMDialect::isCompatibleType(type);
865 }
866 
isCompatibleFloatingPointType(Type type)867 bool mlir::LLVM::isCompatibleFloatingPointType(Type type) {
868   return type.isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
869                   Float80Type, Float128Type, LLVMPPCFP128Type>();
870 }
871 
isCompatibleVectorType(Type type)872 bool mlir::LLVM::isCompatibleVectorType(Type type) {
873   if (type.isa<LLVMFixedVectorType, LLVMScalableVectorType>())
874     return true;
875 
876   if (auto vecType = type.dyn_cast<VectorType>()) {
877     if (vecType.getRank() != 1)
878       return false;
879     Type elementType = vecType.getElementType();
880     if (auto intType = elementType.dyn_cast<IntegerType>())
881       return intType.isSignless();
882     return elementType.isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
883                            Float80Type, Float128Type>();
884   }
885   return false;
886 }
887 
getVectorElementType(Type type)888 Type mlir::LLVM::getVectorElementType(Type type) {
889   return llvm::TypeSwitch<Type, Type>(type)
890       .Case<LLVMFixedVectorType, LLVMScalableVectorType, VectorType>(
891           [](auto ty) { return ty.getElementType(); })
892       .Default([](Type) -> Type {
893         llvm_unreachable("incompatible with LLVM vector type");
894       });
895 }
896 
getVectorNumElements(Type type)897 llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) {
898   return llvm::TypeSwitch<Type, llvm::ElementCount>(type)
899       .Case([](VectorType ty) {
900         if (ty.isScalable())
901           return llvm::ElementCount::getScalable(ty.getNumElements());
902         return llvm::ElementCount::getFixed(ty.getNumElements());
903       })
904       .Case([](LLVMFixedVectorType ty) {
905         return llvm::ElementCount::getFixed(ty.getNumElements());
906       })
907       .Case([](LLVMScalableVectorType ty) {
908         return llvm::ElementCount::getScalable(ty.getMinNumElements());
909       })
910       .Default([](Type) -> llvm::ElementCount {
911         llvm_unreachable("incompatible with LLVM vector type");
912       });
913 }
914 
isScalableVectorType(Type vectorType)915 bool mlir::LLVM::isScalableVectorType(Type vectorType) {
916   assert(
917       (vectorType
918            .isa<LLVMFixedVectorType, LLVMScalableVectorType, VectorType>()) &&
919       "expected LLVM-compatible vector type");
920   return !vectorType.isa<LLVMFixedVectorType>() &&
921          (vectorType.isa<LLVMScalableVectorType>() ||
922           vectorType.cast<VectorType>().isScalable());
923 }
924 
getVectorType(Type elementType,unsigned numElements,bool isScalable)925 Type mlir::LLVM::getVectorType(Type elementType, unsigned numElements,
926                                bool isScalable) {
927   bool useLLVM = LLVMFixedVectorType::isValidElementType(elementType);
928   bool useBuiltIn = VectorType::isValidElementType(elementType);
929   (void)useBuiltIn;
930   assert((useLLVM ^ useBuiltIn) && "expected LLVM-compatible fixed-vector type "
931                                    "to be either builtin or LLVM dialect type");
932   if (useLLVM) {
933     if (isScalable)
934       return LLVMScalableVectorType::get(elementType, numElements);
935     return LLVMFixedVectorType::get(elementType, numElements);
936   }
937   return VectorType::get(numElements, elementType, (unsigned)isScalable);
938 }
939 
getFixedVectorType(Type elementType,unsigned numElements)940 Type mlir::LLVM::getFixedVectorType(Type elementType, unsigned numElements) {
941   bool useLLVM = LLVMFixedVectorType::isValidElementType(elementType);
942   bool useBuiltIn = VectorType::isValidElementType(elementType);
943   (void)useBuiltIn;
944   assert((useLLVM ^ useBuiltIn) && "expected LLVM-compatible fixed-vector type "
945                                    "to be either builtin or LLVM dialect type");
946   if (useLLVM)
947     return LLVMFixedVectorType::get(elementType, numElements);
948   return VectorType::get(numElements, elementType);
949 }
950 
getScalableVectorType(Type elementType,unsigned numElements)951 Type mlir::LLVM::getScalableVectorType(Type elementType, unsigned numElements) {
952   bool useLLVM = LLVMScalableVectorType::isValidElementType(elementType);
953   bool useBuiltIn = VectorType::isValidElementType(elementType);
954   (void)useBuiltIn;
955   assert((useLLVM ^ useBuiltIn) && "expected LLVM-compatible scalable-vector "
956                                    "type to be either builtin or LLVM dialect "
957                                    "type");
958   if (useLLVM)
959     return LLVMScalableVectorType::get(elementType, numElements);
960   return VectorType::get(numElements, elementType, /*numScalableDims=*/1);
961 }
962 
getPrimitiveTypeSizeInBits(Type type)963 llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
964   assert(isCompatibleType(type) &&
965          "expected a type compatible with the LLVM dialect");
966 
967   return llvm::TypeSwitch<Type, llvm::TypeSize>(type)
968       .Case<BFloat16Type, Float16Type>(
969           [](Type) { return llvm::TypeSize::Fixed(16); })
970       .Case<Float32Type>([](Type) { return llvm::TypeSize::Fixed(32); })
971       .Case<Float64Type, LLVMX86MMXType>(
972           [](Type) { return llvm::TypeSize::Fixed(64); })
973       .Case<Float80Type>([](Type) { return llvm::TypeSize::Fixed(80); })
974       .Case<Float128Type>([](Type) { return llvm::TypeSize::Fixed(128); })
975       .Case<IntegerType>([](IntegerType intTy) {
976         return llvm::TypeSize::Fixed(intTy.getWidth());
977       })
978       .Case<LLVMPPCFP128Type>([](Type) { return llvm::TypeSize::Fixed(128); })
979       .Case<LLVMFixedVectorType>([](LLVMFixedVectorType t) {
980         llvm::TypeSize elementSize =
981             getPrimitiveTypeSizeInBits(t.getElementType());
982         return llvm::TypeSize(elementSize.getFixedSize() * t.getNumElements(),
983                               elementSize.isScalable());
984       })
985       .Case<VectorType>([](VectorType t) {
986         assert(isCompatibleVectorType(t) &&
987                "unexpected incompatible with LLVM vector type");
988         llvm::TypeSize elementSize =
989             getPrimitiveTypeSizeInBits(t.getElementType());
990         return llvm::TypeSize(elementSize.getFixedSize() * t.getNumElements(),
991                               elementSize.isScalable());
992       })
993       .Default([](Type ty) {
994         assert((ty.isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
995                        LLVMTokenType, LLVMStructType, LLVMArrayType,
996                        LLVMPointerType, LLVMFunctionType>()) &&
997                "unexpected missing support for primitive type");
998         return llvm::TypeSize::Fixed(0);
999       });
1000 }
1001 
1002 #include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc"
1003