1 //===- polly/ScheduleTreeTransform.cpp --------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Make changes to isl's schedule tree data structure. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "polly/ScheduleTreeTransform.h" 14 #include "polly/Support/ISLTools.h" 15 #include "polly/Support/ScopHelper.h" 16 #include "llvm/ADT/ArrayRef.h" 17 #include "llvm/ADT/Sequence.h" 18 #include "llvm/ADT/SmallVector.h" 19 #include "llvm/IR/Constants.h" 20 #include "llvm/IR/Metadata.h" 21 #include "llvm/Transforms/Utils/UnrollLoop.h" 22 23 using namespace polly; 24 using namespace llvm; 25 26 namespace { 27 /// Recursively visit all nodes of a schedule tree while allowing changes. 28 /// 29 /// The visit methods return an isl::schedule_node that is used to continue 30 /// visiting the tree. Structural changes such as returning a different node 31 /// will confuse the visitor. 32 template <typename Derived, typename... Args> 33 struct ScheduleNodeRewriter 34 : public RecursiveScheduleTreeVisitor<Derived, isl::schedule_node, 35 Args...> { 36 Derived &getDerived() { return *static_cast<Derived *>(this); } 37 const Derived &getDerived() const { 38 return *static_cast<const Derived *>(this); 39 } 40 41 isl::schedule_node visitNode(const isl::schedule_node &Node, Args... args) { 42 if (!Node.has_children()) 43 return Node; 44 45 isl::schedule_node It = Node.first_child(); 46 while (true) { 47 It = getDerived().visit(It, std::forward<Args>(args)...); 48 if (!It.has_next_sibling()) 49 break; 50 It = It.next_sibling(); 51 } 52 return It.parent(); 53 } 54 }; 55 56 /// Rewrite a schedule tree by reconstructing it bottom-up. 57 /// 58 /// By default, the original schedule tree is reconstructed. To build a 59 /// different tree, redefine visitor methods in a derived class (CRTP). 60 /// 61 /// Note that AST build options are not applied; Setting the isolate[] option 62 /// makes the schedule tree 'anchored' and cannot be modified afterwards. Hence, 63 /// AST build options must be set after the tree has been constructed. 64 template <typename Derived, typename... Args> 65 struct ScheduleTreeRewriter 66 : public RecursiveScheduleTreeVisitor<Derived, isl::schedule, Args...> { 67 Derived &getDerived() { return *static_cast<Derived *>(this); } 68 const Derived &getDerived() const { 69 return *static_cast<const Derived *>(this); 70 } 71 72 isl::schedule visitDomain(isl::schedule_node_domain Node, Args... args) { 73 // Every schedule_tree already has a domain node, no need to add one. 74 return getDerived().visit(Node.first_child(), std::forward<Args>(args)...); 75 } 76 77 isl::schedule visitBand(isl::schedule_node_band Band, Args... args) { 78 isl::multi_union_pw_aff PartialSched = 79 isl::manage(isl_schedule_node_band_get_partial_schedule(Band.get())); 80 isl::schedule NewChild = 81 getDerived().visit(Band.child(0), std::forward<Args>(args)...); 82 isl::schedule_node NewNode = 83 NewChild.insert_partial_schedule(PartialSched).get_root().child(0); 84 85 // Reapply permutability and coincidence attributes. 86 NewNode = isl::manage(isl_schedule_node_band_set_permutable( 87 NewNode.release(), isl_schedule_node_band_get_permutable(Band.get()))); 88 unsigned BandDims = isl_schedule_node_band_n_member(Band.get()); 89 for (unsigned i = 0; i < BandDims; i += 1) 90 NewNode = isl::manage(isl_schedule_node_band_member_set_coincident( 91 NewNode.release(), i, 92 isl_schedule_node_band_member_get_coincident(Band.get(), i))); 93 94 return NewNode.get_schedule(); 95 } 96 97 isl::schedule visitSequence(isl::schedule_node_sequence Sequence, 98 Args... args) { 99 int NumChildren = isl_schedule_node_n_children(Sequence.get()); 100 isl::schedule Result = 101 getDerived().visit(Sequence.child(0), std::forward<Args>(args)...); 102 for (int i = 1; i < NumChildren; i += 1) 103 Result = Result.sequence( 104 getDerived().visit(Sequence.child(i), std::forward<Args>(args)...)); 105 return Result; 106 } 107 108 isl::schedule visitSet(isl::schedule_node_set Set, Args... args) { 109 int NumChildren = isl_schedule_node_n_children(Set.get()); 110 isl::schedule Result = 111 getDerived().visit(Set.child(0), std::forward<Args>(args)...); 112 for (int i = 1; i < NumChildren; i += 1) 113 Result = isl::manage( 114 isl_schedule_set(Result.release(), 115 getDerived() 116 .visit(Set.child(i), std::forward<Args>(args)...) 117 .release())); 118 return Result; 119 } 120 121 isl::schedule visitLeaf(isl::schedule_node_leaf Leaf, Args... args) { 122 return isl::schedule::from_domain(Leaf.get_domain()); 123 } 124 125 isl::schedule visitMark(const isl::schedule_node &Mark, Args... args) { 126 127 isl::id TheMark = Mark.as<isl::schedule_node_mark>().get_id(); 128 isl::schedule_node NewChild = 129 getDerived() 130 .visit(Mark.first_child(), std::forward<Args>(args)...) 131 .get_root() 132 .first_child(); 133 return NewChild.insert_mark(TheMark).get_schedule(); 134 } 135 136 isl::schedule visitExtension(isl::schedule_node_extension Extension, 137 Args... args) { 138 isl::union_map TheExtension = 139 Extension.as<isl::schedule_node_extension>().get_extension(); 140 isl::schedule_node NewChild = getDerived() 141 .visit(Extension.child(0), args...) 142 .get_root() 143 .first_child(); 144 isl::schedule_node NewExtension = 145 isl::schedule_node::from_extension(TheExtension); 146 return NewChild.graft_before(NewExtension).get_schedule(); 147 } 148 149 isl::schedule visitFilter(isl::schedule_node_filter Filter, Args... args) { 150 isl::union_set FilterDomain = 151 Filter.as<isl::schedule_node_filter>().get_filter(); 152 isl::schedule NewSchedule = 153 getDerived().visit(Filter.child(0), std::forward<Args>(args)...); 154 return NewSchedule.intersect_domain(FilterDomain); 155 } 156 157 isl::schedule visitNode(isl::schedule_node Node, Args... args) { 158 llvm_unreachable("Not implemented"); 159 } 160 }; 161 162 /// Rewrite a schedule tree to an equivalent one without extension nodes. 163 /// 164 /// Each visit method takes two additional arguments: 165 /// 166 /// * The new domain the node, which is the inherited domain plus any domains 167 /// added by extension nodes. 168 /// 169 /// * A map of extension domains of all children is returned; it is required by 170 /// band nodes to schedule the additional domains at the same position as the 171 /// extension node would. 172 /// 173 struct ExtensionNodeRewriter 174 : public ScheduleTreeRewriter<ExtensionNodeRewriter, const isl::union_set &, 175 isl::union_map &> { 176 using BaseTy = ScheduleTreeRewriter<ExtensionNodeRewriter, 177 const isl::union_set &, isl::union_map &>; 178 BaseTy &getBase() { return *this; } 179 const BaseTy &getBase() const { return *this; } 180 181 isl::schedule visitSchedule(isl::schedule Schedule) { 182 isl::union_map Extensions; 183 isl::schedule Result = 184 visit(Schedule.get_root(), Schedule.get_domain(), Extensions); 185 assert(!Extensions.is_null() && Extensions.is_empty()); 186 return Result; 187 } 188 189 isl::schedule visitSequence(isl::schedule_node_sequence Sequence, 190 const isl::union_set &Domain, 191 isl::union_map &Extensions) { 192 int NumChildren = isl_schedule_node_n_children(Sequence.get()); 193 isl::schedule NewNode = visit(Sequence.first_child(), Domain, Extensions); 194 for (int i = 1; i < NumChildren; i += 1) { 195 isl::schedule_node OldChild = Sequence.child(i); 196 isl::union_map NewChildExtensions; 197 isl::schedule NewChildNode = visit(OldChild, Domain, NewChildExtensions); 198 NewNode = NewNode.sequence(NewChildNode); 199 Extensions = Extensions.unite(NewChildExtensions); 200 } 201 return NewNode; 202 } 203 204 isl::schedule visitSet(isl::schedule_node_set Set, 205 const isl::union_set &Domain, 206 isl::union_map &Extensions) { 207 int NumChildren = isl_schedule_node_n_children(Set.get()); 208 isl::schedule NewNode = visit(Set.first_child(), Domain, Extensions); 209 for (int i = 1; i < NumChildren; i += 1) { 210 isl::schedule_node OldChild = Set.child(i); 211 isl::union_map NewChildExtensions; 212 isl::schedule NewChildNode = visit(OldChild, Domain, NewChildExtensions); 213 NewNode = isl::manage( 214 isl_schedule_set(NewNode.release(), NewChildNode.release())); 215 Extensions = Extensions.unite(NewChildExtensions); 216 } 217 return NewNode; 218 } 219 220 isl::schedule visitLeaf(isl::schedule_node_leaf Leaf, 221 const isl::union_set &Domain, 222 isl::union_map &Extensions) { 223 Extensions = isl::union_map::empty(Leaf.ctx()); 224 return isl::schedule::from_domain(Domain); 225 } 226 227 isl::schedule visitBand(isl::schedule_node_band OldNode, 228 const isl::union_set &Domain, 229 isl::union_map &OuterExtensions) { 230 isl::schedule_node OldChild = OldNode.first_child(); 231 isl::multi_union_pw_aff PartialSched = 232 isl::manage(isl_schedule_node_band_get_partial_schedule(OldNode.get())); 233 234 isl::union_map NewChildExtensions; 235 isl::schedule NewChild = visit(OldChild, Domain, NewChildExtensions); 236 237 // Add the extensions to the partial schedule. 238 OuterExtensions = isl::union_map::empty(NewChildExtensions.ctx()); 239 isl::union_map NewPartialSchedMap = isl::union_map::from(PartialSched); 240 unsigned BandDims = isl_schedule_node_band_n_member(OldNode.get()); 241 for (isl::map Ext : NewChildExtensions.get_map_list()) { 242 unsigned ExtDims = Ext.domain_tuple_dim().release(); 243 assert(ExtDims >= BandDims); 244 unsigned OuterDims = ExtDims - BandDims; 245 246 isl::map BandSched = 247 Ext.project_out(isl::dim::in, 0, OuterDims).reverse(); 248 NewPartialSchedMap = NewPartialSchedMap.unite(BandSched); 249 250 // There might be more outer bands that have to schedule the extensions. 251 if (OuterDims > 0) { 252 isl::map OuterSched = 253 Ext.project_out(isl::dim::in, OuterDims, BandDims); 254 OuterExtensions = OuterExtensions.unite(OuterSched); 255 } 256 } 257 isl::multi_union_pw_aff NewPartialSchedAsAsMultiUnionPwAff = 258 isl::multi_union_pw_aff::from_union_map(NewPartialSchedMap); 259 isl::schedule_node NewNode = 260 NewChild.insert_partial_schedule(NewPartialSchedAsAsMultiUnionPwAff) 261 .get_root() 262 .child(0); 263 264 // Reapply permutability and coincidence attributes. 265 NewNode = isl::manage(isl_schedule_node_band_set_permutable( 266 NewNode.release(), 267 isl_schedule_node_band_get_permutable(OldNode.get()))); 268 for (unsigned i = 0; i < BandDims; i += 1) { 269 NewNode = isl::manage(isl_schedule_node_band_member_set_coincident( 270 NewNode.release(), i, 271 isl_schedule_node_band_member_get_coincident(OldNode.get(), i))); 272 } 273 274 return NewNode.get_schedule(); 275 } 276 277 isl::schedule visitFilter(isl::schedule_node_filter Filter, 278 const isl::union_set &Domain, 279 isl::union_map &Extensions) { 280 isl::union_set FilterDomain = 281 Filter.as<isl::schedule_node_filter>().get_filter(); 282 isl::union_set NewDomain = Domain.intersect(FilterDomain); 283 284 // A filter is added implicitly if necessary when joining schedule trees. 285 return visit(Filter.first_child(), NewDomain, Extensions); 286 } 287 288 isl::schedule visitExtension(isl::schedule_node_extension Extension, 289 const isl::union_set &Domain, 290 isl::union_map &Extensions) { 291 isl::union_map ExtDomain = 292 Extension.as<isl::schedule_node_extension>().get_extension(); 293 isl::union_set NewDomain = Domain.unite(ExtDomain.range()); 294 isl::union_map ChildExtensions; 295 isl::schedule NewChild = 296 visit(Extension.first_child(), NewDomain, ChildExtensions); 297 Extensions = ChildExtensions.unite(ExtDomain); 298 return NewChild; 299 } 300 }; 301 302 /// Collect all AST build options in any schedule tree band. 303 /// 304 /// ScheduleTreeRewriter cannot apply the schedule tree options. This class 305 /// collects these options to apply them later. 306 struct CollectASTBuildOptions 307 : public RecursiveScheduleTreeVisitor<CollectASTBuildOptions> { 308 using BaseTy = RecursiveScheduleTreeVisitor<CollectASTBuildOptions>; 309 BaseTy &getBase() { return *this; } 310 const BaseTy &getBase() const { return *this; } 311 312 llvm::SmallVector<isl::union_set, 8> ASTBuildOptions; 313 314 void visitBand(isl::schedule_node_band Band) { 315 ASTBuildOptions.push_back( 316 isl::manage(isl_schedule_node_band_get_ast_build_options(Band.get()))); 317 return getBase().visitBand(Band); 318 } 319 }; 320 321 /// Apply AST build options to the bands in a schedule tree. 322 /// 323 /// This rewrites a schedule tree with the AST build options applied. We assume 324 /// that the band nodes are visited in the same order as they were when the 325 /// build options were collected, typically by CollectASTBuildOptions. 326 struct ApplyASTBuildOptions 327 : public ScheduleNodeRewriter<ApplyASTBuildOptions> { 328 using BaseTy = ScheduleNodeRewriter<ApplyASTBuildOptions>; 329 BaseTy &getBase() { return *this; } 330 const BaseTy &getBase() const { return *this; } 331 332 size_t Pos; 333 llvm::ArrayRef<isl::union_set> ASTBuildOptions; 334 335 ApplyASTBuildOptions(llvm::ArrayRef<isl::union_set> ASTBuildOptions) 336 : ASTBuildOptions(ASTBuildOptions) {} 337 338 isl::schedule visitSchedule(isl::schedule Schedule) { 339 Pos = 0; 340 isl::schedule Result = visit(Schedule).get_schedule(); 341 assert(Pos == ASTBuildOptions.size() && 342 "AST build options must match to band nodes"); 343 return Result; 344 } 345 346 isl::schedule_node visitBand(isl::schedule_node_band Band) { 347 isl::schedule_node_band Result = 348 Band.set_ast_build_options(ASTBuildOptions[Pos]); 349 Pos += 1; 350 return getBase().visitBand(Result); 351 } 352 }; 353 354 /// Return whether the schedule contains an extension node. 355 static bool containsExtensionNode(isl::schedule Schedule) { 356 assert(!Schedule.is_null()); 357 358 auto Callback = [](__isl_keep isl_schedule_node *Node, 359 void *User) -> isl_bool { 360 if (isl_schedule_node_get_type(Node) == isl_schedule_node_extension) { 361 // Stop walking the schedule tree. 362 return isl_bool_error; 363 } 364 365 // Continue searching the subtree. 366 return isl_bool_true; 367 }; 368 isl_stat RetVal = isl_schedule_foreach_schedule_node_top_down( 369 Schedule.get(), Callback, nullptr); 370 371 // We assume that the traversal itself does not fail, i.e. the only reason to 372 // return isl_stat_error is that an extension node was found. 373 return RetVal == isl_stat_error; 374 } 375 376 /// Find a named MDNode property in a LoopID. 377 static MDNode *findOptionalNodeOperand(MDNode *LoopMD, StringRef Name) { 378 return dyn_cast_or_null<MDNode>( 379 findMetadataOperand(LoopMD, Name).getValueOr(nullptr)); 380 } 381 382 /// Is this node of type mark? 383 static bool isMark(const isl::schedule_node &Node) { 384 return isl_schedule_node_get_type(Node.get()) == isl_schedule_node_mark; 385 } 386 387 #ifndef NDEBUG 388 /// Is this node of type band? 389 static bool isBand(const isl::schedule_node &Node) { 390 return isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band; 391 } 392 393 /// Is this node a band of a single dimension (i.e. could represent a loop)? 394 static bool isBandWithSingleLoop(const isl::schedule_node &Node) { 395 396 return isBand(Node) && isl_schedule_node_band_n_member(Node.get()) == 1; 397 } 398 #endif 399 400 static bool isLeaf(const isl::schedule_node &Node) { 401 return isl_schedule_node_get_type(Node.get()) == isl_schedule_node_leaf; 402 } 403 404 /// Create an isl::id representing the output loop after a transformation. 405 static isl::id createGeneratedLoopAttr(isl::ctx Ctx, MDNode *FollowupLoopMD) { 406 // Don't need to id the followup. 407 // TODO: Append llvm.loop.disable_heustistics metadata unless overridden by 408 // user followup-MD 409 if (!FollowupLoopMD) 410 return {}; 411 412 BandAttr *Attr = new BandAttr(); 413 Attr->Metadata = FollowupLoopMD; 414 return getIslLoopAttr(Ctx, Attr); 415 } 416 417 /// A loop consists of a band and an optional marker that wraps it. Return the 418 /// outermost of the two. 419 420 /// That is, either the mark or, if there is not mark, the loop itself. Can 421 /// start with either the mark or the band. 422 static isl::schedule_node moveToBandMark(isl::schedule_node BandOrMark) { 423 if (isBandMark(BandOrMark)) { 424 assert(isBandWithSingleLoop(BandOrMark.child(0))); 425 return BandOrMark; 426 } 427 assert(isBandWithSingleLoop(BandOrMark)); 428 429 isl::schedule_node Mark = BandOrMark.parent(); 430 if (isBandMark(Mark)) 431 return Mark; 432 433 // Band has no loop marker. 434 return BandOrMark; 435 } 436 437 static isl::schedule_node removeMark(isl::schedule_node MarkOrBand, 438 BandAttr *&Attr) { 439 MarkOrBand = moveToBandMark(MarkOrBand); 440 441 isl::schedule_node Band; 442 if (isMark(MarkOrBand)) { 443 Attr = getLoopAttr(MarkOrBand.as<isl::schedule_node_mark>().get_id()); 444 Band = isl::manage(isl_schedule_node_delete(MarkOrBand.release())); 445 } else { 446 Attr = nullptr; 447 Band = MarkOrBand; 448 } 449 450 assert(isBandWithSingleLoop(Band)); 451 return Band; 452 } 453 454 /// Remove the mark that wraps a loop. Return the band representing the loop. 455 static isl::schedule_node removeMark(isl::schedule_node MarkOrBand) { 456 BandAttr *Attr; 457 return removeMark(MarkOrBand, Attr); 458 } 459 460 static isl::schedule_node insertMark(isl::schedule_node Band, isl::id Mark) { 461 assert(isBand(Band)); 462 assert(moveToBandMark(Band).is_equal(Band) && 463 "Don't add a two marks for a band"); 464 465 return Band.insert_mark(Mark).child(0); 466 } 467 468 /// Return the (one-dimensional) set of numbers that are divisible by @p Factor 469 /// with remainder @p Offset. 470 /// 471 /// isDivisibleBySet(Ctx, 4, 0) = { [i] : floord(i,4) = 0 } 472 /// isDivisibleBySet(Ctx, 4, 1) = { [i] : floord(i,4) = 1 } 473 /// 474 static isl::basic_set isDivisibleBySet(isl::ctx &Ctx, long Factor, 475 long Offset) { 476 isl::val ValFactor{Ctx, Factor}; 477 isl::val ValOffset{Ctx, Offset}; 478 479 isl::space Unispace{Ctx, 0, 1}; 480 isl::local_space LUnispace{Unispace}; 481 isl::aff AffFactor{LUnispace, ValFactor}; 482 isl::aff AffOffset{LUnispace, ValOffset}; 483 484 isl::aff Id = isl::aff::var_on_domain(LUnispace, isl::dim::out, 0); 485 isl::aff DivMul = Id.mod(ValFactor); 486 isl::basic_map Divisible = isl::basic_map::from_aff(DivMul); 487 isl::basic_map Modulo = Divisible.fix_val(isl::dim::out, 0, ValOffset); 488 return Modulo.domain(); 489 } 490 491 /// Make the last dimension of Set to take values from 0 to VectorWidth - 1. 492 /// 493 /// @param Set A set, which should be modified. 494 /// @param VectorWidth A parameter, which determines the constraint. 495 static isl::set addExtentConstraints(isl::set Set, int VectorWidth) { 496 unsigned Dims = Set.tuple_dim().release(); 497 isl::space Space = Set.get_space(); 498 isl::local_space LocalSpace = isl::local_space(Space); 499 isl::constraint ExtConstr = isl::constraint::alloc_inequality(LocalSpace); 500 ExtConstr = ExtConstr.set_constant_si(0); 501 ExtConstr = ExtConstr.set_coefficient_si(isl::dim::set, Dims - 1, 1); 502 Set = Set.add_constraint(ExtConstr); 503 ExtConstr = isl::constraint::alloc_inequality(LocalSpace); 504 ExtConstr = ExtConstr.set_constant_si(VectorWidth - 1); 505 ExtConstr = ExtConstr.set_coefficient_si(isl::dim::set, Dims - 1, -1); 506 return Set.add_constraint(ExtConstr); 507 } 508 } // namespace 509 510 bool polly::isBandMark(const isl::schedule_node &Node) { 511 return isMark(Node) && 512 isLoopAttr(Node.as<isl::schedule_node_mark>().get_id()); 513 } 514 515 BandAttr *polly::getBandAttr(isl::schedule_node MarkOrBand) { 516 MarkOrBand = moveToBandMark(MarkOrBand); 517 if (!isMark(MarkOrBand)) 518 return nullptr; 519 520 return getLoopAttr(MarkOrBand.as<isl::schedule_node_mark>().get_id()); 521 } 522 523 isl::schedule polly::hoistExtensionNodes(isl::schedule Sched) { 524 // If there is no extension node in the first place, return the original 525 // schedule tree. 526 if (!containsExtensionNode(Sched)) 527 return Sched; 528 529 // Build options can anchor schedule nodes, such that the schedule tree cannot 530 // be modified anymore. Therefore, apply build options after the tree has been 531 // created. 532 CollectASTBuildOptions Collector; 533 Collector.visit(Sched); 534 535 // Rewrite the schedule tree without extension nodes. 536 ExtensionNodeRewriter Rewriter; 537 isl::schedule NewSched = Rewriter.visitSchedule(Sched); 538 539 // Reapply the AST build options. The rewriter must not change the iteration 540 // order of bands. Any other node type is ignored. 541 ApplyASTBuildOptions Applicator(Collector.ASTBuildOptions); 542 NewSched = Applicator.visitSchedule(NewSched); 543 544 return NewSched; 545 } 546 547 isl::schedule polly::applyFullUnroll(isl::schedule_node BandToUnroll) { 548 isl::ctx Ctx = BandToUnroll.ctx(); 549 550 // Remove the loop's mark, the loop will disappear anyway. 551 BandToUnroll = removeMark(BandToUnroll); 552 assert(isBandWithSingleLoop(BandToUnroll)); 553 554 isl::multi_union_pw_aff PartialSched = isl::manage( 555 isl_schedule_node_band_get_partial_schedule(BandToUnroll.get())); 556 assert(PartialSched.dim(isl::dim::out).release() == 1 && 557 "Can only unroll a single dimension"); 558 isl::union_pw_aff PartialSchedUAff = PartialSched.at(0); 559 560 isl::union_set Domain = BandToUnroll.get_domain(); 561 PartialSchedUAff = PartialSchedUAff.intersect_domain(Domain); 562 isl::union_map PartialSchedUMap = 563 isl::union_map::from(isl::union_pw_multi_aff(PartialSchedUAff)); 564 565 // Enumerator only the scatter elements. 566 isl::union_set ScatterList = PartialSchedUMap.range(); 567 568 // Enumerate all loop iterations. 569 // TODO: Diagnose if not enumerable or depends on a parameter. 570 SmallVector<isl::point, 16> Elts; 571 ScatterList.foreach_point([&Elts](isl::point P) -> isl::stat { 572 Elts.push_back(P); 573 return isl::stat::ok(); 574 }); 575 576 // Don't assume that foreach_point returns in execution order. 577 llvm::sort(Elts, [](isl::point P1, isl::point P2) -> bool { 578 isl::val C1 = P1.get_coordinate_val(isl::dim::set, 0); 579 isl::val C2 = P2.get_coordinate_val(isl::dim::set, 0); 580 return C1.lt(C2); 581 }); 582 583 // Convert the points to a sequence of filters. 584 isl::union_set_list List = isl::union_set_list(Ctx, Elts.size()); 585 for (isl::point P : Elts) { 586 // Determine the domains that map this scatter element. 587 isl::union_set DomainFilter = PartialSchedUMap.intersect_range(P).domain(); 588 589 List = List.add(DomainFilter); 590 } 591 592 // Replace original band with unrolled sequence. 593 isl::schedule_node Body = 594 isl::manage(isl_schedule_node_delete(BandToUnroll.release())); 595 Body = Body.insert_sequence(List); 596 return Body.get_schedule(); 597 } 598 599 isl::schedule polly::applyPartialUnroll(isl::schedule_node BandToUnroll, 600 int Factor) { 601 assert(Factor > 0 && "Positive unroll factor required"); 602 isl::ctx Ctx = BandToUnroll.ctx(); 603 604 // Remove the mark, save the attribute for later use. 605 BandAttr *Attr; 606 BandToUnroll = removeMark(BandToUnroll, Attr); 607 assert(isBandWithSingleLoop(BandToUnroll)); 608 609 isl::multi_union_pw_aff PartialSched = isl::manage( 610 isl_schedule_node_band_get_partial_schedule(BandToUnroll.get())); 611 612 // { Stmt[] -> [x] } 613 isl::union_pw_aff PartialSchedUAff = PartialSched.at(0); 614 615 // Here we assume the schedule stride is one and starts with 0, which is not 616 // necessarily the case. 617 isl::union_pw_aff StridedPartialSchedUAff = 618 isl::union_pw_aff::empty(PartialSchedUAff.get_space()); 619 isl::val ValFactor{Ctx, Factor}; 620 PartialSchedUAff.foreach_pw_aff([&StridedPartialSchedUAff, 621 &ValFactor](isl::pw_aff PwAff) -> isl::stat { 622 isl::space Space = PwAff.get_space(); 623 isl::set Universe = isl::set::universe(Space.domain()); 624 isl::pw_aff AffFactor{Universe, ValFactor}; 625 isl::pw_aff DivSchedAff = PwAff.div(AffFactor).floor().mul(AffFactor); 626 StridedPartialSchedUAff = StridedPartialSchedUAff.union_add(DivSchedAff); 627 return isl::stat::ok(); 628 }); 629 630 isl::union_set_list List = isl::union_set_list(Ctx, Factor); 631 for (auto i : seq<int>(0, Factor)) { 632 // { Stmt[] -> [x] } 633 isl::union_map UMap = 634 isl::union_map::from(isl::union_pw_multi_aff(PartialSchedUAff)); 635 636 // { [x] } 637 isl::basic_set Divisible = isDivisibleBySet(Ctx, Factor, i); 638 639 // { Stmt[] } 640 isl::union_set UnrolledDomain = UMap.intersect_range(Divisible).domain(); 641 642 List = List.add(UnrolledDomain); 643 } 644 645 isl::schedule_node Body = 646 isl::manage(isl_schedule_node_delete(BandToUnroll.copy())); 647 Body = Body.insert_sequence(List); 648 isl::schedule_node NewLoop = 649 Body.insert_partial_schedule(StridedPartialSchedUAff); 650 651 MDNode *FollowupMD = nullptr; 652 if (Attr && Attr->Metadata) 653 FollowupMD = 654 findOptionalNodeOperand(Attr->Metadata, LLVMLoopUnrollFollowupUnrolled); 655 656 isl::id NewBandId = createGeneratedLoopAttr(Ctx, FollowupMD); 657 if (!NewBandId.is_null()) 658 NewLoop = insertMark(NewLoop, NewBandId); 659 660 return NewLoop.get_schedule(); 661 } 662 663 isl::set polly::getPartialTilePrefixes(isl::set ScheduleRange, 664 int VectorWidth) { 665 isl_size Dims = ScheduleRange.tuple_dim().release(); 666 isl::set LoopPrefixes = 667 ScheduleRange.drop_constraints_involving_dims(isl::dim::set, Dims - 1, 1); 668 auto ExtentPrefixes = addExtentConstraints(LoopPrefixes, VectorWidth); 669 isl::set BadPrefixes = ExtentPrefixes.subtract(ScheduleRange); 670 BadPrefixes = BadPrefixes.project_out(isl::dim::set, Dims - 1, 1); 671 LoopPrefixes = LoopPrefixes.project_out(isl::dim::set, Dims - 1, 1); 672 return LoopPrefixes.subtract(BadPrefixes); 673 } 674 675 isl::union_set polly::getIsolateOptions(isl::set IsolateDomain, 676 isl_size OutDimsNum) { 677 isl_size Dims = IsolateDomain.tuple_dim().release(); 678 assert(OutDimsNum <= Dims && 679 "The isl::set IsolateDomain is used to describe the range of schedule " 680 "dimensions values, which should be isolated. Consequently, the " 681 "number of its dimensions should be greater than or equal to the " 682 "number of the schedule dimensions."); 683 isl::map IsolateRelation = isl::map::from_domain(IsolateDomain); 684 IsolateRelation = IsolateRelation.move_dims(isl::dim::out, 0, isl::dim::in, 685 Dims - OutDimsNum, OutDimsNum); 686 isl::set IsolateOption = IsolateRelation.wrap(); 687 isl::id Id = isl::id::alloc(IsolateOption.ctx(), "isolate", nullptr); 688 IsolateOption = IsolateOption.set_tuple_id(Id); 689 return isl::union_set(IsolateOption); 690 } 691 692 isl::union_set polly::getDimOptions(isl::ctx Ctx, const char *Option) { 693 isl::space Space(Ctx, 0, 1); 694 auto DimOption = isl::set::universe(Space); 695 auto Id = isl::id::alloc(Ctx, Option, nullptr); 696 DimOption = DimOption.set_tuple_id(Id); 697 return isl::union_set(DimOption); 698 } 699 700 isl::schedule_node polly::tileNode(isl::schedule_node Node, 701 const char *Identifier, 702 ArrayRef<int> TileSizes, 703 int DefaultTileSize) { 704 auto Space = isl::manage(isl_schedule_node_band_get_space(Node.get())); 705 auto Dims = Space.dim(isl::dim::set); 706 auto Sizes = isl::multi_val::zero(Space); 707 std::string IdentifierString(Identifier); 708 for (auto i : seq<isl_size>(0, Dims.release())) { 709 auto tileSize = 710 i < (isl_size)TileSizes.size() ? TileSizes[i] : DefaultTileSize; 711 Sizes = Sizes.set_val(i, isl::val(Node.ctx(), tileSize)); 712 } 713 auto TileLoopMarkerStr = IdentifierString + " - Tiles"; 714 auto TileLoopMarker = isl::id::alloc(Node.ctx(), TileLoopMarkerStr, nullptr); 715 Node = Node.insert_mark(TileLoopMarker); 716 Node = Node.child(0); 717 Node = 718 isl::manage(isl_schedule_node_band_tile(Node.release(), Sizes.release())); 719 Node = Node.child(0); 720 auto PointLoopMarkerStr = IdentifierString + " - Points"; 721 auto PointLoopMarker = 722 isl::id::alloc(Node.ctx(), PointLoopMarkerStr, nullptr); 723 Node = Node.insert_mark(PointLoopMarker); 724 return Node.child(0); 725 } 726 727 isl::schedule_node polly::applyRegisterTiling(isl::schedule_node Node, 728 ArrayRef<int> TileSizes, 729 int DefaultTileSize) { 730 Node = tileNode(Node, "Register tiling", TileSizes, DefaultTileSize); 731 auto Ctx = Node.ctx(); 732 return Node.as<isl::schedule_node_band>().set_ast_build_options( 733 isl::union_set(Ctx, "{unroll[x]}")); 734 } 735 736 /// Find statements and sub-loops in (possibly nested) sequences. 737 static void 738 collectFussionableStmts(isl::schedule_node Node, 739 SmallVectorImpl<isl::schedule_node> &ScheduleStmts) { 740 if (isBand(Node) || isLeaf(Node)) { 741 ScheduleStmts.push_back(Node); 742 return; 743 } 744 745 if (Node.has_children()) { 746 isl::schedule_node C = Node.first_child(); 747 while (true) { 748 collectFussionableStmts(C, ScheduleStmts); 749 if (!C.has_next_sibling()) 750 break; 751 C = C.next_sibling(); 752 } 753 } 754 } 755 756 isl::schedule polly::applyMaxFission(isl::schedule_node BandToFission) { 757 isl::ctx Ctx = BandToFission.ctx(); 758 BandToFission = removeMark(BandToFission); 759 isl::schedule_node BandBody = BandToFission.child(0); 760 761 SmallVector<isl::schedule_node> FissionableStmts; 762 collectFussionableStmts(BandBody, FissionableStmts); 763 size_t N = FissionableStmts.size(); 764 765 // Collect the domain for each of the statements that will get their own loop. 766 isl::union_set_list DomList = isl::union_set_list(Ctx, N); 767 for (size_t i = 0; i < N; ++i) { 768 isl::schedule_node BodyPart = FissionableStmts[i]; 769 DomList = DomList.add(BodyPart.get_domain()); 770 } 771 772 // Apply the fission by copying the entire loop, but inserting a filter for 773 // the statement domains for each fissioned loop. 774 isl::schedule_node Fissioned = BandToFission.insert_sequence(DomList); 775 776 return Fissioned.get_schedule(); 777 } 778