1cde4d5a6SJacques Pienaar //===- Pattern.cpp - Pattern wrapper class --------------------------------===// 2eb753f4aSLei Zhang // 3eb753f4aSLei Zhang // Copyright 2019 The MLIR Authors. 4eb753f4aSLei Zhang // 5eb753f4aSLei Zhang // Licensed under the Apache License, Version 2.0 (the "License"); 6eb753f4aSLei Zhang // you may not use this file except in compliance with the License. 7eb753f4aSLei Zhang // You may obtain a copy of the License at 8eb753f4aSLei Zhang // 9eb753f4aSLei Zhang // http://www.apache.org/licenses/LICENSE-2.0 10eb753f4aSLei Zhang // 11eb753f4aSLei Zhang // Unless required by applicable law or agreed to in writing, software 12eb753f4aSLei Zhang // distributed under the License is distributed on an "AS IS" BASIS, 13eb753f4aSLei Zhang // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14eb753f4aSLei Zhang // See the License for the specific language governing permissions and 15eb753f4aSLei Zhang // limitations under the License. 16eb753f4aSLei Zhang // ============================================================================= 17eb753f4aSLei Zhang // 18eb753f4aSLei Zhang // Pattern wrapper class to simplify using TableGen Record defining a MLIR 19eb753f4aSLei Zhang // Pattern. 20eb753f4aSLei Zhang // 21eb753f4aSLei Zhang //===----------------------------------------------------------------------===// 22eb753f4aSLei Zhang 23eb753f4aSLei Zhang #include "mlir/TableGen/Pattern.h" 24eb753f4aSLei Zhang #include "llvm/ADT/Twine.h" 2504b6d2f3SLei Zhang #include "llvm/Support/FormatVariadic.h" 268f5fa566SLei Zhang #include "llvm/TableGen/Error.h" 27eb753f4aSLei Zhang #include "llvm/TableGen/Record.h" 28eb753f4aSLei Zhang 29eb753f4aSLei Zhang using namespace mlir; 30eb753f4aSLei Zhang 310ea6154bSJacques Pienaar using llvm::formatv; 32eb753f4aSLei Zhang using mlir::tblgen::Operator; 33eb753f4aSLei Zhang 34e0774c00SLei Zhang bool tblgen::DagLeaf::isUnspecified() const { 35b9e38a79SLei Zhang return dyn_cast_or_null<llvm::UnsetInit>(def); 36e0774c00SLei Zhang } 37e0774c00SLei Zhang 38e0774c00SLei Zhang bool tblgen::DagLeaf::isOperandMatcher() const { 39e0774c00SLei Zhang // Operand matchers specify a type constraint. 40b9e38a79SLei Zhang return isSubClassOf("TypeConstraint"); 41e0774c00SLei Zhang } 42e0774c00SLei Zhang 43e0774c00SLei Zhang bool tblgen::DagLeaf::isAttrMatcher() const { 44c52a8127SFeng Liu // Attribute matchers specify an attribute constraint. 45b9e38a79SLei Zhang return isSubClassOf("AttrConstraint"); 46e0774c00SLei Zhang } 47e0774c00SLei Zhang 48d0e2019dSLei Zhang bool tblgen::DagLeaf::isNativeCodeCall() const { 49d0e2019dSLei Zhang return isSubClassOf("NativeCodeCall"); 50e0774c00SLei Zhang } 51e0774c00SLei Zhang 52e0774c00SLei Zhang bool tblgen::DagLeaf::isConstantAttr() const { 53b9e38a79SLei Zhang return isSubClassOf("ConstantAttr"); 54b9e38a79SLei Zhang } 55b9e38a79SLei Zhang 56b9e38a79SLei Zhang bool tblgen::DagLeaf::isEnumAttrCase() const { 579dd182e0SLei Zhang return isSubClassOf("EnumAttrCaseInfo"); 58e0774c00SLei Zhang } 59e0774c00SLei Zhang 608f5fa566SLei Zhang tblgen::Constraint tblgen::DagLeaf::getAsConstraint() const { 618f5fa566SLei Zhang assert((isOperandMatcher() || isAttrMatcher()) && 628f5fa566SLei Zhang "the DAG leaf must be operand or attribute"); 638f5fa566SLei Zhang return Constraint(cast<llvm::DefInit>(def)->getDef()); 64e0774c00SLei Zhang } 65e0774c00SLei Zhang 66e0774c00SLei Zhang tblgen::ConstantAttr tblgen::DagLeaf::getAsConstantAttr() const { 67e0774c00SLei Zhang assert(isConstantAttr() && "the DAG leaf must be constant attribute"); 68e0774c00SLei Zhang return ConstantAttr(cast<llvm::DefInit>(def)); 69e0774c00SLei Zhang } 70e0774c00SLei Zhang 71b9e38a79SLei Zhang tblgen::EnumAttrCase tblgen::DagLeaf::getAsEnumAttrCase() const { 72b9e38a79SLei Zhang assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case"); 73b9e38a79SLei Zhang return EnumAttrCase(cast<llvm::DefInit>(def)); 74b9e38a79SLei Zhang } 75b9e38a79SLei Zhang 76e0774c00SLei Zhang std::string tblgen::DagLeaf::getConditionTemplate() const { 778f5fa566SLei Zhang return getAsConstraint().getConditionTemplate(); 78e0774c00SLei Zhang } 79e0774c00SLei Zhang 80d0e2019dSLei Zhang llvm::StringRef tblgen::DagLeaf::getNativeCodeTemplate() const { 81d0e2019dSLei Zhang assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 82d0e2019dSLei Zhang return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression"); 83eb753f4aSLei Zhang } 84eb753f4aSLei Zhang 85b9e38a79SLei Zhang bool tblgen::DagLeaf::isSubClassOf(StringRef superclass) const { 86b9e38a79SLei Zhang if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def)) 87b9e38a79SLei Zhang return defInit->getDef()->isSubClassOf(superclass); 88b9e38a79SLei Zhang return false; 89b9e38a79SLei Zhang } 90b9e38a79SLei Zhang 91d0e2019dSLei Zhang bool tblgen::DagNode::isNativeCodeCall() const { 92d0e2019dSLei Zhang if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator())) 93d0e2019dSLei Zhang return defInit->getDef()->isSubClassOf("NativeCodeCall"); 94c52a8127SFeng Liu return false; 95c52a8127SFeng Liu } 96c52a8127SFeng Liu 97647f8cabSRiver Riddle bool tblgen::DagNode::isOperation() const { 98647f8cabSRiver Riddle return !(isNativeCodeCall() || isVerifyUnusedValue() || isReplaceWithValue()); 99647f8cabSRiver Riddle } 100647f8cabSRiver Riddle 101d0e2019dSLei Zhang llvm::StringRef tblgen::DagNode::getNativeCodeTemplate() const { 102d0e2019dSLei Zhang assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 103c52a8127SFeng Liu return cast<llvm::DefInit>(node->getOperator()) 104c52a8127SFeng Liu ->getDef() 105d0e2019dSLei Zhang ->getValueAsString("expression"); 106c52a8127SFeng Liu } 107c52a8127SFeng Liu 108*e032d0dcSLei Zhang llvm::StringRef tblgen::DagNode::getSymbol() const { 109388fb375SJacques Pienaar return node->getNameStr(); 110388fb375SJacques Pienaar } 111388fb375SJacques Pienaar 112eb753f4aSLei Zhang Operator &tblgen::DagNode::getDialectOp(RecordOperatorMap *mapper) const { 113eb753f4aSLei Zhang llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 1142dc6d205SLei Zhang auto it = mapper->find(opDef); 1152dc6d205SLei Zhang if (it != mapper->end()) 1162dc6d205SLei Zhang return *it->second; 1172dc6d205SLei Zhang return *mapper->try_emplace(opDef, llvm::make_unique<Operator>(opDef)) 1182dc6d205SLei Zhang .first->second; 119eb753f4aSLei Zhang } 120eb753f4aSLei Zhang 1212fe8ae4fSJacques Pienaar int tblgen::DagNode::getNumOps() const { 1222fe8ae4fSJacques Pienaar int count = isReplaceWithValue() ? 0 : 1; 1232fe8ae4fSJacques Pienaar for (int i = 0, e = getNumArgs(); i != e; ++i) { 124eb753f4aSLei Zhang if (auto child = getArgAsNestedDag(i)) 125eb753f4aSLei Zhang count += child.getNumOps(); 126eb753f4aSLei Zhang } 127eb753f4aSLei Zhang return count; 128eb753f4aSLei Zhang } 129eb753f4aSLei Zhang 1302fe8ae4fSJacques Pienaar int tblgen::DagNode::getNumArgs() const { return node->getNumArgs(); } 131eb753f4aSLei Zhang 132eb753f4aSLei Zhang bool tblgen::DagNode::isNestedDagArg(unsigned index) const { 133eb753f4aSLei Zhang return isa<llvm::DagInit>(node->getArg(index)); 134eb753f4aSLei Zhang } 135eb753f4aSLei Zhang 136eb753f4aSLei Zhang tblgen::DagNode tblgen::DagNode::getArgAsNestedDag(unsigned index) const { 137eb753f4aSLei Zhang return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index))); 138eb753f4aSLei Zhang } 139eb753f4aSLei Zhang 140e0774c00SLei Zhang tblgen::DagLeaf tblgen::DagNode::getArgAsLeaf(unsigned index) const { 141e0774c00SLei Zhang assert(!isNestedDagArg(index)); 142e0774c00SLei Zhang return DagLeaf(node->getArg(index)); 143eb753f4aSLei Zhang } 144eb753f4aSLei Zhang 145eb753f4aSLei Zhang StringRef tblgen::DagNode::getArgName(unsigned index) const { 146eb753f4aSLei Zhang return node->getArgNameStr(index); 147eb753f4aSLei Zhang } 148eb753f4aSLei Zhang 149eb753f4aSLei Zhang bool tblgen::DagNode::isReplaceWithValue() const { 150eb753f4aSLei Zhang auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 151eb753f4aSLei Zhang return dagOpDef->getName() == "replaceWithValue"; 152eb753f4aSLei Zhang } 153eb753f4aSLei Zhang 15418fde7c9SLei Zhang bool tblgen::DagNode::isVerifyUnusedValue() const { 15518fde7c9SLei Zhang auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 15618fde7c9SLei Zhang return dagOpDef->getName() == "verifyUnusedValue"; 15718fde7c9SLei Zhang } 15818fde7c9SLei Zhang 159eb753f4aSLei Zhang tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper) 160eb753f4aSLei Zhang : def(*def), recordOpMap(mapper) { 161*e032d0dcSLei Zhang collectBoundSymbols(getSourcePattern(), srcBoundOps, /*isSrcPattern=*/true); 162*e032d0dcSLei Zhang for (int i = 0, e = getNumResultPatterns(); i < e; ++i) 163*e032d0dcSLei Zhang collectBoundSymbols(getResultPattern(i), resBoundOps, 164*e032d0dcSLei Zhang /*isSrcPattern=*/false); 165eb753f4aSLei Zhang } 166eb753f4aSLei Zhang 167eb753f4aSLei Zhang tblgen::DagNode tblgen::Pattern::getSourcePattern() const { 1688f5fa566SLei Zhang return tblgen::DagNode(def.getValueAsDag("sourcePattern")); 169eb753f4aSLei Zhang } 170eb753f4aSLei Zhang 171*e032d0dcSLei Zhang int tblgen::Pattern::getNumResultPatterns() const { 1728f5fa566SLei Zhang auto *results = def.getValueAsListInit("resultPatterns"); 173eb753f4aSLei Zhang return results->size(); 174eb753f4aSLei Zhang } 175eb753f4aSLei Zhang 176eb753f4aSLei Zhang tblgen::DagNode tblgen::Pattern::getResultPattern(unsigned index) const { 1778f5fa566SLei Zhang auto *results = def.getValueAsListInit("resultPatterns"); 178eb753f4aSLei Zhang return tblgen::DagNode(cast<llvm::DagInit>(results->getElement(index))); 179eb753f4aSLei Zhang } 180eb753f4aSLei Zhang 18109b623aaSLei Zhang void tblgen::Pattern::ensureBoundInSourcePattern(llvm::StringRef name) const { 182*e032d0dcSLei Zhang if (srcBoundArguments.find(name) == srcBoundArguments.end() && 183*e032d0dcSLei Zhang srcBoundOps.find(name) == srcBoundOps.end()) 184eb753f4aSLei Zhang PrintFatalError(def.getLoc(), 185eb753f4aSLei Zhang Twine("referencing unbound variable '") + name + "'"); 186eb753f4aSLei Zhang } 187eb753f4aSLei Zhang 188e0774c00SLei Zhang llvm::StringMap<tblgen::Argument> & 189e0774c00SLei Zhang tblgen::Pattern::getSourcePatternBoundArgs() { 190*e032d0dcSLei Zhang return srcBoundArguments; 191eb753f4aSLei Zhang } 192eb753f4aSLei Zhang 1939f02e889SLei Zhang llvm::StringMap<const tblgen::Operator *> & 1949f02e889SLei Zhang tblgen::Pattern::getSourcePatternBoundOps() { 195*e032d0dcSLei Zhang return srcBoundOps; 196*e032d0dcSLei Zhang } 197*e032d0dcSLei Zhang 198*e032d0dcSLei Zhang llvm::StringMap<const tblgen::Operator *> & 199*e032d0dcSLei Zhang tblgen::Pattern::getResultPatternBoundOps() { 200*e032d0dcSLei Zhang return resBoundOps; 201388fb375SJacques Pienaar } 202388fb375SJacques Pienaar 203eb753f4aSLei Zhang const tblgen::Operator &tblgen::Pattern::getSourceRootOp() { 204eb753f4aSLei Zhang return getSourcePattern().getDialectOp(recordOpMap); 205eb753f4aSLei Zhang } 206eb753f4aSLei Zhang 207eb753f4aSLei Zhang tblgen::Operator &tblgen::Pattern::getDialectOp(DagNode node) { 208eb753f4aSLei Zhang return node.getDialectOp(recordOpMap); 209eb753f4aSLei Zhang } 210388fb375SJacques Pienaar 2118f5fa566SLei Zhang std::vector<tblgen::AppliedConstraint> tblgen::Pattern::getConstraints() const { 212388fb375SJacques Pienaar auto *listInit = def.getValueAsListInit("constraints"); 2138f5fa566SLei Zhang std::vector<tblgen::AppliedConstraint> ret; 214388fb375SJacques Pienaar ret.reserve(listInit->size()); 2158f5fa566SLei Zhang 216388fb375SJacques Pienaar for (auto it : *listInit) { 2178f5fa566SLei Zhang auto *dagInit = dyn_cast<llvm::DagInit>(it); 2188f5fa566SLei Zhang if (!dagInit) 2198f5fa566SLei Zhang PrintFatalError(def.getLoc(), "all elemements in Pattern multi-entity " 2208f5fa566SLei Zhang "constraints should be DAG nodes"); 2218f5fa566SLei Zhang 2228f5fa566SLei Zhang std::vector<std::string> entities; 2238f5fa566SLei Zhang entities.reserve(dagInit->arg_size()); 2248f5fa566SLei Zhang for (auto *argName : dagInit->getArgNames()) 2258f5fa566SLei Zhang entities.push_back(argName->getValue()); 2268f5fa566SLei Zhang 2278f5fa566SLei Zhang ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(), 2288f5fa566SLei Zhang std::move(entities)); 229388fb375SJacques Pienaar } 230388fb375SJacques Pienaar return ret; 231388fb375SJacques Pienaar } 23253035874SFeng Liu 23353035874SFeng Liu int tblgen::Pattern::getBenefit() const { 234a0606ca7SFeng Liu // The initial benefit value is a heuristic with number of ops in the source 23553035874SFeng Liu // pattern. 236a0606ca7SFeng Liu int initBenefit = getSourcePattern().getNumOps(); 23753035874SFeng Liu llvm::DagInit *delta = def.getValueAsDag("benefitDelta"); 238a0606ca7SFeng Liu if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) { 239a0606ca7SFeng Liu PrintFatalError(def.getLoc(), 240a0606ca7SFeng Liu "The 'addBenefit' takes and only takes one integer value"); 241a0606ca7SFeng Liu } 242a0606ca7SFeng Liu return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue(); 24353035874SFeng Liu } 24404b6d2f3SLei Zhang 2454165885aSJacques Pienaar std::vector<tblgen::Pattern::IdentifierLine> 2464165885aSJacques Pienaar tblgen::Pattern::getLocation() const { 2474165885aSJacques Pienaar std::vector<std::pair<StringRef, unsigned>> result; 2484165885aSJacques Pienaar result.reserve(def.getLoc().size()); 2494165885aSJacques Pienaar for (auto loc : def.getLoc()) { 2504165885aSJacques Pienaar unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc); 2514165885aSJacques Pienaar assert(buf && "invalid source location"); 2524165885aSJacques Pienaar result.emplace_back( 2534165885aSJacques Pienaar llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(), 2544165885aSJacques Pienaar llvm::SrcMgr.getLineAndColumn(loc, buf).first); 2554165885aSJacques Pienaar } 2564165885aSJacques Pienaar return result; 2574165885aSJacques Pienaar } 2584165885aSJacques Pienaar 259*e032d0dcSLei Zhang void tblgen::Pattern::collectBoundSymbols(DagNode tree, 260*e032d0dcSLei Zhang SymbolOperatorMap &symOpMap, 261*e032d0dcSLei Zhang bool isSrcPattern) { 262*e032d0dcSLei Zhang auto treeName = tree.getSymbol(); 263*e032d0dcSLei Zhang if (!tree.isOperation()) { 264*e032d0dcSLei Zhang if (!treeName.empty()) { 265*e032d0dcSLei Zhang PrintFatalError( 266*e032d0dcSLei Zhang def.getLoc(), 267*e032d0dcSLei Zhang formatv("binding symbol '{0}' to non-operation unsupported right now", 268*e032d0dcSLei Zhang treeName)); 269*e032d0dcSLei Zhang } 270*e032d0dcSLei Zhang return; 271*e032d0dcSLei Zhang } 272*e032d0dcSLei Zhang 27304b6d2f3SLei Zhang auto &op = getDialectOp(tree); 27404b6d2f3SLei Zhang auto numOpArgs = op.getNumArgs(); 27504b6d2f3SLei Zhang auto numTreeArgs = tree.getNumArgs(); 27604b6d2f3SLei Zhang 27704b6d2f3SLei Zhang if (numOpArgs != numTreeArgs) { 27804b6d2f3SLei Zhang PrintFatalError(def.getLoc(), 27904b6d2f3SLei Zhang formatv("op '{0}' argument number mismatch: " 28004b6d2f3SLei Zhang "{1} in pattern vs. {2} in definition", 28104b6d2f3SLei Zhang op.getOperationName(), numTreeArgs, numOpArgs)); 28204b6d2f3SLei Zhang } 28304b6d2f3SLei Zhang 28404b6d2f3SLei Zhang // The name attached to the DAG node's operator is for representing the 28504b6d2f3SLei Zhang // results generated from this op. It should be remembered as bound results. 28604b6d2f3SLei Zhang if (!treeName.empty()) 287*e032d0dcSLei Zhang symOpMap.try_emplace(treeName, &op); 28804b6d2f3SLei Zhang 2892fe8ae4fSJacques Pienaar for (int i = 0; i != numTreeArgs; ++i) { 29004b6d2f3SLei Zhang if (auto treeArg = tree.getArgAsNestedDag(i)) { 29104b6d2f3SLei Zhang // This DAG node argument is a DAG node itself. Go inside recursively. 292*e032d0dcSLei Zhang collectBoundSymbols(treeArg, symOpMap, isSrcPattern); 293*e032d0dcSLei Zhang } else if (isSrcPattern) { 294*e032d0dcSLei Zhang // We can only bind symbols to op arguments in source pattern. Those 295*e032d0dcSLei Zhang // symbols are referenced in result patterns. 29604b6d2f3SLei Zhang auto treeArgName = tree.getArgName(i); 29704b6d2f3SLei Zhang if (!treeArgName.empty()) 298*e032d0dcSLei Zhang srcBoundArguments.try_emplace(treeArgName, op.getArg(i)); 29904b6d2f3SLei Zhang } 30004b6d2f3SLei Zhang } 30104b6d2f3SLei Zhang } 302