15f0d4f20SAlex Zinenko //===- SCFTransformOps.cpp - Implementation of SCF transformation ops -----===//
25f0d4f20SAlex Zinenko //
35f0d4f20SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
45f0d4f20SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
55f0d4f20SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65f0d4f20SAlex Zinenko //
75f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===//
85f0d4f20SAlex Zinenko
95f0d4f20SAlex Zinenko #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
105f0d4f20SAlex Zinenko #include "mlir/Dialect/Affine/IR/AffineOps.h"
115f0d4f20SAlex Zinenko #include "mlir/Dialect/Func/IR/FuncOps.h"
12*333ee218SAlex Zinenko #include "mlir/Dialect/PDL/IR/PDL.h"
138b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
148b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Patterns.h"
158b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Transforms.h"
165f0d4f20SAlex Zinenko #include "mlir/Dialect/SCF/Utils/Utils.h"
175f0d4f20SAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformDialect.h"
18e3890b7fSAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
195f0d4f20SAlex Zinenko #include "mlir/Dialect/Vector/IR/VectorOps.h"
205f0d4f20SAlex Zinenko
215f0d4f20SAlex Zinenko using namespace mlir;
225f0d4f20SAlex Zinenko
235f0d4f20SAlex Zinenko namespace {
245f0d4f20SAlex Zinenko /// A simple pattern rewriter that implements no special logic.
255f0d4f20SAlex Zinenko class SimpleRewriter : public PatternRewriter {
265f0d4f20SAlex Zinenko public:
SimpleRewriter(MLIRContext * context)275f0d4f20SAlex Zinenko SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
285f0d4f20SAlex Zinenko };
295f0d4f20SAlex Zinenko } // namespace
305f0d4f20SAlex Zinenko
315f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===//
325f0d4f20SAlex Zinenko // GetParentForOp
335f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===//
345f0d4f20SAlex Zinenko
351d45282aSAlex Zinenko DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)365f0d4f20SAlex Zinenko transform::GetParentForOp::apply(transform::TransformResults &results,
375f0d4f20SAlex Zinenko transform::TransformState &state) {
385f0d4f20SAlex Zinenko SetVector<Operation *> parents;
395f0d4f20SAlex Zinenko for (Operation *target : state.getPayloadOps(getTarget())) {
405f0d4f20SAlex Zinenko scf::ForOp loop;
415f0d4f20SAlex Zinenko Operation *current = target;
425f0d4f20SAlex Zinenko for (unsigned i = 0, e = getNumLoops(); i < e; ++i) {
435f0d4f20SAlex Zinenko loop = current->getParentOfType<scf::ForOp>();
445f0d4f20SAlex Zinenko if (!loop) {
451d45282aSAlex Zinenko DiagnosedSilenceableFailure diag = emitSilenceableError()
46e3890b7fSAlex Zinenko << "could not find an '"
475f0d4f20SAlex Zinenko << scf::ForOp::getOperationName()
485f0d4f20SAlex Zinenko << "' parent";
495f0d4f20SAlex Zinenko diag.attachNote(target->getLoc()) << "target op";
505f0d4f20SAlex Zinenko return diag;
515f0d4f20SAlex Zinenko }
525f0d4f20SAlex Zinenko current = loop;
535f0d4f20SAlex Zinenko }
545f0d4f20SAlex Zinenko parents.insert(loop);
555f0d4f20SAlex Zinenko }
565f0d4f20SAlex Zinenko results.set(getResult().cast<OpResult>(), parents.getArrayRef());
571d45282aSAlex Zinenko return DiagnosedSilenceableFailure::success();
585f0d4f20SAlex Zinenko }
595f0d4f20SAlex Zinenko
605f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===//
615f0d4f20SAlex Zinenko // LoopOutlineOp
625f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===//
635f0d4f20SAlex Zinenko
645f0d4f20SAlex Zinenko /// Wraps the given operation `op` into an `scf.execute_region` operation. Uses
655f0d4f20SAlex Zinenko /// the provided rewriter for all operations to remain compatible with the
665f0d4f20SAlex Zinenko /// rewriting infra, as opposed to just splicing the op in place.
wrapInExecuteRegion(RewriterBase & b,Operation * op)675f0d4f20SAlex Zinenko static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b,
685f0d4f20SAlex Zinenko Operation *op) {
695f0d4f20SAlex Zinenko if (op->getNumRegions() != 1)
705f0d4f20SAlex Zinenko return nullptr;
715f0d4f20SAlex Zinenko OpBuilder::InsertionGuard g(b);
725f0d4f20SAlex Zinenko b.setInsertionPoint(op);
735f0d4f20SAlex Zinenko scf::ExecuteRegionOp executeRegionOp =
745f0d4f20SAlex Zinenko b.create<scf::ExecuteRegionOp>(op->getLoc(), op->getResultTypes());
755f0d4f20SAlex Zinenko {
765f0d4f20SAlex Zinenko OpBuilder::InsertionGuard g(b);
775f0d4f20SAlex Zinenko b.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock());
785f0d4f20SAlex Zinenko Operation *clonedOp = b.cloneWithoutRegions(*op);
795f0d4f20SAlex Zinenko Region &clonedRegion = clonedOp->getRegions().front();
805f0d4f20SAlex Zinenko assert(clonedRegion.empty() && "expected empty region");
815f0d4f20SAlex Zinenko b.inlineRegionBefore(op->getRegions().front(), clonedRegion,
825f0d4f20SAlex Zinenko clonedRegion.end());
835f0d4f20SAlex Zinenko b.create<scf::YieldOp>(op->getLoc(), clonedOp->getResults());
845f0d4f20SAlex Zinenko }
855f0d4f20SAlex Zinenko b.replaceOp(op, executeRegionOp.getResults());
865f0d4f20SAlex Zinenko return executeRegionOp;
875f0d4f20SAlex Zinenko }
885f0d4f20SAlex Zinenko
891d45282aSAlex Zinenko DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)905f0d4f20SAlex Zinenko transform::LoopOutlineOp::apply(transform::TransformResults &results,
915f0d4f20SAlex Zinenko transform::TransformState &state) {
925f0d4f20SAlex Zinenko SmallVector<Operation *> transformed;
935f0d4f20SAlex Zinenko DenseMap<Operation *, SymbolTable> symbolTables;
945f0d4f20SAlex Zinenko for (Operation *target : state.getPayloadOps(getTarget())) {
955f0d4f20SAlex Zinenko Location location = target->getLoc();
965f0d4f20SAlex Zinenko Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target);
975f0d4f20SAlex Zinenko SimpleRewriter rewriter(getContext());
985f0d4f20SAlex Zinenko scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target);
995f0d4f20SAlex Zinenko if (!exec) {
1001d45282aSAlex Zinenko DiagnosedSilenceableFailure diag = emitSilenceableError()
101e3890b7fSAlex Zinenko << "failed to outline";
1025f0d4f20SAlex Zinenko diag.attachNote(target->getLoc()) << "target op";
1035f0d4f20SAlex Zinenko return diag;
1045f0d4f20SAlex Zinenko }
1055f0d4f20SAlex Zinenko func::CallOp call;
1065f0d4f20SAlex Zinenko FailureOr<func::FuncOp> outlined = outlineSingleBlockRegion(
1075f0d4f20SAlex Zinenko rewriter, location, exec.getRegion(), getFuncName(), &call);
1085f0d4f20SAlex Zinenko
109e3890b7fSAlex Zinenko if (failed(outlined)) {
110e3890b7fSAlex Zinenko (void)reportUnknownTransformError(target);
1111d45282aSAlex Zinenko return DiagnosedSilenceableFailure::definiteFailure();
112e3890b7fSAlex Zinenko }
1135f0d4f20SAlex Zinenko
1145f0d4f20SAlex Zinenko if (symbolTableOp) {
1155f0d4f20SAlex Zinenko SymbolTable &symbolTable =
1165f0d4f20SAlex Zinenko symbolTables.try_emplace(symbolTableOp, symbolTableOp)
1175f0d4f20SAlex Zinenko .first->getSecond();
1185f0d4f20SAlex Zinenko symbolTable.insert(*outlined);
1195f0d4f20SAlex Zinenko call.setCalleeAttr(FlatSymbolRefAttr::get(*outlined));
1205f0d4f20SAlex Zinenko }
1215f0d4f20SAlex Zinenko transformed.push_back(*outlined);
1225f0d4f20SAlex Zinenko }
1235f0d4f20SAlex Zinenko results.set(getTransformed().cast<OpResult>(), transformed);
1241d45282aSAlex Zinenko return DiagnosedSilenceableFailure::success();
1255f0d4f20SAlex Zinenko }
1265f0d4f20SAlex Zinenko
1275f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===//
1285f0d4f20SAlex Zinenko // LoopPeelOp
1295f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===//
1305f0d4f20SAlex Zinenko
13152307109SNicolas Vasilache DiagnosedSilenceableFailure
applyToOne(scf::ForOp target,SmallVector<Operation * > & results,transform::TransformState & state)13252307109SNicolas Vasilache transform::LoopPeelOp::applyToOne(scf::ForOp target,
13352307109SNicolas Vasilache SmallVector<Operation *> &results,
13452307109SNicolas Vasilache transform::TransformState &state) {
1355f0d4f20SAlex Zinenko scf::ForOp result;
13652307109SNicolas Vasilache IRRewriter rewriter(target->getContext());
13752307109SNicolas Vasilache // This helper returns failure when peeling does not occur (i.e. when the IR
13852307109SNicolas Vasilache // is not modified). This is not a failure for the op as the postcondition:
13952307109SNicolas Vasilache // "the loop trip count is divisible by the step"
14052307109SNicolas Vasilache // is valid.
1415f0d4f20SAlex Zinenko LogicalResult status =
14252307109SNicolas Vasilache scf::peelAndCanonicalizeForLoop(rewriter, target, result);
14352307109SNicolas Vasilache // TODO: Return both the peeled loop and the remainder loop.
14452307109SNicolas Vasilache results.push_back(failed(status) ? target : result);
14552307109SNicolas Vasilache return DiagnosedSilenceableFailure(success());
1465f0d4f20SAlex Zinenko }
1475f0d4f20SAlex Zinenko
1485f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===//
1495f0d4f20SAlex Zinenko // LoopPipelineOp
1505f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===//
1515f0d4f20SAlex Zinenko
1525f0d4f20SAlex Zinenko /// Callback for PipeliningOption. Populates `schedule` with the mapping from an
1535f0d4f20SAlex Zinenko /// operation to its logical time position given the iteration interval and the
1545f0d4f20SAlex Zinenko /// read latency. The latter is only relevant for vector transfers.
1555f0d4f20SAlex Zinenko static void
loopScheduling(scf::ForOp forOp,std::vector<std::pair<Operation *,unsigned>> & schedule,unsigned iterationInterval,unsigned readLatency)1565f0d4f20SAlex Zinenko loopScheduling(scf::ForOp forOp,
1575f0d4f20SAlex Zinenko std::vector<std::pair<Operation *, unsigned>> &schedule,
1585f0d4f20SAlex Zinenko unsigned iterationInterval, unsigned readLatency) {
1595f0d4f20SAlex Zinenko auto getLatency = [&](Operation *op) -> unsigned {
1605f0d4f20SAlex Zinenko if (isa<vector::TransferReadOp>(op))
1615f0d4f20SAlex Zinenko return readLatency;
1625f0d4f20SAlex Zinenko return 1;
1635f0d4f20SAlex Zinenko };
1645f0d4f20SAlex Zinenko
1655f0d4f20SAlex Zinenko DenseMap<Operation *, unsigned> opCycles;
1665f0d4f20SAlex Zinenko std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
1675f0d4f20SAlex Zinenko for (Operation &op : forOp.getBody()->getOperations()) {
1685f0d4f20SAlex Zinenko if (isa<scf::YieldOp>(op))
1695f0d4f20SAlex Zinenko continue;
1705f0d4f20SAlex Zinenko unsigned earlyCycle = 0;
1715f0d4f20SAlex Zinenko for (Value operand : op.getOperands()) {
1725f0d4f20SAlex Zinenko Operation *def = operand.getDefiningOp();
1735f0d4f20SAlex Zinenko if (!def)
1745f0d4f20SAlex Zinenko continue;
1755f0d4f20SAlex Zinenko earlyCycle = std::max(earlyCycle, opCycles[def] + getLatency(def));
1765f0d4f20SAlex Zinenko }
1775f0d4f20SAlex Zinenko opCycles[&op] = earlyCycle;
1785f0d4f20SAlex Zinenko wrappedSchedule[earlyCycle % iterationInterval].push_back(&op);
1795f0d4f20SAlex Zinenko }
1808ab925a2SAdrian Kuegel for (const auto &it : wrappedSchedule) {
1815f0d4f20SAlex Zinenko for (Operation *op : it.second) {
1825f0d4f20SAlex Zinenko unsigned cycle = opCycles[op];
183cd417c6aSMehdi Amini schedule.emplace_back(op, cycle / iterationInterval);
1845f0d4f20SAlex Zinenko }
1855f0d4f20SAlex Zinenko }
1865f0d4f20SAlex Zinenko }
1875f0d4f20SAlex Zinenko
18852307109SNicolas Vasilache DiagnosedSilenceableFailure
applyToOne(scf::ForOp target,SmallVector<Operation * > & results,transform::TransformState & state)18952307109SNicolas Vasilache transform::LoopPipelineOp::applyToOne(scf::ForOp target,
19052307109SNicolas Vasilache SmallVector<Operation *> &results,
19152307109SNicolas Vasilache transform::TransformState &state) {
1925f0d4f20SAlex Zinenko scf::PipeliningOption options;
1935f0d4f20SAlex Zinenko options.getScheduleFn =
1945f0d4f20SAlex Zinenko [this](scf::ForOp forOp,
1955f0d4f20SAlex Zinenko std::vector<std::pair<Operation *, unsigned>> &schedule) mutable {
1965f0d4f20SAlex Zinenko loopScheduling(forOp, schedule, getIterationInterval(),
1975f0d4f20SAlex Zinenko getReadLatency());
1985f0d4f20SAlex Zinenko };
19952307109SNicolas Vasilache scf::ForLoopPipeliningPattern pattern(options, target->getContext());
2005f0d4f20SAlex Zinenko SimpleRewriter rewriter(getContext());
20152307109SNicolas Vasilache rewriter.setInsertionPoint(target);
2025f0d4f20SAlex Zinenko FailureOr<scf::ForOp> patternResult =
20352307109SNicolas Vasilache pattern.returningMatchAndRewrite(target, rewriter);
20452307109SNicolas Vasilache if (succeeded(patternResult)) {
20552307109SNicolas Vasilache results.push_back(*patternResult);
20652307109SNicolas Vasilache return DiagnosedSilenceableFailure(success());
20752307109SNicolas Vasilache }
20852307109SNicolas Vasilache results.assign(1, nullptr);
20952307109SNicolas Vasilache return emitDefaultSilenceableFailure(target);
2105f0d4f20SAlex Zinenko }
2115f0d4f20SAlex Zinenko
2125f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===//
2135f0d4f20SAlex Zinenko // LoopUnrollOp
2145f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===//
2155f0d4f20SAlex Zinenko
21652307109SNicolas Vasilache DiagnosedSilenceableFailure
applyToOne(scf::ForOp target,SmallVector<Operation * > & results,transform::TransformState & state)21752307109SNicolas Vasilache transform::LoopUnrollOp::applyToOne(scf::ForOp target,
21852307109SNicolas Vasilache SmallVector<Operation *> &results,
21952307109SNicolas Vasilache transform::TransformState &state) {
22052307109SNicolas Vasilache if (failed(loopUnrollByFactor(target, getFactor()))) {
22152307109SNicolas Vasilache Diagnostic diag(target->getLoc(), DiagnosticSeverity::Note);
22252307109SNicolas Vasilache diag << "op failed to unroll";
22352307109SNicolas Vasilache return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
22452307109SNicolas Vasilache }
22552307109SNicolas Vasilache return DiagnosedSilenceableFailure(success());
2265f0d4f20SAlex Zinenko }
2275f0d4f20SAlex Zinenko
2285f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===//
2295f0d4f20SAlex Zinenko // Transform op registration
2305f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===//
2315f0d4f20SAlex Zinenko
2325f0d4f20SAlex Zinenko namespace {
2335f0d4f20SAlex Zinenko class SCFTransformDialectExtension
2345f0d4f20SAlex Zinenko : public transform::TransformDialectExtension<
2355f0d4f20SAlex Zinenko SCFTransformDialectExtension> {
2365f0d4f20SAlex Zinenko public:
237*333ee218SAlex Zinenko using Base::Base;
238*333ee218SAlex Zinenko
init()239*333ee218SAlex Zinenko void init() {
240*333ee218SAlex Zinenko declareDependentDialect<pdl::PDLDialect>();
241*333ee218SAlex Zinenko
242*333ee218SAlex Zinenko declareGeneratedDialect<AffineDialect>();
243*333ee218SAlex Zinenko declareGeneratedDialect<func::FuncDialect>();
244*333ee218SAlex Zinenko
2455f0d4f20SAlex Zinenko registerTransformOps<
2465f0d4f20SAlex Zinenko #define GET_OP_LIST
2475f0d4f20SAlex Zinenko #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
2485f0d4f20SAlex Zinenko >();
2495f0d4f20SAlex Zinenko }
2505f0d4f20SAlex Zinenko };
2515f0d4f20SAlex Zinenko } // namespace
2525f0d4f20SAlex Zinenko
2535f0d4f20SAlex Zinenko #define GET_OP_CLASSES
2545f0d4f20SAlex Zinenko #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
2555f0d4f20SAlex Zinenko
registerTransformDialectExtension(DialectRegistry & registry)2565f0d4f20SAlex Zinenko void mlir::scf::registerTransformDialectExtension(DialectRegistry ®istry) {
2575f0d4f20SAlex Zinenko registry.addExtensions<SCFTransformDialectExtension>();
2585f0d4f20SAlex Zinenko }
259