1 //===- Generalization.cpp - linalg named ops to generic ops  --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg generalization pass. It converts named
10 // Linalg ops to linalg.generic ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
16 #include "mlir/Dialect/Linalg/Passes.h"
17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18 #include "mlir/EDSC/Builders.h"
19 #include "mlir/IR/AffineMap.h"
20 #include "mlir/IR/Attributes.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/Support/Debug.h"
26 
27 #define DEBUG_TYPE "linalg-generalization"
28 
29 using namespace mlir;
30 using namespace mlir::linalg;
31 
32 // Creates a linalg.generic op from the given `namedOp`. Returns a null op if
33 // the given `namedOp` does not have a region builder.
34 static GenericOp createGenericOpFromNamedOp(LinalgOp namedOp,
35                                             PatternRewriter &rewriter) {
36   SmallVector<AffineMap> indexingMaps = namedOp.getIndexingMaps();
37   SmallVector<StringRef> iterators = llvm::to_vector<4>(
38       namedOp.iterator_types().getAsValueRange<StringAttr>());
39   SmallVector<RankedTensorType> resultTypes = namedOp.getOutputTensorTypes();
40   SmallVector<Type> types(resultTypes.begin(), resultTypes.end());
41 
42   // Inline the existing region if the named operation has a region attached.
43   if (namedOp->getNumRegions() == 1) {
44     GenericOp genericOp = rewriter.create<GenericOp>(
45         namedOp.getLoc(), types, namedOp.getInputs(), namedOp.getOutputs(),
46         indexingMaps, iterators);
47     rewriter.inlineRegionBefore(namedOp->getRegion(0), genericOp.region(),
48                                 genericOp.region().begin());
49     return genericOp;
50   }
51 
52   // Otherwise use the region builder to generate a new region.
53   // TODO: Remove this path once all linag operations have a region attached.
54   auto regionBuilder = namedOp.getRegionBuilder();
55   if (!regionBuilder) {
56     LLVM_DEBUG(llvm::dbgs() << "no region builder for op: " << namedOp << "\n");
57     return nullptr;
58   }
59   return rewriter.create<GenericOp>(
60       namedOp.getLoc(), types, namedOp.getInputs(), namedOp.getOutputs(),
61       indexingMaps, iterators,
62       [&regionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) {
63         edsc::ScopedContext scope(bodyBuilder, loc);
64         regionBuilder(*bodyBuilder.getBlock(), /*captures=*/{});
65       });
66 }
67 
68 namespace {
69 
70 /// Base class for all linalg generalization patterns. A subclass must provide
71 /// the following method:
72 ///   GenericOp createGenericOp(RootOp, PatternRewriter &)
73 /// for creating the generic op.
74 // TODO: remove this pattern after migrating all manually-written named ops
75 // into auto-generated ones.
76 template <typename ConcretePattern, typename RootOp>
77 struct LinalgGeneralizationPattern : OpRewritePattern<RootOp> {
78   LinalgGeneralizationPattern(MLIRContext *context,
79                               LinalgTransformationFilter marker,
80                               PatternBenefit benefit = 1)
81       : OpRewritePattern<RootOp>(context, benefit), marker(std::move(marker)) {}
82 
83   LogicalResult matchAndRewrite(RootOp rootOp,
84                                 PatternRewriter &rewriter) const override {
85     auto linalgOp = dyn_cast<LinalgOp>(rootOp.getOperation());
86     if (!linalgOp)
87       return failure();
88     if (failed(marker.checkAndNotify(rewriter, linalgOp)))
89       return failure();
90 
91     auto *pattern = static_cast<const ConcretePattern *>(this);
92     GenericOp genericOp = pattern->createGenericOp(rootOp, rewriter);
93     if (!genericOp)
94       return failure();
95 
96     rewriter.replaceOp(rootOp, genericOp.getResults());
97     marker.replaceLinalgTransformationFilter(rewriter,
98                                              genericOp.getOperation());
99     return success();
100   }
101 
102 private:
103   LinalgTransformationFilter marker;
104 };
105 
106 struct GeneralizeConvOp
107     : public LinalgGeneralizationPattern<GeneralizeConvOp, ConvOp> {
108   using LinalgGeneralizationPattern::LinalgGeneralizationPattern;
109 
110   GenericOp createGenericOp(ConvOp convOp, OpBuilder &builder) const;
111 };
112 
113 /// Catch-all pattern for converting all named ops with a region builder into
114 /// linalg.generic.
115 struct LinalgNamedOpGeneralizationPattern : RewritePattern {
116   LinalgNamedOpGeneralizationPattern(MLIRContext *context,
117                                      LinalgTransformationFilter marker,
118                                      PatternBenefit benefit = 1)
119       : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
120         marker(std::move(marker)) {}
121 
122   LogicalResult matchAndRewrite(Operation *rootOp,
123                                 PatternRewriter &rewriter) const override {
124     auto linalgOp = dyn_cast<LinalgOp>(rootOp);
125     if (!linalgOp)
126       return failure();
127     if (failed(marker.checkAndNotify(rewriter, linalgOp)))
128       return failure();
129 
130     // No nothing to do for linalg.generic and linalg.indexed_generic.
131     if (isa<GenericOp, IndexedGenericOp>(rootOp))
132       return failure();
133 
134     GenericOp genericOp = createGenericOpFromNamedOp(linalgOp, rewriter);
135     if (!genericOp)
136       return failure();
137 
138     rewriter.replaceOp(rootOp, genericOp.getResults());
139     marker.replaceLinalgTransformationFilter(rewriter,
140                                              genericOp.getOperation());
141     return success();
142   }
143 
144 private:
145   LinalgTransformationFilter marker;
146 };
147 
148 struct LinalgGeneralizationPass
149     : public LinalgGeneralizationBase<LinalgGeneralizationPass> {
150   void runOnFunction() override;
151 };
152 
153 } // namespace
154 
155 void LinalgGeneralizationPass::runOnFunction() {
156   FuncOp func = getFunction();
157   RewritePatternSet patterns(&getContext());
158   populateLinalgConvGeneralizationPatterns(patterns);
159   populateLinalgNamedOpsGeneralizationPatterns(patterns);
160   (void)applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns));
161 }
162 
163 GenericOp GeneralizeConvOp::createGenericOp(ConvOp convOp,
164                                             OpBuilder &builder) const {
165   SmallVector<AffineMap, 4> indexingMaps = convOp.getIndexingMaps();
166   auto iterators =
167       llvm::to_vector<4>(convOp.iterator_types().getAsValueRange<StringAttr>());
168   return builder.create<GenericOp>(
169       convOp.getLoc(), /*resultTensorTypes=*/ArrayRef<Type>(),
170       convOp.getInputBuffers(), convOp.getOutputBuffers(), indexingMaps,
171       iterators,
172       [](OpBuilder &bodyBuilder, Location bodyLoc, ValueRange bodyArgs) {
173         Value mul =
174             bodyBuilder.create<MulFOp>(bodyLoc, bodyArgs[0], bodyArgs[1]);
175         Value add = bodyBuilder.create<AddFOp>(bodyLoc, mul, bodyArgs[2]);
176         bodyBuilder.create<YieldOp>(bodyLoc, add);
177       });
178 }
179 
180 void mlir::linalg::populateLinalgConvGeneralizationPatterns(
181     RewritePatternSet &patterns, LinalgTransformationFilter marker) {
182   patterns.add<GeneralizeConvOp>(patterns.getContext(), marker);
183 }
184 
185 void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
186     RewritePatternSet &patterns, LinalgTransformationFilter marker) {
187   patterns.add<LinalgNamedOpGeneralizationPattern>(patterns.getContext(),
188                                                    marker);
189 }
190 
191 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgGeneralizationPass() {
192   return std::make_unique<LinalgGeneralizationPass>();
193 }
194