1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic and helpers to expose Linalg transforms as rewrite
10 // patterns.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/Utils/Utils.h"
18 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
19 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
20 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
21 #include "mlir/Dialect/Vector/VectorOps.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Support/LLVM.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include <type_traits>
30 
31 #define DEBUG_TYPE "linalg-transforms"
32 
33 using namespace mlir;
34 using namespace mlir::edsc;
35 using namespace mlir::edsc::intrinsics;
36 using namespace mlir::linalg;
37 
38 using llvm::dbgs;
39 
40 //===----------------------------------------------------------------------===//
41 // Transformations exposed as rewrite patterns.
42 //===----------------------------------------------------------------------===//
43 // Marker used as attribute name in generated Linalg rewriting transformations.
44 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
45     "__internal_linalg_transform__";
46 
47 mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction,
48                                          llvm::Optional<StringRef> replacement)
49     : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
50       replacement(replacement) {}
51 
52 mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction,
53                                          StringRef replacement)
54     : LinalgMarker(matchDisjunction, llvm::Optional<StringRef>{replacement}) {}
55 
56 LogicalResult
57 mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
58                                            Operation *op) const {
59   auto attr = op->template getAttrOfType<StringAttr>(
60       LinalgTransforms::kLinalgTransformMarker);
61 
62   if (!attr) {
63     // 1. Has no marker case and matchDisjunction is empty.
64     if (matchDisjunction.empty())
65       return success();
66 
67     // 2. Has no marker and matchDisjuntion matches the no-moarker case.
68     for (auto marker : matchDisjunction)
69       if (marker.empty())
70         return success();
71 
72     // 3. Has no marker but was expecting a marker.
73     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
74       diag << " does not have any marker from list: ";
75       llvm::interleaveComma(matchDisjunction, diag);
76     });
77   }
78 
79   // 4. Match explicit marker.
80   for (auto marker : matchDisjunction)
81     if (attr.getValue() == marker)
82       return success();
83 
84   // 5. Fail to match.
85   return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
86     diag << " does not have any marker from list: ";
87     llvm::interleaveComma(matchDisjunction, diag);
88   });
89 }
90 
91 void mlir::linalg::LinalgMarker::replaceLinalgMarker(PatternRewriter &rewriter,
92                                                      Operation *op) const {
93   if (replacement.hasValue())
94     op->setAttr(LinalgTransforms::kLinalgTransformMarker,
95                 rewriter.getStringAttr(replacement.getValue()));
96   else
97     op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker,
98                                    rewriter.getContext()));
99 }
100 
101 LinalgTilingOptions &
102 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
103   SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
104   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
105     OpBuilder::InsertionGuard guard(b);
106     b.setInsertionPointToStart(
107         &op->getParentOfType<FuncOp>().getBody().front());
108     return llvm::to_vector<4>(llvm::map_range(tileSizes, [&](int64_t s) {
109       Value v = b.create<ConstantIndexOp>(op->getLoc(), s);
110       return v;
111     }));
112   };
113   return *this;
114 };
115 
116 /// Linalg base tiling pattern.
117 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
118     StringRef opName, MLIRContext *context, LinalgTilingOptions options,
119     LinalgMarker marker, PatternBenefit benefit)
120     : RewritePattern(opName, {}, benefit, context), marker(marker),
121       options(options) {}
122 
123 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewrite(
124     Operation *op, PatternRewriter &rewriter) const {
125   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
126   if (!linalgOp)
127     return failure();
128   if (failed(marker.checkAndNotify(rewriter, linalgOp)))
129     return failure();
130   Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options);
131 
132   if (!res)
133     return failure();
134 
135   // New marker if specified.
136   marker.replaceLinalgMarker(rewriter, res->op.getOperation());
137 
138   rewriter.eraseOp(op);
139   return success();
140 }
141 
142 /// Linalg base interchange pattern.
143 mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern(
144     StringRef opName, MLIRContext *context,
145     ArrayRef<unsigned> interchangeVector, LinalgMarker marker,
146     PatternBenefit benefit)
147     : RewritePattern(opName, {}, benefit, context), marker(marker),
148       interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
149 
150 LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite(
151     Operation *op, PatternRewriter &rewriter) const {
152   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
153   if (!linalgOp)
154     return failure();
155   if (failed(marker.checkAndNotify(rewriter, linalgOp)))
156     return failure();
157   if (failed(interchangeGenericLinalgOpPrecondition(op, interchangeVector)))
158     return failure();
159 
160   // TODO: figure out how this interplays with named ops. In particular this
161   // should break the named op property.
162   rewriter.updateRootInPlace(op, [&]() {
163     interchange(linalgOp, interchangeVector);
164     // New marker if specified.
165     marker.replaceLinalgMarker(rewriter, op);
166   });
167   return success();
168 }
169 
170 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
171     StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
172     LinalgMarker marker, PatternBenefit benefit)
173     : RewritePattern(opName, {}, benefit, context), marker(marker),
174       options(options) {}
175 
176 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
177     Operation *op, PatternRewriter &rewriter) const {
178   if (failed(marker.checkAndNotify(rewriter, op)))
179     return failure();
180   if (failed(promoteSubviewsPrecondition(op, options)))
181     return failure();
182 
183   // TODO: We cannot use root update here. This pattern is creating other ops,
184   // so if the promotion fails, those need to be cleaned up, which doesnt seem
185   // to be happening here. So to fail properly, we should be cloning the op and
186   // deleting the previous op. This needs more investigation.
187   rewriter.startRootUpdate(op);
188   Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options);
189   if (!promotedOp) {
190     rewriter.cancelRootUpdate(op);
191     return op->emitError("subview promotion failed");
192   }
193   rewriter.finalizeRootUpdate(op);
194   marker.replaceLinalgMarker(rewriter, op);
195   return success();
196 }
197 
198 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
199     StringRef opName, MLIRContext *context, LinalgMarker marker,
200     PatternBenefit benefit)
201     : RewritePattern(opName, {}, benefit, context), marker(marker) {}
202 
203 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
204     Operation *op, PatternRewriter &rewriter) const {
205   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
206   if (!linalgOp)
207     return failure();
208   if (failed(marker.checkAndNotify(rewriter, linalgOp)))
209     return failure();
210   if (failed(vectorizeLinalgOpPrecondition(op)))
211     return failure();
212   vectorizeLinalgOp(rewriter, op);
213   rewriter.eraseOp(op);
214   return success();
215 }
216 
217 LogicalResult mlir::linalg::applyStagedPatterns(
218     Operation *op, ArrayRef<OwningRewritePatternList> stage1Patterns,
219     const OwningRewritePatternList &stage2Patterns,
220     llvm::function_ref<LogicalResult(Operation *)> stage3Lambda) {
221   for (const auto &patterns : stage1Patterns) {
222     if (!applyPatternsAndFoldGreedily(op, patterns)) {
223       llvm::dbgs() << "Underlying first stage rewrite did not converge";
224       return failure();
225     }
226     if (!applyPatternsAndFoldGreedily(op, stage2Patterns)) {
227       llvm::dbgs() << "Underlying second stage rewrite did not converge";
228       return failure();
229     }
230     if (stage3Lambda) {
231       if (failed(stage3Lambda(op)))
232         return failure();
233     }
234   }
235   return success();
236 }
237