1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Vectorization transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16 #include "mlir/Dialect/Linalg/Utils/Utils.h"
17 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
18 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
19 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
20 #include "mlir/Dialect/Vector/VectorOps.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Support/LLVM.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/raw_ostream.h"
28 #include <type_traits>
29 
30 using namespace mlir;
31 using namespace mlir::edsc;
32 using namespace mlir::edsc::intrinsics;
33 using namespace mlir::linalg;
34 
35 using llvm::dbgs;
36 
37 #define DEBUG_TYPE "linalg-vectorization"
38 
39 static bool hasMultiplyAddBody(Region &r) {
40   if (!llvm::hasSingleElement(r))
41     return false;
42   if (!llvm::hasNItems(r.front().begin(), r.front().end(), 3))
43     return false;
44 
45   using mlir::matchers::m_Val;
46   auto a = m_Val(r.front().getArgument(0));
47   auto b = m_Val(r.front().getArgument(1));
48   auto c = m_Val(r.front().getArgument(2));
49   // TODO: Update this detection once we have  matcher support for specifying
50   // that any permutation of operands matches.
51   auto pattern1 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(a, b), c));
52   auto pattern2 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(a, b)));
53   auto pattern3 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(b, a), c));
54   auto pattern4 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(b, a)));
55   return pattern1.match(&r.front().back()) ||
56          pattern2.match(&r.front().back()) ||
57          pattern3.match(&r.front().back()) || pattern4.match(&r.front().back());
58 }
59 
60 // TODO: Should be Tablegen'd from a single source that generates the op itself.
61 static LogicalResult isContraction(Operation *op) {
62   // TODO: interface for named ops.
63   if (isa<linalg::BatchMatmulOp, linalg::MatmulOp, linalg::MatvecOp,
64           linalg::DotOp>(op))
65     return success();
66 
67   auto genericOp = dyn_cast<linalg::GenericOp>(op);
68   if (!genericOp)
69     return failure();
70 
71   auto mapRange =
72       genericOp.indexing_maps().getAsRange<AffineMapAttr, AffineMap>();
73 
74   return success(
75       genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 &&
76       llvm::all_of(mapRange,
77                    [](AffineMap m) { return m.isProjectedPermutation(); }) &&
78       hasMultiplyAddBody(genericOp.region()));
79 }
80 
81 LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
82   auto linalgOp = cast<linalg::LinalgOp>(op);
83   // All types must be static shape to go to vector.
84   for (Value operand : linalgOp.getInputsAndOutputBuffers())
85     if (!operand.getType().cast<ShapedType>().hasStaticShape())
86       return failure();
87   for (Type outputTensorType : linalgOp.getOutputTensorTypes())
88     if (!outputTensorType.cast<ShapedType>().hasStaticShape())
89       return failure();
90 
91   if (isa<linalg::FillOp>(op))
92     return success();
93 
94   return isContraction(op);
95 }
96 
97 void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
98   assert(succeeded(vectorizeLinalgOpPrecondition(op)));
99 
100   StringRef dbgPref = "\n[" DEBUG_TYPE "]: ";
101   (void)dbgPref;
102   edsc::ScopedContext scope(builder, op->getLoc());
103   if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
104     // Vectorize fill as a vector.broadcast.
105     LLVM_DEBUG(dbgs() << dbgPref
106                       << "Rewrite linalg.fill as vector.broadcast: " << *op);
107     Value memref = vector_type_cast(fillOp.getOutputBuffer(0));
108     Value dst = std_load(memref);
109     Value res = vector_broadcast(dst.getType(), fillOp.value());
110     std_store(res, memref);
111     return;
112   }
113 
114   assert(succeeded(isContraction(op)) && "Expected contraction");
115 
116   // Vectorize other ops as vector contraction.
117   // TODO: interface.
118   LLVM_DEBUG(dbgs() << dbgPref
119                     << "Rewrite linalg op as vector.contract: " << *op);
120   // In the case of 0-D memrefs, return null and special case to scalar load or
121   // store later.
122   auto extractVectorTypeFromScalarView = [](Value v) {
123     MemRefType mt = v.getType().cast<MemRefType>();
124     return mt.getShape().empty()
125                ? VectorType()
126                : VectorType::get(mt.getShape(), mt.getElementType());
127   };
128   auto linalgOp = cast<linalg::LinalgOp>(op);
129   Value viewA = linalgOp.getInput(0);
130   Value viewB = linalgOp.getInput(1);
131   Value viewC = linalgOp.getOutputBuffer(0);
132   VectorType vtA = extractVectorTypeFromScalarView(viewA);
133   VectorType vtB = extractVectorTypeFromScalarView(viewB);
134   VectorType vtC = extractVectorTypeFromScalarView(viewC);
135   Value zero = std_constant_index(0);
136   SmallVector<Value, 4> indicesA, indicesB, indicesC;
137   if (vtA)
138     indicesA = SmallVector<Value, 4>(vtA.getRank(), zero);
139   if (vtB)
140     indicesB = SmallVector<Value, 4>(vtB.getRank(), zero);
141   if (vtC)
142     indicesC = SmallVector<Value, 4>(vtC.getRank(), zero);
143   Value a = vtA ? vector_transfer_read(vtA, viewA, indicesA).value
144                 : std_load(viewA, indicesA).value;
145   Value b = vtB ? vector_transfer_read(vtB, viewB, indicesB).value
146                 : std_load(viewB, indicesB).value;
147   Value c = vtC ? vector_transfer_read(vtC, viewC, indicesC).value
148                 : std_load(viewC, indicesC).value;
149   Value res = vector_contract(a, b, c, linalgOp.indexing_maps(),
150                               linalgOp.iterator_types());
151   if (vtC)
152     vector_transfer_write(res, viewC, indicesC);
153   else
154     std_store(res, viewC, indicesC);
155 }
156 
157 /// Check whether there is any interleaved use of any `values` between `firstOp`
158 /// and `secondOp`. Conservatively return `true` if any op or value is in a
159 /// different block.
160 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
161                                     ValueRange values) {
162   StringRef dbgPref = "\n[" DEBUG_TYPE "]: ";
163   (void)dbgPref;
164   if (firstOp->getBlock() != secondOp->getBlock() ||
165       !firstOp->isBeforeInBlock(secondOp)) {
166     LLVM_DEBUG(llvm::dbgs()
167                << dbgPref << "interleavedUses precondition failed, firstOp: "
168                << *firstOp << ", second op: " << *secondOp);
169     return true;
170   }
171   for (auto v : values) {
172     for (auto &u : v.getUses()) {
173       Operation *owner = u.getOwner();
174       if (owner == firstOp || owner == secondOp)
175         continue;
176       // TODO: this is too conservative, use dominance info in the future.
177       if (owner->getBlock() == firstOp->getBlock() &&
178           (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
179         continue;
180       LLVM_DEBUG(llvm::dbgs()
181                  << dbgPref << " found interleaved op " << *owner
182                  << ", firstOp: " << *firstOp << ", second op: " << *secondOp);
183       return true;
184     }
185   }
186   return false;
187 }
188 
189 /// Return the unique subview use of `v` if it is indeed unique, null otherwise.
190 static SubViewOp getSubViewUseIfUnique(Value v) {
191   SubViewOp subViewOp;
192   for (auto &u : v.getUses()) {
193     if (auto newSubViewOp = dyn_cast<SubViewOp>(u.getOwner())) {
194       if (subViewOp)
195         return SubViewOp();
196       subViewOp = newSubViewOp;
197     }
198   }
199   return subViewOp;
200 }
201 
202 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
203 /// when available.
204 LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
205     vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
206 
207   // Transfer into `view`.
208   Value viewOrAlloc = xferOp.memref();
209   if (!viewOrAlloc.getDefiningOp<ViewOp>() &&
210       !viewOrAlloc.getDefiningOp<AllocOp>())
211     return failure();
212 
213   StringRef dbgPref = "\n[" DEBUG_TYPE "]: VTRForwarding: ";
214   (void)dbgPref;
215   LLVM_DEBUG(llvm::dbgs() << dbgPref << viewOrAlloc);
216 
217   // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
218   SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
219   if (!subViewOp)
220     return failure();
221   Value subView = subViewOp.getResult();
222   LLVM_DEBUG(llvm::dbgs() << dbgPref << "with subView " << subView);
223 
224   // Find the copy into `subView` without interleaved uses.
225   CopyOp copyOp;
226   for (auto &u : subView.getUses()) {
227     if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
228       if (newCopyOp.getOutputBuffer(0) != subView)
229         continue;
230       LLVM_DEBUG(llvm::dbgs() << dbgPref << "copy candidate " << *newCopyOp);
231       if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
232         continue;
233       copyOp = newCopyOp;
234       break;
235     }
236   }
237   if (!copyOp)
238     return failure();
239   LLVM_DEBUG(llvm::dbgs() << dbgPref << "with copy " << *copyOp);
240 
241   // Find the fill into `viewOrAlloc` without interleaved uses before the copy.
242   FillOp maybeFillOp;
243   for (auto &u : viewOrAlloc.getUses()) {
244     if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
245       if (newFillOp.getOutputBuffer(0) != viewOrAlloc)
246         continue;
247       LLVM_DEBUG(llvm::dbgs() << dbgPref << "fill candidate " << *newFillOp);
248       if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
249         continue;
250       maybeFillOp = newFillOp;
251       break;
252     }
253   }
254   // Ensure padding matches.
255   if (maybeFillOp && xferOp.padding() != maybeFillOp.value())
256     return failure();
257   if (maybeFillOp)
258     LLVM_DEBUG(llvm::dbgs() << dbgPref << "with maybeFillOp " << *maybeFillOp);
259 
260   // `in` is the subview that linalg.copy reads. Replace it.
261   Value in = copyOp.getInput(0);
262 
263   // linalg.copy + linalg.fill can be used to create a padded local buffer.
264   // The `masked` attribute is only valid on this padded buffer.
265   // When forwarding to vector.transfer_read, the attribute must be reset
266   // conservatively.
267   Value res = rewriter.create<vector::TransferReadOp>(
268       xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.indices(),
269       xferOp.permutation_map(), xferOp.padding(), ArrayAttr());
270 
271   if (maybeFillOp)
272     rewriter.eraseOp(maybeFillOp);
273   rewriter.eraseOp(copyOp);
274   rewriter.replaceOp(xferOp, res);
275 
276   return success();
277 }
278 
279 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
280 /// when available.
281 LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
282     vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
283   // Transfer into `viewOrAlloc`.
284   Value viewOrAlloc = xferOp.memref();
285   if (!viewOrAlloc.getDefiningOp<ViewOp>() &&
286       !viewOrAlloc.getDefiningOp<AllocOp>())
287     return failure();
288 
289   // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
290   SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
291   if (!subViewOp)
292     return failure();
293   Value subView = subViewOp.getResult();
294 
295   // Find the copy from `subView` without interleaved uses.
296   CopyOp copyOp;
297   for (auto &u : subViewOp.getResult().getUses()) {
298     if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
299       if (newCopyOp.getInput(0) != subView)
300         continue;
301       if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
302         continue;
303       copyOp = newCopyOp;
304       break;
305     }
306   }
307   if (!copyOp)
308     return failure();
309 
310   // `out` is the subview copied into that we replace.
311   Value out = copyOp.getOutputBuffer(0);
312 
313   // Forward vector.transfer into copy.
314   // linalg.copy + linalg.fill can be used to create a padded local buffer.
315   // The `masked` attribute is only valid on this padded buffer.
316   // When forwarding to vector.transfer_write, the attribute must be reset
317   // conservatively.
318   rewriter.create<vector::TransferWriteOp>(
319       xferOp.getLoc(), xferOp.vector(), out, xferOp.indices(),
320       xferOp.permutation_map(), ArrayAttr());
321 
322   rewriter.eraseOp(copyOp);
323   rewriter.eraseOp(xferOp);
324 
325   return success();
326 }
327