1 //===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/FunctionImplementation.h"
12
13 using namespace mlir;
14 using namespace mlir::ml_program;
15
16 //===----------------------------------------------------------------------===//
17 // Custom asm helpers
18 //===----------------------------------------------------------------------===//
19
20 /// Parse and print an ordering clause for a variadic of consuming tokens
21 /// and an producing token.
22 ///
23 /// Syntax:
24 /// ordering(%0, %1 -> !ml_program.token)
25 /// ordering(() -> !ml_program.token)
26 ///
27 /// If both the consuming and producing token are not present on the op, then
28 /// the clause prints nothing.
parseTokenOrdering(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & consumeTokens,Type & produceTokenType)29 static ParseResult parseTokenOrdering(
30 OpAsmParser &parser,
31 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &consumeTokens,
32 Type &produceTokenType) {
33 if (failed(parser.parseOptionalKeyword("ordering")) ||
34 failed(parser.parseLParen()))
35 return success();
36
37 // Parse consuming token list. If there are no consuming tokens, the
38 // '()' null list represents this.
39 if (succeeded(parser.parseOptionalLParen())) {
40 if (failed(parser.parseRParen()))
41 return failure();
42 } else {
43 if (failed(parser.parseOperandList(consumeTokens,
44 /*requiredOperandCount=*/-1)))
45 return failure();
46 }
47
48 // Parse producer token.
49 if (failed(parser.parseArrow()))
50 return failure();
51 if (failed(parser.parseType(produceTokenType)))
52 return failure();
53
54 if (failed(parser.parseRParen()))
55 return failure();
56
57 return success();
58 }
59
printTokenOrdering(OpAsmPrinter & p,Operation * op,OperandRange consumeTokens,Type produceTokenType)60 static void printTokenOrdering(OpAsmPrinter &p, Operation *op,
61 OperandRange consumeTokens,
62 Type produceTokenType) {
63 if (consumeTokens.empty() && !produceTokenType)
64 return;
65
66 p << " ordering(";
67 if (consumeTokens.empty())
68 p << "()";
69 else
70 p.printOperands(consumeTokens);
71 if (produceTokenType) {
72 p << " -> ";
73 p.printType(produceTokenType);
74 }
75 p << ")";
76 }
77
78 /// some.op custom<TypeOrAttr>($type, $attr)
79 ///
80 /// Uninitialized:
81 /// some.op : tensor<3xi32>
82 /// Initialized to narrower type than op:
83 /// some.op (dense<0> : tensor<3xi32>) : tensor<?xi32>
parseTypedInitialValue(OpAsmParser & parser,TypeAttr & typeAttr,Attribute & attr)84 static ParseResult parseTypedInitialValue(OpAsmParser &parser,
85 TypeAttr &typeAttr, Attribute &attr) {
86 if (succeeded(parser.parseOptionalLParen())) {
87 if (failed(parser.parseAttribute(attr)))
88 return failure();
89 if (failed(parser.parseRParen()))
90 return failure();
91 }
92
93 Type type;
94 if (failed(parser.parseColonType(type)))
95 return failure();
96 typeAttr = TypeAttr::get(type);
97 return success();
98 }
99
printTypedInitialValue(OpAsmPrinter & p,Operation * op,TypeAttr type,Attribute attr)100 static void printTypedInitialValue(OpAsmPrinter &p, Operation *op,
101 TypeAttr type, Attribute attr) {
102 if (attr) {
103 p << "(";
104 p.printAttribute(attr);
105 p << ")";
106 }
107
108 p << " : ";
109 p.printAttribute(type);
110 }
111
112 /// some.op custom<SymbolVisibility>($sym_visibility) $sym_name
113 /// ->
114 /// some.op public @foo
115 /// some.op private @foo
parseSymbolVisibility(OpAsmParser & parser,StringAttr & symVisibilityAttr)116 static ParseResult parseSymbolVisibility(OpAsmParser &parser,
117 StringAttr &symVisibilityAttr) {
118 StringRef symVisibility;
119 (void)parser.parseOptionalKeyword(&symVisibility,
120 {"public", "private", "nested"});
121 if (symVisibility.empty())
122 return parser.emitError(parser.getCurrentLocation())
123 << "expected 'public', 'private', or 'nested'";
124 if (!symVisibility.empty())
125 symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility);
126 return success();
127 }
128
printSymbolVisibility(OpAsmPrinter & p,Operation * op,StringAttr symVisibilityAttr)129 static void printSymbolVisibility(OpAsmPrinter &p, Operation *op,
130 StringAttr symVisibilityAttr) {
131 if (!symVisibilityAttr)
132 p << "public";
133 else
134 p << symVisibilityAttr.getValue();
135 }
136
137 //===----------------------------------------------------------------------===//
138 // TableGen'd op method definitions
139 //===----------------------------------------------------------------------===//
140
141 #define GET_OP_CLASSES
142 #include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc"
143
144 //===----------------------------------------------------------------------===//
145 // FuncOp
146 //===----------------------------------------------------------------------===//
147
parse(OpAsmParser & parser,OperationState & result)148 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
149 auto buildFuncType =
150 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
151 function_interface_impl::VariadicFlag,
152 std::string &) { return builder.getFunctionType(argTypes, results); };
153
154 return function_interface_impl::parseFunctionOp(
155 parser, result, /*allowVariadic=*/false, buildFuncType);
156 }
157
print(OpAsmPrinter & p)158 void FuncOp::print(OpAsmPrinter &p) {
159 function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
160 }
161
162 //===----------------------------------------------------------------------===//
163 // GlobalOp
164 //===----------------------------------------------------------------------===//
165
verify()166 LogicalResult GlobalOp::verify() {
167 if (!getIsMutable() && !getValue())
168 return emitOpError() << "immutable global must have an initial value";
169 return success();
170 }
171
172 //===----------------------------------------------------------------------===//
173 // GlobalLoadOp
174 //===----------------------------------------------------------------------===//
175
getGlobalOp(SymbolTableCollection & symbolTable)176 GlobalOp GlobalLoadOp::getGlobalOp(SymbolTableCollection &symbolTable) {
177 return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
178 getOperation()->getParentOp(), getGlobalAttr());
179 }
180
181 LogicalResult
verifySymbolUses(SymbolTableCollection & symbolTable)182 GlobalLoadOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
183 GlobalOp referrent = getGlobalOp(symbolTable);
184 if (!referrent)
185 return emitOpError() << "undefined global: " << getGlobal();
186
187 if (referrent.getType() != getResult().getType()) {
188 return emitOpError() << "cannot load from global typed "
189 << referrent.getType() << " as "
190 << getResult().getType();
191 }
192
193 return success();
194 }
195
196 //===----------------------------------------------------------------------===//
197 // GlobalLoadConstOp
198 //===----------------------------------------------------------------------===//
199
getGlobalOp(SymbolTableCollection & symbolTable)200 GlobalOp GlobalLoadConstOp::getGlobalOp(SymbolTableCollection &symbolTable) {
201 return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
202 getOperation()->getParentOp(), getGlobalAttr());
203 }
204
205 LogicalResult
verifySymbolUses(SymbolTableCollection & symbolTable)206 GlobalLoadConstOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
207 GlobalOp referrent = getGlobalOp(symbolTable);
208 if (!referrent)
209 return emitOpError() << "undefined global: " << getGlobal();
210
211 if (referrent.getIsMutable())
212 return emitOpError() << "cannot load as const from mutable global "
213 << getGlobal();
214
215 if (referrent.getType() != getResult().getType())
216 return emitOpError() << "cannot load from global typed "
217 << referrent.getType() << " as "
218 << getResult().getType();
219
220 return success();
221 }
222
223 //===----------------------------------------------------------------------===//
224 // GlobalLoadGraphOp
225 //===----------------------------------------------------------------------===//
226
getGlobalOp(SymbolTableCollection & symbolTable)227 GlobalOp GlobalLoadGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) {
228 return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
229 getOperation()->getParentOp(), getGlobalAttr());
230 }
231
232 LogicalResult
verifySymbolUses(SymbolTableCollection & symbolTable)233 GlobalLoadGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
234 GlobalOp referrent = getGlobalOp(symbolTable);
235 if (!referrent)
236 return emitOpError() << "undefined global: " << getGlobal();
237
238 if (referrent.getType() != getResult().getType()) {
239 return emitOpError() << "cannot load from global typed "
240 << referrent.getType() << " as "
241 << getResult().getType();
242 }
243
244 return success();
245 }
246
247 //===----------------------------------------------------------------------===//
248 // GlobalStoreOp
249 //===----------------------------------------------------------------------===//
250
getGlobalOp(SymbolTableCollection & symbolTable)251 GlobalOp GlobalStoreOp::getGlobalOp(SymbolTableCollection &symbolTable) {
252 return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
253 getOperation()->getParentOp(), getGlobalAttr());
254 }
255
256 LogicalResult
verifySymbolUses(SymbolTableCollection & symbolTable)257 GlobalStoreOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
258 GlobalOp referrent = getGlobalOp(symbolTable);
259 if (!referrent)
260 return emitOpError() << "undefined global: " << getGlobal();
261
262 if (!referrent.getIsMutable()) {
263 return emitOpError() << "cannot store to an immutable global "
264 << getGlobal();
265 }
266
267 if (referrent.getType() != getValue().getType()) {
268 return emitOpError() << "cannot store to a global typed "
269 << referrent.getType() << " from "
270 << getValue().getType();
271 }
272
273 return success();
274 }
275
276 //===----------------------------------------------------------------------===//
277 // GlobalStoreGraphOp
278 //===----------------------------------------------------------------------===//
279
getGlobalOp(SymbolTableCollection & symbolTable)280 GlobalOp GlobalStoreGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) {
281 return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
282 getOperation()->getParentOp(), getGlobalAttr());
283 }
284
285 LogicalResult
verifySymbolUses(SymbolTableCollection & symbolTable)286 GlobalStoreGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
287 GlobalOp referrent = getGlobalOp(symbolTable);
288 if (!referrent)
289 return emitOpError() << "undefined global: " << getGlobal();
290
291 if (!referrent.getIsMutable()) {
292 return emitOpError() << "cannot store to an immutable global "
293 << getGlobal();
294 }
295
296 if (referrent.getType() != getValue().getType()) {
297 return emitOpError() << "cannot store to a global typed "
298 << referrent.getType() << " from "
299 << getValue().getType();
300 }
301
302 return success();
303 }
304
305 //===----------------------------------------------------------------------===//
306 // SubgraphOp
307 //===----------------------------------------------------------------------===//
308
parse(OpAsmParser & parser,OperationState & result)309 ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) {
310 auto buildFuncType =
311 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
312 function_interface_impl::VariadicFlag,
313 std::string &) { return builder.getFunctionType(argTypes, results); };
314
315 return function_interface_impl::parseFunctionOp(
316 parser, result, /*allowVariadic=*/false, buildFuncType);
317 }
318
print(OpAsmPrinter & p)319 void SubgraphOp::print(OpAsmPrinter &p) {
320 function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
321 }
322
323 //===----------------------------------------------------------------------===//
324 // OutputOp
325 //===----------------------------------------------------------------------===//
326
verify()327 LogicalResult OutputOp::verify() {
328 auto function = cast<SubgraphOp>((*this)->getParentOp());
329
330 // The operand number and types must match the function signature.
331 const auto &results = function.getFunctionType().getResults();
332 if (getNumOperands() != results.size())
333 return emitOpError("has ")
334 << getNumOperands() << " operands, but enclosing function (@"
335 << function.getName() << ") outputs " << results.size();
336
337 for (unsigned i = 0, e = results.size(); i != e; ++i)
338 if (getOperand(i).getType() != results[i])
339 return emitError() << "type of output operand " << i << " ("
340 << getOperand(i).getType()
341 << ") doesn't match function result type ("
342 << results[i] << ")"
343 << " in function @" << function.getName();
344
345 return success();
346 }
347
348 //===----------------------------------------------------------------------===//
349 // ReturnOp
350 //===----------------------------------------------------------------------===//
351
verify()352 LogicalResult ReturnOp::verify() {
353 auto function = cast<FuncOp>((*this)->getParentOp());
354
355 // The operand number and types must match the function signature.
356 const auto &results = function.getFunctionType().getResults();
357 if (getNumOperands() != results.size())
358 return emitOpError("has ")
359 << getNumOperands() << " operands, but enclosing function (@"
360 << function.getName() << ") returns " << results.size();
361
362 for (unsigned i = 0, e = results.size(); i != e; ++i)
363 if (getOperand(i).getType() != results[i])
364 return emitError() << "type of return operand " << i << " ("
365 << getOperand(i).getType()
366 << ") doesn't match function result type ("
367 << results[i] << ")"
368 << " in function @" << function.getName();
369
370 return success();
371 }
372