1 //===------ FlattenAlgo.cpp ------------------------------------*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Main algorithm of the FlattenSchedulePass. This is a separate file to avoid
11 // the unittest for this requiring linking against LLVM.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "polly/FlattenAlgo.h"
16 #include "llvm/Support/Debug.h"
17 #define DEBUG_TYPE "polly-flatten-algo"
18 
19 using namespace polly;
20 using namespace llvm;
21 
22 namespace {
23 
24 /// Whether a dimension of a set is bounded (lower and upper) by a constant,
25 /// i.e. there are two constants Min and Max, such that every value x of the
26 /// chosen dimensions is Min <= x <= Max.
27 bool isDimBoundedByConstant(IslPtr<isl_set> Set, unsigned dim) {
28   auto ParamDims = isl_set_dim(Set.keep(), isl_dim_param);
29   Set = give(isl_set_project_out(Set.take(), isl_dim_param, 0, ParamDims));
30   Set = give(isl_set_project_out(Set.take(), isl_dim_set, 0, dim));
31   auto SetDims = isl_set_dim(Set.keep(), isl_dim_set);
32   Set = give(isl_set_project_out(Set.take(), isl_dim_set, 1, SetDims - 1));
33   return isl_set_is_bounded(Set.keep());
34 }
35 
36 /// Whether a dimension of a set is (lower and upper) bounded by a constant or
37 /// parameters, i.e. there are two expressions Min_p and Max_p of the parameters
38 /// p, such that every value x of the chosen dimensions is
39 /// Min_p <= x <= Max_p.
40 bool isDimBoundedByParameter(IslPtr<isl_set> Set, unsigned dim) {
41   Set = give(isl_set_project_out(Set.take(), isl_dim_set, 0, dim));
42   auto SetDims = isl_set_dim(Set.keep(), isl_dim_set);
43   Set = give(isl_set_project_out(Set.take(), isl_dim_set, 1, SetDims - 1));
44   return isl_set_is_bounded(Set.keep());
45 }
46 
47 /// Whether BMap's first out-dimension is not a constant.
48 bool isVariableDim(NonowningIslPtr<isl_basic_map> BMap) {
49   auto FixedVal =
50       give(isl_basic_map_plain_get_val_if_fixed(BMap.keep(), isl_dim_out, 0));
51   return !FixedVal || isl_val_is_nan(FixedVal.keep());
52 }
53 
54 /// Whether Map's first out dimension is no constant nor piecewise constant.
55 bool isVariableDim(NonowningIslPtr<isl_map> Map) {
56   return foreachEltWithBreak(Map, [](IslPtr<isl_basic_map> BMap) -> isl_stat {
57     if (isVariableDim(BMap))
58       return isl_stat_error;
59     return isl_stat_ok;
60   });
61 }
62 
63 /// Whether UMap's first out dimension is no (piecewise) constant.
64 bool isVariableDim(NonowningIslPtr<isl_union_map> UMap) {
65   return foreachEltWithBreak(UMap, [](IslPtr<isl_map> Map) -> isl_stat {
66     if (isVariableDim(Map))
67       return isl_stat_error;
68     return isl_stat_ok;
69   });
70 }
71 
72 /// If @p PwAff maps to a constant, return said constant. If @p Max/@p Min, it
73 /// can also be a piecewise constant and it would return the minimum/maximum
74 /// value. Otherwise, return NaN.
75 IslPtr<isl_val> getConstant(IslPtr<isl_pw_aff> PwAff, bool Max, bool Min) {
76   assert(!Max || !Min);
77   IslPtr<isl_val> Result;
78   foreachPieceWithBreak(
79       PwAff, [=, &Result](IslPtr<isl_set> Set, IslPtr<isl_aff> Aff) {
80         if (Result && isl_val_is_nan(Result.keep()))
81           return isl_stat_ok;
82 
83         // TODO: If Min/Max, we can also determine a minimum/maximum value if
84         // Set is constant-bounded.
85         if (!isl_aff_is_cst(Aff.keep())) {
86           Result = give(isl_val_nan(Aff.getCtx()));
87           return isl_stat_error;
88         }
89 
90         auto ThisVal = give(isl_aff_get_constant_val(Aff.keep()));
91         if (!Result) {
92           Result = ThisVal;
93           return isl_stat_ok;
94         }
95 
96         if (isl_val_eq(Result.keep(), ThisVal.keep()))
97           return isl_stat_ok;
98 
99         if (Max && isl_val_gt(ThisVal.keep(), Result.keep())) {
100           Result = ThisVal;
101           return isl_stat_ok;
102         }
103 
104         if (Min && isl_val_lt(ThisVal.keep(), Result.keep())) {
105           Result = ThisVal;
106           return isl_stat_ok;
107         }
108 
109         // Not compatible
110         Result = give(isl_val_nan(Aff.getCtx()));
111         return isl_stat_error;
112       });
113   return Result;
114 }
115 
116 /// Compute @p UPwAff - @p Val.
117 IslPtr<isl_union_pw_aff> subtract(IslPtr<isl_union_pw_aff> UPwAff,
118                                   IslPtr<isl_val> Val) {
119   if (isl_val_is_zero(Val.keep()))
120     return UPwAff;
121 
122   auto Result =
123       give(isl_union_pw_aff_empty(isl_union_pw_aff_get_space(UPwAff.keep())));
124   foreachElt(UPwAff, [=, &Result](IslPtr<isl_pw_aff> PwAff) {
125     auto ValAff = give(isl_pw_aff_val_on_domain(
126         isl_set_universe(isl_space_domain(isl_pw_aff_get_space(PwAff.keep()))),
127         Val.copy()));
128     auto Subtracted = give(isl_pw_aff_sub(PwAff.copy(), ValAff.take()));
129     Result = give(isl_union_pw_aff_union_add(
130         Result.take(), isl_union_pw_aff_from_pw_aff(Subtracted.take())));
131   });
132   return Result;
133 }
134 
135 /// Compute @UPwAff * @p Val.
136 IslPtr<isl_union_pw_aff> multiply(IslPtr<isl_union_pw_aff> UPwAff,
137                                   IslPtr<isl_val> Val) {
138   if (isl_val_is_one(Val.keep()))
139     return UPwAff;
140 
141   auto Result =
142       give(isl_union_pw_aff_empty(isl_union_pw_aff_get_space(UPwAff.keep())));
143   foreachElt(UPwAff, [=, &Result](IslPtr<isl_pw_aff> PwAff) {
144     auto ValAff = give(isl_pw_aff_val_on_domain(
145         isl_set_universe(isl_space_domain(isl_pw_aff_get_space(PwAff.keep()))),
146         Val.copy()));
147     auto Multiplied = give(isl_pw_aff_mul(PwAff.copy(), ValAff.take()));
148     Result = give(isl_union_pw_aff_union_add(
149         Result.take(), isl_union_pw_aff_from_pw_aff(Multiplied.take())));
150   });
151   return Result;
152 }
153 
154 /// Remove @p n dimensions from @p UMap's range, starting at @p first.
155 ///
156 /// It is assumed that all maps in the maps have at least the necessary number
157 /// of out dimensions.
158 IslPtr<isl_union_map> scheduleProjectOut(NonowningIslPtr<isl_union_map> UMap,
159                                          unsigned first, unsigned n) {
160   if (n == 0)
161     return UMap; /* isl_map_project_out would also reset the tuple, which should
162                     have no effect on schedule ranges */
163 
164   auto Result = give(isl_union_map_empty(isl_union_map_get_space(UMap.keep())));
165   foreachElt(UMap, [=, &Result](IslPtr<isl_map> Map) {
166     auto Outprojected =
167         give(isl_map_project_out(Map.take(), isl_dim_out, first, n));
168     Result = give(isl_union_map_add_map(Result.take(), Outprojected.take()));
169   });
170   return Result;
171 }
172 
173 /// Return the number of dimensions in the input map's range.
174 ///
175 /// Because this function takes an isl_union_map, the out dimensions could be
176 /// different. We return the maximum number in this case. However, a different
177 /// number of dimensions is not supported by the other code in this file.
178 size_t scheduleScatterDims(NonowningIslPtr<isl_union_map> Schedule) {
179   unsigned Dims = 0;
180   foreachElt(Schedule, [&Dims](IslPtr<isl_map> Map) {
181     Dims = std::max(Dims, isl_map_dim(Map.keep(), isl_dim_out));
182   });
183   return Dims;
184 }
185 
186 /// Return the @p pos' range dimension, converted to an isl_union_pw_aff.
187 IslPtr<isl_union_pw_aff> scheduleExtractDimAff(IslPtr<isl_union_map> UMap,
188                                                unsigned pos) {
189   auto SingleUMap =
190       give(isl_union_map_empty(isl_union_map_get_space(UMap.keep())));
191   foreachElt(UMap, [=, &SingleUMap](IslPtr<isl_map> Map) {
192     auto MapDims = isl_map_dim(Map.keep(), isl_dim_out);
193     auto SingleMap = give(isl_map_project_out(Map.take(), isl_dim_out, 0, pos));
194     SingleMap = give(isl_map_project_out(SingleMap.take(), isl_dim_out, 1,
195                                          MapDims - pos - 1));
196     SingleUMap =
197         give(isl_union_map_add_map(SingleUMap.take(), SingleMap.take()));
198   });
199 
200   auto UAff = give(isl_union_pw_multi_aff_from_union_map(SingleUMap.take()));
201   auto FirstMAff =
202       give(isl_multi_union_pw_aff_from_union_pw_multi_aff(UAff.take()));
203   return give(isl_multi_union_pw_aff_get_union_pw_aff(FirstMAff.keep(), 0));
204 }
205 
206 /// Flatten a sequence-like first dimension.
207 ///
208 /// A sequence-like scatter dimension is constant, or at least only small
209 /// variation, typically the result of ordering a sequence of different
210 /// statements. An example would be:
211 ///   { Stmt_A[] -> [0, X, ...]; Stmt_B[] -> [1, Y, ...] }
212 /// to schedule all instances of Stmt_A before any instance of Stmt_B.
213 ///
214 /// To flatten, first begin with an offset of zero. Then determine the lowest
215 /// possible value of the dimension, call it "i" [In the example we start at 0].
216 /// Considering only schedules with that value, consider only instances with
217 /// that value and determine the extent of the next dimension. Let l_X(i) and
218 /// u_X(i) its minimum (lower bound) and maximum (upper bound) value. Add them
219 /// as "Offset + X - l_X(i)" to the new schedule, then add "u_X(i) - l_X(i) + 1"
220 /// to Offset and remove all i-instances from the old schedule. Repeat with the
221 /// remaining lowest value i' until there are no instances in the old schedule
222 /// left.
223 /// The example schedule would be transformed to:
224 ///   { Stmt_X[] -> [X - l_X, ...]; Stmt_B -> [l_X - u_X + 1 + Y - l_Y, ...] }
225 IslPtr<isl_union_map> tryFlattenSequence(IslPtr<isl_union_map> Schedule) {
226   auto IslCtx = Schedule.getCtx();
227   auto ScatterSet =
228       give(isl_set_from_union_set(isl_union_map_range(Schedule.copy())));
229 
230   auto ParamSpace =
231       give(isl_space_params(isl_union_map_get_space(Schedule.keep())));
232   auto Dims = isl_set_dim(ScatterSet.keep(), isl_dim_set);
233   assert(Dims >= 2);
234 
235   // Would cause an infinite loop.
236   if (!isDimBoundedByConstant(ScatterSet, 0)) {
237     DEBUG(dbgs() << "Abort; dimension is not of fixed size\n");
238     return nullptr;
239   }
240 
241   auto AllDomains = give(isl_union_map_domain(Schedule.copy()));
242   auto AllDomainsToNull =
243       give(isl_union_pw_multi_aff_from_domain(AllDomains.take()));
244 
245   auto NewSchedule = give(isl_union_map_empty(ParamSpace.copy()));
246   auto Counter = give(isl_pw_aff_zero_on_domain(isl_local_space_from_space(
247       isl_space_set_from_params(ParamSpace.copy()))));
248 
249   while (!isl_set_is_empty(ScatterSet.keep())) {
250     DEBUG(dbgs() << "Next counter:\n  " << Counter << "\n");
251     DEBUG(dbgs() << "Remaining scatter set:\n  " << ScatterSet << "\n");
252     auto ThisSet =
253         give(isl_set_project_out(ScatterSet.copy(), isl_dim_set, 1, Dims - 1));
254     auto ThisFirst = give(isl_set_lexmin(ThisSet.take()));
255     auto ScatterFirst =
256         give(isl_set_add_dims(ThisFirst.take(), isl_dim_set, Dims - 1));
257 
258     auto SubSchedule = give(isl_union_map_intersect_range(
259         Schedule.copy(), isl_union_set_from_set(ScatterFirst.copy())));
260     SubSchedule = scheduleProjectOut(std::move(SubSchedule), 0, 1);
261     SubSchedule = flattenSchedule(std::move(SubSchedule));
262 
263     auto SubDims = scheduleScatterDims(SubSchedule);
264     auto FirstSubSchedule = scheduleProjectOut(SubSchedule, 1, SubDims - 1);
265     auto FirstScheduleAff = scheduleExtractDimAff(FirstSubSchedule, 0);
266     auto RemainingSubSchedule =
267         scheduleProjectOut(std::move(SubSchedule), 0, 1);
268 
269     auto FirstSubScatter = give(
270         isl_set_from_union_set(isl_union_map_range(FirstSubSchedule.take())));
271     DEBUG(dbgs() << "Next step in sequence is:\n  " << FirstSubScatter << "\n");
272 
273     if (!isDimBoundedByParameter(FirstSubScatter, 0)) {
274       DEBUG(dbgs() << "Abort; sequence step is not bounded\n");
275       return nullptr;
276     }
277 
278     auto FirstSubScatterMap = give(isl_map_from_range(FirstSubScatter.take()));
279 
280     // isl_set_dim_max returns a strange isl_pw_aff with domain tuple_id of
281     // 'none'. It doesn't match with any space including a 0-dimensional
282     // anonymous tuple.
283     // Interesting, one can create such a set using
284     // isl_set_universe(ParamSpace). Bug?
285     auto PartMin = give(isl_map_dim_min(FirstSubScatterMap.copy(), 0));
286     auto PartMax = give(isl_map_dim_max(FirstSubScatterMap.take(), 0));
287     auto One = give(isl_pw_aff_val_on_domain(
288         isl_set_universe(isl_space_set_from_params(ParamSpace.copy())),
289         isl_val_one(IslCtx)));
290     auto PartLen = give(isl_pw_aff_add(
291         isl_pw_aff_add(PartMax.take(), isl_pw_aff_neg(PartMin.copy())),
292         One.take()));
293 
294     auto AllPartMin = give(isl_union_pw_aff_pullback_union_pw_multi_aff(
295         isl_union_pw_aff_from_pw_aff(PartMin.take()), AllDomainsToNull.copy()));
296     auto FirstScheduleAffNormalized =
297         give(isl_union_pw_aff_sub(FirstScheduleAff.take(), AllPartMin.take()));
298     auto AllCounter = give(isl_union_pw_aff_pullback_union_pw_multi_aff(
299         isl_union_pw_aff_from_pw_aff(Counter.copy()), AllDomainsToNull.copy()));
300     auto FirstScheduleAffWithOffset = give(isl_union_pw_aff_add(
301         FirstScheduleAffNormalized.take(), AllCounter.take()));
302 
303     auto ScheduleWithOffset = give(isl_union_map_flat_range_product(
304         isl_union_map_from_union_pw_aff(FirstScheduleAffWithOffset.take()),
305         RemainingSubSchedule.take()));
306     NewSchedule = give(
307         isl_union_map_union(NewSchedule.take(), ScheduleWithOffset.take()));
308 
309     ScatterSet = give(isl_set_subtract(ScatterSet.take(), ScatterFirst.take()));
310     Counter = give(isl_pw_aff_add(Counter.take(), PartLen.take()));
311   }
312 
313   DEBUG(dbgs() << "Sequence-flatten result is:\n  " << NewSchedule << "\n");
314   return NewSchedule;
315 }
316 
317 /// Flatten a loop-like first dimension.
318 ///
319 /// A loop-like dimension is one that depends on a variable (usually a loop's
320 /// induction variable). Let the input schedule look like this:
321 ///   { Stmt[i] -> [i, X, ...] }
322 ///
323 /// To flatten, we determine the largest extent of X which may not depend on the
324 /// actual value of i. Let l_X() the smallest possible value of X and u_X() its
325 /// largest value. Then, construct a new schedule
326 ///   { Stmt[i] -> [i * (u_X() - l_X() + 1), ...] }
327 IslPtr<isl_union_map> tryFlattenLoop(IslPtr<isl_union_map> Schedule) {
328   assert(scheduleScatterDims(Schedule) >= 2);
329 
330   auto Remaining = scheduleProjectOut(Schedule, 0, 1);
331   auto SubSchedule = flattenSchedule(Remaining);
332   auto SubDims = scheduleScatterDims(SubSchedule);
333 
334   auto SubExtent =
335       give(isl_set_from_union_set(isl_union_map_range(SubSchedule.copy())));
336   auto SubExtentDims = isl_set_dim(SubExtent.keep(), isl_dim_param);
337   SubExtent = give(
338       isl_set_project_out(SubExtent.take(), isl_dim_param, 0, SubExtentDims));
339   SubExtent =
340       give(isl_set_project_out(SubExtent.take(), isl_dim_set, 1, SubDims - 1));
341 
342   if (!isDimBoundedByConstant(SubExtent, 0)) {
343     DEBUG(dbgs() << "Abort; dimension not bounded by constant\n");
344     return nullptr;
345   }
346 
347   auto Min = give(isl_set_dim_min(SubExtent.copy(), 0));
348   DEBUG(dbgs() << "Min bound:\n  " << Min << "\n");
349   auto MinVal = getConstant(Min, false, true);
350   auto Max = give(isl_set_dim_max(SubExtent.take(), 0));
351   DEBUG(dbgs() << "Max bound:\n  " << Max << "\n");
352   auto MaxVal = getConstant(Max, true, false);
353 
354   if (!MinVal || !MaxVal || isl_val_is_nan(MinVal.keep()) ||
355       isl_val_is_nan(MaxVal.keep())) {
356     DEBUG(dbgs() << "Abort; dimension bounds could not be determined\n");
357     return nullptr;
358   }
359 
360   auto FirstSubScheduleAff = scheduleExtractDimAff(SubSchedule, 0);
361   auto RemainingSubSchedule = scheduleProjectOut(std::move(SubSchedule), 0, 1);
362 
363   auto LenVal =
364       give(isl_val_add_ui(isl_val_sub(MaxVal.take(), MinVal.copy()), 1));
365   auto FirstSubScheduleNormalized = subtract(FirstSubScheduleAff, MinVal);
366 
367   // TODO: Normalize FirstAff to zero (convert to isl_map, determine minimum,
368   // subtract it)
369   auto FirstAff = scheduleExtractDimAff(Schedule, 0);
370   auto Offset = multiply(FirstAff, LenVal);
371   auto Index = give(
372       isl_union_pw_aff_add(FirstSubScheduleNormalized.take(), Offset.take()));
373   auto IndexMap = give(isl_union_map_from_union_pw_aff(Index.take()));
374 
375   auto Result = give(isl_union_map_flat_range_product(
376       IndexMap.take(), RemainingSubSchedule.take()));
377   DEBUG(dbgs() << "Loop-flatten result is:\n  " << Result << "\n");
378   return Result;
379 }
380 } // anonymous namespace
381 
382 IslPtr<isl_union_map> polly::flattenSchedule(IslPtr<isl_union_map> Schedule) {
383   auto Dims = scheduleScatterDims(Schedule);
384   DEBUG(dbgs() << "Recursive schedule to process:\n  " << Schedule << "\n");
385 
386   // Base case; no dimensions left
387   if (Dims == 0) {
388     // TODO: Add one dimension?
389     return Schedule;
390   }
391 
392   // Base case; already one-dimensional
393   if (Dims == 1)
394     return Schedule;
395 
396   // Fixed dimension; no need to preserve variabledness.
397   if (!isVariableDim(Schedule)) {
398     DEBUG(dbgs() << "Fixed dimension; try sequence flattening\n");
399     auto NewScheduleSequence = tryFlattenSequence(Schedule);
400     if (NewScheduleSequence)
401       return NewScheduleSequence;
402   }
403 
404   // Constant stride
405   DEBUG(dbgs() << "Try loop flattening\n");
406   auto NewScheduleLoop = tryFlattenLoop(Schedule);
407   if (NewScheduleLoop)
408     return NewScheduleLoop;
409 
410   // Try again without loop condition (may blow up the number of pieces!!)
411   DEBUG(dbgs() << "Try sequence flattening again\n");
412   auto NewScheduleSequence = tryFlattenSequence(Schedule);
413   if (NewScheduleSequence)
414     return NewScheduleSequence;
415 
416   // Cannot flatten
417   return Schedule;
418 }
419