1 //===- IslAst.cpp - isl code generator interface --------------------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // The isl code generator interface takes a Scop and generates a isl_ast. This 11 // ist_ast can either be returned directly or it can be pretty printed to 12 // stdout. 13 // 14 // A typical isl_ast output looks like this: 15 // 16 // for (c2 = max(0, ceild(n + m, 2); c2 <= min(511, floord(5 * n, 3)); c2++) { 17 // bb2(c2); 18 // } 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "polly/CodeGen/CodeGeneration.h" 23 #include "polly/CodeGen/IslAst.h" 24 #include "polly/Dependences.h" 25 #include "polly/LinkAllPasses.h" 26 #include "polly/Options.h" 27 #include "polly/ScopInfo.h" 28 #include "llvm/Support/Debug.h" 29 30 #include "isl/union_map.h" 31 #include "isl/list.h" 32 #include "isl/ast_build.h" 33 #include "isl/set.h" 34 #include "isl/map.h" 35 #include "isl/aff.h" 36 37 using namespace llvm; 38 using namespace polly; 39 40 #define DEBUG_TYPE "polly-ast" 41 42 static cl::opt<bool> UseContext("polly-ast-use-context", 43 cl::desc("Use context"), cl::Hidden, 44 cl::init(false), cl::ZeroOrMore, 45 cl::cat(PollyCategory)); 46 47 static cl::opt<bool> DetectParallel("polly-ast-detect-parallel", 48 cl::desc("Detect parallelism"), cl::Hidden, 49 cl::init(false), cl::ZeroOrMore, 50 cl::cat(PollyCategory)); 51 52 namespace polly { 53 class IslAst { 54 public: 55 IslAst(Scop *Scop, Dependences &D); 56 57 ~IslAst(); 58 59 /// Print a source code representation of the program. 60 void pprint(llvm::raw_ostream &OS); 61 62 __isl_give isl_ast_node *getAst(); 63 64 /// @brief Get the run-time conditions for the Scop. 65 __isl_give isl_ast_expr *getRunCondition(); 66 67 private: 68 Scop *S; 69 isl_ast_node *Root; 70 isl_ast_expr *RunCondition; 71 72 __isl_give isl_union_map *getSchedule(); 73 void buildRunCondition(__isl_keep isl_ast_build *Context); 74 }; 75 } // End namespace polly. 76 77 // Temporary information used when building the ast. 78 struct AstBuildUserInfo { 79 // The dependence information. 80 Dependences *Deps; 81 82 // We are inside a parallel for node. 83 int InParallelFor; 84 }; 85 86 // Print a loop annotated with OpenMP or vector pragmas. 87 static __isl_give isl_printer * 88 printParallelFor(__isl_keep isl_ast_node *Node, __isl_take isl_printer *Printer, 89 __isl_take isl_ast_print_options *PrintOptions, 90 IslAstUser *Info) { 91 if (Info) { 92 if (Info->IsInnermostParallel) { 93 Printer = isl_printer_start_line(Printer); 94 Printer = isl_printer_print_str(Printer, "#pragma simd"); 95 Printer = isl_printer_end_line(Printer); 96 } 97 if (Info->IsOutermostParallel) { 98 Printer = isl_printer_start_line(Printer); 99 Printer = isl_printer_print_str(Printer, "#pragma omp parallel for"); 100 Printer = isl_printer_end_line(Printer); 101 } 102 } 103 return isl_ast_node_for_print(Node, Printer, PrintOptions); 104 } 105 106 // Print an isl_ast_for. 107 static __isl_give isl_printer * 108 printFor(__isl_take isl_printer *Printer, 109 __isl_take isl_ast_print_options *PrintOptions, 110 __isl_keep isl_ast_node *Node, void *User) { 111 isl_id *Id = isl_ast_node_get_annotation(Node); 112 if (!Id) 113 return isl_ast_node_for_print(Node, Printer, PrintOptions); 114 115 struct IslAstUser *Info = (struct IslAstUser *)isl_id_get_user(Id); 116 Printer = printParallelFor(Node, Printer, PrintOptions, Info); 117 isl_id_free(Id); 118 return Printer; 119 } 120 121 // Allocate an AstNodeInfo structure and initialize it with default values. 122 static struct IslAstUser *allocateIslAstUser() { 123 struct IslAstUser *NodeInfo; 124 NodeInfo = (struct IslAstUser *)malloc(sizeof(struct IslAstUser)); 125 NodeInfo->PMA = 0; 126 NodeInfo->Context = 0; 127 NodeInfo->IsOutermostParallel = 0; 128 NodeInfo->IsInnermostParallel = 0; 129 return NodeInfo; 130 } 131 132 // Free the AstNodeInfo structure. 133 static void freeIslAstUser(void *Ptr) { 134 struct IslAstUser *UserStruct = (struct IslAstUser *)Ptr; 135 isl_ast_build_free(UserStruct->Context); 136 isl_pw_multi_aff_free(UserStruct->PMA); 137 free(UserStruct); 138 } 139 140 // Check if the current scheduling dimension is parallel. 141 // 142 // We check for parallelism by verifying that the loop does not carry any 143 // dependences. 144 // 145 // Parallelism test: if the distance is zero in all outer dimensions, then it 146 // has to be zero in the current dimension as well. 147 // 148 // Implementation: first, translate dependences into time space, then force 149 // outer dimensions to be equal. If the distance is zero in the current 150 // dimension, then the loop is parallel. The distance is zero in the current 151 // dimension if it is a subset of a map with equal values for the current 152 // dimension. 153 static bool astScheduleDimIsParallel(__isl_keep isl_ast_build *Build, 154 Dependences *D) { 155 isl_union_map *Schedule, *Deps; 156 isl_map *ScheduleDeps, *Test; 157 isl_space *ScheduleSpace; 158 unsigned Dimension, IsParallel; 159 160 if (!D->hasValidDependences()) { 161 return false; 162 } 163 164 Schedule = isl_ast_build_get_schedule(Build); 165 ScheduleSpace = isl_ast_build_get_schedule_space(Build); 166 167 Dimension = isl_space_dim(ScheduleSpace, isl_dim_out) - 1; 168 169 Deps = D->getDependences(Dependences::TYPE_ALL); 170 Deps = isl_union_map_apply_range(Deps, isl_union_map_copy(Schedule)); 171 Deps = isl_union_map_apply_domain(Deps, Schedule); 172 173 if (isl_union_map_is_empty(Deps)) { 174 isl_union_map_free(Deps); 175 isl_space_free(ScheduleSpace); 176 return true; 177 } 178 179 ScheduleDeps = isl_map_from_union_map(Deps); 180 181 for (unsigned i = 0; i < Dimension; i++) 182 ScheduleDeps = isl_map_equate(ScheduleDeps, isl_dim_out, i, isl_dim_in, i); 183 184 Test = isl_map_universe(isl_map_get_space(ScheduleDeps)); 185 Test = isl_map_equate(Test, isl_dim_out, Dimension, isl_dim_in, Dimension); 186 IsParallel = isl_map_is_subset(ScheduleDeps, Test); 187 188 isl_space_free(ScheduleSpace); 189 isl_map_free(Test); 190 isl_map_free(ScheduleDeps); 191 192 return IsParallel; 193 } 194 195 // Mark a for node openmp parallel, if it is the outermost parallel for node. 196 static void markOpenmpParallel(__isl_keep isl_ast_build *Build, 197 struct AstBuildUserInfo *BuildInfo, 198 struct IslAstUser *NodeInfo) { 199 if (BuildInfo->InParallelFor) 200 return; 201 202 if (astScheduleDimIsParallel(Build, BuildInfo->Deps)) { 203 BuildInfo->InParallelFor = 1; 204 NodeInfo->IsOutermostParallel = 1; 205 } 206 } 207 208 // This method is executed before the construction of a for node. It creates 209 // an isl_id that is used to annotate the subsequently generated ast for nodes. 210 // 211 // In this function we also run the following analyses: 212 // 213 // - Detection of openmp parallel loops 214 // 215 static __isl_give isl_id *astBuildBeforeFor(__isl_keep isl_ast_build *Build, 216 void *User) { 217 struct AstBuildUserInfo *BuildInfo = (struct AstBuildUserInfo *)User; 218 struct IslAstUser *NodeInfo = allocateIslAstUser(); 219 isl_id *Id = isl_id_alloc(isl_ast_build_get_ctx(Build), "", NodeInfo); 220 Id = isl_id_set_free_user(Id, freeIslAstUser); 221 222 markOpenmpParallel(Build, BuildInfo, NodeInfo); 223 224 return Id; 225 } 226 227 // Returns 0 when Node contains loops, otherwise returns -1. This search 228 // function uses ISL's way to iterate over lists of isl_ast_nodes with 229 // isl_ast_node_list_foreach. Please use the single argument wrapper function 230 // that returns a bool instead of using this function directly. 231 static int containsLoops(__isl_take isl_ast_node *Node, void *User) { 232 if (!Node) 233 return -1; 234 235 switch (isl_ast_node_get_type(Node)) { 236 case isl_ast_node_for: 237 isl_ast_node_free(Node); 238 return 0; 239 case isl_ast_node_block: { 240 isl_ast_node_list *List = isl_ast_node_block_get_children(Node); 241 int Res = isl_ast_node_list_foreach(List, &containsLoops, nullptr); 242 isl_ast_node_list_free(List); 243 isl_ast_node_free(Node); 244 return Res; 245 } 246 case isl_ast_node_if: { 247 int Res = -1; 248 if (0 == containsLoops(isl_ast_node_if_get_then(Node), nullptr) || 249 (isl_ast_node_if_has_else(Node) && 250 0 == containsLoops(isl_ast_node_if_get_else(Node), nullptr))) 251 Res = 0; 252 isl_ast_node_free(Node); 253 return Res; 254 } 255 case isl_ast_node_user: 256 default: 257 isl_ast_node_free(Node); 258 return -1; 259 } 260 } 261 262 // Returns true when Node contains loops. 263 static bool containsLoops(__isl_take isl_ast_node *Node) { 264 return 0 == containsLoops(Node, nullptr); 265 } 266 267 // This method is executed after the construction of a for node. 268 // 269 // It performs the following actions: 270 // 271 // - Reset the 'InParallelFor' flag, as soon as we leave a for node, 272 // that is marked as openmp parallel. 273 // 274 static __isl_give isl_ast_node * 275 astBuildAfterFor(__isl_take isl_ast_node *Node, __isl_keep isl_ast_build *Build, 276 void *User) { 277 isl_id *Id = isl_ast_node_get_annotation(Node); 278 if (!Id) 279 return Node; 280 struct IslAstUser *Info = (struct IslAstUser *)isl_id_get_user(Id); 281 struct AstBuildUserInfo *BuildInfo = (struct AstBuildUserInfo *)User; 282 283 if (Info) { 284 if (Info->IsOutermostParallel) 285 BuildInfo->InParallelFor = 0; 286 if (!containsLoops(isl_ast_node_for_get_body(Node))) 287 if (astScheduleDimIsParallel(Build, BuildInfo->Deps)) 288 Info->IsInnermostParallel = 1; 289 if (!Info->Context) 290 Info->Context = isl_ast_build_copy(Build); 291 } 292 293 isl_id_free(Id); 294 return Node; 295 } 296 297 static __isl_give isl_ast_node *AtEachDomain(__isl_take isl_ast_node *Node, 298 __isl_keep isl_ast_build *Context, 299 void *User) { 300 struct IslAstUser *Info = nullptr; 301 isl_id *Id = isl_ast_node_get_annotation(Node); 302 303 if (Id) 304 Info = (struct IslAstUser *)isl_id_get_user(Id); 305 306 if (!Info) { 307 // Allocate annotations once: parallel for detection might have already 308 // allocated the annotations for this node. 309 Info = allocateIslAstUser(); 310 Id = isl_id_alloc(isl_ast_node_get_ctx(Node), nullptr, Info); 311 Id = isl_id_set_free_user(Id, &freeIslAstUser); 312 } 313 314 if (!Info->PMA) { 315 isl_map *Map = isl_map_from_union_map(isl_ast_build_get_schedule(Context)); 316 Info->PMA = isl_pw_multi_aff_from_map(isl_map_reverse(Map)); 317 } 318 if (!Info->Context) 319 Info->Context = isl_ast_build_copy(Context); 320 321 return isl_ast_node_set_annotation(Node, Id); 322 } 323 324 void IslAst::buildRunCondition(__isl_keep isl_ast_build *Context) { 325 // The conditions that need to be checked at run-time for this scop are 326 // available as an isl_set in the AssumedContext. We generate code for this 327 // check as follows. First, we generate an isl_pw_aff that is 1, if a certain 328 // combination of parameter values fulfills the conditions in the assumed 329 // context, and that is 0 otherwise. We then translate this isl_pw_aff into 330 // an isl_ast_expr. At run-time this expression can be evaluated and the 331 // optimized scop can be executed conditionally according to the result of the 332 // run-time check. 333 334 isl_aff *Zero = 335 isl_aff_zero_on_domain(isl_local_space_from_space(S->getParamSpace())); 336 isl_aff *One = 337 isl_aff_zero_on_domain(isl_local_space_from_space(S->getParamSpace())); 338 339 One = isl_aff_add_constant_si(One, 1); 340 341 isl_pw_aff *PwZero = isl_pw_aff_from_aff(Zero); 342 isl_pw_aff *PwOne = isl_pw_aff_from_aff(One); 343 344 PwOne = isl_pw_aff_intersect_domain(PwOne, S->getAssumedContext()); 345 PwZero = isl_pw_aff_intersect_domain( 346 PwZero, isl_set_complement(S->getAssumedContext())); 347 348 isl_pw_aff *Cond = isl_pw_aff_union_max(PwZero, PwOne); 349 350 RunCondition = isl_ast_build_expr_from_pw_aff(Context, Cond); 351 } 352 353 IslAst::IslAst(Scop *Scop, Dependences &D) : S(Scop) { 354 isl_ctx *Ctx = S->getIslCtx(); 355 isl_options_set_ast_build_atomic_upper_bound(Ctx, true); 356 isl_ast_build *Context; 357 struct AstBuildUserInfo BuildInfo; 358 359 if (UseContext) 360 Context = isl_ast_build_from_context(S->getContext()); 361 else 362 Context = isl_ast_build_from_context(isl_set_universe(S->getParamSpace())); 363 364 Context = isl_ast_build_set_at_each_domain(Context, AtEachDomain, nullptr); 365 366 isl_union_map *Schedule = getSchedule(); 367 368 Function *F = Scop->getRegion().getEntry()->getParent(); 369 (void)F; 370 371 DEBUG(dbgs() << ":: isl ast :: " << F->getName() 372 << " :: " << Scop->getRegion().getNameStr() << "\n"); 373 374 DEBUG(dbgs() << S->getContextStr() << "\n"; isl_union_map_dump(Schedule)); 375 376 if (DetectParallel || PollyVectorizerChoice != VECTORIZER_NONE) { 377 BuildInfo.Deps = &D; 378 BuildInfo.InParallelFor = 0; 379 380 Context = isl_ast_build_set_before_each_for(Context, &astBuildBeforeFor, 381 &BuildInfo); 382 Context = isl_ast_build_set_after_each_for(Context, &astBuildAfterFor, 383 &BuildInfo); 384 } 385 386 buildRunCondition(Context); 387 388 Root = isl_ast_build_ast_from_schedule(Context, Schedule); 389 390 isl_ast_build_free(Context); 391 392 DEBUG(pprint(dbgs())); 393 } 394 395 __isl_give isl_union_map *IslAst::getSchedule() { 396 isl_union_map *Schedule = isl_union_map_empty(S->getParamSpace()); 397 398 for (Scop::iterator SI = S->begin(), SE = S->end(); SI != SE; ++SI) { 399 ScopStmt *Stmt = *SI; 400 isl_map *StmtSchedule = Stmt->getScattering(); 401 402 StmtSchedule = isl_map_intersect_domain(StmtSchedule, Stmt->getDomain()); 403 Schedule = 404 isl_union_map_union(Schedule, isl_union_map_from_map(StmtSchedule)); 405 } 406 407 return Schedule; 408 } 409 410 IslAst::~IslAst() { 411 isl_ast_node_free(Root); 412 isl_ast_expr_free(RunCondition); 413 } 414 415 /// Print a C like representation of the program. 416 void IslAst::pprint(llvm::raw_ostream &OS) { 417 isl_ast_node *Root; 418 isl_ast_print_options *Options; 419 420 Options = isl_ast_print_options_alloc(S->getIslCtx()); 421 Options = isl_ast_print_options_set_print_for(Options, &printFor, nullptr); 422 423 isl_printer *P = isl_printer_to_str(S->getIslCtx()); 424 P = isl_printer_set_output_format(P, ISL_FORMAT_C); 425 426 P = isl_printer_print_ast_expr(P, RunCondition); 427 char *result = isl_printer_get_str(P); 428 P = isl_printer_flush(P); 429 430 OS << "\nif (" << result << ")\n\n"; 431 P = isl_printer_indent(P, 4); 432 433 Root = getAst(); 434 P = isl_ast_node_print(Root, P, Options); 435 result = isl_printer_get_str(P); 436 OS << result << "\n"; 437 OS << "else\n"; 438 OS << " { /* original code */ }\n\n"; 439 isl_printer_free(P); 440 isl_ast_node_free(Root); 441 } 442 443 __isl_give isl_ast_node *IslAst::getAst() { return isl_ast_node_copy(Root); } 444 __isl_give isl_ast_expr *IslAst::getRunCondition() { 445 return isl_ast_expr_copy(RunCondition); 446 } 447 448 void IslAstInfo::pprint(llvm::raw_ostream &OS) { Ast->pprint(OS); } 449 450 void IslAstInfo::releaseMemory() { 451 if (Ast) { 452 delete Ast; 453 Ast = 0; 454 } 455 } 456 457 bool IslAstInfo::runOnScop(Scop &Scop) { 458 if (Ast) 459 delete Ast; 460 461 S = &Scop; 462 463 Dependences &D = getAnalysis<Dependences>(); 464 465 Ast = new IslAst(&Scop, D); 466 467 return false; 468 } 469 470 __isl_give isl_ast_node *IslAstInfo::getAst() { return Ast->getAst(); } 471 __isl_give isl_ast_expr *IslAstInfo::getRunCondition() { 472 return Ast->getRunCondition(); 473 } 474 475 void IslAstInfo::printScop(raw_ostream &OS) const { 476 Function *F = S->getRegion().getEntry()->getParent(); 477 478 OS << F->getName() << "():\n"; 479 480 Ast->pprint(OS); 481 } 482 483 void IslAstInfo::getAnalysisUsage(AnalysisUsage &AU) const { 484 // Get the Common analysis usage of ScopPasses. 485 ScopPass::getAnalysisUsage(AU); 486 AU.addRequired<ScopInfo>(); 487 AU.addRequired<Dependences>(); 488 } 489 490 char IslAstInfo::ID = 0; 491 492 Pass *polly::createIslAstInfoPass() { return new IslAstInfo(); } 493 494 INITIALIZE_PASS_BEGIN(IslAstInfo, "polly-ast", 495 "Polly - Generate an AST of the SCoP (isl)", false, 496 false); 497 INITIALIZE_PASS_DEPENDENCY(ScopInfo); 498 INITIALIZE_PASS_DEPENDENCY(Dependences); 499 INITIALIZE_PASS_END(IslAstInfo, "polly-ast", 500 "Polly - Generate an AST from the SCoP (isl)", false, false) 501