1 //===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the dialect for the Toy IR: custom type parsing and
10 // operation verification.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "toy/Dialect.h"
15 
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/DialectImplementation.h"
19 #include "mlir/IR/FunctionImplementation.h"
20 #include "mlir/IR/OpImplementation.h"
21 #include "mlir/Transforms/InliningUtils.h"
22 
23 using namespace mlir;
24 using namespace mlir::toy;
25 
26 #include "toy/Dialect.cpp.inc"
27 
28 //===----------------------------------------------------------------------===//
29 // ToyInlinerInterface
30 //===----------------------------------------------------------------------===//
31 
32 /// This class defines the interface for handling inlining with Toy
33 /// operations.
34 struct ToyInlinerInterface : public DialectInlinerInterface {
35   using DialectInlinerInterface::DialectInlinerInterface;
36 
37   //===--------------------------------------------------------------------===//
38   // Analysis Hooks
39   //===--------------------------------------------------------------------===//
40 
41   /// All call operations within toy can be inlined.
isLegalToInlineToyInlinerInterface42   bool isLegalToInline(Operation *call, Operation *callable,
43                        bool wouldBeCloned) const final {
44     return true;
45   }
46 
47   /// All operations within toy can be inlined.
isLegalToInlineToyInlinerInterface48   bool isLegalToInline(Operation *, Region *, bool,
49                        BlockAndValueMapping &) const final {
50     return true;
51   }
52 
53   // All functions within toy can be inlined.
isLegalToInlineToyInlinerInterface54   bool isLegalToInline(Region *, Region *, bool,
55                        BlockAndValueMapping &) const final {
56     return true;
57   }
58 
59   //===--------------------------------------------------------------------===//
60   // Transformation Hooks
61   //===--------------------------------------------------------------------===//
62 
63   /// Handle the given inlined terminator(toy.return) by replacing it with a new
64   /// operation as necessary.
handleTerminatorToyInlinerInterface65   void handleTerminator(Operation *op,
66                         ArrayRef<Value> valuesToRepl) const final {
67     // Only "toy.return" needs to be handled here.
68     auto returnOp = cast<ReturnOp>(op);
69 
70     // Replace the values directly with the return operands.
71     assert(returnOp.getNumOperands() == valuesToRepl.size());
72     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
73       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
74   }
75 
76   /// Attempts to materialize a conversion for a type mismatch between a call
77   /// from this dialect, and a callable region. This method should generate an
78   /// operation that takes 'input' as the only operand, and produces a single
79   /// result of 'resultType'. If a conversion can not be generated, nullptr
80   /// should be returned.
materializeCallConversionToyInlinerInterface81   Operation *materializeCallConversion(OpBuilder &builder, Value input,
82                                        Type resultType,
83                                        Location conversionLoc) const final {
84     return builder.create<CastOp>(conversionLoc, resultType, input);
85   }
86 };
87 
88 //===----------------------------------------------------------------------===//
89 // Toy Operations
90 //===----------------------------------------------------------------------===//
91 
92 /// A generalized parser for binary operations. This parses the different forms
93 /// of 'printBinaryOp' below.
parseBinaryOp(mlir::OpAsmParser & parser,mlir::OperationState & result)94 static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
95                                        mlir::OperationState &result) {
96   SmallVector<mlir::OpAsmParser::UnresolvedOperand, 2> operands;
97   SMLoc operandsLoc = parser.getCurrentLocation();
98   Type type;
99   if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) ||
100       parser.parseOptionalAttrDict(result.attributes) ||
101       parser.parseColonType(type))
102     return mlir::failure();
103 
104   // If the type is a function type, it contains the input and result types of
105   // this operation.
106   if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
107     if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
108                                result.operands))
109       return mlir::failure();
110     result.addTypes(funcType.getResults());
111     return mlir::success();
112   }
113 
114   // Otherwise, the parsed type is the type of both operands and results.
115   if (parser.resolveOperands(operands, type, result.operands))
116     return mlir::failure();
117   result.addTypes(type);
118   return mlir::success();
119 }
120 
121 /// A generalized printer for binary operations. It prints in two different
122 /// forms depending on if all of the types match.
printBinaryOp(mlir::OpAsmPrinter & printer,mlir::Operation * op)123 static void printBinaryOp(mlir::OpAsmPrinter &printer, mlir::Operation *op) {
124   printer << " " << op->getOperands();
125   printer.printOptionalAttrDict(op->getAttrs());
126   printer << " : ";
127 
128   // If all of the types are the same, print the type directly.
129   Type resultType = *op->result_type_begin();
130   if (llvm::all_of(op->getOperandTypes(),
131                    [=](Type type) { return type == resultType; })) {
132     printer << resultType;
133     return;
134   }
135 
136   // Otherwise, print a functional type.
137   printer.printFunctionalType(op->getOperandTypes(), op->getResultTypes());
138 }
139 
140 //===----------------------------------------------------------------------===//
141 // ConstantOp
142 //===----------------------------------------------------------------------===//
143 
144 /// Build a constant operation.
145 /// The builder is passed as an argument, so is the state that this method is
146 /// expected to fill in order to build the operation.
build(mlir::OpBuilder & builder,mlir::OperationState & state,double value)147 void ConstantOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
148                        double value) {
149   auto dataType = RankedTensorType::get({}, builder.getF64Type());
150   auto dataAttribute = DenseElementsAttr::get(dataType, value);
151   ConstantOp::build(builder, state, dataType, dataAttribute);
152 }
153 
154 /// The 'OpAsmParser' class provides a collection of methods for parsing
155 /// various punctuation, as well as attributes, operands, types, etc. Each of
156 /// these methods returns a `ParseResult`. This class is a wrapper around
157 /// `LogicalResult` that can be converted to a boolean `true` value on failure,
158 /// or `false` on success. This allows for easily chaining together a set of
159 /// parser rules. These rules are used to populate an `mlir::OperationState`
160 /// similarly to the `build` methods described above.
parse(mlir::OpAsmParser & parser,mlir::OperationState & result)161 mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser &parser,
162                                     mlir::OperationState &result) {
163   mlir::DenseElementsAttr value;
164   if (parser.parseOptionalAttrDict(result.attributes) ||
165       parser.parseAttribute(value, "value", result.attributes))
166     return failure();
167 
168   result.addTypes(value.getType());
169   return success();
170 }
171 
172 /// The 'OpAsmPrinter' class is a stream that allows for formatting
173 /// strings, attributes, operands, types, etc.
print(mlir::OpAsmPrinter & printer)174 void ConstantOp::print(mlir::OpAsmPrinter &printer) {
175   printer << " ";
176   printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
177   printer << getValue();
178 }
179 
180 /// Verify that the given attribute value is valid for the given type.
verifyConstantForType(mlir::Type type,mlir::Attribute opaqueValue,mlir::Operation * op)181 static mlir::LogicalResult verifyConstantForType(mlir::Type type,
182                                                  mlir::Attribute opaqueValue,
183                                                  mlir::Operation *op) {
184   if (type.isa<mlir::TensorType>()) {
185     // Check that the value is an elements attribute.
186     auto attrValue = opaqueValue.dyn_cast<mlir::DenseFPElementsAttr>();
187     if (!attrValue)
188       return op->emitError("constant of TensorType must be initialized by "
189                            "a DenseFPElementsAttr, got ")
190              << opaqueValue;
191 
192     // If the return type of the constant is not an unranked tensor, the shape
193     // must match the shape of the attribute holding the data.
194     auto resultType = type.dyn_cast<mlir::RankedTensorType>();
195     if (!resultType)
196       return success();
197 
198     // Check that the rank of the attribute type matches the rank of the
199     // constant result type.
200     auto attrType = attrValue.getType().cast<mlir::TensorType>();
201     if (attrType.getRank() != resultType.getRank()) {
202       return op->emitOpError("return type must match the one of the attached "
203                              "value attribute: ")
204              << attrType.getRank() << " != " << resultType.getRank();
205     }
206 
207     // Check that each of the dimensions match between the two types.
208     for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
209       if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
210         return op->emitOpError(
211                    "return type shape mismatches its attribute at dimension ")
212                << dim << ": " << attrType.getShape()[dim]
213                << " != " << resultType.getShape()[dim];
214       }
215     }
216     return mlir::success();
217   }
218   auto resultType = type.cast<StructType>();
219   llvm::ArrayRef<mlir::Type> resultElementTypes = resultType.getElementTypes();
220 
221   // Verify that the initializer is an Array.
222   auto attrValue = opaqueValue.dyn_cast<ArrayAttr>();
223   if (!attrValue || attrValue.getValue().size() != resultElementTypes.size())
224     return op->emitError("constant of StructType must be initialized by an "
225                          "ArrayAttr with the same number of elements, got ")
226            << opaqueValue;
227 
228   // Check that each of the elements are valid.
229   llvm::ArrayRef<mlir::Attribute> attrElementValues = attrValue.getValue();
230   for (const auto it : llvm::zip(resultElementTypes, attrElementValues))
231     if (failed(verifyConstantForType(std::get<0>(it), std::get<1>(it), op)))
232       return mlir::failure();
233   return mlir::success();
234 }
235 
236 /// Verifier for the constant operation. This corresponds to the `::verify(...)`
237 /// in the op definition.
verify()238 mlir::LogicalResult ConstantOp::verify() {
239   return verifyConstantForType(getResult().getType(), getValue(), *this);
240 }
241 
verify()242 mlir::LogicalResult StructConstantOp::verify() {
243   return verifyConstantForType(getResult().getType(), getValue(), *this);
244 }
245 
246 /// Infer the output shape of the ConstantOp, this is required by the shape
247 /// inference interface.
inferShapes()248 void ConstantOp::inferShapes() { getResult().setType(getValue().getType()); }
249 
250 //===----------------------------------------------------------------------===//
251 // AddOp
252 //===----------------------------------------------------------------------===//
253 
build(mlir::OpBuilder & builder,mlir::OperationState & state,mlir::Value lhs,mlir::Value rhs)254 void AddOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
255                   mlir::Value lhs, mlir::Value rhs) {
256   state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
257   state.addOperands({lhs, rhs});
258 }
259 
parse(mlir::OpAsmParser & parser,mlir::OperationState & result)260 mlir::ParseResult AddOp::parse(mlir::OpAsmParser &parser,
261                                mlir::OperationState &result) {
262   return parseBinaryOp(parser, result);
263 }
264 
print(mlir::OpAsmPrinter & p)265 void AddOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); }
266 
267 /// Infer the output shape of the AddOp, this is required by the shape inference
268 /// interface.
inferShapes()269 void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
270 
271 //===----------------------------------------------------------------------===//
272 // CastOp
273 //===----------------------------------------------------------------------===//
274 
275 /// Infer the output shape of the CastOp, this is required by the shape
276 /// inference interface.
inferShapes()277 void CastOp::inferShapes() { getResult().setType(getOperand().getType()); }
278 
279 /// Returns true if the given set of input and result types are compatible with
280 /// this cast operation. This is required by the `CastOpInterface` to verify
281 /// this operation and provide other additional utilities.
areCastCompatible(TypeRange inputs,TypeRange outputs)282 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
283   if (inputs.size() != 1 || outputs.size() != 1)
284     return false;
285   // The inputs must be Tensors with the same element type.
286   TensorType input = inputs.front().dyn_cast<TensorType>();
287   TensorType output = outputs.front().dyn_cast<TensorType>();
288   if (!input || !output || input.getElementType() != output.getElementType())
289     return false;
290   // The shape is required to match if both types are ranked.
291   return !input.hasRank() || !output.hasRank() || input == output;
292 }
293 
294 //===----------------------------------------------------------------------===//
295 // FuncOp
296 //===----------------------------------------------------------------------===//
297 
build(mlir::OpBuilder & builder,mlir::OperationState & state,llvm::StringRef name,mlir::FunctionType type,llvm::ArrayRef<mlir::NamedAttribute> attrs)298 void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
299                    llvm::StringRef name, mlir::FunctionType type,
300                    llvm::ArrayRef<mlir::NamedAttribute> attrs) {
301   // FunctionOpInterface provides a convenient `build` method that will populate
302   // the state of our FuncOp, and create an entry block.
303   buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
304 }
305 
parse(mlir::OpAsmParser & parser,mlir::OperationState & result)306 mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
307                                 mlir::OperationState &result) {
308   // Dispatch to the FunctionOpInterface provided utility method that parses the
309   // function operation.
310   auto buildFuncType =
311       [](mlir::Builder &builder, llvm::ArrayRef<mlir::Type> argTypes,
312          llvm::ArrayRef<mlir::Type> results,
313          mlir::function_interface_impl::VariadicFlag,
314          std::string &) { return builder.getFunctionType(argTypes, results); };
315 
316   return mlir::function_interface_impl::parseFunctionOp(
317       parser, result, /*allowVariadic=*/false, buildFuncType);
318 }
319 
print(mlir::OpAsmPrinter & p)320 void FuncOp::print(mlir::OpAsmPrinter &p) {
321   // Dispatch to the FunctionOpInterface provided utility method that prints the
322   // function operation.
323   mlir::function_interface_impl::printFunctionOp(p, *this,
324                                                  /*isVariadic=*/false);
325 }
326 
327 /// Returns the region on the function operation that is callable.
getCallableRegion()328 mlir::Region *FuncOp::getCallableRegion() { return &getBody(); }
329 
330 /// Returns the results types that the callable region produces when
331 /// executed.
getCallableResults()332 llvm::ArrayRef<mlir::Type> FuncOp::getCallableResults() {
333   return getFunctionType().getResults();
334 }
335 
336 //===----------------------------------------------------------------------===//
337 // GenericCallOp
338 //===----------------------------------------------------------------------===//
339 
build(mlir::OpBuilder & builder,mlir::OperationState & state,StringRef callee,ArrayRef<mlir::Value> arguments)340 void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
341                           StringRef callee, ArrayRef<mlir::Value> arguments) {
342   // Generic call always returns an unranked Tensor initially.
343   state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
344   state.addOperands(arguments);
345   state.addAttribute("callee",
346                      mlir::SymbolRefAttr::get(builder.getContext(), callee));
347 }
348 
349 /// Return the callee of the generic call operation, this is required by the
350 /// call interface.
getCallableForCallee()351 CallInterfaceCallable GenericCallOp::getCallableForCallee() {
352   return (*this)->getAttrOfType<SymbolRefAttr>("callee");
353 }
354 
355 /// Get the argument operands to the called function, this is required by the
356 /// call interface.
getArgOperands()357 Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
358 
359 //===----------------------------------------------------------------------===//
360 // MulOp
361 //===----------------------------------------------------------------------===//
362 
build(mlir::OpBuilder & builder,mlir::OperationState & state,mlir::Value lhs,mlir::Value rhs)363 void MulOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
364                   mlir::Value lhs, mlir::Value rhs) {
365   state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
366   state.addOperands({lhs, rhs});
367 }
368 
parse(mlir::OpAsmParser & parser,mlir::OperationState & result)369 mlir::ParseResult MulOp::parse(mlir::OpAsmParser &parser,
370                                mlir::OperationState &result) {
371   return parseBinaryOp(parser, result);
372 }
373 
print(mlir::OpAsmPrinter & p)374 void MulOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); }
375 
376 /// Infer the output shape of the MulOp, this is required by the shape inference
377 /// interface.
inferShapes()378 void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
379 
380 //===----------------------------------------------------------------------===//
381 // ReturnOp
382 //===----------------------------------------------------------------------===//
383 
verify()384 mlir::LogicalResult ReturnOp::verify() {
385   // We know that the parent operation is a function, because of the 'HasParent'
386   // trait attached to the operation definition.
387   auto function = cast<FuncOp>((*this)->getParentOp());
388 
389   /// ReturnOps can only have a single optional operand.
390   if (getNumOperands() > 1)
391     return emitOpError() << "expects at most 1 return operand";
392 
393   // The operand number and types must match the function signature.
394   const auto &results = function.getFunctionType().getResults();
395   if (getNumOperands() != results.size())
396     return emitOpError() << "does not return the same number of values ("
397                          << getNumOperands() << ") as the enclosing function ("
398                          << results.size() << ")";
399 
400   // If the operation does not have an input, we are done.
401   if (!hasOperand())
402     return mlir::success();
403 
404   auto inputType = *operand_type_begin();
405   auto resultType = results.front();
406 
407   // Check that the result type of the function matches the operand type.
408   if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
409       resultType.isa<mlir::UnrankedTensorType>())
410     return mlir::success();
411 
412   return emitError() << "type of return operand (" << inputType
413                      << ") doesn't match function result type (" << resultType
414                      << ")";
415 }
416 
417 //===----------------------------------------------------------------------===//
418 // StructAccessOp
419 //===----------------------------------------------------------------------===//
420 
build(mlir::OpBuilder & b,mlir::OperationState & state,mlir::Value input,size_t index)421 void StructAccessOp::build(mlir::OpBuilder &b, mlir::OperationState &state,
422                            mlir::Value input, size_t index) {
423   // Extract the result type from the input type.
424   StructType structTy = input.getType().cast<StructType>();
425   assert(index < structTy.getNumElementTypes());
426   mlir::Type resultType = structTy.getElementTypes()[index];
427 
428   // Call into the auto-generated build method.
429   build(b, state, resultType, input, b.getI64IntegerAttr(index));
430 }
431 
verify()432 mlir::LogicalResult StructAccessOp::verify() {
433   StructType structTy = getInput().getType().cast<StructType>();
434   size_t indexValue = getIndex();
435   if (indexValue >= structTy.getNumElementTypes())
436     return emitOpError()
437            << "index should be within the range of the input struct type";
438   mlir::Type resultType = getResult().getType();
439   if (resultType != structTy.getElementTypes()[indexValue])
440     return emitOpError() << "must have the same result type as the struct "
441                             "element referred to by the index";
442   return mlir::success();
443 }
444 
445 //===----------------------------------------------------------------------===//
446 // TransposeOp
447 //===----------------------------------------------------------------------===//
448 
build(mlir::OpBuilder & builder,mlir::OperationState & state,mlir::Value value)449 void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
450                         mlir::Value value) {
451   state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
452   state.addOperands(value);
453 }
454 
inferShapes()455 void TransposeOp::inferShapes() {
456   auto arrayTy = getOperand().getType().cast<RankedTensorType>();
457   SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
458   getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
459 }
460 
verify()461 mlir::LogicalResult TransposeOp::verify() {
462   auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
463   auto resultType = getType().dyn_cast<RankedTensorType>();
464   if (!inputType || !resultType)
465     return mlir::success();
466 
467   auto inputShape = inputType.getShape();
468   if (!std::equal(inputShape.begin(), inputShape.end(),
469                   resultType.getShape().rbegin())) {
470     return emitError()
471            << "expected result shape to be a transpose of the input";
472   }
473   return mlir::success();
474 }
475 
476 //===----------------------------------------------------------------------===//
477 // Toy Types
478 //===----------------------------------------------------------------------===//
479 
480 namespace mlir {
481 namespace toy {
482 namespace detail {
483 /// This class represents the internal storage of the Toy `StructType`.
484 struct StructTypeStorage : public mlir::TypeStorage {
485   /// The `KeyTy` is a required type that provides an interface for the storage
486   /// instance. This type will be used when uniquing an instance of the type
487   /// storage. For our struct type, we will unique each instance structurally on
488   /// the elements that it contains.
489   using KeyTy = llvm::ArrayRef<mlir::Type>;
490 
491   /// A constructor for the type storage instance.
StructTypeStoragemlir::toy::detail::StructTypeStorage492   StructTypeStorage(llvm::ArrayRef<mlir::Type> elementTypes)
493       : elementTypes(elementTypes) {}
494 
495   /// Define the comparison function for the key type with the current storage
496   /// instance. This is used when constructing a new instance to ensure that we
497   /// haven't already uniqued an instance of the given key.
operator ==mlir::toy::detail::StructTypeStorage498   bool operator==(const KeyTy &key) const { return key == elementTypes; }
499 
500   /// Define a hash function for the key type. This is used when uniquing
501   /// instances of the storage, see the `StructType::get` method.
502   /// Note: This method isn't necessary as both llvm::ArrayRef and mlir::Type
503   /// have hash functions available, so we could just omit this entirely.
hashKeymlir::toy::detail::StructTypeStorage504   static llvm::hash_code hashKey(const KeyTy &key) {
505     return llvm::hash_value(key);
506   }
507 
508   /// Define a construction function for the key type from a set of parameters.
509   /// These parameters will be provided when constructing the storage instance
510   /// itself.
511   /// Note: This method isn't necessary because KeyTy can be directly
512   /// constructed with the given parameters.
getKeymlir::toy::detail::StructTypeStorage513   static KeyTy getKey(llvm::ArrayRef<mlir::Type> elementTypes) {
514     return KeyTy(elementTypes);
515   }
516 
517   /// Define a construction method for creating a new instance of this storage.
518   /// This method takes an instance of a storage allocator, and an instance of a
519   /// `KeyTy`. The given allocator must be used for *all* necessary dynamic
520   /// allocations used to create the type storage and its internal.
constructmlir::toy::detail::StructTypeStorage521   static StructTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
522                                       const KeyTy &key) {
523     // Copy the elements from the provided `KeyTy` into the allocator.
524     llvm::ArrayRef<mlir::Type> elementTypes = allocator.copyInto(key);
525 
526     // Allocate the storage instance and construct it.
527     return new (allocator.allocate<StructTypeStorage>())
528         StructTypeStorage(elementTypes);
529   }
530 
531   /// The following field contains the element types of the struct.
532   llvm::ArrayRef<mlir::Type> elementTypes;
533 };
534 } // namespace detail
535 } // namespace toy
536 } // namespace mlir
537 
538 /// Create an instance of a `StructType` with the given element types. There
539 /// *must* be at least one element type.
get(llvm::ArrayRef<mlir::Type> elementTypes)540 StructType StructType::get(llvm::ArrayRef<mlir::Type> elementTypes) {
541   assert(!elementTypes.empty() && "expected at least 1 element type");
542 
543   // Call into a helper 'get' method in 'TypeBase' to get a uniqued instance
544   // of this type. The first parameter is the context to unique in. The
545   // parameters after the context are forwarded to the storage instance.
546   mlir::MLIRContext *ctx = elementTypes.front().getContext();
547   return Base::get(ctx, elementTypes);
548 }
549 
550 /// Returns the element types of this struct type.
getElementTypes()551 llvm::ArrayRef<mlir::Type> StructType::getElementTypes() {
552   // 'getImpl' returns a pointer to the internal storage instance.
553   return getImpl()->elementTypes;
554 }
555 
556 /// Parse an instance of a type registered to the toy dialect.
parseType(mlir::DialectAsmParser & parser) const557 mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const {
558   // Parse a struct type in the following form:
559   //   struct-type ::= `struct` `<` type (`,` type)* `>`
560 
561   // NOTE: All MLIR parser function return a ParseResult. This is a
562   // specialization of LogicalResult that auto-converts to a `true` boolean
563   // value on failure to allow for chaining, but may be used with explicit
564   // `mlir::failed/mlir::succeeded` as desired.
565 
566   // Parse: `struct` `<`
567   if (parser.parseKeyword("struct") || parser.parseLess())
568     return Type();
569 
570   // Parse the element types of the struct.
571   SmallVector<mlir::Type, 1> elementTypes;
572   do {
573     // Parse the current element type.
574     SMLoc typeLoc = parser.getCurrentLocation();
575     mlir::Type elementType;
576     if (parser.parseType(elementType))
577       return nullptr;
578 
579     // Check that the type is either a TensorType or another StructType.
580     if (!elementType.isa<mlir::TensorType, StructType>()) {
581       parser.emitError(typeLoc, "element type for a struct must either "
582                                 "be a TensorType or a StructType, got: ")
583           << elementType;
584       return Type();
585     }
586     elementTypes.push_back(elementType);
587 
588     // Parse the optional: `,`
589   } while (succeeded(parser.parseOptionalComma()));
590 
591   // Parse: `>`
592   if (parser.parseGreater())
593     return Type();
594   return StructType::get(elementTypes);
595 }
596 
597 /// Print an instance of a type registered to the toy dialect.
printType(mlir::Type type,mlir::DialectAsmPrinter & printer) const598 void ToyDialect::printType(mlir::Type type,
599                            mlir::DialectAsmPrinter &printer) const {
600   // Currently the only toy type is a struct type.
601   StructType structType = type.cast<StructType>();
602 
603   // Print the struct type according to the parser format.
604   printer << "struct<";
605   llvm::interleaveComma(structType.getElementTypes(), printer);
606   printer << '>';
607 }
608 
609 //===----------------------------------------------------------------------===//
610 // TableGen'd op method definitions
611 //===----------------------------------------------------------------------===//
612 
613 #define GET_OP_CLASSES
614 #include "toy/Ops.cpp.inc"
615 
616 //===----------------------------------------------------------------------===//
617 // ToyDialect
618 //===----------------------------------------------------------------------===//
619 
620 /// Dialect initialization, the instance will be owned by the context. This is
621 /// the point of registration of types and operations for the dialect.
initialize()622 void ToyDialect::initialize() {
623   addOperations<
624 #define GET_OP_LIST
625 #include "toy/Ops.cpp.inc"
626       >();
627   addInterfaces<ToyInlinerInterface>();
628   addTypes<StructType>();
629 }
630 
materializeConstant(mlir::OpBuilder & builder,mlir::Attribute value,mlir::Type type,mlir::Location loc)631 mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder,
632                                                  mlir::Attribute value,
633                                                  mlir::Type type,
634                                                  mlir::Location loc) {
635   if (type.isa<StructType>())
636     return builder.create<StructConstantOp>(loc, type,
637                                             value.cast<mlir::ArrayAttr>());
638   return builder.create<ConstantOp>(loc, type,
639                                     value.cast<mlir::DenseElementsAttr>());
640 }
641