1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites as 1->N patterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
14 
15 #include <type_traits>
16 
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
19 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
20 #include "mlir/Dialect/Linalg/IR/Linalg.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/SCF/SCF.h"
23 #include "mlir/Dialect/Utils/IndexingUtils.h"
24 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
25 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
26 #include "mlir/IR/ImplicitLocOpBuilder.h"
27 #include "mlir/IR/Matchers.h"
28 #include "mlir/IR/PatternMatch.h"
29 #include "mlir/Interfaces/VectorInterfaces.h"
30 
31 #include "llvm/ADT/DenseSet.h"
32 #include "llvm/ADT/MapVector.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/Support/CommandLine.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/raw_ostream.h"
37 
38 #define DEBUG_TYPE "vector-to-vector"
39 
40 using namespace mlir;
41 using namespace mlir::vector;
42 
43 // Helper to find an index in an affine map.
44 static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
45   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
46     int64_t idx = map.getDimPosition(i);
47     if (idx == index)
48       return i;
49   }
50   return None;
51 }
52 
53 // Helper to construct iterator types with one index removed.
54 static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes,
55                                             int64_t index) {
56   SmallVector<Attribute, 4> results;
57   for (const auto &it : llvm::enumerate(iteratorTypes)) {
58     int64_t idx = it.index();
59     if (idx == index)
60       continue;
61     results.push_back(it.value());
62   }
63   return results;
64 }
65 
66 // Helper to construct an affine map with one index removed.
67 static AffineMap adjustMap(AffineMap map, int64_t index,
68                            PatternRewriter &rewriter) {
69   auto *ctx = rewriter.getContext();
70   SmallVector<AffineExpr, 4> results;
71   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
72     int64_t idx = map.getDimPosition(i);
73     if (idx == index)
74       continue;
75     // Re-insert remaining indices, but renamed when occurring
76     // after the removed index.
77     auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
78     results.push_back(targetExpr);
79   }
80   return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
81 }
82 
83 // Helper method to possibly drop a dimension in a load.
84 // TODO
85 static Value reshapeLoad(Location loc, Value val, VectorType type,
86                          int64_t index, int64_t pos,
87                          PatternRewriter &rewriter) {
88   if (index == -1)
89     return val;
90   Type lowType = VectorType::Builder(type).dropDim(0);
91   // At extraction dimension?
92   if (index == 0) {
93     auto posAttr = rewriter.getI64ArrayAttr(pos);
94     return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
95   }
96   // Unroll leading dimensions.
97   VectorType vType = lowType.cast<VectorType>();
98   Type resType = VectorType::Builder(type).dropDim(index);
99   auto resVectorType = resType.cast<VectorType>();
100   Value result = rewriter.create<arith::ConstantOp>(
101       loc, resVectorType, rewriter.getZeroAttr(resVectorType));
102   for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) {
103     auto posAttr = rewriter.getI64ArrayAttr(d);
104     Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
105     Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
106     result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result,
107                                                posAttr);
108   }
109   return result;
110 }
111 
112 // Helper method to possibly drop a dimension in a store.
113 // TODO
114 static Value reshapeStore(Location loc, Value val, Value result,
115                           VectorType type, int64_t index, int64_t pos,
116                           PatternRewriter &rewriter) {
117   // Unmodified?
118   if (index == -1)
119     return val;
120   // At insertion dimension?
121   if (index == 0) {
122     auto posAttr = rewriter.getI64ArrayAttr(pos);
123     return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
124   }
125   // Unroll leading dimensions.
126   Type lowType = VectorType::Builder(type).dropDim(0);
127   VectorType vType = lowType.cast<VectorType>();
128   Type insType = VectorType::Builder(vType).dropDim(0);
129   for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
130     auto posAttr = rewriter.getI64ArrayAttr(d);
131     Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
132     Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr);
133     Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
134     result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr);
135   }
136   return result;
137 }
138 
139 template <typename IntType>
140 static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
141   return llvm::to_vector<4>(llvm::map_range(
142       arrayAttr.getAsRange<IntegerAttr>(),
143       [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
144 }
145 
146 namespace {
147 
148 /// ShapeCastOpFolder folds cancelling ShapeCastOps away.
149 //
150 // Example:
151 //
152 //  The following MLIR with cancelling ShapeCastOps:
153 //
154 //   %0 = source : vector<5x4x2xf32>
155 //   %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
156 //   %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
157 //   %3 = user %2 : vector<5x4x2xf32>
158 //
159 //  Should canonicalize to the following:
160 //
161 //   %0 = source : vector<5x4x2xf32>
162 //   %1 = user %0 : vector<5x4x2xf32>
163 //
164 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
165   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
166 
167   LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
168                                 PatternRewriter &rewriter) const override {
169     // Check if 'shapeCastOp' has vector source/result type.
170     auto sourceVectorType =
171         shapeCastOp.source().getType().dyn_cast_or_null<VectorType>();
172     auto resultVectorType =
173         shapeCastOp.result().getType().dyn_cast_or_null<VectorType>();
174     if (!sourceVectorType || !resultVectorType)
175       return failure();
176 
177     // Check if shape cast op source operand is also a shape cast op.
178     auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
179         shapeCastOp.source().getDefiningOp());
180     if (!sourceShapeCastOp)
181       return failure();
182     auto operandSourceVectorType =
183         sourceShapeCastOp.source().getType().cast<VectorType>();
184     auto operandResultVectorType = sourceShapeCastOp.getType();
185 
186     // Check if shape cast operations invert each other.
187     if (operandSourceVectorType != resultVectorType ||
188         operandResultVectorType != sourceVectorType)
189       return failure();
190 
191     rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.source());
192     return success();
193   }
194 };
195 
196 /// Progressive lowering of BroadcastOp.
197 class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
198 public:
199   using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
200 
201   LogicalResult matchAndRewrite(vector::BroadcastOp op,
202                                 PatternRewriter &rewriter) const override {
203     auto loc = op.getLoc();
204     VectorType dstType = op.getVectorType();
205     VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
206     Type eltType = dstType.getElementType();
207 
208     // Scalar to any vector can use splat.
209     if (!srcType) {
210       rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.source());
211       return success();
212     }
213 
214     // Determine rank of source and destination.
215     int64_t srcRank = srcType.getRank();
216     int64_t dstRank = dstType.getRank();
217 
218     // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
219     if (srcRank <= 1 && dstRank == 1) {
220       Value ext;
221       if (srcRank == 0)
222         ext = rewriter.create<vector::ExtractElementOp>(loc, op.source());
223       else
224         ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
225       rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
226       return success();
227     }
228 
229     // Duplicate this rank.
230     // For example:
231     //   %x = broadcast %y  : k-D to n-D, k < n
232     // becomes:
233     //   %b = broadcast %y  : k-D to (n-1)-D
234     //   %x = [%b,%b,%b,%b] : n-D
235     // becomes:
236     //   %b = [%y,%y]       : (n-1)-D
237     //   %x = [%b,%b,%b,%b] : n-D
238     if (srcRank < dstRank) {
239       // Duplication.
240       VectorType resType =
241           VectorType::get(dstType.getShape().drop_front(), eltType);
242       Value bcst =
243           rewriter.create<vector::BroadcastOp>(loc, resType, op.source());
244       Value result = rewriter.create<arith::ConstantOp>(
245           loc, dstType, rewriter.getZeroAttr(dstType));
246       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
247         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
248       rewriter.replaceOp(op, result);
249       return success();
250     }
251 
252     // Find non-matching dimension, if any.
253     assert(srcRank == dstRank);
254     int64_t m = -1;
255     for (int64_t r = 0; r < dstRank; r++)
256       if (srcType.getDimSize(r) != dstType.getDimSize(r)) {
257         m = r;
258         break;
259       }
260 
261     // All trailing dimensions are the same. Simply pass through.
262     if (m == -1) {
263       rewriter.replaceOp(op, op.source());
264       return success();
265     }
266 
267     // Any non-matching dimension forces a stretch along this rank.
268     // For example:
269     //   %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32>
270     // becomes:
271     //   %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32>
272     //   %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32>
273     //   %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32>
274     //   %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32>
275     //   %x = [%a,%b,%c,%d]
276     // becomes:
277     //   %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32>
278     //   %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32>
279     //   %a = [%u, %v]
280     //   ..
281     //   %x = [%a,%b,%c,%d]
282     VectorType resType =
283         VectorType::get(dstType.getShape().drop_front(), eltType);
284     Value result = rewriter.create<arith::ConstantOp>(
285         loc, dstType, rewriter.getZeroAttr(dstType));
286     if (m == 0) {
287       // Stetch at start.
288       Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
289       Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
290       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
291         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
292     } else {
293       // Stetch not at start.
294       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
295         Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), d);
296         Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
297         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
298       }
299     }
300     rewriter.replaceOp(op, result);
301     return success();
302   }
303 };
304 
305 /// Given a 'transpose' pattern, prune the rightmost dimensions that are not
306 /// transposed.
307 void pruneNonTransposedDims(ArrayRef<int64_t> transpose,
308                             SmallVectorImpl<int64_t> &result) {
309   size_t numTransposedDims = transpose.size();
310   for (size_t transpDim : llvm::reverse(transpose)) {
311     if (transpDim != numTransposedDims - 1)
312       break;
313     numTransposedDims--;
314   }
315 
316   result.append(transpose.begin(), transpose.begin() + numTransposedDims);
317 }
318 
319 /// Progressive lowering of TransposeOp.
320 /// One:
321 ///   %x = vector.transpose %y, [1, 0]
322 /// is replaced by:
323 ///   %z = arith.constant dense<0.000000e+00>
324 ///   %0 = vector.extract %y[0, 0]
325 ///   %1 = vector.insert %0, %z [0, 0]
326 ///   ..
327 ///   %x = vector.insert .., .. [.., ..]
328 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
329 public:
330   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
331 
332   TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
333                       MLIRContext *context)
334       : OpRewritePattern<vector::TransposeOp>(context),
335         vectorTransformOptions(vectorTransformOptions) {}
336 
337   LogicalResult matchAndRewrite(vector::TransposeOp op,
338                                 PatternRewriter &rewriter) const override {
339     auto loc = op.getLoc();
340 
341     Value input = op.vector();
342     VectorType inputType = op.getVectorType();
343     VectorType resType = op.getResultType();
344 
345     // Set up convenience transposition table.
346     SmallVector<int64_t, 4> transp;
347     for (auto attr : op.transp())
348       transp.push_back(attr.cast<IntegerAttr>().getInt());
349 
350     if (vectorTransformOptions.vectorTransposeLowering ==
351             vector::VectorTransposeLowering::Shuffle &&
352         resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0)
353       return rewriter.notifyMatchFailure(
354           op, "Options specifies lowering to shuffle");
355 
356     // Handle a true 2-D matrix transpose differently when requested.
357     if (vectorTransformOptions.vectorTransposeLowering ==
358             vector::VectorTransposeLowering::Flat &&
359         resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
360       Type flattenedType =
361           VectorType::get(resType.getNumElements(), resType.getElementType());
362       auto matrix =
363           rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
364       auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
365       auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
366       Value trans = rewriter.create<vector::FlatTransposeOp>(
367           loc, flattenedType, matrix, rows, columns);
368       rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
369       return success();
370     }
371 
372     // Generate unrolled extract/insert ops. We do not unroll the rightmost
373     // (i.e., highest-order) dimensions that are not transposed and leave them
374     // in vector form to improve performance. Therefore, we prune those
375     // dimensions from the shape/transpose data structures used to generate the
376     // extract/insert ops.
377     SmallVector<int64_t, 4> prunedTransp;
378     pruneNonTransposedDims(transp, prunedTransp);
379     size_t numPrunedDims = transp.size() - prunedTransp.size();
380     auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
381     SmallVector<int64_t, 4> ones(prunedInShape.size(), 1);
382     auto prunedInStrides = computeStrides(prunedInShape, ones);
383 
384     // Generates the extract/insert operations for every scalar/vector element
385     // of the leftmost transposed dimensions. We traverse every transpose
386     // element using a linearized index that we delinearize to generate the
387     // appropriate indices for the extract/insert operations.
388     Value result = rewriter.create<arith::ConstantOp>(
389         loc, resType, rewriter.getZeroAttr(resType));
390     int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
391 
392     for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
393          ++linearIdx) {
394       auto extractIdxs = delinearize(prunedInStrides, linearIdx);
395       SmallVector<int64_t, 4> insertIdxs(extractIdxs);
396       applyPermutationToVector(insertIdxs, prunedTransp);
397       Value extractOp =
398           rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
399       result =
400           rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
401     }
402 
403     rewriter.replaceOp(op, result);
404     return success();
405   }
406 
407 private:
408   /// Options to control the vector patterns.
409   vector::VectorTransformsOptions vectorTransformOptions;
410 };
411 
412 /// Rewrite a 2-D vector.transpose as a sequence of:
413 ///   vector.shape_cast 2D -> 1D
414 ///   vector.shuffle
415 ///   vector.shape_cast 1D -> 2D
416 class TransposeOp2DToShuffleLowering
417     : public OpRewritePattern<vector::TransposeOp> {
418 public:
419   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
420 
421   TransposeOp2DToShuffleLowering(
422       vector::VectorTransformsOptions vectorTransformOptions,
423       MLIRContext *context)
424       : OpRewritePattern<vector::TransposeOp>(context),
425         vectorTransformOptions(vectorTransformOptions) {}
426 
427   LogicalResult matchAndRewrite(vector::TransposeOp op,
428                                 PatternRewriter &rewriter) const override {
429     auto loc = op.getLoc();
430 
431     VectorType srcType = op.getVectorType();
432     if (srcType.getRank() != 2)
433       return rewriter.notifyMatchFailure(op, "Not a 2D transpose");
434 
435     SmallVector<int64_t, 4> transp;
436     for (auto attr : op.transp())
437       transp.push_back(attr.cast<IntegerAttr>().getInt());
438     if (transp[0] != 1 && transp[1] != 0)
439       return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation");
440 
441     if (vectorTransformOptions.vectorTransposeLowering !=
442         VectorTransposeLowering::Shuffle)
443       return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle");
444 
445     int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
446     Value casted = rewriter.create<vector::ShapeCastOp>(
447         loc, VectorType::get({m * n}, srcType.getElementType()), op.vector());
448     SmallVector<int64_t> mask;
449     mask.reserve(m * n);
450     for (int64_t j = 0; j < n; ++j)
451       for (int64_t i = 0; i < m; ++i)
452         mask.push_back(i * n + j);
453 
454     Value shuffled =
455         rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask);
456     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
457                                                      shuffled);
458 
459     return success();
460   }
461 
462 private:
463   /// Options to control the vector patterns.
464   vector::VectorTransformsOptions vectorTransformOptions;
465 };
466 
467 /// Progressive lowering of OuterProductOp.
468 /// One:
469 ///   %x = vector.outerproduct %lhs, %rhs, %acc
470 /// is replaced by:
471 ///   %z = zero-result
472 ///   %0 = vector.extract %lhs[0]
473 ///   %1 = vector.broadcast %0
474 ///   %2 = vector.extract %acc[0]
475 ///   %3 = vector.fma %1, %rhs, %2
476 ///   %4 = vector.insert %3, %z[0]
477 ///   ..
478 ///   %x = vector.insert %.., %..[N-1]
479 ///
480 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
481 public:
482   using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
483 
484   LogicalResult matchAndRewrite(vector::OuterProductOp op,
485                                 PatternRewriter &rewriter) const override {
486     auto loc = op.getLoc();
487 
488     VectorType lhsType = op.getOperandVectorTypeLHS();
489     VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>();
490     VectorType resType = op.getVectorType();
491     Type eltType = resType.getElementType();
492     bool isInt = eltType.isa<IntegerType, IndexType>();
493     Value acc = (op.acc().empty()) ? nullptr : op.acc()[0];
494     vector::CombiningKind kind = op.kind();
495 
496     if (!rhsType) {
497       // Special case: AXPY operation.
498       Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.rhs());
499       Optional<Value> mult =
500           isInt ? genMultI(loc, op.lhs(), b, acc, kind, rewriter)
501                 : genMultF(loc, op.lhs(), b, acc, kind, rewriter);
502       if (!mult.hasValue())
503         return failure();
504       rewriter.replaceOp(op, mult.getValue());
505       return success();
506     }
507 
508     Value result = rewriter.create<arith::ConstantOp>(
509         loc, resType, rewriter.getZeroAttr(resType));
510     for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
511       auto pos = rewriter.getI64ArrayAttr(d);
512       Value x = rewriter.create<vector::ExtractOp>(loc, eltType, op.lhs(), pos);
513       Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
514       Value r = nullptr;
515       if (acc)
516         r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
517       Optional<Value> m = isInt ? genMultI(loc, a, op.rhs(), r, kind, rewriter)
518                                 : genMultF(loc, a, op.rhs(), r, kind, rewriter);
519       if (!m.hasValue())
520         return failure();
521       result = rewriter.create<vector::InsertOp>(loc, resType, m.getValue(),
522                                                  result, pos);
523     }
524     rewriter.replaceOp(op, result);
525     return success();
526   }
527 
528 private:
529   static Optional<Value> genMultI(Location loc, Value x, Value y, Value acc,
530                                   vector::CombiningKind kind,
531                                   PatternRewriter &rewriter) {
532     using vector::CombiningKind;
533 
534     auto mul = rewriter.create<arith::MulIOp>(loc, x, y);
535     if (!acc)
536       return Optional<Value>(mul);
537 
538     if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF)
539       // Only valid for floating point types.
540       return Optional<Value>();
541 
542     return makeArithReduction(rewriter, loc, kind, mul, acc);
543   }
544 
545   static Optional<Value> genMultF(Location loc, Value x, Value y, Value acc,
546                                   vector::CombiningKind kind,
547                                   PatternRewriter &rewriter) {
548     using vector::CombiningKind;
549 
550     // Special case for fused multiply-add.
551     if (acc && kind == CombiningKind::ADD) {
552       return Optional<Value>(rewriter.create<vector::FMAOp>(loc, x, y, acc));
553     }
554 
555     auto mul = rewriter.create<arith::MulFOp>(loc, x, y);
556 
557     if (!acc)
558       return Optional<Value>(mul);
559 
560     if (kind == CombiningKind::ADD || kind == CombiningKind::AND ||
561         kind == CombiningKind::MINUI || kind == CombiningKind::MINSI ||
562         kind == CombiningKind::MAXUI || kind == CombiningKind::MAXSI ||
563         kind == CombiningKind::OR || kind == CombiningKind::XOR)
564       // Already handled or only valid for integer types.
565       return Optional<Value>();
566 
567     return makeArithReduction(rewriter, loc, kind, mul, acc);
568   }
569 };
570 
571 /// Progressive lowering of ConstantMaskOp.
572 /// One:
573 ///   %x = vector.constant_mask [a,b]
574 /// is replaced by:
575 ///   %z = zero-result
576 ///   %l = vector.constant_mask [b]
577 ///   %4 = vector.insert %l, %z[0]
578 ///   ..
579 ///   %x = vector.insert %l, %..[a-1]
580 /// until a one-dimensional vector is reached. All these operations
581 /// will be folded at LLVM IR level.
582 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
583 public:
584   using OpRewritePattern<vector::ConstantMaskOp>::OpRewritePattern;
585 
586   LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
587                                 PatternRewriter &rewriter) const override {
588     auto loc = op.getLoc();
589     auto dstType = op.getType();
590     auto eltType = dstType.getElementType();
591     auto dimSizes = op.mask_dim_sizes();
592     int64_t rank = dstType.getRank();
593 
594     if (rank == 0) {
595       assert(dimSizes.size() == 1 &&
596              "Expected exactly one dim size for a 0-D vector");
597       bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1;
598       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
599           op, dstType,
600           DenseIntElementsAttr::get(
601               VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
602               ArrayRef<bool>{value}));
603       return success();
604     }
605 
606     // Scalable constant masks can only be lowered for the "none set" case.
607     if (dstType.cast<VectorType>().isScalable()) {
608       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
609           op, DenseElementsAttr::get(dstType, false));
610       return success();
611     }
612 
613     int64_t trueDim = std::min(dstType.getDimSize(0),
614                                dimSizes[0].cast<IntegerAttr>().getInt());
615 
616     if (rank == 1) {
617       // Express constant 1-D case in explicit vector form:
618       //   [T,..,T,F,..,F].
619       SmallVector<bool, 4> values(dstType.getDimSize(0));
620       for (int64_t d = 0; d < trueDim; d++)
621         values[d] = true;
622       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
623           op, dstType, rewriter.getBoolVectorAttr(values));
624       return success();
625     }
626 
627     VectorType lowType =
628         VectorType::get(dstType.getShape().drop_front(), eltType);
629     SmallVector<int64_t, 4> newDimSizes;
630     for (int64_t r = 1; r < rank; r++)
631       newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
632     Value trueVal = rewriter.create<vector::ConstantMaskOp>(
633         loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
634     Value result = rewriter.create<arith::ConstantOp>(
635         loc, dstType, rewriter.getZeroAttr(dstType));
636     for (int64_t d = 0; d < trueDim; d++) {
637       auto pos = rewriter.getI64ArrayAttr(d);
638       result =
639           rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos);
640     }
641     rewriter.replaceOp(op, result);
642     return success();
643   }
644 };
645 
646 /// Progressive lowering of CreateMaskOp.
647 /// One:
648 ///   %x = vector.create_mask %a, ... : vector<dx...>
649 /// is replaced by:
650 ///   %l = vector.create_mask ... : vector<...>  ; one lower rank
651 ///   %0 = arith.cmpi "slt", %ci, %a       |
652 ///   %1 = select %0, %l, %zeroes    |
653 ///   %r = vector.insert %1, %pr [i] | d-times
654 ///   %x = ....
655 /// until a one-dimensional vector is reached.
656 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
657 public:
658   using OpRewritePattern<vector::CreateMaskOp>::OpRewritePattern;
659 
660   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
661                                 PatternRewriter &rewriter) const override {
662     auto dstType = op.getResult().getType().cast<VectorType>();
663     int64_t rank = dstType.getRank();
664     if (rank <= 1)
665       return rewriter.notifyMatchFailure(
666           op, "0-D and 1-D vectors are handled separately");
667 
668     auto loc = op.getLoc();
669     auto eltType = dstType.getElementType();
670     int64_t dim = dstType.getDimSize(0);
671     Value idx = op.getOperand(0);
672 
673     VectorType lowType =
674         VectorType::get(dstType.getShape().drop_front(), eltType);
675     Value trueVal = rewriter.create<vector::CreateMaskOp>(
676         loc, lowType, op.getOperands().drop_front());
677     Value falseVal = rewriter.create<arith::ConstantOp>(
678         loc, lowType, rewriter.getZeroAttr(lowType));
679     Value result = rewriter.create<arith::ConstantOp>(
680         loc, dstType, rewriter.getZeroAttr(dstType));
681     for (int64_t d = 0; d < dim; d++) {
682       Value bnd =
683           rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d));
684       Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
685                                                  bnd, idx);
686       Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal);
687       auto pos = rewriter.getI64ArrayAttr(d);
688       result =
689           rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos);
690     }
691     rewriter.replaceOp(op, result);
692     return success();
693   }
694 };
695 
696 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
697 /// vectors progressively on the way to target llvm.matrix intrinsics.
698 /// This iterates over the most major dimension of the 2-D vector and performs
699 /// rewrites into:
700 ///   vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
701 class ShapeCastOp2DDownCastRewritePattern
702     : public OpRewritePattern<vector::ShapeCastOp> {
703 public:
704   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
705 
706   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
707                                 PatternRewriter &rewriter) const override {
708     auto sourceVectorType = op.getSourceVectorType();
709     auto resultVectorType = op.getResultVectorType();
710     if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
711       return failure();
712 
713     auto loc = op.getLoc();
714     Value desc = rewriter.create<arith::ConstantOp>(
715         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
716     unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
717     for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
718       Value vec = rewriter.create<vector::ExtractOp>(loc, op.source(), i);
719       desc = rewriter.create<vector::InsertStridedSliceOp>(
720           loc, vec, desc,
721           /*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
722     }
723     rewriter.replaceOp(op, desc);
724     return success();
725   }
726 };
727 
728 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
729 /// vectors progressively.
730 /// This iterates over the most major dimension of the 2-D vector and performs
731 /// rewrites into:
732 ///   vector.extract_strided_slice from 1-D + vector.insert into 2-D
733 /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
734 class ShapeCastOp2DUpCastRewritePattern
735     : public OpRewritePattern<vector::ShapeCastOp> {
736 public:
737   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
738 
739   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
740                                 PatternRewriter &rewriter) const override {
741     auto sourceVectorType = op.getSourceVectorType();
742     auto resultVectorType = op.getResultVectorType();
743     if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
744       return failure();
745 
746     auto loc = op.getLoc();
747     Value desc = rewriter.create<arith::ConstantOp>(
748         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
749     unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
750     for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
751       Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
752           loc, op.source(), /*offsets=*/i * mostMinorVectorSize,
753           /*sizes=*/mostMinorVectorSize,
754           /*strides=*/1);
755       desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
756     }
757     rewriter.replaceOp(op, desc);
758     return success();
759   }
760 };
761 
762 // We typically should not lower general shape cast operations into data
763 // movement instructions, since the assumption is that these casts are
764 // optimized away during progressive lowering. For completeness, however,
765 // we fall back to a reference implementation that moves all elements
766 // into the right place if we get here.
767 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
768 public:
769   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
770 
771   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
772                                 PatternRewriter &rewriter) const override {
773     Location loc = op.getLoc();
774     auto sourceVectorType = op.getSourceVectorType();
775     auto resultVectorType = op.getResultVectorType();
776 
777     // Special case 2D/1D lowerings with better implementations.
778     // TODO: make is ND/1D to allow generic ND->1D->MD.
779     int64_t srcRank = sourceVectorType.getRank();
780     int64_t resRank = resultVectorType.getRank();
781     if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
782       return failure();
783 
784     // Generic ShapeCast lowering path goes all the way down to unrolled scalar
785     // extract/insert chains.
786     // TODO: consider evolving the semantics to only allow 1D source or dest and
787     // drop this potentially very expensive lowering.
788     // Compute number of elements involved in the reshape.
789     int64_t numElts = 1;
790     for (int64_t r = 0; r < srcRank; r++)
791       numElts *= sourceVectorType.getDimSize(r);
792     // Replace with data movement operations:
793     //    x[0,0,0] = y[0,0]
794     //    x[0,0,1] = y[0,1]
795     //    x[0,1,0] = y[0,2]
796     // etc., incrementing the two index vectors "row-major"
797     // within the source and result shape.
798     SmallVector<int64_t, 4> srcIdx(srcRank);
799     SmallVector<int64_t, 4> resIdx(resRank);
800     Value result = rewriter.create<arith::ConstantOp>(
801         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
802     for (int64_t i = 0; i < numElts; i++) {
803       if (i != 0) {
804         incIdx(srcIdx, sourceVectorType, srcRank - 1);
805         incIdx(resIdx, resultVectorType, resRank - 1);
806       }
807       Value e = rewriter.create<vector::ExtractOp>(loc, op.source(), srcIdx);
808       result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx);
809     }
810     rewriter.replaceOp(op, result);
811     return success();
812   }
813 
814 private:
815   static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) {
816     assert(0 <= r && r < tp.getRank());
817     if (++idx[r] == tp.getDimSize(r)) {
818       idx[r] = 0;
819       incIdx(idx, tp, r - 1);
820     }
821   }
822 };
823 
824 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
825 /// Ex:
826 /// ```
827 ///   %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
828 ///   %1 = vector.multi_reduction add, %0 [1]
829 ///     : vector<8x32x16xf32> to vector<8x16xf32>
830 /// ```
831 /// Gets converted to:
832 /// ```
833 ///   %1 = vector.contract {indexing_maps = [
834 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
835 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
836 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
837 ///    iterator_types = ["parallel", "parallel", "reduction"],
838 ///    kind = add} %0, %arg1, %cst_f0
839 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
840 ///  ```
841 struct MultiReduceToContract
842     : public OpRewritePattern<vector::MultiDimReductionOp> {
843   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
844 
845   LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
846                                 PatternRewriter &rewriter) const override {
847     if (reduceOp.kind() != vector::CombiningKind::ADD)
848       return failure();
849     Operation *mulOp = reduceOp.source().getDefiningOp();
850     if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
851       return failure();
852     SmallVector<bool> reductionMask = reduceOp.getReductionMask();
853     auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
854     SmallVector<AffineExpr> exprs;
855     SmallVector<StringRef> iteratorTypes;
856     for (const auto &isReduceDim : llvm::enumerate(reductionMask)) {
857       if (!isReduceDim.value()) {
858         iteratorTypes.push_back(getParallelIteratorTypeName());
859         exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
860       } else {
861         iteratorTypes.push_back(getReductionIteratorTypeName());
862       }
863     }
864     auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(),
865                                  /*symCount=*/0, exprs, reduceOp.getContext());
866     Value zero = rewriter.create<arith::ConstantOp>(
867         reduceOp.getLoc(), reduceOp.getDestType(),
868         rewriter.getZeroAttr(reduceOp.getDestType()));
869     rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
870         reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), zero,
871         rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
872         rewriter.getStrArrayAttr(iteratorTypes));
873     return success();
874   }
875 };
876 
877 /// Merge TransposeOp into ContractionOp user.
878 /// Ex:
879 /// ```
880 ///   %0 = vector.transpose %arg0, [2, 0, 1]
881 ///     : vector<32x16x8xf32> to vector<8x32x16xf32>
882 ///   %1 = vector.contract {indexing_maps = [
883 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
884 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
885 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
886 ///    iterator_types = ["parallel", "parallel", "reduction"],
887 ///    kind = add} %0, %arg1, %cst_f0
888 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
889 /// ```
890 /// Gets converted to:
891 /// ```
892 ///   %1 = vector.contract {indexing_maps = [
893 ///         affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
894 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
895 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
896 ///    iterator_types = ["parallel", "parallel", "reduction"],
897 ///    kind = add} %arg0, %arg1, %cst_f0
898 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
899 ///  ```
900 struct CombineContractTranspose
901     : public OpRewritePattern<vector::ContractionOp> {
902   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
903 
904   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
905                                 PatternRewriter &rewriter) const override {
906     SmallVector<AffineMap, 4> maps =
907         llvm::to_vector<4>(contractOp.getIndexingMaps());
908     Value lhs = contractOp.lhs();
909     Value rhs = contractOp.rhs();
910     size_t index = 0;
911     bool changed = false;
912     for (Value *operand : {&lhs, &rhs}) {
913       AffineMap &map = maps[index++];
914       auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
915       if (!transposeOp)
916         continue;
917       SmallVector<int64_t> perm;
918       transposeOp.getTransp(perm);
919       AffineMap permutationMap = AffineMap::getPermutationMap(
920           extractVector<unsigned>(transposeOp.transp()),
921           contractOp.getContext());
922       map = inversePermutation(permutationMap).compose(map);
923       *operand = transposeOp.vector();
924       changed = true;
925     }
926     if (!changed)
927       return failure();
928     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
929         contractOp, lhs, rhs, contractOp.acc(),
930         rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types());
931     return success();
932   }
933 };
934 
935 /// Merge BroadcastOp into ContractionOp user.
936 /// Ex:
937 /// ```
938 ///   %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
939 ///   %1 = vector.contract {indexing_maps = [
940 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
941 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
942 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
943 ///    iterator_types = ["parallel", "parallel", "reduction"],
944 ///    kind = add} %0, %arg1, %cst_f0
945 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
946 /// ```
947 /// Gets converted to:
948 /// ```
949 ///   %1 = vector.contract {indexing_maps = [
950 ///         affine_map<(d0, d1, d2) -> (d1, d2)>,
951 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
952 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
953 ///    iterator_types = ["parallel", "parallel", "reduction"],
954 ///    kind = add} %arg0, %arg1, %cst_f0
955 ///    : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
956 ///  ```
957 struct CombineContractBroadcast
958     : public OpRewritePattern<vector::ContractionOp> {
959   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
960 
961   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
962                                 PatternRewriter &rewriter) const override {
963     SmallVector<AffineMap, 4> maps =
964         llvm::to_vector<4>(contractOp.getIndexingMaps());
965     Value lhs = contractOp.lhs();
966     Value rhs = contractOp.rhs();
967     size_t index = 0;
968     bool changed = false;
969     for (Value *operand : {&lhs, &rhs}) {
970       AffineMap &map = maps[index++];
971       auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
972       if (!broadcast)
973         continue;
974       // contractionOp can only take vector as operands.
975       auto srcType = broadcast.getSourceType().dyn_cast<VectorType>();
976       if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank())
977         continue;
978       int64_t rankDiff =
979           broadcast.getVectorType().getRank() - srcType.getRank();
980       bool innerDimBroadcast = false;
981       SmallVector<AffineExpr> originalDims;
982       for (const auto &dim : llvm::enumerate(srcType.getShape())) {
983         if (dim.value() !=
984             broadcast.getVectorType().getDimSize(rankDiff + dim.index())) {
985           innerDimBroadcast = true;
986           break;
987         }
988         originalDims.push_back(
989             rewriter.getAffineDimExpr(dim.index() + rankDiff));
990       }
991       // Contract doesn't support inner dimension broadcast. Once this is
992       // relaxed we can remove this case.
993       if (innerDimBroadcast)
994         continue;
995       AffineMap broadcastMap =
996           AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims,
997                          contractOp.getContext());
998       map = broadcastMap.compose(map);
999       *operand = broadcast.source();
1000       changed = true;
1001     }
1002     if (!changed)
1003       return failure();
1004     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1005         contractOp, lhs, rhs, contractOp.acc(),
1006         rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types());
1007     return success();
1008   }
1009 };
1010 
1011 /// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and
1012 /// contraction ops closer, which kicks in CombineContractBroadcast pattern when
1013 /// casting ops are around these operations.
1014 /// Ex:
1015 /// ```
1016 ///   %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8>
1017 ///   %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
1018 /// ```
1019 /// Gets converted to:
1020 /// ```
1021 ///   %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32>
1022 ///   %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32>
1023 /// ```
1024 struct ReorderCastOpsOnBroadcast
1025     : public OpInterfaceRewritePattern<CastOpInterface> {
1026   using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
1027 
1028   LogicalResult matchAndRewrite(CastOpInterface op,
1029                                 PatternRewriter &rewriter) const override {
1030     if (op->getNumOperands() != 1)
1031       return failure();
1032     auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
1033     if (!bcastOp)
1034       return failure();
1035 
1036     Type castResTy = getElementTypeOrSelf(op->getResult(0));
1037     if (auto vecTy = bcastOp.getSourceType().dyn_cast<VectorType>())
1038       castResTy = VectorType::get(vecTy.getShape(), castResTy);
1039     auto castOp = rewriter.create(op->getLoc(), op->getName().getIdentifier(),
1040                                   bcastOp.source(), castResTy, op->getAttrs());
1041     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1042         op, op->getResult(0).getType(), castOp->getResult(0));
1043     return success();
1044   }
1045 };
1046 
1047 /// Reorders cast(transpose) to transpose(cast). This makes broadcast ops and
1048 /// contraction ops closer, which kicks in CombineContractTranspose pattern when
1049 /// casting ops are around these operations.
1050 /// Ex:
1051 /// ```
1052 ///   %0 = vector.transpose %arg0, [2, 0, 1]
1053 ///     : vector<32x16x8xi8> to vector<8x32x16xi8>
1054 ///   %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
1055 /// ```
1056 /// Gets converted to:
1057 /// ```
1058 ///   %0 = arith.extsi %0 : vector<32x16x8xi8> to vector<32x16x8xi32>
1059 ///   %1 = vector.transpose %arg0, [2, 0, 1]
1060 ///     : vector<32x16x8xi32> to vector<8x32x16xi32>
1061 /// ```
1062 struct ReorderCastOpsOnTranspose
1063     : public OpInterfaceRewritePattern<CastOpInterface> {
1064 
1065   using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
1066 
1067   LogicalResult matchAndRewrite(CastOpInterface op,
1068                                 PatternRewriter &rewriter) const override {
1069     if (op->getNumOperands() != 1)
1070       return failure();
1071     auto transpOp = op->getOperand(0).getDefiningOp<vector::TransposeOp>();
1072     if (!transpOp)
1073       return failure();
1074 
1075     auto castResTy = transpOp.getVectorType();
1076     castResTy = VectorType::get(castResTy.getShape(),
1077                                 getElementTypeOrSelf(op->getResult(0)));
1078     auto castOp = rewriter.create(op->getLoc(), op->getName().getIdentifier(),
1079                                   transpOp.vector(), castResTy, op->getAttrs());
1080     rewriter.replaceOpWithNewOp<vector::TransposeOp>(
1081         op, op->getResult(0).getType(), castOp->getResult(0),
1082         transpOp.getTransp());
1083     return success();
1084   }
1085 };
1086 
1087 } // namespace
1088 
1089 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
1090 /// operands `x` and `y`.
1091 static Value createAdd(Location loc, Value x, Value y, bool isInt,
1092                        PatternRewriter &rewriter) {
1093   if (isInt)
1094     return rewriter.create<arith::AddIOp>(loc, x, y);
1095   return rewriter.create<arith::AddFOp>(loc, x, y);
1096 }
1097 
1098 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
1099 /// operands `x and `y`.
1100 static Value createMul(Location loc, Value x, Value y, bool isInt,
1101                        PatternRewriter &rewriter) {
1102   if (isInt)
1103     return rewriter.create<arith::MulIOp>(loc, x, y);
1104   return rewriter.create<arith::MulFOp>(loc, x, y);
1105 }
1106 
1107 namespace mlir {
1108 
1109 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1110 /// semantics to:
1111 /// ```
1112 ///    %mta = maybe_transpose
1113 ///    %mtb = maybe_transpose
1114 ///    %flattened_a = vector.shape_cast %mta
1115 ///    %flattened_b = vector.shape_cast %mtb
1116 ///    %flattened_d = vector.matmul %flattened_a, %flattened_b
1117 ///    %mtd = vector.shape_cast %flattened_d
1118 ///    %d = maybe_untranspose %mtd
1119 ///    %e = add %c, %d
1120 /// ```
1121 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
1122 //
1123 /// This only kicks in when VectorTransformsOptions is set to `Matmul`.
1124 /// vector.transpose operations are inserted if the vector.contract op is not a
1125 /// row-major matrix multiply.
1126 LogicalResult
1127 ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
1128                                                  PatternRewriter &rew) const {
1129   // TODO: implement masks
1130   if (llvm::size(op.masks()) != 0)
1131     return failure();
1132   if (vectorTransformOptions.vectorContractLowering !=
1133       vector::VectorContractLowering::Matmul)
1134     return failure();
1135   if (failed(filter(op)))
1136     return failure();
1137 
1138   auto iteratorTypes = op.iterator_types().getValue();
1139   if (!isParallelIterator(iteratorTypes[0]) ||
1140       !isParallelIterator(iteratorTypes[1]) ||
1141       !isReductionIterator(iteratorTypes[2]))
1142     return failure();
1143 
1144   Type elementType = op.getLhsType().getElementType();
1145   if (!elementType.isIntOrFloat())
1146     return failure();
1147 
1148   // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
1149   // Bail out if the contraction cannot be put in this form.
1150   MLIRContext *ctx = op.getContext();
1151   Location loc = op.getLoc();
1152   AffineExpr m, n, k;
1153   bindDims(rew.getContext(), m, n, k);
1154   // LHS must be A(m, k) or A(k, m).
1155   Value lhs = op.lhs();
1156   auto lhsMap = op.indexing_maps()[0];
1157   if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
1158     lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
1159   else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
1160     return failure();
1161 
1162   // RHS must be B(k, n) or B(n, k).
1163   Value rhs = op.rhs();
1164   auto rhsMap = op.indexing_maps()[1];
1165   if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
1166     rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
1167   else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
1168     return failure();
1169 
1170   // At this point lhs and rhs are in row-major.
1171   VectorType lhsType = lhs.getType().cast<VectorType>();
1172   VectorType rhsType = rhs.getType().cast<VectorType>();
1173   int64_t lhsRows = lhsType.getDimSize(0);
1174   int64_t lhsColumns = lhsType.getDimSize(1);
1175   int64_t rhsColumns = rhsType.getDimSize(1);
1176 
1177   Type flattenedLHSType =
1178       VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
1179   lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
1180 
1181   Type flattenedRHSType =
1182       VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
1183   rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
1184 
1185   Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
1186                                            rhsColumns);
1187   mul = rew.create<vector::ShapeCastOp>(
1188       loc,
1189       VectorType::get({lhsRows, rhsColumns},
1190                       getElementTypeOrSelf(op.acc().getType())),
1191       mul);
1192 
1193   // ACC must be C(m, n) or C(n, m).
1194   auto accMap = op.indexing_maps()[2];
1195   if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
1196     mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
1197   else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
1198     llvm_unreachable("invalid contraction semantics");
1199 
1200   Value res =
1201       elementType.isa<IntegerType>()
1202           ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.acc(), mul))
1203           : static_cast<Value>(rew.create<arith::AddFOp>(loc, op.acc(), mul));
1204 
1205   rew.replaceOp(op, res);
1206   return success();
1207 }
1208 
1209 namespace {
1210 struct IteratorType {
1211   IteratorType(StringRef strRef) : strRef(strRef) {}
1212   bool isOfType(Attribute attr) const {
1213     auto sAttr = attr.dyn_cast<StringAttr>();
1214     return sAttr && sAttr.getValue() == strRef;
1215   }
1216   StringRef strRef;
1217 };
1218 struct Par : public IteratorType {
1219   Par() : IteratorType(getParallelIteratorTypeName()) {}
1220 };
1221 struct Red : public IteratorType {
1222   Red() : IteratorType(getReductionIteratorTypeName()) {}
1223 };
1224 
1225 /// Generate a vector implementation for matmat, matvec and tmatvec.
1226 /// This unrolls outer-products along the reduction dimension.
1227 struct UnrolledOuterProductGenerator
1228     : public StructuredGenerator<vector::ContractionOp> {
1229 
1230   UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op)
1231       : StructuredGenerator<vector::ContractionOp>(builder, op),
1232         kind(op.kind()), lhs(op.lhs()), rhs(op.rhs()), res(op.acc()),
1233         lhsType(op.getLhsType()) {}
1234 
1235   Value t(Value v) {
1236     static constexpr std::array<int64_t, 2> perm = {1, 0};
1237     return builder.create<vector::TransposeOp>(loc, v, perm);
1238   }
1239 
1240   Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) {
1241     assert(reductionSize > 0);
1242     for (int64_t k = 0; k < reductionSize; ++k) {
1243       Value a = builder.create<vector::ExtractOp>(loc, lhs, k);
1244       Value b = builder.create<vector::ExtractOp>(loc, rhs, k);
1245       res = builder.create<vector::OuterProductOp>(loc, res.getType(), a, b,
1246                                                    res, kind);
1247     }
1248     return res;
1249   }
1250 
1251   /// Two outer parallel, one inner reduction (matmat flavor).
1252   FailureOr<Value> matmat() {
1253     if (!iters({Par(), Par(), Red()}))
1254       return failure();
1255     // Set up the parallel/reduction structure in the right form.
1256     AffineExpr m, n, k;
1257     bindDims(builder.getContext(), m, n, k);
1258     // Classical row-major matmul:  Just permute the lhs.
1259     if (layout({{m, k}, {k, n}, {m, n}}))
1260       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1261     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1262     if (layout({{m, k}, {n, k}, {m, n}})) {
1263       Value tlhs = t(lhs);
1264       return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1));
1265     }
1266     // No need to permute anything.
1267     if (layout({{k, m}, {k, n}, {m, n}}))
1268       return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1269     // Just permute the rhs.
1270     if (layout({{k, m}, {n, k}, {m, n}}))
1271       return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0));
1272     // Transposed output: swap RHS and LHS.
1273     // Classical row-major matmul: permute the lhs.
1274     if (layout({{m, k}, {k, n}, {n, m}}))
1275       return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1));
1276     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1277     if (layout({{m, k}, {n, k}, {n, m}})) {
1278       Value trhs = t(rhs);
1279       return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1));
1280     }
1281     if (layout({{k, m}, {k, n}, {n, m}}))
1282       return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1283     if (layout({{k, m}, {n, k}, {n, m}}))
1284       return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1285     return failure();
1286   }
1287 
1288   /// One outer parallel, one inner reduction (matvec flavor)
1289   FailureOr<Value> matvec() {
1290     if (!iters({Par(), Red()}))
1291       return failure();
1292     AffineExpr m, k;
1293     bindDims(builder.getContext(), m, k);
1294 
1295     // Case mat-vec: transpose.
1296     if (layout({{m, k}, {k}, {m}}))
1297       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1298     // Case mat-trans-vec: ready to go.
1299     if (layout({{k, m}, {k}, {m}}))
1300       return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1301     // Case vec-mat: swap and transpose.
1302     if (layout({{k}, {m, k}, {m}}))
1303       return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1304     // Case vec-mat-trans: swap and ready to go.
1305     if (layout({{k}, {k, m}, {m}}))
1306       return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1307     return failure();
1308   }
1309 
1310   //
1311   // One outer reduction, one inner parallel (tmatvec flavor)
1312   //
1313   FailureOr<Value> tmatvec() {
1314     if (!iters({Red(), Par()}))
1315       return failure();
1316     AffineExpr k, m;
1317     bindDims(builder.getContext(), k, m);
1318 
1319     // Case mat-vec: transpose.
1320     if (layout({{m, k}, {k}, {m}}))
1321       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1322     // Case mat-trans-vec: ready to go.
1323     if (layout({{k, m}, {k}, {m}}))
1324       return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1325     // Case vec-mat: swap and transpose.
1326     if (layout({{k}, {m, k}, {m}}))
1327       return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1328     // Case vec-mat-trans: swap and ready to go.
1329     if (layout({{k}, {k, m}, {m}}))
1330       return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1331     return failure();
1332   }
1333 
1334 private:
1335   vector::CombiningKind kind;
1336   Value lhs, rhs, res;
1337   VectorType lhsType;
1338 };
1339 } // namespace
1340 
1341 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1342 /// semantics to a reduction_size-unrolled sequence:
1343 /// ```
1344 ///    %at = vector.transpose %a, [1, 0]
1345 ///    %bRow0 = vector.extract %b[0]
1346 ///    %atRow0 = vector.extract %at[0]
1347 ///    %c0 = vector.outerproduct %atRow0, %bRow0, %c
1348 ///    ...
1349 ///    %bRowK = vector.extract %b[K]
1350 ///    %atRowK = vector.extract %at[K]
1351 ///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
1352 /// ```
1353 ///
1354 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but
1355 /// otherwise supports any layout permutation of the matrix-multiply.
1356 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
1357     vector::ContractionOp op, PatternRewriter &rewriter) const {
1358   // TODO: implement masks
1359   if (llvm::size(op.masks()) != 0)
1360     return failure();
1361 
1362   if (vectorTransformOptions.vectorContractLowering !=
1363       vector::VectorContractLowering::OuterProduct)
1364     return failure();
1365 
1366   if (failed(filter(op)))
1367     return failure();
1368 
1369   UnrolledOuterProductGenerator e(rewriter, op);
1370   FailureOr<Value> matmatRes = e.matmat();
1371   if (succeeded(matmatRes)) {
1372     rewriter.replaceOp(op, *matmatRes);
1373     return success();
1374   }
1375   FailureOr<Value> matvecRes = e.matvec();
1376   if (succeeded(matvecRes)) {
1377     rewriter.replaceOp(op, *matvecRes);
1378     return success();
1379   }
1380   FailureOr<Value> tmatvecRes = e.tmatvec();
1381   if (succeeded(tmatvecRes)) {
1382     rewriter.replaceOp(op, *tmatvecRes);
1383     return success();
1384   }
1385 
1386   return failure();
1387 }
1388 
1389 LogicalResult
1390 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
1391                                             PatternRewriter &rewriter) const {
1392   // TODO: implement masks
1393   if (llvm::size(op.masks()) != 0)
1394     return failure();
1395 
1396   if (failed(filter(op)))
1397     return failure();
1398 
1399   if (vectorTransformOptions.vectorContractLowering !=
1400       vector::VectorContractLowering::Dot)
1401     return failure();
1402 
1403   auto iteratorTypes = op.iterator_types().getValue();
1404   static constexpr std::array<int64_t, 2> perm = {1, 0};
1405   Location loc = op.getLoc();
1406   Value lhs = op.lhs(), rhs = op.rhs();
1407 
1408   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1409   auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
1410   AffineExpr m, n, k;
1411   bindDims(rewriter.getContext(), m, n, k);
1412   SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
1413   //
1414   // In the following we wish to make the reduction dimension innermost so we
1415   // can load vectors and just fmul + reduce into a scalar.
1416   //
1417   if (isParallelIterator(iteratorTypes[0]) &&
1418       isParallelIterator(iteratorTypes[1]) &&
1419       isReductionIterator(iteratorTypes[2])) {
1420     //
1421     // Two outer parallel, one inner reduction (matmat flavor).
1422     //
1423     if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1424       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1425     } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
1426       // No need to permute anything.
1427     } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1428       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1429       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1430     } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1431       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1432     } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1433       // This is the classical row-major matmul. Just permute the lhs.
1434       Value tmp = lhs;
1435       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1436       rhs = tmp;
1437     } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1438       std::swap(lhs, rhs);
1439     } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1440       Value tmp = lhs;
1441       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1442       rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
1443     } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1444       Value tmp = rhs;
1445       rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1446       lhs = tmp;
1447     } else {
1448       return failure();
1449     }
1450   } else if (isParallelIterator(iteratorTypes[0]) &&
1451              isReductionIterator(iteratorTypes[1])) {
1452     //
1453     // One outer parallel, one inner reduction (matvec flavor)
1454     //
1455     if (maps == infer({{m, n}, {n}, {m}})) {
1456       // No need to permute anything.
1457     } else if (maps == infer({{n, m}, {n}, {m}})) {
1458       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1459     } else if (maps == infer({{n}, {m, n}, {m}})) {
1460       std::swap(lhs, rhs);
1461     } else if (maps == infer({{n}, {n, m}, {m}})) {
1462       std::swap(lhs, rhs);
1463       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1464     } else {
1465       return failure();
1466     }
1467   } else {
1468     return failure();
1469   }
1470 
1471   VectorType dstType = op.getResultType().cast<VectorType>();
1472   assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
1473          "Expected dst type of rank 1 or 2");
1474 
1475   unsigned rank = dstType.getRank();
1476   unsigned dstRows = dstType.getShape()[0];
1477   unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
1478 
1479   // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
1480   Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
1481                                                  rewriter.getZeroAttr(dstType));
1482   bool isInt = dstType.getElementType().isa<IntegerType>();
1483   for (unsigned r = 0; r < dstRows; ++r) {
1484     Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
1485     for (unsigned c = 0; c < dstColumns; ++c) {
1486       Value b = rank == 1
1487                     ? rhs
1488                     : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
1489       Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
1490       Value reduced = rewriter.create<vector::ReductionOp>(
1491           op.getLoc(), vector::CombiningKind::ADD, m);
1492 
1493       SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
1494                                               : SmallVector<int64_t, 2>{r, c};
1495       res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
1496     }
1497   }
1498   if (auto acc = op.acc())
1499     res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
1500   rewriter.replaceOp(op, res);
1501   return success();
1502 }
1503 
1504 /// Progressive lowering of ContractionOp.
1505 /// One:
1506 ///   %x = vector.contract with at least one free/batch dimension
1507 /// is replaced by:
1508 ///   %a = vector.contract with one less free/batch dimension
1509 ///   %b = vector.contract with one less free/batch dimension
1510 ///   ..
1511 ///   %x = combine %a %b ..
1512 /// until a pure contraction is reached (no free/batch dimensions),
1513 /// which is replaced by a dot-product.
1514 ///
1515 /// This only kicks in when either VectorTransformsOptions is set
1516 /// to DOT or when other contraction patterns fail.
1517 //
1518 // TODO: break down into transpose/reshape/cast ops
1519 //               when they become available to avoid code dup
1520 // TODO: investigate lowering order impact on performance
1521 LogicalResult
1522 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
1523                                        PatternRewriter &rewriter) const {
1524   // TODO: implement masks.
1525   if (llvm::size(op.masks()) != 0)
1526     return failure();
1527 
1528   if (failed(filter(op)))
1529     return failure();
1530 
1531   // TODO: support mixed mode contract lowering.
1532   if (op.getLhsType().getElementType() !=
1533           getElementTypeOrSelf(op.getAccType()) ||
1534       op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
1535     return failure();
1536 
1537   // TODO: implement benefits, cost models.
1538   MLIRContext *ctx = op.getContext();
1539   ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
1540   if (succeeded(pat1.matchAndRewrite(op, rewriter)))
1541     return success();
1542   ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
1543   if (succeeded(pat2.matchAndRewrite(op, rewriter)))
1544     return success();
1545   ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
1546   if (succeeded(pat3.matchAndRewrite(op, rewriter)))
1547     return success();
1548 
1549   // Find first batch dimension in LHS/RHS, and lower when found.
1550   std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
1551   if (!batchDimMap.empty()) {
1552     int64_t lhsIndex = batchDimMap[0].first;
1553     int64_t rhsIndex = batchDimMap[0].second;
1554     rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter));
1555     return success();
1556   }
1557 
1558   // Collect contracting dimensions.
1559   std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
1560       op.getContractingDimMap();
1561   DenseSet<int64_t> lhsContractingDimSet;
1562   DenseSet<int64_t> rhsContractingDimSet;
1563   for (auto &dimPair : contractingDimMap) {
1564     lhsContractingDimSet.insert(dimPair.first);
1565     rhsContractingDimSet.insert(dimPair.second);
1566   }
1567 
1568   // Find first free dimension in LHS, and lower when found.
1569   VectorType lhsType = op.getLhsType();
1570   for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
1571     if (lhsContractingDimSet.count(lhsIndex) == 0) {
1572       rewriter.replaceOp(
1573           op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter));
1574       return success();
1575     }
1576   }
1577 
1578   // Find first free dimension in RHS, and lower when found.
1579   VectorType rhsType = op.getRhsType();
1580   for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
1581     if (rhsContractingDimSet.count(rhsIndex) == 0) {
1582       rewriter.replaceOp(
1583           op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter));
1584       return success();
1585     }
1586   }
1587 
1588   // Lower the first remaining reduction dimension.
1589   if (!contractingDimMap.empty()) {
1590     rewriter.replaceOp(op, lowerReduction(op, rewriter));
1591     return success();
1592   }
1593 
1594   return failure();
1595 }
1596 
1597 // Lower one parallel dimension.
1598 // TODO: consider reusing existing contract unrolling
1599 Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
1600                                            int64_t lhsIndex, int64_t rhsIndex,
1601                                            PatternRewriter &rewriter) const {
1602   VectorType lhsType = op.getLhsType();
1603   VectorType rhsType = op.getRhsType();
1604   VectorType resType = op.getResultType().cast<VectorType>();
1605   // Find the iterator type index and result index.
1606   SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
1607   int64_t iterIndex = -1;
1608   int64_t dimSize = -1;
1609   if (lhsIndex >= 0) {
1610     iterIndex = iMap[0].getDimPosition(lhsIndex);
1611     assert((rhsIndex < 0 || iterIndex == iMap[1].getDimPosition(rhsIndex)) &&
1612            "parallel index should be free in LHS or batch in LHS/RHS");
1613     dimSize = lhsType.getDimSize(lhsIndex);
1614   } else {
1615     assert(rhsIndex >= 0 && "missing parallel index");
1616     iterIndex = iMap[1].getDimPosition(rhsIndex);
1617     dimSize = rhsType.getDimSize(rhsIndex);
1618   }
1619   assert(iterIndex >= 0 && "parallel index not listed in operand mapping");
1620   Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex);
1621   assert(lookup.hasValue() && "parallel index not listed in reduction");
1622   int64_t resIndex = lookup.getValue();
1623   // Construct new iterator types and affine map array attribute.
1624   std::array<AffineMap, 3> lowIndexingMaps = {
1625       adjustMap(iMap[0], iterIndex, rewriter),
1626       adjustMap(iMap[1], iterIndex, rewriter),
1627       adjustMap(iMap[2], iterIndex, rewriter)};
1628   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1629   auto lowIter =
1630       rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
1631   // Unroll into a series of lower dimensional vector.contract ops.
1632   Location loc = op.getLoc();
1633   Value result = rewriter.create<arith::ConstantOp>(
1634       loc, resType, rewriter.getZeroAttr(resType));
1635   for (int64_t d = 0; d < dimSize; ++d) {
1636     auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
1637     auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
1638     auto acc = reshapeLoad(loc, op.acc(), resType, resIndex, d, rewriter);
1639     Value lowContract = rewriter.create<vector::ContractionOp>(
1640         loc, lhs, rhs, acc, lowAffine, lowIter);
1641     result =
1642         reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter);
1643   }
1644   return result;
1645 }
1646 
1647 // Lower one reduction dimension.
1648 Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
1649                                             PatternRewriter &rewriter) const {
1650   auto loc = op.getLoc();
1651   VectorType lhsType = op.getLhsType();
1652   VectorType rhsType = op.getRhsType();
1653   Type resType = op.getResultType();
1654   assert(!resType.isa<VectorType>());
1655   bool isInt = resType.isa<IntegerType>();
1656   // Use iterator index 0.
1657   int64_t iterIndex = 0;
1658   SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
1659   Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
1660   Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
1661   assert(lookupLhs.hasValue() && "missing LHS parallel index");
1662   assert(lookupRhs.hasValue() && "missing RHS parallel index");
1663   int64_t lhsIndex = lookupLhs.getValue();
1664   int64_t rhsIndex = lookupRhs.getValue();
1665   int64_t dimSize = lhsType.getDimSize(lhsIndex);
1666   assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape");
1667   // Base case.
1668   if (lhsType.getRank() == 1) {
1669     assert(rhsType.getRank() == 1 && "corrupt contraction");
1670     Value m = createMul(loc, op.lhs(), op.rhs(), isInt, rewriter);
1671     auto kind = vector::CombiningKind::ADD;
1672     Value res = rewriter.create<vector::ReductionOp>(loc, kind, m);
1673     if (auto acc = op.acc())
1674       res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
1675     return res;
1676   }
1677   // Construct new iterator types and affine map array attribute.
1678   std::array<AffineMap, 3> lowIndexingMaps = {
1679       adjustMap(iMap[0], iterIndex, rewriter),
1680       adjustMap(iMap[1], iterIndex, rewriter),
1681       adjustMap(iMap[2], iterIndex, rewriter)};
1682   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1683   auto lowIter =
1684       rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
1685   // Unroll into a series of lower dimensional vector.contract ops.
1686   // By feeding the initial accumulator into the first contraction,
1687   // and the result of each contraction into the next, eventually
1688   // the sum of all reductions is computed.
1689   Value result = op.acc();
1690   for (int64_t d = 0; d < dimSize; ++d) {
1691     auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
1692     auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
1693     result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result,
1694                                                     lowAffine, lowIter);
1695   }
1696   return result;
1697 }
1698 
1699 } // namespace mlir
1700 
1701 Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
1702     OpBuilder &builder, Operation *op, ArrayRef<Value> ids,
1703     ArrayRef<int64_t> multiplicity, const AffineMap &map) {
1704   OpBuilder::InsertionGuard guard(builder);
1705   builder.setInsertionPointAfter(op);
1706   Location loc = op->getLoc();
1707   if (op->getNumResults() != 1)
1708     return {};
1709   Value result = op->getResult(0);
1710   VectorType type = op->getResult(0).getType().dyn_cast<VectorType>();
1711   if (!type || map.getNumResults() != multiplicity.size())
1712     return {};
1713   // For each dimension being distributed check that the size is a multiple of
1714   // the multiplicity. To handle more sizes we would need to support masking.
1715   unsigned multiplictyCount = 0;
1716   for (auto exp : map.getResults()) {
1717     auto affinExp = exp.dyn_cast<AffineDimExpr>();
1718     if (!affinExp || affinExp.getPosition() >= type.getRank() ||
1719         type.getDimSize(affinExp.getPosition()) %
1720                 multiplicity[multiplictyCount++] !=
1721             0)
1722       return {};
1723   }
1724   DistributeOps ops;
1725   ops.extract =
1726       builder.create<vector::ExtractMapOp>(loc, result, ids, multiplicity, map);
1727   ops.insert =
1728       builder.create<vector::InsertMapOp>(loc, ops.extract, result, ids);
1729   return ops;
1730 }
1731 
1732 /// Progressive lowering of transfer_read. This pattern supports lowering of
1733 /// `vector.transfer_read` to a combination of `vector.load` and
1734 /// `vector.broadcast` if all of the following hold:
1735 /// - Stride of most minor memref dimension must be 1.
1736 /// - Out-of-bounds masking is not required.
1737 /// - If the memref's element type is a vector type then it coincides with the
1738 ///   result type.
1739 /// - The permutation map doesn't perform permutation (broadcasting is allowed).
1740 struct TransferReadToVectorLoadLowering
1741     : public OpRewritePattern<vector::TransferReadOp> {
1742   TransferReadToVectorLoadLowering(MLIRContext *context,
1743                                    llvm::Optional<unsigned> maxRank)
1744       : OpRewritePattern<vector::TransferReadOp>(context),
1745         maxTransferRank(maxRank) {}
1746 
1747   LogicalResult matchAndRewrite(vector::TransferReadOp read,
1748                                 PatternRewriter &rewriter) const override {
1749     if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank)
1750       return failure();
1751 
1752     SmallVector<unsigned, 4> broadcastedDims;
1753     // Permutations are handled by VectorToSCF or
1754     // populateVectorTransferPermutationMapLoweringPatterns.
1755     // We let the 0-d corner case pass-through as it is supported.
1756     if (!read.permutation_map().isMinorIdentityWithBroadcasting(
1757             &broadcastedDims))
1758       return failure();
1759 
1760     auto memRefType = read.getShapedType().dyn_cast<MemRefType>();
1761     if (!memRefType)
1762       return failure();
1763 
1764     // Non-unit strides are handled by VectorToSCF.
1765     if (!vector::isLastMemrefDimUnitStride(memRefType))
1766       return failure();
1767 
1768     // If there is broadcasting involved then we first load the unbroadcasted
1769     // vector, and then broadcast it with `vector.broadcast`.
1770     ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
1771     SmallVector<int64_t, 4> unbroadcastedVectorShape(vectorShape.begin(),
1772                                                      vectorShape.end());
1773     for (unsigned i : broadcastedDims)
1774       unbroadcastedVectorShape[i] = 1;
1775     VectorType unbroadcastedVectorType = VectorType::get(
1776         unbroadcastedVectorShape, read.getVectorType().getElementType());
1777 
1778     // `vector.load` supports vector types as memref's elements only when the
1779     // resulting vector type is the same as the element type.
1780     auto memrefElTy = memRefType.getElementType();
1781     if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType)
1782       return failure();
1783 
1784     // Otherwise, element types of the memref and the vector must match.
1785     if (!memrefElTy.isa<VectorType>() &&
1786         memrefElTy != read.getVectorType().getElementType())
1787       return failure();
1788 
1789     // Out-of-bounds dims are handled by MaterializeTransferMask.
1790     if (read.hasOutOfBoundsDim())
1791       return failure();
1792 
1793     // Create vector load op.
1794     Operation *loadOp;
1795     if (read.mask()) {
1796       Value fill = rewriter.create<vector::SplatOp>(
1797           read.getLoc(), unbroadcastedVectorType, read.padding());
1798       loadOp = rewriter.create<vector::MaskedLoadOp>(
1799           read.getLoc(), unbroadcastedVectorType, read.source(), read.indices(),
1800           read.mask(), fill);
1801     } else {
1802       loadOp = rewriter.create<vector::LoadOp>(read.getLoc(),
1803                                                unbroadcastedVectorType,
1804                                                read.source(), read.indices());
1805     }
1806 
1807     // Insert a broadcasting op if required.
1808     if (!broadcastedDims.empty()) {
1809       rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1810           read, read.getVectorType(), loadOp->getResult(0));
1811     } else {
1812       rewriter.replaceOp(read, loadOp->getResult(0));
1813     }
1814 
1815     return success();
1816   }
1817 
1818   llvm::Optional<unsigned> maxTransferRank;
1819 };
1820 
1821 /// Replace a 0-d vector.load with a memref.load + vector.broadcast.
1822 // TODO: we shouldn't cross the vector/scalar domains just for this
1823 // but atm we lack the infra to avoid it. Possible solutions include:
1824 // - go directly to LLVM + bitcast
1825 // - introduce a bitcast op and likely a new pointer dialect
1826 // - let memref.load/store additionally support the 0-d vector case
1827 // There are still deeper data layout issues lingering even in this
1828 // trivial case (for architectures for which this matters).
1829 struct VectorLoadToMemrefLoadLowering
1830     : public OpRewritePattern<vector::LoadOp> {
1831   using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
1832 
1833   LogicalResult matchAndRewrite(vector::LoadOp loadOp,
1834                                 PatternRewriter &rewriter) const override {
1835     auto vecType = loadOp.getVectorType();
1836     if (vecType.getNumElements() != 1)
1837       return failure();
1838     auto memrefLoad = rewriter.create<memref::LoadOp>(
1839         loadOp.getLoc(), loadOp.base(), loadOp.indices());
1840     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
1841                                                      memrefLoad);
1842     return success();
1843   }
1844 };
1845 
1846 /// Replace a 0-d vector.store with a vector.extractelement + memref.store.
1847 struct VectorStoreToMemrefStoreLowering
1848     : public OpRewritePattern<vector::StoreOp> {
1849   using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
1850 
1851   LogicalResult matchAndRewrite(vector::StoreOp storeOp,
1852                                 PatternRewriter &rewriter) const override {
1853     auto vecType = storeOp.getVectorType();
1854     if (vecType.getNumElements() != 1)
1855       return failure();
1856     Value extracted;
1857     if (vecType.getRank() == 0) {
1858       // TODO: Unifiy once ExtractOp supports 0-d vectors.
1859       extracted = rewriter.create<vector::ExtractElementOp>(
1860           storeOp.getLoc(), storeOp.valueToStore());
1861     } else {
1862       SmallVector<int64_t> indices(vecType.getRank(), 0);
1863       extracted = rewriter.create<vector::ExtractOp>(
1864           storeOp.getLoc(), storeOp.valueToStore(), indices);
1865     }
1866 
1867     rewriter.replaceOpWithNewOp<memref::StoreOp>(
1868         storeOp, extracted, storeOp.base(), storeOp.indices());
1869     return success();
1870   }
1871 };
1872 
1873 /// Progressive lowering of transfer_write. This pattern supports lowering of
1874 /// `vector.transfer_write` to `vector.store` if all of the following hold:
1875 /// - Stride of most minor memref dimension must be 1.
1876 /// - Out-of-bounds masking is not required.
1877 /// - If the memref's element type is a vector type then it coincides with the
1878 ///   type of the written value.
1879 /// - The permutation map is the minor identity map (neither permutation nor
1880 ///   broadcasting is allowed).
1881 struct TransferWriteToVectorStoreLowering
1882     : public OpRewritePattern<vector::TransferWriteOp> {
1883   TransferWriteToVectorStoreLowering(MLIRContext *context,
1884                                      llvm::Optional<unsigned> maxRank)
1885       : OpRewritePattern<vector::TransferWriteOp>(context),
1886         maxTransferRank(maxRank) {}
1887 
1888   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
1889                                 PatternRewriter &rewriter) const override {
1890     if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank)
1891       return failure();
1892 
1893     // Permutations are handled by VectorToSCF or
1894     // populateVectorTransferPermutationMapLoweringPatterns.
1895     if ( // pass-through for the 0-d corner case.
1896         !write.permutation_map().isMinorIdentity())
1897       return failure();
1898 
1899     auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
1900     if (!memRefType)
1901       return failure();
1902 
1903     // Non-unit strides are handled by VectorToSCF.
1904     if (!vector::isLastMemrefDimUnitStride(memRefType))
1905       return failure();
1906 
1907     // `vector.store` supports vector types as memref's elements only when the
1908     // type of the vector value being written is the same as the element type.
1909     auto memrefElTy = memRefType.getElementType();
1910     if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType())
1911       return failure();
1912 
1913     // Otherwise, element types of the memref and the vector must match.
1914     if (!memrefElTy.isa<VectorType>() &&
1915         memrefElTy != write.getVectorType().getElementType())
1916       return failure();
1917 
1918     // Out-of-bounds dims are handled by MaterializeTransferMask.
1919     if (write.hasOutOfBoundsDim())
1920       return failure();
1921     if (write.mask()) {
1922       rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
1923           write, write.source(), write.indices(), write.mask(), write.vector());
1924     } else {
1925       rewriter.replaceOpWithNewOp<vector::StoreOp>(
1926           write, write.vector(), write.source(), write.indices());
1927     }
1928     return success();
1929   }
1930 
1931   llvm::Optional<unsigned> maxTransferRank;
1932 };
1933 
1934 // Returns the values in `arrayAttr` as an integer vector.
1935 static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) {
1936   return llvm::to_vector<4>(
1937       llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
1938                       [](IntegerAttr attr) { return attr.getInt(); }));
1939 }
1940 
1941 // Shuffles vector.bitcast op after vector.extract op.
1942 //
1943 // This transforms IR like:
1944 //   %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
1945 //   %1 = vector.extract %0[3] : vector<8xf16>
1946 // Into:
1947 //   %0 = vector.extract %src[1] : vector<4xf32>
1948 //   %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16>
1949 //   %2 = vector.extract %1[1] : vector<2xf16>
1950 struct BubbleDownVectorBitCastForExtract
1951     : public OpRewritePattern<vector::ExtractOp> {
1952   using OpRewritePattern::OpRewritePattern;
1953 
1954   LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
1955                                 PatternRewriter &rewriter) const override {
1956     // Only support extracting scalars for now.
1957     if (extractOp.getVectorType().getRank() != 1)
1958       return failure();
1959 
1960     auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>();
1961     if (!castOp)
1962       return failure();
1963 
1964     VectorType castSrcType = castOp.getSourceVectorType();
1965     VectorType castDstType = castOp.getResultVectorType();
1966     assert(castSrcType.getRank() == castDstType.getRank());
1967 
1968     // Fail to match if we only have one element in the cast op source.
1969     // This is to avoid infinite loop given that this pattern can generate
1970     // such cases.
1971     if (castSrcType.getNumElements() == 1)
1972       return failure();
1973 
1974     // Only support casting to a larger number of elements or now.
1975     // E.g., vector<4xf32> -> vector<8xf16>.
1976     if (castSrcType.getNumElements() > castDstType.getNumElements())
1977       return failure();
1978 
1979     unsigned expandRatio =
1980         castDstType.getNumElements() / castSrcType.getNumElements();
1981 
1982     auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t {
1983       return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
1984     };
1985 
1986     uint64_t index = getFirstIntValue(extractOp.position());
1987 
1988     // Get the single scalar (as a vector) in the source value that packs the
1989     // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
1990     VectorType oneScalarType =
1991         VectorType::get({1}, castSrcType.getElementType());
1992     Value packedValue = rewriter.create<vector::ExtractOp>(
1993         extractOp.getLoc(), oneScalarType, castOp.source(),
1994         rewriter.getI64ArrayAttr(index / expandRatio));
1995 
1996     // Cast it to a vector with the desired scalar's type.
1997     // E.g. f32 -> vector<2xf16>
1998     VectorType packedType =
1999         VectorType::get({expandRatio}, castDstType.getElementType());
2000     Value castedValue = rewriter.create<vector::BitCastOp>(
2001         extractOp.getLoc(), packedType, packedValue);
2002 
2003     // Finally extract the desired scalar.
2004     rewriter.replaceOpWithNewOp<vector::ExtractOp>(
2005         extractOp, extractOp.getType(), castedValue,
2006         rewriter.getI64ArrayAttr(index % expandRatio));
2007 
2008     return success();
2009   }
2010 };
2011 
2012 // Shuffles vector.bitcast op after vector.extract_strided_slice op.
2013 //
2014 // This transforms IR like:
2015 //    %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
2016 //     %0 = vector.extract_strided_slice %cast {
2017 //            offsets = [4], sizes = [4], strides = [1]
2018 //          } : vector<8xf16> to vector<4xf16>
2019 // Into:
2020 //   %0 = vector.extract_strided_slice %src {
2021 //          offsets = [2], sizes = [2], strides = [1]
2022 //        } : vector<4xf32> to vector<2xf32>
2023 //   %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16>
2024 struct BubbleDownBitCastForStridedSliceExtract
2025     : public OpRewritePattern<vector::ExtractStridedSliceOp> {
2026   using OpRewritePattern::OpRewritePattern;
2027 
2028   LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
2029                                 PatternRewriter &rewriter) const override {
2030     auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>();
2031     if (!castOp)
2032       return failure();
2033 
2034     VectorType castSrcType = castOp.getSourceVectorType();
2035     VectorType castDstType = castOp.getResultVectorType();
2036     assert(castSrcType.getRank() == castDstType.getRank());
2037 
2038     int64_t castSrcLastDim = castSrcType.getShape().back();
2039     int64_t castDstLastDim = castDstType.getShape().back();
2040     // Require casting to more elements for now; other cases to be implemented.
2041     if (castSrcLastDim > castDstLastDim)
2042       return failure();
2043 
2044     // Only accept all one strides for now.
2045     if (llvm::any_of(extractOp.strides().getAsValueRange<IntegerAttr>(),
2046                      [](const APInt &val) { return !val.isOneValue(); }))
2047       return failure();
2048 
2049     unsigned rank = extractOp.getVectorType().getRank();
2050     assert(castDstLastDim % castSrcLastDim == 0);
2051     int64_t expandRatio = castDstLastDim / castSrcLastDim;
2052 
2053     // If we have a less number of offsets than the rank, then implicitly we
2054     // are selecting the full range for the last bitcasted dimension; other
2055     // dimensions aren't affected. Otherwise, we need to scale down the last
2056     // dimension's offset given we are extracting from less elements now.
2057     ArrayAttr newOffsets = extractOp.offsets();
2058     if (newOffsets.size() == rank) {
2059       SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
2060       if (offsets.back() % expandRatio != 0)
2061         return failure();
2062       offsets.back() = offsets.back() / expandRatio;
2063       newOffsets = rewriter.getI64ArrayAttr(offsets);
2064     }
2065 
2066     // Similarly for sizes.
2067     ArrayAttr newSizes = extractOp.sizes();
2068     if (newSizes.size() == rank) {
2069       SmallVector<int64_t, 4> sizes = getIntValueVector(newSizes);
2070       if (sizes.back() % expandRatio != 0)
2071         return failure();
2072       sizes.back() = sizes.back() / expandRatio;
2073       newSizes = rewriter.getI64ArrayAttr(sizes);
2074     }
2075 
2076     SmallVector<int64_t, 4> dims =
2077         llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape());
2078     dims.back() = dims.back() / expandRatio;
2079     VectorType newExtractType =
2080         VectorType::get(dims, castSrcType.getElementType());
2081 
2082     auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
2083         extractOp.getLoc(), newExtractType, castOp.source(), newOffsets,
2084         newSizes, extractOp.strides());
2085 
2086     rewriter.replaceOpWithNewOp<vector::BitCastOp>(
2087         extractOp, extractOp.getType(), newExtractOp);
2088 
2089     return success();
2090   }
2091 };
2092 
2093 // Shuffles vector.bitcast op before vector.insert_strided_slice op.
2094 //
2095 // This transforms IR like:
2096 //   %0 = vector.insert_strided_slice %src, %dst {
2097 //          offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
2098 //   %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
2099 // Into:
2100 //   %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32>
2101 //   %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32>
2102 //   %2 = vector.insert_strided_slice %src, %dst {
2103 //          offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
2104 struct BubbleUpBitCastForStridedSliceInsert
2105     : public OpRewritePattern<vector::BitCastOp> {
2106   using OpRewritePattern::OpRewritePattern;
2107   LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
2108                                 PatternRewriter &rewriter) const override {
2109     VectorType castSrcType = bitcastOp.getSourceVectorType();
2110     VectorType castDstType = bitcastOp.getResultVectorType();
2111     assert(castSrcType.getRank() == castDstType.getRank());
2112 
2113     int64_t castSrcLastDim = castSrcType.getShape().back();
2114     int64_t castDstLastDim = castDstType.getShape().back();
2115     // Require casting to less elements for now; other cases to be implemented.
2116     if (castSrcLastDim < castDstLastDim)
2117       return failure();
2118 
2119     assert(castSrcLastDim % castDstLastDim == 0);
2120     int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
2121 
2122     auto insertOp =
2123         bitcastOp.source().getDefiningOp<vector::InsertStridedSliceOp>();
2124     if (!insertOp)
2125       return failure();
2126 
2127     // Only accept all one strides for now.
2128     if (llvm::any_of(insertOp.strides().getAsValueRange<IntegerAttr>(),
2129                      [](const APInt &val) { return !val.isOneValue(); }))
2130       return failure();
2131 
2132     unsigned rank = insertOp.getSourceVectorType().getRank();
2133     // Require insert op to have the same rank for the source and destination
2134     // vector; other cases to be implemented.
2135     if (rank != insertOp.getDestVectorType().getRank())
2136       return failure();
2137 
2138     ArrayAttr newOffsets = insertOp.offsets();
2139     assert(newOffsets.size() == rank);
2140     SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
2141     if (offsets.back() % shrinkRatio != 0)
2142       return failure();
2143     offsets.back() = offsets.back() / shrinkRatio;
2144     newOffsets = rewriter.getI64ArrayAttr(offsets);
2145 
2146     SmallVector<int64_t, 4> srcDims =
2147         llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
2148     srcDims.back() = srcDims.back() / shrinkRatio;
2149     VectorType newCastSrcType =
2150         VectorType::get(srcDims, castDstType.getElementType());
2151 
2152     auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
2153         bitcastOp.getLoc(), newCastSrcType, insertOp.source());
2154 
2155     SmallVector<int64_t, 4> dstDims =
2156         llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
2157     dstDims.back() = dstDims.back() / shrinkRatio;
2158     VectorType newCastDstType =
2159         VectorType::get(dstDims, castDstType.getElementType());
2160 
2161     auto newCastDstOp = rewriter.create<vector::BitCastOp>(
2162         bitcastOp.getLoc(), newCastDstType, insertOp.dest());
2163 
2164     rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
2165         bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
2166         insertOp.strides());
2167 
2168     return success();
2169   }
2170 };
2171 
2172 // Helper that returns a vector comparison that constructs a mask:
2173 //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
2174 //
2175 // If `dim == 0` then the result will be a 0-D vector.
2176 //
2177 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
2178 //       much more compact, IR for this operation, but LLVM eventually
2179 //       generates more elaborate instructions for this intrinsic since it
2180 //       is very conservative on the boundary conditions.
2181 static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
2182                                    bool indexOptimizations, int64_t dim,
2183                                    Value b, Value *off = nullptr) {
2184   auto loc = op->getLoc();
2185   // If we can assume all indices fit in 32-bit, we perform the vector
2186   // comparison in 32-bit to get a higher degree of SIMD parallelism.
2187   // Otherwise we perform the vector comparison using 64-bit indices.
2188   Type idxType =
2189       indexOptimizations ? rewriter.getI32Type() : rewriter.getI64Type();
2190   DenseIntElementsAttr indicesAttr;
2191   if (dim == 0 && indexOptimizations) {
2192     indicesAttr = DenseIntElementsAttr::get(
2193         VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0});
2194   } else if (dim == 0) {
2195     indicesAttr = DenseIntElementsAttr::get(
2196         VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0});
2197   } else if (indexOptimizations) {
2198     indicesAttr = rewriter.getI32VectorAttr(
2199         llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
2200   } else {
2201     indicesAttr = rewriter.getI64VectorAttr(
2202         llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
2203   }
2204   Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
2205   // Add in an offset if requested.
2206   if (off) {
2207     Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
2208     Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o);
2209     indices = rewriter.create<arith::AddIOp>(loc, ov, indices);
2210   }
2211   // Construct the vector comparison.
2212   Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
2213   Value bounds =
2214       rewriter.create<vector::SplatOp>(loc, indices.getType(), bound);
2215   return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
2216                                         bounds);
2217 }
2218 
2219 template <typename ConcreteOp>
2220 struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
2221 public:
2222   explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt)
2223       : mlir::OpRewritePattern<ConcreteOp>(context),
2224         indexOptimizations(enableIndexOpt) {}
2225 
2226   LogicalResult matchAndRewrite(ConcreteOp xferOp,
2227                                 PatternRewriter &rewriter) const override {
2228     if (!xferOp.hasOutOfBoundsDim())
2229       return failure();
2230 
2231     if (xferOp.getVectorType().getRank() > 1 ||
2232         llvm::size(xferOp.indices()) == 0)
2233       return failure();
2234 
2235     Location loc = xferOp->getLoc();
2236     VectorType vtp = xferOp.getVectorType();
2237 
2238     // Create the in-bounds mask with all elements between [0 .. dim - offset)
2239     // set and [dim - offset .. vector_length) unset.
2240     //
2241     // TODO: when the leaf transfer rank is k > 1, we need the last `k`
2242     //       dimensions here.
2243     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
2244     Value off = xferOp.indices()[lastIndex];
2245     Value dim =
2246         vector::createOrFoldDimOp(rewriter, loc, xferOp.source(), lastIndex);
2247     Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off);
2248     Value mask = rewriter.create<vector::CreateMaskOp>(
2249         loc,
2250         VectorType::get(vtp.getShape(), rewriter.getI1Type(),
2251                         vtp.getNumScalableDims()),
2252         b);
2253     if (xferOp.mask()) {
2254       // Intersect the in-bounds with the mask specified as an op parameter.
2255       mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.mask());
2256     }
2257 
2258     rewriter.updateRootInPlace(xferOp, [&]() {
2259       xferOp.maskMutable().assign(mask);
2260       xferOp.in_boundsAttr(rewriter.getBoolArrayAttr({true}));
2261     });
2262 
2263     return success();
2264   }
2265 
2266 private:
2267   const bool indexOptimizations;
2268 };
2269 
2270 /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only).
2271 class VectorCreateMaskOpConversion
2272     : public OpRewritePattern<vector::CreateMaskOp> {
2273 public:
2274   explicit VectorCreateMaskOpConversion(MLIRContext *context,
2275                                         bool enableIndexOpt)
2276       : mlir::OpRewritePattern<vector::CreateMaskOp>(context),
2277         indexOptimizations(enableIndexOpt) {}
2278 
2279   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
2280                                 PatternRewriter &rewriter) const override {
2281     auto dstType = op.getType();
2282     if (dstType.cast<VectorType>().isScalable())
2283       return failure();
2284     int64_t rank = dstType.getRank();
2285     if (rank > 1)
2286       return failure();
2287     rewriter.replaceOp(
2288         op, buildVectorComparison(rewriter, op, indexOptimizations,
2289                                   rank == 0 ? 0 : dstType.getDimSize(0),
2290                                   op.getOperand(0)));
2291     return success();
2292   }
2293 
2294 private:
2295   const bool indexOptimizations;
2296 };
2297 
2298 // Drop inner most contiguous unit dimensions from transfer_read operand.
2299 class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
2300   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
2301 
2302   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
2303                                 PatternRewriter &rewriter) const override {
2304     // TODO: support 0-d corner case.
2305     if (readOp.getTransferRank() == 0)
2306       return failure();
2307 
2308     // TODO: support mask.
2309     if (readOp.mask())
2310       return failure();
2311 
2312     auto srcType = readOp.source().getType().dyn_cast<MemRefType>();
2313     if (!srcType || !srcType.hasStaticShape())
2314       return failure();
2315 
2316     if (!readOp.permutation_map().isMinorIdentity())
2317       return failure();
2318 
2319     auto targetType = readOp.getVectorType();
2320     if (targetType.getRank() <= 1)
2321       return failure();
2322 
2323     SmallVector<int64_t> srcStrides;
2324     int64_t srcOffset;
2325     if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
2326       return failure();
2327 
2328     size_t dimsToDrop = 0;
2329     for (size_t i = 1; i < srcStrides.size(); ++i) {
2330       int dim = srcType.getRank() - i - 1;
2331       if (srcStrides[dim] == 1) {
2332         dimsToDrop++;
2333       } else {
2334         break;
2335       }
2336     }
2337     if (dimsToDrop == 0)
2338       return failure();
2339 
2340     auto resultTargetVecType =
2341         VectorType::get(targetType.getShape().drop_back(dimsToDrop),
2342                         targetType.getElementType());
2343 
2344     MemRefType resultMemrefType;
2345     if (srcType.getLayout().getAffineMap().isIdentity()) {
2346       resultMemrefType = MemRefType::get(
2347           srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
2348           {}, srcType.getMemorySpaceAsInt());
2349     } else {
2350       AffineMap map = srcType.getLayout().getAffineMap();
2351       int numResultDims = map.getNumDims() - dimsToDrop;
2352       int numSymbols = map.getNumSymbols();
2353       for (size_t i = 0; i < dimsToDrop; ++i) {
2354         int dim = srcType.getRank() - i - 1;
2355         map = map.replace(rewriter.getAffineDimExpr(dim),
2356                           rewriter.getAffineConstantExpr(0), numResultDims,
2357                           numSymbols);
2358       }
2359       resultMemrefType = MemRefType::get(
2360           srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
2361           map, srcType.getMemorySpaceAsInt());
2362     }
2363 
2364     auto loc = readOp.getLoc();
2365     SmallVector<int64_t> offsets(srcType.getRank(), 0);
2366     SmallVector<int64_t> strides(srcType.getRank(), 1);
2367 
2368     ArrayAttr inBoundsAttr =
2369         readOp.in_bounds()
2370             ? rewriter.getArrayAttr(
2371                   readOp.in_boundsAttr().getValue().drop_back(dimsToDrop))
2372             : ArrayAttr();
2373     Value rankedReducedView = rewriter.create<memref::SubViewOp>(
2374         loc, resultMemrefType, readOp.source(), offsets, srcType.getShape(),
2375         strides);
2376     auto permMap = getTransferMinorIdentityMap(
2377         rankedReducedView.getType().cast<ShapedType>(), resultTargetVecType);
2378     Value result = rewriter.create<vector::TransferReadOp>(
2379         loc, resultTargetVecType, rankedReducedView,
2380         readOp.indices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
2381         readOp.padding(),
2382         // TODO: support mask.
2383         /*mask=*/Value(), inBoundsAttr);
2384     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
2385                                                      result);
2386     return success();
2387   }
2388 };
2389 
2390 namespace {
2391 
2392 /// This function checks to see if the vector combining kind
2393 /// is consistent with the integer or float element type.
2394 static bool isValidKind(bool isInt, vector::CombiningKind kind) {
2395   using vector::CombiningKind;
2396   enum class KindType { FLOAT, INT, INVALID };
2397   KindType type{KindType::INVALID};
2398   switch (kind) {
2399   case CombiningKind::MINF:
2400   case CombiningKind::MAXF:
2401     type = KindType::FLOAT;
2402     break;
2403   case CombiningKind::MINUI:
2404   case CombiningKind::MINSI:
2405   case CombiningKind::MAXUI:
2406   case CombiningKind::MAXSI:
2407   case CombiningKind::AND:
2408   case CombiningKind::OR:
2409   case CombiningKind::XOR:
2410     type = KindType::INT;
2411     break;
2412   case CombiningKind::ADD:
2413   case CombiningKind::MUL:
2414     type = isInt ? KindType::INT : KindType::FLOAT;
2415     break;
2416   }
2417   bool isValidIntKind = (type == KindType::INT) && isInt;
2418   bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt);
2419   return (isValidIntKind || isValidFloatKind);
2420 }
2421 
2422 /// This function constructs the appropriate integer or float
2423 /// operation given the vector combining kind and operands. The
2424 /// supported int operations are : add, mul, min (signed/unsigned),
2425 /// max(signed/unsigned), and, or, xor. The supported float
2426 /// operations are : add, mul, min and max.
2427 static Value genOperator(Location loc, Value x, Value y,
2428                          vector::CombiningKind kind,
2429                          PatternRewriter &rewriter) {
2430   using vector::CombiningKind;
2431 
2432   auto elType = x.getType().cast<VectorType>().getElementType();
2433   bool isInt = elType.isIntOrIndex();
2434 
2435   Value combinedResult{nullptr};
2436   switch (kind) {
2437   case CombiningKind::ADD:
2438     if (isInt)
2439       combinedResult = rewriter.create<arith::AddIOp>(loc, x, y);
2440     else
2441       combinedResult = rewriter.create<arith::AddFOp>(loc, x, y);
2442     break;
2443   case CombiningKind::MUL:
2444     if (isInt)
2445       combinedResult = rewriter.create<arith::MulIOp>(loc, x, y);
2446     else
2447       combinedResult = rewriter.create<arith::MulFOp>(loc, x, y);
2448     break;
2449   case CombiningKind::MINUI:
2450     combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y);
2451     break;
2452   case CombiningKind::MINSI:
2453     combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y);
2454     break;
2455   case CombiningKind::MAXUI:
2456     combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y);
2457     break;
2458   case CombiningKind::MAXSI:
2459     combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y);
2460     break;
2461   case CombiningKind::AND:
2462     combinedResult = rewriter.create<arith::AndIOp>(loc, x, y);
2463     break;
2464   case CombiningKind::OR:
2465     combinedResult = rewriter.create<arith::OrIOp>(loc, x, y);
2466     break;
2467   case CombiningKind::XOR:
2468     combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y);
2469     break;
2470   case CombiningKind::MINF:
2471     combinedResult = rewriter.create<arith::MinFOp>(loc, x, y);
2472     break;
2473   case CombiningKind::MAXF:
2474     combinedResult = rewriter.create<arith::MaxFOp>(loc, x, y);
2475     break;
2476   }
2477   return combinedResult;
2478 }
2479 
2480 /// Convert vector.scan op into arith ops and
2481 /// vector.insert_strided_slice/extract_strided_slice
2482 ///
2483 /// Ex:
2484 /// ```
2485 ///   %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim =
2486 ///   1} :
2487 ///     (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>)
2488 /// ```
2489 /// Gets converted to:
2490 /// ```
2491 ///   %cst = arith.constant dense<0> : vector<2x3xi32>
2492 ///   %0 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 1],
2493 ///   strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %1 =
2494 ///   vector.insert_strided_slice %0, %cst {offsets = [0, 0], strides = [1, 1]}
2495 ///   : vector<2x1xi32> into vector<2x3xi32> %2 = vector.extract_strided_slice
2496 ///   %arg0 {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} :
2497 ///   vector<2x3xi32> to vector<2x1xi32> %3 = arith.muli %0, %2 :
2498 ///   vector<2x1xi32> %4 = vector.insert_strided_slice %3, %1 {offsets = [0, 1],
2499 ///   strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %5 =
2500 ///   vector.extract_strided_slice %arg0 {offsets = [0, 2], sizes = [2, 1],
2501 ///   strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %6 = arith.muli %3,
2502 ///   %5 : vector<2x1xi32> %7 = vector.insert_strided_slice %6, %4 {offsets =
2503 ///   [0, 2], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %8 =
2504 ///   vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32> return %7, %8 :
2505 ///   vector<2x3xi32>, vector<2xi32>
2506 /// ```
2507 struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
2508   using OpRewritePattern<vector::ScanOp>::OpRewritePattern;
2509 
2510   LogicalResult matchAndRewrite(vector::ScanOp scanOp,
2511                                 PatternRewriter &rewriter) const override {
2512     auto loc = scanOp.getLoc();
2513     VectorType destType = scanOp.getDestType();
2514     ArrayRef<int64_t> destShape = destType.getShape();
2515     auto elType = destType.getElementType();
2516     bool isInt = elType.isIntOrIndex();
2517     if (!isValidKind(isInt, scanOp.kind()))
2518       return failure();
2519 
2520     VectorType resType = VectorType::get(destShape, elType);
2521     Value result = rewriter.create<arith::ConstantOp>(
2522         loc, resType, rewriter.getZeroAttr(resType));
2523     int64_t reductionDim = scanOp.reduction_dim();
2524     bool inclusive = scanOp.inclusive();
2525     int64_t destRank = destType.getRank();
2526     VectorType initialValueType = scanOp.getInitialValueType();
2527     int64_t initialValueRank = initialValueType.getRank();
2528 
2529     SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end());
2530     reductionShape[reductionDim] = 1;
2531     VectorType reductionType = VectorType::get(reductionShape, elType);
2532     SmallVector<int64_t> offsets(destRank, 0);
2533     SmallVector<int64_t> strides(destRank, 1);
2534     SmallVector<int64_t> sizes(destShape.begin(), destShape.end());
2535     sizes[reductionDim] = 1;
2536     ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes);
2537     ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides);
2538 
2539     Value lastOutput, lastInput;
2540     for (int i = 0; i < destShape[reductionDim]; i++) {
2541       offsets[reductionDim] = i;
2542       ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets);
2543       Value input = rewriter.create<vector::ExtractStridedSliceOp>(
2544           loc, reductionType, scanOp.source(), scanOffsets, scanSizes,
2545           scanStrides);
2546       Value output;
2547       if (i == 0) {
2548         if (inclusive) {
2549           output = input;
2550         } else {
2551           if (initialValueRank == 0) {
2552             // ShapeCastOp cannot handle 0-D vectors
2553             output = rewriter.create<vector::BroadcastOp>(
2554                 loc, input.getType(), scanOp.initial_value());
2555           } else {
2556             output = rewriter.create<vector::ShapeCastOp>(
2557                 loc, input.getType(), scanOp.initial_value());
2558           }
2559         }
2560       } else {
2561         Value y = inclusive ? input : lastInput;
2562         output = genOperator(loc, lastOutput, y, scanOp.kind(), rewriter);
2563         assert(output != nullptr);
2564       }
2565       result = rewriter.create<vector::InsertStridedSliceOp>(
2566           loc, output, result, offsets, strides);
2567       lastOutput = output;
2568       lastInput = input;
2569     }
2570 
2571     Value reduction;
2572     if (initialValueRank == 0) {
2573       Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0);
2574       reduction =
2575           rewriter.create<vector::BroadcastOp>(loc, initialValueType, v);
2576     } else {
2577       reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType,
2578                                                        lastOutput);
2579     }
2580 
2581     rewriter.replaceOp(scanOp, {result, reduction});
2582     return success();
2583   }
2584 };
2585 
2586 } // namespace
2587 
2588 void mlir::vector::populateVectorMaskMaterializationPatterns(
2589     RewritePatternSet &patterns, bool indexOptimizations) {
2590   patterns.add<VectorCreateMaskOpConversion,
2591                MaterializeTransferMask<vector::TransferReadOp>,
2592                MaterializeTransferMask<vector::TransferWriteOp>>(
2593       patterns.getContext(), indexOptimizations);
2594 }
2595 
2596 void mlir::vector::populateShapeCastFoldingPatterns(
2597     RewritePatternSet &patterns) {
2598   patterns.add<ShapeCastOpFolder>(patterns.getContext());
2599 }
2600 
2601 void mlir::vector::populateBubbleVectorBitCastOpPatterns(
2602     RewritePatternSet &patterns) {
2603   patterns.add<BubbleDownVectorBitCastForExtract,
2604                BubbleDownBitCastForStridedSliceExtract,
2605                BubbleUpBitCastForStridedSliceInsert>(patterns.getContext());
2606 }
2607 
2608 void mlir::vector::populateVectorBroadcastLoweringPatterns(
2609     RewritePatternSet &patterns) {
2610   patterns.add<BroadcastOpLowering>(patterns.getContext());
2611 }
2612 
2613 void mlir::vector::populateVectorMaskOpLoweringPatterns(
2614     RewritePatternSet &patterns) {
2615   patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
2616       patterns.getContext());
2617 }
2618 
2619 void mlir::vector::populateVectorShapeCastLoweringPatterns(
2620     RewritePatternSet &patterns) {
2621   patterns.add<ShapeCastOp2DDownCastRewritePattern,
2622                ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>(
2623       patterns.getContext());
2624 }
2625 
2626 void mlir::vector::populateVectorContractLoweringPatterns(
2627     RewritePatternSet &patterns, VectorTransformsOptions options) {
2628   patterns.add<OuterProductOpLowering>(patterns.getContext());
2629   patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
2630                ContractionOpToOuterProductOpLowering>(options,
2631                                                       patterns.getContext());
2632 }
2633 
2634 void mlir::vector::populateVectorTransposeLoweringPatterns(
2635     RewritePatternSet &patterns, VectorTransformsOptions options) {
2636   patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
2637       options, patterns.getContext());
2638 }
2639 
2640 void mlir::vector::populateVectorReductionToContractPatterns(
2641     RewritePatternSet &patterns) {
2642   patterns.add<MultiReduceToContract, CombineContractBroadcast,
2643                CombineContractTranspose, ReorderCastOpsOnBroadcast,
2644                ReorderCastOpsOnTranspose>(patterns.getContext());
2645 }
2646 
2647 void mlir::vector::
2648     populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
2649         RewritePatternSet &patterns) {
2650   patterns.add<DropInnerMostUnitDims>(patterns.getContext());
2651 }
2652 
2653 void mlir::vector::populateVectorTransferLoweringPatterns(
2654     RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank) {
2655   patterns.add<TransferReadToVectorLoadLowering,
2656                TransferWriteToVectorStoreLowering>(patterns.getContext(),
2657                                                    maxTransferRank);
2658   patterns
2659       .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
2660           patterns.getContext());
2661 }
2662 
2663 void mlir::vector::populateVectorScanLoweringPatterns(
2664     RewritePatternSet &patterns) {
2665   patterns.add<ScanToArithOps>(patterns.getContext());
2666 }
2667