1 //===- CodeGenHelpers.cpp - MLIR op definitions generator ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpDefinitionsGen uses the description of operations to generate C++
10 // definitions for ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/TableGen/CodeGenHelpers.h"
15 #include "mlir/TableGen/Format.h"
16 #include "mlir/TableGen/Operator.h"
17 #include "llvm/ADT/SetVector.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/Support/Path.h"
20 #include "llvm/TableGen/Record.h"
21 
22 using namespace llvm;
23 using namespace mlir;
24 using namespace mlir::tblgen;
25 
26 StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
27     const llvm::RecordKeeper &records, raw_ostream &os)
28     : os(os), uniqueOutputLabel(getUniqueName(records)) {}
29 
30 void StaticVerifierFunctionEmitter::emitFunctionsFor(
31     StringRef signatureFormat, StringRef errorHandlerFormat,
32     StringRef typeArgName, ArrayRef<llvm::Record *> opDefs, bool emitDecl) {
33   llvm::Optional<NamespaceEmitter> namespaceEmitter;
34   if (!emitDecl)
35     namespaceEmitter.emplace(os, Operator(*opDefs[0]).getCppNamespace());
36 
37   emitTypeConstraintMethods(signatureFormat, errorHandlerFormat, typeArgName,
38                             opDefs, emitDecl);
39 }
40 
41 StringRef StaticVerifierFunctionEmitter::getTypeConstraintFn(
42     const Constraint &constraint) const {
43   auto it = localTypeConstraints.find(constraint.getAsOpaquePointer());
44   assert(it != localTypeConstraints.end() && "expected valid constraint fn");
45   return it->second;
46 }
47 
48 std::string StaticVerifierFunctionEmitter::getUniqueName(
49     const llvm::RecordKeeper &records) {
50   // Use the input file name when generating a unique name.
51   std::string inputFilename = records.getInputFilename();
52 
53   // Drop all but the base filename.
54   StringRef nameRef = llvm::sys::path::filename(inputFilename);
55   nameRef.consume_back(".td");
56 
57   // Sanitize any invalid characters.
58   std::string uniqueName;
59   for (char c : nameRef) {
60     if (llvm::isAlnum(c) || c == '_')
61       uniqueName.push_back(c);
62     else
63       uniqueName.append(llvm::utohexstr((unsigned char)c));
64   }
65   return uniqueName;
66 }
67 
68 void StaticVerifierFunctionEmitter::emitTypeConstraintMethods(
69     StringRef signatureFormat, StringRef errorHandlerFormat,
70     StringRef typeArgName, ArrayRef<llvm::Record *> opDefs, bool emitDecl) {
71   // Collect a set of all of the used type constraints within the operation
72   // definitions.
73   llvm::SetVector<const void *> typeConstraints;
74   for (Record *def : opDefs) {
75     Operator op(*def);
76     for (NamedTypeConstraint &operand : op.getOperands())
77       if (operand.hasPredicate())
78         typeConstraints.insert(operand.constraint.getAsOpaquePointer());
79     for (NamedTypeConstraint &result : op.getResults())
80       if (result.hasPredicate())
81         typeConstraints.insert(result.constraint.getAsOpaquePointer());
82   }
83 
84   // Record the mapping from predicate to constraint. If two constraints has the
85   // same predicate and constraint summary, they can share the same verification
86   // function.
87   llvm::DenseMap<Pred, const void *> predToConstraint;
88   FmtContext fctx;
89   for (auto it : llvm::enumerate(typeConstraints)) {
90     std::string name;
91     Constraint constraint = Constraint::getFromOpaquePointer(it.value());
92     Pred pred = constraint.getPredicate();
93     auto iter = predToConstraint.find(pred);
94     if (iter != predToConstraint.end()) {
95       do {
96         Constraint built = Constraint::getFromOpaquePointer(iter->second);
97         // We may have the different constraints but have the same predicate,
98         // for example, ConstraintA and Variadic<ConstraintA>, note that
99         // Variadic<> doesn't introduce new predicate. In this case, we can
100         // share the same predicate function if they also have consistent
101         // summary, otherwise we may report the wrong message while verification
102         // fails.
103         if (constraint.getSummary() == built.getSummary()) {
104           name = getTypeConstraintFn(built).str();
105           break;
106         }
107         ++iter;
108       } while (iter != predToConstraint.end() && iter->first == pred);
109     }
110 
111     if (!name.empty()) {
112       localTypeConstraints.try_emplace(it.value(), name);
113       continue;
114     }
115 
116     // Generate an obscure and unique name for this type constraint.
117     name = (Twine("__mlir_ods_local_type_constraint_") + uniqueOutputLabel +
118             Twine(it.index()))
119                .str();
120     predToConstraint.insert(
121         std::make_pair(constraint.getPredicate(), it.value()));
122     localTypeConstraints.try_emplace(it.value(), name);
123 
124     // Only generate the methods if we are generating definitions.
125     if (emitDecl)
126       continue;
127 
128     os << formatv(signatureFormat.data(), name) << " {\n";
129     os.indent() << "if (!("
130                 << tgfmt(constraint.getConditionTemplate(),
131                          &fctx.withSelf(typeArgName))
132                 << ")) {\n";
133     os.indent() << "return "
134                 << formatv(errorHandlerFormat.data(), constraint.getSummary())
135                 << ";\n";
136     os.unindent() << "}\nreturn ::mlir::success();\n";
137     os.unindent() << "}\n\n";
138   }
139 }
140