18a1ca2cdSRiver Riddle //===- PredicateTree.cpp - Predicate tree merging -------------------------===//
28a1ca2cdSRiver Riddle //
38a1ca2cdSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
48a1ca2cdSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
58a1ca2cdSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68a1ca2cdSRiver Riddle //
78a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
88a1ca2cdSRiver Riddle 
98a1ca2cdSRiver Riddle #include "PredicateTree.h"
10a76ee58fSStanislav Funiak #include "RootOrdering.h"
11a76ee58fSStanislav Funiak 
128a1ca2cdSRiver Riddle #include "mlir/Dialect/PDL/IR/PDL.h"
138a1ca2cdSRiver Riddle #include "mlir/Dialect/PDL/IR/PDLTypes.h"
148a1ca2cdSRiver Riddle #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
1565fcddffSRiver Riddle #include "mlir/IR/BuiltinOps.h"
168a1ca2cdSRiver Riddle #include "mlir/Interfaces/InferTypeOpInterface.h"
17a76ee58fSStanislav Funiak #include "llvm/ADT/MapVector.h"
18242762c9SRiver Riddle #include "llvm/ADT/TypeSwitch.h"
19a76ee58fSStanislav Funiak #include "llvm/Support/Debug.h"
20a76ee58fSStanislav Funiak #include <queue>
21a76ee58fSStanislav Funiak 
22a76ee58fSStanislav Funiak #define DEBUG_TYPE "pdl-predicate-tree"
238a1ca2cdSRiver Riddle 
248a1ca2cdSRiver Riddle using namespace mlir;
258a1ca2cdSRiver Riddle using namespace mlir::pdl_to_pdl_interp;
268a1ca2cdSRiver Riddle 
278a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
288a1ca2cdSRiver Riddle // Predicate List Building
298a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
308a1ca2cdSRiver Riddle 
31242762c9SRiver Riddle static void getTreePredicates(std::vector<PositionalPredicate> &predList,
32242762c9SRiver Riddle                               Value val, PredicateBuilder &builder,
33242762c9SRiver Riddle                               DenseMap<Value, Position *> &inputs,
34242762c9SRiver Riddle                               Position *pos);
35242762c9SRiver Riddle 
368a1ca2cdSRiver Riddle /// Compares the depths of two positions.
comparePosDepth(Position * lhs,Position * rhs)378a1ca2cdSRiver Riddle static bool comparePosDepth(Position *lhs, Position *rhs) {
383a833a0eSRiver Riddle   return lhs->getOperationDepth() < rhs->getOperationDepth();
393a833a0eSRiver Riddle }
403a833a0eSRiver Riddle 
413a833a0eSRiver Riddle /// Returns the number of non-range elements within `values`.
getNumNonRangeValues(ValueRange values)423a833a0eSRiver Riddle static unsigned getNumNonRangeValues(ValueRange values) {
433a833a0eSRiver Riddle   return llvm::count_if(values.getTypes(),
443a833a0eSRiver Riddle                         [](Type type) { return !type.isa<pdl::RangeType>(); });
458a1ca2cdSRiver Riddle }
468a1ca2cdSRiver Riddle 
getTreePredicates(std::vector<PositionalPredicate> & predList,Value val,PredicateBuilder & builder,DenseMap<Value,Position * > & inputs,AttributePosition * pos)478a1ca2cdSRiver Riddle static void getTreePredicates(std::vector<PositionalPredicate> &predList,
488a1ca2cdSRiver Riddle                               Value val, PredicateBuilder &builder,
498a1ca2cdSRiver Riddle                               DenseMap<Value, Position *> &inputs,
50242762c9SRiver Riddle                               AttributePosition *pos) {
51242762c9SRiver Riddle   assert(val.getType().isa<pdl::AttributeType>() && "expected attribute type");
528a1ca2cdSRiver Riddle   pdl::AttributeOp attr = cast<pdl::AttributeOp>(val.getDefiningOp());
538a1ca2cdSRiver Riddle   predList.emplace_back(pos, builder.getIsNotNull());
548a1ca2cdSRiver Riddle 
55242762c9SRiver Riddle   // If the attribute has a type or value, add a constraint.
56242762c9SRiver Riddle   if (Value type = attr.type())
578a1ca2cdSRiver Riddle     getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
58242762c9SRiver Riddle   else if (Attribute value = attr.valueAttr())
59242762c9SRiver Riddle     predList.emplace_back(pos, builder.getAttributeConstraint(value));
60242762c9SRiver Riddle }
618a1ca2cdSRiver Riddle 
623a833a0eSRiver Riddle /// Collect all of the predicates for the given operand position.
getOperandTreePredicates(std::vector<PositionalPredicate> & predList,Value val,PredicateBuilder & builder,DenseMap<Value,Position * > & inputs,Position * pos)633a833a0eSRiver Riddle static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList,
64242762c9SRiver Riddle                                      Value val, PredicateBuilder &builder,
65242762c9SRiver Riddle                                      DenseMap<Value, Position *> &inputs,
663a833a0eSRiver Riddle                                      Position *pos) {
673a833a0eSRiver Riddle   Type valueType = val.getType();
683a833a0eSRiver Riddle   bool isVariadic = valueType.isa<pdl::RangeType>();
698a1ca2cdSRiver Riddle 
70e07c968aSRiver Riddle   // If this is a typed operand, add a type constraint.
713a833a0eSRiver Riddle   TypeSwitch<Operation *>(val.getDefiningOp())
723a833a0eSRiver Riddle       .Case<pdl::OperandOp, pdl::OperandsOp>([&](auto op) {
733a833a0eSRiver Riddle         // Prevent traversal into a null value if the operand has a proper
743a833a0eSRiver Riddle         // index.
753a833a0eSRiver Riddle         if (std::is_same<pdl::OperandOp, decltype(op)>::value ||
763a833a0eSRiver Riddle             cast<OperandGroupPosition>(pos)->getOperandGroupNumber())
773a833a0eSRiver Riddle           predList.emplace_back(pos, builder.getIsNotNull());
78242762c9SRiver Riddle 
793a833a0eSRiver Riddle         if (Value type = op.type())
803a833a0eSRiver Riddle           getTreePredicates(predList, type, builder, inputs,
813a833a0eSRiver Riddle                             builder.getType(pos));
823a833a0eSRiver Riddle       })
833a833a0eSRiver Riddle       .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto op) {
843a833a0eSRiver Riddle         Optional<unsigned> index = op.index();
853a833a0eSRiver Riddle 
863a833a0eSRiver Riddle         // Prevent traversal into a null value if the result has a proper index.
873a833a0eSRiver Riddle         if (index)
883a833a0eSRiver Riddle           predList.emplace_back(pos, builder.getIsNotNull());
893a833a0eSRiver Riddle 
903a833a0eSRiver Riddle         // Get the parent operation of this operand.
913a833a0eSRiver Riddle         OperationPosition *parentPos = builder.getOperandDefiningOp(pos);
92242762c9SRiver Riddle         predList.emplace_back(parentPos, builder.getIsNotNull());
933a833a0eSRiver Riddle 
943a833a0eSRiver Riddle         // Ensure that the operands match the corresponding results of the
953a833a0eSRiver Riddle         // parent operation.
963a833a0eSRiver Riddle         Position *resultPos = nullptr;
973a833a0eSRiver Riddle         if (std::is_same<pdl::ResultOp, decltype(op)>::value)
983a833a0eSRiver Riddle           resultPos = builder.getResult(parentPos, *index);
993a833a0eSRiver Riddle         else
1003a833a0eSRiver Riddle           resultPos = builder.getResultGroup(parentPos, index, isVariadic);
101242762c9SRiver Riddle         predList.emplace_back(resultPos, builder.getEqualTo(pos));
1023a833a0eSRiver Riddle 
1033a833a0eSRiver Riddle         // Collect the predicates of the parent operation.
1041f13963eSRiver Riddle         getTreePredicates(predList, op.parent(), builder, inputs,
1051f13963eSRiver Riddle                           (Position *)parentPos);
1063a833a0eSRiver Riddle       });
1078a1ca2cdSRiver Riddle }
1088a1ca2cdSRiver Riddle 
getTreePredicates(std::vector<PositionalPredicate> & predList,Value val,PredicateBuilder & builder,DenseMap<Value,Position * > & inputs,OperationPosition * pos,Optional<unsigned> ignoreOperand=llvm::None)109242762c9SRiver Riddle static void getTreePredicates(std::vector<PositionalPredicate> &predList,
110242762c9SRiver Riddle                               Value val, PredicateBuilder &builder,
111242762c9SRiver Riddle                               DenseMap<Value, Position *> &inputs,
112a76ee58fSStanislav Funiak                               OperationPosition *pos,
113a76ee58fSStanislav Funiak                               Optional<unsigned> ignoreOperand = llvm::None) {
1148a1ca2cdSRiver Riddle   assert(val.getType().isa<pdl::OperationType>() && "expected operation");
1158a1ca2cdSRiver Riddle   pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp());
1168a1ca2cdSRiver Riddle   OperationPosition *opPos = cast<OperationPosition>(pos);
1178a1ca2cdSRiver Riddle 
1188a1ca2cdSRiver Riddle   // Ensure getDefiningOp returns a non-null operation.
1198a1ca2cdSRiver Riddle   if (!opPos->isRoot())
1208a1ca2cdSRiver Riddle     predList.emplace_back(pos, builder.getIsNotNull());
1218a1ca2cdSRiver Riddle 
1228a1ca2cdSRiver Riddle   // Check that this is the correct root operation.
1238a1ca2cdSRiver Riddle   if (Optional<StringRef> opName = op.name())
1248a1ca2cdSRiver Riddle     predList.emplace_back(pos, builder.getOperationName(*opName));
1258a1ca2cdSRiver Riddle 
1263a833a0eSRiver Riddle   // Check that the operation has the proper number of operands. If there are
1273a833a0eSRiver Riddle   // any variable length operands, we check a minimum instead of an exact count.
1288a1ca2cdSRiver Riddle   OperandRange operands = op.operands();
1293a833a0eSRiver Riddle   unsigned minOperands = getNumNonRangeValues(operands);
1303a833a0eSRiver Riddle   if (minOperands != operands.size()) {
1313a833a0eSRiver Riddle     if (minOperands)
1323a833a0eSRiver Riddle       predList.emplace_back(pos, builder.getOperandCountAtLeast(minOperands));
1333a833a0eSRiver Riddle   } else {
1343a833a0eSRiver Riddle     predList.emplace_back(pos, builder.getOperandCount(minOperands));
1353a833a0eSRiver Riddle   }
1363a833a0eSRiver Riddle 
1373a833a0eSRiver Riddle   // Check that the operation has the proper number of results. If there are
1383a833a0eSRiver Riddle   // any variable length results, we check a minimum instead of an exact count.
139242762c9SRiver Riddle   OperandRange types = op.types();
1403a833a0eSRiver Riddle   unsigned minResults = getNumNonRangeValues(types);
1413a833a0eSRiver Riddle   if (minResults == types.size())
142242762c9SRiver Riddle     predList.emplace_back(pos, builder.getResultCount(types.size()));
1433a833a0eSRiver Riddle   else if (minResults)
1443a833a0eSRiver Riddle     predList.emplace_back(pos, builder.getResultCountAtLeast(minResults));
1458a1ca2cdSRiver Riddle 
1468a1ca2cdSRiver Riddle   // Recurse into any attributes, operands, or results.
1478a1ca2cdSRiver Riddle   for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
1488a1ca2cdSRiver Riddle     getTreePredicates(
1498a1ca2cdSRiver Riddle         predList, std::get<1>(it), builder, inputs,
1508a1ca2cdSRiver Riddle         builder.getAttribute(opPos,
1518a1ca2cdSRiver Riddle                              std::get<0>(it).cast<StringAttr>().getValue()));
1528a1ca2cdSRiver Riddle   }
1533a833a0eSRiver Riddle 
1543a833a0eSRiver Riddle   // Process the operands and results of the operation. For all values up to
1553a833a0eSRiver Riddle   // the first variable length value, we use the concrete operand/result
1563a833a0eSRiver Riddle   // number. After that, we use the "group" given that we can't know the
1573a833a0eSRiver Riddle   // concrete indices until runtime. If there is only one variadic operand
1583a833a0eSRiver Riddle   // group, we treat it as all of the operands/results of the operation.
1593a833a0eSRiver Riddle   /// Operands.
1603a833a0eSRiver Riddle   if (operands.size() == 1 && operands[0].getType().isa<pdl::RangeType>()) {
1612692eae5SStanislav Funiak     // Ignore the operands if we are performing an upward traversal (in that
1622692eae5SStanislav Funiak     // case, they have already been visited).
1632692eae5SStanislav Funiak     if (opPos->isRoot() || opPos->isOperandDefiningOp())
1643a833a0eSRiver Riddle       getTreePredicates(predList, operands.front(), builder, inputs,
1653a833a0eSRiver Riddle                         builder.getAllOperands(opPos));
1663a833a0eSRiver Riddle   } else {
1673a833a0eSRiver Riddle     bool foundVariableLength = false;
168e4853be2SMehdi Amini     for (const auto &operandIt : llvm::enumerate(operands)) {
1693a833a0eSRiver Riddle       bool isVariadic = operandIt.value().getType().isa<pdl::RangeType>();
1703a833a0eSRiver Riddle       foundVariableLength |= isVariadic;
1713a833a0eSRiver Riddle 
172a76ee58fSStanislav Funiak       // Ignore the specified operand, usually because this position was
173a76ee58fSStanislav Funiak       // visited in an upward traversal via an iterative choice.
174a76ee58fSStanislav Funiak       if (ignoreOperand && *ignoreOperand == operandIt.index())
175a76ee58fSStanislav Funiak         continue;
176a76ee58fSStanislav Funiak 
1773a833a0eSRiver Riddle       Position *pos =
1783a833a0eSRiver Riddle           foundVariableLength
1793a833a0eSRiver Riddle               ? builder.getOperandGroup(opPos, operandIt.index(), isVariadic)
1803a833a0eSRiver Riddle               : builder.getOperand(opPos, operandIt.index());
1813a833a0eSRiver Riddle       getTreePredicates(predList, operandIt.value(), builder, inputs, pos);
182242762c9SRiver Riddle     }
1833a833a0eSRiver Riddle   }
1843a833a0eSRiver Riddle   /// Results.
1853a833a0eSRiver Riddle   if (types.size() == 1 && types[0].getType().isa<pdl::RangeType>()) {
1863a833a0eSRiver Riddle     getTreePredicates(predList, types.front(), builder, inputs,
1873a833a0eSRiver Riddle                       builder.getType(builder.getAllResults(opPos)));
1883a833a0eSRiver Riddle   } else {
1893a833a0eSRiver Riddle     bool foundVariableLength = false;
190242762c9SRiver Riddle     for (auto &resultIt : llvm::enumerate(types)) {
1913a833a0eSRiver Riddle       bool isVariadic = resultIt.value().getType().isa<pdl::RangeType>();
1923a833a0eSRiver Riddle       foundVariableLength |= isVariadic;
1933a833a0eSRiver Riddle 
1943a833a0eSRiver Riddle       auto *resultPos =
1953a833a0eSRiver Riddle           foundVariableLength
1963a833a0eSRiver Riddle               ? builder.getResultGroup(pos, resultIt.index(), isVariadic)
1973a833a0eSRiver Riddle               : builder.getResult(pos, resultIt.index());
198242762c9SRiver Riddle       predList.emplace_back(resultPos, builder.getIsNotNull());
1998a1ca2cdSRiver Riddle       getTreePredicates(predList, resultIt.value(), builder, inputs,
200242762c9SRiver Riddle                         builder.getType(resultPos));
2018a1ca2cdSRiver Riddle     }
2028a1ca2cdSRiver Riddle   }
2033a833a0eSRiver Riddle }
2048a1ca2cdSRiver Riddle 
getTreePredicates(std::vector<PositionalPredicate> & predList,Value val,PredicateBuilder & builder,DenseMap<Value,Position * > & inputs,TypePosition * pos)205242762c9SRiver Riddle static void getTreePredicates(std::vector<PositionalPredicate> &predList,
206242762c9SRiver Riddle                               Value val, PredicateBuilder &builder,
207242762c9SRiver Riddle                               DenseMap<Value, Position *> &inputs,
208242762c9SRiver Riddle                               TypePosition *pos) {
2098a1ca2cdSRiver Riddle   // Check for a constraint on a constant type.
2103a833a0eSRiver Riddle   if (pdl::TypeOp typeOp = val.getDefiningOp<pdl::TypeOp>()) {
2113a833a0eSRiver Riddle     if (Attribute type = typeOp.typeAttr())
2123a833a0eSRiver Riddle       predList.emplace_back(pos, builder.getTypeConstraint(type));
2133a833a0eSRiver Riddle   } else if (pdl::TypesOp typeOp = val.getDefiningOp<pdl::TypesOp>()) {
2143a833a0eSRiver Riddle     if (Attribute typeAttr = typeOp.typesAttr())
2153a833a0eSRiver Riddle       predList.emplace_back(pos, builder.getTypeConstraint(typeAttr));
2163a833a0eSRiver Riddle   }
2178a1ca2cdSRiver Riddle }
218242762c9SRiver Riddle 
219242762c9SRiver Riddle /// Collect the tree predicates anchored at the given value.
getTreePredicates(std::vector<PositionalPredicate> & predList,Value val,PredicateBuilder & builder,DenseMap<Value,Position * > & inputs,Position * pos)220242762c9SRiver Riddle static void getTreePredicates(std::vector<PositionalPredicate> &predList,
221242762c9SRiver Riddle                               Value val, PredicateBuilder &builder,
222242762c9SRiver Riddle                               DenseMap<Value, Position *> &inputs,
223242762c9SRiver Riddle                               Position *pos) {
224242762c9SRiver Riddle   // Make sure this input value is accessible to the rewrite.
225242762c9SRiver Riddle   auto it = inputs.try_emplace(val, pos);
226242762c9SRiver Riddle   if (!it.second) {
227242762c9SRiver Riddle     // If this is an input value that has been visited in the tree, add a
228242762c9SRiver Riddle     // constraint to ensure that both instances refer to the same value.
2293a833a0eSRiver Riddle     if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperandsOp, pdl::OperationOp,
2303a833a0eSRiver Riddle             pdl::TypeOp>(val.getDefiningOp())) {
231242762c9SRiver Riddle       auto minMaxPositions =
232242762c9SRiver Riddle           std::minmax(pos, it.first->second, comparePosDepth);
233242762c9SRiver Riddle       predList.emplace_back(minMaxPositions.second,
234242762c9SRiver Riddle                             builder.getEqualTo(minMaxPositions.first));
2358a1ca2cdSRiver Riddle     }
236242762c9SRiver Riddle     return;
237242762c9SRiver Riddle   }
238242762c9SRiver Riddle 
239242762c9SRiver Riddle   TypeSwitch<Position *>(pos)
2403a833a0eSRiver Riddle       .Case<AttributePosition, OperationPosition, TypePosition>([&](auto *pos) {
2413a833a0eSRiver Riddle         getTreePredicates(predList, val, builder, inputs, pos);
2423a833a0eSRiver Riddle       })
2433a833a0eSRiver Riddle       .Case<OperandPosition, OperandGroupPosition>([&](auto *pos) {
2443a833a0eSRiver Riddle         getOperandTreePredicates(predList, val, builder, inputs, pos);
245242762c9SRiver Riddle       })
246242762c9SRiver Riddle       .Default([](auto *) { llvm_unreachable("unexpected position kind"); });
2478a1ca2cdSRiver Riddle }
2488a1ca2cdSRiver Riddle 
getAttributePredicates(pdl::AttributeOp op,std::vector<PositionalPredicate> & predList,PredicateBuilder & builder,DenseMap<Value,Position * > & inputs)249233e9476SRiver Riddle static void getAttributePredicates(pdl::AttributeOp op,
250233e9476SRiver Riddle                                    std::vector<PositionalPredicate> &predList,
251233e9476SRiver Riddle                                    PredicateBuilder &builder,
252233e9476SRiver Riddle                                    DenseMap<Value, Position *> &inputs) {
253233e9476SRiver Riddle   Position *&attrPos = inputs[op];
254233e9476SRiver Riddle   if (attrPos)
255233e9476SRiver Riddle     return;
256233e9476SRiver Riddle   Attribute value = op.valueAttr();
257233e9476SRiver Riddle   assert(value && "expected non-tree `pdl.attribute` to contain a value");
258233e9476SRiver Riddle   attrPos = builder.getAttributeLiteral(value);
259233e9476SRiver Riddle }
260233e9476SRiver Riddle 
getConstraintPredicates(pdl::ApplyNativeConstraintOp op,std::vector<PositionalPredicate> & predList,PredicateBuilder & builder,DenseMap<Value,Position * > & inputs)26102c4c0d5SRiver Riddle static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
262242762c9SRiver Riddle                                     std::vector<PositionalPredicate> &predList,
263242762c9SRiver Riddle                                     PredicateBuilder &builder,
264242762c9SRiver Riddle                                     DenseMap<Value, Position *> &inputs) {
2658a1ca2cdSRiver Riddle   OperandRange arguments = op.args();
2668a1ca2cdSRiver Riddle 
2678a1ca2cdSRiver Riddle   std::vector<Position *> allPositions;
2688a1ca2cdSRiver Riddle   allPositions.reserve(arguments.size());
2698a1ca2cdSRiver Riddle   for (Value arg : arguments)
2708a1ca2cdSRiver Riddle     allPositions.push_back(inputs.lookup(arg));
2718a1ca2cdSRiver Riddle 
2728a1ca2cdSRiver Riddle   // Push the constraint to the furthest position.
2738a1ca2cdSRiver Riddle   Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
2748a1ca2cdSRiver Riddle                                     comparePosDepth);
2758a1ca2cdSRiver Riddle   PredicateBuilder::Predicate pred =
276*9595f356SRiver Riddle       builder.getConstraint(op.name(), allPositions);
2778a1ca2cdSRiver Riddle   predList.emplace_back(pos, pred);
2788a1ca2cdSRiver Riddle }
279242762c9SRiver Riddle 
getResultPredicates(pdl::ResultOp op,std::vector<PositionalPredicate> & predList,PredicateBuilder & builder,DenseMap<Value,Position * > & inputs)280242762c9SRiver Riddle static void getResultPredicates(pdl::ResultOp op,
281242762c9SRiver Riddle                                 std::vector<PositionalPredicate> &predList,
282242762c9SRiver Riddle                                 PredicateBuilder &builder,
283242762c9SRiver Riddle                                 DenseMap<Value, Position *> &inputs) {
284242762c9SRiver Riddle   Position *&resultPos = inputs[op];
285242762c9SRiver Riddle   if (resultPos)
286242762c9SRiver Riddle     return;
2873a833a0eSRiver Riddle 
2883a833a0eSRiver Riddle   // Ensure that the result isn't null.
289242762c9SRiver Riddle   auto *parentPos = cast<OperationPosition>(inputs.lookup(op.parent()));
290242762c9SRiver Riddle   resultPos = builder.getResult(parentPos, op.index());
291242762c9SRiver Riddle   predList.emplace_back(resultPos, builder.getIsNotNull());
292242762c9SRiver Riddle }
293242762c9SRiver Riddle 
getResultPredicates(pdl::ResultsOp op,std::vector<PositionalPredicate> & predList,PredicateBuilder & builder,DenseMap<Value,Position * > & inputs)2943a833a0eSRiver Riddle static void getResultPredicates(pdl::ResultsOp op,
2953a833a0eSRiver Riddle                                 std::vector<PositionalPredicate> &predList,
2963a833a0eSRiver Riddle                                 PredicateBuilder &builder,
2973a833a0eSRiver Riddle                                 DenseMap<Value, Position *> &inputs) {
2983a833a0eSRiver Riddle   Position *&resultPos = inputs[op];
2993a833a0eSRiver Riddle   if (resultPos)
3003a833a0eSRiver Riddle     return;
3013a833a0eSRiver Riddle 
3023a833a0eSRiver Riddle   // Ensure that the result isn't null if the result has an index.
3033a833a0eSRiver Riddle   auto *parentPos = cast<OperationPosition>(inputs.lookup(op.parent()));
3043a833a0eSRiver Riddle   bool isVariadic = op.getType().isa<pdl::RangeType>();
3053a833a0eSRiver Riddle   Optional<unsigned> index = op.index();
3063a833a0eSRiver Riddle   resultPos = builder.getResultGroup(parentPos, index, isVariadic);
3073a833a0eSRiver Riddle   if (index)
3083a833a0eSRiver Riddle     predList.emplace_back(resultPos, builder.getIsNotNull());
3093a833a0eSRiver Riddle }
3103a833a0eSRiver Riddle 
getTypePredicates(Value typeValue,function_ref<Attribute ()> typeAttrFn,PredicateBuilder & builder,DenseMap<Value,Position * > & inputs)311233e9476SRiver Riddle static void getTypePredicates(Value typeValue,
312233e9476SRiver Riddle                               function_ref<Attribute()> typeAttrFn,
313233e9476SRiver Riddle                               PredicateBuilder &builder,
314233e9476SRiver Riddle                               DenseMap<Value, Position *> &inputs) {
315233e9476SRiver Riddle   Position *&typePos = inputs[typeValue];
316233e9476SRiver Riddle   if (typePos)
317233e9476SRiver Riddle     return;
318233e9476SRiver Riddle   Attribute typeAttr = typeAttrFn();
319233e9476SRiver Riddle   assert(typeAttr &&
320233e9476SRiver Riddle          "expected non-tree `pdl.type`/`pdl.types` to contain a value");
321233e9476SRiver Riddle   typePos = builder.getTypeLiteral(typeAttr);
322233e9476SRiver Riddle }
323233e9476SRiver Riddle 
324242762c9SRiver Riddle /// Collect all of the predicates that cannot be determined via walking the
325242762c9SRiver Riddle /// tree.
getNonTreePredicates(pdl::PatternOp pattern,std::vector<PositionalPredicate> & predList,PredicateBuilder & builder,DenseMap<Value,Position * > & inputs)326242762c9SRiver Riddle static void getNonTreePredicates(pdl::PatternOp pattern,
327242762c9SRiver Riddle                                  std::vector<PositionalPredicate> &predList,
328242762c9SRiver Riddle                                  PredicateBuilder &builder,
329242762c9SRiver Riddle                                  DenseMap<Value, Position *> &inputs) {
330242762c9SRiver Riddle   for (Operation &op : pattern.body().getOps()) {
3313a833a0eSRiver Riddle     TypeSwitch<Operation *>(&op)
332233e9476SRiver Riddle         .Case([&](pdl::AttributeOp attrOp) {
333233e9476SRiver Riddle           getAttributePredicates(attrOp, predList, builder, inputs);
334233e9476SRiver Riddle         })
3353a833a0eSRiver Riddle         .Case<pdl::ApplyNativeConstraintOp>([&](auto constraintOp) {
336242762c9SRiver Riddle           getConstraintPredicates(constraintOp, predList, builder, inputs);
3373a833a0eSRiver Riddle         })
3383a833a0eSRiver Riddle         .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
339242762c9SRiver Riddle           getResultPredicates(resultOp, predList, builder, inputs);
340233e9476SRiver Riddle         })
341233e9476SRiver Riddle         .Case([&](pdl::TypeOp typeOp) {
342233e9476SRiver Riddle           getTypePredicates(
343233e9476SRiver Riddle               typeOp, [&] { return typeOp.typeAttr(); }, builder, inputs);
344233e9476SRiver Riddle         })
345233e9476SRiver Riddle         .Case([&](pdl::TypesOp typeOp) {
346233e9476SRiver Riddle           getTypePredicates(
347233e9476SRiver Riddle               typeOp, [&] { return typeOp.typesAttr(); }, builder, inputs);
3483a833a0eSRiver Riddle         });
349242762c9SRiver Riddle   }
3508a1ca2cdSRiver Riddle }
3518a1ca2cdSRiver Riddle 
352a76ee58fSStanislav Funiak namespace {
353a76ee58fSStanislav Funiak 
354a76ee58fSStanislav Funiak /// An op accepting a value at an optional index.
355a76ee58fSStanislav Funiak struct OpIndex {
356a76ee58fSStanislav Funiak   Value parent;
357a76ee58fSStanislav Funiak   Optional<unsigned> index;
358a76ee58fSStanislav Funiak };
359a76ee58fSStanislav Funiak 
360a76ee58fSStanislav Funiak /// The parent and operand index of each operation for each root, stored
361a76ee58fSStanislav Funiak /// as a nested map [root][operation].
362a76ee58fSStanislav Funiak using ParentMaps = DenseMap<Value, DenseMap<Value, OpIndex>>;
363a76ee58fSStanislav Funiak 
364a76ee58fSStanislav Funiak } // namespace
365a76ee58fSStanislav Funiak 
366a76ee58fSStanislav Funiak /// Given a pattern, determines the set of roots present in this pattern.
367a76ee58fSStanislav Funiak /// These are the operations whose results are not consumed by other operations.
detectRoots(pdl::PatternOp pattern)368a76ee58fSStanislav Funiak static SmallVector<Value> detectRoots(pdl::PatternOp pattern) {
369a76ee58fSStanislav Funiak   // First, collect all the operations that are used as operands
370a76ee58fSStanislav Funiak   // to other operations. These are not roots by default.
371a76ee58fSStanislav Funiak   DenseSet<Value> used;
372a76ee58fSStanislav Funiak   for (auto operationOp : pattern.body().getOps<pdl::OperationOp>()) {
373a76ee58fSStanislav Funiak     for (Value operand : operationOp.operands())
374a76ee58fSStanislav Funiak       TypeSwitch<Operation *>(operand.getDefiningOp())
375a76ee58fSStanislav Funiak           .Case<pdl::ResultOp, pdl::ResultsOp>(
376a76ee58fSStanislav Funiak               [&used](auto resultOp) { used.insert(resultOp.parent()); });
377a76ee58fSStanislav Funiak   }
378a76ee58fSStanislav Funiak 
379a76ee58fSStanislav Funiak   // Remove the specified root from the use set, so that we can
380a76ee58fSStanislav Funiak   // always select it as a root, even if it is used by other operations.
381a76ee58fSStanislav Funiak   if (Value root = pattern.getRewriter().root())
382a76ee58fSStanislav Funiak     used.erase(root);
383a76ee58fSStanislav Funiak 
384a76ee58fSStanislav Funiak   // Finally, collect all the unused operations.
385a76ee58fSStanislav Funiak   SmallVector<Value> roots;
386a76ee58fSStanislav Funiak   for (Value operationOp : pattern.body().getOps<pdl::OperationOp>())
387a76ee58fSStanislav Funiak     if (!used.contains(operationOp))
388a76ee58fSStanislav Funiak       roots.push_back(operationOp);
389a76ee58fSStanislav Funiak 
390a76ee58fSStanislav Funiak   return roots;
391a76ee58fSStanislav Funiak }
392a76ee58fSStanislav Funiak 
393a76ee58fSStanislav Funiak /// Given a list of candidate roots, builds the cost graph for connecting them.
394a76ee58fSStanislav Funiak /// The graph is formed by traversing the DAG of operations starting from each
395a76ee58fSStanislav Funiak /// root and marking the depth of each connector value (operand). Then we join
396a76ee58fSStanislav Funiak /// the candidate roots based on the common connector values, taking the one
397a76ee58fSStanislav Funiak /// with the minimum depth. Along the way, we compute, for each candidate root,
398a76ee58fSStanislav Funiak /// a mapping from each operation (in the DAG underneath this root) to its
399a76ee58fSStanislav Funiak /// parent operation and the corresponding operand index.
buildCostGraph(ArrayRef<Value> roots,RootOrderingGraph & graph,ParentMaps & parentMaps)400a76ee58fSStanislav Funiak static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph,
401a76ee58fSStanislav Funiak                            ParentMaps &parentMaps) {
402a76ee58fSStanislav Funiak 
403a76ee58fSStanislav Funiak   // The entry of a queue. The entry consists of the following items:
404a76ee58fSStanislav Funiak   // * the value in the DAG underneath the root;
405a76ee58fSStanislav Funiak   // * the parent of the value;
406a76ee58fSStanislav Funiak   // * the operand index of the value in its parent;
407a76ee58fSStanislav Funiak   // * the depth of the visited value.
408a76ee58fSStanislav Funiak   struct Entry {
409a76ee58fSStanislav Funiak     Entry(Value value, Value parent, Optional<unsigned> index, unsigned depth)
410a76ee58fSStanislav Funiak         : value(value), parent(parent), index(index), depth(depth) {}
411a76ee58fSStanislav Funiak 
412a76ee58fSStanislav Funiak     Value value;
413a76ee58fSStanislav Funiak     Value parent;
414a76ee58fSStanislav Funiak     Optional<unsigned> index;
415a76ee58fSStanislav Funiak     unsigned depth;
416a76ee58fSStanislav Funiak   };
417a76ee58fSStanislav Funiak 
418a76ee58fSStanislav Funiak   // A root of a value and its depth (distance from root to the value).
419a76ee58fSStanislav Funiak   struct RootDepth {
420a76ee58fSStanislav Funiak     Value root;
421a76ee58fSStanislav Funiak     unsigned depth = 0;
422a76ee58fSStanislav Funiak   };
423a76ee58fSStanislav Funiak 
424a76ee58fSStanislav Funiak   // Map from candidate connector values to their roots and depths. Using a
425a76ee58fSStanislav Funiak   // small vector with 1 entry because most values belong to a single root.
426a76ee58fSStanislav Funiak   llvm::MapVector<Value, SmallVector<RootDepth, 1>> connectorsRootsDepths;
427a76ee58fSStanislav Funiak 
428a76ee58fSStanislav Funiak   // Perform a breadth-first traversal of the op DAG rooted at each root.
429a76ee58fSStanislav Funiak   for (Value root : roots) {
430a76ee58fSStanislav Funiak     // The queue of visited values. A value may be present multiple times in
431a76ee58fSStanislav Funiak     // the queue, for multiple parents. We only accept the first occurrence,
432a76ee58fSStanislav Funiak     // which is guaranteed to have the lowest depth.
433a76ee58fSStanislav Funiak     std::queue<Entry> toVisit;
434a76ee58fSStanislav Funiak     toVisit.emplace(root, Value(), 0, 0);
435a76ee58fSStanislav Funiak 
436a76ee58fSStanislav Funiak     // The map from value to its parent for the current root.
437a76ee58fSStanislav Funiak     DenseMap<Value, OpIndex> &parentMap = parentMaps[root];
438a76ee58fSStanislav Funiak 
439a76ee58fSStanislav Funiak     while (!toVisit.empty()) {
440a76ee58fSStanislav Funiak       Entry entry = toVisit.front();
441a76ee58fSStanislav Funiak       toVisit.pop();
442a76ee58fSStanislav Funiak       // Skip if already visited.
443a76ee58fSStanislav Funiak       if (!parentMap.insert({entry.value, {entry.parent, entry.index}}).second)
444a76ee58fSStanislav Funiak         continue;
445a76ee58fSStanislav Funiak 
446a76ee58fSStanislav Funiak       // Mark the root and depth of the value.
447a76ee58fSStanislav Funiak       connectorsRootsDepths[entry.value].push_back({root, entry.depth});
448a76ee58fSStanislav Funiak 
449a76ee58fSStanislav Funiak       // Traverse the operands of an operation and result ops.
450a76ee58fSStanislav Funiak       // We intentionally do not traverse attributes and types, because those
451a76ee58fSStanislav Funiak       // are expensive to join on.
452a76ee58fSStanislav Funiak       TypeSwitch<Operation *>(entry.value.getDefiningOp())
453a76ee58fSStanislav Funiak           .Case<pdl::OperationOp>([&](auto operationOp) {
454a76ee58fSStanislav Funiak             OperandRange operands = operationOp.operands();
455a76ee58fSStanislav Funiak             // Special case when we pass all the operands in one range.
456a76ee58fSStanislav Funiak             // For those, the index is empty.
457a76ee58fSStanislav Funiak             if (operands.size() == 1 &&
458a76ee58fSStanislav Funiak                 operands[0].getType().isa<pdl::RangeType>()) {
459a76ee58fSStanislav Funiak               toVisit.emplace(operands[0], entry.value, llvm::None,
460a76ee58fSStanislav Funiak                               entry.depth + 1);
461a76ee58fSStanislav Funiak               return;
462a76ee58fSStanislav Funiak             }
463a76ee58fSStanislav Funiak 
464a76ee58fSStanislav Funiak             // Default case: visit all the operands.
465e4853be2SMehdi Amini             for (const auto &p : llvm::enumerate(operationOp.operands()))
466a76ee58fSStanislav Funiak               toVisit.emplace(p.value(), entry.value, p.index(),
467a76ee58fSStanislav Funiak                               entry.depth + 1);
468a76ee58fSStanislav Funiak           })
469a76ee58fSStanislav Funiak           .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
470a76ee58fSStanislav Funiak             toVisit.emplace(resultOp.parent(), entry.value, resultOp.index(),
471a76ee58fSStanislav Funiak                             entry.depth);
472a76ee58fSStanislav Funiak           });
473a76ee58fSStanislav Funiak     }
474a76ee58fSStanislav Funiak   }
475a76ee58fSStanislav Funiak 
476a76ee58fSStanislav Funiak   // Now build the cost graph.
477a76ee58fSStanislav Funiak   // This is simply a minimum over all depths for the target root.
478a76ee58fSStanislav Funiak   unsigned nextID = 0;
479a76ee58fSStanislav Funiak   for (const auto &connectorRootsDepths : connectorsRootsDepths) {
480a76ee58fSStanislav Funiak     Value value = connectorRootsDepths.first;
481a76ee58fSStanislav Funiak     ArrayRef<RootDepth> rootsDepths = connectorRootsDepths.second;
482a76ee58fSStanislav Funiak     // If there is only one root for this value, this will not trigger
483a76ee58fSStanislav Funiak     // any edges in the cost graph (a perf optimization).
484a76ee58fSStanislav Funiak     if (rootsDepths.size() == 1)
485a76ee58fSStanislav Funiak       continue;
486a76ee58fSStanislav Funiak 
487a76ee58fSStanislav Funiak     for (const RootDepth &p : rootsDepths) {
488a76ee58fSStanislav Funiak       for (const RootDepth &q : rootsDepths) {
489a76ee58fSStanislav Funiak         if (&p == &q)
490a76ee58fSStanislav Funiak           continue;
491a76ee58fSStanislav Funiak         // Insert or retrieve the property of edge from p to q.
4929eb8e7b1SStanislav Funiak         RootOrderingEntry &entry = graph[q.root][p.root];
4939eb8e7b1SStanislav Funiak         if (!entry.connector /* new edge */ || entry.cost.first > q.depth) {
4949eb8e7b1SStanislav Funiak           if (!entry.connector)
4959eb8e7b1SStanislav Funiak             entry.cost.second = nextID++;
4969eb8e7b1SStanislav Funiak           entry.cost.first = q.depth;
4979eb8e7b1SStanislav Funiak           entry.connector = value;
498a76ee58fSStanislav Funiak         }
499a76ee58fSStanislav Funiak       }
500a76ee58fSStanislav Funiak     }
501a76ee58fSStanislav Funiak   }
502a76ee58fSStanislav Funiak 
503a76ee58fSStanislav Funiak   assert((llvm::hasSingleElement(roots) || graph.size() == roots.size()) &&
504a76ee58fSStanislav Funiak          "the pattern contains a candidate root disconnected from the others");
505a76ee58fSStanislav Funiak }
506a76ee58fSStanislav Funiak 
5072692eae5SStanislav Funiak /// Returns true if the operand at the given index needs to be queried using an
5082692eae5SStanislav Funiak /// operand group, i.e., if it is variadic itself or follows a variadic operand.
useOperandGroup(pdl::OperationOp op,unsigned index)5092692eae5SStanislav Funiak static bool useOperandGroup(pdl::OperationOp op, unsigned index) {
5102692eae5SStanislav Funiak   OperandRange operands = op.operands();
5112692eae5SStanislav Funiak   assert(index < operands.size() && "operand index out of range");
5122692eae5SStanislav Funiak   for (unsigned i = 0; i <= index; ++i)
5132692eae5SStanislav Funiak     if (operands[i].getType().isa<pdl::RangeType>())
5142692eae5SStanislav Funiak       return true;
5152692eae5SStanislav Funiak   return false;
5162692eae5SStanislav Funiak }
5172692eae5SStanislav Funiak 
518a76ee58fSStanislav Funiak /// Visit a node during upward traversal.
visitUpward(std::vector<PositionalPredicate> & predList,OpIndex opIndex,PredicateBuilder & builder,DenseMap<Value,Position * > & valueToPosition,Position * & pos,unsigned rootID)5192692eae5SStanislav Funiak static void visitUpward(std::vector<PositionalPredicate> &predList,
5202692eae5SStanislav Funiak                         OpIndex opIndex, PredicateBuilder &builder,
5212692eae5SStanislav Funiak                         DenseMap<Value, Position *> &valueToPosition,
5222692eae5SStanislav Funiak                         Position *&pos, unsigned rootID) {
523a76ee58fSStanislav Funiak   Value value = opIndex.parent;
524a76ee58fSStanislav Funiak   TypeSwitch<Operation *>(value.getDefiningOp())
525a76ee58fSStanislav Funiak       .Case<pdl::OperationOp>([&](auto operationOp) {
526a76ee58fSStanislav Funiak         LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
527a76ee58fSStanislav Funiak 
5282692eae5SStanislav Funiak         // Get users and iterate over them.
5292692eae5SStanislav Funiak         Position *usersPos = builder.getUsers(pos, /*useRepresentative=*/true);
5302692eae5SStanislav Funiak         Position *foreachPos = builder.getForEach(usersPos, rootID);
5312692eae5SStanislav Funiak         OperationPosition *opPos = builder.getPassthroughOp(foreachPos);
5322692eae5SStanislav Funiak 
5332692eae5SStanislav Funiak         // Compare the operand(s) of the user against the input value(s).
5342692eae5SStanislav Funiak         Position *operandPos;
5352692eae5SStanislav Funiak         if (!opIndex.index) {
5362692eae5SStanislav Funiak           // We are querying all the operands of the operation.
5372692eae5SStanislav Funiak           operandPos = builder.getAllOperands(opPos);
5382692eae5SStanislav Funiak         } else if (useOperandGroup(operationOp, *opIndex.index)) {
5392692eae5SStanislav Funiak           // We are querying an operand group.
5402692eae5SStanislav Funiak           Type type = operationOp.operands()[*opIndex.index].getType();
5412692eae5SStanislav Funiak           bool variadic = type.isa<pdl::RangeType>();
5422692eae5SStanislav Funiak           operandPos = builder.getOperandGroup(opPos, opIndex.index, variadic);
5432692eae5SStanislav Funiak         } else {
5442692eae5SStanislav Funiak           // We are querying an individual operand.
5452692eae5SStanislav Funiak           operandPos = builder.getOperand(opPos, *opIndex.index);
546a76ee58fSStanislav Funiak         }
5472692eae5SStanislav Funiak         predList.emplace_back(operandPos, builder.getEqualTo(pos));
548a76ee58fSStanislav Funiak 
549a76ee58fSStanislav Funiak         // Guard against duplicate upward visits. These are not possible,
550a76ee58fSStanislav Funiak         // because if this value was already visited, it would have been
551a76ee58fSStanislav Funiak         // cheaper to start the traversal at this value rather than at the
552a76ee58fSStanislav Funiak         // `connector`, violating the optimality of our spanning tree.
553a76ee58fSStanislav Funiak         bool inserted = valueToPosition.try_emplace(value, opPos).second;
5541b0312d2SBenjamin Kramer         (void)inserted;
555a76ee58fSStanislav Funiak         assert(inserted && "duplicate upward visit");
556a76ee58fSStanislav Funiak 
557a76ee58fSStanislav Funiak         // Obtain the tree predicates at the current value.
558a76ee58fSStanislav Funiak         getTreePredicates(predList, value, builder, valueToPosition, opPos,
559a76ee58fSStanislav Funiak                           opIndex.index);
560a76ee58fSStanislav Funiak 
561a76ee58fSStanislav Funiak         // Update the position
562a76ee58fSStanislav Funiak         pos = opPos;
563a76ee58fSStanislav Funiak       })
564a76ee58fSStanislav Funiak       .Case<pdl::ResultOp>([&](auto resultOp) {
565a76ee58fSStanislav Funiak         // Traverse up an individual result.
566a76ee58fSStanislav Funiak         auto *opPos = dyn_cast<OperationPosition>(pos);
567a76ee58fSStanislav Funiak         assert(opPos && "operations and results must be interleaved");
568a76ee58fSStanislav Funiak         pos = builder.getResult(opPos, *opIndex.index);
5692692eae5SStanislav Funiak 
5702692eae5SStanislav Funiak         // Insert the result position in case we have not visited it yet.
5712692eae5SStanislav Funiak         valueToPosition.try_emplace(value, pos);
572a76ee58fSStanislav Funiak       })
573a76ee58fSStanislav Funiak       .Case<pdl::ResultsOp>([&](auto resultOp) {
574a76ee58fSStanislav Funiak         // Traverse up a group of results.
575a76ee58fSStanislav Funiak         auto *opPos = dyn_cast<OperationPosition>(pos);
576a76ee58fSStanislav Funiak         assert(opPos && "operations and results must be interleaved");
577a76ee58fSStanislav Funiak         bool isVariadic = value.getType().isa<pdl::RangeType>();
578a76ee58fSStanislav Funiak         if (opIndex.index)
579a76ee58fSStanislav Funiak           pos = builder.getResultGroup(opPos, opIndex.index, isVariadic);
580a76ee58fSStanislav Funiak         else
581a76ee58fSStanislav Funiak           pos = builder.getAllResults(opPos);
5822692eae5SStanislav Funiak 
5832692eae5SStanislav Funiak         // Insert the result position in case we have not visited it yet.
5842692eae5SStanislav Funiak         valueToPosition.try_emplace(value, pos);
585a76ee58fSStanislav Funiak       });
586a76ee58fSStanislav Funiak }
587a76ee58fSStanislav Funiak 
5888a1ca2cdSRiver Riddle /// Given a pattern operation, build the set of matcher predicates necessary to
5898a1ca2cdSRiver Riddle /// match this pattern.
buildPredicateList(pdl::PatternOp pattern,PredicateBuilder & builder,std::vector<PositionalPredicate> & predList,DenseMap<Value,Position * > & valueToPosition)590a76ee58fSStanislav Funiak static Value buildPredicateList(pdl::PatternOp pattern,
5918a1ca2cdSRiver Riddle                                 PredicateBuilder &builder,
5928a1ca2cdSRiver Riddle                                 std::vector<PositionalPredicate> &predList,
5938a1ca2cdSRiver Riddle                                 DenseMap<Value, Position *> &valueToPosition) {
594a76ee58fSStanislav Funiak   SmallVector<Value> roots = detectRoots(pattern);
595a76ee58fSStanislav Funiak 
596a76ee58fSStanislav Funiak   // Build the root ordering graph and compute the parent maps.
597a76ee58fSStanislav Funiak   RootOrderingGraph graph;
598a76ee58fSStanislav Funiak   ParentMaps parentMaps;
599a76ee58fSStanislav Funiak   buildCostGraph(roots, graph, parentMaps);
600a76ee58fSStanislav Funiak   LLVM_DEBUG({
601a76ee58fSStanislav Funiak     llvm::dbgs() << "Graph:\n";
602a76ee58fSStanislav Funiak     for (auto &target : graph) {
6032692eae5SStanislav Funiak       llvm::dbgs() << "  * " << target.first.getLoc() << " " << target.first
6042692eae5SStanislav Funiak                    << "\n";
605a76ee58fSStanislav Funiak       for (auto &source : target.second) {
6069eb8e7b1SStanislav Funiak         RootOrderingEntry &entry = source.second;
6079eb8e7b1SStanislav Funiak         llvm::dbgs() << "      <- " << source.first << ": " << entry.cost.first
6089eb8e7b1SStanislav Funiak                      << ":" << entry.cost.second << " via "
6099eb8e7b1SStanislav Funiak                      << entry.connector.getLoc() << "\n";
610a76ee58fSStanislav Funiak       }
611a76ee58fSStanislav Funiak     }
612a76ee58fSStanislav Funiak   });
613a76ee58fSStanislav Funiak 
614a76ee58fSStanislav Funiak   // Solve the optimal branching problem for each candidate root, or use the
615a76ee58fSStanislav Funiak   // provided one.
616a76ee58fSStanislav Funiak   Value bestRoot = pattern.getRewriter().root();
617a76ee58fSStanislav Funiak   OptimalBranching::EdgeList bestEdges;
618a76ee58fSStanislav Funiak   if (!bestRoot) {
619a76ee58fSStanislav Funiak     unsigned bestCost = 0;
620a76ee58fSStanislav Funiak     LLVM_DEBUG(llvm::dbgs() << "Candidate roots:\n");
621a76ee58fSStanislav Funiak     for (Value root : roots) {
622a76ee58fSStanislav Funiak       OptimalBranching solver(graph, root);
623a76ee58fSStanislav Funiak       unsigned cost = solver.solve();
624a76ee58fSStanislav Funiak       LLVM_DEBUG(llvm::dbgs() << "  * " << root << ": " << cost << "\n");
625a76ee58fSStanislav Funiak       if (!bestRoot || bestCost > cost) {
626a76ee58fSStanislav Funiak         bestCost = cost;
627a76ee58fSStanislav Funiak         bestRoot = root;
628a76ee58fSStanislav Funiak         bestEdges = solver.preOrderTraversal(roots);
629a76ee58fSStanislav Funiak       }
630a76ee58fSStanislav Funiak     }
631a76ee58fSStanislav Funiak   } else {
632a76ee58fSStanislav Funiak     OptimalBranching solver(graph, bestRoot);
633a76ee58fSStanislav Funiak     solver.solve();
634a76ee58fSStanislav Funiak     bestEdges = solver.preOrderTraversal(roots);
635a76ee58fSStanislav Funiak   }
636a76ee58fSStanislav Funiak 
6372692eae5SStanislav Funiak   // Print the best solution.
6382692eae5SStanislav Funiak   LLVM_DEBUG({
6392692eae5SStanislav Funiak     llvm::dbgs() << "Best tree:\n";
6402692eae5SStanislav Funiak     for (const std::pair<Value, Value> &edge : bestEdges) {
6412692eae5SStanislav Funiak       llvm::dbgs() << "  * " << edge.first;
6422692eae5SStanislav Funiak       if (edge.second)
6432692eae5SStanislav Funiak         llvm::dbgs() << " <- " << edge.second;
6442692eae5SStanislav Funiak       llvm::dbgs() << "\n";
6452692eae5SStanislav Funiak     }
6462692eae5SStanislav Funiak   });
6472692eae5SStanislav Funiak 
648a76ee58fSStanislav Funiak   LLVM_DEBUG(llvm::dbgs() << "Calling key getTreePredicates:\n");
649a76ee58fSStanislav Funiak   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << bestRoot << "\n");
650a76ee58fSStanislav Funiak 
651a76ee58fSStanislav Funiak   // The best root is the starting point for the traversal. Get the tree
652a76ee58fSStanislav Funiak   // predicates for the DAG rooted at bestRoot.
653a76ee58fSStanislav Funiak   getTreePredicates(predList, bestRoot, builder, valueToPosition,
654a76ee58fSStanislav Funiak                     builder.getRoot());
655a76ee58fSStanislav Funiak 
656a76ee58fSStanislav Funiak   // Traverse the selected optimal branching. For all edges in order, traverse
657a76ee58fSStanislav Funiak   // up starting from the connector, until the candidate root is reached, and
658a76ee58fSStanislav Funiak   // call getTreePredicates at every node along the way.
65950da0134SAdrian Kuegel   for (const auto &it : llvm::enumerate(bestEdges)) {
6602692eae5SStanislav Funiak     Value target = it.value().first;
6612692eae5SStanislav Funiak     Value source = it.value().second;
662a76ee58fSStanislav Funiak 
663a76ee58fSStanislav Funiak     // Check if we already visited the target root. This happens in two cases:
664a76ee58fSStanislav Funiak     // 1) the initial root (bestRoot);
665a76ee58fSStanislav Funiak     // 2) a root that is dominated by (contained in the subtree rooted at) an
666a76ee58fSStanislav Funiak     //    already visited root.
667a76ee58fSStanislav Funiak     if (valueToPosition.count(target))
668a76ee58fSStanislav Funiak       continue;
669a76ee58fSStanislav Funiak 
670a76ee58fSStanislav Funiak     // Determine the connector.
671a76ee58fSStanislav Funiak     Value connector = graph[target][source].connector;
672a76ee58fSStanislav Funiak     assert(connector && "invalid edge");
673a76ee58fSStanislav Funiak     LLVM_DEBUG(llvm::dbgs() << "  * Connector: " << connector.getLoc() << "\n");
674a76ee58fSStanislav Funiak     DenseMap<Value, OpIndex> parentMap = parentMaps.lookup(target);
675a76ee58fSStanislav Funiak     Position *pos = valueToPosition.lookup(connector);
6762692eae5SStanislav Funiak     assert(pos && "connector has not been traversed yet");
677a76ee58fSStanislav Funiak 
678a76ee58fSStanislav Funiak     // Traverse from the connector upwards towards the target root.
679a76ee58fSStanislav Funiak     for (Value value = connector; value != target;) {
680a76ee58fSStanislav Funiak       OpIndex opIndex = parentMap.lookup(value);
681a76ee58fSStanislav Funiak       assert(opIndex.parent && "missing parent");
6822692eae5SStanislav Funiak       visitUpward(predList, opIndex, builder, valueToPosition, pos, it.index());
683a76ee58fSStanislav Funiak       value = opIndex.parent;
684a76ee58fSStanislav Funiak     }
685a76ee58fSStanislav Funiak   }
686a76ee58fSStanislav Funiak 
687242762c9SRiver Riddle   getNonTreePredicates(pattern, predList, builder, valueToPosition);
688a76ee58fSStanislav Funiak 
689a76ee58fSStanislav Funiak   return bestRoot;
6908a1ca2cdSRiver Riddle }
6918a1ca2cdSRiver Riddle 
6928a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
6938a1ca2cdSRiver Riddle // Pattern Predicate Tree Merging
6948a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
6958a1ca2cdSRiver Riddle 
6968a1ca2cdSRiver Riddle namespace {
6978a1ca2cdSRiver Riddle 
6988a1ca2cdSRiver Riddle /// This class represents a specific predicate applied to a position, and
6998a1ca2cdSRiver Riddle /// provides hashing and ordering operators. This class allows for computing a
7008a1ca2cdSRiver Riddle /// frequence sum and ordering predicates based on a cost model.
7018a1ca2cdSRiver Riddle struct OrderedPredicate {
OrderedPredicate__anon12da54731511::OrderedPredicate7028a1ca2cdSRiver Riddle   OrderedPredicate(const std::pair<Position *, Qualifier *> &ip)
7038a1ca2cdSRiver Riddle       : position(ip.first), question(ip.second) {}
OrderedPredicate__anon12da54731511::OrderedPredicate7048a1ca2cdSRiver Riddle   OrderedPredicate(const PositionalPredicate &ip)
7058a1ca2cdSRiver Riddle       : position(ip.position), question(ip.question) {}
7068a1ca2cdSRiver Riddle 
7078a1ca2cdSRiver Riddle   /// The position this predicate is applied to.
7088a1ca2cdSRiver Riddle   Position *position;
7098a1ca2cdSRiver Riddle 
7108a1ca2cdSRiver Riddle   /// The question that is applied by this predicate onto the position.
7118a1ca2cdSRiver Riddle   Qualifier *question;
7128a1ca2cdSRiver Riddle 
7138a1ca2cdSRiver Riddle   /// The first and second order benefit sums.
7148a1ca2cdSRiver Riddle   /// The primary sum is the number of occurrences of this predicate among all
7158a1ca2cdSRiver Riddle   /// of the patterns.
7168a1ca2cdSRiver Riddle   unsigned primary = 0;
7178a1ca2cdSRiver Riddle   /// The secondary sum is a squared summation of the primary sum of all of the
7188a1ca2cdSRiver Riddle   /// predicates within each pattern that contains this predicate. This allows
7198a1ca2cdSRiver Riddle   /// for favoring predicates that are more commonly shared within a pattern, as
7208a1ca2cdSRiver Riddle   /// opposed to those shared across patterns.
7218a1ca2cdSRiver Riddle   unsigned secondary = 0;
7228a1ca2cdSRiver Riddle 
723138803e0SStanislav Funiak   /// The tie breaking ID, used to preserve a deterministic (insertion) order
724138803e0SStanislav Funiak   /// among all the predicates with the same priority, depth, and position /
725138803e0SStanislav Funiak   /// predicate dependency.
726138803e0SStanislav Funiak   unsigned id = 0;
727138803e0SStanislav Funiak 
7288a1ca2cdSRiver Riddle   /// A map between a pattern operation and the answer to the predicate question
7298a1ca2cdSRiver Riddle   /// within that pattern.
7308a1ca2cdSRiver Riddle   DenseMap<Operation *, Qualifier *> patternToAnswer;
7318a1ca2cdSRiver Riddle 
732ddd556f1SRiver Riddle   /// Returns true if this predicate is ordered before `rhs`, based on the cost
733ddd556f1SRiver Riddle   /// model.
operator <__anon12da54731511::OrderedPredicate734ddd556f1SRiver Riddle   bool operator<(const OrderedPredicate &rhs) const {
7358a1ca2cdSRiver Riddle     // Sort by:
736ddd556f1SRiver Riddle     // * higher first and secondary order sums
7378a1ca2cdSRiver Riddle     // * lower depth
738ddd556f1SRiver Riddle     // * lower position dependency
739ddd556f1SRiver Riddle     // * lower predicate dependency
740138803e0SStanislav Funiak     // * lower tie breaking ID
741ddd556f1SRiver Riddle     auto *rhsPos = rhs.position;
7423a833a0eSRiver Riddle     return std::make_tuple(primary, secondary, rhsPos->getOperationDepth(),
743138803e0SStanislav Funiak                            rhsPos->getKind(), rhs.question->getKind(), rhs.id) >
744ddd556f1SRiver Riddle            std::make_tuple(rhs.primary, rhs.secondary,
7453a833a0eSRiver Riddle                            position->getOperationDepth(), position->getKind(),
746138803e0SStanislav Funiak                            question->getKind(), id);
7478a1ca2cdSRiver Riddle   }
7488a1ca2cdSRiver Riddle };
7498a1ca2cdSRiver Riddle 
7508a1ca2cdSRiver Riddle /// A DenseMapInfo for OrderedPredicate based solely on the position and
7518a1ca2cdSRiver Riddle /// question.
7528a1ca2cdSRiver Riddle struct OrderedPredicateDenseInfo {
7538a1ca2cdSRiver Riddle   using Base = DenseMapInfo<std::pair<Position *, Qualifier *>>;
7548a1ca2cdSRiver Riddle 
getEmptyKey__anon12da54731511::OrderedPredicateDenseInfo7558a1ca2cdSRiver Riddle   static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); }
getTombstoneKey__anon12da54731511::OrderedPredicateDenseInfo7568a1ca2cdSRiver Riddle   static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); }
isEqual__anon12da54731511::OrderedPredicateDenseInfo7578a1ca2cdSRiver Riddle   static bool isEqual(const OrderedPredicate &lhs,
7588a1ca2cdSRiver Riddle                       const OrderedPredicate &rhs) {
7598a1ca2cdSRiver Riddle     return lhs.position == rhs.position && lhs.question == rhs.question;
7608a1ca2cdSRiver Riddle   }
getHashValue__anon12da54731511::OrderedPredicateDenseInfo7618a1ca2cdSRiver Riddle   static unsigned getHashValue(const OrderedPredicate &p) {
7628a1ca2cdSRiver Riddle     return llvm::hash_combine(p.position, p.question);
7638a1ca2cdSRiver Riddle   }
7648a1ca2cdSRiver Riddle };
7658a1ca2cdSRiver Riddle 
7668a1ca2cdSRiver Riddle /// This class wraps a set of ordered predicates that are used within a specific
7678a1ca2cdSRiver Riddle /// pattern operation.
7688a1ca2cdSRiver Riddle struct OrderedPredicateList {
OrderedPredicateList__anon12da54731511::OrderedPredicateList769a76ee58fSStanislav Funiak   OrderedPredicateList(pdl::PatternOp pattern, Value root)
770a76ee58fSStanislav Funiak       : pattern(pattern), root(root) {}
7718a1ca2cdSRiver Riddle 
7728a1ca2cdSRiver Riddle   pdl::PatternOp pattern;
773a76ee58fSStanislav Funiak   Value root;
7748a1ca2cdSRiver Riddle   DenseSet<OrderedPredicate *> predicates;
7758a1ca2cdSRiver Riddle };
776be0a7e9fSMehdi Amini } // namespace
7778a1ca2cdSRiver Riddle 
7788a1ca2cdSRiver Riddle /// Returns true if the given matcher refers to the same predicate as the given
7798a1ca2cdSRiver Riddle /// ordered predicate. This means that the position and questions of the two
7808a1ca2cdSRiver Riddle /// match.
isSamePredicate(MatcherNode * node,OrderedPredicate * predicate)7818a1ca2cdSRiver Riddle static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) {
7828a1ca2cdSRiver Riddle   return node->getPosition() == predicate->position &&
7838a1ca2cdSRiver Riddle          node->getQuestion() == predicate->question;
7848a1ca2cdSRiver Riddle }
7858a1ca2cdSRiver Riddle 
7868a1ca2cdSRiver Riddle /// Get or insert a child matcher for the given parent switch node, given a
7878a1ca2cdSRiver Riddle /// predicate and parent pattern.
getOrCreateChild(SwitchNode * node,OrderedPredicate * predicate,pdl::PatternOp pattern)7888a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node,
7898a1ca2cdSRiver Riddle                                                OrderedPredicate *predicate,
7908a1ca2cdSRiver Riddle                                                pdl::PatternOp pattern) {
7918a1ca2cdSRiver Riddle   assert(isSamePredicate(node, predicate) &&
7928a1ca2cdSRiver Riddle          "expected matcher to equal the given predicate");
7938a1ca2cdSRiver Riddle 
7948a1ca2cdSRiver Riddle   auto it = predicate->patternToAnswer.find(pattern);
7958a1ca2cdSRiver Riddle   assert(it != predicate->patternToAnswer.end() &&
7968a1ca2cdSRiver Riddle          "expected pattern to exist in predicate");
7978a1ca2cdSRiver Riddle   return node->getChildren().insert({it->second, nullptr}).first->second;
7988a1ca2cdSRiver Riddle }
7998a1ca2cdSRiver Riddle 
8008a1ca2cdSRiver Riddle /// Build the matcher CFG by "pushing" patterns through by sorted predicate
8018a1ca2cdSRiver Riddle /// order. A pattern will traverse as far as possible using common predicates
8028a1ca2cdSRiver Riddle /// and then either diverge from the CFG or reach the end of a branch and start
8038a1ca2cdSRiver Riddle /// creating new nodes.
propagatePattern(std::unique_ptr<MatcherNode> & node,OrderedPredicateList & list,std::vector<OrderedPredicate * >::iterator current,std::vector<OrderedPredicate * >::iterator end)8048a1ca2cdSRiver Riddle static void propagatePattern(std::unique_ptr<MatcherNode> &node,
8058a1ca2cdSRiver Riddle                              OrderedPredicateList &list,
8068a1ca2cdSRiver Riddle                              std::vector<OrderedPredicate *>::iterator current,
8078a1ca2cdSRiver Riddle                              std::vector<OrderedPredicate *>::iterator end) {
8088a1ca2cdSRiver Riddle   if (current == end) {
8098a1ca2cdSRiver Riddle     // We've hit the end of a pattern, so create a successful result node.
810a76ee58fSStanislav Funiak     node =
811a76ee58fSStanislav Funiak         std::make_unique<SuccessNode>(list.pattern, list.root, std::move(node));
8128a1ca2cdSRiver Riddle 
8138a1ca2cdSRiver Riddle     // If the pattern doesn't contain this predicate, ignore it.
8148a1ca2cdSRiver Riddle   } else if (list.predicates.find(*current) == list.predicates.end()) {
8158a1ca2cdSRiver Riddle     propagatePattern(node, list, std::next(current), end);
8168a1ca2cdSRiver Riddle 
8178a1ca2cdSRiver Riddle     // If the current matcher node is invalid, create a new one for this
8188a1ca2cdSRiver Riddle     // position and continue propagation.
8198a1ca2cdSRiver Riddle   } else if (!node) {
8208a1ca2cdSRiver Riddle     // Create a new node at this position and continue
8218a1ca2cdSRiver Riddle     node = std::make_unique<SwitchNode>((*current)->position,
8228a1ca2cdSRiver Riddle                                         (*current)->question);
8238a1ca2cdSRiver Riddle     propagatePattern(
8248a1ca2cdSRiver Riddle         getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
8258a1ca2cdSRiver Riddle         list, std::next(current), end);
8268a1ca2cdSRiver Riddle 
8278a1ca2cdSRiver Riddle     // If the matcher has already been created, and it is for this predicate we
8288a1ca2cdSRiver Riddle     // continue propagation to the child.
8298a1ca2cdSRiver Riddle   } else if (isSamePredicate(node.get(), *current)) {
8308a1ca2cdSRiver Riddle     propagatePattern(
8318a1ca2cdSRiver Riddle         getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
8328a1ca2cdSRiver Riddle         list, std::next(current), end);
8338a1ca2cdSRiver Riddle 
8348a1ca2cdSRiver Riddle     // If the matcher doesn't match the current predicate, insert a branch as
8358a1ca2cdSRiver Riddle     // the common set of matchers has diverged.
8368a1ca2cdSRiver Riddle   } else {
8378a1ca2cdSRiver Riddle     propagatePattern(node->getFailureNode(), list, current, end);
8388a1ca2cdSRiver Riddle   }
8398a1ca2cdSRiver Riddle }
8408a1ca2cdSRiver Riddle 
8418a1ca2cdSRiver Riddle /// Fold any switch nodes nested under `node` to boolean nodes when possible.
8428a1ca2cdSRiver Riddle /// `node` is updated in-place if it is a switch.
foldSwitchToBool(std::unique_ptr<MatcherNode> & node)8438a1ca2cdSRiver Riddle static void foldSwitchToBool(std::unique_ptr<MatcherNode> &node) {
8448a1ca2cdSRiver Riddle   if (!node)
8458a1ca2cdSRiver Riddle     return;
8468a1ca2cdSRiver Riddle 
8478a1ca2cdSRiver Riddle   if (SwitchNode *switchNode = dyn_cast<SwitchNode>(&*node)) {
8488a1ca2cdSRiver Riddle     SwitchNode::ChildMapT &children = switchNode->getChildren();
8498a1ca2cdSRiver Riddle     for (auto &it : children)
8508a1ca2cdSRiver Riddle       foldSwitchToBool(it.second);
8518a1ca2cdSRiver Riddle 
8528a1ca2cdSRiver Riddle     // If the node only contains one child, collapse it into a boolean predicate
8538a1ca2cdSRiver Riddle     // node.
8548a1ca2cdSRiver Riddle     if (children.size() == 1) {
8558a1ca2cdSRiver Riddle       auto childIt = children.begin();
8568a1ca2cdSRiver Riddle       node = std::make_unique<BoolNode>(
8578a1ca2cdSRiver Riddle           node->getPosition(), node->getQuestion(), childIt->first,
8588a1ca2cdSRiver Riddle           std::move(childIt->second), std::move(node->getFailureNode()));
8598a1ca2cdSRiver Riddle     }
8608a1ca2cdSRiver Riddle   } else if (BoolNode *boolNode = dyn_cast<BoolNode>(&*node)) {
8618a1ca2cdSRiver Riddle     foldSwitchToBool(boolNode->getSuccessNode());
8628a1ca2cdSRiver Riddle   }
8638a1ca2cdSRiver Riddle 
8648a1ca2cdSRiver Riddle   foldSwitchToBool(node->getFailureNode());
8658a1ca2cdSRiver Riddle }
8668a1ca2cdSRiver Riddle 
8678a1ca2cdSRiver Riddle /// Insert an exit node at the end of the failure path of the `root`.
insertExitNode(std::unique_ptr<MatcherNode> * root)8688a1ca2cdSRiver Riddle static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
8698a1ca2cdSRiver Riddle   while (*root)
8708a1ca2cdSRiver Riddle     root = &(*root)->getFailureNode();
8718a1ca2cdSRiver Riddle   *root = std::make_unique<ExitNode>();
8728a1ca2cdSRiver Riddle }
8738a1ca2cdSRiver Riddle 
8748a1ca2cdSRiver Riddle /// Given a module containing PDL pattern operations, generate a matcher tree
8758a1ca2cdSRiver Riddle /// using the patterns within the given module and return the root matcher node.
8768a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode>
generateMatcherTree(ModuleOp module,PredicateBuilder & builder,DenseMap<Value,Position * > & valueToPosition)8778a1ca2cdSRiver Riddle MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
8788a1ca2cdSRiver Riddle                                  DenseMap<Value, Position *> &valueToPosition) {
879a76ee58fSStanislav Funiak   // The set of predicates contained within the pattern operations of the
880a76ee58fSStanislav Funiak   // module.
881a76ee58fSStanislav Funiak   struct PatternPredicates {
882a76ee58fSStanislav Funiak     PatternPredicates(pdl::PatternOp pattern, Value root,
883a76ee58fSStanislav Funiak                       std::vector<PositionalPredicate> predicates)
884a76ee58fSStanislav Funiak         : pattern(pattern), root(root), predicates(std::move(predicates)) {}
885a76ee58fSStanislav Funiak 
886a76ee58fSStanislav Funiak     /// A pattern.
887a76ee58fSStanislav Funiak     pdl::PatternOp pattern;
888a76ee58fSStanislav Funiak 
889a76ee58fSStanislav Funiak     /// A root of the pattern chosen among the candidate roots in pdl.rewrite.
890a76ee58fSStanislav Funiak     Value root;
891a76ee58fSStanislav Funiak 
892a76ee58fSStanislav Funiak     /// The extracted predicates for this pattern and root.
893a76ee58fSStanislav Funiak     std::vector<PositionalPredicate> predicates;
894a76ee58fSStanislav Funiak   };
895a76ee58fSStanislav Funiak 
896a76ee58fSStanislav Funiak   SmallVector<PatternPredicates, 16> patternsAndPredicates;
8978a1ca2cdSRiver Riddle   for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
8988a1ca2cdSRiver Riddle     std::vector<PositionalPredicate> predicateList;
899a76ee58fSStanislav Funiak     Value root =
9008a1ca2cdSRiver Riddle         buildPredicateList(pattern, builder, predicateList, valueToPosition);
901a76ee58fSStanislav Funiak     patternsAndPredicates.emplace_back(pattern, root, std::move(predicateList));
9028a1ca2cdSRiver Riddle   }
9038a1ca2cdSRiver Riddle 
9048a1ca2cdSRiver Riddle   // Associate a pattern result with each unique predicate.
9058a1ca2cdSRiver Riddle   DenseSet<OrderedPredicate, OrderedPredicateDenseInfo> uniqued;
9068a1ca2cdSRiver Riddle   for (auto &patternAndPredList : patternsAndPredicates) {
907a76ee58fSStanislav Funiak     for (auto &predicate : patternAndPredList.predicates) {
9088a1ca2cdSRiver Riddle       auto it = uniqued.insert(predicate);
909a76ee58fSStanislav Funiak       it.first->patternToAnswer.try_emplace(patternAndPredList.pattern,
9108a1ca2cdSRiver Riddle                                             predicate.answer);
911138803e0SStanislav Funiak       // Mark the insertion order (0-based indexing).
912138803e0SStanislav Funiak       if (it.second)
913138803e0SStanislav Funiak         it.first->id = uniqued.size() - 1;
9148a1ca2cdSRiver Riddle     }
9158a1ca2cdSRiver Riddle   }
9168a1ca2cdSRiver Riddle 
9178a1ca2cdSRiver Riddle   // Associate each pattern to a set of its ordered predicates for later lookup.
9188a1ca2cdSRiver Riddle   std::vector<OrderedPredicateList> lists;
9198a1ca2cdSRiver Riddle   lists.reserve(patternsAndPredicates.size());
9208a1ca2cdSRiver Riddle   for (auto &patternAndPredList : patternsAndPredicates) {
921a76ee58fSStanislav Funiak     OrderedPredicateList list(patternAndPredList.pattern,
922a76ee58fSStanislav Funiak                               patternAndPredList.root);
923a76ee58fSStanislav Funiak     for (auto &predicate : patternAndPredList.predicates) {
9248a1ca2cdSRiver Riddle       OrderedPredicate *orderedPredicate = &*uniqued.find(predicate);
9258a1ca2cdSRiver Riddle       list.predicates.insert(orderedPredicate);
9268a1ca2cdSRiver Riddle 
9278a1ca2cdSRiver Riddle       // Increment the primary sum for each reference to a particular predicate.
9288a1ca2cdSRiver Riddle       ++orderedPredicate->primary;
9298a1ca2cdSRiver Riddle     }
9308a1ca2cdSRiver Riddle     lists.push_back(std::move(list));
9318a1ca2cdSRiver Riddle   }
9328a1ca2cdSRiver Riddle 
9338a1ca2cdSRiver Riddle   // For a particular pattern, get the total primary sum and add it to the
9348a1ca2cdSRiver Riddle   // secondary sum of each predicate. Square the primary sums to emphasize
9358a1ca2cdSRiver Riddle   // shared predicates within rather than across patterns.
9368a1ca2cdSRiver Riddle   for (auto &list : lists) {
9378a1ca2cdSRiver Riddle     unsigned total = 0;
9388a1ca2cdSRiver Riddle     for (auto *predicate : list.predicates)
9398a1ca2cdSRiver Riddle       total += predicate->primary * predicate->primary;
9408a1ca2cdSRiver Riddle     for (auto *predicate : list.predicates)
9418a1ca2cdSRiver Riddle       predicate->secondary += total;
9428a1ca2cdSRiver Riddle   }
9438a1ca2cdSRiver Riddle 
9448a1ca2cdSRiver Riddle   // Sort the set of predicates now that the cost primary and secondary sums
9458a1ca2cdSRiver Riddle   // have been computed.
9468a1ca2cdSRiver Riddle   std::vector<OrderedPredicate *> ordered;
9478a1ca2cdSRiver Riddle   ordered.reserve(uniqued.size());
9488a1ca2cdSRiver Riddle   for (auto &ip : uniqued)
9498a1ca2cdSRiver Riddle     ordered.push_back(&ip);
950138803e0SStanislav Funiak   llvm::sort(ordered, [](OrderedPredicate *lhs, OrderedPredicate *rhs) {
951138803e0SStanislav Funiak     return *lhs < *rhs;
952138803e0SStanislav Funiak   });
9538a1ca2cdSRiver Riddle 
9548a1ca2cdSRiver Riddle   // Build the matchers for each of the pattern predicate lists.
9558a1ca2cdSRiver Riddle   std::unique_ptr<MatcherNode> root;
9568a1ca2cdSRiver Riddle   for (OrderedPredicateList &list : lists)
9578a1ca2cdSRiver Riddle     propagatePattern(root, list, ordered.begin(), ordered.end());
9588a1ca2cdSRiver Riddle 
9598a1ca2cdSRiver Riddle   // Collapse the graph and insert the exit node.
9608a1ca2cdSRiver Riddle   foldSwitchToBool(root);
9618a1ca2cdSRiver Riddle   insertExitNode(&root);
9628a1ca2cdSRiver Riddle   return root;
9638a1ca2cdSRiver Riddle }
9648a1ca2cdSRiver Riddle 
9658a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
9668a1ca2cdSRiver Riddle // MatcherNode
9678a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
9688a1ca2cdSRiver Riddle 
MatcherNode(TypeID matcherTypeID,Position * p,Qualifier * q,std::unique_ptr<MatcherNode> failureNode)9698a1ca2cdSRiver Riddle MatcherNode::MatcherNode(TypeID matcherTypeID, Position *p, Qualifier *q,
9708a1ca2cdSRiver Riddle                          std::unique_ptr<MatcherNode> failureNode)
9718a1ca2cdSRiver Riddle     : position(p), question(q), failureNode(std::move(failureNode)),
9728a1ca2cdSRiver Riddle       matcherTypeID(matcherTypeID) {}
9738a1ca2cdSRiver Riddle 
9748a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
9758a1ca2cdSRiver Riddle // BoolNode
9768a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
9778a1ca2cdSRiver Riddle 
BoolNode(Position * position,Qualifier * question,Qualifier * answer,std::unique_ptr<MatcherNode> successNode,std::unique_ptr<MatcherNode> failureNode)9788a1ca2cdSRiver Riddle BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer,
9798a1ca2cdSRiver Riddle                    std::unique_ptr<MatcherNode> successNode,
9808a1ca2cdSRiver Riddle                    std::unique_ptr<MatcherNode> failureNode)
9818a1ca2cdSRiver Riddle     : MatcherNode(TypeID::get<BoolNode>(), position, question,
9828a1ca2cdSRiver Riddle                   std::move(failureNode)),
9838a1ca2cdSRiver Riddle       answer(answer), successNode(std::move(successNode)) {}
9848a1ca2cdSRiver Riddle 
9858a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
9868a1ca2cdSRiver Riddle // SuccessNode
9878a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
9888a1ca2cdSRiver Riddle 
SuccessNode(pdl::PatternOp pattern,Value root,std::unique_ptr<MatcherNode> failureNode)989a76ee58fSStanislav Funiak SuccessNode::SuccessNode(pdl::PatternOp pattern, Value root,
9908a1ca2cdSRiver Riddle                          std::unique_ptr<MatcherNode> failureNode)
9918a1ca2cdSRiver Riddle     : MatcherNode(TypeID::get<SuccessNode>(), /*position=*/nullptr,
9928a1ca2cdSRiver Riddle                   /*question=*/nullptr, std::move(failureNode)),
993a76ee58fSStanislav Funiak       pattern(pattern), root(root) {}
9948a1ca2cdSRiver Riddle 
9958a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
9968a1ca2cdSRiver Riddle // SwitchNode
9978a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
9988a1ca2cdSRiver Riddle 
SwitchNode(Position * position,Qualifier * question)9998a1ca2cdSRiver Riddle SwitchNode::SwitchNode(Position *position, Qualifier *question)
10008a1ca2cdSRiver Riddle     : MatcherNode(TypeID::get<SwitchNode>(), position, question) {}
1001