1 //===- polly/ScheduleTreeTransform.cpp --------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Make changes to isl's schedule tree data structure.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "polly/ScheduleTreeTransform.h"
14 #include "polly/Support/ISLTools.h"
15 #include "polly/Support/ScopHelper.h"
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/Sequence.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/IR/Constants.h"
20 #include "llvm/IR/Metadata.h"
21 #include "llvm/Transforms/Utils/UnrollLoop.h"
22 
23 using namespace polly;
24 using namespace llvm;
25 
26 namespace {
27 /// Recursively visit all nodes of a schedule tree while allowing changes.
28 ///
29 /// The visit methods return an isl::schedule_node that is used to continue
30 /// visiting the tree. Structural changes such as returning a different node
31 /// will confuse the visitor.
32 template <typename Derived, typename... Args>
33 struct ScheduleNodeRewriter
34     : public RecursiveScheduleTreeVisitor<Derived, isl::schedule_node,
35                                           Args...> {
36   Derived &getDerived() { return *static_cast<Derived *>(this); }
37   const Derived &getDerived() const {
38     return *static_cast<const Derived *>(this);
39   }
40 
41   isl::schedule_node visitNode(const isl::schedule_node &Node, Args... args) {
42     if (!Node.has_children())
43       return Node;
44 
45     isl::schedule_node It = Node.first_child();
46     while (true) {
47       It = getDerived().visit(It, std::forward<Args>(args)...);
48       if (!It.has_next_sibling())
49         break;
50       It = It.next_sibling();
51     }
52     return It.parent();
53   }
54 };
55 
56 /// Rewrite a schedule tree by reconstructing it bottom-up.
57 ///
58 /// By default, the original schedule tree is reconstructed. To build a
59 /// different tree, redefine visitor methods in a derived class (CRTP).
60 ///
61 /// Note that AST build options are not applied; Setting the isolate[] option
62 /// makes the schedule tree 'anchored' and cannot be modified afterwards. Hence,
63 /// AST build options must be set after the tree has been constructed.
64 template <typename Derived, typename... Args>
65 struct ScheduleTreeRewriter
66     : public RecursiveScheduleTreeVisitor<Derived, isl::schedule, Args...> {
67   Derived &getDerived() { return *static_cast<Derived *>(this); }
68   const Derived &getDerived() const {
69     return *static_cast<const Derived *>(this);
70   }
71 
72   isl::schedule visitDomain(isl::schedule_node_domain Node, Args... args) {
73     // Every schedule_tree already has a domain node, no need to add one.
74     return getDerived().visit(Node.first_child(), std::forward<Args>(args)...);
75   }
76 
77   isl::schedule visitBand(isl::schedule_node_band Band, Args... args) {
78     isl::multi_union_pw_aff PartialSched =
79         isl::manage(isl_schedule_node_band_get_partial_schedule(Band.get()));
80     isl::schedule NewChild =
81         getDerived().visit(Band.child(0), std::forward<Args>(args)...);
82     isl::schedule_node NewNode =
83         NewChild.insert_partial_schedule(PartialSched).get_root().child(0);
84 
85     // Reapply permutability and coincidence attributes.
86     NewNode = isl::manage(isl_schedule_node_band_set_permutable(
87         NewNode.release(), isl_schedule_node_band_get_permutable(Band.get())));
88     unsigned BandDims = isl_schedule_node_band_n_member(Band.get());
89     for (unsigned i = 0; i < BandDims; i += 1)
90       NewNode = isl::manage(isl_schedule_node_band_member_set_coincident(
91           NewNode.release(), i,
92           isl_schedule_node_band_member_get_coincident(Band.get(), i)));
93 
94     return NewNode.get_schedule();
95   }
96 
97   isl::schedule visitSequence(isl::schedule_node_sequence Sequence,
98                               Args... args) {
99     int NumChildren = isl_schedule_node_n_children(Sequence.get());
100     isl::schedule Result =
101         getDerived().visit(Sequence.child(0), std::forward<Args>(args)...);
102     for (int i = 1; i < NumChildren; i += 1)
103       Result = Result.sequence(
104           getDerived().visit(Sequence.child(i), std::forward<Args>(args)...));
105     return Result;
106   }
107 
108   isl::schedule visitSet(isl::schedule_node_set Set, Args... args) {
109     int NumChildren = isl_schedule_node_n_children(Set.get());
110     isl::schedule Result =
111         getDerived().visit(Set.child(0), std::forward<Args>(args)...);
112     for (int i = 1; i < NumChildren; i += 1)
113       Result = isl::manage(
114           isl_schedule_set(Result.release(),
115                            getDerived()
116                                .visit(Set.child(i), std::forward<Args>(args)...)
117                                .release()));
118     return Result;
119   }
120 
121   isl::schedule visitLeaf(isl::schedule_node_leaf Leaf, Args... args) {
122     return isl::schedule::from_domain(Leaf.get_domain());
123   }
124 
125   isl::schedule visitMark(const isl::schedule_node &Mark, Args... args) {
126 
127     isl::id TheMark = Mark.as<isl::schedule_node_mark>().get_id();
128     isl::schedule_node NewChild =
129         getDerived()
130             .visit(Mark.first_child(), std::forward<Args>(args)...)
131             .get_root()
132             .first_child();
133     return NewChild.insert_mark(TheMark).get_schedule();
134   }
135 
136   isl::schedule visitExtension(isl::schedule_node_extension Extension,
137                                Args... args) {
138     isl::union_map TheExtension =
139         Extension.as<isl::schedule_node_extension>().get_extension();
140     isl::schedule_node NewChild = getDerived()
141                                       .visit(Extension.child(0), args...)
142                                       .get_root()
143                                       .first_child();
144     isl::schedule_node NewExtension =
145         isl::schedule_node::from_extension(TheExtension);
146     return NewChild.graft_before(NewExtension).get_schedule();
147   }
148 
149   isl::schedule visitFilter(isl::schedule_node_filter Filter, Args... args) {
150     isl::union_set FilterDomain =
151         Filter.as<isl::schedule_node_filter>().get_filter();
152     isl::schedule NewSchedule =
153         getDerived().visit(Filter.child(0), std::forward<Args>(args)...);
154     return NewSchedule.intersect_domain(FilterDomain);
155   }
156 
157   isl::schedule visitNode(isl::schedule_node Node, Args... args) {
158     llvm_unreachable("Not implemented");
159   }
160 };
161 
162 /// Rewrite a schedule tree to an equivalent one without extension nodes.
163 ///
164 /// Each visit method takes two additional arguments:
165 ///
166 ///  * The new domain the node, which is the inherited domain plus any domains
167 ///    added by extension nodes.
168 ///
169 ///  * A map of extension domains of all children is returned; it is required by
170 ///    band nodes to schedule the additional domains at the same position as the
171 ///    extension node would.
172 ///
173 struct ExtensionNodeRewriter
174     : public ScheduleTreeRewriter<ExtensionNodeRewriter, const isl::union_set &,
175                                   isl::union_map &> {
176   using BaseTy = ScheduleTreeRewriter<ExtensionNodeRewriter,
177                                       const isl::union_set &, isl::union_map &>;
178   BaseTy &getBase() { return *this; }
179   const BaseTy &getBase() const { return *this; }
180 
181   isl::schedule visitSchedule(isl::schedule Schedule) {
182     isl::union_map Extensions;
183     isl::schedule Result =
184         visit(Schedule.get_root(), Schedule.get_domain(), Extensions);
185     assert(!Extensions.is_null() && Extensions.is_empty());
186     return Result;
187   }
188 
189   isl::schedule visitSequence(isl::schedule_node_sequence Sequence,
190                               const isl::union_set &Domain,
191                               isl::union_map &Extensions) {
192     int NumChildren = isl_schedule_node_n_children(Sequence.get());
193     isl::schedule NewNode = visit(Sequence.first_child(), Domain, Extensions);
194     for (int i = 1; i < NumChildren; i += 1) {
195       isl::schedule_node OldChild = Sequence.child(i);
196       isl::union_map NewChildExtensions;
197       isl::schedule NewChildNode = visit(OldChild, Domain, NewChildExtensions);
198       NewNode = NewNode.sequence(NewChildNode);
199       Extensions = Extensions.unite(NewChildExtensions);
200     }
201     return NewNode;
202   }
203 
204   isl::schedule visitSet(isl::schedule_node_set Set,
205                          const isl::union_set &Domain,
206                          isl::union_map &Extensions) {
207     int NumChildren = isl_schedule_node_n_children(Set.get());
208     isl::schedule NewNode = visit(Set.first_child(), Domain, Extensions);
209     for (int i = 1; i < NumChildren; i += 1) {
210       isl::schedule_node OldChild = Set.child(i);
211       isl::union_map NewChildExtensions;
212       isl::schedule NewChildNode = visit(OldChild, Domain, NewChildExtensions);
213       NewNode = isl::manage(
214           isl_schedule_set(NewNode.release(), NewChildNode.release()));
215       Extensions = Extensions.unite(NewChildExtensions);
216     }
217     return NewNode;
218   }
219 
220   isl::schedule visitLeaf(isl::schedule_node_leaf Leaf,
221                           const isl::union_set &Domain,
222                           isl::union_map &Extensions) {
223     Extensions = isl::union_map::empty(Leaf.ctx());
224     return isl::schedule::from_domain(Domain);
225   }
226 
227   isl::schedule visitBand(isl::schedule_node_band OldNode,
228                           const isl::union_set &Domain,
229                           isl::union_map &OuterExtensions) {
230     isl::schedule_node OldChild = OldNode.first_child();
231     isl::multi_union_pw_aff PartialSched =
232         isl::manage(isl_schedule_node_band_get_partial_schedule(OldNode.get()));
233 
234     isl::union_map NewChildExtensions;
235     isl::schedule NewChild = visit(OldChild, Domain, NewChildExtensions);
236 
237     // Add the extensions to the partial schedule.
238     OuterExtensions = isl::union_map::empty(NewChildExtensions.ctx());
239     isl::union_map NewPartialSchedMap = isl::union_map::from(PartialSched);
240     unsigned BandDims = isl_schedule_node_band_n_member(OldNode.get());
241     for (isl::map Ext : NewChildExtensions.get_map_list()) {
242       unsigned ExtDims = Ext.domain_tuple_dim().release();
243       assert(ExtDims >= BandDims);
244       unsigned OuterDims = ExtDims - BandDims;
245 
246       isl::map BandSched =
247           Ext.project_out(isl::dim::in, 0, OuterDims).reverse();
248       NewPartialSchedMap = NewPartialSchedMap.unite(BandSched);
249 
250       // There might be more outer bands that have to schedule the extensions.
251       if (OuterDims > 0) {
252         isl::map OuterSched =
253             Ext.project_out(isl::dim::in, OuterDims, BandDims);
254         OuterExtensions = OuterExtensions.unite(OuterSched);
255       }
256     }
257     isl::multi_union_pw_aff NewPartialSchedAsAsMultiUnionPwAff =
258         isl::multi_union_pw_aff::from_union_map(NewPartialSchedMap);
259     isl::schedule_node NewNode =
260         NewChild.insert_partial_schedule(NewPartialSchedAsAsMultiUnionPwAff)
261             .get_root()
262             .child(0);
263 
264     // Reapply permutability and coincidence attributes.
265     NewNode = isl::manage(isl_schedule_node_band_set_permutable(
266         NewNode.release(),
267         isl_schedule_node_band_get_permutable(OldNode.get())));
268     for (unsigned i = 0; i < BandDims; i += 1) {
269       NewNode = isl::manage(isl_schedule_node_band_member_set_coincident(
270           NewNode.release(), i,
271           isl_schedule_node_band_member_get_coincident(OldNode.get(), i)));
272     }
273 
274     return NewNode.get_schedule();
275   }
276 
277   isl::schedule visitFilter(isl::schedule_node_filter Filter,
278                             const isl::union_set &Domain,
279                             isl::union_map &Extensions) {
280     isl::union_set FilterDomain =
281         Filter.as<isl::schedule_node_filter>().get_filter();
282     isl::union_set NewDomain = Domain.intersect(FilterDomain);
283 
284     // A filter is added implicitly if necessary when joining schedule trees.
285     return visit(Filter.first_child(), NewDomain, Extensions);
286   }
287 
288   isl::schedule visitExtension(isl::schedule_node_extension Extension,
289                                const isl::union_set &Domain,
290                                isl::union_map &Extensions) {
291     isl::union_map ExtDomain =
292         Extension.as<isl::schedule_node_extension>().get_extension();
293     isl::union_set NewDomain = Domain.unite(ExtDomain.range());
294     isl::union_map ChildExtensions;
295     isl::schedule NewChild =
296         visit(Extension.first_child(), NewDomain, ChildExtensions);
297     Extensions = ChildExtensions.unite(ExtDomain);
298     return NewChild;
299   }
300 };
301 
302 /// Collect all AST build options in any schedule tree band.
303 ///
304 /// ScheduleTreeRewriter cannot apply the schedule tree options. This class
305 /// collects these options to apply them later.
306 struct CollectASTBuildOptions
307     : public RecursiveScheduleTreeVisitor<CollectASTBuildOptions> {
308   using BaseTy = RecursiveScheduleTreeVisitor<CollectASTBuildOptions>;
309   BaseTy &getBase() { return *this; }
310   const BaseTy &getBase() const { return *this; }
311 
312   llvm::SmallVector<isl::union_set, 8> ASTBuildOptions;
313 
314   void visitBand(isl::schedule_node_band Band) {
315     ASTBuildOptions.push_back(
316         isl::manage(isl_schedule_node_band_get_ast_build_options(Band.get())));
317     return getBase().visitBand(Band);
318   }
319 };
320 
321 /// Apply AST build options to the bands in a schedule tree.
322 ///
323 /// This rewrites a schedule tree with the AST build options applied. We assume
324 /// that the band nodes are visited in the same order as they were when the
325 /// build options were collected, typically by CollectASTBuildOptions.
326 struct ApplyASTBuildOptions
327     : public ScheduleNodeRewriter<ApplyASTBuildOptions> {
328   using BaseTy = ScheduleNodeRewriter<ApplyASTBuildOptions>;
329   BaseTy &getBase() { return *this; }
330   const BaseTy &getBase() const { return *this; }
331 
332   size_t Pos;
333   llvm::ArrayRef<isl::union_set> ASTBuildOptions;
334 
335   ApplyASTBuildOptions(llvm::ArrayRef<isl::union_set> ASTBuildOptions)
336       : ASTBuildOptions(ASTBuildOptions) {}
337 
338   isl::schedule visitSchedule(isl::schedule Schedule) {
339     Pos = 0;
340     isl::schedule Result = visit(Schedule).get_schedule();
341     assert(Pos == ASTBuildOptions.size() &&
342            "AST build options must match to band nodes");
343     return Result;
344   }
345 
346   isl::schedule_node visitBand(isl::schedule_node_band Band) {
347     isl::schedule_node_band Result =
348         Band.set_ast_build_options(ASTBuildOptions[Pos]);
349     Pos += 1;
350     return getBase().visitBand(Result);
351   }
352 };
353 
354 /// Return whether the schedule contains an extension node.
355 static bool containsExtensionNode(isl::schedule Schedule) {
356   assert(!Schedule.is_null());
357 
358   auto Callback = [](__isl_keep isl_schedule_node *Node,
359                      void *User) -> isl_bool {
360     if (isl_schedule_node_get_type(Node) == isl_schedule_node_extension) {
361       // Stop walking the schedule tree.
362       return isl_bool_error;
363     }
364 
365     // Continue searching the subtree.
366     return isl_bool_true;
367   };
368   isl_stat RetVal = isl_schedule_foreach_schedule_node_top_down(
369       Schedule.get(), Callback, nullptr);
370 
371   // We assume that the traversal itself does not fail, i.e. the only reason to
372   // return isl_stat_error is that an extension node was found.
373   return RetVal == isl_stat_error;
374 }
375 
376 /// Find a named MDNode property in a LoopID.
377 static MDNode *findOptionalNodeOperand(MDNode *LoopMD, StringRef Name) {
378   return dyn_cast_or_null<MDNode>(
379       findMetadataOperand(LoopMD, Name).getValueOr(nullptr));
380 }
381 
382 /// Is this node of type mark?
383 static bool isMark(const isl::schedule_node &Node) {
384   return isl_schedule_node_get_type(Node.get()) == isl_schedule_node_mark;
385 }
386 
387 #ifndef NDEBUG
388 /// Is this node of type band?
389 static bool isBand(const isl::schedule_node &Node) {
390   return isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band;
391 }
392 
393 /// Is this node a band of a single dimension (i.e. could represent a loop)?
394 static bool isBandWithSingleLoop(const isl::schedule_node &Node) {
395 
396   return isBand(Node) && isl_schedule_node_band_n_member(Node.get()) == 1;
397 }
398 #endif
399 
400 static bool isLeaf(const isl::schedule_node &Node) {
401   return isl_schedule_node_get_type(Node.get()) == isl_schedule_node_leaf;
402 }
403 
404 /// Create an isl::id representing the output loop after a transformation.
405 static isl::id createGeneratedLoopAttr(isl::ctx Ctx, MDNode *FollowupLoopMD) {
406   // Don't need to id the followup.
407   // TODO: Append llvm.loop.disable_heustistics metadata unless overridden by
408   //       user followup-MD
409   if (!FollowupLoopMD)
410     return {};
411 
412   BandAttr *Attr = new BandAttr();
413   Attr->Metadata = FollowupLoopMD;
414   return getIslLoopAttr(Ctx, Attr);
415 }
416 
417 /// A loop consists of a band and an optional marker that wraps it. Return the
418 /// outermost of the two.
419 
420 /// That is, either the mark or, if there is not mark, the loop itself. Can
421 /// start with either the mark or the band.
422 static isl::schedule_node moveToBandMark(isl::schedule_node BandOrMark) {
423   if (isBandMark(BandOrMark)) {
424     assert(isBandWithSingleLoop(BandOrMark.child(0)));
425     return BandOrMark;
426   }
427   assert(isBandWithSingleLoop(BandOrMark));
428 
429   isl::schedule_node Mark = BandOrMark.parent();
430   if (isBandMark(Mark))
431     return Mark;
432 
433   // Band has no loop marker.
434   return BandOrMark;
435 }
436 
437 static isl::schedule_node removeMark(isl::schedule_node MarkOrBand,
438                                      BandAttr *&Attr) {
439   MarkOrBand = moveToBandMark(MarkOrBand);
440 
441   isl::schedule_node Band;
442   if (isMark(MarkOrBand)) {
443     Attr = getLoopAttr(MarkOrBand.as<isl::schedule_node_mark>().get_id());
444     Band = isl::manage(isl_schedule_node_delete(MarkOrBand.release()));
445   } else {
446     Attr = nullptr;
447     Band = MarkOrBand;
448   }
449 
450   assert(isBandWithSingleLoop(Band));
451   return Band;
452 }
453 
454 /// Remove the mark that wraps a loop. Return the band representing the loop.
455 static isl::schedule_node removeMark(isl::schedule_node MarkOrBand) {
456   BandAttr *Attr;
457   return removeMark(MarkOrBand, Attr);
458 }
459 
460 static isl::schedule_node insertMark(isl::schedule_node Band, isl::id Mark) {
461   assert(isBand(Band));
462   assert(moveToBandMark(Band).is_equal(Band) &&
463          "Don't add a two marks for a band");
464 
465   return Band.insert_mark(Mark).child(0);
466 }
467 
468 /// Return the (one-dimensional) set of numbers that are divisible by @p Factor
469 /// with remainder @p Offset.
470 ///
471 ///  isDivisibleBySet(Ctx, 4, 0) = { [i] : floord(i,4) = 0 }
472 ///  isDivisibleBySet(Ctx, 4, 1) = { [i] : floord(i,4) = 1 }
473 ///
474 static isl::basic_set isDivisibleBySet(isl::ctx &Ctx, long Factor,
475                                        long Offset) {
476   isl::val ValFactor{Ctx, Factor};
477   isl::val ValOffset{Ctx, Offset};
478 
479   isl::space Unispace{Ctx, 0, 1};
480   isl::local_space LUnispace{Unispace};
481   isl::aff AffFactor{LUnispace, ValFactor};
482   isl::aff AffOffset{LUnispace, ValOffset};
483 
484   isl::aff Id = isl::aff::var_on_domain(LUnispace, isl::dim::out, 0);
485   isl::aff DivMul = Id.mod(ValFactor);
486   isl::basic_map Divisible = isl::basic_map::from_aff(DivMul);
487   isl::basic_map Modulo = Divisible.fix_val(isl::dim::out, 0, ValOffset);
488   return Modulo.domain();
489 }
490 
491 /// Make the last dimension of Set to take values from 0 to VectorWidth - 1.
492 ///
493 /// @param Set         A set, which should be modified.
494 /// @param VectorWidth A parameter, which determines the constraint.
495 static isl::set addExtentConstraints(isl::set Set, int VectorWidth) {
496   unsigned Dims = Set.tuple_dim().release();
497   isl::space Space = Set.get_space();
498   isl::local_space LocalSpace = isl::local_space(Space);
499   isl::constraint ExtConstr = isl::constraint::alloc_inequality(LocalSpace);
500   ExtConstr = ExtConstr.set_constant_si(0);
501   ExtConstr = ExtConstr.set_coefficient_si(isl::dim::set, Dims - 1, 1);
502   Set = Set.add_constraint(ExtConstr);
503   ExtConstr = isl::constraint::alloc_inequality(LocalSpace);
504   ExtConstr = ExtConstr.set_constant_si(VectorWidth - 1);
505   ExtConstr = ExtConstr.set_coefficient_si(isl::dim::set, Dims - 1, -1);
506   return Set.add_constraint(ExtConstr);
507 }
508 } // namespace
509 
510 bool polly::isBandMark(const isl::schedule_node &Node) {
511   return isMark(Node) &&
512          isLoopAttr(Node.as<isl::schedule_node_mark>().get_id());
513 }
514 
515 BandAttr *polly::getBandAttr(isl::schedule_node MarkOrBand) {
516   MarkOrBand = moveToBandMark(MarkOrBand);
517   if (!isMark(MarkOrBand))
518     return nullptr;
519 
520   return getLoopAttr(MarkOrBand.as<isl::schedule_node_mark>().get_id());
521 }
522 
523 isl::schedule polly::hoistExtensionNodes(isl::schedule Sched) {
524   // If there is no extension node in the first place, return the original
525   // schedule tree.
526   if (!containsExtensionNode(Sched))
527     return Sched;
528 
529   // Build options can anchor schedule nodes, such that the schedule tree cannot
530   // be modified anymore. Therefore, apply build options after the tree has been
531   // created.
532   CollectASTBuildOptions Collector;
533   Collector.visit(Sched);
534 
535   // Rewrite the schedule tree without extension nodes.
536   ExtensionNodeRewriter Rewriter;
537   isl::schedule NewSched = Rewriter.visitSchedule(Sched);
538 
539   // Reapply the AST build options. The rewriter must not change the iteration
540   // order of bands. Any other node type is ignored.
541   ApplyASTBuildOptions Applicator(Collector.ASTBuildOptions);
542   NewSched = Applicator.visitSchedule(NewSched);
543 
544   return NewSched;
545 }
546 
547 isl::schedule polly::applyFullUnroll(isl::schedule_node BandToUnroll) {
548   isl::ctx Ctx = BandToUnroll.ctx();
549 
550   // Remove the loop's mark, the loop will disappear anyway.
551   BandToUnroll = removeMark(BandToUnroll);
552   assert(isBandWithSingleLoop(BandToUnroll));
553 
554   isl::multi_union_pw_aff PartialSched = isl::manage(
555       isl_schedule_node_band_get_partial_schedule(BandToUnroll.get()));
556   assert(PartialSched.dim(isl::dim::out).release() == 1 &&
557          "Can only unroll a single dimension");
558   isl::union_pw_aff PartialSchedUAff = PartialSched.at(0);
559 
560   isl::union_set Domain = BandToUnroll.get_domain();
561   PartialSchedUAff = PartialSchedUAff.intersect_domain(Domain);
562   isl::union_map PartialSchedUMap =
563       isl::union_map::from(isl::union_pw_multi_aff(PartialSchedUAff));
564 
565   // Enumerator only the scatter elements.
566   isl::union_set ScatterList = PartialSchedUMap.range();
567 
568   // Enumerate all loop iterations.
569   // TODO: Diagnose if not enumerable or depends on a parameter.
570   SmallVector<isl::point, 16> Elts;
571   ScatterList.foreach_point([&Elts](isl::point P) -> isl::stat {
572     Elts.push_back(P);
573     return isl::stat::ok();
574   });
575 
576   // Don't assume that foreach_point returns in execution order.
577   llvm::sort(Elts, [](isl::point P1, isl::point P2) -> bool {
578     isl::val C1 = P1.get_coordinate_val(isl::dim::set, 0);
579     isl::val C2 = P2.get_coordinate_val(isl::dim::set, 0);
580     return C1.lt(C2);
581   });
582 
583   // Convert the points to a sequence of filters.
584   isl::union_set_list List = isl::union_set_list(Ctx, Elts.size());
585   for (isl::point P : Elts) {
586     // Determine the domains that map this scatter element.
587     isl::union_set DomainFilter = PartialSchedUMap.intersect_range(P).domain();
588 
589     List = List.add(DomainFilter);
590   }
591 
592   // Replace original band with unrolled sequence.
593   isl::schedule_node Body =
594       isl::manage(isl_schedule_node_delete(BandToUnroll.release()));
595   Body = Body.insert_sequence(List);
596   return Body.get_schedule();
597 }
598 
599 isl::schedule polly::applyPartialUnroll(isl::schedule_node BandToUnroll,
600                                         int Factor) {
601   assert(Factor > 0 && "Positive unroll factor required");
602   isl::ctx Ctx = BandToUnroll.ctx();
603 
604   // Remove the mark, save the attribute for later use.
605   BandAttr *Attr;
606   BandToUnroll = removeMark(BandToUnroll, Attr);
607   assert(isBandWithSingleLoop(BandToUnroll));
608 
609   isl::multi_union_pw_aff PartialSched = isl::manage(
610       isl_schedule_node_band_get_partial_schedule(BandToUnroll.get()));
611 
612   // { Stmt[] -> [x] }
613   isl::union_pw_aff PartialSchedUAff = PartialSched.at(0);
614 
615   // Here we assume the schedule stride is one and starts with 0, which is not
616   // necessarily the case.
617   isl::union_pw_aff StridedPartialSchedUAff =
618       isl::union_pw_aff::empty(PartialSchedUAff.get_space());
619   isl::val ValFactor{Ctx, Factor};
620   PartialSchedUAff.foreach_pw_aff([&StridedPartialSchedUAff,
621                                    &ValFactor](isl::pw_aff PwAff) -> isl::stat {
622     isl::space Space = PwAff.get_space();
623     isl::set Universe = isl::set::universe(Space.domain());
624     isl::pw_aff AffFactor{Universe, ValFactor};
625     isl::pw_aff DivSchedAff = PwAff.div(AffFactor).floor().mul(AffFactor);
626     StridedPartialSchedUAff = StridedPartialSchedUAff.union_add(DivSchedAff);
627     return isl::stat::ok();
628   });
629 
630   isl::union_set_list List = isl::union_set_list(Ctx, Factor);
631   for (auto i : seq<int>(0, Factor)) {
632     // { Stmt[] -> [x] }
633     isl::union_map UMap =
634         isl::union_map::from(isl::union_pw_multi_aff(PartialSchedUAff));
635 
636     // { [x] }
637     isl::basic_set Divisible = isDivisibleBySet(Ctx, Factor, i);
638 
639     // { Stmt[] }
640     isl::union_set UnrolledDomain = UMap.intersect_range(Divisible).domain();
641 
642     List = List.add(UnrolledDomain);
643   }
644 
645   isl::schedule_node Body =
646       isl::manage(isl_schedule_node_delete(BandToUnroll.copy()));
647   Body = Body.insert_sequence(List);
648   isl::schedule_node NewLoop =
649       Body.insert_partial_schedule(StridedPartialSchedUAff);
650 
651   MDNode *FollowupMD = nullptr;
652   if (Attr && Attr->Metadata)
653     FollowupMD =
654         findOptionalNodeOperand(Attr->Metadata, LLVMLoopUnrollFollowupUnrolled);
655 
656   isl::id NewBandId = createGeneratedLoopAttr(Ctx, FollowupMD);
657   if (!NewBandId.is_null())
658     NewLoop = insertMark(NewLoop, NewBandId);
659 
660   return NewLoop.get_schedule();
661 }
662 
663 isl::set polly::getPartialTilePrefixes(isl::set ScheduleRange,
664                                        int VectorWidth) {
665   isl_size Dims = ScheduleRange.tuple_dim().release();
666   isl::set LoopPrefixes =
667       ScheduleRange.drop_constraints_involving_dims(isl::dim::set, Dims - 1, 1);
668   auto ExtentPrefixes = addExtentConstraints(LoopPrefixes, VectorWidth);
669   isl::set BadPrefixes = ExtentPrefixes.subtract(ScheduleRange);
670   BadPrefixes = BadPrefixes.project_out(isl::dim::set, Dims - 1, 1);
671   LoopPrefixes = LoopPrefixes.project_out(isl::dim::set, Dims - 1, 1);
672   return LoopPrefixes.subtract(BadPrefixes);
673 }
674 
675 isl::union_set polly::getIsolateOptions(isl::set IsolateDomain,
676                                         isl_size OutDimsNum) {
677   isl_size Dims = IsolateDomain.tuple_dim().release();
678   assert(OutDimsNum <= Dims &&
679          "The isl::set IsolateDomain is used to describe the range of schedule "
680          "dimensions values, which should be isolated. Consequently, the "
681          "number of its dimensions should be greater than or equal to the "
682          "number of the schedule dimensions.");
683   isl::map IsolateRelation = isl::map::from_domain(IsolateDomain);
684   IsolateRelation = IsolateRelation.move_dims(isl::dim::out, 0, isl::dim::in,
685                                               Dims - OutDimsNum, OutDimsNum);
686   isl::set IsolateOption = IsolateRelation.wrap();
687   isl::id Id = isl::id::alloc(IsolateOption.ctx(), "isolate", nullptr);
688   IsolateOption = IsolateOption.set_tuple_id(Id);
689   return isl::union_set(IsolateOption);
690 }
691 
692 isl::union_set polly::getDimOptions(isl::ctx Ctx, const char *Option) {
693   isl::space Space(Ctx, 0, 1);
694   auto DimOption = isl::set::universe(Space);
695   auto Id = isl::id::alloc(Ctx, Option, nullptr);
696   DimOption = DimOption.set_tuple_id(Id);
697   return isl::union_set(DimOption);
698 }
699 
700 isl::schedule_node polly::tileNode(isl::schedule_node Node,
701                                    const char *Identifier,
702                                    ArrayRef<int> TileSizes,
703                                    int DefaultTileSize) {
704   auto Space = isl::manage(isl_schedule_node_band_get_space(Node.get()));
705   auto Dims = Space.dim(isl::dim::set);
706   auto Sizes = isl::multi_val::zero(Space);
707   std::string IdentifierString(Identifier);
708   for (auto i : seq<isl_size>(0, Dims.release())) {
709     auto tileSize =
710         i < (isl_size)TileSizes.size() ? TileSizes[i] : DefaultTileSize;
711     Sizes = Sizes.set_val(i, isl::val(Node.ctx(), tileSize));
712   }
713   auto TileLoopMarkerStr = IdentifierString + " - Tiles";
714   auto TileLoopMarker = isl::id::alloc(Node.ctx(), TileLoopMarkerStr, nullptr);
715   Node = Node.insert_mark(TileLoopMarker);
716   Node = Node.child(0);
717   Node =
718       isl::manage(isl_schedule_node_band_tile(Node.release(), Sizes.release()));
719   Node = Node.child(0);
720   auto PointLoopMarkerStr = IdentifierString + " - Points";
721   auto PointLoopMarker =
722       isl::id::alloc(Node.ctx(), PointLoopMarkerStr, nullptr);
723   Node = Node.insert_mark(PointLoopMarker);
724   return Node.child(0);
725 }
726 
727 isl::schedule_node polly::applyRegisterTiling(isl::schedule_node Node,
728                                               ArrayRef<int> TileSizes,
729                                               int DefaultTileSize) {
730   Node = tileNode(Node, "Register tiling", TileSizes, DefaultTileSize);
731   auto Ctx = Node.ctx();
732   return Node.as<isl::schedule_node_band>().set_ast_build_options(
733       isl::union_set(Ctx, "{unroll[x]}"));
734 }
735 
736 /// Find statements and sub-loops in (possibly nested) sequences.
737 static void
738 collectFussionableStmts(isl::schedule_node Node,
739                         SmallVectorImpl<isl::schedule_node> &ScheduleStmts) {
740   if (isBand(Node) || isLeaf(Node)) {
741     ScheduleStmts.push_back(Node);
742     return;
743   }
744 
745   if (Node.has_children()) {
746     isl::schedule_node C = Node.first_child();
747     while (true) {
748       collectFussionableStmts(C, ScheduleStmts);
749       if (!C.has_next_sibling())
750         break;
751       C = C.next_sibling();
752     }
753   }
754 }
755 
756 isl::schedule polly::applyMaxFission(isl::schedule_node BandToFission) {
757   isl::ctx Ctx = BandToFission.ctx();
758   BandToFission = removeMark(BandToFission);
759   isl::schedule_node BandBody = BandToFission.child(0);
760 
761   SmallVector<isl::schedule_node> FissionableStmts;
762   collectFussionableStmts(BandBody, FissionableStmts);
763   size_t N = FissionableStmts.size();
764 
765   // Collect the domain for each of the statements that will get their own loop.
766   isl::union_set_list DomList = isl::union_set_list(Ctx, N);
767   for (size_t i = 0; i < N; ++i) {
768     isl::schedule_node BodyPart = FissionableStmts[i];
769     DomList = DomList.add(BodyPart.get_domain());
770   }
771 
772   // Apply the fission by copying the entire loop, but inserting a filter for
773   // the statement domains for each fissioned loop.
774   isl::schedule_node Fissioned = BandToFission.insert_sequence(DomList);
775 
776   return Fissioned.get_schedule();
777 }
778