1cde4d5a6SJacques Pienaar //===- Pattern.cpp - Pattern wrapper class --------------------------------===//
2eb753f4aSLei Zhang //
3eb753f4aSLei Zhang // Copyright 2019 The MLIR Authors.
4eb753f4aSLei Zhang //
5eb753f4aSLei Zhang // Licensed under the Apache License, Version 2.0 (the "License");
6eb753f4aSLei Zhang // you may not use this file except in compliance with the License.
7eb753f4aSLei Zhang // You may obtain a copy of the License at
8eb753f4aSLei Zhang //
9eb753f4aSLei Zhang //   http://www.apache.org/licenses/LICENSE-2.0
10eb753f4aSLei Zhang //
11eb753f4aSLei Zhang // Unless required by applicable law or agreed to in writing, software
12eb753f4aSLei Zhang // distributed under the License is distributed on an "AS IS" BASIS,
13eb753f4aSLei Zhang // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14eb753f4aSLei Zhang // See the License for the specific language governing permissions and
15eb753f4aSLei Zhang // limitations under the License.
16eb753f4aSLei Zhang // =============================================================================
17eb753f4aSLei Zhang //
18eb753f4aSLei Zhang // Pattern wrapper class to simplify using TableGen Record defining a MLIR
19eb753f4aSLei Zhang // Pattern.
20eb753f4aSLei Zhang //
21eb753f4aSLei Zhang //===----------------------------------------------------------------------===//
22eb753f4aSLei Zhang 
23eb753f4aSLei Zhang #include "mlir/TableGen/Pattern.h"
24eb753f4aSLei Zhang #include "llvm/ADT/Twine.h"
2504b6d2f3SLei Zhang #include "llvm/Support/FormatVariadic.h"
268f5fa566SLei Zhang #include "llvm/TableGen/Error.h"
27eb753f4aSLei Zhang #include "llvm/TableGen/Record.h"
28eb753f4aSLei Zhang 
29eb753f4aSLei Zhang using namespace mlir;
30eb753f4aSLei Zhang 
310ea6154bSJacques Pienaar using llvm::formatv;
32eb753f4aSLei Zhang using mlir::tblgen::Operator;
33eb753f4aSLei Zhang 
34e0774c00SLei Zhang bool tblgen::DagLeaf::isUnspecified() const {
35b9e38a79SLei Zhang   return dyn_cast_or_null<llvm::UnsetInit>(def);
36e0774c00SLei Zhang }
37e0774c00SLei Zhang 
38e0774c00SLei Zhang bool tblgen::DagLeaf::isOperandMatcher() const {
39e0774c00SLei Zhang   // Operand matchers specify a type constraint.
40b9e38a79SLei Zhang   return isSubClassOf("TypeConstraint");
41e0774c00SLei Zhang }
42e0774c00SLei Zhang 
43e0774c00SLei Zhang bool tblgen::DagLeaf::isAttrMatcher() const {
44c52a8127SFeng Liu   // Attribute matchers specify an attribute constraint.
45b9e38a79SLei Zhang   return isSubClassOf("AttrConstraint");
46e0774c00SLei Zhang }
47e0774c00SLei Zhang 
48d0e2019dSLei Zhang bool tblgen::DagLeaf::isNativeCodeCall() const {
49d0e2019dSLei Zhang   return isSubClassOf("NativeCodeCall");
50e0774c00SLei Zhang }
51e0774c00SLei Zhang 
52e0774c00SLei Zhang bool tblgen::DagLeaf::isConstantAttr() const {
53b9e38a79SLei Zhang   return isSubClassOf("ConstantAttr");
54b9e38a79SLei Zhang }
55b9e38a79SLei Zhang 
56b9e38a79SLei Zhang bool tblgen::DagLeaf::isEnumAttrCase() const {
579dd182e0SLei Zhang   return isSubClassOf("EnumAttrCaseInfo");
58e0774c00SLei Zhang }
59e0774c00SLei Zhang 
608f5fa566SLei Zhang tblgen::Constraint tblgen::DagLeaf::getAsConstraint() const {
618f5fa566SLei Zhang   assert((isOperandMatcher() || isAttrMatcher()) &&
628f5fa566SLei Zhang          "the DAG leaf must be operand or attribute");
638f5fa566SLei Zhang   return Constraint(cast<llvm::DefInit>(def)->getDef());
64e0774c00SLei Zhang }
65e0774c00SLei Zhang 
66e0774c00SLei Zhang tblgen::ConstantAttr tblgen::DagLeaf::getAsConstantAttr() const {
67e0774c00SLei Zhang   assert(isConstantAttr() && "the DAG leaf must be constant attribute");
68e0774c00SLei Zhang   return ConstantAttr(cast<llvm::DefInit>(def));
69e0774c00SLei Zhang }
70e0774c00SLei Zhang 
71b9e38a79SLei Zhang tblgen::EnumAttrCase tblgen::DagLeaf::getAsEnumAttrCase() const {
72b9e38a79SLei Zhang   assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
73b9e38a79SLei Zhang   return EnumAttrCase(cast<llvm::DefInit>(def));
74b9e38a79SLei Zhang }
75b9e38a79SLei Zhang 
76e0774c00SLei Zhang std::string tblgen::DagLeaf::getConditionTemplate() const {
778f5fa566SLei Zhang   return getAsConstraint().getConditionTemplate();
78e0774c00SLei Zhang }
79e0774c00SLei Zhang 
80d0e2019dSLei Zhang llvm::StringRef tblgen::DagLeaf::getNativeCodeTemplate() const {
81d0e2019dSLei Zhang   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
82d0e2019dSLei Zhang   return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression");
83eb753f4aSLei Zhang }
84eb753f4aSLei Zhang 
85b9e38a79SLei Zhang bool tblgen::DagLeaf::isSubClassOf(StringRef superclass) const {
86b9e38a79SLei Zhang   if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def))
87b9e38a79SLei Zhang     return defInit->getDef()->isSubClassOf(superclass);
88b9e38a79SLei Zhang   return false;
89b9e38a79SLei Zhang }
90b9e38a79SLei Zhang 
91d0e2019dSLei Zhang bool tblgen::DagNode::isNativeCodeCall() const {
92d0e2019dSLei Zhang   if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator()))
93d0e2019dSLei Zhang     return defInit->getDef()->isSubClassOf("NativeCodeCall");
94c52a8127SFeng Liu   return false;
95c52a8127SFeng Liu }
96c52a8127SFeng Liu 
97647f8cabSRiver Riddle bool tblgen::DagNode::isOperation() const {
98647f8cabSRiver Riddle   return !(isNativeCodeCall() || isVerifyUnusedValue() || isReplaceWithValue());
99647f8cabSRiver Riddle }
100647f8cabSRiver Riddle 
101d0e2019dSLei Zhang llvm::StringRef tblgen::DagNode::getNativeCodeTemplate() const {
102d0e2019dSLei Zhang   assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
103c52a8127SFeng Liu   return cast<llvm::DefInit>(node->getOperator())
104c52a8127SFeng Liu       ->getDef()
105d0e2019dSLei Zhang       ->getValueAsString("expression");
106c52a8127SFeng Liu }
107c52a8127SFeng Liu 
108*e032d0dcSLei Zhang llvm::StringRef tblgen::DagNode::getSymbol() const {
109388fb375SJacques Pienaar   return node->getNameStr();
110388fb375SJacques Pienaar }
111388fb375SJacques Pienaar 
112eb753f4aSLei Zhang Operator &tblgen::DagNode::getDialectOp(RecordOperatorMap *mapper) const {
113eb753f4aSLei Zhang   llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef();
1142dc6d205SLei Zhang   auto it = mapper->find(opDef);
1152dc6d205SLei Zhang   if (it != mapper->end())
1162dc6d205SLei Zhang     return *it->second;
1172dc6d205SLei Zhang   return *mapper->try_emplace(opDef, llvm::make_unique<Operator>(opDef))
1182dc6d205SLei Zhang               .first->second;
119eb753f4aSLei Zhang }
120eb753f4aSLei Zhang 
1212fe8ae4fSJacques Pienaar int tblgen::DagNode::getNumOps() const {
1222fe8ae4fSJacques Pienaar   int count = isReplaceWithValue() ? 0 : 1;
1232fe8ae4fSJacques Pienaar   for (int i = 0, e = getNumArgs(); i != e; ++i) {
124eb753f4aSLei Zhang     if (auto child = getArgAsNestedDag(i))
125eb753f4aSLei Zhang       count += child.getNumOps();
126eb753f4aSLei Zhang   }
127eb753f4aSLei Zhang   return count;
128eb753f4aSLei Zhang }
129eb753f4aSLei Zhang 
1302fe8ae4fSJacques Pienaar int tblgen::DagNode::getNumArgs() const { return node->getNumArgs(); }
131eb753f4aSLei Zhang 
132eb753f4aSLei Zhang bool tblgen::DagNode::isNestedDagArg(unsigned index) const {
133eb753f4aSLei Zhang   return isa<llvm::DagInit>(node->getArg(index));
134eb753f4aSLei Zhang }
135eb753f4aSLei Zhang 
136eb753f4aSLei Zhang tblgen::DagNode tblgen::DagNode::getArgAsNestedDag(unsigned index) const {
137eb753f4aSLei Zhang   return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index)));
138eb753f4aSLei Zhang }
139eb753f4aSLei Zhang 
140e0774c00SLei Zhang tblgen::DagLeaf tblgen::DagNode::getArgAsLeaf(unsigned index) const {
141e0774c00SLei Zhang   assert(!isNestedDagArg(index));
142e0774c00SLei Zhang   return DagLeaf(node->getArg(index));
143eb753f4aSLei Zhang }
144eb753f4aSLei Zhang 
145eb753f4aSLei Zhang StringRef tblgen::DagNode::getArgName(unsigned index) const {
146eb753f4aSLei Zhang   return node->getArgNameStr(index);
147eb753f4aSLei Zhang }
148eb753f4aSLei Zhang 
149eb753f4aSLei Zhang bool tblgen::DagNode::isReplaceWithValue() const {
150eb753f4aSLei Zhang   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
151eb753f4aSLei Zhang   return dagOpDef->getName() == "replaceWithValue";
152eb753f4aSLei Zhang }
153eb753f4aSLei Zhang 
15418fde7c9SLei Zhang bool tblgen::DagNode::isVerifyUnusedValue() const {
15518fde7c9SLei Zhang   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
15618fde7c9SLei Zhang   return dagOpDef->getName() == "verifyUnusedValue";
15718fde7c9SLei Zhang }
15818fde7c9SLei Zhang 
159eb753f4aSLei Zhang tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
160eb753f4aSLei Zhang     : def(*def), recordOpMap(mapper) {
161*e032d0dcSLei Zhang   collectBoundSymbols(getSourcePattern(), srcBoundOps, /*isSrcPattern=*/true);
162*e032d0dcSLei Zhang   for (int i = 0, e = getNumResultPatterns(); i < e; ++i)
163*e032d0dcSLei Zhang     collectBoundSymbols(getResultPattern(i), resBoundOps,
164*e032d0dcSLei Zhang                         /*isSrcPattern=*/false);
165eb753f4aSLei Zhang }
166eb753f4aSLei Zhang 
167eb753f4aSLei Zhang tblgen::DagNode tblgen::Pattern::getSourcePattern() const {
1688f5fa566SLei Zhang   return tblgen::DagNode(def.getValueAsDag("sourcePattern"));
169eb753f4aSLei Zhang }
170eb753f4aSLei Zhang 
171*e032d0dcSLei Zhang int tblgen::Pattern::getNumResultPatterns() const {
1728f5fa566SLei Zhang   auto *results = def.getValueAsListInit("resultPatterns");
173eb753f4aSLei Zhang   return results->size();
174eb753f4aSLei Zhang }
175eb753f4aSLei Zhang 
176eb753f4aSLei Zhang tblgen::DagNode tblgen::Pattern::getResultPattern(unsigned index) const {
1778f5fa566SLei Zhang   auto *results = def.getValueAsListInit("resultPatterns");
178eb753f4aSLei Zhang   return tblgen::DagNode(cast<llvm::DagInit>(results->getElement(index)));
179eb753f4aSLei Zhang }
180eb753f4aSLei Zhang 
18109b623aaSLei Zhang void tblgen::Pattern::ensureBoundInSourcePattern(llvm::StringRef name) const {
182*e032d0dcSLei Zhang   if (srcBoundArguments.find(name) == srcBoundArguments.end() &&
183*e032d0dcSLei Zhang       srcBoundOps.find(name) == srcBoundOps.end())
184eb753f4aSLei Zhang     PrintFatalError(def.getLoc(),
185eb753f4aSLei Zhang                     Twine("referencing unbound variable '") + name + "'");
186eb753f4aSLei Zhang }
187eb753f4aSLei Zhang 
188e0774c00SLei Zhang llvm::StringMap<tblgen::Argument> &
189e0774c00SLei Zhang tblgen::Pattern::getSourcePatternBoundArgs() {
190*e032d0dcSLei Zhang   return srcBoundArguments;
191eb753f4aSLei Zhang }
192eb753f4aSLei Zhang 
1939f02e889SLei Zhang llvm::StringMap<const tblgen::Operator *> &
1949f02e889SLei Zhang tblgen::Pattern::getSourcePatternBoundOps() {
195*e032d0dcSLei Zhang   return srcBoundOps;
196*e032d0dcSLei Zhang }
197*e032d0dcSLei Zhang 
198*e032d0dcSLei Zhang llvm::StringMap<const tblgen::Operator *> &
199*e032d0dcSLei Zhang tblgen::Pattern::getResultPatternBoundOps() {
200*e032d0dcSLei Zhang   return resBoundOps;
201388fb375SJacques Pienaar }
202388fb375SJacques Pienaar 
203eb753f4aSLei Zhang const tblgen::Operator &tblgen::Pattern::getSourceRootOp() {
204eb753f4aSLei Zhang   return getSourcePattern().getDialectOp(recordOpMap);
205eb753f4aSLei Zhang }
206eb753f4aSLei Zhang 
207eb753f4aSLei Zhang tblgen::Operator &tblgen::Pattern::getDialectOp(DagNode node) {
208eb753f4aSLei Zhang   return node.getDialectOp(recordOpMap);
209eb753f4aSLei Zhang }
210388fb375SJacques Pienaar 
2118f5fa566SLei Zhang std::vector<tblgen::AppliedConstraint> tblgen::Pattern::getConstraints() const {
212388fb375SJacques Pienaar   auto *listInit = def.getValueAsListInit("constraints");
2138f5fa566SLei Zhang   std::vector<tblgen::AppliedConstraint> ret;
214388fb375SJacques Pienaar   ret.reserve(listInit->size());
2158f5fa566SLei Zhang 
216388fb375SJacques Pienaar   for (auto it : *listInit) {
2178f5fa566SLei Zhang     auto *dagInit = dyn_cast<llvm::DagInit>(it);
2188f5fa566SLei Zhang     if (!dagInit)
2198f5fa566SLei Zhang       PrintFatalError(def.getLoc(), "all elemements in Pattern multi-entity "
2208f5fa566SLei Zhang                                     "constraints should be DAG nodes");
2218f5fa566SLei Zhang 
2228f5fa566SLei Zhang     std::vector<std::string> entities;
2238f5fa566SLei Zhang     entities.reserve(dagInit->arg_size());
2248f5fa566SLei Zhang     for (auto *argName : dagInit->getArgNames())
2258f5fa566SLei Zhang       entities.push_back(argName->getValue());
2268f5fa566SLei Zhang 
2278f5fa566SLei Zhang     ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(),
2288f5fa566SLei Zhang                      std::move(entities));
229388fb375SJacques Pienaar   }
230388fb375SJacques Pienaar   return ret;
231388fb375SJacques Pienaar }
23253035874SFeng Liu 
23353035874SFeng Liu int tblgen::Pattern::getBenefit() const {
234a0606ca7SFeng Liu   // The initial benefit value is a heuristic with number of ops in the source
23553035874SFeng Liu   // pattern.
236a0606ca7SFeng Liu   int initBenefit = getSourcePattern().getNumOps();
23753035874SFeng Liu   llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
238a0606ca7SFeng Liu   if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
239a0606ca7SFeng Liu     PrintFatalError(def.getLoc(),
240a0606ca7SFeng Liu                     "The 'addBenefit' takes and only takes one integer value");
241a0606ca7SFeng Liu   }
242a0606ca7SFeng Liu   return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
24353035874SFeng Liu }
24404b6d2f3SLei Zhang 
2454165885aSJacques Pienaar std::vector<tblgen::Pattern::IdentifierLine>
2464165885aSJacques Pienaar tblgen::Pattern::getLocation() const {
2474165885aSJacques Pienaar   std::vector<std::pair<StringRef, unsigned>> result;
2484165885aSJacques Pienaar   result.reserve(def.getLoc().size());
2494165885aSJacques Pienaar   for (auto loc : def.getLoc()) {
2504165885aSJacques Pienaar     unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
2514165885aSJacques Pienaar     assert(buf && "invalid source location");
2524165885aSJacques Pienaar     result.emplace_back(
2534165885aSJacques Pienaar         llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
2544165885aSJacques Pienaar         llvm::SrcMgr.getLineAndColumn(loc, buf).first);
2554165885aSJacques Pienaar   }
2564165885aSJacques Pienaar   return result;
2574165885aSJacques Pienaar }
2584165885aSJacques Pienaar 
259*e032d0dcSLei Zhang void tblgen::Pattern::collectBoundSymbols(DagNode tree,
260*e032d0dcSLei Zhang                                           SymbolOperatorMap &symOpMap,
261*e032d0dcSLei Zhang                                           bool isSrcPattern) {
262*e032d0dcSLei Zhang   auto treeName = tree.getSymbol();
263*e032d0dcSLei Zhang   if (!tree.isOperation()) {
264*e032d0dcSLei Zhang     if (!treeName.empty()) {
265*e032d0dcSLei Zhang       PrintFatalError(
266*e032d0dcSLei Zhang           def.getLoc(),
267*e032d0dcSLei Zhang           formatv("binding symbol '{0}' to non-operation unsupported right now",
268*e032d0dcSLei Zhang                   treeName));
269*e032d0dcSLei Zhang     }
270*e032d0dcSLei Zhang     return;
271*e032d0dcSLei Zhang   }
272*e032d0dcSLei Zhang 
27304b6d2f3SLei Zhang   auto &op = getDialectOp(tree);
27404b6d2f3SLei Zhang   auto numOpArgs = op.getNumArgs();
27504b6d2f3SLei Zhang   auto numTreeArgs = tree.getNumArgs();
27604b6d2f3SLei Zhang 
27704b6d2f3SLei Zhang   if (numOpArgs != numTreeArgs) {
27804b6d2f3SLei Zhang     PrintFatalError(def.getLoc(),
27904b6d2f3SLei Zhang                     formatv("op '{0}' argument number mismatch: "
28004b6d2f3SLei Zhang                             "{1} in pattern vs. {2} in definition",
28104b6d2f3SLei Zhang                             op.getOperationName(), numTreeArgs, numOpArgs));
28204b6d2f3SLei Zhang   }
28304b6d2f3SLei Zhang 
28404b6d2f3SLei Zhang   // The name attached to the DAG node's operator is for representing the
28504b6d2f3SLei Zhang   // results generated from this op. It should be remembered as bound results.
28604b6d2f3SLei Zhang   if (!treeName.empty())
287*e032d0dcSLei Zhang     symOpMap.try_emplace(treeName, &op);
28804b6d2f3SLei Zhang 
2892fe8ae4fSJacques Pienaar   for (int i = 0; i != numTreeArgs; ++i) {
29004b6d2f3SLei Zhang     if (auto treeArg = tree.getArgAsNestedDag(i)) {
29104b6d2f3SLei Zhang       // This DAG node argument is a DAG node itself. Go inside recursively.
292*e032d0dcSLei Zhang       collectBoundSymbols(treeArg, symOpMap, isSrcPattern);
293*e032d0dcSLei Zhang     } else if (isSrcPattern) {
294*e032d0dcSLei Zhang       // We can only bind symbols to op arguments in source pattern. Those
295*e032d0dcSLei Zhang       // symbols are referenced in result patterns.
29604b6d2f3SLei Zhang       auto treeArgName = tree.getArgName(i);
29704b6d2f3SLei Zhang       if (!treeArgName.empty())
298*e032d0dcSLei Zhang         srcBoundArguments.try_emplace(treeArgName, op.getArg(i));
29904b6d2f3SLei Zhang     }
30004b6d2f3SLei Zhang   }
30104b6d2f3SLei Zhang }
302