1 //===- DLTI.cpp - Data Layout And Target Info MLIR Dialect Implementation -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/DLTI/DLTI.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/BuiltinDialect.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/DialectImplementation.h"
15 #include "llvm/ADT/TypeSwitch.h"
16 
17 using namespace mlir;
18 
19 #include "mlir/Dialect/DLTI/DLTIDialect.cpp.inc"
20 
21 //===----------------------------------------------------------------------===//
22 // DataLayoutEntryAttr
23 //===----------------------------------------------------------------------===//
24 //
25 constexpr const StringLiteral mlir::DataLayoutEntryAttr::kAttrKeyword;
26 
27 namespace mlir {
28 namespace impl {
29 class DataLayoutEntryStorage : public AttributeStorage {
30 public:
31   using KeyTy = std::pair<DataLayoutEntryKey, Attribute>;
32 
DataLayoutEntryStorage(DataLayoutEntryKey entryKey,Attribute value)33   DataLayoutEntryStorage(DataLayoutEntryKey entryKey, Attribute value)
34       : entryKey(entryKey), value(value) {}
35 
construct(AttributeStorageAllocator & allocator,const KeyTy & key)36   static DataLayoutEntryStorage *construct(AttributeStorageAllocator &allocator,
37                                            const KeyTy &key) {
38     return new (allocator.allocate<DataLayoutEntryStorage>())
39         DataLayoutEntryStorage(key.first, key.second);
40   }
41 
operator ==(const KeyTy & other) const42   bool operator==(const KeyTy &other) const {
43     return other.first == entryKey && other.second == value;
44   }
45 
46   DataLayoutEntryKey entryKey;
47   Attribute value;
48 };
49 } // namespace impl
50 } // namespace mlir
51 
get(StringAttr key,Attribute value)52 DataLayoutEntryAttr DataLayoutEntryAttr::get(StringAttr key, Attribute value) {
53   return Base::get(key.getContext(), key, value);
54 }
55 
get(Type key,Attribute value)56 DataLayoutEntryAttr DataLayoutEntryAttr::get(Type key, Attribute value) {
57   return Base::get(key.getContext(), key, value);
58 }
59 
getKey() const60 DataLayoutEntryKey DataLayoutEntryAttr::getKey() const {
61   return getImpl()->entryKey;
62 }
63 
getValue() const64 Attribute DataLayoutEntryAttr::getValue() const { return getImpl()->value; }
65 
66 /// Parses an attribute with syntax:
67 ///   attr ::= `#target.` `dl_entry` `<` (type | quoted-string) `,` attr `>`
parse(AsmParser & parser)68 DataLayoutEntryAttr DataLayoutEntryAttr::parse(AsmParser &parser) {
69   if (failed(parser.parseLess()))
70     return {};
71 
72   Type type = nullptr;
73   std::string identifier;
74   SMLoc idLoc = parser.getCurrentLocation();
75   OptionalParseResult parsedType = parser.parseOptionalType(type);
76   if (parsedType.hasValue() && failed(parsedType.getValue()))
77     return {};
78   if (!parsedType.hasValue()) {
79     OptionalParseResult parsedString = parser.parseOptionalString(&identifier);
80     if (!parsedString.hasValue() || failed(parsedString.getValue())) {
81       parser.emitError(idLoc) << "expected a type or a quoted string";
82       return {};
83     }
84   }
85 
86   Attribute value;
87   if (failed(parser.parseComma()) || failed(parser.parseAttribute(value)) ||
88       failed(parser.parseGreater()))
89     return {};
90 
91   return type ? get(type, value)
92               : get(parser.getBuilder().getStringAttr(identifier), value);
93 }
94 
print(AsmPrinter & os) const95 void DataLayoutEntryAttr::print(AsmPrinter &os) const {
96   os << DataLayoutEntryAttr::kAttrKeyword << "<";
97   if (auto type = getKey().dyn_cast<Type>())
98     os << type;
99   else
100     os << "\"" << getKey().get<StringAttr>().strref() << "\"";
101   os << ", " << getValue() << ">";
102 }
103 
104 //===----------------------------------------------------------------------===//
105 // DataLayoutSpecAttr
106 //===----------------------------------------------------------------------===//
107 //
108 constexpr const StringLiteral mlir::DataLayoutSpecAttr::kAttrKeyword;
109 
110 namespace mlir {
111 namespace impl {
112 class DataLayoutSpecStorage : public AttributeStorage {
113 public:
114   using KeyTy = ArrayRef<DataLayoutEntryInterface>;
115 
DataLayoutSpecStorage(ArrayRef<DataLayoutEntryInterface> entries)116   DataLayoutSpecStorage(ArrayRef<DataLayoutEntryInterface> entries)
117       : entries(entries) {}
118 
operator ==(const KeyTy & key) const119   bool operator==(const KeyTy &key) const { return key == entries; }
120 
construct(AttributeStorageAllocator & allocator,const KeyTy & key)121   static DataLayoutSpecStorage *construct(AttributeStorageAllocator &allocator,
122                                           const KeyTy &key) {
123     return new (allocator.allocate<DataLayoutSpecStorage>())
124         DataLayoutSpecStorage(allocator.copyInto(key));
125   }
126 
127   ArrayRef<DataLayoutEntryInterface> entries;
128 };
129 } // namespace impl
130 } // namespace mlir
131 
132 DataLayoutSpecAttr
get(MLIRContext * ctx,ArrayRef<DataLayoutEntryInterface> entries)133 DataLayoutSpecAttr::get(MLIRContext *ctx,
134                         ArrayRef<DataLayoutEntryInterface> entries) {
135   return Base::get(ctx, entries);
136 }
137 
138 DataLayoutSpecAttr
getChecked(function_ref<InFlightDiagnostic ()> emitError,MLIRContext * context,ArrayRef<DataLayoutEntryInterface> entries)139 DataLayoutSpecAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
140                                MLIRContext *context,
141                                ArrayRef<DataLayoutEntryInterface> entries) {
142   return Base::getChecked(emitError, context, entries);
143 }
144 
145 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,ArrayRef<DataLayoutEntryInterface> entries)146 DataLayoutSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
147                            ArrayRef<DataLayoutEntryInterface> entries) {
148   DenseSet<Type> types;
149   DenseSet<StringAttr> ids;
150   for (DataLayoutEntryInterface entry : entries) {
151     if (auto type = entry.getKey().dyn_cast<Type>()) {
152       if (!types.insert(type).second)
153         return emitError() << "repeated layout entry key: " << type;
154     } else {
155       auto id = entry.getKey().get<StringAttr>();
156       if (!ids.insert(id).second)
157         return emitError() << "repeated layout entry key: " << id.getValue();
158     }
159   }
160   return success();
161 }
162 
163 /// Given a list of old and a list of new entries, overwrites old entries with
164 /// new ones if they have matching keys, appends new entries to the old entry
165 /// list otherwise.
166 static void
overwriteDuplicateEntries(SmallVectorImpl<DataLayoutEntryInterface> & oldEntries,ArrayRef<DataLayoutEntryInterface> newEntries)167 overwriteDuplicateEntries(SmallVectorImpl<DataLayoutEntryInterface> &oldEntries,
168                           ArrayRef<DataLayoutEntryInterface> newEntries) {
169   unsigned oldEntriesSize = oldEntries.size();
170   for (DataLayoutEntryInterface entry : newEntries) {
171     // We expect a small (dozens) number of entries, so it is practically
172     // cheaper to iterate over the list linearly rather than to create an
173     // auxiliary hashmap to avoid duplication. Also note that we never need to
174     // check for duplicate keys the values that were added from `newEntries`.
175     bool replaced = false;
176     for (unsigned i = 0; i < oldEntriesSize; ++i) {
177       if (oldEntries[i].getKey() == entry.getKey()) {
178         oldEntries[i] = entry;
179         replaced = true;
180         break;
181       }
182     }
183     if (!replaced)
184       oldEntries.push_back(entry);
185   }
186 }
187 
188 /// Combines a data layout spec into the given lists of entries organized by
189 /// type class and identifier, overwriting them if necessary. Fails to combine
190 /// if the two entries with identical keys are not compatible.
191 static LogicalResult
combineOneSpec(DataLayoutSpecInterface spec,DenseMap<TypeID,DataLayoutEntryList> & entriesForType,DenseMap<StringAttr,DataLayoutEntryInterface> & entriesForID)192 combineOneSpec(DataLayoutSpecInterface spec,
193                DenseMap<TypeID, DataLayoutEntryList> &entriesForType,
194                DenseMap<StringAttr, DataLayoutEntryInterface> &entriesForID) {
195   // A missing spec should be fine.
196   if (!spec)
197     return success();
198 
199   DenseMap<TypeID, DataLayoutEntryList> newEntriesForType;
200   DenseMap<StringAttr, DataLayoutEntryInterface> newEntriesForID;
201   spec.bucketEntriesByType(newEntriesForType, newEntriesForID);
202 
203   // Try overwriting the old entries with the new ones.
204   for (auto &kvp : newEntriesForType) {
205     if (!entriesForType.count(kvp.first)) {
206       entriesForType[kvp.first] = std::move(kvp.second);
207       continue;
208     }
209 
210     Type typeSample = kvp.second.front().getKey().get<Type>();
211     assert(&typeSample.getDialect() !=
212                typeSample.getContext()->getLoadedDialect<BuiltinDialect>() &&
213            "unexpected data layout entry for built-in type");
214 
215     auto interface = typeSample.cast<DataLayoutTypeInterface>();
216     if (!interface.areCompatible(entriesForType.lookup(kvp.first), kvp.second))
217       return failure();
218 
219     overwriteDuplicateEntries(entriesForType[kvp.first], kvp.second);
220   }
221 
222   for (const auto &kvp : newEntriesForID) {
223     StringAttr id = kvp.second.getKey().get<StringAttr>();
224     Dialect *dialect = id.getReferencedDialect();
225     if (!entriesForID.count(id)) {
226       entriesForID[id] = kvp.second;
227       continue;
228     }
229 
230     // Attempt to combine the enties using the dialect interface. If the
231     // dialect is not loaded for some reason, use the default combinator
232     // that conservatively accepts identical entries only.
233     entriesForID[id] =
234         dialect ? cast<DataLayoutDialectInterface>(dialect)->combine(
235                       entriesForID[id], kvp.second)
236                 : DataLayoutDialectInterface::defaultCombine(entriesForID[id],
237                                                              kvp.second);
238     if (!entriesForID[id])
239       return failure();
240   }
241 
242   return success();
243 }
244 
245 DataLayoutSpecAttr
combineWith(ArrayRef<DataLayoutSpecInterface> specs) const246 DataLayoutSpecAttr::combineWith(ArrayRef<DataLayoutSpecInterface> specs) const {
247   // Only combine with attributes of the same kind.
248   // TODO: reconsider this when the need arises.
249   if (llvm::any_of(specs, [](DataLayoutSpecInterface spec) {
250         return !spec.isa<DataLayoutSpecAttr>();
251       }))
252     return {};
253 
254   // Combine all specs in order, with `this` being the last one.
255   DenseMap<TypeID, DataLayoutEntryList> entriesForType;
256   DenseMap<StringAttr, DataLayoutEntryInterface> entriesForID;
257   for (DataLayoutSpecInterface spec : specs)
258     if (failed(combineOneSpec(spec, entriesForType, entriesForID)))
259       return nullptr;
260   if (failed(combineOneSpec(*this, entriesForType, entriesForID)))
261     return nullptr;
262 
263   // Rebuild the linear list of entries.
264   SmallVector<DataLayoutEntryInterface> entries;
265   llvm::append_range(entries, llvm::make_second_range(entriesForID));
266   for (const auto &kvp : entriesForType)
267     llvm::append_range(entries, kvp.getSecond());
268 
269   return DataLayoutSpecAttr::get(getContext(), entries);
270 }
271 
getEntries() const272 DataLayoutEntryListRef DataLayoutSpecAttr::getEntries() const {
273   return getImpl()->entries;
274 }
275 
276 /// Parses an attribute with syntax
277 ///   attr ::= `#target.` `dl_spec` `<` attr-list? `>`
278 ///   attr-list ::= attr
279 ///               | attr `,` attr-list
parse(AsmParser & parser)280 DataLayoutSpecAttr DataLayoutSpecAttr::parse(AsmParser &parser) {
281   if (failed(parser.parseLess()))
282     return {};
283 
284   // Empty spec.
285   if (succeeded(parser.parseOptionalGreater()))
286     return get(parser.getContext(), {});
287 
288   SmallVector<DataLayoutEntryInterface> entries;
289   if (parser.parseCommaSeparatedList(
290           [&]() { return parser.parseAttribute(entries.emplace_back()); }) ||
291       parser.parseGreater())
292     return {};
293 
294   return getChecked([&] { return parser.emitError(parser.getNameLoc()); },
295                     parser.getContext(), entries);
296 }
297 
print(AsmPrinter & os) const298 void DataLayoutSpecAttr::print(AsmPrinter &os) const {
299   os << DataLayoutSpecAttr::kAttrKeyword << "<";
300   llvm::interleaveComma(getEntries(), os);
301   os << ">";
302 }
303 
304 //===----------------------------------------------------------------------===//
305 // DLTIDialect
306 //===----------------------------------------------------------------------===//
307 
308 constexpr const StringLiteral mlir::DLTIDialect::kDataLayoutAttrName;
309 constexpr const StringLiteral mlir::DLTIDialect::kDataLayoutEndiannessKey;
310 constexpr const StringLiteral mlir::DLTIDialect::kDataLayoutEndiannessBig;
311 constexpr const StringLiteral mlir::DLTIDialect::kDataLayoutEndiannessLittle;
312 
313 namespace {
314 class TargetDataLayoutInterface : public DataLayoutDialectInterface {
315 public:
316   using DataLayoutDialectInterface::DataLayoutDialectInterface;
317 
verifyEntry(DataLayoutEntryInterface entry,Location loc) const318   LogicalResult verifyEntry(DataLayoutEntryInterface entry,
319                             Location loc) const final {
320     StringRef entryName = entry.getKey().get<StringAttr>().strref();
321     if (entryName == DLTIDialect::kDataLayoutEndiannessKey) {
322       auto value = entry.getValue().dyn_cast<StringAttr>();
323       if (value &&
324           (value.getValue() == DLTIDialect::kDataLayoutEndiannessBig ||
325            value.getValue() == DLTIDialect::kDataLayoutEndiannessLittle))
326         return success();
327       return emitError(loc) << "'" << entryName
328                             << "' data layout entry is expected to be either '"
329                             << DLTIDialect::kDataLayoutEndiannessBig << "' or '"
330                             << DLTIDialect::kDataLayoutEndiannessLittle << "'";
331     }
332     return emitError(loc) << "unknown data layout entry name: " << entryName;
333   }
334 };
335 } // namespace
336 
initialize()337 void DLTIDialect::initialize() {
338   addAttributes<DataLayoutEntryAttr, DataLayoutSpecAttr>();
339   addInterfaces<TargetDataLayoutInterface>();
340 }
341 
parseAttribute(DialectAsmParser & parser,Type type) const342 Attribute DLTIDialect::parseAttribute(DialectAsmParser &parser,
343                                       Type type) const {
344   StringRef attrKind;
345   if (parser.parseKeyword(&attrKind))
346     return {};
347 
348   if (attrKind == DataLayoutEntryAttr::kAttrKeyword)
349     return DataLayoutEntryAttr::parse(parser);
350   if (attrKind == DataLayoutSpecAttr::kAttrKeyword)
351     return DataLayoutSpecAttr::parse(parser);
352 
353   parser.emitError(parser.getNameLoc(), "unknown attrribute type: ")
354       << attrKind;
355   return {};
356 }
357 
printAttribute(Attribute attr,DialectAsmPrinter & os) const358 void DLTIDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const {
359   llvm::TypeSwitch<Attribute>(attr)
360       .Case<DataLayoutEntryAttr, DataLayoutSpecAttr>(
361           [&](auto a) { a.print(os); })
362       .Default([](Attribute) { llvm_unreachable("unknown attribute kind"); });
363 }
364 
verifyOperationAttribute(Operation * op,NamedAttribute attr)365 LogicalResult DLTIDialect::verifyOperationAttribute(Operation *op,
366                                                     NamedAttribute attr) {
367   if (attr.getName() == DLTIDialect::kDataLayoutAttrName) {
368     if (!attr.getValue().isa<DataLayoutSpecAttr>()) {
369       return op->emitError() << "'" << DLTIDialect::kDataLayoutAttrName
370                              << "' is expected to be a #dlti.dl_spec attribute";
371     }
372     if (isa<ModuleOp>(op))
373       return detail::verifyDataLayoutOp(op);
374     return success();
375   }
376 
377   return op->emitError() << "attribute '" << attr.getName().getValue()
378                          << "' not supported by dialect";
379 }
380