1 //===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/FunctionImplementation.h"
12 
13 using namespace mlir;
14 using namespace mlir::ml_program;
15 
16 //===----------------------------------------------------------------------===//
17 // Custom asm helpers
18 //===----------------------------------------------------------------------===//
19 
20 /// Parse and print an ordering clause for a variadic of consuming tokens
21 /// and an optional producing token.
22 ///
23 /// Syntax:
24 ///   ordering(%0, %1 -> !ml_program.token)
25 ///   ordering(() -> !ml_program.token)
26 ///   ordering(%0, %1)
27 ///
28 /// If both the consuming and producing token are not present on the op, then
29 /// the clause prints nothing.
30 static ParseResult parseTokenOrdering(
31     OpAsmParser &parser,
32     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &consumeTokens,
33     Type &produceTokenType) {
34   if (failed(parser.parseOptionalKeyword("ordering")) ||
35       failed(parser.parseLParen()))
36     return success();
37 
38   // Parse consuming token list. If there are no consuming tokens, the
39   // '()' null list represents this.
40   if (succeeded(parser.parseOptionalLParen())) {
41     if (failed(parser.parseRParen()))
42       return failure();
43   } else {
44     if (failed(parser.parseOperandList(consumeTokens,
45                                        /*requiredOperandCount=*/-1)))
46       return failure();
47   }
48 
49   // Parse optional producer token.
50   if (succeeded(parser.parseOptionalArrow()))
51     if (failed(parser.parseType(produceTokenType)))
52       return failure();
53 
54   if (failed(parser.parseRParen()))
55     return failure();
56 
57   return success();
58 }
59 
60 static void printTokenOrdering(OpAsmPrinter &p, Operation *op,
61                                OperandRange consumeTokens,
62                                Type produceTokenType) {
63   if (consumeTokens.empty() && !produceTokenType)
64     return;
65 
66   p << " ordering(";
67   if (consumeTokens.empty())
68     p << "()";
69   else
70     p.printOperands(consumeTokens);
71   if (produceTokenType) {
72     p << " -> ";
73     p.printType(produceTokenType);
74   }
75   p << ")";
76 }
77 
78 /// some.op custom<TypeOrAttr>($type, $attr)
79 ///
80 /// Uninitialized:
81 ///   some.op : tensor<3xi32>
82 /// Initialized to narrower type than op:
83 ///   some.op (dense<0> : tensor<3xi32>) : tensor<?xi32>
84 static ParseResult parseTypedInitialValue(OpAsmParser &parser,
85                                           TypeAttr &typeAttr, Attribute &attr) {
86   if (succeeded(parser.parseOptionalLParen())) {
87     if (failed(parser.parseAttribute(attr)))
88       return failure();
89     if (failed(parser.parseRParen()))
90       return failure();
91   }
92 
93   Type type;
94   if (failed(parser.parseColonType(type)))
95     return failure();
96   typeAttr = TypeAttr::get(type);
97   return success();
98 }
99 
100 static void printTypedInitialValue(OpAsmPrinter &p, Operation *op,
101                                    TypeAttr type, Attribute attr) {
102   if (attr) {
103     p << "(";
104     p.printAttribute(attr);
105     p << ")";
106   }
107 
108   p << " : ";
109   p.printAttribute(type);
110 }
111 
112 /// some.op custom<SymbolVisibility>($sym_visibility) $sym_name
113 /// ->
114 /// some.op public @foo
115 /// some.op private @foo
116 static ParseResult parseSymbolVisibility(OpAsmParser &parser,
117                                          StringAttr &symVisibilityAttr) {
118   StringRef symVisibility;
119   (void)parser.parseOptionalKeyword(&symVisibility,
120                                     {"public", "private", "nested"});
121   if (symVisibility.empty())
122     return parser.emitError(parser.getCurrentLocation())
123            << "expected 'public', 'private', or 'nested'";
124   if (!symVisibility.empty())
125     symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility);
126   return success();
127 }
128 
129 static void printSymbolVisibility(OpAsmPrinter &p, Operation *op,
130                                   StringAttr symVisibilityAttr) {
131   if (!symVisibilityAttr)
132     p << "public";
133   else
134     p << symVisibilityAttr.getValue();
135 }
136 
137 //===----------------------------------------------------------------------===//
138 // TableGen'd op method definitions
139 //===----------------------------------------------------------------------===//
140 
141 #define GET_OP_CLASSES
142 #include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc"
143 
144 //===----------------------------------------------------------------------===//
145 // FuncOp
146 //===----------------------------------------------------------------------===//
147 
148 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
149   auto buildFuncType =
150       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
151          function_interface_impl::VariadicFlag,
152          std::string &) { return builder.getFunctionType(argTypes, results); };
153 
154   return function_interface_impl::parseFunctionOp(
155       parser, result, /*allowVariadic=*/false, buildFuncType);
156 }
157 
158 void FuncOp::print(OpAsmPrinter &p) {
159   function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
160 }
161 
162 //===----------------------------------------------------------------------===//
163 // GlobalOp
164 //===----------------------------------------------------------------------===//
165 
166 LogicalResult GlobalOp::verify() {
167   if (!getIsMutable() && !getValue())
168     return emitOpError() << "immutable global must have an initial value";
169   return success();
170 }
171 
172 //===----------------------------------------------------------------------===//
173 // GlobalLoadOp
174 //===----------------------------------------------------------------------===//
175 
176 GlobalOp GlobalLoadOp::getGlobalOp(SymbolTableCollection &symbolTable) {
177   return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
178       getOperation()->getParentOp(), getGlobalAttr());
179 }
180 
181 LogicalResult
182 GlobalLoadOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
183   GlobalOp referrent = getGlobalOp(symbolTable);
184   if (!referrent)
185     return emitOpError() << "undefined global: " << getGlobal();
186 
187   if (referrent.getType() != getResult().getType()) {
188     return emitOpError() << "cannot load from global typed "
189                          << referrent.getType() << " as "
190                          << getResult().getType();
191   }
192 
193   return success();
194 }
195 
196 //===----------------------------------------------------------------------===//
197 // GlobalLoadConstOp
198 //===----------------------------------------------------------------------===//
199 
200 GlobalOp GlobalLoadConstOp::getGlobalOp(SymbolTableCollection &symbolTable) {
201   return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
202       getOperation()->getParentOp(), getGlobalAttr());
203 }
204 
205 LogicalResult
206 GlobalLoadConstOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
207   GlobalOp referrent = getGlobalOp(symbolTable);
208   if (!referrent)
209     return emitOpError() << "undefined global: " << getGlobal();
210 
211   if (referrent.getIsMutable())
212     return emitOpError() << "cannot load as const from mutable global "
213                          << getGlobal();
214 
215   if (referrent.getType() != getResult().getType())
216     return emitOpError() << "cannot load from global typed "
217                          << referrent.getType() << " as "
218                          << getResult().getType();
219 
220   return success();
221 }
222 
223 //===----------------------------------------------------------------------===//
224 // GlobalStoreOp
225 //===----------------------------------------------------------------------===//
226 
227 GlobalOp GlobalStoreOp::getGlobalOp(SymbolTableCollection &symbolTable) {
228   return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
229       getOperation()->getParentOp(), getGlobalAttr());
230 }
231 
232 LogicalResult
233 GlobalStoreOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
234   GlobalOp referrent = getGlobalOp(symbolTable);
235   if (!referrent)
236     return emitOpError() << "undefined global: " << getGlobal();
237 
238   if (!referrent.getIsMutable()) {
239     return emitOpError() << "cannot store to an immutable global "
240                          << getGlobal();
241   }
242 
243   if (referrent.getType() != getValue().getType()) {
244     return emitOpError() << "cannot store to a global typed "
245                          << referrent.getType() << " from "
246                          << getValue().getType();
247   }
248 
249   return success();
250 }
251 
252 //===----------------------------------------------------------------------===//
253 // SubgraphOp
254 //===----------------------------------------------------------------------===//
255 
256 ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) {
257   auto buildFuncType =
258       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
259          function_interface_impl::VariadicFlag,
260          std::string &) { return builder.getFunctionType(argTypes, results); };
261 
262   return function_interface_impl::parseFunctionOp(
263       parser, result, /*allowVariadic=*/false, buildFuncType);
264 }
265 
266 void SubgraphOp::print(OpAsmPrinter &p) {
267   function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
268 }
269 
270 //===----------------------------------------------------------------------===//
271 // OutputOp
272 //===----------------------------------------------------------------------===//
273 
274 LogicalResult OutputOp::verify() {
275   auto function = cast<SubgraphOp>((*this)->getParentOp());
276 
277   // The operand number and types must match the function signature.
278   const auto &results = function.getFunctionType().getResults();
279   if (getNumOperands() != results.size())
280     return emitOpError("has ")
281            << getNumOperands() << " operands, but enclosing function (@"
282            << function.getName() << ") outputs " << results.size();
283 
284   for (unsigned i = 0, e = results.size(); i != e; ++i)
285     if (getOperand(i).getType() != results[i])
286       return emitError() << "type of output operand " << i << " ("
287                          << getOperand(i).getType()
288                          << ") doesn't match function result type ("
289                          << results[i] << ")"
290                          << " in function @" << function.getName();
291 
292   return success();
293 }
294 
295 //===----------------------------------------------------------------------===//
296 // ReturnOp
297 //===----------------------------------------------------------------------===//
298 
299 LogicalResult ReturnOp::verify() {
300   auto function = cast<FuncOp>((*this)->getParentOp());
301 
302   // The operand number and types must match the function signature.
303   const auto &results = function.getFunctionType().getResults();
304   if (getNumOperands() != results.size())
305     return emitOpError("has ")
306            << getNumOperands() << " operands, but enclosing function (@"
307            << function.getName() << ") returns " << results.size();
308 
309   for (unsigned i = 0, e = results.size(); i != e; ++i)
310     if (getOperand(i).getType() != results[i])
311       return emitError() << "type of return operand " << i << " ("
312                          << getOperand(i).getType()
313                          << ") doesn't match function result type ("
314                          << results[i] << ")"
315                          << " in function @" << function.getName();
316 
317   return success();
318 }
319