1 //===- CodeGenHelpers.cpp - MLIR op definitions generator ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpDefinitionsGen uses the description of operations to generate C++
10 // definitions for ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/TableGen/CodeGenHelpers.h"
15 #include "mlir/TableGen/Operator.h"
16 #include "mlir/TableGen/Pattern.h"
17 #include "llvm/ADT/SetVector.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/Support/Path.h"
20 #include "llvm/TableGen/Record.h"
21 
22 using namespace llvm;
23 using namespace mlir;
24 using namespace mlir::tblgen;
25 
26 /// Generate a unique label based on the current file name to prevent name
27 /// collisions if multiple generated files are included at once.
28 static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) {
29   // Use the input file name when generating a unique name.
30   std::string inputFilename = records.getInputFilename();
31 
32   // Drop all but the base filename.
33   StringRef nameRef = llvm::sys::path::filename(inputFilename);
34   nameRef.consume_back(".td");
35 
36   // Sanitize any invalid characters.
37   std::string uniqueName;
38   for (char c : nameRef) {
39     if (llvm::isAlnum(c) || c == '_')
40       uniqueName.push_back(c);
41     else
42       uniqueName.append(llvm::utohexstr((unsigned char)c));
43   }
44   return uniqueName;
45 }
46 
47 StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
48     raw_ostream &os, const llvm::RecordKeeper &records)
49     : os(os), uniqueOutputLabel(getUniqueOutputLabel(records)) {}
50 
51 void StaticVerifierFunctionEmitter::emitOpConstraints(
52     ArrayRef<llvm::Record *> opDefs, bool emitDecl) {
53   collectOpConstraints(opDefs);
54   if (emitDecl)
55     return;
56 
57   NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace());
58   emitTypeConstraints();
59   emitAttrConstraints();
60   emitSuccessorConstraints();
61   emitRegionConstraints();
62 }
63 
64 void StaticVerifierFunctionEmitter::emitPatternConstraints(
65     const llvm::DenseSet<DagLeaf> &constraints) {
66   collectPatternConstraints(constraints);
67   emitPatternConstraints();
68 }
69 
70 //===----------------------------------------------------------------------===//
71 // Constraint Getters
72 
73 StringRef StaticVerifierFunctionEmitter::getTypeConstraintFn(
74     const Constraint &constraint) const {
75   auto it = typeConstraints.find(constraint);
76   assert(it != typeConstraints.end() && "expected to find a type constraint");
77   return it->second;
78 }
79 
80 // Find a uniqued attribute constraint. Since not all attribute constraints can
81 // be uniqued, return None if one was not found.
82 Optional<StringRef> StaticVerifierFunctionEmitter::getAttrConstraintFn(
83     const Constraint &constraint) const {
84   auto it = attrConstraints.find(constraint);
85   return it == attrConstraints.end() ? Optional<StringRef>()
86                                      : StringRef(it->second);
87 }
88 
89 StringRef StaticVerifierFunctionEmitter::getSuccessorConstraintFn(
90     const Constraint &constraint) const {
91   auto it = successorConstraints.find(constraint);
92   assert(it != successorConstraints.end() &&
93          "expected to find a sucessor constraint");
94   return it->second;
95 }
96 
97 StringRef StaticVerifierFunctionEmitter::getRegionConstraintFn(
98     const Constraint &constraint) const {
99   auto it = regionConstraints.find(constraint);
100   assert(it != regionConstraints.end() &&
101          "expected to find a region constraint");
102   return it->second;
103 }
104 
105 //===----------------------------------------------------------------------===//
106 // Constraint Emission
107 
108 /// Code templates for emitting type, attribute, successor, and region
109 /// constraints. Each of these templates require the following arguments:
110 ///
111 /// {0}: The unique constraint name.
112 /// {1}: The constraint code.
113 /// {2}: The constraint description.
114 
115 /// Code for a type constraint. These may be called on the type of either
116 /// operands or results.
117 static const char *const typeConstraintCode = R"(
118 static ::mlir::LogicalResult {0}(
119     ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind,
120     unsigned valueIndex) {
121   if (!({1})) {
122     return op->emitOpError(valueKind) << " #" << valueIndex
123         << " must be {2}, but got " << type;
124   }
125   return ::mlir::success();
126 }
127 )";
128 
129 /// Code for an attribute constraint. These may be called from ops only.
130 /// Attribute constraints cannot reference anything other than `$_self` and
131 /// `$_op`.
132 ///
133 /// TODO: Unique constraints for adaptors. However, most Adaptor::verify
134 /// functions are stripped anyways.
135 static const char *const attrConstraintCode = R"(
136 static ::mlir::LogicalResult {0}(
137     ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) {
138   if (attr && !({1})) {
139     return op->emitOpError("attribute '") << attrName
140         << "' failed to satisfy constraint: {2}";
141   }
142   return ::mlir::success();
143 }
144 )";
145 
146 /// Code for a successor constraint.
147 static const char *const successorConstraintCode = R"(
148 static ::mlir::LogicalResult {0}(
149     ::mlir::Operation *op, ::mlir::Block *successor,
150     ::llvm::StringRef successorName, unsigned successorIndex) {
151   if (!({1})) {
152     return op->emitOpError("successor #") << successorIndex << " ('"
153         << successorName << ")' failed to verify constraint: {2}";
154   }
155   return ::mlir::success();
156 }
157 )";
158 
159 /// Code for a region constraint. Callers will need to pass in the region's name
160 /// for emitting an error message.
161 static const char *const regionConstraintCode = R"(
162 static ::mlir::LogicalResult {0}(
163     ::mlir::Operation *op, ::mlir::Region &region, ::llvm::StringRef regionName,
164     unsigned regionIndex) {
165   if (!({1})) {
166     return op->emitOpError("region #") << regionIndex
167         << (regionName.empty() ? " " : " ('" + regionName + "') ")
168         << "failed to verify constraint: {2}";
169   }
170   return ::mlir::success();
171 }
172 )";
173 
174 /// Code for a pattern type or attribute constraint.
175 ///
176 /// {3}: "Type type" or "Attribute attr".
177 static const char *const patternAttrOrTypeConstraintCode = R"(
178 static ::mlir::LogicalResult {0}(
179     ::mlir::PatternRewriter &rewriter, ::mlir::Operation *op, ::mlir::{3},
180     ::llvm::StringRef failureStr) {
181   if (!({1})) {
182     return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) {
183       diag << failureStr << ": {2}";
184     });
185   }
186   return ::mlir::success();
187 }
188 )";
189 
190 void StaticVerifierFunctionEmitter::emitConstraints(
191     const ConstraintMap &constraints, StringRef selfName,
192     const char *const codeTemplate) {
193   FmtContext ctx;
194   ctx.withOp("*op").withSelf(selfName);
195   for (auto &it : constraints) {
196     os << formatv(codeTemplate, it.second,
197                   tgfmt(it.first.getConditionTemplate(), &ctx),
198                   escapeString(it.first.getSummary()));
199   }
200 }
201 
202 void StaticVerifierFunctionEmitter::emitTypeConstraints() {
203   emitConstraints(typeConstraints, "type", typeConstraintCode);
204 }
205 
206 void StaticVerifierFunctionEmitter::emitAttrConstraints() {
207   emitConstraints(attrConstraints, "attr", attrConstraintCode);
208 }
209 
210 void StaticVerifierFunctionEmitter::emitSuccessorConstraints() {
211   emitConstraints(successorConstraints, "successor", successorConstraintCode);
212 }
213 
214 void StaticVerifierFunctionEmitter::emitRegionConstraints() {
215   emitConstraints(regionConstraints, "region", regionConstraintCode);
216 }
217 
218 void StaticVerifierFunctionEmitter::emitPatternConstraints() {
219   FmtContext ctx;
220   ctx.withOp("*op").withBuilder("rewriter").withSelf("type");
221   for (auto &it : typeConstraints) {
222     os << formatv(patternAttrOrTypeConstraintCode, it.second,
223                   tgfmt(it.first.getConditionTemplate(), &ctx),
224                   escapeString(it.first.getSummary()), "Type type");
225   }
226   ctx.withSelf("attr");
227   for (auto &it : attrConstraints) {
228     os << formatv(patternAttrOrTypeConstraintCode, it.second,
229                   tgfmt(it.first.getConditionTemplate(), &ctx),
230                   escapeString(it.first.getSummary()), "Attribute attr");
231   }
232 }
233 
234 //===----------------------------------------------------------------------===//
235 // Constraint Uniquing
236 
237 using RecordDenseMapInfo = llvm::DenseMapInfo<const llvm::Record *>;
238 
239 Constraint StaticVerifierFunctionEmitter::ConstraintUniquer::getEmptyKey() {
240   return Constraint(RecordDenseMapInfo::getEmptyKey(),
241                     Constraint::CK_Uncategorized);
242 }
243 
244 Constraint StaticVerifierFunctionEmitter::ConstraintUniquer::getTombstoneKey() {
245   return Constraint(RecordDenseMapInfo::getTombstoneKey(),
246                     Constraint::CK_Uncategorized);
247 }
248 
249 unsigned StaticVerifierFunctionEmitter::ConstraintUniquer::getHashValue(
250     Constraint constraint) {
251   if (constraint == getEmptyKey())
252     return RecordDenseMapInfo::getHashValue(RecordDenseMapInfo::getEmptyKey());
253   if (constraint == getTombstoneKey()) {
254     return RecordDenseMapInfo::getHashValue(
255         RecordDenseMapInfo::getTombstoneKey());
256   }
257   return llvm::hash_combine(constraint.getPredicate(), constraint.getSummary());
258 }
259 
260 bool StaticVerifierFunctionEmitter::ConstraintUniquer::isEqual(Constraint lhs,
261                                                                Constraint rhs) {
262   if (lhs == rhs)
263     return true;
264   if (lhs == getEmptyKey() || lhs == getTombstoneKey())
265     return false;
266   if (rhs == getEmptyKey() || rhs == getTombstoneKey())
267     return false;
268   return lhs.getPredicate() == rhs.getPredicate() &&
269          lhs.getSummary() == rhs.getSummary();
270 }
271 
272 /// An attribute constraint that references anything other than itself and the
273 /// current op cannot be generically extracted into a function. Most
274 /// prohibitive are operands and results, which require calls to
275 /// `getODSOperands` or `getODSResults`. Attribute references are tricky too
276 /// because ops use cached identifiers.
277 static bool canUniqueAttrConstraint(Attribute attr) {
278   FmtContext ctx;
279   auto test =
280       tgfmt(attr.getConditionTemplate(), &ctx.withSelf("attr").withOp("*op"))
281           .str();
282   return !StringRef(test).contains("<no-subst-found>");
283 }
284 
285 std::string StaticVerifierFunctionEmitter::getUniqueName(StringRef kind,
286                                                          unsigned index) {
287   return ("__mlir_ods_local_" + kind + "_constraint_" + uniqueOutputLabel +
288           Twine(index))
289       .str();
290 }
291 
292 void StaticVerifierFunctionEmitter::collectConstraint(ConstraintMap &map,
293                                                       StringRef kind,
294                                                       Constraint constraint) {
295   auto it = map.find(constraint);
296   if (it == map.end())
297     map.insert({constraint, getUniqueName(kind, map.size())});
298 }
299 
300 void StaticVerifierFunctionEmitter::collectOpConstraints(
301     ArrayRef<Record *> opDefs) {
302   const auto collectTypeConstraints = [&](Operator::value_range values) {
303     for (const NamedTypeConstraint &value : values)
304       if (value.hasPredicate())
305         collectConstraint(typeConstraints, "type", value.constraint);
306   };
307 
308   for (Record *def : opDefs) {
309     Operator op(*def);
310     /// Collect type constraints.
311     collectTypeConstraints(op.getOperands());
312     collectTypeConstraints(op.getResults());
313     /// Collect attribute constraints.
314     for (const NamedAttribute &namedAttr : op.getAttributes()) {
315       if (!namedAttr.attr.getPredicate().isNull() &&
316           !namedAttr.attr.isDerivedAttr() &&
317           canUniqueAttrConstraint(namedAttr.attr))
318         collectConstraint(attrConstraints, "attr", namedAttr.attr);
319     }
320     /// Collect successor constraints.
321     for (const NamedSuccessor &successor : op.getSuccessors()) {
322       if (!successor.constraint.getPredicate().isNull()) {
323         collectConstraint(successorConstraints, "successor",
324                           successor.constraint);
325       }
326     }
327     /// Collect region constraints.
328     for (const NamedRegion &region : op.getRegions())
329       if (!region.constraint.getPredicate().isNull())
330         collectConstraint(regionConstraints, "region", region.constraint);
331   }
332 }
333 
334 void StaticVerifierFunctionEmitter::collectPatternConstraints(
335     const llvm::DenseSet<DagLeaf> &constraints) {
336   for (auto &leaf : constraints) {
337     assert(leaf.isOperandMatcher() || leaf.isAttrMatcher());
338     collectConstraint(
339         leaf.isOperandMatcher() ? typeConstraints : attrConstraints,
340         leaf.isOperandMatcher() ? "type" : "attr", leaf.getAsConstraint());
341   }
342 }
343 
344 //===----------------------------------------------------------------------===//
345 // Public Utility Functions
346 //===----------------------------------------------------------------------===//
347 
348 std::string mlir::tblgen::escapeString(StringRef value) {
349   std::string ret;
350   llvm::raw_string_ostream os(ret);
351   os.write_escaped(value);
352   return os.str();
353 }
354