1 //===------ FlattenAlgo.cpp ------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Main algorithm of the FlattenSchedulePass. This is a separate file to avoid
10 // the unittest for this requiring linking against LLVM.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "polly/FlattenAlgo.h"
15 #include "polly/Support/ISLOStream.h"
16 #include "polly/Support/ISLTools.h"
17 #include "llvm/Support/Debug.h"
18 #define DEBUG_TYPE "polly-flatten-algo"
19 
20 using namespace polly;
21 using namespace llvm;
22 
23 namespace {
24 
25 /// Whether a dimension of a set is bounded (lower and upper) by a constant,
26 /// i.e. there are two constants Min and Max, such that every value x of the
27 /// chosen dimensions is Min <= x <= Max.
isDimBoundedByConstant(isl::set Set,unsigned dim)28 bool isDimBoundedByConstant(isl::set Set, unsigned dim) {
29   auto ParamDims = unsignedFromIslSize(Set.dim(isl::dim::param));
30   Set = Set.project_out(isl::dim::param, 0, ParamDims);
31   Set = Set.project_out(isl::dim::set, 0, dim);
32   auto SetDims = unsignedFromIslSize(Set.tuple_dim());
33   assert(SetDims >= 1);
34   Set = Set.project_out(isl::dim::set, 1, SetDims - 1);
35   return bool(Set.is_bounded());
36 }
37 
38 /// Whether a dimension of a set is (lower and upper) bounded by a constant or
39 /// parameters, i.e. there are two expressions Min_p and Max_p of the parameters
40 /// p, such that every value x of the chosen dimensions is
41 /// Min_p <= x <= Max_p.
isDimBoundedByParameter(isl::set Set,unsigned dim)42 bool isDimBoundedByParameter(isl::set Set, unsigned dim) {
43   Set = Set.project_out(isl::dim::set, 0, dim);
44   auto SetDims = unsignedFromIslSize(Set.tuple_dim());
45   assert(SetDims >= 1);
46   Set = Set.project_out(isl::dim::set, 1, SetDims - 1);
47   return bool(Set.is_bounded());
48 }
49 
50 /// Whether BMap's first out-dimension is not a constant.
isVariableDim(const isl::basic_map & BMap)51 bool isVariableDim(const isl::basic_map &BMap) {
52   auto FixedVal = BMap.plain_get_val_if_fixed(isl::dim::out, 0);
53   return FixedVal.is_null() || FixedVal.is_nan();
54 }
55 
56 /// Whether Map's first out dimension is no constant nor piecewise constant.
isVariableDim(const isl::map & Map)57 bool isVariableDim(const isl::map &Map) {
58   for (isl::basic_map BMap : Map.get_basic_map_list())
59     if (isVariableDim(BMap))
60       return false;
61 
62   return true;
63 }
64 
65 /// Whether UMap's first out dimension is no (piecewise) constant.
isVariableDim(const isl::union_map & UMap)66 bool isVariableDim(const isl::union_map &UMap) {
67   for (isl::map Map : UMap.get_map_list())
68     if (isVariableDim(Map))
69       return false;
70   return true;
71 }
72 
73 /// Compute @p UPwAff - @p Val.
subtract(isl::union_pw_aff UPwAff,isl::val Val)74 isl::union_pw_aff subtract(isl::union_pw_aff UPwAff, isl::val Val) {
75   if (Val.is_zero())
76     return UPwAff;
77 
78   auto Result = isl::union_pw_aff::empty(UPwAff.get_space());
79   isl::stat Stat =
80       UPwAff.foreach_pw_aff([=, &Result](isl::pw_aff PwAff) -> isl::stat {
81         auto ValAff =
82             isl::pw_aff(isl::set::universe(PwAff.get_space().domain()), Val);
83         auto Subtracted = PwAff.sub(ValAff);
84         Result = Result.union_add(isl::union_pw_aff(Subtracted));
85         return isl::stat::ok();
86       });
87   if (Stat.is_error())
88     return {};
89   return Result;
90 }
91 
92 /// Compute @UPwAff * @p Val.
multiply(isl::union_pw_aff UPwAff,isl::val Val)93 isl::union_pw_aff multiply(isl::union_pw_aff UPwAff, isl::val Val) {
94   if (Val.is_one())
95     return UPwAff;
96 
97   auto Result = isl::union_pw_aff::empty(UPwAff.get_space());
98   isl::stat Stat =
99       UPwAff.foreach_pw_aff([=, &Result](isl::pw_aff PwAff) -> isl::stat {
100         auto ValAff =
101             isl::pw_aff(isl::set::universe(PwAff.get_space().domain()), Val);
102         auto Multiplied = PwAff.mul(ValAff);
103         Result = Result.union_add(Multiplied);
104         return isl::stat::ok();
105       });
106   if (Stat.is_error())
107     return {};
108   return Result;
109 }
110 
111 /// Remove @p n dimensions from @p UMap's range, starting at @p first.
112 ///
113 /// It is assumed that all maps in the maps have at least the necessary number
114 /// of out dimensions.
scheduleProjectOut(const isl::union_map & UMap,unsigned first,unsigned n)115 isl::union_map scheduleProjectOut(const isl::union_map &UMap, unsigned first,
116                                   unsigned n) {
117   if (n == 0)
118     return UMap; /* isl_map_project_out would also reset the tuple, which should
119                     have no effect on schedule ranges */
120 
121   auto Result = isl::union_map::empty(UMap.ctx());
122   for (isl::map Map : UMap.get_map_list()) {
123     auto Outprojected = Map.project_out(isl::dim::out, first, n);
124     Result = Result.unite(Outprojected);
125   }
126   return Result;
127 }
128 
129 /// Return the @p pos' range dimension, converted to an isl_union_pw_aff.
scheduleExtractDimAff(isl::union_map UMap,unsigned pos)130 isl::union_pw_aff scheduleExtractDimAff(isl::union_map UMap, unsigned pos) {
131   auto SingleUMap = isl::union_map::empty(UMap.ctx());
132   for (isl::map Map : UMap.get_map_list()) {
133     unsigned MapDims = unsignedFromIslSize(Map.range_tuple_dim());
134     assert(MapDims > pos);
135     isl::map SingleMap = Map.project_out(isl::dim::out, 0, pos);
136     SingleMap = SingleMap.project_out(isl::dim::out, 1, MapDims - pos - 1);
137     SingleUMap = SingleUMap.unite(SingleMap);
138   };
139 
140   auto UAff = isl::union_pw_multi_aff(SingleUMap);
141   auto FirstMAff = isl::multi_union_pw_aff(UAff);
142   return FirstMAff.at(0);
143 }
144 
145 /// Flatten a sequence-like first dimension.
146 ///
147 /// A sequence-like scatter dimension is constant, or at least only small
148 /// variation, typically the result of ordering a sequence of different
149 /// statements. An example would be:
150 ///   { Stmt_A[] -> [0, X, ...]; Stmt_B[] -> [1, Y, ...] }
151 /// to schedule all instances of Stmt_A before any instance of Stmt_B.
152 ///
153 /// To flatten, first begin with an offset of zero. Then determine the lowest
154 /// possible value of the dimension, call it "i" [In the example we start at 0].
155 /// Considering only schedules with that value, consider only instances with
156 /// that value and determine the extent of the next dimension. Let l_X(i) and
157 /// u_X(i) its minimum (lower bound) and maximum (upper bound) value. Add them
158 /// as "Offset + X - l_X(i)" to the new schedule, then add "u_X(i) - l_X(i) + 1"
159 /// to Offset and remove all i-instances from the old schedule. Repeat with the
160 /// remaining lowest value i' until there are no instances in the old schedule
161 /// left.
162 /// The example schedule would be transformed to:
163 ///   { Stmt_X[] -> [X - l_X, ...]; Stmt_B -> [l_X - u_X + 1 + Y - l_Y, ...] }
tryFlattenSequence(isl::union_map Schedule)164 isl::union_map tryFlattenSequence(isl::union_map Schedule) {
165   auto IslCtx = Schedule.ctx();
166   auto ScatterSet = isl::set(Schedule.range());
167 
168   auto ParamSpace = Schedule.get_space().params();
169   auto Dims = unsignedFromIslSize(ScatterSet.tuple_dim());
170   assert(Dims >= 2u);
171 
172   // Would cause an infinite loop.
173   if (!isDimBoundedByConstant(ScatterSet, 0)) {
174     LLVM_DEBUG(dbgs() << "Abort; dimension is not of fixed size\n");
175     return {};
176   }
177 
178   auto AllDomains = Schedule.domain();
179   auto AllDomainsToNull = isl::union_pw_multi_aff(AllDomains);
180 
181   auto NewSchedule = isl::union_map::empty(ParamSpace.ctx());
182   auto Counter = isl::pw_aff(isl::local_space(ParamSpace.set_from_params()));
183 
184   while (!ScatterSet.is_empty()) {
185     LLVM_DEBUG(dbgs() << "Next counter:\n  " << Counter << "\n");
186     LLVM_DEBUG(dbgs() << "Remaining scatter set:\n  " << ScatterSet << "\n");
187     auto ThisSet = ScatterSet.project_out(isl::dim::set, 1, Dims - 1);
188     auto ThisFirst = ThisSet.lexmin();
189     auto ScatterFirst = ThisFirst.add_dims(isl::dim::set, Dims - 1);
190 
191     auto SubSchedule = Schedule.intersect_range(ScatterFirst);
192     SubSchedule = scheduleProjectOut(SubSchedule, 0, 1);
193     SubSchedule = flattenSchedule(SubSchedule);
194 
195     unsigned SubDims = getNumScatterDims(SubSchedule);
196     assert(SubDims >= 1);
197     auto FirstSubSchedule = scheduleProjectOut(SubSchedule, 1, SubDims - 1);
198     auto FirstScheduleAff = scheduleExtractDimAff(FirstSubSchedule, 0);
199     auto RemainingSubSchedule = scheduleProjectOut(SubSchedule, 0, 1);
200 
201     auto FirstSubScatter = isl::set(FirstSubSchedule.range());
202     LLVM_DEBUG(dbgs() << "Next step in sequence is:\n  " << FirstSubScatter
203                       << "\n");
204 
205     if (!isDimBoundedByParameter(FirstSubScatter, 0)) {
206       LLVM_DEBUG(dbgs() << "Abort; sequence step is not bounded\n");
207       return {};
208     }
209 
210     auto FirstSubScatterMap = isl::map::from_range(FirstSubScatter);
211 
212     // isl_set_dim_max returns a strange isl_pw_aff with domain tuple_id of
213     // 'none'. It doesn't match with any space including a 0-dimensional
214     // anonymous tuple.
215     // Interesting, one can create such a set using
216     // isl_set_universe(ParamSpace). Bug?
217     auto PartMin = FirstSubScatterMap.dim_min(0);
218     auto PartMax = FirstSubScatterMap.dim_max(0);
219     auto One = isl::pw_aff(isl::set::universe(ParamSpace.set_from_params()),
220                            isl::val::one(IslCtx));
221     auto PartLen = PartMax.add(PartMin.neg()).add(One);
222 
223     auto AllPartMin = isl::union_pw_aff(PartMin).pullback(AllDomainsToNull);
224     auto FirstScheduleAffNormalized = FirstScheduleAff.sub(AllPartMin);
225     auto AllCounter = isl::union_pw_aff(Counter).pullback(AllDomainsToNull);
226     auto FirstScheduleAffWithOffset =
227         FirstScheduleAffNormalized.add(AllCounter);
228 
229     auto ScheduleWithOffset =
230         isl::union_map::from(
231             isl::union_pw_multi_aff(FirstScheduleAffWithOffset))
232             .flat_range_product(RemainingSubSchedule);
233     NewSchedule = NewSchedule.unite(ScheduleWithOffset);
234 
235     ScatterSet = ScatterSet.subtract(ScatterFirst);
236     Counter = Counter.add(PartLen);
237   }
238 
239   LLVM_DEBUG(dbgs() << "Sequence-flatten result is:\n  " << NewSchedule
240                     << "\n");
241   return NewSchedule;
242 }
243 
244 /// Flatten a loop-like first dimension.
245 ///
246 /// A loop-like dimension is one that depends on a variable (usually a loop's
247 /// induction variable). Let the input schedule look like this:
248 ///   { Stmt[i] -> [i, X, ...] }
249 ///
250 /// To flatten, we determine the largest extent of X which may not depend on the
251 /// actual value of i. Let l_X() the smallest possible value of X and u_X() its
252 /// largest value. Then, construct a new schedule
253 ///   { Stmt[i] -> [i * (u_X() - l_X() + 1), ...] }
tryFlattenLoop(isl::union_map Schedule)254 isl::union_map tryFlattenLoop(isl::union_map Schedule) {
255   assert(getNumScatterDims(Schedule) >= 2);
256 
257   auto Remaining = scheduleProjectOut(Schedule, 0, 1);
258   auto SubSchedule = flattenSchedule(Remaining);
259   unsigned SubDims = getNumScatterDims(SubSchedule);
260 
261   assert(SubDims >= 1);
262 
263   auto SubExtent = isl::set(SubSchedule.range());
264   auto SubExtentDims = unsignedFromIslSize(SubExtent.dim(isl::dim::param));
265   SubExtent = SubExtent.project_out(isl::dim::param, 0, SubExtentDims);
266   SubExtent = SubExtent.project_out(isl::dim::set, 1, SubDims - 1);
267 
268   if (!isDimBoundedByConstant(SubExtent, 0)) {
269     LLVM_DEBUG(dbgs() << "Abort; dimension not bounded by constant\n");
270     return {};
271   }
272 
273   auto Min = SubExtent.dim_min(0);
274   LLVM_DEBUG(dbgs() << "Min bound:\n  " << Min << "\n");
275   auto MinVal = getConstant(Min, false, true);
276   auto Max = SubExtent.dim_max(0);
277   LLVM_DEBUG(dbgs() << "Max bound:\n  " << Max << "\n");
278   auto MaxVal = getConstant(Max, true, false);
279 
280   if (MinVal.is_null() || MaxVal.is_null() || MinVal.is_nan() ||
281       MaxVal.is_nan()) {
282     LLVM_DEBUG(dbgs() << "Abort; dimension bounds could not be determined\n");
283     return {};
284   }
285 
286   auto FirstSubScheduleAff = scheduleExtractDimAff(SubSchedule, 0);
287   auto RemainingSubSchedule = scheduleProjectOut(std::move(SubSchedule), 0, 1);
288 
289   auto LenVal = MaxVal.sub(MinVal).add(1);
290   auto FirstSubScheduleNormalized = subtract(FirstSubScheduleAff, MinVal);
291 
292   // TODO: Normalize FirstAff to zero (convert to isl_map, determine minimum,
293   // subtract it)
294   auto FirstAff = scheduleExtractDimAff(Schedule, 0);
295   auto Offset = multiply(FirstAff, LenVal);
296   isl::union_pw_multi_aff Index = FirstSubScheduleNormalized.add(Offset);
297   auto IndexMap = isl::union_map::from(Index);
298 
299   auto Result = IndexMap.flat_range_product(RemainingSubSchedule);
300   LLVM_DEBUG(dbgs() << "Loop-flatten result is:\n  " << Result << "\n");
301   return Result;
302 }
303 } // anonymous namespace
304 
flattenSchedule(isl::union_map Schedule)305 isl::union_map polly::flattenSchedule(isl::union_map Schedule) {
306   unsigned Dims = getNumScatterDims(Schedule);
307   LLVM_DEBUG(dbgs() << "Recursive schedule to process:\n  " << Schedule
308                     << "\n");
309 
310   // Base case; no dimensions left
311   if (Dims == 0) {
312     // TODO: Add one dimension?
313     return Schedule;
314   }
315 
316   // Base case; already one-dimensional
317   if (Dims == 1)
318     return Schedule;
319 
320   // Fixed dimension; no need to preserve variabledness.
321   if (!isVariableDim(Schedule)) {
322     LLVM_DEBUG(dbgs() << "Fixed dimension; try sequence flattening\n");
323     auto NewScheduleSequence = tryFlattenSequence(Schedule);
324     if (!NewScheduleSequence.is_null())
325       return NewScheduleSequence;
326   }
327 
328   // Constant stride
329   LLVM_DEBUG(dbgs() << "Try loop flattening\n");
330   auto NewScheduleLoop = tryFlattenLoop(Schedule);
331   if (!NewScheduleLoop.is_null())
332     return NewScheduleLoop;
333 
334   // Try again without loop condition (may blow up the number of pieces!!)
335   LLVM_DEBUG(dbgs() << "Try sequence flattening again\n");
336   auto NewScheduleSequence = tryFlattenSequence(Schedule);
337   if (!NewScheduleSequence.is_null())
338     return NewScheduleSequence;
339 
340   // Cannot flatten
341   return Schedule;
342 }
343