1c42dd5dbSRiver Riddle //===- SubElementInterfaces.cpp - Attr and Type SubElement Interfaces -----===//
2c42dd5dbSRiver Riddle //
3c42dd5dbSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4c42dd5dbSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5c42dd5dbSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c42dd5dbSRiver Riddle //
7c42dd5dbSRiver Riddle //===----------------------------------------------------------------------===//
8c42dd5dbSRiver Riddle 
9c42dd5dbSRiver Riddle #include "mlir/IR/SubElementInterfaces.h"
10c42dd5dbSRiver Riddle 
11d4102861SMin-Yih Hsu #include "llvm/ADT/DenseSet.h"
12d4102861SMin-Yih Hsu 
13c42dd5dbSRiver Riddle using namespace mlir;
14c42dd5dbSRiver Riddle 
15*01eedbc7SRiver Riddle //===----------------------------------------------------------------------===//
16*01eedbc7SRiver Riddle // SubElementInterface
17*01eedbc7SRiver Riddle //===----------------------------------------------------------------------===//
18*01eedbc7SRiver Riddle 
19*01eedbc7SRiver Riddle //===----------------------------------------------------------------------===//
20*01eedbc7SRiver Riddle // WalkSubElements
21*01eedbc7SRiver Riddle 
22c42dd5dbSRiver Riddle template <typename InterfaceT>
walkSubElementsImpl(InterfaceT interface,function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn,DenseSet<Attribute> & visitedAttrs,DenseSet<Type> & visitedTypes)23c42dd5dbSRiver Riddle static void walkSubElementsImpl(InterfaceT interface,
24c42dd5dbSRiver Riddle                                 function_ref<void(Attribute)> walkAttrsFn,
25d4102861SMin-Yih Hsu                                 function_ref<void(Type)> walkTypesFn,
26d4102861SMin-Yih Hsu                                 DenseSet<Attribute> &visitedAttrs,
27d4102861SMin-Yih Hsu                                 DenseSet<Type> &visitedTypes) {
28c42dd5dbSRiver Riddle   interface.walkImmediateSubElements(
29c42dd5dbSRiver Riddle       [&](Attribute attr) {
30c42dd5dbSRiver Riddle         // Guard against potentially null inputs. This removes the need for the
31c42dd5dbSRiver Riddle         // derived attribute/type to do it.
32c42dd5dbSRiver Riddle         if (!attr)
33c42dd5dbSRiver Riddle           return;
34c42dd5dbSRiver Riddle 
35d4102861SMin-Yih Hsu         // Avoid infinite recursion when visiting sub attributes later, if this
36d4102861SMin-Yih Hsu         // is a mutable attribute.
37d4102861SMin-Yih Hsu         if (LLVM_UNLIKELY(attr.hasTrait<AttributeTrait::IsMutable>())) {
38d4102861SMin-Yih Hsu           if (!visitedAttrs.insert(attr).second)
39d4102861SMin-Yih Hsu             return;
40d4102861SMin-Yih Hsu         }
41d4102861SMin-Yih Hsu 
42c42dd5dbSRiver Riddle         // Walk any sub elements first.
43c42dd5dbSRiver Riddle         if (auto interface = attr.dyn_cast<SubElementAttrInterface>())
44d4102861SMin-Yih Hsu           walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs,
45d4102861SMin-Yih Hsu                               visitedTypes);
46c42dd5dbSRiver Riddle 
47c42dd5dbSRiver Riddle         // Walk this attribute.
48c42dd5dbSRiver Riddle         walkAttrsFn(attr);
49c42dd5dbSRiver Riddle       },
50c42dd5dbSRiver Riddle       [&](Type type) {
51c42dd5dbSRiver Riddle         // Guard against potentially null inputs. This removes the need for the
52c42dd5dbSRiver Riddle         // derived attribute/type to do it.
53c42dd5dbSRiver Riddle         if (!type)
54c42dd5dbSRiver Riddle           return;
55c42dd5dbSRiver Riddle 
56d4102861SMin-Yih Hsu         // Avoid infinite recursion when visiting sub types later, if this
57d4102861SMin-Yih Hsu         // is a mutable type.
58d4102861SMin-Yih Hsu         if (LLVM_UNLIKELY(type.hasTrait<TypeTrait::IsMutable>())) {
59d4102861SMin-Yih Hsu           if (!visitedTypes.insert(type).second)
60d4102861SMin-Yih Hsu             return;
61d4102861SMin-Yih Hsu         }
62d4102861SMin-Yih Hsu 
63c42dd5dbSRiver Riddle         // Walk any sub elements first.
64c42dd5dbSRiver Riddle         if (auto interface = type.dyn_cast<SubElementTypeInterface>())
65d4102861SMin-Yih Hsu           walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs,
66d4102861SMin-Yih Hsu                               visitedTypes);
67c42dd5dbSRiver Riddle 
68c42dd5dbSRiver Riddle         // Walk this type.
69c42dd5dbSRiver Riddle         walkTypesFn(type);
70c42dd5dbSRiver Riddle       });
71c42dd5dbSRiver Riddle }
72c42dd5dbSRiver Riddle 
walkSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn)73c42dd5dbSRiver Riddle void SubElementAttrInterface::walkSubElements(
74c42dd5dbSRiver Riddle     function_ref<void(Attribute)> walkAttrsFn,
75c42dd5dbSRiver Riddle     function_ref<void(Type)> walkTypesFn) {
76c42dd5dbSRiver Riddle   assert(walkAttrsFn && walkTypesFn && "expected valid walk functions");
77d4102861SMin-Yih Hsu   DenseSet<Attribute> visitedAttrs;
78d4102861SMin-Yih Hsu   DenseSet<Type> visitedTypes;
79d4102861SMin-Yih Hsu   walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs,
80d4102861SMin-Yih Hsu                       visitedTypes);
81c42dd5dbSRiver Riddle }
82c42dd5dbSRiver Riddle 
walkSubElements(function_ref<void (Attribute)> walkAttrsFn,function_ref<void (Type)> walkTypesFn)83c42dd5dbSRiver Riddle void SubElementTypeInterface::walkSubElements(
84c42dd5dbSRiver Riddle     function_ref<void(Attribute)> walkAttrsFn,
85c42dd5dbSRiver Riddle     function_ref<void(Type)> walkTypesFn) {
86c42dd5dbSRiver Riddle   assert(walkAttrsFn && walkTypesFn && "expected valid walk functions");
87d4102861SMin-Yih Hsu   DenseSet<Attribute> visitedAttrs;
88d4102861SMin-Yih Hsu   DenseSet<Type> visitedTypes;
89d4102861SMin-Yih Hsu   walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs,
90d4102861SMin-Yih Hsu                       visitedTypes);
91c42dd5dbSRiver Riddle }
92c42dd5dbSRiver Riddle 
93c42dd5dbSRiver Riddle //===----------------------------------------------------------------------===//
94*01eedbc7SRiver Riddle // ReplaceSubElements
95*01eedbc7SRiver Riddle 
96*01eedbc7SRiver Riddle /// Return if the given element is mutable.
isMutable(Attribute attr)97*01eedbc7SRiver Riddle static bool isMutable(Attribute attr) {
98*01eedbc7SRiver Riddle   return attr.hasTrait<AttributeTrait::IsMutable>();
99*01eedbc7SRiver Riddle }
isMutable(Type type)100*01eedbc7SRiver Riddle static bool isMutable(Type type) {
101*01eedbc7SRiver Riddle   return type.hasTrait<TypeTrait::IsMutable>();
102*01eedbc7SRiver Riddle }
103*01eedbc7SRiver Riddle 
104*01eedbc7SRiver Riddle template <typename InterfaceT, typename T, typename ReplaceSubElementFnT>
updateSubElementImpl(T element,function_ref<T (T)> walkFn,DenseMap<T,T> & visited,SmallVectorImpl<T> & newElements,FailureOr<bool> & changed,ReplaceSubElementFnT && replaceSubElementFn)105*01eedbc7SRiver Riddle static void updateSubElementImpl(T element, function_ref<T(T)> walkFn,
106*01eedbc7SRiver Riddle                                  DenseMap<T, T> &visited,
107*01eedbc7SRiver Riddle                                  SmallVectorImpl<T> &newElements,
108*01eedbc7SRiver Riddle                                  FailureOr<bool> &changed,
109*01eedbc7SRiver Riddle                                  ReplaceSubElementFnT &&replaceSubElementFn) {
110*01eedbc7SRiver Riddle   // Bail early if we failed at any point.
111*01eedbc7SRiver Riddle   if (failed(changed))
112*01eedbc7SRiver Riddle     return;
113*01eedbc7SRiver Riddle   newElements.push_back(element);
114*01eedbc7SRiver Riddle 
115*01eedbc7SRiver Riddle   // Guard against potentially null inputs. We always map null to null.
116*01eedbc7SRiver Riddle   if (!element)
117*01eedbc7SRiver Riddle     return;
118*01eedbc7SRiver Riddle 
119*01eedbc7SRiver Riddle   // Check for an existing mapping for this element, and walk it if we haven't
120*01eedbc7SRiver Riddle   // yet.
121*01eedbc7SRiver Riddle   T &mappedElement = visited[element];
122*01eedbc7SRiver Riddle   if (!mappedElement) {
123*01eedbc7SRiver Riddle     // Try walking this element.
124*01eedbc7SRiver Riddle     if (!(mappedElement = walkFn(element))) {
125*01eedbc7SRiver Riddle       changed = failure();
126*01eedbc7SRiver Riddle       return;
127*01eedbc7SRiver Riddle     }
128*01eedbc7SRiver Riddle 
129*01eedbc7SRiver Riddle     // Handle replacing sub-elements if this element is also a container.
130*01eedbc7SRiver Riddle     if (auto interface = mappedElement.template dyn_cast<InterfaceT>()) {
131*01eedbc7SRiver Riddle       if (!(mappedElement = replaceSubElementFn(interface))) {
132*01eedbc7SRiver Riddle         changed = failure();
133*01eedbc7SRiver Riddle         return;
134*01eedbc7SRiver Riddle       }
135*01eedbc7SRiver Riddle     }
136*01eedbc7SRiver Riddle   }
137*01eedbc7SRiver Riddle 
138*01eedbc7SRiver Riddle   // Update to the mapped element.
139*01eedbc7SRiver Riddle   if (mappedElement != element) {
140*01eedbc7SRiver Riddle     newElements.back() = mappedElement;
141*01eedbc7SRiver Riddle     changed = true;
142*01eedbc7SRiver Riddle   }
143*01eedbc7SRiver Riddle }
144*01eedbc7SRiver Riddle 
145*01eedbc7SRiver Riddle template <typename InterfaceT>
146*01eedbc7SRiver Riddle static typename InterfaceT::ValueType
replaceSubElementsImpl(InterfaceT interface,function_ref<Attribute (Attribute)> walkAttrsFn,function_ref<Type (Type)> walkTypesFn,DenseMap<Attribute,Attribute> & visitedAttrs,DenseMap<Type,Type> & visitedTypes)147*01eedbc7SRiver Riddle replaceSubElementsImpl(InterfaceT interface,
148*01eedbc7SRiver Riddle                        function_ref<Attribute(Attribute)> walkAttrsFn,
149*01eedbc7SRiver Riddle                        function_ref<Type(Type)> walkTypesFn,
150*01eedbc7SRiver Riddle                        DenseMap<Attribute, Attribute> &visitedAttrs,
151*01eedbc7SRiver Riddle                        DenseMap<Type, Type> &visitedTypes) {
152*01eedbc7SRiver Riddle   // Walk the current sub-elements, replacing them as necessary.
153*01eedbc7SRiver Riddle   SmallVector<Attribute, 16> newAttrs;
154*01eedbc7SRiver Riddle   SmallVector<Type, 16> newTypes;
155*01eedbc7SRiver Riddle   FailureOr<bool> changed = false;
156*01eedbc7SRiver Riddle   auto replaceSubElementFn = [&](auto subInterface) {
157*01eedbc7SRiver Riddle     return replaceSubElementsImpl(subInterface, walkAttrsFn, walkTypesFn,
158*01eedbc7SRiver Riddle                                   visitedAttrs, visitedTypes);
159*01eedbc7SRiver Riddle   };
160*01eedbc7SRiver Riddle   interface.walkImmediateSubElements(
161*01eedbc7SRiver Riddle       [&](Attribute element) {
162*01eedbc7SRiver Riddle         updateSubElementImpl<SubElementAttrInterface>(
163*01eedbc7SRiver Riddle             element, walkAttrsFn, visitedAttrs, newAttrs, changed,
164*01eedbc7SRiver Riddle             replaceSubElementFn);
165*01eedbc7SRiver Riddle       },
166*01eedbc7SRiver Riddle       [&](Type element) {
167*01eedbc7SRiver Riddle         updateSubElementImpl<SubElementTypeInterface>(
168*01eedbc7SRiver Riddle             element, walkTypesFn, visitedTypes, newTypes, changed,
169*01eedbc7SRiver Riddle             replaceSubElementFn);
170*01eedbc7SRiver Riddle       });
171*01eedbc7SRiver Riddle   if (failed(changed))
172*01eedbc7SRiver Riddle     return {};
173*01eedbc7SRiver Riddle 
174*01eedbc7SRiver Riddle   // If the sub-elements didn't change, just return the original value.
175*01eedbc7SRiver Riddle   if (!*changed)
176*01eedbc7SRiver Riddle     return interface;
177*01eedbc7SRiver Riddle 
178*01eedbc7SRiver Riddle   // If this element is mutable, we don't support changing its sub elements, the
179*01eedbc7SRiver Riddle   // sub element walk doesn't give us a valid ordering for what we need here. If
180*01eedbc7SRiver Riddle   // we want to support mutable elements, we'll need something more.
181*01eedbc7SRiver Riddle   if (isMutable(interface))
182*01eedbc7SRiver Riddle     return {};
183*01eedbc7SRiver Riddle 
184*01eedbc7SRiver Riddle   // Use the new elements during the replacement.
185*01eedbc7SRiver Riddle   return interface.replaceImmediateSubElements(newAttrs, newTypes);
186*01eedbc7SRiver Riddle }
187*01eedbc7SRiver Riddle 
replaceSubElements(function_ref<Attribute (Attribute)> replaceAttrFn,function_ref<Type (Type)> replaceTypeFn)188*01eedbc7SRiver Riddle Attribute SubElementAttrInterface::replaceSubElements(
189*01eedbc7SRiver Riddle     function_ref<Attribute(Attribute)> replaceAttrFn,
190*01eedbc7SRiver Riddle     function_ref<Type(Type)> replaceTypeFn) {
191*01eedbc7SRiver Riddle   assert(replaceAttrFn && replaceTypeFn && "expected valid replace functions");
192*01eedbc7SRiver Riddle   DenseMap<Attribute, Attribute> visitedAttrs;
193*01eedbc7SRiver Riddle   DenseMap<Type, Type> visitedTypes;
194*01eedbc7SRiver Riddle   return replaceSubElementsImpl(*this, replaceAttrFn, replaceTypeFn,
195*01eedbc7SRiver Riddle                                 visitedAttrs, visitedTypes);
196*01eedbc7SRiver Riddle }
197*01eedbc7SRiver Riddle 
replaceSubElements(function_ref<Attribute (Attribute)> replaceAttrFn,function_ref<Type (Type)> replaceTypeFn)198*01eedbc7SRiver Riddle Type SubElementTypeInterface::replaceSubElements(
199*01eedbc7SRiver Riddle     function_ref<Attribute(Attribute)> replaceAttrFn,
200*01eedbc7SRiver Riddle     function_ref<Type(Type)> replaceTypeFn) {
201*01eedbc7SRiver Riddle   assert(replaceAttrFn && replaceTypeFn && "expected valid replace functions");
202*01eedbc7SRiver Riddle   DenseMap<Attribute, Attribute> visitedAttrs;
203*01eedbc7SRiver Riddle   DenseMap<Type, Type> visitedTypes;
204*01eedbc7SRiver Riddle   return replaceSubElementsImpl(*this, replaceAttrFn, replaceTypeFn,
205*01eedbc7SRiver Riddle                                 visitedAttrs, visitedTypes);
206*01eedbc7SRiver Riddle }
207*01eedbc7SRiver Riddle 
208*01eedbc7SRiver Riddle //===----------------------------------------------------------------------===//
209c42dd5dbSRiver Riddle // SubElementInterface Tablegen definitions
210c42dd5dbSRiver Riddle //===----------------------------------------------------------------------===//
211c42dd5dbSRiver Riddle 
212c42dd5dbSRiver Riddle #include "mlir/IR/SubElementAttrInterfaces.cpp.inc"
213c42dd5dbSRiver Riddle #include "mlir/IR/SubElementTypeInterfaces.cpp.inc"
214