1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic and helpers to expose Linalg transforms as rewrite
10 // patterns.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/Utils/Utils.h"
18 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
19 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
20 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
21 #include "mlir/Dialect/Vector/VectorOps.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Support/LLVM.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include <type_traits>
30 
31 #define DEBUG_TYPE "linalg-transforms"
32 
33 using namespace mlir;
34 using namespace mlir::edsc;
35 using namespace mlir::edsc::intrinsics;
36 using namespace mlir::linalg;
37 
38 using llvm::dbgs;
39 
40 #define DEBUG_TYPE "linalg-transforms"
41 
42 //===----------------------------------------------------------------------===//
43 // Transformations exposed as rewrite patterns.
44 //===----------------------------------------------------------------------===//
45 // Marker used as attribute name in generated Linalg rewriting transformations.
46 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
47     "__internal_linalg_transform__";
48 
49 mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<Identifier> matchDisjunction,
50                                          Optional<Identifier> replacement)
51     : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
52       replacement(replacement) {}
53 
54 LogicalResult
55 mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
56                                            Operation *op) const {
57   auto attr = op->template getAttrOfType<StringAttr>(
58       LinalgTransforms::kLinalgTransformMarker);
59 
60   if (!attr) {
61     // 1. Has no marker case and matchDisjunction is empty.
62     if (matchDisjunction.empty())
63       return success();
64 
65     // 2. Has no marker but was expecting a marker.
66     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
67       diag << " does not have any marker from list: ";
68       interleaveComma(matchDisjunction, diag);
69     });
70   }
71 
72   // 4. Match explicit marker.
73   for (auto marker : matchDisjunction)
74     if (attr.getValue() == marker)
75       return success();
76 
77   // 5. Fail to match.
78   return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
79     diag << " does not have any marker from list: ";
80     interleaveComma(matchDisjunction, diag);
81   });
82 }
83 
84 void mlir::linalg::LinalgMarker::replaceLinalgMarker(PatternRewriter &rewriter,
85                                                      Operation *op) const {
86   if (replacement.hasValue())
87     op->setAttr(LinalgTransforms::kLinalgTransformMarker,
88                 rewriter.getStringAttr(replacement.getValue()));
89   else
90     op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker,
91                                    rewriter.getContext()));
92 }
93 
94 LinalgTilingOptions &
95 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
96   SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
97   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
98     OpBuilder::InsertionGuard guard(b);
99     b.setInsertionPointToStart(
100         &op->getParentOfType<FuncOp>().getBody().front());
101     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
102       Value v = b.create<ConstantIndexOp>(op->getLoc(), s);
103       return v;
104     }));
105   };
106   return *this;
107 };
108 
109 /// Linalg base tiling pattern.
110 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
111     StringRef opName, MLIRContext *context, LinalgTilingOptions options,
112     LinalgMarker marker, PatternBenefit benefit)
113     : RewritePattern(opName, {}, benefit, context), marker(marker),
114       options(options) {}
115 
116 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewrite(
117     Operation *op, PatternRewriter &rewriter) const {
118   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
119   if (!linalgOp)
120     return failure();
121   if (failed(marker.checkAndNotify(rewriter, linalgOp)))
122     return failure();
123   Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options);
124 
125   if (!res)
126     return failure();
127 
128   // New marker if specified.
129   marker.replaceLinalgMarker(rewriter, res->op.getOperation());
130 
131   rewriter.eraseOp(op);
132   return success();
133 }
134 
135 /// Linalg base interchange pattern.
136 mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern(
137     StringRef opName, MLIRContext *context,
138     ArrayRef<unsigned> interchangeVector, LinalgMarker marker,
139     PatternBenefit benefit)
140     : RewritePattern(opName, {}, benefit, context), marker(marker),
141       interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
142 
143 LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite(
144     Operation *op, PatternRewriter &rewriter) const {
145   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
146   if (!linalgOp)
147     return failure();
148   if (failed(marker.checkAndNotify(rewriter, linalgOp)))
149     return failure();
150   if (failed(interchangeGenericLinalgOpPrecondition(op, interchangeVector)))
151     return failure();
152 
153   // TODO: figure out how this interplays with named ops. In particular this
154   // should break the named op property.
155   rewriter.updateRootInPlace(op, [&]() {
156     interchange(linalgOp, interchangeVector);
157     // New marker if specified.
158     marker.replaceLinalgMarker(rewriter, op);
159   });
160   return success();
161 }
162 
163 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
164     StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
165     LinalgMarker marker, PatternBenefit benefit)
166     : RewritePattern(opName, {}, benefit, context), marker(marker),
167       options(options) {}
168 
169 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
170     Operation *op, PatternRewriter &rewriter) const {
171   if (failed(marker.checkAndNotify(rewriter, op)))
172     return failure();
173   if (failed(promoteSubviewsPrecondition(op, options)))
174     return failure();
175 
176   // TODO: We cannot use root update here. This pattern is creating other ops,
177   // so if the promotion fails, those need to be cleaned up, which doesnt seem
178   // to be happening here. So to fail properly, we should be cloning the op and
179   // deleting the previous op. This needs more investigation.
180   rewriter.startRootUpdate(op);
181   Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options);
182   if (!promotedOp) {
183     rewriter.cancelRootUpdate(op);
184     return op->emitError("subview promotion failed");
185   }
186   rewriter.finalizeRootUpdate(op);
187   marker.replaceLinalgMarker(rewriter, op);
188   return success();
189 }
190 
191 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
192     StringRef opName, MLIRContext *context, LinalgMarker marker,
193     PatternBenefit benefit)
194     : RewritePattern(opName, {}, benefit, context), marker(marker) {}
195 
196 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
197     Operation *op, PatternRewriter &rewriter) const {
198   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
199   if (!linalgOp)
200     return failure();
201   if (failed(marker.checkAndNotify(rewriter, linalgOp)))
202     return failure();
203   if (failed(vectorizeLinalgOpPrecondition(op)))
204     return failure();
205   vectorizeLinalgOp(rewriter, op);
206   rewriter.eraseOp(op);
207   return success();
208 }
209 
210 LogicalResult mlir::linalg::applyStagedPatterns(
211     Operation *op, ArrayRef<OwningRewritePatternList> stage1Patterns,
212     const OwningRewritePatternList &stage2Patterns,
213     function_ref<LogicalResult(Operation *)> stage3Lambda) {
214   unsigned iteration = 0;
215   (void)iteration;
216   StringRef dbgPref = "\n[" DEBUG_TYPE "]: ";
217   (void)dbgPref;
218   for (const auto &patterns : stage1Patterns) {
219     if (!applyPatternsAndFoldGreedily(op, patterns)) {
220       dbgs() << "Underlying first stage rewrite did not converge";
221       return failure();
222     }
223     LLVM_DEBUG(dbgs()
224                << dbgPref << "After 1st stage, iter: " << ++iteration << "\n"
225                << *op);
226     if (!applyPatternsAndFoldGreedily(op, stage2Patterns)) {
227       LLVM_DEBUG(dbgs()
228                  << dbgPref << "Underlying 2nd stage rewrite did not converge");
229       return failure();
230     }
231     LLVM_DEBUG(dbgs()
232                << dbgPref << "After 2nd stage, iter : " << iteration << "\n"
233                << *op);
234     if (stage3Lambda) {
235       if (failed(stage3Lambda(op)))
236         return failure();
237       LLVM_DEBUG(dbgs()
238                  << dbgPref << "After 3rd stage, iter : " << iteration << "\n"
239                  << *op);
240     }
241   }
242   return success();
243 }
244