1884a6291SStephan Herhut //===- InlineScalarOperands.cpp - Pass to inline scalar operands =============//
2884a6291SStephan Herhut //
3884a6291SStephan Herhut // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4884a6291SStephan Herhut // See https://llvm.org/LICENSE.txt for license information.
5884a6291SStephan Herhut // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6884a6291SStephan Herhut //
7884a6291SStephan Herhut //===----------------------------------------------------------------------===//
8884a6291SStephan Herhut //
9884a6291SStephan Herhut // This file implements patterns/pass to inline scalar operands into a generic
10884a6291SStephan Herhut // operation. A scalar operand is an operand whose indexing map has a constant
11884a6291SStephan Herhut // rhs.
12884a6291SStephan Herhut //
13884a6291SStephan Herhut //===----------------------------------------------------------------------===//
14884a6291SStephan Herhut
15884a6291SStephan Herhut #include "PassDetail.h"
16a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h"
18884a6291SStephan Herhut #include "mlir/Dialect/Linalg/Passes.h"
19884a6291SStephan Herhut #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20884a6291SStephan Herhut #include "mlir/IR/AffineExpr.h"
21884a6291SStephan Herhut #include "mlir/IR/AffineMap.h"
22884a6291SStephan Herhut #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23884a6291SStephan Herhut
24884a6291SStephan Herhut using namespace mlir;
25884a6291SStephan Herhut using namespace mlir::linalg;
26884a6291SStephan Herhut
27884a6291SStephan Herhut namespace {
28884a6291SStephan Herhut struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
29884a6291SStephan Herhut using OpRewritePattern<GenericOp>::OpRewritePattern;
matchAndRewrite__anonbcb454d30111::InlineScalarOperands30884a6291SStephan Herhut LogicalResult matchAndRewrite(GenericOp genericOp,
31884a6291SStephan Herhut PatternRewriter &rewriter) const override {
32884a6291SStephan Herhut if (!genericOp.hasTensorSemantics())
33884a6291SStephan Herhut return failure();
34884a6291SStephan Herhut
35884a6291SStephan Herhut SmallVector<size_t> scalarOperands;
36884a6291SStephan Herhut SmallVector<AffineMap> newIndexingMaps;
37884a6291SStephan Herhut SmallVector<Value> newOperands;
38f44e90b9STobias Gysi for (OpOperand *opOperand : genericOp.getInputOperands()) {
39f44e90b9STobias Gysi AffineMap map = genericOp.getTiedIndexingMap(opOperand);
40f44e90b9STobias Gysi if (genericOp.isInputTensor(opOperand) && map.isConstant()) {
41f44e90b9STobias Gysi scalarOperands.emplace_back(opOperand->getOperandNumber());
42884a6291SStephan Herhut } else {
43884a6291SStephan Herhut newIndexingMaps.emplace_back(map);
44f44e90b9STobias Gysi newOperands.emplace_back(opOperand->get());
45884a6291SStephan Herhut }
46884a6291SStephan Herhut }
47884a6291SStephan Herhut
48884a6291SStephan Herhut if (scalarOperands.empty())
49884a6291SStephan Herhut return failure();
50884a6291SStephan Herhut
51f44e90b9STobias Gysi for (OpOperand *opOperand : genericOp.getOutputOperands())
52f44e90b9STobias Gysi newIndexingMaps.emplace_back(genericOp.getTiedIndexingMap(opOperand));
53884a6291SStephan Herhut
54884a6291SStephan Herhut Location loc = genericOp->getLoc();
55f44e90b9STobias Gysi SmallVector<Value> outputOperands = genericOp.getOutputOperands();
56884a6291SStephan Herhut auto newOp = rewriter.create<GenericOp>(
57f44e90b9STobias Gysi loc, genericOp->getResultTypes(), newOperands, outputOperands,
58f44e90b9STobias Gysi newIndexingMaps,
59884a6291SStephan Herhut llvm::to_vector<4>(
60884a6291SStephan Herhut genericOp.iterator_types().template getAsValueRange<StringAttr>()));
61884a6291SStephan Herhut rewriter.cloneRegionBefore(genericOp.region(), newOp.region(),
62884a6291SStephan Herhut newOp.region().begin());
63884a6291SStephan Herhut
64884a6291SStephan Herhut Block *body = newOp.getBody();
65884a6291SStephan Herhut PatternRewriter::InsertionGuard guard(rewriter);
66884a6291SStephan Herhut rewriter.setInsertionPointToStart(body);
67884a6291SStephan Herhut
68884a6291SStephan Herhut for (auto idx : llvm::reverse(scalarOperands)) {
69f44e90b9STobias Gysi OpOperand *opOperand = genericOp.getInputOperand(idx);
70f44e90b9STobias Gysi AffineMap map = genericOp.getTiedIndexingMap(opOperand);
71884a6291SStephan Herhut SmallVector<int64_t> indices = map.getConstantResults();
72884a6291SStephan Herhut SmallVector<Value> indicesValues;
73884a6291SStephan Herhut for (auto idx : indices)
74a54f4eaeSMogball indicesValues.emplace_back(
75a54f4eaeSMogball rewriter.create<arith::ConstantIndexOp>(loc, idx));
76f44e90b9STobias Gysi Value extractedValue = rewriter.create<tensor::ExtractOp>(
77f44e90b9STobias Gysi loc, opOperand->get(), indicesValues);
78f44e90b9STobias Gysi body->getArgument(idx).replaceAllUsesWith(extractedValue);
79884a6291SStephan Herhut body->eraseArgument(idx);
80884a6291SStephan Herhut }
81884a6291SStephan Herhut
82884a6291SStephan Herhut rewriter.replaceOp(genericOp, newOp->getResults());
83884a6291SStephan Herhut return success();
84884a6291SStephan Herhut }
85884a6291SStephan Herhut };
86884a6291SStephan Herhut } // namespace
87884a6291SStephan Herhut
88884a6291SStephan Herhut /// Patterns that are used to inline constant operands into linalg generic
89884a6291SStephan Herhut /// ops.
populateInlineConstantOperandsPatterns(RewritePatternSet & patterns)90884a6291SStephan Herhut void mlir::linalg::populateInlineConstantOperandsPatterns(
91884a6291SStephan Herhut RewritePatternSet &patterns) {
92884a6291SStephan Herhut auto *context = patterns.getContext();
93884a6291SStephan Herhut patterns.add<InlineScalarOperands>(context);
94884a6291SStephan Herhut }
95884a6291SStephan Herhut
96884a6291SStephan Herhut namespace {
97884a6291SStephan Herhut /// Pass that removes unit-extent dims within generic ops.
98884a6291SStephan Herhut struct LinalgInlineScalarOperandsPass
99884a6291SStephan Herhut : public LinalgInlineScalarOperandsBase<LinalgInlineScalarOperandsPass> {
runOnOperation__anonbcb454d30211::LinalgInlineScalarOperandsPass10041574554SRiver Riddle void runOnOperation() override {
101*58ceae95SRiver Riddle func::FuncOp funcOp = getOperation();
102884a6291SStephan Herhut MLIRContext *context = funcOp.getContext();
103884a6291SStephan Herhut RewritePatternSet patterns(context);
104884a6291SStephan Herhut
105884a6291SStephan Herhut populateInlineConstantOperandsPatterns(patterns);
106884a6291SStephan Herhut (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
107884a6291SStephan Herhut }
108884a6291SStephan Herhut };
109884a6291SStephan Herhut } // namespace
110884a6291SStephan Herhut
111*58ceae95SRiver Riddle std::unique_ptr<OperationPass<func::FuncOp>>
createLinalgInlineScalarOperandsPass()112884a6291SStephan Herhut mlir::createLinalgInlineScalarOperandsPass() {
113884a6291SStephan Herhut return std::make_unique<LinalgInlineScalarOperandsPass>();
114884a6291SStephan Herhut }
115