1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites as 1->N patterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
14 
15 #include <type_traits>
16 
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
19 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
21 #include "mlir/Dialect/Linalg/IR/Linalg.h"
22 #include "mlir/Dialect/MemRef/IR/MemRef.h"
23 #include "mlir/Dialect/SCF/SCF.h"
24 #include "mlir/Dialect/Utils/IndexingUtils.h"
25 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
26 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
27 #include "mlir/IR/ImplicitLocOpBuilder.h"
28 #include "mlir/IR/Matchers.h"
29 #include "mlir/IR/PatternMatch.h"
30 #include "mlir/Interfaces/VectorInterfaces.h"
31 
32 #include "llvm/ADT/DenseSet.h"
33 #include "llvm/ADT/MapVector.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/Support/CommandLine.h"
36 #include "llvm/Support/Debug.h"
37 #include "llvm/Support/raw_ostream.h"
38 
39 #define DEBUG_TYPE "vector-to-vector"
40 
41 using namespace mlir;
42 using namespace mlir::vector;
43 
44 // Helper to find an index in an affine map.
45 static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
46   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
47     int64_t idx = map.getDimPosition(i);
48     if (idx == index)
49       return i;
50   }
51   return None;
52 }
53 
54 // Helper to construct iterator types with one index removed.
55 static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes,
56                                             int64_t index) {
57   SmallVector<Attribute, 4> results;
58   for (const auto &it : llvm::enumerate(iteratorTypes)) {
59     int64_t idx = it.index();
60     if (idx == index)
61       continue;
62     results.push_back(it.value());
63   }
64   return results;
65 }
66 
67 // Helper to construct an affine map with one index removed.
68 static AffineMap adjustMap(AffineMap map, int64_t index,
69                            PatternRewriter &rewriter) {
70   auto *ctx = rewriter.getContext();
71   SmallVector<AffineExpr, 4> results;
72   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
73     int64_t idx = map.getDimPosition(i);
74     if (idx == index)
75       continue;
76     // Re-insert remaining indices, but renamed when occurring
77     // after the removed index.
78     auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
79     results.push_back(targetExpr);
80   }
81   return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
82 }
83 
84 // Helper method to possibly drop a dimension in a load.
85 // TODO
86 static Value reshapeLoad(Location loc, Value val, VectorType type,
87                          int64_t index, int64_t pos,
88                          PatternRewriter &rewriter) {
89   if (index == -1)
90     return val;
91   Type lowType = VectorType::Builder(type).dropDim(0);
92   // At extraction dimension?
93   if (index == 0) {
94     auto posAttr = rewriter.getI64ArrayAttr(pos);
95     return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
96   }
97   // Unroll leading dimensions.
98   VectorType vType = lowType.cast<VectorType>();
99   Type resType = VectorType::Builder(type).dropDim(index);
100   auto resVectorType = resType.cast<VectorType>();
101   Value result = rewriter.create<arith::ConstantOp>(
102       loc, resVectorType, rewriter.getZeroAttr(resVectorType));
103   for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) {
104     auto posAttr = rewriter.getI64ArrayAttr(d);
105     Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
106     Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
107     result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result,
108                                                posAttr);
109   }
110   return result;
111 }
112 
113 // Helper method to possibly drop a dimension in a store.
114 // TODO
115 static Value reshapeStore(Location loc, Value val, Value result,
116                           VectorType type, int64_t index, int64_t pos,
117                           PatternRewriter &rewriter) {
118   // Unmodified?
119   if (index == -1)
120     return val;
121   // At insertion dimension?
122   if (index == 0) {
123     auto posAttr = rewriter.getI64ArrayAttr(pos);
124     return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
125   }
126   // Unroll leading dimensions.
127   Type lowType = VectorType::Builder(type).dropDim(0);
128   VectorType vType = lowType.cast<VectorType>();
129   Type insType = VectorType::Builder(vType).dropDim(0);
130   for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
131     auto posAttr = rewriter.getI64ArrayAttr(d);
132     Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
133     Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr);
134     Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
135     result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr);
136   }
137   return result;
138 }
139 
140 template <typename IntType>
141 static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
142   return llvm::to_vector<4>(llvm::map_range(
143       arrayAttr.getAsRange<IntegerAttr>(),
144       [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
145 }
146 
147 namespace {
148 
149 /// ShapeCastOpFolder folds cancelling ShapeCastOps away.
150 //
151 // Example:
152 //
153 //  The following MLIR with cancelling ShapeCastOps:
154 //
155 //   %0 = source : vector<5x4x2xf32>
156 //   %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
157 //   %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
158 //   %3 = user %2 : vector<5x4x2xf32>
159 //
160 //  Should canonicalize to the following:
161 //
162 //   %0 = source : vector<5x4x2xf32>
163 //   %1 = user %0 : vector<5x4x2xf32>
164 //
165 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
166   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
167 
168   LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
169                                 PatternRewriter &rewriter) const override {
170     // Check if 'shapeCastOp' has vector source/result type.
171     auto sourceVectorType =
172         shapeCastOp.source().getType().dyn_cast_or_null<VectorType>();
173     auto resultVectorType =
174         shapeCastOp.result().getType().dyn_cast_or_null<VectorType>();
175     if (!sourceVectorType || !resultVectorType)
176       return failure();
177 
178     // Check if shape cast op source operand is also a shape cast op.
179     auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
180         shapeCastOp.source().getDefiningOp());
181     if (!sourceShapeCastOp)
182       return failure();
183     auto operandSourceVectorType =
184         sourceShapeCastOp.source().getType().cast<VectorType>();
185     auto operandResultVectorType = sourceShapeCastOp.getType();
186 
187     // Check if shape cast operations invert each other.
188     if (operandSourceVectorType != resultVectorType ||
189         operandResultVectorType != sourceVectorType)
190       return failure();
191 
192     rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.source());
193     return success();
194   }
195 };
196 
197 /// Progressive lowering of BroadcastOp.
198 class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
199 public:
200   using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
201 
202   LogicalResult matchAndRewrite(vector::BroadcastOp op,
203                                 PatternRewriter &rewriter) const override {
204     auto loc = op.getLoc();
205     VectorType dstType = op.getVectorType();
206     VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
207     Type eltType = dstType.getElementType();
208 
209     // Scalar to any vector can use splat.
210     if (!srcType) {
211       rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.source());
212       return success();
213     }
214 
215     // Determine rank of source and destination.
216     int64_t srcRank = srcType.getRank();
217     int64_t dstRank = dstType.getRank();
218 
219     // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
220     if (srcRank <= 1 && dstRank == 1) {
221       Value ext;
222       if (srcRank == 0)
223         ext = rewriter.create<vector::ExtractElementOp>(loc, op.source());
224       else
225         ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
226       rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
227       return success();
228     }
229 
230     // Duplicate this rank.
231     // For example:
232     //   %x = broadcast %y  : k-D to n-D, k < n
233     // becomes:
234     //   %b = broadcast %y  : k-D to (n-1)-D
235     //   %x = [%b,%b,%b,%b] : n-D
236     // becomes:
237     //   %b = [%y,%y]       : (n-1)-D
238     //   %x = [%b,%b,%b,%b] : n-D
239     if (srcRank < dstRank) {
240       // Duplication.
241       VectorType resType =
242           VectorType::get(dstType.getShape().drop_front(), eltType);
243       Value bcst =
244           rewriter.create<vector::BroadcastOp>(loc, resType, op.source());
245       Value result = rewriter.create<arith::ConstantOp>(
246           loc, dstType, rewriter.getZeroAttr(dstType));
247       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
248         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
249       rewriter.replaceOp(op, result);
250       return success();
251     }
252 
253     // Find non-matching dimension, if any.
254     assert(srcRank == dstRank);
255     int64_t m = -1;
256     for (int64_t r = 0; r < dstRank; r++)
257       if (srcType.getDimSize(r) != dstType.getDimSize(r)) {
258         m = r;
259         break;
260       }
261 
262     // All trailing dimensions are the same. Simply pass through.
263     if (m == -1) {
264       rewriter.replaceOp(op, op.source());
265       return success();
266     }
267 
268     // Any non-matching dimension forces a stretch along this rank.
269     // For example:
270     //   %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32>
271     // becomes:
272     //   %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32>
273     //   %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32>
274     //   %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32>
275     //   %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32>
276     //   %x = [%a,%b,%c,%d]
277     // becomes:
278     //   %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32>
279     //   %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32>
280     //   %a = [%u, %v]
281     //   ..
282     //   %x = [%a,%b,%c,%d]
283     VectorType resType =
284         VectorType::get(dstType.getShape().drop_front(), eltType);
285     Value result = rewriter.create<arith::ConstantOp>(
286         loc, dstType, rewriter.getZeroAttr(dstType));
287     if (m == 0) {
288       // Stetch at start.
289       Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
290       Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
291       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
292         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
293     } else {
294       // Stetch not at start.
295       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
296         Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), d);
297         Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
298         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
299       }
300     }
301     rewriter.replaceOp(op, result);
302     return success();
303   }
304 };
305 
306 /// Given a 'transpose' pattern, prune the rightmost dimensions that are not
307 /// transposed.
308 void pruneNonTransposedDims(ArrayRef<int64_t> transpose,
309                             SmallVectorImpl<int64_t> &result) {
310   size_t numTransposedDims = transpose.size();
311   for (size_t transpDim : llvm::reverse(transpose)) {
312     if (transpDim != numTransposedDims - 1)
313       break;
314     numTransposedDims--;
315   }
316 
317   result.append(transpose.begin(), transpose.begin() + numTransposedDims);
318 }
319 
320 /// Progressive lowering of TransposeOp.
321 /// One:
322 ///   %x = vector.transpose %y, [1, 0]
323 /// is replaced by:
324 ///   %z = arith.constant dense<0.000000e+00>
325 ///   %0 = vector.extract %y[0, 0]
326 ///   %1 = vector.insert %0, %z [0, 0]
327 ///   ..
328 ///   %x = vector.insert .., .. [.., ..]
329 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
330 public:
331   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
332 
333   TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
334                       MLIRContext *context)
335       : OpRewritePattern<vector::TransposeOp>(context),
336         vectorTransformOptions(vectorTransformOptions) {}
337 
338   LogicalResult matchAndRewrite(vector::TransposeOp op,
339                                 PatternRewriter &rewriter) const override {
340     auto loc = op.getLoc();
341 
342     Value input = op.vector();
343     VectorType inputType = op.getVectorType();
344     VectorType resType = op.getResultType();
345 
346     // Set up convenience transposition table.
347     SmallVector<int64_t, 4> transp;
348     for (auto attr : op.transp())
349       transp.push_back(attr.cast<IntegerAttr>().getInt());
350 
351     if (vectorTransformOptions.vectorTransposeLowering ==
352             vector::VectorTransposeLowering::Shuffle &&
353         resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0)
354       return rewriter.notifyMatchFailure(
355           op, "Options specifies lowering to shuffle");
356 
357     // Handle a true 2-D matrix transpose differently when requested.
358     if (vectorTransformOptions.vectorTransposeLowering ==
359             vector::VectorTransposeLowering::Flat &&
360         resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
361       Type flattenedType =
362           VectorType::get(resType.getNumElements(), resType.getElementType());
363       auto matrix =
364           rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
365       auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
366       auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
367       Value trans = rewriter.create<vector::FlatTransposeOp>(
368           loc, flattenedType, matrix, rows, columns);
369       rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
370       return success();
371     }
372 
373     // Generate unrolled extract/insert ops. We do not unroll the rightmost
374     // (i.e., highest-order) dimensions that are not transposed and leave them
375     // in vector form to improve performance. Therefore, we prune those
376     // dimensions from the shape/transpose data structures used to generate the
377     // extract/insert ops.
378     SmallVector<int64_t, 4> prunedTransp;
379     pruneNonTransposedDims(transp, prunedTransp);
380     size_t numPrunedDims = transp.size() - prunedTransp.size();
381     auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
382     SmallVector<int64_t, 4> ones(prunedInShape.size(), 1);
383     auto prunedInStrides = computeStrides(prunedInShape, ones);
384 
385     // Generates the extract/insert operations for every scalar/vector element
386     // of the leftmost transposed dimensions. We traverse every transpose
387     // element using a linearized index that we delinearize to generate the
388     // appropriate indices for the extract/insert operations.
389     Value result = rewriter.create<arith::ConstantOp>(
390         loc, resType, rewriter.getZeroAttr(resType));
391     int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
392 
393     for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
394          ++linearIdx) {
395       auto extractIdxs = delinearize(prunedInStrides, linearIdx);
396       SmallVector<int64_t, 4> insertIdxs(extractIdxs);
397       applyPermutationToVector(insertIdxs, prunedTransp);
398       Value extractOp =
399           rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
400       result =
401           rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
402     }
403 
404     rewriter.replaceOp(op, result);
405     return success();
406   }
407 
408 private:
409   /// Options to control the vector patterns.
410   vector::VectorTransformsOptions vectorTransformOptions;
411 };
412 
413 /// Rewrite a 2-D vector.transpose as a sequence of:
414 ///   vector.shape_cast 2D -> 1D
415 ///   vector.shuffle
416 ///   vector.shape_cast 1D -> 2D
417 class TransposeOp2DToShuffleLowering
418     : public OpRewritePattern<vector::TransposeOp> {
419 public:
420   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
421 
422   TransposeOp2DToShuffleLowering(
423       vector::VectorTransformsOptions vectorTransformOptions,
424       MLIRContext *context)
425       : OpRewritePattern<vector::TransposeOp>(context),
426         vectorTransformOptions(vectorTransformOptions) {}
427 
428   LogicalResult matchAndRewrite(vector::TransposeOp op,
429                                 PatternRewriter &rewriter) const override {
430     auto loc = op.getLoc();
431 
432     VectorType srcType = op.getVectorType();
433     if (srcType.getRank() != 2)
434       return rewriter.notifyMatchFailure(op, "Not a 2D transpose");
435 
436     SmallVector<int64_t, 4> transp;
437     for (auto attr : op.transp())
438       transp.push_back(attr.cast<IntegerAttr>().getInt());
439     if (transp[0] != 1 && transp[1] != 0)
440       return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation");
441 
442     if (vectorTransformOptions.vectorTransposeLowering !=
443         VectorTransposeLowering::Shuffle)
444       return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle");
445 
446     int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
447     Value casted = rewriter.create<vector::ShapeCastOp>(
448         loc, VectorType::get({m * n}, srcType.getElementType()), op.vector());
449     SmallVector<int64_t> mask;
450     mask.reserve(m * n);
451     for (int64_t j = 0; j < n; ++j)
452       for (int64_t i = 0; i < m; ++i)
453         mask.push_back(i * n + j);
454 
455     Value shuffled =
456         rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask);
457     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
458                                                      shuffled);
459 
460     return success();
461   }
462 
463 private:
464   /// Options to control the vector patterns.
465   vector::VectorTransformsOptions vectorTransformOptions;
466 };
467 
468 /// Progressive lowering of OuterProductOp.
469 /// One:
470 ///   %x = vector.outerproduct %lhs, %rhs, %acc
471 /// is replaced by:
472 ///   %z = zero-result
473 ///   %0 = vector.extract %lhs[0]
474 ///   %1 = vector.broadcast %0
475 ///   %2 = vector.extract %acc[0]
476 ///   %3 = vector.fma %1, %rhs, %2
477 ///   %4 = vector.insert %3, %z[0]
478 ///   ..
479 ///   %x = vector.insert %.., %..[N-1]
480 ///
481 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
482 public:
483   using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
484 
485   LogicalResult matchAndRewrite(vector::OuterProductOp op,
486                                 PatternRewriter &rewriter) const override {
487     auto loc = op.getLoc();
488 
489     VectorType lhsType = op.getOperandVectorTypeLHS();
490     VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>();
491     VectorType resType = op.getVectorType();
492     Type eltType = resType.getElementType();
493     bool isInt = eltType.isa<IntegerType, IndexType>();
494     Value acc = (op.acc().empty()) ? nullptr : op.acc()[0];
495     vector::CombiningKind kind = op.kind();
496 
497     if (!rhsType) {
498       // Special case: AXPY operation.
499       Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.rhs());
500       Optional<Value> mult =
501           isInt ? genMultI(loc, op.lhs(), b, acc, kind, rewriter)
502                 : genMultF(loc, op.lhs(), b, acc, kind, rewriter);
503       if (!mult.hasValue())
504         return failure();
505       rewriter.replaceOp(op, mult.getValue());
506       return success();
507     }
508 
509     Value result = rewriter.create<arith::ConstantOp>(
510         loc, resType, rewriter.getZeroAttr(resType));
511     for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
512       auto pos = rewriter.getI64ArrayAttr(d);
513       Value x = rewriter.create<vector::ExtractOp>(loc, eltType, op.lhs(), pos);
514       Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
515       Value r = nullptr;
516       if (acc)
517         r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
518       Optional<Value> m = isInt ? genMultI(loc, a, op.rhs(), r, kind, rewriter)
519                                 : genMultF(loc, a, op.rhs(), r, kind, rewriter);
520       if (!m.hasValue())
521         return failure();
522       result = rewriter.create<vector::InsertOp>(loc, resType, m.getValue(),
523                                                  result, pos);
524     }
525     rewriter.replaceOp(op, result);
526     return success();
527   }
528 
529 private:
530   static Optional<Value> genMultI(Location loc, Value x, Value y, Value acc,
531                                   vector::CombiningKind kind,
532                                   PatternRewriter &rewriter) {
533     using vector::CombiningKind;
534 
535     auto mul = rewriter.create<arith::MulIOp>(loc, x, y);
536     if (!acc)
537       return Optional<Value>(mul);
538 
539     if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF)
540       // Only valid for floating point types.
541       return Optional<Value>();
542 
543     return makeArithReduction(rewriter, loc, kind, mul, acc);
544   }
545 
546   static Optional<Value> genMultF(Location loc, Value x, Value y, Value acc,
547                                   vector::CombiningKind kind,
548                                   PatternRewriter &rewriter) {
549     using vector::CombiningKind;
550 
551     // Special case for fused multiply-add.
552     if (acc && kind == CombiningKind::ADD) {
553       return Optional<Value>(rewriter.create<vector::FMAOp>(loc, x, y, acc));
554     }
555 
556     auto mul = rewriter.create<arith::MulFOp>(loc, x, y);
557 
558     if (!acc)
559       return Optional<Value>(mul);
560 
561     if (kind == CombiningKind::ADD || kind == CombiningKind::AND ||
562         kind == CombiningKind::MINUI || kind == CombiningKind::MINSI ||
563         kind == CombiningKind::MAXUI || kind == CombiningKind::MAXSI ||
564         kind == CombiningKind::OR || kind == CombiningKind::XOR)
565       // Already handled or only valid for integer types.
566       return Optional<Value>();
567 
568     return makeArithReduction(rewriter, loc, kind, mul, acc);
569   }
570 };
571 
572 /// Progressive lowering of ConstantMaskOp.
573 /// One:
574 ///   %x = vector.constant_mask [a,b]
575 /// is replaced by:
576 ///   %z = zero-result
577 ///   %l = vector.constant_mask [b]
578 ///   %4 = vector.insert %l, %z[0]
579 ///   ..
580 ///   %x = vector.insert %l, %..[a-1]
581 /// until a one-dimensional vector is reached. All these operations
582 /// will be folded at LLVM IR level.
583 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
584 public:
585   using OpRewritePattern<vector::ConstantMaskOp>::OpRewritePattern;
586 
587   LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
588                                 PatternRewriter &rewriter) const override {
589     auto loc = op.getLoc();
590     auto dstType = op.getType();
591     auto eltType = dstType.getElementType();
592     auto dimSizes = op.mask_dim_sizes();
593     int64_t rank = dstType.getRank();
594 
595     if (rank == 0) {
596       assert(dimSizes.size() == 1 &&
597              "Expected exactly one dim size for a 0-D vector");
598       bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1;
599       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
600           op, dstType,
601           DenseIntElementsAttr::get(
602               VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
603               ArrayRef<bool>{value}));
604       return success();
605     }
606 
607     // Scalable constant masks can only be lowered for the "none set" case.
608     if (dstType.cast<VectorType>().isScalable()) {
609       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
610           op, DenseElementsAttr::get(dstType, false));
611       return success();
612     }
613 
614     int64_t trueDim = std::min(dstType.getDimSize(0),
615                                dimSizes[0].cast<IntegerAttr>().getInt());
616 
617     if (rank == 1) {
618       // Express constant 1-D case in explicit vector form:
619       //   [T,..,T,F,..,F].
620       SmallVector<bool, 4> values(dstType.getDimSize(0));
621       for (int64_t d = 0; d < trueDim; d++)
622         values[d] = true;
623       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
624           op, dstType, rewriter.getBoolVectorAttr(values));
625       return success();
626     }
627 
628     VectorType lowType =
629         VectorType::get(dstType.getShape().drop_front(), eltType);
630     SmallVector<int64_t, 4> newDimSizes;
631     for (int64_t r = 1; r < rank; r++)
632       newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
633     Value trueVal = rewriter.create<vector::ConstantMaskOp>(
634         loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
635     Value result = rewriter.create<arith::ConstantOp>(
636         loc, dstType, rewriter.getZeroAttr(dstType));
637     for (int64_t d = 0; d < trueDim; d++) {
638       auto pos = rewriter.getI64ArrayAttr(d);
639       result =
640           rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos);
641     }
642     rewriter.replaceOp(op, result);
643     return success();
644   }
645 };
646 
647 /// Progressive lowering of CreateMaskOp.
648 /// One:
649 ///   %x = vector.create_mask %a, ... : vector<dx...>
650 /// is replaced by:
651 ///   %l = vector.create_mask ... : vector<...>  ; one lower rank
652 ///   %0 = arith.cmpi "slt", %ci, %a       |
653 ///   %1 = select %0, %l, %zeroes    |
654 ///   %r = vector.insert %1, %pr [i] | d-times
655 ///   %x = ....
656 /// until a one-dimensional vector is reached.
657 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
658 public:
659   using OpRewritePattern<vector::CreateMaskOp>::OpRewritePattern;
660 
661   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
662                                 PatternRewriter &rewriter) const override {
663     auto dstType = op.getResult().getType().cast<VectorType>();
664     int64_t rank = dstType.getRank();
665     if (rank <= 1)
666       return rewriter.notifyMatchFailure(
667           op, "0-D and 1-D vectors are handled separately");
668 
669     auto loc = op.getLoc();
670     auto eltType = dstType.getElementType();
671     int64_t dim = dstType.getDimSize(0);
672     Value idx = op.getOperand(0);
673 
674     VectorType lowType =
675         VectorType::get(dstType.getShape().drop_front(), eltType);
676     Value trueVal = rewriter.create<vector::CreateMaskOp>(
677         loc, lowType, op.getOperands().drop_front());
678     Value falseVal = rewriter.create<arith::ConstantOp>(
679         loc, lowType, rewriter.getZeroAttr(lowType));
680     Value result = rewriter.create<arith::ConstantOp>(
681         loc, dstType, rewriter.getZeroAttr(dstType));
682     for (int64_t d = 0; d < dim; d++) {
683       Value bnd =
684           rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d));
685       Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
686                                                  bnd, idx);
687       Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal);
688       auto pos = rewriter.getI64ArrayAttr(d);
689       result =
690           rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos);
691     }
692     rewriter.replaceOp(op, result);
693     return success();
694   }
695 };
696 
697 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
698 /// vectors progressively on the way to target llvm.matrix intrinsics.
699 /// This iterates over the most major dimension of the 2-D vector and performs
700 /// rewrites into:
701 ///   vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
702 class ShapeCastOp2DDownCastRewritePattern
703     : public OpRewritePattern<vector::ShapeCastOp> {
704 public:
705   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
706 
707   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
708                                 PatternRewriter &rewriter) const override {
709     auto sourceVectorType = op.getSourceVectorType();
710     auto resultVectorType = op.getResultVectorType();
711     if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
712       return failure();
713 
714     auto loc = op.getLoc();
715     Value desc = rewriter.create<arith::ConstantOp>(
716         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
717     unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
718     for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
719       Value vec = rewriter.create<vector::ExtractOp>(loc, op.source(), i);
720       desc = rewriter.create<vector::InsertStridedSliceOp>(
721           loc, vec, desc,
722           /*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
723     }
724     rewriter.replaceOp(op, desc);
725     return success();
726   }
727 };
728 
729 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
730 /// vectors progressively.
731 /// This iterates over the most major dimension of the 2-D vector and performs
732 /// rewrites into:
733 ///   vector.extract_strided_slice from 1-D + vector.insert into 2-D
734 /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
735 class ShapeCastOp2DUpCastRewritePattern
736     : public OpRewritePattern<vector::ShapeCastOp> {
737 public:
738   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
739 
740   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
741                                 PatternRewriter &rewriter) const override {
742     auto sourceVectorType = op.getSourceVectorType();
743     auto resultVectorType = op.getResultVectorType();
744     if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
745       return failure();
746 
747     auto loc = op.getLoc();
748     Value desc = rewriter.create<arith::ConstantOp>(
749         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
750     unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
751     for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
752       Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
753           loc, op.source(), /*offsets=*/i * mostMinorVectorSize,
754           /*sizes=*/mostMinorVectorSize,
755           /*strides=*/1);
756       desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
757     }
758     rewriter.replaceOp(op, desc);
759     return success();
760   }
761 };
762 
763 // We typically should not lower general shape cast operations into data
764 // movement instructions, since the assumption is that these casts are
765 // optimized away during progressive lowering. For completeness, however,
766 // we fall back to a reference implementation that moves all elements
767 // into the right place if we get here.
768 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
769 public:
770   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
771 
772   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
773                                 PatternRewriter &rewriter) const override {
774     Location loc = op.getLoc();
775     auto sourceVectorType = op.getSourceVectorType();
776     auto resultVectorType = op.getResultVectorType();
777 
778     // Special case 2D/1D lowerings with better implementations.
779     // TODO: make is ND/1D to allow generic ND->1D->MD.
780     int64_t srcRank = sourceVectorType.getRank();
781     int64_t resRank = resultVectorType.getRank();
782     if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
783       return failure();
784 
785     // Generic ShapeCast lowering path goes all the way down to unrolled scalar
786     // extract/insert chains.
787     // TODO: consider evolving the semantics to only allow 1D source or dest and
788     // drop this potentially very expensive lowering.
789     // Compute number of elements involved in the reshape.
790     int64_t numElts = 1;
791     for (int64_t r = 0; r < srcRank; r++)
792       numElts *= sourceVectorType.getDimSize(r);
793     // Replace with data movement operations:
794     //    x[0,0,0] = y[0,0]
795     //    x[0,0,1] = y[0,1]
796     //    x[0,1,0] = y[0,2]
797     // etc., incrementing the two index vectors "row-major"
798     // within the source and result shape.
799     SmallVector<int64_t, 4> srcIdx(srcRank);
800     SmallVector<int64_t, 4> resIdx(resRank);
801     Value result = rewriter.create<arith::ConstantOp>(
802         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
803     for (int64_t i = 0; i < numElts; i++) {
804       if (i != 0) {
805         incIdx(srcIdx, sourceVectorType, srcRank - 1);
806         incIdx(resIdx, resultVectorType, resRank - 1);
807       }
808       Value e = rewriter.create<vector::ExtractOp>(loc, op.source(), srcIdx);
809       result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx);
810     }
811     rewriter.replaceOp(op, result);
812     return success();
813   }
814 
815 private:
816   static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) {
817     assert(0 <= r && r < tp.getRank());
818     if (++idx[r] == tp.getDimSize(r)) {
819       idx[r] = 0;
820       incIdx(idx, tp, r - 1);
821     }
822   }
823 };
824 
825 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
826 /// Ex:
827 /// ```
828 ///   %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
829 ///   %1 = vector.multi_reduction add, %0 [1]
830 ///     : vector<8x32x16xf32> to vector<8x16xf32>
831 /// ```
832 /// Gets converted to:
833 /// ```
834 ///   %1 = vector.contract {indexing_maps = [
835 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
836 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
837 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
838 ///    iterator_types = ["parallel", "parallel", "reduction"],
839 ///    kind = add} %0, %arg1, %cst_f0
840 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
841 ///  ```
842 struct MultiReduceToContract
843     : public OpRewritePattern<vector::MultiDimReductionOp> {
844   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
845 
846   LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
847                                 PatternRewriter &rewriter) const override {
848     if (reduceOp.kind() != vector::CombiningKind::ADD)
849       return failure();
850     Operation *mulOp = reduceOp.source().getDefiningOp();
851     if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
852       return failure();
853     SmallVector<bool> reductionMask = reduceOp.getReductionMask();
854     auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
855     SmallVector<AffineExpr> exprs;
856     SmallVector<StringRef> iteratorTypes;
857     for (const auto &isReduceDim : llvm::enumerate(reductionMask)) {
858       if (!isReduceDim.value()) {
859         iteratorTypes.push_back(getParallelIteratorTypeName());
860         exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
861       } else {
862         iteratorTypes.push_back(getReductionIteratorTypeName());
863       }
864     }
865     auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(),
866                                  /*symCount=*/0, exprs, reduceOp.getContext());
867     Value zero = rewriter.create<arith::ConstantOp>(
868         reduceOp.getLoc(), reduceOp.getDestType(),
869         rewriter.getZeroAttr(reduceOp.getDestType()));
870     rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
871         reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), zero,
872         rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
873         rewriter.getStrArrayAttr(iteratorTypes));
874     return success();
875   }
876 };
877 
878 /// Merge TransposeOp into ContractionOp user.
879 /// Ex:
880 /// ```
881 ///   %0 = vector.transpose %arg0, [2, 0, 1]
882 ///     : vector<32x16x8xf32> to vector<8x32x16xf32>
883 ///   %1 = vector.contract {indexing_maps = [
884 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
885 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
886 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
887 ///    iterator_types = ["parallel", "parallel", "reduction"],
888 ///    kind = add} %0, %arg1, %cst_f0
889 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
890 /// ```
891 /// Gets converted to:
892 /// ```
893 ///   %1 = vector.contract {indexing_maps = [
894 ///         affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
895 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
896 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
897 ///    iterator_types = ["parallel", "parallel", "reduction"],
898 ///    kind = add} %arg0, %arg1, %cst_f0
899 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
900 ///  ```
901 struct CombineContractTranspose
902     : public OpRewritePattern<vector::ContractionOp> {
903   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
904 
905   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
906                                 PatternRewriter &rewriter) const override {
907     SmallVector<AffineMap, 4> maps =
908         llvm::to_vector<4>(contractOp.getIndexingMaps());
909     Value lhs = contractOp.lhs();
910     Value rhs = contractOp.rhs();
911     size_t index = 0;
912     bool changed = false;
913     for (Value *operand : {&lhs, &rhs}) {
914       AffineMap &map = maps[index++];
915       auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
916       if (!transposeOp)
917         continue;
918       SmallVector<int64_t> perm;
919       transposeOp.getTransp(perm);
920       AffineMap permutationMap = AffineMap::getPermutationMap(
921           extractVector<unsigned>(transposeOp.transp()),
922           contractOp.getContext());
923       map = inversePermutation(permutationMap).compose(map);
924       *operand = transposeOp.vector();
925       changed = true;
926     }
927     if (!changed)
928       return failure();
929     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
930         contractOp, lhs, rhs, contractOp.acc(),
931         rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types());
932     return success();
933   }
934 };
935 
936 /// Merge BroadcastOp into ContractionOp user.
937 /// Ex:
938 /// ```
939 ///   %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
940 ///   %1 = vector.contract {indexing_maps = [
941 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
942 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
943 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
944 ///    iterator_types = ["parallel", "parallel", "reduction"],
945 ///    kind = add} %0, %arg1, %cst_f0
946 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
947 /// ```
948 /// Gets converted to:
949 /// ```
950 ///   %1 = vector.contract {indexing_maps = [
951 ///         affine_map<(d0, d1, d2) -> (d1, d2)>,
952 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
953 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
954 ///    iterator_types = ["parallel", "parallel", "reduction"],
955 ///    kind = add} %arg0, %arg1, %cst_f0
956 ///    : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
957 ///  ```
958 struct CombineContractBroadcast
959     : public OpRewritePattern<vector::ContractionOp> {
960   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
961 
962   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
963                                 PatternRewriter &rewriter) const override {
964     SmallVector<AffineMap, 4> maps =
965         llvm::to_vector<4>(contractOp.getIndexingMaps());
966     Value lhs = contractOp.lhs();
967     Value rhs = contractOp.rhs();
968     size_t index = 0;
969     bool changed = false;
970     for (Value *operand : {&lhs, &rhs}) {
971       AffineMap &map = maps[index++];
972       auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
973       if (!broadcast)
974         continue;
975       // contractionOp can only take vector as operands.
976       auto srcType = broadcast.getSourceType().dyn_cast<VectorType>();
977       if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank())
978         continue;
979       int64_t rankDiff =
980           broadcast.getVectorType().getRank() - srcType.getRank();
981       bool innerDimBroadcast = false;
982       SmallVector<AffineExpr> originalDims;
983       for (const auto &dim : llvm::enumerate(srcType.getShape())) {
984         if (dim.value() !=
985             broadcast.getVectorType().getDimSize(rankDiff + dim.index())) {
986           innerDimBroadcast = true;
987           break;
988         }
989         originalDims.push_back(
990             rewriter.getAffineDimExpr(dim.index() + rankDiff));
991       }
992       // Contract doesn't support inner dimension broadcast. Once this is
993       // relaxed we can remove this case.
994       if (innerDimBroadcast)
995         continue;
996       AffineMap broadcastMap =
997           AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims,
998                          contractOp.getContext());
999       map = broadcastMap.compose(map);
1000       *operand = broadcast.source();
1001       changed = true;
1002     }
1003     if (!changed)
1004       return failure();
1005     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1006         contractOp, lhs, rhs, contractOp.acc(),
1007         rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types());
1008     return success();
1009   }
1010 };
1011 
1012 /// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and
1013 /// contraction ops closer, which kicks in CombineContractBroadcast pattern when
1014 /// casting ops are around these operations.
1015 /// Ex:
1016 /// ```
1017 ///   %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8>
1018 ///   %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
1019 /// ```
1020 /// Gets converted to:
1021 /// ```
1022 ///   %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32>
1023 ///   %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32>
1024 /// ```
1025 struct ReorderCastOpsOnBroadcast
1026     : public OpInterfaceRewritePattern<CastOpInterface> {
1027   using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
1028 
1029   LogicalResult matchAndRewrite(CastOpInterface op,
1030                                 PatternRewriter &rewriter) const override {
1031     if (op->getNumOperands() != 1)
1032       return failure();
1033     auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
1034     if (!bcastOp)
1035       return failure();
1036 
1037     Type castResTy = getElementTypeOrSelf(op->getResult(0));
1038     if (auto vecTy = bcastOp.getSourceType().dyn_cast<VectorType>())
1039       castResTy = VectorType::get(vecTy.getShape(), castResTy);
1040     auto castOp = rewriter.create(op->getLoc(), op->getName().getIdentifier(),
1041                                   bcastOp.source(), castResTy, op->getAttrs());
1042     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1043         op, op->getResult(0).getType(), castOp->getResult(0));
1044     return success();
1045   }
1046 };
1047 
1048 /// Reorders cast(transpose) to transpose(cast). This makes broadcast ops and
1049 /// contraction ops closer, which kicks in CombineContractTranspose pattern when
1050 /// casting ops are around these operations.
1051 /// Ex:
1052 /// ```
1053 ///   %0 = vector.transpose %arg0, [2, 0, 1]
1054 ///     : vector<32x16x8xi8> to vector<8x32x16xi8>
1055 ///   %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
1056 /// ```
1057 /// Gets converted to:
1058 /// ```
1059 ///   %0 = arith.extsi %0 : vector<32x16x8xi8> to vector<32x16x8xi32>
1060 ///   %1 = vector.transpose %arg0, [2, 0, 1]
1061 ///     : vector<32x16x8xi32> to vector<8x32x16xi32>
1062 /// ```
1063 struct ReorderCastOpsOnTranspose
1064     : public OpInterfaceRewritePattern<CastOpInterface> {
1065 
1066   using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
1067 
1068   LogicalResult matchAndRewrite(CastOpInterface op,
1069                                 PatternRewriter &rewriter) const override {
1070     if (op->getNumOperands() != 1)
1071       return failure();
1072     auto transpOp = op->getOperand(0).getDefiningOp<vector::TransposeOp>();
1073     if (!transpOp)
1074       return failure();
1075 
1076     auto castResTy = transpOp.getVectorType();
1077     castResTy = VectorType::get(castResTy.getShape(),
1078                                 getElementTypeOrSelf(op->getResult(0)));
1079     auto castOp = rewriter.create(op->getLoc(), op->getName().getIdentifier(),
1080                                   transpOp.vector(), castResTy, op->getAttrs());
1081     rewriter.replaceOpWithNewOp<vector::TransposeOp>(
1082         op, op->getResult(0).getType(), castOp->getResult(0),
1083         transpOp.getTransp());
1084     return success();
1085   }
1086 };
1087 
1088 } // namespace
1089 
1090 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
1091 /// operands `x` and `y`.
1092 static Value createAdd(Location loc, Value x, Value y, bool isInt,
1093                        PatternRewriter &rewriter) {
1094   if (isInt)
1095     return rewriter.create<arith::AddIOp>(loc, x, y);
1096   return rewriter.create<arith::AddFOp>(loc, x, y);
1097 }
1098 
1099 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
1100 /// operands `x and `y`.
1101 static Value createMul(Location loc, Value x, Value y, bool isInt,
1102                        PatternRewriter &rewriter) {
1103   if (isInt)
1104     return rewriter.create<arith::MulIOp>(loc, x, y);
1105   return rewriter.create<arith::MulFOp>(loc, x, y);
1106 }
1107 
1108 namespace mlir {
1109 
1110 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1111 /// semantics to:
1112 /// ```
1113 ///    %mta = maybe_transpose
1114 ///    %mtb = maybe_transpose
1115 ///    %flattened_a = vector.shape_cast %mta
1116 ///    %flattened_b = vector.shape_cast %mtb
1117 ///    %flattened_d = vector.matmul %flattened_a, %flattened_b
1118 ///    %mtd = vector.shape_cast %flattened_d
1119 ///    %d = maybe_untranspose %mtd
1120 ///    %e = add %c, %d
1121 /// ```
1122 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
1123 //
1124 /// This only kicks in when VectorTransformsOptions is set to `Matmul`.
1125 /// vector.transpose operations are inserted if the vector.contract op is not a
1126 /// row-major matrix multiply.
1127 LogicalResult
1128 ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
1129                                                  PatternRewriter &rew) const {
1130   // TODO: implement masks
1131   if (llvm::size(op.masks()) != 0)
1132     return failure();
1133   if (vectorTransformOptions.vectorContractLowering !=
1134       vector::VectorContractLowering::Matmul)
1135     return failure();
1136   if (failed(filter(op)))
1137     return failure();
1138 
1139   auto iteratorTypes = op.iterator_types().getValue();
1140   if (!isParallelIterator(iteratorTypes[0]) ||
1141       !isParallelIterator(iteratorTypes[1]) ||
1142       !isReductionIterator(iteratorTypes[2]))
1143     return failure();
1144 
1145   Type elementType = op.getLhsType().getElementType();
1146   if (!elementType.isIntOrFloat())
1147     return failure();
1148 
1149   // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
1150   // Bail out if the contraction cannot be put in this form.
1151   MLIRContext *ctx = op.getContext();
1152   Location loc = op.getLoc();
1153   AffineExpr m, n, k;
1154   bindDims(rew.getContext(), m, n, k);
1155   // LHS must be A(m, k) or A(k, m).
1156   Value lhs = op.lhs();
1157   auto lhsMap = op.indexing_maps()[0];
1158   if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
1159     lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
1160   else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
1161     return failure();
1162 
1163   // RHS must be B(k, n) or B(n, k).
1164   Value rhs = op.rhs();
1165   auto rhsMap = op.indexing_maps()[1];
1166   if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
1167     rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
1168   else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
1169     return failure();
1170 
1171   // At this point lhs and rhs are in row-major.
1172   VectorType lhsType = lhs.getType().cast<VectorType>();
1173   VectorType rhsType = rhs.getType().cast<VectorType>();
1174   int64_t lhsRows = lhsType.getDimSize(0);
1175   int64_t lhsColumns = lhsType.getDimSize(1);
1176   int64_t rhsColumns = rhsType.getDimSize(1);
1177 
1178   Type flattenedLHSType =
1179       VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
1180   lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
1181 
1182   Type flattenedRHSType =
1183       VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
1184   rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
1185 
1186   Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
1187                                            rhsColumns);
1188   mul = rew.create<vector::ShapeCastOp>(
1189       loc,
1190       VectorType::get({lhsRows, rhsColumns},
1191                       getElementTypeOrSelf(op.acc().getType())),
1192       mul);
1193 
1194   // ACC must be C(m, n) or C(n, m).
1195   auto accMap = op.indexing_maps()[2];
1196   if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
1197     mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
1198   else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
1199     llvm_unreachable("invalid contraction semantics");
1200 
1201   Value res =
1202       elementType.isa<IntegerType>()
1203           ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.acc(), mul))
1204           : static_cast<Value>(rew.create<arith::AddFOp>(loc, op.acc(), mul));
1205 
1206   rew.replaceOp(op, res);
1207   return success();
1208 }
1209 
1210 namespace {
1211 struct IteratorType {
1212   IteratorType(StringRef strRef) : strRef(strRef) {}
1213   bool isOfType(Attribute attr) const {
1214     auto sAttr = attr.dyn_cast<StringAttr>();
1215     return sAttr && sAttr.getValue() == strRef;
1216   }
1217   StringRef strRef;
1218 };
1219 struct Par : public IteratorType {
1220   Par() : IteratorType(getParallelIteratorTypeName()) {}
1221 };
1222 struct Red : public IteratorType {
1223   Red() : IteratorType(getReductionIteratorTypeName()) {}
1224 };
1225 
1226 /// Generate a vector implementation for matmat, matvec and tmatvec.
1227 /// This unrolls outer-products along the reduction dimension.
1228 struct UnrolledOuterProductGenerator
1229     : public StructuredGenerator<vector::ContractionOp> {
1230 
1231   UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op)
1232       : StructuredGenerator<vector::ContractionOp>(builder, op),
1233         kind(op.kind()), lhs(op.lhs()), rhs(op.rhs()), res(op.acc()),
1234         lhsType(op.getLhsType()) {}
1235 
1236   Value t(Value v) {
1237     static constexpr std::array<int64_t, 2> perm = {1, 0};
1238     return builder.create<vector::TransposeOp>(loc, v, perm);
1239   }
1240 
1241   Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) {
1242     assert(reductionSize > 0);
1243     for (int64_t k = 0; k < reductionSize; ++k) {
1244       Value a = builder.create<vector::ExtractOp>(loc, lhs, k);
1245       Value b = builder.create<vector::ExtractOp>(loc, rhs, k);
1246       res = builder.create<vector::OuterProductOp>(loc, res.getType(), a, b,
1247                                                    res, kind);
1248     }
1249     return res;
1250   }
1251 
1252   /// Two outer parallel, one inner reduction (matmat flavor).
1253   FailureOr<Value> matmat() {
1254     if (!iters({Par(), Par(), Red()}))
1255       return failure();
1256     // Set up the parallel/reduction structure in the right form.
1257     AffineExpr m, n, k;
1258     bindDims(builder.getContext(), m, n, k);
1259     // Classical row-major matmul:  Just permute the lhs.
1260     if (layout({{m, k}, {k, n}, {m, n}}))
1261       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1262     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1263     if (layout({{m, k}, {n, k}, {m, n}})) {
1264       Value tlhs = t(lhs);
1265       return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1));
1266     }
1267     // No need to permute anything.
1268     if (layout({{k, m}, {k, n}, {m, n}}))
1269       return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1270     // Just permute the rhs.
1271     if (layout({{k, m}, {n, k}, {m, n}}))
1272       return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0));
1273     // Transposed output: swap RHS and LHS.
1274     // Classical row-major matmul: permute the lhs.
1275     if (layout({{m, k}, {k, n}, {n, m}}))
1276       return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1));
1277     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1278     if (layout({{m, k}, {n, k}, {n, m}})) {
1279       Value trhs = t(rhs);
1280       return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1));
1281     }
1282     if (layout({{k, m}, {k, n}, {n, m}}))
1283       return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1284     if (layout({{k, m}, {n, k}, {n, m}}))
1285       return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1286     return failure();
1287   }
1288 
1289   /// One outer parallel, one inner reduction (matvec flavor)
1290   FailureOr<Value> matvec() {
1291     if (!iters({Par(), Red()}))
1292       return failure();
1293     AffineExpr m, k;
1294     bindDims(builder.getContext(), m, k);
1295 
1296     // Case mat-vec: transpose.
1297     if (layout({{m, k}, {k}, {m}}))
1298       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1299     // Case mat-trans-vec: ready to go.
1300     if (layout({{k, m}, {k}, {m}}))
1301       return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1302     // Case vec-mat: swap and transpose.
1303     if (layout({{k}, {m, k}, {m}}))
1304       return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1305     // Case vec-mat-trans: swap and ready to go.
1306     if (layout({{k}, {k, m}, {m}}))
1307       return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1308     return failure();
1309   }
1310 
1311   //
1312   // One outer reduction, one inner parallel (tmatvec flavor)
1313   //
1314   FailureOr<Value> tmatvec() {
1315     if (!iters({Red(), Par()}))
1316       return failure();
1317     AffineExpr k, m;
1318     bindDims(builder.getContext(), k, m);
1319 
1320     // Case mat-vec: transpose.
1321     if (layout({{m, k}, {k}, {m}}))
1322       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1323     // Case mat-trans-vec: ready to go.
1324     if (layout({{k, m}, {k}, {m}}))
1325       return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1326     // Case vec-mat: swap and transpose.
1327     if (layout({{k}, {m, k}, {m}}))
1328       return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1329     // Case vec-mat-trans: swap and ready to go.
1330     if (layout({{k}, {k, m}, {m}}))
1331       return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1332     return failure();
1333   }
1334 
1335 private:
1336   vector::CombiningKind kind;
1337   Value lhs, rhs, res;
1338   VectorType lhsType;
1339 };
1340 } // namespace
1341 
1342 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1343 /// semantics to a reduction_size-unrolled sequence:
1344 /// ```
1345 ///    %at = vector.transpose %a, [1, 0]
1346 ///    %bRow0 = vector.extract %b[0]
1347 ///    %atRow0 = vector.extract %at[0]
1348 ///    %c0 = vector.outerproduct %atRow0, %bRow0, %c
1349 ///    ...
1350 ///    %bRowK = vector.extract %b[K]
1351 ///    %atRowK = vector.extract %at[K]
1352 ///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
1353 /// ```
1354 ///
1355 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but
1356 /// otherwise supports any layout permutation of the matrix-multiply.
1357 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
1358     vector::ContractionOp op, PatternRewriter &rewriter) const {
1359   // TODO: implement masks
1360   if (llvm::size(op.masks()) != 0)
1361     return failure();
1362 
1363   if (vectorTransformOptions.vectorContractLowering !=
1364       vector::VectorContractLowering::OuterProduct)
1365     return failure();
1366 
1367   if (failed(filter(op)))
1368     return failure();
1369 
1370   UnrolledOuterProductGenerator e(rewriter, op);
1371   FailureOr<Value> matmatRes = e.matmat();
1372   if (succeeded(matmatRes)) {
1373     rewriter.replaceOp(op, *matmatRes);
1374     return success();
1375   }
1376   FailureOr<Value> matvecRes = e.matvec();
1377   if (succeeded(matvecRes)) {
1378     rewriter.replaceOp(op, *matvecRes);
1379     return success();
1380   }
1381   FailureOr<Value> tmatvecRes = e.tmatvec();
1382   if (succeeded(tmatvecRes)) {
1383     rewriter.replaceOp(op, *tmatvecRes);
1384     return success();
1385   }
1386 
1387   return failure();
1388 }
1389 
1390 LogicalResult
1391 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
1392                                             PatternRewriter &rewriter) const {
1393   // TODO: implement masks
1394   if (llvm::size(op.masks()) != 0)
1395     return failure();
1396 
1397   if (failed(filter(op)))
1398     return failure();
1399 
1400   if (vectorTransformOptions.vectorContractLowering !=
1401       vector::VectorContractLowering::Dot)
1402     return failure();
1403 
1404   auto iteratorTypes = op.iterator_types().getValue();
1405   static constexpr std::array<int64_t, 2> perm = {1, 0};
1406   Location loc = op.getLoc();
1407   Value lhs = op.lhs(), rhs = op.rhs();
1408 
1409   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1410   auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
1411   AffineExpr m, n, k;
1412   bindDims(rewriter.getContext(), m, n, k);
1413   SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
1414   //
1415   // In the following we wish to make the reduction dimension innermost so we
1416   // can load vectors and just fmul + reduce into a scalar.
1417   //
1418   if (isParallelIterator(iteratorTypes[0]) &&
1419       isParallelIterator(iteratorTypes[1]) &&
1420       isReductionIterator(iteratorTypes[2])) {
1421     //
1422     // Two outer parallel, one inner reduction (matmat flavor).
1423     //
1424     if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1425       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1426     } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
1427       // No need to permute anything.
1428     } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1429       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1430       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1431     } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1432       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1433     } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1434       // This is the classical row-major matmul. Just permute the lhs.
1435       Value tmp = lhs;
1436       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1437       rhs = tmp;
1438     } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1439       std::swap(lhs, rhs);
1440     } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1441       Value tmp = lhs;
1442       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1443       rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
1444     } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1445       Value tmp = rhs;
1446       rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1447       lhs = tmp;
1448     } else {
1449       return failure();
1450     }
1451   } else if (isParallelIterator(iteratorTypes[0]) &&
1452              isReductionIterator(iteratorTypes[1])) {
1453     //
1454     // One outer parallel, one inner reduction (matvec flavor)
1455     //
1456     if (maps == infer({{m, n}, {n}, {m}})) {
1457       // No need to permute anything.
1458     } else if (maps == infer({{n, m}, {n}, {m}})) {
1459       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1460     } else if (maps == infer({{n}, {m, n}, {m}})) {
1461       std::swap(lhs, rhs);
1462     } else if (maps == infer({{n}, {n, m}, {m}})) {
1463       std::swap(lhs, rhs);
1464       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1465     } else {
1466       return failure();
1467     }
1468   } else {
1469     return failure();
1470   }
1471 
1472   VectorType dstType = op.getResultType().cast<VectorType>();
1473   assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
1474          "Expected dst type of rank 1 or 2");
1475 
1476   unsigned rank = dstType.getRank();
1477   unsigned dstRows = dstType.getShape()[0];
1478   unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
1479 
1480   // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
1481   Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
1482                                                  rewriter.getZeroAttr(dstType));
1483   bool isInt = dstType.getElementType().isa<IntegerType>();
1484   for (unsigned r = 0; r < dstRows; ++r) {
1485     Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
1486     for (unsigned c = 0; c < dstColumns; ++c) {
1487       Value b = rank == 1
1488                     ? rhs
1489                     : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
1490       Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
1491       Value reduced = rewriter.create<vector::ReductionOp>(
1492           op.getLoc(), vector::CombiningKind::ADD, m);
1493 
1494       SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
1495                                               : SmallVector<int64_t, 2>{r, c};
1496       res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
1497     }
1498   }
1499   if (auto acc = op.acc())
1500     res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
1501   rewriter.replaceOp(op, res);
1502   return success();
1503 }
1504 
1505 /// Progressive lowering of ContractionOp.
1506 /// One:
1507 ///   %x = vector.contract with at least one free/batch dimension
1508 /// is replaced by:
1509 ///   %a = vector.contract with one less free/batch dimension
1510 ///   %b = vector.contract with one less free/batch dimension
1511 ///   ..
1512 ///   %x = combine %a %b ..
1513 /// until a pure contraction is reached (no free/batch dimensions),
1514 /// which is replaced by a dot-product.
1515 ///
1516 /// This only kicks in when either VectorTransformsOptions is set
1517 /// to DOT or when other contraction patterns fail.
1518 //
1519 // TODO: break down into transpose/reshape/cast ops
1520 //               when they become available to avoid code dup
1521 // TODO: investigate lowering order impact on performance
1522 LogicalResult
1523 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
1524                                        PatternRewriter &rewriter) const {
1525   // TODO: implement masks.
1526   if (llvm::size(op.masks()) != 0)
1527     return failure();
1528 
1529   if (failed(filter(op)))
1530     return failure();
1531 
1532   // TODO: support mixed mode contract lowering.
1533   if (op.getLhsType().getElementType() !=
1534           getElementTypeOrSelf(op.getAccType()) ||
1535       op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
1536     return failure();
1537 
1538   // TODO: implement benefits, cost models.
1539   MLIRContext *ctx = op.getContext();
1540   ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
1541   if (succeeded(pat1.matchAndRewrite(op, rewriter)))
1542     return success();
1543   ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
1544   if (succeeded(pat2.matchAndRewrite(op, rewriter)))
1545     return success();
1546   ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
1547   if (succeeded(pat3.matchAndRewrite(op, rewriter)))
1548     return success();
1549 
1550   // Find first batch dimension in LHS/RHS, and lower when found.
1551   std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
1552   if (!batchDimMap.empty()) {
1553     int64_t lhsIndex = batchDimMap[0].first;
1554     int64_t rhsIndex = batchDimMap[0].second;
1555     rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter));
1556     return success();
1557   }
1558 
1559   // Collect contracting dimensions.
1560   std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
1561       op.getContractingDimMap();
1562   DenseSet<int64_t> lhsContractingDimSet;
1563   DenseSet<int64_t> rhsContractingDimSet;
1564   for (auto &dimPair : contractingDimMap) {
1565     lhsContractingDimSet.insert(dimPair.first);
1566     rhsContractingDimSet.insert(dimPair.second);
1567   }
1568 
1569   // Find first free dimension in LHS, and lower when found.
1570   VectorType lhsType = op.getLhsType();
1571   for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
1572     if (lhsContractingDimSet.count(lhsIndex) == 0) {
1573       rewriter.replaceOp(
1574           op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter));
1575       return success();
1576     }
1577   }
1578 
1579   // Find first free dimension in RHS, and lower when found.
1580   VectorType rhsType = op.getRhsType();
1581   for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
1582     if (rhsContractingDimSet.count(rhsIndex) == 0) {
1583       rewriter.replaceOp(
1584           op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter));
1585       return success();
1586     }
1587   }
1588 
1589   // Lower the first remaining reduction dimension.
1590   if (!contractingDimMap.empty()) {
1591     rewriter.replaceOp(op, lowerReduction(op, rewriter));
1592     return success();
1593   }
1594 
1595   return failure();
1596 }
1597 
1598 // Lower one parallel dimension.
1599 // TODO: consider reusing existing contract unrolling
1600 Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
1601                                            int64_t lhsIndex, int64_t rhsIndex,
1602                                            PatternRewriter &rewriter) const {
1603   VectorType lhsType = op.getLhsType();
1604   VectorType rhsType = op.getRhsType();
1605   VectorType resType = op.getResultType().cast<VectorType>();
1606   // Find the iterator type index and result index.
1607   SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
1608   int64_t iterIndex = -1;
1609   int64_t dimSize = -1;
1610   if (lhsIndex >= 0) {
1611     iterIndex = iMap[0].getDimPosition(lhsIndex);
1612     assert((rhsIndex < 0 || iterIndex == iMap[1].getDimPosition(rhsIndex)) &&
1613            "parallel index should be free in LHS or batch in LHS/RHS");
1614     dimSize = lhsType.getDimSize(lhsIndex);
1615   } else {
1616     assert(rhsIndex >= 0 && "missing parallel index");
1617     iterIndex = iMap[1].getDimPosition(rhsIndex);
1618     dimSize = rhsType.getDimSize(rhsIndex);
1619   }
1620   assert(iterIndex >= 0 && "parallel index not listed in operand mapping");
1621   Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex);
1622   assert(lookup.hasValue() && "parallel index not listed in reduction");
1623   int64_t resIndex = lookup.getValue();
1624   // Construct new iterator types and affine map array attribute.
1625   std::array<AffineMap, 3> lowIndexingMaps = {
1626       adjustMap(iMap[0], iterIndex, rewriter),
1627       adjustMap(iMap[1], iterIndex, rewriter),
1628       adjustMap(iMap[2], iterIndex, rewriter)};
1629   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1630   auto lowIter =
1631       rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
1632   // Unroll into a series of lower dimensional vector.contract ops.
1633   Location loc = op.getLoc();
1634   Value result = rewriter.create<arith::ConstantOp>(
1635       loc, resType, rewriter.getZeroAttr(resType));
1636   for (int64_t d = 0; d < dimSize; ++d) {
1637     auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
1638     auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
1639     auto acc = reshapeLoad(loc, op.acc(), resType, resIndex, d, rewriter);
1640     Value lowContract = rewriter.create<vector::ContractionOp>(
1641         loc, lhs, rhs, acc, lowAffine, lowIter);
1642     result =
1643         reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter);
1644   }
1645   return result;
1646 }
1647 
1648 // Lower one reduction dimension.
1649 Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
1650                                             PatternRewriter &rewriter) const {
1651   auto loc = op.getLoc();
1652   VectorType lhsType = op.getLhsType();
1653   VectorType rhsType = op.getRhsType();
1654   Type resType = op.getResultType();
1655   assert(!resType.isa<VectorType>());
1656   bool isInt = resType.isa<IntegerType>();
1657   // Use iterator index 0.
1658   int64_t iterIndex = 0;
1659   SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
1660   Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
1661   Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
1662   assert(lookupLhs.hasValue() && "missing LHS parallel index");
1663   assert(lookupRhs.hasValue() && "missing RHS parallel index");
1664   int64_t lhsIndex = lookupLhs.getValue();
1665   int64_t rhsIndex = lookupRhs.getValue();
1666   int64_t dimSize = lhsType.getDimSize(lhsIndex);
1667   assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape");
1668   // Base case.
1669   if (lhsType.getRank() == 1) {
1670     assert(rhsType.getRank() == 1 && "corrupt contraction");
1671     Value m = createMul(loc, op.lhs(), op.rhs(), isInt, rewriter);
1672     auto kind = vector::CombiningKind::ADD;
1673     Value res = rewriter.create<vector::ReductionOp>(loc, kind, m);
1674     if (auto acc = op.acc())
1675       res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
1676     return res;
1677   }
1678   // Construct new iterator types and affine map array attribute.
1679   std::array<AffineMap, 3> lowIndexingMaps = {
1680       adjustMap(iMap[0], iterIndex, rewriter),
1681       adjustMap(iMap[1], iterIndex, rewriter),
1682       adjustMap(iMap[2], iterIndex, rewriter)};
1683   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1684   auto lowIter =
1685       rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
1686   // Unroll into a series of lower dimensional vector.contract ops.
1687   // By feeding the initial accumulator into the first contraction,
1688   // and the result of each contraction into the next, eventually
1689   // the sum of all reductions is computed.
1690   Value result = op.acc();
1691   for (int64_t d = 0; d < dimSize; ++d) {
1692     auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
1693     auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
1694     result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result,
1695                                                     lowAffine, lowIter);
1696   }
1697   return result;
1698 }
1699 
1700 } // namespace mlir
1701 
1702 Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
1703     OpBuilder &builder, Operation *op, ArrayRef<Value> ids,
1704     ArrayRef<int64_t> multiplicity, const AffineMap &map) {
1705   OpBuilder::InsertionGuard guard(builder);
1706   builder.setInsertionPointAfter(op);
1707   Location loc = op->getLoc();
1708   if (op->getNumResults() != 1)
1709     return {};
1710   Value result = op->getResult(0);
1711   VectorType type = op->getResult(0).getType().dyn_cast<VectorType>();
1712   if (!type || map.getNumResults() != multiplicity.size())
1713     return {};
1714   // For each dimension being distributed check that the size is a multiple of
1715   // the multiplicity. To handle more sizes we would need to support masking.
1716   unsigned multiplictyCount = 0;
1717   for (auto exp : map.getResults()) {
1718     auto affinExp = exp.dyn_cast<AffineDimExpr>();
1719     if (!affinExp || affinExp.getPosition() >= type.getRank() ||
1720         type.getDimSize(affinExp.getPosition()) %
1721                 multiplicity[multiplictyCount++] !=
1722             0)
1723       return {};
1724   }
1725   DistributeOps ops;
1726   ops.extract =
1727       builder.create<vector::ExtractMapOp>(loc, result, ids, multiplicity, map);
1728   ops.insert =
1729       builder.create<vector::InsertMapOp>(loc, ops.extract, result, ids);
1730   return ops;
1731 }
1732 
1733 /// Progressive lowering of transfer_read. This pattern supports lowering of
1734 /// `vector.transfer_read` to a combination of `vector.load` and
1735 /// `vector.broadcast` if all of the following hold:
1736 /// - Stride of most minor memref dimension must be 1.
1737 /// - Out-of-bounds masking is not required.
1738 /// - If the memref's element type is a vector type then it coincides with the
1739 ///   result type.
1740 /// - The permutation map doesn't perform permutation (broadcasting is allowed).
1741 struct TransferReadToVectorLoadLowering
1742     : public OpRewritePattern<vector::TransferReadOp> {
1743   TransferReadToVectorLoadLowering(MLIRContext *context,
1744                                    llvm::Optional<unsigned> maxRank)
1745       : OpRewritePattern<vector::TransferReadOp>(context),
1746         maxTransferRank(maxRank) {}
1747 
1748   LogicalResult matchAndRewrite(vector::TransferReadOp read,
1749                                 PatternRewriter &rewriter) const override {
1750     if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank)
1751       return failure();
1752 
1753     SmallVector<unsigned, 4> broadcastedDims;
1754     // Permutations are handled by VectorToSCF or
1755     // populateVectorTransferPermutationMapLoweringPatterns.
1756     // We let the 0-d corner case pass-through as it is supported.
1757     if (!read.permutation_map().isMinorIdentityWithBroadcasting(
1758             &broadcastedDims))
1759       return failure();
1760 
1761     auto memRefType = read.getShapedType().dyn_cast<MemRefType>();
1762     if (!memRefType)
1763       return failure();
1764 
1765     // Non-unit strides are handled by VectorToSCF.
1766     if (!vector::isLastMemrefDimUnitStride(memRefType))
1767       return failure();
1768 
1769     // If there is broadcasting involved then we first load the unbroadcasted
1770     // vector, and then broadcast it with `vector.broadcast`.
1771     ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
1772     SmallVector<int64_t, 4> unbroadcastedVectorShape(vectorShape.begin(),
1773                                                      vectorShape.end());
1774     for (unsigned i : broadcastedDims)
1775       unbroadcastedVectorShape[i] = 1;
1776     VectorType unbroadcastedVectorType = VectorType::get(
1777         unbroadcastedVectorShape, read.getVectorType().getElementType());
1778 
1779     // `vector.load` supports vector types as memref's elements only when the
1780     // resulting vector type is the same as the element type.
1781     auto memrefElTy = memRefType.getElementType();
1782     if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType)
1783       return failure();
1784 
1785     // Otherwise, element types of the memref and the vector must match.
1786     if (!memrefElTy.isa<VectorType>() &&
1787         memrefElTy != read.getVectorType().getElementType())
1788       return failure();
1789 
1790     // Out-of-bounds dims are handled by MaterializeTransferMask.
1791     if (read.hasOutOfBoundsDim())
1792       return failure();
1793 
1794     // Create vector load op.
1795     Operation *loadOp;
1796     if (read.mask()) {
1797       Value fill = rewriter.create<vector::SplatOp>(
1798           read.getLoc(), unbroadcastedVectorType, read.padding());
1799       loadOp = rewriter.create<vector::MaskedLoadOp>(
1800           read.getLoc(), unbroadcastedVectorType, read.source(), read.indices(),
1801           read.mask(), fill);
1802     } else {
1803       loadOp = rewriter.create<vector::LoadOp>(read.getLoc(),
1804                                                unbroadcastedVectorType,
1805                                                read.source(), read.indices());
1806     }
1807 
1808     // Insert a broadcasting op if required.
1809     if (!broadcastedDims.empty()) {
1810       rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1811           read, read.getVectorType(), loadOp->getResult(0));
1812     } else {
1813       rewriter.replaceOp(read, loadOp->getResult(0));
1814     }
1815 
1816     return success();
1817   }
1818 
1819   llvm::Optional<unsigned> maxTransferRank;
1820 };
1821 
1822 /// Replace a 0-d vector.load with a memref.load + vector.broadcast.
1823 // TODO: we shouldn't cross the vector/scalar domains just for this
1824 // but atm we lack the infra to avoid it. Possible solutions include:
1825 // - go directly to LLVM + bitcast
1826 // - introduce a bitcast op and likely a new pointer dialect
1827 // - let memref.load/store additionally support the 0-d vector case
1828 // There are still deeper data layout issues lingering even in this
1829 // trivial case (for architectures for which this matters).
1830 struct VectorLoadToMemrefLoadLowering
1831     : public OpRewritePattern<vector::LoadOp> {
1832   using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
1833 
1834   LogicalResult matchAndRewrite(vector::LoadOp loadOp,
1835                                 PatternRewriter &rewriter) const override {
1836     auto vecType = loadOp.getVectorType();
1837     if (vecType.getNumElements() != 1)
1838       return failure();
1839     auto memrefLoad = rewriter.create<memref::LoadOp>(
1840         loadOp.getLoc(), loadOp.base(), loadOp.indices());
1841     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
1842                                                      memrefLoad);
1843     return success();
1844   }
1845 };
1846 
1847 /// Replace a 0-d vector.store with a vector.extractelement + memref.store.
1848 struct VectorStoreToMemrefStoreLowering
1849     : public OpRewritePattern<vector::StoreOp> {
1850   using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
1851 
1852   LogicalResult matchAndRewrite(vector::StoreOp storeOp,
1853                                 PatternRewriter &rewriter) const override {
1854     auto vecType = storeOp.getVectorType();
1855     if (vecType.getNumElements() != 1)
1856       return failure();
1857     Value extracted;
1858     if (vecType.getRank() == 0) {
1859       // TODO: Unifiy once ExtractOp supports 0-d vectors.
1860       extracted = rewriter.create<vector::ExtractElementOp>(
1861           storeOp.getLoc(), storeOp.valueToStore());
1862     } else {
1863       SmallVector<int64_t> indices(vecType.getRank(), 0);
1864       extracted = rewriter.create<vector::ExtractOp>(
1865           storeOp.getLoc(), storeOp.valueToStore(), indices);
1866     }
1867 
1868     rewriter.replaceOpWithNewOp<memref::StoreOp>(
1869         storeOp, extracted, storeOp.base(), storeOp.indices());
1870     return success();
1871   }
1872 };
1873 
1874 /// Progressive lowering of transfer_write. This pattern supports lowering of
1875 /// `vector.transfer_write` to `vector.store` if all of the following hold:
1876 /// - Stride of most minor memref dimension must be 1.
1877 /// - Out-of-bounds masking is not required.
1878 /// - If the memref's element type is a vector type then it coincides with the
1879 ///   type of the written value.
1880 /// - The permutation map is the minor identity map (neither permutation nor
1881 ///   broadcasting is allowed).
1882 struct TransferWriteToVectorStoreLowering
1883     : public OpRewritePattern<vector::TransferWriteOp> {
1884   TransferWriteToVectorStoreLowering(MLIRContext *context,
1885                                      llvm::Optional<unsigned> maxRank)
1886       : OpRewritePattern<vector::TransferWriteOp>(context),
1887         maxTransferRank(maxRank) {}
1888 
1889   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
1890                                 PatternRewriter &rewriter) const override {
1891     if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank)
1892       return failure();
1893 
1894     // Permutations are handled by VectorToSCF or
1895     // populateVectorTransferPermutationMapLoweringPatterns.
1896     if ( // pass-through for the 0-d corner case.
1897         !write.permutation_map().isMinorIdentity())
1898       return failure();
1899 
1900     auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
1901     if (!memRefType)
1902       return failure();
1903 
1904     // Non-unit strides are handled by VectorToSCF.
1905     if (!vector::isLastMemrefDimUnitStride(memRefType))
1906       return failure();
1907 
1908     // `vector.store` supports vector types as memref's elements only when the
1909     // type of the vector value being written is the same as the element type.
1910     auto memrefElTy = memRefType.getElementType();
1911     if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType())
1912       return failure();
1913 
1914     // Otherwise, element types of the memref and the vector must match.
1915     if (!memrefElTy.isa<VectorType>() &&
1916         memrefElTy != write.getVectorType().getElementType())
1917       return failure();
1918 
1919     // Out-of-bounds dims are handled by MaterializeTransferMask.
1920     if (write.hasOutOfBoundsDim())
1921       return failure();
1922     if (write.mask()) {
1923       rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
1924           write, write.source(), write.indices(), write.mask(), write.vector());
1925     } else {
1926       rewriter.replaceOpWithNewOp<vector::StoreOp>(
1927           write, write.vector(), write.source(), write.indices());
1928     }
1929     return success();
1930   }
1931 
1932   llvm::Optional<unsigned> maxTransferRank;
1933 };
1934 
1935 // Returns the values in `arrayAttr` as an integer vector.
1936 static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) {
1937   return llvm::to_vector<4>(
1938       llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
1939                       [](IntegerAttr attr) { return attr.getInt(); }));
1940 }
1941 
1942 // Shuffles vector.bitcast op after vector.extract op.
1943 //
1944 // This transforms IR like:
1945 //   %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
1946 //   %1 = vector.extract %0[3] : vector<8xf16>
1947 // Into:
1948 //   %0 = vector.extract %src[1] : vector<4xf32>
1949 //   %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16>
1950 //   %2 = vector.extract %1[1] : vector<2xf16>
1951 struct BubbleDownVectorBitCastForExtract
1952     : public OpRewritePattern<vector::ExtractOp> {
1953   using OpRewritePattern::OpRewritePattern;
1954 
1955   LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
1956                                 PatternRewriter &rewriter) const override {
1957     // Only support extracting scalars for now.
1958     if (extractOp.getVectorType().getRank() != 1)
1959       return failure();
1960 
1961     auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>();
1962     if (!castOp)
1963       return failure();
1964 
1965     VectorType castSrcType = castOp.getSourceVectorType();
1966     VectorType castDstType = castOp.getResultVectorType();
1967     assert(castSrcType.getRank() == castDstType.getRank());
1968 
1969     // Fail to match if we only have one element in the cast op source.
1970     // This is to avoid infinite loop given that this pattern can generate
1971     // such cases.
1972     if (castSrcType.getNumElements() == 1)
1973       return failure();
1974 
1975     // Only support casting to a larger number of elements or now.
1976     // E.g., vector<4xf32> -> vector<8xf16>.
1977     if (castSrcType.getNumElements() > castDstType.getNumElements())
1978       return failure();
1979 
1980     unsigned expandRatio =
1981         castDstType.getNumElements() / castSrcType.getNumElements();
1982 
1983     auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t {
1984       return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
1985     };
1986 
1987     uint64_t index = getFirstIntValue(extractOp.position());
1988 
1989     // Get the single scalar (as a vector) in the source value that packs the
1990     // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
1991     VectorType oneScalarType =
1992         VectorType::get({1}, castSrcType.getElementType());
1993     Value packedValue = rewriter.create<vector::ExtractOp>(
1994         extractOp.getLoc(), oneScalarType, castOp.source(),
1995         rewriter.getI64ArrayAttr(index / expandRatio));
1996 
1997     // Cast it to a vector with the desired scalar's type.
1998     // E.g. f32 -> vector<2xf16>
1999     VectorType packedType =
2000         VectorType::get({expandRatio}, castDstType.getElementType());
2001     Value castedValue = rewriter.create<vector::BitCastOp>(
2002         extractOp.getLoc(), packedType, packedValue);
2003 
2004     // Finally extract the desired scalar.
2005     rewriter.replaceOpWithNewOp<vector::ExtractOp>(
2006         extractOp, extractOp.getType(), castedValue,
2007         rewriter.getI64ArrayAttr(index % expandRatio));
2008 
2009     return success();
2010   }
2011 };
2012 
2013 // Shuffles vector.bitcast op after vector.extract_strided_slice op.
2014 //
2015 // This transforms IR like:
2016 //    %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
2017 //     %0 = vector.extract_strided_slice %cast {
2018 //            offsets = [4], sizes = [4], strides = [1]
2019 //          } : vector<8xf16> to vector<4xf16>
2020 // Into:
2021 //   %0 = vector.extract_strided_slice %src {
2022 //          offsets = [2], sizes = [2], strides = [1]
2023 //        } : vector<4xf32> to vector<2xf32>
2024 //   %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16>
2025 struct BubbleDownBitCastForStridedSliceExtract
2026     : public OpRewritePattern<vector::ExtractStridedSliceOp> {
2027   using OpRewritePattern::OpRewritePattern;
2028 
2029   LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
2030                                 PatternRewriter &rewriter) const override {
2031     auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>();
2032     if (!castOp)
2033       return failure();
2034 
2035     VectorType castSrcType = castOp.getSourceVectorType();
2036     VectorType castDstType = castOp.getResultVectorType();
2037     assert(castSrcType.getRank() == castDstType.getRank());
2038 
2039     int64_t castSrcLastDim = castSrcType.getShape().back();
2040     int64_t castDstLastDim = castDstType.getShape().back();
2041     // Require casting to more elements for now; other cases to be implemented.
2042     if (castSrcLastDim > castDstLastDim)
2043       return failure();
2044 
2045     // Only accept all one strides for now.
2046     if (llvm::any_of(extractOp.strides().getAsValueRange<IntegerAttr>(),
2047                      [](const APInt &val) { return !val.isOneValue(); }))
2048       return failure();
2049 
2050     unsigned rank = extractOp.getVectorType().getRank();
2051     assert(castDstLastDim % castSrcLastDim == 0);
2052     int64_t expandRatio = castDstLastDim / castSrcLastDim;
2053 
2054     // If we have a less number of offsets than the rank, then implicitly we
2055     // are selecting the full range for the last bitcasted dimension; other
2056     // dimensions aren't affected. Otherwise, we need to scale down the last
2057     // dimension's offset given we are extracting from less elements now.
2058     ArrayAttr newOffsets = extractOp.offsets();
2059     if (newOffsets.size() == rank) {
2060       SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
2061       if (offsets.back() % expandRatio != 0)
2062         return failure();
2063       offsets.back() = offsets.back() / expandRatio;
2064       newOffsets = rewriter.getI64ArrayAttr(offsets);
2065     }
2066 
2067     // Similarly for sizes.
2068     ArrayAttr newSizes = extractOp.sizes();
2069     if (newSizes.size() == rank) {
2070       SmallVector<int64_t, 4> sizes = getIntValueVector(newSizes);
2071       if (sizes.back() % expandRatio != 0)
2072         return failure();
2073       sizes.back() = sizes.back() / expandRatio;
2074       newSizes = rewriter.getI64ArrayAttr(sizes);
2075     }
2076 
2077     SmallVector<int64_t, 4> dims =
2078         llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape());
2079     dims.back() = dims.back() / expandRatio;
2080     VectorType newExtractType =
2081         VectorType::get(dims, castSrcType.getElementType());
2082 
2083     auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
2084         extractOp.getLoc(), newExtractType, castOp.source(), newOffsets,
2085         newSizes, extractOp.strides());
2086 
2087     rewriter.replaceOpWithNewOp<vector::BitCastOp>(
2088         extractOp, extractOp.getType(), newExtractOp);
2089 
2090     return success();
2091   }
2092 };
2093 
2094 // Shuffles vector.bitcast op before vector.insert_strided_slice op.
2095 //
2096 // This transforms IR like:
2097 //   %0 = vector.insert_strided_slice %src, %dst {
2098 //          offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
2099 //   %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
2100 // Into:
2101 //   %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32>
2102 //   %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32>
2103 //   %2 = vector.insert_strided_slice %src, %dst {
2104 //          offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
2105 struct BubbleUpBitCastForStridedSliceInsert
2106     : public OpRewritePattern<vector::BitCastOp> {
2107   using OpRewritePattern::OpRewritePattern;
2108   LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
2109                                 PatternRewriter &rewriter) const override {
2110     VectorType castSrcType = bitcastOp.getSourceVectorType();
2111     VectorType castDstType = bitcastOp.getResultVectorType();
2112     assert(castSrcType.getRank() == castDstType.getRank());
2113 
2114     int64_t castSrcLastDim = castSrcType.getShape().back();
2115     int64_t castDstLastDim = castDstType.getShape().back();
2116     // Require casting to less elements for now; other cases to be implemented.
2117     if (castSrcLastDim < castDstLastDim)
2118       return failure();
2119 
2120     assert(castSrcLastDim % castDstLastDim == 0);
2121     int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
2122 
2123     auto insertOp =
2124         bitcastOp.source().getDefiningOp<vector::InsertStridedSliceOp>();
2125     if (!insertOp)
2126       return failure();
2127 
2128     // Only accept all one strides for now.
2129     if (llvm::any_of(insertOp.strides().getAsValueRange<IntegerAttr>(),
2130                      [](const APInt &val) { return !val.isOneValue(); }))
2131       return failure();
2132 
2133     unsigned rank = insertOp.getSourceVectorType().getRank();
2134     // Require insert op to have the same rank for the source and destination
2135     // vector; other cases to be implemented.
2136     if (rank != insertOp.getDestVectorType().getRank())
2137       return failure();
2138 
2139     ArrayAttr newOffsets = insertOp.offsets();
2140     assert(newOffsets.size() == rank);
2141     SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
2142     if (offsets.back() % shrinkRatio != 0)
2143       return failure();
2144     offsets.back() = offsets.back() / shrinkRatio;
2145     newOffsets = rewriter.getI64ArrayAttr(offsets);
2146 
2147     SmallVector<int64_t, 4> srcDims =
2148         llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
2149     srcDims.back() = srcDims.back() / shrinkRatio;
2150     VectorType newCastSrcType =
2151         VectorType::get(srcDims, castDstType.getElementType());
2152 
2153     auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
2154         bitcastOp.getLoc(), newCastSrcType, insertOp.source());
2155 
2156     SmallVector<int64_t, 4> dstDims =
2157         llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
2158     dstDims.back() = dstDims.back() / shrinkRatio;
2159     VectorType newCastDstType =
2160         VectorType::get(dstDims, castDstType.getElementType());
2161 
2162     auto newCastDstOp = rewriter.create<vector::BitCastOp>(
2163         bitcastOp.getLoc(), newCastDstType, insertOp.dest());
2164 
2165     rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
2166         bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
2167         insertOp.strides());
2168 
2169     return success();
2170   }
2171 };
2172 
2173 // Helper that returns a vector comparison that constructs a mask:
2174 //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
2175 //
2176 // If `dim == 0` then the result will be a 0-D vector.
2177 //
2178 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
2179 //       much more compact, IR for this operation, but LLVM eventually
2180 //       generates more elaborate instructions for this intrinsic since it
2181 //       is very conservative on the boundary conditions.
2182 static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
2183                                    bool indexOptimizations, int64_t dim,
2184                                    Value b, Value *off = nullptr) {
2185   auto loc = op->getLoc();
2186   // If we can assume all indices fit in 32-bit, we perform the vector
2187   // comparison in 32-bit to get a higher degree of SIMD parallelism.
2188   // Otherwise we perform the vector comparison using 64-bit indices.
2189   Type idxType =
2190       indexOptimizations ? rewriter.getI32Type() : rewriter.getI64Type();
2191   DenseIntElementsAttr indicesAttr;
2192   if (dim == 0 && indexOptimizations) {
2193     indicesAttr = DenseIntElementsAttr::get(
2194         VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0});
2195   } else if (dim == 0) {
2196     indicesAttr = DenseIntElementsAttr::get(
2197         VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0});
2198   } else if (indexOptimizations) {
2199     indicesAttr = rewriter.getI32VectorAttr(
2200         llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
2201   } else {
2202     indicesAttr = rewriter.getI64VectorAttr(
2203         llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
2204   }
2205   Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
2206   // Add in an offset if requested.
2207   if (off) {
2208     Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
2209     Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o);
2210     indices = rewriter.create<arith::AddIOp>(loc, ov, indices);
2211   }
2212   // Construct the vector comparison.
2213   Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
2214   Value bounds =
2215       rewriter.create<vector::SplatOp>(loc, indices.getType(), bound);
2216   return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
2217                                         bounds);
2218 }
2219 
2220 template <typename ConcreteOp>
2221 struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
2222 public:
2223   explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt)
2224       : mlir::OpRewritePattern<ConcreteOp>(context),
2225         indexOptimizations(enableIndexOpt) {}
2226 
2227   LogicalResult matchAndRewrite(ConcreteOp xferOp,
2228                                 PatternRewriter &rewriter) const override {
2229     if (!xferOp.hasOutOfBoundsDim())
2230       return failure();
2231 
2232     if (xferOp.getVectorType().getRank() > 1 ||
2233         llvm::size(xferOp.indices()) == 0)
2234       return failure();
2235 
2236     Location loc = xferOp->getLoc();
2237     VectorType vtp = xferOp.getVectorType();
2238 
2239     // Create the in-bounds mask with all elements between [0 .. dim - offset)
2240     // set and [dim - offset .. vector_length) unset.
2241     //
2242     // TODO: when the leaf transfer rank is k > 1, we need the last `k`
2243     //       dimensions here.
2244     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
2245     Value off = xferOp.indices()[lastIndex];
2246     Value dim =
2247         vector::createOrFoldDimOp(rewriter, loc, xferOp.source(), lastIndex);
2248     Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off);
2249     Value mask = rewriter.create<vector::CreateMaskOp>(
2250         loc,
2251         VectorType::get(vtp.getShape(), rewriter.getI1Type(),
2252                         vtp.getNumScalableDims()),
2253         b);
2254     if (xferOp.mask()) {
2255       // Intersect the in-bounds with the mask specified as an op parameter.
2256       mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.mask());
2257     }
2258 
2259     rewriter.updateRootInPlace(xferOp, [&]() {
2260       xferOp.maskMutable().assign(mask);
2261       xferOp.in_boundsAttr(rewriter.getBoolArrayAttr({true}));
2262     });
2263 
2264     return success();
2265   }
2266 
2267 private:
2268   const bool indexOptimizations;
2269 };
2270 
2271 /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only).
2272 class VectorCreateMaskOpConversion
2273     : public OpRewritePattern<vector::CreateMaskOp> {
2274 public:
2275   explicit VectorCreateMaskOpConversion(MLIRContext *context,
2276                                         bool enableIndexOpt)
2277       : mlir::OpRewritePattern<vector::CreateMaskOp>(context),
2278         indexOptimizations(enableIndexOpt) {}
2279 
2280   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
2281                                 PatternRewriter &rewriter) const override {
2282     auto dstType = op.getType();
2283     if (dstType.cast<VectorType>().isScalable())
2284       return failure();
2285     int64_t rank = dstType.getRank();
2286     if (rank > 1)
2287       return failure();
2288     rewriter.replaceOp(
2289         op, buildVectorComparison(rewriter, op, indexOptimizations,
2290                                   rank == 0 ? 0 : dstType.getDimSize(0),
2291                                   op.getOperand(0)));
2292     return success();
2293   }
2294 
2295 private:
2296   const bool indexOptimizations;
2297 };
2298 
2299 // Drop inner most contiguous unit dimensions from transfer_read operand.
2300 class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
2301   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
2302 
2303   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
2304                                 PatternRewriter &rewriter) const override {
2305     // TODO: support 0-d corner case.
2306     if (readOp.getTransferRank() == 0)
2307       return failure();
2308 
2309     // TODO: support mask.
2310     if (readOp.mask())
2311       return failure();
2312 
2313     auto srcType = readOp.source().getType().dyn_cast<MemRefType>();
2314     if (!srcType || !srcType.hasStaticShape())
2315       return failure();
2316 
2317     if (!readOp.permutation_map().isMinorIdentity())
2318       return failure();
2319 
2320     auto targetType = readOp.getVectorType();
2321     if (targetType.getRank() <= 1)
2322       return failure();
2323 
2324     SmallVector<int64_t> srcStrides;
2325     int64_t srcOffset;
2326     if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
2327       return failure();
2328 
2329     size_t dimsToDrop = 0;
2330     for (size_t i = 1; i < srcStrides.size(); ++i) {
2331       int dim = srcType.getRank() - i - 1;
2332       if (srcStrides[dim] == 1) {
2333         dimsToDrop++;
2334       } else {
2335         break;
2336       }
2337     }
2338     if (dimsToDrop == 0)
2339       return failure();
2340 
2341     auto resultTargetVecType =
2342         VectorType::get(targetType.getShape().drop_back(dimsToDrop),
2343                         targetType.getElementType());
2344 
2345     MemRefType resultMemrefType;
2346     if (srcType.getLayout().getAffineMap().isIdentity()) {
2347       resultMemrefType = MemRefType::get(
2348           srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
2349           {}, srcType.getMemorySpaceAsInt());
2350     } else {
2351       AffineMap map = srcType.getLayout().getAffineMap();
2352       int numResultDims = map.getNumDims() - dimsToDrop;
2353       int numSymbols = map.getNumSymbols();
2354       for (size_t i = 0; i < dimsToDrop; ++i) {
2355         int dim = srcType.getRank() - i - 1;
2356         map = map.replace(rewriter.getAffineDimExpr(dim),
2357                           rewriter.getAffineConstantExpr(0), numResultDims,
2358                           numSymbols);
2359       }
2360       resultMemrefType = MemRefType::get(
2361           srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
2362           map, srcType.getMemorySpaceAsInt());
2363     }
2364 
2365     auto loc = readOp.getLoc();
2366     SmallVector<int64_t> offsets(srcType.getRank(), 0);
2367     SmallVector<int64_t> strides(srcType.getRank(), 1);
2368 
2369     ArrayAttr inBoundsAttr =
2370         readOp.in_bounds()
2371             ? rewriter.getArrayAttr(
2372                   readOp.in_boundsAttr().getValue().drop_back(dimsToDrop))
2373             : ArrayAttr();
2374     Value rankedReducedView = rewriter.create<memref::SubViewOp>(
2375         loc, resultMemrefType, readOp.source(), offsets, srcType.getShape(),
2376         strides);
2377     auto permMap = getTransferMinorIdentityMap(
2378         rankedReducedView.getType().cast<ShapedType>(), resultTargetVecType);
2379     Value result = rewriter.create<vector::TransferReadOp>(
2380         loc, resultTargetVecType, rankedReducedView,
2381         readOp.indices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
2382         readOp.padding(),
2383         // TODO: support mask.
2384         /*mask=*/Value(), inBoundsAttr);
2385     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
2386                                                      result);
2387     return success();
2388   }
2389 };
2390 
2391 namespace {
2392 
2393 /// This function checks to see if the vector combining kind
2394 /// is consistent with the integer or float element type.
2395 static bool isValidKind(bool isInt, vector::CombiningKind kind) {
2396   using vector::CombiningKind;
2397   enum class KindType { FLOAT, INT, INVALID };
2398   KindType type{KindType::INVALID};
2399   switch (kind) {
2400   case CombiningKind::MINF:
2401   case CombiningKind::MAXF:
2402     type = KindType::FLOAT;
2403     break;
2404   case CombiningKind::MINUI:
2405   case CombiningKind::MINSI:
2406   case CombiningKind::MAXUI:
2407   case CombiningKind::MAXSI:
2408   case CombiningKind::AND:
2409   case CombiningKind::OR:
2410   case CombiningKind::XOR:
2411     type = KindType::INT;
2412     break;
2413   case CombiningKind::ADD:
2414   case CombiningKind::MUL:
2415     type = isInt ? KindType::INT : KindType::FLOAT;
2416     break;
2417   }
2418   bool isValidIntKind = (type == KindType::INT) && isInt;
2419   bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt);
2420   return (isValidIntKind || isValidFloatKind);
2421 }
2422 
2423 /// This function constructs the appropriate integer or float
2424 /// operation given the vector combining kind and operands. The
2425 /// supported int operations are : add, mul, min (signed/unsigned),
2426 /// max(signed/unsigned), and, or, xor. The supported float
2427 /// operations are : add, mul, min and max.
2428 static Value genOperator(Location loc, Value x, Value y,
2429                          vector::CombiningKind kind,
2430                          PatternRewriter &rewriter) {
2431   using vector::CombiningKind;
2432 
2433   auto elType = x.getType().cast<VectorType>().getElementType();
2434   bool isInt = elType.isIntOrIndex();
2435 
2436   Value combinedResult{nullptr};
2437   switch (kind) {
2438   case CombiningKind::ADD:
2439     if (isInt)
2440       combinedResult = rewriter.create<arith::AddIOp>(loc, x, y);
2441     else
2442       combinedResult = rewriter.create<arith::AddFOp>(loc, x, y);
2443     break;
2444   case CombiningKind::MUL:
2445     if (isInt)
2446       combinedResult = rewriter.create<arith::MulIOp>(loc, x, y);
2447     else
2448       combinedResult = rewriter.create<arith::MulFOp>(loc, x, y);
2449     break;
2450   case CombiningKind::MINUI:
2451     combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y);
2452     break;
2453   case CombiningKind::MINSI:
2454     combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y);
2455     break;
2456   case CombiningKind::MAXUI:
2457     combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y);
2458     break;
2459   case CombiningKind::MAXSI:
2460     combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y);
2461     break;
2462   case CombiningKind::AND:
2463     combinedResult = rewriter.create<arith::AndIOp>(loc, x, y);
2464     break;
2465   case CombiningKind::OR:
2466     combinedResult = rewriter.create<arith::OrIOp>(loc, x, y);
2467     break;
2468   case CombiningKind::XOR:
2469     combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y);
2470     break;
2471   case CombiningKind::MINF:
2472     combinedResult = rewriter.create<arith::MinFOp>(loc, x, y);
2473     break;
2474   case CombiningKind::MAXF:
2475     combinedResult = rewriter.create<arith::MaxFOp>(loc, x, y);
2476     break;
2477   }
2478   return combinedResult;
2479 }
2480 
2481 /// Convert vector.scan op into arith ops and
2482 /// vector.insert_strided_slice/extract_strided_slice
2483 ///
2484 /// Ex:
2485 /// ```
2486 ///   %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim =
2487 ///   1} :
2488 ///     (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>)
2489 /// ```
2490 /// Gets converted to:
2491 /// ```
2492 ///   %cst = arith.constant dense<0> : vector<2x3xi32>
2493 ///   %0 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 1],
2494 ///   strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %1 =
2495 ///   vector.insert_strided_slice %0, %cst {offsets = [0, 0], strides = [1, 1]}
2496 ///   : vector<2x1xi32> into vector<2x3xi32> %2 = vector.extract_strided_slice
2497 ///   %arg0 {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} :
2498 ///   vector<2x3xi32> to vector<2x1xi32> %3 = arith.muli %0, %2 :
2499 ///   vector<2x1xi32> %4 = vector.insert_strided_slice %3, %1 {offsets = [0, 1],
2500 ///   strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %5 =
2501 ///   vector.extract_strided_slice %arg0 {offsets = [0, 2], sizes = [2, 1],
2502 ///   strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %6 = arith.muli %3,
2503 ///   %5 : vector<2x1xi32> %7 = vector.insert_strided_slice %6, %4 {offsets =
2504 ///   [0, 2], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %8 =
2505 ///   vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32> return %7, %8 :
2506 ///   vector<2x3xi32>, vector<2xi32>
2507 /// ```
2508 struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
2509   using OpRewritePattern<vector::ScanOp>::OpRewritePattern;
2510 
2511   LogicalResult matchAndRewrite(vector::ScanOp scanOp,
2512                                 PatternRewriter &rewriter) const override {
2513     auto loc = scanOp.getLoc();
2514     VectorType destType = scanOp.getDestType();
2515     ArrayRef<int64_t> destShape = destType.getShape();
2516     auto elType = destType.getElementType();
2517     bool isInt = elType.isIntOrIndex();
2518     if (!isValidKind(isInt, scanOp.kind()))
2519       return failure();
2520 
2521     VectorType resType = VectorType::get(destShape, elType);
2522     Value result = rewriter.create<arith::ConstantOp>(
2523         loc, resType, rewriter.getZeroAttr(resType));
2524     int64_t reductionDim = scanOp.reduction_dim();
2525     bool inclusive = scanOp.inclusive();
2526     int64_t destRank = destType.getRank();
2527     VectorType initialValueType = scanOp.getInitialValueType();
2528     int64_t initialValueRank = initialValueType.getRank();
2529 
2530     SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end());
2531     reductionShape[reductionDim] = 1;
2532     VectorType reductionType = VectorType::get(reductionShape, elType);
2533     SmallVector<int64_t> offsets(destRank, 0);
2534     SmallVector<int64_t> strides(destRank, 1);
2535     SmallVector<int64_t> sizes(destShape.begin(), destShape.end());
2536     sizes[reductionDim] = 1;
2537     ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes);
2538     ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides);
2539 
2540     Value lastOutput, lastInput;
2541     for (int i = 0; i < destShape[reductionDim]; i++) {
2542       offsets[reductionDim] = i;
2543       ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets);
2544       Value input = rewriter.create<vector::ExtractStridedSliceOp>(
2545           loc, reductionType, scanOp.source(), scanOffsets, scanSizes,
2546           scanStrides);
2547       Value output;
2548       if (i == 0) {
2549         if (inclusive) {
2550           output = input;
2551         } else {
2552           if (initialValueRank == 0) {
2553             // ShapeCastOp cannot handle 0-D vectors
2554             output = rewriter.create<vector::BroadcastOp>(
2555                 loc, input.getType(), scanOp.initial_value());
2556           } else {
2557             output = rewriter.create<vector::ShapeCastOp>(
2558                 loc, input.getType(), scanOp.initial_value());
2559           }
2560         }
2561       } else {
2562         Value y = inclusive ? input : lastInput;
2563         output = genOperator(loc, lastOutput, y, scanOp.kind(), rewriter);
2564         assert(output != nullptr);
2565       }
2566       result = rewriter.create<vector::InsertStridedSliceOp>(
2567           loc, output, result, offsets, strides);
2568       lastOutput = output;
2569       lastInput = input;
2570     }
2571 
2572     Value reduction;
2573     if (initialValueRank == 0) {
2574       Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0);
2575       reduction =
2576           rewriter.create<vector::BroadcastOp>(loc, initialValueType, v);
2577     } else {
2578       reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType,
2579                                                        lastOutput);
2580     }
2581 
2582     rewriter.replaceOp(scanOp, {result, reduction});
2583     return success();
2584   }
2585 };
2586 
2587 } // namespace
2588 
2589 void mlir::vector::populateVectorMaskMaterializationPatterns(
2590     RewritePatternSet &patterns, bool indexOptimizations) {
2591   patterns.add<VectorCreateMaskOpConversion,
2592                MaterializeTransferMask<vector::TransferReadOp>,
2593                MaterializeTransferMask<vector::TransferWriteOp>>(
2594       patterns.getContext(), indexOptimizations);
2595 }
2596 
2597 void mlir::vector::populateShapeCastFoldingPatterns(
2598     RewritePatternSet &patterns) {
2599   patterns.add<ShapeCastOpFolder>(patterns.getContext());
2600 }
2601 
2602 void mlir::vector::populateBubbleVectorBitCastOpPatterns(
2603     RewritePatternSet &patterns) {
2604   patterns.add<BubbleDownVectorBitCastForExtract,
2605                BubbleDownBitCastForStridedSliceExtract,
2606                BubbleUpBitCastForStridedSliceInsert>(patterns.getContext());
2607 }
2608 
2609 void mlir::vector::populateVectorBroadcastLoweringPatterns(
2610     RewritePatternSet &patterns) {
2611   patterns.add<BroadcastOpLowering>(patterns.getContext());
2612 }
2613 
2614 void mlir::vector::populateVectorMaskOpLoweringPatterns(
2615     RewritePatternSet &patterns) {
2616   patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
2617       patterns.getContext());
2618 }
2619 
2620 void mlir::vector::populateVectorShapeCastLoweringPatterns(
2621     RewritePatternSet &patterns) {
2622   patterns.add<ShapeCastOp2DDownCastRewritePattern,
2623                ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>(
2624       patterns.getContext());
2625 }
2626 
2627 void mlir::vector::populateVectorContractLoweringPatterns(
2628     RewritePatternSet &patterns, VectorTransformsOptions options) {
2629   patterns.add<OuterProductOpLowering>(patterns.getContext());
2630   patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
2631                ContractionOpToOuterProductOpLowering>(options,
2632                                                       patterns.getContext());
2633 }
2634 
2635 void mlir::vector::populateVectorTransposeLoweringPatterns(
2636     RewritePatternSet &patterns, VectorTransformsOptions options) {
2637   patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
2638       options, patterns.getContext());
2639 }
2640 
2641 void mlir::vector::populateVectorReductionToContractPatterns(
2642     RewritePatternSet &patterns) {
2643   patterns.add<MultiReduceToContract, CombineContractBroadcast,
2644                CombineContractTranspose, ReorderCastOpsOnBroadcast,
2645                ReorderCastOpsOnTranspose>(patterns.getContext());
2646 }
2647 
2648 void mlir::vector::
2649     populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
2650         RewritePatternSet &patterns) {
2651   patterns.add<DropInnerMostUnitDims>(patterns.getContext());
2652 }
2653 
2654 void mlir::vector::populateVectorTransferLoweringPatterns(
2655     RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank) {
2656   patterns.add<TransferReadToVectorLoadLowering,
2657                TransferWriteToVectorStoreLowering>(patterns.getContext(),
2658                                                    maxTransferRank);
2659   patterns
2660       .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
2661           patterns.getContext());
2662 }
2663 
2664 void mlir::vector::populateVectorScanLoweringPatterns(
2665     RewritePatternSet &patterns) {
2666   patterns.add<ScanToArithOps>(patterns.getContext());
2667 }
2668