1 //===- polly/ScheduleTreeTransform.cpp --------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Make changes to isl's schedule tree data structure. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "polly/ScheduleTreeTransform.h" 14 #include "polly/Support/ISLTools.h" 15 #include "polly/Support/ScopHelper.h" 16 #include "llvm/ADT/ArrayRef.h" 17 #include "llvm/ADT/Sequence.h" 18 #include "llvm/ADT/SmallVector.h" 19 #include "llvm/IR/Constants.h" 20 #include "llvm/IR/Metadata.h" 21 #include "llvm/Transforms/Utils/UnrollLoop.h" 22 23 using namespace polly; 24 using namespace llvm; 25 26 namespace { 27 /// Recursively visit all nodes of a schedule tree while allowing changes. 28 /// 29 /// The visit methods return an isl::schedule_node that is used to continue 30 /// visiting the tree. Structural changes such as returning a different node 31 /// will confuse the visitor. 32 template <typename Derived, typename... Args> 33 struct ScheduleNodeRewriter 34 : public RecursiveScheduleTreeVisitor<Derived, isl::schedule_node, 35 Args...> { 36 Derived &getDerived() { return *static_cast<Derived *>(this); } 37 const Derived &getDerived() const { 38 return *static_cast<const Derived *>(this); 39 } 40 41 isl::schedule_node visitNode(const isl::schedule_node &Node, Args... args) { 42 if (!Node.has_children()) 43 return Node; 44 45 isl::schedule_node It = Node.first_child(); 46 while (true) { 47 It = getDerived().visit(It, std::forward<Args>(args)...); 48 if (!It.has_next_sibling()) 49 break; 50 It = It.next_sibling(); 51 } 52 return It.parent(); 53 } 54 }; 55 56 /// Rewrite a schedule tree by reconstructing it bottom-up. 57 /// 58 /// By default, the original schedule tree is reconstructed. To build a 59 /// different tree, redefine visitor methods in a derived class (CRTP). 60 /// 61 /// Note that AST build options are not applied; Setting the isolate[] option 62 /// makes the schedule tree 'anchored' and cannot be modified afterwards. Hence, 63 /// AST build options must be set after the tree has been constructed. 64 template <typename Derived, typename... Args> 65 struct ScheduleTreeRewriter 66 : public RecursiveScheduleTreeVisitor<Derived, isl::schedule, Args...> { 67 Derived &getDerived() { return *static_cast<Derived *>(this); } 68 const Derived &getDerived() const { 69 return *static_cast<const Derived *>(this); 70 } 71 72 isl::schedule visitDomain(const isl::schedule_node &Node, Args... args) { 73 // Every schedule_tree already has a domain node, no need to add one. 74 return getDerived().visit(Node.first_child(), std::forward<Args>(args)...); 75 } 76 77 isl::schedule visitBand(const isl::schedule_node &Band, Args... args) { 78 isl::multi_union_pw_aff PartialSched = 79 isl::manage(isl_schedule_node_band_get_partial_schedule(Band.get())); 80 isl::schedule NewChild = 81 getDerived().visit(Band.child(0), std::forward<Args>(args)...); 82 isl::schedule_node NewNode = 83 NewChild.insert_partial_schedule(PartialSched).get_root().get_child(0); 84 85 // Reapply permutability and coincidence attributes. 86 NewNode = isl::manage(isl_schedule_node_band_set_permutable( 87 NewNode.release(), isl_schedule_node_band_get_permutable(Band.get()))); 88 unsigned BandDims = isl_schedule_node_band_n_member(Band.get()); 89 for (unsigned i = 0; i < BandDims; i += 1) 90 NewNode = isl::manage(isl_schedule_node_band_member_set_coincident( 91 NewNode.release(), i, 92 isl_schedule_node_band_member_get_coincident(Band.get(), i))); 93 94 return NewNode.get_schedule(); 95 } 96 97 isl::schedule visitSequence(const isl::schedule_node &Sequence, 98 Args... args) { 99 int NumChildren = isl_schedule_node_n_children(Sequence.get()); 100 isl::schedule Result = 101 getDerived().visit(Sequence.child(0), std::forward<Args>(args)...); 102 for (int i = 1; i < NumChildren; i += 1) 103 Result = Result.sequence( 104 getDerived().visit(Sequence.child(i), std::forward<Args>(args)...)); 105 return Result; 106 } 107 108 isl::schedule visitSet(const isl::schedule_node &Set, Args... args) { 109 int NumChildren = isl_schedule_node_n_children(Set.get()); 110 isl::schedule Result = 111 getDerived().visit(Set.child(0), std::forward<Args>(args)...); 112 for (int i = 1; i < NumChildren; i += 1) 113 Result = isl::manage( 114 isl_schedule_set(Result.release(), 115 getDerived() 116 .visit(Set.child(i), std::forward<Args>(args)...) 117 .release())); 118 return Result; 119 } 120 121 isl::schedule visitLeaf(const isl::schedule_node &Leaf, Args... args) { 122 return isl::schedule::from_domain(Leaf.get_domain()); 123 } 124 125 isl::schedule visitMark(const isl::schedule_node &Mark, Args... args) { 126 isl::id TheMark = Mark.mark_get_id(); 127 isl::schedule_node NewChild = 128 getDerived() 129 .visit(Mark.first_child(), std::forward<Args>(args)...) 130 .get_root() 131 .first_child(); 132 return NewChild.insert_mark(TheMark).get_schedule(); 133 } 134 135 isl::schedule visitExtension(const isl::schedule_node &Extension, 136 Args... args) { 137 isl::union_map TheExtension = Extension.extension_get_extension(); 138 isl::schedule_node NewChild = getDerived() 139 .visit(Extension.child(0), args...) 140 .get_root() 141 .first_child(); 142 isl::schedule_node NewExtension = 143 isl::schedule_node::from_extension(TheExtension); 144 return NewChild.graft_before(NewExtension).get_schedule(); 145 } 146 147 isl::schedule visitFilter(const isl::schedule_node &Filter, Args... args) { 148 isl::union_set FilterDomain = Filter.filter_get_filter(); 149 isl::schedule NewSchedule = 150 getDerived().visit(Filter.child(0), std::forward<Args>(args)...); 151 return NewSchedule.intersect_domain(FilterDomain); 152 } 153 154 isl::schedule visitNode(const isl::schedule_node &Node, Args... args) { 155 llvm_unreachable("Not implemented"); 156 } 157 }; 158 159 /// Rewrite a schedule tree to an equivalent one without extension nodes. 160 /// 161 /// Each visit method takes two additional arguments: 162 /// 163 /// * The new domain the node, which is the inherited domain plus any domains 164 /// added by extension nodes. 165 /// 166 /// * A map of extension domains of all children is returned; it is required by 167 /// band nodes to schedule the additional domains at the same position as the 168 /// extension node would. 169 /// 170 struct ExtensionNodeRewriter 171 : public ScheduleTreeRewriter<ExtensionNodeRewriter, const isl::union_set &, 172 isl::union_map &> { 173 using BaseTy = ScheduleTreeRewriter<ExtensionNodeRewriter, 174 const isl::union_set &, isl::union_map &>; 175 BaseTy &getBase() { return *this; } 176 const BaseTy &getBase() const { return *this; } 177 178 isl::schedule visitSchedule(const isl::schedule &Schedule) { 179 isl::union_map Extensions; 180 isl::schedule Result = 181 visit(Schedule.get_root(), Schedule.get_domain(), Extensions); 182 assert(Extensions && Extensions.is_empty()); 183 return Result; 184 } 185 186 isl::schedule visitSequence(const isl::schedule_node &Sequence, 187 const isl::union_set &Domain, 188 isl::union_map &Extensions) { 189 int NumChildren = isl_schedule_node_n_children(Sequence.get()); 190 isl::schedule NewNode = visit(Sequence.first_child(), Domain, Extensions); 191 for (int i = 1; i < NumChildren; i += 1) { 192 isl::schedule_node OldChild = Sequence.child(i); 193 isl::union_map NewChildExtensions; 194 isl::schedule NewChildNode = visit(OldChild, Domain, NewChildExtensions); 195 NewNode = NewNode.sequence(NewChildNode); 196 Extensions = Extensions.unite(NewChildExtensions); 197 } 198 return NewNode; 199 } 200 201 isl::schedule visitSet(const isl::schedule_node &Set, 202 const isl::union_set &Domain, 203 isl::union_map &Extensions) { 204 int NumChildren = isl_schedule_node_n_children(Set.get()); 205 isl::schedule NewNode = visit(Set.first_child(), Domain, Extensions); 206 for (int i = 1; i < NumChildren; i += 1) { 207 isl::schedule_node OldChild = Set.child(i); 208 isl::union_map NewChildExtensions; 209 isl::schedule NewChildNode = visit(OldChild, Domain, NewChildExtensions); 210 NewNode = isl::manage( 211 isl_schedule_set(NewNode.release(), NewChildNode.release())); 212 Extensions = Extensions.unite(NewChildExtensions); 213 } 214 return NewNode; 215 } 216 217 isl::schedule visitLeaf(const isl::schedule_node &Leaf, 218 const isl::union_set &Domain, 219 isl::union_map &Extensions) { 220 isl::ctx Ctx = Leaf.get_ctx(); 221 Extensions = isl::union_map::empty(isl::space::params_alloc(Ctx, 0)); 222 return isl::schedule::from_domain(Domain); 223 } 224 225 isl::schedule visitBand(const isl::schedule_node &OldNode, 226 const isl::union_set &Domain, 227 isl::union_map &OuterExtensions) { 228 isl::schedule_node OldChild = OldNode.first_child(); 229 isl::multi_union_pw_aff PartialSched = 230 isl::manage(isl_schedule_node_band_get_partial_schedule(OldNode.get())); 231 232 isl::union_map NewChildExtensions; 233 isl::schedule NewChild = visit(OldChild, Domain, NewChildExtensions); 234 235 // Add the extensions to the partial schedule. 236 OuterExtensions = isl::union_map::empty(NewChildExtensions.get_space()); 237 isl::union_map NewPartialSchedMap = isl::union_map::from(PartialSched); 238 unsigned BandDims = isl_schedule_node_band_n_member(OldNode.get()); 239 for (isl::map Ext : NewChildExtensions.get_map_list()) { 240 unsigned ExtDims = Ext.dim(isl::dim::in); 241 assert(ExtDims >= BandDims); 242 unsigned OuterDims = ExtDims - BandDims; 243 244 isl::map BandSched = 245 Ext.project_out(isl::dim::in, 0, OuterDims).reverse(); 246 NewPartialSchedMap = NewPartialSchedMap.unite(BandSched); 247 248 // There might be more outer bands that have to schedule the extensions. 249 if (OuterDims > 0) { 250 isl::map OuterSched = 251 Ext.project_out(isl::dim::in, OuterDims, BandDims); 252 OuterExtensions = OuterExtensions.add_map(OuterSched); 253 } 254 } 255 isl::multi_union_pw_aff NewPartialSchedAsAsMultiUnionPwAff = 256 isl::multi_union_pw_aff::from_union_map(NewPartialSchedMap); 257 isl::schedule_node NewNode = 258 NewChild.insert_partial_schedule(NewPartialSchedAsAsMultiUnionPwAff) 259 .get_root() 260 .get_child(0); 261 262 // Reapply permutability and coincidence attributes. 263 NewNode = isl::manage(isl_schedule_node_band_set_permutable( 264 NewNode.release(), 265 isl_schedule_node_band_get_permutable(OldNode.get()))); 266 for (unsigned i = 0; i < BandDims; i += 1) { 267 NewNode = isl::manage(isl_schedule_node_band_member_set_coincident( 268 NewNode.release(), i, 269 isl_schedule_node_band_member_get_coincident(OldNode.get(), i))); 270 } 271 272 return NewNode.get_schedule(); 273 } 274 275 isl::schedule visitFilter(const isl::schedule_node &Filter, 276 const isl::union_set &Domain, 277 isl::union_map &Extensions) { 278 isl::union_set FilterDomain = Filter.filter_get_filter(); 279 isl::union_set NewDomain = Domain.intersect(FilterDomain); 280 281 // A filter is added implicitly if necessary when joining schedule trees. 282 return visit(Filter.first_child(), NewDomain, Extensions); 283 } 284 285 isl::schedule visitExtension(const isl::schedule_node &Extension, 286 const isl::union_set &Domain, 287 isl::union_map &Extensions) { 288 isl::union_map ExtDomain = Extension.extension_get_extension(); 289 isl::union_set NewDomain = Domain.unite(ExtDomain.range()); 290 isl::union_map ChildExtensions; 291 isl::schedule NewChild = 292 visit(Extension.first_child(), NewDomain, ChildExtensions); 293 Extensions = ChildExtensions.unite(ExtDomain); 294 return NewChild; 295 } 296 }; 297 298 /// Collect all AST build options in any schedule tree band. 299 /// 300 /// ScheduleTreeRewriter cannot apply the schedule tree options. This class 301 /// collects these options to apply them later. 302 struct CollectASTBuildOptions 303 : public RecursiveScheduleTreeVisitor<CollectASTBuildOptions> { 304 using BaseTy = RecursiveScheduleTreeVisitor<CollectASTBuildOptions>; 305 BaseTy &getBase() { return *this; } 306 const BaseTy &getBase() const { return *this; } 307 308 llvm::SmallVector<isl::union_set, 8> ASTBuildOptions; 309 310 void visitBand(const isl::schedule_node &Band) { 311 ASTBuildOptions.push_back( 312 isl::manage(isl_schedule_node_band_get_ast_build_options(Band.get()))); 313 return getBase().visitBand(Band); 314 } 315 }; 316 317 /// Apply AST build options to the bands in a schedule tree. 318 /// 319 /// This rewrites a schedule tree with the AST build options applied. We assume 320 /// that the band nodes are visited in the same order as they were when the 321 /// build options were collected, typically by CollectASTBuildOptions. 322 struct ApplyASTBuildOptions 323 : public ScheduleNodeRewriter<ApplyASTBuildOptions> { 324 using BaseTy = ScheduleNodeRewriter<ApplyASTBuildOptions>; 325 BaseTy &getBase() { return *this; } 326 const BaseTy &getBase() const { return *this; } 327 328 size_t Pos; 329 llvm::ArrayRef<isl::union_set> ASTBuildOptions; 330 331 ApplyASTBuildOptions(llvm::ArrayRef<isl::union_set> ASTBuildOptions) 332 : ASTBuildOptions(ASTBuildOptions) {} 333 334 isl::schedule visitSchedule(const isl::schedule &Schedule) { 335 Pos = 0; 336 isl::schedule Result = visit(Schedule).get_schedule(); 337 assert(Pos == ASTBuildOptions.size() && 338 "AST build options must match to band nodes"); 339 return Result; 340 } 341 342 isl::schedule_node visitBand(const isl::schedule_node &Band) { 343 isl::schedule_node Result = 344 Band.band_set_ast_build_options(ASTBuildOptions[Pos]); 345 Pos += 1; 346 return getBase().visitBand(Result); 347 } 348 }; 349 350 /// Return whether the schedule contains an extension node. 351 static bool containsExtensionNode(isl::schedule Schedule) { 352 assert(!Schedule.is_null()); 353 354 auto Callback = [](__isl_keep isl_schedule_node *Node, 355 void *User) -> isl_bool { 356 if (isl_schedule_node_get_type(Node) == isl_schedule_node_extension) { 357 // Stop walking the schedule tree. 358 return isl_bool_error; 359 } 360 361 // Continue searching the subtree. 362 return isl_bool_true; 363 }; 364 isl_stat RetVal = isl_schedule_foreach_schedule_node_top_down( 365 Schedule.get(), Callback, nullptr); 366 367 // We assume that the traversal itself does not fail, i.e. the only reason to 368 // return isl_stat_error is that an extension node was found. 369 return RetVal == isl_stat_error; 370 } 371 372 /// Find a named MDNode property in a LoopID. 373 static MDNode *findOptionalNodeOperand(MDNode *LoopMD, StringRef Name) { 374 return dyn_cast_or_null<MDNode>( 375 findMetadataOperand(LoopMD, Name).getValueOr(nullptr)); 376 } 377 378 /// Is this node of type mark? 379 static bool isMark(const isl::schedule_node &Node) { 380 return isl_schedule_node_get_type(Node.get()) == isl_schedule_node_mark; 381 } 382 383 #ifndef NDEBUG 384 /// Is this node of type band? 385 static bool isBand(const isl::schedule_node &Node) { 386 return isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band; 387 } 388 389 /// Is this node a band of a single dimension (i.e. could represent a loop)? 390 static bool isBandWithSingleLoop(const isl::schedule_node &Node) { 391 392 return isBand(Node) && isl_schedule_node_band_n_member(Node.get()) == 1; 393 } 394 #endif 395 396 /// Create an isl::id representing the output loop after a transformation. 397 static isl::id createGeneratedLoopAttr(isl::ctx Ctx, MDNode *FollowupLoopMD) { 398 // Don't need to id the followup. 399 // TODO: Append llvm.loop.disable_heustistics metadata unless overridden by 400 // user followup-MD 401 if (!FollowupLoopMD) 402 return {}; 403 404 BandAttr *Attr = new BandAttr(); 405 Attr->Metadata = FollowupLoopMD; 406 return getIslLoopAttr(Ctx, Attr); 407 } 408 409 /// A loop consists of a band and an optional marker that wraps it. Return the 410 /// outermost of the two. 411 412 /// That is, either the mark or, if there is not mark, the loop itself. Can 413 /// start with either the mark or the band. 414 static isl::schedule_node moveToBandMark(isl::schedule_node BandOrMark) { 415 if (isBandMark(BandOrMark)) { 416 assert(isBandWithSingleLoop(BandOrMark.get_child(0))); 417 return BandOrMark; 418 } 419 assert(isBandWithSingleLoop(BandOrMark)); 420 421 isl::schedule_node Mark = BandOrMark.parent(); 422 if (isBandMark(Mark)) 423 return Mark; 424 425 // Band has no loop marker. 426 return BandOrMark; 427 } 428 429 static isl::schedule_node removeMark(isl::schedule_node MarkOrBand, 430 BandAttr *&Attr) { 431 MarkOrBand = moveToBandMark(MarkOrBand); 432 433 isl::schedule_node Band; 434 if (isMark(MarkOrBand)) { 435 Attr = getLoopAttr(MarkOrBand.mark_get_id()); 436 Band = isl::manage(isl_schedule_node_delete(MarkOrBand.release())); 437 } else { 438 Attr = nullptr; 439 Band = MarkOrBand; 440 } 441 442 assert(isBandWithSingleLoop(Band)); 443 return Band; 444 } 445 446 /// Remove the mark that wraps a loop. Return the band representing the loop. 447 static isl::schedule_node removeMark(isl::schedule_node MarkOrBand) { 448 BandAttr *Attr; 449 return removeMark(MarkOrBand, Attr); 450 } 451 452 static isl::schedule_node insertMark(isl::schedule_node Band, isl::id Mark) { 453 assert(isBand(Band)); 454 assert(moveToBandMark(Band).is_equal(Band) && 455 "Don't add a two marks for a band"); 456 457 return Band.insert_mark(Mark).get_child(0); 458 } 459 460 /// Return the (one-dimensional) set of numbers that are divisible by @p Factor 461 /// with remainder @p Offset. 462 /// 463 /// isDivisibleBySet(Ctx, 4, 0) = { [i] : floord(i,4) = 0 } 464 /// isDivisibleBySet(Ctx, 4, 1) = { [i] : floord(i,4) = 1 } 465 /// 466 static isl::basic_set isDivisibleBySet(isl::ctx &Ctx, long Factor, 467 long Offset) { 468 isl::val ValFactor{Ctx, Factor}; 469 isl::val ValOffset{Ctx, Offset}; 470 471 isl::space Unispace{Ctx, 0, 1}; 472 isl::local_space LUnispace{Unispace}; 473 isl::aff AffFactor{LUnispace, ValFactor}; 474 isl::aff AffOffset{LUnispace, ValOffset}; 475 476 isl::aff Id = isl::aff::var_on_domain(LUnispace, isl::dim::out, 0); 477 isl::aff DivMul = Id.mod(ValFactor); 478 isl::basic_map Divisible = isl::basic_map::from_aff(DivMul); 479 isl::basic_map Modulo = Divisible.fix_val(isl::dim::out, 0, ValOffset); 480 return Modulo.domain(); 481 } 482 483 } // namespace 484 485 bool polly::isBandMark(const isl::schedule_node &Node) { 486 return isMark(Node) && isLoopAttr(Node.mark_get_id()); 487 } 488 489 BandAttr *polly::getBandAttr(isl::schedule_node MarkOrBand) { 490 MarkOrBand = moveToBandMark(MarkOrBand); 491 if (!isMark(MarkOrBand)) 492 return nullptr; 493 494 return getLoopAttr(MarkOrBand.mark_get_id()); 495 } 496 497 isl::schedule polly::hoistExtensionNodes(isl::schedule Sched) { 498 // If there is no extension node in the first place, return the original 499 // schedule tree. 500 if (!containsExtensionNode(Sched)) 501 return Sched; 502 503 // Build options can anchor schedule nodes, such that the schedule tree cannot 504 // be modified anymore. Therefore, apply build options after the tree has been 505 // created. 506 CollectASTBuildOptions Collector; 507 Collector.visit(Sched); 508 509 // Rewrite the schedule tree without extension nodes. 510 ExtensionNodeRewriter Rewriter; 511 isl::schedule NewSched = Rewriter.visitSchedule(Sched); 512 513 // Reapply the AST build options. The rewriter must not change the iteration 514 // order of bands. Any other node type is ignored. 515 ApplyASTBuildOptions Applicator(Collector.ASTBuildOptions); 516 NewSched = Applicator.visitSchedule(NewSched); 517 518 return NewSched; 519 } 520 521 isl::schedule polly::applyFullUnroll(isl::schedule_node BandToUnroll) { 522 isl::ctx Ctx = BandToUnroll.get_ctx(); 523 524 // Remove the loop's mark, the loop will disappear anyway. 525 BandToUnroll = removeMark(BandToUnroll); 526 assert(isBandWithSingleLoop(BandToUnroll)); 527 528 isl::multi_union_pw_aff PartialSched = isl::manage( 529 isl_schedule_node_band_get_partial_schedule(BandToUnroll.get())); 530 assert(PartialSched.dim(isl::dim::out) == 1 && 531 "Can only unroll a single dimension"); 532 isl::union_pw_aff PartialSchedUAff = PartialSched.get_union_pw_aff(0); 533 534 isl::union_set Domain = BandToUnroll.get_domain(); 535 PartialSchedUAff = PartialSchedUAff.intersect_domain(Domain); 536 isl::union_map PartialSchedUMap = isl::union_map(PartialSchedUAff); 537 538 // Enumerator only the scatter elements. 539 isl::union_set ScatterList = PartialSchedUMap.range(); 540 541 // Enumerate all loop iterations. 542 // TODO: Diagnose if not enumerable or depends on a parameter. 543 SmallVector<isl::point, 16> Elts; 544 ScatterList.foreach_point([&Elts](isl::point P) -> isl::stat { 545 Elts.push_back(P); 546 return isl::stat::ok(); 547 }); 548 549 // Don't assume that foreach_point returns in execution order. 550 llvm::sort(Elts, [](isl::point P1, isl::point P2) -> bool { 551 isl::val C1 = P1.get_coordinate_val(isl::dim::set, 0); 552 isl::val C2 = P2.get_coordinate_val(isl::dim::set, 0); 553 return C1.lt(C2); 554 }); 555 556 // Convert the points to a sequence of filters. 557 isl::union_set_list List = isl::union_set_list::alloc(Ctx, Elts.size()); 558 for (isl::point P : Elts) { 559 // Determine the domains that map this scatter element. 560 isl::union_set DomainFilter = PartialSchedUMap.intersect_range(P).domain(); 561 562 List = List.add(DomainFilter); 563 } 564 565 // Replace original band with unrolled sequence. 566 isl::schedule_node Body = 567 isl::manage(isl_schedule_node_delete(BandToUnroll.release())); 568 Body = Body.insert_sequence(List); 569 return Body.get_schedule(); 570 } 571 572 isl::schedule polly::applyPartialUnroll(isl::schedule_node BandToUnroll, 573 int Factor) { 574 assert(Factor > 0 && "Positive unroll factor required"); 575 isl::ctx Ctx = BandToUnroll.get_ctx(); 576 577 // Remove the mark, save the attribute for later use. 578 BandAttr *Attr; 579 BandToUnroll = removeMark(BandToUnroll, Attr); 580 assert(isBandWithSingleLoop(BandToUnroll)); 581 582 isl::multi_union_pw_aff PartialSched = isl::manage( 583 isl_schedule_node_band_get_partial_schedule(BandToUnroll.get())); 584 585 // { Stmt[] -> [x] } 586 isl::union_pw_aff PartialSchedUAff = PartialSched.get_union_pw_aff(0); 587 588 // Here we assume the schedule stride is one and starts with 0, which is not 589 // necessarily the case. 590 isl::union_pw_aff StridedPartialSchedUAff = 591 isl::union_pw_aff::empty(PartialSchedUAff.get_space()); 592 isl::val ValFactor{Ctx, Factor}; 593 PartialSchedUAff.foreach_pw_aff([&StridedPartialSchedUAff, 594 &ValFactor](isl::pw_aff PwAff) -> isl::stat { 595 isl::space Space = PwAff.get_space(); 596 isl::set Universe = isl::set::universe(Space.domain()); 597 isl::pw_aff AffFactor{Universe, ValFactor}; 598 isl::pw_aff DivSchedAff = PwAff.div(AffFactor).floor().mul(AffFactor); 599 StridedPartialSchedUAff = StridedPartialSchedUAff.union_add(DivSchedAff); 600 return isl::stat::ok(); 601 }); 602 603 isl::union_set_list List = isl::union_set_list::alloc(Ctx, Factor); 604 for (auto i : seq<int>(0, Factor)) { 605 // { Stmt[] -> [x] } 606 isl::union_map UMap{PartialSchedUAff}; 607 608 // { [x] } 609 isl::basic_set Divisible = isDivisibleBySet(Ctx, Factor, i); 610 611 // { Stmt[] } 612 isl::union_set UnrolledDomain = UMap.intersect_range(Divisible).domain(); 613 614 List = List.add(UnrolledDomain); 615 } 616 617 isl::schedule_node Body = 618 isl::manage(isl_schedule_node_delete(BandToUnroll.copy())); 619 Body = Body.insert_sequence(List); 620 isl::schedule_node NewLoop = 621 Body.insert_partial_schedule(StridedPartialSchedUAff); 622 623 MDNode *FollowupMD = nullptr; 624 if (Attr && Attr->Metadata) 625 FollowupMD = 626 findOptionalNodeOperand(Attr->Metadata, LLVMLoopUnrollFollowupUnrolled); 627 628 isl::id NewBandId = createGeneratedLoopAttr(Ctx, FollowupMD); 629 if (NewBandId) 630 NewLoop = insertMark(NewLoop, NewBandId); 631 632 return NewLoop.get_schedule(); 633 } 634