1 //===- IslAst.cpp - isl code generator interface --------------------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // The isl code generator interface takes a Scop and generates a isl_ast. This 11 // ist_ast can either be returned directly or it can be pretty printed to 12 // stdout. 13 // 14 // A typical isl_ast output looks like this: 15 // 16 // for (c2 = max(0, ceild(n + m, 2); c2 <= min(511, floord(5 * n, 3)); c2++) { 17 // bb2(c2); 18 // } 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "polly/CodeGen/CodeGeneration.h" 23 #include "polly/CodeGen/IslAst.h" 24 #include "polly/DependenceInfo.h" 25 #include "polly/LinkAllPasses.h" 26 #include "polly/Options.h" 27 #include "polly/ScopInfo.h" 28 #include "polly/Support/GICHelper.h" 29 30 #include "llvm/Analysis/RegionInfo.h" 31 #include "llvm/Support/Debug.h" 32 33 #include "isl/union_map.h" 34 #include "isl/list.h" 35 #include "isl/ast_build.h" 36 #include "isl/set.h" 37 #include "isl/map.h" 38 #include "isl/aff.h" 39 40 #define DEBUG_TYPE "polly-ast" 41 42 using namespace llvm; 43 using namespace polly; 44 45 using IslAstUserPayload = IslAstInfo::IslAstUserPayload; 46 47 static cl::opt<bool> 48 PollyParallel("polly-parallel", 49 cl::desc("Generate thread parallel code (isl codegen only)"), 50 cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory)); 51 52 static cl::opt<bool> PollyParallelForce( 53 "polly-parallel-force", 54 cl::desc( 55 "Force generation of thread parallel code ignoring any cost model"), 56 cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory)); 57 58 static cl::opt<bool> UseContext("polly-ast-use-context", 59 cl::desc("Use context"), cl::Hidden, 60 cl::init(false), cl::ZeroOrMore, 61 cl::cat(PollyCategory)); 62 63 static cl::opt<bool> DetectParallel("polly-ast-detect-parallel", 64 cl::desc("Detect parallelism"), cl::Hidden, 65 cl::init(false), cl::ZeroOrMore, 66 cl::cat(PollyCategory)); 67 68 static cl::opt<bool> NoEarlyExit( 69 "polly-no-early-exit", 70 cl::desc("Do not exit early if no benefit of the Polly version was found."), 71 cl::Hidden, cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory)); 72 73 namespace polly { 74 class IslAst { 75 public: 76 IslAst(Scop *Scop, const Dependences &D); 77 78 ~IslAst(); 79 80 /// Print a source code representation of the program. 81 void pprint(llvm::raw_ostream &OS); 82 83 __isl_give isl_ast_node *getAst(); 84 85 /// @brief Get the run-time conditions for the Scop. 86 __isl_give isl_ast_expr *getRunCondition(); 87 88 private: 89 Scop *S; 90 isl_ast_node *Root; 91 isl_ast_expr *RunCondition; 92 93 void buildRunCondition(__isl_keep isl_ast_build *Build); 94 }; 95 } // End namespace polly. 96 97 /// @brief Free an IslAstUserPayload object pointed to by @p Ptr 98 static void freeIslAstUserPayload(void *Ptr) { 99 delete ((IslAstInfo::IslAstUserPayload *)Ptr); 100 } 101 102 IslAstInfo::IslAstUserPayload::~IslAstUserPayload() { 103 isl_ast_build_free(Build); 104 isl_pw_aff_free(MinimalDependenceDistance); 105 } 106 107 /// @brief Temporary information used when building the ast. 108 struct AstBuildUserInfo { 109 /// @brief Construct and initialize the helper struct for AST creation. 110 AstBuildUserInfo() 111 : Deps(nullptr), InParallelFor(false), LastForNodeId(nullptr) {} 112 113 /// @brief The dependence information used for the parallelism check. 114 const Dependences *Deps; 115 116 /// @brief Flag to indicate that we are inside a parallel for node. 117 bool InParallelFor; 118 119 /// @brief The last iterator id created for the current SCoP. 120 isl_id *LastForNodeId; 121 }; 122 123 /// @brief Print a string @p str in a single line using @p Printer. 124 static isl_printer *printLine(__isl_take isl_printer *Printer, 125 const std::string &str, 126 __isl_keep isl_pw_aff *PWA = nullptr) { 127 Printer = isl_printer_start_line(Printer); 128 Printer = isl_printer_print_str(Printer, str.c_str()); 129 if (PWA) 130 Printer = isl_printer_print_pw_aff(Printer, PWA); 131 return isl_printer_end_line(Printer); 132 } 133 134 /// @brief Return all broken reductions as a string of clauses (OpenMP style). 135 static const std::string getBrokenReductionsStr(__isl_keep isl_ast_node *Node) { 136 IslAstInfo::MemoryAccessSet *BrokenReductions; 137 std::string str; 138 139 BrokenReductions = IslAstInfo::getBrokenReductions(Node); 140 if (!BrokenReductions || BrokenReductions->empty()) 141 return ""; 142 143 // Map each type of reduction to a comma separated list of the base addresses. 144 std::map<MemoryAccess::ReductionType, std::string> Clauses; 145 for (MemoryAccess *MA : *BrokenReductions) 146 if (MA->isWrite()) 147 Clauses[MA->getReductionType()] += 148 ", " + MA->getBaseAddr()->getName().str(); 149 150 // Now print the reductions sorted by type. Each type will cause a clause 151 // like: reduction (+ : sum0, sum1, sum2) 152 for (const auto &ReductionClause : Clauses) { 153 str += " reduction ("; 154 str += MemoryAccess::getReductionOperatorStr(ReductionClause.first); 155 // Remove the first two symbols (", ") to make the output look pretty. 156 str += " : " + ReductionClause.second.substr(2) + ")"; 157 } 158 159 return str; 160 } 161 162 /// @brief Callback executed for each for node in the ast in order to print it. 163 static isl_printer *cbPrintFor(__isl_take isl_printer *Printer, 164 __isl_take isl_ast_print_options *Options, 165 __isl_keep isl_ast_node *Node, void *) { 166 167 isl_pw_aff *DD = IslAstInfo::getMinimalDependenceDistance(Node); 168 const std::string BrokenReductionsStr = getBrokenReductionsStr(Node); 169 const std::string KnownParallelStr = "#pragma known-parallel"; 170 const std::string DepDisPragmaStr = "#pragma minimal dependence distance: "; 171 const std::string SimdPragmaStr = "#pragma simd"; 172 const std::string OmpPragmaStr = "#pragma omp parallel for"; 173 174 if (DD) 175 Printer = printLine(Printer, DepDisPragmaStr, DD); 176 177 if (IslAstInfo::isInnermostParallel(Node)) 178 Printer = printLine(Printer, SimdPragmaStr + BrokenReductionsStr); 179 180 if (IslAstInfo::isExecutedInParallel(Node)) 181 Printer = printLine(Printer, OmpPragmaStr); 182 else if (IslAstInfo::isOutermostParallel(Node)) 183 Printer = printLine(Printer, KnownParallelStr + BrokenReductionsStr); 184 185 isl_pw_aff_free(DD); 186 return isl_ast_node_for_print(Node, Printer, Options); 187 } 188 189 /// @brief Check if the current scheduling dimension is parallel 190 /// 191 /// In case the dimension is parallel we also check if any reduction 192 /// dependences is broken when we exploit this parallelism. If so, 193 /// @p IsReductionParallel will be set to true. The reduction dependences we use 194 /// to check are actually the union of the transitive closure of the initial 195 /// reduction dependences together with their reveresal. Even though these 196 /// dependences connect all iterations with each other (thus they are cyclic) 197 /// we can perform the parallelism check as we are only interested in a zero 198 /// (or non-zero) dependence distance on the dimension in question. 199 static bool astScheduleDimIsParallel(__isl_keep isl_ast_build *Build, 200 const Dependences *D, 201 IslAstUserPayload *NodeInfo) { 202 if (!D->hasValidDependences()) 203 return false; 204 205 isl_union_map *Schedule = isl_ast_build_get_schedule(Build); 206 isl_union_map *Deps = D->getDependences( 207 Dependences::TYPE_RAW | Dependences::TYPE_WAW | Dependences::TYPE_WAR); 208 209 if (!D->isParallel(Schedule, Deps, &NodeInfo->MinimalDependenceDistance) && 210 !isl_union_map_free(Schedule)) 211 return false; 212 213 isl_union_map *RedDeps = D->getDependences(Dependences::TYPE_TC_RED); 214 if (!D->isParallel(Schedule, RedDeps)) 215 NodeInfo->IsReductionParallel = true; 216 217 if (!NodeInfo->IsReductionParallel && !isl_union_map_free(Schedule)) 218 return true; 219 220 // Annotate reduction parallel nodes with the memory accesses which caused the 221 // reduction dependences parallel execution of the node conflicts with. 222 for (const auto &MaRedPair : D->getReductionDependences()) { 223 if (!MaRedPair.second) 224 continue; 225 RedDeps = isl_union_map_from_map(isl_map_copy(MaRedPair.second)); 226 if (!D->isParallel(Schedule, RedDeps)) 227 NodeInfo->BrokenReductions.insert(MaRedPair.first); 228 } 229 230 isl_union_map_free(Schedule); 231 return true; 232 } 233 234 // This method is executed before the construction of a for node. It creates 235 // an isl_id that is used to annotate the subsequently generated ast for nodes. 236 // 237 // In this function we also run the following analyses: 238 // 239 // - Detection of openmp parallel loops 240 // 241 static __isl_give isl_id *astBuildBeforeFor(__isl_keep isl_ast_build *Build, 242 void *User) { 243 AstBuildUserInfo *BuildInfo = (AstBuildUserInfo *)User; 244 IslAstUserPayload *Payload = new IslAstUserPayload(); 245 isl_id *Id = isl_id_alloc(isl_ast_build_get_ctx(Build), "", Payload); 246 Id = isl_id_set_free_user(Id, freeIslAstUserPayload); 247 BuildInfo->LastForNodeId = Id; 248 249 // Test for parallelism only if we are not already inside a parallel loop 250 if (!BuildInfo->InParallelFor) 251 BuildInfo->InParallelFor = Payload->IsOutermostParallel = 252 astScheduleDimIsParallel(Build, BuildInfo->Deps, Payload); 253 254 return Id; 255 } 256 257 // This method is executed after the construction of a for node. 258 // 259 // It performs the following actions: 260 // 261 // - Reset the 'InParallelFor' flag, as soon as we leave a for node, 262 // that is marked as openmp parallel. 263 // 264 static __isl_give isl_ast_node * 265 astBuildAfterFor(__isl_take isl_ast_node *Node, __isl_keep isl_ast_build *Build, 266 void *User) { 267 isl_id *Id = isl_ast_node_get_annotation(Node); 268 assert(Id && "Post order visit assumes annotated for nodes"); 269 IslAstUserPayload *Payload = (IslAstUserPayload *)isl_id_get_user(Id); 270 assert(Payload && "Post order visit assumes annotated for nodes"); 271 272 AstBuildUserInfo *BuildInfo = (AstBuildUserInfo *)User; 273 assert(!Payload->Build && "Build environment already set"); 274 Payload->Build = isl_ast_build_copy(Build); 275 Payload->IsInnermost = (Id == BuildInfo->LastForNodeId); 276 277 // Innermost loops that are surrounded by parallel loops have not yet been 278 // tested for parallelism. Test them here to ensure we check all innermost 279 // loops for parallelism. 280 if (Payload->IsInnermost && BuildInfo->InParallelFor) { 281 if (Payload->IsOutermostParallel) 282 Payload->IsInnermostParallel = true; 283 else 284 Payload->IsInnermostParallel = 285 astScheduleDimIsParallel(Build, BuildInfo->Deps, Payload); 286 } 287 if (Payload->IsOutermostParallel) 288 BuildInfo->InParallelFor = false; 289 290 isl_id_free(Id); 291 return Node; 292 } 293 294 static __isl_give isl_ast_node *AtEachDomain(__isl_take isl_ast_node *Node, 295 __isl_keep isl_ast_build *Build, 296 void *User) { 297 assert(!isl_ast_node_get_annotation(Node) && "Node already annotated"); 298 299 IslAstUserPayload *Payload = new IslAstUserPayload(); 300 isl_id *Id = isl_id_alloc(isl_ast_build_get_ctx(Build), "", Payload); 301 Id = isl_id_set_free_user(Id, freeIslAstUserPayload); 302 303 Payload->Build = isl_ast_build_copy(Build); 304 305 return isl_ast_node_set_annotation(Node, Id); 306 } 307 308 void IslAst::buildRunCondition(__isl_keep isl_ast_build *Build) { 309 // The conditions that need to be checked at run-time for this scop are 310 // available as an isl_set in the AssumedContext from which we can directly 311 // derive a run-time condition. 312 RunCondition = isl_ast_build_expr_from_set(Build, S->getAssumedContext()); 313 314 // Create the alias checks from the minimal/maximal accesses in each alias 315 // group. This operation is by construction quadratic in the number of 316 // elements in each alias group. 317 isl_ast_expr *NonAliasGroup, *MinExpr, *MaxExpr; 318 for (const Scop::MinMaxVectorTy *MinMaxAccesses : S->getAliasGroups()) { 319 auto AccEnd = MinMaxAccesses->end(); 320 for (auto AccIt0 = MinMaxAccesses->begin(); AccIt0 != AccEnd; ++AccIt0) { 321 for (auto AccIt1 = AccIt0 + 1; AccIt1 != AccEnd; ++AccIt1) { 322 MinExpr = 323 isl_ast_expr_address_of(isl_ast_build_access_from_pw_multi_aff( 324 Build, isl_pw_multi_aff_copy(AccIt0->first))); 325 MaxExpr = 326 isl_ast_expr_address_of(isl_ast_build_access_from_pw_multi_aff( 327 Build, isl_pw_multi_aff_copy(AccIt1->second))); 328 NonAliasGroup = isl_ast_expr_le(MaxExpr, MinExpr); 329 MinExpr = 330 isl_ast_expr_address_of(isl_ast_build_access_from_pw_multi_aff( 331 Build, isl_pw_multi_aff_copy(AccIt1->first))); 332 MaxExpr = 333 isl_ast_expr_address_of(isl_ast_build_access_from_pw_multi_aff( 334 Build, isl_pw_multi_aff_copy(AccIt0->second))); 335 NonAliasGroup = 336 isl_ast_expr_or(NonAliasGroup, isl_ast_expr_le(MaxExpr, MinExpr)); 337 RunCondition = isl_ast_expr_and(RunCondition, NonAliasGroup); 338 } 339 } 340 } 341 } 342 343 /// @brief Simple cost analysis for a given SCoP 344 /// 345 /// TODO: Improve this analysis and extract it to make it usable in other 346 /// places too. 347 /// In order to improve the cost model we could either keep track of 348 /// performed optimizations (e.g., tiling) or compute properties on the 349 /// original as well as optimized SCoP (e.g., #stride-one-accesses). 350 static bool benefitsFromPolly(Scop *Scop, bool PerformParallelTest) { 351 352 // First check the user choice. 353 if (NoEarlyExit) 354 return true; 355 356 // Check if nothing interesting happened. 357 if (!PerformParallelTest && !Scop->isOptimized() && 358 Scop->getAliasGroups().empty()) 359 return false; 360 361 // The default assumption is that Polly improves the code. 362 return true; 363 } 364 365 IslAst::IslAst(Scop *Scop, const Dependences &D) 366 : S(Scop), Root(nullptr), RunCondition(nullptr) { 367 368 bool PerformParallelTest = PollyParallel || DetectParallel || 369 PollyVectorizerChoice != VECTORIZER_NONE; 370 371 // Skip AST and code generation if there was no benefit achieved. 372 if (!benefitsFromPolly(Scop, PerformParallelTest)) 373 return; 374 375 isl_ctx *Ctx = S->getIslCtx(); 376 isl_options_set_ast_build_atomic_upper_bound(Ctx, true); 377 isl_ast_build *Build; 378 AstBuildUserInfo BuildInfo; 379 380 if (UseContext) 381 Build = isl_ast_build_from_context(S->getContext()); 382 else 383 Build = isl_ast_build_from_context(isl_set_universe(S->getParamSpace())); 384 385 Build = isl_ast_build_set_at_each_domain(Build, AtEachDomain, nullptr); 386 387 isl_union_map *Schedule = 388 isl_union_map_intersect_domain(S->getSchedule(), S->getDomains()); 389 390 if (PerformParallelTest) { 391 BuildInfo.Deps = &D; 392 BuildInfo.InParallelFor = 0; 393 394 Build = isl_ast_build_set_before_each_for(Build, &astBuildBeforeFor, 395 &BuildInfo); 396 Build = 397 isl_ast_build_set_after_each_for(Build, &astBuildAfterFor, &BuildInfo); 398 } 399 400 buildRunCondition(Build); 401 402 Root = isl_ast_build_ast_from_schedule(Build, Schedule); 403 404 isl_ast_build_free(Build); 405 } 406 407 IslAst::~IslAst() { 408 isl_ast_node_free(Root); 409 isl_ast_expr_free(RunCondition); 410 } 411 412 __isl_give isl_ast_node *IslAst::getAst() { return isl_ast_node_copy(Root); } 413 __isl_give isl_ast_expr *IslAst::getRunCondition() { 414 return isl_ast_expr_copy(RunCondition); 415 } 416 417 void IslAstInfo::releaseMemory() { 418 if (Ast) { 419 delete Ast; 420 Ast = nullptr; 421 } 422 } 423 424 bool IslAstInfo::runOnScop(Scop &Scop) { 425 if (Ast) 426 delete Ast; 427 428 S = &Scop; 429 430 const Dependences &D = getAnalysis<DependenceInfo>().getDependences(); 431 432 Ast = new IslAst(&Scop, D); 433 434 DEBUG(printScop(dbgs(), Scop)); 435 return false; 436 } 437 438 __isl_give isl_ast_node *IslAstInfo::getAst() const { return Ast->getAst(); } 439 __isl_give isl_ast_expr *IslAstInfo::getRunCondition() const { 440 return Ast->getRunCondition(); 441 } 442 443 IslAstUserPayload *IslAstInfo::getNodePayload(__isl_keep isl_ast_node *Node) { 444 isl_id *Id = isl_ast_node_get_annotation(Node); 445 if (!Id) 446 return nullptr; 447 IslAstUserPayload *Payload = (IslAstUserPayload *)isl_id_get_user(Id); 448 isl_id_free(Id); 449 return Payload; 450 } 451 452 bool IslAstInfo::isInnermost(__isl_keep isl_ast_node *Node) { 453 IslAstUserPayload *Payload = getNodePayload(Node); 454 return Payload && Payload->IsInnermost; 455 } 456 457 bool IslAstInfo::isParallel(__isl_keep isl_ast_node *Node) { 458 return IslAstInfo::isInnermostParallel(Node) || 459 IslAstInfo::isOutermostParallel(Node); 460 } 461 462 bool IslAstInfo::isInnermostParallel(__isl_keep isl_ast_node *Node) { 463 IslAstUserPayload *Payload = getNodePayload(Node); 464 return Payload && Payload->IsInnermostParallel; 465 } 466 467 bool IslAstInfo::isOutermostParallel(__isl_keep isl_ast_node *Node) { 468 IslAstUserPayload *Payload = getNodePayload(Node); 469 return Payload && Payload->IsOutermostParallel; 470 } 471 472 bool IslAstInfo::isReductionParallel(__isl_keep isl_ast_node *Node) { 473 IslAstUserPayload *Payload = getNodePayload(Node); 474 return Payload && Payload->IsReductionParallel; 475 } 476 477 bool IslAstInfo::isExecutedInParallel(__isl_keep isl_ast_node *Node) { 478 479 if (!PollyParallel) 480 return false; 481 482 // Do not parallelize innermost loops. 483 // 484 // Parallelizing innermost loops is often not profitable, especially if 485 // they have a low number of iterations. 486 // 487 // TODO: Decide this based on the number of loop iterations that will be 488 // executed. This can possibly require run-time checks, which again 489 // raises the question of both run-time check overhead and code size 490 // costs. 491 if (!PollyParallelForce && isInnermost(Node)) 492 return false; 493 494 return isOutermostParallel(Node) && !isReductionParallel(Node); 495 } 496 497 isl_union_map *IslAstInfo::getSchedule(__isl_keep isl_ast_node *Node) { 498 IslAstUserPayload *Payload = getNodePayload(Node); 499 return Payload ? isl_ast_build_get_schedule(Payload->Build) : nullptr; 500 } 501 502 isl_pw_aff * 503 IslAstInfo::getMinimalDependenceDistance(__isl_keep isl_ast_node *Node) { 504 IslAstUserPayload *Payload = getNodePayload(Node); 505 return Payload ? isl_pw_aff_copy(Payload->MinimalDependenceDistance) 506 : nullptr; 507 } 508 509 IslAstInfo::MemoryAccessSet * 510 IslAstInfo::getBrokenReductions(__isl_keep isl_ast_node *Node) { 511 IslAstUserPayload *Payload = getNodePayload(Node); 512 return Payload ? &Payload->BrokenReductions : nullptr; 513 } 514 515 isl_ast_build *IslAstInfo::getBuild(__isl_keep isl_ast_node *Node) { 516 IslAstUserPayload *Payload = getNodePayload(Node); 517 return Payload ? Payload->Build : nullptr; 518 } 519 520 void IslAstInfo::printScop(raw_ostream &OS, Scop &S) const { 521 isl_ast_print_options *Options; 522 isl_ast_node *RootNode = getAst(); 523 Function *F = S.getRegion().getEntry()->getParent(); 524 525 OS << ":: isl ast :: " << F->getName() << " :: " << S.getRegion().getNameStr() 526 << "\n"; 527 528 if (!RootNode) { 529 OS << ":: isl ast generation and code generation was skipped!\n\n"; 530 return; 531 } 532 533 isl_ast_expr *RunCondition = getRunCondition(); 534 char *RtCStr, *AstStr; 535 536 Options = isl_ast_print_options_alloc(S.getIslCtx()); 537 Options = isl_ast_print_options_set_print_for(Options, cbPrintFor, nullptr); 538 539 isl_printer *P = isl_printer_to_str(S.getIslCtx()); 540 P = isl_printer_print_ast_expr(P, RunCondition); 541 RtCStr = isl_printer_get_str(P); 542 P = isl_printer_flush(P); 543 P = isl_printer_indent(P, 4); 544 P = isl_printer_set_output_format(P, ISL_FORMAT_C); 545 P = isl_ast_node_print(RootNode, P, Options); 546 AstStr = isl_printer_get_str(P); 547 548 isl_union_map *Schedule = 549 isl_union_map_intersect_domain(S.getSchedule(), S.getDomains()); 550 551 DEBUG({ 552 dbgs() << S.getContextStr() << "\n"; 553 dbgs() << stringFromIslObj(Schedule); 554 }); 555 OS << "\nif (" << RtCStr << ")\n\n"; 556 OS << AstStr << "\n"; 557 OS << "else\n"; 558 OS << " { /* original code */ }\n\n"; 559 560 isl_ast_expr_free(RunCondition); 561 isl_union_map_free(Schedule); 562 isl_ast_node_free(RootNode); 563 isl_printer_free(P); 564 } 565 566 void IslAstInfo::getAnalysisUsage(AnalysisUsage &AU) const { 567 // Get the Common analysis usage of ScopPasses. 568 ScopPass::getAnalysisUsage(AU); 569 AU.addRequired<ScopInfo>(); 570 AU.addRequired<DependenceInfo>(); 571 } 572 573 char IslAstInfo::ID = 0; 574 575 Pass *polly::createIslAstInfoPass() { return new IslAstInfo(); } 576 577 INITIALIZE_PASS_BEGIN(IslAstInfo, "polly-ast", 578 "Polly - Generate an AST of the SCoP (isl)", false, 579 false); 580 INITIALIZE_PASS_DEPENDENCY(ScopInfo); 581 INITIALIZE_PASS_DEPENDENCY(DependenceInfo); 582 INITIALIZE_PASS_END(IslAstInfo, "polly-ast", 583 "Polly - Generate an AST from the SCoP (isl)", false, false) 584