1 //===- PDLInterp.cpp - PDL Interpreter Dialect ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
10 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
11 #include "mlir/IR/BuiltinTypes.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/FunctionImplementation.h"
14 
15 using namespace mlir;
16 using namespace mlir::pdl_interp;
17 
18 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOpsDialect.cpp.inc"
19 
20 //===----------------------------------------------------------------------===//
21 // PDLInterp Dialect
22 //===----------------------------------------------------------------------===//
23 
24 void PDLInterpDialect::initialize() {
25   addOperations<
26 #define GET_OP_LIST
27 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"
28       >();
29 }
30 
31 template <typename OpT>
32 static LogicalResult verifySwitchOp(OpT op) {
33   // Verify that the number of case destinations matches the number of case
34   // values.
35   size_t numDests = op.getCases().size();
36   size_t numValues = op.getCaseValues().size();
37   if (numDests != numValues) {
38     return op.emitOpError(
39                "expected number of cases to match the number of case "
40                "values, got ")
41            << numDests << " but expected " << numValues;
42   }
43   return success();
44 }
45 
46 //===----------------------------------------------------------------------===//
47 // pdl_interp::CreateOperationOp
48 //===----------------------------------------------------------------------===//
49 
50 static ParseResult parseCreateOperationOpAttributes(
51     OpAsmParser &p,
52     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands,
53     ArrayAttr &attrNamesAttr) {
54   Builder &builder = p.getBuilder();
55   SmallVector<Attribute, 4> attrNames;
56   if (succeeded(p.parseOptionalLBrace())) {
57     do {
58       StringAttr nameAttr;
59       OpAsmParser::UnresolvedOperand operand;
60       if (p.parseAttribute(nameAttr) || p.parseEqual() ||
61           p.parseOperand(operand))
62         return failure();
63       attrNames.push_back(nameAttr);
64       attrOperands.push_back(operand);
65     } while (succeeded(p.parseOptionalComma()));
66     if (p.parseRBrace())
67       return failure();
68   }
69   attrNamesAttr = builder.getArrayAttr(attrNames);
70   return success();
71 }
72 
73 static void printCreateOperationOpAttributes(OpAsmPrinter &p,
74                                              CreateOperationOp op,
75                                              OperandRange attrArgs,
76                                              ArrayAttr attrNames) {
77   if (attrNames.empty())
78     return;
79   p << " {";
80   interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
81                   [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
82   p << '}';
83 }
84 
85 //===----------------------------------------------------------------------===//
86 // pdl_interp::ForEachOp
87 //===----------------------------------------------------------------------===//
88 
89 void ForEachOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
90                       Value range, Block *successor, bool initLoop) {
91   build(builder, state, range, successor);
92   if (initLoop) {
93     // Create the block and the loop variable.
94     // FIXME: Allow passing in a proper location for the loop variable.
95     auto rangeType = range.getType().cast<pdl::RangeType>();
96     state.regions.front()->emplaceBlock();
97     state.regions.front()->addArgument(rangeType.getElementType(),
98                                        state.location);
99   }
100 }
101 
102 ParseResult ForEachOp::parse(OpAsmParser &parser, OperationState &result) {
103   // Parse the loop variable followed by type.
104   OpAsmParser::Argument loopVariable;
105   OpAsmParser::UnresolvedOperand operandInfo;
106   if (parser.parseArgument(loopVariable, /*allowType=*/true) ||
107       parser.parseKeyword("in", " after loop variable") ||
108       // Parse the operand (value range).
109       parser.parseOperand(operandInfo))
110     return failure();
111 
112   // Resolve the operand.
113   Type rangeType = pdl::RangeType::get(loopVariable.type);
114   if (parser.resolveOperand(operandInfo, rangeType, result.operands))
115     return failure();
116 
117   // Parse the body region.
118   Region *body = result.addRegion();
119   Block *successor;
120   if (parser.parseRegion(*body, loopVariable) ||
121       parser.parseOptionalAttrDict(result.attributes) ||
122       // Parse the successor.
123       parser.parseArrow() || parser.parseSuccessor(successor))
124     return failure();
125 
126   result.addSuccessors(successor);
127   return success();
128 }
129 
130 void ForEachOp::print(OpAsmPrinter &p) {
131   BlockArgument arg = getLoopVariable();
132   p << ' ' << arg << " : " << arg.getType() << " in " << getValues() << ' ';
133   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
134   p.printOptionalAttrDict((*this)->getAttrs());
135   p << " -> ";
136   p.printSuccessor(getSuccessor());
137 }
138 
139 LogicalResult ForEachOp::verify() {
140   // Verify that the operation has exactly one argument.
141   if (getRegion().getNumArguments() != 1)
142     return emitOpError("requires exactly one argument");
143 
144   // Verify that the loop variable and the operand (value range)
145   // have compatible types.
146   BlockArgument arg = getLoopVariable();
147   Type rangeType = pdl::RangeType::get(arg.getType());
148   if (rangeType != getValues().getType())
149     return emitOpError("operand must be a range of loop variable type");
150 
151   return success();
152 }
153 
154 //===----------------------------------------------------------------------===//
155 // pdl_interp::FuncOp
156 //===----------------------------------------------------------------------===//
157 
158 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
159                    FunctionType type, ArrayRef<NamedAttribute> attrs) {
160   buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
161 }
162 
163 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
164   auto buildFuncType =
165       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
166          function_interface_impl::VariadicFlag,
167          std::string &) { return builder.getFunctionType(argTypes, results); };
168 
169   return function_interface_impl::parseFunctionOp(
170       parser, result, /*allowVariadic=*/false, buildFuncType);
171 }
172 
173 void FuncOp::print(OpAsmPrinter &p) {
174   function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
175 }
176 
177 //===----------------------------------------------------------------------===//
178 // pdl_interp::GetValueTypeOp
179 //===----------------------------------------------------------------------===//
180 
181 /// Given the result type of a `GetValueTypeOp`, return the expected input type.
182 static Type getGetValueTypeOpValueType(Type type) {
183   Type valueTy = pdl::ValueType::get(type.getContext());
184   return type.isa<pdl::RangeType>() ? pdl::RangeType::get(valueTy) : valueTy;
185 }
186 
187 //===----------------------------------------------------------------------===//
188 // pdl_interp::SwitchAttributeOp
189 //===----------------------------------------------------------------------===//
190 
191 LogicalResult SwitchAttributeOp::verify() { return verifySwitchOp(*this); }
192 
193 //===----------------------------------------------------------------------===//
194 // pdl_interp::SwitchOperandCountOp
195 //===----------------------------------------------------------------------===//
196 
197 LogicalResult SwitchOperandCountOp::verify() { return verifySwitchOp(*this); }
198 
199 //===----------------------------------------------------------------------===//
200 // pdl_interp::SwitchOperationNameOp
201 //===----------------------------------------------------------------------===//
202 
203 LogicalResult SwitchOperationNameOp::verify() { return verifySwitchOp(*this); }
204 
205 //===----------------------------------------------------------------------===//
206 // pdl_interp::SwitchResultCountOp
207 //===----------------------------------------------------------------------===//
208 
209 LogicalResult SwitchResultCountOp::verify() { return verifySwitchOp(*this); }
210 
211 //===----------------------------------------------------------------------===//
212 // pdl_interp::SwitchTypeOp
213 //===----------------------------------------------------------------------===//
214 
215 LogicalResult SwitchTypeOp::verify() { return verifySwitchOp(*this); }
216 
217 //===----------------------------------------------------------------------===//
218 // pdl_interp::SwitchTypesOp
219 //===----------------------------------------------------------------------===//
220 
221 LogicalResult SwitchTypesOp::verify() { return verifySwitchOp(*this); }
222 
223 //===----------------------------------------------------------------------===//
224 // TableGen Auto-Generated Op and Interface Definitions
225 //===----------------------------------------------------------------------===//
226 
227 #define GET_OP_CLASSES
228 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"
229