1 //===- IslAst.cpp - isl code generator interface --------------------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // The isl code generator interface takes a Scop and generates a isl_ast. This 11 // ist_ast can either be returned directly or it can be pretty printed to 12 // stdout. 13 // 14 // A typical isl_ast output looks like this: 15 // 16 // for (c2 = max(0, ceild(n + m, 2); c2 <= min(511, floord(5 * n, 3)); c2++) { 17 // bb2(c2); 18 // } 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "polly/CodeGen/CodeGeneration.h" 23 #include "polly/CodeGen/IslAst.h" 24 #include "polly/Dependences.h" 25 #include "polly/LinkAllPasses.h" 26 #include "polly/Options.h" 27 #include "polly/ScopInfo.h" 28 #include "polly/Support/GICHelper.h" 29 #include "llvm/Support/Debug.h" 30 31 #include "isl/union_map.h" 32 #include "isl/list.h" 33 #include "isl/ast_build.h" 34 #include "isl/set.h" 35 #include "isl/map.h" 36 #include "isl/aff.h" 37 38 #define DEBUG_TYPE "polly-ast" 39 40 using namespace llvm; 41 using namespace polly; 42 43 using IslAstUserPayload = IslAstInfo::IslAstUserPayload; 44 45 static cl::opt<bool> 46 PollyParallel("polly-parallel", 47 cl::desc("Generate thread parallel code (isl codegen only)"), 48 cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory)); 49 50 static cl::opt<bool> PollyParallelForce( 51 "polly-parallel-force", 52 cl::desc( 53 "Force generation of thread parallel code ignoring any cost model"), 54 cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory)); 55 56 static cl::opt<bool> UseContext("polly-ast-use-context", 57 cl::desc("Use context"), cl::Hidden, 58 cl::init(false), cl::ZeroOrMore, 59 cl::cat(PollyCategory)); 60 61 static cl::opt<bool> DetectParallel("polly-ast-detect-parallel", 62 cl::desc("Detect parallelism"), cl::Hidden, 63 cl::init(false), cl::ZeroOrMore, 64 cl::cat(PollyCategory)); 65 66 namespace polly { 67 class IslAst { 68 public: 69 IslAst(Scop *Scop, Dependences &D); 70 71 ~IslAst(); 72 73 /// Print a source code representation of the program. 74 void pprint(llvm::raw_ostream &OS); 75 76 __isl_give isl_ast_node *getAst(); 77 78 /// @brief Get the run-time conditions for the Scop. 79 __isl_give isl_ast_expr *getRunCondition(); 80 81 private: 82 Scop *S; 83 isl_ast_node *Root; 84 isl_ast_expr *RunCondition; 85 86 void buildRunCondition(__isl_keep isl_ast_build *Build); 87 }; 88 } // End namespace polly. 89 90 /// @brief Free an IslAstUserPayload object pointed to by @p Ptr 91 static void freeIslAstUserPayload(void *Ptr) { 92 delete ((IslAstInfo::IslAstUserPayload *)Ptr); 93 } 94 95 IslAstInfo::IslAstUserPayload::~IslAstUserPayload() { 96 isl_ast_build_free(Build); 97 isl_pw_aff_free(MinimalDependenceDistance); 98 } 99 100 /// @brief Temporary information used when building the ast. 101 struct AstBuildUserInfo { 102 /// @brief Construct and initialize the helper struct for AST creation. 103 AstBuildUserInfo() 104 : Deps(nullptr), InParallelFor(false), LastForNodeId(nullptr) {} 105 106 /// @brief The dependence information used for the parallelism check. 107 Dependences *Deps; 108 109 /// @brief Flag to indicate that we are inside a parallel for node. 110 bool InParallelFor; 111 112 /// @brief The last iterator id created for the current SCoP. 113 isl_id *LastForNodeId; 114 }; 115 116 /// @brief Print a string @p str in a single line using @p Printer. 117 static isl_printer *printLine(__isl_take isl_printer *Printer, 118 const std::string &str, 119 __isl_keep isl_pw_aff *PWA = nullptr) { 120 Printer = isl_printer_start_line(Printer); 121 Printer = isl_printer_print_str(Printer, str.c_str()); 122 if (PWA) 123 Printer = isl_printer_print_pw_aff(Printer, PWA); 124 return isl_printer_end_line(Printer); 125 } 126 127 /// @brief Return all broken reductions as a string of clauses (OpenMP style). 128 static const std::string getBrokenReductionsStr(__isl_keep isl_ast_node *Node) { 129 IslAstInfo::MemoryAccessSet *BrokenReductions; 130 std::string str; 131 132 BrokenReductions = IslAstInfo::getBrokenReductions(Node); 133 if (!BrokenReductions || BrokenReductions->empty()) 134 return ""; 135 136 // Map each type of reduction to a comma separated list of the base addresses. 137 std::map<MemoryAccess::ReductionType, std::string> Clauses; 138 for (MemoryAccess *MA : *BrokenReductions) 139 if (MA->isWrite()) 140 Clauses[MA->getReductionType()] += 141 ", " + MA->getBaseAddr()->getName().str(); 142 143 // Now print the reductions sorted by type. Each type will cause a clause 144 // like: reduction (+ : sum0, sum1, sum2) 145 for (const auto &ReductionClause : Clauses) { 146 str += " reduction ("; 147 str += MemoryAccess::getReductionOperatorStr(ReductionClause.first); 148 // Remove the first two symbols (", ") to make the output look pretty. 149 str += " : " + ReductionClause.second.substr(2) + ")"; 150 } 151 152 return str; 153 } 154 155 /// @brief Callback executed for each for node in the ast in order to print it. 156 static isl_printer *cbPrintFor(__isl_take isl_printer *Printer, 157 __isl_take isl_ast_print_options *Options, 158 __isl_keep isl_ast_node *Node, void *) { 159 160 isl_pw_aff *DD = IslAstInfo::getMinimalDependenceDistance(Node); 161 const std::string BrokenReductionsStr = getBrokenReductionsStr(Node); 162 const std::string KnownParallelStr = "#pragma known-parallel"; 163 const std::string DepDisPragmaStr = "#pragma minimal dependence distance: "; 164 const std::string SimdPragmaStr = "#pragma simd"; 165 const std::string OmpPragmaStr = "#pragma omp parallel for"; 166 167 if (DD) 168 Printer = printLine(Printer, DepDisPragmaStr, DD); 169 170 if (IslAstInfo::isInnermostParallel(Node)) 171 Printer = printLine(Printer, SimdPragmaStr + BrokenReductionsStr); 172 173 if (IslAstInfo::isExecutedInParallel(Node)) 174 Printer = printLine(Printer, OmpPragmaStr); 175 else if (IslAstInfo::isOutermostParallel(Node)) 176 Printer = printLine(Printer, KnownParallelStr + BrokenReductionsStr); 177 178 isl_pw_aff_free(DD); 179 return isl_ast_node_for_print(Node, Printer, Options); 180 } 181 182 /// @brief Check if the current scheduling dimension is parallel 183 /// 184 /// In case the dimension is parallel we also check if any reduction 185 /// dependences is broken when we exploit this parallelism. If so, 186 /// @p IsReductionParallel will be set to true. The reduction dependences we use 187 /// to check are actually the union of the transitive closure of the initial 188 /// reduction dependences together with their reveresal. Even though these 189 /// dependences connect all iterations with each other (thus they are cyclic) 190 /// we can perform the parallelism check as we are only interested in a zero 191 /// (or non-zero) dependence distance on the dimension in question. 192 static bool astScheduleDimIsParallel(__isl_keep isl_ast_build *Build, 193 Dependences *D, 194 IslAstUserPayload *NodeInfo) { 195 if (!D->hasValidDependences()) 196 return false; 197 198 isl_union_map *Schedule = isl_ast_build_get_schedule(Build); 199 isl_union_map *Deps = D->getDependences( 200 Dependences::TYPE_RAW | Dependences::TYPE_WAW | Dependences::TYPE_WAR); 201 202 if (!D->isParallel(Schedule, Deps, &NodeInfo->MinimalDependenceDistance) && 203 !isl_union_map_free(Schedule)) 204 return false; 205 206 isl_union_map *RedDeps = D->getDependences(Dependences::TYPE_TC_RED); 207 if (!D->isParallel(Schedule, RedDeps)) 208 NodeInfo->IsReductionParallel = true; 209 210 if (!NodeInfo->IsReductionParallel && !isl_union_map_free(Schedule)) 211 return true; 212 213 // Annotate reduction parallel nodes with the memory accesses which caused the 214 // reduction dependences parallel execution of the node conflicts with. 215 for (const auto &MaRedPair : D->getReductionDependences()) { 216 if (!MaRedPair.second) 217 continue; 218 RedDeps = isl_union_map_from_map(isl_map_copy(MaRedPair.second)); 219 if (!D->isParallel(Schedule, RedDeps)) 220 NodeInfo->BrokenReductions.insert(MaRedPair.first); 221 } 222 223 isl_union_map_free(Schedule); 224 return true; 225 } 226 227 // This method is executed before the construction of a for node. It creates 228 // an isl_id that is used to annotate the subsequently generated ast for nodes. 229 // 230 // In this function we also run the following analyses: 231 // 232 // - Detection of openmp parallel loops 233 // 234 static __isl_give isl_id *astBuildBeforeFor(__isl_keep isl_ast_build *Build, 235 void *User) { 236 AstBuildUserInfo *BuildInfo = (AstBuildUserInfo *)User; 237 IslAstUserPayload *Payload = new IslAstUserPayload(); 238 isl_id *Id = isl_id_alloc(isl_ast_build_get_ctx(Build), "", Payload); 239 Id = isl_id_set_free_user(Id, freeIslAstUserPayload); 240 BuildInfo->LastForNodeId = Id; 241 242 // Test for parallelism only if we are not already inside a parallel loop 243 if (!BuildInfo->InParallelFor) 244 BuildInfo->InParallelFor = Payload->IsOutermostParallel = 245 astScheduleDimIsParallel(Build, BuildInfo->Deps, Payload); 246 247 return Id; 248 } 249 250 // This method is executed after the construction of a for node. 251 // 252 // It performs the following actions: 253 // 254 // - Reset the 'InParallelFor' flag, as soon as we leave a for node, 255 // that is marked as openmp parallel. 256 // 257 static __isl_give isl_ast_node * 258 astBuildAfterFor(__isl_take isl_ast_node *Node, __isl_keep isl_ast_build *Build, 259 void *User) { 260 isl_id *Id = isl_ast_node_get_annotation(Node); 261 assert(Id && "Post order visit assumes annotated for nodes"); 262 IslAstUserPayload *Payload = (IslAstUserPayload *)isl_id_get_user(Id); 263 assert(Payload && "Post order visit assumes annotated for nodes"); 264 265 AstBuildUserInfo *BuildInfo = (AstBuildUserInfo *)User; 266 assert(!Payload->Build && "Build environment already set"); 267 Payload->Build = isl_ast_build_copy(Build); 268 Payload->IsInnermost = (Id == BuildInfo->LastForNodeId); 269 270 // Innermost loops that are surrounded by parallel loops have not yet been 271 // tested for parallelism. Test them here to ensure we check all innermost 272 // loops for parallelism. 273 if (Payload->IsInnermost && BuildInfo->InParallelFor) { 274 if (Payload->IsOutermostParallel) 275 Payload->IsInnermostParallel = true; 276 else 277 Payload->IsInnermostParallel = 278 astScheduleDimIsParallel(Build, BuildInfo->Deps, Payload); 279 } 280 if (Payload->IsOutermostParallel) 281 BuildInfo->InParallelFor = false; 282 283 isl_id_free(Id); 284 return Node; 285 } 286 287 static __isl_give isl_ast_node *AtEachDomain(__isl_take isl_ast_node *Node, 288 __isl_keep isl_ast_build *Build, 289 void *User) { 290 assert(!isl_ast_node_get_annotation(Node) && "Node already annotated"); 291 292 IslAstUserPayload *Payload = new IslAstUserPayload(); 293 isl_id *Id = isl_id_alloc(isl_ast_build_get_ctx(Build), "", Payload); 294 Id = isl_id_set_free_user(Id, freeIslAstUserPayload); 295 296 Payload->Build = isl_ast_build_copy(Build); 297 298 return isl_ast_node_set_annotation(Node, Id); 299 } 300 301 void IslAst::buildRunCondition(__isl_keep isl_ast_build *Build) { 302 // The conditions that need to be checked at run-time for this scop are 303 // available as an isl_set in the AssumedContext. We generate code for this 304 // check as follows. First, we generate an isl_pw_aff that is 1, if a certain 305 // combination of parameter values fulfills the conditions in the assumed 306 // context, and that is 0 otherwise. We then translate this isl_pw_aff into 307 // an isl_ast_expr. At run-time this expression can be evaluated and the 308 // optimized scop can be executed conditionally according to the result of the 309 // run-time check. 310 311 isl_aff *Zero = 312 isl_aff_zero_on_domain(isl_local_space_from_space(S->getParamSpace())); 313 isl_aff *One = 314 isl_aff_zero_on_domain(isl_local_space_from_space(S->getParamSpace())); 315 316 One = isl_aff_add_constant_si(One, 1); 317 318 isl_pw_aff *PwZero = isl_pw_aff_from_aff(Zero); 319 isl_pw_aff *PwOne = isl_pw_aff_from_aff(One); 320 321 PwOne = isl_pw_aff_intersect_domain(PwOne, S->getAssumedContext()); 322 PwZero = isl_pw_aff_intersect_domain( 323 PwZero, isl_set_complement(S->getAssumedContext())); 324 325 isl_pw_aff *Cond = isl_pw_aff_union_max(PwOne, PwZero); 326 327 RunCondition = isl_ast_build_expr_from_pw_aff(Build, Cond); 328 329 // Create the alias checks from the minimal/maximal accesses in each alias 330 // group. This operation is by construction quadratic in the number of 331 // elements in each alias group. 332 isl_ast_expr *NonAliasGroup, *MinExpr, *MaxExpr; 333 for (const Scop::MinMaxVectorTy *MinMaxAccesses : S->getAliasGroups()) { 334 auto AccEnd = MinMaxAccesses->end(); 335 for (auto AccIt0 = MinMaxAccesses->begin(); AccIt0 != AccEnd; ++AccIt0) { 336 for (auto AccIt1 = AccIt0 + 1; AccIt1 != AccEnd; ++AccIt1) { 337 MinExpr = 338 isl_ast_expr_address_of(isl_ast_build_access_from_pw_multi_aff( 339 Build, isl_pw_multi_aff_copy(AccIt0->first))); 340 MaxExpr = 341 isl_ast_expr_address_of(isl_ast_build_access_from_pw_multi_aff( 342 Build, isl_pw_multi_aff_copy(AccIt1->second))); 343 NonAliasGroup = isl_ast_expr_le(MaxExpr, MinExpr); 344 MinExpr = 345 isl_ast_expr_address_of(isl_ast_build_access_from_pw_multi_aff( 346 Build, isl_pw_multi_aff_copy(AccIt1->first))); 347 MaxExpr = 348 isl_ast_expr_address_of(isl_ast_build_access_from_pw_multi_aff( 349 Build, isl_pw_multi_aff_copy(AccIt0->second))); 350 NonAliasGroup = 351 isl_ast_expr_or(NonAliasGroup, isl_ast_expr_le(MaxExpr, MinExpr)); 352 RunCondition = isl_ast_expr_and(RunCondition, NonAliasGroup); 353 } 354 } 355 } 356 } 357 358 IslAst::IslAst(Scop *Scop, Dependences &D) : S(Scop) { 359 isl_ctx *Ctx = S->getIslCtx(); 360 isl_options_set_ast_build_atomic_upper_bound(Ctx, true); 361 isl_ast_build *Build; 362 AstBuildUserInfo BuildInfo; 363 364 if (UseContext) 365 Build = isl_ast_build_from_context(S->getContext()); 366 else 367 Build = isl_ast_build_from_context(isl_set_universe(S->getParamSpace())); 368 369 Build = isl_ast_build_set_at_each_domain(Build, AtEachDomain, nullptr); 370 371 isl_union_map *Schedule = 372 isl_union_map_intersect_domain(S->getSchedule(), S->getDomains()); 373 374 if (PollyParallel || DetectParallel || 375 PollyVectorizerChoice != VECTORIZER_NONE) { 376 BuildInfo.Deps = &D; 377 BuildInfo.InParallelFor = 0; 378 379 Build = isl_ast_build_set_before_each_for(Build, &astBuildBeforeFor, 380 &BuildInfo); 381 Build = 382 isl_ast_build_set_after_each_for(Build, &astBuildAfterFor, &BuildInfo); 383 } 384 385 buildRunCondition(Build); 386 387 Root = isl_ast_build_ast_from_schedule(Build, Schedule); 388 389 isl_ast_build_free(Build); 390 } 391 392 IslAst::~IslAst() { 393 isl_ast_node_free(Root); 394 isl_ast_expr_free(RunCondition); 395 } 396 397 __isl_give isl_ast_node *IslAst::getAst() { return isl_ast_node_copy(Root); } 398 __isl_give isl_ast_expr *IslAst::getRunCondition() { 399 return isl_ast_expr_copy(RunCondition); 400 } 401 402 void IslAstInfo::releaseMemory() { 403 if (Ast) { 404 delete Ast; 405 Ast = nullptr; 406 } 407 } 408 409 bool IslAstInfo::runOnScop(Scop &Scop) { 410 if (Ast) 411 delete Ast; 412 413 S = &Scop; 414 415 Dependences &D = getAnalysis<Dependences>(); 416 417 Ast = new IslAst(&Scop, D); 418 419 DEBUG(printScop(dbgs())); 420 return false; 421 } 422 423 __isl_give isl_ast_node *IslAstInfo::getAst() const { return Ast->getAst(); } 424 __isl_give isl_ast_expr *IslAstInfo::getRunCondition() const { 425 return Ast->getRunCondition(); 426 } 427 428 IslAstUserPayload *IslAstInfo::getNodePayload(__isl_keep isl_ast_node *Node) { 429 isl_id *Id = isl_ast_node_get_annotation(Node); 430 if (!Id) 431 return nullptr; 432 IslAstUserPayload *Payload = (IslAstUserPayload *)isl_id_get_user(Id); 433 isl_id_free(Id); 434 return Payload; 435 } 436 437 bool IslAstInfo::isInnermost(__isl_keep isl_ast_node *Node) { 438 IslAstUserPayload *Payload = getNodePayload(Node); 439 return Payload && Payload->IsInnermost; 440 } 441 442 bool IslAstInfo::isParallel(__isl_keep isl_ast_node *Node) { 443 return IslAstInfo::isInnermostParallel(Node) || 444 IslAstInfo::isOutermostParallel(Node); 445 } 446 447 bool IslAstInfo::isInnermostParallel(__isl_keep isl_ast_node *Node) { 448 IslAstUserPayload *Payload = getNodePayload(Node); 449 return Payload && Payload->IsInnermostParallel; 450 } 451 452 bool IslAstInfo::isOutermostParallel(__isl_keep isl_ast_node *Node) { 453 IslAstUserPayload *Payload = getNodePayload(Node); 454 return Payload && Payload->IsOutermostParallel; 455 } 456 457 bool IslAstInfo::isReductionParallel(__isl_keep isl_ast_node *Node) { 458 IslAstUserPayload *Payload = getNodePayload(Node); 459 return Payload && Payload->IsReductionParallel; 460 } 461 462 bool IslAstInfo::isExecutedInParallel(__isl_keep isl_ast_node *Node) { 463 464 if (!PollyParallel) 465 return false; 466 467 // Do not parallelize innermost loops. 468 // 469 // Parallelizing innermost loops is often not profitable, especially if 470 // they have a low number of iterations. 471 // 472 // TODO: Decide this based on the number of loop iterations that will be 473 // executed. This can possibly require run-time checks, which again 474 // raises the question of both run-time check overhead and code size 475 // costs. 476 if (!PollyParallelForce && isInnermost(Node)) 477 return false; 478 479 return isOutermostParallel(Node) && !isReductionParallel(Node); 480 } 481 482 isl_union_map *IslAstInfo::getSchedule(__isl_keep isl_ast_node *Node) { 483 IslAstUserPayload *Payload = getNodePayload(Node); 484 return Payload ? isl_ast_build_get_schedule(Payload->Build) : nullptr; 485 } 486 487 isl_pw_aff * 488 IslAstInfo::getMinimalDependenceDistance(__isl_keep isl_ast_node *Node) { 489 IslAstUserPayload *Payload = getNodePayload(Node); 490 return Payload ? isl_pw_aff_copy(Payload->MinimalDependenceDistance) 491 : nullptr; 492 } 493 494 IslAstInfo::MemoryAccessSet * 495 IslAstInfo::getBrokenReductions(__isl_keep isl_ast_node *Node) { 496 IslAstUserPayload *Payload = getNodePayload(Node); 497 return Payload ? &Payload->BrokenReductions : nullptr; 498 } 499 500 isl_ast_build *IslAstInfo::getBuild(__isl_keep isl_ast_node *Node) { 501 IslAstUserPayload *Payload = getNodePayload(Node); 502 return Payload ? Payload->Build : nullptr; 503 } 504 505 void IslAstInfo::printScop(raw_ostream &OS) const { 506 isl_ast_print_options *Options; 507 isl_ast_node *RootNode = getAst(); 508 isl_ast_expr *RunCondition = getRunCondition(); 509 char *RtCStr, *AstStr; 510 511 Scop &S = getCurScop(); 512 Options = isl_ast_print_options_alloc(S.getIslCtx()); 513 Options = isl_ast_print_options_set_print_for(Options, cbPrintFor, nullptr); 514 515 isl_printer *P = isl_printer_to_str(S.getIslCtx()); 516 P = isl_printer_print_ast_expr(P, RunCondition); 517 RtCStr = isl_printer_get_str(P); 518 P = isl_printer_flush(P); 519 P = isl_printer_indent(P, 4); 520 P = isl_printer_set_output_format(P, ISL_FORMAT_C); 521 P = isl_ast_node_print(RootNode, P, Options); 522 AstStr = isl_printer_get_str(P); 523 524 Function *F = S.getRegion().getEntry()->getParent(); 525 isl_union_map *Schedule = 526 isl_union_map_intersect_domain(S.getSchedule(), S.getDomains()); 527 528 OS << ":: isl ast :: " << F->getName() << " :: " << S.getRegion().getNameStr() 529 << "\n"; 530 DEBUG({ 531 dbgs() << S.getContextStr() << "\n"; 532 dbgs() << stringFromIslObj(Schedule); 533 }); 534 OS << "\nif (" << RtCStr << ")\n\n"; 535 OS << AstStr << "\n"; 536 OS << "else\n"; 537 OS << " { /* original code */ }\n\n"; 538 539 isl_ast_expr_free(RunCondition); 540 isl_union_map_free(Schedule); 541 isl_ast_node_free(RootNode); 542 isl_printer_free(P); 543 } 544 545 void IslAstInfo::getAnalysisUsage(AnalysisUsage &AU) const { 546 // Get the Common analysis usage of ScopPasses. 547 ScopPass::getAnalysisUsage(AU); 548 AU.addRequired<ScopInfo>(); 549 AU.addRequired<Dependences>(); 550 } 551 552 char IslAstInfo::ID = 0; 553 554 Pass *polly::createIslAstInfoPass() { return new IslAstInfo(); } 555 556 INITIALIZE_PASS_BEGIN(IslAstInfo, "polly-ast", 557 "Polly - Generate an AST of the SCoP (isl)", false, 558 false); 559 INITIALIZE_PASS_DEPENDENCY(ScopInfo); 560 INITIALIZE_PASS_DEPENDENCY(Dependences); 561 INITIALIZE_PASS_END(IslAstInfo, "polly-ast", 562 "Polly - Generate an AST from the SCoP (isl)", false, false) 563