1 //===- DataLayoutInterfaces.cpp - Data Layout Interface Implementation ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Interfaces/DataLayoutInterfaces.h"
10 #include "mlir/IR/BuiltinDialect.h"
11 #include "mlir/IR/BuiltinOps.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/IR/Operation.h"
14 
15 #include "llvm/ADT/TypeSwitch.h"
16 #include "llvm/Support/MathExtras.h"
17 
18 using namespace mlir;
19 
20 //===----------------------------------------------------------------------===//
21 // Default implementations
22 //===----------------------------------------------------------------------===//
23 
24 /// Reports that the given type is missing the data layout information and
25 /// exits.
26 [[noreturn]] static void reportMissingDataLayout(Type type) {
27   std::string message;
28   llvm::raw_string_ostream os(message);
29   os << "neither the scoping op nor the type class provide data layout "
30         "information for "
31      << type;
32   llvm::report_fatal_error(Twine(os.str()));
33 }
34 
35 /// Returns the bitwidth of the index type if specified in the param list.
36 /// Assumes 64-bit index otherwise.
37 static unsigned getIndexBitwidth(DataLayoutEntryListRef params) {
38   if (params.empty())
39     return 64;
40   auto attr = params.front().getValue().cast<IntegerAttr>();
41   return attr.getValue().getZExtValue();
42 }
43 
44 unsigned
45 mlir::detail::getDefaultTypeSize(Type type, const DataLayout &dataLayout,
46                                  ArrayRef<DataLayoutEntryInterface> params) {
47   unsigned bits = getDefaultTypeSizeInBits(type, dataLayout, params);
48   return llvm::divideCeil(bits, 8);
49 }
50 
51 unsigned mlir::detail::getDefaultTypeSizeInBits(Type type,
52                                                 const DataLayout &dataLayout,
53                                                 DataLayoutEntryListRef params) {
54   if (type.isa<IntegerType, FloatType>())
55     return type.getIntOrFloatBitWidth();
56 
57   if (auto ctype = type.dyn_cast<ComplexType>()) {
58     auto et = ctype.getElementType();
59     auto innerAlignment =
60         getDefaultPreferredAlignment(et, dataLayout, params) * 8;
61     auto innerSize = getDefaultTypeSizeInBits(et, dataLayout, params);
62 
63     // Include padding required to align the imaginary value in the complex
64     // type.
65     return llvm::alignTo(innerSize, innerAlignment) + innerSize;
66   }
67 
68   // Index is an integer of some bitwidth.
69   if (type.isa<IndexType>())
70     return dataLayout.getTypeSizeInBits(
71         IntegerType::get(type.getContext(), getIndexBitwidth(params)));
72 
73   // Sizes of vector types are rounded up to those of types with closest
74   // power-of-two number of elements in the innermost dimension. We also assume
75   // there is no bit-packing at the moment element sizes are taken in bytes and
76   // multiplied with 8 bits.
77   // TODO: make this extensible.
78   if (auto vecType = type.dyn_cast<VectorType>())
79     return vecType.getNumElements() / vecType.getShape().back() *
80            llvm::PowerOf2Ceil(vecType.getShape().back()) *
81            dataLayout.getTypeSize(vecType.getElementType()) * 8;
82 
83   if (auto typeInterface = type.dyn_cast<DataLayoutTypeInterface>())
84     return typeInterface.getTypeSizeInBits(dataLayout, params);
85 
86   reportMissingDataLayout(type);
87 }
88 
89 unsigned mlir::detail::getDefaultABIAlignment(
90     Type type, const DataLayout &dataLayout,
91     ArrayRef<DataLayoutEntryInterface> params) {
92   // Natural alignment is the closest power-of-two number above.
93   if (type.isa<FloatType, VectorType>())
94     return llvm::PowerOf2Ceil(dataLayout.getTypeSize(type));
95 
96   // Index is an integer of some bitwidth.
97   if (type.isa<IndexType>())
98     return dataLayout.getTypeABIAlignment(
99         IntegerType::get(type.getContext(), getIndexBitwidth(params)));
100 
101   if (auto intType = type.dyn_cast<IntegerType>()) {
102     return intType.getWidth() < 64
103                ? llvm::PowerOf2Ceil(llvm::divideCeil(intType.getWidth(), 8))
104                : 4;
105   }
106 
107   if (auto ctype = type.dyn_cast<ComplexType>())
108     return getDefaultABIAlignment(ctype.getElementType(), dataLayout, params);
109 
110   if (auto typeInterface = type.dyn_cast<DataLayoutTypeInterface>())
111     return typeInterface.getABIAlignment(dataLayout, params);
112 
113   reportMissingDataLayout(type);
114 }
115 
116 unsigned mlir::detail::getDefaultPreferredAlignment(
117     Type type, const DataLayout &dataLayout,
118     ArrayRef<DataLayoutEntryInterface> params) {
119   // Preferred alignment is same as natural for floats and vectors.
120   if (type.isa<FloatType, VectorType>())
121     return dataLayout.getTypeABIAlignment(type);
122 
123   // Preferred alignment is the cloest power-of-two number above for integers
124   // (ABI alignment may be smaller).
125   if (type.isa<IntegerType, IndexType>())
126     return llvm::PowerOf2Ceil(dataLayout.getTypeSize(type));
127 
128   if (auto ctype = type.dyn_cast<ComplexType>())
129     return getDefaultPreferredAlignment(ctype.getElementType(), dataLayout,
130                                         params);
131 
132   if (auto typeInterface = type.dyn_cast<DataLayoutTypeInterface>())
133     return typeInterface.getPreferredAlignment(dataLayout, params);
134 
135   reportMissingDataLayout(type);
136 }
137 
138 DataLayoutEntryList
139 mlir::detail::filterEntriesForType(DataLayoutEntryListRef entries,
140                                    TypeID typeID) {
141   return llvm::to_vector<4>(llvm::make_filter_range(
142       entries, [typeID](DataLayoutEntryInterface entry) {
143         auto type = entry.getKey().dyn_cast<Type>();
144         return type && type.getTypeID() == typeID;
145       }));
146 }
147 
148 DataLayoutEntryInterface
149 mlir::detail::filterEntryForIdentifier(DataLayoutEntryListRef entries,
150                                        Identifier id) {
151   const auto *it = llvm::find_if(entries, [id](DataLayoutEntryInterface entry) {
152     if (!entry.getKey().is<Identifier>())
153       return false;
154     return entry.getKey().get<Identifier>() == id;
155   });
156   return it == entries.end() ? DataLayoutEntryInterface() : *it;
157 }
158 
159 static DataLayoutSpecInterface getSpec(Operation *operation) {
160   return llvm::TypeSwitch<Operation *, DataLayoutSpecInterface>(operation)
161       .Case<ModuleOp, DataLayoutOpInterface>(
162           [&](auto op) { return op.getDataLayoutSpec(); })
163       .Default([](Operation *) {
164         llvm_unreachable("expected an op with data layout spec");
165         return DataLayoutSpecInterface();
166       });
167 }
168 
169 /// Populates `opsWithLayout` with the list of proper ancestors of `leaf` that
170 /// are either modules or implement the `DataLayoutOpInterface`.
171 static void
172 collectParentLayouts(Operation *leaf,
173                      SmallVectorImpl<DataLayoutSpecInterface> &specs,
174                      SmallVectorImpl<Location> *opLocations = nullptr) {
175   if (!leaf)
176     return;
177 
178   for (Operation *parent = leaf->getParentOp(); parent != nullptr;
179        parent = parent->getParentOp()) {
180     llvm::TypeSwitch<Operation *>(parent)
181         .Case<ModuleOp>([&](ModuleOp op) {
182           // Skip top-level module op unless it has a layout. Top-level module
183           // without layout is most likely the one implicitly added by the
184           // parser and it doesn't have location. Top-level null specification
185           // would have had the same effect as not having a specification at all
186           // (using type defaults).
187           if (!op->getParentOp() && !op.getDataLayoutSpec())
188             return;
189           specs.push_back(op.getDataLayoutSpec());
190           if (opLocations)
191             opLocations->push_back(op.getLoc());
192         })
193         .Case<DataLayoutOpInterface>([&](DataLayoutOpInterface op) {
194           specs.push_back(op.getDataLayoutSpec());
195           if (opLocations)
196             opLocations->push_back(op.getLoc());
197         });
198   }
199 }
200 
201 /// Returns a layout spec that is a combination of the layout specs attached
202 /// to the given operation and all its ancestors.
203 static DataLayoutSpecInterface getCombinedDataLayout(Operation *leaf) {
204   if (!leaf)
205     return {};
206 
207   assert((isa<ModuleOp, DataLayoutOpInterface>(leaf)) &&
208          "expected an op with data layout spec");
209 
210   SmallVector<DataLayoutOpInterface> opsWithLayout;
211   SmallVector<DataLayoutSpecInterface> specs;
212   collectParentLayouts(leaf, specs);
213 
214   // Fast track if there are no ancestors.
215   if (specs.empty())
216     return getSpec(leaf);
217 
218   // Create the list of non-null specs (null/missing specs can be safely
219   // ignored) from the outermost to the innermost.
220   auto nonNullSpecs = llvm::to_vector<2>(llvm::make_filter_range(
221       llvm::reverse(specs),
222       [](DataLayoutSpecInterface iface) { return iface != nullptr; }));
223 
224   // Combine the specs using the innermost as anchor.
225   if (DataLayoutSpecInterface current = getSpec(leaf))
226     return current.combineWith(nonNullSpecs);
227   if (nonNullSpecs.empty())
228     return {};
229   return nonNullSpecs.back().combineWith(
230       llvm::makeArrayRef(nonNullSpecs).drop_back());
231 }
232 
233 LogicalResult mlir::detail::verifyDataLayoutOp(Operation *op) {
234   DataLayoutSpecInterface spec = getSpec(op);
235   // The layout specification may be missing and it's fine.
236   if (!spec)
237     return success();
238 
239   if (failed(spec.verifySpec(op->getLoc())))
240     return failure();
241   if (!getCombinedDataLayout(op)) {
242     InFlightDiagnostic diag =
243         op->emitError()
244         << "data layout does not combine with layouts of enclosing ops";
245     SmallVector<DataLayoutSpecInterface> specs;
246     SmallVector<Location> opLocations;
247     collectParentLayouts(op, specs, &opLocations);
248     for (Location loc : opLocations)
249       diag.attachNote(loc) << "enclosing op with data layout";
250     return diag;
251   }
252   return success();
253 }
254 
255 //===----------------------------------------------------------------------===//
256 // DataLayout
257 //===----------------------------------------------------------------------===//
258 
259 template <typename OpTy>
260 void checkMissingLayout(DataLayoutSpecInterface originalLayout, OpTy op) {
261   if (!originalLayout) {
262     assert((!op || !op.getDataLayoutSpec()) &&
263            "could not compute layout information for an op (failed to "
264            "combine attributes?)");
265   }
266 }
267 
268 mlir::DataLayout::DataLayout() : DataLayout(ModuleOp()) {}
269 
270 mlir::DataLayout::DataLayout(DataLayoutOpInterface op)
271     : originalLayout(getCombinedDataLayout(op)), scope(op) {
272 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
273   checkMissingLayout(originalLayout, op);
274   collectParentLayouts(op, layoutStack);
275 #endif
276 }
277 
278 mlir::DataLayout::DataLayout(ModuleOp op)
279     : originalLayout(getCombinedDataLayout(op)), scope(op) {
280 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
281   checkMissingLayout(originalLayout, op);
282   collectParentLayouts(op, layoutStack);
283 #endif
284 }
285 
286 mlir::DataLayout mlir::DataLayout::closest(Operation *op) {
287   // Search the closest parent either being a module operation or implementing
288   // the data layout interface.
289   while (op) {
290     if (auto module = dyn_cast<ModuleOp>(op))
291       return DataLayout(module);
292     if (auto iface = dyn_cast<DataLayoutOpInterface>(op))
293       return DataLayout(iface);
294     op = op->getParentOp();
295   }
296   return DataLayout();
297 }
298 
299 void mlir::DataLayout::checkValid() const {
300 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
301   SmallVector<DataLayoutSpecInterface> specs;
302   collectParentLayouts(scope, specs);
303   assert(specs.size() == layoutStack.size() &&
304          "data layout object used, but no longer valid due to the change in "
305          "number of nested layouts");
306   for (auto pair : llvm::zip(specs, layoutStack)) {
307     Attribute newLayout = std::get<0>(pair);
308     Attribute origLayout = std::get<1>(pair);
309     assert(newLayout == origLayout &&
310            "data layout object used, but no longer valid "
311            "due to the change in layout attributes");
312   }
313 #endif
314   assert(((!scope && !this->originalLayout) ||
315           (scope && this->originalLayout == getCombinedDataLayout(scope))) &&
316          "data layout object used, but no longer valid due to the change in "
317          "layout spec");
318 }
319 
320 /// Looks up the value for the given type key in the given cache. If there is no
321 /// such value in the cache, compute it using the given callback and put it in
322 /// the cache before returning.
323 static unsigned cachedLookup(Type t, DenseMap<Type, unsigned> &cache,
324                              function_ref<unsigned(Type)> compute) {
325   auto it = cache.find(t);
326   if (it != cache.end())
327     return it->second;
328 
329   auto result = cache.try_emplace(t, compute(t));
330   return result.first->second;
331 }
332 
333 unsigned mlir::DataLayout::getTypeSize(Type t) const {
334   checkValid();
335   return cachedLookup(t, sizes, [&](Type ty) {
336     DataLayoutEntryList list;
337     if (originalLayout)
338       list = originalLayout.getSpecForType(ty.getTypeID());
339     if (auto iface = dyn_cast_or_null<DataLayoutOpInterface>(scope))
340       return iface.getTypeSize(ty, *this, list);
341     return detail::getDefaultTypeSize(ty, *this, list);
342   });
343 }
344 
345 unsigned mlir::DataLayout::getTypeSizeInBits(Type t) const {
346   checkValid();
347   return cachedLookup(t, bitsizes, [&](Type ty) {
348     DataLayoutEntryList list;
349     if (originalLayout)
350       list = originalLayout.getSpecForType(ty.getTypeID());
351     if (auto iface = dyn_cast_or_null<DataLayoutOpInterface>(scope))
352       return iface.getTypeSizeInBits(ty, *this, list);
353     return detail::getDefaultTypeSizeInBits(ty, *this, list);
354   });
355 }
356 
357 unsigned mlir::DataLayout::getTypeABIAlignment(Type t) const {
358   checkValid();
359   return cachedLookup(t, abiAlignments, [&](Type ty) {
360     DataLayoutEntryList list;
361     if (originalLayout)
362       list = originalLayout.getSpecForType(ty.getTypeID());
363     if (auto iface = dyn_cast_or_null<DataLayoutOpInterface>(scope))
364       return iface.getTypeABIAlignment(ty, *this, list);
365     return detail::getDefaultABIAlignment(ty, *this, list);
366   });
367 }
368 
369 unsigned mlir::DataLayout::getTypePreferredAlignment(Type t) const {
370   checkValid();
371   return cachedLookup(t, preferredAlignments, [&](Type ty) {
372     DataLayoutEntryList list;
373     if (originalLayout)
374       list = originalLayout.getSpecForType(ty.getTypeID());
375     if (auto iface = dyn_cast_or_null<DataLayoutOpInterface>(scope))
376       return iface.getTypePreferredAlignment(ty, *this, list);
377     return detail::getDefaultPreferredAlignment(ty, *this, list);
378   });
379 }
380 
381 //===----------------------------------------------------------------------===//
382 // DataLayoutSpecInterface
383 //===----------------------------------------------------------------------===//
384 
385 void DataLayoutSpecInterface::bucketEntriesByType(
386     DenseMap<TypeID, DataLayoutEntryList> &types,
387     DenseMap<Identifier, DataLayoutEntryInterface> &ids) {
388   for (DataLayoutEntryInterface entry : getEntries()) {
389     if (auto type = entry.getKey().dyn_cast<Type>())
390       types[type.getTypeID()].push_back(entry);
391     else
392       ids[entry.getKey().get<Identifier>()] = entry;
393   }
394 }
395 
396 LogicalResult mlir::detail::verifyDataLayoutSpec(DataLayoutSpecInterface spec,
397                                                  Location loc) {
398   // First, verify individual entries.
399   for (DataLayoutEntryInterface entry : spec.getEntries())
400     if (failed(entry.verifyEntry(loc)))
401       return failure();
402 
403   // Second, dispatch verifications of entry groups to types or dialects they
404   // are are associated with.
405   DenseMap<TypeID, DataLayoutEntryList> types;
406   DenseMap<Identifier, DataLayoutEntryInterface> ids;
407   spec.bucketEntriesByType(types, ids);
408 
409   for (const auto &kvp : types) {
410     auto sampleType = kvp.second.front().getKey().get<Type>();
411     if (sampleType.isa<IndexType>()) {
412       assert(kvp.second.size() == 1 &&
413              "expected one data layout entry for non-parametric 'index' type");
414       if (!kvp.second.front().getValue().isa<IntegerAttr>())
415         return emitError(loc)
416                << "expected integer attribute in the data layout entry for "
417                << sampleType;
418       continue;
419     }
420 
421     if (isa<BuiltinDialect>(&sampleType.getDialect()))
422       return emitError(loc) << "unexpected data layout for a built-in type";
423 
424     auto dlType = sampleType.dyn_cast<DataLayoutTypeInterface>();
425     if (!dlType)
426       return emitError(loc)
427              << "data layout specified for a type that does not support it";
428     if (failed(dlType.verifyEntries(kvp.second, loc)))
429       return failure();
430   }
431 
432   for (const auto &kvp : ids) {
433     Identifier identifier = kvp.second.getKey().get<Identifier>();
434     Dialect *dialect = identifier.getReferencedDialect();
435 
436     // Ignore attributes that belong to an unknown dialect, the dialect may
437     // actually implement the relevant interface but we don't know about that.
438     if (!dialect)
439       continue;
440 
441     const auto *iface =
442         dialect->getRegisteredInterface<DataLayoutDialectInterface>();
443     if (!iface) {
444       return emitError(loc)
445              << "the '" << dialect->getNamespace()
446              << "' dialect does not support identifier data layout entries";
447     }
448     if (failed(iface->verifyEntry(kvp.second, loc)))
449       return failure();
450   }
451 
452   return success();
453 }
454 
455 #include "mlir/Interfaces/DataLayoutAttrInterface.cpp.inc"
456 #include "mlir/Interfaces/DataLayoutOpInterface.cpp.inc"
457 #include "mlir/Interfaces/DataLayoutTypeInterface.cpp.inc"
458