1 //===- IslAst.cpp - isl code generator interface --------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // The isl code generator interface takes a Scop and generates a isl_ast. This
11 // ist_ast can either be returned directly or it can be pretty printed to
12 // stdout.
13 //
14 // A typical isl_ast output looks like this:
15 //
16 // for (c2 = max(0, ceild(n + m, 2); c2 <= min(511, floord(5 * n, 3)); c2++) {
17 //   bb2(c2);
18 // }
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "polly/CodeGen/CodeGeneration.h"
23 #include "polly/CodeGen/IslAst.h"
24 #include "polly/Dependences.h"
25 #include "polly/LinkAllPasses.h"
26 #include "polly/Options.h"
27 #include "polly/ScopInfo.h"
28 #include "polly/Support/GICHelper.h"
29 #include "llvm/Support/Debug.h"
30 
31 #include "isl/union_map.h"
32 #include "isl/list.h"
33 #include "isl/ast_build.h"
34 #include "isl/set.h"
35 #include "isl/map.h"
36 #include "isl/aff.h"
37 
38 #define DEBUG_TYPE "polly-ast"
39 
40 using namespace llvm;
41 using namespace polly;
42 
43 using IslAstUserPayload = IslAstInfo::IslAstUserPayload;
44 
45 static cl::opt<bool> UseContext("polly-ast-use-context",
46                                 cl::desc("Use context"), cl::Hidden,
47                                 cl::init(false), cl::ZeroOrMore,
48                                 cl::cat(PollyCategory));
49 
50 static cl::opt<bool> DetectParallel("polly-ast-detect-parallel",
51                                     cl::desc("Detect parallelism"), cl::Hidden,
52                                     cl::init(false), cl::ZeroOrMore,
53                                     cl::cat(PollyCategory));
54 
55 namespace polly {
56 class IslAst {
57 public:
58   IslAst(Scop *Scop, Dependences &D);
59 
60   ~IslAst();
61 
62   /// Print a source code representation of the program.
63   void pprint(llvm::raw_ostream &OS);
64 
65   __isl_give isl_ast_node *getAst();
66 
67   /// @brief Get the run-time conditions for the Scop.
68   __isl_give isl_ast_expr *getRunCondition();
69 
70 private:
71   Scop *S;
72   isl_ast_node *Root;
73   isl_ast_expr *RunCondition;
74 
75   void buildRunCondition(__isl_keep isl_ast_build *Build);
76 };
77 } // End namespace polly.
78 
79 /// @brief Free an IslAstUserPayload object pointed to by @p Ptr
80 static void freeIslAstUserPayload(void *Ptr) {
81   delete ((IslAstInfo::IslAstUserPayload *)Ptr);
82 }
83 
84 IslAstInfo::IslAstUserPayload::~IslAstUserPayload() {
85   isl_ast_build_free(Build);
86   isl_pw_aff_free(MinimalDependenceDistance);
87 }
88 
89 /// @brief Temporary information used when building the ast.
90 struct AstBuildUserInfo {
91   /// @brief Construct and initialize the helper struct for AST creation.
92   AstBuildUserInfo()
93       : Deps(nullptr), InParallelFor(false), LastForNodeId(nullptr) {}
94 
95   /// @brief The dependence information used for the parallelism check.
96   Dependences *Deps;
97 
98   /// @brief Flag to indicate that we are inside a parallel for node.
99   bool InParallelFor;
100 
101   /// @brief The last iterator id created for the current SCoP.
102   isl_id *LastForNodeId;
103 };
104 
105 /// @brief Print a string @p str in a single line using @p Printer.
106 static isl_printer *printLine(__isl_take isl_printer *Printer,
107                               const std::string &str,
108                               __isl_keep isl_pw_aff *PWA = nullptr) {
109   Printer = isl_printer_start_line(Printer);
110   Printer = isl_printer_print_str(Printer, str.c_str());
111   if (PWA)
112     Printer = isl_printer_print_pw_aff(Printer, PWA);
113   return isl_printer_end_line(Printer);
114 }
115 
116 /// @brief Return all broken reductions as a string of clauses (OpenMP style).
117 static const std::string getBrokenReductionsStr(__isl_keep isl_ast_node *Node) {
118   IslAstInfo::MemoryAccessSet *BrokenReductions;
119   std::string str;
120 
121   BrokenReductions = IslAstInfo::getBrokenReductions(Node);
122   if (!BrokenReductions || BrokenReductions->empty())
123     return "";
124 
125   // Map each type of reduction to a comma separated list of the base addresses.
126   std::map<MemoryAccess::ReductionType, std::string> Clauses;
127   for (MemoryAccess *MA : *BrokenReductions)
128     if (MA->isWrite())
129       Clauses[MA->getReductionType()] +=
130           ", " + MA->getBaseAddr()->getName().str();
131 
132   // Now print the reductions sorted by type. Each type will cause a clause
133   // like:  reduction (+ : sum0, sum1, sum2)
134   for (const auto &ReductionClause : Clauses) {
135     str += " reduction (";
136     str += MemoryAccess::getReductionOperatorStr(ReductionClause.first);
137     // Remove the first two symbols (", ") to make the output look pretty.
138     str += " : " + ReductionClause.second.substr(2) + ")";
139   }
140 
141   return str;
142 }
143 
144 /// @brief Callback executed for each for node in the ast in order to print it.
145 static isl_printer *cbPrintFor(__isl_take isl_printer *Printer,
146                                __isl_take isl_ast_print_options *Options,
147                                __isl_keep isl_ast_node *Node, void *) {
148 
149   isl_pw_aff *DD = IslAstInfo::getMinimalDependenceDistance(Node);
150   const std::string BrokenReductionsStr = getBrokenReductionsStr(Node);
151   const std::string DepDisPragmaStr = "#pragma minimal dependence distance: ";
152   const std::string SimdPragmaStr = "#pragma simd";
153   const std::string OmpPragmaStr = "#pragma omp parallel for";
154 
155   if (DD)
156     Printer = printLine(Printer, DepDisPragmaStr, DD);
157 
158   if (IslAstInfo::isInnermostParallel(Node))
159     Printer = printLine(Printer, SimdPragmaStr + BrokenReductionsStr);
160 
161   if (IslAstInfo::isOutermostParallel(Node))
162     Printer = printLine(Printer, OmpPragmaStr + BrokenReductionsStr);
163 
164   isl_pw_aff_free(DD);
165   return isl_ast_node_for_print(Node, Printer, Options);
166 }
167 
168 /// @brief Check if the current scheduling dimension is parallel
169 ///
170 /// In case the dimension is parallel we also check if any reduction
171 /// dependences is broken when we exploit this parallelism. If so,
172 /// @p IsReductionParallel will be set to true. The reduction dependences we use
173 /// to check are actually the union of the transitive closure of the initial
174 /// reduction dependences together with their reveresal. Even though these
175 /// dependences connect all iterations with each other (thus they are cyclic)
176 /// we can perform the parallelism check as we are only interested in a zero
177 /// (or non-zero) dependence distance on the dimension in question.
178 static bool astScheduleDimIsParallel(__isl_keep isl_ast_build *Build,
179                                      Dependences *D,
180                                      IslAstUserPayload *NodeInfo) {
181   if (!D->hasValidDependences())
182     return false;
183 
184   isl_union_map *Schedule = isl_ast_build_get_schedule(Build);
185   isl_union_map *Deps = D->getDependences(
186       Dependences::TYPE_RAW | Dependences::TYPE_WAW | Dependences::TYPE_WAR);
187 
188   if (!D->isParallel(Schedule, Deps, &NodeInfo->MinimalDependenceDistance) &&
189       !isl_union_map_free(Schedule))
190     return false;
191 
192   isl_union_map *RedDeps = D->getDependences(Dependences::TYPE_TC_RED);
193   if (!D->isParallel(Schedule, RedDeps))
194     NodeInfo->IsReductionParallel = true;
195 
196   if (!NodeInfo->IsReductionParallel && !isl_union_map_free(Schedule))
197     return true;
198 
199   // Annotate reduction parallel nodes with the memory accesses which caused the
200   // reduction dependences parallel execution of the node conflicts with.
201   for (const auto &MaRedPair : D->getReductionDependences()) {
202     if (!MaRedPair.second)
203       continue;
204     RedDeps = isl_union_map_from_map(isl_map_copy(MaRedPair.second));
205     if (!D->isParallel(Schedule, RedDeps))
206       NodeInfo->BrokenReductions.insert(MaRedPair.first);
207   }
208 
209   isl_union_map_free(Schedule);
210   return true;
211 }
212 
213 // This method is executed before the construction of a for node. It creates
214 // an isl_id that is used to annotate the subsequently generated ast for nodes.
215 //
216 // In this function we also run the following analyses:
217 //
218 // - Detection of openmp parallel loops
219 //
220 static __isl_give isl_id *astBuildBeforeFor(__isl_keep isl_ast_build *Build,
221                                             void *User) {
222   AstBuildUserInfo *BuildInfo = (AstBuildUserInfo *)User;
223   IslAstUserPayload *Payload = new IslAstUserPayload();
224   isl_id *Id = isl_id_alloc(isl_ast_build_get_ctx(Build), "", Payload);
225   Id = isl_id_set_free_user(Id, freeIslAstUserPayload);
226   BuildInfo->LastForNodeId = Id;
227 
228   // Test for parallelism only if we are not already inside a parallel loop
229   if (!BuildInfo->InParallelFor)
230     BuildInfo->InParallelFor = Payload->IsOutermostParallel =
231         astScheduleDimIsParallel(Build, BuildInfo->Deps, Payload);
232 
233   return Id;
234 }
235 
236 // This method is executed after the construction of a for node.
237 //
238 // It performs the following actions:
239 //
240 // - Reset the 'InParallelFor' flag, as soon as we leave a for node,
241 //   that is marked as openmp parallel.
242 //
243 static __isl_give isl_ast_node *
244 astBuildAfterFor(__isl_take isl_ast_node *Node, __isl_keep isl_ast_build *Build,
245                  void *User) {
246   isl_id *Id = isl_ast_node_get_annotation(Node);
247   assert(Id && "Post order visit assumes annotated for nodes");
248   IslAstUserPayload *Payload = (IslAstUserPayload *)isl_id_get_user(Id);
249   assert(Payload && "Post order visit assumes annotated for nodes");
250 
251   AstBuildUserInfo *BuildInfo = (AstBuildUserInfo *)User;
252   assert(!Payload->Build && "Build environment already set");
253   Payload->Build = isl_ast_build_copy(Build);
254   Payload->IsInnermost = (Id == BuildInfo->LastForNodeId);
255 
256   // Innermost loops that are surrounded by parallel loops have not yet been
257   // tested for parallelism. Test them here to ensure we check all innermost
258   // loops for parallelism.
259   if (Payload->IsInnermost && BuildInfo->InParallelFor) {
260     if (Payload->IsOutermostParallel)
261       Payload->IsInnermostParallel = true;
262     else
263       Payload->IsInnermostParallel =
264           astScheduleDimIsParallel(Build, BuildInfo->Deps, Payload);
265   }
266   if (Payload->IsOutermostParallel)
267     BuildInfo->InParallelFor = false;
268 
269   isl_id_free(Id);
270   return Node;
271 }
272 
273 static __isl_give isl_ast_node *AtEachDomain(__isl_take isl_ast_node *Node,
274                                              __isl_keep isl_ast_build *Build,
275                                              void *User) {
276   assert(!isl_ast_node_get_annotation(Node) && "Node already annotated");
277 
278   IslAstUserPayload *Payload = new IslAstUserPayload();
279   isl_id *Id = isl_id_alloc(isl_ast_build_get_ctx(Build), "", Payload);
280   Id = isl_id_set_free_user(Id, freeIslAstUserPayload);
281 
282   Payload->Build = isl_ast_build_copy(Build);
283 
284   return isl_ast_node_set_annotation(Node, Id);
285 }
286 
287 void IslAst::buildRunCondition(__isl_keep isl_ast_build *Build) {
288   // The conditions that need to be checked at run-time for this scop are
289   // available as an isl_set in the AssumedContext. We generate code for this
290   // check as follows. First, we generate an isl_pw_aff that is 1, if a certain
291   // combination of parameter values fulfills the conditions in the assumed
292   // context, and that is 0 otherwise. We then translate this isl_pw_aff into
293   // an isl_ast_expr. At run-time this expression can be evaluated and the
294   // optimized scop can be executed conditionally according to the result of the
295   // run-time check.
296 
297   isl_aff *Zero =
298       isl_aff_zero_on_domain(isl_local_space_from_space(S->getParamSpace()));
299   isl_aff *One =
300       isl_aff_zero_on_domain(isl_local_space_from_space(S->getParamSpace()));
301 
302   One = isl_aff_add_constant_si(One, 1);
303 
304   isl_pw_aff *PwZero = isl_pw_aff_from_aff(Zero);
305   isl_pw_aff *PwOne = isl_pw_aff_from_aff(One);
306 
307   PwOne = isl_pw_aff_intersect_domain(PwOne, S->getAssumedContext());
308   PwZero = isl_pw_aff_intersect_domain(
309       PwZero, isl_set_complement(S->getAssumedContext()));
310 
311   isl_pw_aff *Cond = isl_pw_aff_union_max(PwOne, PwZero);
312 
313   RunCondition = isl_ast_build_expr_from_pw_aff(Build, Cond);
314 
315   // Create the alias checks from the minimal/maximal accesses in each alias
316   // group. This operation is by construction quadratic in the number of
317   // elements in each alias group.
318   isl_ast_expr *NonAliasGroup, *MinExpr, *MaxExpr;
319   for (const Scop::MinMaxVectorTy *MinMaxAccesses : S->getAliasGroups()) {
320     auto AccEnd = MinMaxAccesses->end();
321     for (auto AccIt0 = MinMaxAccesses->begin(); AccIt0 != AccEnd; ++AccIt0) {
322       for (auto AccIt1 = AccIt0 + 1; AccIt1 != AccEnd; ++AccIt1) {
323         MinExpr =
324             isl_ast_expr_address_of(isl_ast_build_access_from_pw_multi_aff(
325                 Build, isl_pw_multi_aff_copy(AccIt0->first)));
326         MaxExpr =
327             isl_ast_expr_address_of(isl_ast_build_access_from_pw_multi_aff(
328                 Build, isl_pw_multi_aff_copy(AccIt1->second)));
329         NonAliasGroup = isl_ast_expr_le(MaxExpr, MinExpr);
330         MinExpr =
331             isl_ast_expr_address_of(isl_ast_build_access_from_pw_multi_aff(
332                 Build, isl_pw_multi_aff_copy(AccIt1->first)));
333         MaxExpr =
334             isl_ast_expr_address_of(isl_ast_build_access_from_pw_multi_aff(
335                 Build, isl_pw_multi_aff_copy(AccIt0->second)));
336         NonAliasGroup =
337             isl_ast_expr_or(NonAliasGroup, isl_ast_expr_le(MaxExpr, MinExpr));
338         RunCondition = isl_ast_expr_and(RunCondition, NonAliasGroup);
339       }
340     }
341   }
342 }
343 
344 IslAst::IslAst(Scop *Scop, Dependences &D) : S(Scop) {
345   isl_ctx *Ctx = S->getIslCtx();
346   isl_options_set_ast_build_atomic_upper_bound(Ctx, true);
347   isl_ast_build *Build;
348   AstBuildUserInfo BuildInfo;
349 
350   if (UseContext)
351     Build = isl_ast_build_from_context(S->getContext());
352   else
353     Build = isl_ast_build_from_context(isl_set_universe(S->getParamSpace()));
354 
355   Build = isl_ast_build_set_at_each_domain(Build, AtEachDomain, nullptr);
356 
357   isl_union_map *Schedule =
358       isl_union_map_intersect_domain(S->getSchedule(), S->getDomains());
359 
360   if (DetectParallel || PollyVectorizerChoice != VECTORIZER_NONE) {
361     BuildInfo.Deps = &D;
362     BuildInfo.InParallelFor = 0;
363 
364     Build = isl_ast_build_set_before_each_for(Build, &astBuildBeforeFor,
365                                               &BuildInfo);
366     Build =
367         isl_ast_build_set_after_each_for(Build, &astBuildAfterFor, &BuildInfo);
368   }
369 
370   buildRunCondition(Build);
371 
372   Root = isl_ast_build_ast_from_schedule(Build, Schedule);
373 
374   isl_ast_build_free(Build);
375 }
376 
377 IslAst::~IslAst() {
378   isl_ast_node_free(Root);
379   isl_ast_expr_free(RunCondition);
380 }
381 
382 __isl_give isl_ast_node *IslAst::getAst() { return isl_ast_node_copy(Root); }
383 __isl_give isl_ast_expr *IslAst::getRunCondition() {
384   return isl_ast_expr_copy(RunCondition);
385 }
386 
387 void IslAstInfo::releaseMemory() {
388   if (Ast) {
389     delete Ast;
390     Ast = 0;
391   }
392 }
393 
394 bool IslAstInfo::runOnScop(Scop &Scop) {
395   if (Ast)
396     delete Ast;
397 
398   S = &Scop;
399 
400   Dependences &D = getAnalysis<Dependences>();
401 
402   Ast = new IslAst(&Scop, D);
403 
404   DEBUG(printScop(dbgs()));
405   return false;
406 }
407 
408 __isl_give isl_ast_node *IslAstInfo::getAst() const { return Ast->getAst(); }
409 __isl_give isl_ast_expr *IslAstInfo::getRunCondition() const {
410   return Ast->getRunCondition();
411 }
412 
413 IslAstUserPayload *IslAstInfo::getNodePayload(__isl_keep isl_ast_node *Node) {
414   isl_id *Id = isl_ast_node_get_annotation(Node);
415   if (!Id)
416     return nullptr;
417   IslAstUserPayload *Payload = (IslAstUserPayload *)isl_id_get_user(Id);
418   isl_id_free(Id);
419   return Payload;
420 }
421 
422 bool IslAstInfo::isInnermost(__isl_keep isl_ast_node *Node) {
423   IslAstUserPayload *Payload = getNodePayload(Node);
424   return Payload && Payload->IsInnermost;
425 }
426 
427 bool IslAstInfo::isParallel(__isl_keep isl_ast_node *Node) {
428   return IslAstInfo::isInnermostParallel(Node) ||
429          IslAstInfo::isOutermostParallel(Node);
430 }
431 
432 bool IslAstInfo::isInnermostParallel(__isl_keep isl_ast_node *Node) {
433   IslAstUserPayload *Payload = getNodePayload(Node);
434   return Payload && Payload->IsInnermostParallel;
435 }
436 
437 bool IslAstInfo::isOutermostParallel(__isl_keep isl_ast_node *Node) {
438   IslAstUserPayload *Payload = getNodePayload(Node);
439   return Payload && Payload->IsOutermostParallel;
440 }
441 
442 bool IslAstInfo::isReductionParallel(__isl_keep isl_ast_node *Node) {
443   IslAstUserPayload *Payload = getNodePayload(Node);
444   return Payload && Payload->IsReductionParallel;
445 }
446 
447 isl_union_map *IslAstInfo::getSchedule(__isl_keep isl_ast_node *Node) {
448   IslAstUserPayload *Payload = getNodePayload(Node);
449   return Payload ? isl_ast_build_get_schedule(Payload->Build) : nullptr;
450 }
451 
452 isl_pw_aff *
453 IslAstInfo::getMinimalDependenceDistance(__isl_keep isl_ast_node *Node) {
454   IslAstUserPayload *Payload = getNodePayload(Node);
455   return Payload ? isl_pw_aff_copy(Payload->MinimalDependenceDistance)
456                  : nullptr;
457 }
458 
459 IslAstInfo::MemoryAccessSet *
460 IslAstInfo::getBrokenReductions(__isl_keep isl_ast_node *Node) {
461   IslAstUserPayload *Payload = getNodePayload(Node);
462   return Payload ? &Payload->BrokenReductions : nullptr;
463 }
464 
465 isl_ast_build *IslAstInfo::getBuild(__isl_keep isl_ast_node *Node) {
466   IslAstUserPayload *Payload = getNodePayload(Node);
467   return Payload ? Payload->Build : nullptr;
468 }
469 
470 void IslAstInfo::printScop(raw_ostream &OS) const {
471   isl_ast_print_options *Options;
472   isl_ast_node *RootNode = getAst();
473   isl_ast_expr *RunCondition = getRunCondition();
474   char *RtCStr, *AstStr;
475 
476   Scop &S = getCurScop();
477   Options = isl_ast_print_options_alloc(S.getIslCtx());
478   Options = isl_ast_print_options_set_print_for(Options, cbPrintFor, nullptr);
479 
480   isl_printer *P = isl_printer_to_str(S.getIslCtx());
481   P = isl_printer_print_ast_expr(P, RunCondition);
482   RtCStr = isl_printer_get_str(P);
483   P = isl_printer_flush(P);
484   P = isl_printer_indent(P, 4);
485   P = isl_printer_set_output_format(P, ISL_FORMAT_C);
486   P = isl_ast_node_print(RootNode, P, Options);
487   AstStr = isl_printer_get_str(P);
488 
489   Function *F = S.getRegion().getEntry()->getParent();
490   isl_union_map *Schedule =
491       isl_union_map_intersect_domain(S.getSchedule(), S.getDomains());
492 
493   OS << ":: isl ast :: " << F->getName() << " :: " << S.getRegion().getNameStr()
494      << "\n";
495   DEBUG({
496     dbgs() << S.getContextStr() << "\n";
497     dbgs() << stringFromIslObj(Schedule);
498   });
499   OS << "\nif (" << RtCStr << ")\n\n";
500   OS << AstStr << "\n";
501   OS << "else\n";
502   OS << "    {  /* original code */ }\n\n";
503 
504   isl_ast_expr_free(RunCondition);
505   isl_union_map_free(Schedule);
506   isl_ast_node_free(RootNode);
507   isl_printer_free(P);
508 }
509 
510 void IslAstInfo::getAnalysisUsage(AnalysisUsage &AU) const {
511   // Get the Common analysis usage of ScopPasses.
512   ScopPass::getAnalysisUsage(AU);
513   AU.addRequired<ScopInfo>();
514   AU.addRequired<Dependences>();
515 }
516 
517 char IslAstInfo::ID = 0;
518 
519 Pass *polly::createIslAstInfoPass() { return new IslAstInfo(); }
520 
521 INITIALIZE_PASS_BEGIN(IslAstInfo, "polly-ast",
522                       "Polly - Generate an AST of the SCoP (isl)", false,
523                       false);
524 INITIALIZE_PASS_DEPENDENCY(ScopInfo);
525 INITIALIZE_PASS_DEPENDENCY(Dependences);
526 INITIALIZE_PASS_END(IslAstInfo, "polly-ast",
527                     "Polly - Generate an AST from the SCoP (isl)", false, false)
528