1 //===- IslAst.cpp - isl code generator interface --------------------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // The isl code generator interface takes a Scop and generates a isl_ast. This 11 // ist_ast can either be returned directly or it can be pretty printed to 12 // stdout. 13 // 14 // A typical isl_ast output looks like this: 15 // 16 // for (c2 = max(0, ceild(n + m, 2); c2 <= min(511, floord(5 * n, 3)); c2++) { 17 // bb2(c2); 18 // } 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "polly/CodeGen/CodeGeneration.h" 23 #include "polly/CodeGen/IslAst.h" 24 #include "polly/Dependences.h" 25 #include "polly/LinkAllPasses.h" 26 #include "polly/Options.h" 27 #include "polly/ScopInfo.h" 28 #include "llvm/Support/Debug.h" 29 30 #include "isl/union_map.h" 31 #include "isl/list.h" 32 #include "isl/ast_build.h" 33 #include "isl/set.h" 34 #include "isl/map.h" 35 #include "isl/aff.h" 36 37 using namespace llvm; 38 using namespace polly; 39 40 #define DEBUG_TYPE "polly-ast" 41 42 static cl::opt<bool> UseContext("polly-ast-use-context", 43 cl::desc("Use context"), cl::Hidden, 44 cl::init(false), cl::ZeroOrMore, 45 cl::cat(PollyCategory)); 46 47 static cl::opt<bool> DetectParallel("polly-ast-detect-parallel", 48 cl::desc("Detect parallelism"), cl::Hidden, 49 cl::init(false), cl::ZeroOrMore, 50 cl::cat(PollyCategory)); 51 52 namespace polly { 53 class IslAst { 54 public: 55 IslAst(Scop *Scop, Dependences &D); 56 57 ~IslAst(); 58 59 /// Print a source code representation of the program. 60 void pprint(llvm::raw_ostream &OS); 61 62 __isl_give isl_ast_node *getAst(); 63 64 /// @brief Get the run-time conditions for the Scop. 65 __isl_give isl_ast_expr *getRunCondition(); 66 67 private: 68 Scop *S; 69 isl_ast_node *Root; 70 isl_ast_expr *RunCondition; 71 72 __isl_give isl_union_map *getSchedule(); 73 void buildRunCondition(__isl_keep isl_ast_build *Context); 74 }; 75 } // End namespace polly. 76 77 // Temporary information used when building the ast. 78 struct AstBuildUserInfo { 79 // The dependence information. 80 Dependences *Deps; 81 82 // We are inside a parallel for node. 83 int InParallelFor; 84 }; 85 86 // Print a loop annotated with OpenMP or vector pragmas. 87 static __isl_give isl_printer * 88 printParallelFor(__isl_keep isl_ast_node *Node, __isl_take isl_printer *Printer, 89 __isl_take isl_ast_print_options *PrintOptions, 90 IslAstUser *Info) { 91 if (Info) { 92 if (Info->IsInnermostParallel) { 93 Printer = isl_printer_start_line(Printer); 94 Printer = isl_printer_print_str(Printer, "#pragma simd"); 95 Printer = isl_printer_end_line(Printer); 96 } 97 if (Info->IsOutermostParallel) { 98 Printer = isl_printer_start_line(Printer); 99 Printer = isl_printer_print_str(Printer, "#pragma omp parallel for"); 100 Printer = isl_printer_end_line(Printer); 101 } 102 } 103 return isl_ast_node_for_print(Node, Printer, PrintOptions); 104 } 105 106 // Print an isl_ast_for. 107 static __isl_give isl_printer * 108 printFor(__isl_take isl_printer *Printer, 109 __isl_take isl_ast_print_options *PrintOptions, 110 __isl_keep isl_ast_node *Node, void *User) { 111 isl_id *Id = isl_ast_node_get_annotation(Node); 112 if (!Id) 113 return isl_ast_node_for_print(Node, Printer, PrintOptions); 114 115 struct IslAstUser *Info = (struct IslAstUser *)isl_id_get_user(Id); 116 Printer = printParallelFor(Node, Printer, PrintOptions, Info); 117 isl_id_free(Id); 118 return Printer; 119 } 120 121 // Allocate an AstNodeInfo structure and initialize it with default values. 122 static struct IslAstUser *allocateIslAstUser() { 123 struct IslAstUser *NodeInfo; 124 NodeInfo = (struct IslAstUser *)malloc(sizeof(struct IslAstUser)); 125 NodeInfo->PMA = 0; 126 NodeInfo->Context = 0; 127 NodeInfo->IsOutermostParallel = 0; 128 NodeInfo->IsInnermostParallel = 0; 129 return NodeInfo; 130 } 131 132 // Free the AstNodeInfo structure. 133 static void freeIslAstUser(void *Ptr) { 134 struct IslAstUser *UserStruct = (struct IslAstUser *)Ptr; 135 isl_ast_build_free(UserStruct->Context); 136 isl_pw_multi_aff_free(UserStruct->PMA); 137 free(UserStruct); 138 } 139 140 // Check if the current scheduling dimension is parallel. 141 // 142 // We check for parallelism by verifying that the loop does not carry any 143 // dependences. 144 // 145 // Parallelism test: if the distance is zero in all outer dimensions, then it 146 // has to be zero in the current dimension as well. 147 // 148 // Implementation: first, translate dependences into time space, then force 149 // outer dimensions to be equal. If the distance is zero in the current 150 // dimension, then the loop is parallel. The distance is zero in the current 151 // dimension if it is a subset of a map with equal values for the current 152 // dimension. 153 static bool astScheduleDimIsParallel(__isl_keep isl_ast_build *Build, 154 Dependences *D) { 155 isl_union_map *Schedule, *Deps; 156 isl_map *ScheduleDeps, *Test; 157 isl_space *ScheduleSpace; 158 unsigned Dimension, IsParallel; 159 160 if (!D->hasValidDependences()) { 161 return false; 162 } 163 164 Schedule = isl_ast_build_get_schedule(Build); 165 ScheduleSpace = isl_ast_build_get_schedule_space(Build); 166 167 Dimension = isl_space_dim(ScheduleSpace, isl_dim_out) - 1; 168 169 // FIXME: We can remove ignore reduction dependences in case we privatize the 170 // memory locations the reduction statements reduce into. 171 Deps = D->getDependences(Dependences::TYPE_ALL | Dependences::TYPE_RED); 172 Deps = isl_union_map_apply_range(Deps, isl_union_map_copy(Schedule)); 173 Deps = isl_union_map_apply_domain(Deps, Schedule); 174 175 if (isl_union_map_is_empty(Deps)) { 176 isl_union_map_free(Deps); 177 isl_space_free(ScheduleSpace); 178 return true; 179 } 180 181 ScheduleDeps = isl_map_from_union_map(Deps); 182 183 for (unsigned i = 0; i < Dimension; i++) 184 ScheduleDeps = isl_map_equate(ScheduleDeps, isl_dim_out, i, isl_dim_in, i); 185 186 Test = isl_map_universe(isl_map_get_space(ScheduleDeps)); 187 Test = isl_map_equate(Test, isl_dim_out, Dimension, isl_dim_in, Dimension); 188 IsParallel = isl_map_is_subset(ScheduleDeps, Test); 189 190 isl_space_free(ScheduleSpace); 191 isl_map_free(Test); 192 isl_map_free(ScheduleDeps); 193 194 return IsParallel; 195 } 196 197 // Mark a for node openmp parallel, if it is the outermost parallel for node. 198 static void markOpenmpParallel(__isl_keep isl_ast_build *Build, 199 struct AstBuildUserInfo *BuildInfo, 200 struct IslAstUser *NodeInfo) { 201 if (BuildInfo->InParallelFor) 202 return; 203 204 if (astScheduleDimIsParallel(Build, BuildInfo->Deps)) { 205 BuildInfo->InParallelFor = 1; 206 NodeInfo->IsOutermostParallel = 1; 207 } 208 } 209 210 // This method is executed before the construction of a for node. It creates 211 // an isl_id that is used to annotate the subsequently generated ast for nodes. 212 // 213 // In this function we also run the following analyses: 214 // 215 // - Detection of openmp parallel loops 216 // 217 static __isl_give isl_id *astBuildBeforeFor(__isl_keep isl_ast_build *Build, 218 void *User) { 219 struct AstBuildUserInfo *BuildInfo = (struct AstBuildUserInfo *)User; 220 struct IslAstUser *NodeInfo = allocateIslAstUser(); 221 isl_id *Id = isl_id_alloc(isl_ast_build_get_ctx(Build), "", NodeInfo); 222 Id = isl_id_set_free_user(Id, freeIslAstUser); 223 224 markOpenmpParallel(Build, BuildInfo, NodeInfo); 225 226 return Id; 227 } 228 229 // Returns 0 when Node contains loops, otherwise returns -1. This search 230 // function uses ISL's way to iterate over lists of isl_ast_nodes with 231 // isl_ast_node_list_foreach. Please use the single argument wrapper function 232 // that returns a bool instead of using this function directly. 233 static int containsLoops(__isl_take isl_ast_node *Node, void *User) { 234 if (!Node) 235 return -1; 236 237 switch (isl_ast_node_get_type(Node)) { 238 case isl_ast_node_for: 239 isl_ast_node_free(Node); 240 return 0; 241 case isl_ast_node_block: { 242 isl_ast_node_list *List = isl_ast_node_block_get_children(Node); 243 int Res = isl_ast_node_list_foreach(List, &containsLoops, nullptr); 244 isl_ast_node_list_free(List); 245 isl_ast_node_free(Node); 246 return Res; 247 } 248 case isl_ast_node_if: { 249 int Res = -1; 250 if (0 == containsLoops(isl_ast_node_if_get_then(Node), nullptr) || 251 (isl_ast_node_if_has_else(Node) && 252 0 == containsLoops(isl_ast_node_if_get_else(Node), nullptr))) 253 Res = 0; 254 isl_ast_node_free(Node); 255 return Res; 256 } 257 case isl_ast_node_user: 258 default: 259 isl_ast_node_free(Node); 260 return -1; 261 } 262 } 263 264 // Returns true when Node contains loops. 265 static bool containsLoops(__isl_take isl_ast_node *Node) { 266 return 0 == containsLoops(Node, nullptr); 267 } 268 269 // This method is executed after the construction of a for node. 270 // 271 // It performs the following actions: 272 // 273 // - Reset the 'InParallelFor' flag, as soon as we leave a for node, 274 // that is marked as openmp parallel. 275 // 276 static __isl_give isl_ast_node * 277 astBuildAfterFor(__isl_take isl_ast_node *Node, __isl_keep isl_ast_build *Build, 278 void *User) { 279 isl_id *Id = isl_ast_node_get_annotation(Node); 280 if (!Id) 281 return Node; 282 struct IslAstUser *Info = (struct IslAstUser *)isl_id_get_user(Id); 283 struct AstBuildUserInfo *BuildInfo = (struct AstBuildUserInfo *)User; 284 285 if (Info) { 286 if (Info->IsOutermostParallel) 287 BuildInfo->InParallelFor = 0; 288 if (!containsLoops(isl_ast_node_for_get_body(Node))) 289 if (astScheduleDimIsParallel(Build, BuildInfo->Deps)) 290 Info->IsInnermostParallel = 1; 291 if (!Info->Context) 292 Info->Context = isl_ast_build_copy(Build); 293 } 294 295 isl_id_free(Id); 296 return Node; 297 } 298 299 static __isl_give isl_ast_node *AtEachDomain(__isl_take isl_ast_node *Node, 300 __isl_keep isl_ast_build *Context, 301 void *User) { 302 struct IslAstUser *Info = nullptr; 303 isl_id *Id = isl_ast_node_get_annotation(Node); 304 305 if (Id) 306 Info = (struct IslAstUser *)isl_id_get_user(Id); 307 308 if (!Info) { 309 // Allocate annotations once: parallel for detection might have already 310 // allocated the annotations for this node. 311 Info = allocateIslAstUser(); 312 Id = isl_id_alloc(isl_ast_node_get_ctx(Node), nullptr, Info); 313 Id = isl_id_set_free_user(Id, &freeIslAstUser); 314 } 315 316 if (!Info->PMA) { 317 isl_map *Map = isl_map_from_union_map(isl_ast_build_get_schedule(Context)); 318 Info->PMA = isl_pw_multi_aff_from_map(isl_map_reverse(Map)); 319 } 320 if (!Info->Context) 321 Info->Context = isl_ast_build_copy(Context); 322 323 return isl_ast_node_set_annotation(Node, Id); 324 } 325 326 void IslAst::buildRunCondition(__isl_keep isl_ast_build *Context) { 327 // The conditions that need to be checked at run-time for this scop are 328 // available as an isl_set in the AssumedContext. We generate code for this 329 // check as follows. First, we generate an isl_pw_aff that is 1, if a certain 330 // combination of parameter values fulfills the conditions in the assumed 331 // context, and that is 0 otherwise. We then translate this isl_pw_aff into 332 // an isl_ast_expr. At run-time this expression can be evaluated and the 333 // optimized scop can be executed conditionally according to the result of the 334 // run-time check. 335 336 isl_aff *Zero = 337 isl_aff_zero_on_domain(isl_local_space_from_space(S->getParamSpace())); 338 isl_aff *One = 339 isl_aff_zero_on_domain(isl_local_space_from_space(S->getParamSpace())); 340 341 One = isl_aff_add_constant_si(One, 1); 342 343 isl_pw_aff *PwZero = isl_pw_aff_from_aff(Zero); 344 isl_pw_aff *PwOne = isl_pw_aff_from_aff(One); 345 346 PwOne = isl_pw_aff_intersect_domain(PwOne, S->getAssumedContext()); 347 PwZero = isl_pw_aff_intersect_domain( 348 PwZero, isl_set_complement(S->getAssumedContext())); 349 350 isl_pw_aff *Cond = isl_pw_aff_union_max(PwZero, PwOne); 351 352 RunCondition = isl_ast_build_expr_from_pw_aff(Context, Cond); 353 } 354 355 IslAst::IslAst(Scop *Scop, Dependences &D) : S(Scop) { 356 isl_ctx *Ctx = S->getIslCtx(); 357 isl_options_set_ast_build_atomic_upper_bound(Ctx, true); 358 isl_ast_build *Context; 359 struct AstBuildUserInfo BuildInfo; 360 361 if (UseContext) 362 Context = isl_ast_build_from_context(S->getContext()); 363 else 364 Context = isl_ast_build_from_context(isl_set_universe(S->getParamSpace())); 365 366 Context = isl_ast_build_set_at_each_domain(Context, AtEachDomain, nullptr); 367 368 isl_union_map *Schedule = getSchedule(); 369 370 Function *F = Scop->getRegion().getEntry()->getParent(); 371 (void)F; 372 373 DEBUG(dbgs() << ":: isl ast :: " << F->getName() 374 << " :: " << Scop->getRegion().getNameStr() << "\n"); 375 376 DEBUG(dbgs() << S->getContextStr() << "\n"; isl_union_map_dump(Schedule)); 377 378 if (DetectParallel || PollyVectorizerChoice != VECTORIZER_NONE) { 379 BuildInfo.Deps = &D; 380 BuildInfo.InParallelFor = 0; 381 382 Context = isl_ast_build_set_before_each_for(Context, &astBuildBeforeFor, 383 &BuildInfo); 384 Context = isl_ast_build_set_after_each_for(Context, &astBuildAfterFor, 385 &BuildInfo); 386 } 387 388 buildRunCondition(Context); 389 390 Root = isl_ast_build_ast_from_schedule(Context, Schedule); 391 392 isl_ast_build_free(Context); 393 394 DEBUG(pprint(dbgs())); 395 } 396 397 __isl_give isl_union_map *IslAst::getSchedule() { 398 isl_union_map *Schedule = isl_union_map_empty(S->getParamSpace()); 399 400 for (ScopStmt *Stmt : *S) { 401 isl_map *StmtSchedule = Stmt->getScattering(); 402 403 StmtSchedule = isl_map_intersect_domain(StmtSchedule, Stmt->getDomain()); 404 Schedule = 405 isl_union_map_union(Schedule, isl_union_map_from_map(StmtSchedule)); 406 } 407 408 return Schedule; 409 } 410 411 IslAst::~IslAst() { 412 isl_ast_node_free(Root); 413 isl_ast_expr_free(RunCondition); 414 } 415 416 /// Print a C like representation of the program. 417 void IslAst::pprint(llvm::raw_ostream &OS) { 418 isl_ast_node *Root; 419 isl_ast_print_options *Options; 420 421 Options = isl_ast_print_options_alloc(S->getIslCtx()); 422 Options = isl_ast_print_options_set_print_for(Options, &printFor, nullptr); 423 424 isl_printer *P = isl_printer_to_str(S->getIslCtx()); 425 P = isl_printer_set_output_format(P, ISL_FORMAT_C); 426 427 P = isl_printer_print_ast_expr(P, RunCondition); 428 char *result = isl_printer_get_str(P); 429 P = isl_printer_flush(P); 430 431 OS << "\nif (" << result << ")\n\n"; 432 P = isl_printer_indent(P, 4); 433 434 Root = getAst(); 435 P = isl_ast_node_print(Root, P, Options); 436 result = isl_printer_get_str(P); 437 OS << result << "\n"; 438 OS << "else\n"; 439 OS << " { /* original code */ }\n\n"; 440 isl_printer_free(P); 441 isl_ast_node_free(Root); 442 } 443 444 __isl_give isl_ast_node *IslAst::getAst() { return isl_ast_node_copy(Root); } 445 __isl_give isl_ast_expr *IslAst::getRunCondition() { 446 return isl_ast_expr_copy(RunCondition); 447 } 448 449 void IslAstInfo::pprint(llvm::raw_ostream &OS) { Ast->pprint(OS); } 450 451 void IslAstInfo::releaseMemory() { 452 if (Ast) { 453 delete Ast; 454 Ast = 0; 455 } 456 } 457 458 bool IslAstInfo::runOnScop(Scop &Scop) { 459 if (Ast) 460 delete Ast; 461 462 S = &Scop; 463 464 Dependences &D = getAnalysis<Dependences>(); 465 466 Ast = new IslAst(&Scop, D); 467 468 return false; 469 } 470 471 __isl_give isl_ast_node *IslAstInfo::getAst() { return Ast->getAst(); } 472 __isl_give isl_ast_expr *IslAstInfo::getRunCondition() { 473 return Ast->getRunCondition(); 474 } 475 476 void IslAstInfo::printScop(raw_ostream &OS) const { 477 Function *F = S->getRegion().getEntry()->getParent(); 478 479 OS << F->getName() << "():\n"; 480 481 Ast->pprint(OS); 482 } 483 484 void IslAstInfo::getAnalysisUsage(AnalysisUsage &AU) const { 485 // Get the Common analysis usage of ScopPasses. 486 ScopPass::getAnalysisUsage(AU); 487 AU.addRequired<ScopInfo>(); 488 AU.addRequired<Dependences>(); 489 } 490 491 char IslAstInfo::ID = 0; 492 493 Pass *polly::createIslAstInfoPass() { return new IslAstInfo(); } 494 495 INITIALIZE_PASS_BEGIN(IslAstInfo, "polly-ast", 496 "Polly - Generate an AST of the SCoP (isl)", false, 497 false); 498 INITIALIZE_PASS_DEPENDENCY(ScopInfo); 499 INITIALIZE_PASS_DEPENDENCY(Dependences); 500 INITIALIZE_PASS_END(IslAstInfo, "polly-ast", 501 "Polly - Generate an AST from the SCoP (isl)", false, false) 502