1 //===- Generalization.cpp - linalg named ops to generic ops  --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg generalization pass. It converts named
10 // Linalg ops to linalg.generic ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
16 #include "mlir/Dialect/Linalg/Passes.h"
17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18 #include "mlir/IR/AffineMap.h"
19 #include "mlir/IR/Attributes.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/ImplicitLocOpBuilder.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/Support/Debug.h"
26 
27 #define DEBUG_TYPE "linalg-generalization"
28 
29 using namespace mlir;
30 using namespace mlir::linalg;
31 
32 // Creates a linalg.generic op from the given `namedOp`. Returns a null op if
33 // the given `namedOp` does not have a region builder.
34 static GenericOp createGenericOpFromNamedOp(LinalgOp namedOp,
35                                             PatternRewriter &rewriter) {
36   SmallVector<Value> inputOperands = namedOp.getInputOperands();
37   SmallVector<Value> outputOperands = namedOp.getOutputOperands();
38   SmallVector<AffineMap> indexingMaps = namedOp.getIndexingMaps();
39   SmallVector<StringRef> iterators = llvm::to_vector<4>(
40       namedOp.iterator_types().getAsValueRange<StringAttr>());
41   SmallVector<RankedTensorType> resultTypes = namedOp.getOutputTensorTypes();
42   SmallVector<Type> types(resultTypes.begin(), resultTypes.end());
43 
44   // Inline the existing region if the named operation has a region attached.
45   if (namedOp->getNumRegions() == 1) {
46     GenericOp genericOp =
47         rewriter.create<GenericOp>(namedOp.getLoc(), types, inputOperands,
48                                    outputOperands, indexingMaps, iterators);
49     rewriter.inlineRegionBefore(namedOp->getRegion(0), genericOp.region(),
50                                 genericOp.region().begin());
51     return genericOp;
52   }
53 
54   // Otherwise use the region builder to generate a new region.
55   // TODO: Remove this path once all linag operations have a region attached.
56   auto regionBuilder = namedOp.getRegionBuilder();
57   if (!regionBuilder) {
58     LLVM_DEBUG(llvm::dbgs() << "no region builder for op: " << namedOp << "\n");
59     return nullptr;
60   }
61   return rewriter.create<GenericOp>(
62       namedOp.getLoc(), types, inputOperands, outputOperands, indexingMaps,
63       iterators,
64       [&regionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) {
65         ImplicitLocOpBuilder b(loc, bodyBuilder);
66         regionBuilder(b, *bodyBuilder.getBlock(),
67                       /*captures=*/{});
68       });
69 }
70 
71 namespace {
72 
73 /// Base class for all linalg generalization patterns. A subclass must provide
74 /// the following method:
75 ///   GenericOp createGenericOp(RootOp, PatternRewriter &)
76 /// for creating the generic op.
77 // TODO: remove this pattern after migrating all manually-written named ops
78 // into auto-generated ones.
79 template <typename ConcretePattern, typename RootOp>
80 struct LinalgGeneralizationPattern : OpRewritePattern<RootOp> {
81   LinalgGeneralizationPattern(MLIRContext *context,
82                               LinalgTransformationFilter marker,
83                               PatternBenefit benefit = 1)
84       : OpRewritePattern<RootOp>(context, benefit), marker(std::move(marker)) {}
85 
86   LogicalResult matchAndRewrite(RootOp rootOp,
87                                 PatternRewriter &rewriter) const override {
88     auto linalgOp = dyn_cast<LinalgOp>(rootOp.getOperation());
89     if (!linalgOp)
90       return failure();
91     if (failed(marker.checkAndNotify(rewriter, linalgOp)))
92       return failure();
93 
94     auto *pattern = static_cast<const ConcretePattern *>(this);
95     GenericOp genericOp = pattern->createGenericOp(rootOp, rewriter);
96     if (!genericOp)
97       return failure();
98 
99     rewriter.replaceOp(rootOp, genericOp.getResults());
100     marker.replaceLinalgTransformationFilter(rewriter,
101                                              genericOp.getOperation());
102     return success();
103   }
104 
105 private:
106   LinalgTransformationFilter marker;
107 };
108 
109 struct GeneralizeConvOp
110     : public LinalgGeneralizationPattern<GeneralizeConvOp, ConvOp> {
111   using LinalgGeneralizationPattern::LinalgGeneralizationPattern;
112 
113   GenericOp createGenericOp(ConvOp convOp, OpBuilder &builder) const;
114 };
115 
116 /// Catch-all pattern for converting all named ops with a region builder into
117 /// linalg.generic.
118 struct LinalgNamedOpGeneralizationPattern : RewritePattern {
119   LinalgNamedOpGeneralizationPattern(MLIRContext *context,
120                                      LinalgTransformationFilter marker,
121                                      PatternBenefit benefit = 1)
122       : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
123         marker(std::move(marker)) {}
124 
125   LogicalResult matchAndRewrite(Operation *rootOp,
126                                 PatternRewriter &rewriter) const override {
127     auto linalgOp = dyn_cast<LinalgOp>(rootOp);
128     if (!linalgOp)
129       return failure();
130     if (failed(marker.checkAndNotify(rewriter, linalgOp)))
131       return failure();
132 
133     // No nothing to do for linalg.generic and linalg.indexed_generic.
134     if (isa<GenericOp, IndexedGenericOp>(rootOp))
135       return failure();
136 
137     GenericOp genericOp = createGenericOpFromNamedOp(linalgOp, rewriter);
138     if (!genericOp)
139       return failure();
140 
141     rewriter.replaceOp(rootOp, genericOp.getResults());
142     marker.replaceLinalgTransformationFilter(rewriter,
143                                              genericOp.getOperation());
144     return success();
145   }
146 
147 private:
148   LinalgTransformationFilter marker;
149 };
150 
151 struct LinalgGeneralizationPass
152     : public LinalgGeneralizationBase<LinalgGeneralizationPass> {
153   void runOnFunction() override;
154 };
155 
156 } // namespace
157 
158 void LinalgGeneralizationPass::runOnFunction() {
159   FuncOp func = getFunction();
160   RewritePatternSet patterns(&getContext());
161   populateLinalgConvGeneralizationPatterns(patterns);
162   populateLinalgNamedOpsGeneralizationPatterns(patterns);
163   (void)applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns));
164 }
165 
166 GenericOp GeneralizeConvOp::createGenericOp(ConvOp convOp,
167                                             OpBuilder &builder) const {
168   SmallVector<AffineMap> indexingMaps = convOp.getIndexingMaps();
169   auto iterators =
170       llvm::to_vector<4>(convOp.iterator_types().getAsValueRange<StringAttr>());
171   SmallVector<Value> inputBuffers = convOp.getInputBufferOperands();
172   SmallVector<Value> outputBuffers = convOp.getOutputBufferOperands();
173   return builder.create<GenericOp>(
174       convOp.getLoc(), /*resultTensorTypes=*/ArrayRef<Type>(), inputBuffers,
175       outputBuffers, indexingMaps, iterators,
176       [](OpBuilder &bodyBuilder, Location bodyLoc, ValueRange bodyArgs) {
177         Value mul =
178             bodyBuilder.create<MulFOp>(bodyLoc, bodyArgs[0], bodyArgs[1]);
179         Value add = bodyBuilder.create<AddFOp>(bodyLoc, mul, bodyArgs[2]);
180         bodyBuilder.create<YieldOp>(bodyLoc, add);
181       });
182 }
183 
184 void mlir::linalg::populateLinalgConvGeneralizationPatterns(
185     RewritePatternSet &patterns, LinalgTransformationFilter marker) {
186   patterns.add<GeneralizeConvOp>(patterns.getContext(), marker);
187 }
188 
189 void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
190     RewritePatternSet &patterns, LinalgTransformationFilter marker) {
191   patterns.add<LinalgNamedOpGeneralizationPattern>(patterns.getContext(),
192                                                    marker);
193 }
194 
195 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgGeneralizationPass() {
196   return std::make_unique<LinalgGeneralizationPass>();
197 }
198