1 //===- PDLInterp.cpp - PDL Interpreter Dialect ------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 10 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 11 #include "mlir/IR/BuiltinTypes.h" 12 #include "mlir/IR/DialectImplementation.h" 13 #include "mlir/IR/FunctionImplementation.h" 14 15 using namespace mlir; 16 using namespace mlir::pdl_interp; 17 18 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOpsDialect.cpp.inc" 19 20 //===----------------------------------------------------------------------===// 21 // PDLInterp Dialect 22 //===----------------------------------------------------------------------===// 23 24 void PDLInterpDialect::initialize() { 25 addOperations< 26 #define GET_OP_LIST 27 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc" 28 >(); 29 } 30 31 template <typename OpT> 32 static LogicalResult verifySwitchOp(OpT op) { 33 // Verify that the number of case destinations matches the number of case 34 // values. 35 size_t numDests = op.getCases().size(); 36 size_t numValues = op.getCaseValues().size(); 37 if (numDests != numValues) { 38 return op.emitOpError( 39 "expected number of cases to match the number of case " 40 "values, got ") 41 << numDests << " but expected " << numValues; 42 } 43 return success(); 44 } 45 46 //===----------------------------------------------------------------------===// 47 // pdl_interp::CreateOperationOp 48 //===----------------------------------------------------------------------===// 49 50 LogicalResult CreateOperationOp::verify() { 51 if (!getInferredResultTypes()) 52 return success(); 53 if (!getInputResultTypes().empty()) { 54 return emitOpError("with inferred results cannot also have " 55 "explicit result types"); 56 } 57 OperationName opName(getName(), getContext()); 58 if (!opName.hasInterface<InferTypeOpInterface>()) { 59 return emitOpError() 60 << "has inferred results, but the created operation '" << opName 61 << "' does not support result type inference (or is not " 62 "registered)"; 63 } 64 return success(); 65 } 66 67 static ParseResult parseCreateOperationOpAttributes( 68 OpAsmParser &p, 69 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands, 70 ArrayAttr &attrNamesAttr) { 71 Builder &builder = p.getBuilder(); 72 SmallVector<Attribute, 4> attrNames; 73 if (succeeded(p.parseOptionalLBrace())) { 74 auto parseOperands = [&]() { 75 StringAttr nameAttr; 76 OpAsmParser::UnresolvedOperand operand; 77 if (p.parseAttribute(nameAttr) || p.parseEqual() || 78 p.parseOperand(operand)) 79 return failure(); 80 attrNames.push_back(nameAttr); 81 attrOperands.push_back(operand); 82 return success(); 83 }; 84 if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace()) 85 return failure(); 86 } 87 attrNamesAttr = builder.getArrayAttr(attrNames); 88 return success(); 89 } 90 91 static void printCreateOperationOpAttributes(OpAsmPrinter &p, 92 CreateOperationOp op, 93 OperandRange attrArgs, 94 ArrayAttr attrNames) { 95 if (attrNames.empty()) 96 return; 97 p << " {"; 98 interleaveComma(llvm::seq<int>(0, attrNames.size()), p, 99 [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; }); 100 p << '}'; 101 } 102 103 static ParseResult parseCreateOperationOpResults( 104 OpAsmParser &p, 105 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &resultOperands, 106 SmallVectorImpl<Type> &resultTypes, UnitAttr &inferredResultTypes) { 107 if (failed(p.parseOptionalArrow())) 108 return success(); 109 110 // Handle the case of inferred results. 111 if (succeeded(p.parseOptionalLess())) { 112 if (p.parseKeyword("inferred") || p.parseGreater()) 113 return failure(); 114 inferredResultTypes = p.getBuilder().getUnitAttr(); 115 return success(); 116 } 117 118 // Otherwise, parse the explicit results. 119 return failure(p.parseLParen() || p.parseOperandList(resultOperands) || 120 p.parseColonTypeList(resultTypes) || p.parseRParen()); 121 } 122 123 static void printCreateOperationOpResults(OpAsmPrinter &p, CreateOperationOp op, 124 OperandRange resultOperands, 125 TypeRange resultTypes, 126 UnitAttr inferredResultTypes) { 127 // Handle the case of inferred results. 128 if (inferredResultTypes) { 129 p << " -> <inferred>"; 130 return; 131 } 132 133 // Otherwise, handle the explicit results. 134 if (!resultTypes.empty()) 135 p << " -> (" << resultOperands << " : " << resultTypes << ")"; 136 } 137 138 //===----------------------------------------------------------------------===// 139 // pdl_interp::ForEachOp 140 //===----------------------------------------------------------------------===// 141 142 void ForEachOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, 143 Value range, Block *successor, bool initLoop) { 144 build(builder, state, range, successor); 145 if (initLoop) { 146 // Create the block and the loop variable. 147 // FIXME: Allow passing in a proper location for the loop variable. 148 auto rangeType = range.getType().cast<pdl::RangeType>(); 149 state.regions.front()->emplaceBlock(); 150 state.regions.front()->addArgument(rangeType.getElementType(), 151 state.location); 152 } 153 } 154 155 ParseResult ForEachOp::parse(OpAsmParser &parser, OperationState &result) { 156 // Parse the loop variable followed by type. 157 OpAsmParser::Argument loopVariable; 158 OpAsmParser::UnresolvedOperand operandInfo; 159 if (parser.parseArgument(loopVariable, /*allowType=*/true) || 160 parser.parseKeyword("in", " after loop variable") || 161 // Parse the operand (value range). 162 parser.parseOperand(operandInfo)) 163 return failure(); 164 165 // Resolve the operand. 166 Type rangeType = pdl::RangeType::get(loopVariable.type); 167 if (parser.resolveOperand(operandInfo, rangeType, result.operands)) 168 return failure(); 169 170 // Parse the body region. 171 Region *body = result.addRegion(); 172 Block *successor; 173 if (parser.parseRegion(*body, loopVariable) || 174 parser.parseOptionalAttrDict(result.attributes) || 175 // Parse the successor. 176 parser.parseArrow() || parser.parseSuccessor(successor)) 177 return failure(); 178 179 result.addSuccessors(successor); 180 return success(); 181 } 182 183 void ForEachOp::print(OpAsmPrinter &p) { 184 BlockArgument arg = getLoopVariable(); 185 p << ' ' << arg << " : " << arg.getType() << " in " << getValues() << ' '; 186 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); 187 p.printOptionalAttrDict((*this)->getAttrs()); 188 p << " -> "; 189 p.printSuccessor(getSuccessor()); 190 } 191 192 LogicalResult ForEachOp::verify() { 193 // Verify that the operation has exactly one argument. 194 if (getRegion().getNumArguments() != 1) 195 return emitOpError("requires exactly one argument"); 196 197 // Verify that the loop variable and the operand (value range) 198 // have compatible types. 199 BlockArgument arg = getLoopVariable(); 200 Type rangeType = pdl::RangeType::get(arg.getType()); 201 if (rangeType != getValues().getType()) 202 return emitOpError("operand must be a range of loop variable type"); 203 204 return success(); 205 } 206 207 //===----------------------------------------------------------------------===// 208 // pdl_interp::FuncOp 209 //===----------------------------------------------------------------------===// 210 211 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, 212 FunctionType type, ArrayRef<NamedAttribute> attrs) { 213 buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs()); 214 } 215 216 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { 217 auto buildFuncType = 218 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, 219 function_interface_impl::VariadicFlag, 220 std::string &) { return builder.getFunctionType(argTypes, results); }; 221 222 return function_interface_impl::parseFunctionOp( 223 parser, result, /*allowVariadic=*/false, buildFuncType); 224 } 225 226 void FuncOp::print(OpAsmPrinter &p) { 227 function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); 228 } 229 230 //===----------------------------------------------------------------------===// 231 // pdl_interp::GetValueTypeOp 232 //===----------------------------------------------------------------------===// 233 234 /// Given the result type of a `GetValueTypeOp`, return the expected input type. 235 static Type getGetValueTypeOpValueType(Type type) { 236 Type valueTy = pdl::ValueType::get(type.getContext()); 237 return type.isa<pdl::RangeType>() ? pdl::RangeType::get(valueTy) : valueTy; 238 } 239 240 //===----------------------------------------------------------------------===// 241 // pdl_interp::SwitchAttributeOp 242 //===----------------------------------------------------------------------===// 243 244 LogicalResult SwitchAttributeOp::verify() { return verifySwitchOp(*this); } 245 246 //===----------------------------------------------------------------------===// 247 // pdl_interp::SwitchOperandCountOp 248 //===----------------------------------------------------------------------===// 249 250 LogicalResult SwitchOperandCountOp::verify() { return verifySwitchOp(*this); } 251 252 //===----------------------------------------------------------------------===// 253 // pdl_interp::SwitchOperationNameOp 254 //===----------------------------------------------------------------------===// 255 256 LogicalResult SwitchOperationNameOp::verify() { return verifySwitchOp(*this); } 257 258 //===----------------------------------------------------------------------===// 259 // pdl_interp::SwitchResultCountOp 260 //===----------------------------------------------------------------------===// 261 262 LogicalResult SwitchResultCountOp::verify() { return verifySwitchOp(*this); } 263 264 //===----------------------------------------------------------------------===// 265 // pdl_interp::SwitchTypeOp 266 //===----------------------------------------------------------------------===// 267 268 LogicalResult SwitchTypeOp::verify() { return verifySwitchOp(*this); } 269 270 //===----------------------------------------------------------------------===// 271 // pdl_interp::SwitchTypesOp 272 //===----------------------------------------------------------------------===// 273 274 LogicalResult SwitchTypesOp::verify() { return verifySwitchOp(*this); } 275 276 //===----------------------------------------------------------------------===// 277 // TableGen Auto-Generated Op and Interface Definitions 278 //===----------------------------------------------------------------------===// 279 280 #define GET_OP_CLASSES 281 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc" 282