1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites as 1->N patterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
14 
15 #include <type_traits>
16 
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
19 #include "mlir/Dialect/Linalg/IR/Linalg.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/Dialect/SCF/SCF.h"
22 #include "mlir/Dialect/Utils/IndexingUtils.h"
23 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
24 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
25 #include "mlir/IR/ImplicitLocOpBuilder.h"
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/PatternMatch.h"
28 #include "mlir/Interfaces/VectorInterfaces.h"
29 
30 #include "llvm/ADT/DenseSet.h"
31 #include "llvm/ADT/MapVector.h"
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/raw_ostream.h"
36 
37 #define DEBUG_TYPE "vector-to-vector"
38 
39 using namespace mlir;
40 using namespace mlir::vector;
41 
42 // Helper to find an index in an affine map.
43 static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
44   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
45     int64_t idx = map.getDimPosition(i);
46     if (idx == index)
47       return i;
48   }
49   return None;
50 }
51 
52 // Helper to construct iterator types with one index removed.
53 static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes,
54                                             int64_t index) {
55   SmallVector<Attribute, 4> results;
56   for (const auto &it : llvm::enumerate(iteratorTypes)) {
57     int64_t idx = it.index();
58     if (idx == index)
59       continue;
60     results.push_back(it.value());
61   }
62   return results;
63 }
64 
65 // Helper to construct an affine map with one index removed.
66 static AffineMap adjustMap(AffineMap map, int64_t index,
67                            PatternRewriter &rewriter) {
68   auto *ctx = rewriter.getContext();
69   SmallVector<AffineExpr, 4> results;
70   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
71     int64_t idx = map.getDimPosition(i);
72     if (idx == index)
73       continue;
74     // Re-insert remaining indices, but renamed when occurring
75     // after the removed index.
76     auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
77     results.push_back(targetExpr);
78   }
79   return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
80 }
81 
82 // Helper method to possibly drop a dimension in a load.
83 // TODO
84 static Value reshapeLoad(Location loc, Value val, VectorType type,
85                          int64_t index, int64_t pos,
86                          PatternRewriter &rewriter) {
87   if (index == -1)
88     return val;
89   Type lowType = VectorType::Builder(type).dropDim(0);
90   // At extraction dimension?
91   if (index == 0) {
92     auto posAttr = rewriter.getI64ArrayAttr(pos);
93     return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
94   }
95   // Unroll leading dimensions.
96   VectorType vType = lowType.cast<VectorType>();
97   Type resType = VectorType::Builder(type).dropDim(index);
98   auto resVectorType = resType.cast<VectorType>();
99   Value result = rewriter.create<arith::ConstantOp>(
100       loc, resVectorType, rewriter.getZeroAttr(resVectorType));
101   for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) {
102     auto posAttr = rewriter.getI64ArrayAttr(d);
103     Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
104     Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
105     result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result,
106                                                posAttr);
107   }
108   return result;
109 }
110 
111 // Helper method to possibly drop a dimension in a store.
112 // TODO
113 static Value reshapeStore(Location loc, Value val, Value result,
114                           VectorType type, int64_t index, int64_t pos,
115                           PatternRewriter &rewriter) {
116   // Unmodified?
117   if (index == -1)
118     return val;
119   // At insertion dimension?
120   if (index == 0) {
121     auto posAttr = rewriter.getI64ArrayAttr(pos);
122     return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
123   }
124   // Unroll leading dimensions.
125   Type lowType = VectorType::Builder(type).dropDim(0);
126   VectorType vType = lowType.cast<VectorType>();
127   Type insType = VectorType::Builder(vType).dropDim(0);
128   for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
129     auto posAttr = rewriter.getI64ArrayAttr(d);
130     Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
131     Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr);
132     Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
133     result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr);
134   }
135   return result;
136 }
137 
138 template <typename IntType>
139 static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
140   return llvm::to_vector<4>(llvm::map_range(
141       arrayAttr.getAsRange<IntegerAttr>(),
142       [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
143 }
144 
145 namespace {
146 
147 /// ShapeCastOpFolder folds cancelling ShapeCastOps away.
148 //
149 // Example:
150 //
151 //  The following MLIR with cancelling ShapeCastOps:
152 //
153 //   %0 = source : vector<5x4x2xf32>
154 //   %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
155 //   %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
156 //   %3 = user %2 : vector<5x4x2xf32>
157 //
158 //  Should canonicalize to the following:
159 //
160 //   %0 = source : vector<5x4x2xf32>
161 //   %1 = user %0 : vector<5x4x2xf32>
162 //
163 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
164   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
165 
166   LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
167                                 PatternRewriter &rewriter) const override {
168     // Check if 'shapeCastOp' has vector source/result type.
169     auto sourceVectorType =
170         shapeCastOp.source().getType().dyn_cast_or_null<VectorType>();
171     auto resultVectorType =
172         shapeCastOp.result().getType().dyn_cast_or_null<VectorType>();
173     if (!sourceVectorType || !resultVectorType)
174       return failure();
175 
176     // Check if shape cast op source operand is also a shape cast op.
177     auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
178         shapeCastOp.source().getDefiningOp());
179     if (!sourceShapeCastOp)
180       return failure();
181     auto operandSourceVectorType =
182         sourceShapeCastOp.source().getType().cast<VectorType>();
183     auto operandResultVectorType = sourceShapeCastOp.getType();
184 
185     // Check if shape cast operations invert each other.
186     if (operandSourceVectorType != resultVectorType ||
187         operandResultVectorType != sourceVectorType)
188       return failure();
189 
190     rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.source());
191     return success();
192   }
193 };
194 
195 /// Progressive lowering of BroadcastOp.
196 class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
197 public:
198   using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
199 
200   LogicalResult matchAndRewrite(vector::BroadcastOp op,
201                                 PatternRewriter &rewriter) const override {
202     auto loc = op.getLoc();
203     VectorType dstType = op.getVectorType();
204     VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
205     Type eltType = dstType.getElementType();
206 
207     // Scalar to any vector can use splat.
208     if (!srcType) {
209       rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.source());
210       return success();
211     }
212 
213     // Determine rank of source and destination.
214     int64_t srcRank = srcType.getRank();
215     int64_t dstRank = dstType.getRank();
216 
217     // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
218     if (srcRank <= 1 && dstRank == 1) {
219       Value ext;
220       if (srcRank == 0)
221         ext = rewriter.create<vector::ExtractElementOp>(loc, op.source());
222       else
223         ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
224       rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
225       return success();
226     }
227 
228     // Duplicate this rank.
229     // For example:
230     //   %x = broadcast %y  : k-D to n-D, k < n
231     // becomes:
232     //   %b = broadcast %y  : k-D to (n-1)-D
233     //   %x = [%b,%b,%b,%b] : n-D
234     // becomes:
235     //   %b = [%y,%y]       : (n-1)-D
236     //   %x = [%b,%b,%b,%b] : n-D
237     if (srcRank < dstRank) {
238       // Duplication.
239       VectorType resType =
240           VectorType::get(dstType.getShape().drop_front(), eltType);
241       Value bcst =
242           rewriter.create<vector::BroadcastOp>(loc, resType, op.source());
243       Value result = rewriter.create<arith::ConstantOp>(
244           loc, dstType, rewriter.getZeroAttr(dstType));
245       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
246         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
247       rewriter.replaceOp(op, result);
248       return success();
249     }
250 
251     // Find non-matching dimension, if any.
252     assert(srcRank == dstRank);
253     int64_t m = -1;
254     for (int64_t r = 0; r < dstRank; r++)
255       if (srcType.getDimSize(r) != dstType.getDimSize(r)) {
256         m = r;
257         break;
258       }
259 
260     // All trailing dimensions are the same. Simply pass through.
261     if (m == -1) {
262       rewriter.replaceOp(op, op.source());
263       return success();
264     }
265 
266     // Any non-matching dimension forces a stretch along this rank.
267     // For example:
268     //   %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32>
269     // becomes:
270     //   %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32>
271     //   %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32>
272     //   %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32>
273     //   %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32>
274     //   %x = [%a,%b,%c,%d]
275     // becomes:
276     //   %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32>
277     //   %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32>
278     //   %a = [%u, %v]
279     //   ..
280     //   %x = [%a,%b,%c,%d]
281     VectorType resType =
282         VectorType::get(dstType.getShape().drop_front(), eltType);
283     Value result = rewriter.create<arith::ConstantOp>(
284         loc, dstType, rewriter.getZeroAttr(dstType));
285     if (m == 0) {
286       // Stetch at start.
287       Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
288       Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
289       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
290         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
291     } else {
292       // Stetch not at start.
293       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
294         Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), d);
295         Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
296         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
297       }
298     }
299     rewriter.replaceOp(op, result);
300     return success();
301   }
302 };
303 
304 /// Given a 'transpose' pattern, prune the rightmost dimensions that are not
305 /// transposed.
306 void pruneNonTransposedDims(ArrayRef<int64_t> transpose,
307                             SmallVectorImpl<int64_t> &result) {
308   size_t numTransposedDims = transpose.size();
309   for (size_t transpDim : llvm::reverse(transpose)) {
310     if (transpDim != numTransposedDims - 1)
311       break;
312     numTransposedDims--;
313   }
314 
315   result.append(transpose.begin(), transpose.begin() + numTransposedDims);
316 }
317 
318 /// Progressive lowering of TransposeOp.
319 /// One:
320 ///   %x = vector.transpose %y, [1, 0]
321 /// is replaced by:
322 ///   %z = arith.constant dense<0.000000e+00>
323 ///   %0 = vector.extract %y[0, 0]
324 ///   %1 = vector.insert %0, %z [0, 0]
325 ///   ..
326 ///   %x = vector.insert .., .. [.., ..]
327 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
328 public:
329   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
330 
331   TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
332                       MLIRContext *context)
333       : OpRewritePattern<vector::TransposeOp>(context),
334         vectorTransformOptions(vectorTransformOptions) {}
335 
336   LogicalResult matchAndRewrite(vector::TransposeOp op,
337                                 PatternRewriter &rewriter) const override {
338     auto loc = op.getLoc();
339 
340     Value input = op.vector();
341     VectorType inputType = op.getVectorType();
342     VectorType resType = op.getResultType();
343 
344     // Set up convenience transposition table.
345     SmallVector<int64_t, 4> transp;
346     for (auto attr : op.transp())
347       transp.push_back(attr.cast<IntegerAttr>().getInt());
348 
349     if (vectorTransformOptions.vectorTransposeLowering ==
350             vector::VectorTransposeLowering::Shuffle &&
351         resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0)
352       return rewriter.notifyMatchFailure(
353           op, "Options specifies lowering to shuffle");
354 
355     // Handle a true 2-D matrix transpose differently when requested.
356     if (vectorTransformOptions.vectorTransposeLowering ==
357             vector::VectorTransposeLowering::Flat &&
358         resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
359       Type flattenedType =
360           VectorType::get(resType.getNumElements(), resType.getElementType());
361       auto matrix =
362           rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
363       auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
364       auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
365       Value trans = rewriter.create<vector::FlatTransposeOp>(
366           loc, flattenedType, matrix, rows, columns);
367       rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
368       return success();
369     }
370 
371     // Generate unrolled extract/insert ops. We do not unroll the rightmost
372     // (i.e., highest-order) dimensions that are not transposed and leave them
373     // in vector form to improve performance. Therefore, we prune those
374     // dimensions from the shape/transpose data structures used to generate the
375     // extract/insert ops.
376     SmallVector<int64_t, 4> prunedTransp;
377     pruneNonTransposedDims(transp, prunedTransp);
378     size_t numPrunedDims = transp.size() - prunedTransp.size();
379     auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
380     SmallVector<int64_t, 4> ones(prunedInShape.size(), 1);
381     auto prunedInStrides = computeStrides(prunedInShape, ones);
382 
383     // Generates the extract/insert operations for every scalar/vector element
384     // of the leftmost transposed dimensions. We traverse every transpose
385     // element using a linearized index that we delinearize to generate the
386     // appropriate indices for the extract/insert operations.
387     Value result = rewriter.create<arith::ConstantOp>(
388         loc, resType, rewriter.getZeroAttr(resType));
389     int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
390 
391     for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
392          ++linearIdx) {
393       auto extractIdxs = delinearize(prunedInStrides, linearIdx);
394       SmallVector<int64_t, 4> insertIdxs(extractIdxs);
395       applyPermutationToVector(insertIdxs, prunedTransp);
396       Value extractOp =
397           rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
398       result =
399           rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
400     }
401 
402     rewriter.replaceOp(op, result);
403     return success();
404   }
405 
406 private:
407   /// Options to control the vector patterns.
408   vector::VectorTransformsOptions vectorTransformOptions;
409 };
410 
411 /// Rewrite a 2-D vector.transpose as a sequence of:
412 ///   vector.shape_cast 2D -> 1D
413 ///   vector.shuffle
414 ///   vector.shape_cast 1D -> 2D
415 class TransposeOp2DToShuffleLowering
416     : public OpRewritePattern<vector::TransposeOp> {
417 public:
418   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
419 
420   TransposeOp2DToShuffleLowering(
421       vector::VectorTransformsOptions vectorTransformOptions,
422       MLIRContext *context)
423       : OpRewritePattern<vector::TransposeOp>(context),
424         vectorTransformOptions(vectorTransformOptions) {}
425 
426   LogicalResult matchAndRewrite(vector::TransposeOp op,
427                                 PatternRewriter &rewriter) const override {
428     auto loc = op.getLoc();
429 
430     VectorType srcType = op.getVectorType();
431     if (srcType.getRank() != 2)
432       return rewriter.notifyMatchFailure(op, "Not a 2D transpose");
433 
434     SmallVector<int64_t, 4> transp;
435     for (auto attr : op.transp())
436       transp.push_back(attr.cast<IntegerAttr>().getInt());
437     if (transp[0] != 1 && transp[1] != 0)
438       return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation");
439 
440     if (vectorTransformOptions.vectorTransposeLowering !=
441         VectorTransposeLowering::Shuffle)
442       return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle");
443 
444     int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
445     Value casted = rewriter.create<vector::ShapeCastOp>(
446         loc, VectorType::get({m * n}, srcType.getElementType()), op.vector());
447     SmallVector<int64_t> mask;
448     mask.reserve(m * n);
449     for (int64_t j = 0; j < n; ++j)
450       for (int64_t i = 0; i < m; ++i)
451         mask.push_back(i * n + j);
452 
453     Value shuffled =
454         rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask);
455     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
456                                                      shuffled);
457 
458     return success();
459   }
460 
461 private:
462   /// Options to control the vector patterns.
463   vector::VectorTransformsOptions vectorTransformOptions;
464 };
465 
466 /// Progressive lowering of OuterProductOp.
467 /// One:
468 ///   %x = vector.outerproduct %lhs, %rhs, %acc
469 /// is replaced by:
470 ///   %z = zero-result
471 ///   %0 = vector.extract %lhs[0]
472 ///   %1 = vector.broadcast %0
473 ///   %2 = vector.extract %acc[0]
474 ///   %3 = vector.fma %1, %rhs, %2
475 ///   %4 = vector.insert %3, %z[0]
476 ///   ..
477 ///   %x = vector.insert %.., %..[N-1]
478 ///
479 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
480 public:
481   using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
482 
483   LogicalResult matchAndRewrite(vector::OuterProductOp op,
484                                 PatternRewriter &rewriter) const override {
485     auto loc = op.getLoc();
486 
487     VectorType lhsType = op.getOperandVectorTypeLHS();
488     VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>();
489     VectorType resType = op.getVectorType();
490     Type eltType = resType.getElementType();
491     bool isInt = eltType.isa<IntegerType, IndexType>();
492     Value acc = (op.acc().empty()) ? nullptr : op.acc()[0];
493     vector::CombiningKind kind = op.kind();
494 
495     if (!rhsType) {
496       // Special case: AXPY operation.
497       Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.rhs());
498       Optional<Value> mult =
499           isInt ? genMultI(loc, op.lhs(), b, acc, kind, rewriter)
500                 : genMultF(loc, op.lhs(), b, acc, kind, rewriter);
501       if (!mult.hasValue())
502         return failure();
503       rewriter.replaceOp(op, mult.getValue());
504       return success();
505     }
506 
507     Value result = rewriter.create<arith::ConstantOp>(
508         loc, resType, rewriter.getZeroAttr(resType));
509     for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
510       auto pos = rewriter.getI64ArrayAttr(d);
511       Value x = rewriter.create<vector::ExtractOp>(loc, eltType, op.lhs(), pos);
512       Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
513       Value r = nullptr;
514       if (acc)
515         r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
516       Optional<Value> m = isInt ? genMultI(loc, a, op.rhs(), r, kind, rewriter)
517                                 : genMultF(loc, a, op.rhs(), r, kind, rewriter);
518       if (!m.hasValue())
519         return failure();
520       result = rewriter.create<vector::InsertOp>(loc, resType, m.getValue(),
521                                                  result, pos);
522     }
523     rewriter.replaceOp(op, result);
524     return success();
525   }
526 
527 private:
528   static Optional<Value> genMultI(Location loc, Value x, Value y, Value acc,
529                                   vector::CombiningKind kind,
530                                   PatternRewriter &rewriter) {
531     using vector::CombiningKind;
532 
533     auto mul = rewriter.create<arith::MulIOp>(loc, x, y);
534     if (!acc)
535       return Optional<Value>(mul);
536 
537     if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF)
538       // Only valid for floating point types.
539       return Optional<Value>();
540 
541     return makeArithReduction(rewriter, loc, kind, mul, acc);
542   }
543 
544   static Optional<Value> genMultF(Location loc, Value x, Value y, Value acc,
545                                   vector::CombiningKind kind,
546                                   PatternRewriter &rewriter) {
547     using vector::CombiningKind;
548 
549     // Special case for fused multiply-add.
550     if (acc && kind == CombiningKind::ADD) {
551       return Optional<Value>(rewriter.create<vector::FMAOp>(loc, x, y, acc));
552     }
553 
554     auto mul = rewriter.create<arith::MulFOp>(loc, x, y);
555 
556     if (!acc)
557       return Optional<Value>(mul);
558 
559     if (kind == CombiningKind::ADD || kind == CombiningKind::AND ||
560         kind == CombiningKind::MINUI || kind == CombiningKind::MINSI ||
561         kind == CombiningKind::MAXUI || kind == CombiningKind::MAXSI ||
562         kind == CombiningKind::OR || kind == CombiningKind::XOR)
563       // Already handled or only valid for integer types.
564       return Optional<Value>();
565 
566     return makeArithReduction(rewriter, loc, kind, mul, acc);
567   }
568 };
569 
570 /// Progressive lowering of ConstantMaskOp.
571 /// One:
572 ///   %x = vector.constant_mask [a,b]
573 /// is replaced by:
574 ///   %z = zero-result
575 ///   %l = vector.constant_mask [b]
576 ///   %4 = vector.insert %l, %z[0]
577 ///   ..
578 ///   %x = vector.insert %l, %..[a-1]
579 /// until a one-dimensional vector is reached. All these operations
580 /// will be folded at LLVM IR level.
581 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
582 public:
583   using OpRewritePattern<vector::ConstantMaskOp>::OpRewritePattern;
584 
585   LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
586                                 PatternRewriter &rewriter) const override {
587     auto loc = op.getLoc();
588     auto dstType = op.getType();
589     auto eltType = dstType.getElementType();
590     auto dimSizes = op.mask_dim_sizes();
591     int64_t rank = dstType.getRank();
592 
593     if (rank == 0) {
594       assert(dimSizes.size() == 1 &&
595              "Expected exactly one dim size for a 0-D vector");
596       bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1;
597       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
598           op, dstType,
599           DenseIntElementsAttr::get(
600               VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
601               ArrayRef<bool>{value}));
602       return success();
603     }
604 
605     int64_t trueDim = std::min(dstType.getDimSize(0),
606                                dimSizes[0].cast<IntegerAttr>().getInt());
607 
608     if (rank == 1) {
609       // Express constant 1-D case in explicit vector form:
610       //   [T,..,T,F,..,F].
611       SmallVector<bool, 4> values(dstType.getDimSize(0));
612       for (int64_t d = 0; d < trueDim; d++)
613         values[d] = true;
614       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
615           op, dstType, rewriter.getBoolVectorAttr(values));
616       return success();
617     }
618 
619     VectorType lowType =
620         VectorType::get(dstType.getShape().drop_front(), eltType);
621     SmallVector<int64_t, 4> newDimSizes;
622     for (int64_t r = 1; r < rank; r++)
623       newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
624     Value trueVal = rewriter.create<vector::ConstantMaskOp>(
625         loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
626     Value result = rewriter.create<arith::ConstantOp>(
627         loc, dstType, rewriter.getZeroAttr(dstType));
628     for (int64_t d = 0; d < trueDim; d++) {
629       auto pos = rewriter.getI64ArrayAttr(d);
630       result =
631           rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos);
632     }
633     rewriter.replaceOp(op, result);
634     return success();
635   }
636 };
637 
638 /// Progressive lowering of CreateMaskOp.
639 /// One:
640 ///   %x = vector.create_mask %a, ... : vector<dx...>
641 /// is replaced by:
642 ///   %l = vector.create_mask ... : vector<...>  ; one lower rank
643 ///   %0 = arith.cmpi "slt", %ci, %a       |
644 ///   %1 = select %0, %l, %zeroes    |
645 ///   %r = vector.insert %1, %pr [i] | d-times
646 ///   %x = ....
647 /// until a one-dimensional vector is reached.
648 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
649 public:
650   using OpRewritePattern<vector::CreateMaskOp>::OpRewritePattern;
651 
652   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
653                                 PatternRewriter &rewriter) const override {
654     auto dstType = op.getResult().getType().cast<VectorType>();
655     int64_t rank = dstType.getRank();
656     if (rank <= 1)
657       return rewriter.notifyMatchFailure(
658           op, "0-D and 1-D vectors are handled separately");
659 
660     auto loc = op.getLoc();
661     auto eltType = dstType.getElementType();
662     int64_t dim = dstType.getDimSize(0);
663     Value idx = op.getOperand(0);
664 
665     VectorType lowType =
666         VectorType::get(dstType.getShape().drop_front(), eltType);
667     Value trueVal = rewriter.create<vector::CreateMaskOp>(
668         loc, lowType, op.getOperands().drop_front());
669     Value falseVal = rewriter.create<arith::ConstantOp>(
670         loc, lowType, rewriter.getZeroAttr(lowType));
671     Value result = rewriter.create<arith::ConstantOp>(
672         loc, dstType, rewriter.getZeroAttr(dstType));
673     for (int64_t d = 0; d < dim; d++) {
674       Value bnd =
675           rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d));
676       Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
677                                                  bnd, idx);
678       Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal);
679       auto pos = rewriter.getI64ArrayAttr(d);
680       result =
681           rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos);
682     }
683     rewriter.replaceOp(op, result);
684     return success();
685   }
686 };
687 
688 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
689 /// vectors progressively on the way to target llvm.matrix intrinsics.
690 /// This iterates over the most major dimension of the 2-D vector and performs
691 /// rewrites into:
692 ///   vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
693 class ShapeCastOp2DDownCastRewritePattern
694     : public OpRewritePattern<vector::ShapeCastOp> {
695 public:
696   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
697 
698   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
699                                 PatternRewriter &rewriter) const override {
700     auto sourceVectorType = op.getSourceVectorType();
701     auto resultVectorType = op.getResultVectorType();
702     if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
703       return failure();
704 
705     auto loc = op.getLoc();
706     Value desc = rewriter.create<arith::ConstantOp>(
707         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
708     unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
709     for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
710       Value vec = rewriter.create<vector::ExtractOp>(loc, op.source(), i);
711       desc = rewriter.create<vector::InsertStridedSliceOp>(
712           loc, vec, desc,
713           /*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
714     }
715     rewriter.replaceOp(op, desc);
716     return success();
717   }
718 };
719 
720 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
721 /// vectors progressively.
722 /// This iterates over the most major dimension of the 2-D vector and performs
723 /// rewrites into:
724 ///   vector.extract_strided_slice from 1-D + vector.insert into 2-D
725 /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
726 class ShapeCastOp2DUpCastRewritePattern
727     : public OpRewritePattern<vector::ShapeCastOp> {
728 public:
729   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
730 
731   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
732                                 PatternRewriter &rewriter) const override {
733     auto sourceVectorType = op.getSourceVectorType();
734     auto resultVectorType = op.getResultVectorType();
735     if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
736       return failure();
737 
738     auto loc = op.getLoc();
739     Value desc = rewriter.create<arith::ConstantOp>(
740         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
741     unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
742     for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
743       Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
744           loc, op.source(), /*offsets=*/i * mostMinorVectorSize,
745           /*sizes=*/mostMinorVectorSize,
746           /*strides=*/1);
747       desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
748     }
749     rewriter.replaceOp(op, desc);
750     return success();
751   }
752 };
753 
754 // We typically should not lower general shape cast operations into data
755 // movement instructions, since the assumption is that these casts are
756 // optimized away during progressive lowering. For completeness, however,
757 // we fall back to a reference implementation that moves all elements
758 // into the right place if we get here.
759 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
760 public:
761   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
762 
763   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
764                                 PatternRewriter &rewriter) const override {
765     Location loc = op.getLoc();
766     auto sourceVectorType = op.getSourceVectorType();
767     auto resultVectorType = op.getResultVectorType();
768 
769     // Special case 2D/1D lowerings with better implementations.
770     // TODO: make is ND/1D to allow generic ND->1D->MD.
771     int64_t srcRank = sourceVectorType.getRank();
772     int64_t resRank = resultVectorType.getRank();
773     if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
774       return failure();
775 
776     // Generic ShapeCast lowering path goes all the way down to unrolled scalar
777     // extract/insert chains.
778     // TODO: consider evolving the semantics to only allow 1D source or dest and
779     // drop this potentially very expensive lowering.
780     // Compute number of elements involved in the reshape.
781     int64_t numElts = 1;
782     for (int64_t r = 0; r < srcRank; r++)
783       numElts *= sourceVectorType.getDimSize(r);
784     // Replace with data movement operations:
785     //    x[0,0,0] = y[0,0]
786     //    x[0,0,1] = y[0,1]
787     //    x[0,1,0] = y[0,2]
788     // etc., incrementing the two index vectors "row-major"
789     // within the source and result shape.
790     SmallVector<int64_t, 4> srcIdx(srcRank);
791     SmallVector<int64_t, 4> resIdx(resRank);
792     Value result = rewriter.create<arith::ConstantOp>(
793         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
794     for (int64_t i = 0; i < numElts; i++) {
795       if (i != 0) {
796         incIdx(srcIdx, sourceVectorType, srcRank - 1);
797         incIdx(resIdx, resultVectorType, resRank - 1);
798       }
799       Value e = rewriter.create<vector::ExtractOp>(loc, op.source(), srcIdx);
800       result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx);
801     }
802     rewriter.replaceOp(op, result);
803     return success();
804   }
805 
806 private:
807   static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) {
808     assert(0 <= r && r < tp.getRank());
809     if (++idx[r] == tp.getDimSize(r)) {
810       idx[r] = 0;
811       incIdx(idx, tp, r - 1);
812     }
813   }
814 };
815 
816 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
817 /// Ex:
818 /// ```
819 ///   %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
820 ///   %1 = vector.multi_reduction add, %0 [1]
821 ///     : vector<8x32x16xf32> to vector<8x16xf32>
822 /// ```
823 /// Gets converted to:
824 /// ```
825 ///   %1 = vector.contract {indexing_maps = [
826 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
827 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
828 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
829 ///    iterator_types = ["parallel", "parallel", "reduction"],
830 ///    kind = add} %0, %arg1, %cst_f0
831 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
832 ///  ```
833 struct MultiReduceToContract
834     : public OpRewritePattern<vector::MultiDimReductionOp> {
835   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
836 
837   LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
838                                 PatternRewriter &rewriter) const override {
839     if (reduceOp.kind() != vector::CombiningKind::ADD)
840       return failure();
841     Operation *mulOp = reduceOp.source().getDefiningOp();
842     if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
843       return failure();
844     SmallVector<bool> reductionMask = reduceOp.getReductionMask();
845     auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
846     SmallVector<AffineExpr> exprs;
847     SmallVector<StringRef> iteratorTypes;
848     for (const auto &isReduceDim : llvm::enumerate(reductionMask)) {
849       if (!isReduceDim.value()) {
850         iteratorTypes.push_back(getParallelIteratorTypeName());
851         exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
852       } else {
853         iteratorTypes.push_back(getReductionIteratorTypeName());
854       }
855     }
856     auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(),
857                                  /*symCount=*/0, exprs, reduceOp.getContext());
858     Value zero = rewriter.create<arith::ConstantOp>(
859         reduceOp.getLoc(), reduceOp.getDestType(),
860         rewriter.getZeroAttr(reduceOp.getDestType()));
861     rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
862         reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), zero,
863         rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
864         rewriter.getStrArrayAttr(iteratorTypes));
865     return success();
866   }
867 };
868 
869 /// Merge TransposeOp into ContractionOp user.
870 /// Ex:
871 /// ```
872 ///   %0 = vector.transpose %arg0, [2, 0, 1]
873 ///     : vector<32x16x8xf32> to vector<8x32x16xf32>
874 ///   %1 = vector.contract {indexing_maps = [
875 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
876 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
877 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
878 ///    iterator_types = ["parallel", "parallel", "reduction"],
879 ///    kind = add} %0, %arg1, %cst_f0
880 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
881 /// ```
882 /// Gets converted to:
883 /// ```
884 ///   %1 = vector.contract {indexing_maps = [
885 ///         affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
886 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
887 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
888 ///    iterator_types = ["parallel", "parallel", "reduction"],
889 ///    kind = add} %arg0, %arg1, %cst_f0
890 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
891 ///  ```
892 struct CombineContractTranspose
893     : public OpRewritePattern<vector::ContractionOp> {
894   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
895 
896   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
897                                 PatternRewriter &rewriter) const override {
898     SmallVector<AffineMap, 4> maps =
899         llvm::to_vector<4>(contractOp.getIndexingMaps());
900     Value lhs = contractOp.lhs();
901     Value rhs = contractOp.rhs();
902     size_t index = 0;
903     bool changed = false;
904     for (Value *operand : {&lhs, &rhs}) {
905       AffineMap &map = maps[index++];
906       auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
907       if (!transposeOp)
908         continue;
909       SmallVector<int64_t> perm;
910       transposeOp.getTransp(perm);
911       AffineMap permutationMap = AffineMap::getPermutationMap(
912           extractVector<unsigned>(transposeOp.transp()),
913           contractOp.getContext());
914       map = inversePermutation(permutationMap).compose(map);
915       *operand = transposeOp.vector();
916       changed = true;
917     }
918     if (!changed)
919       return failure();
920     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
921         contractOp, lhs, rhs, contractOp.acc(),
922         rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types());
923     return success();
924   }
925 };
926 
927 /// Merge BroadcastOp into ContractionOp user.
928 /// Ex:
929 /// ```
930 ///   %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
931 ///   %1 = vector.contract {indexing_maps = [
932 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
933 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
934 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
935 ///    iterator_types = ["parallel", "parallel", "reduction"],
936 ///    kind = add} %0, %arg1, %cst_f0
937 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
938 /// ```
939 /// Gets converted to:
940 /// ```
941 ///   %1 = vector.contract {indexing_maps = [
942 ///         affine_map<(d0, d1, d2) -> (d1, d2)>,
943 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
944 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
945 ///    iterator_types = ["parallel", "parallel", "reduction"],
946 ///    kind = add} %arg0, %arg1, %cst_f0
947 ///    : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
948 ///  ```
949 struct CombineContractBroadcast
950     : public OpRewritePattern<vector::ContractionOp> {
951   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
952 
953   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
954                                 PatternRewriter &rewriter) const override {
955     SmallVector<AffineMap, 4> maps =
956         llvm::to_vector<4>(contractOp.getIndexingMaps());
957     Value lhs = contractOp.lhs();
958     Value rhs = contractOp.rhs();
959     size_t index = 0;
960     bool changed = false;
961     for (Value *operand : {&lhs, &rhs}) {
962       AffineMap &map = maps[index++];
963       auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
964       if (!broadcast)
965         continue;
966       // contractionOp can only take vector as operands.
967       auto srcType = broadcast.getSourceType().dyn_cast<VectorType>();
968       if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank())
969         continue;
970       int64_t rankDiff =
971           broadcast.getVectorType().getRank() - srcType.getRank();
972       bool innerDimBroadcast = false;
973       SmallVector<AffineExpr> originalDims;
974       for (const auto &dim : llvm::enumerate(srcType.getShape())) {
975         if (dim.value() !=
976             broadcast.getVectorType().getDimSize(rankDiff + dim.index())) {
977           innerDimBroadcast = true;
978           break;
979         }
980         originalDims.push_back(
981             rewriter.getAffineDimExpr(dim.index() + rankDiff));
982       }
983       // Contract doesn't support inner dimension broadcast. Once this is
984       // relaxed we can remove this case.
985       if (innerDimBroadcast)
986         continue;
987       AffineMap broadcastMap =
988           AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims,
989                          contractOp.getContext());
990       map = broadcastMap.compose(map);
991       *operand = broadcast.source();
992       changed = true;
993     }
994     if (!changed)
995       return failure();
996     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
997         contractOp, lhs, rhs, contractOp.acc(),
998         rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types());
999     return success();
1000   }
1001 };
1002 
1003 /// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and
1004 /// contraction ops closer, which kicks in CombineContractBroadcast pattern when
1005 /// casting ops are around these operations.
1006 /// Ex:
1007 /// ```
1008 ///   %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8>
1009 ///   %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
1010 /// ```
1011 /// Gets converted to:
1012 /// ```
1013 ///   %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32>
1014 ///   %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32>
1015 /// ```
1016 struct ReorderCastOpsOnBroadcast
1017     : public OpInterfaceRewritePattern<CastOpInterface> {
1018   using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
1019 
1020   LogicalResult matchAndRewrite(CastOpInterface op,
1021                                 PatternRewriter &rewriter) const override {
1022     if (op->getNumOperands() != 1)
1023       return failure();
1024     auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
1025     if (!bcastOp)
1026       return failure();
1027 
1028     Type castResTy = getElementTypeOrSelf(op->getResult(0));
1029     if (auto vecTy = bcastOp.getSourceType().dyn_cast<VectorType>())
1030       castResTy = VectorType::get(vecTy.getShape(), castResTy);
1031     OperationState state(op->getLoc(), op->getName(), bcastOp.source(),
1032                          castResTy, op->getAttrs());
1033     auto castOp = rewriter.createOperation(state);
1034     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1035         op, op->getResult(0).getType(), castOp->getResult(0));
1036     return success();
1037   }
1038 };
1039 
1040 /// Reorders cast(transpose) to transpose(cast). This makes broadcast ops and
1041 /// contraction ops closer, which kicks in CombineContractTranspose pattern when
1042 /// casting ops are around these operations.
1043 /// Ex:
1044 /// ```
1045 ///   %0 = vector.transpose %arg0, [2, 0, 1]
1046 ///     : vector<32x16x8xi8> to vector<8x32x16xi8>
1047 ///   %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
1048 /// ```
1049 /// Gets converted to:
1050 /// ```
1051 ///   %0 = arith.extsi %0 : vector<32x16x8xi8> to vector<32x16x8xi32>
1052 ///   %1 = vector.transpose %arg0, [2, 0, 1]
1053 ///     : vector<32x16x8xi32> to vector<8x32x16xi32>
1054 /// ```
1055 struct ReorderCastOpsOnTranspose
1056     : public OpInterfaceRewritePattern<CastOpInterface> {
1057 
1058   using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
1059 
1060   LogicalResult matchAndRewrite(CastOpInterface op,
1061                                 PatternRewriter &rewriter) const override {
1062     if (op->getNumOperands() != 1)
1063       return failure();
1064     auto transpOp = op->getOperand(0).getDefiningOp<vector::TransposeOp>();
1065     if (!transpOp)
1066       return failure();
1067 
1068     auto castResTy = transpOp.getVectorType();
1069     castResTy = VectorType::get(castResTy.getShape(),
1070                                 getElementTypeOrSelf(op->getResult(0)));
1071     OperationState state(op->getLoc(), op->getName(), transpOp.vector(),
1072                          castResTy, op->getAttrs());
1073     auto castOp = rewriter.createOperation(state);
1074     rewriter.replaceOpWithNewOp<vector::TransposeOp>(
1075         op, op->getResult(0).getType(), castOp->getResult(0),
1076         transpOp.getTransp());
1077     return success();
1078   }
1079 };
1080 
1081 } // namespace
1082 
1083 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
1084 /// operands `x` and `y`.
1085 static Value createAdd(Location loc, Value x, Value y, bool isInt,
1086                        PatternRewriter &rewriter) {
1087   if (isInt)
1088     return rewriter.create<arith::AddIOp>(loc, x, y);
1089   return rewriter.create<arith::AddFOp>(loc, x, y);
1090 }
1091 
1092 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
1093 /// operands `x and `y`.
1094 static Value createMul(Location loc, Value x, Value y, bool isInt,
1095                        PatternRewriter &rewriter) {
1096   if (isInt)
1097     return rewriter.create<arith::MulIOp>(loc, x, y);
1098   return rewriter.create<arith::MulFOp>(loc, x, y);
1099 }
1100 
1101 namespace mlir {
1102 
1103 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1104 /// semantics to:
1105 /// ```
1106 ///    %mta = maybe_transpose
1107 ///    %mtb = maybe_transpose
1108 ///    %flattened_a = vector.shape_cast %mta
1109 ///    %flattened_b = vector.shape_cast %mtb
1110 ///    %flattened_d = vector.matmul %flattened_a, %flattened_b
1111 ///    %mtd = vector.shape_cast %flattened_d
1112 ///    %d = maybe_untranspose %mtd
1113 ///    %e = add %c, %d
1114 /// ```
1115 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
1116 //
1117 /// This only kicks in when VectorTransformsOptions is set to `Matmul`.
1118 /// vector.transpose operations are inserted if the vector.contract op is not a
1119 /// row-major matrix multiply.
1120 LogicalResult
1121 ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
1122                                                  PatternRewriter &rew) const {
1123   // TODO: implement masks
1124   if (llvm::size(op.masks()) != 0)
1125     return failure();
1126   if (vectorTransformOptions.vectorContractLowering !=
1127       vector::VectorContractLowering::Matmul)
1128     return failure();
1129   if (failed(filter(op)))
1130     return failure();
1131 
1132   auto iteratorTypes = op.iterator_types().getValue();
1133   if (!isParallelIterator(iteratorTypes[0]) ||
1134       !isParallelIterator(iteratorTypes[1]) ||
1135       !isReductionIterator(iteratorTypes[2]))
1136     return failure();
1137 
1138   Type elementType = op.getLhsType().getElementType();
1139   if (!elementType.isIntOrFloat())
1140     return failure();
1141 
1142   // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
1143   // Bail out if the contraction cannot be put in this form.
1144   MLIRContext *ctx = op.getContext();
1145   Location loc = op.getLoc();
1146   AffineExpr m, n, k;
1147   bindDims(rew.getContext(), m, n, k);
1148   // LHS must be A(m, k) or A(k, m).
1149   Value lhs = op.lhs();
1150   auto lhsMap = op.indexing_maps()[0];
1151   if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
1152     lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
1153   else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
1154     return failure();
1155 
1156   // RHS must be B(k, n) or B(n, k).
1157   Value rhs = op.rhs();
1158   auto rhsMap = op.indexing_maps()[1];
1159   if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
1160     rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
1161   else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
1162     return failure();
1163 
1164   // At this point lhs and rhs are in row-major.
1165   VectorType lhsType = lhs.getType().cast<VectorType>();
1166   VectorType rhsType = rhs.getType().cast<VectorType>();
1167   int64_t lhsRows = lhsType.getDimSize(0);
1168   int64_t lhsColumns = lhsType.getDimSize(1);
1169   int64_t rhsColumns = rhsType.getDimSize(1);
1170 
1171   Type flattenedLHSType =
1172       VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
1173   lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
1174 
1175   Type flattenedRHSType =
1176       VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
1177   rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
1178 
1179   Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
1180                                            rhsColumns);
1181   mul = rew.create<vector::ShapeCastOp>(
1182       loc,
1183       VectorType::get({lhsRows, rhsColumns},
1184                       getElementTypeOrSelf(op.acc().getType())),
1185       mul);
1186 
1187   // ACC must be C(m, n) or C(n, m).
1188   auto accMap = op.indexing_maps()[2];
1189   if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
1190     mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
1191   else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
1192     llvm_unreachable("invalid contraction semantics");
1193 
1194   Value res =
1195       elementType.isa<IntegerType>()
1196           ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.acc(), mul))
1197           : static_cast<Value>(rew.create<arith::AddFOp>(loc, op.acc(), mul));
1198 
1199   rew.replaceOp(op, res);
1200   return success();
1201 }
1202 
1203 namespace {
1204 struct IteratorType {
1205   IteratorType(StringRef strRef) : strRef(strRef) {}
1206   bool isOfType(Attribute attr) const {
1207     auto sAttr = attr.dyn_cast<StringAttr>();
1208     return sAttr && sAttr.getValue() == strRef;
1209   }
1210   StringRef strRef;
1211 };
1212 struct Par : public IteratorType {
1213   Par() : IteratorType(getParallelIteratorTypeName()) {}
1214 };
1215 struct Red : public IteratorType {
1216   Red() : IteratorType(getReductionIteratorTypeName()) {}
1217 };
1218 
1219 /// Generate a vector implementation for matmat, matvec and tmatvec.
1220 /// This unrolls outer-products along the reduction dimension.
1221 struct UnrolledOuterProductGenerator
1222     : public StructuredGenerator<vector::ContractionOp> {
1223 
1224   UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op)
1225       : StructuredGenerator<vector::ContractionOp>(builder, op),
1226         kind(op.kind()), lhs(op.lhs()), rhs(op.rhs()), res(op.acc()),
1227         lhsType(op.getLhsType()) {}
1228 
1229   Value t(Value v) {
1230     static constexpr std::array<int64_t, 2> perm = {1, 0};
1231     return builder.create<vector::TransposeOp>(loc, v, perm);
1232   }
1233 
1234   Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) {
1235     assert(reductionSize > 0);
1236     for (int64_t k = 0; k < reductionSize; ++k) {
1237       Value a = builder.create<vector::ExtractOp>(loc, lhs, k);
1238       Value b = builder.create<vector::ExtractOp>(loc, rhs, k);
1239       res = builder.create<vector::OuterProductOp>(loc, res.getType(), a, b,
1240                                                    res, kind);
1241     }
1242     return res;
1243   }
1244 
1245   /// Two outer parallel, one inner reduction (matmat flavor).
1246   FailureOr<Value> matmat() {
1247     if (!iters({Par(), Par(), Red()}))
1248       return failure();
1249     // Set up the parallel/reduction structure in the right form.
1250     AffineExpr m, n, k;
1251     bindDims(builder.getContext(), m, n, k);
1252     // Classical row-major matmul:  Just permute the lhs.
1253     if (layout({{m, k}, {k, n}, {m, n}}))
1254       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1255     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1256     if (layout({{m, k}, {n, k}, {m, n}})) {
1257       Value tlhs = t(lhs);
1258       return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1));
1259     }
1260     // No need to permute anything.
1261     if (layout({{k, m}, {k, n}, {m, n}}))
1262       return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1263     // Just permute the rhs.
1264     if (layout({{k, m}, {n, k}, {m, n}}))
1265       return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0));
1266     // Transposed output: swap RHS and LHS.
1267     // Classical row-major matmul: permute the lhs.
1268     if (layout({{m, k}, {k, n}, {n, m}}))
1269       return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1));
1270     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1271     if (layout({{m, k}, {n, k}, {n, m}})) {
1272       Value trhs = t(rhs);
1273       return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1));
1274     }
1275     if (layout({{k, m}, {k, n}, {n, m}}))
1276       return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1277     if (layout({{k, m}, {n, k}, {n, m}}))
1278       return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1279     return failure();
1280   }
1281 
1282   /// One outer parallel, one inner reduction (matvec flavor)
1283   FailureOr<Value> matvec() {
1284     if (!iters({Par(), Red()}))
1285       return failure();
1286     AffineExpr m, k;
1287     bindDims(builder.getContext(), m, k);
1288 
1289     // Case mat-vec: transpose.
1290     if (layout({{m, k}, {k}, {m}}))
1291       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1292     // Case mat-trans-vec: ready to go.
1293     if (layout({{k, m}, {k}, {m}}))
1294       return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1295     // Case vec-mat: swap and transpose.
1296     if (layout({{k}, {m, k}, {m}}))
1297       return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1298     // Case vec-mat-trans: swap and ready to go.
1299     if (layout({{k}, {k, m}, {m}}))
1300       return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1301     return failure();
1302   }
1303 
1304   //
1305   // One outer reduction, one inner parallel (tmatvec flavor)
1306   //
1307   FailureOr<Value> tmatvec() {
1308     if (!iters({Red(), Par()}))
1309       return failure();
1310     AffineExpr k, m;
1311     bindDims(builder.getContext(), k, m);
1312 
1313     // Case mat-vec: transpose.
1314     if (layout({{m, k}, {k}, {m}}))
1315       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1316     // Case mat-trans-vec: ready to go.
1317     if (layout({{k, m}, {k}, {m}}))
1318       return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1319     // Case vec-mat: swap and transpose.
1320     if (layout({{k}, {m, k}, {m}}))
1321       return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1322     // Case vec-mat-trans: swap and ready to go.
1323     if (layout({{k}, {k, m}, {m}}))
1324       return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1325     return failure();
1326   }
1327 
1328 private:
1329   vector::CombiningKind kind;
1330   Value lhs, rhs, res;
1331   VectorType lhsType;
1332 };
1333 } // namespace
1334 
1335 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1336 /// semantics to a reduction_size-unrolled sequence:
1337 /// ```
1338 ///    %at = vector.transpose %a, [1, 0]
1339 ///    %bRow0 = vector.extract %b[0]
1340 ///    %atRow0 = vector.extract %at[0]
1341 ///    %c0 = vector.outerproduct %atRow0, %bRow0, %c
1342 ///    ...
1343 ///    %bRowK = vector.extract %b[K]
1344 ///    %atRowK = vector.extract %at[K]
1345 ///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
1346 /// ```
1347 ///
1348 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but
1349 /// otherwise supports any layout permutation of the matrix-multiply.
1350 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
1351     vector::ContractionOp op, PatternRewriter &rewriter) const {
1352   // TODO: implement masks
1353   if (llvm::size(op.masks()) != 0)
1354     return failure();
1355 
1356   if (vectorTransformOptions.vectorContractLowering !=
1357       vector::VectorContractLowering::OuterProduct)
1358     return failure();
1359 
1360   if (failed(filter(op)))
1361     return failure();
1362 
1363   UnrolledOuterProductGenerator e(rewriter, op);
1364   FailureOr<Value> matmatRes = e.matmat();
1365   if (succeeded(matmatRes)) {
1366     rewriter.replaceOp(op, *matmatRes);
1367     return success();
1368   }
1369   FailureOr<Value> matvecRes = e.matvec();
1370   if (succeeded(matvecRes)) {
1371     rewriter.replaceOp(op, *matvecRes);
1372     return success();
1373   }
1374   FailureOr<Value> tmatvecRes = e.tmatvec();
1375   if (succeeded(tmatvecRes)) {
1376     rewriter.replaceOp(op, *tmatvecRes);
1377     return success();
1378   }
1379 
1380   return failure();
1381 }
1382 
1383 LogicalResult
1384 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
1385                                             PatternRewriter &rewriter) const {
1386   // TODO: implement masks
1387   if (llvm::size(op.masks()) != 0)
1388     return failure();
1389 
1390   if (failed(filter(op)))
1391     return failure();
1392 
1393   if (vectorTransformOptions.vectorContractLowering !=
1394       vector::VectorContractLowering::Dot)
1395     return failure();
1396 
1397   auto iteratorTypes = op.iterator_types().getValue();
1398   static constexpr std::array<int64_t, 2> perm = {1, 0};
1399   Location loc = op.getLoc();
1400   Value lhs = op.lhs(), rhs = op.rhs();
1401 
1402   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1403   auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
1404   AffineExpr m, n, k;
1405   bindDims(rewriter.getContext(), m, n, k);
1406   SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
1407   //
1408   // In the following we wish to make the reduction dimension innermost so we
1409   // can load vectors and just fmul + reduce into a scalar.
1410   //
1411   if (isParallelIterator(iteratorTypes[0]) &&
1412       isParallelIterator(iteratorTypes[1]) &&
1413       isReductionIterator(iteratorTypes[2])) {
1414     //
1415     // Two outer parallel, one inner reduction (matmat flavor).
1416     //
1417     if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1418       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1419     } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
1420       // No need to permute anything.
1421     } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1422       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1423       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1424     } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1425       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1426     } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1427       // This is the classical row-major matmul. Just permute the lhs.
1428       Value tmp = lhs;
1429       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1430       rhs = tmp;
1431     } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1432       std::swap(lhs, rhs);
1433     } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1434       Value tmp = lhs;
1435       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1436       rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
1437     } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1438       Value tmp = rhs;
1439       rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1440       lhs = tmp;
1441     } else {
1442       return failure();
1443     }
1444   } else if (isParallelIterator(iteratorTypes[0]) &&
1445              isReductionIterator(iteratorTypes[1])) {
1446     //
1447     // One outer parallel, one inner reduction (matvec flavor)
1448     //
1449     if (maps == infer({{m, n}, {n}, {m}})) {
1450       // No need to permute anything.
1451     } else if (maps == infer({{n, m}, {n}, {m}})) {
1452       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1453     } else if (maps == infer({{n}, {m, n}, {m}})) {
1454       std::swap(lhs, rhs);
1455     } else if (maps == infer({{n}, {n, m}, {m}})) {
1456       std::swap(lhs, rhs);
1457       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1458     } else {
1459       return failure();
1460     }
1461   } else {
1462     return failure();
1463   }
1464 
1465   VectorType dstType = op.getResultType().cast<VectorType>();
1466   assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
1467          "Expected dst type of rank 1 or 2");
1468 
1469   unsigned rank = dstType.getRank();
1470   unsigned dstRows = dstType.getShape()[0];
1471   unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
1472 
1473   // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
1474   Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
1475                                                  rewriter.getZeroAttr(dstType));
1476   bool isInt = dstType.getElementType().isa<IntegerType>();
1477   for (unsigned r = 0; r < dstRows; ++r) {
1478     Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
1479     for (unsigned c = 0; c < dstColumns; ++c) {
1480       Value b = rank == 1
1481                     ? rhs
1482                     : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
1483       Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
1484       Value reduced = rewriter.create<vector::ReductionOp>(
1485           op.getLoc(), vector::CombiningKind::ADD, m);
1486 
1487       SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
1488                                               : SmallVector<int64_t, 2>{r, c};
1489       res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
1490     }
1491   }
1492   if (auto acc = op.acc())
1493     res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
1494   rewriter.replaceOp(op, res);
1495   return success();
1496 }
1497 
1498 /// Progressive lowering of ContractionOp.
1499 /// One:
1500 ///   %x = vector.contract with at least one free/batch dimension
1501 /// is replaced by:
1502 ///   %a = vector.contract with one less free/batch dimension
1503 ///   %b = vector.contract with one less free/batch dimension
1504 ///   ..
1505 ///   %x = combine %a %b ..
1506 /// until a pure contraction is reached (no free/batch dimensions),
1507 /// which is replaced by a dot-product.
1508 ///
1509 /// This only kicks in when either VectorTransformsOptions is set
1510 /// to DOT or when other contraction patterns fail.
1511 //
1512 // TODO: break down into transpose/reshape/cast ops
1513 //               when they become available to avoid code dup
1514 // TODO: investigate lowering order impact on performance
1515 LogicalResult
1516 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
1517                                        PatternRewriter &rewriter) const {
1518   // TODO: implement masks.
1519   if (llvm::size(op.masks()) != 0)
1520     return failure();
1521 
1522   if (failed(filter(op)))
1523     return failure();
1524 
1525   // TODO: support mixed mode contract lowering.
1526   if (op.getLhsType().getElementType() !=
1527           getElementTypeOrSelf(op.getAccType()) ||
1528       op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
1529     return failure();
1530 
1531   // TODO: implement benefits, cost models.
1532   MLIRContext *ctx = op.getContext();
1533   ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
1534   if (succeeded(pat1.matchAndRewrite(op, rewriter)))
1535     return success();
1536   ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
1537   if (succeeded(pat2.matchAndRewrite(op, rewriter)))
1538     return success();
1539   ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
1540   if (succeeded(pat3.matchAndRewrite(op, rewriter)))
1541     return success();
1542 
1543   // Find first batch dimension in LHS/RHS, and lower when found.
1544   std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
1545   if (!batchDimMap.empty()) {
1546     int64_t lhsIndex = batchDimMap[0].first;
1547     int64_t rhsIndex = batchDimMap[0].second;
1548     rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter));
1549     return success();
1550   }
1551 
1552   // Collect contracting dimensions.
1553   std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
1554       op.getContractingDimMap();
1555   DenseSet<int64_t> lhsContractingDimSet;
1556   DenseSet<int64_t> rhsContractingDimSet;
1557   for (auto &dimPair : contractingDimMap) {
1558     lhsContractingDimSet.insert(dimPair.first);
1559     rhsContractingDimSet.insert(dimPair.second);
1560   }
1561 
1562   // Find first free dimension in LHS, and lower when found.
1563   VectorType lhsType = op.getLhsType();
1564   for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
1565     if (lhsContractingDimSet.count(lhsIndex) == 0) {
1566       rewriter.replaceOp(
1567           op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter));
1568       return success();
1569     }
1570   }
1571 
1572   // Find first free dimension in RHS, and lower when found.
1573   VectorType rhsType = op.getRhsType();
1574   for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
1575     if (rhsContractingDimSet.count(rhsIndex) == 0) {
1576       rewriter.replaceOp(
1577           op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter));
1578       return success();
1579     }
1580   }
1581 
1582   // Lower the first remaining reduction dimension.
1583   if (!contractingDimMap.empty()) {
1584     rewriter.replaceOp(op, lowerReduction(op, rewriter));
1585     return success();
1586   }
1587 
1588   return failure();
1589 }
1590 
1591 // Lower one parallel dimension.
1592 // TODO: consider reusing existing contract unrolling
1593 Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
1594                                            int64_t lhsIndex, int64_t rhsIndex,
1595                                            PatternRewriter &rewriter) const {
1596   VectorType lhsType = op.getLhsType();
1597   VectorType rhsType = op.getRhsType();
1598   VectorType resType = op.getResultType().cast<VectorType>();
1599   // Find the iterator type index and result index.
1600   SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
1601   int64_t iterIndex = -1;
1602   int64_t dimSize = -1;
1603   if (lhsIndex >= 0) {
1604     iterIndex = iMap[0].getDimPosition(lhsIndex);
1605     assert((rhsIndex < 0 || iterIndex == iMap[1].getDimPosition(rhsIndex)) &&
1606            "parallel index should be free in LHS or batch in LHS/RHS");
1607     dimSize = lhsType.getDimSize(lhsIndex);
1608   } else {
1609     assert(rhsIndex >= 0 && "missing parallel index");
1610     iterIndex = iMap[1].getDimPosition(rhsIndex);
1611     dimSize = rhsType.getDimSize(rhsIndex);
1612   }
1613   assert(iterIndex >= 0 && "parallel index not listed in operand mapping");
1614   Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex);
1615   assert(lookup.hasValue() && "parallel index not listed in reduction");
1616   int64_t resIndex = lookup.getValue();
1617   // Construct new iterator types and affine map array attribute.
1618   std::array<AffineMap, 3> lowIndexingMaps = {
1619       adjustMap(iMap[0], iterIndex, rewriter),
1620       adjustMap(iMap[1], iterIndex, rewriter),
1621       adjustMap(iMap[2], iterIndex, rewriter)};
1622   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1623   auto lowIter =
1624       rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
1625   // Unroll into a series of lower dimensional vector.contract ops.
1626   Location loc = op.getLoc();
1627   Value result = rewriter.create<arith::ConstantOp>(
1628       loc, resType, rewriter.getZeroAttr(resType));
1629   for (int64_t d = 0; d < dimSize; ++d) {
1630     auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
1631     auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
1632     auto acc = reshapeLoad(loc, op.acc(), resType, resIndex, d, rewriter);
1633     Value lowContract = rewriter.create<vector::ContractionOp>(
1634         loc, lhs, rhs, acc, lowAffine, lowIter);
1635     result =
1636         reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter);
1637   }
1638   return result;
1639 }
1640 
1641 // Lower one reduction dimension.
1642 Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
1643                                             PatternRewriter &rewriter) const {
1644   auto loc = op.getLoc();
1645   VectorType lhsType = op.getLhsType();
1646   VectorType rhsType = op.getRhsType();
1647   Type resType = op.getResultType();
1648   assert(!resType.isa<VectorType>());
1649   bool isInt = resType.isa<IntegerType>();
1650   // Use iterator index 0.
1651   int64_t iterIndex = 0;
1652   SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
1653   Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
1654   Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
1655   assert(lookupLhs.hasValue() && "missing LHS parallel index");
1656   assert(lookupRhs.hasValue() && "missing RHS parallel index");
1657   int64_t lhsIndex = lookupLhs.getValue();
1658   int64_t rhsIndex = lookupRhs.getValue();
1659   int64_t dimSize = lhsType.getDimSize(lhsIndex);
1660   assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape");
1661   // Base case.
1662   if (lhsType.getRank() == 1) {
1663     assert(rhsType.getRank() == 1 && "corrupt contraction");
1664     Value m = createMul(loc, op.lhs(), op.rhs(), isInt, rewriter);
1665     auto kind = vector::CombiningKind::ADD;
1666     Value res = rewriter.create<vector::ReductionOp>(loc, kind, m);
1667     if (auto acc = op.acc())
1668       res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
1669     return res;
1670   }
1671   // Construct new iterator types and affine map array attribute.
1672   std::array<AffineMap, 3> lowIndexingMaps = {
1673       adjustMap(iMap[0], iterIndex, rewriter),
1674       adjustMap(iMap[1], iterIndex, rewriter),
1675       adjustMap(iMap[2], iterIndex, rewriter)};
1676   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1677   auto lowIter =
1678       rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
1679   // Unroll into a series of lower dimensional vector.contract ops.
1680   // By feeding the initial accumulator into the first contraction,
1681   // and the result of each contraction into the next, eventually
1682   // the sum of all reductions is computed.
1683   Value result = op.acc();
1684   for (int64_t d = 0; d < dimSize; ++d) {
1685     auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
1686     auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
1687     result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result,
1688                                                     lowAffine, lowIter);
1689   }
1690   return result;
1691 }
1692 
1693 } // namespace mlir
1694 
1695 Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
1696     OpBuilder &builder, Operation *op, ArrayRef<Value> ids,
1697     ArrayRef<int64_t> multiplicity, const AffineMap &map) {
1698   OpBuilder::InsertionGuard guard(builder);
1699   builder.setInsertionPointAfter(op);
1700   Location loc = op->getLoc();
1701   if (op->getNumResults() != 1)
1702     return {};
1703   Value result = op->getResult(0);
1704   VectorType type = op->getResult(0).getType().dyn_cast<VectorType>();
1705   if (!type || map.getNumResults() != multiplicity.size())
1706     return {};
1707   // For each dimension being distributed check that the size is a multiple of
1708   // the multiplicity. To handle more sizes we would need to support masking.
1709   unsigned multiplictyCount = 0;
1710   for (auto exp : map.getResults()) {
1711     auto affinExp = exp.dyn_cast<AffineDimExpr>();
1712     if (!affinExp || affinExp.getPosition() >= type.getRank() ||
1713         type.getDimSize(affinExp.getPosition()) %
1714                 multiplicity[multiplictyCount++] !=
1715             0)
1716       return {};
1717   }
1718   DistributeOps ops;
1719   ops.extract =
1720       builder.create<vector::ExtractMapOp>(loc, result, ids, multiplicity, map);
1721   ops.insert =
1722       builder.create<vector::InsertMapOp>(loc, ops.extract, result, ids);
1723   return ops;
1724 }
1725 
1726 /// Progressive lowering of transfer_read. This pattern supports lowering of
1727 /// `vector.transfer_read` to a combination of `vector.load` and
1728 /// `vector.broadcast` if all of the following hold:
1729 /// - Stride of most minor memref dimension must be 1.
1730 /// - Out-of-bounds masking is not required.
1731 /// - If the memref's element type is a vector type then it coincides with the
1732 ///   result type.
1733 /// - The permutation map doesn't perform permutation (broadcasting is allowed).
1734 struct TransferReadToVectorLoadLowering
1735     : public OpRewritePattern<vector::TransferReadOp> {
1736   TransferReadToVectorLoadLowering(MLIRContext *context,
1737                                    llvm::Optional<unsigned> maxRank)
1738       : OpRewritePattern<vector::TransferReadOp>(context),
1739         maxTransferRank(maxRank) {}
1740 
1741   LogicalResult matchAndRewrite(vector::TransferReadOp read,
1742                                 PatternRewriter &rewriter) const override {
1743     if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank)
1744       return failure();
1745 
1746     SmallVector<unsigned, 4> broadcastedDims;
1747     // Permutations are handled by VectorToSCF or
1748     // populateVectorTransferPermutationMapLoweringPatterns.
1749     // We let the 0-d corner case pass-through as it is supported.
1750     if (!read.permutation_map().isMinorIdentityWithBroadcasting(
1751             &broadcastedDims))
1752       return failure();
1753 
1754     auto memRefType = read.getShapedType().dyn_cast<MemRefType>();
1755     if (!memRefType)
1756       return failure();
1757 
1758     // Non-unit strides are handled by VectorToSCF.
1759     if (!vector::isLastMemrefDimUnitStride(memRefType))
1760       return failure();
1761 
1762     // If there is broadcasting involved then we first load the unbroadcasted
1763     // vector, and then broadcast it with `vector.broadcast`.
1764     ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
1765     SmallVector<int64_t, 4> unbroadcastedVectorShape(vectorShape.begin(),
1766                                                      vectorShape.end());
1767     for (unsigned i : broadcastedDims)
1768       unbroadcastedVectorShape[i] = 1;
1769     VectorType unbroadcastedVectorType = VectorType::get(
1770         unbroadcastedVectorShape, read.getVectorType().getElementType());
1771 
1772     // `vector.load` supports vector types as memref's elements only when the
1773     // resulting vector type is the same as the element type.
1774     auto memrefElTy = memRefType.getElementType();
1775     if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType)
1776       return failure();
1777 
1778     // Otherwise, element types of the memref and the vector must match.
1779     if (!memrefElTy.isa<VectorType>() &&
1780         memrefElTy != read.getVectorType().getElementType())
1781       return failure();
1782 
1783     // Out-of-bounds dims are handled by MaterializeTransferMask.
1784     if (read.hasOutOfBoundsDim())
1785       return failure();
1786 
1787     // Create vector load op.
1788     Operation *loadOp;
1789     if (read.mask()) {
1790       Value fill = rewriter.create<vector::SplatOp>(
1791           read.getLoc(), unbroadcastedVectorType, read.padding());
1792       loadOp = rewriter.create<vector::MaskedLoadOp>(
1793           read.getLoc(), unbroadcastedVectorType, read.source(), read.indices(),
1794           read.mask(), fill);
1795     } else {
1796       loadOp = rewriter.create<vector::LoadOp>(read.getLoc(),
1797                                                unbroadcastedVectorType,
1798                                                read.source(), read.indices());
1799     }
1800 
1801     // Insert a broadcasting op if required.
1802     if (!broadcastedDims.empty()) {
1803       rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1804           read, read.getVectorType(), loadOp->getResult(0));
1805     } else {
1806       rewriter.replaceOp(read, loadOp->getResult(0));
1807     }
1808 
1809     return success();
1810   }
1811 
1812   llvm::Optional<unsigned> maxTransferRank;
1813 };
1814 
1815 /// Replace a 0-d vector.load with a memref.load + vector.broadcast.
1816 // TODO: we shouldn't cross the vector/scalar domains just for this
1817 // but atm we lack the infra to avoid it. Possible solutions include:
1818 // - go directly to LLVM + bitcast
1819 // - introduce a bitcast op and likely a new pointer dialect
1820 // - let memref.load/store additionally support the 0-d vector case
1821 // There are still deeper data layout issues lingering even in this
1822 // trivial case (for architectures for which this matters).
1823 struct VectorLoadToMemrefLoadLowering
1824     : public OpRewritePattern<vector::LoadOp> {
1825   using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
1826 
1827   LogicalResult matchAndRewrite(vector::LoadOp loadOp,
1828                                 PatternRewriter &rewriter) const override {
1829     auto vecType = loadOp.getVectorType();
1830     if (vecType.getNumElements() != 1)
1831       return failure();
1832     auto memrefLoad = rewriter.create<memref::LoadOp>(
1833         loadOp.getLoc(), loadOp.base(), loadOp.indices());
1834     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
1835                                                      memrefLoad);
1836     return success();
1837   }
1838 };
1839 
1840 /// Replace a 0-d vector.store with a vector.extractelement + memref.store.
1841 struct VectorStoreToMemrefStoreLowering
1842     : public OpRewritePattern<vector::StoreOp> {
1843   using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
1844 
1845   LogicalResult matchAndRewrite(vector::StoreOp storeOp,
1846                                 PatternRewriter &rewriter) const override {
1847     auto vecType = storeOp.getVectorType();
1848     if (vecType.getNumElements() != 1)
1849       return failure();
1850     Value extracted;
1851     if (vecType.getRank() == 0) {
1852       // TODO: Unifiy once ExtractOp supports 0-d vectors.
1853       extracted = rewriter.create<vector::ExtractElementOp>(
1854           storeOp.getLoc(), storeOp.valueToStore());
1855     } else {
1856       SmallVector<int64_t> indices(vecType.getRank(), 0);
1857       extracted = rewriter.create<vector::ExtractOp>(
1858           storeOp.getLoc(), storeOp.valueToStore(), indices);
1859     }
1860 
1861     rewriter.replaceOpWithNewOp<memref::StoreOp>(
1862         storeOp, extracted, storeOp.base(), storeOp.indices());
1863     return success();
1864   }
1865 };
1866 
1867 /// Progressive lowering of transfer_write. This pattern supports lowering of
1868 /// `vector.transfer_write` to `vector.store` if all of the following hold:
1869 /// - Stride of most minor memref dimension must be 1.
1870 /// - Out-of-bounds masking is not required.
1871 /// - If the memref's element type is a vector type then it coincides with the
1872 ///   type of the written value.
1873 /// - The permutation map is the minor identity map (neither permutation nor
1874 ///   broadcasting is allowed).
1875 struct TransferWriteToVectorStoreLowering
1876     : public OpRewritePattern<vector::TransferWriteOp> {
1877   TransferWriteToVectorStoreLowering(MLIRContext *context,
1878                                      llvm::Optional<unsigned> maxRank)
1879       : OpRewritePattern<vector::TransferWriteOp>(context),
1880         maxTransferRank(maxRank) {}
1881 
1882   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
1883                                 PatternRewriter &rewriter) const override {
1884     if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank)
1885       return failure();
1886 
1887     // Permutations are handled by VectorToSCF or
1888     // populateVectorTransferPermutationMapLoweringPatterns.
1889     if ( // pass-through for the 0-d corner case.
1890         !write.permutation_map().isMinorIdentity())
1891       return failure();
1892 
1893     auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
1894     if (!memRefType)
1895       return failure();
1896 
1897     // Non-unit strides are handled by VectorToSCF.
1898     if (!vector::isLastMemrefDimUnitStride(memRefType))
1899       return failure();
1900 
1901     // `vector.store` supports vector types as memref's elements only when the
1902     // type of the vector value being written is the same as the element type.
1903     auto memrefElTy = memRefType.getElementType();
1904     if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType())
1905       return failure();
1906 
1907     // Otherwise, element types of the memref and the vector must match.
1908     if (!memrefElTy.isa<VectorType>() &&
1909         memrefElTy != write.getVectorType().getElementType())
1910       return failure();
1911 
1912     // Out-of-bounds dims are handled by MaterializeTransferMask.
1913     if (write.hasOutOfBoundsDim())
1914       return failure();
1915     if (write.mask()) {
1916       rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
1917           write, write.source(), write.indices(), write.mask(), write.vector());
1918     } else {
1919       rewriter.replaceOpWithNewOp<vector::StoreOp>(
1920           write, write.vector(), write.source(), write.indices());
1921     }
1922     return success();
1923   }
1924 
1925   llvm::Optional<unsigned> maxTransferRank;
1926 };
1927 
1928 // Returns the values in `arrayAttr` as an integer vector.
1929 static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) {
1930   return llvm::to_vector<4>(
1931       llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
1932                       [](IntegerAttr attr) { return attr.getInt(); }));
1933 }
1934 
1935 // Shuffles vector.bitcast op after vector.extract op.
1936 //
1937 // This transforms IR like:
1938 //   %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
1939 //   %1 = vector.extract %0[3] : vector<8xf16>
1940 // Into:
1941 //   %0 = vector.extract %src[1] : vector<4xf32>
1942 //   %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16>
1943 //   %2 = vector.extract %1[1] : vector<2xf16>
1944 struct BubbleDownVectorBitCastForExtract
1945     : public OpRewritePattern<vector::ExtractOp> {
1946   using OpRewritePattern::OpRewritePattern;
1947 
1948   LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
1949                                 PatternRewriter &rewriter) const override {
1950     // Only support extracting scalars for now.
1951     if (extractOp.getVectorType().getRank() != 1)
1952       return failure();
1953 
1954     auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>();
1955     if (!castOp)
1956       return failure();
1957 
1958     VectorType castSrcType = castOp.getSourceVectorType();
1959     VectorType castDstType = castOp.getResultVectorType();
1960     assert(castSrcType.getRank() == castDstType.getRank());
1961 
1962     // Fail to match if we only have one element in the cast op source.
1963     // This is to avoid infinite loop given that this pattern can generate
1964     // such cases.
1965     if (castSrcType.getNumElements() == 1)
1966       return failure();
1967 
1968     // Only support casting to a larger number of elements or now.
1969     // E.g., vector<4xf32> -> vector<8xf16>.
1970     if (castSrcType.getNumElements() > castDstType.getNumElements())
1971       return failure();
1972 
1973     unsigned expandRatio =
1974         castDstType.getNumElements() / castSrcType.getNumElements();
1975 
1976     auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t {
1977       return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
1978     };
1979 
1980     uint64_t index = getFirstIntValue(extractOp.position());
1981 
1982     // Get the single scalar (as a vector) in the source value that packs the
1983     // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
1984     VectorType oneScalarType =
1985         VectorType::get({1}, castSrcType.getElementType());
1986     Value packedValue = rewriter.create<vector::ExtractOp>(
1987         extractOp.getLoc(), oneScalarType, castOp.source(),
1988         rewriter.getI64ArrayAttr(index / expandRatio));
1989 
1990     // Cast it to a vector with the desired scalar's type.
1991     // E.g. f32 -> vector<2xf16>
1992     VectorType packedType =
1993         VectorType::get({expandRatio}, castDstType.getElementType());
1994     Value castedValue = rewriter.create<vector::BitCastOp>(
1995         extractOp.getLoc(), packedType, packedValue);
1996 
1997     // Finally extract the desired scalar.
1998     rewriter.replaceOpWithNewOp<vector::ExtractOp>(
1999         extractOp, extractOp.getType(), castedValue,
2000         rewriter.getI64ArrayAttr(index % expandRatio));
2001 
2002     return success();
2003   }
2004 };
2005 
2006 // Shuffles vector.bitcast op after vector.extract_strided_slice op.
2007 //
2008 // This transforms IR like:
2009 //    %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
2010 //     %0 = vector.extract_strided_slice %cast {
2011 //            offsets = [4], sizes = [4], strides = [1]
2012 //          } : vector<8xf16> to vector<4xf16>
2013 // Into:
2014 //   %0 = vector.extract_strided_slice %src {
2015 //          offsets = [2], sizes = [2], strides = [1]
2016 //        } : vector<4xf32> to vector<2xf32>
2017 //   %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16>
2018 struct BubbleDownBitCastForStridedSliceExtract
2019     : public OpRewritePattern<vector::ExtractStridedSliceOp> {
2020   using OpRewritePattern::OpRewritePattern;
2021 
2022   LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
2023                                 PatternRewriter &rewriter) const override {
2024     auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>();
2025     if (!castOp)
2026       return failure();
2027 
2028     VectorType castSrcType = castOp.getSourceVectorType();
2029     VectorType castDstType = castOp.getResultVectorType();
2030     assert(castSrcType.getRank() == castDstType.getRank());
2031 
2032     int64_t castSrcLastDim = castSrcType.getShape().back();
2033     int64_t castDstLastDim = castDstType.getShape().back();
2034     // Require casting to more elements for now; other cases to be implemented.
2035     if (castSrcLastDim > castDstLastDim)
2036       return failure();
2037 
2038     // Only accept all one strides for now.
2039     if (llvm::any_of(extractOp.strides().getAsValueRange<IntegerAttr>(),
2040                      [](const APInt &val) { return !val.isOneValue(); }))
2041       return failure();
2042 
2043     unsigned rank = extractOp.getVectorType().getRank();
2044     assert(castDstLastDim % castSrcLastDim == 0);
2045     int64_t expandRatio = castDstLastDim / castSrcLastDim;
2046 
2047     // If we have a less number of offsets than the rank, then implicitly we
2048     // are selecting the full range for the last bitcasted dimension; other
2049     // dimensions aren't affected. Otherwise, we need to scale down the last
2050     // dimension's offset given we are extracting from less elements now.
2051     ArrayAttr newOffsets = extractOp.offsets();
2052     if (newOffsets.size() == rank) {
2053       SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
2054       if (offsets.back() % expandRatio != 0)
2055         return failure();
2056       offsets.back() = offsets.back() / expandRatio;
2057       newOffsets = rewriter.getI64ArrayAttr(offsets);
2058     }
2059 
2060     // Similarly for sizes.
2061     ArrayAttr newSizes = extractOp.sizes();
2062     if (newSizes.size() == rank) {
2063       SmallVector<int64_t, 4> sizes = getIntValueVector(newSizes);
2064       if (sizes.back() % expandRatio != 0)
2065         return failure();
2066       sizes.back() = sizes.back() / expandRatio;
2067       newSizes = rewriter.getI64ArrayAttr(sizes);
2068     }
2069 
2070     SmallVector<int64_t, 4> dims =
2071         llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape());
2072     dims.back() = dims.back() / expandRatio;
2073     VectorType newExtractType =
2074         VectorType::get(dims, castSrcType.getElementType());
2075 
2076     auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
2077         extractOp.getLoc(), newExtractType, castOp.source(), newOffsets,
2078         newSizes, extractOp.strides());
2079 
2080     rewriter.replaceOpWithNewOp<vector::BitCastOp>(
2081         extractOp, extractOp.getType(), newExtractOp);
2082 
2083     return success();
2084   }
2085 };
2086 
2087 // Shuffles vector.bitcast op before vector.insert_strided_slice op.
2088 //
2089 // This transforms IR like:
2090 //   %0 = vector.insert_strided_slice %src, %dst {
2091 //          offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
2092 //   %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
2093 // Into:
2094 //   %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32>
2095 //   %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32>
2096 //   %2 = vector.insert_strided_slice %src, %dst {
2097 //          offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
2098 struct BubbleUpBitCastForStridedSliceInsert
2099     : public OpRewritePattern<vector::BitCastOp> {
2100   using OpRewritePattern::OpRewritePattern;
2101   LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
2102                                 PatternRewriter &rewriter) const override {
2103     VectorType castSrcType = bitcastOp.getSourceVectorType();
2104     VectorType castDstType = bitcastOp.getResultVectorType();
2105     assert(castSrcType.getRank() == castDstType.getRank());
2106 
2107     int64_t castSrcLastDim = castSrcType.getShape().back();
2108     int64_t castDstLastDim = castDstType.getShape().back();
2109     // Require casting to less elements for now; other cases to be implemented.
2110     if (castSrcLastDim < castDstLastDim)
2111       return failure();
2112 
2113     assert(castSrcLastDim % castDstLastDim == 0);
2114     int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
2115 
2116     auto insertOp =
2117         bitcastOp.source().getDefiningOp<vector::InsertStridedSliceOp>();
2118     if (!insertOp)
2119       return failure();
2120 
2121     // Only accept all one strides for now.
2122     if (llvm::any_of(insertOp.strides().getAsValueRange<IntegerAttr>(),
2123                      [](const APInt &val) { return !val.isOneValue(); }))
2124       return failure();
2125 
2126     unsigned rank = insertOp.getSourceVectorType().getRank();
2127     // Require insert op to have the same rank for the source and destination
2128     // vector; other cases to be implemented.
2129     if (rank != insertOp.getDestVectorType().getRank())
2130       return failure();
2131 
2132     ArrayAttr newOffsets = insertOp.offsets();
2133     assert(newOffsets.size() == rank);
2134     SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
2135     if (offsets.back() % shrinkRatio != 0)
2136       return failure();
2137     offsets.back() = offsets.back() / shrinkRatio;
2138     newOffsets = rewriter.getI64ArrayAttr(offsets);
2139 
2140     SmallVector<int64_t, 4> srcDims =
2141         llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
2142     srcDims.back() = srcDims.back() / shrinkRatio;
2143     VectorType newCastSrcType =
2144         VectorType::get(srcDims, castDstType.getElementType());
2145 
2146     auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
2147         bitcastOp.getLoc(), newCastSrcType, insertOp.source());
2148 
2149     SmallVector<int64_t, 4> dstDims =
2150         llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
2151     dstDims.back() = dstDims.back() / shrinkRatio;
2152     VectorType newCastDstType =
2153         VectorType::get(dstDims, castDstType.getElementType());
2154 
2155     auto newCastDstOp = rewriter.create<vector::BitCastOp>(
2156         bitcastOp.getLoc(), newCastDstType, insertOp.dest());
2157 
2158     rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
2159         bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
2160         insertOp.strides());
2161 
2162     return success();
2163   }
2164 };
2165 
2166 static Value createCastToIndexLike(PatternRewriter &rewriter, Location loc,
2167                                    Type targetType, Value value) {
2168   if (targetType == value.getType())
2169     return value;
2170 
2171   bool targetIsIndex = targetType.isIndex();
2172   bool valueIsIndex = value.getType().isIndex();
2173   if (targetIsIndex ^ valueIsIndex)
2174     return rewriter.create<arith::IndexCastOp>(loc, targetType, value);
2175 
2176   auto targetIntegerType = targetType.dyn_cast<IntegerType>();
2177   auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
2178   assert(targetIntegerType && valueIntegerType &&
2179          "unexpected cast between types other than integers and index");
2180   assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
2181 
2182   if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
2183     return rewriter.create<arith::ExtSIOp>(loc, targetIntegerType, value);
2184   return rewriter.create<arith::TruncIOp>(loc, targetIntegerType, value);
2185 }
2186 
2187 // Helper that returns a vector comparison that constructs a mask:
2188 //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
2189 //
2190 // If `dim == 0` then the result will be a 0-D vector.
2191 //
2192 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
2193 //       much more compact, IR for this operation, but LLVM eventually
2194 //       generates more elaborate instructions for this intrinsic since it
2195 //       is very conservative on the boundary conditions.
2196 static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
2197                                    bool indexOptimizations, int64_t dim,
2198                                    Value b, Value *off = nullptr) {
2199   auto loc = op->getLoc();
2200   // If we can assume all indices fit in 32-bit, we perform the vector
2201   // comparison in 32-bit to get a higher degree of SIMD parallelism.
2202   // Otherwise we perform the vector comparison using 64-bit indices.
2203   Type idxType =
2204       indexOptimizations ? rewriter.getI32Type() : rewriter.getI64Type();
2205   DenseIntElementsAttr indicesAttr;
2206   if (dim == 0 && indexOptimizations) {
2207     indicesAttr = DenseIntElementsAttr::get(
2208         VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0});
2209   } else if (dim == 0) {
2210     indicesAttr = DenseIntElementsAttr::get(
2211         VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0});
2212   } else if (indexOptimizations) {
2213     indicesAttr = rewriter.getI32VectorAttr(
2214         llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
2215   } else {
2216     indicesAttr = rewriter.getI64VectorAttr(
2217         llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
2218   }
2219   Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
2220   // Add in an offset if requested.
2221   if (off) {
2222     Value o = createCastToIndexLike(rewriter, loc, idxType, *off);
2223     Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o);
2224     indices = rewriter.create<arith::AddIOp>(loc, ov, indices);
2225   }
2226   // Construct the vector comparison.
2227   Value bound = createCastToIndexLike(rewriter, loc, idxType, b);
2228   Value bounds =
2229       rewriter.create<vector::SplatOp>(loc, indices.getType(), bound);
2230   return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
2231                                         bounds);
2232 }
2233 
2234 template <typename ConcreteOp>
2235 struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
2236 public:
2237   explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt)
2238       : mlir::OpRewritePattern<ConcreteOp>(context),
2239         indexOptimizations(enableIndexOpt) {}
2240 
2241   LogicalResult matchAndRewrite(ConcreteOp xferOp,
2242                                 PatternRewriter &rewriter) const override {
2243     if (!xferOp.hasOutOfBoundsDim())
2244       return failure();
2245 
2246     if (xferOp.getVectorType().getRank() > 1 ||
2247         llvm::size(xferOp.indices()) == 0)
2248       return failure();
2249 
2250     Location loc = xferOp->getLoc();
2251     VectorType vtp = xferOp.getVectorType();
2252 
2253     // Create the in-bounds mask with all elements between [0 .. dim - offset)
2254     // set and [dim - offset .. vector_length) unset.
2255     //
2256     // TODO: when the leaf transfer rank is k > 1, we need the last `k`
2257     //       dimensions here.
2258     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
2259     Value off = xferOp.indices()[lastIndex];
2260     Value dim =
2261         vector::createOrFoldDimOp(rewriter, loc, xferOp.source(), lastIndex);
2262     Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off);
2263     Value mask = rewriter.create<vector::CreateMaskOp>(
2264         loc,
2265         VectorType::get(vtp.getShape(), rewriter.getI1Type(),
2266                         vtp.getNumScalableDims()),
2267         b);
2268     if (xferOp.mask()) {
2269       // Intersect the in-bounds with the mask specified as an op parameter.
2270       mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.mask());
2271     }
2272 
2273     rewriter.updateRootInPlace(xferOp, [&]() {
2274       xferOp.maskMutable().assign(mask);
2275       xferOp.in_boundsAttr(rewriter.getBoolArrayAttr({true}));
2276     });
2277 
2278     return success();
2279   }
2280 
2281 private:
2282   const bool indexOptimizations;
2283 };
2284 
2285 /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only).
2286 class VectorCreateMaskOpConversion
2287     : public OpRewritePattern<vector::CreateMaskOp> {
2288 public:
2289   explicit VectorCreateMaskOpConversion(MLIRContext *context,
2290                                         bool enableIndexOpt)
2291       : mlir::OpRewritePattern<vector::CreateMaskOp>(context),
2292         indexOptimizations(enableIndexOpt) {}
2293 
2294   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
2295                                 PatternRewriter &rewriter) const override {
2296     auto dstType = op.getType();
2297     int64_t rank = dstType.getRank();
2298     if (rank > 1)
2299       return failure();
2300     rewriter.replaceOp(
2301         op, buildVectorComparison(rewriter, op, indexOptimizations,
2302                                   rank == 0 ? 0 : dstType.getDimSize(0),
2303                                   op.getOperand(0)));
2304     return success();
2305   }
2306 
2307 private:
2308   const bool indexOptimizations;
2309 };
2310 
2311 // Drop inner most contiguous unit dimensions from transfer_read operand.
2312 class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
2313   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
2314 
2315   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
2316                                 PatternRewriter &rewriter) const override {
2317     // TODO: support 0-d corner case.
2318     if (readOp.getTransferRank() == 0)
2319       return failure();
2320 
2321     // TODO: support mask.
2322     if (readOp.mask())
2323       return failure();
2324 
2325     auto srcType = readOp.source().getType().dyn_cast<MemRefType>();
2326     if (!srcType || !srcType.hasStaticShape())
2327       return failure();
2328 
2329     if (!readOp.permutation_map().isMinorIdentity())
2330       return failure();
2331 
2332     auto targetType = readOp.getVectorType();
2333     if (targetType.getRank() <= 1)
2334       return failure();
2335 
2336     SmallVector<int64_t> srcStrides;
2337     int64_t srcOffset;
2338     if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
2339       return failure();
2340 
2341     size_t dimsToDrop = 0;
2342     for (size_t i = 1; i < srcStrides.size(); ++i) {
2343       int dim = srcType.getRank() - i - 1;
2344       if (srcStrides[dim] == 1) {
2345         dimsToDrop++;
2346       } else {
2347         break;
2348       }
2349     }
2350     if (dimsToDrop == 0)
2351       return failure();
2352 
2353     auto resultTargetVecType =
2354         VectorType::get(targetType.getShape().drop_back(dimsToDrop),
2355                         targetType.getElementType());
2356 
2357     MemRefType resultMemrefType;
2358     if (srcType.getLayout().getAffineMap().isIdentity()) {
2359       resultMemrefType = MemRefType::get(
2360           srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
2361           {}, srcType.getMemorySpaceAsInt());
2362     } else {
2363       AffineMap map = srcType.getLayout().getAffineMap();
2364       int numResultDims = map.getNumDims() - dimsToDrop;
2365       int numSymbols = map.getNumSymbols();
2366       for (size_t i = 0; i < dimsToDrop; ++i) {
2367         int dim = srcType.getRank() - i - 1;
2368         map = map.replace(rewriter.getAffineDimExpr(dim),
2369                           rewriter.getAffineConstantExpr(0), numResultDims,
2370                           numSymbols);
2371       }
2372       resultMemrefType = MemRefType::get(
2373           srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
2374           map, srcType.getMemorySpaceAsInt());
2375     }
2376 
2377     auto loc = readOp.getLoc();
2378     SmallVector<int64_t> offsets(srcType.getRank(), 0);
2379     SmallVector<int64_t> strides(srcType.getRank(), 1);
2380 
2381     ArrayAttr inBoundsAttr =
2382         readOp.in_bounds()
2383             ? rewriter.getArrayAttr(
2384                   readOp.in_boundsAttr().getValue().drop_back(dimsToDrop))
2385             : ArrayAttr();
2386     Value rankedReducedView = rewriter.create<memref::SubViewOp>(
2387         loc, resultMemrefType, readOp.source(), offsets, srcType.getShape(),
2388         strides);
2389     auto permMap = getTransferMinorIdentityMap(
2390         rankedReducedView.getType().cast<ShapedType>(), resultTargetVecType);
2391     Value result = rewriter.create<vector::TransferReadOp>(
2392         loc, resultTargetVecType, rankedReducedView,
2393         readOp.indices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
2394         readOp.padding(),
2395         // TODO: support mask.
2396         /*mask=*/Value(), inBoundsAttr);
2397     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
2398                                                      result);
2399     return success();
2400   }
2401 };
2402 
2403 namespace {
2404 
2405 /// This function checks to see if the vector combining kind
2406 /// is consistent with the integer or float element type.
2407 static bool isValidKind(bool isInt, vector::CombiningKind kind) {
2408   using vector::CombiningKind;
2409   enum class KindType { FLOAT, INT, INVALID };
2410   KindType type{KindType::INVALID};
2411   switch (kind) {
2412   case CombiningKind::MINF:
2413   case CombiningKind::MAXF:
2414     type = KindType::FLOAT;
2415     break;
2416   case CombiningKind::MINUI:
2417   case CombiningKind::MINSI:
2418   case CombiningKind::MAXUI:
2419   case CombiningKind::MAXSI:
2420   case CombiningKind::AND:
2421   case CombiningKind::OR:
2422   case CombiningKind::XOR:
2423     type = KindType::INT;
2424     break;
2425   case CombiningKind::ADD:
2426   case CombiningKind::MUL:
2427     type = isInt ? KindType::INT : KindType::FLOAT;
2428     break;
2429   }
2430   bool isValidIntKind = (type == KindType::INT) && isInt;
2431   bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt);
2432   return (isValidIntKind || isValidFloatKind);
2433 }
2434 
2435 /// This function constructs the appropriate integer or float
2436 /// operation given the vector combining kind and operands. The
2437 /// supported int operations are : add, mul, min (signed/unsigned),
2438 /// max(signed/unsigned), and, or, xor. The supported float
2439 /// operations are : add, mul, min and max.
2440 static Value genOperator(Location loc, Value x, Value y,
2441                          vector::CombiningKind kind,
2442                          PatternRewriter &rewriter) {
2443   using vector::CombiningKind;
2444 
2445   auto elType = x.getType().cast<VectorType>().getElementType();
2446   bool isInt = elType.isIntOrIndex();
2447 
2448   Value combinedResult{nullptr};
2449   switch (kind) {
2450   case CombiningKind::ADD:
2451     if (isInt)
2452       combinedResult = rewriter.create<arith::AddIOp>(loc, x, y);
2453     else
2454       combinedResult = rewriter.create<arith::AddFOp>(loc, x, y);
2455     break;
2456   case CombiningKind::MUL:
2457     if (isInt)
2458       combinedResult = rewriter.create<arith::MulIOp>(loc, x, y);
2459     else
2460       combinedResult = rewriter.create<arith::MulFOp>(loc, x, y);
2461     break;
2462   case CombiningKind::MINUI:
2463     combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y);
2464     break;
2465   case CombiningKind::MINSI:
2466     combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y);
2467     break;
2468   case CombiningKind::MAXUI:
2469     combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y);
2470     break;
2471   case CombiningKind::MAXSI:
2472     combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y);
2473     break;
2474   case CombiningKind::AND:
2475     combinedResult = rewriter.create<arith::AndIOp>(loc, x, y);
2476     break;
2477   case CombiningKind::OR:
2478     combinedResult = rewriter.create<arith::OrIOp>(loc, x, y);
2479     break;
2480   case CombiningKind::XOR:
2481     combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y);
2482     break;
2483   case CombiningKind::MINF:
2484     combinedResult = rewriter.create<arith::MinFOp>(loc, x, y);
2485     break;
2486   case CombiningKind::MAXF:
2487     combinedResult = rewriter.create<arith::MaxFOp>(loc, x, y);
2488     break;
2489   }
2490   return combinedResult;
2491 }
2492 
2493 /// Convert vector.scan op into arith ops and
2494 /// vector.insert_strided_slice/extract_strided_slice
2495 ///
2496 /// Ex:
2497 /// ```
2498 ///   %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim =
2499 ///   1} :
2500 ///     (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>)
2501 /// ```
2502 /// Gets converted to:
2503 /// ```
2504 ///   %cst = arith.constant dense<0> : vector<2x3xi32>
2505 ///   %0 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 1],
2506 ///   strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %1 =
2507 ///   vector.insert_strided_slice %0, %cst {offsets = [0, 0], strides = [1, 1]}
2508 ///   : vector<2x1xi32> into vector<2x3xi32> %2 = vector.extract_strided_slice
2509 ///   %arg0 {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} :
2510 ///   vector<2x3xi32> to vector<2x1xi32> %3 = arith.muli %0, %2 :
2511 ///   vector<2x1xi32> %4 = vector.insert_strided_slice %3, %1 {offsets = [0, 1],
2512 ///   strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %5 =
2513 ///   vector.extract_strided_slice %arg0 {offsets = [0, 2], sizes = [2, 1],
2514 ///   strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %6 = arith.muli %3,
2515 ///   %5 : vector<2x1xi32> %7 = vector.insert_strided_slice %6, %4 {offsets =
2516 ///   [0, 2], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %8 =
2517 ///   vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32> return %7, %8 :
2518 ///   vector<2x3xi32>, vector<2xi32>
2519 /// ```
2520 struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
2521   using OpRewritePattern<vector::ScanOp>::OpRewritePattern;
2522 
2523   LogicalResult matchAndRewrite(vector::ScanOp scanOp,
2524                                 PatternRewriter &rewriter) const override {
2525     auto loc = scanOp.getLoc();
2526     VectorType destType = scanOp.getDestType();
2527     ArrayRef<int64_t> destShape = destType.getShape();
2528     auto elType = destType.getElementType();
2529     bool isInt = elType.isIntOrIndex();
2530     if (!isValidKind(isInt, scanOp.kind()))
2531       return failure();
2532 
2533     VectorType resType = VectorType::get(destShape, elType);
2534     Value result = rewriter.create<arith::ConstantOp>(
2535         loc, resType, rewriter.getZeroAttr(resType));
2536     int64_t reductionDim = scanOp.reduction_dim();
2537     bool inclusive = scanOp.inclusive();
2538     int64_t destRank = destType.getRank();
2539     VectorType initialValueType = scanOp.getInitialValueType();
2540     int64_t initialValueRank = initialValueType.getRank();
2541 
2542     SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end());
2543     reductionShape[reductionDim] = 1;
2544     VectorType reductionType = VectorType::get(reductionShape, elType);
2545     SmallVector<int64_t> offsets(destRank, 0);
2546     SmallVector<int64_t> strides(destRank, 1);
2547     SmallVector<int64_t> sizes(destShape.begin(), destShape.end());
2548     sizes[reductionDim] = 1;
2549     ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes);
2550     ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides);
2551 
2552     Value lastOutput, lastInput;
2553     for (int i = 0; i < destShape[reductionDim]; i++) {
2554       offsets[reductionDim] = i;
2555       ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets);
2556       Value input = rewriter.create<vector::ExtractStridedSliceOp>(
2557           loc, reductionType, scanOp.source(), scanOffsets, scanSizes,
2558           scanStrides);
2559       Value output;
2560       if (i == 0) {
2561         if (inclusive) {
2562           output = input;
2563         } else {
2564           if (initialValueRank == 0) {
2565             // ShapeCastOp cannot handle 0-D vectors
2566             output = rewriter.create<vector::BroadcastOp>(
2567                 loc, input.getType(), scanOp.initial_value());
2568           } else {
2569             output = rewriter.create<vector::ShapeCastOp>(
2570                 loc, input.getType(), scanOp.initial_value());
2571           }
2572         }
2573       } else {
2574         Value y = inclusive ? input : lastInput;
2575         output = genOperator(loc, lastOutput, y, scanOp.kind(), rewriter);
2576         assert(output != nullptr);
2577       }
2578       result = rewriter.create<vector::InsertStridedSliceOp>(
2579           loc, output, result, offsets, strides);
2580       lastOutput = output;
2581       lastInput = input;
2582     }
2583 
2584     Value reduction;
2585     if (initialValueRank == 0) {
2586       Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0);
2587       reduction =
2588           rewriter.create<vector::BroadcastOp>(loc, initialValueType, v);
2589     } else {
2590       reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType,
2591                                                        lastOutput);
2592     }
2593 
2594     rewriter.replaceOp(scanOp, {result, reduction});
2595     return success();
2596   }
2597 };
2598 
2599 } // namespace
2600 
2601 void mlir::vector::populateVectorMaskMaterializationPatterns(
2602     RewritePatternSet &patterns, bool indexOptimizations) {
2603   patterns.add<VectorCreateMaskOpConversion,
2604                MaterializeTransferMask<vector::TransferReadOp>,
2605                MaterializeTransferMask<vector::TransferWriteOp>>(
2606       patterns.getContext(), indexOptimizations);
2607 }
2608 
2609 void mlir::vector::populateShapeCastFoldingPatterns(
2610     RewritePatternSet &patterns) {
2611   patterns.add<ShapeCastOpFolder>(patterns.getContext());
2612 }
2613 
2614 void mlir::vector::populateBubbleVectorBitCastOpPatterns(
2615     RewritePatternSet &patterns) {
2616   patterns.add<BubbleDownVectorBitCastForExtract,
2617                BubbleDownBitCastForStridedSliceExtract,
2618                BubbleUpBitCastForStridedSliceInsert>(patterns.getContext());
2619 }
2620 
2621 void mlir::vector::populateVectorBroadcastLoweringPatterns(
2622     RewritePatternSet &patterns) {
2623   patterns.add<BroadcastOpLowering>(patterns.getContext());
2624 }
2625 
2626 void mlir::vector::populateVectorMaskOpLoweringPatterns(
2627     RewritePatternSet &patterns) {
2628   patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
2629       patterns.getContext());
2630 }
2631 
2632 void mlir::vector::populateVectorShapeCastLoweringPatterns(
2633     RewritePatternSet &patterns) {
2634   patterns.add<ShapeCastOp2DDownCastRewritePattern,
2635                ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>(
2636       patterns.getContext());
2637 }
2638 
2639 void mlir::vector::populateVectorContractLoweringPatterns(
2640     RewritePatternSet &patterns, VectorTransformsOptions options) {
2641   patterns.add<OuterProductOpLowering>(patterns.getContext());
2642   patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
2643                ContractionOpToOuterProductOpLowering>(options,
2644                                                       patterns.getContext());
2645 }
2646 
2647 void mlir::vector::populateVectorTransposeLoweringPatterns(
2648     RewritePatternSet &patterns, VectorTransformsOptions options) {
2649   patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
2650       options, patterns.getContext());
2651 }
2652 
2653 void mlir::vector::populateVectorReductionToContractPatterns(
2654     RewritePatternSet &patterns) {
2655   patterns.add<MultiReduceToContract, CombineContractBroadcast,
2656                CombineContractTranspose, ReorderCastOpsOnBroadcast,
2657                ReorderCastOpsOnTranspose>(patterns.getContext());
2658 }
2659 
2660 void mlir::vector::
2661     populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
2662         RewritePatternSet &patterns) {
2663   patterns.add<DropInnerMostUnitDims>(patterns.getContext());
2664 }
2665 
2666 void mlir::vector::populateVectorTransferLoweringPatterns(
2667     RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank) {
2668   patterns.add<TransferReadToVectorLoadLowering,
2669                TransferWriteToVectorStoreLowering>(patterns.getContext(),
2670                                                    maxTransferRank);
2671   patterns
2672       .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
2673           patterns.getContext());
2674 }
2675 
2676 void mlir::vector::populateVectorScanLoweringPatterns(
2677     RewritePatternSet &patterns) {
2678   patterns.add<ScanToArithOps>(patterns.getContext());
2679 }
2680