1 //===- PDLInterp.cpp - PDL Interpreter Dialect ------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 10 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 11 #include "mlir/IR/BuiltinTypes.h" 12 #include "mlir/IR/DialectImplementation.h" 13 #include "mlir/IR/FunctionImplementation.h" 14 15 using namespace mlir; 16 using namespace mlir::pdl_interp; 17 18 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOpsDialect.cpp.inc" 19 20 //===----------------------------------------------------------------------===// 21 // PDLInterp Dialect 22 //===----------------------------------------------------------------------===// 23 24 void PDLInterpDialect::initialize() { 25 addOperations< 26 #define GET_OP_LIST 27 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc" 28 >(); 29 } 30 31 template <typename OpT> 32 static LogicalResult verifySwitchOp(OpT op) { 33 // Verify that the number of case destinations matches the number of case 34 // values. 35 size_t numDests = op.getCases().size(); 36 size_t numValues = op.getCaseValues().size(); 37 if (numDests != numValues) { 38 return op.emitOpError( 39 "expected number of cases to match the number of case " 40 "values, got ") 41 << numDests << " but expected " << numValues; 42 } 43 return success(); 44 } 45 46 //===----------------------------------------------------------------------===// 47 // pdl_interp::CreateOperationOp 48 //===----------------------------------------------------------------------===// 49 50 static ParseResult parseCreateOperationOpAttributes( 51 OpAsmParser &p, SmallVectorImpl<OpAsmParser::OperandType> &attrOperands, 52 ArrayAttr &attrNamesAttr) { 53 Builder &builder = p.getBuilder(); 54 SmallVector<Attribute, 4> attrNames; 55 if (succeeded(p.parseOptionalLBrace())) { 56 do { 57 StringAttr nameAttr; 58 OpAsmParser::OperandType operand; 59 if (p.parseAttribute(nameAttr) || p.parseEqual() || 60 p.parseOperand(operand)) 61 return failure(); 62 attrNames.push_back(nameAttr); 63 attrOperands.push_back(operand); 64 } while (succeeded(p.parseOptionalComma())); 65 if (p.parseRBrace()) 66 return failure(); 67 } 68 attrNamesAttr = builder.getArrayAttr(attrNames); 69 return success(); 70 } 71 72 static void printCreateOperationOpAttributes(OpAsmPrinter &p, 73 CreateOperationOp op, 74 OperandRange attrArgs, 75 ArrayAttr attrNames) { 76 if (attrNames.empty()) 77 return; 78 p << " {"; 79 interleaveComma(llvm::seq<int>(0, attrNames.size()), p, 80 [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; }); 81 p << '}'; 82 } 83 84 //===----------------------------------------------------------------------===// 85 // pdl_interp::ForEachOp 86 //===----------------------------------------------------------------------===// 87 88 void ForEachOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, 89 Value range, Block *successor, bool initLoop) { 90 build(builder, state, range, successor); 91 if (initLoop) { 92 // Create the block and the loop variable. 93 // FIXME: Allow passing in a proper location for the loop variable. 94 auto rangeType = range.getType().cast<pdl::RangeType>(); 95 state.regions.front()->emplaceBlock(); 96 state.regions.front()->addArgument(rangeType.getElementType(), 97 state.location); 98 } 99 } 100 101 ParseResult ForEachOp::parse(OpAsmParser &parser, OperationState &result) { 102 // Parse the loop variable followed by type. 103 OpAsmParser::OperandType loopVariable; 104 Type loopVariableType; 105 if (parser.parseRegionArgument(loopVariable) || 106 parser.parseColonType(loopVariableType)) 107 return failure(); 108 109 // Parse the "in" keyword. 110 if (parser.parseKeyword("in", " after loop variable")) 111 return failure(); 112 113 // Parse the operand (value range). 114 OpAsmParser::OperandType operandInfo; 115 if (parser.parseOperand(operandInfo)) 116 return failure(); 117 118 // Resolve the operand. 119 Type rangeType = pdl::RangeType::get(loopVariableType); 120 if (parser.resolveOperand(operandInfo, rangeType, result.operands)) 121 return failure(); 122 123 // Parse the body region. 124 Region *body = result.addRegion(); 125 if (parser.parseRegion(*body, {loopVariable}, {loopVariableType})) 126 return failure(); 127 128 // Parse the attribute dictionary. 129 if (parser.parseOptionalAttrDict(result.attributes)) 130 return failure(); 131 132 // Parse the successor. 133 Block *successor; 134 if (parser.parseArrow() || parser.parseSuccessor(successor)) 135 return failure(); 136 result.addSuccessors(successor); 137 138 return success(); 139 } 140 141 void ForEachOp::print(OpAsmPrinter &p) { 142 BlockArgument arg = getLoopVariable(); 143 p << ' ' << arg << " : " << arg.getType() << " in " << getValues() << ' '; 144 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); 145 p.printOptionalAttrDict((*this)->getAttrs()); 146 p << " -> "; 147 p.printSuccessor(getSuccessor()); 148 } 149 150 LogicalResult ForEachOp::verify() { 151 // Verify that the operation has exactly one argument. 152 if (getRegion().getNumArguments() != 1) 153 return emitOpError("requires exactly one argument"); 154 155 // Verify that the loop variable and the operand (value range) 156 // have compatible types. 157 BlockArgument arg = getLoopVariable(); 158 Type rangeType = pdl::RangeType::get(arg.getType()); 159 if (rangeType != getValues().getType()) 160 return emitOpError("operand must be a range of loop variable type"); 161 162 return success(); 163 } 164 165 //===----------------------------------------------------------------------===// 166 // pdl_interp::FuncOp 167 //===----------------------------------------------------------------------===// 168 169 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, 170 FunctionType type, ArrayRef<NamedAttribute> attrs) { 171 buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs()); 172 } 173 174 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { 175 auto buildFuncType = 176 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, 177 function_interface_impl::VariadicFlag, 178 std::string &) { return builder.getFunctionType(argTypes, results); }; 179 180 return function_interface_impl::parseFunctionOp( 181 parser, result, /*allowVariadic=*/false, buildFuncType); 182 } 183 184 void FuncOp::print(OpAsmPrinter &p) { 185 function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); 186 } 187 188 //===----------------------------------------------------------------------===// 189 // pdl_interp::GetValueTypeOp 190 //===----------------------------------------------------------------------===// 191 192 /// Given the result type of a `GetValueTypeOp`, return the expected input type. 193 static Type getGetValueTypeOpValueType(Type type) { 194 Type valueTy = pdl::ValueType::get(type.getContext()); 195 return type.isa<pdl::RangeType>() ? pdl::RangeType::get(valueTy) : valueTy; 196 } 197 198 //===----------------------------------------------------------------------===// 199 // pdl_interp::SwitchAttributeOp 200 //===----------------------------------------------------------------------===// 201 202 LogicalResult SwitchAttributeOp::verify() { return verifySwitchOp(*this); } 203 204 //===----------------------------------------------------------------------===// 205 // pdl_interp::SwitchOperandCountOp 206 //===----------------------------------------------------------------------===// 207 208 LogicalResult SwitchOperandCountOp::verify() { return verifySwitchOp(*this); } 209 210 //===----------------------------------------------------------------------===// 211 // pdl_interp::SwitchOperationNameOp 212 //===----------------------------------------------------------------------===// 213 214 LogicalResult SwitchOperationNameOp::verify() { return verifySwitchOp(*this); } 215 216 //===----------------------------------------------------------------------===// 217 // pdl_interp::SwitchResultCountOp 218 //===----------------------------------------------------------------------===// 219 220 LogicalResult SwitchResultCountOp::verify() { return verifySwitchOp(*this); } 221 222 //===----------------------------------------------------------------------===// 223 // pdl_interp::SwitchTypeOp 224 //===----------------------------------------------------------------------===// 225 226 LogicalResult SwitchTypeOp::verify() { return verifySwitchOp(*this); } 227 228 //===----------------------------------------------------------------------===// 229 // pdl_interp::SwitchTypesOp 230 //===----------------------------------------------------------------------===// 231 232 LogicalResult SwitchTypesOp::verify() { return verifySwitchOp(*this); } 233 234 //===----------------------------------------------------------------------===// 235 // TableGen Auto-Generated Op and Interface Definitions 236 //===----------------------------------------------------------------------===// 237 238 #define GET_OP_CLASSES 239 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc" 240