1 //===- CodeGenHelpers.cpp - MLIR op definitions generator ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // OpDefinitionsGen uses the description of operations to generate C++ 10 // definitions for ops. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/TableGen/CodeGenHelpers.h" 15 #include "mlir/TableGen/Operator.h" 16 #include "mlir/TableGen/Pattern.h" 17 #include "llvm/ADT/SetVector.h" 18 #include "llvm/Support/FormatVariadic.h" 19 #include "llvm/Support/Path.h" 20 #include "llvm/TableGen/Record.h" 21 22 using namespace llvm; 23 using namespace mlir; 24 using namespace mlir::tblgen; 25 26 /// Generate a unique label based on the current file name to prevent name 27 /// collisions if multiple generated files are included at once. 28 static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) { 29 // Use the input file name when generating a unique name. 30 std::string inputFilename = records.getInputFilename(); 31 32 // Drop all but the base filename. 33 StringRef nameRef = llvm::sys::path::filename(inputFilename); 34 nameRef.consume_back(".td"); 35 36 // Sanitize any invalid characters. 37 std::string uniqueName; 38 for (char c : nameRef) { 39 if (llvm::isAlnum(c) || c == '_') 40 uniqueName.push_back(c); 41 else 42 uniqueName.append(llvm::utohexstr((unsigned char)c)); 43 } 44 return uniqueName; 45 } 46 47 StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter( 48 raw_ostream &os, const llvm::RecordKeeper &records) 49 : os(os), uniqueOutputLabel(getUniqueOutputLabel(records)) {} 50 51 void StaticVerifierFunctionEmitter::emitOpConstraints( 52 ArrayRef<llvm::Record *> opDefs, bool emitDecl) { 53 collectOpConstraints(opDefs); 54 if (emitDecl) 55 return; 56 57 NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace()); 58 emitTypeConstraints(); 59 emitAttrConstraints(); 60 emitSuccessorConstraints(); 61 emitRegionConstraints(); 62 } 63 64 void StaticVerifierFunctionEmitter::emitPatternConstraints( 65 const llvm::ArrayRef<DagLeaf> constraints) { 66 collectPatternConstraints(constraints); 67 emitPatternConstraints(); 68 } 69 70 //===----------------------------------------------------------------------===// 71 // Constraint Getters 72 73 StringRef StaticVerifierFunctionEmitter::getTypeConstraintFn( 74 const Constraint &constraint) const { 75 auto it = typeConstraints.find(constraint); 76 assert(it != typeConstraints.end() && "expected to find a type constraint"); 77 return it->second; 78 } 79 80 // Find a uniqued attribute constraint. Since not all attribute constraints can 81 // be uniqued, return None if one was not found. 82 Optional<StringRef> StaticVerifierFunctionEmitter::getAttrConstraintFn( 83 const Constraint &constraint) const { 84 auto it = attrConstraints.find(constraint); 85 return it == attrConstraints.end() ? Optional<StringRef>() 86 : StringRef(it->second); 87 } 88 89 StringRef StaticVerifierFunctionEmitter::getSuccessorConstraintFn( 90 const Constraint &constraint) const { 91 auto it = successorConstraints.find(constraint); 92 assert(it != successorConstraints.end() && 93 "expected to find a sucessor constraint"); 94 return it->second; 95 } 96 97 StringRef StaticVerifierFunctionEmitter::getRegionConstraintFn( 98 const Constraint &constraint) const { 99 auto it = regionConstraints.find(constraint); 100 assert(it != regionConstraints.end() && 101 "expected to find a region constraint"); 102 return it->second; 103 } 104 105 //===----------------------------------------------------------------------===// 106 // Constraint Emission 107 108 /// Code templates for emitting type, attribute, successor, and region 109 /// constraints. Each of these templates require the following arguments: 110 /// 111 /// {0}: The unique constraint name. 112 /// {1}: The constraint code. 113 /// {2}: The constraint description. 114 115 /// Code for a type constraint. These may be called on the type of either 116 /// operands or results. 117 static const char *const typeConstraintCode = R"( 118 static ::mlir::LogicalResult {0}( 119 ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind, 120 unsigned valueIndex) { 121 if (!({1})) { 122 return op->emitOpError(valueKind) << " #" << valueIndex 123 << " must be {2}, but got " << type; 124 } 125 return ::mlir::success(); 126 } 127 )"; 128 129 /// Code for an attribute constraint. These may be called from ops only. 130 /// Attribute constraints cannot reference anything other than `$_self` and 131 /// `$_op`. 132 /// 133 /// TODO: Unique constraints for adaptors. However, most Adaptor::verify 134 /// functions are stripped anyways. 135 static const char *const attrConstraintCode = R"( 136 static ::mlir::LogicalResult {0}( 137 ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) { 138 if (attr && !({1})) { 139 return op->emitOpError("attribute '") << attrName 140 << "' failed to satisfy constraint: {2}"; 141 } 142 return ::mlir::success(); 143 } 144 )"; 145 146 /// Code for a successor constraint. 147 static const char *const successorConstraintCode = R"( 148 static ::mlir::LogicalResult {0}( 149 ::mlir::Operation *op, ::mlir::Block *successor, 150 ::llvm::StringRef successorName, unsigned successorIndex) { 151 if (!({1})) { 152 return op->emitOpError("successor #") << successorIndex << " ('" 153 << successorName << ")' failed to verify constraint: {2}"; 154 } 155 return ::mlir::success(); 156 } 157 )"; 158 159 /// Code for a region constraint. Callers will need to pass in the region's name 160 /// for emitting an error message. 161 static const char *const regionConstraintCode = R"( 162 static ::mlir::LogicalResult {0}( 163 ::mlir::Operation *op, ::mlir::Region ®ion, ::llvm::StringRef regionName, 164 unsigned regionIndex) { 165 if (!({1})) { 166 return op->emitOpError("region #") << regionIndex 167 << (regionName.empty() ? " " : " ('" + regionName + "') ") 168 << "failed to verify constraint: {2}"; 169 } 170 return ::mlir::success(); 171 } 172 )"; 173 174 /// Code for a pattern type or attribute constraint. 175 /// 176 /// {3}: "Type type" or "Attribute attr". 177 static const char *const patternAttrOrTypeConstraintCode = R"( 178 static ::mlir::LogicalResult {0}( 179 ::mlir::PatternRewriter &rewriter, ::mlir::Operation *op, ::mlir::{3}, 180 ::llvm::StringRef failureStr) { 181 if (!({1})) { 182 return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) { 183 diag << failureStr << ": {2}"; 184 }); 185 } 186 return ::mlir::success(); 187 } 188 )"; 189 190 void StaticVerifierFunctionEmitter::emitConstraints( 191 const ConstraintMap &constraints, StringRef selfName, 192 const char *const codeTemplate) { 193 FmtContext ctx; 194 ctx.withOp("*op").withSelf(selfName); 195 for (auto &it : constraints) { 196 os << formatv(codeTemplate, it.second, 197 tgfmt(it.first.getConditionTemplate(), &ctx), 198 escapeString(it.first.getSummary())); 199 } 200 } 201 202 void StaticVerifierFunctionEmitter::emitTypeConstraints() { 203 emitConstraints(typeConstraints, "type", typeConstraintCode); 204 } 205 206 void StaticVerifierFunctionEmitter::emitAttrConstraints() { 207 emitConstraints(attrConstraints, "attr", attrConstraintCode); 208 } 209 210 void StaticVerifierFunctionEmitter::emitSuccessorConstraints() { 211 emitConstraints(successorConstraints, "successor", successorConstraintCode); 212 } 213 214 void StaticVerifierFunctionEmitter::emitRegionConstraints() { 215 emitConstraints(regionConstraints, "region", regionConstraintCode); 216 } 217 218 void StaticVerifierFunctionEmitter::emitPatternConstraints() { 219 FmtContext ctx; 220 ctx.withOp("*op").withBuilder("rewriter").withSelf("type"); 221 for (auto &it : typeConstraints) { 222 os << formatv(patternAttrOrTypeConstraintCode, it.second, 223 tgfmt(it.first.getConditionTemplate(), &ctx), 224 escapeString(it.first.getSummary()), "Type type"); 225 } 226 ctx.withSelf("attr"); 227 for (auto &it : attrConstraints) { 228 os << formatv(patternAttrOrTypeConstraintCode, it.second, 229 tgfmt(it.first.getConditionTemplate(), &ctx), 230 escapeString(it.first.getSummary()), "Attribute attr"); 231 } 232 } 233 234 //===----------------------------------------------------------------------===// 235 // Constraint Uniquing 236 237 /// An attribute constraint that references anything other than itself and the 238 /// current op cannot be generically extracted into a function. Most 239 /// prohibitive are operands and results, which require calls to 240 /// `getODSOperands` or `getODSResults`. Attribute references are tricky too 241 /// because ops use cached identifiers. 242 static bool canUniqueAttrConstraint(Attribute attr) { 243 FmtContext ctx; 244 auto test = 245 tgfmt(attr.getConditionTemplate(), &ctx.withSelf("attr").withOp("*op")) 246 .str(); 247 return !StringRef(test).contains("<no-subst-found>"); 248 } 249 250 std::string StaticVerifierFunctionEmitter::getUniqueName(StringRef kind, 251 unsigned index) { 252 return ("__mlir_ods_local_" + kind + "_constraint_" + uniqueOutputLabel + 253 Twine(index)) 254 .str(); 255 } 256 257 void StaticVerifierFunctionEmitter::collectConstraint(ConstraintMap &map, 258 StringRef kind, 259 Constraint constraint) { 260 auto it = map.find(constraint); 261 if (it == map.end()) 262 map.insert({constraint, getUniqueName(kind, map.size())}); 263 } 264 265 void StaticVerifierFunctionEmitter::collectOpConstraints( 266 ArrayRef<Record *> opDefs) { 267 const auto collectTypeConstraints = [&](Operator::const_value_range values) { 268 for (const NamedTypeConstraint &value : values) 269 if (value.hasPredicate()) 270 collectConstraint(typeConstraints, "type", value.constraint); 271 }; 272 273 for (Record *def : opDefs) { 274 Operator op(*def); 275 /// Collect type constraints. 276 collectTypeConstraints(op.getOperands()); 277 collectTypeConstraints(op.getResults()); 278 /// Collect attribute constraints. 279 for (const NamedAttribute &namedAttr : op.getAttributes()) { 280 if (!namedAttr.attr.getPredicate().isNull() && 281 !namedAttr.attr.isDerivedAttr() && 282 canUniqueAttrConstraint(namedAttr.attr)) 283 collectConstraint(attrConstraints, "attr", namedAttr.attr); 284 } 285 /// Collect successor constraints. 286 for (const NamedSuccessor &successor : op.getSuccessors()) { 287 if (!successor.constraint.getPredicate().isNull()) { 288 collectConstraint(successorConstraints, "successor", 289 successor.constraint); 290 } 291 } 292 /// Collect region constraints. 293 for (const NamedRegion ®ion : op.getRegions()) 294 if (!region.constraint.getPredicate().isNull()) 295 collectConstraint(regionConstraints, "region", region.constraint); 296 } 297 } 298 299 void StaticVerifierFunctionEmitter::collectPatternConstraints( 300 const llvm::ArrayRef<DagLeaf> constraints) { 301 for (auto &leaf : constraints) { 302 assert(leaf.isOperandMatcher() || leaf.isAttrMatcher()); 303 collectConstraint( 304 leaf.isOperandMatcher() ? typeConstraints : attrConstraints, 305 leaf.isOperandMatcher() ? "type" : "attr", leaf.getAsConstraint()); 306 } 307 } 308 309 //===----------------------------------------------------------------------===// 310 // Public Utility Functions 311 //===----------------------------------------------------------------------===// 312 313 std::string mlir::tblgen::escapeString(StringRef value) { 314 std::string ret; 315 llvm::raw_string_ostream os(ret); 316 os.write_escaped(value); 317 return os.str(); 318 } 319