1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites as 1->N patterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <type_traits>
14 
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Linalg/IR/Linalg.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/SCF/SCF.h"
20 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
21 
22 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
23 #include "mlir/IR/ImplicitLocOpBuilder.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/Interfaces/VectorInterfaces.h"
27 
28 #include "llvm/ADT/DenseSet.h"
29 #include "llvm/ADT/MapVector.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/Support/CommandLine.h"
32 #include "llvm/Support/Debug.h"
33 #include "llvm/Support/raw_ostream.h"
34 
35 #define DEBUG_TYPE "vector-to-vector"
36 
37 using namespace mlir;
38 using namespace mlir::vector;
39 
40 // Helper to find an index in an affine map.
41 static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
42   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
43     int64_t idx = map.getDimPosition(i);
44     if (idx == index)
45       return i;
46   }
47   return None;
48 }
49 
50 // Helper to construct iterator types with one index removed.
51 static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes,
52                                             int64_t index) {
53   SmallVector<Attribute, 4> results;
54   for (const auto &it : llvm::enumerate(iteratorTypes)) {
55     int64_t idx = it.index();
56     if (idx == index)
57       continue;
58     results.push_back(it.value());
59   }
60   return results;
61 }
62 
63 // Helper to construct an affine map with one index removed.
64 static AffineMap adjustMap(AffineMap map, int64_t index,
65                            PatternRewriter &rewriter) {
66   auto *ctx = rewriter.getContext();
67   SmallVector<AffineExpr, 4> results;
68   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
69     int64_t idx = map.getDimPosition(i);
70     if (idx == index)
71       continue;
72     // Re-insert remaining indices, but renamed when occurring
73     // after the removed index.
74     auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
75     results.push_back(targetExpr);
76   }
77   return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
78 }
79 
80 // Helper method to possibly drop a dimension in a load.
81 // TODO
82 static Value reshapeLoad(Location loc, Value val, VectorType type,
83                          int64_t index, int64_t pos,
84                          PatternRewriter &rewriter) {
85   if (index == -1)
86     return val;
87   Type lowType = VectorType::Builder(type).dropDim(0);
88   // At extraction dimension?
89   if (index == 0) {
90     auto posAttr = rewriter.getI64ArrayAttr(pos);
91     return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
92   }
93   // Unroll leading dimensions.
94   VectorType vType = lowType.cast<VectorType>();
95   Type resType = VectorType::Builder(type).dropDim(index);
96   auto resVectorType = resType.cast<VectorType>();
97   Value result = rewriter.create<arith::ConstantOp>(
98       loc, resVectorType, rewriter.getZeroAttr(resVectorType));
99   for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) {
100     auto posAttr = rewriter.getI64ArrayAttr(d);
101     Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
102     Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
103     result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result,
104                                                posAttr);
105   }
106   return result;
107 }
108 
109 // Helper method to possibly drop a dimension in a store.
110 // TODO
111 static Value reshapeStore(Location loc, Value val, Value result,
112                           VectorType type, int64_t index, int64_t pos,
113                           PatternRewriter &rewriter) {
114   // Unmodified?
115   if (index == -1)
116     return val;
117   // At insertion dimension?
118   if (index == 0) {
119     auto posAttr = rewriter.getI64ArrayAttr(pos);
120     return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
121   }
122   // Unroll leading dimensions.
123   Type lowType = VectorType::Builder(type).dropDim(0);
124   VectorType vType = lowType.cast<VectorType>();
125   Type insType = VectorType::Builder(vType).dropDim(0);
126   for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
127     auto posAttr = rewriter.getI64ArrayAttr(d);
128     Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
129     Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr);
130     Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
131     result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr);
132   }
133   return result;
134 }
135 
136 template <typename IntType>
137 static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
138   return llvm::to_vector<4>(llvm::map_range(
139       arrayAttr.getAsRange<IntegerAttr>(),
140       [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
141 }
142 
143 namespace {
144 
145 /// ShapeCastOpFolder folds cancelling ShapeCastOps away.
146 //
147 // Example:
148 //
149 //  The following MLIR with cancelling ShapeCastOps:
150 //
151 //   %0 = source : vector<5x4x2xf32>
152 //   %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
153 //   %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
154 //   %3 = user %2 : vector<5x4x2xf32>
155 //
156 //  Should canonicalize to the following:
157 //
158 //   %0 = source : vector<5x4x2xf32>
159 //   %1 = user %0 : vector<5x4x2xf32>
160 //
161 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
162   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
163 
164   LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
165                                 PatternRewriter &rewriter) const override {
166     // Check if 'shapeCastOp' has vector source/result type.
167     auto sourceVectorType =
168         shapeCastOp.source().getType().dyn_cast_or_null<VectorType>();
169     auto resultVectorType =
170         shapeCastOp.result().getType().dyn_cast_or_null<VectorType>();
171     if (!sourceVectorType || !resultVectorType)
172       return failure();
173 
174     // Check if shape cast op source operand is also a shape cast op.
175     auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
176         shapeCastOp.source().getDefiningOp());
177     if (!sourceShapeCastOp)
178       return failure();
179     auto operandSourceVectorType =
180         sourceShapeCastOp.source().getType().cast<VectorType>();
181     auto operandResultVectorType = sourceShapeCastOp.getType();
182 
183     // Check if shape cast operations invert each other.
184     if (operandSourceVectorType != resultVectorType ||
185         operandResultVectorType != sourceVectorType)
186       return failure();
187 
188     rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.source());
189     return success();
190   }
191 };
192 
193 /// Progressive lowering of BroadcastOp.
194 class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
195 public:
196   using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
197 
198   LogicalResult matchAndRewrite(vector::BroadcastOp op,
199                                 PatternRewriter &rewriter) const override {
200     auto loc = op.getLoc();
201     VectorType dstType = op.getVectorType();
202     VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
203     Type eltType = dstType.getElementType();
204 
205     // Scalar to any vector can use splat.
206     if (!srcType) {
207       rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.source());
208       return success();
209     }
210 
211     // Determine rank of source and destination.
212     int64_t srcRank = srcType.getRank();
213     int64_t dstRank = dstType.getRank();
214 
215     // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
216     if (srcRank <= 1 && dstRank == 1) {
217       Value ext;
218       if (srcRank == 0)
219         ext = rewriter.create<vector::ExtractElementOp>(loc, op.source());
220       else
221         ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
222       rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
223       return success();
224     }
225 
226     // Duplicate this rank.
227     // For example:
228     //   %x = broadcast %y  : k-D to n-D, k < n
229     // becomes:
230     //   %b = broadcast %y  : k-D to (n-1)-D
231     //   %x = [%b,%b,%b,%b] : n-D
232     // becomes:
233     //   %b = [%y,%y]       : (n-1)-D
234     //   %x = [%b,%b,%b,%b] : n-D
235     if (srcRank < dstRank) {
236       // Duplication.
237       VectorType resType =
238           VectorType::get(dstType.getShape().drop_front(), eltType);
239       Value bcst =
240           rewriter.create<vector::BroadcastOp>(loc, resType, op.source());
241       Value result = rewriter.create<arith::ConstantOp>(
242           loc, dstType, rewriter.getZeroAttr(dstType));
243       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
244         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
245       rewriter.replaceOp(op, result);
246       return success();
247     }
248 
249     // Find non-matching dimension, if any.
250     assert(srcRank == dstRank);
251     int64_t m = -1;
252     for (int64_t r = 0; r < dstRank; r++)
253       if (srcType.getDimSize(r) != dstType.getDimSize(r)) {
254         m = r;
255         break;
256       }
257 
258     // All trailing dimensions are the same. Simply pass through.
259     if (m == -1) {
260       rewriter.replaceOp(op, op.source());
261       return success();
262     }
263 
264     // Any non-matching dimension forces a stretch along this rank.
265     // For example:
266     //   %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32>
267     // becomes:
268     //   %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32>
269     //   %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32>
270     //   %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32>
271     //   %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32>
272     //   %x = [%a,%b,%c,%d]
273     // becomes:
274     //   %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32>
275     //   %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32>
276     //   %a = [%u, %v]
277     //   ..
278     //   %x = [%a,%b,%c,%d]
279     VectorType resType =
280         VectorType::get(dstType.getShape().drop_front(), eltType);
281     Value result = rewriter.create<arith::ConstantOp>(
282         loc, dstType, rewriter.getZeroAttr(dstType));
283     if (m == 0) {
284       // Stetch at start.
285       Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
286       Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
287       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
288         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
289     } else {
290       // Stetch not at start.
291       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
292         Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), d);
293         Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
294         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
295       }
296     }
297     rewriter.replaceOp(op, result);
298     return success();
299   }
300 };
301 
302 /// Progressive lowering of TransposeOp.
303 /// One:
304 ///   %x = vector.transpose %y, [1, 0]
305 /// is replaced by:
306 ///   %z = arith.constant dense<0.000000e+00>
307 ///   %0 = vector.extract %y[0, 0]
308 ///   %1 = vector.insert %0, %z [0, 0]
309 ///   ..
310 ///   %x = vector.insert .., .. [.., ..]
311 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
312 public:
313   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
314 
315   TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
316                       MLIRContext *context)
317       : OpRewritePattern<vector::TransposeOp>(context),
318         vectorTransformOptions(vectorTransformOptions) {}
319 
320   LogicalResult matchAndRewrite(vector::TransposeOp op,
321                                 PatternRewriter &rewriter) const override {
322     auto loc = op.getLoc();
323 
324     VectorType resType = op.getResultType();
325 
326     // Set up convenience transposition table.
327     SmallVector<int64_t, 4> transp;
328     for (auto attr : op.transp())
329       transp.push_back(attr.cast<IntegerAttr>().getInt());
330 
331     if (vectorTransformOptions.vectorTransposeLowering ==
332             vector::VectorTransposeLowering::Shuffle &&
333         resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0)
334       return rewriter.notifyMatchFailure(
335           op, "Options specifies lowering to shuffle");
336 
337     // Handle a true 2-D matrix transpose differently when requested.
338     if (vectorTransformOptions.vectorTransposeLowering ==
339             vector::VectorTransposeLowering::Flat &&
340         resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
341       Type flattenedType =
342           VectorType::get(resType.getNumElements(), resType.getElementType());
343       auto matrix =
344           rewriter.create<vector::ShapeCastOp>(loc, flattenedType, op.vector());
345       auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
346       auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
347       Value trans = rewriter.create<vector::FlatTransposeOp>(
348           loc, flattenedType, matrix, rows, columns);
349       rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
350       return success();
351     }
352 
353     // Generate fully unrolled extract/insert ops.
354     Value result = rewriter.create<arith::ConstantOp>(
355         loc, resType, rewriter.getZeroAttr(resType));
356     SmallVector<int64_t, 4> lhs(transp.size(), 0);
357     SmallVector<int64_t, 4> rhs(transp.size(), 0);
358     rewriter.replaceOp(op, expandIndices(loc, resType, 0, transp, lhs, rhs,
359                                          op.vector(), result, rewriter));
360     return success();
361   }
362 
363 private:
364   // Builds the indices arrays for the lhs and rhs. Generates the extract/insert
365   // operation when al ranks are exhausted.
366   Value expandIndices(Location loc, VectorType resType, int64_t pos,
367                       SmallVector<int64_t, 4> &transp,
368                       SmallVector<int64_t, 4> &lhs,
369                       SmallVector<int64_t, 4> &rhs, Value input, Value result,
370                       PatternRewriter &rewriter) const {
371     if (pos >= resType.getRank()) {
372       auto ridx = rewriter.getI64ArrayAttr(rhs);
373       auto lidx = rewriter.getI64ArrayAttr(lhs);
374       Type eltType = resType.getElementType();
375       Value e = rewriter.create<vector::ExtractOp>(loc, eltType, input, ridx);
376       return rewriter.create<vector::InsertOp>(loc, resType, e, result, lidx);
377     }
378     for (int64_t d = 0, e = resType.getDimSize(pos); d < e; ++d) {
379       lhs[pos] = d;
380       rhs[transp[pos]] = d;
381       result = expandIndices(loc, resType, pos + 1, transp, lhs, rhs, input,
382                              result, rewriter);
383     }
384     return result;
385   }
386 
387   /// Options to control the vector patterns.
388   vector::VectorTransformsOptions vectorTransformOptions;
389 };
390 
391 /// Rewrite a 2-D vector.transpose as a sequence of:
392 ///   vector.shape_cast 2D -> 1D
393 ///   vector.shuffle
394 ///   vector.shape_cast 1D -> 2D
395 class TransposeOp2DToShuffleLowering
396     : public OpRewritePattern<vector::TransposeOp> {
397 public:
398   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
399 
400   TransposeOp2DToShuffleLowering(
401       vector::VectorTransformsOptions vectorTransformOptions,
402       MLIRContext *context)
403       : OpRewritePattern<vector::TransposeOp>(context),
404         vectorTransformOptions(vectorTransformOptions) {}
405 
406   LogicalResult matchAndRewrite(vector::TransposeOp op,
407                                 PatternRewriter &rewriter) const override {
408     auto loc = op.getLoc();
409 
410     VectorType srcType = op.getVectorType();
411     if (srcType.getRank() != 2)
412       return rewriter.notifyMatchFailure(op, "Not a 2D transpose");
413 
414     SmallVector<int64_t, 4> transp;
415     for (auto attr : op.transp())
416       transp.push_back(attr.cast<IntegerAttr>().getInt());
417     if (transp[0] != 1 && transp[1] != 0)
418       return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation");
419 
420     if (vectorTransformOptions.vectorTransposeLowering !=
421         VectorTransposeLowering::Shuffle)
422       return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle");
423 
424     int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
425     Value casted = rewriter.create<vector::ShapeCastOp>(
426         loc, VectorType::get({m * n}, srcType.getElementType()), op.vector());
427     SmallVector<int64_t> mask;
428     mask.reserve(m * n);
429     for (int64_t j = 0; j < n; ++j)
430       for (int64_t i = 0; i < m; ++i)
431         mask.push_back(i * n + j);
432 
433     Value shuffled =
434         rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask);
435     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
436                                                      shuffled);
437 
438     return success();
439   }
440 
441 private:
442   /// Options to control the vector patterns.
443   vector::VectorTransformsOptions vectorTransformOptions;
444 };
445 
446 /// Progressive lowering of OuterProductOp.
447 /// One:
448 ///   %x = vector.outerproduct %lhs, %rhs, %acc
449 /// is replaced by:
450 ///   %z = zero-result
451 ///   %0 = vector.extract %lhs[0]
452 ///   %1 = vector.broadcast %0
453 ///   %2 = vector.extract %acc[0]
454 ///   %3 = vector.fma %1, %rhs, %2
455 ///   %4 = vector.insert %3, %z[0]
456 ///   ..
457 ///   %x = vector.insert %.., %..[N-1]
458 ///
459 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
460 public:
461   using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
462 
463   LogicalResult matchAndRewrite(vector::OuterProductOp op,
464                                 PatternRewriter &rewriter) const override {
465     auto loc = op.getLoc();
466 
467     VectorType lhsType = op.getOperandVectorTypeLHS();
468     VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>();
469     VectorType resType = op.getVectorType();
470     Type eltType = resType.getElementType();
471     bool isInt = eltType.isa<IntegerType, IndexType>();
472     Value acc = (op.acc().empty()) ? nullptr : op.acc()[0];
473     vector::CombiningKind kind = op.kind();
474 
475     if (!rhsType) {
476       // Special case: AXPY operation.
477       Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.rhs());
478       Optional<Value> mult =
479           isInt ? genMultI(loc, op.lhs(), b, acc, kind, rewriter)
480                 : genMultF(loc, op.lhs(), b, acc, kind, rewriter);
481       if (!mult.hasValue())
482         return failure();
483       rewriter.replaceOp(op, mult.getValue());
484       return success();
485     }
486 
487     Value result = rewriter.create<arith::ConstantOp>(
488         loc, resType, rewriter.getZeroAttr(resType));
489     for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
490       auto pos = rewriter.getI64ArrayAttr(d);
491       Value x = rewriter.create<vector::ExtractOp>(loc, eltType, op.lhs(), pos);
492       Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
493       Value r = nullptr;
494       if (acc)
495         r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
496       Optional<Value> m = isInt ? genMultI(loc, a, op.rhs(), r, kind, rewriter)
497                                 : genMultF(loc, a, op.rhs(), r, kind, rewriter);
498       if (!m.hasValue())
499         return failure();
500       result = rewriter.create<vector::InsertOp>(loc, resType, m.getValue(),
501                                                  result, pos);
502     }
503     rewriter.replaceOp(op, result);
504     return success();
505   }
506 
507 private:
508   static Optional<Value> genMultI(Location loc, Value x, Value y, Value acc,
509                                   vector::CombiningKind kind,
510                                   PatternRewriter &rewriter) {
511     using vector::CombiningKind;
512 
513     auto mul = rewriter.create<arith::MulIOp>(loc, x, y);
514     if (!acc)
515       return Optional<Value>(mul);
516 
517     Value combinedResult;
518     switch (kind) {
519     case CombiningKind::ADD:
520       combinedResult = rewriter.create<arith::AddIOp>(loc, mul, acc);
521       break;
522     case CombiningKind::MUL:
523       combinedResult = rewriter.create<arith::MulIOp>(loc, mul, acc);
524       break;
525     case CombiningKind::MINUI:
526       combinedResult = rewriter.create<arith::MinUIOp>(loc, mul, acc);
527       break;
528     case CombiningKind::MINSI:
529       combinedResult = rewriter.create<arith::MinSIOp>(loc, mul, acc);
530       break;
531     case CombiningKind::MAXUI:
532       combinedResult = rewriter.create<arith::MaxUIOp>(loc, mul, acc);
533       break;
534     case CombiningKind::MAXSI:
535       combinedResult = rewriter.create<arith::MaxSIOp>(loc, mul, acc);
536       break;
537     case CombiningKind::AND:
538       combinedResult = rewriter.create<arith::AndIOp>(loc, mul, acc);
539       break;
540     case CombiningKind::OR:
541       combinedResult = rewriter.create<arith::OrIOp>(loc, mul, acc);
542       break;
543     case CombiningKind::XOR:
544       combinedResult = rewriter.create<arith::XOrIOp>(loc, mul, acc);
545       break;
546     case CombiningKind::MINF: // Only valid for floating point types.
547     case CombiningKind::MAXF: // Only valid for floating point types.
548       return Optional<Value>();
549     }
550     return Optional<Value>(combinedResult);
551   }
552 
553   static Optional<Value> genMultF(Location loc, Value x, Value y, Value acc,
554                                   vector::CombiningKind kind,
555                                   PatternRewriter &rewriter) {
556     using vector::CombiningKind;
557 
558     // Special case for fused multiply-add.
559     if (acc && kind == CombiningKind::ADD) {
560       return Optional<Value>(rewriter.create<vector::FMAOp>(loc, x, y, acc));
561     }
562 
563     auto mul = rewriter.create<arith::MulFOp>(loc, x, y);
564 
565     if (!acc)
566       return Optional<Value>(mul);
567 
568     Value combinedResult;
569     switch (kind) {
570     case CombiningKind::MUL:
571       combinedResult = rewriter.create<arith::MulFOp>(loc, mul, acc);
572       break;
573     case CombiningKind::MINF:
574       combinedResult = rewriter.create<arith::MinFOp>(loc, mul, acc);
575       break;
576     case CombiningKind::MAXF:
577       combinedResult = rewriter.create<arith::MaxFOp>(loc, mul, acc);
578       break;
579     case CombiningKind::ADD:   // Already handled this special case above.
580     case CombiningKind::AND:   // Only valid for integer types.
581     case CombiningKind::MINUI: // Only valid for integer types.
582     case CombiningKind::MINSI: // Only valid for integer types.
583     case CombiningKind::MAXUI: // Only valid for integer types.
584     case CombiningKind::MAXSI: // Only valid for integer types.
585     case CombiningKind::OR:    // Only valid for integer types.
586     case CombiningKind::XOR:   // Only valid for integer types.
587       return Optional<Value>();
588     }
589     return Optional<Value>(combinedResult);
590   }
591 };
592 
593 /// Progressive lowering of ConstantMaskOp.
594 /// One:
595 ///   %x = vector.constant_mask [a,b]
596 /// is replaced by:
597 ///   %z = zero-result
598 ///   %l = vector.constant_mask [b]
599 ///   %4 = vector.insert %l, %z[0]
600 ///   ..
601 ///   %x = vector.insert %l, %..[a-1]
602 /// until a one-dimensional vector is reached. All these operations
603 /// will be folded at LLVM IR level.
604 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
605 public:
606   using OpRewritePattern<vector::ConstantMaskOp>::OpRewritePattern;
607 
608   LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
609                                 PatternRewriter &rewriter) const override {
610     auto loc = op.getLoc();
611     auto dstType = op.getType();
612     auto eltType = dstType.getElementType();
613     auto dimSizes = op.mask_dim_sizes();
614     int64_t rank = dstType.getRank();
615 
616     if (rank == 0) {
617       assert(dimSizes.size() == 1 &&
618              "Expected exactly one dim size for a 0-D vector");
619       bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1;
620       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
621           op, dstType,
622           DenseIntElementsAttr::get(
623               VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
624               ArrayRef<bool>{value}));
625       return success();
626     }
627 
628     int64_t trueDim = std::min(dstType.getDimSize(0),
629                                dimSizes[0].cast<IntegerAttr>().getInt());
630 
631     if (rank == 1) {
632       // Express constant 1-D case in explicit vector form:
633       //   [T,..,T,F,..,F].
634       SmallVector<bool, 4> values(dstType.getDimSize(0));
635       for (int64_t d = 0; d < trueDim; d++)
636         values[d] = true;
637       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
638           op, dstType, rewriter.getBoolVectorAttr(values));
639       return success();
640     }
641 
642     VectorType lowType =
643         VectorType::get(dstType.getShape().drop_front(), eltType);
644     SmallVector<int64_t, 4> newDimSizes;
645     for (int64_t r = 1; r < rank; r++)
646       newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
647     Value trueVal = rewriter.create<vector::ConstantMaskOp>(
648         loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
649     Value result = rewriter.create<arith::ConstantOp>(
650         loc, dstType, rewriter.getZeroAttr(dstType));
651     for (int64_t d = 0; d < trueDim; d++) {
652       auto pos = rewriter.getI64ArrayAttr(d);
653       result =
654           rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos);
655     }
656     rewriter.replaceOp(op, result);
657     return success();
658   }
659 };
660 
661 /// Progressive lowering of CreateMaskOp.
662 /// One:
663 ///   %x = vector.create_mask %a, ... : vector<dx...>
664 /// is replaced by:
665 ///   %l = vector.create_mask ... : vector<...>  ; one lower rank
666 ///   %0 = arith.cmpi "slt", %ci, %a       |
667 ///   %1 = select %0, %l, %zeroes    |
668 ///   %r = vector.insert %1, %pr [i] | d-times
669 ///   %x = ....
670 /// until a one-dimensional vector is reached.
671 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
672 public:
673   using OpRewritePattern<vector::CreateMaskOp>::OpRewritePattern;
674 
675   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
676                                 PatternRewriter &rewriter) const override {
677     auto dstType = op.getResult().getType().cast<VectorType>();
678     int64_t rank = dstType.getRank();
679     if (rank <= 1)
680       return rewriter.notifyMatchFailure(
681           op, "0-D and 1-D vectors are handled separately");
682 
683     auto loc = op.getLoc();
684     auto eltType = dstType.getElementType();
685     int64_t dim = dstType.getDimSize(0);
686     Value idx = op.getOperand(0);
687 
688     VectorType lowType =
689         VectorType::get(dstType.getShape().drop_front(), eltType);
690     Value trueVal = rewriter.create<vector::CreateMaskOp>(
691         loc, lowType, op.getOperands().drop_front());
692     Value falseVal = rewriter.create<arith::ConstantOp>(
693         loc, lowType, rewriter.getZeroAttr(lowType));
694     Value result = rewriter.create<arith::ConstantOp>(
695         loc, dstType, rewriter.getZeroAttr(dstType));
696     for (int64_t d = 0; d < dim; d++) {
697       Value bnd =
698           rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d));
699       Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
700                                                  bnd, idx);
701       Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal);
702       auto pos = rewriter.getI64ArrayAttr(d);
703       result =
704           rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos);
705     }
706     rewriter.replaceOp(op, result);
707     return success();
708   }
709 };
710 
711 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
712 /// vectors progressively on the way to target llvm.matrix intrinsics.
713 /// This iterates over the most major dimension of the 2-D vector and performs
714 /// rewrites into:
715 ///   vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
716 class ShapeCastOp2DDownCastRewritePattern
717     : public OpRewritePattern<vector::ShapeCastOp> {
718 public:
719   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
720 
721   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
722                                 PatternRewriter &rewriter) const override {
723     auto sourceVectorType = op.getSourceVectorType();
724     auto resultVectorType = op.getResultVectorType();
725     if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
726       return failure();
727 
728     auto loc = op.getLoc();
729     Value desc = rewriter.create<arith::ConstantOp>(
730         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
731     unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
732     for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
733       Value vec = rewriter.create<vector::ExtractOp>(loc, op.source(), i);
734       desc = rewriter.create<vector::InsertStridedSliceOp>(
735           loc, vec, desc,
736           /*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
737     }
738     rewriter.replaceOp(op, desc);
739     return success();
740   }
741 };
742 
743 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
744 /// vectors progressively.
745 /// This iterates over the most major dimension of the 2-D vector and performs
746 /// rewrites into:
747 ///   vector.extract_strided_slice from 1-D + vector.insert into 2-D
748 /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
749 class ShapeCastOp2DUpCastRewritePattern
750     : public OpRewritePattern<vector::ShapeCastOp> {
751 public:
752   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
753 
754   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
755                                 PatternRewriter &rewriter) const override {
756     auto sourceVectorType = op.getSourceVectorType();
757     auto resultVectorType = op.getResultVectorType();
758     if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
759       return failure();
760 
761     auto loc = op.getLoc();
762     Value desc = rewriter.create<arith::ConstantOp>(
763         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
764     unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
765     for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
766       Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
767           loc, op.source(), /*offsets=*/i * mostMinorVectorSize,
768           /*sizes=*/mostMinorVectorSize,
769           /*strides=*/1);
770       desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
771     }
772     rewriter.replaceOp(op, desc);
773     return success();
774   }
775 };
776 
777 // We typically should not lower general shape cast operations into data
778 // movement instructions, since the assumption is that these casts are
779 // optimized away during progressive lowering. For completeness, however,
780 // we fall back to a reference implementation that moves all elements
781 // into the right place if we get here.
782 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
783 public:
784   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
785 
786   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
787                                 PatternRewriter &rewriter) const override {
788     Location loc = op.getLoc();
789     auto sourceVectorType = op.getSourceVectorType();
790     auto resultVectorType = op.getResultVectorType();
791 
792     // Special case 2D/1D lowerings with better implementations.
793     // TODO: make is ND/1D to allow generic ND->1D->MD.
794     int64_t srcRank = sourceVectorType.getRank();
795     int64_t resRank = resultVectorType.getRank();
796     if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
797       return failure();
798 
799     // Generic ShapeCast lowering path goes all the way down to unrolled scalar
800     // extract/insert chains.
801     // TODO: consider evolving the semantics to only allow 1D source or dest and
802     // drop this potentially very expensive lowering.
803     // Compute number of elements involved in the reshape.
804     int64_t numElts = 1;
805     for (int64_t r = 0; r < srcRank; r++)
806       numElts *= sourceVectorType.getDimSize(r);
807     // Replace with data movement operations:
808     //    x[0,0,0] = y[0,0]
809     //    x[0,0,1] = y[0,1]
810     //    x[0,1,0] = y[0,2]
811     // etc., incrementing the two index vectors "row-major"
812     // within the source and result shape.
813     SmallVector<int64_t, 4> srcIdx(srcRank);
814     SmallVector<int64_t, 4> resIdx(resRank);
815     Value result = rewriter.create<arith::ConstantOp>(
816         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
817     for (int64_t i = 0; i < numElts; i++) {
818       if (i != 0) {
819         incIdx(srcIdx, sourceVectorType, srcRank - 1);
820         incIdx(resIdx, resultVectorType, resRank - 1);
821       }
822       Value e = rewriter.create<vector::ExtractOp>(loc, op.source(), srcIdx);
823       result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx);
824     }
825     rewriter.replaceOp(op, result);
826     return success();
827   }
828 
829 private:
830   static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) {
831     assert(0 <= r && r < tp.getRank());
832     if (++idx[r] == tp.getDimSize(r)) {
833       idx[r] = 0;
834       incIdx(idx, tp, r - 1);
835     }
836   }
837 };
838 
839 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
840 /// Ex:
841 /// ```
842 ///   %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
843 ///   %1 = vector.multi_reduction add, %0 [1]
844 ///     : vector<8x32x16xf32> to vector<8x16xf32>
845 /// ```
846 /// Gets converted to:
847 /// ```
848 ///   %1 = vector.contract {indexing_maps = [
849 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
850 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
851 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
852 ///    iterator_types = ["parallel", "parallel", "reduction"],
853 ///    kind = add} %0, %arg1, %cst_f0
854 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
855 ///  ```
856 struct MultiReduceToContract
857     : public OpRewritePattern<vector::MultiDimReductionOp> {
858   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
859 
860   LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
861                                 PatternRewriter &rewriter) const override {
862     if (reduceOp.kind() != vector::CombiningKind::ADD)
863       return failure();
864     Operation *mulOp = reduceOp.source().getDefiningOp();
865     if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
866       return failure();
867     SmallVector<bool> reductionMask = reduceOp.getReductionMask();
868     auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
869     SmallVector<AffineExpr> exprs;
870     SmallVector<StringRef> iteratorTypes;
871     for (const auto &isReduceDim : llvm::enumerate(reductionMask)) {
872       if (!isReduceDim.value()) {
873         iteratorTypes.push_back(getParallelIteratorTypeName());
874         exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
875       } else {
876         iteratorTypes.push_back(getReductionIteratorTypeName());
877       }
878     }
879     auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(),
880                                  /*symCount=*/0, exprs, reduceOp.getContext());
881     Value zero = rewriter.create<arith::ConstantOp>(
882         reduceOp.getLoc(), reduceOp.getDestType(),
883         rewriter.getZeroAttr(reduceOp.getDestType()));
884     rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
885         reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), zero,
886         rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
887         rewriter.getStrArrayAttr(iteratorTypes));
888     return success();
889   }
890 };
891 
892 /// Merge TransposeOp into ContractionOp user.
893 /// Ex:
894 /// ```
895 ///   %0 = vector.transpose %arg0, [2, 0, 1]
896 ///     : vector<32x16x8xf32> to vector<8x32x16xf32>
897 ///   %1 = vector.contract {indexing_maps = [
898 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
899 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
900 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
901 ///    iterator_types = ["parallel", "parallel", "reduction"],
902 ///    kind = add} %0, %arg1, %cst_f0
903 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
904 /// ```
905 /// Gets converted to:
906 /// ```
907 ///   %1 = vector.contract {indexing_maps = [
908 ///         affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
909 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
910 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
911 ///    iterator_types = ["parallel", "parallel", "reduction"],
912 ///    kind = add} %arg0, %arg1, %cst_f0
913 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
914 ///  ```
915 struct CombineContractTranspose
916     : public OpRewritePattern<vector::ContractionOp> {
917   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
918 
919   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
920                                 PatternRewriter &rewriter) const override {
921     SmallVector<AffineMap, 4> maps =
922         llvm::to_vector<4>(contractOp.getIndexingMaps());
923     Value lhs = contractOp.lhs();
924     Value rhs = contractOp.rhs();
925     size_t index = 0;
926     bool changed = false;
927     for (Value *operand : {&lhs, &rhs}) {
928       AffineMap &map = maps[index++];
929       auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
930       if (!transposeOp)
931         continue;
932       SmallVector<int64_t> perm;
933       transposeOp.getTransp(perm);
934       AffineMap permutationMap = AffineMap::getPermutationMap(
935           extractVector<unsigned>(transposeOp.transp()),
936           contractOp.getContext());
937       map = inversePermutation(permutationMap).compose(map);
938       *operand = transposeOp.vector();
939       changed = true;
940     }
941     if (!changed)
942       return failure();
943     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
944         contractOp, lhs, rhs, contractOp.acc(),
945         rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types());
946     return success();
947   }
948 };
949 
950 /// Merge BroadcastOp into ContractionOp user.
951 /// Ex:
952 /// ```
953 ///   %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
954 ///   %1 = vector.contract {indexing_maps = [
955 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
956 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
957 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
958 ///    iterator_types = ["parallel", "parallel", "reduction"],
959 ///    kind = add} %0, %arg1, %cst_f0
960 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
961 /// ```
962 /// Gets converted to:
963 /// ```
964 ///   %1 = vector.contract {indexing_maps = [
965 ///         affine_map<(d0, d1, d2) -> (d1, d2)>,
966 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
967 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
968 ///    iterator_types = ["parallel", "parallel", "reduction"],
969 ///    kind = add} %arg0, %arg1, %cst_f0
970 ///    : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
971 ///  ```
972 struct CombineContractBroadcast
973     : public OpRewritePattern<vector::ContractionOp> {
974   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
975 
976   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
977                                 PatternRewriter &rewriter) const override {
978     SmallVector<AffineMap, 4> maps =
979         llvm::to_vector<4>(contractOp.getIndexingMaps());
980     Value lhs = contractOp.lhs();
981     Value rhs = contractOp.rhs();
982     size_t index = 0;
983     bool changed = false;
984     for (Value *operand : {&lhs, &rhs}) {
985       AffineMap &map = maps[index++];
986       auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
987       if (!broadcast)
988         continue;
989       // contractionOp can only take vector as operands.
990       auto srcType = broadcast.getSourceType().dyn_cast<VectorType>();
991       if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank())
992         continue;
993       int64_t rankDiff =
994           broadcast.getVectorType().getRank() - srcType.getRank();
995       bool innerDimBroadcast = false;
996       SmallVector<AffineExpr> originalDims;
997       for (const auto &dim : llvm::enumerate(srcType.getShape())) {
998         if (dim.value() !=
999             broadcast.getVectorType().getDimSize(rankDiff + dim.index())) {
1000           innerDimBroadcast = true;
1001           break;
1002         }
1003         originalDims.push_back(
1004             rewriter.getAffineDimExpr(dim.index() + rankDiff));
1005       }
1006       // Contract doesn't support inner dimension broadcast. Once this is
1007       // relaxed we can remove this case.
1008       if (innerDimBroadcast)
1009         continue;
1010       AffineMap broadcastMap =
1011           AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims,
1012                          contractOp.getContext());
1013       map = broadcastMap.compose(map);
1014       *operand = broadcast.source();
1015       changed = true;
1016     }
1017     if (!changed)
1018       return failure();
1019     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1020         contractOp, lhs, rhs, contractOp.acc(),
1021         rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types());
1022     return success();
1023   }
1024 };
1025 
1026 } // namespace
1027 
1028 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
1029 /// operands `x` and `y`.
1030 static Value createAdd(Location loc, Value x, Value y, bool isInt,
1031                        PatternRewriter &rewriter) {
1032   if (isInt)
1033     return rewriter.create<arith::AddIOp>(loc, x, y);
1034   return rewriter.create<arith::AddFOp>(loc, x, y);
1035 }
1036 
1037 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
1038 /// operands `x and `y`.
1039 static Value createMul(Location loc, Value x, Value y, bool isInt,
1040                        PatternRewriter &rewriter) {
1041   if (isInt)
1042     return rewriter.create<arith::MulIOp>(loc, x, y);
1043   return rewriter.create<arith::MulFOp>(loc, x, y);
1044 }
1045 
1046 namespace mlir {
1047 
1048 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1049 /// semantics to:
1050 /// ```
1051 ///    %mta = maybe_transpose
1052 ///    %mtb = maybe_transpose
1053 ///    %flattened_a = vector.shape_cast %mta
1054 ///    %flattened_b = vector.shape_cast %mtb
1055 ///    %flattened_d = vector.matmul %flattened_a, %flattened_b
1056 ///    %mtd = vector.shape_cast %flattened_d
1057 ///    %d = maybe_untranspose %mtd
1058 ///    %e = add %c, %d
1059 /// ```
1060 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
1061 //
1062 /// This only kicks in when VectorTransformsOptions is set to `Matmul`.
1063 /// vector.transpose operations are inserted if the vector.contract op is not a
1064 /// row-major matrix multiply.
1065 LogicalResult
1066 ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
1067                                                  PatternRewriter &rew) const {
1068   // TODO: implement masks
1069   if (llvm::size(op.masks()) != 0)
1070     return failure();
1071   if (vectorTransformOptions.vectorContractLowering !=
1072       vector::VectorContractLowering::Matmul)
1073     return failure();
1074   if (failed(filter(op)))
1075     return failure();
1076 
1077   auto iteratorTypes = op.iterator_types().getValue();
1078   if (!isParallelIterator(iteratorTypes[0]) ||
1079       !isParallelIterator(iteratorTypes[1]) ||
1080       !isReductionIterator(iteratorTypes[2]))
1081     return failure();
1082 
1083   Type elementType = op.getLhsType().getElementType();
1084   if (!elementType.isIntOrFloat())
1085     return failure();
1086 
1087   // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
1088   // Bail out if the contraction cannot be put in this form.
1089   MLIRContext *ctx = op.getContext();
1090   Location loc = op.getLoc();
1091   AffineExpr m, n, k;
1092   bindDims(rew.getContext(), m, n, k);
1093   // LHS must be A(m, k) or A(k, m).
1094   Value lhs = op.lhs();
1095   auto lhsMap = op.indexing_maps()[0].cast<AffineMapAttr>().getValue();
1096   if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
1097     lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
1098   else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
1099     return failure();
1100 
1101   // RHS must be B(k, n) or B(n, k).
1102   Value rhs = op.rhs();
1103   auto rhsMap = op.indexing_maps()[1].cast<AffineMapAttr>().getValue();
1104   if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
1105     rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
1106   else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
1107     return failure();
1108 
1109   // At this point lhs and rhs are in row-major.
1110   VectorType lhsType = lhs.getType().cast<VectorType>();
1111   VectorType rhsType = rhs.getType().cast<VectorType>();
1112   int64_t lhsRows = lhsType.getDimSize(0);
1113   int64_t lhsColumns = lhsType.getDimSize(1);
1114   int64_t rhsColumns = rhsType.getDimSize(1);
1115 
1116   Type flattenedLHSType =
1117       VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
1118   lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
1119 
1120   Type flattenedRHSType =
1121       VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
1122   rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
1123 
1124   Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
1125                                            rhsColumns);
1126   mul = rew.create<vector::ShapeCastOp>(
1127       loc,
1128       VectorType::get({lhsRows, rhsColumns},
1129                       getElementTypeOrSelf(op.acc().getType())),
1130       mul);
1131 
1132   // ACC must be C(m, n) or C(n, m).
1133   auto accMap = op.indexing_maps()[2].cast<AffineMapAttr>().getValue();
1134   if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
1135     mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
1136   else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
1137     llvm_unreachable("invalid contraction semantics");
1138 
1139   Value res =
1140       elementType.isa<IntegerType>()
1141           ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.acc(), mul))
1142           : static_cast<Value>(rew.create<arith::AddFOp>(loc, op.acc(), mul));
1143 
1144   rew.replaceOp(op, res);
1145   return success();
1146 }
1147 
1148 namespace {
1149 struct IteratorType {
1150   IteratorType(StringRef strRef) : strRef(strRef) {}
1151   bool isOfType(Attribute attr) const {
1152     auto sAttr = attr.dyn_cast<StringAttr>();
1153     return sAttr && sAttr.getValue() == strRef;
1154   }
1155   StringRef strRef;
1156 };
1157 struct Par : public IteratorType {
1158   Par() : IteratorType(getParallelIteratorTypeName()) {}
1159 };
1160 struct Red : public IteratorType {
1161   Red() : IteratorType(getReductionIteratorTypeName()) {}
1162 };
1163 
1164 /// Generate a vector implementation for matmat, matvec and tmatvec.
1165 /// This unrolls outer-products along the reduction dimension.
1166 struct UnrolledOuterProductGenerator
1167     : public StructuredGenerator<vector::ContractionOp> {
1168 
1169   UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op)
1170       : StructuredGenerator<vector::ContractionOp>(builder, op),
1171         kind(op.kind()), lhs(op.lhs()), rhs(op.rhs()), res(op.acc()),
1172         lhsType(op.getLhsType()) {}
1173 
1174   Value t(Value v) {
1175     static constexpr std::array<int64_t, 2> perm = {1, 0};
1176     return builder.create<vector::TransposeOp>(loc, v, perm);
1177   }
1178 
1179   Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) {
1180     assert(reductionSize > 0);
1181     for (int64_t k = 0; k < reductionSize; ++k) {
1182       Value a = builder.create<vector::ExtractOp>(loc, lhs, k);
1183       Value b = builder.create<vector::ExtractOp>(loc, rhs, k);
1184       res = builder.create<vector::OuterProductOp>(loc, res.getType(), a, b,
1185                                                    res, kind);
1186     }
1187     return res;
1188   }
1189 
1190   /// Two outer parallel, one inner reduction (matmat flavor).
1191   FailureOr<Value> matmat() {
1192     if (!iters({Par(), Par(), Red()}))
1193       return failure();
1194     // Set up the parallel/reduction structure in the right form.
1195     AffineExpr m, n, k;
1196     bindDims(builder.getContext(), m, n, k);
1197     // Classical row-major matmul:  Just permute the lhs.
1198     if (layout({{m, k}, {k, n}, {m, n}}))
1199       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1200     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1201     if (layout({{m, k}, {n, k}, {m, n}})) {
1202       Value tlhs = t(lhs);
1203       return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1));
1204     }
1205     // No need to permute anything.
1206     if (layout({{k, m}, {k, n}, {m, n}}))
1207       return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1208     // Just permute the rhs.
1209     if (layout({{k, m}, {n, k}, {m, n}}))
1210       return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0));
1211     // Transposed output: swap RHS and LHS.
1212     // Classical row-major matmul: permute the lhs.
1213     if (layout({{m, k}, {k, n}, {n, m}}))
1214       return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1));
1215     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1216     if (layout({{m, k}, {n, k}, {n, m}})) {
1217       Value trhs = t(rhs);
1218       return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1));
1219     }
1220     if (layout({{k, m}, {k, n}, {n, m}}))
1221       return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1222     if (layout({{k, m}, {n, k}, {n, m}}))
1223       return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1224     return failure();
1225   }
1226 
1227   /// One outer parallel, one inner reduction (matvec flavor)
1228   FailureOr<Value> matvec() {
1229     if (!iters({Par(), Red()}))
1230       return failure();
1231     AffineExpr m, k;
1232     bindDims(builder.getContext(), m, k);
1233 
1234     // Case mat-vec: transpose.
1235     if (layout({{m, k}, {k}, {m}}))
1236       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1237     // Case mat-trans-vec: ready to go.
1238     if (layout({{k, m}, {k}, {m}}))
1239       return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1240     // Case vec-mat: swap and transpose.
1241     if (layout({{k}, {m, k}, {m}}))
1242       return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1243     // Case vec-mat-trans: swap and ready to go.
1244     if (layout({{k}, {k, m}, {m}}))
1245       return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1246     return failure();
1247   }
1248 
1249   //
1250   // One outer reduction, one inner parallel (tmatvec flavor)
1251   //
1252   FailureOr<Value> tmatvec() {
1253     if (!iters({Red(), Par()}))
1254       return failure();
1255     AffineExpr k, m;
1256     bindDims(builder.getContext(), k, m);
1257 
1258     // Case mat-vec: transpose.
1259     if (layout({{m, k}, {k}, {m}}))
1260       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1261     // Case mat-trans-vec: ready to go.
1262     if (layout({{k, m}, {k}, {m}}))
1263       return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1264     // Case vec-mat: swap and transpose.
1265     if (layout({{k}, {m, k}, {m}}))
1266       return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1267     // Case vec-mat-trans: swap and ready to go.
1268     if (layout({{k}, {k, m}, {m}}))
1269       return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1270     return failure();
1271   }
1272 
1273 private:
1274   vector::CombiningKind kind;
1275   Value lhs, rhs, res;
1276   VectorType lhsType;
1277 };
1278 } // namespace
1279 
1280 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1281 /// semantics to a reduction_size-unrolled sequence:
1282 /// ```
1283 ///    %at = vector.transpose %a, [1, 0]
1284 ///    %bRow0 = vector.extract %b[0]
1285 ///    %atRow0 = vector.extract %at[0]
1286 ///    %c0 = vector.outerproduct %atRow0, %bRow0, %c
1287 ///    ...
1288 ///    %bRowK = vector.extract %b[K]
1289 ///    %atRowK = vector.extract %at[K]
1290 ///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
1291 /// ```
1292 ///
1293 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but
1294 /// otherwise supports any layout permutation of the matrix-multiply.
1295 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
1296     vector::ContractionOp op, PatternRewriter &rewriter) const {
1297   // TODO: implement masks
1298   if (llvm::size(op.masks()) != 0)
1299     return failure();
1300 
1301   if (vectorTransformOptions.vectorContractLowering !=
1302       vector::VectorContractLowering::OuterProduct)
1303     return failure();
1304 
1305   if (failed(filter(op)))
1306     return failure();
1307 
1308   UnrolledOuterProductGenerator e(rewriter, op);
1309   FailureOr<Value> matmatRes = e.matmat();
1310   if (succeeded(matmatRes)) {
1311     rewriter.replaceOp(op, *matmatRes);
1312     return success();
1313   }
1314   FailureOr<Value> matvecRes = e.matvec();
1315   if (succeeded(matvecRes)) {
1316     rewriter.replaceOp(op, *matvecRes);
1317     return success();
1318   }
1319   FailureOr<Value> tmatvecRes = e.tmatvec();
1320   if (succeeded(tmatvecRes)) {
1321     rewriter.replaceOp(op, *tmatvecRes);
1322     return success();
1323   }
1324 
1325   return failure();
1326 }
1327 
1328 LogicalResult
1329 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
1330                                             PatternRewriter &rewriter) const {
1331   // TODO: implement masks
1332   if (llvm::size(op.masks()) != 0)
1333     return failure();
1334 
1335   if (failed(filter(op)))
1336     return failure();
1337 
1338   if (vectorTransformOptions.vectorContractLowering !=
1339       vector::VectorContractLowering::Dot)
1340     return failure();
1341 
1342   auto iteratorTypes = op.iterator_types().getValue();
1343   static constexpr std::array<int64_t, 2> perm = {1, 0};
1344   Location loc = op.getLoc();
1345   Value lhs = op.lhs(), rhs = op.rhs();
1346 
1347   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1348   auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
1349   AffineExpr m, n, k;
1350   bindDims(rewriter.getContext(), m, n, k);
1351   SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
1352   //
1353   // In the following we wish to make the reduction dimension innermost so we
1354   // can load vectors and just fmul + reduce into a scalar.
1355   //
1356   if (isParallelIterator(iteratorTypes[0]) &&
1357       isParallelIterator(iteratorTypes[1]) &&
1358       isReductionIterator(iteratorTypes[2])) {
1359     //
1360     // Two outer parallel, one inner reduction (matmat flavor).
1361     //
1362     if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1363       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1364     } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
1365       // No need to permute anything.
1366     } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1367       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1368       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1369     } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1370       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1371     } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1372       // This is the classical row-major matmul. Just permute the lhs.
1373       Value tmp = lhs;
1374       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1375       rhs = tmp;
1376     } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1377       std::swap(lhs, rhs);
1378     } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1379       Value tmp = lhs;
1380       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1381       rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
1382     } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1383       Value tmp = rhs;
1384       rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1385       lhs = tmp;
1386     } else {
1387       return failure();
1388     }
1389   } else if (isParallelIterator(iteratorTypes[0]) &&
1390              isReductionIterator(iteratorTypes[1])) {
1391     //
1392     // One outer parallel, one inner reduction (matvec flavor)
1393     //
1394     if (maps == infer({{m, n}, {n}, {m}})) {
1395       // No need to permute anything.
1396     } else if (maps == infer({{n, m}, {n}, {m}})) {
1397       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1398     } else if (maps == infer({{n}, {m, n}, {m}})) {
1399       std::swap(lhs, rhs);
1400     } else if (maps == infer({{n}, {n, m}, {m}})) {
1401       std::swap(lhs, rhs);
1402       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1403     } else {
1404       return failure();
1405     }
1406   } else {
1407     return failure();
1408   }
1409 
1410   VectorType dstType = op.getResultType().cast<VectorType>();
1411   assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
1412          "Expected dst type of rank 1 or 2");
1413 
1414   unsigned rank = dstType.getRank();
1415   unsigned dstRows = dstType.getShape()[0];
1416   unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
1417 
1418   // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
1419   Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
1420                                                  rewriter.getZeroAttr(dstType));
1421   bool isInt = dstType.getElementType().isa<IntegerType>();
1422   for (unsigned r = 0; r < dstRows; ++r) {
1423     Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
1424     for (unsigned c = 0; c < dstColumns; ++c) {
1425       Value b = rank == 1
1426                     ? rhs
1427                     : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
1428       Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
1429       Value reduced = rewriter.create<vector::ReductionOp>(
1430           op.getLoc(), vector::CombiningKind::ADD, m);
1431 
1432       SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
1433                                               : SmallVector<int64_t, 2>{r, c};
1434       res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
1435     }
1436   }
1437   if (auto acc = op.acc())
1438     res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
1439   rewriter.replaceOp(op, res);
1440   return success();
1441 }
1442 
1443 /// Progressive lowering of ContractionOp.
1444 /// One:
1445 ///   %x = vector.contract with at least one free/batch dimension
1446 /// is replaced by:
1447 ///   %a = vector.contract with one less free/batch dimension
1448 ///   %b = vector.contract with one less free/batch dimension
1449 ///   ..
1450 ///   %x = combine %a %b ..
1451 /// until a pure contraction is reached (no free/batch dimensions),
1452 /// which is replaced by a dot-product.
1453 ///
1454 /// This only kicks in when either VectorTransformsOptions is set
1455 /// to DOT or when other contraction patterns fail.
1456 //
1457 // TODO: break down into transpose/reshape/cast ops
1458 //               when they become available to avoid code dup
1459 // TODO: investigate lowering order impact on performance
1460 LogicalResult
1461 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
1462                                        PatternRewriter &rewriter) const {
1463   // TODO: implement masks.
1464   if (llvm::size(op.masks()) != 0)
1465     return failure();
1466 
1467   if (failed(filter(op)))
1468     return failure();
1469 
1470   // TODO: support mixed mode contract lowering.
1471   if (op.getLhsType().getElementType() !=
1472           getElementTypeOrSelf(op.getAccType()) ||
1473       op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
1474     return failure();
1475 
1476   // TODO: implement benefits, cost models.
1477   MLIRContext *ctx = op.getContext();
1478   ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
1479   if (succeeded(pat1.matchAndRewrite(op, rewriter)))
1480     return success();
1481   ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
1482   if (succeeded(pat2.matchAndRewrite(op, rewriter)))
1483     return success();
1484   ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
1485   if (succeeded(pat3.matchAndRewrite(op, rewriter)))
1486     return success();
1487 
1488   // Find first batch dimension in LHS/RHS, and lower when found.
1489   std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
1490   if (!batchDimMap.empty()) {
1491     int64_t lhsIndex = batchDimMap[0].first;
1492     int64_t rhsIndex = batchDimMap[0].second;
1493     rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter));
1494     return success();
1495   }
1496 
1497   // Collect contracting dimensions.
1498   std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
1499       op.getContractingDimMap();
1500   DenseSet<int64_t> lhsContractingDimSet;
1501   DenseSet<int64_t> rhsContractingDimSet;
1502   for (auto &dimPair : contractingDimMap) {
1503     lhsContractingDimSet.insert(dimPair.first);
1504     rhsContractingDimSet.insert(dimPair.second);
1505   }
1506 
1507   // Find first free dimension in LHS, and lower when found.
1508   VectorType lhsType = op.getLhsType();
1509   for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
1510     if (lhsContractingDimSet.count(lhsIndex) == 0) {
1511       rewriter.replaceOp(
1512           op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter));
1513       return success();
1514     }
1515   }
1516 
1517   // Find first free dimension in RHS, and lower when found.
1518   VectorType rhsType = op.getRhsType();
1519   for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
1520     if (rhsContractingDimSet.count(rhsIndex) == 0) {
1521       rewriter.replaceOp(
1522           op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter));
1523       return success();
1524     }
1525   }
1526 
1527   // Lower the first remaining reduction dimension.
1528   if (!contractingDimMap.empty()) {
1529     rewriter.replaceOp(op, lowerReduction(op, rewriter));
1530     return success();
1531   }
1532 
1533   return failure();
1534 }
1535 
1536 // Lower one parallel dimension.
1537 // TODO: consider reusing existing contract unrolling
1538 Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
1539                                            int64_t lhsIndex, int64_t rhsIndex,
1540                                            PatternRewriter &rewriter) const {
1541   VectorType lhsType = op.getLhsType();
1542   VectorType rhsType = op.getRhsType();
1543   VectorType resType = op.getResultType().cast<VectorType>();
1544   // Find the iterator type index and result index.
1545   SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
1546   int64_t iterIndex = -1;
1547   int64_t dimSize = -1;
1548   if (lhsIndex >= 0) {
1549     iterIndex = iMap[0].getDimPosition(lhsIndex);
1550     assert((rhsIndex < 0 || iterIndex == iMap[1].getDimPosition(rhsIndex)) &&
1551            "parallel index should be free in LHS or batch in LHS/RHS");
1552     dimSize = lhsType.getDimSize(lhsIndex);
1553   } else {
1554     assert(rhsIndex >= 0 && "missing parallel index");
1555     iterIndex = iMap[1].getDimPosition(rhsIndex);
1556     dimSize = rhsType.getDimSize(rhsIndex);
1557   }
1558   assert(iterIndex >= 0 && "parallel index not listed in operand mapping");
1559   Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex);
1560   assert(lookup.hasValue() && "parallel index not listed in reduction");
1561   int64_t resIndex = lookup.getValue();
1562   // Construct new iterator types and affine map array attribute.
1563   std::array<AffineMap, 3> lowIndexingMaps = {
1564       adjustMap(iMap[0], iterIndex, rewriter),
1565       adjustMap(iMap[1], iterIndex, rewriter),
1566       adjustMap(iMap[2], iterIndex, rewriter)};
1567   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1568   auto lowIter =
1569       rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
1570   // Unroll into a series of lower dimensional vector.contract ops.
1571   Location loc = op.getLoc();
1572   Value result = rewriter.create<arith::ConstantOp>(
1573       loc, resType, rewriter.getZeroAttr(resType));
1574   for (int64_t d = 0; d < dimSize; ++d) {
1575     auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
1576     auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
1577     auto acc = reshapeLoad(loc, op.acc(), resType, resIndex, d, rewriter);
1578     Value lowContract = rewriter.create<vector::ContractionOp>(
1579         loc, lhs, rhs, acc, lowAffine, lowIter);
1580     result =
1581         reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter);
1582   }
1583   return result;
1584 }
1585 
1586 // Lower one reduction dimension.
1587 Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
1588                                             PatternRewriter &rewriter) const {
1589   auto loc = op.getLoc();
1590   VectorType lhsType = op.getLhsType();
1591   VectorType rhsType = op.getRhsType();
1592   Type resType = op.getResultType();
1593   assert(!resType.isa<VectorType>());
1594   bool isInt = resType.isa<IntegerType>();
1595   // Use iterator index 0.
1596   int64_t iterIndex = 0;
1597   SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
1598   Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
1599   Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
1600   assert(lookupLhs.hasValue() && "missing LHS parallel index");
1601   assert(lookupRhs.hasValue() && "missing RHS parallel index");
1602   int64_t lhsIndex = lookupLhs.getValue();
1603   int64_t rhsIndex = lookupRhs.getValue();
1604   int64_t dimSize = lhsType.getDimSize(lhsIndex);
1605   assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape");
1606   // Base case.
1607   if (lhsType.getRank() == 1) {
1608     assert(rhsType.getRank() == 1 && "corrupt contraction");
1609     Value m = createMul(loc, op.lhs(), op.rhs(), isInt, rewriter);
1610     auto kind = vector::CombiningKind::ADD;
1611     Value res = rewriter.create<vector::ReductionOp>(loc, kind, m);
1612     if (auto acc = op.acc())
1613       res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
1614     return res;
1615   }
1616   // Construct new iterator types and affine map array attribute.
1617   std::array<AffineMap, 3> lowIndexingMaps = {
1618       adjustMap(iMap[0], iterIndex, rewriter),
1619       adjustMap(iMap[1], iterIndex, rewriter),
1620       adjustMap(iMap[2], iterIndex, rewriter)};
1621   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1622   auto lowIter =
1623       rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
1624   // Unroll into a series of lower dimensional vector.contract ops.
1625   // By feeding the initial accumulator into the first contraction,
1626   // and the result of each contraction into the next, eventually
1627   // the sum of all reductions is computed.
1628   Value result = op.acc();
1629   for (int64_t d = 0; d < dimSize; ++d) {
1630     auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
1631     auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
1632     result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result,
1633                                                     lowAffine, lowIter);
1634   }
1635   return result;
1636 }
1637 
1638 } // namespace mlir
1639 
1640 Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
1641     OpBuilder &builder, Operation *op, ArrayRef<Value> ids,
1642     ArrayRef<int64_t> multiplicity, const AffineMap &map) {
1643   OpBuilder::InsertionGuard guard(builder);
1644   builder.setInsertionPointAfter(op);
1645   Location loc = op->getLoc();
1646   if (op->getNumResults() != 1)
1647     return {};
1648   Value result = op->getResult(0);
1649   VectorType type = op->getResult(0).getType().dyn_cast<VectorType>();
1650   if (!type || map.getNumResults() != multiplicity.size())
1651     return {};
1652   // For each dimension being distributed check that the size is a multiple of
1653   // the multiplicity. To handle more sizes we would need to support masking.
1654   unsigned multiplictyCount = 0;
1655   for (auto exp : map.getResults()) {
1656     auto affinExp = exp.dyn_cast<AffineDimExpr>();
1657     if (!affinExp || affinExp.getPosition() >= type.getRank() ||
1658         type.getDimSize(affinExp.getPosition()) %
1659                 multiplicity[multiplictyCount++] !=
1660             0)
1661       return {};
1662   }
1663   DistributeOps ops;
1664   ops.extract =
1665       builder.create<vector::ExtractMapOp>(loc, result, ids, multiplicity, map);
1666   ops.insert =
1667       builder.create<vector::InsertMapOp>(loc, ops.extract, result, ids);
1668   return ops;
1669 }
1670 
1671 /// Progressive lowering of transfer_read. This pattern supports lowering of
1672 /// `vector.transfer_read` to a combination of `vector.load` and
1673 /// `vector.broadcast` if all of the following hold:
1674 /// - Stride of most minor memref dimension must be 1.
1675 /// - Out-of-bounds masking is not required.
1676 /// - If the memref's element type is a vector type then it coincides with the
1677 ///   result type.
1678 /// - The permutation map doesn't perform permutation (broadcasting is allowed).
1679 struct TransferReadToVectorLoadLowering
1680     : public OpRewritePattern<vector::TransferReadOp> {
1681   TransferReadToVectorLoadLowering(MLIRContext *context,
1682                                    llvm::Optional<unsigned> maxRank)
1683       : OpRewritePattern<vector::TransferReadOp>(context),
1684         maxTransferRank(maxRank) {}
1685 
1686   LogicalResult matchAndRewrite(vector::TransferReadOp read,
1687                                 PatternRewriter &rewriter) const override {
1688     if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank)
1689       return failure();
1690 
1691     SmallVector<unsigned, 4> broadcastedDims;
1692     // Permutations are handled by VectorToSCF or
1693     // populateVectorTransferPermutationMapLoweringPatterns.
1694     // We let the 0-d corner case pass-through as it is supported.
1695     if (!read.permutation_map().isMinorIdentityWithBroadcasting(
1696             &broadcastedDims))
1697       return failure();
1698 
1699     auto memRefType = read.getShapedType().dyn_cast<MemRefType>();
1700     if (!memRefType)
1701       return failure();
1702 
1703     // Non-unit strides are handled by VectorToSCF.
1704     if (!vector::isLastMemrefDimUnitStride(memRefType))
1705       return failure();
1706 
1707     // If there is broadcasting involved then we first load the unbroadcasted
1708     // vector, and then broadcast it with `vector.broadcast`.
1709     ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
1710     SmallVector<int64_t, 4> unbroadcastedVectorShape(vectorShape.begin(),
1711                                                      vectorShape.end());
1712     for (unsigned i : broadcastedDims)
1713       unbroadcastedVectorShape[i] = 1;
1714     VectorType unbroadcastedVectorType = VectorType::get(
1715         unbroadcastedVectorShape, read.getVectorType().getElementType());
1716 
1717     // `vector.load` supports vector types as memref's elements only when the
1718     // resulting vector type is the same as the element type.
1719     auto memrefElTy = memRefType.getElementType();
1720     if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType)
1721       return failure();
1722 
1723     // Otherwise, element types of the memref and the vector must match.
1724     if (!memrefElTy.isa<VectorType>() &&
1725         memrefElTy != read.getVectorType().getElementType())
1726       return failure();
1727 
1728     // Out-of-bounds dims are handled by MaterializeTransferMask.
1729     if (read.hasOutOfBoundsDim())
1730       return failure();
1731 
1732     // Create vector load op.
1733     Operation *loadOp;
1734     if (read.mask()) {
1735       Value fill = rewriter.create<vector::SplatOp>(
1736           read.getLoc(), unbroadcastedVectorType, read.padding());
1737       loadOp = rewriter.create<vector::MaskedLoadOp>(
1738           read.getLoc(), unbroadcastedVectorType, read.source(), read.indices(),
1739           read.mask(), fill);
1740     } else {
1741       loadOp = rewriter.create<vector::LoadOp>(read.getLoc(),
1742                                                unbroadcastedVectorType,
1743                                                read.source(), read.indices());
1744     }
1745 
1746     // Insert a broadcasting op if required.
1747     if (!broadcastedDims.empty()) {
1748       rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1749           read, read.getVectorType(), loadOp->getResult(0));
1750     } else {
1751       rewriter.replaceOp(read, loadOp->getResult(0));
1752     }
1753 
1754     return success();
1755   }
1756 
1757   llvm::Optional<unsigned> maxTransferRank;
1758 };
1759 
1760 /// Replace a 0-d vector.load with a memref.load + vector.broadcast.
1761 // TODO: we shouldn't cross the vector/scalar domains just for this
1762 // but atm we lack the infra to avoid it. Possible solutions include:
1763 // - go directly to LLVM + bitcast
1764 // - introduce a bitcast op and likely a new pointer dialect
1765 // - let memref.load/store additionally support the 0-d vector case
1766 // There are still deeper data layout issues lingering even in this
1767 // trivial case (for architectures for which this matters).
1768 struct VectorLoadToMemrefLoadLowering
1769     : public OpRewritePattern<vector::LoadOp> {
1770   using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
1771 
1772   LogicalResult matchAndRewrite(vector::LoadOp loadOp,
1773                                 PatternRewriter &rewriter) const override {
1774     auto vecType = loadOp.getVectorType();
1775     if (vecType.getNumElements() != 1)
1776       return failure();
1777     auto memrefLoad = rewriter.create<memref::LoadOp>(
1778         loadOp.getLoc(), loadOp.base(), loadOp.indices());
1779     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
1780                                                      memrefLoad);
1781     return success();
1782   }
1783 };
1784 
1785 /// Replace a 0-d vector.store with a vector.extractelement + memref.store.
1786 struct VectorStoreToMemrefStoreLowering
1787     : public OpRewritePattern<vector::StoreOp> {
1788   using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
1789 
1790   LogicalResult matchAndRewrite(vector::StoreOp storeOp,
1791                                 PatternRewriter &rewriter) const override {
1792     auto vecType = storeOp.getVectorType();
1793     if (vecType.getNumElements() != 1)
1794       return failure();
1795     Value extracted;
1796     if (vecType.getRank() == 0) {
1797       // TODO: Unifiy once ExtractOp supports 0-d vectors.
1798       extracted = rewriter.create<vector::ExtractElementOp>(
1799           storeOp.getLoc(), storeOp.valueToStore());
1800     } else {
1801       SmallVector<int64_t> indices(vecType.getRank(), 0);
1802       extracted = rewriter.create<vector::ExtractOp>(
1803           storeOp.getLoc(), storeOp.valueToStore(), indices);
1804     }
1805 
1806     rewriter.replaceOpWithNewOp<memref::StoreOp>(
1807         storeOp, extracted, storeOp.base(), storeOp.indices());
1808     return success();
1809   }
1810 };
1811 
1812 /// Progressive lowering of transfer_write. This pattern supports lowering of
1813 /// `vector.transfer_write` to `vector.store` if all of the following hold:
1814 /// - Stride of most minor memref dimension must be 1.
1815 /// - Out-of-bounds masking is not required.
1816 /// - If the memref's element type is a vector type then it coincides with the
1817 ///   type of the written value.
1818 /// - The permutation map is the minor identity map (neither permutation nor
1819 ///   broadcasting is allowed).
1820 struct TransferWriteToVectorStoreLowering
1821     : public OpRewritePattern<vector::TransferWriteOp> {
1822   TransferWriteToVectorStoreLowering(MLIRContext *context,
1823                                      llvm::Optional<unsigned> maxRank)
1824       : OpRewritePattern<vector::TransferWriteOp>(context),
1825         maxTransferRank(maxRank) {}
1826 
1827   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
1828                                 PatternRewriter &rewriter) const override {
1829     if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank)
1830       return failure();
1831 
1832     // Permutations are handled by VectorToSCF or
1833     // populateVectorTransferPermutationMapLoweringPatterns.
1834     if ( // pass-through for the 0-d corner case.
1835         !write.permutation_map().isMinorIdentity())
1836       return failure();
1837 
1838     auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
1839     if (!memRefType)
1840       return failure();
1841 
1842     // Non-unit strides are handled by VectorToSCF.
1843     if (!vector::isLastMemrefDimUnitStride(memRefType))
1844       return failure();
1845 
1846     // `vector.store` supports vector types as memref's elements only when the
1847     // type of the vector value being written is the same as the element type.
1848     auto memrefElTy = memRefType.getElementType();
1849     if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType())
1850       return failure();
1851 
1852     // Otherwise, element types of the memref and the vector must match.
1853     if (!memrefElTy.isa<VectorType>() &&
1854         memrefElTy != write.getVectorType().getElementType())
1855       return failure();
1856 
1857     // Out-of-bounds dims are handled by MaterializeTransferMask.
1858     if (write.hasOutOfBoundsDim())
1859       return failure();
1860     if (write.mask()) {
1861       rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
1862           write, write.source(), write.indices(), write.mask(), write.vector());
1863     } else {
1864       rewriter.replaceOpWithNewOp<vector::StoreOp>(
1865           write, write.vector(), write.source(), write.indices());
1866     }
1867     return success();
1868   }
1869 
1870   llvm::Optional<unsigned> maxTransferRank;
1871 };
1872 
1873 // Returns the values in `arrayAttr` as an integer vector.
1874 static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) {
1875   return llvm::to_vector<4>(
1876       llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
1877                       [](IntegerAttr attr) { return attr.getInt(); }));
1878 }
1879 
1880 // Shuffles vector.bitcast op after vector.extract op.
1881 //
1882 // This transforms IR like:
1883 //   %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
1884 //   %1 = vector.extract %0[3] : vector<8xf16>
1885 // Into:
1886 //   %0 = vector.extract %src[1] : vector<4xf32>
1887 //   %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16>
1888 //   %2 = vector.extract %1[1] : vector<2xf16>
1889 struct BubbleDownVectorBitCastForExtract
1890     : public OpRewritePattern<vector::ExtractOp> {
1891   using OpRewritePattern::OpRewritePattern;
1892 
1893   LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
1894                                 PatternRewriter &rewriter) const override {
1895     // Only support extracting scalars for now.
1896     if (extractOp.getVectorType().getRank() != 1)
1897       return failure();
1898 
1899     auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>();
1900     if (!castOp)
1901       return failure();
1902 
1903     VectorType castSrcType = castOp.getSourceVectorType();
1904     VectorType castDstType = castOp.getResultVectorType();
1905     assert(castSrcType.getRank() == castDstType.getRank());
1906 
1907     // Fail to match if we only have one element in the cast op source.
1908     // This is to avoid infinite loop given that this pattern can generate
1909     // such cases.
1910     if (castSrcType.getNumElements() == 1)
1911       return failure();
1912 
1913     // Only support casting to a larger number of elements or now.
1914     // E.g., vector<4xf32> -> vector<8xf16>.
1915     if (castSrcType.getNumElements() > castDstType.getNumElements())
1916       return failure();
1917 
1918     unsigned expandRatio =
1919         castDstType.getNumElements() / castSrcType.getNumElements();
1920 
1921     auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t {
1922       return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
1923     };
1924 
1925     uint64_t index = getFirstIntValue(extractOp.position());
1926 
1927     // Get the single scalar (as a vector) in the source value that packs the
1928     // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
1929     VectorType oneScalarType =
1930         VectorType::get({1}, castSrcType.getElementType());
1931     Value packedValue = rewriter.create<vector::ExtractOp>(
1932         extractOp.getLoc(), oneScalarType, castOp.source(),
1933         rewriter.getI64ArrayAttr(index / expandRatio));
1934 
1935     // Cast it to a vector with the desired scalar's type.
1936     // E.g. f32 -> vector<2xf16>
1937     VectorType packedType =
1938         VectorType::get({expandRatio}, castDstType.getElementType());
1939     Value castedValue = rewriter.create<vector::BitCastOp>(
1940         extractOp.getLoc(), packedType, packedValue);
1941 
1942     // Finally extract the desired scalar.
1943     rewriter.replaceOpWithNewOp<vector::ExtractOp>(
1944         extractOp, extractOp.getType(), castedValue,
1945         rewriter.getI64ArrayAttr(index % expandRatio));
1946 
1947     return success();
1948   }
1949 };
1950 
1951 // Shuffles vector.bitcast op after vector.extract_strided_slice op.
1952 //
1953 // This transforms IR like:
1954 //    %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
1955 //     %0 = vector.extract_strided_slice %cast {
1956 //            offsets = [4], sizes = [4], strides = [1]
1957 //          } : vector<8xf16> to vector<4xf16>
1958 // Into:
1959 //   %0 = vector.extract_strided_slice %src {
1960 //          offsets = [2], sizes = [2], strides = [1]
1961 //        } : vector<4xf32> to vector<2xf32>
1962 //   %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16>
1963 struct BubbleDownBitCastForStridedSliceExtract
1964     : public OpRewritePattern<vector::ExtractStridedSliceOp> {
1965   using OpRewritePattern::OpRewritePattern;
1966 
1967   LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
1968                                 PatternRewriter &rewriter) const override {
1969     auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>();
1970     if (!castOp)
1971       return failure();
1972 
1973     VectorType castSrcType = castOp.getSourceVectorType();
1974     VectorType castDstType = castOp.getResultVectorType();
1975     assert(castSrcType.getRank() == castDstType.getRank());
1976 
1977     int64_t castSrcLastDim = castSrcType.getShape().back();
1978     int64_t castDstLastDim = castDstType.getShape().back();
1979     // Require casting to more elements for now; other cases to be implemented.
1980     if (castSrcLastDim > castDstLastDim)
1981       return failure();
1982 
1983     // Only accept all one strides for now.
1984     if (llvm::any_of(extractOp.strides().getAsValueRange<IntegerAttr>(),
1985                      [](const APInt &val) { return !val.isOneValue(); }))
1986       return failure();
1987 
1988     unsigned rank = extractOp.getVectorType().getRank();
1989     assert(castDstLastDim % castSrcLastDim == 0);
1990     int64_t expandRatio = castDstLastDim / castSrcLastDim;
1991 
1992     // If we have a less number of offsets than the rank, then implicitly we
1993     // are selecting the full range for the last bitcasted dimension; other
1994     // dimensions aren't affected. Otherwise, we need to scale down the last
1995     // dimension's offset given we are extracting from less elements now.
1996     ArrayAttr newOffsets = extractOp.offsets();
1997     if (newOffsets.size() == rank) {
1998       SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
1999       if (offsets.back() % expandRatio != 0)
2000         return failure();
2001       offsets.back() = offsets.back() / expandRatio;
2002       newOffsets = rewriter.getI64ArrayAttr(offsets);
2003     }
2004 
2005     // Similarly for sizes.
2006     ArrayAttr newSizes = extractOp.sizes();
2007     if (newSizes.size() == rank) {
2008       SmallVector<int64_t, 4> sizes = getIntValueVector(newSizes);
2009       if (sizes.back() % expandRatio != 0)
2010         return failure();
2011       sizes.back() = sizes.back() / expandRatio;
2012       newSizes = rewriter.getI64ArrayAttr(sizes);
2013     }
2014 
2015     SmallVector<int64_t, 4> dims =
2016         llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape());
2017     dims.back() = dims.back() / expandRatio;
2018     VectorType newExtractType =
2019         VectorType::get(dims, castSrcType.getElementType());
2020 
2021     auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
2022         extractOp.getLoc(), newExtractType, castOp.source(), newOffsets,
2023         newSizes, extractOp.strides());
2024 
2025     rewriter.replaceOpWithNewOp<vector::BitCastOp>(
2026         extractOp, extractOp.getType(), newExtractOp);
2027 
2028     return success();
2029   }
2030 };
2031 
2032 // Shuffles vector.bitcast op before vector.insert_strided_slice op.
2033 //
2034 // This transforms IR like:
2035 //   %0 = vector.insert_strided_slice %src, %dst {
2036 //          offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
2037 //   %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
2038 // Into:
2039 //   %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32>
2040 //   %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32>
2041 //   %2 = vector.insert_strided_slice %src, %dst {
2042 //          offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
2043 struct BubbleUpBitCastForStridedSliceInsert
2044     : public OpRewritePattern<vector::BitCastOp> {
2045   using OpRewritePattern::OpRewritePattern;
2046   LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
2047                                 PatternRewriter &rewriter) const override {
2048     VectorType castSrcType = bitcastOp.getSourceVectorType();
2049     VectorType castDstType = bitcastOp.getResultVectorType();
2050     assert(castSrcType.getRank() == castDstType.getRank());
2051 
2052     int64_t castSrcLastDim = castSrcType.getShape().back();
2053     int64_t castDstLastDim = castDstType.getShape().back();
2054     // Require casting to less elements for now; other cases to be implemented.
2055     if (castSrcLastDim < castDstLastDim)
2056       return failure();
2057 
2058     assert(castSrcLastDim % castDstLastDim == 0);
2059     int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
2060 
2061     auto insertOp =
2062         bitcastOp.source().getDefiningOp<vector::InsertStridedSliceOp>();
2063     if (!insertOp)
2064       return failure();
2065 
2066     // Only accept all one strides for now.
2067     if (llvm::any_of(insertOp.strides().getAsValueRange<IntegerAttr>(),
2068                      [](const APInt &val) { return !val.isOneValue(); }))
2069       return failure();
2070 
2071     unsigned rank = insertOp.getSourceVectorType().getRank();
2072     // Require insert op to have the same rank for the source and destination
2073     // vector; other cases to be implemented.
2074     if (rank != insertOp.getDestVectorType().getRank())
2075       return failure();
2076 
2077     ArrayAttr newOffsets = insertOp.offsets();
2078     assert(newOffsets.size() == rank);
2079     SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
2080     if (offsets.back() % shrinkRatio != 0)
2081       return failure();
2082     offsets.back() = offsets.back() / shrinkRatio;
2083     newOffsets = rewriter.getI64ArrayAttr(offsets);
2084 
2085     SmallVector<int64_t, 4> srcDims =
2086         llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
2087     srcDims.back() = srcDims.back() / shrinkRatio;
2088     VectorType newCastSrcType =
2089         VectorType::get(srcDims, castDstType.getElementType());
2090 
2091     auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
2092         bitcastOp.getLoc(), newCastSrcType, insertOp.source());
2093 
2094     SmallVector<int64_t, 4> dstDims =
2095         llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
2096     dstDims.back() = dstDims.back() / shrinkRatio;
2097     VectorType newCastDstType =
2098         VectorType::get(dstDims, castDstType.getElementType());
2099 
2100     auto newCastDstOp = rewriter.create<vector::BitCastOp>(
2101         bitcastOp.getLoc(), newCastDstType, insertOp.dest());
2102 
2103     rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
2104         bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
2105         insertOp.strides());
2106 
2107     return success();
2108   }
2109 };
2110 
2111 static Value createCastToIndexLike(PatternRewriter &rewriter, Location loc,
2112                                    Type targetType, Value value) {
2113   if (targetType == value.getType())
2114     return value;
2115 
2116   bool targetIsIndex = targetType.isIndex();
2117   bool valueIsIndex = value.getType().isIndex();
2118   if (targetIsIndex ^ valueIsIndex)
2119     return rewriter.create<arith::IndexCastOp>(loc, targetType, value);
2120 
2121   auto targetIntegerType = targetType.dyn_cast<IntegerType>();
2122   auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
2123   assert(targetIntegerType && valueIntegerType &&
2124          "unexpected cast between types other than integers and index");
2125   assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
2126 
2127   if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
2128     return rewriter.create<arith::ExtSIOp>(loc, targetIntegerType, value);
2129   return rewriter.create<arith::TruncIOp>(loc, targetIntegerType, value);
2130 }
2131 
2132 // Helper that returns a vector comparison that constructs a mask:
2133 //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
2134 //
2135 // If `dim == 0` then the result will be a 0-D vector.
2136 //
2137 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
2138 //       much more compact, IR for this operation, but LLVM eventually
2139 //       generates more elaborate instructions for this intrinsic since it
2140 //       is very conservative on the boundary conditions.
2141 static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
2142                                    bool indexOptimizations, int64_t dim,
2143                                    Value b, Value *off = nullptr) {
2144   auto loc = op->getLoc();
2145   // If we can assume all indices fit in 32-bit, we perform the vector
2146   // comparison in 32-bit to get a higher degree of SIMD parallelism.
2147   // Otherwise we perform the vector comparison using 64-bit indices.
2148   Type idxType =
2149       indexOptimizations ? rewriter.getI32Type() : rewriter.getI64Type();
2150   DenseIntElementsAttr indicesAttr;
2151   if (dim == 0 && indexOptimizations) {
2152     indicesAttr = DenseIntElementsAttr::get(
2153         VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0});
2154   } else if (dim == 0) {
2155     indicesAttr = DenseIntElementsAttr::get(
2156         VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0});
2157   } else if (indexOptimizations) {
2158     indicesAttr = rewriter.getI32VectorAttr(
2159         llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
2160   } else {
2161     indicesAttr = rewriter.getI64VectorAttr(
2162         llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
2163   }
2164   Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
2165   // Add in an offset if requested.
2166   if (off) {
2167     Value o = createCastToIndexLike(rewriter, loc, idxType, *off);
2168     Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o);
2169     indices = rewriter.create<arith::AddIOp>(loc, ov, indices);
2170   }
2171   // Construct the vector comparison.
2172   Value bound = createCastToIndexLike(rewriter, loc, idxType, b);
2173   Value bounds =
2174       rewriter.create<vector::SplatOp>(loc, indices.getType(), bound);
2175   return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
2176                                         bounds);
2177 }
2178 
2179 template <typename ConcreteOp>
2180 struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
2181 public:
2182   explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt)
2183       : mlir::OpRewritePattern<ConcreteOp>(context),
2184         indexOptimizations(enableIndexOpt) {}
2185 
2186   LogicalResult matchAndRewrite(ConcreteOp xferOp,
2187                                 PatternRewriter &rewriter) const override {
2188     if (!xferOp.hasOutOfBoundsDim())
2189       return failure();
2190 
2191     if (xferOp.getVectorType().getRank() > 1 ||
2192         llvm::size(xferOp.indices()) == 0)
2193       return failure();
2194 
2195     Location loc = xferOp->getLoc();
2196     VectorType vtp = xferOp.getVectorType();
2197 
2198     // * Create a vector with linear indices [ 0 .. vector_length - 1 ].
2199     // * Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
2200     // * Let dim the memref dimension, compute the vector comparison mask
2201     //   (in-bounds mask):
2202     //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
2203     //
2204     // TODO: when the leaf transfer rank is k > 1, we need the last `k`
2205     //       dimensions here.
2206     unsigned vecWidth = vtp.getNumElements();
2207     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
2208     Value off = xferOp.indices()[lastIndex];
2209     Value dim =
2210         vector::createOrFoldDimOp(rewriter, loc, xferOp.source(), lastIndex);
2211     Value mask = buildVectorComparison(rewriter, xferOp, indexOptimizations,
2212                                        vecWidth, dim, &off);
2213 
2214     if (xferOp.mask()) {
2215       // Intersect the in-bounds with the mask specified as an op parameter.
2216       mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.mask());
2217     }
2218 
2219     rewriter.updateRootInPlace(xferOp, [&]() {
2220       xferOp.maskMutable().assign(mask);
2221       xferOp.in_boundsAttr(rewriter.getBoolArrayAttr({true}));
2222     });
2223 
2224     return success();
2225   }
2226 
2227 private:
2228   const bool indexOptimizations;
2229 };
2230 
2231 /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only).
2232 class VectorCreateMaskOpConversion
2233     : public OpRewritePattern<vector::CreateMaskOp> {
2234 public:
2235   explicit VectorCreateMaskOpConversion(MLIRContext *context,
2236                                         bool enableIndexOpt)
2237       : mlir::OpRewritePattern<vector::CreateMaskOp>(context),
2238         indexOptimizations(enableIndexOpt) {}
2239 
2240   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
2241                                 PatternRewriter &rewriter) const override {
2242     auto dstType = op.getType();
2243     int64_t rank = dstType.getRank();
2244     if (rank > 1)
2245       return failure();
2246     rewriter.replaceOp(
2247         op, buildVectorComparison(rewriter, op, indexOptimizations,
2248                                   rank == 0 ? 0 : dstType.getDimSize(0),
2249                                   op.getOperand(0)));
2250     return success();
2251   }
2252 
2253 private:
2254   const bool indexOptimizations;
2255 };
2256 
2257 // Drop inner most contiguous unit dimensions from transfer_read operand.
2258 class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
2259   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
2260 
2261   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
2262                                 PatternRewriter &rewriter) const override {
2263     // TODO: support 0-d corner case.
2264     if (readOp.getTransferRank() == 0)
2265       return failure();
2266 
2267     // TODO: support mask.
2268     if (readOp.mask())
2269       return failure();
2270 
2271     auto srcType = readOp.source().getType().dyn_cast<MemRefType>();
2272     if (!srcType || !srcType.hasStaticShape())
2273       return failure();
2274 
2275     if (!readOp.permutation_map().isMinorIdentity())
2276       return failure();
2277 
2278     auto targetType = readOp.getVectorType();
2279     if (targetType.getRank() <= 1)
2280       return failure();
2281 
2282     SmallVector<int64_t> srcStrides;
2283     int64_t srcOffset;
2284     if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
2285       return failure();
2286 
2287     size_t dimsToDrop = 0;
2288     for (size_t i = 1; i < srcStrides.size(); ++i) {
2289       int dim = srcType.getRank() - i - 1;
2290       if (srcStrides[dim] == 1) {
2291         dimsToDrop++;
2292       } else {
2293         break;
2294       }
2295     }
2296     if (dimsToDrop == 0)
2297       return failure();
2298 
2299     auto resultTargetVecType =
2300         VectorType::get(targetType.getShape().drop_back(dimsToDrop),
2301                         targetType.getElementType());
2302 
2303     MemRefType resultMemrefType;
2304     if (srcType.getLayout().getAffineMap().isIdentity()) {
2305       resultMemrefType = MemRefType::get(
2306           srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
2307           {}, srcType.getMemorySpaceAsInt());
2308     } else {
2309       AffineMap map = srcType.getLayout().getAffineMap();
2310       int numResultDims = map.getNumDims() - dimsToDrop;
2311       int numSymbols = map.getNumSymbols();
2312       for (size_t i = 0; i < dimsToDrop; ++i) {
2313         int dim = srcType.getRank() - i - 1;
2314         map = map.replace(rewriter.getAffineDimExpr(dim),
2315                           rewriter.getAffineConstantExpr(0), numResultDims,
2316                           numSymbols);
2317       }
2318       resultMemrefType = MemRefType::get(
2319           srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
2320           map, srcType.getMemorySpaceAsInt());
2321     }
2322 
2323     auto loc = readOp.getLoc();
2324     SmallVector<int64_t> offsets(srcType.getRank(), 0);
2325     SmallVector<int64_t> strides(srcType.getRank(), 1);
2326 
2327     ArrayAttr inBoundsAttr =
2328         readOp.in_bounds()
2329             ? rewriter.getArrayAttr(
2330                   readOp.in_boundsAttr().getValue().drop_back(dimsToDrop))
2331             : ArrayAttr();
2332     Value rankedReducedView = rewriter.create<memref::SubViewOp>(
2333         loc, resultMemrefType, readOp.source(), offsets, srcType.getShape(),
2334         strides);
2335     auto permMap = getTransferMinorIdentityMap(
2336         rankedReducedView.getType().cast<ShapedType>(), resultTargetVecType);
2337     Value result = rewriter.create<vector::TransferReadOp>(
2338         loc, resultTargetVecType, rankedReducedView,
2339         readOp.indices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
2340         readOp.padding(),
2341         // TODO: support mask.
2342         /*mask=*/Value(), inBoundsAttr);
2343     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
2344                                                      result);
2345     return success();
2346   }
2347 };
2348 
2349 namespace {
2350 
2351 /// This function checks to see if the vector combining kind
2352 /// is consistent with the integer or float element type.
2353 static bool isValidKind(bool isInt, vector::CombiningKind kind) {
2354   using vector::CombiningKind;
2355   enum class KindType { FLOAT, INT, INVALID };
2356   KindType type{KindType::INVALID};
2357   switch (kind) {
2358   case CombiningKind::MINF:
2359   case CombiningKind::MAXF:
2360     type = KindType::FLOAT;
2361     break;
2362   case CombiningKind::MINUI:
2363   case CombiningKind::MINSI:
2364   case CombiningKind::MAXUI:
2365   case CombiningKind::MAXSI:
2366   case CombiningKind::AND:
2367   case CombiningKind::OR:
2368   case CombiningKind::XOR:
2369     type = KindType::INT;
2370     break;
2371   case CombiningKind::ADD:
2372   case CombiningKind::MUL:
2373     type = isInt ? KindType::INT : KindType::FLOAT;
2374     break;
2375   }
2376   bool isValidIntKind = (type == KindType::INT) && isInt;
2377   bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt);
2378   return (isValidIntKind || isValidFloatKind);
2379 }
2380 
2381 /// This function constructs the appropriate integer or float
2382 /// operation given the vector combining kind and operands. The
2383 /// supported int operations are : add, mul, min (signed/unsigned),
2384 /// max(signed/unsigned), and, or, xor. The supported float
2385 /// operations are : add, mul, min and max.
2386 static Value genOperator(Location loc, Value x, Value y,
2387                          vector::CombiningKind kind,
2388                          PatternRewriter &rewriter) {
2389   using vector::CombiningKind;
2390 
2391   auto elType = x.getType().cast<VectorType>().getElementType();
2392   bool isInt = elType.isIntOrIndex();
2393 
2394   Value combinedResult{nullptr};
2395   switch (kind) {
2396   case CombiningKind::ADD:
2397     if (isInt)
2398       combinedResult = rewriter.create<arith::AddIOp>(loc, x, y);
2399     else
2400       combinedResult = rewriter.create<arith::AddFOp>(loc, x, y);
2401     break;
2402   case CombiningKind::MUL:
2403     if (isInt)
2404       combinedResult = rewriter.create<arith::MulIOp>(loc, x, y);
2405     else
2406       combinedResult = rewriter.create<arith::MulFOp>(loc, x, y);
2407     break;
2408   case CombiningKind::MINUI:
2409     combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y);
2410     break;
2411   case CombiningKind::MINSI:
2412     combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y);
2413     break;
2414   case CombiningKind::MAXUI:
2415     combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y);
2416     break;
2417   case CombiningKind::MAXSI:
2418     combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y);
2419     break;
2420   case CombiningKind::AND:
2421     combinedResult = rewriter.create<arith::AndIOp>(loc, x, y);
2422     break;
2423   case CombiningKind::OR:
2424     combinedResult = rewriter.create<arith::OrIOp>(loc, x, y);
2425     break;
2426   case CombiningKind::XOR:
2427     combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y);
2428     break;
2429   case CombiningKind::MINF:
2430     combinedResult = rewriter.create<arith::MinFOp>(loc, x, y);
2431     break;
2432   case CombiningKind::MAXF:
2433     combinedResult = rewriter.create<arith::MaxFOp>(loc, x, y);
2434     break;
2435   }
2436   return combinedResult;
2437 }
2438 
2439 /// Convert vector.scan op into arith ops and
2440 /// vector.insert_strided_slice/extract_strided_slice
2441 ///
2442 /// Ex:
2443 /// ```
2444 ///   %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim =
2445 ///   1} :
2446 ///     (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>)
2447 /// ```
2448 /// Gets converted to:
2449 /// ```
2450 ///   %cst = arith.constant dense<0> : vector<2x3xi32>
2451 ///   %0 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 1],
2452 ///   strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %1 =
2453 ///   vector.insert_strided_slice %0, %cst {offsets = [0, 0], strides = [1, 1]}
2454 ///   : vector<2x1xi32> into vector<2x3xi32> %2 = vector.extract_strided_slice
2455 ///   %arg0 {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} :
2456 ///   vector<2x3xi32> to vector<2x1xi32> %3 = arith.muli %0, %2 :
2457 ///   vector<2x1xi32> %4 = vector.insert_strided_slice %3, %1 {offsets = [0, 1],
2458 ///   strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %5 =
2459 ///   vector.extract_strided_slice %arg0 {offsets = [0, 2], sizes = [2, 1],
2460 ///   strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %6 = arith.muli %3,
2461 ///   %5 : vector<2x1xi32> %7 = vector.insert_strided_slice %6, %4 {offsets =
2462 ///   [0, 2], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %8 =
2463 ///   vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32> return %7, %8 :
2464 ///   vector<2x3xi32>, vector<2xi32>
2465 /// ```
2466 struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
2467   using OpRewritePattern<vector::ScanOp>::OpRewritePattern;
2468 
2469   LogicalResult matchAndRewrite(vector::ScanOp scanOp,
2470                                 PatternRewriter &rewriter) const override {
2471     auto loc = scanOp.getLoc();
2472     VectorType destType = scanOp.getDestType();
2473     ArrayRef<int64_t> destShape = destType.getShape();
2474     auto elType = destType.getElementType();
2475     bool isInt = elType.isIntOrIndex();
2476     if (!isValidKind(isInt, scanOp.kind()))
2477       return failure();
2478 
2479     VectorType resType = VectorType::get(destShape, elType);
2480     Value result = rewriter.create<arith::ConstantOp>(
2481         loc, resType, rewriter.getZeroAttr(resType));
2482     int64_t reductionDim = scanOp.reduction_dim();
2483     bool inclusive = scanOp.inclusive();
2484     int64_t destRank = destType.getRank();
2485     VectorType initialValueType = scanOp.getInitialValueType();
2486     int64_t initialValueRank = initialValueType.getRank();
2487 
2488     SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end());
2489     reductionShape[reductionDim] = 1;
2490     VectorType reductionType = VectorType::get(reductionShape, elType);
2491     SmallVector<int64_t> offsets(destRank, 0);
2492     SmallVector<int64_t> strides(destRank, 1);
2493     SmallVector<int64_t> sizes(destShape.begin(), destShape.end());
2494     sizes[reductionDim] = 1;
2495     ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes);
2496     ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides);
2497 
2498     Value lastOutput, lastInput;
2499     for (int i = 0; i < destShape[reductionDim]; i++) {
2500       offsets[reductionDim] = i;
2501       ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets);
2502       Value input = rewriter.create<vector::ExtractStridedSliceOp>(
2503           loc, reductionType, scanOp.source(), scanOffsets, scanSizes,
2504           scanStrides);
2505       Value output;
2506       if (i == 0) {
2507         if (inclusive) {
2508           output = input;
2509         } else {
2510           if (initialValueRank == 0) {
2511             // ShapeCastOp cannot handle 0-D vectors
2512             output = rewriter.create<vector::BroadcastOp>(
2513                 loc, input.getType(), scanOp.initial_value());
2514           } else {
2515             output = rewriter.create<vector::ShapeCastOp>(
2516                 loc, input.getType(), scanOp.initial_value());
2517           }
2518         }
2519       } else {
2520         Value y = inclusive ? input : lastInput;
2521         output = genOperator(loc, lastOutput, y, scanOp.kind(), rewriter);
2522         assert(output != nullptr);
2523       }
2524       result = rewriter.create<vector::InsertStridedSliceOp>(
2525           loc, output, result, offsets, strides);
2526       lastOutput = output;
2527       lastInput = input;
2528     }
2529 
2530     Value reduction;
2531     if (initialValueRank == 0) {
2532       Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0);
2533       reduction =
2534           rewriter.create<vector::BroadcastOp>(loc, initialValueType, v);
2535     } else {
2536       reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType,
2537                                                        lastOutput);
2538     }
2539 
2540     rewriter.replaceOp(scanOp, {result, reduction});
2541     return success();
2542   }
2543 };
2544 
2545 } // namespace
2546 
2547 void mlir::vector::populateVectorMaskMaterializationPatterns(
2548     RewritePatternSet &patterns, bool indexOptimizations) {
2549   patterns.add<VectorCreateMaskOpConversion,
2550                MaterializeTransferMask<vector::TransferReadOp>,
2551                MaterializeTransferMask<vector::TransferWriteOp>>(
2552       patterns.getContext(), indexOptimizations);
2553 }
2554 
2555 void mlir::vector::populateShapeCastFoldingPatterns(
2556     RewritePatternSet &patterns) {
2557   patterns.add<ShapeCastOpFolder>(patterns.getContext());
2558 }
2559 
2560 void mlir::vector::populateBubbleVectorBitCastOpPatterns(
2561     RewritePatternSet &patterns) {
2562   patterns.add<BubbleDownVectorBitCastForExtract,
2563                BubbleDownBitCastForStridedSliceExtract,
2564                BubbleUpBitCastForStridedSliceInsert>(patterns.getContext());
2565 }
2566 
2567 void mlir::vector::populateVectorBroadcastLoweringPatterns(
2568     RewritePatternSet &patterns) {
2569   patterns.add<BroadcastOpLowering>(patterns.getContext());
2570 }
2571 
2572 void mlir::vector::populateVectorMaskOpLoweringPatterns(
2573     RewritePatternSet &patterns) {
2574   patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
2575       patterns.getContext());
2576 }
2577 
2578 void mlir::vector::populateVectorShapeCastLoweringPatterns(
2579     RewritePatternSet &patterns) {
2580   patterns.add<ShapeCastOp2DDownCastRewritePattern,
2581                ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>(
2582       patterns.getContext());
2583 }
2584 
2585 void mlir::vector::populateVectorContractLoweringPatterns(
2586     RewritePatternSet &patterns, VectorTransformsOptions options) {
2587   patterns.add<OuterProductOpLowering>(patterns.getContext());
2588   patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
2589                ContractionOpToOuterProductOpLowering>(options,
2590                                                       patterns.getContext());
2591 }
2592 
2593 void mlir::vector::populateVectorTransposeLoweringPatterns(
2594     RewritePatternSet &patterns, VectorTransformsOptions options) {
2595   patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
2596       options, patterns.getContext());
2597 }
2598 
2599 void mlir::vector::populateVectorReductionToContractPatterns(
2600     RewritePatternSet &patterns) {
2601   patterns.add<MultiReduceToContract, CombineContractBroadcast,
2602                CombineContractTranspose>(patterns.getContext());
2603 }
2604 
2605 void mlir::vector::
2606     populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
2607         RewritePatternSet &patterns) {
2608   patterns.add<DropInnerMostUnitDims>(patterns.getContext());
2609 }
2610 
2611 void mlir::vector::populateVectorTransferLoweringPatterns(
2612     RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank) {
2613   patterns.add<TransferReadToVectorLoadLowering,
2614                TransferWriteToVectorStoreLowering>(patterns.getContext(),
2615                                                    maxTransferRank);
2616   patterns
2617       .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
2618           patterns.getContext());
2619 }
2620 
2621 void mlir::vector::populateVectorScanLoweringPatterns(
2622     RewritePatternSet &patterns) {
2623   patterns.add<ScanToArithOps>(patterns.getContext());
2624 }
2625