18a1ca2cdSRiver Riddle //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===//
28a1ca2cdSRiver Riddle //
38a1ca2cdSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
48a1ca2cdSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
58a1ca2cdSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68a1ca2cdSRiver Riddle //
78a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
88a1ca2cdSRiver Riddle
98a1ca2cdSRiver Riddle #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
108a1ca2cdSRiver Riddle #include "../PassDetail.h"
118a1ca2cdSRiver Riddle #include "PredicateTree.h"
128a1ca2cdSRiver Riddle #include "mlir/Dialect/PDL/IR/PDL.h"
138a1ca2cdSRiver Riddle #include "mlir/Dialect/PDL/IR/PDLTypes.h"
148a1ca2cdSRiver Riddle #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
158a1ca2cdSRiver Riddle #include "mlir/Pass/Pass.h"
168a1ca2cdSRiver Riddle #include "llvm/ADT/MapVector.h"
178a1ca2cdSRiver Riddle #include "llvm/ADT/ScopedHashTable.h"
181d49e535SGuillaume Chatelet #include "llvm/ADT/Sequence.h"
198a1ca2cdSRiver Riddle #include "llvm/ADT/SetVector.h"
201d49e535SGuillaume Chatelet #include "llvm/ADT/SmallVector.h"
218a1ca2cdSRiver Riddle #include "llvm/ADT/TypeSwitch.h"
228a1ca2cdSRiver Riddle
238a1ca2cdSRiver Riddle using namespace mlir;
248a1ca2cdSRiver Riddle using namespace mlir::pdl_to_pdl_interp;
258a1ca2cdSRiver Riddle
268a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
278a1ca2cdSRiver Riddle // PatternLowering
288a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
298a1ca2cdSRiver Riddle
308a1ca2cdSRiver Riddle namespace {
318a1ca2cdSRiver Riddle /// This class generators operations within the PDL Interpreter dialect from a
328a1ca2cdSRiver Riddle /// given module containing PDL pattern operations.
338a1ca2cdSRiver Riddle struct PatternLowering {
348a1ca2cdSRiver Riddle public:
35f96a8675SRiver Riddle PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule);
368a1ca2cdSRiver Riddle
378a1ca2cdSRiver Riddle /// Generate code for matching and rewriting based on the pattern operations
388a1ca2cdSRiver Riddle /// within the module.
398a1ca2cdSRiver Riddle void lower(ModuleOp module);
408a1ca2cdSRiver Riddle
418a1ca2cdSRiver Riddle private:
428a1ca2cdSRiver Riddle using ValueMap = llvm::ScopedHashTable<Position *, Value>;
438a1ca2cdSRiver Riddle using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>;
448a1ca2cdSRiver Riddle
458a1ca2cdSRiver Riddle /// Generate interpreter operations for the tree rooted at the given matcher
46a76ee58fSStanislav Funiak /// node, in the specified region.
47a76ee58fSStanislav Funiak Block *generateMatcher(MatcherNode &node, Region ®ion);
488a1ca2cdSRiver Riddle
49a76ee58fSStanislav Funiak /// Get or create an access to the provided positional value in the current
50a76ee58fSStanislav Funiak /// block. This operation may mutate the provided block pointer if nested
51a76ee58fSStanislav Funiak /// regions (i.e., pdl_interp.iterate) are required.
52a76ee58fSStanislav Funiak Value getValueAt(Block *¤tBlock, Position *pos);
538a1ca2cdSRiver Riddle
54a76ee58fSStanislav Funiak /// Create the interpreter predicate operations. This operation may mutate the
55a76ee58fSStanislav Funiak /// provided current block pointer if nested regions (iterates) are required.
56a76ee58fSStanislav Funiak void generate(BoolNode *boolNode, Block *¤tBlock, Value val);
578a1ca2cdSRiver Riddle
58a76ee58fSStanislav Funiak /// Create the interpreter switch / predicate operations, with several case
59a76ee58fSStanislav Funiak /// destinations. This operation never mutates the provided current block
60a76ee58fSStanislav Funiak /// pointer, because the switch operation does not need Values beyond `val`.
61a76ee58fSStanislav Funiak void generate(SwitchNode *switchNode, Block *currentBlock, Value val);
628a1ca2cdSRiver Riddle
63a76ee58fSStanislav Funiak /// Create the interpreter operations to record a successful pattern match
64a76ee58fSStanislav Funiak /// using the contained root operation. This operation may mutate the current
65a76ee58fSStanislav Funiak /// block pointer if nested regions (i.e., pdl_interp.iterate) are required.
66a76ee58fSStanislav Funiak void generate(SuccessNode *successNode, Block *¤tBlock);
678a1ca2cdSRiver Riddle
688a1ca2cdSRiver Riddle /// Generate a rewriter function for the given pattern operation, and returns
698a1ca2cdSRiver Riddle /// a reference to that function.
708a1ca2cdSRiver Riddle SymbolRefAttr generateRewriter(pdl::PatternOp pattern,
718a1ca2cdSRiver Riddle SmallVectorImpl<Position *> &usedMatchValues);
728a1ca2cdSRiver Riddle
738a1ca2cdSRiver Riddle /// Generate the rewriter code for the given operation.
7402c4c0d5SRiver Riddle void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,
7502c4c0d5SRiver Riddle DenseMap<Value, Value> &rewriteValues,
7602c4c0d5SRiver Riddle function_ref<Value(Value)> mapRewriteValue);
778a1ca2cdSRiver Riddle void generateRewriter(pdl::AttributeOp attrOp,
788a1ca2cdSRiver Riddle DenseMap<Value, Value> &rewriteValues,
798a1ca2cdSRiver Riddle function_ref<Value(Value)> mapRewriteValue);
808a1ca2cdSRiver Riddle void generateRewriter(pdl::EraseOp eraseOp,
818a1ca2cdSRiver Riddle DenseMap<Value, Value> &rewriteValues,
828a1ca2cdSRiver Riddle function_ref<Value(Value)> mapRewriteValue);
838a1ca2cdSRiver Riddle void generateRewriter(pdl::OperationOp operationOp,
848a1ca2cdSRiver Riddle DenseMap<Value, Value> &rewriteValues,
858a1ca2cdSRiver Riddle function_ref<Value(Value)> mapRewriteValue);
868a1ca2cdSRiver Riddle void generateRewriter(pdl::ReplaceOp replaceOp,
878a1ca2cdSRiver Riddle DenseMap<Value, Value> &rewriteValues,
888a1ca2cdSRiver Riddle function_ref<Value(Value)> mapRewriteValue);
89242762c9SRiver Riddle void generateRewriter(pdl::ResultOp resultOp,
90242762c9SRiver Riddle DenseMap<Value, Value> &rewriteValues,
91242762c9SRiver Riddle function_ref<Value(Value)> mapRewriteValue);
923a833a0eSRiver Riddle void generateRewriter(pdl::ResultsOp resultOp,
933a833a0eSRiver Riddle DenseMap<Value, Value> &rewriteValues,
943a833a0eSRiver Riddle function_ref<Value(Value)> mapRewriteValue);
958a1ca2cdSRiver Riddle void generateRewriter(pdl::TypeOp typeOp,
968a1ca2cdSRiver Riddle DenseMap<Value, Value> &rewriteValues,
978a1ca2cdSRiver Riddle function_ref<Value(Value)> mapRewriteValue);
983a833a0eSRiver Riddle void generateRewriter(pdl::TypesOp typeOp,
993a833a0eSRiver Riddle DenseMap<Value, Value> &rewriteValues,
1003a833a0eSRiver Riddle function_ref<Value(Value)> mapRewriteValue);
1018a1ca2cdSRiver Riddle
1028a1ca2cdSRiver Riddle /// Generate the values used for resolving the result types of an operation
103*3c752289SRiver Riddle /// created within a dag rewriter region. If the result types of the operation
104*3c752289SRiver Riddle /// should be inferred, `hasInferredResultTypes` is set to true.
1058a1ca2cdSRiver Riddle void generateOperationResultTypeRewriter(
106*3c752289SRiver Riddle pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
107*3c752289SRiver Riddle SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
108*3c752289SRiver Riddle bool &hasInferredResultTypes);
1098a1ca2cdSRiver Riddle
1108a1ca2cdSRiver Riddle /// A builder to use when generating interpreter operations.
1118a1ca2cdSRiver Riddle OpBuilder builder;
1128a1ca2cdSRiver Riddle
1138a1ca2cdSRiver Riddle /// The matcher function used for all match related logic within PDL patterns.
114f96a8675SRiver Riddle pdl_interp::FuncOp matcherFunc;
1158a1ca2cdSRiver Riddle
1168a1ca2cdSRiver Riddle /// The rewriter module containing the all rewrite related logic within PDL
1178a1ca2cdSRiver Riddle /// patterns.
1188a1ca2cdSRiver Riddle ModuleOp rewriterModule;
1198a1ca2cdSRiver Riddle
1208a1ca2cdSRiver Riddle /// The symbol table of the rewriter module used for insertion.
1218a1ca2cdSRiver Riddle SymbolTable rewriterSymbolTable;
1228a1ca2cdSRiver Riddle
1238a1ca2cdSRiver Riddle /// A scoped map connecting a position with the corresponding interpreter
1248a1ca2cdSRiver Riddle /// value.
1258a1ca2cdSRiver Riddle ValueMap values;
1268a1ca2cdSRiver Riddle
1278a1ca2cdSRiver Riddle /// A stack of blocks used as the failure destination for matcher nodes that
1288a1ca2cdSRiver Riddle /// don't have an explicit failure path.
1298a1ca2cdSRiver Riddle SmallVector<Block *, 8> failureBlockStack;
1308a1ca2cdSRiver Riddle
1318a1ca2cdSRiver Riddle /// A mapping between values defined in a pattern match, and the corresponding
1328a1ca2cdSRiver Riddle /// positional value.
1338a1ca2cdSRiver Riddle DenseMap<Value, Position *> valueToPosition;
1348a1ca2cdSRiver Riddle
1358a1ca2cdSRiver Riddle /// The set of operation values whose whose location will be used for newly
1368a1ca2cdSRiver Riddle /// generated operations.
1374efb7754SRiver Riddle SetVector<Value> locOps;
1388a1ca2cdSRiver Riddle };
139be0a7e9fSMehdi Amini } // namespace
1408a1ca2cdSRiver Riddle
PatternLowering(pdl_interp::FuncOp matcherFunc,ModuleOp rewriterModule)141f96a8675SRiver Riddle PatternLowering::PatternLowering(pdl_interp::FuncOp matcherFunc,
142f96a8675SRiver Riddle ModuleOp rewriterModule)
1438a1ca2cdSRiver Riddle : builder(matcherFunc.getContext()), matcherFunc(matcherFunc),
1448a1ca2cdSRiver Riddle rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule) {}
1458a1ca2cdSRiver Riddle
lower(ModuleOp module)1468a1ca2cdSRiver Riddle void PatternLowering::lower(ModuleOp module) {
1478a1ca2cdSRiver Riddle PredicateUniquer predicateUniquer;
1488a1ca2cdSRiver Riddle PredicateBuilder predicateBuilder(predicateUniquer, module.getContext());
1498a1ca2cdSRiver Riddle
1508a1ca2cdSRiver Riddle // Define top-level scope for the arguments to the matcher function.
1518a1ca2cdSRiver Riddle ValueMapScope topLevelValueScope(values);
1528a1ca2cdSRiver Riddle
1538a1ca2cdSRiver Riddle // Insert the root operation, i.e. argument to the matcher, at the root
1548a1ca2cdSRiver Riddle // position.
155f96a8675SRiver Riddle Block *matcherEntryBlock = &matcherFunc.front();
1568a1ca2cdSRiver Riddle values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0));
1578a1ca2cdSRiver Riddle
1588a1ca2cdSRiver Riddle // Generate a root matcher node from the provided PDL module.
1598a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree(
1608a1ca2cdSRiver Riddle module, predicateBuilder, valueToPosition);
161a76ee58fSStanislav Funiak Block *firstMatcherBlock = generateMatcher(*root, matcherFunc.getBody());
162a76ee58fSStanislav Funiak assert(failureBlockStack.empty() && "failed to empty the stack");
1638a1ca2cdSRiver Riddle
1648a1ca2cdSRiver Riddle // After generation, merged the first matched block into the entry.
1658a1ca2cdSRiver Riddle matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(),
1668a1ca2cdSRiver Riddle firstMatcherBlock->getOperations());
1678a1ca2cdSRiver Riddle firstMatcherBlock->erase();
1688a1ca2cdSRiver Riddle }
1698a1ca2cdSRiver Riddle
generateMatcher(MatcherNode & node,Region & region)170a76ee58fSStanislav Funiak Block *PatternLowering::generateMatcher(MatcherNode &node, Region ®ion) {
1718a1ca2cdSRiver Riddle // Push a new scope for the values used by this matcher.
172a76ee58fSStanislav Funiak Block *block = ®ion.emplaceBlock();
1738a1ca2cdSRiver Riddle ValueMapScope scope(values);
1748a1ca2cdSRiver Riddle
1758a1ca2cdSRiver Riddle // If this is the return node, simply insert the corresponding interpreter
1768a1ca2cdSRiver Riddle // finalize.
1778a1ca2cdSRiver Riddle if (isa<ExitNode>(node)) {
1788a1ca2cdSRiver Riddle builder.setInsertionPointToEnd(block);
1798a1ca2cdSRiver Riddle builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc());
1808a1ca2cdSRiver Riddle return block;
1818a1ca2cdSRiver Riddle }
1828a1ca2cdSRiver Riddle
1838a1ca2cdSRiver Riddle // Get the next block in the match sequence.
184a76ee58fSStanislav Funiak // This is intentionally executed first, before we get the value for the
185a76ee58fSStanislav Funiak // position associated with the node, so that we preserve an "there exist"
186a76ee58fSStanislav Funiak // semantics: if getting a value requires an upward traversal (going from a
187a76ee58fSStanislav Funiak // value to its consumers), we want to perform the check on all the consumers
188a76ee58fSStanislav Funiak // before we pass control to the failure node.
1898a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode();
190a76ee58fSStanislav Funiak Block *failureBlock;
1918a1ca2cdSRiver Riddle if (failureNode) {
192a76ee58fSStanislav Funiak failureBlock = generateMatcher(*failureNode, region);
193a76ee58fSStanislav Funiak failureBlockStack.push_back(failureBlock);
1948a1ca2cdSRiver Riddle } else {
1958a1ca2cdSRiver Riddle assert(!failureBlockStack.empty() && "expected valid failure block");
196a76ee58fSStanislav Funiak failureBlock = failureBlockStack.back();
1978a1ca2cdSRiver Riddle }
1988a1ca2cdSRiver Riddle
199a76ee58fSStanislav Funiak // If this node contains a position, get the corresponding value for this
200a76ee58fSStanislav Funiak // block.
201a76ee58fSStanislav Funiak Block *currentBlock = block;
202a76ee58fSStanislav Funiak Position *position = node.getPosition();
203a76ee58fSStanislav Funiak Value val = position ? getValueAt(currentBlock, position) : Value();
204a76ee58fSStanislav Funiak
2058a1ca2cdSRiver Riddle // If this value corresponds to an operation, record that we are going to use
2068a1ca2cdSRiver Riddle // its location as part of a fused location.
2078a1ca2cdSRiver Riddle bool isOperationValue = val && val.getType().isa<pdl::OperationType>();
2088a1ca2cdSRiver Riddle if (isOperationValue)
2098a1ca2cdSRiver Riddle locOps.insert(val);
2108a1ca2cdSRiver Riddle
211a76ee58fSStanislav Funiak // Dispatch to the correct method based on derived node type.
212a76ee58fSStanislav Funiak TypeSwitch<MatcherNode *>(&node)
213a19e1635SStanislav Funiak .Case<BoolNode, SwitchNode>([&](auto *derivedNode) {
214a19e1635SStanislav Funiak this->generate(derivedNode, currentBlock, val);
215a19e1635SStanislav Funiak })
216a76ee58fSStanislav Funiak .Case([&](SuccessNode *successNode) {
217a76ee58fSStanislav Funiak generate(successNode, currentBlock);
218a76ee58fSStanislav Funiak });
2198a1ca2cdSRiver Riddle
220a76ee58fSStanislav Funiak // Pop all the failure blocks that were inserted due to nesting of
221a76ee58fSStanislav Funiak // pdl_interp.iterate.
222a76ee58fSStanislav Funiak while (failureBlockStack.back() != failureBlock) {
223a76ee58fSStanislav Funiak failureBlockStack.pop_back();
224a76ee58fSStanislav Funiak assert(!failureBlockStack.empty() && "unable to locate failure block");
2258a1ca2cdSRiver Riddle }
2268a1ca2cdSRiver Riddle
227a76ee58fSStanislav Funiak // Pop the new failure block.
2288a1ca2cdSRiver Riddle if (failureNode)
2298a1ca2cdSRiver Riddle failureBlockStack.pop_back();
230a76ee58fSStanislav Funiak
2318a1ca2cdSRiver Riddle if (isOperationValue)
2328a1ca2cdSRiver Riddle locOps.remove(val);
233a76ee58fSStanislav Funiak
2348a1ca2cdSRiver Riddle return block;
2358a1ca2cdSRiver Riddle }
2368a1ca2cdSRiver Riddle
getValueAt(Block * & currentBlock,Position * pos)237a76ee58fSStanislav Funiak Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) {
2388a1ca2cdSRiver Riddle if (Value val = values.lookup(pos))
2398a1ca2cdSRiver Riddle return val;
2408a1ca2cdSRiver Riddle
2418a1ca2cdSRiver Riddle // Get the value for the parent position.
242233e9476SRiver Riddle Value parentVal;
243233e9476SRiver Riddle if (Position *parent = pos->getParent())
24480b3f08eSUday Bondhugula parentVal = getValueAt(currentBlock, parent);
2458a1ca2cdSRiver Riddle
2468a1ca2cdSRiver Riddle // TODO: Use a location from the position.
247233e9476SRiver Riddle Location loc = parentVal ? parentVal.getLoc() : builder.getUnknownLoc();
248a76ee58fSStanislav Funiak builder.setInsertionPointToEnd(currentBlock);
2498a1ca2cdSRiver Riddle Value value;
2508a1ca2cdSRiver Riddle switch (pos->getKind()) {
251a76ee58fSStanislav Funiak case Predicates::OperationPos: {
252a76ee58fSStanislav Funiak auto *operationPos = cast<OperationPosition>(pos);
2532692eae5SStanislav Funiak if (operationPos->isOperandDefiningOp())
254a76ee58fSStanislav Funiak // Standard (downward) traversal which directly follows the defining op.
2558a1ca2cdSRiver Riddle value = builder.create<pdl_interp::GetDefiningOpOp>(
2568a1ca2cdSRiver Riddle loc, builder.getType<pdl::OperationType>(), parentVal);
2572692eae5SStanislav Funiak else
2582692eae5SStanislav Funiak // A passthrough operation position.
2592692eae5SStanislav Funiak value = parentVal;
2608a1ca2cdSRiver Riddle break;
261a76ee58fSStanislav Funiak }
2622692eae5SStanislav Funiak case Predicates::UsersPos: {
2632692eae5SStanislav Funiak auto *usersPos = cast<UsersPosition>(pos);
264a76ee58fSStanislav Funiak
265a76ee58fSStanislav Funiak // The first operation retrieves the representative value of a range.
2662692eae5SStanislav Funiak // This applies only when the parent is a range of values and we were
2672692eae5SStanislav Funiak // requested to use a representative value (e.g., upward traversal).
2682692eae5SStanislav Funiak if (parentVal.getType().isa<pdl::RangeType>() &&
2692692eae5SStanislav Funiak usersPos->useRepresentative())
270a76ee58fSStanislav Funiak value = builder.create<pdl_interp::ExtractOp>(loc, parentVal, 0);
271a76ee58fSStanislav Funiak else
272a76ee58fSStanislav Funiak value = parentVal;
273a76ee58fSStanislav Funiak
274a76ee58fSStanislav Funiak // The second operation retrieves the users.
275a76ee58fSStanislav Funiak value = builder.create<pdl_interp::GetUsersOp>(loc, value);
2762692eae5SStanislav Funiak break;
2772692eae5SStanislav Funiak }
2782692eae5SStanislav Funiak case Predicates::ForEachPos: {
279a76ee58fSStanislav Funiak assert(!failureBlockStack.empty() && "expected valid failure block");
280a76ee58fSStanislav Funiak auto foreach = builder.create<pdl_interp::ForEachOp>(
2812692eae5SStanislav Funiak loc, parentVal, failureBlockStack.back(), /*initLoop=*/true);
282a76ee58fSStanislav Funiak value = foreach.getLoopVariable();
283a76ee58fSStanislav Funiak
2842692eae5SStanislav Funiak // Create the continuation block.
2853c405c3bSRiver Riddle Block *continueBlock = builder.createBlock(&foreach.getRegion());
286a76ee58fSStanislav Funiak builder.create<pdl_interp::ContinueOp>(loc);
287a76ee58fSStanislav Funiak failureBlockStack.push_back(continueBlock);
288a76ee58fSStanislav Funiak
2893c405c3bSRiver Riddle currentBlock = &foreach.getRegion().front();
290a76ee58fSStanislav Funiak break;
291a76ee58fSStanislav Funiak }
2928a1ca2cdSRiver Riddle case Predicates::OperandPos: {
2938a1ca2cdSRiver Riddle auto *operandPos = cast<OperandPosition>(pos);
2948a1ca2cdSRiver Riddle value = builder.create<pdl_interp::GetOperandOp>(
2958a1ca2cdSRiver Riddle loc, builder.getType<pdl::ValueType>(), parentVal,
2968a1ca2cdSRiver Riddle operandPos->getOperandNumber());
2978a1ca2cdSRiver Riddle break;
2988a1ca2cdSRiver Riddle }
2993a833a0eSRiver Riddle case Predicates::OperandGroupPos: {
3003a833a0eSRiver Riddle auto *operandPos = cast<OperandGroupPosition>(pos);
3013a833a0eSRiver Riddle Type valueTy = builder.getType<pdl::ValueType>();
3023a833a0eSRiver Riddle value = builder.create<pdl_interp::GetOperandsOp>(
3033a833a0eSRiver Riddle loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
3043a833a0eSRiver Riddle parentVal, operandPos->getOperandGroupNumber());
3053a833a0eSRiver Riddle break;
3063a833a0eSRiver Riddle }
3078a1ca2cdSRiver Riddle case Predicates::AttributePos: {
3088a1ca2cdSRiver Riddle auto *attrPos = cast<AttributePosition>(pos);
3098a1ca2cdSRiver Riddle value = builder.create<pdl_interp::GetAttributeOp>(
3108a1ca2cdSRiver Riddle loc, builder.getType<pdl::AttributeType>(), parentVal,
3118a1ca2cdSRiver Riddle attrPos->getName().strref());
3128a1ca2cdSRiver Riddle break;
3138a1ca2cdSRiver Riddle }
3148a1ca2cdSRiver Riddle case Predicates::TypePos: {
3153a833a0eSRiver Riddle if (parentVal.getType().isa<pdl::AttributeType>())
3168a1ca2cdSRiver Riddle value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal);
3173a833a0eSRiver Riddle else
3183a833a0eSRiver Riddle value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal);
3198a1ca2cdSRiver Riddle break;
3208a1ca2cdSRiver Riddle }
3218a1ca2cdSRiver Riddle case Predicates::ResultPos: {
3228a1ca2cdSRiver Riddle auto *resPos = cast<ResultPosition>(pos);
3238a1ca2cdSRiver Riddle value = builder.create<pdl_interp::GetResultOp>(
3248a1ca2cdSRiver Riddle loc, builder.getType<pdl::ValueType>(), parentVal,
3258a1ca2cdSRiver Riddle resPos->getResultNumber());
3268a1ca2cdSRiver Riddle break;
3278a1ca2cdSRiver Riddle }
3283a833a0eSRiver Riddle case Predicates::ResultGroupPos: {
3293a833a0eSRiver Riddle auto *resPos = cast<ResultGroupPosition>(pos);
3303a833a0eSRiver Riddle Type valueTy = builder.getType<pdl::ValueType>();
3313a833a0eSRiver Riddle value = builder.create<pdl_interp::GetResultsOp>(
3323a833a0eSRiver Riddle loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
3333a833a0eSRiver Riddle parentVal, resPos->getResultGroupNumber());
3343a833a0eSRiver Riddle break;
3353a833a0eSRiver Riddle }
336233e9476SRiver Riddle case Predicates::AttributeLiteralPos: {
337233e9476SRiver Riddle auto *attrPos = cast<AttributeLiteralPosition>(pos);
338233e9476SRiver Riddle value =
339233e9476SRiver Riddle builder.create<pdl_interp::CreateAttributeOp>(loc, attrPos->getValue());
340233e9476SRiver Riddle break;
341233e9476SRiver Riddle }
342233e9476SRiver Riddle case Predicates::TypeLiteralPos: {
343233e9476SRiver Riddle auto *typePos = cast<TypeLiteralPosition>(pos);
344233e9476SRiver Riddle Attribute rawTypeAttr = typePos->getValue();
345233e9476SRiver Riddle if (TypeAttr typeAttr = rawTypeAttr.dyn_cast<TypeAttr>())
346233e9476SRiver Riddle value = builder.create<pdl_interp::CreateTypeOp>(loc, typeAttr);
347233e9476SRiver Riddle else
348233e9476SRiver Riddle value = builder.create<pdl_interp::CreateTypesOp>(
349233e9476SRiver Riddle loc, rawTypeAttr.cast<ArrayAttr>());
350233e9476SRiver Riddle break;
351233e9476SRiver Riddle }
3528a1ca2cdSRiver Riddle default:
3538a1ca2cdSRiver Riddle llvm_unreachable("Generating unknown Position getter");
3548a1ca2cdSRiver Riddle break;
3558a1ca2cdSRiver Riddle }
356a76ee58fSStanislav Funiak
3578a1ca2cdSRiver Riddle values.insert(pos, value);
3588a1ca2cdSRiver Riddle return value;
3598a1ca2cdSRiver Riddle }
3608a1ca2cdSRiver Riddle
generate(BoolNode * boolNode,Block * & currentBlock,Value val)361a76ee58fSStanislav Funiak void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock,
362a76ee58fSStanislav Funiak Value val) {
3638a1ca2cdSRiver Riddle Location loc = val.getLoc();
364a76ee58fSStanislav Funiak Qualifier *question = boolNode->getQuestion();
365a76ee58fSStanislav Funiak Qualifier *answer = boolNode->getAnswer();
366a76ee58fSStanislav Funiak Region *region = currentBlock->getParent();
367a76ee58fSStanislav Funiak
368a76ee58fSStanislav Funiak // Execute the getValue queries first, so that we create success
369a76ee58fSStanislav Funiak // matcher in the correct (possibly nested) region.
370a76ee58fSStanislav Funiak SmallVector<Value> args;
371a76ee58fSStanislav Funiak if (auto *equalToQuestion = dyn_cast<EqualToQuestion>(question)) {
372a76ee58fSStanislav Funiak args = {getValueAt(currentBlock, equalToQuestion->getValue())};
373a76ee58fSStanislav Funiak } else if (auto *cstQuestion = dyn_cast<ConstraintQuestion>(question)) {
374233e9476SRiver Riddle for (Position *position : cstQuestion->getArgs())
375a76ee58fSStanislav Funiak args.push_back(getValueAt(currentBlock, position));
376a76ee58fSStanislav Funiak }
377a76ee58fSStanislav Funiak
378a76ee58fSStanislav Funiak // Generate the matcher in the current (potentially nested) region
379a76ee58fSStanislav Funiak // and get the failure successor.
380a76ee58fSStanislav Funiak Block *success = generateMatcher(*boolNode->getSuccessNode(), *region);
381a76ee58fSStanislav Funiak Block *failure = failureBlockStack.back();
382a76ee58fSStanislav Funiak
383a76ee58fSStanislav Funiak // Finally, create the predicate.
384a76ee58fSStanislav Funiak builder.setInsertionPointToEnd(currentBlock);
3853a833a0eSRiver Riddle Predicates::Kind kind = question->getKind();
3863a833a0eSRiver Riddle switch (kind) {
3878a1ca2cdSRiver Riddle case Predicates::IsNotNullQuestion:
388a76ee58fSStanislav Funiak builder.create<pdl_interp::IsNotNullOp>(loc, val, success, failure);
3898a1ca2cdSRiver Riddle break;
3908a1ca2cdSRiver Riddle case Predicates::OperationNameQuestion: {
3918a1ca2cdSRiver Riddle auto *opNameAnswer = cast<OperationNameAnswer>(answer);
3928a1ca2cdSRiver Riddle builder.create<pdl_interp::CheckOperationNameOp>(
393a76ee58fSStanislav Funiak loc, val, opNameAnswer->getValue().getStringRef(), success, failure);
3948a1ca2cdSRiver Riddle break;
3958a1ca2cdSRiver Riddle }
3968a1ca2cdSRiver Riddle case Predicates::TypeQuestion: {
3978a1ca2cdSRiver Riddle auto *ans = cast<TypeAnswer>(answer);
3983a833a0eSRiver Riddle if (val.getType().isa<pdl::RangeType>())
3993a833a0eSRiver Riddle builder.create<pdl_interp::CheckTypesOp>(
400a76ee58fSStanislav Funiak loc, val, ans->getValue().cast<ArrayAttr>(), success, failure);
4013a833a0eSRiver Riddle else
4028a1ca2cdSRiver Riddle builder.create<pdl_interp::CheckTypeOp>(
403a76ee58fSStanislav Funiak loc, val, ans->getValue().cast<TypeAttr>(), success, failure);
4048a1ca2cdSRiver Riddle break;
4058a1ca2cdSRiver Riddle }
4068a1ca2cdSRiver Riddle case Predicates::AttributeQuestion: {
4078a1ca2cdSRiver Riddle auto *ans = cast<AttributeAnswer>(answer);
4088a1ca2cdSRiver Riddle builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(),
409a76ee58fSStanislav Funiak success, failure);
4108a1ca2cdSRiver Riddle break;
4118a1ca2cdSRiver Riddle }
4123a833a0eSRiver Riddle case Predicates::OperandCountAtLeastQuestion:
4133a833a0eSRiver Riddle case Predicates::OperandCountQuestion:
4148a1ca2cdSRiver Riddle builder.create<pdl_interp::CheckOperandCountOp>(
4153a833a0eSRiver Riddle loc, val, cast<UnsignedAnswer>(answer)->getValue(),
4163a833a0eSRiver Riddle /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion,
417a76ee58fSStanislav Funiak success, failure);
4188a1ca2cdSRiver Riddle break;
4193a833a0eSRiver Riddle case Predicates::ResultCountAtLeastQuestion:
4203a833a0eSRiver Riddle case Predicates::ResultCountQuestion:
4218a1ca2cdSRiver Riddle builder.create<pdl_interp::CheckResultCountOp>(
4223a833a0eSRiver Riddle loc, val, cast<UnsignedAnswer>(answer)->getValue(),
4233a833a0eSRiver Riddle /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion,
424a76ee58fSStanislav Funiak success, failure);
4258a1ca2cdSRiver Riddle break;
4268a1ca2cdSRiver Riddle case Predicates::EqualToQuestion: {
427a76ee58fSStanislav Funiak bool trueAnswer = isa<TrueAnswer>(answer);
428a76ee58fSStanislav Funiak builder.create<pdl_interp::AreEqualOp>(loc, val, args.front(),
429a76ee58fSStanislav Funiak trueAnswer ? success : failure,
430a76ee58fSStanislav Funiak trueAnswer ? failure : success);
4318a1ca2cdSRiver Riddle break;
4328a1ca2cdSRiver Riddle }
4338a1ca2cdSRiver Riddle case Predicates::ConstraintQuestion: {
434233e9476SRiver Riddle auto *cstQuestion = cast<ConstraintQuestion>(question);
4359595f356SRiver Riddle builder.create<pdl_interp::ApplyConstraintOp>(loc, cstQuestion->getName(),
4369595f356SRiver Riddle args, success, failure);
4378a1ca2cdSRiver Riddle break;
4388a1ca2cdSRiver Riddle }
4398a1ca2cdSRiver Riddle default:
4408a1ca2cdSRiver Riddle llvm_unreachable("Generating unknown Predicate operation");
4418a1ca2cdSRiver Riddle }
4428a1ca2cdSRiver Riddle }
4438a1ca2cdSRiver Riddle
4448a1ca2cdSRiver Riddle template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>
createSwitchOp(Value val,Block * defaultDest,OpBuilder & builder,llvm::MapVector<Qualifier *,Block * > & dests)4458a1ca2cdSRiver Riddle static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder,
4463a833a0eSRiver Riddle llvm::MapVector<Qualifier *, Block *> &dests) {
4478a1ca2cdSRiver Riddle std::vector<ValT> values;
4488a1ca2cdSRiver Riddle std::vector<Block *> blocks;
4498a1ca2cdSRiver Riddle values.reserve(dests.size());
4508a1ca2cdSRiver Riddle blocks.reserve(dests.size());
4518a1ca2cdSRiver Riddle for (const auto &it : dests) {
4528a1ca2cdSRiver Riddle blocks.push_back(it.second);
4538a1ca2cdSRiver Riddle values.push_back(cast<PredT>(it.first)->getValue());
4548a1ca2cdSRiver Riddle }
4558a1ca2cdSRiver Riddle builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks);
4568a1ca2cdSRiver Riddle }
4578a1ca2cdSRiver Riddle
generate(SwitchNode * switchNode,Block * currentBlock,Value val)458a76ee58fSStanislav Funiak void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock,
459a76ee58fSStanislav Funiak Value val) {
460a76ee58fSStanislav Funiak Qualifier *question = switchNode->getQuestion();
461a76ee58fSStanislav Funiak Region *region = currentBlock->getParent();
462a76ee58fSStanislav Funiak Block *defaultDest = failureBlockStack.back();
463a76ee58fSStanislav Funiak
4643a833a0eSRiver Riddle // If the switch question is not an exact answer, i.e. for the `at_least`
4653a833a0eSRiver Riddle // cases, we generate a special block sequence.
4663a833a0eSRiver Riddle Predicates::Kind kind = question->getKind();
4673a833a0eSRiver Riddle if (kind == Predicates::OperandCountAtLeastQuestion ||
4683a833a0eSRiver Riddle kind == Predicates::ResultCountAtLeastQuestion) {
4693a833a0eSRiver Riddle // Order the children such that the cases are in reverse numerical order.
4701d49e535SGuillaume Chatelet SmallVector<unsigned> sortedChildren = llvm::to_vector<16>(
4711d49e535SGuillaume Chatelet llvm::seq<unsigned>(0, switchNode->getChildren().size()));
4723a833a0eSRiver Riddle llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) {
4733a833a0eSRiver Riddle return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() >
4743a833a0eSRiver Riddle cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue();
4753a833a0eSRiver Riddle });
4763a833a0eSRiver Riddle
4773a833a0eSRiver Riddle // Build the destination for each child using the next highest child as a
4783a833a0eSRiver Riddle // a failure destination. This essentially creates the following control
4793a833a0eSRiver Riddle // flow:
4803a833a0eSRiver Riddle //
4813a833a0eSRiver Riddle // if (operand_count < 1)
4823a833a0eSRiver Riddle // goto failure
4833a833a0eSRiver Riddle // if (child1.match())
4843a833a0eSRiver Riddle // ...
4853a833a0eSRiver Riddle //
4863a833a0eSRiver Riddle // if (operand_count < 2)
4873a833a0eSRiver Riddle // goto failure
4883a833a0eSRiver Riddle // if (child2.match())
4893a833a0eSRiver Riddle // ...
4903a833a0eSRiver Riddle //
4913a833a0eSRiver Riddle // failure:
4923a833a0eSRiver Riddle // ...
4933a833a0eSRiver Riddle //
4943a833a0eSRiver Riddle failureBlockStack.push_back(defaultDest);
495a76ee58fSStanislav Funiak Location loc = val.getLoc();
4963a833a0eSRiver Riddle for (unsigned idx : sortedChildren) {
4973a833a0eSRiver Riddle auto &child = switchNode->getChild(idx);
498a76ee58fSStanislav Funiak Block *childBlock = generateMatcher(*child.second, *region);
4993a833a0eSRiver Riddle Block *predicateBlock = builder.createBlock(childBlock);
500a76ee58fSStanislav Funiak builder.setInsertionPointToEnd(predicateBlock);
501a76ee58fSStanislav Funiak unsigned ans = cast<UnsignedAnswer>(child.first)->getValue();
502a76ee58fSStanislav Funiak switch (kind) {
503a76ee58fSStanislav Funiak case Predicates::OperandCountAtLeastQuestion:
504a76ee58fSStanislav Funiak builder.create<pdl_interp::CheckOperandCountOp>(
505a76ee58fSStanislav Funiak loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest);
506a76ee58fSStanislav Funiak break;
507a76ee58fSStanislav Funiak case Predicates::ResultCountAtLeastQuestion:
508a76ee58fSStanislav Funiak builder.create<pdl_interp::CheckResultCountOp>(
509a76ee58fSStanislav Funiak loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest);
510a76ee58fSStanislav Funiak break;
511a76ee58fSStanislav Funiak default:
512a76ee58fSStanislav Funiak llvm_unreachable("Generating invalid AtLeast operation");
513a76ee58fSStanislav Funiak }
5143a833a0eSRiver Riddle failureBlockStack.back() = predicateBlock;
5153a833a0eSRiver Riddle }
5163a833a0eSRiver Riddle Block *firstPredicateBlock = failureBlockStack.pop_back_val();
5173a833a0eSRiver Riddle currentBlock->getOperations().splice(currentBlock->end(),
5183a833a0eSRiver Riddle firstPredicateBlock->getOperations());
5193a833a0eSRiver Riddle firstPredicateBlock->erase();
5203a833a0eSRiver Riddle return;
5213a833a0eSRiver Riddle }
5223a833a0eSRiver Riddle
5233a833a0eSRiver Riddle // Otherwise, generate each of the children and generate an interpreter
5243a833a0eSRiver Riddle // switch.
5253a833a0eSRiver Riddle llvm::MapVector<Qualifier *, Block *> children;
5263a833a0eSRiver Riddle for (auto &it : switchNode->getChildren())
527a76ee58fSStanislav Funiak children.insert({it.first, generateMatcher(*it.second, *region)});
5288a1ca2cdSRiver Riddle builder.setInsertionPointToEnd(currentBlock);
5293a833a0eSRiver Riddle
5308a1ca2cdSRiver Riddle switch (question->getKind()) {
5318a1ca2cdSRiver Riddle case Predicates::OperandCountQuestion:
5328a1ca2cdSRiver Riddle return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer,
5333a833a0eSRiver Riddle int32_t>(val, defaultDest, builder, children);
5348a1ca2cdSRiver Riddle case Predicates::ResultCountQuestion:
5358a1ca2cdSRiver Riddle return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer,
5363a833a0eSRiver Riddle int32_t>(val, defaultDest, builder, children);
5378a1ca2cdSRiver Riddle case Predicates::OperationNameQuestion:
5388a1ca2cdSRiver Riddle return createSwitchOp<pdl_interp::SwitchOperationNameOp,
5398a1ca2cdSRiver Riddle OperationNameAnswer>(val, defaultDest, builder,
5403a833a0eSRiver Riddle children);
5418a1ca2cdSRiver Riddle case Predicates::TypeQuestion:
5423a833a0eSRiver Riddle if (val.getType().isa<pdl::RangeType>()) {
5433a833a0eSRiver Riddle return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>(
5443a833a0eSRiver Riddle val, defaultDest, builder, children);
5453a833a0eSRiver Riddle }
5468a1ca2cdSRiver Riddle return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>(
5473a833a0eSRiver Riddle val, defaultDest, builder, children);
5488a1ca2cdSRiver Riddle case Predicates::AttributeQuestion:
5498a1ca2cdSRiver Riddle return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>(
5503a833a0eSRiver Riddle val, defaultDest, builder, children);
5518a1ca2cdSRiver Riddle default:
5528a1ca2cdSRiver Riddle llvm_unreachable("Generating unknown switch predicate.");
5538a1ca2cdSRiver Riddle }
5548a1ca2cdSRiver Riddle }
5558a1ca2cdSRiver Riddle
generate(SuccessNode * successNode,Block * & currentBlock)556a76ee58fSStanislav Funiak void PatternLowering::generate(SuccessNode *successNode, Block *¤tBlock) {
557a76ee58fSStanislav Funiak pdl::PatternOp pattern = successNode->getPattern();
558a76ee58fSStanislav Funiak Value root = successNode->getRoot();
559a76ee58fSStanislav Funiak
5608a1ca2cdSRiver Riddle // Generate a rewriter for the pattern this success node represents, and track
5618a1ca2cdSRiver Riddle // any values used from the match region.
5628a1ca2cdSRiver Riddle SmallVector<Position *, 8> usedMatchValues;
5638a1ca2cdSRiver Riddle SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues);
5648a1ca2cdSRiver Riddle
5658a1ca2cdSRiver Riddle // Process any values used in the rewrite that are defined in the match.
5668a1ca2cdSRiver Riddle std::vector<Value> mappedMatchValues;
5678a1ca2cdSRiver Riddle mappedMatchValues.reserve(usedMatchValues.size());
5688a1ca2cdSRiver Riddle for (Position *position : usedMatchValues)
5698a1ca2cdSRiver Riddle mappedMatchValues.push_back(getValueAt(currentBlock, position));
5708a1ca2cdSRiver Riddle
5718a1ca2cdSRiver Riddle // Collect the set of operations generated by the rewriter.
5728a1ca2cdSRiver Riddle SmallVector<StringRef, 4> generatedOps;
5738a1ca2cdSRiver Riddle for (auto op : pattern.getRewriter().body().getOps<pdl::OperationOp>())
5748a1ca2cdSRiver Riddle generatedOps.push_back(*op.name());
5758a1ca2cdSRiver Riddle ArrayAttr generatedOpsAttr;
5768a1ca2cdSRiver Riddle if (!generatedOps.empty())
5778a1ca2cdSRiver Riddle generatedOpsAttr = builder.getStrArrayAttr(generatedOps);
5788a1ca2cdSRiver Riddle
5798a1ca2cdSRiver Riddle // Grab the root kind if present.
5808a1ca2cdSRiver Riddle StringAttr rootKindAttr;
581a76ee58fSStanislav Funiak if (pdl::OperationOp rootOp = root.getDefiningOp<pdl::OperationOp>())
582a76ee58fSStanislav Funiak if (Optional<StringRef> rootKind = rootOp.name())
5838a1ca2cdSRiver Riddle rootKindAttr = builder.getStringAttr(*rootKind);
5848a1ca2cdSRiver Riddle
5858a1ca2cdSRiver Riddle builder.setInsertionPointToEnd(currentBlock);
5868a1ca2cdSRiver Riddle builder.create<pdl_interp::RecordMatchOp>(
5878a1ca2cdSRiver Riddle pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
5888a1ca2cdSRiver Riddle rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.benefitAttr(),
589a76ee58fSStanislav Funiak failureBlockStack.back());
5908a1ca2cdSRiver Riddle }
5918a1ca2cdSRiver Riddle
generateRewriter(pdl::PatternOp pattern,SmallVectorImpl<Position * > & usedMatchValues)5928a1ca2cdSRiver Riddle SymbolRefAttr PatternLowering::generateRewriter(
5938a1ca2cdSRiver Riddle pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) {
594f96a8675SRiver Riddle builder.setInsertionPointToEnd(rewriterModule.getBody());
595f96a8675SRiver Riddle auto rewriterFunc = builder.create<pdl_interp::FuncOp>(
596f96a8675SRiver Riddle pattern.getLoc(), "pdl_generated_rewriter",
5978a1ca2cdSRiver Riddle builder.getFunctionType(llvm::None, llvm::None));
5988a1ca2cdSRiver Riddle rewriterSymbolTable.insert(rewriterFunc);
5998a1ca2cdSRiver Riddle
6008a1ca2cdSRiver Riddle // Generate the rewriter function body.
601f96a8675SRiver Riddle builder.setInsertionPointToEnd(&rewriterFunc.front());
6028a1ca2cdSRiver Riddle
6038a1ca2cdSRiver Riddle // Map an input operand of the pattern to a generated interpreter value.
6048a1ca2cdSRiver Riddle DenseMap<Value, Value> rewriteValues;
6058a1ca2cdSRiver Riddle auto mapRewriteValue = [&](Value oldValue) {
6068a1ca2cdSRiver Riddle Value &newValue = rewriteValues[oldValue];
6078a1ca2cdSRiver Riddle if (newValue)
6088a1ca2cdSRiver Riddle return newValue;
6098a1ca2cdSRiver Riddle
6108a1ca2cdSRiver Riddle // Prefer materializing constants directly when possible.
6118a1ca2cdSRiver Riddle Operation *oldOp = oldValue.getDefiningOp();
6128a1ca2cdSRiver Riddle if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) {
6138a1ca2cdSRiver Riddle if (Attribute value = attrOp.valueAttr()) {
6148a1ca2cdSRiver Riddle return newValue = builder.create<pdl_interp::CreateAttributeOp>(
6158a1ca2cdSRiver Riddle attrOp.getLoc(), value);
6168a1ca2cdSRiver Riddle }
6178a1ca2cdSRiver Riddle } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) {
6188a1ca2cdSRiver Riddle if (TypeAttr type = typeOp.typeAttr()) {
6198a1ca2cdSRiver Riddle return newValue = builder.create<pdl_interp::CreateTypeOp>(
6208a1ca2cdSRiver Riddle typeOp.getLoc(), type);
6218a1ca2cdSRiver Riddle }
6223a833a0eSRiver Riddle } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) {
6233a833a0eSRiver Riddle if (ArrayAttr type = typeOp.typesAttr()) {
6243a833a0eSRiver Riddle return newValue = builder.create<pdl_interp::CreateTypesOp>(
6253a833a0eSRiver Riddle typeOp.getLoc(), typeOp.getType(), type);
6263a833a0eSRiver Riddle }
6278a1ca2cdSRiver Riddle }
6288a1ca2cdSRiver Riddle
6298a1ca2cdSRiver Riddle // Otherwise, add this as an input to the rewriter.
6308a1ca2cdSRiver Riddle Position *inputPos = valueToPosition.lookup(oldValue);
6318a1ca2cdSRiver Riddle assert(inputPos && "expected value to be a pattern input");
6328a1ca2cdSRiver Riddle usedMatchValues.push_back(inputPos);
633e084679fSRiver Riddle return newValue = rewriterFunc.front().addArgument(oldValue.getType(),
634e084679fSRiver Riddle oldValue.getLoc());
6358a1ca2cdSRiver Riddle };
6368a1ca2cdSRiver Riddle
6378a1ca2cdSRiver Riddle // If this is a custom rewriter, simply dispatch to the registered rewrite
6388a1ca2cdSRiver Riddle // method.
6398a1ca2cdSRiver Riddle pdl::RewriteOp rewriter = pattern.getRewriter();
6408a1ca2cdSRiver Riddle if (StringAttr rewriteName = rewriter.nameAttr()) {
641a76ee58fSStanislav Funiak SmallVector<Value> args;
642a76ee58fSStanislav Funiak if (rewriter.root())
643a76ee58fSStanislav Funiak args.push_back(mapRewriteValue(rewriter.root()));
64402c4c0d5SRiver Riddle auto mappedArgs = llvm::map_range(rewriter.externalArgs(), mapRewriteValue);
64502c4c0d5SRiver Riddle args.append(mappedArgs.begin(), mappedArgs.end());
6468a1ca2cdSRiver Riddle builder.create<pdl_interp::ApplyRewriteOp>(
6479595f356SRiver Riddle rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args);
6488a1ca2cdSRiver Riddle } else {
6498a1ca2cdSRiver Riddle // Otherwise this is a dag rewriter defined using PDL operations.
6508a1ca2cdSRiver Riddle for (Operation &rewriteOp : *rewriter.getBody()) {
6518a1ca2cdSRiver Riddle llvm::TypeSwitch<Operation *>(&rewriteOp)
65202c4c0d5SRiver Riddle .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp,
6533a833a0eSRiver Riddle pdl::OperationOp, pdl::ReplaceOp, pdl::ResultOp, pdl::ResultsOp,
6543a833a0eSRiver Riddle pdl::TypeOp, pdl::TypesOp>([&](auto op) {
6558a1ca2cdSRiver Riddle this->generateRewriter(op, rewriteValues, mapRewriteValue);
6568a1ca2cdSRiver Riddle });
6578a1ca2cdSRiver Riddle }
6588a1ca2cdSRiver Riddle }
6598a1ca2cdSRiver Riddle
6608a1ca2cdSRiver Riddle // Update the signature of the rewrite function.
6618a1ca2cdSRiver Riddle rewriterFunc.setType(builder.getFunctionType(
6628a1ca2cdSRiver Riddle llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()),
6638a1ca2cdSRiver Riddle /*results=*/llvm::None));
6648a1ca2cdSRiver Riddle
6658a1ca2cdSRiver Riddle builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc());
666faf1c224SChris Lattner return SymbolRefAttr::get(
667faf1c224SChris Lattner builder.getContext(),
6688a1ca2cdSRiver Riddle pdl_interp::PDLInterpDialect::getRewriterModuleName(),
669faf1c224SChris Lattner SymbolRefAttr::get(rewriterFunc));
6708a1ca2cdSRiver Riddle }
6718a1ca2cdSRiver Riddle
generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)6728a1ca2cdSRiver Riddle void PatternLowering::generateRewriter(
67302c4c0d5SRiver Riddle pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues,
67402c4c0d5SRiver Riddle function_ref<Value(Value)> mapRewriteValue) {
67502c4c0d5SRiver Riddle SmallVector<Value, 2> arguments;
67602c4c0d5SRiver Riddle for (Value argument : rewriteOp.args())
67702c4c0d5SRiver Riddle arguments.push_back(mapRewriteValue(argument));
67802c4c0d5SRiver Riddle auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>(
67902c4c0d5SRiver Riddle rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.nameAttr(),
6809595f356SRiver Riddle arguments);
6819595f356SRiver Riddle for (auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults()))
68202c4c0d5SRiver Riddle rewriteValues[std::get<0>(it)] = std::get<1>(it);
68302c4c0d5SRiver Riddle }
68402c4c0d5SRiver Riddle
generateRewriter(pdl::AttributeOp attrOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)68502c4c0d5SRiver Riddle void PatternLowering::generateRewriter(
6868a1ca2cdSRiver Riddle pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues,
6878a1ca2cdSRiver Riddle function_ref<Value(Value)> mapRewriteValue) {
6888a1ca2cdSRiver Riddle Value newAttr = builder.create<pdl_interp::CreateAttributeOp>(
6898a1ca2cdSRiver Riddle attrOp.getLoc(), attrOp.valueAttr());
6908a1ca2cdSRiver Riddle rewriteValues[attrOp] = newAttr;
6918a1ca2cdSRiver Riddle }
6928a1ca2cdSRiver Riddle
generateRewriter(pdl::EraseOp eraseOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)6938a1ca2cdSRiver Riddle void PatternLowering::generateRewriter(
6948a1ca2cdSRiver Riddle pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues,
6958a1ca2cdSRiver Riddle function_ref<Value(Value)> mapRewriteValue) {
6968a1ca2cdSRiver Riddle builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(),
6978a1ca2cdSRiver Riddle mapRewriteValue(eraseOp.operation()));
6988a1ca2cdSRiver Riddle }
6998a1ca2cdSRiver Riddle
generateRewriter(pdl::OperationOp operationOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)7008a1ca2cdSRiver Riddle void PatternLowering::generateRewriter(
7018a1ca2cdSRiver Riddle pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues,
7028a1ca2cdSRiver Riddle function_ref<Value(Value)> mapRewriteValue) {
7038a1ca2cdSRiver Riddle SmallVector<Value, 4> operands;
7048a1ca2cdSRiver Riddle for (Value operand : operationOp.operands())
7058a1ca2cdSRiver Riddle operands.push_back(mapRewriteValue(operand));
7068a1ca2cdSRiver Riddle
7078a1ca2cdSRiver Riddle SmallVector<Value, 4> attributes;
7088a1ca2cdSRiver Riddle for (Value attr : operationOp.attributes())
7098a1ca2cdSRiver Riddle attributes.push_back(mapRewriteValue(attr));
7108a1ca2cdSRiver Riddle
711*3c752289SRiver Riddle bool hasInferredResultTypes = false;
7128a1ca2cdSRiver Riddle SmallVector<Value, 2> types;
713*3c752289SRiver Riddle generateOperationResultTypeRewriter(operationOp, mapRewriteValue, types,
714*3c752289SRiver Riddle rewriteValues, hasInferredResultTypes);
7158a1ca2cdSRiver Riddle
7168a1ca2cdSRiver Riddle // Create the new operation.
7178a1ca2cdSRiver Riddle Location loc = operationOp.getLoc();
7188a1ca2cdSRiver Riddle Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
719*3c752289SRiver Riddle loc, *operationOp.name(), types, hasInferredResultTypes, operands,
720*3c752289SRiver Riddle attributes, operationOp.attributeNames());
7218a1ca2cdSRiver Riddle rewriteValues[operationOp.op()] = createdOp;
7228a1ca2cdSRiver Riddle
723242762c9SRiver Riddle // Generate accesses for any results that have their types constrained.
7243a833a0eSRiver Riddle // Handle the case where there is a single range representing all of the
7253a833a0eSRiver Riddle // result types.
7263a833a0eSRiver Riddle OperandRange resultTys = operationOp.types();
7273a833a0eSRiver Riddle if (resultTys.size() == 1 && resultTys[0].getType().isa<pdl::RangeType>()) {
7283a833a0eSRiver Riddle Value &type = rewriteValues[resultTys[0]];
7293a833a0eSRiver Riddle if (!type) {
7303a833a0eSRiver Riddle auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp);
7313a833a0eSRiver Riddle type = builder.create<pdl_interp::GetValueTypeOp>(loc, results);
7323a833a0eSRiver Riddle }
7333a833a0eSRiver Riddle return;
7343a833a0eSRiver Riddle }
7353a833a0eSRiver Riddle
7363a833a0eSRiver Riddle // Otherwise, populate the individual results.
7373a833a0eSRiver Riddle bool seenVariableLength = false;
7383a833a0eSRiver Riddle Type valueTy = builder.getType<pdl::ValueType>();
7393a833a0eSRiver Riddle Type valueRangeTy = pdl::RangeType::get(valueTy);
740e4853be2SMehdi Amini for (const auto &it : llvm::enumerate(resultTys)) {
741242762c9SRiver Riddle Value &type = rewriteValues[it.value()];
742242762c9SRiver Riddle if (type)
743242762c9SRiver Riddle continue;
7443a833a0eSRiver Riddle bool isVariadic = it.value().getType().isa<pdl::RangeType>();
7453a833a0eSRiver Riddle seenVariableLength |= isVariadic;
746242762c9SRiver Riddle
7473a833a0eSRiver Riddle // After a variable length result has been seen, we need to use result
7483a833a0eSRiver Riddle // groups because the exact index of the result is not statically known.
7493a833a0eSRiver Riddle Value resultVal;
7503a833a0eSRiver Riddle if (seenVariableLength)
7513a833a0eSRiver Riddle resultVal = builder.create<pdl_interp::GetResultsOp>(
7523a833a0eSRiver Riddle loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index());
7533a833a0eSRiver Riddle else
7543a833a0eSRiver Riddle resultVal = builder.create<pdl_interp::GetResultOp>(
7553a833a0eSRiver Riddle loc, valueTy, createdOp, it.index());
7563a833a0eSRiver Riddle type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal);
7578a1ca2cdSRiver Riddle }
7588a1ca2cdSRiver Riddle }
7598a1ca2cdSRiver Riddle
generateRewriter(pdl::ReplaceOp replaceOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)7608a1ca2cdSRiver Riddle void PatternLowering::generateRewriter(
7618a1ca2cdSRiver Riddle pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
7628a1ca2cdSRiver Riddle function_ref<Value(Value)> mapRewriteValue) {
763242762c9SRiver Riddle SmallVector<Value, 4> replOperands;
764242762c9SRiver Riddle
7658a1ca2cdSRiver Riddle // If the replacement was another operation, get its results. `pdl` allows
7668a1ca2cdSRiver Riddle // for using an operation for simplicitly, but the interpreter isn't as
7678a1ca2cdSRiver Riddle // user facing.
768242762c9SRiver Riddle if (Value replOp = replaceOp.replOperation()) {
7693a833a0eSRiver Riddle // Don't use replace if we know the replaced operation has no results.
7703a833a0eSRiver Riddle auto opOp = replaceOp.operation().getDefiningOp<pdl::OperationOp>();
7713a833a0eSRiver Riddle if (!opOp || !opOp.types().empty()) {
7723a833a0eSRiver Riddle replOperands.push_back(builder.create<pdl_interp::GetResultsOp>(
7733a833a0eSRiver Riddle replOp.getLoc(), mapRewriteValue(replOp)));
7743a833a0eSRiver Riddle }
775242762c9SRiver Riddle } else {
776242762c9SRiver Riddle for (Value operand : replaceOp.replValues())
777242762c9SRiver Riddle replOperands.push_back(mapRewriteValue(operand));
778242762c9SRiver Riddle }
7798a1ca2cdSRiver Riddle
7808a1ca2cdSRiver Riddle // If there are no replacement values, just create an erase instead.
781242762c9SRiver Riddle if (replOperands.empty()) {
7828a1ca2cdSRiver Riddle builder.create<pdl_interp::EraseOp>(replaceOp.getLoc(),
7838a1ca2cdSRiver Riddle mapRewriteValue(replaceOp.operation()));
7848a1ca2cdSRiver Riddle return;
7858a1ca2cdSRiver Riddle }
7868a1ca2cdSRiver Riddle
7878a1ca2cdSRiver Riddle builder.create<pdl_interp::ReplaceOp>(
7888a1ca2cdSRiver Riddle replaceOp.getLoc(), mapRewriteValue(replaceOp.operation()), replOperands);
7898a1ca2cdSRiver Riddle }
7908a1ca2cdSRiver Riddle
generateRewriter(pdl::ResultOp resultOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)7918a1ca2cdSRiver Riddle void PatternLowering::generateRewriter(
792242762c9SRiver Riddle pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues,
793242762c9SRiver Riddle function_ref<Value(Value)> mapRewriteValue) {
794242762c9SRiver Riddle rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>(
795242762c9SRiver Riddle resultOp.getLoc(), builder.getType<pdl::ValueType>(),
796242762c9SRiver Riddle mapRewriteValue(resultOp.parent()), resultOp.index());
797242762c9SRiver Riddle }
798242762c9SRiver Riddle
generateRewriter(pdl::ResultsOp resultOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)799242762c9SRiver Riddle void PatternLowering::generateRewriter(
8003a833a0eSRiver Riddle pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues,
8013a833a0eSRiver Riddle function_ref<Value(Value)> mapRewriteValue) {
8023a833a0eSRiver Riddle rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>(
8033a833a0eSRiver Riddle resultOp.getLoc(), resultOp.getType(), mapRewriteValue(resultOp.parent()),
8043a833a0eSRiver Riddle resultOp.index());
8053a833a0eSRiver Riddle }
8063a833a0eSRiver Riddle
generateRewriter(pdl::TypeOp typeOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)8073a833a0eSRiver Riddle void PatternLowering::generateRewriter(
8088a1ca2cdSRiver Riddle pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues,
8098a1ca2cdSRiver Riddle function_ref<Value(Value)> mapRewriteValue) {
8108a1ca2cdSRiver Riddle // If the type isn't constant, the users (e.g. OperationOp) will resolve this
8118a1ca2cdSRiver Riddle // type.
8128a1ca2cdSRiver Riddle if (TypeAttr typeAttr = typeOp.typeAttr()) {
8133a833a0eSRiver Riddle rewriteValues[typeOp] =
8148a1ca2cdSRiver Riddle builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr);
8153a833a0eSRiver Riddle }
8163a833a0eSRiver Riddle }
8173a833a0eSRiver Riddle
generateRewriter(pdl::TypesOp typeOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)8183a833a0eSRiver Riddle void PatternLowering::generateRewriter(
8193a833a0eSRiver Riddle pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues,
8203a833a0eSRiver Riddle function_ref<Value(Value)> mapRewriteValue) {
8213a833a0eSRiver Riddle // If the type isn't constant, the users (e.g. OperationOp) will resolve this
8223a833a0eSRiver Riddle // type.
8233a833a0eSRiver Riddle if (ArrayAttr typeAttr = typeOp.typesAttr()) {
8243a833a0eSRiver Riddle rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>(
8253a833a0eSRiver Riddle typeOp.getLoc(), typeOp.getType(), typeAttr);
8268a1ca2cdSRiver Riddle }
8278a1ca2cdSRiver Riddle }
8288a1ca2cdSRiver Riddle
generateOperationResultTypeRewriter(pdl::OperationOp op,function_ref<Value (Value)> mapRewriteValue,SmallVectorImpl<Value> & types,DenseMap<Value,Value> & rewriteValues,bool & hasInferredResultTypes)8298a1ca2cdSRiver Riddle void PatternLowering::generateOperationResultTypeRewriter(
830*3c752289SRiver Riddle pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
831*3c752289SRiver Riddle SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
832*3c752289SRiver Riddle bool &hasInferredResultTypes) {
8333a833a0eSRiver Riddle // Look for an operation that was replaced by `op`. The result types will be
8343a833a0eSRiver Riddle // inferred from the results that were replaced.
835c4a04059SChristian Sigg Block *rewriterBlock = op->getBlock();
8363a833a0eSRiver Riddle for (OpOperand &use : op.op().getUses()) {
8378a1ca2cdSRiver Riddle // Check that the use corresponds to a ReplaceOp and that it is the
8388a1ca2cdSRiver Riddle // replacement value, not the operation being replaced.
8398a1ca2cdSRiver Riddle pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner());
8408a1ca2cdSRiver Riddle if (!replOpUser || use.getOperandNumber() == 0)
8413a833a0eSRiver Riddle continue;
8428a1ca2cdSRiver Riddle // Make sure the replaced operation was defined before this one.
8433a833a0eSRiver Riddle Value replOpVal = replOpUser.operation();
8443a833a0eSRiver Riddle Operation *replacedOp = replOpVal.getDefiningOp();
8453a833a0eSRiver Riddle if (replacedOp->getBlock() == rewriterBlock &&
8463a833a0eSRiver Riddle !replacedOp->isBeforeInBlock(op))
8473a833a0eSRiver Riddle continue;
8488a1ca2cdSRiver Riddle
8493a833a0eSRiver Riddle Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>(
8503a833a0eSRiver Riddle replacedOp->getLoc(), mapRewriteValue(replOpVal));
8513a833a0eSRiver Riddle types.push_back(builder.create<pdl_interp::GetValueTypeOp>(
8523a833a0eSRiver Riddle replacedOp->getLoc(), replacedOpResults));
8533a833a0eSRiver Riddle return;
8543a833a0eSRiver Riddle }
8553a833a0eSRiver Riddle
856*3c752289SRiver Riddle // Try to handle resolution for each of the result types individually. This is
857*3c752289SRiver Riddle // preferred over type inferrence because it will allow for us to use existing
858*3c752289SRiver Riddle // types directly, as opposed to trying to rebuild the type list.
8593a833a0eSRiver Riddle OperandRange resultTypeValues = op.types();
860*3c752289SRiver Riddle auto tryResolveResultTypes = [&] {
8618a1ca2cdSRiver Riddle types.reserve(resultTypeValues.size());
862e4853be2SMehdi Amini for (const auto &it : llvm::enumerate(resultTypeValues)) {
863242762c9SRiver Riddle Value resultType = it.value();
8648a1ca2cdSRiver Riddle
8658a1ca2cdSRiver Riddle // Check for an already translated value.
8668a1ca2cdSRiver Riddle if (Value existingRewriteValue = rewriteValues.lookup(resultType)) {
8678a1ca2cdSRiver Riddle types.push_back(existingRewriteValue);
8688a1ca2cdSRiver Riddle continue;
8698a1ca2cdSRiver Riddle }
8708a1ca2cdSRiver Riddle
8718a1ca2cdSRiver Riddle // Check for an input from the matcher.
8728a1ca2cdSRiver Riddle if (resultType.getDefiningOp()->getBlock() != rewriterBlock) {
8738a1ca2cdSRiver Riddle types.push_back(mapRewriteValue(resultType));
8748a1ca2cdSRiver Riddle continue;
8758a1ca2cdSRiver Riddle }
8768a1ca2cdSRiver Riddle
877*3c752289SRiver Riddle // Otherwise, we couldn't infer the result types. Bail out here to see if
878*3c752289SRiver Riddle // we can infer the types for this operation from another way.
879*3c752289SRiver Riddle types.clear();
880*3c752289SRiver Riddle return failure();
881*3c752289SRiver Riddle }
882*3c752289SRiver Riddle return success();
883*3c752289SRiver Riddle };
884*3c752289SRiver Riddle if (!resultTypeValues.empty() && succeeded(tryResolveResultTypes()))
885*3c752289SRiver Riddle return;
886*3c752289SRiver Riddle
887*3c752289SRiver Riddle // Otherwise, check if the operation has type inference support itself.
888*3c752289SRiver Riddle if (op.hasTypeInference()) {
889*3c752289SRiver Riddle hasInferredResultTypes = true;
890*3c752289SRiver Riddle return;
891*3c752289SRiver Riddle }
892*3c752289SRiver Riddle
893*3c752289SRiver Riddle // If the types could not be inferred from any context and there weren't any
894*3c752289SRiver Riddle // explicit result types, assume the user actually meant for the operation to
895*3c752289SRiver Riddle // have no results.
896*3c752289SRiver Riddle if (resultTypeValues.empty())
897*3c752289SRiver Riddle return;
898*3c752289SRiver Riddle
8993a833a0eSRiver Riddle // The verifier asserts that the result types of each pdl.operation can be
9003a833a0eSRiver Riddle // inferred. If we reach here, there is a bug either in the logic above or
9013a833a0eSRiver Riddle // in the verifier for pdl.operation.
9023a833a0eSRiver Riddle op->emitOpError() << "unable to infer result type for operation";
9033a833a0eSRiver Riddle llvm_unreachable("unable to infer result type for operation");
9048a1ca2cdSRiver Riddle }
9058a1ca2cdSRiver Riddle
9068a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
9078a1ca2cdSRiver Riddle // Conversion Pass
9088a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
9098a1ca2cdSRiver Riddle
9108a1ca2cdSRiver Riddle namespace {
9118a1ca2cdSRiver Riddle struct PDLToPDLInterpPass
9128a1ca2cdSRiver Riddle : public ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> {
9138a1ca2cdSRiver Riddle void runOnOperation() final;
9148a1ca2cdSRiver Riddle };
9158a1ca2cdSRiver Riddle } // namespace
9168a1ca2cdSRiver Riddle
9178a1ca2cdSRiver Riddle /// Convert the given module containing PDL pattern operations into a PDL
9188a1ca2cdSRiver Riddle /// Interpreter operations.
runOnOperation()9198a1ca2cdSRiver Riddle void PDLToPDLInterpPass::runOnOperation() {
9208a1ca2cdSRiver Riddle ModuleOp module = getOperation();
9218a1ca2cdSRiver Riddle
9228a1ca2cdSRiver Riddle // Create the main matcher function This function contains all of the match
9238a1ca2cdSRiver Riddle // related functionality from patterns in the module.
9248a1ca2cdSRiver Riddle OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
925f96a8675SRiver Riddle auto matcherFunc = builder.create<pdl_interp::FuncOp>(
9268a1ca2cdSRiver Riddle module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
9278a1ca2cdSRiver Riddle builder.getFunctionType(builder.getType<pdl::OperationType>(),
9288a1ca2cdSRiver Riddle /*results=*/llvm::None),
9298a1ca2cdSRiver Riddle /*attrs=*/llvm::None);
9308a1ca2cdSRiver Riddle
9318a1ca2cdSRiver Riddle // Create a nested module to hold the functions invoked for rewriting the IR
9328a1ca2cdSRiver Riddle // after a successful match.
9338a1ca2cdSRiver Riddle ModuleOp rewriterModule = builder.create<ModuleOp>(
9348a1ca2cdSRiver Riddle module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName());
9358a1ca2cdSRiver Riddle
9368a1ca2cdSRiver Riddle // Generate the code for the patterns within the module.
9378a1ca2cdSRiver Riddle PatternLowering generator(matcherFunc, rewriterModule);
9388a1ca2cdSRiver Riddle generator.lower(module);
9398a1ca2cdSRiver Riddle
9408a1ca2cdSRiver Riddle // After generation, delete all of the pattern operations.
9418a1ca2cdSRiver Riddle for (pdl::PatternOp pattern :
9428a1ca2cdSRiver Riddle llvm::make_early_inc_range(module.getOps<pdl::PatternOp>()))
9438a1ca2cdSRiver Riddle pattern.erase();
9448a1ca2cdSRiver Riddle }
9458a1ca2cdSRiver Riddle
createPDLToPDLInterpPass()9468a1ca2cdSRiver Riddle std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() {
9478a1ca2cdSRiver Riddle return std::make_unique<PDLToPDLInterpPass>();
9488a1ca2cdSRiver Riddle }
949