1 //===- TestTypes.h - MLIR Test Dialect Types --------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains types defined by the TestDialect for testing various
10 // features of MLIR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_TESTTYPES_H
15 #define MLIR_TESTTYPES_H
16 
17 #include <tuple>
18 
19 #include "TestTraits.h"
20 #include "mlir/IR/Diagnostics.h"
21 #include "mlir/IR/Dialect.h"
22 #include "mlir/IR/DialectImplementation.h"
23 #include "mlir/IR/Operation.h"
24 #include "mlir/IR/SubElementInterfaces.h"
25 #include "mlir/IR/Types.h"
26 #include "mlir/Interfaces/DataLayoutInterfaces.h"
27 
28 namespace test {
29 class TestAttrWithFormatAttr;
30 
31 /// FieldInfo represents a field in the StructType data type. It is used as a
32 /// parameter in TestTypeDefs.td.
33 struct FieldInfo {
34   ::llvm::StringRef name;
35   ::mlir::Type type;
36 
37   // Custom allocation called from generated constructor code
allocateIntoFieldInfo38   FieldInfo allocateInto(::mlir::TypeStorageAllocator &alloc) const {
39     return FieldInfo{alloc.copyInto(name), type};
40   }
41 };
42 
43 /// A custom type for a test type parameter.
44 struct CustomParam {
45   int value;
46 
47   bool operator==(const CustomParam &other) const {
48     return other.value == value;
49   }
50 };
51 
hash_value(const test::CustomParam & param)52 inline llvm::hash_code hash_value(const test::CustomParam &param) {
53   return llvm::hash_value(param.value);
54 }
55 
56 } // namespace test
57 
58 namespace mlir {
59 template <>
60 struct FieldParser<test::CustomParam> {
61   static FailureOr<test::CustomParam> parse(AsmParser &parser) {
62     auto value = FieldParser<int>::parse(parser);
63     if (failed(value))
64       return failure();
65     return test::CustomParam{*value};
66   }
67 };
68 
69 inline mlir::AsmPrinter &operator<<(mlir::AsmPrinter &printer,
70                                     test::CustomParam param) {
71   return printer << param.value;
72 }
73 
74 /// Overload the attribute parameter parser for optional integers.
75 template <>
76 struct FieldParser<Optional<int>> {
77   static FailureOr<Optional<int>> parse(AsmParser &parser) {
78     Optional<int> value;
79     value.emplace();
80     OptionalParseResult result = parser.parseOptionalInteger(*value);
81     if (result.hasValue()) {
82       if (succeeded(*result))
83         return value;
84       return failure();
85     }
86     value.reset();
87     return value;
88   }
89 };
90 } // namespace mlir
91 
92 #include "TestTypeInterfaces.h.inc"
93 
94 #define GET_TYPEDEF_CLASSES
95 #include "TestTypeDefs.h.inc"
96 
97 namespace test {
98 
99 /// Storage for simple named recursive types, where the type is identified by
100 /// its name and can "contain" another type, including itself.
101 struct TestRecursiveTypeStorage : public ::mlir::TypeStorage {
102   using KeyTy = ::llvm::StringRef;
103 
104   explicit TestRecursiveTypeStorage(::llvm::StringRef key)
105       : name(key), body(::mlir::Type()) {}
106 
107   bool operator==(const KeyTy &other) const { return name == other; }
108 
109   static TestRecursiveTypeStorage *
110   construct(::mlir::TypeStorageAllocator &allocator, const KeyTy &key) {
111     return new (allocator.allocate<TestRecursiveTypeStorage>())
112         TestRecursiveTypeStorage(allocator.copyInto(key));
113   }
114 
115   ::mlir::LogicalResult mutate(::mlir::TypeStorageAllocator &allocator,
116                                ::mlir::Type newBody) {
117     // Cannot set a different body than before.
118     if (body && body != newBody)
119       return ::mlir::failure();
120 
121     body = newBody;
122     return ::mlir::success();
123   }
124 
125   ::llvm::StringRef name;
126   ::mlir::Type body;
127 };
128 
129 /// Simple recursive type identified by its name and pointing to another named
130 /// type, potentially itself. This requires the body to be mutated separately
131 /// from type creation.
132 class TestRecursiveType
133     : public ::mlir::Type::TypeBase<TestRecursiveType, ::mlir::Type,
134                                     TestRecursiveTypeStorage,
135                                     ::mlir::SubElementTypeInterface::Trait,
136                                     ::mlir::TypeTrait::IsMutable> {
137 public:
138   using Base::Base;
139 
140   static TestRecursiveType get(::mlir::MLIRContext *ctx,
141                                ::llvm::StringRef name) {
142     return Base::get(ctx, name);
143   }
144 
145   /// Body getter and setter.
146   ::mlir::LogicalResult setBody(Type body) { return Base::mutate(body); }
147   ::mlir::Type getBody() const { return getImpl()->body; }
148 
149   /// Name/key getter.
150   ::llvm::StringRef getName() { return getImpl()->name; }
151 
152   void walkImmediateSubElements(
153       ::llvm::function_ref<void(::mlir::Attribute)> walkAttrsFn,
154       ::llvm::function_ref<void(::mlir::Type)> walkTypesFn) const {
155     walkTypesFn(getBody());
156   }
157   Type replaceImmediateSubElements(llvm::ArrayRef<mlir::Attribute> replAttrs,
158                                    llvm::ArrayRef<mlir::Type> replTypes) const {
159     // TODO: It's not clear how we support replacing sub-elements of mutable
160     // types.
161     return nullptr;
162   }
163 };
164 
165 } // namespace test
166 
167 #endif // MLIR_TESTTYPES_H
168