1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites as 1->N patterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
14 
15 #include <type_traits>
16 
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
19 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
20 #include "mlir/Dialect/Linalg/IR/Linalg.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/SCF/SCF.h"
23 #include "mlir/Dialect/Utils/IndexingUtils.h"
24 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
25 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
26 #include "mlir/IR/BuiltinTypes.h"
27 #include "mlir/IR/ImplicitLocOpBuilder.h"
28 #include "mlir/IR/Matchers.h"
29 #include "mlir/IR/PatternMatch.h"
30 #include "mlir/Interfaces/VectorInterfaces.h"
31 
32 #include "llvm/ADT/DenseSet.h"
33 #include "llvm/ADT/MapVector.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/Support/CommandLine.h"
36 #include "llvm/Support/Debug.h"
37 #include "llvm/Support/raw_ostream.h"
38 
39 #define DEBUG_TYPE "vector-to-vector"
40 
41 using namespace mlir;
42 using namespace mlir::vector;
43 
44 // Helper to find an index in an affine map.
45 static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
46   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
47     int64_t idx = map.getDimPosition(i);
48     if (idx == index)
49       return i;
50   }
51   return None;
52 }
53 
54 // Helper to construct iterator types with one index removed.
55 static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes,
56                                             int64_t index) {
57   SmallVector<Attribute, 4> results;
58   for (const auto &it : llvm::enumerate(iteratorTypes)) {
59     int64_t idx = it.index();
60     if (idx == index)
61       continue;
62     results.push_back(it.value());
63   }
64   return results;
65 }
66 
67 // Helper to construct an affine map with one index removed.
68 static AffineMap adjustMap(AffineMap map, int64_t index,
69                            PatternRewriter &rewriter) {
70   auto *ctx = rewriter.getContext();
71   SmallVector<AffineExpr, 4> results;
72   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
73     int64_t idx = map.getDimPosition(i);
74     if (idx == index)
75       continue;
76     // Re-insert remaining indices, but renamed when occurring
77     // after the removed index.
78     auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
79     results.push_back(targetExpr);
80   }
81   return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
82 }
83 
84 // Helper method to possibly drop a dimension in a load.
85 // TODO
86 static Value reshapeLoad(Location loc, Value val, VectorType type,
87                          int64_t index, int64_t pos,
88                          PatternRewriter &rewriter) {
89   if (index == -1)
90     return val;
91   Type lowType = VectorType::Builder(type).dropDim(0);
92   // At extraction dimension?
93   if (index == 0) {
94     auto posAttr = rewriter.getI64ArrayAttr(pos);
95     return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
96   }
97   // Unroll leading dimensions.
98   VectorType vType = lowType.cast<VectorType>();
99   Type resType = VectorType::Builder(type).dropDim(index);
100   auto resVectorType = resType.cast<VectorType>();
101   Value result = rewriter.create<arith::ConstantOp>(
102       loc, resVectorType, rewriter.getZeroAttr(resVectorType));
103   for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) {
104     auto posAttr = rewriter.getI64ArrayAttr(d);
105     Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
106     Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
107     result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result,
108                                                posAttr);
109   }
110   return result;
111 }
112 
113 // Helper method to possibly drop a dimension in a store.
114 // TODO
115 static Value reshapeStore(Location loc, Value val, Value result,
116                           VectorType type, int64_t index, int64_t pos,
117                           PatternRewriter &rewriter) {
118   // Unmodified?
119   if (index == -1)
120     return val;
121   // At insertion dimension?
122   if (index == 0) {
123     auto posAttr = rewriter.getI64ArrayAttr(pos);
124     return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
125   }
126   // Unroll leading dimensions.
127   Type lowType = VectorType::Builder(type).dropDim(0);
128   VectorType vType = lowType.cast<VectorType>();
129   Type insType = VectorType::Builder(vType).dropDim(0);
130   for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
131     auto posAttr = rewriter.getI64ArrayAttr(d);
132     Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
133     Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr);
134     Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
135     result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr);
136   }
137   return result;
138 }
139 
140 template <typename IntType>
141 static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
142   return llvm::to_vector<4>(llvm::map_range(
143       arrayAttr.getAsRange<IntegerAttr>(),
144       [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
145 }
146 
147 namespace {
148 
149 /// ShapeCastOpFolder folds cancelling ShapeCastOps away.
150 //
151 // Example:
152 //
153 //  The following MLIR with cancelling ShapeCastOps:
154 //
155 //   %0 = source : vector<5x4x2xf32>
156 //   %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
157 //   %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
158 //   %3 = user %2 : vector<5x4x2xf32>
159 //
160 //  Should canonicalize to the following:
161 //
162 //   %0 = source : vector<5x4x2xf32>
163 //   %1 = user %0 : vector<5x4x2xf32>
164 //
165 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
166   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
167 
168   LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
169                                 PatternRewriter &rewriter) const override {
170     // Check if 'shapeCastOp' has vector source/result type.
171     auto sourceVectorType =
172         shapeCastOp.getSource().getType().dyn_cast_or_null<VectorType>();
173     auto resultVectorType =
174         shapeCastOp.getResult().getType().dyn_cast_or_null<VectorType>();
175     if (!sourceVectorType || !resultVectorType)
176       return failure();
177 
178     // Check if shape cast op source operand is also a shape cast op.
179     auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
180         shapeCastOp.getSource().getDefiningOp());
181     if (!sourceShapeCastOp)
182       return failure();
183     auto operandSourceVectorType =
184         sourceShapeCastOp.getSource().getType().cast<VectorType>();
185     auto operandResultVectorType = sourceShapeCastOp.getType();
186 
187     // Check if shape cast operations invert each other.
188     if (operandSourceVectorType != resultVectorType ||
189         operandResultVectorType != sourceVectorType)
190       return failure();
191 
192     rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.getSource());
193     return success();
194   }
195 };
196 
197 /// Progressive lowering of BroadcastOp.
198 class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
199 public:
200   using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
201 
202   LogicalResult matchAndRewrite(vector::BroadcastOp op,
203                                 PatternRewriter &rewriter) const override {
204     auto loc = op.getLoc();
205     VectorType dstType = op.getVectorType();
206     VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
207     Type eltType = dstType.getElementType();
208 
209     // Scalar to any vector can use splat.
210     if (!srcType) {
211       rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource());
212       return success();
213     }
214 
215     // Determine rank of source and destination.
216     int64_t srcRank = srcType.getRank();
217     int64_t dstRank = dstType.getRank();
218 
219     // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
220     if (srcRank <= 1 && dstRank == 1) {
221       Value ext;
222       if (srcRank == 0)
223         ext = rewriter.create<vector::ExtractElementOp>(loc, op.getSource());
224       else
225         ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
226       rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
227       return success();
228     }
229 
230     // Duplicate this rank.
231     // For example:
232     //   %x = broadcast %y  : k-D to n-D, k < n
233     // becomes:
234     //   %b = broadcast %y  : k-D to (n-1)-D
235     //   %x = [%b,%b,%b,%b] : n-D
236     // becomes:
237     //   %b = [%y,%y]       : (n-1)-D
238     //   %x = [%b,%b,%b,%b] : n-D
239     if (srcRank < dstRank) {
240       // Duplication.
241       VectorType resType =
242           VectorType::get(dstType.getShape().drop_front(), eltType);
243       Value bcst =
244           rewriter.create<vector::BroadcastOp>(loc, resType, op.getSource());
245       Value result = rewriter.create<arith::ConstantOp>(
246           loc, dstType, rewriter.getZeroAttr(dstType));
247       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
248         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
249       rewriter.replaceOp(op, result);
250       return success();
251     }
252 
253     // Find non-matching dimension, if any.
254     assert(srcRank == dstRank);
255     int64_t m = -1;
256     for (int64_t r = 0; r < dstRank; r++)
257       if (srcType.getDimSize(r) != dstType.getDimSize(r)) {
258         m = r;
259         break;
260       }
261 
262     // All trailing dimensions are the same. Simply pass through.
263     if (m == -1) {
264       rewriter.replaceOp(op, op.getSource());
265       return success();
266     }
267 
268     // Any non-matching dimension forces a stretch along this rank.
269     // For example:
270     //   %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32>
271     // becomes:
272     //   %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32>
273     //   %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32>
274     //   %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32>
275     //   %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32>
276     //   %x = [%a,%b,%c,%d]
277     // becomes:
278     //   %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32>
279     //   %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32>
280     //   %a = [%u, %v]
281     //   ..
282     //   %x = [%a,%b,%c,%d]
283     VectorType resType =
284         VectorType::get(dstType.getShape().drop_front(), eltType);
285     Value result = rewriter.create<arith::ConstantOp>(
286         loc, dstType, rewriter.getZeroAttr(dstType));
287     if (m == 0) {
288       // Stetch at start.
289       Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
290       Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
291       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
292         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
293     } else {
294       // Stetch not at start.
295       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
296         Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), d);
297         Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
298         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
299       }
300     }
301     rewriter.replaceOp(op, result);
302     return success();
303   }
304 };
305 
306 /// Given a 'transpose' pattern, prune the rightmost dimensions that are not
307 /// transposed.
308 void pruneNonTransposedDims(ArrayRef<int64_t> transpose,
309                             SmallVectorImpl<int64_t> &result) {
310   size_t numTransposedDims = transpose.size();
311   for (size_t transpDim : llvm::reverse(transpose)) {
312     if (transpDim != numTransposedDims - 1)
313       break;
314     numTransposedDims--;
315   }
316 
317   result.append(transpose.begin(), transpose.begin() + numTransposedDims);
318 }
319 
320 /// Progressive lowering of TransposeOp.
321 /// One:
322 ///   %x = vector.transpose %y, [1, 0]
323 /// is replaced by:
324 ///   %z = arith.constant dense<0.000000e+00>
325 ///   %0 = vector.extract %y[0, 0]
326 ///   %1 = vector.insert %0, %z [0, 0]
327 ///   ..
328 ///   %x = vector.insert .., .. [.., ..]
329 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
330 public:
331   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
332 
333   TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
334                       MLIRContext *context)
335       : OpRewritePattern<vector::TransposeOp>(context),
336         vectorTransformOptions(vectorTransformOptions) {}
337 
338   LogicalResult matchAndRewrite(vector::TransposeOp op,
339                                 PatternRewriter &rewriter) const override {
340     auto loc = op.getLoc();
341 
342     Value input = op.getVector();
343     VectorType inputType = op.getVectorType();
344     VectorType resType = op.getResultType();
345 
346     // Set up convenience transposition table.
347     SmallVector<int64_t, 4> transp;
348     for (auto attr : op.getTransp())
349       transp.push_back(attr.cast<IntegerAttr>().getInt());
350 
351     if (vectorTransformOptions.vectorTransposeLowering ==
352             vector::VectorTransposeLowering::Shuffle &&
353         resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0)
354       return rewriter.notifyMatchFailure(
355           op, "Options specifies lowering to shuffle");
356 
357     // Handle a true 2-D matrix transpose differently when requested.
358     if (vectorTransformOptions.vectorTransposeLowering ==
359             vector::VectorTransposeLowering::Flat &&
360         resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
361       Type flattenedType =
362           VectorType::get(resType.getNumElements(), resType.getElementType());
363       auto matrix =
364           rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
365       auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
366       auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
367       Value trans = rewriter.create<vector::FlatTransposeOp>(
368           loc, flattenedType, matrix, rows, columns);
369       rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
370       return success();
371     }
372 
373     // Generate unrolled extract/insert ops. We do not unroll the rightmost
374     // (i.e., highest-order) dimensions that are not transposed and leave them
375     // in vector form to improve performance. Therefore, we prune those
376     // dimensions from the shape/transpose data structures used to generate the
377     // extract/insert ops.
378     SmallVector<int64_t, 4> prunedTransp;
379     pruneNonTransposedDims(transp, prunedTransp);
380     size_t numPrunedDims = transp.size() - prunedTransp.size();
381     auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
382     SmallVector<int64_t, 4> ones(prunedInShape.size(), 1);
383     auto prunedInStrides = computeStrides(prunedInShape, ones);
384 
385     // Generates the extract/insert operations for every scalar/vector element
386     // of the leftmost transposed dimensions. We traverse every transpose
387     // element using a linearized index that we delinearize to generate the
388     // appropriate indices for the extract/insert operations.
389     Value result = rewriter.create<arith::ConstantOp>(
390         loc, resType, rewriter.getZeroAttr(resType));
391     int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
392 
393     for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
394          ++linearIdx) {
395       auto extractIdxs = delinearize(prunedInStrides, linearIdx);
396       SmallVector<int64_t, 4> insertIdxs(extractIdxs);
397       applyPermutationToVector(insertIdxs, prunedTransp);
398       Value extractOp =
399           rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
400       result =
401           rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
402     }
403 
404     rewriter.replaceOp(op, result);
405     return success();
406   }
407 
408 private:
409   /// Options to control the vector patterns.
410   vector::VectorTransformsOptions vectorTransformOptions;
411 };
412 
413 /// Rewrite a 2-D vector.transpose as a sequence of:
414 ///   vector.shape_cast 2D -> 1D
415 ///   vector.shuffle
416 ///   vector.shape_cast 1D -> 2D
417 class TransposeOp2DToShuffleLowering
418     : public OpRewritePattern<vector::TransposeOp> {
419 public:
420   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
421 
422   TransposeOp2DToShuffleLowering(
423       vector::VectorTransformsOptions vectorTransformOptions,
424       MLIRContext *context)
425       : OpRewritePattern<vector::TransposeOp>(context),
426         vectorTransformOptions(vectorTransformOptions) {}
427 
428   LogicalResult matchAndRewrite(vector::TransposeOp op,
429                                 PatternRewriter &rewriter) const override {
430     auto loc = op.getLoc();
431 
432     VectorType srcType = op.getVectorType();
433     if (srcType.getRank() != 2)
434       return rewriter.notifyMatchFailure(op, "Not a 2D transpose");
435 
436     SmallVector<int64_t, 4> transp;
437     for (auto attr : op.getTransp())
438       transp.push_back(attr.cast<IntegerAttr>().getInt());
439     if (transp[0] != 1 && transp[1] != 0)
440       return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation");
441 
442     if (vectorTransformOptions.vectorTransposeLowering !=
443         VectorTransposeLowering::Shuffle)
444       return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle");
445 
446     int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
447     Value casted = rewriter.create<vector::ShapeCastOp>(
448         loc, VectorType::get({m * n}, srcType.getElementType()),
449         op.getVector());
450     SmallVector<int64_t> mask;
451     mask.reserve(m * n);
452     for (int64_t j = 0; j < n; ++j)
453       for (int64_t i = 0; i < m; ++i)
454         mask.push_back(i * n + j);
455 
456     Value shuffled =
457         rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask);
458     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
459                                                      shuffled);
460 
461     return success();
462   }
463 
464 private:
465   /// Options to control the vector patterns.
466   vector::VectorTransformsOptions vectorTransformOptions;
467 };
468 
469 /// Progressive lowering of OuterProductOp.
470 /// One:
471 ///   %x = vector.outerproduct %lhs, %rhs, %acc
472 /// is replaced by:
473 ///   %z = zero-result
474 ///   %0 = vector.extract %lhs[0]
475 ///   %1 = vector.broadcast %0
476 ///   %2 = vector.extract %acc[0]
477 ///   %3 = vector.fma %1, %rhs, %2
478 ///   %4 = vector.insert %3, %z[0]
479 ///   ..
480 ///   %x = vector.insert %.., %..[N-1]
481 ///
482 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
483 public:
484   using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
485 
486   LogicalResult matchAndRewrite(vector::OuterProductOp op,
487                                 PatternRewriter &rewriter) const override {
488     auto loc = op.getLoc();
489 
490     VectorType lhsType = op.getOperandVectorTypeLHS();
491     VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>();
492     VectorType resType = op.getVectorType();
493     Type eltType = resType.getElementType();
494     bool isInt = eltType.isa<IntegerType, IndexType>();
495     Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0];
496     vector::CombiningKind kind = op.getKind();
497 
498     if (!rhsType) {
499       // Special case: AXPY operation.
500       Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
501       Optional<Value> mult =
502           isInt ? genMultI(loc, op.getLhs(), b, acc, kind, rewriter)
503                 : genMultF(loc, op.getLhs(), b, acc, kind, rewriter);
504       if (!mult.hasValue())
505         return failure();
506       rewriter.replaceOp(op, mult.getValue());
507       return success();
508     }
509 
510     Value result = rewriter.create<arith::ConstantOp>(
511         loc, resType, rewriter.getZeroAttr(resType));
512     for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
513       auto pos = rewriter.getI64ArrayAttr(d);
514       Value x =
515           rewriter.create<vector::ExtractOp>(loc, eltType, op.getLhs(), pos);
516       Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
517       Value r = nullptr;
518       if (acc)
519         r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
520       Optional<Value> m =
521           isInt ? genMultI(loc, a, op.getRhs(), r, kind, rewriter)
522                 : genMultF(loc, a, op.getRhs(), r, kind, rewriter);
523       if (!m.hasValue())
524         return failure();
525       result = rewriter.create<vector::InsertOp>(loc, resType, m.getValue(),
526                                                  result, pos);
527     }
528     rewriter.replaceOp(op, result);
529     return success();
530   }
531 
532 private:
533   static Optional<Value> genMultI(Location loc, Value x, Value y, Value acc,
534                                   vector::CombiningKind kind,
535                                   PatternRewriter &rewriter) {
536     using vector::CombiningKind;
537 
538     auto mul = rewriter.create<arith::MulIOp>(loc, x, y);
539     if (!acc)
540       return Optional<Value>(mul);
541 
542     if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF)
543       // Only valid for floating point types.
544       return Optional<Value>();
545 
546     return makeArithReduction(rewriter, loc, kind, mul, acc);
547   }
548 
549   static Optional<Value> genMultF(Location loc, Value x, Value y, Value acc,
550                                   vector::CombiningKind kind,
551                                   PatternRewriter &rewriter) {
552     using vector::CombiningKind;
553 
554     // Special case for fused multiply-add.
555     if (acc && kind == CombiningKind::ADD) {
556       return Optional<Value>(rewriter.create<vector::FMAOp>(loc, x, y, acc));
557     }
558 
559     auto mul = rewriter.create<arith::MulFOp>(loc, x, y);
560 
561     if (!acc)
562       return Optional<Value>(mul);
563 
564     if (kind == CombiningKind::ADD || kind == CombiningKind::AND ||
565         kind == CombiningKind::MINUI || kind == CombiningKind::MINSI ||
566         kind == CombiningKind::MAXUI || kind == CombiningKind::MAXSI ||
567         kind == CombiningKind::OR || kind == CombiningKind::XOR)
568       // Already handled or only valid for integer types.
569       return Optional<Value>();
570 
571     return makeArithReduction(rewriter, loc, kind, mul, acc);
572   }
573 };
574 
575 /// Progressive lowering of ConstantMaskOp.
576 /// One:
577 ///   %x = vector.constant_mask [a,b]
578 /// is replaced by:
579 ///   %z = zero-result
580 ///   %l = vector.constant_mask [b]
581 ///   %4 = vector.insert %l, %z[0]
582 ///   ..
583 ///   %x = vector.insert %l, %..[a-1]
584 /// until a one-dimensional vector is reached. All these operations
585 /// will be folded at LLVM IR level.
586 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
587 public:
588   using OpRewritePattern<vector::ConstantMaskOp>::OpRewritePattern;
589 
590   LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
591                                 PatternRewriter &rewriter) const override {
592     auto loc = op.getLoc();
593     auto dstType = op.getType();
594     auto eltType = dstType.getElementType();
595     auto dimSizes = op.getMaskDimSizes();
596     int64_t rank = dstType.getRank();
597 
598     if (rank == 0) {
599       assert(dimSizes.size() == 1 &&
600              "Expected exactly one dim size for a 0-D vector");
601       bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1;
602       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
603           op, dstType,
604           DenseIntElementsAttr::get(
605               VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
606               ArrayRef<bool>{value}));
607       return success();
608     }
609 
610     // Scalable constant masks can only be lowered for the "none set" case.
611     if (dstType.cast<VectorType>().isScalable()) {
612       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
613           op, DenseElementsAttr::get(dstType, false));
614       return success();
615     }
616 
617     int64_t trueDim = std::min(dstType.getDimSize(0),
618                                dimSizes[0].cast<IntegerAttr>().getInt());
619 
620     if (rank == 1) {
621       // Express constant 1-D case in explicit vector form:
622       //   [T,..,T,F,..,F].
623       SmallVector<bool, 4> values(dstType.getDimSize(0));
624       for (int64_t d = 0; d < trueDim; d++)
625         values[d] = true;
626       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
627           op, dstType, rewriter.getBoolVectorAttr(values));
628       return success();
629     }
630 
631     VectorType lowType =
632         VectorType::get(dstType.getShape().drop_front(), eltType);
633     SmallVector<int64_t, 4> newDimSizes;
634     for (int64_t r = 1; r < rank; r++)
635       newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
636     Value trueVal = rewriter.create<vector::ConstantMaskOp>(
637         loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
638     Value result = rewriter.create<arith::ConstantOp>(
639         loc, dstType, rewriter.getZeroAttr(dstType));
640     for (int64_t d = 0; d < trueDim; d++) {
641       auto pos = rewriter.getI64ArrayAttr(d);
642       result =
643           rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos);
644     }
645     rewriter.replaceOp(op, result);
646     return success();
647   }
648 };
649 
650 /// Progressive lowering of CreateMaskOp.
651 /// One:
652 ///   %x = vector.create_mask %a, ... : vector<dx...>
653 /// is replaced by:
654 ///   %l = vector.create_mask ... : vector<...>  ; one lower rank
655 ///   %0 = arith.cmpi "slt", %ci, %a       |
656 ///   %1 = select %0, %l, %zeroes    |
657 ///   %r = vector.insert %1, %pr [i] | d-times
658 ///   %x = ....
659 /// until a one-dimensional vector is reached.
660 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
661 public:
662   using OpRewritePattern<vector::CreateMaskOp>::OpRewritePattern;
663 
664   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
665                                 PatternRewriter &rewriter) const override {
666     auto dstType = op.getResult().getType().cast<VectorType>();
667     int64_t rank = dstType.getRank();
668     if (rank <= 1)
669       return rewriter.notifyMatchFailure(
670           op, "0-D and 1-D vectors are handled separately");
671 
672     auto loc = op.getLoc();
673     auto eltType = dstType.getElementType();
674     int64_t dim = dstType.getDimSize(0);
675     Value idx = op.getOperand(0);
676 
677     VectorType lowType =
678         VectorType::get(dstType.getShape().drop_front(), eltType);
679     Value trueVal = rewriter.create<vector::CreateMaskOp>(
680         loc, lowType, op.getOperands().drop_front());
681     Value falseVal = rewriter.create<arith::ConstantOp>(
682         loc, lowType, rewriter.getZeroAttr(lowType));
683     Value result = rewriter.create<arith::ConstantOp>(
684         loc, dstType, rewriter.getZeroAttr(dstType));
685     for (int64_t d = 0; d < dim; d++) {
686       Value bnd =
687           rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d));
688       Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
689                                                  bnd, idx);
690       Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal);
691       auto pos = rewriter.getI64ArrayAttr(d);
692       result =
693           rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos);
694     }
695     rewriter.replaceOp(op, result);
696     return success();
697   }
698 };
699 
700 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
701 /// vectors progressively on the way to target llvm.matrix intrinsics.
702 /// This iterates over the most major dimension of the 2-D vector and performs
703 /// rewrites into:
704 ///   vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
705 class ShapeCastOp2DDownCastRewritePattern
706     : public OpRewritePattern<vector::ShapeCastOp> {
707 public:
708   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
709 
710   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
711                                 PatternRewriter &rewriter) const override {
712     auto sourceVectorType = op.getSourceVectorType();
713     auto resultVectorType = op.getResultVectorType();
714     if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
715       return failure();
716 
717     auto loc = op.getLoc();
718     Value desc = rewriter.create<arith::ConstantOp>(
719         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
720     unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
721     for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
722       Value vec = rewriter.create<vector::ExtractOp>(loc, op.getSource(), i);
723       desc = rewriter.create<vector::InsertStridedSliceOp>(
724           loc, vec, desc,
725           /*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
726     }
727     rewriter.replaceOp(op, desc);
728     return success();
729   }
730 };
731 
732 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
733 /// vectors progressively.
734 /// This iterates over the most major dimension of the 2-D vector and performs
735 /// rewrites into:
736 ///   vector.extract_strided_slice from 1-D + vector.insert into 2-D
737 /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
738 class ShapeCastOp2DUpCastRewritePattern
739     : public OpRewritePattern<vector::ShapeCastOp> {
740 public:
741   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
742 
743   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
744                                 PatternRewriter &rewriter) const override {
745     auto sourceVectorType = op.getSourceVectorType();
746     auto resultVectorType = op.getResultVectorType();
747     if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
748       return failure();
749 
750     auto loc = op.getLoc();
751     Value desc = rewriter.create<arith::ConstantOp>(
752         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
753     unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
754     for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
755       Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
756           loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize,
757           /*sizes=*/mostMinorVectorSize,
758           /*strides=*/1);
759       desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
760     }
761     rewriter.replaceOp(op, desc);
762     return success();
763   }
764 };
765 
766 // We typically should not lower general shape cast operations into data
767 // movement instructions, since the assumption is that these casts are
768 // optimized away during progressive lowering. For completeness, however,
769 // we fall back to a reference implementation that moves all elements
770 // into the right place if we get here.
771 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
772 public:
773   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
774 
775   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
776                                 PatternRewriter &rewriter) const override {
777     Location loc = op.getLoc();
778     auto sourceVectorType = op.getSourceVectorType();
779     auto resultVectorType = op.getResultVectorType();
780 
781     // Special case 2D/1D lowerings with better implementations.
782     // TODO: make is ND/1D to allow generic ND->1D->MD.
783     int64_t srcRank = sourceVectorType.getRank();
784     int64_t resRank = resultVectorType.getRank();
785     if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
786       return failure();
787 
788     // Generic ShapeCast lowering path goes all the way down to unrolled scalar
789     // extract/insert chains.
790     // TODO: consider evolving the semantics to only allow 1D source or dest and
791     // drop this potentially very expensive lowering.
792     // Compute number of elements involved in the reshape.
793     int64_t numElts = 1;
794     for (int64_t r = 0; r < srcRank; r++)
795       numElts *= sourceVectorType.getDimSize(r);
796     // Replace with data movement operations:
797     //    x[0,0,0] = y[0,0]
798     //    x[0,0,1] = y[0,1]
799     //    x[0,1,0] = y[0,2]
800     // etc., incrementing the two index vectors "row-major"
801     // within the source and result shape.
802     SmallVector<int64_t, 4> srcIdx(srcRank);
803     SmallVector<int64_t, 4> resIdx(resRank);
804     Value result = rewriter.create<arith::ConstantOp>(
805         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
806     for (int64_t i = 0; i < numElts; i++) {
807       if (i != 0) {
808         incIdx(srcIdx, sourceVectorType, srcRank - 1);
809         incIdx(resIdx, resultVectorType, resRank - 1);
810       }
811       Value e = rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
812       result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx);
813     }
814     rewriter.replaceOp(op, result);
815     return success();
816   }
817 
818 private:
819   static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) {
820     assert(0 <= r && r < tp.getRank());
821     if (++idx[r] == tp.getDimSize(r)) {
822       idx[r] = 0;
823       incIdx(idx, tp, r - 1);
824     }
825   }
826 };
827 
828 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
829 /// Ex:
830 /// ```
831 ///   %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
832 ///   %1 = vector.multi_reduction add, %0 [1]
833 ///     : vector<8x32x16xf32> to vector<8x16xf32>
834 /// ```
835 /// Gets converted to:
836 /// ```
837 ///   %1 = vector.contract {indexing_maps = [
838 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
839 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
840 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
841 ///    iterator_types = ["parallel", "parallel", "reduction"],
842 ///    kind = add} %0, %arg1, %cst_f0
843 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
844 ///  ```
845 struct MultiReduceToContract
846     : public OpRewritePattern<vector::MultiDimReductionOp> {
847   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
848 
849   LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
850                                 PatternRewriter &rewriter) const override {
851     if (reduceOp.getKind() != vector::CombiningKind::ADD)
852       return failure();
853     Operation *mulOp = reduceOp.getSource().getDefiningOp();
854     if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
855       return failure();
856     SmallVector<bool> reductionMask = reduceOp.getReductionMask();
857     auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
858     SmallVector<AffineExpr> exprs;
859     SmallVector<StringRef> iteratorTypes;
860     for (const auto &isReduceDim : llvm::enumerate(reductionMask)) {
861       if (!isReduceDim.value()) {
862         iteratorTypes.push_back(getParallelIteratorTypeName());
863         exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
864       } else {
865         iteratorTypes.push_back(getReductionIteratorTypeName());
866       }
867     }
868     auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(),
869                                  /*symCount=*/0, exprs, reduceOp.getContext());
870     Value zero = rewriter.create<arith::ConstantOp>(
871         reduceOp.getLoc(), reduceOp.getDestType(),
872         rewriter.getZeroAttr(reduceOp.getDestType()));
873     rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
874         reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), zero,
875         rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
876         rewriter.getStrArrayAttr(iteratorTypes));
877     return success();
878   }
879 };
880 
881 /// Merge TransposeOp into ContractionOp user.
882 /// Ex:
883 /// ```
884 ///   %0 = vector.transpose %arg0, [2, 0, 1]
885 ///     : vector<32x16x8xf32> to vector<8x32x16xf32>
886 ///   %1 = vector.contract {indexing_maps = [
887 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
888 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
889 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
890 ///    iterator_types = ["parallel", "parallel", "reduction"],
891 ///    kind = add} %0, %arg1, %cst_f0
892 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
893 /// ```
894 /// Gets converted to:
895 /// ```
896 ///   %1 = vector.contract {indexing_maps = [
897 ///         affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
898 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
899 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
900 ///    iterator_types = ["parallel", "parallel", "reduction"],
901 ///    kind = add} %arg0, %arg1, %cst_f0
902 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
903 ///  ```
904 struct CombineContractTranspose
905     : public OpRewritePattern<vector::ContractionOp> {
906   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
907 
908   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
909                                 PatternRewriter &rewriter) const override {
910     SmallVector<AffineMap, 4> maps =
911         llvm::to_vector<4>(contractOp.getIndexingMaps());
912     Value lhs = contractOp.getLhs();
913     Value rhs = contractOp.getRhs();
914     size_t index = 0;
915     bool changed = false;
916     for (Value *operand : {&lhs, &rhs}) {
917       AffineMap &map = maps[index++];
918       auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
919       if (!transposeOp)
920         continue;
921       SmallVector<int64_t> perm;
922       transposeOp.getTransp(perm);
923       AffineMap permutationMap = AffineMap::getPermutationMap(
924           extractVector<unsigned>(transposeOp.getTransp()),
925           contractOp.getContext());
926       map = inversePermutation(permutationMap).compose(map);
927       *operand = transposeOp.getVector();
928       changed = true;
929     }
930     if (!changed)
931       return failure();
932     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
933         contractOp, lhs, rhs, contractOp.getAcc(),
934         rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
935     return success();
936   }
937 };
938 
939 /// Merge BroadcastOp into ContractionOp user.
940 /// Ex:
941 /// ```
942 ///   %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
943 ///   %1 = vector.contract {indexing_maps = [
944 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
945 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
946 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
947 ///    iterator_types = ["parallel", "parallel", "reduction"],
948 ///    kind = add} %0, %arg1, %cst_f0
949 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
950 /// ```
951 /// Gets converted to:
952 /// ```
953 ///   %1 = vector.contract {indexing_maps = [
954 ///         affine_map<(d0, d1, d2) -> (d1, d2)>,
955 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
956 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
957 ///    iterator_types = ["parallel", "parallel", "reduction"],
958 ///    kind = add} %arg0, %arg1, %cst_f0
959 ///    : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
960 ///  ```
961 struct CombineContractBroadcast
962     : public OpRewritePattern<vector::ContractionOp> {
963   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
964 
965   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
966                                 PatternRewriter &rewriter) const override {
967     SmallVector<AffineMap, 4> maps =
968         llvm::to_vector<4>(contractOp.getIndexingMaps());
969     Value lhs = contractOp.getLhs();
970     Value rhs = contractOp.getRhs();
971     size_t index = 0;
972     bool changed = false;
973     for (Value *operand : {&lhs, &rhs}) {
974       AffineMap &map = maps[index++];
975       auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
976       if (!broadcast)
977         continue;
978       // contractionOp can only take vector as operands.
979       auto srcType = broadcast.getSourceType().dyn_cast<VectorType>();
980       if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank())
981         continue;
982       int64_t rankDiff =
983           broadcast.getVectorType().getRank() - srcType.getRank();
984       bool innerDimBroadcast = false;
985       SmallVector<AffineExpr> originalDims;
986       for (const auto &dim : llvm::enumerate(srcType.getShape())) {
987         if (dim.value() !=
988             broadcast.getVectorType().getDimSize(rankDiff + dim.index())) {
989           innerDimBroadcast = true;
990           break;
991         }
992         originalDims.push_back(
993             rewriter.getAffineDimExpr(dim.index() + rankDiff));
994       }
995       // Contract doesn't support inner dimension broadcast. Once this is
996       // relaxed we can remove this case.
997       if (innerDimBroadcast)
998         continue;
999       AffineMap broadcastMap =
1000           AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims,
1001                          contractOp.getContext());
1002       map = broadcastMap.compose(map);
1003       *operand = broadcast.getSource();
1004       changed = true;
1005     }
1006     if (!changed)
1007       return failure();
1008     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1009         contractOp, lhs, rhs, contractOp.getAcc(),
1010         rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
1011     return success();
1012   }
1013 };
1014 
1015 /// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and
1016 /// contraction ops closer, which kicks in CombineContractBroadcast pattern when
1017 /// casting ops are around these operations.
1018 /// Ex:
1019 /// ```
1020 ///   %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8>
1021 ///   %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
1022 /// ```
1023 /// Gets converted to:
1024 /// ```
1025 ///   %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32>
1026 ///   %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32>
1027 /// ```
1028 struct ReorderCastOpsOnBroadcast
1029     : public OpInterfaceRewritePattern<CastOpInterface> {
1030   using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
1031 
1032   LogicalResult matchAndRewrite(CastOpInterface op,
1033                                 PatternRewriter &rewriter) const override {
1034     if (op->getNumOperands() != 1)
1035       return failure();
1036     auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
1037     if (!bcastOp)
1038       return failure();
1039 
1040     Type castResTy = getElementTypeOrSelf(op->getResult(0));
1041     if (auto vecTy = bcastOp.getSourceType().dyn_cast<VectorType>())
1042       castResTy = VectorType::get(vecTy.getShape(), castResTy);
1043     auto *castOp =
1044         rewriter.create(op->getLoc(), op->getName().getIdentifier(),
1045                         bcastOp.getSource(), castResTy, op->getAttrs());
1046     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1047         op, op->getResult(0).getType(), castOp->getResult(0));
1048     return success();
1049   }
1050 };
1051 
1052 /// Reorders elementwise(transpose) to transpose(elementwise). This makes
1053 /// transpose ops and contraction ops closer, which kicks in
1054 /// CombineContractTranspose pattern when elementwise ops are between these
1055 /// operations. Ex:
1056 /// ```
1057 /// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
1058 /// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
1059 /// %r = arith.addf %at, %bt : vector<2x4xf32>
1060 /// ```
1061 /// Gets converted to:
1062 /// ```
1063 /// %0 = arith.addf %a, %b : vector<4x2xf32>
1064 /// %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
1065 /// ```
1066 struct ReorderElementwiseOpsOnTranspose final
1067     : public OpTraitRewritePattern<OpTrait::Elementwise> {
1068   using OpTraitRewritePattern::OpTraitRewritePattern;
1069   LogicalResult matchAndRewrite(Operation *op,
1070                                 PatternRewriter &rewriter) const override {
1071     if (op->getNumResults() != 1 || op->getNumRegions() != 0)
1072       return failure();
1073 
1074     // Make sure all operands are transpose/constant ops and collect their
1075     // transposition maps.
1076     SmallVector<ArrayAttr, 4> transposeMaps;
1077     transposeMaps.reserve(op->getNumOperands());
1078     // Record the initial type before transposition. We'll use its shape later.
1079     // Any type will do here as we will check all transpose maps are the same.
1080     VectorType srcType;
1081     for (Value operand : op->getOperands()) {
1082       auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
1083       if (transposeOp) {
1084         transposeMaps.push_back(transposeOp.getTransp());
1085         srcType = transposeOp.getVectorType();
1086       } else if (!matchPattern(operand, m_Constant())) {
1087         return failure();
1088       }
1089     }
1090     if (transposeMaps.empty())
1091       return failure();
1092     // This is an elementwise op, so all transposed operands should have the
1093     // same type. We need to additionally check that all transposes uses the
1094     // same map.
1095     if (!llvm::is_splat(transposeMaps))
1096       return rewriter.notifyMatchFailure(op, "different transpose map");
1097 
1098     SmallVector<Value, 4> srcValues;
1099     srcValues.reserve(op->getNumOperands());
1100 
1101     // If there are constant operands, we need to insert inverse transposes for
1102     // them. Calculate the inverse order first.
1103     auto order = extractVector<unsigned>(transposeMaps.front());
1104     SmallVector<int64_t> invOrder(order.size());
1105     for (int i = 0, e = order.size(); i < e; ++i)
1106       invOrder[order[i]] = i;
1107 
1108     for (Value operand : op->getOperands()) {
1109       auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
1110       if (transposeOp) {
1111         srcValues.push_back(transposeOp.getVector());
1112       } else {
1113         // This is a constant. Create a reverse transpose op for it.
1114         auto vectorType = VectorType::get(
1115             srcType.getShape(),
1116             operand.getType().cast<VectorType>().getElementType());
1117         srcValues.push_back(rewriter.create<vector::TransposeOp>(
1118             operand.getLoc(), vectorType, operand,
1119             rewriter.getI64ArrayAttr(invOrder)));
1120       }
1121     }
1122 
1123     auto vectorType = VectorType::get(
1124         srcType.getShape(),
1125         op->getResultTypes()[0].cast<VectorType>().getElementType());
1126     Operation *elementwiseOp =
1127         rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
1128                         vectorType, op->getAttrs());
1129     rewriter.replaceOpWithNewOp<vector::TransposeOp>(
1130         op, op->getResultTypes()[0], elementwiseOp->getResult(0),
1131         transposeMaps.front());
1132     return success();
1133   }
1134 };
1135 
1136 } // namespace
1137 
1138 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
1139 /// operands `x` and `y`.
1140 static Value createAdd(Location loc, Value x, Value y, bool isInt,
1141                        PatternRewriter &rewriter) {
1142   if (isInt)
1143     return rewriter.create<arith::AddIOp>(loc, x, y);
1144   return rewriter.create<arith::AddFOp>(loc, x, y);
1145 }
1146 
1147 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
1148 /// operands `x and `y`.
1149 static Value createMul(Location loc, Value x, Value y, bool isInt,
1150                        PatternRewriter &rewriter) {
1151   if (isInt)
1152     return rewriter.create<arith::MulIOp>(loc, x, y);
1153   return rewriter.create<arith::MulFOp>(loc, x, y);
1154 }
1155 
1156 namespace mlir {
1157 
1158 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1159 /// semantics to:
1160 /// ```
1161 ///    %mta = maybe_transpose
1162 ///    %mtb = maybe_transpose
1163 ///    %flattened_a = vector.shape_cast %mta
1164 ///    %flattened_b = vector.shape_cast %mtb
1165 ///    %flattened_d = vector.matmul %flattened_a, %flattened_b
1166 ///    %mtd = vector.shape_cast %flattened_d
1167 ///    %d = maybe_untranspose %mtd
1168 ///    %e = add %c, %d
1169 /// ```
1170 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
1171 //
1172 /// This only kicks in when VectorTransformsOptions is set to `Matmul`.
1173 /// vector.transpose operations are inserted if the vector.contract op is not a
1174 /// row-major matrix multiply.
1175 LogicalResult
1176 ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
1177                                                  PatternRewriter &rew) const {
1178   // TODO: implement masks
1179   if (llvm::size(op.getMasks()) != 0)
1180     return failure();
1181   if (vectorTransformOptions.vectorContractLowering !=
1182       vector::VectorContractLowering::Matmul)
1183     return failure();
1184   if (failed(filter(op)))
1185     return failure();
1186 
1187   auto iteratorTypes = op.getIteratorTypes().getValue();
1188   if (!isParallelIterator(iteratorTypes[0]) ||
1189       !isParallelIterator(iteratorTypes[1]) ||
1190       !isReductionIterator(iteratorTypes[2]))
1191     return failure();
1192 
1193   Type elementType = op.getLhsType().getElementType();
1194   if (!elementType.isIntOrFloat())
1195     return failure();
1196 
1197   // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
1198   // Bail out if the contraction cannot be put in this form.
1199   MLIRContext *ctx = op.getContext();
1200   Location loc = op.getLoc();
1201   AffineExpr m, n, k;
1202   bindDims(rew.getContext(), m, n, k);
1203   // LHS must be A(m, k) or A(k, m).
1204   Value lhs = op.getLhs();
1205   auto lhsMap = op.getIndexingMaps()[0];
1206   if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
1207     lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
1208   else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
1209     return failure();
1210 
1211   // RHS must be B(k, n) or B(n, k).
1212   Value rhs = op.getRhs();
1213   auto rhsMap = op.getIndexingMaps()[1];
1214   if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
1215     rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
1216   else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
1217     return failure();
1218 
1219   // At this point lhs and rhs are in row-major.
1220   VectorType lhsType = lhs.getType().cast<VectorType>();
1221   VectorType rhsType = rhs.getType().cast<VectorType>();
1222   int64_t lhsRows = lhsType.getDimSize(0);
1223   int64_t lhsColumns = lhsType.getDimSize(1);
1224   int64_t rhsColumns = rhsType.getDimSize(1);
1225 
1226   Type flattenedLHSType =
1227       VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
1228   lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
1229 
1230   Type flattenedRHSType =
1231       VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
1232   rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
1233 
1234   Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
1235                                            rhsColumns);
1236   mul = rew.create<vector::ShapeCastOp>(
1237       loc,
1238       VectorType::get({lhsRows, rhsColumns},
1239                       getElementTypeOrSelf(op.getAcc().getType())),
1240       mul);
1241 
1242   // ACC must be C(m, n) or C(n, m).
1243   auto accMap = op.getIndexingMaps()[2];
1244   if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
1245     mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
1246   else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
1247     llvm_unreachable("invalid contraction semantics");
1248 
1249   Value res =
1250       elementType.isa<IntegerType>()
1251           ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
1252           : static_cast<Value>(
1253                 rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
1254 
1255   rew.replaceOp(op, res);
1256   return success();
1257 }
1258 
1259 namespace {
1260 struct IteratorType {
1261   IteratorType(StringRef strRef) : strRef(strRef) {}
1262   bool isOfType(Attribute attr) const {
1263     auto sAttr = attr.dyn_cast<StringAttr>();
1264     return sAttr && sAttr.getValue() == strRef;
1265   }
1266   StringRef strRef;
1267 };
1268 struct Par : public IteratorType {
1269   Par() : IteratorType(getParallelIteratorTypeName()) {}
1270 };
1271 struct Red : public IteratorType {
1272   Red() : IteratorType(getReductionIteratorTypeName()) {}
1273 };
1274 
1275 /// Generate a vector implementation for matmat, matvec and tmatvec.
1276 /// This unrolls outer-products along the reduction dimension.
1277 struct UnrolledOuterProductGenerator
1278     : public StructuredGenerator<vector::ContractionOp> {
1279   UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op)
1280       : StructuredGenerator<vector::ContractionOp>(builder, op),
1281         kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
1282         res(op.getAcc()), lhsType(op.getLhsType()) {}
1283 
1284   Value t(Value v) {
1285     static constexpr std::array<int64_t, 2> perm = {1, 0};
1286     return builder.create<vector::TransposeOp>(loc, v, perm);
1287   }
1288 
1289   Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) {
1290     assert(reductionSize > 0);
1291     for (int64_t k = 0; k < reductionSize; ++k) {
1292       Value a = builder.create<vector::ExtractOp>(loc, lhs, k);
1293       Value b = builder.create<vector::ExtractOp>(loc, rhs, k);
1294       res = builder.create<vector::OuterProductOp>(loc, res.getType(), a, b,
1295                                                    res, kind);
1296     }
1297     return res;
1298   }
1299 
1300   /// Two outer parallel, one inner reduction (matmat flavor).
1301   FailureOr<Value> matmat() {
1302     if (!iters({Par(), Par(), Red()}))
1303       return failure();
1304     // Set up the parallel/reduction structure in the right form.
1305     AffineExpr m, n, k;
1306     bindDims(builder.getContext(), m, n, k);
1307     // Classical row-major matmul:  Just permute the lhs.
1308     if (layout({{m, k}, {k, n}, {m, n}}))
1309       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1310     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1311     if (layout({{m, k}, {n, k}, {m, n}})) {
1312       Value tlhs = t(lhs);
1313       return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1));
1314     }
1315     // No need to permute anything.
1316     if (layout({{k, m}, {k, n}, {m, n}}))
1317       return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1318     // Just permute the rhs.
1319     if (layout({{k, m}, {n, k}, {m, n}}))
1320       return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0));
1321     // Transposed output: swap RHS and LHS.
1322     // Classical row-major matmul: permute the lhs.
1323     if (layout({{m, k}, {k, n}, {n, m}}))
1324       return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1));
1325     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1326     if (layout({{m, k}, {n, k}, {n, m}})) {
1327       Value trhs = t(rhs);
1328       return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1));
1329     }
1330     if (layout({{k, m}, {k, n}, {n, m}}))
1331       return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1332     if (layout({{k, m}, {n, k}, {n, m}}))
1333       return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1334     return failure();
1335   }
1336 
1337   /// One outer parallel, one inner reduction (matvec flavor)
1338   FailureOr<Value> matvec() {
1339     if (!iters({Par(), Red()}))
1340       return failure();
1341     AffineExpr m, k;
1342     bindDims(builder.getContext(), m, k);
1343 
1344     // Case mat-vec: transpose.
1345     if (layout({{m, k}, {k}, {m}}))
1346       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1347     // Case mat-trans-vec: ready to go.
1348     if (layout({{k, m}, {k}, {m}}))
1349       return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1350     // Case vec-mat: swap and transpose.
1351     if (layout({{k}, {m, k}, {m}}))
1352       return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1353     // Case vec-mat-trans: swap and ready to go.
1354     if (layout({{k}, {k, m}, {m}}))
1355       return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1356     return failure();
1357   }
1358 
1359   //
1360   // One outer reduction, one inner parallel (tmatvec flavor)
1361   //
1362   FailureOr<Value> tmatvec() {
1363     if (!iters({Red(), Par()}))
1364       return failure();
1365     AffineExpr k, m;
1366     bindDims(builder.getContext(), k, m);
1367 
1368     // Case mat-vec: transpose.
1369     if (layout({{m, k}, {k}, {m}}))
1370       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1371     // Case mat-trans-vec: ready to go.
1372     if (layout({{k, m}, {k}, {m}}))
1373       return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1374     // Case vec-mat: swap and transpose.
1375     if (layout({{k}, {m, k}, {m}}))
1376       return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1377     // Case vec-mat-trans: swap and ready to go.
1378     if (layout({{k}, {k, m}, {m}}))
1379       return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1380     return failure();
1381   }
1382 
1383 private:
1384   vector::CombiningKind kind;
1385   Value lhs, rhs, res;
1386   VectorType lhsType;
1387 };
1388 } // namespace
1389 
1390 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1391 /// semantics to a reduction_size-unrolled sequence:
1392 /// ```
1393 ///    %at = vector.transpose %a, [1, 0]
1394 ///    %bRow0 = vector.extract %b[0]
1395 ///    %atRow0 = vector.extract %at[0]
1396 ///    %c0 = vector.outerproduct %atRow0, %bRow0, %c
1397 ///    ...
1398 ///    %bRowK = vector.extract %b[K]
1399 ///    %atRowK = vector.extract %at[K]
1400 ///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
1401 /// ```
1402 ///
1403 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but
1404 /// otherwise supports any layout permutation of the matrix-multiply.
1405 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
1406     vector::ContractionOp op, PatternRewriter &rewriter) const {
1407   // TODO: implement masks
1408   if (llvm::size(op.getMasks()) != 0)
1409     return failure();
1410 
1411   if (vectorTransformOptions.vectorContractLowering !=
1412       vector::VectorContractLowering::OuterProduct)
1413     return failure();
1414 
1415   if (failed(filter(op)))
1416     return failure();
1417 
1418   UnrolledOuterProductGenerator e(rewriter, op);
1419   FailureOr<Value> matmatRes = e.matmat();
1420   if (succeeded(matmatRes)) {
1421     rewriter.replaceOp(op, *matmatRes);
1422     return success();
1423   }
1424   FailureOr<Value> matvecRes = e.matvec();
1425   if (succeeded(matvecRes)) {
1426     rewriter.replaceOp(op, *matvecRes);
1427     return success();
1428   }
1429   FailureOr<Value> tmatvecRes = e.tmatvec();
1430   if (succeeded(tmatvecRes)) {
1431     rewriter.replaceOp(op, *tmatvecRes);
1432     return success();
1433   }
1434 
1435   return failure();
1436 }
1437 
1438 LogicalResult
1439 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
1440                                             PatternRewriter &rewriter) const {
1441   // TODO: implement masks
1442   if (llvm::size(op.getMasks()) != 0)
1443     return failure();
1444 
1445   if (failed(filter(op)))
1446     return failure();
1447 
1448   if (vectorTransformOptions.vectorContractLowering !=
1449       vector::VectorContractLowering::Dot)
1450     return failure();
1451 
1452   auto iteratorTypes = op.getIteratorTypes().getValue();
1453   static constexpr std::array<int64_t, 2> perm = {1, 0};
1454   Location loc = op.getLoc();
1455   Value lhs = op.getLhs(), rhs = op.getRhs();
1456 
1457   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1458   auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
1459   AffineExpr m, n, k;
1460   bindDims(rewriter.getContext(), m, n, k);
1461   SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
1462   //
1463   // In the following we wish to make the reduction dimension innermost so we
1464   // can load vectors and just fmul + reduce into a scalar.
1465   //
1466   if (isParallelIterator(iteratorTypes[0]) &&
1467       isParallelIterator(iteratorTypes[1]) &&
1468       isReductionIterator(iteratorTypes[2])) {
1469     //
1470     // Two outer parallel, one inner reduction (matmat flavor).
1471     //
1472     if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1473       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1474     } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
1475       // No need to permute anything.
1476     } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1477       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1478       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1479     } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1480       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1481     } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1482       // This is the classical row-major matmul. Just permute the lhs.
1483       Value tmp = lhs;
1484       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1485       rhs = tmp;
1486     } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1487       std::swap(lhs, rhs);
1488     } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1489       Value tmp = lhs;
1490       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1491       rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
1492     } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1493       Value tmp = rhs;
1494       rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1495       lhs = tmp;
1496     } else {
1497       return failure();
1498     }
1499   } else if (isParallelIterator(iteratorTypes[0]) &&
1500              isReductionIterator(iteratorTypes[1])) {
1501     //
1502     // One outer parallel, one inner reduction (matvec flavor)
1503     //
1504     if (maps == infer({{m, n}, {n}, {m}})) {
1505       // No need to permute anything.
1506     } else if (maps == infer({{n, m}, {n}, {m}})) {
1507       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1508     } else if (maps == infer({{n}, {m, n}, {m}})) {
1509       std::swap(lhs, rhs);
1510     } else if (maps == infer({{n}, {n, m}, {m}})) {
1511       std::swap(lhs, rhs);
1512       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1513     } else {
1514       return failure();
1515     }
1516   } else {
1517     return failure();
1518   }
1519 
1520   VectorType dstType = op.getResultType().cast<VectorType>();
1521   assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
1522          "Expected dst type of rank 1 or 2");
1523 
1524   unsigned rank = dstType.getRank();
1525   unsigned dstRows = dstType.getShape()[0];
1526   unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
1527 
1528   // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
1529   Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
1530                                                  rewriter.getZeroAttr(dstType));
1531   bool isInt = dstType.getElementType().isa<IntegerType>();
1532   for (unsigned r = 0; r < dstRows; ++r) {
1533     Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
1534     for (unsigned c = 0; c < dstColumns; ++c) {
1535       Value b = rank == 1
1536                     ? rhs
1537                     : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
1538       Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
1539       Value reduced = rewriter.create<vector::ReductionOp>(
1540           op.getLoc(), vector::CombiningKind::ADD, m);
1541 
1542       SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
1543                                               : SmallVector<int64_t, 2>{r, c};
1544       res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
1545     }
1546   }
1547   if (auto acc = op.getAcc())
1548     res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
1549   rewriter.replaceOp(op, res);
1550   return success();
1551 }
1552 
1553 /// Progressive lowering of ContractionOp.
1554 /// One:
1555 ///   %x = vector.contract with at least one free/batch dimension
1556 /// is replaced by:
1557 ///   %a = vector.contract with one less free/batch dimension
1558 ///   %b = vector.contract with one less free/batch dimension
1559 ///   ..
1560 ///   %x = combine %a %b ..
1561 /// until a pure contraction is reached (no free/batch dimensions),
1562 /// which is replaced by a dot-product.
1563 ///
1564 /// This only kicks in when either VectorTransformsOptions is set
1565 /// to DOT or when other contraction patterns fail.
1566 //
1567 // TODO: break down into transpose/reshape/cast ops
1568 //               when they become available to avoid code dup
1569 // TODO: investigate lowering order impact on performance
1570 LogicalResult
1571 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
1572                                        PatternRewriter &rewriter) const {
1573   // TODO: implement masks.
1574   if (llvm::size(op.getMasks()) != 0)
1575     return failure();
1576 
1577   if (failed(filter(op)))
1578     return failure();
1579 
1580   // TODO: support mixed mode contract lowering.
1581   if (op.getLhsType().getElementType() !=
1582           getElementTypeOrSelf(op.getAccType()) ||
1583       op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
1584     return failure();
1585 
1586   // TODO: implement benefits, cost models.
1587   MLIRContext *ctx = op.getContext();
1588   ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
1589   if (succeeded(pat1.matchAndRewrite(op, rewriter)))
1590     return success();
1591   ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
1592   if (succeeded(pat2.matchAndRewrite(op, rewriter)))
1593     return success();
1594   ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
1595   if (succeeded(pat3.matchAndRewrite(op, rewriter)))
1596     return success();
1597 
1598   // Find first batch dimension in LHS/RHS, and lower when found.
1599   std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
1600   if (!batchDimMap.empty()) {
1601     int64_t lhsIndex = batchDimMap[0].first;
1602     int64_t rhsIndex = batchDimMap[0].second;
1603     rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter));
1604     return success();
1605   }
1606 
1607   // Collect contracting dimensions.
1608   std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
1609       op.getContractingDimMap();
1610   DenseSet<int64_t> lhsContractingDimSet;
1611   DenseSet<int64_t> rhsContractingDimSet;
1612   for (auto &dimPair : contractingDimMap) {
1613     lhsContractingDimSet.insert(dimPair.first);
1614     rhsContractingDimSet.insert(dimPair.second);
1615   }
1616 
1617   // Find first free dimension in LHS, and lower when found.
1618   VectorType lhsType = op.getLhsType();
1619   for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
1620     if (lhsContractingDimSet.count(lhsIndex) == 0) {
1621       rewriter.replaceOp(
1622           op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter));
1623       return success();
1624     }
1625   }
1626 
1627   // Find first free dimension in RHS, and lower when found.
1628   VectorType rhsType = op.getRhsType();
1629   for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
1630     if (rhsContractingDimSet.count(rhsIndex) == 0) {
1631       rewriter.replaceOp(
1632           op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter));
1633       return success();
1634     }
1635   }
1636 
1637   // Lower the first remaining reduction dimension.
1638   if (!contractingDimMap.empty()) {
1639     rewriter.replaceOp(op, lowerReduction(op, rewriter));
1640     return success();
1641   }
1642 
1643   return failure();
1644 }
1645 
1646 // Lower one parallel dimension.
1647 // TODO: consider reusing existing contract unrolling
1648 Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
1649                                            int64_t lhsIndex, int64_t rhsIndex,
1650                                            PatternRewriter &rewriter) const {
1651   VectorType lhsType = op.getLhsType();
1652   VectorType rhsType = op.getRhsType();
1653   VectorType resType = op.getResultType().cast<VectorType>();
1654   // Find the iterator type index and result index.
1655   SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
1656   int64_t iterIndex = -1;
1657   int64_t dimSize = -1;
1658   if (lhsIndex >= 0) {
1659     iterIndex = iMap[0].getDimPosition(lhsIndex);
1660     assert((rhsIndex < 0 || iterIndex == iMap[1].getDimPosition(rhsIndex)) &&
1661            "parallel index should be free in LHS or batch in LHS/RHS");
1662     dimSize = lhsType.getDimSize(lhsIndex);
1663   } else {
1664     assert(rhsIndex >= 0 && "missing parallel index");
1665     iterIndex = iMap[1].getDimPosition(rhsIndex);
1666     dimSize = rhsType.getDimSize(rhsIndex);
1667   }
1668   assert(iterIndex >= 0 && "parallel index not listed in operand mapping");
1669   Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex);
1670   assert(lookup.hasValue() && "parallel index not listed in reduction");
1671   int64_t resIndex = lookup.getValue();
1672   // Construct new iterator types and affine map array attribute.
1673   std::array<AffineMap, 3> lowIndexingMaps = {
1674       adjustMap(iMap[0], iterIndex, rewriter),
1675       adjustMap(iMap[1], iterIndex, rewriter),
1676       adjustMap(iMap[2], iterIndex, rewriter)};
1677   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1678   auto lowIter =
1679       rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1680   // Unroll into a series of lower dimensional vector.contract ops.
1681   Location loc = op.getLoc();
1682   Value result = rewriter.create<arith::ConstantOp>(
1683       loc, resType, rewriter.getZeroAttr(resType));
1684   for (int64_t d = 0; d < dimSize; ++d) {
1685     auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1686     auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1687     auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
1688     Value lowContract = rewriter.create<vector::ContractionOp>(
1689         loc, lhs, rhs, acc, lowAffine, lowIter);
1690     result =
1691         reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter);
1692   }
1693   return result;
1694 }
1695 
1696 // Lower one reduction dimension.
1697 Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
1698                                             PatternRewriter &rewriter) const {
1699   auto loc = op.getLoc();
1700   VectorType lhsType = op.getLhsType();
1701   VectorType rhsType = op.getRhsType();
1702   Type resType = op.getResultType();
1703   assert(!resType.isa<VectorType>());
1704   bool isInt = resType.isa<IntegerType>();
1705   // Use iterator index 0.
1706   int64_t iterIndex = 0;
1707   SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
1708   Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
1709   Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
1710   assert(lookupLhs.hasValue() && "missing LHS parallel index");
1711   assert(lookupRhs.hasValue() && "missing RHS parallel index");
1712   int64_t lhsIndex = lookupLhs.getValue();
1713   int64_t rhsIndex = lookupRhs.getValue();
1714   int64_t dimSize = lhsType.getDimSize(lhsIndex);
1715   assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape");
1716   // Base case.
1717   if (lhsType.getRank() == 1) {
1718     assert(rhsType.getRank() == 1 && "corrupt contraction");
1719     Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
1720     auto kind = vector::CombiningKind::ADD;
1721     Value res = rewriter.create<vector::ReductionOp>(loc, kind, m);
1722     if (auto acc = op.getAcc())
1723       res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
1724     return res;
1725   }
1726   // Construct new iterator types and affine map array attribute.
1727   std::array<AffineMap, 3> lowIndexingMaps = {
1728       adjustMap(iMap[0], iterIndex, rewriter),
1729       adjustMap(iMap[1], iterIndex, rewriter),
1730       adjustMap(iMap[2], iterIndex, rewriter)};
1731   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1732   auto lowIter =
1733       rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1734   // Unroll into a series of lower dimensional vector.contract ops.
1735   // By feeding the initial accumulator into the first contraction,
1736   // and the result of each contraction into the next, eventually
1737   // the sum of all reductions is computed.
1738   Value result = op.getAcc();
1739   for (int64_t d = 0; d < dimSize; ++d) {
1740     auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1741     auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1742     result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result,
1743                                                     lowAffine, lowIter);
1744   }
1745   return result;
1746 }
1747 
1748 } // namespace mlir
1749 
1750 Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
1751     OpBuilder &builder, Operation *op, ArrayRef<Value> ids,
1752     ArrayRef<int64_t> multiplicity, const AffineMap &map) {
1753   OpBuilder::InsertionGuard guard(builder);
1754   builder.setInsertionPointAfter(op);
1755   Location loc = op->getLoc();
1756   if (op->getNumResults() != 1)
1757     return {};
1758   Value result = op->getResult(0);
1759   VectorType type = op->getResult(0).getType().dyn_cast<VectorType>();
1760   if (!type || map.getNumResults() != multiplicity.size())
1761     return {};
1762   // For each dimension being distributed check that the size is a multiple of
1763   // the multiplicity. To handle more sizes we would need to support masking.
1764   unsigned multiplictyCount = 0;
1765   for (auto exp : map.getResults()) {
1766     auto affinExp = exp.dyn_cast<AffineDimExpr>();
1767     if (!affinExp || affinExp.getPosition() >= type.getRank() ||
1768         type.getDimSize(affinExp.getPosition()) %
1769                 multiplicity[multiplictyCount++] !=
1770             0)
1771       return {};
1772   }
1773   DistributeOps ops;
1774   ops.extract =
1775       builder.create<vector::ExtractMapOp>(loc, result, ids, multiplicity, map);
1776   ops.insert =
1777       builder.create<vector::InsertMapOp>(loc, ops.extract, result, ids);
1778   return ops;
1779 }
1780 
1781 /// Progressive lowering of transfer_read. This pattern supports lowering of
1782 /// `vector.transfer_read` to a combination of `vector.load` and
1783 /// `vector.broadcast` if all of the following hold:
1784 /// - Stride of most minor memref dimension must be 1.
1785 /// - Out-of-bounds masking is not required.
1786 /// - If the memref's element type is a vector type then it coincides with the
1787 ///   result type.
1788 /// - The permutation map doesn't perform permutation (broadcasting is allowed).
1789 struct TransferReadToVectorLoadLowering
1790     : public OpRewritePattern<vector::TransferReadOp> {
1791   TransferReadToVectorLoadLowering(MLIRContext *context,
1792                                    llvm::Optional<unsigned> maxRank)
1793       : OpRewritePattern<vector::TransferReadOp>(context),
1794         maxTransferRank(maxRank) {}
1795 
1796   LogicalResult matchAndRewrite(vector::TransferReadOp read,
1797                                 PatternRewriter &rewriter) const override {
1798     if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank)
1799       return failure();
1800 
1801     SmallVector<unsigned, 4> broadcastedDims;
1802     // Permutations are handled by VectorToSCF or
1803     // populateVectorTransferPermutationMapLoweringPatterns.
1804     // We let the 0-d corner case pass-through as it is supported.
1805     if (!read.getPermutationMap().isMinorIdentityWithBroadcasting(
1806             &broadcastedDims))
1807       return failure();
1808 
1809     auto memRefType = read.getShapedType().dyn_cast<MemRefType>();
1810     if (!memRefType)
1811       return failure();
1812 
1813     // Non-unit strides are handled by VectorToSCF.
1814     if (!vector::isLastMemrefDimUnitStride(memRefType))
1815       return failure();
1816 
1817     // If there is broadcasting involved then we first load the unbroadcasted
1818     // vector, and then broadcast it with `vector.broadcast`.
1819     ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
1820     SmallVector<int64_t, 4> unbroadcastedVectorShape(vectorShape.begin(),
1821                                                      vectorShape.end());
1822     for (unsigned i : broadcastedDims)
1823       unbroadcastedVectorShape[i] = 1;
1824     VectorType unbroadcastedVectorType = VectorType::get(
1825         unbroadcastedVectorShape, read.getVectorType().getElementType());
1826 
1827     // `vector.load` supports vector types as memref's elements only when the
1828     // resulting vector type is the same as the element type.
1829     auto memrefElTy = memRefType.getElementType();
1830     if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType)
1831       return failure();
1832 
1833     // Otherwise, element types of the memref and the vector must match.
1834     if (!memrefElTy.isa<VectorType>() &&
1835         memrefElTy != read.getVectorType().getElementType())
1836       return failure();
1837 
1838     // Out-of-bounds dims are handled by MaterializeTransferMask.
1839     if (read.hasOutOfBoundsDim())
1840       return failure();
1841 
1842     // Create vector load op.
1843     Operation *loadOp;
1844     if (read.getMask()) {
1845       Value fill = rewriter.create<vector::SplatOp>(
1846           read.getLoc(), unbroadcastedVectorType, read.getPadding());
1847       loadOp = rewriter.create<vector::MaskedLoadOp>(
1848           read.getLoc(), unbroadcastedVectorType, read.getSource(),
1849           read.getIndices(), read.getMask(), fill);
1850     } else {
1851       loadOp = rewriter.create<vector::LoadOp>(
1852           read.getLoc(), unbroadcastedVectorType, read.getSource(),
1853           read.getIndices());
1854     }
1855 
1856     // Insert a broadcasting op if required.
1857     if (!broadcastedDims.empty()) {
1858       rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1859           read, read.getVectorType(), loadOp->getResult(0));
1860     } else {
1861       rewriter.replaceOp(read, loadOp->getResult(0));
1862     }
1863 
1864     return success();
1865   }
1866 
1867   llvm::Optional<unsigned> maxTransferRank;
1868 };
1869 
1870 /// Replace a 0-d vector.load with a memref.load + vector.broadcast.
1871 // TODO: we shouldn't cross the vector/scalar domains just for this
1872 // but atm we lack the infra to avoid it. Possible solutions include:
1873 // - go directly to LLVM + bitcast
1874 // - introduce a bitcast op and likely a new pointer dialect
1875 // - let memref.load/store additionally support the 0-d vector case
1876 // There are still deeper data layout issues lingering even in this
1877 // trivial case (for architectures for which this matters).
1878 struct VectorLoadToMemrefLoadLowering
1879     : public OpRewritePattern<vector::LoadOp> {
1880   using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
1881 
1882   LogicalResult matchAndRewrite(vector::LoadOp loadOp,
1883                                 PatternRewriter &rewriter) const override {
1884     auto vecType = loadOp.getVectorType();
1885     if (vecType.getNumElements() != 1)
1886       return failure();
1887     auto memrefLoad = rewriter.create<memref::LoadOp>(
1888         loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices());
1889     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
1890                                                      memrefLoad);
1891     return success();
1892   }
1893 };
1894 
1895 /// Replace a 0-d vector.store with a vector.extractelement + memref.store.
1896 struct VectorStoreToMemrefStoreLowering
1897     : public OpRewritePattern<vector::StoreOp> {
1898   using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
1899 
1900   LogicalResult matchAndRewrite(vector::StoreOp storeOp,
1901                                 PatternRewriter &rewriter) const override {
1902     auto vecType = storeOp.getVectorType();
1903     if (vecType.getNumElements() != 1)
1904       return failure();
1905     Value extracted;
1906     if (vecType.getRank() == 0) {
1907       // TODO: Unifiy once ExtractOp supports 0-d vectors.
1908       extracted = rewriter.create<vector::ExtractElementOp>(
1909           storeOp.getLoc(), storeOp.getValueToStore());
1910     } else {
1911       SmallVector<int64_t> indices(vecType.getRank(), 0);
1912       extracted = rewriter.create<vector::ExtractOp>(
1913           storeOp.getLoc(), storeOp.getValueToStore(), indices);
1914     }
1915 
1916     rewriter.replaceOpWithNewOp<memref::StoreOp>(
1917         storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
1918     return success();
1919   }
1920 };
1921 
1922 /// Progressive lowering of transfer_write. This pattern supports lowering of
1923 /// `vector.transfer_write` to `vector.store` if all of the following hold:
1924 /// - Stride of most minor memref dimension must be 1.
1925 /// - Out-of-bounds masking is not required.
1926 /// - If the memref's element type is a vector type then it coincides with the
1927 ///   type of the written value.
1928 /// - The permutation map is the minor identity map (neither permutation nor
1929 ///   broadcasting is allowed).
1930 struct TransferWriteToVectorStoreLowering
1931     : public OpRewritePattern<vector::TransferWriteOp> {
1932   TransferWriteToVectorStoreLowering(MLIRContext *context,
1933                                      llvm::Optional<unsigned> maxRank)
1934       : OpRewritePattern<vector::TransferWriteOp>(context),
1935         maxTransferRank(maxRank) {}
1936 
1937   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
1938                                 PatternRewriter &rewriter) const override {
1939     if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank)
1940       return failure();
1941 
1942     // Permutations are handled by VectorToSCF or
1943     // populateVectorTransferPermutationMapLoweringPatterns.
1944     if ( // pass-through for the 0-d corner case.
1945         !write.getPermutationMap().isMinorIdentity())
1946       return failure();
1947 
1948     auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
1949     if (!memRefType)
1950       return failure();
1951 
1952     // Non-unit strides are handled by VectorToSCF.
1953     if (!vector::isLastMemrefDimUnitStride(memRefType))
1954       return failure();
1955 
1956     // `vector.store` supports vector types as memref's elements only when the
1957     // type of the vector value being written is the same as the element type.
1958     auto memrefElTy = memRefType.getElementType();
1959     if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType())
1960       return failure();
1961 
1962     // Otherwise, element types of the memref and the vector must match.
1963     if (!memrefElTy.isa<VectorType>() &&
1964         memrefElTy != write.getVectorType().getElementType())
1965       return failure();
1966 
1967     // Out-of-bounds dims are handled by MaterializeTransferMask.
1968     if (write.hasOutOfBoundsDim())
1969       return failure();
1970     if (write.getMask()) {
1971       rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
1972           write, write.getSource(), write.getIndices(), write.getMask(),
1973           write.getVector());
1974     } else {
1975       rewriter.replaceOpWithNewOp<vector::StoreOp>(
1976           write, write.getVector(), write.getSource(), write.getIndices());
1977     }
1978     return success();
1979   }
1980 
1981   llvm::Optional<unsigned> maxTransferRank;
1982 };
1983 
1984 // Returns the values in `arrayAttr` as an integer vector.
1985 static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) {
1986   return llvm::to_vector<4>(
1987       llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
1988                       [](IntegerAttr attr) { return attr.getInt(); }));
1989 }
1990 
1991 // Shuffles vector.bitcast op after vector.extract op.
1992 //
1993 // This transforms IR like:
1994 //   %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
1995 //   %1 = vector.extract %0[3] : vector<8xf16>
1996 // Into:
1997 //   %0 = vector.extract %src[1] : vector<4xf32>
1998 //   %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16>
1999 //   %2 = vector.extract %1[1] : vector<2xf16>
2000 struct BubbleDownVectorBitCastForExtract
2001     : public OpRewritePattern<vector::ExtractOp> {
2002   using OpRewritePattern::OpRewritePattern;
2003 
2004   LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
2005                                 PatternRewriter &rewriter) const override {
2006     // Only support extracting scalars for now.
2007     if (extractOp.getVectorType().getRank() != 1)
2008       return failure();
2009 
2010     auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
2011     if (!castOp)
2012       return failure();
2013 
2014     VectorType castSrcType = castOp.getSourceVectorType();
2015     VectorType castDstType = castOp.getResultVectorType();
2016     assert(castSrcType.getRank() == castDstType.getRank());
2017 
2018     // Fail to match if we only have one element in the cast op source.
2019     // This is to avoid infinite loop given that this pattern can generate
2020     // such cases.
2021     if (castSrcType.getNumElements() == 1)
2022       return failure();
2023 
2024     // Only support casting to a larger number of elements or now.
2025     // E.g., vector<4xf32> -> vector<8xf16>.
2026     if (castSrcType.getNumElements() > castDstType.getNumElements())
2027       return failure();
2028 
2029     unsigned expandRatio =
2030         castDstType.getNumElements() / castSrcType.getNumElements();
2031 
2032     auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t {
2033       return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
2034     };
2035 
2036     uint64_t index = getFirstIntValue(extractOp.getPosition());
2037 
2038     // Get the single scalar (as a vector) in the source value that packs the
2039     // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
2040     VectorType oneScalarType =
2041         VectorType::get({1}, castSrcType.getElementType());
2042     Value packedValue = rewriter.create<vector::ExtractOp>(
2043         extractOp.getLoc(), oneScalarType, castOp.getSource(),
2044         rewriter.getI64ArrayAttr(index / expandRatio));
2045 
2046     // Cast it to a vector with the desired scalar's type.
2047     // E.g. f32 -> vector<2xf16>
2048     VectorType packedType =
2049         VectorType::get({expandRatio}, castDstType.getElementType());
2050     Value castedValue = rewriter.create<vector::BitCastOp>(
2051         extractOp.getLoc(), packedType, packedValue);
2052 
2053     // Finally extract the desired scalar.
2054     rewriter.replaceOpWithNewOp<vector::ExtractOp>(
2055         extractOp, extractOp.getType(), castedValue,
2056         rewriter.getI64ArrayAttr(index % expandRatio));
2057 
2058     return success();
2059   }
2060 };
2061 
2062 // Shuffles vector.bitcast op after vector.extract_strided_slice op.
2063 //
2064 // This transforms IR like:
2065 //    %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
2066 //     %0 = vector.extract_strided_slice %cast {
2067 //            offsets = [4], sizes = [4], strides = [1]
2068 //          } : vector<8xf16> to vector<4xf16>
2069 // Into:
2070 //   %0 = vector.extract_strided_slice %src {
2071 //          offsets = [2], sizes = [2], strides = [1]
2072 //        } : vector<4xf32> to vector<2xf32>
2073 //   %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16>
2074 struct BubbleDownBitCastForStridedSliceExtract
2075     : public OpRewritePattern<vector::ExtractStridedSliceOp> {
2076   using OpRewritePattern::OpRewritePattern;
2077 
2078   LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
2079                                 PatternRewriter &rewriter) const override {
2080     auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
2081     if (!castOp)
2082       return failure();
2083 
2084     VectorType castSrcType = castOp.getSourceVectorType();
2085     VectorType castDstType = castOp.getResultVectorType();
2086     assert(castSrcType.getRank() == castDstType.getRank());
2087 
2088     int64_t castSrcLastDim = castSrcType.getShape().back();
2089     int64_t castDstLastDim = castDstType.getShape().back();
2090     // Require casting to more elements for now; other cases to be implemented.
2091     if (castSrcLastDim > castDstLastDim)
2092       return failure();
2093 
2094     // Only accept all one strides for now.
2095     if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
2096                      [](const APInt &val) { return !val.isOneValue(); }))
2097       return failure();
2098 
2099     unsigned rank = extractOp.getVectorType().getRank();
2100     assert(castDstLastDim % castSrcLastDim == 0);
2101     int64_t expandRatio = castDstLastDim / castSrcLastDim;
2102 
2103     // If we have a less number of offsets than the rank, then implicitly we
2104     // are selecting the full range for the last bitcasted dimension; other
2105     // dimensions aren't affected. Otherwise, we need to scale down the last
2106     // dimension's offset given we are extracting from less elements now.
2107     ArrayAttr newOffsets = extractOp.getOffsets();
2108     if (newOffsets.size() == rank) {
2109       SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
2110       if (offsets.back() % expandRatio != 0)
2111         return failure();
2112       offsets.back() = offsets.back() / expandRatio;
2113       newOffsets = rewriter.getI64ArrayAttr(offsets);
2114     }
2115 
2116     // Similarly for sizes.
2117     ArrayAttr newSizes = extractOp.getSizes();
2118     if (newSizes.size() == rank) {
2119       SmallVector<int64_t, 4> sizes = getIntValueVector(newSizes);
2120       if (sizes.back() % expandRatio != 0)
2121         return failure();
2122       sizes.back() = sizes.back() / expandRatio;
2123       newSizes = rewriter.getI64ArrayAttr(sizes);
2124     }
2125 
2126     SmallVector<int64_t, 4> dims =
2127         llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape());
2128     dims.back() = dims.back() / expandRatio;
2129     VectorType newExtractType =
2130         VectorType::get(dims, castSrcType.getElementType());
2131 
2132     auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
2133         extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets,
2134         newSizes, extractOp.getStrides());
2135 
2136     rewriter.replaceOpWithNewOp<vector::BitCastOp>(
2137         extractOp, extractOp.getType(), newExtractOp);
2138 
2139     return success();
2140   }
2141 };
2142 
2143 // Shuffles vector.bitcast op before vector.insert_strided_slice op.
2144 //
2145 // This transforms IR like:
2146 //   %0 = vector.insert_strided_slice %src, %dst {
2147 //          offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
2148 //   %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
2149 // Into:
2150 //   %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32>
2151 //   %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32>
2152 //   %2 = vector.insert_strided_slice %src, %dst {
2153 //          offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
2154 struct BubbleUpBitCastForStridedSliceInsert
2155     : public OpRewritePattern<vector::BitCastOp> {
2156   using OpRewritePattern::OpRewritePattern;
2157   LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
2158                                 PatternRewriter &rewriter) const override {
2159     VectorType castSrcType = bitcastOp.getSourceVectorType();
2160     VectorType castDstType = bitcastOp.getResultVectorType();
2161     assert(castSrcType.getRank() == castDstType.getRank());
2162 
2163     int64_t castSrcLastDim = castSrcType.getShape().back();
2164     int64_t castDstLastDim = castDstType.getShape().back();
2165     // Require casting to less elements for now; other cases to be implemented.
2166     if (castSrcLastDim < castDstLastDim)
2167       return failure();
2168 
2169     assert(castSrcLastDim % castDstLastDim == 0);
2170     int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
2171 
2172     auto insertOp =
2173         bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
2174     if (!insertOp)
2175       return failure();
2176 
2177     // Only accept all one strides for now.
2178     if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
2179                      [](const APInt &val) { return !val.isOneValue(); }))
2180       return failure();
2181 
2182     unsigned rank = insertOp.getSourceVectorType().getRank();
2183     // Require insert op to have the same rank for the source and destination
2184     // vector; other cases to be implemented.
2185     if (rank != insertOp.getDestVectorType().getRank())
2186       return failure();
2187 
2188     ArrayAttr newOffsets = insertOp.getOffsets();
2189     assert(newOffsets.size() == rank);
2190     SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
2191     if (offsets.back() % shrinkRatio != 0)
2192       return failure();
2193     offsets.back() = offsets.back() / shrinkRatio;
2194     newOffsets = rewriter.getI64ArrayAttr(offsets);
2195 
2196     SmallVector<int64_t, 4> srcDims =
2197         llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
2198     srcDims.back() = srcDims.back() / shrinkRatio;
2199     VectorType newCastSrcType =
2200         VectorType::get(srcDims, castDstType.getElementType());
2201 
2202     auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
2203         bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
2204 
2205     SmallVector<int64_t, 4> dstDims =
2206         llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
2207     dstDims.back() = dstDims.back() / shrinkRatio;
2208     VectorType newCastDstType =
2209         VectorType::get(dstDims, castDstType.getElementType());
2210 
2211     auto newCastDstOp = rewriter.create<vector::BitCastOp>(
2212         bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
2213 
2214     rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
2215         bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
2216         insertOp.getStrides());
2217 
2218     return success();
2219   }
2220 };
2221 
2222 // Helper that returns a vector comparison that constructs a mask:
2223 //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
2224 //
2225 // If `dim == 0` then the result will be a 0-D vector.
2226 //
2227 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
2228 //       much more compact, IR for this operation, but LLVM eventually
2229 //       generates more elaborate instructions for this intrinsic since it
2230 //       is very conservative on the boundary conditions.
2231 static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
2232                                    bool force32BitVectorIndices, int64_t dim,
2233                                    Value b, Value *off = nullptr) {
2234   auto loc = op->getLoc();
2235   // If we can assume all indices fit in 32-bit, we perform the vector
2236   // comparison in 32-bit to get a higher degree of SIMD parallelism.
2237   // Otherwise we perform the vector comparison using 64-bit indices.
2238   Type idxType =
2239       force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
2240   DenseIntElementsAttr indicesAttr;
2241   if (dim == 0 && force32BitVectorIndices) {
2242     indicesAttr = DenseIntElementsAttr::get(
2243         VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0});
2244   } else if (dim == 0) {
2245     indicesAttr = DenseIntElementsAttr::get(
2246         VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0});
2247   } else if (force32BitVectorIndices) {
2248     indicesAttr = rewriter.getI32VectorAttr(
2249         llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
2250   } else {
2251     indicesAttr = rewriter.getI64VectorAttr(
2252         llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
2253   }
2254   Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
2255   // Add in an offset if requested.
2256   if (off) {
2257     Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
2258     Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o);
2259     indices = rewriter.create<arith::AddIOp>(loc, ov, indices);
2260   }
2261   // Construct the vector comparison.
2262   Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
2263   Value bounds =
2264       rewriter.create<vector::SplatOp>(loc, indices.getType(), bound);
2265   return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
2266                                         bounds);
2267 }
2268 
2269 template <typename ConcreteOp>
2270 struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
2271 public:
2272   explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt)
2273       : mlir::OpRewritePattern<ConcreteOp>(context),
2274         force32BitVectorIndices(enableIndexOpt) {}
2275 
2276   LogicalResult matchAndRewrite(ConcreteOp xferOp,
2277                                 PatternRewriter &rewriter) const override {
2278     if (!xferOp.hasOutOfBoundsDim())
2279       return failure();
2280 
2281     if (xferOp.getVectorType().getRank() > 1 ||
2282         llvm::size(xferOp.getIndices()) == 0)
2283       return failure();
2284 
2285     Location loc = xferOp->getLoc();
2286     VectorType vtp = xferOp.getVectorType();
2287 
2288     // Create the in-bounds mask with all elements between [0 .. dim - offset)
2289     // set and [dim - offset .. vector_length) unset.
2290     //
2291     // TODO: when the leaf transfer rank is k > 1, we need the last `k`
2292     //       dimensions here.
2293     unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
2294     Value off = xferOp.getIndices()[lastIndex];
2295     Value dim =
2296         vector::createOrFoldDimOp(rewriter, loc, xferOp.getSource(), lastIndex);
2297     Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off);
2298     Value mask = rewriter.create<vector::CreateMaskOp>(
2299         loc,
2300         VectorType::get(vtp.getShape(), rewriter.getI1Type(),
2301                         vtp.getNumScalableDims()),
2302         b);
2303     if (xferOp.getMask()) {
2304       // Intersect the in-bounds with the mask specified as an op parameter.
2305       mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask());
2306     }
2307 
2308     rewriter.updateRootInPlace(xferOp, [&]() {
2309       xferOp.getMaskMutable().assign(mask);
2310       xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
2311     });
2312 
2313     return success();
2314   }
2315 
2316 private:
2317   const bool force32BitVectorIndices;
2318 };
2319 
2320 /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only).
2321 class VectorCreateMaskOpConversion
2322     : public OpRewritePattern<vector::CreateMaskOp> {
2323 public:
2324   explicit VectorCreateMaskOpConversion(MLIRContext *context,
2325                                         bool enableIndexOpt)
2326       : mlir::OpRewritePattern<vector::CreateMaskOp>(context),
2327         force32BitVectorIndices(enableIndexOpt) {}
2328 
2329   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
2330                                 PatternRewriter &rewriter) const override {
2331     auto dstType = op.getType();
2332     if (dstType.cast<VectorType>().isScalable())
2333       return failure();
2334     int64_t rank = dstType.getRank();
2335     if (rank > 1)
2336       return failure();
2337     rewriter.replaceOp(
2338         op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
2339                                   rank == 0 ? 0 : dstType.getDimSize(0),
2340                                   op.getOperand(0)));
2341     return success();
2342   }
2343 
2344 private:
2345   const bool force32BitVectorIndices;
2346 };
2347 
2348 // Drop inner most contiguous unit dimensions from transfer_read operand.
2349 class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
2350   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
2351 
2352   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
2353                                 PatternRewriter &rewriter) const override {
2354     // TODO: support 0-d corner case.
2355     if (readOp.getTransferRank() == 0)
2356       return failure();
2357 
2358     // TODO: support mask.
2359     if (readOp.getMask())
2360       return failure();
2361 
2362     auto srcType = readOp.getSource().getType().dyn_cast<MemRefType>();
2363     if (!srcType || !srcType.hasStaticShape())
2364       return failure();
2365 
2366     if (!readOp.getPermutationMap().isMinorIdentity())
2367       return failure();
2368 
2369     auto targetType = readOp.getVectorType();
2370     if (targetType.getRank() <= 1)
2371       return failure();
2372 
2373     SmallVector<int64_t> srcStrides;
2374     int64_t srcOffset;
2375     if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
2376       return failure();
2377 
2378     size_t dimsToDrop = 0;
2379     for (size_t i = 1; i < srcStrides.size(); ++i) {
2380       int dim = srcType.getRank() - i - 1;
2381       if (srcStrides[dim] == 1) {
2382         dimsToDrop++;
2383       } else {
2384         break;
2385       }
2386     }
2387     if (dimsToDrop == 0)
2388       return failure();
2389 
2390     auto resultTargetVecType =
2391         VectorType::get(targetType.getShape().drop_back(dimsToDrop),
2392                         targetType.getElementType());
2393 
2394     MemRefType resultMemrefType;
2395     if (srcType.getLayout().getAffineMap().isIdentity()) {
2396       resultMemrefType = MemRefType::get(
2397           srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
2398           {}, srcType.getMemorySpaceAsInt());
2399     } else {
2400       AffineMap map = srcType.getLayout().getAffineMap();
2401       int numSymbols = map.getNumSymbols();
2402       for (size_t i = 0; i < dimsToDrop; ++i) {
2403         int dim = srcType.getRank() - i - 1;
2404         map = map.replace(rewriter.getAffineDimExpr(dim),
2405                           rewriter.getAffineConstantExpr(0),
2406                           map.getNumDims() - 1, numSymbols);
2407       }
2408       resultMemrefType = MemRefType::get(
2409           srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
2410           map, srcType.getMemorySpaceAsInt());
2411     }
2412 
2413     auto loc = readOp.getLoc();
2414     SmallVector<int64_t> offsets(srcType.getRank(), 0);
2415     SmallVector<int64_t> strides(srcType.getRank(), 1);
2416 
2417     ArrayAttr inBoundsAttr =
2418         readOp.getInBounds()
2419             ? rewriter.getArrayAttr(
2420                   readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
2421             : ArrayAttr();
2422     Value rankedReducedView = rewriter.create<memref::SubViewOp>(
2423         loc, resultMemrefType, readOp.getSource(), offsets, srcType.getShape(),
2424         strides);
2425     auto permMap = getTransferMinorIdentityMap(
2426         rankedReducedView.getType().cast<ShapedType>(), resultTargetVecType);
2427     Value result = rewriter.create<vector::TransferReadOp>(
2428         loc, resultTargetVecType, rankedReducedView,
2429         readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
2430         readOp.getPadding(),
2431         // TODO: support mask.
2432         /*mask=*/Value(), inBoundsAttr);
2433     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
2434                                                      result);
2435     return success();
2436   }
2437 };
2438 
2439 namespace {
2440 
2441 /// This function checks to see if the vector combining kind
2442 /// is consistent with the integer or float element type.
2443 static bool isValidKind(bool isInt, vector::CombiningKind kind) {
2444   using vector::CombiningKind;
2445   enum class KindType { FLOAT, INT, INVALID };
2446   KindType type{KindType::INVALID};
2447   switch (kind) {
2448   case CombiningKind::MINF:
2449   case CombiningKind::MAXF:
2450     type = KindType::FLOAT;
2451     break;
2452   case CombiningKind::MINUI:
2453   case CombiningKind::MINSI:
2454   case CombiningKind::MAXUI:
2455   case CombiningKind::MAXSI:
2456   case CombiningKind::AND:
2457   case CombiningKind::OR:
2458   case CombiningKind::XOR:
2459     type = KindType::INT;
2460     break;
2461   case CombiningKind::ADD:
2462   case CombiningKind::MUL:
2463     type = isInt ? KindType::INT : KindType::FLOAT;
2464     break;
2465   }
2466   bool isValidIntKind = (type == KindType::INT) && isInt;
2467   bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt);
2468   return (isValidIntKind || isValidFloatKind);
2469 }
2470 
2471 /// This function constructs the appropriate integer or float
2472 /// operation given the vector combining kind and operands. The
2473 /// supported int operations are : add, mul, min (signed/unsigned),
2474 /// max(signed/unsigned), and, or, xor. The supported float
2475 /// operations are : add, mul, min and max.
2476 static Value genOperator(Location loc, Value x, Value y,
2477                          vector::CombiningKind kind,
2478                          PatternRewriter &rewriter) {
2479   using vector::CombiningKind;
2480 
2481   auto elType = x.getType().cast<VectorType>().getElementType();
2482   bool isInt = elType.isIntOrIndex();
2483 
2484   Value combinedResult{nullptr};
2485   switch (kind) {
2486   case CombiningKind::ADD:
2487     if (isInt)
2488       combinedResult = rewriter.create<arith::AddIOp>(loc, x, y);
2489     else
2490       combinedResult = rewriter.create<arith::AddFOp>(loc, x, y);
2491     break;
2492   case CombiningKind::MUL:
2493     if (isInt)
2494       combinedResult = rewriter.create<arith::MulIOp>(loc, x, y);
2495     else
2496       combinedResult = rewriter.create<arith::MulFOp>(loc, x, y);
2497     break;
2498   case CombiningKind::MINUI:
2499     combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y);
2500     break;
2501   case CombiningKind::MINSI:
2502     combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y);
2503     break;
2504   case CombiningKind::MAXUI:
2505     combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y);
2506     break;
2507   case CombiningKind::MAXSI:
2508     combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y);
2509     break;
2510   case CombiningKind::AND:
2511     combinedResult = rewriter.create<arith::AndIOp>(loc, x, y);
2512     break;
2513   case CombiningKind::OR:
2514     combinedResult = rewriter.create<arith::OrIOp>(loc, x, y);
2515     break;
2516   case CombiningKind::XOR:
2517     combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y);
2518     break;
2519   case CombiningKind::MINF:
2520     combinedResult = rewriter.create<arith::MinFOp>(loc, x, y);
2521     break;
2522   case CombiningKind::MAXF:
2523     combinedResult = rewriter.create<arith::MaxFOp>(loc, x, y);
2524     break;
2525   }
2526   return combinedResult;
2527 }
2528 
2529 /// Convert vector.scan op into arith ops and
2530 /// vector.insert_strided_slice/extract_strided_slice
2531 ///
2532 /// Ex:
2533 /// ```
2534 ///   %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim =
2535 ///   1} :
2536 ///     (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>)
2537 /// ```
2538 /// Gets converted to:
2539 /// ```
2540 ///   %cst = arith.constant dense<0> : vector<2x3xi32>
2541 ///   %0 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 1],
2542 ///   strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %1 =
2543 ///   vector.insert_strided_slice %0, %cst {offsets = [0, 0], strides = [1, 1]}
2544 ///   : vector<2x1xi32> into vector<2x3xi32> %2 = vector.extract_strided_slice
2545 ///   %arg0 {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} :
2546 ///   vector<2x3xi32> to vector<2x1xi32> %3 = arith.muli %0, %2 :
2547 ///   vector<2x1xi32> %4 = vector.insert_strided_slice %3, %1 {offsets = [0, 1],
2548 ///   strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %5 =
2549 ///   vector.extract_strided_slice %arg0 {offsets = [0, 2], sizes = [2, 1],
2550 ///   strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %6 = arith.muli %3,
2551 ///   %5 : vector<2x1xi32> %7 = vector.insert_strided_slice %6, %4 {offsets =
2552 ///   [0, 2], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %8 =
2553 ///   vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32> return %7, %8 :
2554 ///   vector<2x3xi32>, vector<2xi32>
2555 /// ```
2556 struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
2557   using OpRewritePattern<vector::ScanOp>::OpRewritePattern;
2558 
2559   LogicalResult matchAndRewrite(vector::ScanOp scanOp,
2560                                 PatternRewriter &rewriter) const override {
2561     auto loc = scanOp.getLoc();
2562     VectorType destType = scanOp.getDestType();
2563     ArrayRef<int64_t> destShape = destType.getShape();
2564     auto elType = destType.getElementType();
2565     bool isInt = elType.isIntOrIndex();
2566     if (!isValidKind(isInt, scanOp.getKind()))
2567       return failure();
2568 
2569     VectorType resType = VectorType::get(destShape, elType);
2570     Value result = rewriter.create<arith::ConstantOp>(
2571         loc, resType, rewriter.getZeroAttr(resType));
2572     int64_t reductionDim = scanOp.getReductionDim();
2573     bool inclusive = scanOp.getInclusive();
2574     int64_t destRank = destType.getRank();
2575     VectorType initialValueType = scanOp.getInitialValueType();
2576     int64_t initialValueRank = initialValueType.getRank();
2577 
2578     SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end());
2579     reductionShape[reductionDim] = 1;
2580     VectorType reductionType = VectorType::get(reductionShape, elType);
2581     SmallVector<int64_t> offsets(destRank, 0);
2582     SmallVector<int64_t> strides(destRank, 1);
2583     SmallVector<int64_t> sizes(destShape.begin(), destShape.end());
2584     sizes[reductionDim] = 1;
2585     ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes);
2586     ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides);
2587 
2588     Value lastOutput, lastInput;
2589     for (int i = 0; i < destShape[reductionDim]; i++) {
2590       offsets[reductionDim] = i;
2591       ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets);
2592       Value input = rewriter.create<vector::ExtractStridedSliceOp>(
2593           loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes,
2594           scanStrides);
2595       Value output;
2596       if (i == 0) {
2597         if (inclusive) {
2598           output = input;
2599         } else {
2600           if (initialValueRank == 0) {
2601             // ShapeCastOp cannot handle 0-D vectors
2602             output = rewriter.create<vector::BroadcastOp>(
2603                 loc, input.getType(), scanOp.getInitialValue());
2604           } else {
2605             output = rewriter.create<vector::ShapeCastOp>(
2606                 loc, input.getType(), scanOp.getInitialValue());
2607           }
2608         }
2609       } else {
2610         Value y = inclusive ? input : lastInput;
2611         output = genOperator(loc, lastOutput, y, scanOp.getKind(), rewriter);
2612         assert(output != nullptr);
2613       }
2614       result = rewriter.create<vector::InsertStridedSliceOp>(
2615           loc, output, result, offsets, strides);
2616       lastOutput = output;
2617       lastInput = input;
2618     }
2619 
2620     Value reduction;
2621     if (initialValueRank == 0) {
2622       Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0);
2623       reduction =
2624           rewriter.create<vector::BroadcastOp>(loc, initialValueType, v);
2625     } else {
2626       reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType,
2627                                                        lastOutput);
2628     }
2629 
2630     rewriter.replaceOp(scanOp, {result, reduction});
2631     return success();
2632   }
2633 };
2634 
2635 } // namespace
2636 
2637 void mlir::vector::populateVectorMaskMaterializationPatterns(
2638     RewritePatternSet &patterns, bool force32BitVectorIndices) {
2639   patterns.add<VectorCreateMaskOpConversion,
2640                MaterializeTransferMask<vector::TransferReadOp>,
2641                MaterializeTransferMask<vector::TransferWriteOp>>(
2642       patterns.getContext(), force32BitVectorIndices);
2643 }
2644 
2645 void mlir::vector::populateShapeCastFoldingPatterns(
2646     RewritePatternSet &patterns) {
2647   patterns.add<ShapeCastOpFolder>(patterns.getContext());
2648 }
2649 
2650 void mlir::vector::populateBubbleVectorBitCastOpPatterns(
2651     RewritePatternSet &patterns) {
2652   patterns.add<BubbleDownVectorBitCastForExtract,
2653                BubbleDownBitCastForStridedSliceExtract,
2654                BubbleUpBitCastForStridedSliceInsert>(patterns.getContext());
2655 }
2656 
2657 void mlir::vector::populateVectorBroadcastLoweringPatterns(
2658     RewritePatternSet &patterns) {
2659   patterns.add<BroadcastOpLowering>(patterns.getContext());
2660 }
2661 
2662 void mlir::vector::populateVectorMaskOpLoweringPatterns(
2663     RewritePatternSet &patterns) {
2664   patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
2665       patterns.getContext());
2666 }
2667 
2668 void mlir::vector::populateVectorShapeCastLoweringPatterns(
2669     RewritePatternSet &patterns) {
2670   patterns.add<ShapeCastOp2DDownCastRewritePattern,
2671                ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>(
2672       patterns.getContext());
2673 }
2674 
2675 void mlir::vector::populateVectorContractLoweringPatterns(
2676     RewritePatternSet &patterns, VectorTransformsOptions options) {
2677   patterns.add<OuterProductOpLowering>(patterns.getContext());
2678   patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
2679                ContractionOpToOuterProductOpLowering>(options,
2680                                                       patterns.getContext());
2681 }
2682 
2683 void mlir::vector::populateVectorTransposeLoweringPatterns(
2684     RewritePatternSet &patterns, VectorTransformsOptions options) {
2685   patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
2686       options, patterns.getContext());
2687 }
2688 
2689 void mlir::vector::populateVectorReductionToContractPatterns(
2690     RewritePatternSet &patterns) {
2691   patterns.add<MultiReduceToContract, CombineContractBroadcast,
2692                CombineContractTranspose, ReorderCastOpsOnBroadcast,
2693                ReorderElementwiseOpsOnTranspose>(patterns.getContext());
2694 }
2695 
2696 void mlir::vector::
2697     populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
2698         RewritePatternSet &patterns) {
2699   patterns.add<DropInnerMostUnitDims>(patterns.getContext());
2700 }
2701 
2702 void mlir::vector::populateVectorTransferLoweringPatterns(
2703     RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank) {
2704   patterns.add<TransferReadToVectorLoadLowering,
2705                TransferWriteToVectorStoreLowering>(patterns.getContext(),
2706                                                    maxTransferRank);
2707   patterns
2708       .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
2709           patterns.getContext());
2710 }
2711 
2712 void mlir::vector::populateVectorScanLoweringPatterns(
2713     RewritePatternSet &patterns) {
2714   patterns.add<ScanToArithOps>(patterns.getContext());
2715 }
2716