1 //===- PDLInterp.cpp - PDL Interpreter Dialect ------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 10 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 11 #include "mlir/IR/BuiltinTypes.h" 12 #include "mlir/IR/DialectImplementation.h" 13 #include "mlir/IR/FunctionImplementation.h" 14 15 using namespace mlir; 16 using namespace mlir::pdl_interp; 17 18 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOpsDialect.cpp.inc" 19 20 //===----------------------------------------------------------------------===// 21 // PDLInterp Dialect 22 //===----------------------------------------------------------------------===// 23 24 void PDLInterpDialect::initialize() { 25 addOperations< 26 #define GET_OP_LIST 27 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc" 28 >(); 29 } 30 31 template <typename OpT> 32 static LogicalResult verifySwitchOp(OpT op) { 33 // Verify that the number of case destinations matches the number of case 34 // values. 35 size_t numDests = op.getCases().size(); 36 size_t numValues = op.getCaseValues().size(); 37 if (numDests != numValues) { 38 return op.emitOpError( 39 "expected number of cases to match the number of case " 40 "values, got ") 41 << numDests << " but expected " << numValues; 42 } 43 return success(); 44 } 45 46 //===----------------------------------------------------------------------===// 47 // pdl_interp::CreateOperationOp 48 //===----------------------------------------------------------------------===// 49 50 static ParseResult parseCreateOperationOpAttributes( 51 OpAsmParser &p, 52 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands, 53 ArrayAttr &attrNamesAttr) { 54 Builder &builder = p.getBuilder(); 55 SmallVector<Attribute, 4> attrNames; 56 if (succeeded(p.parseOptionalLBrace())) { 57 do { 58 StringAttr nameAttr; 59 OpAsmParser::UnresolvedOperand operand; 60 if (p.parseAttribute(nameAttr) || p.parseEqual() || 61 p.parseOperand(operand)) 62 return failure(); 63 attrNames.push_back(nameAttr); 64 attrOperands.push_back(operand); 65 } while (succeeded(p.parseOptionalComma())); 66 if (p.parseRBrace()) 67 return failure(); 68 } 69 attrNamesAttr = builder.getArrayAttr(attrNames); 70 return success(); 71 } 72 73 static void printCreateOperationOpAttributes(OpAsmPrinter &p, 74 CreateOperationOp op, 75 OperandRange attrArgs, 76 ArrayAttr attrNames) { 77 if (attrNames.empty()) 78 return; 79 p << " {"; 80 interleaveComma(llvm::seq<int>(0, attrNames.size()), p, 81 [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; }); 82 p << '}'; 83 } 84 85 //===----------------------------------------------------------------------===// 86 // pdl_interp::ForEachOp 87 //===----------------------------------------------------------------------===// 88 89 void ForEachOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, 90 Value range, Block *successor, bool initLoop) { 91 build(builder, state, range, successor); 92 if (initLoop) { 93 // Create the block and the loop variable. 94 // FIXME: Allow passing in a proper location for the loop variable. 95 auto rangeType = range.getType().cast<pdl::RangeType>(); 96 state.regions.front()->emplaceBlock(); 97 state.regions.front()->addArgument(rangeType.getElementType(), 98 state.location); 99 } 100 } 101 102 ParseResult ForEachOp::parse(OpAsmParser &parser, OperationState &result) { 103 // Parse the loop variable followed by type. 104 OpAsmParser::Argument loopVariable; 105 OpAsmParser::UnresolvedOperand operandInfo; 106 if (parser.parseArgument(loopVariable, /*allowType=*/true) || 107 parser.parseKeyword("in", " after loop variable") || 108 // Parse the operand (value range). 109 parser.parseOperand(operandInfo)) 110 return failure(); 111 112 // Resolve the operand. 113 Type rangeType = pdl::RangeType::get(loopVariable.type); 114 if (parser.resolveOperand(operandInfo, rangeType, result.operands)) 115 return failure(); 116 117 // Parse the body region. 118 Region *body = result.addRegion(); 119 Block *successor; 120 if (parser.parseRegion(*body, loopVariable) || 121 parser.parseOptionalAttrDict(result.attributes) || 122 // Parse the successor. 123 parser.parseArrow() || parser.parseSuccessor(successor)) 124 return failure(); 125 126 result.addSuccessors(successor); 127 return success(); 128 } 129 130 void ForEachOp::print(OpAsmPrinter &p) { 131 BlockArgument arg = getLoopVariable(); 132 p << ' ' << arg << " : " << arg.getType() << " in " << getValues() << ' '; 133 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); 134 p.printOptionalAttrDict((*this)->getAttrs()); 135 p << " -> "; 136 p.printSuccessor(getSuccessor()); 137 } 138 139 LogicalResult ForEachOp::verify() { 140 // Verify that the operation has exactly one argument. 141 if (getRegion().getNumArguments() != 1) 142 return emitOpError("requires exactly one argument"); 143 144 // Verify that the loop variable and the operand (value range) 145 // have compatible types. 146 BlockArgument arg = getLoopVariable(); 147 Type rangeType = pdl::RangeType::get(arg.getType()); 148 if (rangeType != getValues().getType()) 149 return emitOpError("operand must be a range of loop variable type"); 150 151 return success(); 152 } 153 154 //===----------------------------------------------------------------------===// 155 // pdl_interp::FuncOp 156 //===----------------------------------------------------------------------===// 157 158 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, 159 FunctionType type, ArrayRef<NamedAttribute> attrs) { 160 buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs()); 161 } 162 163 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { 164 auto buildFuncType = 165 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, 166 function_interface_impl::VariadicFlag, 167 std::string &) { return builder.getFunctionType(argTypes, results); }; 168 169 return function_interface_impl::parseFunctionOp( 170 parser, result, /*allowVariadic=*/false, buildFuncType); 171 } 172 173 void FuncOp::print(OpAsmPrinter &p) { 174 function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); 175 } 176 177 //===----------------------------------------------------------------------===// 178 // pdl_interp::GetValueTypeOp 179 //===----------------------------------------------------------------------===// 180 181 /// Given the result type of a `GetValueTypeOp`, return the expected input type. 182 static Type getGetValueTypeOpValueType(Type type) { 183 Type valueTy = pdl::ValueType::get(type.getContext()); 184 return type.isa<pdl::RangeType>() ? pdl::RangeType::get(valueTy) : valueTy; 185 } 186 187 //===----------------------------------------------------------------------===// 188 // pdl_interp::SwitchAttributeOp 189 //===----------------------------------------------------------------------===// 190 191 LogicalResult SwitchAttributeOp::verify() { return verifySwitchOp(*this); } 192 193 //===----------------------------------------------------------------------===// 194 // pdl_interp::SwitchOperandCountOp 195 //===----------------------------------------------------------------------===// 196 197 LogicalResult SwitchOperandCountOp::verify() { return verifySwitchOp(*this); } 198 199 //===----------------------------------------------------------------------===// 200 // pdl_interp::SwitchOperationNameOp 201 //===----------------------------------------------------------------------===// 202 203 LogicalResult SwitchOperationNameOp::verify() { return verifySwitchOp(*this); } 204 205 //===----------------------------------------------------------------------===// 206 // pdl_interp::SwitchResultCountOp 207 //===----------------------------------------------------------------------===// 208 209 LogicalResult SwitchResultCountOp::verify() { return verifySwitchOp(*this); } 210 211 //===----------------------------------------------------------------------===// 212 // pdl_interp::SwitchTypeOp 213 //===----------------------------------------------------------------------===// 214 215 LogicalResult SwitchTypeOp::verify() { return verifySwitchOp(*this); } 216 217 //===----------------------------------------------------------------------===// 218 // pdl_interp::SwitchTypesOp 219 //===----------------------------------------------------------------------===// 220 221 LogicalResult SwitchTypesOp::verify() { return verifySwitchOp(*this); } 222 223 //===----------------------------------------------------------------------===// 224 // TableGen Auto-Generated Op and Interface Definitions 225 //===----------------------------------------------------------------------===// 226 227 #define GET_OP_CLASSES 228 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc" 229