1 //===- IslAst.cpp - isl code generator interface --------------------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // The isl code generator interface takes a Scop and generates a isl_ast. This 11 // ist_ast can either be returned directly or it can be pretty printed to 12 // stdout. 13 // 14 // A typical isl_ast output looks like this: 15 // 16 // for (c2 = max(0, ceild(n + m, 2); c2 <= min(511, floord(5 * n, 3)); c2++) { 17 // bb2(c2); 18 // } 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "polly/CodeGen/CodeGeneration.h" 23 #include "polly/CodeGen/IslAst.h" 24 #include "polly/Dependences.h" 25 #include "polly/LinkAllPasses.h" 26 #include "polly/Options.h" 27 #include "polly/ScopInfo.h" 28 #include "polly/Support/GICHelper.h" 29 #include "llvm/Support/Debug.h" 30 31 #include "isl/union_map.h" 32 #include "isl/list.h" 33 #include "isl/ast_build.h" 34 #include "isl/set.h" 35 #include "isl/map.h" 36 #include "isl/aff.h" 37 38 #define DEBUG_TYPE "polly-ast" 39 40 using namespace llvm; 41 using namespace polly; 42 43 using IslAstUserPayload = IslAstInfo::IslAstUserPayload; 44 45 static cl::opt<bool> UseContext("polly-ast-use-context", 46 cl::desc("Use context"), cl::Hidden, 47 cl::init(false), cl::ZeroOrMore, 48 cl::cat(PollyCategory)); 49 50 static cl::opt<bool> DetectParallel("polly-ast-detect-parallel", 51 cl::desc("Detect parallelism"), cl::Hidden, 52 cl::init(false), cl::ZeroOrMore, 53 cl::cat(PollyCategory)); 54 55 namespace polly { 56 class IslAst { 57 public: 58 IslAst(Scop *Scop, Dependences &D); 59 60 ~IslAst(); 61 62 /// Print a source code representation of the program. 63 void pprint(llvm::raw_ostream &OS); 64 65 __isl_give isl_ast_node *getAst(); 66 67 /// @brief Get the run-time conditions for the Scop. 68 __isl_give isl_ast_expr *getRunCondition(); 69 70 private: 71 Scop *S; 72 isl_ast_node *Root; 73 isl_ast_expr *RunCondition; 74 75 void buildRunCondition(__isl_keep isl_ast_build *Build); 76 }; 77 } // End namespace polly. 78 79 /// @brief Free an IslAstUserPayload object pointed to by @p Ptr 80 static void freeIslAstUserPayload(void *Ptr) { 81 delete ((IslAstInfo::IslAstUserPayload *)Ptr); 82 } 83 84 IslAstInfo::IslAstUserPayload::~IslAstUserPayload() { 85 isl_ast_build_free(Build); 86 isl_pw_aff_free(MinimalDependenceDistance); 87 } 88 89 /// @brief Temporary information used when building the ast. 90 struct AstBuildUserInfo { 91 /// @brief Construct and initialize the helper struct for AST creation. 92 AstBuildUserInfo() 93 : Deps(nullptr), InParallelFor(false), LastForNodeId(nullptr) {} 94 95 /// @brief The dependence information used for the parallelism check. 96 Dependences *Deps; 97 98 /// @brief Flag to indicate that we are inside a parallel for node. 99 bool InParallelFor; 100 101 /// @brief The last iterator id created for the current SCoP. 102 isl_id *LastForNodeId; 103 }; 104 105 /// @brief Print a string @p str in a single line using @p Printer. 106 static isl_printer *printLine(__isl_take isl_printer *Printer, 107 const std::string &str, 108 __isl_keep isl_pw_aff *PWA = nullptr) { 109 Printer = isl_printer_start_line(Printer); 110 Printer = isl_printer_print_str(Printer, str.c_str()); 111 if (PWA) 112 Printer = isl_printer_print_pw_aff(Printer, PWA); 113 return isl_printer_end_line(Printer); 114 } 115 116 /// @brief Return all broken reductions as a string of clauses (OpenMP style). 117 static const std::string getBrokenReductionsStr(__isl_keep isl_ast_node *Node) { 118 IslAstInfo::MemoryAccessSet *BrokenReductions; 119 std::string str; 120 121 BrokenReductions = IslAstInfo::getBrokenReductions(Node); 122 if (!BrokenReductions || BrokenReductions->empty()) 123 return ""; 124 125 // Map each type of reduction to a comma separated list of the base addresses. 126 std::map<MemoryAccess::ReductionType, std::string> Clauses; 127 for (MemoryAccess *MA : *BrokenReductions) 128 if (MA->isWrite()) 129 Clauses[MA->getReductionType()] += 130 ", " + MA->getBaseAddr()->getName().str(); 131 132 // Now print the reductions sorted by type. Each type will cause a clause 133 // like: reduction (+ : sum0, sum1, sum2) 134 for (const auto &ReductionClause : Clauses) { 135 str += " reduction ("; 136 str += MemoryAccess::getReductionOperatorStr(ReductionClause.first); 137 // Remove the first two symbols (", ") to make the output look pretty. 138 str += " : " + ReductionClause.second.substr(2) + ")"; 139 } 140 141 return str; 142 } 143 144 /// @brief Callback executed for each for node in the ast in order to print it. 145 static isl_printer *cbPrintFor(__isl_take isl_printer *Printer, 146 __isl_take isl_ast_print_options *Options, 147 __isl_keep isl_ast_node *Node, void *) { 148 149 isl_pw_aff *DD = IslAstInfo::getMinimalDependenceDistance(Node); 150 const std::string BrokenReductionsStr = getBrokenReductionsStr(Node); 151 const std::string DepDisPragmaStr = "#pragma minimal dependence distance: "; 152 const std::string SimdPragmaStr = "#pragma simd"; 153 const std::string OmpPragmaStr = "#pragma omp parallel for"; 154 155 if (DD) 156 Printer = printLine(Printer, DepDisPragmaStr, DD); 157 158 if (IslAstInfo::isInnermostParallel(Node)) 159 Printer = printLine(Printer, SimdPragmaStr + BrokenReductionsStr); 160 161 if (IslAstInfo::isOutermostParallel(Node)) 162 Printer = printLine(Printer, OmpPragmaStr + BrokenReductionsStr); 163 164 isl_pw_aff_free(DD); 165 return isl_ast_node_for_print(Node, Printer, Options); 166 } 167 168 /// @brief Check if the current scheduling dimension is parallel 169 /// 170 /// In case the dimension is parallel we also check if any reduction 171 /// dependences is broken when we exploit this parallelism. If so, 172 /// @p IsReductionParallel will be set to true. The reduction dependences we use 173 /// to check are actually the union of the transitive closure of the initial 174 /// reduction dependences together with their reveresal. Even though these 175 /// dependences connect all iterations with each other (thus they are cyclic) 176 /// we can perform the parallelism check as we are only interested in a zero 177 /// (or non-zero) dependence distance on the dimension in question. 178 static bool astScheduleDimIsParallel(__isl_keep isl_ast_build *Build, 179 Dependences *D, 180 IslAstUserPayload *NodeInfo) { 181 if (!D->hasValidDependences()) 182 return false; 183 184 isl_union_map *Schedule = isl_ast_build_get_schedule(Build); 185 isl_union_map *Deps = D->getDependences( 186 Dependences::TYPE_RAW | Dependences::TYPE_WAW | Dependences::TYPE_WAR); 187 188 if (!D->isParallel(Schedule, Deps, &NodeInfo->MinimalDependenceDistance) && 189 !isl_union_map_free(Schedule)) 190 return false; 191 192 isl_union_map *RedDeps = D->getDependences(Dependences::TYPE_TC_RED); 193 if (!D->isParallel(Schedule, RedDeps)) 194 NodeInfo->IsReductionParallel = true; 195 196 if (!NodeInfo->IsReductionParallel && !isl_union_map_free(Schedule)) 197 return true; 198 199 // Annotate reduction parallel nodes with the memory accesses which caused the 200 // reduction dependences parallel execution of the node conflicts with. 201 for (const auto &MaRedPair : D->getReductionDependences()) { 202 if (!MaRedPair.second) 203 continue; 204 RedDeps = isl_union_map_from_map(isl_map_copy(MaRedPair.second)); 205 if (!D->isParallel(Schedule, RedDeps)) 206 NodeInfo->BrokenReductions.insert(MaRedPair.first); 207 } 208 209 isl_union_map_free(Schedule); 210 return true; 211 } 212 213 // This method is executed before the construction of a for node. It creates 214 // an isl_id that is used to annotate the subsequently generated ast for nodes. 215 // 216 // In this function we also run the following analyses: 217 // 218 // - Detection of openmp parallel loops 219 // 220 static __isl_give isl_id *astBuildBeforeFor(__isl_keep isl_ast_build *Build, 221 void *User) { 222 AstBuildUserInfo *BuildInfo = (AstBuildUserInfo *)User; 223 IslAstUserPayload *Payload = new IslAstUserPayload(); 224 isl_id *Id = isl_id_alloc(isl_ast_build_get_ctx(Build), "", Payload); 225 Id = isl_id_set_free_user(Id, freeIslAstUserPayload); 226 BuildInfo->LastForNodeId = Id; 227 228 // Test for parallelism only if we are not already inside a parallel loop 229 if (!BuildInfo->InParallelFor) 230 BuildInfo->InParallelFor = Payload->IsOutermostParallel = 231 astScheduleDimIsParallel(Build, BuildInfo->Deps, Payload); 232 233 return Id; 234 } 235 236 // This method is executed after the construction of a for node. 237 // 238 // It performs the following actions: 239 // 240 // - Reset the 'InParallelFor' flag, as soon as we leave a for node, 241 // that is marked as openmp parallel. 242 // 243 static __isl_give isl_ast_node * 244 astBuildAfterFor(__isl_take isl_ast_node *Node, __isl_keep isl_ast_build *Build, 245 void *User) { 246 isl_id *Id = isl_ast_node_get_annotation(Node); 247 assert(Id && "Post order visit assumes annotated for nodes"); 248 IslAstUserPayload *Payload = (IslAstUserPayload *)isl_id_get_user(Id); 249 assert(Payload && "Post order visit assumes annotated for nodes"); 250 251 AstBuildUserInfo *BuildInfo = (AstBuildUserInfo *)User; 252 assert(!Payload->Build && "Build environment already set"); 253 Payload->Build = isl_ast_build_copy(Build); 254 Payload->IsInnermost = (Id == BuildInfo->LastForNodeId); 255 256 // Innermost loops that are surrounded by parallel loops have not yet been 257 // tested for parallelism. Test them here to ensure we check all innermost 258 // loops for parallelism. 259 if (Payload->IsInnermost && BuildInfo->InParallelFor) { 260 if (Payload->IsOutermostParallel) 261 Payload->IsInnermostParallel = true; 262 else 263 Payload->IsInnermostParallel = 264 astScheduleDimIsParallel(Build, BuildInfo->Deps, Payload); 265 } 266 if (Payload->IsOutermostParallel) 267 BuildInfo->InParallelFor = false; 268 269 isl_id_free(Id); 270 return Node; 271 } 272 273 static __isl_give isl_ast_node *AtEachDomain(__isl_take isl_ast_node *Node, 274 __isl_keep isl_ast_build *Build, 275 void *User) { 276 assert(!isl_ast_node_get_annotation(Node) && "Node already annotated"); 277 278 IslAstUserPayload *Payload = new IslAstUserPayload(); 279 isl_id *Id = isl_id_alloc(isl_ast_build_get_ctx(Build), "", Payload); 280 Id = isl_id_set_free_user(Id, freeIslAstUserPayload); 281 282 Payload->Build = isl_ast_build_copy(Build); 283 284 return isl_ast_node_set_annotation(Node, Id); 285 } 286 287 void IslAst::buildRunCondition(__isl_keep isl_ast_build *Build) { 288 // The conditions that need to be checked at run-time for this scop are 289 // available as an isl_set in the AssumedContext. We generate code for this 290 // check as follows. First, we generate an isl_pw_aff that is 1, if a certain 291 // combination of parameter values fulfills the conditions in the assumed 292 // context, and that is 0 otherwise. We then translate this isl_pw_aff into 293 // an isl_ast_expr. At run-time this expression can be evaluated and the 294 // optimized scop can be executed conditionally according to the result of the 295 // run-time check. 296 297 isl_aff *Zero = 298 isl_aff_zero_on_domain(isl_local_space_from_space(S->getParamSpace())); 299 isl_aff *One = 300 isl_aff_zero_on_domain(isl_local_space_from_space(S->getParamSpace())); 301 302 One = isl_aff_add_constant_si(One, 1); 303 304 isl_pw_aff *PwZero = isl_pw_aff_from_aff(Zero); 305 isl_pw_aff *PwOne = isl_pw_aff_from_aff(One); 306 307 PwOne = isl_pw_aff_intersect_domain(PwOne, S->getAssumedContext()); 308 PwZero = isl_pw_aff_intersect_domain( 309 PwZero, isl_set_complement(S->getAssumedContext())); 310 311 isl_pw_aff *Cond = isl_pw_aff_union_max(PwOne, PwZero); 312 313 RunCondition = isl_ast_build_expr_from_pw_aff(Build, Cond); 314 315 // Create the alias checks from the minimal/maximal accesses in each alias 316 // group. This operation is by construction quadratic in the number of 317 // elements in each alias group. 318 isl_ast_expr *NonAliasGroup, *MinExpr, *MaxExpr; 319 for (const Scop::MinMaxVectorTy *MinMaxAccesses : S->getAliasGroups()) { 320 auto AccEnd = MinMaxAccesses->end(); 321 for (auto AccIt0 = MinMaxAccesses->begin(); AccIt0 != AccEnd; ++AccIt0) { 322 for (auto AccIt1 = AccIt0 + 1; AccIt1 != AccEnd; ++AccIt1) { 323 MinExpr = 324 isl_ast_expr_address_of(isl_ast_build_access_from_pw_multi_aff( 325 Build, isl_pw_multi_aff_copy(AccIt0->first))); 326 MaxExpr = 327 isl_ast_expr_address_of(isl_ast_build_access_from_pw_multi_aff( 328 Build, isl_pw_multi_aff_copy(AccIt1->second))); 329 NonAliasGroup = isl_ast_expr_le(MaxExpr, MinExpr); 330 MinExpr = 331 isl_ast_expr_address_of(isl_ast_build_access_from_pw_multi_aff( 332 Build, isl_pw_multi_aff_copy(AccIt1->first))); 333 MaxExpr = 334 isl_ast_expr_address_of(isl_ast_build_access_from_pw_multi_aff( 335 Build, isl_pw_multi_aff_copy(AccIt0->second))); 336 NonAliasGroup = 337 isl_ast_expr_or(NonAliasGroup, isl_ast_expr_le(MaxExpr, MinExpr)); 338 RunCondition = isl_ast_expr_and(RunCondition, NonAliasGroup); 339 } 340 } 341 } 342 } 343 344 IslAst::IslAst(Scop *Scop, Dependences &D) : S(Scop) { 345 isl_ctx *Ctx = S->getIslCtx(); 346 isl_options_set_ast_build_atomic_upper_bound(Ctx, true); 347 isl_ast_build *Build; 348 AstBuildUserInfo BuildInfo; 349 350 if (UseContext) 351 Build = isl_ast_build_from_context(S->getContext()); 352 else 353 Build = isl_ast_build_from_context(isl_set_universe(S->getParamSpace())); 354 355 Build = isl_ast_build_set_at_each_domain(Build, AtEachDomain, nullptr); 356 357 isl_union_map *Schedule = 358 isl_union_map_intersect_domain(S->getSchedule(), S->getDomains()); 359 360 if (DetectParallel || PollyVectorizerChoice != VECTORIZER_NONE) { 361 BuildInfo.Deps = &D; 362 BuildInfo.InParallelFor = 0; 363 364 Build = isl_ast_build_set_before_each_for(Build, &astBuildBeforeFor, 365 &BuildInfo); 366 Build = 367 isl_ast_build_set_after_each_for(Build, &astBuildAfterFor, &BuildInfo); 368 } 369 370 buildRunCondition(Build); 371 372 Root = isl_ast_build_ast_from_schedule(Build, Schedule); 373 374 isl_ast_build_free(Build); 375 } 376 377 IslAst::~IslAst() { 378 isl_ast_node_free(Root); 379 isl_ast_expr_free(RunCondition); 380 } 381 382 __isl_give isl_ast_node *IslAst::getAst() { return isl_ast_node_copy(Root); } 383 __isl_give isl_ast_expr *IslAst::getRunCondition() { 384 return isl_ast_expr_copy(RunCondition); 385 } 386 387 void IslAstInfo::releaseMemory() { 388 if (Ast) { 389 delete Ast; 390 Ast = 0; 391 } 392 } 393 394 bool IslAstInfo::runOnScop(Scop &Scop) { 395 if (Ast) 396 delete Ast; 397 398 S = &Scop; 399 400 Dependences &D = getAnalysis<Dependences>(); 401 402 Ast = new IslAst(&Scop, D); 403 404 DEBUG(printScop(dbgs())); 405 return false; 406 } 407 408 __isl_give isl_ast_node *IslAstInfo::getAst() const { return Ast->getAst(); } 409 __isl_give isl_ast_expr *IslAstInfo::getRunCondition() const { 410 return Ast->getRunCondition(); 411 } 412 413 IslAstUserPayload *IslAstInfo::getNodePayload(__isl_keep isl_ast_node *Node) { 414 isl_id *Id = isl_ast_node_get_annotation(Node); 415 if (!Id) 416 return nullptr; 417 IslAstUserPayload *Payload = (IslAstUserPayload *)isl_id_get_user(Id); 418 isl_id_free(Id); 419 return Payload; 420 } 421 422 bool IslAstInfo::isInnermost(__isl_keep isl_ast_node *Node) { 423 IslAstUserPayload *Payload = getNodePayload(Node); 424 return Payload && Payload->IsInnermost; 425 } 426 427 bool IslAstInfo::isParallel(__isl_keep isl_ast_node *Node) { 428 return IslAstInfo::isInnermostParallel(Node) || 429 IslAstInfo::isOutermostParallel(Node); 430 } 431 432 bool IslAstInfo::isInnermostParallel(__isl_keep isl_ast_node *Node) { 433 IslAstUserPayload *Payload = getNodePayload(Node); 434 return Payload && Payload->IsInnermostParallel; 435 } 436 437 bool IslAstInfo::isOutermostParallel(__isl_keep isl_ast_node *Node) { 438 IslAstUserPayload *Payload = getNodePayload(Node); 439 return Payload && Payload->IsOutermostParallel; 440 } 441 442 bool IslAstInfo::isReductionParallel(__isl_keep isl_ast_node *Node) { 443 IslAstUserPayload *Payload = getNodePayload(Node); 444 return Payload && Payload->IsReductionParallel; 445 } 446 447 isl_union_map *IslAstInfo::getSchedule(__isl_keep isl_ast_node *Node) { 448 IslAstUserPayload *Payload = getNodePayload(Node); 449 return Payload ? isl_ast_build_get_schedule(Payload->Build) : nullptr; 450 } 451 452 isl_pw_aff * 453 IslAstInfo::getMinimalDependenceDistance(__isl_keep isl_ast_node *Node) { 454 IslAstUserPayload *Payload = getNodePayload(Node); 455 return Payload ? isl_pw_aff_copy(Payload->MinimalDependenceDistance) 456 : nullptr; 457 } 458 459 IslAstInfo::MemoryAccessSet * 460 IslAstInfo::getBrokenReductions(__isl_keep isl_ast_node *Node) { 461 IslAstUserPayload *Payload = getNodePayload(Node); 462 return Payload ? &Payload->BrokenReductions : nullptr; 463 } 464 465 isl_ast_build *IslAstInfo::getBuild(__isl_keep isl_ast_node *Node) { 466 IslAstUserPayload *Payload = getNodePayload(Node); 467 return Payload ? Payload->Build : nullptr; 468 } 469 470 void IslAstInfo::printScop(raw_ostream &OS) const { 471 isl_ast_print_options *Options; 472 isl_ast_node *RootNode = getAst(); 473 isl_ast_expr *RunCondition = getRunCondition(); 474 char *RtCStr, *AstStr; 475 476 Scop &S = getCurScop(); 477 Options = isl_ast_print_options_alloc(S.getIslCtx()); 478 Options = isl_ast_print_options_set_print_for(Options, cbPrintFor, nullptr); 479 480 isl_printer *P = isl_printer_to_str(S.getIslCtx()); 481 P = isl_printer_print_ast_expr(P, RunCondition); 482 RtCStr = isl_printer_get_str(P); 483 P = isl_printer_flush(P); 484 P = isl_printer_indent(P, 4); 485 P = isl_printer_set_output_format(P, ISL_FORMAT_C); 486 P = isl_ast_node_print(RootNode, P, Options); 487 AstStr = isl_printer_get_str(P); 488 489 Function *F = S.getRegion().getEntry()->getParent(); 490 isl_union_map *Schedule = 491 isl_union_map_intersect_domain(S.getSchedule(), S.getDomains()); 492 493 OS << ":: isl ast :: " << F->getName() << " :: " << S.getRegion().getNameStr() 494 << "\n"; 495 DEBUG({ 496 dbgs() << S.getContextStr() << "\n"; 497 dbgs() << stringFromIslObj(Schedule); 498 }); 499 OS << "\nif (" << RtCStr << ")\n\n"; 500 OS << AstStr << "\n"; 501 OS << "else\n"; 502 OS << " { /* original code */ }\n\n"; 503 504 isl_ast_expr_free(RunCondition); 505 isl_union_map_free(Schedule); 506 isl_ast_node_free(RootNode); 507 isl_printer_free(P); 508 } 509 510 void IslAstInfo::getAnalysisUsage(AnalysisUsage &AU) const { 511 // Get the Common analysis usage of ScopPasses. 512 ScopPass::getAnalysisUsage(AU); 513 AU.addRequired<ScopInfo>(); 514 AU.addRequired<Dependences>(); 515 } 516 517 char IslAstInfo::ID = 0; 518 519 Pass *polly::createIslAstInfoPass() { return new IslAstInfo(); } 520 521 INITIALIZE_PASS_BEGIN(IslAstInfo, "polly-ast", 522 "Polly - Generate an AST of the SCoP (isl)", false, 523 false); 524 INITIALIZE_PASS_DEPENDENCY(ScopInfo); 525 INITIALIZE_PASS_DEPENDENCY(Dependences); 526 INITIALIZE_PASS_END(IslAstInfo, "polly-ast", 527 "Polly - Generate an AST from the SCoP (isl)", false, false) 528