1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Vectorization transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16 #include "mlir/Dialect/Linalg/Utils/Utils.h"
17 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
18 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
19 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
20 #include "mlir/Dialect/Vector/VectorOps.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Support/LLVM.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/raw_ostream.h"
28 #include <type_traits>
29 
30 using namespace mlir;
31 using namespace mlir::edsc;
32 using namespace mlir::edsc::intrinsics;
33 using namespace mlir::linalg;
34 
35 using llvm::dbgs;
36 
37 #define DEBUG_TYPE "linalg-vectorization"
38 
39 static bool hasMultiplyAddBody(linalg::GenericOp op) {
40   auto &r = op.region();
41   if (!llvm::hasSingleElement(r))
42     return false;
43   if (!llvm::hasNItems(r.front().begin(), r.front().end(), 3))
44     return false;
45 
46   using mlir::matchers::m_Val;
47   auto a = m_Val(r.front().getArgument(0));
48   auto b = m_Val(r.front().getArgument(1));
49   auto c = m_Val(r.front().getArgument(2));
50   // TODO: Update this detection once we have  matcher support for specifying
51   // that any permutation of operands matches.
52   auto pattern1 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(a, b), c));
53   auto pattern2 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(a, b)));
54   auto pattern3 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(b, a), c));
55   auto pattern4 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(b, a)));
56   return pattern1.match(&r.front().back()) ||
57          pattern2.match(&r.front().back()) ||
58          pattern3.match(&r.front().back()) || pattern4.match(&r.front().back());
59 }
60 
61 // TODO: Should be Tablegen'd from a single source that generates the op itself.
62 static bool isRowMajorMatmul(linalg::GenericOp genericOp) {
63   return genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 &&
64          isRowMajorMatmul(genericOp.indexing_maps()) &&
65          hasMultiplyAddBody(genericOp);
66 }
67 
68 // TODO: This is in fact much more general than just vectorization for matmul
69 // and fill ops.
70 LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
71   auto linalgOp = cast<linalg::LinalgOp>(op);
72   // All types must be static shape to go to vector.
73   for (Value operand : linalgOp.getInputsAndOutputBuffers())
74     if (!operand.getType().cast<ShapedType>().hasStaticShape())
75       return failure();
76   for (Type outputTensorType : linalgOp.getOutputTensorTypes())
77     if (!outputTensorType.cast<ShapedType>().hasStaticShape())
78       return failure();
79   if (isa<linalg::MatmulOp, linalg::FillOp>(op))
80     return success();
81 
82   auto genericOp = dyn_cast<linalg::GenericOp>(op);
83   if (!genericOp || !::isRowMajorMatmul(genericOp))
84     return failure();
85 
86   // TODO(ntv): non-identity layout.
87   auto isStaticMemRefWithIdentityLayout = [](Value v) {
88     auto m = v.getType().dyn_cast<MemRefType>();
89     if (!m || !m.hasStaticShape() || !m.getAffineMaps().empty())
90       return false;
91     return true;
92   };
93   return success(llvm::all_of(genericOp.getInputsAndOutputBuffers(),
94                               isStaticMemRefWithIdentityLayout));
95 }
96 
97 void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
98   assert(succeeded(vectorizeLinalgOpPrecondition(op)));
99 
100   if (auto convOp = dyn_cast<linalg::ConvOp>(op)) {
101     // TODO: add a level of indirection to linalg.generic.
102     if (convOp.padding())
103       llvm_unreachable("Unexpected conv with padding");
104   }
105 
106   StringRef dbgPref = "\n[" DEBUG_TYPE "]: ";
107   (void)dbgPref;
108   edsc::ScopedContext scope(builder, op->getLoc());
109   if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
110     // Vectorize fill as a vector.broadcast.
111     LLVM_DEBUG(dbgs() << dbgPref
112                       << "Rewrite linalg.fill as vector.broadcast: " << *op);
113     Value memref = vector_type_cast(fillOp.getOutputBuffer(0));
114     Value dst = std_load(memref);
115     Value res = vector_broadcast(dst.getType(), fillOp.value());
116     std_store(res, memref);
117     return;
118   }
119 
120   // Vectorize other ops as vector contraction (currently only matmul).
121   LLVM_DEBUG(dbgs() << dbgPref
122                     << "Rewrite linalg op as vector.contract: " << *op);
123   auto extractVectorTypeFromScalarView = [](Value v) {
124     MemRefType mt = v.getType().cast<MemRefType>();
125     return VectorType::get(mt.getShape(), mt.getElementType());
126   };
127   auto linalgOp = cast<linalg::LinalgOp>(op);
128   Value viewA = linalgOp.getInput(0);
129   Value viewB = linalgOp.getInput(1);
130   Value viewC = linalgOp.getOutputBuffer(0);
131   Value zero = std_constant_index(0);
132   SmallVector<Value, 4> indicesA(linalgOp.getInputShapedType(0).getRank(),
133                                  zero);
134   SmallVector<Value, 4> indicesB(linalgOp.getInputShapedType(1).getRank(),
135                                  zero);
136   SmallVector<Value, 4> indicesC(linalgOp.getOutputShapedType(0).getRank(),
137                                  zero);
138   Value a = vector_transfer_read(extractVectorTypeFromScalarView(viewA), viewA,
139                                  indicesA);
140   Value b = vector_transfer_read(extractVectorTypeFromScalarView(viewB), viewB,
141                                  indicesB);
142   Value c = vector_transfer_read(extractVectorTypeFromScalarView(viewC), viewC,
143                                  indicesC);
144   Value res = vector_contract(a, b, c, linalgOp.indexing_maps(),
145                               linalgOp.iterator_types());
146   vector_transfer_write(res, viewC, indicesC);
147 }
148 
149 /// Check whether there is any interleaved use of any `values` between `firstOp`
150 /// and `secondOp`. Conservatively return `true` if any op or value is in a
151 /// different block.
152 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
153                                     ValueRange values) {
154   StringRef dbgPref = "\n[" DEBUG_TYPE "]: ";
155   (void)dbgPref;
156   if (firstOp->getBlock() != secondOp->getBlock() ||
157       !firstOp->isBeforeInBlock(secondOp)) {
158     LLVM_DEBUG(llvm::dbgs()
159                << dbgPref << "interleavedUses precondition failed, firstOp: "
160                << *firstOp << ", second op: " << *secondOp);
161     return true;
162   }
163   for (auto v : values) {
164     for (auto &u : v.getUses()) {
165       Operation *owner = u.getOwner();
166       if (owner == firstOp || owner == secondOp)
167         continue;
168       // TODO: this is too conservative, use dominance info in the future.
169       if (owner->getBlock() == firstOp->getBlock() &&
170           (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
171         continue;
172       LLVM_DEBUG(llvm::dbgs()
173                  << dbgPref << " found interleaved op " << *owner
174                  << ", firstOp: " << *firstOp << ", second op: " << *secondOp);
175       return true;
176     }
177   }
178   return false;
179 }
180 
181 /// Return the unique subview use of `v` if it is indeed unique, null otherwise.
182 static SubViewOp getSubViewUseIfUnique(Value v) {
183   SubViewOp subViewOp;
184   for (auto &u : v.getUses()) {
185     if (auto newSubViewOp = dyn_cast<SubViewOp>(u.getOwner())) {
186       if (subViewOp)
187         return SubViewOp();
188       subViewOp = newSubViewOp;
189     }
190   }
191   return subViewOp;
192 }
193 
194 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
195 /// when available.
196 LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
197     vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
198 
199   // Transfer into `view`.
200   Value viewOrAlloc = xferOp.memref();
201   if (!viewOrAlloc.getDefiningOp<ViewOp>() &&
202       !viewOrAlloc.getDefiningOp<AllocOp>())
203     return failure();
204 
205   StringRef dbgPref = "\n[" DEBUG_TYPE "]: VTRForwarding: ";
206   (void)dbgPref;
207   LLVM_DEBUG(llvm::dbgs() << dbgPref << viewOrAlloc);
208 
209   // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
210   SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
211   if (!subViewOp)
212     return failure();
213   Value subView = subViewOp.getResult();
214   LLVM_DEBUG(llvm::dbgs() << dbgPref << "with subView " << subView);
215 
216   // Find the copy into `subView` without interleaved uses.
217   CopyOp copyOp;
218   for (auto &u : subView.getUses()) {
219     if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
220       if (newCopyOp.getOutputBuffer(0) != subView)
221         continue;
222       LLVM_DEBUG(llvm::dbgs() << dbgPref << "copy candidate " << *newCopyOp);
223       if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
224         continue;
225       copyOp = newCopyOp;
226       break;
227     }
228   }
229   if (!copyOp)
230     return failure();
231   LLVM_DEBUG(llvm::dbgs() << dbgPref << "with copy " << *copyOp);
232 
233   // Find the fill into `viewOrAlloc` without interleaved uses before the copy.
234   FillOp maybeFillOp;
235   for (auto &u : viewOrAlloc.getUses()) {
236     if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
237       if (newFillOp.getOutputBuffer(0) != viewOrAlloc)
238         continue;
239       LLVM_DEBUG(llvm::dbgs() << dbgPref << "fill candidate " << *newFillOp);
240       if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
241         continue;
242       maybeFillOp = newFillOp;
243       break;
244     }
245   }
246   // Ensure padding matches.
247   if (maybeFillOp && xferOp.padding() != maybeFillOp.value())
248     return failure();
249   if (maybeFillOp)
250     LLVM_DEBUG(llvm::dbgs() << dbgPref << "with maybeFillOp " << *maybeFillOp);
251 
252   // `in` is the subview that linalg.copy reads. Replace it.
253   Value in = copyOp.getInput(0);
254 
255   Value res = rewriter.create<vector::TransferReadOp>(
256       xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.indices(),
257       xferOp.permutation_map(), xferOp.padding(),
258       xferOp.masked() ? *xferOp.masked() : ArrayAttr());
259 
260   if (maybeFillOp)
261     rewriter.eraseOp(maybeFillOp);
262   rewriter.eraseOp(copyOp);
263   rewriter.replaceOp(xferOp, res);
264 
265   return success();
266 }
267 
268 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
269 /// when available.
270 LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
271     vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
272   // Transfer into `viewOrAlloc`.
273   Value viewOrAlloc = xferOp.memref();
274   if (!viewOrAlloc.getDefiningOp<ViewOp>() &&
275       !viewOrAlloc.getDefiningOp<AllocOp>())
276     return failure();
277 
278   // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
279   SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
280   if (!subViewOp)
281     return failure();
282   Value subView = subViewOp.getResult();
283 
284   // Find the copy from `subView` without interleaved uses.
285   CopyOp copyOp;
286   for (auto &u : subViewOp.getResult().getUses()) {
287     if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
288       if (newCopyOp.getInput(0) != subView)
289         continue;
290       if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
291         continue;
292       copyOp = newCopyOp;
293       break;
294     }
295   }
296   if (!copyOp)
297     return failure();
298 
299   // `out` is the subview copied into that we replace.
300   Value out = copyOp.getOutputBuffer(0);
301 
302   // Forward vector.transfer into copy.
303   rewriter.create<vector::TransferWriteOp>(
304       xferOp.getLoc(), xferOp.vector(), out, xferOp.indices(),
305       xferOp.permutation_map(),
306       xferOp.masked() ? *xferOp.masked() : ArrayAttr());
307 
308   rewriter.eraseOp(copyOp);
309   rewriter.eraseOp(xferOp);
310 
311   return success();
312 }
313