1abfd1a8bSRiver Riddle //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2abfd1a8bSRiver Riddle //
3abfd1a8bSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4abfd1a8bSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5abfd1a8bSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6abfd1a8bSRiver Riddle //
7abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
8abfd1a8bSRiver Riddle //
9abfd1a8bSRiver Riddle // This file implements MLIR to byte-code generation and the interpreter.
10abfd1a8bSRiver Riddle //
11abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
12abfd1a8bSRiver Riddle 
13abfd1a8bSRiver Riddle #include "ByteCode.h"
14abfd1a8bSRiver Riddle #include "mlir/Analysis/Liveness.h"
15abfd1a8bSRiver Riddle #include "mlir/Dialect/PDL/IR/PDLTypes.h"
16abfd1a8bSRiver Riddle #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
17e66c2e25SRiver Riddle #include "mlir/IR/BuiltinOps.h"
18abfd1a8bSRiver Riddle #include "mlir/IR/RegionGraphTraits.h"
19abfd1a8bSRiver Riddle #include "llvm/ADT/IntervalMap.h"
20abfd1a8bSRiver Riddle #include "llvm/ADT/PostOrderIterator.h"
21abfd1a8bSRiver Riddle #include "llvm/ADT/TypeSwitch.h"
22abfd1a8bSRiver Riddle #include "llvm/Support/Debug.h"
2385ab413bSRiver Riddle #include "llvm/Support/Format.h"
2485ab413bSRiver Riddle #include "llvm/Support/FormatVariadic.h"
2585ab413bSRiver Riddle #include <numeric>
26abfd1a8bSRiver Riddle 
27abfd1a8bSRiver Riddle #define DEBUG_TYPE "pdl-bytecode"
28abfd1a8bSRiver Riddle 
29abfd1a8bSRiver Riddle using namespace mlir;
30abfd1a8bSRiver Riddle using namespace mlir::detail;
31abfd1a8bSRiver Riddle 
32abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
33abfd1a8bSRiver Riddle // PDLByteCodePattern
34abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
35abfd1a8bSRiver Riddle 
36abfd1a8bSRiver Riddle PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
37abfd1a8bSRiver Riddle                                               ByteCodeAddr rewriterAddr) {
38abfd1a8bSRiver Riddle   SmallVector<StringRef, 8> generatedOps;
39abfd1a8bSRiver Riddle   if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr())
40abfd1a8bSRiver Riddle     generatedOps =
41abfd1a8bSRiver Riddle         llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
42abfd1a8bSRiver Riddle 
43abfd1a8bSRiver Riddle   PatternBenefit benefit = matchOp.benefit();
44abfd1a8bSRiver Riddle   MLIRContext *ctx = matchOp.getContext();
45abfd1a8bSRiver Riddle 
46abfd1a8bSRiver Riddle   // Check to see if this is pattern matches a specific operation type.
47abfd1a8bSRiver Riddle   if (Optional<StringRef> rootKind = matchOp.rootKind())
4876f3c2f3SRiver Riddle     return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx,
4976f3c2f3SRiver Riddle                               generatedOps);
5076f3c2f3SRiver Riddle   return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx,
5176f3c2f3SRiver Riddle                             generatedOps);
52abfd1a8bSRiver Riddle }
53abfd1a8bSRiver Riddle 
54abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
55abfd1a8bSRiver Riddle // PDLByteCodeMutableState
56abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
57abfd1a8bSRiver Riddle 
58abfd1a8bSRiver Riddle /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
59abfd1a8bSRiver Riddle /// to the position of the pattern within the range returned by
60abfd1a8bSRiver Riddle /// `PDLByteCode::getPatterns`.
61abfd1a8bSRiver Riddle void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
62abfd1a8bSRiver Riddle                                                    PatternBenefit benefit) {
63abfd1a8bSRiver Riddle   currentPatternBenefits[patternIndex] = benefit;
64abfd1a8bSRiver Riddle }
65abfd1a8bSRiver Riddle 
6685ab413bSRiver Riddle /// Cleanup any allocated state after a full match/rewrite has been completed.
6785ab413bSRiver Riddle /// This method should be called irregardless of whether the match+rewrite was a
6885ab413bSRiver Riddle /// success or not.
6985ab413bSRiver Riddle void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() {
7085ab413bSRiver Riddle   allocatedTypeRangeMemory.clear();
7185ab413bSRiver Riddle   allocatedValueRangeMemory.clear();
7285ab413bSRiver Riddle }
7385ab413bSRiver Riddle 
74abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
75abfd1a8bSRiver Riddle // Bytecode OpCodes
76abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
77abfd1a8bSRiver Riddle 
78abfd1a8bSRiver Riddle namespace {
79abfd1a8bSRiver Riddle enum OpCode : ByteCodeField {
80abfd1a8bSRiver Riddle   /// Apply an externally registered constraint.
81abfd1a8bSRiver Riddle   ApplyConstraint,
82abfd1a8bSRiver Riddle   /// Apply an externally registered rewrite.
83abfd1a8bSRiver Riddle   ApplyRewrite,
84abfd1a8bSRiver Riddle   /// Check if two generic values are equal.
85abfd1a8bSRiver Riddle   AreEqual,
8685ab413bSRiver Riddle   /// Check if two ranges are equal.
8785ab413bSRiver Riddle   AreRangesEqual,
88abfd1a8bSRiver Riddle   /// Unconditional branch.
89abfd1a8bSRiver Riddle   Branch,
90abfd1a8bSRiver Riddle   /// Compare the operand count of an operation with a constant.
91abfd1a8bSRiver Riddle   CheckOperandCount,
92abfd1a8bSRiver Riddle   /// Compare the name of an operation with a constant.
93abfd1a8bSRiver Riddle   CheckOperationName,
94abfd1a8bSRiver Riddle   /// Compare the result count of an operation with a constant.
95abfd1a8bSRiver Riddle   CheckResultCount,
9685ab413bSRiver Riddle   /// Compare a range of types to a constant range of types.
9785ab413bSRiver Riddle   CheckTypes,
98abfd1a8bSRiver Riddle   /// Create an operation.
99abfd1a8bSRiver Riddle   CreateOperation,
10085ab413bSRiver Riddle   /// Create a range of types.
10185ab413bSRiver Riddle   CreateTypes,
102abfd1a8bSRiver Riddle   /// Erase an operation.
103abfd1a8bSRiver Riddle   EraseOp,
104abfd1a8bSRiver Riddle   /// Terminate a matcher or rewrite sequence.
105abfd1a8bSRiver Riddle   Finalize,
106abfd1a8bSRiver Riddle   /// Get a specific attribute of an operation.
107abfd1a8bSRiver Riddle   GetAttribute,
108abfd1a8bSRiver Riddle   /// Get the type of an attribute.
109abfd1a8bSRiver Riddle   GetAttributeType,
110abfd1a8bSRiver Riddle   /// Get the defining operation of a value.
111abfd1a8bSRiver Riddle   GetDefiningOp,
112abfd1a8bSRiver Riddle   /// Get a specific operand of an operation.
113abfd1a8bSRiver Riddle   GetOperand0,
114abfd1a8bSRiver Riddle   GetOperand1,
115abfd1a8bSRiver Riddle   GetOperand2,
116abfd1a8bSRiver Riddle   GetOperand3,
117abfd1a8bSRiver Riddle   GetOperandN,
11885ab413bSRiver Riddle   /// Get a specific operand group of an operation.
11985ab413bSRiver Riddle   GetOperands,
120abfd1a8bSRiver Riddle   /// Get a specific result of an operation.
121abfd1a8bSRiver Riddle   GetResult0,
122abfd1a8bSRiver Riddle   GetResult1,
123abfd1a8bSRiver Riddle   GetResult2,
124abfd1a8bSRiver Riddle   GetResult3,
125abfd1a8bSRiver Riddle   GetResultN,
12685ab413bSRiver Riddle   /// Get a specific result group of an operation.
12785ab413bSRiver Riddle   GetResults,
128abfd1a8bSRiver Riddle   /// Get the type of a value.
129abfd1a8bSRiver Riddle   GetValueType,
13085ab413bSRiver Riddle   /// Get the types of a value range.
13185ab413bSRiver Riddle   GetValueRangeTypes,
132abfd1a8bSRiver Riddle   /// Check if a generic value is not null.
133abfd1a8bSRiver Riddle   IsNotNull,
134abfd1a8bSRiver Riddle   /// Record a successful pattern match.
135abfd1a8bSRiver Riddle   RecordMatch,
136abfd1a8bSRiver Riddle   /// Replace an operation.
137abfd1a8bSRiver Riddle   ReplaceOp,
138abfd1a8bSRiver Riddle   /// Compare an attribute with a set of constants.
139abfd1a8bSRiver Riddle   SwitchAttribute,
140abfd1a8bSRiver Riddle   /// Compare the operand count of an operation with a set of constants.
141abfd1a8bSRiver Riddle   SwitchOperandCount,
142abfd1a8bSRiver Riddle   /// Compare the name of an operation with a set of constants.
143abfd1a8bSRiver Riddle   SwitchOperationName,
144abfd1a8bSRiver Riddle   /// Compare the result count of an operation with a set of constants.
145abfd1a8bSRiver Riddle   SwitchResultCount,
146abfd1a8bSRiver Riddle   /// Compare a type with a set of constants.
147abfd1a8bSRiver Riddle   SwitchType,
14885ab413bSRiver Riddle   /// Compare a range of types with a set of constants.
14985ab413bSRiver Riddle   SwitchTypes,
150abfd1a8bSRiver Riddle };
151abfd1a8bSRiver Riddle } // end anonymous namespace
152abfd1a8bSRiver Riddle 
153abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
154abfd1a8bSRiver Riddle // ByteCode Generation
155abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
156abfd1a8bSRiver Riddle 
157abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
158abfd1a8bSRiver Riddle // Generator
159abfd1a8bSRiver Riddle 
160abfd1a8bSRiver Riddle namespace {
161abfd1a8bSRiver Riddle struct ByteCodeWriter;
162abfd1a8bSRiver Riddle 
163abfd1a8bSRiver Riddle /// This class represents the main generator for the pattern bytecode.
164abfd1a8bSRiver Riddle class Generator {
165abfd1a8bSRiver Riddle public:
166abfd1a8bSRiver Riddle   Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
167abfd1a8bSRiver Riddle             SmallVectorImpl<ByteCodeField> &matcherByteCode,
168abfd1a8bSRiver Riddle             SmallVectorImpl<ByteCodeField> &rewriterByteCode,
169abfd1a8bSRiver Riddle             SmallVectorImpl<PDLByteCodePattern> &patterns,
170abfd1a8bSRiver Riddle             ByteCodeField &maxValueMemoryIndex,
17185ab413bSRiver Riddle             ByteCodeField &maxTypeRangeMemoryIndex,
17285ab413bSRiver Riddle             ByteCodeField &maxValueRangeMemoryIndex,
173abfd1a8bSRiver Riddle             llvm::StringMap<PDLConstraintFunction> &constraintFns,
174abfd1a8bSRiver Riddle             llvm::StringMap<PDLRewriteFunction> &rewriteFns)
175abfd1a8bSRiver Riddle       : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
176abfd1a8bSRiver Riddle         rewriterByteCode(rewriterByteCode), patterns(patterns),
17785ab413bSRiver Riddle         maxValueMemoryIndex(maxValueMemoryIndex),
17885ab413bSRiver Riddle         maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
17985ab413bSRiver Riddle         maxValueRangeMemoryIndex(maxValueRangeMemoryIndex) {
180abfd1a8bSRiver Riddle     for (auto it : llvm::enumerate(constraintFns))
181abfd1a8bSRiver Riddle       constraintToMemIndex.try_emplace(it.value().first(), it.index());
182abfd1a8bSRiver Riddle     for (auto it : llvm::enumerate(rewriteFns))
183abfd1a8bSRiver Riddle       externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
184abfd1a8bSRiver Riddle   }
185abfd1a8bSRiver Riddle 
186abfd1a8bSRiver Riddle   /// Generate the bytecode for the given PDL interpreter module.
187abfd1a8bSRiver Riddle   void generate(ModuleOp module);
188abfd1a8bSRiver Riddle 
189abfd1a8bSRiver Riddle   /// Return the memory index to use for the given value.
190abfd1a8bSRiver Riddle   ByteCodeField &getMemIndex(Value value) {
191abfd1a8bSRiver Riddle     assert(valueToMemIndex.count(value) &&
192abfd1a8bSRiver Riddle            "expected memory index to be assigned");
193abfd1a8bSRiver Riddle     return valueToMemIndex[value];
194abfd1a8bSRiver Riddle   }
195abfd1a8bSRiver Riddle 
19685ab413bSRiver Riddle   /// Return the range memory index used to store the given range value.
19785ab413bSRiver Riddle   ByteCodeField &getRangeStorageIndex(Value value) {
19885ab413bSRiver Riddle     assert(valueToRangeIndex.count(value) &&
19985ab413bSRiver Riddle            "expected range index to be assigned");
20085ab413bSRiver Riddle     return valueToRangeIndex[value];
20185ab413bSRiver Riddle   }
20285ab413bSRiver Riddle 
203abfd1a8bSRiver Riddle   /// Return an index to use when referring to the given data that is uniqued in
204abfd1a8bSRiver Riddle   /// the MLIR context.
205abfd1a8bSRiver Riddle   template <typename T>
206abfd1a8bSRiver Riddle   std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
207abfd1a8bSRiver Riddle   getMemIndex(T val) {
208abfd1a8bSRiver Riddle     const void *opaqueVal = val.getAsOpaquePointer();
209abfd1a8bSRiver Riddle 
210abfd1a8bSRiver Riddle     // Get or insert a reference to this value.
211abfd1a8bSRiver Riddle     auto it = uniquedDataToMemIndex.try_emplace(
212abfd1a8bSRiver Riddle         opaqueVal, maxValueMemoryIndex + uniquedData.size());
213abfd1a8bSRiver Riddle     if (it.second)
214abfd1a8bSRiver Riddle       uniquedData.push_back(opaqueVal);
215abfd1a8bSRiver Riddle     return it.first->second;
216abfd1a8bSRiver Riddle   }
217abfd1a8bSRiver Riddle 
218abfd1a8bSRiver Riddle private:
219abfd1a8bSRiver Riddle   /// Allocate memory indices for the results of operations within the matcher
220abfd1a8bSRiver Riddle   /// and rewriters.
221abfd1a8bSRiver Riddle   void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule);
222abfd1a8bSRiver Riddle 
223abfd1a8bSRiver Riddle   /// Generate the bytecode for the given operation.
224abfd1a8bSRiver Riddle   void generate(Operation *op, ByteCodeWriter &writer);
225abfd1a8bSRiver Riddle   void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
226abfd1a8bSRiver Riddle   void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
227abfd1a8bSRiver Riddle   void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
228abfd1a8bSRiver Riddle   void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
229abfd1a8bSRiver Riddle   void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
230abfd1a8bSRiver Riddle   void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
231abfd1a8bSRiver Riddle   void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
232abfd1a8bSRiver Riddle   void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
233abfd1a8bSRiver Riddle   void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
23485ab413bSRiver Riddle   void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
235abfd1a8bSRiver Riddle   void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
236abfd1a8bSRiver Riddle   void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
237abfd1a8bSRiver Riddle   void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
23885ab413bSRiver Riddle   void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
239abfd1a8bSRiver Riddle   void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
240abfd1a8bSRiver Riddle   void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
241abfd1a8bSRiver Riddle   void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
242abfd1a8bSRiver Riddle   void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
243abfd1a8bSRiver Riddle   void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
244abfd1a8bSRiver Riddle   void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
24585ab413bSRiver Riddle   void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
246abfd1a8bSRiver Riddle   void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
24785ab413bSRiver Riddle   void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
248abfd1a8bSRiver Riddle   void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
2493a833a0eSRiver Riddle   void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer);
250abfd1a8bSRiver Riddle   void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
251abfd1a8bSRiver Riddle   void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
252abfd1a8bSRiver Riddle   void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
253abfd1a8bSRiver Riddle   void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
254abfd1a8bSRiver Riddle   void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
25585ab413bSRiver Riddle   void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
256abfd1a8bSRiver Riddle   void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
257abfd1a8bSRiver Riddle   void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
258abfd1a8bSRiver Riddle   void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
259abfd1a8bSRiver Riddle 
260abfd1a8bSRiver Riddle   /// Mapping from value to its corresponding memory index.
261abfd1a8bSRiver Riddle   DenseMap<Value, ByteCodeField> valueToMemIndex;
262abfd1a8bSRiver Riddle 
26385ab413bSRiver Riddle   /// Mapping from a range value to its corresponding range storage index.
26485ab413bSRiver Riddle   DenseMap<Value, ByteCodeField> valueToRangeIndex;
26585ab413bSRiver Riddle 
266abfd1a8bSRiver Riddle   /// Mapping from the name of an externally registered rewrite to its index in
267abfd1a8bSRiver Riddle   /// the bytecode registry.
268abfd1a8bSRiver Riddle   llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
269abfd1a8bSRiver Riddle 
270abfd1a8bSRiver Riddle   /// Mapping from the name of an externally registered constraint to its index
271abfd1a8bSRiver Riddle   /// in the bytecode registry.
272abfd1a8bSRiver Riddle   llvm::StringMap<ByteCodeField> constraintToMemIndex;
273abfd1a8bSRiver Riddle 
274abfd1a8bSRiver Riddle   /// Mapping from rewriter function name to the bytecode address of the
275abfd1a8bSRiver Riddle   /// rewriter function in byte.
276abfd1a8bSRiver Riddle   llvm::StringMap<ByteCodeAddr> rewriterToAddr;
277abfd1a8bSRiver Riddle 
278abfd1a8bSRiver Riddle   /// Mapping from a uniqued storage object to its memory index within
279abfd1a8bSRiver Riddle   /// `uniquedData`.
280abfd1a8bSRiver Riddle   DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
281abfd1a8bSRiver Riddle 
282abfd1a8bSRiver Riddle   /// The current MLIR context.
283abfd1a8bSRiver Riddle   MLIRContext *ctx;
284abfd1a8bSRiver Riddle 
285abfd1a8bSRiver Riddle   /// Data of the ByteCode class to be populated.
286abfd1a8bSRiver Riddle   std::vector<const void *> &uniquedData;
287abfd1a8bSRiver Riddle   SmallVectorImpl<ByteCodeField> &matcherByteCode;
288abfd1a8bSRiver Riddle   SmallVectorImpl<ByteCodeField> &rewriterByteCode;
289abfd1a8bSRiver Riddle   SmallVectorImpl<PDLByteCodePattern> &patterns;
290abfd1a8bSRiver Riddle   ByteCodeField &maxValueMemoryIndex;
29185ab413bSRiver Riddle   ByteCodeField &maxTypeRangeMemoryIndex;
29285ab413bSRiver Riddle   ByteCodeField &maxValueRangeMemoryIndex;
293abfd1a8bSRiver Riddle };
294abfd1a8bSRiver Riddle 
295abfd1a8bSRiver Riddle /// This class provides utilities for writing a bytecode stream.
296abfd1a8bSRiver Riddle struct ByteCodeWriter {
297abfd1a8bSRiver Riddle   ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
298abfd1a8bSRiver Riddle       : bytecode(bytecode), generator(generator) {}
299abfd1a8bSRiver Riddle 
300abfd1a8bSRiver Riddle   /// Append a field to the bytecode.
301abfd1a8bSRiver Riddle   void append(ByteCodeField field) { bytecode.push_back(field); }
302fa20ab7bSRiver Riddle   void append(OpCode opCode) { bytecode.push_back(opCode); }
303abfd1a8bSRiver Riddle 
304abfd1a8bSRiver Riddle   /// Append an address to the bytecode.
305abfd1a8bSRiver Riddle   void append(ByteCodeAddr field) {
306abfd1a8bSRiver Riddle     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
307abfd1a8bSRiver Riddle                   "unexpected ByteCode address size");
308abfd1a8bSRiver Riddle 
309abfd1a8bSRiver Riddle     ByteCodeField fieldParts[2];
310abfd1a8bSRiver Riddle     std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
311abfd1a8bSRiver Riddle     bytecode.append({fieldParts[0], fieldParts[1]});
312abfd1a8bSRiver Riddle   }
313abfd1a8bSRiver Riddle 
314abfd1a8bSRiver Riddle   /// Append a successor range to the bytecode, the exact address will need to
315abfd1a8bSRiver Riddle   /// be resolved later.
316abfd1a8bSRiver Riddle   void append(SuccessorRange successors) {
317abfd1a8bSRiver Riddle     // Add back references to the any successors so that the address can be
318abfd1a8bSRiver Riddle     // resolved later.
319abfd1a8bSRiver Riddle     for (Block *successor : successors) {
320abfd1a8bSRiver Riddle       unresolvedSuccessorRefs[successor].push_back(bytecode.size());
321abfd1a8bSRiver Riddle       append(ByteCodeAddr(0));
322abfd1a8bSRiver Riddle     }
323abfd1a8bSRiver Riddle   }
324abfd1a8bSRiver Riddle 
325abfd1a8bSRiver Riddle   /// Append a range of values that will be read as generic PDLValues.
326abfd1a8bSRiver Riddle   void appendPDLValueList(OperandRange values) {
327abfd1a8bSRiver Riddle     bytecode.push_back(values.size());
32885ab413bSRiver Riddle     for (Value value : values)
32985ab413bSRiver Riddle       appendPDLValue(value);
33085ab413bSRiver Riddle   }
33185ab413bSRiver Riddle 
33285ab413bSRiver Riddle   /// Append a value as a PDLValue.
33385ab413bSRiver Riddle   void appendPDLValue(Value value) {
33485ab413bSRiver Riddle     appendPDLValueKind(value);
335abfd1a8bSRiver Riddle     append(value);
336abfd1a8bSRiver Riddle   }
33785ab413bSRiver Riddle 
33885ab413bSRiver Riddle   /// Append the PDLValue::Kind of the given value.
33985ab413bSRiver Riddle   void appendPDLValueKind(Value value) {
34085ab413bSRiver Riddle     // Append the type of the value in addition to the value itself.
34185ab413bSRiver Riddle     PDLValue::Kind kind =
34285ab413bSRiver Riddle         TypeSwitch<Type, PDLValue::Kind>(value.getType())
34385ab413bSRiver Riddle             .Case<pdl::AttributeType>(
34485ab413bSRiver Riddle                 [](Type) { return PDLValue::Kind::Attribute; })
34585ab413bSRiver Riddle             .Case<pdl::OperationType>(
34685ab413bSRiver Riddle                 [](Type) { return PDLValue::Kind::Operation; })
34785ab413bSRiver Riddle             .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
34885ab413bSRiver Riddle               if (rangeTy.getElementType().isa<pdl::TypeType>())
34985ab413bSRiver Riddle                 return PDLValue::Kind::TypeRange;
35085ab413bSRiver Riddle               return PDLValue::Kind::ValueRange;
35185ab413bSRiver Riddle             })
35285ab413bSRiver Riddle             .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
35385ab413bSRiver Riddle             .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
35485ab413bSRiver Riddle     bytecode.push_back(static_cast<ByteCodeField>(kind));
355abfd1a8bSRiver Riddle   }
356abfd1a8bSRiver Riddle 
357abfd1a8bSRiver Riddle   /// Check if the given class `T` has an iterator type.
358abfd1a8bSRiver Riddle   template <typename T, typename... Args>
359abfd1a8bSRiver Riddle   using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
360abfd1a8bSRiver Riddle 
361abfd1a8bSRiver Riddle   /// Append a value that will be stored in a memory slot and not inline within
362abfd1a8bSRiver Riddle   /// the bytecode.
363abfd1a8bSRiver Riddle   template <typename T>
364abfd1a8bSRiver Riddle   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
365abfd1a8bSRiver Riddle                    std::is_pointer<T>::value>
366abfd1a8bSRiver Riddle   append(T value) {
367abfd1a8bSRiver Riddle     bytecode.push_back(generator.getMemIndex(value));
368abfd1a8bSRiver Riddle   }
369abfd1a8bSRiver Riddle 
370abfd1a8bSRiver Riddle   /// Append a range of values.
371abfd1a8bSRiver Riddle   template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
372abfd1a8bSRiver Riddle   std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
373abfd1a8bSRiver Riddle   append(T range) {
374abfd1a8bSRiver Riddle     bytecode.push_back(llvm::size(range));
375abfd1a8bSRiver Riddle     for (auto it : range)
376abfd1a8bSRiver Riddle       append(it);
377abfd1a8bSRiver Riddle   }
378abfd1a8bSRiver Riddle 
379abfd1a8bSRiver Riddle   /// Append a variadic number of fields to the bytecode.
380abfd1a8bSRiver Riddle   template <typename FieldTy, typename Field2Ty, typename... FieldTys>
381abfd1a8bSRiver Riddle   void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
382abfd1a8bSRiver Riddle     append(field);
383abfd1a8bSRiver Riddle     append(field2, fields...);
384abfd1a8bSRiver Riddle   }
385abfd1a8bSRiver Riddle 
386abfd1a8bSRiver Riddle   /// Successor references in the bytecode that have yet to be resolved.
387abfd1a8bSRiver Riddle   DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
388abfd1a8bSRiver Riddle 
389abfd1a8bSRiver Riddle   /// The underlying bytecode buffer.
390abfd1a8bSRiver Riddle   SmallVectorImpl<ByteCodeField> &bytecode;
391abfd1a8bSRiver Riddle 
392abfd1a8bSRiver Riddle   /// The main generator producing PDL.
393abfd1a8bSRiver Riddle   Generator &generator;
394abfd1a8bSRiver Riddle };
39585ab413bSRiver Riddle 
39685ab413bSRiver Riddle /// This class represents a live range of PDL Interpreter values, containing
39785ab413bSRiver Riddle /// information about when values are live within a match/rewrite.
39885ab413bSRiver Riddle struct ByteCodeLiveRange {
39985ab413bSRiver Riddle   using Set = llvm::IntervalMap<ByteCodeField, char, 16>;
40085ab413bSRiver Riddle   using Allocator = Set::Allocator;
40185ab413bSRiver Riddle 
40285ab413bSRiver Riddle   ByteCodeLiveRange(Allocator &alloc) : liveness(alloc) {}
40385ab413bSRiver Riddle 
40485ab413bSRiver Riddle   /// Union this live range with the one provided.
40585ab413bSRiver Riddle   void unionWith(const ByteCodeLiveRange &rhs) {
40685ab413bSRiver Riddle     for (auto it = rhs.liveness.begin(), e = rhs.liveness.end(); it != e; ++it)
40785ab413bSRiver Riddle       liveness.insert(it.start(), it.stop(), /*dummyValue*/ 0);
40885ab413bSRiver Riddle   }
40985ab413bSRiver Riddle 
41085ab413bSRiver Riddle   /// Returns true if this range overlaps with the one provided.
41185ab413bSRiver Riddle   bool overlaps(const ByteCodeLiveRange &rhs) const {
41285ab413bSRiver Riddle     return llvm::IntervalMapOverlaps<Set, Set>(liveness, rhs.liveness).valid();
41385ab413bSRiver Riddle   }
41485ab413bSRiver Riddle 
41585ab413bSRiver Riddle   /// A map representing the ranges of the match/rewrite that a value is live in
41685ab413bSRiver Riddle   /// the interpreter.
41785ab413bSRiver Riddle   llvm::IntervalMap<ByteCodeField, char, 16> liveness;
41885ab413bSRiver Riddle 
41985ab413bSRiver Riddle   /// The type range storage index for this range.
42085ab413bSRiver Riddle   Optional<unsigned> typeRangeIndex;
42185ab413bSRiver Riddle 
42285ab413bSRiver Riddle   /// The value range storage index for this range.
42385ab413bSRiver Riddle   Optional<unsigned> valueRangeIndex;
42485ab413bSRiver Riddle };
425abfd1a8bSRiver Riddle } // end anonymous namespace
426abfd1a8bSRiver Riddle 
427abfd1a8bSRiver Riddle void Generator::generate(ModuleOp module) {
428abfd1a8bSRiver Riddle   FuncOp matcherFunc = module.lookupSymbol<FuncOp>(
429abfd1a8bSRiver Riddle       pdl_interp::PDLInterpDialect::getMatcherFunctionName());
430abfd1a8bSRiver Riddle   ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
431abfd1a8bSRiver Riddle       pdl_interp::PDLInterpDialect::getRewriterModuleName());
432abfd1a8bSRiver Riddle   assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
433abfd1a8bSRiver Riddle 
434abfd1a8bSRiver Riddle   // Allocate memory indices for the results of operations within the matcher
435abfd1a8bSRiver Riddle   // and rewriters.
436abfd1a8bSRiver Riddle   allocateMemoryIndices(matcherFunc, rewriterModule);
437abfd1a8bSRiver Riddle 
438abfd1a8bSRiver Riddle   // Generate code for the rewriter functions.
439abfd1a8bSRiver Riddle   ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
440abfd1a8bSRiver Riddle   for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
441abfd1a8bSRiver Riddle     rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
442abfd1a8bSRiver Riddle     for (Operation &op : rewriterFunc.getOps())
443abfd1a8bSRiver Riddle       generate(&op, rewriterByteCodeWriter);
444abfd1a8bSRiver Riddle   }
445abfd1a8bSRiver Riddle   assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
446abfd1a8bSRiver Riddle          "unexpected branches in rewriter function");
447abfd1a8bSRiver Riddle 
448abfd1a8bSRiver Riddle   // Generate code for the matcher function.
449abfd1a8bSRiver Riddle   DenseMap<Block *, ByteCodeAddr> blockToAddr;
450abfd1a8bSRiver Riddle   llvm::ReversePostOrderTraversal<Region *> rpot(&matcherFunc.getBody());
451abfd1a8bSRiver Riddle   ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
452abfd1a8bSRiver Riddle   for (Block *block : rpot) {
453abfd1a8bSRiver Riddle     // Keep track of where this block begins within the matcher function.
454abfd1a8bSRiver Riddle     blockToAddr.try_emplace(block, matcherByteCode.size());
455abfd1a8bSRiver Riddle     for (Operation &op : *block)
456abfd1a8bSRiver Riddle       generate(&op, matcherByteCodeWriter);
457abfd1a8bSRiver Riddle   }
458abfd1a8bSRiver Riddle 
459abfd1a8bSRiver Riddle   // Resolve successor references in the matcher.
460abfd1a8bSRiver Riddle   for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
461abfd1a8bSRiver Riddle     ByteCodeAddr addr = blockToAddr[it.first];
462abfd1a8bSRiver Riddle     for (unsigned offsetToFix : it.second)
463abfd1a8bSRiver Riddle       std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
464abfd1a8bSRiver Riddle   }
465abfd1a8bSRiver Riddle }
466abfd1a8bSRiver Riddle 
467abfd1a8bSRiver Riddle void Generator::allocateMemoryIndices(FuncOp matcherFunc,
468abfd1a8bSRiver Riddle                                       ModuleOp rewriterModule) {
469abfd1a8bSRiver Riddle   // Rewriters use simplistic allocation scheme that simply assigns an index to
470abfd1a8bSRiver Riddle   // each result.
471abfd1a8bSRiver Riddle   for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
47285ab413bSRiver Riddle     ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
47385ab413bSRiver Riddle     auto processRewriterValue = [&](Value val) {
47485ab413bSRiver Riddle       valueToMemIndex.try_emplace(val, index++);
47585ab413bSRiver Riddle       if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) {
47685ab413bSRiver Riddle         Type elementTy = rangeType.getElementType();
47785ab413bSRiver Riddle         if (elementTy.isa<pdl::TypeType>())
47885ab413bSRiver Riddle           valueToRangeIndex.try_emplace(val, typeRangeIndex++);
47985ab413bSRiver Riddle         else if (elementTy.isa<pdl::ValueType>())
48085ab413bSRiver Riddle           valueToRangeIndex.try_emplace(val, valueRangeIndex++);
48185ab413bSRiver Riddle       }
48285ab413bSRiver Riddle     };
48385ab413bSRiver Riddle 
484abfd1a8bSRiver Riddle     for (BlockArgument arg : rewriterFunc.getArguments())
48585ab413bSRiver Riddle       processRewriterValue(arg);
486abfd1a8bSRiver Riddle     rewriterFunc.getBody().walk([&](Operation *op) {
487abfd1a8bSRiver Riddle       for (Value result : op->getResults())
48885ab413bSRiver Riddle         processRewriterValue(result);
489abfd1a8bSRiver Riddle     });
490abfd1a8bSRiver Riddle     if (index > maxValueMemoryIndex)
491abfd1a8bSRiver Riddle       maxValueMemoryIndex = index;
49285ab413bSRiver Riddle     if (typeRangeIndex > maxTypeRangeMemoryIndex)
49385ab413bSRiver Riddle       maxTypeRangeMemoryIndex = typeRangeIndex;
49485ab413bSRiver Riddle     if (valueRangeIndex > maxValueRangeMemoryIndex)
49585ab413bSRiver Riddle       maxValueRangeMemoryIndex = valueRangeIndex;
496abfd1a8bSRiver Riddle   }
497abfd1a8bSRiver Riddle 
498abfd1a8bSRiver Riddle   // The matcher function uses a more sophisticated numbering that tries to
499abfd1a8bSRiver Riddle   // minimize the number of memory indices assigned. This is done by determining
500abfd1a8bSRiver Riddle   // a live range of the values within the matcher, then the allocation is just
501abfd1a8bSRiver Riddle   // finding the minimal number of overlapping live ranges. This is essentially
502abfd1a8bSRiver Riddle   // a simplified form of register allocation where we don't necessarily have a
503abfd1a8bSRiver Riddle   // limited number of registers, but we still want to minimize the number used.
504abfd1a8bSRiver Riddle   DenseMap<Operation *, ByteCodeField> opToIndex;
505abfd1a8bSRiver Riddle   matcherFunc.getBody().walk([&](Operation *op) {
506abfd1a8bSRiver Riddle     opToIndex.insert(std::make_pair(op, opToIndex.size()));
507abfd1a8bSRiver Riddle   });
508abfd1a8bSRiver Riddle 
509abfd1a8bSRiver Riddle   // Liveness info for each of the defs within the matcher.
51085ab413bSRiver Riddle   ByteCodeLiveRange::Allocator allocator;
51185ab413bSRiver Riddle   DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
512abfd1a8bSRiver Riddle 
513abfd1a8bSRiver Riddle   // Assign the root operation being matched to slot 0.
514abfd1a8bSRiver Riddle   BlockArgument rootOpArg = matcherFunc.getArgument(0);
515abfd1a8bSRiver Riddle   valueToMemIndex[rootOpArg] = 0;
516abfd1a8bSRiver Riddle 
517abfd1a8bSRiver Riddle   // Walk each of the blocks, computing the def interval that the value is used.
518abfd1a8bSRiver Riddle   Liveness matcherLiveness(matcherFunc);
519abfd1a8bSRiver Riddle   for (Block &block : matcherFunc.getBody()) {
520abfd1a8bSRiver Riddle     const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block);
521abfd1a8bSRiver Riddle     assert(info && "expected liveness info for block");
522abfd1a8bSRiver Riddle     auto processValue = [&](Value value, Operation *firstUseOrDef) {
523abfd1a8bSRiver Riddle       // We don't need to process the root op argument, this value is always
524abfd1a8bSRiver Riddle       // assigned to the first memory slot.
525abfd1a8bSRiver Riddle       if (value == rootOpArg)
526abfd1a8bSRiver Riddle         return;
527abfd1a8bSRiver Riddle 
528abfd1a8bSRiver Riddle       // Set indices for the range of this block that the value is used.
529abfd1a8bSRiver Riddle       auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
53085ab413bSRiver Riddle       defRangeIt->second.liveness.insert(
531abfd1a8bSRiver Riddle           opToIndex[firstUseOrDef],
532abfd1a8bSRiver Riddle           opToIndex[info->getEndOperation(value, firstUseOrDef)],
533abfd1a8bSRiver Riddle           /*dummyValue*/ 0);
53485ab413bSRiver Riddle 
53585ab413bSRiver Riddle       // Check to see if this value is a range type.
53685ab413bSRiver Riddle       if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) {
53785ab413bSRiver Riddle         Type eleType = rangeTy.getElementType();
53885ab413bSRiver Riddle         if (eleType.isa<pdl::TypeType>())
53985ab413bSRiver Riddle           defRangeIt->second.typeRangeIndex = 0;
54085ab413bSRiver Riddle         else if (eleType.isa<pdl::ValueType>())
54185ab413bSRiver Riddle           defRangeIt->second.valueRangeIndex = 0;
54285ab413bSRiver Riddle       }
543abfd1a8bSRiver Riddle     };
544abfd1a8bSRiver Riddle 
545abfd1a8bSRiver Riddle     // Process the live-ins of this block.
546abfd1a8bSRiver Riddle     for (Value liveIn : info->in())
547abfd1a8bSRiver Riddle       processValue(liveIn, &block.front());
548abfd1a8bSRiver Riddle 
549abfd1a8bSRiver Riddle     // Process any new defs within this block.
550abfd1a8bSRiver Riddle     for (Operation &op : block)
551abfd1a8bSRiver Riddle       for (Value result : op.getResults())
552abfd1a8bSRiver Riddle         processValue(result, &op);
553abfd1a8bSRiver Riddle   }
554abfd1a8bSRiver Riddle 
555abfd1a8bSRiver Riddle   // Greedily allocate memory slots using the computed def live ranges.
55685ab413bSRiver Riddle   std::vector<ByteCodeLiveRange> allocatedIndices;
55785ab413bSRiver Riddle   ByteCodeField numIndices = 1, numTypeRanges = 0, numValueRanges = 0;
558abfd1a8bSRiver Riddle   for (auto &defIt : valueDefRanges) {
559abfd1a8bSRiver Riddle     ByteCodeField &memIndex = valueToMemIndex[defIt.first];
56085ab413bSRiver Riddle     ByteCodeLiveRange &defRange = defIt.second;
561abfd1a8bSRiver Riddle 
562abfd1a8bSRiver Riddle     // Try to allocate to an existing index.
563abfd1a8bSRiver Riddle     for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) {
56485ab413bSRiver Riddle       ByteCodeLiveRange &existingRange = existingIndexIt.value();
56585ab413bSRiver Riddle       if (!defRange.overlaps(existingRange)) {
56685ab413bSRiver Riddle         existingRange.unionWith(defRange);
567abfd1a8bSRiver Riddle         memIndex = existingIndexIt.index() + 1;
56885ab413bSRiver Riddle 
56985ab413bSRiver Riddle         if (defRange.typeRangeIndex) {
57085ab413bSRiver Riddle           if (!existingRange.typeRangeIndex)
57185ab413bSRiver Riddle             existingRange.typeRangeIndex = numTypeRanges++;
57285ab413bSRiver Riddle           valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
57385ab413bSRiver Riddle         } else if (defRange.valueRangeIndex) {
57485ab413bSRiver Riddle           if (!existingRange.valueRangeIndex)
57585ab413bSRiver Riddle             existingRange.valueRangeIndex = numValueRanges++;
57685ab413bSRiver Riddle           valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
57785ab413bSRiver Riddle         }
57885ab413bSRiver Riddle         break;
57985ab413bSRiver Riddle       }
580abfd1a8bSRiver Riddle     }
581abfd1a8bSRiver Riddle 
582abfd1a8bSRiver Riddle     // If no existing index could be used, add a new one.
583abfd1a8bSRiver Riddle     if (memIndex == 0) {
584abfd1a8bSRiver Riddle       allocatedIndices.emplace_back(allocator);
58585ab413bSRiver Riddle       ByteCodeLiveRange &newRange = allocatedIndices.back();
58685ab413bSRiver Riddle       newRange.unionWith(defRange);
58785ab413bSRiver Riddle 
58885ab413bSRiver Riddle       // Allocate an index for type/value ranges.
58985ab413bSRiver Riddle       if (defRange.typeRangeIndex) {
59085ab413bSRiver Riddle         newRange.typeRangeIndex = numTypeRanges;
59185ab413bSRiver Riddle         valueToRangeIndex[defIt.first] = numTypeRanges++;
59285ab413bSRiver Riddle       } else if (defRange.valueRangeIndex) {
59385ab413bSRiver Riddle         newRange.valueRangeIndex = numValueRanges;
59485ab413bSRiver Riddle         valueToRangeIndex[defIt.first] = numValueRanges++;
59585ab413bSRiver Riddle       }
59685ab413bSRiver Riddle 
597abfd1a8bSRiver Riddle       memIndex = allocatedIndices.size();
59885ab413bSRiver Riddle       ++numIndices;
599abfd1a8bSRiver Riddle     }
600abfd1a8bSRiver Riddle   }
601abfd1a8bSRiver Riddle 
602abfd1a8bSRiver Riddle   // Update the max number of indices.
60385ab413bSRiver Riddle   if (numIndices > maxValueMemoryIndex)
60485ab413bSRiver Riddle     maxValueMemoryIndex = numIndices;
60585ab413bSRiver Riddle   if (numTypeRanges > maxTypeRangeMemoryIndex)
60685ab413bSRiver Riddle     maxTypeRangeMemoryIndex = numTypeRanges;
60785ab413bSRiver Riddle   if (numValueRanges > maxValueRangeMemoryIndex)
60885ab413bSRiver Riddle     maxValueRangeMemoryIndex = numValueRanges;
609abfd1a8bSRiver Riddle }
610abfd1a8bSRiver Riddle 
611abfd1a8bSRiver Riddle void Generator::generate(Operation *op, ByteCodeWriter &writer) {
612abfd1a8bSRiver Riddle   TypeSwitch<Operation *>(op)
613abfd1a8bSRiver Riddle       .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
614abfd1a8bSRiver Riddle             pdl_interp::AreEqualOp, pdl_interp::BranchOp,
615abfd1a8bSRiver Riddle             pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
616abfd1a8bSRiver Riddle             pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
61785ab413bSRiver Riddle             pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
61885ab413bSRiver Riddle             pdl_interp::CreateAttributeOp, pdl_interp::CreateOperationOp,
61985ab413bSRiver Riddle             pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
62002c4c0d5SRiver Riddle             pdl_interp::EraseOp, pdl_interp::FinalizeOp,
62102c4c0d5SRiver Riddle             pdl_interp::GetAttributeOp, pdl_interp::GetAttributeTypeOp,
62202c4c0d5SRiver Riddle             pdl_interp::GetDefiningOpOp, pdl_interp::GetOperandOp,
62385ab413bSRiver Riddle             pdl_interp::GetOperandsOp, pdl_interp::GetResultOp,
62485ab413bSRiver Riddle             pdl_interp::GetResultsOp, pdl_interp::GetValueTypeOp,
6253a833a0eSRiver Riddle             pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp,
62602c4c0d5SRiver Riddle             pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp,
62702c4c0d5SRiver Riddle             pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp,
62885ab413bSRiver Riddle             pdl_interp::SwitchTypesOp, pdl_interp::SwitchOperandCountOp,
62985ab413bSRiver Riddle             pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>(
630abfd1a8bSRiver Riddle           [&](auto interpOp) { this->generate(interpOp, writer); })
631abfd1a8bSRiver Riddle       .Default([](Operation *) {
632abfd1a8bSRiver Riddle         llvm_unreachable("unknown `pdl_interp` operation");
633abfd1a8bSRiver Riddle       });
634abfd1a8bSRiver Riddle }
635abfd1a8bSRiver Riddle 
636abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::ApplyConstraintOp op,
637abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
638abfd1a8bSRiver Riddle   assert(constraintToMemIndex.count(op.name()) &&
639abfd1a8bSRiver Riddle          "expected index for constraint function");
640abfd1a8bSRiver Riddle   writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()],
641abfd1a8bSRiver Riddle                 op.constParamsAttr());
642abfd1a8bSRiver Riddle   writer.appendPDLValueList(op.args());
643abfd1a8bSRiver Riddle   writer.append(op.getSuccessors());
644abfd1a8bSRiver Riddle }
645abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::ApplyRewriteOp op,
646abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
647abfd1a8bSRiver Riddle   assert(externalRewriterToMemIndex.count(op.name()) &&
648abfd1a8bSRiver Riddle          "expected index for rewrite function");
649abfd1a8bSRiver Riddle   writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()],
65002c4c0d5SRiver Riddle                 op.constParamsAttr());
651abfd1a8bSRiver Riddle   writer.appendPDLValueList(op.args());
65202c4c0d5SRiver Riddle 
65385ab413bSRiver Riddle   ResultRange results = op.results();
65485ab413bSRiver Riddle   writer.append(ByteCodeField(results.size()));
65585ab413bSRiver Riddle   for (Value result : results) {
65685ab413bSRiver Riddle     // In debug mode we also record the expected kind of the result, so that we
65785ab413bSRiver Riddle     // can provide extra verification of the native rewrite function.
65802c4c0d5SRiver Riddle #ifndef NDEBUG
65985ab413bSRiver Riddle     writer.appendPDLValueKind(result);
66002c4c0d5SRiver Riddle #endif
66185ab413bSRiver Riddle 
66285ab413bSRiver Riddle     // Range results also need to append the range storage index.
66385ab413bSRiver Riddle     if (result.getType().isa<pdl::RangeType>())
66485ab413bSRiver Riddle       writer.append(getRangeStorageIndex(result));
66502c4c0d5SRiver Riddle     writer.append(result);
666abfd1a8bSRiver Riddle   }
66785ab413bSRiver Riddle }
668abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
66985ab413bSRiver Riddle   Value lhs = op.lhs();
67085ab413bSRiver Riddle   if (lhs.getType().isa<pdl::RangeType>()) {
67185ab413bSRiver Riddle     writer.append(OpCode::AreRangesEqual);
67285ab413bSRiver Riddle     writer.appendPDLValueKind(lhs);
67385ab413bSRiver Riddle     writer.append(op.lhs(), op.rhs(), op.getSuccessors());
67485ab413bSRiver Riddle     return;
67585ab413bSRiver Riddle   }
67685ab413bSRiver Riddle 
67785ab413bSRiver Riddle   writer.append(OpCode::AreEqual, lhs, op.rhs(), op.getSuccessors());
678abfd1a8bSRiver Riddle }
679abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
6808affe881SRiver Riddle   writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
681abfd1a8bSRiver Riddle }
682abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckAttributeOp op,
683abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
684abfd1a8bSRiver Riddle   writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(),
685abfd1a8bSRiver Riddle                 op.getSuccessors());
686abfd1a8bSRiver Riddle }
687abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckOperandCountOp op,
688abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
689abfd1a8bSRiver Riddle   writer.append(OpCode::CheckOperandCount, op.operation(), op.count(),
69085ab413bSRiver Riddle                 static_cast<ByteCodeField>(op.compareAtLeast()),
691abfd1a8bSRiver Riddle                 op.getSuccessors());
692abfd1a8bSRiver Riddle }
693abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckOperationNameOp op,
694abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
695abfd1a8bSRiver Riddle   writer.append(OpCode::CheckOperationName, op.operation(),
696abfd1a8bSRiver Riddle                 OperationName(op.name(), ctx), op.getSuccessors());
697abfd1a8bSRiver Riddle }
698abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckResultCountOp op,
699abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
700abfd1a8bSRiver Riddle   writer.append(OpCode::CheckResultCount, op.operation(), op.count(),
70185ab413bSRiver Riddle                 static_cast<ByteCodeField>(op.compareAtLeast()),
702abfd1a8bSRiver Riddle                 op.getSuccessors());
703abfd1a8bSRiver Riddle }
704abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
705abfd1a8bSRiver Riddle   writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors());
706abfd1a8bSRiver Riddle }
70785ab413bSRiver Riddle void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
70885ab413bSRiver Riddle   writer.append(OpCode::CheckTypes, op.value(), op.types(), op.getSuccessors());
70985ab413bSRiver Riddle }
710abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateAttributeOp op,
711abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
712abfd1a8bSRiver Riddle   // Simply repoint the memory index of the result to the constant.
713abfd1a8bSRiver Riddle   getMemIndex(op.attribute()) = getMemIndex(op.value());
714abfd1a8bSRiver Riddle }
715abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateOperationOp op,
716abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
717abfd1a8bSRiver Riddle   writer.append(OpCode::CreateOperation, op.operation(),
71885ab413bSRiver Riddle                 OperationName(op.name(), ctx));
71985ab413bSRiver Riddle   writer.appendPDLValueList(op.operands());
720abfd1a8bSRiver Riddle 
721abfd1a8bSRiver Riddle   // Add the attributes.
722abfd1a8bSRiver Riddle   OperandRange attributes = op.attributes();
723abfd1a8bSRiver Riddle   writer.append(static_cast<ByteCodeField>(attributes.size()));
724abfd1a8bSRiver Riddle   for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
725abfd1a8bSRiver Riddle     writer.append(
726abfd1a8bSRiver Riddle         Identifier::get(std::get<0>(it).cast<StringAttr>().getValue(), ctx),
727abfd1a8bSRiver Riddle         std::get<1>(it));
728abfd1a8bSRiver Riddle   }
72985ab413bSRiver Riddle   writer.appendPDLValueList(op.types());
730abfd1a8bSRiver Riddle }
731abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
732abfd1a8bSRiver Riddle   // Simply repoint the memory index of the result to the constant.
733abfd1a8bSRiver Riddle   getMemIndex(op.result()) = getMemIndex(op.value());
734abfd1a8bSRiver Riddle }
73585ab413bSRiver Riddle void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
73685ab413bSRiver Riddle   writer.append(OpCode::CreateTypes, op.result(),
73785ab413bSRiver Riddle                 getRangeStorageIndex(op.result()), op.value());
73885ab413bSRiver Riddle }
739abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
740abfd1a8bSRiver Riddle   writer.append(OpCode::EraseOp, op.operation());
741abfd1a8bSRiver Riddle }
742abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
743abfd1a8bSRiver Riddle   writer.append(OpCode::Finalize);
744abfd1a8bSRiver Riddle }
745abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetAttributeOp op,
746abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
747abfd1a8bSRiver Riddle   writer.append(OpCode::GetAttribute, op.attribute(), op.operation(),
748abfd1a8bSRiver Riddle                 Identifier::get(op.name(), ctx));
749abfd1a8bSRiver Riddle }
750abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetAttributeTypeOp op,
751abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
752abfd1a8bSRiver Riddle   writer.append(OpCode::GetAttributeType, op.result(), op.value());
753abfd1a8bSRiver Riddle }
754abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetDefiningOpOp op,
755abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
75685ab413bSRiver Riddle   writer.append(OpCode::GetDefiningOp, op.operation());
75785ab413bSRiver Riddle   writer.appendPDLValue(op.value());
758abfd1a8bSRiver Riddle }
759abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
760abfd1a8bSRiver Riddle   uint32_t index = op.index();
761abfd1a8bSRiver Riddle   if (index < 4)
762abfd1a8bSRiver Riddle     writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
763abfd1a8bSRiver Riddle   else
764abfd1a8bSRiver Riddle     writer.append(OpCode::GetOperandN, index);
765abfd1a8bSRiver Riddle   writer.append(op.operation(), op.value());
766abfd1a8bSRiver Riddle }
76785ab413bSRiver Riddle void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
76885ab413bSRiver Riddle   Value result = op.value();
76985ab413bSRiver Riddle   Optional<uint32_t> index = op.index();
77085ab413bSRiver Riddle   writer.append(OpCode::GetOperands,
77185ab413bSRiver Riddle                 index.getValueOr(std::numeric_limits<uint32_t>::max()),
77285ab413bSRiver Riddle                 op.operation());
77385ab413bSRiver Riddle   if (result.getType().isa<pdl::RangeType>())
77485ab413bSRiver Riddle     writer.append(getRangeStorageIndex(result));
77585ab413bSRiver Riddle   else
77685ab413bSRiver Riddle     writer.append(std::numeric_limits<ByteCodeField>::max());
77785ab413bSRiver Riddle   writer.append(result);
77885ab413bSRiver Riddle }
779abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
780abfd1a8bSRiver Riddle   uint32_t index = op.index();
781abfd1a8bSRiver Riddle   if (index < 4)
782abfd1a8bSRiver Riddle     writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
783abfd1a8bSRiver Riddle   else
784abfd1a8bSRiver Riddle     writer.append(OpCode::GetResultN, index);
785abfd1a8bSRiver Riddle   writer.append(op.operation(), op.value());
786abfd1a8bSRiver Riddle }
78785ab413bSRiver Riddle void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
78885ab413bSRiver Riddle   Value result = op.value();
78985ab413bSRiver Riddle   Optional<uint32_t> index = op.index();
79085ab413bSRiver Riddle   writer.append(OpCode::GetResults,
79185ab413bSRiver Riddle                 index.getValueOr(std::numeric_limits<uint32_t>::max()),
79285ab413bSRiver Riddle                 op.operation());
79385ab413bSRiver Riddle   if (result.getType().isa<pdl::RangeType>())
79485ab413bSRiver Riddle     writer.append(getRangeStorageIndex(result));
79585ab413bSRiver Riddle   else
79685ab413bSRiver Riddle     writer.append(std::numeric_limits<ByteCodeField>::max());
79785ab413bSRiver Riddle   writer.append(result);
79885ab413bSRiver Riddle }
799abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetValueTypeOp op,
800abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
80185ab413bSRiver Riddle   if (op.getType().isa<pdl::RangeType>()) {
80285ab413bSRiver Riddle     Value result = op.result();
80385ab413bSRiver Riddle     writer.append(OpCode::GetValueRangeTypes, result,
80485ab413bSRiver Riddle                   getRangeStorageIndex(result), op.value());
80585ab413bSRiver Riddle   } else {
806abfd1a8bSRiver Riddle     writer.append(OpCode::GetValueType, op.result(), op.value());
807abfd1a8bSRiver Riddle   }
80885ab413bSRiver Riddle }
80985ab413bSRiver Riddle 
8103a833a0eSRiver Riddle void Generator::generate(pdl_interp::InferredTypesOp op,
811abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
8123a833a0eSRiver Riddle   // InferType maps to a null type as a marker for inferring result types.
813abfd1a8bSRiver Riddle   getMemIndex(op.type()) = getMemIndex(Type());
814abfd1a8bSRiver Riddle }
815abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
816abfd1a8bSRiver Riddle   writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors());
817abfd1a8bSRiver Riddle }
818abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
819abfd1a8bSRiver Riddle   ByteCodeField patternIndex = patterns.size();
820abfd1a8bSRiver Riddle   patterns.emplace_back(PDLByteCodePattern::create(
821*41d4aa7dSChris Lattner       op, rewriterToAddr[op.rewriter().getLeafReference().getValue()]));
8228affe881SRiver Riddle   writer.append(OpCode::RecordMatch, patternIndex,
82385ab413bSRiver Riddle                 SuccessorRange(op.getOperation()), op.matchedOps());
82485ab413bSRiver Riddle   writer.appendPDLValueList(op.inputs());
825abfd1a8bSRiver Riddle }
826abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
82785ab413bSRiver Riddle   writer.append(OpCode::ReplaceOp, op.operation());
82885ab413bSRiver Riddle   writer.appendPDLValueList(op.replValues());
829abfd1a8bSRiver Riddle }
830abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchAttributeOp op,
831abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
832abfd1a8bSRiver Riddle   writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(),
833abfd1a8bSRiver Riddle                 op.getSuccessors());
834abfd1a8bSRiver Riddle }
835abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchOperandCountOp op,
836abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
837abfd1a8bSRiver Riddle   writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(),
838abfd1a8bSRiver Riddle                 op.getSuccessors());
839abfd1a8bSRiver Riddle }
840abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchOperationNameOp op,
841abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
842abfd1a8bSRiver Riddle   auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) {
843abfd1a8bSRiver Riddle     return OperationName(attr.cast<StringAttr>().getValue(), ctx);
844abfd1a8bSRiver Riddle   });
845abfd1a8bSRiver Riddle   writer.append(OpCode::SwitchOperationName, op.operation(), cases,
846abfd1a8bSRiver Riddle                 op.getSuccessors());
847abfd1a8bSRiver Riddle }
848abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchResultCountOp op,
849abfd1a8bSRiver Riddle                          ByteCodeWriter &writer) {
850abfd1a8bSRiver Riddle   writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(),
851abfd1a8bSRiver Riddle                 op.getSuccessors());
852abfd1a8bSRiver Riddle }
853abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
854abfd1a8bSRiver Riddle   writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(),
855abfd1a8bSRiver Riddle                 op.getSuccessors());
856abfd1a8bSRiver Riddle }
85785ab413bSRiver Riddle void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
85885ab413bSRiver Riddle   writer.append(OpCode::SwitchTypes, op.value(), op.caseValuesAttr(),
85985ab413bSRiver Riddle                 op.getSuccessors());
86085ab413bSRiver Riddle }
861abfd1a8bSRiver Riddle 
862abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
863abfd1a8bSRiver Riddle // PDLByteCode
864abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
865abfd1a8bSRiver Riddle 
866abfd1a8bSRiver Riddle PDLByteCode::PDLByteCode(ModuleOp module,
867abfd1a8bSRiver Riddle                          llvm::StringMap<PDLConstraintFunction> constraintFns,
868abfd1a8bSRiver Riddle                          llvm::StringMap<PDLRewriteFunction> rewriteFns) {
869abfd1a8bSRiver Riddle   Generator generator(module.getContext(), uniquedData, matcherByteCode,
870abfd1a8bSRiver Riddle                       rewriterByteCode, patterns, maxValueMemoryIndex,
87185ab413bSRiver Riddle                       maxTypeRangeCount, maxValueRangeCount, constraintFns,
87285ab413bSRiver Riddle                       rewriteFns);
873abfd1a8bSRiver Riddle   generator.generate(module);
874abfd1a8bSRiver Riddle 
875abfd1a8bSRiver Riddle   // Initialize the external functions.
876abfd1a8bSRiver Riddle   for (auto &it : constraintFns)
877abfd1a8bSRiver Riddle     constraintFunctions.push_back(std::move(it.second));
878abfd1a8bSRiver Riddle   for (auto &it : rewriteFns)
879abfd1a8bSRiver Riddle     rewriteFunctions.push_back(std::move(it.second));
880abfd1a8bSRiver Riddle }
881abfd1a8bSRiver Riddle 
882abfd1a8bSRiver Riddle /// Initialize the given state such that it can be used to execute the current
883abfd1a8bSRiver Riddle /// bytecode.
884abfd1a8bSRiver Riddle void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
885abfd1a8bSRiver Riddle   state.memory.resize(maxValueMemoryIndex, nullptr);
88685ab413bSRiver Riddle   state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
88785ab413bSRiver Riddle   state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
888abfd1a8bSRiver Riddle   state.currentPatternBenefits.reserve(patterns.size());
889abfd1a8bSRiver Riddle   for (const PDLByteCodePattern &pattern : patterns)
890abfd1a8bSRiver Riddle     state.currentPatternBenefits.push_back(pattern.getBenefit());
891abfd1a8bSRiver Riddle }
892abfd1a8bSRiver Riddle 
893abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===//
894abfd1a8bSRiver Riddle // ByteCode Execution
895abfd1a8bSRiver Riddle 
896abfd1a8bSRiver Riddle namespace {
897abfd1a8bSRiver Riddle /// This class provides support for executing a bytecode stream.
898abfd1a8bSRiver Riddle class ByteCodeExecutor {
899abfd1a8bSRiver Riddle public:
90085ab413bSRiver Riddle   ByteCodeExecutor(
90185ab413bSRiver Riddle       const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
90285ab413bSRiver Riddle       MutableArrayRef<TypeRange> typeRangeMemory,
90385ab413bSRiver Riddle       std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
90485ab413bSRiver Riddle       MutableArrayRef<ValueRange> valueRangeMemory,
90585ab413bSRiver Riddle       std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
90685ab413bSRiver Riddle       ArrayRef<const void *> uniquedMemory, ArrayRef<ByteCodeField> code,
907abfd1a8bSRiver Riddle       ArrayRef<PatternBenefit> currentPatternBenefits,
908abfd1a8bSRiver Riddle       ArrayRef<PDLByteCodePattern> patterns,
909abfd1a8bSRiver Riddle       ArrayRef<PDLConstraintFunction> constraintFunctions,
910abfd1a8bSRiver Riddle       ArrayRef<PDLRewriteFunction> rewriteFunctions)
91185ab413bSRiver Riddle       : curCodeIt(curCodeIt), memory(memory), typeRangeMemory(typeRangeMemory),
91285ab413bSRiver Riddle         allocatedTypeRangeMemory(allocatedTypeRangeMemory),
91385ab413bSRiver Riddle         valueRangeMemory(valueRangeMemory),
91485ab413bSRiver Riddle         allocatedValueRangeMemory(allocatedValueRangeMemory),
91585ab413bSRiver Riddle         uniquedMemory(uniquedMemory), code(code),
91685ab413bSRiver Riddle         currentPatternBenefits(currentPatternBenefits), patterns(patterns),
91785ab413bSRiver Riddle         constraintFunctions(constraintFunctions),
91802c4c0d5SRiver Riddle         rewriteFunctions(rewriteFunctions) {}
919abfd1a8bSRiver Riddle 
920abfd1a8bSRiver Riddle   /// Start executing the code at the current bytecode index. `matches` is an
921abfd1a8bSRiver Riddle   /// optional field provided when this function is executed in a matching
922abfd1a8bSRiver Riddle   /// context.
923abfd1a8bSRiver Riddle   void execute(PatternRewriter &rewriter,
924abfd1a8bSRiver Riddle                SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
925abfd1a8bSRiver Riddle                Optional<Location> mainRewriteLoc = {});
926abfd1a8bSRiver Riddle 
927abfd1a8bSRiver Riddle private:
928154cabe7SRiver Riddle   /// Internal implementation of executing each of the bytecode commands.
929154cabe7SRiver Riddle   void executeApplyConstraint(PatternRewriter &rewriter);
930154cabe7SRiver Riddle   void executeApplyRewrite(PatternRewriter &rewriter);
931154cabe7SRiver Riddle   void executeAreEqual();
93285ab413bSRiver Riddle   void executeAreRangesEqual();
933154cabe7SRiver Riddle   void executeBranch();
934154cabe7SRiver Riddle   void executeCheckOperandCount();
935154cabe7SRiver Riddle   void executeCheckOperationName();
936154cabe7SRiver Riddle   void executeCheckResultCount();
93785ab413bSRiver Riddle   void executeCheckTypes();
938154cabe7SRiver Riddle   void executeCreateOperation(PatternRewriter &rewriter,
939154cabe7SRiver Riddle                               Location mainRewriteLoc);
94085ab413bSRiver Riddle   void executeCreateTypes();
941154cabe7SRiver Riddle   void executeEraseOp(PatternRewriter &rewriter);
942154cabe7SRiver Riddle   void executeGetAttribute();
943154cabe7SRiver Riddle   void executeGetAttributeType();
944154cabe7SRiver Riddle   void executeGetDefiningOp();
945154cabe7SRiver Riddle   void executeGetOperand(unsigned index);
94685ab413bSRiver Riddle   void executeGetOperands();
947154cabe7SRiver Riddle   void executeGetResult(unsigned index);
94885ab413bSRiver Riddle   void executeGetResults();
949154cabe7SRiver Riddle   void executeGetValueType();
95085ab413bSRiver Riddle   void executeGetValueRangeTypes();
951154cabe7SRiver Riddle   void executeIsNotNull();
952154cabe7SRiver Riddle   void executeRecordMatch(PatternRewriter &rewriter,
953154cabe7SRiver Riddle                           SmallVectorImpl<PDLByteCode::MatchResult> &matches);
954154cabe7SRiver Riddle   void executeReplaceOp(PatternRewriter &rewriter);
955154cabe7SRiver Riddle   void executeSwitchAttribute();
956154cabe7SRiver Riddle   void executeSwitchOperandCount();
957154cabe7SRiver Riddle   void executeSwitchOperationName();
958154cabe7SRiver Riddle   void executeSwitchResultCount();
959154cabe7SRiver Riddle   void executeSwitchType();
96085ab413bSRiver Riddle   void executeSwitchTypes();
961154cabe7SRiver Riddle 
962abfd1a8bSRiver Riddle   /// Read a value from the bytecode buffer, optionally skipping a certain
963abfd1a8bSRiver Riddle   /// number of prefix values. These methods always update the buffer to point
964abfd1a8bSRiver Riddle   /// to the next field after the read data.
965abfd1a8bSRiver Riddle   template <typename T = ByteCodeField>
966abfd1a8bSRiver Riddle   T read(size_t skipN = 0) {
967abfd1a8bSRiver Riddle     curCodeIt += skipN;
968abfd1a8bSRiver Riddle     return readImpl<T>();
969abfd1a8bSRiver Riddle   }
970abfd1a8bSRiver Riddle   ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
971abfd1a8bSRiver Riddle 
972abfd1a8bSRiver Riddle   /// Read a list of values from the bytecode buffer.
973abfd1a8bSRiver Riddle   template <typename ValueT, typename T>
974abfd1a8bSRiver Riddle   void readList(SmallVectorImpl<T> &list) {
975abfd1a8bSRiver Riddle     list.clear();
976abfd1a8bSRiver Riddle     for (unsigned i = 0, e = read(); i != e; ++i)
977abfd1a8bSRiver Riddle       list.push_back(read<ValueT>());
978abfd1a8bSRiver Riddle   }
979abfd1a8bSRiver Riddle 
98085ab413bSRiver Riddle   /// Read a list of values from the bytecode buffer. The values may be encoded
98185ab413bSRiver Riddle   /// as either Value or ValueRange elements.
98285ab413bSRiver Riddle   void readValueList(SmallVectorImpl<Value> &list) {
98385ab413bSRiver Riddle     for (unsigned i = 0, e = read(); i != e; ++i) {
98485ab413bSRiver Riddle       if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
98585ab413bSRiver Riddle         list.push_back(read<Value>());
98685ab413bSRiver Riddle       } else {
98785ab413bSRiver Riddle         ValueRange *values = read<ValueRange *>();
98885ab413bSRiver Riddle         list.append(values->begin(), values->end());
98985ab413bSRiver Riddle       }
99085ab413bSRiver Riddle     }
99185ab413bSRiver Riddle   }
99285ab413bSRiver Riddle 
993abfd1a8bSRiver Riddle   /// Jump to a specific successor based on a predicate value.
994abfd1a8bSRiver Riddle   void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
995abfd1a8bSRiver Riddle   /// Jump to a specific successor based on a destination index.
996abfd1a8bSRiver Riddle   void selectJump(size_t destIndex) {
997abfd1a8bSRiver Riddle     curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
998abfd1a8bSRiver Riddle   }
999abfd1a8bSRiver Riddle 
1000abfd1a8bSRiver Riddle   /// Handle a switch operation with the provided value and cases.
100185ab413bSRiver Riddle   template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
100285ab413bSRiver Riddle   void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1003abfd1a8bSRiver Riddle     LLVM_DEBUG({
1004abfd1a8bSRiver Riddle       llvm::dbgs() << "  * Value: " << value << "\n"
1005abfd1a8bSRiver Riddle                    << "  * Cases: ";
1006abfd1a8bSRiver Riddle       llvm::interleaveComma(cases, llvm::dbgs());
1007154cabe7SRiver Riddle       llvm::dbgs() << "\n";
1008abfd1a8bSRiver Riddle     });
1009abfd1a8bSRiver Riddle 
1010abfd1a8bSRiver Riddle     // Check to see if the attribute value is within the case list. Jump to
1011abfd1a8bSRiver Riddle     // the correct successor index based on the result.
1012f80b6304SRiver Riddle     for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
101385ab413bSRiver Riddle       if (cmp(*it, value))
1014f80b6304SRiver Riddle         return selectJump(size_t((it - cases.begin()) + 1));
1015f80b6304SRiver Riddle     selectJump(size_t(0));
1016abfd1a8bSRiver Riddle   }
1017abfd1a8bSRiver Riddle 
1018abfd1a8bSRiver Riddle   /// Internal implementation of reading various data types from the bytecode
1019abfd1a8bSRiver Riddle   /// stream.
1020abfd1a8bSRiver Riddle   template <typename T>
1021abfd1a8bSRiver Riddle   const void *readFromMemory() {
1022abfd1a8bSRiver Riddle     size_t index = *curCodeIt++;
1023abfd1a8bSRiver Riddle 
1024abfd1a8bSRiver Riddle     // If this type is an SSA value, it can only be stored in non-const memory.
102585ab413bSRiver Riddle     if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
102685ab413bSRiver Riddle                         Value>::value ||
102785ab413bSRiver Riddle         index < memory.size())
1028abfd1a8bSRiver Riddle       return memory[index];
1029abfd1a8bSRiver Riddle 
1030abfd1a8bSRiver Riddle     // Otherwise, if this index is not inbounds it is uniqued.
1031abfd1a8bSRiver Riddle     return uniquedMemory[index - memory.size()];
1032abfd1a8bSRiver Riddle   }
1033abfd1a8bSRiver Riddle   template <typename T>
1034abfd1a8bSRiver Riddle   std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1035abfd1a8bSRiver Riddle     return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1036abfd1a8bSRiver Riddle   }
1037abfd1a8bSRiver Riddle   template <typename T>
1038abfd1a8bSRiver Riddle   std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1039abfd1a8bSRiver Riddle                    T>
1040abfd1a8bSRiver Riddle   readImpl() {
1041abfd1a8bSRiver Riddle     return T(T::getFromOpaquePointer(readFromMemory<T>()));
1042abfd1a8bSRiver Riddle   }
1043abfd1a8bSRiver Riddle   template <typename T>
1044abfd1a8bSRiver Riddle   std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
104585ab413bSRiver Riddle     switch (read<PDLValue::Kind>()) {
104685ab413bSRiver Riddle     case PDLValue::Kind::Attribute:
1047abfd1a8bSRiver Riddle       return read<Attribute>();
104885ab413bSRiver Riddle     case PDLValue::Kind::Operation:
1049abfd1a8bSRiver Riddle       return read<Operation *>();
105085ab413bSRiver Riddle     case PDLValue::Kind::Type:
1051abfd1a8bSRiver Riddle       return read<Type>();
105285ab413bSRiver Riddle     case PDLValue::Kind::Value:
1053abfd1a8bSRiver Riddle       return read<Value>();
105485ab413bSRiver Riddle     case PDLValue::Kind::TypeRange:
105585ab413bSRiver Riddle       return read<TypeRange *>();
105685ab413bSRiver Riddle     case PDLValue::Kind::ValueRange:
105785ab413bSRiver Riddle       return read<ValueRange *>();
1058abfd1a8bSRiver Riddle     }
105985ab413bSRiver Riddle     llvm_unreachable("unhandled PDLValue::Kind");
1060abfd1a8bSRiver Riddle   }
1061abfd1a8bSRiver Riddle   template <typename T>
1062abfd1a8bSRiver Riddle   std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1063abfd1a8bSRiver Riddle     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1064abfd1a8bSRiver Riddle                   "unexpected ByteCode address size");
1065abfd1a8bSRiver Riddle     ByteCodeAddr result;
1066abfd1a8bSRiver Riddle     std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
1067abfd1a8bSRiver Riddle     curCodeIt += 2;
1068abfd1a8bSRiver Riddle     return result;
1069abfd1a8bSRiver Riddle   }
1070abfd1a8bSRiver Riddle   template <typename T>
1071abfd1a8bSRiver Riddle   std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1072abfd1a8bSRiver Riddle     return *curCodeIt++;
1073abfd1a8bSRiver Riddle   }
107485ab413bSRiver Riddle   template <typename T>
107585ab413bSRiver Riddle   std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
107685ab413bSRiver Riddle     return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
107785ab413bSRiver Riddle   }
1078abfd1a8bSRiver Riddle 
1079abfd1a8bSRiver Riddle   /// The underlying bytecode buffer.
1080abfd1a8bSRiver Riddle   const ByteCodeField *curCodeIt;
1081abfd1a8bSRiver Riddle 
1082abfd1a8bSRiver Riddle   /// The current execution memory.
1083abfd1a8bSRiver Riddle   MutableArrayRef<const void *> memory;
108485ab413bSRiver Riddle   MutableArrayRef<TypeRange> typeRangeMemory;
108585ab413bSRiver Riddle   std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
108685ab413bSRiver Riddle   MutableArrayRef<ValueRange> valueRangeMemory;
108785ab413bSRiver Riddle   std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1088abfd1a8bSRiver Riddle 
1089abfd1a8bSRiver Riddle   /// References to ByteCode data necessary for execution.
1090abfd1a8bSRiver Riddle   ArrayRef<const void *> uniquedMemory;
1091abfd1a8bSRiver Riddle   ArrayRef<ByteCodeField> code;
1092abfd1a8bSRiver Riddle   ArrayRef<PatternBenefit> currentPatternBenefits;
1093abfd1a8bSRiver Riddle   ArrayRef<PDLByteCodePattern> patterns;
1094abfd1a8bSRiver Riddle   ArrayRef<PDLConstraintFunction> constraintFunctions;
1095abfd1a8bSRiver Riddle   ArrayRef<PDLRewriteFunction> rewriteFunctions;
1096abfd1a8bSRiver Riddle };
109702c4c0d5SRiver Riddle 
109802c4c0d5SRiver Riddle /// This class is an instantiation of the PDLResultList that provides access to
109902c4c0d5SRiver Riddle /// the returned results. This API is not on `PDLResultList` to avoid
110002c4c0d5SRiver Riddle /// overexposing access to information specific solely to the ByteCode.
110102c4c0d5SRiver Riddle class ByteCodeRewriteResultList : public PDLResultList {
110202c4c0d5SRiver Riddle public:
110385ab413bSRiver Riddle   ByteCodeRewriteResultList(unsigned maxNumResults)
110485ab413bSRiver Riddle       : PDLResultList(maxNumResults) {}
110585ab413bSRiver Riddle 
110602c4c0d5SRiver Riddle   /// Return the list of PDL results.
110702c4c0d5SRiver Riddle   MutableArrayRef<PDLValue> getResults() { return results; }
110885ab413bSRiver Riddle 
110985ab413bSRiver Riddle   /// Return the type ranges allocated by this list.
111085ab413bSRiver Riddle   MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
111185ab413bSRiver Riddle     return allocatedTypeRanges;
111285ab413bSRiver Riddle   }
111385ab413bSRiver Riddle 
111485ab413bSRiver Riddle   /// Return the value ranges allocated by this list.
111585ab413bSRiver Riddle   MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
111685ab413bSRiver Riddle     return allocatedValueRanges;
111785ab413bSRiver Riddle   }
111802c4c0d5SRiver Riddle };
1119abfd1a8bSRiver Riddle } // end anonymous namespace
1120abfd1a8bSRiver Riddle 
1121154cabe7SRiver Riddle void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1122abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
1123abfd1a8bSRiver Riddle   const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
1124abfd1a8bSRiver Riddle   ArrayAttr constParams = read<ArrayAttr>();
1125abfd1a8bSRiver Riddle   SmallVector<PDLValue, 16> args;
1126abfd1a8bSRiver Riddle   readList<PDLValue>(args);
1127154cabe7SRiver Riddle 
1128abfd1a8bSRiver Riddle   LLVM_DEBUG({
1129abfd1a8bSRiver Riddle     llvm::dbgs() << "  * Arguments: ";
1130abfd1a8bSRiver Riddle     llvm::interleaveComma(args, llvm::dbgs());
1131154cabe7SRiver Riddle     llvm::dbgs() << "\n  * Parameters: " << constParams << "\n";
1132abfd1a8bSRiver Riddle   });
1133abfd1a8bSRiver Riddle 
1134abfd1a8bSRiver Riddle   // Invoke the constraint and jump to the proper destination.
1135abfd1a8bSRiver Riddle   selectJump(succeeded(constraintFn(args, constParams, rewriter)));
1136abfd1a8bSRiver Riddle }
1137154cabe7SRiver Riddle 
1138154cabe7SRiver Riddle void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1139abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
1140abfd1a8bSRiver Riddle   const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1141abfd1a8bSRiver Riddle   ArrayAttr constParams = read<ArrayAttr>();
1142abfd1a8bSRiver Riddle   SmallVector<PDLValue, 16> args;
1143abfd1a8bSRiver Riddle   readList<PDLValue>(args);
1144abfd1a8bSRiver Riddle 
1145abfd1a8bSRiver Riddle   LLVM_DEBUG({
114602c4c0d5SRiver Riddle     llvm::dbgs() << "  * Arguments: ";
1147abfd1a8bSRiver Riddle     llvm::interleaveComma(args, llvm::dbgs());
1148154cabe7SRiver Riddle     llvm::dbgs() << "\n  * Parameters: " << constParams << "\n";
1149abfd1a8bSRiver Riddle   });
115085ab413bSRiver Riddle 
115185ab413bSRiver Riddle   // Execute the rewrite function.
115285ab413bSRiver Riddle   ByteCodeField numResults = read();
115385ab413bSRiver Riddle   ByteCodeRewriteResultList results(numResults);
115402c4c0d5SRiver Riddle   rewriteFn(args, constParams, rewriter, results);
1155154cabe7SRiver Riddle 
115685ab413bSRiver Riddle   assert(results.getResults().size() == numResults &&
115702c4c0d5SRiver Riddle          "native PDL rewrite function returned unexpected number of results");
115802c4c0d5SRiver Riddle 
115902c4c0d5SRiver Riddle   // Store the results in the bytecode memory.
116002c4c0d5SRiver Riddle   for (PDLValue &result : results.getResults()) {
116102c4c0d5SRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n");
116285ab413bSRiver Riddle 
116385ab413bSRiver Riddle // In debug mode we also verify the expected kind of the result.
116485ab413bSRiver Riddle #ifndef NDEBUG
116585ab413bSRiver Riddle     assert(result.getKind() == read<PDLValue::Kind>() &&
116685ab413bSRiver Riddle            "native PDL rewrite function returned an unexpected type of result");
116785ab413bSRiver Riddle #endif
116885ab413bSRiver Riddle 
116985ab413bSRiver Riddle     // If the result is a range, we need to copy it over to the bytecodes
117085ab413bSRiver Riddle     // range memory.
117185ab413bSRiver Riddle     if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
117285ab413bSRiver Riddle       unsigned rangeIndex = read();
117385ab413bSRiver Riddle       typeRangeMemory[rangeIndex] = *typeRange;
117485ab413bSRiver Riddle       memory[read()] = &typeRangeMemory[rangeIndex];
117585ab413bSRiver Riddle     } else if (Optional<ValueRange> valueRange =
117685ab413bSRiver Riddle                    result.dyn_cast<ValueRange>()) {
117785ab413bSRiver Riddle       unsigned rangeIndex = read();
117885ab413bSRiver Riddle       valueRangeMemory[rangeIndex] = *valueRange;
117985ab413bSRiver Riddle       memory[read()] = &valueRangeMemory[rangeIndex];
118085ab413bSRiver Riddle     } else {
118102c4c0d5SRiver Riddle       memory[read()] = result.getAsOpaquePointer();
118202c4c0d5SRiver Riddle     }
1183abfd1a8bSRiver Riddle   }
1184154cabe7SRiver Riddle 
118585ab413bSRiver Riddle   // Copy over any underlying storage allocated for result ranges.
118685ab413bSRiver Riddle   for (auto &it : results.getAllocatedTypeRanges())
118785ab413bSRiver Riddle     allocatedTypeRangeMemory.push_back(std::move(it));
118885ab413bSRiver Riddle   for (auto &it : results.getAllocatedValueRanges())
118985ab413bSRiver Riddle     allocatedValueRangeMemory.push_back(std::move(it));
119085ab413bSRiver Riddle }
119185ab413bSRiver Riddle 
1192154cabe7SRiver Riddle void ByteCodeExecutor::executeAreEqual() {
1193abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1194abfd1a8bSRiver Riddle   const void *lhs = read<const void *>();
1195abfd1a8bSRiver Riddle   const void *rhs = read<const void *>();
1196abfd1a8bSRiver Riddle 
1197154cabe7SRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n");
1198abfd1a8bSRiver Riddle   selectJump(lhs == rhs);
1199abfd1a8bSRiver Riddle }
1200154cabe7SRiver Riddle 
120185ab413bSRiver Riddle void ByteCodeExecutor::executeAreRangesEqual() {
120285ab413bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
120385ab413bSRiver Riddle   PDLValue::Kind valueKind = read<PDLValue::Kind>();
120485ab413bSRiver Riddle   const void *lhs = read<const void *>();
120585ab413bSRiver Riddle   const void *rhs = read<const void *>();
120685ab413bSRiver Riddle 
120785ab413bSRiver Riddle   switch (valueKind) {
120885ab413bSRiver Riddle   case PDLValue::Kind::TypeRange: {
120985ab413bSRiver Riddle     const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
121085ab413bSRiver Riddle     const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
121185ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
121285ab413bSRiver Riddle     selectJump(*lhsRange == *rhsRange);
121385ab413bSRiver Riddle     break;
121485ab413bSRiver Riddle   }
121585ab413bSRiver Riddle   case PDLValue::Kind::ValueRange: {
121685ab413bSRiver Riddle     const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
121785ab413bSRiver Riddle     const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
121885ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
121985ab413bSRiver Riddle     selectJump(*lhsRange == *rhsRange);
122085ab413bSRiver Riddle     break;
122185ab413bSRiver Riddle   }
122285ab413bSRiver Riddle   default:
122385ab413bSRiver Riddle     llvm_unreachable("unexpected `AreRangesEqual` value kind");
122485ab413bSRiver Riddle   }
122585ab413bSRiver Riddle }
122685ab413bSRiver Riddle 
1227154cabe7SRiver Riddle void ByteCodeExecutor::executeBranch() {
1228154cabe7SRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
1229abfd1a8bSRiver Riddle   curCodeIt = &code[read<ByteCodeAddr>()];
1230abfd1a8bSRiver Riddle }
1231154cabe7SRiver Riddle 
1232154cabe7SRiver Riddle void ByteCodeExecutor::executeCheckOperandCount() {
1233abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
1234abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1235abfd1a8bSRiver Riddle   uint32_t expectedCount = read<uint32_t>();
123685ab413bSRiver Riddle   bool compareAtLeast = read();
1237abfd1a8bSRiver Riddle 
1238abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumOperands() << "\n"
123985ab413bSRiver Riddle                           << "  * Expected: " << expectedCount << "\n"
124085ab413bSRiver Riddle                           << "  * Comparator: "
124185ab413bSRiver Riddle                           << (compareAtLeast ? ">=" : "==") << "\n");
124285ab413bSRiver Riddle   if (compareAtLeast)
124385ab413bSRiver Riddle     selectJump(op->getNumOperands() >= expectedCount);
124485ab413bSRiver Riddle   else
1245abfd1a8bSRiver Riddle     selectJump(op->getNumOperands() == expectedCount);
1246abfd1a8bSRiver Riddle }
1247154cabe7SRiver Riddle 
1248154cabe7SRiver Riddle void ByteCodeExecutor::executeCheckOperationName() {
1249abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
1250abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1251abfd1a8bSRiver Riddle   OperationName expectedName = read<OperationName>();
1252abfd1a8bSRiver Riddle 
1253154cabe7SRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Found: \"" << op->getName() << "\"\n"
1254154cabe7SRiver Riddle                           << "  * Expected: \"" << expectedName << "\"\n");
1255abfd1a8bSRiver Riddle   selectJump(op->getName() == expectedName);
1256abfd1a8bSRiver Riddle }
1257154cabe7SRiver Riddle 
1258154cabe7SRiver Riddle void ByteCodeExecutor::executeCheckResultCount() {
1259abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
1260abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1261abfd1a8bSRiver Riddle   uint32_t expectedCount = read<uint32_t>();
126285ab413bSRiver Riddle   bool compareAtLeast = read();
1263abfd1a8bSRiver Riddle 
1264abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumResults() << "\n"
126585ab413bSRiver Riddle                           << "  * Expected: " << expectedCount << "\n"
126685ab413bSRiver Riddle                           << "  * Comparator: "
126785ab413bSRiver Riddle                           << (compareAtLeast ? ">=" : "==") << "\n");
126885ab413bSRiver Riddle   if (compareAtLeast)
126985ab413bSRiver Riddle     selectJump(op->getNumResults() >= expectedCount);
127085ab413bSRiver Riddle   else
1271abfd1a8bSRiver Riddle     selectJump(op->getNumResults() == expectedCount);
1272abfd1a8bSRiver Riddle }
1273154cabe7SRiver Riddle 
127485ab413bSRiver Riddle void ByteCodeExecutor::executeCheckTypes() {
127585ab413bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
127685ab413bSRiver Riddle   TypeRange *lhs = read<TypeRange *>();
127785ab413bSRiver Riddle   Attribute rhs = read<Attribute>();
127885ab413bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
127985ab413bSRiver Riddle 
128085ab413bSRiver Riddle   selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>());
128185ab413bSRiver Riddle }
128285ab413bSRiver Riddle 
128385ab413bSRiver Riddle void ByteCodeExecutor::executeCreateTypes() {
128485ab413bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n");
128585ab413bSRiver Riddle   unsigned memIndex = read();
128685ab413bSRiver Riddle   unsigned rangeIndex = read();
128785ab413bSRiver Riddle   ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>();
128885ab413bSRiver Riddle 
128985ab413bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Types: " << typesAttr << "\n\n");
129085ab413bSRiver Riddle 
129185ab413bSRiver Riddle   // Allocate a buffer for this type range.
129285ab413bSRiver Riddle   llvm::OwningArrayRef<Type> storage(typesAttr.size());
129385ab413bSRiver Riddle   llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin());
129485ab413bSRiver Riddle   allocatedTypeRangeMemory.emplace_back(std::move(storage));
129585ab413bSRiver Riddle 
129685ab413bSRiver Riddle   // Assign this to the range slot and use the range as the value for the
129785ab413bSRiver Riddle   // memory index.
129885ab413bSRiver Riddle   typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back();
129985ab413bSRiver Riddle   memory[memIndex] = &typeRangeMemory[rangeIndex];
130085ab413bSRiver Riddle }
130185ab413bSRiver Riddle 
1302154cabe7SRiver Riddle void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1303154cabe7SRiver Riddle                                               Location mainRewriteLoc) {
1304abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
1305abfd1a8bSRiver Riddle 
1306abfd1a8bSRiver Riddle   unsigned memIndex = read();
1307154cabe7SRiver Riddle   OperationState state(mainRewriteLoc, read<OperationName>());
130885ab413bSRiver Riddle   readValueList(state.operands);
1309abfd1a8bSRiver Riddle   for (unsigned i = 0, e = read(); i != e; ++i) {
1310abfd1a8bSRiver Riddle     Identifier name = read<Identifier>();
1311abfd1a8bSRiver Riddle     if (Attribute attr = read<Attribute>())
1312abfd1a8bSRiver Riddle       state.addAttribute(name, attr);
1313abfd1a8bSRiver Riddle   }
1314abfd1a8bSRiver Riddle 
1315abfd1a8bSRiver Riddle   for (unsigned i = 0, e = read(); i != e; ++i) {
131685ab413bSRiver Riddle     if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
131785ab413bSRiver Riddle       state.types.push_back(read<Type>());
131885ab413bSRiver Riddle       continue;
131985ab413bSRiver Riddle     }
132085ab413bSRiver Riddle 
132185ab413bSRiver Riddle     // If we find a null range, this signals that the types are infered.
132285ab413bSRiver Riddle     if (TypeRange *resultTypes = read<TypeRange *>()) {
132385ab413bSRiver Riddle       state.types.append(resultTypes->begin(), resultTypes->end());
132485ab413bSRiver Riddle       continue;
1325abfd1a8bSRiver Riddle     }
1326abfd1a8bSRiver Riddle 
1327abfd1a8bSRiver Riddle     // Handle the case where the operation has inferred types.
1328abfd1a8bSRiver Riddle     InferTypeOpInterface::Concept *concept =
1329154cabe7SRiver Riddle         state.name.getAbstractOperation()->getInterface<InferTypeOpInterface>();
1330abfd1a8bSRiver Riddle 
1331abfd1a8bSRiver Riddle     // TODO: Handle failure.
13323a833a0eSRiver Riddle     state.types.clear();
1333abfd1a8bSRiver Riddle     if (failed(concept->inferReturnTypes(
1334abfd1a8bSRiver Riddle             state.getContext(), state.location, state.operands,
1335154cabe7SRiver Riddle             state.attributes.getDictionary(state.getContext()), state.regions,
13363a833a0eSRiver Riddle             state.types)))
1337abfd1a8bSRiver Riddle       return;
133885ab413bSRiver Riddle     break;
1339abfd1a8bSRiver Riddle   }
134085ab413bSRiver Riddle 
1341abfd1a8bSRiver Riddle   Operation *resultOp = rewriter.createOperation(state);
1342abfd1a8bSRiver Riddle   memory[memIndex] = resultOp;
1343abfd1a8bSRiver Riddle 
1344abfd1a8bSRiver Riddle   LLVM_DEBUG({
1345abfd1a8bSRiver Riddle     llvm::dbgs() << "  * Attributes: "
1346abfd1a8bSRiver Riddle                  << state.attributes.getDictionary(state.getContext())
1347abfd1a8bSRiver Riddle                  << "\n  * Operands: ";
1348abfd1a8bSRiver Riddle     llvm::interleaveComma(state.operands, llvm::dbgs());
1349abfd1a8bSRiver Riddle     llvm::dbgs() << "\n  * Result Types: ";
1350abfd1a8bSRiver Riddle     llvm::interleaveComma(state.types, llvm::dbgs());
1351154cabe7SRiver Riddle     llvm::dbgs() << "\n  * Result: " << *resultOp << "\n";
1352abfd1a8bSRiver Riddle   });
1353abfd1a8bSRiver Riddle }
1354154cabe7SRiver Riddle 
1355154cabe7SRiver Riddle void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1356abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1357abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1358abfd1a8bSRiver Riddle 
1359154cabe7SRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1360abfd1a8bSRiver Riddle   rewriter.eraseOp(op);
1361abfd1a8bSRiver Riddle }
1362154cabe7SRiver Riddle 
1363154cabe7SRiver Riddle void ByteCodeExecutor::executeGetAttribute() {
1364abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1365abfd1a8bSRiver Riddle   unsigned memIndex = read();
1366abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1367abfd1a8bSRiver Riddle   Identifier attrName = read<Identifier>();
1368abfd1a8bSRiver Riddle   Attribute attr = op->getAttr(attrName);
1369abfd1a8bSRiver Riddle 
1370abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1371abfd1a8bSRiver Riddle                           << "  * Attribute: " << attrName << "\n"
1372154cabe7SRiver Riddle                           << "  * Result: " << attr << "\n");
1373abfd1a8bSRiver Riddle   memory[memIndex] = attr.getAsOpaquePointer();
1374abfd1a8bSRiver Riddle }
1375154cabe7SRiver Riddle 
1376154cabe7SRiver Riddle void ByteCodeExecutor::executeGetAttributeType() {
1377abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1378abfd1a8bSRiver Riddle   unsigned memIndex = read();
1379abfd1a8bSRiver Riddle   Attribute attr = read<Attribute>();
1380154cabe7SRiver Riddle   Type type = attr ? attr.getType() : Type();
1381abfd1a8bSRiver Riddle 
1382abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Attribute: " << attr << "\n"
1383154cabe7SRiver Riddle                           << "  * Result: " << type << "\n");
1384154cabe7SRiver Riddle   memory[memIndex] = type.getAsOpaquePointer();
1385abfd1a8bSRiver Riddle }
1386154cabe7SRiver Riddle 
1387154cabe7SRiver Riddle void ByteCodeExecutor::executeGetDefiningOp() {
1388abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1389abfd1a8bSRiver Riddle   unsigned memIndex = read();
139085ab413bSRiver Riddle   Operation *op = nullptr;
139185ab413bSRiver Riddle   if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1392abfd1a8bSRiver Riddle     Value value = read<Value>();
139385ab413bSRiver Riddle     if (value)
139485ab413bSRiver Riddle       op = value.getDefiningOp();
139585ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
139685ab413bSRiver Riddle   } else {
139785ab413bSRiver Riddle     ValueRange *values = read<ValueRange *>();
139885ab413bSRiver Riddle     if (values && !values->empty()) {
139985ab413bSRiver Riddle       op = values->front().getDefiningOp();
140085ab413bSRiver Riddle     }
140185ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "  * Values: " << values << "\n");
140285ab413bSRiver Riddle   }
1403abfd1a8bSRiver Riddle 
140485ab413bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Result: " << op << "\n");
1405abfd1a8bSRiver Riddle   memory[memIndex] = op;
1406abfd1a8bSRiver Riddle }
1407154cabe7SRiver Riddle 
1408154cabe7SRiver Riddle void ByteCodeExecutor::executeGetOperand(unsigned index) {
1409abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1410abfd1a8bSRiver Riddle   unsigned memIndex = read();
1411abfd1a8bSRiver Riddle   Value operand =
1412abfd1a8bSRiver Riddle       index < op->getNumOperands() ? op->getOperand(index) : Value();
1413abfd1a8bSRiver Riddle 
1414abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1415abfd1a8bSRiver Riddle                           << "  * Index: " << index << "\n"
1416154cabe7SRiver Riddle                           << "  * Result: " << operand << "\n");
1417abfd1a8bSRiver Riddle   memory[memIndex] = operand.getAsOpaquePointer();
1418abfd1a8bSRiver Riddle }
1419154cabe7SRiver Riddle 
142085ab413bSRiver Riddle /// This function is the internal implementation of `GetResults` and
142185ab413bSRiver Riddle /// `GetOperands` that provides support for extracting a value range from the
142285ab413bSRiver Riddle /// given operation.
142385ab413bSRiver Riddle template <template <typename> class AttrSizedSegmentsT, typename RangeT>
142485ab413bSRiver Riddle static void *
142585ab413bSRiver Riddle executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
142685ab413bSRiver Riddle                           ByteCodeField rangeIndex, StringRef attrSizedSegments,
142785ab413bSRiver Riddle                           MutableArrayRef<ValueRange> &valueRangeMemory) {
142885ab413bSRiver Riddle   // Check for the sentinel index that signals that all values should be
142985ab413bSRiver Riddle   // returned.
143085ab413bSRiver Riddle   if (index == std::numeric_limits<uint32_t>::max()) {
143185ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "  * Getting all values\n");
143285ab413bSRiver Riddle     // `values` is already the full value range.
143385ab413bSRiver Riddle 
143485ab413bSRiver Riddle     // Otherwise, check to see if this operation uses AttrSizedSegments.
143585ab413bSRiver Riddle   } else if (op->hasTrait<AttrSizedSegmentsT>()) {
143685ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs()
143785ab413bSRiver Riddle                << "  * Extracting values from `" << attrSizedSegments << "`\n");
143885ab413bSRiver Riddle 
143985ab413bSRiver Riddle     auto segmentAttr = op->getAttrOfType<DenseElementsAttr>(attrSizedSegments);
144085ab413bSRiver Riddle     if (!segmentAttr || segmentAttr.getNumElements() <= index)
144185ab413bSRiver Riddle       return nullptr;
144285ab413bSRiver Riddle 
144385ab413bSRiver Riddle     auto segments = segmentAttr.getValues<int32_t>();
144485ab413bSRiver Riddle     unsigned startIndex =
144585ab413bSRiver Riddle         std::accumulate(segments.begin(), segments.begin() + index, 0);
144685ab413bSRiver Riddle     values = values.slice(startIndex, *std::next(segments.begin(), index));
144785ab413bSRiver Riddle 
144885ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "  * Extracting range[" << startIndex << ", "
144985ab413bSRiver Riddle                             << *std::next(segments.begin(), index) << "]\n");
145085ab413bSRiver Riddle 
145185ab413bSRiver Riddle     // Otherwise, assume this is the last operand group of the operation.
145285ab413bSRiver Riddle     // FIXME: We currently don't support operations with
145385ab413bSRiver Riddle     // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
145485ab413bSRiver Riddle     // have a way to detect it's presence.
145585ab413bSRiver Riddle   } else if (values.size() >= index) {
145685ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs()
145785ab413bSRiver Riddle                << "  * Treating values as trailing variadic range\n");
145885ab413bSRiver Riddle     values = values.drop_front(index);
145985ab413bSRiver Riddle 
146085ab413bSRiver Riddle     // If we couldn't detect a way to compute the values, bail out.
146185ab413bSRiver Riddle   } else {
146285ab413bSRiver Riddle     return nullptr;
146385ab413bSRiver Riddle   }
146485ab413bSRiver Riddle 
146585ab413bSRiver Riddle   // If the range index is valid, we are returning a range.
146685ab413bSRiver Riddle   if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
146785ab413bSRiver Riddle     valueRangeMemory[rangeIndex] = values;
146885ab413bSRiver Riddle     return &valueRangeMemory[rangeIndex];
146985ab413bSRiver Riddle   }
147085ab413bSRiver Riddle 
147185ab413bSRiver Riddle   // If a range index wasn't provided, the range is required to be non-variadic.
147285ab413bSRiver Riddle   return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
147385ab413bSRiver Riddle }
147485ab413bSRiver Riddle 
147585ab413bSRiver Riddle void ByteCodeExecutor::executeGetOperands() {
147685ab413bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
147785ab413bSRiver Riddle   unsigned index = read<uint32_t>();
147885ab413bSRiver Riddle   Operation *op = read<Operation *>();
147985ab413bSRiver Riddle   ByteCodeField rangeIndex = read();
148085ab413bSRiver Riddle 
148185ab413bSRiver Riddle   void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
148285ab413bSRiver Riddle       op->getOperands(), op, index, rangeIndex, "operand_segment_sizes",
148385ab413bSRiver Riddle       valueRangeMemory);
148485ab413bSRiver Riddle   if (!result)
148585ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "  * Invalid operand range\n");
148685ab413bSRiver Riddle   memory[read()] = result;
148785ab413bSRiver Riddle }
148885ab413bSRiver Riddle 
1489154cabe7SRiver Riddle void ByteCodeExecutor::executeGetResult(unsigned index) {
1490abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1491abfd1a8bSRiver Riddle   unsigned memIndex = read();
1492abfd1a8bSRiver Riddle   OpResult result =
1493abfd1a8bSRiver Riddle       index < op->getNumResults() ? op->getResult(index) : OpResult();
1494abfd1a8bSRiver Riddle 
1495abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1496abfd1a8bSRiver Riddle                           << "  * Index: " << index << "\n"
1497154cabe7SRiver Riddle                           << "  * Result: " << result << "\n");
1498abfd1a8bSRiver Riddle   memory[memIndex] = result.getAsOpaquePointer();
1499abfd1a8bSRiver Riddle }
1500154cabe7SRiver Riddle 
150185ab413bSRiver Riddle void ByteCodeExecutor::executeGetResults() {
150285ab413bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
150385ab413bSRiver Riddle   unsigned index = read<uint32_t>();
150485ab413bSRiver Riddle   Operation *op = read<Operation *>();
150585ab413bSRiver Riddle   ByteCodeField rangeIndex = read();
150685ab413bSRiver Riddle 
150785ab413bSRiver Riddle   void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
150885ab413bSRiver Riddle       op->getResults(), op, index, rangeIndex, "result_segment_sizes",
150985ab413bSRiver Riddle       valueRangeMemory);
151085ab413bSRiver Riddle   if (!result)
151185ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "  * Invalid result range\n");
151285ab413bSRiver Riddle   memory[read()] = result;
151385ab413bSRiver Riddle }
151485ab413bSRiver Riddle 
1515154cabe7SRiver Riddle void ByteCodeExecutor::executeGetValueType() {
1516abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1517abfd1a8bSRiver Riddle   unsigned memIndex = read();
1518abfd1a8bSRiver Riddle   Value value = read<Value>();
1519154cabe7SRiver Riddle   Type type = value ? value.getType() : Type();
1520abfd1a8bSRiver Riddle 
1521abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n"
1522154cabe7SRiver Riddle                           << "  * Result: " << type << "\n");
1523154cabe7SRiver Riddle   memory[memIndex] = type.getAsOpaquePointer();
1524abfd1a8bSRiver Riddle }
1525154cabe7SRiver Riddle 
152685ab413bSRiver Riddle void ByteCodeExecutor::executeGetValueRangeTypes() {
152785ab413bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
152885ab413bSRiver Riddle   unsigned memIndex = read();
152985ab413bSRiver Riddle   unsigned rangeIndex = read();
153085ab413bSRiver Riddle   ValueRange *values = read<ValueRange *>();
153185ab413bSRiver Riddle   if (!values) {
153285ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "  * Values: <NULL>\n\n");
153385ab413bSRiver Riddle     memory[memIndex] = nullptr;
153485ab413bSRiver Riddle     return;
153585ab413bSRiver Riddle   }
153685ab413bSRiver Riddle 
153785ab413bSRiver Riddle   LLVM_DEBUG({
153885ab413bSRiver Riddle     llvm::dbgs() << "  * Values (" << values->size() << "): ";
153985ab413bSRiver Riddle     llvm::interleaveComma(*values, llvm::dbgs());
154085ab413bSRiver Riddle     llvm::dbgs() << "\n  * Result: ";
154185ab413bSRiver Riddle     llvm::interleaveComma(values->getType(), llvm::dbgs());
154285ab413bSRiver Riddle     llvm::dbgs() << "\n";
154385ab413bSRiver Riddle   });
154485ab413bSRiver Riddle   typeRangeMemory[rangeIndex] = values->getType();
154585ab413bSRiver Riddle   memory[memIndex] = &typeRangeMemory[rangeIndex];
154685ab413bSRiver Riddle }
154785ab413bSRiver Riddle 
1548154cabe7SRiver Riddle void ByteCodeExecutor::executeIsNotNull() {
1549abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
1550abfd1a8bSRiver Riddle   const void *value = read<const void *>();
1551abfd1a8bSRiver Riddle 
1552154cabe7SRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1553abfd1a8bSRiver Riddle   selectJump(value != nullptr);
1554abfd1a8bSRiver Riddle }
1555154cabe7SRiver Riddle 
1556154cabe7SRiver Riddle void ByteCodeExecutor::executeRecordMatch(
1557154cabe7SRiver Riddle     PatternRewriter &rewriter,
1558154cabe7SRiver Riddle     SmallVectorImpl<PDLByteCode::MatchResult> &matches) {
1559abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
1560abfd1a8bSRiver Riddle   unsigned patternIndex = read();
1561abfd1a8bSRiver Riddle   PatternBenefit benefit = currentPatternBenefits[patternIndex];
1562abfd1a8bSRiver Riddle   const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1563abfd1a8bSRiver Riddle 
1564abfd1a8bSRiver Riddle   // If the benefit of the pattern is impossible, skip the processing of the
1565abfd1a8bSRiver Riddle   // rest of the pattern.
1566abfd1a8bSRiver Riddle   if (benefit.isImpossibleToMatch()) {
1567154cabe7SRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "  * Benefit: Impossible To Match\n");
1568abfd1a8bSRiver Riddle     curCodeIt = dest;
1569154cabe7SRiver Riddle     return;
1570abfd1a8bSRiver Riddle   }
1571abfd1a8bSRiver Riddle 
1572abfd1a8bSRiver Riddle   // Create a fused location containing the locations of each of the
1573abfd1a8bSRiver Riddle   // operations used in the match. This will be used as the location for
1574abfd1a8bSRiver Riddle   // created operations during the rewrite that don't already have an
1575abfd1a8bSRiver Riddle   // explicit location set.
1576abfd1a8bSRiver Riddle   unsigned numMatchLocs = read();
1577abfd1a8bSRiver Riddle   SmallVector<Location, 4> matchLocs;
1578abfd1a8bSRiver Riddle   matchLocs.reserve(numMatchLocs);
1579abfd1a8bSRiver Riddle   for (unsigned i = 0; i != numMatchLocs; ++i)
1580abfd1a8bSRiver Riddle     matchLocs.push_back(read<Operation *>()->getLoc());
1581abfd1a8bSRiver Riddle   Location matchLoc = rewriter.getFusedLoc(matchLocs);
1582abfd1a8bSRiver Riddle 
1583abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Benefit: " << benefit.getBenefit() << "\n"
1584154cabe7SRiver Riddle                           << "  * Location: " << matchLoc << "\n");
1585154cabe7SRiver Riddle   matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
158685ab413bSRiver Riddle   PDLByteCode::MatchResult &match = matches.back();
158785ab413bSRiver Riddle 
158885ab413bSRiver Riddle   // Record all of the inputs to the match. If any of the inputs are ranges, we
158985ab413bSRiver Riddle   // will also need to remap the range pointer to memory stored in the match
159085ab413bSRiver Riddle   // state.
159185ab413bSRiver Riddle   unsigned numInputs = read();
159285ab413bSRiver Riddle   match.values.reserve(numInputs);
159385ab413bSRiver Riddle   match.typeRangeValues.reserve(numInputs);
159485ab413bSRiver Riddle   match.valueRangeValues.reserve(numInputs);
159585ab413bSRiver Riddle   for (unsigned i = 0; i < numInputs; ++i) {
159685ab413bSRiver Riddle     switch (read<PDLValue::Kind>()) {
159785ab413bSRiver Riddle     case PDLValue::Kind::TypeRange:
159885ab413bSRiver Riddle       match.typeRangeValues.push_back(*read<TypeRange *>());
159985ab413bSRiver Riddle       match.values.push_back(&match.typeRangeValues.back());
160085ab413bSRiver Riddle       break;
160185ab413bSRiver Riddle     case PDLValue::Kind::ValueRange:
160285ab413bSRiver Riddle       match.valueRangeValues.push_back(*read<ValueRange *>());
160385ab413bSRiver Riddle       match.values.push_back(&match.valueRangeValues.back());
160485ab413bSRiver Riddle       break;
160585ab413bSRiver Riddle     default:
160685ab413bSRiver Riddle       match.values.push_back(read<const void *>());
160785ab413bSRiver Riddle       break;
160885ab413bSRiver Riddle     }
160985ab413bSRiver Riddle   }
1610abfd1a8bSRiver Riddle   curCodeIt = dest;
1611abfd1a8bSRiver Riddle }
1612154cabe7SRiver Riddle 
1613154cabe7SRiver Riddle void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
1614abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
1615abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1616abfd1a8bSRiver Riddle   SmallVector<Value, 16> args;
161785ab413bSRiver Riddle   readValueList(args);
1618abfd1a8bSRiver Riddle 
1619abfd1a8bSRiver Riddle   LLVM_DEBUG({
1620abfd1a8bSRiver Riddle     llvm::dbgs() << "  * Operation: " << *op << "\n"
1621abfd1a8bSRiver Riddle                  << "  * Values: ";
1622abfd1a8bSRiver Riddle     llvm::interleaveComma(args, llvm::dbgs());
1623154cabe7SRiver Riddle     llvm::dbgs() << "\n";
1624abfd1a8bSRiver Riddle   });
1625abfd1a8bSRiver Riddle   rewriter.replaceOp(op, args);
1626abfd1a8bSRiver Riddle }
1627154cabe7SRiver Riddle 
1628154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchAttribute() {
1629abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
1630abfd1a8bSRiver Riddle   Attribute value = read<Attribute>();
1631abfd1a8bSRiver Riddle   ArrayAttr cases = read<ArrayAttr>();
1632abfd1a8bSRiver Riddle   handleSwitch(value, cases);
1633abfd1a8bSRiver Riddle }
1634154cabe7SRiver Riddle 
1635154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchOperandCount() {
1636abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
1637abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1638abfd1a8bSRiver Riddle   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1639abfd1a8bSRiver Riddle 
1640abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1641abfd1a8bSRiver Riddle   handleSwitch(op->getNumOperands(), cases);
1642abfd1a8bSRiver Riddle }
1643154cabe7SRiver Riddle 
1644154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchOperationName() {
1645abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
1646abfd1a8bSRiver Riddle   OperationName value = read<Operation *>()->getName();
1647abfd1a8bSRiver Riddle   size_t caseCount = read();
1648abfd1a8bSRiver Riddle 
1649abfd1a8bSRiver Riddle   // The operation names are stored in-line, so to print them out for
1650abfd1a8bSRiver Riddle   // debugging purposes we need to read the array before executing the
1651abfd1a8bSRiver Riddle   // switch so that we can display all of the possible values.
1652abfd1a8bSRiver Riddle   LLVM_DEBUG({
1653abfd1a8bSRiver Riddle     const ByteCodeField *prevCodeIt = curCodeIt;
1654abfd1a8bSRiver Riddle     llvm::dbgs() << "  * Value: " << value << "\n"
1655abfd1a8bSRiver Riddle                  << "  * Cases: ";
1656abfd1a8bSRiver Riddle     llvm::interleaveComma(
1657abfd1a8bSRiver Riddle         llvm::map_range(llvm::seq<size_t>(0, caseCount),
1658154cabe7SRiver Riddle                         [&](size_t) { return read<OperationName>(); }),
1659abfd1a8bSRiver Riddle         llvm::dbgs());
1660154cabe7SRiver Riddle     llvm::dbgs() << "\n";
1661abfd1a8bSRiver Riddle     curCodeIt = prevCodeIt;
1662abfd1a8bSRiver Riddle   });
1663abfd1a8bSRiver Riddle 
1664abfd1a8bSRiver Riddle   // Try to find the switch value within any of the cases.
1665abfd1a8bSRiver Riddle   for (size_t i = 0; i != caseCount; ++i) {
1666abfd1a8bSRiver Riddle     if (read<OperationName>() == value) {
1667abfd1a8bSRiver Riddle       curCodeIt += (caseCount - i - 1);
1668154cabe7SRiver Riddle       return selectJump(i + 1);
1669abfd1a8bSRiver Riddle     }
1670abfd1a8bSRiver Riddle   }
1671154cabe7SRiver Riddle   selectJump(size_t(0));
1672abfd1a8bSRiver Riddle }
1673154cabe7SRiver Riddle 
1674154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchResultCount() {
1675abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
1676abfd1a8bSRiver Riddle   Operation *op = read<Operation *>();
1677abfd1a8bSRiver Riddle   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1678abfd1a8bSRiver Riddle 
1679abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1680abfd1a8bSRiver Riddle   handleSwitch(op->getNumResults(), cases);
1681abfd1a8bSRiver Riddle }
1682154cabe7SRiver Riddle 
1683154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchType() {
1684abfd1a8bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
1685abfd1a8bSRiver Riddle   Type value = read<Type>();
1686abfd1a8bSRiver Riddle   auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
1687abfd1a8bSRiver Riddle   handleSwitch(value, cases);
1688154cabe7SRiver Riddle }
1689154cabe7SRiver Riddle 
169085ab413bSRiver Riddle void ByteCodeExecutor::executeSwitchTypes() {
169185ab413bSRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
169285ab413bSRiver Riddle   TypeRange *value = read<TypeRange *>();
169385ab413bSRiver Riddle   auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
169485ab413bSRiver Riddle   if (!value) {
169585ab413bSRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
169685ab413bSRiver Riddle     return selectJump(size_t(0));
169785ab413bSRiver Riddle   }
169885ab413bSRiver Riddle   handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
169985ab413bSRiver Riddle     return value == caseValue.getAsValueRange<TypeAttr>();
170085ab413bSRiver Riddle   });
170185ab413bSRiver Riddle }
170285ab413bSRiver Riddle 
1703154cabe7SRiver Riddle void ByteCodeExecutor::execute(
1704154cabe7SRiver Riddle     PatternRewriter &rewriter,
1705154cabe7SRiver Riddle     SmallVectorImpl<PDLByteCode::MatchResult> *matches,
1706154cabe7SRiver Riddle     Optional<Location> mainRewriteLoc) {
1707154cabe7SRiver Riddle   while (true) {
1708154cabe7SRiver Riddle     OpCode opCode = static_cast<OpCode>(read());
1709154cabe7SRiver Riddle     switch (opCode) {
1710154cabe7SRiver Riddle     case ApplyConstraint:
1711154cabe7SRiver Riddle       executeApplyConstraint(rewriter);
1712154cabe7SRiver Riddle       break;
1713154cabe7SRiver Riddle     case ApplyRewrite:
1714154cabe7SRiver Riddle       executeApplyRewrite(rewriter);
1715154cabe7SRiver Riddle       break;
1716154cabe7SRiver Riddle     case AreEqual:
1717154cabe7SRiver Riddle       executeAreEqual();
1718154cabe7SRiver Riddle       break;
171985ab413bSRiver Riddle     case AreRangesEqual:
172085ab413bSRiver Riddle       executeAreRangesEqual();
172185ab413bSRiver Riddle       break;
1722154cabe7SRiver Riddle     case Branch:
1723154cabe7SRiver Riddle       executeBranch();
1724154cabe7SRiver Riddle       break;
1725154cabe7SRiver Riddle     case CheckOperandCount:
1726154cabe7SRiver Riddle       executeCheckOperandCount();
1727154cabe7SRiver Riddle       break;
1728154cabe7SRiver Riddle     case CheckOperationName:
1729154cabe7SRiver Riddle       executeCheckOperationName();
1730154cabe7SRiver Riddle       break;
1731154cabe7SRiver Riddle     case CheckResultCount:
1732154cabe7SRiver Riddle       executeCheckResultCount();
1733154cabe7SRiver Riddle       break;
173485ab413bSRiver Riddle     case CheckTypes:
173585ab413bSRiver Riddle       executeCheckTypes();
173685ab413bSRiver Riddle       break;
1737154cabe7SRiver Riddle     case CreateOperation:
1738154cabe7SRiver Riddle       executeCreateOperation(rewriter, *mainRewriteLoc);
1739154cabe7SRiver Riddle       break;
174085ab413bSRiver Riddle     case CreateTypes:
174185ab413bSRiver Riddle       executeCreateTypes();
174285ab413bSRiver Riddle       break;
1743154cabe7SRiver Riddle     case EraseOp:
1744154cabe7SRiver Riddle       executeEraseOp(rewriter);
1745154cabe7SRiver Riddle       break;
1746154cabe7SRiver Riddle     case Finalize:
1747154cabe7SRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n");
1748154cabe7SRiver Riddle       return;
1749154cabe7SRiver Riddle     case GetAttribute:
1750154cabe7SRiver Riddle       executeGetAttribute();
1751154cabe7SRiver Riddle       break;
1752154cabe7SRiver Riddle     case GetAttributeType:
1753154cabe7SRiver Riddle       executeGetAttributeType();
1754154cabe7SRiver Riddle       break;
1755154cabe7SRiver Riddle     case GetDefiningOp:
1756154cabe7SRiver Riddle       executeGetDefiningOp();
1757154cabe7SRiver Riddle       break;
1758154cabe7SRiver Riddle     case GetOperand0:
1759154cabe7SRiver Riddle     case GetOperand1:
1760154cabe7SRiver Riddle     case GetOperand2:
1761154cabe7SRiver Riddle     case GetOperand3: {
1762154cabe7SRiver Riddle       unsigned index = opCode - GetOperand0;
1763154cabe7SRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
17641fff7c89SFrederik Gossen       executeGetOperand(index);
1765abfd1a8bSRiver Riddle       break;
1766abfd1a8bSRiver Riddle     }
1767154cabe7SRiver Riddle     case GetOperandN:
1768154cabe7SRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
1769154cabe7SRiver Riddle       executeGetOperand(read<uint32_t>());
1770154cabe7SRiver Riddle       break;
177185ab413bSRiver Riddle     case GetOperands:
177285ab413bSRiver Riddle       executeGetOperands();
177385ab413bSRiver Riddle       break;
1774154cabe7SRiver Riddle     case GetResult0:
1775154cabe7SRiver Riddle     case GetResult1:
1776154cabe7SRiver Riddle     case GetResult2:
1777154cabe7SRiver Riddle     case GetResult3: {
1778154cabe7SRiver Riddle       unsigned index = opCode - GetResult0;
1779154cabe7SRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
17801fff7c89SFrederik Gossen       executeGetResult(index);
1781154cabe7SRiver Riddle       break;
1782abfd1a8bSRiver Riddle     }
1783154cabe7SRiver Riddle     case GetResultN:
1784154cabe7SRiver Riddle       LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
1785154cabe7SRiver Riddle       executeGetResult(read<uint32_t>());
1786154cabe7SRiver Riddle       break;
178785ab413bSRiver Riddle     case GetResults:
178885ab413bSRiver Riddle       executeGetResults();
178985ab413bSRiver Riddle       break;
1790154cabe7SRiver Riddle     case GetValueType:
1791154cabe7SRiver Riddle       executeGetValueType();
1792154cabe7SRiver Riddle       break;
179385ab413bSRiver Riddle     case GetValueRangeTypes:
179485ab413bSRiver Riddle       executeGetValueRangeTypes();
179585ab413bSRiver Riddle       break;
1796154cabe7SRiver Riddle     case IsNotNull:
1797154cabe7SRiver Riddle       executeIsNotNull();
1798154cabe7SRiver Riddle       break;
1799154cabe7SRiver Riddle     case RecordMatch:
1800154cabe7SRiver Riddle       assert(matches &&
1801154cabe7SRiver Riddle              "expected matches to be provided when executing the matcher");
1802154cabe7SRiver Riddle       executeRecordMatch(rewriter, *matches);
1803154cabe7SRiver Riddle       break;
1804154cabe7SRiver Riddle     case ReplaceOp:
1805154cabe7SRiver Riddle       executeReplaceOp(rewriter);
1806154cabe7SRiver Riddle       break;
1807154cabe7SRiver Riddle     case SwitchAttribute:
1808154cabe7SRiver Riddle       executeSwitchAttribute();
1809154cabe7SRiver Riddle       break;
1810154cabe7SRiver Riddle     case SwitchOperandCount:
1811154cabe7SRiver Riddle       executeSwitchOperandCount();
1812154cabe7SRiver Riddle       break;
1813154cabe7SRiver Riddle     case SwitchOperationName:
1814154cabe7SRiver Riddle       executeSwitchOperationName();
1815154cabe7SRiver Riddle       break;
1816154cabe7SRiver Riddle     case SwitchResultCount:
1817154cabe7SRiver Riddle       executeSwitchResultCount();
1818154cabe7SRiver Riddle       break;
1819154cabe7SRiver Riddle     case SwitchType:
1820154cabe7SRiver Riddle       executeSwitchType();
1821154cabe7SRiver Riddle       break;
182285ab413bSRiver Riddle     case SwitchTypes:
182385ab413bSRiver Riddle       executeSwitchTypes();
182485ab413bSRiver Riddle       break;
1825154cabe7SRiver Riddle     }
1826154cabe7SRiver Riddle     LLVM_DEBUG(llvm::dbgs() << "\n");
1827abfd1a8bSRiver Riddle   }
1828abfd1a8bSRiver Riddle }
1829abfd1a8bSRiver Riddle 
1830abfd1a8bSRiver Riddle /// Run the pattern matcher on the given root operation, collecting the matched
1831abfd1a8bSRiver Riddle /// patterns in `matches`.
1832abfd1a8bSRiver Riddle void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
1833abfd1a8bSRiver Riddle                         SmallVectorImpl<MatchResult> &matches,
1834abfd1a8bSRiver Riddle                         PDLByteCodeMutableState &state) const {
1835abfd1a8bSRiver Riddle   // The first memory slot is always the root operation.
1836abfd1a8bSRiver Riddle   state.memory[0] = op;
1837abfd1a8bSRiver Riddle 
1838abfd1a8bSRiver Riddle   // The matcher function always starts at code address 0.
183985ab413bSRiver Riddle   ByteCodeExecutor executor(
184085ab413bSRiver Riddle       matcherByteCode.data(), state.memory, state.typeRangeMemory,
184185ab413bSRiver Riddle       state.allocatedTypeRangeMemory, state.valueRangeMemory,
184285ab413bSRiver Riddle       state.allocatedValueRangeMemory, uniquedData, matcherByteCode,
184385ab413bSRiver Riddle       state.currentPatternBenefits, patterns, constraintFunctions,
184485ab413bSRiver Riddle       rewriteFunctions);
1845abfd1a8bSRiver Riddle   executor.execute(rewriter, &matches);
1846abfd1a8bSRiver Riddle 
1847abfd1a8bSRiver Riddle   // Order the found matches by benefit.
1848abfd1a8bSRiver Riddle   std::stable_sort(matches.begin(), matches.end(),
1849abfd1a8bSRiver Riddle                    [](const MatchResult &lhs, const MatchResult &rhs) {
1850abfd1a8bSRiver Riddle                      return lhs.benefit > rhs.benefit;
1851abfd1a8bSRiver Riddle                    });
1852abfd1a8bSRiver Riddle }
1853abfd1a8bSRiver Riddle 
1854abfd1a8bSRiver Riddle /// Run the rewriter of the given pattern on the root operation `op`.
1855abfd1a8bSRiver Riddle void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
1856abfd1a8bSRiver Riddle                           PDLByteCodeMutableState &state) const {
1857abfd1a8bSRiver Riddle   // The arguments of the rewrite function are stored at the start of the
1858abfd1a8bSRiver Riddle   // memory buffer.
1859abfd1a8bSRiver Riddle   llvm::copy(match.values, state.memory.begin());
1860abfd1a8bSRiver Riddle 
186185ab413bSRiver Riddle   ByteCodeExecutor executor(
186285ab413bSRiver Riddle       &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
186385ab413bSRiver Riddle       state.typeRangeMemory, state.allocatedTypeRangeMemory,
186485ab413bSRiver Riddle       state.valueRangeMemory, state.allocatedValueRangeMemory, uniquedData,
186585ab413bSRiver Riddle       rewriterByteCode, state.currentPatternBenefits, patterns,
186602c4c0d5SRiver Riddle       constraintFunctions, rewriteFunctions);
1867abfd1a8bSRiver Riddle   executor.execute(rewriter, /*matches=*/nullptr, match.location);
1868abfd1a8bSRiver Riddle }
1869