1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites as 1->N patterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
14 
15 #include <type_traits>
16 
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
19 #include "mlir/Dialect/Linalg/IR/Linalg.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/Dialect/SCF/SCF.h"
22 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
23 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
24 #include "mlir/IR/ImplicitLocOpBuilder.h"
25 #include "mlir/IR/Matchers.h"
26 #include "mlir/IR/PatternMatch.h"
27 #include "mlir/Interfaces/VectorInterfaces.h"
28 
29 #include "llvm/ADT/DenseSet.h"
30 #include "llvm/ADT/MapVector.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/raw_ostream.h"
35 
36 #define DEBUG_TYPE "vector-to-vector"
37 
38 using namespace mlir;
39 using namespace mlir::vector;
40 
41 // Helper to find an index in an affine map.
42 static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
43   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
44     int64_t idx = map.getDimPosition(i);
45     if (idx == index)
46       return i;
47   }
48   return None;
49 }
50 
51 // Helper to construct iterator types with one index removed.
52 static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes,
53                                             int64_t index) {
54   SmallVector<Attribute, 4> results;
55   for (const auto &it : llvm::enumerate(iteratorTypes)) {
56     int64_t idx = it.index();
57     if (idx == index)
58       continue;
59     results.push_back(it.value());
60   }
61   return results;
62 }
63 
64 // Helper to construct an affine map with one index removed.
65 static AffineMap adjustMap(AffineMap map, int64_t index,
66                            PatternRewriter &rewriter) {
67   auto *ctx = rewriter.getContext();
68   SmallVector<AffineExpr, 4> results;
69   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
70     int64_t idx = map.getDimPosition(i);
71     if (idx == index)
72       continue;
73     // Re-insert remaining indices, but renamed when occurring
74     // after the removed index.
75     auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
76     results.push_back(targetExpr);
77   }
78   return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
79 }
80 
81 // Helper method to possibly drop a dimension in a load.
82 // TODO
83 static Value reshapeLoad(Location loc, Value val, VectorType type,
84                          int64_t index, int64_t pos,
85                          PatternRewriter &rewriter) {
86   if (index == -1)
87     return val;
88   Type lowType = VectorType::Builder(type).dropDim(0);
89   // At extraction dimension?
90   if (index == 0) {
91     auto posAttr = rewriter.getI64ArrayAttr(pos);
92     return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
93   }
94   // Unroll leading dimensions.
95   VectorType vType = lowType.cast<VectorType>();
96   Type resType = VectorType::Builder(type).dropDim(index);
97   auto resVectorType = resType.cast<VectorType>();
98   Value result = rewriter.create<arith::ConstantOp>(
99       loc, resVectorType, rewriter.getZeroAttr(resVectorType));
100   for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) {
101     auto posAttr = rewriter.getI64ArrayAttr(d);
102     Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
103     Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
104     result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result,
105                                                posAttr);
106   }
107   return result;
108 }
109 
110 // Helper method to possibly drop a dimension in a store.
111 // TODO
112 static Value reshapeStore(Location loc, Value val, Value result,
113                           VectorType type, int64_t index, int64_t pos,
114                           PatternRewriter &rewriter) {
115   // Unmodified?
116   if (index == -1)
117     return val;
118   // At insertion dimension?
119   if (index == 0) {
120     auto posAttr = rewriter.getI64ArrayAttr(pos);
121     return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
122   }
123   // Unroll leading dimensions.
124   Type lowType = VectorType::Builder(type).dropDim(0);
125   VectorType vType = lowType.cast<VectorType>();
126   Type insType = VectorType::Builder(vType).dropDim(0);
127   for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
128     auto posAttr = rewriter.getI64ArrayAttr(d);
129     Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
130     Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr);
131     Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
132     result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr);
133   }
134   return result;
135 }
136 
137 template <typename IntType>
138 static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
139   return llvm::to_vector<4>(llvm::map_range(
140       arrayAttr.getAsRange<IntegerAttr>(),
141       [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
142 }
143 
144 namespace {
145 
146 /// ShapeCastOpFolder folds cancelling ShapeCastOps away.
147 //
148 // Example:
149 //
150 //  The following MLIR with cancelling ShapeCastOps:
151 //
152 //   %0 = source : vector<5x4x2xf32>
153 //   %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
154 //   %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
155 //   %3 = user %2 : vector<5x4x2xf32>
156 //
157 //  Should canonicalize to the following:
158 //
159 //   %0 = source : vector<5x4x2xf32>
160 //   %1 = user %0 : vector<5x4x2xf32>
161 //
162 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
163   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
164 
165   LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
166                                 PatternRewriter &rewriter) const override {
167     // Check if 'shapeCastOp' has vector source/result type.
168     auto sourceVectorType =
169         shapeCastOp.source().getType().dyn_cast_or_null<VectorType>();
170     auto resultVectorType =
171         shapeCastOp.result().getType().dyn_cast_or_null<VectorType>();
172     if (!sourceVectorType || !resultVectorType)
173       return failure();
174 
175     // Check if shape cast op source operand is also a shape cast op.
176     auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
177         shapeCastOp.source().getDefiningOp());
178     if (!sourceShapeCastOp)
179       return failure();
180     auto operandSourceVectorType =
181         sourceShapeCastOp.source().getType().cast<VectorType>();
182     auto operandResultVectorType = sourceShapeCastOp.getType();
183 
184     // Check if shape cast operations invert each other.
185     if (operandSourceVectorType != resultVectorType ||
186         operandResultVectorType != sourceVectorType)
187       return failure();
188 
189     rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.source());
190     return success();
191   }
192 };
193 
194 /// Progressive lowering of BroadcastOp.
195 class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
196 public:
197   using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
198 
199   LogicalResult matchAndRewrite(vector::BroadcastOp op,
200                                 PatternRewriter &rewriter) const override {
201     auto loc = op.getLoc();
202     VectorType dstType = op.getVectorType();
203     VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
204     Type eltType = dstType.getElementType();
205 
206     // Scalar to any vector can use splat.
207     if (!srcType) {
208       rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.source());
209       return success();
210     }
211 
212     // Determine rank of source and destination.
213     int64_t srcRank = srcType.getRank();
214     int64_t dstRank = dstType.getRank();
215 
216     // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
217     if (srcRank <= 1 && dstRank == 1) {
218       Value ext;
219       if (srcRank == 0)
220         ext = rewriter.create<vector::ExtractElementOp>(loc, op.source());
221       else
222         ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
223       rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
224       return success();
225     }
226 
227     // Duplicate this rank.
228     // For example:
229     //   %x = broadcast %y  : k-D to n-D, k < n
230     // becomes:
231     //   %b = broadcast %y  : k-D to (n-1)-D
232     //   %x = [%b,%b,%b,%b] : n-D
233     // becomes:
234     //   %b = [%y,%y]       : (n-1)-D
235     //   %x = [%b,%b,%b,%b] : n-D
236     if (srcRank < dstRank) {
237       // Duplication.
238       VectorType resType =
239           VectorType::get(dstType.getShape().drop_front(), eltType);
240       Value bcst =
241           rewriter.create<vector::BroadcastOp>(loc, resType, op.source());
242       Value result = rewriter.create<arith::ConstantOp>(
243           loc, dstType, rewriter.getZeroAttr(dstType));
244       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
245         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
246       rewriter.replaceOp(op, result);
247       return success();
248     }
249 
250     // Find non-matching dimension, if any.
251     assert(srcRank == dstRank);
252     int64_t m = -1;
253     for (int64_t r = 0; r < dstRank; r++)
254       if (srcType.getDimSize(r) != dstType.getDimSize(r)) {
255         m = r;
256         break;
257       }
258 
259     // All trailing dimensions are the same. Simply pass through.
260     if (m == -1) {
261       rewriter.replaceOp(op, op.source());
262       return success();
263     }
264 
265     // Any non-matching dimension forces a stretch along this rank.
266     // For example:
267     //   %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32>
268     // becomes:
269     //   %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32>
270     //   %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32>
271     //   %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32>
272     //   %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32>
273     //   %x = [%a,%b,%c,%d]
274     // becomes:
275     //   %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32>
276     //   %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32>
277     //   %a = [%u, %v]
278     //   ..
279     //   %x = [%a,%b,%c,%d]
280     VectorType resType =
281         VectorType::get(dstType.getShape().drop_front(), eltType);
282     Value result = rewriter.create<arith::ConstantOp>(
283         loc, dstType, rewriter.getZeroAttr(dstType));
284     if (m == 0) {
285       // Stetch at start.
286       Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
287       Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
288       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
289         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
290     } else {
291       // Stetch not at start.
292       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
293         Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), d);
294         Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
295         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
296       }
297     }
298     rewriter.replaceOp(op, result);
299     return success();
300   }
301 };
302 
303 /// Return the number of leftmost dimensions from the first rightmost transposed
304 /// dimension found in 'transpose'.
305 size_t getNumDimsFromFirstTransposedDim(ArrayRef<int64_t> transpose) {
306   size_t numTransposedDims = transpose.size();
307   for (size_t transpDim : llvm::reverse(transpose)) {
308     if (transpDim != numTransposedDims - 1)
309       break;
310     numTransposedDims--;
311   }
312   return numTransposedDims;
313 }
314 
315 /// Progressive lowering of TransposeOp.
316 /// One:
317 ///   %x = vector.transpose %y, [1, 0]
318 /// is replaced by:
319 ///   %z = arith.constant dense<0.000000e+00>
320 ///   %0 = vector.extract %y[0, 0]
321 ///   %1 = vector.insert %0, %z [0, 0]
322 ///   ..
323 ///   %x = vector.insert .., .. [.., ..]
324 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
325 public:
326   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
327 
328   TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
329                       MLIRContext *context)
330       : OpRewritePattern<vector::TransposeOp>(context),
331         vectorTransformOptions(vectorTransformOptions) {}
332 
333   LogicalResult matchAndRewrite(vector::TransposeOp op,
334                                 PatternRewriter &rewriter) const override {
335     auto loc = op.getLoc();
336 
337     VectorType resType = op.getResultType();
338 
339     // Set up convenience transposition table.
340     SmallVector<int64_t, 4> transp;
341     for (auto attr : op.transp())
342       transp.push_back(attr.cast<IntegerAttr>().getInt());
343 
344     if (vectorTransformOptions.vectorTransposeLowering ==
345             vector::VectorTransposeLowering::Shuffle &&
346         resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0)
347       return rewriter.notifyMatchFailure(
348           op, "Options specifies lowering to shuffle");
349 
350     // Handle a true 2-D matrix transpose differently when requested.
351     if (vectorTransformOptions.vectorTransposeLowering ==
352             vector::VectorTransposeLowering::Flat &&
353         resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
354       Type flattenedType =
355           VectorType::get(resType.getNumElements(), resType.getElementType());
356       auto matrix =
357           rewriter.create<vector::ShapeCastOp>(loc, flattenedType, op.vector());
358       auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
359       auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
360       Value trans = rewriter.create<vector::FlatTransposeOp>(
361           loc, flattenedType, matrix, rows, columns);
362       rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
363       return success();
364     }
365 
366     // Generate unrolled extract/insert ops. We do not unroll the rightmost
367     // (i.e., highest-order) dimensions that are not transposed and leave them
368     // in vector form to improve performance.
369     size_t numLeftmostTransposedDims = getNumDimsFromFirstTransposedDim(transp);
370 
371     // The type of the extract operation will be scalar if all the dimensions
372     // are unrolled. Otherwise, it will be a vector with the shape of the
373     // dimensions that are not transposed.
374     Type extractType =
375         numLeftmostTransposedDims == transp.size()
376             ? resType.getElementType()
377             : VectorType::Builder(resType).setShape(
378                   resType.getShape().drop_front(numLeftmostTransposedDims));
379 
380     Value result = rewriter.create<arith::ConstantOp>(
381         loc, resType, rewriter.getZeroAttr(resType));
382     SmallVector<int64_t, 4> lhs(numLeftmostTransposedDims, 0);
383     SmallVector<int64_t, 4> rhs(numLeftmostTransposedDims, 0);
384     rewriter.replaceOp(op, expandIndices(loc, resType, extractType, 0,
385                                          numLeftmostTransposedDims, transp, lhs,
386                                          rhs, op.vector(), result, rewriter));
387     return success();
388   }
389 
390 private:
391   // Builds the indices arrays for the lhs and rhs. Generates the extract/insert
392   // operations when all the ranks go over the last dimension being transposed.
393   Value expandIndices(Location loc, VectorType resType, Type extractType,
394                       int64_t pos, int64_t numLeftmostTransposedDims,
395                       SmallVector<int64_t, 4> &transp,
396                       SmallVector<int64_t, 4> &lhs,
397                       SmallVector<int64_t, 4> &rhs, Value input, Value result,
398                       PatternRewriter &rewriter) const {
399     if (pos >= numLeftmostTransposedDims) {
400       auto ridx = rewriter.getI64ArrayAttr(rhs);
401       auto lidx = rewriter.getI64ArrayAttr(lhs);
402       Value e =
403           rewriter.create<vector::ExtractOp>(loc, extractType, input, ridx);
404       return rewriter.create<vector::InsertOp>(loc, resType, e, result, lidx);
405     }
406     for (int64_t d = 0, e = resType.getDimSize(pos); d < e; ++d) {
407       lhs[pos] = d;
408       rhs[transp[pos]] = d;
409       result = expandIndices(loc, resType, extractType, pos + 1,
410                              numLeftmostTransposedDims, transp, lhs, rhs, input,
411                              result, rewriter);
412     }
413     return result;
414   }
415 
416   /// Options to control the vector patterns.
417   vector::VectorTransformsOptions vectorTransformOptions;
418 };
419 
420 /// Rewrite a 2-D vector.transpose as a sequence of:
421 ///   vector.shape_cast 2D -> 1D
422 ///   vector.shuffle
423 ///   vector.shape_cast 1D -> 2D
424 class TransposeOp2DToShuffleLowering
425     : public OpRewritePattern<vector::TransposeOp> {
426 public:
427   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
428 
429   TransposeOp2DToShuffleLowering(
430       vector::VectorTransformsOptions vectorTransformOptions,
431       MLIRContext *context)
432       : OpRewritePattern<vector::TransposeOp>(context),
433         vectorTransformOptions(vectorTransformOptions) {}
434 
435   LogicalResult matchAndRewrite(vector::TransposeOp op,
436                                 PatternRewriter &rewriter) const override {
437     auto loc = op.getLoc();
438 
439     VectorType srcType = op.getVectorType();
440     if (srcType.getRank() != 2)
441       return rewriter.notifyMatchFailure(op, "Not a 2D transpose");
442 
443     SmallVector<int64_t, 4> transp;
444     for (auto attr : op.transp())
445       transp.push_back(attr.cast<IntegerAttr>().getInt());
446     if (transp[0] != 1 && transp[1] != 0)
447       return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation");
448 
449     if (vectorTransformOptions.vectorTransposeLowering !=
450         VectorTransposeLowering::Shuffle)
451       return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle");
452 
453     int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
454     Value casted = rewriter.create<vector::ShapeCastOp>(
455         loc, VectorType::get({m * n}, srcType.getElementType()), op.vector());
456     SmallVector<int64_t> mask;
457     mask.reserve(m * n);
458     for (int64_t j = 0; j < n; ++j)
459       for (int64_t i = 0; i < m; ++i)
460         mask.push_back(i * n + j);
461 
462     Value shuffled =
463         rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask);
464     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
465                                                      shuffled);
466 
467     return success();
468   }
469 
470 private:
471   /// Options to control the vector patterns.
472   vector::VectorTransformsOptions vectorTransformOptions;
473 };
474 
475 /// Progressive lowering of OuterProductOp.
476 /// One:
477 ///   %x = vector.outerproduct %lhs, %rhs, %acc
478 /// is replaced by:
479 ///   %z = zero-result
480 ///   %0 = vector.extract %lhs[0]
481 ///   %1 = vector.broadcast %0
482 ///   %2 = vector.extract %acc[0]
483 ///   %3 = vector.fma %1, %rhs, %2
484 ///   %4 = vector.insert %3, %z[0]
485 ///   ..
486 ///   %x = vector.insert %.., %..[N-1]
487 ///
488 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
489 public:
490   using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
491 
492   LogicalResult matchAndRewrite(vector::OuterProductOp op,
493                                 PatternRewriter &rewriter) const override {
494     auto loc = op.getLoc();
495 
496     VectorType lhsType = op.getOperandVectorTypeLHS();
497     VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>();
498     VectorType resType = op.getVectorType();
499     Type eltType = resType.getElementType();
500     bool isInt = eltType.isa<IntegerType, IndexType>();
501     Value acc = (op.acc().empty()) ? nullptr : op.acc()[0];
502     vector::CombiningKind kind = op.kind();
503 
504     if (!rhsType) {
505       // Special case: AXPY operation.
506       Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.rhs());
507       Optional<Value> mult =
508           isInt ? genMultI(loc, op.lhs(), b, acc, kind, rewriter)
509                 : genMultF(loc, op.lhs(), b, acc, kind, rewriter);
510       if (!mult.hasValue())
511         return failure();
512       rewriter.replaceOp(op, mult.getValue());
513       return success();
514     }
515 
516     Value result = rewriter.create<arith::ConstantOp>(
517         loc, resType, rewriter.getZeroAttr(resType));
518     for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
519       auto pos = rewriter.getI64ArrayAttr(d);
520       Value x = rewriter.create<vector::ExtractOp>(loc, eltType, op.lhs(), pos);
521       Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
522       Value r = nullptr;
523       if (acc)
524         r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
525       Optional<Value> m = isInt ? genMultI(loc, a, op.rhs(), r, kind, rewriter)
526                                 : genMultF(loc, a, op.rhs(), r, kind, rewriter);
527       if (!m.hasValue())
528         return failure();
529       result = rewriter.create<vector::InsertOp>(loc, resType, m.getValue(),
530                                                  result, pos);
531     }
532     rewriter.replaceOp(op, result);
533     return success();
534   }
535 
536 private:
537   static Optional<Value> genMultI(Location loc, Value x, Value y, Value acc,
538                                   vector::CombiningKind kind,
539                                   PatternRewriter &rewriter) {
540     using vector::CombiningKind;
541 
542     auto mul = rewriter.create<arith::MulIOp>(loc, x, y);
543     if (!acc)
544       return Optional<Value>(mul);
545 
546     if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF)
547       // Only valid for floating point types.
548       return Optional<Value>();
549 
550     return makeArithReduction(rewriter, loc, kind, mul, acc);
551   }
552 
553   static Optional<Value> genMultF(Location loc, Value x, Value y, Value acc,
554                                   vector::CombiningKind kind,
555                                   PatternRewriter &rewriter) {
556     using vector::CombiningKind;
557 
558     // Special case for fused multiply-add.
559     if (acc && kind == CombiningKind::ADD) {
560       return Optional<Value>(rewriter.create<vector::FMAOp>(loc, x, y, acc));
561     }
562 
563     auto mul = rewriter.create<arith::MulFOp>(loc, x, y);
564 
565     if (!acc)
566       return Optional<Value>(mul);
567 
568     if (kind == CombiningKind::ADD || kind == CombiningKind::AND ||
569         kind == CombiningKind::MINUI || kind == CombiningKind::MINSI ||
570         kind == CombiningKind::MAXUI || kind == CombiningKind::MAXSI ||
571         kind == CombiningKind::OR || kind == CombiningKind::XOR)
572       // Already handled or only valid for integer types.
573       return Optional<Value>();
574 
575     return makeArithReduction(rewriter, loc, kind, mul, acc);
576   }
577 };
578 
579 /// Progressive lowering of ConstantMaskOp.
580 /// One:
581 ///   %x = vector.constant_mask [a,b]
582 /// is replaced by:
583 ///   %z = zero-result
584 ///   %l = vector.constant_mask [b]
585 ///   %4 = vector.insert %l, %z[0]
586 ///   ..
587 ///   %x = vector.insert %l, %..[a-1]
588 /// until a one-dimensional vector is reached. All these operations
589 /// will be folded at LLVM IR level.
590 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
591 public:
592   using OpRewritePattern<vector::ConstantMaskOp>::OpRewritePattern;
593 
594   LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
595                                 PatternRewriter &rewriter) const override {
596     auto loc = op.getLoc();
597     auto dstType = op.getType();
598     auto eltType = dstType.getElementType();
599     auto dimSizes = op.mask_dim_sizes();
600     int64_t rank = dstType.getRank();
601 
602     if (rank == 0) {
603       assert(dimSizes.size() == 1 &&
604              "Expected exactly one dim size for a 0-D vector");
605       bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1;
606       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
607           op, dstType,
608           DenseIntElementsAttr::get(
609               VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
610               ArrayRef<bool>{value}));
611       return success();
612     }
613 
614     int64_t trueDim = std::min(dstType.getDimSize(0),
615                                dimSizes[0].cast<IntegerAttr>().getInt());
616 
617     if (rank == 1) {
618       // Express constant 1-D case in explicit vector form:
619       //   [T,..,T,F,..,F].
620       SmallVector<bool, 4> values(dstType.getDimSize(0));
621       for (int64_t d = 0; d < trueDim; d++)
622         values[d] = true;
623       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
624           op, dstType, rewriter.getBoolVectorAttr(values));
625       return success();
626     }
627 
628     VectorType lowType =
629         VectorType::get(dstType.getShape().drop_front(), eltType);
630     SmallVector<int64_t, 4> newDimSizes;
631     for (int64_t r = 1; r < rank; r++)
632       newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
633     Value trueVal = rewriter.create<vector::ConstantMaskOp>(
634         loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
635     Value result = rewriter.create<arith::ConstantOp>(
636         loc, dstType, rewriter.getZeroAttr(dstType));
637     for (int64_t d = 0; d < trueDim; d++) {
638       auto pos = rewriter.getI64ArrayAttr(d);
639       result =
640           rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos);
641     }
642     rewriter.replaceOp(op, result);
643     return success();
644   }
645 };
646 
647 /// Progressive lowering of CreateMaskOp.
648 /// One:
649 ///   %x = vector.create_mask %a, ... : vector<dx...>
650 /// is replaced by:
651 ///   %l = vector.create_mask ... : vector<...>  ; one lower rank
652 ///   %0 = arith.cmpi "slt", %ci, %a       |
653 ///   %1 = select %0, %l, %zeroes    |
654 ///   %r = vector.insert %1, %pr [i] | d-times
655 ///   %x = ....
656 /// until a one-dimensional vector is reached.
657 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
658 public:
659   using OpRewritePattern<vector::CreateMaskOp>::OpRewritePattern;
660 
661   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
662                                 PatternRewriter &rewriter) const override {
663     auto dstType = op.getResult().getType().cast<VectorType>();
664     int64_t rank = dstType.getRank();
665     if (rank <= 1)
666       return rewriter.notifyMatchFailure(
667           op, "0-D and 1-D vectors are handled separately");
668 
669     auto loc = op.getLoc();
670     auto eltType = dstType.getElementType();
671     int64_t dim = dstType.getDimSize(0);
672     Value idx = op.getOperand(0);
673 
674     VectorType lowType =
675         VectorType::get(dstType.getShape().drop_front(), eltType);
676     Value trueVal = rewriter.create<vector::CreateMaskOp>(
677         loc, lowType, op.getOperands().drop_front());
678     Value falseVal = rewriter.create<arith::ConstantOp>(
679         loc, lowType, rewriter.getZeroAttr(lowType));
680     Value result = rewriter.create<arith::ConstantOp>(
681         loc, dstType, rewriter.getZeroAttr(dstType));
682     for (int64_t d = 0; d < dim; d++) {
683       Value bnd =
684           rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d));
685       Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
686                                                  bnd, idx);
687       Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal);
688       auto pos = rewriter.getI64ArrayAttr(d);
689       result =
690           rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos);
691     }
692     rewriter.replaceOp(op, result);
693     return success();
694   }
695 };
696 
697 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
698 /// vectors progressively on the way to target llvm.matrix intrinsics.
699 /// This iterates over the most major dimension of the 2-D vector and performs
700 /// rewrites into:
701 ///   vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
702 class ShapeCastOp2DDownCastRewritePattern
703     : public OpRewritePattern<vector::ShapeCastOp> {
704 public:
705   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
706 
707   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
708                                 PatternRewriter &rewriter) const override {
709     auto sourceVectorType = op.getSourceVectorType();
710     auto resultVectorType = op.getResultVectorType();
711     if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
712       return failure();
713 
714     auto loc = op.getLoc();
715     Value desc = rewriter.create<arith::ConstantOp>(
716         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
717     unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
718     for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
719       Value vec = rewriter.create<vector::ExtractOp>(loc, op.source(), i);
720       desc = rewriter.create<vector::InsertStridedSliceOp>(
721           loc, vec, desc,
722           /*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
723     }
724     rewriter.replaceOp(op, desc);
725     return success();
726   }
727 };
728 
729 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
730 /// vectors progressively.
731 /// This iterates over the most major dimension of the 2-D vector and performs
732 /// rewrites into:
733 ///   vector.extract_strided_slice from 1-D + vector.insert into 2-D
734 /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
735 class ShapeCastOp2DUpCastRewritePattern
736     : public OpRewritePattern<vector::ShapeCastOp> {
737 public:
738   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
739 
740   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
741                                 PatternRewriter &rewriter) const override {
742     auto sourceVectorType = op.getSourceVectorType();
743     auto resultVectorType = op.getResultVectorType();
744     if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
745       return failure();
746 
747     auto loc = op.getLoc();
748     Value desc = rewriter.create<arith::ConstantOp>(
749         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
750     unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
751     for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
752       Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
753           loc, op.source(), /*offsets=*/i * mostMinorVectorSize,
754           /*sizes=*/mostMinorVectorSize,
755           /*strides=*/1);
756       desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
757     }
758     rewriter.replaceOp(op, desc);
759     return success();
760   }
761 };
762 
763 // We typically should not lower general shape cast operations into data
764 // movement instructions, since the assumption is that these casts are
765 // optimized away during progressive lowering. For completeness, however,
766 // we fall back to a reference implementation that moves all elements
767 // into the right place if we get here.
768 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
769 public:
770   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
771 
772   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
773                                 PatternRewriter &rewriter) const override {
774     Location loc = op.getLoc();
775     auto sourceVectorType = op.getSourceVectorType();
776     auto resultVectorType = op.getResultVectorType();
777 
778     // Special case 2D/1D lowerings with better implementations.
779     // TODO: make is ND/1D to allow generic ND->1D->MD.
780     int64_t srcRank = sourceVectorType.getRank();
781     int64_t resRank = resultVectorType.getRank();
782     if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
783       return failure();
784 
785     // Generic ShapeCast lowering path goes all the way down to unrolled scalar
786     // extract/insert chains.
787     // TODO: consider evolving the semantics to only allow 1D source or dest and
788     // drop this potentially very expensive lowering.
789     // Compute number of elements involved in the reshape.
790     int64_t numElts = 1;
791     for (int64_t r = 0; r < srcRank; r++)
792       numElts *= sourceVectorType.getDimSize(r);
793     // Replace with data movement operations:
794     //    x[0,0,0] = y[0,0]
795     //    x[0,0,1] = y[0,1]
796     //    x[0,1,0] = y[0,2]
797     // etc., incrementing the two index vectors "row-major"
798     // within the source and result shape.
799     SmallVector<int64_t, 4> srcIdx(srcRank);
800     SmallVector<int64_t, 4> resIdx(resRank);
801     Value result = rewriter.create<arith::ConstantOp>(
802         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
803     for (int64_t i = 0; i < numElts; i++) {
804       if (i != 0) {
805         incIdx(srcIdx, sourceVectorType, srcRank - 1);
806         incIdx(resIdx, resultVectorType, resRank - 1);
807       }
808       Value e = rewriter.create<vector::ExtractOp>(loc, op.source(), srcIdx);
809       result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx);
810     }
811     rewriter.replaceOp(op, result);
812     return success();
813   }
814 
815 private:
816   static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) {
817     assert(0 <= r && r < tp.getRank());
818     if (++idx[r] == tp.getDimSize(r)) {
819       idx[r] = 0;
820       incIdx(idx, tp, r - 1);
821     }
822   }
823 };
824 
825 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
826 /// Ex:
827 /// ```
828 ///   %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
829 ///   %1 = vector.multi_reduction add, %0 [1]
830 ///     : vector<8x32x16xf32> to vector<8x16xf32>
831 /// ```
832 /// Gets converted to:
833 /// ```
834 ///   %1 = vector.contract {indexing_maps = [
835 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
836 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
837 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
838 ///    iterator_types = ["parallel", "parallel", "reduction"],
839 ///    kind = add} %0, %arg1, %cst_f0
840 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
841 ///  ```
842 struct MultiReduceToContract
843     : public OpRewritePattern<vector::MultiDimReductionOp> {
844   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
845 
846   LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
847                                 PatternRewriter &rewriter) const override {
848     if (reduceOp.kind() != vector::CombiningKind::ADD)
849       return failure();
850     Operation *mulOp = reduceOp.source().getDefiningOp();
851     if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
852       return failure();
853     SmallVector<bool> reductionMask = reduceOp.getReductionMask();
854     auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
855     SmallVector<AffineExpr> exprs;
856     SmallVector<StringRef> iteratorTypes;
857     for (const auto &isReduceDim : llvm::enumerate(reductionMask)) {
858       if (!isReduceDim.value()) {
859         iteratorTypes.push_back(getParallelIteratorTypeName());
860         exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
861       } else {
862         iteratorTypes.push_back(getReductionIteratorTypeName());
863       }
864     }
865     auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(),
866                                  /*symCount=*/0, exprs, reduceOp.getContext());
867     Value zero = rewriter.create<arith::ConstantOp>(
868         reduceOp.getLoc(), reduceOp.getDestType(),
869         rewriter.getZeroAttr(reduceOp.getDestType()));
870     rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
871         reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), zero,
872         rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
873         rewriter.getStrArrayAttr(iteratorTypes));
874     return success();
875   }
876 };
877 
878 /// Merge TransposeOp into ContractionOp user.
879 /// Ex:
880 /// ```
881 ///   %0 = vector.transpose %arg0, [2, 0, 1]
882 ///     : vector<32x16x8xf32> to vector<8x32x16xf32>
883 ///   %1 = vector.contract {indexing_maps = [
884 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
885 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
886 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
887 ///    iterator_types = ["parallel", "parallel", "reduction"],
888 ///    kind = add} %0, %arg1, %cst_f0
889 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
890 /// ```
891 /// Gets converted to:
892 /// ```
893 ///   %1 = vector.contract {indexing_maps = [
894 ///         affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
895 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
896 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
897 ///    iterator_types = ["parallel", "parallel", "reduction"],
898 ///    kind = add} %arg0, %arg1, %cst_f0
899 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
900 ///  ```
901 struct CombineContractTranspose
902     : public OpRewritePattern<vector::ContractionOp> {
903   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
904 
905   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
906                                 PatternRewriter &rewriter) const override {
907     SmallVector<AffineMap, 4> maps =
908         llvm::to_vector<4>(contractOp.getIndexingMaps());
909     Value lhs = contractOp.lhs();
910     Value rhs = contractOp.rhs();
911     size_t index = 0;
912     bool changed = false;
913     for (Value *operand : {&lhs, &rhs}) {
914       AffineMap &map = maps[index++];
915       auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
916       if (!transposeOp)
917         continue;
918       SmallVector<int64_t> perm;
919       transposeOp.getTransp(perm);
920       AffineMap permutationMap = AffineMap::getPermutationMap(
921           extractVector<unsigned>(transposeOp.transp()),
922           contractOp.getContext());
923       map = inversePermutation(permutationMap).compose(map);
924       *operand = transposeOp.vector();
925       changed = true;
926     }
927     if (!changed)
928       return failure();
929     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
930         contractOp, lhs, rhs, contractOp.acc(),
931         rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types());
932     return success();
933   }
934 };
935 
936 /// Merge BroadcastOp into ContractionOp user.
937 /// Ex:
938 /// ```
939 ///   %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
940 ///   %1 = vector.contract {indexing_maps = [
941 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
942 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
943 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
944 ///    iterator_types = ["parallel", "parallel", "reduction"],
945 ///    kind = add} %0, %arg1, %cst_f0
946 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
947 /// ```
948 /// Gets converted to:
949 /// ```
950 ///   %1 = vector.contract {indexing_maps = [
951 ///         affine_map<(d0, d1, d2) -> (d1, d2)>,
952 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
953 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
954 ///    iterator_types = ["parallel", "parallel", "reduction"],
955 ///    kind = add} %arg0, %arg1, %cst_f0
956 ///    : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
957 ///  ```
958 struct CombineContractBroadcast
959     : public OpRewritePattern<vector::ContractionOp> {
960   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
961 
962   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
963                                 PatternRewriter &rewriter) const override {
964     SmallVector<AffineMap, 4> maps =
965         llvm::to_vector<4>(contractOp.getIndexingMaps());
966     Value lhs = contractOp.lhs();
967     Value rhs = contractOp.rhs();
968     size_t index = 0;
969     bool changed = false;
970     for (Value *operand : {&lhs, &rhs}) {
971       AffineMap &map = maps[index++];
972       auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
973       if (!broadcast)
974         continue;
975       // contractionOp can only take vector as operands.
976       auto srcType = broadcast.getSourceType().dyn_cast<VectorType>();
977       if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank())
978         continue;
979       int64_t rankDiff =
980           broadcast.getVectorType().getRank() - srcType.getRank();
981       bool innerDimBroadcast = false;
982       SmallVector<AffineExpr> originalDims;
983       for (const auto &dim : llvm::enumerate(srcType.getShape())) {
984         if (dim.value() !=
985             broadcast.getVectorType().getDimSize(rankDiff + dim.index())) {
986           innerDimBroadcast = true;
987           break;
988         }
989         originalDims.push_back(
990             rewriter.getAffineDimExpr(dim.index() + rankDiff));
991       }
992       // Contract doesn't support inner dimension broadcast. Once this is
993       // relaxed we can remove this case.
994       if (innerDimBroadcast)
995         continue;
996       AffineMap broadcastMap =
997           AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims,
998                          contractOp.getContext());
999       map = broadcastMap.compose(map);
1000       *operand = broadcast.source();
1001       changed = true;
1002     }
1003     if (!changed)
1004       return failure();
1005     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1006         contractOp, lhs, rhs, contractOp.acc(),
1007         rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types());
1008     return success();
1009   }
1010 };
1011 
1012 /// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and
1013 /// contraction ops closer, which kicks in CombineContractBroadcast pattern when
1014 /// casting ops are around these operations.
1015 /// Ex:
1016 /// ```
1017 ///   %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8>
1018 ///   %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
1019 /// ```
1020 /// Gets converted to:
1021 /// ```
1022 ///   %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32>
1023 ///   %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32>
1024 /// ```
1025 struct ReorderCastOpsOnBroadcast
1026     : public OpInterfaceRewritePattern<CastOpInterface> {
1027   using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
1028 
1029   LogicalResult matchAndRewrite(CastOpInterface op,
1030                                 PatternRewriter &rewriter) const override {
1031     if (op->getNumOperands() != 1)
1032       return failure();
1033     auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
1034     if (!bcastOp)
1035       return failure();
1036 
1037     Type castResTy = getElementTypeOrSelf(op->getResult(0));
1038     if (auto vecTy = bcastOp.getSourceType().dyn_cast<VectorType>())
1039       castResTy = VectorType::get(vecTy.getShape(), castResTy);
1040     OperationState state(op->getLoc(), op->getName(), bcastOp.source(),
1041                          castResTy, op->getAttrs());
1042     auto castOp = rewriter.createOperation(state);
1043     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1044         op, op->getResult(0).getType(), castOp->getResult(0));
1045     return success();
1046   }
1047 };
1048 
1049 /// Reorders cast(transpose) to transpose(cast). This makes broadcast ops and
1050 /// contraction ops closer, which kicks in CombineContractTranspose pattern when
1051 /// casting ops are around these operations.
1052 /// Ex:
1053 /// ```
1054 ///   %0 = vector.transpose %arg0, [2, 0, 1]
1055 ///     : vector<32x16x8xi8> to vector<8x32x16xi8>
1056 ///   %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
1057 /// ```
1058 /// Gets converted to:
1059 /// ```
1060 ///   %0 = arith.extsi %0 : vector<32x16x8xi8> to vector<32x16x8xi32>
1061 ///   %1 = vector.transpose %arg0, [2, 0, 1]
1062 ///     : vector<32x16x8xi32> to vector<8x32x16xi32>
1063 /// ```
1064 struct ReorderCastOpsOnTranspose
1065     : public OpInterfaceRewritePattern<CastOpInterface> {
1066 
1067   using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
1068 
1069   LogicalResult matchAndRewrite(CastOpInterface op,
1070                                 PatternRewriter &rewriter) const override {
1071     if (op->getNumOperands() != 1)
1072       return failure();
1073     auto transpOp = op->getOperand(0).getDefiningOp<vector::TransposeOp>();
1074     if (!transpOp)
1075       return failure();
1076 
1077     auto castResTy = transpOp.getVectorType();
1078     castResTy = VectorType::get(castResTy.getShape(),
1079                                 getElementTypeOrSelf(op->getResult(0)));
1080     OperationState state(op->getLoc(), op->getName(), transpOp.vector(),
1081                          castResTy, op->getAttrs());
1082     auto castOp = rewriter.createOperation(state);
1083     rewriter.replaceOpWithNewOp<vector::TransposeOp>(
1084         op, op->getResult(0).getType(), castOp->getResult(0),
1085         transpOp.getTransp());
1086     return success();
1087   }
1088 };
1089 
1090 } // namespace
1091 
1092 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
1093 /// operands `x` and `y`.
1094 static Value createAdd(Location loc, Value x, Value y, bool isInt,
1095                        PatternRewriter &rewriter) {
1096   if (isInt)
1097     return rewriter.create<arith::AddIOp>(loc, x, y);
1098   return rewriter.create<arith::AddFOp>(loc, x, y);
1099 }
1100 
1101 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
1102 /// operands `x and `y`.
1103 static Value createMul(Location loc, Value x, Value y, bool isInt,
1104                        PatternRewriter &rewriter) {
1105   if (isInt)
1106     return rewriter.create<arith::MulIOp>(loc, x, y);
1107   return rewriter.create<arith::MulFOp>(loc, x, y);
1108 }
1109 
1110 namespace mlir {
1111 
1112 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1113 /// semantics to:
1114 /// ```
1115 ///    %mta = maybe_transpose
1116 ///    %mtb = maybe_transpose
1117 ///    %flattened_a = vector.shape_cast %mta
1118 ///    %flattened_b = vector.shape_cast %mtb
1119 ///    %flattened_d = vector.matmul %flattened_a, %flattened_b
1120 ///    %mtd = vector.shape_cast %flattened_d
1121 ///    %d = maybe_untranspose %mtd
1122 ///    %e = add %c, %d
1123 /// ```
1124 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
1125 //
1126 /// This only kicks in when VectorTransformsOptions is set to `Matmul`.
1127 /// vector.transpose operations are inserted if the vector.contract op is not a
1128 /// row-major matrix multiply.
1129 LogicalResult
1130 ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
1131                                                  PatternRewriter &rew) const {
1132   // TODO: implement masks
1133   if (llvm::size(op.masks()) != 0)
1134     return failure();
1135   if (vectorTransformOptions.vectorContractLowering !=
1136       vector::VectorContractLowering::Matmul)
1137     return failure();
1138   if (failed(filter(op)))
1139     return failure();
1140 
1141   auto iteratorTypes = op.iterator_types().getValue();
1142   if (!isParallelIterator(iteratorTypes[0]) ||
1143       !isParallelIterator(iteratorTypes[1]) ||
1144       !isReductionIterator(iteratorTypes[2]))
1145     return failure();
1146 
1147   Type elementType = op.getLhsType().getElementType();
1148   if (!elementType.isIntOrFloat())
1149     return failure();
1150 
1151   // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
1152   // Bail out if the contraction cannot be put in this form.
1153   MLIRContext *ctx = op.getContext();
1154   Location loc = op.getLoc();
1155   AffineExpr m, n, k;
1156   bindDims(rew.getContext(), m, n, k);
1157   // LHS must be A(m, k) or A(k, m).
1158   Value lhs = op.lhs();
1159   auto lhsMap = op.indexing_maps()[0];
1160   if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
1161     lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
1162   else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
1163     return failure();
1164 
1165   // RHS must be B(k, n) or B(n, k).
1166   Value rhs = op.rhs();
1167   auto rhsMap = op.indexing_maps()[1];
1168   if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
1169     rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
1170   else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
1171     return failure();
1172 
1173   // At this point lhs and rhs are in row-major.
1174   VectorType lhsType = lhs.getType().cast<VectorType>();
1175   VectorType rhsType = rhs.getType().cast<VectorType>();
1176   int64_t lhsRows = lhsType.getDimSize(0);
1177   int64_t lhsColumns = lhsType.getDimSize(1);
1178   int64_t rhsColumns = rhsType.getDimSize(1);
1179 
1180   Type flattenedLHSType =
1181       VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
1182   lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
1183 
1184   Type flattenedRHSType =
1185       VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
1186   rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
1187 
1188   Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
1189                                            rhsColumns);
1190   mul = rew.create<vector::ShapeCastOp>(
1191       loc,
1192       VectorType::get({lhsRows, rhsColumns},
1193                       getElementTypeOrSelf(op.acc().getType())),
1194       mul);
1195 
1196   // ACC must be C(m, n) or C(n, m).
1197   auto accMap = op.indexing_maps()[2];
1198   if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
1199     mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
1200   else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
1201     llvm_unreachable("invalid contraction semantics");
1202 
1203   Value res =
1204       elementType.isa<IntegerType>()
1205           ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.acc(), mul))
1206           : static_cast<Value>(rew.create<arith::AddFOp>(loc, op.acc(), mul));
1207 
1208   rew.replaceOp(op, res);
1209   return success();
1210 }
1211 
1212 namespace {
1213 struct IteratorType {
1214   IteratorType(StringRef strRef) : strRef(strRef) {}
1215   bool isOfType(Attribute attr) const {
1216     auto sAttr = attr.dyn_cast<StringAttr>();
1217     return sAttr && sAttr.getValue() == strRef;
1218   }
1219   StringRef strRef;
1220 };
1221 struct Par : public IteratorType {
1222   Par() : IteratorType(getParallelIteratorTypeName()) {}
1223 };
1224 struct Red : public IteratorType {
1225   Red() : IteratorType(getReductionIteratorTypeName()) {}
1226 };
1227 
1228 /// Generate a vector implementation for matmat, matvec and tmatvec.
1229 /// This unrolls outer-products along the reduction dimension.
1230 struct UnrolledOuterProductGenerator
1231     : public StructuredGenerator<vector::ContractionOp> {
1232 
1233   UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op)
1234       : StructuredGenerator<vector::ContractionOp>(builder, op),
1235         kind(op.kind()), lhs(op.lhs()), rhs(op.rhs()), res(op.acc()),
1236         lhsType(op.getLhsType()) {}
1237 
1238   Value t(Value v) {
1239     static constexpr std::array<int64_t, 2> perm = {1, 0};
1240     return builder.create<vector::TransposeOp>(loc, v, perm);
1241   }
1242 
1243   Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) {
1244     assert(reductionSize > 0);
1245     for (int64_t k = 0; k < reductionSize; ++k) {
1246       Value a = builder.create<vector::ExtractOp>(loc, lhs, k);
1247       Value b = builder.create<vector::ExtractOp>(loc, rhs, k);
1248       res = builder.create<vector::OuterProductOp>(loc, res.getType(), a, b,
1249                                                    res, kind);
1250     }
1251     return res;
1252   }
1253 
1254   /// Two outer parallel, one inner reduction (matmat flavor).
1255   FailureOr<Value> matmat() {
1256     if (!iters({Par(), Par(), Red()}))
1257       return failure();
1258     // Set up the parallel/reduction structure in the right form.
1259     AffineExpr m, n, k;
1260     bindDims(builder.getContext(), m, n, k);
1261     // Classical row-major matmul:  Just permute the lhs.
1262     if (layout({{m, k}, {k, n}, {m, n}}))
1263       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1264     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1265     if (layout({{m, k}, {n, k}, {m, n}})) {
1266       Value tlhs = t(lhs);
1267       return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1));
1268     }
1269     // No need to permute anything.
1270     if (layout({{k, m}, {k, n}, {m, n}}))
1271       return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1272     // Just permute the rhs.
1273     if (layout({{k, m}, {n, k}, {m, n}}))
1274       return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0));
1275     // Transposed output: swap RHS and LHS.
1276     // Classical row-major matmul: permute the lhs.
1277     if (layout({{m, k}, {k, n}, {n, m}}))
1278       return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1));
1279     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1280     if (layout({{m, k}, {n, k}, {n, m}})) {
1281       Value trhs = t(rhs);
1282       return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1));
1283     }
1284     if (layout({{k, m}, {k, n}, {n, m}}))
1285       return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1286     if (layout({{k, m}, {n, k}, {n, m}}))
1287       return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1288     return failure();
1289   }
1290 
1291   /// One outer parallel, one inner reduction (matvec flavor)
1292   FailureOr<Value> matvec() {
1293     if (!iters({Par(), Red()}))
1294       return failure();
1295     AffineExpr m, k;
1296     bindDims(builder.getContext(), m, k);
1297 
1298     // Case mat-vec: transpose.
1299     if (layout({{m, k}, {k}, {m}}))
1300       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1301     // Case mat-trans-vec: ready to go.
1302     if (layout({{k, m}, {k}, {m}}))
1303       return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1304     // Case vec-mat: swap and transpose.
1305     if (layout({{k}, {m, k}, {m}}))
1306       return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1307     // Case vec-mat-trans: swap and ready to go.
1308     if (layout({{k}, {k, m}, {m}}))
1309       return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1310     return failure();
1311   }
1312 
1313   //
1314   // One outer reduction, one inner parallel (tmatvec flavor)
1315   //
1316   FailureOr<Value> tmatvec() {
1317     if (!iters({Red(), Par()}))
1318       return failure();
1319     AffineExpr k, m;
1320     bindDims(builder.getContext(), k, m);
1321 
1322     // Case mat-vec: transpose.
1323     if (layout({{m, k}, {k}, {m}}))
1324       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1325     // Case mat-trans-vec: ready to go.
1326     if (layout({{k, m}, {k}, {m}}))
1327       return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1328     // Case vec-mat: swap and transpose.
1329     if (layout({{k}, {m, k}, {m}}))
1330       return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1331     // Case vec-mat-trans: swap and ready to go.
1332     if (layout({{k}, {k, m}, {m}}))
1333       return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1334     return failure();
1335   }
1336 
1337 private:
1338   vector::CombiningKind kind;
1339   Value lhs, rhs, res;
1340   VectorType lhsType;
1341 };
1342 } // namespace
1343 
1344 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1345 /// semantics to a reduction_size-unrolled sequence:
1346 /// ```
1347 ///    %at = vector.transpose %a, [1, 0]
1348 ///    %bRow0 = vector.extract %b[0]
1349 ///    %atRow0 = vector.extract %at[0]
1350 ///    %c0 = vector.outerproduct %atRow0, %bRow0, %c
1351 ///    ...
1352 ///    %bRowK = vector.extract %b[K]
1353 ///    %atRowK = vector.extract %at[K]
1354 ///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
1355 /// ```
1356 ///
1357 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but
1358 /// otherwise supports any layout permutation of the matrix-multiply.
1359 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
1360     vector::ContractionOp op, PatternRewriter &rewriter) const {
1361   // TODO: implement masks
1362   if (llvm::size(op.masks()) != 0)
1363     return failure();
1364 
1365   if (vectorTransformOptions.vectorContractLowering !=
1366       vector::VectorContractLowering::OuterProduct)
1367     return failure();
1368 
1369   if (failed(filter(op)))
1370     return failure();
1371 
1372   UnrolledOuterProductGenerator e(rewriter, op);
1373   FailureOr<Value> matmatRes = e.matmat();
1374   if (succeeded(matmatRes)) {
1375     rewriter.replaceOp(op, *matmatRes);
1376     return success();
1377   }
1378   FailureOr<Value> matvecRes = e.matvec();
1379   if (succeeded(matvecRes)) {
1380     rewriter.replaceOp(op, *matvecRes);
1381     return success();
1382   }
1383   FailureOr<Value> tmatvecRes = e.tmatvec();
1384   if (succeeded(tmatvecRes)) {
1385     rewriter.replaceOp(op, *tmatvecRes);
1386     return success();
1387   }
1388 
1389   return failure();
1390 }
1391 
1392 LogicalResult
1393 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
1394                                             PatternRewriter &rewriter) const {
1395   // TODO: implement masks
1396   if (llvm::size(op.masks()) != 0)
1397     return failure();
1398 
1399   if (failed(filter(op)))
1400     return failure();
1401 
1402   if (vectorTransformOptions.vectorContractLowering !=
1403       vector::VectorContractLowering::Dot)
1404     return failure();
1405 
1406   auto iteratorTypes = op.iterator_types().getValue();
1407   static constexpr std::array<int64_t, 2> perm = {1, 0};
1408   Location loc = op.getLoc();
1409   Value lhs = op.lhs(), rhs = op.rhs();
1410 
1411   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1412   auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
1413   AffineExpr m, n, k;
1414   bindDims(rewriter.getContext(), m, n, k);
1415   SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
1416   //
1417   // In the following we wish to make the reduction dimension innermost so we
1418   // can load vectors and just fmul + reduce into a scalar.
1419   //
1420   if (isParallelIterator(iteratorTypes[0]) &&
1421       isParallelIterator(iteratorTypes[1]) &&
1422       isReductionIterator(iteratorTypes[2])) {
1423     //
1424     // Two outer parallel, one inner reduction (matmat flavor).
1425     //
1426     if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1427       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1428     } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
1429       // No need to permute anything.
1430     } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1431       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1432       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1433     } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1434       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1435     } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1436       // This is the classical row-major matmul. Just permute the lhs.
1437       Value tmp = lhs;
1438       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1439       rhs = tmp;
1440     } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1441       std::swap(lhs, rhs);
1442     } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1443       Value tmp = lhs;
1444       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1445       rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
1446     } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1447       Value tmp = rhs;
1448       rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1449       lhs = tmp;
1450     } else {
1451       return failure();
1452     }
1453   } else if (isParallelIterator(iteratorTypes[0]) &&
1454              isReductionIterator(iteratorTypes[1])) {
1455     //
1456     // One outer parallel, one inner reduction (matvec flavor)
1457     //
1458     if (maps == infer({{m, n}, {n}, {m}})) {
1459       // No need to permute anything.
1460     } else if (maps == infer({{n, m}, {n}, {m}})) {
1461       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1462     } else if (maps == infer({{n}, {m, n}, {m}})) {
1463       std::swap(lhs, rhs);
1464     } else if (maps == infer({{n}, {n, m}, {m}})) {
1465       std::swap(lhs, rhs);
1466       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1467     } else {
1468       return failure();
1469     }
1470   } else {
1471     return failure();
1472   }
1473 
1474   VectorType dstType = op.getResultType().cast<VectorType>();
1475   assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
1476          "Expected dst type of rank 1 or 2");
1477 
1478   unsigned rank = dstType.getRank();
1479   unsigned dstRows = dstType.getShape()[0];
1480   unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
1481 
1482   // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
1483   Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
1484                                                  rewriter.getZeroAttr(dstType));
1485   bool isInt = dstType.getElementType().isa<IntegerType>();
1486   for (unsigned r = 0; r < dstRows; ++r) {
1487     Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
1488     for (unsigned c = 0; c < dstColumns; ++c) {
1489       Value b = rank == 1
1490                     ? rhs
1491                     : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
1492       Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
1493       Value reduced = rewriter.create<vector::ReductionOp>(
1494           op.getLoc(), vector::CombiningKind::ADD, m);
1495 
1496       SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
1497                                               : SmallVector<int64_t, 2>{r, c};
1498       res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
1499     }
1500   }
1501   if (auto acc = op.acc())
1502     res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
1503   rewriter.replaceOp(op, res);
1504   return success();
1505 }
1506 
1507 /// Progressive lowering of ContractionOp.
1508 /// One:
1509 ///   %x = vector.contract with at least one free/batch dimension
1510 /// is replaced by:
1511 ///   %a = vector.contract with one less free/batch dimension
1512 ///   %b = vector.contract with one less free/batch dimension
1513 ///   ..
1514 ///   %x = combine %a %b ..
1515 /// until a pure contraction is reached (no free/batch dimensions),
1516 /// which is replaced by a dot-product.
1517 ///
1518 /// This only kicks in when either VectorTransformsOptions is set
1519 /// to DOT or when other contraction patterns fail.
1520 //
1521 // TODO: break down into transpose/reshape/cast ops
1522 //               when they become available to avoid code dup
1523 // TODO: investigate lowering order impact on performance
1524 LogicalResult
1525 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
1526                                        PatternRewriter &rewriter) const {
1527   // TODO: implement masks.
1528   if (llvm::size(op.masks()) != 0)
1529     return failure();
1530 
1531   if (failed(filter(op)))
1532     return failure();
1533 
1534   // TODO: support mixed mode contract lowering.
1535   if (op.getLhsType().getElementType() !=
1536           getElementTypeOrSelf(op.getAccType()) ||
1537       op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
1538     return failure();
1539 
1540   // TODO: implement benefits, cost models.
1541   MLIRContext *ctx = op.getContext();
1542   ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
1543   if (succeeded(pat1.matchAndRewrite(op, rewriter)))
1544     return success();
1545   ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
1546   if (succeeded(pat2.matchAndRewrite(op, rewriter)))
1547     return success();
1548   ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
1549   if (succeeded(pat3.matchAndRewrite(op, rewriter)))
1550     return success();
1551 
1552   // Find first batch dimension in LHS/RHS, and lower when found.
1553   std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
1554   if (!batchDimMap.empty()) {
1555     int64_t lhsIndex = batchDimMap[0].first;
1556     int64_t rhsIndex = batchDimMap[0].second;
1557     rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter));
1558     return success();
1559   }
1560 
1561   // Collect contracting dimensions.
1562   std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
1563       op.getContractingDimMap();
1564   DenseSet<int64_t> lhsContractingDimSet;
1565   DenseSet<int64_t> rhsContractingDimSet;
1566   for (auto &dimPair : contractingDimMap) {
1567     lhsContractingDimSet.insert(dimPair.first);
1568     rhsContractingDimSet.insert(dimPair.second);
1569   }
1570 
1571   // Find first free dimension in LHS, and lower when found.
1572   VectorType lhsType = op.getLhsType();
1573   for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
1574     if (lhsContractingDimSet.count(lhsIndex) == 0) {
1575       rewriter.replaceOp(
1576           op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter));
1577       return success();
1578     }
1579   }
1580 
1581   // Find first free dimension in RHS, and lower when found.
1582   VectorType rhsType = op.getRhsType();
1583   for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
1584     if (rhsContractingDimSet.count(rhsIndex) == 0) {
1585       rewriter.replaceOp(
1586           op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter));
1587       return success();
1588     }
1589   }
1590 
1591   // Lower the first remaining reduction dimension.
1592   if (!contractingDimMap.empty()) {
1593     rewriter.replaceOp(op, lowerReduction(op, rewriter));
1594     return success();
1595   }
1596 
1597   return failure();
1598 }
1599 
1600 // Lower one parallel dimension.
1601 // TODO: consider reusing existing contract unrolling
1602 Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
1603                                            int64_t lhsIndex, int64_t rhsIndex,
1604                                            PatternRewriter &rewriter) const {
1605   VectorType lhsType = op.getLhsType();
1606   VectorType rhsType = op.getRhsType();
1607   VectorType resType = op.getResultType().cast<VectorType>();
1608   // Find the iterator type index and result index.
1609   SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
1610   int64_t iterIndex = -1;
1611   int64_t dimSize = -1;
1612   if (lhsIndex >= 0) {
1613     iterIndex = iMap[0].getDimPosition(lhsIndex);
1614     assert((rhsIndex < 0 || iterIndex == iMap[1].getDimPosition(rhsIndex)) &&
1615            "parallel index should be free in LHS or batch in LHS/RHS");
1616     dimSize = lhsType.getDimSize(lhsIndex);
1617   } else {
1618     assert(rhsIndex >= 0 && "missing parallel index");
1619     iterIndex = iMap[1].getDimPosition(rhsIndex);
1620     dimSize = rhsType.getDimSize(rhsIndex);
1621   }
1622   assert(iterIndex >= 0 && "parallel index not listed in operand mapping");
1623   Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex);
1624   assert(lookup.hasValue() && "parallel index not listed in reduction");
1625   int64_t resIndex = lookup.getValue();
1626   // Construct new iterator types and affine map array attribute.
1627   std::array<AffineMap, 3> lowIndexingMaps = {
1628       adjustMap(iMap[0], iterIndex, rewriter),
1629       adjustMap(iMap[1], iterIndex, rewriter),
1630       adjustMap(iMap[2], iterIndex, rewriter)};
1631   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1632   auto lowIter =
1633       rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
1634   // Unroll into a series of lower dimensional vector.contract ops.
1635   Location loc = op.getLoc();
1636   Value result = rewriter.create<arith::ConstantOp>(
1637       loc, resType, rewriter.getZeroAttr(resType));
1638   for (int64_t d = 0; d < dimSize; ++d) {
1639     auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
1640     auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
1641     auto acc = reshapeLoad(loc, op.acc(), resType, resIndex, d, rewriter);
1642     Value lowContract = rewriter.create<vector::ContractionOp>(
1643         loc, lhs, rhs, acc, lowAffine, lowIter);
1644     result =
1645         reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter);
1646   }
1647   return result;
1648 }
1649 
1650 // Lower one reduction dimension.
1651 Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
1652                                             PatternRewriter &rewriter) const {
1653   auto loc = op.getLoc();
1654   VectorType lhsType = op.getLhsType();
1655   VectorType rhsType = op.getRhsType();
1656   Type resType = op.getResultType();
1657   assert(!resType.isa<VectorType>());
1658   bool isInt = resType.isa<IntegerType>();
1659   // Use iterator index 0.
1660   int64_t iterIndex = 0;
1661   SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
1662   Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
1663   Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
1664   assert(lookupLhs.hasValue() && "missing LHS parallel index");
1665   assert(lookupRhs.hasValue() && "missing RHS parallel index");
1666   int64_t lhsIndex = lookupLhs.getValue();
1667   int64_t rhsIndex = lookupRhs.getValue();
1668   int64_t dimSize = lhsType.getDimSize(lhsIndex);
1669   assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape");
1670   // Base case.
1671   if (lhsType.getRank() == 1) {
1672     assert(rhsType.getRank() == 1 && "corrupt contraction");
1673     Value m = createMul(loc, op.lhs(), op.rhs(), isInt, rewriter);
1674     auto kind = vector::CombiningKind::ADD;
1675     Value res = rewriter.create<vector::ReductionOp>(loc, kind, m);
1676     if (auto acc = op.acc())
1677       res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
1678     return res;
1679   }
1680   // Construct new iterator types and affine map array attribute.
1681   std::array<AffineMap, 3> lowIndexingMaps = {
1682       adjustMap(iMap[0], iterIndex, rewriter),
1683       adjustMap(iMap[1], iterIndex, rewriter),
1684       adjustMap(iMap[2], iterIndex, rewriter)};
1685   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1686   auto lowIter =
1687       rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
1688   // Unroll into a series of lower dimensional vector.contract ops.
1689   // By feeding the initial accumulator into the first contraction,
1690   // and the result of each contraction into the next, eventually
1691   // the sum of all reductions is computed.
1692   Value result = op.acc();
1693   for (int64_t d = 0; d < dimSize; ++d) {
1694     auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
1695     auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
1696     result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result,
1697                                                     lowAffine, lowIter);
1698   }
1699   return result;
1700 }
1701 
1702 } // namespace mlir
1703 
1704 Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
1705     OpBuilder &builder, Operation *op, ArrayRef<Value> ids,
1706     ArrayRef<int64_t> multiplicity, const AffineMap &map) {
1707   OpBuilder::InsertionGuard guard(builder);
1708   builder.setInsertionPointAfter(op);
1709   Location loc = op->getLoc();
1710   if (op->getNumResults() != 1)
1711     return {};
1712   Value result = op->getResult(0);
1713   VectorType type = op->getResult(0).getType().dyn_cast<VectorType>();
1714   if (!type || map.getNumResults() != multiplicity.size())
1715     return {};
1716   // For each dimension being distributed check that the size is a multiple of
1717   // the multiplicity. To handle more sizes we would need to support masking.
1718   unsigned multiplictyCount = 0;
1719   for (auto exp : map.getResults()) {
1720     auto affinExp = exp.dyn_cast<AffineDimExpr>();
1721     if (!affinExp || affinExp.getPosition() >= type.getRank() ||
1722         type.getDimSize(affinExp.getPosition()) %
1723                 multiplicity[multiplictyCount++] !=
1724             0)
1725       return {};
1726   }
1727   DistributeOps ops;
1728   ops.extract =
1729       builder.create<vector::ExtractMapOp>(loc, result, ids, multiplicity, map);
1730   ops.insert =
1731       builder.create<vector::InsertMapOp>(loc, ops.extract, result, ids);
1732   return ops;
1733 }
1734 
1735 /// Progressive lowering of transfer_read. This pattern supports lowering of
1736 /// `vector.transfer_read` to a combination of `vector.load` and
1737 /// `vector.broadcast` if all of the following hold:
1738 /// - Stride of most minor memref dimension must be 1.
1739 /// - Out-of-bounds masking is not required.
1740 /// - If the memref's element type is a vector type then it coincides with the
1741 ///   result type.
1742 /// - The permutation map doesn't perform permutation (broadcasting is allowed).
1743 struct TransferReadToVectorLoadLowering
1744     : public OpRewritePattern<vector::TransferReadOp> {
1745   TransferReadToVectorLoadLowering(MLIRContext *context,
1746                                    llvm::Optional<unsigned> maxRank)
1747       : OpRewritePattern<vector::TransferReadOp>(context),
1748         maxTransferRank(maxRank) {}
1749 
1750   LogicalResult matchAndRewrite(vector::TransferReadOp read,
1751                                 PatternRewriter &rewriter) const override {
1752     if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank)
1753       return failure();
1754 
1755     SmallVector<unsigned, 4> broadcastedDims;
1756     // Permutations are handled by VectorToSCF or
1757     // populateVectorTransferPermutationMapLoweringPatterns.
1758     // We let the 0-d corner case pass-through as it is supported.
1759     if (!read.permutation_map().isMinorIdentityWithBroadcasting(
1760             &broadcastedDims))
1761       return failure();
1762 
1763     auto memRefType = read.getShapedType().dyn_cast<MemRefType>();
1764     if (!memRefType)
1765       return failure();
1766 
1767     // Non-unit strides are handled by VectorToSCF.
1768     if (!vector::isLastMemrefDimUnitStride(memRefType))
1769       return failure();
1770 
1771     // If there is broadcasting involved then we first load the unbroadcasted
1772     // vector, and then broadcast it with `vector.broadcast`.
1773     ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
1774     SmallVector<int64_t, 4> unbroadcastedVectorShape(vectorShape.begin(),
1775                                                      vectorShape.end());
1776     for (unsigned i : broadcastedDims)
1777       unbroadcastedVectorShape[i] = 1;
1778     VectorType unbroadcastedVectorType = VectorType::get(
1779         unbroadcastedVectorShape, read.getVectorType().getElementType());
1780 
1781     // `vector.load` supports vector types as memref's elements only when the
1782     // resulting vector type is the same as the element type.
1783     auto memrefElTy = memRefType.getElementType();
1784     if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType)
1785       return failure();
1786 
1787     // Otherwise, element types of the memref and the vector must match.
1788     if (!memrefElTy.isa<VectorType>() &&
1789         memrefElTy != read.getVectorType().getElementType())
1790       return failure();
1791 
1792     // Out-of-bounds dims are handled by MaterializeTransferMask.
1793     if (read.hasOutOfBoundsDim())
1794       return failure();
1795 
1796     // Create vector load op.
1797     Operation *loadOp;
1798     if (read.mask()) {
1799       Value fill = rewriter.create<vector::SplatOp>(
1800           read.getLoc(), unbroadcastedVectorType, read.padding());
1801       loadOp = rewriter.create<vector::MaskedLoadOp>(
1802           read.getLoc(), unbroadcastedVectorType, read.source(), read.indices(),
1803           read.mask(), fill);
1804     } else {
1805       loadOp = rewriter.create<vector::LoadOp>(read.getLoc(),
1806                                                unbroadcastedVectorType,
1807                                                read.source(), read.indices());
1808     }
1809 
1810     // Insert a broadcasting op if required.
1811     if (!broadcastedDims.empty()) {
1812       rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1813           read, read.getVectorType(), loadOp->getResult(0));
1814     } else {
1815       rewriter.replaceOp(read, loadOp->getResult(0));
1816     }
1817 
1818     return success();
1819   }
1820 
1821   llvm::Optional<unsigned> maxTransferRank;
1822 };
1823 
1824 /// Replace a 0-d vector.load with a memref.load + vector.broadcast.
1825 // TODO: we shouldn't cross the vector/scalar domains just for this
1826 // but atm we lack the infra to avoid it. Possible solutions include:
1827 // - go directly to LLVM + bitcast
1828 // - introduce a bitcast op and likely a new pointer dialect
1829 // - let memref.load/store additionally support the 0-d vector case
1830 // There are still deeper data layout issues lingering even in this
1831 // trivial case (for architectures for which this matters).
1832 struct VectorLoadToMemrefLoadLowering
1833     : public OpRewritePattern<vector::LoadOp> {
1834   using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
1835 
1836   LogicalResult matchAndRewrite(vector::LoadOp loadOp,
1837                                 PatternRewriter &rewriter) const override {
1838     auto vecType = loadOp.getVectorType();
1839     if (vecType.getNumElements() != 1)
1840       return failure();
1841     auto memrefLoad = rewriter.create<memref::LoadOp>(
1842         loadOp.getLoc(), loadOp.base(), loadOp.indices());
1843     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
1844                                                      memrefLoad);
1845     return success();
1846   }
1847 };
1848 
1849 /// Replace a 0-d vector.store with a vector.extractelement + memref.store.
1850 struct VectorStoreToMemrefStoreLowering
1851     : public OpRewritePattern<vector::StoreOp> {
1852   using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
1853 
1854   LogicalResult matchAndRewrite(vector::StoreOp storeOp,
1855                                 PatternRewriter &rewriter) const override {
1856     auto vecType = storeOp.getVectorType();
1857     if (vecType.getNumElements() != 1)
1858       return failure();
1859     Value extracted;
1860     if (vecType.getRank() == 0) {
1861       // TODO: Unifiy once ExtractOp supports 0-d vectors.
1862       extracted = rewriter.create<vector::ExtractElementOp>(
1863           storeOp.getLoc(), storeOp.valueToStore());
1864     } else {
1865       SmallVector<int64_t> indices(vecType.getRank(), 0);
1866       extracted = rewriter.create<vector::ExtractOp>(
1867           storeOp.getLoc(), storeOp.valueToStore(), indices);
1868     }
1869 
1870     rewriter.replaceOpWithNewOp<memref::StoreOp>(
1871         storeOp, extracted, storeOp.base(), storeOp.indices());
1872     return success();
1873   }
1874 };
1875 
1876 /// Progressive lowering of transfer_write. This pattern supports lowering of
1877 /// `vector.transfer_write` to `vector.store` if all of the following hold:
1878 /// - Stride of most minor memref dimension must be 1.
1879 /// - Out-of-bounds masking is not required.
1880 /// - If the memref's element type is a vector type then it coincides with the
1881 ///   type of the written value.
1882 /// - The permutation map is the minor identity map (neither permutation nor
1883 ///   broadcasting is allowed).
1884 struct TransferWriteToVectorStoreLowering
1885     : public OpRewritePattern<vector::TransferWriteOp> {
1886   TransferWriteToVectorStoreLowering(MLIRContext *context,
1887                                      llvm::Optional<unsigned> maxRank)
1888       : OpRewritePattern<vector::TransferWriteOp>(context),
1889         maxTransferRank(maxRank) {}
1890 
1891   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
1892                                 PatternRewriter &rewriter) const override {
1893     if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank)
1894       return failure();
1895 
1896     // Permutations are handled by VectorToSCF or
1897     // populateVectorTransferPermutationMapLoweringPatterns.
1898     if ( // pass-through for the 0-d corner case.
1899         !write.permutation_map().isMinorIdentity())
1900       return failure();
1901 
1902     auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
1903     if (!memRefType)
1904       return failure();
1905 
1906     // Non-unit strides are handled by VectorToSCF.
1907     if (!vector::isLastMemrefDimUnitStride(memRefType))
1908       return failure();
1909 
1910     // `vector.store` supports vector types as memref's elements only when the
1911     // type of the vector value being written is the same as the element type.
1912     auto memrefElTy = memRefType.getElementType();
1913     if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType())
1914       return failure();
1915 
1916     // Otherwise, element types of the memref and the vector must match.
1917     if (!memrefElTy.isa<VectorType>() &&
1918         memrefElTy != write.getVectorType().getElementType())
1919       return failure();
1920 
1921     // Out-of-bounds dims are handled by MaterializeTransferMask.
1922     if (write.hasOutOfBoundsDim())
1923       return failure();
1924     if (write.mask()) {
1925       rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
1926           write, write.source(), write.indices(), write.mask(), write.vector());
1927     } else {
1928       rewriter.replaceOpWithNewOp<vector::StoreOp>(
1929           write, write.vector(), write.source(), write.indices());
1930     }
1931     return success();
1932   }
1933 
1934   llvm::Optional<unsigned> maxTransferRank;
1935 };
1936 
1937 // Returns the values in `arrayAttr` as an integer vector.
1938 static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) {
1939   return llvm::to_vector<4>(
1940       llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
1941                       [](IntegerAttr attr) { return attr.getInt(); }));
1942 }
1943 
1944 // Shuffles vector.bitcast op after vector.extract op.
1945 //
1946 // This transforms IR like:
1947 //   %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
1948 //   %1 = vector.extract %0[3] : vector<8xf16>
1949 // Into:
1950 //   %0 = vector.extract %src[1] : vector<4xf32>
1951 //   %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16>
1952 //   %2 = vector.extract %1[1] : vector<2xf16>
1953 struct BubbleDownVectorBitCastForExtract
1954     : public OpRewritePattern<vector::ExtractOp> {
1955   using OpRewritePattern::OpRewritePattern;
1956 
1957   LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
1958                                 PatternRewriter &rewriter) const override {
1959     // Only support extracting scalars for now.
1960     if (extractOp.getVectorType().getRank() != 1)
1961       return failure();
1962 
1963     auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>();
1964     if (!castOp)
1965       return failure();
1966 
1967     VectorType castSrcType = castOp.getSourceVectorType();
1968     VectorType castDstType = castOp.getResultVectorType();
1969     assert(castSrcType.getRank() == castDstType.getRank());
1970 
1971     // Fail to match if we only have one element in the cast op source.
1972     // This is to avoid infinite loop given that this pattern can generate
1973     // such cases.
1974     if (castSrcType.getNumElements() == 1)
1975       return failure();
1976 
1977     // Only support casting to a larger number of elements or now.
1978     // E.g., vector<4xf32> -> vector<8xf16>.
1979     if (castSrcType.getNumElements() > castDstType.getNumElements())
1980       return failure();
1981 
1982     unsigned expandRatio =
1983         castDstType.getNumElements() / castSrcType.getNumElements();
1984 
1985     auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t {
1986       return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
1987     };
1988 
1989     uint64_t index = getFirstIntValue(extractOp.position());
1990 
1991     // Get the single scalar (as a vector) in the source value that packs the
1992     // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
1993     VectorType oneScalarType =
1994         VectorType::get({1}, castSrcType.getElementType());
1995     Value packedValue = rewriter.create<vector::ExtractOp>(
1996         extractOp.getLoc(), oneScalarType, castOp.source(),
1997         rewriter.getI64ArrayAttr(index / expandRatio));
1998 
1999     // Cast it to a vector with the desired scalar's type.
2000     // E.g. f32 -> vector<2xf16>
2001     VectorType packedType =
2002         VectorType::get({expandRatio}, castDstType.getElementType());
2003     Value castedValue = rewriter.create<vector::BitCastOp>(
2004         extractOp.getLoc(), packedType, packedValue);
2005 
2006     // Finally extract the desired scalar.
2007     rewriter.replaceOpWithNewOp<vector::ExtractOp>(
2008         extractOp, extractOp.getType(), castedValue,
2009         rewriter.getI64ArrayAttr(index % expandRatio));
2010 
2011     return success();
2012   }
2013 };
2014 
2015 // Shuffles vector.bitcast op after vector.extract_strided_slice op.
2016 //
2017 // This transforms IR like:
2018 //    %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
2019 //     %0 = vector.extract_strided_slice %cast {
2020 //            offsets = [4], sizes = [4], strides = [1]
2021 //          } : vector<8xf16> to vector<4xf16>
2022 // Into:
2023 //   %0 = vector.extract_strided_slice %src {
2024 //          offsets = [2], sizes = [2], strides = [1]
2025 //        } : vector<4xf32> to vector<2xf32>
2026 //   %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16>
2027 struct BubbleDownBitCastForStridedSliceExtract
2028     : public OpRewritePattern<vector::ExtractStridedSliceOp> {
2029   using OpRewritePattern::OpRewritePattern;
2030 
2031   LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
2032                                 PatternRewriter &rewriter) const override {
2033     auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>();
2034     if (!castOp)
2035       return failure();
2036 
2037     VectorType castSrcType = castOp.getSourceVectorType();
2038     VectorType castDstType = castOp.getResultVectorType();
2039     assert(castSrcType.getRank() == castDstType.getRank());
2040 
2041     int64_t castSrcLastDim = castSrcType.getShape().back();
2042     int64_t castDstLastDim = castDstType.getShape().back();
2043     // Require casting to more elements for now; other cases to be implemented.
2044     if (castSrcLastDim > castDstLastDim)
2045       return failure();
2046 
2047     // Only accept all one strides for now.
2048     if (llvm::any_of(extractOp.strides().getAsValueRange<IntegerAttr>(),
2049                      [](const APInt &val) { return !val.isOneValue(); }))
2050       return failure();
2051 
2052     unsigned rank = extractOp.getVectorType().getRank();
2053     assert(castDstLastDim % castSrcLastDim == 0);
2054     int64_t expandRatio = castDstLastDim / castSrcLastDim;
2055 
2056     // If we have a less number of offsets than the rank, then implicitly we
2057     // are selecting the full range for the last bitcasted dimension; other
2058     // dimensions aren't affected. Otherwise, we need to scale down the last
2059     // dimension's offset given we are extracting from less elements now.
2060     ArrayAttr newOffsets = extractOp.offsets();
2061     if (newOffsets.size() == rank) {
2062       SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
2063       if (offsets.back() % expandRatio != 0)
2064         return failure();
2065       offsets.back() = offsets.back() / expandRatio;
2066       newOffsets = rewriter.getI64ArrayAttr(offsets);
2067     }
2068 
2069     // Similarly for sizes.
2070     ArrayAttr newSizes = extractOp.sizes();
2071     if (newSizes.size() == rank) {
2072       SmallVector<int64_t, 4> sizes = getIntValueVector(newSizes);
2073       if (sizes.back() % expandRatio != 0)
2074         return failure();
2075       sizes.back() = sizes.back() / expandRatio;
2076       newSizes = rewriter.getI64ArrayAttr(sizes);
2077     }
2078 
2079     SmallVector<int64_t, 4> dims =
2080         llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape());
2081     dims.back() = dims.back() / expandRatio;
2082     VectorType newExtractType =
2083         VectorType::get(dims, castSrcType.getElementType());
2084 
2085     auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
2086         extractOp.getLoc(), newExtractType, castOp.source(), newOffsets,
2087         newSizes, extractOp.strides());
2088 
2089     rewriter.replaceOpWithNewOp<vector::BitCastOp>(
2090         extractOp, extractOp.getType(), newExtractOp);
2091 
2092     return success();
2093   }
2094 };
2095 
2096 // Shuffles vector.bitcast op before vector.insert_strided_slice op.
2097 //
2098 // This transforms IR like:
2099 //   %0 = vector.insert_strided_slice %src, %dst {
2100 //          offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
2101 //   %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
2102 // Into:
2103 //   %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32>
2104 //   %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32>
2105 //   %2 = vector.insert_strided_slice %src, %dst {
2106 //          offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
2107 struct BubbleUpBitCastForStridedSliceInsert
2108     : public OpRewritePattern<vector::BitCastOp> {
2109   using OpRewritePattern::OpRewritePattern;
2110   LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
2111                                 PatternRewriter &rewriter) const override {
2112     VectorType castSrcType = bitcastOp.getSourceVectorType();
2113     VectorType castDstType = bitcastOp.getResultVectorType();
2114     assert(castSrcType.getRank() == castDstType.getRank());
2115 
2116     int64_t castSrcLastDim = castSrcType.getShape().back();
2117     int64_t castDstLastDim = castDstType.getShape().back();
2118     // Require casting to less elements for now; other cases to be implemented.
2119     if (castSrcLastDim < castDstLastDim)
2120       return failure();
2121 
2122     assert(castSrcLastDim % castDstLastDim == 0);
2123     int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
2124 
2125     auto insertOp =
2126         bitcastOp.source().getDefiningOp<vector::InsertStridedSliceOp>();
2127     if (!insertOp)
2128       return failure();
2129 
2130     // Only accept all one strides for now.
2131     if (llvm::any_of(insertOp.strides().getAsValueRange<IntegerAttr>(),
2132                      [](const APInt &val) { return !val.isOneValue(); }))
2133       return failure();
2134 
2135     unsigned rank = insertOp.getSourceVectorType().getRank();
2136     // Require insert op to have the same rank for the source and destination
2137     // vector; other cases to be implemented.
2138     if (rank != insertOp.getDestVectorType().getRank())
2139       return failure();
2140 
2141     ArrayAttr newOffsets = insertOp.offsets();
2142     assert(newOffsets.size() == rank);
2143     SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
2144     if (offsets.back() % shrinkRatio != 0)
2145       return failure();
2146     offsets.back() = offsets.back() / shrinkRatio;
2147     newOffsets = rewriter.getI64ArrayAttr(offsets);
2148 
2149     SmallVector<int64_t, 4> srcDims =
2150         llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
2151     srcDims.back() = srcDims.back() / shrinkRatio;
2152     VectorType newCastSrcType =
2153         VectorType::get(srcDims, castDstType.getElementType());
2154 
2155     auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
2156         bitcastOp.getLoc(), newCastSrcType, insertOp.source());
2157 
2158     SmallVector<int64_t, 4> dstDims =
2159         llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
2160     dstDims.back() = dstDims.back() / shrinkRatio;
2161     VectorType newCastDstType =
2162         VectorType::get(dstDims, castDstType.getElementType());
2163 
2164     auto newCastDstOp = rewriter.create<vector::BitCastOp>(
2165         bitcastOp.getLoc(), newCastDstType, insertOp.dest());
2166 
2167     rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
2168         bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
2169         insertOp.strides());
2170 
2171     return success();
2172   }
2173 };
2174 
2175 static Value createCastToIndexLike(PatternRewriter &rewriter, Location loc,
2176                                    Type targetType, Value value) {
2177   if (targetType == value.getType())
2178     return value;
2179 
2180   bool targetIsIndex = targetType.isIndex();
2181   bool valueIsIndex = value.getType().isIndex();
2182   if (targetIsIndex ^ valueIsIndex)
2183     return rewriter.create<arith::IndexCastOp>(loc, targetType, value);
2184 
2185   auto targetIntegerType = targetType.dyn_cast<IntegerType>();
2186   auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
2187   assert(targetIntegerType && valueIntegerType &&
2188          "unexpected cast between types other than integers and index");
2189   assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
2190 
2191   if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
2192     return rewriter.create<arith::ExtSIOp>(loc, targetIntegerType, value);
2193   return rewriter.create<arith::TruncIOp>(loc, targetIntegerType, value);
2194 }
2195 
2196 // Helper that returns a vector comparison that constructs a mask:
2197 //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
2198 //
2199 // If `dim == 0` then the result will be a 0-D vector.
2200 //
2201 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
2202 //       much more compact, IR for this operation, but LLVM eventually
2203 //       generates more elaborate instructions for this intrinsic since it
2204 //       is very conservative on the boundary conditions.
2205 static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
2206                                    bool indexOptimizations, int64_t dim,
2207                                    Value b, Value *off = nullptr) {
2208   auto loc = op->getLoc();
2209   // If we can assume all indices fit in 32-bit, we perform the vector
2210   // comparison in 32-bit to get a higher degree of SIMD parallelism.
2211   // Otherwise we perform the vector comparison using 64-bit indices.
2212   Type idxType =
2213       indexOptimizations ? rewriter.getI32Type() : rewriter.getI64Type();
2214   DenseIntElementsAttr indicesAttr;
2215   if (dim == 0 && indexOptimizations) {
2216     indicesAttr = DenseIntElementsAttr::get(
2217         VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0});
2218   } else if (dim == 0) {
2219     indicesAttr = DenseIntElementsAttr::get(
2220         VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0});
2221   } else if (indexOptimizations) {
2222     indicesAttr = rewriter.getI32VectorAttr(
2223         llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
2224   } else {
2225     indicesAttr = rewriter.getI64VectorAttr(
2226         llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
2227   }
2228   Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
2229   // Add in an offset if requested.
2230   if (off) {
2231     Value o = createCastToIndexLike(rewriter, loc, idxType, *off);
2232     Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o);
2233     indices = rewriter.create<arith::AddIOp>(loc, ov, indices);
2234   }
2235   // Construct the vector comparison.
2236   Value bound = createCastToIndexLike(rewriter, loc, idxType, b);
2237   Value bounds =
2238       rewriter.create<vector::SplatOp>(loc, indices.getType(), bound);
2239   return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
2240                                         bounds);
2241 }
2242 
2243 template <typename ConcreteOp>
2244 struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
2245 public:
2246   explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt)
2247       : mlir::OpRewritePattern<ConcreteOp>(context),
2248         indexOptimizations(enableIndexOpt) {}
2249 
2250   LogicalResult matchAndRewrite(ConcreteOp xferOp,
2251                                 PatternRewriter &rewriter) const override {
2252     if (!xferOp.hasOutOfBoundsDim())
2253       return failure();
2254 
2255     if (xferOp.getVectorType().getRank() > 1 ||
2256         llvm::size(xferOp.indices()) == 0)
2257       return failure();
2258 
2259     Location loc = xferOp->getLoc();
2260     VectorType vtp = xferOp.getVectorType();
2261 
2262     // Create the in-bounds mask with all elements between [0 .. dim - offset)
2263     // set and [dim - offset .. vector_length) unset.
2264     //
2265     // TODO: when the leaf transfer rank is k > 1, we need the last `k`
2266     //       dimensions here.
2267     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
2268     Value off = xferOp.indices()[lastIndex];
2269     Value dim =
2270         vector::createOrFoldDimOp(rewriter, loc, xferOp.source(), lastIndex);
2271     Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off);
2272     Value mask = rewriter.create<vector::CreateMaskOp>(
2273         loc,
2274         VectorType::get(vtp.getShape(), rewriter.getI1Type(),
2275                         vtp.getNumScalableDims()),
2276         b);
2277     if (xferOp.mask()) {
2278       // Intersect the in-bounds with the mask specified as an op parameter.
2279       mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.mask());
2280     }
2281 
2282     rewriter.updateRootInPlace(xferOp, [&]() {
2283       xferOp.maskMutable().assign(mask);
2284       xferOp.in_boundsAttr(rewriter.getBoolArrayAttr({true}));
2285     });
2286 
2287     return success();
2288   }
2289 
2290 private:
2291   const bool indexOptimizations;
2292 };
2293 
2294 /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only).
2295 class VectorCreateMaskOpConversion
2296     : public OpRewritePattern<vector::CreateMaskOp> {
2297 public:
2298   explicit VectorCreateMaskOpConversion(MLIRContext *context,
2299                                         bool enableIndexOpt)
2300       : mlir::OpRewritePattern<vector::CreateMaskOp>(context),
2301         indexOptimizations(enableIndexOpt) {}
2302 
2303   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
2304                                 PatternRewriter &rewriter) const override {
2305     auto dstType = op.getType();
2306     int64_t rank = dstType.getRank();
2307     if (rank > 1)
2308       return failure();
2309     rewriter.replaceOp(
2310         op, buildVectorComparison(rewriter, op, indexOptimizations,
2311                                   rank == 0 ? 0 : dstType.getDimSize(0),
2312                                   op.getOperand(0)));
2313     return success();
2314   }
2315 
2316 private:
2317   const bool indexOptimizations;
2318 };
2319 
2320 // Drop inner most contiguous unit dimensions from transfer_read operand.
2321 class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
2322   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
2323 
2324   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
2325                                 PatternRewriter &rewriter) const override {
2326     // TODO: support 0-d corner case.
2327     if (readOp.getTransferRank() == 0)
2328       return failure();
2329 
2330     // TODO: support mask.
2331     if (readOp.mask())
2332       return failure();
2333 
2334     auto srcType = readOp.source().getType().dyn_cast<MemRefType>();
2335     if (!srcType || !srcType.hasStaticShape())
2336       return failure();
2337 
2338     if (!readOp.permutation_map().isMinorIdentity())
2339       return failure();
2340 
2341     auto targetType = readOp.getVectorType();
2342     if (targetType.getRank() <= 1)
2343       return failure();
2344 
2345     SmallVector<int64_t> srcStrides;
2346     int64_t srcOffset;
2347     if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
2348       return failure();
2349 
2350     size_t dimsToDrop = 0;
2351     for (size_t i = 1; i < srcStrides.size(); ++i) {
2352       int dim = srcType.getRank() - i - 1;
2353       if (srcStrides[dim] == 1) {
2354         dimsToDrop++;
2355       } else {
2356         break;
2357       }
2358     }
2359     if (dimsToDrop == 0)
2360       return failure();
2361 
2362     auto resultTargetVecType =
2363         VectorType::get(targetType.getShape().drop_back(dimsToDrop),
2364                         targetType.getElementType());
2365 
2366     MemRefType resultMemrefType;
2367     if (srcType.getLayout().getAffineMap().isIdentity()) {
2368       resultMemrefType = MemRefType::get(
2369           srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
2370           {}, srcType.getMemorySpaceAsInt());
2371     } else {
2372       AffineMap map = srcType.getLayout().getAffineMap();
2373       int numResultDims = map.getNumDims() - dimsToDrop;
2374       int numSymbols = map.getNumSymbols();
2375       for (size_t i = 0; i < dimsToDrop; ++i) {
2376         int dim = srcType.getRank() - i - 1;
2377         map = map.replace(rewriter.getAffineDimExpr(dim),
2378                           rewriter.getAffineConstantExpr(0), numResultDims,
2379                           numSymbols);
2380       }
2381       resultMemrefType = MemRefType::get(
2382           srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
2383           map, srcType.getMemorySpaceAsInt());
2384     }
2385 
2386     auto loc = readOp.getLoc();
2387     SmallVector<int64_t> offsets(srcType.getRank(), 0);
2388     SmallVector<int64_t> strides(srcType.getRank(), 1);
2389 
2390     ArrayAttr inBoundsAttr =
2391         readOp.in_bounds()
2392             ? rewriter.getArrayAttr(
2393                   readOp.in_boundsAttr().getValue().drop_back(dimsToDrop))
2394             : ArrayAttr();
2395     Value rankedReducedView = rewriter.create<memref::SubViewOp>(
2396         loc, resultMemrefType, readOp.source(), offsets, srcType.getShape(),
2397         strides);
2398     auto permMap = getTransferMinorIdentityMap(
2399         rankedReducedView.getType().cast<ShapedType>(), resultTargetVecType);
2400     Value result = rewriter.create<vector::TransferReadOp>(
2401         loc, resultTargetVecType, rankedReducedView,
2402         readOp.indices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
2403         readOp.padding(),
2404         // TODO: support mask.
2405         /*mask=*/Value(), inBoundsAttr);
2406     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
2407                                                      result);
2408     return success();
2409   }
2410 };
2411 
2412 namespace {
2413 
2414 /// This function checks to see if the vector combining kind
2415 /// is consistent with the integer or float element type.
2416 static bool isValidKind(bool isInt, vector::CombiningKind kind) {
2417   using vector::CombiningKind;
2418   enum class KindType { FLOAT, INT, INVALID };
2419   KindType type{KindType::INVALID};
2420   switch (kind) {
2421   case CombiningKind::MINF:
2422   case CombiningKind::MAXF:
2423     type = KindType::FLOAT;
2424     break;
2425   case CombiningKind::MINUI:
2426   case CombiningKind::MINSI:
2427   case CombiningKind::MAXUI:
2428   case CombiningKind::MAXSI:
2429   case CombiningKind::AND:
2430   case CombiningKind::OR:
2431   case CombiningKind::XOR:
2432     type = KindType::INT;
2433     break;
2434   case CombiningKind::ADD:
2435   case CombiningKind::MUL:
2436     type = isInt ? KindType::INT : KindType::FLOAT;
2437     break;
2438   }
2439   bool isValidIntKind = (type == KindType::INT) && isInt;
2440   bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt);
2441   return (isValidIntKind || isValidFloatKind);
2442 }
2443 
2444 /// This function constructs the appropriate integer or float
2445 /// operation given the vector combining kind and operands. The
2446 /// supported int operations are : add, mul, min (signed/unsigned),
2447 /// max(signed/unsigned), and, or, xor. The supported float
2448 /// operations are : add, mul, min and max.
2449 static Value genOperator(Location loc, Value x, Value y,
2450                          vector::CombiningKind kind,
2451                          PatternRewriter &rewriter) {
2452   using vector::CombiningKind;
2453 
2454   auto elType = x.getType().cast<VectorType>().getElementType();
2455   bool isInt = elType.isIntOrIndex();
2456 
2457   Value combinedResult{nullptr};
2458   switch (kind) {
2459   case CombiningKind::ADD:
2460     if (isInt)
2461       combinedResult = rewriter.create<arith::AddIOp>(loc, x, y);
2462     else
2463       combinedResult = rewriter.create<arith::AddFOp>(loc, x, y);
2464     break;
2465   case CombiningKind::MUL:
2466     if (isInt)
2467       combinedResult = rewriter.create<arith::MulIOp>(loc, x, y);
2468     else
2469       combinedResult = rewriter.create<arith::MulFOp>(loc, x, y);
2470     break;
2471   case CombiningKind::MINUI:
2472     combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y);
2473     break;
2474   case CombiningKind::MINSI:
2475     combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y);
2476     break;
2477   case CombiningKind::MAXUI:
2478     combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y);
2479     break;
2480   case CombiningKind::MAXSI:
2481     combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y);
2482     break;
2483   case CombiningKind::AND:
2484     combinedResult = rewriter.create<arith::AndIOp>(loc, x, y);
2485     break;
2486   case CombiningKind::OR:
2487     combinedResult = rewriter.create<arith::OrIOp>(loc, x, y);
2488     break;
2489   case CombiningKind::XOR:
2490     combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y);
2491     break;
2492   case CombiningKind::MINF:
2493     combinedResult = rewriter.create<arith::MinFOp>(loc, x, y);
2494     break;
2495   case CombiningKind::MAXF:
2496     combinedResult = rewriter.create<arith::MaxFOp>(loc, x, y);
2497     break;
2498   }
2499   return combinedResult;
2500 }
2501 
2502 /// Convert vector.scan op into arith ops and
2503 /// vector.insert_strided_slice/extract_strided_slice
2504 ///
2505 /// Ex:
2506 /// ```
2507 ///   %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim =
2508 ///   1} :
2509 ///     (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>)
2510 /// ```
2511 /// Gets converted to:
2512 /// ```
2513 ///   %cst = arith.constant dense<0> : vector<2x3xi32>
2514 ///   %0 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 1],
2515 ///   strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %1 =
2516 ///   vector.insert_strided_slice %0, %cst {offsets = [0, 0], strides = [1, 1]}
2517 ///   : vector<2x1xi32> into vector<2x3xi32> %2 = vector.extract_strided_slice
2518 ///   %arg0 {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} :
2519 ///   vector<2x3xi32> to vector<2x1xi32> %3 = arith.muli %0, %2 :
2520 ///   vector<2x1xi32> %4 = vector.insert_strided_slice %3, %1 {offsets = [0, 1],
2521 ///   strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %5 =
2522 ///   vector.extract_strided_slice %arg0 {offsets = [0, 2], sizes = [2, 1],
2523 ///   strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %6 = arith.muli %3,
2524 ///   %5 : vector<2x1xi32> %7 = vector.insert_strided_slice %6, %4 {offsets =
2525 ///   [0, 2], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %8 =
2526 ///   vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32> return %7, %8 :
2527 ///   vector<2x3xi32>, vector<2xi32>
2528 /// ```
2529 struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
2530   using OpRewritePattern<vector::ScanOp>::OpRewritePattern;
2531 
2532   LogicalResult matchAndRewrite(vector::ScanOp scanOp,
2533                                 PatternRewriter &rewriter) const override {
2534     auto loc = scanOp.getLoc();
2535     VectorType destType = scanOp.getDestType();
2536     ArrayRef<int64_t> destShape = destType.getShape();
2537     auto elType = destType.getElementType();
2538     bool isInt = elType.isIntOrIndex();
2539     if (!isValidKind(isInt, scanOp.kind()))
2540       return failure();
2541 
2542     VectorType resType = VectorType::get(destShape, elType);
2543     Value result = rewriter.create<arith::ConstantOp>(
2544         loc, resType, rewriter.getZeroAttr(resType));
2545     int64_t reductionDim = scanOp.reduction_dim();
2546     bool inclusive = scanOp.inclusive();
2547     int64_t destRank = destType.getRank();
2548     VectorType initialValueType = scanOp.getInitialValueType();
2549     int64_t initialValueRank = initialValueType.getRank();
2550 
2551     SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end());
2552     reductionShape[reductionDim] = 1;
2553     VectorType reductionType = VectorType::get(reductionShape, elType);
2554     SmallVector<int64_t> offsets(destRank, 0);
2555     SmallVector<int64_t> strides(destRank, 1);
2556     SmallVector<int64_t> sizes(destShape.begin(), destShape.end());
2557     sizes[reductionDim] = 1;
2558     ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes);
2559     ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides);
2560 
2561     Value lastOutput, lastInput;
2562     for (int i = 0; i < destShape[reductionDim]; i++) {
2563       offsets[reductionDim] = i;
2564       ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets);
2565       Value input = rewriter.create<vector::ExtractStridedSliceOp>(
2566           loc, reductionType, scanOp.source(), scanOffsets, scanSizes,
2567           scanStrides);
2568       Value output;
2569       if (i == 0) {
2570         if (inclusive) {
2571           output = input;
2572         } else {
2573           if (initialValueRank == 0) {
2574             // ShapeCastOp cannot handle 0-D vectors
2575             output = rewriter.create<vector::BroadcastOp>(
2576                 loc, input.getType(), scanOp.initial_value());
2577           } else {
2578             output = rewriter.create<vector::ShapeCastOp>(
2579                 loc, input.getType(), scanOp.initial_value());
2580           }
2581         }
2582       } else {
2583         Value y = inclusive ? input : lastInput;
2584         output = genOperator(loc, lastOutput, y, scanOp.kind(), rewriter);
2585         assert(output != nullptr);
2586       }
2587       result = rewriter.create<vector::InsertStridedSliceOp>(
2588           loc, output, result, offsets, strides);
2589       lastOutput = output;
2590       lastInput = input;
2591     }
2592 
2593     Value reduction;
2594     if (initialValueRank == 0) {
2595       Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0);
2596       reduction =
2597           rewriter.create<vector::BroadcastOp>(loc, initialValueType, v);
2598     } else {
2599       reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType,
2600                                                        lastOutput);
2601     }
2602 
2603     rewriter.replaceOp(scanOp, {result, reduction});
2604     return success();
2605   }
2606 };
2607 
2608 } // namespace
2609 
2610 void mlir::vector::populateVectorMaskMaterializationPatterns(
2611     RewritePatternSet &patterns, bool indexOptimizations) {
2612   patterns.add<VectorCreateMaskOpConversion,
2613                MaterializeTransferMask<vector::TransferReadOp>,
2614                MaterializeTransferMask<vector::TransferWriteOp>>(
2615       patterns.getContext(), indexOptimizations);
2616 }
2617 
2618 void mlir::vector::populateShapeCastFoldingPatterns(
2619     RewritePatternSet &patterns) {
2620   patterns.add<ShapeCastOpFolder>(patterns.getContext());
2621 }
2622 
2623 void mlir::vector::populateBubbleVectorBitCastOpPatterns(
2624     RewritePatternSet &patterns) {
2625   patterns.add<BubbleDownVectorBitCastForExtract,
2626                BubbleDownBitCastForStridedSliceExtract,
2627                BubbleUpBitCastForStridedSliceInsert>(patterns.getContext());
2628 }
2629 
2630 void mlir::vector::populateVectorBroadcastLoweringPatterns(
2631     RewritePatternSet &patterns) {
2632   patterns.add<BroadcastOpLowering>(patterns.getContext());
2633 }
2634 
2635 void mlir::vector::populateVectorMaskOpLoweringPatterns(
2636     RewritePatternSet &patterns) {
2637   patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
2638       patterns.getContext());
2639 }
2640 
2641 void mlir::vector::populateVectorShapeCastLoweringPatterns(
2642     RewritePatternSet &patterns) {
2643   patterns.add<ShapeCastOp2DDownCastRewritePattern,
2644                ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>(
2645       patterns.getContext());
2646 }
2647 
2648 void mlir::vector::populateVectorContractLoweringPatterns(
2649     RewritePatternSet &patterns, VectorTransformsOptions options) {
2650   patterns.add<OuterProductOpLowering>(patterns.getContext());
2651   patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
2652                ContractionOpToOuterProductOpLowering>(options,
2653                                                       patterns.getContext());
2654 }
2655 
2656 void mlir::vector::populateVectorTransposeLoweringPatterns(
2657     RewritePatternSet &patterns, VectorTransformsOptions options) {
2658   patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
2659       options, patterns.getContext());
2660 }
2661 
2662 void mlir::vector::populateVectorReductionToContractPatterns(
2663     RewritePatternSet &patterns) {
2664   patterns.add<MultiReduceToContract, CombineContractBroadcast,
2665                CombineContractTranspose, ReorderCastOpsOnBroadcast,
2666                ReorderCastOpsOnTranspose>(patterns.getContext());
2667 }
2668 
2669 void mlir::vector::
2670     populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
2671         RewritePatternSet &patterns) {
2672   patterns.add<DropInnerMostUnitDims>(patterns.getContext());
2673 }
2674 
2675 void mlir::vector::populateVectorTransferLoweringPatterns(
2676     RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank) {
2677   patterns.add<TransferReadToVectorLoadLowering,
2678                TransferWriteToVectorStoreLowering>(patterns.getContext(),
2679                                                    maxTransferRank);
2680   patterns
2681       .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
2682           patterns.getContext());
2683 }
2684 
2685 void mlir::vector::populateVectorScanLoweringPatterns(
2686     RewritePatternSet &patterns) {
2687   patterns.add<ScanToArithOps>(patterns.getContext());
2688 }
2689