1 //===- PDLInterp.cpp - PDL Interpreter Dialect ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
10 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
11 #include "mlir/IR/BuiltinTypes.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/FunctionImplementation.h"
14 
15 using namespace mlir;
16 using namespace mlir::pdl_interp;
17 
18 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOpsDialect.cpp.inc"
19 
20 //===----------------------------------------------------------------------===//
21 // PDLInterp Dialect
22 //===----------------------------------------------------------------------===//
23 
24 void PDLInterpDialect::initialize() {
25   addOperations<
26 #define GET_OP_LIST
27 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"
28       >();
29 }
30 
31 template <typename OpT>
32 static LogicalResult verifySwitchOp(OpT op) {
33   // Verify that the number of case destinations matches the number of case
34   // values.
35   size_t numDests = op.getCases().size();
36   size_t numValues = op.getCaseValues().size();
37   if (numDests != numValues) {
38     return op.emitOpError(
39                "expected number of cases to match the number of case "
40                "values, got ")
41            << numDests << " but expected " << numValues;
42   }
43   return success();
44 }
45 
46 //===----------------------------------------------------------------------===//
47 // pdl_interp::CreateOperationOp
48 //===----------------------------------------------------------------------===//
49 
50 static ParseResult parseCreateOperationOpAttributes(
51     OpAsmParser &p, SmallVectorImpl<OpAsmParser::OperandType> &attrOperands,
52     ArrayAttr &attrNamesAttr) {
53   Builder &builder = p.getBuilder();
54   SmallVector<Attribute, 4> attrNames;
55   if (succeeded(p.parseOptionalLBrace())) {
56     do {
57       StringAttr nameAttr;
58       OpAsmParser::OperandType operand;
59       if (p.parseAttribute(nameAttr) || p.parseEqual() ||
60           p.parseOperand(operand))
61         return failure();
62       attrNames.push_back(nameAttr);
63       attrOperands.push_back(operand);
64     } while (succeeded(p.parseOptionalComma()));
65     if (p.parseRBrace())
66       return failure();
67   }
68   attrNamesAttr = builder.getArrayAttr(attrNames);
69   return success();
70 }
71 
72 static void printCreateOperationOpAttributes(OpAsmPrinter &p,
73                                              CreateOperationOp op,
74                                              OperandRange attrArgs,
75                                              ArrayAttr attrNames) {
76   if (attrNames.empty())
77     return;
78   p << " {";
79   interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
80                   [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
81   p << '}';
82 }
83 
84 //===----------------------------------------------------------------------===//
85 // pdl_interp::ForEachOp
86 //===----------------------------------------------------------------------===//
87 
88 void ForEachOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
89                       Value range, Block *successor, bool initLoop) {
90   build(builder, state, range, successor);
91   if (initLoop) {
92     // Create the block and the loop variable.
93     // FIXME: Allow passing in a proper location for the loop variable.
94     auto rangeType = range.getType().cast<pdl::RangeType>();
95     state.regions.front()->emplaceBlock();
96     state.regions.front()->addArgument(rangeType.getElementType(),
97                                        state.location);
98   }
99 }
100 
101 ParseResult ForEachOp::parse(OpAsmParser &parser, OperationState &result) {
102   // Parse the loop variable followed by type.
103   OpAsmParser::OperandType loopVariable;
104   Type loopVariableType;
105   if (parser.parseRegionArgument(loopVariable) ||
106       parser.parseColonType(loopVariableType))
107     return failure();
108 
109   // Parse the "in" keyword.
110   if (parser.parseKeyword("in", " after loop variable"))
111     return failure();
112 
113   // Parse the operand (value range).
114   OpAsmParser::OperandType operandInfo;
115   if (parser.parseOperand(operandInfo))
116     return failure();
117 
118   // Resolve the operand.
119   Type rangeType = pdl::RangeType::get(loopVariableType);
120   if (parser.resolveOperand(operandInfo, rangeType, result.operands))
121     return failure();
122 
123   // Parse the body region.
124   Region *body = result.addRegion();
125   if (parser.parseRegion(*body, {loopVariable}, {loopVariableType}))
126     return failure();
127 
128   // Parse the attribute dictionary.
129   if (parser.parseOptionalAttrDict(result.attributes))
130     return failure();
131 
132   // Parse the successor.
133   Block *successor;
134   if (parser.parseArrow() || parser.parseSuccessor(successor))
135     return failure();
136   result.addSuccessors(successor);
137 
138   return success();
139 }
140 
141 void ForEachOp::print(OpAsmPrinter &p) {
142   BlockArgument arg = getLoopVariable();
143   p << ' ' << arg << " : " << arg.getType() << " in " << getValues() << ' ';
144   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
145   p.printOptionalAttrDict((*this)->getAttrs());
146   p << " -> ";
147   p.printSuccessor(getSuccessor());
148 }
149 
150 LogicalResult ForEachOp::verify() {
151   // Verify that the operation has exactly one argument.
152   if (getRegion().getNumArguments() != 1)
153     return emitOpError("requires exactly one argument");
154 
155   // Verify that the loop variable and the operand (value range)
156   // have compatible types.
157   BlockArgument arg = getLoopVariable();
158   Type rangeType = pdl::RangeType::get(arg.getType());
159   if (rangeType != getValues().getType())
160     return emitOpError("operand must be a range of loop variable type");
161 
162   return success();
163 }
164 
165 //===----------------------------------------------------------------------===//
166 // pdl_interp::FuncOp
167 //===----------------------------------------------------------------------===//
168 
169 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
170                    FunctionType type, ArrayRef<NamedAttribute> attrs) {
171   buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
172 }
173 
174 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
175   auto buildFuncType =
176       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
177          function_interface_impl::VariadicFlag,
178          std::string &) { return builder.getFunctionType(argTypes, results); };
179 
180   return function_interface_impl::parseFunctionOp(
181       parser, result, /*allowVariadic=*/false, buildFuncType);
182 }
183 
184 void FuncOp::print(OpAsmPrinter &p) {
185   function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
186 }
187 
188 //===----------------------------------------------------------------------===//
189 // pdl_interp::GetValueTypeOp
190 //===----------------------------------------------------------------------===//
191 
192 /// Given the result type of a `GetValueTypeOp`, return the expected input type.
193 static Type getGetValueTypeOpValueType(Type type) {
194   Type valueTy = pdl::ValueType::get(type.getContext());
195   return type.isa<pdl::RangeType>() ? pdl::RangeType::get(valueTy) : valueTy;
196 }
197 
198 //===----------------------------------------------------------------------===//
199 // pdl_interp::SwitchAttributeOp
200 //===----------------------------------------------------------------------===//
201 
202 LogicalResult SwitchAttributeOp::verify() { return verifySwitchOp(*this); }
203 
204 //===----------------------------------------------------------------------===//
205 // pdl_interp::SwitchOperandCountOp
206 //===----------------------------------------------------------------------===//
207 
208 LogicalResult SwitchOperandCountOp::verify() { return verifySwitchOp(*this); }
209 
210 //===----------------------------------------------------------------------===//
211 // pdl_interp::SwitchOperationNameOp
212 //===----------------------------------------------------------------------===//
213 
214 LogicalResult SwitchOperationNameOp::verify() { return verifySwitchOp(*this); }
215 
216 //===----------------------------------------------------------------------===//
217 // pdl_interp::SwitchResultCountOp
218 //===----------------------------------------------------------------------===//
219 
220 LogicalResult SwitchResultCountOp::verify() { return verifySwitchOp(*this); }
221 
222 //===----------------------------------------------------------------------===//
223 // pdl_interp::SwitchTypeOp
224 //===----------------------------------------------------------------------===//
225 
226 LogicalResult SwitchTypeOp::verify() { return verifySwitchOp(*this); }
227 
228 //===----------------------------------------------------------------------===//
229 // pdl_interp::SwitchTypesOp
230 //===----------------------------------------------------------------------===//
231 
232 LogicalResult SwitchTypesOp::verify() { return verifySwitchOp(*this); }
233 
234 //===----------------------------------------------------------------------===//
235 // TableGen Auto-Generated Op and Interface Definitions
236 //===----------------------------------------------------------------------===//
237 
238 #define GET_OP_CLASSES
239 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"
240