1 //===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/FunctionImplementation.h"
12 
13 using namespace mlir;
14 using namespace mlir::ml_program;
15 
16 //===----------------------------------------------------------------------===//
17 // Custom asm helpers
18 //===----------------------------------------------------------------------===//
19 
20 /// Parse and print an ordering clause for a variadic of consuming tokens
21 /// and an producing token.
22 ///
23 /// Syntax:
24 ///   ordering(%0, %1 -> !ml_program.token)
25 ///   ordering(() -> !ml_program.token)
26 ///
27 /// If both the consuming and producing token are not present on the op, then
28 /// the clause prints nothing.
parseTokenOrdering(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & consumeTokens,Type & produceTokenType)29 static ParseResult parseTokenOrdering(
30     OpAsmParser &parser,
31     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &consumeTokens,
32     Type &produceTokenType) {
33   if (failed(parser.parseOptionalKeyword("ordering")) ||
34       failed(parser.parseLParen()))
35     return success();
36 
37   // Parse consuming token list. If there are no consuming tokens, the
38   // '()' null list represents this.
39   if (succeeded(parser.parseOptionalLParen())) {
40     if (failed(parser.parseRParen()))
41       return failure();
42   } else {
43     if (failed(parser.parseOperandList(consumeTokens,
44                                        /*requiredOperandCount=*/-1)))
45       return failure();
46   }
47 
48   // Parse producer token.
49   if (failed(parser.parseArrow()))
50     return failure();
51   if (failed(parser.parseType(produceTokenType)))
52     return failure();
53 
54   if (failed(parser.parseRParen()))
55     return failure();
56 
57   return success();
58 }
59 
printTokenOrdering(OpAsmPrinter & p,Operation * op,OperandRange consumeTokens,Type produceTokenType)60 static void printTokenOrdering(OpAsmPrinter &p, Operation *op,
61                                OperandRange consumeTokens,
62                                Type produceTokenType) {
63   if (consumeTokens.empty() && !produceTokenType)
64     return;
65 
66   p << " ordering(";
67   if (consumeTokens.empty())
68     p << "()";
69   else
70     p.printOperands(consumeTokens);
71   if (produceTokenType) {
72     p << " -> ";
73     p.printType(produceTokenType);
74   }
75   p << ")";
76 }
77 
78 /// some.op custom<TypeOrAttr>($type, $attr)
79 ///
80 /// Uninitialized:
81 ///   some.op : tensor<3xi32>
82 /// Initialized to narrower type than op:
83 ///   some.op (dense<0> : tensor<3xi32>) : tensor<?xi32>
parseTypedInitialValue(OpAsmParser & parser,TypeAttr & typeAttr,Attribute & attr)84 static ParseResult parseTypedInitialValue(OpAsmParser &parser,
85                                           TypeAttr &typeAttr, Attribute &attr) {
86   if (succeeded(parser.parseOptionalLParen())) {
87     if (failed(parser.parseAttribute(attr)))
88       return failure();
89     if (failed(parser.parseRParen()))
90       return failure();
91   }
92 
93   Type type;
94   if (failed(parser.parseColonType(type)))
95     return failure();
96   typeAttr = TypeAttr::get(type);
97   return success();
98 }
99 
printTypedInitialValue(OpAsmPrinter & p,Operation * op,TypeAttr type,Attribute attr)100 static void printTypedInitialValue(OpAsmPrinter &p, Operation *op,
101                                    TypeAttr type, Attribute attr) {
102   if (attr) {
103     p << "(";
104     p.printAttribute(attr);
105     p << ")";
106   }
107 
108   p << " : ";
109   p.printAttribute(type);
110 }
111 
112 /// some.op custom<SymbolVisibility>($sym_visibility) $sym_name
113 /// ->
114 /// some.op public @foo
115 /// some.op private @foo
parseSymbolVisibility(OpAsmParser & parser,StringAttr & symVisibilityAttr)116 static ParseResult parseSymbolVisibility(OpAsmParser &parser,
117                                          StringAttr &symVisibilityAttr) {
118   StringRef symVisibility;
119   (void)parser.parseOptionalKeyword(&symVisibility,
120                                     {"public", "private", "nested"});
121   if (symVisibility.empty())
122     return parser.emitError(parser.getCurrentLocation())
123            << "expected 'public', 'private', or 'nested'";
124   if (!symVisibility.empty())
125     symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility);
126   return success();
127 }
128 
printSymbolVisibility(OpAsmPrinter & p,Operation * op,StringAttr symVisibilityAttr)129 static void printSymbolVisibility(OpAsmPrinter &p, Operation *op,
130                                   StringAttr symVisibilityAttr) {
131   if (!symVisibilityAttr)
132     p << "public";
133   else
134     p << symVisibilityAttr.getValue();
135 }
136 
137 //===----------------------------------------------------------------------===//
138 // TableGen'd op method definitions
139 //===----------------------------------------------------------------------===//
140 
141 #define GET_OP_CLASSES
142 #include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc"
143 
144 //===----------------------------------------------------------------------===//
145 // FuncOp
146 //===----------------------------------------------------------------------===//
147 
parse(OpAsmParser & parser,OperationState & result)148 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
149   auto buildFuncType =
150       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
151          function_interface_impl::VariadicFlag,
152          std::string &) { return builder.getFunctionType(argTypes, results); };
153 
154   return function_interface_impl::parseFunctionOp(
155       parser, result, /*allowVariadic=*/false, buildFuncType);
156 }
157 
print(OpAsmPrinter & p)158 void FuncOp::print(OpAsmPrinter &p) {
159   function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
160 }
161 
162 //===----------------------------------------------------------------------===//
163 // GlobalOp
164 //===----------------------------------------------------------------------===//
165 
verify()166 LogicalResult GlobalOp::verify() {
167   if (!getIsMutable() && !getValue())
168     return emitOpError() << "immutable global must have an initial value";
169   return success();
170 }
171 
172 //===----------------------------------------------------------------------===//
173 // GlobalLoadOp
174 //===----------------------------------------------------------------------===//
175 
getGlobalOp(SymbolTableCollection & symbolTable)176 GlobalOp GlobalLoadOp::getGlobalOp(SymbolTableCollection &symbolTable) {
177   return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
178       getOperation()->getParentOp(), getGlobalAttr());
179 }
180 
181 LogicalResult
verifySymbolUses(SymbolTableCollection & symbolTable)182 GlobalLoadOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
183   GlobalOp referrent = getGlobalOp(symbolTable);
184   if (!referrent)
185     return emitOpError() << "undefined global: " << getGlobal();
186 
187   if (referrent.getType() != getResult().getType()) {
188     return emitOpError() << "cannot load from global typed "
189                          << referrent.getType() << " as "
190                          << getResult().getType();
191   }
192 
193   return success();
194 }
195 
196 //===----------------------------------------------------------------------===//
197 // GlobalLoadConstOp
198 //===----------------------------------------------------------------------===//
199 
getGlobalOp(SymbolTableCollection & symbolTable)200 GlobalOp GlobalLoadConstOp::getGlobalOp(SymbolTableCollection &symbolTable) {
201   return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
202       getOperation()->getParentOp(), getGlobalAttr());
203 }
204 
205 LogicalResult
verifySymbolUses(SymbolTableCollection & symbolTable)206 GlobalLoadConstOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
207   GlobalOp referrent = getGlobalOp(symbolTable);
208   if (!referrent)
209     return emitOpError() << "undefined global: " << getGlobal();
210 
211   if (referrent.getIsMutable())
212     return emitOpError() << "cannot load as const from mutable global "
213                          << getGlobal();
214 
215   if (referrent.getType() != getResult().getType())
216     return emitOpError() << "cannot load from global typed "
217                          << referrent.getType() << " as "
218                          << getResult().getType();
219 
220   return success();
221 }
222 
223 //===----------------------------------------------------------------------===//
224 // GlobalLoadGraphOp
225 //===----------------------------------------------------------------------===//
226 
getGlobalOp(SymbolTableCollection & symbolTable)227 GlobalOp GlobalLoadGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) {
228   return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
229       getOperation()->getParentOp(), getGlobalAttr());
230 }
231 
232 LogicalResult
verifySymbolUses(SymbolTableCollection & symbolTable)233 GlobalLoadGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
234   GlobalOp referrent = getGlobalOp(symbolTable);
235   if (!referrent)
236     return emitOpError() << "undefined global: " << getGlobal();
237 
238   if (referrent.getType() != getResult().getType()) {
239     return emitOpError() << "cannot load from global typed "
240                          << referrent.getType() << " as "
241                          << getResult().getType();
242   }
243 
244   return success();
245 }
246 
247 //===----------------------------------------------------------------------===//
248 // GlobalStoreOp
249 //===----------------------------------------------------------------------===//
250 
getGlobalOp(SymbolTableCollection & symbolTable)251 GlobalOp GlobalStoreOp::getGlobalOp(SymbolTableCollection &symbolTable) {
252   return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
253       getOperation()->getParentOp(), getGlobalAttr());
254 }
255 
256 LogicalResult
verifySymbolUses(SymbolTableCollection & symbolTable)257 GlobalStoreOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
258   GlobalOp referrent = getGlobalOp(symbolTable);
259   if (!referrent)
260     return emitOpError() << "undefined global: " << getGlobal();
261 
262   if (!referrent.getIsMutable()) {
263     return emitOpError() << "cannot store to an immutable global "
264                          << getGlobal();
265   }
266 
267   if (referrent.getType() != getValue().getType()) {
268     return emitOpError() << "cannot store to a global typed "
269                          << referrent.getType() << " from "
270                          << getValue().getType();
271   }
272 
273   return success();
274 }
275 
276 //===----------------------------------------------------------------------===//
277 // GlobalStoreGraphOp
278 //===----------------------------------------------------------------------===//
279 
getGlobalOp(SymbolTableCollection & symbolTable)280 GlobalOp GlobalStoreGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) {
281   return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
282       getOperation()->getParentOp(), getGlobalAttr());
283 }
284 
285 LogicalResult
verifySymbolUses(SymbolTableCollection & symbolTable)286 GlobalStoreGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
287   GlobalOp referrent = getGlobalOp(symbolTable);
288   if (!referrent)
289     return emitOpError() << "undefined global: " << getGlobal();
290 
291   if (!referrent.getIsMutable()) {
292     return emitOpError() << "cannot store to an immutable global "
293                          << getGlobal();
294   }
295 
296   if (referrent.getType() != getValue().getType()) {
297     return emitOpError() << "cannot store to a global typed "
298                          << referrent.getType() << " from "
299                          << getValue().getType();
300   }
301 
302   return success();
303 }
304 
305 //===----------------------------------------------------------------------===//
306 // SubgraphOp
307 //===----------------------------------------------------------------------===//
308 
parse(OpAsmParser & parser,OperationState & result)309 ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) {
310   auto buildFuncType =
311       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
312          function_interface_impl::VariadicFlag,
313          std::string &) { return builder.getFunctionType(argTypes, results); };
314 
315   return function_interface_impl::parseFunctionOp(
316       parser, result, /*allowVariadic=*/false, buildFuncType);
317 }
318 
print(OpAsmPrinter & p)319 void SubgraphOp::print(OpAsmPrinter &p) {
320   function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
321 }
322 
323 //===----------------------------------------------------------------------===//
324 // OutputOp
325 //===----------------------------------------------------------------------===//
326 
verify()327 LogicalResult OutputOp::verify() {
328   auto function = cast<SubgraphOp>((*this)->getParentOp());
329 
330   // The operand number and types must match the function signature.
331   const auto &results = function.getFunctionType().getResults();
332   if (getNumOperands() != results.size())
333     return emitOpError("has ")
334            << getNumOperands() << " operands, but enclosing function (@"
335            << function.getName() << ") outputs " << results.size();
336 
337   for (unsigned i = 0, e = results.size(); i != e; ++i)
338     if (getOperand(i).getType() != results[i])
339       return emitError() << "type of output operand " << i << " ("
340                          << getOperand(i).getType()
341                          << ") doesn't match function result type ("
342                          << results[i] << ")"
343                          << " in function @" << function.getName();
344 
345   return success();
346 }
347 
348 //===----------------------------------------------------------------------===//
349 // ReturnOp
350 //===----------------------------------------------------------------------===//
351 
verify()352 LogicalResult ReturnOp::verify() {
353   auto function = cast<FuncOp>((*this)->getParentOp());
354 
355   // The operand number and types must match the function signature.
356   const auto &results = function.getFunctionType().getResults();
357   if (getNumOperands() != results.size())
358     return emitOpError("has ")
359            << getNumOperands() << " operands, but enclosing function (@"
360            << function.getName() << ") returns " << results.size();
361 
362   for (unsigned i = 0, e = results.size(); i != e; ++i)
363     if (getOperand(i).getType() != results[i])
364       return emitError() << "type of return operand " << i << " ("
365                          << getOperand(i).getType()
366                          << ") doesn't match function result type ("
367                          << results[i] << ")"
368                          << " in function @" << function.getName();
369 
370   return success();
371 }
372