15f0d4f20SAlex Zinenko //===- SCFTransformOps.cpp - Implementation of SCF transformation ops -----===//
25f0d4f20SAlex Zinenko //
35f0d4f20SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
45f0d4f20SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
55f0d4f20SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65f0d4f20SAlex Zinenko //
75f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===//
85f0d4f20SAlex Zinenko 
95f0d4f20SAlex Zinenko #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
105f0d4f20SAlex Zinenko #include "mlir/Dialect/Affine/IR/AffineOps.h"
115f0d4f20SAlex Zinenko #include "mlir/Dialect/Func/IR/FuncOps.h"
12*333ee218SAlex Zinenko #include "mlir/Dialect/PDL/IR/PDL.h"
138b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
148b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Patterns.h"
158b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Transforms.h"
165f0d4f20SAlex Zinenko #include "mlir/Dialect/SCF/Utils/Utils.h"
175f0d4f20SAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformDialect.h"
18e3890b7fSAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
195f0d4f20SAlex Zinenko #include "mlir/Dialect/Vector/IR/VectorOps.h"
205f0d4f20SAlex Zinenko 
215f0d4f20SAlex Zinenko using namespace mlir;
225f0d4f20SAlex Zinenko 
235f0d4f20SAlex Zinenko namespace {
245f0d4f20SAlex Zinenko /// A simple pattern rewriter that implements no special logic.
255f0d4f20SAlex Zinenko class SimpleRewriter : public PatternRewriter {
265f0d4f20SAlex Zinenko public:
SimpleRewriter(MLIRContext * context)275f0d4f20SAlex Zinenko   SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
285f0d4f20SAlex Zinenko };
295f0d4f20SAlex Zinenko } // namespace
305f0d4f20SAlex Zinenko 
315f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===//
325f0d4f20SAlex Zinenko // GetParentForOp
335f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===//
345f0d4f20SAlex Zinenko 
351d45282aSAlex Zinenko DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)365f0d4f20SAlex Zinenko transform::GetParentForOp::apply(transform::TransformResults &results,
375f0d4f20SAlex Zinenko                                  transform::TransformState &state) {
385f0d4f20SAlex Zinenko   SetVector<Operation *> parents;
395f0d4f20SAlex Zinenko   for (Operation *target : state.getPayloadOps(getTarget())) {
405f0d4f20SAlex Zinenko     scf::ForOp loop;
415f0d4f20SAlex Zinenko     Operation *current = target;
425f0d4f20SAlex Zinenko     for (unsigned i = 0, e = getNumLoops(); i < e; ++i) {
435f0d4f20SAlex Zinenko       loop = current->getParentOfType<scf::ForOp>();
445f0d4f20SAlex Zinenko       if (!loop) {
451d45282aSAlex Zinenko         DiagnosedSilenceableFailure diag = emitSilenceableError()
46e3890b7fSAlex Zinenko                                            << "could not find an '"
475f0d4f20SAlex Zinenko                                            << scf::ForOp::getOperationName()
485f0d4f20SAlex Zinenko                                            << "' parent";
495f0d4f20SAlex Zinenko         diag.attachNote(target->getLoc()) << "target op";
505f0d4f20SAlex Zinenko         return diag;
515f0d4f20SAlex Zinenko       }
525f0d4f20SAlex Zinenko       current = loop;
535f0d4f20SAlex Zinenko     }
545f0d4f20SAlex Zinenko     parents.insert(loop);
555f0d4f20SAlex Zinenko   }
565f0d4f20SAlex Zinenko   results.set(getResult().cast<OpResult>(), parents.getArrayRef());
571d45282aSAlex Zinenko   return DiagnosedSilenceableFailure::success();
585f0d4f20SAlex Zinenko }
595f0d4f20SAlex Zinenko 
605f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===//
615f0d4f20SAlex Zinenko // LoopOutlineOp
625f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===//
635f0d4f20SAlex Zinenko 
645f0d4f20SAlex Zinenko /// Wraps the given operation `op` into an `scf.execute_region` operation. Uses
655f0d4f20SAlex Zinenko /// the provided rewriter for all operations to remain compatible with the
665f0d4f20SAlex Zinenko /// rewriting infra, as opposed to just splicing the op in place.
wrapInExecuteRegion(RewriterBase & b,Operation * op)675f0d4f20SAlex Zinenko static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b,
685f0d4f20SAlex Zinenko                                                 Operation *op) {
695f0d4f20SAlex Zinenko   if (op->getNumRegions() != 1)
705f0d4f20SAlex Zinenko     return nullptr;
715f0d4f20SAlex Zinenko   OpBuilder::InsertionGuard g(b);
725f0d4f20SAlex Zinenko   b.setInsertionPoint(op);
735f0d4f20SAlex Zinenko   scf::ExecuteRegionOp executeRegionOp =
745f0d4f20SAlex Zinenko       b.create<scf::ExecuteRegionOp>(op->getLoc(), op->getResultTypes());
755f0d4f20SAlex Zinenko   {
765f0d4f20SAlex Zinenko     OpBuilder::InsertionGuard g(b);
775f0d4f20SAlex Zinenko     b.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock());
785f0d4f20SAlex Zinenko     Operation *clonedOp = b.cloneWithoutRegions(*op);
795f0d4f20SAlex Zinenko     Region &clonedRegion = clonedOp->getRegions().front();
805f0d4f20SAlex Zinenko     assert(clonedRegion.empty() && "expected empty region");
815f0d4f20SAlex Zinenko     b.inlineRegionBefore(op->getRegions().front(), clonedRegion,
825f0d4f20SAlex Zinenko                          clonedRegion.end());
835f0d4f20SAlex Zinenko     b.create<scf::YieldOp>(op->getLoc(), clonedOp->getResults());
845f0d4f20SAlex Zinenko   }
855f0d4f20SAlex Zinenko   b.replaceOp(op, executeRegionOp.getResults());
865f0d4f20SAlex Zinenko   return executeRegionOp;
875f0d4f20SAlex Zinenko }
885f0d4f20SAlex Zinenko 
891d45282aSAlex Zinenko DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)905f0d4f20SAlex Zinenko transform::LoopOutlineOp::apply(transform::TransformResults &results,
915f0d4f20SAlex Zinenko                                 transform::TransformState &state) {
925f0d4f20SAlex Zinenko   SmallVector<Operation *> transformed;
935f0d4f20SAlex Zinenko   DenseMap<Operation *, SymbolTable> symbolTables;
945f0d4f20SAlex Zinenko   for (Operation *target : state.getPayloadOps(getTarget())) {
955f0d4f20SAlex Zinenko     Location location = target->getLoc();
965f0d4f20SAlex Zinenko     Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target);
975f0d4f20SAlex Zinenko     SimpleRewriter rewriter(getContext());
985f0d4f20SAlex Zinenko     scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target);
995f0d4f20SAlex Zinenko     if (!exec) {
1001d45282aSAlex Zinenko       DiagnosedSilenceableFailure diag = emitSilenceableError()
101e3890b7fSAlex Zinenko                                          << "failed to outline";
1025f0d4f20SAlex Zinenko       diag.attachNote(target->getLoc()) << "target op";
1035f0d4f20SAlex Zinenko       return diag;
1045f0d4f20SAlex Zinenko     }
1055f0d4f20SAlex Zinenko     func::CallOp call;
1065f0d4f20SAlex Zinenko     FailureOr<func::FuncOp> outlined = outlineSingleBlockRegion(
1075f0d4f20SAlex Zinenko         rewriter, location, exec.getRegion(), getFuncName(), &call);
1085f0d4f20SAlex Zinenko 
109e3890b7fSAlex Zinenko     if (failed(outlined)) {
110e3890b7fSAlex Zinenko       (void)reportUnknownTransformError(target);
1111d45282aSAlex Zinenko       return DiagnosedSilenceableFailure::definiteFailure();
112e3890b7fSAlex Zinenko     }
1135f0d4f20SAlex Zinenko 
1145f0d4f20SAlex Zinenko     if (symbolTableOp) {
1155f0d4f20SAlex Zinenko       SymbolTable &symbolTable =
1165f0d4f20SAlex Zinenko           symbolTables.try_emplace(symbolTableOp, symbolTableOp)
1175f0d4f20SAlex Zinenko               .first->getSecond();
1185f0d4f20SAlex Zinenko       symbolTable.insert(*outlined);
1195f0d4f20SAlex Zinenko       call.setCalleeAttr(FlatSymbolRefAttr::get(*outlined));
1205f0d4f20SAlex Zinenko     }
1215f0d4f20SAlex Zinenko     transformed.push_back(*outlined);
1225f0d4f20SAlex Zinenko   }
1235f0d4f20SAlex Zinenko   results.set(getTransformed().cast<OpResult>(), transformed);
1241d45282aSAlex Zinenko   return DiagnosedSilenceableFailure::success();
1255f0d4f20SAlex Zinenko }
1265f0d4f20SAlex Zinenko 
1275f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===//
1285f0d4f20SAlex Zinenko // LoopPeelOp
1295f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===//
1305f0d4f20SAlex Zinenko 
13152307109SNicolas Vasilache DiagnosedSilenceableFailure
applyToOne(scf::ForOp target,SmallVector<Operation * > & results,transform::TransformState & state)13252307109SNicolas Vasilache transform::LoopPeelOp::applyToOne(scf::ForOp target,
13352307109SNicolas Vasilache                                   SmallVector<Operation *> &results,
13452307109SNicolas Vasilache                                   transform::TransformState &state) {
1355f0d4f20SAlex Zinenko   scf::ForOp result;
13652307109SNicolas Vasilache   IRRewriter rewriter(target->getContext());
13752307109SNicolas Vasilache   // This helper returns failure when peeling does not occur (i.e. when the IR
13852307109SNicolas Vasilache   // is not modified). This is not a failure for the op as the postcondition:
13952307109SNicolas Vasilache   //    "the loop trip count is divisible by the step"
14052307109SNicolas Vasilache   // is valid.
1415f0d4f20SAlex Zinenko   LogicalResult status =
14252307109SNicolas Vasilache       scf::peelAndCanonicalizeForLoop(rewriter, target, result);
14352307109SNicolas Vasilache   // TODO: Return both the peeled loop and the remainder loop.
14452307109SNicolas Vasilache   results.push_back(failed(status) ? target : result);
14552307109SNicolas Vasilache   return DiagnosedSilenceableFailure(success());
1465f0d4f20SAlex Zinenko }
1475f0d4f20SAlex Zinenko 
1485f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===//
1495f0d4f20SAlex Zinenko // LoopPipelineOp
1505f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===//
1515f0d4f20SAlex Zinenko 
1525f0d4f20SAlex Zinenko /// Callback for PipeliningOption. Populates `schedule` with the mapping from an
1535f0d4f20SAlex Zinenko /// operation to its logical time position given the iteration interval and the
1545f0d4f20SAlex Zinenko /// read latency. The latter is only relevant for vector transfers.
1555f0d4f20SAlex Zinenko static void
loopScheduling(scf::ForOp forOp,std::vector<std::pair<Operation *,unsigned>> & schedule,unsigned iterationInterval,unsigned readLatency)1565f0d4f20SAlex Zinenko loopScheduling(scf::ForOp forOp,
1575f0d4f20SAlex Zinenko                std::vector<std::pair<Operation *, unsigned>> &schedule,
1585f0d4f20SAlex Zinenko                unsigned iterationInterval, unsigned readLatency) {
1595f0d4f20SAlex Zinenko   auto getLatency = [&](Operation *op) -> unsigned {
1605f0d4f20SAlex Zinenko     if (isa<vector::TransferReadOp>(op))
1615f0d4f20SAlex Zinenko       return readLatency;
1625f0d4f20SAlex Zinenko     return 1;
1635f0d4f20SAlex Zinenko   };
1645f0d4f20SAlex Zinenko 
1655f0d4f20SAlex Zinenko   DenseMap<Operation *, unsigned> opCycles;
1665f0d4f20SAlex Zinenko   std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
1675f0d4f20SAlex Zinenko   for (Operation &op : forOp.getBody()->getOperations()) {
1685f0d4f20SAlex Zinenko     if (isa<scf::YieldOp>(op))
1695f0d4f20SAlex Zinenko       continue;
1705f0d4f20SAlex Zinenko     unsigned earlyCycle = 0;
1715f0d4f20SAlex Zinenko     for (Value operand : op.getOperands()) {
1725f0d4f20SAlex Zinenko       Operation *def = operand.getDefiningOp();
1735f0d4f20SAlex Zinenko       if (!def)
1745f0d4f20SAlex Zinenko         continue;
1755f0d4f20SAlex Zinenko       earlyCycle = std::max(earlyCycle, opCycles[def] + getLatency(def));
1765f0d4f20SAlex Zinenko     }
1775f0d4f20SAlex Zinenko     opCycles[&op] = earlyCycle;
1785f0d4f20SAlex Zinenko     wrappedSchedule[earlyCycle % iterationInterval].push_back(&op);
1795f0d4f20SAlex Zinenko   }
1808ab925a2SAdrian Kuegel   for (const auto &it : wrappedSchedule) {
1815f0d4f20SAlex Zinenko     for (Operation *op : it.second) {
1825f0d4f20SAlex Zinenko       unsigned cycle = opCycles[op];
183cd417c6aSMehdi Amini       schedule.emplace_back(op, cycle / iterationInterval);
1845f0d4f20SAlex Zinenko     }
1855f0d4f20SAlex Zinenko   }
1865f0d4f20SAlex Zinenko }
1875f0d4f20SAlex Zinenko 
18852307109SNicolas Vasilache DiagnosedSilenceableFailure
applyToOne(scf::ForOp target,SmallVector<Operation * > & results,transform::TransformState & state)18952307109SNicolas Vasilache transform::LoopPipelineOp::applyToOne(scf::ForOp target,
19052307109SNicolas Vasilache                                       SmallVector<Operation *> &results,
19152307109SNicolas Vasilache                                       transform::TransformState &state) {
1925f0d4f20SAlex Zinenko   scf::PipeliningOption options;
1935f0d4f20SAlex Zinenko   options.getScheduleFn =
1945f0d4f20SAlex Zinenko       [this](scf::ForOp forOp,
1955f0d4f20SAlex Zinenko              std::vector<std::pair<Operation *, unsigned>> &schedule) mutable {
1965f0d4f20SAlex Zinenko         loopScheduling(forOp, schedule, getIterationInterval(),
1975f0d4f20SAlex Zinenko                        getReadLatency());
1985f0d4f20SAlex Zinenko       };
19952307109SNicolas Vasilache   scf::ForLoopPipeliningPattern pattern(options, target->getContext());
2005f0d4f20SAlex Zinenko   SimpleRewriter rewriter(getContext());
20152307109SNicolas Vasilache   rewriter.setInsertionPoint(target);
2025f0d4f20SAlex Zinenko   FailureOr<scf::ForOp> patternResult =
20352307109SNicolas Vasilache       pattern.returningMatchAndRewrite(target, rewriter);
20452307109SNicolas Vasilache   if (succeeded(patternResult)) {
20552307109SNicolas Vasilache     results.push_back(*patternResult);
20652307109SNicolas Vasilache     return DiagnosedSilenceableFailure(success());
20752307109SNicolas Vasilache   }
20852307109SNicolas Vasilache   results.assign(1, nullptr);
20952307109SNicolas Vasilache   return emitDefaultSilenceableFailure(target);
2105f0d4f20SAlex Zinenko }
2115f0d4f20SAlex Zinenko 
2125f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===//
2135f0d4f20SAlex Zinenko // LoopUnrollOp
2145f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===//
2155f0d4f20SAlex Zinenko 
21652307109SNicolas Vasilache DiagnosedSilenceableFailure
applyToOne(scf::ForOp target,SmallVector<Operation * > & results,transform::TransformState & state)21752307109SNicolas Vasilache transform::LoopUnrollOp::applyToOne(scf::ForOp target,
21852307109SNicolas Vasilache                                     SmallVector<Operation *> &results,
21952307109SNicolas Vasilache                                     transform::TransformState &state) {
22052307109SNicolas Vasilache   if (failed(loopUnrollByFactor(target, getFactor()))) {
22152307109SNicolas Vasilache     Diagnostic diag(target->getLoc(), DiagnosticSeverity::Note);
22252307109SNicolas Vasilache     diag << "op failed to unroll";
22352307109SNicolas Vasilache     return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
22452307109SNicolas Vasilache   }
22552307109SNicolas Vasilache   return DiagnosedSilenceableFailure(success());
2265f0d4f20SAlex Zinenko }
2275f0d4f20SAlex Zinenko 
2285f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===//
2295f0d4f20SAlex Zinenko // Transform op registration
2305f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===//
2315f0d4f20SAlex Zinenko 
2325f0d4f20SAlex Zinenko namespace {
2335f0d4f20SAlex Zinenko class SCFTransformDialectExtension
2345f0d4f20SAlex Zinenko     : public transform::TransformDialectExtension<
2355f0d4f20SAlex Zinenko           SCFTransformDialectExtension> {
2365f0d4f20SAlex Zinenko public:
237*333ee218SAlex Zinenko   using Base::Base;
238*333ee218SAlex Zinenko 
init()239*333ee218SAlex Zinenko   void init() {
240*333ee218SAlex Zinenko     declareDependentDialect<pdl::PDLDialect>();
241*333ee218SAlex Zinenko 
242*333ee218SAlex Zinenko     declareGeneratedDialect<AffineDialect>();
243*333ee218SAlex Zinenko     declareGeneratedDialect<func::FuncDialect>();
244*333ee218SAlex Zinenko 
2455f0d4f20SAlex Zinenko     registerTransformOps<
2465f0d4f20SAlex Zinenko #define GET_OP_LIST
2475f0d4f20SAlex Zinenko #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
2485f0d4f20SAlex Zinenko         >();
2495f0d4f20SAlex Zinenko   }
2505f0d4f20SAlex Zinenko };
2515f0d4f20SAlex Zinenko } // namespace
2525f0d4f20SAlex Zinenko 
2535f0d4f20SAlex Zinenko #define GET_OP_CLASSES
2545f0d4f20SAlex Zinenko #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
2555f0d4f20SAlex Zinenko 
registerTransformDialectExtension(DialectRegistry & registry)2565f0d4f20SAlex Zinenko void mlir::scf::registerTransformDialectExtension(DialectRegistry &registry) {
2575f0d4f20SAlex Zinenko   registry.addExtensions<SCFTransformDialectExtension>();
2585f0d4f20SAlex Zinenko }
259