1abfd1a8bSRiver Riddle //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2abfd1a8bSRiver Riddle //
3abfd1a8bSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4abfd1a8bSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5abfd1a8bSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6abfd1a8bSRiver Riddle //
7abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
8abfd1a8bSRiver Riddle //
9abfd1a8bSRiver Riddle // This file implements MLIR to byte-code generation and the interpreter.
10abfd1a8bSRiver Riddle //
11abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
12abfd1a8bSRiver Riddle 
13abfd1a8bSRiver Riddle #include "ByteCode.h"
14abfd1a8bSRiver Riddle #include "mlir/Analysis/Liveness.h"
15abfd1a8bSRiver Riddle #include "mlir/Dialect/PDL/IR/PDLTypes.h"
16abfd1a8bSRiver Riddle #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
17e66c2e25SRiver Riddle #include "mlir/IR/BuiltinOps.h"
18abfd1a8bSRiver Riddle #include "mlir/IR/RegionGraphTraits.h"
19abfd1a8bSRiver Riddle #include "llvm/ADT/IntervalMap.h"
20abfd1a8bSRiver Riddle #include "llvm/ADT/PostOrderIterator.h"
21abfd1a8bSRiver Riddle #include "llvm/ADT/TypeSwitch.h"
22abfd1a8bSRiver Riddle #include "llvm/Support/Debug.h"
2385ab413bSRiver Riddle #include "llvm/Support/Format.h"
2485ab413bSRiver Riddle #include "llvm/Support/FormatVariadic.h"
2585ab413bSRiver Riddle #include <numeric>
26abfd1a8bSRiver Riddle 
27abfd1a8bSRiver Riddle #define DEBUG_TYPE "pdl-bytecode"
28abfd1a8bSRiver Riddle 
29abfd1a8bSRiver Riddle using namespace mlir;
30abfd1a8bSRiver Riddle using namespace mlir::detail;
31abfd1a8bSRiver Riddle 
32abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
33abfd1a8bSRiver Riddle // PDLByteCodePattern
34abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
35abfd1a8bSRiver Riddle 
create(pdl_interp::RecordMatchOp matchOp,ByteCodeAddr rewriterAddr)36abfd1a8bSRiver Riddle PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
37abfd1a8bSRiver Riddle                                               ByteCodeAddr rewriterAddr) {
38abfd1a8bSRiver Riddle   SmallVector<StringRef, 8> generatedOps;
393c405c3bSRiver Riddle   if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
40abfd1a8bSRiver Riddle     generatedOps =
41abfd1a8bSRiver Riddle         llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
42abfd1a8bSRiver Riddle 
433c405c3bSRiver Riddle   PatternBenefit benefit = matchOp.getBenefit();
44abfd1a8bSRiver Riddle   MLIRContext *ctx = matchOp.getContext();
45abfd1a8bSRiver Riddle 
46abfd1a8bSRiver Riddle   // Check to see if this is pattern matches a specific operation type.
473c405c3bSRiver Riddle   if (Optional<StringRef> rootKind = matchOp.getRootKind())
4876f3c2f3SRiver Riddle     return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx,
4976f3c2f3SRiver Riddle                               generatedOps);
5076f3c2f3SRiver Riddle   return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx,
5176f3c2f3SRiver Riddle                             generatedOps);
52abfd1a8bSRiver Riddle }
53abfd1a8bSRiver Riddle 
54abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
55abfd1a8bSRiver Riddle // PDLByteCodeMutableState
56abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
57abfd1a8bSRiver Riddle 
58abfd1a8bSRiver Riddle /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
59abfd1a8bSRiver Riddle /// to the position of the pattern within the range returned by
60abfd1a8bSRiver Riddle /// `PDLByteCode::getPatterns`.
updatePatternBenefit(unsigned patternIndex,PatternBenefit benefit)61abfd1a8bSRiver Riddle void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
62abfd1a8bSRiver Riddle                                                    PatternBenefit benefit) {
63abfd1a8bSRiver Riddle   currentPatternBenefits[patternIndex] = benefit;
64abfd1a8bSRiver Riddle }
65abfd1a8bSRiver Riddle 
6685ab413bSRiver Riddle /// Cleanup any allocated state after a full match/rewrite has been completed.
6785ab413bSRiver Riddle /// This method should be called irregardless of whether the match+rewrite was a
6885ab413bSRiver Riddle /// success or not.
cleanupAfterMatchAndRewrite()6985ab413bSRiver Riddle void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() {
7085ab413bSRiver Riddle   allocatedTypeRangeMemory.clear();
7185ab413bSRiver Riddle   allocatedValueRangeMemory.clear();
7285ab413bSRiver Riddle }
7385ab413bSRiver Riddle 
74abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
75abfd1a8bSRiver Riddle // Bytecode OpCodes
76abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
77abfd1a8bSRiver Riddle 
78abfd1a8bSRiver Riddle namespace {
79abfd1a8bSRiver Riddle enum OpCode : ByteCodeField {
80abfd1a8bSRiver Riddle   /// Apply an externally registered constraint.
81abfd1a8bSRiver Riddle   ApplyConstraint,
82abfd1a8bSRiver Riddle   /// Apply an externally registered rewrite.
83abfd1a8bSRiver Riddle   ApplyRewrite,
84abfd1a8bSRiver Riddle   /// Check if two generic values are equal.
85abfd1a8bSRiver Riddle   AreEqual,
8685ab413bSRiver Riddle   /// Check if two ranges are equal.
8785ab413bSRiver Riddle   AreRangesEqual,
88abfd1a8bSRiver Riddle   /// Unconditional branch.
89abfd1a8bSRiver Riddle   Branch,
90abfd1a8bSRiver Riddle   /// Compare the operand count of an operation with a constant.
91abfd1a8bSRiver Riddle   CheckOperandCount,
92abfd1a8bSRiver Riddle   /// Compare the name of an operation with a constant.
93abfd1a8bSRiver Riddle   CheckOperationName,
94abfd1a8bSRiver Riddle   /// Compare the result count of an operation with a constant.
95abfd1a8bSRiver Riddle   CheckResultCount,
9685ab413bSRiver Riddle   /// Compare a range of types to a constant range of types.
9785ab413bSRiver Riddle   CheckTypes,
983eb1647aSStanislav Funiak   /// Continue to the next iteration of a loop.
993eb1647aSStanislav Funiak   Continue,
100abfd1a8bSRiver Riddle   /// Create an operation.
101abfd1a8bSRiver Riddle   CreateOperation,
10285ab413bSRiver Riddle   /// Create a range of types.
10385ab413bSRiver Riddle   CreateTypes,
104abfd1a8bSRiver Riddle   /// Erase an operation.
105abfd1a8bSRiver Riddle   EraseOp,
1063eb1647aSStanislav Funiak   /// Extract the op from a range at the specified index.
1073eb1647aSStanislav Funiak   ExtractOp,
1083eb1647aSStanislav Funiak   /// Extract the type from a range at the specified index.
1093eb1647aSStanislav Funiak   ExtractType,
1103eb1647aSStanislav Funiak   /// Extract the value from a range at the specified index.
1113eb1647aSStanislav Funiak   ExtractValue,
112abfd1a8bSRiver Riddle   /// Terminate a matcher or rewrite sequence.
113abfd1a8bSRiver Riddle   Finalize,
1143eb1647aSStanislav Funiak   /// Iterate over a range of values.
1153eb1647aSStanislav Funiak   ForEach,
116abfd1a8bSRiver Riddle   /// Get a specific attribute of an operation.
117abfd1a8bSRiver Riddle   GetAttribute,
118abfd1a8bSRiver Riddle   /// Get the type of an attribute.
119abfd1a8bSRiver Riddle   GetAttributeType,
120abfd1a8bSRiver Riddle   /// Get the defining operation of a value.
121abfd1a8bSRiver Riddle   GetDefiningOp,
122abfd1a8bSRiver Riddle   /// Get a specific operand of an operation.
123abfd1a8bSRiver Riddle   GetOperand0,
124abfd1a8bSRiver Riddle   GetOperand1,
125abfd1a8bSRiver Riddle   GetOperand2,
126abfd1a8bSRiver Riddle   GetOperand3,
127abfd1a8bSRiver Riddle   GetOperandN,
12885ab413bSRiver Riddle   /// Get a specific operand group of an operation.
12985ab413bSRiver Riddle   GetOperands,
130abfd1a8bSRiver Riddle   /// Get a specific result of an operation.
131abfd1a8bSRiver Riddle   GetResult0,
132abfd1a8bSRiver Riddle   GetResult1,
133abfd1a8bSRiver Riddle   GetResult2,
134abfd1a8bSRiver Riddle   GetResult3,
135abfd1a8bSRiver Riddle   GetResultN,
13685ab413bSRiver Riddle   /// Get a specific result group of an operation.
13785ab413bSRiver Riddle   GetResults,
1383eb1647aSStanislav Funiak   /// Get the users of a value or a range of values.
1393eb1647aSStanislav Funiak   GetUsers,
140abfd1a8bSRiver Riddle   /// Get the type of a value.
141abfd1a8bSRiver Riddle   GetValueType,
14285ab413bSRiver Riddle   /// Get the types of a value range.
14385ab413bSRiver Riddle   GetValueRangeTypes,
144abfd1a8bSRiver Riddle   /// Check if a generic value is not null.
145abfd1a8bSRiver Riddle   IsNotNull,
146abfd1a8bSRiver Riddle   /// Record a successful pattern match.
147abfd1a8bSRiver Riddle   RecordMatch,
148abfd1a8bSRiver Riddle   /// Replace an operation.
149abfd1a8bSRiver Riddle   ReplaceOp,
150abfd1a8bSRiver Riddle   /// Compare an attribute with a set of constants.
151abfd1a8bSRiver Riddle   SwitchAttribute,
152abfd1a8bSRiver Riddle   /// Compare the operand count of an operation with a set of constants.
153abfd1a8bSRiver Riddle   SwitchOperandCount,
154abfd1a8bSRiver Riddle   /// Compare the name of an operation with a set of constants.
155abfd1a8bSRiver Riddle   SwitchOperationName,
156abfd1a8bSRiver Riddle   /// Compare the result count of an operation with a set of constants.
157abfd1a8bSRiver Riddle   SwitchResultCount,
158abfd1a8bSRiver Riddle   /// Compare a type with a set of constants.
159abfd1a8bSRiver Riddle   SwitchType,
16085ab413bSRiver Riddle   /// Compare a range of types with a set of constants.
16185ab413bSRiver Riddle   SwitchTypes,
162abfd1a8bSRiver Riddle };
163be0a7e9fSMehdi Amini } // namespace
164abfd1a8bSRiver Riddle 
1653c752289SRiver Riddle /// A marker used to indicate if an operation should infer types.
1663c752289SRiver Riddle static constexpr ByteCodeField kInferTypesMarker =
1673c752289SRiver Riddle     std::numeric_limits<ByteCodeField>::max();
1683c752289SRiver Riddle 
169abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
170abfd1a8bSRiver Riddle // ByteCode Generation
171abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
172abfd1a8bSRiver Riddle 
173abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
174abfd1a8bSRiver Riddle // Generator
175abfd1a8bSRiver Riddle 
176abfd1a8bSRiver Riddle namespace {
1773eb1647aSStanislav Funiak struct ByteCodeLiveRange;
178abfd1a8bSRiver Riddle struct ByteCodeWriter;
179abfd1a8bSRiver Riddle 
1803eb1647aSStanislav Funiak /// Check if the given class `T` can be converted to an opaque pointer.
1813eb1647aSStanislav Funiak template <typename T, typename... Args>
1823eb1647aSStanislav Funiak using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
1833eb1647aSStanislav Funiak 
184abfd1a8bSRiver Riddle /// This class represents the main generator for the pattern bytecode.
185abfd1a8bSRiver Riddle class Generator {
186abfd1a8bSRiver Riddle public:
Generator(MLIRContext * ctx,std::vector<const void * > & uniquedData,SmallVectorImpl<ByteCodeField> & matcherByteCode,SmallVectorImpl<ByteCodeField> & rewriterByteCode,SmallVectorImpl<PDLByteCodePattern> & patterns,ByteCodeField & maxValueMemoryIndex,ByteCodeField & maxOpRangeMemoryIndex,ByteCodeField & maxTypeRangeMemoryIndex,ByteCodeField & maxValueRangeMemoryIndex,ByteCodeField & maxLoopLevel,llvm::StringMap<PDLConstraintFunction> & constraintFns,llvm::StringMap<PDLRewriteFunction> & rewriteFns)187abfd1a8bSRiver Riddle   Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
188abfd1a8bSRiver Riddle             SmallVectorImpl<ByteCodeField> &matcherByteCode,
189abfd1a8bSRiver Riddle             SmallVectorImpl<ByteCodeField> &rewriterByteCode,
190abfd1a8bSRiver Riddle             SmallVectorImpl<PDLByteCodePattern> &patterns,
191abfd1a8bSRiver Riddle             ByteCodeField &maxValueMemoryIndex,
1923eb1647aSStanislav Funiak             ByteCodeField &maxOpRangeMemoryIndex,
19385ab413bSRiver Riddle             ByteCodeField &maxTypeRangeMemoryIndex,
19485ab413bSRiver Riddle             ByteCodeField &maxValueRangeMemoryIndex,
1953eb1647aSStanislav Funiak             ByteCodeField &maxLoopLevel,
196abfd1a8bSRiver Riddle             llvm::StringMap<PDLConstraintFunction> &constraintFns,
197abfd1a8bSRiver Riddle             llvm::StringMap<PDLRewriteFunction> &rewriteFns)
198abfd1a8bSRiver Riddle       : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
199abfd1a8bSRiver Riddle         rewriterByteCode(rewriterByteCode), patterns(patterns),
20085ab413bSRiver Riddle         maxValueMemoryIndex(maxValueMemoryIndex),
2013eb1647aSStanislav Funiak         maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
20285ab413bSRiver Riddle         maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
2033eb1647aSStanislav Funiak         maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
2043eb1647aSStanislav Funiak         maxLoopLevel(maxLoopLevel) {
205e4853be2SMehdi Amini     for (const auto &it : llvm::enumerate(constraintFns))
206abfd1a8bSRiver Riddle       constraintToMemIndex.try_emplace(it.value().first(), it.index());
207e4853be2SMehdi Amini     for (const auto &it : llvm::enumerate(rewriteFns))
208abfd1a8bSRiver Riddle       externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
209abfd1a8bSRiver Riddle   }
210abfd1a8bSRiver Riddle 
211abfd1a8bSRiver Riddle   /// Generate the bytecode for the given PDL interpreter module.
212abfd1a8bSRiver Riddle   void generate(ModuleOp module);
213abfd1a8bSRiver Riddle 
214abfd1a8bSRiver Riddle   /// Return the memory index to use for the given value.
getMemIndex(Value value)215abfd1a8bSRiver Riddle   ByteCodeField &getMemIndex(Value value) {
216abfd1a8bSRiver Riddle     assert(valueToMemIndex.count(value) &&
217abfd1a8bSRiver Riddle            "expected memory index to be assigned");
218abfd1a8bSRiver Riddle     return valueToMemIndex[value];
219abfd1a8bSRiver Riddle   }
220abfd1a8bSRiver Riddle 
22185ab413bSRiver Riddle   /// Return the range memory index used to store the given range value.
getRangeStorageIndex(Value value)22285ab413bSRiver Riddle   ByteCodeField &getRangeStorageIndex(Value value) {
22385ab413bSRiver Riddle     assert(valueToRangeIndex.count(value) &&
22485ab413bSRiver Riddle            "expected range index to be assigned");
22585ab413bSRiver Riddle     return valueToRangeIndex[value];
22685ab413bSRiver Riddle   }
22785ab413bSRiver Riddle 
228abfd1a8bSRiver Riddle   /// Return an index to use when referring to the given data that is uniqued in
229abfd1a8bSRiver Riddle   /// the MLIR context.
230abfd1a8bSRiver Riddle   template <typename T>
231abfd1a8bSRiver Riddle   std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
getMemIndex(T val)232abfd1a8bSRiver Riddle   getMemIndex(T val) {
233abfd1a8bSRiver Riddle     const void *opaqueVal = val.getAsOpaquePointer();
234abfd1a8bSRiver Riddle 
235abfd1a8bSRiver Riddle     // Get or insert a reference to this value.
236abfd1a8bSRiver Riddle     auto it = uniquedDataToMemIndex.try_emplace(
237abfd1a8bSRiver Riddle         opaqueVal, maxValueMemoryIndex + uniquedData.size());
238abfd1a8bSRiver Riddle     if (it.second)
239abfd1a8bSRiver Riddle       uniquedData.push_back(opaqueVal);
240abfd1a8bSRiver Riddle     return it.first->second;
241abfd1a8bSRiver Riddle   }
242abfd1a8bSRiver Riddle 
243abfd1a8bSRiver Riddle private:
244abfd1a8bSRiver Riddle   /// Allocate memory indices for the results of operations within the matcher
245abfd1a8bSRiver Riddle   /// and rewriters.
246f96a8675SRiver Riddle   void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
247f96a8675SRiver Riddle                              ModuleOp rewriterModule);
248abfd1a8bSRiver Riddle 
249abfd1a8bSRiver Riddle   /// Generate the bytecode for the given operation.
2503eb1647aSStanislav Funiak   void generate(Region *region, ByteCodeWriter &writer);
251abfd1a8bSRiver Riddle   void generate(Operation *op, ByteCodeWriter &writer);
252abfd1a8bSRiver Riddle   void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
253abfd1a8bSRiver Riddle   void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
254abfd1a8bSRiver Riddle   void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
255abfd1a8bSRiver Riddle   void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
256abfd1a8bSRiver Riddle   void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
257abfd1a8bSRiver Riddle   void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
258abfd1a8bSRiver Riddle   void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
259abfd1a8bSRiver Riddle   void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
260abfd1a8bSRiver Riddle   void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
26185ab413bSRiver Riddle   void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
2623eb1647aSStanislav Funiak   void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
263abfd1a8bSRiver Riddle   void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
264abfd1a8bSRiver Riddle   void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
265abfd1a8bSRiver Riddle   void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
26685ab413bSRiver Riddle   void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
267abfd1a8bSRiver Riddle   void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
2683eb1647aSStanislav Funiak   void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
269abfd1a8bSRiver Riddle   void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
2703eb1647aSStanislav Funiak   void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
271abfd1a8bSRiver Riddle   void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
272abfd1a8bSRiver Riddle   void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
273abfd1a8bSRiver Riddle   void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
274abfd1a8bSRiver Riddle   void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
27585ab413bSRiver Riddle   void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
276abfd1a8bSRiver Riddle   void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
27785ab413bSRiver Riddle   void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
2783eb1647aSStanislav Funiak   void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
279abfd1a8bSRiver Riddle   void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
280abfd1a8bSRiver Riddle   void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
281abfd1a8bSRiver Riddle   void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
282abfd1a8bSRiver Riddle   void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
283abfd1a8bSRiver Riddle   void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
284abfd1a8bSRiver Riddle   void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
28585ab413bSRiver Riddle   void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
286abfd1a8bSRiver Riddle   void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
287abfd1a8bSRiver Riddle   void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
288abfd1a8bSRiver Riddle   void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
289abfd1a8bSRiver Riddle 
290abfd1a8bSRiver Riddle   /// Mapping from value to its corresponding memory index.
291abfd1a8bSRiver Riddle   DenseMap<Value, ByteCodeField> valueToMemIndex;
292abfd1a8bSRiver Riddle 
29385ab413bSRiver Riddle   /// Mapping from a range value to its corresponding range storage index.
29485ab413bSRiver Riddle   DenseMap<Value, ByteCodeField> valueToRangeIndex;
29585ab413bSRiver Riddle 
296abfd1a8bSRiver Riddle   /// Mapping from the name of an externally registered rewrite to its index in
297abfd1a8bSRiver Riddle   /// the bytecode registry.
298abfd1a8bSRiver Riddle   llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
299abfd1a8bSRiver Riddle 
300abfd1a8bSRiver Riddle   /// Mapping from the name of an externally registered constraint to its index
301abfd1a8bSRiver Riddle   /// in the bytecode registry.
302abfd1a8bSRiver Riddle   llvm::StringMap<ByteCodeField> constraintToMemIndex;
303abfd1a8bSRiver Riddle 
304abfd1a8bSRiver Riddle   /// Mapping from rewriter function name to the bytecode address of the
305abfd1a8bSRiver Riddle   /// rewriter function in byte.
306abfd1a8bSRiver Riddle   llvm::StringMap<ByteCodeAddr> rewriterToAddr;
307abfd1a8bSRiver Riddle 
308abfd1a8bSRiver Riddle   /// Mapping from a uniqued storage object to its memory index within
309abfd1a8bSRiver Riddle   /// `uniquedData`.
310abfd1a8bSRiver Riddle   DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
311abfd1a8bSRiver Riddle 
3123eb1647aSStanislav Funiak   /// The current level of the foreach loop.
3133eb1647aSStanislav Funiak   ByteCodeField curLoopLevel = 0;
3143eb1647aSStanislav Funiak 
315abfd1a8bSRiver Riddle   /// The current MLIR context.
316abfd1a8bSRiver Riddle   MLIRContext *ctx;
317abfd1a8bSRiver Riddle 
3183eb1647aSStanislav Funiak   /// Mapping from block to its address.
3193eb1647aSStanislav Funiak   DenseMap<Block *, ByteCodeAddr> blockToAddr;
3203eb1647aSStanislav Funiak 
321abfd1a8bSRiver Riddle   /// Data of the ByteCode class to be populated.
322abfd1a8bSRiver Riddle   std::vector<const void *> &uniquedData;
323abfd1a8bSRiver Riddle   SmallVectorImpl<ByteCodeField> &matcherByteCode;
324abfd1a8bSRiver Riddle   SmallVectorImpl<ByteCodeField> &rewriterByteCode;
325abfd1a8bSRiver Riddle   SmallVectorImpl<PDLByteCodePattern> &patterns;
326abfd1a8bSRiver Riddle   ByteCodeField &maxValueMemoryIndex;
3273eb1647aSStanislav Funiak   ByteCodeField &maxOpRangeMemoryIndex;
32885ab413bSRiver Riddle   ByteCodeField &maxTypeRangeMemoryIndex;
32985ab413bSRiver Riddle   ByteCodeField &maxValueRangeMemoryIndex;
3303eb1647aSStanislav Funiak   ByteCodeField &maxLoopLevel;
331abfd1a8bSRiver Riddle };
332abfd1a8bSRiver Riddle 
333abfd1a8bSRiver Riddle /// This class provides utilities for writing a bytecode stream.
334abfd1a8bSRiver Riddle struct ByteCodeWriter {
ByteCodeWriter__anonaa7cf1d90211::ByteCodeWriter335abfd1a8bSRiver Riddle   ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
336abfd1a8bSRiver Riddle       : bytecode(bytecode), generator(generator) {}
337abfd1a8bSRiver Riddle 
338abfd1a8bSRiver Riddle   /// Append a field to the bytecode.
append__anonaa7cf1d90211::ByteCodeWriter339abfd1a8bSRiver Riddle   void append(ByteCodeField field) { bytecode.push_back(field); }
append__anonaa7cf1d90211::ByteCodeWriter340fa20ab7bSRiver Riddle   void append(OpCode opCode) { bytecode.push_back(opCode); }
341abfd1a8bSRiver Riddle 
342abfd1a8bSRiver Riddle   /// Append an address to the bytecode.
append__anonaa7cf1d90211::ByteCodeWriter343abfd1a8bSRiver Riddle   void append(ByteCodeAddr field) {
344abfd1a8bSRiver Riddle     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
345abfd1a8bSRiver Riddle                   "unexpected ByteCode address size");
346abfd1a8bSRiver Riddle 
347abfd1a8bSRiver Riddle     ByteCodeField fieldParts[2];
348abfd1a8bSRiver Riddle     std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
349abfd1a8bSRiver Riddle     bytecode.append({fieldParts[0], fieldParts[1]});
350abfd1a8bSRiver Riddle   }
351abfd1a8bSRiver Riddle 
3523eb1647aSStanislav Funiak   /// Append a single successor to the bytecode, the exact address will need to
353abfd1a8bSRiver Riddle   /// be resolved later.
append__anonaa7cf1d90211::ByteCodeWriter3543eb1647aSStanislav Funiak   void append(Block *successor) {
3553eb1647aSStanislav Funiak     // Add back a reference to the successor so that the address can be resolved
3563eb1647aSStanislav Funiak     // later.
357abfd1a8bSRiver Riddle     unresolvedSuccessorRefs[successor].push_back(bytecode.size());
358abfd1a8bSRiver Riddle     append(ByteCodeAddr(0));
359abfd1a8bSRiver Riddle   }
3603eb1647aSStanislav Funiak 
3613eb1647aSStanislav Funiak   /// Append a successor range to the bytecode, the exact address will need to
3623eb1647aSStanislav Funiak   /// be resolved later.
append__anonaa7cf1d90211::ByteCodeWriter3633eb1647aSStanislav Funiak   void append(SuccessorRange successors) {
3643eb1647aSStanislav Funiak     for (Block *successor : successors)
3653eb1647aSStanislav Funiak       append(successor);
366abfd1a8bSRiver Riddle   }
367abfd1a8bSRiver Riddle 
368abfd1a8bSRiver Riddle   /// Append a range of values that will be read as generic PDLValues.
appendPDLValueList__anonaa7cf1d90211::ByteCodeWriter369abfd1a8bSRiver Riddle   void appendPDLValueList(OperandRange values) {
370abfd1a8bSRiver Riddle     bytecode.push_back(values.size());
37185ab413bSRiver Riddle     for (Value value : values)
37285ab413bSRiver Riddle       appendPDLValue(value);
37385ab413bSRiver Riddle   }
37485ab413bSRiver Riddle 
37585ab413bSRiver Riddle   /// Append a value as a PDLValue.
appendPDLValue__anonaa7cf1d90211::ByteCodeWriter37685ab413bSRiver Riddle   void appendPDLValue(Value value) {
37785ab413bSRiver Riddle     appendPDLValueKind(value);
378abfd1a8bSRiver Riddle     append(value);
379abfd1a8bSRiver Riddle   }
38085ab413bSRiver Riddle 
38185ab413bSRiver Riddle   /// Append the PDLValue::Kind of the given value.
appendPDLValueKind__anonaa7cf1d90211::ByteCodeWriter3823eb1647aSStanislav Funiak   void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); }
3833eb1647aSStanislav Funiak 
3843eb1647aSStanislav Funiak   /// Append the PDLValue::Kind of the given type.
appendPDLValueKind__anonaa7cf1d90211::ByteCodeWriter3853eb1647aSStanislav Funiak   void appendPDLValueKind(Type type) {
38685ab413bSRiver Riddle     PDLValue::Kind kind =
3873eb1647aSStanislav Funiak         TypeSwitch<Type, PDLValue::Kind>(type)
38885ab413bSRiver Riddle             .Case<pdl::AttributeType>(
38985ab413bSRiver Riddle                 [](Type) { return PDLValue::Kind::Attribute; })
39085ab413bSRiver Riddle             .Case<pdl::OperationType>(
39185ab413bSRiver Riddle                 [](Type) { return PDLValue::Kind::Operation; })
39285ab413bSRiver Riddle             .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
39385ab413bSRiver Riddle               if (rangeTy.getElementType().isa<pdl::TypeType>())
39485ab413bSRiver Riddle                 return PDLValue::Kind::TypeRange;
39585ab413bSRiver Riddle               return PDLValue::Kind::ValueRange;
39685ab413bSRiver Riddle             })
39785ab413bSRiver Riddle             .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
39885ab413bSRiver Riddle             .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
39985ab413bSRiver Riddle     bytecode.push_back(static_cast<ByteCodeField>(kind));
400abfd1a8bSRiver Riddle   }
401abfd1a8bSRiver Riddle 
402abfd1a8bSRiver Riddle   /// Append a value that will be stored in a memory slot and not inline within
403abfd1a8bSRiver Riddle   /// the bytecode.
404abfd1a8bSRiver Riddle   template <typename T>
405abfd1a8bSRiver Riddle   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
406abfd1a8bSRiver Riddle                    std::is_pointer<T>::value>
append__anonaa7cf1d90211::ByteCodeWriter407abfd1a8bSRiver Riddle   append(T value) {
408abfd1a8bSRiver Riddle     bytecode.push_back(generator.getMemIndex(value));
409abfd1a8bSRiver Riddle   }
410abfd1a8bSRiver Riddle 
411abfd1a8bSRiver Riddle   /// Append a range of values.
412abfd1a8bSRiver Riddle   template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
413abfd1a8bSRiver Riddle   std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
append__anonaa7cf1d90211::ByteCodeWriter414abfd1a8bSRiver Riddle   append(T range) {
415abfd1a8bSRiver Riddle     bytecode.push_back(llvm::size(range));
416abfd1a8bSRiver Riddle     for (auto it : range)
417abfd1a8bSRiver Riddle       append(it);
418abfd1a8bSRiver Riddle   }
419abfd1a8bSRiver Riddle 
420abfd1a8bSRiver Riddle   /// Append a variadic number of fields to the bytecode.
421abfd1a8bSRiver Riddle   template <typename FieldTy, typename Field2Ty, typename... FieldTys>
append__anonaa7cf1d90211::ByteCodeWriter422abfd1a8bSRiver Riddle   void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
423abfd1a8bSRiver Riddle     append(field);
424abfd1a8bSRiver Riddle     append(field2, fields...);
425abfd1a8bSRiver Riddle   }
426abfd1a8bSRiver Riddle 
427d35f1190SStanislav Funiak   /// Appends a value as a pointer, stored inline within the bytecode.
428d35f1190SStanislav Funiak   template <typename T>
429d35f1190SStanislav Funiak   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
appendInline__anonaa7cf1d90211::ByteCodeWriter430d35f1190SStanislav Funiak   appendInline(T value) {
431d35f1190SStanislav Funiak     constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField);
432d35f1190SStanislav Funiak     const void *pointer = value.getAsOpaquePointer();
433d35f1190SStanislav Funiak     ByteCodeField fieldParts[numParts];
434d35f1190SStanislav Funiak     std::memcpy(fieldParts, &pointer, sizeof(const void *));
435d35f1190SStanislav Funiak     bytecode.append(fieldParts, fieldParts + numParts);
436d35f1190SStanislav Funiak   }
437d35f1190SStanislav Funiak 
438abfd1a8bSRiver Riddle   /// Successor references in the bytecode that have yet to be resolved.
439abfd1a8bSRiver Riddle   DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
440abfd1a8bSRiver Riddle 
441abfd1a8bSRiver Riddle   /// The underlying bytecode buffer.
442abfd1a8bSRiver Riddle   SmallVectorImpl<ByteCodeField> &bytecode;
443abfd1a8bSRiver Riddle 
444abfd1a8bSRiver Riddle   /// The main generator producing PDL.
445abfd1a8bSRiver Riddle   Generator &generator;
446abfd1a8bSRiver Riddle };
44785ab413bSRiver Riddle 
44885ab413bSRiver Riddle /// This class represents a live range of PDL Interpreter values, containing
44985ab413bSRiver Riddle /// information about when values are live within a match/rewrite.
45085ab413bSRiver Riddle struct ByteCodeLiveRange {
4513eb1647aSStanislav Funiak   using Set = llvm::IntervalMap<uint64_t, char, 16>;
45285ab413bSRiver Riddle   using Allocator = Set::Allocator;
45385ab413bSRiver Riddle 
ByteCodeLiveRange__anonaa7cf1d90211::ByteCodeLiveRange4543eb1647aSStanislav Funiak   ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
45585ab413bSRiver Riddle 
45685ab413bSRiver Riddle   /// Union this live range with the one provided.
unionWith__anonaa7cf1d90211::ByteCodeLiveRange45785ab413bSRiver Riddle   void unionWith(const ByteCodeLiveRange &rhs) {
4583eb1647aSStanislav Funiak     for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
4593eb1647aSStanislav Funiak          ++it)
4603eb1647aSStanislav Funiak       liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0);
46185ab413bSRiver Riddle   }
46285ab413bSRiver Riddle 
46385ab413bSRiver Riddle   /// Returns true if this range overlaps with the one provided.
overlaps__anonaa7cf1d90211::ByteCodeLiveRange46485ab413bSRiver Riddle   bool overlaps(const ByteCodeLiveRange &rhs) const {
4653eb1647aSStanislav Funiak     return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
4663eb1647aSStanislav Funiak         .valid();
46785ab413bSRiver Riddle   }
46885ab413bSRiver Riddle 
46985ab413bSRiver Riddle   /// A map representing the ranges of the match/rewrite that a value is live in
47085ab413bSRiver Riddle   /// the interpreter.
4713eb1647aSStanislav Funiak   ///
4723eb1647aSStanislav Funiak   /// We use std::unique_ptr here, because IntervalMap does not provide a
4733eb1647aSStanislav Funiak   /// correct copy or move constructor. We can eliminate the pointer once
4743eb1647aSStanislav Funiak   /// https://reviews.llvm.org/D113240 lands.
4753eb1647aSStanislav Funiak   std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
4763eb1647aSStanislav Funiak 
4773eb1647aSStanislav Funiak   /// The operation range storage index for this range.
4783eb1647aSStanislav Funiak   Optional<unsigned> opRangeIndex;
47985ab413bSRiver Riddle 
48085ab413bSRiver Riddle   /// The type range storage index for this range.
48185ab413bSRiver Riddle   Optional<unsigned> typeRangeIndex;
48285ab413bSRiver Riddle 
48385ab413bSRiver Riddle   /// The value range storage index for this range.
48485ab413bSRiver Riddle   Optional<unsigned> valueRangeIndex;
48585ab413bSRiver Riddle };
486be0a7e9fSMehdi Amini } // namespace
487abfd1a8bSRiver Riddle 
generate(ModuleOp module)488abfd1a8bSRiver Riddle void Generator::generate(ModuleOp module) {
489f96a8675SRiver Riddle   auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
490abfd1a8bSRiver Riddle       pdl_interp::PDLInterpDialect::getMatcherFunctionName());
491abfd1a8bSRiver Riddle   ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
492abfd1a8bSRiver Riddle       pdl_interp::PDLInterpDialect::getRewriterModuleName());
493abfd1a8bSRiver Riddle   assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
494abfd1a8bSRiver Riddle 
495abfd1a8bSRiver Riddle   // Allocate memory indices for the results of operations within the matcher
496abfd1a8bSRiver Riddle   // and rewriters.
497abfd1a8bSRiver Riddle   allocateMemoryIndices(matcherFunc, rewriterModule);
498abfd1a8bSRiver Riddle 
499abfd1a8bSRiver Riddle   // Generate code for the rewriter functions.
500abfd1a8bSRiver Riddle   ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
501f96a8675SRiver Riddle   for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
502abfd1a8bSRiver Riddle     rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
503abfd1a8bSRiver Riddle     for (Operation &op : rewriterFunc.getOps())
504abfd1a8bSRiver Riddle       generate(&op, rewriterByteCodeWriter);
505abfd1a8bSRiver Riddle   }
506abfd1a8bSRiver Riddle   assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
507abfd1a8bSRiver Riddle          "unexpected branches in rewriter function");
508abfd1a8bSRiver Riddle 
509abfd1a8bSRiver Riddle   // Generate code for the matcher function.
510abfd1a8bSRiver Riddle   ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
5113eb1647aSStanislav Funiak   generate(&matcherFunc.getBody(), matcherByteCodeWriter);
512abfd1a8bSRiver Riddle 
513abfd1a8bSRiver Riddle   // Resolve successor references in the matcher.
514abfd1a8bSRiver Riddle   for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
515abfd1a8bSRiver Riddle     ByteCodeAddr addr = blockToAddr[it.first];
516abfd1a8bSRiver Riddle     for (unsigned offsetToFix : it.second)
517abfd1a8bSRiver Riddle       std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
518abfd1a8bSRiver Riddle   }
519abfd1a8bSRiver Riddle }
520abfd1a8bSRiver Riddle 
allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,ModuleOp rewriterModule)521f96a8675SRiver Riddle void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
522abfd1a8bSRiver Riddle                                       ModuleOp rewriterModule) {
523abfd1a8bSRiver Riddle   // Rewriters use simplistic allocation scheme that simply assigns an index to
524abfd1a8bSRiver Riddle   // each result.
525f96a8675SRiver Riddle   for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
52685ab413bSRiver Riddle     ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
52785ab413bSRiver Riddle     auto processRewriterValue = [&](Value val) {
52885ab413bSRiver Riddle       valueToMemIndex.try_emplace(val, index++);
52985ab413bSRiver Riddle       if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) {
53085ab413bSRiver Riddle         Type elementTy = rangeType.getElementType();
53185ab413bSRiver Riddle         if (elementTy.isa<pdl::TypeType>())
53285ab413bSRiver Riddle           valueToRangeIndex.try_emplace(val, typeRangeIndex++);
53385ab413bSRiver Riddle         else if (elementTy.isa<pdl::ValueType>())
53485ab413bSRiver Riddle           valueToRangeIndex.try_emplace(val, valueRangeIndex++);
53585ab413bSRiver Riddle       }
53685ab413bSRiver Riddle     };
53785ab413bSRiver Riddle 
538abfd1a8bSRiver Riddle     for (BlockArgument arg : rewriterFunc.getArguments())
53985ab413bSRiver Riddle       processRewriterValue(arg);
540abfd1a8bSRiver Riddle     rewriterFunc.getBody().walk([&](Operation *op) {
541abfd1a8bSRiver Riddle       for (Value result : op->getResults())
54285ab413bSRiver Riddle         processRewriterValue(result);
543abfd1a8bSRiver Riddle     });
544abfd1a8bSRiver Riddle     if (index > maxValueMemoryIndex)
545abfd1a8bSRiver Riddle       maxValueMemoryIndex = index;
54685ab413bSRiver Riddle     if (typeRangeIndex > maxTypeRangeMemoryIndex)
54785ab413bSRiver Riddle       maxTypeRangeMemoryIndex = typeRangeIndex;
54885ab413bSRiver Riddle     if (valueRangeIndex > maxValueRangeMemoryIndex)
54985ab413bSRiver Riddle       maxValueRangeMemoryIndex = valueRangeIndex;
550abfd1a8bSRiver Riddle   }
551abfd1a8bSRiver Riddle 
552abfd1a8bSRiver Riddle   // The matcher function uses a more sophisticated numbering that tries to
553abfd1a8bSRiver Riddle   // minimize the number of memory indices assigned. This is done by determining
554abfd1a8bSRiver Riddle   // a live range of the values within the matcher, then the allocation is just
555abfd1a8bSRiver Riddle   // finding the minimal number of overlapping live ranges. This is essentially
556abfd1a8bSRiver Riddle   // a simplified form of register allocation where we don't necessarily have a
557abfd1a8bSRiver Riddle   // limited number of registers, but we still want to minimize the number used.
558b4130e9eSStanislav Funiak   DenseMap<Operation *, unsigned> opToFirstIndex;
559b4130e9eSStanislav Funiak   DenseMap<Operation *, unsigned> opToLastIndex;
560b4130e9eSStanislav Funiak 
561b4130e9eSStanislav Funiak   // A custom walk that marks the first and the last index of each operation.
562b4130e9eSStanislav Funiak   // The entry marks the beginning of the liveness range for this operation,
563b4130e9eSStanislav Funiak   // followed by nested operations, followed by the end of the liveness range.
564b4130e9eSStanislav Funiak   unsigned index = 0;
565b4130e9eSStanislav Funiak   llvm::unique_function<void(Operation *)> walk = [&](Operation *op) {
566b4130e9eSStanislav Funiak     opToFirstIndex.try_emplace(op, index++);
567b4130e9eSStanislav Funiak     for (Region &region : op->getRegions())
568b4130e9eSStanislav Funiak       for (Block &block : region.getBlocks())
569b4130e9eSStanislav Funiak         for (Operation &nested : block)
570b4130e9eSStanislav Funiak           walk(&nested);
571b4130e9eSStanislav Funiak     opToLastIndex.try_emplace(op, index++);
572b4130e9eSStanislav Funiak   };
573b4130e9eSStanislav Funiak   walk(matcherFunc);
574abfd1a8bSRiver Riddle 
575abfd1a8bSRiver Riddle   // Liveness info for each of the defs within the matcher.
57685ab413bSRiver Riddle   ByteCodeLiveRange::Allocator allocator;
57785ab413bSRiver Riddle   DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
578abfd1a8bSRiver Riddle 
579abfd1a8bSRiver Riddle   // Assign the root operation being matched to slot 0.
580abfd1a8bSRiver Riddle   BlockArgument rootOpArg = matcherFunc.getArgument(0);
581abfd1a8bSRiver Riddle   valueToMemIndex[rootOpArg] = 0;
582abfd1a8bSRiver Riddle 
583abfd1a8bSRiver Riddle   // Walk each of the blocks, computing the def interval that the value is used.
584abfd1a8bSRiver Riddle   Liveness matcherLiveness(matcherFunc);
5853eb1647aSStanislav Funiak   matcherFunc->walk([&](Block *block) {
5863eb1647aSStanislav Funiak     const LivenessBlockInfo *info = matcherLiveness.getLiveness(block);
587abfd1a8bSRiver Riddle     assert(info && "expected liveness info for block");
588abfd1a8bSRiver Riddle     auto processValue = [&](Value value, Operation *firstUseOrDef) {
589abfd1a8bSRiver Riddle       // We don't need to process the root op argument, this value is always
590abfd1a8bSRiver Riddle       // assigned to the first memory slot.
591abfd1a8bSRiver Riddle       if (value == rootOpArg)
592abfd1a8bSRiver Riddle         return;
593abfd1a8bSRiver Riddle 
594abfd1a8bSRiver Riddle       // Set indices for the range of this block that the value is used.
595abfd1a8bSRiver Riddle       auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
5963eb1647aSStanislav Funiak       defRangeIt->second.liveness->insert(
597b4130e9eSStanislav Funiak           opToFirstIndex[firstUseOrDef],
598b4130e9eSStanislav Funiak           opToLastIndex[info->getEndOperation(value, firstUseOrDef)],
599abfd1a8bSRiver Riddle           /*dummyValue*/ 0);
60085ab413bSRiver Riddle 
60185ab413bSRiver Riddle       // Check to see if this value is a range type.
60285ab413bSRiver Riddle       if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) {
60385ab413bSRiver Riddle         Type eleType = rangeTy.getElementType();
6043eb1647aSStanislav Funiak         if (eleType.isa<pdl::OperationType>())
6053eb1647aSStanislav Funiak           defRangeIt->second.opRangeIndex = 0;
6063eb1647aSStanislav Funiak         else if (eleType.isa<pdl::TypeType>())
60785ab413bSRiver Riddle           defRangeIt->second.typeRangeIndex = 0;
60885ab413bSRiver Riddle         else if (eleType.isa<pdl::ValueType>())
60985ab413bSRiver Riddle           defRangeIt->second.valueRangeIndex = 0;
61085ab413bSRiver Riddle       }
611abfd1a8bSRiver Riddle     };
612abfd1a8bSRiver Riddle 
613abfd1a8bSRiver Riddle     // Process the live-ins of this block.
6143eb1647aSStanislav Funiak     for (Value liveIn : info->in()) {
6153eb1647aSStanislav Funiak       // Only process the value if it has been defined in the current region.
6163eb1647aSStanislav Funiak       // Other values that span across pdl_interp.foreach will be added higher
6173eb1647aSStanislav Funiak       // up. This ensures that the we keep them alive for the entire duration
6183eb1647aSStanislav Funiak       // of the loop.
6193eb1647aSStanislav Funiak       if (liveIn.getParentRegion() == block->getParent())
6203eb1647aSStanislav Funiak         processValue(liveIn, &block->front());
6213eb1647aSStanislav Funiak     }
6223eb1647aSStanislav Funiak 
6233eb1647aSStanislav Funiak     // Process the block arguments for the entry block (those are not live-in).
6243eb1647aSStanislav Funiak     if (block->isEntryBlock()) {
6253eb1647aSStanislav Funiak       for (Value argument : block->getArguments())
6263eb1647aSStanislav Funiak         processValue(argument, &block->front());
6273eb1647aSStanislav Funiak     }
628abfd1a8bSRiver Riddle 
629abfd1a8bSRiver Riddle     // Process any new defs within this block.
6303eb1647aSStanislav Funiak     for (Operation &op : *block)
631abfd1a8bSRiver Riddle       for (Value result : op.getResults())
632abfd1a8bSRiver Riddle         processValue(result, &op);
6333eb1647aSStanislav Funiak   });
634abfd1a8bSRiver Riddle 
635abfd1a8bSRiver Riddle   // Greedily allocate memory slots using the computed def live ranges.
63685ab413bSRiver Riddle   std::vector<ByteCodeLiveRange> allocatedIndices;
6373eb1647aSStanislav Funiak 
6383eb1647aSStanislav Funiak   // The number of memory indices currently allocated (and its next value).
6393eb1647aSStanislav Funiak   // Recall that the root gets allocated memory index 0.
6403eb1647aSStanislav Funiak   ByteCodeField numIndices = 1;
6413eb1647aSStanislav Funiak 
6423eb1647aSStanislav Funiak   // The number of memory ranges of various types (and their next values).
6433eb1647aSStanislav Funiak   ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
6443eb1647aSStanislav Funiak 
645abfd1a8bSRiver Riddle   for (auto &defIt : valueDefRanges) {
646abfd1a8bSRiver Riddle     ByteCodeField &memIndex = valueToMemIndex[defIt.first];
64785ab413bSRiver Riddle     ByteCodeLiveRange &defRange = defIt.second;
648abfd1a8bSRiver Riddle 
649abfd1a8bSRiver Riddle     // Try to allocate to an existing index.
650e4853be2SMehdi Amini     for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) {
65185ab413bSRiver Riddle       ByteCodeLiveRange &existingRange = existingIndexIt.value();
65285ab413bSRiver Riddle       if (!defRange.overlaps(existingRange)) {
65385ab413bSRiver Riddle         existingRange.unionWith(defRange);
654abfd1a8bSRiver Riddle         memIndex = existingIndexIt.index() + 1;
65585ab413bSRiver Riddle 
6563eb1647aSStanislav Funiak         if (defRange.opRangeIndex) {
6573eb1647aSStanislav Funiak           if (!existingRange.opRangeIndex)
6583eb1647aSStanislav Funiak             existingRange.opRangeIndex = numOpRanges++;
6593eb1647aSStanislav Funiak           valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
6603eb1647aSStanislav Funiak         } else if (defRange.typeRangeIndex) {
66185ab413bSRiver Riddle           if (!existingRange.typeRangeIndex)
66285ab413bSRiver Riddle             existingRange.typeRangeIndex = numTypeRanges++;
66385ab413bSRiver Riddle           valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
66485ab413bSRiver Riddle         } else if (defRange.valueRangeIndex) {
66585ab413bSRiver Riddle           if (!existingRange.valueRangeIndex)
66685ab413bSRiver Riddle             existingRange.valueRangeIndex = numValueRanges++;
66785ab413bSRiver Riddle           valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
66885ab413bSRiver Riddle         }
66985ab413bSRiver Riddle         break;
67085ab413bSRiver Riddle       }
671abfd1a8bSRiver Riddle     }
672abfd1a8bSRiver Riddle 
673abfd1a8bSRiver Riddle     // If no existing index could be used, add a new one.
674abfd1a8bSRiver Riddle     if (memIndex == 0) {
675abfd1a8bSRiver Riddle       allocatedIndices.emplace_back(allocator);
67685ab413bSRiver Riddle       ByteCodeLiveRange &newRange = allocatedIndices.back();
67785ab413bSRiver Riddle       newRange.unionWith(defRange);
67885ab413bSRiver Riddle 
6793eb1647aSStanislav Funiak       // Allocate an index for op/type/value ranges.
6803eb1647aSStanislav Funiak       if (defRange.opRangeIndex) {
6813eb1647aSStanislav Funiak         newRange.opRangeIndex = numOpRanges;
6823eb1647aSStanislav Funiak         valueToRangeIndex[defIt.first] = numOpRanges++;
6833eb1647aSStanislav Funiak       } else if (defRange.typeRangeIndex) {
68485ab413bSRiver Riddle         newRange.typeRangeIndex = numTypeRanges;
68585ab413bSRiver Riddle         valueToRangeIndex[defIt.first] = numTypeRanges++;
68685ab413bSRiver Riddle       } else if (defRange.valueRangeIndex) {
68785ab413bSRiver Riddle         newRange.valueRangeIndex = numValueRanges;
68885ab413bSRiver Riddle         valueToRangeIndex[defIt.first] = numValueRanges++;
68985ab413bSRiver Riddle       }
69085ab413bSRiver Riddle 
691abfd1a8bSRiver Riddle       memIndex = allocatedIndices.size();
69285ab413bSRiver Riddle       ++numIndices;
693abfd1a8bSRiver Riddle     }
694abfd1a8bSRiver Riddle   }
695abfd1a8bSRiver Riddle 
6963eb1647aSStanislav Funiak   // Print the index usage and ensure that we did not run out of index space.
6973eb1647aSStanislav Funiak   LLVM_DEBUG({
6983eb1647aSStanislav Funiak     llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices "
6993eb1647aSStanislav Funiak                  << "(down from initial " << valueDefRanges.size() << ").\n";
7003eb1647aSStanislav Funiak   });
7013eb1647aSStanislav Funiak   assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
7023eb1647aSStanislav Funiak          "Ran out of memory for allocated indices");
7033eb1647aSStanislav Funiak 
704abfd1a8bSRiver Riddle   // Update the max number of indices.
70585ab413bSRiver Riddle   if (numIndices > maxValueMemoryIndex)
70685ab413bSRiver Riddle     maxValueMemoryIndex = numIndices;
7073eb1647aSStanislav Funiak   if (numOpRanges > maxOpRangeMemoryIndex)
7083eb1647aSStanislav Funiak     maxOpRangeMemoryIndex = numOpRanges;
70985ab413bSRiver Riddle   if (numTypeRanges > maxTypeRangeMemoryIndex)
71085ab413bSRiver Riddle     maxTypeRangeMemoryIndex = numTypeRanges;
71185ab413bSRiver Riddle   if (numValueRanges > maxValueRangeMemoryIndex)
71285ab413bSRiver Riddle     maxValueRangeMemoryIndex = numValueRanges;
713abfd1a8bSRiver Riddle }
714abfd1a8bSRiver Riddle 
generate(Region * region,ByteCodeWriter & writer)7153eb1647aSStanislav Funiak void Generator::generate(Region *region, ByteCodeWriter &writer) {
7163eb1647aSStanislav Funiak   llvm::ReversePostOrderTraversal<Region *> rpot(region);
7173eb1647aSStanislav Funiak   for (Block *block : rpot) {
7183eb1647aSStanislav Funiak     // Keep track of where this block begins within the matcher function.
7193eb1647aSStanislav Funiak     blockToAddr.try_emplace(block, matcherByteCode.size());
7203eb1647aSStanislav Funiak     for (Operation &op : *block)
7213eb1647aSStanislav Funiak       generate(&op, writer);
7223eb1647aSStanislav Funiak   }
7233eb1647aSStanislav Funiak }
7243eb1647aSStanislav Funiak 
generate(Operation * op,ByteCodeWriter & writer)725abfd1a8bSRiver Riddle void Generator::generate(Operation *op, ByteCodeWriter &writer) {
726d35f1190SStanislav Funiak   LLVM_DEBUG({
727d35f1190SStanislav Funiak     // The following list must contain all the operations that do not
728d35f1190SStanislav Funiak     // produce any bytecode.
7293c752289SRiver Riddle     if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
730d35f1190SStanislav Funiak       writer.appendInline(op->getLoc());
731d35f1190SStanislav Funiak   });
732abfd1a8bSRiver Riddle   TypeSwitch<Operation *>(op)
733abfd1a8bSRiver Riddle       .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
734abfd1a8bSRiver Riddle             pdl_interp::AreEqualOp, pdl_interp::BranchOp,
735abfd1a8bSRiver Riddle             pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
736abfd1a8bSRiver Riddle             pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
73785ab413bSRiver Riddle             pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
7383eb1647aSStanislav Funiak             pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
7393eb1647aSStanislav Funiak             pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp,
7403eb1647aSStanislav Funiak             pdl_interp::CreateTypesOp, pdl_interp::EraseOp,
7413eb1647aSStanislav Funiak             pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
7423eb1647aSStanislav Funiak             pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
7433eb1647aSStanislav Funiak             pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
7443eb1647aSStanislav Funiak             pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
7453eb1647aSStanislav Funiak             pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
7463eb1647aSStanislav Funiak             pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
7473c752289SRiver Riddle             pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
7483c752289SRiver Riddle             pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
7493c752289SRiver Riddle             pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
7503c752289SRiver Riddle             pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
7513c752289SRiver Riddle             pdl_interp::SwitchResultCountOp>(
752abfd1a8bSRiver Riddle           [&](auto interpOp) { this->generate(interpOp, writer); })
753abfd1a8bSRiver Riddle       .Default([](Operation *) {
754abfd1a8bSRiver Riddle         llvm_unreachable("unknown `pdl_interp` operation");
755abfd1a8bSRiver Riddle       });
756abfd1a8bSRiver Riddle }
757abfd1a8bSRiver Riddle 
generate(pdl_interp::ApplyConstraintOp op,ByteCodeWriter & writer)758abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::ApplyConstraintOp op,
759abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
7603c405c3bSRiver Riddle   assert(constraintToMemIndex.count(op.getName()) &&
761abfd1a8bSRiver Riddle          "expected index for constraint function");
7629595f356SRiver Riddle   writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
7633c405c3bSRiver Riddle   writer.appendPDLValueList(op.getArgs());
764abfd1a8bSRiver Riddle   writer.append(op.getSuccessors());
765abfd1a8bSRiver Riddle }
generate(pdl_interp::ApplyRewriteOp op,ByteCodeWriter & writer)766abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::ApplyRewriteOp op,
767abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
7683c405c3bSRiver Riddle   assert(externalRewriterToMemIndex.count(op.getName()) &&
769abfd1a8bSRiver Riddle          "expected index for rewrite function");
7709595f356SRiver Riddle   writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]);
7713c405c3bSRiver Riddle   writer.appendPDLValueList(op.getArgs());
77202c4c0d5SRiver Riddle 
7733c405c3bSRiver Riddle   ResultRange results = op.getResults();
77485ab413bSRiver Riddle   writer.append(ByteCodeField(results.size()));
77585ab413bSRiver Riddle   for (Value result : results) {
77685ab413bSRiver Riddle     // In debug mode we also record the expected kind of the result, so that we
77785ab413bSRiver Riddle     // can provide extra verification of the native rewrite function.
77802c4c0d5SRiver Riddle #ifndef NDEBUG
77985ab413bSRiver Riddle     writer.appendPDLValueKind(result);
78002c4c0d5SRiver Riddle #endif
78185ab413bSRiver Riddle 
78285ab413bSRiver Riddle     // Range results also need to append the range storage index.
78385ab413bSRiver Riddle     if (result.getType().isa<pdl::RangeType>())
78485ab413bSRiver Riddle       writer.append(getRangeStorageIndex(result));
78502c4c0d5SRiver Riddle     writer.append(result);
786abfd1a8bSRiver Riddle   }
78785ab413bSRiver Riddle }
generate(pdl_interp::AreEqualOp op,ByteCodeWriter & writer)788abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
7893c405c3bSRiver Riddle   Value lhs = op.getLhs();
79085ab413bSRiver Riddle   if (lhs.getType().isa<pdl::RangeType>()) {
79185ab413bSRiver Riddle     writer.append(OpCode::AreRangesEqual);
79285ab413bSRiver Riddle     writer.appendPDLValueKind(lhs);
7933c405c3bSRiver Riddle     writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
79485ab413bSRiver Riddle     return;
79585ab413bSRiver Riddle   }
79685ab413bSRiver Riddle 
7973c405c3bSRiver Riddle   writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors());
798abfd1a8bSRiver Riddle }
generate(pdl_interp::BranchOp op,ByteCodeWriter & writer)799abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
8008affe881SRiver Riddle   writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
801abfd1a8bSRiver Riddle }
generate(pdl_interp::CheckAttributeOp op,ByteCodeWriter & writer)802abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckAttributeOp op,
803abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
8043c405c3bSRiver Riddle   writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(),
805abfd1a8bSRiver Riddle                 op.getSuccessors());
806abfd1a8bSRiver Riddle }
generate(pdl_interp::CheckOperandCountOp op,ByteCodeWriter & writer)807abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckOperandCountOp op,
808abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
8093c405c3bSRiver Riddle   writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(),
8103c405c3bSRiver Riddle                 static_cast<ByteCodeField>(op.getCompareAtLeast()),
811abfd1a8bSRiver Riddle                 op.getSuccessors());
812abfd1a8bSRiver Riddle }
generate(pdl_interp::CheckOperationNameOp op,ByteCodeWriter & writer)813abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckOperationNameOp op,
814abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
8153c405c3bSRiver Riddle   writer.append(OpCode::CheckOperationName, op.getInputOp(),
8163c405c3bSRiver Riddle                 OperationName(op.getName(), ctx), op.getSuccessors());
817abfd1a8bSRiver Riddle }
generate(pdl_interp::CheckResultCountOp op,ByteCodeWriter & writer)818abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckResultCountOp op,
819abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
8203c405c3bSRiver Riddle   writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(),
8213c405c3bSRiver Riddle                 static_cast<ByteCodeField>(op.getCompareAtLeast()),
822abfd1a8bSRiver Riddle                 op.getSuccessors());
823abfd1a8bSRiver Riddle }
generate(pdl_interp::CheckTypeOp op,ByteCodeWriter & writer)824abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
8253c405c3bSRiver Riddle   writer.append(OpCode::AreEqual, op.getValue(), op.getType(),
8263c405c3bSRiver Riddle                 op.getSuccessors());
827abfd1a8bSRiver Riddle }
generate(pdl_interp::CheckTypesOp op,ByteCodeWriter & writer)82885ab413bSRiver Riddle void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
8293c405c3bSRiver Riddle   writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(),
8303c405c3bSRiver Riddle                 op.getSuccessors());
83185ab413bSRiver Riddle }
generate(pdl_interp::ContinueOp op,ByteCodeWriter & writer)8323eb1647aSStanislav Funiak void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
8333eb1647aSStanislav Funiak   assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level");
8343eb1647aSStanislav Funiak   writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
8353eb1647aSStanislav Funiak }
generate(pdl_interp::CreateAttributeOp op,ByteCodeWriter & writer)836abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateAttributeOp op,
837abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
838abfd1a8bSRiver Riddle   // Simply repoint the memory index of the result to the constant.
8393c405c3bSRiver Riddle   getMemIndex(op.getAttribute()) = getMemIndex(op.getValue());
840abfd1a8bSRiver Riddle }
generate(pdl_interp::CreateOperationOp op,ByteCodeWriter & writer)841abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateOperationOp op,
842abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
8433c405c3bSRiver Riddle   writer.append(OpCode::CreateOperation, op.getResultOp(),
8443c405c3bSRiver Riddle                 OperationName(op.getName(), ctx));
8453c405c3bSRiver Riddle   writer.appendPDLValueList(op.getInputOperands());
846abfd1a8bSRiver Riddle 
847abfd1a8bSRiver Riddle   // Add the attributes.
8483c405c3bSRiver Riddle   OperandRange attributes = op.getInputAttributes();
849abfd1a8bSRiver Riddle   writer.append(static_cast<ByteCodeField>(attributes.size()));
8503c405c3bSRiver Riddle   for (auto it : llvm::zip(op.getInputAttributeNames(), attributes))
851195730a6SRiver Riddle     writer.append(std::get<0>(it), std::get<1>(it));
8523c752289SRiver Riddle 
8533c752289SRiver Riddle   // Add the result types. If the operation has inferred results, we use a
8543c752289SRiver Riddle   // marker "size" value. Otherwise, we add the list of explicit result types.
8553c752289SRiver Riddle   if (op.getInferredResultTypes())
8563c752289SRiver Riddle     writer.append(kInferTypesMarker);
8573c752289SRiver Riddle   else
8583c405c3bSRiver Riddle     writer.appendPDLValueList(op.getInputResultTypes());
859abfd1a8bSRiver Riddle }
generate(pdl_interp::CreateTypeOp op,ByteCodeWriter & writer)860abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
861abfd1a8bSRiver Riddle   // Simply repoint the memory index of the result to the constant.
8623c405c3bSRiver Riddle   getMemIndex(op.getResult()) = getMemIndex(op.getValue());
863abfd1a8bSRiver Riddle }
generate(pdl_interp::CreateTypesOp op,ByteCodeWriter & writer)86485ab413bSRiver Riddle void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
8653c405c3bSRiver Riddle   writer.append(OpCode::CreateTypes, op.getResult(),
8663c405c3bSRiver Riddle                 getRangeStorageIndex(op.getResult()), op.getValue());
86785ab413bSRiver Riddle }
generate(pdl_interp::EraseOp op,ByteCodeWriter & writer)868abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
8693c405c3bSRiver Riddle   writer.append(OpCode::EraseOp, op.getInputOp());
870abfd1a8bSRiver Riddle }
generate(pdl_interp::ExtractOp op,ByteCodeWriter & writer)8713eb1647aSStanislav Funiak void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
8723eb1647aSStanislav Funiak   OpCode opCode =
8733c405c3bSRiver Riddle       TypeSwitch<Type, OpCode>(op.getResult().getType())
8743eb1647aSStanislav Funiak           .Case([](pdl::OperationType) { return OpCode::ExtractOp; })
8753eb1647aSStanislav Funiak           .Case([](pdl::ValueType) { return OpCode::ExtractValue; })
8763eb1647aSStanislav Funiak           .Case([](pdl::TypeType) { return OpCode::ExtractType; })
8773eb1647aSStanislav Funiak           .Default([](Type) -> OpCode {
8783eb1647aSStanislav Funiak             llvm_unreachable("unsupported element type");
8793eb1647aSStanislav Funiak           });
8803c405c3bSRiver Riddle   writer.append(opCode, op.getRange(), op.getIndex(), op.getResult());
8813eb1647aSStanislav Funiak }
generate(pdl_interp::FinalizeOp op,ByteCodeWriter & writer)882abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
883abfd1a8bSRiver Riddle   writer.append(OpCode::Finalize);
884abfd1a8bSRiver Riddle }
generate(pdl_interp::ForEachOp op,ByteCodeWriter & writer)8853eb1647aSStanislav Funiak void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
8863eb1647aSStanislav Funiak   BlockArgument arg = op.getLoopVariable();
8873c405c3bSRiver Riddle   writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg);
8883eb1647aSStanislav Funiak   writer.appendPDLValueKind(arg.getType());
8893c405c3bSRiver Riddle   writer.append(curLoopLevel, op.getSuccessor());
8903eb1647aSStanislav Funiak   ++curLoopLevel;
8913eb1647aSStanislav Funiak   if (curLoopLevel > maxLoopLevel)
8923eb1647aSStanislav Funiak     maxLoopLevel = curLoopLevel;
8933c405c3bSRiver Riddle   generate(&op.getRegion(), writer);
8943eb1647aSStanislav Funiak   --curLoopLevel;
8953eb1647aSStanislav Funiak }
generate(pdl_interp::GetAttributeOp op,ByteCodeWriter & writer)896abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetAttributeOp op,
897abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
8983c405c3bSRiver Riddle   writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(),
8993c405c3bSRiver Riddle                 op.getNameAttr());
900abfd1a8bSRiver Riddle }
generate(pdl_interp::GetAttributeTypeOp op,ByteCodeWriter & writer)901abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetAttributeTypeOp op,
902abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
9033c405c3bSRiver Riddle   writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue());
904abfd1a8bSRiver Riddle }
generate(pdl_interp::GetDefiningOpOp op,ByteCodeWriter & writer)905abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetDefiningOpOp op,
906abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
9073c405c3bSRiver Riddle   writer.append(OpCode::GetDefiningOp, op.getInputOp());
9083c405c3bSRiver Riddle   writer.appendPDLValue(op.getValue());
909abfd1a8bSRiver Riddle }
generate(pdl_interp::GetOperandOp op,ByteCodeWriter & writer)910abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
9113c405c3bSRiver Riddle   uint32_t index = op.getIndex();
912abfd1a8bSRiver Riddle   if (index < 4)
913abfd1a8bSRiver Riddle     writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
914abfd1a8bSRiver Riddle   else
915abfd1a8bSRiver Riddle     writer.append(OpCode::GetOperandN, index);
9163c405c3bSRiver Riddle   writer.append(op.getInputOp(), op.getValue());
917abfd1a8bSRiver Riddle }
generate(pdl_interp::GetOperandsOp op,ByteCodeWriter & writer)91885ab413bSRiver Riddle void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
9193c405c3bSRiver Riddle   Value result = op.getValue();
9203c405c3bSRiver Riddle   Optional<uint32_t> index = op.getIndex();
92185ab413bSRiver Riddle   writer.append(OpCode::GetOperands,
922*30c67587SKazu Hirata                 index.value_or(std::numeric_limits<uint32_t>::max()),
9233c405c3bSRiver Riddle                 op.getInputOp());
92485ab413bSRiver Riddle   if (result.getType().isa<pdl::RangeType>())
92585ab413bSRiver Riddle     writer.append(getRangeStorageIndex(result));
92685ab413bSRiver Riddle   else
92785ab413bSRiver Riddle     writer.append(std::numeric_limits<ByteCodeField>::max());
92885ab413bSRiver Riddle   writer.append(result);
92985ab413bSRiver Riddle }
generate(pdl_interp::GetResultOp op,ByteCodeWriter & writer)930abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
9313c405c3bSRiver Riddle   uint32_t index = op.getIndex();
932abfd1a8bSRiver Riddle   if (index < 4)
933abfd1a8bSRiver Riddle     writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
934abfd1a8bSRiver Riddle   else
935abfd1a8bSRiver Riddle     writer.append(OpCode::GetResultN, index);
9363c405c3bSRiver Riddle   writer.append(op.getInputOp(), op.getValue());
937abfd1a8bSRiver Riddle }
generate(pdl_interp::GetResultsOp op,ByteCodeWriter & writer)93885ab413bSRiver Riddle void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
9393c405c3bSRiver Riddle   Value result = op.getValue();
9403c405c3bSRiver Riddle   Optional<uint32_t> index = op.getIndex();
94185ab413bSRiver Riddle   writer.append(OpCode::GetResults,
942*30c67587SKazu Hirata                 index.value_or(std::numeric_limits<uint32_t>::max()),
9433c405c3bSRiver Riddle                 op.getInputOp());
94485ab413bSRiver Riddle   if (result.getType().isa<pdl::RangeType>())
94585ab413bSRiver Riddle     writer.append(getRangeStorageIndex(result));
94685ab413bSRiver Riddle   else
94785ab413bSRiver Riddle     writer.append(std::numeric_limits<ByteCodeField>::max());
94885ab413bSRiver Riddle   writer.append(result);
94985ab413bSRiver Riddle }
generate(pdl_interp::GetUsersOp op,ByteCodeWriter & writer)9503eb1647aSStanislav Funiak void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
9513c405c3bSRiver Riddle   Value operations = op.getOperations();
9523eb1647aSStanislav Funiak   ByteCodeField rangeIndex = getRangeStorageIndex(operations);
9533eb1647aSStanislav Funiak   writer.append(OpCode::GetUsers, operations, rangeIndex);
9543c405c3bSRiver Riddle   writer.appendPDLValue(op.getValue());
9553eb1647aSStanislav Funiak }
generate(pdl_interp::GetValueTypeOp op,ByteCodeWriter & writer)956abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetValueTypeOp op,
957abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
95885ab413bSRiver Riddle   if (op.getType().isa<pdl::RangeType>()) {
9593c405c3bSRiver Riddle     Value result = op.getResult();
96085ab413bSRiver Riddle     writer.append(OpCode::GetValueRangeTypes, result,
9613c405c3bSRiver Riddle                   getRangeStorageIndex(result), op.getValue());
96285ab413bSRiver Riddle   } else {
9633c405c3bSRiver Riddle     writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
964abfd1a8bSRiver Riddle   }
96585ab413bSRiver Riddle }
generate(pdl_interp::IsNotNullOp op,ByteCodeWriter & writer)966abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
9673c405c3bSRiver Riddle   writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
968abfd1a8bSRiver Riddle }
generate(pdl_interp::RecordMatchOp op,ByteCodeWriter & writer)969abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
970abfd1a8bSRiver Riddle   ByteCodeField patternIndex = patterns.size();
971abfd1a8bSRiver Riddle   patterns.emplace_back(PDLByteCodePattern::create(
9723c405c3bSRiver Riddle       op, rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
9738affe881SRiver Riddle   writer.append(OpCode::RecordMatch, patternIndex,
9743c405c3bSRiver Riddle                 SuccessorRange(op.getOperation()), op.getMatchedOps());
9753c405c3bSRiver Riddle   writer.appendPDLValueList(op.getInputs());
976abfd1a8bSRiver Riddle }
generate(pdl_interp::ReplaceOp op,ByteCodeWriter & writer)977abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
9783c405c3bSRiver Riddle   writer.append(OpCode::ReplaceOp, op.getInputOp());
9793c405c3bSRiver Riddle   writer.appendPDLValueList(op.getReplValues());
980abfd1a8bSRiver Riddle }
generate(pdl_interp::SwitchAttributeOp op,ByteCodeWriter & writer)981abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchAttributeOp op,
982abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
9833c405c3bSRiver Riddle   writer.append(OpCode::SwitchAttribute, op.getAttribute(),
9843c405c3bSRiver Riddle                 op.getCaseValuesAttr(), op.getSuccessors());
985abfd1a8bSRiver Riddle }
generate(pdl_interp::SwitchOperandCountOp op,ByteCodeWriter & writer)986abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchOperandCountOp op,
987abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
9883c405c3bSRiver Riddle   writer.append(OpCode::SwitchOperandCount, op.getInputOp(),
9893c405c3bSRiver Riddle                 op.getCaseValuesAttr(), op.getSuccessors());
990abfd1a8bSRiver Riddle }
generate(pdl_interp::SwitchOperationNameOp op,ByteCodeWriter & writer)991abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchOperationNameOp op,
992abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
9933c405c3bSRiver Riddle   auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) {
994abfd1a8bSRiver Riddle     return OperationName(attr.cast<StringAttr>().getValue(), ctx);
995abfd1a8bSRiver Riddle   });
9963c405c3bSRiver Riddle   writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
997abfd1a8bSRiver Riddle                 op.getSuccessors());
998abfd1a8bSRiver Riddle }
generate(pdl_interp::SwitchResultCountOp op,ByteCodeWriter & writer)999abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchResultCountOp op,
1000abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
10013c405c3bSRiver Riddle   writer.append(OpCode::SwitchResultCount, op.getInputOp(),
10023c405c3bSRiver Riddle                 op.getCaseValuesAttr(), op.getSuccessors());
1003abfd1a8bSRiver Riddle }
generate(pdl_interp::SwitchTypeOp op,ByteCodeWriter & writer)1004abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
10053c405c3bSRiver Riddle   writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(),
1006abfd1a8bSRiver Riddle                 op.getSuccessors());
1007abfd1a8bSRiver Riddle }
generate(pdl_interp::SwitchTypesOp op,ByteCodeWriter & writer)100885ab413bSRiver Riddle void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
10093c405c3bSRiver Riddle   writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(),
101085ab413bSRiver Riddle                 op.getSuccessors());
101185ab413bSRiver Riddle }
1012abfd1a8bSRiver Riddle 
1013abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
1014abfd1a8bSRiver Riddle // PDLByteCode
1015abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
1016abfd1a8bSRiver Riddle 
PDLByteCode(ModuleOp module,llvm::StringMap<PDLConstraintFunction> constraintFns,llvm::StringMap<PDLRewriteFunction> rewriteFns)1017abfd1a8bSRiver Riddle PDLByteCode::PDLByteCode(ModuleOp module,
1018abfd1a8bSRiver Riddle                          llvm::StringMap<PDLConstraintFunction> constraintFns,
1019abfd1a8bSRiver Riddle                          llvm::StringMap<PDLRewriteFunction> rewriteFns) {
1020abfd1a8bSRiver Riddle   Generator generator(module.getContext(), uniquedData, matcherByteCode,
1021abfd1a8bSRiver Riddle                       rewriterByteCode, patterns, maxValueMemoryIndex,
10223eb1647aSStanislav Funiak                       maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
10233eb1647aSStanislav Funiak                       maxLoopLevel, constraintFns, rewriteFns);
1024abfd1a8bSRiver Riddle   generator.generate(module);
1025abfd1a8bSRiver Riddle 
1026abfd1a8bSRiver Riddle   // Initialize the external functions.
1027abfd1a8bSRiver Riddle   for (auto &it : constraintFns)
1028abfd1a8bSRiver Riddle     constraintFunctions.push_back(std::move(it.second));
1029abfd1a8bSRiver Riddle   for (auto &it : rewriteFns)
1030abfd1a8bSRiver Riddle     rewriteFunctions.push_back(std::move(it.second));
1031abfd1a8bSRiver Riddle }
1032abfd1a8bSRiver Riddle 
1033abfd1a8bSRiver Riddle /// Initialize the given state such that it can be used to execute the current
1034abfd1a8bSRiver Riddle /// bytecode.
initializeMutableState(PDLByteCodeMutableState & state) const1035abfd1a8bSRiver Riddle void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
1036abfd1a8bSRiver Riddle   state.memory.resize(maxValueMemoryIndex, nullptr);
10373eb1647aSStanislav Funiak   state.opRangeMemory.resize(maxOpRangeCount);
103885ab413bSRiver Riddle   state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
103985ab413bSRiver Riddle   state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
10403eb1647aSStanislav Funiak   state.loopIndex.resize(maxLoopLevel, 0);
1041abfd1a8bSRiver Riddle   state.currentPatternBenefits.reserve(patterns.size());
1042abfd1a8bSRiver Riddle   for (const PDLByteCodePattern &pattern : patterns)
1043abfd1a8bSRiver Riddle     state.currentPatternBenefits.push_back(pattern.getBenefit());
1044abfd1a8bSRiver Riddle }
1045abfd1a8bSRiver Riddle 
1046abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
1047abfd1a8bSRiver Riddle // ByteCode Execution
1048abfd1a8bSRiver Riddle 
1049abfd1a8bSRiver Riddle namespace {
1050abfd1a8bSRiver Riddle /// This class provides support for executing a bytecode stream.
1051abfd1a8bSRiver Riddle class ByteCodeExecutor {
1052abfd1a8bSRiver Riddle public:
ByteCodeExecutor(const ByteCodeField * curCodeIt,MutableArrayRef<const void * > memory,MutableArrayRef<llvm::OwningArrayRef<Operation * >> opRangeMemory,MutableArrayRef<TypeRange> typeRangeMemory,std::vector<llvm::OwningArrayRef<Type>> & allocatedTypeRangeMemory,MutableArrayRef<ValueRange> valueRangeMemory,std::vector<llvm::OwningArrayRef<Value>> & allocatedValueRangeMemory,MutableArrayRef<unsigned> loopIndex,ArrayRef<const void * > uniquedMemory,ArrayRef<ByteCodeField> code,ArrayRef<PatternBenefit> currentPatternBenefits,ArrayRef<PDLByteCodePattern> patterns,ArrayRef<PDLConstraintFunction> constraintFunctions,ArrayRef<PDLRewriteFunction> rewriteFunctions)105385ab413bSRiver Riddle   ByteCodeExecutor(
105485ab413bSRiver Riddle       const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
10553eb1647aSStanislav Funiak       MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory,
105685ab413bSRiver Riddle       MutableArrayRef<TypeRange> typeRangeMemory,
105785ab413bSRiver Riddle       std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
105885ab413bSRiver Riddle       MutableArrayRef<ValueRange> valueRangeMemory,
105985ab413bSRiver Riddle       std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
10603eb1647aSStanislav Funiak       MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory,
10613eb1647aSStanislav Funiak       ArrayRef<ByteCodeField> code,
1062abfd1a8bSRiver Riddle       ArrayRef<PatternBenefit> currentPatternBenefits,
1063abfd1a8bSRiver Riddle       ArrayRef<PDLByteCodePattern> patterns,
1064abfd1a8bSRiver Riddle       ArrayRef<PDLConstraintFunction> constraintFunctions,
1065abfd1a8bSRiver Riddle       ArrayRef<PDLRewriteFunction> rewriteFunctions)
10663eb1647aSStanislav Funiak       : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
10673eb1647aSStanislav Funiak         typeRangeMemory(typeRangeMemory),
106885ab413bSRiver Riddle         allocatedTypeRangeMemory(allocatedTypeRangeMemory),
106985ab413bSRiver Riddle         valueRangeMemory(valueRangeMemory),
107085ab413bSRiver Riddle         allocatedValueRangeMemory(allocatedValueRangeMemory),
10713eb1647aSStanislav Funiak         loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
107285ab413bSRiver Riddle         currentPatternBenefits(currentPatternBenefits), patterns(patterns),
107385ab413bSRiver Riddle         constraintFunctions(constraintFunctions),
107402c4c0d5SRiver Riddle         rewriteFunctions(rewriteFunctions) {}
1075abfd1a8bSRiver Riddle 
1076abfd1a8bSRiver Riddle   /// Start executing the code at the current bytecode index. `matches` is an
1077abfd1a8bSRiver Riddle   /// optional field provided when this function is executed in a matching
1078abfd1a8bSRiver Riddle   /// context.
1079abfd1a8bSRiver Riddle   void execute(PatternRewriter &rewriter,
1080abfd1a8bSRiver Riddle                SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
1081abfd1a8bSRiver Riddle                Optional<Location> mainRewriteLoc = {});
1082abfd1a8bSRiver Riddle 
1083abfd1a8bSRiver Riddle private:
1084154cabe7SRiver Riddle   /// Internal implementation of executing each of the bytecode commands.
1085154cabe7SRiver Riddle   void executeApplyConstraint(PatternRewriter &rewriter);
1086154cabe7SRiver Riddle   void executeApplyRewrite(PatternRewriter &rewriter);
1087154cabe7SRiver Riddle   void executeAreEqual();
108885ab413bSRiver Riddle   void executeAreRangesEqual();
1089154cabe7SRiver Riddle   void executeBranch();
1090154cabe7SRiver Riddle   void executeCheckOperandCount();
1091154cabe7SRiver Riddle   void executeCheckOperationName();
1092154cabe7SRiver Riddle   void executeCheckResultCount();
109385ab413bSRiver Riddle   void executeCheckTypes();
10943eb1647aSStanislav Funiak   void executeContinue();
1095154cabe7SRiver Riddle   void executeCreateOperation(PatternRewriter &rewriter,
1096154cabe7SRiver Riddle                               Location mainRewriteLoc);
109785ab413bSRiver Riddle   void executeCreateTypes();
1098154cabe7SRiver Riddle   void executeEraseOp(PatternRewriter &rewriter);
10993eb1647aSStanislav Funiak   template <typename T, typename Range, PDLValue::Kind kind>
11003eb1647aSStanislav Funiak   void executeExtract();
11013eb1647aSStanislav Funiak   void executeFinalize();
11023eb1647aSStanislav Funiak   void executeForEach();
1103154cabe7SRiver Riddle   void executeGetAttribute();
1104154cabe7SRiver Riddle   void executeGetAttributeType();
1105154cabe7SRiver Riddle   void executeGetDefiningOp();
1106154cabe7SRiver Riddle   void executeGetOperand(unsigned index);
110785ab413bSRiver Riddle   void executeGetOperands();
1108154cabe7SRiver Riddle   void executeGetResult(unsigned index);
110985ab413bSRiver Riddle   void executeGetResults();
11103eb1647aSStanislav Funiak   void executeGetUsers();
1111154cabe7SRiver Riddle   void executeGetValueType();
111285ab413bSRiver Riddle   void executeGetValueRangeTypes();
1113154cabe7SRiver Riddle   void executeIsNotNull();
1114154cabe7SRiver Riddle   void executeRecordMatch(PatternRewriter &rewriter,
1115154cabe7SRiver Riddle                           SmallVectorImpl<PDLByteCode::MatchResult> &matches);
1116154cabe7SRiver Riddle   void executeReplaceOp(PatternRewriter &rewriter);
1117154cabe7SRiver Riddle   void executeSwitchAttribute();
1118154cabe7SRiver Riddle   void executeSwitchOperandCount();
1119154cabe7SRiver Riddle   void executeSwitchOperationName();
1120154cabe7SRiver Riddle   void executeSwitchResultCount();
1121154cabe7SRiver Riddle   void executeSwitchType();
112285ab413bSRiver Riddle   void executeSwitchTypes();
1123154cabe7SRiver Riddle 
11243eb1647aSStanislav Funiak   /// Pushes a code iterator to the stack.
pushCodeIt(const ByteCodeField * it)11253eb1647aSStanislav Funiak   void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
11263eb1647aSStanislav Funiak 
11273eb1647aSStanislav Funiak   /// Pops a code iterator from the stack, returning true on success.
popCodeIt()11283eb1647aSStanislav Funiak   void popCodeIt() {
11293eb1647aSStanislav Funiak     assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack");
11303eb1647aSStanislav Funiak     curCodeIt = resumeCodeIt.back();
11313eb1647aSStanislav Funiak     resumeCodeIt.pop_back();
11323eb1647aSStanislav Funiak   }
11333eb1647aSStanislav Funiak 
1134d35f1190SStanislav Funiak   /// Return the bytecode iterator at the start of the current op code.
getPrevCodeIt() const1135d35f1190SStanislav Funiak   const ByteCodeField *getPrevCodeIt() const {
1136d35f1190SStanislav Funiak     LLVM_DEBUG({
1137d35f1190SStanislav Funiak       // Account for the op code and the Location stored inline.
1138d35f1190SStanislav Funiak       return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField);
1139d35f1190SStanislav Funiak     });
1140d35f1190SStanislav Funiak 
1141d35f1190SStanislav Funiak     // Account for the op code only.
1142d35f1190SStanislav Funiak     return curCodeIt - 1;
1143d35f1190SStanislav Funiak   }
1144d35f1190SStanislav Funiak 
1145abfd1a8bSRiver Riddle   /// Read a value from the bytecode buffer, optionally skipping a certain
1146abfd1a8bSRiver Riddle   /// number of prefix values. These methods always update the buffer to point
1147abfd1a8bSRiver Riddle   /// to the next field after the read data.
1148abfd1a8bSRiver Riddle   template <typename T = ByteCodeField>
read(size_t skipN=0)1149abfd1a8bSRiver Riddle   T read(size_t skipN = 0) {
1150abfd1a8bSRiver Riddle     curCodeIt += skipN;
1151abfd1a8bSRiver Riddle     return readImpl<T>();
1152abfd1a8bSRiver Riddle   }
read(size_t skipN=0)1153abfd1a8bSRiver Riddle   ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
1154abfd1a8bSRiver Riddle 
1155abfd1a8bSRiver Riddle   /// Read a list of values from the bytecode buffer.
1156abfd1a8bSRiver Riddle   template <typename ValueT, typename T>
readList(SmallVectorImpl<T> & list)1157abfd1a8bSRiver Riddle   void readList(SmallVectorImpl<T> &list) {
1158abfd1a8bSRiver Riddle     list.clear();
1159abfd1a8bSRiver Riddle     for (unsigned i = 0, e = read(); i != e; ++i)
1160abfd1a8bSRiver Riddle       list.push_back(read<ValueT>());
1161abfd1a8bSRiver Riddle   }
1162abfd1a8bSRiver Riddle 
116385ab413bSRiver Riddle   /// Read a list of values from the bytecode buffer. The values may be encoded
116485ab413bSRiver Riddle   /// as either Value or ValueRange elements.
readValueList(SmallVectorImpl<Value> & list)116585ab413bSRiver Riddle   void readValueList(SmallVectorImpl<Value> &list) {
116685ab413bSRiver Riddle     for (unsigned i = 0, e = read(); i != e; ++i) {
116785ab413bSRiver Riddle       if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
116885ab413bSRiver Riddle         list.push_back(read<Value>());
116985ab413bSRiver Riddle       } else {
117085ab413bSRiver Riddle         ValueRange *values = read<ValueRange *>();
117185ab413bSRiver Riddle         list.append(values->begin(), values->end());
117285ab413bSRiver Riddle       }
117385ab413bSRiver Riddle     }
117485ab413bSRiver Riddle   }
117585ab413bSRiver Riddle 
1176d35f1190SStanislav Funiak   /// Read a value stored inline as a pointer.
1177d35f1190SStanislav Funiak   template <typename T>
1178d35f1190SStanislav Funiak   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T>
readInline()1179d35f1190SStanislav Funiak   readInline() {
1180d35f1190SStanislav Funiak     const void *pointer;
1181d35f1190SStanislav Funiak     std::memcpy(&pointer, curCodeIt, sizeof(const void *));
1182d35f1190SStanislav Funiak     curCodeIt += sizeof(const void *) / sizeof(ByteCodeField);
1183d35f1190SStanislav Funiak     return T::getFromOpaquePointer(pointer);
1184d35f1190SStanislav Funiak   }
1185d35f1190SStanislav Funiak 
1186abfd1a8bSRiver Riddle   /// Jump to a specific successor based on a predicate value.
selectJump(bool isTrue)1187abfd1a8bSRiver Riddle   void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
1188abfd1a8bSRiver Riddle   /// Jump to a specific successor based on a destination index.
selectJump(size_t destIndex)1189abfd1a8bSRiver Riddle   void selectJump(size_t destIndex) {
1190abfd1a8bSRiver Riddle     curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1191abfd1a8bSRiver Riddle   }
1192abfd1a8bSRiver Riddle 
1193abfd1a8bSRiver Riddle   /// Handle a switch operation with the provided value and cases.
119485ab413bSRiver Riddle   template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
handleSwitch(const T & value,RangeT && cases,Comparator cmp={})119585ab413bSRiver Riddle   void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1196abfd1a8bSRiver Riddle     LLVM_DEBUG({
1197abfd1a8bSRiver Riddle       llvm::dbgs() << "  * Value: " << value << "\n"
1198abfd1a8bSRiver Riddle                    << "  * Cases: ";
1199abfd1a8bSRiver Riddle       llvm::interleaveComma(cases, llvm::dbgs());
1200154cabe7SRiver Riddle       llvm::dbgs() << "\n";
1201abfd1a8bSRiver Riddle     });
1202abfd1a8bSRiver Riddle 
1203abfd1a8bSRiver Riddle     // Check to see if the attribute value is within the case list. Jump to
1204abfd1a8bSRiver Riddle     // the correct successor index based on the result.
1205f80b6304SRiver Riddle     for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
120685ab413bSRiver Riddle       if (cmp(*it, value))
1207f80b6304SRiver Riddle         return selectJump(size_t((it - cases.begin()) + 1));
1208f80b6304SRiver Riddle     selectJump(size_t(0));
1209abfd1a8bSRiver Riddle   }
1210abfd1a8bSRiver Riddle 
12113eb1647aSStanislav Funiak   /// Store a pointer to memory.
storeToMemory(unsigned index,const void * value)12123eb1647aSStanislav Funiak   void storeToMemory(unsigned index, const void *value) {
12133eb1647aSStanislav Funiak     memory[index] = value;
12143eb1647aSStanislav Funiak   }
12153eb1647aSStanislav Funiak 
12163eb1647aSStanislav Funiak   /// Store a value to memory as an opaque pointer.
12173eb1647aSStanislav Funiak   template <typename T>
12183eb1647aSStanislav Funiak   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
storeToMemory(unsigned index,T value)12193eb1647aSStanislav Funiak   storeToMemory(unsigned index, T value) {
12203eb1647aSStanislav Funiak     memory[index] = value.getAsOpaquePointer();
12213eb1647aSStanislav Funiak   }
12223eb1647aSStanislav Funiak 
1223abfd1a8bSRiver Riddle   /// Internal implementation of reading various data types from the bytecode
1224abfd1a8bSRiver Riddle   /// stream.
1225abfd1a8bSRiver Riddle   template <typename T>
readFromMemory()1226abfd1a8bSRiver Riddle   const void *readFromMemory() {
1227abfd1a8bSRiver Riddle     size_t index = *curCodeIt++;
1228abfd1a8bSRiver Riddle 
1229abfd1a8bSRiver Riddle     // If this type is an SSA value, it can only be stored in non-const memory.
123085ab413bSRiver Riddle     if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
123185ab413bSRiver Riddle                         Value>::value ||
123285ab413bSRiver Riddle         index < memory.size())
1233abfd1a8bSRiver Riddle       return memory[index];
1234abfd1a8bSRiver Riddle 
1235abfd1a8bSRiver Riddle     // Otherwise, if this index is not inbounds it is uniqued.
1236abfd1a8bSRiver Riddle     return uniquedMemory[index - memory.size()];
1237abfd1a8bSRiver Riddle   }
1238abfd1a8bSRiver Riddle   template <typename T>
readImpl()1239abfd1a8bSRiver Riddle   std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1240abfd1a8bSRiver Riddle     return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1241abfd1a8bSRiver Riddle   }
1242abfd1a8bSRiver Riddle   template <typename T>
1243abfd1a8bSRiver Riddle   std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1244abfd1a8bSRiver Riddle                    T>
readImpl()1245abfd1a8bSRiver Riddle   readImpl() {
1246abfd1a8bSRiver Riddle     return T(T::getFromOpaquePointer(readFromMemory<T>()));
1247abfd1a8bSRiver Riddle   }
1248abfd1a8bSRiver Riddle   template <typename T>
readImpl()1249abfd1a8bSRiver Riddle   std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
125085ab413bSRiver Riddle     switch (read<PDLValue::Kind>()) {
125185ab413bSRiver Riddle     case PDLValue::Kind::Attribute:
1252abfd1a8bSRiver Riddle       return read<Attribute>();
125385ab413bSRiver Riddle     case PDLValue::Kind::Operation:
1254abfd1a8bSRiver Riddle       return read<Operation *>();
125585ab413bSRiver Riddle     case PDLValue::Kind::Type:
1256abfd1a8bSRiver Riddle       return read<Type>();
125785ab413bSRiver Riddle     case PDLValue::Kind::Value:
1258abfd1a8bSRiver Riddle       return read<Value>();
125985ab413bSRiver Riddle     case PDLValue::Kind::TypeRange:
126085ab413bSRiver Riddle       return read<TypeRange *>();
126185ab413bSRiver Riddle     case PDLValue::Kind::ValueRange:
126285ab413bSRiver Riddle       return read<ValueRange *>();
1263abfd1a8bSRiver Riddle     }
126485ab413bSRiver Riddle     llvm_unreachable("unhandled PDLValue::Kind");
1265abfd1a8bSRiver Riddle   }
1266abfd1a8bSRiver Riddle   template <typename T>
readImpl()1267abfd1a8bSRiver Riddle   std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1268abfd1a8bSRiver Riddle     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1269abfd1a8bSRiver Riddle                   "unexpected ByteCode address size");
1270abfd1a8bSRiver Riddle     ByteCodeAddr result;
1271abfd1a8bSRiver Riddle     std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
1272abfd1a8bSRiver Riddle     curCodeIt += 2;
1273abfd1a8bSRiver Riddle     return result;
1274abfd1a8bSRiver Riddle   }
1275abfd1a8bSRiver Riddle   template <typename T>
readImpl()1276abfd1a8bSRiver Riddle   std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1277abfd1a8bSRiver Riddle     return *curCodeIt++;
1278abfd1a8bSRiver Riddle   }
127985ab413bSRiver Riddle   template <typename T>
readImpl()128085ab413bSRiver Riddle   std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
128185ab413bSRiver Riddle     return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
128285ab413bSRiver Riddle   }
1283abfd1a8bSRiver Riddle 
1284abfd1a8bSRiver Riddle   /// The underlying bytecode buffer.
1285abfd1a8bSRiver Riddle   const ByteCodeField *curCodeIt;
1286abfd1a8bSRiver Riddle 
12873eb1647aSStanislav Funiak   /// The stack of bytecode positions at which to resume operation.
12883eb1647aSStanislav Funiak   SmallVector<const ByteCodeField *> resumeCodeIt;
12893eb1647aSStanislav Funiak 
1290abfd1a8bSRiver Riddle   /// The current execution memory.
1291abfd1a8bSRiver Riddle   MutableArrayRef<const void *> memory;
12923eb1647aSStanislav Funiak   MutableArrayRef<OwningOpRange> opRangeMemory;
129385ab413bSRiver Riddle   MutableArrayRef<TypeRange> typeRangeMemory;
129485ab413bSRiver Riddle   std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
129585ab413bSRiver Riddle   MutableArrayRef<ValueRange> valueRangeMemory;
129685ab413bSRiver Riddle   std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1297abfd1a8bSRiver Riddle 
12983eb1647aSStanislav Funiak   /// The current loop indices.
12993eb1647aSStanislav Funiak   MutableArrayRef<unsigned> loopIndex;
13003eb1647aSStanislav Funiak 
1301abfd1a8bSRiver Riddle   /// References to ByteCode data necessary for execution.
1302abfd1a8bSRiver Riddle   ArrayRef<const void *> uniquedMemory;
1303abfd1a8bSRiver Riddle   ArrayRef<ByteCodeField> code;
1304abfd1a8bSRiver Riddle   ArrayRef<PatternBenefit> currentPatternBenefits;
1305abfd1a8bSRiver Riddle   ArrayRef<PDLByteCodePattern> patterns;
1306abfd1a8bSRiver Riddle   ArrayRef<PDLConstraintFunction> constraintFunctions;
1307abfd1a8bSRiver Riddle   ArrayRef<PDLRewriteFunction> rewriteFunctions;
1308abfd1a8bSRiver Riddle };
130902c4c0d5SRiver Riddle 
131002c4c0d5SRiver Riddle /// This class is an instantiation of the PDLResultList that provides access to
131102c4c0d5SRiver Riddle /// the returned results. This API is not on `PDLResultList` to avoid
131202c4c0d5SRiver Riddle /// overexposing access to information specific solely to the ByteCode.
131302c4c0d5SRiver Riddle class ByteCodeRewriteResultList : public PDLResultList {
131402c4c0d5SRiver Riddle public:
ByteCodeRewriteResultList(unsigned maxNumResults)131585ab413bSRiver Riddle   ByteCodeRewriteResultList(unsigned maxNumResults)
131685ab413bSRiver Riddle       : PDLResultList(maxNumResults) {}
131785ab413bSRiver Riddle 
131802c4c0d5SRiver Riddle   /// Return the list of PDL results.
getResults()131902c4c0d5SRiver Riddle   MutableArrayRef<PDLValue> getResults() { return results; }
132085ab413bSRiver Riddle 
132185ab413bSRiver Riddle   /// Return the type ranges allocated by this list.
getAllocatedTypeRanges()132285ab413bSRiver Riddle   MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
132385ab413bSRiver Riddle     return allocatedTypeRanges;
132485ab413bSRiver Riddle   }
132585ab413bSRiver Riddle 
132685ab413bSRiver Riddle   /// Return the value ranges allocated by this list.
getAllocatedValueRanges()132785ab413bSRiver Riddle   MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
132885ab413bSRiver Riddle     return allocatedValueRanges;
132985ab413bSRiver Riddle   }
133002c4c0d5SRiver Riddle };
1331be0a7e9fSMehdi Amini } // namespace
1332abfd1a8bSRiver Riddle 
executeApplyConstraint(PatternRewriter & rewriter)1333154cabe7SRiver Riddle void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1334abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
1335abfd1a8bSRiver Riddle   const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
1336abfd1a8bSRiver Riddle   SmallVector<PDLValue, 16> args;
1337abfd1a8bSRiver Riddle   readList<PDLValue>(args);
1338154cabe7SRiver Riddle 
1339abfd1a8bSRiver Riddle   LLVM_DEBUG({
1340abfd1a8bSRiver Riddle     llvm::dbgs() << "  * Arguments: ";
1341abfd1a8bSRiver Riddle     llvm::interleaveComma(args, llvm::dbgs());
1342abfd1a8bSRiver Riddle   });
1343abfd1a8bSRiver Riddle 
1344abfd1a8bSRiver Riddle   // Invoke the constraint and jump to the proper destination.
1345ea64828aSRiver Riddle   selectJump(succeeded(constraintFn(rewriter, args)));
1346abfd1a8bSRiver Riddle }
1347154cabe7SRiver Riddle 
executeApplyRewrite(PatternRewriter & rewriter)1348154cabe7SRiver Riddle void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1349abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
1350abfd1a8bSRiver Riddle   const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1351abfd1a8bSRiver Riddle   SmallVector<PDLValue, 16> args;
1352abfd1a8bSRiver Riddle   readList<PDLValue>(args);
1353abfd1a8bSRiver Riddle 
1354abfd1a8bSRiver Riddle   LLVM_DEBUG({
135502c4c0d5SRiver Riddle     llvm::dbgs() << "  * Arguments: ";
1356abfd1a8bSRiver Riddle     llvm::interleaveComma(args, llvm::dbgs());
1357abfd1a8bSRiver Riddle   });
135885ab413bSRiver Riddle 
135985ab413bSRiver Riddle   // Execute the rewrite function.
136085ab413bSRiver Riddle   ByteCodeField numResults = read();
136185ab413bSRiver Riddle   ByteCodeRewriteResultList results(numResults);
1362ea64828aSRiver Riddle   rewriteFn(rewriter, results, args);
1363154cabe7SRiver Riddle 
136485ab413bSRiver Riddle   assert(results.getResults().size() == numResults &&
136502c4c0d5SRiver Riddle          "native PDL rewrite function returned unexpected number of results");
136602c4c0d5SRiver Riddle 
136702c4c0d5SRiver Riddle   // Store the results in the bytecode memory.
136802c4c0d5SRiver Riddle   for (PDLValue &result : results.getResults()) {
136902c4c0d5SRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n");
137085ab413bSRiver Riddle 
137185ab413bSRiver Riddle // In debug mode we also verify the expected kind of the result.
137285ab413bSRiver Riddle #ifndef NDEBUG
137385ab413bSRiver Riddle     assert(result.getKind() == read<PDLValue::Kind>() &&
137485ab413bSRiver Riddle            "native PDL rewrite function returned an unexpected type of result");
137585ab413bSRiver Riddle #endif
137685ab413bSRiver Riddle 
137785ab413bSRiver Riddle     // If the result is a range, we need to copy it over to the bytecodes
137885ab413bSRiver Riddle     // range memory.
137985ab413bSRiver Riddle     if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
138085ab413bSRiver Riddle       unsigned rangeIndex = read();
138185ab413bSRiver Riddle       typeRangeMemory[rangeIndex] = *typeRange;
138285ab413bSRiver Riddle       memory[read()] = &typeRangeMemory[rangeIndex];
138385ab413bSRiver Riddle     } else if (Optional<ValueRange> valueRange =
138485ab413bSRiver Riddle                    result.dyn_cast<ValueRange>()) {
138585ab413bSRiver Riddle       unsigned rangeIndex = read();
138685ab413bSRiver Riddle       valueRangeMemory[rangeIndex] = *valueRange;
138785ab413bSRiver Riddle       memory[read()] = &valueRangeMemory[rangeIndex];
138885ab413bSRiver Riddle     } else {
138902c4c0d5SRiver Riddle       memory[read()] = result.getAsOpaquePointer();
139002c4c0d5SRiver Riddle     }
1391abfd1a8bSRiver Riddle   }
1392154cabe7SRiver Riddle 
139385ab413bSRiver Riddle   // Copy over any underlying storage allocated for result ranges.
139485ab413bSRiver Riddle   for (auto &it : results.getAllocatedTypeRanges())
139585ab413bSRiver Riddle     allocatedTypeRangeMemory.push_back(std::move(it));
139685ab413bSRiver Riddle   for (auto &it : results.getAllocatedValueRanges())
139785ab413bSRiver Riddle     allocatedValueRangeMemory.push_back(std::move(it));
139885ab413bSRiver Riddle }
139985ab413bSRiver Riddle 
executeAreEqual()1400154cabe7SRiver Riddle void ByteCodeExecutor::executeAreEqual() {
1401abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1402abfd1a8bSRiver Riddle   const void *lhs = read<const void *>();
1403abfd1a8bSRiver Riddle   const void *rhs = read<const void *>();
1404abfd1a8bSRiver Riddle 
1405154cabe7SRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n");
1406abfd1a8bSRiver Riddle   selectJump(lhs == rhs);
1407abfd1a8bSRiver Riddle }
1408154cabe7SRiver Riddle 
executeAreRangesEqual()140985ab413bSRiver Riddle void ByteCodeExecutor::executeAreRangesEqual() {
141085ab413bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
141185ab413bSRiver Riddle   PDLValue::Kind valueKind = read<PDLValue::Kind>();
141285ab413bSRiver Riddle   const void *lhs = read<const void *>();
141385ab413bSRiver Riddle   const void *rhs = read<const void *>();
141485ab413bSRiver Riddle 
141585ab413bSRiver Riddle   switch (valueKind) {
141685ab413bSRiver Riddle   case PDLValue::Kind::TypeRange: {
141785ab413bSRiver Riddle     const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
141885ab413bSRiver Riddle     const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
141985ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
142085ab413bSRiver Riddle     selectJump(*lhsRange == *rhsRange);
142185ab413bSRiver Riddle     break;
142285ab413bSRiver Riddle   }
142385ab413bSRiver Riddle   case PDLValue::Kind::ValueRange: {
142485ab413bSRiver Riddle     const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
142585ab413bSRiver Riddle     const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
142685ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
142785ab413bSRiver Riddle     selectJump(*lhsRange == *rhsRange);
142885ab413bSRiver Riddle     break;
142985ab413bSRiver Riddle   }
143085ab413bSRiver Riddle   default:
143185ab413bSRiver Riddle     llvm_unreachable("unexpected `AreRangesEqual` value kind");
143285ab413bSRiver Riddle   }
143385ab413bSRiver Riddle }
143485ab413bSRiver Riddle 
executeBranch()1435154cabe7SRiver Riddle void ByteCodeExecutor::executeBranch() {
1436154cabe7SRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
1437abfd1a8bSRiver Riddle   curCodeIt = &code[read<ByteCodeAddr>()];
1438abfd1a8bSRiver Riddle }
1439154cabe7SRiver Riddle 
executeCheckOperandCount()1440154cabe7SRiver Riddle void ByteCodeExecutor::executeCheckOperandCount() {
1441abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
1442abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1443abfd1a8bSRiver Riddle   uint32_t expectedCount = read<uint32_t>();
144485ab413bSRiver Riddle   bool compareAtLeast = read();
1445abfd1a8bSRiver Riddle 
1446abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumOperands() << "\n"
144785ab413bSRiver Riddle                           << "  * Expected: " << expectedCount << "\n"
144885ab413bSRiver Riddle                           << "  * Comparator: "
144985ab413bSRiver Riddle                           << (compareAtLeast ? ">=" : "==") << "\n");
145085ab413bSRiver Riddle   if (compareAtLeast)
145185ab413bSRiver Riddle     selectJump(op->getNumOperands() >= expectedCount);
145285ab413bSRiver Riddle   else
1453abfd1a8bSRiver Riddle     selectJump(op->getNumOperands() == expectedCount);
1454abfd1a8bSRiver Riddle }
1455154cabe7SRiver Riddle 
executeCheckOperationName()1456154cabe7SRiver Riddle void ByteCodeExecutor::executeCheckOperationName() {
1457abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
1458abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1459abfd1a8bSRiver Riddle   OperationName expectedName = read<OperationName>();
1460abfd1a8bSRiver Riddle 
1461154cabe7SRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Found: \"" << op->getName() << "\"\n"
1462154cabe7SRiver Riddle                           << "  * Expected: \"" << expectedName << "\"\n");
1463abfd1a8bSRiver Riddle   selectJump(op->getName() == expectedName);
1464abfd1a8bSRiver Riddle }
1465154cabe7SRiver Riddle 
executeCheckResultCount()1466154cabe7SRiver Riddle void ByteCodeExecutor::executeCheckResultCount() {
1467abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
1468abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1469abfd1a8bSRiver Riddle   uint32_t expectedCount = read<uint32_t>();
147085ab413bSRiver Riddle   bool compareAtLeast = read();
1471abfd1a8bSRiver Riddle 
1472abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumResults() << "\n"
147385ab413bSRiver Riddle                           << "  * Expected: " << expectedCount << "\n"
147485ab413bSRiver Riddle                           << "  * Comparator: "
147585ab413bSRiver Riddle                           << (compareAtLeast ? ">=" : "==") << "\n");
147685ab413bSRiver Riddle   if (compareAtLeast)
147785ab413bSRiver Riddle     selectJump(op->getNumResults() >= expectedCount);
147885ab413bSRiver Riddle   else
1479abfd1a8bSRiver Riddle     selectJump(op->getNumResults() == expectedCount);
1480abfd1a8bSRiver Riddle }
1481154cabe7SRiver Riddle 
executeCheckTypes()148285ab413bSRiver Riddle void ByteCodeExecutor::executeCheckTypes() {
148385ab413bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
148485ab413bSRiver Riddle   TypeRange *lhs = read<TypeRange *>();
148585ab413bSRiver Riddle   Attribute rhs = read<Attribute>();
148685ab413bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
148785ab413bSRiver Riddle 
148885ab413bSRiver Riddle   selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>());
148985ab413bSRiver Riddle }
149085ab413bSRiver Riddle 
executeContinue()14913eb1647aSStanislav Funiak void ByteCodeExecutor::executeContinue() {
14923eb1647aSStanislav Funiak   ByteCodeField level = read();
14933eb1647aSStanislav Funiak   LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n"
14943eb1647aSStanislav Funiak                           << "  * Level: " << level << "\n");
14953eb1647aSStanislav Funiak   ++loopIndex[level];
14963eb1647aSStanislav Funiak   popCodeIt();
14973eb1647aSStanislav Funiak }
14983eb1647aSStanislav Funiak 
executeCreateTypes()149985ab413bSRiver Riddle void ByteCodeExecutor::executeCreateTypes() {
150085ab413bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n");
150185ab413bSRiver Riddle   unsigned memIndex = read();
150285ab413bSRiver Riddle   unsigned rangeIndex = read();
150385ab413bSRiver Riddle   ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>();
150485ab413bSRiver Riddle 
150585ab413bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Types: " << typesAttr << "\n\n");
150685ab413bSRiver Riddle 
150785ab413bSRiver Riddle   // Allocate a buffer for this type range.
150885ab413bSRiver Riddle   llvm::OwningArrayRef<Type> storage(typesAttr.size());
150985ab413bSRiver Riddle   llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin());
151085ab413bSRiver Riddle   allocatedTypeRangeMemory.emplace_back(std::move(storage));
151185ab413bSRiver Riddle 
151285ab413bSRiver Riddle   // Assign this to the range slot and use the range as the value for the
151385ab413bSRiver Riddle   // memory index.
151485ab413bSRiver Riddle   typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back();
151585ab413bSRiver Riddle   memory[memIndex] = &typeRangeMemory[rangeIndex];
151685ab413bSRiver Riddle }
151785ab413bSRiver Riddle 
executeCreateOperation(PatternRewriter & rewriter,Location mainRewriteLoc)1518154cabe7SRiver Riddle void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1519154cabe7SRiver Riddle                                               Location mainRewriteLoc) {
1520abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
1521abfd1a8bSRiver Riddle 
1522abfd1a8bSRiver Riddle   unsigned memIndex = read();
1523154cabe7SRiver Riddle   OperationState state(mainRewriteLoc, read<OperationName>());
152485ab413bSRiver Riddle   readValueList(state.operands);
1525abfd1a8bSRiver Riddle   for (unsigned i = 0, e = read(); i != e; ++i) {
1526195730a6SRiver Riddle     StringAttr name = read<StringAttr>();
1527abfd1a8bSRiver Riddle     if (Attribute attr = read<Attribute>())
1528abfd1a8bSRiver Riddle       state.addAttribute(name, attr);
1529abfd1a8bSRiver Riddle   }
1530abfd1a8bSRiver Riddle 
15313c752289SRiver Riddle   // Read in the result types. If the "size" is the sentinel value, this
15323c752289SRiver Riddle   // indicates that the result types should be inferred.
15333c752289SRiver Riddle   unsigned numResults = read();
15343c752289SRiver Riddle   if (numResults == kInferTypesMarker) {
1535ea7be7e3SBenjamin Kramer     InferTypeOpInterface::Concept *inferInterface =
1536edc6c0ecSRiver Riddle         state.name.getRegisteredInfo()->getInterface<InferTypeOpInterface>();
15373c752289SRiver Riddle     assert(inferInterface &&
15383c752289SRiver Riddle            "expected operation to provide InferTypeOpInterface");
1539abfd1a8bSRiver Riddle 
1540abfd1a8bSRiver Riddle     // TODO: Handle failure.
1541ea7be7e3SBenjamin Kramer     if (failed(inferInterface->inferReturnTypes(
1542abfd1a8bSRiver Riddle             state.getContext(), state.location, state.operands,
1543154cabe7SRiver Riddle             state.attributes.getDictionary(state.getContext()), state.regions,
15443a833a0eSRiver Riddle             state.types)))
1545abfd1a8bSRiver Riddle       return;
15463c752289SRiver Riddle   } else {
15473c752289SRiver Riddle     // Otherwise, this is a fixed number of results.
15483c752289SRiver Riddle     for (unsigned i = 0; i != numResults; ++i) {
15493c752289SRiver Riddle       if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
15503c752289SRiver Riddle         state.types.push_back(read<Type>());
15513c752289SRiver Riddle       } else {
15523c752289SRiver Riddle         TypeRange *resultTypes = read<TypeRange *>();
15533c752289SRiver Riddle         state.types.append(resultTypes->begin(), resultTypes->end());
15543c752289SRiver Riddle       }
15553c752289SRiver Riddle     }
1556abfd1a8bSRiver Riddle   }
155785ab413bSRiver Riddle 
155814ecafd0SChia-hung Duan   Operation *resultOp = rewriter.create(state);
1559abfd1a8bSRiver Riddle   memory[memIndex] = resultOp;
1560abfd1a8bSRiver Riddle 
1561abfd1a8bSRiver Riddle   LLVM_DEBUG({
1562abfd1a8bSRiver Riddle     llvm::dbgs() << "  * Attributes: "
1563abfd1a8bSRiver Riddle                  << state.attributes.getDictionary(state.getContext())
1564abfd1a8bSRiver Riddle                  << "\n  * Operands: ";
1565abfd1a8bSRiver Riddle     llvm::interleaveComma(state.operands, llvm::dbgs());
1566abfd1a8bSRiver Riddle     llvm::dbgs() << "\n  * Result Types: ";
1567abfd1a8bSRiver Riddle     llvm::interleaveComma(state.types, llvm::dbgs());
1568154cabe7SRiver Riddle     llvm::dbgs() << "\n  * Result: " << *resultOp << "\n";
1569abfd1a8bSRiver Riddle   });
1570abfd1a8bSRiver Riddle }
1571154cabe7SRiver Riddle 
executeEraseOp(PatternRewriter & rewriter)1572154cabe7SRiver Riddle void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1573abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1574abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1575abfd1a8bSRiver Riddle 
1576154cabe7SRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1577abfd1a8bSRiver Riddle   rewriter.eraseOp(op);
1578abfd1a8bSRiver Riddle }
1579154cabe7SRiver Riddle 
15803eb1647aSStanislav Funiak template <typename T, typename Range, PDLValue::Kind kind>
executeExtract()15813eb1647aSStanislav Funiak void ByteCodeExecutor::executeExtract() {
15823eb1647aSStanislav Funiak   LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n");
15833eb1647aSStanislav Funiak   Range *range = read<Range *>();
15843eb1647aSStanislav Funiak   unsigned index = read<uint32_t>();
15853eb1647aSStanislav Funiak   unsigned memIndex = read();
15863eb1647aSStanislav Funiak 
15873eb1647aSStanislav Funiak   if (!range) {
15883eb1647aSStanislav Funiak     memory[memIndex] = nullptr;
15893eb1647aSStanislav Funiak     return;
15903eb1647aSStanislav Funiak   }
15913eb1647aSStanislav Funiak 
15923eb1647aSStanislav Funiak   T result = index < range->size() ? (*range)[index] : T();
15933eb1647aSStanislav Funiak   LLVM_DEBUG(llvm::dbgs() << "  * " << kind << "s(" << range->size() << ")\n"
15943eb1647aSStanislav Funiak                           << "  * Index: " << index << "\n"
15953eb1647aSStanislav Funiak                           << "  * Result: " << result << "\n");
15963eb1647aSStanislav Funiak   storeToMemory(memIndex, result);
15973eb1647aSStanislav Funiak }
15983eb1647aSStanislav Funiak 
executeFinalize()15993eb1647aSStanislav Funiak void ByteCodeExecutor::executeFinalize() {
16003eb1647aSStanislav Funiak   LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n");
16013eb1647aSStanislav Funiak }
16023eb1647aSStanislav Funiak 
executeForEach()16033eb1647aSStanislav Funiak void ByteCodeExecutor::executeForEach() {
16043eb1647aSStanislav Funiak   LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n");
1605d35f1190SStanislav Funiak   const ByteCodeField *prevCodeIt = getPrevCodeIt();
16063eb1647aSStanislav Funiak   unsigned rangeIndex = read();
16073eb1647aSStanislav Funiak   unsigned memIndex = read();
16083eb1647aSStanislav Funiak   const void *value = nullptr;
16093eb1647aSStanislav Funiak 
16103eb1647aSStanislav Funiak   switch (read<PDLValue::Kind>()) {
16113eb1647aSStanislav Funiak   case PDLValue::Kind::Operation: {
16123eb1647aSStanislav Funiak     unsigned &index = loopIndex[read()];
16133eb1647aSStanislav Funiak     ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
16143eb1647aSStanislav Funiak     assert(index <= array.size() && "iterated past the end");
16153eb1647aSStanislav Funiak     if (index < array.size()) {
16163eb1647aSStanislav Funiak       LLVM_DEBUG(llvm::dbgs() << "  * Result: " << array[index] << "\n");
16173eb1647aSStanislav Funiak       value = array[index];
16183eb1647aSStanislav Funiak       break;
16193eb1647aSStanislav Funiak     }
16203eb1647aSStanislav Funiak 
16213eb1647aSStanislav Funiak     LLVM_DEBUG(llvm::dbgs() << "  * Done\n");
16223eb1647aSStanislav Funiak     index = 0;
16233eb1647aSStanislav Funiak     selectJump(size_t(0));
16243eb1647aSStanislav Funiak     return;
16253eb1647aSStanislav Funiak   }
16263eb1647aSStanislav Funiak   default:
16273eb1647aSStanislav Funiak     llvm_unreachable("unexpected `ForEach` value kind");
16283eb1647aSStanislav Funiak   }
16293eb1647aSStanislav Funiak 
16303eb1647aSStanislav Funiak   // Store the iterate value and the stack address.
16313eb1647aSStanislav Funiak   memory[memIndex] = value;
1632d35f1190SStanislav Funiak   pushCodeIt(prevCodeIt);
16333eb1647aSStanislav Funiak 
16343eb1647aSStanislav Funiak   // Skip over the successor (we will enter the body of the loop).
16353eb1647aSStanislav Funiak   read<ByteCodeAddr>();
16363eb1647aSStanislav Funiak }
16373eb1647aSStanislav Funiak 
executeGetAttribute()1638154cabe7SRiver Riddle void ByteCodeExecutor::executeGetAttribute() {
1639abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1640abfd1a8bSRiver Riddle   unsigned memIndex = read();
1641abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1642195730a6SRiver Riddle   StringAttr attrName = read<StringAttr>();
1643abfd1a8bSRiver Riddle   Attribute attr = op->getAttr(attrName);
1644abfd1a8bSRiver Riddle 
1645abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1646abfd1a8bSRiver Riddle                           << "  * Attribute: " << attrName << "\n"
1647154cabe7SRiver Riddle                           << "  * Result: " << attr << "\n");
1648abfd1a8bSRiver Riddle   memory[memIndex] = attr.getAsOpaquePointer();
1649abfd1a8bSRiver Riddle }
1650154cabe7SRiver Riddle 
executeGetAttributeType()1651154cabe7SRiver Riddle void ByteCodeExecutor::executeGetAttributeType() {
1652abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1653abfd1a8bSRiver Riddle   unsigned memIndex = read();
1654abfd1a8bSRiver Riddle   Attribute attr = read<Attribute>();
1655154cabe7SRiver Riddle   Type type = attr ? attr.getType() : Type();
1656abfd1a8bSRiver Riddle 
1657abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Attribute: " << attr << "\n"
1658154cabe7SRiver Riddle                           << "  * Result: " << type << "\n");
1659154cabe7SRiver Riddle   memory[memIndex] = type.getAsOpaquePointer();
1660abfd1a8bSRiver Riddle }
1661154cabe7SRiver Riddle 
executeGetDefiningOp()1662154cabe7SRiver Riddle void ByteCodeExecutor::executeGetDefiningOp() {
1663abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1664abfd1a8bSRiver Riddle   unsigned memIndex = read();
166585ab413bSRiver Riddle   Operation *op = nullptr;
166685ab413bSRiver Riddle   if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1667abfd1a8bSRiver Riddle     Value value = read<Value>();
166885ab413bSRiver Riddle     if (value)
166985ab413bSRiver Riddle       op = value.getDefiningOp();
167085ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
167185ab413bSRiver Riddle   } else {
167285ab413bSRiver Riddle     ValueRange *values = read<ValueRange *>();
167385ab413bSRiver Riddle     if (values && !values->empty()) {
167485ab413bSRiver Riddle       op = values->front().getDefiningOp();
167585ab413bSRiver Riddle     }
167685ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "  * Values: " << values << "\n");
167785ab413bSRiver Riddle   }
1678abfd1a8bSRiver Riddle 
167985ab413bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Result: " << op << "\n");
1680abfd1a8bSRiver Riddle   memory[memIndex] = op;
1681abfd1a8bSRiver Riddle }
1682154cabe7SRiver Riddle 
executeGetOperand(unsigned index)1683154cabe7SRiver Riddle void ByteCodeExecutor::executeGetOperand(unsigned index) {
1684abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1685abfd1a8bSRiver Riddle   unsigned memIndex = read();
1686abfd1a8bSRiver Riddle   Value operand =
1687abfd1a8bSRiver Riddle       index < op->getNumOperands() ? op->getOperand(index) : Value();
1688abfd1a8bSRiver Riddle 
1689abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1690abfd1a8bSRiver Riddle                           << "  * Index: " << index << "\n"
1691154cabe7SRiver Riddle                           << "  * Result: " << operand << "\n");
1692abfd1a8bSRiver Riddle   memory[memIndex] = operand.getAsOpaquePointer();
1693abfd1a8bSRiver Riddle }
1694154cabe7SRiver Riddle 
169585ab413bSRiver Riddle /// This function is the internal implementation of `GetResults` and
169685ab413bSRiver Riddle /// `GetOperands` that provides support for extracting a value range from the
169785ab413bSRiver Riddle /// given operation.
169885ab413bSRiver Riddle template <template <typename> class AttrSizedSegmentsT, typename RangeT>
169985ab413bSRiver Riddle static void *
executeGetOperandsResults(RangeT values,Operation * op,unsigned index,ByteCodeField rangeIndex,StringRef attrSizedSegments,MutableArrayRef<ValueRange> valueRangeMemory)170085ab413bSRiver Riddle executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
170185ab413bSRiver Riddle                           ByteCodeField rangeIndex, StringRef attrSizedSegments,
17023eb1647aSStanislav Funiak                           MutableArrayRef<ValueRange> valueRangeMemory) {
170385ab413bSRiver Riddle   // Check for the sentinel index that signals that all values should be
170485ab413bSRiver Riddle   // returned.
170585ab413bSRiver Riddle   if (index == std::numeric_limits<uint32_t>::max()) {
170685ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "  * Getting all values\n");
170785ab413bSRiver Riddle     // `values` is already the full value range.
170885ab413bSRiver Riddle 
170985ab413bSRiver Riddle     // Otherwise, check to see if this operation uses AttrSizedSegments.
171085ab413bSRiver Riddle   } else if (op->hasTrait<AttrSizedSegmentsT>()) {
171185ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs()
171285ab413bSRiver Riddle                << "  * Extracting values from `" << attrSizedSegments << "`\n");
171385ab413bSRiver Riddle 
171485ab413bSRiver Riddle     auto segmentAttr = op->getAttrOfType<DenseElementsAttr>(attrSizedSegments);
171585ab413bSRiver Riddle     if (!segmentAttr || segmentAttr.getNumElements() <= index)
171685ab413bSRiver Riddle       return nullptr;
171785ab413bSRiver Riddle 
171885ab413bSRiver Riddle     auto segments = segmentAttr.getValues<int32_t>();
171985ab413bSRiver Riddle     unsigned startIndex =
172085ab413bSRiver Riddle         std::accumulate(segments.begin(), segments.begin() + index, 0);
172185ab413bSRiver Riddle     values = values.slice(startIndex, *std::next(segments.begin(), index));
172285ab413bSRiver Riddle 
172385ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "  * Extracting range[" << startIndex << ", "
172485ab413bSRiver Riddle                             << *std::next(segments.begin(), index) << "]\n");
172585ab413bSRiver Riddle 
172685ab413bSRiver Riddle     // Otherwise, assume this is the last operand group of the operation.
172785ab413bSRiver Riddle     // FIXME: We currently don't support operations with
172885ab413bSRiver Riddle     // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
172985ab413bSRiver Riddle     // have a way to detect it's presence.
173085ab413bSRiver Riddle   } else if (values.size() >= index) {
173185ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs()
173285ab413bSRiver Riddle                << "  * Treating values as trailing variadic range\n");
173385ab413bSRiver Riddle     values = values.drop_front(index);
173485ab413bSRiver Riddle 
173585ab413bSRiver Riddle     // If we couldn't detect a way to compute the values, bail out.
173685ab413bSRiver Riddle   } else {
173785ab413bSRiver Riddle     return nullptr;
173885ab413bSRiver Riddle   }
173985ab413bSRiver Riddle 
174085ab413bSRiver Riddle   // If the range index is valid, we are returning a range.
174185ab413bSRiver Riddle   if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
174285ab413bSRiver Riddle     valueRangeMemory[rangeIndex] = values;
174385ab413bSRiver Riddle     return &valueRangeMemory[rangeIndex];
174485ab413bSRiver Riddle   }
174585ab413bSRiver Riddle 
174685ab413bSRiver Riddle   // If a range index wasn't provided, the range is required to be non-variadic.
174785ab413bSRiver Riddle   return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
174885ab413bSRiver Riddle }
174985ab413bSRiver Riddle 
executeGetOperands()175085ab413bSRiver Riddle void ByteCodeExecutor::executeGetOperands() {
175185ab413bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
175285ab413bSRiver Riddle   unsigned index = read<uint32_t>();
175385ab413bSRiver Riddle   Operation *op = read<Operation *>();
175485ab413bSRiver Riddle   ByteCodeField rangeIndex = read();
175585ab413bSRiver Riddle 
175685ab413bSRiver Riddle   void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
175785ab413bSRiver Riddle       op->getOperands(), op, index, rangeIndex, "operand_segment_sizes",
175885ab413bSRiver Riddle       valueRangeMemory);
175985ab413bSRiver Riddle   if (!result)
176085ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "  * Invalid operand range\n");
176185ab413bSRiver Riddle   memory[read()] = result;
176285ab413bSRiver Riddle }
176385ab413bSRiver Riddle 
executeGetResult(unsigned index)1764154cabe7SRiver Riddle void ByteCodeExecutor::executeGetResult(unsigned index) {
1765abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1766abfd1a8bSRiver Riddle   unsigned memIndex = read();
1767abfd1a8bSRiver Riddle   OpResult result =
1768abfd1a8bSRiver Riddle       index < op->getNumResults() ? op->getResult(index) : OpResult();
1769abfd1a8bSRiver Riddle 
1770abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1771abfd1a8bSRiver Riddle                           << "  * Index: " << index << "\n"
1772154cabe7SRiver Riddle                           << "  * Result: " << result << "\n");
1773abfd1a8bSRiver Riddle   memory[memIndex] = result.getAsOpaquePointer();
1774abfd1a8bSRiver Riddle }
1775154cabe7SRiver Riddle 
executeGetResults()177685ab413bSRiver Riddle void ByteCodeExecutor::executeGetResults() {
177785ab413bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
177885ab413bSRiver Riddle   unsigned index = read<uint32_t>();
177985ab413bSRiver Riddle   Operation *op = read<Operation *>();
178085ab413bSRiver Riddle   ByteCodeField rangeIndex = read();
178185ab413bSRiver Riddle 
178285ab413bSRiver Riddle   void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
178385ab413bSRiver Riddle       op->getResults(), op, index, rangeIndex, "result_segment_sizes",
178485ab413bSRiver Riddle       valueRangeMemory);
178585ab413bSRiver Riddle   if (!result)
178685ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "  * Invalid result range\n");
178785ab413bSRiver Riddle   memory[read()] = result;
178885ab413bSRiver Riddle }
178985ab413bSRiver Riddle 
executeGetUsers()17903eb1647aSStanislav Funiak void ByteCodeExecutor::executeGetUsers() {
17913eb1647aSStanislav Funiak   LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n");
17923eb1647aSStanislav Funiak   unsigned memIndex = read();
17933eb1647aSStanislav Funiak   unsigned rangeIndex = read();
17943eb1647aSStanislav Funiak   OwningOpRange &range = opRangeMemory[rangeIndex];
17953eb1647aSStanislav Funiak   memory[memIndex] = &range;
17963eb1647aSStanislav Funiak 
17973eb1647aSStanislav Funiak   range = OwningOpRange();
17983eb1647aSStanislav Funiak   if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
17993eb1647aSStanislav Funiak     // Read the value.
18003eb1647aSStanislav Funiak     Value value = read<Value>();
18013eb1647aSStanislav Funiak     if (!value)
18023eb1647aSStanislav Funiak       return;
18033eb1647aSStanislav Funiak     LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
18043eb1647aSStanislav Funiak 
18053eb1647aSStanislav Funiak     // Extract the users of a single value.
18063eb1647aSStanislav Funiak     range = OwningOpRange(std::distance(value.user_begin(), value.user_end()));
18073eb1647aSStanislav Funiak     llvm::copy(value.getUsers(), range.begin());
18083eb1647aSStanislav Funiak   } else {
18093eb1647aSStanislav Funiak     // Read a range of values.
18103eb1647aSStanislav Funiak     ValueRange *values = read<ValueRange *>();
18113eb1647aSStanislav Funiak     if (!values)
18123eb1647aSStanislav Funiak       return;
18133eb1647aSStanislav Funiak     LLVM_DEBUG({
18143eb1647aSStanislav Funiak       llvm::dbgs() << "  * Values (" << values->size() << "): ";
18153eb1647aSStanislav Funiak       llvm::interleaveComma(*values, llvm::dbgs());
18163eb1647aSStanislav Funiak       llvm::dbgs() << "\n";
18173eb1647aSStanislav Funiak     });
18183eb1647aSStanislav Funiak 
18193eb1647aSStanislav Funiak     // Extract all the users of a range of values.
18203eb1647aSStanislav Funiak     SmallVector<Operation *> users;
18213eb1647aSStanislav Funiak     for (Value value : *values)
18223eb1647aSStanislav Funiak       users.append(value.user_begin(), value.user_end());
18233eb1647aSStanislav Funiak     range = OwningOpRange(users.size());
18243eb1647aSStanislav Funiak     llvm::copy(users, range.begin());
18253eb1647aSStanislav Funiak   }
18263eb1647aSStanislav Funiak 
18273eb1647aSStanislav Funiak   LLVM_DEBUG(llvm::dbgs() << "  * Result: " << range.size() << " operations\n");
18283eb1647aSStanislav Funiak }
18293eb1647aSStanislav Funiak 
executeGetValueType()1830154cabe7SRiver Riddle void ByteCodeExecutor::executeGetValueType() {
1831abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1832abfd1a8bSRiver Riddle   unsigned memIndex = read();
1833abfd1a8bSRiver Riddle   Value value = read<Value>();
1834154cabe7SRiver Riddle   Type type = value ? value.getType() : Type();
1835abfd1a8bSRiver Riddle 
1836abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n"
1837154cabe7SRiver Riddle                           << "  * Result: " << type << "\n");
1838154cabe7SRiver Riddle   memory[memIndex] = type.getAsOpaquePointer();
1839abfd1a8bSRiver Riddle }
1840154cabe7SRiver Riddle 
executeGetValueRangeTypes()184185ab413bSRiver Riddle void ByteCodeExecutor::executeGetValueRangeTypes() {
184285ab413bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
184385ab413bSRiver Riddle   unsigned memIndex = read();
184485ab413bSRiver Riddle   unsigned rangeIndex = read();
184585ab413bSRiver Riddle   ValueRange *values = read<ValueRange *>();
184685ab413bSRiver Riddle   if (!values) {
184785ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "  * Values: <NULL>\n\n");
184885ab413bSRiver Riddle     memory[memIndex] = nullptr;
184985ab413bSRiver Riddle     return;
185085ab413bSRiver Riddle   }
185185ab413bSRiver Riddle 
185285ab413bSRiver Riddle   LLVM_DEBUG({
185385ab413bSRiver Riddle     llvm::dbgs() << "  * Values (" << values->size() << "): ";
185485ab413bSRiver Riddle     llvm::interleaveComma(*values, llvm::dbgs());
185585ab413bSRiver Riddle     llvm::dbgs() << "\n  * Result: ";
185685ab413bSRiver Riddle     llvm::interleaveComma(values->getType(), llvm::dbgs());
185785ab413bSRiver Riddle     llvm::dbgs() << "\n";
185885ab413bSRiver Riddle   });
185985ab413bSRiver Riddle   typeRangeMemory[rangeIndex] = values->getType();
186085ab413bSRiver Riddle   memory[memIndex] = &typeRangeMemory[rangeIndex];
186185ab413bSRiver Riddle }
186285ab413bSRiver Riddle 
executeIsNotNull()1863154cabe7SRiver Riddle void ByteCodeExecutor::executeIsNotNull() {
1864abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
1865abfd1a8bSRiver Riddle   const void *value = read<const void *>();
1866abfd1a8bSRiver Riddle 
1867154cabe7SRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1868abfd1a8bSRiver Riddle   selectJump(value != nullptr);
1869abfd1a8bSRiver Riddle }
1870154cabe7SRiver Riddle 
executeRecordMatch(PatternRewriter & rewriter,SmallVectorImpl<PDLByteCode::MatchResult> & matches)1871154cabe7SRiver Riddle void ByteCodeExecutor::executeRecordMatch(
1872154cabe7SRiver Riddle     PatternRewriter &rewriter,
1873154cabe7SRiver Riddle     SmallVectorImpl<PDLByteCode::MatchResult> &matches) {
1874abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
1875abfd1a8bSRiver Riddle   unsigned patternIndex = read();
1876abfd1a8bSRiver Riddle   PatternBenefit benefit = currentPatternBenefits[patternIndex];
1877abfd1a8bSRiver Riddle   const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1878abfd1a8bSRiver Riddle 
1879abfd1a8bSRiver Riddle   // If the benefit of the pattern is impossible, skip the processing of the
1880abfd1a8bSRiver Riddle   // rest of the pattern.
1881abfd1a8bSRiver Riddle   if (benefit.isImpossibleToMatch()) {
1882154cabe7SRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "  * Benefit: Impossible To Match\n");
1883abfd1a8bSRiver Riddle     curCodeIt = dest;
1884154cabe7SRiver Riddle     return;
1885abfd1a8bSRiver Riddle   }
1886abfd1a8bSRiver Riddle 
1887abfd1a8bSRiver Riddle   // Create a fused location containing the locations of each of the
1888abfd1a8bSRiver Riddle   // operations used in the match. This will be used as the location for
1889abfd1a8bSRiver Riddle   // created operations during the rewrite that don't already have an
1890abfd1a8bSRiver Riddle   // explicit location set.
1891abfd1a8bSRiver Riddle   unsigned numMatchLocs = read();
1892abfd1a8bSRiver Riddle   SmallVector<Location, 4> matchLocs;
1893abfd1a8bSRiver Riddle   matchLocs.reserve(numMatchLocs);
1894abfd1a8bSRiver Riddle   for (unsigned i = 0; i != numMatchLocs; ++i)
1895abfd1a8bSRiver Riddle     matchLocs.push_back(read<Operation *>()->getLoc());
1896abfd1a8bSRiver Riddle   Location matchLoc = rewriter.getFusedLoc(matchLocs);
1897abfd1a8bSRiver Riddle 
1898abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Benefit: " << benefit.getBenefit() << "\n"
1899154cabe7SRiver Riddle                           << "  * Location: " << matchLoc << "\n");
1900154cabe7SRiver Riddle   matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
190185ab413bSRiver Riddle   PDLByteCode::MatchResult &match = matches.back();
190285ab413bSRiver Riddle 
190385ab413bSRiver Riddle   // Record all of the inputs to the match. If any of the inputs are ranges, we
190485ab413bSRiver Riddle   // will also need to remap the range pointer to memory stored in the match
190585ab413bSRiver Riddle   // state.
190685ab413bSRiver Riddle   unsigned numInputs = read();
190785ab413bSRiver Riddle   match.values.reserve(numInputs);
190885ab413bSRiver Riddle   match.typeRangeValues.reserve(numInputs);
190985ab413bSRiver Riddle   match.valueRangeValues.reserve(numInputs);
191085ab413bSRiver Riddle   for (unsigned i = 0; i < numInputs; ++i) {
191185ab413bSRiver Riddle     switch (read<PDLValue::Kind>()) {
191285ab413bSRiver Riddle     case PDLValue::Kind::TypeRange:
191385ab413bSRiver Riddle       match.typeRangeValues.push_back(*read<TypeRange *>());
191485ab413bSRiver Riddle       match.values.push_back(&match.typeRangeValues.back());
191585ab413bSRiver Riddle       break;
191685ab413bSRiver Riddle     case PDLValue::Kind::ValueRange:
191785ab413bSRiver Riddle       match.valueRangeValues.push_back(*read<ValueRange *>());
191885ab413bSRiver Riddle       match.values.push_back(&match.valueRangeValues.back());
191985ab413bSRiver Riddle       break;
192085ab413bSRiver Riddle     default:
192185ab413bSRiver Riddle       match.values.push_back(read<const void *>());
192285ab413bSRiver Riddle       break;
192385ab413bSRiver Riddle     }
192485ab413bSRiver Riddle   }
1925abfd1a8bSRiver Riddle   curCodeIt = dest;
1926abfd1a8bSRiver Riddle }
1927154cabe7SRiver Riddle 
executeReplaceOp(PatternRewriter & rewriter)1928154cabe7SRiver Riddle void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
1929abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
1930abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1931abfd1a8bSRiver Riddle   SmallVector<Value, 16> args;
193285ab413bSRiver Riddle   readValueList(args);
1933abfd1a8bSRiver Riddle 
1934abfd1a8bSRiver Riddle   LLVM_DEBUG({
1935abfd1a8bSRiver Riddle     llvm::dbgs() << "  * Operation: " << *op << "\n"
1936abfd1a8bSRiver Riddle                  << "  * Values: ";
1937abfd1a8bSRiver Riddle     llvm::interleaveComma(args, llvm::dbgs());
1938154cabe7SRiver Riddle     llvm::dbgs() << "\n";
1939abfd1a8bSRiver Riddle   });
1940abfd1a8bSRiver Riddle   rewriter.replaceOp(op, args);
1941abfd1a8bSRiver Riddle }
1942154cabe7SRiver Riddle 
executeSwitchAttribute()1943154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchAttribute() {
1944abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
1945abfd1a8bSRiver Riddle   Attribute value = read<Attribute>();
1946abfd1a8bSRiver Riddle   ArrayAttr cases = read<ArrayAttr>();
1947abfd1a8bSRiver Riddle   handleSwitch(value, cases);
1948abfd1a8bSRiver Riddle }
1949154cabe7SRiver Riddle 
executeSwitchOperandCount()1950154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchOperandCount() {
1951abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
1952abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1953abfd1a8bSRiver Riddle   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1954abfd1a8bSRiver Riddle 
1955abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1956abfd1a8bSRiver Riddle   handleSwitch(op->getNumOperands(), cases);
1957abfd1a8bSRiver Riddle }
1958154cabe7SRiver Riddle 
executeSwitchOperationName()1959154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchOperationName() {
1960abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
1961abfd1a8bSRiver Riddle   OperationName value = read<Operation *>()->getName();
1962abfd1a8bSRiver Riddle   size_t caseCount = read();
1963abfd1a8bSRiver Riddle 
1964abfd1a8bSRiver Riddle   // The operation names are stored in-line, so to print them out for
1965abfd1a8bSRiver Riddle   // debugging purposes we need to read the array before executing the
1966abfd1a8bSRiver Riddle   // switch so that we can display all of the possible values.
1967abfd1a8bSRiver Riddle   LLVM_DEBUG({
1968abfd1a8bSRiver Riddle     const ByteCodeField *prevCodeIt = curCodeIt;
1969abfd1a8bSRiver Riddle     llvm::dbgs() << "  * Value: " << value << "\n"
1970abfd1a8bSRiver Riddle                  << "  * Cases: ";
1971abfd1a8bSRiver Riddle     llvm::interleaveComma(
1972abfd1a8bSRiver Riddle         llvm::map_range(llvm::seq<size_t>(0, caseCount),
1973154cabe7SRiver Riddle                         [&](size_t) { return read<OperationName>(); }),
1974abfd1a8bSRiver Riddle         llvm::dbgs());
1975154cabe7SRiver Riddle     llvm::dbgs() << "\n";
1976abfd1a8bSRiver Riddle     curCodeIt = prevCodeIt;
1977abfd1a8bSRiver Riddle   });
1978abfd1a8bSRiver Riddle 
1979abfd1a8bSRiver Riddle   // Try to find the switch value within any of the cases.
1980abfd1a8bSRiver Riddle   for (size_t i = 0; i != caseCount; ++i) {
1981abfd1a8bSRiver Riddle     if (read<OperationName>() == value) {
1982abfd1a8bSRiver Riddle       curCodeIt += (caseCount - i - 1);
1983154cabe7SRiver Riddle       return selectJump(i + 1);
1984abfd1a8bSRiver Riddle     }
1985abfd1a8bSRiver Riddle   }
1986154cabe7SRiver Riddle   selectJump(size_t(0));
1987abfd1a8bSRiver Riddle }
1988154cabe7SRiver Riddle 
executeSwitchResultCount()1989154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchResultCount() {
1990abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
1991abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1992abfd1a8bSRiver Riddle   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1993abfd1a8bSRiver Riddle 
1994abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1995abfd1a8bSRiver Riddle   handleSwitch(op->getNumResults(), cases);
1996abfd1a8bSRiver Riddle }
1997154cabe7SRiver Riddle 
executeSwitchType()1998154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchType() {
1999abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
2000abfd1a8bSRiver Riddle   Type value = read<Type>();
2001abfd1a8bSRiver Riddle   auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2002abfd1a8bSRiver Riddle   handleSwitch(value, cases);
2003154cabe7SRiver Riddle }
2004154cabe7SRiver Riddle 
executeSwitchTypes()200585ab413bSRiver Riddle void ByteCodeExecutor::executeSwitchTypes() {
200685ab413bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
200785ab413bSRiver Riddle   TypeRange *value = read<TypeRange *>();
200885ab413bSRiver Riddle   auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
200985ab413bSRiver Riddle   if (!value) {
201085ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
201185ab413bSRiver Riddle     return selectJump(size_t(0));
201285ab413bSRiver Riddle   }
201385ab413bSRiver Riddle   handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
201485ab413bSRiver Riddle     return value == caseValue.getAsValueRange<TypeAttr>();
201585ab413bSRiver Riddle   });
201685ab413bSRiver Riddle }
201785ab413bSRiver Riddle 
execute(PatternRewriter & rewriter,SmallVectorImpl<PDLByteCode::MatchResult> * matches,Optional<Location> mainRewriteLoc)2018154cabe7SRiver Riddle void ByteCodeExecutor::execute(
2019154cabe7SRiver Riddle     PatternRewriter &rewriter,
2020154cabe7SRiver Riddle     SmallVectorImpl<PDLByteCode::MatchResult> *matches,
2021154cabe7SRiver Riddle     Optional<Location> mainRewriteLoc) {
2022154cabe7SRiver Riddle   while (true) {
2023d35f1190SStanislav Funiak     // Print the location of the operation being executed.
2024d35f1190SStanislav Funiak     LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n");
2025d35f1190SStanislav Funiak 
2026154cabe7SRiver Riddle     OpCode opCode = static_cast<OpCode>(read());
2027154cabe7SRiver Riddle     switch (opCode) {
2028154cabe7SRiver Riddle     case ApplyConstraint:
2029154cabe7SRiver Riddle       executeApplyConstraint(rewriter);
2030154cabe7SRiver Riddle       break;
2031154cabe7SRiver Riddle     case ApplyRewrite:
2032154cabe7SRiver Riddle       executeApplyRewrite(rewriter);
2033154cabe7SRiver Riddle       break;
2034154cabe7SRiver Riddle     case AreEqual:
2035154cabe7SRiver Riddle       executeAreEqual();
2036154cabe7SRiver Riddle       break;
203785ab413bSRiver Riddle     case AreRangesEqual:
203885ab413bSRiver Riddle       executeAreRangesEqual();
203985ab413bSRiver Riddle       break;
2040154cabe7SRiver Riddle     case Branch:
2041154cabe7SRiver Riddle       executeBranch();
2042154cabe7SRiver Riddle       break;
2043154cabe7SRiver Riddle     case CheckOperandCount:
2044154cabe7SRiver Riddle       executeCheckOperandCount();
2045154cabe7SRiver Riddle       break;
2046154cabe7SRiver Riddle     case CheckOperationName:
2047154cabe7SRiver Riddle       executeCheckOperationName();
2048154cabe7SRiver Riddle       break;
2049154cabe7SRiver Riddle     case CheckResultCount:
2050154cabe7SRiver Riddle       executeCheckResultCount();
2051154cabe7SRiver Riddle       break;
205285ab413bSRiver Riddle     case CheckTypes:
205385ab413bSRiver Riddle       executeCheckTypes();
205485ab413bSRiver Riddle       break;
20553eb1647aSStanislav Funiak     case Continue:
20563eb1647aSStanislav Funiak       executeContinue();
20573eb1647aSStanislav Funiak       break;
2058154cabe7SRiver Riddle     case CreateOperation:
2059154cabe7SRiver Riddle       executeCreateOperation(rewriter, *mainRewriteLoc);
2060154cabe7SRiver Riddle       break;
206185ab413bSRiver Riddle     case CreateTypes:
206285ab413bSRiver Riddle       executeCreateTypes();
206385ab413bSRiver Riddle       break;
2064154cabe7SRiver Riddle     case EraseOp:
2065154cabe7SRiver Riddle       executeEraseOp(rewriter);
2066154cabe7SRiver Riddle       break;
20673eb1647aSStanislav Funiak     case ExtractOp:
20683eb1647aSStanislav Funiak       executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
20693eb1647aSStanislav Funiak       break;
20703eb1647aSStanislav Funiak     case ExtractType:
20713eb1647aSStanislav Funiak       executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
20723eb1647aSStanislav Funiak       break;
20733eb1647aSStanislav Funiak     case ExtractValue:
20743eb1647aSStanislav Funiak       executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
20753eb1647aSStanislav Funiak       break;
2076154cabe7SRiver Riddle     case Finalize:
20773eb1647aSStanislav Funiak       executeFinalize();
20783eb1647aSStanislav Funiak       LLVM_DEBUG(llvm::dbgs() << "\n");
2079154cabe7SRiver Riddle       return;
20803eb1647aSStanislav Funiak     case ForEach:
20813eb1647aSStanislav Funiak       executeForEach();
20823eb1647aSStanislav Funiak       break;
2083154cabe7SRiver Riddle     case GetAttribute:
2084154cabe7SRiver Riddle       executeGetAttribute();
2085154cabe7SRiver Riddle       break;
2086154cabe7SRiver Riddle     case GetAttributeType:
2087154cabe7SRiver Riddle       executeGetAttributeType();
2088154cabe7SRiver Riddle       break;
2089154cabe7SRiver Riddle     case GetDefiningOp:
2090154cabe7SRiver Riddle       executeGetDefiningOp();
2091154cabe7SRiver Riddle       break;
2092154cabe7SRiver Riddle     case GetOperand0:
2093154cabe7SRiver Riddle     case GetOperand1:
2094154cabe7SRiver Riddle     case GetOperand2:
2095154cabe7SRiver Riddle     case GetOperand3: {
2096154cabe7SRiver Riddle       unsigned index = opCode - GetOperand0;
2097154cabe7SRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
20981fff7c89SFrederik Gossen       executeGetOperand(index);
2099abfd1a8bSRiver Riddle       break;
2100abfd1a8bSRiver Riddle     }
2101154cabe7SRiver Riddle     case GetOperandN:
2102154cabe7SRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
2103154cabe7SRiver Riddle       executeGetOperand(read<uint32_t>());
2104154cabe7SRiver Riddle       break;
210585ab413bSRiver Riddle     case GetOperands:
210685ab413bSRiver Riddle       executeGetOperands();
210785ab413bSRiver Riddle       break;
2108154cabe7SRiver Riddle     case GetResult0:
2109154cabe7SRiver Riddle     case GetResult1:
2110154cabe7SRiver Riddle     case GetResult2:
2111154cabe7SRiver Riddle     case GetResult3: {
2112154cabe7SRiver Riddle       unsigned index = opCode - GetResult0;
2113154cabe7SRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
21141fff7c89SFrederik Gossen       executeGetResult(index);
2115154cabe7SRiver Riddle       break;
2116abfd1a8bSRiver Riddle     }
2117154cabe7SRiver Riddle     case GetResultN:
2118154cabe7SRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
2119154cabe7SRiver Riddle       executeGetResult(read<uint32_t>());
2120154cabe7SRiver Riddle       break;
212185ab413bSRiver Riddle     case GetResults:
212285ab413bSRiver Riddle       executeGetResults();
212385ab413bSRiver Riddle       break;
21243eb1647aSStanislav Funiak     case GetUsers:
21253eb1647aSStanislav Funiak       executeGetUsers();
21263eb1647aSStanislav Funiak       break;
2127154cabe7SRiver Riddle     case GetValueType:
2128154cabe7SRiver Riddle       executeGetValueType();
2129154cabe7SRiver Riddle       break;
213085ab413bSRiver Riddle     case GetValueRangeTypes:
213185ab413bSRiver Riddle       executeGetValueRangeTypes();
213285ab413bSRiver Riddle       break;
2133154cabe7SRiver Riddle     case IsNotNull:
2134154cabe7SRiver Riddle       executeIsNotNull();
2135154cabe7SRiver Riddle       break;
2136154cabe7SRiver Riddle     case RecordMatch:
2137154cabe7SRiver Riddle       assert(matches &&
2138154cabe7SRiver Riddle              "expected matches to be provided when executing the matcher");
2139154cabe7SRiver Riddle       executeRecordMatch(rewriter, *matches);
2140154cabe7SRiver Riddle       break;
2141154cabe7SRiver Riddle     case ReplaceOp:
2142154cabe7SRiver Riddle       executeReplaceOp(rewriter);
2143154cabe7SRiver Riddle       break;
2144154cabe7SRiver Riddle     case SwitchAttribute:
2145154cabe7SRiver Riddle       executeSwitchAttribute();
2146154cabe7SRiver Riddle       break;
2147154cabe7SRiver Riddle     case SwitchOperandCount:
2148154cabe7SRiver Riddle       executeSwitchOperandCount();
2149154cabe7SRiver Riddle       break;
2150154cabe7SRiver Riddle     case SwitchOperationName:
2151154cabe7SRiver Riddle       executeSwitchOperationName();
2152154cabe7SRiver Riddle       break;
2153154cabe7SRiver Riddle     case SwitchResultCount:
2154154cabe7SRiver Riddle       executeSwitchResultCount();
2155154cabe7SRiver Riddle       break;
2156154cabe7SRiver Riddle     case SwitchType:
2157154cabe7SRiver Riddle       executeSwitchType();
2158154cabe7SRiver Riddle       break;
215985ab413bSRiver Riddle     case SwitchTypes:
216085ab413bSRiver Riddle       executeSwitchTypes();
216185ab413bSRiver Riddle       break;
2162154cabe7SRiver Riddle     }
2163154cabe7SRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "\n");
2164abfd1a8bSRiver Riddle   }
2165abfd1a8bSRiver Riddle }
2166abfd1a8bSRiver Riddle 
2167abfd1a8bSRiver Riddle /// Run the pattern matcher on the given root operation, collecting the matched
2168abfd1a8bSRiver Riddle /// patterns in `matches`.
match(Operation * op,PatternRewriter & rewriter,SmallVectorImpl<MatchResult> & matches,PDLByteCodeMutableState & state) const2169abfd1a8bSRiver Riddle void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
2170abfd1a8bSRiver Riddle                         SmallVectorImpl<MatchResult> &matches,
2171abfd1a8bSRiver Riddle                         PDLByteCodeMutableState &state) const {
2172abfd1a8bSRiver Riddle   // The first memory slot is always the root operation.
2173abfd1a8bSRiver Riddle   state.memory[0] = op;
2174abfd1a8bSRiver Riddle 
2175abfd1a8bSRiver Riddle   // The matcher function always starts at code address 0.
217685ab413bSRiver Riddle   ByteCodeExecutor executor(
21773eb1647aSStanislav Funiak       matcherByteCode.data(), state.memory, state.opRangeMemory,
21783eb1647aSStanislav Funiak       state.typeRangeMemory, state.allocatedTypeRangeMemory,
21793eb1647aSStanislav Funiak       state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
21803eb1647aSStanislav Funiak       uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
21813eb1647aSStanislav Funiak       constraintFunctions, rewriteFunctions);
2182abfd1a8bSRiver Riddle   executor.execute(rewriter, &matches);
2183abfd1a8bSRiver Riddle 
2184abfd1a8bSRiver Riddle   // Order the found matches by benefit.
2185abfd1a8bSRiver Riddle   std::stable_sort(matches.begin(), matches.end(),
2186abfd1a8bSRiver Riddle                    [](const MatchResult &lhs, const MatchResult &rhs) {
2187abfd1a8bSRiver Riddle                      return lhs.benefit > rhs.benefit;
2188abfd1a8bSRiver Riddle                    });
2189abfd1a8bSRiver Riddle }
2190abfd1a8bSRiver Riddle 
2191abfd1a8bSRiver Riddle /// Run the rewriter of the given pattern on the root operation `op`.
rewrite(PatternRewriter & rewriter,const MatchResult & match,PDLByteCodeMutableState & state) const2192abfd1a8bSRiver Riddle void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
2193abfd1a8bSRiver Riddle                           PDLByteCodeMutableState &state) const {
2194abfd1a8bSRiver Riddle   // The arguments of the rewrite function are stored at the start of the
2195abfd1a8bSRiver Riddle   // memory buffer.
2196abfd1a8bSRiver Riddle   llvm::copy(match.values, state.memory.begin());
2197abfd1a8bSRiver Riddle 
219885ab413bSRiver Riddle   ByteCodeExecutor executor(
219985ab413bSRiver Riddle       &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
22003eb1647aSStanislav Funiak       state.opRangeMemory, state.typeRangeMemory,
22013eb1647aSStanislav Funiak       state.allocatedTypeRangeMemory, state.valueRangeMemory,
22023eb1647aSStanislav Funiak       state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
220385ab413bSRiver Riddle       rewriterByteCode, state.currentPatternBenefits, patterns,
220402c4c0d5SRiver Riddle       constraintFunctions, rewriteFunctions);
2205abfd1a8bSRiver Riddle   executor.execute(rewriter, /*matches=*/nullptr, match.location);
2206abfd1a8bSRiver Riddle }
2207