1 //===- IslAst.cpp - isl code generator interface --------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // The isl code generator interface takes a Scop and generates a isl_ast. This
11 // ist_ast can either be returned directly or it can be pretty printed to
12 // stdout.
13 //
14 // A typical isl_ast output looks like this:
15 //
16 // for (c2 = max(0, ceild(n + m, 2); c2 <= min(511, floord(5 * n, 3)); c2++) {
17 //   bb2(c2);
18 // }
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "polly/CodeGen/CodeGeneration.h"
23 #include "polly/CodeGen/IslAst.h"
24 
25 #include "polly/LinkAllPasses.h"
26 #include "polly/Dependences.h"
27 #include "polly/ScopInfo.h"
28 
29 #define DEBUG_TYPE "polly-ast"
30 #include "llvm/Support/CommandLine.h"
31 #include "llvm/Support/Debug.h"
32 
33 #include "isl/union_map.h"
34 #include "isl/list.h"
35 #include "isl/ast_build.h"
36 #include "isl/set.h"
37 #include "isl/map.h"
38 #include "isl/aff.h"
39 
40 using namespace llvm;
41 using namespace polly;
42 
43 static cl::opt<bool>
44 UseContext("polly-ast-use-context", cl::desc("Use context"), cl::Hidden,
45            cl::init(false), cl::ZeroOrMore);
46 
47 static cl::opt<bool>
48 DetectParallel("polly-ast-detect-parallel", cl::desc("Detect parallelism"),
49                cl::Hidden, cl::init(false), cl::ZeroOrMore);
50 
51 namespace polly {
52 class IslAst {
53 public:
54   IslAst(Scop *Scop, Dependences &D);
55 
56   ~IslAst();
57 
58   /// Print a source code representation of the program.
59   void pprint(llvm::raw_ostream &OS);
60 
61   __isl_give isl_ast_node *getAst();
62 
63 private:
64   Scop *S;
65   isl_ast_node *Root;
66 
67   __isl_give isl_union_map *getSchedule();
68 };
69 } // End namespace polly.
70 
71 // Temporary information used when building the ast.
72 struct AstBuildUserInfo {
73   // The dependence information.
74   Dependences *Deps;
75 
76   // We are inside a parallel for node.
77   int InParallelFor;
78 };
79 
80 // Print a loop annotated with OpenMP or vector pragmas.
81 static __isl_give isl_printer *printParallelFor(
82     __isl_keep isl_ast_node *Node, __isl_take isl_printer *Printer,
83     __isl_take isl_ast_print_options *PrintOptions, IslAstUser *Info) {
84   if (Info) {
85     if (Info->IsInnermostParallel) {
86       Printer = isl_printer_start_line(Printer);
87       Printer = isl_printer_print_str(Printer, "#pragma simd");
88       Printer = isl_printer_end_line(Printer);
89     }
90     if (Info->IsOutermostParallel) {
91       Printer = isl_printer_start_line(Printer);
92       Printer = isl_printer_print_str(Printer, "#pragma omp parallel for");
93       Printer = isl_printer_end_line(Printer);
94     }
95   }
96   return isl_ast_node_for_print(Node, Printer, PrintOptions);
97 }
98 
99 // Print an isl_ast_for.
100 static __isl_give isl_printer *
101 printFor(__isl_take isl_printer *Printer,
102          __isl_take isl_ast_print_options *PrintOptions,
103          __isl_keep isl_ast_node *Node, void *User) {
104   isl_id *Id = isl_ast_node_get_annotation(Node);
105   if (!Id)
106     return isl_ast_node_for_print(Node, Printer, PrintOptions);
107 
108   struct IslAstUser *Info = (struct IslAstUser *)isl_id_get_user(Id);
109   Printer = printParallelFor(Node, Printer, PrintOptions, Info);
110   isl_id_free(Id);
111   return Printer;
112 }
113 
114 // Allocate an AstNodeInfo structure and initialize it with default values.
115 static struct IslAstUser *allocateIslAstUser() {
116   struct IslAstUser *NodeInfo;
117   NodeInfo = (struct IslAstUser *)malloc(sizeof(struct IslAstUser));
118   NodeInfo->PMA = 0;
119   NodeInfo->Context = 0;
120   NodeInfo->IsOutermostParallel = 0;
121   NodeInfo->IsInnermostParallel = 0;
122   return NodeInfo;
123 }
124 
125 // Free the AstNodeInfo structure.
126 static void freeIslAstUser(void *Ptr) {
127   struct IslAstUser *UserStruct = (struct IslAstUser *)Ptr;
128   isl_ast_build_free(UserStruct->Context);
129   isl_pw_multi_aff_free(UserStruct->PMA);
130   free(UserStruct);
131 }
132 
133 // Check if the current scheduling dimension is parallel.
134 //
135 // We check for parallelism by verifying that the loop does not carry any
136 // dependences.
137 //
138 // Parallelism test: if the distance is zero in all outer dimensions, then it
139 // has to be zero in the current dimension as well.
140 //
141 // Implementation: first, translate dependences into time space, then force
142 // outer dimensions to be equal. If the distance is zero in the current
143 // dimension, then the loop is parallel. The distance is zero in the current
144 // dimension if it is a subset of a map with equal values for the current
145 // dimension.
146 static bool astScheduleDimIsParallel(__isl_keep isl_ast_build *Build,
147                                      Dependences *D) {
148   isl_union_map *Schedule, *Deps;
149   isl_map *ScheduleDeps, *Test;
150   isl_space *ScheduleSpace;
151   unsigned Dimension, IsParallel;
152 
153   Schedule = isl_ast_build_get_schedule(Build);
154   ScheduleSpace = isl_ast_build_get_schedule_space(Build);
155 
156   Dimension = isl_space_dim(ScheduleSpace, isl_dim_out) - 1;
157 
158   Deps = D->getDependences(Dependences::TYPE_ALL);
159   Deps = isl_union_map_apply_range(Deps, isl_union_map_copy(Schedule));
160   Deps = isl_union_map_apply_domain(Deps, Schedule);
161 
162   if (isl_union_map_is_empty(Deps)) {
163     isl_union_map_free(Deps);
164     isl_space_free(ScheduleSpace);
165     return 1;
166   }
167 
168   ScheduleDeps = isl_map_from_union_map(Deps);
169 
170   for (unsigned i = 0; i < Dimension; i++)
171     ScheduleDeps = isl_map_equate(ScheduleDeps, isl_dim_out, i, isl_dim_in, i);
172 
173   Test = isl_map_universe(isl_map_get_space(ScheduleDeps));
174   Test = isl_map_equate(Test, isl_dim_out, Dimension, isl_dim_in, Dimension);
175   IsParallel = isl_map_is_subset(ScheduleDeps, Test);
176 
177   isl_space_free(ScheduleSpace);
178   isl_map_free(Test);
179   isl_map_free(ScheduleDeps);
180 
181   return IsParallel;
182 }
183 
184 // Mark a for node openmp parallel, if it is the outermost parallel for node.
185 static void markOpenmpParallel(__isl_keep isl_ast_build *Build,
186                                struct AstBuildUserInfo *BuildInfo,
187                                struct IslAstUser *NodeInfo) {
188   if (BuildInfo->InParallelFor)
189     return;
190 
191   if (astScheduleDimIsParallel(Build, BuildInfo->Deps)) {
192     BuildInfo->InParallelFor = 1;
193     NodeInfo->IsOutermostParallel = 1;
194   }
195 }
196 
197 // This method is executed before the construction of a for node. It creates
198 // an isl_id that is used to annotate the subsequently generated ast for nodes.
199 //
200 // In this function we also run the following analyses:
201 //
202 // - Detection of openmp parallel loops
203 //
204 static __isl_give isl_id *
205 astBuildBeforeFor(__isl_keep isl_ast_build *Build, void *User) {
206   struct AstBuildUserInfo *BuildInfo = (struct AstBuildUserInfo *)User;
207   struct IslAstUser *NodeInfo = allocateIslAstUser();
208   isl_id *Id = isl_id_alloc(isl_ast_build_get_ctx(Build), "", NodeInfo);
209   Id = isl_id_set_free_user(Id, freeIslAstUser);
210 
211   markOpenmpParallel(Build, BuildInfo, NodeInfo);
212 
213   return Id;
214 }
215 
216 // Returns 0 when Node contains loops, otherwise returns -1. This search
217 // function uses ISL's way to iterate over lists of isl_ast_nodes with
218 // isl_ast_node_list_foreach. Please use the single argument wrapper function
219 // that returns a bool instead of using this function directly.
220 static int containsLoops(__isl_take isl_ast_node *Node, void *User) {
221   if (!Node)
222     return -1;
223 
224   switch (isl_ast_node_get_type(Node)) {
225   case isl_ast_node_for:
226     isl_ast_node_free(Node);
227     return 0;
228   case isl_ast_node_block: {
229     isl_ast_node_list *List = isl_ast_node_block_get_children(Node);
230     int Res = isl_ast_node_list_foreach(List, &containsLoops, NULL);
231     isl_ast_node_list_free(List);
232     isl_ast_node_free(Node);
233     return Res;
234   }
235   case isl_ast_node_if: {
236     int Res = -1;
237     if (0 == containsLoops(isl_ast_node_if_get_then(Node), NULL) ||
238         (isl_ast_node_if_has_else(Node) &&
239          0 == containsLoops(isl_ast_node_if_get_else(Node), NULL)))
240       Res = 0;
241     isl_ast_node_free(Node);
242     return Res;
243   }
244   case isl_ast_node_user:
245   default:
246     isl_ast_node_free(Node);
247     return -1;
248   }
249 }
250 
251 // Returns true when Node contains loops.
252 static bool containsLoops(__isl_take isl_ast_node *Node) {
253   return 0 == containsLoops(Node, NULL);
254 }
255 
256 // This method is executed after the construction of a for node.
257 //
258 // It performs the following actions:
259 //
260 // - Reset the 'InParallelFor' flag, as soon as we leave a for node,
261 //   that is marked as openmp parallel.
262 //
263 static __isl_give isl_ast_node *
264 astBuildAfterFor(__isl_take isl_ast_node *Node, __isl_keep isl_ast_build *Build,
265                  void *User) {
266   isl_id *Id = isl_ast_node_get_annotation(Node);
267   if (!Id)
268     return Node;
269   struct IslAstUser *Info = (struct IslAstUser *)isl_id_get_user(Id);
270   struct AstBuildUserInfo *BuildInfo = (struct AstBuildUserInfo *)User;
271 
272   if (Info) {
273     if (Info->IsOutermostParallel)
274       BuildInfo->InParallelFor = 0;
275     if (!containsLoops(isl_ast_node_for_get_body(Node)))
276       if (astScheduleDimIsParallel(Build, BuildInfo->Deps))
277         Info->IsInnermostParallel = 1;
278     if (!Info->Context)
279       Info->Context = isl_ast_build_copy(Build);
280   }
281 
282   isl_id_free(Id);
283   return Node;
284 }
285 
286 static __isl_give isl_ast_node *
287 AtEachDomain(__isl_take isl_ast_node *Node, __isl_keep isl_ast_build *Context,
288              void *User) {
289   struct IslAstUser *Info = NULL;
290   isl_id *Id = isl_ast_node_get_annotation(Node);
291 
292   if (Id)
293     Info = (struct IslAstUser *)isl_id_get_user(Id);
294 
295   if (!Info) {
296     // Allocate annotations once: parallel for detection might have already
297     // allocated the annotations for this node.
298     Info = allocateIslAstUser();
299     Id = isl_id_alloc(isl_ast_node_get_ctx(Node), NULL, Info);
300     Id = isl_id_set_free_user(Id, &freeIslAstUser);
301   }
302 
303   if (!Info->PMA) {
304     isl_map *Map = isl_map_from_union_map(isl_ast_build_get_schedule(Context));
305     Info->PMA = isl_pw_multi_aff_from_map(isl_map_reverse(Map));
306   }
307   if (!Info->Context)
308     Info->Context = isl_ast_build_copy(Context);
309 
310   return isl_ast_node_set_annotation(Node, Id);
311 }
312 
313 IslAst::IslAst(Scop *Scop, Dependences &D) : S(Scop) {
314   isl_ctx *Ctx = S->getIslCtx();
315   isl_options_set_ast_build_atomic_upper_bound(Ctx, true);
316   isl_ast_build *Context;
317   struct AstBuildUserInfo BuildInfo;
318 
319   if (UseContext)
320     Context = isl_ast_build_from_context(S->getContext());
321   else
322     Context = isl_ast_build_from_context(isl_set_universe(S->getParamSpace()));
323 
324   Context = isl_ast_build_set_at_each_domain(Context, AtEachDomain, NULL);
325 
326   isl_union_map *Schedule = getSchedule();
327 
328   Function *F = Scop->getRegion().getEntry()->getParent();
329 
330   DEBUG(dbgs() << ":: isl ast :: " << F->getName()
331                << " :: " << Scop->getRegion().getNameStr() << "\n");;
332   DEBUG(dbgs() << S->getContextStr() << "\n";
333     isl_union_map_dump(Schedule);
334   );
335 
336   if (DetectParallel || PollyVectorizerChoice != VECTORIZER_NONE) {
337     BuildInfo.Deps = &D;
338     BuildInfo.InParallelFor = 0;
339 
340     Context = isl_ast_build_set_before_each_for(Context, &astBuildBeforeFor,
341                                                 &BuildInfo);
342     Context = isl_ast_build_set_after_each_for(Context, &astBuildAfterFor,
343                                                &BuildInfo);
344   }
345 
346   Root = isl_ast_build_ast_from_schedule(Context, Schedule);
347 
348   isl_ast_build_free(Context);
349 
350   DEBUG(pprint(dbgs()));
351 }
352 
353 __isl_give isl_union_map *IslAst::getSchedule() {
354   isl_union_map *Schedule = isl_union_map_empty(S->getParamSpace());
355 
356   for (Scop::iterator SI = S->begin(), SE = S->end(); SI != SE; ++SI) {
357     ScopStmt *Stmt = *SI;
358     isl_map *StmtSchedule = Stmt->getScattering();
359 
360     StmtSchedule = isl_map_intersect_domain(StmtSchedule, Stmt->getDomain());
361     Schedule =
362         isl_union_map_union(Schedule, isl_union_map_from_map(StmtSchedule));
363   }
364 
365   return Schedule;
366 }
367 
368 IslAst::~IslAst() { isl_ast_node_free(Root); }
369 
370 /// Print a C like representation of the program.
371 void IslAst::pprint(llvm::raw_ostream &OS) {
372   isl_ast_node *Root;
373   isl_ast_print_options *Options;
374 
375   Options = isl_ast_print_options_alloc(S->getIslCtx());
376   Options = isl_ast_print_options_set_print_for(Options, &printFor, NULL);
377 
378   isl_printer *P = isl_printer_to_str(S->getIslCtx());
379   P = isl_printer_set_output_format(P, ISL_FORMAT_C);
380   Root = getAst();
381   P = isl_ast_node_print(Root, P, Options);
382   char *result = isl_printer_get_str(P);
383   OS << result << "\n";
384   isl_printer_free(P);
385   isl_ast_node_free(Root);
386 }
387 
388 /// Create the isl_ast from this program.
389 __isl_give isl_ast_node *IslAst::getAst() { return isl_ast_node_copy(Root); }
390 
391 void IslAstInfo::pprint(llvm::raw_ostream &OS) { Ast->pprint(OS); }
392 
393 void IslAstInfo::releaseMemory() {
394   if (Ast) {
395     delete Ast;
396     Ast = 0;
397   }
398 }
399 
400 bool IslAstInfo::runOnScop(Scop &Scop) {
401   if (Ast)
402     delete Ast;
403 
404   S = &Scop;
405 
406   Dependences &D = getAnalysis<Dependences>();
407 
408   Ast = new IslAst(&Scop, D);
409 
410   return false;
411 }
412 
413 __isl_give isl_ast_node *IslAstInfo::getAst() { return Ast->getAst(); }
414 
415 void IslAstInfo::printScop(raw_ostream &OS) const {
416   Function *F = S->getRegion().getEntry()->getParent();
417 
418   OS << F->getName() << "():\n";
419 
420   Ast->pprint(OS);
421 }
422 
423 void IslAstInfo::getAnalysisUsage(AnalysisUsage &AU) const {
424   // Get the Common analysis usage of ScopPasses.
425   ScopPass::getAnalysisUsage(AU);
426   AU.addRequired<ScopInfo>();
427   AU.addRequired<Dependences>();
428 }
429 char IslAstInfo::ID = 0;
430 
431 INITIALIZE_PASS_BEGIN(IslAstInfo, "polly-ast",
432                       "Generate an AST of the SCoP (isl)", false, false)
433 INITIALIZE_PASS_DEPENDENCY(ScopInfo)
434 INITIALIZE_PASS_DEPENDENCY(Dependences)
435 INITIALIZE_PASS_END(IslAstInfo, "polly-ast",
436                     "Generate an AST from the SCoP (isl)", false, false)
437 
438 Pass *polly::createIslAstInfoPass() {
439   return new IslAstInfo();
440 }
441