1 //===- PDLInterp.cpp - PDL Interpreter Dialect ------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 10 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 11 #include "mlir/IR/BuiltinTypes.h" 12 #include "mlir/IR/DialectImplementation.h" 13 #include "mlir/IR/FunctionImplementation.h" 14 15 using namespace mlir; 16 using namespace mlir::pdl_interp; 17 18 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOpsDialect.cpp.inc" 19 20 //===----------------------------------------------------------------------===// 21 // PDLInterp Dialect 22 //===----------------------------------------------------------------------===// 23 24 void PDLInterpDialect::initialize() { 25 addOperations< 26 #define GET_OP_LIST 27 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc" 28 >(); 29 } 30 31 template <typename OpT> 32 static LogicalResult verifySwitchOp(OpT op) { 33 // Verify that the number of case destinations matches the number of case 34 // values. 35 size_t numDests = op.getCases().size(); 36 size_t numValues = op.getCaseValues().size(); 37 if (numDests != numValues) { 38 return op.emitOpError( 39 "expected number of cases to match the number of case " 40 "values, got ") 41 << numDests << " but expected " << numValues; 42 } 43 return success(); 44 } 45 46 //===----------------------------------------------------------------------===// 47 // pdl_interp::CreateOperationOp 48 //===----------------------------------------------------------------------===// 49 50 static ParseResult parseCreateOperationOpAttributes( 51 OpAsmParser &p, 52 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands, 53 ArrayAttr &attrNamesAttr) { 54 Builder &builder = p.getBuilder(); 55 SmallVector<Attribute, 4> attrNames; 56 if (succeeded(p.parseOptionalLBrace())) { 57 do { 58 StringAttr nameAttr; 59 OpAsmParser::UnresolvedOperand operand; 60 if (p.parseAttribute(nameAttr) || p.parseEqual() || 61 p.parseOperand(operand)) 62 return failure(); 63 attrNames.push_back(nameAttr); 64 attrOperands.push_back(operand); 65 } while (succeeded(p.parseOptionalComma())); 66 if (p.parseRBrace()) 67 return failure(); 68 } 69 attrNamesAttr = builder.getArrayAttr(attrNames); 70 return success(); 71 } 72 73 static void printCreateOperationOpAttributes(OpAsmPrinter &p, 74 CreateOperationOp op, 75 OperandRange attrArgs, 76 ArrayAttr attrNames) { 77 if (attrNames.empty()) 78 return; 79 p << " {"; 80 interleaveComma(llvm::seq<int>(0, attrNames.size()), p, 81 [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; }); 82 p << '}'; 83 } 84 85 //===----------------------------------------------------------------------===// 86 // pdl_interp::ForEachOp 87 //===----------------------------------------------------------------------===// 88 89 void ForEachOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, 90 Value range, Block *successor, bool initLoop) { 91 build(builder, state, range, successor); 92 if (initLoop) { 93 // Create the block and the loop variable. 94 // FIXME: Allow passing in a proper location for the loop variable. 95 auto rangeType = range.getType().cast<pdl::RangeType>(); 96 state.regions.front()->emplaceBlock(); 97 state.regions.front()->addArgument(rangeType.getElementType(), 98 state.location); 99 } 100 } 101 102 ParseResult ForEachOp::parse(OpAsmParser &parser, OperationState &result) { 103 // Parse the loop variable followed by type. 104 OpAsmParser::UnresolvedOperand loopVariable; 105 Type loopVariableType; 106 if (parser.parseOperand(loopVariable, /*allowResultNumber=*/false) || 107 parser.parseColonType(loopVariableType)) 108 return failure(); 109 110 // Parse the "in" keyword. 111 if (parser.parseKeyword("in", " after loop variable")) 112 return failure(); 113 114 // Parse the operand (value range). 115 OpAsmParser::UnresolvedOperand operandInfo; 116 if (parser.parseOperand(operandInfo)) 117 return failure(); 118 119 // Resolve the operand. 120 Type rangeType = pdl::RangeType::get(loopVariableType); 121 if (parser.resolveOperand(operandInfo, rangeType, result.operands)) 122 return failure(); 123 124 // Parse the body region. 125 Region *body = result.addRegion(); 126 if (parser.parseRegion(*body, {loopVariable}, {loopVariableType})) 127 return failure(); 128 129 // Parse the attribute dictionary. 130 if (parser.parseOptionalAttrDict(result.attributes)) 131 return failure(); 132 133 // Parse the successor. 134 Block *successor; 135 if (parser.parseArrow() || parser.parseSuccessor(successor)) 136 return failure(); 137 result.addSuccessors(successor); 138 139 return success(); 140 } 141 142 void ForEachOp::print(OpAsmPrinter &p) { 143 BlockArgument arg = getLoopVariable(); 144 p << ' ' << arg << " : " << arg.getType() << " in " << getValues() << ' '; 145 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); 146 p.printOptionalAttrDict((*this)->getAttrs()); 147 p << " -> "; 148 p.printSuccessor(getSuccessor()); 149 } 150 151 LogicalResult ForEachOp::verify() { 152 // Verify that the operation has exactly one argument. 153 if (getRegion().getNumArguments() != 1) 154 return emitOpError("requires exactly one argument"); 155 156 // Verify that the loop variable and the operand (value range) 157 // have compatible types. 158 BlockArgument arg = getLoopVariable(); 159 Type rangeType = pdl::RangeType::get(arg.getType()); 160 if (rangeType != getValues().getType()) 161 return emitOpError("operand must be a range of loop variable type"); 162 163 return success(); 164 } 165 166 //===----------------------------------------------------------------------===// 167 // pdl_interp::FuncOp 168 //===----------------------------------------------------------------------===// 169 170 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, 171 FunctionType type, ArrayRef<NamedAttribute> attrs) { 172 buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs()); 173 } 174 175 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { 176 auto buildFuncType = 177 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, 178 function_interface_impl::VariadicFlag, 179 std::string &) { return builder.getFunctionType(argTypes, results); }; 180 181 return function_interface_impl::parseFunctionOp( 182 parser, result, /*allowVariadic=*/false, buildFuncType); 183 } 184 185 void FuncOp::print(OpAsmPrinter &p) { 186 function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); 187 } 188 189 //===----------------------------------------------------------------------===// 190 // pdl_interp::GetValueTypeOp 191 //===----------------------------------------------------------------------===// 192 193 /// Given the result type of a `GetValueTypeOp`, return the expected input type. 194 static Type getGetValueTypeOpValueType(Type type) { 195 Type valueTy = pdl::ValueType::get(type.getContext()); 196 return type.isa<pdl::RangeType>() ? pdl::RangeType::get(valueTy) : valueTy; 197 } 198 199 //===----------------------------------------------------------------------===// 200 // pdl_interp::SwitchAttributeOp 201 //===----------------------------------------------------------------------===// 202 203 LogicalResult SwitchAttributeOp::verify() { return verifySwitchOp(*this); } 204 205 //===----------------------------------------------------------------------===// 206 // pdl_interp::SwitchOperandCountOp 207 //===----------------------------------------------------------------------===// 208 209 LogicalResult SwitchOperandCountOp::verify() { return verifySwitchOp(*this); } 210 211 //===----------------------------------------------------------------------===// 212 // pdl_interp::SwitchOperationNameOp 213 //===----------------------------------------------------------------------===// 214 215 LogicalResult SwitchOperationNameOp::verify() { return verifySwitchOp(*this); } 216 217 //===----------------------------------------------------------------------===// 218 // pdl_interp::SwitchResultCountOp 219 //===----------------------------------------------------------------------===// 220 221 LogicalResult SwitchResultCountOp::verify() { return verifySwitchOp(*this); } 222 223 //===----------------------------------------------------------------------===// 224 // pdl_interp::SwitchTypeOp 225 //===----------------------------------------------------------------------===// 226 227 LogicalResult SwitchTypeOp::verify() { return verifySwitchOp(*this); } 228 229 //===----------------------------------------------------------------------===// 230 // pdl_interp::SwitchTypesOp 231 //===----------------------------------------------------------------------===// 232 233 LogicalResult SwitchTypesOp::verify() { return verifySwitchOp(*this); } 234 235 //===----------------------------------------------------------------------===// 236 // TableGen Auto-Generated Op and Interface Definitions 237 //===----------------------------------------------------------------------===// 238 239 #define GET_OP_CLASSES 240 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc" 241