1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic and helpers to expose Linalg transforms as rewrite
10 // patterns.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/Utils/Utils.h"
18 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
19 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
20 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
21 #include "mlir/Dialect/Vector/VectorOps.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Support/LLVM.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include <type_traits>
30 
31 #define DEBUG_TYPE "linalg-transforms"
32 
33 using namespace mlir;
34 using namespace mlir::edsc;
35 using namespace mlir::edsc::intrinsics;
36 using namespace mlir::linalg;
37 
38 using llvm::dbgs;
39 
40 #define DEBUG_TYPE "linalg-transforms"
41 
42 //===----------------------------------------------------------------------===//
43 // Transformations exposed as rewrite patterns.
44 //===----------------------------------------------------------------------===//
45 // Marker used as attribute name in generated Linalg rewriting transformations.
46 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
47     "__internal_linalg_transform__";
48 
49 mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction,
50                                          Optional<StringRef> replacement)
51     : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
52       replacement(replacement) {}
53 
54 mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction,
55                                          StringRef replacement)
56     : LinalgMarker(matchDisjunction, Optional<StringRef>{replacement}) {}
57 
58 LogicalResult
59 mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
60                                            Operation *op) const {
61   auto attr = op->template getAttrOfType<StringAttr>(
62       LinalgTransforms::kLinalgTransformMarker);
63 
64   if (!attr) {
65     // 1. Has no marker case and matchDisjunction is empty.
66     if (matchDisjunction.empty())
67       return success();
68 
69     // 2. Has no marker and matchDisjuntion matches the no-moarker case.
70     for (auto marker : matchDisjunction)
71       if (marker.empty())
72         return success();
73 
74     // 3. Has no marker but was expecting a marker.
75     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
76       diag << " does not have any marker from list: ";
77       interleaveComma(matchDisjunction, diag);
78     });
79   }
80 
81   // 4. Match explicit marker.
82   for (auto marker : matchDisjunction)
83     if (attr.getValue() == marker)
84       return success();
85 
86   // 5. Fail to match.
87   return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
88     diag << " does not have any marker from list: ";
89     interleaveComma(matchDisjunction, diag);
90   });
91 }
92 
93 void mlir::linalg::LinalgMarker::replaceLinalgMarker(PatternRewriter &rewriter,
94                                                      Operation *op) const {
95   if (replacement.hasValue())
96     op->setAttr(LinalgTransforms::kLinalgTransformMarker,
97                 rewriter.getStringAttr(replacement.getValue()));
98   else
99     op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker,
100                                    rewriter.getContext()));
101 }
102 
103 LinalgTilingOptions &
104 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
105   SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
106   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
107     OpBuilder::InsertionGuard guard(b);
108     b.setInsertionPointToStart(
109         &op->getParentOfType<FuncOp>().getBody().front());
110     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
111       Value v = b.create<ConstantIndexOp>(op->getLoc(), s);
112       return v;
113     }));
114   };
115   return *this;
116 };
117 
118 /// Linalg base tiling pattern.
119 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
120     StringRef opName, MLIRContext *context, LinalgTilingOptions options,
121     LinalgMarker marker, PatternBenefit benefit)
122     : RewritePattern(opName, {}, benefit, context), marker(marker),
123       options(options) {}
124 
125 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewrite(
126     Operation *op, PatternRewriter &rewriter) const {
127   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
128   if (!linalgOp)
129     return failure();
130   if (failed(marker.checkAndNotify(rewriter, linalgOp)))
131     return failure();
132   Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options);
133 
134   if (!res)
135     return failure();
136 
137   // New marker if specified.
138   marker.replaceLinalgMarker(rewriter, res->op.getOperation());
139 
140   rewriter.eraseOp(op);
141   return success();
142 }
143 
144 /// Linalg base interchange pattern.
145 mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern(
146     StringRef opName, MLIRContext *context,
147     ArrayRef<unsigned> interchangeVector, LinalgMarker marker,
148     PatternBenefit benefit)
149     : RewritePattern(opName, {}, benefit, context), marker(marker),
150       interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
151 
152 LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite(
153     Operation *op, PatternRewriter &rewriter) const {
154   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
155   if (!linalgOp)
156     return failure();
157   if (failed(marker.checkAndNotify(rewriter, linalgOp)))
158     return failure();
159   if (failed(interchangeGenericLinalgOpPrecondition(op, interchangeVector)))
160     return failure();
161 
162   // TODO: figure out how this interplays with named ops. In particular this
163   // should break the named op property.
164   rewriter.updateRootInPlace(op, [&]() {
165     interchange(linalgOp, interchangeVector);
166     // New marker if specified.
167     marker.replaceLinalgMarker(rewriter, op);
168   });
169   return success();
170 }
171 
172 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
173     StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
174     LinalgMarker marker, PatternBenefit benefit)
175     : RewritePattern(opName, {}, benefit, context), marker(marker),
176       options(options) {}
177 
178 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
179     Operation *op, PatternRewriter &rewriter) const {
180   if (failed(marker.checkAndNotify(rewriter, op)))
181     return failure();
182   if (failed(promoteSubviewsPrecondition(op, options)))
183     return failure();
184 
185   // TODO: We cannot use root update here. This pattern is creating other ops,
186   // so if the promotion fails, those need to be cleaned up, which doesnt seem
187   // to be happening here. So to fail properly, we should be cloning the op and
188   // deleting the previous op. This needs more investigation.
189   rewriter.startRootUpdate(op);
190   Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options);
191   if (!promotedOp) {
192     rewriter.cancelRootUpdate(op);
193     return op->emitError("subview promotion failed");
194   }
195   rewriter.finalizeRootUpdate(op);
196   marker.replaceLinalgMarker(rewriter, op);
197   return success();
198 }
199 
200 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
201     StringRef opName, MLIRContext *context, LinalgMarker marker,
202     PatternBenefit benefit)
203     : RewritePattern(opName, {}, benefit, context), marker(marker) {}
204 
205 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
206     Operation *op, PatternRewriter &rewriter) const {
207   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
208   if (!linalgOp)
209     return failure();
210   if (failed(marker.checkAndNotify(rewriter, linalgOp)))
211     return failure();
212   if (failed(vectorizeLinalgOpPrecondition(op)))
213     return failure();
214   vectorizeLinalgOp(rewriter, op);
215   rewriter.eraseOp(op);
216   return success();
217 }
218 
219 LogicalResult mlir::linalg::applyStagedPatterns(
220     Operation *op, ArrayRef<OwningRewritePatternList> stage1Patterns,
221     const OwningRewritePatternList &stage2Patterns,
222     function_ref<LogicalResult(Operation *)> stage3Lambda) {
223   unsigned iteration = 0;
224   (void)iteration;
225   StringRef dbgPref = "\n[" DEBUG_TYPE "]: ";
226   (void)dbgPref;
227   for (const auto &patterns : stage1Patterns) {
228     if (!applyPatternsAndFoldGreedily(op, patterns)) {
229       dbgs() << "Underlying first stage rewrite did not converge";
230       return failure();
231     }
232     LLVM_DEBUG(dbgs()
233                << dbgPref << "After 1st stage, iter: " << ++iteration << "\n"
234                << *op);
235     if (!applyPatternsAndFoldGreedily(op, stage2Patterns)) {
236       LLVM_DEBUG(dbgs()
237                  << dbgPref << "Underlying 2nd stage rewrite did not converge");
238       return failure();
239     }
240     LLVM_DEBUG(dbgs()
241                << dbgPref << "After 2nd stage, iter : " << iteration << "\n"
242                << *op);
243     if (stage3Lambda) {
244       if (failed(stage3Lambda(op)))
245         return failure();
246       LLVM_DEBUG(dbgs()
247                  << dbgPref << "After 3rd stage, iter : " << iteration << "\n"
248                  << *op);
249     }
250   }
251   return success();
252 }
253