1eb753f4aSLei Zhang //===- Pattern.cpp - Pattern wrapper class ----------------------*- C++ -*-===//
2eb753f4aSLei Zhang //
3eb753f4aSLei Zhang // Copyright 2019 The MLIR Authors.
4eb753f4aSLei Zhang //
5eb753f4aSLei Zhang // Licensed under the Apache License, Version 2.0 (the "License");
6eb753f4aSLei Zhang // you may not use this file except in compliance with the License.
7eb753f4aSLei Zhang // You may obtain a copy of the License at
8eb753f4aSLei Zhang //
9eb753f4aSLei Zhang //   http://www.apache.org/licenses/LICENSE-2.0
10eb753f4aSLei Zhang //
11eb753f4aSLei Zhang // Unless required by applicable law or agreed to in writing, software
12eb753f4aSLei Zhang // distributed under the License is distributed on an "AS IS" BASIS,
13eb753f4aSLei Zhang // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14eb753f4aSLei Zhang // See the License for the specific language governing permissions and
15eb753f4aSLei Zhang // limitations under the License.
16eb753f4aSLei Zhang // =============================================================================
17eb753f4aSLei Zhang //
18eb753f4aSLei Zhang // Pattern wrapper class to simplify using TableGen Record defining a MLIR
19eb753f4aSLei Zhang // Pattern.
20eb753f4aSLei Zhang //
21eb753f4aSLei Zhang //===----------------------------------------------------------------------===//
22eb753f4aSLei Zhang 
23eb753f4aSLei Zhang #include "mlir/TableGen/Pattern.h"
24eb753f4aSLei Zhang #include "llvm/ADT/Twine.h"
2504b6d2f3SLei Zhang #include "llvm/Support/FormatVariadic.h"
268f5fa566SLei Zhang #include "llvm/TableGen/Error.h"
27eb753f4aSLei Zhang #include "llvm/TableGen/Record.h"
28eb753f4aSLei Zhang 
29eb753f4aSLei Zhang using namespace mlir;
30eb753f4aSLei Zhang 
31eb753f4aSLei Zhang using mlir::tblgen::Operator;
32eb753f4aSLei Zhang 
33e0774c00SLei Zhang bool tblgen::DagLeaf::isUnspecified() const {
34b9e38a79SLei Zhang   return dyn_cast_or_null<llvm::UnsetInit>(def);
35e0774c00SLei Zhang }
36e0774c00SLei Zhang 
37e0774c00SLei Zhang bool tblgen::DagLeaf::isOperandMatcher() const {
38e0774c00SLei Zhang   // Operand matchers specify a type constraint.
39b9e38a79SLei Zhang   return isSubClassOf("TypeConstraint");
40e0774c00SLei Zhang }
41e0774c00SLei Zhang 
42e0774c00SLei Zhang bool tblgen::DagLeaf::isAttrMatcher() const {
43c52a8127SFeng Liu   // Attribute matchers specify an attribute constraint.
44b9e38a79SLei Zhang   return isSubClassOf("AttrConstraint");
45e0774c00SLei Zhang }
46e0774c00SLei Zhang 
47e0774c00SLei Zhang bool tblgen::DagLeaf::isAttrTransformer() const {
48b9e38a79SLei Zhang   return isSubClassOf("tAttr");
49e0774c00SLei Zhang }
50e0774c00SLei Zhang 
51e0774c00SLei Zhang bool tblgen::DagLeaf::isConstantAttr() const {
52b9e38a79SLei Zhang   return isSubClassOf("ConstantAttr");
53b9e38a79SLei Zhang }
54b9e38a79SLei Zhang 
55b9e38a79SLei Zhang bool tblgen::DagLeaf::isEnumAttrCase() const {
56b9e38a79SLei Zhang   return isSubClassOf("EnumAttrCase");
57e0774c00SLei Zhang }
58e0774c00SLei Zhang 
598f5fa566SLei Zhang tblgen::Constraint tblgen::DagLeaf::getAsConstraint() const {
608f5fa566SLei Zhang   assert((isOperandMatcher() || isAttrMatcher()) &&
618f5fa566SLei Zhang          "the DAG leaf must be operand or attribute");
628f5fa566SLei Zhang   return Constraint(cast<llvm::DefInit>(def)->getDef());
63e0774c00SLei Zhang }
64e0774c00SLei Zhang 
65e0774c00SLei Zhang tblgen::ConstantAttr tblgen::DagLeaf::getAsConstantAttr() const {
66e0774c00SLei Zhang   assert(isConstantAttr() && "the DAG leaf must be constant attribute");
67e0774c00SLei Zhang   return ConstantAttr(cast<llvm::DefInit>(def));
68e0774c00SLei Zhang }
69e0774c00SLei Zhang 
70b9e38a79SLei Zhang tblgen::EnumAttrCase tblgen::DagLeaf::getAsEnumAttrCase() const {
71b9e38a79SLei Zhang   assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
72b9e38a79SLei Zhang   return EnumAttrCase(cast<llvm::DefInit>(def));
73b9e38a79SLei Zhang }
74b9e38a79SLei Zhang 
75e0774c00SLei Zhang std::string tblgen::DagLeaf::getConditionTemplate() const {
768f5fa566SLei Zhang   return getAsConstraint().getConditionTemplate();
77e0774c00SLei Zhang }
78e0774c00SLei Zhang 
79e0774c00SLei Zhang std::string tblgen::DagLeaf::getTransformationTemplate() const {
80e0774c00SLei Zhang   assert(isAttrTransformer() && "the DAG leaf must be attribute transformer");
81e0774c00SLei Zhang   return cast<llvm::DefInit>(def)
82e0774c00SLei Zhang       ->getDef()
83e0774c00SLei Zhang       ->getValueAsString("attrTransform")
84e0774c00SLei Zhang       .str();
85eb753f4aSLei Zhang }
86eb753f4aSLei Zhang 
87b9e38a79SLei Zhang bool tblgen::DagLeaf::isSubClassOf(StringRef superclass) const {
88b9e38a79SLei Zhang   if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def))
89b9e38a79SLei Zhang     return defInit->getDef()->isSubClassOf(superclass);
90b9e38a79SLei Zhang   return false;
91b9e38a79SLei Zhang }
92b9e38a79SLei Zhang 
93c52a8127SFeng Liu bool tblgen::DagNode::isAttrTransformer() const {
94c52a8127SFeng Liu   auto op = node->getOperator();
95c52a8127SFeng Liu   if (!op || !isa<llvm::DefInit>(op))
96c52a8127SFeng Liu     return false;
97c52a8127SFeng Liu   return cast<llvm::DefInit>(op)->getDef()->isSubClassOf("tAttr");
98c52a8127SFeng Liu }
99c52a8127SFeng Liu 
100c52a8127SFeng Liu std::string tblgen::DagNode::getTransformationTemplate() const {
101c52a8127SFeng Liu   assert(isAttrTransformer() && "the DAG leaf must be attribute transformer");
102c52a8127SFeng Liu   return cast<llvm::DefInit>(node->getOperator())
103c52a8127SFeng Liu       ->getDef()
104c52a8127SFeng Liu       ->getValueAsString("attrTransform")
105c52a8127SFeng Liu       .str();
106c52a8127SFeng Liu }
107c52a8127SFeng Liu 
108388fb375SJacques Pienaar llvm::StringRef tblgen::DagNode::getOpName() const {
109388fb375SJacques Pienaar   return node->getNameStr();
110388fb375SJacques Pienaar }
111388fb375SJacques Pienaar 
112eb753f4aSLei Zhang Operator &tblgen::DagNode::getDialectOp(RecordOperatorMap *mapper) const {
113eb753f4aSLei Zhang   llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef();
1142dc6d205SLei Zhang   auto it = mapper->find(opDef);
1152dc6d205SLei Zhang   if (it != mapper->end())
1162dc6d205SLei Zhang     return *it->second;
1172dc6d205SLei Zhang   return *mapper->try_emplace(opDef, llvm::make_unique<Operator>(opDef))
1182dc6d205SLei Zhang               .first->second;
119eb753f4aSLei Zhang }
120eb753f4aSLei Zhang 
121eb753f4aSLei Zhang unsigned tblgen::DagNode::getNumOps() const {
122eb753f4aSLei Zhang   unsigned count = isReplaceWithValue() ? 0 : 1;
123eb753f4aSLei Zhang   for (unsigned i = 0, e = getNumArgs(); i != e; ++i) {
124eb753f4aSLei Zhang     if (auto child = getArgAsNestedDag(i))
125eb753f4aSLei Zhang       count += child.getNumOps();
126eb753f4aSLei Zhang   }
127eb753f4aSLei Zhang   return count;
128eb753f4aSLei Zhang }
129eb753f4aSLei Zhang 
130eb753f4aSLei Zhang unsigned tblgen::DagNode::getNumArgs() const { return node->getNumArgs(); }
131eb753f4aSLei Zhang 
132eb753f4aSLei Zhang bool tblgen::DagNode::isNestedDagArg(unsigned index) const {
133eb753f4aSLei Zhang   return isa<llvm::DagInit>(node->getArg(index));
134eb753f4aSLei Zhang }
135eb753f4aSLei Zhang 
136eb753f4aSLei Zhang tblgen::DagNode tblgen::DagNode::getArgAsNestedDag(unsigned index) const {
137eb753f4aSLei Zhang   return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index)));
138eb753f4aSLei Zhang }
139eb753f4aSLei Zhang 
140e0774c00SLei Zhang tblgen::DagLeaf tblgen::DagNode::getArgAsLeaf(unsigned index) const {
141e0774c00SLei Zhang   assert(!isNestedDagArg(index));
142e0774c00SLei Zhang   return DagLeaf(node->getArg(index));
143eb753f4aSLei Zhang }
144eb753f4aSLei Zhang 
145eb753f4aSLei Zhang StringRef tblgen::DagNode::getArgName(unsigned index) const {
146eb753f4aSLei Zhang   return node->getArgNameStr(index);
147eb753f4aSLei Zhang }
148eb753f4aSLei Zhang 
149eb753f4aSLei Zhang bool tblgen::DagNode::isReplaceWithValue() const {
150eb753f4aSLei Zhang   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
151eb753f4aSLei Zhang   return dagOpDef->getName() == "replaceWithValue";
152eb753f4aSLei Zhang }
153eb753f4aSLei Zhang 
15418fde7c9SLei Zhang bool tblgen::DagNode::isVerifyUnusedValue() const {
15518fde7c9SLei Zhang   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
15618fde7c9SLei Zhang   return dagOpDef->getName() == "verifyUnusedValue";
15718fde7c9SLei Zhang }
15818fde7c9SLei Zhang 
15982dc6a87SJacques Pienaar bool tblgen::DagNode::isNativeCodeBuilder() const {
16082dc6a87SJacques Pienaar   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
16182dc6a87SJacques Pienaar   return dagOpDef->isSubClassOf("cOp");
16282dc6a87SJacques Pienaar }
16382dc6a87SJacques Pienaar 
16482dc6a87SJacques Pienaar llvm::StringRef tblgen::DagNode::getNativeCodeBuilder() const {
16582dc6a87SJacques Pienaar   assert(isNativeCodeBuilder());
16682dc6a87SJacques Pienaar   auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
16782dc6a87SJacques Pienaar   return dagOpDef->getValueAsString("function");
16882dc6a87SJacques Pienaar }
16982dc6a87SJacques Pienaar 
170eb753f4aSLei Zhang tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
171eb753f4aSLei Zhang     : def(*def), recordOpMap(mapper) {
17204b6d2f3SLei Zhang   collectBoundArguments(getSourcePattern());
173eb753f4aSLei Zhang }
174eb753f4aSLei Zhang 
175eb753f4aSLei Zhang tblgen::DagNode tblgen::Pattern::getSourcePattern() const {
1768f5fa566SLei Zhang   return tblgen::DagNode(def.getValueAsDag("sourcePattern"));
177eb753f4aSLei Zhang }
178eb753f4aSLei Zhang 
179eb753f4aSLei Zhang unsigned tblgen::Pattern::getNumResults() const {
1808f5fa566SLei Zhang   auto *results = def.getValueAsListInit("resultPatterns");
181eb753f4aSLei Zhang   return results->size();
182eb753f4aSLei Zhang }
183eb753f4aSLei Zhang 
184eb753f4aSLei Zhang tblgen::DagNode tblgen::Pattern::getResultPattern(unsigned index) const {
1858f5fa566SLei Zhang   auto *results = def.getValueAsListInit("resultPatterns");
186eb753f4aSLei Zhang   return tblgen::DagNode(cast<llvm::DagInit>(results->getElement(index)));
187eb753f4aSLei Zhang }
188eb753f4aSLei Zhang 
189*09b623aaSLei Zhang void tblgen::Pattern::ensureBoundInSourcePattern(llvm::StringRef name) const {
190*09b623aaSLei Zhang   if (boundArguments.find(name) == boundArguments.end() &&
191*09b623aaSLei Zhang       boundOps.find(name) == boundOps.end())
192eb753f4aSLei Zhang     PrintFatalError(def.getLoc(),
193eb753f4aSLei Zhang                     Twine("referencing unbound variable '") + name + "'");
194eb753f4aSLei Zhang }
195eb753f4aSLei Zhang 
196e0774c00SLei Zhang llvm::StringMap<tblgen::Argument> &
197e0774c00SLei Zhang tblgen::Pattern::getSourcePatternBoundArgs() {
198eb753f4aSLei Zhang   return boundArguments;
199eb753f4aSLei Zhang }
200eb753f4aSLei Zhang 
201*09b623aaSLei Zhang llvm::StringSet<> &tblgen::Pattern::getSourcePatternBoundOps() {
202*09b623aaSLei Zhang   return boundOps;
203388fb375SJacques Pienaar }
204388fb375SJacques Pienaar 
205eb753f4aSLei Zhang const tblgen::Operator &tblgen::Pattern::getSourceRootOp() {
206eb753f4aSLei Zhang   return getSourcePattern().getDialectOp(recordOpMap);
207eb753f4aSLei Zhang }
208eb753f4aSLei Zhang 
209eb753f4aSLei Zhang tblgen::Operator &tblgen::Pattern::getDialectOp(DagNode node) {
210eb753f4aSLei Zhang   return node.getDialectOp(recordOpMap);
211eb753f4aSLei Zhang }
212388fb375SJacques Pienaar 
2138f5fa566SLei Zhang std::vector<tblgen::AppliedConstraint> tblgen::Pattern::getConstraints() const {
214388fb375SJacques Pienaar   auto *listInit = def.getValueAsListInit("constraints");
2158f5fa566SLei Zhang   std::vector<tblgen::AppliedConstraint> ret;
216388fb375SJacques Pienaar   ret.reserve(listInit->size());
2178f5fa566SLei Zhang 
218388fb375SJacques Pienaar   for (auto it : *listInit) {
2198f5fa566SLei Zhang     auto *dagInit = dyn_cast<llvm::DagInit>(it);
2208f5fa566SLei Zhang     if (!dagInit)
2218f5fa566SLei Zhang       PrintFatalError(def.getLoc(), "all elemements in Pattern multi-entity "
2228f5fa566SLei Zhang                                     "constraints should be DAG nodes");
2238f5fa566SLei Zhang 
2248f5fa566SLei Zhang     std::vector<std::string> entities;
2258f5fa566SLei Zhang     entities.reserve(dagInit->arg_size());
2268f5fa566SLei Zhang     for (auto *argName : dagInit->getArgNames())
2278f5fa566SLei Zhang       entities.push_back(argName->getValue());
2288f5fa566SLei Zhang 
2298f5fa566SLei Zhang     ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(),
2308f5fa566SLei Zhang                      std::move(entities));
231388fb375SJacques Pienaar   }
232388fb375SJacques Pienaar   return ret;
233388fb375SJacques Pienaar }
23453035874SFeng Liu 
23553035874SFeng Liu int tblgen::Pattern::getBenefit() const {
236a0606ca7SFeng Liu   // The initial benefit value is a heuristic with number of ops in the source
23753035874SFeng Liu   // pattern.
238a0606ca7SFeng Liu   int initBenefit = getSourcePattern().getNumOps();
23953035874SFeng Liu   llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
240a0606ca7SFeng Liu   if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
241a0606ca7SFeng Liu     PrintFatalError(def.getLoc(),
242a0606ca7SFeng Liu                     "The 'addBenefit' takes and only takes one integer value");
243a0606ca7SFeng Liu   }
244a0606ca7SFeng Liu   return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
24553035874SFeng Liu }
24604b6d2f3SLei Zhang 
24704b6d2f3SLei Zhang void tblgen::Pattern::collectBoundArguments(DagNode tree) {
24804b6d2f3SLei Zhang   auto &op = getDialectOp(tree);
24904b6d2f3SLei Zhang   auto numOpArgs = op.getNumArgs();
25004b6d2f3SLei Zhang   auto numTreeArgs = tree.getNumArgs();
25104b6d2f3SLei Zhang 
25204b6d2f3SLei Zhang   if (numOpArgs != numTreeArgs) {
25304b6d2f3SLei Zhang     PrintFatalError(def.getLoc(),
25404b6d2f3SLei Zhang                     formatv("op '{0}' argument number mismatch: "
25504b6d2f3SLei Zhang                             "{1} in pattern vs. {2} in definition",
25604b6d2f3SLei Zhang                             op.getOperationName(), numTreeArgs, numOpArgs));
25704b6d2f3SLei Zhang   }
25804b6d2f3SLei Zhang 
25904b6d2f3SLei Zhang   // The name attached to the DAG node's operator is for representing the
26004b6d2f3SLei Zhang   // results generated from this op. It should be remembered as bound results.
26104b6d2f3SLei Zhang   auto treeName = tree.getOpName();
26204b6d2f3SLei Zhang   if (!treeName.empty())
263*09b623aaSLei Zhang     boundOps.insert(treeName);
26404b6d2f3SLei Zhang 
26504b6d2f3SLei Zhang   // TODO(jpienaar): Expand to multiple matches.
26604b6d2f3SLei Zhang   for (unsigned i = 0; i != numTreeArgs; ++i) {
26704b6d2f3SLei Zhang     if (auto treeArg = tree.getArgAsNestedDag(i)) {
26804b6d2f3SLei Zhang       // This DAG node argument is a DAG node itself. Go inside recursively.
26904b6d2f3SLei Zhang       collectBoundArguments(treeArg);
27004b6d2f3SLei Zhang     } else {
27104b6d2f3SLei Zhang       auto treeArgName = tree.getArgName(i);
27204b6d2f3SLei Zhang       if (!treeArgName.empty())
27304b6d2f3SLei Zhang         boundArguments.try_emplace(treeArgName, op.getArg(i));
27404b6d2f3SLei Zhang     }
27504b6d2f3SLei Zhang   }
27604b6d2f3SLei Zhang }
277