1 //===- CodeGenHelpers.cpp - MLIR op definitions generator ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // OpDefinitionsGen uses the description of operations to generate C++ 10 // definitions for ops. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/TableGen/CodeGenHelpers.h" 15 #include "mlir/TableGen/Format.h" 16 #include "mlir/TableGen/Operator.h" 17 #include "llvm/ADT/SetVector.h" 18 #include "llvm/Support/FormatVariadic.h" 19 #include "llvm/Support/Path.h" 20 #include "llvm/TableGen/Record.h" 21 22 using namespace llvm; 23 using namespace mlir; 24 using namespace mlir::tblgen; 25 26 StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter( 27 const llvm::RecordKeeper &records, raw_ostream &os) 28 : os(os), uniqueOutputLabel(getUniqueName(records)) {} 29 30 void StaticVerifierFunctionEmitter::emitFunctionsFor( 31 StringRef signatureFormat, StringRef errorHandlerFormat, 32 StringRef typeArgName, ArrayRef<llvm::Record *> opDefs, bool emitDecl) { 33 llvm::Optional<NamespaceEmitter> namespaceEmitter; 34 if (!emitDecl) 35 namespaceEmitter.emplace(os, Operator(*opDefs[0]).getCppNamespace()); 36 37 emitTypeConstraintMethods(signatureFormat, errorHandlerFormat, typeArgName, 38 opDefs, emitDecl); 39 } 40 41 StringRef StaticVerifierFunctionEmitter::getTypeConstraintFn( 42 const Constraint &constraint) const { 43 auto it = localTypeConstraints.find(constraint.getAsOpaquePointer()); 44 assert(it != localTypeConstraints.end() && "expected valid constraint fn"); 45 return it->second; 46 } 47 48 std::string StaticVerifierFunctionEmitter::getUniqueName( 49 const llvm::RecordKeeper &records) { 50 // Use the input file name when generating a unique name. 51 std::string inputFilename = records.getInputFilename(); 52 53 // Drop all but the base filename. 54 StringRef nameRef = llvm::sys::path::filename(inputFilename); 55 nameRef.consume_back(".td"); 56 57 // Sanitize any invalid characters. 58 std::string uniqueName; 59 for (char c : nameRef) { 60 if (llvm::isAlnum(c) || c == '_') 61 uniqueName.push_back(c); 62 else 63 uniqueName.append(llvm::utohexstr((unsigned char)c)); 64 } 65 return uniqueName; 66 } 67 68 void StaticVerifierFunctionEmitter::emitTypeConstraintMethods( 69 StringRef signatureFormat, StringRef errorHandlerFormat, 70 StringRef typeArgName, ArrayRef<llvm::Record *> opDefs, bool emitDecl) { 71 // Collect a set of all of the used type constraints within the operation 72 // definitions. 73 llvm::SetVector<const void *> typeConstraints; 74 for (Record *def : opDefs) { 75 Operator op(*def); 76 for (NamedTypeConstraint &operand : op.getOperands()) 77 if (operand.hasPredicate()) 78 typeConstraints.insert(operand.constraint.getAsOpaquePointer()); 79 for (NamedTypeConstraint &result : op.getResults()) 80 if (result.hasPredicate()) 81 typeConstraints.insert(result.constraint.getAsOpaquePointer()); 82 } 83 84 // Record the mapping from predicate to constraint. If two constraints has the 85 // same predicate and constraint summary, they can share the same verification 86 // function. 87 llvm::DenseMap<Pred, const void *> predToConstraint; 88 FmtContext fctx; 89 for (auto it : llvm::enumerate(typeConstraints)) { 90 std::string name; 91 Constraint constraint = Constraint::getFromOpaquePointer(it.value()); 92 Pred pred = constraint.getPredicate(); 93 auto iter = predToConstraint.find(pred); 94 if (iter != predToConstraint.end()) { 95 do { 96 Constraint built = Constraint::getFromOpaquePointer(iter->second); 97 // We may have the different constraints but have the same predicate, 98 // for example, ConstraintA and Variadic<ConstraintA>, note that 99 // Variadic<> doesn't introduce new predicate. In this case, we can 100 // share the same predicate function if they also have consistent 101 // summary, otherwise we may report the wrong message while verification 102 // fails. 103 if (constraint.getSummary() == built.getSummary()) { 104 name = getTypeConstraintFn(built).str(); 105 break; 106 } 107 ++iter; 108 } while (iter != predToConstraint.end() && iter->first == pred); 109 } 110 111 if (!name.empty()) { 112 localTypeConstraints.try_emplace(it.value(), name); 113 continue; 114 } 115 116 // Generate an obscure and unique name for this type constraint. 117 name = (Twine("__mlir_ods_local_type_constraint_") + uniqueOutputLabel + 118 Twine(it.index())) 119 .str(); 120 predToConstraint.insert( 121 std::make_pair(constraint.getPredicate(), it.value())); 122 localTypeConstraints.try_emplace(it.value(), name); 123 124 // Only generate the methods if we are generating definitions. 125 if (emitDecl) 126 continue; 127 128 os << formatv(signatureFormat.data(), name) << " {\n"; 129 os.indent() << "if (!(" 130 << tgfmt(constraint.getConditionTemplate(), 131 &fctx.withSelf(typeArgName)) 132 << ")) {\n"; 133 os.indent() << "return " 134 << formatv(errorHandlerFormat.data(), constraint.getSummary()) 135 << ";\n"; 136 os.unindent() << "}\nreturn ::mlir::success();\n"; 137 os.unindent() << "}\n\n"; 138 } 139 } 140