1 //===- polly/ScheduleTreeTransform.cpp --------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Make changes to isl's schedule tree data structure.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "polly/ScheduleTreeTransform.h"
14 #include "polly/Support/GICHelper.h"
15 #include "polly/Support/ISLTools.h"
16 #include "polly/Support/ScopHelper.h"
17 #include "llvm/ADT/ArrayRef.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/IR/Constants.h"
21 #include "llvm/IR/Metadata.h"
22 #include "llvm/Transforms/Utils/UnrollLoop.h"
23 
24 #define DEBUG_TYPE "polly-opt-isl"
25 
26 using namespace polly;
27 using namespace llvm;
28 
29 namespace {
30 
31 /// Copy the band member attributes (coincidence, loop type, isolate ast loop
32 /// type) from one band to another.
33 static isl::schedule_node_band
applyBandMemberAttributes(isl::schedule_node_band Target,int TargetIdx,const isl::schedule_node_band & Source,int SourceIdx)34 applyBandMemberAttributes(isl::schedule_node_band Target, int TargetIdx,
35                           const isl::schedule_node_band &Source,
36                           int SourceIdx) {
37   bool Coincident = Source.member_get_coincident(SourceIdx).release();
38   Target = Target.member_set_coincident(TargetIdx, Coincident);
39 
40   isl_ast_loop_type LoopType =
41       isl_schedule_node_band_member_get_ast_loop_type(Source.get(), SourceIdx);
42   Target = isl::manage(isl_schedule_node_band_member_set_ast_loop_type(
43                            Target.release(), TargetIdx, LoopType))
44                .as<isl::schedule_node_band>();
45 
46   isl_ast_loop_type IsolateType =
47       isl_schedule_node_band_member_get_isolate_ast_loop_type(Source.get(),
48                                                               SourceIdx);
49   Target = isl::manage(isl_schedule_node_band_member_set_isolate_ast_loop_type(
50                            Target.release(), TargetIdx, IsolateType))
51                .as<isl::schedule_node_band>();
52 
53   return Target;
54 }
55 
56 /// Create a new band by copying members from another @p Band. @p IncludeCb
57 /// decides which band indices are copied to the result.
58 template <typename CbTy>
rebuildBand(isl::schedule_node_band OldBand,isl::schedule Body,CbTy IncludeCb)59 static isl::schedule rebuildBand(isl::schedule_node_band OldBand,
60                                  isl::schedule Body, CbTy IncludeCb) {
61   int NumBandDims = unsignedFromIslSize(OldBand.n_member());
62 
63   bool ExcludeAny = false;
64   bool IncludeAny = false;
65   for (auto OldIdx : seq<int>(0, NumBandDims)) {
66     if (IncludeCb(OldIdx))
67       IncludeAny = true;
68     else
69       ExcludeAny = true;
70   }
71 
72   // Instead of creating a zero-member band, don't create a band at all.
73   if (!IncludeAny)
74     return Body;
75 
76   isl::multi_union_pw_aff PartialSched = OldBand.get_partial_schedule();
77   isl::multi_union_pw_aff NewPartialSched;
78   if (ExcludeAny) {
79     // Select the included partial scatter functions.
80     isl::union_pw_aff_list List = PartialSched.list();
81     int NewIdx = 0;
82     for (auto OldIdx : seq<int>(0, NumBandDims)) {
83       if (IncludeCb(OldIdx))
84         NewIdx += 1;
85       else
86         List = List.drop(NewIdx, 1);
87     }
88     isl::space ParamSpace = PartialSched.get_space().params();
89     isl::space NewScatterSpace = ParamSpace.add_unnamed_tuple(NewIdx);
90     NewPartialSched = isl::multi_union_pw_aff(NewScatterSpace, List);
91   } else {
92     // Just reuse original scatter function of copying all of them.
93     NewPartialSched = PartialSched;
94   }
95 
96   // Create the new band node.
97   isl::schedule_node_band NewBand =
98       Body.insert_partial_schedule(NewPartialSched)
99           .get_root()
100           .child(0)
101           .as<isl::schedule_node_band>();
102 
103   // If OldBand was permutable, so is the new one, even if some dimensions are
104   // missing.
105   bool IsPermutable = OldBand.permutable().release();
106   NewBand = NewBand.set_permutable(IsPermutable);
107 
108   // Reapply member attributes.
109   int NewIdx = 0;
110   for (auto OldIdx : seq<int>(0, NumBandDims)) {
111     if (!IncludeCb(OldIdx))
112       continue;
113     NewBand =
114         applyBandMemberAttributes(std::move(NewBand), NewIdx, OldBand, OldIdx);
115     NewIdx += 1;
116   }
117 
118   return NewBand.get_schedule();
119 }
120 
121 /// Rewrite a schedule tree by reconstructing it bottom-up.
122 ///
123 /// By default, the original schedule tree is reconstructed. To build a
124 /// different tree, redefine visitor methods in a derived class (CRTP).
125 ///
126 /// Note that AST build options are not applied; Setting the isolate[] option
127 /// makes the schedule tree 'anchored' and cannot be modified afterwards. Hence,
128 /// AST build options must be set after the tree has been constructed.
129 template <typename Derived, typename... Args>
130 struct ScheduleTreeRewriter
131     : RecursiveScheduleTreeVisitor<Derived, isl::schedule, Args...> {
getDerived__anon9ace89990111::ScheduleTreeRewriter132   Derived &getDerived() { return *static_cast<Derived *>(this); }
getDerived__anon9ace89990111::ScheduleTreeRewriter133   const Derived &getDerived() const {
134     return *static_cast<const Derived *>(this);
135   }
136 
visitDomain__anon9ace89990111::ScheduleTreeRewriter137   isl::schedule visitDomain(isl::schedule_node_domain Node, Args... args) {
138     // Every schedule_tree already has a domain node, no need to add one.
139     return getDerived().visit(Node.first_child(), std::forward<Args>(args)...);
140   }
141 
visitBand__anon9ace89990111::ScheduleTreeRewriter142   isl::schedule visitBand(isl::schedule_node_band Band, Args... args) {
143     isl::schedule NewChild =
144         getDerived().visit(Band.child(0), std::forward<Args>(args)...);
145     return rebuildBand(Band, NewChild, [](int) { return true; });
146   }
147 
visitSequence__anon9ace89990111::ScheduleTreeRewriter148   isl::schedule visitSequence(isl::schedule_node_sequence Sequence,
149                               Args... args) {
150     int NumChildren = isl_schedule_node_n_children(Sequence.get());
151     isl::schedule Result =
152         getDerived().visit(Sequence.child(0), std::forward<Args>(args)...);
153     for (int i = 1; i < NumChildren; i += 1)
154       Result = Result.sequence(
155           getDerived().visit(Sequence.child(i), std::forward<Args>(args)...));
156     return Result;
157   }
158 
visitSet__anon9ace89990111::ScheduleTreeRewriter159   isl::schedule visitSet(isl::schedule_node_set Set, Args... args) {
160     int NumChildren = isl_schedule_node_n_children(Set.get());
161     isl::schedule Result =
162         getDerived().visit(Set.child(0), std::forward<Args>(args)...);
163     for (int i = 1; i < NumChildren; i += 1)
164       Result = isl::manage(
165           isl_schedule_set(Result.release(),
166                            getDerived()
167                                .visit(Set.child(i), std::forward<Args>(args)...)
168                                .release()));
169     return Result;
170   }
171 
visitLeaf__anon9ace89990111::ScheduleTreeRewriter172   isl::schedule visitLeaf(isl::schedule_node_leaf Leaf, Args... args) {
173     return isl::schedule::from_domain(Leaf.get_domain());
174   }
175 
visitMark__anon9ace89990111::ScheduleTreeRewriter176   isl::schedule visitMark(const isl::schedule_node &Mark, Args... args) {
177 
178     isl::id TheMark = Mark.as<isl::schedule_node_mark>().get_id();
179     isl::schedule_node NewChild =
180         getDerived()
181             .visit(Mark.first_child(), std::forward<Args>(args)...)
182             .get_root()
183             .first_child();
184     return NewChild.insert_mark(TheMark).get_schedule();
185   }
186 
visitExtension__anon9ace89990111::ScheduleTreeRewriter187   isl::schedule visitExtension(isl::schedule_node_extension Extension,
188                                Args... args) {
189     isl::union_map TheExtension =
190         Extension.as<isl::schedule_node_extension>().get_extension();
191     isl::schedule_node NewChild = getDerived()
192                                       .visit(Extension.child(0), args...)
193                                       .get_root()
194                                       .first_child();
195     isl::schedule_node NewExtension =
196         isl::schedule_node::from_extension(TheExtension);
197     return NewChild.graft_before(NewExtension).get_schedule();
198   }
199 
visitFilter__anon9ace89990111::ScheduleTreeRewriter200   isl::schedule visitFilter(isl::schedule_node_filter Filter, Args... args) {
201     isl::union_set FilterDomain =
202         Filter.as<isl::schedule_node_filter>().get_filter();
203     isl::schedule NewSchedule =
204         getDerived().visit(Filter.child(0), std::forward<Args>(args)...);
205     return NewSchedule.intersect_domain(FilterDomain);
206   }
207 
visitNode__anon9ace89990111::ScheduleTreeRewriter208   isl::schedule visitNode(isl::schedule_node Node, Args... args) {
209     llvm_unreachable("Not implemented");
210   }
211 };
212 
213 /// Rewrite the schedule tree without any changes. Useful to copy a subtree into
214 /// a new schedule, discarding everything but.
215 struct IdentityRewriter : ScheduleTreeRewriter<IdentityRewriter> {};
216 
217 /// Rewrite a schedule tree to an equivalent one without extension nodes.
218 ///
219 /// Each visit method takes two additional arguments:
220 ///
221 ///  * The new domain the node, which is the inherited domain plus any domains
222 ///    added by extension nodes.
223 ///
224 ///  * A map of extension domains of all children is returned; it is required by
225 ///    band nodes to schedule the additional domains at the same position as the
226 ///    extension node would.
227 ///
228 struct ExtensionNodeRewriter final
229     : ScheduleTreeRewriter<ExtensionNodeRewriter, const isl::union_set &,
230                            isl::union_map &> {
231   using BaseTy = ScheduleTreeRewriter<ExtensionNodeRewriter,
232                                       const isl::union_set &, isl::union_map &>;
getBase__anon9ace89990111::ExtensionNodeRewriter233   BaseTy &getBase() { return *this; }
getBase__anon9ace89990111::ExtensionNodeRewriter234   const BaseTy &getBase() const { return *this; }
235 
visitSchedule__anon9ace89990111::ExtensionNodeRewriter236   isl::schedule visitSchedule(isl::schedule Schedule) {
237     isl::union_map Extensions;
238     isl::schedule Result =
239         visit(Schedule.get_root(), Schedule.get_domain(), Extensions);
240     assert(!Extensions.is_null() && Extensions.is_empty());
241     return Result;
242   }
243 
visitSequence__anon9ace89990111::ExtensionNodeRewriter244   isl::schedule visitSequence(isl::schedule_node_sequence Sequence,
245                               const isl::union_set &Domain,
246                               isl::union_map &Extensions) {
247     int NumChildren = isl_schedule_node_n_children(Sequence.get());
248     isl::schedule NewNode = visit(Sequence.first_child(), Domain, Extensions);
249     for (int i = 1; i < NumChildren; i += 1) {
250       isl::schedule_node OldChild = Sequence.child(i);
251       isl::union_map NewChildExtensions;
252       isl::schedule NewChildNode = visit(OldChild, Domain, NewChildExtensions);
253       NewNode = NewNode.sequence(NewChildNode);
254       Extensions = Extensions.unite(NewChildExtensions);
255     }
256     return NewNode;
257   }
258 
visitSet__anon9ace89990111::ExtensionNodeRewriter259   isl::schedule visitSet(isl::schedule_node_set Set,
260                          const isl::union_set &Domain,
261                          isl::union_map &Extensions) {
262     int NumChildren = isl_schedule_node_n_children(Set.get());
263     isl::schedule NewNode = visit(Set.first_child(), Domain, Extensions);
264     for (int i = 1; i < NumChildren; i += 1) {
265       isl::schedule_node OldChild = Set.child(i);
266       isl::union_map NewChildExtensions;
267       isl::schedule NewChildNode = visit(OldChild, Domain, NewChildExtensions);
268       NewNode = isl::manage(
269           isl_schedule_set(NewNode.release(), NewChildNode.release()));
270       Extensions = Extensions.unite(NewChildExtensions);
271     }
272     return NewNode;
273   }
274 
visitLeaf__anon9ace89990111::ExtensionNodeRewriter275   isl::schedule visitLeaf(isl::schedule_node_leaf Leaf,
276                           const isl::union_set &Domain,
277                           isl::union_map &Extensions) {
278     Extensions = isl::union_map::empty(Leaf.ctx());
279     return isl::schedule::from_domain(Domain);
280   }
281 
visitBand__anon9ace89990111::ExtensionNodeRewriter282   isl::schedule visitBand(isl::schedule_node_band OldNode,
283                           const isl::union_set &Domain,
284                           isl::union_map &OuterExtensions) {
285     isl::schedule_node OldChild = OldNode.first_child();
286     isl::multi_union_pw_aff PartialSched =
287         isl::manage(isl_schedule_node_band_get_partial_schedule(OldNode.get()));
288 
289     isl::union_map NewChildExtensions;
290     isl::schedule NewChild = visit(OldChild, Domain, NewChildExtensions);
291 
292     // Add the extensions to the partial schedule.
293     OuterExtensions = isl::union_map::empty(NewChildExtensions.ctx());
294     isl::union_map NewPartialSchedMap = isl::union_map::from(PartialSched);
295     unsigned BandDims = isl_schedule_node_band_n_member(OldNode.get());
296     for (isl::map Ext : NewChildExtensions.get_map_list()) {
297       unsigned ExtDims = unsignedFromIslSize(Ext.domain_tuple_dim());
298       assert(ExtDims >= BandDims);
299       unsigned OuterDims = ExtDims - BandDims;
300 
301       isl::map BandSched =
302           Ext.project_out(isl::dim::in, 0, OuterDims).reverse();
303       NewPartialSchedMap = NewPartialSchedMap.unite(BandSched);
304 
305       // There might be more outer bands that have to schedule the extensions.
306       if (OuterDims > 0) {
307         isl::map OuterSched =
308             Ext.project_out(isl::dim::in, OuterDims, BandDims);
309         OuterExtensions = OuterExtensions.unite(OuterSched);
310       }
311     }
312     isl::multi_union_pw_aff NewPartialSchedAsAsMultiUnionPwAff =
313         isl::multi_union_pw_aff::from_union_map(NewPartialSchedMap);
314     isl::schedule_node NewNode =
315         NewChild.insert_partial_schedule(NewPartialSchedAsAsMultiUnionPwAff)
316             .get_root()
317             .child(0);
318 
319     // Reapply permutability and coincidence attributes.
320     NewNode = isl::manage(isl_schedule_node_band_set_permutable(
321         NewNode.release(),
322         isl_schedule_node_band_get_permutable(OldNode.get())));
323     for (unsigned i = 0; i < BandDims; i += 1)
324       NewNode = applyBandMemberAttributes(NewNode.as<isl::schedule_node_band>(),
325                                           i, OldNode, i);
326 
327     return NewNode.get_schedule();
328   }
329 
visitFilter__anon9ace89990111::ExtensionNodeRewriter330   isl::schedule visitFilter(isl::schedule_node_filter Filter,
331                             const isl::union_set &Domain,
332                             isl::union_map &Extensions) {
333     isl::union_set FilterDomain =
334         Filter.as<isl::schedule_node_filter>().get_filter();
335     isl::union_set NewDomain = Domain.intersect(FilterDomain);
336 
337     // A filter is added implicitly if necessary when joining schedule trees.
338     return visit(Filter.first_child(), NewDomain, Extensions);
339   }
340 
visitExtension__anon9ace89990111::ExtensionNodeRewriter341   isl::schedule visitExtension(isl::schedule_node_extension Extension,
342                                const isl::union_set &Domain,
343                                isl::union_map &Extensions) {
344     isl::union_map ExtDomain =
345         Extension.as<isl::schedule_node_extension>().get_extension();
346     isl::union_set NewDomain = Domain.unite(ExtDomain.range());
347     isl::union_map ChildExtensions;
348     isl::schedule NewChild =
349         visit(Extension.first_child(), NewDomain, ChildExtensions);
350     Extensions = ChildExtensions.unite(ExtDomain);
351     return NewChild;
352   }
353 };
354 
355 /// Collect all AST build options in any schedule tree band.
356 ///
357 /// ScheduleTreeRewriter cannot apply the schedule tree options. This class
358 /// collects these options to apply them later.
359 struct CollectASTBuildOptions final
360     : RecursiveScheduleTreeVisitor<CollectASTBuildOptions> {
361   using BaseTy = RecursiveScheduleTreeVisitor<CollectASTBuildOptions>;
getBase__anon9ace89990111::CollectASTBuildOptions362   BaseTy &getBase() { return *this; }
getBase__anon9ace89990111::CollectASTBuildOptions363   const BaseTy &getBase() const { return *this; }
364 
365   llvm::SmallVector<isl::union_set, 8> ASTBuildOptions;
366 
visitBand__anon9ace89990111::CollectASTBuildOptions367   void visitBand(isl::schedule_node_band Band) {
368     ASTBuildOptions.push_back(
369         isl::manage(isl_schedule_node_band_get_ast_build_options(Band.get())));
370     return getBase().visitBand(Band);
371   }
372 };
373 
374 /// Apply AST build options to the bands in a schedule tree.
375 ///
376 /// This rewrites a schedule tree with the AST build options applied. We assume
377 /// that the band nodes are visited in the same order as they were when the
378 /// build options were collected, typically by CollectASTBuildOptions.
379 struct ApplyASTBuildOptions final : ScheduleNodeRewriter<ApplyASTBuildOptions> {
380   using BaseTy = ScheduleNodeRewriter<ApplyASTBuildOptions>;
getBase__anon9ace89990111::ApplyASTBuildOptions381   BaseTy &getBase() { return *this; }
getBase__anon9ace89990111::ApplyASTBuildOptions382   const BaseTy &getBase() const { return *this; }
383 
384   size_t Pos;
385   llvm::ArrayRef<isl::union_set> ASTBuildOptions;
386 
ApplyASTBuildOptions__anon9ace89990111::ApplyASTBuildOptions387   ApplyASTBuildOptions(llvm::ArrayRef<isl::union_set> ASTBuildOptions)
388       : ASTBuildOptions(ASTBuildOptions) {}
389 
visitSchedule__anon9ace89990111::ApplyASTBuildOptions390   isl::schedule visitSchedule(isl::schedule Schedule) {
391     Pos = 0;
392     isl::schedule Result = visit(Schedule).get_schedule();
393     assert(Pos == ASTBuildOptions.size() &&
394            "AST build options must match to band nodes");
395     return Result;
396   }
397 
visitBand__anon9ace89990111::ApplyASTBuildOptions398   isl::schedule_node visitBand(isl::schedule_node_band Band) {
399     isl::schedule_node_band Result =
400         Band.set_ast_build_options(ASTBuildOptions[Pos]);
401     Pos += 1;
402     return getBase().visitBand(Result);
403   }
404 };
405 
406 /// Return whether the schedule contains an extension node.
containsExtensionNode(isl::schedule Schedule)407 static bool containsExtensionNode(isl::schedule Schedule) {
408   assert(!Schedule.is_null());
409 
410   auto Callback = [](__isl_keep isl_schedule_node *Node,
411                      void *User) -> isl_bool {
412     if (isl_schedule_node_get_type(Node) == isl_schedule_node_extension) {
413       // Stop walking the schedule tree.
414       return isl_bool_error;
415     }
416 
417     // Continue searching the subtree.
418     return isl_bool_true;
419   };
420   isl_stat RetVal = isl_schedule_foreach_schedule_node_top_down(
421       Schedule.get(), Callback, nullptr);
422 
423   // We assume that the traversal itself does not fail, i.e. the only reason to
424   // return isl_stat_error is that an extension node was found.
425   return RetVal == isl_stat_error;
426 }
427 
428 /// Find a named MDNode property in a LoopID.
findOptionalNodeOperand(MDNode * LoopMD,StringRef Name)429 static MDNode *findOptionalNodeOperand(MDNode *LoopMD, StringRef Name) {
430   return dyn_cast_or_null<MDNode>(
431       findMetadataOperand(LoopMD, Name).value_or(nullptr));
432 }
433 
434 /// Is this node of type mark?
isMark(const isl::schedule_node & Node)435 static bool isMark(const isl::schedule_node &Node) {
436   return isl_schedule_node_get_type(Node.get()) == isl_schedule_node_mark;
437 }
438 
439 /// Is this node of type band?
isBand(const isl::schedule_node & Node)440 static bool isBand(const isl::schedule_node &Node) {
441   return isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band;
442 }
443 
444 #ifndef NDEBUG
445 /// Is this node a band of a single dimension (i.e. could represent a loop)?
isBandWithSingleLoop(const isl::schedule_node & Node)446 static bool isBandWithSingleLoop(const isl::schedule_node &Node) {
447   return isBand(Node) && isl_schedule_node_band_n_member(Node.get()) == 1;
448 }
449 #endif
450 
isLeaf(const isl::schedule_node & Node)451 static bool isLeaf(const isl::schedule_node &Node) {
452   return isl_schedule_node_get_type(Node.get()) == isl_schedule_node_leaf;
453 }
454 
455 /// Create an isl::id representing the output loop after a transformation.
createGeneratedLoopAttr(isl::ctx Ctx,MDNode * FollowupLoopMD)456 static isl::id createGeneratedLoopAttr(isl::ctx Ctx, MDNode *FollowupLoopMD) {
457   // Don't need to id the followup.
458   // TODO: Append llvm.loop.disable_heustistics metadata unless overridden by
459   //       user followup-MD
460   if (!FollowupLoopMD)
461     return {};
462 
463   BandAttr *Attr = new BandAttr();
464   Attr->Metadata = FollowupLoopMD;
465   return getIslLoopAttr(Ctx, Attr);
466 }
467 
468 /// A loop consists of a band and an optional marker that wraps it. Return the
469 /// outermost of the two.
470 
471 /// That is, either the mark or, if there is not mark, the loop itself. Can
472 /// start with either the mark or the band.
moveToBandMark(isl::schedule_node BandOrMark)473 static isl::schedule_node moveToBandMark(isl::schedule_node BandOrMark) {
474   if (isBandMark(BandOrMark)) {
475     assert(isBandWithSingleLoop(BandOrMark.child(0)));
476     return BandOrMark;
477   }
478   assert(isBandWithSingleLoop(BandOrMark));
479 
480   isl::schedule_node Mark = BandOrMark.parent();
481   if (isBandMark(Mark))
482     return Mark;
483 
484   // Band has no loop marker.
485   return BandOrMark;
486 }
487 
removeMark(isl::schedule_node MarkOrBand,BandAttr * & Attr)488 static isl::schedule_node removeMark(isl::schedule_node MarkOrBand,
489                                      BandAttr *&Attr) {
490   MarkOrBand = moveToBandMark(MarkOrBand);
491 
492   isl::schedule_node Band;
493   if (isMark(MarkOrBand)) {
494     Attr = getLoopAttr(MarkOrBand.as<isl::schedule_node_mark>().get_id());
495     Band = isl::manage(isl_schedule_node_delete(MarkOrBand.release()));
496   } else {
497     Attr = nullptr;
498     Band = MarkOrBand;
499   }
500 
501   assert(isBandWithSingleLoop(Band));
502   return Band;
503 }
504 
505 /// Remove the mark that wraps a loop. Return the band representing the loop.
removeMark(isl::schedule_node MarkOrBand)506 static isl::schedule_node removeMark(isl::schedule_node MarkOrBand) {
507   BandAttr *Attr;
508   return removeMark(MarkOrBand, Attr);
509 }
510 
insertMark(isl::schedule_node Band,isl::id Mark)511 static isl::schedule_node insertMark(isl::schedule_node Band, isl::id Mark) {
512   assert(isBand(Band));
513   assert(moveToBandMark(Band).is_equal(Band) &&
514          "Don't add a two marks for a band");
515 
516   return Band.insert_mark(Mark).child(0);
517 }
518 
519 /// Return the (one-dimensional) set of numbers that are divisible by @p Factor
520 /// with remainder @p Offset.
521 ///
522 ///  isDivisibleBySet(Ctx, 4, 0) = { [i] : floord(i,4) = 0 }
523 ///  isDivisibleBySet(Ctx, 4, 1) = { [i] : floord(i,4) = 1 }
524 ///
isDivisibleBySet(isl::ctx & Ctx,long Factor,long Offset)525 static isl::basic_set isDivisibleBySet(isl::ctx &Ctx, long Factor,
526                                        long Offset) {
527   isl::val ValFactor{Ctx, Factor};
528   isl::val ValOffset{Ctx, Offset};
529 
530   isl::space Unispace{Ctx, 0, 1};
531   isl::local_space LUnispace{Unispace};
532   isl::aff AffFactor{LUnispace, ValFactor};
533   isl::aff AffOffset{LUnispace, ValOffset};
534 
535   isl::aff Id = isl::aff::var_on_domain(LUnispace, isl::dim::out, 0);
536   isl::aff DivMul = Id.mod(ValFactor);
537   isl::basic_map Divisible = isl::basic_map::from_aff(DivMul);
538   isl::basic_map Modulo = Divisible.fix_val(isl::dim::out, 0, ValOffset);
539   return Modulo.domain();
540 }
541 
542 /// Make the last dimension of Set to take values from 0 to VectorWidth - 1.
543 ///
544 /// @param Set         A set, which should be modified.
545 /// @param VectorWidth A parameter, which determines the constraint.
addExtentConstraints(isl::set Set,int VectorWidth)546 static isl::set addExtentConstraints(isl::set Set, int VectorWidth) {
547   unsigned Dims = unsignedFromIslSize(Set.tuple_dim());
548   assert(Dims >= 1);
549   isl::space Space = Set.get_space();
550   isl::local_space LocalSpace = isl::local_space(Space);
551   isl::constraint ExtConstr = isl::constraint::alloc_inequality(LocalSpace);
552   ExtConstr = ExtConstr.set_constant_si(0);
553   ExtConstr = ExtConstr.set_coefficient_si(isl::dim::set, Dims - 1, 1);
554   Set = Set.add_constraint(ExtConstr);
555   ExtConstr = isl::constraint::alloc_inequality(LocalSpace);
556   ExtConstr = ExtConstr.set_constant_si(VectorWidth - 1);
557   ExtConstr = ExtConstr.set_coefficient_si(isl::dim::set, Dims - 1, -1);
558   return Set.add_constraint(ExtConstr);
559 }
560 
561 /// Collapse perfectly nested bands into a single band.
562 class BandCollapseRewriter final
563     : public ScheduleTreeRewriter<BandCollapseRewriter> {
564 private:
565   using BaseTy = ScheduleTreeRewriter<BandCollapseRewriter>;
getBase()566   BaseTy &getBase() { return *this; }
getBase() const567   const BaseTy &getBase() const { return *this; }
568 
569 public:
visitBand(isl::schedule_node_band RootBand)570   isl::schedule visitBand(isl::schedule_node_band RootBand) {
571     isl::schedule_node_band Band = RootBand;
572     isl::ctx Ctx = Band.ctx();
573 
574     // Do not merge permutable band to avoid loosing the permutability property.
575     // Cannot collapse even two permutable loops, they might be permutable
576     // individually, but not necassarily accross.
577     if (unsignedFromIslSize(Band.n_member()) > 1u && Band.permutable())
578       return getBase().visitBand(Band);
579 
580     // Find collapsable bands.
581     SmallVector<isl::schedule_node_band> Nest;
582     int NumTotalLoops = 0;
583     isl::schedule_node Body;
584     while (true) {
585       Nest.push_back(Band);
586       NumTotalLoops += unsignedFromIslSize(Band.n_member());
587       Body = Band.first_child();
588       if (!Body.isa<isl::schedule_node_band>())
589         break;
590       Band = Body.as<isl::schedule_node_band>();
591 
592       // Do not include next band if it is permutable to not lose its
593       // permutability property.
594       if (unsignedFromIslSize(Band.n_member()) > 1u && Band.permutable())
595         break;
596     }
597 
598     // Nothing to collapse, preserve permutability.
599     if (Nest.size() <= 1)
600       return getBase().visitBand(Band);
601 
602     LLVM_DEBUG({
603       dbgs() << "Found loops to collapse between\n";
604       dumpIslObj(RootBand, dbgs());
605       dbgs() << "and\n";
606       dumpIslObj(Body, dbgs());
607       dbgs() << "\n";
608     });
609 
610     isl::schedule NewBody = visit(Body);
611 
612     // Collect partial schedules from all members.
613     isl::union_pw_aff_list PartScheds{Ctx, NumTotalLoops};
614     for (isl::schedule_node_band Band : Nest) {
615       int NumLoops = unsignedFromIslSize(Band.n_member());
616       isl::multi_union_pw_aff BandScheds = Band.get_partial_schedule();
617       for (auto j : seq<int>(0, NumLoops))
618         PartScheds = PartScheds.add(BandScheds.at(j));
619     }
620     isl::space ScatterSpace = isl::space(Ctx, 0, NumTotalLoops);
621     isl::multi_union_pw_aff PartSchedsMulti{ScatterSpace, PartScheds};
622 
623     isl::schedule_node_band CollapsedBand =
624         NewBody.insert_partial_schedule(PartSchedsMulti)
625             .get_root()
626             .first_child()
627             .as<isl::schedule_node_band>();
628 
629     // Copy over loop attributes form original bands.
630     int LoopIdx = 0;
631     for (isl::schedule_node_band Band : Nest) {
632       int NumLoops = unsignedFromIslSize(Band.n_member());
633       for (int i : seq<int>(0, NumLoops)) {
634         CollapsedBand = applyBandMemberAttributes(std::move(CollapsedBand),
635                                                   LoopIdx, Band, i);
636         LoopIdx += 1;
637       }
638     }
639     assert(LoopIdx == NumTotalLoops &&
640            "Expect the same number of loops to add up again");
641 
642     return CollapsedBand.get_schedule();
643   }
644 };
645 
collapseBands(isl::schedule Sched)646 static isl::schedule collapseBands(isl::schedule Sched) {
647   LLVM_DEBUG(dbgs() << "Collapse bands in schedule\n");
648   BandCollapseRewriter Rewriter;
649   return Rewriter.visit(Sched);
650 }
651 
652 /// Collect sequentially executed bands (or anything else), even if nested in a
653 /// mark or other nodes whose child is executed just once. If we can
654 /// successfully fuse the bands, we allow them to be removed.
collectPotentiallyFusableBands(isl::schedule_node Node,SmallVectorImpl<std::pair<isl::schedule_node,isl::schedule_node>> & ScheduleBands,const isl::schedule_node & DirectChild)655 static void collectPotentiallyFusableBands(
656     isl::schedule_node Node,
657     SmallVectorImpl<std::pair<isl::schedule_node, isl::schedule_node>>
658         &ScheduleBands,
659     const isl::schedule_node &DirectChild) {
660   switch (isl_schedule_node_get_type(Node.get())) {
661   case isl_schedule_node_sequence:
662   case isl_schedule_node_set:
663   case isl_schedule_node_mark:
664   case isl_schedule_node_domain:
665   case isl_schedule_node_filter:
666     if (Node.has_children()) {
667       isl::schedule_node C = Node.first_child();
668       while (true) {
669         collectPotentiallyFusableBands(C, ScheduleBands, DirectChild);
670         if (!C.has_next_sibling())
671           break;
672         C = C.next_sibling();
673       }
674     }
675     break;
676 
677   default:
678     // Something that does not execute suquentially (e.g. a band)
679     ScheduleBands.push_back({Node, DirectChild});
680     break;
681   }
682 }
683 
684 /// Remove dependencies that are resolved by @p PartSched. That is, remove
685 /// everything that we already know is executed in-order.
remainingDepsFromPartialSchedule(isl::union_map PartSched,isl::union_map Deps)686 static isl::union_map remainingDepsFromPartialSchedule(isl::union_map PartSched,
687                                                        isl::union_map Deps) {
688   unsigned NumDims = getNumScatterDims(PartSched);
689   auto ParamSpace = PartSched.get_space().params();
690 
691   // { Scatter[] }
692   isl::space ScatterSpace =
693       ParamSpace.set_from_params().add_dims(isl::dim::set, NumDims);
694 
695   // { Scatter[] -> Domain[] }
696   isl::union_map PartSchedRev = PartSched.reverse();
697 
698   // { Scatter[] -> Scatter[] }
699   isl::map MaybeBefore = isl::map::lex_le(ScatterSpace);
700 
701   // { Domain[] -> Domain[] }
702   isl::union_map DomMaybeBefore =
703       MaybeBefore.apply_domain(PartSchedRev).apply_range(PartSchedRev);
704 
705   // { Domain[] -> Domain[] }
706   isl::union_map ChildRemainingDeps = Deps.intersect(DomMaybeBefore);
707 
708   return ChildRemainingDeps;
709 }
710 
711 /// Remove dependencies that are resolved by executing them in the order
712 /// specified by @p Domains;
remainigDepsFromSequence(ArrayRef<isl::union_set> Domains,isl::union_map Deps)713 static isl::union_map remainigDepsFromSequence(ArrayRef<isl::union_set> Domains,
714                                                isl::union_map Deps) {
715   isl::ctx Ctx = Deps.ctx();
716   isl::space ParamSpace = Deps.get_space().params();
717 
718   // Create a partial schedule mapping to constants that reflect the execution
719   // order.
720   isl::union_map PartialSchedules = isl::union_map::empty(Ctx);
721   for (auto P : enumerate(Domains)) {
722     isl::val ExecTime = isl::val(Ctx, P.index());
723     isl::union_pw_aff DomSched{P.value(), ExecTime};
724     PartialSchedules = PartialSchedules.unite(DomSched.as_union_map());
725   }
726 
727   return remainingDepsFromPartialSchedule(PartialSchedules, Deps);
728 }
729 
730 /// Determine whether the outermost loop of to bands can be fused while
731 /// respecting validity dependencies.
canFuseOutermost(const isl::schedule_node_band & LHS,const isl::schedule_node_band & RHS,const isl::union_map & Deps)732 static bool canFuseOutermost(const isl::schedule_node_band &LHS,
733                              const isl::schedule_node_band &RHS,
734                              const isl::union_map &Deps) {
735   // { LHSDomain[] -> Scatter[] }
736   isl::union_map LHSPartSched =
737       LHS.get_partial_schedule().get_at(0).as_union_map();
738 
739   // { Domain[] -> Scatter[] }
740   isl::union_map RHSPartSched =
741       RHS.get_partial_schedule().get_at(0).as_union_map();
742 
743   // Dependencies that are already resolved because LHS executes before RHS, but
744   // will not be anymore after fusion. { DefDomain[] -> UseDomain[] }
745   isl::union_map OrderedBySequence =
746       Deps.intersect_domain(LHSPartSched.domain())
747           .intersect_range(RHSPartSched.domain());
748 
749   isl::space ParamSpace = OrderedBySequence.get_space().params();
750   isl::space NewScatterSpace = ParamSpace.add_unnamed_tuple(1);
751 
752   // { Scatter[] -> Scatter[] }
753   isl::map After = isl::map::lex_gt(NewScatterSpace);
754 
755   // After fusion, instances with smaller (or equal, which means they will be
756   // executed in the same iteration, but the LHS instance is still sequenced
757   // before RHS) scatter value will still be executed before. This are the
758   // orderings where this is not necessarily the case.
759   // { LHSDomain[] -> RHSDomain[] }
760   isl::union_map MightBeAfterDoms = After.apply_domain(LHSPartSched.reverse())
761                                         .apply_range(RHSPartSched.reverse());
762 
763   // Dependencies that are not resolved by the new execution order.
764   isl::union_map WithBefore = OrderedBySequence.intersect(MightBeAfterDoms);
765 
766   return WithBefore.is_empty();
767 }
768 
769 /// Fuse @p LHS and @p RHS if possible while preserving validity dependenvies.
tryGreedyFuse(isl::schedule_node_band LHS,isl::schedule_node_band RHS,const isl::union_map & Deps)770 static isl::schedule tryGreedyFuse(isl::schedule_node_band LHS,
771                                    isl::schedule_node_band RHS,
772                                    const isl::union_map &Deps) {
773   if (!canFuseOutermost(LHS, RHS, Deps))
774     return {};
775 
776   LLVM_DEBUG({
777     dbgs() << "Found loops for greedy fusion:\n";
778     dumpIslObj(LHS, dbgs());
779     dbgs() << "and\n";
780     dumpIslObj(RHS, dbgs());
781     dbgs() << "\n";
782   });
783 
784   // The partial schedule of the bands outermost loop that we need to combine
785   // for the fusion.
786   isl::union_pw_aff LHSPartOuterSched = LHS.get_partial_schedule().get_at(0);
787   isl::union_pw_aff RHSPartOuterSched = RHS.get_partial_schedule().get_at(0);
788 
789   // Isolate band bodies as roots of their own schedule trees.
790   IdentityRewriter Rewriter;
791   isl::schedule LHSBody = Rewriter.visit(LHS.first_child());
792   isl::schedule RHSBody = Rewriter.visit(RHS.first_child());
793 
794   // Reconstruct the non-outermost (not going to be fused) loops from both
795   // bands.
796   // TODO: Maybe it is possibly to transfer the 'permutability' property from
797   // LHS+RHS. At minimum we need merge multiple band members at once, otherwise
798   // permutability has no meaning.
799   isl::schedule LHSNewBody =
800       rebuildBand(LHS, LHSBody, [](int i) { return i > 0; });
801   isl::schedule RHSNewBody =
802       rebuildBand(RHS, RHSBody, [](int i) { return i > 0; });
803 
804   // The loop body of the fused loop.
805   isl::schedule NewCommonBody = LHSNewBody.sequence(RHSNewBody);
806 
807   // Combine the partial schedules of both loops to a new one. Instances with
808   // the same scatter value are put together.
809   isl::union_map NewCommonPartialSched =
810       LHSPartOuterSched.as_union_map().unite(RHSPartOuterSched.as_union_map());
811   isl::schedule NewCommonSchedule = NewCommonBody.insert_partial_schedule(
812       NewCommonPartialSched.as_multi_union_pw_aff());
813 
814   return NewCommonSchedule;
815 }
816 
tryGreedyFuse(isl::schedule_node LHS,isl::schedule_node RHS,const isl::union_map & Deps)817 static isl::schedule tryGreedyFuse(isl::schedule_node LHS,
818                                    isl::schedule_node RHS,
819                                    const isl::union_map &Deps) {
820   // TODO: Non-bands could be interpreted as a band with just as single
821   // iteration. However, this is only useful if both ends of a fused loop were
822   // originally loops themselves.
823   if (!LHS.isa<isl::schedule_node_band>())
824     return {};
825   if (!RHS.isa<isl::schedule_node_band>())
826     return {};
827 
828   return tryGreedyFuse(LHS.as<isl::schedule_node_band>(),
829                        RHS.as<isl::schedule_node_band>(), Deps);
830 }
831 
832 /// Fuse all fusable loop top-down in a schedule tree.
833 ///
834 /// The isl::union_map parameters is the set of validity dependencies that have
835 /// not been resolved/carried by a parent schedule node.
836 class GreedyFusionRewriter final
837     : public ScheduleTreeRewriter<GreedyFusionRewriter, isl::union_map> {
838 private:
839   using BaseTy = ScheduleTreeRewriter<GreedyFusionRewriter, isl::union_map>;
getBase()840   BaseTy &getBase() { return *this; }
getBase() const841   const BaseTy &getBase() const { return *this; }
842 
843 public:
844   /// Is set to true if anything has been fused.
845   bool AnyChange = false;
846 
visitBand(isl::schedule_node_band Band,isl::union_map Deps)847   isl::schedule visitBand(isl::schedule_node_band Band, isl::union_map Deps) {
848     // { Domain[] -> Scatter[] }
849     isl::union_map PartSched =
850         isl::union_map::from(Band.get_partial_schedule());
851     assert(getNumScatterDims(PartSched) ==
852            unsignedFromIslSize(Band.n_member()));
853     isl::space ParamSpace = PartSched.get_space().params();
854 
855     // { Scatter[] -> Domain[] }
856     isl::union_map PartSchedRev = PartSched.reverse();
857 
858     // Possible within the same iteration. Dependencies with smaller scatter
859     // value are carried by this loop and therefore have been resolved by the
860     // in-order execution if the loop iteration. A dependency with small scatter
861     // value would be a dependency violation that we assume did not happen. {
862     // Domain[] -> Domain[] }
863     isl::union_map Unsequenced = PartSchedRev.apply_domain(PartSchedRev);
864 
865     // Actual dependencies within the same iteration.
866     // { DefDomain[] -> UseDomain[] }
867     isl::union_map RemDeps = Deps.intersect(Unsequenced);
868 
869     return getBase().visitBand(Band, RemDeps);
870   }
871 
visitSequence(isl::schedule_node_sequence Sequence,isl::union_map Deps)872   isl::schedule visitSequence(isl::schedule_node_sequence Sequence,
873                               isl::union_map Deps) {
874     int NumChildren = isl_schedule_node_n_children(Sequence.get());
875 
876     // List of fusion candidates. The first element is the fusion candidate, the
877     // second is candidate's ancestor that is the sequence's direct child. It is
878     // preferable to use the direct child if not if its non-direct children is
879     // fused to preserve its structure such as mark nodes.
880     SmallVector<std::pair<isl::schedule_node, isl::schedule_node>> Bands;
881     for (auto i : seq<int>(0, NumChildren)) {
882       isl::schedule_node Child = Sequence.child(i);
883       collectPotentiallyFusableBands(Child, Bands, Child);
884     }
885 
886     // Direct children that had at least one of its decendants fused.
887     SmallDenseSet<isl_schedule_node *, 4> ChangedDirectChildren;
888 
889     // Fuse neigboring bands until reaching the end of candidates.
890     int i = 0;
891     while (i + 1 < (int)Bands.size()) {
892       isl::schedule Fused =
893           tryGreedyFuse(Bands[i].first, Bands[i + 1].first, Deps);
894       if (Fused.is_null()) {
895         // Cannot merge this node with the next; look at next pair.
896         i += 1;
897         continue;
898       }
899 
900       // Mark the direct children as (partially) fused.
901       if (!Bands[i].second.is_null())
902         ChangedDirectChildren.insert(Bands[i].second.get());
903       if (!Bands[i + 1].second.is_null())
904         ChangedDirectChildren.insert(Bands[i + 1].second.get());
905 
906       // Collapse the neigbros to a single new candidate that could be fused
907       // with the next candidate.
908       Bands[i] = {Fused.get_root(), {}};
909       Bands.erase(Bands.begin() + i + 1);
910 
911       AnyChange = true;
912     }
913 
914     // By construction equal if done with collectPotentiallyFusableBands's
915     // output.
916     SmallVector<isl::union_set> SubDomains;
917     SubDomains.reserve(NumChildren);
918     for (int i = 0; i < NumChildren; i += 1)
919       SubDomains.push_back(Sequence.child(i).domain());
920     auto SubRemainingDeps = remainigDepsFromSequence(SubDomains, Deps);
921 
922     // We may iterate over direct children multiple times, be sure to add each
923     // at most once.
924     SmallDenseSet<isl_schedule_node *, 4> AlreadyAdded;
925 
926     isl::schedule Result;
927     for (auto &P : Bands) {
928       isl::schedule_node MaybeFused = P.first;
929       isl::schedule_node DirectChild = P.second;
930 
931       // If not modified, use the direct child.
932       if (!DirectChild.is_null() &&
933           !ChangedDirectChildren.count(DirectChild.get())) {
934         if (AlreadyAdded.count(DirectChild.get()))
935           continue;
936         AlreadyAdded.insert(DirectChild.get());
937         MaybeFused = DirectChild;
938       } else {
939         assert(AnyChange &&
940                "Need changed flag for be consistent with actual change");
941       }
942 
943       // Top-down recursion: If the outermost loop has been fused, their nested
944       // bands might be fusable now as well.
945       isl::schedule InnerFused = visit(MaybeFused, SubRemainingDeps);
946 
947       // Reconstruct the sequence, with some of the children fused.
948       if (Result.is_null())
949         Result = InnerFused;
950       else
951         Result = Result.sequence(InnerFused);
952     }
953 
954     return Result;
955   }
956 };
957 
958 } // namespace
959 
isBandMark(const isl::schedule_node & Node)960 bool polly::isBandMark(const isl::schedule_node &Node) {
961   return isMark(Node) &&
962          isLoopAttr(Node.as<isl::schedule_node_mark>().get_id());
963 }
964 
getBandAttr(isl::schedule_node MarkOrBand)965 BandAttr *polly::getBandAttr(isl::schedule_node MarkOrBand) {
966   MarkOrBand = moveToBandMark(MarkOrBand);
967   if (!isMark(MarkOrBand))
968     return nullptr;
969 
970   return getLoopAttr(MarkOrBand.as<isl::schedule_node_mark>().get_id());
971 }
972 
hoistExtensionNodes(isl::schedule Sched)973 isl::schedule polly::hoistExtensionNodes(isl::schedule Sched) {
974   // If there is no extension node in the first place, return the original
975   // schedule tree.
976   if (!containsExtensionNode(Sched))
977     return Sched;
978 
979   // Build options can anchor schedule nodes, such that the schedule tree cannot
980   // be modified anymore. Therefore, apply build options after the tree has been
981   // created.
982   CollectASTBuildOptions Collector;
983   Collector.visit(Sched);
984 
985   // Rewrite the schedule tree without extension nodes.
986   ExtensionNodeRewriter Rewriter;
987   isl::schedule NewSched = Rewriter.visitSchedule(Sched);
988 
989   // Reapply the AST build options. The rewriter must not change the iteration
990   // order of bands. Any other node type is ignored.
991   ApplyASTBuildOptions Applicator(Collector.ASTBuildOptions);
992   NewSched = Applicator.visitSchedule(NewSched);
993 
994   return NewSched;
995 }
996 
applyFullUnroll(isl::schedule_node BandToUnroll)997 isl::schedule polly::applyFullUnroll(isl::schedule_node BandToUnroll) {
998   isl::ctx Ctx = BandToUnroll.ctx();
999 
1000   // Remove the loop's mark, the loop will disappear anyway.
1001   BandToUnroll = removeMark(BandToUnroll);
1002   assert(isBandWithSingleLoop(BandToUnroll));
1003 
1004   isl::multi_union_pw_aff PartialSched = isl::manage(
1005       isl_schedule_node_band_get_partial_schedule(BandToUnroll.get()));
1006   assert(unsignedFromIslSize(PartialSched.dim(isl::dim::out)) == 1u &&
1007          "Can only unroll a single dimension");
1008   isl::union_pw_aff PartialSchedUAff = PartialSched.at(0);
1009 
1010   isl::union_set Domain = BandToUnroll.get_domain();
1011   PartialSchedUAff = PartialSchedUAff.intersect_domain(Domain);
1012   isl::union_map PartialSchedUMap =
1013       isl::union_map::from(isl::union_pw_multi_aff(PartialSchedUAff));
1014 
1015   // Enumerator only the scatter elements.
1016   isl::union_set ScatterList = PartialSchedUMap.range();
1017 
1018   // Enumerate all loop iterations.
1019   // TODO: Diagnose if not enumerable or depends on a parameter.
1020   SmallVector<isl::point, 16> Elts;
1021   ScatterList.foreach_point([&Elts](isl::point P) -> isl::stat {
1022     Elts.push_back(P);
1023     return isl::stat::ok();
1024   });
1025 
1026   // Don't assume that foreach_point returns in execution order.
1027   llvm::sort(Elts, [](isl::point P1, isl::point P2) -> bool {
1028     isl::val C1 = P1.get_coordinate_val(isl::dim::set, 0);
1029     isl::val C2 = P2.get_coordinate_val(isl::dim::set, 0);
1030     return C1.lt(C2);
1031   });
1032 
1033   // Convert the points to a sequence of filters.
1034   isl::union_set_list List = isl::union_set_list(Ctx, Elts.size());
1035   for (isl::point P : Elts) {
1036     // Determine the domains that map this scatter element.
1037     isl::union_set DomainFilter = PartialSchedUMap.intersect_range(P).domain();
1038 
1039     List = List.add(DomainFilter);
1040   }
1041 
1042   // Replace original band with unrolled sequence.
1043   isl::schedule_node Body =
1044       isl::manage(isl_schedule_node_delete(BandToUnroll.release()));
1045   Body = Body.insert_sequence(List);
1046   return Body.get_schedule();
1047 }
1048 
applyPartialUnroll(isl::schedule_node BandToUnroll,int Factor)1049 isl::schedule polly::applyPartialUnroll(isl::schedule_node BandToUnroll,
1050                                         int Factor) {
1051   assert(Factor > 0 && "Positive unroll factor required");
1052   isl::ctx Ctx = BandToUnroll.ctx();
1053 
1054   // Remove the mark, save the attribute for later use.
1055   BandAttr *Attr;
1056   BandToUnroll = removeMark(BandToUnroll, Attr);
1057   assert(isBandWithSingleLoop(BandToUnroll));
1058 
1059   isl::multi_union_pw_aff PartialSched = isl::manage(
1060       isl_schedule_node_band_get_partial_schedule(BandToUnroll.get()));
1061 
1062   // { Stmt[] -> [x] }
1063   isl::union_pw_aff PartialSchedUAff = PartialSched.at(0);
1064 
1065   // Here we assume the schedule stride is one and starts with 0, which is not
1066   // necessarily the case.
1067   isl::union_pw_aff StridedPartialSchedUAff =
1068       isl::union_pw_aff::empty(PartialSchedUAff.get_space());
1069   isl::val ValFactor{Ctx, Factor};
1070   PartialSchedUAff.foreach_pw_aff([&StridedPartialSchedUAff,
1071                                    &ValFactor](isl::pw_aff PwAff) -> isl::stat {
1072     isl::space Space = PwAff.get_space();
1073     isl::set Universe = isl::set::universe(Space.domain());
1074     isl::pw_aff AffFactor{Universe, ValFactor};
1075     isl::pw_aff DivSchedAff = PwAff.div(AffFactor).floor().mul(AffFactor);
1076     StridedPartialSchedUAff = StridedPartialSchedUAff.union_add(DivSchedAff);
1077     return isl::stat::ok();
1078   });
1079 
1080   isl::union_set_list List = isl::union_set_list(Ctx, Factor);
1081   for (auto i : seq<int>(0, Factor)) {
1082     // { Stmt[] -> [x] }
1083     isl::union_map UMap =
1084         isl::union_map::from(isl::union_pw_multi_aff(PartialSchedUAff));
1085 
1086     // { [x] }
1087     isl::basic_set Divisible = isDivisibleBySet(Ctx, Factor, i);
1088 
1089     // { Stmt[] }
1090     isl::union_set UnrolledDomain = UMap.intersect_range(Divisible).domain();
1091 
1092     List = List.add(UnrolledDomain);
1093   }
1094 
1095   isl::schedule_node Body =
1096       isl::manage(isl_schedule_node_delete(BandToUnroll.copy()));
1097   Body = Body.insert_sequence(List);
1098   isl::schedule_node NewLoop =
1099       Body.insert_partial_schedule(StridedPartialSchedUAff);
1100 
1101   MDNode *FollowupMD = nullptr;
1102   if (Attr && Attr->Metadata)
1103     FollowupMD =
1104         findOptionalNodeOperand(Attr->Metadata, LLVMLoopUnrollFollowupUnrolled);
1105 
1106   isl::id NewBandId = createGeneratedLoopAttr(Ctx, FollowupMD);
1107   if (!NewBandId.is_null())
1108     NewLoop = insertMark(NewLoop, NewBandId);
1109 
1110   return NewLoop.get_schedule();
1111 }
1112 
getPartialTilePrefixes(isl::set ScheduleRange,int VectorWidth)1113 isl::set polly::getPartialTilePrefixes(isl::set ScheduleRange,
1114                                        int VectorWidth) {
1115   unsigned Dims = unsignedFromIslSize(ScheduleRange.tuple_dim());
1116   assert(Dims >= 1);
1117   isl::set LoopPrefixes =
1118       ScheduleRange.drop_constraints_involving_dims(isl::dim::set, Dims - 1, 1);
1119   auto ExtentPrefixes = addExtentConstraints(LoopPrefixes, VectorWidth);
1120   isl::set BadPrefixes = ExtentPrefixes.subtract(ScheduleRange);
1121   BadPrefixes = BadPrefixes.project_out(isl::dim::set, Dims - 1, 1);
1122   LoopPrefixes = LoopPrefixes.project_out(isl::dim::set, Dims - 1, 1);
1123   return LoopPrefixes.subtract(BadPrefixes);
1124 }
1125 
getIsolateOptions(isl::set IsolateDomain,unsigned OutDimsNum)1126 isl::union_set polly::getIsolateOptions(isl::set IsolateDomain,
1127                                         unsigned OutDimsNum) {
1128   unsigned Dims = unsignedFromIslSize(IsolateDomain.tuple_dim());
1129   assert(OutDimsNum <= Dims &&
1130          "The isl::set IsolateDomain is used to describe the range of schedule "
1131          "dimensions values, which should be isolated. Consequently, the "
1132          "number of its dimensions should be greater than or equal to the "
1133          "number of the schedule dimensions.");
1134   isl::map IsolateRelation = isl::map::from_domain(IsolateDomain);
1135   IsolateRelation = IsolateRelation.move_dims(isl::dim::out, 0, isl::dim::in,
1136                                               Dims - OutDimsNum, OutDimsNum);
1137   isl::set IsolateOption = IsolateRelation.wrap();
1138   isl::id Id = isl::id::alloc(IsolateOption.ctx(), "isolate", nullptr);
1139   IsolateOption = IsolateOption.set_tuple_id(Id);
1140   return isl::union_set(IsolateOption);
1141 }
1142 
getDimOptions(isl::ctx Ctx,const char * Option)1143 isl::union_set polly::getDimOptions(isl::ctx Ctx, const char *Option) {
1144   isl::space Space(Ctx, 0, 1);
1145   auto DimOption = isl::set::universe(Space);
1146   auto Id = isl::id::alloc(Ctx, Option, nullptr);
1147   DimOption = DimOption.set_tuple_id(Id);
1148   return isl::union_set(DimOption);
1149 }
1150 
tileNode(isl::schedule_node Node,const char * Identifier,ArrayRef<int> TileSizes,int DefaultTileSize)1151 isl::schedule_node polly::tileNode(isl::schedule_node Node,
1152                                    const char *Identifier,
1153                                    ArrayRef<int> TileSizes,
1154                                    int DefaultTileSize) {
1155   auto Space = isl::manage(isl_schedule_node_band_get_space(Node.get()));
1156   auto Dims = Space.dim(isl::dim::set);
1157   auto Sizes = isl::multi_val::zero(Space);
1158   std::string IdentifierString(Identifier);
1159   for (unsigned i : rangeIslSize(0, Dims)) {
1160     unsigned tileSize = i < TileSizes.size() ? TileSizes[i] : DefaultTileSize;
1161     Sizes = Sizes.set_val(i, isl::val(Node.ctx(), tileSize));
1162   }
1163   auto TileLoopMarkerStr = IdentifierString + " - Tiles";
1164   auto TileLoopMarker = isl::id::alloc(Node.ctx(), TileLoopMarkerStr, nullptr);
1165   Node = Node.insert_mark(TileLoopMarker);
1166   Node = Node.child(0);
1167   Node =
1168       isl::manage(isl_schedule_node_band_tile(Node.release(), Sizes.release()));
1169   Node = Node.child(0);
1170   auto PointLoopMarkerStr = IdentifierString + " - Points";
1171   auto PointLoopMarker =
1172       isl::id::alloc(Node.ctx(), PointLoopMarkerStr, nullptr);
1173   Node = Node.insert_mark(PointLoopMarker);
1174   return Node.child(0);
1175 }
1176 
applyRegisterTiling(isl::schedule_node Node,ArrayRef<int> TileSizes,int DefaultTileSize)1177 isl::schedule_node polly::applyRegisterTiling(isl::schedule_node Node,
1178                                               ArrayRef<int> TileSizes,
1179                                               int DefaultTileSize) {
1180   Node = tileNode(Node, "Register tiling", TileSizes, DefaultTileSize);
1181   auto Ctx = Node.ctx();
1182   return Node.as<isl::schedule_node_band>().set_ast_build_options(
1183       isl::union_set(Ctx, "{unroll[x]}"));
1184 }
1185 
1186 /// Find statements and sub-loops in (possibly nested) sequences.
1187 static void
collectFissionableStmts(isl::schedule_node Node,SmallVectorImpl<isl::schedule_node> & ScheduleStmts)1188 collectFissionableStmts(isl::schedule_node Node,
1189                         SmallVectorImpl<isl::schedule_node> &ScheduleStmts) {
1190   if (isBand(Node) || isLeaf(Node)) {
1191     ScheduleStmts.push_back(Node);
1192     return;
1193   }
1194 
1195   if (Node.has_children()) {
1196     isl::schedule_node C = Node.first_child();
1197     while (true) {
1198       collectFissionableStmts(C, ScheduleStmts);
1199       if (!C.has_next_sibling())
1200         break;
1201       C = C.next_sibling();
1202     }
1203   }
1204 }
1205 
applyMaxFission(isl::schedule_node BandToFission)1206 isl::schedule polly::applyMaxFission(isl::schedule_node BandToFission) {
1207   isl::ctx Ctx = BandToFission.ctx();
1208   BandToFission = removeMark(BandToFission);
1209   isl::schedule_node BandBody = BandToFission.child(0);
1210 
1211   SmallVector<isl::schedule_node> FissionableStmts;
1212   collectFissionableStmts(BandBody, FissionableStmts);
1213   size_t N = FissionableStmts.size();
1214 
1215   // Collect the domain for each of the statements that will get their own loop.
1216   isl::union_set_list DomList = isl::union_set_list(Ctx, N);
1217   for (size_t i = 0; i < N; ++i) {
1218     isl::schedule_node BodyPart = FissionableStmts[i];
1219     DomList = DomList.add(BodyPart.get_domain());
1220   }
1221 
1222   // Apply the fission by copying the entire loop, but inserting a filter for
1223   // the statement domains for each fissioned loop.
1224   isl::schedule_node Fissioned = BandToFission.insert_sequence(DomList);
1225 
1226   return Fissioned.get_schedule();
1227 }
1228 
applyGreedyFusion(isl::schedule Sched,const isl::union_map & Deps)1229 isl::schedule polly::applyGreedyFusion(isl::schedule Sched,
1230                                        const isl::union_map &Deps) {
1231   LLVM_DEBUG(dbgs() << "Greedy loop fusion\n");
1232 
1233   GreedyFusionRewriter Rewriter;
1234   isl::schedule Result = Rewriter.visit(Sched, Deps);
1235   if (!Rewriter.AnyChange) {
1236     LLVM_DEBUG(dbgs() << "Found nothing to fuse\n");
1237     return Sched;
1238   }
1239 
1240   // GreedyFusionRewriter due to working loop-by-loop, bands with multiple loops
1241   // may have been split into multiple bands.
1242   return collapseBands(Result);
1243 }
1244