1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites as 1->N patterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
14 
15 #include <type_traits>
16 
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
19 #include "mlir/Dialect/Linalg/IR/Linalg.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/Dialect/SCF/SCF.h"
22 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
23 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
24 #include "mlir/IR/ImplicitLocOpBuilder.h"
25 #include "mlir/IR/Matchers.h"
26 #include "mlir/IR/PatternMatch.h"
27 #include "mlir/Interfaces/VectorInterfaces.h"
28 
29 #include "llvm/ADT/DenseSet.h"
30 #include "llvm/ADT/MapVector.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/raw_ostream.h"
35 
36 #define DEBUG_TYPE "vector-to-vector"
37 
38 using namespace mlir;
39 using namespace mlir::vector;
40 
41 // Helper to find an index in an affine map.
42 static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
43   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
44     int64_t idx = map.getDimPosition(i);
45     if (idx == index)
46       return i;
47   }
48   return None;
49 }
50 
51 // Helper to construct iterator types with one index removed.
52 static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes,
53                                             int64_t index) {
54   SmallVector<Attribute, 4> results;
55   for (const auto &it : llvm::enumerate(iteratorTypes)) {
56     int64_t idx = it.index();
57     if (idx == index)
58       continue;
59     results.push_back(it.value());
60   }
61   return results;
62 }
63 
64 // Helper to construct an affine map with one index removed.
65 static AffineMap adjustMap(AffineMap map, int64_t index,
66                            PatternRewriter &rewriter) {
67   auto *ctx = rewriter.getContext();
68   SmallVector<AffineExpr, 4> results;
69   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
70     int64_t idx = map.getDimPosition(i);
71     if (idx == index)
72       continue;
73     // Re-insert remaining indices, but renamed when occurring
74     // after the removed index.
75     auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
76     results.push_back(targetExpr);
77   }
78   return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
79 }
80 
81 // Helper method to possibly drop a dimension in a load.
82 // TODO
83 static Value reshapeLoad(Location loc, Value val, VectorType type,
84                          int64_t index, int64_t pos,
85                          PatternRewriter &rewriter) {
86   if (index == -1)
87     return val;
88   Type lowType = VectorType::Builder(type).dropDim(0);
89   // At extraction dimension?
90   if (index == 0) {
91     auto posAttr = rewriter.getI64ArrayAttr(pos);
92     return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
93   }
94   // Unroll leading dimensions.
95   VectorType vType = lowType.cast<VectorType>();
96   Type resType = VectorType::Builder(type).dropDim(index);
97   auto resVectorType = resType.cast<VectorType>();
98   Value result = rewriter.create<arith::ConstantOp>(
99       loc, resVectorType, rewriter.getZeroAttr(resVectorType));
100   for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) {
101     auto posAttr = rewriter.getI64ArrayAttr(d);
102     Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
103     Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
104     result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result,
105                                                posAttr);
106   }
107   return result;
108 }
109 
110 // Helper method to possibly drop a dimension in a store.
111 // TODO
112 static Value reshapeStore(Location loc, Value val, Value result,
113                           VectorType type, int64_t index, int64_t pos,
114                           PatternRewriter &rewriter) {
115   // Unmodified?
116   if (index == -1)
117     return val;
118   // At insertion dimension?
119   if (index == 0) {
120     auto posAttr = rewriter.getI64ArrayAttr(pos);
121     return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
122   }
123   // Unroll leading dimensions.
124   Type lowType = VectorType::Builder(type).dropDim(0);
125   VectorType vType = lowType.cast<VectorType>();
126   Type insType = VectorType::Builder(vType).dropDim(0);
127   for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
128     auto posAttr = rewriter.getI64ArrayAttr(d);
129     Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
130     Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr);
131     Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
132     result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr);
133   }
134   return result;
135 }
136 
137 template <typename IntType>
138 static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
139   return llvm::to_vector<4>(llvm::map_range(
140       arrayAttr.getAsRange<IntegerAttr>(),
141       [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
142 }
143 
144 namespace {
145 
146 /// ShapeCastOpFolder folds cancelling ShapeCastOps away.
147 //
148 // Example:
149 //
150 //  The following MLIR with cancelling ShapeCastOps:
151 //
152 //   %0 = source : vector<5x4x2xf32>
153 //   %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
154 //   %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
155 //   %3 = user %2 : vector<5x4x2xf32>
156 //
157 //  Should canonicalize to the following:
158 //
159 //   %0 = source : vector<5x4x2xf32>
160 //   %1 = user %0 : vector<5x4x2xf32>
161 //
162 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
163   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
164 
165   LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
166                                 PatternRewriter &rewriter) const override {
167     // Check if 'shapeCastOp' has vector source/result type.
168     auto sourceVectorType =
169         shapeCastOp.source().getType().dyn_cast_or_null<VectorType>();
170     auto resultVectorType =
171         shapeCastOp.result().getType().dyn_cast_or_null<VectorType>();
172     if (!sourceVectorType || !resultVectorType)
173       return failure();
174 
175     // Check if shape cast op source operand is also a shape cast op.
176     auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
177         shapeCastOp.source().getDefiningOp());
178     if (!sourceShapeCastOp)
179       return failure();
180     auto operandSourceVectorType =
181         sourceShapeCastOp.source().getType().cast<VectorType>();
182     auto operandResultVectorType = sourceShapeCastOp.getType();
183 
184     // Check if shape cast operations invert each other.
185     if (operandSourceVectorType != resultVectorType ||
186         operandResultVectorType != sourceVectorType)
187       return failure();
188 
189     rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.source());
190     return success();
191   }
192 };
193 
194 /// Progressive lowering of BroadcastOp.
195 class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
196 public:
197   using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
198 
199   LogicalResult matchAndRewrite(vector::BroadcastOp op,
200                                 PatternRewriter &rewriter) const override {
201     auto loc = op.getLoc();
202     VectorType dstType = op.getVectorType();
203     VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
204     Type eltType = dstType.getElementType();
205 
206     // Scalar to any vector can use splat.
207     if (!srcType) {
208       rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.source());
209       return success();
210     }
211 
212     // Determine rank of source and destination.
213     int64_t srcRank = srcType.getRank();
214     int64_t dstRank = dstType.getRank();
215 
216     // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
217     if (srcRank <= 1 && dstRank == 1) {
218       Value ext;
219       if (srcRank == 0)
220         ext = rewriter.create<vector::ExtractElementOp>(loc, op.source());
221       else
222         ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
223       rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
224       return success();
225     }
226 
227     // Duplicate this rank.
228     // For example:
229     //   %x = broadcast %y  : k-D to n-D, k < n
230     // becomes:
231     //   %b = broadcast %y  : k-D to (n-1)-D
232     //   %x = [%b,%b,%b,%b] : n-D
233     // becomes:
234     //   %b = [%y,%y]       : (n-1)-D
235     //   %x = [%b,%b,%b,%b] : n-D
236     if (srcRank < dstRank) {
237       // Duplication.
238       VectorType resType =
239           VectorType::get(dstType.getShape().drop_front(), eltType);
240       Value bcst =
241           rewriter.create<vector::BroadcastOp>(loc, resType, op.source());
242       Value result = rewriter.create<arith::ConstantOp>(
243           loc, dstType, rewriter.getZeroAttr(dstType));
244       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
245         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
246       rewriter.replaceOp(op, result);
247       return success();
248     }
249 
250     // Find non-matching dimension, if any.
251     assert(srcRank == dstRank);
252     int64_t m = -1;
253     for (int64_t r = 0; r < dstRank; r++)
254       if (srcType.getDimSize(r) != dstType.getDimSize(r)) {
255         m = r;
256         break;
257       }
258 
259     // All trailing dimensions are the same. Simply pass through.
260     if (m == -1) {
261       rewriter.replaceOp(op, op.source());
262       return success();
263     }
264 
265     // Any non-matching dimension forces a stretch along this rank.
266     // For example:
267     //   %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32>
268     // becomes:
269     //   %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32>
270     //   %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32>
271     //   %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32>
272     //   %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32>
273     //   %x = [%a,%b,%c,%d]
274     // becomes:
275     //   %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32>
276     //   %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32>
277     //   %a = [%u, %v]
278     //   ..
279     //   %x = [%a,%b,%c,%d]
280     VectorType resType =
281         VectorType::get(dstType.getShape().drop_front(), eltType);
282     Value result = rewriter.create<arith::ConstantOp>(
283         loc, dstType, rewriter.getZeroAttr(dstType));
284     if (m == 0) {
285       // Stetch at start.
286       Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
287       Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
288       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
289         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
290     } else {
291       // Stetch not at start.
292       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
293         Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), d);
294         Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
295         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
296       }
297     }
298     rewriter.replaceOp(op, result);
299     return success();
300   }
301 };
302 
303 /// Progressive lowering of TransposeOp.
304 /// One:
305 ///   %x = vector.transpose %y, [1, 0]
306 /// is replaced by:
307 ///   %z = arith.constant dense<0.000000e+00>
308 ///   %0 = vector.extract %y[0, 0]
309 ///   %1 = vector.insert %0, %z [0, 0]
310 ///   ..
311 ///   %x = vector.insert .., .. [.., ..]
312 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
313 public:
314   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
315 
316   TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
317                       MLIRContext *context)
318       : OpRewritePattern<vector::TransposeOp>(context),
319         vectorTransformOptions(vectorTransformOptions) {}
320 
321   LogicalResult matchAndRewrite(vector::TransposeOp op,
322                                 PatternRewriter &rewriter) const override {
323     auto loc = op.getLoc();
324 
325     VectorType resType = op.getResultType();
326 
327     // Set up convenience transposition table.
328     SmallVector<int64_t, 4> transp;
329     for (auto attr : op.transp())
330       transp.push_back(attr.cast<IntegerAttr>().getInt());
331 
332     if (vectorTransformOptions.vectorTransposeLowering ==
333             vector::VectorTransposeLowering::Shuffle &&
334         resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0)
335       return rewriter.notifyMatchFailure(
336           op, "Options specifies lowering to shuffle");
337 
338     // Handle a true 2-D matrix transpose differently when requested.
339     if (vectorTransformOptions.vectorTransposeLowering ==
340             vector::VectorTransposeLowering::Flat &&
341         resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
342       Type flattenedType =
343           VectorType::get(resType.getNumElements(), resType.getElementType());
344       auto matrix =
345           rewriter.create<vector::ShapeCastOp>(loc, flattenedType, op.vector());
346       auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
347       auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
348       Value trans = rewriter.create<vector::FlatTransposeOp>(
349           loc, flattenedType, matrix, rows, columns);
350       rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
351       return success();
352     }
353 
354     // Generate fully unrolled extract/insert ops.
355     Value result = rewriter.create<arith::ConstantOp>(
356         loc, resType, rewriter.getZeroAttr(resType));
357     SmallVector<int64_t, 4> lhs(transp.size(), 0);
358     SmallVector<int64_t, 4> rhs(transp.size(), 0);
359     rewriter.replaceOp(op, expandIndices(loc, resType, 0, transp, lhs, rhs,
360                                          op.vector(), result, rewriter));
361     return success();
362   }
363 
364 private:
365   // Builds the indices arrays for the lhs and rhs. Generates the extract/insert
366   // operation when al ranks are exhausted.
367   Value expandIndices(Location loc, VectorType resType, int64_t pos,
368                       SmallVector<int64_t, 4> &transp,
369                       SmallVector<int64_t, 4> &lhs,
370                       SmallVector<int64_t, 4> &rhs, Value input, Value result,
371                       PatternRewriter &rewriter) const {
372     if (pos >= resType.getRank()) {
373       auto ridx = rewriter.getI64ArrayAttr(rhs);
374       auto lidx = rewriter.getI64ArrayAttr(lhs);
375       Type eltType = resType.getElementType();
376       Value e = rewriter.create<vector::ExtractOp>(loc, eltType, input, ridx);
377       return rewriter.create<vector::InsertOp>(loc, resType, e, result, lidx);
378     }
379     for (int64_t d = 0, e = resType.getDimSize(pos); d < e; ++d) {
380       lhs[pos] = d;
381       rhs[transp[pos]] = d;
382       result = expandIndices(loc, resType, pos + 1, transp, lhs, rhs, input,
383                              result, rewriter);
384     }
385     return result;
386   }
387 
388   /// Options to control the vector patterns.
389   vector::VectorTransformsOptions vectorTransformOptions;
390 };
391 
392 /// Rewrite a 2-D vector.transpose as a sequence of:
393 ///   vector.shape_cast 2D -> 1D
394 ///   vector.shuffle
395 ///   vector.shape_cast 1D -> 2D
396 class TransposeOp2DToShuffleLowering
397     : public OpRewritePattern<vector::TransposeOp> {
398 public:
399   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
400 
401   TransposeOp2DToShuffleLowering(
402       vector::VectorTransformsOptions vectorTransformOptions,
403       MLIRContext *context)
404       : OpRewritePattern<vector::TransposeOp>(context),
405         vectorTransformOptions(vectorTransformOptions) {}
406 
407   LogicalResult matchAndRewrite(vector::TransposeOp op,
408                                 PatternRewriter &rewriter) const override {
409     auto loc = op.getLoc();
410 
411     VectorType srcType = op.getVectorType();
412     if (srcType.getRank() != 2)
413       return rewriter.notifyMatchFailure(op, "Not a 2D transpose");
414 
415     SmallVector<int64_t, 4> transp;
416     for (auto attr : op.transp())
417       transp.push_back(attr.cast<IntegerAttr>().getInt());
418     if (transp[0] != 1 && transp[1] != 0)
419       return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation");
420 
421     if (vectorTransformOptions.vectorTransposeLowering !=
422         VectorTransposeLowering::Shuffle)
423       return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle");
424 
425     int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
426     Value casted = rewriter.create<vector::ShapeCastOp>(
427         loc, VectorType::get({m * n}, srcType.getElementType()), op.vector());
428     SmallVector<int64_t> mask;
429     mask.reserve(m * n);
430     for (int64_t j = 0; j < n; ++j)
431       for (int64_t i = 0; i < m; ++i)
432         mask.push_back(i * n + j);
433 
434     Value shuffled =
435         rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask);
436     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
437                                                      shuffled);
438 
439     return success();
440   }
441 
442 private:
443   /// Options to control the vector patterns.
444   vector::VectorTransformsOptions vectorTransformOptions;
445 };
446 
447 /// Progressive lowering of OuterProductOp.
448 /// One:
449 ///   %x = vector.outerproduct %lhs, %rhs, %acc
450 /// is replaced by:
451 ///   %z = zero-result
452 ///   %0 = vector.extract %lhs[0]
453 ///   %1 = vector.broadcast %0
454 ///   %2 = vector.extract %acc[0]
455 ///   %3 = vector.fma %1, %rhs, %2
456 ///   %4 = vector.insert %3, %z[0]
457 ///   ..
458 ///   %x = vector.insert %.., %..[N-1]
459 ///
460 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
461 public:
462   using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
463 
464   LogicalResult matchAndRewrite(vector::OuterProductOp op,
465                                 PatternRewriter &rewriter) const override {
466     auto loc = op.getLoc();
467 
468     VectorType lhsType = op.getOperandVectorTypeLHS();
469     VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>();
470     VectorType resType = op.getVectorType();
471     Type eltType = resType.getElementType();
472     bool isInt = eltType.isa<IntegerType, IndexType>();
473     Value acc = (op.acc().empty()) ? nullptr : op.acc()[0];
474     vector::CombiningKind kind = op.kind();
475 
476     if (!rhsType) {
477       // Special case: AXPY operation.
478       Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.rhs());
479       Optional<Value> mult =
480           isInt ? genMultI(loc, op.lhs(), b, acc, kind, rewriter)
481                 : genMultF(loc, op.lhs(), b, acc, kind, rewriter);
482       if (!mult.hasValue())
483         return failure();
484       rewriter.replaceOp(op, mult.getValue());
485       return success();
486     }
487 
488     Value result = rewriter.create<arith::ConstantOp>(
489         loc, resType, rewriter.getZeroAttr(resType));
490     for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
491       auto pos = rewriter.getI64ArrayAttr(d);
492       Value x = rewriter.create<vector::ExtractOp>(loc, eltType, op.lhs(), pos);
493       Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
494       Value r = nullptr;
495       if (acc)
496         r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
497       Optional<Value> m = isInt ? genMultI(loc, a, op.rhs(), r, kind, rewriter)
498                                 : genMultF(loc, a, op.rhs(), r, kind, rewriter);
499       if (!m.hasValue())
500         return failure();
501       result = rewriter.create<vector::InsertOp>(loc, resType, m.getValue(),
502                                                  result, pos);
503     }
504     rewriter.replaceOp(op, result);
505     return success();
506   }
507 
508 private:
509   static Optional<Value> genMultI(Location loc, Value x, Value y, Value acc,
510                                   vector::CombiningKind kind,
511                                   PatternRewriter &rewriter) {
512     using vector::CombiningKind;
513 
514     auto mul = rewriter.create<arith::MulIOp>(loc, x, y);
515     if (!acc)
516       return Optional<Value>(mul);
517 
518     if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF)
519       // Only valid for floating point types.
520       return Optional<Value>();
521 
522     return makeArithReduction(rewriter, loc, kind, mul, acc);
523   }
524 
525   static Optional<Value> genMultF(Location loc, Value x, Value y, Value acc,
526                                   vector::CombiningKind kind,
527                                   PatternRewriter &rewriter) {
528     using vector::CombiningKind;
529 
530     // Special case for fused multiply-add.
531     if (acc && kind == CombiningKind::ADD) {
532       return Optional<Value>(rewriter.create<vector::FMAOp>(loc, x, y, acc));
533     }
534 
535     auto mul = rewriter.create<arith::MulFOp>(loc, x, y);
536 
537     if (!acc)
538       return Optional<Value>(mul);
539 
540     if (kind == CombiningKind::ADD || kind == CombiningKind::AND ||
541         kind == CombiningKind::MINUI || kind == CombiningKind::MINSI ||
542         kind == CombiningKind::MAXUI || kind == CombiningKind::MAXSI ||
543         kind == CombiningKind::OR || kind == CombiningKind::XOR)
544       // Already handled or only valid for integer types.
545       return Optional<Value>();
546 
547     return makeArithReduction(rewriter, loc, kind, mul, acc);
548   }
549 };
550 
551 /// Progressive lowering of ConstantMaskOp.
552 /// One:
553 ///   %x = vector.constant_mask [a,b]
554 /// is replaced by:
555 ///   %z = zero-result
556 ///   %l = vector.constant_mask [b]
557 ///   %4 = vector.insert %l, %z[0]
558 ///   ..
559 ///   %x = vector.insert %l, %..[a-1]
560 /// until a one-dimensional vector is reached. All these operations
561 /// will be folded at LLVM IR level.
562 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
563 public:
564   using OpRewritePattern<vector::ConstantMaskOp>::OpRewritePattern;
565 
566   LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
567                                 PatternRewriter &rewriter) const override {
568     auto loc = op.getLoc();
569     auto dstType = op.getType();
570     auto eltType = dstType.getElementType();
571     auto dimSizes = op.mask_dim_sizes();
572     int64_t rank = dstType.getRank();
573 
574     if (rank == 0) {
575       assert(dimSizes.size() == 1 &&
576              "Expected exactly one dim size for a 0-D vector");
577       bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1;
578       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
579           op, dstType,
580           DenseIntElementsAttr::get(
581               VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
582               ArrayRef<bool>{value}));
583       return success();
584     }
585 
586     int64_t trueDim = std::min(dstType.getDimSize(0),
587                                dimSizes[0].cast<IntegerAttr>().getInt());
588 
589     if (rank == 1) {
590       // Express constant 1-D case in explicit vector form:
591       //   [T,..,T,F,..,F].
592       SmallVector<bool, 4> values(dstType.getDimSize(0));
593       for (int64_t d = 0; d < trueDim; d++)
594         values[d] = true;
595       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
596           op, dstType, rewriter.getBoolVectorAttr(values));
597       return success();
598     }
599 
600     VectorType lowType =
601         VectorType::get(dstType.getShape().drop_front(), eltType);
602     SmallVector<int64_t, 4> newDimSizes;
603     for (int64_t r = 1; r < rank; r++)
604       newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
605     Value trueVal = rewriter.create<vector::ConstantMaskOp>(
606         loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
607     Value result = rewriter.create<arith::ConstantOp>(
608         loc, dstType, rewriter.getZeroAttr(dstType));
609     for (int64_t d = 0; d < trueDim; d++) {
610       auto pos = rewriter.getI64ArrayAttr(d);
611       result =
612           rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos);
613     }
614     rewriter.replaceOp(op, result);
615     return success();
616   }
617 };
618 
619 /// Progressive lowering of CreateMaskOp.
620 /// One:
621 ///   %x = vector.create_mask %a, ... : vector<dx...>
622 /// is replaced by:
623 ///   %l = vector.create_mask ... : vector<...>  ; one lower rank
624 ///   %0 = arith.cmpi "slt", %ci, %a       |
625 ///   %1 = select %0, %l, %zeroes    |
626 ///   %r = vector.insert %1, %pr [i] | d-times
627 ///   %x = ....
628 /// until a one-dimensional vector is reached.
629 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
630 public:
631   using OpRewritePattern<vector::CreateMaskOp>::OpRewritePattern;
632 
633   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
634                                 PatternRewriter &rewriter) const override {
635     auto dstType = op.getResult().getType().cast<VectorType>();
636     int64_t rank = dstType.getRank();
637     if (rank <= 1)
638       return rewriter.notifyMatchFailure(
639           op, "0-D and 1-D vectors are handled separately");
640 
641     auto loc = op.getLoc();
642     auto eltType = dstType.getElementType();
643     int64_t dim = dstType.getDimSize(0);
644     Value idx = op.getOperand(0);
645 
646     VectorType lowType =
647         VectorType::get(dstType.getShape().drop_front(), eltType);
648     Value trueVal = rewriter.create<vector::CreateMaskOp>(
649         loc, lowType, op.getOperands().drop_front());
650     Value falseVal = rewriter.create<arith::ConstantOp>(
651         loc, lowType, rewriter.getZeroAttr(lowType));
652     Value result = rewriter.create<arith::ConstantOp>(
653         loc, dstType, rewriter.getZeroAttr(dstType));
654     for (int64_t d = 0; d < dim; d++) {
655       Value bnd =
656           rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d));
657       Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
658                                                  bnd, idx);
659       Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal);
660       auto pos = rewriter.getI64ArrayAttr(d);
661       result =
662           rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos);
663     }
664     rewriter.replaceOp(op, result);
665     return success();
666   }
667 };
668 
669 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
670 /// vectors progressively on the way to target llvm.matrix intrinsics.
671 /// This iterates over the most major dimension of the 2-D vector and performs
672 /// rewrites into:
673 ///   vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
674 class ShapeCastOp2DDownCastRewritePattern
675     : public OpRewritePattern<vector::ShapeCastOp> {
676 public:
677   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
678 
679   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
680                                 PatternRewriter &rewriter) const override {
681     auto sourceVectorType = op.getSourceVectorType();
682     auto resultVectorType = op.getResultVectorType();
683     if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
684       return failure();
685 
686     auto loc = op.getLoc();
687     Value desc = rewriter.create<arith::ConstantOp>(
688         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
689     unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
690     for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
691       Value vec = rewriter.create<vector::ExtractOp>(loc, op.source(), i);
692       desc = rewriter.create<vector::InsertStridedSliceOp>(
693           loc, vec, desc,
694           /*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
695     }
696     rewriter.replaceOp(op, desc);
697     return success();
698   }
699 };
700 
701 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
702 /// vectors progressively.
703 /// This iterates over the most major dimension of the 2-D vector and performs
704 /// rewrites into:
705 ///   vector.extract_strided_slice from 1-D + vector.insert into 2-D
706 /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
707 class ShapeCastOp2DUpCastRewritePattern
708     : public OpRewritePattern<vector::ShapeCastOp> {
709 public:
710   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
711 
712   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
713                                 PatternRewriter &rewriter) const override {
714     auto sourceVectorType = op.getSourceVectorType();
715     auto resultVectorType = op.getResultVectorType();
716     if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
717       return failure();
718 
719     auto loc = op.getLoc();
720     Value desc = rewriter.create<arith::ConstantOp>(
721         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
722     unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
723     for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
724       Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
725           loc, op.source(), /*offsets=*/i * mostMinorVectorSize,
726           /*sizes=*/mostMinorVectorSize,
727           /*strides=*/1);
728       desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
729     }
730     rewriter.replaceOp(op, desc);
731     return success();
732   }
733 };
734 
735 // We typically should not lower general shape cast operations into data
736 // movement instructions, since the assumption is that these casts are
737 // optimized away during progressive lowering. For completeness, however,
738 // we fall back to a reference implementation that moves all elements
739 // into the right place if we get here.
740 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
741 public:
742   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
743 
744   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
745                                 PatternRewriter &rewriter) const override {
746     Location loc = op.getLoc();
747     auto sourceVectorType = op.getSourceVectorType();
748     auto resultVectorType = op.getResultVectorType();
749 
750     // Special case 2D/1D lowerings with better implementations.
751     // TODO: make is ND/1D to allow generic ND->1D->MD.
752     int64_t srcRank = sourceVectorType.getRank();
753     int64_t resRank = resultVectorType.getRank();
754     if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
755       return failure();
756 
757     // Generic ShapeCast lowering path goes all the way down to unrolled scalar
758     // extract/insert chains.
759     // TODO: consider evolving the semantics to only allow 1D source or dest and
760     // drop this potentially very expensive lowering.
761     // Compute number of elements involved in the reshape.
762     int64_t numElts = 1;
763     for (int64_t r = 0; r < srcRank; r++)
764       numElts *= sourceVectorType.getDimSize(r);
765     // Replace with data movement operations:
766     //    x[0,0,0] = y[0,0]
767     //    x[0,0,1] = y[0,1]
768     //    x[0,1,0] = y[0,2]
769     // etc., incrementing the two index vectors "row-major"
770     // within the source and result shape.
771     SmallVector<int64_t, 4> srcIdx(srcRank);
772     SmallVector<int64_t, 4> resIdx(resRank);
773     Value result = rewriter.create<arith::ConstantOp>(
774         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
775     for (int64_t i = 0; i < numElts; i++) {
776       if (i != 0) {
777         incIdx(srcIdx, sourceVectorType, srcRank - 1);
778         incIdx(resIdx, resultVectorType, resRank - 1);
779       }
780       Value e = rewriter.create<vector::ExtractOp>(loc, op.source(), srcIdx);
781       result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx);
782     }
783     rewriter.replaceOp(op, result);
784     return success();
785   }
786 
787 private:
788   static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) {
789     assert(0 <= r && r < tp.getRank());
790     if (++idx[r] == tp.getDimSize(r)) {
791       idx[r] = 0;
792       incIdx(idx, tp, r - 1);
793     }
794   }
795 };
796 
797 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
798 /// Ex:
799 /// ```
800 ///   %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
801 ///   %1 = vector.multi_reduction add, %0 [1]
802 ///     : vector<8x32x16xf32> to vector<8x16xf32>
803 /// ```
804 /// Gets converted to:
805 /// ```
806 ///   %1 = vector.contract {indexing_maps = [
807 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
808 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
809 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
810 ///    iterator_types = ["parallel", "parallel", "reduction"],
811 ///    kind = add} %0, %arg1, %cst_f0
812 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
813 ///  ```
814 struct MultiReduceToContract
815     : public OpRewritePattern<vector::MultiDimReductionOp> {
816   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
817 
818   LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
819                                 PatternRewriter &rewriter) const override {
820     if (reduceOp.kind() != vector::CombiningKind::ADD)
821       return failure();
822     Operation *mulOp = reduceOp.source().getDefiningOp();
823     if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
824       return failure();
825     SmallVector<bool> reductionMask = reduceOp.getReductionMask();
826     auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
827     SmallVector<AffineExpr> exprs;
828     SmallVector<StringRef> iteratorTypes;
829     for (const auto &isReduceDim : llvm::enumerate(reductionMask)) {
830       if (!isReduceDim.value()) {
831         iteratorTypes.push_back(getParallelIteratorTypeName());
832         exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
833       } else {
834         iteratorTypes.push_back(getReductionIteratorTypeName());
835       }
836     }
837     auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(),
838                                  /*symCount=*/0, exprs, reduceOp.getContext());
839     Value zero = rewriter.create<arith::ConstantOp>(
840         reduceOp.getLoc(), reduceOp.getDestType(),
841         rewriter.getZeroAttr(reduceOp.getDestType()));
842     rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
843         reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), zero,
844         rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
845         rewriter.getStrArrayAttr(iteratorTypes));
846     return success();
847   }
848 };
849 
850 /// Merge TransposeOp into ContractionOp user.
851 /// Ex:
852 /// ```
853 ///   %0 = vector.transpose %arg0, [2, 0, 1]
854 ///     : vector<32x16x8xf32> to vector<8x32x16xf32>
855 ///   %1 = vector.contract {indexing_maps = [
856 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
857 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
858 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
859 ///    iterator_types = ["parallel", "parallel", "reduction"],
860 ///    kind = add} %0, %arg1, %cst_f0
861 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
862 /// ```
863 /// Gets converted to:
864 /// ```
865 ///   %1 = vector.contract {indexing_maps = [
866 ///         affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
867 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
868 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
869 ///    iterator_types = ["parallel", "parallel", "reduction"],
870 ///    kind = add} %arg0, %arg1, %cst_f0
871 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
872 ///  ```
873 struct CombineContractTranspose
874     : public OpRewritePattern<vector::ContractionOp> {
875   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
876 
877   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
878                                 PatternRewriter &rewriter) const override {
879     SmallVector<AffineMap, 4> maps =
880         llvm::to_vector<4>(contractOp.getIndexingMaps());
881     Value lhs = contractOp.lhs();
882     Value rhs = contractOp.rhs();
883     size_t index = 0;
884     bool changed = false;
885     for (Value *operand : {&lhs, &rhs}) {
886       AffineMap &map = maps[index++];
887       auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
888       if (!transposeOp)
889         continue;
890       SmallVector<int64_t> perm;
891       transposeOp.getTransp(perm);
892       AffineMap permutationMap = AffineMap::getPermutationMap(
893           extractVector<unsigned>(transposeOp.transp()),
894           contractOp.getContext());
895       map = inversePermutation(permutationMap).compose(map);
896       *operand = transposeOp.vector();
897       changed = true;
898     }
899     if (!changed)
900       return failure();
901     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
902         contractOp, lhs, rhs, contractOp.acc(),
903         rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types());
904     return success();
905   }
906 };
907 
908 /// Merge BroadcastOp into ContractionOp user.
909 /// Ex:
910 /// ```
911 ///   %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
912 ///   %1 = vector.contract {indexing_maps = [
913 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
914 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
915 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
916 ///    iterator_types = ["parallel", "parallel", "reduction"],
917 ///    kind = add} %0, %arg1, %cst_f0
918 ///    : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
919 /// ```
920 /// Gets converted to:
921 /// ```
922 ///   %1 = vector.contract {indexing_maps = [
923 ///         affine_map<(d0, d1, d2) -> (d1, d2)>,
924 ///         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
925 ///         affine_map<(d0, d1, d2) -> (d0, d1)>],
926 ///    iterator_types = ["parallel", "parallel", "reduction"],
927 ///    kind = add} %arg0, %arg1, %cst_f0
928 ///    : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
929 ///  ```
930 struct CombineContractBroadcast
931     : public OpRewritePattern<vector::ContractionOp> {
932   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
933 
934   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
935                                 PatternRewriter &rewriter) const override {
936     SmallVector<AffineMap, 4> maps =
937         llvm::to_vector<4>(contractOp.getIndexingMaps());
938     Value lhs = contractOp.lhs();
939     Value rhs = contractOp.rhs();
940     size_t index = 0;
941     bool changed = false;
942     for (Value *operand : {&lhs, &rhs}) {
943       AffineMap &map = maps[index++];
944       auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
945       if (!broadcast)
946         continue;
947       // contractionOp can only take vector as operands.
948       auto srcType = broadcast.getSourceType().dyn_cast<VectorType>();
949       if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank())
950         continue;
951       int64_t rankDiff =
952           broadcast.getVectorType().getRank() - srcType.getRank();
953       bool innerDimBroadcast = false;
954       SmallVector<AffineExpr> originalDims;
955       for (const auto &dim : llvm::enumerate(srcType.getShape())) {
956         if (dim.value() !=
957             broadcast.getVectorType().getDimSize(rankDiff + dim.index())) {
958           innerDimBroadcast = true;
959           break;
960         }
961         originalDims.push_back(
962             rewriter.getAffineDimExpr(dim.index() + rankDiff));
963       }
964       // Contract doesn't support inner dimension broadcast. Once this is
965       // relaxed we can remove this case.
966       if (innerDimBroadcast)
967         continue;
968       AffineMap broadcastMap =
969           AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims,
970                          contractOp.getContext());
971       map = broadcastMap.compose(map);
972       *operand = broadcast.source();
973       changed = true;
974     }
975     if (!changed)
976       return failure();
977     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
978         contractOp, lhs, rhs, contractOp.acc(),
979         rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types());
980     return success();
981   }
982 };
983 
984 } // namespace
985 
986 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
987 /// operands `x` and `y`.
988 static Value createAdd(Location loc, Value x, Value y, bool isInt,
989                        PatternRewriter &rewriter) {
990   if (isInt)
991     return rewriter.create<arith::AddIOp>(loc, x, y);
992   return rewriter.create<arith::AddFOp>(loc, x, y);
993 }
994 
995 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
996 /// operands `x and `y`.
997 static Value createMul(Location loc, Value x, Value y, bool isInt,
998                        PatternRewriter &rewriter) {
999   if (isInt)
1000     return rewriter.create<arith::MulIOp>(loc, x, y);
1001   return rewriter.create<arith::MulFOp>(loc, x, y);
1002 }
1003 
1004 namespace mlir {
1005 
1006 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1007 /// semantics to:
1008 /// ```
1009 ///    %mta = maybe_transpose
1010 ///    %mtb = maybe_transpose
1011 ///    %flattened_a = vector.shape_cast %mta
1012 ///    %flattened_b = vector.shape_cast %mtb
1013 ///    %flattened_d = vector.matmul %flattened_a, %flattened_b
1014 ///    %mtd = vector.shape_cast %flattened_d
1015 ///    %d = maybe_untranspose %mtd
1016 ///    %e = add %c, %d
1017 /// ```
1018 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
1019 //
1020 /// This only kicks in when VectorTransformsOptions is set to `Matmul`.
1021 /// vector.transpose operations are inserted if the vector.contract op is not a
1022 /// row-major matrix multiply.
1023 LogicalResult
1024 ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
1025                                                  PatternRewriter &rew) const {
1026   // TODO: implement masks
1027   if (llvm::size(op.masks()) != 0)
1028     return failure();
1029   if (vectorTransformOptions.vectorContractLowering !=
1030       vector::VectorContractLowering::Matmul)
1031     return failure();
1032   if (failed(filter(op)))
1033     return failure();
1034 
1035   auto iteratorTypes = op.iterator_types().getValue();
1036   if (!isParallelIterator(iteratorTypes[0]) ||
1037       !isParallelIterator(iteratorTypes[1]) ||
1038       !isReductionIterator(iteratorTypes[2]))
1039     return failure();
1040 
1041   Type elementType = op.getLhsType().getElementType();
1042   if (!elementType.isIntOrFloat())
1043     return failure();
1044 
1045   // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
1046   // Bail out if the contraction cannot be put in this form.
1047   MLIRContext *ctx = op.getContext();
1048   Location loc = op.getLoc();
1049   AffineExpr m, n, k;
1050   bindDims(rew.getContext(), m, n, k);
1051   // LHS must be A(m, k) or A(k, m).
1052   Value lhs = op.lhs();
1053   auto lhsMap = op.indexing_maps()[0];
1054   if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
1055     lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
1056   else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
1057     return failure();
1058 
1059   // RHS must be B(k, n) or B(n, k).
1060   Value rhs = op.rhs();
1061   auto rhsMap = op.indexing_maps()[1];
1062   if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
1063     rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
1064   else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
1065     return failure();
1066 
1067   // At this point lhs and rhs are in row-major.
1068   VectorType lhsType = lhs.getType().cast<VectorType>();
1069   VectorType rhsType = rhs.getType().cast<VectorType>();
1070   int64_t lhsRows = lhsType.getDimSize(0);
1071   int64_t lhsColumns = lhsType.getDimSize(1);
1072   int64_t rhsColumns = rhsType.getDimSize(1);
1073 
1074   Type flattenedLHSType =
1075       VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
1076   lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
1077 
1078   Type flattenedRHSType =
1079       VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
1080   rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
1081 
1082   Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
1083                                            rhsColumns);
1084   mul = rew.create<vector::ShapeCastOp>(
1085       loc,
1086       VectorType::get({lhsRows, rhsColumns},
1087                       getElementTypeOrSelf(op.acc().getType())),
1088       mul);
1089 
1090   // ACC must be C(m, n) or C(n, m).
1091   auto accMap = op.indexing_maps()[2];
1092   if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
1093     mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
1094   else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
1095     llvm_unreachable("invalid contraction semantics");
1096 
1097   Value res =
1098       elementType.isa<IntegerType>()
1099           ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.acc(), mul))
1100           : static_cast<Value>(rew.create<arith::AddFOp>(loc, op.acc(), mul));
1101 
1102   rew.replaceOp(op, res);
1103   return success();
1104 }
1105 
1106 namespace {
1107 struct IteratorType {
1108   IteratorType(StringRef strRef) : strRef(strRef) {}
1109   bool isOfType(Attribute attr) const {
1110     auto sAttr = attr.dyn_cast<StringAttr>();
1111     return sAttr && sAttr.getValue() == strRef;
1112   }
1113   StringRef strRef;
1114 };
1115 struct Par : public IteratorType {
1116   Par() : IteratorType(getParallelIteratorTypeName()) {}
1117 };
1118 struct Red : public IteratorType {
1119   Red() : IteratorType(getReductionIteratorTypeName()) {}
1120 };
1121 
1122 /// Generate a vector implementation for matmat, matvec and tmatvec.
1123 /// This unrolls outer-products along the reduction dimension.
1124 struct UnrolledOuterProductGenerator
1125     : public StructuredGenerator<vector::ContractionOp> {
1126 
1127   UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op)
1128       : StructuredGenerator<vector::ContractionOp>(builder, op),
1129         kind(op.kind()), lhs(op.lhs()), rhs(op.rhs()), res(op.acc()),
1130         lhsType(op.getLhsType()) {}
1131 
1132   Value t(Value v) {
1133     static constexpr std::array<int64_t, 2> perm = {1, 0};
1134     return builder.create<vector::TransposeOp>(loc, v, perm);
1135   }
1136 
1137   Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) {
1138     assert(reductionSize > 0);
1139     for (int64_t k = 0; k < reductionSize; ++k) {
1140       Value a = builder.create<vector::ExtractOp>(loc, lhs, k);
1141       Value b = builder.create<vector::ExtractOp>(loc, rhs, k);
1142       res = builder.create<vector::OuterProductOp>(loc, res.getType(), a, b,
1143                                                    res, kind);
1144     }
1145     return res;
1146   }
1147 
1148   /// Two outer parallel, one inner reduction (matmat flavor).
1149   FailureOr<Value> matmat() {
1150     if (!iters({Par(), Par(), Red()}))
1151       return failure();
1152     // Set up the parallel/reduction structure in the right form.
1153     AffineExpr m, n, k;
1154     bindDims(builder.getContext(), m, n, k);
1155     // Classical row-major matmul:  Just permute the lhs.
1156     if (layout({{m, k}, {k, n}, {m, n}}))
1157       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1158     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1159     if (layout({{m, k}, {n, k}, {m, n}})) {
1160       Value tlhs = t(lhs);
1161       return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1));
1162     }
1163     // No need to permute anything.
1164     if (layout({{k, m}, {k, n}, {m, n}}))
1165       return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1166     // Just permute the rhs.
1167     if (layout({{k, m}, {n, k}, {m, n}}))
1168       return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0));
1169     // Transposed output: swap RHS and LHS.
1170     // Classical row-major matmul: permute the lhs.
1171     if (layout({{m, k}, {k, n}, {n, m}}))
1172       return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1));
1173     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
1174     if (layout({{m, k}, {n, k}, {n, m}})) {
1175       Value trhs = t(rhs);
1176       return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1));
1177     }
1178     if (layout({{k, m}, {k, n}, {n, m}}))
1179       return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1180     if (layout({{k, m}, {n, k}, {n, m}}))
1181       return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1182     return failure();
1183   }
1184 
1185   /// One outer parallel, one inner reduction (matvec flavor)
1186   FailureOr<Value> matvec() {
1187     if (!iters({Par(), Red()}))
1188       return failure();
1189     AffineExpr m, k;
1190     bindDims(builder.getContext(), m, k);
1191 
1192     // Case mat-vec: transpose.
1193     if (layout({{m, k}, {k}, {m}}))
1194       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1195     // Case mat-trans-vec: ready to go.
1196     if (layout({{k, m}, {k}, {m}}))
1197       return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1198     // Case vec-mat: swap and transpose.
1199     if (layout({{k}, {m, k}, {m}}))
1200       return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1201     // Case vec-mat-trans: swap and ready to go.
1202     if (layout({{k}, {k, m}, {m}}))
1203       return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1204     return failure();
1205   }
1206 
1207   //
1208   // One outer reduction, one inner parallel (tmatvec flavor)
1209   //
1210   FailureOr<Value> tmatvec() {
1211     if (!iters({Red(), Par()}))
1212       return failure();
1213     AffineExpr k, m;
1214     bindDims(builder.getContext(), k, m);
1215 
1216     // Case mat-vec: transpose.
1217     if (layout({{m, k}, {k}, {m}}))
1218       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
1219     // Case mat-trans-vec: ready to go.
1220     if (layout({{k, m}, {k}, {m}}))
1221       return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
1222     // Case vec-mat: swap and transpose.
1223     if (layout({{k}, {m, k}, {m}}))
1224       return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
1225     // Case vec-mat-trans: swap and ready to go.
1226     if (layout({{k}, {k, m}, {m}}))
1227       return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
1228     return failure();
1229   }
1230 
1231 private:
1232   vector::CombiningKind kind;
1233   Value lhs, rhs, res;
1234   VectorType lhsType;
1235 };
1236 } // namespace
1237 
1238 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1239 /// semantics to a reduction_size-unrolled sequence:
1240 /// ```
1241 ///    %at = vector.transpose %a, [1, 0]
1242 ///    %bRow0 = vector.extract %b[0]
1243 ///    %atRow0 = vector.extract %at[0]
1244 ///    %c0 = vector.outerproduct %atRow0, %bRow0, %c
1245 ///    ...
1246 ///    %bRowK = vector.extract %b[K]
1247 ///    %atRowK = vector.extract %at[K]
1248 ///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
1249 /// ```
1250 ///
1251 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but
1252 /// otherwise supports any layout permutation of the matrix-multiply.
1253 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
1254     vector::ContractionOp op, PatternRewriter &rewriter) const {
1255   // TODO: implement masks
1256   if (llvm::size(op.masks()) != 0)
1257     return failure();
1258 
1259   if (vectorTransformOptions.vectorContractLowering !=
1260       vector::VectorContractLowering::OuterProduct)
1261     return failure();
1262 
1263   if (failed(filter(op)))
1264     return failure();
1265 
1266   UnrolledOuterProductGenerator e(rewriter, op);
1267   FailureOr<Value> matmatRes = e.matmat();
1268   if (succeeded(matmatRes)) {
1269     rewriter.replaceOp(op, *matmatRes);
1270     return success();
1271   }
1272   FailureOr<Value> matvecRes = e.matvec();
1273   if (succeeded(matvecRes)) {
1274     rewriter.replaceOp(op, *matvecRes);
1275     return success();
1276   }
1277   FailureOr<Value> tmatvecRes = e.tmatvec();
1278   if (succeeded(tmatvecRes)) {
1279     rewriter.replaceOp(op, *tmatvecRes);
1280     return success();
1281   }
1282 
1283   return failure();
1284 }
1285 
1286 LogicalResult
1287 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
1288                                             PatternRewriter &rewriter) const {
1289   // TODO: implement masks
1290   if (llvm::size(op.masks()) != 0)
1291     return failure();
1292 
1293   if (failed(filter(op)))
1294     return failure();
1295 
1296   if (vectorTransformOptions.vectorContractLowering !=
1297       vector::VectorContractLowering::Dot)
1298     return failure();
1299 
1300   auto iteratorTypes = op.iterator_types().getValue();
1301   static constexpr std::array<int64_t, 2> perm = {1, 0};
1302   Location loc = op.getLoc();
1303   Value lhs = op.lhs(), rhs = op.rhs();
1304 
1305   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1306   auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
1307   AffineExpr m, n, k;
1308   bindDims(rewriter.getContext(), m, n, k);
1309   SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
1310   //
1311   // In the following we wish to make the reduction dimension innermost so we
1312   // can load vectors and just fmul + reduce into a scalar.
1313   //
1314   if (isParallelIterator(iteratorTypes[0]) &&
1315       isParallelIterator(iteratorTypes[1]) &&
1316       isReductionIterator(iteratorTypes[2])) {
1317     //
1318     // Two outer parallel, one inner reduction (matmat flavor).
1319     //
1320     if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1321       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1322     } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
1323       // No need to permute anything.
1324     } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1325       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1326       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1327     } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1328       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1329     } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1330       // This is the classical row-major matmul. Just permute the lhs.
1331       Value tmp = lhs;
1332       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1333       rhs = tmp;
1334     } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1335       std::swap(lhs, rhs);
1336     } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1337       Value tmp = lhs;
1338       lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
1339       rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
1340     } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1341       Value tmp = rhs;
1342       rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1343       lhs = tmp;
1344     } else {
1345       return failure();
1346     }
1347   } else if (isParallelIterator(iteratorTypes[0]) &&
1348              isReductionIterator(iteratorTypes[1])) {
1349     //
1350     // One outer parallel, one inner reduction (matvec flavor)
1351     //
1352     if (maps == infer({{m, n}, {n}, {m}})) {
1353       // No need to permute anything.
1354     } else if (maps == infer({{n, m}, {n}, {m}})) {
1355       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1356     } else if (maps == infer({{n}, {m, n}, {m}})) {
1357       std::swap(lhs, rhs);
1358     } else if (maps == infer({{n}, {n, m}, {m}})) {
1359       std::swap(lhs, rhs);
1360       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
1361     } else {
1362       return failure();
1363     }
1364   } else {
1365     return failure();
1366   }
1367 
1368   VectorType dstType = op.getResultType().cast<VectorType>();
1369   assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
1370          "Expected dst type of rank 1 or 2");
1371 
1372   unsigned rank = dstType.getRank();
1373   unsigned dstRows = dstType.getShape()[0];
1374   unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
1375 
1376   // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
1377   Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
1378                                                  rewriter.getZeroAttr(dstType));
1379   bool isInt = dstType.getElementType().isa<IntegerType>();
1380   for (unsigned r = 0; r < dstRows; ++r) {
1381     Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
1382     for (unsigned c = 0; c < dstColumns; ++c) {
1383       Value b = rank == 1
1384                     ? rhs
1385                     : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
1386       Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
1387       Value reduced = rewriter.create<vector::ReductionOp>(
1388           op.getLoc(), vector::CombiningKind::ADD, m);
1389 
1390       SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
1391                                               : SmallVector<int64_t, 2>{r, c};
1392       res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
1393     }
1394   }
1395   if (auto acc = op.acc())
1396     res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
1397   rewriter.replaceOp(op, res);
1398   return success();
1399 }
1400 
1401 /// Progressive lowering of ContractionOp.
1402 /// One:
1403 ///   %x = vector.contract with at least one free/batch dimension
1404 /// is replaced by:
1405 ///   %a = vector.contract with one less free/batch dimension
1406 ///   %b = vector.contract with one less free/batch dimension
1407 ///   ..
1408 ///   %x = combine %a %b ..
1409 /// until a pure contraction is reached (no free/batch dimensions),
1410 /// which is replaced by a dot-product.
1411 ///
1412 /// This only kicks in when either VectorTransformsOptions is set
1413 /// to DOT or when other contraction patterns fail.
1414 //
1415 // TODO: break down into transpose/reshape/cast ops
1416 //               when they become available to avoid code dup
1417 // TODO: investigate lowering order impact on performance
1418 LogicalResult
1419 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
1420                                        PatternRewriter &rewriter) const {
1421   // TODO: implement masks.
1422   if (llvm::size(op.masks()) != 0)
1423     return failure();
1424 
1425   if (failed(filter(op)))
1426     return failure();
1427 
1428   // TODO: support mixed mode contract lowering.
1429   if (op.getLhsType().getElementType() !=
1430           getElementTypeOrSelf(op.getAccType()) ||
1431       op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
1432     return failure();
1433 
1434   // TODO: implement benefits, cost models.
1435   MLIRContext *ctx = op.getContext();
1436   ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
1437   if (succeeded(pat1.matchAndRewrite(op, rewriter)))
1438     return success();
1439   ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
1440   if (succeeded(pat2.matchAndRewrite(op, rewriter)))
1441     return success();
1442   ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
1443   if (succeeded(pat3.matchAndRewrite(op, rewriter)))
1444     return success();
1445 
1446   // Find first batch dimension in LHS/RHS, and lower when found.
1447   std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
1448   if (!batchDimMap.empty()) {
1449     int64_t lhsIndex = batchDimMap[0].first;
1450     int64_t rhsIndex = batchDimMap[0].second;
1451     rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter));
1452     return success();
1453   }
1454 
1455   // Collect contracting dimensions.
1456   std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
1457       op.getContractingDimMap();
1458   DenseSet<int64_t> lhsContractingDimSet;
1459   DenseSet<int64_t> rhsContractingDimSet;
1460   for (auto &dimPair : contractingDimMap) {
1461     lhsContractingDimSet.insert(dimPair.first);
1462     rhsContractingDimSet.insert(dimPair.second);
1463   }
1464 
1465   // Find first free dimension in LHS, and lower when found.
1466   VectorType lhsType = op.getLhsType();
1467   for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
1468     if (lhsContractingDimSet.count(lhsIndex) == 0) {
1469       rewriter.replaceOp(
1470           op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter));
1471       return success();
1472     }
1473   }
1474 
1475   // Find first free dimension in RHS, and lower when found.
1476   VectorType rhsType = op.getRhsType();
1477   for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
1478     if (rhsContractingDimSet.count(rhsIndex) == 0) {
1479       rewriter.replaceOp(
1480           op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter));
1481       return success();
1482     }
1483   }
1484 
1485   // Lower the first remaining reduction dimension.
1486   if (!contractingDimMap.empty()) {
1487     rewriter.replaceOp(op, lowerReduction(op, rewriter));
1488     return success();
1489   }
1490 
1491   return failure();
1492 }
1493 
1494 // Lower one parallel dimension.
1495 // TODO: consider reusing existing contract unrolling
1496 Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
1497                                            int64_t lhsIndex, int64_t rhsIndex,
1498                                            PatternRewriter &rewriter) const {
1499   VectorType lhsType = op.getLhsType();
1500   VectorType rhsType = op.getRhsType();
1501   VectorType resType = op.getResultType().cast<VectorType>();
1502   // Find the iterator type index and result index.
1503   SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
1504   int64_t iterIndex = -1;
1505   int64_t dimSize = -1;
1506   if (lhsIndex >= 0) {
1507     iterIndex = iMap[0].getDimPosition(lhsIndex);
1508     assert((rhsIndex < 0 || iterIndex == iMap[1].getDimPosition(rhsIndex)) &&
1509            "parallel index should be free in LHS or batch in LHS/RHS");
1510     dimSize = lhsType.getDimSize(lhsIndex);
1511   } else {
1512     assert(rhsIndex >= 0 && "missing parallel index");
1513     iterIndex = iMap[1].getDimPosition(rhsIndex);
1514     dimSize = rhsType.getDimSize(rhsIndex);
1515   }
1516   assert(iterIndex >= 0 && "parallel index not listed in operand mapping");
1517   Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex);
1518   assert(lookup.hasValue() && "parallel index not listed in reduction");
1519   int64_t resIndex = lookup.getValue();
1520   // Construct new iterator types and affine map array attribute.
1521   std::array<AffineMap, 3> lowIndexingMaps = {
1522       adjustMap(iMap[0], iterIndex, rewriter),
1523       adjustMap(iMap[1], iterIndex, rewriter),
1524       adjustMap(iMap[2], iterIndex, rewriter)};
1525   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1526   auto lowIter =
1527       rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
1528   // Unroll into a series of lower dimensional vector.contract ops.
1529   Location loc = op.getLoc();
1530   Value result = rewriter.create<arith::ConstantOp>(
1531       loc, resType, rewriter.getZeroAttr(resType));
1532   for (int64_t d = 0; d < dimSize; ++d) {
1533     auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
1534     auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
1535     auto acc = reshapeLoad(loc, op.acc(), resType, resIndex, d, rewriter);
1536     Value lowContract = rewriter.create<vector::ContractionOp>(
1537         loc, lhs, rhs, acc, lowAffine, lowIter);
1538     result =
1539         reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter);
1540   }
1541   return result;
1542 }
1543 
1544 // Lower one reduction dimension.
1545 Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
1546                                             PatternRewriter &rewriter) const {
1547   auto loc = op.getLoc();
1548   VectorType lhsType = op.getLhsType();
1549   VectorType rhsType = op.getRhsType();
1550   Type resType = op.getResultType();
1551   assert(!resType.isa<VectorType>());
1552   bool isInt = resType.isa<IntegerType>();
1553   // Use iterator index 0.
1554   int64_t iterIndex = 0;
1555   SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
1556   Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
1557   Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
1558   assert(lookupLhs.hasValue() && "missing LHS parallel index");
1559   assert(lookupRhs.hasValue() && "missing RHS parallel index");
1560   int64_t lhsIndex = lookupLhs.getValue();
1561   int64_t rhsIndex = lookupRhs.getValue();
1562   int64_t dimSize = lhsType.getDimSize(lhsIndex);
1563   assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape");
1564   // Base case.
1565   if (lhsType.getRank() == 1) {
1566     assert(rhsType.getRank() == 1 && "corrupt contraction");
1567     Value m = createMul(loc, op.lhs(), op.rhs(), isInt, rewriter);
1568     auto kind = vector::CombiningKind::ADD;
1569     Value res = rewriter.create<vector::ReductionOp>(loc, kind, m);
1570     if (auto acc = op.acc())
1571       res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
1572     return res;
1573   }
1574   // Construct new iterator types and affine map array attribute.
1575   std::array<AffineMap, 3> lowIndexingMaps = {
1576       adjustMap(iMap[0], iterIndex, rewriter),
1577       adjustMap(iMap[1], iterIndex, rewriter),
1578       adjustMap(iMap[2], iterIndex, rewriter)};
1579   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1580   auto lowIter =
1581       rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
1582   // Unroll into a series of lower dimensional vector.contract ops.
1583   // By feeding the initial accumulator into the first contraction,
1584   // and the result of each contraction into the next, eventually
1585   // the sum of all reductions is computed.
1586   Value result = op.acc();
1587   for (int64_t d = 0; d < dimSize; ++d) {
1588     auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
1589     auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
1590     result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result,
1591                                                     lowAffine, lowIter);
1592   }
1593   return result;
1594 }
1595 
1596 } // namespace mlir
1597 
1598 Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
1599     OpBuilder &builder, Operation *op, ArrayRef<Value> ids,
1600     ArrayRef<int64_t> multiplicity, const AffineMap &map) {
1601   OpBuilder::InsertionGuard guard(builder);
1602   builder.setInsertionPointAfter(op);
1603   Location loc = op->getLoc();
1604   if (op->getNumResults() != 1)
1605     return {};
1606   Value result = op->getResult(0);
1607   VectorType type = op->getResult(0).getType().dyn_cast<VectorType>();
1608   if (!type || map.getNumResults() != multiplicity.size())
1609     return {};
1610   // For each dimension being distributed check that the size is a multiple of
1611   // the multiplicity. To handle more sizes we would need to support masking.
1612   unsigned multiplictyCount = 0;
1613   for (auto exp : map.getResults()) {
1614     auto affinExp = exp.dyn_cast<AffineDimExpr>();
1615     if (!affinExp || affinExp.getPosition() >= type.getRank() ||
1616         type.getDimSize(affinExp.getPosition()) %
1617                 multiplicity[multiplictyCount++] !=
1618             0)
1619       return {};
1620   }
1621   DistributeOps ops;
1622   ops.extract =
1623       builder.create<vector::ExtractMapOp>(loc, result, ids, multiplicity, map);
1624   ops.insert =
1625       builder.create<vector::InsertMapOp>(loc, ops.extract, result, ids);
1626   return ops;
1627 }
1628 
1629 /// Progressive lowering of transfer_read. This pattern supports lowering of
1630 /// `vector.transfer_read` to a combination of `vector.load` and
1631 /// `vector.broadcast` if all of the following hold:
1632 /// - Stride of most minor memref dimension must be 1.
1633 /// - Out-of-bounds masking is not required.
1634 /// - If the memref's element type is a vector type then it coincides with the
1635 ///   result type.
1636 /// - The permutation map doesn't perform permutation (broadcasting is allowed).
1637 struct TransferReadToVectorLoadLowering
1638     : public OpRewritePattern<vector::TransferReadOp> {
1639   TransferReadToVectorLoadLowering(MLIRContext *context,
1640                                    llvm::Optional<unsigned> maxRank)
1641       : OpRewritePattern<vector::TransferReadOp>(context),
1642         maxTransferRank(maxRank) {}
1643 
1644   LogicalResult matchAndRewrite(vector::TransferReadOp read,
1645                                 PatternRewriter &rewriter) const override {
1646     if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank)
1647       return failure();
1648 
1649     SmallVector<unsigned, 4> broadcastedDims;
1650     // Permutations are handled by VectorToSCF or
1651     // populateVectorTransferPermutationMapLoweringPatterns.
1652     // We let the 0-d corner case pass-through as it is supported.
1653     if (!read.permutation_map().isMinorIdentityWithBroadcasting(
1654             &broadcastedDims))
1655       return failure();
1656 
1657     auto memRefType = read.getShapedType().dyn_cast<MemRefType>();
1658     if (!memRefType)
1659       return failure();
1660 
1661     // Non-unit strides are handled by VectorToSCF.
1662     if (!vector::isLastMemrefDimUnitStride(memRefType))
1663       return failure();
1664 
1665     // If there is broadcasting involved then we first load the unbroadcasted
1666     // vector, and then broadcast it with `vector.broadcast`.
1667     ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
1668     SmallVector<int64_t, 4> unbroadcastedVectorShape(vectorShape.begin(),
1669                                                      vectorShape.end());
1670     for (unsigned i : broadcastedDims)
1671       unbroadcastedVectorShape[i] = 1;
1672     VectorType unbroadcastedVectorType = VectorType::get(
1673         unbroadcastedVectorShape, read.getVectorType().getElementType());
1674 
1675     // `vector.load` supports vector types as memref's elements only when the
1676     // resulting vector type is the same as the element type.
1677     auto memrefElTy = memRefType.getElementType();
1678     if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType)
1679       return failure();
1680 
1681     // Otherwise, element types of the memref and the vector must match.
1682     if (!memrefElTy.isa<VectorType>() &&
1683         memrefElTy != read.getVectorType().getElementType())
1684       return failure();
1685 
1686     // Out-of-bounds dims are handled by MaterializeTransferMask.
1687     if (read.hasOutOfBoundsDim())
1688       return failure();
1689 
1690     // Create vector load op.
1691     Operation *loadOp;
1692     if (read.mask()) {
1693       Value fill = rewriter.create<vector::SplatOp>(
1694           read.getLoc(), unbroadcastedVectorType, read.padding());
1695       loadOp = rewriter.create<vector::MaskedLoadOp>(
1696           read.getLoc(), unbroadcastedVectorType, read.source(), read.indices(),
1697           read.mask(), fill);
1698     } else {
1699       loadOp = rewriter.create<vector::LoadOp>(read.getLoc(),
1700                                                unbroadcastedVectorType,
1701                                                read.source(), read.indices());
1702     }
1703 
1704     // Insert a broadcasting op if required.
1705     if (!broadcastedDims.empty()) {
1706       rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1707           read, read.getVectorType(), loadOp->getResult(0));
1708     } else {
1709       rewriter.replaceOp(read, loadOp->getResult(0));
1710     }
1711 
1712     return success();
1713   }
1714 
1715   llvm::Optional<unsigned> maxTransferRank;
1716 };
1717 
1718 /// Replace a 0-d vector.load with a memref.load + vector.broadcast.
1719 // TODO: we shouldn't cross the vector/scalar domains just for this
1720 // but atm we lack the infra to avoid it. Possible solutions include:
1721 // - go directly to LLVM + bitcast
1722 // - introduce a bitcast op and likely a new pointer dialect
1723 // - let memref.load/store additionally support the 0-d vector case
1724 // There are still deeper data layout issues lingering even in this
1725 // trivial case (for architectures for which this matters).
1726 struct VectorLoadToMemrefLoadLowering
1727     : public OpRewritePattern<vector::LoadOp> {
1728   using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
1729 
1730   LogicalResult matchAndRewrite(vector::LoadOp loadOp,
1731                                 PatternRewriter &rewriter) const override {
1732     auto vecType = loadOp.getVectorType();
1733     if (vecType.getNumElements() != 1)
1734       return failure();
1735     auto memrefLoad = rewriter.create<memref::LoadOp>(
1736         loadOp.getLoc(), loadOp.base(), loadOp.indices());
1737     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
1738                                                      memrefLoad);
1739     return success();
1740   }
1741 };
1742 
1743 /// Replace a 0-d vector.store with a vector.extractelement + memref.store.
1744 struct VectorStoreToMemrefStoreLowering
1745     : public OpRewritePattern<vector::StoreOp> {
1746   using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
1747 
1748   LogicalResult matchAndRewrite(vector::StoreOp storeOp,
1749                                 PatternRewriter &rewriter) const override {
1750     auto vecType = storeOp.getVectorType();
1751     if (vecType.getNumElements() != 1)
1752       return failure();
1753     Value extracted;
1754     if (vecType.getRank() == 0) {
1755       // TODO: Unifiy once ExtractOp supports 0-d vectors.
1756       extracted = rewriter.create<vector::ExtractElementOp>(
1757           storeOp.getLoc(), storeOp.valueToStore());
1758     } else {
1759       SmallVector<int64_t> indices(vecType.getRank(), 0);
1760       extracted = rewriter.create<vector::ExtractOp>(
1761           storeOp.getLoc(), storeOp.valueToStore(), indices);
1762     }
1763 
1764     rewriter.replaceOpWithNewOp<memref::StoreOp>(
1765         storeOp, extracted, storeOp.base(), storeOp.indices());
1766     return success();
1767   }
1768 };
1769 
1770 /// Progressive lowering of transfer_write. This pattern supports lowering of
1771 /// `vector.transfer_write` to `vector.store` if all of the following hold:
1772 /// - Stride of most minor memref dimension must be 1.
1773 /// - Out-of-bounds masking is not required.
1774 /// - If the memref's element type is a vector type then it coincides with the
1775 ///   type of the written value.
1776 /// - The permutation map is the minor identity map (neither permutation nor
1777 ///   broadcasting is allowed).
1778 struct TransferWriteToVectorStoreLowering
1779     : public OpRewritePattern<vector::TransferWriteOp> {
1780   TransferWriteToVectorStoreLowering(MLIRContext *context,
1781                                      llvm::Optional<unsigned> maxRank)
1782       : OpRewritePattern<vector::TransferWriteOp>(context),
1783         maxTransferRank(maxRank) {}
1784 
1785   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
1786                                 PatternRewriter &rewriter) const override {
1787     if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank)
1788       return failure();
1789 
1790     // Permutations are handled by VectorToSCF or
1791     // populateVectorTransferPermutationMapLoweringPatterns.
1792     if ( // pass-through for the 0-d corner case.
1793         !write.permutation_map().isMinorIdentity())
1794       return failure();
1795 
1796     auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
1797     if (!memRefType)
1798       return failure();
1799 
1800     // Non-unit strides are handled by VectorToSCF.
1801     if (!vector::isLastMemrefDimUnitStride(memRefType))
1802       return failure();
1803 
1804     // `vector.store` supports vector types as memref's elements only when the
1805     // type of the vector value being written is the same as the element type.
1806     auto memrefElTy = memRefType.getElementType();
1807     if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType())
1808       return failure();
1809 
1810     // Otherwise, element types of the memref and the vector must match.
1811     if (!memrefElTy.isa<VectorType>() &&
1812         memrefElTy != write.getVectorType().getElementType())
1813       return failure();
1814 
1815     // Out-of-bounds dims are handled by MaterializeTransferMask.
1816     if (write.hasOutOfBoundsDim())
1817       return failure();
1818     if (write.mask()) {
1819       rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
1820           write, write.source(), write.indices(), write.mask(), write.vector());
1821     } else {
1822       rewriter.replaceOpWithNewOp<vector::StoreOp>(
1823           write, write.vector(), write.source(), write.indices());
1824     }
1825     return success();
1826   }
1827 
1828   llvm::Optional<unsigned> maxTransferRank;
1829 };
1830 
1831 // Returns the values in `arrayAttr` as an integer vector.
1832 static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) {
1833   return llvm::to_vector<4>(
1834       llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
1835                       [](IntegerAttr attr) { return attr.getInt(); }));
1836 }
1837 
1838 // Shuffles vector.bitcast op after vector.extract op.
1839 //
1840 // This transforms IR like:
1841 //   %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
1842 //   %1 = vector.extract %0[3] : vector<8xf16>
1843 // Into:
1844 //   %0 = vector.extract %src[1] : vector<4xf32>
1845 //   %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16>
1846 //   %2 = vector.extract %1[1] : vector<2xf16>
1847 struct BubbleDownVectorBitCastForExtract
1848     : public OpRewritePattern<vector::ExtractOp> {
1849   using OpRewritePattern::OpRewritePattern;
1850 
1851   LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
1852                                 PatternRewriter &rewriter) const override {
1853     // Only support extracting scalars for now.
1854     if (extractOp.getVectorType().getRank() != 1)
1855       return failure();
1856 
1857     auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>();
1858     if (!castOp)
1859       return failure();
1860 
1861     VectorType castSrcType = castOp.getSourceVectorType();
1862     VectorType castDstType = castOp.getResultVectorType();
1863     assert(castSrcType.getRank() == castDstType.getRank());
1864 
1865     // Fail to match if we only have one element in the cast op source.
1866     // This is to avoid infinite loop given that this pattern can generate
1867     // such cases.
1868     if (castSrcType.getNumElements() == 1)
1869       return failure();
1870 
1871     // Only support casting to a larger number of elements or now.
1872     // E.g., vector<4xf32> -> vector<8xf16>.
1873     if (castSrcType.getNumElements() > castDstType.getNumElements())
1874       return failure();
1875 
1876     unsigned expandRatio =
1877         castDstType.getNumElements() / castSrcType.getNumElements();
1878 
1879     auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t {
1880       return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
1881     };
1882 
1883     uint64_t index = getFirstIntValue(extractOp.position());
1884 
1885     // Get the single scalar (as a vector) in the source value that packs the
1886     // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
1887     VectorType oneScalarType =
1888         VectorType::get({1}, castSrcType.getElementType());
1889     Value packedValue = rewriter.create<vector::ExtractOp>(
1890         extractOp.getLoc(), oneScalarType, castOp.source(),
1891         rewriter.getI64ArrayAttr(index / expandRatio));
1892 
1893     // Cast it to a vector with the desired scalar's type.
1894     // E.g. f32 -> vector<2xf16>
1895     VectorType packedType =
1896         VectorType::get({expandRatio}, castDstType.getElementType());
1897     Value castedValue = rewriter.create<vector::BitCastOp>(
1898         extractOp.getLoc(), packedType, packedValue);
1899 
1900     // Finally extract the desired scalar.
1901     rewriter.replaceOpWithNewOp<vector::ExtractOp>(
1902         extractOp, extractOp.getType(), castedValue,
1903         rewriter.getI64ArrayAttr(index % expandRatio));
1904 
1905     return success();
1906   }
1907 };
1908 
1909 // Shuffles vector.bitcast op after vector.extract_strided_slice op.
1910 //
1911 // This transforms IR like:
1912 //    %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
1913 //     %0 = vector.extract_strided_slice %cast {
1914 //            offsets = [4], sizes = [4], strides = [1]
1915 //          } : vector<8xf16> to vector<4xf16>
1916 // Into:
1917 //   %0 = vector.extract_strided_slice %src {
1918 //          offsets = [2], sizes = [2], strides = [1]
1919 //        } : vector<4xf32> to vector<2xf32>
1920 //   %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16>
1921 struct BubbleDownBitCastForStridedSliceExtract
1922     : public OpRewritePattern<vector::ExtractStridedSliceOp> {
1923   using OpRewritePattern::OpRewritePattern;
1924 
1925   LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
1926                                 PatternRewriter &rewriter) const override {
1927     auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>();
1928     if (!castOp)
1929       return failure();
1930 
1931     VectorType castSrcType = castOp.getSourceVectorType();
1932     VectorType castDstType = castOp.getResultVectorType();
1933     assert(castSrcType.getRank() == castDstType.getRank());
1934 
1935     int64_t castSrcLastDim = castSrcType.getShape().back();
1936     int64_t castDstLastDim = castDstType.getShape().back();
1937     // Require casting to more elements for now; other cases to be implemented.
1938     if (castSrcLastDim > castDstLastDim)
1939       return failure();
1940 
1941     // Only accept all one strides for now.
1942     if (llvm::any_of(extractOp.strides().getAsValueRange<IntegerAttr>(),
1943                      [](const APInt &val) { return !val.isOneValue(); }))
1944       return failure();
1945 
1946     unsigned rank = extractOp.getVectorType().getRank();
1947     assert(castDstLastDim % castSrcLastDim == 0);
1948     int64_t expandRatio = castDstLastDim / castSrcLastDim;
1949 
1950     // If we have a less number of offsets than the rank, then implicitly we
1951     // are selecting the full range for the last bitcasted dimension; other
1952     // dimensions aren't affected. Otherwise, we need to scale down the last
1953     // dimension's offset given we are extracting from less elements now.
1954     ArrayAttr newOffsets = extractOp.offsets();
1955     if (newOffsets.size() == rank) {
1956       SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
1957       if (offsets.back() % expandRatio != 0)
1958         return failure();
1959       offsets.back() = offsets.back() / expandRatio;
1960       newOffsets = rewriter.getI64ArrayAttr(offsets);
1961     }
1962 
1963     // Similarly for sizes.
1964     ArrayAttr newSizes = extractOp.sizes();
1965     if (newSizes.size() == rank) {
1966       SmallVector<int64_t, 4> sizes = getIntValueVector(newSizes);
1967       if (sizes.back() % expandRatio != 0)
1968         return failure();
1969       sizes.back() = sizes.back() / expandRatio;
1970       newSizes = rewriter.getI64ArrayAttr(sizes);
1971     }
1972 
1973     SmallVector<int64_t, 4> dims =
1974         llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape());
1975     dims.back() = dims.back() / expandRatio;
1976     VectorType newExtractType =
1977         VectorType::get(dims, castSrcType.getElementType());
1978 
1979     auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
1980         extractOp.getLoc(), newExtractType, castOp.source(), newOffsets,
1981         newSizes, extractOp.strides());
1982 
1983     rewriter.replaceOpWithNewOp<vector::BitCastOp>(
1984         extractOp, extractOp.getType(), newExtractOp);
1985 
1986     return success();
1987   }
1988 };
1989 
1990 // Shuffles vector.bitcast op before vector.insert_strided_slice op.
1991 //
1992 // This transforms IR like:
1993 //   %0 = vector.insert_strided_slice %src, %dst {
1994 //          offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
1995 //   %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
1996 // Into:
1997 //   %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32>
1998 //   %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32>
1999 //   %2 = vector.insert_strided_slice %src, %dst {
2000 //          offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
2001 struct BubbleUpBitCastForStridedSliceInsert
2002     : public OpRewritePattern<vector::BitCastOp> {
2003   using OpRewritePattern::OpRewritePattern;
2004   LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
2005                                 PatternRewriter &rewriter) const override {
2006     VectorType castSrcType = bitcastOp.getSourceVectorType();
2007     VectorType castDstType = bitcastOp.getResultVectorType();
2008     assert(castSrcType.getRank() == castDstType.getRank());
2009 
2010     int64_t castSrcLastDim = castSrcType.getShape().back();
2011     int64_t castDstLastDim = castDstType.getShape().back();
2012     // Require casting to less elements for now; other cases to be implemented.
2013     if (castSrcLastDim < castDstLastDim)
2014       return failure();
2015 
2016     assert(castSrcLastDim % castDstLastDim == 0);
2017     int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
2018 
2019     auto insertOp =
2020         bitcastOp.source().getDefiningOp<vector::InsertStridedSliceOp>();
2021     if (!insertOp)
2022       return failure();
2023 
2024     // Only accept all one strides for now.
2025     if (llvm::any_of(insertOp.strides().getAsValueRange<IntegerAttr>(),
2026                      [](const APInt &val) { return !val.isOneValue(); }))
2027       return failure();
2028 
2029     unsigned rank = insertOp.getSourceVectorType().getRank();
2030     // Require insert op to have the same rank for the source and destination
2031     // vector; other cases to be implemented.
2032     if (rank != insertOp.getDestVectorType().getRank())
2033       return failure();
2034 
2035     ArrayAttr newOffsets = insertOp.offsets();
2036     assert(newOffsets.size() == rank);
2037     SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
2038     if (offsets.back() % shrinkRatio != 0)
2039       return failure();
2040     offsets.back() = offsets.back() / shrinkRatio;
2041     newOffsets = rewriter.getI64ArrayAttr(offsets);
2042 
2043     SmallVector<int64_t, 4> srcDims =
2044         llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
2045     srcDims.back() = srcDims.back() / shrinkRatio;
2046     VectorType newCastSrcType =
2047         VectorType::get(srcDims, castDstType.getElementType());
2048 
2049     auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
2050         bitcastOp.getLoc(), newCastSrcType, insertOp.source());
2051 
2052     SmallVector<int64_t, 4> dstDims =
2053         llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
2054     dstDims.back() = dstDims.back() / shrinkRatio;
2055     VectorType newCastDstType =
2056         VectorType::get(dstDims, castDstType.getElementType());
2057 
2058     auto newCastDstOp = rewriter.create<vector::BitCastOp>(
2059         bitcastOp.getLoc(), newCastDstType, insertOp.dest());
2060 
2061     rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
2062         bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
2063         insertOp.strides());
2064 
2065     return success();
2066   }
2067 };
2068 
2069 static Value createCastToIndexLike(PatternRewriter &rewriter, Location loc,
2070                                    Type targetType, Value value) {
2071   if (targetType == value.getType())
2072     return value;
2073 
2074   bool targetIsIndex = targetType.isIndex();
2075   bool valueIsIndex = value.getType().isIndex();
2076   if (targetIsIndex ^ valueIsIndex)
2077     return rewriter.create<arith::IndexCastOp>(loc, targetType, value);
2078 
2079   auto targetIntegerType = targetType.dyn_cast<IntegerType>();
2080   auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
2081   assert(targetIntegerType && valueIntegerType &&
2082          "unexpected cast between types other than integers and index");
2083   assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
2084 
2085   if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
2086     return rewriter.create<arith::ExtSIOp>(loc, targetIntegerType, value);
2087   return rewriter.create<arith::TruncIOp>(loc, targetIntegerType, value);
2088 }
2089 
2090 // Helper that returns a vector comparison that constructs a mask:
2091 //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
2092 //
2093 // If `dim == 0` then the result will be a 0-D vector.
2094 //
2095 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
2096 //       much more compact, IR for this operation, but LLVM eventually
2097 //       generates more elaborate instructions for this intrinsic since it
2098 //       is very conservative on the boundary conditions.
2099 static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
2100                                    bool indexOptimizations, int64_t dim,
2101                                    Value b, Value *off = nullptr) {
2102   auto loc = op->getLoc();
2103   // If we can assume all indices fit in 32-bit, we perform the vector
2104   // comparison in 32-bit to get a higher degree of SIMD parallelism.
2105   // Otherwise we perform the vector comparison using 64-bit indices.
2106   Type idxType =
2107       indexOptimizations ? rewriter.getI32Type() : rewriter.getI64Type();
2108   DenseIntElementsAttr indicesAttr;
2109   if (dim == 0 && indexOptimizations) {
2110     indicesAttr = DenseIntElementsAttr::get(
2111         VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0});
2112   } else if (dim == 0) {
2113     indicesAttr = DenseIntElementsAttr::get(
2114         VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0});
2115   } else if (indexOptimizations) {
2116     indicesAttr = rewriter.getI32VectorAttr(
2117         llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
2118   } else {
2119     indicesAttr = rewriter.getI64VectorAttr(
2120         llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
2121   }
2122   Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
2123   // Add in an offset if requested.
2124   if (off) {
2125     Value o = createCastToIndexLike(rewriter, loc, idxType, *off);
2126     Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o);
2127     indices = rewriter.create<arith::AddIOp>(loc, ov, indices);
2128   }
2129   // Construct the vector comparison.
2130   Value bound = createCastToIndexLike(rewriter, loc, idxType, b);
2131   Value bounds =
2132       rewriter.create<vector::SplatOp>(loc, indices.getType(), bound);
2133   return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
2134                                         bounds);
2135 }
2136 
2137 template <typename ConcreteOp>
2138 struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
2139 public:
2140   explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt)
2141       : mlir::OpRewritePattern<ConcreteOp>(context),
2142         indexOptimizations(enableIndexOpt) {}
2143 
2144   LogicalResult matchAndRewrite(ConcreteOp xferOp,
2145                                 PatternRewriter &rewriter) const override {
2146     if (!xferOp.hasOutOfBoundsDim())
2147       return failure();
2148 
2149     if (xferOp.getVectorType().getRank() > 1 ||
2150         llvm::size(xferOp.indices()) == 0)
2151       return failure();
2152 
2153     Location loc = xferOp->getLoc();
2154     VectorType vtp = xferOp.getVectorType();
2155 
2156     // * Create a vector with linear indices [ 0 .. vector_length - 1 ].
2157     // * Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
2158     // * Let dim the memref dimension, compute the vector comparison mask
2159     //   (in-bounds mask):
2160     //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
2161     //
2162     // TODO: when the leaf transfer rank is k > 1, we need the last `k`
2163     //       dimensions here.
2164     unsigned vecWidth = vtp.getNumElements();
2165     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
2166     Value off = xferOp.indices()[lastIndex];
2167     Value dim =
2168         vector::createOrFoldDimOp(rewriter, loc, xferOp.source(), lastIndex);
2169     Value mask = buildVectorComparison(rewriter, xferOp, indexOptimizations,
2170                                        vecWidth, dim, &off);
2171 
2172     if (xferOp.mask()) {
2173       // Intersect the in-bounds with the mask specified as an op parameter.
2174       mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.mask());
2175     }
2176 
2177     rewriter.updateRootInPlace(xferOp, [&]() {
2178       xferOp.maskMutable().assign(mask);
2179       xferOp.in_boundsAttr(rewriter.getBoolArrayAttr({true}));
2180     });
2181 
2182     return success();
2183   }
2184 
2185 private:
2186   const bool indexOptimizations;
2187 };
2188 
2189 /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only).
2190 class VectorCreateMaskOpConversion
2191     : public OpRewritePattern<vector::CreateMaskOp> {
2192 public:
2193   explicit VectorCreateMaskOpConversion(MLIRContext *context,
2194                                         bool enableIndexOpt)
2195       : mlir::OpRewritePattern<vector::CreateMaskOp>(context),
2196         indexOptimizations(enableIndexOpt) {}
2197 
2198   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
2199                                 PatternRewriter &rewriter) const override {
2200     auto dstType = op.getType();
2201     int64_t rank = dstType.getRank();
2202     if (rank > 1)
2203       return failure();
2204     rewriter.replaceOp(
2205         op, buildVectorComparison(rewriter, op, indexOptimizations,
2206                                   rank == 0 ? 0 : dstType.getDimSize(0),
2207                                   op.getOperand(0)));
2208     return success();
2209   }
2210 
2211 private:
2212   const bool indexOptimizations;
2213 };
2214 
2215 // Drop inner most contiguous unit dimensions from transfer_read operand.
2216 class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
2217   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
2218 
2219   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
2220                                 PatternRewriter &rewriter) const override {
2221     // TODO: support 0-d corner case.
2222     if (readOp.getTransferRank() == 0)
2223       return failure();
2224 
2225     // TODO: support mask.
2226     if (readOp.mask())
2227       return failure();
2228 
2229     auto srcType = readOp.source().getType().dyn_cast<MemRefType>();
2230     if (!srcType || !srcType.hasStaticShape())
2231       return failure();
2232 
2233     if (!readOp.permutation_map().isMinorIdentity())
2234       return failure();
2235 
2236     auto targetType = readOp.getVectorType();
2237     if (targetType.getRank() <= 1)
2238       return failure();
2239 
2240     SmallVector<int64_t> srcStrides;
2241     int64_t srcOffset;
2242     if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
2243       return failure();
2244 
2245     size_t dimsToDrop = 0;
2246     for (size_t i = 1; i < srcStrides.size(); ++i) {
2247       int dim = srcType.getRank() - i - 1;
2248       if (srcStrides[dim] == 1) {
2249         dimsToDrop++;
2250       } else {
2251         break;
2252       }
2253     }
2254     if (dimsToDrop == 0)
2255       return failure();
2256 
2257     auto resultTargetVecType =
2258         VectorType::get(targetType.getShape().drop_back(dimsToDrop),
2259                         targetType.getElementType());
2260 
2261     MemRefType resultMemrefType;
2262     if (srcType.getLayout().getAffineMap().isIdentity()) {
2263       resultMemrefType = MemRefType::get(
2264           srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
2265           {}, srcType.getMemorySpaceAsInt());
2266     } else {
2267       AffineMap map = srcType.getLayout().getAffineMap();
2268       int numResultDims = map.getNumDims() - dimsToDrop;
2269       int numSymbols = map.getNumSymbols();
2270       for (size_t i = 0; i < dimsToDrop; ++i) {
2271         int dim = srcType.getRank() - i - 1;
2272         map = map.replace(rewriter.getAffineDimExpr(dim),
2273                           rewriter.getAffineConstantExpr(0), numResultDims,
2274                           numSymbols);
2275       }
2276       resultMemrefType = MemRefType::get(
2277           srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
2278           map, srcType.getMemorySpaceAsInt());
2279     }
2280 
2281     auto loc = readOp.getLoc();
2282     SmallVector<int64_t> offsets(srcType.getRank(), 0);
2283     SmallVector<int64_t> strides(srcType.getRank(), 1);
2284 
2285     ArrayAttr inBoundsAttr =
2286         readOp.in_bounds()
2287             ? rewriter.getArrayAttr(
2288                   readOp.in_boundsAttr().getValue().drop_back(dimsToDrop))
2289             : ArrayAttr();
2290     Value rankedReducedView = rewriter.create<memref::SubViewOp>(
2291         loc, resultMemrefType, readOp.source(), offsets, srcType.getShape(),
2292         strides);
2293     auto permMap = getTransferMinorIdentityMap(
2294         rankedReducedView.getType().cast<ShapedType>(), resultTargetVecType);
2295     Value result = rewriter.create<vector::TransferReadOp>(
2296         loc, resultTargetVecType, rankedReducedView,
2297         readOp.indices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
2298         readOp.padding(),
2299         // TODO: support mask.
2300         /*mask=*/Value(), inBoundsAttr);
2301     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
2302                                                      result);
2303     return success();
2304   }
2305 };
2306 
2307 namespace {
2308 
2309 /// This function checks to see if the vector combining kind
2310 /// is consistent with the integer or float element type.
2311 static bool isValidKind(bool isInt, vector::CombiningKind kind) {
2312   using vector::CombiningKind;
2313   enum class KindType { FLOAT, INT, INVALID };
2314   KindType type{KindType::INVALID};
2315   switch (kind) {
2316   case CombiningKind::MINF:
2317   case CombiningKind::MAXF:
2318     type = KindType::FLOAT;
2319     break;
2320   case CombiningKind::MINUI:
2321   case CombiningKind::MINSI:
2322   case CombiningKind::MAXUI:
2323   case CombiningKind::MAXSI:
2324   case CombiningKind::AND:
2325   case CombiningKind::OR:
2326   case CombiningKind::XOR:
2327     type = KindType::INT;
2328     break;
2329   case CombiningKind::ADD:
2330   case CombiningKind::MUL:
2331     type = isInt ? KindType::INT : KindType::FLOAT;
2332     break;
2333   }
2334   bool isValidIntKind = (type == KindType::INT) && isInt;
2335   bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt);
2336   return (isValidIntKind || isValidFloatKind);
2337 }
2338 
2339 /// This function constructs the appropriate integer or float
2340 /// operation given the vector combining kind and operands. The
2341 /// supported int operations are : add, mul, min (signed/unsigned),
2342 /// max(signed/unsigned), and, or, xor. The supported float
2343 /// operations are : add, mul, min and max.
2344 static Value genOperator(Location loc, Value x, Value y,
2345                          vector::CombiningKind kind,
2346                          PatternRewriter &rewriter) {
2347   using vector::CombiningKind;
2348 
2349   auto elType = x.getType().cast<VectorType>().getElementType();
2350   bool isInt = elType.isIntOrIndex();
2351 
2352   Value combinedResult{nullptr};
2353   switch (kind) {
2354   case CombiningKind::ADD:
2355     if (isInt)
2356       combinedResult = rewriter.create<arith::AddIOp>(loc, x, y);
2357     else
2358       combinedResult = rewriter.create<arith::AddFOp>(loc, x, y);
2359     break;
2360   case CombiningKind::MUL:
2361     if (isInt)
2362       combinedResult = rewriter.create<arith::MulIOp>(loc, x, y);
2363     else
2364       combinedResult = rewriter.create<arith::MulFOp>(loc, x, y);
2365     break;
2366   case CombiningKind::MINUI:
2367     combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y);
2368     break;
2369   case CombiningKind::MINSI:
2370     combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y);
2371     break;
2372   case CombiningKind::MAXUI:
2373     combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y);
2374     break;
2375   case CombiningKind::MAXSI:
2376     combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y);
2377     break;
2378   case CombiningKind::AND:
2379     combinedResult = rewriter.create<arith::AndIOp>(loc, x, y);
2380     break;
2381   case CombiningKind::OR:
2382     combinedResult = rewriter.create<arith::OrIOp>(loc, x, y);
2383     break;
2384   case CombiningKind::XOR:
2385     combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y);
2386     break;
2387   case CombiningKind::MINF:
2388     combinedResult = rewriter.create<arith::MinFOp>(loc, x, y);
2389     break;
2390   case CombiningKind::MAXF:
2391     combinedResult = rewriter.create<arith::MaxFOp>(loc, x, y);
2392     break;
2393   }
2394   return combinedResult;
2395 }
2396 
2397 /// Convert vector.scan op into arith ops and
2398 /// vector.insert_strided_slice/extract_strided_slice
2399 ///
2400 /// Ex:
2401 /// ```
2402 ///   %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim =
2403 ///   1} :
2404 ///     (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>)
2405 /// ```
2406 /// Gets converted to:
2407 /// ```
2408 ///   %cst = arith.constant dense<0> : vector<2x3xi32>
2409 ///   %0 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 1],
2410 ///   strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %1 =
2411 ///   vector.insert_strided_slice %0, %cst {offsets = [0, 0], strides = [1, 1]}
2412 ///   : vector<2x1xi32> into vector<2x3xi32> %2 = vector.extract_strided_slice
2413 ///   %arg0 {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} :
2414 ///   vector<2x3xi32> to vector<2x1xi32> %3 = arith.muli %0, %2 :
2415 ///   vector<2x1xi32> %4 = vector.insert_strided_slice %3, %1 {offsets = [0, 1],
2416 ///   strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %5 =
2417 ///   vector.extract_strided_slice %arg0 {offsets = [0, 2], sizes = [2, 1],
2418 ///   strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %6 = arith.muli %3,
2419 ///   %5 : vector<2x1xi32> %7 = vector.insert_strided_slice %6, %4 {offsets =
2420 ///   [0, 2], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %8 =
2421 ///   vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32> return %7, %8 :
2422 ///   vector<2x3xi32>, vector<2xi32>
2423 /// ```
2424 struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
2425   using OpRewritePattern<vector::ScanOp>::OpRewritePattern;
2426 
2427   LogicalResult matchAndRewrite(vector::ScanOp scanOp,
2428                                 PatternRewriter &rewriter) const override {
2429     auto loc = scanOp.getLoc();
2430     VectorType destType = scanOp.getDestType();
2431     ArrayRef<int64_t> destShape = destType.getShape();
2432     auto elType = destType.getElementType();
2433     bool isInt = elType.isIntOrIndex();
2434     if (!isValidKind(isInt, scanOp.kind()))
2435       return failure();
2436 
2437     VectorType resType = VectorType::get(destShape, elType);
2438     Value result = rewriter.create<arith::ConstantOp>(
2439         loc, resType, rewriter.getZeroAttr(resType));
2440     int64_t reductionDim = scanOp.reduction_dim();
2441     bool inclusive = scanOp.inclusive();
2442     int64_t destRank = destType.getRank();
2443     VectorType initialValueType = scanOp.getInitialValueType();
2444     int64_t initialValueRank = initialValueType.getRank();
2445 
2446     SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end());
2447     reductionShape[reductionDim] = 1;
2448     VectorType reductionType = VectorType::get(reductionShape, elType);
2449     SmallVector<int64_t> offsets(destRank, 0);
2450     SmallVector<int64_t> strides(destRank, 1);
2451     SmallVector<int64_t> sizes(destShape.begin(), destShape.end());
2452     sizes[reductionDim] = 1;
2453     ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes);
2454     ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides);
2455 
2456     Value lastOutput, lastInput;
2457     for (int i = 0; i < destShape[reductionDim]; i++) {
2458       offsets[reductionDim] = i;
2459       ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets);
2460       Value input = rewriter.create<vector::ExtractStridedSliceOp>(
2461           loc, reductionType, scanOp.source(), scanOffsets, scanSizes,
2462           scanStrides);
2463       Value output;
2464       if (i == 0) {
2465         if (inclusive) {
2466           output = input;
2467         } else {
2468           if (initialValueRank == 0) {
2469             // ShapeCastOp cannot handle 0-D vectors
2470             output = rewriter.create<vector::BroadcastOp>(
2471                 loc, input.getType(), scanOp.initial_value());
2472           } else {
2473             output = rewriter.create<vector::ShapeCastOp>(
2474                 loc, input.getType(), scanOp.initial_value());
2475           }
2476         }
2477       } else {
2478         Value y = inclusive ? input : lastInput;
2479         output = genOperator(loc, lastOutput, y, scanOp.kind(), rewriter);
2480         assert(output != nullptr);
2481       }
2482       result = rewriter.create<vector::InsertStridedSliceOp>(
2483           loc, output, result, offsets, strides);
2484       lastOutput = output;
2485       lastInput = input;
2486     }
2487 
2488     Value reduction;
2489     if (initialValueRank == 0) {
2490       Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0);
2491       reduction =
2492           rewriter.create<vector::BroadcastOp>(loc, initialValueType, v);
2493     } else {
2494       reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType,
2495                                                        lastOutput);
2496     }
2497 
2498     rewriter.replaceOp(scanOp, {result, reduction});
2499     return success();
2500   }
2501 };
2502 
2503 } // namespace
2504 
2505 void mlir::vector::populateVectorMaskMaterializationPatterns(
2506     RewritePatternSet &patterns, bool indexOptimizations) {
2507   patterns.add<VectorCreateMaskOpConversion,
2508                MaterializeTransferMask<vector::TransferReadOp>,
2509                MaterializeTransferMask<vector::TransferWriteOp>>(
2510       patterns.getContext(), indexOptimizations);
2511 }
2512 
2513 void mlir::vector::populateShapeCastFoldingPatterns(
2514     RewritePatternSet &patterns) {
2515   patterns.add<ShapeCastOpFolder>(patterns.getContext());
2516 }
2517 
2518 void mlir::vector::populateBubbleVectorBitCastOpPatterns(
2519     RewritePatternSet &patterns) {
2520   patterns.add<BubbleDownVectorBitCastForExtract,
2521                BubbleDownBitCastForStridedSliceExtract,
2522                BubbleUpBitCastForStridedSliceInsert>(patterns.getContext());
2523 }
2524 
2525 void mlir::vector::populateVectorBroadcastLoweringPatterns(
2526     RewritePatternSet &patterns) {
2527   patterns.add<BroadcastOpLowering>(patterns.getContext());
2528 }
2529 
2530 void mlir::vector::populateVectorMaskOpLoweringPatterns(
2531     RewritePatternSet &patterns) {
2532   patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
2533       patterns.getContext());
2534 }
2535 
2536 void mlir::vector::populateVectorShapeCastLoweringPatterns(
2537     RewritePatternSet &patterns) {
2538   patterns.add<ShapeCastOp2DDownCastRewritePattern,
2539                ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>(
2540       patterns.getContext());
2541 }
2542 
2543 void mlir::vector::populateVectorContractLoweringPatterns(
2544     RewritePatternSet &patterns, VectorTransformsOptions options) {
2545   patterns.add<OuterProductOpLowering>(patterns.getContext());
2546   patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
2547                ContractionOpToOuterProductOpLowering>(options,
2548                                                       patterns.getContext());
2549 }
2550 
2551 void mlir::vector::populateVectorTransposeLoweringPatterns(
2552     RewritePatternSet &patterns, VectorTransformsOptions options) {
2553   patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
2554       options, patterns.getContext());
2555 }
2556 
2557 void mlir::vector::populateVectorReductionToContractPatterns(
2558     RewritePatternSet &patterns) {
2559   patterns.add<MultiReduceToContract, CombineContractBroadcast,
2560                CombineContractTranspose>(patterns.getContext());
2561 }
2562 
2563 void mlir::vector::
2564     populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
2565         RewritePatternSet &patterns) {
2566   patterns.add<DropInnerMostUnitDims>(patterns.getContext());
2567 }
2568 
2569 void mlir::vector::populateVectorTransferLoweringPatterns(
2570     RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank) {
2571   patterns.add<TransferReadToVectorLoadLowering,
2572                TransferWriteToVectorStoreLowering>(patterns.getContext(),
2573                                                    maxTransferRank);
2574   patterns
2575       .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
2576           patterns.getContext());
2577 }
2578 
2579 void mlir::vector::populateVectorScanLoweringPatterns(
2580     RewritePatternSet &patterns) {
2581   patterns.add<ScanToArithOps>(patterns.getContext());
2582 }
2583