1 //===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements convenience types for working with super-vectorization
10 // operations, in particular super-vector loads and stores.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Vector/IR/VectorOps.h"
15 
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/MemRef/IR/MemRef.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/Dialect/Utils/IndexingUtils.h"
22 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
23 #include "mlir/IR/AffineExpr.h"
24 #include "mlir/IR/AffineMap.h"
25 #include "mlir/IR/BlockAndValueMapping.h"
26 #include "mlir/IR/Builders.h"
27 #include "mlir/IR/BuiltinOps.h"
28 #include "mlir/IR/BuiltinTypes.h"
29 #include "mlir/IR/DialectImplementation.h"
30 #include "mlir/IR/OpImplementation.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeUtilities.h"
33 #include "mlir/Support/LLVM.h"
34 #include "mlir/Support/MathExtras.h"
35 #include "llvm/ADT/StringSet.h"
36 #include "llvm/ADT/bit.h"
37 #include <numeric>
38 
39 #include "mlir/Dialect/Vector/IR/VectorOpsDialect.cpp.inc"
40 // Pull in all enum type and utility function definitions.
41 #include "mlir/Dialect/Vector/IR/VectorOpsEnums.cpp.inc"
42 
43 using namespace mlir;
44 using namespace mlir::vector;
45 
46 /// Helper enum to classify mask value.
47 enum class MaskFormat {
48   AllTrue = 0,
49   AllFalse = 1,
50   Unknown = 2,
51 };
52 
53 /// Helper method to classify a 1-D mask value. Currently, the method
54 /// looks "under the hood" of a constant value with dense attributes
55 /// and a constant mask operation (since the client may be called at
56 /// various stages during progressive lowering).
57 static MaskFormat get1DMaskFormat(Value mask) {
58   if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
59     // Inspect constant dense values. We count up for bits that
60     // are set, count down for bits that are cleared, and bail
61     // when a mix is detected.
62     if (auto denseElts = c.getValue().dyn_cast<DenseIntElementsAttr>()) {
63       int64_t val = 0;
64       for (bool b : denseElts.getValues<bool>())
65         if (b && val >= 0)
66           val++;
67         else if (!b && val <= 0)
68           val--;
69         else
70           return MaskFormat::Unknown;
71       if (val > 0)
72         return MaskFormat::AllTrue;
73       if (val < 0)
74         return MaskFormat::AllFalse;
75     }
76   } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
77     // Inspect constant mask index. If the index exceeds the
78     // dimension size, all bits are set. If the index is zero
79     // or less, no bits are set.
80     ArrayAttr masks = m.mask_dim_sizes();
81     assert(masks.size() == 1);
82     int64_t i = masks[0].cast<IntegerAttr>().getInt();
83     int64_t u = m.getType().getDimSize(0);
84     if (i >= u)
85       return MaskFormat::AllTrue;
86     if (i <= 0)
87       return MaskFormat::AllFalse;
88   }
89   return MaskFormat::Unknown;
90 }
91 
92 // Helper for verifying combining kinds in contractions and reductions.
93 static bool isSupportedCombiningKind(CombiningKind combiningKind,
94                                      Type elementType) {
95   switch (combiningKind) {
96   case CombiningKind::ADD:
97   case CombiningKind::MUL:
98     return elementType.isIntOrIndexOrFloat();
99   case CombiningKind::MINUI:
100   case CombiningKind::MINSI:
101   case CombiningKind::MAXUI:
102   case CombiningKind::MAXSI:
103   case CombiningKind::AND:
104   case CombiningKind::OR:
105   case CombiningKind::XOR:
106     return elementType.isIntOrIndex();
107   case CombiningKind::MINF:
108   case CombiningKind::MAXF:
109     return elementType.isa<FloatType>();
110   }
111   return false;
112 }
113 
114 /// Return true if the last dimension of the MemRefType has unit stride. Also
115 /// return true for memrefs with no strides.
116 bool mlir::vector::isLastMemrefDimUnitStride(MemRefType type) {
117   int64_t offset;
118   SmallVector<int64_t> strides;
119   auto successStrides = getStridesAndOffset(type, strides, offset);
120   return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
121 }
122 
123 AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
124                                                     VectorType vectorType) {
125   int64_t elementVectorRank = 0;
126   VectorType elementVectorType =
127       shapedType.getElementType().dyn_cast<VectorType>();
128   if (elementVectorType)
129     elementVectorRank += elementVectorType.getRank();
130   // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
131   // TODO: replace once we have 0-d vectors.
132   if (shapedType.getRank() == 0 &&
133       vectorType.getShape() == ArrayRef<int64_t>{1})
134     return AffineMap::get(
135         /*numDims=*/0, /*numSymbols=*/0,
136         getAffineConstantExpr(0, shapedType.getContext()));
137   return AffineMap::getMinorIdentityMap(
138       shapedType.getRank(), vectorType.getRank() - elementVectorRank,
139       shapedType.getContext());
140 }
141 
142 bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
143                                      vector::TransferReadOp read) {
144   return !defWrite.hasOutOfBoundsDim() && !defWrite.mask() && !read.mask() &&
145          defWrite.indices() == read.indices() &&
146          defWrite.getVectorType() == read.getVectorType() &&
147          defWrite.permutation_map() == read.permutation_map();
148 }
149 
150 bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
151                                      vector::TransferWriteOp priorWrite) {
152   return priorWrite.indices() == write.indices() &&
153          priorWrite.mask() == write.mask() &&
154          priorWrite.getVectorType() == write.getVectorType() &&
155          priorWrite.permutation_map() == write.permutation_map();
156 }
157 
158 bool mlir::vector::isDisjointTransferIndices(
159     VectorTransferOpInterface transferA, VectorTransferOpInterface transferB) {
160   // For simplicity only look at transfer of same type.
161   if (transferA.getVectorType() != transferB.getVectorType())
162     return false;
163   unsigned rankOffset = transferA.getLeadingShapedRank();
164   for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) {
165     auto indexA = transferA.indices()[i].getDefiningOp<arith::ConstantOp>();
166     auto indexB = transferB.indices()[i].getDefiningOp<arith::ConstantOp>();
167     // If any of the indices are dynamic we cannot prove anything.
168     if (!indexA || !indexB)
169       continue;
170 
171     if (i < rankOffset) {
172       // For leading dimensions, if we can prove that index are different we
173       // know we are accessing disjoint slices.
174       if (indexA.getValue().cast<IntegerAttr>().getInt() !=
175           indexB.getValue().cast<IntegerAttr>().getInt())
176         return true;
177     } else {
178       // For this dimension, we slice a part of the memref we need to make sure
179       // the intervals accessed don't overlap.
180       int64_t distance =
181           std::abs(indexA.getValue().cast<IntegerAttr>().getInt() -
182                    indexB.getValue().cast<IntegerAttr>().getInt());
183       if (distance >= transferA.getVectorType().getDimSize(i - rankOffset))
184         return true;
185     }
186   }
187   return false;
188 }
189 
190 bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
191                                          VectorTransferOpInterface transferB) {
192   if (transferA.source() != transferB.source())
193     return false;
194   return isDisjointTransferIndices(transferA, transferB);
195 }
196 
197 //===----------------------------------------------------------------------===//
198 // CombiningKindAttr
199 //===----------------------------------------------------------------------===//
200 
201 namespace mlir {
202 namespace vector {
203 namespace detail {
204 struct BitmaskEnumStorage : public AttributeStorage {
205   using KeyTy = uint64_t;
206 
207   BitmaskEnumStorage(KeyTy val) : value(val) {}
208 
209   bool operator==(const KeyTy &key) const { return value == key; }
210 
211   static BitmaskEnumStorage *construct(AttributeStorageAllocator &allocator,
212                                        const KeyTy &key) {
213     return new (allocator.allocate<BitmaskEnumStorage>())
214         BitmaskEnumStorage(key);
215   }
216 
217   KeyTy value = 0;
218 };
219 } // namespace detail
220 } // namespace vector
221 } // namespace mlir
222 
223 CombiningKindAttr CombiningKindAttr::get(CombiningKind kind,
224                                          MLIRContext *context) {
225   return Base::get(context, static_cast<uint64_t>(kind));
226 }
227 
228 CombiningKind CombiningKindAttr::getKind() const {
229   return static_cast<CombiningKind>(getImpl()->value);
230 }
231 
232 static constexpr const CombiningKind combiningKindsList[] = {
233     // clang-format off
234     CombiningKind::ADD,
235     CombiningKind::MUL,
236     CombiningKind::MINUI,
237     CombiningKind::MINSI,
238     CombiningKind::MINF,
239     CombiningKind::MAXUI,
240     CombiningKind::MAXSI,
241     CombiningKind::MAXF,
242     CombiningKind::AND,
243     CombiningKind::OR,
244     CombiningKind::XOR,
245     // clang-format on
246 };
247 
248 void CombiningKindAttr::print(AsmPrinter &printer) const {
249   printer << "<";
250   auto kinds = llvm::make_filter_range(combiningKindsList, [&](auto kind) {
251     return bitEnumContains(this->getKind(), kind);
252   });
253   llvm::interleaveComma(kinds, printer,
254                         [&](auto kind) { printer << stringifyEnum(kind); });
255   printer << ">";
256 }
257 
258 Attribute CombiningKindAttr::parse(AsmParser &parser, Type type) {
259   if (failed(parser.parseLess()))
260     return {};
261 
262   StringRef elemName;
263   if (failed(parser.parseKeyword(&elemName)))
264     return {};
265 
266   auto kind = symbolizeCombiningKind(elemName);
267   if (!kind) {
268     parser.emitError(parser.getNameLoc(), "Unknown combining kind: ")
269         << elemName;
270     return {};
271   }
272 
273   if (failed(parser.parseGreater()))
274     return {};
275 
276   return CombiningKindAttr::get(kind.getValue(), parser.getContext());
277 }
278 
279 Attribute VectorDialect::parseAttribute(DialectAsmParser &parser,
280                                         Type type) const {
281   StringRef attrKind;
282   if (parser.parseKeyword(&attrKind))
283     return {};
284 
285   if (attrKind == "kind")
286     return CombiningKindAttr::parse(parser, {});
287 
288   parser.emitError(parser.getNameLoc(), "Unknown attribute type: ") << attrKind;
289   return {};
290 }
291 
292 void VectorDialect::printAttribute(Attribute attr,
293                                    DialectAsmPrinter &os) const {
294   if (auto ck = attr.dyn_cast<CombiningKindAttr>()) {
295     os << "kind";
296     ck.print(os);
297     return;
298   }
299   llvm_unreachable("Unknown attribute type");
300 }
301 
302 //===----------------------------------------------------------------------===//
303 // VectorDialect
304 //===----------------------------------------------------------------------===//
305 
306 void VectorDialect::initialize() {
307   addAttributes<CombiningKindAttr>();
308 
309   addOperations<
310 #define GET_OP_LIST
311 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
312       >();
313 }
314 
315 /// Materialize a single constant operation from a given attribute value with
316 /// the desired resultant type.
317 Operation *VectorDialect::materializeConstant(OpBuilder &builder,
318                                               Attribute value, Type type,
319                                               Location loc) {
320   return builder.create<arith::ConstantOp>(loc, type, value);
321 }
322 
323 IntegerType vector::getVectorSubscriptType(Builder &builder) {
324   return builder.getIntegerType(64);
325 }
326 
327 ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
328                                          ArrayRef<int64_t> values) {
329   return builder.getI64ArrayAttr(values);
330 }
331 
332 //===----------------------------------------------------------------------===//
333 // MultiDimReductionOp
334 //===----------------------------------------------------------------------===//
335 
336 void vector::MultiDimReductionOp::build(OpBuilder &builder,
337                                         OperationState &result, Value source,
338                                         ArrayRef<bool> reductionMask,
339                                         CombiningKind kind) {
340   result.addOperands(source);
341   auto sourceVectorType = source.getType().cast<VectorType>();
342   auto targetType = MultiDimReductionOp::inferDestType(
343       sourceVectorType.getShape(), reductionMask,
344       sourceVectorType.getElementType());
345   result.addTypes(targetType);
346 
347   SmallVector<int64_t> reductionDims;
348   for (const auto &en : llvm::enumerate(reductionMask))
349     if (en.value())
350       reductionDims.push_back(en.index());
351   result.addAttribute(getReductionDimsAttrName(),
352                       builder.getI64ArrayAttr(reductionDims));
353   result.addAttribute(getKindAttrName(),
354                       CombiningKindAttr::get(kind, builder.getContext()));
355 }
356 
357 static LogicalResult verify(MultiDimReductionOp op) {
358   auto reductionMask = op.getReductionMask();
359   auto targetType = MultiDimReductionOp::inferDestType(
360       op.getSourceVectorType().getShape(), reductionMask,
361       op.getSourceVectorType().getElementType());
362   // TODO: update to support 0-d vectors when available.
363   if (targetType != op.getDestType())
364     return op.emitError("invalid output vector type: ")
365            << op.getDestType() << " (expected: " << targetType << ")";
366   return success();
367 }
368 
369 OpFoldResult MultiDimReductionOp::fold(ArrayRef<Attribute> operands) {
370   // Single parallel dim, this is a noop.
371   if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
372     return source();
373   return {};
374 }
375 
376 //===----------------------------------------------------------------------===//
377 // ReductionOp
378 //===----------------------------------------------------------------------===//
379 
380 static LogicalResult verify(ReductionOp op) {
381   // Verify for 1-D vector.
382   int64_t rank = op.getVectorType().getRank();
383   if (rank != 1)
384     return op.emitOpError("unsupported reduction rank: ") << rank;
385 
386   // Verify supported reduction kind.
387   StringRef strKind = op.kind();
388   auto maybeKind = symbolizeCombiningKind(strKind);
389   if (!maybeKind)
390     return op.emitOpError("unknown reduction kind: ") << strKind;
391 
392   Type eltType = op.dest().getType();
393   if (!isSupportedCombiningKind(*maybeKind, eltType))
394     return op.emitOpError("unsupported reduction type '")
395            << eltType << "' for kind '" << op.kind() << "'";
396 
397   // Verify optional accumulator.
398   if (!op.acc().empty()) {
399     if (strKind != "add" && strKind != "mul")
400       return op.emitOpError("no accumulator for reduction kind: ") << strKind;
401     if (!eltType.isa<FloatType>())
402       return op.emitOpError("no accumulator for type: ") << eltType;
403   }
404 
405   return success();
406 }
407 
408 static ParseResult parseReductionOp(OpAsmParser &parser,
409                                     OperationState &result) {
410   SmallVector<OpAsmParser::OperandType, 2> operandsInfo;
411   Type redType;
412   Type resType;
413   Attribute attr;
414   if (parser.parseAttribute(attr, "kind", result.attributes) ||
415       parser.parseComma() || parser.parseOperandList(operandsInfo) ||
416       parser.parseColonType(redType) ||
417       parser.parseKeywordType("into", resType) ||
418       (!operandsInfo.empty() &&
419        parser.resolveOperand(operandsInfo[0], redType, result.operands)) ||
420       (operandsInfo.size() > 1 &&
421        parser.resolveOperand(operandsInfo[1], resType, result.operands)) ||
422       parser.addTypeToList(resType, result.types))
423     return failure();
424   if (operandsInfo.empty() || operandsInfo.size() > 2)
425     return parser.emitError(parser.getNameLoc(),
426                             "unsupported number of operands");
427   return success();
428 }
429 
430 static void print(OpAsmPrinter &p, ReductionOp op) {
431   p << " \"" << op.kind() << "\", " << op.vector();
432   if (!op.acc().empty())
433     p << ", " << op.acc();
434   p << " : " << op.vector().getType() << " into " << op.dest().getType();
435 }
436 
437 Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
438                                          OpBuilder &builder, Location loc,
439                                          Value vector) {
440   Type scalarType = vector.getType().cast<ShapedType>().getElementType();
441   switch (op) {
442   case arith::AtomicRMWKind::addf:
443   case arith::AtomicRMWKind::addi:
444     return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
445                                                builder.getStringAttr("add"),
446                                                vector, ValueRange{});
447   case arith::AtomicRMWKind::mulf:
448   case arith::AtomicRMWKind::muli:
449     return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
450                                                builder.getStringAttr("mul"),
451                                                vector, ValueRange{});
452   case arith::AtomicRMWKind::minf:
453     return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
454                                                builder.getStringAttr("minf"),
455                                                vector, ValueRange{});
456   case arith::AtomicRMWKind::mins:
457     return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
458                                                builder.getStringAttr("minsi"),
459                                                vector, ValueRange{});
460   case arith::AtomicRMWKind::minu:
461     return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
462                                                builder.getStringAttr("minui"),
463                                                vector, ValueRange{});
464   case arith::AtomicRMWKind::maxf:
465     return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
466                                                builder.getStringAttr("maxf"),
467                                                vector, ValueRange{});
468   case arith::AtomicRMWKind::maxs:
469     return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
470                                                builder.getStringAttr("maxsi"),
471                                                vector, ValueRange{});
472   case arith::AtomicRMWKind::maxu:
473     return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
474                                                builder.getStringAttr("maxui"),
475                                                vector, ValueRange{});
476   // TODO: Add remaining reduction operations.
477   default:
478     (void)emitOptionalError(loc, "Reduction operation type not supported");
479     break;
480   }
481   return nullptr;
482 }
483 
484 //===----------------------------------------------------------------------===//
485 // ContractionOp
486 //===----------------------------------------------------------------------===//
487 
488 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
489                                   Value lhs, Value rhs, Value acc,
490                                   ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
491                                   ArrayRef<StringRef> iteratorTypes) {
492   result.addOperands({lhs, rhs, acc});
493   result.addTypes(acc.getType());
494   result.addAttribute(getIndexingMapsAttrName(),
495                       builder.getAffineMapArrayAttr(
496                           AffineMap::inferFromExprList(indexingExprs)));
497   result.addAttribute(getIteratorTypesAttrName(),
498                       builder.getStrArrayAttr(iteratorTypes));
499 }
500 
501 void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
502                                   Value lhs, Value rhs, Value acc,
503                                   ArrayAttr indexingMaps,
504                                   ArrayAttr iteratorTypes) {
505   result.addOperands({lhs, rhs, acc});
506   result.addTypes(acc.getType());
507   result.addAttribute(getIndexingMapsAttrName(), indexingMaps);
508   result.addAttribute(getIteratorTypesAttrName(), iteratorTypes);
509   result.addAttribute(ContractionOp::getKindAttrName(),
510                       CombiningKindAttr::get(ContractionOp::getDefaultKind(),
511                                              builder.getContext()));
512 }
513 
514 static ParseResult parseContractionOp(OpAsmParser &parser,
515                                       OperationState &result) {
516   OpAsmParser::OperandType lhsInfo;
517   OpAsmParser::OperandType rhsInfo;
518   OpAsmParser::OperandType accInfo;
519   SmallVector<OpAsmParser::OperandType, 2> masksInfo;
520   SmallVector<Type, 2> types;
521   Type resultType;
522   auto loc = parser.getCurrentLocation();
523   DictionaryAttr dictAttr;
524   // TODO: Unify linalg op attribute parsing.
525   if (parser.parseAttribute(dictAttr, "_", result.attributes) ||
526       parser.parseOperand(lhsInfo) || parser.parseComma() ||
527       parser.parseOperand(rhsInfo) || parser.parseComma() ||
528       parser.parseOperand(accInfo) ||
529       parser.parseTrailingOperandList(masksInfo) ||
530       parser.parseOptionalAttrDict(result.attributes) ||
531       parser.parseColonTypeList(types) ||
532       parser.parseKeywordType("into", resultType) ||
533       parser.resolveOperand(lhsInfo, types[0], result.operands) ||
534       parser.resolveOperand(rhsInfo, types[1], result.operands) ||
535       parser.resolveOperand(accInfo, resultType, result.operands) ||
536       parser.addTypeToList(resultType, result.types))
537     return failure();
538   result.attributes.assign(dictAttr.getValue().begin(),
539                            dictAttr.getValue().end());
540   if (!result.attributes.get(ContractionOp::getKindAttrName())) {
541     result.addAttribute(ContractionOp::getKindAttrName(),
542                         CombiningKindAttr::get(ContractionOp::getDefaultKind(),
543                                                result.getContext()));
544   }
545   if (masksInfo.empty())
546     return success();
547   if (masksInfo.size() != 2)
548     return parser.emitError(parser.getNameLoc(),
549                             "expected zero or exactly 2 vector mask operands");
550   auto lhsType = types[0].cast<VectorType>();
551   auto rhsType = types[1].cast<VectorType>();
552   auto maskElementType = parser.getBuilder().getI1Type();
553   std::array<Type, 2> maskTypes = {
554       VectorType::Builder(lhsType).setElementType(maskElementType),
555       VectorType::Builder(rhsType).setElementType(maskElementType)};
556   if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
557     return failure();
558   return success();
559 }
560 
561 static void print(OpAsmPrinter &p, ContractionOp op) {
562   // TODO: Unify printing code with linalg ops.
563   auto attrNames = op.getTraitAttrNames();
564   llvm::StringSet<> traitAttrsSet;
565   traitAttrsSet.insert(attrNames.begin(), attrNames.end());
566   SmallVector<NamedAttribute, 8> attrs;
567   for (auto attr : op->getAttrs())
568     if (traitAttrsSet.count(attr.getName().strref()) > 0)
569       attrs.push_back(attr);
570 
571   auto dictAttr = DictionaryAttr::get(op.getContext(), attrs);
572   p << " " << dictAttr << " " << op.lhs() << ", ";
573   p << op.rhs() << ", " << op.acc();
574   if (op.masks().size() == 2)
575     p << ", " << op.masks();
576 
577   p.printOptionalAttrDict(op->getAttrs(), attrNames);
578   p << " : " << op.lhs().getType() << ", " << op.rhs().getType() << " into "
579     << op.getResultType();
580 }
581 
582 static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
583                          const std::vector<std::pair<int64_t, int64_t>> &map) {
584   for (auto &dimPair : map) {
585     if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
586         dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
587         lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
588       return false;
589   }
590   return true;
591 }
592 
593 static LogicalResult verifyOutputShape(
594     ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
595     Type resType,
596     const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
597     const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
598   DenseSet<int64_t> lhsContractingDimSet;
599   DenseSet<int64_t> rhsContractingDimSet;
600   for (auto &dimPair : contractingDimMap) {
601     lhsContractingDimSet.insert(dimPair.first);
602     rhsContractingDimSet.insert(dimPair.second);
603   }
604   DenseSet<int64_t> rhsBatchDimSet;
605   for (auto &dimPair : batchDimMap)
606     rhsBatchDimSet.insert(dimPair.second);
607 
608   // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
609   SmallVector<int64_t, 4> expectedResultDims;
610   for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
611     if (lhsContractingDimSet.count(i) > 0)
612       continue;
613     expectedResultDims.push_back(lhsType.getDimSize(i));
614   }
615 
616   // Add free dimensions from 'rhsType' to 'expectedResultDims'.
617   for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
618     if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
619       continue;
620     expectedResultDims.push_back(rhsType.getDimSize(i));
621   }
622 
623   // Verify 'expectedResultDims'.
624   if (expectedResultDims.empty()) {
625     // No batch or free dimension implies a scalar result.
626     if (resType.isa<VectorType>() || accType.isa<VectorType>())
627       return op.emitOpError("invalid accumulator/result vector shape");
628   } else {
629     // At least one batch or free dimension implies a vector result.
630     auto resVectorType = resType.dyn_cast<VectorType>();
631     auto accVectorType = accType.dyn_cast<VectorType>();
632     if (!resVectorType || !accVectorType)
633       return op.emitOpError("invalid accumulator/result vector shape");
634 
635     // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
636     // types fully define the result vector type. This assumes the affine maps
637     // are well-formed, which must have been verified already.
638     MLIRContext *ctx = op.getContext();
639     AffineMap lhsMap = op.getIndexingMaps()[0];
640     AffineMap rhsMap = op.getIndexingMaps()[1];
641     SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
642     for (auto pair :
643          {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
644       VectorType v = pair.first;
645       auto map = pair.second;
646       for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
647         unsigned pos = map.getDimPosition(idx);
648         if (!extents[pos])
649           extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
650       }
651     }
652     assert(llvm::all_of(extents, [](AffineExpr e) { return e; }) &&
653            "expected extent along all dimensions.");
654 
655     AffineMap resMap = op.getIndexingMaps()[2];
656     auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
657                                      /*symCount=*/0, extents, ctx);
658     // Compose the resMap with the extentsMap, which is a constant map.
659     AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
660     assert(llvm::all_of(
661                expectedMap.getResults(),
662                [](AffineExpr e) { return e.isa<AffineConstantExpr>(); }) &&
663            "expected constant extent along all dimensions.");
664     // Extract the expected shape and build the type.
665     auto expectedShape = llvm::to_vector<4>(
666         llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
667           return e.cast<AffineConstantExpr>().getValue();
668         }));
669     auto expected =
670         VectorType::get(expectedShape, resVectorType.getElementType());
671     if (resVectorType != expected || accVectorType != expected)
672       return op.emitOpError(
673                  "invalid accumulator/result vector shape, expected: ")
674              << expected;
675   }
676   return success();
677 }
678 
679 static LogicalResult verify(ContractionOp op) {
680   auto lhsType = op.getLhsType();
681   auto rhsType = op.getRhsType();
682   auto accType = op.getAccType();
683   auto resType = op.getResultType();
684 
685   // Verify that an indexing map was specified for each vector operand.
686   if (op.indexing_maps().size() != 3)
687     return op.emitOpError("expected an indexing map for each vector operand");
688 
689   // Verify that each index map has 'numIterators' inputs, no symbols, and
690   // that the number of map outputs equals the rank of its associated
691   // vector operand.
692   unsigned numIterators = op.iterator_types().getValue().size();
693   for (const auto &it : llvm::enumerate(op.indexing_maps())) {
694     auto index = it.index();
695     auto map = it.value().cast<AffineMapAttr>().getValue();
696     if (map.getNumSymbols() != 0)
697       return op.emitOpError("expected indexing map ")
698              << index << " to have no symbols";
699     auto vectorType = op.getOperand(index).getType().dyn_cast<VectorType>();
700     unsigned rank = vectorType ? vectorType.getShape().size() : 0;
701     // Verify that the map has the right number of inputs, outputs, and indices.
702     // This also correctly accounts for (..) -> () for rank-0 results.
703     if (map.getNumDims() != numIterators)
704       return op.emitOpError("expected indexing map ")
705              << index << " to have " << numIterators << " number of inputs";
706     if (map.getNumResults() != rank)
707       return op.emitOpError("expected indexing map ")
708              << index << " to have " << rank << " number of outputs";
709     if (!map.isProjectedPermutation())
710       return op.emitOpError("expected indexing map ")
711              << index << " to be a projected permutation of its inputs";
712   }
713 
714   auto contractingDimMap = op.getContractingDimMap();
715   auto batchDimMap = op.getBatchDimMap();
716 
717   // Verify at least one contracting dimension pair was specified.
718   if (contractingDimMap.empty())
719     return op.emitOpError("expected at least one contracting dimension pair");
720 
721   // Verify contracting dimension map was properly constructed.
722   if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
723     return op.emitOpError("invalid contracting dimension map");
724 
725   // Verify batch dimension map was properly constructed.
726   if (!verifyDimMap(lhsType, rhsType, batchDimMap))
727     return op.emitOpError("invalid batch dimension map");
728 
729   // Verify 'accType' and 'resType' shape.
730   if (failed(verifyOutputShape(op, lhsType, rhsType, accType, resType,
731                                contractingDimMap, batchDimMap)))
732     return failure();
733 
734   // Verify that either two vector masks are set or none are set.
735   auto lhsMaskType = op.getLHSVectorMaskType();
736   auto rhsMaskType = op.getRHSVectorMaskType();
737   if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType))
738     return op.emitOpError("invalid number of vector masks specified");
739   if (lhsMaskType && rhsMaskType) {
740     // Verify mask rank == argument rank.
741     if (lhsMaskType.getShape().size() != lhsType.getShape().size() ||
742         rhsMaskType.getShape().size() != rhsType.getShape().size())
743       return op.emitOpError("invalid vector mask rank");
744   }
745 
746   // Verify supported combining kind.
747   auto vectorType = resType.dyn_cast<VectorType>();
748   auto elementType = vectorType ? vectorType.getElementType() : resType;
749   if (!isSupportedCombiningKind(op.kind(), elementType))
750     return op.emitOpError("unsupported contraction type");
751 
752   return success();
753 }
754 
755 ArrayRef<StringRef> ContractionOp::getTraitAttrNames() {
756   static constexpr StringRef names[3] = {getIndexingMapsAttrName(),
757                                          getIteratorTypesAttrName(),
758                                          ContractionOp::getKindAttrName()};
759   return llvm::makeArrayRef(names);
760 }
761 
762 static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
763   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
764     if (targetExpr == map.getResult(i))
765       return i;
766   return -1;
767 }
768 
769 static std::vector<std::pair<int64_t, int64_t>>
770 getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
771           StringRef targetIteratorTypeName, MLIRContext *context) {
772   std::vector<std::pair<int64_t, int64_t>> dimMap;
773   for (const auto &it : llvm::enumerate(iteratorTypes)) {
774     auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
775     if (iteratorTypeName != targetIteratorTypeName)
776       continue;
777     // Search lhs/rhs map results for 'targetExpr'.
778     auto targetExpr = getAffineDimExpr(it.index(), context);
779     int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
780     int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
781     if (lhsDim >= 0 && rhsDim >= 0)
782       dimMap.emplace_back(lhsDim, rhsDim);
783   }
784   return dimMap;
785 }
786 
787 void ContractionOp::getIterationBounds(
788     SmallVectorImpl<int64_t> &iterationBounds) {
789   auto lhsShape = getLhsType().getShape();
790   auto resVectorType = getResultType().dyn_cast<VectorType>();
791   SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
792   SmallVector<int64_t, 2> iterationShape;
793   for (const auto &it : llvm::enumerate(iterator_types())) {
794     // Search lhs/rhs map results for 'targetExpr'.
795     auto targetExpr = getAffineDimExpr(it.index(), getContext());
796     auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
797     if (iteratorTypeName == getReductionIteratorTypeName()) {
798       // Get reduction dim size from lhs shape (same size in rhsShape).
799       int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
800       assert(lhsDimIndex >= 0);
801       iterationBounds.push_back(lhsShape[lhsDimIndex]);
802       continue;
803     }
804     // Get parallel dimension size from result shape.
805     int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
806     assert(resDimIndex >= 0);
807     assert(resVectorType != nullptr);
808     iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
809   }
810 }
811 
812 void ContractionOp::getIterationIndexMap(
813     std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
814   unsigned numMaps = indexing_maps().getValue().size();
815   iterationIndexMap.resize(numMaps);
816   for (const auto &it : llvm::enumerate(indexing_maps())) {
817     auto index = it.index();
818     auto map = it.value().cast<AffineMapAttr>().getValue();
819     for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
820       auto dim = map.getResult(i).cast<AffineDimExpr>();
821       iterationIndexMap[index][dim.getPosition()] = i;
822     }
823   }
824 }
825 
826 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
827   SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
828   return getDimMap(indexingMaps, iterator_types(),
829                    getReductionIteratorTypeName(), getContext());
830 }
831 
832 std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
833   SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
834   return getDimMap(indexingMaps, iterator_types(),
835                    getParallelIteratorTypeName(), getContext());
836 }
837 
838 SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() {
839   return llvm::to_vector<4>(
840       llvm::map_range(indexing_maps().getValue(), [](Attribute mapAttr) {
841         return mapAttr.cast<AffineMapAttr>().getValue();
842       }));
843 }
844 
845 Optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
846   SmallVector<int64_t, 4> shape;
847   getIterationBounds(shape);
848   return shape;
849 }
850 
851 /// Return a fused vector::ContractionOp which represents a patterns such as:
852 ///
853 /// ```mlir
854 ///    %c0 = vector.constant 0: ...
855 ///    %c = vector.contract %a, %b, %c0: ...
856 ///    %e = add %c, %d: ...
857 /// ```
858 ///
859 /// by:
860 ///
861 /// ```mlir
862 ///    %e = vector.contract %a, %b, %d: ...
863 /// ```
864 ///
865 /// Return null if the canonicalization does not apply.
866 // TODO: This should be a folding of Add into Contract in core but while they
867 // live in different dialects, it is not possible without unnatural
868 // dependencies.
869 template <typename AddOpType>
870 struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
871   using OpRewritePattern<AddOpType>::OpRewritePattern;
872 
873   LogicalResult matchAndRewrite(AddOpType addOp,
874                                 PatternRewriter &rewriter) const override {
875     auto canonicalize = [&](Value maybeContraction,
876                             Value otherOperand) -> vector::ContractionOp {
877       vector::ContractionOp contractionOp =
878           dyn_cast_or_null<vector::ContractionOp>(
879               maybeContraction.getDefiningOp());
880       if (!contractionOp)
881         return vector::ContractionOp();
882       if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
883               contractionOp.acc().getDefiningOp())) {
884         if (maybeZero.getValue() ==
885             rewriter.getZeroAttr(contractionOp.acc().getType())) {
886           BlockAndValueMapping bvm;
887           bvm.map(contractionOp.acc(), otherOperand);
888           auto newContraction =
889               cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
890           rewriter.replaceOp(addOp, newContraction.getResult());
891           return newContraction;
892         }
893       }
894       return vector::ContractionOp();
895     };
896 
897     Value a = addOp->getOperand(0), b = addOp->getOperand(1);
898     vector::ContractionOp contract = canonicalize(a, b);
899     contract = contract ? contract : canonicalize(b, a);
900     return contract ? success() : failure();
901   }
902 };
903 
904 void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
905                                                 MLIRContext *context) {
906   results.add<CanonicalizeContractAdd<arith::AddIOp>,
907               CanonicalizeContractAdd<arith::AddFOp>>(context);
908 }
909 
910 //===----------------------------------------------------------------------===//
911 // ExtractElementOp
912 //===----------------------------------------------------------------------===//
913 
914 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
915                                      Value source) {
916   result.addOperands({source});
917   result.addTypes(source.getType().cast<VectorType>().getElementType());
918 }
919 
920 void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
921                                      Value source, Value position) {
922   result.addOperands({source, position});
923   result.addTypes(source.getType().cast<VectorType>().getElementType());
924 }
925 
926 static LogicalResult verify(vector::ExtractElementOp op) {
927   VectorType vectorType = op.getVectorType();
928   if (vectorType.getRank() == 0) {
929     if (op.position())
930       return op.emitOpError("expected position to be empty with 0-D vector");
931     return success();
932   }
933   if (vectorType.getRank() != 1)
934     return op.emitOpError("unexpected >1 vector rank");
935   if (!op.position())
936     return op.emitOpError("expected position for 1-D vector");
937   return success();
938 }
939 
940 //===----------------------------------------------------------------------===//
941 // ExtractOp
942 //===----------------------------------------------------------------------===//
943 
944 static Type inferExtractOpResultType(VectorType vectorType,
945                                      ArrayAttr position) {
946   if (static_cast<int64_t>(position.size()) == vectorType.getRank())
947     return vectorType.getElementType();
948   return VectorType::get(vectorType.getShape().drop_front(position.size()),
949                          vectorType.getElementType());
950 }
951 
952 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
953                               Value source, ArrayRef<int64_t> position) {
954   result.addOperands(source);
955   auto positionAttr = getVectorSubscriptAttr(builder, position);
956   result.addTypes(inferExtractOpResultType(source.getType().cast<VectorType>(),
957                                            positionAttr));
958   result.addAttribute(getPositionAttrName(), positionAttr);
959 }
960 
961 // Convenience builder which assumes the values are constant indices.
962 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
963                               Value source, ValueRange position) {
964   SmallVector<int64_t, 4> positionConstants =
965       llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
966         return pos.getDefiningOp<arith::ConstantIndexOp>().value();
967       }));
968   build(builder, result, source, positionConstants);
969 }
970 
971 static void print(OpAsmPrinter &p, vector::ExtractOp op) {
972   p << " " << op.vector() << op.position();
973   p.printOptionalAttrDict(op->getAttrs(), {"position"});
974   p << " : " << op.vector().getType();
975 }
976 
977 static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) {
978   SMLoc attributeLoc, typeLoc;
979   NamedAttrList attrs;
980   OpAsmParser::OperandType vector;
981   Type type;
982   Attribute attr;
983   if (parser.parseOperand(vector) || parser.getCurrentLocation(&attributeLoc) ||
984       parser.parseAttribute(attr, "position", attrs) ||
985       parser.parseOptionalAttrDict(attrs) ||
986       parser.getCurrentLocation(&typeLoc) || parser.parseColonType(type))
987     return failure();
988 
989   auto vectorType = type.dyn_cast<VectorType>();
990   if (!vectorType)
991     return parser.emitError(typeLoc, "expected vector type");
992 
993   auto positionAttr = attr.dyn_cast<ArrayAttr>();
994   if (!positionAttr ||
995       static_cast<int64_t>(positionAttr.size()) > vectorType.getRank())
996     return parser.emitError(
997         attributeLoc,
998         "expected position attribute of rank smaller than vector rank");
999 
1000   Type resType = inferExtractOpResultType(vectorType, positionAttr);
1001   result.attributes = attrs;
1002   return failure(parser.resolveOperand(vector, type, result.operands) ||
1003                  parser.addTypeToList(resType, result.types));
1004 }
1005 
1006 static LogicalResult verify(vector::ExtractOp op) {
1007   auto positionAttr = op.position().getValue();
1008   if (positionAttr.size() > static_cast<unsigned>(op.getVectorType().getRank()))
1009     return op.emitOpError(
1010         "expected position attribute of rank smaller than vector rank");
1011   for (const auto &en : llvm::enumerate(positionAttr)) {
1012     auto attr = en.value().dyn_cast<IntegerAttr>();
1013     if (!attr || attr.getInt() < 0 ||
1014         attr.getInt() >= op.getVectorType().getDimSize(en.index()))
1015       return op.emitOpError("expected position attribute #")
1016              << (en.index() + 1)
1017              << " to be a non-negative integer smaller than the corresponding "
1018                 "vector dimension";
1019   }
1020   return success();
1021 }
1022 
1023 template <typename IntType>
1024 static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
1025   return llvm::to_vector<4>(llvm::map_range(
1026       arrayAttr.getAsRange<IntegerAttr>(),
1027       [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1028 }
1029 
1030 /// Fold the result of chains of ExtractOp in place by simply concatenating the
1031 /// positions.
1032 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
1033   if (!extractOp.vector().getDefiningOp<ExtractOp>())
1034     return failure();
1035 
1036   SmallVector<int64_t, 4> globalPosition;
1037   ExtractOp currentOp = extractOp;
1038   auto extrPos = extractVector<int64_t>(currentOp.position());
1039   globalPosition.append(extrPos.rbegin(), extrPos.rend());
1040   while (ExtractOp nextOp = currentOp.vector().getDefiningOp<ExtractOp>()) {
1041     currentOp = nextOp;
1042     auto extrPos = extractVector<int64_t>(currentOp.position());
1043     globalPosition.append(extrPos.rbegin(), extrPos.rend());
1044   }
1045   extractOp.setOperand(currentOp.vector());
1046   // OpBuilder is only used as a helper to build an I64ArrayAttr.
1047   OpBuilder b(extractOp.getContext());
1048   std::reverse(globalPosition.begin(), globalPosition.end());
1049   extractOp->setAttr(ExtractOp::getPositionAttrName(),
1050                      b.getI64ArrayAttr(globalPosition));
1051   return success();
1052 }
1053 
1054 namespace {
1055 /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
1056 /// Walk back a chain of InsertOp/TransposeOp until we hit a match.
1057 /// Compose TransposeOp permutations as we walk back.
1058 /// This helper class keeps an updated extraction position `extractPosition`
1059 /// with extra trailing sentinels.
1060 /// The sentinels encode the internal transposition status of the result vector.
1061 /// As we iterate, extractPosition is permuted and updated.
1062 class ExtractFromInsertTransposeChainState {
1063 public:
1064   ExtractFromInsertTransposeChainState(ExtractOp e);
1065 
1066   /// Iterate over producing insert and transpose ops until we find a fold.
1067   Value fold();
1068 
1069 private:
1070   /// Return true if the vector at position `a` is contained within the vector
1071   /// at position `b`. Under insert/extract semantics, this is the same as `a`
1072   /// is a prefix of `b`.
1073   template <typename ContainerA, typename ContainerB>
1074   bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
1075     return a.size() <= b.size() &&
1076            std::equal(a.begin(), a.begin() + a.size(), b.begin());
1077   }
1078 
1079   /// Return true if the vector at position `a` intersects the vector at
1080   /// position `b`. Under insert/extract semantics, this is the same as equality
1081   /// of all entries of `a` that are >=0 with the corresponding entries of b.
1082   /// Comparison is on the common prefix (i.e. zip).
1083   template <typename ContainerA, typename ContainerB>
1084   bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
1085     for (auto it : llvm::zip(a, b)) {
1086       if (std::get<0>(it) < 0 || std::get<0>(it) < 0)
1087         continue;
1088       if (std::get<0>(it) != std::get<1>(it))
1089         return false;
1090     }
1091     return true;
1092   }
1093 
1094   /// Folding is only possible in the absence of an internal permutation in the
1095   /// result vector.
1096   bool canFold() {
1097     return (sentinels ==
1098             makeArrayRef(extractPosition).drop_front(extractedRank));
1099   }
1100 
1101   // Helper to get the next defining op of interest.
1102   void updateStateForNextIteration(Value v) {
1103     nextInsertOp = v.getDefiningOp<vector::InsertOp>();
1104     nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
1105   };
1106 
1107   // Case 1. If we hit a transpose, just compose the map and iterate.
1108   // Invariant: insert + transpose do not change rank, we can always compose.
1109   LogicalResult handleTransposeOp();
1110 
1111   // Case 2: the insert position matches extractPosition exactly, early return.
1112   LogicalResult handleInsertOpWithMatchingPos(Value &res);
1113 
1114   /// Case 3: if the insert position is a prefix of extractPosition, extract a
1115   /// portion of the source of the insert.
1116   /// Example:
1117   /// ```
1118   /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
1119   /// // extractPosition == [1, 2, 3]
1120   /// %ext = vector.extract %ins[1, 0]: vector<3x4x5>
1121   /// // can fold to vector.extract %source[0, 3]
1122   /// %ext = vector.extract %source[3]: vector<5x6>
1123   /// ```
1124   /// To traverse through %source, we need to set the leading dims to 0 and
1125   /// drop the extra leading dims.
1126   /// This method updates the internal state.
1127   LogicalResult handleInsertOpWithPrefixPos(Value &res);
1128 
1129   /// Try to fold in place to extract(source, extractPosition) and return the
1130   /// folded result. Return null if folding is not possible (e.g. due to an
1131   /// internal tranposition in the result).
1132   Value tryToFoldExtractOpInPlace(Value source);
1133 
1134   ExtractOp extractOp;
1135   int64_t vectorRank;
1136   int64_t extractedRank;
1137 
1138   InsertOp nextInsertOp;
1139   TransposeOp nextTransposeOp;
1140 
1141   /// Sentinel values that encode the internal permutation status of the result.
1142   /// They are set to (-1, ... , -k) at the beginning and appended to
1143   /// `extractPosition`.
1144   /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
1145   /// ensure that there is no internal transposition.
1146   /// Internal transposition cannot be accounted for with a folding pattern.
1147   // TODO: We could relax the internal transposition with an extra transposition
1148   // operation in a future canonicalizer.
1149   SmallVector<int64_t> sentinels;
1150   SmallVector<int64_t> extractPosition;
1151 };
1152 } // namespace
1153 
1154 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1155     ExtractOp e)
1156     : extractOp(e), vectorRank(extractOp.getVectorType().getRank()),
1157       extractedRank(extractOp.position().size()) {
1158   assert(vectorRank >= extractedRank && "extracted pos overflow");
1159   sentinels.reserve(vectorRank - extractedRank);
1160   for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1161     sentinels.push_back(-(i + 1));
1162   extractPosition = extractVector<int64_t>(extractOp.position());
1163   llvm::append_range(extractPosition, sentinels);
1164 }
1165 
1166 // Case 1. If we hit a transpose, just compose the map and iterate.
1167 // Invariant: insert + transpose do not change rank, we can always compose.
1168 LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1169   if (!nextTransposeOp)
1170     return failure();
1171   auto permutation = extractVector<unsigned>(nextTransposeOp.transp());
1172   AffineMap m = inversePermutation(
1173       AffineMap::getPermutationMap(permutation, extractOp.getContext()));
1174   extractPosition = applyPermutationMap(m, makeArrayRef(extractPosition));
1175   return success();
1176 }
1177 
1178 // Case 2: the insert position matches extractPosition exactly, early return.
1179 LogicalResult
1180 ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1181     Value &res) {
1182   auto insertedPos = extractVector<int64_t>(nextInsertOp.position());
1183   if (makeArrayRef(insertedPos) !=
1184       llvm::makeArrayRef(extractPosition).take_front(extractedRank))
1185     return failure();
1186   // Case 2.a. early-exit fold.
1187   res = nextInsertOp.source();
1188   // Case 2.b. if internal transposition is present, canFold will be false.
1189   return success();
1190 }
1191 
1192 /// Case 3: if inserted position is a prefix of extractPosition,
1193 /// extract a portion of the source of the insertion.
1194 /// This method updates the internal state.
1195 LogicalResult
1196 ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1197   auto insertedPos = extractVector<int64_t>(nextInsertOp.position());
1198   if (!isContainedWithin(insertedPos, extractPosition))
1199     return failure();
1200   // Set leading dims to zero.
1201   std::fill_n(extractPosition.begin(), insertedPos.size(), 0);
1202   // Drop extra leading dims.
1203   extractPosition.erase(extractPosition.begin(),
1204                         extractPosition.begin() + insertedPos.size());
1205   extractedRank = extractPosition.size() - sentinels.size();
1206   // Case 3.a. early-exit fold (break and delegate to post-while path).
1207   res = nextInsertOp.source();
1208   // Case 3.b. if internal transposition is present, canFold will be false.
1209   return success();
1210 }
1211 
1212 /// Try to fold in place to extract(source, extractPosition) and return the
1213 /// folded result. Return null if folding is not possible (e.g. due to an
1214 /// internal tranposition in the result).
1215 Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1216     Value source) {
1217   // If we can't fold (either internal transposition, or nothing to fold), bail.
1218   bool nothingToFold = (source == extractOp.vector());
1219   if (nothingToFold || !canFold())
1220     return Value();
1221   // Otherwise, fold by updating the op inplace and return its result.
1222   OpBuilder b(extractOp.getContext());
1223   extractOp->setAttr(
1224       extractOp.positionAttrName(),
1225       b.getI64ArrayAttr(
1226           makeArrayRef(extractPosition).take_front(extractedRank)));
1227   extractOp.vectorMutable().assign(source);
1228   return extractOp.getResult();
1229 }
1230 
1231 /// Iterate over producing insert and transpose ops until we find a fold.
1232 Value ExtractFromInsertTransposeChainState::fold() {
1233   Value valueToExtractFrom = extractOp.vector();
1234   updateStateForNextIteration(valueToExtractFrom);
1235   while (nextInsertOp || nextTransposeOp) {
1236     // Case 1. If we hit a transpose, just compose the map and iterate.
1237     // Invariant: insert + transpose do not change rank, we can always compose.
1238     if (succeeded(handleTransposeOp())) {
1239       valueToExtractFrom = nextTransposeOp.vector();
1240       updateStateForNextIteration(valueToExtractFrom);
1241       continue;
1242     }
1243 
1244     Value result;
1245     // Case 2: the position match exactly.
1246     if (succeeded(handleInsertOpWithMatchingPos(result)))
1247       return result;
1248 
1249     // Case 3: if the inserted position is a prefix of extractPosition, we can
1250     // just extract a portion of the source of the insert.
1251     if (succeeded(handleInsertOpWithPrefixPos(result)))
1252       return tryToFoldExtractOpInPlace(result);
1253 
1254     // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
1255     // values. This is a more difficult case and we bail.
1256     auto insertedPos = extractVector<int64_t>(nextInsertOp.position());
1257     if (isContainedWithin(extractPosition, insertedPos) ||
1258         intersectsWhereNonNegative(extractPosition, insertedPos))
1259       return Value();
1260 
1261     // Case 5: No intersection, we forward the extract to insertOp.dest().
1262     valueToExtractFrom = nextInsertOp.dest();
1263     updateStateForNextIteration(valueToExtractFrom);
1264   }
1265   // If after all this we can fold, go for it.
1266   return tryToFoldExtractOpInPlace(valueToExtractFrom);
1267 }
1268 
1269 /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
1270 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1271   Operation *defOp = extractOp.vector().getDefiningOp();
1272   if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1273     return Value();
1274   Value source = defOp->getOperand(0);
1275   if (extractOp.getType() == source.getType())
1276     return source;
1277   auto getRank = [](Type type) {
1278     return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0;
1279   };
1280   unsigned broadcastSrcRank = getRank(source.getType());
1281   unsigned extractResultRank = getRank(extractOp.getType());
1282   if (extractResultRank < broadcastSrcRank) {
1283     auto extractPos = extractVector<int64_t>(extractOp.position());
1284     unsigned rankDiff = broadcastSrcRank - extractResultRank;
1285     extractPos.erase(
1286         extractPos.begin(),
1287         std::next(extractPos.begin(), extractPos.size() - rankDiff));
1288     extractOp.setOperand(source);
1289     // OpBuilder is only used as a helper to build an I64ArrayAttr.
1290     OpBuilder b(extractOp.getContext());
1291     extractOp->setAttr(ExtractOp::getPositionAttrName(),
1292                        b.getI64ArrayAttr(extractPos));
1293     return extractOp.getResult();
1294   }
1295   return Value();
1296 }
1297 
1298 // Fold extractOp with source coming from ShapeCast op.
1299 static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1300   auto shapeCastOp = extractOp.vector().getDefiningOp<vector::ShapeCastOp>();
1301   if (!shapeCastOp)
1302     return Value();
1303   // Get the nth dimension size starting from lowest dimension.
1304   auto getDimReverse = [](VectorType type, int64_t n) {
1305     return type.getShape().take_back(n + 1).front();
1306   };
1307   int64_t destinationRank =
1308       extractOp.getType().isa<VectorType>()
1309           ? extractOp.getType().cast<VectorType>().getRank()
1310           : 0;
1311   if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1312     return Value();
1313   if (destinationRank > 0) {
1314     auto destinationType = extractOp.getResult().getType().cast<VectorType>();
1315     for (int64_t i = 0; i < destinationRank; i++) {
1316       // The lowest dimension of of the destination must match the lowest
1317       // dimension of the shapecast op source.
1318       // TODO: This case could be support in a canonicalization pattern.
1319       if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1320           getDimReverse(destinationType, i))
1321         return Value();
1322     }
1323   }
1324   // Extract the strides associated with the extract op vector source. Then use
1325   // this to calculate a linearized position for the extract.
1326   auto extractedPos = extractVector<int64_t>(extractOp.position());
1327   std::reverse(extractedPos.begin(), extractedPos.end());
1328   SmallVector<int64_t, 4> strides;
1329   int64_t stride = 1;
1330   for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1331     strides.push_back(stride);
1332     stride *= getDimReverse(extractOp.getVectorType(), i + destinationRank);
1333   }
1334 
1335   int64_t position = linearize(extractedPos, strides);
1336   // Then extract the strides associated to the shapeCast op vector source and
1337   // delinearize the position using those strides.
1338   SmallVector<int64_t, 4> newStrides;
1339   int64_t numDimension =
1340       shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1341   stride = 1;
1342   for (int64_t i = 0; i < numDimension; i++) {
1343     newStrides.push_back(stride);
1344     stride *=
1345         getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1346   }
1347   std::reverse(newStrides.begin(), newStrides.end());
1348   SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position);
1349   // OpBuilder is only used as a helper to build an I64ArrayAttr.
1350   OpBuilder b(extractOp.getContext());
1351   extractOp->setAttr(ExtractOp::getPositionAttrName(),
1352                      b.getI64ArrayAttr(newPosition));
1353   extractOp.setOperand(shapeCastOp.source());
1354   return extractOp.getResult();
1355 }
1356 
1357 /// Fold an ExtractOp from ExtractStridedSliceOp.
1358 static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
1359   auto extractStridedSliceOp =
1360       extractOp.vector().getDefiningOp<vector::ExtractStridedSliceOp>();
1361   if (!extractStridedSliceOp)
1362     return Value();
1363   // Return if 'extractStridedSliceOp' has non-unit strides.
1364   if (extractStridedSliceOp.hasNonUnitStrides())
1365     return Value();
1366 
1367   // Trim offsets for dimensions fully extracted.
1368   auto sliceOffsets = extractVector<int64_t>(extractStridedSliceOp.offsets());
1369   while (!sliceOffsets.empty()) {
1370     size_t lastOffset = sliceOffsets.size() - 1;
1371     if (sliceOffsets.back() != 0 ||
1372         extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1373             extractStridedSliceOp.getVectorType().getDimSize(lastOffset))
1374       break;
1375     sliceOffsets.pop_back();
1376   }
1377   unsigned destinationRank = 0;
1378   if (auto vecType = extractOp.getType().dyn_cast<VectorType>())
1379     destinationRank = vecType.getRank();
1380   // The dimensions of the result need to be untouched by the
1381   // extractStridedSlice op.
1382   if (destinationRank >
1383       extractStridedSliceOp.getVectorType().getRank() - sliceOffsets.size())
1384     return Value();
1385   auto extractedPos = extractVector<int64_t>(extractOp.position());
1386   assert(extractedPos.size() >= sliceOffsets.size());
1387   for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1388     extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1389   extractOp.vectorMutable().assign(extractStridedSliceOp.vector());
1390   // OpBuilder is only used as a helper to build an I64ArrayAttr.
1391   OpBuilder b(extractOp.getContext());
1392   extractOp->setAttr(ExtractOp::getPositionAttrName(),
1393                      b.getI64ArrayAttr(extractedPos));
1394   return extractOp.getResult();
1395 }
1396 
1397 /// Fold extract_op fed from a chain of insertStridedSlice ops.
1398 static Value foldExtractStridedOpFromInsertChain(ExtractOp op) {
1399   int64_t destinationRank = op.getType().isa<VectorType>()
1400                                 ? op.getType().cast<VectorType>().getRank()
1401                                 : 0;
1402   auto insertOp = op.vector().getDefiningOp<InsertStridedSliceOp>();
1403   while (insertOp) {
1404     int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1405                              insertOp.getSourceVectorType().getRank();
1406     if (destinationRank > insertOp.getSourceVectorType().getRank())
1407       return Value();
1408     auto insertOffsets = extractVector<int64_t>(insertOp.offsets());
1409     auto extractOffsets = extractVector<int64_t>(op.position());
1410 
1411     if (llvm::any_of(insertOp.strides(), [](Attribute attr) {
1412           return attr.cast<IntegerAttr>().getInt() != 1;
1413         }))
1414       return Value();
1415     bool disjoint = false;
1416     SmallVector<int64_t, 4> offsetDiffs;
1417     for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1418       int64_t start = insertOffsets[dim];
1419       int64_t size =
1420           (dim < insertRankDiff)
1421               ? 1
1422               : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1423       int64_t end = start + size;
1424       int64_t offset = extractOffsets[dim];
1425       // Check if the start of the extract offset is in the interval inserted.
1426       if (start <= offset && offset < end) {
1427         if (dim >= insertRankDiff)
1428           offsetDiffs.push_back(offset - start);
1429         continue;
1430       }
1431       disjoint = true;
1432       break;
1433     }
1434     // The extract element chunk overlap with the vector inserted.
1435     if (!disjoint) {
1436       // If any of the inner dimensions are only partially inserted we have a
1437       // partial overlap.
1438       int64_t srcRankDiff =
1439           insertOp.getSourceVectorType().getRank() - destinationRank;
1440       for (int64_t i = 0; i < destinationRank; i++) {
1441         if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1442             insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1443                                                     insertRankDiff))
1444           return Value();
1445       }
1446       op.vectorMutable().assign(insertOp.source());
1447       // OpBuilder is only used as a helper to build an I64ArrayAttr.
1448       OpBuilder b(op.getContext());
1449       op->setAttr(ExtractOp::getPositionAttrName(),
1450                   b.getI64ArrayAttr(offsetDiffs));
1451       return op.getResult();
1452     }
1453     // If the chunk extracted is disjoint from the chunk inserted, keep
1454     // looking in the insert chain.
1455     insertOp = insertOp.dest().getDefiningOp<InsertStridedSliceOp>();
1456   }
1457   return Value();
1458 }
1459 
1460 OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
1461   if (position().empty())
1462     return vector();
1463   if (succeeded(foldExtractOpFromExtractChain(*this)))
1464     return getResult();
1465   if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
1466     return res;
1467   if (auto res = foldExtractFromBroadcast(*this))
1468     return res;
1469   if (auto res = foldExtractFromShapeCast(*this))
1470     return res;
1471   if (auto val = foldExtractFromExtractStrided(*this))
1472     return val;
1473   if (auto val = foldExtractStridedOpFromInsertChain(*this))
1474     return val;
1475   return OpFoldResult();
1476 }
1477 
1478 namespace {
1479 
1480 // Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
1481 class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
1482 public:
1483   using OpRewritePattern<ExtractOp>::OpRewritePattern;
1484 
1485   LogicalResult matchAndRewrite(ExtractOp extractOp,
1486                                 PatternRewriter &rewriter) const override {
1487     Operation *defOp = extractOp.vector().getDefiningOp();
1488     if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1489       return failure();
1490     Value source = defOp->getOperand(0);
1491     if (extractOp.getType() == source.getType())
1492       return failure();
1493     auto getRank = [](Type type) {
1494       return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0;
1495     };
1496     unsigned broadcastSrcRank = getRank(source.getType());
1497     unsigned extractResultRank = getRank(extractOp.getType());
1498     // We only consider the case where the rank of the source is smaller than
1499     // the rank of the extract dst. The other cases are handled in the folding
1500     // patterns.
1501     if (extractResultRank <= broadcastSrcRank)
1502       return failure();
1503     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1504         extractOp, extractOp.getType(), source);
1505     return success();
1506   }
1507 };
1508 
1509 // Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
1510 class ExtractOpConstantFolder final : public OpRewritePattern<ExtractOp> {
1511 public:
1512   using OpRewritePattern<ExtractOp>::OpRewritePattern;
1513 
1514   LogicalResult matchAndRewrite(ExtractOp extractOp,
1515                                 PatternRewriter &rewriter) const override {
1516     // Return if 'extractStridedSliceOp' operand is not defined by a
1517     // ConstantOp.
1518     auto constantOp = extractOp.vector().getDefiningOp<arith::ConstantOp>();
1519     if (!constantOp)
1520       return failure();
1521     auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
1522     if (!dense)
1523       return failure();
1524     Attribute newAttr = dense.getSplatValue<Attribute>();
1525     if (auto vecDstType = extractOp.getType().dyn_cast<VectorType>())
1526       newAttr = DenseElementsAttr::get(vecDstType, newAttr);
1527     rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
1528     return success();
1529   }
1530 };
1531 
1532 } // namespace
1533 
1534 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1535                                             MLIRContext *context) {
1536   results.add<ExtractOpConstantFolder, ExtractOpFromBroadcast>(context);
1537 }
1538 
1539 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
1540                                        SmallVectorImpl<int64_t> &results) {
1541   for (auto attr : arrayAttr)
1542     results.push_back(attr.cast<IntegerAttr>().getInt());
1543 }
1544 
1545 //===----------------------------------------------------------------------===//
1546 // ExtractMapOp
1547 //===----------------------------------------------------------------------===//
1548 
1549 void ExtractMapOp::build(OpBuilder &builder, OperationState &result,
1550                          Value vector, ValueRange ids,
1551                          ArrayRef<int64_t> multiplicity,
1552                          AffineMap permutationMap) {
1553   assert(ids.size() == multiplicity.size() &&
1554          ids.size() == permutationMap.getNumResults());
1555   assert(permutationMap.isProjectedPermutation());
1556   VectorType type = vector.getType().cast<VectorType>();
1557   SmallVector<int64_t, 4> newShape(type.getShape().begin(),
1558                                    type.getShape().end());
1559   for (unsigned i = 0, e = permutationMap.getNumResults(); i < e; i++) {
1560     AffineExpr expr = permutationMap.getResult(i);
1561     auto dim = expr.cast<AffineDimExpr>();
1562     newShape[dim.getPosition()] = newShape[dim.getPosition()] / multiplicity[i];
1563   }
1564   VectorType resultType = VectorType::get(newShape, type.getElementType());
1565   ExtractMapOp::build(builder, result, resultType, vector, ids);
1566 }
1567 
1568 static LogicalResult verify(ExtractMapOp op) {
1569   if (op.getSourceVectorType().getRank() != op.getResultType().getRank())
1570     return op.emitOpError(
1571         "expected source and destination vectors of same rank");
1572   unsigned numId = 0;
1573   for (unsigned i = 0, e = op.getSourceVectorType().getRank(); i < e; ++i) {
1574     if (op.getSourceVectorType().getDimSize(i) %
1575             op.getResultType().getDimSize(i) !=
1576         0)
1577       return op.emitOpError("source vector dimensions must be a multiple of "
1578                             "destination vector dimensions");
1579     if (op.getSourceVectorType().getDimSize(i) !=
1580         op.getResultType().getDimSize(i))
1581       numId++;
1582   }
1583   if (numId != op.ids().size())
1584     return op.emitOpError("expected number of ids must match the number of "
1585                           "dimensions distributed");
1586   return success();
1587 }
1588 
1589 OpFoldResult ExtractMapOp::fold(ArrayRef<Attribute> operands) {
1590   auto insert = vector().getDefiningOp<vector::InsertMapOp>();
1591   if (insert == nullptr || getType() != insert.vector().getType() ||
1592       ids() != insert.ids())
1593     return {};
1594   return insert.vector();
1595 }
1596 
1597 void ExtractMapOp::getMultiplicity(SmallVectorImpl<int64_t> &multiplicity) {
1598   assert(multiplicity.empty());
1599   for (unsigned i = 0, e = getSourceVectorType().getRank(); i < e; i++) {
1600     if (getSourceVectorType().getDimSize(i) != getResultType().getDimSize(i))
1601       multiplicity.push_back(getSourceVectorType().getDimSize(i) /
1602                              getResultType().getDimSize(i));
1603   }
1604 }
1605 
1606 template <typename MapOp>
1607 AffineMap calculateImplicitMap(MapOp op) {
1608   SmallVector<AffineExpr, 4> perm;
1609   // Check which dimension have a multiplicity greater than 1 and associated
1610   // them to the IDs in order.
1611   for (unsigned i = 0, e = op.getSourceVectorType().getRank(); i < e; i++) {
1612     if (op.getSourceVectorType().getDimSize(i) !=
1613         op.getResultType().getDimSize(i))
1614       perm.push_back(getAffineDimExpr(i, op.getContext()));
1615   }
1616   auto map = AffineMap::get(op.getSourceVectorType().getRank(), 0, perm,
1617                             op.getContext());
1618   return map;
1619 }
1620 
1621 AffineMap ExtractMapOp::map() { return calculateImplicitMap(*this); }
1622 
1623 //===----------------------------------------------------------------------===//
1624 // FmaOp
1625 //===----------------------------------------------------------------------===//
1626 
1627 Optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
1628   return llvm::to_vector<4>(getVectorType().getShape());
1629 }
1630 
1631 //===----------------------------------------------------------------------===//
1632 // BroadcastOp
1633 //===----------------------------------------------------------------------===//
1634 
1635 BroadcastableToResult
1636 mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
1637                                 std::pair<int, int> *mismatchingDims) {
1638   // Broadcast scalar to vector of the same element type.
1639   if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
1640       getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
1641     return BroadcastableToResult::Success;
1642   // From now on, only vectors broadcast.
1643   VectorType srcVectorType = srcType.dyn_cast<VectorType>();
1644   if (!srcVectorType)
1645     return BroadcastableToResult::SourceTypeNotAVector;
1646 
1647   int64_t srcRank = srcVectorType.getRank();
1648   int64_t dstRank = dstVectorType.getRank();
1649   if (srcRank > dstRank)
1650     return BroadcastableToResult::SourceRankHigher;
1651   // Source has an exact match or singleton value for all trailing dimensions
1652   // (all leading dimensions are simply duplicated).
1653   int64_t lead = dstRank - srcRank;
1654   for (int64_t r = 0; r < srcRank; ++r) {
1655     int64_t srcDim = srcVectorType.getDimSize(r);
1656     int64_t dstDim = dstVectorType.getDimSize(lead + r);
1657     if (srcDim != 1 && srcDim != dstDim) {
1658       if (mismatchingDims) {
1659         mismatchingDims->first = srcDim;
1660         mismatchingDims->second = dstDim;
1661       }
1662       return BroadcastableToResult::DimensionMismatch;
1663     }
1664   }
1665 
1666   return BroadcastableToResult::Success;
1667 }
1668 
1669 static LogicalResult verify(BroadcastOp op) {
1670   std::pair<int, int> mismatchingDims;
1671   BroadcastableToResult res = isBroadcastableTo(
1672       op.getSourceType(), op.getVectorType(), &mismatchingDims);
1673   if (res == BroadcastableToResult::Success)
1674     return success();
1675   if (res == BroadcastableToResult::SourceRankHigher)
1676     return op.emitOpError("source rank higher than destination rank");
1677   if (res == BroadcastableToResult::DimensionMismatch)
1678     return op.emitOpError("dimension mismatch (")
1679            << mismatchingDims.first << " vs. " << mismatchingDims.second << ")";
1680   if (res == BroadcastableToResult::SourceTypeNotAVector)
1681     return op.emitOpError("source type is not a vector");
1682   llvm_unreachable("unexpected vector.broadcast op error");
1683 }
1684 
1685 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
1686   if (getSourceType() == getVectorType())
1687     return source();
1688   if (!operands[0])
1689     return {};
1690   auto vectorType = getVectorType();
1691   if (operands[0].getType().isIntOrIndexOrFloat())
1692     return DenseElementsAttr::get(vectorType, operands[0]);
1693   if (auto attr = operands[0].dyn_cast<SplatElementsAttr>())
1694     return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
1695   return {};
1696 }
1697 
1698 namespace {
1699 
1700 // Fold broadcast1(broadcast2(x)) into broadcast1(x).
1701 struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
1702   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
1703 
1704   LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
1705                                 PatternRewriter &rewriter) const override {
1706     auto srcBroadcast = broadcastOp.source().getDefiningOp<BroadcastOp>();
1707     if (!srcBroadcast)
1708       return failure();
1709     rewriter.replaceOpWithNewOp<BroadcastOp>(
1710         broadcastOp, broadcastOp.getVectorType(), srcBroadcast.source());
1711     return success();
1712   }
1713 };
1714 } // namespace
1715 
1716 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
1717                                               MLIRContext *context) {
1718   // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
1719   // calling `populateCastAwayVectorLeadingOneDimPatterns`
1720   results.add<BroadcastFolder>(context);
1721 }
1722 
1723 //===----------------------------------------------------------------------===//
1724 // ShuffleOp
1725 //===----------------------------------------------------------------------===//
1726 
1727 void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1,
1728                       Value v2, ArrayRef<int64_t> mask) {
1729   result.addOperands({v1, v2});
1730   auto maskAttr = getVectorSubscriptAttr(builder, mask);
1731   auto v1Type = v1.getType().cast<VectorType>();
1732   auto shape = llvm::to_vector<4>(v1Type.getShape());
1733   shape[0] = mask.size();
1734   result.addTypes(VectorType::get(shape, v1Type.getElementType()));
1735   result.addAttribute(getMaskAttrName(), maskAttr);
1736 }
1737 
1738 static void print(OpAsmPrinter &p, ShuffleOp op) {
1739   p << " " << op.v1() << ", " << op.v2() << " " << op.mask();
1740   p.printOptionalAttrDict(op->getAttrs(), {ShuffleOp::getMaskAttrName()});
1741   p << " : " << op.v1().getType() << ", " << op.v2().getType();
1742 }
1743 
1744 static LogicalResult verify(ShuffleOp op) {
1745   VectorType resultType = op.getVectorType();
1746   VectorType v1Type = op.getV1VectorType();
1747   VectorType v2Type = op.getV2VectorType();
1748   // Verify ranks.
1749   int64_t resRank = resultType.getRank();
1750   int64_t v1Rank = v1Type.getRank();
1751   int64_t v2Rank = v2Type.getRank();
1752   if (resRank != v1Rank || v1Rank != v2Rank)
1753     return op.emitOpError("rank mismatch");
1754   // Verify all but leading dimension sizes.
1755   for (int64_t r = 1; r < v1Rank; ++r) {
1756     int64_t resDim = resultType.getDimSize(r);
1757     int64_t v1Dim = v1Type.getDimSize(r);
1758     int64_t v2Dim = v2Type.getDimSize(r);
1759     if (resDim != v1Dim || v1Dim != v2Dim)
1760       return op.emitOpError("dimension mismatch");
1761   }
1762   // Verify mask length.
1763   auto maskAttr = op.mask().getValue();
1764   int64_t maskLength = maskAttr.size();
1765   if (maskLength != resultType.getDimSize(0))
1766     return op.emitOpError("mask length mismatch");
1767   // Verify all indices.
1768   int64_t indexSize = v1Type.getDimSize(0) + v2Type.getDimSize(0);
1769   for (const auto &en : llvm::enumerate(maskAttr)) {
1770     auto attr = en.value().dyn_cast<IntegerAttr>();
1771     if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
1772       return op.emitOpError("mask index #")
1773              << (en.index() + 1) << " out of range";
1774   }
1775   return success();
1776 }
1777 
1778 static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) {
1779   OpAsmParser::OperandType v1, v2;
1780   Attribute attr;
1781   VectorType v1Type, v2Type;
1782   if (parser.parseOperand(v1) || parser.parseComma() ||
1783       parser.parseOperand(v2) ||
1784       parser.parseAttribute(attr, ShuffleOp::getMaskAttrName(),
1785                             result.attributes) ||
1786       parser.parseOptionalAttrDict(result.attributes) ||
1787       parser.parseColonType(v1Type) || parser.parseComma() ||
1788       parser.parseType(v2Type) ||
1789       parser.resolveOperand(v1, v1Type, result.operands) ||
1790       parser.resolveOperand(v2, v2Type, result.operands))
1791     return failure();
1792   // Construct resulting type: leading dimension matches mask length,
1793   // all trailing dimensions match the operands.
1794   auto maskAttr = attr.dyn_cast<ArrayAttr>();
1795   if (!maskAttr)
1796     return parser.emitError(parser.getNameLoc(), "missing mask attribute");
1797   int64_t maskLength = maskAttr.size();
1798   if (maskLength <= 0)
1799     return parser.emitError(parser.getNameLoc(), "invalid mask length");
1800   int64_t v1Rank = v1Type.getRank();
1801   SmallVector<int64_t, 4> shape;
1802   shape.reserve(v1Rank);
1803   shape.push_back(maskLength);
1804   for (int64_t r = 1; r < v1Rank; ++r)
1805     shape.push_back(v1Type.getDimSize(r));
1806   VectorType resType = VectorType::get(shape, v1Type.getElementType());
1807   parser.addTypeToList(resType, result.types);
1808   return success();
1809 }
1810 
1811 //===----------------------------------------------------------------------===//
1812 // InsertElementOp
1813 //===----------------------------------------------------------------------===//
1814 
1815 void InsertElementOp::build(OpBuilder &builder, OperationState &result,
1816                             Value source, Value dest) {
1817   result.addOperands({source, dest});
1818   result.addTypes(dest.getType());
1819 }
1820 
1821 void InsertElementOp::build(OpBuilder &builder, OperationState &result,
1822                             Value source, Value dest, Value position) {
1823   result.addOperands({source, dest, position});
1824   result.addTypes(dest.getType());
1825 }
1826 
1827 static LogicalResult verify(InsertElementOp op) {
1828   auto dstVectorType = op.getDestVectorType();
1829   if (dstVectorType.getRank() == 0) {
1830     if (op.position())
1831       return op.emitOpError("expected position to be empty with 0-D vector");
1832     return success();
1833   }
1834   if (dstVectorType.getRank() != 1)
1835     return op.emitOpError("unexpected >1 vector rank");
1836   if (!op.position())
1837     return op.emitOpError("expected position for 1-D vector");
1838   return success();
1839 }
1840 
1841 //===----------------------------------------------------------------------===//
1842 // InsertOp
1843 //===----------------------------------------------------------------------===//
1844 
1845 void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
1846                      Value dest, ArrayRef<int64_t> position) {
1847   result.addOperands({source, dest});
1848   auto positionAttr = getVectorSubscriptAttr(builder, position);
1849   result.addTypes(dest.getType());
1850   result.addAttribute(getPositionAttrName(), positionAttr);
1851 }
1852 
1853 // Convenience builder which assumes the values are constant indices.
1854 void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
1855                      Value dest, ValueRange position) {
1856   SmallVector<int64_t, 4> positionConstants =
1857       llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
1858         return pos.getDefiningOp<arith::ConstantIndexOp>().value();
1859       }));
1860   build(builder, result, source, dest, positionConstants);
1861 }
1862 
1863 static LogicalResult verify(InsertOp op) {
1864   auto positionAttr = op.position().getValue();
1865   auto destVectorType = op.getDestVectorType();
1866   if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank()))
1867     return op.emitOpError(
1868         "expected position attribute of rank smaller than dest vector rank");
1869   auto srcVectorType = op.getSourceType().dyn_cast<VectorType>();
1870   if (srcVectorType &&
1871       (static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() !=
1872        static_cast<unsigned>(destVectorType.getRank())))
1873     return op.emitOpError("expected position attribute rank + source rank to "
1874                           "match dest vector rank");
1875   if (!srcVectorType &&
1876       (positionAttr.size() != static_cast<unsigned>(destVectorType.getRank())))
1877     return op.emitOpError(
1878         "expected position attribute rank to match the dest vector rank");
1879   for (const auto &en : llvm::enumerate(positionAttr)) {
1880     auto attr = en.value().dyn_cast<IntegerAttr>();
1881     if (!attr || attr.getInt() < 0 ||
1882         attr.getInt() >= destVectorType.getDimSize(en.index()))
1883       return op.emitOpError("expected position attribute #")
1884              << (en.index() + 1)
1885              << " to be a non-negative integer smaller than the corresponding "
1886                 "dest vector dimension";
1887   }
1888   return success();
1889 }
1890 
1891 namespace {
1892 
1893 // If insertOp is only inserting unit dimensions it can be transformed to a
1894 // broadcast.
1895 class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
1896 public:
1897   using OpRewritePattern<InsertOp>::OpRewritePattern;
1898 
1899   LogicalResult matchAndRewrite(InsertOp insertOp,
1900                                 PatternRewriter &rewriter) const override {
1901     auto srcVecType = insertOp.getSourceType().dyn_cast<VectorType>();
1902     if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
1903                            srcVecType.getNumElements())
1904       return failure();
1905     rewriter.replaceOpWithNewOp<BroadcastOp>(
1906         insertOp, insertOp.getDestVectorType(), insertOp.source());
1907     return success();
1908   }
1909 };
1910 
1911 } // namespace
1912 
1913 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
1914                                            MLIRContext *context) {
1915   results.add<InsertToBroadcast, BroadcastFolder>(context);
1916 }
1917 
1918 // Eliminates insert operations that produce values identical to their source
1919 // value. This happens when the source and destination vectors have identical
1920 // sizes.
1921 OpFoldResult vector::InsertOp::fold(ArrayRef<Attribute> operands) {
1922   if (position().empty())
1923     return source();
1924   return {};
1925 }
1926 
1927 //===----------------------------------------------------------------------===//
1928 // InsertMapOp
1929 //===----------------------------------------------------------------------===//
1930 
1931 void InsertMapOp::build(OpBuilder &builder, OperationState &result,
1932                         Value vector, Value dest, ValueRange ids) {
1933   InsertMapOp::build(builder, result, dest.getType(), vector, dest, ids);
1934 }
1935 
1936 static LogicalResult verify(InsertMapOp op) {
1937   if (op.getSourceVectorType().getRank() != op.getResultType().getRank())
1938     return op.emitOpError(
1939         "expected source and destination vectors of same rank");
1940   unsigned numId = 0;
1941   for (unsigned i = 0, e = op.getResultType().getRank(); i < e; i++) {
1942     if (op.getResultType().getDimSize(i) %
1943             op.getSourceVectorType().getDimSize(i) !=
1944         0)
1945       return op.emitOpError(
1946           "destination vector size must be a multiple of source vector size");
1947     if (op.getResultType().getDimSize(i) !=
1948         op.getSourceVectorType().getDimSize(i))
1949       numId++;
1950   }
1951   if (numId != op.ids().size())
1952     return op.emitOpError("expected number of ids must match the number of "
1953                           "dimensions distributed");
1954   return success();
1955 }
1956 
1957 AffineMap InsertMapOp::map() { return calculateImplicitMap(*this); }
1958 
1959 //===----------------------------------------------------------------------===//
1960 // InsertStridedSliceOp
1961 //===----------------------------------------------------------------------===//
1962 
1963 void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
1964                                  Value source, Value dest,
1965                                  ArrayRef<int64_t> offsets,
1966                                  ArrayRef<int64_t> strides) {
1967   result.addOperands({source, dest});
1968   auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
1969   auto stridesAttr = getVectorSubscriptAttr(builder, strides);
1970   result.addTypes(dest.getType());
1971   result.addAttribute(getOffsetsAttrName(), offsetsAttr);
1972   result.addAttribute(getStridesAttrName(), stridesAttr);
1973 }
1974 
1975 // TODO: Should be moved to Tablegen Confined attributes.
1976 template <typename OpType>
1977 static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
1978                                                         ArrayAttr arrayAttr,
1979                                                         ArrayRef<int64_t> shape,
1980                                                         StringRef attrName) {
1981   if (arrayAttr.size() > shape.size())
1982     return op.emitOpError("expected ")
1983            << attrName << " attribute of rank smaller than vector rank";
1984   return success();
1985 }
1986 
1987 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
1988 // interval. If `halfOpen` is true then the admissible interval is [min, max).
1989 // Otherwise, the admissible interval is [min, max].
1990 template <typename OpType>
1991 static LogicalResult
1992 isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
1993                                   int64_t max, StringRef attrName,
1994                                   bool halfOpen = true) {
1995   for (auto attr : arrayAttr) {
1996     auto val = attr.cast<IntegerAttr>().getInt();
1997     auto upper = max;
1998     if (!halfOpen)
1999       upper += 1;
2000     if (val < min || val >= upper)
2001       return op.emitOpError("expected ") << attrName << " to be confined to ["
2002                                          << min << ", " << upper << ")";
2003   }
2004   return success();
2005 }
2006 
2007 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
2008 // interval. If `halfOpen` is true then the admissible interval is [min, max).
2009 // Otherwise, the admissible interval is [min, max].
2010 template <typename OpType>
2011 static LogicalResult
2012 isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
2013                                   ArrayRef<int64_t> shape, StringRef attrName,
2014                                   bool halfOpen = true, int64_t min = 0) {
2015   assert(arrayAttr.size() <= shape.size());
2016   unsigned index = 0;
2017   for (auto it : llvm::zip(arrayAttr, shape)) {
2018     auto val = std::get<0>(it).cast<IntegerAttr>().getInt();
2019     auto max = std::get<1>(it);
2020     if (!halfOpen)
2021       max += 1;
2022     if (val < min || val >= max)
2023       return op.emitOpError("expected ")
2024              << attrName << " dimension " << index << " to be confined to ["
2025              << min << ", " << max << ")";
2026     ++index;
2027   }
2028   return success();
2029 }
2030 
2031 // Returns true if all integers in `arrayAttr` are in the interval [min, max}.
2032 // interval. If `halfOpen` is true then the admissible interval is [min, max).
2033 // Otherwise, the admissible interval is [min, max].
2034 template <typename OpType>
2035 static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
2036     OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
2037     ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
2038     bool halfOpen = true, int64_t min = 1) {
2039   assert(arrayAttr1.size() <= shape.size());
2040   assert(arrayAttr2.size() <= shape.size());
2041   unsigned index = 0;
2042   for (auto it : llvm::zip(arrayAttr1, arrayAttr2, shape)) {
2043     auto val1 = std::get<0>(it).cast<IntegerAttr>().getInt();
2044     auto val2 = std::get<1>(it).cast<IntegerAttr>().getInt();
2045     auto max = std::get<2>(it);
2046     if (!halfOpen)
2047       max += 1;
2048     if (val1 + val2 < 0 || val1 + val2 >= max)
2049       return op.emitOpError("expected sum(")
2050              << attrName1 << ", " << attrName2 << ") dimension " << index
2051              << " to be confined to [" << min << ", " << max << ")";
2052     ++index;
2053   }
2054   return success();
2055 }
2056 
2057 static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
2058                                   MLIRContext *context) {
2059   auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
2060     return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
2061   });
2062   return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
2063 }
2064 
2065 static LogicalResult verify(InsertStridedSliceOp op) {
2066   auto sourceVectorType = op.getSourceVectorType();
2067   auto destVectorType = op.getDestVectorType();
2068   auto offsets = op.offsets();
2069   auto strides = op.strides();
2070   if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
2071     return op.emitOpError(
2072         "expected offsets of same size as destination vector rank");
2073   if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
2074     return op.emitOpError(
2075         "expected strides of same size as source vector rank");
2076   if (sourceVectorType.getRank() > destVectorType.getRank())
2077     return op.emitOpError(
2078         "expected source rank to be smaller than destination rank");
2079 
2080   auto sourceShape = sourceVectorType.getShape();
2081   auto destShape = destVectorType.getShape();
2082   SmallVector<int64_t, 4> sourceShapeAsDestShape(
2083       destShape.size() - sourceShape.size(), 0);
2084   sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
2085   auto offName = InsertStridedSliceOp::getOffsetsAttrName();
2086   auto stridesName = InsertStridedSliceOp::getStridesAttrName();
2087   if (failed(
2088           isIntegerArrayAttrConfinedToShape(op, offsets, destShape, offName)) ||
2089       failed(isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName,
2090                                                /*halfOpen=*/false)) ||
2091       failed(isSumOfIntegerArrayAttrConfinedToShape(
2092           op, offsets,
2093           makeI64ArrayAttr(sourceShapeAsDestShape, op.getContext()), destShape,
2094           offName, "source vector shape",
2095           /*halfOpen=*/false, /*min=*/1)))
2096     return failure();
2097 
2098   return success();
2099 }
2100 
2101 OpFoldResult InsertStridedSliceOp::fold(ArrayRef<Attribute> operands) {
2102   if (getSourceVectorType() == getDestVectorType())
2103     return source();
2104   return {};
2105 }
2106 
2107 //===----------------------------------------------------------------------===//
2108 // OuterProductOp
2109 //===----------------------------------------------------------------------===//
2110 
2111 /// Build an op without mask, use the type of `acc` as the return type.
2112 void OuterProductOp::build(OpBuilder &builder, OperationState &result,
2113                            Value lhs, Value rhs, Value acc) {
2114   result.addOperands({lhs, rhs, acc});
2115   result.addTypes(acc.getType());
2116 }
2117 
2118 static void print(OpAsmPrinter &p, OuterProductOp op) {
2119   p << " " << op.lhs() << ", " << op.rhs();
2120   if (!op.acc().empty()) {
2121     p << ", " << op.acc();
2122     p.printOptionalAttrDict(op->getAttrs());
2123   }
2124   p << " : " << op.lhs().getType() << ", " << op.rhs().getType();
2125 }
2126 
2127 static ParseResult parseOuterProductOp(OpAsmParser &parser,
2128                                        OperationState &result) {
2129   SmallVector<OpAsmParser::OperandType, 3> operandsInfo;
2130   Type tLHS, tRHS;
2131   if (parser.parseOperandList(operandsInfo) ||
2132       parser.parseOptionalAttrDict(result.attributes) ||
2133       parser.parseColonType(tLHS) || parser.parseComma() ||
2134       parser.parseType(tRHS))
2135     return failure();
2136   if (operandsInfo.size() < 2)
2137     return parser.emitError(parser.getNameLoc(),
2138                             "expected at least 2 operands");
2139   VectorType vLHS = tLHS.dyn_cast<VectorType>();
2140   VectorType vRHS = tRHS.dyn_cast<VectorType>();
2141   if (!vLHS)
2142     return parser.emitError(parser.getNameLoc(),
2143                             "expected vector type for operand #1");
2144   VectorType resType =
2145       vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
2146                              vLHS.getElementType())
2147            : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType());
2148 
2149   if (!result.attributes.get(OuterProductOp::getKindAttrName())) {
2150     result.attributes.append(
2151         OuterProductOp::getKindAttrName(),
2152         CombiningKindAttr::get(OuterProductOp::getDefaultKind(),
2153                                result.getContext()));
2154   }
2155 
2156   return failure(
2157       parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
2158       parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
2159       (operandsInfo.size() > 2 &&
2160        parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
2161       parser.addTypeToList(resType, result.types));
2162 }
2163 
2164 static LogicalResult verify(OuterProductOp op) {
2165   Type tRHS = op.getOperandTypeRHS();
2166   VectorType vLHS = op.getOperandVectorTypeLHS(),
2167              vRHS = tRHS.dyn_cast<VectorType>(),
2168              vACC = op.getOperandVectorTypeACC(), vRES = op.getVectorType();
2169 
2170   if (vLHS.getRank() != 1)
2171     return op.emitOpError("expected 1-d vector for operand #1");
2172 
2173   if (vRHS) {
2174     // Proper OUTER operation.
2175     if (vRHS.getRank() != 1)
2176       return op.emitOpError("expected 1-d vector for operand #2");
2177     if (vRES.getRank() != 2)
2178       return op.emitOpError("expected 2-d vector result");
2179     if (vLHS.getDimSize(0) != vRES.getDimSize(0))
2180       return op.emitOpError("expected #1 operand dim to match result dim #1");
2181     if (vRHS.getDimSize(0) != vRES.getDimSize(1))
2182       return op.emitOpError("expected #2 operand dim to match result dim #2");
2183   } else {
2184     // An AXPY operation.
2185     if (vRES.getRank() != 1)
2186       return op.emitOpError("expected 1-d vector result");
2187     if (vLHS.getDimSize(0) != vRES.getDimSize(0))
2188       return op.emitOpError("expected #1 operand dim to match result dim #1");
2189   }
2190 
2191   if (vACC && vACC != vRES)
2192     return op.emitOpError("expected operand #3 of same type as result type");
2193 
2194   // Verify supported combining kind.
2195   if (!isSupportedCombiningKind(op.kind(), vRES.getElementType()))
2196     return op.emitOpError("unsupported outerproduct type");
2197 
2198   return success();
2199 }
2200 
2201 //===----------------------------------------------------------------------===//
2202 // ReshapeOp
2203 //===----------------------------------------------------------------------===//
2204 
2205 static LogicalResult verify(ReshapeOp op) {
2206   // Verify that rank(numInputs/outputs) + numFixedVec dim matches vec rank.
2207   auto inputVectorType = op.getInputVectorType();
2208   auto outputVectorType = op.getOutputVectorType();
2209   int64_t inputShapeRank = op.getNumInputShapeSizes();
2210   int64_t outputShapeRank = op.getNumOutputShapeSizes();
2211   SmallVector<int64_t, 4> fixedVectorSizes;
2212   op.getFixedVectorSizes(fixedVectorSizes);
2213   int64_t numFixedVectorSizes = fixedVectorSizes.size();
2214 
2215   if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes)
2216     return op.emitError("invalid input shape for vector type ")
2217            << inputVectorType;
2218 
2219   if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes)
2220     return op.emitError("invalid output shape for vector type ")
2221            << outputVectorType;
2222 
2223   // Verify that the 'fixedVectorSizes' match an input/output vector shape
2224   // suffix.
2225   unsigned inputVectorRank = inputVectorType.getRank();
2226   for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
2227     unsigned index = inputVectorRank - numFixedVectorSizes - i;
2228     if (fixedVectorSizes[i] != inputVectorType.getShape()[index])
2229       return op.emitError("fixed vector size must match input vector for dim ")
2230              << i;
2231   }
2232 
2233   unsigned outputVectorRank = outputVectorType.getRank();
2234   for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
2235     unsigned index = outputVectorRank - numFixedVectorSizes - i;
2236     if (fixedVectorSizes[i] != outputVectorType.getShape()[index])
2237       return op.emitError("fixed vector size must match output vector for dim ")
2238              << i;
2239   }
2240 
2241   // If all shape operands are produced by constant ops, verify that product
2242   // of dimensions for input/output shape match.
2243   auto isDefByConstant = [](Value operand) {
2244     return isa_and_nonnull<arith::ConstantIndexOp>(operand.getDefiningOp());
2245   };
2246   if (llvm::all_of(op.input_shape(), isDefByConstant) &&
2247       llvm::all_of(op.output_shape(), isDefByConstant)) {
2248     int64_t numInputElements = 1;
2249     for (auto operand : op.input_shape())
2250       numInputElements *=
2251           cast<arith::ConstantIndexOp>(operand.getDefiningOp()).value();
2252     int64_t numOutputElements = 1;
2253     for (auto operand : op.output_shape())
2254       numOutputElements *=
2255           cast<arith::ConstantIndexOp>(operand.getDefiningOp()).value();
2256     if (numInputElements != numOutputElements)
2257       return op.emitError("product of input and output shape sizes must match");
2258   }
2259   return success();
2260 }
2261 
2262 void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) {
2263   populateFromInt64AttrArray(fixed_vector_sizes(), results);
2264 }
2265 
2266 //===----------------------------------------------------------------------===//
2267 // ExtractStridedSliceOp
2268 //===----------------------------------------------------------------------===//
2269 
2270 // Inference works as follows:
2271 //   1. Add 'sizes' from prefix of dims in 'offsets'.
2272 //   2. Add sizes from 'vectorType' for remaining dims.
2273 static Type inferStridedSliceOpResultType(VectorType vectorType,
2274                                           ArrayAttr offsets, ArrayAttr sizes,
2275                                           ArrayAttr strides) {
2276   assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
2277   SmallVector<int64_t, 4> shape;
2278   shape.reserve(vectorType.getRank());
2279   unsigned idx = 0;
2280   for (unsigned e = offsets.size(); idx < e; ++idx)
2281     shape.push_back(sizes[idx].cast<IntegerAttr>().getInt());
2282   for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
2283     shape.push_back(vectorType.getShape()[idx]);
2284 
2285   return VectorType::get(shape, vectorType.getElementType());
2286 }
2287 
2288 void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
2289                                   Value source, ArrayRef<int64_t> offsets,
2290                                   ArrayRef<int64_t> sizes,
2291                                   ArrayRef<int64_t> strides) {
2292   result.addOperands(source);
2293   auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
2294   auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
2295   auto stridesAttr = getVectorSubscriptAttr(builder, strides);
2296   result.addTypes(
2297       inferStridedSliceOpResultType(source.getType().cast<VectorType>(),
2298                                     offsetsAttr, sizesAttr, stridesAttr));
2299   result.addAttribute(getOffsetsAttrName(), offsetsAttr);
2300   result.addAttribute(getSizesAttrName(), sizesAttr);
2301   result.addAttribute(getStridesAttrName(), stridesAttr);
2302 }
2303 
2304 static LogicalResult verify(ExtractStridedSliceOp op) {
2305   auto type = op.getVectorType();
2306   auto offsets = op.offsets();
2307   auto sizes = op.sizes();
2308   auto strides = op.strides();
2309   if (offsets.size() != sizes.size() || offsets.size() != strides.size()) {
2310     op.emitOpError(
2311         "expected offsets, sizes and strides attributes of same size");
2312     return failure();
2313   }
2314 
2315   auto shape = type.getShape();
2316   auto offName = ExtractStridedSliceOp::getOffsetsAttrName();
2317   auto sizesName = ExtractStridedSliceOp::getSizesAttrName();
2318   auto stridesName = ExtractStridedSliceOp::getStridesAttrName();
2319   if (failed(isIntegerArrayAttrSmallerThanShape(op, offsets, shape, offName)) ||
2320       failed(isIntegerArrayAttrSmallerThanShape(op, sizes, shape, sizesName)) ||
2321       failed(isIntegerArrayAttrSmallerThanShape(op, strides, shape,
2322                                                 stridesName)) ||
2323       failed(isIntegerArrayAttrConfinedToShape(op, offsets, shape, offName)) ||
2324       failed(isIntegerArrayAttrConfinedToShape(op, sizes, shape, sizesName,
2325                                                /*halfOpen=*/false,
2326                                                /*min=*/1)) ||
2327       failed(isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName,
2328                                                /*halfOpen=*/false)) ||
2329       failed(isSumOfIntegerArrayAttrConfinedToShape(op, offsets, sizes, shape,
2330                                                     offName, sizesName,
2331                                                     /*halfOpen=*/false)))
2332     return failure();
2333 
2334   auto resultType = inferStridedSliceOpResultType(
2335       op.getVectorType(), op.offsets(), op.sizes(), op.strides());
2336   if (op.getResult().getType() != resultType) {
2337     op.emitOpError("expected result type to be ") << resultType;
2338     return failure();
2339   }
2340 
2341   return success();
2342 }
2343 
2344 // When the source of ExtractStrided comes from a chain of InsertStrided ops try
2345 // to use the source of the InsertStrided ops if we can detect that the
2346 // extracted vector is a subset of one of the vector inserted.
2347 static LogicalResult
2348 foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
2349   // Helper to extract integer out of ArrayAttr.
2350   auto getElement = [](ArrayAttr array, int idx) {
2351     return array[idx].cast<IntegerAttr>().getInt();
2352   };
2353   ArrayAttr extractOffsets = op.offsets();
2354   ArrayAttr extractStrides = op.strides();
2355   ArrayAttr extractSizes = op.sizes();
2356   auto insertOp = op.vector().getDefiningOp<InsertStridedSliceOp>();
2357   while (insertOp) {
2358     if (op.getVectorType().getRank() !=
2359         insertOp.getSourceVectorType().getRank())
2360       return failure();
2361     ArrayAttr insertOffsets = insertOp.offsets();
2362     ArrayAttr insertStrides = insertOp.strides();
2363     // If the rank of extract is greater than the rank of insert, we are likely
2364     // extracting a partial chunk of the vector inserted.
2365     if (extractOffsets.size() > insertOffsets.size())
2366       return failure();
2367     bool patialoverlap = false;
2368     bool disjoint = false;
2369     SmallVector<int64_t, 4> offsetDiffs;
2370     for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
2371       if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
2372         return failure();
2373       int64_t start = getElement(insertOffsets, dim);
2374       int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
2375       int64_t offset = getElement(extractOffsets, dim);
2376       int64_t size = getElement(extractSizes, dim);
2377       // Check if the start of the extract offset is in the interval inserted.
2378       if (start <= offset && offset < end) {
2379         // If the extract interval overlaps but is not fully included we may
2380         // have a partial overlap that will prevent any folding.
2381         if (offset + size > end)
2382           patialoverlap = true;
2383         offsetDiffs.push_back(offset - start);
2384         continue;
2385       }
2386       disjoint = true;
2387       break;
2388     }
2389     // The extract element chunk is a subset of the insert element.
2390     if (!disjoint && !patialoverlap) {
2391       op.setOperand(insertOp.source());
2392       // OpBuilder is only used as a helper to build an I64ArrayAttr.
2393       OpBuilder b(op.getContext());
2394       op->setAttr(ExtractStridedSliceOp::getOffsetsAttrName(),
2395                   b.getI64ArrayAttr(offsetDiffs));
2396       return success();
2397     }
2398     // If the chunk extracted is disjoint from the chunk inserted, keep looking
2399     // in the insert chain.
2400     if (disjoint)
2401       insertOp = insertOp.dest().getDefiningOp<InsertStridedSliceOp>();
2402     else {
2403       // The extracted vector partially overlap the inserted vector, we cannot
2404       // fold.
2405       return failure();
2406     }
2407   }
2408   return failure();
2409 }
2410 
2411 OpFoldResult ExtractStridedSliceOp::fold(ArrayRef<Attribute> operands) {
2412   if (getVectorType() == getResult().getType())
2413     return vector();
2414   if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
2415     return getResult();
2416   return {};
2417 }
2418 
2419 void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
2420   populateFromInt64AttrArray(offsets(), results);
2421 }
2422 
2423 namespace {
2424 
2425 // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
2426 // ConstantMaskOp.
2427 class StridedSliceConstantMaskFolder final
2428     : public OpRewritePattern<ExtractStridedSliceOp> {
2429 public:
2430   using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
2431 
2432   LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
2433                                 PatternRewriter &rewriter) const override {
2434     // Return if 'extractStridedSliceOp' operand is not defined by a
2435     // ConstantMaskOp.
2436     auto *defOp = extractStridedSliceOp.vector().getDefiningOp();
2437     auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
2438     if (!constantMaskOp)
2439       return failure();
2440     // Return if 'extractStridedSliceOp' has non-unit strides.
2441     if (extractStridedSliceOp.hasNonUnitStrides())
2442       return failure();
2443     // Gather constant mask dimension sizes.
2444     SmallVector<int64_t, 4> maskDimSizes;
2445     populateFromInt64AttrArray(constantMaskOp.mask_dim_sizes(), maskDimSizes);
2446     // Gather strided slice offsets and sizes.
2447     SmallVector<int64_t, 4> sliceOffsets;
2448     populateFromInt64AttrArray(extractStridedSliceOp.offsets(), sliceOffsets);
2449     SmallVector<int64_t, 4> sliceSizes;
2450     populateFromInt64AttrArray(extractStridedSliceOp.sizes(), sliceSizes);
2451 
2452     // Compute slice of vector mask region.
2453     SmallVector<int64_t, 4> sliceMaskDimSizes;
2454     assert(sliceOffsets.size() == maskDimSizes.size());
2455     for (auto it : llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
2456       int64_t maskDimSize = std::get<0>(it);
2457       int64_t sliceOffset = std::get<1>(it);
2458       int64_t sliceSize = std::get<2>(it);
2459       int64_t sliceMaskDimSize = std::max(
2460           static_cast<int64_t>(0),
2461           std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
2462       sliceMaskDimSizes.push_back(sliceMaskDimSize);
2463     }
2464     // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
2465     // region is a conjunction of mask dim intervals).
2466     if (llvm::is_contained(sliceMaskDimSizes, 0))
2467       sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
2468 
2469     // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
2470     // region.
2471     rewriter.replaceOpWithNewOp<ConstantMaskOp>(
2472         extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
2473         vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes));
2474     return success();
2475   }
2476 };
2477 
2478 // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
2479 class StridedSliceConstantFolder final
2480     : public OpRewritePattern<ExtractStridedSliceOp> {
2481 public:
2482   using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
2483 
2484   LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
2485                                 PatternRewriter &rewriter) const override {
2486     // Return if 'extractStridedSliceOp' operand is not defined by a
2487     // ConstantOp.
2488     auto constantOp =
2489         extractStridedSliceOp.vector().getDefiningOp<arith::ConstantOp>();
2490     if (!constantOp)
2491       return failure();
2492     auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
2493     if (!dense)
2494       return failure();
2495     auto newAttr = DenseElementsAttr::get(extractStridedSliceOp.getType(),
2496                                           dense.getSplatValue<Attribute>());
2497     rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
2498                                                    newAttr);
2499     return success();
2500   }
2501 };
2502 
2503 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
2504 // BroadcastOp(ExtractStrideSliceOp).
2505 class StridedSliceBroadcast final
2506     : public OpRewritePattern<ExtractStridedSliceOp> {
2507 public:
2508   using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
2509 
2510   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
2511                                 PatternRewriter &rewriter) const override {
2512     auto broadcast = op.vector().getDefiningOp<BroadcastOp>();
2513     if (!broadcast)
2514       return failure();
2515     auto srcVecType = broadcast.source().getType().dyn_cast<VectorType>();
2516     unsigned srcRrank = srcVecType ? srcVecType.getRank() : 0;
2517     auto dstVecType = op.getType().cast<VectorType>();
2518     unsigned dstRank = dstVecType.getRank();
2519     unsigned rankDiff = dstRank - srcRrank;
2520     // Check if the most inner dimensions of the source of the broadcast are the
2521     // same as the destination of the extract. If this is the case we can just
2522     // use a broadcast as the original dimensions are untouched.
2523     bool lowerDimMatch = true;
2524     for (unsigned i = 0; i < srcRrank; i++) {
2525       if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
2526         lowerDimMatch = false;
2527         break;
2528       }
2529     }
2530     Value source = broadcast.source();
2531     if (!lowerDimMatch) {
2532       // The inner dimensions don't match, it means we need to extract from the
2533       // source of the orignal broadcast and then broadcast the extracted value.
2534       source = rewriter.create<ExtractStridedSliceOp>(
2535           op->getLoc(), source,
2536           getI64SubArray(op.offsets(), /* dropFront=*/rankDiff),
2537           getI64SubArray(op.sizes(), /* dropFront=*/rankDiff),
2538           getI64SubArray(op.strides(), /* dropFront=*/rankDiff));
2539     }
2540     rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
2541     return success();
2542   }
2543 };
2544 
2545 /// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
2546 class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
2547 public:
2548   using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
2549 
2550   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
2551                                 PatternRewriter &rewriter) const override {
2552     auto splat = op.vector().getDefiningOp<SplatOp>();
2553     if (!splat)
2554       return failure();
2555     rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput());
2556     return success();
2557   }
2558 };
2559 
2560 } // namespace
2561 
2562 void ExtractStridedSliceOp::getCanonicalizationPatterns(
2563     RewritePatternSet &results, MLIRContext *context) {
2564   // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
2565   // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
2566   results.add<StridedSliceConstantMaskFolder, StridedSliceConstantFolder,
2567               StridedSliceBroadcast, StridedSliceSplat>(context);
2568 }
2569 
2570 //===----------------------------------------------------------------------===//
2571 // TransferReadOp
2572 //===----------------------------------------------------------------------===//
2573 
2574 /// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
2575 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2576                            VectorType vectorType, Value source,
2577                            ValueRange indices, AffineMapAttr permutationMapAttr,
2578                            /*optional*/ ArrayAttr inBoundsAttr) {
2579   Type elemType = source.getType().cast<ShapedType>().getElementType();
2580   Value padding = builder.create<arith::ConstantOp>(
2581       result.location, elemType, builder.getZeroAttr(elemType));
2582   build(builder, result, vectorType, source, indices, permutationMapAttr,
2583         padding, /*mask=*/Value(), inBoundsAttr);
2584 }
2585 
2586 /// 2. Builder that sets padding to zero an empty mask (variant without attrs).
2587 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2588                            VectorType vectorType, Value source,
2589                            ValueRange indices, AffineMap permutationMap,
2590                            Optional<ArrayRef<bool>> inBounds) {
2591   auto permutationMapAttr = AffineMapAttr::get(permutationMap);
2592   auto inBoundsAttr = (inBounds && !inBounds.getValue().empty())
2593                           ? builder.getBoolArrayAttr(inBounds.getValue())
2594                           : ArrayAttr();
2595   build(builder, result, vectorType, source, indices, permutationMapAttr,
2596         inBoundsAttr);
2597 }
2598 
2599 /// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
2600 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2601                            VectorType vectorType, Value source,
2602                            ValueRange indices, Value padding,
2603                            Optional<ArrayRef<bool>> inBounds) {
2604   AffineMap permutationMap = getTransferMinorIdentityMap(
2605       source.getType().cast<ShapedType>(), vectorType);
2606   auto permutationMapAttr = AffineMapAttr::get(permutationMap);
2607   auto inBoundsAttr = (inBounds && !inBounds.getValue().empty())
2608                           ? builder.getBoolArrayAttr(inBounds.getValue())
2609                           : ArrayAttr();
2610   build(builder, result, vectorType, source, indices, permutationMapAttr,
2611         padding,
2612         /*mask=*/Value(), inBoundsAttr);
2613 }
2614 
2615 /// 4. Builder that sets padding to zero and permutation map to
2616 /// 'getMinorIdentityMap'.
2617 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2618                            VectorType vectorType, Value source,
2619                            ValueRange indices,
2620                            Optional<ArrayRef<bool>> inBounds) {
2621   Type elemType = source.getType().cast<ShapedType>().getElementType();
2622   Value padding = builder.create<arith::ConstantOp>(
2623       result.location, elemType, builder.getZeroAttr(elemType));
2624   build(builder, result, vectorType, source, indices, padding, inBounds);
2625 }
2626 
2627 template <typename EmitFun>
2628 static LogicalResult verifyPermutationMap(AffineMap permutationMap,
2629                                           EmitFun emitOpError) {
2630   SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
2631   for (auto expr : permutationMap.getResults()) {
2632     auto dim = expr.dyn_cast<AffineDimExpr>();
2633     auto zero = expr.dyn_cast<AffineConstantExpr>();
2634     if (zero) {
2635       if (zero.getValue() != 0) {
2636         return emitOpError(
2637             "requires a projected permutation_map (at most one dim or the zero "
2638             "constant can appear in each result)");
2639       }
2640       continue;
2641     }
2642     if (!dim) {
2643       return emitOpError("requires a projected permutation_map (at most one "
2644                          "dim or the zero constant can appear in each result)");
2645     }
2646     if (seen[dim.getPosition()]) {
2647       return emitOpError(
2648           "requires a permutation_map that is a permutation (found one dim "
2649           "used more than once)");
2650     }
2651     seen[dim.getPosition()] = true;
2652   }
2653   return success();
2654 }
2655 
2656 static LogicalResult
2657 verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
2658                  VectorType vectorType, VectorType maskType,
2659                  AffineMap permutationMap, ArrayAttr inBounds) {
2660   if (op->hasAttr("masked")) {
2661     return op->emitOpError("masked attribute has been removed. "
2662                            "Use in_bounds instead.");
2663   }
2664 
2665   if (!shapedType.isa<MemRefType, RankedTensorType>())
2666     return op->emitOpError(
2667         "requires source to be a memref or ranked tensor type");
2668 
2669   auto elementType = shapedType.getElementType();
2670   DataLayout dataLayout = DataLayout::closest(op);
2671   if (auto vectorElementType = elementType.dyn_cast<VectorType>()) {
2672     // Memref or tensor has vector element type.
2673     unsigned sourceVecSize =
2674         dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
2675         vectorElementType.getShape().back();
2676     unsigned resultVecSize =
2677         dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
2678         vectorType.getShape().back();
2679     if (resultVecSize % sourceVecSize != 0)
2680       return op->emitOpError(
2681           "requires the bitwidth of the minor 1-D vector to be an integral "
2682           "multiple of the bitwidth of the minor 1-D vector of the source");
2683 
2684     unsigned sourceVecEltRank = vectorElementType.getRank();
2685     unsigned resultVecRank = vectorType.getRank();
2686     if (sourceVecEltRank > resultVecRank)
2687       return op->emitOpError(
2688           "requires source vector element and vector result ranks to match.");
2689     unsigned rankOffset = resultVecRank - sourceVecEltRank;
2690     // Check that permutation map results match 'rankOffset' of vector type.
2691     if (permutationMap.getNumResults() != rankOffset)
2692       return op->emitOpError("requires a permutation_map with result dims of "
2693                              "the same rank as the vector type");
2694 
2695     if (maskType)
2696       return op->emitOpError("does not support masks with vector element type");
2697   } else {
2698     // Memref or tensor has scalar element type.
2699     unsigned minorSize =
2700         vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
2701     unsigned resultVecSize =
2702         dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize;
2703     if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
2704       return op->emitOpError(
2705           "requires the bitwidth of the minor 1-D vector to be an integral "
2706           "multiple of the bitwidth of the source element type");
2707 
2708     // Check that permutation map results match rank of vector type.
2709     if (permutationMap.getNumResults() != vectorType.getRank())
2710       return op->emitOpError("requires a permutation_map with result dims of "
2711                              "the same rank as the vector type");
2712 
2713     VectorType expectedMaskType =
2714         vector::detail::transferMaskType(vectorType, permutationMap);
2715     if (maskType && expectedMaskType != maskType)
2716       return op->emitOpError("expects mask type consistent with permutation "
2717                              "map: ")
2718              << maskType;
2719   }
2720 
2721   if (permutationMap.getNumSymbols() != 0)
2722     return op->emitOpError("requires permutation_map without symbols");
2723 
2724   if (permutationMap.getNumInputs() != shapedType.getRank())
2725     return op->emitOpError("requires a permutation_map with input dims of the "
2726                            "same rank as the source type");
2727 
2728   if (inBounds) {
2729     if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
2730       return op->emitOpError("expects the optional in_bounds attr of same rank "
2731                              "as permutation_map results: ")
2732              << AffineMapAttr::get(permutationMap)
2733              << " vs inBounds of size: " << inBounds.size();
2734     for (unsigned int i = 0; i < permutationMap.getNumResults(); ++i)
2735       if (permutationMap.getResult(i).isa<AffineConstantExpr>() &&
2736           !inBounds.getValue()[i].cast<BoolAttr>().getValue())
2737         return op->emitOpError("requires broadcast dimensions to be in-bounds");
2738   }
2739 
2740   return success();
2741 }
2742 
2743 static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
2744   SmallVector<StringRef, 3> elidedAttrs;
2745   elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
2746   if (op.permutation_map().isMinorIdentity())
2747     elidedAttrs.push_back(op.getPermutationMapAttrName());
2748   bool elideInBounds = true;
2749   if (auto inBounds = op.in_bounds()) {
2750     for (auto attr : *inBounds) {
2751       if (attr.template cast<BoolAttr>().getValue()) {
2752         elideInBounds = false;
2753         break;
2754       }
2755     }
2756   }
2757   if (elideInBounds)
2758     elidedAttrs.push_back(op.getInBoundsAttrName());
2759   p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
2760 }
2761 
2762 static void print(OpAsmPrinter &p, TransferReadOp op) {
2763   p << " " << op.source() << "[" << op.indices() << "], " << op.padding();
2764   if (op.mask())
2765     p << ", " << op.mask();
2766   printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
2767   p << " : " << op.getShapedType() << ", " << op.getVectorType();
2768 }
2769 
2770 static ParseResult parseTransferReadOp(OpAsmParser &parser,
2771                                        OperationState &result) {
2772   auto &builder = parser.getBuilder();
2773   SMLoc typesLoc;
2774   OpAsmParser::OperandType sourceInfo;
2775   SmallVector<OpAsmParser::OperandType, 8> indexInfo;
2776   OpAsmParser::OperandType paddingInfo;
2777   SmallVector<Type, 2> types;
2778   OpAsmParser::OperandType maskInfo;
2779   // Parsing with support for paddingValue.
2780   if (parser.parseOperand(sourceInfo) ||
2781       parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
2782       parser.parseComma() || parser.parseOperand(paddingInfo))
2783     return failure();
2784   ParseResult hasMask = parser.parseOptionalComma();
2785   if (hasMask.succeeded()) {
2786     parser.parseOperand(maskInfo);
2787   }
2788   if (parser.parseOptionalAttrDict(result.attributes) ||
2789       parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
2790     return failure();
2791   if (types.size() != 2)
2792     return parser.emitError(typesLoc, "requires two types");
2793   auto indexType = builder.getIndexType();
2794   auto shapedType = types[0].dyn_cast<ShapedType>();
2795   if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
2796     return parser.emitError(typesLoc, "requires memref or ranked tensor type");
2797   VectorType vectorType = types[1].dyn_cast<VectorType>();
2798   if (!vectorType)
2799     return parser.emitError(typesLoc, "requires vector type");
2800   auto permutationAttrName = TransferReadOp::getPermutationMapAttrName();
2801   Attribute mapAttr = result.attributes.get(permutationAttrName);
2802   if (!mapAttr) {
2803     auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
2804     // Update `mapAttr` that is used later to determine mask type.
2805     mapAttr = AffineMapAttr::get(permMap);
2806     result.attributes.set(permutationAttrName, mapAttr);
2807   }
2808   if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
2809       parser.resolveOperands(indexInfo, indexType, result.operands) ||
2810       parser.resolveOperand(paddingInfo, shapedType.getElementType(),
2811                             result.operands))
2812     return failure();
2813   if (hasMask.succeeded()) {
2814     if (shapedType.getElementType().dyn_cast<VectorType>())
2815       return parser.emitError(
2816           maskInfo.location, "does not support masks with vector element type");
2817     auto map = mapAttr.dyn_cast<AffineMapAttr>().getValue();
2818     // Instead of adding the mask type as an op type, compute it based on the
2819     // vector type and the permutation map (to keep the type signature small).
2820     auto maskType = mlir::vector::detail::transferMaskType(vectorType, map);
2821     if (parser.resolveOperand(maskInfo, maskType, result.operands))
2822       return failure();
2823   }
2824   result.addAttribute(
2825       TransferReadOp::getOperandSegmentSizeAttr(),
2826       builder.getI32VectorAttr({1, static_cast<int32_t>(indexInfo.size()), 1,
2827                                 static_cast<int32_t>(hasMask.succeeded())}));
2828   return parser.addTypeToList(vectorType, result.types);
2829 }
2830 
2831 static LogicalResult verify(TransferReadOp op) {
2832   // Consistency of elemental types in source and vector.
2833   ShapedType shapedType = op.getShapedType();
2834   VectorType vectorType = op.getVectorType();
2835   VectorType maskType = op.getMaskType();
2836   auto paddingType = op.padding().getType();
2837   auto permutationMap = op.permutation_map();
2838   auto sourceElementType = shapedType.getElementType();
2839 
2840   if (static_cast<int64_t>(op.indices().size()) != shapedType.getRank())
2841     return op.emitOpError("requires ") << shapedType.getRank() << " indices";
2842 
2843   if (failed(
2844           verifyTransferOp(cast<VectorTransferOpInterface>(op.getOperation()),
2845                            shapedType, vectorType, maskType, permutationMap,
2846                            op.in_bounds() ? *op.in_bounds() : ArrayAttr())))
2847     return failure();
2848 
2849   if (auto sourceVectorElementType = sourceElementType.dyn_cast<VectorType>()) {
2850     // Source has vector element type.
2851     // Check that 'sourceVectorElementType' and 'paddingType' types match.
2852     if (sourceVectorElementType != paddingType)
2853       return op.emitOpError(
2854           "requires source element type and padding type to match.");
2855 
2856   } else {
2857     // Check that 'paddingType' is valid to store in a vector type.
2858     if (!VectorType::isValidElementType(paddingType))
2859       return op.emitOpError("requires valid padding vector elemental type");
2860 
2861     // Check that padding type and vector element types match.
2862     if (paddingType != sourceElementType)
2863       return op.emitOpError(
2864           "requires formal padding and source of the same elemental type");
2865   }
2866 
2867   return verifyPermutationMap(permutationMap,
2868                               [&op](Twine t) { return op.emitOpError(t); });
2869 }
2870 
2871 /// This is a common class used for patterns of the form
2872 /// ```
2873 ///    someop(memrefcast) -> someop
2874 /// ```
2875 /// It folds the source of the memref.cast into the root operation directly.
2876 static LogicalResult foldMemRefCast(Operation *op) {
2877   bool folded = false;
2878   for (OpOperand &operand : op->getOpOperands()) {
2879     auto castOp = operand.get().getDefiningOp<memref::CastOp>();
2880     if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
2881       operand.set(castOp.getOperand());
2882       folded = true;
2883     }
2884   }
2885   return success(folded);
2886 }
2887 
2888 static LogicalResult foldTensorCast(Operation *op) {
2889   bool folded = false;
2890   for (OpOperand &operand : op->getOpOperands()) {
2891     auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
2892     if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
2893       operand.set(castOp.getOperand());
2894       folded = true;
2895     }
2896   }
2897   return success(folded);
2898 }
2899 
2900 template <typename TransferOp>
2901 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
2902   // TODO: support more aggressive createOrFold on:
2903   // `op.indices()[indicesIdx] + vectorType < dim(op.source(), indicesIdx)`
2904   if (op.getShapedType().isDynamicDim(indicesIdx))
2905     return false;
2906   Value index = op.indices()[indicesIdx];
2907   auto cstOp = index.getDefiningOp<arith::ConstantIndexOp>();
2908   if (!cstOp)
2909     return false;
2910 
2911   int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
2912   int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
2913 
2914   return cstOp.value() + vectorSize <= sourceSize;
2915 }
2916 
2917 template <typename TransferOp>
2918 static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
2919   // TODO: support 0-d corner case.
2920   // TODO: Be less conservative.
2921   if (op.getTransferRank() == 0)
2922     return failure();
2923   AffineMap permutationMap = op.permutation_map();
2924   bool changed = false;
2925   SmallVector<bool, 4> newInBounds;
2926   newInBounds.reserve(op.getTransferRank());
2927   for (unsigned i = 0; i < op.getTransferRank(); ++i) {
2928     // Already marked as in-bounds, nothing to see here.
2929     if (op.isDimInBounds(i)) {
2930       newInBounds.push_back(true);
2931       continue;
2932     }
2933     // Currently out-of-bounds, check whether we can statically determine it is
2934     // inBounds.
2935     auto dimExpr = permutationMap.getResult(i).dyn_cast<AffineDimExpr>();
2936     assert(dimExpr && "Broadcast dims must be in-bounds");
2937     auto inBounds =
2938         isInBounds(op, /*resultIdx=*/i, /*indicesIdx=*/dimExpr.getPosition());
2939     newInBounds.push_back(inBounds);
2940     // We commit the pattern if it is "more inbounds".
2941     changed |= inBounds;
2942   }
2943   if (!changed)
2944     return failure();
2945   // OpBuilder is only used as a helper to build an I64ArrayAttr.
2946   OpBuilder b(op.getContext());
2947   op->setAttr(TransferOp::getInBoundsAttrName(),
2948               b.getBoolArrayAttr(newInBounds));
2949   return success();
2950 }
2951 
2952 ///  ```
2953 ///  %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
2954 ///    : vector<1x4xf32>, tensor<4x4xf32>
2955 ///  %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
2956 ///    : tensor<4x4xf32>, vector<1x4xf32>
2957 ///  ```
2958 ///  -> Folds into
2959 ///  ```
2960 ///  %v0
2961 ///  ```
2962 static Value foldRAW(TransferReadOp readOp) {
2963   if (!readOp.getShapedType().isa<RankedTensorType>())
2964     return {};
2965   auto defWrite = readOp.source().getDefiningOp<vector::TransferWriteOp>();
2966   while (defWrite) {
2967     if (checkSameValueRAW(defWrite, readOp))
2968       return defWrite.vector();
2969     if (!isDisjointTransferIndices(
2970             cast<VectorTransferOpInterface>(defWrite.getOperation()),
2971             cast<VectorTransferOpInterface>(readOp.getOperation())))
2972       break;
2973     defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>();
2974   }
2975   return {};
2976 }
2977 
2978 OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
2979   if (Value vec = foldRAW(*this))
2980     return vec;
2981   /// transfer_read(memrefcast) -> transfer_read
2982   if (succeeded(foldTransferInBoundsAttribute(*this)))
2983     return getResult();
2984   if (succeeded(foldMemRefCast(*this)))
2985     return getResult();
2986   if (succeeded(foldTensorCast(*this)))
2987     return getResult();
2988   return OpFoldResult();
2989 }
2990 
2991 Optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
2992   return llvm::to_vector<4>(getVectorType().getShape());
2993 }
2994 
2995 void TransferReadOp::getEffects(
2996     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2997         &effects) {
2998   if (getShapedType().isa<MemRefType>())
2999     effects.emplace_back(MemoryEffects::Read::get(), source(),
3000                          SideEffects::DefaultResource::get());
3001 }
3002 
3003 namespace {
3004 /// Fold transfer_reads of a tensor.extract_slice op. E.g.:
3005 ///
3006 /// ```
3007 /// %0 = tensor.extract_slice %t[%a, %b] [%c, %d] [1, 1]
3008 ///     : tensor<?x?xf32> to tensor<?x?xf32>
3009 /// %1 = vector.transfer_read %0[%e, %f], %cst {in_bounds = [true, true]}
3010 ///     : tensor<?x?xf32>, vector<4x5xf32>
3011 /// ```
3012 /// is rewritten to:
3013 /// ```
3014 /// %p0 = arith.addi %a, %e : index
3015 /// %p1 = arith.addi %b, %f : index
3016 /// %1 = vector.transfer_read %t[%p0, %p1], %cst {in_bounds = [true, true]}
3017 ///     : tensor<?x?xf32>, vector<4x5xf32>
3018 /// ```
3019 struct FoldExtractSliceIntoTransferRead
3020     : public OpRewritePattern<TransferReadOp> {
3021 public:
3022   using OpRewritePattern<TransferReadOp>::OpRewritePattern;
3023 
3024   LogicalResult matchAndRewrite(TransferReadOp xferOp,
3025                                 PatternRewriter &rewriter) const override {
3026     // TODO: support 0-d corner case.
3027     if (xferOp.getTransferRank() == 0)
3028       return failure();
3029     if (xferOp.hasOutOfBoundsDim())
3030       return failure();
3031     if (!xferOp.permutation_map().isIdentity())
3032       return failure();
3033     if (xferOp.mask())
3034       return failure();
3035     auto extractOp = xferOp.source().getDefiningOp<tensor::ExtractSliceOp>();
3036     if (!extractOp)
3037       return failure();
3038     if (!extractOp.hasUnitStride())
3039       return failure();
3040 
3041     // Bail on illegal rank-reduction: we need to check that the rank-reduced
3042     // dims are exactly the leading dims. I.e. the following is illegal:
3043     // ```
3044     //    %0 = tensor.extract_slice %t[0,0,0][2,1,4][1,1,1] :
3045     //      tensor<2x1x4xf32> to tensor<2x4xf32>
3046     //    %1 = vector.transfer_read %0[0,0], %cst :
3047     //      tensor<2x4xf32>, vector<2x4xf32>
3048     // ```
3049     //
3050     // Cannot fold into:
3051     // ```
3052     //    %0 = vector.transfer_read %t[0,0,0], %cst :
3053     //      tensor<2x1x4xf32>, vector<2x4xf32>
3054     // ```
3055     // For this, check the trailing `vectorRank` dims of the extract_slice
3056     // result tensor match the trailing dims of the inferred result tensor.
3057     int64_t rankReduced =
3058         extractOp.getSourceType().getRank() - extractOp.getType().getRank();
3059     int64_t vectorRank = xferOp.getVectorType().getRank();
3060     RankedTensorType inferredDestTensorType =
3061         tensor::ExtractSliceOp::inferResultType(
3062             extractOp.getSourceType(), extractOp.getMixedOffsets(),
3063             extractOp.getMixedSizes(), extractOp.getMixedStrides());
3064     auto actualDestTensorShape = extractOp.getType().getShape();
3065     if (rankReduced > 0 &&
3066         actualDestTensorShape.take_back(vectorRank) !=
3067             inferredDestTensorType.getShape().take_back(vectorRank))
3068       return failure();
3069 
3070     SmallVector<Value> newIndices;
3071     // In case this is a rank-reducing ExtractSliceOp, copy rank-reduced
3072     // indices first.
3073     for (int64_t i = 0; i < rankReduced; ++i) {
3074       OpFoldResult offset = extractOp.getMixedOffsets()[i];
3075       newIndices.push_back(getValueOrCreateConstantIndexOp(
3076           rewriter, extractOp.getLoc(), offset));
3077     }
3078     for (const auto &it : llvm::enumerate(xferOp.indices())) {
3079       OpFoldResult offset =
3080           extractOp.getMixedOffsets()[it.index() + rankReduced];
3081       newIndices.push_back(rewriter.create<arith::AddIOp>(
3082           xferOp->getLoc(), it.value(),
3083           getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(),
3084                                           offset)));
3085     }
3086     SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
3087     rewriter.replaceOpWithNewOp<TransferReadOp>(
3088         xferOp, xferOp.getVectorType(), extractOp.source(), newIndices,
3089         xferOp.padding(), ArrayRef<bool>{inBounds});
3090 
3091     return success();
3092   }
3093 };
3094 } // namespace
3095 
3096 void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3097                                                  MLIRContext *context) {
3098   results.add<FoldExtractSliceIntoTransferRead>(context);
3099 }
3100 
3101 //===----------------------------------------------------------------------===//
3102 // TransferWriteOp
3103 //===----------------------------------------------------------------------===//
3104 
3105 /// 1. Builder with type inference.
3106 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
3107                             Value vector, Value dest, ValueRange indices,
3108                             AffineMapAttr permutationMapAttr,
3109                             /*optional*/ Value mask,
3110                             /*optional*/ ArrayAttr inBoundsAttr) {
3111   Type resultType = dest.getType().dyn_cast<RankedTensorType>();
3112   build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
3113         mask, inBoundsAttr);
3114 }
3115 
3116 /// 2. Builder with type inference that sets an empty mask (variant with attrs).
3117 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
3118                             Value vector, Value dest, ValueRange indices,
3119                             AffineMapAttr permutationMapAttr,
3120                             /*optional*/ ArrayAttr inBoundsAttr) {
3121   build(builder, result, vector, dest, indices, permutationMapAttr,
3122         /*mask=*/Value(), inBoundsAttr);
3123 }
3124 
3125 /// 3. Builder with type inference that sets an empty mask (variant without
3126 /// attrs)
3127 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
3128                             Value vector, Value dest, ValueRange indices,
3129                             AffineMap permutationMap,
3130                             Optional<ArrayRef<bool>> inBounds) {
3131   auto permutationMapAttr = AffineMapAttr::get(permutationMap);
3132   auto inBoundsAttr = (inBounds && !inBounds.getValue().empty())
3133                           ? builder.getBoolArrayAttr(inBounds.getValue())
3134                           : ArrayAttr();
3135   build(builder, result, vector, dest, indices, permutationMapAttr,
3136         /*mask=*/Value(), inBoundsAttr);
3137 }
3138 
3139 /// 4. Builder with type inference that sets an empty mask and sets permutation
3140 ///    map to 'getMinorIdentityMap'.
3141 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
3142                             Value vector, Value dest, ValueRange indices,
3143                             Optional<ArrayRef<bool>> inBounds) {
3144   auto vectorType = vector.getType().cast<VectorType>();
3145   AffineMap permutationMap = getTransferMinorIdentityMap(
3146       dest.getType().cast<ShapedType>(), vectorType);
3147   build(builder, result, vector, dest, indices, permutationMap, inBounds);
3148 }
3149 
3150 static ParseResult parseTransferWriteOp(OpAsmParser &parser,
3151                                         OperationState &result) {
3152   auto &builder = parser.getBuilder();
3153   SMLoc typesLoc;
3154   OpAsmParser::OperandType vectorInfo, sourceInfo;
3155   SmallVector<OpAsmParser::OperandType, 8> indexInfo;
3156   SmallVector<Type, 2> types;
3157   OpAsmParser::OperandType maskInfo;
3158   if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
3159       parser.parseOperand(sourceInfo) ||
3160       parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
3161     return failure();
3162   ParseResult hasMask = parser.parseOptionalComma();
3163   if (hasMask.succeeded() && parser.parseOperand(maskInfo))
3164     return failure();
3165   if (parser.parseOptionalAttrDict(result.attributes) ||
3166       parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
3167     return failure();
3168   if (types.size() != 2)
3169     return parser.emitError(typesLoc, "requires two types");
3170   auto indexType = builder.getIndexType();
3171   VectorType vectorType = types[0].dyn_cast<VectorType>();
3172   if (!vectorType)
3173     return parser.emitError(typesLoc, "requires vector type");
3174   ShapedType shapedType = types[1].dyn_cast<ShapedType>();
3175   if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
3176     return parser.emitError(typesLoc, "requires memref or ranked tensor type");
3177   auto permutationAttrName = TransferWriteOp::getPermutationMapAttrName();
3178   auto attr = result.attributes.get(permutationAttrName);
3179   if (!attr) {
3180     auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
3181     result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
3182   }
3183   if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
3184       parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
3185       parser.resolveOperands(indexInfo, indexType, result.operands))
3186     return failure();
3187   if (hasMask.succeeded()) {
3188     if (shapedType.getElementType().dyn_cast<VectorType>())
3189       return parser.emitError(
3190           maskInfo.location, "does not support masks with vector element type");
3191     auto maskType = VectorType::get(vectorType.getShape(), builder.getI1Type());
3192     if (parser.resolveOperand(maskInfo, maskType, result.operands))
3193       return failure();
3194   }
3195   result.addAttribute(
3196       TransferWriteOp::getOperandSegmentSizeAttr(),
3197       builder.getI32VectorAttr({1, 1, static_cast<int32_t>(indexInfo.size()),
3198                                 static_cast<int32_t>(hasMask.succeeded())}));
3199   return failure(shapedType.isa<RankedTensorType>() &&
3200                  parser.addTypeToList(shapedType, result.types));
3201 }
3202 
3203 static void print(OpAsmPrinter &p, TransferWriteOp op) {
3204   p << " " << op.vector() << ", " << op.source() << "[" << op.indices() << "]";
3205   if (op.mask())
3206     p << ", " << op.mask();
3207   printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
3208   p << " : " << op.getVectorType() << ", " << op.getShapedType();
3209 }
3210 
3211 static LogicalResult verify(TransferWriteOp op) {
3212   // Consistency of elemental types in shape and vector.
3213   ShapedType shapedType = op.getShapedType();
3214   VectorType vectorType = op.getVectorType();
3215   VectorType maskType = op.getMaskType();
3216   auto permutationMap = op.permutation_map();
3217 
3218   if (llvm::size(op.indices()) != shapedType.getRank())
3219     return op.emitOpError("requires ") << shapedType.getRank() << " indices";
3220 
3221   // We do not allow broadcast dimensions on TransferWriteOps for the moment,
3222   // as the semantics is unclear. This can be revisited later if necessary.
3223   if (op.hasBroadcastDim())
3224     return op.emitOpError("should not have broadcast dimensions");
3225 
3226   if (failed(
3227           verifyTransferOp(cast<VectorTransferOpInterface>(op.getOperation()),
3228                            shapedType, vectorType, maskType, permutationMap,
3229                            op.in_bounds() ? *op.in_bounds() : ArrayAttr())))
3230     return failure();
3231 
3232   return verifyPermutationMap(permutationMap,
3233                               [&op](Twine t) { return op.emitOpError(t); });
3234 }
3235 
3236 /// Fold:
3237 /// ```
3238 ///    %t1 = ...
3239 ///    %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
3240 ///      tensor<static_sizesxf32>, vector<static_sizesxf32>
3241 ///    %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
3242 ///      vector<static_sizesxf32>, tensor<static_sizesxf32>
3243 /// ```
3244 ///
3245 /// into:
3246 ///
3247 /// ```
3248 ///    %t0
3249 /// ```
3250 ///
3251 /// The producer of t1 may or may not be DCE'd depending on whether it is a
3252 /// block argument or has side effects.
3253 static LogicalResult foldReadInitWrite(TransferWriteOp write,
3254                                        ArrayRef<Attribute>,
3255                                        SmallVectorImpl<OpFoldResult> &results) {
3256   // TODO: support 0-d corner case.
3257   if (write.getTransferRank() == 0)
3258     return failure();
3259   auto rankedTensorType = write.source().getType().dyn_cast<RankedTensorType>();
3260   // If not operating on tensors, bail.
3261   if (!rankedTensorType)
3262     return failure();
3263   // If no read, bail.
3264   auto read = write.vector().getDefiningOp<vector::TransferReadOp>();
3265   if (!read)
3266     return failure();
3267   // TODO: support 0-d corner case.
3268   if (read.getTransferRank() == 0)
3269     return failure();
3270   // For now, only accept minor identity. Future: composition is minor identity.
3271   if (!read.permutation_map().isMinorIdentity() ||
3272       !write.permutation_map().isMinorIdentity())
3273     return failure();
3274   // Bail on mismatching ranks.
3275   if (read.getTransferRank() != write.getTransferRank())
3276     return failure();
3277   // Bail on potential out-of-bounds accesses.
3278   if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
3279     return failure();
3280   // Tensor types must be the same.
3281   if (read.source().getType() != rankedTensorType)
3282     return failure();
3283   // Vector types must be the same.
3284   if (read.getVectorType() != write.getVectorType())
3285     return failure();
3286   // Vector and Tensor shapes must match.
3287   if (read.getVectorType().getShape() != rankedTensorType.getShape())
3288     return failure();
3289   // If any index is nonzero.
3290   auto isNotConstantZero = [](Value v) {
3291     auto cstOp = v.getDefiningOp<arith::ConstantIndexOp>();
3292     return !cstOp || cstOp.value() != 0;
3293   };
3294   if (llvm::any_of(read.indices(), isNotConstantZero) ||
3295       llvm::any_of(write.indices(), isNotConstantZero))
3296     return failure();
3297   // Success.
3298   results.push_back(read.source());
3299   return success();
3300 }
3301 
3302 static bool checkSameValueWAR(vector::TransferReadOp read,
3303                               vector::TransferWriteOp write) {
3304   return read.source() == write.source() && read.indices() == write.indices() &&
3305          read.permutation_map() == write.permutation_map() &&
3306          read.getVectorType() == write.getVectorType() && !read.mask() &&
3307          !write.mask();
3308 }
3309 /// Fold transfer_write write after read:
3310 /// ```
3311 ///    %t0 = ...
3312 ///    %v = vector.transfer_read %t0[%c0...] :
3313 ///      tensor<static_sizesxf32>, vector<static_sizesxf32>
3314 ///    %t1 = vector.transfer_write %v, %t0[%c0...] :
3315 ///      vector<static_sizesxf32>, tensor<static_sizesxf32>
3316 /// ```
3317 ///
3318 /// into:
3319 ///
3320 /// ```
3321 ///    %t0
3322 /// ```
3323 static LogicalResult foldWAR(TransferWriteOp write,
3324                              SmallVectorImpl<OpFoldResult> &results) {
3325   if (!write.source().getType().isa<RankedTensorType>())
3326     return failure();
3327   auto read = write.vector().getDefiningOp<vector::TransferReadOp>();
3328   if (!read)
3329     return failure();
3330 
3331   if (!checkSameValueWAR(read, write))
3332     return failure();
3333   results.push_back(read.source());
3334   return success();
3335 }
3336 
3337 LogicalResult TransferWriteOp::fold(ArrayRef<Attribute> operands,
3338                                     SmallVectorImpl<OpFoldResult> &results) {
3339   if (succeeded(foldReadInitWrite(*this, operands, results)))
3340     return success();
3341   if (succeeded(foldWAR(*this, results)))
3342     return success();
3343   if (succeeded(foldTransferInBoundsAttribute(*this)))
3344     return success();
3345   return foldMemRefCast(*this);
3346 }
3347 
3348 Optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
3349   return llvm::to_vector<4>(getVectorType().getShape());
3350 }
3351 
3352 void TransferWriteOp::getEffects(
3353     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
3354         &effects) {
3355   if (getShapedType().isa<MemRefType>())
3356     effects.emplace_back(MemoryEffects::Write::get(), source(),
3357                          SideEffects::DefaultResource::get());
3358 }
3359 
3360 namespace {
3361 /// Remove dead transfer write from the SSA chain so that it an be eliminated by
3362 /// DCE
3363 /// ```
3364 ///  %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
3365 ///    : vector<1x4xf32>, tensor<4x4xf32>
3366 ///  %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
3367 ///    : vector<1x4xf32>, tensor<4x4xf32>
3368 ///  %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
3369 ///    : vector<1x4xf32>, tensor<4x4xf32>
3370 /// ```
3371 ///
3372 /// into:
3373 ///
3374 /// ```
3375 ///  %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
3376 ///    : vector<1x4xf32>, tensor<4x4xf32>
3377 ///  %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
3378 ///    : vector<1x4xf32>, tensor<4x4xf32>
3379 ///  %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
3380 ///    : vector<1x4xf32>, tensor<4x4xf32>
3381 /// ```
3382 ///
3383 /// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
3384 /// any other uses.
3385 class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
3386 public:
3387   using OpRewritePattern<TransferWriteOp>::OpRewritePattern;
3388   LogicalResult matchAndRewrite(TransferWriteOp writeOp,
3389                                 PatternRewriter &rewriter) const override {
3390     if (!writeOp.getShapedType().isa<RankedTensorType>())
3391       return failure();
3392     vector::TransferWriteOp writeToModify = writeOp;
3393 
3394     auto defWrite = writeOp.source().getDefiningOp<vector::TransferWriteOp>();
3395     while (defWrite) {
3396       if (checkSameValueWAW(writeOp, defWrite)) {
3397         writeToModify.sourceMutable().assign(defWrite.source());
3398         return success();
3399       }
3400       if (!isDisjointTransferIndices(
3401               cast<VectorTransferOpInterface>(defWrite.getOperation()),
3402               cast<VectorTransferOpInterface>(writeOp.getOperation())))
3403         break;
3404       // If the previous write op doesn't have any other use we an safely look
3405       // at the previous store to see if it can be removed.
3406       if (!defWrite->hasOneUse())
3407         break;
3408       writeToModify = defWrite;
3409       defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>();
3410     }
3411     return failure();
3412   }
3413 };
3414 
3415 /// Fold tensor.insert_slice into vector.transfer_write if the transfer_write
3416 /// could directly write to the insert_slice's destination. E.g.:
3417 ///
3418 /// ```
3419 /// %0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]}
3420 ///     : vector<4x5xf32>, tensor<4x5xf32>
3421 /// %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1]
3422 ///     : tensor<4x5xf32> into tensor<?x?xf32>
3423 /// ```
3424 /// is rewritten to:
3425 /// ```
3426 /// %1 = vector.transfer_write %v, %t2[%a, %b] {in_bounds = [true, true]}
3427 ///     : vector<4x5xf32>, tensor<?x?xf32>
3428 /// ```
3429 struct FoldInsertSliceIntoTransferWrite
3430     : public OpRewritePattern<tensor::InsertSliceOp> {
3431 public:
3432   using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
3433 
3434   LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
3435                                 PatternRewriter &rewriter) const override {
3436     if (!insertOp.hasUnitStride())
3437       return failure();
3438 
3439     auto xferOp = insertOp.source().getDefiningOp<TransferWriteOp>();
3440     if (!xferOp)
3441       return failure();
3442     // TODO: support 0-d corner case.
3443     if (xferOp.getTransferRank() == 0)
3444       return failure();
3445 
3446     if (xferOp.hasOutOfBoundsDim())
3447       return failure();
3448     if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank())
3449       return failure();
3450     if (xferOp.mask())
3451       return failure();
3452     // Fold only if the TransferWriteOp completely overwrites the `source` with
3453     // a vector. I.e., the result of the TransferWriteOp is a new tensor whose
3454     // content is the data of the vector.
3455     if (!llvm::equal(xferOp.getVectorType().getShape(),
3456                      xferOp.getShapedType().getShape()))
3457       return failure();
3458     if (!xferOp.permutation_map().isIdentity())
3459       return failure();
3460 
3461     // Bail on illegal rank-reduction: we need to check that the rank-reduced
3462     // dims are exactly the leading dims. I.e. the following is illegal:
3463     // ```
3464     //    %0 = vector.transfer_write %v, %t[0,0], %cst :
3465     //      vector<2x4xf32>, tensor<2x4xf32>
3466     //    %1 = tensor.insert_slice %0 into %tt[0,0,0][2,1,4][1,1,1] :
3467     //      tensor<2x4xf32> into tensor<2x1x4xf32>
3468     // ```
3469     //
3470     // Cannot fold into:
3471     // ```
3472     //    %0 = vector.transfer_write %v, %t[0,0,0], %cst :
3473     //      vector<2x4xf32>, tensor<2x1x4xf32>
3474     // ```
3475     // For this, check the trailing `vectorRank` dims of the insert_slice result
3476     // tensor match the trailing dims of the inferred result tensor.
3477     int64_t rankReduced =
3478         insertOp.getType().getRank() - insertOp.getSourceType().getRank();
3479     int64_t vectorRank = xferOp.getVectorType().getRank();
3480     RankedTensorType inferredSourceTensorType =
3481         tensor::ExtractSliceOp::inferResultType(
3482             insertOp.getType(), insertOp.getMixedOffsets(),
3483             insertOp.getMixedSizes(), insertOp.getMixedStrides());
3484     auto actualSourceTensorShape = insertOp.getSourceType().getShape();
3485     if (rankReduced > 0 &&
3486         actualSourceTensorShape.take_back(vectorRank) !=
3487             inferredSourceTensorType.getShape().take_back(vectorRank))
3488       return failure();
3489 
3490     SmallVector<Value> indices = getValueOrCreateConstantIndexOp(
3491         rewriter, insertOp.getLoc(), insertOp.getMixedOffsets());
3492     SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
3493     rewriter.replaceOpWithNewOp<TransferWriteOp>(insertOp, xferOp.vector(),
3494                                                  insertOp.dest(), indices,
3495                                                  ArrayRef<bool>{inBounds});
3496     return success();
3497   }
3498 };
3499 } // namespace
3500 
3501 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
3502                                                   MLIRContext *context) {
3503   results.add<FoldWaw, FoldInsertSliceIntoTransferWrite>(context);
3504 }
3505 
3506 //===----------------------------------------------------------------------===//
3507 // LoadOp
3508 //===----------------------------------------------------------------------===//
3509 
3510 static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
3511                                                  MemRefType memRefTy) {
3512   if (!isLastMemrefDimUnitStride(memRefTy))
3513     return op->emitOpError("most minor memref dim must have unit stride");
3514   return success();
3515 }
3516 
3517 static LogicalResult verify(vector::LoadOp op) {
3518   VectorType resVecTy = op.getVectorType();
3519   MemRefType memRefTy = op.getMemRefType();
3520 
3521   if (failed(verifyLoadStoreMemRefLayout(op, memRefTy)))
3522     return failure();
3523 
3524   // Checks for vector memrefs.
3525   Type memElemTy = memRefTy.getElementType();
3526   if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) {
3527     if (memVecTy != resVecTy)
3528       return op.emitOpError("base memref and result vector types should match");
3529     memElemTy = memVecTy.getElementType();
3530   }
3531 
3532   if (resVecTy.getElementType() != memElemTy)
3533     return op.emitOpError("base and result element types should match");
3534   if (llvm::size(op.indices()) != memRefTy.getRank())
3535     return op.emitOpError("requires ") << memRefTy.getRank() << " indices";
3536   return success();
3537 }
3538 
3539 OpFoldResult LoadOp::fold(ArrayRef<Attribute>) {
3540   if (succeeded(foldMemRefCast(*this)))
3541     return getResult();
3542   return OpFoldResult();
3543 }
3544 
3545 //===----------------------------------------------------------------------===//
3546 // StoreOp
3547 //===----------------------------------------------------------------------===//
3548 
3549 static LogicalResult verify(vector::StoreOp op) {
3550   VectorType valueVecTy = op.getVectorType();
3551   MemRefType memRefTy = op.getMemRefType();
3552 
3553   if (failed(verifyLoadStoreMemRefLayout(op, memRefTy)))
3554     return failure();
3555 
3556   // Checks for vector memrefs.
3557   Type memElemTy = memRefTy.getElementType();
3558   if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) {
3559     if (memVecTy != valueVecTy)
3560       return op.emitOpError(
3561           "base memref and valueToStore vector types should match");
3562     memElemTy = memVecTy.getElementType();
3563   }
3564 
3565   if (valueVecTy.getElementType() != memElemTy)
3566     return op.emitOpError("base and valueToStore element type should match");
3567   if (llvm::size(op.indices()) != memRefTy.getRank())
3568     return op.emitOpError("requires ") << memRefTy.getRank() << " indices";
3569   return success();
3570 }
3571 
3572 LogicalResult StoreOp::fold(ArrayRef<Attribute> operands,
3573                             SmallVectorImpl<OpFoldResult> &results) {
3574   return foldMemRefCast(*this);
3575 }
3576 
3577 //===----------------------------------------------------------------------===//
3578 // MaskedLoadOp
3579 //===----------------------------------------------------------------------===//
3580 
3581 static LogicalResult verify(MaskedLoadOp op) {
3582   VectorType maskVType = op.getMaskVectorType();
3583   VectorType passVType = op.getPassThruVectorType();
3584   VectorType resVType = op.getVectorType();
3585   MemRefType memType = op.getMemRefType();
3586 
3587   if (resVType.getElementType() != memType.getElementType())
3588     return op.emitOpError("base and result element type should match");
3589   if (llvm::size(op.indices()) != memType.getRank())
3590     return op.emitOpError("requires ") << memType.getRank() << " indices";
3591   if (resVType.getDimSize(0) != maskVType.getDimSize(0))
3592     return op.emitOpError("expected result dim to match mask dim");
3593   if (resVType != passVType)
3594     return op.emitOpError("expected pass_thru of same type as result type");
3595   return success();
3596 }
3597 
3598 namespace {
3599 class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
3600 public:
3601   using OpRewritePattern<MaskedLoadOp>::OpRewritePattern;
3602   LogicalResult matchAndRewrite(MaskedLoadOp load,
3603                                 PatternRewriter &rewriter) const override {
3604     switch (get1DMaskFormat(load.mask())) {
3605     case MaskFormat::AllTrue:
3606       rewriter.replaceOpWithNewOp<vector::LoadOp>(load, load.getType(),
3607                                                   load.base(), load.indices());
3608       return success();
3609     case MaskFormat::AllFalse:
3610       rewriter.replaceOp(load, load.pass_thru());
3611       return success();
3612     case MaskFormat::Unknown:
3613       return failure();
3614     }
3615     llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
3616   }
3617 };
3618 } // namespace
3619 
3620 void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3621                                                MLIRContext *context) {
3622   results.add<MaskedLoadFolder>(context);
3623 }
3624 
3625 OpFoldResult MaskedLoadOp::fold(ArrayRef<Attribute>) {
3626   if (succeeded(foldMemRefCast(*this)))
3627     return getResult();
3628   return OpFoldResult();
3629 }
3630 
3631 //===----------------------------------------------------------------------===//
3632 // MaskedStoreOp
3633 //===----------------------------------------------------------------------===//
3634 
3635 static LogicalResult verify(MaskedStoreOp op) {
3636   VectorType maskVType = op.getMaskVectorType();
3637   VectorType valueVType = op.getVectorType();
3638   MemRefType memType = op.getMemRefType();
3639 
3640   if (valueVType.getElementType() != memType.getElementType())
3641     return op.emitOpError("base and valueToStore element type should match");
3642   if (llvm::size(op.indices()) != memType.getRank())
3643     return op.emitOpError("requires ") << memType.getRank() << " indices";
3644   if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
3645     return op.emitOpError("expected valueToStore dim to match mask dim");
3646   return success();
3647 }
3648 
3649 namespace {
3650 class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
3651 public:
3652   using OpRewritePattern<MaskedStoreOp>::OpRewritePattern;
3653   LogicalResult matchAndRewrite(MaskedStoreOp store,
3654                                 PatternRewriter &rewriter) const override {
3655     switch (get1DMaskFormat(store.mask())) {
3656     case MaskFormat::AllTrue:
3657       rewriter.replaceOpWithNewOp<vector::StoreOp>(
3658           store, store.valueToStore(), store.base(), store.indices());
3659       return success();
3660     case MaskFormat::AllFalse:
3661       rewriter.eraseOp(store);
3662       return success();
3663     case MaskFormat::Unknown:
3664       return failure();
3665     }
3666     llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
3667   }
3668 };
3669 } // namespace
3670 
3671 void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
3672                                                 MLIRContext *context) {
3673   results.add<MaskedStoreFolder>(context);
3674 }
3675 
3676 LogicalResult MaskedStoreOp::fold(ArrayRef<Attribute> operands,
3677                                   SmallVectorImpl<OpFoldResult> &results) {
3678   return foldMemRefCast(*this);
3679 }
3680 
3681 //===----------------------------------------------------------------------===//
3682 // GatherOp
3683 //===----------------------------------------------------------------------===//
3684 
3685 static LogicalResult verify(GatherOp op) {
3686   VectorType indVType = op.getIndexVectorType();
3687   VectorType maskVType = op.getMaskVectorType();
3688   VectorType resVType = op.getVectorType();
3689   MemRefType memType = op.getMemRefType();
3690 
3691   if (resVType.getElementType() != memType.getElementType())
3692     return op.emitOpError("base and result element type should match");
3693   if (llvm::size(op.indices()) != memType.getRank())
3694     return op.emitOpError("requires ") << memType.getRank() << " indices";
3695   if (resVType.getDimSize(0) != indVType.getDimSize(0))
3696     return op.emitOpError("expected result dim to match indices dim");
3697   if (resVType.getDimSize(0) != maskVType.getDimSize(0))
3698     return op.emitOpError("expected result dim to match mask dim");
3699   if (resVType != op.getPassThruVectorType())
3700     return op.emitOpError("expected pass_thru of same type as result type");
3701   return success();
3702 }
3703 
3704 namespace {
3705 class GatherFolder final : public OpRewritePattern<GatherOp> {
3706 public:
3707   using OpRewritePattern<GatherOp>::OpRewritePattern;
3708   LogicalResult matchAndRewrite(GatherOp gather,
3709                                 PatternRewriter &rewriter) const override {
3710     switch (get1DMaskFormat(gather.mask())) {
3711     case MaskFormat::AllTrue:
3712       return failure(); // no unmasked equivalent
3713     case MaskFormat::AllFalse:
3714       rewriter.replaceOp(gather, gather.pass_thru());
3715       return success();
3716     case MaskFormat::Unknown:
3717       return failure();
3718     }
3719     llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
3720   }
3721 };
3722 } // namespace
3723 
3724 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
3725                                            MLIRContext *context) {
3726   results.add<GatherFolder>(context);
3727 }
3728 
3729 //===----------------------------------------------------------------------===//
3730 // ScatterOp
3731 //===----------------------------------------------------------------------===//
3732 
3733 static LogicalResult verify(ScatterOp op) {
3734   VectorType indVType = op.getIndexVectorType();
3735   VectorType maskVType = op.getMaskVectorType();
3736   VectorType valueVType = op.getVectorType();
3737   MemRefType memType = op.getMemRefType();
3738 
3739   if (valueVType.getElementType() != memType.getElementType())
3740     return op.emitOpError("base and valueToStore element type should match");
3741   if (llvm::size(op.indices()) != memType.getRank())
3742     return op.emitOpError("requires ") << memType.getRank() << " indices";
3743   if (valueVType.getDimSize(0) != indVType.getDimSize(0))
3744     return op.emitOpError("expected valueToStore dim to match indices dim");
3745   if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
3746     return op.emitOpError("expected valueToStore dim to match mask dim");
3747   return success();
3748 }
3749 
3750 namespace {
3751 class ScatterFolder final : public OpRewritePattern<ScatterOp> {
3752 public:
3753   using OpRewritePattern<ScatterOp>::OpRewritePattern;
3754   LogicalResult matchAndRewrite(ScatterOp scatter,
3755                                 PatternRewriter &rewriter) const override {
3756     switch (get1DMaskFormat(scatter.mask())) {
3757     case MaskFormat::AllTrue:
3758       return failure(); // no unmasked equivalent
3759     case MaskFormat::AllFalse:
3760       rewriter.eraseOp(scatter);
3761       return success();
3762     case MaskFormat::Unknown:
3763       return failure();
3764     }
3765     llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
3766   }
3767 };
3768 } // namespace
3769 
3770 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
3771                                             MLIRContext *context) {
3772   results.add<ScatterFolder>(context);
3773 }
3774 
3775 //===----------------------------------------------------------------------===//
3776 // ExpandLoadOp
3777 //===----------------------------------------------------------------------===//
3778 
3779 static LogicalResult verify(ExpandLoadOp op) {
3780   VectorType maskVType = op.getMaskVectorType();
3781   VectorType passVType = op.getPassThruVectorType();
3782   VectorType resVType = op.getVectorType();
3783   MemRefType memType = op.getMemRefType();
3784 
3785   if (resVType.getElementType() != memType.getElementType())
3786     return op.emitOpError("base and result element type should match");
3787   if (llvm::size(op.indices()) != memType.getRank())
3788     return op.emitOpError("requires ") << memType.getRank() << " indices";
3789   if (resVType.getDimSize(0) != maskVType.getDimSize(0))
3790     return op.emitOpError("expected result dim to match mask dim");
3791   if (resVType != passVType)
3792     return op.emitOpError("expected pass_thru of same type as result type");
3793   return success();
3794 }
3795 
3796 namespace {
3797 class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
3798 public:
3799   using OpRewritePattern<ExpandLoadOp>::OpRewritePattern;
3800   LogicalResult matchAndRewrite(ExpandLoadOp expand,
3801                                 PatternRewriter &rewriter) const override {
3802     switch (get1DMaskFormat(expand.mask())) {
3803     case MaskFormat::AllTrue:
3804       rewriter.replaceOpWithNewOp<vector::LoadOp>(
3805           expand, expand.getType(), expand.base(), expand.indices());
3806       return success();
3807     case MaskFormat::AllFalse:
3808       rewriter.replaceOp(expand, expand.pass_thru());
3809       return success();
3810     case MaskFormat::Unknown:
3811       return failure();
3812     }
3813     llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
3814   }
3815 };
3816 } // namespace
3817 
3818 void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3819                                                MLIRContext *context) {
3820   results.add<ExpandLoadFolder>(context);
3821 }
3822 
3823 //===----------------------------------------------------------------------===//
3824 // CompressStoreOp
3825 //===----------------------------------------------------------------------===//
3826 
3827 static LogicalResult verify(CompressStoreOp op) {
3828   VectorType maskVType = op.getMaskVectorType();
3829   VectorType valueVType = op.getVectorType();
3830   MemRefType memType = op.getMemRefType();
3831 
3832   if (valueVType.getElementType() != memType.getElementType())
3833     return op.emitOpError("base and valueToStore element type should match");
3834   if (llvm::size(op.indices()) != memType.getRank())
3835     return op.emitOpError("requires ") << memType.getRank() << " indices";
3836   if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
3837     return op.emitOpError("expected valueToStore dim to match mask dim");
3838   return success();
3839 }
3840 
3841 namespace {
3842 class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
3843 public:
3844   using OpRewritePattern<CompressStoreOp>::OpRewritePattern;
3845   LogicalResult matchAndRewrite(CompressStoreOp compress,
3846                                 PatternRewriter &rewriter) const override {
3847     switch (get1DMaskFormat(compress.mask())) {
3848     case MaskFormat::AllTrue:
3849       rewriter.replaceOpWithNewOp<vector::StoreOp>(
3850           compress, compress.valueToStore(), compress.base(),
3851           compress.indices());
3852       return success();
3853     case MaskFormat::AllFalse:
3854       rewriter.eraseOp(compress);
3855       return success();
3856     case MaskFormat::Unknown:
3857       return failure();
3858     }
3859     llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
3860   }
3861 };
3862 } // namespace
3863 
3864 void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
3865                                                   MLIRContext *context) {
3866   results.add<CompressStoreFolder>(context);
3867 }
3868 
3869 //===----------------------------------------------------------------------===//
3870 // ShapeCastOp
3871 //===----------------------------------------------------------------------===//
3872 
3873 /// Returns true if each element of 'a' is equal to the product of a contiguous
3874 /// sequence of the elements of 'b'. Returns false otherwise.
3875 static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
3876   unsigned rankA = a.size();
3877   unsigned rankB = b.size();
3878   assert(rankA < rankB);
3879 
3880   unsigned i = 0;
3881   unsigned j = 0;
3882   while (i < rankA && j < rankB) {
3883     int64_t dimA = a[i];
3884     int64_t dimB = 1;
3885     while (dimB < dimA && j < rankB)
3886       dimB *= b[j++];
3887     if (dimA != dimB)
3888       break;
3889     ++i;
3890 
3891     // Handle the case when trailing dimensions are of size 1.
3892     // Include them into the contiguous sequence.
3893     auto isOne = [](int64_t v) { return v == 1; };
3894     if (i < rankA && llvm::all_of(a.slice(i), isOne))
3895       i = rankA;
3896     if (j < rankB && llvm::all_of(b.slice(j), isOne))
3897       j = rankB;
3898   }
3899 
3900   return i == rankA && j == rankB;
3901 }
3902 
3903 static LogicalResult verifyVectorShapeCast(Operation *op,
3904                                            VectorType sourceVectorType,
3905                                            VectorType resultVectorType) {
3906   // Check that element type is the same.
3907   if (sourceVectorType.getElementType() != resultVectorType.getElementType())
3908     return op->emitOpError("source/result vectors must have same element type");
3909   auto sourceShape = sourceVectorType.getShape();
3910   auto resultShape = resultVectorType.getShape();
3911 
3912   // Check that product of source dim sizes matches product of result dim sizes.
3913   int64_t sourceDimProduct = std::accumulate(
3914       sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
3915   int64_t resultDimProduct = std::accumulate(
3916       resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
3917   if (sourceDimProduct != resultDimProduct)
3918     return op->emitOpError("source/result number of elements must match");
3919 
3920   // Check that expanding/contracting rank cases.
3921   unsigned sourceRank = sourceVectorType.getRank();
3922   unsigned resultRank = resultVectorType.getRank();
3923   if (sourceRank < resultRank) {
3924     if (!isValidShapeCast(sourceShape, resultShape))
3925       return op->emitOpError("invalid shape cast");
3926   } else if (sourceRank > resultRank) {
3927     if (!isValidShapeCast(resultShape, sourceShape))
3928       return op->emitOpError("invalid shape cast");
3929   }
3930   return success();
3931 }
3932 
3933 static LogicalResult verify(ShapeCastOp op) {
3934   auto sourceVectorType = op.source().getType().dyn_cast_or_null<VectorType>();
3935   auto resultVectorType = op.result().getType().dyn_cast_or_null<VectorType>();
3936 
3937   // Check if source/result are of vector type.
3938   if (sourceVectorType && resultVectorType)
3939     return verifyVectorShapeCast(op, sourceVectorType, resultVectorType);
3940 
3941   return success();
3942 }
3943 
3944 OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
3945   // Nop shape cast.
3946   if (source().getType() == result().getType())
3947     return source();
3948 
3949   // Canceling shape casts.
3950   if (auto otherOp = source().getDefiningOp<ShapeCastOp>()) {
3951     if (result().getType() == otherOp.source().getType())
3952       return otherOp.source();
3953 
3954     // Only allows valid transitive folding.
3955     VectorType srcType = otherOp.source().getType().cast<VectorType>();
3956     VectorType resultType = getResult().getType().cast<VectorType>();
3957     if (srcType.getRank() < resultType.getRank()) {
3958       if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
3959         return {};
3960     } else if (srcType.getRank() > resultType.getRank()) {
3961       if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
3962         return {};
3963     } else {
3964       return {};
3965     }
3966 
3967     setOperand(otherOp.source());
3968     return getResult();
3969   }
3970   return {};
3971 }
3972 
3973 namespace {
3974 // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
3975 class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
3976 public:
3977   using OpRewritePattern<ShapeCastOp>::OpRewritePattern;
3978 
3979   LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
3980                                 PatternRewriter &rewriter) const override {
3981     auto constantOp = shapeCastOp.source().getDefiningOp<arith::ConstantOp>();
3982     if (!constantOp)
3983       return failure();
3984     // Only handle splat for now.
3985     auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
3986     if (!dense)
3987       return failure();
3988     auto newAttr =
3989         DenseElementsAttr::get(shapeCastOp.getType().cast<VectorType>(),
3990                                dense.getSplatValue<Attribute>());
3991     rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr);
3992     return success();
3993   }
3994 };
3995 
3996 } // namespace
3997 
3998 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
3999                                               MLIRContext *context) {
4000   // Pattern to rewrite a ShapeCastOp(ConstantOp) -> ConstantOp.
4001   results.add<ShapeCastConstantFolder>(context);
4002 }
4003 
4004 //===----------------------------------------------------------------------===//
4005 // VectorBitCastOp
4006 //===----------------------------------------------------------------------===//
4007 
4008 static LogicalResult verify(BitCastOp op) {
4009   auto sourceVectorType = op.getSourceVectorType();
4010   auto resultVectorType = op.getResultVectorType();
4011 
4012   for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
4013     if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
4014       return op.emitOpError("dimension size mismatch at: ") << i;
4015   }
4016 
4017   DataLayout dataLayout = DataLayout::closest(op);
4018   auto sourceElementBits =
4019       dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
4020   auto resultElementBits =
4021       dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
4022 
4023   if (sourceVectorType.getRank() == 0) {
4024     if (sourceElementBits != resultElementBits)
4025       return op.emitOpError("source/result bitwidth of the 0-D vector element "
4026                             "types must be equal");
4027   } else if (sourceElementBits * sourceVectorType.getShape().back() !=
4028              resultElementBits * resultVectorType.getShape().back()) {
4029     return op.emitOpError(
4030         "source/result bitwidth of the minor 1-D vectors must be equal");
4031   }
4032 
4033   return success();
4034 }
4035 
4036 OpFoldResult BitCastOp::fold(ArrayRef<Attribute> operands) {
4037   // Nop cast.
4038   if (source().getType() == result().getType())
4039     return source();
4040 
4041   // Canceling bitcasts.
4042   if (auto otherOp = source().getDefiningOp<BitCastOp>())
4043     if (result().getType() == otherOp.source().getType())
4044       return otherOp.source();
4045 
4046   Attribute sourceConstant = operands.front();
4047   if (!sourceConstant)
4048     return {};
4049 
4050   Type srcElemType = getSourceVectorType().getElementType();
4051   Type dstElemType = getResultVectorType().getElementType();
4052 
4053   if (auto floatPack = sourceConstant.dyn_cast<DenseFPElementsAttr>()) {
4054     if (floatPack.isSplat()) {
4055       auto splat = floatPack.getSplatValue<FloatAttr>();
4056 
4057       // Casting fp16 into fp32.
4058       if (srcElemType.isF16() && dstElemType.isF32()) {
4059         uint32_t bits = static_cast<uint32_t>(
4060             splat.getValue().bitcastToAPInt().getZExtValue());
4061         // Duplicate the 16-bit pattern.
4062         bits = (bits << 16) | (bits & 0xffff);
4063         APInt intBits(32, bits);
4064         APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
4065         return DenseElementsAttr::get(getResultVectorType(), floatBits);
4066       }
4067     }
4068   }
4069 
4070   return {};
4071 }
4072 
4073 //===----------------------------------------------------------------------===//
4074 // TypeCastOp
4075 //===----------------------------------------------------------------------===//
4076 
4077 static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
4078   auto vectorType = memRefType.getElementType().dyn_cast<VectorType>();
4079   SmallVector<int64_t, 8> res(memRefType.getShape().begin(),
4080                               memRefType.getShape().end());
4081   if (vectorType)
4082     res.append(vectorType.getShape().begin(), vectorType.getShape().end());
4083   return res;
4084 }
4085 
4086 /// Build the canonical memRefType with a single vector.
4087 /// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
4088 void TypeCastOp::build(OpBuilder &builder, OperationState &result,
4089                        Value source) {
4090   result.addOperands(source);
4091   MemRefType memRefType = source.getType().cast<MemRefType>();
4092   VectorType vectorType =
4093       VectorType::get(extractShape(memRefType),
4094                       getElementTypeOrSelf(getElementTypeOrSelf(memRefType)));
4095   result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
4096                                   memRefType.getMemorySpace()));
4097 }
4098 
4099 static LogicalResult verify(TypeCastOp op) {
4100   MemRefType canonicalType = canonicalizeStridedLayout(op.getMemRefType());
4101   if (!canonicalType.getLayout().isIdentity())
4102     return op.emitOpError(
4103         "expects operand to be a memref with identity layout");
4104   if (!op.getResultMemRefType().getLayout().isIdentity())
4105     return op.emitOpError("expects result to be a memref with identity layout");
4106   if (op.getResultMemRefType().getMemorySpace() !=
4107       op.getMemRefType().getMemorySpace())
4108     return op.emitOpError("expects result in same memory space");
4109 
4110   auto sourceType = op.getMemRefType();
4111   auto resultType = op.getResultMemRefType();
4112   if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
4113       getElementTypeOrSelf(getElementTypeOrSelf(resultType)))
4114     return op.emitOpError(
4115                "expects result and operand with same underlying scalar type: ")
4116            << resultType;
4117   if (extractShape(sourceType) != extractShape(resultType))
4118     return op.emitOpError(
4119                "expects concatenated result and operand shapes to be equal: ")
4120            << resultType;
4121   return success();
4122 }
4123 
4124 //===----------------------------------------------------------------------===//
4125 // TransposeOp
4126 //===----------------------------------------------------------------------===//
4127 
4128 void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
4129                                 Value vector, ArrayRef<int64_t> transp) {
4130   VectorType vt = vector.getType().cast<VectorType>();
4131   SmallVector<int64_t, 4> transposedShape(vt.getRank());
4132   for (unsigned i = 0; i < transp.size(); ++i)
4133     transposedShape[i] = vt.getShape()[transp[i]];
4134 
4135   result.addOperands(vector);
4136   result.addTypes(VectorType::get(transposedShape, vt.getElementType()));
4137   result.addAttribute(getTranspAttrName(), builder.getI64ArrayAttr(transp));
4138 }
4139 
4140 // Eliminates transpose operations, which produce values identical to their
4141 // input values. This happens when the dimensions of the input vector remain in
4142 // their original order after the transpose operation.
4143 OpFoldResult vector::TransposeOp::fold(ArrayRef<Attribute> operands) {
4144   SmallVector<int64_t, 4> transp;
4145   getTransp(transp);
4146 
4147   // Check if the permutation of the dimensions contains sequential values:
4148   // {0, 1, 2, ...}.
4149   for (int64_t i = 0, e = transp.size(); i < e; i++) {
4150     if (transp[i] != i)
4151       return {};
4152   }
4153 
4154   return vector();
4155 }
4156 
4157 static LogicalResult verify(vector::TransposeOp op) {
4158   VectorType vectorType = op.getVectorType();
4159   VectorType resultType = op.getResultType();
4160   int64_t rank = resultType.getRank();
4161   if (vectorType.getRank() != rank)
4162     return op.emitOpError("vector result rank mismatch: ") << rank;
4163   // Verify transposition array.
4164   auto transpAttr = op.transp().getValue();
4165   int64_t size = transpAttr.size();
4166   if (rank != size)
4167     return op.emitOpError("transposition length mismatch: ") << size;
4168   SmallVector<bool, 8> seen(rank, false);
4169   for (const auto &ta : llvm::enumerate(transpAttr)) {
4170     int64_t i = ta.value().cast<IntegerAttr>().getInt();
4171     if (i < 0 || i >= rank)
4172       return op.emitOpError("transposition index out of range: ") << i;
4173     if (seen[i])
4174       return op.emitOpError("duplicate position index: ") << i;
4175     seen[i] = true;
4176     if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i))
4177       return op.emitOpError("dimension size mismatch at: ") << i;
4178   }
4179   return success();
4180 }
4181 
4182 namespace {
4183 
4184 // Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
4185 class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
4186 public:
4187   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
4188 
4189   LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
4190                                 PatternRewriter &rewriter) const override {
4191     // Wrapper around vector::TransposeOp::getTransp() for cleaner code.
4192     auto getPermutation = [](vector::TransposeOp transpose) {
4193       SmallVector<int64_t, 4> permutation;
4194       transpose.getTransp(permutation);
4195       return permutation;
4196     };
4197 
4198     // Composes two permutations: result[i] = permutation1[permutation2[i]].
4199     auto composePermutations = [](ArrayRef<int64_t> permutation1,
4200                                   ArrayRef<int64_t> permutation2) {
4201       SmallVector<int64_t, 4> result;
4202       for (auto index : permutation2)
4203         result.push_back(permutation1[index]);
4204       return result;
4205     };
4206 
4207     // Return if the input of 'transposeOp' is not defined by another transpose.
4208     vector::TransposeOp parentTransposeOp =
4209         transposeOp.vector().getDefiningOp<vector::TransposeOp>();
4210     if (!parentTransposeOp)
4211       return failure();
4212 
4213     SmallVector<int64_t, 4> permutation = composePermutations(
4214         getPermutation(parentTransposeOp), getPermutation(transposeOp));
4215     // Replace 'transposeOp' with a new transpose operation.
4216     rewriter.replaceOpWithNewOp<vector::TransposeOp>(
4217         transposeOp, transposeOp.getResult().getType(),
4218         parentTransposeOp.vector(),
4219         vector::getVectorSubscriptAttr(rewriter, permutation));
4220     return success();
4221   }
4222 };
4223 
4224 } // namespace
4225 
4226 void vector::TransposeOp::getCanonicalizationPatterns(
4227     RewritePatternSet &results, MLIRContext *context) {
4228   results.add<TransposeFolder>(context);
4229 }
4230 
4231 void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
4232   populateFromInt64AttrArray(transp(), results);
4233 }
4234 
4235 //===----------------------------------------------------------------------===//
4236 // ConstantMaskOp
4237 //===----------------------------------------------------------------------===//
4238 
4239 static LogicalResult verify(ConstantMaskOp &op) {
4240   auto resultType = op.getResult().getType().cast<VectorType>();
4241   // Check the corner case of 0-D vectors first.
4242   if (resultType.getRank() == 0) {
4243     if (op.mask_dim_sizes().size() != 1)
4244       return op->emitError("array attr must have length 1 for 0-D vectors");
4245     auto dim = op.mask_dim_sizes()[0].cast<IntegerAttr>().getInt();
4246     if (dim != 0 && dim != 1)
4247       return op->emitError(
4248           "mask dim size must be either 0 or 1 for 0-D vectors");
4249     return success();
4250   }
4251 
4252   // Verify that array attr size matches the rank of the vector result.
4253   if (static_cast<int64_t>(op.mask_dim_sizes().size()) != resultType.getRank())
4254     return op.emitOpError(
4255         "must specify array attr of size equal vector result rank");
4256   // Verify that each array attr element is in bounds of corresponding vector
4257   // result dimension size.
4258   auto resultShape = resultType.getShape();
4259   SmallVector<int64_t, 4> maskDimSizes;
4260   for (const auto &it : llvm::enumerate(op.mask_dim_sizes())) {
4261     int64_t attrValue = it.value().cast<IntegerAttr>().getInt();
4262     if (attrValue < 0 || attrValue > resultShape[it.index()])
4263       return op.emitOpError(
4264           "array attr of size out of bounds of vector result dimension size");
4265     maskDimSizes.push_back(attrValue);
4266   }
4267   // Verify that if one mask dim size is zero, they all should be zero (because
4268   // the mask region is a conjunction of each mask dimension interval).
4269   bool anyZeros = llvm::is_contained(maskDimSizes, 0);
4270   bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
4271   if (anyZeros && !allZeros)
4272     return op.emitOpError("expected all mask dim sizes to be zeros, "
4273                           "as a result of conjunction with zero mask dim");
4274   return success();
4275 }
4276 
4277 //===----------------------------------------------------------------------===//
4278 // CreateMaskOp
4279 //===----------------------------------------------------------------------===//
4280 
4281 static LogicalResult verify(CreateMaskOp op) {
4282   auto vectorType = op.getResult().getType().cast<VectorType>();
4283   // Verify that an operand was specified for each result vector each dimension.
4284   if (vectorType.getRank() == 0) {
4285     if (op->getNumOperands() != 1)
4286       return op.emitOpError(
4287           "must specify exactly one operand for 0-D create_mask");
4288   } else if (op.getNumOperands() !=
4289              op.getResult().getType().cast<VectorType>().getRank()) {
4290     return op.emitOpError(
4291         "must specify an operand for each result vector dimension");
4292   }
4293   return success();
4294 }
4295 
4296 namespace {
4297 
4298 // Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
4299 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
4300 public:
4301   using OpRewritePattern<CreateMaskOp>::OpRewritePattern;
4302 
4303   LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
4304                                 PatternRewriter &rewriter) const override {
4305     // Return if any of 'createMaskOp' operands are not defined by a constant.
4306     auto isNotDefByConstant = [](Value operand) {
4307       return !isa_and_nonnull<arith::ConstantIndexOp>(operand.getDefiningOp());
4308     };
4309     if (llvm::any_of(createMaskOp.operands(), isNotDefByConstant))
4310       return failure();
4311     // Gather constant mask dimension sizes.
4312     SmallVector<int64_t, 4> maskDimSizes;
4313     for (auto it : llvm::zip(createMaskOp.operands(),
4314                              createMaskOp.getType().getShape())) {
4315       auto *defOp = std::get<0>(it).getDefiningOp();
4316       int64_t maxDimSize = std::get<1>(it);
4317       int64_t dimSize = cast<arith::ConstantIndexOp>(defOp).value();
4318       dimSize = std::min(dimSize, maxDimSize);
4319       // If one of dim sizes is zero, set all dims to zero.
4320       if (dimSize <= 0) {
4321         maskDimSizes.assign(createMaskOp.getType().getRank(), 0);
4322         break;
4323       }
4324       maskDimSizes.push_back(dimSize);
4325     }
4326     // Replace 'createMaskOp' with ConstantMaskOp.
4327     rewriter.replaceOpWithNewOp<ConstantMaskOp>(
4328         createMaskOp, createMaskOp.getResult().getType(),
4329         vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
4330     return success();
4331   }
4332 };
4333 
4334 } // namespace
4335 
4336 void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
4337                                                MLIRContext *context) {
4338   results.add<CreateMaskFolder>(context);
4339 }
4340 
4341 //===----------------------------------------------------------------------===//
4342 // ScanOp
4343 //===----------------------------------------------------------------------===//
4344 
4345 static LogicalResult verify(ScanOp op) {
4346   VectorType srcType = op.getSourceType();
4347   VectorType initialType = op.getInitialValueType();
4348   // Check reduction dimension < rank.
4349   int64_t srcRank = srcType.getRank();
4350   int64_t reductionDim = op.reduction_dim();
4351   if (reductionDim >= srcRank)
4352     return op.emitOpError("reduction dimension ")
4353            << reductionDim << " has to be less than " << srcRank;
4354 
4355   // Check that rank(initial_value) = rank(src) - 1.
4356   int64_t initialValueRank = initialType.getRank();
4357   if (initialValueRank != srcRank - 1)
4358     return op.emitOpError("initial value rank ")
4359            << initialValueRank << " has to be equal to " << srcRank - 1;
4360 
4361   // Check shapes of initial value and src.
4362   ArrayRef<int64_t> srcShape = srcType.getShape();
4363   ArrayRef<int64_t> initialValueShapes = initialType.getShape();
4364   SmallVector<int64_t> expectedShape;
4365   for (int i = 0; i < srcRank; i++) {
4366     if (i != reductionDim)
4367       expectedShape.push_back(srcShape[i]);
4368   }
4369   if (llvm::any_of(llvm::zip(initialValueShapes, expectedShape),
4370                    [](std::tuple<int64_t, int64_t> s) {
4371                      return std::get<0>(s) != std::get<1>(s);
4372                    })) {
4373     return op.emitOpError("incompatible input/initial value shapes");
4374   }
4375 
4376   return success();
4377 }
4378 
4379 void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
4380     RewritePatternSet &patterns) {
4381   patterns
4382       .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
4383            ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
4384            StridedSliceConstantMaskFolder, TransposeFolder>(
4385           patterns.getContext());
4386 }
4387 
4388 #define GET_OP_CLASSES
4389 #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
4390