1 //===- CodeGenHelpers.cpp - MLIR op definitions generator ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpDefinitionsGen uses the description of operations to generate C++
10 // definitions for ops.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/TableGen/CodeGenHelpers.h"
15 #include "mlir/TableGen/Operator.h"
16 #include "mlir/TableGen/Pattern.h"
17 #include "llvm/ADT/SetVector.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/Support/Path.h"
20 #include "llvm/TableGen/Record.h"
21
22 using namespace llvm;
23 using namespace mlir;
24 using namespace mlir::tblgen;
25
26 /// Generate a unique label based on the current file name to prevent name
27 /// collisions if multiple generated files are included at once.
getUniqueOutputLabel(const llvm::RecordKeeper & records)28 static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) {
29 // Use the input file name when generating a unique name.
30 std::string inputFilename = records.getInputFilename();
31
32 // Drop all but the base filename.
33 StringRef nameRef = llvm::sys::path::filename(inputFilename);
34 nameRef.consume_back(".td");
35
36 // Sanitize any invalid characters.
37 std::string uniqueName;
38 for (char c : nameRef) {
39 if (llvm::isAlnum(c) || c == '_')
40 uniqueName.push_back(c);
41 else
42 uniqueName.append(llvm::utohexstr((unsigned char)c));
43 }
44 return uniqueName;
45 }
46
StaticVerifierFunctionEmitter(raw_ostream & os,const llvm::RecordKeeper & records)47 StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
48 raw_ostream &os, const llvm::RecordKeeper &records)
49 : os(os), uniqueOutputLabel(getUniqueOutputLabel(records)) {}
50
emitOpConstraints(ArrayRef<llvm::Record * > opDefs,bool emitDecl)51 void StaticVerifierFunctionEmitter::emitOpConstraints(
52 ArrayRef<llvm::Record *> opDefs, bool emitDecl) {
53 collectOpConstraints(opDefs);
54 if (emitDecl)
55 return;
56
57 NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace());
58 emitTypeConstraints();
59 emitAttrConstraints();
60 emitSuccessorConstraints();
61 emitRegionConstraints();
62 }
63
emitPatternConstraints(const llvm::ArrayRef<DagLeaf> constraints)64 void StaticVerifierFunctionEmitter::emitPatternConstraints(
65 const llvm::ArrayRef<DagLeaf> constraints) {
66 collectPatternConstraints(constraints);
67 emitPatternConstraints();
68 }
69
70 //===----------------------------------------------------------------------===//
71 // Constraint Getters
72
getTypeConstraintFn(const Constraint & constraint) const73 StringRef StaticVerifierFunctionEmitter::getTypeConstraintFn(
74 const Constraint &constraint) const {
75 auto it = typeConstraints.find(constraint);
76 assert(it != typeConstraints.end() && "expected to find a type constraint");
77 return it->second;
78 }
79
80 // Find a uniqued attribute constraint. Since not all attribute constraints can
81 // be uniqued, return None if one was not found.
getAttrConstraintFn(const Constraint & constraint) const82 Optional<StringRef> StaticVerifierFunctionEmitter::getAttrConstraintFn(
83 const Constraint &constraint) const {
84 auto it = attrConstraints.find(constraint);
85 return it == attrConstraints.end() ? Optional<StringRef>()
86 : StringRef(it->second);
87 }
88
getSuccessorConstraintFn(const Constraint & constraint) const89 StringRef StaticVerifierFunctionEmitter::getSuccessorConstraintFn(
90 const Constraint &constraint) const {
91 auto it = successorConstraints.find(constraint);
92 assert(it != successorConstraints.end() &&
93 "expected to find a sucessor constraint");
94 return it->second;
95 }
96
getRegionConstraintFn(const Constraint & constraint) const97 StringRef StaticVerifierFunctionEmitter::getRegionConstraintFn(
98 const Constraint &constraint) const {
99 auto it = regionConstraints.find(constraint);
100 assert(it != regionConstraints.end() &&
101 "expected to find a region constraint");
102 return it->second;
103 }
104
105 //===----------------------------------------------------------------------===//
106 // Constraint Emission
107
108 /// Code templates for emitting type, attribute, successor, and region
109 /// constraints. Each of these templates require the following arguments:
110 ///
111 /// {0}: The unique constraint name.
112 /// {1}: The constraint code.
113 /// {2}: The constraint description.
114
115 /// Code for a type constraint. These may be called on the type of either
116 /// operands or results.
117 static const char *const typeConstraintCode = R"(
118 static ::mlir::LogicalResult {0}(
119 ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind,
120 unsigned valueIndex) {
121 if (!({1})) {
122 return op->emitOpError(valueKind) << " #" << valueIndex
123 << " must be {2}, but got " << type;
124 }
125 return ::mlir::success();
126 }
127 )";
128
129 /// Code for an attribute constraint. These may be called from ops only.
130 /// Attribute constraints cannot reference anything other than `$_self` and
131 /// `$_op`.
132 ///
133 /// TODO: Unique constraints for adaptors. However, most Adaptor::verify
134 /// functions are stripped anyways.
135 static const char *const attrConstraintCode = R"(
136 static ::mlir::LogicalResult {0}(
137 ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) {
138 if (attr && !({1})) {
139 return op->emitOpError("attribute '") << attrName
140 << "' failed to satisfy constraint: {2}";
141 }
142 return ::mlir::success();
143 }
144 )";
145
146 /// Code for a successor constraint.
147 static const char *const successorConstraintCode = R"(
148 static ::mlir::LogicalResult {0}(
149 ::mlir::Operation *op, ::mlir::Block *successor,
150 ::llvm::StringRef successorName, unsigned successorIndex) {
151 if (!({1})) {
152 return op->emitOpError("successor #") << successorIndex << " ('"
153 << successorName << ")' failed to verify constraint: {2}";
154 }
155 return ::mlir::success();
156 }
157 )";
158
159 /// Code for a region constraint. Callers will need to pass in the region's name
160 /// for emitting an error message.
161 static const char *const regionConstraintCode = R"(
162 static ::mlir::LogicalResult {0}(
163 ::mlir::Operation *op, ::mlir::Region ®ion, ::llvm::StringRef regionName,
164 unsigned regionIndex) {
165 if (!({1})) {
166 return op->emitOpError("region #") << regionIndex
167 << (regionName.empty() ? " " : " ('" + regionName + "') ")
168 << "failed to verify constraint: {2}";
169 }
170 return ::mlir::success();
171 }
172 )";
173
174 /// Code for a pattern type or attribute constraint.
175 ///
176 /// {3}: "Type type" or "Attribute attr".
177 static const char *const patternAttrOrTypeConstraintCode = R"(
178 static ::mlir::LogicalResult {0}(
179 ::mlir::PatternRewriter &rewriter, ::mlir::Operation *op, ::mlir::{3},
180 ::llvm::StringRef failureStr) {
181 if (!({1})) {
182 return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) {
183 diag << failureStr << ": {2}";
184 });
185 }
186 return ::mlir::success();
187 }
188 )";
189
emitConstraints(const ConstraintMap & constraints,StringRef selfName,const char * const codeTemplate)190 void StaticVerifierFunctionEmitter::emitConstraints(
191 const ConstraintMap &constraints, StringRef selfName,
192 const char *const codeTemplate) {
193 FmtContext ctx;
194 ctx.withOp("*op").withSelf(selfName);
195 for (auto &it : constraints) {
196 os << formatv(codeTemplate, it.second,
197 tgfmt(it.first.getConditionTemplate(), &ctx),
198 escapeString(it.first.getSummary()));
199 }
200 }
201
emitTypeConstraints()202 void StaticVerifierFunctionEmitter::emitTypeConstraints() {
203 emitConstraints(typeConstraints, "type", typeConstraintCode);
204 }
205
emitAttrConstraints()206 void StaticVerifierFunctionEmitter::emitAttrConstraints() {
207 emitConstraints(attrConstraints, "attr", attrConstraintCode);
208 }
209
emitSuccessorConstraints()210 void StaticVerifierFunctionEmitter::emitSuccessorConstraints() {
211 emitConstraints(successorConstraints, "successor", successorConstraintCode);
212 }
213
emitRegionConstraints()214 void StaticVerifierFunctionEmitter::emitRegionConstraints() {
215 emitConstraints(regionConstraints, "region", regionConstraintCode);
216 }
217
emitPatternConstraints()218 void StaticVerifierFunctionEmitter::emitPatternConstraints() {
219 FmtContext ctx;
220 ctx.withOp("*op").withBuilder("rewriter").withSelf("type");
221 for (auto &it : typeConstraints) {
222 os << formatv(patternAttrOrTypeConstraintCode, it.second,
223 tgfmt(it.first.getConditionTemplate(), &ctx),
224 escapeString(it.first.getSummary()), "Type type");
225 }
226 ctx.withSelf("attr");
227 for (auto &it : attrConstraints) {
228 os << formatv(patternAttrOrTypeConstraintCode, it.second,
229 tgfmt(it.first.getConditionTemplate(), &ctx),
230 escapeString(it.first.getSummary()), "Attribute attr");
231 }
232 }
233
234 //===----------------------------------------------------------------------===//
235 // Constraint Uniquing
236
237 /// An attribute constraint that references anything other than itself and the
238 /// current op cannot be generically extracted into a function. Most
239 /// prohibitive are operands and results, which require calls to
240 /// `getODSOperands` or `getODSResults`. Attribute references are tricky too
241 /// because ops use cached identifiers.
canUniqueAttrConstraint(Attribute attr)242 static bool canUniqueAttrConstraint(Attribute attr) {
243 FmtContext ctx;
244 auto test =
245 tgfmt(attr.getConditionTemplate(), &ctx.withSelf("attr").withOp("*op"))
246 .str();
247 return !StringRef(test).contains("<no-subst-found>");
248 }
249
getUniqueName(StringRef kind,unsigned index)250 std::string StaticVerifierFunctionEmitter::getUniqueName(StringRef kind,
251 unsigned index) {
252 return ("__mlir_ods_local_" + kind + "_constraint_" + uniqueOutputLabel +
253 Twine(index))
254 .str();
255 }
256
collectConstraint(ConstraintMap & map,StringRef kind,Constraint constraint)257 void StaticVerifierFunctionEmitter::collectConstraint(ConstraintMap &map,
258 StringRef kind,
259 Constraint constraint) {
260 auto it = map.find(constraint);
261 if (it == map.end())
262 map.insert({constraint, getUniqueName(kind, map.size())});
263 }
264
collectOpConstraints(ArrayRef<Record * > opDefs)265 void StaticVerifierFunctionEmitter::collectOpConstraints(
266 ArrayRef<Record *> opDefs) {
267 const auto collectTypeConstraints = [&](Operator::const_value_range values) {
268 for (const NamedTypeConstraint &value : values)
269 if (value.hasPredicate())
270 collectConstraint(typeConstraints, "type", value.constraint);
271 };
272
273 for (Record *def : opDefs) {
274 Operator op(*def);
275 /// Collect type constraints.
276 collectTypeConstraints(op.getOperands());
277 collectTypeConstraints(op.getResults());
278 /// Collect attribute constraints.
279 for (const NamedAttribute &namedAttr : op.getAttributes()) {
280 if (!namedAttr.attr.getPredicate().isNull() &&
281 !namedAttr.attr.isDerivedAttr() &&
282 canUniqueAttrConstraint(namedAttr.attr))
283 collectConstraint(attrConstraints, "attr", namedAttr.attr);
284 }
285 /// Collect successor constraints.
286 for (const NamedSuccessor &successor : op.getSuccessors()) {
287 if (!successor.constraint.getPredicate().isNull()) {
288 collectConstraint(successorConstraints, "successor",
289 successor.constraint);
290 }
291 }
292 /// Collect region constraints.
293 for (const NamedRegion ®ion : op.getRegions())
294 if (!region.constraint.getPredicate().isNull())
295 collectConstraint(regionConstraints, "region", region.constraint);
296 }
297 }
298
collectPatternConstraints(const llvm::ArrayRef<DagLeaf> constraints)299 void StaticVerifierFunctionEmitter::collectPatternConstraints(
300 const llvm::ArrayRef<DagLeaf> constraints) {
301 for (auto &leaf : constraints) {
302 assert(leaf.isOperandMatcher() || leaf.isAttrMatcher());
303 collectConstraint(
304 leaf.isOperandMatcher() ? typeConstraints : attrConstraints,
305 leaf.isOperandMatcher() ? "type" : "attr", leaf.getAsConstraint());
306 }
307 }
308
309 //===----------------------------------------------------------------------===//
310 // Public Utility Functions
311 //===----------------------------------------------------------------------===//
312
escapeString(StringRef value)313 std::string mlir::tblgen::escapeString(StringRef value) {
314 std::string ret;
315 llvm::raw_string_ostream os(ret);
316 os.write_escaped(value);
317 return os.str();
318 }
319