1aa8a9761SMichael Kruse //===- polly/ScheduleTreeTransform.cpp --------------------------*- C++ -*-===//
2aa8a9761SMichael Kruse //
3aa8a9761SMichael Kruse // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4aa8a9761SMichael Kruse // See https://llvm.org/LICENSE.txt for license information.
5aa8a9761SMichael Kruse // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6aa8a9761SMichael Kruse //
7aa8a9761SMichael Kruse //===----------------------------------------------------------------------===//
8aa8a9761SMichael Kruse //
9aa8a9761SMichael Kruse // Make changes to isl's schedule tree data structure.
10aa8a9761SMichael Kruse //
11aa8a9761SMichael Kruse //===----------------------------------------------------------------------===//
12aa8a9761SMichael Kruse
13aa8a9761SMichael Kruse #include "polly/ScheduleTreeTransform.h"
1464489255SMichael Kruse #include "polly/Support/GICHelper.h"
15aa8a9761SMichael Kruse #include "polly/Support/ISLTools.h"
163f170eb1SMichael Kruse #include "polly/Support/ScopHelper.h"
17aa8a9761SMichael Kruse #include "llvm/ADT/ArrayRef.h"
183f170eb1SMichael Kruse #include "llvm/ADT/Sequence.h"
19aa8a9761SMichael Kruse #include "llvm/ADT/SmallVector.h"
203f170eb1SMichael Kruse #include "llvm/IR/Constants.h"
213f170eb1SMichael Kruse #include "llvm/IR/Metadata.h"
223f170eb1SMichael Kruse #include "llvm/Transforms/Utils/UnrollLoop.h"
23aa8a9761SMichael Kruse
2464489255SMichael Kruse #define DEBUG_TYPE "polly-opt-isl"
2564489255SMichael Kruse
26aa8a9761SMichael Kruse using namespace polly;
273f170eb1SMichael Kruse using namespace llvm;
28aa8a9761SMichael Kruse
29aa8a9761SMichael Kruse namespace {
3064489255SMichael Kruse
3164489255SMichael Kruse /// Copy the band member attributes (coincidence, loop type, isolate ast loop
3264489255SMichael Kruse /// type) from one band to another.
3364489255SMichael Kruse static isl::schedule_node_band
applyBandMemberAttributes(isl::schedule_node_band Target,int TargetIdx,const isl::schedule_node_band & Source,int SourceIdx)3464489255SMichael Kruse applyBandMemberAttributes(isl::schedule_node_band Target, int TargetIdx,
3564489255SMichael Kruse const isl::schedule_node_band &Source,
3664489255SMichael Kruse int SourceIdx) {
3764489255SMichael Kruse bool Coincident = Source.member_get_coincident(SourceIdx).release();
3864489255SMichael Kruse Target = Target.member_set_coincident(TargetIdx, Coincident);
3964489255SMichael Kruse
4064489255SMichael Kruse isl_ast_loop_type LoopType =
4164489255SMichael Kruse isl_schedule_node_band_member_get_ast_loop_type(Source.get(), SourceIdx);
4264489255SMichael Kruse Target = isl::manage(isl_schedule_node_band_member_set_ast_loop_type(
4364489255SMichael Kruse Target.release(), TargetIdx, LoopType))
4464489255SMichael Kruse .as<isl::schedule_node_band>();
4564489255SMichael Kruse
4664489255SMichael Kruse isl_ast_loop_type IsolateType =
4764489255SMichael Kruse isl_schedule_node_band_member_get_isolate_ast_loop_type(Source.get(),
4864489255SMichael Kruse SourceIdx);
4964489255SMichael Kruse Target = isl::manage(isl_schedule_node_band_member_set_isolate_ast_loop_type(
5064489255SMichael Kruse Target.release(), TargetIdx, IsolateType))
5164489255SMichael Kruse .as<isl::schedule_node_band>();
5264489255SMichael Kruse
5364489255SMichael Kruse return Target;
5464489255SMichael Kruse }
5564489255SMichael Kruse
5664489255SMichael Kruse /// Create a new band by copying members from another @p Band. @p IncludeCb
5764489255SMichael Kruse /// decides which band indices are copied to the result.
5864489255SMichael Kruse template <typename CbTy>
rebuildBand(isl::schedule_node_band OldBand,isl::schedule Body,CbTy IncludeCb)5964489255SMichael Kruse static isl::schedule rebuildBand(isl::schedule_node_band OldBand,
6064489255SMichael Kruse isl::schedule Body, CbTy IncludeCb) {
6144596fe6SRiccardo Mori int NumBandDims = unsignedFromIslSize(OldBand.n_member());
6264489255SMichael Kruse
6364489255SMichael Kruse bool ExcludeAny = false;
6464489255SMichael Kruse bool IncludeAny = false;
6564489255SMichael Kruse for (auto OldIdx : seq<int>(0, NumBandDims)) {
6664489255SMichael Kruse if (IncludeCb(OldIdx))
6764489255SMichael Kruse IncludeAny = true;
6864489255SMichael Kruse else
6964489255SMichael Kruse ExcludeAny = true;
7064489255SMichael Kruse }
7164489255SMichael Kruse
7264489255SMichael Kruse // Instead of creating a zero-member band, don't create a band at all.
7364489255SMichael Kruse if (!IncludeAny)
7464489255SMichael Kruse return Body;
7564489255SMichael Kruse
7664489255SMichael Kruse isl::multi_union_pw_aff PartialSched = OldBand.get_partial_schedule();
7764489255SMichael Kruse isl::multi_union_pw_aff NewPartialSched;
7864489255SMichael Kruse if (ExcludeAny) {
7964489255SMichael Kruse // Select the included partial scatter functions.
8064489255SMichael Kruse isl::union_pw_aff_list List = PartialSched.list();
8164489255SMichael Kruse int NewIdx = 0;
8264489255SMichael Kruse for (auto OldIdx : seq<int>(0, NumBandDims)) {
8364489255SMichael Kruse if (IncludeCb(OldIdx))
8464489255SMichael Kruse NewIdx += 1;
8564489255SMichael Kruse else
8664489255SMichael Kruse List = List.drop(NewIdx, 1);
8764489255SMichael Kruse }
8864489255SMichael Kruse isl::space ParamSpace = PartialSched.get_space().params();
8964489255SMichael Kruse isl::space NewScatterSpace = ParamSpace.add_unnamed_tuple(NewIdx);
9064489255SMichael Kruse NewPartialSched = isl::multi_union_pw_aff(NewScatterSpace, List);
9164489255SMichael Kruse } else {
9264489255SMichael Kruse // Just reuse original scatter function of copying all of them.
9364489255SMichael Kruse NewPartialSched = PartialSched;
9464489255SMichael Kruse }
9564489255SMichael Kruse
9664489255SMichael Kruse // Create the new band node.
9764489255SMichael Kruse isl::schedule_node_band NewBand =
9864489255SMichael Kruse Body.insert_partial_schedule(NewPartialSched)
9964489255SMichael Kruse .get_root()
10064489255SMichael Kruse .child(0)
10164489255SMichael Kruse .as<isl::schedule_node_band>();
10264489255SMichael Kruse
10364489255SMichael Kruse // If OldBand was permutable, so is the new one, even if some dimensions are
10464489255SMichael Kruse // missing.
10564489255SMichael Kruse bool IsPermutable = OldBand.permutable().release();
10664489255SMichael Kruse NewBand = NewBand.set_permutable(IsPermutable);
10764489255SMichael Kruse
10864489255SMichael Kruse // Reapply member attributes.
10964489255SMichael Kruse int NewIdx = 0;
11064489255SMichael Kruse for (auto OldIdx : seq<int>(0, NumBandDims)) {
11164489255SMichael Kruse if (!IncludeCb(OldIdx))
11264489255SMichael Kruse continue;
11364489255SMichael Kruse NewBand =
11464489255SMichael Kruse applyBandMemberAttributes(std::move(NewBand), NewIdx, OldBand, OldIdx);
11564489255SMichael Kruse NewIdx += 1;
11664489255SMichael Kruse }
11764489255SMichael Kruse
11864489255SMichael Kruse return NewBand.get_schedule();
11964489255SMichael Kruse }
12064489255SMichael Kruse
121aa8a9761SMichael Kruse /// Rewrite a schedule tree by reconstructing it bottom-up.
122aa8a9761SMichael Kruse ///
123aa8a9761SMichael Kruse /// By default, the original schedule tree is reconstructed. To build a
124aa8a9761SMichael Kruse /// different tree, redefine visitor methods in a derived class (CRTP).
125aa8a9761SMichael Kruse ///
126aa8a9761SMichael Kruse /// Note that AST build options are not applied; Setting the isolate[] option
127aa8a9761SMichael Kruse /// makes the schedule tree 'anchored' and cannot be modified afterwards. Hence,
128aa8a9761SMichael Kruse /// AST build options must be set after the tree has been constructed.
129aa8a9761SMichael Kruse template <typename Derived, typename... Args>
130aa8a9761SMichael Kruse struct ScheduleTreeRewriter
131bd93df93SMichael Kruse : RecursiveScheduleTreeVisitor<Derived, isl::schedule, Args...> {
getDerived__anon9ace89990111::ScheduleTreeRewriter132aa8a9761SMichael Kruse Derived &getDerived() { return *static_cast<Derived *>(this); }
getDerived__anon9ace89990111::ScheduleTreeRewriter133aa8a9761SMichael Kruse const Derived &getDerived() const {
134aa8a9761SMichael Kruse return *static_cast<const Derived *>(this);
135aa8a9761SMichael Kruse }
136aa8a9761SMichael Kruse
visitDomain__anon9ace89990111::ScheduleTreeRewriter137c62d9a5cSMichael Kruse isl::schedule visitDomain(isl::schedule_node_domain Node, Args... args) {
138aa8a9761SMichael Kruse // Every schedule_tree already has a domain node, no need to add one.
139aa8a9761SMichael Kruse return getDerived().visit(Node.first_child(), std::forward<Args>(args)...);
140aa8a9761SMichael Kruse }
141aa8a9761SMichael Kruse
visitBand__anon9ace89990111::ScheduleTreeRewriter142c62d9a5cSMichael Kruse isl::schedule visitBand(isl::schedule_node_band Band, Args... args) {
143aa8a9761SMichael Kruse isl::schedule NewChild =
144aa8a9761SMichael Kruse getDerived().visit(Band.child(0), std::forward<Args>(args)...);
14564489255SMichael Kruse return rebuildBand(Band, NewChild, [](int) { return true; });
146aa8a9761SMichael Kruse }
147aa8a9761SMichael Kruse
visitSequence__anon9ace89990111::ScheduleTreeRewriter148c62d9a5cSMichael Kruse isl::schedule visitSequence(isl::schedule_node_sequence Sequence,
149aa8a9761SMichael Kruse Args... args) {
150aa8a9761SMichael Kruse int NumChildren = isl_schedule_node_n_children(Sequence.get());
151aa8a9761SMichael Kruse isl::schedule Result =
152aa8a9761SMichael Kruse getDerived().visit(Sequence.child(0), std::forward<Args>(args)...);
153aa8a9761SMichael Kruse for (int i = 1; i < NumChildren; i += 1)
154aa8a9761SMichael Kruse Result = Result.sequence(
155aa8a9761SMichael Kruse getDerived().visit(Sequence.child(i), std::forward<Args>(args)...));
156aa8a9761SMichael Kruse return Result;
157aa8a9761SMichael Kruse }
158aa8a9761SMichael Kruse
visitSet__anon9ace89990111::ScheduleTreeRewriter159c62d9a5cSMichael Kruse isl::schedule visitSet(isl::schedule_node_set Set, Args... args) {
160aa8a9761SMichael Kruse int NumChildren = isl_schedule_node_n_children(Set.get());
161aa8a9761SMichael Kruse isl::schedule Result =
162aa8a9761SMichael Kruse getDerived().visit(Set.child(0), std::forward<Args>(args)...);
163aa8a9761SMichael Kruse for (int i = 1; i < NumChildren; i += 1)
164aa8a9761SMichael Kruse Result = isl::manage(
165aa8a9761SMichael Kruse isl_schedule_set(Result.release(),
166aa8a9761SMichael Kruse getDerived()
167aa8a9761SMichael Kruse .visit(Set.child(i), std::forward<Args>(args)...)
168aa8a9761SMichael Kruse .release()));
169aa8a9761SMichael Kruse return Result;
170aa8a9761SMichael Kruse }
171aa8a9761SMichael Kruse
visitLeaf__anon9ace89990111::ScheduleTreeRewriter172c62d9a5cSMichael Kruse isl::schedule visitLeaf(isl::schedule_node_leaf Leaf, Args... args) {
173aa8a9761SMichael Kruse return isl::schedule::from_domain(Leaf.get_domain());
174aa8a9761SMichael Kruse }
175aa8a9761SMichael Kruse
visitMark__anon9ace89990111::ScheduleTreeRewriter176aa8a9761SMichael Kruse isl::schedule visitMark(const isl::schedule_node &Mark, Args... args) {
177d3fdbda6SRiccardo Mori
178d3fdbda6SRiccardo Mori isl::id TheMark = Mark.as<isl::schedule_node_mark>().get_id();
179aa8a9761SMichael Kruse isl::schedule_node NewChild =
180aa8a9761SMichael Kruse getDerived()
181aa8a9761SMichael Kruse .visit(Mark.first_child(), std::forward<Args>(args)...)
182aa8a9761SMichael Kruse .get_root()
183aa8a9761SMichael Kruse .first_child();
184aa8a9761SMichael Kruse return NewChild.insert_mark(TheMark).get_schedule();
185aa8a9761SMichael Kruse }
186aa8a9761SMichael Kruse
visitExtension__anon9ace89990111::ScheduleTreeRewriter187c62d9a5cSMichael Kruse isl::schedule visitExtension(isl::schedule_node_extension Extension,
188aa8a9761SMichael Kruse Args... args) {
189d3fdbda6SRiccardo Mori isl::union_map TheExtension =
190d3fdbda6SRiccardo Mori Extension.as<isl::schedule_node_extension>().get_extension();
191aa8a9761SMichael Kruse isl::schedule_node NewChild = getDerived()
192aa8a9761SMichael Kruse .visit(Extension.child(0), args...)
193aa8a9761SMichael Kruse .get_root()
194aa8a9761SMichael Kruse .first_child();
195aa8a9761SMichael Kruse isl::schedule_node NewExtension =
196aa8a9761SMichael Kruse isl::schedule_node::from_extension(TheExtension);
197aa8a9761SMichael Kruse return NewChild.graft_before(NewExtension).get_schedule();
198aa8a9761SMichael Kruse }
199aa8a9761SMichael Kruse
visitFilter__anon9ace89990111::ScheduleTreeRewriter200c62d9a5cSMichael Kruse isl::schedule visitFilter(isl::schedule_node_filter Filter, Args... args) {
201d3fdbda6SRiccardo Mori isl::union_set FilterDomain =
202d3fdbda6SRiccardo Mori Filter.as<isl::schedule_node_filter>().get_filter();
203aa8a9761SMichael Kruse isl::schedule NewSchedule =
204aa8a9761SMichael Kruse getDerived().visit(Filter.child(0), std::forward<Args>(args)...);
205aa8a9761SMichael Kruse return NewSchedule.intersect_domain(FilterDomain);
206aa8a9761SMichael Kruse }
207aa8a9761SMichael Kruse
visitNode__anon9ace89990111::ScheduleTreeRewriter208c62d9a5cSMichael Kruse isl::schedule visitNode(isl::schedule_node Node, Args... args) {
209aa8a9761SMichael Kruse llvm_unreachable("Not implemented");
210aa8a9761SMichael Kruse }
211aa8a9761SMichael Kruse };
212aa8a9761SMichael Kruse
21364489255SMichael Kruse /// Rewrite the schedule tree without any changes. Useful to copy a subtree into
21464489255SMichael Kruse /// a new schedule, discarding everything but.
215bd93df93SMichael Kruse struct IdentityRewriter : ScheduleTreeRewriter<IdentityRewriter> {};
21664489255SMichael Kruse
217aa8a9761SMichael Kruse /// Rewrite a schedule tree to an equivalent one without extension nodes.
218aa8a9761SMichael Kruse ///
219aa8a9761SMichael Kruse /// Each visit method takes two additional arguments:
220aa8a9761SMichael Kruse ///
221aa8a9761SMichael Kruse /// * The new domain the node, which is the inherited domain plus any domains
222aa8a9761SMichael Kruse /// added by extension nodes.
223aa8a9761SMichael Kruse ///
224aa8a9761SMichael Kruse /// * A map of extension domains of all children is returned; it is required by
225aa8a9761SMichael Kruse /// band nodes to schedule the additional domains at the same position as the
226aa8a9761SMichael Kruse /// extension node would.
227aa8a9761SMichael Kruse ///
228bd93df93SMichael Kruse struct ExtensionNodeRewriter final
229bd93df93SMichael Kruse : ScheduleTreeRewriter<ExtensionNodeRewriter, const isl::union_set &,
230aa8a9761SMichael Kruse isl::union_map &> {
231aa8a9761SMichael Kruse using BaseTy = ScheduleTreeRewriter<ExtensionNodeRewriter,
232aa8a9761SMichael Kruse const isl::union_set &, isl::union_map &>;
getBase__anon9ace89990111::ExtensionNodeRewriter233aa8a9761SMichael Kruse BaseTy &getBase() { return *this; }
getBase__anon9ace89990111::ExtensionNodeRewriter234aa8a9761SMichael Kruse const BaseTy &getBase() const { return *this; }
235aa8a9761SMichael Kruse
visitSchedule__anon9ace89990111::ExtensionNodeRewriter236c62d9a5cSMichael Kruse isl::schedule visitSchedule(isl::schedule Schedule) {
237aa8a9761SMichael Kruse isl::union_map Extensions;
238aa8a9761SMichael Kruse isl::schedule Result =
239aa8a9761SMichael Kruse visit(Schedule.get_root(), Schedule.get_domain(), Extensions);
2407c7978a1Spatacca assert(!Extensions.is_null() && Extensions.is_empty());
241aa8a9761SMichael Kruse return Result;
242aa8a9761SMichael Kruse }
243aa8a9761SMichael Kruse
visitSequence__anon9ace89990111::ExtensionNodeRewriter244c62d9a5cSMichael Kruse isl::schedule visitSequence(isl::schedule_node_sequence Sequence,
245aa8a9761SMichael Kruse const isl::union_set &Domain,
246aa8a9761SMichael Kruse isl::union_map &Extensions) {
247aa8a9761SMichael Kruse int NumChildren = isl_schedule_node_n_children(Sequence.get());
248aa8a9761SMichael Kruse isl::schedule NewNode = visit(Sequence.first_child(), Domain, Extensions);
249aa8a9761SMichael Kruse for (int i = 1; i < NumChildren; i += 1) {
250aa8a9761SMichael Kruse isl::schedule_node OldChild = Sequence.child(i);
251aa8a9761SMichael Kruse isl::union_map NewChildExtensions;
252aa8a9761SMichael Kruse isl::schedule NewChildNode = visit(OldChild, Domain, NewChildExtensions);
253aa8a9761SMichael Kruse NewNode = NewNode.sequence(NewChildNode);
254aa8a9761SMichael Kruse Extensions = Extensions.unite(NewChildExtensions);
255aa8a9761SMichael Kruse }
256aa8a9761SMichael Kruse return NewNode;
257aa8a9761SMichael Kruse }
258aa8a9761SMichael Kruse
visitSet__anon9ace89990111::ExtensionNodeRewriter259c62d9a5cSMichael Kruse isl::schedule visitSet(isl::schedule_node_set Set,
260aa8a9761SMichael Kruse const isl::union_set &Domain,
261aa8a9761SMichael Kruse isl::union_map &Extensions) {
262aa8a9761SMichael Kruse int NumChildren = isl_schedule_node_n_children(Set.get());
263aa8a9761SMichael Kruse isl::schedule NewNode = visit(Set.first_child(), Domain, Extensions);
264aa8a9761SMichael Kruse for (int i = 1; i < NumChildren; i += 1) {
265aa8a9761SMichael Kruse isl::schedule_node OldChild = Set.child(i);
266aa8a9761SMichael Kruse isl::union_map NewChildExtensions;
267aa8a9761SMichael Kruse isl::schedule NewChildNode = visit(OldChild, Domain, NewChildExtensions);
268aa8a9761SMichael Kruse NewNode = isl::manage(
269aa8a9761SMichael Kruse isl_schedule_set(NewNode.release(), NewChildNode.release()));
270aa8a9761SMichael Kruse Extensions = Extensions.unite(NewChildExtensions);
271aa8a9761SMichael Kruse }
272aa8a9761SMichael Kruse return NewNode;
273aa8a9761SMichael Kruse }
274aa8a9761SMichael Kruse
visitLeaf__anon9ace89990111::ExtensionNodeRewriter275c62d9a5cSMichael Kruse isl::schedule visitLeaf(isl::schedule_node_leaf Leaf,
276aa8a9761SMichael Kruse const isl::union_set &Domain,
277aa8a9761SMichael Kruse isl::union_map &Extensions) {
278bad3ebbaSRiccardo Mori Extensions = isl::union_map::empty(Leaf.ctx());
279aa8a9761SMichael Kruse return isl::schedule::from_domain(Domain);
280aa8a9761SMichael Kruse }
281aa8a9761SMichael Kruse
visitBand__anon9ace89990111::ExtensionNodeRewriter282c62d9a5cSMichael Kruse isl::schedule visitBand(isl::schedule_node_band OldNode,
283aa8a9761SMichael Kruse const isl::union_set &Domain,
284aa8a9761SMichael Kruse isl::union_map &OuterExtensions) {
285aa8a9761SMichael Kruse isl::schedule_node OldChild = OldNode.first_child();
286aa8a9761SMichael Kruse isl::multi_union_pw_aff PartialSched =
287aa8a9761SMichael Kruse isl::manage(isl_schedule_node_band_get_partial_schedule(OldNode.get()));
288aa8a9761SMichael Kruse
289aa8a9761SMichael Kruse isl::union_map NewChildExtensions;
290aa8a9761SMichael Kruse isl::schedule NewChild = visit(OldChild, Domain, NewChildExtensions);
291aa8a9761SMichael Kruse
292aa8a9761SMichael Kruse // Add the extensions to the partial schedule.
293bad3ebbaSRiccardo Mori OuterExtensions = isl::union_map::empty(NewChildExtensions.ctx());
294aa8a9761SMichael Kruse isl::union_map NewPartialSchedMap = isl::union_map::from(PartialSched);
295aa8a9761SMichael Kruse unsigned BandDims = isl_schedule_node_band_n_member(OldNode.get());
296aa8a9761SMichael Kruse for (isl::map Ext : NewChildExtensions.get_map_list()) {
29744596fe6SRiccardo Mori unsigned ExtDims = unsignedFromIslSize(Ext.domain_tuple_dim());
298aa8a9761SMichael Kruse assert(ExtDims >= BandDims);
299aa8a9761SMichael Kruse unsigned OuterDims = ExtDims - BandDims;
300aa8a9761SMichael Kruse
301aa8a9761SMichael Kruse isl::map BandSched =
302aa8a9761SMichael Kruse Ext.project_out(isl::dim::in, 0, OuterDims).reverse();
303aa8a9761SMichael Kruse NewPartialSchedMap = NewPartialSchedMap.unite(BandSched);
304aa8a9761SMichael Kruse
305aa8a9761SMichael Kruse // There might be more outer bands that have to schedule the extensions.
306aa8a9761SMichael Kruse if (OuterDims > 0) {
307aa8a9761SMichael Kruse isl::map OuterSched =
308aa8a9761SMichael Kruse Ext.project_out(isl::dim::in, OuterDims, BandDims);
309d5ee355fSRiccardo Mori OuterExtensions = OuterExtensions.unite(OuterSched);
310aa8a9761SMichael Kruse }
311aa8a9761SMichael Kruse }
312aa8a9761SMichael Kruse isl::multi_union_pw_aff NewPartialSchedAsAsMultiUnionPwAff =
313aa8a9761SMichael Kruse isl::multi_union_pw_aff::from_union_map(NewPartialSchedMap);
314aa8a9761SMichael Kruse isl::schedule_node NewNode =
315aa8a9761SMichael Kruse NewChild.insert_partial_schedule(NewPartialSchedAsAsMultiUnionPwAff)
316aa8a9761SMichael Kruse .get_root()
317d3fdbda6SRiccardo Mori .child(0);
318aa8a9761SMichael Kruse
319aa8a9761SMichael Kruse // Reapply permutability and coincidence attributes.
320aa8a9761SMichael Kruse NewNode = isl::manage(isl_schedule_node_band_set_permutable(
321aa8a9761SMichael Kruse NewNode.release(),
322aa8a9761SMichael Kruse isl_schedule_node_band_get_permutable(OldNode.get())));
32364489255SMichael Kruse for (unsigned i = 0; i < BandDims; i += 1)
32464489255SMichael Kruse NewNode = applyBandMemberAttributes(NewNode.as<isl::schedule_node_band>(),
32564489255SMichael Kruse i, OldNode, i);
326aa8a9761SMichael Kruse
327aa8a9761SMichael Kruse return NewNode.get_schedule();
328aa8a9761SMichael Kruse }
329aa8a9761SMichael Kruse
visitFilter__anon9ace89990111::ExtensionNodeRewriter330c62d9a5cSMichael Kruse isl::schedule visitFilter(isl::schedule_node_filter Filter,
331aa8a9761SMichael Kruse const isl::union_set &Domain,
332aa8a9761SMichael Kruse isl::union_map &Extensions) {
333d3fdbda6SRiccardo Mori isl::union_set FilterDomain =
334d3fdbda6SRiccardo Mori Filter.as<isl::schedule_node_filter>().get_filter();
335aa8a9761SMichael Kruse isl::union_set NewDomain = Domain.intersect(FilterDomain);
336aa8a9761SMichael Kruse
337aa8a9761SMichael Kruse // A filter is added implicitly if necessary when joining schedule trees.
338aa8a9761SMichael Kruse return visit(Filter.first_child(), NewDomain, Extensions);
339aa8a9761SMichael Kruse }
340aa8a9761SMichael Kruse
visitExtension__anon9ace89990111::ExtensionNodeRewriter341c62d9a5cSMichael Kruse isl::schedule visitExtension(isl::schedule_node_extension Extension,
342aa8a9761SMichael Kruse const isl::union_set &Domain,
343aa8a9761SMichael Kruse isl::union_map &Extensions) {
344d3fdbda6SRiccardo Mori isl::union_map ExtDomain =
345d3fdbda6SRiccardo Mori Extension.as<isl::schedule_node_extension>().get_extension();
346aa8a9761SMichael Kruse isl::union_set NewDomain = Domain.unite(ExtDomain.range());
347aa8a9761SMichael Kruse isl::union_map ChildExtensions;
348aa8a9761SMichael Kruse isl::schedule NewChild =
349aa8a9761SMichael Kruse visit(Extension.first_child(), NewDomain, ChildExtensions);
350aa8a9761SMichael Kruse Extensions = ChildExtensions.unite(ExtDomain);
351aa8a9761SMichael Kruse return NewChild;
352aa8a9761SMichael Kruse }
353aa8a9761SMichael Kruse };
354aa8a9761SMichael Kruse
355aa8a9761SMichael Kruse /// Collect all AST build options in any schedule tree band.
356aa8a9761SMichael Kruse ///
357aa8a9761SMichael Kruse /// ScheduleTreeRewriter cannot apply the schedule tree options. This class
358aa8a9761SMichael Kruse /// collects these options to apply them later.
359bd93df93SMichael Kruse struct CollectASTBuildOptions final
360bd93df93SMichael Kruse : RecursiveScheduleTreeVisitor<CollectASTBuildOptions> {
361aa8a9761SMichael Kruse using BaseTy = RecursiveScheduleTreeVisitor<CollectASTBuildOptions>;
getBase__anon9ace89990111::CollectASTBuildOptions362aa8a9761SMichael Kruse BaseTy &getBase() { return *this; }
getBase__anon9ace89990111::CollectASTBuildOptions363aa8a9761SMichael Kruse const BaseTy &getBase() const { return *this; }
364aa8a9761SMichael Kruse
365aa8a9761SMichael Kruse llvm::SmallVector<isl::union_set, 8> ASTBuildOptions;
366aa8a9761SMichael Kruse
visitBand__anon9ace89990111::CollectASTBuildOptions367c62d9a5cSMichael Kruse void visitBand(isl::schedule_node_band Band) {
368aa8a9761SMichael Kruse ASTBuildOptions.push_back(
369aa8a9761SMichael Kruse isl::manage(isl_schedule_node_band_get_ast_build_options(Band.get())));
370aa8a9761SMichael Kruse return getBase().visitBand(Band);
371aa8a9761SMichael Kruse }
372aa8a9761SMichael Kruse };
373aa8a9761SMichael Kruse
374aa8a9761SMichael Kruse /// Apply AST build options to the bands in a schedule tree.
375aa8a9761SMichael Kruse ///
376aa8a9761SMichael Kruse /// This rewrites a schedule tree with the AST build options applied. We assume
377aa8a9761SMichael Kruse /// that the band nodes are visited in the same order as they were when the
378aa8a9761SMichael Kruse /// build options were collected, typically by CollectASTBuildOptions.
379bd93df93SMichael Kruse struct ApplyASTBuildOptions final : ScheduleNodeRewriter<ApplyASTBuildOptions> {
380aa8a9761SMichael Kruse using BaseTy = ScheduleNodeRewriter<ApplyASTBuildOptions>;
getBase__anon9ace89990111::ApplyASTBuildOptions381aa8a9761SMichael Kruse BaseTy &getBase() { return *this; }
getBase__anon9ace89990111::ApplyASTBuildOptions382aa8a9761SMichael Kruse const BaseTy &getBase() const { return *this; }
383aa8a9761SMichael Kruse
384bd9e810bSMichael Kruse size_t Pos;
385aa8a9761SMichael Kruse llvm::ArrayRef<isl::union_set> ASTBuildOptions;
386aa8a9761SMichael Kruse
ApplyASTBuildOptions__anon9ace89990111::ApplyASTBuildOptions387aa8a9761SMichael Kruse ApplyASTBuildOptions(llvm::ArrayRef<isl::union_set> ASTBuildOptions)
388aa8a9761SMichael Kruse : ASTBuildOptions(ASTBuildOptions) {}
389aa8a9761SMichael Kruse
visitSchedule__anon9ace89990111::ApplyASTBuildOptions390c62d9a5cSMichael Kruse isl::schedule visitSchedule(isl::schedule Schedule) {
391aa8a9761SMichael Kruse Pos = 0;
392aa8a9761SMichael Kruse isl::schedule Result = visit(Schedule).get_schedule();
393aa8a9761SMichael Kruse assert(Pos == ASTBuildOptions.size() &&
394aa8a9761SMichael Kruse "AST build options must match to band nodes");
395aa8a9761SMichael Kruse return Result;
396aa8a9761SMichael Kruse }
397aa8a9761SMichael Kruse
visitBand__anon9ace89990111::ApplyASTBuildOptions398c62d9a5cSMichael Kruse isl::schedule_node visitBand(isl::schedule_node_band Band) {
399c62d9a5cSMichael Kruse isl::schedule_node_band Result =
400c62d9a5cSMichael Kruse Band.set_ast_build_options(ASTBuildOptions[Pos]);
401aa8a9761SMichael Kruse Pos += 1;
402aa8a9761SMichael Kruse return getBase().visitBand(Result);
403aa8a9761SMichael Kruse }
404aa8a9761SMichael Kruse };
405aa8a9761SMichael Kruse
406aa8a9761SMichael Kruse /// Return whether the schedule contains an extension node.
containsExtensionNode(isl::schedule Schedule)407aa8a9761SMichael Kruse static bool containsExtensionNode(isl::schedule Schedule) {
408aa8a9761SMichael Kruse assert(!Schedule.is_null());
409aa8a9761SMichael Kruse
410aa8a9761SMichael Kruse auto Callback = [](__isl_keep isl_schedule_node *Node,
411aa8a9761SMichael Kruse void *User) -> isl_bool {
412aa8a9761SMichael Kruse if (isl_schedule_node_get_type(Node) == isl_schedule_node_extension) {
413aa8a9761SMichael Kruse // Stop walking the schedule tree.
414aa8a9761SMichael Kruse return isl_bool_error;
415aa8a9761SMichael Kruse }
416aa8a9761SMichael Kruse
417aa8a9761SMichael Kruse // Continue searching the subtree.
418aa8a9761SMichael Kruse return isl_bool_true;
419aa8a9761SMichael Kruse };
420aa8a9761SMichael Kruse isl_stat RetVal = isl_schedule_foreach_schedule_node_top_down(
421aa8a9761SMichael Kruse Schedule.get(), Callback, nullptr);
422aa8a9761SMichael Kruse
423aa8a9761SMichael Kruse // We assume that the traversal itself does not fail, i.e. the only reason to
424aa8a9761SMichael Kruse // return isl_stat_error is that an extension node was found.
425aa8a9761SMichael Kruse return RetVal == isl_stat_error;
426aa8a9761SMichael Kruse }
427aa8a9761SMichael Kruse
4283f170eb1SMichael Kruse /// Find a named MDNode property in a LoopID.
findOptionalNodeOperand(MDNode * LoopMD,StringRef Name)4293f170eb1SMichael Kruse static MDNode *findOptionalNodeOperand(MDNode *LoopMD, StringRef Name) {
4303f170eb1SMichael Kruse return dyn_cast_or_null<MDNode>(
431*30c67587SKazu Hirata findMetadataOperand(LoopMD, Name).value_or(nullptr));
4323f170eb1SMichael Kruse }
4333f170eb1SMichael Kruse
4343f170eb1SMichael Kruse /// Is this node of type mark?
isMark(const isl::schedule_node & Node)4353f170eb1SMichael Kruse static bool isMark(const isl::schedule_node &Node) {
4363f170eb1SMichael Kruse return isl_schedule_node_get_type(Node.get()) == isl_schedule_node_mark;
4373f170eb1SMichael Kruse }
4383f170eb1SMichael Kruse
43930df6d5dSDavid Blaikie /// Is this node of type band?
isBand(const isl::schedule_node & Node)44030df6d5dSDavid Blaikie static bool isBand(const isl::schedule_node &Node) {
44130df6d5dSDavid Blaikie return isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band;
44230df6d5dSDavid Blaikie }
44330df6d5dSDavid Blaikie
444e470f926SMichael Kruse #ifndef NDEBUG
4453f170eb1SMichael Kruse /// Is this node a band of a single dimension (i.e. could represent a loop)?
isBandWithSingleLoop(const isl::schedule_node & Node)4463f170eb1SMichael Kruse static bool isBandWithSingleLoop(const isl::schedule_node &Node) {
4473f170eb1SMichael Kruse return isBand(Node) && isl_schedule_node_band_n_member(Node.get()) == 1;
4483f170eb1SMichael Kruse }
44930df6d5dSDavid Blaikie #endif
4503f170eb1SMichael Kruse
isLeaf(const isl::schedule_node & Node)451e470f926SMichael Kruse static bool isLeaf(const isl::schedule_node &Node) {
452e470f926SMichael Kruse return isl_schedule_node_get_type(Node.get()) == isl_schedule_node_leaf;
453e470f926SMichael Kruse }
454e470f926SMichael Kruse
4553f170eb1SMichael Kruse /// Create an isl::id representing the output loop after a transformation.
createGeneratedLoopAttr(isl::ctx Ctx,MDNode * FollowupLoopMD)4563f170eb1SMichael Kruse static isl::id createGeneratedLoopAttr(isl::ctx Ctx, MDNode *FollowupLoopMD) {
4573f170eb1SMichael Kruse // Don't need to id the followup.
4583f170eb1SMichael Kruse // TODO: Append llvm.loop.disable_heustistics metadata unless overridden by
4593f170eb1SMichael Kruse // user followup-MD
4603f170eb1SMichael Kruse if (!FollowupLoopMD)
4613f170eb1SMichael Kruse return {};
4623f170eb1SMichael Kruse
4633f170eb1SMichael Kruse BandAttr *Attr = new BandAttr();
4643f170eb1SMichael Kruse Attr->Metadata = FollowupLoopMD;
4653f170eb1SMichael Kruse return getIslLoopAttr(Ctx, Attr);
4663f170eb1SMichael Kruse }
4673f170eb1SMichael Kruse
4683f170eb1SMichael Kruse /// A loop consists of a band and an optional marker that wraps it. Return the
4693f170eb1SMichael Kruse /// outermost of the two.
4703f170eb1SMichael Kruse
4713f170eb1SMichael Kruse /// That is, either the mark or, if there is not mark, the loop itself. Can
4723f170eb1SMichael Kruse /// start with either the mark or the band.
moveToBandMark(isl::schedule_node BandOrMark)4733f170eb1SMichael Kruse static isl::schedule_node moveToBandMark(isl::schedule_node BandOrMark) {
4743f170eb1SMichael Kruse if (isBandMark(BandOrMark)) {
475d3fdbda6SRiccardo Mori assert(isBandWithSingleLoop(BandOrMark.child(0)));
4763f170eb1SMichael Kruse return BandOrMark;
4773f170eb1SMichael Kruse }
4783f170eb1SMichael Kruse assert(isBandWithSingleLoop(BandOrMark));
4793f170eb1SMichael Kruse
4803f170eb1SMichael Kruse isl::schedule_node Mark = BandOrMark.parent();
4813f170eb1SMichael Kruse if (isBandMark(Mark))
4823f170eb1SMichael Kruse return Mark;
4833f170eb1SMichael Kruse
4843f170eb1SMichael Kruse // Band has no loop marker.
4853f170eb1SMichael Kruse return BandOrMark;
4863f170eb1SMichael Kruse }
4873f170eb1SMichael Kruse
removeMark(isl::schedule_node MarkOrBand,BandAttr * & Attr)4883f170eb1SMichael Kruse static isl::schedule_node removeMark(isl::schedule_node MarkOrBand,
4893f170eb1SMichael Kruse BandAttr *&Attr) {
4903f170eb1SMichael Kruse MarkOrBand = moveToBandMark(MarkOrBand);
4913f170eb1SMichael Kruse
4923f170eb1SMichael Kruse isl::schedule_node Band;
4933f170eb1SMichael Kruse if (isMark(MarkOrBand)) {
494d3fdbda6SRiccardo Mori Attr = getLoopAttr(MarkOrBand.as<isl::schedule_node_mark>().get_id());
4953f170eb1SMichael Kruse Band = isl::manage(isl_schedule_node_delete(MarkOrBand.release()));
4963f170eb1SMichael Kruse } else {
4973f170eb1SMichael Kruse Attr = nullptr;
4983f170eb1SMichael Kruse Band = MarkOrBand;
4993f170eb1SMichael Kruse }
5003f170eb1SMichael Kruse
5013f170eb1SMichael Kruse assert(isBandWithSingleLoop(Band));
5023f170eb1SMichael Kruse return Band;
5033f170eb1SMichael Kruse }
5043f170eb1SMichael Kruse
5053f170eb1SMichael Kruse /// Remove the mark that wraps a loop. Return the band representing the loop.
removeMark(isl::schedule_node MarkOrBand)5063f170eb1SMichael Kruse static isl::schedule_node removeMark(isl::schedule_node MarkOrBand) {
5073f170eb1SMichael Kruse BandAttr *Attr;
5083f170eb1SMichael Kruse return removeMark(MarkOrBand, Attr);
5093f170eb1SMichael Kruse }
5103f170eb1SMichael Kruse
insertMark(isl::schedule_node Band,isl::id Mark)5113f170eb1SMichael Kruse static isl::schedule_node insertMark(isl::schedule_node Band, isl::id Mark) {
5123f170eb1SMichael Kruse assert(isBand(Band));
5133f170eb1SMichael Kruse assert(moveToBandMark(Band).is_equal(Band) &&
5143f170eb1SMichael Kruse "Don't add a two marks for a band");
5153f170eb1SMichael Kruse
516d3fdbda6SRiccardo Mori return Band.insert_mark(Mark).child(0);
5173f170eb1SMichael Kruse }
5183f170eb1SMichael Kruse
5193f170eb1SMichael Kruse /// Return the (one-dimensional) set of numbers that are divisible by @p Factor
5203f170eb1SMichael Kruse /// with remainder @p Offset.
5213f170eb1SMichael Kruse ///
5223f170eb1SMichael Kruse /// isDivisibleBySet(Ctx, 4, 0) = { [i] : floord(i,4) = 0 }
5233f170eb1SMichael Kruse /// isDivisibleBySet(Ctx, 4, 1) = { [i] : floord(i,4) = 1 }
5243f170eb1SMichael Kruse ///
isDivisibleBySet(isl::ctx & Ctx,long Factor,long Offset)5253f170eb1SMichael Kruse static isl::basic_set isDivisibleBySet(isl::ctx &Ctx, long Factor,
5263f170eb1SMichael Kruse long Offset) {
5273f170eb1SMichael Kruse isl::val ValFactor{Ctx, Factor};
5283f170eb1SMichael Kruse isl::val ValOffset{Ctx, Offset};
5293f170eb1SMichael Kruse
5303f170eb1SMichael Kruse isl::space Unispace{Ctx, 0, 1};
5313f170eb1SMichael Kruse isl::local_space LUnispace{Unispace};
5323f170eb1SMichael Kruse isl::aff AffFactor{LUnispace, ValFactor};
5333f170eb1SMichael Kruse isl::aff AffOffset{LUnispace, ValOffset};
5343f170eb1SMichael Kruse
5353f170eb1SMichael Kruse isl::aff Id = isl::aff::var_on_domain(LUnispace, isl::dim::out, 0);
5363f170eb1SMichael Kruse isl::aff DivMul = Id.mod(ValFactor);
5373f170eb1SMichael Kruse isl::basic_map Divisible = isl::basic_map::from_aff(DivMul);
5383f170eb1SMichael Kruse isl::basic_map Modulo = Divisible.fix_val(isl::dim::out, 0, ValOffset);
5393f170eb1SMichael Kruse return Modulo.domain();
5403f170eb1SMichael Kruse }
5413f170eb1SMichael Kruse
542d123e983SMichael Kruse /// Make the last dimension of Set to take values from 0 to VectorWidth - 1.
543d123e983SMichael Kruse ///
544d123e983SMichael Kruse /// @param Set A set, which should be modified.
545d123e983SMichael Kruse /// @param VectorWidth A parameter, which determines the constraint.
addExtentConstraints(isl::set Set,int VectorWidth)546d123e983SMichael Kruse static isl::set addExtentConstraints(isl::set Set, int VectorWidth) {
54744596fe6SRiccardo Mori unsigned Dims = unsignedFromIslSize(Set.tuple_dim());
54844596fe6SRiccardo Mori assert(Dims >= 1);
549d123e983SMichael Kruse isl::space Space = Set.get_space();
550d123e983SMichael Kruse isl::local_space LocalSpace = isl::local_space(Space);
551d123e983SMichael Kruse isl::constraint ExtConstr = isl::constraint::alloc_inequality(LocalSpace);
552d123e983SMichael Kruse ExtConstr = ExtConstr.set_constant_si(0);
553d123e983SMichael Kruse ExtConstr = ExtConstr.set_coefficient_si(isl::dim::set, Dims - 1, 1);
554d123e983SMichael Kruse Set = Set.add_constraint(ExtConstr);
555d123e983SMichael Kruse ExtConstr = isl::constraint::alloc_inequality(LocalSpace);
556d123e983SMichael Kruse ExtConstr = ExtConstr.set_constant_si(VectorWidth - 1);
557d123e983SMichael Kruse ExtConstr = ExtConstr.set_coefficient_si(isl::dim::set, Dims - 1, -1);
558d123e983SMichael Kruse return Set.add_constraint(ExtConstr);
559d123e983SMichael Kruse }
56064489255SMichael Kruse
56164489255SMichael Kruse /// Collapse perfectly nested bands into a single band.
562bd93df93SMichael Kruse class BandCollapseRewriter final
563bd93df93SMichael Kruse : public ScheduleTreeRewriter<BandCollapseRewriter> {
56464489255SMichael Kruse private:
56564489255SMichael Kruse using BaseTy = ScheduleTreeRewriter<BandCollapseRewriter>;
getBase()56664489255SMichael Kruse BaseTy &getBase() { return *this; }
getBase() const56764489255SMichael Kruse const BaseTy &getBase() const { return *this; }
56864489255SMichael Kruse
56964489255SMichael Kruse public:
visitBand(isl::schedule_node_band RootBand)57064489255SMichael Kruse isl::schedule visitBand(isl::schedule_node_band RootBand) {
57164489255SMichael Kruse isl::schedule_node_band Band = RootBand;
57264489255SMichael Kruse isl::ctx Ctx = Band.ctx();
57364489255SMichael Kruse
57464489255SMichael Kruse // Do not merge permutable band to avoid loosing the permutability property.
57564489255SMichael Kruse // Cannot collapse even two permutable loops, they might be permutable
57664489255SMichael Kruse // individually, but not necassarily accross.
57744596fe6SRiccardo Mori if (unsignedFromIslSize(Band.n_member()) > 1u && Band.permutable())
57864489255SMichael Kruse return getBase().visitBand(Band);
57964489255SMichael Kruse
58064489255SMichael Kruse // Find collapsable bands.
58164489255SMichael Kruse SmallVector<isl::schedule_node_band> Nest;
58264489255SMichael Kruse int NumTotalLoops = 0;
58364489255SMichael Kruse isl::schedule_node Body;
58464489255SMichael Kruse while (true) {
58564489255SMichael Kruse Nest.push_back(Band);
58644596fe6SRiccardo Mori NumTotalLoops += unsignedFromIslSize(Band.n_member());
58764489255SMichael Kruse Body = Band.first_child();
58864489255SMichael Kruse if (!Body.isa<isl::schedule_node_band>())
58964489255SMichael Kruse break;
59064489255SMichael Kruse Band = Body.as<isl::schedule_node_band>();
59164489255SMichael Kruse
59264489255SMichael Kruse // Do not include next band if it is permutable to not lose its
59364489255SMichael Kruse // permutability property.
59444596fe6SRiccardo Mori if (unsignedFromIslSize(Band.n_member()) > 1u && Band.permutable())
59564489255SMichael Kruse break;
59664489255SMichael Kruse }
59764489255SMichael Kruse
59864489255SMichael Kruse // Nothing to collapse, preserve permutability.
59964489255SMichael Kruse if (Nest.size() <= 1)
60064489255SMichael Kruse return getBase().visitBand(Band);
60164489255SMichael Kruse
60264489255SMichael Kruse LLVM_DEBUG({
60364489255SMichael Kruse dbgs() << "Found loops to collapse between\n";
60464489255SMichael Kruse dumpIslObj(RootBand, dbgs());
60564489255SMichael Kruse dbgs() << "and\n";
60664489255SMichael Kruse dumpIslObj(Body, dbgs());
60764489255SMichael Kruse dbgs() << "\n";
60864489255SMichael Kruse });
60964489255SMichael Kruse
61064489255SMichael Kruse isl::schedule NewBody = visit(Body);
61164489255SMichael Kruse
61264489255SMichael Kruse // Collect partial schedules from all members.
61364489255SMichael Kruse isl::union_pw_aff_list PartScheds{Ctx, NumTotalLoops};
61464489255SMichael Kruse for (isl::schedule_node_band Band : Nest) {
61544596fe6SRiccardo Mori int NumLoops = unsignedFromIslSize(Band.n_member());
61664489255SMichael Kruse isl::multi_union_pw_aff BandScheds = Band.get_partial_schedule();
61764489255SMichael Kruse for (auto j : seq<int>(0, NumLoops))
61864489255SMichael Kruse PartScheds = PartScheds.add(BandScheds.at(j));
61964489255SMichael Kruse }
62064489255SMichael Kruse isl::space ScatterSpace = isl::space(Ctx, 0, NumTotalLoops);
62164489255SMichael Kruse isl::multi_union_pw_aff PartSchedsMulti{ScatterSpace, PartScheds};
62264489255SMichael Kruse
62364489255SMichael Kruse isl::schedule_node_band CollapsedBand =
62464489255SMichael Kruse NewBody.insert_partial_schedule(PartSchedsMulti)
62564489255SMichael Kruse .get_root()
62664489255SMichael Kruse .first_child()
62764489255SMichael Kruse .as<isl::schedule_node_band>();
62864489255SMichael Kruse
62964489255SMichael Kruse // Copy over loop attributes form original bands.
63064489255SMichael Kruse int LoopIdx = 0;
63164489255SMichael Kruse for (isl::schedule_node_band Band : Nest) {
63244596fe6SRiccardo Mori int NumLoops = unsignedFromIslSize(Band.n_member());
63364489255SMichael Kruse for (int i : seq<int>(0, NumLoops)) {
63464489255SMichael Kruse CollapsedBand = applyBandMemberAttributes(std::move(CollapsedBand),
63564489255SMichael Kruse LoopIdx, Band, i);
63664489255SMichael Kruse LoopIdx += 1;
63764489255SMichael Kruse }
63864489255SMichael Kruse }
63964489255SMichael Kruse assert(LoopIdx == NumTotalLoops &&
64064489255SMichael Kruse "Expect the same number of loops to add up again");
64164489255SMichael Kruse
64264489255SMichael Kruse return CollapsedBand.get_schedule();
64364489255SMichael Kruse }
64464489255SMichael Kruse };
64564489255SMichael Kruse
collapseBands(isl::schedule Sched)64664489255SMichael Kruse static isl::schedule collapseBands(isl::schedule Sched) {
64764489255SMichael Kruse LLVM_DEBUG(dbgs() << "Collapse bands in schedule\n");
64864489255SMichael Kruse BandCollapseRewriter Rewriter;
64964489255SMichael Kruse return Rewriter.visit(Sched);
65064489255SMichael Kruse }
65164489255SMichael Kruse
65264489255SMichael Kruse /// Collect sequentially executed bands (or anything else), even if nested in a
65364489255SMichael Kruse /// mark or other nodes whose child is executed just once. If we can
65464489255SMichael Kruse /// successfully fuse the bands, we allow them to be removed.
collectPotentiallyFusableBands(isl::schedule_node Node,SmallVectorImpl<std::pair<isl::schedule_node,isl::schedule_node>> & ScheduleBands,const isl::schedule_node & DirectChild)65564489255SMichael Kruse static void collectPotentiallyFusableBands(
65664489255SMichael Kruse isl::schedule_node Node,
65764489255SMichael Kruse SmallVectorImpl<std::pair<isl::schedule_node, isl::schedule_node>>
65864489255SMichael Kruse &ScheduleBands,
65964489255SMichael Kruse const isl::schedule_node &DirectChild) {
66064489255SMichael Kruse switch (isl_schedule_node_get_type(Node.get())) {
66164489255SMichael Kruse case isl_schedule_node_sequence:
66264489255SMichael Kruse case isl_schedule_node_set:
66364489255SMichael Kruse case isl_schedule_node_mark:
66464489255SMichael Kruse case isl_schedule_node_domain:
66564489255SMichael Kruse case isl_schedule_node_filter:
66664489255SMichael Kruse if (Node.has_children()) {
66764489255SMichael Kruse isl::schedule_node C = Node.first_child();
66864489255SMichael Kruse while (true) {
66964489255SMichael Kruse collectPotentiallyFusableBands(C, ScheduleBands, DirectChild);
67064489255SMichael Kruse if (!C.has_next_sibling())
67164489255SMichael Kruse break;
67264489255SMichael Kruse C = C.next_sibling();
67364489255SMichael Kruse }
67464489255SMichael Kruse }
67564489255SMichael Kruse break;
67664489255SMichael Kruse
67764489255SMichael Kruse default:
67864489255SMichael Kruse // Something that does not execute suquentially (e.g. a band)
67964489255SMichael Kruse ScheduleBands.push_back({Node, DirectChild});
68064489255SMichael Kruse break;
68164489255SMichael Kruse }
68264489255SMichael Kruse }
68364489255SMichael Kruse
68464489255SMichael Kruse /// Remove dependencies that are resolved by @p PartSched. That is, remove
68564489255SMichael Kruse /// everything that we already know is executed in-order.
remainingDepsFromPartialSchedule(isl::union_map PartSched,isl::union_map Deps)68664489255SMichael Kruse static isl::union_map remainingDepsFromPartialSchedule(isl::union_map PartSched,
68764489255SMichael Kruse isl::union_map Deps) {
68844596fe6SRiccardo Mori unsigned NumDims = getNumScatterDims(PartSched);
68964489255SMichael Kruse auto ParamSpace = PartSched.get_space().params();
69064489255SMichael Kruse
69164489255SMichael Kruse // { Scatter[] }
69264489255SMichael Kruse isl::space ScatterSpace =
69364489255SMichael Kruse ParamSpace.set_from_params().add_dims(isl::dim::set, NumDims);
69464489255SMichael Kruse
69564489255SMichael Kruse // { Scatter[] -> Domain[] }
69664489255SMichael Kruse isl::union_map PartSchedRev = PartSched.reverse();
69764489255SMichael Kruse
69864489255SMichael Kruse // { Scatter[] -> Scatter[] }
69964489255SMichael Kruse isl::map MaybeBefore = isl::map::lex_le(ScatterSpace);
70064489255SMichael Kruse
70164489255SMichael Kruse // { Domain[] -> Domain[] }
70264489255SMichael Kruse isl::union_map DomMaybeBefore =
70364489255SMichael Kruse MaybeBefore.apply_domain(PartSchedRev).apply_range(PartSchedRev);
70464489255SMichael Kruse
70564489255SMichael Kruse // { Domain[] -> Domain[] }
70664489255SMichael Kruse isl::union_map ChildRemainingDeps = Deps.intersect(DomMaybeBefore);
70764489255SMichael Kruse
70864489255SMichael Kruse return ChildRemainingDeps;
70964489255SMichael Kruse }
71064489255SMichael Kruse
71164489255SMichael Kruse /// Remove dependencies that are resolved by executing them in the order
71264489255SMichael Kruse /// specified by @p Domains;
remainigDepsFromSequence(ArrayRef<isl::union_set> Domains,isl::union_map Deps)71364489255SMichael Kruse static isl::union_map remainigDepsFromSequence(ArrayRef<isl::union_set> Domains,
71464489255SMichael Kruse isl::union_map Deps) {
71564489255SMichael Kruse isl::ctx Ctx = Deps.ctx();
71664489255SMichael Kruse isl::space ParamSpace = Deps.get_space().params();
71764489255SMichael Kruse
71864489255SMichael Kruse // Create a partial schedule mapping to constants that reflect the execution
71964489255SMichael Kruse // order.
72064489255SMichael Kruse isl::union_map PartialSchedules = isl::union_map::empty(Ctx);
72164489255SMichael Kruse for (auto P : enumerate(Domains)) {
72264489255SMichael Kruse isl::val ExecTime = isl::val(Ctx, P.index());
72364489255SMichael Kruse isl::union_pw_aff DomSched{P.value(), ExecTime};
72464489255SMichael Kruse PartialSchedules = PartialSchedules.unite(DomSched.as_union_map());
72564489255SMichael Kruse }
72664489255SMichael Kruse
72764489255SMichael Kruse return remainingDepsFromPartialSchedule(PartialSchedules, Deps);
72864489255SMichael Kruse }
72964489255SMichael Kruse
73064489255SMichael Kruse /// Determine whether the outermost loop of to bands can be fused while
73164489255SMichael Kruse /// respecting validity dependencies.
canFuseOutermost(const isl::schedule_node_band & LHS,const isl::schedule_node_band & RHS,const isl::union_map & Deps)73264489255SMichael Kruse static bool canFuseOutermost(const isl::schedule_node_band &LHS,
73364489255SMichael Kruse const isl::schedule_node_band &RHS,
73464489255SMichael Kruse const isl::union_map &Deps) {
73564489255SMichael Kruse // { LHSDomain[] -> Scatter[] }
73664489255SMichael Kruse isl::union_map LHSPartSched =
73764489255SMichael Kruse LHS.get_partial_schedule().get_at(0).as_union_map();
73864489255SMichael Kruse
73964489255SMichael Kruse // { Domain[] -> Scatter[] }
74064489255SMichael Kruse isl::union_map RHSPartSched =
74164489255SMichael Kruse RHS.get_partial_schedule().get_at(0).as_union_map();
74264489255SMichael Kruse
74364489255SMichael Kruse // Dependencies that are already resolved because LHS executes before RHS, but
74464489255SMichael Kruse // will not be anymore after fusion. { DefDomain[] -> UseDomain[] }
74564489255SMichael Kruse isl::union_map OrderedBySequence =
74664489255SMichael Kruse Deps.intersect_domain(LHSPartSched.domain())
74764489255SMichael Kruse .intersect_range(RHSPartSched.domain());
74864489255SMichael Kruse
74964489255SMichael Kruse isl::space ParamSpace = OrderedBySequence.get_space().params();
75064489255SMichael Kruse isl::space NewScatterSpace = ParamSpace.add_unnamed_tuple(1);
75164489255SMichael Kruse
75264489255SMichael Kruse // { Scatter[] -> Scatter[] }
75364489255SMichael Kruse isl::map After = isl::map::lex_gt(NewScatterSpace);
75464489255SMichael Kruse
75564489255SMichael Kruse // After fusion, instances with smaller (or equal, which means they will be
75664489255SMichael Kruse // executed in the same iteration, but the LHS instance is still sequenced
75764489255SMichael Kruse // before RHS) scatter value will still be executed before. This are the
75864489255SMichael Kruse // orderings where this is not necessarily the case.
75964489255SMichael Kruse // { LHSDomain[] -> RHSDomain[] }
76064489255SMichael Kruse isl::union_map MightBeAfterDoms = After.apply_domain(LHSPartSched.reverse())
76164489255SMichael Kruse .apply_range(RHSPartSched.reverse());
76264489255SMichael Kruse
76364489255SMichael Kruse // Dependencies that are not resolved by the new execution order.
76464489255SMichael Kruse isl::union_map WithBefore = OrderedBySequence.intersect(MightBeAfterDoms);
76564489255SMichael Kruse
76664489255SMichael Kruse return WithBefore.is_empty();
76764489255SMichael Kruse }
76864489255SMichael Kruse
76964489255SMichael Kruse /// Fuse @p LHS and @p RHS if possible while preserving validity dependenvies.
tryGreedyFuse(isl::schedule_node_band LHS,isl::schedule_node_band RHS,const isl::union_map & Deps)77064489255SMichael Kruse static isl::schedule tryGreedyFuse(isl::schedule_node_band LHS,
77164489255SMichael Kruse isl::schedule_node_band RHS,
77264489255SMichael Kruse const isl::union_map &Deps) {
77364489255SMichael Kruse if (!canFuseOutermost(LHS, RHS, Deps))
77464489255SMichael Kruse return {};
77564489255SMichael Kruse
77664489255SMichael Kruse LLVM_DEBUG({
77764489255SMichael Kruse dbgs() << "Found loops for greedy fusion:\n";
77864489255SMichael Kruse dumpIslObj(LHS, dbgs());
77964489255SMichael Kruse dbgs() << "and\n";
78064489255SMichael Kruse dumpIslObj(RHS, dbgs());
78164489255SMichael Kruse dbgs() << "\n";
78264489255SMichael Kruse });
78364489255SMichael Kruse
78464489255SMichael Kruse // The partial schedule of the bands outermost loop that we need to combine
78564489255SMichael Kruse // for the fusion.
78664489255SMichael Kruse isl::union_pw_aff LHSPartOuterSched = LHS.get_partial_schedule().get_at(0);
78764489255SMichael Kruse isl::union_pw_aff RHSPartOuterSched = RHS.get_partial_schedule().get_at(0);
78864489255SMichael Kruse
78964489255SMichael Kruse // Isolate band bodies as roots of their own schedule trees.
79064489255SMichael Kruse IdentityRewriter Rewriter;
79164489255SMichael Kruse isl::schedule LHSBody = Rewriter.visit(LHS.first_child());
79264489255SMichael Kruse isl::schedule RHSBody = Rewriter.visit(RHS.first_child());
79364489255SMichael Kruse
79464489255SMichael Kruse // Reconstruct the non-outermost (not going to be fused) loops from both
79564489255SMichael Kruse // bands.
79664489255SMichael Kruse // TODO: Maybe it is possibly to transfer the 'permutability' property from
79764489255SMichael Kruse // LHS+RHS. At minimum we need merge multiple band members at once, otherwise
79864489255SMichael Kruse // permutability has no meaning.
79964489255SMichael Kruse isl::schedule LHSNewBody =
80064489255SMichael Kruse rebuildBand(LHS, LHSBody, [](int i) { return i > 0; });
80164489255SMichael Kruse isl::schedule RHSNewBody =
80264489255SMichael Kruse rebuildBand(RHS, RHSBody, [](int i) { return i > 0; });
80364489255SMichael Kruse
80464489255SMichael Kruse // The loop body of the fused loop.
80564489255SMichael Kruse isl::schedule NewCommonBody = LHSNewBody.sequence(RHSNewBody);
80664489255SMichael Kruse
80764489255SMichael Kruse // Combine the partial schedules of both loops to a new one. Instances with
80864489255SMichael Kruse // the same scatter value are put together.
80964489255SMichael Kruse isl::union_map NewCommonPartialSched =
81064489255SMichael Kruse LHSPartOuterSched.as_union_map().unite(RHSPartOuterSched.as_union_map());
81164489255SMichael Kruse isl::schedule NewCommonSchedule = NewCommonBody.insert_partial_schedule(
81264489255SMichael Kruse NewCommonPartialSched.as_multi_union_pw_aff());
81364489255SMichael Kruse
81464489255SMichael Kruse return NewCommonSchedule;
81564489255SMichael Kruse }
81664489255SMichael Kruse
tryGreedyFuse(isl::schedule_node LHS,isl::schedule_node RHS,const isl::union_map & Deps)81764489255SMichael Kruse static isl::schedule tryGreedyFuse(isl::schedule_node LHS,
81864489255SMichael Kruse isl::schedule_node RHS,
81964489255SMichael Kruse const isl::union_map &Deps) {
82064489255SMichael Kruse // TODO: Non-bands could be interpreted as a band with just as single
82164489255SMichael Kruse // iteration. However, this is only useful if both ends of a fused loop were
82264489255SMichael Kruse // originally loops themselves.
82364489255SMichael Kruse if (!LHS.isa<isl::schedule_node_band>())
82464489255SMichael Kruse return {};
82564489255SMichael Kruse if (!RHS.isa<isl::schedule_node_band>())
82664489255SMichael Kruse return {};
82764489255SMichael Kruse
82864489255SMichael Kruse return tryGreedyFuse(LHS.as<isl::schedule_node_band>(),
82964489255SMichael Kruse RHS.as<isl::schedule_node_band>(), Deps);
83064489255SMichael Kruse }
83164489255SMichael Kruse
83264489255SMichael Kruse /// Fuse all fusable loop top-down in a schedule tree.
83364489255SMichael Kruse ///
83464489255SMichael Kruse /// The isl::union_map parameters is the set of validity dependencies that have
83564489255SMichael Kruse /// not been resolved/carried by a parent schedule node.
836bd93df93SMichael Kruse class GreedyFusionRewriter final
83764489255SMichael Kruse : public ScheduleTreeRewriter<GreedyFusionRewriter, isl::union_map> {
83864489255SMichael Kruse private:
83964489255SMichael Kruse using BaseTy = ScheduleTreeRewriter<GreedyFusionRewriter, isl::union_map>;
getBase()84064489255SMichael Kruse BaseTy &getBase() { return *this; }
getBase() const84164489255SMichael Kruse const BaseTy &getBase() const { return *this; }
84264489255SMichael Kruse
84364489255SMichael Kruse public:
84464489255SMichael Kruse /// Is set to true if anything has been fused.
84564489255SMichael Kruse bool AnyChange = false;
84664489255SMichael Kruse
visitBand(isl::schedule_node_band Band,isl::union_map Deps)84764489255SMichael Kruse isl::schedule visitBand(isl::schedule_node_band Band, isl::union_map Deps) {
84864489255SMichael Kruse // { Domain[] -> Scatter[] }
84964489255SMichael Kruse isl::union_map PartSched =
85064489255SMichael Kruse isl::union_map::from(Band.get_partial_schedule());
85144596fe6SRiccardo Mori assert(getNumScatterDims(PartSched) ==
85244596fe6SRiccardo Mori unsignedFromIslSize(Band.n_member()));
85364489255SMichael Kruse isl::space ParamSpace = PartSched.get_space().params();
85464489255SMichael Kruse
85564489255SMichael Kruse // { Scatter[] -> Domain[] }
85664489255SMichael Kruse isl::union_map PartSchedRev = PartSched.reverse();
85764489255SMichael Kruse
85864489255SMichael Kruse // Possible within the same iteration. Dependencies with smaller scatter
85964489255SMichael Kruse // value are carried by this loop and therefore have been resolved by the
86064489255SMichael Kruse // in-order execution if the loop iteration. A dependency with small scatter
86164489255SMichael Kruse // value would be a dependency violation that we assume did not happen. {
86264489255SMichael Kruse // Domain[] -> Domain[] }
86364489255SMichael Kruse isl::union_map Unsequenced = PartSchedRev.apply_domain(PartSchedRev);
86464489255SMichael Kruse
86564489255SMichael Kruse // Actual dependencies within the same iteration.
86664489255SMichael Kruse // { DefDomain[] -> UseDomain[] }
86764489255SMichael Kruse isl::union_map RemDeps = Deps.intersect(Unsequenced);
86864489255SMichael Kruse
86964489255SMichael Kruse return getBase().visitBand(Band, RemDeps);
87064489255SMichael Kruse }
87164489255SMichael Kruse
visitSequence(isl::schedule_node_sequence Sequence,isl::union_map Deps)87264489255SMichael Kruse isl::schedule visitSequence(isl::schedule_node_sequence Sequence,
87364489255SMichael Kruse isl::union_map Deps) {
87464489255SMichael Kruse int NumChildren = isl_schedule_node_n_children(Sequence.get());
87564489255SMichael Kruse
87664489255SMichael Kruse // List of fusion candidates. The first element is the fusion candidate, the
87764489255SMichael Kruse // second is candidate's ancestor that is the sequence's direct child. It is
87864489255SMichael Kruse // preferable to use the direct child if not if its non-direct children is
87964489255SMichael Kruse // fused to preserve its structure such as mark nodes.
88064489255SMichael Kruse SmallVector<std::pair<isl::schedule_node, isl::schedule_node>> Bands;
88164489255SMichael Kruse for (auto i : seq<int>(0, NumChildren)) {
88264489255SMichael Kruse isl::schedule_node Child = Sequence.child(i);
88364489255SMichael Kruse collectPotentiallyFusableBands(Child, Bands, Child);
88464489255SMichael Kruse }
88564489255SMichael Kruse
88664489255SMichael Kruse // Direct children that had at least one of its decendants fused.
88764489255SMichael Kruse SmallDenseSet<isl_schedule_node *, 4> ChangedDirectChildren;
88864489255SMichael Kruse
88964489255SMichael Kruse // Fuse neigboring bands until reaching the end of candidates.
89064489255SMichael Kruse int i = 0;
89164489255SMichael Kruse while (i + 1 < (int)Bands.size()) {
89264489255SMichael Kruse isl::schedule Fused =
89364489255SMichael Kruse tryGreedyFuse(Bands[i].first, Bands[i + 1].first, Deps);
89464489255SMichael Kruse if (Fused.is_null()) {
89564489255SMichael Kruse // Cannot merge this node with the next; look at next pair.
89664489255SMichael Kruse i += 1;
89764489255SMichael Kruse continue;
89864489255SMichael Kruse }
89964489255SMichael Kruse
90064489255SMichael Kruse // Mark the direct children as (partially) fused.
90164489255SMichael Kruse if (!Bands[i].second.is_null())
90264489255SMichael Kruse ChangedDirectChildren.insert(Bands[i].second.get());
90364489255SMichael Kruse if (!Bands[i + 1].second.is_null())
90464489255SMichael Kruse ChangedDirectChildren.insert(Bands[i + 1].second.get());
90564489255SMichael Kruse
90664489255SMichael Kruse // Collapse the neigbros to a single new candidate that could be fused
90764489255SMichael Kruse // with the next candidate.
90864489255SMichael Kruse Bands[i] = {Fused.get_root(), {}};
90964489255SMichael Kruse Bands.erase(Bands.begin() + i + 1);
91064489255SMichael Kruse
91164489255SMichael Kruse AnyChange = true;
91264489255SMichael Kruse }
91364489255SMichael Kruse
91464489255SMichael Kruse // By construction equal if done with collectPotentiallyFusableBands's
91564489255SMichael Kruse // output.
91664489255SMichael Kruse SmallVector<isl::union_set> SubDomains;
91764489255SMichael Kruse SubDomains.reserve(NumChildren);
91864489255SMichael Kruse for (int i = 0; i < NumChildren; i += 1)
91964489255SMichael Kruse SubDomains.push_back(Sequence.child(i).domain());
92064489255SMichael Kruse auto SubRemainingDeps = remainigDepsFromSequence(SubDomains, Deps);
92164489255SMichael Kruse
92264489255SMichael Kruse // We may iterate over direct children multiple times, be sure to add each
92364489255SMichael Kruse // at most once.
92464489255SMichael Kruse SmallDenseSet<isl_schedule_node *, 4> AlreadyAdded;
92564489255SMichael Kruse
92664489255SMichael Kruse isl::schedule Result;
92764489255SMichael Kruse for (auto &P : Bands) {
92864489255SMichael Kruse isl::schedule_node MaybeFused = P.first;
92964489255SMichael Kruse isl::schedule_node DirectChild = P.second;
93064489255SMichael Kruse
93164489255SMichael Kruse // If not modified, use the direct child.
93264489255SMichael Kruse if (!DirectChild.is_null() &&
93364489255SMichael Kruse !ChangedDirectChildren.count(DirectChild.get())) {
93464489255SMichael Kruse if (AlreadyAdded.count(DirectChild.get()))
93564489255SMichael Kruse continue;
93664489255SMichael Kruse AlreadyAdded.insert(DirectChild.get());
93764489255SMichael Kruse MaybeFused = DirectChild;
93864489255SMichael Kruse } else {
93964489255SMichael Kruse assert(AnyChange &&
94064489255SMichael Kruse "Need changed flag for be consistent with actual change");
94164489255SMichael Kruse }
94264489255SMichael Kruse
94364489255SMichael Kruse // Top-down recursion: If the outermost loop has been fused, their nested
94464489255SMichael Kruse // bands might be fusable now as well.
94564489255SMichael Kruse isl::schedule InnerFused = visit(MaybeFused, SubRemainingDeps);
94664489255SMichael Kruse
94764489255SMichael Kruse // Reconstruct the sequence, with some of the children fused.
94864489255SMichael Kruse if (Result.is_null())
94964489255SMichael Kruse Result = InnerFused;
95064489255SMichael Kruse else
95164489255SMichael Kruse Result = Result.sequence(InnerFused);
95264489255SMichael Kruse }
95364489255SMichael Kruse
95464489255SMichael Kruse return Result;
95564489255SMichael Kruse }
95664489255SMichael Kruse };
95764489255SMichael Kruse
9583f170eb1SMichael Kruse } // namespace
9593f170eb1SMichael Kruse
isBandMark(const isl::schedule_node & Node)9603f170eb1SMichael Kruse bool polly::isBandMark(const isl::schedule_node &Node) {
961d3fdbda6SRiccardo Mori return isMark(Node) &&
962d3fdbda6SRiccardo Mori isLoopAttr(Node.as<isl::schedule_node_mark>().get_id());
9633f170eb1SMichael Kruse }
9643f170eb1SMichael Kruse
getBandAttr(isl::schedule_node MarkOrBand)9653f170eb1SMichael Kruse BandAttr *polly::getBandAttr(isl::schedule_node MarkOrBand) {
9663f170eb1SMichael Kruse MarkOrBand = moveToBandMark(MarkOrBand);
9673f170eb1SMichael Kruse if (!isMark(MarkOrBand))
9683f170eb1SMichael Kruse return nullptr;
9693f170eb1SMichael Kruse
970d3fdbda6SRiccardo Mori return getLoopAttr(MarkOrBand.as<isl::schedule_node_mark>().get_id());
9713f170eb1SMichael Kruse }
9723f170eb1SMichael Kruse
hoistExtensionNodes(isl::schedule Sched)973aa8a9761SMichael Kruse isl::schedule polly::hoistExtensionNodes(isl::schedule Sched) {
974aa8a9761SMichael Kruse // If there is no extension node in the first place, return the original
975aa8a9761SMichael Kruse // schedule tree.
976aa8a9761SMichael Kruse if (!containsExtensionNode(Sched))
977aa8a9761SMichael Kruse return Sched;
978aa8a9761SMichael Kruse
979aa8a9761SMichael Kruse // Build options can anchor schedule nodes, such that the schedule tree cannot
980aa8a9761SMichael Kruse // be modified anymore. Therefore, apply build options after the tree has been
981aa8a9761SMichael Kruse // created.
982aa8a9761SMichael Kruse CollectASTBuildOptions Collector;
983aa8a9761SMichael Kruse Collector.visit(Sched);
984aa8a9761SMichael Kruse
985aa8a9761SMichael Kruse // Rewrite the schedule tree without extension nodes.
986aa8a9761SMichael Kruse ExtensionNodeRewriter Rewriter;
987aa8a9761SMichael Kruse isl::schedule NewSched = Rewriter.visitSchedule(Sched);
988aa8a9761SMichael Kruse
989aa8a9761SMichael Kruse // Reapply the AST build options. The rewriter must not change the iteration
990aa8a9761SMichael Kruse // order of bands. Any other node type is ignored.
991aa8a9761SMichael Kruse ApplyASTBuildOptions Applicator(Collector.ASTBuildOptions);
992aa8a9761SMichael Kruse NewSched = Applicator.visitSchedule(NewSched);
993aa8a9761SMichael Kruse
994aa8a9761SMichael Kruse return NewSched;
995aa8a9761SMichael Kruse }
9963f170eb1SMichael Kruse
applyFullUnroll(isl::schedule_node BandToUnroll)9973f170eb1SMichael Kruse isl::schedule polly::applyFullUnroll(isl::schedule_node BandToUnroll) {
9980813bd16SRiccardo Mori isl::ctx Ctx = BandToUnroll.ctx();
9993f170eb1SMichael Kruse
10003f170eb1SMichael Kruse // Remove the loop's mark, the loop will disappear anyway.
10013f170eb1SMichael Kruse BandToUnroll = removeMark(BandToUnroll);
10023f170eb1SMichael Kruse assert(isBandWithSingleLoop(BandToUnroll));
10033f170eb1SMichael Kruse
10043f170eb1SMichael Kruse isl::multi_union_pw_aff PartialSched = isl::manage(
10053f170eb1SMichael Kruse isl_schedule_node_band_get_partial_schedule(BandToUnroll.get()));
100644596fe6SRiccardo Mori assert(unsignedFromIslSize(PartialSched.dim(isl::dim::out)) == 1u &&
10073f170eb1SMichael Kruse "Can only unroll a single dimension");
1008d3fdbda6SRiccardo Mori isl::union_pw_aff PartialSchedUAff = PartialSched.at(0);
10093f170eb1SMichael Kruse
10103f170eb1SMichael Kruse isl::union_set Domain = BandToUnroll.get_domain();
10113f170eb1SMichael Kruse PartialSchedUAff = PartialSchedUAff.intersect_domain(Domain);
1012d3fdbda6SRiccardo Mori isl::union_map PartialSchedUMap =
1013d3fdbda6SRiccardo Mori isl::union_map::from(isl::union_pw_multi_aff(PartialSchedUAff));
10143f170eb1SMichael Kruse
1015f51427afSMichael Kruse // Enumerator only the scatter elements.
1016f51427afSMichael Kruse isl::union_set ScatterList = PartialSchedUMap.range();
10173f170eb1SMichael Kruse
1018f51427afSMichael Kruse // Enumerate all loop iterations.
10193f170eb1SMichael Kruse // TODO: Diagnose if not enumerable or depends on a parameter.
1020f51427afSMichael Kruse SmallVector<isl::point, 16> Elts;
1021f51427afSMichael Kruse ScatterList.foreach_point([&Elts](isl::point P) -> isl::stat {
10223f170eb1SMichael Kruse Elts.push_back(P);
10233f170eb1SMichael Kruse return isl::stat::ok();
10243f170eb1SMichael Kruse });
10253f170eb1SMichael Kruse
10263f170eb1SMichael Kruse // Don't assume that foreach_point returns in execution order.
10273f170eb1SMichael Kruse llvm::sort(Elts, [](isl::point P1, isl::point P2) -> bool {
10283f170eb1SMichael Kruse isl::val C1 = P1.get_coordinate_val(isl::dim::set, 0);
10293f170eb1SMichael Kruse isl::val C2 = P2.get_coordinate_val(isl::dim::set, 0);
10303f170eb1SMichael Kruse return C1.lt(C2);
10313f170eb1SMichael Kruse });
10323f170eb1SMichael Kruse
10333f170eb1SMichael Kruse // Convert the points to a sequence of filters.
1034d3fdbda6SRiccardo Mori isl::union_set_list List = isl::union_set_list(Ctx, Elts.size());
10353f170eb1SMichael Kruse for (isl::point P : Elts) {
1036f51427afSMichael Kruse // Determine the domains that map this scatter element.
1037f51427afSMichael Kruse isl::union_set DomainFilter = PartialSchedUMap.intersect_range(P).domain();
10383f170eb1SMichael Kruse
1039f51427afSMichael Kruse List = List.add(DomainFilter);
10403f170eb1SMichael Kruse }
10413f170eb1SMichael Kruse
10423f170eb1SMichael Kruse // Replace original band with unrolled sequence.
10433f170eb1SMichael Kruse isl::schedule_node Body =
10443f170eb1SMichael Kruse isl::manage(isl_schedule_node_delete(BandToUnroll.release()));
10453f170eb1SMichael Kruse Body = Body.insert_sequence(List);
10463f170eb1SMichael Kruse return Body.get_schedule();
10473f170eb1SMichael Kruse }
10483f170eb1SMichael Kruse
applyPartialUnroll(isl::schedule_node BandToUnroll,int Factor)10493f170eb1SMichael Kruse isl::schedule polly::applyPartialUnroll(isl::schedule_node BandToUnroll,
10503f170eb1SMichael Kruse int Factor) {
10513f170eb1SMichael Kruse assert(Factor > 0 && "Positive unroll factor required");
10520813bd16SRiccardo Mori isl::ctx Ctx = BandToUnroll.ctx();
10533f170eb1SMichael Kruse
10543f170eb1SMichael Kruse // Remove the mark, save the attribute for later use.
10553f170eb1SMichael Kruse BandAttr *Attr;
10563f170eb1SMichael Kruse BandToUnroll = removeMark(BandToUnroll, Attr);
10573f170eb1SMichael Kruse assert(isBandWithSingleLoop(BandToUnroll));
10583f170eb1SMichael Kruse
10593f170eb1SMichael Kruse isl::multi_union_pw_aff PartialSched = isl::manage(
10603f170eb1SMichael Kruse isl_schedule_node_band_get_partial_schedule(BandToUnroll.get()));
10613f170eb1SMichael Kruse
10623f170eb1SMichael Kruse // { Stmt[] -> [x] }
1063d3fdbda6SRiccardo Mori isl::union_pw_aff PartialSchedUAff = PartialSched.at(0);
10643f170eb1SMichael Kruse
10653f170eb1SMichael Kruse // Here we assume the schedule stride is one and starts with 0, which is not
10663f170eb1SMichael Kruse // necessarily the case.
10673f170eb1SMichael Kruse isl::union_pw_aff StridedPartialSchedUAff =
10683f170eb1SMichael Kruse isl::union_pw_aff::empty(PartialSchedUAff.get_space());
10693f170eb1SMichael Kruse isl::val ValFactor{Ctx, Factor};
10703f170eb1SMichael Kruse PartialSchedUAff.foreach_pw_aff([&StridedPartialSchedUAff,
10713f170eb1SMichael Kruse &ValFactor](isl::pw_aff PwAff) -> isl::stat {
10723f170eb1SMichael Kruse isl::space Space = PwAff.get_space();
10733f170eb1SMichael Kruse isl::set Universe = isl::set::universe(Space.domain());
10743f170eb1SMichael Kruse isl::pw_aff AffFactor{Universe, ValFactor};
10753f170eb1SMichael Kruse isl::pw_aff DivSchedAff = PwAff.div(AffFactor).floor().mul(AffFactor);
10763f170eb1SMichael Kruse StridedPartialSchedUAff = StridedPartialSchedUAff.union_add(DivSchedAff);
10773f170eb1SMichael Kruse return isl::stat::ok();
10783f170eb1SMichael Kruse });
10793f170eb1SMichael Kruse
1080d3fdbda6SRiccardo Mori isl::union_set_list List = isl::union_set_list(Ctx, Factor);
10813f170eb1SMichael Kruse for (auto i : seq<int>(0, Factor)) {
10823f170eb1SMichael Kruse // { Stmt[] -> [x] }
1083d3fdbda6SRiccardo Mori isl::union_map UMap =
1084d3fdbda6SRiccardo Mori isl::union_map::from(isl::union_pw_multi_aff(PartialSchedUAff));
10853f170eb1SMichael Kruse
10863f170eb1SMichael Kruse // { [x] }
10873f170eb1SMichael Kruse isl::basic_set Divisible = isDivisibleBySet(Ctx, Factor, i);
10883f170eb1SMichael Kruse
10893f170eb1SMichael Kruse // { Stmt[] }
10903f170eb1SMichael Kruse isl::union_set UnrolledDomain = UMap.intersect_range(Divisible).domain();
10913f170eb1SMichael Kruse
10923f170eb1SMichael Kruse List = List.add(UnrolledDomain);
10933f170eb1SMichael Kruse }
10943f170eb1SMichael Kruse
10953f170eb1SMichael Kruse isl::schedule_node Body =
10963f170eb1SMichael Kruse isl::manage(isl_schedule_node_delete(BandToUnroll.copy()));
10973f170eb1SMichael Kruse Body = Body.insert_sequence(List);
10983f170eb1SMichael Kruse isl::schedule_node NewLoop =
10993f170eb1SMichael Kruse Body.insert_partial_schedule(StridedPartialSchedUAff);
11003f170eb1SMichael Kruse
11013f170eb1SMichael Kruse MDNode *FollowupMD = nullptr;
11023f170eb1SMichael Kruse if (Attr && Attr->Metadata)
11033f170eb1SMichael Kruse FollowupMD =
11043f170eb1SMichael Kruse findOptionalNodeOperand(Attr->Metadata, LLVMLoopUnrollFollowupUnrolled);
11053f170eb1SMichael Kruse
11063f170eb1SMichael Kruse isl::id NewBandId = createGeneratedLoopAttr(Ctx, FollowupMD);
11077c7978a1Spatacca if (!NewBandId.is_null())
11083f170eb1SMichael Kruse NewLoop = insertMark(NewLoop, NewBandId);
11093f170eb1SMichael Kruse
11103f170eb1SMichael Kruse return NewLoop.get_schedule();
11113f170eb1SMichael Kruse }
1112d123e983SMichael Kruse
getPartialTilePrefixes(isl::set ScheduleRange,int VectorWidth)1113d123e983SMichael Kruse isl::set polly::getPartialTilePrefixes(isl::set ScheduleRange,
1114d123e983SMichael Kruse int VectorWidth) {
111544596fe6SRiccardo Mori unsigned Dims = unsignedFromIslSize(ScheduleRange.tuple_dim());
111644596fe6SRiccardo Mori assert(Dims >= 1);
1117d123e983SMichael Kruse isl::set LoopPrefixes =
1118d123e983SMichael Kruse ScheduleRange.drop_constraints_involving_dims(isl::dim::set, Dims - 1, 1);
1119d123e983SMichael Kruse auto ExtentPrefixes = addExtentConstraints(LoopPrefixes, VectorWidth);
1120d123e983SMichael Kruse isl::set BadPrefixes = ExtentPrefixes.subtract(ScheduleRange);
1121d123e983SMichael Kruse BadPrefixes = BadPrefixes.project_out(isl::dim::set, Dims - 1, 1);
1122d123e983SMichael Kruse LoopPrefixes = LoopPrefixes.project_out(isl::dim::set, Dims - 1, 1);
1123d123e983SMichael Kruse return LoopPrefixes.subtract(BadPrefixes);
1124d123e983SMichael Kruse }
1125d123e983SMichael Kruse
getIsolateOptions(isl::set IsolateDomain,unsigned OutDimsNum)1126d123e983SMichael Kruse isl::union_set polly::getIsolateOptions(isl::set IsolateDomain,
112744596fe6SRiccardo Mori unsigned OutDimsNum) {
112844596fe6SRiccardo Mori unsigned Dims = unsignedFromIslSize(IsolateDomain.tuple_dim());
1129d123e983SMichael Kruse assert(OutDimsNum <= Dims &&
1130d123e983SMichael Kruse "The isl::set IsolateDomain is used to describe the range of schedule "
1131d123e983SMichael Kruse "dimensions values, which should be isolated. Consequently, the "
1132d123e983SMichael Kruse "number of its dimensions should be greater than or equal to the "
1133d123e983SMichael Kruse "number of the schedule dimensions.");
1134d123e983SMichael Kruse isl::map IsolateRelation = isl::map::from_domain(IsolateDomain);
1135d123e983SMichael Kruse IsolateRelation = IsolateRelation.move_dims(isl::dim::out, 0, isl::dim::in,
1136d123e983SMichael Kruse Dims - OutDimsNum, OutDimsNum);
1137d123e983SMichael Kruse isl::set IsolateOption = IsolateRelation.wrap();
11380813bd16SRiccardo Mori isl::id Id = isl::id::alloc(IsolateOption.ctx(), "isolate", nullptr);
1139d123e983SMichael Kruse IsolateOption = IsolateOption.set_tuple_id(Id);
1140d123e983SMichael Kruse return isl::union_set(IsolateOption);
1141d123e983SMichael Kruse }
1142d123e983SMichael Kruse
getDimOptions(isl::ctx Ctx,const char * Option)1143d123e983SMichael Kruse isl::union_set polly::getDimOptions(isl::ctx Ctx, const char *Option) {
1144d123e983SMichael Kruse isl::space Space(Ctx, 0, 1);
1145d123e983SMichael Kruse auto DimOption = isl::set::universe(Space);
1146d123e983SMichael Kruse auto Id = isl::id::alloc(Ctx, Option, nullptr);
1147d123e983SMichael Kruse DimOption = DimOption.set_tuple_id(Id);
1148d123e983SMichael Kruse return isl::union_set(DimOption);
1149d123e983SMichael Kruse }
1150d123e983SMichael Kruse
tileNode(isl::schedule_node Node,const char * Identifier,ArrayRef<int> TileSizes,int DefaultTileSize)1151d123e983SMichael Kruse isl::schedule_node polly::tileNode(isl::schedule_node Node,
1152d123e983SMichael Kruse const char *Identifier,
1153d123e983SMichael Kruse ArrayRef<int> TileSizes,
1154d123e983SMichael Kruse int DefaultTileSize) {
1155d123e983SMichael Kruse auto Space = isl::manage(isl_schedule_node_band_get_space(Node.get()));
1156d123e983SMichael Kruse auto Dims = Space.dim(isl::dim::set);
1157d123e983SMichael Kruse auto Sizes = isl::multi_val::zero(Space);
1158d123e983SMichael Kruse std::string IdentifierString(Identifier);
115944596fe6SRiccardo Mori for (unsigned i : rangeIslSize(0, Dims)) {
116044596fe6SRiccardo Mori unsigned tileSize = i < TileSizes.size() ? TileSizes[i] : DefaultTileSize;
11610813bd16SRiccardo Mori Sizes = Sizes.set_val(i, isl::val(Node.ctx(), tileSize));
1162d123e983SMichael Kruse }
1163d123e983SMichael Kruse auto TileLoopMarkerStr = IdentifierString + " - Tiles";
11640813bd16SRiccardo Mori auto TileLoopMarker = isl::id::alloc(Node.ctx(), TileLoopMarkerStr, nullptr);
1165d123e983SMichael Kruse Node = Node.insert_mark(TileLoopMarker);
1166d123e983SMichael Kruse Node = Node.child(0);
1167d123e983SMichael Kruse Node =
1168d123e983SMichael Kruse isl::manage(isl_schedule_node_band_tile(Node.release(), Sizes.release()));
1169d123e983SMichael Kruse Node = Node.child(0);
1170d123e983SMichael Kruse auto PointLoopMarkerStr = IdentifierString + " - Points";
1171d123e983SMichael Kruse auto PointLoopMarker =
11720813bd16SRiccardo Mori isl::id::alloc(Node.ctx(), PointLoopMarkerStr, nullptr);
1173d123e983SMichael Kruse Node = Node.insert_mark(PointLoopMarker);
1174d123e983SMichael Kruse return Node.child(0);
1175d123e983SMichael Kruse }
1176d123e983SMichael Kruse
applyRegisterTiling(isl::schedule_node Node,ArrayRef<int> TileSizes,int DefaultTileSize)1177d123e983SMichael Kruse isl::schedule_node polly::applyRegisterTiling(isl::schedule_node Node,
1178d123e983SMichael Kruse ArrayRef<int> TileSizes,
1179d123e983SMichael Kruse int DefaultTileSize) {
1180d123e983SMichael Kruse Node = tileNode(Node, "Register tiling", TileSizes, DefaultTileSize);
11810813bd16SRiccardo Mori auto Ctx = Node.ctx();
1182d3fdbda6SRiccardo Mori return Node.as<isl::schedule_node_band>().set_ast_build_options(
1183d3fdbda6SRiccardo Mori isl::union_set(Ctx, "{unroll[x]}"));
1184d123e983SMichael Kruse }
1185e470f926SMichael Kruse
1186e470f926SMichael Kruse /// Find statements and sub-loops in (possibly nested) sequences.
1187e470f926SMichael Kruse static void
collectFissionableStmts(isl::schedule_node Node,SmallVectorImpl<isl::schedule_node> & ScheduleStmts)1188b554c643SMichael Kruse collectFissionableStmts(isl::schedule_node Node,
1189e470f926SMichael Kruse SmallVectorImpl<isl::schedule_node> &ScheduleStmts) {
1190e470f926SMichael Kruse if (isBand(Node) || isLeaf(Node)) {
1191e470f926SMichael Kruse ScheduleStmts.push_back(Node);
1192e470f926SMichael Kruse return;
1193e470f926SMichael Kruse }
1194e470f926SMichael Kruse
1195e470f926SMichael Kruse if (Node.has_children()) {
1196e470f926SMichael Kruse isl::schedule_node C = Node.first_child();
1197e470f926SMichael Kruse while (true) {
1198b554c643SMichael Kruse collectFissionableStmts(C, ScheduleStmts);
1199e470f926SMichael Kruse if (!C.has_next_sibling())
1200e470f926SMichael Kruse break;
1201e470f926SMichael Kruse C = C.next_sibling();
1202e470f926SMichael Kruse }
1203e470f926SMichael Kruse }
1204e470f926SMichael Kruse }
1205e470f926SMichael Kruse
applyMaxFission(isl::schedule_node BandToFission)1206e470f926SMichael Kruse isl::schedule polly::applyMaxFission(isl::schedule_node BandToFission) {
1207e470f926SMichael Kruse isl::ctx Ctx = BandToFission.ctx();
1208e470f926SMichael Kruse BandToFission = removeMark(BandToFission);
1209e470f926SMichael Kruse isl::schedule_node BandBody = BandToFission.child(0);
1210e470f926SMichael Kruse
1211e470f926SMichael Kruse SmallVector<isl::schedule_node> FissionableStmts;
1212b554c643SMichael Kruse collectFissionableStmts(BandBody, FissionableStmts);
1213e470f926SMichael Kruse size_t N = FissionableStmts.size();
1214e470f926SMichael Kruse
1215e470f926SMichael Kruse // Collect the domain for each of the statements that will get their own loop.
1216e470f926SMichael Kruse isl::union_set_list DomList = isl::union_set_list(Ctx, N);
1217e470f926SMichael Kruse for (size_t i = 0; i < N; ++i) {
1218e470f926SMichael Kruse isl::schedule_node BodyPart = FissionableStmts[i];
1219e470f926SMichael Kruse DomList = DomList.add(BodyPart.get_domain());
1220e470f926SMichael Kruse }
1221e470f926SMichael Kruse
1222e470f926SMichael Kruse // Apply the fission by copying the entire loop, but inserting a filter for
1223e470f926SMichael Kruse // the statement domains for each fissioned loop.
1224e470f926SMichael Kruse isl::schedule_node Fissioned = BandToFission.insert_sequence(DomList);
1225e470f926SMichael Kruse
1226e470f926SMichael Kruse return Fissioned.get_schedule();
1227e470f926SMichael Kruse }
122864489255SMichael Kruse
applyGreedyFusion(isl::schedule Sched,const isl::union_map & Deps)122964489255SMichael Kruse isl::schedule polly::applyGreedyFusion(isl::schedule Sched,
123064489255SMichael Kruse const isl::union_map &Deps) {
123164489255SMichael Kruse LLVM_DEBUG(dbgs() << "Greedy loop fusion\n");
123264489255SMichael Kruse
123364489255SMichael Kruse GreedyFusionRewriter Rewriter;
123464489255SMichael Kruse isl::schedule Result = Rewriter.visit(Sched, Deps);
123564489255SMichael Kruse if (!Rewriter.AnyChange) {
123664489255SMichael Kruse LLVM_DEBUG(dbgs() << "Found nothing to fuse\n");
123764489255SMichael Kruse return Sched;
123864489255SMichael Kruse }
123964489255SMichael Kruse
124064489255SMichael Kruse // GreedyFusionRewriter due to working loop-by-loop, bands with multiple loops
124164489255SMichael Kruse // may have been split into multiple bands.
124264489255SMichael Kruse return collapseBands(Result);
124364489255SMichael Kruse }
1244