13fef2d26SRiver Riddle //===- TestSCFUtils.cpp --- Pass to test independent SCF dialect utils ----===//
23fef2d26SRiver Riddle //
33fef2d26SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43fef2d26SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
53fef2d26SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
63fef2d26SRiver Riddle //
73fef2d26SRiver Riddle //===----------------------------------------------------------------------===//
83fef2d26SRiver Riddle //
93fef2d26SRiver Riddle // This file implements a pass to test SCF dialect utils.
103fef2d26SRiver Riddle //
113fef2d26SRiver Riddle //===----------------------------------------------------------------------===//
123fef2d26SRiver Riddle 
13a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1436550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
15*8b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
16*8b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Transforms.h"
17f40475c7SAdrian Kuegel #include "mlir/Dialect/SCF/Utils/Utils.h"
183fef2d26SRiver Riddle #include "mlir/IR/Builders.h"
19f6f88e66Sthomasraoux #include "mlir/IR/PatternMatch.h"
203fef2d26SRiver Riddle #include "mlir/Pass/Pass.h"
21f6f88e66Sthomasraoux #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
223fef2d26SRiver Riddle 
233fef2d26SRiver Riddle #include "llvm/ADT/SetVector.h"
243fef2d26SRiver Riddle 
253fef2d26SRiver Riddle using namespace mlir;
263fef2d26SRiver Riddle 
273fef2d26SRiver Riddle namespace {
285e50dd04SRiver Riddle struct TestSCFForUtilsPass
2958ceae95SRiver Riddle     : public PassWrapper<TestSCFForUtilsPass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon3ae6e7e50111::TestSCFForUtilsPass305e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFForUtilsPass)
315e50dd04SRiver Riddle 
32b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-scf-for-utils"; }
getDescription__anon3ae6e7e50111::TestSCFForUtilsPass33b5e22e6dSMehdi Amini   StringRef getDescription() const final { return "test scf.for utils"; }
34b5e22e6dSMehdi Amini   explicit TestSCFForUtilsPass() = default;
TestSCFForUtilsPass__anon3ae6e7e50111::TestSCFForUtilsPass35567fd523SMahesh Ravishankar   TestSCFForUtilsPass(const TestSCFForUtilsPass &pass) : PassWrapper(pass) {}
36567fd523SMahesh Ravishankar 
37567fd523SMahesh Ravishankar   Option<bool> testReplaceWithNewYields{
38567fd523SMahesh Ravishankar       *this, "test-replace-with-new-yields",
39567fd523SMahesh Ravishankar       llvm::cl::desc("Test replacing a loop with a new loop that returns new "
40567fd523SMahesh Ravishankar                      "additional yeild values"),
41567fd523SMahesh Ravishankar       llvm::cl::init(false)};
423fef2d26SRiver Riddle 
runOnOperation__anon3ae6e7e50111::TestSCFForUtilsPass4341574554SRiver Riddle   void runOnOperation() override {
4458ceae95SRiver Riddle     func::FuncOp func = getOperation();
453fef2d26SRiver Riddle     SmallVector<scf::ForOp, 4> toErase;
463fef2d26SRiver Riddle 
47567fd523SMahesh Ravishankar     if (testReplaceWithNewYields) {
48567fd523SMahesh Ravishankar       func.walk([&](scf::ForOp forOp) {
49567fd523SMahesh Ravishankar         if (forOp.getNumResults() == 0)
50567fd523SMahesh Ravishankar           return;
51567fd523SMahesh Ravishankar         auto newInitValues = forOp.getInitArgs();
52567fd523SMahesh Ravishankar         if (newInitValues.empty())
53567fd523SMahesh Ravishankar           return;
54567fd523SMahesh Ravishankar         NewYieldValueFn fn = [&](OpBuilder &b, Location loc,
55567fd523SMahesh Ravishankar                                  ArrayRef<BlockArgument> newBBArgs) {
56567fd523SMahesh Ravishankar           Block *block = newBBArgs.front().getOwner();
57567fd523SMahesh Ravishankar           SmallVector<Value> newYieldValues;
58567fd523SMahesh Ravishankar           for (auto yieldVal :
59567fd523SMahesh Ravishankar                cast<scf::YieldOp>(block->getTerminator()).getResults()) {
60567fd523SMahesh Ravishankar             newYieldValues.push_back(
61567fd523SMahesh Ravishankar                 b.create<arith::AddFOp>(loc, yieldVal, yieldVal));
62567fd523SMahesh Ravishankar           }
63567fd523SMahesh Ravishankar           return newYieldValues;
64567fd523SMahesh Ravishankar         };
65567fd523SMahesh Ravishankar         OpBuilder b(forOp);
66567fd523SMahesh Ravishankar         replaceLoopWithNewYields(b, forOp, newInitValues, fn);
67567fd523SMahesh Ravishankar       });
68567fd523SMahesh Ravishankar     }
69567fd523SMahesh Ravishankar   }
703fef2d26SRiver Riddle };
713fef2d26SRiver Riddle 
725e50dd04SRiver Riddle struct TestSCFIfUtilsPass
7311b67aafSNicolas Vasilache     : public PassWrapper<TestSCFIfUtilsPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon3ae6e7e50111::TestSCFIfUtilsPass745e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFIfUtilsPass)
755e50dd04SRiver Riddle 
76b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-scf-if-utils"; }
getDescription__anon3ae6e7e50111::TestSCFIfUtilsPass77b5e22e6dSMehdi Amini   StringRef getDescription() const final { return "test scf.if utils"; }
78b5e22e6dSMehdi Amini   explicit TestSCFIfUtilsPass() = default;
793fef2d26SRiver Riddle 
runOnOperation__anon3ae6e7e50111::TestSCFIfUtilsPass8011b67aafSNicolas Vasilache   void runOnOperation() override {
813fef2d26SRiver Riddle     int count = 0;
8211b67aafSNicolas Vasilache     getOperation().walk([&](scf::IfOp ifOp) {
833fef2d26SRiver Riddle       auto strCount = std::to_string(count++);
8458ceae95SRiver Riddle       func::FuncOp thenFn, elseFn;
853fef2d26SRiver Riddle       OpBuilder b(ifOp);
8611b67aafSNicolas Vasilache       IRRewriter rewriter(b);
8711b67aafSNicolas Vasilache       if (failed(outlineIfOp(rewriter, ifOp, &thenFn,
8811b67aafSNicolas Vasilache                              std::string("outlined_then") + strCount, &elseFn,
8911b67aafSNicolas Vasilache                              std::string("outlined_else") + strCount))) {
9011b67aafSNicolas Vasilache         this->signalPassFailure();
9111b67aafSNicolas Vasilache         return WalkResult::interrupt();
9211b67aafSNicolas Vasilache       }
9311b67aafSNicolas Vasilache       return WalkResult::advance();
943fef2d26SRiver Riddle     });
953fef2d26SRiver Riddle   }
963fef2d26SRiver Riddle };
97f6f88e66Sthomasraoux 
98f6f88e66Sthomasraoux static const StringLiteral kTestPipeliningLoopMarker =
99f6f88e66Sthomasraoux     "__test_pipelining_loop__";
100f6f88e66Sthomasraoux static const StringLiteral kTestPipeliningStageMarker =
101f6f88e66Sthomasraoux     "__test_pipelining_stage__";
102567fd523SMahesh Ravishankar /// Marker to express the order in which operations should be after
103567fd523SMahesh Ravishankar /// pipelining.
104f6f88e66Sthomasraoux static const StringLiteral kTestPipeliningOpOrderMarker =
105f6f88e66Sthomasraoux     "__test_pipelining_op_order__";
106f6f88e66Sthomasraoux 
1070736bbd7SThomas Raoux static const StringLiteral kTestPipeliningAnnotationPart =
1080736bbd7SThomas Raoux     "__test_pipelining_part";
1090736bbd7SThomas Raoux static const StringLiteral kTestPipeliningAnnotationIteration =
1100736bbd7SThomas Raoux     "__test_pipelining_iteration";
1110736bbd7SThomas Raoux 
1125e50dd04SRiver Riddle struct TestSCFPipeliningPass
11358ceae95SRiver Riddle     : public PassWrapper<TestSCFPipeliningPass, OperationPass<func::FuncOp>> {
1145e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFPipeliningPass)
1155e50dd04SRiver Riddle 
1160736bbd7SThomas Raoux   TestSCFPipeliningPass() = default;
TestSCFPipeliningPass__anon3ae6e7e50111::TestSCFPipeliningPass1170736bbd7SThomas Raoux   TestSCFPipeliningPass(const TestSCFPipeliningPass &) {}
getArgument__anon3ae6e7e50111::TestSCFPipeliningPass118f6f88e66Sthomasraoux   StringRef getArgument() const final { return "test-scf-pipelining"; }
getDescription__anon3ae6e7e50111::TestSCFPipeliningPass119f6f88e66Sthomasraoux   StringRef getDescription() const final { return "test scf.forOp pipelining"; }
1200736bbd7SThomas Raoux 
1210736bbd7SThomas Raoux   Option<bool> annotatePipeline{
1220736bbd7SThomas Raoux       *this, "annotate",
1230736bbd7SThomas Raoux       llvm::cl::desc("Annote operations during loop pipelining transformation"),
1240736bbd7SThomas Raoux       llvm::cl::init(false)};
125f6f88e66Sthomasraoux 
126205c08b5SThomas Raoux   Option<bool> noEpiloguePeeling{
127205c08b5SThomas Raoux       *this, "no-epilogue-peeling",
128205c08b5SThomas Raoux       llvm::cl::desc("Use predicates instead of peeling the epilogue."),
129205c08b5SThomas Raoux       llvm::cl::init(false)};
130205c08b5SThomas Raoux 
131f6f88e66Sthomasraoux   static void
getSchedule__anon3ae6e7e50111::TestSCFPipeliningPass132f6f88e66Sthomasraoux   getSchedule(scf::ForOp forOp,
133f6f88e66Sthomasraoux               std::vector<std::pair<Operation *, unsigned>> &schedule) {
134f6f88e66Sthomasraoux     if (!forOp->hasAttr(kTestPipeliningLoopMarker))
135f6f88e66Sthomasraoux       return;
136f6f88e66Sthomasraoux     schedule.resize(forOp.getBody()->getOperations().size() - 1);
137f6f88e66Sthomasraoux     forOp.walk([&schedule](Operation *op) {
138f6f88e66Sthomasraoux       auto attrStage =
139f6f88e66Sthomasraoux           op->getAttrOfType<IntegerAttr>(kTestPipeliningStageMarker);
140f6f88e66Sthomasraoux       auto attrCycle =
141f6f88e66Sthomasraoux           op->getAttrOfType<IntegerAttr>(kTestPipeliningOpOrderMarker);
142f6f88e66Sthomasraoux       if (attrCycle && attrStage) {
143f6f88e66Sthomasraoux         schedule[attrCycle.getInt()] =
144f6f88e66Sthomasraoux             std::make_pair(op, unsigned(attrStage.getInt()));
145f6f88e66Sthomasraoux       }
146f6f88e66Sthomasraoux     });
147f6f88e66Sthomasraoux   }
148f6f88e66Sthomasraoux 
149205c08b5SThomas Raoux   /// Helper to generate "predicated" version of `op`. For simplicity we just
150205c08b5SThomas Raoux   /// wrap the operation in a scf.ifOp operation.
predicateOp__anon3ae6e7e50111::TestSCFPipeliningPass151205c08b5SThomas Raoux   static Operation *predicateOp(Operation *op, Value pred,
152205c08b5SThomas Raoux                                 PatternRewriter &rewriter) {
153205c08b5SThomas Raoux     Location loc = op->getLoc();
154205c08b5SThomas Raoux     auto ifOp =
155205c08b5SThomas Raoux         rewriter.create<scf::IfOp>(loc, op->getResultTypes(), pred, true);
156205c08b5SThomas Raoux     // True branch.
157205c08b5SThomas Raoux     op->moveBefore(&ifOp.getThenRegion().front(),
158205c08b5SThomas Raoux                    ifOp.getThenRegion().front().end());
159205c08b5SThomas Raoux     rewriter.setInsertionPointAfter(op);
160205c08b5SThomas Raoux     rewriter.create<scf::YieldOp>(loc, op->getResults());
161205c08b5SThomas Raoux     // False branch.
162205c08b5SThomas Raoux     rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
163205c08b5SThomas Raoux     SmallVector<Value> zeros;
164205c08b5SThomas Raoux     for (Type type : op->getResultTypes()) {
165205c08b5SThomas Raoux       zeros.push_back(
166205c08b5SThomas Raoux           rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(type)));
167205c08b5SThomas Raoux     }
168205c08b5SThomas Raoux     rewriter.create<scf::YieldOp>(loc, zeros);
169205c08b5SThomas Raoux     return ifOp.getOperation();
170205c08b5SThomas Raoux   }
171205c08b5SThomas Raoux 
annotate__anon3ae6e7e50111::TestSCFPipeliningPass1720736bbd7SThomas Raoux   static void annotate(Operation *op,
1730736bbd7SThomas Raoux                        mlir::scf::PipeliningOption::PipelinerPart part,
1740736bbd7SThomas Raoux                        unsigned iteration) {
1750736bbd7SThomas Raoux     OpBuilder b(op);
1760736bbd7SThomas Raoux     switch (part) {
1770736bbd7SThomas Raoux     case mlir::scf::PipeliningOption::PipelinerPart::Prologue:
1780736bbd7SThomas Raoux       op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("prologue"));
1790736bbd7SThomas Raoux       break;
1800736bbd7SThomas Raoux     case mlir::scf::PipeliningOption::PipelinerPart::Kernel:
1810736bbd7SThomas Raoux       op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("kernel"));
1820736bbd7SThomas Raoux       break;
1830736bbd7SThomas Raoux     case mlir::scf::PipeliningOption::PipelinerPart::Epilogue:
1840736bbd7SThomas Raoux       op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("epilogue"));
1850736bbd7SThomas Raoux       break;
1860736bbd7SThomas Raoux     }
1870736bbd7SThomas Raoux     op->setAttr(kTestPipeliningAnnotationIteration,
1880736bbd7SThomas Raoux                 b.getI32IntegerAttr(iteration));
1890736bbd7SThomas Raoux   }
1900736bbd7SThomas Raoux 
getDependentDialects__anon3ae6e7e50111::TestSCFPipeliningPass191a54f4eaeSMogball   void getDependentDialects(DialectRegistry &registry) const override {
1921f971e23SRiver Riddle     registry.insert<arith::ArithmeticDialect>();
193a54f4eaeSMogball   }
194a54f4eaeSMogball 
runOnOperation__anon3ae6e7e50111::TestSCFPipeliningPass19541574554SRiver Riddle   void runOnOperation() override {
196f6f88e66Sthomasraoux     RewritePatternSet patterns(&getContext());
197f6f88e66Sthomasraoux     mlir::scf::PipeliningOption options;
198f6f88e66Sthomasraoux     options.getScheduleFn = getSchedule;
1990736bbd7SThomas Raoux     if (annotatePipeline)
2000736bbd7SThomas Raoux       options.annotateFn = annotate;
201205c08b5SThomas Raoux     if (noEpiloguePeeling) {
202205c08b5SThomas Raoux       options.peelEpilogue = false;
203205c08b5SThomas Raoux       options.predicateFn = predicateOp;
204205c08b5SThomas Raoux     }
205f6f88e66Sthomasraoux     scf::populateSCFLoopPipeliningPatterns(patterns, options);
20641574554SRiver Riddle     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
20741574554SRiver Riddle     getOperation().walk([](Operation *op) {
208f6f88e66Sthomasraoux       // Clean up the markers.
209f6f88e66Sthomasraoux       op->removeAttr(kTestPipeliningStageMarker);
210f6f88e66Sthomasraoux       op->removeAttr(kTestPipeliningOpOrderMarker);
211f6f88e66Sthomasraoux     });
212f6f88e66Sthomasraoux   }
213f6f88e66Sthomasraoux };
2143fef2d26SRiver Riddle } // namespace
2153fef2d26SRiver Riddle 
2163fef2d26SRiver Riddle namespace mlir {
2173fef2d26SRiver Riddle namespace test {
registerTestSCFUtilsPass()2183fef2d26SRiver Riddle void registerTestSCFUtilsPass() {
219b5e22e6dSMehdi Amini   PassRegistration<TestSCFForUtilsPass>();
220b5e22e6dSMehdi Amini   PassRegistration<TestSCFIfUtilsPass>();
221f6f88e66Sthomasraoux   PassRegistration<TestSCFPipeliningPass>();
2223fef2d26SRiver Riddle }
2233fef2d26SRiver Riddle } // namespace test
2243fef2d26SRiver Riddle } // namespace mlir
225