1 //===- Types.cpp ----------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Tools/PDLL/AST/Types.h"
10 #include "TypeDetail.h"
11 #include "mlir/Tools/PDLL/AST/Context.h"
12
13 using namespace mlir;
14 using namespace mlir::pdll;
15 using namespace mlir::pdll::ast;
16
17 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::AttributeTypeStorage)
MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::ConstraintTypeStorage)18 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::ConstraintTypeStorage)
19 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::OperationTypeStorage)
20 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::RangeTypeStorage)
21 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::RewriteTypeStorage)
22 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::TupleTypeStorage)
23 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::TypeTypeStorage)
24 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::ValueTypeStorage)
25
26 //===----------------------------------------------------------------------===//
27 // Type
28 //===----------------------------------------------------------------------===//
29
30 TypeID Type::getTypeID() const { return impl->typeID; }
31
refineWith(Type other) const32 Type Type::refineWith(Type other) const {
33 if (*this == other)
34 return *this;
35
36 // Operation types are compatible if the operation names don't conflict.
37 if (auto opTy = dyn_cast<OperationType>()) {
38 auto otherOpTy = other.dyn_cast<ast::OperationType>();
39 if (!otherOpTy)
40 return nullptr;
41 if (!otherOpTy.getName())
42 return *this;
43 if (!opTy.getName())
44 return other;
45
46 return nullptr;
47 }
48
49 return nullptr;
50 }
51
52 //===----------------------------------------------------------------------===//
53 // AttributeType
54 //===----------------------------------------------------------------------===//
55
get(Context & context)56 AttributeType AttributeType::get(Context &context) {
57 return context.getTypeUniquer().get<ImplTy>();
58 }
59
60 //===----------------------------------------------------------------------===//
61 // ConstraintType
62 //===----------------------------------------------------------------------===//
63
get(Context & context)64 ConstraintType ConstraintType::get(Context &context) {
65 return context.getTypeUniquer().get<ImplTy>();
66 }
67
68 //===----------------------------------------------------------------------===//
69 // OperationType
70 //===----------------------------------------------------------------------===//
71
get(Context & context,Optional<StringRef> name,const ods::Operation * odsOp)72 OperationType OperationType::get(Context &context, Optional<StringRef> name,
73 const ods::Operation *odsOp) {
74 return context.getTypeUniquer().get<ImplTy>(
75 /*initFn=*/function_ref<void(ImplTy *)>(),
76 std::make_pair(name.value_or(""), odsOp));
77 }
78
getName() const79 Optional<StringRef> OperationType::getName() const {
80 StringRef name = getImplAs<ImplTy>()->getValue().first;
81 return name.empty() ? Optional<StringRef>() : Optional<StringRef>(name);
82 }
83
getODSOperation() const84 const ods::Operation *OperationType::getODSOperation() const {
85 return getImplAs<ImplTy>()->getValue().second;
86 }
87
88 //===----------------------------------------------------------------------===//
89 // RangeType
90 //===----------------------------------------------------------------------===//
91
get(Context & context,Type elementType)92 RangeType RangeType::get(Context &context, Type elementType) {
93 return context.getTypeUniquer().get<ImplTy>(
94 /*initFn=*/function_ref<void(ImplTy *)>(), elementType);
95 }
96
getElementType() const97 Type RangeType::getElementType() const {
98 return getImplAs<ImplTy>()->getValue();
99 }
100
101 //===----------------------------------------------------------------------===//
102 // TypeRangeType
103
classof(Type type)104 bool TypeRangeType::classof(Type type) {
105 RangeType range = type.dyn_cast<RangeType>();
106 return range && range.getElementType().isa<TypeType>();
107 }
108
get(Context & context)109 TypeRangeType TypeRangeType::get(Context &context) {
110 return RangeType::get(context, TypeType::get(context)).cast<TypeRangeType>();
111 }
112
113 //===----------------------------------------------------------------------===//
114 // ValueRangeType
115
classof(Type type)116 bool ValueRangeType::classof(Type type) {
117 RangeType range = type.dyn_cast<RangeType>();
118 return range && range.getElementType().isa<ValueType>();
119 }
120
get(Context & context)121 ValueRangeType ValueRangeType::get(Context &context) {
122 return RangeType::get(context, ValueType::get(context))
123 .cast<ValueRangeType>();
124 }
125
126 //===----------------------------------------------------------------------===//
127 // RewriteType
128 //===----------------------------------------------------------------------===//
129
get(Context & context)130 RewriteType RewriteType::get(Context &context) {
131 return context.getTypeUniquer().get<ImplTy>();
132 }
133
134 //===----------------------------------------------------------------------===//
135 // TupleType
136 //===----------------------------------------------------------------------===//
137
get(Context & context,ArrayRef<Type> elementTypes,ArrayRef<StringRef> elementNames)138 TupleType TupleType::get(Context &context, ArrayRef<Type> elementTypes,
139 ArrayRef<StringRef> elementNames) {
140 assert(elementTypes.size() == elementNames.size());
141 return context.getTypeUniquer().get<ImplTy>(
142 /*initFn=*/function_ref<void(ImplTy *)>(), elementTypes, elementNames);
143 }
get(Context & context,ArrayRef<Type> elementTypes)144 TupleType TupleType::get(Context &context, ArrayRef<Type> elementTypes) {
145 SmallVector<StringRef> elementNames(elementTypes.size());
146 return get(context, elementTypes, elementNames);
147 }
148
getElementTypes() const149 ArrayRef<Type> TupleType::getElementTypes() const {
150 return getImplAs<ImplTy>()->getValue().first;
151 }
152
getElementNames() const153 ArrayRef<StringRef> TupleType::getElementNames() const {
154 return getImplAs<ImplTy>()->getValue().second;
155 }
156
157 //===----------------------------------------------------------------------===//
158 // TypeType
159 //===----------------------------------------------------------------------===//
160
get(Context & context)161 TypeType TypeType::get(Context &context) {
162 return context.getTypeUniquer().get<ImplTy>();
163 }
164
165 //===----------------------------------------------------------------------===//
166 // ValueType
167 //===----------------------------------------------------------------------===//
168
get(Context & context)169 ValueType ValueType::get(Context &context) {
170 return context.getTypeUniquer().get<ImplTy>();
171 }
172