1 //===- polly/ScheduleTreeTransform.cpp --------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Make changes to isl's schedule tree data structure. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "polly/ScheduleTreeTransform.h" 14 #include "polly/Support/ISLTools.h" 15 #include "llvm/ADT/ArrayRef.h" 16 #include "llvm/ADT/SmallVector.h" 17 18 using namespace polly; 19 20 namespace { 21 22 /// This class defines a simple visitor class that may be used for 23 /// various schedule tree analysis purposes. 24 template <typename Derived, typename RetTy = void, typename... Args> 25 struct ScheduleTreeVisitor { 26 Derived &getDerived() { return *static_cast<Derived *>(this); } 27 const Derived &getDerived() const { 28 return *static_cast<const Derived *>(this); 29 } 30 31 RetTy visit(const isl::schedule_node &Node, Args... args) { 32 assert(!Node.is_null()); 33 switch (isl_schedule_node_get_type(Node.get())) { 34 case isl_schedule_node_domain: 35 assert(isl_schedule_node_n_children(Node.get()) == 1); 36 return getDerived().visitDomain(Node, std::forward<Args>(args)...); 37 case isl_schedule_node_band: 38 assert(isl_schedule_node_n_children(Node.get()) == 1); 39 return getDerived().visitBand(Node, std::forward<Args>(args)...); 40 case isl_schedule_node_sequence: 41 assert(isl_schedule_node_n_children(Node.get()) >= 2); 42 return getDerived().visitSequence(Node, std::forward<Args>(args)...); 43 case isl_schedule_node_set: 44 return getDerived().visitSet(Node, std::forward<Args>(args)...); 45 assert(isl_schedule_node_n_children(Node.get()) >= 2); 46 case isl_schedule_node_leaf: 47 assert(isl_schedule_node_n_children(Node.get()) == 0); 48 return getDerived().visitLeaf(Node, std::forward<Args>(args)...); 49 case isl_schedule_node_mark: 50 assert(isl_schedule_node_n_children(Node.get()) == 1); 51 return getDerived().visitMark(Node, std::forward<Args>(args)...); 52 case isl_schedule_node_extension: 53 assert(isl_schedule_node_n_children(Node.get()) == 1); 54 return getDerived().visitExtension(Node, std::forward<Args>(args)...); 55 case isl_schedule_node_filter: 56 assert(isl_schedule_node_n_children(Node.get()) == 1); 57 return getDerived().visitFilter(Node, std::forward<Args>(args)...); 58 default: 59 llvm_unreachable("unimplemented schedule node type"); 60 } 61 } 62 63 RetTy visitDomain(const isl::schedule_node &Domain, Args... args) { 64 return getDerived().visitSingleChild(Domain, std::forward<Args>(args)...); 65 } 66 67 RetTy visitBand(const isl::schedule_node &Band, Args... args) { 68 return getDerived().visitSingleChild(Band, std::forward<Args>(args)...); 69 } 70 71 RetTy visitSequence(const isl::schedule_node &Sequence, Args... args) { 72 return getDerived().visitMultiChild(Sequence, std::forward<Args>(args)...); 73 } 74 75 RetTy visitSet(const isl::schedule_node &Set, Args... args) { 76 return getDerived().visitMultiChild(Set, std::forward<Args>(args)...); 77 } 78 79 RetTy visitLeaf(const isl::schedule_node &Leaf, Args... args) { 80 return getDerived().visitNode(Leaf, std::forward<Args>(args)...); 81 } 82 83 RetTy visitMark(const isl::schedule_node &Mark, Args... args) { 84 return getDerived().visitSingleChild(Mark, std::forward<Args>(args)...); 85 } 86 87 RetTy visitExtension(const isl::schedule_node &Extension, Args... args) { 88 return getDerived().visitSingleChild(Extension, 89 std::forward<Args>(args)...); 90 } 91 92 RetTy visitFilter(const isl::schedule_node &Extension, Args... args) { 93 return getDerived().visitSingleChild(Extension, 94 std::forward<Args>(args)...); 95 } 96 97 RetTy visitSingleChild(const isl::schedule_node &Node, Args... args) { 98 return getDerived().visitNode(Node, std::forward<Args>(args)...); 99 } 100 101 RetTy visitMultiChild(const isl::schedule_node &Node, Args... args) { 102 return getDerived().visitNode(Node, std::forward<Args>(args)...); 103 } 104 105 RetTy visitNode(const isl::schedule_node &Node, Args... args) { 106 llvm_unreachable("Unimplemented other"); 107 } 108 }; 109 110 /// Recursively visit all nodes of a schedule tree. 111 template <typename Derived, typename RetTy = void, typename... Args> 112 struct RecursiveScheduleTreeVisitor 113 : public ScheduleTreeVisitor<Derived, RetTy, Args...> { 114 using BaseTy = ScheduleTreeVisitor<Derived, RetTy, Args...>; 115 BaseTy &getBase() { return *this; } 116 const BaseTy &getBase() const { return *this; } 117 Derived &getDerived() { return *static_cast<Derived *>(this); } 118 const Derived &getDerived() const { 119 return *static_cast<const Derived *>(this); 120 } 121 122 /// When visiting an entire schedule tree, start at its root node. 123 RetTy visit(const isl::schedule &Schedule, Args... args) { 124 return getDerived().visit(Schedule.get_root(), std::forward<Args>(args)...); 125 } 126 127 // Necessary to allow overload resolution with the added visit(isl::schedule) 128 // overload. 129 RetTy visit(const isl::schedule_node &Node, Args... args) { 130 return getBase().visit(Node, std::forward<Args>(args)...); 131 } 132 133 RetTy visitNode(const isl::schedule_node &Node, Args... args) { 134 int NumChildren = isl_schedule_node_n_children(Node.get()); 135 for (int i = 0; i < NumChildren; i += 1) 136 getDerived().visit(Node.child(i), std::forward<Args>(args)...); 137 return RetTy(); 138 } 139 }; 140 141 /// Recursively visit all nodes of a schedule tree while allowing changes. 142 /// 143 /// The visit methods return an isl::schedule_node that is used to continue 144 /// visiting the tree. Structural changes such as returning a different node 145 /// will confuse the visitor. 146 template <typename Derived, typename... Args> 147 struct ScheduleNodeRewriter 148 : public RecursiveScheduleTreeVisitor<Derived, isl::schedule_node, 149 Args...> { 150 Derived &getDerived() { return *static_cast<Derived *>(this); } 151 const Derived &getDerived() const { 152 return *static_cast<const Derived *>(this); 153 } 154 155 isl::schedule_node visitNode(const isl::schedule_node &Node, Args... args) { 156 if (!Node.has_children()) 157 return Node; 158 159 isl::schedule_node It = Node.first_child(); 160 while (true) { 161 It = getDerived().visit(It, std::forward<Args>(args)...); 162 if (!It.has_next_sibling()) 163 break; 164 It = It.next_sibling(); 165 } 166 return It.parent(); 167 } 168 }; 169 170 /// Rewrite a schedule tree by reconstructing it bottom-up. 171 /// 172 /// By default, the original schedule tree is reconstructed. To build a 173 /// different tree, redefine visitor methods in a derived class (CRTP). 174 /// 175 /// Note that AST build options are not applied; Setting the isolate[] option 176 /// makes the schedule tree 'anchored' and cannot be modified afterwards. Hence, 177 /// AST build options must be set after the tree has been constructed. 178 template <typename Derived, typename... Args> 179 struct ScheduleTreeRewriter 180 : public RecursiveScheduleTreeVisitor<Derived, isl::schedule, Args...> { 181 Derived &getDerived() { return *static_cast<Derived *>(this); } 182 const Derived &getDerived() const { 183 return *static_cast<const Derived *>(this); 184 } 185 186 isl::schedule visitDomain(const isl::schedule_node &Node, Args... args) { 187 // Every schedule_tree already has a domain node, no need to add one. 188 return getDerived().visit(Node.first_child(), std::forward<Args>(args)...); 189 } 190 191 isl::schedule visitBand(const isl::schedule_node &Band, Args... args) { 192 isl::multi_union_pw_aff PartialSched = 193 isl::manage(isl_schedule_node_band_get_partial_schedule(Band.get())); 194 isl::schedule NewChild = 195 getDerived().visit(Band.child(0), std::forward<Args>(args)...); 196 isl::schedule_node NewNode = 197 NewChild.insert_partial_schedule(PartialSched).get_root().get_child(0); 198 199 // Reapply permutability and coincidence attributes. 200 NewNode = isl::manage(isl_schedule_node_band_set_permutable( 201 NewNode.release(), isl_schedule_node_band_get_permutable(Band.get()))); 202 unsigned BandDims = isl_schedule_node_band_n_member(Band.get()); 203 for (unsigned i = 0; i < BandDims; i += 1) 204 NewNode = isl::manage(isl_schedule_node_band_member_set_coincident( 205 NewNode.release(), i, 206 isl_schedule_node_band_member_get_coincident(Band.get(), i))); 207 208 return NewNode.get_schedule(); 209 } 210 211 isl::schedule visitSequence(const isl::schedule_node &Sequence, 212 Args... args) { 213 int NumChildren = isl_schedule_node_n_children(Sequence.get()); 214 isl::schedule Result = 215 getDerived().visit(Sequence.child(0), std::forward<Args>(args)...); 216 for (int i = 1; i < NumChildren; i += 1) 217 Result = Result.sequence( 218 getDerived().visit(Sequence.child(i), std::forward<Args>(args)...)); 219 return Result; 220 } 221 222 isl::schedule visitSet(const isl::schedule_node &Set, Args... args) { 223 int NumChildren = isl_schedule_node_n_children(Set.get()); 224 isl::schedule Result = 225 getDerived().visit(Set.child(0), std::forward<Args>(args)...); 226 for (int i = 1; i < NumChildren; i += 1) 227 Result = isl::manage( 228 isl_schedule_set(Result.release(), 229 getDerived() 230 .visit(Set.child(i), std::forward<Args>(args)...) 231 .release())); 232 return Result; 233 } 234 235 isl::schedule visitLeaf(const isl::schedule_node &Leaf, Args... args) { 236 return isl::schedule::from_domain(Leaf.get_domain()); 237 } 238 239 isl::schedule visitMark(const isl::schedule_node &Mark, Args... args) { 240 isl::id TheMark = Mark.mark_get_id(); 241 isl::schedule_node NewChild = 242 getDerived() 243 .visit(Mark.first_child(), std::forward<Args>(args)...) 244 .get_root() 245 .first_child(); 246 return NewChild.insert_mark(TheMark).get_schedule(); 247 } 248 249 isl::schedule visitExtension(const isl::schedule_node &Extension, 250 Args... args) { 251 isl::union_map TheExtension = Extension.extension_get_extension(); 252 isl::schedule_node NewChild = getDerived() 253 .visit(Extension.child(0), args...) 254 .get_root() 255 .first_child(); 256 isl::schedule_node NewExtension = 257 isl::schedule_node::from_extension(TheExtension); 258 return NewChild.graft_before(NewExtension).get_schedule(); 259 } 260 261 isl::schedule visitFilter(const isl::schedule_node &Filter, Args... args) { 262 isl::union_set FilterDomain = Filter.filter_get_filter(); 263 isl::schedule NewSchedule = 264 getDerived().visit(Filter.child(0), std::forward<Args>(args)...); 265 return NewSchedule.intersect_domain(FilterDomain); 266 } 267 268 isl::schedule visitNode(const isl::schedule_node &Node, Args... args) { 269 llvm_unreachable("Not implemented"); 270 } 271 }; 272 273 /// Rewrite a schedule tree to an equivalent one without extension nodes. 274 /// 275 /// Each visit method takes two additional arguments: 276 /// 277 /// * The new domain the node, which is the inherited domain plus any domains 278 /// added by extension nodes. 279 /// 280 /// * A map of extension domains of all children is returned; it is required by 281 /// band nodes to schedule the additional domains at the same position as the 282 /// extension node would. 283 /// 284 struct ExtensionNodeRewriter 285 : public ScheduleTreeRewriter<ExtensionNodeRewriter, const isl::union_set &, 286 isl::union_map &> { 287 using BaseTy = ScheduleTreeRewriter<ExtensionNodeRewriter, 288 const isl::union_set &, isl::union_map &>; 289 BaseTy &getBase() { return *this; } 290 const BaseTy &getBase() const { return *this; } 291 292 isl::schedule visitSchedule(const isl::schedule &Schedule) { 293 isl::union_map Extensions; 294 isl::schedule Result = 295 visit(Schedule.get_root(), Schedule.get_domain(), Extensions); 296 assert(Extensions && Extensions.is_empty()); 297 return Result; 298 } 299 300 isl::schedule visitSequence(const isl::schedule_node &Sequence, 301 const isl::union_set &Domain, 302 isl::union_map &Extensions) { 303 int NumChildren = isl_schedule_node_n_children(Sequence.get()); 304 isl::schedule NewNode = visit(Sequence.first_child(), Domain, Extensions); 305 for (int i = 1; i < NumChildren; i += 1) { 306 isl::schedule_node OldChild = Sequence.child(i); 307 isl::union_map NewChildExtensions; 308 isl::schedule NewChildNode = visit(OldChild, Domain, NewChildExtensions); 309 NewNode = NewNode.sequence(NewChildNode); 310 Extensions = Extensions.unite(NewChildExtensions); 311 } 312 return NewNode; 313 } 314 315 isl::schedule visitSet(const isl::schedule_node &Set, 316 const isl::union_set &Domain, 317 isl::union_map &Extensions) { 318 int NumChildren = isl_schedule_node_n_children(Set.get()); 319 isl::schedule NewNode = visit(Set.first_child(), Domain, Extensions); 320 for (int i = 1; i < NumChildren; i += 1) { 321 isl::schedule_node OldChild = Set.child(i); 322 isl::union_map NewChildExtensions; 323 isl::schedule NewChildNode = visit(OldChild, Domain, NewChildExtensions); 324 NewNode = isl::manage( 325 isl_schedule_set(NewNode.release(), NewChildNode.release())); 326 Extensions = Extensions.unite(NewChildExtensions); 327 } 328 return NewNode; 329 } 330 331 isl::schedule visitLeaf(const isl::schedule_node &Leaf, 332 const isl::union_set &Domain, 333 isl::union_map &Extensions) { 334 isl::ctx Ctx = Leaf.get_ctx(); 335 Extensions = isl::union_map::empty(isl::space::params_alloc(Ctx, 0)); 336 return isl::schedule::from_domain(Domain); 337 } 338 339 isl::schedule visitBand(const isl::schedule_node &OldNode, 340 const isl::union_set &Domain, 341 isl::union_map &OuterExtensions) { 342 isl::schedule_node OldChild = OldNode.first_child(); 343 isl::multi_union_pw_aff PartialSched = 344 isl::manage(isl_schedule_node_band_get_partial_schedule(OldNode.get())); 345 346 isl::union_map NewChildExtensions; 347 isl::schedule NewChild = visit(OldChild, Domain, NewChildExtensions); 348 349 // Add the extensions to the partial schedule. 350 OuterExtensions = isl::union_map::empty(NewChildExtensions.get_space()); 351 isl::union_map NewPartialSchedMap = isl::union_map::from(PartialSched); 352 unsigned BandDims = isl_schedule_node_band_n_member(OldNode.get()); 353 for (isl::map Ext : NewChildExtensions.get_map_list()) { 354 unsigned ExtDims = Ext.dim(isl::dim::in); 355 assert(ExtDims >= BandDims); 356 unsigned OuterDims = ExtDims - BandDims; 357 358 isl::map BandSched = 359 Ext.project_out(isl::dim::in, 0, OuterDims).reverse(); 360 NewPartialSchedMap = NewPartialSchedMap.unite(BandSched); 361 362 // There might be more outer bands that have to schedule the extensions. 363 if (OuterDims > 0) { 364 isl::map OuterSched = 365 Ext.project_out(isl::dim::in, OuterDims, BandDims); 366 OuterExtensions = OuterExtensions.add_map(OuterSched); 367 } 368 } 369 isl::multi_union_pw_aff NewPartialSchedAsAsMultiUnionPwAff = 370 isl::multi_union_pw_aff::from_union_map(NewPartialSchedMap); 371 isl::schedule_node NewNode = 372 NewChild.insert_partial_schedule(NewPartialSchedAsAsMultiUnionPwAff) 373 .get_root() 374 .get_child(0); 375 376 // Reapply permutability and coincidence attributes. 377 NewNode = isl::manage(isl_schedule_node_band_set_permutable( 378 NewNode.release(), 379 isl_schedule_node_band_get_permutable(OldNode.get()))); 380 for (unsigned i = 0; i < BandDims; i += 1) { 381 NewNode = isl::manage(isl_schedule_node_band_member_set_coincident( 382 NewNode.release(), i, 383 isl_schedule_node_band_member_get_coincident(OldNode.get(), i))); 384 } 385 386 return NewNode.get_schedule(); 387 } 388 389 isl::schedule visitFilter(const isl::schedule_node &Filter, 390 const isl::union_set &Domain, 391 isl::union_map &Extensions) { 392 isl::union_set FilterDomain = Filter.filter_get_filter(); 393 isl::union_set NewDomain = Domain.intersect(FilterDomain); 394 395 // A filter is added implicitly if necessary when joining schedule trees. 396 return visit(Filter.first_child(), NewDomain, Extensions); 397 } 398 399 isl::schedule visitExtension(const isl::schedule_node &Extension, 400 const isl::union_set &Domain, 401 isl::union_map &Extensions) { 402 isl::union_map ExtDomain = Extension.extension_get_extension(); 403 isl::union_set NewDomain = Domain.unite(ExtDomain.range()); 404 isl::union_map ChildExtensions; 405 isl::schedule NewChild = 406 visit(Extension.first_child(), NewDomain, ChildExtensions); 407 Extensions = ChildExtensions.unite(ExtDomain); 408 return NewChild; 409 } 410 }; 411 412 /// Collect all AST build options in any schedule tree band. 413 /// 414 /// ScheduleTreeRewriter cannot apply the schedule tree options. This class 415 /// collects these options to apply them later. 416 struct CollectASTBuildOptions 417 : public RecursiveScheduleTreeVisitor<CollectASTBuildOptions> { 418 using BaseTy = RecursiveScheduleTreeVisitor<CollectASTBuildOptions>; 419 BaseTy &getBase() { return *this; } 420 const BaseTy &getBase() const { return *this; } 421 422 llvm::SmallVector<isl::union_set, 8> ASTBuildOptions; 423 424 void visitBand(const isl::schedule_node &Band) { 425 ASTBuildOptions.push_back( 426 isl::manage(isl_schedule_node_band_get_ast_build_options(Band.get()))); 427 return getBase().visitBand(Band); 428 } 429 }; 430 431 /// Apply AST build options to the bands in a schedule tree. 432 /// 433 /// This rewrites a schedule tree with the AST build options applied. We assume 434 /// that the band nodes are visited in the same order as they were when the 435 /// build options were collected, typically by CollectASTBuildOptions. 436 struct ApplyASTBuildOptions 437 : public ScheduleNodeRewriter<ApplyASTBuildOptions> { 438 using BaseTy = ScheduleNodeRewriter<ApplyASTBuildOptions>; 439 BaseTy &getBase() { return *this; } 440 const BaseTy &getBase() const { return *this; } 441 442 size_t Pos; 443 llvm::ArrayRef<isl::union_set> ASTBuildOptions; 444 445 ApplyASTBuildOptions(llvm::ArrayRef<isl::union_set> ASTBuildOptions) 446 : ASTBuildOptions(ASTBuildOptions) {} 447 448 isl::schedule visitSchedule(const isl::schedule &Schedule) { 449 Pos = 0; 450 isl::schedule Result = visit(Schedule).get_schedule(); 451 assert(Pos == ASTBuildOptions.size() && 452 "AST build options must match to band nodes"); 453 return Result; 454 } 455 456 isl::schedule_node visitBand(const isl::schedule_node &Band) { 457 isl::schedule_node Result = 458 Band.band_set_ast_build_options(ASTBuildOptions[Pos]); 459 Pos += 1; 460 return getBase().visitBand(Result); 461 } 462 }; 463 464 } // namespace 465 466 /// Return whether the schedule contains an extension node. 467 static bool containsExtensionNode(isl::schedule Schedule) { 468 assert(!Schedule.is_null()); 469 470 auto Callback = [](__isl_keep isl_schedule_node *Node, 471 void *User) -> isl_bool { 472 if (isl_schedule_node_get_type(Node) == isl_schedule_node_extension) { 473 // Stop walking the schedule tree. 474 return isl_bool_error; 475 } 476 477 // Continue searching the subtree. 478 return isl_bool_true; 479 }; 480 isl_stat RetVal = isl_schedule_foreach_schedule_node_top_down( 481 Schedule.get(), Callback, nullptr); 482 483 // We assume that the traversal itself does not fail, i.e. the only reason to 484 // return isl_stat_error is that an extension node was found. 485 return RetVal == isl_stat_error; 486 } 487 488 isl::schedule polly::hoistExtensionNodes(isl::schedule Sched) { 489 // If there is no extension node in the first place, return the original 490 // schedule tree. 491 if (!containsExtensionNode(Sched)) 492 return Sched; 493 494 // Build options can anchor schedule nodes, such that the schedule tree cannot 495 // be modified anymore. Therefore, apply build options after the tree has been 496 // created. 497 CollectASTBuildOptions Collector; 498 Collector.visit(Sched); 499 500 // Rewrite the schedule tree without extension nodes. 501 ExtensionNodeRewriter Rewriter; 502 isl::schedule NewSched = Rewriter.visitSchedule(Sched); 503 504 // Reapply the AST build options. The rewriter must not change the iteration 505 // order of bands. Any other node type is ignored. 506 ApplyASTBuildOptions Applicator(Collector.ASTBuildOptions); 507 NewSched = Applicator.visitSchedule(NewSched); 508 509 return NewSched; 510 } 511