1 //===- PDLInterp.cpp - PDL Interpreter Dialect ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
10 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
11 #include "mlir/IR/BuiltinTypes.h"
12 #include "mlir/IR/DialectImplementation.h"
13 
14 using namespace mlir;
15 using namespace mlir::pdl_interp;
16 
17 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOpsDialect.cpp.inc"
18 
19 //===----------------------------------------------------------------------===//
20 // PDLInterp Dialect
21 //===----------------------------------------------------------------------===//
22 
23 void PDLInterpDialect::initialize() {
24   addOperations<
25 #define GET_OP_LIST
26 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"
27       >();
28 }
29 
30 //===----------------------------------------------------------------------===//
31 // pdl_interp::CreateOperationOp
32 //===----------------------------------------------------------------------===//
33 
34 static ParseResult parseCreateOperationOpAttributes(
35     OpAsmParser &p, SmallVectorImpl<OpAsmParser::OperandType> &attrOperands,
36     ArrayAttr &attrNamesAttr) {
37   Builder &builder = p.getBuilder();
38   SmallVector<Attribute, 4> attrNames;
39   if (succeeded(p.parseOptionalLBrace())) {
40     do {
41       StringAttr nameAttr;
42       OpAsmParser::OperandType operand;
43       if (p.parseAttribute(nameAttr) || p.parseEqual() ||
44           p.parseOperand(operand))
45         return failure();
46       attrNames.push_back(nameAttr);
47       attrOperands.push_back(operand);
48     } while (succeeded(p.parseOptionalComma()));
49     if (p.parseRBrace())
50       return failure();
51   }
52   attrNamesAttr = builder.getArrayAttr(attrNames);
53   return success();
54 }
55 
56 static void printCreateOperationOpAttributes(OpAsmPrinter &p,
57                                              CreateOperationOp op,
58                                              OperandRange attrArgs,
59                                              ArrayAttr attrNames) {
60   if (attrNames.empty())
61     return;
62   p << " {";
63   interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
64                   [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
65   p << '}';
66 }
67 
68 //===----------------------------------------------------------------------===//
69 // pdl_interp::ForEachOp
70 //===----------------------------------------------------------------------===//
71 
72 void ForEachOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
73                       Value range, Block *successor, bool initLoop) {
74   build(builder, state, range, successor);
75   if (initLoop) {
76     // Create the block and the loop variable.
77     auto rangeType = range.getType().cast<pdl::RangeType>();
78     state.regions.front()->emplaceBlock();
79     state.regions.front()->addArgument(rangeType.getElementType());
80   }
81 }
82 
83 static ParseResult parseForEachOp(OpAsmParser &parser, OperationState &result) {
84   // Parse the loop variable followed by type.
85   OpAsmParser::OperandType loopVariable;
86   Type loopVariableType;
87   if (parser.parseRegionArgument(loopVariable) ||
88       parser.parseColonType(loopVariableType))
89     return failure();
90 
91   // Parse the "in" keyword.
92   if (parser.parseKeyword("in", " after loop variable"))
93     return failure();
94 
95   // Parse the operand (value range).
96   OpAsmParser::OperandType operandInfo;
97   if (parser.parseOperand(operandInfo))
98     return failure();
99 
100   // Resolve the operand.
101   Type rangeType = pdl::RangeType::get(loopVariableType);
102   if (parser.resolveOperand(operandInfo, rangeType, result.operands))
103     return failure();
104 
105   // Parse the body region.
106   Region *body = result.addRegion();
107   if (parser.parseRegion(*body, {loopVariable}, {loopVariableType}))
108     return failure();
109 
110   // Parse the attribute dictionary.
111   if (parser.parseOptionalAttrDict(result.attributes))
112     return failure();
113 
114   // Parse the successor.
115   Block *successor;
116   if (parser.parseArrow() || parser.parseSuccessor(successor))
117     return failure();
118   result.addSuccessors(successor);
119 
120   return success();
121 }
122 
123 static void print(OpAsmPrinter &p, ForEachOp op) {
124   BlockArgument arg = op.getLoopVariable();
125   p << ' ' << arg << " : " << arg.getType() << " in " << op.values() << ' ';
126   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
127   p.printOptionalAttrDict(op->getAttrs());
128   p << " -> ";
129   p.printSuccessor(op.successor());
130 }
131 
132 static LogicalResult verify(ForEachOp op) {
133   // Verify that the operation has exactly one argument.
134   if (op.region().getNumArguments() != 1)
135     return op.emitOpError("requires exactly one argument");
136 
137   // Verify that the loop variable and the operand (value range)
138   // have compatible types.
139   BlockArgument arg = op.getLoopVariable();
140   Type rangeType = pdl::RangeType::get(arg.getType());
141   if (rangeType != op.values().getType())
142     return op.emitOpError("operand must be a range of loop variable type");
143 
144   return success();
145 }
146 
147 //===----------------------------------------------------------------------===//
148 // pdl_interp::GetValueTypeOp
149 //===----------------------------------------------------------------------===//
150 
151 /// Given the result type of a `GetValueTypeOp`, return the expected input type.
152 static Type getGetValueTypeOpValueType(Type type) {
153   Type valueTy = pdl::ValueType::get(type.getContext());
154   return type.isa<pdl::RangeType>() ? pdl::RangeType::get(valueTy) : valueTy;
155 }
156 
157 //===----------------------------------------------------------------------===//
158 // TableGen Auto-Generated Op and Interface Definitions
159 //===----------------------------------------------------------------------===//
160 
161 #define GET_OP_CLASSES
162 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"
163