1 //===- PDLInterp.cpp - PDL Interpreter Dialect ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
10 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
11 #include "mlir/IR/BuiltinTypes.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/FunctionImplementation.h"
14 
15 using namespace mlir;
16 using namespace mlir::pdl_interp;
17 
18 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOpsDialect.cpp.inc"
19 
20 //===----------------------------------------------------------------------===//
21 // PDLInterp Dialect
22 //===----------------------------------------------------------------------===//
23 
24 void PDLInterpDialect::initialize() {
25   addOperations<
26 #define GET_OP_LIST
27 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"
28       >();
29 }
30 
31 template <typename OpT>
32 static LogicalResult verifySwitchOp(OpT op) {
33   // Verify that the number of case destinations matches the number of case
34   // values.
35   size_t numDests = op.getCases().size();
36   size_t numValues = op.getCaseValues().size();
37   if (numDests != numValues) {
38     return op.emitOpError(
39                "expected number of cases to match the number of case "
40                "values, got ")
41            << numDests << " but expected " << numValues;
42   }
43   return success();
44 }
45 
46 //===----------------------------------------------------------------------===//
47 // pdl_interp::CreateOperationOp
48 //===----------------------------------------------------------------------===//
49 
50 static ParseResult parseCreateOperationOpAttributes(
51     OpAsmParser &p,
52     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands,
53     ArrayAttr &attrNamesAttr) {
54   Builder &builder = p.getBuilder();
55   SmallVector<Attribute, 4> attrNames;
56   if (succeeded(p.parseOptionalLBrace())) {
57     do {
58       StringAttr nameAttr;
59       OpAsmParser::UnresolvedOperand operand;
60       if (p.parseAttribute(nameAttr) || p.parseEqual() ||
61           p.parseOperand(operand))
62         return failure();
63       attrNames.push_back(nameAttr);
64       attrOperands.push_back(operand);
65     } while (succeeded(p.parseOptionalComma()));
66     if (p.parseRBrace())
67       return failure();
68   }
69   attrNamesAttr = builder.getArrayAttr(attrNames);
70   return success();
71 }
72 
73 static void printCreateOperationOpAttributes(OpAsmPrinter &p,
74                                              CreateOperationOp op,
75                                              OperandRange attrArgs,
76                                              ArrayAttr attrNames) {
77   if (attrNames.empty())
78     return;
79   p << " {";
80   interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
81                   [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
82   p << '}';
83 }
84 
85 //===----------------------------------------------------------------------===//
86 // pdl_interp::ForEachOp
87 //===----------------------------------------------------------------------===//
88 
89 void ForEachOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
90                       Value range, Block *successor, bool initLoop) {
91   build(builder, state, range, successor);
92   if (initLoop) {
93     // Create the block and the loop variable.
94     // FIXME: Allow passing in a proper location for the loop variable.
95     auto rangeType = range.getType().cast<pdl::RangeType>();
96     state.regions.front()->emplaceBlock();
97     state.regions.front()->addArgument(rangeType.getElementType(),
98                                        state.location);
99   }
100 }
101 
102 ParseResult ForEachOp::parse(OpAsmParser &parser, OperationState &result) {
103   // Parse the loop variable followed by type.
104   OpAsmParser::UnresolvedOperand loopVariable;
105   Type loopVariableType;
106   if (parser.parseOperand(loopVariable, /*allowResultNumber=*/false) ||
107       parser.parseColonType(loopVariableType))
108     return failure();
109 
110   // Parse the "in" keyword.
111   if (parser.parseKeyword("in", " after loop variable"))
112     return failure();
113 
114   // Parse the operand (value range).
115   OpAsmParser::UnresolvedOperand operandInfo;
116   if (parser.parseOperand(operandInfo))
117     return failure();
118 
119   // Resolve the operand.
120   Type rangeType = pdl::RangeType::get(loopVariableType);
121   if (parser.resolveOperand(operandInfo, rangeType, result.operands))
122     return failure();
123 
124   // Parse the body region.
125   Region *body = result.addRegion();
126   if (parser.parseRegion(*body, {loopVariable}, {loopVariableType}))
127     return failure();
128 
129   // Parse the attribute dictionary.
130   if (parser.parseOptionalAttrDict(result.attributes))
131     return failure();
132 
133   // Parse the successor.
134   Block *successor;
135   if (parser.parseArrow() || parser.parseSuccessor(successor))
136     return failure();
137   result.addSuccessors(successor);
138 
139   return success();
140 }
141 
142 void ForEachOp::print(OpAsmPrinter &p) {
143   BlockArgument arg = getLoopVariable();
144   p << ' ' << arg << " : " << arg.getType() << " in " << getValues() << ' ';
145   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
146   p.printOptionalAttrDict((*this)->getAttrs());
147   p << " -> ";
148   p.printSuccessor(getSuccessor());
149 }
150 
151 LogicalResult ForEachOp::verify() {
152   // Verify that the operation has exactly one argument.
153   if (getRegion().getNumArguments() != 1)
154     return emitOpError("requires exactly one argument");
155 
156   // Verify that the loop variable and the operand (value range)
157   // have compatible types.
158   BlockArgument arg = getLoopVariable();
159   Type rangeType = pdl::RangeType::get(arg.getType());
160   if (rangeType != getValues().getType())
161     return emitOpError("operand must be a range of loop variable type");
162 
163   return success();
164 }
165 
166 //===----------------------------------------------------------------------===//
167 // pdl_interp::FuncOp
168 //===----------------------------------------------------------------------===//
169 
170 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
171                    FunctionType type, ArrayRef<NamedAttribute> attrs) {
172   buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
173 }
174 
175 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
176   auto buildFuncType =
177       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
178          function_interface_impl::VariadicFlag,
179          std::string &) { return builder.getFunctionType(argTypes, results); };
180 
181   return function_interface_impl::parseFunctionOp(
182       parser, result, /*allowVariadic=*/false, buildFuncType);
183 }
184 
185 void FuncOp::print(OpAsmPrinter &p) {
186   function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
187 }
188 
189 //===----------------------------------------------------------------------===//
190 // pdl_interp::GetValueTypeOp
191 //===----------------------------------------------------------------------===//
192 
193 /// Given the result type of a `GetValueTypeOp`, return the expected input type.
194 static Type getGetValueTypeOpValueType(Type type) {
195   Type valueTy = pdl::ValueType::get(type.getContext());
196   return type.isa<pdl::RangeType>() ? pdl::RangeType::get(valueTy) : valueTy;
197 }
198 
199 //===----------------------------------------------------------------------===//
200 // pdl_interp::SwitchAttributeOp
201 //===----------------------------------------------------------------------===//
202 
203 LogicalResult SwitchAttributeOp::verify() { return verifySwitchOp(*this); }
204 
205 //===----------------------------------------------------------------------===//
206 // pdl_interp::SwitchOperandCountOp
207 //===----------------------------------------------------------------------===//
208 
209 LogicalResult SwitchOperandCountOp::verify() { return verifySwitchOp(*this); }
210 
211 //===----------------------------------------------------------------------===//
212 // pdl_interp::SwitchOperationNameOp
213 //===----------------------------------------------------------------------===//
214 
215 LogicalResult SwitchOperationNameOp::verify() { return verifySwitchOp(*this); }
216 
217 //===----------------------------------------------------------------------===//
218 // pdl_interp::SwitchResultCountOp
219 //===----------------------------------------------------------------------===//
220 
221 LogicalResult SwitchResultCountOp::verify() { return verifySwitchOp(*this); }
222 
223 //===----------------------------------------------------------------------===//
224 // pdl_interp::SwitchTypeOp
225 //===----------------------------------------------------------------------===//
226 
227 LogicalResult SwitchTypeOp::verify() { return verifySwitchOp(*this); }
228 
229 //===----------------------------------------------------------------------===//
230 // pdl_interp::SwitchTypesOp
231 //===----------------------------------------------------------------------===//
232 
233 LogicalResult SwitchTypesOp::verify() { return verifySwitchOp(*this); }
234 
235 //===----------------------------------------------------------------------===//
236 // TableGen Auto-Generated Op and Interface Definitions
237 //===----------------------------------------------------------------------===//
238 
239 #define GET_OP_CLASSES
240 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"
241