1 //===- polly/ScheduleTreeTransform.cpp --------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Make changes to isl's schedule tree data structure. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "polly/ScheduleTreeTransform.h" 14 #include "polly/Support/GICHelper.h" 15 #include "polly/Support/ISLTools.h" 16 #include "polly/Support/ScopHelper.h" 17 #include "llvm/ADT/ArrayRef.h" 18 #include "llvm/ADT/Sequence.h" 19 #include "llvm/ADT/SmallVector.h" 20 #include "llvm/IR/Constants.h" 21 #include "llvm/IR/Metadata.h" 22 #include "llvm/Transforms/Utils/UnrollLoop.h" 23 24 #define DEBUG_TYPE "polly-opt-isl" 25 26 using namespace polly; 27 using namespace llvm; 28 29 namespace { 30 31 /// Copy the band member attributes (coincidence, loop type, isolate ast loop 32 /// type) from one band to another. 33 static isl::schedule_node_band 34 applyBandMemberAttributes(isl::schedule_node_band Target, int TargetIdx, 35 const isl::schedule_node_band &Source, 36 int SourceIdx) { 37 bool Coincident = Source.member_get_coincident(SourceIdx).release(); 38 Target = Target.member_set_coincident(TargetIdx, Coincident); 39 40 isl_ast_loop_type LoopType = 41 isl_schedule_node_band_member_get_ast_loop_type(Source.get(), SourceIdx); 42 Target = isl::manage(isl_schedule_node_band_member_set_ast_loop_type( 43 Target.release(), TargetIdx, LoopType)) 44 .as<isl::schedule_node_band>(); 45 46 isl_ast_loop_type IsolateType = 47 isl_schedule_node_band_member_get_isolate_ast_loop_type(Source.get(), 48 SourceIdx); 49 Target = isl::manage(isl_schedule_node_band_member_set_isolate_ast_loop_type( 50 Target.release(), TargetIdx, IsolateType)) 51 .as<isl::schedule_node_band>(); 52 53 return Target; 54 } 55 56 /// Create a new band by copying members from another @p Band. @p IncludeCb 57 /// decides which band indices are copied to the result. 58 template <typename CbTy> 59 static isl::schedule rebuildBand(isl::schedule_node_band OldBand, 60 isl::schedule Body, CbTy IncludeCb) { 61 int NumBandDims = unsignedFromIslSize(OldBand.n_member()); 62 63 bool ExcludeAny = false; 64 bool IncludeAny = false; 65 for (auto OldIdx : seq<int>(0, NumBandDims)) { 66 if (IncludeCb(OldIdx)) 67 IncludeAny = true; 68 else 69 ExcludeAny = true; 70 } 71 72 // Instead of creating a zero-member band, don't create a band at all. 73 if (!IncludeAny) 74 return Body; 75 76 isl::multi_union_pw_aff PartialSched = OldBand.get_partial_schedule(); 77 isl::multi_union_pw_aff NewPartialSched; 78 if (ExcludeAny) { 79 // Select the included partial scatter functions. 80 isl::union_pw_aff_list List = PartialSched.list(); 81 int NewIdx = 0; 82 for (auto OldIdx : seq<int>(0, NumBandDims)) { 83 if (IncludeCb(OldIdx)) 84 NewIdx += 1; 85 else 86 List = List.drop(NewIdx, 1); 87 } 88 isl::space ParamSpace = PartialSched.get_space().params(); 89 isl::space NewScatterSpace = ParamSpace.add_unnamed_tuple(NewIdx); 90 NewPartialSched = isl::multi_union_pw_aff(NewScatterSpace, List); 91 } else { 92 // Just reuse original scatter function of copying all of them. 93 NewPartialSched = PartialSched; 94 } 95 96 // Create the new band node. 97 isl::schedule_node_band NewBand = 98 Body.insert_partial_schedule(NewPartialSched) 99 .get_root() 100 .child(0) 101 .as<isl::schedule_node_band>(); 102 103 // If OldBand was permutable, so is the new one, even if some dimensions are 104 // missing. 105 bool IsPermutable = OldBand.permutable().release(); 106 NewBand = NewBand.set_permutable(IsPermutable); 107 108 // Reapply member attributes. 109 int NewIdx = 0; 110 for (auto OldIdx : seq<int>(0, NumBandDims)) { 111 if (!IncludeCb(OldIdx)) 112 continue; 113 NewBand = 114 applyBandMemberAttributes(std::move(NewBand), NewIdx, OldBand, OldIdx); 115 NewIdx += 1; 116 } 117 118 return NewBand.get_schedule(); 119 } 120 121 /// Rewrite a schedule tree by reconstructing it bottom-up. 122 /// 123 /// By default, the original schedule tree is reconstructed. To build a 124 /// different tree, redefine visitor methods in a derived class (CRTP). 125 /// 126 /// Note that AST build options are not applied; Setting the isolate[] option 127 /// makes the schedule tree 'anchored' and cannot be modified afterwards. Hence, 128 /// AST build options must be set after the tree has been constructed. 129 template <typename Derived, typename... Args> 130 struct ScheduleTreeRewriter 131 : public RecursiveScheduleTreeVisitor<Derived, isl::schedule, Args...> { 132 Derived &getDerived() { return *static_cast<Derived *>(this); } 133 const Derived &getDerived() const { 134 return *static_cast<const Derived *>(this); 135 } 136 137 isl::schedule visitDomain(isl::schedule_node_domain Node, Args... args) { 138 // Every schedule_tree already has a domain node, no need to add one. 139 return getDerived().visit(Node.first_child(), std::forward<Args>(args)...); 140 } 141 142 isl::schedule visitBand(isl::schedule_node_band Band, Args... args) { 143 isl::schedule NewChild = 144 getDerived().visit(Band.child(0), std::forward<Args>(args)...); 145 return rebuildBand(Band, NewChild, [](int) { return true; }); 146 } 147 148 isl::schedule visitSequence(isl::schedule_node_sequence Sequence, 149 Args... args) { 150 int NumChildren = isl_schedule_node_n_children(Sequence.get()); 151 isl::schedule Result = 152 getDerived().visit(Sequence.child(0), std::forward<Args>(args)...); 153 for (int i = 1; i < NumChildren; i += 1) 154 Result = Result.sequence( 155 getDerived().visit(Sequence.child(i), std::forward<Args>(args)...)); 156 return Result; 157 } 158 159 isl::schedule visitSet(isl::schedule_node_set Set, Args... args) { 160 int NumChildren = isl_schedule_node_n_children(Set.get()); 161 isl::schedule Result = 162 getDerived().visit(Set.child(0), std::forward<Args>(args)...); 163 for (int i = 1; i < NumChildren; i += 1) 164 Result = isl::manage( 165 isl_schedule_set(Result.release(), 166 getDerived() 167 .visit(Set.child(i), std::forward<Args>(args)...) 168 .release())); 169 return Result; 170 } 171 172 isl::schedule visitLeaf(isl::schedule_node_leaf Leaf, Args... args) { 173 return isl::schedule::from_domain(Leaf.get_domain()); 174 } 175 176 isl::schedule visitMark(const isl::schedule_node &Mark, Args... args) { 177 178 isl::id TheMark = Mark.as<isl::schedule_node_mark>().get_id(); 179 isl::schedule_node NewChild = 180 getDerived() 181 .visit(Mark.first_child(), std::forward<Args>(args)...) 182 .get_root() 183 .first_child(); 184 return NewChild.insert_mark(TheMark).get_schedule(); 185 } 186 187 isl::schedule visitExtension(isl::schedule_node_extension Extension, 188 Args... args) { 189 isl::union_map TheExtension = 190 Extension.as<isl::schedule_node_extension>().get_extension(); 191 isl::schedule_node NewChild = getDerived() 192 .visit(Extension.child(0), args...) 193 .get_root() 194 .first_child(); 195 isl::schedule_node NewExtension = 196 isl::schedule_node::from_extension(TheExtension); 197 return NewChild.graft_before(NewExtension).get_schedule(); 198 } 199 200 isl::schedule visitFilter(isl::schedule_node_filter Filter, Args... args) { 201 isl::union_set FilterDomain = 202 Filter.as<isl::schedule_node_filter>().get_filter(); 203 isl::schedule NewSchedule = 204 getDerived().visit(Filter.child(0), std::forward<Args>(args)...); 205 return NewSchedule.intersect_domain(FilterDomain); 206 } 207 208 isl::schedule visitNode(isl::schedule_node Node, Args... args) { 209 llvm_unreachable("Not implemented"); 210 } 211 }; 212 213 /// Rewrite the schedule tree without any changes. Useful to copy a subtree into 214 /// a new schedule, discarding everything but. 215 struct IdentityRewriter : public ScheduleTreeRewriter<IdentityRewriter> {}; 216 217 /// Rewrite a schedule tree to an equivalent one without extension nodes. 218 /// 219 /// Each visit method takes two additional arguments: 220 /// 221 /// * The new domain the node, which is the inherited domain plus any domains 222 /// added by extension nodes. 223 /// 224 /// * A map of extension domains of all children is returned; it is required by 225 /// band nodes to schedule the additional domains at the same position as the 226 /// extension node would. 227 /// 228 struct ExtensionNodeRewriter 229 : public ScheduleTreeRewriter<ExtensionNodeRewriter, const isl::union_set &, 230 isl::union_map &> { 231 using BaseTy = ScheduleTreeRewriter<ExtensionNodeRewriter, 232 const isl::union_set &, isl::union_map &>; 233 BaseTy &getBase() { return *this; } 234 const BaseTy &getBase() const { return *this; } 235 236 isl::schedule visitSchedule(isl::schedule Schedule) { 237 isl::union_map Extensions; 238 isl::schedule Result = 239 visit(Schedule.get_root(), Schedule.get_domain(), Extensions); 240 assert(!Extensions.is_null() && Extensions.is_empty()); 241 return Result; 242 } 243 244 isl::schedule visitSequence(isl::schedule_node_sequence Sequence, 245 const isl::union_set &Domain, 246 isl::union_map &Extensions) { 247 int NumChildren = isl_schedule_node_n_children(Sequence.get()); 248 isl::schedule NewNode = visit(Sequence.first_child(), Domain, Extensions); 249 for (int i = 1; i < NumChildren; i += 1) { 250 isl::schedule_node OldChild = Sequence.child(i); 251 isl::union_map NewChildExtensions; 252 isl::schedule NewChildNode = visit(OldChild, Domain, NewChildExtensions); 253 NewNode = NewNode.sequence(NewChildNode); 254 Extensions = Extensions.unite(NewChildExtensions); 255 } 256 return NewNode; 257 } 258 259 isl::schedule visitSet(isl::schedule_node_set Set, 260 const isl::union_set &Domain, 261 isl::union_map &Extensions) { 262 int NumChildren = isl_schedule_node_n_children(Set.get()); 263 isl::schedule NewNode = visit(Set.first_child(), Domain, Extensions); 264 for (int i = 1; i < NumChildren; i += 1) { 265 isl::schedule_node OldChild = Set.child(i); 266 isl::union_map NewChildExtensions; 267 isl::schedule NewChildNode = visit(OldChild, Domain, NewChildExtensions); 268 NewNode = isl::manage( 269 isl_schedule_set(NewNode.release(), NewChildNode.release())); 270 Extensions = Extensions.unite(NewChildExtensions); 271 } 272 return NewNode; 273 } 274 275 isl::schedule visitLeaf(isl::schedule_node_leaf Leaf, 276 const isl::union_set &Domain, 277 isl::union_map &Extensions) { 278 Extensions = isl::union_map::empty(Leaf.ctx()); 279 return isl::schedule::from_domain(Domain); 280 } 281 282 isl::schedule visitBand(isl::schedule_node_band OldNode, 283 const isl::union_set &Domain, 284 isl::union_map &OuterExtensions) { 285 isl::schedule_node OldChild = OldNode.first_child(); 286 isl::multi_union_pw_aff PartialSched = 287 isl::manage(isl_schedule_node_band_get_partial_schedule(OldNode.get())); 288 289 isl::union_map NewChildExtensions; 290 isl::schedule NewChild = visit(OldChild, Domain, NewChildExtensions); 291 292 // Add the extensions to the partial schedule. 293 OuterExtensions = isl::union_map::empty(NewChildExtensions.ctx()); 294 isl::union_map NewPartialSchedMap = isl::union_map::from(PartialSched); 295 unsigned BandDims = isl_schedule_node_band_n_member(OldNode.get()); 296 for (isl::map Ext : NewChildExtensions.get_map_list()) { 297 unsigned ExtDims = unsignedFromIslSize(Ext.domain_tuple_dim()); 298 assert(ExtDims >= BandDims); 299 unsigned OuterDims = ExtDims - BandDims; 300 301 isl::map BandSched = 302 Ext.project_out(isl::dim::in, 0, OuterDims).reverse(); 303 NewPartialSchedMap = NewPartialSchedMap.unite(BandSched); 304 305 // There might be more outer bands that have to schedule the extensions. 306 if (OuterDims > 0) { 307 isl::map OuterSched = 308 Ext.project_out(isl::dim::in, OuterDims, BandDims); 309 OuterExtensions = OuterExtensions.unite(OuterSched); 310 } 311 } 312 isl::multi_union_pw_aff NewPartialSchedAsAsMultiUnionPwAff = 313 isl::multi_union_pw_aff::from_union_map(NewPartialSchedMap); 314 isl::schedule_node NewNode = 315 NewChild.insert_partial_schedule(NewPartialSchedAsAsMultiUnionPwAff) 316 .get_root() 317 .child(0); 318 319 // Reapply permutability and coincidence attributes. 320 NewNode = isl::manage(isl_schedule_node_band_set_permutable( 321 NewNode.release(), 322 isl_schedule_node_band_get_permutable(OldNode.get()))); 323 for (unsigned i = 0; i < BandDims; i += 1) 324 NewNode = applyBandMemberAttributes(NewNode.as<isl::schedule_node_band>(), 325 i, OldNode, i); 326 327 return NewNode.get_schedule(); 328 } 329 330 isl::schedule visitFilter(isl::schedule_node_filter Filter, 331 const isl::union_set &Domain, 332 isl::union_map &Extensions) { 333 isl::union_set FilterDomain = 334 Filter.as<isl::schedule_node_filter>().get_filter(); 335 isl::union_set NewDomain = Domain.intersect(FilterDomain); 336 337 // A filter is added implicitly if necessary when joining schedule trees. 338 return visit(Filter.first_child(), NewDomain, Extensions); 339 } 340 341 isl::schedule visitExtension(isl::schedule_node_extension Extension, 342 const isl::union_set &Domain, 343 isl::union_map &Extensions) { 344 isl::union_map ExtDomain = 345 Extension.as<isl::schedule_node_extension>().get_extension(); 346 isl::union_set NewDomain = Domain.unite(ExtDomain.range()); 347 isl::union_map ChildExtensions; 348 isl::schedule NewChild = 349 visit(Extension.first_child(), NewDomain, ChildExtensions); 350 Extensions = ChildExtensions.unite(ExtDomain); 351 return NewChild; 352 } 353 }; 354 355 /// Collect all AST build options in any schedule tree band. 356 /// 357 /// ScheduleTreeRewriter cannot apply the schedule tree options. This class 358 /// collects these options to apply them later. 359 struct CollectASTBuildOptions 360 : public RecursiveScheduleTreeVisitor<CollectASTBuildOptions> { 361 using BaseTy = RecursiveScheduleTreeVisitor<CollectASTBuildOptions>; 362 BaseTy &getBase() { return *this; } 363 const BaseTy &getBase() const { return *this; } 364 365 llvm::SmallVector<isl::union_set, 8> ASTBuildOptions; 366 367 void visitBand(isl::schedule_node_band Band) { 368 ASTBuildOptions.push_back( 369 isl::manage(isl_schedule_node_band_get_ast_build_options(Band.get()))); 370 return getBase().visitBand(Band); 371 } 372 }; 373 374 /// Apply AST build options to the bands in a schedule tree. 375 /// 376 /// This rewrites a schedule tree with the AST build options applied. We assume 377 /// that the band nodes are visited in the same order as they were when the 378 /// build options were collected, typically by CollectASTBuildOptions. 379 struct ApplyASTBuildOptions 380 : public ScheduleNodeRewriter<ApplyASTBuildOptions> { 381 using BaseTy = ScheduleNodeRewriter<ApplyASTBuildOptions>; 382 BaseTy &getBase() { return *this; } 383 const BaseTy &getBase() const { return *this; } 384 385 size_t Pos; 386 llvm::ArrayRef<isl::union_set> ASTBuildOptions; 387 388 ApplyASTBuildOptions(llvm::ArrayRef<isl::union_set> ASTBuildOptions) 389 : ASTBuildOptions(ASTBuildOptions) {} 390 391 isl::schedule visitSchedule(isl::schedule Schedule) { 392 Pos = 0; 393 isl::schedule Result = visit(Schedule).get_schedule(); 394 assert(Pos == ASTBuildOptions.size() && 395 "AST build options must match to band nodes"); 396 return Result; 397 } 398 399 isl::schedule_node visitBand(isl::schedule_node_band Band) { 400 isl::schedule_node_band Result = 401 Band.set_ast_build_options(ASTBuildOptions[Pos]); 402 Pos += 1; 403 return getBase().visitBand(Result); 404 } 405 }; 406 407 /// Return whether the schedule contains an extension node. 408 static bool containsExtensionNode(isl::schedule Schedule) { 409 assert(!Schedule.is_null()); 410 411 auto Callback = [](__isl_keep isl_schedule_node *Node, 412 void *User) -> isl_bool { 413 if (isl_schedule_node_get_type(Node) == isl_schedule_node_extension) { 414 // Stop walking the schedule tree. 415 return isl_bool_error; 416 } 417 418 // Continue searching the subtree. 419 return isl_bool_true; 420 }; 421 isl_stat RetVal = isl_schedule_foreach_schedule_node_top_down( 422 Schedule.get(), Callback, nullptr); 423 424 // We assume that the traversal itself does not fail, i.e. the only reason to 425 // return isl_stat_error is that an extension node was found. 426 return RetVal == isl_stat_error; 427 } 428 429 /// Find a named MDNode property in a LoopID. 430 static MDNode *findOptionalNodeOperand(MDNode *LoopMD, StringRef Name) { 431 return dyn_cast_or_null<MDNode>( 432 findMetadataOperand(LoopMD, Name).getValueOr(nullptr)); 433 } 434 435 /// Is this node of type mark? 436 static bool isMark(const isl::schedule_node &Node) { 437 return isl_schedule_node_get_type(Node.get()) == isl_schedule_node_mark; 438 } 439 440 /// Is this node of type band? 441 static bool isBand(const isl::schedule_node &Node) { 442 return isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band; 443 } 444 445 #ifndef NDEBUG 446 /// Is this node a band of a single dimension (i.e. could represent a loop)? 447 static bool isBandWithSingleLoop(const isl::schedule_node &Node) { 448 return isBand(Node) && isl_schedule_node_band_n_member(Node.get()) == 1; 449 } 450 #endif 451 452 static bool isLeaf(const isl::schedule_node &Node) { 453 return isl_schedule_node_get_type(Node.get()) == isl_schedule_node_leaf; 454 } 455 456 /// Create an isl::id representing the output loop after a transformation. 457 static isl::id createGeneratedLoopAttr(isl::ctx Ctx, MDNode *FollowupLoopMD) { 458 // Don't need to id the followup. 459 // TODO: Append llvm.loop.disable_heustistics metadata unless overridden by 460 // user followup-MD 461 if (!FollowupLoopMD) 462 return {}; 463 464 BandAttr *Attr = new BandAttr(); 465 Attr->Metadata = FollowupLoopMD; 466 return getIslLoopAttr(Ctx, Attr); 467 } 468 469 /// A loop consists of a band and an optional marker that wraps it. Return the 470 /// outermost of the two. 471 472 /// That is, either the mark or, if there is not mark, the loop itself. Can 473 /// start with either the mark or the band. 474 static isl::schedule_node moveToBandMark(isl::schedule_node BandOrMark) { 475 if (isBandMark(BandOrMark)) { 476 assert(isBandWithSingleLoop(BandOrMark.child(0))); 477 return BandOrMark; 478 } 479 assert(isBandWithSingleLoop(BandOrMark)); 480 481 isl::schedule_node Mark = BandOrMark.parent(); 482 if (isBandMark(Mark)) 483 return Mark; 484 485 // Band has no loop marker. 486 return BandOrMark; 487 } 488 489 static isl::schedule_node removeMark(isl::schedule_node MarkOrBand, 490 BandAttr *&Attr) { 491 MarkOrBand = moveToBandMark(MarkOrBand); 492 493 isl::schedule_node Band; 494 if (isMark(MarkOrBand)) { 495 Attr = getLoopAttr(MarkOrBand.as<isl::schedule_node_mark>().get_id()); 496 Band = isl::manage(isl_schedule_node_delete(MarkOrBand.release())); 497 } else { 498 Attr = nullptr; 499 Band = MarkOrBand; 500 } 501 502 assert(isBandWithSingleLoop(Band)); 503 return Band; 504 } 505 506 /// Remove the mark that wraps a loop. Return the band representing the loop. 507 static isl::schedule_node removeMark(isl::schedule_node MarkOrBand) { 508 BandAttr *Attr; 509 return removeMark(MarkOrBand, Attr); 510 } 511 512 static isl::schedule_node insertMark(isl::schedule_node Band, isl::id Mark) { 513 assert(isBand(Band)); 514 assert(moveToBandMark(Band).is_equal(Band) && 515 "Don't add a two marks for a band"); 516 517 return Band.insert_mark(Mark).child(0); 518 } 519 520 /// Return the (one-dimensional) set of numbers that are divisible by @p Factor 521 /// with remainder @p Offset. 522 /// 523 /// isDivisibleBySet(Ctx, 4, 0) = { [i] : floord(i,4) = 0 } 524 /// isDivisibleBySet(Ctx, 4, 1) = { [i] : floord(i,4) = 1 } 525 /// 526 static isl::basic_set isDivisibleBySet(isl::ctx &Ctx, long Factor, 527 long Offset) { 528 isl::val ValFactor{Ctx, Factor}; 529 isl::val ValOffset{Ctx, Offset}; 530 531 isl::space Unispace{Ctx, 0, 1}; 532 isl::local_space LUnispace{Unispace}; 533 isl::aff AffFactor{LUnispace, ValFactor}; 534 isl::aff AffOffset{LUnispace, ValOffset}; 535 536 isl::aff Id = isl::aff::var_on_domain(LUnispace, isl::dim::out, 0); 537 isl::aff DivMul = Id.mod(ValFactor); 538 isl::basic_map Divisible = isl::basic_map::from_aff(DivMul); 539 isl::basic_map Modulo = Divisible.fix_val(isl::dim::out, 0, ValOffset); 540 return Modulo.domain(); 541 } 542 543 /// Make the last dimension of Set to take values from 0 to VectorWidth - 1. 544 /// 545 /// @param Set A set, which should be modified. 546 /// @param VectorWidth A parameter, which determines the constraint. 547 static isl::set addExtentConstraints(isl::set Set, int VectorWidth) { 548 unsigned Dims = unsignedFromIslSize(Set.tuple_dim()); 549 assert(Dims >= 1); 550 isl::space Space = Set.get_space(); 551 isl::local_space LocalSpace = isl::local_space(Space); 552 isl::constraint ExtConstr = isl::constraint::alloc_inequality(LocalSpace); 553 ExtConstr = ExtConstr.set_constant_si(0); 554 ExtConstr = ExtConstr.set_coefficient_si(isl::dim::set, Dims - 1, 1); 555 Set = Set.add_constraint(ExtConstr); 556 ExtConstr = isl::constraint::alloc_inequality(LocalSpace); 557 ExtConstr = ExtConstr.set_constant_si(VectorWidth - 1); 558 ExtConstr = ExtConstr.set_coefficient_si(isl::dim::set, Dims - 1, -1); 559 return Set.add_constraint(ExtConstr); 560 } 561 562 /// Collapse perfectly nested bands into a single band. 563 class BandCollapseRewriter : public ScheduleTreeRewriter<BandCollapseRewriter> { 564 private: 565 using BaseTy = ScheduleTreeRewriter<BandCollapseRewriter>; 566 BaseTy &getBase() { return *this; } 567 const BaseTy &getBase() const { return *this; } 568 569 public: 570 isl::schedule visitBand(isl::schedule_node_band RootBand) { 571 isl::schedule_node_band Band = RootBand; 572 isl::ctx Ctx = Band.ctx(); 573 574 // Do not merge permutable band to avoid loosing the permutability property. 575 // Cannot collapse even two permutable loops, they might be permutable 576 // individually, but not necassarily accross. 577 if (unsignedFromIslSize(Band.n_member()) > 1u && Band.permutable()) 578 return getBase().visitBand(Band); 579 580 // Find collapsable bands. 581 SmallVector<isl::schedule_node_band> Nest; 582 int NumTotalLoops = 0; 583 isl::schedule_node Body; 584 while (true) { 585 Nest.push_back(Band); 586 NumTotalLoops += unsignedFromIslSize(Band.n_member()); 587 Body = Band.first_child(); 588 if (!Body.isa<isl::schedule_node_band>()) 589 break; 590 Band = Body.as<isl::schedule_node_band>(); 591 592 // Do not include next band if it is permutable to not lose its 593 // permutability property. 594 if (unsignedFromIslSize(Band.n_member()) > 1u && Band.permutable()) 595 break; 596 } 597 598 // Nothing to collapse, preserve permutability. 599 if (Nest.size() <= 1) 600 return getBase().visitBand(Band); 601 602 LLVM_DEBUG({ 603 dbgs() << "Found loops to collapse between\n"; 604 dumpIslObj(RootBand, dbgs()); 605 dbgs() << "and\n"; 606 dumpIslObj(Body, dbgs()); 607 dbgs() << "\n"; 608 }); 609 610 isl::schedule NewBody = visit(Body); 611 612 // Collect partial schedules from all members. 613 isl::union_pw_aff_list PartScheds{Ctx, NumTotalLoops}; 614 for (isl::schedule_node_band Band : Nest) { 615 int NumLoops = unsignedFromIslSize(Band.n_member()); 616 isl::multi_union_pw_aff BandScheds = Band.get_partial_schedule(); 617 for (auto j : seq<int>(0, NumLoops)) 618 PartScheds = PartScheds.add(BandScheds.at(j)); 619 } 620 isl::space ScatterSpace = isl::space(Ctx, 0, NumTotalLoops); 621 isl::multi_union_pw_aff PartSchedsMulti{ScatterSpace, PartScheds}; 622 623 isl::schedule_node_band CollapsedBand = 624 NewBody.insert_partial_schedule(PartSchedsMulti) 625 .get_root() 626 .first_child() 627 .as<isl::schedule_node_band>(); 628 629 // Copy over loop attributes form original bands. 630 int LoopIdx = 0; 631 for (isl::schedule_node_band Band : Nest) { 632 int NumLoops = unsignedFromIslSize(Band.n_member()); 633 for (int i : seq<int>(0, NumLoops)) { 634 CollapsedBand = applyBandMemberAttributes(std::move(CollapsedBand), 635 LoopIdx, Band, i); 636 LoopIdx += 1; 637 } 638 } 639 assert(LoopIdx == NumTotalLoops && 640 "Expect the same number of loops to add up again"); 641 642 return CollapsedBand.get_schedule(); 643 } 644 }; 645 646 static isl::schedule collapseBands(isl::schedule Sched) { 647 LLVM_DEBUG(dbgs() << "Collapse bands in schedule\n"); 648 BandCollapseRewriter Rewriter; 649 return Rewriter.visit(Sched); 650 } 651 652 /// Collect sequentially executed bands (or anything else), even if nested in a 653 /// mark or other nodes whose child is executed just once. If we can 654 /// successfully fuse the bands, we allow them to be removed. 655 static void collectPotentiallyFusableBands( 656 isl::schedule_node Node, 657 SmallVectorImpl<std::pair<isl::schedule_node, isl::schedule_node>> 658 &ScheduleBands, 659 const isl::schedule_node &DirectChild) { 660 switch (isl_schedule_node_get_type(Node.get())) { 661 case isl_schedule_node_sequence: 662 case isl_schedule_node_set: 663 case isl_schedule_node_mark: 664 case isl_schedule_node_domain: 665 case isl_schedule_node_filter: 666 if (Node.has_children()) { 667 isl::schedule_node C = Node.first_child(); 668 while (true) { 669 collectPotentiallyFusableBands(C, ScheduleBands, DirectChild); 670 if (!C.has_next_sibling()) 671 break; 672 C = C.next_sibling(); 673 } 674 } 675 break; 676 677 default: 678 // Something that does not execute suquentially (e.g. a band) 679 ScheduleBands.push_back({Node, DirectChild}); 680 break; 681 } 682 } 683 684 /// Remove dependencies that are resolved by @p PartSched. That is, remove 685 /// everything that we already know is executed in-order. 686 static isl::union_map remainingDepsFromPartialSchedule(isl::union_map PartSched, 687 isl::union_map Deps) { 688 unsigned NumDims = getNumScatterDims(PartSched); 689 auto ParamSpace = PartSched.get_space().params(); 690 691 // { Scatter[] } 692 isl::space ScatterSpace = 693 ParamSpace.set_from_params().add_dims(isl::dim::set, NumDims); 694 695 // { Scatter[] -> Domain[] } 696 isl::union_map PartSchedRev = PartSched.reverse(); 697 698 // { Scatter[] -> Scatter[] } 699 isl::map MaybeBefore = isl::map::lex_le(ScatterSpace); 700 701 // { Domain[] -> Domain[] } 702 isl::union_map DomMaybeBefore = 703 MaybeBefore.apply_domain(PartSchedRev).apply_range(PartSchedRev); 704 705 // { Domain[] -> Domain[] } 706 isl::union_map ChildRemainingDeps = Deps.intersect(DomMaybeBefore); 707 708 return ChildRemainingDeps; 709 } 710 711 /// Remove dependencies that are resolved by executing them in the order 712 /// specified by @p Domains; 713 static isl::union_map remainigDepsFromSequence(ArrayRef<isl::union_set> Domains, 714 isl::union_map Deps) { 715 isl::ctx Ctx = Deps.ctx(); 716 isl::space ParamSpace = Deps.get_space().params(); 717 718 // Create a partial schedule mapping to constants that reflect the execution 719 // order. 720 isl::union_map PartialSchedules = isl::union_map::empty(Ctx); 721 for (auto P : enumerate(Domains)) { 722 isl::val ExecTime = isl::val(Ctx, P.index()); 723 isl::union_pw_aff DomSched{P.value(), ExecTime}; 724 PartialSchedules = PartialSchedules.unite(DomSched.as_union_map()); 725 } 726 727 return remainingDepsFromPartialSchedule(PartialSchedules, Deps); 728 } 729 730 /// Determine whether the outermost loop of to bands can be fused while 731 /// respecting validity dependencies. 732 static bool canFuseOutermost(const isl::schedule_node_band &LHS, 733 const isl::schedule_node_band &RHS, 734 const isl::union_map &Deps) { 735 // { LHSDomain[] -> Scatter[] } 736 isl::union_map LHSPartSched = 737 LHS.get_partial_schedule().get_at(0).as_union_map(); 738 739 // { Domain[] -> Scatter[] } 740 isl::union_map RHSPartSched = 741 RHS.get_partial_schedule().get_at(0).as_union_map(); 742 743 // Dependencies that are already resolved because LHS executes before RHS, but 744 // will not be anymore after fusion. { DefDomain[] -> UseDomain[] } 745 isl::union_map OrderedBySequence = 746 Deps.intersect_domain(LHSPartSched.domain()) 747 .intersect_range(RHSPartSched.domain()); 748 749 isl::space ParamSpace = OrderedBySequence.get_space().params(); 750 isl::space NewScatterSpace = ParamSpace.add_unnamed_tuple(1); 751 752 // { Scatter[] -> Scatter[] } 753 isl::map After = isl::map::lex_gt(NewScatterSpace); 754 755 // After fusion, instances with smaller (or equal, which means they will be 756 // executed in the same iteration, but the LHS instance is still sequenced 757 // before RHS) scatter value will still be executed before. This are the 758 // orderings where this is not necessarily the case. 759 // { LHSDomain[] -> RHSDomain[] } 760 isl::union_map MightBeAfterDoms = After.apply_domain(LHSPartSched.reverse()) 761 .apply_range(RHSPartSched.reverse()); 762 763 // Dependencies that are not resolved by the new execution order. 764 isl::union_map WithBefore = OrderedBySequence.intersect(MightBeAfterDoms); 765 766 return WithBefore.is_empty(); 767 } 768 769 /// Fuse @p LHS and @p RHS if possible while preserving validity dependenvies. 770 static isl::schedule tryGreedyFuse(isl::schedule_node_band LHS, 771 isl::schedule_node_band RHS, 772 const isl::union_map &Deps) { 773 if (!canFuseOutermost(LHS, RHS, Deps)) 774 return {}; 775 776 LLVM_DEBUG({ 777 dbgs() << "Found loops for greedy fusion:\n"; 778 dumpIslObj(LHS, dbgs()); 779 dbgs() << "and\n"; 780 dumpIslObj(RHS, dbgs()); 781 dbgs() << "\n"; 782 }); 783 784 // The partial schedule of the bands outermost loop that we need to combine 785 // for the fusion. 786 isl::union_pw_aff LHSPartOuterSched = LHS.get_partial_schedule().get_at(0); 787 isl::union_pw_aff RHSPartOuterSched = RHS.get_partial_schedule().get_at(0); 788 789 // Isolate band bodies as roots of their own schedule trees. 790 IdentityRewriter Rewriter; 791 isl::schedule LHSBody = Rewriter.visit(LHS.first_child()); 792 isl::schedule RHSBody = Rewriter.visit(RHS.first_child()); 793 794 // Reconstruct the non-outermost (not going to be fused) loops from both 795 // bands. 796 // TODO: Maybe it is possibly to transfer the 'permutability' property from 797 // LHS+RHS. At minimum we need merge multiple band members at once, otherwise 798 // permutability has no meaning. 799 isl::schedule LHSNewBody = 800 rebuildBand(LHS, LHSBody, [](int i) { return i > 0; }); 801 isl::schedule RHSNewBody = 802 rebuildBand(RHS, RHSBody, [](int i) { return i > 0; }); 803 804 // The loop body of the fused loop. 805 isl::schedule NewCommonBody = LHSNewBody.sequence(RHSNewBody); 806 807 // Combine the partial schedules of both loops to a new one. Instances with 808 // the same scatter value are put together. 809 isl::union_map NewCommonPartialSched = 810 LHSPartOuterSched.as_union_map().unite(RHSPartOuterSched.as_union_map()); 811 isl::schedule NewCommonSchedule = NewCommonBody.insert_partial_schedule( 812 NewCommonPartialSched.as_multi_union_pw_aff()); 813 814 return NewCommonSchedule; 815 } 816 817 static isl::schedule tryGreedyFuse(isl::schedule_node LHS, 818 isl::schedule_node RHS, 819 const isl::union_map &Deps) { 820 // TODO: Non-bands could be interpreted as a band with just as single 821 // iteration. However, this is only useful if both ends of a fused loop were 822 // originally loops themselves. 823 if (!LHS.isa<isl::schedule_node_band>()) 824 return {}; 825 if (!RHS.isa<isl::schedule_node_band>()) 826 return {}; 827 828 return tryGreedyFuse(LHS.as<isl::schedule_node_band>(), 829 RHS.as<isl::schedule_node_band>(), Deps); 830 } 831 832 /// Fuse all fusable loop top-down in a schedule tree. 833 /// 834 /// The isl::union_map parameters is the set of validity dependencies that have 835 /// not been resolved/carried by a parent schedule node. 836 class GreedyFusionRewriter 837 : public ScheduleTreeRewriter<GreedyFusionRewriter, isl::union_map> { 838 private: 839 using BaseTy = ScheduleTreeRewriter<GreedyFusionRewriter, isl::union_map>; 840 BaseTy &getBase() { return *this; } 841 const BaseTy &getBase() const { return *this; } 842 843 public: 844 /// Is set to true if anything has been fused. 845 bool AnyChange = false; 846 847 isl::schedule visitBand(isl::schedule_node_band Band, isl::union_map Deps) { 848 // { Domain[] -> Scatter[] } 849 isl::union_map PartSched = 850 isl::union_map::from(Band.get_partial_schedule()); 851 assert(getNumScatterDims(PartSched) == 852 unsignedFromIslSize(Band.n_member())); 853 isl::space ParamSpace = PartSched.get_space().params(); 854 855 // { Scatter[] -> Domain[] } 856 isl::union_map PartSchedRev = PartSched.reverse(); 857 858 // Possible within the same iteration. Dependencies with smaller scatter 859 // value are carried by this loop and therefore have been resolved by the 860 // in-order execution if the loop iteration. A dependency with small scatter 861 // value would be a dependency violation that we assume did not happen. { 862 // Domain[] -> Domain[] } 863 isl::union_map Unsequenced = PartSchedRev.apply_domain(PartSchedRev); 864 865 // Actual dependencies within the same iteration. 866 // { DefDomain[] -> UseDomain[] } 867 isl::union_map RemDeps = Deps.intersect(Unsequenced); 868 869 return getBase().visitBand(Band, RemDeps); 870 } 871 872 isl::schedule visitSequence(isl::schedule_node_sequence Sequence, 873 isl::union_map Deps) { 874 int NumChildren = isl_schedule_node_n_children(Sequence.get()); 875 876 // List of fusion candidates. The first element is the fusion candidate, the 877 // second is candidate's ancestor that is the sequence's direct child. It is 878 // preferable to use the direct child if not if its non-direct children is 879 // fused to preserve its structure such as mark nodes. 880 SmallVector<std::pair<isl::schedule_node, isl::schedule_node>> Bands; 881 for (auto i : seq<int>(0, NumChildren)) { 882 isl::schedule_node Child = Sequence.child(i); 883 collectPotentiallyFusableBands(Child, Bands, Child); 884 } 885 886 // Direct children that had at least one of its decendants fused. 887 SmallDenseSet<isl_schedule_node *, 4> ChangedDirectChildren; 888 889 // Fuse neigboring bands until reaching the end of candidates. 890 int i = 0; 891 while (i + 1 < (int)Bands.size()) { 892 isl::schedule Fused = 893 tryGreedyFuse(Bands[i].first, Bands[i + 1].first, Deps); 894 if (Fused.is_null()) { 895 // Cannot merge this node with the next; look at next pair. 896 i += 1; 897 continue; 898 } 899 900 // Mark the direct children as (partially) fused. 901 if (!Bands[i].second.is_null()) 902 ChangedDirectChildren.insert(Bands[i].second.get()); 903 if (!Bands[i + 1].second.is_null()) 904 ChangedDirectChildren.insert(Bands[i + 1].second.get()); 905 906 // Collapse the neigbros to a single new candidate that could be fused 907 // with the next candidate. 908 Bands[i] = {Fused.get_root(), {}}; 909 Bands.erase(Bands.begin() + i + 1); 910 911 AnyChange = true; 912 } 913 914 // By construction equal if done with collectPotentiallyFusableBands's 915 // output. 916 SmallVector<isl::union_set> SubDomains; 917 SubDomains.reserve(NumChildren); 918 for (int i = 0; i < NumChildren; i += 1) 919 SubDomains.push_back(Sequence.child(i).domain()); 920 auto SubRemainingDeps = remainigDepsFromSequence(SubDomains, Deps); 921 922 // We may iterate over direct children multiple times, be sure to add each 923 // at most once. 924 SmallDenseSet<isl_schedule_node *, 4> AlreadyAdded; 925 926 isl::schedule Result; 927 for (auto &P : Bands) { 928 isl::schedule_node MaybeFused = P.first; 929 isl::schedule_node DirectChild = P.second; 930 931 // If not modified, use the direct child. 932 if (!DirectChild.is_null() && 933 !ChangedDirectChildren.count(DirectChild.get())) { 934 if (AlreadyAdded.count(DirectChild.get())) 935 continue; 936 AlreadyAdded.insert(DirectChild.get()); 937 MaybeFused = DirectChild; 938 } else { 939 assert(AnyChange && 940 "Need changed flag for be consistent with actual change"); 941 } 942 943 // Top-down recursion: If the outermost loop has been fused, their nested 944 // bands might be fusable now as well. 945 isl::schedule InnerFused = visit(MaybeFused, SubRemainingDeps); 946 947 // Reconstruct the sequence, with some of the children fused. 948 if (Result.is_null()) 949 Result = InnerFused; 950 else 951 Result = Result.sequence(InnerFused); 952 } 953 954 return Result; 955 } 956 }; 957 958 } // namespace 959 960 bool polly::isBandMark(const isl::schedule_node &Node) { 961 return isMark(Node) && 962 isLoopAttr(Node.as<isl::schedule_node_mark>().get_id()); 963 } 964 965 BandAttr *polly::getBandAttr(isl::schedule_node MarkOrBand) { 966 MarkOrBand = moveToBandMark(MarkOrBand); 967 if (!isMark(MarkOrBand)) 968 return nullptr; 969 970 return getLoopAttr(MarkOrBand.as<isl::schedule_node_mark>().get_id()); 971 } 972 973 isl::schedule polly::hoistExtensionNodes(isl::schedule Sched) { 974 // If there is no extension node in the first place, return the original 975 // schedule tree. 976 if (!containsExtensionNode(Sched)) 977 return Sched; 978 979 // Build options can anchor schedule nodes, such that the schedule tree cannot 980 // be modified anymore. Therefore, apply build options after the tree has been 981 // created. 982 CollectASTBuildOptions Collector; 983 Collector.visit(Sched); 984 985 // Rewrite the schedule tree without extension nodes. 986 ExtensionNodeRewriter Rewriter; 987 isl::schedule NewSched = Rewriter.visitSchedule(Sched); 988 989 // Reapply the AST build options. The rewriter must not change the iteration 990 // order of bands. Any other node type is ignored. 991 ApplyASTBuildOptions Applicator(Collector.ASTBuildOptions); 992 NewSched = Applicator.visitSchedule(NewSched); 993 994 return NewSched; 995 } 996 997 isl::schedule polly::applyFullUnroll(isl::schedule_node BandToUnroll) { 998 isl::ctx Ctx = BandToUnroll.ctx(); 999 1000 // Remove the loop's mark, the loop will disappear anyway. 1001 BandToUnroll = removeMark(BandToUnroll); 1002 assert(isBandWithSingleLoop(BandToUnroll)); 1003 1004 isl::multi_union_pw_aff PartialSched = isl::manage( 1005 isl_schedule_node_band_get_partial_schedule(BandToUnroll.get())); 1006 assert(unsignedFromIslSize(PartialSched.dim(isl::dim::out)) == 1u && 1007 "Can only unroll a single dimension"); 1008 isl::union_pw_aff PartialSchedUAff = PartialSched.at(0); 1009 1010 isl::union_set Domain = BandToUnroll.get_domain(); 1011 PartialSchedUAff = PartialSchedUAff.intersect_domain(Domain); 1012 isl::union_map PartialSchedUMap = 1013 isl::union_map::from(isl::union_pw_multi_aff(PartialSchedUAff)); 1014 1015 // Enumerator only the scatter elements. 1016 isl::union_set ScatterList = PartialSchedUMap.range(); 1017 1018 // Enumerate all loop iterations. 1019 // TODO: Diagnose if not enumerable or depends on a parameter. 1020 SmallVector<isl::point, 16> Elts; 1021 ScatterList.foreach_point([&Elts](isl::point P) -> isl::stat { 1022 Elts.push_back(P); 1023 return isl::stat::ok(); 1024 }); 1025 1026 // Don't assume that foreach_point returns in execution order. 1027 llvm::sort(Elts, [](isl::point P1, isl::point P2) -> bool { 1028 isl::val C1 = P1.get_coordinate_val(isl::dim::set, 0); 1029 isl::val C2 = P2.get_coordinate_val(isl::dim::set, 0); 1030 return C1.lt(C2); 1031 }); 1032 1033 // Convert the points to a sequence of filters. 1034 isl::union_set_list List = isl::union_set_list(Ctx, Elts.size()); 1035 for (isl::point P : Elts) { 1036 // Determine the domains that map this scatter element. 1037 isl::union_set DomainFilter = PartialSchedUMap.intersect_range(P).domain(); 1038 1039 List = List.add(DomainFilter); 1040 } 1041 1042 // Replace original band with unrolled sequence. 1043 isl::schedule_node Body = 1044 isl::manage(isl_schedule_node_delete(BandToUnroll.release())); 1045 Body = Body.insert_sequence(List); 1046 return Body.get_schedule(); 1047 } 1048 1049 isl::schedule polly::applyPartialUnroll(isl::schedule_node BandToUnroll, 1050 int Factor) { 1051 assert(Factor > 0 && "Positive unroll factor required"); 1052 isl::ctx Ctx = BandToUnroll.ctx(); 1053 1054 // Remove the mark, save the attribute for later use. 1055 BandAttr *Attr; 1056 BandToUnroll = removeMark(BandToUnroll, Attr); 1057 assert(isBandWithSingleLoop(BandToUnroll)); 1058 1059 isl::multi_union_pw_aff PartialSched = isl::manage( 1060 isl_schedule_node_band_get_partial_schedule(BandToUnroll.get())); 1061 1062 // { Stmt[] -> [x] } 1063 isl::union_pw_aff PartialSchedUAff = PartialSched.at(0); 1064 1065 // Here we assume the schedule stride is one and starts with 0, which is not 1066 // necessarily the case. 1067 isl::union_pw_aff StridedPartialSchedUAff = 1068 isl::union_pw_aff::empty(PartialSchedUAff.get_space()); 1069 isl::val ValFactor{Ctx, Factor}; 1070 PartialSchedUAff.foreach_pw_aff([&StridedPartialSchedUAff, 1071 &ValFactor](isl::pw_aff PwAff) -> isl::stat { 1072 isl::space Space = PwAff.get_space(); 1073 isl::set Universe = isl::set::universe(Space.domain()); 1074 isl::pw_aff AffFactor{Universe, ValFactor}; 1075 isl::pw_aff DivSchedAff = PwAff.div(AffFactor).floor().mul(AffFactor); 1076 StridedPartialSchedUAff = StridedPartialSchedUAff.union_add(DivSchedAff); 1077 return isl::stat::ok(); 1078 }); 1079 1080 isl::union_set_list List = isl::union_set_list(Ctx, Factor); 1081 for (auto i : seq<int>(0, Factor)) { 1082 // { Stmt[] -> [x] } 1083 isl::union_map UMap = 1084 isl::union_map::from(isl::union_pw_multi_aff(PartialSchedUAff)); 1085 1086 // { [x] } 1087 isl::basic_set Divisible = isDivisibleBySet(Ctx, Factor, i); 1088 1089 // { Stmt[] } 1090 isl::union_set UnrolledDomain = UMap.intersect_range(Divisible).domain(); 1091 1092 List = List.add(UnrolledDomain); 1093 } 1094 1095 isl::schedule_node Body = 1096 isl::manage(isl_schedule_node_delete(BandToUnroll.copy())); 1097 Body = Body.insert_sequence(List); 1098 isl::schedule_node NewLoop = 1099 Body.insert_partial_schedule(StridedPartialSchedUAff); 1100 1101 MDNode *FollowupMD = nullptr; 1102 if (Attr && Attr->Metadata) 1103 FollowupMD = 1104 findOptionalNodeOperand(Attr->Metadata, LLVMLoopUnrollFollowupUnrolled); 1105 1106 isl::id NewBandId = createGeneratedLoopAttr(Ctx, FollowupMD); 1107 if (!NewBandId.is_null()) 1108 NewLoop = insertMark(NewLoop, NewBandId); 1109 1110 return NewLoop.get_schedule(); 1111 } 1112 1113 isl::set polly::getPartialTilePrefixes(isl::set ScheduleRange, 1114 int VectorWidth) { 1115 unsigned Dims = unsignedFromIslSize(ScheduleRange.tuple_dim()); 1116 assert(Dims >= 1); 1117 isl::set LoopPrefixes = 1118 ScheduleRange.drop_constraints_involving_dims(isl::dim::set, Dims - 1, 1); 1119 auto ExtentPrefixes = addExtentConstraints(LoopPrefixes, VectorWidth); 1120 isl::set BadPrefixes = ExtentPrefixes.subtract(ScheduleRange); 1121 BadPrefixes = BadPrefixes.project_out(isl::dim::set, Dims - 1, 1); 1122 LoopPrefixes = LoopPrefixes.project_out(isl::dim::set, Dims - 1, 1); 1123 return LoopPrefixes.subtract(BadPrefixes); 1124 } 1125 1126 isl::union_set polly::getIsolateOptions(isl::set IsolateDomain, 1127 unsigned OutDimsNum) { 1128 unsigned Dims = unsignedFromIslSize(IsolateDomain.tuple_dim()); 1129 assert(OutDimsNum <= Dims && 1130 "The isl::set IsolateDomain is used to describe the range of schedule " 1131 "dimensions values, which should be isolated. Consequently, the " 1132 "number of its dimensions should be greater than or equal to the " 1133 "number of the schedule dimensions."); 1134 isl::map IsolateRelation = isl::map::from_domain(IsolateDomain); 1135 IsolateRelation = IsolateRelation.move_dims(isl::dim::out, 0, isl::dim::in, 1136 Dims - OutDimsNum, OutDimsNum); 1137 isl::set IsolateOption = IsolateRelation.wrap(); 1138 isl::id Id = isl::id::alloc(IsolateOption.ctx(), "isolate", nullptr); 1139 IsolateOption = IsolateOption.set_tuple_id(Id); 1140 return isl::union_set(IsolateOption); 1141 } 1142 1143 isl::union_set polly::getDimOptions(isl::ctx Ctx, const char *Option) { 1144 isl::space Space(Ctx, 0, 1); 1145 auto DimOption = isl::set::universe(Space); 1146 auto Id = isl::id::alloc(Ctx, Option, nullptr); 1147 DimOption = DimOption.set_tuple_id(Id); 1148 return isl::union_set(DimOption); 1149 } 1150 1151 isl::schedule_node polly::tileNode(isl::schedule_node Node, 1152 const char *Identifier, 1153 ArrayRef<int> TileSizes, 1154 int DefaultTileSize) { 1155 auto Space = isl::manage(isl_schedule_node_band_get_space(Node.get())); 1156 auto Dims = Space.dim(isl::dim::set); 1157 auto Sizes = isl::multi_val::zero(Space); 1158 std::string IdentifierString(Identifier); 1159 for (unsigned i : rangeIslSize(0, Dims)) { 1160 unsigned tileSize = i < TileSizes.size() ? TileSizes[i] : DefaultTileSize; 1161 Sizes = Sizes.set_val(i, isl::val(Node.ctx(), tileSize)); 1162 } 1163 auto TileLoopMarkerStr = IdentifierString + " - Tiles"; 1164 auto TileLoopMarker = isl::id::alloc(Node.ctx(), TileLoopMarkerStr, nullptr); 1165 Node = Node.insert_mark(TileLoopMarker); 1166 Node = Node.child(0); 1167 Node = 1168 isl::manage(isl_schedule_node_band_tile(Node.release(), Sizes.release())); 1169 Node = Node.child(0); 1170 auto PointLoopMarkerStr = IdentifierString + " - Points"; 1171 auto PointLoopMarker = 1172 isl::id::alloc(Node.ctx(), PointLoopMarkerStr, nullptr); 1173 Node = Node.insert_mark(PointLoopMarker); 1174 return Node.child(0); 1175 } 1176 1177 isl::schedule_node polly::applyRegisterTiling(isl::schedule_node Node, 1178 ArrayRef<int> TileSizes, 1179 int DefaultTileSize) { 1180 Node = tileNode(Node, "Register tiling", TileSizes, DefaultTileSize); 1181 auto Ctx = Node.ctx(); 1182 return Node.as<isl::schedule_node_band>().set_ast_build_options( 1183 isl::union_set(Ctx, "{unroll[x]}")); 1184 } 1185 1186 /// Find statements and sub-loops in (possibly nested) sequences. 1187 static void 1188 collectFussionableStmts(isl::schedule_node Node, 1189 SmallVectorImpl<isl::schedule_node> &ScheduleStmts) { 1190 if (isBand(Node) || isLeaf(Node)) { 1191 ScheduleStmts.push_back(Node); 1192 return; 1193 } 1194 1195 if (Node.has_children()) { 1196 isl::schedule_node C = Node.first_child(); 1197 while (true) { 1198 collectFussionableStmts(C, ScheduleStmts); 1199 if (!C.has_next_sibling()) 1200 break; 1201 C = C.next_sibling(); 1202 } 1203 } 1204 } 1205 1206 isl::schedule polly::applyMaxFission(isl::schedule_node BandToFission) { 1207 isl::ctx Ctx = BandToFission.ctx(); 1208 BandToFission = removeMark(BandToFission); 1209 isl::schedule_node BandBody = BandToFission.child(0); 1210 1211 SmallVector<isl::schedule_node> FissionableStmts; 1212 collectFussionableStmts(BandBody, FissionableStmts); 1213 size_t N = FissionableStmts.size(); 1214 1215 // Collect the domain for each of the statements that will get their own loop. 1216 isl::union_set_list DomList = isl::union_set_list(Ctx, N); 1217 for (size_t i = 0; i < N; ++i) { 1218 isl::schedule_node BodyPart = FissionableStmts[i]; 1219 DomList = DomList.add(BodyPart.get_domain()); 1220 } 1221 1222 // Apply the fission by copying the entire loop, but inserting a filter for 1223 // the statement domains for each fissioned loop. 1224 isl::schedule_node Fissioned = BandToFission.insert_sequence(DomList); 1225 1226 return Fissioned.get_schedule(); 1227 } 1228 1229 isl::schedule polly::applyGreedyFusion(isl::schedule Sched, 1230 const isl::union_map &Deps) { 1231 LLVM_DEBUG(dbgs() << "Greedy loop fusion\n"); 1232 1233 GreedyFusionRewriter Rewriter; 1234 isl::schedule Result = Rewriter.visit(Sched, Deps); 1235 if (!Rewriter.AnyChange) { 1236 LLVM_DEBUG(dbgs() << "Found nothing to fuse\n"); 1237 return Sched; 1238 } 1239 1240 // GreedyFusionRewriter due to working loop-by-loop, bands with multiple loops 1241 // may have been split into multiple bands. 1242 return collapseBands(Result); 1243 } 1244