1 //===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the dialect for the Toy IR: custom type parsing and
10 // operation verification.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "toy/Dialect.h"
15
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/DialectImplementation.h"
19 #include "mlir/IR/FunctionImplementation.h"
20 #include "mlir/IR/OpImplementation.h"
21 #include "mlir/Transforms/InliningUtils.h"
22
23 using namespace mlir;
24 using namespace mlir::toy;
25
26 #include "toy/Dialect.cpp.inc"
27
28 //===----------------------------------------------------------------------===//
29 // ToyInlinerInterface
30 //===----------------------------------------------------------------------===//
31
32 /// This class defines the interface for handling inlining with Toy
33 /// operations.
34 struct ToyInlinerInterface : public DialectInlinerInterface {
35 using DialectInlinerInterface::DialectInlinerInterface;
36
37 //===--------------------------------------------------------------------===//
38 // Analysis Hooks
39 //===--------------------------------------------------------------------===//
40
41 /// All call operations within toy can be inlined.
isLegalToInlineToyInlinerInterface42 bool isLegalToInline(Operation *call, Operation *callable,
43 bool wouldBeCloned) const final {
44 return true;
45 }
46
47 /// All operations within toy can be inlined.
isLegalToInlineToyInlinerInterface48 bool isLegalToInline(Operation *, Region *, bool,
49 BlockAndValueMapping &) const final {
50 return true;
51 }
52
53 // All functions within toy can be inlined.
isLegalToInlineToyInlinerInterface54 bool isLegalToInline(Region *, Region *, bool,
55 BlockAndValueMapping &) const final {
56 return true;
57 }
58
59 //===--------------------------------------------------------------------===//
60 // Transformation Hooks
61 //===--------------------------------------------------------------------===//
62
63 /// Handle the given inlined terminator(toy.return) by replacing it with a new
64 /// operation as necessary.
handleTerminatorToyInlinerInterface65 void handleTerminator(Operation *op,
66 ArrayRef<Value> valuesToRepl) const final {
67 // Only "toy.return" needs to be handled here.
68 auto returnOp = cast<ReturnOp>(op);
69
70 // Replace the values directly with the return operands.
71 assert(returnOp.getNumOperands() == valuesToRepl.size());
72 for (const auto &it : llvm::enumerate(returnOp.getOperands()))
73 valuesToRepl[it.index()].replaceAllUsesWith(it.value());
74 }
75
76 /// Attempts to materialize a conversion for a type mismatch between a call
77 /// from this dialect, and a callable region. This method should generate an
78 /// operation that takes 'input' as the only operand, and produces a single
79 /// result of 'resultType'. If a conversion can not be generated, nullptr
80 /// should be returned.
materializeCallConversionToyInlinerInterface81 Operation *materializeCallConversion(OpBuilder &builder, Value input,
82 Type resultType,
83 Location conversionLoc) const final {
84 return builder.create<CastOp>(conversionLoc, resultType, input);
85 }
86 };
87
88 //===----------------------------------------------------------------------===//
89 // Toy Operations
90 //===----------------------------------------------------------------------===//
91
92 /// A generalized parser for binary operations. This parses the different forms
93 /// of 'printBinaryOp' below.
parseBinaryOp(mlir::OpAsmParser & parser,mlir::OperationState & result)94 static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
95 mlir::OperationState &result) {
96 SmallVector<mlir::OpAsmParser::UnresolvedOperand, 2> operands;
97 SMLoc operandsLoc = parser.getCurrentLocation();
98 Type type;
99 if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) ||
100 parser.parseOptionalAttrDict(result.attributes) ||
101 parser.parseColonType(type))
102 return mlir::failure();
103
104 // If the type is a function type, it contains the input and result types of
105 // this operation.
106 if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
107 if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
108 result.operands))
109 return mlir::failure();
110 result.addTypes(funcType.getResults());
111 return mlir::success();
112 }
113
114 // Otherwise, the parsed type is the type of both operands and results.
115 if (parser.resolveOperands(operands, type, result.operands))
116 return mlir::failure();
117 result.addTypes(type);
118 return mlir::success();
119 }
120
121 /// A generalized printer for binary operations. It prints in two different
122 /// forms depending on if all of the types match.
printBinaryOp(mlir::OpAsmPrinter & printer,mlir::Operation * op)123 static void printBinaryOp(mlir::OpAsmPrinter &printer, mlir::Operation *op) {
124 printer << " " << op->getOperands();
125 printer.printOptionalAttrDict(op->getAttrs());
126 printer << " : ";
127
128 // If all of the types are the same, print the type directly.
129 Type resultType = *op->result_type_begin();
130 if (llvm::all_of(op->getOperandTypes(),
131 [=](Type type) { return type == resultType; })) {
132 printer << resultType;
133 return;
134 }
135
136 // Otherwise, print a functional type.
137 printer.printFunctionalType(op->getOperandTypes(), op->getResultTypes());
138 }
139
140 //===----------------------------------------------------------------------===//
141 // ConstantOp
142 //===----------------------------------------------------------------------===//
143
144 /// Build a constant operation.
145 /// The builder is passed as an argument, so is the state that this method is
146 /// expected to fill in order to build the operation.
build(mlir::OpBuilder & builder,mlir::OperationState & state,double value)147 void ConstantOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
148 double value) {
149 auto dataType = RankedTensorType::get({}, builder.getF64Type());
150 auto dataAttribute = DenseElementsAttr::get(dataType, value);
151 ConstantOp::build(builder, state, dataType, dataAttribute);
152 }
153
154 /// The 'OpAsmParser' class provides a collection of methods for parsing
155 /// various punctuation, as well as attributes, operands, types, etc. Each of
156 /// these methods returns a `ParseResult`. This class is a wrapper around
157 /// `LogicalResult` that can be converted to a boolean `true` value on failure,
158 /// or `false` on success. This allows for easily chaining together a set of
159 /// parser rules. These rules are used to populate an `mlir::OperationState`
160 /// similarly to the `build` methods described above.
parse(mlir::OpAsmParser & parser,mlir::OperationState & result)161 mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser &parser,
162 mlir::OperationState &result) {
163 mlir::DenseElementsAttr value;
164 if (parser.parseOptionalAttrDict(result.attributes) ||
165 parser.parseAttribute(value, "value", result.attributes))
166 return failure();
167
168 result.addTypes(value.getType());
169 return success();
170 }
171
172 /// The 'OpAsmPrinter' class is a stream that allows for formatting
173 /// strings, attributes, operands, types, etc.
print(mlir::OpAsmPrinter & printer)174 void ConstantOp::print(mlir::OpAsmPrinter &printer) {
175 printer << " ";
176 printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
177 printer << getValue();
178 }
179
180 /// Verify that the given attribute value is valid for the given type.
verifyConstantForType(mlir::Type type,mlir::Attribute opaqueValue,mlir::Operation * op)181 static mlir::LogicalResult verifyConstantForType(mlir::Type type,
182 mlir::Attribute opaqueValue,
183 mlir::Operation *op) {
184 if (type.isa<mlir::TensorType>()) {
185 // Check that the value is an elements attribute.
186 auto attrValue = opaqueValue.dyn_cast<mlir::DenseFPElementsAttr>();
187 if (!attrValue)
188 return op->emitError("constant of TensorType must be initialized by "
189 "a DenseFPElementsAttr, got ")
190 << opaqueValue;
191
192 // If the return type of the constant is not an unranked tensor, the shape
193 // must match the shape of the attribute holding the data.
194 auto resultType = type.dyn_cast<mlir::RankedTensorType>();
195 if (!resultType)
196 return success();
197
198 // Check that the rank of the attribute type matches the rank of the
199 // constant result type.
200 auto attrType = attrValue.getType().cast<mlir::TensorType>();
201 if (attrType.getRank() != resultType.getRank()) {
202 return op->emitOpError("return type must match the one of the attached "
203 "value attribute: ")
204 << attrType.getRank() << " != " << resultType.getRank();
205 }
206
207 // Check that each of the dimensions match between the two types.
208 for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
209 if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
210 return op->emitOpError(
211 "return type shape mismatches its attribute at dimension ")
212 << dim << ": " << attrType.getShape()[dim]
213 << " != " << resultType.getShape()[dim];
214 }
215 }
216 return mlir::success();
217 }
218 auto resultType = type.cast<StructType>();
219 llvm::ArrayRef<mlir::Type> resultElementTypes = resultType.getElementTypes();
220
221 // Verify that the initializer is an Array.
222 auto attrValue = opaqueValue.dyn_cast<ArrayAttr>();
223 if (!attrValue || attrValue.getValue().size() != resultElementTypes.size())
224 return op->emitError("constant of StructType must be initialized by an "
225 "ArrayAttr with the same number of elements, got ")
226 << opaqueValue;
227
228 // Check that each of the elements are valid.
229 llvm::ArrayRef<mlir::Attribute> attrElementValues = attrValue.getValue();
230 for (const auto it : llvm::zip(resultElementTypes, attrElementValues))
231 if (failed(verifyConstantForType(std::get<0>(it), std::get<1>(it), op)))
232 return mlir::failure();
233 return mlir::success();
234 }
235
236 /// Verifier for the constant operation. This corresponds to the `::verify(...)`
237 /// in the op definition.
verify()238 mlir::LogicalResult ConstantOp::verify() {
239 return verifyConstantForType(getResult().getType(), getValue(), *this);
240 }
241
verify()242 mlir::LogicalResult StructConstantOp::verify() {
243 return verifyConstantForType(getResult().getType(), getValue(), *this);
244 }
245
246 /// Infer the output shape of the ConstantOp, this is required by the shape
247 /// inference interface.
inferShapes()248 void ConstantOp::inferShapes() { getResult().setType(getValue().getType()); }
249
250 //===----------------------------------------------------------------------===//
251 // AddOp
252 //===----------------------------------------------------------------------===//
253
build(mlir::OpBuilder & builder,mlir::OperationState & state,mlir::Value lhs,mlir::Value rhs)254 void AddOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
255 mlir::Value lhs, mlir::Value rhs) {
256 state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
257 state.addOperands({lhs, rhs});
258 }
259
parse(mlir::OpAsmParser & parser,mlir::OperationState & result)260 mlir::ParseResult AddOp::parse(mlir::OpAsmParser &parser,
261 mlir::OperationState &result) {
262 return parseBinaryOp(parser, result);
263 }
264
print(mlir::OpAsmPrinter & p)265 void AddOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); }
266
267 /// Infer the output shape of the AddOp, this is required by the shape inference
268 /// interface.
inferShapes()269 void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
270
271 //===----------------------------------------------------------------------===//
272 // CastOp
273 //===----------------------------------------------------------------------===//
274
275 /// Infer the output shape of the CastOp, this is required by the shape
276 /// inference interface.
inferShapes()277 void CastOp::inferShapes() { getResult().setType(getOperand().getType()); }
278
279 /// Returns true if the given set of input and result types are compatible with
280 /// this cast operation. This is required by the `CastOpInterface` to verify
281 /// this operation and provide other additional utilities.
areCastCompatible(TypeRange inputs,TypeRange outputs)282 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
283 if (inputs.size() != 1 || outputs.size() != 1)
284 return false;
285 // The inputs must be Tensors with the same element type.
286 TensorType input = inputs.front().dyn_cast<TensorType>();
287 TensorType output = outputs.front().dyn_cast<TensorType>();
288 if (!input || !output || input.getElementType() != output.getElementType())
289 return false;
290 // The shape is required to match if both types are ranked.
291 return !input.hasRank() || !output.hasRank() || input == output;
292 }
293
294 //===----------------------------------------------------------------------===//
295 // FuncOp
296 //===----------------------------------------------------------------------===//
297
build(mlir::OpBuilder & builder,mlir::OperationState & state,llvm::StringRef name,mlir::FunctionType type,llvm::ArrayRef<mlir::NamedAttribute> attrs)298 void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
299 llvm::StringRef name, mlir::FunctionType type,
300 llvm::ArrayRef<mlir::NamedAttribute> attrs) {
301 // FunctionOpInterface provides a convenient `build` method that will populate
302 // the state of our FuncOp, and create an entry block.
303 buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
304 }
305
parse(mlir::OpAsmParser & parser,mlir::OperationState & result)306 mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
307 mlir::OperationState &result) {
308 // Dispatch to the FunctionOpInterface provided utility method that parses the
309 // function operation.
310 auto buildFuncType =
311 [](mlir::Builder &builder, llvm::ArrayRef<mlir::Type> argTypes,
312 llvm::ArrayRef<mlir::Type> results,
313 mlir::function_interface_impl::VariadicFlag,
314 std::string &) { return builder.getFunctionType(argTypes, results); };
315
316 return mlir::function_interface_impl::parseFunctionOp(
317 parser, result, /*allowVariadic=*/false, buildFuncType);
318 }
319
print(mlir::OpAsmPrinter & p)320 void FuncOp::print(mlir::OpAsmPrinter &p) {
321 // Dispatch to the FunctionOpInterface provided utility method that prints the
322 // function operation.
323 mlir::function_interface_impl::printFunctionOp(p, *this,
324 /*isVariadic=*/false);
325 }
326
327 /// Returns the region on the function operation that is callable.
getCallableRegion()328 mlir::Region *FuncOp::getCallableRegion() { return &getBody(); }
329
330 /// Returns the results types that the callable region produces when
331 /// executed.
getCallableResults()332 llvm::ArrayRef<mlir::Type> FuncOp::getCallableResults() {
333 return getFunctionType().getResults();
334 }
335
336 //===----------------------------------------------------------------------===//
337 // GenericCallOp
338 //===----------------------------------------------------------------------===//
339
build(mlir::OpBuilder & builder,mlir::OperationState & state,StringRef callee,ArrayRef<mlir::Value> arguments)340 void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
341 StringRef callee, ArrayRef<mlir::Value> arguments) {
342 // Generic call always returns an unranked Tensor initially.
343 state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
344 state.addOperands(arguments);
345 state.addAttribute("callee",
346 mlir::SymbolRefAttr::get(builder.getContext(), callee));
347 }
348
349 /// Return the callee of the generic call operation, this is required by the
350 /// call interface.
getCallableForCallee()351 CallInterfaceCallable GenericCallOp::getCallableForCallee() {
352 return (*this)->getAttrOfType<SymbolRefAttr>("callee");
353 }
354
355 /// Get the argument operands to the called function, this is required by the
356 /// call interface.
getArgOperands()357 Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
358
359 //===----------------------------------------------------------------------===//
360 // MulOp
361 //===----------------------------------------------------------------------===//
362
build(mlir::OpBuilder & builder,mlir::OperationState & state,mlir::Value lhs,mlir::Value rhs)363 void MulOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
364 mlir::Value lhs, mlir::Value rhs) {
365 state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
366 state.addOperands({lhs, rhs});
367 }
368
parse(mlir::OpAsmParser & parser,mlir::OperationState & result)369 mlir::ParseResult MulOp::parse(mlir::OpAsmParser &parser,
370 mlir::OperationState &result) {
371 return parseBinaryOp(parser, result);
372 }
373
print(mlir::OpAsmPrinter & p)374 void MulOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); }
375
376 /// Infer the output shape of the MulOp, this is required by the shape inference
377 /// interface.
inferShapes()378 void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
379
380 //===----------------------------------------------------------------------===//
381 // ReturnOp
382 //===----------------------------------------------------------------------===//
383
verify()384 mlir::LogicalResult ReturnOp::verify() {
385 // We know that the parent operation is a function, because of the 'HasParent'
386 // trait attached to the operation definition.
387 auto function = cast<FuncOp>((*this)->getParentOp());
388
389 /// ReturnOps can only have a single optional operand.
390 if (getNumOperands() > 1)
391 return emitOpError() << "expects at most 1 return operand";
392
393 // The operand number and types must match the function signature.
394 const auto &results = function.getFunctionType().getResults();
395 if (getNumOperands() != results.size())
396 return emitOpError() << "does not return the same number of values ("
397 << getNumOperands() << ") as the enclosing function ("
398 << results.size() << ")";
399
400 // If the operation does not have an input, we are done.
401 if (!hasOperand())
402 return mlir::success();
403
404 auto inputType = *operand_type_begin();
405 auto resultType = results.front();
406
407 // Check that the result type of the function matches the operand type.
408 if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
409 resultType.isa<mlir::UnrankedTensorType>())
410 return mlir::success();
411
412 return emitError() << "type of return operand (" << inputType
413 << ") doesn't match function result type (" << resultType
414 << ")";
415 }
416
417 //===----------------------------------------------------------------------===//
418 // StructAccessOp
419 //===----------------------------------------------------------------------===//
420
build(mlir::OpBuilder & b,mlir::OperationState & state,mlir::Value input,size_t index)421 void StructAccessOp::build(mlir::OpBuilder &b, mlir::OperationState &state,
422 mlir::Value input, size_t index) {
423 // Extract the result type from the input type.
424 StructType structTy = input.getType().cast<StructType>();
425 assert(index < structTy.getNumElementTypes());
426 mlir::Type resultType = structTy.getElementTypes()[index];
427
428 // Call into the auto-generated build method.
429 build(b, state, resultType, input, b.getI64IntegerAttr(index));
430 }
431
verify()432 mlir::LogicalResult StructAccessOp::verify() {
433 StructType structTy = getInput().getType().cast<StructType>();
434 size_t indexValue = getIndex();
435 if (indexValue >= structTy.getNumElementTypes())
436 return emitOpError()
437 << "index should be within the range of the input struct type";
438 mlir::Type resultType = getResult().getType();
439 if (resultType != structTy.getElementTypes()[indexValue])
440 return emitOpError() << "must have the same result type as the struct "
441 "element referred to by the index";
442 return mlir::success();
443 }
444
445 //===----------------------------------------------------------------------===//
446 // TransposeOp
447 //===----------------------------------------------------------------------===//
448
build(mlir::OpBuilder & builder,mlir::OperationState & state,mlir::Value value)449 void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
450 mlir::Value value) {
451 state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
452 state.addOperands(value);
453 }
454
inferShapes()455 void TransposeOp::inferShapes() {
456 auto arrayTy = getOperand().getType().cast<RankedTensorType>();
457 SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
458 getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
459 }
460
verify()461 mlir::LogicalResult TransposeOp::verify() {
462 auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
463 auto resultType = getType().dyn_cast<RankedTensorType>();
464 if (!inputType || !resultType)
465 return mlir::success();
466
467 auto inputShape = inputType.getShape();
468 if (!std::equal(inputShape.begin(), inputShape.end(),
469 resultType.getShape().rbegin())) {
470 return emitError()
471 << "expected result shape to be a transpose of the input";
472 }
473 return mlir::success();
474 }
475
476 //===----------------------------------------------------------------------===//
477 // Toy Types
478 //===----------------------------------------------------------------------===//
479
480 namespace mlir {
481 namespace toy {
482 namespace detail {
483 /// This class represents the internal storage of the Toy `StructType`.
484 struct StructTypeStorage : public mlir::TypeStorage {
485 /// The `KeyTy` is a required type that provides an interface for the storage
486 /// instance. This type will be used when uniquing an instance of the type
487 /// storage. For our struct type, we will unique each instance structurally on
488 /// the elements that it contains.
489 using KeyTy = llvm::ArrayRef<mlir::Type>;
490
491 /// A constructor for the type storage instance.
StructTypeStoragemlir::toy::detail::StructTypeStorage492 StructTypeStorage(llvm::ArrayRef<mlir::Type> elementTypes)
493 : elementTypes(elementTypes) {}
494
495 /// Define the comparison function for the key type with the current storage
496 /// instance. This is used when constructing a new instance to ensure that we
497 /// haven't already uniqued an instance of the given key.
operator ==mlir::toy::detail::StructTypeStorage498 bool operator==(const KeyTy &key) const { return key == elementTypes; }
499
500 /// Define a hash function for the key type. This is used when uniquing
501 /// instances of the storage, see the `StructType::get` method.
502 /// Note: This method isn't necessary as both llvm::ArrayRef and mlir::Type
503 /// have hash functions available, so we could just omit this entirely.
hashKeymlir::toy::detail::StructTypeStorage504 static llvm::hash_code hashKey(const KeyTy &key) {
505 return llvm::hash_value(key);
506 }
507
508 /// Define a construction function for the key type from a set of parameters.
509 /// These parameters will be provided when constructing the storage instance
510 /// itself.
511 /// Note: This method isn't necessary because KeyTy can be directly
512 /// constructed with the given parameters.
getKeymlir::toy::detail::StructTypeStorage513 static KeyTy getKey(llvm::ArrayRef<mlir::Type> elementTypes) {
514 return KeyTy(elementTypes);
515 }
516
517 /// Define a construction method for creating a new instance of this storage.
518 /// This method takes an instance of a storage allocator, and an instance of a
519 /// `KeyTy`. The given allocator must be used for *all* necessary dynamic
520 /// allocations used to create the type storage and its internal.
constructmlir::toy::detail::StructTypeStorage521 static StructTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
522 const KeyTy &key) {
523 // Copy the elements from the provided `KeyTy` into the allocator.
524 llvm::ArrayRef<mlir::Type> elementTypes = allocator.copyInto(key);
525
526 // Allocate the storage instance and construct it.
527 return new (allocator.allocate<StructTypeStorage>())
528 StructTypeStorage(elementTypes);
529 }
530
531 /// The following field contains the element types of the struct.
532 llvm::ArrayRef<mlir::Type> elementTypes;
533 };
534 } // namespace detail
535 } // namespace toy
536 } // namespace mlir
537
538 /// Create an instance of a `StructType` with the given element types. There
539 /// *must* be at least one element type.
get(llvm::ArrayRef<mlir::Type> elementTypes)540 StructType StructType::get(llvm::ArrayRef<mlir::Type> elementTypes) {
541 assert(!elementTypes.empty() && "expected at least 1 element type");
542
543 // Call into a helper 'get' method in 'TypeBase' to get a uniqued instance
544 // of this type. The first parameter is the context to unique in. The
545 // parameters after the context are forwarded to the storage instance.
546 mlir::MLIRContext *ctx = elementTypes.front().getContext();
547 return Base::get(ctx, elementTypes);
548 }
549
550 /// Returns the element types of this struct type.
getElementTypes()551 llvm::ArrayRef<mlir::Type> StructType::getElementTypes() {
552 // 'getImpl' returns a pointer to the internal storage instance.
553 return getImpl()->elementTypes;
554 }
555
556 /// Parse an instance of a type registered to the toy dialect.
parseType(mlir::DialectAsmParser & parser) const557 mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const {
558 // Parse a struct type in the following form:
559 // struct-type ::= `struct` `<` type (`,` type)* `>`
560
561 // NOTE: All MLIR parser function return a ParseResult. This is a
562 // specialization of LogicalResult that auto-converts to a `true` boolean
563 // value on failure to allow for chaining, but may be used with explicit
564 // `mlir::failed/mlir::succeeded` as desired.
565
566 // Parse: `struct` `<`
567 if (parser.parseKeyword("struct") || parser.parseLess())
568 return Type();
569
570 // Parse the element types of the struct.
571 SmallVector<mlir::Type, 1> elementTypes;
572 do {
573 // Parse the current element type.
574 SMLoc typeLoc = parser.getCurrentLocation();
575 mlir::Type elementType;
576 if (parser.parseType(elementType))
577 return nullptr;
578
579 // Check that the type is either a TensorType or another StructType.
580 if (!elementType.isa<mlir::TensorType, StructType>()) {
581 parser.emitError(typeLoc, "element type for a struct must either "
582 "be a TensorType or a StructType, got: ")
583 << elementType;
584 return Type();
585 }
586 elementTypes.push_back(elementType);
587
588 // Parse the optional: `,`
589 } while (succeeded(parser.parseOptionalComma()));
590
591 // Parse: `>`
592 if (parser.parseGreater())
593 return Type();
594 return StructType::get(elementTypes);
595 }
596
597 /// Print an instance of a type registered to the toy dialect.
printType(mlir::Type type,mlir::DialectAsmPrinter & printer) const598 void ToyDialect::printType(mlir::Type type,
599 mlir::DialectAsmPrinter &printer) const {
600 // Currently the only toy type is a struct type.
601 StructType structType = type.cast<StructType>();
602
603 // Print the struct type according to the parser format.
604 printer << "struct<";
605 llvm::interleaveComma(structType.getElementTypes(), printer);
606 printer << '>';
607 }
608
609 //===----------------------------------------------------------------------===//
610 // TableGen'd op method definitions
611 //===----------------------------------------------------------------------===//
612
613 #define GET_OP_CLASSES
614 #include "toy/Ops.cpp.inc"
615
616 //===----------------------------------------------------------------------===//
617 // ToyDialect
618 //===----------------------------------------------------------------------===//
619
620 /// Dialect initialization, the instance will be owned by the context. This is
621 /// the point of registration of types and operations for the dialect.
initialize()622 void ToyDialect::initialize() {
623 addOperations<
624 #define GET_OP_LIST
625 #include "toy/Ops.cpp.inc"
626 >();
627 addInterfaces<ToyInlinerInterface>();
628 addTypes<StructType>();
629 }
630
materializeConstant(mlir::OpBuilder & builder,mlir::Attribute value,mlir::Type type,mlir::Location loc)631 mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder,
632 mlir::Attribute value,
633 mlir::Type type,
634 mlir::Location loc) {
635 if (type.isa<StructType>())
636 return builder.create<StructConstantOp>(loc, type,
637 value.cast<mlir::ArrayAttr>());
638 return builder.create<ConstantOp>(loc, type,
639 value.cast<mlir::DenseElementsAttr>());
640 }
641