1 //===- Types.cpp ----------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Tools/PDLL/AST/Types.h"
10 #include "TypeDetail.h"
11 #include "mlir/Tools/PDLL/AST/Context.h"
12 
13 using namespace mlir;
14 using namespace mlir::pdll;
15 using namespace mlir::pdll::ast;
16 
17 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::AttributeTypeStorage)
MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::ConstraintTypeStorage)18 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::ConstraintTypeStorage)
19 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::OperationTypeStorage)
20 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::RangeTypeStorage)
21 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::RewriteTypeStorage)
22 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::TupleTypeStorage)
23 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::TypeTypeStorage)
24 MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::ValueTypeStorage)
25 
26 //===----------------------------------------------------------------------===//
27 // Type
28 //===----------------------------------------------------------------------===//
29 
30 TypeID Type::getTypeID() const { return impl->typeID; }
31 
refineWith(Type other) const32 Type Type::refineWith(Type other) const {
33   if (*this == other)
34     return *this;
35 
36   // Operation types are compatible if the operation names don't conflict.
37   if (auto opTy = dyn_cast<OperationType>()) {
38     auto otherOpTy = other.dyn_cast<ast::OperationType>();
39     if (!otherOpTy)
40       return nullptr;
41     if (!otherOpTy.getName())
42       return *this;
43     if (!opTy.getName())
44       return other;
45 
46     return nullptr;
47   }
48 
49   return nullptr;
50 }
51 
52 //===----------------------------------------------------------------------===//
53 // AttributeType
54 //===----------------------------------------------------------------------===//
55 
get(Context & context)56 AttributeType AttributeType::get(Context &context) {
57   return context.getTypeUniquer().get<ImplTy>();
58 }
59 
60 //===----------------------------------------------------------------------===//
61 // ConstraintType
62 //===----------------------------------------------------------------------===//
63 
get(Context & context)64 ConstraintType ConstraintType::get(Context &context) {
65   return context.getTypeUniquer().get<ImplTy>();
66 }
67 
68 //===----------------------------------------------------------------------===//
69 // OperationType
70 //===----------------------------------------------------------------------===//
71 
get(Context & context,Optional<StringRef> name,const ods::Operation * odsOp)72 OperationType OperationType::get(Context &context, Optional<StringRef> name,
73                                  const ods::Operation *odsOp) {
74   return context.getTypeUniquer().get<ImplTy>(
75       /*initFn=*/function_ref<void(ImplTy *)>(),
76       std::make_pair(name.value_or(""), odsOp));
77 }
78 
getName() const79 Optional<StringRef> OperationType::getName() const {
80   StringRef name = getImplAs<ImplTy>()->getValue().first;
81   return name.empty() ? Optional<StringRef>() : Optional<StringRef>(name);
82 }
83 
getODSOperation() const84 const ods::Operation *OperationType::getODSOperation() const {
85   return getImplAs<ImplTy>()->getValue().second;
86 }
87 
88 //===----------------------------------------------------------------------===//
89 // RangeType
90 //===----------------------------------------------------------------------===//
91 
get(Context & context,Type elementType)92 RangeType RangeType::get(Context &context, Type elementType) {
93   return context.getTypeUniquer().get<ImplTy>(
94       /*initFn=*/function_ref<void(ImplTy *)>(), elementType);
95 }
96 
getElementType() const97 Type RangeType::getElementType() const {
98   return getImplAs<ImplTy>()->getValue();
99 }
100 
101 //===----------------------------------------------------------------------===//
102 // TypeRangeType
103 
classof(Type type)104 bool TypeRangeType::classof(Type type) {
105   RangeType range = type.dyn_cast<RangeType>();
106   return range && range.getElementType().isa<TypeType>();
107 }
108 
get(Context & context)109 TypeRangeType TypeRangeType::get(Context &context) {
110   return RangeType::get(context, TypeType::get(context)).cast<TypeRangeType>();
111 }
112 
113 //===----------------------------------------------------------------------===//
114 // ValueRangeType
115 
classof(Type type)116 bool ValueRangeType::classof(Type type) {
117   RangeType range = type.dyn_cast<RangeType>();
118   return range && range.getElementType().isa<ValueType>();
119 }
120 
get(Context & context)121 ValueRangeType ValueRangeType::get(Context &context) {
122   return RangeType::get(context, ValueType::get(context))
123       .cast<ValueRangeType>();
124 }
125 
126 //===----------------------------------------------------------------------===//
127 // RewriteType
128 //===----------------------------------------------------------------------===//
129 
get(Context & context)130 RewriteType RewriteType::get(Context &context) {
131   return context.getTypeUniquer().get<ImplTy>();
132 }
133 
134 //===----------------------------------------------------------------------===//
135 // TupleType
136 //===----------------------------------------------------------------------===//
137 
get(Context & context,ArrayRef<Type> elementTypes,ArrayRef<StringRef> elementNames)138 TupleType TupleType::get(Context &context, ArrayRef<Type> elementTypes,
139                          ArrayRef<StringRef> elementNames) {
140   assert(elementTypes.size() == elementNames.size());
141   return context.getTypeUniquer().get<ImplTy>(
142       /*initFn=*/function_ref<void(ImplTy *)>(), elementTypes, elementNames);
143 }
get(Context & context,ArrayRef<Type> elementTypes)144 TupleType TupleType::get(Context &context, ArrayRef<Type> elementTypes) {
145   SmallVector<StringRef> elementNames(elementTypes.size());
146   return get(context, elementTypes, elementNames);
147 }
148 
getElementTypes() const149 ArrayRef<Type> TupleType::getElementTypes() const {
150   return getImplAs<ImplTy>()->getValue().first;
151 }
152 
getElementNames() const153 ArrayRef<StringRef> TupleType::getElementNames() const {
154   return getImplAs<ImplTy>()->getValue().second;
155 }
156 
157 //===----------------------------------------------------------------------===//
158 // TypeType
159 //===----------------------------------------------------------------------===//
160 
get(Context & context)161 TypeType TypeType::get(Context &context) {
162   return context.getTypeUniquer().get<ImplTy>();
163 }
164 
165 //===----------------------------------------------------------------------===//
166 // ValueType
167 //===----------------------------------------------------------------------===//
168 
get(Context & context)169 ValueType ValueType::get(Context &context) {
170   return context.getTypeUniquer().get<ImplTy>();
171 }
172