1 //
2 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
3 // See https://llvm.org/LICENSE.txt for license information.
4 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5 //
6 //===----------------------------------------------------------------------===//
7 //
8 // This file implements the types for the LLVM dialect in MLIR. These MLIR types
9 // correspond to the LLVM IR type system.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "TypeDetail.h"
14
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/DialectImplementation.h"
19 #include "mlir/IR/TypeSupport.h"
20
21 #include "llvm/ADT/ScopeExit.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/TypeSize.h"
24
25 using namespace mlir;
26 using namespace mlir::LLVM;
27
28 constexpr const static unsigned kBitsInByte = 8;
29
30 //===----------------------------------------------------------------------===//
31 // Array type.
32 //===----------------------------------------------------------------------===//
33
isValidElementType(Type type)34 bool LLVMArrayType::isValidElementType(Type type) {
35 return !type.isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
36 LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>();
37 }
38
get(Type elementType,unsigned numElements)39 LLVMArrayType LLVMArrayType::get(Type elementType, unsigned numElements) {
40 assert(elementType && "expected non-null subtype");
41 return Base::get(elementType.getContext(), elementType, numElements);
42 }
43
44 LLVMArrayType
getChecked(function_ref<InFlightDiagnostic ()> emitError,Type elementType,unsigned numElements)45 LLVMArrayType::getChecked(function_ref<InFlightDiagnostic()> emitError,
46 Type elementType, unsigned numElements) {
47 assert(elementType && "expected non-null subtype");
48 return Base::getChecked(emitError, elementType.getContext(), elementType,
49 numElements);
50 }
51
getElementType() const52 Type LLVMArrayType::getElementType() const { return getImpl()->elementType; }
53
getNumElements() const54 unsigned LLVMArrayType::getNumElements() const {
55 return getImpl()->numElements;
56 }
57
58 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Type elementType,unsigned numElements)59 LLVMArrayType::verify(function_ref<InFlightDiagnostic()> emitError,
60 Type elementType, unsigned numElements) {
61 if (!isValidElementType(elementType))
62 return emitError() << "invalid array element type: " << elementType;
63 return success();
64 }
65
getTypeSizeInBits(const DataLayout & dataLayout,DataLayoutEntryListRef params) const66 unsigned LLVMArrayType::getTypeSizeInBits(const DataLayout &dataLayout,
67 DataLayoutEntryListRef params) const {
68 return kBitsInByte * getTypeSize(dataLayout, params);
69 }
70
getTypeSize(const DataLayout & dataLayout,DataLayoutEntryListRef params) const71 unsigned LLVMArrayType::getTypeSize(const DataLayout &dataLayout,
72 DataLayoutEntryListRef params) const {
73 return llvm::alignTo(dataLayout.getTypeSize(getElementType()),
74 dataLayout.getTypeABIAlignment(getElementType())) *
75 getNumElements();
76 }
77
getABIAlignment(const DataLayout & dataLayout,DataLayoutEntryListRef params) const78 unsigned LLVMArrayType::getABIAlignment(const DataLayout &dataLayout,
79 DataLayoutEntryListRef params) const {
80 return dataLayout.getTypeABIAlignment(getElementType());
81 }
82
83 unsigned
getPreferredAlignment(const DataLayout & dataLayout,DataLayoutEntryListRef params) const84 LLVMArrayType::getPreferredAlignment(const DataLayout &dataLayout,
85 DataLayoutEntryListRef params) const {
86 return dataLayout.getTypePreferredAlignment(getElementType());
87 }
88
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const89 void LLVMArrayType::walkImmediateSubElements(
90 function_ref<void(Attribute)> walkAttrsFn,
91 function_ref<void(Type)> walkTypesFn) const {
92 walkTypesFn(getElementType());
93 }
94
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const95 Type LLVMArrayType::replaceImmediateSubElements(
96 ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
97 return get(replTypes.front(), getNumElements());
98 }
99
100 //===----------------------------------------------------------------------===//
101 // Function type.
102 //===----------------------------------------------------------------------===//
103
isValidArgumentType(Type type)104 bool LLVMFunctionType::isValidArgumentType(Type type) {
105 return !type.isa<LLVMVoidType, LLVMFunctionType>();
106 }
107
isValidResultType(Type type)108 bool LLVMFunctionType::isValidResultType(Type type) {
109 return !type.isa<LLVMFunctionType, LLVMMetadataType, LLVMLabelType>();
110 }
111
get(Type result,ArrayRef<Type> arguments,bool isVarArg)112 LLVMFunctionType LLVMFunctionType::get(Type result, ArrayRef<Type> arguments,
113 bool isVarArg) {
114 assert(result && "expected non-null result");
115 return Base::get(result.getContext(), result, arguments, isVarArg);
116 }
117
118 LLVMFunctionType
getChecked(function_ref<InFlightDiagnostic ()> emitError,Type result,ArrayRef<Type> arguments,bool isVarArg)119 LLVMFunctionType::getChecked(function_ref<InFlightDiagnostic()> emitError,
120 Type result, ArrayRef<Type> arguments,
121 bool isVarArg) {
122 assert(result && "expected non-null result");
123 return Base::getChecked(emitError, result.getContext(), result, arguments,
124 isVarArg);
125 }
126
clone(TypeRange inputs,TypeRange results) const127 LLVMFunctionType LLVMFunctionType::clone(TypeRange inputs,
128 TypeRange results) const {
129 assert(results.size() == 1 && "expected a single result type");
130 return get(results[0], llvm::to_vector(inputs), isVarArg());
131 }
132
getReturnType() const133 Type LLVMFunctionType::getReturnType() const {
134 return getImpl()->getReturnType();
135 }
getReturnTypes() const136 ArrayRef<Type> LLVMFunctionType::getReturnTypes() const {
137 return getImpl()->getReturnType();
138 }
139
getNumParams()140 unsigned LLVMFunctionType::getNumParams() {
141 return getImpl()->getArgumentTypes().size();
142 }
143
getParamType(unsigned i)144 Type LLVMFunctionType::getParamType(unsigned i) {
145 return getImpl()->getArgumentTypes()[i];
146 }
147
isVarArg() const148 bool LLVMFunctionType::isVarArg() const { return getImpl()->isVariadic(); }
149
getParams() const150 ArrayRef<Type> LLVMFunctionType::getParams() const {
151 return getImpl()->getArgumentTypes();
152 }
153
154 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Type result,ArrayRef<Type> arguments,bool)155 LLVMFunctionType::verify(function_ref<InFlightDiagnostic()> emitError,
156 Type result, ArrayRef<Type> arguments, bool) {
157 if (!isValidResultType(result))
158 return emitError() << "invalid function result type: " << result;
159
160 for (Type arg : arguments)
161 if (!isValidArgumentType(arg))
162 return emitError() << "invalid function argument type: " << arg;
163
164 return success();
165 }
166
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const167 void LLVMFunctionType::walkImmediateSubElements(
168 function_ref<void(Attribute)> walkAttrsFn,
169 function_ref<void(Type)> walkTypesFn) const {
170 for (Type type : llvm::concat<const Type>(getReturnTypes(), getParams()))
171 walkTypesFn(type);
172 }
173
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const174 Type LLVMFunctionType::replaceImmediateSubElements(
175 ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
176 return get(replTypes.front(), replTypes.drop_front(), isVarArg());
177 }
178
179 //===----------------------------------------------------------------------===//
180 // Pointer type.
181 //===----------------------------------------------------------------------===//
182
isValidElementType(Type type)183 bool LLVMPointerType::isValidElementType(Type type) {
184 if (!type)
185 return true;
186 return isCompatibleOuterType(type)
187 ? !type.isa<LLVMVoidType, LLVMTokenType, LLVMMetadataType,
188 LLVMLabelType>()
189 : type.isa<PointerElementTypeInterface>();
190 }
191
get(Type pointee,unsigned addressSpace)192 LLVMPointerType LLVMPointerType::get(Type pointee, unsigned addressSpace) {
193 assert(pointee && "expected non-null subtype, pass the context instead if "
194 "the opaque pointer type is desired");
195 return Base::get(pointee.getContext(), pointee, addressSpace);
196 }
197
get(MLIRContext * context,unsigned addressSpace)198 LLVMPointerType LLVMPointerType::get(MLIRContext *context,
199 unsigned addressSpace) {
200 return Base::get(context, Type(), addressSpace);
201 }
202
203 LLVMPointerType
getChecked(function_ref<InFlightDiagnostic ()> emitError,Type pointee,unsigned addressSpace)204 LLVMPointerType::getChecked(function_ref<InFlightDiagnostic()> emitError,
205 Type pointee, unsigned addressSpace) {
206 return Base::getChecked(emitError, pointee.getContext(), pointee,
207 addressSpace);
208 }
209
210 LLVMPointerType
getChecked(function_ref<InFlightDiagnostic ()> emitError,MLIRContext * context,unsigned addressSpace)211 LLVMPointerType::getChecked(function_ref<InFlightDiagnostic()> emitError,
212 MLIRContext *context, unsigned addressSpace) {
213 return Base::getChecked(emitError, context, Type(), addressSpace);
214 }
215
getElementType() const216 Type LLVMPointerType::getElementType() const { return getImpl()->pointeeType; }
217
isOpaque() const218 bool LLVMPointerType::isOpaque() const { return !getImpl()->pointeeType; }
219
getAddressSpace() const220 unsigned LLVMPointerType::getAddressSpace() const {
221 return getImpl()->addressSpace;
222 }
223
224 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Type pointee,unsigned)225 LLVMPointerType::verify(function_ref<InFlightDiagnostic()> emitError,
226 Type pointee, unsigned) {
227 if (!isValidElementType(pointee))
228 return emitError() << "invalid pointer element type: " << pointee;
229 return success();
230 }
231
232 namespace {
233 /// The positions of different values in the data layout entry.
234 enum class DLEntryPos { Size = 0, Abi = 1, Preferred = 2, Address = 3 };
235 } // namespace
236
237 constexpr const static unsigned kDefaultPointerSizeBits = 64;
238 constexpr const static unsigned kDefaultPointerAlignment = 8;
239
240 /// Returns the value that corresponds to named position `pos` from the
241 /// attribute `attr` assuming it's a dense integer elements attribute.
extractPointerSpecValue(Attribute attr,DLEntryPos pos)242 static unsigned extractPointerSpecValue(Attribute attr, DLEntryPos pos) {
243 return attr.cast<DenseIntElementsAttr>()
244 .getValues<unsigned>()[static_cast<unsigned>(pos)];
245 }
246
247 /// Returns the part of the data layout entry that corresponds to `pos` for the
248 /// given `type` by interpreting the list of entries `params`. For the pointer
249 /// type in the default address space, returns the default value if the entries
250 /// do not provide a custom one, for other address spaces returns None.
251 static Optional<unsigned>
getPointerDataLayoutEntry(DataLayoutEntryListRef params,LLVMPointerType type,DLEntryPos pos)252 getPointerDataLayoutEntry(DataLayoutEntryListRef params, LLVMPointerType type,
253 DLEntryPos pos) {
254 // First, look for the entry for the pointer in the current address space.
255 Attribute currentEntry;
256 for (DataLayoutEntryInterface entry : params) {
257 if (!entry.isTypeEntry())
258 continue;
259 if (entry.getKey().get<Type>().cast<LLVMPointerType>().getAddressSpace() ==
260 type.getAddressSpace()) {
261 currentEntry = entry.getValue();
262 break;
263 }
264 }
265 if (currentEntry) {
266 return extractPointerSpecValue(currentEntry, pos) /
267 (pos == DLEntryPos::Size ? 1 : kBitsInByte);
268 }
269
270 // If not found, and this is the pointer to the default memory space, assume
271 // 64-bit pointers.
272 if (type.getAddressSpace() == 0) {
273 return pos == DLEntryPos::Size ? kDefaultPointerSizeBits
274 : kDefaultPointerAlignment;
275 }
276
277 return llvm::None;
278 }
279
280 unsigned
getTypeSizeInBits(const DataLayout & dataLayout,DataLayoutEntryListRef params) const281 LLVMPointerType::getTypeSizeInBits(const DataLayout &dataLayout,
282 DataLayoutEntryListRef params) const {
283 if (Optional<unsigned> size =
284 getPointerDataLayoutEntry(params, *this, DLEntryPos::Size))
285 return *size;
286
287 // For other memory spaces, use the size of the pointer to the default memory
288 // space.
289 if (isOpaque())
290 return dataLayout.getTypeSizeInBits(get(getContext()));
291 return dataLayout.getTypeSizeInBits(get(getElementType()));
292 }
293
getABIAlignment(const DataLayout & dataLayout,DataLayoutEntryListRef params) const294 unsigned LLVMPointerType::getABIAlignment(const DataLayout &dataLayout,
295 DataLayoutEntryListRef params) const {
296 if (Optional<unsigned> alignment =
297 getPointerDataLayoutEntry(params, *this, DLEntryPos::Abi))
298 return *alignment;
299
300 if (isOpaque())
301 return dataLayout.getTypeABIAlignment(get(getContext()));
302 return dataLayout.getTypeABIAlignment(get(getElementType()));
303 }
304
305 unsigned
getPreferredAlignment(const DataLayout & dataLayout,DataLayoutEntryListRef params) const306 LLVMPointerType::getPreferredAlignment(const DataLayout &dataLayout,
307 DataLayoutEntryListRef params) const {
308 if (Optional<unsigned> alignment =
309 getPointerDataLayoutEntry(params, *this, DLEntryPos::Preferred))
310 return *alignment;
311
312 if (isOpaque())
313 return dataLayout.getTypePreferredAlignment(get(getContext()));
314 return dataLayout.getTypePreferredAlignment(get(getElementType()));
315 }
316
areCompatible(DataLayoutEntryListRef oldLayout,DataLayoutEntryListRef newLayout) const317 bool LLVMPointerType::areCompatible(DataLayoutEntryListRef oldLayout,
318 DataLayoutEntryListRef newLayout) const {
319 for (DataLayoutEntryInterface newEntry : newLayout) {
320 if (!newEntry.isTypeEntry())
321 continue;
322 unsigned size = kDefaultPointerSizeBits;
323 unsigned abi = kDefaultPointerAlignment;
324 auto newType = newEntry.getKey().get<Type>().cast<LLVMPointerType>();
325 const auto *it =
326 llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
327 if (auto type = entry.getKey().dyn_cast<Type>()) {
328 return type.cast<LLVMPointerType>().getAddressSpace() ==
329 newType.getAddressSpace();
330 }
331 return false;
332 });
333 if (it == oldLayout.end()) {
334 llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
335 if (auto type = entry.getKey().dyn_cast<Type>()) {
336 return type.cast<LLVMPointerType>().getAddressSpace() == 0;
337 }
338 return false;
339 });
340 }
341 if (it != oldLayout.end()) {
342 size = extractPointerSpecValue(*it, DLEntryPos::Size);
343 abi = extractPointerSpecValue(*it, DLEntryPos::Abi);
344 }
345
346 Attribute newSpec = newEntry.getValue().cast<DenseIntElementsAttr>();
347 unsigned newSize = extractPointerSpecValue(newSpec, DLEntryPos::Size);
348 unsigned newAbi = extractPointerSpecValue(newSpec, DLEntryPos::Abi);
349 if (size != newSize || abi < newAbi || abi % newAbi != 0)
350 return false;
351 }
352 return true;
353 }
354
verifyEntries(DataLayoutEntryListRef entries,Location loc) const355 LogicalResult LLVMPointerType::verifyEntries(DataLayoutEntryListRef entries,
356 Location loc) const {
357 for (DataLayoutEntryInterface entry : entries) {
358 if (!entry.isTypeEntry())
359 continue;
360 auto key = entry.getKey().get<Type>().cast<LLVMPointerType>();
361 auto values = entry.getValue().dyn_cast<DenseIntElementsAttr>();
362 if (!values || (values.size() != 3 && values.size() != 4)) {
363 return emitError(loc)
364 << "expected layout attribute for " << entry.getKey().get<Type>()
365 << " to be a dense integer elements attribute with 3 or 4 "
366 "elements";
367 }
368 if (key.getElementType() && !key.getElementType().isInteger(8)) {
369 return emitError(loc) << "unexpected layout attribute for pointer to "
370 << key.getElementType();
371 }
372 if (extractPointerSpecValue(values, DLEntryPos::Abi) >
373 extractPointerSpecValue(values, DLEntryPos::Preferred)) {
374 return emitError(loc) << "preferred alignment is expected to be at least "
375 "as large as ABI alignment";
376 }
377 }
378 return success();
379 }
380
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const381 void LLVMPointerType::walkImmediateSubElements(
382 function_ref<void(Attribute)> walkAttrsFn,
383 function_ref<void(Type)> walkTypesFn) const {
384 walkTypesFn(getElementType());
385 }
386
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const387 Type LLVMPointerType::replaceImmediateSubElements(
388 ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
389 return get(replTypes.front(), getAddressSpace());
390 }
391
392 //===----------------------------------------------------------------------===//
393 // Struct type.
394 //===----------------------------------------------------------------------===//
395
isValidElementType(Type type)396 bool LLVMStructType::isValidElementType(Type type) {
397 return !type.isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
398 LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>();
399 }
400
getIdentified(MLIRContext * context,StringRef name)401 LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
402 StringRef name) {
403 return Base::get(context, name, /*opaque=*/false);
404 }
405
getIdentifiedChecked(function_ref<InFlightDiagnostic ()> emitError,MLIRContext * context,StringRef name)406 LLVMStructType LLVMStructType::getIdentifiedChecked(
407 function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
408 StringRef name) {
409 return Base::getChecked(emitError, context, name, /*opaque=*/false);
410 }
411
getNewIdentified(MLIRContext * context,StringRef name,ArrayRef<Type> elements,bool isPacked)412 LLVMStructType LLVMStructType::getNewIdentified(MLIRContext *context,
413 StringRef name,
414 ArrayRef<Type> elements,
415 bool isPacked) {
416 std::string stringName = name.str();
417 unsigned counter = 0;
418 do {
419 auto type = LLVMStructType::getIdentified(context, stringName);
420 if (type.isInitialized() || failed(type.setBody(elements, isPacked))) {
421 counter += 1;
422 stringName = (Twine(name) + "." + std::to_string(counter)).str();
423 continue;
424 }
425 return type;
426 } while (true);
427 }
428
getLiteral(MLIRContext * context,ArrayRef<Type> types,bool isPacked)429 LLVMStructType LLVMStructType::getLiteral(MLIRContext *context,
430 ArrayRef<Type> types, bool isPacked) {
431 return Base::get(context, types, isPacked);
432 }
433
434 LLVMStructType
getLiteralChecked(function_ref<InFlightDiagnostic ()> emitError,MLIRContext * context,ArrayRef<Type> types,bool isPacked)435 LLVMStructType::getLiteralChecked(function_ref<InFlightDiagnostic()> emitError,
436 MLIRContext *context, ArrayRef<Type> types,
437 bool isPacked) {
438 return Base::getChecked(emitError, context, types, isPacked);
439 }
440
getOpaque(StringRef name,MLIRContext * context)441 LLVMStructType LLVMStructType::getOpaque(StringRef name, MLIRContext *context) {
442 return Base::get(context, name, /*opaque=*/true);
443 }
444
445 LLVMStructType
getOpaqueChecked(function_ref<InFlightDiagnostic ()> emitError,MLIRContext * context,StringRef name)446 LLVMStructType::getOpaqueChecked(function_ref<InFlightDiagnostic()> emitError,
447 MLIRContext *context, StringRef name) {
448 return Base::getChecked(emitError, context, name, /*opaque=*/true);
449 }
450
setBody(ArrayRef<Type> types,bool isPacked)451 LogicalResult LLVMStructType::setBody(ArrayRef<Type> types, bool isPacked) {
452 assert(isIdentified() && "can only set bodies of identified structs");
453 assert(llvm::all_of(types, LLVMStructType::isValidElementType) &&
454 "expected valid body types");
455 return Base::mutate(types, isPacked);
456 }
457
isPacked() const458 bool LLVMStructType::isPacked() const { return getImpl()->isPacked(); }
isIdentified() const459 bool LLVMStructType::isIdentified() const { return getImpl()->isIdentified(); }
isOpaque()460 bool LLVMStructType::isOpaque() {
461 return getImpl()->isIdentified() &&
462 (getImpl()->isOpaque() || !getImpl()->isInitialized());
463 }
isInitialized()464 bool LLVMStructType::isInitialized() { return getImpl()->isInitialized(); }
getName()465 StringRef LLVMStructType::getName() { return getImpl()->getIdentifier(); }
getBody() const466 ArrayRef<Type> LLVMStructType::getBody() const {
467 return isIdentified() ? getImpl()->getIdentifiedStructBody()
468 : getImpl()->getTypeList();
469 }
470
verify(function_ref<InFlightDiagnostic ()>,StringRef,bool)471 LogicalResult LLVMStructType::verify(function_ref<InFlightDiagnostic()>,
472 StringRef, bool) {
473 return success();
474 }
475
476 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,ArrayRef<Type> types,bool)477 LLVMStructType::verify(function_ref<InFlightDiagnostic()> emitError,
478 ArrayRef<Type> types, bool) {
479 for (Type t : types)
480 if (!isValidElementType(t))
481 return emitError() << "invalid LLVM structure element type: " << t;
482
483 return success();
484 }
485
486 unsigned
getTypeSizeInBits(const DataLayout & dataLayout,DataLayoutEntryListRef params) const487 LLVMStructType::getTypeSizeInBits(const DataLayout &dataLayout,
488 DataLayoutEntryListRef params) const {
489 unsigned structSize = 0;
490 unsigned structAlignment = 1;
491 for (Type element : getBody()) {
492 unsigned elementAlignment =
493 isPacked() ? 1 : dataLayout.getTypeABIAlignment(element);
494 // Add padding to the struct size to align it to the abi alignment of the
495 // element type before than adding the size of the element
496 structSize = llvm::alignTo(structSize, elementAlignment);
497 structSize += dataLayout.getTypeSize(element);
498
499 // The alignment requirement of a struct is equal to the strictest alignment
500 // requirement of its elements.
501 structAlignment = std::max(elementAlignment, structAlignment);
502 }
503 // At the end, add padding to the struct to satisfy its own alignment
504 // requirement. Otherwise structs inside of arrays would be misaligned.
505 structSize = llvm::alignTo(structSize, structAlignment);
506 return structSize * kBitsInByte;
507 }
508
509 namespace {
510 enum class StructDLEntryPos { Abi = 0, Preferred = 1 };
511 } // namespace
512
513 static Optional<unsigned>
getStructDataLayoutEntry(DataLayoutEntryListRef params,LLVMStructType type,StructDLEntryPos pos)514 getStructDataLayoutEntry(DataLayoutEntryListRef params, LLVMStructType type,
515 StructDLEntryPos pos) {
516 const auto *currentEntry =
517 llvm::find_if(params, [](DataLayoutEntryInterface entry) {
518 return entry.isTypeEntry();
519 });
520 if (currentEntry == params.end())
521 return llvm::None;
522
523 auto attr = currentEntry->getValue().cast<DenseIntElementsAttr>();
524 if (pos == StructDLEntryPos::Preferred &&
525 attr.size() <= static_cast<unsigned>(StructDLEntryPos::Preferred))
526 // If no preferred was specified, fall back to abi alignment
527 pos = StructDLEntryPos::Abi;
528
529 return attr.getValues<unsigned>()[static_cast<unsigned>(pos)];
530 }
531
calculateStructAlignment(const DataLayout & dataLayout,DataLayoutEntryListRef params,LLVMStructType type,StructDLEntryPos pos)532 static unsigned calculateStructAlignment(const DataLayout &dataLayout,
533 DataLayoutEntryListRef params,
534 LLVMStructType type,
535 StructDLEntryPos pos) {
536 // Packed structs always have an abi alignment of 1
537 if (pos == StructDLEntryPos::Abi && type.isPacked()) {
538 return 1;
539 }
540
541 // The alignment requirement of a struct is equal to the strictest alignment
542 // requirement of its elements.
543 unsigned structAlignment = 1;
544 for (Type iter : type.getBody()) {
545 structAlignment =
546 std::max(dataLayout.getTypeABIAlignment(iter), structAlignment);
547 }
548
549 // Entries are only allowed to be stricter than the required alignment
550 if (Optional<unsigned> entryResult =
551 getStructDataLayoutEntry(params, type, pos))
552 return std::max(*entryResult / kBitsInByte, structAlignment);
553
554 return structAlignment;
555 }
556
getABIAlignment(const DataLayout & dataLayout,DataLayoutEntryListRef params) const557 unsigned LLVMStructType::getABIAlignment(const DataLayout &dataLayout,
558 DataLayoutEntryListRef params) const {
559 return calculateStructAlignment(dataLayout, params, *this,
560 StructDLEntryPos::Abi);
561 }
562
563 unsigned
getPreferredAlignment(const DataLayout & dataLayout,DataLayoutEntryListRef params) const564 LLVMStructType::getPreferredAlignment(const DataLayout &dataLayout,
565 DataLayoutEntryListRef params) const {
566 return calculateStructAlignment(dataLayout, params, *this,
567 StructDLEntryPos::Preferred);
568 }
569
extractStructSpecValue(Attribute attr,StructDLEntryPos pos)570 static unsigned extractStructSpecValue(Attribute attr, StructDLEntryPos pos) {
571 return attr.cast<DenseIntElementsAttr>()
572 .getValues<unsigned>()[static_cast<unsigned>(pos)];
573 }
574
areCompatible(DataLayoutEntryListRef oldLayout,DataLayoutEntryListRef newLayout) const575 bool LLVMStructType::areCompatible(DataLayoutEntryListRef oldLayout,
576 DataLayoutEntryListRef newLayout) const {
577 for (DataLayoutEntryInterface newEntry : newLayout) {
578 if (!newEntry.isTypeEntry())
579 continue;
580
581 const auto *previousEntry =
582 llvm::find_if(oldLayout, [](DataLayoutEntryInterface entry) {
583 return entry.isTypeEntry();
584 });
585 if (previousEntry == oldLayout.end())
586 continue;
587
588 unsigned abi = extractStructSpecValue(previousEntry->getValue(),
589 StructDLEntryPos::Abi);
590 unsigned newAbi =
591 extractStructSpecValue(newEntry.getValue(), StructDLEntryPos::Abi);
592 if (abi < newAbi || abi % newAbi != 0)
593 return false;
594 }
595 return true;
596 }
597
verifyEntries(DataLayoutEntryListRef entries,Location loc) const598 LogicalResult LLVMStructType::verifyEntries(DataLayoutEntryListRef entries,
599 Location loc) const {
600 for (DataLayoutEntryInterface entry : entries) {
601 if (!entry.isTypeEntry())
602 continue;
603
604 auto key = entry.getKey().get<Type>().cast<LLVMStructType>();
605 auto values = entry.getValue().dyn_cast<DenseIntElementsAttr>();
606 if (!values || (values.size() != 2 && values.size() != 1)) {
607 return emitError(loc)
608 << "expected layout attribute for " << entry.getKey().get<Type>()
609 << " to be a dense integer elements attribute of 1 or 2 elements";
610 }
611
612 if (key.isIdentified() || !key.getBody().empty()) {
613 return emitError(loc) << "unexpected layout attribute for struct " << key;
614 }
615
616 if (values.size() == 1)
617 continue;
618
619 if (extractStructSpecValue(values, StructDLEntryPos::Abi) >
620 extractStructSpecValue(values, StructDLEntryPos::Preferred)) {
621 return emitError(loc) << "preferred alignment is expected to be at least "
622 "as large as ABI alignment";
623 }
624 }
625 return mlir::success();
626 }
627
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const628 void LLVMStructType::walkImmediateSubElements(
629 function_ref<void(Attribute)> walkAttrsFn,
630 function_ref<void(Type)> walkTypesFn) const {
631 for (Type type : getBody())
632 walkTypesFn(type);
633 }
634
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const635 Type LLVMStructType::replaceImmediateSubElements(
636 ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
637 // TODO: It's not clear how we support replacing sub-elements of mutable
638 // types.
639 return nullptr;
640 }
641
642 //===----------------------------------------------------------------------===//
643 // Vector types.
644 //===----------------------------------------------------------------------===//
645
646 /// Verifies that the type about to be constructed is well-formed.
647 template <typename VecTy>
648 static LogicalResult
verifyVectorConstructionInvariants(function_ref<InFlightDiagnostic ()> emitError,Type elementType,unsigned numElements)649 verifyVectorConstructionInvariants(function_ref<InFlightDiagnostic()> emitError,
650 Type elementType, unsigned numElements) {
651 if (numElements == 0)
652 return emitError() << "the number of vector elements must be positive";
653
654 if (!VecTy::isValidElementType(elementType))
655 return emitError() << "invalid vector element type";
656
657 return success();
658 }
659
get(Type elementType,unsigned numElements)660 LLVMFixedVectorType LLVMFixedVectorType::get(Type elementType,
661 unsigned numElements) {
662 assert(elementType && "expected non-null subtype");
663 return Base::get(elementType.getContext(), elementType, numElements);
664 }
665
666 LLVMFixedVectorType
getChecked(function_ref<InFlightDiagnostic ()> emitError,Type elementType,unsigned numElements)667 LLVMFixedVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
668 Type elementType, unsigned numElements) {
669 assert(elementType && "expected non-null subtype");
670 return Base::getChecked(emitError, elementType.getContext(), elementType,
671 numElements);
672 }
673
getElementType() const674 Type LLVMFixedVectorType::getElementType() const {
675 return static_cast<detail::LLVMTypeAndSizeStorage *>(impl)->elementType;
676 }
677
getNumElements() const678 unsigned LLVMFixedVectorType::getNumElements() const {
679 return getImpl()->numElements;
680 }
681
isValidElementType(Type type)682 bool LLVMFixedVectorType::isValidElementType(Type type) {
683 return type.isa<LLVMPointerType, LLVMPPCFP128Type>();
684 }
685
686 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Type elementType,unsigned numElements)687 LLVMFixedVectorType::verify(function_ref<InFlightDiagnostic()> emitError,
688 Type elementType, unsigned numElements) {
689 return verifyVectorConstructionInvariants<LLVMFixedVectorType>(
690 emitError, elementType, numElements);
691 }
692
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const693 void LLVMFixedVectorType::walkImmediateSubElements(
694 function_ref<void(Attribute)> walkAttrsFn,
695 function_ref<void(Type)> walkTypesFn) const {
696 walkTypesFn(getElementType());
697 }
698
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const699 Type LLVMFixedVectorType::replaceImmediateSubElements(
700 ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
701 return get(replTypes[0], getNumElements());
702 }
703
704 //===----------------------------------------------------------------------===//
705 // LLVMScalableVectorType.
706 //===----------------------------------------------------------------------===//
707
get(Type elementType,unsigned minNumElements)708 LLVMScalableVectorType LLVMScalableVectorType::get(Type elementType,
709 unsigned minNumElements) {
710 assert(elementType && "expected non-null subtype");
711 return Base::get(elementType.getContext(), elementType, minNumElements);
712 }
713
714 LLVMScalableVectorType
getChecked(function_ref<InFlightDiagnostic ()> emitError,Type elementType,unsigned minNumElements)715 LLVMScalableVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
716 Type elementType, unsigned minNumElements) {
717 assert(elementType && "expected non-null subtype");
718 return Base::getChecked(emitError, elementType.getContext(), elementType,
719 minNumElements);
720 }
721
getElementType() const722 Type LLVMScalableVectorType::getElementType() const {
723 return static_cast<detail::LLVMTypeAndSizeStorage *>(impl)->elementType;
724 }
725
getMinNumElements() const726 unsigned LLVMScalableVectorType::getMinNumElements() const {
727 return getImpl()->numElements;
728 }
729
isValidElementType(Type type)730 bool LLVMScalableVectorType::isValidElementType(Type type) {
731 if (auto intType = type.dyn_cast<IntegerType>())
732 return intType.isSignless();
733
734 return isCompatibleFloatingPointType(type) || type.isa<LLVMPointerType>();
735 }
736
737 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Type elementType,unsigned numElements)738 LLVMScalableVectorType::verify(function_ref<InFlightDiagnostic()> emitError,
739 Type elementType, unsigned numElements) {
740 return verifyVectorConstructionInvariants<LLVMScalableVectorType>(
741 emitError, elementType, numElements);
742 }
743
walkImmediateSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn) const744 void LLVMScalableVectorType::walkImmediateSubElements(
745 function_ref<void(Attribute)> walkAttrsFn,
746 function_ref<void(Type)> walkTypesFn) const {
747 walkTypesFn(getElementType());
748 }
749
replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,ArrayRef<Type> replTypes) const750 Type LLVMScalableVectorType::replaceImmediateSubElements(
751 ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
752 return get(replTypes[0], getMinNumElements());
753 }
754
755 //===----------------------------------------------------------------------===//
756 // Utility functions.
757 //===----------------------------------------------------------------------===//
758
isCompatibleOuterType(Type type)759 bool mlir::LLVM::isCompatibleOuterType(Type type) {
760 // clang-format off
761 if (type.isa<
762 BFloat16Type,
763 Float16Type,
764 Float32Type,
765 Float64Type,
766 Float80Type,
767 Float128Type,
768 LLVMArrayType,
769 LLVMFunctionType,
770 LLVMLabelType,
771 LLVMMetadataType,
772 LLVMPPCFP128Type,
773 LLVMPointerType,
774 LLVMStructType,
775 LLVMTokenType,
776 LLVMFixedVectorType,
777 LLVMScalableVectorType,
778 LLVMVoidType,
779 LLVMX86MMXType
780 >()) {
781 // clang-format on
782 return true;
783 }
784
785 // Only signless integers are compatible.
786 if (auto intType = type.dyn_cast<IntegerType>())
787 return intType.isSignless();
788
789 // 1D vector types are compatible.
790 if (auto vecType = type.dyn_cast<VectorType>())
791 return vecType.getRank() == 1;
792
793 return false;
794 }
795
isCompatibleImpl(Type type,DenseSet<Type> & compatibleTypes)796 static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
797 if (!compatibleTypes.insert(type).second)
798 return true;
799
800 auto isCompatible = [&](Type type) {
801 return isCompatibleImpl(type, compatibleTypes);
802 };
803
804 bool result =
805 llvm::TypeSwitch<Type, bool>(type)
806 .Case<LLVMStructType>([&](auto structType) {
807 return llvm::all_of(structType.getBody(), isCompatible);
808 })
809 .Case<LLVMFunctionType>([&](auto funcType) {
810 return isCompatible(funcType.getReturnType()) &&
811 llvm::all_of(funcType.getParams(), isCompatible);
812 })
813 .Case<IntegerType>([](auto intType) { return intType.isSignless(); })
814 .Case<VectorType>([&](auto vecType) {
815 return vecType.getRank() == 1 &&
816 isCompatible(vecType.getElementType());
817 })
818 .Case<LLVMPointerType>([&](auto pointerType) {
819 if (pointerType.isOpaque())
820 return true;
821 return isCompatible(pointerType.getElementType());
822 })
823 // clang-format off
824 .Case<
825 LLVMFixedVectorType,
826 LLVMScalableVectorType,
827 LLVMArrayType
828 >([&](auto containerType) {
829 return isCompatible(containerType.getElementType());
830 })
831 .Case<
832 BFloat16Type,
833 Float16Type,
834 Float32Type,
835 Float64Type,
836 Float80Type,
837 Float128Type,
838 LLVMLabelType,
839 LLVMMetadataType,
840 LLVMPPCFP128Type,
841 LLVMTokenType,
842 LLVMVoidType,
843 LLVMX86MMXType
844 >([](Type) { return true; })
845 // clang-format on
846 .Default([](Type) { return false; });
847
848 if (!result)
849 compatibleTypes.erase(type);
850
851 return result;
852 }
853
isCompatibleType(Type type)854 bool LLVMDialect::isCompatibleType(Type type) {
855 if (auto *llvmDialect =
856 type.getContext()->getLoadedDialect<LLVM::LLVMDialect>())
857 return isCompatibleImpl(type, llvmDialect->compatibleTypes.get());
858
859 DenseSet<Type> localCompatibleTypes;
860 return isCompatibleImpl(type, localCompatibleTypes);
861 }
862
isCompatibleType(Type type)863 bool mlir::LLVM::isCompatibleType(Type type) {
864 return LLVMDialect::isCompatibleType(type);
865 }
866
isCompatibleFloatingPointType(Type type)867 bool mlir::LLVM::isCompatibleFloatingPointType(Type type) {
868 return type.isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
869 Float80Type, Float128Type, LLVMPPCFP128Type>();
870 }
871
isCompatibleVectorType(Type type)872 bool mlir::LLVM::isCompatibleVectorType(Type type) {
873 if (type.isa<LLVMFixedVectorType, LLVMScalableVectorType>())
874 return true;
875
876 if (auto vecType = type.dyn_cast<VectorType>()) {
877 if (vecType.getRank() != 1)
878 return false;
879 Type elementType = vecType.getElementType();
880 if (auto intType = elementType.dyn_cast<IntegerType>())
881 return intType.isSignless();
882 return elementType.isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
883 Float80Type, Float128Type>();
884 }
885 return false;
886 }
887
getVectorElementType(Type type)888 Type mlir::LLVM::getVectorElementType(Type type) {
889 return llvm::TypeSwitch<Type, Type>(type)
890 .Case<LLVMFixedVectorType, LLVMScalableVectorType, VectorType>(
891 [](auto ty) { return ty.getElementType(); })
892 .Default([](Type) -> Type {
893 llvm_unreachable("incompatible with LLVM vector type");
894 });
895 }
896
getVectorNumElements(Type type)897 llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) {
898 return llvm::TypeSwitch<Type, llvm::ElementCount>(type)
899 .Case([](VectorType ty) {
900 if (ty.isScalable())
901 return llvm::ElementCount::getScalable(ty.getNumElements());
902 return llvm::ElementCount::getFixed(ty.getNumElements());
903 })
904 .Case([](LLVMFixedVectorType ty) {
905 return llvm::ElementCount::getFixed(ty.getNumElements());
906 })
907 .Case([](LLVMScalableVectorType ty) {
908 return llvm::ElementCount::getScalable(ty.getMinNumElements());
909 })
910 .Default([](Type) -> llvm::ElementCount {
911 llvm_unreachable("incompatible with LLVM vector type");
912 });
913 }
914
isScalableVectorType(Type vectorType)915 bool mlir::LLVM::isScalableVectorType(Type vectorType) {
916 assert(
917 (vectorType
918 .isa<LLVMFixedVectorType, LLVMScalableVectorType, VectorType>()) &&
919 "expected LLVM-compatible vector type");
920 return !vectorType.isa<LLVMFixedVectorType>() &&
921 (vectorType.isa<LLVMScalableVectorType>() ||
922 vectorType.cast<VectorType>().isScalable());
923 }
924
getVectorType(Type elementType,unsigned numElements,bool isScalable)925 Type mlir::LLVM::getVectorType(Type elementType, unsigned numElements,
926 bool isScalable) {
927 bool useLLVM = LLVMFixedVectorType::isValidElementType(elementType);
928 bool useBuiltIn = VectorType::isValidElementType(elementType);
929 (void)useBuiltIn;
930 assert((useLLVM ^ useBuiltIn) && "expected LLVM-compatible fixed-vector type "
931 "to be either builtin or LLVM dialect type");
932 if (useLLVM) {
933 if (isScalable)
934 return LLVMScalableVectorType::get(elementType, numElements);
935 return LLVMFixedVectorType::get(elementType, numElements);
936 }
937 return VectorType::get(numElements, elementType, (unsigned)isScalable);
938 }
939
getFixedVectorType(Type elementType,unsigned numElements)940 Type mlir::LLVM::getFixedVectorType(Type elementType, unsigned numElements) {
941 bool useLLVM = LLVMFixedVectorType::isValidElementType(elementType);
942 bool useBuiltIn = VectorType::isValidElementType(elementType);
943 (void)useBuiltIn;
944 assert((useLLVM ^ useBuiltIn) && "expected LLVM-compatible fixed-vector type "
945 "to be either builtin or LLVM dialect type");
946 if (useLLVM)
947 return LLVMFixedVectorType::get(elementType, numElements);
948 return VectorType::get(numElements, elementType);
949 }
950
getScalableVectorType(Type elementType,unsigned numElements)951 Type mlir::LLVM::getScalableVectorType(Type elementType, unsigned numElements) {
952 bool useLLVM = LLVMScalableVectorType::isValidElementType(elementType);
953 bool useBuiltIn = VectorType::isValidElementType(elementType);
954 (void)useBuiltIn;
955 assert((useLLVM ^ useBuiltIn) && "expected LLVM-compatible scalable-vector "
956 "type to be either builtin or LLVM dialect "
957 "type");
958 if (useLLVM)
959 return LLVMScalableVectorType::get(elementType, numElements);
960 return VectorType::get(numElements, elementType, /*numScalableDims=*/1);
961 }
962
getPrimitiveTypeSizeInBits(Type type)963 llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
964 assert(isCompatibleType(type) &&
965 "expected a type compatible with the LLVM dialect");
966
967 return llvm::TypeSwitch<Type, llvm::TypeSize>(type)
968 .Case<BFloat16Type, Float16Type>(
969 [](Type) { return llvm::TypeSize::Fixed(16); })
970 .Case<Float32Type>([](Type) { return llvm::TypeSize::Fixed(32); })
971 .Case<Float64Type, LLVMX86MMXType>(
972 [](Type) { return llvm::TypeSize::Fixed(64); })
973 .Case<Float80Type>([](Type) { return llvm::TypeSize::Fixed(80); })
974 .Case<Float128Type>([](Type) { return llvm::TypeSize::Fixed(128); })
975 .Case<IntegerType>([](IntegerType intTy) {
976 return llvm::TypeSize::Fixed(intTy.getWidth());
977 })
978 .Case<LLVMPPCFP128Type>([](Type) { return llvm::TypeSize::Fixed(128); })
979 .Case<LLVMFixedVectorType>([](LLVMFixedVectorType t) {
980 llvm::TypeSize elementSize =
981 getPrimitiveTypeSizeInBits(t.getElementType());
982 return llvm::TypeSize(elementSize.getFixedSize() * t.getNumElements(),
983 elementSize.isScalable());
984 })
985 .Case<VectorType>([](VectorType t) {
986 assert(isCompatibleVectorType(t) &&
987 "unexpected incompatible with LLVM vector type");
988 llvm::TypeSize elementSize =
989 getPrimitiveTypeSizeInBits(t.getElementType());
990 return llvm::TypeSize(elementSize.getFixedSize() * t.getNumElements(),
991 elementSize.isScalable());
992 })
993 .Default([](Type ty) {
994 assert((ty.isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
995 LLVMTokenType, LLVMStructType, LLVMArrayType,
996 LLVMPointerType, LLVMFunctionType>()) &&
997 "unexpected missing support for primitive type");
998 return llvm::TypeSize::Fixed(0);
999 });
1000 }
1001
1002 #include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc"
1003