1884a6291SStephan Herhut //===- InlineScalarOperands.cpp - Pass to inline scalar operands =============//
2884a6291SStephan Herhut //
3884a6291SStephan Herhut // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4884a6291SStephan Herhut // See https://llvm.org/LICENSE.txt for license information.
5884a6291SStephan Herhut // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6884a6291SStephan Herhut //
7884a6291SStephan Herhut //===----------------------------------------------------------------------===//
8884a6291SStephan Herhut //
9884a6291SStephan Herhut // This file implements patterns/pass to inline scalar operands into a generic
10884a6291SStephan Herhut // operation. A scalar operand is an operand whose indexing map has a constant
11884a6291SStephan Herhut // rhs.
12884a6291SStephan Herhut //
13884a6291SStephan Herhut //===----------------------------------------------------------------------===//
14884a6291SStephan Herhut 
15884a6291SStephan Herhut #include "PassDetail.h"
16a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h"
18884a6291SStephan Herhut #include "mlir/Dialect/Linalg/Passes.h"
19884a6291SStephan Herhut #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20884a6291SStephan Herhut #include "mlir/IR/AffineExpr.h"
21884a6291SStephan Herhut #include "mlir/IR/AffineMap.h"
22884a6291SStephan Herhut #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23884a6291SStephan Herhut 
24884a6291SStephan Herhut using namespace mlir;
25884a6291SStephan Herhut using namespace mlir::linalg;
26884a6291SStephan Herhut 
27884a6291SStephan Herhut namespace {
28884a6291SStephan Herhut struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
29884a6291SStephan Herhut   using OpRewritePattern<GenericOp>::OpRewritePattern;
matchAndRewrite__anonbcb454d30111::InlineScalarOperands30884a6291SStephan Herhut   LogicalResult matchAndRewrite(GenericOp genericOp,
31884a6291SStephan Herhut                                 PatternRewriter &rewriter) const override {
32884a6291SStephan Herhut     if (!genericOp.hasTensorSemantics())
33884a6291SStephan Herhut       return failure();
34884a6291SStephan Herhut 
35884a6291SStephan Herhut     SmallVector<size_t> scalarOperands;
36884a6291SStephan Herhut     SmallVector<AffineMap> newIndexingMaps;
37884a6291SStephan Herhut     SmallVector<Value> newOperands;
38f44e90b9STobias Gysi     for (OpOperand *opOperand : genericOp.getInputOperands()) {
39f44e90b9STobias Gysi       AffineMap map = genericOp.getTiedIndexingMap(opOperand);
40f44e90b9STobias Gysi       if (genericOp.isInputTensor(opOperand) && map.isConstant()) {
41f44e90b9STobias Gysi         scalarOperands.emplace_back(opOperand->getOperandNumber());
42884a6291SStephan Herhut       } else {
43884a6291SStephan Herhut         newIndexingMaps.emplace_back(map);
44f44e90b9STobias Gysi         newOperands.emplace_back(opOperand->get());
45884a6291SStephan Herhut       }
46884a6291SStephan Herhut     }
47884a6291SStephan Herhut 
48884a6291SStephan Herhut     if (scalarOperands.empty())
49884a6291SStephan Herhut       return failure();
50884a6291SStephan Herhut 
51f44e90b9STobias Gysi     for (OpOperand *opOperand : genericOp.getOutputOperands())
52f44e90b9STobias Gysi       newIndexingMaps.emplace_back(genericOp.getTiedIndexingMap(opOperand));
53884a6291SStephan Herhut 
54884a6291SStephan Herhut     Location loc = genericOp->getLoc();
55f44e90b9STobias Gysi     SmallVector<Value> outputOperands = genericOp.getOutputOperands();
56884a6291SStephan Herhut     auto newOp = rewriter.create<GenericOp>(
57f44e90b9STobias Gysi         loc, genericOp->getResultTypes(), newOperands, outputOperands,
58f44e90b9STobias Gysi         newIndexingMaps,
59884a6291SStephan Herhut         llvm::to_vector<4>(
60884a6291SStephan Herhut             genericOp.iterator_types().template getAsValueRange<StringAttr>()));
61884a6291SStephan Herhut     rewriter.cloneRegionBefore(genericOp.region(), newOp.region(),
62884a6291SStephan Herhut                                newOp.region().begin());
63884a6291SStephan Herhut 
64884a6291SStephan Herhut     Block *body = newOp.getBody();
65884a6291SStephan Herhut     PatternRewriter::InsertionGuard guard(rewriter);
66884a6291SStephan Herhut     rewriter.setInsertionPointToStart(body);
67884a6291SStephan Herhut 
68884a6291SStephan Herhut     for (auto idx : llvm::reverse(scalarOperands)) {
69f44e90b9STobias Gysi       OpOperand *opOperand = genericOp.getInputOperand(idx);
70f44e90b9STobias Gysi       AffineMap map = genericOp.getTiedIndexingMap(opOperand);
71884a6291SStephan Herhut       SmallVector<int64_t> indices = map.getConstantResults();
72884a6291SStephan Herhut       SmallVector<Value> indicesValues;
73884a6291SStephan Herhut       for (auto idx : indices)
74a54f4eaeSMogball         indicesValues.emplace_back(
75a54f4eaeSMogball             rewriter.create<arith::ConstantIndexOp>(loc, idx));
76f44e90b9STobias Gysi       Value extractedValue = rewriter.create<tensor::ExtractOp>(
77f44e90b9STobias Gysi           loc, opOperand->get(), indicesValues);
78f44e90b9STobias Gysi       body->getArgument(idx).replaceAllUsesWith(extractedValue);
79884a6291SStephan Herhut       body->eraseArgument(idx);
80884a6291SStephan Herhut     }
81884a6291SStephan Herhut 
82884a6291SStephan Herhut     rewriter.replaceOp(genericOp, newOp->getResults());
83884a6291SStephan Herhut     return success();
84884a6291SStephan Herhut   }
85884a6291SStephan Herhut };
86884a6291SStephan Herhut } // namespace
87884a6291SStephan Herhut 
88884a6291SStephan Herhut /// Patterns that are used to inline constant operands into linalg generic
89884a6291SStephan Herhut /// ops.
populateInlineConstantOperandsPatterns(RewritePatternSet & patterns)90884a6291SStephan Herhut void mlir::linalg::populateInlineConstantOperandsPatterns(
91884a6291SStephan Herhut     RewritePatternSet &patterns) {
92884a6291SStephan Herhut   auto *context = patterns.getContext();
93884a6291SStephan Herhut   patterns.add<InlineScalarOperands>(context);
94884a6291SStephan Herhut }
95884a6291SStephan Herhut 
96884a6291SStephan Herhut namespace {
97884a6291SStephan Herhut /// Pass that removes unit-extent dims within generic ops.
98884a6291SStephan Herhut struct LinalgInlineScalarOperandsPass
99884a6291SStephan Herhut     : public LinalgInlineScalarOperandsBase<LinalgInlineScalarOperandsPass> {
runOnOperation__anonbcb454d30211::LinalgInlineScalarOperandsPass10041574554SRiver Riddle   void runOnOperation() override {
101*58ceae95SRiver Riddle     func::FuncOp funcOp = getOperation();
102884a6291SStephan Herhut     MLIRContext *context = funcOp.getContext();
103884a6291SStephan Herhut     RewritePatternSet patterns(context);
104884a6291SStephan Herhut 
105884a6291SStephan Herhut     populateInlineConstantOperandsPatterns(patterns);
106884a6291SStephan Herhut     (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
107884a6291SStephan Herhut   }
108884a6291SStephan Herhut };
109884a6291SStephan Herhut } // namespace
110884a6291SStephan Herhut 
111*58ceae95SRiver Riddle std::unique_ptr<OperationPass<func::FuncOp>>
createLinalgInlineScalarOperandsPass()112884a6291SStephan Herhut mlir::createLinalgInlineScalarOperandsPass() {
113884a6291SStephan Herhut   return std::make_unique<LinalgInlineScalarOperandsPass>();
114884a6291SStephan Herhut }
115