1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites as 1->N patterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
14 
15 #include <type_traits>
16 
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
19 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
20 #include "mlir/Dialect/Linalg/IR/Linalg.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/SCF/IR/SCF.h"
23 #include "mlir/Dialect/Utils/IndexingUtils.h"
24 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
25 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
26 #include "mlir/IR/BuiltinTypes.h"
27 #include "mlir/IR/ImplicitLocOpBuilder.h"
28 #include "mlir/IR/Matchers.h"
29 #include "mlir/IR/PatternMatch.h"
30 #include "mlir/Interfaces/VectorInterfaces.h"
31 
32 #include "llvm/ADT/DenseSet.h"
33 #include "llvm/ADT/MapVector.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/Support/CommandLine.h"
36 #include "llvm/Support/Debug.h"
37 #include "llvm/Support/raw_ostream.h"
38 
39 #define DEBUG_TYPE "vector-to-vector"
40 
41 using namespace mlir;
42 using namespace mlir::vector;
43 
44 // Helper to find an index in an affine map.
getResultIndex(AffineMap map,int64_t index)45 static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
46   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
47     int64_t idx = map.getDimPosition(i);
48     if (idx == index)
49       return i;
50   }
51   return None;
52 }
53 
54 // Helper to construct iterator types with one index removed.
adjustIter(ArrayAttr iteratorTypes,int64_t index)55 static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes,
56                                             int64_t index) {
57   SmallVector<Attribute, 4> results;
58   for (const auto &it : llvm::enumerate(iteratorTypes)) {
59     int64_t idx = it.index();
60     if (idx == index)
61       continue;
62     results.push_back(it.value());
63   }
64   return results;
65 }
66 
67 // Helper to construct an affine map with one index removed.
adjustMap(AffineMap map,int64_t index,PatternRewriter & rewriter)68 static AffineMap adjustMap(AffineMap map, int64_t index,
69                            PatternRewriter &rewriter) {
70   auto *ctx = rewriter.getContext();
71   SmallVector<AffineExpr, 4> results;
72   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
73     int64_t idx = map.getDimPosition(i);
74     if (idx == index)
75       continue;
76     // Re-insert remaining indices, but renamed when occurring
77     // after the removed index.
78     auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
79     results.push_back(targetExpr);
80   }
81   return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
82 }
83 
84 // Helper method to possibly drop a dimension in a load.
85 // TODO
reshapeLoad(Location loc,Value val,VectorType type,int64_t index,int64_t pos,PatternRewriter & rewriter)86 static Value reshapeLoad(Location loc, Value val, VectorType type,
87                          int64_t index, int64_t pos,
88                          PatternRewriter &rewriter) {
89   if (index == -1)
90     return val;
91   Type lowType = VectorType::Builder(type).dropDim(0);
92   // At extraction dimension?
93   if (index == 0) {
94     auto posAttr = rewriter.getI64ArrayAttr(pos);
95     return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
96   }
97   // Unroll leading dimensions.
98   VectorType vType = lowType.cast<VectorType>();
99   Type resType = VectorType::Builder(type).dropDim(index);
100   auto resVectorType = resType.cast<VectorType>();
101   Value result = rewriter.create<arith::ConstantOp>(
102       loc, resVectorType, rewriter.getZeroAttr(resVectorType));
103   for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) {
104     auto posAttr = rewriter.getI64ArrayAttr(d);
105     Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
106     Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
107     result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result,
108                                                posAttr);
109   }
110   return result;
111 }
112 
113 // Helper method to possibly drop a dimension in a store.
114 // TODO
reshapeStore(Location loc,Value val,Value result,VectorType type,int64_t index,int64_t pos,PatternRewriter & rewriter)115 static Value reshapeStore(Location loc, Value val, Value result,
116                           VectorType type, int64_t index, int64_t pos,
117                           PatternRewriter &rewriter) {
118   // Unmodified?
119   if (index == -1)
120     return val;
121   // At insertion dimension?
122   if (index == 0) {
123     auto posAttr = rewriter.getI64ArrayAttr(pos);
124     return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
125   }
126   // Unroll leading dimensions.
127   Type lowType = VectorType::Builder(type).dropDim(0);
128   VectorType vType = lowType.cast<VectorType>();
129   Type insType = VectorType::Builder(vType).dropDim(0);
130   for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
131     auto posAttr = rewriter.getI64ArrayAttr(d);
132     Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
133     Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr);
134     Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
135     result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr);
136   }
137   return result;
138 }
139 
140 template <typename IntType>
extractVector(ArrayAttr arrayAttr)141 static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
142   return llvm::to_vector<4>(llvm::map_range(
143       arrayAttr.getAsRange<IntegerAttr>(),
144       [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
145 }
146 
147 /// Helper to create arithmetic operation associated with a kind of contraction.
createContractArithOp(Location loc,Value x,Value y,Value acc,vector::CombiningKind kind,PatternRewriter & rewriter,bool isInt)148 static Optional<Value> createContractArithOp(Location loc, Value x, Value y,
149                                              Value acc,
150                                              vector::CombiningKind kind,
151                                              PatternRewriter &rewriter,
152                                              bool isInt) {
153   using vector::CombiningKind;
154   Value mul;
155   if (isInt) {
156     if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF)
157       // Only valid for floating point types.
158       return Optional<Value>();
159     mul = rewriter.create<arith::MulIOp>(loc, x, y);
160   } else {
161     // Float case.
162     if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
163         kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
164         kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
165         kind == CombiningKind::XOR)
166       // Only valid for integer types.
167       return Optional<Value>();
168     // Special case for fused multiply-add.
169     if (acc && acc.getType().isa<VectorType>() && kind == CombiningKind::ADD) {
170       return Optional<Value>(rewriter.create<vector::FMAOp>(loc, x, y, acc));
171     }
172     mul = rewriter.create<arith::MulFOp>(loc, x, y);
173   }
174   if (!acc)
175     return Optional<Value>(mul);
176   return makeArithReduction(rewriter, loc, kind, mul, acc);
177 }
178 
179 /// Return the positions of the reductions in the given map.
getReductionIndex(AffineMap map,ArrayAttr iteratorTypes)180 static SmallVector<int64_t> getReductionIndex(AffineMap map,
181                                               ArrayAttr iteratorTypes) {
182   SmallVector<int64_t> dimsIdx;
183   for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
184     if (isReductionIterator(iteratorTypes[map.getDimPosition(i)]))
185       dimsIdx.push_back(i);
186   }
187   return dimsIdx;
188 }
189 
190 /// Look for a given dimension in an affine map and return its position. Return
191 /// llvm::None if the dimension is not in the map results.
getDimPosition(AffineMap map,unsigned dim)192 static llvm::Optional<unsigned> getDimPosition(AffineMap map, unsigned dim) {
193   for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
194     if (map.getDimPosition(i) == dim)
195       return i;
196   }
197   return llvm::None;
198 }
199 
200 namespace {
201 
202 /// ShapeCastOpFolder folds cancelling ShapeCastOps away.
203 //
204 // Example:
205 //
206 //  The following MLIR with cancelling ShapeCastOps:
207 //
208 //   %0 = source : vector<5x4x2xf32>
209 //   %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
210 //   %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
211 //   %3 = user %2 : vector<5x4x2xf32>
212 //
213 //  Should canonicalize to the following:
214 //
215 //   %0 = source : vector<5x4x2xf32>
216 //   %1 = user %0 : vector<5x4x2xf32>
217 //
218 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
219   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
220 
matchAndRewrite__anon5c5a5b800211::ShapeCastOpFolder221   LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
222                                 PatternRewriter &rewriter) const override {
223     // Check if 'shapeCastOp' has vector source/result type.
224     auto sourceVectorType =
225         shapeCastOp.getSource().getType().dyn_cast_or_null<VectorType>();
226     auto resultVectorType =
227         shapeCastOp.getResult().getType().dyn_cast_or_null<VectorType>();
228     if (!sourceVectorType || !resultVectorType)
229       return failure();
230 
231     // Check if shape cast op source operand is also a shape cast op.
232     auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
233         shapeCastOp.getSource().getDefiningOp());
234     if (!sourceShapeCastOp)
235       return failure();
236     auto operandSourceVectorType =
237         sourceShapeCastOp.getSource().getType().cast<VectorType>();
238     auto operandResultVectorType = sourceShapeCastOp.getType();
239 
240     // Check if shape cast operations invert each other.
241     if (operandSourceVectorType != resultVectorType ||
242         operandResultVectorType != sourceVectorType)
243       return failure();
244 
245     rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.getSource());
246     return success();
247   }
248 };
249 
250 /// Progressive lowering of BroadcastOp.
251 class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
252 public:
253   using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
254 
matchAndRewrite(vector::BroadcastOp op,PatternRewriter & rewriter) const255   LogicalResult matchAndRewrite(vector::BroadcastOp op,
256                                 PatternRewriter &rewriter) const override {
257     auto loc = op.getLoc();
258     VectorType dstType = op.getVectorType();
259     VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
260     Type eltType = dstType.getElementType();
261 
262     // Scalar to any vector can use splat.
263     if (!srcType) {
264       rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource());
265       return success();
266     }
267 
268     // Determine rank of source and destination.
269     int64_t srcRank = srcType.getRank();
270     int64_t dstRank = dstType.getRank();
271 
272     // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
273     if (srcRank <= 1 && dstRank == 1) {
274       Value ext;
275       if (srcRank == 0)
276         ext = rewriter.create<vector::ExtractElementOp>(loc, op.getSource());
277       else
278         ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
279       rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
280       return success();
281     }
282 
283     // Duplicate this rank.
284     // For example:
285     //   %x = broadcast %y  : k-D to n-D, k < n
286     // becomes:
287     //   %b = broadcast %y  : k-D to (n-1)-D
288     //   %x = [%b,%b,%b,%b] : n-D
289     // becomes:
290     //   %b = [%y,%y]       : (n-1)-D
291     //   %x = [%b,%b,%b,%b] : n-D
292     if (srcRank < dstRank) {
293       // Duplication.
294       VectorType resType =
295           VectorType::get(dstType.getShape().drop_front(), eltType);
296       Value bcst =
297           rewriter.create<vector::BroadcastOp>(loc, resType, op.getSource());
298       Value result = rewriter.create<arith::ConstantOp>(
299           loc, dstType, rewriter.getZeroAttr(dstType));
300       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
301         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
302       rewriter.replaceOp(op, result);
303       return success();
304     }
305 
306     // Find non-matching dimension, if any.
307     assert(srcRank == dstRank);
308     int64_t m = -1;
309     for (int64_t r = 0; r < dstRank; r++)
310       if (srcType.getDimSize(r) != dstType.getDimSize(r)) {
311         m = r;
312         break;
313       }
314 
315     // All trailing dimensions are the same. Simply pass through.
316     if (m == -1) {
317       rewriter.replaceOp(op, op.getSource());
318       return success();
319     }
320 
321     // Any non-matching dimension forces a stretch along this rank.
322     // For example:
323     //   %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32>
324     // becomes:
325     //   %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32>
326     //   %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32>
327     //   %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32>
328     //   %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32>
329     //   %x = [%a,%b,%c,%d]
330     // becomes:
331     //   %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32>
332     //   %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32>
333     //   %a = [%u, %v]
334     //   ..
335     //   %x = [%a,%b,%c,%d]
336     VectorType resType =
337         VectorType::get(dstType.getShape().drop_front(), eltType);
338     Value result = rewriter.create<arith::ConstantOp>(
339         loc, dstType, rewriter.getZeroAttr(dstType));
340     if (m == 0) {
341       // Stetch at start.
342       Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
343       Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
344       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
345         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
346     } else {
347       // Stetch not at start.
348       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
349         Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), d);
350         Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
351         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
352       }
353     }
354     rewriter.replaceOp(op, result);
355     return success();
356   }
357 };
358 
359 /// Given a 'transpose' pattern, prune the rightmost dimensions that are not
360 /// transposed.
pruneNonTransposedDims(ArrayRef<int64_t> transpose,SmallVectorImpl<int64_t> & result)361 void pruneNonTransposedDims(ArrayRef<int64_t> transpose,
362                             SmallVectorImpl<int64_t> &result) {
363   size_t numTransposedDims = transpose.size();
364   for (size_t transpDim : llvm::reverse(transpose)) {
365     if (transpDim != numTransposedDims - 1)
366       break;
367     numTransposedDims--;
368   }
369 
370   result.append(transpose.begin(), transpose.begin() + numTransposedDims);
371 }
372 
373 /// Progressive lowering of TransposeOp.
374 /// One:
375 ///   %x = vector.transpose %y, [1, 0]
376 /// is replaced by:
377 ///   %z = arith.constant dense<0.000000e+00>
378 ///   %0 = vector.extract %y[0, 0]
379 ///   %1 = vector.insert %0, %z [0, 0]
380 ///   ..
381 ///   %x = vector.insert .., .. [.., ..]
382 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
383 public:
384   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
385 
TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,MLIRContext * context)386   TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
387                       MLIRContext *context)
388       : OpRewritePattern<vector::TransposeOp>(context),
389         vectorTransformOptions(vectorTransformOptions) {}
390 
matchAndRewrite(vector::TransposeOp op,PatternRewriter & rewriter) const391   LogicalResult matchAndRewrite(vector::TransposeOp op,
392                                 PatternRewriter &rewriter) const override {
393     auto loc = op.getLoc();
394 
395     Value input = op.getVector();
396     VectorType inputType = op.getVectorType();
397     VectorType resType = op.getResultType();
398 
399     // Set up convenience transposition table.
400     SmallVector<int64_t, 4> transp;
401     for (auto attr : op.getTransp())
402       transp.push_back(attr.cast<IntegerAttr>().getInt());
403 
404     if (vectorTransformOptions.vectorTransposeLowering ==
405             vector::VectorTransposeLowering::Shuffle &&
406         resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0)
407       return rewriter.notifyMatchFailure(
408           op, "Options specifies lowering to shuffle");
409 
410     // Handle a true 2-D matrix transpose differently when requested.
411     if (vectorTransformOptions.vectorTransposeLowering ==
412             vector::VectorTransposeLowering::Flat &&
413         resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
414       Type flattenedType =
415           VectorType::get(resType.getNumElements(), resType.getElementType());
416       auto matrix =
417           rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
418       auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
419       auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
420       Value trans = rewriter.create<vector::FlatTransposeOp>(
421           loc, flattenedType, matrix, rows, columns);
422       rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
423       return success();
424     }
425 
426     // Generate unrolled extract/insert ops. We do not unroll the rightmost
427     // (i.e., highest-order) dimensions that are not transposed and leave them
428     // in vector form to improve performance. Therefore, we prune those
429     // dimensions from the shape/transpose data structures used to generate the
430     // extract/insert ops.
431     SmallVector<int64_t, 4> prunedTransp;
432     pruneNonTransposedDims(transp, prunedTransp);
433     size_t numPrunedDims = transp.size() - prunedTransp.size();
434     auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
435     SmallVector<int64_t, 4> ones(prunedInShape.size(), 1);
436     auto prunedInStrides = computeStrides(prunedInShape, ones);
437 
438     // Generates the extract/insert operations for every scalar/vector element
439     // of the leftmost transposed dimensions. We traverse every transpose
440     // element using a linearized index that we delinearize to generate the
441     // appropriate indices for the extract/insert operations.
442     Value result = rewriter.create<arith::ConstantOp>(
443         loc, resType, rewriter.getZeroAttr(resType));
444     int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
445 
446     for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
447          ++linearIdx) {
448       auto extractIdxs = delinearize(prunedInStrides, linearIdx);
449       SmallVector<int64_t, 4> insertIdxs(extractIdxs);
450       applyPermutationToVector(insertIdxs, prunedTransp);
451       Value extractOp =
452           rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
453       result =
454           rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
455     }
456 
457     rewriter.replaceOp(op, result);
458     return success();
459   }
460 
461 private:
462   /// Options to control the vector patterns.
463   vector::VectorTransformsOptions vectorTransformOptions;
464 };
465 
466 /// Rewrite a 2-D vector.transpose as a sequence of:
467 ///   vector.shape_cast 2D -> 1D
468 ///   vector.shuffle
469 ///   vector.shape_cast 1D -> 2D
470 class TransposeOp2DToShuffleLowering
471     : public OpRewritePattern<vector::TransposeOp> {
472 public:
473   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
474 
TransposeOp2DToShuffleLowering(vector::VectorTransformsOptions vectorTransformOptions,MLIRContext * context)475   TransposeOp2DToShuffleLowering(
476       vector::VectorTransformsOptions vectorTransformOptions,
477       MLIRContext *context)
478       : OpRewritePattern<vector::TransposeOp>(context),
479         vectorTransformOptions(vectorTransformOptions) {}
480 
matchAndRewrite(vector::TransposeOp op,PatternRewriter & rewriter) const481   LogicalResult matchAndRewrite(vector::TransposeOp op,
482                                 PatternRewriter &rewriter) const override {
483     auto loc = op.getLoc();
484 
485     VectorType srcType = op.getVectorType();
486     if (srcType.getRank() != 2)
487       return rewriter.notifyMatchFailure(op, "Not a 2D transpose");
488 
489     SmallVector<int64_t, 4> transp;
490     for (auto attr : op.getTransp())
491       transp.push_back(attr.cast<IntegerAttr>().getInt());
492     if (transp[0] != 1 && transp[1] != 0)
493       return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation");
494 
495     if (vectorTransformOptions.vectorTransposeLowering !=
496         VectorTransposeLowering::Shuffle)
497       return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle");
498 
499     int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
500     Value casted = rewriter.create<vector::ShapeCastOp>(
501         loc, VectorType::get({m * n}, srcType.getElementType()),
502         op.getVector());
503     SmallVector<int64_t> mask;
504     mask.reserve(m * n);
505     for (int64_t j = 0; j < n; ++j)
506       for (int64_t i = 0; i < m; ++i)
507         mask.push_back(i * n + j);
508 
509     Value shuffled =
510         rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask);
511     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
512                                                      shuffled);
513 
514     return success();
515   }
516 
517 private:
518   /// Options to control the vector patterns.
519   vector::VectorTransformsOptions vectorTransformOptions;
520 };
521 
522 /// Progressive lowering of OuterProductOp.
523 /// One:
524 ///   %x = vector.outerproduct %lhs, %rhs, %acc
525 /// is replaced by:
526 ///   %z = zero-result
527 ///   %0 = vector.extract %lhs[0]
528 ///   %1 = vector.broadcast %0
529 ///   %2 = vector.extract %acc[0]
530 ///   %3 = vector.fma %1, %rhs, %2
531 ///   %4 = vector.insert %3, %z[0]
532 ///   ..
533 ///   %x = vector.insert %.., %..[N-1]
534 ///
535 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
536 public:
537   using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
538 
matchAndRewrite(vector::OuterProductOp op,PatternRewriter & rewriter) const539   LogicalResult matchAndRewrite(vector::OuterProductOp op,
540                                 PatternRewriter &rewriter) const override {
541     auto loc = op.getLoc();
542 
543     VectorType lhsType = op.getOperandVectorTypeLHS();
544     VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>();
545     VectorType resType = op.getVectorType();
546     Type eltType = resType.getElementType();
547     bool isInt = eltType.isa<IntegerType, IndexType>();
548     Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0];
549     vector::CombiningKind kind = op.getKind();
550 
551     if (!rhsType) {
552       // Special case: AXPY operation.
553       Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
554       Optional<Value> mult = createContractArithOp(loc, op.getLhs(), b, acc,
555                                                    kind, rewriter, isInt);
556       if (!mult.has_value())
557         return failure();
558       rewriter.replaceOp(op, mult.value());
559       return success();
560     }
561 
562     Value result = rewriter.create<arith::ConstantOp>(
563         loc, resType, rewriter.getZeroAttr(resType));
564     for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
565       auto pos = rewriter.getI64ArrayAttr(d);
566       Value x =
567           rewriter.create<vector::ExtractOp>(loc, eltType, op.getLhs(), pos);
568       Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
569       Value r = nullptr;
570       if (acc)
571         r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
572       Optional<Value> m =
573           createContractArithOp(loc, a, op.getRhs(), r, kind, rewriter, isInt);
574       if (!m.has_value())
575         return failure();
576       result = rewriter.create<vector::InsertOp>(loc, resType, m.value(),
577                                                  result, pos);
578     }
579     rewriter.replaceOp(op, result);
580     return success();
581   }
582 };
583 
584 /// Lower vector.contract with all size one reduction dimensions to
585 /// elementwise ops when possible.
586 struct ContractOpToElementwise
587     : public OpRewritePattern<vector::ContractionOp> {
588   using OpRewritePattern::OpRewritePattern;
589   using FilterConstraintType =
590       std::function<LogicalResult(vector::ContractionOp op)>;
defaultFilter__anon5c5a5b800211::ContractOpToElementwise591   static LogicalResult defaultFilter(vector::ContractionOp op) {
592     return success();
593   }
ContractOpToElementwise__anon5c5a5b800211::ContractOpToElementwise594   ContractOpToElementwise(
595       vector::VectorTransformsOptions vectorTransformOptions,
596       MLIRContext *context,
597       const FilterConstraintType &constraint = defaultFilter)
598       : OpRewritePattern<vector::ContractionOp>(context),
599         vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
600 
matchAndRewrite__anon5c5a5b800211::ContractOpToElementwise601   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
602                                 PatternRewriter &rewriter) const override {
603     // TODO: implement masks
604     if (llvm::size(contractOp.getMasks()) != 0)
605       return failure();
606 
607     if (failed(filter(contractOp)))
608       return failure();
609 
610     if (vectorTransformOptions.vectorContractLowering !=
611         vector::VectorContractLowering::ParallelArith)
612       return failure();
613     ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
614     ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape();
615     AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
616     AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
617     SmallVector<int64_t> lhsReductionDims =
618         getReductionIndex(lhsMap, contractOp.getIteratorTypes());
619     SmallVector<int64_t> rhsReductionDims =
620         getReductionIndex(rhsMap, contractOp.getIteratorTypes());
621     // All the reduction dimensions must be a size 1.
622     for (int64_t dim : lhsReductionDims) {
623       if (lhsShape[dim] != 1)
624         return failure();
625     }
626     for (int64_t dim : rhsReductionDims) {
627       if (rhsShape[dim] != 1)
628         return failure();
629     }
630     AffineMap accMap = contractOp.getIndexingMapsArray()[2];
631     unsigned numParallelDims = accMap.getNumResults();
632     unsigned numLhsDimToBroadcast =
633         numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size());
634     unsigned numRhsDimToBroadcast =
635         numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size());
636     SmallVector<int64_t> lhsDims;
637     SmallVector<int64_t> lhsTranspose;
638     SmallVector<int64_t> rhsDims;
639     SmallVector<int64_t> rhsTranspose;
640     for (int64_t dim : lhsReductionDims)
641       lhsTranspose.push_back(numLhsDimToBroadcast + dim);
642     for (int64_t dim : rhsReductionDims)
643       rhsTranspose.push_back(numRhsDimToBroadcast + dim);
644     // Loop through the parallel dimensions to calculate the dimensions to
645     // broadcast and to permute in order to extract only parallel dimensions.
646     for (unsigned i = 0; i < numParallelDims; i++) {
647       llvm::Optional<unsigned> lhsDim =
648           getDimPosition(lhsMap, accMap.getDimPosition(i));
649       if (lhsDim) {
650         lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
651       } else {
652         // If the parallel dimension doesn't exist we will have to broadcast it.
653         lhsDims.push_back(
654             contractOp.getResultType().cast<VectorType>().getDimSize(i));
655         lhsTranspose.push_back(lhsDims.size() - 1);
656       }
657       llvm::Optional<unsigned> rhsDim =
658           getDimPosition(rhsMap, accMap.getDimPosition(i));
659       if (rhsDim) {
660         rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
661       } else {
662         // If the parallel dimension doesn't exist we will have to broadcast it.
663         rhsDims.push_back(
664             contractOp.getResultType().cast<VectorType>().getDimSize(i));
665         rhsTranspose.push_back(rhsDims.size() - 1);
666       }
667     }
668     Value newLhs = contractOp.getLhs();
669     Value newRhs = contractOp.getRhs();
670     Location loc = contractOp.getLoc();
671     if (!lhsDims.empty()) {
672       lhsDims.append(lhsShape.begin(), lhsShape.end());
673       auto expandedType =
674           VectorType::get(lhsDims, contractOp.getLhsType().getElementType());
675       newLhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newLhs);
676     }
677     if (!rhsDims.empty()) {
678       rhsDims.append(rhsShape.begin(), rhsShape.end());
679       auto expandedType =
680           VectorType::get(rhsDims, contractOp.getRhsType().getElementType());
681       newRhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newRhs);
682     }
683     bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
684     newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose);
685     newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
686     SmallVector<int64_t, 4> lhsOffsets(lhsReductionDims.size(), 0);
687     SmallVector<int64_t, 4> rhsOffsets(rhsReductionDims.size(), 0);
688     newLhs = rewriter.create<vector::ExtractOp>(
689         loc, newLhs, rewriter.getI64ArrayAttr(lhsOffsets));
690     newRhs = rewriter.create<vector::ExtractOp>(
691         loc, newRhs, rewriter.getI64ArrayAttr(rhsOffsets));
692     Optional<Value> result =
693         createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
694                               contractOp.getKind(), rewriter, isInt);
695     rewriter.replaceOp(contractOp, {*result});
696     return success();
697   }
698 
699 private:
700   /// Options to control the vector patterns.
701   vector::VectorTransformsOptions vectorTransformOptions;
702   FilterConstraintType filter;
703 };
704 
705 /// Progressive lowering of ConstantMaskOp.
706 /// One:
707 ///   %x = vector.constant_mask [a,b]
708 /// is replaced by:
709 ///   %z = zero-result
710 ///   %l = vector.constant_mask [b]
711 ///   %4 = vector.insert %l, %z[0]
712 ///   ..
713 ///   %x = vector.insert %l, %..[a-1]
714 /// until a one-dimensional vector is reached. All these operations
715 /// will be folded at LLVM IR level.
716 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
717 public:
718   using OpRewritePattern<vector::ConstantMaskOp>::OpRewritePattern;
719 
matchAndRewrite(vector::ConstantMaskOp op,PatternRewriter & rewriter) const720   LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
721                                 PatternRewriter &rewriter) const override {
722     auto loc = op.getLoc();
723     auto dstType = op.getType();
724     auto eltType = dstType.getElementType();
725     auto dimSizes = op.getMaskDimSizes();
726     int64_t rank = dstType.getRank();
727 
728     if (rank == 0) {
729       assert(dimSizes.size() == 1 &&
730              "Expected exactly one dim size for a 0-D vector");
731       bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1;
732       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
733           op, dstType,
734           DenseIntElementsAttr::get(
735               VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
736               ArrayRef<bool>{value}));
737       return success();
738     }
739 
740     // Scalable constant masks can only be lowered for the "none set" case.
741     if (dstType.cast<VectorType>().isScalable()) {
742       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
743           op, DenseElementsAttr::get(dstType, false));
744       return success();
745     }
746 
747     int64_t trueDim = std::min(dstType.getDimSize(0),
748                                dimSizes[0].cast<IntegerAttr>().getInt());
749 
750     if (rank == 1) {
751       // Express constant 1-D case in explicit vector form:
752       //   [T,..,T,F,..,F].
753       SmallVector<bool, 4> values(dstType.getDimSize(0));
754       for (int64_t d = 0; d < trueDim; d++)
755         values[d] = true;
756       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
757           op, dstType, rewriter.getBoolVectorAttr(values));
758       return success();
759     }
760 
761     VectorType lowType =
762         VectorType::get(dstType.getShape().drop_front(), eltType);
763     SmallVector<int64_t, 4> newDimSizes;
764     for (int64_t r = 1; r < rank; r++)
765       newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
766     Value trueVal = rewriter.create<vector::ConstantMaskOp>(
767         loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
768     Value result = rewriter.create<arith::ConstantOp>(
769         loc, dstType, rewriter.getZeroAttr(dstType));
770     for (int64_t d = 0; d < trueDim; d++) {
771       auto pos = rewriter.getI64ArrayAttr(d);
772       result =
773           rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos);
774     }
775     rewriter.replaceOp(op, result);
776     return success();
777   }
778 };
779 
780 /// Progressive lowering of CreateMaskOp.
781 /// One:
782 ///   %x = vector.create_mask %a, ... : vector<dx...>
783 /// is replaced by:
784 ///   %l = vector.create_mask ... : vector<...>  ; one lower rank
785 ///   %0 = arith.cmpi "slt", %ci, %a       |
786 ///   %1 = select %0, %l, %zeroes    |
787 ///   %r = vector.insert %1, %pr [i] | d-times
788 ///   %x = ....
789 /// until a one-dimensional vector is reached.
790 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
791 public:
792   using OpRewritePattern<vector::CreateMaskOp>::OpRewritePattern;
793 
matchAndRewrite(vector::CreateMaskOp op,PatternRewriter & rewriter) const794   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
795                                 PatternRewriter &rewriter) const override {
796     auto dstType = op.getResult().getType().cast<VectorType>();
797     int64_t rank = dstType.getRank();
798     if (rank <= 1)
799       return rewriter.notifyMatchFailure(
800           op, "0-D and 1-D vectors are handled separately");
801 
802     auto loc = op.getLoc();
803     auto eltType = dstType.getElementType();
804     int64_t dim = dstType.getDimSize(0);
805     Value idx = op.getOperand(0);
806 
807     VectorType lowType =
808         VectorType::get(dstType.getShape().drop_front(), eltType);
809     Value trueVal = rewriter.create<vector::CreateMaskOp>(
810         loc, lowType, op.getOperands().drop_front());
811     Value falseVal = rewriter.create<arith::ConstantOp>(
812         loc, lowType, rewriter.getZeroAttr(lowType));
813     Value result = rewriter.create<arith::ConstantOp>(
814         loc, dstType, rewriter.getZeroAttr(dstType));
815     for (int64_t d = 0; d < dim; d++) {
816       Value bnd =
817           rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d));
818       Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
819                                                  bnd, idx);
820       Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal);
821       auto pos = rewriter.getI64ArrayAttr(d);
822       result =
823           rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos);
824     }
825     rewriter.replaceOp(op, result);
826     return success();
827   }
828 };
829 
830 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
831 /// vectors progressively on the way to target llvm.matrix intrinsics.
832 /// This iterates over the most major dimension of the 2-D vector and performs
833 /// rewrites into:
834 ///   vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
835 class ShapeCastOp2DDownCastRewritePattern
836     : public OpRewritePattern<vector::ShapeCastOp> {
837 public:
838   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
839 
matchAndRewrite(vector::ShapeCastOp op,PatternRewriter & rewriter) const840   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
841                                 PatternRewriter &rewriter) const override {
842     auto sourceVectorType = op.getSourceVectorType();
843     auto resultVectorType = op.getResultVectorType();
844     if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
845       return failure();
846 
847     auto loc = op.getLoc();
848     Value desc = rewriter.create<arith::ConstantOp>(
849         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
850     unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
851     for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
852       Value vec = rewriter.create<vector::ExtractOp>(loc, op.getSource(), i);
853       desc = rewriter.create<vector::InsertStridedSliceOp>(
854           loc, vec, desc,
855           /*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
856     }
857     rewriter.replaceOp(op, desc);
858     return success();
859   }
860 };
861 
862 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
863 /// vectors progressively.
864 /// This iterates over the most major dimension of the 2-D vector and performs
865 /// rewrites into:
866 ///   vector.extract_strided_slice from 1-D + vector.insert into 2-D
867 /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
868 class ShapeCastOp2DUpCastRewritePattern
869     : public OpRewritePattern<vector::ShapeCastOp> {
870 public:
871   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
872 
matchAndRewrite(vector::ShapeCastOp op,PatternRewriter & rewriter) const873   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
874                                 PatternRewriter &rewriter) const override {
875     auto sourceVectorType = op.getSourceVectorType();
876     auto resultVectorType = op.getResultVectorType();
877     if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
878       return failure();
879 
880     auto loc = op.getLoc();
881     Value desc = rewriter.create<arith::ConstantOp>(
882         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
883     unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
884     for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
885       Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
886           loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize,
887           /*sizes=*/mostMinorVectorSize,
888           /*strides=*/1);
889       desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
890     }
891     rewriter.replaceOp(op, desc);
892     return success();
893   }
894 };
895 
896 // We typically should not lower general shape cast operations into data
897 // movement instructions, since the assumption is that these casts are
898 // optimized away during progressive lowering. For completeness, however,
899 // we fall back to a reference implementation that moves all elements
900 // into the right place if we get here.
901 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
902 public:
903   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
904 
matchAndRewrite(vector::ShapeCastOp op,PatternRewriter & rewriter) const905   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
906                                 PatternRewriter &rewriter) const override {
907     Location loc = op.getLoc();
908     auto sourceVectorType = op.getSourceVectorType();
909     auto resultVectorType = op.getResultVectorType();
910 
911     // Special case 2D/1D lowerings with better implementations.
912     // TODO: make is ND/1D to allow generic ND->1D->MD.
913     int64_t srcRank = sourceVectorType.getRank();
914     int64_t resRank = resultVectorType.getRank();
915     if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
916       return failure();
917 
918     // Generic ShapeCast lowering path goes all the way down to unrolled scalar
919     // extract/insert chains.
920     // TODO: consider evolving the semantics to only allow 1D source or dest and
921     // drop this potentially very expensive lowering.
922     // Compute number of elements involved in the reshape.
923     int64_t numElts = 1;
924     for (int64_t r = 0; r < srcRank; r++)
925       numElts *= sourceVectorType.getDimSize(r);
926     // Replace with data movement operations:
927     //    x[0,0,0] = y[0,0]
928     //    x[0,0,1] = y[0,1]
929     //    x[0,1,0] = y[0,2]
930     // etc., incrementing the two index vectors "row-major"
931     // within the source and result shape.
932     SmallVector<int64_t, 4> srcIdx(srcRank);
933     SmallVector<int64_t, 4> resIdx(resRank);
934     Value result = rewriter.create<arith::ConstantOp>(
935         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
936     for (int64_t i = 0; i < numElts; i++) {
937       if (i != 0) {
938         incIdx(srcIdx, sourceVectorType, srcRank - 1);
939         incIdx(resIdx, resultVectorType, resRank - 1);
940       }
941       Value e = rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
942       result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx);
943     }
944     rewriter.replaceOp(op, result);
945     return success();
946   }
947 
948 private:
incIdx(SmallVector<int64_t,4> & idx,VectorType tp,int64_t r)949   static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) {
950     assert(0 <= r && r < tp.getRank());
951     if (++idx[r] == tp.getDimSize(r)) {
952       idx[r] = 0;
953       incIdx(idx, tp, r - 1);
954     }
955   }
956 };
957 
958 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
959 /// Ex:
960 /// ```
961 ///   %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
962 ///   %1 = vector.multi_reduction add, %0 [1]
963 ///     : vector<8x32x16xf32> to vector<8x16xf32>
964 /// ```
965 /// Gets converted to:
966 /// ```
967 ///   %1 = vector.contract {indexing_maps = [
968 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
969 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
970 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
971 ///    iterator_types = ["parallel", "parallel", "reduction"],
972 ///    kind = add} %0, %arg1, %cst_f0
973 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
974 ///  ```
975 struct MultiReduceToContract
976     : public OpRewritePattern<vector::MultiDimReductionOp> {
977   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
978 
matchAndRewrite__anon5c5a5b800211::MultiReduceToContract979   LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
980                                 PatternRewriter &rewriter) const override {
981     if (reduceOp.getKind() != vector::CombiningKind::ADD)
982       return failure();
983     Operation *mulOp = reduceOp.getSource().getDefiningOp();
984     if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
985       return failure();
986     SmallVector<bool> reductionMask = reduceOp.getReductionMask();
987     auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
988     SmallVector<AffineExpr> exprs;
989     SmallVector<StringRef> iteratorTypes;
990     for (const auto &isReduceDim : llvm::enumerate(reductionMask)) {
991       if (!isReduceDim.value()) {
992         iteratorTypes.push_back(getParallelIteratorTypeName());
993         exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
994       } else {
995         iteratorTypes.push_back(getReductionIteratorTypeName());
996       }
997     }
998     auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(),
999                                  /*symCount=*/0, exprs, reduceOp.getContext());
1000     rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
1001         reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), reduceOp.getAcc(),
1002         rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
1003         rewriter.getStrArrayAttr(iteratorTypes));
1004     return success();
1005   }
1006 };
1007 
1008 /// Merge TransposeOp into ContractionOp user.
1009 /// Ex:
1010 /// ```
1011 ///   %0 = vector.transpose %arg0, [2, 0, 1]
1012 ///     : vector<32x16x8xf32> to vector<8x32x16xf32>
1013 ///   %1 = vector.contract {indexing_maps = [
1014 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1015 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1016 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
1017 ///    iterator_types = ["parallel", "parallel", "reduction"],
1018 ///    kind = add} %0, %arg1, %cst_f0
1019 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
1020 /// ```
1021 /// Gets converted to:
1022 /// ```
1023 ///   %1 = vector.contract {indexing_maps = [
1024 ///         affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
1025 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1026 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
1027 ///    iterator_types = ["parallel", "parallel", "reduction"],
1028 ///    kind = add} %arg0, %arg1, %cst_f0
1029 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
1030 ///  ```
1031 struct CombineContractTranspose
1032     : public OpRewritePattern<vector::ContractionOp> {
1033   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
1034 
matchAndRewrite__anon5c5a5b800211::CombineContractTranspose1035   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1036                                 PatternRewriter &rewriter) const override {
1037     SmallVector<AffineMap, 4> maps =
1038         llvm::to_vector<4>(contractOp.getIndexingMapsArray());
1039     Value lhs = contractOp.getLhs();
1040     Value rhs = contractOp.getRhs();
1041     size_t index = 0;
1042     bool changed = false;
1043     for (Value *operand : {&lhs, &rhs}) {
1044       AffineMap &map = maps[index++];
1045       auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
1046       if (!transposeOp)
1047         continue;
1048       SmallVector<int64_t> perm;
1049       transposeOp.getTransp(perm);
1050       AffineMap permutationMap = AffineMap::getPermutationMap(
1051           extractVector<unsigned>(transposeOp.getTransp()),
1052           contractOp.getContext());
1053       map = inversePermutation(permutationMap).compose(map);
1054       *operand = transposeOp.getVector();
1055       changed = true;
1056     }
1057     if (!changed)
1058       return failure();
1059     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1060         contractOp, lhs, rhs, contractOp.getAcc(),
1061         rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
1062     return success();
1063   }
1064 };
1065 
1066 /// Merge BroadcastOp into ContractionOp user.
1067 /// Ex:
1068 /// ```
1069 ///   %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
1070 ///   %1 = vector.contract {indexing_maps = [
1071 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1072 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1073 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
1074 ///    iterator_types = ["parallel", "parallel", "reduction"],
1075 ///    kind = add} %0, %arg1, %cst_f0
1076 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
1077 /// ```
1078 /// Gets converted to:
1079 /// ```
1080 ///   %1 = vector.contract {indexing_maps = [
1081 ///         affine_map<(d0, d1, d2) -> (d1, d2)>,
1082 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1083 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
1084 ///    iterator_types = ["parallel", "parallel", "reduction"],
1085 ///    kind = add} %arg0, %arg1, %cst_f0
1086 ///    : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
1087 ///  ```
1088 struct CombineContractBroadcast
1089     : public OpRewritePattern<vector::ContractionOp> {
1090   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
1091 
matchAndRewrite__anon5c5a5b800211::CombineContractBroadcast1092   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1093                                 PatternRewriter &rewriter) const override {
1094     SmallVector<AffineMap, 4> maps =
1095         llvm::to_vector<4>(contractOp.getIndexingMapsArray());
1096     Value lhs = contractOp.getLhs();
1097     Value rhs = contractOp.getRhs();
1098     size_t index = 0;
1099     bool changed = false;
1100     for (Value *operand : {&lhs, &rhs}) {
1101       AffineMap &map = maps[index++];
1102       auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
1103       if (!broadcast)
1104         continue;
1105       // contractionOp can only take vector as operands.
1106       auto srcType = broadcast.getSourceType().dyn_cast<VectorType>();
1107       if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank())
1108         continue;
1109       int64_t rankDiff =
1110           broadcast.getVectorType().getRank() - srcType.getRank();
1111       bool innerDimBroadcast = false;
1112       SmallVector<AffineExpr> originalDims;
1113       for (const auto &dim : llvm::enumerate(srcType.getShape())) {
1114         if (dim.value() !=
1115             broadcast.getVectorType().getDimSize(rankDiff + dim.index())) {
1116           innerDimBroadcast = true;
1117           break;
1118         }
1119         originalDims.push_back(
1120             rewriter.getAffineDimExpr(dim.index() + rankDiff));
1121       }
1122       // Contract doesn't support inner dimension broadcast. Once this is
1123       // relaxed we can remove this case.
1124       if (innerDimBroadcast)
1125         continue;
1126 
1127       // It would be incorrect to fold a broadcast onto a reduction dimension
1128       // of non-unit size.
1129       bool nonUnitDimReductionBroadcast = false;
1130       for (int64_t i = 0; i < rankDiff; ++i) {
1131         if (broadcast.getVectorType().getDimSize(i) != 1 &&
1132             isReductionIterator(contractOp.getIteratorTypes()
1133                                     .getValue()[map.getDimPosition(i)])) {
1134           nonUnitDimReductionBroadcast = true;
1135           break;
1136         }
1137       }
1138       if (nonUnitDimReductionBroadcast)
1139         continue;
1140 
1141       AffineMap broadcastMap =
1142           AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims,
1143                          contractOp.getContext());
1144       map = broadcastMap.compose(map);
1145       *operand = broadcast.getSource();
1146       changed = true;
1147     }
1148 
1149     if (!changed)
1150       return failure();
1151 
1152     // Determine which dims are usused, now that the maps have been composed
1153     // with the broadcast maps.
1154     llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
1155     // Compress unused dims.
1156     for (auto &m : maps)
1157       m = compressDims(m, unusedDimsBitVector);
1158     // Compute the combined iterators.
1159     SmallVector<Attribute, 4> iterators;
1160     for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
1161       if (!unusedDimsBitVector.test(i))
1162         iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
1163     }
1164     // Check that compressing unused dims isn't removing all reduction dimension
1165     // pairs. For example, if the vector.contract had only one reduction
1166     // iterator and that was a unit-dimension created by a broadcast,
1167     // then we should bail here, otherwise we would create a contract without
1168     // a reduction dimension pair.
1169     bool hasReductionIteratorApplyingOnBothSides = false;
1170     for (unsigned i = 0; i < iterators.size(); ++i) {
1171       if (!isReductionIterator(iterators[i]))
1172         continue;
1173       if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) {
1174         hasReductionIteratorApplyingOnBothSides = true;
1175         break;
1176       }
1177     }
1178     if (!hasReductionIteratorApplyingOnBothSides)
1179       return failure();
1180 
1181     // If the compressed maps have a dimension that is not used by either LHS or
1182     // RHS then the ContractionOp verifier would fail.
1183     if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
1184       return failure();
1185     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1186         contractOp, lhs, rhs, contractOp.getAcc(),
1187         rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
1188     return success();
1189   }
1190 };
1191 
1192 /// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and
1193 /// contraction ops closer, which kicks in CombineContractBroadcast pattern when
1194 /// casting ops are around these operations.
1195 /// Ex:
1196 /// ```
1197 ///   %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8>
1198 ///   %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
1199 /// ```
1200 /// Gets converted to:
1201 /// ```
1202 ///   %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32>
1203 ///   %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32>
1204 /// ```
1205 struct ReorderCastOpsOnBroadcast
1206     : public OpInterfaceRewritePattern<CastOpInterface> {
1207   using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
1208 
matchAndRewrite__anon5c5a5b800211::ReorderCastOpsOnBroadcast1209   LogicalResult matchAndRewrite(CastOpInterface op,
1210                                 PatternRewriter &rewriter) const override {
1211     if (op->getNumOperands() != 1)
1212       return failure();
1213     auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
1214     if (!bcastOp)
1215       return failure();
1216 
1217     Type castResTy = getElementTypeOrSelf(op->getResult(0));
1218     if (auto vecTy = bcastOp.getSourceType().dyn_cast<VectorType>())
1219       castResTy = VectorType::get(vecTy.getShape(), castResTy);
1220     auto *castOp =
1221         rewriter.create(op->getLoc(), op->getName().getIdentifier(),
1222                         bcastOp.getSource(), castResTy, op->getAttrs());
1223     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1224         op, op->getResult(0).getType(), castOp->getResult(0));
1225     return success();
1226   }
1227 };
1228 
1229 /// Reorders elementwise(transpose) to transpose(elementwise). This makes
1230 /// transpose ops and contraction ops closer, which kicks in
1231 /// CombineContractTranspose pattern when elementwise ops are between these
1232 /// operations. Ex:
1233 /// ```
1234 /// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
1235 /// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
1236 /// %r = arith.addf %at, %bt : vector<2x4xf32>
1237 /// ```
1238 /// Gets converted to:
1239 /// ```
1240 /// %0 = arith.addf %a, %b : vector<4x2xf32>
1241 /// %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
1242 /// ```
1243 struct ReorderElementwiseOpsOnTranspose final
1244     : public OpTraitRewritePattern<OpTrait::Elementwise> {
1245   using OpTraitRewritePattern::OpTraitRewritePattern;
matchAndRewrite__anon5c5a5b800211::ReorderElementwiseOpsOnTranspose1246   LogicalResult matchAndRewrite(Operation *op,
1247                                 PatternRewriter &rewriter) const override {
1248     if (op->getNumResults() != 1 || op->getNumRegions() != 0)
1249       return failure();
1250 
1251     // Make sure all operands are transpose/constant ops and collect their
1252     // transposition maps.
1253     SmallVector<ArrayAttr, 4> transposeMaps;
1254     transposeMaps.reserve(op->getNumOperands());
1255     // Record the initial type before transposition. We'll use its shape later.
1256     // Any type will do here as we will check all transpose maps are the same.
1257     VectorType srcType;
1258     for (Value operand : op->getOperands()) {
1259       auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
1260       if (transposeOp) {
1261         transposeMaps.push_back(transposeOp.getTransp());
1262         srcType = transposeOp.getVectorType();
1263       } else if (!matchPattern(operand, m_Constant())) {
1264         return failure();
1265       }
1266     }
1267     if (transposeMaps.empty())
1268       return failure();
1269     // This is an elementwise op, so all transposed operands should have the
1270     // same type. We need to additionally check that all transposes uses the
1271     // same map.
1272     if (!llvm::is_splat(transposeMaps))
1273       return rewriter.notifyMatchFailure(op, "different transpose map");
1274 
1275     SmallVector<Value, 4> srcValues;
1276     srcValues.reserve(op->getNumOperands());
1277 
1278     // If there are constant operands, we need to insert inverse transposes for
1279     // them. Calculate the inverse order first.
1280     auto order = extractVector<unsigned>(transposeMaps.front());
1281     SmallVector<int64_t> invOrder(order.size());
1282     for (int i = 0, e = order.size(); i < e; ++i)
1283       invOrder[order[i]] = i;
1284 
1285     for (Value operand : op->getOperands()) {
1286       auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
1287       if (transposeOp) {
1288         srcValues.push_back(transposeOp.getVector());
1289       } else {
1290         // This is a constant. Create a reverse transpose op for it.
1291         auto vectorType = VectorType::get(
1292             srcType.getShape(),
1293             operand.getType().cast<VectorType>().getElementType());
1294         srcValues.push_back(rewriter.create<vector::TransposeOp>(
1295             operand.getLoc(), vectorType, operand,
1296             rewriter.getI64ArrayAttr(invOrder)));
1297       }
1298     }
1299 
1300     auto vectorType = VectorType::get(
1301         srcType.getShape(),
1302         op->getResultTypes()[0].cast<VectorType>().getElementType());
1303     Operation *elementwiseOp =
1304         rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
1305                         vectorType, op->getAttrs());
1306     rewriter.replaceOpWithNewOp<vector::TransposeOp>(
1307         op, op->getResultTypes()[0], elementwiseOp->getResult(0),
1308         transposeMaps.front());
1309     return success();
1310   }
1311 };
1312 
1313 } // namespace
1314 
1315 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
1316 /// operands `x` and `y`.
createAdd(Location loc,Value x,Value y,bool isInt,PatternRewriter & rewriter)1317 static Value createAdd(Location loc, Value x, Value y, bool isInt,
1318                        PatternRewriter &rewriter) {
1319   if (isInt)
1320     return rewriter.create<arith::AddIOp>(loc, x, y);
1321   return rewriter.create<arith::AddFOp>(loc, x, y);
1322 }
1323 
1324 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
1325 /// operands `x and `y`.
createMul(Location loc,Value x,Value y,bool isInt,PatternRewriter & rewriter)1326 static Value createMul(Location loc, Value x, Value y, bool isInt,
1327                        PatternRewriter &rewriter) {
1328   if (isInt)
1329     return rewriter.create<arith::MulIOp>(loc, x, y);
1330   return rewriter.create<arith::MulFOp>(loc, x, y);
1331 }
1332 
1333 namespace mlir {
1334 
1335 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1336 /// semantics to:
1337 /// ```
1338 ///    %mta = maybe_transpose
1339 ///    %mtb = maybe_transpose
1340 ///    %flattened_a = vector.shape_cast %mta
1341 ///    %flattened_b = vector.shape_cast %mtb
1342 ///    %flattened_d = vector.matmul %flattened_a, %flattened_b
1343 ///    %mtd = vector.shape_cast %flattened_d
1344 ///    %d = maybe_untranspose %mtd
1345 ///    %e = add %c, %d
1346 /// ```
1347 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
1348 //
1349 /// This only kicks in when VectorTransformsOptions is set to `Matmul`.
1350 /// vector.transpose operations are inserted if the vector.contract op is not a
1351 /// row-major matrix multiply.
1352 LogicalResult
matchAndRewrite(vector::ContractionOp op,PatternRewriter & rew) const1353 ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
1354                                                  PatternRewriter &rew) const {
1355   // TODO: implement masks
1356   if (llvm::size(op.getMasks()) != 0)
1357     return failure();
1358   if (vectorTransformOptions.vectorContractLowering !=
1359       vector::VectorContractLowering::Matmul)
1360     return failure();
1361   if (failed(filter(op)))
1362     return failure();
1363 
1364   auto iteratorTypes = op.getIteratorTypes().getValue();
1365   if (!isParallelIterator(iteratorTypes[0]) ||
1366       !isParallelIterator(iteratorTypes[1]) ||
1367       !isReductionIterator(iteratorTypes[2]))
1368     return failure();
1369 
1370   Type elementType = op.getLhsType().getElementType();
1371   if (!elementType.isIntOrFloat())
1372     return failure();
1373 
1374   Type dstElementType = op.getType();
1375   if (auto vecType = dstElementType.dyn_cast<VectorType>())
1376     dstElementType = vecType.getElementType();
1377   if (elementType != dstElementType)
1378     return failure();
1379 
1380   // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
1381   // Bail out if the contraction cannot be put in this form.
1382   MLIRContext *ctx = op.getContext();
1383   Location loc = op.getLoc();
1384   AffineExpr m, n, k;
1385   bindDims(rew.getContext(), m, n, k);
1386   // LHS must be A(m, k) or A(k, m).
1387   Value lhs = op.getLhs();
1388   auto lhsMap = op.getIndexingMapsArray()[0];
1389   if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
1390     lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
1391   else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
1392     return failure();
1393 
1394   // RHS must be B(k, n) or B(n, k).
1395   Value rhs = op.getRhs();
1396   auto rhsMap = op.getIndexingMapsArray()[1];
1397   if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
1398     rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
1399   else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
1400     return failure();
1401 
1402   // At this point lhs and rhs are in row-major.
1403   VectorType lhsType = lhs.getType().cast<VectorType>();
1404   VectorType rhsType = rhs.getType().cast<VectorType>();
1405   int64_t lhsRows = lhsType.getDimSize(0);
1406   int64_t lhsColumns = lhsType.getDimSize(1);
1407   int64_t rhsColumns = rhsType.getDimSize(1);
1408 
1409   Type flattenedLHSType =
1410       VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
1411   lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
1412 
1413   Type flattenedRHSType =
1414       VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
1415   rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
1416 
1417   Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
1418                                            rhsColumns);
1419   mul = rew.create<vector::ShapeCastOp>(
1420       loc,
1421       VectorType::get({lhsRows, rhsColumns},
1422                       getElementTypeOrSelf(op.getAcc().getType())),
1423       mul);
1424 
1425   // ACC must be C(m, n) or C(n, m).
1426   auto accMap = op.getIndexingMapsArray()[2];
1427   if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
1428     mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
1429   else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
1430     llvm_unreachable("invalid contraction semantics");
1431 
1432   Value res =
1433       elementType.isa<IntegerType>()
1434           ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
1435           : static_cast<Value>(
1436                 rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
1437 
1438   rew.replaceOp(op, res);
1439   return success();
1440 }
1441 
1442 namespace {
1443 struct IteratorType {
IteratorTypemlir::__anon5c5a5b800311::IteratorType1444   IteratorType(StringRef strRef) : strRef(strRef) {}
isOfTypemlir::__anon5c5a5b800311::IteratorType1445   bool isOfType(Attribute attr) const {
1446     auto sAttr = attr.dyn_cast<StringAttr>();
1447     return sAttr && sAttr.getValue() == strRef;
1448   }
1449   StringRef strRef;
1450 };
1451 struct Par : public IteratorType {
Parmlir::__anon5c5a5b800311::Par1452   Par() : IteratorType(getParallelIteratorTypeName()) {}
1453 };
1454 struct Red : public IteratorType {
Redmlir::__anon5c5a5b800311::Red1455   Red() : IteratorType(getReductionIteratorTypeName()) {}
1456 };
1457 
1458 /// Generate a vector implementation for matmat, matvec and tmatvec.
1459 /// This unrolls outer-products along the reduction dimension.
1460 struct UnrolledOuterProductGenerator
1461     : public StructuredGenerator<vector::ContractionOp> {
UnrolledOuterProductGeneratormlir::__anon5c5a5b800311::UnrolledOuterProductGenerator1462   UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op)
1463       : StructuredGenerator<vector::ContractionOp>(builder, op),
1464         kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
1465         res(op.getAcc()), lhsType(op.getLhsType()) {}
1466 
tmlir::__anon5c5a5b800311::UnrolledOuterProductGenerator1467   Value t(Value v) {
1468     static constexpr std::array<int64_t, 2> perm = {1, 0};
1469     return builder.create<vector::TransposeOp>(loc, v, perm);
1470   }
1471 
promotemlir::__anon5c5a5b800311::UnrolledOuterProductGenerator1472   Value promote(Value v, Type dstElementType) {
1473     Type elementType = v.getType();
1474     auto vecType = elementType.dyn_cast<VectorType>();
1475     if (vecType)
1476       elementType = vecType.getElementType();
1477     if (elementType == dstElementType)
1478       return v;
1479     Type promotedType = dstElementType;
1480     if (vecType)
1481       promotedType = VectorType::get(vecType.getShape(), promotedType);
1482     if (dstElementType.isa<FloatType>())
1483       return builder.create<arith::ExtFOp>(loc, promotedType, v);
1484     return builder.create<arith::ExtSIOp>(loc, promotedType, v);
1485   }
1486 
outerProdmlir::__anon5c5a5b800311::UnrolledOuterProductGenerator1487   Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) {
1488     assert(reductionSize > 0);
1489     Type resElementType = res.getType().cast<VectorType>().getElementType();
1490     for (int64_t k = 0; k < reductionSize; ++k) {
1491       Value a = builder.create<vector::ExtractOp>(loc, lhs, k);
1492       Value b = builder.create<vector::ExtractOp>(loc, rhs, k);
1493       a = promote(a, resElementType);
1494       b = promote(b, resElementType);
1495       res = builder.create<vector::OuterProductOp>(loc, res.getType(), a, b,
1496                                                    res, kind);
1497     }
1498     return res;
1499   }
1500 
1501   /// Two outer parallel, one inner reduction (matmat flavor).
matmatmlir::__anon5c5a5b800311::UnrolledOuterProductGenerator1502   FailureOr<Value> matmat() {
1503     if (!iters({Par(), Par(), Red()}))
1504       return failure();
1505     // Set up the parallel/reduction structure in the right form.
1506     AffineExpr m, n, k;
1507     bindDims(builder.getContext(), m, n, k);
1508     // Classical row-major matmul:  Just permute the lhs.
1509     if (layout({{m, k}, {k, n}, {m, n}}))
1510       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1511     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1512     if (layout({{m, k}, {n, k}, {m, n}})) {
1513       Value tlhs = t(lhs);
1514       return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1));
1515     }
1516     // No need to permute anything.
1517     if (layout({{k, m}, {k, n}, {m, n}}))
1518       return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1519     // Just permute the rhs.
1520     if (layout({{k, m}, {n, k}, {m, n}}))
1521       return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0));
1522     // Transposed output: swap RHS and LHS.
1523     // Classical row-major matmul: permute the lhs.
1524     if (layout({{m, k}, {k, n}, {n, m}}))
1525       return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1));
1526     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1527     if (layout({{m, k}, {n, k}, {n, m}})) {
1528       Value trhs = t(rhs);
1529       return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1));
1530     }
1531     if (layout({{k, m}, {k, n}, {n, m}}))
1532       return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1533     if (layout({{k, m}, {n, k}, {n, m}}))
1534       return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1535     return failure();
1536   }
1537 
1538   /// One outer parallel, one inner reduction (matvec flavor)
matvecmlir::__anon5c5a5b800311::UnrolledOuterProductGenerator1539   FailureOr<Value> matvec() {
1540     if (!iters({Par(), Red()}))
1541       return failure();
1542     AffineExpr m, k;
1543     bindDims(builder.getContext(), m, k);
1544 
1545     // Case mat-vec: transpose.
1546     if (layout({{m, k}, {k}, {m}}))
1547       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1548     // Case mat-trans-vec: ready to go.
1549     if (layout({{k, m}, {k}, {m}}))
1550       return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1551     // Case vec-mat: swap and transpose.
1552     if (layout({{k}, {m, k}, {m}}))
1553       return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1554     // Case vec-mat-trans: swap and ready to go.
1555     if (layout({{k}, {k, m}, {m}}))
1556       return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1557     return failure();
1558   }
1559 
1560   //
1561   // One outer reduction, one inner parallel (tmatvec flavor)
1562   //
tmatvecmlir::__anon5c5a5b800311::UnrolledOuterProductGenerator1563   FailureOr<Value> tmatvec() {
1564     if (!iters({Red(), Par()}))
1565       return failure();
1566     AffineExpr k, m;
1567     bindDims(builder.getContext(), k, m);
1568 
1569     // Case mat-vec: transpose.
1570     if (layout({{m, k}, {k}, {m}}))
1571       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1572     // Case mat-trans-vec: ready to go.
1573     if (layout({{k, m}, {k}, {m}}))
1574       return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1575     // Case vec-mat: swap and transpose.
1576     if (layout({{k}, {m, k}, {m}}))
1577       return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1578     // Case vec-mat-trans: swap and ready to go.
1579     if (layout({{k}, {k, m}, {m}}))
1580       return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1581     return failure();
1582   }
1583 
1584 private:
1585   vector::CombiningKind kind;
1586   Value lhs, rhs, res;
1587   VectorType lhsType;
1588 };
1589 } // namespace
1590 
1591 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1592 /// semantics to a reduction_size-unrolled sequence:
1593 /// ```
1594 ///    %at = vector.transpose %a, [1, 0]
1595 ///    %bRow0 = vector.extract %b[0]
1596 ///    %atRow0 = vector.extract %at[0]
1597 ///    %c0 = vector.outerproduct %atRow0, %bRow0, %c
1598 ///    ...
1599 ///    %bRowK = vector.extract %b[K]
1600 ///    %atRowK = vector.extract %at[K]
1601 ///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
1602 /// ```
1603 ///
1604 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but
1605 /// otherwise supports any layout permutation of the matrix-multiply.
matchAndRewrite(vector::ContractionOp op,PatternRewriter & rewriter) const1606 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
1607     vector::ContractionOp op, PatternRewriter &rewriter) const {
1608   // TODO: implement masks
1609   if (llvm::size(op.getMasks()) != 0)
1610     return failure();
1611 
1612   if (vectorTransformOptions.vectorContractLowering !=
1613       vector::VectorContractLowering::OuterProduct)
1614     return failure();
1615 
1616   if (failed(filter(op)))
1617     return failure();
1618 
1619   UnrolledOuterProductGenerator e(rewriter, op);
1620   FailureOr<Value> matmatRes = e.matmat();
1621   if (succeeded(matmatRes)) {
1622     rewriter.replaceOp(op, *matmatRes);
1623     return success();
1624   }
1625   FailureOr<Value> matvecRes = e.matvec();
1626   if (succeeded(matvecRes)) {
1627     rewriter.replaceOp(op, *matvecRes);
1628     return success();
1629   }
1630   FailureOr<Value> tmatvecRes = e.tmatvec();
1631   if (succeeded(tmatvecRes)) {
1632     rewriter.replaceOp(op, *tmatvecRes);
1633     return success();
1634   }
1635 
1636   return failure();
1637 }
1638 
1639 LogicalResult
matchAndRewrite(vector::ContractionOp op,PatternRewriter & rewriter) const1640 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
1641                                             PatternRewriter &rewriter) const {
1642   // TODO: implement masks
1643   if (llvm::size(op.getMasks()) != 0)
1644     return failure();
1645 
1646   if (failed(filter(op)))
1647     return failure();
1648 
1649   if (vectorTransformOptions.vectorContractLowering !=
1650       vector::VectorContractLowering::Dot)
1651     return failure();
1652 
1653   auto iteratorTypes = op.getIteratorTypes().getValue();
1654   static constexpr std::array<int64_t, 2> perm = {1, 0};
1655   Location loc = op.getLoc();
1656   Value lhs = op.getLhs(), rhs = op.getRhs();
1657 
1658   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1659   auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
1660   AffineExpr m, n, k;
1661   bindDims(rewriter.getContext(), m, n, k);
1662   SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
1663   //
1664   // In the following we wish to make the reduction dimension innermost so we
1665   // can load vectors and just fmul + reduce into a scalar.
1666   //
1667   if (isParallelIterator(iteratorTypes[0]) &&
1668       isParallelIterator(iteratorTypes[1]) &&
1669       isReductionIterator(iteratorTypes[2])) {
1670     //
1671     // Two outer parallel, one inner reduction (matmat flavor).
1672     //
1673     if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1674       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1675     } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
1676       // No need to permute anything.
1677     } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1678       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1679       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1680     } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1681       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1682     } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1683       // This is the classical row-major matmul. Just permute the lhs.
1684       Value tmp = lhs;
1685       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1686       rhs = tmp;
1687     } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1688       std::swap(lhs, rhs);
1689     } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1690       Value tmp = lhs;
1691       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1692       rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
1693     } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1694       Value tmp = rhs;
1695       rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1696       lhs = tmp;
1697     } else {
1698       return failure();
1699     }
1700   } else if (isParallelIterator(iteratorTypes[0]) &&
1701              isReductionIterator(iteratorTypes[1])) {
1702     //
1703     // One outer parallel, one inner reduction (matvec flavor)
1704     //
1705     if (maps == infer({{m, n}, {n}, {m}})) {
1706       // No need to permute anything.
1707     } else if (maps == infer({{n, m}, {n}, {m}})) {
1708       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1709     } else if (maps == infer({{n}, {m, n}, {m}})) {
1710       std::swap(lhs, rhs);
1711     } else if (maps == infer({{n}, {n, m}, {m}})) {
1712       std::swap(lhs, rhs);
1713       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1714     } else {
1715       return failure();
1716     }
1717   } else {
1718     return failure();
1719   }
1720 
1721   VectorType dstType = op.getResultType().cast<VectorType>();
1722   assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
1723          "Expected dst type of rank 1 or 2");
1724 
1725   unsigned rank = dstType.getRank();
1726   unsigned dstRows = dstType.getShape()[0];
1727   unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
1728 
1729   // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
1730   Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
1731                                                  rewriter.getZeroAttr(dstType));
1732   bool isInt = dstType.getElementType().isa<IntegerType>();
1733   for (unsigned r = 0; r < dstRows; ++r) {
1734     Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
1735     for (unsigned c = 0; c < dstColumns; ++c) {
1736       Value b = rank == 1
1737                     ? rhs
1738                     : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
1739       Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
1740       Value reduced = rewriter.create<vector::ReductionOp>(
1741           op.getLoc(), vector::CombiningKind::ADD, m);
1742 
1743       SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
1744                                               : SmallVector<int64_t, 2>{r, c};
1745       res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
1746     }
1747   }
1748   if (auto acc = op.getAcc())
1749     res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
1750   rewriter.replaceOp(op, res);
1751   return success();
1752 }
1753 
1754 /// Progressive lowering of ContractionOp.
1755 /// One:
1756 ///   %x = vector.contract with at least one free/batch dimension
1757 /// is replaced by:
1758 ///   %a = vector.contract with one less free/batch dimension
1759 ///   %b = vector.contract with one less free/batch dimension
1760 ///   ..
1761 ///   %x = combine %a %b ..
1762 /// until a pure contraction is reached (no free/batch dimensions),
1763 /// which is replaced by a dot-product.
1764 ///
1765 /// This only kicks in when either VectorTransformsOptions is set
1766 /// to DOT or when other contraction patterns fail.
1767 //
1768 // TODO: break down into transpose/reshape/cast ops
1769 //               when they become available to avoid code dup
1770 // TODO: investigate lowering order impact on performance
1771 LogicalResult
matchAndRewrite(vector::ContractionOp op,PatternRewriter & rewriter) const1772 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
1773                                        PatternRewriter &rewriter) const {
1774   // TODO: implement masks.
1775   if (llvm::size(op.getMasks()) != 0)
1776     return failure();
1777 
1778   if (failed(filter(op)))
1779     return failure();
1780 
1781   // TODO: support mixed mode contract lowering.
1782   if (op.getLhsType().getElementType() !=
1783           getElementTypeOrSelf(op.getAccType()) ||
1784       op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
1785     return failure();
1786 
1787   // TODO: implement benefits, cost models.
1788   MLIRContext *ctx = op.getContext();
1789   ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
1790   if (succeeded(pat1.matchAndRewrite(op, rewriter)))
1791     return success();
1792   ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
1793   if (succeeded(pat2.matchAndRewrite(op, rewriter)))
1794     return success();
1795   ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
1796   if (succeeded(pat3.matchAndRewrite(op, rewriter)))
1797     return success();
1798   ContractOpToElementwise pat4(vectorTransformOptions, ctx);
1799   if (succeeded(pat4.matchAndRewrite(op, rewriter)))
1800     return success();
1801 
1802   // Find first batch dimension in LHS/RHS, and lower when found.
1803   std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
1804   if (!batchDimMap.empty()) {
1805     int64_t lhsIndex = batchDimMap[0].first;
1806     int64_t rhsIndex = batchDimMap[0].second;
1807     auto newOp = lowerParallel(op, lhsIndex, rhsIndex, rewriter);
1808     if (failed(newOp))
1809       return failure();
1810     rewriter.replaceOp(op, newOp.value());
1811     return success();
1812   }
1813 
1814   // Collect contracting dimensions.
1815   std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
1816       op.getContractingDimMap();
1817   DenseSet<int64_t> lhsContractingDimSet;
1818   DenseSet<int64_t> rhsContractingDimSet;
1819   for (auto &dimPair : contractingDimMap) {
1820     lhsContractingDimSet.insert(dimPair.first);
1821     rhsContractingDimSet.insert(dimPair.second);
1822   }
1823 
1824   // Find first free dimension in LHS, and lower when found.
1825   VectorType lhsType = op.getLhsType();
1826   for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
1827     if (lhsContractingDimSet.count(lhsIndex) == 0) {
1828       auto newOp = lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter);
1829       if (failed(newOp))
1830         return failure();
1831       rewriter.replaceOp(op, newOp.value());
1832       return success();
1833     }
1834   }
1835 
1836   // Find first free dimension in RHS, and lower when found.
1837   VectorType rhsType = op.getRhsType();
1838   for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
1839     if (rhsContractingDimSet.count(rhsIndex) == 0) {
1840       auto newOp = lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter);
1841       if (failed(newOp))
1842         return failure();
1843       rewriter.replaceOp(op, newOp.value());
1844       return success();
1845     }
1846   }
1847 
1848   // Lower the first remaining reduction dimension.
1849   if (!contractingDimMap.empty()) {
1850     auto newOp = lowerReduction(op, rewriter);
1851     if (failed(newOp))
1852       return failure();
1853     rewriter.replaceOp(op, newOp.value());
1854     return success();
1855   }
1856 
1857   return failure();
1858 }
1859 
1860 // Lower one parallel dimension.
1861 // Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
1862 // TODO: consider reusing existing contract unrolling
1863 FailureOr<Value>
lowerParallel(vector::ContractionOp op,int64_t lhsIndex,int64_t rhsIndex,PatternRewriter & rewriter) const1864 ContractionOpLowering::lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
1865                                      int64_t rhsIndex,
1866                                      PatternRewriter &rewriter) const {
1867   VectorType lhsType = op.getLhsType();
1868   VectorType rhsType = op.getRhsType();
1869   VectorType resType = op.getResultType().cast<VectorType>();
1870   // Find the iterator type index and result index.
1871   SmallVector<AffineMap, 4> iMap = op.getIndexingMapsArray();
1872   int64_t iterIndex = -1;
1873   int64_t dimSize = -1;
1874   if (lhsIndex >= 0) {
1875     iterIndex = iMap[0].getDimPosition(lhsIndex);
1876     if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex))
1877       return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1878         diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
1879              << " to map to the same dimension";
1880       });
1881     dimSize = lhsType.getDimSize(lhsIndex);
1882   } else if (rhsIndex >= 0) {
1883     iterIndex = iMap[1].getDimPosition(rhsIndex);
1884     dimSize = rhsType.getDimSize(rhsIndex);
1885   }
1886   if (iterIndex < 0)
1887     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1888       diag << "expected either lhsIndex=" << lhsIndex
1889            << " or rhsIndex=" << rhsIndex << " to be nonnegative";
1890     });
1891   // value_or(-1) means that we tolerate a dimension not appearing
1892   // in the result map. That can't happen for actual parallel iterators, but
1893   // the caller ContractionOpLowering::matchAndRewrite is currently calling
1894   // lowerParallel also for the case of unit-size reduction dims appearing only
1895   // on one of LHS or RHS, not both. At the moment, such cases are created by
1896   // CastAwayContractionLeadingOneDim, so we need to either support that or
1897   // modify that pattern.
1898   int64_t resIndex = getResultIndex(iMap[2], iterIndex).value_or(-1);
1899   if (resIndex == -1 && dimSize != 1)
1900     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1901       diag << "expected the dimension for iterIndex=" << iterIndex
1902            << " to either appear in the result map, or to be a unit dimension";
1903     });
1904   // Construct new iterator types and affine map array attribute.
1905   std::array<AffineMap, 3> lowIndexingMaps = {
1906       adjustMap(iMap[0], iterIndex, rewriter),
1907       adjustMap(iMap[1], iterIndex, rewriter),
1908       adjustMap(iMap[2], iterIndex, rewriter)};
1909   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1910   auto lowIter =
1911       rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1912   // Unroll into a series of lower dimensional vector.contract ops.
1913   Location loc = op.getLoc();
1914   Value result = rewriter.create<arith::ConstantOp>(
1915       loc, resType, rewriter.getZeroAttr(resType));
1916   for (int64_t d = 0; d < dimSize; ++d) {
1917     auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1918     auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1919     auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
1920     Value lowContract = rewriter.create<vector::ContractionOp>(
1921         loc, lhs, rhs, acc, lowAffine, lowIter);
1922     result =
1923         reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter);
1924   }
1925   return result;
1926 }
1927 
1928 // Lower one reduction dimension.
1929 FailureOr<Value>
lowerReduction(vector::ContractionOp op,PatternRewriter & rewriter) const1930 ContractionOpLowering::lowerReduction(vector::ContractionOp op,
1931                                       PatternRewriter &rewriter) const {
1932   auto loc = op.getLoc();
1933   VectorType lhsType = op.getLhsType();
1934   VectorType rhsType = op.getRhsType();
1935   Type resType = op.getResultType();
1936   if (resType.isa<VectorType>())
1937     return rewriter.notifyMatchFailure(op,
1938                                        "did not expect a VectorType result");
1939   bool isInt = resType.isa<IntegerType>();
1940   // Use iterator index 0.
1941   int64_t iterIndex = 0;
1942   SmallVector<AffineMap, 4> iMap = op.getIndexingMapsArray();
1943   Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
1944   Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
1945   if (!lookupLhs.has_value())
1946     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1947       diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension";
1948     });
1949   if (!lookupRhs.has_value())
1950     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1951       diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension";
1952     });
1953   int64_t lhsIndex = lookupLhs.value();
1954   int64_t rhsIndex = lookupRhs.value();
1955   int64_t dimSize = lhsType.getDimSize(lhsIndex);
1956   if (dimSize != rhsType.getDimSize(rhsIndex))
1957     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1958       diag << "expect LHS dimension " << lhsIndex
1959            << " to have the same size as RHS dimension " << rhsIndex;
1960     });
1961   // Base case.
1962   if (lhsType.getRank() == 1) {
1963     if (rhsType.getRank() != 1)
1964       return rewriter.notifyMatchFailure(
1965           op, "When LHS has rank 1, expected also RHS to have rank 1");
1966     Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
1967     auto kind = vector::CombiningKind::ADD;
1968     if (auto acc = op.getAcc())
1969       return rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
1970           .getResult();
1971     return rewriter.create<vector::ReductionOp>(loc, kind, m).getResult();
1972   }
1973   // Construct new iterator types and affine map array attribute.
1974   std::array<AffineMap, 3> lowIndexingMaps = {
1975       adjustMap(iMap[0], iterIndex, rewriter),
1976       adjustMap(iMap[1], iterIndex, rewriter),
1977       adjustMap(iMap[2], iterIndex, rewriter)};
1978   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1979   auto lowIter =
1980       rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1981   // Unroll into a series of lower dimensional vector.contract ops.
1982   // By feeding the initial accumulator into the first contraction,
1983   // and the result of each contraction into the next, eventually
1984   // the sum of all reductions is computed.
1985   Value result = op.getAcc();
1986   for (int64_t d = 0; d < dimSize; ++d) {
1987     auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1988     auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1989     result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result,
1990                                                     lowAffine, lowIter);
1991   }
1992   return result;
1993 }
1994 
1995 } // namespace mlir
1996 
distributPointwiseVectorOp(OpBuilder & builder,Operation * op,ArrayRef<Value> ids,ArrayRef<int64_t> multiplicity,const AffineMap & map)1997 Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
1998     OpBuilder &builder, Operation *op, ArrayRef<Value> ids,
1999     ArrayRef<int64_t> multiplicity, const AffineMap &map) {
2000   OpBuilder::InsertionGuard guard(builder);
2001   builder.setInsertionPointAfter(op);
2002   Location loc = op->getLoc();
2003   if (op->getNumResults() != 1)
2004     return {};
2005   Value result = op->getResult(0);
2006   VectorType type = op->getResult(0).getType().dyn_cast<VectorType>();
2007   if (!type || map.getNumResults() != multiplicity.size())
2008     return {};
2009   // For each dimension being distributed check that the size is a multiple of
2010   // the multiplicity. To handle more sizes we would need to support masking.
2011   unsigned multiplictyCount = 0;
2012   for (auto exp : map.getResults()) {
2013     auto affinExp = exp.dyn_cast<AffineDimExpr>();
2014     if (!affinExp || affinExp.getPosition() >= type.getRank() ||
2015         type.getDimSize(affinExp.getPosition()) %
2016                 multiplicity[multiplictyCount++] !=
2017             0)
2018       return {};
2019   }
2020   DistributeOps ops;
2021   ops.extract =
2022       builder.create<vector::ExtractMapOp>(loc, result, ids, multiplicity, map);
2023   ops.insert =
2024       builder.create<vector::InsertMapOp>(loc, ops.extract, result, ids);
2025   return ops;
2026 }
2027 
2028 /// Progressive lowering of transfer_read. This pattern supports lowering of
2029 /// `vector.transfer_read` to a combination of `vector.load` and
2030 /// `vector.broadcast` if all of the following hold:
2031 /// - Stride of most minor memref dimension must be 1.
2032 /// - Out-of-bounds masking is not required.
2033 /// - If the memref's element type is a vector type then it coincides with the
2034 ///   result type.
2035 /// - The permutation map doesn't perform permutation (broadcasting is allowed).
2036 struct TransferReadToVectorLoadLowering
2037     : public OpRewritePattern<vector::TransferReadOp> {
TransferReadToVectorLoadLoweringTransferReadToVectorLoadLowering2038   TransferReadToVectorLoadLowering(MLIRContext *context,
2039                                    llvm::Optional<unsigned> maxRank)
2040       : OpRewritePattern<vector::TransferReadOp>(context),
2041         maxTransferRank(maxRank) {}
2042 
matchAndRewriteTransferReadToVectorLoadLowering2043   LogicalResult matchAndRewrite(vector::TransferReadOp read,
2044                                 PatternRewriter &rewriter) const override {
2045     if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank)
2046       return failure();
2047 
2048     SmallVector<unsigned, 4> broadcastedDims;
2049     // Permutations are handled by VectorToSCF or
2050     // populateVectorTransferPermutationMapLoweringPatterns.
2051     // We let the 0-d corner case pass-through as it is supported.
2052     if (!read.getPermutationMap().isMinorIdentityWithBroadcasting(
2053             &broadcastedDims))
2054       return failure();
2055 
2056     auto memRefType = read.getShapedType().dyn_cast<MemRefType>();
2057     if (!memRefType)
2058       return failure();
2059 
2060     // Non-unit strides are handled by VectorToSCF.
2061     if (!vector::isLastMemrefDimUnitStride(memRefType))
2062       return failure();
2063 
2064     // If there is broadcasting involved then we first load the unbroadcasted
2065     // vector, and then broadcast it with `vector.broadcast`.
2066     ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
2067     SmallVector<int64_t, 4> unbroadcastedVectorShape(vectorShape.begin(),
2068                                                      vectorShape.end());
2069     for (unsigned i : broadcastedDims)
2070       unbroadcastedVectorShape[i] = 1;
2071     VectorType unbroadcastedVectorType = VectorType::get(
2072         unbroadcastedVectorShape, read.getVectorType().getElementType());
2073 
2074     // `vector.load` supports vector types as memref's elements only when the
2075     // resulting vector type is the same as the element type.
2076     auto memrefElTy = memRefType.getElementType();
2077     if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType)
2078       return failure();
2079 
2080     // Otherwise, element types of the memref and the vector must match.
2081     if (!memrefElTy.isa<VectorType>() &&
2082         memrefElTy != read.getVectorType().getElementType())
2083       return failure();
2084 
2085     // Out-of-bounds dims are handled by MaterializeTransferMask.
2086     if (read.hasOutOfBoundsDim())
2087       return failure();
2088 
2089     // Create vector load op.
2090     Operation *loadOp;
2091     if (read.getMask()) {
2092       Value fill = rewriter.create<vector::SplatOp>(
2093           read.getLoc(), unbroadcastedVectorType, read.getPadding());
2094       loadOp = rewriter.create<vector::MaskedLoadOp>(
2095           read.getLoc(), unbroadcastedVectorType, read.getSource(),
2096           read.getIndices(), read.getMask(), fill);
2097     } else {
2098       loadOp = rewriter.create<vector::LoadOp>(
2099           read.getLoc(), unbroadcastedVectorType, read.getSource(),
2100           read.getIndices());
2101     }
2102 
2103     // Insert a broadcasting op if required.
2104     if (!broadcastedDims.empty()) {
2105       rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
2106           read, read.getVectorType(), loadOp->getResult(0));
2107     } else {
2108       rewriter.replaceOp(read, loadOp->getResult(0));
2109     }
2110 
2111     return success();
2112   }
2113 
2114   llvm::Optional<unsigned> maxTransferRank;
2115 };
2116 
2117 /// Replace a 0-d vector.load with a memref.load + vector.broadcast.
2118 // TODO: we shouldn't cross the vector/scalar domains just for this
2119 // but atm we lack the infra to avoid it. Possible solutions include:
2120 // - go directly to LLVM + bitcast
2121 // - introduce a bitcast op and likely a new pointer dialect
2122 // - let memref.load/store additionally support the 0-d vector case
2123 // There are still deeper data layout issues lingering even in this
2124 // trivial case (for architectures for which this matters).
2125 struct VectorLoadToMemrefLoadLowering
2126     : public OpRewritePattern<vector::LoadOp> {
2127   using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
2128 
matchAndRewriteVectorLoadToMemrefLoadLowering2129   LogicalResult matchAndRewrite(vector::LoadOp loadOp,
2130                                 PatternRewriter &rewriter) const override {
2131     auto vecType = loadOp.getVectorType();
2132     if (vecType.getNumElements() != 1)
2133       return failure();
2134     auto memrefLoad = rewriter.create<memref::LoadOp>(
2135         loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices());
2136     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
2137                                                      memrefLoad);
2138     return success();
2139   }
2140 };
2141 
2142 /// Replace a 0-d vector.store with a vector.extractelement + memref.store.
2143 struct VectorStoreToMemrefStoreLowering
2144     : public OpRewritePattern<vector::StoreOp> {
2145   using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
2146 
matchAndRewriteVectorStoreToMemrefStoreLowering2147   LogicalResult matchAndRewrite(vector::StoreOp storeOp,
2148                                 PatternRewriter &rewriter) const override {
2149     auto vecType = storeOp.getVectorType();
2150     if (vecType.getNumElements() != 1)
2151       return failure();
2152     Value extracted;
2153     if (vecType.getRank() == 0) {
2154       // TODO: Unifiy once ExtractOp supports 0-d vectors.
2155       extracted = rewriter.create<vector::ExtractElementOp>(
2156           storeOp.getLoc(), storeOp.getValueToStore());
2157     } else {
2158       SmallVector<int64_t> indices(vecType.getRank(), 0);
2159       extracted = rewriter.create<vector::ExtractOp>(
2160           storeOp.getLoc(), storeOp.getValueToStore(), indices);
2161     }
2162 
2163     rewriter.replaceOpWithNewOp<memref::StoreOp>(
2164         storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
2165     return success();
2166   }
2167 };
2168 
2169 /// Progressive lowering of transfer_write. This pattern supports lowering of
2170 /// `vector.transfer_write` to `vector.store` if all of the following hold:
2171 /// - Stride of most minor memref dimension must be 1.
2172 /// - Out-of-bounds masking is not required.
2173 /// - If the memref's element type is a vector type then it coincides with the
2174 ///   type of the written value.
2175 /// - The permutation map is the minor identity map (neither permutation nor
2176 ///   broadcasting is allowed).
2177 struct TransferWriteToVectorStoreLowering
2178     : public OpRewritePattern<vector::TransferWriteOp> {
TransferWriteToVectorStoreLoweringTransferWriteToVectorStoreLowering2179   TransferWriteToVectorStoreLowering(MLIRContext *context,
2180                                      llvm::Optional<unsigned> maxRank)
2181       : OpRewritePattern<vector::TransferWriteOp>(context),
2182         maxTransferRank(maxRank) {}
2183 
matchAndRewriteTransferWriteToVectorStoreLowering2184   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
2185                                 PatternRewriter &rewriter) const override {
2186     if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank)
2187       return failure();
2188 
2189     // Permutations are handled by VectorToSCF or
2190     // populateVectorTransferPermutationMapLoweringPatterns.
2191     if ( // pass-through for the 0-d corner case.
2192         !write.getPermutationMap().isMinorIdentity())
2193       return failure();
2194 
2195     auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
2196     if (!memRefType)
2197       return failure();
2198 
2199     // Non-unit strides are handled by VectorToSCF.
2200     if (!vector::isLastMemrefDimUnitStride(memRefType))
2201       return failure();
2202 
2203     // `vector.store` supports vector types as memref's elements only when the
2204     // type of the vector value being written is the same as the element type.
2205     auto memrefElTy = memRefType.getElementType();
2206     if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType())
2207       return failure();
2208 
2209     // Otherwise, element types of the memref and the vector must match.
2210     if (!memrefElTy.isa<VectorType>() &&
2211         memrefElTy != write.getVectorType().getElementType())
2212       return failure();
2213 
2214     // Out-of-bounds dims are handled by MaterializeTransferMask.
2215     if (write.hasOutOfBoundsDim())
2216       return failure();
2217     if (write.getMask()) {
2218       rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
2219           write, write.getSource(), write.getIndices(), write.getMask(),
2220           write.getVector());
2221     } else {
2222       rewriter.replaceOpWithNewOp<vector::StoreOp>(
2223           write, write.getVector(), write.getSource(), write.getIndices());
2224     }
2225     return success();
2226   }
2227 
2228   llvm::Optional<unsigned> maxTransferRank;
2229 };
2230 
2231 // Returns the values in `arrayAttr` as an integer vector.
getIntValueVector(ArrayAttr arrayAttr)2232 static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) {
2233   return llvm::to_vector<4>(
2234       llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
2235                       [](IntegerAttr attr) { return attr.getInt(); }));
2236 }
2237 
2238 // Shuffles vector.bitcast op after vector.extract op.
2239 //
2240 // This transforms IR like:
2241 //   %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
2242 //   %1 = vector.extract %0[3] : vector<8xf16>
2243 // Into:
2244 //   %0 = vector.extract %src[1] : vector<4xf32>
2245 //   %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16>
2246 //   %2 = vector.extract %1[1] : vector<2xf16>
2247 struct BubbleDownVectorBitCastForExtract
2248     : public OpRewritePattern<vector::ExtractOp> {
2249   using OpRewritePattern::OpRewritePattern;
2250 
matchAndRewriteBubbleDownVectorBitCastForExtract2251   LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
2252                                 PatternRewriter &rewriter) const override {
2253     // Only support extracting scalars for now.
2254     if (extractOp.getVectorType().getRank() != 1)
2255       return failure();
2256 
2257     auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
2258     if (!castOp)
2259       return failure();
2260 
2261     VectorType castSrcType = castOp.getSourceVectorType();
2262     VectorType castDstType = castOp.getResultVectorType();
2263     assert(castSrcType.getRank() == castDstType.getRank());
2264 
2265     // Fail to match if we only have one element in the cast op source.
2266     // This is to avoid infinite loop given that this pattern can generate
2267     // such cases.
2268     if (castSrcType.getNumElements() == 1)
2269       return failure();
2270 
2271     // Only support casting to a larger number of elements or now.
2272     // E.g., vector<4xf32> -> vector<8xf16>.
2273     if (castSrcType.getNumElements() > castDstType.getNumElements())
2274       return failure();
2275 
2276     unsigned expandRatio =
2277         castDstType.getNumElements() / castSrcType.getNumElements();
2278 
2279     auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t {
2280       return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
2281     };
2282 
2283     uint64_t index = getFirstIntValue(extractOp.getPosition());
2284 
2285     // Get the single scalar (as a vector) in the source value that packs the
2286     // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
2287     VectorType oneScalarType =
2288         VectorType::get({1}, castSrcType.getElementType());
2289     Value packedValue = rewriter.create<vector::ExtractOp>(
2290         extractOp.getLoc(), oneScalarType, castOp.getSource(),
2291         rewriter.getI64ArrayAttr(index / expandRatio));
2292 
2293     // Cast it to a vector with the desired scalar's type.
2294     // E.g. f32 -> vector<2xf16>
2295     VectorType packedType =
2296         VectorType::get({expandRatio}, castDstType.getElementType());
2297     Value castedValue = rewriter.create<vector::BitCastOp>(
2298         extractOp.getLoc(), packedType, packedValue);
2299 
2300     // Finally extract the desired scalar.
2301     rewriter.replaceOpWithNewOp<vector::ExtractOp>(
2302         extractOp, extractOp.getType(), castedValue,
2303         rewriter.getI64ArrayAttr(index % expandRatio));
2304 
2305     return success();
2306   }
2307 };
2308 
2309 // Shuffles vector.bitcast op after vector.extract_strided_slice op.
2310 //
2311 // This transforms IR like:
2312 //    %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
2313 //     %0 = vector.extract_strided_slice %cast {
2314 //            offsets = [4], sizes = [4], strides = [1]
2315 //          } : vector<8xf16> to vector<4xf16>
2316 // Into:
2317 //   %0 = vector.extract_strided_slice %src {
2318 //          offsets = [2], sizes = [2], strides = [1]
2319 //        } : vector<4xf32> to vector<2xf32>
2320 //   %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16>
2321 struct BubbleDownBitCastForStridedSliceExtract
2322     : public OpRewritePattern<vector::ExtractStridedSliceOp> {
2323   using OpRewritePattern::OpRewritePattern;
2324 
matchAndRewriteBubbleDownBitCastForStridedSliceExtract2325   LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
2326                                 PatternRewriter &rewriter) const override {
2327     auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
2328     if (!castOp)
2329       return failure();
2330 
2331     VectorType castSrcType = castOp.getSourceVectorType();
2332     VectorType castDstType = castOp.getResultVectorType();
2333     assert(castSrcType.getRank() == castDstType.getRank());
2334 
2335     int64_t castSrcLastDim = castSrcType.getShape().back();
2336     int64_t castDstLastDim = castDstType.getShape().back();
2337     // Require casting to more elements for now; other cases to be implemented.
2338     if (castSrcLastDim > castDstLastDim)
2339       return failure();
2340 
2341     // Only accept all one strides for now.
2342     if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
2343                      [](const APInt &val) { return !val.isOneValue(); }))
2344       return failure();
2345 
2346     unsigned rank = extractOp.getVectorType().getRank();
2347     assert(castDstLastDim % castSrcLastDim == 0);
2348     int64_t expandRatio = castDstLastDim / castSrcLastDim;
2349 
2350     // If we have a less number of offsets than the rank, then implicitly we
2351     // are selecting the full range for the last bitcasted dimension; other
2352     // dimensions aren't affected. Otherwise, we need to scale down the last
2353     // dimension's offset given we are extracting from less elements now.
2354     ArrayAttr newOffsets = extractOp.getOffsets();
2355     if (newOffsets.size() == rank) {
2356       SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
2357       if (offsets.back() % expandRatio != 0)
2358         return failure();
2359       offsets.back() = offsets.back() / expandRatio;
2360       newOffsets = rewriter.getI64ArrayAttr(offsets);
2361     }
2362 
2363     // Similarly for sizes.
2364     ArrayAttr newSizes = extractOp.getSizes();
2365     if (newSizes.size() == rank) {
2366       SmallVector<int64_t, 4> sizes = getIntValueVector(newSizes);
2367       if (sizes.back() % expandRatio != 0)
2368         return failure();
2369       sizes.back() = sizes.back() / expandRatio;
2370       newSizes = rewriter.getI64ArrayAttr(sizes);
2371     }
2372 
2373     SmallVector<int64_t, 4> dims =
2374         llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape());
2375     dims.back() = dims.back() / expandRatio;
2376     VectorType newExtractType =
2377         VectorType::get(dims, castSrcType.getElementType());
2378 
2379     auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
2380         extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets,
2381         newSizes, extractOp.getStrides());
2382 
2383     rewriter.replaceOpWithNewOp<vector::BitCastOp>(
2384         extractOp, extractOp.getType(), newExtractOp);
2385 
2386     return success();
2387   }
2388 };
2389 
2390 // Shuffles vector.bitcast op before vector.insert_strided_slice op.
2391 //
2392 // This transforms IR like:
2393 //   %0 = vector.insert_strided_slice %src, %dst {
2394 //          offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
2395 //   %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
2396 // Into:
2397 //   %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32>
2398 //   %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32>
2399 //   %2 = vector.insert_strided_slice %src, %dst {
2400 //          offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
2401 struct BubbleUpBitCastForStridedSliceInsert
2402     : public OpRewritePattern<vector::BitCastOp> {
2403   using OpRewritePattern::OpRewritePattern;
matchAndRewriteBubbleUpBitCastForStridedSliceInsert2404   LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
2405                                 PatternRewriter &rewriter) const override {
2406     VectorType castSrcType = bitcastOp.getSourceVectorType();
2407     VectorType castDstType = bitcastOp.getResultVectorType();
2408     assert(castSrcType.getRank() == castDstType.getRank());
2409 
2410     int64_t castSrcLastDim = castSrcType.getShape().back();
2411     int64_t castDstLastDim = castDstType.getShape().back();
2412     // Require casting to less elements for now; other cases to be implemented.
2413     if (castSrcLastDim < castDstLastDim)
2414       return failure();
2415 
2416     assert(castSrcLastDim % castDstLastDim == 0);
2417     int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
2418 
2419     auto insertOp =
2420         bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
2421     if (!insertOp)
2422       return failure();
2423 
2424     // Only accept all one strides for now.
2425     if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
2426                      [](const APInt &val) { return !val.isOneValue(); }))
2427       return failure();
2428 
2429     unsigned rank = insertOp.getSourceVectorType().getRank();
2430     // Require insert op to have the same rank for the source and destination
2431     // vector; other cases to be implemented.
2432     if (rank != insertOp.getDestVectorType().getRank())
2433       return failure();
2434 
2435     ArrayAttr newOffsets = insertOp.getOffsets();
2436     assert(newOffsets.size() == rank);
2437     SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
2438     if (offsets.back() % shrinkRatio != 0)
2439       return failure();
2440     offsets.back() = offsets.back() / shrinkRatio;
2441     newOffsets = rewriter.getI64ArrayAttr(offsets);
2442 
2443     SmallVector<int64_t, 4> srcDims =
2444         llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
2445     srcDims.back() = srcDims.back() / shrinkRatio;
2446     VectorType newCastSrcType =
2447         VectorType::get(srcDims, castDstType.getElementType());
2448 
2449     auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
2450         bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
2451 
2452     SmallVector<int64_t, 4> dstDims =
2453         llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
2454     dstDims.back() = dstDims.back() / shrinkRatio;
2455     VectorType newCastDstType =
2456         VectorType::get(dstDims, castDstType.getElementType());
2457 
2458     auto newCastDstOp = rewriter.create<vector::BitCastOp>(
2459         bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
2460 
2461     rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
2462         bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
2463         insertOp.getStrides());
2464 
2465     return success();
2466   }
2467 };
2468 
2469 // Helper that returns a vector comparison that constructs a mask:
2470 //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
2471 //
2472 // If `dim == 0` then the result will be a 0-D vector.
2473 //
2474 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
2475 //       much more compact, IR for this operation, but LLVM eventually
2476 //       generates more elaborate instructions for this intrinsic since it
2477 //       is very conservative on the boundary conditions.
buildVectorComparison(PatternRewriter & rewriter,Operation * op,bool force32BitVectorIndices,int64_t dim,Value b,Value * off=nullptr)2478 static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
2479                                    bool force32BitVectorIndices, int64_t dim,
2480                                    Value b, Value *off = nullptr) {
2481   auto loc = op->getLoc();
2482   // If we can assume all indices fit in 32-bit, we perform the vector
2483   // comparison in 32-bit to get a higher degree of SIMD parallelism.
2484   // Otherwise we perform the vector comparison using 64-bit indices.
2485   Type idxType =
2486       force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
2487   DenseIntElementsAttr indicesAttr;
2488   if (dim == 0 && force32BitVectorIndices) {
2489     indicesAttr = DenseIntElementsAttr::get(
2490         VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0});
2491   } else if (dim == 0) {
2492     indicesAttr = DenseIntElementsAttr::get(
2493         VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0});
2494   } else if (force32BitVectorIndices) {
2495     indicesAttr = rewriter.getI32VectorAttr(
2496         llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
2497   } else {
2498     indicesAttr = rewriter.getI64VectorAttr(
2499         llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
2500   }
2501   Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
2502   // Add in an offset if requested.
2503   if (off) {
2504     Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
2505     Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o);
2506     indices = rewriter.create<arith::AddIOp>(loc, ov, indices);
2507   }
2508   // Construct the vector comparison.
2509   Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
2510   Value bounds =
2511       rewriter.create<vector::SplatOp>(loc, indices.getType(), bound);
2512   return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
2513                                         bounds);
2514 }
2515 
2516 template <typename ConcreteOp>
2517 struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
2518 public:
MaterializeTransferMaskMaterializeTransferMask2519   explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt)
2520       : mlir::OpRewritePattern<ConcreteOp>(context),
2521         force32BitVectorIndices(enableIndexOpt) {}
2522 
matchAndRewriteMaterializeTransferMask2523   LogicalResult matchAndRewrite(ConcreteOp xferOp,
2524                                 PatternRewriter &rewriter) const override {
2525     if (!xferOp.hasOutOfBoundsDim())
2526       return failure();
2527 
2528     if (xferOp.getVectorType().getRank() > 1 ||
2529         llvm::size(xferOp.getIndices()) == 0)
2530       return failure();
2531 
2532     Location loc = xferOp->getLoc();
2533     VectorType vtp = xferOp.getVectorType();
2534 
2535     // Create the in-bounds mask with all elements between [0 .. dim - offset)
2536     // set and [dim - offset .. vector_length) unset.
2537     //
2538     // TODO: when the leaf transfer rank is k > 1, we need the last `k`
2539     //       dimensions here.
2540     unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
2541     Value off = xferOp.getIndices()[lastIndex];
2542     Value dim =
2543         vector::createOrFoldDimOp(rewriter, loc, xferOp.getSource(), lastIndex);
2544     Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off);
2545     Value mask = rewriter.create<vector::CreateMaskOp>(
2546         loc,
2547         VectorType::get(vtp.getShape(), rewriter.getI1Type(),
2548                         vtp.getNumScalableDims()),
2549         b);
2550     if (xferOp.getMask()) {
2551       // Intersect the in-bounds with the mask specified as an op parameter.
2552       mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask());
2553     }
2554 
2555     rewriter.updateRootInPlace(xferOp, [&]() {
2556       xferOp.getMaskMutable().assign(mask);
2557       xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
2558     });
2559 
2560     return success();
2561   }
2562 
2563 private:
2564   const bool force32BitVectorIndices;
2565 };
2566 
2567 /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only).
2568 class VectorCreateMaskOpConversion
2569     : public OpRewritePattern<vector::CreateMaskOp> {
2570 public:
VectorCreateMaskOpConversion(MLIRContext * context,bool enableIndexOpt)2571   explicit VectorCreateMaskOpConversion(MLIRContext *context,
2572                                         bool enableIndexOpt)
2573       : mlir::OpRewritePattern<vector::CreateMaskOp>(context),
2574         force32BitVectorIndices(enableIndexOpt) {}
2575 
matchAndRewrite(vector::CreateMaskOp op,PatternRewriter & rewriter) const2576   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
2577                                 PatternRewriter &rewriter) const override {
2578     auto dstType = op.getType();
2579     if (dstType.cast<VectorType>().isScalable())
2580       return failure();
2581     int64_t rank = dstType.getRank();
2582     if (rank > 1)
2583       return failure();
2584     rewriter.replaceOp(
2585         op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
2586                                   rank == 0 ? 0 : dstType.getDimSize(0),
2587                                   op.getOperand(0)));
2588     return success();
2589   }
2590 
2591 private:
2592   const bool force32BitVectorIndices;
2593 };
2594 
2595 // Drop inner most contiguous unit dimensions from transfer_read operand.
2596 class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
2597   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
2598 
matchAndRewrite(vector::TransferReadOp readOp,PatternRewriter & rewriter) const2599   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
2600                                 PatternRewriter &rewriter) const override {
2601     // TODO: support 0-d corner case.
2602     if (readOp.getTransferRank() == 0)
2603       return failure();
2604 
2605     // TODO: support mask.
2606     if (readOp.getMask())
2607       return failure();
2608 
2609     auto srcType = readOp.getSource().getType().dyn_cast<MemRefType>();
2610     if (!srcType || !srcType.hasStaticShape())
2611       return failure();
2612 
2613     if (!readOp.getPermutationMap().isMinorIdentity())
2614       return failure();
2615 
2616     auto targetType = readOp.getVectorType();
2617     if (targetType.getRank() <= 1)
2618       return failure();
2619 
2620     SmallVector<int64_t> srcStrides;
2621     int64_t srcOffset;
2622     if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
2623       return failure();
2624 
2625     size_t dimsToDrop = 0;
2626     for (size_t i = 1; i < srcStrides.size(); ++i) {
2627       int dim = srcType.getRank() - i - 1;
2628       if (srcStrides[dim] == 1) {
2629         dimsToDrop++;
2630       } else {
2631         break;
2632       }
2633     }
2634     if (dimsToDrop == 0)
2635       return failure();
2636 
2637     auto resultTargetVecType =
2638         VectorType::get(targetType.getShape().drop_back(dimsToDrop),
2639                         targetType.getElementType());
2640 
2641     MemRefType resultMemrefType;
2642     if (srcType.getLayout().getAffineMap().isIdentity()) {
2643       resultMemrefType = MemRefType::get(
2644           srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
2645           {}, srcType.getMemorySpaceAsInt());
2646     } else {
2647       AffineMap map = srcType.getLayout().getAffineMap();
2648       int numSymbols = map.getNumSymbols();
2649       for (size_t i = 0; i < dimsToDrop; ++i) {
2650         int dim = srcType.getRank() - i - 1;
2651         map = map.replace(rewriter.getAffineDimExpr(dim),
2652                           rewriter.getAffineConstantExpr(0),
2653                           map.getNumDims() - 1, numSymbols);
2654       }
2655       resultMemrefType = MemRefType::get(
2656           srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
2657           map, srcType.getMemorySpaceAsInt());
2658     }
2659 
2660     auto loc = readOp.getLoc();
2661     SmallVector<int64_t> offsets(srcType.getRank(), 0);
2662     SmallVector<int64_t> strides(srcType.getRank(), 1);
2663 
2664     ArrayAttr inBoundsAttr =
2665         readOp.getInBounds()
2666             ? rewriter.getArrayAttr(
2667                   readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
2668             : ArrayAttr();
2669     Value rankedReducedView = rewriter.create<memref::SubViewOp>(
2670         loc, resultMemrefType, readOp.getSource(), offsets, srcType.getShape(),
2671         strides);
2672     auto permMap = getTransferMinorIdentityMap(
2673         rankedReducedView.getType().cast<ShapedType>(), resultTargetVecType);
2674     Value result = rewriter.create<vector::TransferReadOp>(
2675         loc, resultTargetVecType, rankedReducedView,
2676         readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
2677         readOp.getPadding(),
2678         // TODO: support mask.
2679         /*mask=*/Value(), inBoundsAttr);
2680     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
2681                                                      result);
2682     return success();
2683   }
2684 };
2685 
2686 namespace {
2687 
2688 /// This function checks to see if the vector combining kind
2689 /// is consistent with the integer or float element type.
isValidKind(bool isInt,vector::CombiningKind kind)2690 static bool isValidKind(bool isInt, vector::CombiningKind kind) {
2691   using vector::CombiningKind;
2692   enum class KindType { FLOAT, INT, INVALID };
2693   KindType type{KindType::INVALID};
2694   switch (kind) {
2695   case CombiningKind::MINF:
2696   case CombiningKind::MAXF:
2697     type = KindType::FLOAT;
2698     break;
2699   case CombiningKind::MINUI:
2700   case CombiningKind::MINSI:
2701   case CombiningKind::MAXUI:
2702   case CombiningKind::MAXSI:
2703   case CombiningKind::AND:
2704   case CombiningKind::OR:
2705   case CombiningKind::XOR:
2706     type = KindType::INT;
2707     break;
2708   case CombiningKind::ADD:
2709   case CombiningKind::MUL:
2710     type = isInt ? KindType::INT : KindType::FLOAT;
2711     break;
2712   }
2713   bool isValidIntKind = (type == KindType::INT) && isInt;
2714   bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt);
2715   return (isValidIntKind || isValidFloatKind);
2716 }
2717 
2718 /// This function constructs the appropriate integer or float
2719 /// operation given the vector combining kind and operands. The
2720 /// supported int operations are : add, mul, min (signed/unsigned),
2721 /// max(signed/unsigned), and, or, xor. The supported float
2722 /// operations are : add, mul, min and max.
genOperator(Location loc,Value x,Value y,vector::CombiningKind kind,PatternRewriter & rewriter)2723 static Value genOperator(Location loc, Value x, Value y,
2724                          vector::CombiningKind kind,
2725                          PatternRewriter &rewriter) {
2726   using vector::CombiningKind;
2727 
2728   auto elType = x.getType().cast<VectorType>().getElementType();
2729   bool isInt = elType.isIntOrIndex();
2730 
2731   Value combinedResult{nullptr};
2732   switch (kind) {
2733   case CombiningKind::ADD:
2734     if (isInt)
2735       combinedResult = rewriter.create<arith::AddIOp>(loc, x, y);
2736     else
2737       combinedResult = rewriter.create<arith::AddFOp>(loc, x, y);
2738     break;
2739   case CombiningKind::MUL:
2740     if (isInt)
2741       combinedResult = rewriter.create<arith::MulIOp>(loc, x, y);
2742     else
2743       combinedResult = rewriter.create<arith::MulFOp>(loc, x, y);
2744     break;
2745   case CombiningKind::MINUI:
2746     combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y);
2747     break;
2748   case CombiningKind::MINSI:
2749     combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y);
2750     break;
2751   case CombiningKind::MAXUI:
2752     combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y);
2753     break;
2754   case CombiningKind::MAXSI:
2755     combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y);
2756     break;
2757   case CombiningKind::AND:
2758     combinedResult = rewriter.create<arith::AndIOp>(loc, x, y);
2759     break;
2760   case CombiningKind::OR:
2761     combinedResult = rewriter.create<arith::OrIOp>(loc, x, y);
2762     break;
2763   case CombiningKind::XOR:
2764     combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y);
2765     break;
2766   case CombiningKind::MINF:
2767     combinedResult = rewriter.create<arith::MinFOp>(loc, x, y);
2768     break;
2769   case CombiningKind::MAXF:
2770     combinedResult = rewriter.create<arith::MaxFOp>(loc, x, y);
2771     break;
2772   }
2773   return combinedResult;
2774 }
2775 
2776 /// Convert vector.scan op into arith ops and
2777 /// vector.insert_strided_slice/extract_strided_slice
2778 ///
2779 /// Ex:
2780 /// ```
2781 ///   %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim =
2782 ///   1} :
2783 ///     (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>)
2784 /// ```
2785 /// Gets converted to:
2786 /// ```
2787 ///   %cst = arith.constant dense<0> : vector<2x3xi32>
2788 ///   %0 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 1],
2789 ///   strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %1 =
2790 ///   vector.insert_strided_slice %0, %cst {offsets = [0, 0], strides = [1, 1]}
2791 ///   : vector<2x1xi32> into vector<2x3xi32> %2 = vector.extract_strided_slice
2792 ///   %arg0 {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} :
2793 ///   vector<2x3xi32> to vector<2x1xi32> %3 = arith.muli %0, %2 :
2794 ///   vector<2x1xi32> %4 = vector.insert_strided_slice %3, %1 {offsets = [0, 1],
2795 ///   strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %5 =
2796 ///   vector.extract_strided_slice %arg0 {offsets = [0, 2], sizes = [2, 1],
2797 ///   strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %6 = arith.muli %3,
2798 ///   %5 : vector<2x1xi32> %7 = vector.insert_strided_slice %6, %4 {offsets =
2799 ///   [0, 2], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %8 =
2800 ///   vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32> return %7, %8 :
2801 ///   vector<2x3xi32>, vector<2xi32>
2802 /// ```
2803 struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
2804   using OpRewritePattern<vector::ScanOp>::OpRewritePattern;
2805 
matchAndRewrite__anon5c5a5b801011::ScanToArithOps2806   LogicalResult matchAndRewrite(vector::ScanOp scanOp,
2807                                 PatternRewriter &rewriter) const override {
2808     auto loc = scanOp.getLoc();
2809     VectorType destType = scanOp.getDestType();
2810     ArrayRef<int64_t> destShape = destType.getShape();
2811     auto elType = destType.getElementType();
2812     bool isInt = elType.isIntOrIndex();
2813     if (!isValidKind(isInt, scanOp.getKind()))
2814       return failure();
2815 
2816     VectorType resType = VectorType::get(destShape, elType);
2817     Value result = rewriter.create<arith::ConstantOp>(
2818         loc, resType, rewriter.getZeroAttr(resType));
2819     int64_t reductionDim = scanOp.getReductionDim();
2820     bool inclusive = scanOp.getInclusive();
2821     int64_t destRank = destType.getRank();
2822     VectorType initialValueType = scanOp.getInitialValueType();
2823     int64_t initialValueRank = initialValueType.getRank();
2824 
2825     SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end());
2826     reductionShape[reductionDim] = 1;
2827     VectorType reductionType = VectorType::get(reductionShape, elType);
2828     SmallVector<int64_t> offsets(destRank, 0);
2829     SmallVector<int64_t> strides(destRank, 1);
2830     SmallVector<int64_t> sizes(destShape.begin(), destShape.end());
2831     sizes[reductionDim] = 1;
2832     ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes);
2833     ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides);
2834 
2835     Value lastOutput, lastInput;
2836     for (int i = 0; i < destShape[reductionDim]; i++) {
2837       offsets[reductionDim] = i;
2838       ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets);
2839       Value input = rewriter.create<vector::ExtractStridedSliceOp>(
2840           loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes,
2841           scanStrides);
2842       Value output;
2843       if (i == 0) {
2844         if (inclusive) {
2845           output = input;
2846         } else {
2847           if (initialValueRank == 0) {
2848             // ShapeCastOp cannot handle 0-D vectors
2849             output = rewriter.create<vector::BroadcastOp>(
2850                 loc, input.getType(), scanOp.getInitialValue());
2851           } else {
2852             output = rewriter.create<vector::ShapeCastOp>(
2853                 loc, input.getType(), scanOp.getInitialValue());
2854           }
2855         }
2856       } else {
2857         Value y = inclusive ? input : lastInput;
2858         output = genOperator(loc, lastOutput, y, scanOp.getKind(), rewriter);
2859         assert(output != nullptr);
2860       }
2861       result = rewriter.create<vector::InsertStridedSliceOp>(
2862           loc, output, result, offsets, strides);
2863       lastOutput = output;
2864       lastInput = input;
2865     }
2866 
2867     Value reduction;
2868     if (initialValueRank == 0) {
2869       Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0);
2870       reduction =
2871           rewriter.create<vector::BroadcastOp>(loc, initialValueType, v);
2872     } else {
2873       reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType,
2874                                                        lastOutput);
2875     }
2876 
2877     rewriter.replaceOp(scanOp, {result, reduction});
2878     return success();
2879   }
2880 };
2881 
2882 } // namespace
2883 
populateVectorMaskMaterializationPatterns(RewritePatternSet & patterns,bool force32BitVectorIndices)2884 void mlir::vector::populateVectorMaskMaterializationPatterns(
2885     RewritePatternSet &patterns, bool force32BitVectorIndices) {
2886   patterns.add<VectorCreateMaskOpConversion,
2887                MaterializeTransferMask<vector::TransferReadOp>,
2888                MaterializeTransferMask<vector::TransferWriteOp>>(
2889       patterns.getContext(), force32BitVectorIndices);
2890 }
2891 
populateShapeCastFoldingPatterns(RewritePatternSet & patterns)2892 void mlir::vector::populateShapeCastFoldingPatterns(
2893     RewritePatternSet &patterns) {
2894   patterns.add<ShapeCastOpFolder>(patterns.getContext());
2895 }
2896 
populateBubbleVectorBitCastOpPatterns(RewritePatternSet & patterns)2897 void mlir::vector::populateBubbleVectorBitCastOpPatterns(
2898     RewritePatternSet &patterns) {
2899   patterns.add<BubbleDownVectorBitCastForExtract,
2900                BubbleDownBitCastForStridedSliceExtract,
2901                BubbleUpBitCastForStridedSliceInsert>(patterns.getContext());
2902 }
2903 
populateVectorBroadcastLoweringPatterns(RewritePatternSet & patterns)2904 void mlir::vector::populateVectorBroadcastLoweringPatterns(
2905     RewritePatternSet &patterns) {
2906   patterns.add<BroadcastOpLowering>(patterns.getContext());
2907 }
2908 
populateVectorMaskOpLoweringPatterns(RewritePatternSet & patterns)2909 void mlir::vector::populateVectorMaskOpLoweringPatterns(
2910     RewritePatternSet &patterns) {
2911   patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
2912       patterns.getContext());
2913 }
2914 
populateVectorShapeCastLoweringPatterns(RewritePatternSet & patterns)2915 void mlir::vector::populateVectorShapeCastLoweringPatterns(
2916     RewritePatternSet &patterns) {
2917   patterns.add<ShapeCastOp2DDownCastRewritePattern,
2918                ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>(
2919       patterns.getContext());
2920 }
2921 
populateVectorContractLoweringPatterns(RewritePatternSet & patterns,VectorTransformsOptions options)2922 void mlir::vector::populateVectorContractLoweringPatterns(
2923     RewritePatternSet &patterns, VectorTransformsOptions options) {
2924   patterns.add<OuterProductOpLowering>(patterns.getContext());
2925   patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
2926                ContractionOpToOuterProductOpLowering>(options,
2927                                                       patterns.getContext());
2928 }
2929 
populateVectorTransposeLoweringPatterns(RewritePatternSet & patterns,VectorTransformsOptions options)2930 void mlir::vector::populateVectorTransposeLoweringPatterns(
2931     RewritePatternSet &patterns, VectorTransformsOptions options) {
2932   patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
2933       options, patterns.getContext());
2934 }
2935 
populateVectorReductionToContractPatterns(RewritePatternSet & patterns)2936 void mlir::vector::populateVectorReductionToContractPatterns(
2937     RewritePatternSet &patterns) {
2938   patterns.add<MultiReduceToContract, CombineContractBroadcast,
2939                CombineContractTranspose, ReorderCastOpsOnBroadcast,
2940                ReorderElementwiseOpsOnTranspose>(patterns.getContext());
2941 }
2942 
2943 void mlir::vector::
populateVectorTransferCollapseInnerMostContiguousDimsPatterns(RewritePatternSet & patterns)2944     populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
2945         RewritePatternSet &patterns) {
2946   patterns.add<DropInnerMostUnitDims>(patterns.getContext());
2947 }
2948 
populateVectorTransferLoweringPatterns(RewritePatternSet & patterns,llvm::Optional<unsigned> maxTransferRank)2949 void mlir::vector::populateVectorTransferLoweringPatterns(
2950     RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank) {
2951   patterns.add<TransferReadToVectorLoadLowering,
2952                TransferWriteToVectorStoreLowering>(patterns.getContext(),
2953                                                    maxTransferRank);
2954   patterns
2955       .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
2956           patterns.getContext());
2957 }
2958 
populateVectorScanLoweringPatterns(RewritePatternSet & patterns)2959 void mlir::vector::populateVectorScanLoweringPatterns(
2960     RewritePatternSet &patterns) {
2961   patterns.add<ScanToArithOps>(patterns.getContext());
2962 }
2963