1 //===- IslAst.cpp - isl code generator interface --------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // The isl code generator interface takes a Scop and generates a isl_ast. This
11 // ist_ast can either be returned directly or it can be pretty printed to
12 // stdout.
13 //
14 // A typical isl_ast output looks like this:
15 //
16 // for (c2 = max(0, ceild(n + m, 2); c2 <= min(511, floord(5 * n, 3)); c2++) {
17 //   bb2(c2);
18 // }
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "polly/CodeGen/CodeGeneration.h"
23 #include "polly/CodeGen/IslAst.h"
24 
25 #include "polly/LinkAllPasses.h"
26 #include "polly/Dependences.h"
27 #include "polly/ScopInfo.h"
28 
29 #define DEBUG_TYPE "polly-ast"
30 #include "llvm/Support/CommandLine.h"
31 #include "llvm/Support/Debug.h"
32 
33 #include "isl/union_map.h"
34 #include "isl/list.h"
35 #include "isl/ast_build.h"
36 #include "isl/set.h"
37 #include "isl/map.h"
38 #include "isl/aff.h"
39 
40 using namespace llvm;
41 using namespace polly;
42 
43 static cl::opt<bool>
44 UseContext("polly-ast-use-context", cl::desc("Use context"), cl::Hidden,
45            cl::init(false), cl::ZeroOrMore);
46 
47 static cl::opt<bool>
48 DetectParallel("polly-ast-detect-parallel", cl::desc("Detect parallelism"),
49                cl::Hidden, cl::init(false), cl::ZeroOrMore);
50 
51 namespace polly {
52 class IslAst {
53 public:
54   IslAst(Scop *Scop, Dependences &D);
55 
56   ~IslAst();
57 
58   /// Print a source code representation of the program.
59   void pprint(llvm::raw_ostream &OS);
60 
61   __isl_give isl_ast_node *getAst();
62 
63 private:
64   Scop *S;
65   isl_ast_node *Root;
66 
67   __isl_give isl_union_map *getSchedule();
68 };
69 } // End namespace polly.
70 
71 // Temporary information used when building the ast.
72 struct AstBuildUserInfo {
73   // The dependence information.
74   Dependences *Deps;
75 
76   // We are inside a parallel for node.
77   int InParallelFor;
78 };
79 
80 // Print a loop annotated with OpenMP or vector pragmas.
81 static __isl_give isl_printer *printParallelFor(
82     __isl_keep isl_ast_node *Node, __isl_take isl_printer *Printer,
83     __isl_take isl_ast_print_options *PrintOptions, IslAstUser *Info) {
84   if (Info) {
85     if (Info->IsInnermostParallel) {
86       Printer = isl_printer_start_line(Printer);
87       Printer = isl_printer_print_str(Printer, "#pragma simd");
88       Printer = isl_printer_end_line(Printer);
89     }
90     if (Info->IsOutermostParallel) {
91       Printer = isl_printer_start_line(Printer);
92       Printer = isl_printer_print_str(Printer, "#pragma omp parallel for");
93       Printer = isl_printer_end_line(Printer);
94     }
95   }
96   return isl_ast_node_for_print(Node, Printer, PrintOptions);
97 }
98 
99 // Print an isl_ast_for.
100 static __isl_give isl_printer *
101 printFor(__isl_take isl_printer *Printer,
102          __isl_take isl_ast_print_options *PrintOptions,
103          __isl_keep isl_ast_node *Node, void *User) {
104   isl_id *Id = isl_ast_node_get_annotation(Node);
105   if (!Id)
106     return isl_ast_node_for_print(Node, Printer, PrintOptions);
107 
108   struct IslAstUser *Info = (struct IslAstUser *)isl_id_get_user(Id);
109   Printer = printParallelFor(Node, Printer, PrintOptions, Info);
110   isl_id_free(Id);
111   return Printer;
112 }
113 
114 // Allocate an AstNodeInfo structure and initialize it with default values.
115 static struct IslAstUser *allocateIslAstUser() {
116   struct IslAstUser *NodeInfo;
117   NodeInfo = (struct IslAstUser *)malloc(sizeof(struct IslAstUser));
118   NodeInfo->PMA = 0;
119   NodeInfo->Context = 0;
120   NodeInfo->IsOutermostParallel = 0;
121   NodeInfo->IsInnermostParallel = 0;
122   return NodeInfo;
123 }
124 
125 // Free the AstNodeInfo structure.
126 static void freeIslAstUser(void *Ptr) {
127   struct IslAstUser *UserStruct = (struct IslAstUser *)Ptr;
128   isl_ast_build_free(UserStruct->Context);
129   isl_pw_multi_aff_free(UserStruct->PMA);
130   free(UserStruct);
131 }
132 
133 // Check if the current scheduling dimension is parallel.
134 //
135 // We check for parallelism by verifying that the loop does not carry any
136 // dependences.
137 //
138 // Parallelism test: if the distance is zero in all outer dimensions, then it
139 // has to be zero in the current dimension as well.
140 //
141 // Implementation: first, translate dependences into time space, then force
142 // outer dimensions to be equal. If the distance is zero in the current
143 // dimension, then the loop is parallel. The distance is zero in the current
144 // dimension if it is a subset of a map with equal values for the current
145 // dimension.
146 static bool astScheduleDimIsParallel(__isl_keep isl_ast_build *Build,
147                                      Dependences *D) {
148   isl_union_map *Schedule, *Deps;
149   isl_map *ScheduleDeps, *Test;
150   isl_space *ScheduleSpace;
151   unsigned Dimension, IsParallel;
152 
153   Schedule = isl_ast_build_get_schedule(Build);
154   ScheduleSpace = isl_ast_build_get_schedule_space(Build);
155 
156   Dimension = isl_space_dim(ScheduleSpace, isl_dim_out) - 1;
157 
158   Deps = D->getDependences(Dependences::TYPE_ALL);
159   Deps = isl_union_map_apply_range(Deps, isl_union_map_copy(Schedule));
160   Deps = isl_union_map_apply_domain(Deps, Schedule);
161 
162   if (isl_union_map_is_empty(Deps)) {
163     isl_union_map_free(Deps);
164     isl_space_free(ScheduleSpace);
165     return 1;
166   }
167 
168   ScheduleDeps = isl_map_from_union_map(Deps);
169 
170   for (unsigned i = 0; i < Dimension; i++)
171     ScheduleDeps = isl_map_equate(ScheduleDeps, isl_dim_out, i, isl_dim_in, i);
172 
173   Test = isl_map_universe(isl_map_get_space(ScheduleDeps));
174   Test = isl_map_equate(Test, isl_dim_out, Dimension, isl_dim_in, Dimension);
175   IsParallel = isl_map_is_subset(ScheduleDeps, Test);
176 
177   isl_space_free(ScheduleSpace);
178   isl_map_free(Test);
179   isl_map_free(ScheduleDeps);
180 
181   return IsParallel;
182 }
183 
184 // Mark a for node openmp parallel, if it is the outermost parallel for node.
185 static void markOpenmpParallel(__isl_keep isl_ast_build *Build,
186                                struct AstBuildUserInfo *BuildInfo,
187                                struct IslAstUser *NodeInfo) {
188   if (BuildInfo->InParallelFor)
189     return;
190 
191   if (astScheduleDimIsParallel(Build, BuildInfo->Deps)) {
192     BuildInfo->InParallelFor = 1;
193     NodeInfo->IsOutermostParallel = 1;
194   }
195 }
196 
197 // This method is executed before the construction of a for node. It creates
198 // an isl_id that is used to annotate the subsequently generated ast for nodes.
199 //
200 // In this function we also run the following analyses:
201 //
202 // - Detection of openmp parallel loops
203 //
204 static __isl_give isl_id *
205 astBuildBeforeFor(__isl_keep isl_ast_build *Build, void *User) {
206   struct AstBuildUserInfo *BuildInfo = (struct AstBuildUserInfo *)User;
207   struct IslAstUser *NodeInfo = allocateIslAstUser();
208   isl_id *Id = isl_id_alloc(isl_ast_build_get_ctx(Build), "", NodeInfo);
209   Id = isl_id_set_free_user(Id, freeIslAstUser);
210 
211   markOpenmpParallel(Build, BuildInfo, NodeInfo);
212 
213   return Id;
214 }
215 
216 // Returns 0 when Node contains loops, otherwise returns -1. This search
217 // function uses ISL's way to iterate over lists of isl_ast_nodes with
218 // isl_ast_node_list_foreach. Please use the single argument wrapper function
219 // that returns a bool instead of using this function directly.
220 static int containsLoops(__isl_take isl_ast_node *Node, void *User) {
221   if (!Node)
222     return -1;
223 
224   switch (isl_ast_node_get_type(Node)) {
225   case isl_ast_node_for:
226     isl_ast_node_free(Node);
227     return 0;
228   case isl_ast_node_block: {
229     isl_ast_node_list *List = isl_ast_node_block_get_children(Node);
230     int Res = isl_ast_node_list_foreach(List, &containsLoops, NULL);
231     isl_ast_node_list_free(List);
232     isl_ast_node_free(Node);
233     return Res;
234   }
235   case isl_ast_node_if: {
236     int Res = -1;
237     if (0 == containsLoops(isl_ast_node_if_get_then(Node), NULL) ||
238         (isl_ast_node_if_has_else(Node) &&
239          0 == containsLoops(isl_ast_node_if_get_else(Node), NULL)))
240       Res = 0;
241     isl_ast_node_free(Node);
242     return Res;
243   }
244   case isl_ast_node_user:
245   default:
246     isl_ast_node_free(Node);
247     return -1;
248   }
249 }
250 
251 // Returns true when Node contains loops.
252 static bool containsLoops(__isl_take isl_ast_node *Node) {
253   return 0 == containsLoops(Node, NULL);
254 }
255 
256 // This method is executed after the construction of a for node.
257 //
258 // It performs the following actions:
259 //
260 // - Reset the 'InParallelFor' flag, as soon as we leave a for node,
261 //   that is marked as openmp parallel.
262 //
263 static __isl_give isl_ast_node *
264 astBuildAfterFor(__isl_take isl_ast_node *Node, __isl_keep isl_ast_build *Build,
265                  void *User) {
266   isl_id *Id = isl_ast_node_get_annotation(Node);
267   if (!Id)
268     return Node;
269   struct IslAstUser *Info = (struct IslAstUser *)isl_id_get_user(Id);
270   struct AstBuildUserInfo *BuildInfo = (struct AstBuildUserInfo *)User;
271 
272   if (Info) {
273     if (Info->IsOutermostParallel)
274       BuildInfo->InParallelFor = 0;
275     if (!containsLoops(isl_ast_node_for_get_body(Node)))
276       if (astScheduleDimIsParallel(Build, BuildInfo->Deps))
277         Info->IsInnermostParallel = 1;
278     if (!Info->Context)
279       Info->Context = isl_ast_build_copy(Build);
280   }
281 
282   isl_id_free(Id);
283   return Node;
284 }
285 
286 static __isl_give isl_ast_node *
287 AtEachDomain(__isl_take isl_ast_node *Node, __isl_keep isl_ast_build *Context,
288              void *User) {
289   struct IslAstUser *Info = NULL;
290   isl_id *Id = isl_ast_node_get_annotation(Node);
291 
292   if (Id)
293     Info = (struct IslAstUser *)isl_id_get_user(Id);
294 
295   if (!Info) {
296     // Allocate annotations once: parallel for detection might have already
297     // allocated the annotations for this node.
298     Info = allocateIslAstUser();
299     Id = isl_id_alloc(isl_ast_node_get_ctx(Node), NULL, Info);
300     Id = isl_id_set_free_user(Id, &freeIslAstUser);
301   }
302 
303   if (!Info->PMA) {
304     isl_map *Map = isl_map_from_union_map(isl_ast_build_get_schedule(Context));
305     Info->PMA = isl_pw_multi_aff_from_map(isl_map_reverse(Map));
306   }
307   if (!Info->Context)
308     Info->Context = isl_ast_build_copy(Context);
309 
310   return isl_ast_node_set_annotation(Node, Id);
311 }
312 
313 IslAst::IslAst(Scop *Scop, Dependences &D) : S(Scop) {
314   isl_ctx *Ctx = S->getIslCtx();
315   isl_options_set_ast_build_atomic_upper_bound(Ctx, true);
316   isl_ast_build *Context;
317   struct AstBuildUserInfo BuildInfo;
318 
319   if (UseContext)
320     Context = isl_ast_build_from_context(S->getContext());
321   else
322     Context = isl_ast_build_from_context(isl_set_universe(S->getParamSpace()));
323 
324   Context = isl_ast_build_set_at_each_domain(Context, AtEachDomain, NULL);
325 
326   isl_union_map *Schedule = getSchedule();
327 
328   Function *F = Scop->getRegion().getEntry()->getParent();
329   (void)F;
330 
331   DEBUG(dbgs() << ":: isl ast :: " << F->getName()
332                << " :: " << Scop->getRegion().getNameStr() << "\n");
333 
334   DEBUG(dbgs() << S->getContextStr() << "\n";
335         isl_union_map_dump(Schedule));
336 
337   if (DetectParallel || PollyVectorizerChoice != VECTORIZER_NONE) {
338     BuildInfo.Deps = &D;
339     BuildInfo.InParallelFor = 0;
340 
341     Context = isl_ast_build_set_before_each_for(Context, &astBuildBeforeFor,
342                                                 &BuildInfo);
343     Context = isl_ast_build_set_after_each_for(Context, &astBuildAfterFor,
344                                                &BuildInfo);
345   }
346 
347   Root = isl_ast_build_ast_from_schedule(Context, Schedule);
348 
349   isl_ast_build_free(Context);
350 
351   DEBUG(pprint(dbgs()));
352 }
353 
354 __isl_give isl_union_map *IslAst::getSchedule() {
355   isl_union_map *Schedule = isl_union_map_empty(S->getParamSpace());
356 
357   for (Scop::iterator SI = S->begin(), SE = S->end(); SI != SE; ++SI) {
358     ScopStmt *Stmt = *SI;
359     isl_map *StmtSchedule = Stmt->getScattering();
360 
361     StmtSchedule = isl_map_intersect_domain(StmtSchedule, Stmt->getDomain());
362     Schedule =
363         isl_union_map_union(Schedule, isl_union_map_from_map(StmtSchedule));
364   }
365 
366   return Schedule;
367 }
368 
369 IslAst::~IslAst() { isl_ast_node_free(Root); }
370 
371 /// Print a C like representation of the program.
372 void IslAst::pprint(llvm::raw_ostream &OS) {
373   isl_ast_node *Root;
374   isl_ast_print_options *Options;
375 
376   Options = isl_ast_print_options_alloc(S->getIslCtx());
377   Options = isl_ast_print_options_set_print_for(Options, &printFor, NULL);
378 
379   isl_printer *P = isl_printer_to_str(S->getIslCtx());
380   P = isl_printer_set_output_format(P, ISL_FORMAT_C);
381   Root = getAst();
382   P = isl_ast_node_print(Root, P, Options);
383   char *result = isl_printer_get_str(P);
384   OS << result << "\n";
385   isl_printer_free(P);
386   isl_ast_node_free(Root);
387 }
388 
389 /// Create the isl_ast from this program.
390 __isl_give isl_ast_node *IslAst::getAst() { return isl_ast_node_copy(Root); }
391 
392 void IslAstInfo::pprint(llvm::raw_ostream &OS) { Ast->pprint(OS); }
393 
394 void IslAstInfo::releaseMemory() {
395   if (Ast) {
396     delete Ast;
397     Ast = 0;
398   }
399 }
400 
401 bool IslAstInfo::runOnScop(Scop &Scop) {
402   if (Ast)
403     delete Ast;
404 
405   S = &Scop;
406 
407   Dependences &D = getAnalysis<Dependences>();
408 
409   Ast = new IslAst(&Scop, D);
410 
411   return false;
412 }
413 
414 __isl_give isl_ast_node *IslAstInfo::getAst() { return Ast->getAst(); }
415 
416 void IslAstInfo::printScop(raw_ostream &OS) const {
417   Function *F = S->getRegion().getEntry()->getParent();
418 
419   OS << F->getName() << "():\n";
420 
421   Ast->pprint(OS);
422 }
423 
424 void IslAstInfo::getAnalysisUsage(AnalysisUsage &AU) const {
425   // Get the Common analysis usage of ScopPasses.
426   ScopPass::getAnalysisUsage(AU);
427   AU.addRequired<ScopInfo>();
428   AU.addRequired<Dependences>();
429 }
430 
431 char IslAstInfo::ID = 0;
432 
433 Pass *polly::createIslAstInfoPass() { return new IslAstInfo(); }
434 
435 INITIALIZE_PASS_BEGIN(IslAstInfo, "polly-ast",
436                       "Generate an AST of the SCoP (isl)", false, false);
437 INITIALIZE_PASS_DEPENDENCY(ScopInfo);
438 INITIALIZE_PASS_DEPENDENCY(Dependences);
439 INITIALIZE_PASS_END(IslAstInfo, "polly-ast",
440                     "Generate an AST from the SCoP (isl)", false, false)
441 
442