1 //===- polly/ScheduleTreeTransform.cpp --------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Make changes to isl's schedule tree data structure. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "polly/ScheduleTreeTransform.h" 14 #include "polly/Support/GICHelper.h" 15 #include "polly/Support/ISLTools.h" 16 #include "polly/Support/ScopHelper.h" 17 #include "llvm/ADT/ArrayRef.h" 18 #include "llvm/ADT/Sequence.h" 19 #include "llvm/ADT/SmallVector.h" 20 #include "llvm/IR/Constants.h" 21 #include "llvm/IR/Metadata.h" 22 #include "llvm/Transforms/Utils/UnrollLoop.h" 23 24 #define DEBUG_TYPE "polly-opt-isl" 25 26 using namespace polly; 27 using namespace llvm; 28 29 namespace { 30 31 /// Copy the band member attributes (coincidence, loop type, isolate ast loop 32 /// type) from one band to another. 33 static isl::schedule_node_band 34 applyBandMemberAttributes(isl::schedule_node_band Target, int TargetIdx, 35 const isl::schedule_node_band &Source, 36 int SourceIdx) { 37 bool Coincident = Source.member_get_coincident(SourceIdx).release(); 38 Target = Target.member_set_coincident(TargetIdx, Coincident); 39 40 isl_ast_loop_type LoopType = 41 isl_schedule_node_band_member_get_ast_loop_type(Source.get(), SourceIdx); 42 Target = isl::manage(isl_schedule_node_band_member_set_ast_loop_type( 43 Target.release(), TargetIdx, LoopType)) 44 .as<isl::schedule_node_band>(); 45 46 isl_ast_loop_type IsolateType = 47 isl_schedule_node_band_member_get_isolate_ast_loop_type(Source.get(), 48 SourceIdx); 49 Target = isl::manage(isl_schedule_node_band_member_set_isolate_ast_loop_type( 50 Target.release(), TargetIdx, IsolateType)) 51 .as<isl::schedule_node_band>(); 52 53 return Target; 54 } 55 56 /// Create a new band by copying members from another @p Band. @p IncludeCb 57 /// decides which band indices are copied to the result. 58 template <typename CbTy> 59 static isl::schedule rebuildBand(isl::schedule_node_band OldBand, 60 isl::schedule Body, CbTy IncludeCb) { 61 int NumBandDims = unsignedFromIslSize(OldBand.n_member()); 62 63 bool ExcludeAny = false; 64 bool IncludeAny = false; 65 for (auto OldIdx : seq<int>(0, NumBandDims)) { 66 if (IncludeCb(OldIdx)) 67 IncludeAny = true; 68 else 69 ExcludeAny = true; 70 } 71 72 // Instead of creating a zero-member band, don't create a band at all. 73 if (!IncludeAny) 74 return Body; 75 76 isl::multi_union_pw_aff PartialSched = OldBand.get_partial_schedule(); 77 isl::multi_union_pw_aff NewPartialSched; 78 if (ExcludeAny) { 79 // Select the included partial scatter functions. 80 isl::union_pw_aff_list List = PartialSched.list(); 81 int NewIdx = 0; 82 for (auto OldIdx : seq<int>(0, NumBandDims)) { 83 if (IncludeCb(OldIdx)) 84 NewIdx += 1; 85 else 86 List = List.drop(NewIdx, 1); 87 } 88 isl::space ParamSpace = PartialSched.get_space().params(); 89 isl::space NewScatterSpace = ParamSpace.add_unnamed_tuple(NewIdx); 90 NewPartialSched = isl::multi_union_pw_aff(NewScatterSpace, List); 91 } else { 92 // Just reuse original scatter function of copying all of them. 93 NewPartialSched = PartialSched; 94 } 95 96 // Create the new band node. 97 isl::schedule_node_band NewBand = 98 Body.insert_partial_schedule(NewPartialSched) 99 .get_root() 100 .child(0) 101 .as<isl::schedule_node_band>(); 102 103 // If OldBand was permutable, so is the new one, even if some dimensions are 104 // missing. 105 bool IsPermutable = OldBand.permutable().release(); 106 NewBand = NewBand.set_permutable(IsPermutable); 107 108 // Reapply member attributes. 109 int NewIdx = 0; 110 for (auto OldIdx : seq<int>(0, NumBandDims)) { 111 if (!IncludeCb(OldIdx)) 112 continue; 113 NewBand = 114 applyBandMemberAttributes(std::move(NewBand), NewIdx, OldBand, OldIdx); 115 NewIdx += 1; 116 } 117 118 return NewBand.get_schedule(); 119 } 120 121 /// Recursively visit all nodes of a schedule tree while allowing changes. 122 /// 123 /// The visit methods return an isl::schedule_node that is used to continue 124 /// visiting the tree. Structural changes such as returning a different node 125 /// will confuse the visitor. 126 template <typename Derived, typename... Args> 127 struct ScheduleNodeRewriter 128 : public RecursiveScheduleTreeVisitor<Derived, isl::schedule_node, 129 Args...> { 130 Derived &getDerived() { return *static_cast<Derived *>(this); } 131 const Derived &getDerived() const { 132 return *static_cast<const Derived *>(this); 133 } 134 135 isl::schedule_node visitNode(const isl::schedule_node &Node, Args... args) { 136 if (!Node.has_children()) 137 return Node; 138 139 isl::schedule_node It = Node.first_child(); 140 while (true) { 141 It = getDerived().visit(It, std::forward<Args>(args)...); 142 if (!It.has_next_sibling()) 143 break; 144 It = It.next_sibling(); 145 } 146 return It.parent(); 147 } 148 }; 149 150 /// Rewrite a schedule tree by reconstructing it bottom-up. 151 /// 152 /// By default, the original schedule tree is reconstructed. To build a 153 /// different tree, redefine visitor methods in a derived class (CRTP). 154 /// 155 /// Note that AST build options are not applied; Setting the isolate[] option 156 /// makes the schedule tree 'anchored' and cannot be modified afterwards. Hence, 157 /// AST build options must be set after the tree has been constructed. 158 template <typename Derived, typename... Args> 159 struct ScheduleTreeRewriter 160 : public RecursiveScheduleTreeVisitor<Derived, isl::schedule, Args...> { 161 Derived &getDerived() { return *static_cast<Derived *>(this); } 162 const Derived &getDerived() const { 163 return *static_cast<const Derived *>(this); 164 } 165 166 isl::schedule visitDomain(isl::schedule_node_domain Node, Args... args) { 167 // Every schedule_tree already has a domain node, no need to add one. 168 return getDerived().visit(Node.first_child(), std::forward<Args>(args)...); 169 } 170 171 isl::schedule visitBand(isl::schedule_node_band Band, Args... args) { 172 isl::schedule NewChild = 173 getDerived().visit(Band.child(0), std::forward<Args>(args)...); 174 return rebuildBand(Band, NewChild, [](int) { return true; }); 175 } 176 177 isl::schedule visitSequence(isl::schedule_node_sequence Sequence, 178 Args... args) { 179 int NumChildren = isl_schedule_node_n_children(Sequence.get()); 180 isl::schedule Result = 181 getDerived().visit(Sequence.child(0), std::forward<Args>(args)...); 182 for (int i = 1; i < NumChildren; i += 1) 183 Result = Result.sequence( 184 getDerived().visit(Sequence.child(i), std::forward<Args>(args)...)); 185 return Result; 186 } 187 188 isl::schedule visitSet(isl::schedule_node_set Set, Args... args) { 189 int NumChildren = isl_schedule_node_n_children(Set.get()); 190 isl::schedule Result = 191 getDerived().visit(Set.child(0), std::forward<Args>(args)...); 192 for (int i = 1; i < NumChildren; i += 1) 193 Result = isl::manage( 194 isl_schedule_set(Result.release(), 195 getDerived() 196 .visit(Set.child(i), std::forward<Args>(args)...) 197 .release())); 198 return Result; 199 } 200 201 isl::schedule visitLeaf(isl::schedule_node_leaf Leaf, Args... args) { 202 return isl::schedule::from_domain(Leaf.get_domain()); 203 } 204 205 isl::schedule visitMark(const isl::schedule_node &Mark, Args... args) { 206 207 isl::id TheMark = Mark.as<isl::schedule_node_mark>().get_id(); 208 isl::schedule_node NewChild = 209 getDerived() 210 .visit(Mark.first_child(), std::forward<Args>(args)...) 211 .get_root() 212 .first_child(); 213 return NewChild.insert_mark(TheMark).get_schedule(); 214 } 215 216 isl::schedule visitExtension(isl::schedule_node_extension Extension, 217 Args... args) { 218 isl::union_map TheExtension = 219 Extension.as<isl::schedule_node_extension>().get_extension(); 220 isl::schedule_node NewChild = getDerived() 221 .visit(Extension.child(0), args...) 222 .get_root() 223 .first_child(); 224 isl::schedule_node NewExtension = 225 isl::schedule_node::from_extension(TheExtension); 226 return NewChild.graft_before(NewExtension).get_schedule(); 227 } 228 229 isl::schedule visitFilter(isl::schedule_node_filter Filter, Args... args) { 230 isl::union_set FilterDomain = 231 Filter.as<isl::schedule_node_filter>().get_filter(); 232 isl::schedule NewSchedule = 233 getDerived().visit(Filter.child(0), std::forward<Args>(args)...); 234 return NewSchedule.intersect_domain(FilterDomain); 235 } 236 237 isl::schedule visitNode(isl::schedule_node Node, Args... args) { 238 llvm_unreachable("Not implemented"); 239 } 240 }; 241 242 /// Rewrite the schedule tree without any changes. Useful to copy a subtree into 243 /// a new schedule, discarding everything but. 244 struct IdentityRewriter : public ScheduleTreeRewriter<IdentityRewriter> {}; 245 246 /// Rewrite a schedule tree to an equivalent one without extension nodes. 247 /// 248 /// Each visit method takes two additional arguments: 249 /// 250 /// * The new domain the node, which is the inherited domain plus any domains 251 /// added by extension nodes. 252 /// 253 /// * A map of extension domains of all children is returned; it is required by 254 /// band nodes to schedule the additional domains at the same position as the 255 /// extension node would. 256 /// 257 struct ExtensionNodeRewriter 258 : public ScheduleTreeRewriter<ExtensionNodeRewriter, const isl::union_set &, 259 isl::union_map &> { 260 using BaseTy = ScheduleTreeRewriter<ExtensionNodeRewriter, 261 const isl::union_set &, isl::union_map &>; 262 BaseTy &getBase() { return *this; } 263 const BaseTy &getBase() const { return *this; } 264 265 isl::schedule visitSchedule(isl::schedule Schedule) { 266 isl::union_map Extensions; 267 isl::schedule Result = 268 visit(Schedule.get_root(), Schedule.get_domain(), Extensions); 269 assert(!Extensions.is_null() && Extensions.is_empty()); 270 return Result; 271 } 272 273 isl::schedule visitSequence(isl::schedule_node_sequence Sequence, 274 const isl::union_set &Domain, 275 isl::union_map &Extensions) { 276 int NumChildren = isl_schedule_node_n_children(Sequence.get()); 277 isl::schedule NewNode = visit(Sequence.first_child(), Domain, Extensions); 278 for (int i = 1; i < NumChildren; i += 1) { 279 isl::schedule_node OldChild = Sequence.child(i); 280 isl::union_map NewChildExtensions; 281 isl::schedule NewChildNode = visit(OldChild, Domain, NewChildExtensions); 282 NewNode = NewNode.sequence(NewChildNode); 283 Extensions = Extensions.unite(NewChildExtensions); 284 } 285 return NewNode; 286 } 287 288 isl::schedule visitSet(isl::schedule_node_set Set, 289 const isl::union_set &Domain, 290 isl::union_map &Extensions) { 291 int NumChildren = isl_schedule_node_n_children(Set.get()); 292 isl::schedule NewNode = visit(Set.first_child(), Domain, Extensions); 293 for (int i = 1; i < NumChildren; i += 1) { 294 isl::schedule_node OldChild = Set.child(i); 295 isl::union_map NewChildExtensions; 296 isl::schedule NewChildNode = visit(OldChild, Domain, NewChildExtensions); 297 NewNode = isl::manage( 298 isl_schedule_set(NewNode.release(), NewChildNode.release())); 299 Extensions = Extensions.unite(NewChildExtensions); 300 } 301 return NewNode; 302 } 303 304 isl::schedule visitLeaf(isl::schedule_node_leaf Leaf, 305 const isl::union_set &Domain, 306 isl::union_map &Extensions) { 307 Extensions = isl::union_map::empty(Leaf.ctx()); 308 return isl::schedule::from_domain(Domain); 309 } 310 311 isl::schedule visitBand(isl::schedule_node_band OldNode, 312 const isl::union_set &Domain, 313 isl::union_map &OuterExtensions) { 314 isl::schedule_node OldChild = OldNode.first_child(); 315 isl::multi_union_pw_aff PartialSched = 316 isl::manage(isl_schedule_node_band_get_partial_schedule(OldNode.get())); 317 318 isl::union_map NewChildExtensions; 319 isl::schedule NewChild = visit(OldChild, Domain, NewChildExtensions); 320 321 // Add the extensions to the partial schedule. 322 OuterExtensions = isl::union_map::empty(NewChildExtensions.ctx()); 323 isl::union_map NewPartialSchedMap = isl::union_map::from(PartialSched); 324 unsigned BandDims = isl_schedule_node_band_n_member(OldNode.get()); 325 for (isl::map Ext : NewChildExtensions.get_map_list()) { 326 unsigned ExtDims = unsignedFromIslSize(Ext.domain_tuple_dim()); 327 assert(ExtDims >= BandDims); 328 unsigned OuterDims = ExtDims - BandDims; 329 330 isl::map BandSched = 331 Ext.project_out(isl::dim::in, 0, OuterDims).reverse(); 332 NewPartialSchedMap = NewPartialSchedMap.unite(BandSched); 333 334 // There might be more outer bands that have to schedule the extensions. 335 if (OuterDims > 0) { 336 isl::map OuterSched = 337 Ext.project_out(isl::dim::in, OuterDims, BandDims); 338 OuterExtensions = OuterExtensions.unite(OuterSched); 339 } 340 } 341 isl::multi_union_pw_aff NewPartialSchedAsAsMultiUnionPwAff = 342 isl::multi_union_pw_aff::from_union_map(NewPartialSchedMap); 343 isl::schedule_node NewNode = 344 NewChild.insert_partial_schedule(NewPartialSchedAsAsMultiUnionPwAff) 345 .get_root() 346 .child(0); 347 348 // Reapply permutability and coincidence attributes. 349 NewNode = isl::manage(isl_schedule_node_band_set_permutable( 350 NewNode.release(), 351 isl_schedule_node_band_get_permutable(OldNode.get()))); 352 for (unsigned i = 0; i < BandDims; i += 1) 353 NewNode = applyBandMemberAttributes(NewNode.as<isl::schedule_node_band>(), 354 i, OldNode, i); 355 356 return NewNode.get_schedule(); 357 } 358 359 isl::schedule visitFilter(isl::schedule_node_filter Filter, 360 const isl::union_set &Domain, 361 isl::union_map &Extensions) { 362 isl::union_set FilterDomain = 363 Filter.as<isl::schedule_node_filter>().get_filter(); 364 isl::union_set NewDomain = Domain.intersect(FilterDomain); 365 366 // A filter is added implicitly if necessary when joining schedule trees. 367 return visit(Filter.first_child(), NewDomain, Extensions); 368 } 369 370 isl::schedule visitExtension(isl::schedule_node_extension Extension, 371 const isl::union_set &Domain, 372 isl::union_map &Extensions) { 373 isl::union_map ExtDomain = 374 Extension.as<isl::schedule_node_extension>().get_extension(); 375 isl::union_set NewDomain = Domain.unite(ExtDomain.range()); 376 isl::union_map ChildExtensions; 377 isl::schedule NewChild = 378 visit(Extension.first_child(), NewDomain, ChildExtensions); 379 Extensions = ChildExtensions.unite(ExtDomain); 380 return NewChild; 381 } 382 }; 383 384 /// Collect all AST build options in any schedule tree band. 385 /// 386 /// ScheduleTreeRewriter cannot apply the schedule tree options. This class 387 /// collects these options to apply them later. 388 struct CollectASTBuildOptions 389 : public RecursiveScheduleTreeVisitor<CollectASTBuildOptions> { 390 using BaseTy = RecursiveScheduleTreeVisitor<CollectASTBuildOptions>; 391 BaseTy &getBase() { return *this; } 392 const BaseTy &getBase() const { return *this; } 393 394 llvm::SmallVector<isl::union_set, 8> ASTBuildOptions; 395 396 void visitBand(isl::schedule_node_band Band) { 397 ASTBuildOptions.push_back( 398 isl::manage(isl_schedule_node_band_get_ast_build_options(Band.get()))); 399 return getBase().visitBand(Band); 400 } 401 }; 402 403 /// Apply AST build options to the bands in a schedule tree. 404 /// 405 /// This rewrites a schedule tree with the AST build options applied. We assume 406 /// that the band nodes are visited in the same order as they were when the 407 /// build options were collected, typically by CollectASTBuildOptions. 408 struct ApplyASTBuildOptions 409 : public ScheduleNodeRewriter<ApplyASTBuildOptions> { 410 using BaseTy = ScheduleNodeRewriter<ApplyASTBuildOptions>; 411 BaseTy &getBase() { return *this; } 412 const BaseTy &getBase() const { return *this; } 413 414 size_t Pos; 415 llvm::ArrayRef<isl::union_set> ASTBuildOptions; 416 417 ApplyASTBuildOptions(llvm::ArrayRef<isl::union_set> ASTBuildOptions) 418 : ASTBuildOptions(ASTBuildOptions) {} 419 420 isl::schedule visitSchedule(isl::schedule Schedule) { 421 Pos = 0; 422 isl::schedule Result = visit(Schedule).get_schedule(); 423 assert(Pos == ASTBuildOptions.size() && 424 "AST build options must match to band nodes"); 425 return Result; 426 } 427 428 isl::schedule_node visitBand(isl::schedule_node_band Band) { 429 isl::schedule_node_band Result = 430 Band.set_ast_build_options(ASTBuildOptions[Pos]); 431 Pos += 1; 432 return getBase().visitBand(Result); 433 } 434 }; 435 436 /// Return whether the schedule contains an extension node. 437 static bool containsExtensionNode(isl::schedule Schedule) { 438 assert(!Schedule.is_null()); 439 440 auto Callback = [](__isl_keep isl_schedule_node *Node, 441 void *User) -> isl_bool { 442 if (isl_schedule_node_get_type(Node) == isl_schedule_node_extension) { 443 // Stop walking the schedule tree. 444 return isl_bool_error; 445 } 446 447 // Continue searching the subtree. 448 return isl_bool_true; 449 }; 450 isl_stat RetVal = isl_schedule_foreach_schedule_node_top_down( 451 Schedule.get(), Callback, nullptr); 452 453 // We assume that the traversal itself does not fail, i.e. the only reason to 454 // return isl_stat_error is that an extension node was found. 455 return RetVal == isl_stat_error; 456 } 457 458 /// Find a named MDNode property in a LoopID. 459 static MDNode *findOptionalNodeOperand(MDNode *LoopMD, StringRef Name) { 460 return dyn_cast_or_null<MDNode>( 461 findMetadataOperand(LoopMD, Name).getValueOr(nullptr)); 462 } 463 464 /// Is this node of type mark? 465 static bool isMark(const isl::schedule_node &Node) { 466 return isl_schedule_node_get_type(Node.get()) == isl_schedule_node_mark; 467 } 468 469 /// Is this node of type band? 470 static bool isBand(const isl::schedule_node &Node) { 471 return isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band; 472 } 473 474 #ifndef NDEBUG 475 /// Is this node a band of a single dimension (i.e. could represent a loop)? 476 static bool isBandWithSingleLoop(const isl::schedule_node &Node) { 477 return isBand(Node) && isl_schedule_node_band_n_member(Node.get()) == 1; 478 } 479 #endif 480 481 static bool isLeaf(const isl::schedule_node &Node) { 482 return isl_schedule_node_get_type(Node.get()) == isl_schedule_node_leaf; 483 } 484 485 /// Create an isl::id representing the output loop after a transformation. 486 static isl::id createGeneratedLoopAttr(isl::ctx Ctx, MDNode *FollowupLoopMD) { 487 // Don't need to id the followup. 488 // TODO: Append llvm.loop.disable_heustistics metadata unless overridden by 489 // user followup-MD 490 if (!FollowupLoopMD) 491 return {}; 492 493 BandAttr *Attr = new BandAttr(); 494 Attr->Metadata = FollowupLoopMD; 495 return getIslLoopAttr(Ctx, Attr); 496 } 497 498 /// A loop consists of a band and an optional marker that wraps it. Return the 499 /// outermost of the two. 500 501 /// That is, either the mark or, if there is not mark, the loop itself. Can 502 /// start with either the mark or the band. 503 static isl::schedule_node moveToBandMark(isl::schedule_node BandOrMark) { 504 if (isBandMark(BandOrMark)) { 505 assert(isBandWithSingleLoop(BandOrMark.child(0))); 506 return BandOrMark; 507 } 508 assert(isBandWithSingleLoop(BandOrMark)); 509 510 isl::schedule_node Mark = BandOrMark.parent(); 511 if (isBandMark(Mark)) 512 return Mark; 513 514 // Band has no loop marker. 515 return BandOrMark; 516 } 517 518 static isl::schedule_node removeMark(isl::schedule_node MarkOrBand, 519 BandAttr *&Attr) { 520 MarkOrBand = moveToBandMark(MarkOrBand); 521 522 isl::schedule_node Band; 523 if (isMark(MarkOrBand)) { 524 Attr = getLoopAttr(MarkOrBand.as<isl::schedule_node_mark>().get_id()); 525 Band = isl::manage(isl_schedule_node_delete(MarkOrBand.release())); 526 } else { 527 Attr = nullptr; 528 Band = MarkOrBand; 529 } 530 531 assert(isBandWithSingleLoop(Band)); 532 return Band; 533 } 534 535 /// Remove the mark that wraps a loop. Return the band representing the loop. 536 static isl::schedule_node removeMark(isl::schedule_node MarkOrBand) { 537 BandAttr *Attr; 538 return removeMark(MarkOrBand, Attr); 539 } 540 541 static isl::schedule_node insertMark(isl::schedule_node Band, isl::id Mark) { 542 assert(isBand(Band)); 543 assert(moveToBandMark(Band).is_equal(Band) && 544 "Don't add a two marks for a band"); 545 546 return Band.insert_mark(Mark).child(0); 547 } 548 549 /// Return the (one-dimensional) set of numbers that are divisible by @p Factor 550 /// with remainder @p Offset. 551 /// 552 /// isDivisibleBySet(Ctx, 4, 0) = { [i] : floord(i,4) = 0 } 553 /// isDivisibleBySet(Ctx, 4, 1) = { [i] : floord(i,4) = 1 } 554 /// 555 static isl::basic_set isDivisibleBySet(isl::ctx &Ctx, long Factor, 556 long Offset) { 557 isl::val ValFactor{Ctx, Factor}; 558 isl::val ValOffset{Ctx, Offset}; 559 560 isl::space Unispace{Ctx, 0, 1}; 561 isl::local_space LUnispace{Unispace}; 562 isl::aff AffFactor{LUnispace, ValFactor}; 563 isl::aff AffOffset{LUnispace, ValOffset}; 564 565 isl::aff Id = isl::aff::var_on_domain(LUnispace, isl::dim::out, 0); 566 isl::aff DivMul = Id.mod(ValFactor); 567 isl::basic_map Divisible = isl::basic_map::from_aff(DivMul); 568 isl::basic_map Modulo = Divisible.fix_val(isl::dim::out, 0, ValOffset); 569 return Modulo.domain(); 570 } 571 572 /// Make the last dimension of Set to take values from 0 to VectorWidth - 1. 573 /// 574 /// @param Set A set, which should be modified. 575 /// @param VectorWidth A parameter, which determines the constraint. 576 static isl::set addExtentConstraints(isl::set Set, int VectorWidth) { 577 unsigned Dims = unsignedFromIslSize(Set.tuple_dim()); 578 assert(Dims >= 1); 579 isl::space Space = Set.get_space(); 580 isl::local_space LocalSpace = isl::local_space(Space); 581 isl::constraint ExtConstr = isl::constraint::alloc_inequality(LocalSpace); 582 ExtConstr = ExtConstr.set_constant_si(0); 583 ExtConstr = ExtConstr.set_coefficient_si(isl::dim::set, Dims - 1, 1); 584 Set = Set.add_constraint(ExtConstr); 585 ExtConstr = isl::constraint::alloc_inequality(LocalSpace); 586 ExtConstr = ExtConstr.set_constant_si(VectorWidth - 1); 587 ExtConstr = ExtConstr.set_coefficient_si(isl::dim::set, Dims - 1, -1); 588 return Set.add_constraint(ExtConstr); 589 } 590 591 /// Collapse perfectly nested bands into a single band. 592 class BandCollapseRewriter : public ScheduleTreeRewriter<BandCollapseRewriter> { 593 private: 594 using BaseTy = ScheduleTreeRewriter<BandCollapseRewriter>; 595 BaseTy &getBase() { return *this; } 596 const BaseTy &getBase() const { return *this; } 597 598 public: 599 isl::schedule visitBand(isl::schedule_node_band RootBand) { 600 isl::schedule_node_band Band = RootBand; 601 isl::ctx Ctx = Band.ctx(); 602 603 // Do not merge permutable band to avoid loosing the permutability property. 604 // Cannot collapse even two permutable loops, they might be permutable 605 // individually, but not necassarily accross. 606 if (unsignedFromIslSize(Band.n_member()) > 1u && Band.permutable()) 607 return getBase().visitBand(Band); 608 609 // Find collapsable bands. 610 SmallVector<isl::schedule_node_band> Nest; 611 int NumTotalLoops = 0; 612 isl::schedule_node Body; 613 while (true) { 614 Nest.push_back(Band); 615 NumTotalLoops += unsignedFromIslSize(Band.n_member()); 616 Body = Band.first_child(); 617 if (!Body.isa<isl::schedule_node_band>()) 618 break; 619 Band = Body.as<isl::schedule_node_band>(); 620 621 // Do not include next band if it is permutable to not lose its 622 // permutability property. 623 if (unsignedFromIslSize(Band.n_member()) > 1u && Band.permutable()) 624 break; 625 } 626 627 // Nothing to collapse, preserve permutability. 628 if (Nest.size() <= 1) 629 return getBase().visitBand(Band); 630 631 LLVM_DEBUG({ 632 dbgs() << "Found loops to collapse between\n"; 633 dumpIslObj(RootBand, dbgs()); 634 dbgs() << "and\n"; 635 dumpIslObj(Body, dbgs()); 636 dbgs() << "\n"; 637 }); 638 639 isl::schedule NewBody = visit(Body); 640 641 // Collect partial schedules from all members. 642 isl::union_pw_aff_list PartScheds{Ctx, NumTotalLoops}; 643 for (isl::schedule_node_band Band : Nest) { 644 int NumLoops = unsignedFromIslSize(Band.n_member()); 645 isl::multi_union_pw_aff BandScheds = Band.get_partial_schedule(); 646 for (auto j : seq<int>(0, NumLoops)) 647 PartScheds = PartScheds.add(BandScheds.at(j)); 648 } 649 isl::space ScatterSpace = isl::space(Ctx, 0, NumTotalLoops); 650 isl::multi_union_pw_aff PartSchedsMulti{ScatterSpace, PartScheds}; 651 652 isl::schedule_node_band CollapsedBand = 653 NewBody.insert_partial_schedule(PartSchedsMulti) 654 .get_root() 655 .first_child() 656 .as<isl::schedule_node_band>(); 657 658 // Copy over loop attributes form original bands. 659 int LoopIdx = 0; 660 for (isl::schedule_node_band Band : Nest) { 661 int NumLoops = unsignedFromIslSize(Band.n_member()); 662 for (int i : seq<int>(0, NumLoops)) { 663 CollapsedBand = applyBandMemberAttributes(std::move(CollapsedBand), 664 LoopIdx, Band, i); 665 LoopIdx += 1; 666 } 667 } 668 assert(LoopIdx == NumTotalLoops && 669 "Expect the same number of loops to add up again"); 670 671 return CollapsedBand.get_schedule(); 672 } 673 }; 674 675 static isl::schedule collapseBands(isl::schedule Sched) { 676 LLVM_DEBUG(dbgs() << "Collapse bands in schedule\n"); 677 BandCollapseRewriter Rewriter; 678 return Rewriter.visit(Sched); 679 } 680 681 /// Collect sequentially executed bands (or anything else), even if nested in a 682 /// mark or other nodes whose child is executed just once. If we can 683 /// successfully fuse the bands, we allow them to be removed. 684 static void collectPotentiallyFusableBands( 685 isl::schedule_node Node, 686 SmallVectorImpl<std::pair<isl::schedule_node, isl::schedule_node>> 687 &ScheduleBands, 688 const isl::schedule_node &DirectChild) { 689 switch (isl_schedule_node_get_type(Node.get())) { 690 case isl_schedule_node_sequence: 691 case isl_schedule_node_set: 692 case isl_schedule_node_mark: 693 case isl_schedule_node_domain: 694 case isl_schedule_node_filter: 695 if (Node.has_children()) { 696 isl::schedule_node C = Node.first_child(); 697 while (true) { 698 collectPotentiallyFusableBands(C, ScheduleBands, DirectChild); 699 if (!C.has_next_sibling()) 700 break; 701 C = C.next_sibling(); 702 } 703 } 704 break; 705 706 default: 707 // Something that does not execute suquentially (e.g. a band) 708 ScheduleBands.push_back({Node, DirectChild}); 709 break; 710 } 711 } 712 713 /// Remove dependencies that are resolved by @p PartSched. That is, remove 714 /// everything that we already know is executed in-order. 715 static isl::union_map remainingDepsFromPartialSchedule(isl::union_map PartSched, 716 isl::union_map Deps) { 717 unsigned NumDims = getNumScatterDims(PartSched); 718 auto ParamSpace = PartSched.get_space().params(); 719 720 // { Scatter[] } 721 isl::space ScatterSpace = 722 ParamSpace.set_from_params().add_dims(isl::dim::set, NumDims); 723 724 // { Scatter[] -> Domain[] } 725 isl::union_map PartSchedRev = PartSched.reverse(); 726 727 // { Scatter[] -> Scatter[] } 728 isl::map MaybeBefore = isl::map::lex_le(ScatterSpace); 729 730 // { Domain[] -> Domain[] } 731 isl::union_map DomMaybeBefore = 732 MaybeBefore.apply_domain(PartSchedRev).apply_range(PartSchedRev); 733 734 // { Domain[] -> Domain[] } 735 isl::union_map ChildRemainingDeps = Deps.intersect(DomMaybeBefore); 736 737 return ChildRemainingDeps; 738 } 739 740 /// Remove dependencies that are resolved by executing them in the order 741 /// specified by @p Domains; 742 static isl::union_map remainigDepsFromSequence(ArrayRef<isl::union_set> Domains, 743 isl::union_map Deps) { 744 isl::ctx Ctx = Deps.ctx(); 745 isl::space ParamSpace = Deps.get_space().params(); 746 747 // Create a partial schedule mapping to constants that reflect the execution 748 // order. 749 isl::union_map PartialSchedules = isl::union_map::empty(Ctx); 750 for (auto P : enumerate(Domains)) { 751 isl::val ExecTime = isl::val(Ctx, P.index()); 752 isl::union_pw_aff DomSched{P.value(), ExecTime}; 753 PartialSchedules = PartialSchedules.unite(DomSched.as_union_map()); 754 } 755 756 return remainingDepsFromPartialSchedule(PartialSchedules, Deps); 757 } 758 759 /// Determine whether the outermost loop of to bands can be fused while 760 /// respecting validity dependencies. 761 static bool canFuseOutermost(const isl::schedule_node_band &LHS, 762 const isl::schedule_node_band &RHS, 763 const isl::union_map &Deps) { 764 // { LHSDomain[] -> Scatter[] } 765 isl::union_map LHSPartSched = 766 LHS.get_partial_schedule().get_at(0).as_union_map(); 767 768 // { Domain[] -> Scatter[] } 769 isl::union_map RHSPartSched = 770 RHS.get_partial_schedule().get_at(0).as_union_map(); 771 772 // Dependencies that are already resolved because LHS executes before RHS, but 773 // will not be anymore after fusion. { DefDomain[] -> UseDomain[] } 774 isl::union_map OrderedBySequence = 775 Deps.intersect_domain(LHSPartSched.domain()) 776 .intersect_range(RHSPartSched.domain()); 777 778 isl::space ParamSpace = OrderedBySequence.get_space().params(); 779 isl::space NewScatterSpace = ParamSpace.add_unnamed_tuple(1); 780 781 // { Scatter[] -> Scatter[] } 782 isl::map After = isl::map::lex_gt(NewScatterSpace); 783 784 // After fusion, instances with smaller (or equal, which means they will be 785 // executed in the same iteration, but the LHS instance is still sequenced 786 // before RHS) scatter value will still be executed before. This are the 787 // orderings where this is not necessarily the case. 788 // { LHSDomain[] -> RHSDomain[] } 789 isl::union_map MightBeAfterDoms = After.apply_domain(LHSPartSched.reverse()) 790 .apply_range(RHSPartSched.reverse()); 791 792 // Dependencies that are not resolved by the new execution order. 793 isl::union_map WithBefore = OrderedBySequence.intersect(MightBeAfterDoms); 794 795 return WithBefore.is_empty(); 796 } 797 798 /// Fuse @p LHS and @p RHS if possible while preserving validity dependenvies. 799 static isl::schedule tryGreedyFuse(isl::schedule_node_band LHS, 800 isl::schedule_node_band RHS, 801 const isl::union_map &Deps) { 802 if (!canFuseOutermost(LHS, RHS, Deps)) 803 return {}; 804 805 LLVM_DEBUG({ 806 dbgs() << "Found loops for greedy fusion:\n"; 807 dumpIslObj(LHS, dbgs()); 808 dbgs() << "and\n"; 809 dumpIslObj(RHS, dbgs()); 810 dbgs() << "\n"; 811 }); 812 813 // The partial schedule of the bands outermost loop that we need to combine 814 // for the fusion. 815 isl::union_pw_aff LHSPartOuterSched = LHS.get_partial_schedule().get_at(0); 816 isl::union_pw_aff RHSPartOuterSched = RHS.get_partial_schedule().get_at(0); 817 818 // Isolate band bodies as roots of their own schedule trees. 819 IdentityRewriter Rewriter; 820 isl::schedule LHSBody = Rewriter.visit(LHS.first_child()); 821 isl::schedule RHSBody = Rewriter.visit(RHS.first_child()); 822 823 // Reconstruct the non-outermost (not going to be fused) loops from both 824 // bands. 825 // TODO: Maybe it is possibly to transfer the 'permutability' property from 826 // LHS+RHS. At minimum we need merge multiple band members at once, otherwise 827 // permutability has no meaning. 828 isl::schedule LHSNewBody = 829 rebuildBand(LHS, LHSBody, [](int i) { return i > 0; }); 830 isl::schedule RHSNewBody = 831 rebuildBand(RHS, RHSBody, [](int i) { return i > 0; }); 832 833 // The loop body of the fused loop. 834 isl::schedule NewCommonBody = LHSNewBody.sequence(RHSNewBody); 835 836 // Combine the partial schedules of both loops to a new one. Instances with 837 // the same scatter value are put together. 838 isl::union_map NewCommonPartialSched = 839 LHSPartOuterSched.as_union_map().unite(RHSPartOuterSched.as_union_map()); 840 isl::schedule NewCommonSchedule = NewCommonBody.insert_partial_schedule( 841 NewCommonPartialSched.as_multi_union_pw_aff()); 842 843 return NewCommonSchedule; 844 } 845 846 static isl::schedule tryGreedyFuse(isl::schedule_node LHS, 847 isl::schedule_node RHS, 848 const isl::union_map &Deps) { 849 // TODO: Non-bands could be interpreted as a band with just as single 850 // iteration. However, this is only useful if both ends of a fused loop were 851 // originally loops themselves. 852 if (!LHS.isa<isl::schedule_node_band>()) 853 return {}; 854 if (!RHS.isa<isl::schedule_node_band>()) 855 return {}; 856 857 return tryGreedyFuse(LHS.as<isl::schedule_node_band>(), 858 RHS.as<isl::schedule_node_band>(), Deps); 859 } 860 861 /// Fuse all fusable loop top-down in a schedule tree. 862 /// 863 /// The isl::union_map parameters is the set of validity dependencies that have 864 /// not been resolved/carried by a parent schedule node. 865 class GreedyFusionRewriter 866 : public ScheduleTreeRewriter<GreedyFusionRewriter, isl::union_map> { 867 private: 868 using BaseTy = ScheduleTreeRewriter<GreedyFusionRewriter, isl::union_map>; 869 BaseTy &getBase() { return *this; } 870 const BaseTy &getBase() const { return *this; } 871 872 public: 873 /// Is set to true if anything has been fused. 874 bool AnyChange = false; 875 876 isl::schedule visitBand(isl::schedule_node_band Band, isl::union_map Deps) { 877 // { Domain[] -> Scatter[] } 878 isl::union_map PartSched = 879 isl::union_map::from(Band.get_partial_schedule()); 880 assert(getNumScatterDims(PartSched) == 881 unsignedFromIslSize(Band.n_member())); 882 isl::space ParamSpace = PartSched.get_space().params(); 883 884 // { Scatter[] -> Domain[] } 885 isl::union_map PartSchedRev = PartSched.reverse(); 886 887 // Possible within the same iteration. Dependencies with smaller scatter 888 // value are carried by this loop and therefore have been resolved by the 889 // in-order execution if the loop iteration. A dependency with small scatter 890 // value would be a dependency violation that we assume did not happen. { 891 // Domain[] -> Domain[] } 892 isl::union_map Unsequenced = PartSchedRev.apply_domain(PartSchedRev); 893 894 // Actual dependencies within the same iteration. 895 // { DefDomain[] -> UseDomain[] } 896 isl::union_map RemDeps = Deps.intersect(Unsequenced); 897 898 return getBase().visitBand(Band, RemDeps); 899 } 900 901 isl::schedule visitSequence(isl::schedule_node_sequence Sequence, 902 isl::union_map Deps) { 903 int NumChildren = isl_schedule_node_n_children(Sequence.get()); 904 905 // List of fusion candidates. The first element is the fusion candidate, the 906 // second is candidate's ancestor that is the sequence's direct child. It is 907 // preferable to use the direct child if not if its non-direct children is 908 // fused to preserve its structure such as mark nodes. 909 SmallVector<std::pair<isl::schedule_node, isl::schedule_node>> Bands; 910 for (auto i : seq<int>(0, NumChildren)) { 911 isl::schedule_node Child = Sequence.child(i); 912 collectPotentiallyFusableBands(Child, Bands, Child); 913 } 914 915 // Direct children that had at least one of its decendants fused. 916 SmallDenseSet<isl_schedule_node *, 4> ChangedDirectChildren; 917 918 // Fuse neigboring bands until reaching the end of candidates. 919 int i = 0; 920 while (i + 1 < (int)Bands.size()) { 921 isl::schedule Fused = 922 tryGreedyFuse(Bands[i].first, Bands[i + 1].first, Deps); 923 if (Fused.is_null()) { 924 // Cannot merge this node with the next; look at next pair. 925 i += 1; 926 continue; 927 } 928 929 // Mark the direct children as (partially) fused. 930 if (!Bands[i].second.is_null()) 931 ChangedDirectChildren.insert(Bands[i].second.get()); 932 if (!Bands[i + 1].second.is_null()) 933 ChangedDirectChildren.insert(Bands[i + 1].second.get()); 934 935 // Collapse the neigbros to a single new candidate that could be fused 936 // with the next candidate. 937 Bands[i] = {Fused.get_root(), {}}; 938 Bands.erase(Bands.begin() + i + 1); 939 940 AnyChange = true; 941 } 942 943 // By construction equal if done with collectPotentiallyFusableBands's 944 // output. 945 SmallVector<isl::union_set> SubDomains; 946 SubDomains.reserve(NumChildren); 947 for (int i = 0; i < NumChildren; i += 1) 948 SubDomains.push_back(Sequence.child(i).domain()); 949 auto SubRemainingDeps = remainigDepsFromSequence(SubDomains, Deps); 950 951 // We may iterate over direct children multiple times, be sure to add each 952 // at most once. 953 SmallDenseSet<isl_schedule_node *, 4> AlreadyAdded; 954 955 isl::schedule Result; 956 for (auto &P : Bands) { 957 isl::schedule_node MaybeFused = P.first; 958 isl::schedule_node DirectChild = P.second; 959 960 // If not modified, use the direct child. 961 if (!DirectChild.is_null() && 962 !ChangedDirectChildren.count(DirectChild.get())) { 963 if (AlreadyAdded.count(DirectChild.get())) 964 continue; 965 AlreadyAdded.insert(DirectChild.get()); 966 MaybeFused = DirectChild; 967 } else { 968 assert(AnyChange && 969 "Need changed flag for be consistent with actual change"); 970 } 971 972 // Top-down recursion: If the outermost loop has been fused, their nested 973 // bands might be fusable now as well. 974 isl::schedule InnerFused = visit(MaybeFused, SubRemainingDeps); 975 976 // Reconstruct the sequence, with some of the children fused. 977 if (Result.is_null()) 978 Result = InnerFused; 979 else 980 Result = Result.sequence(InnerFused); 981 } 982 983 return Result; 984 } 985 }; 986 987 } // namespace 988 989 bool polly::isBandMark(const isl::schedule_node &Node) { 990 return isMark(Node) && 991 isLoopAttr(Node.as<isl::schedule_node_mark>().get_id()); 992 } 993 994 BandAttr *polly::getBandAttr(isl::schedule_node MarkOrBand) { 995 MarkOrBand = moveToBandMark(MarkOrBand); 996 if (!isMark(MarkOrBand)) 997 return nullptr; 998 999 return getLoopAttr(MarkOrBand.as<isl::schedule_node_mark>().get_id()); 1000 } 1001 1002 isl::schedule polly::hoistExtensionNodes(isl::schedule Sched) { 1003 // If there is no extension node in the first place, return the original 1004 // schedule tree. 1005 if (!containsExtensionNode(Sched)) 1006 return Sched; 1007 1008 // Build options can anchor schedule nodes, such that the schedule tree cannot 1009 // be modified anymore. Therefore, apply build options after the tree has been 1010 // created. 1011 CollectASTBuildOptions Collector; 1012 Collector.visit(Sched); 1013 1014 // Rewrite the schedule tree without extension nodes. 1015 ExtensionNodeRewriter Rewriter; 1016 isl::schedule NewSched = Rewriter.visitSchedule(Sched); 1017 1018 // Reapply the AST build options. The rewriter must not change the iteration 1019 // order of bands. Any other node type is ignored. 1020 ApplyASTBuildOptions Applicator(Collector.ASTBuildOptions); 1021 NewSched = Applicator.visitSchedule(NewSched); 1022 1023 return NewSched; 1024 } 1025 1026 isl::schedule polly::applyFullUnroll(isl::schedule_node BandToUnroll) { 1027 isl::ctx Ctx = BandToUnroll.ctx(); 1028 1029 // Remove the loop's mark, the loop will disappear anyway. 1030 BandToUnroll = removeMark(BandToUnroll); 1031 assert(isBandWithSingleLoop(BandToUnroll)); 1032 1033 isl::multi_union_pw_aff PartialSched = isl::manage( 1034 isl_schedule_node_band_get_partial_schedule(BandToUnroll.get())); 1035 assert(unsignedFromIslSize(PartialSched.dim(isl::dim::out)) == 1u && 1036 "Can only unroll a single dimension"); 1037 isl::union_pw_aff PartialSchedUAff = PartialSched.at(0); 1038 1039 isl::union_set Domain = BandToUnroll.get_domain(); 1040 PartialSchedUAff = PartialSchedUAff.intersect_domain(Domain); 1041 isl::union_map PartialSchedUMap = 1042 isl::union_map::from(isl::union_pw_multi_aff(PartialSchedUAff)); 1043 1044 // Enumerator only the scatter elements. 1045 isl::union_set ScatterList = PartialSchedUMap.range(); 1046 1047 // Enumerate all loop iterations. 1048 // TODO: Diagnose if not enumerable or depends on a parameter. 1049 SmallVector<isl::point, 16> Elts; 1050 ScatterList.foreach_point([&Elts](isl::point P) -> isl::stat { 1051 Elts.push_back(P); 1052 return isl::stat::ok(); 1053 }); 1054 1055 // Don't assume that foreach_point returns in execution order. 1056 llvm::sort(Elts, [](isl::point P1, isl::point P2) -> bool { 1057 isl::val C1 = P1.get_coordinate_val(isl::dim::set, 0); 1058 isl::val C2 = P2.get_coordinate_val(isl::dim::set, 0); 1059 return C1.lt(C2); 1060 }); 1061 1062 // Convert the points to a sequence of filters. 1063 isl::union_set_list List = isl::union_set_list(Ctx, Elts.size()); 1064 for (isl::point P : Elts) { 1065 // Determine the domains that map this scatter element. 1066 isl::union_set DomainFilter = PartialSchedUMap.intersect_range(P).domain(); 1067 1068 List = List.add(DomainFilter); 1069 } 1070 1071 // Replace original band with unrolled sequence. 1072 isl::schedule_node Body = 1073 isl::manage(isl_schedule_node_delete(BandToUnroll.release())); 1074 Body = Body.insert_sequence(List); 1075 return Body.get_schedule(); 1076 } 1077 1078 isl::schedule polly::applyPartialUnroll(isl::schedule_node BandToUnroll, 1079 int Factor) { 1080 assert(Factor > 0 && "Positive unroll factor required"); 1081 isl::ctx Ctx = BandToUnroll.ctx(); 1082 1083 // Remove the mark, save the attribute for later use. 1084 BandAttr *Attr; 1085 BandToUnroll = removeMark(BandToUnroll, Attr); 1086 assert(isBandWithSingleLoop(BandToUnroll)); 1087 1088 isl::multi_union_pw_aff PartialSched = isl::manage( 1089 isl_schedule_node_band_get_partial_schedule(BandToUnroll.get())); 1090 1091 // { Stmt[] -> [x] } 1092 isl::union_pw_aff PartialSchedUAff = PartialSched.at(0); 1093 1094 // Here we assume the schedule stride is one and starts with 0, which is not 1095 // necessarily the case. 1096 isl::union_pw_aff StridedPartialSchedUAff = 1097 isl::union_pw_aff::empty(PartialSchedUAff.get_space()); 1098 isl::val ValFactor{Ctx, Factor}; 1099 PartialSchedUAff.foreach_pw_aff([&StridedPartialSchedUAff, 1100 &ValFactor](isl::pw_aff PwAff) -> isl::stat { 1101 isl::space Space = PwAff.get_space(); 1102 isl::set Universe = isl::set::universe(Space.domain()); 1103 isl::pw_aff AffFactor{Universe, ValFactor}; 1104 isl::pw_aff DivSchedAff = PwAff.div(AffFactor).floor().mul(AffFactor); 1105 StridedPartialSchedUAff = StridedPartialSchedUAff.union_add(DivSchedAff); 1106 return isl::stat::ok(); 1107 }); 1108 1109 isl::union_set_list List = isl::union_set_list(Ctx, Factor); 1110 for (auto i : seq<int>(0, Factor)) { 1111 // { Stmt[] -> [x] } 1112 isl::union_map UMap = 1113 isl::union_map::from(isl::union_pw_multi_aff(PartialSchedUAff)); 1114 1115 // { [x] } 1116 isl::basic_set Divisible = isDivisibleBySet(Ctx, Factor, i); 1117 1118 // { Stmt[] } 1119 isl::union_set UnrolledDomain = UMap.intersect_range(Divisible).domain(); 1120 1121 List = List.add(UnrolledDomain); 1122 } 1123 1124 isl::schedule_node Body = 1125 isl::manage(isl_schedule_node_delete(BandToUnroll.copy())); 1126 Body = Body.insert_sequence(List); 1127 isl::schedule_node NewLoop = 1128 Body.insert_partial_schedule(StridedPartialSchedUAff); 1129 1130 MDNode *FollowupMD = nullptr; 1131 if (Attr && Attr->Metadata) 1132 FollowupMD = 1133 findOptionalNodeOperand(Attr->Metadata, LLVMLoopUnrollFollowupUnrolled); 1134 1135 isl::id NewBandId = createGeneratedLoopAttr(Ctx, FollowupMD); 1136 if (!NewBandId.is_null()) 1137 NewLoop = insertMark(NewLoop, NewBandId); 1138 1139 return NewLoop.get_schedule(); 1140 } 1141 1142 isl::set polly::getPartialTilePrefixes(isl::set ScheduleRange, 1143 int VectorWidth) { 1144 unsigned Dims = unsignedFromIslSize(ScheduleRange.tuple_dim()); 1145 assert(Dims >= 1); 1146 isl::set LoopPrefixes = 1147 ScheduleRange.drop_constraints_involving_dims(isl::dim::set, Dims - 1, 1); 1148 auto ExtentPrefixes = addExtentConstraints(LoopPrefixes, VectorWidth); 1149 isl::set BadPrefixes = ExtentPrefixes.subtract(ScheduleRange); 1150 BadPrefixes = BadPrefixes.project_out(isl::dim::set, Dims - 1, 1); 1151 LoopPrefixes = LoopPrefixes.project_out(isl::dim::set, Dims - 1, 1); 1152 return LoopPrefixes.subtract(BadPrefixes); 1153 } 1154 1155 isl::union_set polly::getIsolateOptions(isl::set IsolateDomain, 1156 unsigned OutDimsNum) { 1157 unsigned Dims = unsignedFromIslSize(IsolateDomain.tuple_dim()); 1158 assert(OutDimsNum <= Dims && 1159 "The isl::set IsolateDomain is used to describe the range of schedule " 1160 "dimensions values, which should be isolated. Consequently, the " 1161 "number of its dimensions should be greater than or equal to the " 1162 "number of the schedule dimensions."); 1163 isl::map IsolateRelation = isl::map::from_domain(IsolateDomain); 1164 IsolateRelation = IsolateRelation.move_dims(isl::dim::out, 0, isl::dim::in, 1165 Dims - OutDimsNum, OutDimsNum); 1166 isl::set IsolateOption = IsolateRelation.wrap(); 1167 isl::id Id = isl::id::alloc(IsolateOption.ctx(), "isolate", nullptr); 1168 IsolateOption = IsolateOption.set_tuple_id(Id); 1169 return isl::union_set(IsolateOption); 1170 } 1171 1172 isl::union_set polly::getDimOptions(isl::ctx Ctx, const char *Option) { 1173 isl::space Space(Ctx, 0, 1); 1174 auto DimOption = isl::set::universe(Space); 1175 auto Id = isl::id::alloc(Ctx, Option, nullptr); 1176 DimOption = DimOption.set_tuple_id(Id); 1177 return isl::union_set(DimOption); 1178 } 1179 1180 isl::schedule_node polly::tileNode(isl::schedule_node Node, 1181 const char *Identifier, 1182 ArrayRef<int> TileSizes, 1183 int DefaultTileSize) { 1184 auto Space = isl::manage(isl_schedule_node_band_get_space(Node.get())); 1185 auto Dims = Space.dim(isl::dim::set); 1186 auto Sizes = isl::multi_val::zero(Space); 1187 std::string IdentifierString(Identifier); 1188 for (unsigned i : rangeIslSize(0, Dims)) { 1189 unsigned tileSize = i < TileSizes.size() ? TileSizes[i] : DefaultTileSize; 1190 Sizes = Sizes.set_val(i, isl::val(Node.ctx(), tileSize)); 1191 } 1192 auto TileLoopMarkerStr = IdentifierString + " - Tiles"; 1193 auto TileLoopMarker = isl::id::alloc(Node.ctx(), TileLoopMarkerStr, nullptr); 1194 Node = Node.insert_mark(TileLoopMarker); 1195 Node = Node.child(0); 1196 Node = 1197 isl::manage(isl_schedule_node_band_tile(Node.release(), Sizes.release())); 1198 Node = Node.child(0); 1199 auto PointLoopMarkerStr = IdentifierString + " - Points"; 1200 auto PointLoopMarker = 1201 isl::id::alloc(Node.ctx(), PointLoopMarkerStr, nullptr); 1202 Node = Node.insert_mark(PointLoopMarker); 1203 return Node.child(0); 1204 } 1205 1206 isl::schedule_node polly::applyRegisterTiling(isl::schedule_node Node, 1207 ArrayRef<int> TileSizes, 1208 int DefaultTileSize) { 1209 Node = tileNode(Node, "Register tiling", TileSizes, DefaultTileSize); 1210 auto Ctx = Node.ctx(); 1211 return Node.as<isl::schedule_node_band>().set_ast_build_options( 1212 isl::union_set(Ctx, "{unroll[x]}")); 1213 } 1214 1215 /// Find statements and sub-loops in (possibly nested) sequences. 1216 static void 1217 collectFussionableStmts(isl::schedule_node Node, 1218 SmallVectorImpl<isl::schedule_node> &ScheduleStmts) { 1219 if (isBand(Node) || isLeaf(Node)) { 1220 ScheduleStmts.push_back(Node); 1221 return; 1222 } 1223 1224 if (Node.has_children()) { 1225 isl::schedule_node C = Node.first_child(); 1226 while (true) { 1227 collectFussionableStmts(C, ScheduleStmts); 1228 if (!C.has_next_sibling()) 1229 break; 1230 C = C.next_sibling(); 1231 } 1232 } 1233 } 1234 1235 isl::schedule polly::applyMaxFission(isl::schedule_node BandToFission) { 1236 isl::ctx Ctx = BandToFission.ctx(); 1237 BandToFission = removeMark(BandToFission); 1238 isl::schedule_node BandBody = BandToFission.child(0); 1239 1240 SmallVector<isl::schedule_node> FissionableStmts; 1241 collectFussionableStmts(BandBody, FissionableStmts); 1242 size_t N = FissionableStmts.size(); 1243 1244 // Collect the domain for each of the statements that will get their own loop. 1245 isl::union_set_list DomList = isl::union_set_list(Ctx, N); 1246 for (size_t i = 0; i < N; ++i) { 1247 isl::schedule_node BodyPart = FissionableStmts[i]; 1248 DomList = DomList.add(BodyPart.get_domain()); 1249 } 1250 1251 // Apply the fission by copying the entire loop, but inserting a filter for 1252 // the statement domains for each fissioned loop. 1253 isl::schedule_node Fissioned = BandToFission.insert_sequence(DomList); 1254 1255 return Fissioned.get_schedule(); 1256 } 1257 1258 isl::schedule polly::applyGreedyFusion(isl::schedule Sched, 1259 const isl::union_map &Deps) { 1260 LLVM_DEBUG(dbgs() << "Greedy loop fusion\n"); 1261 1262 GreedyFusionRewriter Rewriter; 1263 isl::schedule Result = Rewriter.visit(Sched, Deps); 1264 if (!Rewriter.AnyChange) { 1265 LLVM_DEBUG(dbgs() << "Found nothing to fuse\n"); 1266 return Sched; 1267 } 1268 1269 // GreedyFusionRewriter due to working loop-by-loop, bands with multiple loops 1270 // may have been split into multiple bands. 1271 return collapseBands(Result); 1272 } 1273