1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic and helpers to expose Linalg transforms as rewrite
10 // patterns.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/Utils/Utils.h"
18 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
19 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
20 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
21 #include "mlir/Dialect/Vector/VectorOps.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Support/LLVM.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include <type_traits>
30 
31 #define DEBUG_TYPE "linalg-transforms"
32 
33 using namespace mlir;
34 using namespace mlir::edsc;
35 using namespace mlir::edsc::intrinsics;
36 using namespace mlir::linalg;
37 
38 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
39 //===----------------------------------------------------------------------===//
40 // Transformations exposed as rewrite patterns.
41 //===----------------------------------------------------------------------===//
42 // Marker used as attribute name in generated Linalg rewriting transformations.
43 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
44     "__internal_linalg_transform__";
45 
46 mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<Identifier> matchDisjunction,
47                                          Optional<Identifier> replacement)
48     : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
49       replacement(replacement) {}
50 
51 LogicalResult
52 mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
53                                            Operation *op) const {
54   auto attr = op->template getAttrOfType<StringAttr>(
55       LinalgTransforms::kLinalgTransformMarker);
56 
57   if (!attr) {
58     // 1. Has no marker case and matchDisjunction is empty.
59     if (matchDisjunction.empty())
60       return success();
61 
62     // 2. Has no marker but was expecting a marker.
63     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
64       diag << " does not have any marker from list: ";
65       interleaveComma(matchDisjunction, diag);
66     });
67   }
68 
69   // 4. Match explicit marker.
70   for (auto marker : matchDisjunction)
71     if (attr.getValue() == marker)
72       return success();
73 
74   // 5. Fail to match.
75   return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
76     diag << " does not have any marker from list: ";
77     interleaveComma(matchDisjunction, diag);
78   });
79 }
80 
81 void mlir::linalg::LinalgMarker::replaceLinalgMarker(PatternRewriter &rewriter,
82                                                      Operation *op) const {
83   if (replacement.hasValue())
84     op->setAttr(LinalgTransforms::kLinalgTransformMarker,
85                 rewriter.getStringAttr(replacement.getValue()));
86   else
87     op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker,
88                                    rewriter.getContext()));
89 }
90 
91 LinalgTilingOptions &
92 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
93   SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
94   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
95     OpBuilder::InsertionGuard guard(b);
96     b.setInsertionPointToStart(
97         &op->getParentOfType<FuncOp>().getBody().front());
98     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
99       Value v = b.create<ConstantIndexOp>(op->getLoc(), s);
100       return v;
101     }));
102   };
103   return *this;
104 }
105 
106 /// Linalg base tiling pattern.
107 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
108     StringRef opName, MLIRContext *context, LinalgTilingOptions options,
109     LinalgMarker marker, PatternBenefit benefit)
110     : RewritePattern(opName, {}, benefit, context), marker(marker),
111       options(options) {}
112 
113 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewrite(
114     Operation *op, PatternRewriter &rewriter) const {
115   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
116   if (!linalgOp)
117     return failure();
118   if (failed(marker.checkAndNotify(rewriter, linalgOp)))
119     return failure();
120 
121   Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options);
122 
123   if (!res)
124     return failure();
125 
126   // New marker if specified.
127   marker.replaceLinalgMarker(rewriter, res->op.getOperation());
128 
129   rewriter.eraseOp(op);
130   return success();
131 }
132 
133 /// Linalg base interchange pattern.
134 mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern(
135     StringRef opName, MLIRContext *context,
136     ArrayRef<unsigned> interchangeVector, LinalgMarker marker,
137     PatternBenefit benefit)
138     : RewritePattern(opName, {}, benefit, context), marker(marker),
139       interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
140 
141 LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite(
142     Operation *op, PatternRewriter &rewriter) const {
143   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
144   if (!linalgOp)
145     return failure();
146   if (failed(marker.checkAndNotify(rewriter, linalgOp)))
147     return failure();
148   if (failed(interchangeGenericLinalgOpPrecondition(op, interchangeVector)))
149     return failure();
150 
151   // TODO: figure out how this interplays with named ops. In particular this
152   // should break the named op property.
153   rewriter.updateRootInPlace(op, [&]() {
154     interchange(linalgOp, interchangeVector);
155     // New marker if specified.
156     marker.replaceLinalgMarker(rewriter, op);
157   });
158   return success();
159 }
160 
161 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
162     StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
163     LinalgMarker marker, PatternBenefit benefit)
164     : RewritePattern(opName, {}, benefit, context), marker(marker),
165       options(options) {}
166 
167 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
168     Operation *op, PatternRewriter &rewriter) const {
169   if (failed(marker.checkAndNotify(rewriter, op)))
170     return failure();
171   if (failed(promoteSubviewsPrecondition(op, options)))
172     return failure();
173 
174   // TODO: We cannot use root update here. This pattern is creating other ops,
175   // so if the promotion fails, those need to be cleaned up, which doesnt seem
176   // to be happening here. So to fail properly, we should be cloning the op and
177   // deleting the previous op. This needs more investigation.
178   rewriter.startRootUpdate(op);
179   Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options);
180   if (!promotedOp) {
181     rewriter.cancelRootUpdate(op);
182     return op->emitError("subview promotion failed");
183   }
184   rewriter.finalizeRootUpdate(op);
185   marker.replaceLinalgMarker(rewriter, op);
186   return success();
187 }
188 
189 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
190     StringRef opName, MLIRContext *context, LinalgMarker marker,
191     PatternBenefit benefit)
192     : RewritePattern(opName, {}, benefit, context), marker(marker) {}
193 
194 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
195     Operation *op, PatternRewriter &rewriter) const {
196   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
197   if (!linalgOp)
198     return failure();
199   if (failed(marker.checkAndNotify(rewriter, linalgOp)))
200     return failure();
201   if (failed(vectorizeLinalgOpPrecondition(op)))
202     return failure();
203   vectorizeLinalgOp(rewriter, op);
204   rewriter.eraseOp(op);
205   return success();
206 }
207 
208 LogicalResult mlir::linalg::applyStagedPatterns(
209     Operation *op, ArrayRef<OwningRewritePatternList> stage1Patterns,
210     const OwningRewritePatternList &stage2Patterns,
211     function_ref<LogicalResult(Operation *)> stage3Lambda) {
212   unsigned iteration = 0;
213   (void)iteration;
214   for (const auto &patterns : stage1Patterns) {
215     LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n"
216                       << *op);
217     if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
218       LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge");
219       return failure();
220     }
221     LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n"
222                       << *op);
223     if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) {
224       LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge");
225       return failure();
226     }
227     LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n"
228                       << *op);
229     if (stage3Lambda) {
230       if (failed(stage3Lambda(op)))
231         return failure();
232       LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n"
233                         << *op);
234     }
235   }
236   return success();
237 }
238