1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic and helpers to expose Linalg transforms as rewrite
10 // patterns.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/Utils/Utils.h"
18 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
19 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
20 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
21 #include "mlir/Dialect/Vector/VectorOps.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Support/LLVM.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include <type_traits>
30 
31 #define DEBUG_TYPE "linalg-transforms"
32 
33 using namespace mlir;
34 using namespace mlir::edsc;
35 using namespace mlir::edsc::intrinsics;
36 using namespace mlir::linalg;
37 
38 using llvm::dbgs;
39 
40 //===----------------------------------------------------------------------===//
41 // Transformations exposed as rewrite patterns.
42 //===----------------------------------------------------------------------===//
43 // Marker used as attribute name in generated Linalg rewriting transformations.
44 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
45     "__internal_linalg_transform__";
46 
47 mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction,
48                                          llvm::Optional<StringRef> replacement)
49     : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
50       replacement(replacement) {}
51 
52 mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction,
53                                          StringRef replacement)
54     : LinalgMarker(matchDisjunction, llvm::Optional<StringRef>{replacement}) {}
55 
56 LogicalResult
57 mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
58                                            Operation *op) const {
59   auto attr = op->template getAttrOfType<StringAttr>(
60       LinalgTransforms::kLinalgTransformMarker);
61 
62   if (!attr) {
63     // 1. Has no marker case and matchDisjunction is empty.
64     if (matchDisjunction.empty())
65       return success();
66 
67     // 2. Has no marker and matchDisjuntion matches the no-moarker case.
68     for (auto marker : matchDisjunction)
69       if (marker.empty())
70         return success();
71 
72     // 3. Has no marker but was expecting a marker.
73     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
74       diag << " does not have any marker from list: ";
75       llvm::interleaveComma(matchDisjunction, diag);
76     });
77   }
78 
79   // 4. Match explicit marker.
80   for (auto marker : matchDisjunction)
81     if (attr.getValue() == marker)
82       return success();
83 
84   // 5. Fail to match.
85   return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
86     diag << " does not have any marker from list: ";
87     llvm::interleaveComma(matchDisjunction, diag);
88   });
89 }
90 
91 void mlir::linalg::LinalgMarker::replaceLinalgMarker(PatternRewriter &rewriter,
92                                                      Operation *op) const {
93   if (replacement.hasValue())
94     op->setAttr(LinalgTransforms::kLinalgTransformMarker,
95                 rewriter.getStringAttr(replacement.getValue()));
96   else
97     op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker,
98                                    rewriter.getContext()));
99 }
100 
101 /// Linalg base tiling pattern.
102 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
103     StringRef opName, MLIRContext *context, LinalgTilingOptions options,
104     LinalgMarker marker, PatternBenefit benefit)
105     : RewritePattern(opName, {}, benefit, context), marker(marker),
106       options(options) {}
107 
108 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewrite(
109     Operation *op, PatternRewriter &rewriter) const {
110   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
111   if (!linalgOp)
112     return failure();
113   if (failed(marker.checkAndNotify(rewriter, linalgOp)))
114     return failure();
115   Optional<TiledLinalgOp> res;
116   if (options.loopType == LinalgTilingLoopType::Loops)
117     res = tileLinalgOp(rewriter, linalgOp, options.tileSizes,
118                        options.interchangeVector);
119   else if (options.loopType == LinalgTilingLoopType::ParallelLoops)
120     res = tileLinalgOpToParallelLoops(rewriter, linalgOp, options.tileSizes,
121                                       options.interchangeVector);
122   // TODO: Impl tiling to affine loops when it makes sense.
123 
124   if (!res)
125     return failure();
126 
127   // New marker if specified.
128   marker.replaceLinalgMarker(rewriter, res->op.getOperation());
129 
130   rewriter.eraseOp(op);
131   return success();
132 }
133 
134 /// Linalg base interchange pattern.
135 mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern(
136     StringRef opName, MLIRContext *context,
137     ArrayRef<unsigned> interchangeVector, LinalgMarker marker,
138     PatternBenefit benefit)
139     : RewritePattern(opName, {}, benefit, context), marker(marker),
140       interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
141 
142 LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite(
143     Operation *op, PatternRewriter &rewriter) const {
144   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
145   if (!linalgOp)
146     return failure();
147   if (failed(marker.checkAndNotify(rewriter, linalgOp)))
148     return failure();
149   if (failed(interchangeGenericLinalgOpPrecondition(op, interchangeVector)))
150     return failure();
151 
152   // TODO: figure out how this interplays with named ops. In particular this
153   // should break the named op property.
154   rewriter.updateRootInPlace(op, [&]() {
155     interchange(linalgOp, interchangeVector);
156     // New marker if specified.
157     marker.replaceLinalgMarker(rewriter, op);
158   });
159   return success();
160 }
161 
162 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
163     StringRef opName, MLIRContext *context,
164     ArrayRef<unsigned> operandsToPromote, unsigned alignment,
165     LinalgMarker marker, PatternBenefit benefit)
166     : RewritePattern(opName, {}, benefit, context), marker(marker),
167       operandsToPromote(operandsToPromote.begin(), operandsToPromote.end()),
168       alignment(alignment) {}
169 
170 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
171     Operation *op, PatternRewriter &rewriter) const {
172   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
173   if (!linalgOp)
174     return failure();
175   if (failed(marker.checkAndNotify(rewriter, linalgOp)))
176     return failure();
177   if (operandsToPromote.empty()) {
178     if (failed(promoteSubviewsLinalgOpPrecondition(op, llvm::None)))
179       return failure();
180   } else {
181     DenseSet<unsigned> set;
182     set.insert(operandsToPromote.begin(), operandsToPromote.end());
183     if (failed(promoteSubviewsLinalgOpPrecondition(op, set)))
184       return failure();
185   }
186 
187   llvm::SetVector<Value> subViews;
188   if (!operandsToPromote.empty()) {
189     for (unsigned idx : operandsToPromote) {
190       auto *op = linalgOp.getBuffer(idx).getDefiningOp();
191       if (auto sv = dyn_cast_or_null<SubViewOp>(op))
192         subViews.insert(sv);
193     }
194   } else {
195     unsigned nBuffers = linalgOp.getNumInputsAndOutputBuffers();
196     for (unsigned idx = 0; idx < nBuffers; ++idx) {
197       auto *op = linalgOp.getBuffer(idx).getDefiningOp();
198       if (auto sv = dyn_cast_or_null<SubViewOp>(op))
199         subViews.insert(sv);
200     }
201   }
202 
203   auto promotedOp =
204       promoteSubViewOperands(rewriter, op, subViews, /*dynamicBuffers=*/false,
205                              /*alignment=*/alignment);
206   marker.replaceLinalgMarker(rewriter, promotedOp.getOperation());
207   rewriter.eraseOp(op);
208   return success();
209 }
210 
211 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
212     StringRef opName, MLIRContext *context, LinalgMarker marker,
213     PatternBenefit benefit)
214     : RewritePattern(opName, {}, benefit, context), marker(marker) {}
215 
216 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
217     Operation *op, PatternRewriter &rewriter) const {
218   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
219   if (!linalgOp)
220     return failure();
221   if (failed(marker.checkAndNotify(rewriter, linalgOp)))
222     return failure();
223   if (failed(vectorizeLinalgOpPrecondition(op)))
224     return failure();
225   vectorizeLinalgOp(rewriter, op);
226   rewriter.eraseOp(op);
227   return success();
228 }
229