1abfd1a8bSRiver Riddle //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===// 2abfd1a8bSRiver Riddle // 3abfd1a8bSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4abfd1a8bSRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 5abfd1a8bSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6abfd1a8bSRiver Riddle // 7abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 8abfd1a8bSRiver Riddle // 9abfd1a8bSRiver Riddle // This file implements MLIR to byte-code generation and the interpreter. 10abfd1a8bSRiver Riddle // 11abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 12abfd1a8bSRiver Riddle 13abfd1a8bSRiver Riddle #include "ByteCode.h" 14abfd1a8bSRiver Riddle #include "mlir/Analysis/Liveness.h" 15abfd1a8bSRiver Riddle #include "mlir/Dialect/PDL/IR/PDLTypes.h" 16abfd1a8bSRiver Riddle #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 17abfd1a8bSRiver Riddle #include "mlir/IR/Function.h" 18abfd1a8bSRiver Riddle #include "mlir/IR/RegionGraphTraits.h" 19abfd1a8bSRiver Riddle #include "llvm/ADT/IntervalMap.h" 20abfd1a8bSRiver Riddle #include "llvm/ADT/PostOrderIterator.h" 21abfd1a8bSRiver Riddle #include "llvm/ADT/TypeSwitch.h" 22abfd1a8bSRiver Riddle #include "llvm/Support/Debug.h" 23abfd1a8bSRiver Riddle 24abfd1a8bSRiver Riddle #define DEBUG_TYPE "pdl-bytecode" 25abfd1a8bSRiver Riddle 26abfd1a8bSRiver Riddle using namespace mlir; 27abfd1a8bSRiver Riddle using namespace mlir::detail; 28abfd1a8bSRiver Riddle 29abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 30abfd1a8bSRiver Riddle // PDLByteCodePattern 31abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 32abfd1a8bSRiver Riddle 33abfd1a8bSRiver Riddle PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp, 34abfd1a8bSRiver Riddle ByteCodeAddr rewriterAddr) { 35abfd1a8bSRiver Riddle SmallVector<StringRef, 8> generatedOps; 36abfd1a8bSRiver Riddle if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr()) 37abfd1a8bSRiver Riddle generatedOps = 38abfd1a8bSRiver Riddle llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>()); 39abfd1a8bSRiver Riddle 40abfd1a8bSRiver Riddle PatternBenefit benefit = matchOp.benefit(); 41abfd1a8bSRiver Riddle MLIRContext *ctx = matchOp.getContext(); 42abfd1a8bSRiver Riddle 43abfd1a8bSRiver Riddle // Check to see if this is pattern matches a specific operation type. 44abfd1a8bSRiver Riddle if (Optional<StringRef> rootKind = matchOp.rootKind()) 45abfd1a8bSRiver Riddle return PDLByteCodePattern(rewriterAddr, *rootKind, generatedOps, benefit, 46abfd1a8bSRiver Riddle ctx); 47abfd1a8bSRiver Riddle return PDLByteCodePattern(rewriterAddr, generatedOps, benefit, ctx, 48abfd1a8bSRiver Riddle MatchAnyOpTypeTag()); 49abfd1a8bSRiver Riddle } 50abfd1a8bSRiver Riddle 51abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 52abfd1a8bSRiver Riddle // PDLByteCodeMutableState 53abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 54abfd1a8bSRiver Riddle 55abfd1a8bSRiver Riddle /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds 56abfd1a8bSRiver Riddle /// to the position of the pattern within the range returned by 57abfd1a8bSRiver Riddle /// `PDLByteCode::getPatterns`. 58abfd1a8bSRiver Riddle void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex, 59abfd1a8bSRiver Riddle PatternBenefit benefit) { 60abfd1a8bSRiver Riddle currentPatternBenefits[patternIndex] = benefit; 61abfd1a8bSRiver Riddle } 62abfd1a8bSRiver Riddle 63abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 64abfd1a8bSRiver Riddle // Bytecode OpCodes 65abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 66abfd1a8bSRiver Riddle 67abfd1a8bSRiver Riddle namespace { 68abfd1a8bSRiver Riddle enum OpCode : ByteCodeField { 69abfd1a8bSRiver Riddle /// Apply an externally registered constraint. 70abfd1a8bSRiver Riddle ApplyConstraint, 71abfd1a8bSRiver Riddle /// Apply an externally registered rewrite. 72abfd1a8bSRiver Riddle ApplyRewrite, 73abfd1a8bSRiver Riddle /// Check if two generic values are equal. 74abfd1a8bSRiver Riddle AreEqual, 75abfd1a8bSRiver Riddle /// Unconditional branch. 76abfd1a8bSRiver Riddle Branch, 77abfd1a8bSRiver Riddle /// Compare the operand count of an operation with a constant. 78abfd1a8bSRiver Riddle CheckOperandCount, 79abfd1a8bSRiver Riddle /// Compare the name of an operation with a constant. 80abfd1a8bSRiver Riddle CheckOperationName, 81abfd1a8bSRiver Riddle /// Compare the result count of an operation with a constant. 82abfd1a8bSRiver Riddle CheckResultCount, 83abfd1a8bSRiver Riddle /// Invoke a native creation method. 84abfd1a8bSRiver Riddle CreateNative, 85abfd1a8bSRiver Riddle /// Create an operation. 86abfd1a8bSRiver Riddle CreateOperation, 87abfd1a8bSRiver Riddle /// Erase an operation. 88abfd1a8bSRiver Riddle EraseOp, 89abfd1a8bSRiver Riddle /// Terminate a matcher or rewrite sequence. 90abfd1a8bSRiver Riddle Finalize, 91abfd1a8bSRiver Riddle /// Get a specific attribute of an operation. 92abfd1a8bSRiver Riddle GetAttribute, 93abfd1a8bSRiver Riddle /// Get the type of an attribute. 94abfd1a8bSRiver Riddle GetAttributeType, 95abfd1a8bSRiver Riddle /// Get the defining operation of a value. 96abfd1a8bSRiver Riddle GetDefiningOp, 97abfd1a8bSRiver Riddle /// Get a specific operand of an operation. 98abfd1a8bSRiver Riddle GetOperand0, 99abfd1a8bSRiver Riddle GetOperand1, 100abfd1a8bSRiver Riddle GetOperand2, 101abfd1a8bSRiver Riddle GetOperand3, 102abfd1a8bSRiver Riddle GetOperandN, 103abfd1a8bSRiver Riddle /// Get a specific result of an operation. 104abfd1a8bSRiver Riddle GetResult0, 105abfd1a8bSRiver Riddle GetResult1, 106abfd1a8bSRiver Riddle GetResult2, 107abfd1a8bSRiver Riddle GetResult3, 108abfd1a8bSRiver Riddle GetResultN, 109abfd1a8bSRiver Riddle /// Get the type of a value. 110abfd1a8bSRiver Riddle GetValueType, 111abfd1a8bSRiver Riddle /// Check if a generic value is not null. 112abfd1a8bSRiver Riddle IsNotNull, 113abfd1a8bSRiver Riddle /// Record a successful pattern match. 114abfd1a8bSRiver Riddle RecordMatch, 115abfd1a8bSRiver Riddle /// Replace an operation. 116abfd1a8bSRiver Riddle ReplaceOp, 117abfd1a8bSRiver Riddle /// Compare an attribute with a set of constants. 118abfd1a8bSRiver Riddle SwitchAttribute, 119abfd1a8bSRiver Riddle /// Compare the operand count of an operation with a set of constants. 120abfd1a8bSRiver Riddle SwitchOperandCount, 121abfd1a8bSRiver Riddle /// Compare the name of an operation with a set of constants. 122abfd1a8bSRiver Riddle SwitchOperationName, 123abfd1a8bSRiver Riddle /// Compare the result count of an operation with a set of constants. 124abfd1a8bSRiver Riddle SwitchResultCount, 125abfd1a8bSRiver Riddle /// Compare a type with a set of constants. 126abfd1a8bSRiver Riddle SwitchType, 127abfd1a8bSRiver Riddle }; 128abfd1a8bSRiver Riddle 129abfd1a8bSRiver Riddle enum class PDLValueKind { Attribute, Operation, Type, Value }; 130abfd1a8bSRiver Riddle } // end anonymous namespace 131abfd1a8bSRiver Riddle 132abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 133abfd1a8bSRiver Riddle // ByteCode Generation 134abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 135abfd1a8bSRiver Riddle 136abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 137abfd1a8bSRiver Riddle // Generator 138abfd1a8bSRiver Riddle 139abfd1a8bSRiver Riddle namespace { 140abfd1a8bSRiver Riddle struct ByteCodeWriter; 141abfd1a8bSRiver Riddle 142abfd1a8bSRiver Riddle /// This class represents the main generator for the pattern bytecode. 143abfd1a8bSRiver Riddle class Generator { 144abfd1a8bSRiver Riddle public: 145abfd1a8bSRiver Riddle Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData, 146abfd1a8bSRiver Riddle SmallVectorImpl<ByteCodeField> &matcherByteCode, 147abfd1a8bSRiver Riddle SmallVectorImpl<ByteCodeField> &rewriterByteCode, 148abfd1a8bSRiver Riddle SmallVectorImpl<PDLByteCodePattern> &patterns, 149abfd1a8bSRiver Riddle ByteCodeField &maxValueMemoryIndex, 150abfd1a8bSRiver Riddle llvm::StringMap<PDLConstraintFunction> &constraintFns, 151abfd1a8bSRiver Riddle llvm::StringMap<PDLCreateFunction> &createFns, 152abfd1a8bSRiver Riddle llvm::StringMap<PDLRewriteFunction> &rewriteFns) 153abfd1a8bSRiver Riddle : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), 154abfd1a8bSRiver Riddle rewriterByteCode(rewriterByteCode), patterns(patterns), 155abfd1a8bSRiver Riddle maxValueMemoryIndex(maxValueMemoryIndex) { 156abfd1a8bSRiver Riddle for (auto it : llvm::enumerate(constraintFns)) 157abfd1a8bSRiver Riddle constraintToMemIndex.try_emplace(it.value().first(), it.index()); 158abfd1a8bSRiver Riddle for (auto it : llvm::enumerate(createFns)) 159abfd1a8bSRiver Riddle nativeCreateToMemIndex.try_emplace(it.value().first(), it.index()); 160abfd1a8bSRiver Riddle for (auto it : llvm::enumerate(rewriteFns)) 161abfd1a8bSRiver Riddle externalRewriterToMemIndex.try_emplace(it.value().first(), it.index()); 162abfd1a8bSRiver Riddle } 163abfd1a8bSRiver Riddle 164abfd1a8bSRiver Riddle /// Generate the bytecode for the given PDL interpreter module. 165abfd1a8bSRiver Riddle void generate(ModuleOp module); 166abfd1a8bSRiver Riddle 167abfd1a8bSRiver Riddle /// Return the memory index to use for the given value. 168abfd1a8bSRiver Riddle ByteCodeField &getMemIndex(Value value) { 169abfd1a8bSRiver Riddle assert(valueToMemIndex.count(value) && 170abfd1a8bSRiver Riddle "expected memory index to be assigned"); 171abfd1a8bSRiver Riddle return valueToMemIndex[value]; 172abfd1a8bSRiver Riddle } 173abfd1a8bSRiver Riddle 174abfd1a8bSRiver Riddle /// Return an index to use when referring to the given data that is uniqued in 175abfd1a8bSRiver Riddle /// the MLIR context. 176abfd1a8bSRiver Riddle template <typename T> 177abfd1a8bSRiver Riddle std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &> 178abfd1a8bSRiver Riddle getMemIndex(T val) { 179abfd1a8bSRiver Riddle const void *opaqueVal = val.getAsOpaquePointer(); 180abfd1a8bSRiver Riddle 181abfd1a8bSRiver Riddle // Get or insert a reference to this value. 182abfd1a8bSRiver Riddle auto it = uniquedDataToMemIndex.try_emplace( 183abfd1a8bSRiver Riddle opaqueVal, maxValueMemoryIndex + uniquedData.size()); 184abfd1a8bSRiver Riddle if (it.second) 185abfd1a8bSRiver Riddle uniquedData.push_back(opaqueVal); 186abfd1a8bSRiver Riddle return it.first->second; 187abfd1a8bSRiver Riddle } 188abfd1a8bSRiver Riddle 189abfd1a8bSRiver Riddle private: 190abfd1a8bSRiver Riddle /// Allocate memory indices for the results of operations within the matcher 191abfd1a8bSRiver Riddle /// and rewriters. 192abfd1a8bSRiver Riddle void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule); 193abfd1a8bSRiver Riddle 194abfd1a8bSRiver Riddle /// Generate the bytecode for the given operation. 195abfd1a8bSRiver Riddle void generate(Operation *op, ByteCodeWriter &writer); 196abfd1a8bSRiver Riddle void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer); 197abfd1a8bSRiver Riddle void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer); 198abfd1a8bSRiver Riddle void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer); 199abfd1a8bSRiver Riddle void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer); 200abfd1a8bSRiver Riddle void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer); 201abfd1a8bSRiver Riddle void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer); 202abfd1a8bSRiver Riddle void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer); 203abfd1a8bSRiver Riddle void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); 204abfd1a8bSRiver Riddle void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); 205abfd1a8bSRiver Riddle void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); 206abfd1a8bSRiver Riddle void generate(pdl_interp::CreateNativeOp op, ByteCodeWriter &writer); 207abfd1a8bSRiver Riddle void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); 208abfd1a8bSRiver Riddle void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); 209abfd1a8bSRiver Riddle void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); 210abfd1a8bSRiver Riddle void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer); 211abfd1a8bSRiver Riddle void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); 212abfd1a8bSRiver Riddle void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); 213abfd1a8bSRiver Riddle void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); 214abfd1a8bSRiver Riddle void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer); 215abfd1a8bSRiver Riddle void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); 216abfd1a8bSRiver Riddle void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); 217abfd1a8bSRiver Riddle void generate(pdl_interp::InferredTypeOp op, ByteCodeWriter &writer); 218abfd1a8bSRiver Riddle void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); 219abfd1a8bSRiver Riddle void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer); 220abfd1a8bSRiver Riddle void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer); 221abfd1a8bSRiver Riddle void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer); 222abfd1a8bSRiver Riddle void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer); 223abfd1a8bSRiver Riddle void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer); 224abfd1a8bSRiver Riddle void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer); 225abfd1a8bSRiver Riddle void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer); 226abfd1a8bSRiver Riddle 227abfd1a8bSRiver Riddle /// Mapping from value to its corresponding memory index. 228abfd1a8bSRiver Riddle DenseMap<Value, ByteCodeField> valueToMemIndex; 229abfd1a8bSRiver Riddle 230abfd1a8bSRiver Riddle /// Mapping from the name of an externally registered rewrite to its index in 231abfd1a8bSRiver Riddle /// the bytecode registry. 232abfd1a8bSRiver Riddle llvm::StringMap<ByteCodeField> externalRewriterToMemIndex; 233abfd1a8bSRiver Riddle 234abfd1a8bSRiver Riddle /// Mapping from the name of an externally registered constraint to its index 235abfd1a8bSRiver Riddle /// in the bytecode registry. 236abfd1a8bSRiver Riddle llvm::StringMap<ByteCodeField> constraintToMemIndex; 237abfd1a8bSRiver Riddle 238abfd1a8bSRiver Riddle /// Mapping from the name of an externally registered creation method to its 239abfd1a8bSRiver Riddle /// index in the bytecode registry. 240abfd1a8bSRiver Riddle llvm::StringMap<ByteCodeField> nativeCreateToMemIndex; 241abfd1a8bSRiver Riddle 242abfd1a8bSRiver Riddle /// Mapping from rewriter function name to the bytecode address of the 243abfd1a8bSRiver Riddle /// rewriter function in byte. 244abfd1a8bSRiver Riddle llvm::StringMap<ByteCodeAddr> rewriterToAddr; 245abfd1a8bSRiver Riddle 246abfd1a8bSRiver Riddle /// Mapping from a uniqued storage object to its memory index within 247abfd1a8bSRiver Riddle /// `uniquedData`. 248abfd1a8bSRiver Riddle DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex; 249abfd1a8bSRiver Riddle 250abfd1a8bSRiver Riddle /// The current MLIR context. 251abfd1a8bSRiver Riddle MLIRContext *ctx; 252abfd1a8bSRiver Riddle 253abfd1a8bSRiver Riddle /// Data of the ByteCode class to be populated. 254abfd1a8bSRiver Riddle std::vector<const void *> &uniquedData; 255abfd1a8bSRiver Riddle SmallVectorImpl<ByteCodeField> &matcherByteCode; 256abfd1a8bSRiver Riddle SmallVectorImpl<ByteCodeField> &rewriterByteCode; 257abfd1a8bSRiver Riddle SmallVectorImpl<PDLByteCodePattern> &patterns; 258abfd1a8bSRiver Riddle ByteCodeField &maxValueMemoryIndex; 259abfd1a8bSRiver Riddle }; 260abfd1a8bSRiver Riddle 261abfd1a8bSRiver Riddle /// This class provides utilities for writing a bytecode stream. 262abfd1a8bSRiver Riddle struct ByteCodeWriter { 263abfd1a8bSRiver Riddle ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator) 264abfd1a8bSRiver Riddle : bytecode(bytecode), generator(generator) {} 265abfd1a8bSRiver Riddle 266abfd1a8bSRiver Riddle /// Append a field to the bytecode. 267abfd1a8bSRiver Riddle void append(ByteCodeField field) { bytecode.push_back(field); } 268*fa20ab7bSRiver Riddle void append(OpCode opCode) { bytecode.push_back(opCode); } 269abfd1a8bSRiver Riddle 270abfd1a8bSRiver Riddle /// Append an address to the bytecode. 271abfd1a8bSRiver Riddle void append(ByteCodeAddr field) { 272abfd1a8bSRiver Riddle static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 273abfd1a8bSRiver Riddle "unexpected ByteCode address size"); 274abfd1a8bSRiver Riddle 275abfd1a8bSRiver Riddle ByteCodeField fieldParts[2]; 276abfd1a8bSRiver Riddle std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr)); 277abfd1a8bSRiver Riddle bytecode.append({fieldParts[0], fieldParts[1]}); 278abfd1a8bSRiver Riddle } 279abfd1a8bSRiver Riddle 280abfd1a8bSRiver Riddle /// Append a successor range to the bytecode, the exact address will need to 281abfd1a8bSRiver Riddle /// be resolved later. 282abfd1a8bSRiver Riddle void append(SuccessorRange successors) { 283abfd1a8bSRiver Riddle // Add back references to the any successors so that the address can be 284abfd1a8bSRiver Riddle // resolved later. 285abfd1a8bSRiver Riddle for (Block *successor : successors) { 286abfd1a8bSRiver Riddle unresolvedSuccessorRefs[successor].push_back(bytecode.size()); 287abfd1a8bSRiver Riddle append(ByteCodeAddr(0)); 288abfd1a8bSRiver Riddle } 289abfd1a8bSRiver Riddle } 290abfd1a8bSRiver Riddle 291abfd1a8bSRiver Riddle /// Append a range of values that will be read as generic PDLValues. 292abfd1a8bSRiver Riddle void appendPDLValueList(OperandRange values) { 293abfd1a8bSRiver Riddle bytecode.push_back(values.size()); 294abfd1a8bSRiver Riddle for (Value value : values) { 295abfd1a8bSRiver Riddle // Append the type of the value in addition to the value itself. 296abfd1a8bSRiver Riddle PDLValueKind kind = 297abfd1a8bSRiver Riddle TypeSwitch<Type, PDLValueKind>(value.getType()) 298abfd1a8bSRiver Riddle .Case<pdl::AttributeType>( 299abfd1a8bSRiver Riddle [](Type) { return PDLValueKind::Attribute; }) 300abfd1a8bSRiver Riddle .Case<pdl::OperationType>( 301abfd1a8bSRiver Riddle [](Type) { return PDLValueKind::Operation; }) 302abfd1a8bSRiver Riddle .Case<pdl::TypeType>([](Type) { return PDLValueKind::Type; }) 303abfd1a8bSRiver Riddle .Case<pdl::ValueType>([](Type) { return PDLValueKind::Value; }); 304abfd1a8bSRiver Riddle bytecode.push_back(static_cast<ByteCodeField>(kind)); 305abfd1a8bSRiver Riddle append(value); 306abfd1a8bSRiver Riddle } 307abfd1a8bSRiver Riddle } 308abfd1a8bSRiver Riddle 309abfd1a8bSRiver Riddle /// Check if the given class `T` has an iterator type. 310abfd1a8bSRiver Riddle template <typename T, typename... Args> 311abfd1a8bSRiver Riddle using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer()); 312abfd1a8bSRiver Riddle 313abfd1a8bSRiver Riddle /// Append a value that will be stored in a memory slot and not inline within 314abfd1a8bSRiver Riddle /// the bytecode. 315abfd1a8bSRiver Riddle template <typename T> 316abfd1a8bSRiver Riddle std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value || 317abfd1a8bSRiver Riddle std::is_pointer<T>::value> 318abfd1a8bSRiver Riddle append(T value) { 319abfd1a8bSRiver Riddle bytecode.push_back(generator.getMemIndex(value)); 320abfd1a8bSRiver Riddle } 321abfd1a8bSRiver Riddle 322abfd1a8bSRiver Riddle /// Append a range of values. 323abfd1a8bSRiver Riddle template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>> 324abfd1a8bSRiver Riddle std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value> 325abfd1a8bSRiver Riddle append(T range) { 326abfd1a8bSRiver Riddle bytecode.push_back(llvm::size(range)); 327abfd1a8bSRiver Riddle for (auto it : range) 328abfd1a8bSRiver Riddle append(it); 329abfd1a8bSRiver Riddle } 330abfd1a8bSRiver Riddle 331abfd1a8bSRiver Riddle /// Append a variadic number of fields to the bytecode. 332abfd1a8bSRiver Riddle template <typename FieldTy, typename Field2Ty, typename... FieldTys> 333abfd1a8bSRiver Riddle void append(FieldTy field, Field2Ty field2, FieldTys... fields) { 334abfd1a8bSRiver Riddle append(field); 335abfd1a8bSRiver Riddle append(field2, fields...); 336abfd1a8bSRiver Riddle } 337abfd1a8bSRiver Riddle 338abfd1a8bSRiver Riddle /// Successor references in the bytecode that have yet to be resolved. 339abfd1a8bSRiver Riddle DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs; 340abfd1a8bSRiver Riddle 341abfd1a8bSRiver Riddle /// The underlying bytecode buffer. 342abfd1a8bSRiver Riddle SmallVectorImpl<ByteCodeField> &bytecode; 343abfd1a8bSRiver Riddle 344abfd1a8bSRiver Riddle /// The main generator producing PDL. 345abfd1a8bSRiver Riddle Generator &generator; 346abfd1a8bSRiver Riddle }; 347abfd1a8bSRiver Riddle } // end anonymous namespace 348abfd1a8bSRiver Riddle 349abfd1a8bSRiver Riddle void Generator::generate(ModuleOp module) { 350abfd1a8bSRiver Riddle FuncOp matcherFunc = module.lookupSymbol<FuncOp>( 351abfd1a8bSRiver Riddle pdl_interp::PDLInterpDialect::getMatcherFunctionName()); 352abfd1a8bSRiver Riddle ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>( 353abfd1a8bSRiver Riddle pdl_interp::PDLInterpDialect::getRewriterModuleName()); 354abfd1a8bSRiver Riddle assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module"); 355abfd1a8bSRiver Riddle 356abfd1a8bSRiver Riddle // Allocate memory indices for the results of operations within the matcher 357abfd1a8bSRiver Riddle // and rewriters. 358abfd1a8bSRiver Riddle allocateMemoryIndices(matcherFunc, rewriterModule); 359abfd1a8bSRiver Riddle 360abfd1a8bSRiver Riddle // Generate code for the rewriter functions. 361abfd1a8bSRiver Riddle ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this); 362abfd1a8bSRiver Riddle for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) { 363abfd1a8bSRiver Riddle rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size()); 364abfd1a8bSRiver Riddle for (Operation &op : rewriterFunc.getOps()) 365abfd1a8bSRiver Riddle generate(&op, rewriterByteCodeWriter); 366abfd1a8bSRiver Riddle } 367abfd1a8bSRiver Riddle assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() && 368abfd1a8bSRiver Riddle "unexpected branches in rewriter function"); 369abfd1a8bSRiver Riddle 370abfd1a8bSRiver Riddle // Generate code for the matcher function. 371abfd1a8bSRiver Riddle DenseMap<Block *, ByteCodeAddr> blockToAddr; 372abfd1a8bSRiver Riddle llvm::ReversePostOrderTraversal<Region *> rpot(&matcherFunc.getBody()); 373abfd1a8bSRiver Riddle ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this); 374abfd1a8bSRiver Riddle for (Block *block : rpot) { 375abfd1a8bSRiver Riddle // Keep track of where this block begins within the matcher function. 376abfd1a8bSRiver Riddle blockToAddr.try_emplace(block, matcherByteCode.size()); 377abfd1a8bSRiver Riddle for (Operation &op : *block) 378abfd1a8bSRiver Riddle generate(&op, matcherByteCodeWriter); 379abfd1a8bSRiver Riddle } 380abfd1a8bSRiver Riddle 381abfd1a8bSRiver Riddle // Resolve successor references in the matcher. 382abfd1a8bSRiver Riddle for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) { 383abfd1a8bSRiver Riddle ByteCodeAddr addr = blockToAddr[it.first]; 384abfd1a8bSRiver Riddle for (unsigned offsetToFix : it.second) 385abfd1a8bSRiver Riddle std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr)); 386abfd1a8bSRiver Riddle } 387abfd1a8bSRiver Riddle } 388abfd1a8bSRiver Riddle 389abfd1a8bSRiver Riddle void Generator::allocateMemoryIndices(FuncOp matcherFunc, 390abfd1a8bSRiver Riddle ModuleOp rewriterModule) { 391abfd1a8bSRiver Riddle // Rewriters use simplistic allocation scheme that simply assigns an index to 392abfd1a8bSRiver Riddle // each result. 393abfd1a8bSRiver Riddle for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) { 394abfd1a8bSRiver Riddle ByteCodeField index = 0; 395abfd1a8bSRiver Riddle for (BlockArgument arg : rewriterFunc.getArguments()) 396abfd1a8bSRiver Riddle valueToMemIndex.try_emplace(arg, index++); 397abfd1a8bSRiver Riddle rewriterFunc.getBody().walk([&](Operation *op) { 398abfd1a8bSRiver Riddle for (Value result : op->getResults()) 399abfd1a8bSRiver Riddle valueToMemIndex.try_emplace(result, index++); 400abfd1a8bSRiver Riddle }); 401abfd1a8bSRiver Riddle if (index > maxValueMemoryIndex) 402abfd1a8bSRiver Riddle maxValueMemoryIndex = index; 403abfd1a8bSRiver Riddle } 404abfd1a8bSRiver Riddle 405abfd1a8bSRiver Riddle // The matcher function uses a more sophisticated numbering that tries to 406abfd1a8bSRiver Riddle // minimize the number of memory indices assigned. This is done by determining 407abfd1a8bSRiver Riddle // a live range of the values within the matcher, then the allocation is just 408abfd1a8bSRiver Riddle // finding the minimal number of overlapping live ranges. This is essentially 409abfd1a8bSRiver Riddle // a simplified form of register allocation where we don't necessarily have a 410abfd1a8bSRiver Riddle // limited number of registers, but we still want to minimize the number used. 411abfd1a8bSRiver Riddle DenseMap<Operation *, ByteCodeField> opToIndex; 412abfd1a8bSRiver Riddle matcherFunc.getBody().walk([&](Operation *op) { 413abfd1a8bSRiver Riddle opToIndex.insert(std::make_pair(op, opToIndex.size())); 414abfd1a8bSRiver Riddle }); 415abfd1a8bSRiver Riddle 416abfd1a8bSRiver Riddle // Liveness info for each of the defs within the matcher. 417abfd1a8bSRiver Riddle using LivenessSet = llvm::IntervalMap<ByteCodeField, char, 16>; 418abfd1a8bSRiver Riddle LivenessSet::Allocator allocator; 419abfd1a8bSRiver Riddle DenseMap<Value, LivenessSet> valueDefRanges; 420abfd1a8bSRiver Riddle 421abfd1a8bSRiver Riddle // Assign the root operation being matched to slot 0. 422abfd1a8bSRiver Riddle BlockArgument rootOpArg = matcherFunc.getArgument(0); 423abfd1a8bSRiver Riddle valueToMemIndex[rootOpArg] = 0; 424abfd1a8bSRiver Riddle 425abfd1a8bSRiver Riddle // Walk each of the blocks, computing the def interval that the value is used. 426abfd1a8bSRiver Riddle Liveness matcherLiveness(matcherFunc); 427abfd1a8bSRiver Riddle for (Block &block : matcherFunc.getBody()) { 428abfd1a8bSRiver Riddle const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block); 429abfd1a8bSRiver Riddle assert(info && "expected liveness info for block"); 430abfd1a8bSRiver Riddle auto processValue = [&](Value value, Operation *firstUseOrDef) { 431abfd1a8bSRiver Riddle // We don't need to process the root op argument, this value is always 432abfd1a8bSRiver Riddle // assigned to the first memory slot. 433abfd1a8bSRiver Riddle if (value == rootOpArg) 434abfd1a8bSRiver Riddle return; 435abfd1a8bSRiver Riddle 436abfd1a8bSRiver Riddle // Set indices for the range of this block that the value is used. 437abfd1a8bSRiver Riddle auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; 438abfd1a8bSRiver Riddle defRangeIt->second.insert( 439abfd1a8bSRiver Riddle opToIndex[firstUseOrDef], 440abfd1a8bSRiver Riddle opToIndex[info->getEndOperation(value, firstUseOrDef)], 441abfd1a8bSRiver Riddle /*dummyValue*/ 0); 442abfd1a8bSRiver Riddle }; 443abfd1a8bSRiver Riddle 444abfd1a8bSRiver Riddle // Process the live-ins of this block. 445abfd1a8bSRiver Riddle for (Value liveIn : info->in()) 446abfd1a8bSRiver Riddle processValue(liveIn, &block.front()); 447abfd1a8bSRiver Riddle 448abfd1a8bSRiver Riddle // Process any new defs within this block. 449abfd1a8bSRiver Riddle for (Operation &op : block) 450abfd1a8bSRiver Riddle for (Value result : op.getResults()) 451abfd1a8bSRiver Riddle processValue(result, &op); 452abfd1a8bSRiver Riddle } 453abfd1a8bSRiver Riddle 454abfd1a8bSRiver Riddle // Greedily allocate memory slots using the computed def live ranges. 455abfd1a8bSRiver Riddle std::vector<LivenessSet> allocatedIndices; 456abfd1a8bSRiver Riddle for (auto &defIt : valueDefRanges) { 457abfd1a8bSRiver Riddle ByteCodeField &memIndex = valueToMemIndex[defIt.first]; 458abfd1a8bSRiver Riddle LivenessSet &defSet = defIt.second; 459abfd1a8bSRiver Riddle 460abfd1a8bSRiver Riddle // Try to allocate to an existing index. 461abfd1a8bSRiver Riddle for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) { 462abfd1a8bSRiver Riddle LivenessSet &existingIndex = existingIndexIt.value(); 463abfd1a8bSRiver Riddle llvm::IntervalMapOverlaps<LivenessSet, LivenessSet> overlaps( 464abfd1a8bSRiver Riddle defIt.second, existingIndex); 465abfd1a8bSRiver Riddle if (overlaps.valid()) 466abfd1a8bSRiver Riddle continue; 467abfd1a8bSRiver Riddle // Union the range of the def within the existing index. 468abfd1a8bSRiver Riddle for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it) 469abfd1a8bSRiver Riddle existingIndex.insert(it.start(), it.stop(), /*dummyValue*/ 0); 470abfd1a8bSRiver Riddle memIndex = existingIndexIt.index() + 1; 471abfd1a8bSRiver Riddle } 472abfd1a8bSRiver Riddle 473abfd1a8bSRiver Riddle // If no existing index could be used, add a new one. 474abfd1a8bSRiver Riddle if (memIndex == 0) { 475abfd1a8bSRiver Riddle allocatedIndices.emplace_back(allocator); 476abfd1a8bSRiver Riddle for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it) 477abfd1a8bSRiver Riddle allocatedIndices.back().insert(it.start(), it.stop(), /*dummyValue*/ 0); 478abfd1a8bSRiver Riddle memIndex = allocatedIndices.size(); 479abfd1a8bSRiver Riddle } 480abfd1a8bSRiver Riddle } 481abfd1a8bSRiver Riddle 482abfd1a8bSRiver Riddle // Update the max number of indices. 483abfd1a8bSRiver Riddle ByteCodeField numMatcherIndices = allocatedIndices.size() + 1; 484abfd1a8bSRiver Riddle if (numMatcherIndices > maxValueMemoryIndex) 485abfd1a8bSRiver Riddle maxValueMemoryIndex = numMatcherIndices; 486abfd1a8bSRiver Riddle } 487abfd1a8bSRiver Riddle 488abfd1a8bSRiver Riddle void Generator::generate(Operation *op, ByteCodeWriter &writer) { 489abfd1a8bSRiver Riddle TypeSwitch<Operation *>(op) 490abfd1a8bSRiver Riddle .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp, 491abfd1a8bSRiver Riddle pdl_interp::AreEqualOp, pdl_interp::BranchOp, 492abfd1a8bSRiver Riddle pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp, 493abfd1a8bSRiver Riddle pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp, 494abfd1a8bSRiver Riddle pdl_interp::CheckTypeOp, pdl_interp::CreateAttributeOp, 495abfd1a8bSRiver Riddle pdl_interp::CreateNativeOp, pdl_interp::CreateOperationOp, 496abfd1a8bSRiver Riddle pdl_interp::CreateTypeOp, pdl_interp::EraseOp, 497abfd1a8bSRiver Riddle pdl_interp::FinalizeOp, pdl_interp::GetAttributeOp, 498abfd1a8bSRiver Riddle pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp, 499abfd1a8bSRiver Riddle pdl_interp::GetOperandOp, pdl_interp::GetResultOp, 500abfd1a8bSRiver Riddle pdl_interp::GetValueTypeOp, pdl_interp::InferredTypeOp, 501abfd1a8bSRiver Riddle pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp, 502abfd1a8bSRiver Riddle pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp, 503abfd1a8bSRiver Riddle pdl_interp::SwitchTypeOp, pdl_interp::SwitchOperandCountOp, 504abfd1a8bSRiver Riddle pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>( 505abfd1a8bSRiver Riddle [&](auto interpOp) { this->generate(interpOp, writer); }) 506abfd1a8bSRiver Riddle .Default([](Operation *) { 507abfd1a8bSRiver Riddle llvm_unreachable("unknown `pdl_interp` operation"); 508abfd1a8bSRiver Riddle }); 509abfd1a8bSRiver Riddle } 510abfd1a8bSRiver Riddle 511abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::ApplyConstraintOp op, 512abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 513abfd1a8bSRiver Riddle assert(constraintToMemIndex.count(op.name()) && 514abfd1a8bSRiver Riddle "expected index for constraint function"); 515abfd1a8bSRiver Riddle writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()], 516abfd1a8bSRiver Riddle op.constParamsAttr()); 517abfd1a8bSRiver Riddle writer.appendPDLValueList(op.args()); 518abfd1a8bSRiver Riddle writer.append(op.getSuccessors()); 519abfd1a8bSRiver Riddle } 520abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::ApplyRewriteOp op, 521abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 522abfd1a8bSRiver Riddle assert(externalRewriterToMemIndex.count(op.name()) && 523abfd1a8bSRiver Riddle "expected index for rewrite function"); 524abfd1a8bSRiver Riddle writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()], 525abfd1a8bSRiver Riddle op.constParamsAttr(), op.root()); 526abfd1a8bSRiver Riddle writer.appendPDLValueList(op.args()); 527abfd1a8bSRiver Riddle } 528abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) { 529abfd1a8bSRiver Riddle writer.append(OpCode::AreEqual, op.lhs(), op.rhs(), op.getSuccessors()); 530abfd1a8bSRiver Riddle } 531abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) { 532abfd1a8bSRiver Riddle writer.append(OpCode::Branch, SuccessorRange(op)); 533abfd1a8bSRiver Riddle } 534abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckAttributeOp op, 535abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 536abfd1a8bSRiver Riddle writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(), 537abfd1a8bSRiver Riddle op.getSuccessors()); 538abfd1a8bSRiver Riddle } 539abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckOperandCountOp op, 540abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 541abfd1a8bSRiver Riddle writer.append(OpCode::CheckOperandCount, op.operation(), op.count(), 542abfd1a8bSRiver Riddle op.getSuccessors()); 543abfd1a8bSRiver Riddle } 544abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckOperationNameOp op, 545abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 546abfd1a8bSRiver Riddle writer.append(OpCode::CheckOperationName, op.operation(), 547abfd1a8bSRiver Riddle OperationName(op.name(), ctx), op.getSuccessors()); 548abfd1a8bSRiver Riddle } 549abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckResultCountOp op, 550abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 551abfd1a8bSRiver Riddle writer.append(OpCode::CheckResultCount, op.operation(), op.count(), 552abfd1a8bSRiver Riddle op.getSuccessors()); 553abfd1a8bSRiver Riddle } 554abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) { 555abfd1a8bSRiver Riddle writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors()); 556abfd1a8bSRiver Riddle } 557abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateAttributeOp op, 558abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 559abfd1a8bSRiver Riddle // Simply repoint the memory index of the result to the constant. 560abfd1a8bSRiver Riddle getMemIndex(op.attribute()) = getMemIndex(op.value()); 561abfd1a8bSRiver Riddle } 562abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateNativeOp op, 563abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 564abfd1a8bSRiver Riddle assert(nativeCreateToMemIndex.count(op.name()) && 565abfd1a8bSRiver Riddle "expected index for creation function"); 566abfd1a8bSRiver Riddle writer.append(OpCode::CreateNative, nativeCreateToMemIndex[op.name()], 567abfd1a8bSRiver Riddle op.result(), op.constParamsAttr()); 568abfd1a8bSRiver Riddle writer.appendPDLValueList(op.args()); 569abfd1a8bSRiver Riddle } 570abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateOperationOp op, 571abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 572abfd1a8bSRiver Riddle writer.append(OpCode::CreateOperation, op.operation(), 573abfd1a8bSRiver Riddle OperationName(op.name(), ctx), op.operands()); 574abfd1a8bSRiver Riddle 575abfd1a8bSRiver Riddle // Add the attributes. 576abfd1a8bSRiver Riddle OperandRange attributes = op.attributes(); 577abfd1a8bSRiver Riddle writer.append(static_cast<ByteCodeField>(attributes.size())); 578abfd1a8bSRiver Riddle for (auto it : llvm::zip(op.attributeNames(), op.attributes())) { 579abfd1a8bSRiver Riddle writer.append( 580abfd1a8bSRiver Riddle Identifier::get(std::get<0>(it).cast<StringAttr>().getValue(), ctx), 581abfd1a8bSRiver Riddle std::get<1>(it)); 582abfd1a8bSRiver Riddle } 583abfd1a8bSRiver Riddle writer.append(op.types()); 584abfd1a8bSRiver Riddle } 585abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) { 586abfd1a8bSRiver Riddle // Simply repoint the memory index of the result to the constant. 587abfd1a8bSRiver Riddle getMemIndex(op.result()) = getMemIndex(op.value()); 588abfd1a8bSRiver Riddle } 589abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { 590abfd1a8bSRiver Riddle writer.append(OpCode::EraseOp, op.operation()); 591abfd1a8bSRiver Riddle } 592abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) { 593abfd1a8bSRiver Riddle writer.append(OpCode::Finalize); 594abfd1a8bSRiver Riddle } 595abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetAttributeOp op, 596abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 597abfd1a8bSRiver Riddle writer.append(OpCode::GetAttribute, op.attribute(), op.operation(), 598abfd1a8bSRiver Riddle Identifier::get(op.name(), ctx)); 599abfd1a8bSRiver Riddle } 600abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetAttributeTypeOp op, 601abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 602abfd1a8bSRiver Riddle writer.append(OpCode::GetAttributeType, op.result(), op.value()); 603abfd1a8bSRiver Riddle } 604abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetDefiningOpOp op, 605abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 606abfd1a8bSRiver Riddle writer.append(OpCode::GetDefiningOp, op.operation(), op.value()); 607abfd1a8bSRiver Riddle } 608abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) { 609abfd1a8bSRiver Riddle uint32_t index = op.index(); 610abfd1a8bSRiver Riddle if (index < 4) 611abfd1a8bSRiver Riddle writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index)); 612abfd1a8bSRiver Riddle else 613abfd1a8bSRiver Riddle writer.append(OpCode::GetOperandN, index); 614abfd1a8bSRiver Riddle writer.append(op.operation(), op.value()); 615abfd1a8bSRiver Riddle } 616abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) { 617abfd1a8bSRiver Riddle uint32_t index = op.index(); 618abfd1a8bSRiver Riddle if (index < 4) 619abfd1a8bSRiver Riddle writer.append(static_cast<OpCode>(OpCode::GetResult0 + index)); 620abfd1a8bSRiver Riddle else 621abfd1a8bSRiver Riddle writer.append(OpCode::GetResultN, index); 622abfd1a8bSRiver Riddle writer.append(op.operation(), op.value()); 623abfd1a8bSRiver Riddle } 624abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetValueTypeOp op, 625abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 626abfd1a8bSRiver Riddle writer.append(OpCode::GetValueType, op.result(), op.value()); 627abfd1a8bSRiver Riddle } 628abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::InferredTypeOp op, 629abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 630abfd1a8bSRiver Riddle // InferType maps to a null type as a marker for inferring a result type. 631abfd1a8bSRiver Riddle getMemIndex(op.type()) = getMemIndex(Type()); 632abfd1a8bSRiver Riddle } 633abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) { 634abfd1a8bSRiver Riddle writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors()); 635abfd1a8bSRiver Riddle } 636abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { 637abfd1a8bSRiver Riddle ByteCodeField patternIndex = patterns.size(); 638abfd1a8bSRiver Riddle patterns.emplace_back(PDLByteCodePattern::create( 639abfd1a8bSRiver Riddle op, rewriterToAddr[op.rewriter().getLeafReference()])); 640abfd1a8bSRiver Riddle writer.append(OpCode::RecordMatch, patternIndex, SuccessorRange(op), 641abfd1a8bSRiver Riddle op.matchedOps(), op.inputs()); 642abfd1a8bSRiver Riddle } 643abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) { 644abfd1a8bSRiver Riddle writer.append(OpCode::ReplaceOp, op.operation(), op.replValues()); 645abfd1a8bSRiver Riddle } 646abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchAttributeOp op, 647abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 648abfd1a8bSRiver Riddle writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(), 649abfd1a8bSRiver Riddle op.getSuccessors()); 650abfd1a8bSRiver Riddle } 651abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchOperandCountOp op, 652abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 653abfd1a8bSRiver Riddle writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(), 654abfd1a8bSRiver Riddle op.getSuccessors()); 655abfd1a8bSRiver Riddle } 656abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchOperationNameOp op, 657abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 658abfd1a8bSRiver Riddle auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) { 659abfd1a8bSRiver Riddle return OperationName(attr.cast<StringAttr>().getValue(), ctx); 660abfd1a8bSRiver Riddle }); 661abfd1a8bSRiver Riddle writer.append(OpCode::SwitchOperationName, op.operation(), cases, 662abfd1a8bSRiver Riddle op.getSuccessors()); 663abfd1a8bSRiver Riddle } 664abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchResultCountOp op, 665abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 666abfd1a8bSRiver Riddle writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(), 667abfd1a8bSRiver Riddle op.getSuccessors()); 668abfd1a8bSRiver Riddle } 669abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) { 670abfd1a8bSRiver Riddle writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(), 671abfd1a8bSRiver Riddle op.getSuccessors()); 672abfd1a8bSRiver Riddle } 673abfd1a8bSRiver Riddle 674abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 675abfd1a8bSRiver Riddle // PDLByteCode 676abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 677abfd1a8bSRiver Riddle 678abfd1a8bSRiver Riddle PDLByteCode::PDLByteCode(ModuleOp module, 679abfd1a8bSRiver Riddle llvm::StringMap<PDLConstraintFunction> constraintFns, 680abfd1a8bSRiver Riddle llvm::StringMap<PDLCreateFunction> createFns, 681abfd1a8bSRiver Riddle llvm::StringMap<PDLRewriteFunction> rewriteFns) { 682abfd1a8bSRiver Riddle Generator generator(module.getContext(), uniquedData, matcherByteCode, 683abfd1a8bSRiver Riddle rewriterByteCode, patterns, maxValueMemoryIndex, 684abfd1a8bSRiver Riddle constraintFns, createFns, rewriteFns); 685abfd1a8bSRiver Riddle generator.generate(module); 686abfd1a8bSRiver Riddle 687abfd1a8bSRiver Riddle // Initialize the external functions. 688abfd1a8bSRiver Riddle for (auto &it : constraintFns) 689abfd1a8bSRiver Riddle constraintFunctions.push_back(std::move(it.second)); 690abfd1a8bSRiver Riddle for (auto &it : createFns) 691abfd1a8bSRiver Riddle createFunctions.push_back(std::move(it.second)); 692abfd1a8bSRiver Riddle for (auto &it : rewriteFns) 693abfd1a8bSRiver Riddle rewriteFunctions.push_back(std::move(it.second)); 694abfd1a8bSRiver Riddle } 695abfd1a8bSRiver Riddle 696abfd1a8bSRiver Riddle /// Initialize the given state such that it can be used to execute the current 697abfd1a8bSRiver Riddle /// bytecode. 698abfd1a8bSRiver Riddle void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { 699abfd1a8bSRiver Riddle state.memory.resize(maxValueMemoryIndex, nullptr); 700abfd1a8bSRiver Riddle state.currentPatternBenefits.reserve(patterns.size()); 701abfd1a8bSRiver Riddle for (const PDLByteCodePattern &pattern : patterns) 702abfd1a8bSRiver Riddle state.currentPatternBenefits.push_back(pattern.getBenefit()); 703abfd1a8bSRiver Riddle } 704abfd1a8bSRiver Riddle 705abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 706abfd1a8bSRiver Riddle // ByteCode Execution 707abfd1a8bSRiver Riddle 708abfd1a8bSRiver Riddle namespace { 709abfd1a8bSRiver Riddle /// This class provides support for executing a bytecode stream. 710abfd1a8bSRiver Riddle class ByteCodeExecutor { 711abfd1a8bSRiver Riddle public: 712abfd1a8bSRiver Riddle ByteCodeExecutor(const ByteCodeField *curCodeIt, 713abfd1a8bSRiver Riddle MutableArrayRef<const void *> memory, 714abfd1a8bSRiver Riddle ArrayRef<const void *> uniquedMemory, 715abfd1a8bSRiver Riddle ArrayRef<ByteCodeField> code, 716abfd1a8bSRiver Riddle ArrayRef<PatternBenefit> currentPatternBenefits, 717abfd1a8bSRiver Riddle ArrayRef<PDLByteCodePattern> patterns, 718abfd1a8bSRiver Riddle ArrayRef<PDLConstraintFunction> constraintFunctions, 719abfd1a8bSRiver Riddle ArrayRef<PDLCreateFunction> createFunctions, 720abfd1a8bSRiver Riddle ArrayRef<PDLRewriteFunction> rewriteFunctions) 721abfd1a8bSRiver Riddle : curCodeIt(curCodeIt), memory(memory), uniquedMemory(uniquedMemory), 722abfd1a8bSRiver Riddle code(code), currentPatternBenefits(currentPatternBenefits), 723abfd1a8bSRiver Riddle patterns(patterns), constraintFunctions(constraintFunctions), 724abfd1a8bSRiver Riddle createFunctions(createFunctions), rewriteFunctions(rewriteFunctions) {} 725abfd1a8bSRiver Riddle 726abfd1a8bSRiver Riddle /// Start executing the code at the current bytecode index. `matches` is an 727abfd1a8bSRiver Riddle /// optional field provided when this function is executed in a matching 728abfd1a8bSRiver Riddle /// context. 729abfd1a8bSRiver Riddle void execute(PatternRewriter &rewriter, 730abfd1a8bSRiver Riddle SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr, 731abfd1a8bSRiver Riddle Optional<Location> mainRewriteLoc = {}); 732abfd1a8bSRiver Riddle 733abfd1a8bSRiver Riddle private: 734abfd1a8bSRiver Riddle /// Read a value from the bytecode buffer, optionally skipping a certain 735abfd1a8bSRiver Riddle /// number of prefix values. These methods always update the buffer to point 736abfd1a8bSRiver Riddle /// to the next field after the read data. 737abfd1a8bSRiver Riddle template <typename T = ByteCodeField> 738abfd1a8bSRiver Riddle T read(size_t skipN = 0) { 739abfd1a8bSRiver Riddle curCodeIt += skipN; 740abfd1a8bSRiver Riddle return readImpl<T>(); 741abfd1a8bSRiver Riddle } 742abfd1a8bSRiver Riddle ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); } 743abfd1a8bSRiver Riddle 744abfd1a8bSRiver Riddle /// Read a list of values from the bytecode buffer. 745abfd1a8bSRiver Riddle template <typename ValueT, typename T> 746abfd1a8bSRiver Riddle void readList(SmallVectorImpl<T> &list) { 747abfd1a8bSRiver Riddle list.clear(); 748abfd1a8bSRiver Riddle for (unsigned i = 0, e = read(); i != e; ++i) 749abfd1a8bSRiver Riddle list.push_back(read<ValueT>()); 750abfd1a8bSRiver Riddle } 751abfd1a8bSRiver Riddle 752abfd1a8bSRiver Riddle /// Jump to a specific successor based on a predicate value. 753abfd1a8bSRiver Riddle void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); } 754abfd1a8bSRiver Riddle /// Jump to a specific successor based on a destination index. 755abfd1a8bSRiver Riddle void selectJump(size_t destIndex) { 756abfd1a8bSRiver Riddle curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)]; 757abfd1a8bSRiver Riddle } 758abfd1a8bSRiver Riddle 759abfd1a8bSRiver Riddle /// Handle a switch operation with the provided value and cases. 760abfd1a8bSRiver Riddle template <typename T, typename RangeT> 761abfd1a8bSRiver Riddle void handleSwitch(const T &value, RangeT &&cases) { 762abfd1a8bSRiver Riddle LLVM_DEBUG({ 763abfd1a8bSRiver Riddle llvm::dbgs() << " * Value: " << value << "\n" 764abfd1a8bSRiver Riddle << " * Cases: "; 765abfd1a8bSRiver Riddle llvm::interleaveComma(cases, llvm::dbgs()); 766abfd1a8bSRiver Riddle llvm::dbgs() << "\n\n"; 767abfd1a8bSRiver Riddle }); 768abfd1a8bSRiver Riddle 769abfd1a8bSRiver Riddle // Check to see if the attribute value is within the case list. Jump to 770abfd1a8bSRiver Riddle // the correct successor index based on the result. 771abfd1a8bSRiver Riddle auto it = llvm::find(cases, value); 772abfd1a8bSRiver Riddle selectJump(it == cases.end() ? size_t(0) : ((it - cases.begin()) + 1)); 773abfd1a8bSRiver Riddle } 774abfd1a8bSRiver Riddle 775abfd1a8bSRiver Riddle /// Internal implementation of reading various data types from the bytecode 776abfd1a8bSRiver Riddle /// stream. 777abfd1a8bSRiver Riddle template <typename T> 778abfd1a8bSRiver Riddle const void *readFromMemory() { 779abfd1a8bSRiver Riddle size_t index = *curCodeIt++; 780abfd1a8bSRiver Riddle 781abfd1a8bSRiver Riddle // If this type is an SSA value, it can only be stored in non-const memory. 782abfd1a8bSRiver Riddle if (llvm::is_one_of<T, Operation *, Value>::value || index < memory.size()) 783abfd1a8bSRiver Riddle return memory[index]; 784abfd1a8bSRiver Riddle 785abfd1a8bSRiver Riddle // Otherwise, if this index is not inbounds it is uniqued. 786abfd1a8bSRiver Riddle return uniquedMemory[index - memory.size()]; 787abfd1a8bSRiver Riddle } 788abfd1a8bSRiver Riddle template <typename T> 789abfd1a8bSRiver Riddle std::enable_if_t<std::is_pointer<T>::value, T> readImpl() { 790abfd1a8bSRiver Riddle return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>())); 791abfd1a8bSRiver Riddle } 792abfd1a8bSRiver Riddle template <typename T> 793abfd1a8bSRiver Riddle std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value, 794abfd1a8bSRiver Riddle T> 795abfd1a8bSRiver Riddle readImpl() { 796abfd1a8bSRiver Riddle return T(T::getFromOpaquePointer(readFromMemory<T>())); 797abfd1a8bSRiver Riddle } 798abfd1a8bSRiver Riddle template <typename T> 799abfd1a8bSRiver Riddle std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() { 800abfd1a8bSRiver Riddle switch (static_cast<PDLValueKind>(read())) { 801abfd1a8bSRiver Riddle case PDLValueKind::Attribute: 802abfd1a8bSRiver Riddle return read<Attribute>(); 803abfd1a8bSRiver Riddle case PDLValueKind::Operation: 804abfd1a8bSRiver Riddle return read<Operation *>(); 805abfd1a8bSRiver Riddle case PDLValueKind::Type: 806abfd1a8bSRiver Riddle return read<Type>(); 807abfd1a8bSRiver Riddle case PDLValueKind::Value: 808abfd1a8bSRiver Riddle return read<Value>(); 809abfd1a8bSRiver Riddle } 810abfd1a8bSRiver Riddle } 811abfd1a8bSRiver Riddle template <typename T> 812abfd1a8bSRiver Riddle std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() { 813abfd1a8bSRiver Riddle static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 814abfd1a8bSRiver Riddle "unexpected ByteCode address size"); 815abfd1a8bSRiver Riddle ByteCodeAddr result; 816abfd1a8bSRiver Riddle std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr)); 817abfd1a8bSRiver Riddle curCodeIt += 2; 818abfd1a8bSRiver Riddle return result; 819abfd1a8bSRiver Riddle } 820abfd1a8bSRiver Riddle template <typename T> 821abfd1a8bSRiver Riddle std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() { 822abfd1a8bSRiver Riddle return *curCodeIt++; 823abfd1a8bSRiver Riddle } 824abfd1a8bSRiver Riddle 825abfd1a8bSRiver Riddle /// The underlying bytecode buffer. 826abfd1a8bSRiver Riddle const ByteCodeField *curCodeIt; 827abfd1a8bSRiver Riddle 828abfd1a8bSRiver Riddle /// The current execution memory. 829abfd1a8bSRiver Riddle MutableArrayRef<const void *> memory; 830abfd1a8bSRiver Riddle 831abfd1a8bSRiver Riddle /// References to ByteCode data necessary for execution. 832abfd1a8bSRiver Riddle ArrayRef<const void *> uniquedMemory; 833abfd1a8bSRiver Riddle ArrayRef<ByteCodeField> code; 834abfd1a8bSRiver Riddle ArrayRef<PatternBenefit> currentPatternBenefits; 835abfd1a8bSRiver Riddle ArrayRef<PDLByteCodePattern> patterns; 836abfd1a8bSRiver Riddle ArrayRef<PDLConstraintFunction> constraintFunctions; 837abfd1a8bSRiver Riddle ArrayRef<PDLCreateFunction> createFunctions; 838abfd1a8bSRiver Riddle ArrayRef<PDLRewriteFunction> rewriteFunctions; 839abfd1a8bSRiver Riddle }; 840abfd1a8bSRiver Riddle } // end anonymous namespace 841abfd1a8bSRiver Riddle 842abfd1a8bSRiver Riddle void ByteCodeExecutor::execute( 843abfd1a8bSRiver Riddle PatternRewriter &rewriter, 844abfd1a8bSRiver Riddle SmallVectorImpl<PDLByteCode::MatchResult> *matches, 845abfd1a8bSRiver Riddle Optional<Location> mainRewriteLoc) { 846abfd1a8bSRiver Riddle while (true) { 847abfd1a8bSRiver Riddle OpCode opCode = static_cast<OpCode>(read()); 848abfd1a8bSRiver Riddle switch (opCode) { 849abfd1a8bSRiver Riddle case ApplyConstraint: { 850abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); 851abfd1a8bSRiver Riddle const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; 852abfd1a8bSRiver Riddle ArrayAttr constParams = read<ArrayAttr>(); 853abfd1a8bSRiver Riddle SmallVector<PDLValue, 16> args; 854abfd1a8bSRiver Riddle readList<PDLValue>(args); 855abfd1a8bSRiver Riddle LLVM_DEBUG({ 856abfd1a8bSRiver Riddle llvm::dbgs() << " * Arguments: "; 857abfd1a8bSRiver Riddle llvm::interleaveComma(args, llvm::dbgs()); 858abfd1a8bSRiver Riddle llvm::dbgs() << "\n * Parameters: " << constParams << "\n\n"; 859abfd1a8bSRiver Riddle }); 860abfd1a8bSRiver Riddle 861abfd1a8bSRiver Riddle // Invoke the constraint and jump to the proper destination. 862abfd1a8bSRiver Riddle selectJump(succeeded(constraintFn(args, constParams, rewriter))); 863abfd1a8bSRiver Riddle break; 864abfd1a8bSRiver Riddle } 865abfd1a8bSRiver Riddle case ApplyRewrite: { 866abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); 867abfd1a8bSRiver Riddle const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; 868abfd1a8bSRiver Riddle ArrayAttr constParams = read<ArrayAttr>(); 869abfd1a8bSRiver Riddle Operation *root = read<Operation *>(); 870abfd1a8bSRiver Riddle SmallVector<PDLValue, 16> args; 871abfd1a8bSRiver Riddle readList<PDLValue>(args); 872abfd1a8bSRiver Riddle 873abfd1a8bSRiver Riddle LLVM_DEBUG({ 874abfd1a8bSRiver Riddle llvm::dbgs() << " * Root: " << *root << "\n" 875abfd1a8bSRiver Riddle << " * Arguments: "; 876abfd1a8bSRiver Riddle llvm::interleaveComma(args, llvm::dbgs()); 877abfd1a8bSRiver Riddle llvm::dbgs() << "\n * Parameters: " << constParams << "\n\n"; 878abfd1a8bSRiver Riddle }); 879abfd1a8bSRiver Riddle rewriteFn(root, args, constParams, rewriter); 880abfd1a8bSRiver Riddle break; 881abfd1a8bSRiver Riddle } 882abfd1a8bSRiver Riddle case AreEqual: { 883abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 884abfd1a8bSRiver Riddle const void *lhs = read<const void *>(); 885abfd1a8bSRiver Riddle const void *rhs = read<const void *>(); 886abfd1a8bSRiver Riddle 887abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); 888abfd1a8bSRiver Riddle selectJump(lhs == rhs); 889abfd1a8bSRiver Riddle break; 890abfd1a8bSRiver Riddle } 891abfd1a8bSRiver Riddle case Branch: { 892abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n\n"); 893abfd1a8bSRiver Riddle curCodeIt = &code[read<ByteCodeAddr>()]; 894abfd1a8bSRiver Riddle break; 895abfd1a8bSRiver Riddle } 896abfd1a8bSRiver Riddle case CheckOperandCount: { 897abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n"); 898abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 899abfd1a8bSRiver Riddle uint32_t expectedCount = read<uint32_t>(); 900abfd1a8bSRiver Riddle 901abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n" 902abfd1a8bSRiver Riddle << " * Expected: " << expectedCount << "\n\n"); 903abfd1a8bSRiver Riddle selectJump(op->getNumOperands() == expectedCount); 904abfd1a8bSRiver Riddle break; 905abfd1a8bSRiver Riddle } 906abfd1a8bSRiver Riddle case CheckOperationName: { 907abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n"); 908abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 909abfd1a8bSRiver Riddle OperationName expectedName = read<OperationName>(); 910abfd1a8bSRiver Riddle 911abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() 912abfd1a8bSRiver Riddle << " * Found: \"" << op->getName() << "\"\n" 913abfd1a8bSRiver Riddle << " * Expected: \"" << expectedName << "\"\n\n"); 914abfd1a8bSRiver Riddle selectJump(op->getName() == expectedName); 915abfd1a8bSRiver Riddle break; 916abfd1a8bSRiver Riddle } 917abfd1a8bSRiver Riddle case CheckResultCount: { 918abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n"); 919abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 920abfd1a8bSRiver Riddle uint32_t expectedCount = read<uint32_t>(); 921abfd1a8bSRiver Riddle 922abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n" 923abfd1a8bSRiver Riddle << " * Expected: " << expectedCount << "\n\n"); 924abfd1a8bSRiver Riddle selectJump(op->getNumResults() == expectedCount); 925abfd1a8bSRiver Riddle break; 926abfd1a8bSRiver Riddle } 927abfd1a8bSRiver Riddle case CreateNative: { 928abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing CreateNative:\n"); 929abfd1a8bSRiver Riddle const PDLCreateFunction &createFn = createFunctions[read()]; 930abfd1a8bSRiver Riddle ByteCodeField resultIndex = read(); 931abfd1a8bSRiver Riddle ArrayAttr constParams = read<ArrayAttr>(); 932abfd1a8bSRiver Riddle SmallVector<PDLValue, 16> args; 933abfd1a8bSRiver Riddle readList<PDLValue>(args); 934abfd1a8bSRiver Riddle 935abfd1a8bSRiver Riddle LLVM_DEBUG({ 936abfd1a8bSRiver Riddle llvm::dbgs() << " * Arguments: "; 937abfd1a8bSRiver Riddle llvm::interleaveComma(args, llvm::dbgs()); 938abfd1a8bSRiver Riddle llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; 939abfd1a8bSRiver Riddle }); 940abfd1a8bSRiver Riddle 941abfd1a8bSRiver Riddle PDLValue result = createFn(args, constParams, rewriter); 942abfd1a8bSRiver Riddle memory[resultIndex] = result.getAsOpaquePointer(); 943abfd1a8bSRiver Riddle 944abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n\n"); 945abfd1a8bSRiver Riddle break; 946abfd1a8bSRiver Riddle } 947abfd1a8bSRiver Riddle case CreateOperation: { 948abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n"); 949abfd1a8bSRiver Riddle assert(mainRewriteLoc && "expected rewrite loc to be provided when " 950abfd1a8bSRiver Riddle "executing the rewriter bytecode"); 951abfd1a8bSRiver Riddle 952abfd1a8bSRiver Riddle unsigned memIndex = read(); 953abfd1a8bSRiver Riddle OperationState state(*mainRewriteLoc, read<OperationName>()); 954abfd1a8bSRiver Riddle readList<Value>(state.operands); 955abfd1a8bSRiver Riddle for (unsigned i = 0, e = read(); i != e; ++i) { 956abfd1a8bSRiver Riddle Identifier name = read<Identifier>(); 957abfd1a8bSRiver Riddle if (Attribute attr = read<Attribute>()) 958abfd1a8bSRiver Riddle state.addAttribute(name, attr); 959abfd1a8bSRiver Riddle } 960abfd1a8bSRiver Riddle 961abfd1a8bSRiver Riddle bool hasInferredTypes = false; 962abfd1a8bSRiver Riddle for (unsigned i = 0, e = read(); i != e; ++i) { 963abfd1a8bSRiver Riddle Type resultType = read<Type>(); 964abfd1a8bSRiver Riddle hasInferredTypes |= !resultType; 965abfd1a8bSRiver Riddle state.types.push_back(resultType); 966abfd1a8bSRiver Riddle } 967abfd1a8bSRiver Riddle 968abfd1a8bSRiver Riddle // Handle the case where the operation has inferred types. 969abfd1a8bSRiver Riddle if (hasInferredTypes) { 970abfd1a8bSRiver Riddle InferTypeOpInterface::Concept *concept = 971abfd1a8bSRiver Riddle state.name.getAbstractOperation() 972abfd1a8bSRiver Riddle ->getInterface<InferTypeOpInterface>(); 973abfd1a8bSRiver Riddle 974abfd1a8bSRiver Riddle // TODO: Handle failure. 975abfd1a8bSRiver Riddle SmallVector<Type, 2> inferredTypes; 976abfd1a8bSRiver Riddle if (failed(concept->inferReturnTypes( 977abfd1a8bSRiver Riddle state.getContext(), state.location, state.operands, 978abfd1a8bSRiver Riddle state.attributes.getDictionary(state.getContext()), 979abfd1a8bSRiver Riddle state.regions, inferredTypes))) 980abfd1a8bSRiver Riddle return; 981abfd1a8bSRiver Riddle 982abfd1a8bSRiver Riddle for (unsigned i = 0, e = state.types.size(); i != e; ++i) 983abfd1a8bSRiver Riddle if (!state.types[i]) 984abfd1a8bSRiver Riddle state.types[i] = inferredTypes[i]; 985abfd1a8bSRiver Riddle } 986abfd1a8bSRiver Riddle Operation *resultOp = rewriter.createOperation(state); 987abfd1a8bSRiver Riddle memory[memIndex] = resultOp; 988abfd1a8bSRiver Riddle 989abfd1a8bSRiver Riddle LLVM_DEBUG({ 990abfd1a8bSRiver Riddle llvm::dbgs() << " * Attributes: " 991abfd1a8bSRiver Riddle << state.attributes.getDictionary(state.getContext()) 992abfd1a8bSRiver Riddle << "\n * Operands: "; 993abfd1a8bSRiver Riddle llvm::interleaveComma(state.operands, llvm::dbgs()); 994abfd1a8bSRiver Riddle llvm::dbgs() << "\n * Result Types: "; 995abfd1a8bSRiver Riddle llvm::interleaveComma(state.types, llvm::dbgs()); 996abfd1a8bSRiver Riddle llvm::dbgs() << "\n * Result: " << *resultOp << "\n\n"; 997abfd1a8bSRiver Riddle }); 998abfd1a8bSRiver Riddle break; 999abfd1a8bSRiver Riddle } 1000abfd1a8bSRiver Riddle case EraseOp: { 1001abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n"); 1002abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1003abfd1a8bSRiver Riddle 1004abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n\n"); 1005abfd1a8bSRiver Riddle rewriter.eraseOp(op); 1006abfd1a8bSRiver Riddle break; 1007abfd1a8bSRiver Riddle } 1008abfd1a8bSRiver Riddle case Finalize: { 1009abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n"); 1010abfd1a8bSRiver Riddle return; 1011abfd1a8bSRiver Riddle } 1012abfd1a8bSRiver Riddle case GetAttribute: { 1013abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n"); 1014abfd1a8bSRiver Riddle unsigned memIndex = read(); 1015abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1016abfd1a8bSRiver Riddle Identifier attrName = read<Identifier>(); 1017abfd1a8bSRiver Riddle Attribute attr = op->getAttr(attrName); 1018abfd1a8bSRiver Riddle 1019abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1020abfd1a8bSRiver Riddle << " * Attribute: " << attrName << "\n" 1021abfd1a8bSRiver Riddle << " * Result: " << attr << "\n\n"); 1022abfd1a8bSRiver Riddle memory[memIndex] = attr.getAsOpaquePointer(); 1023abfd1a8bSRiver Riddle break; 1024abfd1a8bSRiver Riddle } 1025abfd1a8bSRiver Riddle case GetAttributeType: { 1026abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n"); 1027abfd1a8bSRiver Riddle unsigned memIndex = read(); 1028abfd1a8bSRiver Riddle Attribute attr = read<Attribute>(); 1029abfd1a8bSRiver Riddle 1030abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n" 1031abfd1a8bSRiver Riddle << " * Result: " << attr.getType() << "\n\n"); 1032abfd1a8bSRiver Riddle memory[memIndex] = attr.getType().getAsOpaquePointer(); 1033abfd1a8bSRiver Riddle break; 1034abfd1a8bSRiver Riddle } 1035abfd1a8bSRiver Riddle case GetDefiningOp: { 1036abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n"); 1037abfd1a8bSRiver Riddle unsigned memIndex = read(); 1038abfd1a8bSRiver Riddle Value value = read<Value>(); 1039abfd1a8bSRiver Riddle Operation *op = value ? value.getDefiningOp() : nullptr; 1040abfd1a8bSRiver Riddle 1041abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" 1042abfd1a8bSRiver Riddle << " * Result: " << *op << "\n\n"); 1043abfd1a8bSRiver Riddle memory[memIndex] = op; 1044abfd1a8bSRiver Riddle break; 1045abfd1a8bSRiver Riddle } 1046abfd1a8bSRiver Riddle case GetOperand0: 1047abfd1a8bSRiver Riddle case GetOperand1: 1048abfd1a8bSRiver Riddle case GetOperand2: 1049abfd1a8bSRiver Riddle case GetOperand3: 1050abfd1a8bSRiver Riddle case GetOperandN: { 1051abfd1a8bSRiver Riddle LLVM_DEBUG({ 1052abfd1a8bSRiver Riddle llvm::dbgs() << "Executing GetOperand" 1053abfd1a8bSRiver Riddle << (opCode == GetOperandN ? Twine("N") 1054abfd1a8bSRiver Riddle : Twine(opCode - GetOperand0)) 1055abfd1a8bSRiver Riddle << ":\n"; 1056abfd1a8bSRiver Riddle }); 1057abfd1a8bSRiver Riddle unsigned index = 1058abfd1a8bSRiver Riddle opCode == GetOperandN ? read<uint32_t>() : (opCode - GetOperand0); 1059abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1060abfd1a8bSRiver Riddle unsigned memIndex = read(); 1061abfd1a8bSRiver Riddle Value operand = 1062abfd1a8bSRiver Riddle index < op->getNumOperands() ? op->getOperand(index) : Value(); 1063abfd1a8bSRiver Riddle 1064abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1065abfd1a8bSRiver Riddle << " * Index: " << index << "\n" 1066abfd1a8bSRiver Riddle << " * Result: " << operand << "\n\n"); 1067abfd1a8bSRiver Riddle memory[memIndex] = operand.getAsOpaquePointer(); 1068abfd1a8bSRiver Riddle break; 1069abfd1a8bSRiver Riddle } 1070abfd1a8bSRiver Riddle case GetResult0: 1071abfd1a8bSRiver Riddle case GetResult1: 1072abfd1a8bSRiver Riddle case GetResult2: 1073abfd1a8bSRiver Riddle case GetResult3: 1074abfd1a8bSRiver Riddle case GetResultN: { 1075abfd1a8bSRiver Riddle LLVM_DEBUG({ 1076abfd1a8bSRiver Riddle llvm::dbgs() << "Executing GetResult" 1077abfd1a8bSRiver Riddle << (opCode == GetResultN ? Twine("N") 1078abfd1a8bSRiver Riddle : Twine(opCode - GetResult0)) 1079abfd1a8bSRiver Riddle << ":\n"; 1080abfd1a8bSRiver Riddle }); 1081abfd1a8bSRiver Riddle unsigned index = 1082abfd1a8bSRiver Riddle opCode == GetResultN ? read<uint32_t>() : (opCode - GetResult0); 1083abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1084abfd1a8bSRiver Riddle unsigned memIndex = read(); 1085abfd1a8bSRiver Riddle OpResult result = 1086abfd1a8bSRiver Riddle index < op->getNumResults() ? op->getResult(index) : OpResult(); 1087abfd1a8bSRiver Riddle 1088abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1089abfd1a8bSRiver Riddle << " * Index: " << index << "\n" 1090abfd1a8bSRiver Riddle << " * Result: " << result << "\n\n"); 1091abfd1a8bSRiver Riddle memory[memIndex] = result.getAsOpaquePointer(); 1092abfd1a8bSRiver Riddle break; 1093abfd1a8bSRiver Riddle } 1094abfd1a8bSRiver Riddle case GetValueType: { 1095abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n"); 1096abfd1a8bSRiver Riddle unsigned memIndex = read(); 1097abfd1a8bSRiver Riddle Value value = read<Value>(); 1098abfd1a8bSRiver Riddle 1099abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" 1100abfd1a8bSRiver Riddle << " * Result: " << value.getType() << "\n\n"); 1101abfd1a8bSRiver Riddle memory[memIndex] = value.getType().getAsOpaquePointer(); 1102abfd1a8bSRiver Riddle break; 1103abfd1a8bSRiver Riddle } 1104abfd1a8bSRiver Riddle case IsNotNull: { 1105abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n"); 1106abfd1a8bSRiver Riddle const void *value = read<const void *>(); 1107abfd1a8bSRiver Riddle 1108abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n\n"); 1109abfd1a8bSRiver Riddle selectJump(value != nullptr); 1110abfd1a8bSRiver Riddle break; 1111abfd1a8bSRiver Riddle } 1112abfd1a8bSRiver Riddle case RecordMatch: { 1113abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n"); 1114abfd1a8bSRiver Riddle assert(matches && 1115abfd1a8bSRiver Riddle "expected matches to be provided when executing the matcher"); 1116abfd1a8bSRiver Riddle unsigned patternIndex = read(); 1117abfd1a8bSRiver Riddle PatternBenefit benefit = currentPatternBenefits[patternIndex]; 1118abfd1a8bSRiver Riddle const ByteCodeField *dest = &code[read<ByteCodeAddr>()]; 1119abfd1a8bSRiver Riddle 1120abfd1a8bSRiver Riddle // If the benefit of the pattern is impossible, skip the processing of the 1121abfd1a8bSRiver Riddle // rest of the pattern. 1122abfd1a8bSRiver Riddle if (benefit.isImpossibleToMatch()) { 1123abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n\n"); 1124abfd1a8bSRiver Riddle curCodeIt = dest; 1125abfd1a8bSRiver Riddle break; 1126abfd1a8bSRiver Riddle } 1127abfd1a8bSRiver Riddle 1128abfd1a8bSRiver Riddle // Create a fused location containing the locations of each of the 1129abfd1a8bSRiver Riddle // operations used in the match. This will be used as the location for 1130abfd1a8bSRiver Riddle // created operations during the rewrite that don't already have an 1131abfd1a8bSRiver Riddle // explicit location set. 1132abfd1a8bSRiver Riddle unsigned numMatchLocs = read(); 1133abfd1a8bSRiver Riddle SmallVector<Location, 4> matchLocs; 1134abfd1a8bSRiver Riddle matchLocs.reserve(numMatchLocs); 1135abfd1a8bSRiver Riddle for (unsigned i = 0; i != numMatchLocs; ++i) 1136abfd1a8bSRiver Riddle matchLocs.push_back(read<Operation *>()->getLoc()); 1137abfd1a8bSRiver Riddle Location matchLoc = rewriter.getFusedLoc(matchLocs); 1138abfd1a8bSRiver Riddle 1139abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n" 1140abfd1a8bSRiver Riddle << " * Location: " << matchLoc << "\n\n"); 1141abfd1a8bSRiver Riddle matches->emplace_back(matchLoc, patterns[patternIndex], benefit); 1142abfd1a8bSRiver Riddle readList<const void *>(matches->back().values); 1143abfd1a8bSRiver Riddle curCodeIt = dest; 1144abfd1a8bSRiver Riddle break; 1145abfd1a8bSRiver Riddle } 1146abfd1a8bSRiver Riddle case ReplaceOp: { 1147abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n"); 1148abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1149abfd1a8bSRiver Riddle SmallVector<Value, 16> args; 1150abfd1a8bSRiver Riddle readList<Value>(args); 1151abfd1a8bSRiver Riddle 1152abfd1a8bSRiver Riddle LLVM_DEBUG({ 1153abfd1a8bSRiver Riddle llvm::dbgs() << " * Operation: " << *op << "\n" 1154abfd1a8bSRiver Riddle << " * Values: "; 1155abfd1a8bSRiver Riddle llvm::interleaveComma(args, llvm::dbgs()); 1156abfd1a8bSRiver Riddle llvm::dbgs() << "\n\n"; 1157abfd1a8bSRiver Riddle }); 1158abfd1a8bSRiver Riddle rewriter.replaceOp(op, args); 1159abfd1a8bSRiver Riddle break; 1160abfd1a8bSRiver Riddle } 1161abfd1a8bSRiver Riddle case SwitchAttribute: { 1162abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n"); 1163abfd1a8bSRiver Riddle Attribute value = read<Attribute>(); 1164abfd1a8bSRiver Riddle ArrayAttr cases = read<ArrayAttr>(); 1165abfd1a8bSRiver Riddle handleSwitch(value, cases); 1166abfd1a8bSRiver Riddle break; 1167abfd1a8bSRiver Riddle } 1168abfd1a8bSRiver Riddle case SwitchOperandCount: { 1169abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n"); 1170abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1171abfd1a8bSRiver Riddle auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1172abfd1a8bSRiver Riddle 1173abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1174abfd1a8bSRiver Riddle handleSwitch(op->getNumOperands(), cases); 1175abfd1a8bSRiver Riddle break; 1176abfd1a8bSRiver Riddle } 1177abfd1a8bSRiver Riddle case SwitchOperationName: { 1178abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n"); 1179abfd1a8bSRiver Riddle OperationName value = read<Operation *>()->getName(); 1180abfd1a8bSRiver Riddle size_t caseCount = read(); 1181abfd1a8bSRiver Riddle 1182abfd1a8bSRiver Riddle // The operation names are stored in-line, so to print them out for 1183abfd1a8bSRiver Riddle // debugging purposes we need to read the array before executing the 1184abfd1a8bSRiver Riddle // switch so that we can display all of the possible values. 1185abfd1a8bSRiver Riddle LLVM_DEBUG({ 1186abfd1a8bSRiver Riddle const ByteCodeField *prevCodeIt = curCodeIt; 1187abfd1a8bSRiver Riddle llvm::dbgs() << " * Value: " << value << "\n" 1188abfd1a8bSRiver Riddle << " * Cases: "; 1189abfd1a8bSRiver Riddle llvm::interleaveComma( 1190abfd1a8bSRiver Riddle llvm::map_range(llvm::seq<size_t>(0, caseCount), 1191abfd1a8bSRiver Riddle [&](size_t i) { return read<OperationName>(); }), 1192abfd1a8bSRiver Riddle llvm::dbgs()); 1193abfd1a8bSRiver Riddle llvm::dbgs() << "\n\n"; 1194abfd1a8bSRiver Riddle curCodeIt = prevCodeIt; 1195abfd1a8bSRiver Riddle }); 1196abfd1a8bSRiver Riddle 1197abfd1a8bSRiver Riddle // Try to find the switch value within any of the cases. 1198abfd1a8bSRiver Riddle size_t jumpDest = 0; 1199abfd1a8bSRiver Riddle for (size_t i = 0; i != caseCount; ++i) { 1200abfd1a8bSRiver Riddle if (read<OperationName>() == value) { 1201abfd1a8bSRiver Riddle curCodeIt += (caseCount - i - 1); 1202abfd1a8bSRiver Riddle jumpDest = i + 1; 1203abfd1a8bSRiver Riddle break; 1204abfd1a8bSRiver Riddle } 1205abfd1a8bSRiver Riddle } 1206abfd1a8bSRiver Riddle selectJump(jumpDest); 1207abfd1a8bSRiver Riddle break; 1208abfd1a8bSRiver Riddle } 1209abfd1a8bSRiver Riddle case SwitchResultCount: { 1210abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n"); 1211abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1212abfd1a8bSRiver Riddle auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1213abfd1a8bSRiver Riddle 1214abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1215abfd1a8bSRiver Riddle handleSwitch(op->getNumResults(), cases); 1216abfd1a8bSRiver Riddle break; 1217abfd1a8bSRiver Riddle } 1218abfd1a8bSRiver Riddle case SwitchType: { 1219abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n"); 1220abfd1a8bSRiver Riddle Type value = read<Type>(); 1221abfd1a8bSRiver Riddle auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>(); 1222abfd1a8bSRiver Riddle handleSwitch(value, cases); 1223abfd1a8bSRiver Riddle break; 1224abfd1a8bSRiver Riddle } 1225abfd1a8bSRiver Riddle } 1226abfd1a8bSRiver Riddle } 1227abfd1a8bSRiver Riddle } 1228abfd1a8bSRiver Riddle 1229abfd1a8bSRiver Riddle /// Run the pattern matcher on the given root operation, collecting the matched 1230abfd1a8bSRiver Riddle /// patterns in `matches`. 1231abfd1a8bSRiver Riddle void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, 1232abfd1a8bSRiver Riddle SmallVectorImpl<MatchResult> &matches, 1233abfd1a8bSRiver Riddle PDLByteCodeMutableState &state) const { 1234abfd1a8bSRiver Riddle // The first memory slot is always the root operation. 1235abfd1a8bSRiver Riddle state.memory[0] = op; 1236abfd1a8bSRiver Riddle 1237abfd1a8bSRiver Riddle // The matcher function always starts at code address 0. 1238abfd1a8bSRiver Riddle ByteCodeExecutor executor(matcherByteCode.data(), state.memory, uniquedData, 1239abfd1a8bSRiver Riddle matcherByteCode, state.currentPatternBenefits, 1240abfd1a8bSRiver Riddle patterns, constraintFunctions, createFunctions, 1241abfd1a8bSRiver Riddle rewriteFunctions); 1242abfd1a8bSRiver Riddle executor.execute(rewriter, &matches); 1243abfd1a8bSRiver Riddle 1244abfd1a8bSRiver Riddle // Order the found matches by benefit. 1245abfd1a8bSRiver Riddle std::stable_sort(matches.begin(), matches.end(), 1246abfd1a8bSRiver Riddle [](const MatchResult &lhs, const MatchResult &rhs) { 1247abfd1a8bSRiver Riddle return lhs.benefit > rhs.benefit; 1248abfd1a8bSRiver Riddle }); 1249abfd1a8bSRiver Riddle } 1250abfd1a8bSRiver Riddle 1251abfd1a8bSRiver Riddle /// Run the rewriter of the given pattern on the root operation `op`. 1252abfd1a8bSRiver Riddle void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match, 1253abfd1a8bSRiver Riddle PDLByteCodeMutableState &state) const { 1254abfd1a8bSRiver Riddle // The arguments of the rewrite function are stored at the start of the 1255abfd1a8bSRiver Riddle // memory buffer. 1256abfd1a8bSRiver Riddle llvm::copy(match.values, state.memory.begin()); 1257abfd1a8bSRiver Riddle 1258abfd1a8bSRiver Riddle ByteCodeExecutor executor( 1259abfd1a8bSRiver Riddle &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory, 1260abfd1a8bSRiver Riddle uniquedData, rewriterByteCode, state.currentPatternBenefits, patterns, 1261abfd1a8bSRiver Riddle constraintFunctions, createFunctions, rewriteFunctions); 1262abfd1a8bSRiver Riddle executor.execute(rewriter, /*matches=*/nullptr, match.location); 1263abfd1a8bSRiver Riddle } 1264