1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites as 1->N patterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
14 
15 #include <type_traits>
16 
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
19 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
20 #include "mlir/Dialect/Linalg/IR/Linalg.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/SCF/IR/SCF.h"
23 #include "mlir/Dialect/Utils/IndexingUtils.h"
24 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
25 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
26 #include "mlir/IR/BuiltinTypes.h"
27 #include "mlir/IR/ImplicitLocOpBuilder.h"
28 #include "mlir/IR/Matchers.h"
29 #include "mlir/IR/PatternMatch.h"
30 #include "mlir/Interfaces/VectorInterfaces.h"
31 
32 #include "llvm/ADT/DenseSet.h"
33 #include "llvm/ADT/MapVector.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/SmallBitVector.h"
36 #include "llvm/Support/CommandLine.h"
37 #include "llvm/Support/Debug.h"
38 #include "llvm/Support/raw_ostream.h"
39 
40 #define DEBUG_TYPE "vector-to-vector"
41 
42 using namespace mlir;
43 using namespace mlir::vector;
44 
45 // Helper to find an index in an affine map.
46 static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
47   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
48     int64_t idx = map.getDimPosition(i);
49     if (idx == index)
50       return i;
51   }
52   return None;
53 }
54 
55 // Helper to construct iterator types with one index removed.
56 static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes,
57                                             int64_t index) {
58   SmallVector<Attribute, 4> results;
59   for (const auto &it : llvm::enumerate(iteratorTypes)) {
60     int64_t idx = it.index();
61     if (idx == index)
62       continue;
63     results.push_back(it.value());
64   }
65   return results;
66 }
67 
68 // Helper to construct an affine map with one index removed.
69 static AffineMap adjustMap(AffineMap map, int64_t index,
70                            PatternRewriter &rewriter) {
71   auto *ctx = rewriter.getContext();
72   SmallVector<AffineExpr, 4> results;
73   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
74     int64_t idx = map.getDimPosition(i);
75     if (idx == index)
76       continue;
77     // Re-insert remaining indices, but renamed when occurring
78     // after the removed index.
79     auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
80     results.push_back(targetExpr);
81   }
82   return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
83 }
84 
85 // Helper method to possibly drop a dimension in a load.
86 // TODO
87 static Value reshapeLoad(Location loc, Value val, VectorType type,
88                          int64_t index, int64_t pos,
89                          PatternRewriter &rewriter) {
90   if (index == -1)
91     return val;
92   Type lowType = VectorType::Builder(type).dropDim(0);
93   // At extraction dimension?
94   if (index == 0) {
95     auto posAttr = rewriter.getI64ArrayAttr(pos);
96     return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
97   }
98   // Unroll leading dimensions.
99   VectorType vType = lowType.cast<VectorType>();
100   Type resType = VectorType::Builder(type).dropDim(index);
101   auto resVectorType = resType.cast<VectorType>();
102   Value result = rewriter.create<arith::ConstantOp>(
103       loc, resVectorType, rewriter.getZeroAttr(resVectorType));
104   for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) {
105     auto posAttr = rewriter.getI64ArrayAttr(d);
106     Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
107     Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
108     result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result,
109                                                posAttr);
110   }
111   return result;
112 }
113 
114 // Helper method to possibly drop a dimension in a store.
115 // TODO
116 static Value reshapeStore(Location loc, Value val, Value result,
117                           VectorType type, int64_t index, int64_t pos,
118                           PatternRewriter &rewriter) {
119   // Unmodified?
120   if (index == -1)
121     return val;
122   // At insertion dimension?
123   if (index == 0) {
124     auto posAttr = rewriter.getI64ArrayAttr(pos);
125     return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
126   }
127   // Unroll leading dimensions.
128   Type lowType = VectorType::Builder(type).dropDim(0);
129   VectorType vType = lowType.cast<VectorType>();
130   Type insType = VectorType::Builder(vType).dropDim(0);
131   for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
132     auto posAttr = rewriter.getI64ArrayAttr(d);
133     Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
134     Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr);
135     Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
136     result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr);
137   }
138   return result;
139 }
140 
141 template <typename IntType>
142 static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
143   return llvm::to_vector<4>(llvm::map_range(
144       arrayAttr.getAsRange<IntegerAttr>(),
145       [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
146 }
147 
148 /// Helper to create arithmetic operation associated with a kind of contraction.
149 static Optional<Value> createContractArithOp(Location loc, Value x, Value y,
150                                              Value acc,
151                                              vector::CombiningKind kind,
152                                              PatternRewriter &rewriter,
153                                              bool isInt) {
154   using vector::CombiningKind;
155   Value mul;
156   if (isInt) {
157     if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF)
158       // Only valid for floating point types.
159       return Optional<Value>();
160     mul = rewriter.create<arith::MulIOp>(loc, x, y);
161   } else {
162     // Float case.
163     if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
164         kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
165         kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
166         kind == CombiningKind::XOR)
167       // Only valid for integer types.
168       return Optional<Value>();
169     // Special case for fused multiply-add.
170     if (acc && acc.getType().isa<VectorType>() && kind == CombiningKind::ADD) {
171       return Optional<Value>(rewriter.create<vector::FMAOp>(loc, x, y, acc));
172     }
173     mul = rewriter.create<arith::MulFOp>(loc, x, y);
174   }
175   if (!acc)
176     return Optional<Value>(mul);
177   return makeArithReduction(rewriter, loc, kind, mul, acc);
178 }
179 
180 /// Return the positions of the reductions in the given map.
181 static SmallVector<int64_t> getReductionIndex(AffineMap map,
182                                               ArrayAttr iteratorTypes) {
183   SmallVector<int64_t> dimsIdx;
184   for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
185     if (isReductionIterator(iteratorTypes[map.getDimPosition(i)]))
186       dimsIdx.push_back(i);
187   }
188   return dimsIdx;
189 }
190 
191 /// Look for a given dimension in an affine map and return its position. Return
192 /// llvm::None if the dimension is not in the map results.
193 static llvm::Optional<unsigned> getDimPosition(AffineMap map, unsigned dim) {
194   for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
195     if (map.getDimPosition(i) == dim)
196       return i;
197   }
198   return llvm::None;
199 }
200 
201 namespace {
202 
203 /// ShapeCastOpFolder folds cancelling ShapeCastOps away.
204 //
205 // Example:
206 //
207 //  The following MLIR with cancelling ShapeCastOps:
208 //
209 //   %0 = source : vector<5x4x2xf32>
210 //   %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
211 //   %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
212 //   %3 = user %2 : vector<5x4x2xf32>
213 //
214 //  Should canonicalize to the following:
215 //
216 //   %0 = source : vector<5x4x2xf32>
217 //   %1 = user %0 : vector<5x4x2xf32>
218 //
219 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
220   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
221 
222   LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
223                                 PatternRewriter &rewriter) const override {
224     // Check if 'shapeCastOp' has vector source/result type.
225     auto sourceVectorType =
226         shapeCastOp.getSource().getType().dyn_cast_or_null<VectorType>();
227     auto resultVectorType =
228         shapeCastOp.getResult().getType().dyn_cast_or_null<VectorType>();
229     if (!sourceVectorType || !resultVectorType)
230       return failure();
231 
232     // Check if shape cast op source operand is also a shape cast op.
233     auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
234         shapeCastOp.getSource().getDefiningOp());
235     if (!sourceShapeCastOp)
236       return failure();
237     auto operandSourceVectorType =
238         sourceShapeCastOp.getSource().getType().cast<VectorType>();
239     auto operandResultVectorType = sourceShapeCastOp.getType();
240 
241     // Check if shape cast operations invert each other.
242     if (operandSourceVectorType != resultVectorType ||
243         operandResultVectorType != sourceVectorType)
244       return failure();
245 
246     rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.getSource());
247     return success();
248   }
249 };
250 
251 /// Progressive lowering of BroadcastOp.
252 class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
253 public:
254   using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
255 
256   LogicalResult matchAndRewrite(vector::BroadcastOp op,
257                                 PatternRewriter &rewriter) const override {
258     auto loc = op.getLoc();
259     VectorType dstType = op.getVectorType();
260     VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
261     Type eltType = dstType.getElementType();
262 
263     // Scalar to any vector can use splat.
264     if (!srcType) {
265       rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource());
266       return success();
267     }
268 
269     // Determine rank of source and destination.
270     int64_t srcRank = srcType.getRank();
271     int64_t dstRank = dstType.getRank();
272 
273     // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
274     if (srcRank <= 1 && dstRank == 1) {
275       Value ext;
276       if (srcRank == 0)
277         ext = rewriter.create<vector::ExtractElementOp>(loc, op.getSource());
278       else
279         ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
280       rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
281       return success();
282     }
283 
284     // Duplicate this rank.
285     // For example:
286     //   %x = broadcast %y  : k-D to n-D, k < n
287     // becomes:
288     //   %b = broadcast %y  : k-D to (n-1)-D
289     //   %x = [%b,%b,%b,%b] : n-D
290     // becomes:
291     //   %b = [%y,%y]       : (n-1)-D
292     //   %x = [%b,%b,%b,%b] : n-D
293     if (srcRank < dstRank) {
294       // Duplication.
295       VectorType resType =
296           VectorType::get(dstType.getShape().drop_front(), eltType);
297       Value bcst =
298           rewriter.create<vector::BroadcastOp>(loc, resType, op.getSource());
299       Value result = rewriter.create<arith::ConstantOp>(
300           loc, dstType, rewriter.getZeroAttr(dstType));
301       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
302         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
303       rewriter.replaceOp(op, result);
304       return success();
305     }
306 
307     // Find non-matching dimension, if any.
308     assert(srcRank == dstRank);
309     int64_t m = -1;
310     for (int64_t r = 0; r < dstRank; r++)
311       if (srcType.getDimSize(r) != dstType.getDimSize(r)) {
312         m = r;
313         break;
314       }
315 
316     // All trailing dimensions are the same. Simply pass through.
317     if (m == -1) {
318       rewriter.replaceOp(op, op.getSource());
319       return success();
320     }
321 
322     // Any non-matching dimension forces a stretch along this rank.
323     // For example:
324     //   %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32>
325     // becomes:
326     //   %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32>
327     //   %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32>
328     //   %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32>
329     //   %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32>
330     //   %x = [%a,%b,%c,%d]
331     // becomes:
332     //   %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32>
333     //   %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32>
334     //   %a = [%u, %v]
335     //   ..
336     //   %x = [%a,%b,%c,%d]
337     VectorType resType =
338         VectorType::get(dstType.getShape().drop_front(), eltType);
339     Value result = rewriter.create<arith::ConstantOp>(
340         loc, dstType, rewriter.getZeroAttr(dstType));
341     if (m == 0) {
342       // Stetch at start.
343       Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
344       Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
345       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
346         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
347     } else {
348       // Stetch not at start.
349       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
350         Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), d);
351         Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
352         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
353       }
354     }
355     rewriter.replaceOp(op, result);
356     return success();
357   }
358 };
359 
360 /// Given a 'transpose' pattern, prune the rightmost dimensions that are not
361 /// transposed.
362 void pruneNonTransposedDims(ArrayRef<int64_t> transpose,
363                             SmallVectorImpl<int64_t> &result) {
364   size_t numTransposedDims = transpose.size();
365   for (size_t transpDim : llvm::reverse(transpose)) {
366     if (transpDim != numTransposedDims - 1)
367       break;
368     numTransposedDims--;
369   }
370 
371   result.append(transpose.begin(), transpose.begin() + numTransposedDims);
372 }
373 
374 /// Progressive lowering of TransposeOp.
375 /// One:
376 ///   %x = vector.transpose %y, [1, 0]
377 /// is replaced by:
378 ///   %z = arith.constant dense<0.000000e+00>
379 ///   %0 = vector.extract %y[0, 0]
380 ///   %1 = vector.insert %0, %z [0, 0]
381 ///   ..
382 ///   %x = vector.insert .., .. [.., ..]
383 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
384 public:
385   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
386 
387   TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
388                       MLIRContext *context)
389       : OpRewritePattern<vector::TransposeOp>(context),
390         vectorTransformOptions(vectorTransformOptions) {}
391 
392   LogicalResult matchAndRewrite(vector::TransposeOp op,
393                                 PatternRewriter &rewriter) const override {
394     auto loc = op.getLoc();
395 
396     Value input = op.getVector();
397     VectorType inputType = op.getVectorType();
398     VectorType resType = op.getResultType();
399 
400     // Set up convenience transposition table.
401     SmallVector<int64_t, 4> transp;
402     for (auto attr : op.getTransp())
403       transp.push_back(attr.cast<IntegerAttr>().getInt());
404 
405     if (vectorTransformOptions.vectorTransposeLowering ==
406             vector::VectorTransposeLowering::Shuffle &&
407         resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0)
408       return rewriter.notifyMatchFailure(
409           op, "Options specifies lowering to shuffle");
410 
411     // Handle a true 2-D matrix transpose differently when requested.
412     if (vectorTransformOptions.vectorTransposeLowering ==
413             vector::VectorTransposeLowering::Flat &&
414         resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
415       Type flattenedType =
416           VectorType::get(resType.getNumElements(), resType.getElementType());
417       auto matrix =
418           rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
419       auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
420       auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
421       Value trans = rewriter.create<vector::FlatTransposeOp>(
422           loc, flattenedType, matrix, rows, columns);
423       rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
424       return success();
425     }
426 
427     // Generate unrolled extract/insert ops. We do not unroll the rightmost
428     // (i.e., highest-order) dimensions that are not transposed and leave them
429     // in vector form to improve performance. Therefore, we prune those
430     // dimensions from the shape/transpose data structures used to generate the
431     // extract/insert ops.
432     SmallVector<int64_t, 4> prunedTransp;
433     pruneNonTransposedDims(transp, prunedTransp);
434     size_t numPrunedDims = transp.size() - prunedTransp.size();
435     auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
436     SmallVector<int64_t, 4> ones(prunedInShape.size(), 1);
437     auto prunedInStrides = computeStrides(prunedInShape, ones);
438 
439     // Generates the extract/insert operations for every scalar/vector element
440     // of the leftmost transposed dimensions. We traverse every transpose
441     // element using a linearized index that we delinearize to generate the
442     // appropriate indices for the extract/insert operations.
443     Value result = rewriter.create<arith::ConstantOp>(
444         loc, resType, rewriter.getZeroAttr(resType));
445     int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
446 
447     for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
448          ++linearIdx) {
449       auto extractIdxs = delinearize(prunedInStrides, linearIdx);
450       SmallVector<int64_t, 4> insertIdxs(extractIdxs);
451       applyPermutationToVector(insertIdxs, prunedTransp);
452       Value extractOp =
453           rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
454       result =
455           rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
456     }
457 
458     rewriter.replaceOp(op, result);
459     return success();
460   }
461 
462 private:
463   /// Options to control the vector patterns.
464   vector::VectorTransformsOptions vectorTransformOptions;
465 };
466 
467 /// Rewrite a 2-D vector.transpose as a sequence of:
468 ///   vector.shape_cast 2D -> 1D
469 ///   vector.shuffle
470 ///   vector.shape_cast 1D -> 2D
471 class TransposeOp2DToShuffleLowering
472     : public OpRewritePattern<vector::TransposeOp> {
473 public:
474   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
475 
476   TransposeOp2DToShuffleLowering(
477       vector::VectorTransformsOptions vectorTransformOptions,
478       MLIRContext *context)
479       : OpRewritePattern<vector::TransposeOp>(context),
480         vectorTransformOptions(vectorTransformOptions) {}
481 
482   LogicalResult matchAndRewrite(vector::TransposeOp op,
483                                 PatternRewriter &rewriter) const override {
484     auto loc = op.getLoc();
485 
486     VectorType srcType = op.getVectorType();
487     if (srcType.getRank() != 2)
488       return rewriter.notifyMatchFailure(op, "Not a 2D transpose");
489 
490     SmallVector<int64_t, 4> transp;
491     for (auto attr : op.getTransp())
492       transp.push_back(attr.cast<IntegerAttr>().getInt());
493     if (transp[0] != 1 && transp[1] != 0)
494       return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation");
495 
496     if (vectorTransformOptions.vectorTransposeLowering !=
497         VectorTransposeLowering::Shuffle)
498       return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle");
499 
500     int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
501     Value casted = rewriter.create<vector::ShapeCastOp>(
502         loc, VectorType::get({m * n}, srcType.getElementType()),
503         op.getVector());
504     SmallVector<int64_t> mask;
505     mask.reserve(m * n);
506     for (int64_t j = 0; j < n; ++j)
507       for (int64_t i = 0; i < m; ++i)
508         mask.push_back(i * n + j);
509 
510     Value shuffled =
511         rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask);
512     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
513                                                      shuffled);
514 
515     return success();
516   }
517 
518 private:
519   /// Options to control the vector patterns.
520   vector::VectorTransformsOptions vectorTransformOptions;
521 };
522 
523 /// Progressive lowering of OuterProductOp.
524 /// One:
525 ///   %x = vector.outerproduct %lhs, %rhs, %acc
526 /// is replaced by:
527 ///   %z = zero-result
528 ///   %0 = vector.extract %lhs[0]
529 ///   %1 = vector.broadcast %0
530 ///   %2 = vector.extract %acc[0]
531 ///   %3 = vector.fma %1, %rhs, %2
532 ///   %4 = vector.insert %3, %z[0]
533 ///   ..
534 ///   %x = vector.insert %.., %..[N-1]
535 ///
536 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
537 public:
538   using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
539 
540   LogicalResult matchAndRewrite(vector::OuterProductOp op,
541                                 PatternRewriter &rewriter) const override {
542     auto loc = op.getLoc();
543 
544     VectorType lhsType = op.getOperandVectorTypeLHS();
545     VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>();
546     VectorType resType = op.getVectorType();
547     Type eltType = resType.getElementType();
548     bool isInt = eltType.isa<IntegerType, IndexType>();
549     Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0];
550     vector::CombiningKind kind = op.getKind();
551 
552     if (!rhsType) {
553       // Special case: AXPY operation.
554       Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
555       Optional<Value> mult = createContractArithOp(loc, op.getLhs(), b, acc,
556                                                    kind, rewriter, isInt);
557       if (!mult.hasValue())
558         return failure();
559       rewriter.replaceOp(op, mult.getValue());
560       return success();
561     }
562 
563     Value result = rewriter.create<arith::ConstantOp>(
564         loc, resType, rewriter.getZeroAttr(resType));
565     for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
566       auto pos = rewriter.getI64ArrayAttr(d);
567       Value x =
568           rewriter.create<vector::ExtractOp>(loc, eltType, op.getLhs(), pos);
569       Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
570       Value r = nullptr;
571       if (acc)
572         r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
573       Optional<Value> m =
574           createContractArithOp(loc, a, op.getRhs(), r, kind, rewriter, isInt);
575       if (!m.hasValue())
576         return failure();
577       result = rewriter.create<vector::InsertOp>(loc, resType, m.getValue(),
578                                                  result, pos);
579     }
580     rewriter.replaceOp(op, result);
581     return success();
582   }
583 };
584 
585 /// Lower vector.contract with all size one reduction dimensions to
586 /// elementwise ops when possible.
587 struct ContractOpToElementwise
588     : public OpRewritePattern<vector::ContractionOp> {
589   using OpRewritePattern::OpRewritePattern;
590   using FilterConstraintType =
591       std::function<LogicalResult(vector::ContractionOp op)>;
592   static LogicalResult defaultFilter(vector::ContractionOp op) {
593     return success();
594   }
595   ContractOpToElementwise(
596       vector::VectorTransformsOptions vectorTransformOptions,
597       MLIRContext *context,
598       const FilterConstraintType &constraint = defaultFilter)
599       : OpRewritePattern<vector::ContractionOp>(context),
600         vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
601 
602   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
603                                 PatternRewriter &rewriter) const override {
604     // TODO: implement masks
605     if (llvm::size(contractOp.getMasks()) != 0)
606       return failure();
607 
608     if (failed(filter(contractOp)))
609       return failure();
610 
611     if (vectorTransformOptions.vectorContractLowering !=
612         vector::VectorContractLowering::ParallelArith)
613       return failure();
614     ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
615     ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape();
616     AffineMap lhsMap = contractOp.getIndexingMaps()[0];
617     AffineMap rhsMap = contractOp.getIndexingMaps()[1];
618     SmallVector<int64_t> lhsReductionDims =
619         getReductionIndex(lhsMap, contractOp.getIteratorTypes());
620     SmallVector<int64_t> rhsReductionDims =
621         getReductionIndex(rhsMap, contractOp.getIteratorTypes());
622     // All the reduction dimensions must be a size 1.
623     for (int64_t dim : lhsReductionDims) {
624       if (lhsShape[dim] != 1)
625         return failure();
626     }
627     for (int64_t dim : rhsReductionDims) {
628       if (rhsShape[dim] != 1)
629         return failure();
630     }
631     AffineMap accMap = contractOp.getIndexingMaps()[2];
632     unsigned numParallelDims = accMap.getNumResults();
633     unsigned numLhsDimToBroadcast =
634         numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size());
635     unsigned numRhsDimToBroadcast =
636         numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size());
637     SmallVector<int64_t> lhsDims;
638     SmallVector<int64_t> lhsTranspose;
639     SmallVector<int64_t> rhsDims;
640     SmallVector<int64_t> rhsTranspose;
641     for (int64_t dim : lhsReductionDims)
642       lhsTranspose.push_back(numLhsDimToBroadcast + dim);
643     for (int64_t dim : rhsReductionDims)
644       rhsTranspose.push_back(numRhsDimToBroadcast + dim);
645     // Loop through the parallel dimensions to calculate the dimensions to
646     // broadcast and to permute in order to extract only parallel dimensions.
647     for (unsigned i = 0; i < numParallelDims; i++) {
648       llvm::Optional<unsigned> lhsDim =
649           getDimPosition(lhsMap, accMap.getDimPosition(i));
650       if (lhsDim) {
651         lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
652       } else {
653         // If the parallel dimension doesn't exist we will have to broadcast it.
654         lhsDims.push_back(
655             contractOp.getResultType().cast<VectorType>().getDimSize(i));
656         lhsTranspose.push_back(lhsDims.size() - 1);
657       }
658       llvm::Optional<unsigned> rhsDim =
659           getDimPosition(rhsMap, accMap.getDimPosition(i));
660       if (rhsDim) {
661         rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
662       } else {
663         // If the parallel dimension doesn't exist we will have to broadcast it.
664         rhsDims.push_back(
665             contractOp.getResultType().cast<VectorType>().getDimSize(i));
666         rhsTranspose.push_back(rhsDims.size() - 1);
667       }
668     }
669     Value newLhs = contractOp.getLhs();
670     Value newRhs = contractOp.getRhs();
671     Location loc = contractOp.getLoc();
672     if (!lhsDims.empty()) {
673       lhsDims.append(lhsShape.begin(), lhsShape.end());
674       auto expandedType =
675           VectorType::get(lhsDims, contractOp.getLhsType().getElementType());
676       newLhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newLhs);
677     }
678     if (!rhsDims.empty()) {
679       rhsDims.append(rhsShape.begin(), rhsShape.end());
680       auto expandedType =
681           VectorType::get(rhsDims, contractOp.getRhsType().getElementType());
682       newRhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newRhs);
683     }
684     bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
685     newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose);
686     newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
687     SmallVector<int64_t, 4> lhsOffsets(lhsReductionDims.size(), 0);
688     SmallVector<int64_t, 4> rhsOffsets(rhsReductionDims.size(), 0);
689     newLhs = rewriter.create<vector::ExtractOp>(
690         loc, newLhs, rewriter.getI64ArrayAttr(lhsOffsets));
691     newRhs = rewriter.create<vector::ExtractOp>(
692         loc, newRhs, rewriter.getI64ArrayAttr(rhsOffsets));
693     Optional<Value> result =
694         createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
695                               contractOp.getKind(), rewriter, isInt);
696     rewriter.replaceOp(contractOp, {*result});
697     return success();
698   }
699 
700 private:
701   /// Options to control the vector patterns.
702   vector::VectorTransformsOptions vectorTransformOptions;
703   FilterConstraintType filter;
704 };
705 
706 /// Progressive lowering of ConstantMaskOp.
707 /// One:
708 ///   %x = vector.constant_mask [a,b]
709 /// is replaced by:
710 ///   %z = zero-result
711 ///   %l = vector.constant_mask [b]
712 ///   %4 = vector.insert %l, %z[0]
713 ///   ..
714 ///   %x = vector.insert %l, %..[a-1]
715 /// until a one-dimensional vector is reached. All these operations
716 /// will be folded at LLVM IR level.
717 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
718 public:
719   using OpRewritePattern<vector::ConstantMaskOp>::OpRewritePattern;
720 
721   LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
722                                 PatternRewriter &rewriter) const override {
723     auto loc = op.getLoc();
724     auto dstType = op.getType();
725     auto eltType = dstType.getElementType();
726     auto dimSizes = op.getMaskDimSizes();
727     int64_t rank = dstType.getRank();
728 
729     if (rank == 0) {
730       assert(dimSizes.size() == 1 &&
731              "Expected exactly one dim size for a 0-D vector");
732       bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1;
733       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
734           op, dstType,
735           DenseIntElementsAttr::get(
736               VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
737               ArrayRef<bool>{value}));
738       return success();
739     }
740 
741     // Scalable constant masks can only be lowered for the "none set" case.
742     if (dstType.cast<VectorType>().isScalable()) {
743       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
744           op, DenseElementsAttr::get(dstType, false));
745       return success();
746     }
747 
748     int64_t trueDim = std::min(dstType.getDimSize(0),
749                                dimSizes[0].cast<IntegerAttr>().getInt());
750 
751     if (rank == 1) {
752       // Express constant 1-D case in explicit vector form:
753       //   [T,..,T,F,..,F].
754       SmallVector<bool, 4> values(dstType.getDimSize(0));
755       for (int64_t d = 0; d < trueDim; d++)
756         values[d] = true;
757       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
758           op, dstType, rewriter.getBoolVectorAttr(values));
759       return success();
760     }
761 
762     VectorType lowType =
763         VectorType::get(dstType.getShape().drop_front(), eltType);
764     SmallVector<int64_t, 4> newDimSizes;
765     for (int64_t r = 1; r < rank; r++)
766       newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
767     Value trueVal = rewriter.create<vector::ConstantMaskOp>(
768         loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
769     Value result = rewriter.create<arith::ConstantOp>(
770         loc, dstType, rewriter.getZeroAttr(dstType));
771     for (int64_t d = 0; d < trueDim; d++) {
772       auto pos = rewriter.getI64ArrayAttr(d);
773       result =
774           rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos);
775     }
776     rewriter.replaceOp(op, result);
777     return success();
778   }
779 };
780 
781 /// Progressive lowering of CreateMaskOp.
782 /// One:
783 ///   %x = vector.create_mask %a, ... : vector<dx...>
784 /// is replaced by:
785 ///   %l = vector.create_mask ... : vector<...>  ; one lower rank
786 ///   %0 = arith.cmpi "slt", %ci, %a       |
787 ///   %1 = select %0, %l, %zeroes    |
788 ///   %r = vector.insert %1, %pr [i] | d-times
789 ///   %x = ....
790 /// until a one-dimensional vector is reached.
791 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
792 public:
793   using OpRewritePattern<vector::CreateMaskOp>::OpRewritePattern;
794 
795   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
796                                 PatternRewriter &rewriter) const override {
797     auto dstType = op.getResult().getType().cast<VectorType>();
798     int64_t rank = dstType.getRank();
799     if (rank <= 1)
800       return rewriter.notifyMatchFailure(
801           op, "0-D and 1-D vectors are handled separately");
802 
803     auto loc = op.getLoc();
804     auto eltType = dstType.getElementType();
805     int64_t dim = dstType.getDimSize(0);
806     Value idx = op.getOperand(0);
807 
808     VectorType lowType =
809         VectorType::get(dstType.getShape().drop_front(), eltType);
810     Value trueVal = rewriter.create<vector::CreateMaskOp>(
811         loc, lowType, op.getOperands().drop_front());
812     Value falseVal = rewriter.create<arith::ConstantOp>(
813         loc, lowType, rewriter.getZeroAttr(lowType));
814     Value result = rewriter.create<arith::ConstantOp>(
815         loc, dstType, rewriter.getZeroAttr(dstType));
816     for (int64_t d = 0; d < dim; d++) {
817       Value bnd =
818           rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d));
819       Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
820                                                  bnd, idx);
821       Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal);
822       auto pos = rewriter.getI64ArrayAttr(d);
823       result =
824           rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos);
825     }
826     rewriter.replaceOp(op, result);
827     return success();
828   }
829 };
830 
831 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
832 /// vectors progressively on the way to target llvm.matrix intrinsics.
833 /// This iterates over the most major dimension of the 2-D vector and performs
834 /// rewrites into:
835 ///   vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
836 class ShapeCastOp2DDownCastRewritePattern
837     : public OpRewritePattern<vector::ShapeCastOp> {
838 public:
839   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
840 
841   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
842                                 PatternRewriter &rewriter) const override {
843     auto sourceVectorType = op.getSourceVectorType();
844     auto resultVectorType = op.getResultVectorType();
845     if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
846       return failure();
847 
848     auto loc = op.getLoc();
849     Value desc = rewriter.create<arith::ConstantOp>(
850         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
851     unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
852     for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
853       Value vec = rewriter.create<vector::ExtractOp>(loc, op.getSource(), i);
854       desc = rewriter.create<vector::InsertStridedSliceOp>(
855           loc, vec, desc,
856           /*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
857     }
858     rewriter.replaceOp(op, desc);
859     return success();
860   }
861 };
862 
863 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
864 /// vectors progressively.
865 /// This iterates over the most major dimension of the 2-D vector and performs
866 /// rewrites into:
867 ///   vector.extract_strided_slice from 1-D + vector.insert into 2-D
868 /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
869 class ShapeCastOp2DUpCastRewritePattern
870     : public OpRewritePattern<vector::ShapeCastOp> {
871 public:
872   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
873 
874   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
875                                 PatternRewriter &rewriter) const override {
876     auto sourceVectorType = op.getSourceVectorType();
877     auto resultVectorType = op.getResultVectorType();
878     if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
879       return failure();
880 
881     auto loc = op.getLoc();
882     Value desc = rewriter.create<arith::ConstantOp>(
883         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
884     unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
885     for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
886       Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
887           loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize,
888           /*sizes=*/mostMinorVectorSize,
889           /*strides=*/1);
890       desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
891     }
892     rewriter.replaceOp(op, desc);
893     return success();
894   }
895 };
896 
897 // We typically should not lower general shape cast operations into data
898 // movement instructions, since the assumption is that these casts are
899 // optimized away during progressive lowering. For completeness, however,
900 // we fall back to a reference implementation that moves all elements
901 // into the right place if we get here.
902 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
903 public:
904   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
905 
906   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
907                                 PatternRewriter &rewriter) const override {
908     Location loc = op.getLoc();
909     auto sourceVectorType = op.getSourceVectorType();
910     auto resultVectorType = op.getResultVectorType();
911 
912     // Special case 2D/1D lowerings with better implementations.
913     // TODO: make is ND/1D to allow generic ND->1D->MD.
914     int64_t srcRank = sourceVectorType.getRank();
915     int64_t resRank = resultVectorType.getRank();
916     if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
917       return failure();
918 
919     // Generic ShapeCast lowering path goes all the way down to unrolled scalar
920     // extract/insert chains.
921     // TODO: consider evolving the semantics to only allow 1D source or dest and
922     // drop this potentially very expensive lowering.
923     // Compute number of elements involved in the reshape.
924     int64_t numElts = 1;
925     for (int64_t r = 0; r < srcRank; r++)
926       numElts *= sourceVectorType.getDimSize(r);
927     // Replace with data movement operations:
928     //    x[0,0,0] = y[0,0]
929     //    x[0,0,1] = y[0,1]
930     //    x[0,1,0] = y[0,2]
931     // etc., incrementing the two index vectors "row-major"
932     // within the source and result shape.
933     SmallVector<int64_t, 4> srcIdx(srcRank);
934     SmallVector<int64_t, 4> resIdx(resRank);
935     Value result = rewriter.create<arith::ConstantOp>(
936         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
937     for (int64_t i = 0; i < numElts; i++) {
938       if (i != 0) {
939         incIdx(srcIdx, sourceVectorType, srcRank - 1);
940         incIdx(resIdx, resultVectorType, resRank - 1);
941       }
942       Value e = rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
943       result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx);
944     }
945     rewriter.replaceOp(op, result);
946     return success();
947   }
948 
949 private:
950   static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) {
951     assert(0 <= r && r < tp.getRank());
952     if (++idx[r] == tp.getDimSize(r)) {
953       idx[r] = 0;
954       incIdx(idx, tp, r - 1);
955     }
956   }
957 };
958 
959 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
960 /// Ex:
961 /// ```
962 ///   %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
963 ///   %1 = vector.multi_reduction add, %0 [1]
964 ///     : vector<8x32x16xf32> to vector<8x16xf32>
965 /// ```
966 /// Gets converted to:
967 /// ```
968 ///   %1 = vector.contract {indexing_maps = [
969 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
970 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
971 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
972 ///    iterator_types = ["parallel", "parallel", "reduction"],
973 ///    kind = add} %0, %arg1, %cst_f0
974 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
975 ///  ```
976 struct MultiReduceToContract
977     : public OpRewritePattern<vector::MultiDimReductionOp> {
978   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
979 
980   LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
981                                 PatternRewriter &rewriter) const override {
982     if (reduceOp.getKind() != vector::CombiningKind::ADD)
983       return failure();
984     Operation *mulOp = reduceOp.getSource().getDefiningOp();
985     if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
986       return failure();
987     SmallVector<bool> reductionMask = reduceOp.getReductionMask();
988     auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
989     SmallVector<AffineExpr> exprs;
990     SmallVector<StringRef> iteratorTypes;
991     for (const auto &isReduceDim : llvm::enumerate(reductionMask)) {
992       if (!isReduceDim.value()) {
993         iteratorTypes.push_back(getParallelIteratorTypeName());
994         exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
995       } else {
996         iteratorTypes.push_back(getReductionIteratorTypeName());
997       }
998     }
999     auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(),
1000                                  /*symCount=*/0, exprs, reduceOp.getContext());
1001     Value zero = rewriter.create<arith::ConstantOp>(
1002         reduceOp.getLoc(), reduceOp.getDestType(),
1003         rewriter.getZeroAttr(reduceOp.getDestType()));
1004     rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
1005         reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), zero,
1006         rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
1007         rewriter.getStrArrayAttr(iteratorTypes));
1008     return success();
1009   }
1010 };
1011 
1012 /// Merge TransposeOp into ContractionOp user.
1013 /// Ex:
1014 /// ```
1015 ///   %0 = vector.transpose %arg0, [2, 0, 1]
1016 ///     : vector<32x16x8xf32> to vector<8x32x16xf32>
1017 ///   %1 = vector.contract {indexing_maps = [
1018 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1019 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1020 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
1021 ///    iterator_types = ["parallel", "parallel", "reduction"],
1022 ///    kind = add} %0, %arg1, %cst_f0
1023 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
1024 /// ```
1025 /// Gets converted to:
1026 /// ```
1027 ///   %1 = vector.contract {indexing_maps = [
1028 ///         affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
1029 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1030 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
1031 ///    iterator_types = ["parallel", "parallel", "reduction"],
1032 ///    kind = add} %arg0, %arg1, %cst_f0
1033 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
1034 ///  ```
1035 struct CombineContractTranspose
1036     : public OpRewritePattern<vector::ContractionOp> {
1037   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
1038 
1039   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1040                                 PatternRewriter &rewriter) const override {
1041     SmallVector<AffineMap, 4> maps =
1042         llvm::to_vector<4>(contractOp.getIndexingMaps());
1043     Value lhs = contractOp.getLhs();
1044     Value rhs = contractOp.getRhs();
1045     size_t index = 0;
1046     bool changed = false;
1047     for (Value *operand : {&lhs, &rhs}) {
1048       AffineMap &map = maps[index++];
1049       auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
1050       if (!transposeOp)
1051         continue;
1052       SmallVector<int64_t> perm;
1053       transposeOp.getTransp(perm);
1054       AffineMap permutationMap = AffineMap::getPermutationMap(
1055           extractVector<unsigned>(transposeOp.getTransp()),
1056           contractOp.getContext());
1057       map = inversePermutation(permutationMap).compose(map);
1058       *operand = transposeOp.getVector();
1059       changed = true;
1060     }
1061     if (!changed)
1062       return failure();
1063     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1064         contractOp, lhs, rhs, contractOp.getAcc(),
1065         rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
1066     return success();
1067   }
1068 };
1069 
1070 /// Merge BroadcastOp into ContractionOp user.
1071 /// Ex:
1072 /// ```
1073 ///   %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
1074 ///   %1 = vector.contract {indexing_maps = [
1075 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1076 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1077 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
1078 ///    iterator_types = ["parallel", "parallel", "reduction"],
1079 ///    kind = add} %0, %arg1, %cst_f0
1080 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
1081 /// ```
1082 /// Gets converted to:
1083 /// ```
1084 ///   %1 = vector.contract {indexing_maps = [
1085 ///         affine_map<(d0, d1, d2) -> (d1, d2)>,
1086 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1087 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
1088 ///    iterator_types = ["parallel", "parallel", "reduction"],
1089 ///    kind = add} %arg0, %arg1, %cst_f0
1090 ///    : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
1091 ///  ```
1092 struct CombineContractBroadcast
1093     : public OpRewritePattern<vector::ContractionOp> {
1094   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
1095 
1096   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1097                                 PatternRewriter &rewriter) const override {
1098     SmallVector<AffineMap, 4> maps =
1099         llvm::to_vector<4>(contractOp.getIndexingMaps());
1100     Value lhs = contractOp.getLhs();
1101     Value rhs = contractOp.getRhs();
1102     size_t index = 0;
1103     bool changed = false;
1104     for (Value *operand : {&lhs, &rhs}) {
1105       AffineMap &map = maps[index++];
1106       auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
1107       if (!broadcast)
1108         continue;
1109       // contractionOp can only take vector as operands.
1110       auto srcType = broadcast.getSourceType().dyn_cast<VectorType>();
1111       if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank())
1112         continue;
1113       int64_t rankDiff =
1114           broadcast.getVectorType().getRank() - srcType.getRank();
1115       bool innerDimBroadcast = false;
1116       SmallVector<AffineExpr> originalDims;
1117       for (const auto &dim : llvm::enumerate(srcType.getShape())) {
1118         if (dim.value() !=
1119             broadcast.getVectorType().getDimSize(rankDiff + dim.index())) {
1120           innerDimBroadcast = true;
1121           break;
1122         }
1123         originalDims.push_back(
1124             rewriter.getAffineDimExpr(dim.index() + rankDiff));
1125       }
1126       // Contract doesn't support inner dimension broadcast. Once this is
1127       // relaxed we can remove this case.
1128       if (innerDimBroadcast)
1129         continue;
1130 
1131       // It would be incorrect to fold a broadcast onto a reduction dimension
1132       // of non-unit size.
1133       bool nonUnitDimReductionBroadcast = false;
1134       for (int64_t i = 0; i < rankDiff; ++i) {
1135         if (broadcast.getVectorType().getDimSize(i) != 1 &&
1136             isReductionIterator(contractOp.getIteratorTypes()
1137                                     .getValue()[map.getDimPosition(i)])) {
1138           nonUnitDimReductionBroadcast = true;
1139           break;
1140         }
1141       }
1142       if (nonUnitDimReductionBroadcast)
1143         continue;
1144 
1145       AffineMap broadcastMap =
1146           AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims,
1147                          contractOp.getContext());
1148       map = broadcastMap.compose(map);
1149       *operand = broadcast.getSource();
1150       changed = true;
1151     }
1152 
1153     if (!changed)
1154       return failure();
1155 
1156     // Determine which dims are usused, now that the maps have been composed
1157     // with the broadcast maps.
1158     unsigned numDims = maps[0].getNumDims();
1159     llvm::SmallBitVector unusedDims(numDims, true);
1160     for (const auto &m : maps) {
1161       for (unsigned i = 0; i < numDims; ++i) {
1162         if (m.isFunctionOfDim(i))
1163           unusedDims.reset(i);
1164       }
1165     }
1166     // Compress unused dims.
1167     for (auto &m : maps)
1168       m = compressDims(m, unusedDims);
1169     // Compute the combined iterators.
1170     SmallVector<Attribute, 4> iterators;
1171     for (unsigned i = 0; i < numDims; ++i) {
1172       if (!unusedDims.test(i))
1173         iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
1174     }
1175     // Check that compressing unused dims isn't removing all reduction
1176     // iterators. For example, if the vector.contract had only one reduction
1177     // iterator and that was a unit-dimension created by a broadcast,
1178     // then we should bail here, otherwise we would create a contract without
1179     // a reduction iterator.
1180     if (!llvm::any_of(iterators, isReductionIterator))
1181       return failure();
1182 
1183     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1184         contractOp, lhs, rhs, contractOp.getAcc(),
1185         rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
1186     return success();
1187   }
1188 };
1189 
1190 /// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and
1191 /// contraction ops closer, which kicks in CombineContractBroadcast pattern when
1192 /// casting ops are around these operations.
1193 /// Ex:
1194 /// ```
1195 ///   %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8>
1196 ///   %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
1197 /// ```
1198 /// Gets converted to:
1199 /// ```
1200 ///   %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32>
1201 ///   %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32>
1202 /// ```
1203 struct ReorderCastOpsOnBroadcast
1204     : public OpInterfaceRewritePattern<CastOpInterface> {
1205   using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
1206 
1207   LogicalResult matchAndRewrite(CastOpInterface op,
1208                                 PatternRewriter &rewriter) const override {
1209     if (op->getNumOperands() != 1)
1210       return failure();
1211     auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
1212     if (!bcastOp)
1213       return failure();
1214 
1215     Type castResTy = getElementTypeOrSelf(op->getResult(0));
1216     if (auto vecTy = bcastOp.getSourceType().dyn_cast<VectorType>())
1217       castResTy = VectorType::get(vecTy.getShape(), castResTy);
1218     auto *castOp =
1219         rewriter.create(op->getLoc(), op->getName().getIdentifier(),
1220                         bcastOp.getSource(), castResTy, op->getAttrs());
1221     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1222         op, op->getResult(0).getType(), castOp->getResult(0));
1223     return success();
1224   }
1225 };
1226 
1227 /// Reorders elementwise(transpose) to transpose(elementwise). This makes
1228 /// transpose ops and contraction ops closer, which kicks in
1229 /// CombineContractTranspose pattern when elementwise ops are between these
1230 /// operations. Ex:
1231 /// ```
1232 /// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
1233 /// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
1234 /// %r = arith.addf %at, %bt : vector<2x4xf32>
1235 /// ```
1236 /// Gets converted to:
1237 /// ```
1238 /// %0 = arith.addf %a, %b : vector<4x2xf32>
1239 /// %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
1240 /// ```
1241 struct ReorderElementwiseOpsOnTranspose final
1242     : public OpTraitRewritePattern<OpTrait::Elementwise> {
1243   using OpTraitRewritePattern::OpTraitRewritePattern;
1244   LogicalResult matchAndRewrite(Operation *op,
1245                                 PatternRewriter &rewriter) const override {
1246     if (op->getNumResults() != 1 || op->getNumRegions() != 0)
1247       return failure();
1248 
1249     // Make sure all operands are transpose/constant ops and collect their
1250     // transposition maps.
1251     SmallVector<ArrayAttr, 4> transposeMaps;
1252     transposeMaps.reserve(op->getNumOperands());
1253     // Record the initial type before transposition. We'll use its shape later.
1254     // Any type will do here as we will check all transpose maps are the same.
1255     VectorType srcType;
1256     for (Value operand : op->getOperands()) {
1257       auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
1258       if (transposeOp) {
1259         transposeMaps.push_back(transposeOp.getTransp());
1260         srcType = transposeOp.getVectorType();
1261       } else if (!matchPattern(operand, m_Constant())) {
1262         return failure();
1263       }
1264     }
1265     if (transposeMaps.empty())
1266       return failure();
1267     // This is an elementwise op, so all transposed operands should have the
1268     // same type. We need to additionally check that all transposes uses the
1269     // same map.
1270     if (!llvm::is_splat(transposeMaps))
1271       return rewriter.notifyMatchFailure(op, "different transpose map");
1272 
1273     SmallVector<Value, 4> srcValues;
1274     srcValues.reserve(op->getNumOperands());
1275 
1276     // If there are constant operands, we need to insert inverse transposes for
1277     // them. Calculate the inverse order first.
1278     auto order = extractVector<unsigned>(transposeMaps.front());
1279     SmallVector<int64_t> invOrder(order.size());
1280     for (int i = 0, e = order.size(); i < e; ++i)
1281       invOrder[order[i]] = i;
1282 
1283     for (Value operand : op->getOperands()) {
1284       auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
1285       if (transposeOp) {
1286         srcValues.push_back(transposeOp.getVector());
1287       } else {
1288         // This is a constant. Create a reverse transpose op for it.
1289         auto vectorType = VectorType::get(
1290             srcType.getShape(),
1291             operand.getType().cast<VectorType>().getElementType());
1292         srcValues.push_back(rewriter.create<vector::TransposeOp>(
1293             operand.getLoc(), vectorType, operand,
1294             rewriter.getI64ArrayAttr(invOrder)));
1295       }
1296     }
1297 
1298     auto vectorType = VectorType::get(
1299         srcType.getShape(),
1300         op->getResultTypes()[0].cast<VectorType>().getElementType());
1301     Operation *elementwiseOp =
1302         rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
1303                         vectorType, op->getAttrs());
1304     rewriter.replaceOpWithNewOp<vector::TransposeOp>(
1305         op, op->getResultTypes()[0], elementwiseOp->getResult(0),
1306         transposeMaps.front());
1307     return success();
1308   }
1309 };
1310 
1311 } // namespace
1312 
1313 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
1314 /// operands `x` and `y`.
1315 static Value createAdd(Location loc, Value x, Value y, bool isInt,
1316                        PatternRewriter &rewriter) {
1317   if (isInt)
1318     return rewriter.create<arith::AddIOp>(loc, x, y);
1319   return rewriter.create<arith::AddFOp>(loc, x, y);
1320 }
1321 
1322 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
1323 /// operands `x and `y`.
1324 static Value createMul(Location loc, Value x, Value y, bool isInt,
1325                        PatternRewriter &rewriter) {
1326   if (isInt)
1327     return rewriter.create<arith::MulIOp>(loc, x, y);
1328   return rewriter.create<arith::MulFOp>(loc, x, y);
1329 }
1330 
1331 namespace mlir {
1332 
1333 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1334 /// semantics to:
1335 /// ```
1336 ///    %mta = maybe_transpose
1337 ///    %mtb = maybe_transpose
1338 ///    %flattened_a = vector.shape_cast %mta
1339 ///    %flattened_b = vector.shape_cast %mtb
1340 ///    %flattened_d = vector.matmul %flattened_a, %flattened_b
1341 ///    %mtd = vector.shape_cast %flattened_d
1342 ///    %d = maybe_untranspose %mtd
1343 ///    %e = add %c, %d
1344 /// ```
1345 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
1346 //
1347 /// This only kicks in when VectorTransformsOptions is set to `Matmul`.
1348 /// vector.transpose operations are inserted if the vector.contract op is not a
1349 /// row-major matrix multiply.
1350 LogicalResult
1351 ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
1352                                                  PatternRewriter &rew) const {
1353   // TODO: implement masks
1354   if (llvm::size(op.getMasks()) != 0)
1355     return failure();
1356   if (vectorTransformOptions.vectorContractLowering !=
1357       vector::VectorContractLowering::Matmul)
1358     return failure();
1359   if (failed(filter(op)))
1360     return failure();
1361 
1362   auto iteratorTypes = op.getIteratorTypes().getValue();
1363   if (!isParallelIterator(iteratorTypes[0]) ||
1364       !isParallelIterator(iteratorTypes[1]) ||
1365       !isReductionIterator(iteratorTypes[2]))
1366     return failure();
1367 
1368   Type elementType = op.getLhsType().getElementType();
1369   if (!elementType.isIntOrFloat())
1370     return failure();
1371 
1372   Type dstElementType = op.getType();
1373   if (auto vecType = dstElementType.dyn_cast<VectorType>())
1374     dstElementType = vecType.getElementType();
1375   if (elementType != dstElementType)
1376     return failure();
1377 
1378   // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
1379   // Bail out if the contraction cannot be put in this form.
1380   MLIRContext *ctx = op.getContext();
1381   Location loc = op.getLoc();
1382   AffineExpr m, n, k;
1383   bindDims(rew.getContext(), m, n, k);
1384   // LHS must be A(m, k) or A(k, m).
1385   Value lhs = op.getLhs();
1386   auto lhsMap = op.getIndexingMaps()[0];
1387   if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
1388     lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
1389   else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
1390     return failure();
1391 
1392   // RHS must be B(k, n) or B(n, k).
1393   Value rhs = op.getRhs();
1394   auto rhsMap = op.getIndexingMaps()[1];
1395   if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
1396     rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
1397   else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
1398     return failure();
1399 
1400   // At this point lhs and rhs are in row-major.
1401   VectorType lhsType = lhs.getType().cast<VectorType>();
1402   VectorType rhsType = rhs.getType().cast<VectorType>();
1403   int64_t lhsRows = lhsType.getDimSize(0);
1404   int64_t lhsColumns = lhsType.getDimSize(1);
1405   int64_t rhsColumns = rhsType.getDimSize(1);
1406 
1407   Type flattenedLHSType =
1408       VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
1409   lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
1410 
1411   Type flattenedRHSType =
1412       VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
1413   rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
1414 
1415   Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
1416                                            rhsColumns);
1417   mul = rew.create<vector::ShapeCastOp>(
1418       loc,
1419       VectorType::get({lhsRows, rhsColumns},
1420                       getElementTypeOrSelf(op.getAcc().getType())),
1421       mul);
1422 
1423   // ACC must be C(m, n) or C(n, m).
1424   auto accMap = op.getIndexingMaps()[2];
1425   if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
1426     mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
1427   else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
1428     llvm_unreachable("invalid contraction semantics");
1429 
1430   Value res =
1431       elementType.isa<IntegerType>()
1432           ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
1433           : static_cast<Value>(
1434                 rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
1435 
1436   rew.replaceOp(op, res);
1437   return success();
1438 }
1439 
1440 namespace {
1441 struct IteratorType {
1442   IteratorType(StringRef strRef) : strRef(strRef) {}
1443   bool isOfType(Attribute attr) const {
1444     auto sAttr = attr.dyn_cast<StringAttr>();
1445     return sAttr && sAttr.getValue() == strRef;
1446   }
1447   StringRef strRef;
1448 };
1449 struct Par : public IteratorType {
1450   Par() : IteratorType(getParallelIteratorTypeName()) {}
1451 };
1452 struct Red : public IteratorType {
1453   Red() : IteratorType(getReductionIteratorTypeName()) {}
1454 };
1455 
1456 /// Generate a vector implementation for matmat, matvec and tmatvec.
1457 /// This unrolls outer-products along the reduction dimension.
1458 struct UnrolledOuterProductGenerator
1459     : public StructuredGenerator<vector::ContractionOp> {
1460   UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op)
1461       : StructuredGenerator<vector::ContractionOp>(builder, op),
1462         kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
1463         res(op.getAcc()), lhsType(op.getLhsType()) {}
1464 
1465   Value t(Value v) {
1466     static constexpr std::array<int64_t, 2> perm = {1, 0};
1467     return builder.create<vector::TransposeOp>(loc, v, perm);
1468   }
1469 
1470   Value promote(Value v, Type dstElementType) {
1471     Type elementType = v.getType();
1472     auto vecType = elementType.dyn_cast<VectorType>();
1473     if (vecType)
1474       elementType = vecType.getElementType();
1475     if (elementType == dstElementType)
1476       return v;
1477     Type promotedType = dstElementType;
1478     if (vecType)
1479       promotedType = VectorType::get(vecType.getShape(), promotedType);
1480     if (dstElementType.isa<FloatType>())
1481       return builder.create<arith::ExtFOp>(loc, promotedType, v);
1482     return builder.create<arith::ExtSIOp>(loc, promotedType, v);
1483   }
1484 
1485   Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) {
1486     assert(reductionSize > 0);
1487     Type resElementType = res.getType().cast<VectorType>().getElementType();
1488     for (int64_t k = 0; k < reductionSize; ++k) {
1489       Value a = builder.create<vector::ExtractOp>(loc, lhs, k);
1490       Value b = builder.create<vector::ExtractOp>(loc, rhs, k);
1491       a = promote(a, resElementType);
1492       b = promote(b, resElementType);
1493       res = builder.create<vector::OuterProductOp>(loc, res.getType(), a, b,
1494                                                    res, kind);
1495     }
1496     return res;
1497   }
1498 
1499   /// Two outer parallel, one inner reduction (matmat flavor).
1500   FailureOr<Value> matmat() {
1501     if (!iters({Par(), Par(), Red()}))
1502       return failure();
1503     // Set up the parallel/reduction structure in the right form.
1504     AffineExpr m, n, k;
1505     bindDims(builder.getContext(), m, n, k);
1506     // Classical row-major matmul:  Just permute the lhs.
1507     if (layout({{m, k}, {k, n}, {m, n}}))
1508       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1509     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1510     if (layout({{m, k}, {n, k}, {m, n}})) {
1511       Value tlhs = t(lhs);
1512       return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1));
1513     }
1514     // No need to permute anything.
1515     if (layout({{k, m}, {k, n}, {m, n}}))
1516       return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1517     // Just permute the rhs.
1518     if (layout({{k, m}, {n, k}, {m, n}}))
1519       return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0));
1520     // Transposed output: swap RHS and LHS.
1521     // Classical row-major matmul: permute the lhs.
1522     if (layout({{m, k}, {k, n}, {n, m}}))
1523       return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1));
1524     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1525     if (layout({{m, k}, {n, k}, {n, m}})) {
1526       Value trhs = t(rhs);
1527       return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1));
1528     }
1529     if (layout({{k, m}, {k, n}, {n, m}}))
1530       return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1531     if (layout({{k, m}, {n, k}, {n, m}}))
1532       return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1533     return failure();
1534   }
1535 
1536   /// One outer parallel, one inner reduction (matvec flavor)
1537   FailureOr<Value> matvec() {
1538     if (!iters({Par(), Red()}))
1539       return failure();
1540     AffineExpr m, k;
1541     bindDims(builder.getContext(), m, k);
1542 
1543     // Case mat-vec: transpose.
1544     if (layout({{m, k}, {k}, {m}}))
1545       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1546     // Case mat-trans-vec: ready to go.
1547     if (layout({{k, m}, {k}, {m}}))
1548       return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1549     // Case vec-mat: swap and transpose.
1550     if (layout({{k}, {m, k}, {m}}))
1551       return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1552     // Case vec-mat-trans: swap and ready to go.
1553     if (layout({{k}, {k, m}, {m}}))
1554       return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1555     return failure();
1556   }
1557 
1558   //
1559   // One outer reduction, one inner parallel (tmatvec flavor)
1560   //
1561   FailureOr<Value> tmatvec() {
1562     if (!iters({Red(), Par()}))
1563       return failure();
1564     AffineExpr k, m;
1565     bindDims(builder.getContext(), k, m);
1566 
1567     // Case mat-vec: transpose.
1568     if (layout({{m, k}, {k}, {m}}))
1569       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1570     // Case mat-trans-vec: ready to go.
1571     if (layout({{k, m}, {k}, {m}}))
1572       return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1573     // Case vec-mat: swap and transpose.
1574     if (layout({{k}, {m, k}, {m}}))
1575       return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1576     // Case vec-mat-trans: swap and ready to go.
1577     if (layout({{k}, {k, m}, {m}}))
1578       return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1579     return failure();
1580   }
1581 
1582 private:
1583   vector::CombiningKind kind;
1584   Value lhs, rhs, res;
1585   VectorType lhsType;
1586 };
1587 } // namespace
1588 
1589 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1590 /// semantics to a reduction_size-unrolled sequence:
1591 /// ```
1592 ///    %at = vector.transpose %a, [1, 0]
1593 ///    %bRow0 = vector.extract %b[0]
1594 ///    %atRow0 = vector.extract %at[0]
1595 ///    %c0 = vector.outerproduct %atRow0, %bRow0, %c
1596 ///    ...
1597 ///    %bRowK = vector.extract %b[K]
1598 ///    %atRowK = vector.extract %at[K]
1599 ///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
1600 /// ```
1601 ///
1602 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but
1603 /// otherwise supports any layout permutation of the matrix-multiply.
1604 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
1605     vector::ContractionOp op, PatternRewriter &rewriter) const {
1606   // TODO: implement masks
1607   if (llvm::size(op.getMasks()) != 0)
1608     return failure();
1609 
1610   if (vectorTransformOptions.vectorContractLowering !=
1611       vector::VectorContractLowering::OuterProduct)
1612     return failure();
1613 
1614   if (failed(filter(op)))
1615     return failure();
1616 
1617   UnrolledOuterProductGenerator e(rewriter, op);
1618   FailureOr<Value> matmatRes = e.matmat();
1619   if (succeeded(matmatRes)) {
1620     rewriter.replaceOp(op, *matmatRes);
1621     return success();
1622   }
1623   FailureOr<Value> matvecRes = e.matvec();
1624   if (succeeded(matvecRes)) {
1625     rewriter.replaceOp(op, *matvecRes);
1626     return success();
1627   }
1628   FailureOr<Value> tmatvecRes = e.tmatvec();
1629   if (succeeded(tmatvecRes)) {
1630     rewriter.replaceOp(op, *tmatvecRes);
1631     return success();
1632   }
1633 
1634   return failure();
1635 }
1636 
1637 LogicalResult
1638 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
1639                                             PatternRewriter &rewriter) const {
1640   // TODO: implement masks
1641   if (llvm::size(op.getMasks()) != 0)
1642     return failure();
1643 
1644   if (failed(filter(op)))
1645     return failure();
1646 
1647   if (vectorTransformOptions.vectorContractLowering !=
1648       vector::VectorContractLowering::Dot)
1649     return failure();
1650 
1651   auto iteratorTypes = op.getIteratorTypes().getValue();
1652   static constexpr std::array<int64_t, 2> perm = {1, 0};
1653   Location loc = op.getLoc();
1654   Value lhs = op.getLhs(), rhs = op.getRhs();
1655 
1656   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1657   auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
1658   AffineExpr m, n, k;
1659   bindDims(rewriter.getContext(), m, n, k);
1660   SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
1661   //
1662   // In the following we wish to make the reduction dimension innermost so we
1663   // can load vectors and just fmul + reduce into a scalar.
1664   //
1665   if (isParallelIterator(iteratorTypes[0]) &&
1666       isParallelIterator(iteratorTypes[1]) &&
1667       isReductionIterator(iteratorTypes[2])) {
1668     //
1669     // Two outer parallel, one inner reduction (matmat flavor).
1670     //
1671     if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1672       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1673     } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
1674       // No need to permute anything.
1675     } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1676       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1677       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1678     } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1679       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1680     } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1681       // This is the classical row-major matmul. Just permute the lhs.
1682       Value tmp = lhs;
1683       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1684       rhs = tmp;
1685     } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1686       std::swap(lhs, rhs);
1687     } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1688       Value tmp = lhs;
1689       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1690       rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
1691     } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1692       Value tmp = rhs;
1693       rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1694       lhs = tmp;
1695     } else {
1696       return failure();
1697     }
1698   } else if (isParallelIterator(iteratorTypes[0]) &&
1699              isReductionIterator(iteratorTypes[1])) {
1700     //
1701     // One outer parallel, one inner reduction (matvec flavor)
1702     //
1703     if (maps == infer({{m, n}, {n}, {m}})) {
1704       // No need to permute anything.
1705     } else if (maps == infer({{n, m}, {n}, {m}})) {
1706       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1707     } else if (maps == infer({{n}, {m, n}, {m}})) {
1708       std::swap(lhs, rhs);
1709     } else if (maps == infer({{n}, {n, m}, {m}})) {
1710       std::swap(lhs, rhs);
1711       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1712     } else {
1713       return failure();
1714     }
1715   } else {
1716     return failure();
1717   }
1718 
1719   VectorType dstType = op.getResultType().cast<VectorType>();
1720   assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
1721          "Expected dst type of rank 1 or 2");
1722 
1723   unsigned rank = dstType.getRank();
1724   unsigned dstRows = dstType.getShape()[0];
1725   unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
1726 
1727   // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
1728   Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
1729                                                  rewriter.getZeroAttr(dstType));
1730   bool isInt = dstType.getElementType().isa<IntegerType>();
1731   for (unsigned r = 0; r < dstRows; ++r) {
1732     Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
1733     for (unsigned c = 0; c < dstColumns; ++c) {
1734       Value b = rank == 1
1735                     ? rhs
1736                     : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
1737       Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
1738       Value reduced = rewriter.create<vector::ReductionOp>(
1739           op.getLoc(), vector::CombiningKind::ADD, m);
1740 
1741       SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
1742                                               : SmallVector<int64_t, 2>{r, c};
1743       res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
1744     }
1745   }
1746   if (auto acc = op.getAcc())
1747     res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
1748   rewriter.replaceOp(op, res);
1749   return success();
1750 }
1751 
1752 /// Progressive lowering of ContractionOp.
1753 /// One:
1754 ///   %x = vector.contract with at least one free/batch dimension
1755 /// is replaced by:
1756 ///   %a = vector.contract with one less free/batch dimension
1757 ///   %b = vector.contract with one less free/batch dimension
1758 ///   ..
1759 ///   %x = combine %a %b ..
1760 /// until a pure contraction is reached (no free/batch dimensions),
1761 /// which is replaced by a dot-product.
1762 ///
1763 /// This only kicks in when either VectorTransformsOptions is set
1764 /// to DOT or when other contraction patterns fail.
1765 //
1766 // TODO: break down into transpose/reshape/cast ops
1767 //               when they become available to avoid code dup
1768 // TODO: investigate lowering order impact on performance
1769 LogicalResult
1770 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
1771                                        PatternRewriter &rewriter) const {
1772   // TODO: implement masks.
1773   if (llvm::size(op.getMasks()) != 0)
1774     return failure();
1775 
1776   if (failed(filter(op)))
1777     return failure();
1778 
1779   // TODO: support mixed mode contract lowering.
1780   if (op.getLhsType().getElementType() !=
1781           getElementTypeOrSelf(op.getAccType()) ||
1782       op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
1783     return failure();
1784 
1785   // TODO: implement benefits, cost models.
1786   MLIRContext *ctx = op.getContext();
1787   ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
1788   if (succeeded(pat1.matchAndRewrite(op, rewriter)))
1789     return success();
1790   ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
1791   if (succeeded(pat2.matchAndRewrite(op, rewriter)))
1792     return success();
1793   ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
1794   if (succeeded(pat3.matchAndRewrite(op, rewriter)))
1795     return success();
1796   ContractOpToElementwise pat4(vectorTransformOptions, ctx);
1797   if (succeeded(pat4.matchAndRewrite(op, rewriter)))
1798     return success();
1799 
1800   // Find first batch dimension in LHS/RHS, and lower when found.
1801   std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
1802   if (!batchDimMap.empty()) {
1803     int64_t lhsIndex = batchDimMap[0].first;
1804     int64_t rhsIndex = batchDimMap[0].second;
1805     rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter));
1806     return success();
1807   }
1808 
1809   // Collect contracting dimensions.
1810   std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
1811       op.getContractingDimMap();
1812   DenseSet<int64_t> lhsContractingDimSet;
1813   DenseSet<int64_t> rhsContractingDimSet;
1814   for (auto &dimPair : contractingDimMap) {
1815     lhsContractingDimSet.insert(dimPair.first);
1816     rhsContractingDimSet.insert(dimPair.second);
1817   }
1818 
1819   // Find first free dimension in LHS, and lower when found.
1820   VectorType lhsType = op.getLhsType();
1821   for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
1822     if (lhsContractingDimSet.count(lhsIndex) == 0) {
1823       rewriter.replaceOp(
1824           op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter));
1825       return success();
1826     }
1827   }
1828 
1829   // Find first free dimension in RHS, and lower when found.
1830   VectorType rhsType = op.getRhsType();
1831   for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
1832     if (rhsContractingDimSet.count(rhsIndex) == 0) {
1833       rewriter.replaceOp(
1834           op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter));
1835       return success();
1836     }
1837   }
1838 
1839   // Lower the first remaining reduction dimension.
1840   if (!contractingDimMap.empty()) {
1841     rewriter.replaceOp(op, lowerReduction(op, rewriter));
1842     return success();
1843   }
1844 
1845   return failure();
1846 }
1847 
1848 // Lower one parallel dimension.
1849 // TODO: consider reusing existing contract unrolling
1850 Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
1851                                            int64_t lhsIndex, int64_t rhsIndex,
1852                                            PatternRewriter &rewriter) const {
1853   VectorType lhsType = op.getLhsType();
1854   VectorType rhsType = op.getRhsType();
1855   VectorType resType = op.getResultType().cast<VectorType>();
1856   // Find the iterator type index and result index.
1857   SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
1858   int64_t iterIndex = -1;
1859   int64_t dimSize = -1;
1860   if (lhsIndex >= 0) {
1861     iterIndex = iMap[0].getDimPosition(lhsIndex);
1862     assert((rhsIndex < 0 || iterIndex == iMap[1].getDimPosition(rhsIndex)) &&
1863            "parallel index should be free in LHS or batch in LHS/RHS");
1864     dimSize = lhsType.getDimSize(lhsIndex);
1865   } else {
1866     assert(rhsIndex >= 0 && "missing parallel index");
1867     iterIndex = iMap[1].getDimPosition(rhsIndex);
1868     dimSize = rhsType.getDimSize(rhsIndex);
1869   }
1870   assert(iterIndex >= 0 && "parallel index not listed in operand mapping");
1871   Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex);
1872   assert(lookup.hasValue() && "parallel index not listed in reduction");
1873   int64_t resIndex = lookup.getValue();
1874   // Construct new iterator types and affine map array attribute.
1875   std::array<AffineMap, 3> lowIndexingMaps = {
1876       adjustMap(iMap[0], iterIndex, rewriter),
1877       adjustMap(iMap[1], iterIndex, rewriter),
1878       adjustMap(iMap[2], iterIndex, rewriter)};
1879   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1880   auto lowIter =
1881       rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1882   // Unroll into a series of lower dimensional vector.contract ops.
1883   Location loc = op.getLoc();
1884   Value result = rewriter.create<arith::ConstantOp>(
1885       loc, resType, rewriter.getZeroAttr(resType));
1886   for (int64_t d = 0; d < dimSize; ++d) {
1887     auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1888     auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1889     auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
1890     Value lowContract = rewriter.create<vector::ContractionOp>(
1891         loc, lhs, rhs, acc, lowAffine, lowIter);
1892     result =
1893         reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter);
1894   }
1895   return result;
1896 }
1897 
1898 // Lower one reduction dimension.
1899 Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
1900                                             PatternRewriter &rewriter) const {
1901   auto loc = op.getLoc();
1902   VectorType lhsType = op.getLhsType();
1903   VectorType rhsType = op.getRhsType();
1904   Type resType = op.getResultType();
1905   assert(!resType.isa<VectorType>());
1906   bool isInt = resType.isa<IntegerType>();
1907   // Use iterator index 0.
1908   int64_t iterIndex = 0;
1909   SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
1910   Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
1911   Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
1912   assert(lookupLhs.hasValue() && "missing LHS parallel index");
1913   assert(lookupRhs.hasValue() && "missing RHS parallel index");
1914   int64_t lhsIndex = lookupLhs.getValue();
1915   int64_t rhsIndex = lookupRhs.getValue();
1916   int64_t dimSize = lhsType.getDimSize(lhsIndex);
1917   assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape");
1918   // Base case.
1919   if (lhsType.getRank() == 1) {
1920     assert(rhsType.getRank() == 1 && "corrupt contraction");
1921     Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
1922     auto kind = vector::CombiningKind::ADD;
1923     if (auto acc = op.getAcc())
1924       return rewriter.create<vector::ReductionOp>(loc, kind, m, acc);
1925     return rewriter.create<vector::ReductionOp>(loc, kind, m);
1926   }
1927   // Construct new iterator types and affine map array attribute.
1928   std::array<AffineMap, 3> lowIndexingMaps = {
1929       adjustMap(iMap[0], iterIndex, rewriter),
1930       adjustMap(iMap[1], iterIndex, rewriter),
1931       adjustMap(iMap[2], iterIndex, rewriter)};
1932   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1933   auto lowIter =
1934       rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1935   // Unroll into a series of lower dimensional vector.contract ops.
1936   // By feeding the initial accumulator into the first contraction,
1937   // and the result of each contraction into the next, eventually
1938   // the sum of all reductions is computed.
1939   Value result = op.getAcc();
1940   for (int64_t d = 0; d < dimSize; ++d) {
1941     auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1942     auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1943     result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result,
1944                                                     lowAffine, lowIter);
1945   }
1946   return result;
1947 }
1948 
1949 } // namespace mlir
1950 
1951 Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
1952     OpBuilder &builder, Operation *op, ArrayRef<Value> ids,
1953     ArrayRef<int64_t> multiplicity, const AffineMap &map) {
1954   OpBuilder::InsertionGuard guard(builder);
1955   builder.setInsertionPointAfter(op);
1956   Location loc = op->getLoc();
1957   if (op->getNumResults() != 1)
1958     return {};
1959   Value result = op->getResult(0);
1960   VectorType type = op->getResult(0).getType().dyn_cast<VectorType>();
1961   if (!type || map.getNumResults() != multiplicity.size())
1962     return {};
1963   // For each dimension being distributed check that the size is a multiple of
1964   // the multiplicity. To handle more sizes we would need to support masking.
1965   unsigned multiplictyCount = 0;
1966   for (auto exp : map.getResults()) {
1967     auto affinExp = exp.dyn_cast<AffineDimExpr>();
1968     if (!affinExp || affinExp.getPosition() >= type.getRank() ||
1969         type.getDimSize(affinExp.getPosition()) %
1970                 multiplicity[multiplictyCount++] !=
1971             0)
1972       return {};
1973   }
1974   DistributeOps ops;
1975   ops.extract =
1976       builder.create<vector::ExtractMapOp>(loc, result, ids, multiplicity, map);
1977   ops.insert =
1978       builder.create<vector::InsertMapOp>(loc, ops.extract, result, ids);
1979   return ops;
1980 }
1981 
1982 /// Progressive lowering of transfer_read. This pattern supports lowering of
1983 /// `vector.transfer_read` to a combination of `vector.load` and
1984 /// `vector.broadcast` if all of the following hold:
1985 /// - Stride of most minor memref dimension must be 1.
1986 /// - Out-of-bounds masking is not required.
1987 /// - If the memref's element type is a vector type then it coincides with the
1988 ///   result type.
1989 /// - The permutation map doesn't perform permutation (broadcasting is allowed).
1990 struct TransferReadToVectorLoadLowering
1991     : public OpRewritePattern<vector::TransferReadOp> {
1992   TransferReadToVectorLoadLowering(MLIRContext *context,
1993                                    llvm::Optional<unsigned> maxRank)
1994       : OpRewritePattern<vector::TransferReadOp>(context),
1995         maxTransferRank(maxRank) {}
1996 
1997   LogicalResult matchAndRewrite(vector::TransferReadOp read,
1998                                 PatternRewriter &rewriter) const override {
1999     if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank)
2000       return failure();
2001 
2002     SmallVector<unsigned, 4> broadcastedDims;
2003     // Permutations are handled by VectorToSCF or
2004     // populateVectorTransferPermutationMapLoweringPatterns.
2005     // We let the 0-d corner case pass-through as it is supported.
2006     if (!read.getPermutationMap().isMinorIdentityWithBroadcasting(
2007             &broadcastedDims))
2008       return failure();
2009 
2010     auto memRefType = read.getShapedType().dyn_cast<MemRefType>();
2011     if (!memRefType)
2012       return failure();
2013 
2014     // Non-unit strides are handled by VectorToSCF.
2015     if (!vector::isLastMemrefDimUnitStride(memRefType))
2016       return failure();
2017 
2018     // If there is broadcasting involved then we first load the unbroadcasted
2019     // vector, and then broadcast it with `vector.broadcast`.
2020     ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
2021     SmallVector<int64_t, 4> unbroadcastedVectorShape(vectorShape.begin(),
2022                                                      vectorShape.end());
2023     for (unsigned i : broadcastedDims)
2024       unbroadcastedVectorShape[i] = 1;
2025     VectorType unbroadcastedVectorType = VectorType::get(
2026         unbroadcastedVectorShape, read.getVectorType().getElementType());
2027 
2028     // `vector.load` supports vector types as memref's elements only when the
2029     // resulting vector type is the same as the element type.
2030     auto memrefElTy = memRefType.getElementType();
2031     if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType)
2032       return failure();
2033 
2034     // Otherwise, element types of the memref and the vector must match.
2035     if (!memrefElTy.isa<VectorType>() &&
2036         memrefElTy != read.getVectorType().getElementType())
2037       return failure();
2038 
2039     // Out-of-bounds dims are handled by MaterializeTransferMask.
2040     if (read.hasOutOfBoundsDim())
2041       return failure();
2042 
2043     // Create vector load op.
2044     Operation *loadOp;
2045     if (read.getMask()) {
2046       Value fill = rewriter.create<vector::SplatOp>(
2047           read.getLoc(), unbroadcastedVectorType, read.getPadding());
2048       loadOp = rewriter.create<vector::MaskedLoadOp>(
2049           read.getLoc(), unbroadcastedVectorType, read.getSource(),
2050           read.getIndices(), read.getMask(), fill);
2051     } else {
2052       loadOp = rewriter.create<vector::LoadOp>(
2053           read.getLoc(), unbroadcastedVectorType, read.getSource(),
2054           read.getIndices());
2055     }
2056 
2057     // Insert a broadcasting op if required.
2058     if (!broadcastedDims.empty()) {
2059       rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
2060           read, read.getVectorType(), loadOp->getResult(0));
2061     } else {
2062       rewriter.replaceOp(read, loadOp->getResult(0));
2063     }
2064 
2065     return success();
2066   }
2067 
2068   llvm::Optional<unsigned> maxTransferRank;
2069 };
2070 
2071 /// Replace a 0-d vector.load with a memref.load + vector.broadcast.
2072 // TODO: we shouldn't cross the vector/scalar domains just for this
2073 // but atm we lack the infra to avoid it. Possible solutions include:
2074 // - go directly to LLVM + bitcast
2075 // - introduce a bitcast op and likely a new pointer dialect
2076 // - let memref.load/store additionally support the 0-d vector case
2077 // There are still deeper data layout issues lingering even in this
2078 // trivial case (for architectures for which this matters).
2079 struct VectorLoadToMemrefLoadLowering
2080     : public OpRewritePattern<vector::LoadOp> {
2081   using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
2082 
2083   LogicalResult matchAndRewrite(vector::LoadOp loadOp,
2084                                 PatternRewriter &rewriter) const override {
2085     auto vecType = loadOp.getVectorType();
2086     if (vecType.getNumElements() != 1)
2087       return failure();
2088     auto memrefLoad = rewriter.create<memref::LoadOp>(
2089         loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices());
2090     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
2091                                                      memrefLoad);
2092     return success();
2093   }
2094 };
2095 
2096 /// Replace a 0-d vector.store with a vector.extractelement + memref.store.
2097 struct VectorStoreToMemrefStoreLowering
2098     : public OpRewritePattern<vector::StoreOp> {
2099   using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
2100 
2101   LogicalResult matchAndRewrite(vector::StoreOp storeOp,
2102                                 PatternRewriter &rewriter) const override {
2103     auto vecType = storeOp.getVectorType();
2104     if (vecType.getNumElements() != 1)
2105       return failure();
2106     Value extracted;
2107     if (vecType.getRank() == 0) {
2108       // TODO: Unifiy once ExtractOp supports 0-d vectors.
2109       extracted = rewriter.create<vector::ExtractElementOp>(
2110           storeOp.getLoc(), storeOp.getValueToStore());
2111     } else {
2112       SmallVector<int64_t> indices(vecType.getRank(), 0);
2113       extracted = rewriter.create<vector::ExtractOp>(
2114           storeOp.getLoc(), storeOp.getValueToStore(), indices);
2115     }
2116 
2117     rewriter.replaceOpWithNewOp<memref::StoreOp>(
2118         storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
2119     return success();
2120   }
2121 };
2122 
2123 /// Progressive lowering of transfer_write. This pattern supports lowering of
2124 /// `vector.transfer_write` to `vector.store` if all of the following hold:
2125 /// - Stride of most minor memref dimension must be 1.
2126 /// - Out-of-bounds masking is not required.
2127 /// - If the memref's element type is a vector type then it coincides with the
2128 ///   type of the written value.
2129 /// - The permutation map is the minor identity map (neither permutation nor
2130 ///   broadcasting is allowed).
2131 struct TransferWriteToVectorStoreLowering
2132     : public OpRewritePattern<vector::TransferWriteOp> {
2133   TransferWriteToVectorStoreLowering(MLIRContext *context,
2134                                      llvm::Optional<unsigned> maxRank)
2135       : OpRewritePattern<vector::TransferWriteOp>(context),
2136         maxTransferRank(maxRank) {}
2137 
2138   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
2139                                 PatternRewriter &rewriter) const override {
2140     if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank)
2141       return failure();
2142 
2143     // Permutations are handled by VectorToSCF or
2144     // populateVectorTransferPermutationMapLoweringPatterns.
2145     if ( // pass-through for the 0-d corner case.
2146         !write.getPermutationMap().isMinorIdentity())
2147       return failure();
2148 
2149     auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
2150     if (!memRefType)
2151       return failure();
2152 
2153     // Non-unit strides are handled by VectorToSCF.
2154     if (!vector::isLastMemrefDimUnitStride(memRefType))
2155       return failure();
2156 
2157     // `vector.store` supports vector types as memref's elements only when the
2158     // type of the vector value being written is the same as the element type.
2159     auto memrefElTy = memRefType.getElementType();
2160     if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType())
2161       return failure();
2162 
2163     // Otherwise, element types of the memref and the vector must match.
2164     if (!memrefElTy.isa<VectorType>() &&
2165         memrefElTy != write.getVectorType().getElementType())
2166       return failure();
2167 
2168     // Out-of-bounds dims are handled by MaterializeTransferMask.
2169     if (write.hasOutOfBoundsDim())
2170       return failure();
2171     if (write.getMask()) {
2172       rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
2173           write, write.getSource(), write.getIndices(), write.getMask(),
2174           write.getVector());
2175     } else {
2176       rewriter.replaceOpWithNewOp<vector::StoreOp>(
2177           write, write.getVector(), write.getSource(), write.getIndices());
2178     }
2179     return success();
2180   }
2181 
2182   llvm::Optional<unsigned> maxTransferRank;
2183 };
2184 
2185 // Returns the values in `arrayAttr` as an integer vector.
2186 static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) {
2187   return llvm::to_vector<4>(
2188       llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
2189                       [](IntegerAttr attr) { return attr.getInt(); }));
2190 }
2191 
2192 // Shuffles vector.bitcast op after vector.extract op.
2193 //
2194 // This transforms IR like:
2195 //   %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
2196 //   %1 = vector.extract %0[3] : vector<8xf16>
2197 // Into:
2198 //   %0 = vector.extract %src[1] : vector<4xf32>
2199 //   %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16>
2200 //   %2 = vector.extract %1[1] : vector<2xf16>
2201 struct BubbleDownVectorBitCastForExtract
2202     : public OpRewritePattern<vector::ExtractOp> {
2203   using OpRewritePattern::OpRewritePattern;
2204 
2205   LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
2206                                 PatternRewriter &rewriter) const override {
2207     // Only support extracting scalars for now.
2208     if (extractOp.getVectorType().getRank() != 1)
2209       return failure();
2210 
2211     auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
2212     if (!castOp)
2213       return failure();
2214 
2215     VectorType castSrcType = castOp.getSourceVectorType();
2216     VectorType castDstType = castOp.getResultVectorType();
2217     assert(castSrcType.getRank() == castDstType.getRank());
2218 
2219     // Fail to match if we only have one element in the cast op source.
2220     // This is to avoid infinite loop given that this pattern can generate
2221     // such cases.
2222     if (castSrcType.getNumElements() == 1)
2223       return failure();
2224 
2225     // Only support casting to a larger number of elements or now.
2226     // E.g., vector<4xf32> -> vector<8xf16>.
2227     if (castSrcType.getNumElements() > castDstType.getNumElements())
2228       return failure();
2229 
2230     unsigned expandRatio =
2231         castDstType.getNumElements() / castSrcType.getNumElements();
2232 
2233     auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t {
2234       return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
2235     };
2236 
2237     uint64_t index = getFirstIntValue(extractOp.getPosition());
2238 
2239     // Get the single scalar (as a vector) in the source value that packs the
2240     // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
2241     VectorType oneScalarType =
2242         VectorType::get({1}, castSrcType.getElementType());
2243     Value packedValue = rewriter.create<vector::ExtractOp>(
2244         extractOp.getLoc(), oneScalarType, castOp.getSource(),
2245         rewriter.getI64ArrayAttr(index / expandRatio));
2246 
2247     // Cast it to a vector with the desired scalar's type.
2248     // E.g. f32 -> vector<2xf16>
2249     VectorType packedType =
2250         VectorType::get({expandRatio}, castDstType.getElementType());
2251     Value castedValue = rewriter.create<vector::BitCastOp>(
2252         extractOp.getLoc(), packedType, packedValue);
2253 
2254     // Finally extract the desired scalar.
2255     rewriter.replaceOpWithNewOp<vector::ExtractOp>(
2256         extractOp, extractOp.getType(), castedValue,
2257         rewriter.getI64ArrayAttr(index % expandRatio));
2258 
2259     return success();
2260   }
2261 };
2262 
2263 // Shuffles vector.bitcast op after vector.extract_strided_slice op.
2264 //
2265 // This transforms IR like:
2266 //    %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
2267 //     %0 = vector.extract_strided_slice %cast {
2268 //            offsets = [4], sizes = [4], strides = [1]
2269 //          } : vector<8xf16> to vector<4xf16>
2270 // Into:
2271 //   %0 = vector.extract_strided_slice %src {
2272 //          offsets = [2], sizes = [2], strides = [1]
2273 //        } : vector<4xf32> to vector<2xf32>
2274 //   %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16>
2275 struct BubbleDownBitCastForStridedSliceExtract
2276     : public OpRewritePattern<vector::ExtractStridedSliceOp> {
2277   using OpRewritePattern::OpRewritePattern;
2278 
2279   LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
2280                                 PatternRewriter &rewriter) const override {
2281     auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
2282     if (!castOp)
2283       return failure();
2284 
2285     VectorType castSrcType = castOp.getSourceVectorType();
2286     VectorType castDstType = castOp.getResultVectorType();
2287     assert(castSrcType.getRank() == castDstType.getRank());
2288 
2289     int64_t castSrcLastDim = castSrcType.getShape().back();
2290     int64_t castDstLastDim = castDstType.getShape().back();
2291     // Require casting to more elements for now; other cases to be implemented.
2292     if (castSrcLastDim > castDstLastDim)
2293       return failure();
2294 
2295     // Only accept all one strides for now.
2296     if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
2297                      [](const APInt &val) { return !val.isOneValue(); }))
2298       return failure();
2299 
2300     unsigned rank = extractOp.getVectorType().getRank();
2301     assert(castDstLastDim % castSrcLastDim == 0);
2302     int64_t expandRatio = castDstLastDim / castSrcLastDim;
2303 
2304     // If we have a less number of offsets than the rank, then implicitly we
2305     // are selecting the full range for the last bitcasted dimension; other
2306     // dimensions aren't affected. Otherwise, we need to scale down the last
2307     // dimension's offset given we are extracting from less elements now.
2308     ArrayAttr newOffsets = extractOp.getOffsets();
2309     if (newOffsets.size() == rank) {
2310       SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
2311       if (offsets.back() % expandRatio != 0)
2312         return failure();
2313       offsets.back() = offsets.back() / expandRatio;
2314       newOffsets = rewriter.getI64ArrayAttr(offsets);
2315     }
2316 
2317     // Similarly for sizes.
2318     ArrayAttr newSizes = extractOp.getSizes();
2319     if (newSizes.size() == rank) {
2320       SmallVector<int64_t, 4> sizes = getIntValueVector(newSizes);
2321       if (sizes.back() % expandRatio != 0)
2322         return failure();
2323       sizes.back() = sizes.back() / expandRatio;
2324       newSizes = rewriter.getI64ArrayAttr(sizes);
2325     }
2326 
2327     SmallVector<int64_t, 4> dims =
2328         llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape());
2329     dims.back() = dims.back() / expandRatio;
2330     VectorType newExtractType =
2331         VectorType::get(dims, castSrcType.getElementType());
2332 
2333     auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
2334         extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets,
2335         newSizes, extractOp.getStrides());
2336 
2337     rewriter.replaceOpWithNewOp<vector::BitCastOp>(
2338         extractOp, extractOp.getType(), newExtractOp);
2339 
2340     return success();
2341   }
2342 };
2343 
2344 // Shuffles vector.bitcast op before vector.insert_strided_slice op.
2345 //
2346 // This transforms IR like:
2347 //   %0 = vector.insert_strided_slice %src, %dst {
2348 //          offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
2349 //   %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
2350 // Into:
2351 //   %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32>
2352 //   %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32>
2353 //   %2 = vector.insert_strided_slice %src, %dst {
2354 //          offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
2355 struct BubbleUpBitCastForStridedSliceInsert
2356     : public OpRewritePattern<vector::BitCastOp> {
2357   using OpRewritePattern::OpRewritePattern;
2358   LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
2359                                 PatternRewriter &rewriter) const override {
2360     VectorType castSrcType = bitcastOp.getSourceVectorType();
2361     VectorType castDstType = bitcastOp.getResultVectorType();
2362     assert(castSrcType.getRank() == castDstType.getRank());
2363 
2364     int64_t castSrcLastDim = castSrcType.getShape().back();
2365     int64_t castDstLastDim = castDstType.getShape().back();
2366     // Require casting to less elements for now; other cases to be implemented.
2367     if (castSrcLastDim < castDstLastDim)
2368       return failure();
2369 
2370     assert(castSrcLastDim % castDstLastDim == 0);
2371     int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
2372 
2373     auto insertOp =
2374         bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
2375     if (!insertOp)
2376       return failure();
2377 
2378     // Only accept all one strides for now.
2379     if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
2380                      [](const APInt &val) { return !val.isOneValue(); }))
2381       return failure();
2382 
2383     unsigned rank = insertOp.getSourceVectorType().getRank();
2384     // Require insert op to have the same rank for the source and destination
2385     // vector; other cases to be implemented.
2386     if (rank != insertOp.getDestVectorType().getRank())
2387       return failure();
2388 
2389     ArrayAttr newOffsets = insertOp.getOffsets();
2390     assert(newOffsets.size() == rank);
2391     SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
2392     if (offsets.back() % shrinkRatio != 0)
2393       return failure();
2394     offsets.back() = offsets.back() / shrinkRatio;
2395     newOffsets = rewriter.getI64ArrayAttr(offsets);
2396 
2397     SmallVector<int64_t, 4> srcDims =
2398         llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
2399     srcDims.back() = srcDims.back() / shrinkRatio;
2400     VectorType newCastSrcType =
2401         VectorType::get(srcDims, castDstType.getElementType());
2402 
2403     auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
2404         bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
2405 
2406     SmallVector<int64_t, 4> dstDims =
2407         llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
2408     dstDims.back() = dstDims.back() / shrinkRatio;
2409     VectorType newCastDstType =
2410         VectorType::get(dstDims, castDstType.getElementType());
2411 
2412     auto newCastDstOp = rewriter.create<vector::BitCastOp>(
2413         bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
2414 
2415     rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
2416         bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
2417         insertOp.getStrides());
2418 
2419     return success();
2420   }
2421 };
2422 
2423 // Helper that returns a vector comparison that constructs a mask:
2424 //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
2425 //
2426 // If `dim == 0` then the result will be a 0-D vector.
2427 //
2428 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
2429 //       much more compact, IR for this operation, but LLVM eventually
2430 //       generates more elaborate instructions for this intrinsic since it
2431 //       is very conservative on the boundary conditions.
2432 static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
2433                                    bool force32BitVectorIndices, int64_t dim,
2434                                    Value b, Value *off = nullptr) {
2435   auto loc = op->getLoc();
2436   // If we can assume all indices fit in 32-bit, we perform the vector
2437   // comparison in 32-bit to get a higher degree of SIMD parallelism.
2438   // Otherwise we perform the vector comparison using 64-bit indices.
2439   Type idxType =
2440       force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
2441   DenseIntElementsAttr indicesAttr;
2442   if (dim == 0 && force32BitVectorIndices) {
2443     indicesAttr = DenseIntElementsAttr::get(
2444         VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0});
2445   } else if (dim == 0) {
2446     indicesAttr = DenseIntElementsAttr::get(
2447         VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0});
2448   } else if (force32BitVectorIndices) {
2449     indicesAttr = rewriter.getI32VectorAttr(
2450         llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
2451   } else {
2452     indicesAttr = rewriter.getI64VectorAttr(
2453         llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
2454   }
2455   Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
2456   // Add in an offset if requested.
2457   if (off) {
2458     Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
2459     Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o);
2460     indices = rewriter.create<arith::AddIOp>(loc, ov, indices);
2461   }
2462   // Construct the vector comparison.
2463   Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
2464   Value bounds =
2465       rewriter.create<vector::SplatOp>(loc, indices.getType(), bound);
2466   return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
2467                                         bounds);
2468 }
2469 
2470 template <typename ConcreteOp>
2471 struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
2472 public:
2473   explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt)
2474       : mlir::OpRewritePattern<ConcreteOp>(context),
2475         force32BitVectorIndices(enableIndexOpt) {}
2476 
2477   LogicalResult matchAndRewrite(ConcreteOp xferOp,
2478                                 PatternRewriter &rewriter) const override {
2479     if (!xferOp.hasOutOfBoundsDim())
2480       return failure();
2481 
2482     if (xferOp.getVectorType().getRank() > 1 ||
2483         llvm::size(xferOp.getIndices()) == 0)
2484       return failure();
2485 
2486     Location loc = xferOp->getLoc();
2487     VectorType vtp = xferOp.getVectorType();
2488 
2489     // Create the in-bounds mask with all elements between [0 .. dim - offset)
2490     // set and [dim - offset .. vector_length) unset.
2491     //
2492     // TODO: when the leaf transfer rank is k > 1, we need the last `k`
2493     //       dimensions here.
2494     unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
2495     Value off = xferOp.getIndices()[lastIndex];
2496     Value dim =
2497         vector::createOrFoldDimOp(rewriter, loc, xferOp.getSource(), lastIndex);
2498     Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off);
2499     Value mask = rewriter.create<vector::CreateMaskOp>(
2500         loc,
2501         VectorType::get(vtp.getShape(), rewriter.getI1Type(),
2502                         vtp.getNumScalableDims()),
2503         b);
2504     if (xferOp.getMask()) {
2505       // Intersect the in-bounds with the mask specified as an op parameter.
2506       mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask());
2507     }
2508 
2509     rewriter.updateRootInPlace(xferOp, [&]() {
2510       xferOp.getMaskMutable().assign(mask);
2511       xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
2512     });
2513 
2514     return success();
2515   }
2516 
2517 private:
2518   const bool force32BitVectorIndices;
2519 };
2520 
2521 /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only).
2522 class VectorCreateMaskOpConversion
2523     : public OpRewritePattern<vector::CreateMaskOp> {
2524 public:
2525   explicit VectorCreateMaskOpConversion(MLIRContext *context,
2526                                         bool enableIndexOpt)
2527       : mlir::OpRewritePattern<vector::CreateMaskOp>(context),
2528         force32BitVectorIndices(enableIndexOpt) {}
2529 
2530   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
2531                                 PatternRewriter &rewriter) const override {
2532     auto dstType = op.getType();
2533     if (dstType.cast<VectorType>().isScalable())
2534       return failure();
2535     int64_t rank = dstType.getRank();
2536     if (rank > 1)
2537       return failure();
2538     rewriter.replaceOp(
2539         op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
2540                                   rank == 0 ? 0 : dstType.getDimSize(0),
2541                                   op.getOperand(0)));
2542     return success();
2543   }
2544 
2545 private:
2546   const bool force32BitVectorIndices;
2547 };
2548 
2549 // Drop inner most contiguous unit dimensions from transfer_read operand.
2550 class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
2551   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
2552 
2553   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
2554                                 PatternRewriter &rewriter) const override {
2555     // TODO: support 0-d corner case.
2556     if (readOp.getTransferRank() == 0)
2557       return failure();
2558 
2559     // TODO: support mask.
2560     if (readOp.getMask())
2561       return failure();
2562 
2563     auto srcType = readOp.getSource().getType().dyn_cast<MemRefType>();
2564     if (!srcType || !srcType.hasStaticShape())
2565       return failure();
2566 
2567     if (!readOp.getPermutationMap().isMinorIdentity())
2568       return failure();
2569 
2570     auto targetType = readOp.getVectorType();
2571     if (targetType.getRank() <= 1)
2572       return failure();
2573 
2574     SmallVector<int64_t> srcStrides;
2575     int64_t srcOffset;
2576     if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
2577       return failure();
2578 
2579     size_t dimsToDrop = 0;
2580     for (size_t i = 1; i < srcStrides.size(); ++i) {
2581       int dim = srcType.getRank() - i - 1;
2582       if (srcStrides[dim] == 1) {
2583         dimsToDrop++;
2584       } else {
2585         break;
2586       }
2587     }
2588     if (dimsToDrop == 0)
2589       return failure();
2590 
2591     auto resultTargetVecType =
2592         VectorType::get(targetType.getShape().drop_back(dimsToDrop),
2593                         targetType.getElementType());
2594 
2595     MemRefType resultMemrefType;
2596     if (srcType.getLayout().getAffineMap().isIdentity()) {
2597       resultMemrefType = MemRefType::get(
2598           srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
2599           {}, srcType.getMemorySpaceAsInt());
2600     } else {
2601       AffineMap map = srcType.getLayout().getAffineMap();
2602       int numSymbols = map.getNumSymbols();
2603       for (size_t i = 0; i < dimsToDrop; ++i) {
2604         int dim = srcType.getRank() - i - 1;
2605         map = map.replace(rewriter.getAffineDimExpr(dim),
2606                           rewriter.getAffineConstantExpr(0),
2607                           map.getNumDims() - 1, numSymbols);
2608       }
2609       resultMemrefType = MemRefType::get(
2610           srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
2611           map, srcType.getMemorySpaceAsInt());
2612     }
2613 
2614     auto loc = readOp.getLoc();
2615     SmallVector<int64_t> offsets(srcType.getRank(), 0);
2616     SmallVector<int64_t> strides(srcType.getRank(), 1);
2617 
2618     ArrayAttr inBoundsAttr =
2619         readOp.getInBounds()
2620             ? rewriter.getArrayAttr(
2621                   readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
2622             : ArrayAttr();
2623     Value rankedReducedView = rewriter.create<memref::SubViewOp>(
2624         loc, resultMemrefType, readOp.getSource(), offsets, srcType.getShape(),
2625         strides);
2626     auto permMap = getTransferMinorIdentityMap(
2627         rankedReducedView.getType().cast<ShapedType>(), resultTargetVecType);
2628     Value result = rewriter.create<vector::TransferReadOp>(
2629         loc, resultTargetVecType, rankedReducedView,
2630         readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
2631         readOp.getPadding(),
2632         // TODO: support mask.
2633         /*mask=*/Value(), inBoundsAttr);
2634     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
2635                                                      result);
2636     return success();
2637   }
2638 };
2639 
2640 namespace {
2641 
2642 /// This function checks to see if the vector combining kind
2643 /// is consistent with the integer or float element type.
2644 static bool isValidKind(bool isInt, vector::CombiningKind kind) {
2645   using vector::CombiningKind;
2646   enum class KindType { FLOAT, INT, INVALID };
2647   KindType type{KindType::INVALID};
2648   switch (kind) {
2649   case CombiningKind::MINF:
2650   case CombiningKind::MAXF:
2651     type = KindType::FLOAT;
2652     break;
2653   case CombiningKind::MINUI:
2654   case CombiningKind::MINSI:
2655   case CombiningKind::MAXUI:
2656   case CombiningKind::MAXSI:
2657   case CombiningKind::AND:
2658   case CombiningKind::OR:
2659   case CombiningKind::XOR:
2660     type = KindType::INT;
2661     break;
2662   case CombiningKind::ADD:
2663   case CombiningKind::MUL:
2664     type = isInt ? KindType::INT : KindType::FLOAT;
2665     break;
2666   }
2667   bool isValidIntKind = (type == KindType::INT) && isInt;
2668   bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt);
2669   return (isValidIntKind || isValidFloatKind);
2670 }
2671 
2672 /// This function constructs the appropriate integer or float
2673 /// operation given the vector combining kind and operands. The
2674 /// supported int operations are : add, mul, min (signed/unsigned),
2675 /// max(signed/unsigned), and, or, xor. The supported float
2676 /// operations are : add, mul, min and max.
2677 static Value genOperator(Location loc, Value x, Value y,
2678                          vector::CombiningKind kind,
2679                          PatternRewriter &rewriter) {
2680   using vector::CombiningKind;
2681 
2682   auto elType = x.getType().cast<VectorType>().getElementType();
2683   bool isInt = elType.isIntOrIndex();
2684 
2685   Value combinedResult{nullptr};
2686   switch (kind) {
2687   case CombiningKind::ADD:
2688     if (isInt)
2689       combinedResult = rewriter.create<arith::AddIOp>(loc, x, y);
2690     else
2691       combinedResult = rewriter.create<arith::AddFOp>(loc, x, y);
2692     break;
2693   case CombiningKind::MUL:
2694     if (isInt)
2695       combinedResult = rewriter.create<arith::MulIOp>(loc, x, y);
2696     else
2697       combinedResult = rewriter.create<arith::MulFOp>(loc, x, y);
2698     break;
2699   case CombiningKind::MINUI:
2700     combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y);
2701     break;
2702   case CombiningKind::MINSI:
2703     combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y);
2704     break;
2705   case CombiningKind::MAXUI:
2706     combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y);
2707     break;
2708   case CombiningKind::MAXSI:
2709     combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y);
2710     break;
2711   case CombiningKind::AND:
2712     combinedResult = rewriter.create<arith::AndIOp>(loc, x, y);
2713     break;
2714   case CombiningKind::OR:
2715     combinedResult = rewriter.create<arith::OrIOp>(loc, x, y);
2716     break;
2717   case CombiningKind::XOR:
2718     combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y);
2719     break;
2720   case CombiningKind::MINF:
2721     combinedResult = rewriter.create<arith::MinFOp>(loc, x, y);
2722     break;
2723   case CombiningKind::MAXF:
2724     combinedResult = rewriter.create<arith::MaxFOp>(loc, x, y);
2725     break;
2726   }
2727   return combinedResult;
2728 }
2729 
2730 /// Convert vector.scan op into arith ops and
2731 /// vector.insert_strided_slice/extract_strided_slice
2732 ///
2733 /// Ex:
2734 /// ```
2735 ///   %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim =
2736 ///   1} :
2737 ///     (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>)
2738 /// ```
2739 /// Gets converted to:
2740 /// ```
2741 ///   %cst = arith.constant dense<0> : vector<2x3xi32>
2742 ///   %0 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 1],
2743 ///   strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %1 =
2744 ///   vector.insert_strided_slice %0, %cst {offsets = [0, 0], strides = [1, 1]}
2745 ///   : vector<2x1xi32> into vector<2x3xi32> %2 = vector.extract_strided_slice
2746 ///   %arg0 {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} :
2747 ///   vector<2x3xi32> to vector<2x1xi32> %3 = arith.muli %0, %2 :
2748 ///   vector<2x1xi32> %4 = vector.insert_strided_slice %3, %1 {offsets = [0, 1],
2749 ///   strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %5 =
2750 ///   vector.extract_strided_slice %arg0 {offsets = [0, 2], sizes = [2, 1],
2751 ///   strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %6 = arith.muli %3,
2752 ///   %5 : vector<2x1xi32> %7 = vector.insert_strided_slice %6, %4 {offsets =
2753 ///   [0, 2], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %8 =
2754 ///   vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32> return %7, %8 :
2755 ///   vector<2x3xi32>, vector<2xi32>
2756 /// ```
2757 struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
2758   using OpRewritePattern<vector::ScanOp>::OpRewritePattern;
2759 
2760   LogicalResult matchAndRewrite(vector::ScanOp scanOp,
2761                                 PatternRewriter &rewriter) const override {
2762     auto loc = scanOp.getLoc();
2763     VectorType destType = scanOp.getDestType();
2764     ArrayRef<int64_t> destShape = destType.getShape();
2765     auto elType = destType.getElementType();
2766     bool isInt = elType.isIntOrIndex();
2767     if (!isValidKind(isInt, scanOp.getKind()))
2768       return failure();
2769 
2770     VectorType resType = VectorType::get(destShape, elType);
2771     Value result = rewriter.create<arith::ConstantOp>(
2772         loc, resType, rewriter.getZeroAttr(resType));
2773     int64_t reductionDim = scanOp.getReductionDim();
2774     bool inclusive = scanOp.getInclusive();
2775     int64_t destRank = destType.getRank();
2776     VectorType initialValueType = scanOp.getInitialValueType();
2777     int64_t initialValueRank = initialValueType.getRank();
2778 
2779     SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end());
2780     reductionShape[reductionDim] = 1;
2781     VectorType reductionType = VectorType::get(reductionShape, elType);
2782     SmallVector<int64_t> offsets(destRank, 0);
2783     SmallVector<int64_t> strides(destRank, 1);
2784     SmallVector<int64_t> sizes(destShape.begin(), destShape.end());
2785     sizes[reductionDim] = 1;
2786     ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes);
2787     ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides);
2788 
2789     Value lastOutput, lastInput;
2790     for (int i = 0; i < destShape[reductionDim]; i++) {
2791       offsets[reductionDim] = i;
2792       ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets);
2793       Value input = rewriter.create<vector::ExtractStridedSliceOp>(
2794           loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes,
2795           scanStrides);
2796       Value output;
2797       if (i == 0) {
2798         if (inclusive) {
2799           output = input;
2800         } else {
2801           if (initialValueRank == 0) {
2802             // ShapeCastOp cannot handle 0-D vectors
2803             output = rewriter.create<vector::BroadcastOp>(
2804                 loc, input.getType(), scanOp.getInitialValue());
2805           } else {
2806             output = rewriter.create<vector::ShapeCastOp>(
2807                 loc, input.getType(), scanOp.getInitialValue());
2808           }
2809         }
2810       } else {
2811         Value y = inclusive ? input : lastInput;
2812         output = genOperator(loc, lastOutput, y, scanOp.getKind(), rewriter);
2813         assert(output != nullptr);
2814       }
2815       result = rewriter.create<vector::InsertStridedSliceOp>(
2816           loc, output, result, offsets, strides);
2817       lastOutput = output;
2818       lastInput = input;
2819     }
2820 
2821     Value reduction;
2822     if (initialValueRank == 0) {
2823       Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0);
2824       reduction =
2825           rewriter.create<vector::BroadcastOp>(loc, initialValueType, v);
2826     } else {
2827       reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType,
2828                                                        lastOutput);
2829     }
2830 
2831     rewriter.replaceOp(scanOp, {result, reduction});
2832     return success();
2833   }
2834 };
2835 
2836 } // namespace
2837 
2838 void mlir::vector::populateVectorMaskMaterializationPatterns(
2839     RewritePatternSet &patterns, bool force32BitVectorIndices) {
2840   patterns.add<VectorCreateMaskOpConversion,
2841                MaterializeTransferMask<vector::TransferReadOp>,
2842                MaterializeTransferMask<vector::TransferWriteOp>>(
2843       patterns.getContext(), force32BitVectorIndices);
2844 }
2845 
2846 void mlir::vector::populateShapeCastFoldingPatterns(
2847     RewritePatternSet &patterns) {
2848   patterns.add<ShapeCastOpFolder>(patterns.getContext());
2849 }
2850 
2851 void mlir::vector::populateBubbleVectorBitCastOpPatterns(
2852     RewritePatternSet &patterns) {
2853   patterns.add<BubbleDownVectorBitCastForExtract,
2854                BubbleDownBitCastForStridedSliceExtract,
2855                BubbleUpBitCastForStridedSliceInsert>(patterns.getContext());
2856 }
2857 
2858 void mlir::vector::populateVectorBroadcastLoweringPatterns(
2859     RewritePatternSet &patterns) {
2860   patterns.add<BroadcastOpLowering>(patterns.getContext());
2861 }
2862 
2863 void mlir::vector::populateVectorMaskOpLoweringPatterns(
2864     RewritePatternSet &patterns) {
2865   patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
2866       patterns.getContext());
2867 }
2868 
2869 void mlir::vector::populateVectorShapeCastLoweringPatterns(
2870     RewritePatternSet &patterns) {
2871   patterns.add<ShapeCastOp2DDownCastRewritePattern,
2872                ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>(
2873       patterns.getContext());
2874 }
2875 
2876 void mlir::vector::populateVectorContractLoweringPatterns(
2877     RewritePatternSet &patterns, VectorTransformsOptions options) {
2878   patterns.add<OuterProductOpLowering>(patterns.getContext());
2879   patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
2880                ContractionOpToOuterProductOpLowering>(options,
2881                                                       patterns.getContext());
2882 }
2883 
2884 void mlir::vector::populateVectorTransposeLoweringPatterns(
2885     RewritePatternSet &patterns, VectorTransformsOptions options) {
2886   patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
2887       options, patterns.getContext());
2888 }
2889 
2890 void mlir::vector::populateVectorReductionToContractPatterns(
2891     RewritePatternSet &patterns) {
2892   patterns.add<MultiReduceToContract, CombineContractBroadcast,
2893                CombineContractTranspose, ReorderCastOpsOnBroadcast,
2894                ReorderElementwiseOpsOnTranspose>(patterns.getContext());
2895 }
2896 
2897 void mlir::vector::
2898     populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
2899         RewritePatternSet &patterns) {
2900   patterns.add<DropInnerMostUnitDims>(patterns.getContext());
2901 }
2902 
2903 void mlir::vector::populateVectorTransferLoweringPatterns(
2904     RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank) {
2905   patterns.add<TransferReadToVectorLoadLowering,
2906                TransferWriteToVectorStoreLowering>(patterns.getContext(),
2907                                                    maxTransferRank);
2908   patterns
2909       .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
2910           patterns.getContext());
2911 }
2912 
2913 void mlir::vector::populateVectorScanLoweringPatterns(
2914     RewritePatternSet &patterns) {
2915   patterns.add<ScanToArithOps>(patterns.getContext());
2916 }
2917