1abfd1a8bSRiver Riddle //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===// 2abfd1a8bSRiver Riddle // 3abfd1a8bSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4abfd1a8bSRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 5abfd1a8bSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6abfd1a8bSRiver Riddle // 7abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 8abfd1a8bSRiver Riddle // 9abfd1a8bSRiver Riddle // This file implements MLIR to byte-code generation and the interpreter. 10abfd1a8bSRiver Riddle // 11abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 12abfd1a8bSRiver Riddle 13abfd1a8bSRiver Riddle #include "ByteCode.h" 14abfd1a8bSRiver Riddle #include "mlir/Analysis/Liveness.h" 15abfd1a8bSRiver Riddle #include "mlir/Dialect/PDL/IR/PDLTypes.h" 16abfd1a8bSRiver Riddle #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 17e66c2e25SRiver Riddle #include "mlir/IR/BuiltinOps.h" 18abfd1a8bSRiver Riddle #include "mlir/IR/RegionGraphTraits.h" 19abfd1a8bSRiver Riddle #include "llvm/ADT/IntervalMap.h" 20abfd1a8bSRiver Riddle #include "llvm/ADT/PostOrderIterator.h" 21abfd1a8bSRiver Riddle #include "llvm/ADT/TypeSwitch.h" 22abfd1a8bSRiver Riddle #include "llvm/Support/Debug.h" 23abfd1a8bSRiver Riddle 24abfd1a8bSRiver Riddle #define DEBUG_TYPE "pdl-bytecode" 25abfd1a8bSRiver Riddle 26abfd1a8bSRiver Riddle using namespace mlir; 27abfd1a8bSRiver Riddle using namespace mlir::detail; 28abfd1a8bSRiver Riddle 29abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 30abfd1a8bSRiver Riddle // PDLByteCodePattern 31abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 32abfd1a8bSRiver Riddle 33abfd1a8bSRiver Riddle PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp, 34abfd1a8bSRiver Riddle ByteCodeAddr rewriterAddr) { 35abfd1a8bSRiver Riddle SmallVector<StringRef, 8> generatedOps; 36abfd1a8bSRiver Riddle if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr()) 37abfd1a8bSRiver Riddle generatedOps = 38abfd1a8bSRiver Riddle llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>()); 39abfd1a8bSRiver Riddle 40abfd1a8bSRiver Riddle PatternBenefit benefit = matchOp.benefit(); 41abfd1a8bSRiver Riddle MLIRContext *ctx = matchOp.getContext(); 42abfd1a8bSRiver Riddle 43abfd1a8bSRiver Riddle // Check to see if this is pattern matches a specific operation type. 44abfd1a8bSRiver Riddle if (Optional<StringRef> rootKind = matchOp.rootKind()) 45abfd1a8bSRiver Riddle return PDLByteCodePattern(rewriterAddr, *rootKind, generatedOps, benefit, 46abfd1a8bSRiver Riddle ctx); 47abfd1a8bSRiver Riddle return PDLByteCodePattern(rewriterAddr, generatedOps, benefit, ctx, 48abfd1a8bSRiver Riddle MatchAnyOpTypeTag()); 49abfd1a8bSRiver Riddle } 50abfd1a8bSRiver Riddle 51abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 52abfd1a8bSRiver Riddle // PDLByteCodeMutableState 53abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 54abfd1a8bSRiver Riddle 55abfd1a8bSRiver Riddle /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds 56abfd1a8bSRiver Riddle /// to the position of the pattern within the range returned by 57abfd1a8bSRiver Riddle /// `PDLByteCode::getPatterns`. 58abfd1a8bSRiver Riddle void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex, 59abfd1a8bSRiver Riddle PatternBenefit benefit) { 60abfd1a8bSRiver Riddle currentPatternBenefits[patternIndex] = benefit; 61abfd1a8bSRiver Riddle } 62abfd1a8bSRiver Riddle 63abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 64abfd1a8bSRiver Riddle // Bytecode OpCodes 65abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 66abfd1a8bSRiver Riddle 67abfd1a8bSRiver Riddle namespace { 68abfd1a8bSRiver Riddle enum OpCode : ByteCodeField { 69abfd1a8bSRiver Riddle /// Apply an externally registered constraint. 70abfd1a8bSRiver Riddle ApplyConstraint, 71abfd1a8bSRiver Riddle /// Apply an externally registered rewrite. 72abfd1a8bSRiver Riddle ApplyRewrite, 73abfd1a8bSRiver Riddle /// Check if two generic values are equal. 74abfd1a8bSRiver Riddle AreEqual, 75abfd1a8bSRiver Riddle /// Unconditional branch. 76abfd1a8bSRiver Riddle Branch, 77abfd1a8bSRiver Riddle /// Compare the operand count of an operation with a constant. 78abfd1a8bSRiver Riddle CheckOperandCount, 79abfd1a8bSRiver Riddle /// Compare the name of an operation with a constant. 80abfd1a8bSRiver Riddle CheckOperationName, 81abfd1a8bSRiver Riddle /// Compare the result count of an operation with a constant. 82abfd1a8bSRiver Riddle CheckResultCount, 83abfd1a8bSRiver Riddle /// Invoke a native creation method. 84abfd1a8bSRiver Riddle CreateNative, 85abfd1a8bSRiver Riddle /// Create an operation. 86abfd1a8bSRiver Riddle CreateOperation, 87abfd1a8bSRiver Riddle /// Erase an operation. 88abfd1a8bSRiver Riddle EraseOp, 89abfd1a8bSRiver Riddle /// Terminate a matcher or rewrite sequence. 90abfd1a8bSRiver Riddle Finalize, 91abfd1a8bSRiver Riddle /// Get a specific attribute of an operation. 92abfd1a8bSRiver Riddle GetAttribute, 93abfd1a8bSRiver Riddle /// Get the type of an attribute. 94abfd1a8bSRiver Riddle GetAttributeType, 95abfd1a8bSRiver Riddle /// Get the defining operation of a value. 96abfd1a8bSRiver Riddle GetDefiningOp, 97abfd1a8bSRiver Riddle /// Get a specific operand of an operation. 98abfd1a8bSRiver Riddle GetOperand0, 99abfd1a8bSRiver Riddle GetOperand1, 100abfd1a8bSRiver Riddle GetOperand2, 101abfd1a8bSRiver Riddle GetOperand3, 102abfd1a8bSRiver Riddle GetOperandN, 103abfd1a8bSRiver Riddle /// Get a specific result of an operation. 104abfd1a8bSRiver Riddle GetResult0, 105abfd1a8bSRiver Riddle GetResult1, 106abfd1a8bSRiver Riddle GetResult2, 107abfd1a8bSRiver Riddle GetResult3, 108abfd1a8bSRiver Riddle GetResultN, 109abfd1a8bSRiver Riddle /// Get the type of a value. 110abfd1a8bSRiver Riddle GetValueType, 111abfd1a8bSRiver Riddle /// Check if a generic value is not null. 112abfd1a8bSRiver Riddle IsNotNull, 113abfd1a8bSRiver Riddle /// Record a successful pattern match. 114abfd1a8bSRiver Riddle RecordMatch, 115abfd1a8bSRiver Riddle /// Replace an operation. 116abfd1a8bSRiver Riddle ReplaceOp, 117abfd1a8bSRiver Riddle /// Compare an attribute with a set of constants. 118abfd1a8bSRiver Riddle SwitchAttribute, 119abfd1a8bSRiver Riddle /// Compare the operand count of an operation with a set of constants. 120abfd1a8bSRiver Riddle SwitchOperandCount, 121abfd1a8bSRiver Riddle /// Compare the name of an operation with a set of constants. 122abfd1a8bSRiver Riddle SwitchOperationName, 123abfd1a8bSRiver Riddle /// Compare the result count of an operation with a set of constants. 124abfd1a8bSRiver Riddle SwitchResultCount, 125abfd1a8bSRiver Riddle /// Compare a type with a set of constants. 126abfd1a8bSRiver Riddle SwitchType, 127abfd1a8bSRiver Riddle }; 128abfd1a8bSRiver Riddle 129abfd1a8bSRiver Riddle enum class PDLValueKind { Attribute, Operation, Type, Value }; 130abfd1a8bSRiver Riddle } // end anonymous namespace 131abfd1a8bSRiver Riddle 132abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 133abfd1a8bSRiver Riddle // ByteCode Generation 134abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 135abfd1a8bSRiver Riddle 136abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 137abfd1a8bSRiver Riddle // Generator 138abfd1a8bSRiver Riddle 139abfd1a8bSRiver Riddle namespace { 140abfd1a8bSRiver Riddle struct ByteCodeWriter; 141abfd1a8bSRiver Riddle 142abfd1a8bSRiver Riddle /// This class represents the main generator for the pattern bytecode. 143abfd1a8bSRiver Riddle class Generator { 144abfd1a8bSRiver Riddle public: 145abfd1a8bSRiver Riddle Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData, 146abfd1a8bSRiver Riddle SmallVectorImpl<ByteCodeField> &matcherByteCode, 147abfd1a8bSRiver Riddle SmallVectorImpl<ByteCodeField> &rewriterByteCode, 148abfd1a8bSRiver Riddle SmallVectorImpl<PDLByteCodePattern> &patterns, 149abfd1a8bSRiver Riddle ByteCodeField &maxValueMemoryIndex, 150abfd1a8bSRiver Riddle llvm::StringMap<PDLConstraintFunction> &constraintFns, 151abfd1a8bSRiver Riddle llvm::StringMap<PDLCreateFunction> &createFns, 152abfd1a8bSRiver Riddle llvm::StringMap<PDLRewriteFunction> &rewriteFns) 153abfd1a8bSRiver Riddle : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), 154abfd1a8bSRiver Riddle rewriterByteCode(rewriterByteCode), patterns(patterns), 155abfd1a8bSRiver Riddle maxValueMemoryIndex(maxValueMemoryIndex) { 156abfd1a8bSRiver Riddle for (auto it : llvm::enumerate(constraintFns)) 157abfd1a8bSRiver Riddle constraintToMemIndex.try_emplace(it.value().first(), it.index()); 158abfd1a8bSRiver Riddle for (auto it : llvm::enumerate(createFns)) 159abfd1a8bSRiver Riddle nativeCreateToMemIndex.try_emplace(it.value().first(), it.index()); 160abfd1a8bSRiver Riddle for (auto it : llvm::enumerate(rewriteFns)) 161abfd1a8bSRiver Riddle externalRewriterToMemIndex.try_emplace(it.value().first(), it.index()); 162abfd1a8bSRiver Riddle } 163abfd1a8bSRiver Riddle 164abfd1a8bSRiver Riddle /// Generate the bytecode for the given PDL interpreter module. 165abfd1a8bSRiver Riddle void generate(ModuleOp module); 166abfd1a8bSRiver Riddle 167abfd1a8bSRiver Riddle /// Return the memory index to use for the given value. 168abfd1a8bSRiver Riddle ByteCodeField &getMemIndex(Value value) { 169abfd1a8bSRiver Riddle assert(valueToMemIndex.count(value) && 170abfd1a8bSRiver Riddle "expected memory index to be assigned"); 171abfd1a8bSRiver Riddle return valueToMemIndex[value]; 172abfd1a8bSRiver Riddle } 173abfd1a8bSRiver Riddle 174abfd1a8bSRiver Riddle /// Return an index to use when referring to the given data that is uniqued in 175abfd1a8bSRiver Riddle /// the MLIR context. 176abfd1a8bSRiver Riddle template <typename T> 177abfd1a8bSRiver Riddle std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &> 178abfd1a8bSRiver Riddle getMemIndex(T val) { 179abfd1a8bSRiver Riddle const void *opaqueVal = val.getAsOpaquePointer(); 180abfd1a8bSRiver Riddle 181abfd1a8bSRiver Riddle // Get or insert a reference to this value. 182abfd1a8bSRiver Riddle auto it = uniquedDataToMemIndex.try_emplace( 183abfd1a8bSRiver Riddle opaqueVal, maxValueMemoryIndex + uniquedData.size()); 184abfd1a8bSRiver Riddle if (it.second) 185abfd1a8bSRiver Riddle uniquedData.push_back(opaqueVal); 186abfd1a8bSRiver Riddle return it.first->second; 187abfd1a8bSRiver Riddle } 188abfd1a8bSRiver Riddle 189abfd1a8bSRiver Riddle private: 190abfd1a8bSRiver Riddle /// Allocate memory indices for the results of operations within the matcher 191abfd1a8bSRiver Riddle /// and rewriters. 192abfd1a8bSRiver Riddle void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule); 193abfd1a8bSRiver Riddle 194abfd1a8bSRiver Riddle /// Generate the bytecode for the given operation. 195abfd1a8bSRiver Riddle void generate(Operation *op, ByteCodeWriter &writer); 196abfd1a8bSRiver Riddle void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer); 197abfd1a8bSRiver Riddle void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer); 198abfd1a8bSRiver Riddle void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer); 199abfd1a8bSRiver Riddle void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer); 200abfd1a8bSRiver Riddle void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer); 201abfd1a8bSRiver Riddle void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer); 202abfd1a8bSRiver Riddle void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer); 203abfd1a8bSRiver Riddle void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); 204abfd1a8bSRiver Riddle void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); 205abfd1a8bSRiver Riddle void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); 206abfd1a8bSRiver Riddle void generate(pdl_interp::CreateNativeOp op, ByteCodeWriter &writer); 207abfd1a8bSRiver Riddle void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); 208abfd1a8bSRiver Riddle void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); 209abfd1a8bSRiver Riddle void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); 210abfd1a8bSRiver Riddle void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer); 211abfd1a8bSRiver Riddle void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); 212abfd1a8bSRiver Riddle void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); 213abfd1a8bSRiver Riddle void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); 214abfd1a8bSRiver Riddle void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer); 215abfd1a8bSRiver Riddle void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); 216abfd1a8bSRiver Riddle void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); 217abfd1a8bSRiver Riddle void generate(pdl_interp::InferredTypeOp op, ByteCodeWriter &writer); 218abfd1a8bSRiver Riddle void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); 219abfd1a8bSRiver Riddle void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer); 220abfd1a8bSRiver Riddle void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer); 221abfd1a8bSRiver Riddle void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer); 222abfd1a8bSRiver Riddle void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer); 223abfd1a8bSRiver Riddle void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer); 224abfd1a8bSRiver Riddle void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer); 225abfd1a8bSRiver Riddle void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer); 226abfd1a8bSRiver Riddle 227abfd1a8bSRiver Riddle /// Mapping from value to its corresponding memory index. 228abfd1a8bSRiver Riddle DenseMap<Value, ByteCodeField> valueToMemIndex; 229abfd1a8bSRiver Riddle 230abfd1a8bSRiver Riddle /// Mapping from the name of an externally registered rewrite to its index in 231abfd1a8bSRiver Riddle /// the bytecode registry. 232abfd1a8bSRiver Riddle llvm::StringMap<ByteCodeField> externalRewriterToMemIndex; 233abfd1a8bSRiver Riddle 234abfd1a8bSRiver Riddle /// Mapping from the name of an externally registered constraint to its index 235abfd1a8bSRiver Riddle /// in the bytecode registry. 236abfd1a8bSRiver Riddle llvm::StringMap<ByteCodeField> constraintToMemIndex; 237abfd1a8bSRiver Riddle 238abfd1a8bSRiver Riddle /// Mapping from the name of an externally registered creation method to its 239abfd1a8bSRiver Riddle /// index in the bytecode registry. 240abfd1a8bSRiver Riddle llvm::StringMap<ByteCodeField> nativeCreateToMemIndex; 241abfd1a8bSRiver Riddle 242abfd1a8bSRiver Riddle /// Mapping from rewriter function name to the bytecode address of the 243abfd1a8bSRiver Riddle /// rewriter function in byte. 244abfd1a8bSRiver Riddle llvm::StringMap<ByteCodeAddr> rewriterToAddr; 245abfd1a8bSRiver Riddle 246abfd1a8bSRiver Riddle /// Mapping from a uniqued storage object to its memory index within 247abfd1a8bSRiver Riddle /// `uniquedData`. 248abfd1a8bSRiver Riddle DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex; 249abfd1a8bSRiver Riddle 250abfd1a8bSRiver Riddle /// The current MLIR context. 251abfd1a8bSRiver Riddle MLIRContext *ctx; 252abfd1a8bSRiver Riddle 253abfd1a8bSRiver Riddle /// Data of the ByteCode class to be populated. 254abfd1a8bSRiver Riddle std::vector<const void *> &uniquedData; 255abfd1a8bSRiver Riddle SmallVectorImpl<ByteCodeField> &matcherByteCode; 256abfd1a8bSRiver Riddle SmallVectorImpl<ByteCodeField> &rewriterByteCode; 257abfd1a8bSRiver Riddle SmallVectorImpl<PDLByteCodePattern> &patterns; 258abfd1a8bSRiver Riddle ByteCodeField &maxValueMemoryIndex; 259abfd1a8bSRiver Riddle }; 260abfd1a8bSRiver Riddle 261abfd1a8bSRiver Riddle /// This class provides utilities for writing a bytecode stream. 262abfd1a8bSRiver Riddle struct ByteCodeWriter { 263abfd1a8bSRiver Riddle ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator) 264abfd1a8bSRiver Riddle : bytecode(bytecode), generator(generator) {} 265abfd1a8bSRiver Riddle 266abfd1a8bSRiver Riddle /// Append a field to the bytecode. 267abfd1a8bSRiver Riddle void append(ByteCodeField field) { bytecode.push_back(field); } 268fa20ab7bSRiver Riddle void append(OpCode opCode) { bytecode.push_back(opCode); } 269abfd1a8bSRiver Riddle 270abfd1a8bSRiver Riddle /// Append an address to the bytecode. 271abfd1a8bSRiver Riddle void append(ByteCodeAddr field) { 272abfd1a8bSRiver Riddle static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 273abfd1a8bSRiver Riddle "unexpected ByteCode address size"); 274abfd1a8bSRiver Riddle 275abfd1a8bSRiver Riddle ByteCodeField fieldParts[2]; 276abfd1a8bSRiver Riddle std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr)); 277abfd1a8bSRiver Riddle bytecode.append({fieldParts[0], fieldParts[1]}); 278abfd1a8bSRiver Riddle } 279abfd1a8bSRiver Riddle 280abfd1a8bSRiver Riddle /// Append a successor range to the bytecode, the exact address will need to 281abfd1a8bSRiver Riddle /// be resolved later. 282abfd1a8bSRiver Riddle void append(SuccessorRange successors) { 283abfd1a8bSRiver Riddle // Add back references to the any successors so that the address can be 284abfd1a8bSRiver Riddle // resolved later. 285abfd1a8bSRiver Riddle for (Block *successor : successors) { 286abfd1a8bSRiver Riddle unresolvedSuccessorRefs[successor].push_back(bytecode.size()); 287abfd1a8bSRiver Riddle append(ByteCodeAddr(0)); 288abfd1a8bSRiver Riddle } 289abfd1a8bSRiver Riddle } 290abfd1a8bSRiver Riddle 291abfd1a8bSRiver Riddle /// Append a range of values that will be read as generic PDLValues. 292abfd1a8bSRiver Riddle void appendPDLValueList(OperandRange values) { 293abfd1a8bSRiver Riddle bytecode.push_back(values.size()); 294abfd1a8bSRiver Riddle for (Value value : values) { 295abfd1a8bSRiver Riddle // Append the type of the value in addition to the value itself. 296abfd1a8bSRiver Riddle PDLValueKind kind = 297abfd1a8bSRiver Riddle TypeSwitch<Type, PDLValueKind>(value.getType()) 298abfd1a8bSRiver Riddle .Case<pdl::AttributeType>( 299abfd1a8bSRiver Riddle [](Type) { return PDLValueKind::Attribute; }) 300abfd1a8bSRiver Riddle .Case<pdl::OperationType>( 301abfd1a8bSRiver Riddle [](Type) { return PDLValueKind::Operation; }) 302abfd1a8bSRiver Riddle .Case<pdl::TypeType>([](Type) { return PDLValueKind::Type; }) 303abfd1a8bSRiver Riddle .Case<pdl::ValueType>([](Type) { return PDLValueKind::Value; }); 304abfd1a8bSRiver Riddle bytecode.push_back(static_cast<ByteCodeField>(kind)); 305abfd1a8bSRiver Riddle append(value); 306abfd1a8bSRiver Riddle } 307abfd1a8bSRiver Riddle } 308abfd1a8bSRiver Riddle 309abfd1a8bSRiver Riddle /// Check if the given class `T` has an iterator type. 310abfd1a8bSRiver Riddle template <typename T, typename... Args> 311abfd1a8bSRiver Riddle using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer()); 312abfd1a8bSRiver Riddle 313abfd1a8bSRiver Riddle /// Append a value that will be stored in a memory slot and not inline within 314abfd1a8bSRiver Riddle /// the bytecode. 315abfd1a8bSRiver Riddle template <typename T> 316abfd1a8bSRiver Riddle std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value || 317abfd1a8bSRiver Riddle std::is_pointer<T>::value> 318abfd1a8bSRiver Riddle append(T value) { 319abfd1a8bSRiver Riddle bytecode.push_back(generator.getMemIndex(value)); 320abfd1a8bSRiver Riddle } 321abfd1a8bSRiver Riddle 322abfd1a8bSRiver Riddle /// Append a range of values. 323abfd1a8bSRiver Riddle template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>> 324abfd1a8bSRiver Riddle std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value> 325abfd1a8bSRiver Riddle append(T range) { 326abfd1a8bSRiver Riddle bytecode.push_back(llvm::size(range)); 327abfd1a8bSRiver Riddle for (auto it : range) 328abfd1a8bSRiver Riddle append(it); 329abfd1a8bSRiver Riddle } 330abfd1a8bSRiver Riddle 331abfd1a8bSRiver Riddle /// Append a variadic number of fields to the bytecode. 332abfd1a8bSRiver Riddle template <typename FieldTy, typename Field2Ty, typename... FieldTys> 333abfd1a8bSRiver Riddle void append(FieldTy field, Field2Ty field2, FieldTys... fields) { 334abfd1a8bSRiver Riddle append(field); 335abfd1a8bSRiver Riddle append(field2, fields...); 336abfd1a8bSRiver Riddle } 337abfd1a8bSRiver Riddle 338abfd1a8bSRiver Riddle /// Successor references in the bytecode that have yet to be resolved. 339abfd1a8bSRiver Riddle DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs; 340abfd1a8bSRiver Riddle 341abfd1a8bSRiver Riddle /// The underlying bytecode buffer. 342abfd1a8bSRiver Riddle SmallVectorImpl<ByteCodeField> &bytecode; 343abfd1a8bSRiver Riddle 344abfd1a8bSRiver Riddle /// The main generator producing PDL. 345abfd1a8bSRiver Riddle Generator &generator; 346abfd1a8bSRiver Riddle }; 347abfd1a8bSRiver Riddle } // end anonymous namespace 348abfd1a8bSRiver Riddle 349abfd1a8bSRiver Riddle void Generator::generate(ModuleOp module) { 350abfd1a8bSRiver Riddle FuncOp matcherFunc = module.lookupSymbol<FuncOp>( 351abfd1a8bSRiver Riddle pdl_interp::PDLInterpDialect::getMatcherFunctionName()); 352abfd1a8bSRiver Riddle ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>( 353abfd1a8bSRiver Riddle pdl_interp::PDLInterpDialect::getRewriterModuleName()); 354abfd1a8bSRiver Riddle assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module"); 355abfd1a8bSRiver Riddle 356abfd1a8bSRiver Riddle // Allocate memory indices for the results of operations within the matcher 357abfd1a8bSRiver Riddle // and rewriters. 358abfd1a8bSRiver Riddle allocateMemoryIndices(matcherFunc, rewriterModule); 359abfd1a8bSRiver Riddle 360abfd1a8bSRiver Riddle // Generate code for the rewriter functions. 361abfd1a8bSRiver Riddle ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this); 362abfd1a8bSRiver Riddle for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) { 363abfd1a8bSRiver Riddle rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size()); 364abfd1a8bSRiver Riddle for (Operation &op : rewriterFunc.getOps()) 365abfd1a8bSRiver Riddle generate(&op, rewriterByteCodeWriter); 366abfd1a8bSRiver Riddle } 367abfd1a8bSRiver Riddle assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() && 368abfd1a8bSRiver Riddle "unexpected branches in rewriter function"); 369abfd1a8bSRiver Riddle 370abfd1a8bSRiver Riddle // Generate code for the matcher function. 371abfd1a8bSRiver Riddle DenseMap<Block *, ByteCodeAddr> blockToAddr; 372abfd1a8bSRiver Riddle llvm::ReversePostOrderTraversal<Region *> rpot(&matcherFunc.getBody()); 373abfd1a8bSRiver Riddle ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this); 374abfd1a8bSRiver Riddle for (Block *block : rpot) { 375abfd1a8bSRiver Riddle // Keep track of where this block begins within the matcher function. 376abfd1a8bSRiver Riddle blockToAddr.try_emplace(block, matcherByteCode.size()); 377abfd1a8bSRiver Riddle for (Operation &op : *block) 378abfd1a8bSRiver Riddle generate(&op, matcherByteCodeWriter); 379abfd1a8bSRiver Riddle } 380abfd1a8bSRiver Riddle 381abfd1a8bSRiver Riddle // Resolve successor references in the matcher. 382abfd1a8bSRiver Riddle for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) { 383abfd1a8bSRiver Riddle ByteCodeAddr addr = blockToAddr[it.first]; 384abfd1a8bSRiver Riddle for (unsigned offsetToFix : it.second) 385abfd1a8bSRiver Riddle std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr)); 386abfd1a8bSRiver Riddle } 387abfd1a8bSRiver Riddle } 388abfd1a8bSRiver Riddle 389abfd1a8bSRiver Riddle void Generator::allocateMemoryIndices(FuncOp matcherFunc, 390abfd1a8bSRiver Riddle ModuleOp rewriterModule) { 391abfd1a8bSRiver Riddle // Rewriters use simplistic allocation scheme that simply assigns an index to 392abfd1a8bSRiver Riddle // each result. 393abfd1a8bSRiver Riddle for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) { 394abfd1a8bSRiver Riddle ByteCodeField index = 0; 395abfd1a8bSRiver Riddle for (BlockArgument arg : rewriterFunc.getArguments()) 396abfd1a8bSRiver Riddle valueToMemIndex.try_emplace(arg, index++); 397abfd1a8bSRiver Riddle rewriterFunc.getBody().walk([&](Operation *op) { 398abfd1a8bSRiver Riddle for (Value result : op->getResults()) 399abfd1a8bSRiver Riddle valueToMemIndex.try_emplace(result, index++); 400abfd1a8bSRiver Riddle }); 401abfd1a8bSRiver Riddle if (index > maxValueMemoryIndex) 402abfd1a8bSRiver Riddle maxValueMemoryIndex = index; 403abfd1a8bSRiver Riddle } 404abfd1a8bSRiver Riddle 405abfd1a8bSRiver Riddle // The matcher function uses a more sophisticated numbering that tries to 406abfd1a8bSRiver Riddle // minimize the number of memory indices assigned. This is done by determining 407abfd1a8bSRiver Riddle // a live range of the values within the matcher, then the allocation is just 408abfd1a8bSRiver Riddle // finding the minimal number of overlapping live ranges. This is essentially 409abfd1a8bSRiver Riddle // a simplified form of register allocation where we don't necessarily have a 410abfd1a8bSRiver Riddle // limited number of registers, but we still want to minimize the number used. 411abfd1a8bSRiver Riddle DenseMap<Operation *, ByteCodeField> opToIndex; 412abfd1a8bSRiver Riddle matcherFunc.getBody().walk([&](Operation *op) { 413abfd1a8bSRiver Riddle opToIndex.insert(std::make_pair(op, opToIndex.size())); 414abfd1a8bSRiver Riddle }); 415abfd1a8bSRiver Riddle 416abfd1a8bSRiver Riddle // Liveness info for each of the defs within the matcher. 417abfd1a8bSRiver Riddle using LivenessSet = llvm::IntervalMap<ByteCodeField, char, 16>; 418abfd1a8bSRiver Riddle LivenessSet::Allocator allocator; 419abfd1a8bSRiver Riddle DenseMap<Value, LivenessSet> valueDefRanges; 420abfd1a8bSRiver Riddle 421abfd1a8bSRiver Riddle // Assign the root operation being matched to slot 0. 422abfd1a8bSRiver Riddle BlockArgument rootOpArg = matcherFunc.getArgument(0); 423abfd1a8bSRiver Riddle valueToMemIndex[rootOpArg] = 0; 424abfd1a8bSRiver Riddle 425abfd1a8bSRiver Riddle // Walk each of the blocks, computing the def interval that the value is used. 426abfd1a8bSRiver Riddle Liveness matcherLiveness(matcherFunc); 427abfd1a8bSRiver Riddle for (Block &block : matcherFunc.getBody()) { 428abfd1a8bSRiver Riddle const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block); 429abfd1a8bSRiver Riddle assert(info && "expected liveness info for block"); 430abfd1a8bSRiver Riddle auto processValue = [&](Value value, Operation *firstUseOrDef) { 431abfd1a8bSRiver Riddle // We don't need to process the root op argument, this value is always 432abfd1a8bSRiver Riddle // assigned to the first memory slot. 433abfd1a8bSRiver Riddle if (value == rootOpArg) 434abfd1a8bSRiver Riddle return; 435abfd1a8bSRiver Riddle 436abfd1a8bSRiver Riddle // Set indices for the range of this block that the value is used. 437abfd1a8bSRiver Riddle auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; 438abfd1a8bSRiver Riddle defRangeIt->second.insert( 439abfd1a8bSRiver Riddle opToIndex[firstUseOrDef], 440abfd1a8bSRiver Riddle opToIndex[info->getEndOperation(value, firstUseOrDef)], 441abfd1a8bSRiver Riddle /*dummyValue*/ 0); 442abfd1a8bSRiver Riddle }; 443abfd1a8bSRiver Riddle 444abfd1a8bSRiver Riddle // Process the live-ins of this block. 445abfd1a8bSRiver Riddle for (Value liveIn : info->in()) 446abfd1a8bSRiver Riddle processValue(liveIn, &block.front()); 447abfd1a8bSRiver Riddle 448abfd1a8bSRiver Riddle // Process any new defs within this block. 449abfd1a8bSRiver Riddle for (Operation &op : block) 450abfd1a8bSRiver Riddle for (Value result : op.getResults()) 451abfd1a8bSRiver Riddle processValue(result, &op); 452abfd1a8bSRiver Riddle } 453abfd1a8bSRiver Riddle 454abfd1a8bSRiver Riddle // Greedily allocate memory slots using the computed def live ranges. 455abfd1a8bSRiver Riddle std::vector<LivenessSet> allocatedIndices; 456abfd1a8bSRiver Riddle for (auto &defIt : valueDefRanges) { 457abfd1a8bSRiver Riddle ByteCodeField &memIndex = valueToMemIndex[defIt.first]; 458abfd1a8bSRiver Riddle LivenessSet &defSet = defIt.second; 459abfd1a8bSRiver Riddle 460abfd1a8bSRiver Riddle // Try to allocate to an existing index. 461abfd1a8bSRiver Riddle for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) { 462abfd1a8bSRiver Riddle LivenessSet &existingIndex = existingIndexIt.value(); 463abfd1a8bSRiver Riddle llvm::IntervalMapOverlaps<LivenessSet, LivenessSet> overlaps( 464abfd1a8bSRiver Riddle defIt.second, existingIndex); 465abfd1a8bSRiver Riddle if (overlaps.valid()) 466abfd1a8bSRiver Riddle continue; 467abfd1a8bSRiver Riddle // Union the range of the def within the existing index. 468abfd1a8bSRiver Riddle for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it) 469abfd1a8bSRiver Riddle existingIndex.insert(it.start(), it.stop(), /*dummyValue*/ 0); 470abfd1a8bSRiver Riddle memIndex = existingIndexIt.index() + 1; 471abfd1a8bSRiver Riddle } 472abfd1a8bSRiver Riddle 473abfd1a8bSRiver Riddle // If no existing index could be used, add a new one. 474abfd1a8bSRiver Riddle if (memIndex == 0) { 475abfd1a8bSRiver Riddle allocatedIndices.emplace_back(allocator); 476abfd1a8bSRiver Riddle for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it) 477abfd1a8bSRiver Riddle allocatedIndices.back().insert(it.start(), it.stop(), /*dummyValue*/ 0); 478abfd1a8bSRiver Riddle memIndex = allocatedIndices.size(); 479abfd1a8bSRiver Riddle } 480abfd1a8bSRiver Riddle } 481abfd1a8bSRiver Riddle 482abfd1a8bSRiver Riddle // Update the max number of indices. 483abfd1a8bSRiver Riddle ByteCodeField numMatcherIndices = allocatedIndices.size() + 1; 484abfd1a8bSRiver Riddle if (numMatcherIndices > maxValueMemoryIndex) 485abfd1a8bSRiver Riddle maxValueMemoryIndex = numMatcherIndices; 486abfd1a8bSRiver Riddle } 487abfd1a8bSRiver Riddle 488abfd1a8bSRiver Riddle void Generator::generate(Operation *op, ByteCodeWriter &writer) { 489abfd1a8bSRiver Riddle TypeSwitch<Operation *>(op) 490abfd1a8bSRiver Riddle .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp, 491abfd1a8bSRiver Riddle pdl_interp::AreEqualOp, pdl_interp::BranchOp, 492abfd1a8bSRiver Riddle pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp, 493abfd1a8bSRiver Riddle pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp, 494abfd1a8bSRiver Riddle pdl_interp::CheckTypeOp, pdl_interp::CreateAttributeOp, 495abfd1a8bSRiver Riddle pdl_interp::CreateNativeOp, pdl_interp::CreateOperationOp, 496abfd1a8bSRiver Riddle pdl_interp::CreateTypeOp, pdl_interp::EraseOp, 497abfd1a8bSRiver Riddle pdl_interp::FinalizeOp, pdl_interp::GetAttributeOp, 498abfd1a8bSRiver Riddle pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp, 499abfd1a8bSRiver Riddle pdl_interp::GetOperandOp, pdl_interp::GetResultOp, 500abfd1a8bSRiver Riddle pdl_interp::GetValueTypeOp, pdl_interp::InferredTypeOp, 501abfd1a8bSRiver Riddle pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp, 502abfd1a8bSRiver Riddle pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp, 503abfd1a8bSRiver Riddle pdl_interp::SwitchTypeOp, pdl_interp::SwitchOperandCountOp, 504abfd1a8bSRiver Riddle pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>( 505abfd1a8bSRiver Riddle [&](auto interpOp) { this->generate(interpOp, writer); }) 506abfd1a8bSRiver Riddle .Default([](Operation *) { 507abfd1a8bSRiver Riddle llvm_unreachable("unknown `pdl_interp` operation"); 508abfd1a8bSRiver Riddle }); 509abfd1a8bSRiver Riddle } 510abfd1a8bSRiver Riddle 511abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::ApplyConstraintOp op, 512abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 513abfd1a8bSRiver Riddle assert(constraintToMemIndex.count(op.name()) && 514abfd1a8bSRiver Riddle "expected index for constraint function"); 515abfd1a8bSRiver Riddle writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()], 516abfd1a8bSRiver Riddle op.constParamsAttr()); 517abfd1a8bSRiver Riddle writer.appendPDLValueList(op.args()); 518abfd1a8bSRiver Riddle writer.append(op.getSuccessors()); 519abfd1a8bSRiver Riddle } 520abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::ApplyRewriteOp op, 521abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 522abfd1a8bSRiver Riddle assert(externalRewriterToMemIndex.count(op.name()) && 523abfd1a8bSRiver Riddle "expected index for rewrite function"); 524abfd1a8bSRiver Riddle writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()], 525abfd1a8bSRiver Riddle op.constParamsAttr(), op.root()); 526abfd1a8bSRiver Riddle writer.appendPDLValueList(op.args()); 527abfd1a8bSRiver Riddle } 528abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) { 529abfd1a8bSRiver Riddle writer.append(OpCode::AreEqual, op.lhs(), op.rhs(), op.getSuccessors()); 530abfd1a8bSRiver Riddle } 531abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) { 5328affe881SRiver Riddle writer.append(OpCode::Branch, SuccessorRange(op.getOperation())); 533abfd1a8bSRiver Riddle } 534abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckAttributeOp op, 535abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 536abfd1a8bSRiver Riddle writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(), 537abfd1a8bSRiver Riddle op.getSuccessors()); 538abfd1a8bSRiver Riddle } 539abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckOperandCountOp op, 540abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 541abfd1a8bSRiver Riddle writer.append(OpCode::CheckOperandCount, op.operation(), op.count(), 542abfd1a8bSRiver Riddle op.getSuccessors()); 543abfd1a8bSRiver Riddle } 544abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckOperationNameOp op, 545abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 546abfd1a8bSRiver Riddle writer.append(OpCode::CheckOperationName, op.operation(), 547abfd1a8bSRiver Riddle OperationName(op.name(), ctx), op.getSuccessors()); 548abfd1a8bSRiver Riddle } 549abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckResultCountOp op, 550abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 551abfd1a8bSRiver Riddle writer.append(OpCode::CheckResultCount, op.operation(), op.count(), 552abfd1a8bSRiver Riddle op.getSuccessors()); 553abfd1a8bSRiver Riddle } 554abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) { 555abfd1a8bSRiver Riddle writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors()); 556abfd1a8bSRiver Riddle } 557abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateAttributeOp op, 558abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 559abfd1a8bSRiver Riddle // Simply repoint the memory index of the result to the constant. 560abfd1a8bSRiver Riddle getMemIndex(op.attribute()) = getMemIndex(op.value()); 561abfd1a8bSRiver Riddle } 562abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateNativeOp op, 563abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 564abfd1a8bSRiver Riddle assert(nativeCreateToMemIndex.count(op.name()) && 565abfd1a8bSRiver Riddle "expected index for creation function"); 566abfd1a8bSRiver Riddle writer.append(OpCode::CreateNative, nativeCreateToMemIndex[op.name()], 567abfd1a8bSRiver Riddle op.result(), op.constParamsAttr()); 568abfd1a8bSRiver Riddle writer.appendPDLValueList(op.args()); 569abfd1a8bSRiver Riddle } 570abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateOperationOp op, 571abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 572abfd1a8bSRiver Riddle writer.append(OpCode::CreateOperation, op.operation(), 573abfd1a8bSRiver Riddle OperationName(op.name(), ctx), op.operands()); 574abfd1a8bSRiver Riddle 575abfd1a8bSRiver Riddle // Add the attributes. 576abfd1a8bSRiver Riddle OperandRange attributes = op.attributes(); 577abfd1a8bSRiver Riddle writer.append(static_cast<ByteCodeField>(attributes.size())); 578abfd1a8bSRiver Riddle for (auto it : llvm::zip(op.attributeNames(), op.attributes())) { 579abfd1a8bSRiver Riddle writer.append( 580abfd1a8bSRiver Riddle Identifier::get(std::get<0>(it).cast<StringAttr>().getValue(), ctx), 581abfd1a8bSRiver Riddle std::get<1>(it)); 582abfd1a8bSRiver Riddle } 583abfd1a8bSRiver Riddle writer.append(op.types()); 584abfd1a8bSRiver Riddle } 585abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) { 586abfd1a8bSRiver Riddle // Simply repoint the memory index of the result to the constant. 587abfd1a8bSRiver Riddle getMemIndex(op.result()) = getMemIndex(op.value()); 588abfd1a8bSRiver Riddle } 589abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { 590abfd1a8bSRiver Riddle writer.append(OpCode::EraseOp, op.operation()); 591abfd1a8bSRiver Riddle } 592abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) { 593abfd1a8bSRiver Riddle writer.append(OpCode::Finalize); 594abfd1a8bSRiver Riddle } 595abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetAttributeOp op, 596abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 597abfd1a8bSRiver Riddle writer.append(OpCode::GetAttribute, op.attribute(), op.operation(), 598abfd1a8bSRiver Riddle Identifier::get(op.name(), ctx)); 599abfd1a8bSRiver Riddle } 600abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetAttributeTypeOp op, 601abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 602abfd1a8bSRiver Riddle writer.append(OpCode::GetAttributeType, op.result(), op.value()); 603abfd1a8bSRiver Riddle } 604abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetDefiningOpOp op, 605abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 606abfd1a8bSRiver Riddle writer.append(OpCode::GetDefiningOp, op.operation(), op.value()); 607abfd1a8bSRiver Riddle } 608abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) { 609abfd1a8bSRiver Riddle uint32_t index = op.index(); 610abfd1a8bSRiver Riddle if (index < 4) 611abfd1a8bSRiver Riddle writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index)); 612abfd1a8bSRiver Riddle else 613abfd1a8bSRiver Riddle writer.append(OpCode::GetOperandN, index); 614abfd1a8bSRiver Riddle writer.append(op.operation(), op.value()); 615abfd1a8bSRiver Riddle } 616abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) { 617abfd1a8bSRiver Riddle uint32_t index = op.index(); 618abfd1a8bSRiver Riddle if (index < 4) 619abfd1a8bSRiver Riddle writer.append(static_cast<OpCode>(OpCode::GetResult0 + index)); 620abfd1a8bSRiver Riddle else 621abfd1a8bSRiver Riddle writer.append(OpCode::GetResultN, index); 622abfd1a8bSRiver Riddle writer.append(op.operation(), op.value()); 623abfd1a8bSRiver Riddle } 624abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::GetValueTypeOp op, 625abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 626abfd1a8bSRiver Riddle writer.append(OpCode::GetValueType, op.result(), op.value()); 627abfd1a8bSRiver Riddle } 628abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::InferredTypeOp op, 629abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 630abfd1a8bSRiver Riddle // InferType maps to a null type as a marker for inferring a result type. 631abfd1a8bSRiver Riddle getMemIndex(op.type()) = getMemIndex(Type()); 632abfd1a8bSRiver Riddle } 633abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) { 634abfd1a8bSRiver Riddle writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors()); 635abfd1a8bSRiver Riddle } 636abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { 637abfd1a8bSRiver Riddle ByteCodeField patternIndex = patterns.size(); 638abfd1a8bSRiver Riddle patterns.emplace_back(PDLByteCodePattern::create( 639abfd1a8bSRiver Riddle op, rewriterToAddr[op.rewriter().getLeafReference()])); 6408affe881SRiver Riddle writer.append(OpCode::RecordMatch, patternIndex, 6418affe881SRiver Riddle SuccessorRange(op.getOperation()), op.matchedOps(), 6428affe881SRiver Riddle op.inputs()); 643abfd1a8bSRiver Riddle } 644abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) { 645abfd1a8bSRiver Riddle writer.append(OpCode::ReplaceOp, op.operation(), op.replValues()); 646abfd1a8bSRiver Riddle } 647abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchAttributeOp op, 648abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 649abfd1a8bSRiver Riddle writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(), 650abfd1a8bSRiver Riddle op.getSuccessors()); 651abfd1a8bSRiver Riddle } 652abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchOperandCountOp op, 653abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 654abfd1a8bSRiver Riddle writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(), 655abfd1a8bSRiver Riddle op.getSuccessors()); 656abfd1a8bSRiver Riddle } 657abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchOperationNameOp op, 658abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 659abfd1a8bSRiver Riddle auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) { 660abfd1a8bSRiver Riddle return OperationName(attr.cast<StringAttr>().getValue(), ctx); 661abfd1a8bSRiver Riddle }); 662abfd1a8bSRiver Riddle writer.append(OpCode::SwitchOperationName, op.operation(), cases, 663abfd1a8bSRiver Riddle op.getSuccessors()); 664abfd1a8bSRiver Riddle } 665abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchResultCountOp op, 666abfd1a8bSRiver Riddle ByteCodeWriter &writer) { 667abfd1a8bSRiver Riddle writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(), 668abfd1a8bSRiver Riddle op.getSuccessors()); 669abfd1a8bSRiver Riddle } 670abfd1a8bSRiver Riddle void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) { 671abfd1a8bSRiver Riddle writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(), 672abfd1a8bSRiver Riddle op.getSuccessors()); 673abfd1a8bSRiver Riddle } 674abfd1a8bSRiver Riddle 675abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 676abfd1a8bSRiver Riddle // PDLByteCode 677abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 678abfd1a8bSRiver Riddle 679abfd1a8bSRiver Riddle PDLByteCode::PDLByteCode(ModuleOp module, 680abfd1a8bSRiver Riddle llvm::StringMap<PDLConstraintFunction> constraintFns, 681abfd1a8bSRiver Riddle llvm::StringMap<PDLCreateFunction> createFns, 682abfd1a8bSRiver Riddle llvm::StringMap<PDLRewriteFunction> rewriteFns) { 683abfd1a8bSRiver Riddle Generator generator(module.getContext(), uniquedData, matcherByteCode, 684abfd1a8bSRiver Riddle rewriterByteCode, patterns, maxValueMemoryIndex, 685abfd1a8bSRiver Riddle constraintFns, createFns, rewriteFns); 686abfd1a8bSRiver Riddle generator.generate(module); 687abfd1a8bSRiver Riddle 688abfd1a8bSRiver Riddle // Initialize the external functions. 689abfd1a8bSRiver Riddle for (auto &it : constraintFns) 690abfd1a8bSRiver Riddle constraintFunctions.push_back(std::move(it.second)); 691abfd1a8bSRiver Riddle for (auto &it : createFns) 692abfd1a8bSRiver Riddle createFunctions.push_back(std::move(it.second)); 693abfd1a8bSRiver Riddle for (auto &it : rewriteFns) 694abfd1a8bSRiver Riddle rewriteFunctions.push_back(std::move(it.second)); 695abfd1a8bSRiver Riddle } 696abfd1a8bSRiver Riddle 697abfd1a8bSRiver Riddle /// Initialize the given state such that it can be used to execute the current 698abfd1a8bSRiver Riddle /// bytecode. 699abfd1a8bSRiver Riddle void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { 700abfd1a8bSRiver Riddle state.memory.resize(maxValueMemoryIndex, nullptr); 701abfd1a8bSRiver Riddle state.currentPatternBenefits.reserve(patterns.size()); 702abfd1a8bSRiver Riddle for (const PDLByteCodePattern &pattern : patterns) 703abfd1a8bSRiver Riddle state.currentPatternBenefits.push_back(pattern.getBenefit()); 704abfd1a8bSRiver Riddle } 705abfd1a8bSRiver Riddle 706abfd1a8bSRiver Riddle //===----------------------------------------------------------------------===// 707abfd1a8bSRiver Riddle // ByteCode Execution 708abfd1a8bSRiver Riddle 709abfd1a8bSRiver Riddle namespace { 710abfd1a8bSRiver Riddle /// This class provides support for executing a bytecode stream. 711abfd1a8bSRiver Riddle class ByteCodeExecutor { 712abfd1a8bSRiver Riddle public: 713abfd1a8bSRiver Riddle ByteCodeExecutor(const ByteCodeField *curCodeIt, 714abfd1a8bSRiver Riddle MutableArrayRef<const void *> memory, 715abfd1a8bSRiver Riddle ArrayRef<const void *> uniquedMemory, 716abfd1a8bSRiver Riddle ArrayRef<ByteCodeField> code, 717abfd1a8bSRiver Riddle ArrayRef<PatternBenefit> currentPatternBenefits, 718abfd1a8bSRiver Riddle ArrayRef<PDLByteCodePattern> patterns, 719abfd1a8bSRiver Riddle ArrayRef<PDLConstraintFunction> constraintFunctions, 720abfd1a8bSRiver Riddle ArrayRef<PDLCreateFunction> createFunctions, 721abfd1a8bSRiver Riddle ArrayRef<PDLRewriteFunction> rewriteFunctions) 722abfd1a8bSRiver Riddle : curCodeIt(curCodeIt), memory(memory), uniquedMemory(uniquedMemory), 723abfd1a8bSRiver Riddle code(code), currentPatternBenefits(currentPatternBenefits), 724abfd1a8bSRiver Riddle patterns(patterns), constraintFunctions(constraintFunctions), 725abfd1a8bSRiver Riddle createFunctions(createFunctions), rewriteFunctions(rewriteFunctions) {} 726abfd1a8bSRiver Riddle 727abfd1a8bSRiver Riddle /// Start executing the code at the current bytecode index. `matches` is an 728abfd1a8bSRiver Riddle /// optional field provided when this function is executed in a matching 729abfd1a8bSRiver Riddle /// context. 730abfd1a8bSRiver Riddle void execute(PatternRewriter &rewriter, 731abfd1a8bSRiver Riddle SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr, 732abfd1a8bSRiver Riddle Optional<Location> mainRewriteLoc = {}); 733abfd1a8bSRiver Riddle 734abfd1a8bSRiver Riddle private: 735*154cabe7SRiver Riddle /// Internal implementation of executing each of the bytecode commands. 736*154cabe7SRiver Riddle void executeApplyConstraint(PatternRewriter &rewriter); 737*154cabe7SRiver Riddle void executeApplyRewrite(PatternRewriter &rewriter); 738*154cabe7SRiver Riddle void executeAreEqual(); 739*154cabe7SRiver Riddle void executeBranch(); 740*154cabe7SRiver Riddle void executeCheckOperandCount(); 741*154cabe7SRiver Riddle void executeCheckOperationName(); 742*154cabe7SRiver Riddle void executeCheckResultCount(); 743*154cabe7SRiver Riddle void executeCreateNative(PatternRewriter &rewriter); 744*154cabe7SRiver Riddle void executeCreateOperation(PatternRewriter &rewriter, 745*154cabe7SRiver Riddle Location mainRewriteLoc); 746*154cabe7SRiver Riddle void executeEraseOp(PatternRewriter &rewriter); 747*154cabe7SRiver Riddle void executeGetAttribute(); 748*154cabe7SRiver Riddle void executeGetAttributeType(); 749*154cabe7SRiver Riddle void executeGetDefiningOp(); 750*154cabe7SRiver Riddle void executeGetOperand(unsigned index); 751*154cabe7SRiver Riddle void executeGetResult(unsigned index); 752*154cabe7SRiver Riddle void executeGetValueType(); 753*154cabe7SRiver Riddle void executeIsNotNull(); 754*154cabe7SRiver Riddle void executeRecordMatch(PatternRewriter &rewriter, 755*154cabe7SRiver Riddle SmallVectorImpl<PDLByteCode::MatchResult> &matches); 756*154cabe7SRiver Riddle void executeReplaceOp(PatternRewriter &rewriter); 757*154cabe7SRiver Riddle void executeSwitchAttribute(); 758*154cabe7SRiver Riddle void executeSwitchOperandCount(); 759*154cabe7SRiver Riddle void executeSwitchOperationName(); 760*154cabe7SRiver Riddle void executeSwitchResultCount(); 761*154cabe7SRiver Riddle void executeSwitchType(); 762*154cabe7SRiver Riddle 763abfd1a8bSRiver Riddle /// Read a value from the bytecode buffer, optionally skipping a certain 764abfd1a8bSRiver Riddle /// number of prefix values. These methods always update the buffer to point 765abfd1a8bSRiver Riddle /// to the next field after the read data. 766abfd1a8bSRiver Riddle template <typename T = ByteCodeField> 767abfd1a8bSRiver Riddle T read(size_t skipN = 0) { 768abfd1a8bSRiver Riddle curCodeIt += skipN; 769abfd1a8bSRiver Riddle return readImpl<T>(); 770abfd1a8bSRiver Riddle } 771abfd1a8bSRiver Riddle ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); } 772abfd1a8bSRiver Riddle 773abfd1a8bSRiver Riddle /// Read a list of values from the bytecode buffer. 774abfd1a8bSRiver Riddle template <typename ValueT, typename T> 775abfd1a8bSRiver Riddle void readList(SmallVectorImpl<T> &list) { 776abfd1a8bSRiver Riddle list.clear(); 777abfd1a8bSRiver Riddle for (unsigned i = 0, e = read(); i != e; ++i) 778abfd1a8bSRiver Riddle list.push_back(read<ValueT>()); 779abfd1a8bSRiver Riddle } 780abfd1a8bSRiver Riddle 781abfd1a8bSRiver Riddle /// Jump to a specific successor based on a predicate value. 782abfd1a8bSRiver Riddle void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); } 783abfd1a8bSRiver Riddle /// Jump to a specific successor based on a destination index. 784abfd1a8bSRiver Riddle void selectJump(size_t destIndex) { 785abfd1a8bSRiver Riddle curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)]; 786abfd1a8bSRiver Riddle } 787abfd1a8bSRiver Riddle 788abfd1a8bSRiver Riddle /// Handle a switch operation with the provided value and cases. 789abfd1a8bSRiver Riddle template <typename T, typename RangeT> 790abfd1a8bSRiver Riddle void handleSwitch(const T &value, RangeT &&cases) { 791abfd1a8bSRiver Riddle LLVM_DEBUG({ 792abfd1a8bSRiver Riddle llvm::dbgs() << " * Value: " << value << "\n" 793abfd1a8bSRiver Riddle << " * Cases: "; 794abfd1a8bSRiver Riddle llvm::interleaveComma(cases, llvm::dbgs()); 795*154cabe7SRiver Riddle llvm::dbgs() << "\n"; 796abfd1a8bSRiver Riddle }); 797abfd1a8bSRiver Riddle 798abfd1a8bSRiver Riddle // Check to see if the attribute value is within the case list. Jump to 799abfd1a8bSRiver Riddle // the correct successor index based on the result. 800f80b6304SRiver Riddle for (auto it = cases.begin(), e = cases.end(); it != e; ++it) 801f80b6304SRiver Riddle if (*it == value) 802f80b6304SRiver Riddle return selectJump(size_t((it - cases.begin()) + 1)); 803f80b6304SRiver Riddle selectJump(size_t(0)); 804abfd1a8bSRiver Riddle } 805abfd1a8bSRiver Riddle 806abfd1a8bSRiver Riddle /// Internal implementation of reading various data types from the bytecode 807abfd1a8bSRiver Riddle /// stream. 808abfd1a8bSRiver Riddle template <typename T> 809abfd1a8bSRiver Riddle const void *readFromMemory() { 810abfd1a8bSRiver Riddle size_t index = *curCodeIt++; 811abfd1a8bSRiver Riddle 812abfd1a8bSRiver Riddle // If this type is an SSA value, it can only be stored in non-const memory. 813abfd1a8bSRiver Riddle if (llvm::is_one_of<T, Operation *, Value>::value || index < memory.size()) 814abfd1a8bSRiver Riddle return memory[index]; 815abfd1a8bSRiver Riddle 816abfd1a8bSRiver Riddle // Otherwise, if this index is not inbounds it is uniqued. 817abfd1a8bSRiver Riddle return uniquedMemory[index - memory.size()]; 818abfd1a8bSRiver Riddle } 819abfd1a8bSRiver Riddle template <typename T> 820abfd1a8bSRiver Riddle std::enable_if_t<std::is_pointer<T>::value, T> readImpl() { 821abfd1a8bSRiver Riddle return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>())); 822abfd1a8bSRiver Riddle } 823abfd1a8bSRiver Riddle template <typename T> 824abfd1a8bSRiver Riddle std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value, 825abfd1a8bSRiver Riddle T> 826abfd1a8bSRiver Riddle readImpl() { 827abfd1a8bSRiver Riddle return T(T::getFromOpaquePointer(readFromMemory<T>())); 828abfd1a8bSRiver Riddle } 829abfd1a8bSRiver Riddle template <typename T> 830abfd1a8bSRiver Riddle std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() { 831abfd1a8bSRiver Riddle switch (static_cast<PDLValueKind>(read())) { 832abfd1a8bSRiver Riddle case PDLValueKind::Attribute: 833abfd1a8bSRiver Riddle return read<Attribute>(); 834abfd1a8bSRiver Riddle case PDLValueKind::Operation: 835abfd1a8bSRiver Riddle return read<Operation *>(); 836abfd1a8bSRiver Riddle case PDLValueKind::Type: 837abfd1a8bSRiver Riddle return read<Type>(); 838abfd1a8bSRiver Riddle case PDLValueKind::Value: 839abfd1a8bSRiver Riddle return read<Value>(); 840abfd1a8bSRiver Riddle } 8417dadcd02SMehdi Amini llvm_unreachable("unhandled PDLValueKind"); 842abfd1a8bSRiver Riddle } 843abfd1a8bSRiver Riddle template <typename T> 844abfd1a8bSRiver Riddle std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() { 845abfd1a8bSRiver Riddle static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, 846abfd1a8bSRiver Riddle "unexpected ByteCode address size"); 847abfd1a8bSRiver Riddle ByteCodeAddr result; 848abfd1a8bSRiver Riddle std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr)); 849abfd1a8bSRiver Riddle curCodeIt += 2; 850abfd1a8bSRiver Riddle return result; 851abfd1a8bSRiver Riddle } 852abfd1a8bSRiver Riddle template <typename T> 853abfd1a8bSRiver Riddle std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() { 854abfd1a8bSRiver Riddle return *curCodeIt++; 855abfd1a8bSRiver Riddle } 856abfd1a8bSRiver Riddle 857abfd1a8bSRiver Riddle /// The underlying bytecode buffer. 858abfd1a8bSRiver Riddle const ByteCodeField *curCodeIt; 859abfd1a8bSRiver Riddle 860abfd1a8bSRiver Riddle /// The current execution memory. 861abfd1a8bSRiver Riddle MutableArrayRef<const void *> memory; 862abfd1a8bSRiver Riddle 863abfd1a8bSRiver Riddle /// References to ByteCode data necessary for execution. 864abfd1a8bSRiver Riddle ArrayRef<const void *> uniquedMemory; 865abfd1a8bSRiver Riddle ArrayRef<ByteCodeField> code; 866abfd1a8bSRiver Riddle ArrayRef<PatternBenefit> currentPatternBenefits; 867abfd1a8bSRiver Riddle ArrayRef<PDLByteCodePattern> patterns; 868abfd1a8bSRiver Riddle ArrayRef<PDLConstraintFunction> constraintFunctions; 869abfd1a8bSRiver Riddle ArrayRef<PDLCreateFunction> createFunctions; 870abfd1a8bSRiver Riddle ArrayRef<PDLRewriteFunction> rewriteFunctions; 871abfd1a8bSRiver Riddle }; 872abfd1a8bSRiver Riddle } // end anonymous namespace 873abfd1a8bSRiver Riddle 874*154cabe7SRiver Riddle void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { 875abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); 876abfd1a8bSRiver Riddle const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; 877abfd1a8bSRiver Riddle ArrayAttr constParams = read<ArrayAttr>(); 878abfd1a8bSRiver Riddle SmallVector<PDLValue, 16> args; 879abfd1a8bSRiver Riddle readList<PDLValue>(args); 880*154cabe7SRiver Riddle 881abfd1a8bSRiver Riddle LLVM_DEBUG({ 882abfd1a8bSRiver Riddle llvm::dbgs() << " * Arguments: "; 883abfd1a8bSRiver Riddle llvm::interleaveComma(args, llvm::dbgs()); 884*154cabe7SRiver Riddle llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; 885abfd1a8bSRiver Riddle }); 886abfd1a8bSRiver Riddle 887abfd1a8bSRiver Riddle // Invoke the constraint and jump to the proper destination. 888abfd1a8bSRiver Riddle selectJump(succeeded(constraintFn(args, constParams, rewriter))); 889abfd1a8bSRiver Riddle } 890*154cabe7SRiver Riddle 891*154cabe7SRiver Riddle void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { 892abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); 893abfd1a8bSRiver Riddle const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; 894abfd1a8bSRiver Riddle ArrayAttr constParams = read<ArrayAttr>(); 895abfd1a8bSRiver Riddle Operation *root = read<Operation *>(); 896abfd1a8bSRiver Riddle SmallVector<PDLValue, 16> args; 897abfd1a8bSRiver Riddle readList<PDLValue>(args); 898abfd1a8bSRiver Riddle 899abfd1a8bSRiver Riddle LLVM_DEBUG({ 900*154cabe7SRiver Riddle llvm::dbgs() << " * Root: " << *root << "\n * Arguments: "; 901abfd1a8bSRiver Riddle llvm::interleaveComma(args, llvm::dbgs()); 902*154cabe7SRiver Riddle llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; 903abfd1a8bSRiver Riddle }); 904*154cabe7SRiver Riddle 905*154cabe7SRiver Riddle // Invoke the native rewrite function. 906abfd1a8bSRiver Riddle rewriteFn(root, args, constParams, rewriter); 907abfd1a8bSRiver Riddle } 908*154cabe7SRiver Riddle 909*154cabe7SRiver Riddle void ByteCodeExecutor::executeAreEqual() { 910abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); 911abfd1a8bSRiver Riddle const void *lhs = read<const void *>(); 912abfd1a8bSRiver Riddle const void *rhs = read<const void *>(); 913abfd1a8bSRiver Riddle 914*154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n"); 915abfd1a8bSRiver Riddle selectJump(lhs == rhs); 916abfd1a8bSRiver Riddle } 917*154cabe7SRiver Riddle 918*154cabe7SRiver Riddle void ByteCodeExecutor::executeBranch() { 919*154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n"); 920abfd1a8bSRiver Riddle curCodeIt = &code[read<ByteCodeAddr>()]; 921abfd1a8bSRiver Riddle } 922*154cabe7SRiver Riddle 923*154cabe7SRiver Riddle void ByteCodeExecutor::executeCheckOperandCount() { 924abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n"); 925abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 926abfd1a8bSRiver Riddle uint32_t expectedCount = read<uint32_t>(); 927abfd1a8bSRiver Riddle 928abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n" 929*154cabe7SRiver Riddle << " * Expected: " << expectedCount << "\n"); 930abfd1a8bSRiver Riddle selectJump(op->getNumOperands() == expectedCount); 931abfd1a8bSRiver Riddle } 932*154cabe7SRiver Riddle 933*154cabe7SRiver Riddle void ByteCodeExecutor::executeCheckOperationName() { 934abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n"); 935abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 936abfd1a8bSRiver Riddle OperationName expectedName = read<OperationName>(); 937abfd1a8bSRiver Riddle 938*154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n" 939*154cabe7SRiver Riddle << " * Expected: \"" << expectedName << "\"\n"); 940abfd1a8bSRiver Riddle selectJump(op->getName() == expectedName); 941abfd1a8bSRiver Riddle } 942*154cabe7SRiver Riddle 943*154cabe7SRiver Riddle void ByteCodeExecutor::executeCheckResultCount() { 944abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n"); 945abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 946abfd1a8bSRiver Riddle uint32_t expectedCount = read<uint32_t>(); 947abfd1a8bSRiver Riddle 948abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n" 949*154cabe7SRiver Riddle << " * Expected: " << expectedCount << "\n"); 950abfd1a8bSRiver Riddle selectJump(op->getNumResults() == expectedCount); 951abfd1a8bSRiver Riddle } 952*154cabe7SRiver Riddle 953*154cabe7SRiver Riddle void ByteCodeExecutor::executeCreateNative(PatternRewriter &rewriter) { 954abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing CreateNative:\n"); 955abfd1a8bSRiver Riddle const PDLCreateFunction &createFn = createFunctions[read()]; 956abfd1a8bSRiver Riddle ByteCodeField resultIndex = read(); 957abfd1a8bSRiver Riddle ArrayAttr constParams = read<ArrayAttr>(); 958abfd1a8bSRiver Riddle SmallVector<PDLValue, 16> args; 959abfd1a8bSRiver Riddle readList<PDLValue>(args); 960abfd1a8bSRiver Riddle 961abfd1a8bSRiver Riddle LLVM_DEBUG({ 962abfd1a8bSRiver Riddle llvm::dbgs() << " * Arguments: "; 963abfd1a8bSRiver Riddle llvm::interleaveComma(args, llvm::dbgs()); 964abfd1a8bSRiver Riddle llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; 965abfd1a8bSRiver Riddle }); 966abfd1a8bSRiver Riddle 967abfd1a8bSRiver Riddle PDLValue result = createFn(args, constParams, rewriter); 968abfd1a8bSRiver Riddle memory[resultIndex] = result.getAsOpaquePointer(); 969abfd1a8bSRiver Riddle 970*154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); 971abfd1a8bSRiver Riddle } 972*154cabe7SRiver Riddle 973*154cabe7SRiver Riddle void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, 974*154cabe7SRiver Riddle Location mainRewriteLoc) { 975abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n"); 976abfd1a8bSRiver Riddle 977abfd1a8bSRiver Riddle unsigned memIndex = read(); 978*154cabe7SRiver Riddle OperationState state(mainRewriteLoc, read<OperationName>()); 979abfd1a8bSRiver Riddle readList<Value>(state.operands); 980abfd1a8bSRiver Riddle for (unsigned i = 0, e = read(); i != e; ++i) { 981abfd1a8bSRiver Riddle Identifier name = read<Identifier>(); 982abfd1a8bSRiver Riddle if (Attribute attr = read<Attribute>()) 983abfd1a8bSRiver Riddle state.addAttribute(name, attr); 984abfd1a8bSRiver Riddle } 985abfd1a8bSRiver Riddle 986abfd1a8bSRiver Riddle bool hasInferredTypes = false; 987abfd1a8bSRiver Riddle for (unsigned i = 0, e = read(); i != e; ++i) { 988abfd1a8bSRiver Riddle Type resultType = read<Type>(); 989abfd1a8bSRiver Riddle hasInferredTypes |= !resultType; 990abfd1a8bSRiver Riddle state.types.push_back(resultType); 991abfd1a8bSRiver Riddle } 992abfd1a8bSRiver Riddle 993abfd1a8bSRiver Riddle // Handle the case where the operation has inferred types. 994abfd1a8bSRiver Riddle if (hasInferredTypes) { 995abfd1a8bSRiver Riddle InferTypeOpInterface::Concept *concept = 996*154cabe7SRiver Riddle state.name.getAbstractOperation()->getInterface<InferTypeOpInterface>(); 997abfd1a8bSRiver Riddle 998abfd1a8bSRiver Riddle // TODO: Handle failure. 999abfd1a8bSRiver Riddle SmallVector<Type, 2> inferredTypes; 1000abfd1a8bSRiver Riddle if (failed(concept->inferReturnTypes( 1001abfd1a8bSRiver Riddle state.getContext(), state.location, state.operands, 1002*154cabe7SRiver Riddle state.attributes.getDictionary(state.getContext()), state.regions, 1003*154cabe7SRiver Riddle inferredTypes))) 1004abfd1a8bSRiver Riddle return; 1005abfd1a8bSRiver Riddle 1006abfd1a8bSRiver Riddle for (unsigned i = 0, e = state.types.size(); i != e; ++i) 1007abfd1a8bSRiver Riddle if (!state.types[i]) 1008abfd1a8bSRiver Riddle state.types[i] = inferredTypes[i]; 1009abfd1a8bSRiver Riddle } 1010abfd1a8bSRiver Riddle Operation *resultOp = rewriter.createOperation(state); 1011abfd1a8bSRiver Riddle memory[memIndex] = resultOp; 1012abfd1a8bSRiver Riddle 1013abfd1a8bSRiver Riddle LLVM_DEBUG({ 1014abfd1a8bSRiver Riddle llvm::dbgs() << " * Attributes: " 1015abfd1a8bSRiver Riddle << state.attributes.getDictionary(state.getContext()) 1016abfd1a8bSRiver Riddle << "\n * Operands: "; 1017abfd1a8bSRiver Riddle llvm::interleaveComma(state.operands, llvm::dbgs()); 1018abfd1a8bSRiver Riddle llvm::dbgs() << "\n * Result Types: "; 1019abfd1a8bSRiver Riddle llvm::interleaveComma(state.types, llvm::dbgs()); 1020*154cabe7SRiver Riddle llvm::dbgs() << "\n * Result: " << *resultOp << "\n"; 1021abfd1a8bSRiver Riddle }); 1022abfd1a8bSRiver Riddle } 1023*154cabe7SRiver Riddle 1024*154cabe7SRiver Riddle void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) { 1025abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n"); 1026abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1027abfd1a8bSRiver Riddle 1028*154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1029abfd1a8bSRiver Riddle rewriter.eraseOp(op); 1030abfd1a8bSRiver Riddle } 1031*154cabe7SRiver Riddle 1032*154cabe7SRiver Riddle void ByteCodeExecutor::executeGetAttribute() { 1033abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n"); 1034abfd1a8bSRiver Riddle unsigned memIndex = read(); 1035abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1036abfd1a8bSRiver Riddle Identifier attrName = read<Identifier>(); 1037abfd1a8bSRiver Riddle Attribute attr = op->getAttr(attrName); 1038abfd1a8bSRiver Riddle 1039abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1040abfd1a8bSRiver Riddle << " * Attribute: " << attrName << "\n" 1041*154cabe7SRiver Riddle << " * Result: " << attr << "\n"); 1042abfd1a8bSRiver Riddle memory[memIndex] = attr.getAsOpaquePointer(); 1043abfd1a8bSRiver Riddle } 1044*154cabe7SRiver Riddle 1045*154cabe7SRiver Riddle void ByteCodeExecutor::executeGetAttributeType() { 1046abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n"); 1047abfd1a8bSRiver Riddle unsigned memIndex = read(); 1048abfd1a8bSRiver Riddle Attribute attr = read<Attribute>(); 1049*154cabe7SRiver Riddle Type type = attr ? attr.getType() : Type(); 1050abfd1a8bSRiver Riddle 1051abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n" 1052*154cabe7SRiver Riddle << " * Result: " << type << "\n"); 1053*154cabe7SRiver Riddle memory[memIndex] = type.getAsOpaquePointer(); 1054abfd1a8bSRiver Riddle } 1055*154cabe7SRiver Riddle 1056*154cabe7SRiver Riddle void ByteCodeExecutor::executeGetDefiningOp() { 1057abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n"); 1058abfd1a8bSRiver Riddle unsigned memIndex = read(); 1059abfd1a8bSRiver Riddle Value value = read<Value>(); 1060abfd1a8bSRiver Riddle Operation *op = value ? value.getDefiningOp() : nullptr; 1061abfd1a8bSRiver Riddle 1062abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" 1063*154cabe7SRiver Riddle << " * Result: " << *op << "\n"); 1064abfd1a8bSRiver Riddle memory[memIndex] = op; 1065abfd1a8bSRiver Riddle } 1066*154cabe7SRiver Riddle 1067*154cabe7SRiver Riddle void ByteCodeExecutor::executeGetOperand(unsigned index) { 1068abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1069abfd1a8bSRiver Riddle unsigned memIndex = read(); 1070abfd1a8bSRiver Riddle Value operand = 1071abfd1a8bSRiver Riddle index < op->getNumOperands() ? op->getOperand(index) : Value(); 1072abfd1a8bSRiver Riddle 1073abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1074abfd1a8bSRiver Riddle << " * Index: " << index << "\n" 1075*154cabe7SRiver Riddle << " * Result: " << operand << "\n"); 1076abfd1a8bSRiver Riddle memory[memIndex] = operand.getAsOpaquePointer(); 1077abfd1a8bSRiver Riddle } 1078*154cabe7SRiver Riddle 1079*154cabe7SRiver Riddle void ByteCodeExecutor::executeGetResult(unsigned index) { 1080abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1081abfd1a8bSRiver Riddle unsigned memIndex = read(); 1082abfd1a8bSRiver Riddle OpResult result = 1083abfd1a8bSRiver Riddle index < op->getNumResults() ? op->getResult(index) : OpResult(); 1084abfd1a8bSRiver Riddle 1085abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" 1086abfd1a8bSRiver Riddle << " * Index: " << index << "\n" 1087*154cabe7SRiver Riddle << " * Result: " << result << "\n"); 1088abfd1a8bSRiver Riddle memory[memIndex] = result.getAsOpaquePointer(); 1089abfd1a8bSRiver Riddle } 1090*154cabe7SRiver Riddle 1091*154cabe7SRiver Riddle void ByteCodeExecutor::executeGetValueType() { 1092abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n"); 1093abfd1a8bSRiver Riddle unsigned memIndex = read(); 1094abfd1a8bSRiver Riddle Value value = read<Value>(); 1095*154cabe7SRiver Riddle Type type = value ? value.getType() : Type(); 1096abfd1a8bSRiver Riddle 1097abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" 1098*154cabe7SRiver Riddle << " * Result: " << type << "\n"); 1099*154cabe7SRiver Riddle memory[memIndex] = type.getAsOpaquePointer(); 1100abfd1a8bSRiver Riddle } 1101*154cabe7SRiver Riddle 1102*154cabe7SRiver Riddle void ByteCodeExecutor::executeIsNotNull() { 1103abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n"); 1104abfd1a8bSRiver Riddle const void *value = read<const void *>(); 1105abfd1a8bSRiver Riddle 1106*154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); 1107abfd1a8bSRiver Riddle selectJump(value != nullptr); 1108abfd1a8bSRiver Riddle } 1109*154cabe7SRiver Riddle 1110*154cabe7SRiver Riddle void ByteCodeExecutor::executeRecordMatch( 1111*154cabe7SRiver Riddle PatternRewriter &rewriter, 1112*154cabe7SRiver Riddle SmallVectorImpl<PDLByteCode::MatchResult> &matches) { 1113abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n"); 1114abfd1a8bSRiver Riddle unsigned patternIndex = read(); 1115abfd1a8bSRiver Riddle PatternBenefit benefit = currentPatternBenefits[patternIndex]; 1116abfd1a8bSRiver Riddle const ByteCodeField *dest = &code[read<ByteCodeAddr>()]; 1117abfd1a8bSRiver Riddle 1118abfd1a8bSRiver Riddle // If the benefit of the pattern is impossible, skip the processing of the 1119abfd1a8bSRiver Riddle // rest of the pattern. 1120abfd1a8bSRiver Riddle if (benefit.isImpossibleToMatch()) { 1121*154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n"); 1122abfd1a8bSRiver Riddle curCodeIt = dest; 1123*154cabe7SRiver Riddle return; 1124abfd1a8bSRiver Riddle } 1125abfd1a8bSRiver Riddle 1126abfd1a8bSRiver Riddle // Create a fused location containing the locations of each of the 1127abfd1a8bSRiver Riddle // operations used in the match. This will be used as the location for 1128abfd1a8bSRiver Riddle // created operations during the rewrite that don't already have an 1129abfd1a8bSRiver Riddle // explicit location set. 1130abfd1a8bSRiver Riddle unsigned numMatchLocs = read(); 1131abfd1a8bSRiver Riddle SmallVector<Location, 4> matchLocs; 1132abfd1a8bSRiver Riddle matchLocs.reserve(numMatchLocs); 1133abfd1a8bSRiver Riddle for (unsigned i = 0; i != numMatchLocs; ++i) 1134abfd1a8bSRiver Riddle matchLocs.push_back(read<Operation *>()->getLoc()); 1135abfd1a8bSRiver Riddle Location matchLoc = rewriter.getFusedLoc(matchLocs); 1136abfd1a8bSRiver Riddle 1137abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n" 1138*154cabe7SRiver Riddle << " * Location: " << matchLoc << "\n"); 1139*154cabe7SRiver Riddle matches.emplace_back(matchLoc, patterns[patternIndex], benefit); 1140*154cabe7SRiver Riddle readList<const void *>(matches.back().values); 1141abfd1a8bSRiver Riddle curCodeIt = dest; 1142abfd1a8bSRiver Riddle } 1143*154cabe7SRiver Riddle 1144*154cabe7SRiver Riddle void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) { 1145abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n"); 1146abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1147abfd1a8bSRiver Riddle SmallVector<Value, 16> args; 1148abfd1a8bSRiver Riddle readList<Value>(args); 1149abfd1a8bSRiver Riddle 1150abfd1a8bSRiver Riddle LLVM_DEBUG({ 1151abfd1a8bSRiver Riddle llvm::dbgs() << " * Operation: " << *op << "\n" 1152abfd1a8bSRiver Riddle << " * Values: "; 1153abfd1a8bSRiver Riddle llvm::interleaveComma(args, llvm::dbgs()); 1154*154cabe7SRiver Riddle llvm::dbgs() << "\n"; 1155abfd1a8bSRiver Riddle }); 1156abfd1a8bSRiver Riddle rewriter.replaceOp(op, args); 1157abfd1a8bSRiver Riddle } 1158*154cabe7SRiver Riddle 1159*154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchAttribute() { 1160abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n"); 1161abfd1a8bSRiver Riddle Attribute value = read<Attribute>(); 1162abfd1a8bSRiver Riddle ArrayAttr cases = read<ArrayAttr>(); 1163abfd1a8bSRiver Riddle handleSwitch(value, cases); 1164abfd1a8bSRiver Riddle } 1165*154cabe7SRiver Riddle 1166*154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchOperandCount() { 1167abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n"); 1168abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1169abfd1a8bSRiver Riddle auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1170abfd1a8bSRiver Riddle 1171abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1172abfd1a8bSRiver Riddle handleSwitch(op->getNumOperands(), cases); 1173abfd1a8bSRiver Riddle } 1174*154cabe7SRiver Riddle 1175*154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchOperationName() { 1176abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n"); 1177abfd1a8bSRiver Riddle OperationName value = read<Operation *>()->getName(); 1178abfd1a8bSRiver Riddle size_t caseCount = read(); 1179abfd1a8bSRiver Riddle 1180abfd1a8bSRiver Riddle // The operation names are stored in-line, so to print them out for 1181abfd1a8bSRiver Riddle // debugging purposes we need to read the array before executing the 1182abfd1a8bSRiver Riddle // switch so that we can display all of the possible values. 1183abfd1a8bSRiver Riddle LLVM_DEBUG({ 1184abfd1a8bSRiver Riddle const ByteCodeField *prevCodeIt = curCodeIt; 1185abfd1a8bSRiver Riddle llvm::dbgs() << " * Value: " << value << "\n" 1186abfd1a8bSRiver Riddle << " * Cases: "; 1187abfd1a8bSRiver Riddle llvm::interleaveComma( 1188abfd1a8bSRiver Riddle llvm::map_range(llvm::seq<size_t>(0, caseCount), 1189*154cabe7SRiver Riddle [&](size_t) { return read<OperationName>(); }), 1190abfd1a8bSRiver Riddle llvm::dbgs()); 1191*154cabe7SRiver Riddle llvm::dbgs() << "\n"; 1192abfd1a8bSRiver Riddle curCodeIt = prevCodeIt; 1193abfd1a8bSRiver Riddle }); 1194abfd1a8bSRiver Riddle 1195abfd1a8bSRiver Riddle // Try to find the switch value within any of the cases. 1196abfd1a8bSRiver Riddle for (size_t i = 0; i != caseCount; ++i) { 1197abfd1a8bSRiver Riddle if (read<OperationName>() == value) { 1198abfd1a8bSRiver Riddle curCodeIt += (caseCount - i - 1); 1199*154cabe7SRiver Riddle return selectJump(i + 1); 1200abfd1a8bSRiver Riddle } 1201abfd1a8bSRiver Riddle } 1202*154cabe7SRiver Riddle selectJump(size_t(0)); 1203abfd1a8bSRiver Riddle } 1204*154cabe7SRiver Riddle 1205*154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchResultCount() { 1206abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n"); 1207abfd1a8bSRiver Riddle Operation *op = read<Operation *>(); 1208abfd1a8bSRiver Riddle auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); 1209abfd1a8bSRiver Riddle 1210abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); 1211abfd1a8bSRiver Riddle handleSwitch(op->getNumResults(), cases); 1212abfd1a8bSRiver Riddle } 1213*154cabe7SRiver Riddle 1214*154cabe7SRiver Riddle void ByteCodeExecutor::executeSwitchType() { 1215abfd1a8bSRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n"); 1216abfd1a8bSRiver Riddle Type value = read<Type>(); 1217abfd1a8bSRiver Riddle auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>(); 1218abfd1a8bSRiver Riddle handleSwitch(value, cases); 1219*154cabe7SRiver Riddle } 1220*154cabe7SRiver Riddle 1221*154cabe7SRiver Riddle void ByteCodeExecutor::execute( 1222*154cabe7SRiver Riddle PatternRewriter &rewriter, 1223*154cabe7SRiver Riddle SmallVectorImpl<PDLByteCode::MatchResult> *matches, 1224*154cabe7SRiver Riddle Optional<Location> mainRewriteLoc) { 1225*154cabe7SRiver Riddle while (true) { 1226*154cabe7SRiver Riddle OpCode opCode = static_cast<OpCode>(read()); 1227*154cabe7SRiver Riddle switch (opCode) { 1228*154cabe7SRiver Riddle case ApplyConstraint: 1229*154cabe7SRiver Riddle executeApplyConstraint(rewriter); 1230*154cabe7SRiver Riddle break; 1231*154cabe7SRiver Riddle case ApplyRewrite: 1232*154cabe7SRiver Riddle executeApplyRewrite(rewriter); 1233*154cabe7SRiver Riddle break; 1234*154cabe7SRiver Riddle case AreEqual: 1235*154cabe7SRiver Riddle executeAreEqual(); 1236*154cabe7SRiver Riddle break; 1237*154cabe7SRiver Riddle case Branch: 1238*154cabe7SRiver Riddle executeBranch(); 1239*154cabe7SRiver Riddle break; 1240*154cabe7SRiver Riddle case CheckOperandCount: 1241*154cabe7SRiver Riddle executeCheckOperandCount(); 1242*154cabe7SRiver Riddle break; 1243*154cabe7SRiver Riddle case CheckOperationName: 1244*154cabe7SRiver Riddle executeCheckOperationName(); 1245*154cabe7SRiver Riddle break; 1246*154cabe7SRiver Riddle case CheckResultCount: 1247*154cabe7SRiver Riddle executeCheckResultCount(); 1248*154cabe7SRiver Riddle break; 1249*154cabe7SRiver Riddle case CreateNative: 1250*154cabe7SRiver Riddle executeCreateNative(rewriter); 1251*154cabe7SRiver Riddle break; 1252*154cabe7SRiver Riddle case CreateOperation: 1253*154cabe7SRiver Riddle executeCreateOperation(rewriter, *mainRewriteLoc); 1254*154cabe7SRiver Riddle break; 1255*154cabe7SRiver Riddle case EraseOp: 1256*154cabe7SRiver Riddle executeEraseOp(rewriter); 1257*154cabe7SRiver Riddle break; 1258*154cabe7SRiver Riddle case Finalize: 1259*154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n"); 1260*154cabe7SRiver Riddle return; 1261*154cabe7SRiver Riddle case GetAttribute: 1262*154cabe7SRiver Riddle executeGetAttribute(); 1263*154cabe7SRiver Riddle break; 1264*154cabe7SRiver Riddle case GetAttributeType: 1265*154cabe7SRiver Riddle executeGetAttributeType(); 1266*154cabe7SRiver Riddle break; 1267*154cabe7SRiver Riddle case GetDefiningOp: 1268*154cabe7SRiver Riddle executeGetDefiningOp(); 1269*154cabe7SRiver Riddle break; 1270*154cabe7SRiver Riddle case GetOperand0: 1271*154cabe7SRiver Riddle case GetOperand1: 1272*154cabe7SRiver Riddle case GetOperand2: 1273*154cabe7SRiver Riddle case GetOperand3: { 1274*154cabe7SRiver Riddle unsigned index = opCode - GetOperand0; 1275*154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n"); 1276*154cabe7SRiver Riddle executeGetOperand(opCode - GetOperand0); 1277abfd1a8bSRiver Riddle break; 1278abfd1a8bSRiver Riddle } 1279*154cabe7SRiver Riddle case GetOperandN: 1280*154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n"); 1281*154cabe7SRiver Riddle executeGetOperand(read<uint32_t>()); 1282*154cabe7SRiver Riddle break; 1283*154cabe7SRiver Riddle case GetResult0: 1284*154cabe7SRiver Riddle case GetResult1: 1285*154cabe7SRiver Riddle case GetResult2: 1286*154cabe7SRiver Riddle case GetResult3: { 1287*154cabe7SRiver Riddle unsigned index = opCode - GetResult0; 1288*154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n"); 1289*154cabe7SRiver Riddle executeGetResult(opCode - GetResult0); 1290*154cabe7SRiver Riddle break; 1291abfd1a8bSRiver Riddle } 1292*154cabe7SRiver Riddle case GetResultN: 1293*154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n"); 1294*154cabe7SRiver Riddle executeGetResult(read<uint32_t>()); 1295*154cabe7SRiver Riddle break; 1296*154cabe7SRiver Riddle case GetValueType: 1297*154cabe7SRiver Riddle executeGetValueType(); 1298*154cabe7SRiver Riddle break; 1299*154cabe7SRiver Riddle case IsNotNull: 1300*154cabe7SRiver Riddle executeIsNotNull(); 1301*154cabe7SRiver Riddle break; 1302*154cabe7SRiver Riddle case RecordMatch: 1303*154cabe7SRiver Riddle assert(matches && 1304*154cabe7SRiver Riddle "expected matches to be provided when executing the matcher"); 1305*154cabe7SRiver Riddle executeRecordMatch(rewriter, *matches); 1306*154cabe7SRiver Riddle break; 1307*154cabe7SRiver Riddle case ReplaceOp: 1308*154cabe7SRiver Riddle executeReplaceOp(rewriter); 1309*154cabe7SRiver Riddle break; 1310*154cabe7SRiver Riddle case SwitchAttribute: 1311*154cabe7SRiver Riddle executeSwitchAttribute(); 1312*154cabe7SRiver Riddle break; 1313*154cabe7SRiver Riddle case SwitchOperandCount: 1314*154cabe7SRiver Riddle executeSwitchOperandCount(); 1315*154cabe7SRiver Riddle break; 1316*154cabe7SRiver Riddle case SwitchOperationName: 1317*154cabe7SRiver Riddle executeSwitchOperationName(); 1318*154cabe7SRiver Riddle break; 1319*154cabe7SRiver Riddle case SwitchResultCount: 1320*154cabe7SRiver Riddle executeSwitchResultCount(); 1321*154cabe7SRiver Riddle break; 1322*154cabe7SRiver Riddle case SwitchType: 1323*154cabe7SRiver Riddle executeSwitchType(); 1324*154cabe7SRiver Riddle break; 1325*154cabe7SRiver Riddle } 1326*154cabe7SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "\n"); 1327abfd1a8bSRiver Riddle } 1328abfd1a8bSRiver Riddle } 1329abfd1a8bSRiver Riddle 1330abfd1a8bSRiver Riddle /// Run the pattern matcher on the given root operation, collecting the matched 1331abfd1a8bSRiver Riddle /// patterns in `matches`. 1332abfd1a8bSRiver Riddle void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, 1333abfd1a8bSRiver Riddle SmallVectorImpl<MatchResult> &matches, 1334abfd1a8bSRiver Riddle PDLByteCodeMutableState &state) const { 1335abfd1a8bSRiver Riddle // The first memory slot is always the root operation. 1336abfd1a8bSRiver Riddle state.memory[0] = op; 1337abfd1a8bSRiver Riddle 1338abfd1a8bSRiver Riddle // The matcher function always starts at code address 0. 1339abfd1a8bSRiver Riddle ByteCodeExecutor executor(matcherByteCode.data(), state.memory, uniquedData, 1340abfd1a8bSRiver Riddle matcherByteCode, state.currentPatternBenefits, 1341abfd1a8bSRiver Riddle patterns, constraintFunctions, createFunctions, 1342abfd1a8bSRiver Riddle rewriteFunctions); 1343abfd1a8bSRiver Riddle executor.execute(rewriter, &matches); 1344abfd1a8bSRiver Riddle 1345abfd1a8bSRiver Riddle // Order the found matches by benefit. 1346abfd1a8bSRiver Riddle std::stable_sort(matches.begin(), matches.end(), 1347abfd1a8bSRiver Riddle [](const MatchResult &lhs, const MatchResult &rhs) { 1348abfd1a8bSRiver Riddle return lhs.benefit > rhs.benefit; 1349abfd1a8bSRiver Riddle }); 1350abfd1a8bSRiver Riddle } 1351abfd1a8bSRiver Riddle 1352abfd1a8bSRiver Riddle /// Run the rewriter of the given pattern on the root operation `op`. 1353abfd1a8bSRiver Riddle void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match, 1354abfd1a8bSRiver Riddle PDLByteCodeMutableState &state) const { 1355abfd1a8bSRiver Riddle // The arguments of the rewrite function are stored at the start of the 1356abfd1a8bSRiver Riddle // memory buffer. 1357abfd1a8bSRiver Riddle llvm::copy(match.values, state.memory.begin()); 1358abfd1a8bSRiver Riddle 1359abfd1a8bSRiver Riddle ByteCodeExecutor executor( 1360abfd1a8bSRiver Riddle &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory, 1361abfd1a8bSRiver Riddle uniquedData, rewriterByteCode, state.currentPatternBenefits, patterns, 1362abfd1a8bSRiver Riddle constraintFunctions, createFunctions, rewriteFunctions); 1363abfd1a8bSRiver Riddle executor.execute(rewriter, /*matches=*/nullptr, match.location); 1364abfd1a8bSRiver Riddle } 1365