1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites as 1->N patterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
14 
15 #include <type_traits>
16 
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
19 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
20 #include "mlir/Dialect/Linalg/IR/Linalg.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/SCF/SCF.h"
23 #include "mlir/Dialect/Utils/IndexingUtils.h"
24 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
25 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
26 #include "mlir/IR/ImplicitLocOpBuilder.h"
27 #include "mlir/IR/Matchers.h"
28 #include "mlir/IR/PatternMatch.h"
29 #include "mlir/Interfaces/VectorInterfaces.h"
30 
31 #include "llvm/ADT/DenseSet.h"
32 #include "llvm/ADT/MapVector.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/Support/CommandLine.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/raw_ostream.h"
37 
38 #define DEBUG_TYPE "vector-to-vector"
39 
40 using namespace mlir;
41 using namespace mlir::vector;
42 
43 // Helper to find an index in an affine map.
44 static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
45   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
46     int64_t idx = map.getDimPosition(i);
47     if (idx == index)
48       return i;
49   }
50   return None;
51 }
52 
53 // Helper to construct iterator types with one index removed.
54 static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes,
55                                             int64_t index) {
56   SmallVector<Attribute, 4> results;
57   for (const auto &it : llvm::enumerate(iteratorTypes)) {
58     int64_t idx = it.index();
59     if (idx == index)
60       continue;
61     results.push_back(it.value());
62   }
63   return results;
64 }
65 
66 // Helper to construct an affine map with one index removed.
67 static AffineMap adjustMap(AffineMap map, int64_t index,
68                            PatternRewriter &rewriter) {
69   auto *ctx = rewriter.getContext();
70   SmallVector<AffineExpr, 4> results;
71   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
72     int64_t idx = map.getDimPosition(i);
73     if (idx == index)
74       continue;
75     // Re-insert remaining indices, but renamed when occurring
76     // after the removed index.
77     auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
78     results.push_back(targetExpr);
79   }
80   return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
81 }
82 
83 // Helper method to possibly drop a dimension in a load.
84 // TODO
85 static Value reshapeLoad(Location loc, Value val, VectorType type,
86                          int64_t index, int64_t pos,
87                          PatternRewriter &rewriter) {
88   if (index == -1)
89     return val;
90   Type lowType = VectorType::Builder(type).dropDim(0);
91   // At extraction dimension?
92   if (index == 0) {
93     auto posAttr = rewriter.getI64ArrayAttr(pos);
94     return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
95   }
96   // Unroll leading dimensions.
97   VectorType vType = lowType.cast<VectorType>();
98   Type resType = VectorType::Builder(type).dropDim(index);
99   auto resVectorType = resType.cast<VectorType>();
100   Value result = rewriter.create<arith::ConstantOp>(
101       loc, resVectorType, rewriter.getZeroAttr(resVectorType));
102   for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) {
103     auto posAttr = rewriter.getI64ArrayAttr(d);
104     Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
105     Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
106     result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result,
107                                                posAttr);
108   }
109   return result;
110 }
111 
112 // Helper method to possibly drop a dimension in a store.
113 // TODO
114 static Value reshapeStore(Location loc, Value val, Value result,
115                           VectorType type, int64_t index, int64_t pos,
116                           PatternRewriter &rewriter) {
117   // Unmodified?
118   if (index == -1)
119     return val;
120   // At insertion dimension?
121   if (index == 0) {
122     auto posAttr = rewriter.getI64ArrayAttr(pos);
123     return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
124   }
125   // Unroll leading dimensions.
126   Type lowType = VectorType::Builder(type).dropDim(0);
127   VectorType vType = lowType.cast<VectorType>();
128   Type insType = VectorType::Builder(vType).dropDim(0);
129   for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
130     auto posAttr = rewriter.getI64ArrayAttr(d);
131     Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
132     Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr);
133     Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
134     result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr);
135   }
136   return result;
137 }
138 
139 template <typename IntType>
140 static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
141   return llvm::to_vector<4>(llvm::map_range(
142       arrayAttr.getAsRange<IntegerAttr>(),
143       [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
144 }
145 
146 namespace {
147 
148 /// ShapeCastOpFolder folds cancelling ShapeCastOps away.
149 //
150 // Example:
151 //
152 //  The following MLIR with cancelling ShapeCastOps:
153 //
154 //   %0 = source : vector<5x4x2xf32>
155 //   %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
156 //   %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
157 //   %3 = user %2 : vector<5x4x2xf32>
158 //
159 //  Should canonicalize to the following:
160 //
161 //   %0 = source : vector<5x4x2xf32>
162 //   %1 = user %0 : vector<5x4x2xf32>
163 //
164 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
165   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
166 
167   LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
168                                 PatternRewriter &rewriter) const override {
169     // Check if 'shapeCastOp' has vector source/result type.
170     auto sourceVectorType =
171         shapeCastOp.getSource().getType().dyn_cast_or_null<VectorType>();
172     auto resultVectorType =
173         shapeCastOp.getResult().getType().dyn_cast_or_null<VectorType>();
174     if (!sourceVectorType || !resultVectorType)
175       return failure();
176 
177     // Check if shape cast op source operand is also a shape cast op.
178     auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
179         shapeCastOp.getSource().getDefiningOp());
180     if (!sourceShapeCastOp)
181       return failure();
182     auto operandSourceVectorType =
183         sourceShapeCastOp.getSource().getType().cast<VectorType>();
184     auto operandResultVectorType = sourceShapeCastOp.getType();
185 
186     // Check if shape cast operations invert each other.
187     if (operandSourceVectorType != resultVectorType ||
188         operandResultVectorType != sourceVectorType)
189       return failure();
190 
191     rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.getSource());
192     return success();
193   }
194 };
195 
196 /// Progressive lowering of BroadcastOp.
197 class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
198 public:
199   using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
200 
201   LogicalResult matchAndRewrite(vector::BroadcastOp op,
202                                 PatternRewriter &rewriter) const override {
203     auto loc = op.getLoc();
204     VectorType dstType = op.getVectorType();
205     VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
206     Type eltType = dstType.getElementType();
207 
208     // Scalar to any vector can use splat.
209     if (!srcType) {
210       rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource());
211       return success();
212     }
213 
214     // Determine rank of source and destination.
215     int64_t srcRank = srcType.getRank();
216     int64_t dstRank = dstType.getRank();
217 
218     // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
219     if (srcRank <= 1 && dstRank == 1) {
220       Value ext;
221       if (srcRank == 0)
222         ext = rewriter.create<vector::ExtractElementOp>(loc, op.getSource());
223       else
224         ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
225       rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
226       return success();
227     }
228 
229     // Duplicate this rank.
230     // For example:
231     //   %x = broadcast %y  : k-D to n-D, k < n
232     // becomes:
233     //   %b = broadcast %y  : k-D to (n-1)-D
234     //   %x = [%b,%b,%b,%b] : n-D
235     // becomes:
236     //   %b = [%y,%y]       : (n-1)-D
237     //   %x = [%b,%b,%b,%b] : n-D
238     if (srcRank < dstRank) {
239       // Duplication.
240       VectorType resType =
241           VectorType::get(dstType.getShape().drop_front(), eltType);
242       Value bcst =
243           rewriter.create<vector::BroadcastOp>(loc, resType, op.getSource());
244       Value result = rewriter.create<arith::ConstantOp>(
245           loc, dstType, rewriter.getZeroAttr(dstType));
246       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
247         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
248       rewriter.replaceOp(op, result);
249       return success();
250     }
251 
252     // Find non-matching dimension, if any.
253     assert(srcRank == dstRank);
254     int64_t m = -1;
255     for (int64_t r = 0; r < dstRank; r++)
256       if (srcType.getDimSize(r) != dstType.getDimSize(r)) {
257         m = r;
258         break;
259       }
260 
261     // All trailing dimensions are the same. Simply pass through.
262     if (m == -1) {
263       rewriter.replaceOp(op, op.getSource());
264       return success();
265     }
266 
267     // Any non-matching dimension forces a stretch along this rank.
268     // For example:
269     //   %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32>
270     // becomes:
271     //   %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32>
272     //   %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32>
273     //   %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32>
274     //   %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32>
275     //   %x = [%a,%b,%c,%d]
276     // becomes:
277     //   %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32>
278     //   %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32>
279     //   %a = [%u, %v]
280     //   ..
281     //   %x = [%a,%b,%c,%d]
282     VectorType resType =
283         VectorType::get(dstType.getShape().drop_front(), eltType);
284     Value result = rewriter.create<arith::ConstantOp>(
285         loc, dstType, rewriter.getZeroAttr(dstType));
286     if (m == 0) {
287       // Stetch at start.
288       Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
289       Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
290       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
291         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
292     } else {
293       // Stetch not at start.
294       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
295         Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), d);
296         Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
297         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
298       }
299     }
300     rewriter.replaceOp(op, result);
301     return success();
302   }
303 };
304 
305 /// Given a 'transpose' pattern, prune the rightmost dimensions that are not
306 /// transposed.
307 void pruneNonTransposedDims(ArrayRef<int64_t> transpose,
308                             SmallVectorImpl<int64_t> &result) {
309   size_t numTransposedDims = transpose.size();
310   for (size_t transpDim : llvm::reverse(transpose)) {
311     if (transpDim != numTransposedDims - 1)
312       break;
313     numTransposedDims--;
314   }
315 
316   result.append(transpose.begin(), transpose.begin() + numTransposedDims);
317 }
318 
319 /// Progressive lowering of TransposeOp.
320 /// One:
321 ///   %x = vector.transpose %y, [1, 0]
322 /// is replaced by:
323 ///   %z = arith.constant dense<0.000000e+00>
324 ///   %0 = vector.extract %y[0, 0]
325 ///   %1 = vector.insert %0, %z [0, 0]
326 ///   ..
327 ///   %x = vector.insert .., .. [.., ..]
328 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
329 public:
330   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
331 
332   TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
333                       MLIRContext *context)
334       : OpRewritePattern<vector::TransposeOp>(context),
335         vectorTransformOptions(vectorTransformOptions) {}
336 
337   LogicalResult matchAndRewrite(vector::TransposeOp op,
338                                 PatternRewriter &rewriter) const override {
339     auto loc = op.getLoc();
340 
341     Value input = op.getVector();
342     VectorType inputType = op.getVectorType();
343     VectorType resType = op.getResultType();
344 
345     // Set up convenience transposition table.
346     SmallVector<int64_t, 4> transp;
347     for (auto attr : op.getTransp())
348       transp.push_back(attr.cast<IntegerAttr>().getInt());
349 
350     if (vectorTransformOptions.vectorTransposeLowering ==
351             vector::VectorTransposeLowering::Shuffle &&
352         resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0)
353       return rewriter.notifyMatchFailure(
354           op, "Options specifies lowering to shuffle");
355 
356     // Handle a true 2-D matrix transpose differently when requested.
357     if (vectorTransformOptions.vectorTransposeLowering ==
358             vector::VectorTransposeLowering::Flat &&
359         resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
360       Type flattenedType =
361           VectorType::get(resType.getNumElements(), resType.getElementType());
362       auto matrix =
363           rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
364       auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
365       auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
366       Value trans = rewriter.create<vector::FlatTransposeOp>(
367           loc, flattenedType, matrix, rows, columns);
368       rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
369       return success();
370     }
371 
372     // Generate unrolled extract/insert ops. We do not unroll the rightmost
373     // (i.e., highest-order) dimensions that are not transposed and leave them
374     // in vector form to improve performance. Therefore, we prune those
375     // dimensions from the shape/transpose data structures used to generate the
376     // extract/insert ops.
377     SmallVector<int64_t, 4> prunedTransp;
378     pruneNonTransposedDims(transp, prunedTransp);
379     size_t numPrunedDims = transp.size() - prunedTransp.size();
380     auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
381     SmallVector<int64_t, 4> ones(prunedInShape.size(), 1);
382     auto prunedInStrides = computeStrides(prunedInShape, ones);
383 
384     // Generates the extract/insert operations for every scalar/vector element
385     // of the leftmost transposed dimensions. We traverse every transpose
386     // element using a linearized index that we delinearize to generate the
387     // appropriate indices for the extract/insert operations.
388     Value result = rewriter.create<arith::ConstantOp>(
389         loc, resType, rewriter.getZeroAttr(resType));
390     int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
391 
392     for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
393          ++linearIdx) {
394       auto extractIdxs = delinearize(prunedInStrides, linearIdx);
395       SmallVector<int64_t, 4> insertIdxs(extractIdxs);
396       applyPermutationToVector(insertIdxs, prunedTransp);
397       Value extractOp =
398           rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
399       result =
400           rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
401     }
402 
403     rewriter.replaceOp(op, result);
404     return success();
405   }
406 
407 private:
408   /// Options to control the vector patterns.
409   vector::VectorTransformsOptions vectorTransformOptions;
410 };
411 
412 /// Rewrite a 2-D vector.transpose as a sequence of:
413 ///   vector.shape_cast 2D -> 1D
414 ///   vector.shuffle
415 ///   vector.shape_cast 1D -> 2D
416 class TransposeOp2DToShuffleLowering
417     : public OpRewritePattern<vector::TransposeOp> {
418 public:
419   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
420 
421   TransposeOp2DToShuffleLowering(
422       vector::VectorTransformsOptions vectorTransformOptions,
423       MLIRContext *context)
424       : OpRewritePattern<vector::TransposeOp>(context),
425         vectorTransformOptions(vectorTransformOptions) {}
426 
427   LogicalResult matchAndRewrite(vector::TransposeOp op,
428                                 PatternRewriter &rewriter) const override {
429     auto loc = op.getLoc();
430 
431     VectorType srcType = op.getVectorType();
432     if (srcType.getRank() != 2)
433       return rewriter.notifyMatchFailure(op, "Not a 2D transpose");
434 
435     SmallVector<int64_t, 4> transp;
436     for (auto attr : op.getTransp())
437       transp.push_back(attr.cast<IntegerAttr>().getInt());
438     if (transp[0] != 1 && transp[1] != 0)
439       return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation");
440 
441     if (vectorTransformOptions.vectorTransposeLowering !=
442         VectorTransposeLowering::Shuffle)
443       return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle");
444 
445     int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
446     Value casted = rewriter.create<vector::ShapeCastOp>(
447         loc, VectorType::get({m * n}, srcType.getElementType()),
448         op.getVector());
449     SmallVector<int64_t> mask;
450     mask.reserve(m * n);
451     for (int64_t j = 0; j < n; ++j)
452       for (int64_t i = 0; i < m; ++i)
453         mask.push_back(i * n + j);
454 
455     Value shuffled =
456         rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask);
457     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
458                                                      shuffled);
459 
460     return success();
461   }
462 
463 private:
464   /// Options to control the vector patterns.
465   vector::VectorTransformsOptions vectorTransformOptions;
466 };
467 
468 /// Progressive lowering of OuterProductOp.
469 /// One:
470 ///   %x = vector.outerproduct %lhs, %rhs, %acc
471 /// is replaced by:
472 ///   %z = zero-result
473 ///   %0 = vector.extract %lhs[0]
474 ///   %1 = vector.broadcast %0
475 ///   %2 = vector.extract %acc[0]
476 ///   %3 = vector.fma %1, %rhs, %2
477 ///   %4 = vector.insert %3, %z[0]
478 ///   ..
479 ///   %x = vector.insert %.., %..[N-1]
480 ///
481 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
482 public:
483   using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
484 
485   LogicalResult matchAndRewrite(vector::OuterProductOp op,
486                                 PatternRewriter &rewriter) const override {
487     auto loc = op.getLoc();
488 
489     VectorType lhsType = op.getOperandVectorTypeLHS();
490     VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>();
491     VectorType resType = op.getVectorType();
492     Type eltType = resType.getElementType();
493     bool isInt = eltType.isa<IntegerType, IndexType>();
494     Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0];
495     vector::CombiningKind kind = op.getKind();
496 
497     if (!rhsType) {
498       // Special case: AXPY operation.
499       Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
500       Optional<Value> mult =
501           isInt ? genMultI(loc, op.getLhs(), b, acc, kind, rewriter)
502                 : genMultF(loc, op.getLhs(), b, acc, kind, rewriter);
503       if (!mult.hasValue())
504         return failure();
505       rewriter.replaceOp(op, mult.getValue());
506       return success();
507     }
508 
509     Value result = rewriter.create<arith::ConstantOp>(
510         loc, resType, rewriter.getZeroAttr(resType));
511     for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
512       auto pos = rewriter.getI64ArrayAttr(d);
513       Value x =
514           rewriter.create<vector::ExtractOp>(loc, eltType, op.getLhs(), pos);
515       Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
516       Value r = nullptr;
517       if (acc)
518         r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
519       Optional<Value> m =
520           isInt ? genMultI(loc, a, op.getRhs(), r, kind, rewriter)
521                 : genMultF(loc, a, op.getRhs(), r, kind, rewriter);
522       if (!m.hasValue())
523         return failure();
524       result = rewriter.create<vector::InsertOp>(loc, resType, m.getValue(),
525                                                  result, pos);
526     }
527     rewriter.replaceOp(op, result);
528     return success();
529   }
530 
531 private:
532   static Optional<Value> genMultI(Location loc, Value x, Value y, Value acc,
533                                   vector::CombiningKind kind,
534                                   PatternRewriter &rewriter) {
535     using vector::CombiningKind;
536 
537     auto mul = rewriter.create<arith::MulIOp>(loc, x, y);
538     if (!acc)
539       return Optional<Value>(mul);
540 
541     if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF)
542       // Only valid for floating point types.
543       return Optional<Value>();
544 
545     return makeArithReduction(rewriter, loc, kind, mul, acc);
546   }
547 
548   static Optional<Value> genMultF(Location loc, Value x, Value y, Value acc,
549                                   vector::CombiningKind kind,
550                                   PatternRewriter &rewriter) {
551     using vector::CombiningKind;
552 
553     // Special case for fused multiply-add.
554     if (acc && kind == CombiningKind::ADD) {
555       return Optional<Value>(rewriter.create<vector::FMAOp>(loc, x, y, acc));
556     }
557 
558     auto mul = rewriter.create<arith::MulFOp>(loc, x, y);
559 
560     if (!acc)
561       return Optional<Value>(mul);
562 
563     if (kind == CombiningKind::ADD || kind == CombiningKind::AND ||
564         kind == CombiningKind::MINUI || kind == CombiningKind::MINSI ||
565         kind == CombiningKind::MAXUI || kind == CombiningKind::MAXSI ||
566         kind == CombiningKind::OR || kind == CombiningKind::XOR)
567       // Already handled or only valid for integer types.
568       return Optional<Value>();
569 
570     return makeArithReduction(rewriter, loc, kind, mul, acc);
571   }
572 };
573 
574 /// Progressive lowering of ConstantMaskOp.
575 /// One:
576 ///   %x = vector.constant_mask [a,b]
577 /// is replaced by:
578 ///   %z = zero-result
579 ///   %l = vector.constant_mask [b]
580 ///   %4 = vector.insert %l, %z[0]
581 ///   ..
582 ///   %x = vector.insert %l, %..[a-1]
583 /// until a one-dimensional vector is reached. All these operations
584 /// will be folded at LLVM IR level.
585 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
586 public:
587   using OpRewritePattern<vector::ConstantMaskOp>::OpRewritePattern;
588 
589   LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
590                                 PatternRewriter &rewriter) const override {
591     auto loc = op.getLoc();
592     auto dstType = op.getType();
593     auto eltType = dstType.getElementType();
594     auto dimSizes = op.getMaskDimSizes();
595     int64_t rank = dstType.getRank();
596 
597     if (rank == 0) {
598       assert(dimSizes.size() == 1 &&
599              "Expected exactly one dim size for a 0-D vector");
600       bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1;
601       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
602           op, dstType,
603           DenseIntElementsAttr::get(
604               VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
605               ArrayRef<bool>{value}));
606       return success();
607     }
608 
609     // Scalable constant masks can only be lowered for the "none set" case.
610     if (dstType.cast<VectorType>().isScalable()) {
611       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
612           op, DenseElementsAttr::get(dstType, false));
613       return success();
614     }
615 
616     int64_t trueDim = std::min(dstType.getDimSize(0),
617                                dimSizes[0].cast<IntegerAttr>().getInt());
618 
619     if (rank == 1) {
620       // Express constant 1-D case in explicit vector form:
621       //   [T,..,T,F,..,F].
622       SmallVector<bool, 4> values(dstType.getDimSize(0));
623       for (int64_t d = 0; d < trueDim; d++)
624         values[d] = true;
625       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
626           op, dstType, rewriter.getBoolVectorAttr(values));
627       return success();
628     }
629 
630     VectorType lowType =
631         VectorType::get(dstType.getShape().drop_front(), eltType);
632     SmallVector<int64_t, 4> newDimSizes;
633     for (int64_t r = 1; r < rank; r++)
634       newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
635     Value trueVal = rewriter.create<vector::ConstantMaskOp>(
636         loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
637     Value result = rewriter.create<arith::ConstantOp>(
638         loc, dstType, rewriter.getZeroAttr(dstType));
639     for (int64_t d = 0; d < trueDim; d++) {
640       auto pos = rewriter.getI64ArrayAttr(d);
641       result =
642           rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos);
643     }
644     rewriter.replaceOp(op, result);
645     return success();
646   }
647 };
648 
649 /// Progressive lowering of CreateMaskOp.
650 /// One:
651 ///   %x = vector.create_mask %a, ... : vector<dx...>
652 /// is replaced by:
653 ///   %l = vector.create_mask ... : vector<...>  ; one lower rank
654 ///   %0 = arith.cmpi "slt", %ci, %a       |
655 ///   %1 = select %0, %l, %zeroes    |
656 ///   %r = vector.insert %1, %pr [i] | d-times
657 ///   %x = ....
658 /// until a one-dimensional vector is reached.
659 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
660 public:
661   using OpRewritePattern<vector::CreateMaskOp>::OpRewritePattern;
662 
663   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
664                                 PatternRewriter &rewriter) const override {
665     auto dstType = op.getResult().getType().cast<VectorType>();
666     int64_t rank = dstType.getRank();
667     if (rank <= 1)
668       return rewriter.notifyMatchFailure(
669           op, "0-D and 1-D vectors are handled separately");
670 
671     auto loc = op.getLoc();
672     auto eltType = dstType.getElementType();
673     int64_t dim = dstType.getDimSize(0);
674     Value idx = op.getOperand(0);
675 
676     VectorType lowType =
677         VectorType::get(dstType.getShape().drop_front(), eltType);
678     Value trueVal = rewriter.create<vector::CreateMaskOp>(
679         loc, lowType, op.getOperands().drop_front());
680     Value falseVal = rewriter.create<arith::ConstantOp>(
681         loc, lowType, rewriter.getZeroAttr(lowType));
682     Value result = rewriter.create<arith::ConstantOp>(
683         loc, dstType, rewriter.getZeroAttr(dstType));
684     for (int64_t d = 0; d < dim; d++) {
685       Value bnd =
686           rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d));
687       Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
688                                                  bnd, idx);
689       Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal);
690       auto pos = rewriter.getI64ArrayAttr(d);
691       result =
692           rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos);
693     }
694     rewriter.replaceOp(op, result);
695     return success();
696   }
697 };
698 
699 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
700 /// vectors progressively on the way to target llvm.matrix intrinsics.
701 /// This iterates over the most major dimension of the 2-D vector and performs
702 /// rewrites into:
703 ///   vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
704 class ShapeCastOp2DDownCastRewritePattern
705     : public OpRewritePattern<vector::ShapeCastOp> {
706 public:
707   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
708 
709   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
710                                 PatternRewriter &rewriter) const override {
711     auto sourceVectorType = op.getSourceVectorType();
712     auto resultVectorType = op.getResultVectorType();
713     if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
714       return failure();
715 
716     auto loc = op.getLoc();
717     Value desc = rewriter.create<arith::ConstantOp>(
718         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
719     unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
720     for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
721       Value vec = rewriter.create<vector::ExtractOp>(loc, op.getSource(), i);
722       desc = rewriter.create<vector::InsertStridedSliceOp>(
723           loc, vec, desc,
724           /*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
725     }
726     rewriter.replaceOp(op, desc);
727     return success();
728   }
729 };
730 
731 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
732 /// vectors progressively.
733 /// This iterates over the most major dimension of the 2-D vector and performs
734 /// rewrites into:
735 ///   vector.extract_strided_slice from 1-D + vector.insert into 2-D
736 /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
737 class ShapeCastOp2DUpCastRewritePattern
738     : public OpRewritePattern<vector::ShapeCastOp> {
739 public:
740   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
741 
742   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
743                                 PatternRewriter &rewriter) const override {
744     auto sourceVectorType = op.getSourceVectorType();
745     auto resultVectorType = op.getResultVectorType();
746     if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
747       return failure();
748 
749     auto loc = op.getLoc();
750     Value desc = rewriter.create<arith::ConstantOp>(
751         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
752     unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
753     for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
754       Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
755           loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize,
756           /*sizes=*/mostMinorVectorSize,
757           /*strides=*/1);
758       desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
759     }
760     rewriter.replaceOp(op, desc);
761     return success();
762   }
763 };
764 
765 // We typically should not lower general shape cast operations into data
766 // movement instructions, since the assumption is that these casts are
767 // optimized away during progressive lowering. For completeness, however,
768 // we fall back to a reference implementation that moves all elements
769 // into the right place if we get here.
770 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
771 public:
772   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
773 
774   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
775                                 PatternRewriter &rewriter) const override {
776     Location loc = op.getLoc();
777     auto sourceVectorType = op.getSourceVectorType();
778     auto resultVectorType = op.getResultVectorType();
779 
780     // Special case 2D/1D lowerings with better implementations.
781     // TODO: make is ND/1D to allow generic ND->1D->MD.
782     int64_t srcRank = sourceVectorType.getRank();
783     int64_t resRank = resultVectorType.getRank();
784     if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
785       return failure();
786 
787     // Generic ShapeCast lowering path goes all the way down to unrolled scalar
788     // extract/insert chains.
789     // TODO: consider evolving the semantics to only allow 1D source or dest and
790     // drop this potentially very expensive lowering.
791     // Compute number of elements involved in the reshape.
792     int64_t numElts = 1;
793     for (int64_t r = 0; r < srcRank; r++)
794       numElts *= sourceVectorType.getDimSize(r);
795     // Replace with data movement operations:
796     //    x[0,0,0] = y[0,0]
797     //    x[0,0,1] = y[0,1]
798     //    x[0,1,0] = y[0,2]
799     // etc., incrementing the two index vectors "row-major"
800     // within the source and result shape.
801     SmallVector<int64_t, 4> srcIdx(srcRank);
802     SmallVector<int64_t, 4> resIdx(resRank);
803     Value result = rewriter.create<arith::ConstantOp>(
804         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
805     for (int64_t i = 0; i < numElts; i++) {
806       if (i != 0) {
807         incIdx(srcIdx, sourceVectorType, srcRank - 1);
808         incIdx(resIdx, resultVectorType, resRank - 1);
809       }
810       Value e = rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
811       result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx);
812     }
813     rewriter.replaceOp(op, result);
814     return success();
815   }
816 
817 private:
818   static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) {
819     assert(0 <= r && r < tp.getRank());
820     if (++idx[r] == tp.getDimSize(r)) {
821       idx[r] = 0;
822       incIdx(idx, tp, r - 1);
823     }
824   }
825 };
826 
827 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
828 /// Ex:
829 /// ```
830 ///   %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
831 ///   %1 = vector.multi_reduction add, %0 [1]
832 ///     : vector<8x32x16xf32> to vector<8x16xf32>
833 /// ```
834 /// Gets converted to:
835 /// ```
836 ///   %1 = vector.contract {indexing_maps = [
837 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
838 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
839 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
840 ///    iterator_types = ["parallel", "parallel", "reduction"],
841 ///    kind = add} %0, %arg1, %cst_f0
842 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
843 ///  ```
844 struct MultiReduceToContract
845     : public OpRewritePattern<vector::MultiDimReductionOp> {
846   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
847 
848   LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
849                                 PatternRewriter &rewriter) const override {
850     if (reduceOp.getKind() != vector::CombiningKind::ADD)
851       return failure();
852     Operation *mulOp = reduceOp.getSource().getDefiningOp();
853     if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
854       return failure();
855     SmallVector<bool> reductionMask = reduceOp.getReductionMask();
856     auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
857     SmallVector<AffineExpr> exprs;
858     SmallVector<StringRef> iteratorTypes;
859     for (const auto &isReduceDim : llvm::enumerate(reductionMask)) {
860       if (!isReduceDim.value()) {
861         iteratorTypes.push_back(getParallelIteratorTypeName());
862         exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
863       } else {
864         iteratorTypes.push_back(getReductionIteratorTypeName());
865       }
866     }
867     auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(),
868                                  /*symCount=*/0, exprs, reduceOp.getContext());
869     Value zero = rewriter.create<arith::ConstantOp>(
870         reduceOp.getLoc(), reduceOp.getDestType(),
871         rewriter.getZeroAttr(reduceOp.getDestType()));
872     rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
873         reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), zero,
874         rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
875         rewriter.getStrArrayAttr(iteratorTypes));
876     return success();
877   }
878 };
879 
880 /// Merge TransposeOp into ContractionOp user.
881 /// Ex:
882 /// ```
883 ///   %0 = vector.transpose %arg0, [2, 0, 1]
884 ///     : vector<32x16x8xf32> to vector<8x32x16xf32>
885 ///   %1 = vector.contract {indexing_maps = [
886 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
887 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
888 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
889 ///    iterator_types = ["parallel", "parallel", "reduction"],
890 ///    kind = add} %0, %arg1, %cst_f0
891 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
892 /// ```
893 /// Gets converted to:
894 /// ```
895 ///   %1 = vector.contract {indexing_maps = [
896 ///         affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
897 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
898 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
899 ///    iterator_types = ["parallel", "parallel", "reduction"],
900 ///    kind = add} %arg0, %arg1, %cst_f0
901 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
902 ///  ```
903 struct CombineContractTranspose
904     : public OpRewritePattern<vector::ContractionOp> {
905   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
906 
907   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
908                                 PatternRewriter &rewriter) const override {
909     SmallVector<AffineMap, 4> maps =
910         llvm::to_vector<4>(contractOp.getIndexingMaps());
911     Value lhs = contractOp.getLhs();
912     Value rhs = contractOp.getRhs();
913     size_t index = 0;
914     bool changed = false;
915     for (Value *operand : {&lhs, &rhs}) {
916       AffineMap &map = maps[index++];
917       auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
918       if (!transposeOp)
919         continue;
920       SmallVector<int64_t> perm;
921       transposeOp.getTransp(perm);
922       AffineMap permutationMap = AffineMap::getPermutationMap(
923           extractVector<unsigned>(transposeOp.getTransp()),
924           contractOp.getContext());
925       map = inversePermutation(permutationMap).compose(map);
926       *operand = transposeOp.getVector();
927       changed = true;
928     }
929     if (!changed)
930       return failure();
931     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
932         contractOp, lhs, rhs, contractOp.getAcc(),
933         rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
934     return success();
935   }
936 };
937 
938 /// Merge BroadcastOp into ContractionOp user.
939 /// Ex:
940 /// ```
941 ///   %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
942 ///   %1 = vector.contract {indexing_maps = [
943 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
944 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
945 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
946 ///    iterator_types = ["parallel", "parallel", "reduction"],
947 ///    kind = add} %0, %arg1, %cst_f0
948 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
949 /// ```
950 /// Gets converted to:
951 /// ```
952 ///   %1 = vector.contract {indexing_maps = [
953 ///         affine_map<(d0, d1, d2) -> (d1, d2)>,
954 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
955 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
956 ///    iterator_types = ["parallel", "parallel", "reduction"],
957 ///    kind = add} %arg0, %arg1, %cst_f0
958 ///    : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
959 ///  ```
960 struct CombineContractBroadcast
961     : public OpRewritePattern<vector::ContractionOp> {
962   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
963 
964   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
965                                 PatternRewriter &rewriter) const override {
966     SmallVector<AffineMap, 4> maps =
967         llvm::to_vector<4>(contractOp.getIndexingMaps());
968     Value lhs = contractOp.getLhs();
969     Value rhs = contractOp.getRhs();
970     size_t index = 0;
971     bool changed = false;
972     for (Value *operand : {&lhs, &rhs}) {
973       AffineMap &map = maps[index++];
974       auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
975       if (!broadcast)
976         continue;
977       // contractionOp can only take vector as operands.
978       auto srcType = broadcast.getSourceType().dyn_cast<VectorType>();
979       if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank())
980         continue;
981       int64_t rankDiff =
982           broadcast.getVectorType().getRank() - srcType.getRank();
983       bool innerDimBroadcast = false;
984       SmallVector<AffineExpr> originalDims;
985       for (const auto &dim : llvm::enumerate(srcType.getShape())) {
986         if (dim.value() !=
987             broadcast.getVectorType().getDimSize(rankDiff + dim.index())) {
988           innerDimBroadcast = true;
989           break;
990         }
991         originalDims.push_back(
992             rewriter.getAffineDimExpr(dim.index() + rankDiff));
993       }
994       // Contract doesn't support inner dimension broadcast. Once this is
995       // relaxed we can remove this case.
996       if (innerDimBroadcast)
997         continue;
998       AffineMap broadcastMap =
999           AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims,
1000                          contractOp.getContext());
1001       map = broadcastMap.compose(map);
1002       *operand = broadcast.getSource();
1003       changed = true;
1004     }
1005     if (!changed)
1006       return failure();
1007     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1008         contractOp, lhs, rhs, contractOp.getAcc(),
1009         rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
1010     return success();
1011   }
1012 };
1013 
1014 /// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and
1015 /// contraction ops closer, which kicks in CombineContractBroadcast pattern when
1016 /// casting ops are around these operations.
1017 /// Ex:
1018 /// ```
1019 ///   %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8>
1020 ///   %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
1021 /// ```
1022 /// Gets converted to:
1023 /// ```
1024 ///   %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32>
1025 ///   %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32>
1026 /// ```
1027 struct ReorderCastOpsOnBroadcast
1028     : public OpInterfaceRewritePattern<CastOpInterface> {
1029   using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
1030 
1031   LogicalResult matchAndRewrite(CastOpInterface op,
1032                                 PatternRewriter &rewriter) const override {
1033     if (op->getNumOperands() != 1)
1034       return failure();
1035     auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
1036     if (!bcastOp)
1037       return failure();
1038 
1039     Type castResTy = getElementTypeOrSelf(op->getResult(0));
1040     if (auto vecTy = bcastOp.getSourceType().dyn_cast<VectorType>())
1041       castResTy = VectorType::get(vecTy.getShape(), castResTy);
1042     auto castOp =
1043         rewriter.create(op->getLoc(), op->getName().getIdentifier(),
1044                         bcastOp.getSource(), castResTy, op->getAttrs());
1045     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1046         op, op->getResult(0).getType(), castOp->getResult(0));
1047     return success();
1048   }
1049 };
1050 
1051 /// Reorders cast(transpose) to transpose(cast). This makes broadcast ops and
1052 /// contraction ops closer, which kicks in CombineContractTranspose pattern when
1053 /// casting ops are around these operations.
1054 /// Ex:
1055 /// ```
1056 ///   %0 = vector.transpose %arg0, [2, 0, 1]
1057 ///     : vector<32x16x8xi8> to vector<8x32x16xi8>
1058 ///   %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
1059 /// ```
1060 /// Gets converted to:
1061 /// ```
1062 ///   %0 = arith.extsi %0 : vector<32x16x8xi8> to vector<32x16x8xi32>
1063 ///   %1 = vector.transpose %arg0, [2, 0, 1]
1064 ///     : vector<32x16x8xi32> to vector<8x32x16xi32>
1065 /// ```
1066 struct ReorderCastOpsOnTranspose
1067     : public OpInterfaceRewritePattern<CastOpInterface> {
1068 
1069   using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
1070 
1071   LogicalResult matchAndRewrite(CastOpInterface op,
1072                                 PatternRewriter &rewriter) const override {
1073     if (op->getNumOperands() != 1)
1074       return failure();
1075     auto transpOp = op->getOperand(0).getDefiningOp<vector::TransposeOp>();
1076     if (!transpOp)
1077       return failure();
1078 
1079     auto castResTy = transpOp.getVectorType();
1080     castResTy = VectorType::get(castResTy.getShape(),
1081                                 getElementTypeOrSelf(op->getResult(0)));
1082     auto castOp =
1083         rewriter.create(op->getLoc(), op->getName().getIdentifier(),
1084                         transpOp.getVector(), castResTy, op->getAttrs());
1085     rewriter.replaceOpWithNewOp<vector::TransposeOp>(
1086         op, op->getResult(0).getType(), castOp->getResult(0),
1087         transpOp.getTransp());
1088     return success();
1089   }
1090 };
1091 
1092 } // namespace
1093 
1094 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
1095 /// operands `x` and `y`.
1096 static Value createAdd(Location loc, Value x, Value y, bool isInt,
1097                        PatternRewriter &rewriter) {
1098   if (isInt)
1099     return rewriter.create<arith::AddIOp>(loc, x, y);
1100   return rewriter.create<arith::AddFOp>(loc, x, y);
1101 }
1102 
1103 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
1104 /// operands `x and `y`.
1105 static Value createMul(Location loc, Value x, Value y, bool isInt,
1106                        PatternRewriter &rewriter) {
1107   if (isInt)
1108     return rewriter.create<arith::MulIOp>(loc, x, y);
1109   return rewriter.create<arith::MulFOp>(loc, x, y);
1110 }
1111 
1112 namespace mlir {
1113 
1114 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1115 /// semantics to:
1116 /// ```
1117 ///    %mta = maybe_transpose
1118 ///    %mtb = maybe_transpose
1119 ///    %flattened_a = vector.shape_cast %mta
1120 ///    %flattened_b = vector.shape_cast %mtb
1121 ///    %flattened_d = vector.matmul %flattened_a, %flattened_b
1122 ///    %mtd = vector.shape_cast %flattened_d
1123 ///    %d = maybe_untranspose %mtd
1124 ///    %e = add %c, %d
1125 /// ```
1126 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
1127 //
1128 /// This only kicks in when VectorTransformsOptions is set to `Matmul`.
1129 /// vector.transpose operations are inserted if the vector.contract op is not a
1130 /// row-major matrix multiply.
1131 LogicalResult
1132 ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
1133                                                  PatternRewriter &rew) const {
1134   // TODO: implement masks
1135   if (llvm::size(op.getMasks()) != 0)
1136     return failure();
1137   if (vectorTransformOptions.vectorContractLowering !=
1138       vector::VectorContractLowering::Matmul)
1139     return failure();
1140   if (failed(filter(op)))
1141     return failure();
1142 
1143   auto iteratorTypes = op.getIteratorTypes().getValue();
1144   if (!isParallelIterator(iteratorTypes[0]) ||
1145       !isParallelIterator(iteratorTypes[1]) ||
1146       !isReductionIterator(iteratorTypes[2]))
1147     return failure();
1148 
1149   Type elementType = op.getLhsType().getElementType();
1150   if (!elementType.isIntOrFloat())
1151     return failure();
1152 
1153   // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
1154   // Bail out if the contraction cannot be put in this form.
1155   MLIRContext *ctx = op.getContext();
1156   Location loc = op.getLoc();
1157   AffineExpr m, n, k;
1158   bindDims(rew.getContext(), m, n, k);
1159   // LHS must be A(m, k) or A(k, m).
1160   Value lhs = op.getLhs();
1161   auto lhsMap = op.getIndexingMaps()[0];
1162   if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
1163     lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
1164   else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
1165     return failure();
1166 
1167   // RHS must be B(k, n) or B(n, k).
1168   Value rhs = op.getRhs();
1169   auto rhsMap = op.getIndexingMaps()[1];
1170   if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
1171     rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
1172   else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
1173     return failure();
1174 
1175   // At this point lhs and rhs are in row-major.
1176   VectorType lhsType = lhs.getType().cast<VectorType>();
1177   VectorType rhsType = rhs.getType().cast<VectorType>();
1178   int64_t lhsRows = lhsType.getDimSize(0);
1179   int64_t lhsColumns = lhsType.getDimSize(1);
1180   int64_t rhsColumns = rhsType.getDimSize(1);
1181 
1182   Type flattenedLHSType =
1183       VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
1184   lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
1185 
1186   Type flattenedRHSType =
1187       VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
1188   rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
1189 
1190   Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
1191                                            rhsColumns);
1192   mul = rew.create<vector::ShapeCastOp>(
1193       loc,
1194       VectorType::get({lhsRows, rhsColumns},
1195                       getElementTypeOrSelf(op.getAcc().getType())),
1196       mul);
1197 
1198   // ACC must be C(m, n) or C(n, m).
1199   auto accMap = op.getIndexingMaps()[2];
1200   if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
1201     mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
1202   else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
1203     llvm_unreachable("invalid contraction semantics");
1204 
1205   Value res =
1206       elementType.isa<IntegerType>()
1207           ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
1208           : static_cast<Value>(
1209                 rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
1210 
1211   rew.replaceOp(op, res);
1212   return success();
1213 }
1214 
1215 namespace {
1216 struct IteratorType {
1217   IteratorType(StringRef strRef) : strRef(strRef) {}
1218   bool isOfType(Attribute attr) const {
1219     auto sAttr = attr.dyn_cast<StringAttr>();
1220     return sAttr && sAttr.getValue() == strRef;
1221   }
1222   StringRef strRef;
1223 };
1224 struct Par : public IteratorType {
1225   Par() : IteratorType(getParallelIteratorTypeName()) {}
1226 };
1227 struct Red : public IteratorType {
1228   Red() : IteratorType(getReductionIteratorTypeName()) {}
1229 };
1230 
1231 /// Generate a vector implementation for matmat, matvec and tmatvec.
1232 /// This unrolls outer-products along the reduction dimension.
1233 struct UnrolledOuterProductGenerator
1234     : public StructuredGenerator<vector::ContractionOp> {
1235   UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op)
1236       : StructuredGenerator<vector::ContractionOp>(builder, op),
1237         kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
1238         res(op.getAcc()), lhsType(op.getLhsType()) {}
1239 
1240   Value t(Value v) {
1241     static constexpr std::array<int64_t, 2> perm = {1, 0};
1242     return builder.create<vector::TransposeOp>(loc, v, perm);
1243   }
1244 
1245   Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) {
1246     assert(reductionSize > 0);
1247     for (int64_t k = 0; k < reductionSize; ++k) {
1248       Value a = builder.create<vector::ExtractOp>(loc, lhs, k);
1249       Value b = builder.create<vector::ExtractOp>(loc, rhs, k);
1250       res = builder.create<vector::OuterProductOp>(loc, res.getType(), a, b,
1251                                                    res, kind);
1252     }
1253     return res;
1254   }
1255 
1256   /// Two outer parallel, one inner reduction (matmat flavor).
1257   FailureOr<Value> matmat() {
1258     if (!iters({Par(), Par(), Red()}))
1259       return failure();
1260     // Set up the parallel/reduction structure in the right form.
1261     AffineExpr m, n, k;
1262     bindDims(builder.getContext(), m, n, k);
1263     // Classical row-major matmul:  Just permute the lhs.
1264     if (layout({{m, k}, {k, n}, {m, n}}))
1265       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1266     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1267     if (layout({{m, k}, {n, k}, {m, n}})) {
1268       Value tlhs = t(lhs);
1269       return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1));
1270     }
1271     // No need to permute anything.
1272     if (layout({{k, m}, {k, n}, {m, n}}))
1273       return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1274     // Just permute the rhs.
1275     if (layout({{k, m}, {n, k}, {m, n}}))
1276       return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0));
1277     // Transposed output: swap RHS and LHS.
1278     // Classical row-major matmul: permute the lhs.
1279     if (layout({{m, k}, {k, n}, {n, m}}))
1280       return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1));
1281     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1282     if (layout({{m, k}, {n, k}, {n, m}})) {
1283       Value trhs = t(rhs);
1284       return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1));
1285     }
1286     if (layout({{k, m}, {k, n}, {n, m}}))
1287       return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1288     if (layout({{k, m}, {n, k}, {n, m}}))
1289       return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1290     return failure();
1291   }
1292 
1293   /// One outer parallel, one inner reduction (matvec flavor)
1294   FailureOr<Value> matvec() {
1295     if (!iters({Par(), Red()}))
1296       return failure();
1297     AffineExpr m, k;
1298     bindDims(builder.getContext(), m, k);
1299 
1300     // Case mat-vec: transpose.
1301     if (layout({{m, k}, {k}, {m}}))
1302       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1303     // Case mat-trans-vec: ready to go.
1304     if (layout({{k, m}, {k}, {m}}))
1305       return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1306     // Case vec-mat: swap and transpose.
1307     if (layout({{k}, {m, k}, {m}}))
1308       return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1309     // Case vec-mat-trans: swap and ready to go.
1310     if (layout({{k}, {k, m}, {m}}))
1311       return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1312     return failure();
1313   }
1314 
1315   //
1316   // One outer reduction, one inner parallel (tmatvec flavor)
1317   //
1318   FailureOr<Value> tmatvec() {
1319     if (!iters({Red(), Par()}))
1320       return failure();
1321     AffineExpr k, m;
1322     bindDims(builder.getContext(), k, m);
1323 
1324     // Case mat-vec: transpose.
1325     if (layout({{m, k}, {k}, {m}}))
1326       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1327     // Case mat-trans-vec: ready to go.
1328     if (layout({{k, m}, {k}, {m}}))
1329       return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1330     // Case vec-mat: swap and transpose.
1331     if (layout({{k}, {m, k}, {m}}))
1332       return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1333     // Case vec-mat-trans: swap and ready to go.
1334     if (layout({{k}, {k, m}, {m}}))
1335       return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1336     return failure();
1337   }
1338 
1339 private:
1340   vector::CombiningKind kind;
1341   Value lhs, rhs, res;
1342   VectorType lhsType;
1343 };
1344 } // namespace
1345 
1346 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1347 /// semantics to a reduction_size-unrolled sequence:
1348 /// ```
1349 ///    %at = vector.transpose %a, [1, 0]
1350 ///    %bRow0 = vector.extract %b[0]
1351 ///    %atRow0 = vector.extract %at[0]
1352 ///    %c0 = vector.outerproduct %atRow0, %bRow0, %c
1353 ///    ...
1354 ///    %bRowK = vector.extract %b[K]
1355 ///    %atRowK = vector.extract %at[K]
1356 ///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
1357 /// ```
1358 ///
1359 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but
1360 /// otherwise supports any layout permutation of the matrix-multiply.
1361 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
1362     vector::ContractionOp op, PatternRewriter &rewriter) const {
1363   // TODO: implement masks
1364   if (llvm::size(op.getMasks()) != 0)
1365     return failure();
1366 
1367   if (vectorTransformOptions.vectorContractLowering !=
1368       vector::VectorContractLowering::OuterProduct)
1369     return failure();
1370 
1371   if (failed(filter(op)))
1372     return failure();
1373 
1374   UnrolledOuterProductGenerator e(rewriter, op);
1375   FailureOr<Value> matmatRes = e.matmat();
1376   if (succeeded(matmatRes)) {
1377     rewriter.replaceOp(op, *matmatRes);
1378     return success();
1379   }
1380   FailureOr<Value> matvecRes = e.matvec();
1381   if (succeeded(matvecRes)) {
1382     rewriter.replaceOp(op, *matvecRes);
1383     return success();
1384   }
1385   FailureOr<Value> tmatvecRes = e.tmatvec();
1386   if (succeeded(tmatvecRes)) {
1387     rewriter.replaceOp(op, *tmatvecRes);
1388     return success();
1389   }
1390 
1391   return failure();
1392 }
1393 
1394 LogicalResult
1395 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
1396                                             PatternRewriter &rewriter) const {
1397   // TODO: implement masks
1398   if (llvm::size(op.getMasks()) != 0)
1399     return failure();
1400 
1401   if (failed(filter(op)))
1402     return failure();
1403 
1404   if (vectorTransformOptions.vectorContractLowering !=
1405       vector::VectorContractLowering::Dot)
1406     return failure();
1407 
1408   auto iteratorTypes = op.getIteratorTypes().getValue();
1409   static constexpr std::array<int64_t, 2> perm = {1, 0};
1410   Location loc = op.getLoc();
1411   Value lhs = op.getLhs(), rhs = op.getRhs();
1412 
1413   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1414   auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
1415   AffineExpr m, n, k;
1416   bindDims(rewriter.getContext(), m, n, k);
1417   SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
1418   //
1419   // In the following we wish to make the reduction dimension innermost so we
1420   // can load vectors and just fmul + reduce into a scalar.
1421   //
1422   if (isParallelIterator(iteratorTypes[0]) &&
1423       isParallelIterator(iteratorTypes[1]) &&
1424       isReductionIterator(iteratorTypes[2])) {
1425     //
1426     // Two outer parallel, one inner reduction (matmat flavor).
1427     //
1428     if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1429       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1430     } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
1431       // No need to permute anything.
1432     } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1433       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1434       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1435     } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1436       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1437     } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1438       // This is the classical row-major matmul. Just permute the lhs.
1439       Value tmp = lhs;
1440       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1441       rhs = tmp;
1442     } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1443       std::swap(lhs, rhs);
1444     } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1445       Value tmp = lhs;
1446       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1447       rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
1448     } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1449       Value tmp = rhs;
1450       rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1451       lhs = tmp;
1452     } else {
1453       return failure();
1454     }
1455   } else if (isParallelIterator(iteratorTypes[0]) &&
1456              isReductionIterator(iteratorTypes[1])) {
1457     //
1458     // One outer parallel, one inner reduction (matvec flavor)
1459     //
1460     if (maps == infer({{m, n}, {n}, {m}})) {
1461       // No need to permute anything.
1462     } else if (maps == infer({{n, m}, {n}, {m}})) {
1463       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1464     } else if (maps == infer({{n}, {m, n}, {m}})) {
1465       std::swap(lhs, rhs);
1466     } else if (maps == infer({{n}, {n, m}, {m}})) {
1467       std::swap(lhs, rhs);
1468       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1469     } else {
1470       return failure();
1471     }
1472   } else {
1473     return failure();
1474   }
1475 
1476   VectorType dstType = op.getResultType().cast<VectorType>();
1477   assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
1478          "Expected dst type of rank 1 or 2");
1479 
1480   unsigned rank = dstType.getRank();
1481   unsigned dstRows = dstType.getShape()[0];
1482   unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
1483 
1484   // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
1485   Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
1486                                                  rewriter.getZeroAttr(dstType));
1487   bool isInt = dstType.getElementType().isa<IntegerType>();
1488   for (unsigned r = 0; r < dstRows; ++r) {
1489     Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
1490     for (unsigned c = 0; c < dstColumns; ++c) {
1491       Value b = rank == 1
1492                     ? rhs
1493                     : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
1494       Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
1495       Value reduced = rewriter.create<vector::ReductionOp>(
1496           op.getLoc(), vector::CombiningKind::ADD, m);
1497 
1498       SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
1499                                               : SmallVector<int64_t, 2>{r, c};
1500       res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
1501     }
1502   }
1503   if (auto acc = op.getAcc())
1504     res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
1505   rewriter.replaceOp(op, res);
1506   return success();
1507 }
1508 
1509 /// Progressive lowering of ContractionOp.
1510 /// One:
1511 ///   %x = vector.contract with at least one free/batch dimension
1512 /// is replaced by:
1513 ///   %a = vector.contract with one less free/batch dimension
1514 ///   %b = vector.contract with one less free/batch dimension
1515 ///   ..
1516 ///   %x = combine %a %b ..
1517 /// until a pure contraction is reached (no free/batch dimensions),
1518 /// which is replaced by a dot-product.
1519 ///
1520 /// This only kicks in when either VectorTransformsOptions is set
1521 /// to DOT or when other contraction patterns fail.
1522 //
1523 // TODO: break down into transpose/reshape/cast ops
1524 //               when they become available to avoid code dup
1525 // TODO: investigate lowering order impact on performance
1526 LogicalResult
1527 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
1528                                        PatternRewriter &rewriter) const {
1529   // TODO: implement masks.
1530   if (llvm::size(op.getMasks()) != 0)
1531     return failure();
1532 
1533   if (failed(filter(op)))
1534     return failure();
1535 
1536   // TODO: support mixed mode contract lowering.
1537   if (op.getLhsType().getElementType() !=
1538           getElementTypeOrSelf(op.getAccType()) ||
1539       op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
1540     return failure();
1541 
1542   // TODO: implement benefits, cost models.
1543   MLIRContext *ctx = op.getContext();
1544   ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
1545   if (succeeded(pat1.matchAndRewrite(op, rewriter)))
1546     return success();
1547   ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
1548   if (succeeded(pat2.matchAndRewrite(op, rewriter)))
1549     return success();
1550   ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
1551   if (succeeded(pat3.matchAndRewrite(op, rewriter)))
1552     return success();
1553 
1554   // Find first batch dimension in LHS/RHS, and lower when found.
1555   std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
1556   if (!batchDimMap.empty()) {
1557     int64_t lhsIndex = batchDimMap[0].first;
1558     int64_t rhsIndex = batchDimMap[0].second;
1559     rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter));
1560     return success();
1561   }
1562 
1563   // Collect contracting dimensions.
1564   std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
1565       op.getContractingDimMap();
1566   DenseSet<int64_t> lhsContractingDimSet;
1567   DenseSet<int64_t> rhsContractingDimSet;
1568   for (auto &dimPair : contractingDimMap) {
1569     lhsContractingDimSet.insert(dimPair.first);
1570     rhsContractingDimSet.insert(dimPair.second);
1571   }
1572 
1573   // Find first free dimension in LHS, and lower when found.
1574   VectorType lhsType = op.getLhsType();
1575   for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
1576     if (lhsContractingDimSet.count(lhsIndex) == 0) {
1577       rewriter.replaceOp(
1578           op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter));
1579       return success();
1580     }
1581   }
1582 
1583   // Find first free dimension in RHS, and lower when found.
1584   VectorType rhsType = op.getRhsType();
1585   for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
1586     if (rhsContractingDimSet.count(rhsIndex) == 0) {
1587       rewriter.replaceOp(
1588           op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter));
1589       return success();
1590     }
1591   }
1592 
1593   // Lower the first remaining reduction dimension.
1594   if (!contractingDimMap.empty()) {
1595     rewriter.replaceOp(op, lowerReduction(op, rewriter));
1596     return success();
1597   }
1598 
1599   return failure();
1600 }
1601 
1602 // Lower one parallel dimension.
1603 // TODO: consider reusing existing contract unrolling
1604 Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
1605                                            int64_t lhsIndex, int64_t rhsIndex,
1606                                            PatternRewriter &rewriter) const {
1607   VectorType lhsType = op.getLhsType();
1608   VectorType rhsType = op.getRhsType();
1609   VectorType resType = op.getResultType().cast<VectorType>();
1610   // Find the iterator type index and result index.
1611   SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
1612   int64_t iterIndex = -1;
1613   int64_t dimSize = -1;
1614   if (lhsIndex >= 0) {
1615     iterIndex = iMap[0].getDimPosition(lhsIndex);
1616     assert((rhsIndex < 0 || iterIndex == iMap[1].getDimPosition(rhsIndex)) &&
1617            "parallel index should be free in LHS or batch in LHS/RHS");
1618     dimSize = lhsType.getDimSize(lhsIndex);
1619   } else {
1620     assert(rhsIndex >= 0 && "missing parallel index");
1621     iterIndex = iMap[1].getDimPosition(rhsIndex);
1622     dimSize = rhsType.getDimSize(rhsIndex);
1623   }
1624   assert(iterIndex >= 0 && "parallel index not listed in operand mapping");
1625   Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex);
1626   assert(lookup.hasValue() && "parallel index not listed in reduction");
1627   int64_t resIndex = lookup.getValue();
1628   // Construct new iterator types and affine map array attribute.
1629   std::array<AffineMap, 3> lowIndexingMaps = {
1630       adjustMap(iMap[0], iterIndex, rewriter),
1631       adjustMap(iMap[1], iterIndex, rewriter),
1632       adjustMap(iMap[2], iterIndex, rewriter)};
1633   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1634   auto lowIter =
1635       rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1636   // Unroll into a series of lower dimensional vector.contract ops.
1637   Location loc = op.getLoc();
1638   Value result = rewriter.create<arith::ConstantOp>(
1639       loc, resType, rewriter.getZeroAttr(resType));
1640   for (int64_t d = 0; d < dimSize; ++d) {
1641     auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1642     auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1643     auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
1644     Value lowContract = rewriter.create<vector::ContractionOp>(
1645         loc, lhs, rhs, acc, lowAffine, lowIter);
1646     result =
1647         reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter);
1648   }
1649   return result;
1650 }
1651 
1652 // Lower one reduction dimension.
1653 Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
1654                                             PatternRewriter &rewriter) const {
1655   auto loc = op.getLoc();
1656   VectorType lhsType = op.getLhsType();
1657   VectorType rhsType = op.getRhsType();
1658   Type resType = op.getResultType();
1659   assert(!resType.isa<VectorType>());
1660   bool isInt = resType.isa<IntegerType>();
1661   // Use iterator index 0.
1662   int64_t iterIndex = 0;
1663   SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
1664   Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
1665   Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
1666   assert(lookupLhs.hasValue() && "missing LHS parallel index");
1667   assert(lookupRhs.hasValue() && "missing RHS parallel index");
1668   int64_t lhsIndex = lookupLhs.getValue();
1669   int64_t rhsIndex = lookupRhs.getValue();
1670   int64_t dimSize = lhsType.getDimSize(lhsIndex);
1671   assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape");
1672   // Base case.
1673   if (lhsType.getRank() == 1) {
1674     assert(rhsType.getRank() == 1 && "corrupt contraction");
1675     Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
1676     auto kind = vector::CombiningKind::ADD;
1677     Value res = rewriter.create<vector::ReductionOp>(loc, kind, m);
1678     if (auto acc = op.getAcc())
1679       res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
1680     return res;
1681   }
1682   // Construct new iterator types and affine map array attribute.
1683   std::array<AffineMap, 3> lowIndexingMaps = {
1684       adjustMap(iMap[0], iterIndex, rewriter),
1685       adjustMap(iMap[1], iterIndex, rewriter),
1686       adjustMap(iMap[2], iterIndex, rewriter)};
1687   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1688   auto lowIter =
1689       rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1690   // Unroll into a series of lower dimensional vector.contract ops.
1691   // By feeding the initial accumulator into the first contraction,
1692   // and the result of each contraction into the next, eventually
1693   // the sum of all reductions is computed.
1694   Value result = op.getAcc();
1695   for (int64_t d = 0; d < dimSize; ++d) {
1696     auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1697     auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1698     result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result,
1699                                                     lowAffine, lowIter);
1700   }
1701   return result;
1702 }
1703 
1704 } // namespace mlir
1705 
1706 Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
1707     OpBuilder &builder, Operation *op, ArrayRef<Value> ids,
1708     ArrayRef<int64_t> multiplicity, const AffineMap &map) {
1709   OpBuilder::InsertionGuard guard(builder);
1710   builder.setInsertionPointAfter(op);
1711   Location loc = op->getLoc();
1712   if (op->getNumResults() != 1)
1713     return {};
1714   Value result = op->getResult(0);
1715   VectorType type = op->getResult(0).getType().dyn_cast<VectorType>();
1716   if (!type || map.getNumResults() != multiplicity.size())
1717     return {};
1718   // For each dimension being distributed check that the size is a multiple of
1719   // the multiplicity. To handle more sizes we would need to support masking.
1720   unsigned multiplictyCount = 0;
1721   for (auto exp : map.getResults()) {
1722     auto affinExp = exp.dyn_cast<AffineDimExpr>();
1723     if (!affinExp || affinExp.getPosition() >= type.getRank() ||
1724         type.getDimSize(affinExp.getPosition()) %
1725                 multiplicity[multiplictyCount++] !=
1726             0)
1727       return {};
1728   }
1729   DistributeOps ops;
1730   ops.extract =
1731       builder.create<vector::ExtractMapOp>(loc, result, ids, multiplicity, map);
1732   ops.insert =
1733       builder.create<vector::InsertMapOp>(loc, ops.extract, result, ids);
1734   return ops;
1735 }
1736 
1737 /// Progressive lowering of transfer_read. This pattern supports lowering of
1738 /// `vector.transfer_read` to a combination of `vector.load` and
1739 /// `vector.broadcast` if all of the following hold:
1740 /// - Stride of most minor memref dimension must be 1.
1741 /// - Out-of-bounds masking is not required.
1742 /// - If the memref's element type is a vector type then it coincides with the
1743 ///   result type.
1744 /// - The permutation map doesn't perform permutation (broadcasting is allowed).
1745 struct TransferReadToVectorLoadLowering
1746     : public OpRewritePattern<vector::TransferReadOp> {
1747   TransferReadToVectorLoadLowering(MLIRContext *context,
1748                                    llvm::Optional<unsigned> maxRank)
1749       : OpRewritePattern<vector::TransferReadOp>(context),
1750         maxTransferRank(maxRank) {}
1751 
1752   LogicalResult matchAndRewrite(vector::TransferReadOp read,
1753                                 PatternRewriter &rewriter) const override {
1754     if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank)
1755       return failure();
1756 
1757     SmallVector<unsigned, 4> broadcastedDims;
1758     // Permutations are handled by VectorToSCF or
1759     // populateVectorTransferPermutationMapLoweringPatterns.
1760     // We let the 0-d corner case pass-through as it is supported.
1761     if (!read.getPermutationMap().isMinorIdentityWithBroadcasting(
1762             &broadcastedDims))
1763       return failure();
1764 
1765     auto memRefType = read.getShapedType().dyn_cast<MemRefType>();
1766     if (!memRefType)
1767       return failure();
1768 
1769     // Non-unit strides are handled by VectorToSCF.
1770     if (!vector::isLastMemrefDimUnitStride(memRefType))
1771       return failure();
1772 
1773     // If there is broadcasting involved then we first load the unbroadcasted
1774     // vector, and then broadcast it with `vector.broadcast`.
1775     ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
1776     SmallVector<int64_t, 4> unbroadcastedVectorShape(vectorShape.begin(),
1777                                                      vectorShape.end());
1778     for (unsigned i : broadcastedDims)
1779       unbroadcastedVectorShape[i] = 1;
1780     VectorType unbroadcastedVectorType = VectorType::get(
1781         unbroadcastedVectorShape, read.getVectorType().getElementType());
1782 
1783     // `vector.load` supports vector types as memref's elements only when the
1784     // resulting vector type is the same as the element type.
1785     auto memrefElTy = memRefType.getElementType();
1786     if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType)
1787       return failure();
1788 
1789     // Otherwise, element types of the memref and the vector must match.
1790     if (!memrefElTy.isa<VectorType>() &&
1791         memrefElTy != read.getVectorType().getElementType())
1792       return failure();
1793 
1794     // Out-of-bounds dims are handled by MaterializeTransferMask.
1795     if (read.hasOutOfBoundsDim())
1796       return failure();
1797 
1798     // Create vector load op.
1799     Operation *loadOp;
1800     if (read.getMask()) {
1801       Value fill = rewriter.create<vector::SplatOp>(
1802           read.getLoc(), unbroadcastedVectorType, read.getPadding());
1803       loadOp = rewriter.create<vector::MaskedLoadOp>(
1804           read.getLoc(), unbroadcastedVectorType, read.getSource(),
1805           read.getIndices(), read.getMask(), fill);
1806     } else {
1807       loadOp = rewriter.create<vector::LoadOp>(
1808           read.getLoc(), unbroadcastedVectorType, read.getSource(),
1809           read.getIndices());
1810     }
1811 
1812     // Insert a broadcasting op if required.
1813     if (!broadcastedDims.empty()) {
1814       rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1815           read, read.getVectorType(), loadOp->getResult(0));
1816     } else {
1817       rewriter.replaceOp(read, loadOp->getResult(0));
1818     }
1819 
1820     return success();
1821   }
1822 
1823   llvm::Optional<unsigned> maxTransferRank;
1824 };
1825 
1826 /// Replace a 0-d vector.load with a memref.load + vector.broadcast.
1827 // TODO: we shouldn't cross the vector/scalar domains just for this
1828 // but atm we lack the infra to avoid it. Possible solutions include:
1829 // - go directly to LLVM + bitcast
1830 // - introduce a bitcast op and likely a new pointer dialect
1831 // - let memref.load/store additionally support the 0-d vector case
1832 // There are still deeper data layout issues lingering even in this
1833 // trivial case (for architectures for which this matters).
1834 struct VectorLoadToMemrefLoadLowering
1835     : public OpRewritePattern<vector::LoadOp> {
1836   using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
1837 
1838   LogicalResult matchAndRewrite(vector::LoadOp loadOp,
1839                                 PatternRewriter &rewriter) const override {
1840     auto vecType = loadOp.getVectorType();
1841     if (vecType.getNumElements() != 1)
1842       return failure();
1843     auto memrefLoad = rewriter.create<memref::LoadOp>(
1844         loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices());
1845     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
1846                                                      memrefLoad);
1847     return success();
1848   }
1849 };
1850 
1851 /// Replace a 0-d vector.store with a vector.extractelement + memref.store.
1852 struct VectorStoreToMemrefStoreLowering
1853     : public OpRewritePattern<vector::StoreOp> {
1854   using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
1855 
1856   LogicalResult matchAndRewrite(vector::StoreOp storeOp,
1857                                 PatternRewriter &rewriter) const override {
1858     auto vecType = storeOp.getVectorType();
1859     if (vecType.getNumElements() != 1)
1860       return failure();
1861     Value extracted;
1862     if (vecType.getRank() == 0) {
1863       // TODO: Unifiy once ExtractOp supports 0-d vectors.
1864       extracted = rewriter.create<vector::ExtractElementOp>(
1865           storeOp.getLoc(), storeOp.getValueToStore());
1866     } else {
1867       SmallVector<int64_t> indices(vecType.getRank(), 0);
1868       extracted = rewriter.create<vector::ExtractOp>(
1869           storeOp.getLoc(), storeOp.getValueToStore(), indices);
1870     }
1871 
1872     rewriter.replaceOpWithNewOp<memref::StoreOp>(
1873         storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
1874     return success();
1875   }
1876 };
1877 
1878 /// Progressive lowering of transfer_write. This pattern supports lowering of
1879 /// `vector.transfer_write` to `vector.store` if all of the following hold:
1880 /// - Stride of most minor memref dimension must be 1.
1881 /// - Out-of-bounds masking is not required.
1882 /// - If the memref's element type is a vector type then it coincides with the
1883 ///   type of the written value.
1884 /// - The permutation map is the minor identity map (neither permutation nor
1885 ///   broadcasting is allowed).
1886 struct TransferWriteToVectorStoreLowering
1887     : public OpRewritePattern<vector::TransferWriteOp> {
1888   TransferWriteToVectorStoreLowering(MLIRContext *context,
1889                                      llvm::Optional<unsigned> maxRank)
1890       : OpRewritePattern<vector::TransferWriteOp>(context),
1891         maxTransferRank(maxRank) {}
1892 
1893   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
1894                                 PatternRewriter &rewriter) const override {
1895     if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank)
1896       return failure();
1897 
1898     // Permutations are handled by VectorToSCF or
1899     // populateVectorTransferPermutationMapLoweringPatterns.
1900     if ( // pass-through for the 0-d corner case.
1901         !write.getPermutationMap().isMinorIdentity())
1902       return failure();
1903 
1904     auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
1905     if (!memRefType)
1906       return failure();
1907 
1908     // Non-unit strides are handled by VectorToSCF.
1909     if (!vector::isLastMemrefDimUnitStride(memRefType))
1910       return failure();
1911 
1912     // `vector.store` supports vector types as memref's elements only when the
1913     // type of the vector value being written is the same as the element type.
1914     auto memrefElTy = memRefType.getElementType();
1915     if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType())
1916       return failure();
1917 
1918     // Otherwise, element types of the memref and the vector must match.
1919     if (!memrefElTy.isa<VectorType>() &&
1920         memrefElTy != write.getVectorType().getElementType())
1921       return failure();
1922 
1923     // Out-of-bounds dims are handled by MaterializeTransferMask.
1924     if (write.hasOutOfBoundsDim())
1925       return failure();
1926     if (write.getMask()) {
1927       rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
1928           write, write.getSource(), write.getIndices(), write.getMask(),
1929           write.getVector());
1930     } else {
1931       rewriter.replaceOpWithNewOp<vector::StoreOp>(
1932           write, write.getVector(), write.getSource(), write.getIndices());
1933     }
1934     return success();
1935   }
1936 
1937   llvm::Optional<unsigned> maxTransferRank;
1938 };
1939 
1940 // Returns the values in `arrayAttr` as an integer vector.
1941 static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) {
1942   return llvm::to_vector<4>(
1943       llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
1944                       [](IntegerAttr attr) { return attr.getInt(); }));
1945 }
1946 
1947 // Shuffles vector.bitcast op after vector.extract op.
1948 //
1949 // This transforms IR like:
1950 //   %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
1951 //   %1 = vector.extract %0[3] : vector<8xf16>
1952 // Into:
1953 //   %0 = vector.extract %src[1] : vector<4xf32>
1954 //   %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16>
1955 //   %2 = vector.extract %1[1] : vector<2xf16>
1956 struct BubbleDownVectorBitCastForExtract
1957     : public OpRewritePattern<vector::ExtractOp> {
1958   using OpRewritePattern::OpRewritePattern;
1959 
1960   LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
1961                                 PatternRewriter &rewriter) const override {
1962     // Only support extracting scalars for now.
1963     if (extractOp.getVectorType().getRank() != 1)
1964       return failure();
1965 
1966     auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
1967     if (!castOp)
1968       return failure();
1969 
1970     VectorType castSrcType = castOp.getSourceVectorType();
1971     VectorType castDstType = castOp.getResultVectorType();
1972     assert(castSrcType.getRank() == castDstType.getRank());
1973 
1974     // Fail to match if we only have one element in the cast op source.
1975     // This is to avoid infinite loop given that this pattern can generate
1976     // such cases.
1977     if (castSrcType.getNumElements() == 1)
1978       return failure();
1979 
1980     // Only support casting to a larger number of elements or now.
1981     // E.g., vector<4xf32> -> vector<8xf16>.
1982     if (castSrcType.getNumElements() > castDstType.getNumElements())
1983       return failure();
1984 
1985     unsigned expandRatio =
1986         castDstType.getNumElements() / castSrcType.getNumElements();
1987 
1988     auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t {
1989       return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
1990     };
1991 
1992     uint64_t index = getFirstIntValue(extractOp.getPosition());
1993 
1994     // Get the single scalar (as a vector) in the source value that packs the
1995     // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
1996     VectorType oneScalarType =
1997         VectorType::get({1}, castSrcType.getElementType());
1998     Value packedValue = rewriter.create<vector::ExtractOp>(
1999         extractOp.getLoc(), oneScalarType, castOp.getSource(),
2000         rewriter.getI64ArrayAttr(index / expandRatio));
2001 
2002     // Cast it to a vector with the desired scalar's type.
2003     // E.g. f32 -> vector<2xf16>
2004     VectorType packedType =
2005         VectorType::get({expandRatio}, castDstType.getElementType());
2006     Value castedValue = rewriter.create<vector::BitCastOp>(
2007         extractOp.getLoc(), packedType, packedValue);
2008 
2009     // Finally extract the desired scalar.
2010     rewriter.replaceOpWithNewOp<vector::ExtractOp>(
2011         extractOp, extractOp.getType(), castedValue,
2012         rewriter.getI64ArrayAttr(index % expandRatio));
2013 
2014     return success();
2015   }
2016 };
2017 
2018 // Shuffles vector.bitcast op after vector.extract_strided_slice op.
2019 //
2020 // This transforms IR like:
2021 //    %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
2022 //     %0 = vector.extract_strided_slice %cast {
2023 //            offsets = [4], sizes = [4], strides = [1]
2024 //          } : vector<8xf16> to vector<4xf16>
2025 // Into:
2026 //   %0 = vector.extract_strided_slice %src {
2027 //          offsets = [2], sizes = [2], strides = [1]
2028 //        } : vector<4xf32> to vector<2xf32>
2029 //   %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16>
2030 struct BubbleDownBitCastForStridedSliceExtract
2031     : public OpRewritePattern<vector::ExtractStridedSliceOp> {
2032   using OpRewritePattern::OpRewritePattern;
2033 
2034   LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
2035                                 PatternRewriter &rewriter) const override {
2036     auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
2037     if (!castOp)
2038       return failure();
2039 
2040     VectorType castSrcType = castOp.getSourceVectorType();
2041     VectorType castDstType = castOp.getResultVectorType();
2042     assert(castSrcType.getRank() == castDstType.getRank());
2043 
2044     int64_t castSrcLastDim = castSrcType.getShape().back();
2045     int64_t castDstLastDim = castDstType.getShape().back();
2046     // Require casting to more elements for now; other cases to be implemented.
2047     if (castSrcLastDim > castDstLastDim)
2048       return failure();
2049 
2050     // Only accept all one strides for now.
2051     if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
2052                      [](const APInt &val) { return !val.isOneValue(); }))
2053       return failure();
2054 
2055     unsigned rank = extractOp.getVectorType().getRank();
2056     assert(castDstLastDim % castSrcLastDim == 0);
2057     int64_t expandRatio = castDstLastDim / castSrcLastDim;
2058 
2059     // If we have a less number of offsets than the rank, then implicitly we
2060     // are selecting the full range for the last bitcasted dimension; other
2061     // dimensions aren't affected. Otherwise, we need to scale down the last
2062     // dimension's offset given we are extracting from less elements now.
2063     ArrayAttr newOffsets = extractOp.getOffsets();
2064     if (newOffsets.size() == rank) {
2065       SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
2066       if (offsets.back() % expandRatio != 0)
2067         return failure();
2068       offsets.back() = offsets.back() / expandRatio;
2069       newOffsets = rewriter.getI64ArrayAttr(offsets);
2070     }
2071 
2072     // Similarly for sizes.
2073     ArrayAttr newSizes = extractOp.getSizes();
2074     if (newSizes.size() == rank) {
2075       SmallVector<int64_t, 4> sizes = getIntValueVector(newSizes);
2076       if (sizes.back() % expandRatio != 0)
2077         return failure();
2078       sizes.back() = sizes.back() / expandRatio;
2079       newSizes = rewriter.getI64ArrayAttr(sizes);
2080     }
2081 
2082     SmallVector<int64_t, 4> dims =
2083         llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape());
2084     dims.back() = dims.back() / expandRatio;
2085     VectorType newExtractType =
2086         VectorType::get(dims, castSrcType.getElementType());
2087 
2088     auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
2089         extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets,
2090         newSizes, extractOp.getStrides());
2091 
2092     rewriter.replaceOpWithNewOp<vector::BitCastOp>(
2093         extractOp, extractOp.getType(), newExtractOp);
2094 
2095     return success();
2096   }
2097 };
2098 
2099 // Shuffles vector.bitcast op before vector.insert_strided_slice op.
2100 //
2101 // This transforms IR like:
2102 //   %0 = vector.insert_strided_slice %src, %dst {
2103 //          offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
2104 //   %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
2105 // Into:
2106 //   %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32>
2107 //   %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32>
2108 //   %2 = vector.insert_strided_slice %src, %dst {
2109 //          offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
2110 struct BubbleUpBitCastForStridedSliceInsert
2111     : public OpRewritePattern<vector::BitCastOp> {
2112   using OpRewritePattern::OpRewritePattern;
2113   LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
2114                                 PatternRewriter &rewriter) const override {
2115     VectorType castSrcType = bitcastOp.getSourceVectorType();
2116     VectorType castDstType = bitcastOp.getResultVectorType();
2117     assert(castSrcType.getRank() == castDstType.getRank());
2118 
2119     int64_t castSrcLastDim = castSrcType.getShape().back();
2120     int64_t castDstLastDim = castDstType.getShape().back();
2121     // Require casting to less elements for now; other cases to be implemented.
2122     if (castSrcLastDim < castDstLastDim)
2123       return failure();
2124 
2125     assert(castSrcLastDim % castDstLastDim == 0);
2126     int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
2127 
2128     auto insertOp =
2129         bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
2130     if (!insertOp)
2131       return failure();
2132 
2133     // Only accept all one strides for now.
2134     if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
2135                      [](const APInt &val) { return !val.isOneValue(); }))
2136       return failure();
2137 
2138     unsigned rank = insertOp.getSourceVectorType().getRank();
2139     // Require insert op to have the same rank for the source and destination
2140     // vector; other cases to be implemented.
2141     if (rank != insertOp.getDestVectorType().getRank())
2142       return failure();
2143 
2144     ArrayAttr newOffsets = insertOp.getOffsets();
2145     assert(newOffsets.size() == rank);
2146     SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
2147     if (offsets.back() % shrinkRatio != 0)
2148       return failure();
2149     offsets.back() = offsets.back() / shrinkRatio;
2150     newOffsets = rewriter.getI64ArrayAttr(offsets);
2151 
2152     SmallVector<int64_t, 4> srcDims =
2153         llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
2154     srcDims.back() = srcDims.back() / shrinkRatio;
2155     VectorType newCastSrcType =
2156         VectorType::get(srcDims, castDstType.getElementType());
2157 
2158     auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
2159         bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
2160 
2161     SmallVector<int64_t, 4> dstDims =
2162         llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
2163     dstDims.back() = dstDims.back() / shrinkRatio;
2164     VectorType newCastDstType =
2165         VectorType::get(dstDims, castDstType.getElementType());
2166 
2167     auto newCastDstOp = rewriter.create<vector::BitCastOp>(
2168         bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
2169 
2170     rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
2171         bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
2172         insertOp.getStrides());
2173 
2174     return success();
2175   }
2176 };
2177 
2178 // Helper that returns a vector comparison that constructs a mask:
2179 //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
2180 //
2181 // If `dim == 0` then the result will be a 0-D vector.
2182 //
2183 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
2184 //       much more compact, IR for this operation, but LLVM eventually
2185 //       generates more elaborate instructions for this intrinsic since it
2186 //       is very conservative on the boundary conditions.
2187 static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
2188                                    bool force32BitVectorIndices, int64_t dim,
2189                                    Value b, Value *off = nullptr) {
2190   auto loc = op->getLoc();
2191   // If we can assume all indices fit in 32-bit, we perform the vector
2192   // comparison in 32-bit to get a higher degree of SIMD parallelism.
2193   // Otherwise we perform the vector comparison using 64-bit indices.
2194   Type idxType =
2195       force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
2196   DenseIntElementsAttr indicesAttr;
2197   if (dim == 0 && force32BitVectorIndices) {
2198     indicesAttr = DenseIntElementsAttr::get(
2199         VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0});
2200   } else if (dim == 0) {
2201     indicesAttr = DenseIntElementsAttr::get(
2202         VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0});
2203   } else if (force32BitVectorIndices) {
2204     indicesAttr = rewriter.getI32VectorAttr(
2205         llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
2206   } else {
2207     indicesAttr = rewriter.getI64VectorAttr(
2208         llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
2209   }
2210   Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
2211   // Add in an offset if requested.
2212   if (off) {
2213     Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
2214     Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o);
2215     indices = rewriter.create<arith::AddIOp>(loc, ov, indices);
2216   }
2217   // Construct the vector comparison.
2218   Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
2219   Value bounds =
2220       rewriter.create<vector::SplatOp>(loc, indices.getType(), bound);
2221   return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
2222                                         bounds);
2223 }
2224 
2225 template <typename ConcreteOp>
2226 struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
2227 public:
2228   explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt)
2229       : mlir::OpRewritePattern<ConcreteOp>(context),
2230         force32BitVectorIndices(enableIndexOpt) {}
2231 
2232   LogicalResult matchAndRewrite(ConcreteOp xferOp,
2233                                 PatternRewriter &rewriter) const override {
2234     if (!xferOp.hasOutOfBoundsDim())
2235       return failure();
2236 
2237     if (xferOp.getVectorType().getRank() > 1 ||
2238         llvm::size(xferOp.getIndices()) == 0)
2239       return failure();
2240 
2241     Location loc = xferOp->getLoc();
2242     VectorType vtp = xferOp.getVectorType();
2243 
2244     // Create the in-bounds mask with all elements between [0 .. dim - offset)
2245     // set and [dim - offset .. vector_length) unset.
2246     //
2247     // TODO: when the leaf transfer rank is k > 1, we need the last `k`
2248     //       dimensions here.
2249     unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
2250     Value off = xferOp.getIndices()[lastIndex];
2251     Value dim =
2252         vector::createOrFoldDimOp(rewriter, loc, xferOp.getSource(), lastIndex);
2253     Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off);
2254     Value mask = rewriter.create<vector::CreateMaskOp>(
2255         loc,
2256         VectorType::get(vtp.getShape(), rewriter.getI1Type(),
2257                         vtp.getNumScalableDims()),
2258         b);
2259     if (xferOp.getMask()) {
2260       // Intersect the in-bounds with the mask specified as an op parameter.
2261       mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask());
2262     }
2263 
2264     rewriter.updateRootInPlace(xferOp, [&]() {
2265       xferOp.getMaskMutable().assign(mask);
2266       xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
2267     });
2268 
2269     return success();
2270   }
2271 
2272 private:
2273   const bool force32BitVectorIndices;
2274 };
2275 
2276 /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only).
2277 class VectorCreateMaskOpConversion
2278     : public OpRewritePattern<vector::CreateMaskOp> {
2279 public:
2280   explicit VectorCreateMaskOpConversion(MLIRContext *context,
2281                                         bool enableIndexOpt)
2282       : mlir::OpRewritePattern<vector::CreateMaskOp>(context),
2283         force32BitVectorIndices(enableIndexOpt) {}
2284 
2285   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
2286                                 PatternRewriter &rewriter) const override {
2287     auto dstType = op.getType();
2288     if (dstType.cast<VectorType>().isScalable())
2289       return failure();
2290     int64_t rank = dstType.getRank();
2291     if (rank > 1)
2292       return failure();
2293     rewriter.replaceOp(
2294         op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
2295                                   rank == 0 ? 0 : dstType.getDimSize(0),
2296                                   op.getOperand(0)));
2297     return success();
2298   }
2299 
2300 private:
2301   const bool force32BitVectorIndices;
2302 };
2303 
2304 // Drop inner most contiguous unit dimensions from transfer_read operand.
2305 class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
2306   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
2307 
2308   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
2309                                 PatternRewriter &rewriter) const override {
2310     // TODO: support 0-d corner case.
2311     if (readOp.getTransferRank() == 0)
2312       return failure();
2313 
2314     // TODO: support mask.
2315     if (readOp.getMask())
2316       return failure();
2317 
2318     auto srcType = readOp.getSource().getType().dyn_cast<MemRefType>();
2319     if (!srcType || !srcType.hasStaticShape())
2320       return failure();
2321 
2322     if (!readOp.getPermutationMap().isMinorIdentity())
2323       return failure();
2324 
2325     auto targetType = readOp.getVectorType();
2326     if (targetType.getRank() <= 1)
2327       return failure();
2328 
2329     SmallVector<int64_t> srcStrides;
2330     int64_t srcOffset;
2331     if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
2332       return failure();
2333 
2334     size_t dimsToDrop = 0;
2335     for (size_t i = 1; i < srcStrides.size(); ++i) {
2336       int dim = srcType.getRank() - i - 1;
2337       if (srcStrides[dim] == 1) {
2338         dimsToDrop++;
2339       } else {
2340         break;
2341       }
2342     }
2343     if (dimsToDrop == 0)
2344       return failure();
2345 
2346     auto resultTargetVecType =
2347         VectorType::get(targetType.getShape().drop_back(dimsToDrop),
2348                         targetType.getElementType());
2349 
2350     MemRefType resultMemrefType;
2351     if (srcType.getLayout().getAffineMap().isIdentity()) {
2352       resultMemrefType = MemRefType::get(
2353           srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
2354           {}, srcType.getMemorySpaceAsInt());
2355     } else {
2356       AffineMap map = srcType.getLayout().getAffineMap();
2357       int numResultDims = map.getNumDims() - dimsToDrop;
2358       int numSymbols = map.getNumSymbols();
2359       for (size_t i = 0; i < dimsToDrop; ++i) {
2360         int dim = srcType.getRank() - i - 1;
2361         map = map.replace(rewriter.getAffineDimExpr(dim),
2362                           rewriter.getAffineConstantExpr(0), numResultDims,
2363                           numSymbols);
2364       }
2365       resultMemrefType = MemRefType::get(
2366           srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
2367           map, srcType.getMemorySpaceAsInt());
2368     }
2369 
2370     auto loc = readOp.getLoc();
2371     SmallVector<int64_t> offsets(srcType.getRank(), 0);
2372     SmallVector<int64_t> strides(srcType.getRank(), 1);
2373 
2374     ArrayAttr inBoundsAttr =
2375         readOp.getInBounds()
2376             ? rewriter.getArrayAttr(
2377                   readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
2378             : ArrayAttr();
2379     Value rankedReducedView = rewriter.create<memref::SubViewOp>(
2380         loc, resultMemrefType, readOp.getSource(), offsets, srcType.getShape(),
2381         strides);
2382     auto permMap = getTransferMinorIdentityMap(
2383         rankedReducedView.getType().cast<ShapedType>(), resultTargetVecType);
2384     Value result = rewriter.create<vector::TransferReadOp>(
2385         loc, resultTargetVecType, rankedReducedView,
2386         readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
2387         readOp.getPadding(),
2388         // TODO: support mask.
2389         /*mask=*/Value(), inBoundsAttr);
2390     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
2391                                                      result);
2392     return success();
2393   }
2394 };
2395 
2396 namespace {
2397 
2398 /// This function checks to see if the vector combining kind
2399 /// is consistent with the integer or float element type.
2400 static bool isValidKind(bool isInt, vector::CombiningKind kind) {
2401   using vector::CombiningKind;
2402   enum class KindType { FLOAT, INT, INVALID };
2403   KindType type{KindType::INVALID};
2404   switch (kind) {
2405   case CombiningKind::MINF:
2406   case CombiningKind::MAXF:
2407     type = KindType::FLOAT;
2408     break;
2409   case CombiningKind::MINUI:
2410   case CombiningKind::MINSI:
2411   case CombiningKind::MAXUI:
2412   case CombiningKind::MAXSI:
2413   case CombiningKind::AND:
2414   case CombiningKind::OR:
2415   case CombiningKind::XOR:
2416     type = KindType::INT;
2417     break;
2418   case CombiningKind::ADD:
2419   case CombiningKind::MUL:
2420     type = isInt ? KindType::INT : KindType::FLOAT;
2421     break;
2422   }
2423   bool isValidIntKind = (type == KindType::INT) && isInt;
2424   bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt);
2425   return (isValidIntKind || isValidFloatKind);
2426 }
2427 
2428 /// This function constructs the appropriate integer or float
2429 /// operation given the vector combining kind and operands. The
2430 /// supported int operations are : add, mul, min (signed/unsigned),
2431 /// max(signed/unsigned), and, or, xor. The supported float
2432 /// operations are : add, mul, min and max.
2433 static Value genOperator(Location loc, Value x, Value y,
2434                          vector::CombiningKind kind,
2435                          PatternRewriter &rewriter) {
2436   using vector::CombiningKind;
2437 
2438   auto elType = x.getType().cast<VectorType>().getElementType();
2439   bool isInt = elType.isIntOrIndex();
2440 
2441   Value combinedResult{nullptr};
2442   switch (kind) {
2443   case CombiningKind::ADD:
2444     if (isInt)
2445       combinedResult = rewriter.create<arith::AddIOp>(loc, x, y);
2446     else
2447       combinedResult = rewriter.create<arith::AddFOp>(loc, x, y);
2448     break;
2449   case CombiningKind::MUL:
2450     if (isInt)
2451       combinedResult = rewriter.create<arith::MulIOp>(loc, x, y);
2452     else
2453       combinedResult = rewriter.create<arith::MulFOp>(loc, x, y);
2454     break;
2455   case CombiningKind::MINUI:
2456     combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y);
2457     break;
2458   case CombiningKind::MINSI:
2459     combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y);
2460     break;
2461   case CombiningKind::MAXUI:
2462     combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y);
2463     break;
2464   case CombiningKind::MAXSI:
2465     combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y);
2466     break;
2467   case CombiningKind::AND:
2468     combinedResult = rewriter.create<arith::AndIOp>(loc, x, y);
2469     break;
2470   case CombiningKind::OR:
2471     combinedResult = rewriter.create<arith::OrIOp>(loc, x, y);
2472     break;
2473   case CombiningKind::XOR:
2474     combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y);
2475     break;
2476   case CombiningKind::MINF:
2477     combinedResult = rewriter.create<arith::MinFOp>(loc, x, y);
2478     break;
2479   case CombiningKind::MAXF:
2480     combinedResult = rewriter.create<arith::MaxFOp>(loc, x, y);
2481     break;
2482   }
2483   return combinedResult;
2484 }
2485 
2486 /// Convert vector.scan op into arith ops and
2487 /// vector.insert_strided_slice/extract_strided_slice
2488 ///
2489 /// Ex:
2490 /// ```
2491 ///   %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim =
2492 ///   1} :
2493 ///     (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>)
2494 /// ```
2495 /// Gets converted to:
2496 /// ```
2497 ///   %cst = arith.constant dense<0> : vector<2x3xi32>
2498 ///   %0 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 1],
2499 ///   strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %1 =
2500 ///   vector.insert_strided_slice %0, %cst {offsets = [0, 0], strides = [1, 1]}
2501 ///   : vector<2x1xi32> into vector<2x3xi32> %2 = vector.extract_strided_slice
2502 ///   %arg0 {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} :
2503 ///   vector<2x3xi32> to vector<2x1xi32> %3 = arith.muli %0, %2 :
2504 ///   vector<2x1xi32> %4 = vector.insert_strided_slice %3, %1 {offsets = [0, 1],
2505 ///   strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %5 =
2506 ///   vector.extract_strided_slice %arg0 {offsets = [0, 2], sizes = [2, 1],
2507 ///   strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %6 = arith.muli %3,
2508 ///   %5 : vector<2x1xi32> %7 = vector.insert_strided_slice %6, %4 {offsets =
2509 ///   [0, 2], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %8 =
2510 ///   vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32> return %7, %8 :
2511 ///   vector<2x3xi32>, vector<2xi32>
2512 /// ```
2513 struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
2514   using OpRewritePattern<vector::ScanOp>::OpRewritePattern;
2515 
2516   LogicalResult matchAndRewrite(vector::ScanOp scanOp,
2517                                 PatternRewriter &rewriter) const override {
2518     auto loc = scanOp.getLoc();
2519     VectorType destType = scanOp.getDestType();
2520     ArrayRef<int64_t> destShape = destType.getShape();
2521     auto elType = destType.getElementType();
2522     bool isInt = elType.isIntOrIndex();
2523     if (!isValidKind(isInt, scanOp.getKind()))
2524       return failure();
2525 
2526     VectorType resType = VectorType::get(destShape, elType);
2527     Value result = rewriter.create<arith::ConstantOp>(
2528         loc, resType, rewriter.getZeroAttr(resType));
2529     int64_t reductionDim = scanOp.getReductionDim();
2530     bool inclusive = scanOp.getInclusive();
2531     int64_t destRank = destType.getRank();
2532     VectorType initialValueType = scanOp.getInitialValueType();
2533     int64_t initialValueRank = initialValueType.getRank();
2534 
2535     SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end());
2536     reductionShape[reductionDim] = 1;
2537     VectorType reductionType = VectorType::get(reductionShape, elType);
2538     SmallVector<int64_t> offsets(destRank, 0);
2539     SmallVector<int64_t> strides(destRank, 1);
2540     SmallVector<int64_t> sizes(destShape.begin(), destShape.end());
2541     sizes[reductionDim] = 1;
2542     ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes);
2543     ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides);
2544 
2545     Value lastOutput, lastInput;
2546     for (int i = 0; i < destShape[reductionDim]; i++) {
2547       offsets[reductionDim] = i;
2548       ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets);
2549       Value input = rewriter.create<vector::ExtractStridedSliceOp>(
2550           loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes,
2551           scanStrides);
2552       Value output;
2553       if (i == 0) {
2554         if (inclusive) {
2555           output = input;
2556         } else {
2557           if (initialValueRank == 0) {
2558             // ShapeCastOp cannot handle 0-D vectors
2559             output = rewriter.create<vector::BroadcastOp>(
2560                 loc, input.getType(), scanOp.getInitialValue());
2561           } else {
2562             output = rewriter.create<vector::ShapeCastOp>(
2563                 loc, input.getType(), scanOp.getInitialValue());
2564           }
2565         }
2566       } else {
2567         Value y = inclusive ? input : lastInput;
2568         output = genOperator(loc, lastOutput, y, scanOp.getKind(), rewriter);
2569         assert(output != nullptr);
2570       }
2571       result = rewriter.create<vector::InsertStridedSliceOp>(
2572           loc, output, result, offsets, strides);
2573       lastOutput = output;
2574       lastInput = input;
2575     }
2576 
2577     Value reduction;
2578     if (initialValueRank == 0) {
2579       Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0);
2580       reduction =
2581           rewriter.create<vector::BroadcastOp>(loc, initialValueType, v);
2582     } else {
2583       reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType,
2584                                                        lastOutput);
2585     }
2586 
2587     rewriter.replaceOp(scanOp, {result, reduction});
2588     return success();
2589   }
2590 };
2591 
2592 } // namespace
2593 
2594 void mlir::vector::populateVectorMaskMaterializationPatterns(
2595     RewritePatternSet &patterns, bool force32BitVectorIndices) {
2596   patterns.add<VectorCreateMaskOpConversion,
2597                MaterializeTransferMask<vector::TransferReadOp>,
2598                MaterializeTransferMask<vector::TransferWriteOp>>(
2599       patterns.getContext(), force32BitVectorIndices);
2600 }
2601 
2602 void mlir::vector::populateShapeCastFoldingPatterns(
2603     RewritePatternSet &patterns) {
2604   patterns.add<ShapeCastOpFolder>(patterns.getContext());
2605 }
2606 
2607 void mlir::vector::populateBubbleVectorBitCastOpPatterns(
2608     RewritePatternSet &patterns) {
2609   patterns.add<BubbleDownVectorBitCastForExtract,
2610                BubbleDownBitCastForStridedSliceExtract,
2611                BubbleUpBitCastForStridedSliceInsert>(patterns.getContext());
2612 }
2613 
2614 void mlir::vector::populateVectorBroadcastLoweringPatterns(
2615     RewritePatternSet &patterns) {
2616   patterns.add<BroadcastOpLowering>(patterns.getContext());
2617 }
2618 
2619 void mlir::vector::populateVectorMaskOpLoweringPatterns(
2620     RewritePatternSet &patterns) {
2621   patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
2622       patterns.getContext());
2623 }
2624 
2625 void mlir::vector::populateVectorShapeCastLoweringPatterns(
2626     RewritePatternSet &patterns) {
2627   patterns.add<ShapeCastOp2DDownCastRewritePattern,
2628                ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>(
2629       patterns.getContext());
2630 }
2631 
2632 void mlir::vector::populateVectorContractLoweringPatterns(
2633     RewritePatternSet &patterns, VectorTransformsOptions options) {
2634   patterns.add<OuterProductOpLowering>(patterns.getContext());
2635   patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
2636                ContractionOpToOuterProductOpLowering>(options,
2637                                                       patterns.getContext());
2638 }
2639 
2640 void mlir::vector::populateVectorTransposeLoweringPatterns(
2641     RewritePatternSet &patterns, VectorTransformsOptions options) {
2642   patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
2643       options, patterns.getContext());
2644 }
2645 
2646 void mlir::vector::populateVectorReductionToContractPatterns(
2647     RewritePatternSet &patterns) {
2648   patterns.add<MultiReduceToContract, CombineContractBroadcast,
2649                CombineContractTranspose, ReorderCastOpsOnBroadcast,
2650                ReorderCastOpsOnTranspose>(patterns.getContext());
2651 }
2652 
2653 void mlir::vector::
2654     populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
2655         RewritePatternSet &patterns) {
2656   patterns.add<DropInnerMostUnitDims>(patterns.getContext());
2657 }
2658 
2659 void mlir::vector::populateVectorTransferLoweringPatterns(
2660     RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank) {
2661   patterns.add<TransferReadToVectorLoadLowering,
2662                TransferWriteToVectorStoreLowering>(patterns.getContext(),
2663                                                    maxTransferRank);
2664   patterns
2665       .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
2666           patterns.getContext());
2667 }
2668 
2669 void mlir::vector::populateVectorScanLoweringPatterns(
2670     RewritePatternSet &patterns) {
2671   patterns.add<ScanToArithOps>(patterns.getContext());
2672 }
2673