1 //===- DataLayoutInterfaces.cpp - Data Layout Interface Implementation ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/DataLayoutInterfaces.h"
10 #include "mlir/IR/BuiltinDialect.h"
11 #include "mlir/IR/BuiltinOps.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/IR/Operation.h"
14 
15 #include "llvm/ADT/TypeSwitch.h"
16 #include "llvm/Support/MathExtras.h"
17 
18 using namespace mlir;
19 
20 //===----------------------------------------------------------------------===//
21 // Default implementations
22 //===----------------------------------------------------------------------===//
23 
24 /// Reports that the given type is missing the data layout information and
25 /// exits.
26 [[noreturn]] static void reportMissingDataLayout(Type type) {
27   std::string message;
28   llvm::raw_string_ostream os(message);
29   os << "neither the scoping op nor the type class provide data layout "
30         "information for "
31      << type;
32   llvm::report_fatal_error(Twine(os.str()));
33 }
34 
35 /// Returns the bitwidth of the index type if specified in the param list.
36 /// Assumes 64-bit index otherwise.
37 static unsigned getIndexBitwidth(DataLayoutEntryListRef params) {
38   if (params.empty())
39     return 64;
40   auto attr = params.front().getValue().cast<IntegerAttr>();
41   return attr.getValue().getZExtValue();
42 }
43 
44 unsigned
45 mlir::detail::getDefaultTypeSize(Type type, const DataLayout &dataLayout,
46                                  ArrayRef<DataLayoutEntryInterface> params) {
47   unsigned bits = getDefaultTypeSizeInBits(type, dataLayout, params);
48   return llvm::divideCeil(bits, 8);
49 }
50 
51 unsigned mlir::detail::getDefaultTypeSizeInBits(Type type,
52                                                 const DataLayout &dataLayout,
53                                                 DataLayoutEntryListRef params) {
54   if (type.isa<IntegerType, FloatType>())
55     return type.getIntOrFloatBitWidth();
56 
57   if (auto ctype = type.dyn_cast<ComplexType>()) {
58     auto et = ctype.getElementType();
59     auto innerAlignment =
60         getDefaultPreferredAlignment(et, dataLayout, params) * 8;
61     auto innerSize = getDefaultTypeSizeInBits(et, dataLayout, params);
62 
63     // Include padding required to align the imaginary value in the complex
64     // type.
65     return llvm::alignTo(innerSize, innerAlignment) + innerSize;
66   }
67 
68   // Index is an integer of some bitwidth.
69   if (type.isa<IndexType>())
70     return dataLayout.getTypeSizeInBits(
71         IntegerType::get(type.getContext(), getIndexBitwidth(params)));
72 
73   // Sizes of vector types are rounded up to those of types with closest
74   // power-of-two number of elements in the innermost dimension. We also assume
75   // there is no bit-packing at the moment element sizes are taken in bytes and
76   // multiplied with 8 bits.
77   // TODO: make this extensible.
78   if (auto vecType = type.dyn_cast<VectorType>())
79     return vecType.getNumElements() / vecType.getShape().back() *
80            llvm::PowerOf2Ceil(vecType.getShape().back()) *
81            dataLayout.getTypeSize(vecType.getElementType()) * 8;
82 
83   if (auto typeInterface = type.dyn_cast<DataLayoutTypeInterface>())
84     return typeInterface.getTypeSizeInBits(dataLayout, params);
85 
86   reportMissingDataLayout(type);
87 }
88 
89 static DataLayoutEntryInterface
90 findEntryForIntegerType(IntegerType intType,
91                         ArrayRef<DataLayoutEntryInterface> params) {
92   assert(!params.empty() && "expected non-empty parameter list");
93   std::map<unsigned, DataLayoutEntryInterface> sortedParams;
94   for (DataLayoutEntryInterface entry : params) {
95     sortedParams.insert(std::make_pair(
96         entry.getKey().get<Type>().getIntOrFloatBitWidth(), entry));
97   }
98   auto iter = sortedParams.lower_bound(intType.getWidth());
99   if (iter == sortedParams.end())
100     iter = std::prev(iter);
101 
102   return iter->second;
103 }
104 
105 static unsigned extractABIAlignment(DataLayoutEntryInterface entry) {
106   auto values =
107       entry.getValue().cast<DenseIntElementsAttr>().getValues<int32_t>();
108   return *values.begin() / 8u;
109 }
110 
111 static unsigned
112 getIntegerTypeABIAlignment(IntegerType intType,
113                            ArrayRef<DataLayoutEntryInterface> params) {
114   if (params.empty()) {
115     return intType.getWidth() < 64
116                ? llvm::PowerOf2Ceil(llvm::divideCeil(intType.getWidth(), 8))
117                : 4;
118   }
119 
120   return extractABIAlignment(findEntryForIntegerType(intType, params));
121 }
122 
123 static unsigned
124 getFloatTypeABIAlignment(FloatType fltType, const DataLayout &dataLayout,
125                          ArrayRef<DataLayoutEntryInterface> params) {
126   assert(params.size() <= 1 && "at most one data layout entry is expected for "
127                                "the singleton floating-point type");
128   if (params.empty())
129     return llvm::PowerOf2Ceil(dataLayout.getTypeSize(fltType));
130   return extractABIAlignment(params[0]);
131 }
132 
133 unsigned mlir::detail::getDefaultABIAlignment(
134     Type type, const DataLayout &dataLayout,
135     ArrayRef<DataLayoutEntryInterface> params) {
136   // Natural alignment is the closest power-of-two number above.
137   if (type.isa<VectorType>())
138     return llvm::PowerOf2Ceil(dataLayout.getTypeSize(type));
139 
140   if (auto fltType = type.dyn_cast<FloatType>())
141     return getFloatTypeABIAlignment(fltType, dataLayout, params);
142 
143   // Index is an integer of some bitwidth.
144   if (type.isa<IndexType>())
145     return dataLayout.getTypeABIAlignment(
146         IntegerType::get(type.getContext(), getIndexBitwidth(params)));
147 
148   if (auto intType = type.dyn_cast<IntegerType>())
149     return getIntegerTypeABIAlignment(intType, params);
150 
151   if (auto ctype = type.dyn_cast<ComplexType>())
152     return getDefaultABIAlignment(ctype.getElementType(), dataLayout, params);
153 
154   if (auto typeInterface = type.dyn_cast<DataLayoutTypeInterface>())
155     return typeInterface.getABIAlignment(dataLayout, params);
156 
157   reportMissingDataLayout(type);
158 }
159 
160 static unsigned extractPreferredAlignment(DataLayoutEntryInterface entry) {
161   auto values =
162       entry.getValue().cast<DenseIntElementsAttr>().getValues<int32_t>();
163   return *std::next(values.begin(), values.size() - 1) / 8u;
164 }
165 
166 static unsigned
167 getIntegerTypePreferredAlignment(IntegerType intType,
168                                  const DataLayout &dataLayout,
169                                  ArrayRef<DataLayoutEntryInterface> params) {
170   if (params.empty())
171     return llvm::PowerOf2Ceil(dataLayout.getTypeSize(intType));
172 
173   return extractPreferredAlignment(findEntryForIntegerType(intType, params));
174 }
175 
176 static unsigned
177 getFloatTypePreferredAlignment(FloatType fltType, const DataLayout &dataLayout,
178                                ArrayRef<DataLayoutEntryInterface> params) {
179   assert(params.size() <= 1 && "at most one data layout entry is expected for "
180                                "the singleton floating-point type");
181   if (params.empty())
182     return dataLayout.getTypeABIAlignment(fltType);
183   return extractPreferredAlignment(params[0]);
184 }
185 
186 unsigned mlir::detail::getDefaultPreferredAlignment(
187     Type type, const DataLayout &dataLayout,
188     ArrayRef<DataLayoutEntryInterface> params) {
189   // Preferred alignment is same as natural for floats and vectors.
190   if (type.isa<VectorType>())
191     return dataLayout.getTypeABIAlignment(type);
192 
193   if (auto fltType = type.dyn_cast<FloatType>())
194     return getFloatTypePreferredAlignment(fltType, dataLayout, params);
195 
196   // Preferred alignment is the closest power-of-two number above for integers
197   // (ABI alignment may be smaller).
198   if (auto intType = type.dyn_cast<IntegerType>())
199     return getIntegerTypePreferredAlignment(intType, dataLayout, params);
200 
201   if (type.isa<IndexType>()) {
202     return dataLayout.getTypePreferredAlignment(
203         IntegerType::get(type.getContext(), getIndexBitwidth(params)));
204   }
205 
206   if (auto ctype = type.dyn_cast<ComplexType>())
207     return getDefaultPreferredAlignment(ctype.getElementType(), dataLayout,
208                                         params);
209 
210   if (auto typeInterface = type.dyn_cast<DataLayoutTypeInterface>())
211     return typeInterface.getPreferredAlignment(dataLayout, params);
212 
213   reportMissingDataLayout(type);
214 }
215 
216 DataLayoutEntryList
217 mlir::detail::filterEntriesForType(DataLayoutEntryListRef entries,
218                                    TypeID typeID) {
219   return llvm::to_vector<4>(llvm::make_filter_range(
220       entries, [typeID](DataLayoutEntryInterface entry) {
221         auto type = entry.getKey().dyn_cast<Type>();
222         return type && type.getTypeID() == typeID;
223       }));
224 }
225 
226 DataLayoutEntryInterface
227 mlir::detail::filterEntryForIdentifier(DataLayoutEntryListRef entries,
228                                        StringAttr id) {
229   const auto *it = llvm::find_if(entries, [id](DataLayoutEntryInterface entry) {
230     if (!entry.getKey().is<StringAttr>())
231       return false;
232     return entry.getKey().get<StringAttr>() == id;
233   });
234   return it == entries.end() ? DataLayoutEntryInterface() : *it;
235 }
236 
237 static DataLayoutSpecInterface getSpec(Operation *operation) {
238   return llvm::TypeSwitch<Operation *, DataLayoutSpecInterface>(operation)
239       .Case<ModuleOp, DataLayoutOpInterface>(
240           [&](auto op) { return op.getDataLayoutSpec(); })
241       .Default([](Operation *) {
242         llvm_unreachable("expected an op with data layout spec");
243         return DataLayoutSpecInterface();
244       });
245 }
246 
247 /// Populates `opsWithLayout` with the list of proper ancestors of `leaf` that
248 /// are either modules or implement the `DataLayoutOpInterface`.
249 static void
250 collectParentLayouts(Operation *leaf,
251                      SmallVectorImpl<DataLayoutSpecInterface> &specs,
252                      SmallVectorImpl<Location> *opLocations = nullptr) {
253   if (!leaf)
254     return;
255 
256   for (Operation *parent = leaf->getParentOp(); parent != nullptr;
257        parent = parent->getParentOp()) {
258     llvm::TypeSwitch<Operation *>(parent)
259         .Case<ModuleOp>([&](ModuleOp op) {
260           // Skip top-level module op unless it has a layout. Top-level module
261           // without layout is most likely the one implicitly added by the
262           // parser and it doesn't have location. Top-level null specification
263           // would have had the same effect as not having a specification at all
264           // (using type defaults).
265           if (!op->getParentOp() && !op.getDataLayoutSpec())
266             return;
267           specs.push_back(op.getDataLayoutSpec());
268           if (opLocations)
269             opLocations->push_back(op.getLoc());
270         })
271         .Case<DataLayoutOpInterface>([&](DataLayoutOpInterface op) {
272           specs.push_back(op.getDataLayoutSpec());
273           if (opLocations)
274             opLocations->push_back(op.getLoc());
275         });
276   }
277 }
278 
279 /// Returns a layout spec that is a combination of the layout specs attached
280 /// to the given operation and all its ancestors.
281 static DataLayoutSpecInterface getCombinedDataLayout(Operation *leaf) {
282   if (!leaf)
283     return {};
284 
285   assert((isa<ModuleOp, DataLayoutOpInterface>(leaf)) &&
286          "expected an op with data layout spec");
287 
288   SmallVector<DataLayoutOpInterface> opsWithLayout;
289   SmallVector<DataLayoutSpecInterface> specs;
290   collectParentLayouts(leaf, specs);
291 
292   // Fast track if there are no ancestors.
293   if (specs.empty())
294     return getSpec(leaf);
295 
296   // Create the list of non-null specs (null/missing specs can be safely
297   // ignored) from the outermost to the innermost.
298   auto nonNullSpecs = llvm::to_vector<2>(llvm::make_filter_range(
299       llvm::reverse(specs),
300       [](DataLayoutSpecInterface iface) { return iface != nullptr; }));
301 
302   // Combine the specs using the innermost as anchor.
303   if (DataLayoutSpecInterface current = getSpec(leaf))
304     return current.combineWith(nonNullSpecs);
305   if (nonNullSpecs.empty())
306     return {};
307   return nonNullSpecs.back().combineWith(
308       llvm::makeArrayRef(nonNullSpecs).drop_back());
309 }
310 
311 LogicalResult mlir::detail::verifyDataLayoutOp(Operation *op) {
312   DataLayoutSpecInterface spec = getSpec(op);
313   // The layout specification may be missing and it's fine.
314   if (!spec)
315     return success();
316 
317   if (failed(spec.verifySpec(op->getLoc())))
318     return failure();
319   if (!getCombinedDataLayout(op)) {
320     InFlightDiagnostic diag =
321         op->emitError()
322         << "data layout does not combine with layouts of enclosing ops";
323     SmallVector<DataLayoutSpecInterface> specs;
324     SmallVector<Location> opLocations;
325     collectParentLayouts(op, specs, &opLocations);
326     for (Location loc : opLocations)
327       diag.attachNote(loc) << "enclosing op with data layout";
328     return diag;
329   }
330   return success();
331 }
332 
333 //===----------------------------------------------------------------------===//
334 // DataLayout
335 //===----------------------------------------------------------------------===//
336 
337 template <typename OpTy>
338 void checkMissingLayout(DataLayoutSpecInterface originalLayout, OpTy op) {
339   if (!originalLayout) {
340     assert((!op || !op.getDataLayoutSpec()) &&
341            "could not compute layout information for an op (failed to "
342            "combine attributes?)");
343   }
344 }
345 
346 mlir::DataLayout::DataLayout() : DataLayout(ModuleOp()) {}
347 
348 mlir::DataLayout::DataLayout(DataLayoutOpInterface op)
349     : originalLayout(getCombinedDataLayout(op)), scope(op) {
350 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
351   checkMissingLayout(originalLayout, op);
352   collectParentLayouts(op, layoutStack);
353 #endif
354 }
355 
356 mlir::DataLayout::DataLayout(ModuleOp op)
357     : originalLayout(getCombinedDataLayout(op)), scope(op) {
358 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
359   checkMissingLayout(originalLayout, op);
360   collectParentLayouts(op, layoutStack);
361 #endif
362 }
363 
364 mlir::DataLayout mlir::DataLayout::closest(Operation *op) {
365   // Search the closest parent either being a module operation or implementing
366   // the data layout interface.
367   while (op) {
368     if (auto module = dyn_cast<ModuleOp>(op))
369       return DataLayout(module);
370     if (auto iface = dyn_cast<DataLayoutOpInterface>(op))
371       return DataLayout(iface);
372     op = op->getParentOp();
373   }
374   return DataLayout();
375 }
376 
377 void mlir::DataLayout::checkValid() const {
378 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
379   SmallVector<DataLayoutSpecInterface> specs;
380   collectParentLayouts(scope, specs);
381   assert(specs.size() == layoutStack.size() &&
382          "data layout object used, but no longer valid due to the change in "
383          "number of nested layouts");
384   for (auto pair : llvm::zip(specs, layoutStack)) {
385     Attribute newLayout = std::get<0>(pair);
386     Attribute origLayout = std::get<1>(pair);
387     assert(newLayout == origLayout &&
388            "data layout object used, but no longer valid "
389            "due to the change in layout attributes");
390   }
391 #endif
392   assert(((!scope && !this->originalLayout) ||
393           (scope && this->originalLayout == getCombinedDataLayout(scope))) &&
394          "data layout object used, but no longer valid due to the change in "
395          "layout spec");
396 }
397 
398 /// Looks up the value for the given type key in the given cache. If there is no
399 /// such value in the cache, compute it using the given callback and put it in
400 /// the cache before returning.
401 static unsigned cachedLookup(Type t, DenseMap<Type, unsigned> &cache,
402                              function_ref<unsigned(Type)> compute) {
403   auto it = cache.find(t);
404   if (it != cache.end())
405     return it->second;
406 
407   auto result = cache.try_emplace(t, compute(t));
408   return result.first->second;
409 }
410 
411 unsigned mlir::DataLayout::getTypeSize(Type t) const {
412   checkValid();
413   return cachedLookup(t, sizes, [&](Type ty) {
414     DataLayoutEntryList list;
415     if (originalLayout)
416       list = originalLayout.getSpecForType(ty.getTypeID());
417     if (auto iface = dyn_cast_or_null<DataLayoutOpInterface>(scope))
418       return iface.getTypeSize(ty, *this, list);
419     return detail::getDefaultTypeSize(ty, *this, list);
420   });
421 }
422 
423 unsigned mlir::DataLayout::getTypeSizeInBits(Type t) const {
424   checkValid();
425   return cachedLookup(t, bitsizes, [&](Type ty) {
426     DataLayoutEntryList list;
427     if (originalLayout)
428       list = originalLayout.getSpecForType(ty.getTypeID());
429     if (auto iface = dyn_cast_or_null<DataLayoutOpInterface>(scope))
430       return iface.getTypeSizeInBits(ty, *this, list);
431     return detail::getDefaultTypeSizeInBits(ty, *this, list);
432   });
433 }
434 
435 unsigned mlir::DataLayout::getTypeABIAlignment(Type t) const {
436   checkValid();
437   return cachedLookup(t, abiAlignments, [&](Type ty) {
438     DataLayoutEntryList list;
439     if (originalLayout)
440       list = originalLayout.getSpecForType(ty.getTypeID());
441     if (auto iface = dyn_cast_or_null<DataLayoutOpInterface>(scope))
442       return iface.getTypeABIAlignment(ty, *this, list);
443     return detail::getDefaultABIAlignment(ty, *this, list);
444   });
445 }
446 
447 unsigned mlir::DataLayout::getTypePreferredAlignment(Type t) const {
448   checkValid();
449   return cachedLookup(t, preferredAlignments, [&](Type ty) {
450     DataLayoutEntryList list;
451     if (originalLayout)
452       list = originalLayout.getSpecForType(ty.getTypeID());
453     if (auto iface = dyn_cast_or_null<DataLayoutOpInterface>(scope))
454       return iface.getTypePreferredAlignment(ty, *this, list);
455     return detail::getDefaultPreferredAlignment(ty, *this, list);
456   });
457 }
458 
459 //===----------------------------------------------------------------------===//
460 // DataLayoutSpecInterface
461 //===----------------------------------------------------------------------===//
462 
463 void DataLayoutSpecInterface::bucketEntriesByType(
464     DenseMap<TypeID, DataLayoutEntryList> &types,
465     DenseMap<StringAttr, DataLayoutEntryInterface> &ids) {
466   for (DataLayoutEntryInterface entry : getEntries()) {
467     if (auto type = entry.getKey().dyn_cast<Type>())
468       types[type.getTypeID()].push_back(entry);
469     else
470       ids[entry.getKey().get<StringAttr>()] = entry;
471   }
472 }
473 
474 LogicalResult mlir::detail::verifyDataLayoutSpec(DataLayoutSpecInterface spec,
475                                                  Location loc) {
476   // First, verify individual entries.
477   for (DataLayoutEntryInterface entry : spec.getEntries())
478     if (failed(entry.verifyEntry(loc)))
479       return failure();
480 
481   // Second, dispatch verifications of entry groups to types or dialects they
482   // are are associated with.
483   DenseMap<TypeID, DataLayoutEntryList> types;
484   DenseMap<StringAttr, DataLayoutEntryInterface> ids;
485   spec.bucketEntriesByType(types, ids);
486 
487   for (const auto &kvp : types) {
488     auto sampleType = kvp.second.front().getKey().get<Type>();
489     if (sampleType.isa<IndexType>()) {
490       assert(kvp.second.size() == 1 &&
491              "expected one data layout entry for non-parametric 'index' type");
492       if (!kvp.second.front().getValue().isa<IntegerAttr>())
493         return emitError(loc)
494                << "expected integer attribute in the data layout entry for "
495                << sampleType;
496       continue;
497     }
498 
499     if (sampleType.isa<IntegerType, FloatType>()) {
500       for (DataLayoutEntryInterface entry : kvp.second) {
501         auto value = entry.getValue().dyn_cast<DenseIntElementsAttr>();
502         if (!value || !value.getElementType().isSignlessInteger(32)) {
503           emitError(loc) << "expected a dense i32 elements attribute in the "
504                             "data layout entry "
505                          << entry;
506           return failure();
507         }
508 
509         auto elements = llvm::to_vector<2>(value.getValues<int32_t>());
510         unsigned numElements = elements.size();
511         if (numElements < 1 || numElements > 2) {
512           emitError(loc) << "expected 1 or 2 elements in the data layout entry "
513                          << entry;
514           return failure();
515         }
516 
517         int32_t abi = elements[0];
518         int32_t preferred = numElements == 2 ? elements[1] : abi;
519         if (preferred < abi) {
520           emitError(loc)
521               << "preferred alignment is expected to be greater than or equal "
522                  "to the abi alignment in data layout entry "
523               << entry;
524           return failure();
525         }
526       }
527       continue;
528     }
529 
530     if (isa<BuiltinDialect>(&sampleType.getDialect()))
531       return emitError(loc) << "unexpected data layout for a built-in type";
532 
533     auto dlType = sampleType.dyn_cast<DataLayoutTypeInterface>();
534     if (!dlType)
535       return emitError(loc)
536              << "data layout specified for a type that does not support it";
537     if (failed(dlType.verifyEntries(kvp.second, loc)))
538       return failure();
539   }
540 
541   for (const auto &kvp : ids) {
542     StringAttr identifier = kvp.second.getKey().get<StringAttr>();
543     Dialect *dialect = identifier.getReferencedDialect();
544 
545     // Ignore attributes that belong to an unknown dialect, the dialect may
546     // actually implement the relevant interface but we don't know about that.
547     if (!dialect)
548       continue;
549 
550     const auto *iface = dyn_cast<DataLayoutDialectInterface>(dialect);
551     if (!iface) {
552       return emitError(loc)
553              << "the '" << dialect->getNamespace()
554              << "' dialect does not support identifier data layout entries";
555     }
556     if (failed(iface->verifyEntry(kvp.second, loc)))
557       return failure();
558   }
559 
560   return success();
561 }
562 
563 #include "mlir/Interfaces/DataLayoutAttrInterface.cpp.inc"
564 #include "mlir/Interfaces/DataLayoutOpInterface.cpp.inc"
565 #include "mlir/Interfaces/DataLayoutTypeInterface.cpp.inc"
566