1 //===- PDLInterp.cpp - PDL Interpreter Dialect ------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 10 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 11 #include "mlir/IR/BuiltinTypes.h" 12 #include "mlir/IR/DialectImplementation.h" 13 14 using namespace mlir; 15 using namespace mlir::pdl_interp; 16 17 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOpsDialect.cpp.inc" 18 19 //===----------------------------------------------------------------------===// 20 // PDLInterp Dialect 21 //===----------------------------------------------------------------------===// 22 23 void PDLInterpDialect::initialize() { 24 addOperations< 25 #define GET_OP_LIST 26 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc" 27 >(); 28 } 29 30 //===----------------------------------------------------------------------===// 31 // pdl_interp::CreateOperationOp 32 //===----------------------------------------------------------------------===// 33 34 static ParseResult parseCreateOperationOpAttributes( 35 OpAsmParser &p, SmallVectorImpl<OpAsmParser::OperandType> &attrOperands, 36 ArrayAttr &attrNamesAttr) { 37 Builder &builder = p.getBuilder(); 38 SmallVector<Attribute, 4> attrNames; 39 if (succeeded(p.parseOptionalLBrace())) { 40 do { 41 StringAttr nameAttr; 42 OpAsmParser::OperandType operand; 43 if (p.parseAttribute(nameAttr) || p.parseEqual() || 44 p.parseOperand(operand)) 45 return failure(); 46 attrNames.push_back(nameAttr); 47 attrOperands.push_back(operand); 48 } while (succeeded(p.parseOptionalComma())); 49 if (p.parseRBrace()) 50 return failure(); 51 } 52 attrNamesAttr = builder.getArrayAttr(attrNames); 53 return success(); 54 } 55 56 static void printCreateOperationOpAttributes(OpAsmPrinter &p, 57 CreateOperationOp op, 58 OperandRange attrArgs, 59 ArrayAttr attrNames) { 60 if (attrNames.empty()) 61 return; 62 p << " {"; 63 interleaveComma(llvm::seq<int>(0, attrNames.size()), p, 64 [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; }); 65 p << '}'; 66 } 67 68 //===----------------------------------------------------------------------===// 69 // pdl_interp::ForEachOp 70 //===----------------------------------------------------------------------===// 71 72 void ForEachOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, 73 Value range, Block *successor, bool initLoop) { 74 build(builder, state, range, successor); 75 if (initLoop) { 76 // Create the block and the loop variable. 77 auto rangeType = range.getType().cast<pdl::RangeType>(); 78 state.regions.front()->emplaceBlock(); 79 state.regions.front()->addArgument(rangeType.getElementType()); 80 } 81 } 82 83 static ParseResult parseForEachOp(OpAsmParser &parser, OperationState &result) { 84 // Parse the loop variable followed by type. 85 OpAsmParser::OperandType loopVariable; 86 Type loopVariableType; 87 if (parser.parseRegionArgument(loopVariable) || 88 parser.parseColonType(loopVariableType)) 89 return failure(); 90 91 // Parse the "in" keyword. 92 if (parser.parseKeyword("in", " after loop variable")) 93 return failure(); 94 95 // Parse the operand (value range). 96 OpAsmParser::OperandType operandInfo; 97 if (parser.parseOperand(operandInfo)) 98 return failure(); 99 100 // Resolve the operand. 101 Type rangeType = pdl::RangeType::get(loopVariableType); 102 if (parser.resolveOperand(operandInfo, rangeType, result.operands)) 103 return failure(); 104 105 // Parse the body region. 106 Region *body = result.addRegion(); 107 if (parser.parseRegion(*body, {loopVariable}, {loopVariableType})) 108 return failure(); 109 110 // Parse the attribute dictionary. 111 if (parser.parseOptionalAttrDict(result.attributes)) 112 return failure(); 113 114 // Parse the successor. 115 Block *successor; 116 if (parser.parseArrow() || parser.parseSuccessor(successor)) 117 return failure(); 118 result.addSuccessors(successor); 119 120 return success(); 121 } 122 123 static void print(OpAsmPrinter &p, ForEachOp op) { 124 BlockArgument arg = op.getLoopVariable(); 125 p << ' ' << arg << " : " << arg.getType() << " in " << op.values() << ' '; 126 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 127 p.printOptionalAttrDict(op->getAttrs()); 128 p << " -> "; 129 p.printSuccessor(op.successor()); 130 } 131 132 static LogicalResult verify(ForEachOp op) { 133 // Verify that the operation has exactly one argument. 134 if (op.region().getNumArguments() != 1) 135 return op.emitOpError("requires exactly one argument"); 136 137 // Verify that the loop variable and the operand (value range) 138 // have compatible types. 139 BlockArgument arg = op.getLoopVariable(); 140 Type rangeType = pdl::RangeType::get(arg.getType()); 141 if (rangeType != op.values().getType()) 142 return op.emitOpError("operand must be a range of loop variable type"); 143 144 return success(); 145 } 146 147 //===----------------------------------------------------------------------===// 148 // pdl_interp::GetValueTypeOp 149 //===----------------------------------------------------------------------===// 150 151 /// Given the result type of a `GetValueTypeOp`, return the expected input type. 152 static Type getGetValueTypeOpValueType(Type type) { 153 Type valueTy = pdl::ValueType::get(type.getContext()); 154 return type.isa<pdl::RangeType>() ? pdl::RangeType::get(valueTy) : valueTy; 155 } 156 157 //===----------------------------------------------------------------------===// 158 // TableGen Auto-Generated Op and Interface Definitions 159 //===----------------------------------------------------------------------===// 160 161 #define GET_OP_CLASSES 162 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc" 163