1 //===- Generalization.cpp - linalg named ops to generic ops  --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg generalization pass. It converts named
10 // Linalg ops to linalg.generic ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
16 #include "mlir/Dialect/Linalg/Passes.h"
17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18 #include "mlir/EDSC/Builders.h"
19 #include "mlir/IR/AffineMap.h"
20 #include "mlir/IR/Attributes.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/Support/Debug.h"
26 
27 #define DEBUG_TYPE "linalg-generalization"
28 
29 using namespace mlir;
30 
31 // Creates a linalg.generic op from the given `namedOp`. Returns a null op if
32 // the given `namedOp` does not have a region builder.
33 static linalg::GenericOp createGenericOpFromNamedOp(linalg::LinalgOp namedOp,
34                                                     OpBuilder &builder) {
35   auto regionBuilder = namedOp.getRegionBuilder();
36   if (!regionBuilder) {
37     LLVM_DEBUG(llvm::dbgs() << "no region builder for op: " << namedOp << "\n");
38     return nullptr;
39   }
40 
41   SmallVector<AffineMap, 4> indexingMaps = namedOp.getIndexingMaps();
42   auto iterators = llvm::to_vector<4>(
43       namedOp.iterator_types().getAsValueRange<StringAttr>());
44   auto resultTypes = namedOp.getOutputTensorTypes();
45   SmallVector<Type, 4> types(resultTypes.begin(), resultTypes.end());
46 
47   return builder.create<linalg::GenericOp>(
48       namedOp.getLoc(), types, namedOp.getInputs(), namedOp.getOutputs(),
49       indexingMaps, iterators,
50       [&regionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) {
51         edsc::ScopedContext scope(bodyBuilder, loc);
52         regionBuilder(*bodyBuilder.getBlock(), /*captures=*/{});
53       });
54 }
55 
56 namespace {
57 
58 /// Base class for all linalg generalization patterns. A subclass must provide
59 /// the following method:
60 ///   linalg::GenericOp createGenericOp(RootOp, PatternRewriter &)
61 /// for creating the generic op.
62 // TODO: remove this pattern after migrating all manually-written named ops
63 // into auto-generated ones.
64 template <typename ConcretePattern, typename RootOp>
65 struct LinalgGeneralizationPattern : OpRewritePattern<RootOp> {
66   LinalgGeneralizationPattern(MLIRContext *context,
67                               linalg::LinalgTransformationFilter marker,
68                               PatternBenefit benefit = 1)
69       : OpRewritePattern<RootOp>(context, benefit), marker(std::move(marker)) {}
70 
71   LogicalResult matchAndRewrite(RootOp rootOp,
72                                 PatternRewriter &rewriter) const override {
73     auto linalgOp = dyn_cast<linalg::LinalgOp>(rootOp.getOperation());
74     if (!linalgOp)
75       return failure();
76     if (failed(marker.checkAndNotify(rewriter, linalgOp)))
77       return failure();
78 
79     auto *pattern = static_cast<const ConcretePattern *>(this);
80     linalg::GenericOp genericOp = pattern->createGenericOp(rootOp, rewriter);
81     if (!genericOp)
82       return failure();
83 
84     rewriter.replaceOp(rootOp, genericOp.getResults());
85     marker.replaceLinalgTransformationFilter(rewriter,
86                                              genericOp.getOperation());
87     return success();
88   }
89 
90 private:
91   linalg::LinalgTransformationFilter marker;
92 };
93 
94 struct GeneralizeConvOp
95     : public LinalgGeneralizationPattern<GeneralizeConvOp, linalg::ConvOp> {
96   using LinalgGeneralizationPattern::LinalgGeneralizationPattern;
97 
98   linalg::GenericOp createGenericOp(linalg::ConvOp, OpBuilder &rewriter) const;
99 };
100 
101 /// Catch-all pattern for converting all named ops with a region builder into
102 /// linalg.generic.
103 struct LinalgNamedOpGeneralizationPattern : RewritePattern {
104   LinalgNamedOpGeneralizationPattern(MLIRContext *context,
105                                      linalg::LinalgTransformationFilter marker,
106                                      PatternBenefit benefit = 1)
107       : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
108         marker(std::move(marker)) {}
109 
110   LogicalResult matchAndRewrite(Operation *rootOp,
111                                 PatternRewriter &rewriter) const override {
112     auto linalgOp = dyn_cast<linalg::LinalgOp>(rootOp);
113     if (!linalgOp)
114       return failure();
115     if (failed(marker.checkAndNotify(rewriter, linalgOp)))
116       return failure();
117 
118     // No nothing to do for linalg.generic and linalg.indexed_generic.
119     if (isa<linalg::GenericOp, linalg::IndexedGenericOp>(rootOp))
120       return failure();
121 
122     linalg::GenericOp genericOp =
123         createGenericOpFromNamedOp(linalgOp, rewriter);
124     if (!genericOp)
125       return failure();
126 
127     rewriter.replaceOp(rootOp, genericOp.getResults());
128     marker.replaceLinalgTransformationFilter(rewriter,
129                                              genericOp.getOperation());
130     return success();
131   }
132 
133 private:
134   linalg::LinalgTransformationFilter marker;
135 };
136 
137 struct LinalgGeneralizationPass
138     : public LinalgGeneralizationBase<LinalgGeneralizationPass> {
139   void runOnFunction() override;
140 };
141 
142 } // namespace
143 
144 void LinalgGeneralizationPass::runOnFunction() {
145   FuncOp func = getFunction();
146   RewritePatternSet patterns(&getContext());
147   linalg::populateLinalgConvGeneralizationPatterns(patterns);
148   linalg::populateLinalgNamedOpsGeneralizationPatterns(patterns);
149   (void)applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns));
150 }
151 
152 linalg::GenericOp GeneralizeConvOp::createGenericOp(linalg::ConvOp convOp,
153                                                     OpBuilder &builder) const {
154   SmallVector<AffineMap, 4> indexingMaps = convOp.getIndexingMaps();
155   auto iterators =
156       llvm::to_vector<4>(convOp.iterator_types().getAsValueRange<StringAttr>());
157   return builder.create<linalg::GenericOp>(
158       convOp.getLoc(), /*resultTensorTypes=*/ArrayRef<Type>(),
159       convOp.getInputBuffers(), convOp.getOutputBuffers(), indexingMaps,
160       iterators,
161       [](OpBuilder &bodyBuilder, Location bodyLoc, ValueRange bodyArgs) {
162         Value mul =
163             bodyBuilder.create<MulFOp>(bodyLoc, bodyArgs[0], bodyArgs[1]);
164         Value add = bodyBuilder.create<AddFOp>(bodyLoc, mul, bodyArgs[2]);
165         bodyBuilder.create<linalg::YieldOp>(bodyLoc, add);
166       });
167 }
168 
169 void mlir::linalg::populateLinalgConvGeneralizationPatterns(
170     RewritePatternSet &patterns, linalg::LinalgTransformationFilter marker) {
171   patterns.add<GeneralizeConvOp>(patterns.getContext(), marker);
172 }
173 
174 void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
175     RewritePatternSet &patterns, linalg::LinalgTransformationFilter marker) {
176   patterns.add<LinalgNamedOpGeneralizationPattern>(patterns.getContext(),
177                                                    marker);
178 }
179 
180 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgGeneralizationPass() {
181   return std::make_unique<LinalgGeneralizationPass>();
182 }
183