1 //===- IslAst.cpp - isl code generator interface --------------------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // The isl code generator interface takes a Scop and generates a isl_ast. This 11 // ist_ast can either be returned directly or it can be pretty printed to 12 // stdout. 13 // 14 // A typical isl_ast output looks like this: 15 // 16 // for (c2 = max(0, ceild(n + m, 2); c2 <= min(511, floord(5 * n, 3)); c2++) { 17 // bb2(c2); 18 // } 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "polly/CodeGen/CodeGeneration.h" 23 #include "polly/CodeGen/IslAst.h" 24 #include "polly/Dependences.h" 25 #include "polly/LinkAllPasses.h" 26 #include "polly/Options.h" 27 #include "polly/ScopInfo.h" 28 29 #define DEBUG_TYPE "polly-ast" 30 #include "llvm/Support/Debug.h" 31 32 #include "isl/union_map.h" 33 #include "isl/list.h" 34 #include "isl/ast_build.h" 35 #include "isl/set.h" 36 #include "isl/map.h" 37 #include "isl/aff.h" 38 39 using namespace llvm; 40 using namespace polly; 41 42 static cl::opt<bool> UseContext("polly-ast-use-context", 43 cl::desc("Use context"), cl::Hidden, 44 cl::init(false), cl::ZeroOrMore, 45 cl::cat(PollyCategory)); 46 47 static cl::opt<bool> DetectParallel("polly-ast-detect-parallel", 48 cl::desc("Detect parallelism"), cl::Hidden, 49 cl::init(false), cl::ZeroOrMore, 50 cl::cat(PollyCategory)); 51 52 namespace polly { 53 class IslAst { 54 public: 55 IslAst(Scop *Scop, Dependences &D); 56 57 ~IslAst(); 58 59 /// Print a source code representation of the program. 60 void pprint(llvm::raw_ostream &OS); 61 62 __isl_give isl_ast_node *getAst(); 63 64 /// @brief Get the run-time conditions for the Scop. 65 __isl_give isl_ast_expr *getRunCondition(); 66 67 private: 68 Scop *S; 69 isl_ast_node *Root; 70 isl_ast_expr *RunCondition; 71 72 __isl_give isl_union_map *getSchedule(); 73 void buildRunCondition(__isl_keep isl_ast_build *Context); 74 }; 75 } // End namespace polly. 76 77 // Temporary information used when building the ast. 78 struct AstBuildUserInfo { 79 // The dependence information. 80 Dependences *Deps; 81 82 // We are inside a parallel for node. 83 int InParallelFor; 84 }; 85 86 // Print a loop annotated with OpenMP or vector pragmas. 87 static __isl_give isl_printer * 88 printParallelFor(__isl_keep isl_ast_node *Node, __isl_take isl_printer *Printer, 89 __isl_take isl_ast_print_options *PrintOptions, 90 IslAstUser *Info) { 91 if (Info) { 92 if (Info->IsInnermostParallel) { 93 Printer = isl_printer_start_line(Printer); 94 Printer = isl_printer_print_str(Printer, "#pragma simd"); 95 Printer = isl_printer_end_line(Printer); 96 } 97 if (Info->IsOutermostParallel) { 98 Printer = isl_printer_start_line(Printer); 99 Printer = isl_printer_print_str(Printer, "#pragma omp parallel for"); 100 Printer = isl_printer_end_line(Printer); 101 } 102 } 103 return isl_ast_node_for_print(Node, Printer, PrintOptions); 104 } 105 106 // Print an isl_ast_for. 107 static __isl_give isl_printer * 108 printFor(__isl_take isl_printer *Printer, 109 __isl_take isl_ast_print_options *PrintOptions, 110 __isl_keep isl_ast_node *Node, void *User) { 111 isl_id *Id = isl_ast_node_get_annotation(Node); 112 if (!Id) 113 return isl_ast_node_for_print(Node, Printer, PrintOptions); 114 115 struct IslAstUser *Info = (struct IslAstUser *)isl_id_get_user(Id); 116 Printer = printParallelFor(Node, Printer, PrintOptions, Info); 117 isl_id_free(Id); 118 return Printer; 119 } 120 121 // Allocate an AstNodeInfo structure and initialize it with default values. 122 static struct IslAstUser *allocateIslAstUser() { 123 struct IslAstUser *NodeInfo; 124 NodeInfo = (struct IslAstUser *)malloc(sizeof(struct IslAstUser)); 125 NodeInfo->PMA = 0; 126 NodeInfo->Context = 0; 127 NodeInfo->IsOutermostParallel = 0; 128 NodeInfo->IsInnermostParallel = 0; 129 return NodeInfo; 130 } 131 132 // Free the AstNodeInfo structure. 133 static void freeIslAstUser(void *Ptr) { 134 struct IslAstUser *UserStruct = (struct IslAstUser *)Ptr; 135 isl_ast_build_free(UserStruct->Context); 136 isl_pw_multi_aff_free(UserStruct->PMA); 137 free(UserStruct); 138 } 139 140 // Check if the current scheduling dimension is parallel. 141 // 142 // We check for parallelism by verifying that the loop does not carry any 143 // dependences. 144 // 145 // Parallelism test: if the distance is zero in all outer dimensions, then it 146 // has to be zero in the current dimension as well. 147 // 148 // Implementation: first, translate dependences into time space, then force 149 // outer dimensions to be equal. If the distance is zero in the current 150 // dimension, then the loop is parallel. The distance is zero in the current 151 // dimension if it is a subset of a map with equal values for the current 152 // dimension. 153 static bool astScheduleDimIsParallel(__isl_keep isl_ast_build *Build, 154 Dependences *D) { 155 isl_union_map *Schedule, *Deps; 156 isl_map *ScheduleDeps, *Test; 157 isl_space *ScheduleSpace; 158 unsigned Dimension, IsParallel; 159 160 Schedule = isl_ast_build_get_schedule(Build); 161 ScheduleSpace = isl_ast_build_get_schedule_space(Build); 162 163 Dimension = isl_space_dim(ScheduleSpace, isl_dim_out) - 1; 164 165 Deps = D->getDependences(Dependences::TYPE_ALL); 166 Deps = isl_union_map_apply_range(Deps, isl_union_map_copy(Schedule)); 167 Deps = isl_union_map_apply_domain(Deps, Schedule); 168 169 if (isl_union_map_is_empty(Deps)) { 170 isl_union_map_free(Deps); 171 isl_space_free(ScheduleSpace); 172 return 1; 173 } 174 175 ScheduleDeps = isl_map_from_union_map(Deps); 176 177 for (unsigned i = 0; i < Dimension; i++) 178 ScheduleDeps = isl_map_equate(ScheduleDeps, isl_dim_out, i, isl_dim_in, i); 179 180 Test = isl_map_universe(isl_map_get_space(ScheduleDeps)); 181 Test = isl_map_equate(Test, isl_dim_out, Dimension, isl_dim_in, Dimension); 182 IsParallel = isl_map_is_subset(ScheduleDeps, Test); 183 184 isl_space_free(ScheduleSpace); 185 isl_map_free(Test); 186 isl_map_free(ScheduleDeps); 187 188 return IsParallel; 189 } 190 191 // Mark a for node openmp parallel, if it is the outermost parallel for node. 192 static void markOpenmpParallel(__isl_keep isl_ast_build *Build, 193 struct AstBuildUserInfo *BuildInfo, 194 struct IslAstUser *NodeInfo) { 195 if (BuildInfo->InParallelFor) 196 return; 197 198 if (astScheduleDimIsParallel(Build, BuildInfo->Deps)) { 199 BuildInfo->InParallelFor = 1; 200 NodeInfo->IsOutermostParallel = 1; 201 } 202 } 203 204 // This method is executed before the construction of a for node. It creates 205 // an isl_id that is used to annotate the subsequently generated ast for nodes. 206 // 207 // In this function we also run the following analyses: 208 // 209 // - Detection of openmp parallel loops 210 // 211 static __isl_give isl_id *astBuildBeforeFor(__isl_keep isl_ast_build *Build, 212 void *User) { 213 struct AstBuildUserInfo *BuildInfo = (struct AstBuildUserInfo *)User; 214 struct IslAstUser *NodeInfo = allocateIslAstUser(); 215 isl_id *Id = isl_id_alloc(isl_ast_build_get_ctx(Build), "", NodeInfo); 216 Id = isl_id_set_free_user(Id, freeIslAstUser); 217 218 markOpenmpParallel(Build, BuildInfo, NodeInfo); 219 220 return Id; 221 } 222 223 // Returns 0 when Node contains loops, otherwise returns -1. This search 224 // function uses ISL's way to iterate over lists of isl_ast_nodes with 225 // isl_ast_node_list_foreach. Please use the single argument wrapper function 226 // that returns a bool instead of using this function directly. 227 static int containsLoops(__isl_take isl_ast_node *Node, void *User) { 228 if (!Node) 229 return -1; 230 231 switch (isl_ast_node_get_type(Node)) { 232 case isl_ast_node_for: 233 isl_ast_node_free(Node); 234 return 0; 235 case isl_ast_node_block: { 236 isl_ast_node_list *List = isl_ast_node_block_get_children(Node); 237 int Res = isl_ast_node_list_foreach(List, &containsLoops, NULL); 238 isl_ast_node_list_free(List); 239 isl_ast_node_free(Node); 240 return Res; 241 } 242 case isl_ast_node_if: { 243 int Res = -1; 244 if (0 == containsLoops(isl_ast_node_if_get_then(Node), NULL) || 245 (isl_ast_node_if_has_else(Node) && 246 0 == containsLoops(isl_ast_node_if_get_else(Node), NULL))) 247 Res = 0; 248 isl_ast_node_free(Node); 249 return Res; 250 } 251 case isl_ast_node_user: 252 default: 253 isl_ast_node_free(Node); 254 return -1; 255 } 256 } 257 258 // Returns true when Node contains loops. 259 static bool containsLoops(__isl_take isl_ast_node *Node) { 260 return 0 == containsLoops(Node, NULL); 261 } 262 263 // This method is executed after the construction of a for node. 264 // 265 // It performs the following actions: 266 // 267 // - Reset the 'InParallelFor' flag, as soon as we leave a for node, 268 // that is marked as openmp parallel. 269 // 270 static __isl_give isl_ast_node * 271 astBuildAfterFor(__isl_take isl_ast_node *Node, __isl_keep isl_ast_build *Build, 272 void *User) { 273 isl_id *Id = isl_ast_node_get_annotation(Node); 274 if (!Id) 275 return Node; 276 struct IslAstUser *Info = (struct IslAstUser *)isl_id_get_user(Id); 277 struct AstBuildUserInfo *BuildInfo = (struct AstBuildUserInfo *)User; 278 279 if (Info) { 280 if (Info->IsOutermostParallel) 281 BuildInfo->InParallelFor = 0; 282 if (!containsLoops(isl_ast_node_for_get_body(Node))) 283 if (astScheduleDimIsParallel(Build, BuildInfo->Deps)) 284 Info->IsInnermostParallel = 1; 285 if (!Info->Context) 286 Info->Context = isl_ast_build_copy(Build); 287 } 288 289 isl_id_free(Id); 290 return Node; 291 } 292 293 static __isl_give isl_ast_node *AtEachDomain(__isl_take isl_ast_node *Node, 294 __isl_keep isl_ast_build *Context, 295 void *User) { 296 struct IslAstUser *Info = NULL; 297 isl_id *Id = isl_ast_node_get_annotation(Node); 298 299 if (Id) 300 Info = (struct IslAstUser *)isl_id_get_user(Id); 301 302 if (!Info) { 303 // Allocate annotations once: parallel for detection might have already 304 // allocated the annotations for this node. 305 Info = allocateIslAstUser(); 306 Id = isl_id_alloc(isl_ast_node_get_ctx(Node), NULL, Info); 307 Id = isl_id_set_free_user(Id, &freeIslAstUser); 308 } 309 310 if (!Info->PMA) { 311 isl_map *Map = isl_map_from_union_map(isl_ast_build_get_schedule(Context)); 312 Info->PMA = isl_pw_multi_aff_from_map(isl_map_reverse(Map)); 313 } 314 if (!Info->Context) 315 Info->Context = isl_ast_build_copy(Context); 316 317 return isl_ast_node_set_annotation(Node, Id); 318 } 319 320 void IslAst::buildRunCondition(__isl_keep isl_ast_build *Context) { 321 // The conditions that need to be checked at run-time for this scop are 322 // available as an isl_set in the AssumedContext. We generate code for this 323 // check as follows. First, we generate an isl_pw_aff that is 1, if a certain 324 // combination of parameter values fulfills the conditions in the assumed 325 // context, and that is 0 otherwise. We then translate this isl_pw_aff into 326 // an isl_ast_expr. At run-time this expression can be evaluated and the 327 // optimized scop can be executed conditionally according to the result of the 328 // run-time check. 329 330 isl_aff *Zero = 331 isl_aff_zero_on_domain(isl_local_space_from_space(S->getParamSpace())); 332 isl_aff *One = 333 isl_aff_zero_on_domain(isl_local_space_from_space(S->getParamSpace())); 334 335 One = isl_aff_add_constant_si(One, 1); 336 337 isl_pw_aff *PwZero = isl_pw_aff_from_aff(Zero); 338 isl_pw_aff *PwOne = isl_pw_aff_from_aff(One); 339 340 PwOne = isl_pw_aff_intersect_domain(PwOne, S->getAssumedContext()); 341 PwZero = isl_pw_aff_intersect_domain( 342 PwZero, isl_set_complement(S->getAssumedContext())); 343 344 isl_pw_aff *Cond = isl_pw_aff_union_max(PwZero, PwOne); 345 346 RunCondition = isl_ast_build_expr_from_pw_aff(Context, Cond); 347 } 348 349 IslAst::IslAst(Scop *Scop, Dependences &D) : S(Scop) { 350 isl_ctx *Ctx = S->getIslCtx(); 351 isl_options_set_ast_build_atomic_upper_bound(Ctx, true); 352 isl_ast_build *Context; 353 struct AstBuildUserInfo BuildInfo; 354 355 if (UseContext) 356 Context = isl_ast_build_from_context(S->getContext()); 357 else 358 Context = isl_ast_build_from_context(isl_set_universe(S->getParamSpace())); 359 360 Context = isl_ast_build_set_at_each_domain(Context, AtEachDomain, NULL); 361 362 isl_union_map *Schedule = getSchedule(); 363 364 Function *F = Scop->getRegion().getEntry()->getParent(); 365 (void)F; 366 367 DEBUG(dbgs() << ":: isl ast :: " << F->getName() 368 << " :: " << Scop->getRegion().getNameStr() << "\n"); 369 370 DEBUG(dbgs() << S->getContextStr() << "\n"; isl_union_map_dump(Schedule)); 371 372 if (DetectParallel || PollyVectorizerChoice != VECTORIZER_NONE) { 373 BuildInfo.Deps = &D; 374 BuildInfo.InParallelFor = 0; 375 376 Context = isl_ast_build_set_before_each_for(Context, &astBuildBeforeFor, 377 &BuildInfo); 378 Context = isl_ast_build_set_after_each_for(Context, &astBuildAfterFor, 379 &BuildInfo); 380 } 381 382 buildRunCondition(Context); 383 384 Root = isl_ast_build_ast_from_schedule(Context, Schedule); 385 386 isl_ast_build_free(Context); 387 388 DEBUG(pprint(dbgs())); 389 } 390 391 __isl_give isl_union_map *IslAst::getSchedule() { 392 isl_union_map *Schedule = isl_union_map_empty(S->getParamSpace()); 393 394 for (Scop::iterator SI = S->begin(), SE = S->end(); SI != SE; ++SI) { 395 ScopStmt *Stmt = *SI; 396 isl_map *StmtSchedule = Stmt->getScattering(); 397 398 StmtSchedule = isl_map_intersect_domain(StmtSchedule, Stmt->getDomain()); 399 Schedule = 400 isl_union_map_union(Schedule, isl_union_map_from_map(StmtSchedule)); 401 } 402 403 return Schedule; 404 } 405 406 IslAst::~IslAst() { 407 isl_ast_node_free(Root); 408 isl_ast_expr_free(RunCondition); 409 } 410 411 /// Print a C like representation of the program. 412 void IslAst::pprint(llvm::raw_ostream &OS) { 413 isl_ast_node *Root; 414 isl_ast_print_options *Options; 415 416 Options = isl_ast_print_options_alloc(S->getIslCtx()); 417 Options = isl_ast_print_options_set_print_for(Options, &printFor, NULL); 418 419 isl_printer *P = isl_printer_to_str(S->getIslCtx()); 420 P = isl_printer_set_output_format(P, ISL_FORMAT_C); 421 422 P = isl_printer_print_ast_expr(P, RunCondition); 423 char *result = isl_printer_get_str(P); 424 P = isl_printer_flush(P); 425 426 OS << "\nif (" << result << ")\n\n"; 427 P = isl_printer_indent(P, 4); 428 429 Root = getAst(); 430 P = isl_ast_node_print(Root, P, Options); 431 result = isl_printer_get_str(P); 432 OS << result << "\n"; 433 OS << "else\n"; 434 OS << " { /* original code */ }\n\n"; 435 isl_printer_free(P); 436 isl_ast_node_free(Root); 437 } 438 439 __isl_give isl_ast_node *IslAst::getAst() { return isl_ast_node_copy(Root); } 440 __isl_give isl_ast_expr *IslAst::getRunCondition() { 441 return isl_ast_expr_copy(RunCondition); 442 } 443 444 void IslAstInfo::pprint(llvm::raw_ostream &OS) { Ast->pprint(OS); } 445 446 void IslAstInfo::releaseMemory() { 447 if (Ast) { 448 delete Ast; 449 Ast = 0; 450 } 451 } 452 453 bool IslAstInfo::runOnScop(Scop &Scop) { 454 if (Ast) 455 delete Ast; 456 457 S = &Scop; 458 459 Dependences &D = getAnalysis<Dependences>(); 460 461 Ast = new IslAst(&Scop, D); 462 463 return false; 464 } 465 466 __isl_give isl_ast_node *IslAstInfo::getAst() { return Ast->getAst(); } 467 __isl_give isl_ast_expr *IslAstInfo::getRunCondition() { 468 return Ast->getRunCondition(); 469 } 470 471 void IslAstInfo::printScop(raw_ostream &OS) const { 472 Function *F = S->getRegion().getEntry()->getParent(); 473 474 OS << F->getName() << "():\n"; 475 476 Ast->pprint(OS); 477 } 478 479 void IslAstInfo::getAnalysisUsage(AnalysisUsage &AU) const { 480 // Get the Common analysis usage of ScopPasses. 481 ScopPass::getAnalysisUsage(AU); 482 AU.addRequired<ScopInfo>(); 483 AU.addRequired<Dependences>(); 484 } 485 486 char IslAstInfo::ID = 0; 487 488 Pass *polly::createIslAstInfoPass() { return new IslAstInfo(); } 489 490 INITIALIZE_PASS_BEGIN(IslAstInfo, "polly-ast", 491 "Polly - Generate an AST of the SCoP (isl)", false, 492 false); 493 INITIALIZE_PASS_DEPENDENCY(ScopInfo); 494 INITIALIZE_PASS_DEPENDENCY(Dependences); 495 INITIALIZE_PASS_END(IslAstInfo, "polly-ast", 496 "Polly - Generate an AST from the SCoP (isl)", false, false) 497