1 //===------ FlattenAlgo.cpp ------------------------------------*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Main algorithm of the FlattenSchedulePass. This is a separate file to avoid
11 // the unittest for this requiring linking against LLVM.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "polly/FlattenAlgo.h"
16 #include "polly/Support/ISLOStream.h"
17 #include "polly/Support/ISLTools.h"
18 #include "llvm/Support/Debug.h"
19 #define DEBUG_TYPE "polly-flatten-algo"
20 
21 using namespace polly;
22 using namespace llvm;
23 
24 namespace {
25 
26 /// Whether a dimension of a set is bounded (lower and upper) by a constant,
27 /// i.e. there are two constants Min and Max, such that every value x of the
28 /// chosen dimensions is Min <= x <= Max.
29 bool isDimBoundedByConstant(isl::set Set, unsigned dim) {
30   auto ParamDims = Set.dim(isl::dim::param);
31   Set = Set.project_out(isl::dim::param, 0, ParamDims);
32   Set = Set.project_out(isl::dim::set, 0, dim);
33   auto SetDims = Set.dim(isl::dim::set);
34   Set = Set.project_out(isl::dim::set, 1, SetDims - 1);
35   return bool(Set.is_bounded());
36 }
37 
38 /// Whether a dimension of a set is (lower and upper) bounded by a constant or
39 /// parameters, i.e. there are two expressions Min_p and Max_p of the parameters
40 /// p, such that every value x of the chosen dimensions is
41 /// Min_p <= x <= Max_p.
42 bool isDimBoundedByParameter(isl::set Set, unsigned dim) {
43   Set = Set.project_out(isl::dim::set, 0, dim);
44   auto SetDims = Set.dim(isl::dim::set);
45   Set = Set.project_out(isl::dim::set, 1, SetDims - 1);
46   return bool(Set.is_bounded());
47 }
48 
49 /// Whether BMap's first out-dimension is not a constant.
50 bool isVariableDim(const isl::basic_map &BMap) {
51   auto FixedVal = BMap.plain_get_val_if_fixed(isl::dim::out, 0);
52   return !FixedVal || FixedVal.is_nan();
53 }
54 
55 /// Whether Map's first out dimension is no constant nor piecewise constant.
56 bool isVariableDim(const isl::map &Map) {
57   for (isl::basic_map BMap : Map.get_basic_map_list())
58     if (isVariableDim(BMap))
59       return false;
60 
61   return true;
62 }
63 
64 /// Whether UMap's first out dimension is no (piecewise) constant.
65 bool isVariableDim(const isl::union_map &UMap) {
66   for (isl::map Map : UMap.get_map_list())
67     if (isVariableDim(Map))
68       return false;
69   return true;
70 }
71 
72 /// Compute @p UPwAff - @p Val.
73 isl::union_pw_aff subtract(isl::union_pw_aff UPwAff, isl::val Val) {
74   if (Val.is_zero())
75     return UPwAff;
76 
77   auto Result = isl::union_pw_aff::empty(UPwAff.get_space());
78   isl::stat Stat =
79       UPwAff.foreach_pw_aff([=, &Result](isl::pw_aff PwAff) -> isl::stat {
80         auto ValAff =
81             isl::pw_aff(isl::set::universe(PwAff.get_space().domain()), Val);
82         auto Subtracted = PwAff.sub(ValAff);
83         Result = Result.union_add(isl::union_pw_aff(Subtracted));
84         return isl::stat::ok();
85       });
86   if (Stat.is_error())
87     return {};
88   return Result;
89 }
90 
91 /// Compute @UPwAff * @p Val.
92 isl::union_pw_aff multiply(isl::union_pw_aff UPwAff, isl::val Val) {
93   if (Val.is_one())
94     return UPwAff;
95 
96   auto Result = isl::union_pw_aff::empty(UPwAff.get_space());
97   isl::stat Stat =
98       UPwAff.foreach_pw_aff([=, &Result](isl::pw_aff PwAff) -> isl::stat {
99         auto ValAff =
100             isl::pw_aff(isl::set::universe(PwAff.get_space().domain()), Val);
101         auto Multiplied = PwAff.mul(ValAff);
102         Result = Result.union_add(Multiplied);
103         return isl::stat::ok();
104       });
105   if (Stat.is_error())
106     return {};
107   return Result;
108 }
109 
110 /// Remove @p n dimensions from @p UMap's range, starting at @p first.
111 ///
112 /// It is assumed that all maps in the maps have at least the necessary number
113 /// of out dimensions.
114 isl::union_map scheduleProjectOut(const isl::union_map &UMap, unsigned first,
115                                   unsigned n) {
116   if (n == 0)
117     return UMap; /* isl_map_project_out would also reset the tuple, which should
118                     have no effect on schedule ranges */
119 
120   auto Result = isl::union_map::empty(UMap.get_space());
121   for (isl::map Map : UMap.get_map_list()) {
122     auto Outprojected = Map.project_out(isl::dim::out, first, n);
123     Result = Result.add_map(Outprojected);
124   }
125   return Result;
126 }
127 
128 /// Return the number of dimensions in the input map's range.
129 ///
130 /// Because this function takes an isl_union_map, the out dimensions could be
131 /// different. We return the maximum number in this case. However, a different
132 /// number of dimensions is not supported by the other code in this file.
133 size_t scheduleScatterDims(const isl::union_map &Schedule) {
134   unsigned Dims = 0;
135   for (isl::map Map : Schedule.get_map_list())
136     Dims = std::max(Dims, Map.dim(isl::dim::out));
137   return Dims;
138 }
139 
140 /// Return the @p pos' range dimension, converted to an isl_union_pw_aff.
141 isl::union_pw_aff scheduleExtractDimAff(isl::union_map UMap, unsigned pos) {
142   auto SingleUMap = isl::union_map::empty(UMap.get_space());
143   for (isl::map Map : UMap.get_map_list()) {
144     unsigned MapDims = Map.dim(isl::dim::out);
145     isl::map SingleMap = Map.project_out(isl::dim::out, 0, pos);
146     SingleMap = SingleMap.project_out(isl::dim::out, 1, MapDims - pos - 1);
147     SingleUMap = SingleUMap.add_map(SingleMap);
148   };
149 
150   auto UAff = isl::union_pw_multi_aff(SingleUMap);
151   auto FirstMAff = isl::multi_union_pw_aff(UAff);
152   return FirstMAff.get_union_pw_aff(0);
153 }
154 
155 /// Flatten a sequence-like first dimension.
156 ///
157 /// A sequence-like scatter dimension is constant, or at least only small
158 /// variation, typically the result of ordering a sequence of different
159 /// statements. An example would be:
160 ///   { Stmt_A[] -> [0, X, ...]; Stmt_B[] -> [1, Y, ...] }
161 /// to schedule all instances of Stmt_A before any instance of Stmt_B.
162 ///
163 /// To flatten, first begin with an offset of zero. Then determine the lowest
164 /// possible value of the dimension, call it "i" [In the example we start at 0].
165 /// Considering only schedules with that value, consider only instances with
166 /// that value and determine the extent of the next dimension. Let l_X(i) and
167 /// u_X(i) its minimum (lower bound) and maximum (upper bound) value. Add them
168 /// as "Offset + X - l_X(i)" to the new schedule, then add "u_X(i) - l_X(i) + 1"
169 /// to Offset and remove all i-instances from the old schedule. Repeat with the
170 /// remaining lowest value i' until there are no instances in the old schedule
171 /// left.
172 /// The example schedule would be transformed to:
173 ///   { Stmt_X[] -> [X - l_X, ...]; Stmt_B -> [l_X - u_X + 1 + Y - l_Y, ...] }
174 isl::union_map tryFlattenSequence(isl::union_map Schedule) {
175   auto IslCtx = Schedule.get_ctx();
176   auto ScatterSet = isl::set(Schedule.range());
177 
178   auto ParamSpace = Schedule.get_space().params();
179   auto Dims = ScatterSet.dim(isl::dim::set);
180   assert(Dims >= 2);
181 
182   // Would cause an infinite loop.
183   if (!isDimBoundedByConstant(ScatterSet, 0)) {
184     LLVM_DEBUG(dbgs() << "Abort; dimension is not of fixed size\n");
185     return nullptr;
186   }
187 
188   auto AllDomains = Schedule.domain();
189   auto AllDomainsToNull = isl::union_pw_multi_aff(AllDomains);
190 
191   auto NewSchedule = isl::union_map::empty(ParamSpace);
192   auto Counter = isl::pw_aff(isl::local_space(ParamSpace.set_from_params()));
193 
194   while (!ScatterSet.is_empty()) {
195     LLVM_DEBUG(dbgs() << "Next counter:\n  " << Counter << "\n");
196     LLVM_DEBUG(dbgs() << "Remaining scatter set:\n  " << ScatterSet << "\n");
197     auto ThisSet = ScatterSet.project_out(isl::dim::set, 1, Dims - 1);
198     auto ThisFirst = ThisSet.lexmin();
199     auto ScatterFirst = ThisFirst.add_dims(isl::dim::set, Dims - 1);
200 
201     auto SubSchedule = Schedule.intersect_range(ScatterFirst);
202     SubSchedule = scheduleProjectOut(SubSchedule, 0, 1);
203     SubSchedule = flattenSchedule(SubSchedule);
204 
205     auto SubDims = scheduleScatterDims(SubSchedule);
206     auto FirstSubSchedule = scheduleProjectOut(SubSchedule, 1, SubDims - 1);
207     auto FirstScheduleAff = scheduleExtractDimAff(FirstSubSchedule, 0);
208     auto RemainingSubSchedule = scheduleProjectOut(SubSchedule, 0, 1);
209 
210     auto FirstSubScatter = isl::set(FirstSubSchedule.range());
211     LLVM_DEBUG(dbgs() << "Next step in sequence is:\n  " << FirstSubScatter
212                       << "\n");
213 
214     if (!isDimBoundedByParameter(FirstSubScatter, 0)) {
215       LLVM_DEBUG(dbgs() << "Abort; sequence step is not bounded\n");
216       return nullptr;
217     }
218 
219     auto FirstSubScatterMap = isl::map::from_range(FirstSubScatter);
220 
221     // isl_set_dim_max returns a strange isl_pw_aff with domain tuple_id of
222     // 'none'. It doesn't match with any space including a 0-dimensional
223     // anonymous tuple.
224     // Interesting, one can create such a set using
225     // isl_set_universe(ParamSpace). Bug?
226     auto PartMin = FirstSubScatterMap.dim_min(0);
227     auto PartMax = FirstSubScatterMap.dim_max(0);
228     auto One = isl::pw_aff(isl::set::universe(ParamSpace.set_from_params()),
229                            isl::val::one(IslCtx));
230     auto PartLen = PartMax.add(PartMin.neg()).add(One);
231 
232     auto AllPartMin = isl::union_pw_aff(PartMin).pullback(AllDomainsToNull);
233     auto FirstScheduleAffNormalized = FirstScheduleAff.sub(AllPartMin);
234     auto AllCounter = isl::union_pw_aff(Counter).pullback(AllDomainsToNull);
235     auto FirstScheduleAffWithOffset =
236         FirstScheduleAffNormalized.add(AllCounter);
237 
238     auto ScheduleWithOffset = isl::union_map(FirstScheduleAffWithOffset)
239                                   .flat_range_product(RemainingSubSchedule);
240     NewSchedule = NewSchedule.unite(ScheduleWithOffset);
241 
242     ScatterSet = ScatterSet.subtract(ScatterFirst);
243     Counter = Counter.add(PartLen);
244   }
245 
246   LLVM_DEBUG(dbgs() << "Sequence-flatten result is:\n  " << NewSchedule
247                     << "\n");
248   return NewSchedule;
249 }
250 
251 /// Flatten a loop-like first dimension.
252 ///
253 /// A loop-like dimension is one that depends on a variable (usually a loop's
254 /// induction variable). Let the input schedule look like this:
255 ///   { Stmt[i] -> [i, X, ...] }
256 ///
257 /// To flatten, we determine the largest extent of X which may not depend on the
258 /// actual value of i. Let l_X() the smallest possible value of X and u_X() its
259 /// largest value. Then, construct a new schedule
260 ///   { Stmt[i] -> [i * (u_X() - l_X() + 1), ...] }
261 isl::union_map tryFlattenLoop(isl::union_map Schedule) {
262   assert(scheduleScatterDims(Schedule) >= 2);
263 
264   auto Remaining = scheduleProjectOut(Schedule, 0, 1);
265   auto SubSchedule = flattenSchedule(Remaining);
266   auto SubDims = scheduleScatterDims(SubSchedule);
267 
268   auto SubExtent = isl::set(SubSchedule.range());
269   auto SubExtentDims = SubExtent.dim(isl::dim::param);
270   SubExtent = SubExtent.project_out(isl::dim::param, 0, SubExtentDims);
271   SubExtent = SubExtent.project_out(isl::dim::set, 1, SubDims - 1);
272 
273   if (!isDimBoundedByConstant(SubExtent, 0)) {
274     LLVM_DEBUG(dbgs() << "Abort; dimension not bounded by constant\n");
275     return nullptr;
276   }
277 
278   auto Min = SubExtent.dim_min(0);
279   LLVM_DEBUG(dbgs() << "Min bound:\n  " << Min << "\n");
280   auto MinVal = getConstant(Min, false, true);
281   auto Max = SubExtent.dim_max(0);
282   LLVM_DEBUG(dbgs() << "Max bound:\n  " << Max << "\n");
283   auto MaxVal = getConstant(Max, true, false);
284 
285   if (!MinVal || !MaxVal || MinVal.is_nan() || MaxVal.is_nan()) {
286     LLVM_DEBUG(dbgs() << "Abort; dimension bounds could not be determined\n");
287     return nullptr;
288   }
289 
290   auto FirstSubScheduleAff = scheduleExtractDimAff(SubSchedule, 0);
291   auto RemainingSubSchedule = scheduleProjectOut(std::move(SubSchedule), 0, 1);
292 
293   auto LenVal = MaxVal.sub(MinVal).add_ui(1);
294   auto FirstSubScheduleNormalized = subtract(FirstSubScheduleAff, MinVal);
295 
296   // TODO: Normalize FirstAff to zero (convert to isl_map, determine minimum,
297   // subtract it)
298   auto FirstAff = scheduleExtractDimAff(Schedule, 0);
299   auto Offset = multiply(FirstAff, LenVal);
300   auto Index = FirstSubScheduleNormalized.add(Offset);
301   auto IndexMap = isl::union_map(Index);
302 
303   auto Result = IndexMap.flat_range_product(RemainingSubSchedule);
304   LLVM_DEBUG(dbgs() << "Loop-flatten result is:\n  " << Result << "\n");
305   return Result;
306 }
307 } // anonymous namespace
308 
309 isl::union_map polly::flattenSchedule(isl::union_map Schedule) {
310   auto Dims = scheduleScatterDims(Schedule);
311   LLVM_DEBUG(dbgs() << "Recursive schedule to process:\n  " << Schedule
312                     << "\n");
313 
314   // Base case; no dimensions left
315   if (Dims == 0) {
316     // TODO: Add one dimension?
317     return Schedule;
318   }
319 
320   // Base case; already one-dimensional
321   if (Dims == 1)
322     return Schedule;
323 
324   // Fixed dimension; no need to preserve variabledness.
325   if (!isVariableDim(Schedule)) {
326     LLVM_DEBUG(dbgs() << "Fixed dimension; try sequence flattening\n");
327     auto NewScheduleSequence = tryFlattenSequence(Schedule);
328     if (NewScheduleSequence)
329       return NewScheduleSequence;
330   }
331 
332   // Constant stride
333   LLVM_DEBUG(dbgs() << "Try loop flattening\n");
334   auto NewScheduleLoop = tryFlattenLoop(Schedule);
335   if (NewScheduleLoop)
336     return NewScheduleLoop;
337 
338   // Try again without loop condition (may blow up the number of pieces!!)
339   LLVM_DEBUG(dbgs() << "Try sequence flattening again\n");
340   auto NewScheduleSequence = tryFlattenSequence(Schedule);
341   if (NewScheduleSequence)
342     return NewScheduleSequence;
343 
344   // Cannot flatten
345   return Schedule;
346 }
347