1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
13 #include "mlir/Dialect/Utils/StaticValueUtils.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Interfaces/InferTypeOpInterface.h"
21 #include "mlir/Interfaces/ViewLikeInterface.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallBitVector.h"
24 
25 using namespace mlir;
26 using namespace mlir::memref;
27 
28 /// Materialize a single constant operation from a given attribute value with
29 /// the desired resultant type.
30 Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
31                                               Attribute value, Type type,
32                                               Location loc) {
33   if (arith::ConstantOp::isBuildableWith(value, type))
34     return builder.create<arith::ConstantOp>(loc, value, type);
35   return nullptr;
36 }
37 
38 //===----------------------------------------------------------------------===//
39 // Common canonicalization pattern support logic
40 //===----------------------------------------------------------------------===//
41 
42 /// This is a common class used for patterns of the form
43 /// "someop(memrefcast) -> someop".  It folds the source of any memref.cast
44 /// into the root operation directly.
45 LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) {
46   bool folded = false;
47   for (OpOperand &operand : op->getOpOperands()) {
48     auto cast = operand.get().getDefiningOp<CastOp>();
49     if (cast && operand.get() != inner &&
50         !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
51       operand.set(cast.getOperand());
52       folded = true;
53     }
54   }
55   return success(folded);
56 }
57 
58 /// Return an unranked/ranked tensor type for the given unranked/ranked memref
59 /// type.
60 Type mlir::memref::getTensorTypeFromMemRefType(Type type) {
61   if (auto memref = type.dyn_cast<MemRefType>())
62     return RankedTensorType::get(memref.getShape(), memref.getElementType());
63   if (auto memref = type.dyn_cast<UnrankedMemRefType>())
64     return UnrankedTensorType::get(memref.getElementType());
65   return NoneType::get(type.getContext());
66 }
67 
68 //===----------------------------------------------------------------------===//
69 // AllocOp / AllocaOp
70 //===----------------------------------------------------------------------===//
71 
72 template <typename AllocLikeOp>
73 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
74   static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
75                 "applies to only alloc or alloca");
76   auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>();
77   if (!memRefType)
78     return op.emitOpError("result must be a memref");
79 
80   if (static_cast<int64_t>(op.dynamicSizes().size()) !=
81       memRefType.getNumDynamicDims())
82     return op.emitOpError("dimension operand count does not equal memref "
83                           "dynamic dimension count");
84 
85   unsigned numSymbols = 0;
86   if (!memRefType.getLayout().isIdentity())
87     numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
88   if (op.symbolOperands().size() != numSymbols)
89     return op.emitOpError("symbol operand count does not equal memref symbol "
90                           "count: expected ")
91            << numSymbols << ", got " << op.symbolOperands().size();
92 
93   return success();
94 }
95 
96 LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); }
97 
98 LogicalResult AllocaOp::verify() {
99   // An alloca op needs to have an ancestor with an allocation scope trait.
100   if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
101     return emitOpError(
102         "requires an ancestor op with AutomaticAllocationScope trait");
103 
104   return verifyAllocLikeOp(*this);
105 }
106 
107 namespace {
108 /// Fold constant dimensions into an alloc like operation.
109 template <typename AllocLikeOp>
110 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
111   using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
112 
113   LogicalResult matchAndRewrite(AllocLikeOp alloc,
114                                 PatternRewriter &rewriter) const override {
115     // Check to see if any dimensions operands are constants.  If so, we can
116     // substitute and drop them.
117     if (llvm::none_of(alloc.dynamicSizes(), [](Value operand) {
118           return matchPattern(operand, matchConstantIndex());
119         }))
120       return failure();
121 
122     auto memrefType = alloc.getType();
123 
124     // Ok, we have one or more constant operands.  Collect the non-constant ones
125     // and keep track of the resultant memref type to build.
126     SmallVector<int64_t, 4> newShapeConstants;
127     newShapeConstants.reserve(memrefType.getRank());
128     SmallVector<Value, 4> dynamicSizes;
129 
130     unsigned dynamicDimPos = 0;
131     for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
132       int64_t dimSize = memrefType.getDimSize(dim);
133       // If this is already static dimension, keep it.
134       if (dimSize != -1) {
135         newShapeConstants.push_back(dimSize);
136         continue;
137       }
138       auto dynamicSize = alloc.dynamicSizes()[dynamicDimPos];
139       auto *defOp = dynamicSize.getDefiningOp();
140       if (auto constantIndexOp =
141               dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
142         // Dynamic shape dimension will be folded.
143         newShapeConstants.push_back(constantIndexOp.value());
144       } else {
145         // Dynamic shape dimension not folded; copy dynamicSize from old memref.
146         newShapeConstants.push_back(-1);
147         dynamicSizes.push_back(dynamicSize);
148       }
149       dynamicDimPos++;
150     }
151 
152     // Create new memref type (which will have fewer dynamic dimensions).
153     MemRefType newMemRefType =
154         MemRefType::Builder(memrefType).setShape(newShapeConstants);
155     assert(static_cast<int64_t>(dynamicSizes.size()) ==
156            newMemRefType.getNumDynamicDims());
157 
158     // Create and insert the alloc op for the new memref.
159     auto newAlloc = rewriter.create<AllocLikeOp>(
160         alloc.getLoc(), newMemRefType, dynamicSizes, alloc.symbolOperands(),
161         alloc.alignmentAttr());
162     // Insert a cast so we have the same type as the old alloc.
163     auto resultCast =
164         rewriter.create<CastOp>(alloc.getLoc(), alloc.getType(), newAlloc);
165 
166     rewriter.replaceOp(alloc, {resultCast});
167     return success();
168   }
169 };
170 
171 /// Fold alloc operations with no users or only store and dealloc uses.
172 template <typename T>
173 struct SimplifyDeadAlloc : public OpRewritePattern<T> {
174   using OpRewritePattern<T>::OpRewritePattern;
175 
176   LogicalResult matchAndRewrite(T alloc,
177                                 PatternRewriter &rewriter) const override {
178     if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
179           if (auto storeOp = dyn_cast<StoreOp>(op))
180             return storeOp.value() == alloc;
181           return !isa<DeallocOp>(op);
182         }))
183       return failure();
184 
185     for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
186       rewriter.eraseOp(user);
187 
188     rewriter.eraseOp(alloc);
189     return success();
190   }
191 };
192 } // namespace
193 
194 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
195                                           MLIRContext *context) {
196   results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
197 }
198 
199 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
200                                            MLIRContext *context) {
201   results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
202       context);
203 }
204 
205 //===----------------------------------------------------------------------===//
206 // AllocaScopeOp
207 //===----------------------------------------------------------------------===//
208 
209 void AllocaScopeOp::print(OpAsmPrinter &p) {
210   bool printBlockTerminators = false;
211 
212   p << ' ';
213   if (!results().empty()) {
214     p << " -> (" << getResultTypes() << ")";
215     printBlockTerminators = true;
216   }
217   p << ' ';
218   p.printRegion(bodyRegion(),
219                 /*printEntryBlockArgs=*/false,
220                 /*printBlockTerminators=*/printBlockTerminators);
221   p.printOptionalAttrDict((*this)->getAttrs());
222 }
223 
224 ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
225   // Create a region for the body.
226   result.regions.reserve(1);
227   Region *bodyRegion = result.addRegion();
228 
229   // Parse optional results type list.
230   if (parser.parseOptionalArrowTypeList(result.types))
231     return failure();
232 
233   // Parse the body region.
234   if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{}))
235     return failure();
236   AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
237                                   result.location);
238 
239   // Parse the optional attribute list.
240   if (parser.parseOptionalAttrDict(result.attributes))
241     return failure();
242 
243   return success();
244 }
245 
246 LogicalResult AllocaScopeOp::verify() {
247   return RegionBranchOpInterface::verifyTypes(*this);
248 }
249 
250 void AllocaScopeOp::getSuccessorRegions(
251     Optional<unsigned> index, ArrayRef<Attribute> operands,
252     SmallVectorImpl<RegionSuccessor> &regions) {
253   if (index.hasValue()) {
254     regions.push_back(RegionSuccessor(getResults()));
255     return;
256   }
257 
258   regions.push_back(RegionSuccessor(&bodyRegion()));
259 }
260 
261 //===----------------------------------------------------------------------===//
262 // AssumeAlignmentOp
263 //===----------------------------------------------------------------------===//
264 
265 LogicalResult AssumeAlignmentOp::verify() {
266   if (!llvm::isPowerOf2_32(alignment()))
267     return emitOpError("alignment must be power of 2");
268   return success();
269 }
270 
271 //===----------------------------------------------------------------------===//
272 // CastOp
273 //===----------------------------------------------------------------------===//
274 
275 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
276 /// source memref. This is useful to to fold a memref.cast into a consuming op
277 /// and implement canonicalization patterns for ops in different dialects that
278 /// may consume the results of memref.cast operations. Such foldable memref.cast
279 /// operations are typically inserted as `view` and `subview` ops are
280 /// canonicalized, to preserve the type compatibility of their uses.
281 ///
282 /// Returns true when all conditions are met:
283 /// 1. source and result are ranked memrefs with strided semantics and same
284 /// element type and rank.
285 /// 2. each of the source's size, offset or stride has more static information
286 /// than the corresponding result's size, offset or stride.
287 ///
288 /// Example 1:
289 /// ```mlir
290 ///   %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
291 ///   %2 = consumer %1 ... : memref<?x?xf32> ...
292 /// ```
293 ///
294 /// may fold into:
295 ///
296 /// ```mlir
297 ///   %2 = consumer %0 ... : memref<8x16xf32> ...
298 /// ```
299 ///
300 /// Example 2:
301 /// ```
302 ///   %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
303 ///          to memref<?x?xf32>
304 ///   consumer %1 : memref<?x?xf32> ...
305 /// ```
306 ///
307 /// may fold into:
308 ///
309 /// ```
310 ///   consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
311 /// ```
312 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
313   MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
314   MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
315 
316   // Requires ranked MemRefType.
317   if (!sourceType || !resultType)
318     return false;
319 
320   // Requires same elemental type.
321   if (sourceType.getElementType() != resultType.getElementType())
322     return false;
323 
324   // Requires same rank.
325   if (sourceType.getRank() != resultType.getRank())
326     return false;
327 
328   // Only fold casts between strided memref forms.
329   int64_t sourceOffset, resultOffset;
330   SmallVector<int64_t, 4> sourceStrides, resultStrides;
331   if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
332       failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
333     return false;
334 
335   // If cast is towards more static sizes along any dimension, don't fold.
336   for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
337     auto ss = std::get<0>(it), st = std::get<1>(it);
338     if (ss != st)
339       if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
340         return false;
341   }
342 
343   // If cast is towards more static offset along any dimension, don't fold.
344   if (sourceOffset != resultOffset)
345     if (ShapedType::isDynamicStrideOrOffset(sourceOffset) &&
346         !ShapedType::isDynamicStrideOrOffset(resultOffset))
347       return false;
348 
349   // If cast is towards more static strides along any dimension, don't fold.
350   for (auto it : llvm::zip(sourceStrides, resultStrides)) {
351     auto ss = std::get<0>(it), st = std::get<1>(it);
352     if (ss != st)
353       if (ShapedType::isDynamicStrideOrOffset(ss) &&
354           !ShapedType::isDynamicStrideOrOffset(st))
355         return false;
356   }
357 
358   return true;
359 }
360 
361 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
362   if (inputs.size() != 1 || outputs.size() != 1)
363     return false;
364   Type a = inputs.front(), b = outputs.front();
365   auto aT = a.dyn_cast<MemRefType>();
366   auto bT = b.dyn_cast<MemRefType>();
367 
368   auto uaT = a.dyn_cast<UnrankedMemRefType>();
369   auto ubT = b.dyn_cast<UnrankedMemRefType>();
370 
371   if (aT && bT) {
372     if (aT.getElementType() != bT.getElementType())
373       return false;
374     if (aT.getLayout() != bT.getLayout()) {
375       int64_t aOffset, bOffset;
376       SmallVector<int64_t, 4> aStrides, bStrides;
377       if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
378           failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
379           aStrides.size() != bStrides.size())
380         return false;
381 
382       // Strides along a dimension/offset are compatible if the value in the
383       // source memref is static and the value in the target memref is the
384       // same. They are also compatible if either one is dynamic (see
385       // description of MemRefCastOp for details).
386       auto checkCompatible = [](int64_t a, int64_t b) {
387         return (a == MemRefType::getDynamicStrideOrOffset() ||
388                 b == MemRefType::getDynamicStrideOrOffset() || a == b);
389       };
390       if (!checkCompatible(aOffset, bOffset))
391         return false;
392       for (const auto &aStride : enumerate(aStrides))
393         if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
394           return false;
395     }
396     if (aT.getMemorySpace() != bT.getMemorySpace())
397       return false;
398 
399     // They must have the same rank, and any specified dimensions must match.
400     if (aT.getRank() != bT.getRank())
401       return false;
402 
403     for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
404       int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
405       if (aDim != -1 && bDim != -1 && aDim != bDim)
406         return false;
407     }
408     return true;
409   } else {
410     if (!aT && !uaT)
411       return false;
412     if (!bT && !ubT)
413       return false;
414     // Unranked to unranked casting is unsupported
415     if (uaT && ubT)
416       return false;
417 
418     auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
419     auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
420     if (aEltType != bEltType)
421       return false;
422 
423     auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
424     auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
425     return aMemSpace == bMemSpace;
426   }
427 
428   return false;
429 }
430 
431 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
432   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
433 }
434 
435 //===----------------------------------------------------------------------===//
436 // CopyOp
437 //===----------------------------------------------------------------------===//
438 
439 namespace {
440 /// If the source/target of a CopyOp is a CastOp that does not modify the shape
441 /// and element type, the cast can be skipped. Such CastOps only cast the layout
442 /// of the type.
443 struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
444   using OpRewritePattern<CopyOp>::OpRewritePattern;
445 
446   LogicalResult matchAndRewrite(CopyOp copyOp,
447                                 PatternRewriter &rewriter) const override {
448     bool modified = false;
449 
450     // Check source.
451     if (auto castOp = copyOp.source().getDefiningOp<CastOp>()) {
452       auto fromType = castOp.source().getType().dyn_cast<MemRefType>();
453       auto toType = castOp.source().getType().dyn_cast<MemRefType>();
454 
455       if (fromType && toType) {
456         if (fromType.getShape() == toType.getShape() &&
457             fromType.getElementType() == toType.getElementType()) {
458           rewriter.updateRootInPlace(
459               copyOp, [&] { copyOp.sourceMutable().assign(castOp.source()); });
460           modified = true;
461         }
462       }
463     }
464 
465     // Check target.
466     if (auto castOp = copyOp.target().getDefiningOp<CastOp>()) {
467       auto fromType = castOp.source().getType().dyn_cast<MemRefType>();
468       auto toType = castOp.source().getType().dyn_cast<MemRefType>();
469 
470       if (fromType && toType) {
471         if (fromType.getShape() == toType.getShape() &&
472             fromType.getElementType() == toType.getElementType()) {
473           rewriter.updateRootInPlace(
474               copyOp, [&] { copyOp.targetMutable().assign(castOp.source()); });
475           modified = true;
476         }
477       }
478     }
479 
480     return success(modified);
481   }
482 };
483 
484 /// Fold memref.copy(%x, %x).
485 struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
486   using OpRewritePattern<CopyOp>::OpRewritePattern;
487 
488   LogicalResult matchAndRewrite(CopyOp copyOp,
489                                 PatternRewriter &rewriter) const override {
490     if (copyOp.source() != copyOp.target())
491       return failure();
492 
493     rewriter.eraseOp(copyOp);
494     return success();
495   }
496 };
497 } // namespace
498 
499 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
500                                          MLIRContext *context) {
501   results.add<FoldCopyOfCast, FoldSelfCopy>(context);
502 }
503 
504 LogicalResult CopyOp::fold(ArrayRef<Attribute> cstOperands,
505                            SmallVectorImpl<OpFoldResult> &results) {
506   /// copy(memrefcast) -> copy
507   bool folded = false;
508   Operation *op = *this;
509   for (OpOperand &operand : op->getOpOperands()) {
510     auto castOp = operand.get().getDefiningOp<memref::CastOp>();
511     if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
512       operand.set(castOp.getOperand());
513       folded = true;
514     }
515   }
516   return success(folded);
517 }
518 
519 //===----------------------------------------------------------------------===//
520 // DeallocOp
521 //===----------------------------------------------------------------------===//
522 
523 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
524                               SmallVectorImpl<OpFoldResult> &results) {
525   /// dealloc(memrefcast) -> dealloc
526   return foldMemRefCast(*this);
527 }
528 
529 //===----------------------------------------------------------------------===//
530 // DimOp
531 //===----------------------------------------------------------------------===//
532 
533 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
534                   int64_t index) {
535   auto loc = result.location;
536   Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
537   build(builder, result, source, indexValue);
538 }
539 
540 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
541                   Value index) {
542   auto indexTy = builder.getIndexType();
543   build(builder, result, indexTy, source, index);
544 }
545 
546 Optional<int64_t> DimOp::getConstantIndex() {
547   if (auto constantOp = index().getDefiningOp<arith::ConstantOp>())
548     return constantOp.getValue().cast<IntegerAttr>().getInt();
549   return {};
550 }
551 
552 LogicalResult DimOp::verify() {
553   // Assume unknown index to be in range.
554   Optional<int64_t> index = getConstantIndex();
555   if (!index.hasValue())
556     return success();
557 
558   // Check that constant index is not knowingly out of range.
559   auto type = source().getType();
560   if (auto memrefType = type.dyn_cast<MemRefType>()) {
561     if (index.getValue() >= memrefType.getRank())
562       return emitOpError("index is out of range");
563   } else if (type.isa<UnrankedMemRefType>()) {
564     // Assume index to be in range.
565   } else {
566     llvm_unreachable("expected operand with memref type");
567   }
568   return success();
569 }
570 
571 /// Return a map with key being elements in `vals` and data being number of
572 /// occurences of it. Use std::map, since the `vals` here are strides and the
573 /// dynamic stride value is the same as the tombstone value for
574 /// `DenseMap<int64_t>`.
575 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
576   std::map<int64_t, unsigned> numOccurences;
577   for (auto val : vals)
578     numOccurences[val]++;
579   return numOccurences;
580 }
581 
582 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed
583 /// to be a subset of `originalType` with some `1` entries erased, return the
584 /// set of indices that specifies which of the entries of `originalShape` are
585 /// dropped to obtain `reducedShape`.
586 /// This accounts for cases where there are multiple unit-dims, but only a
587 /// subset of those are dropped. For MemRefTypes these can be disambiguated
588 /// using the strides. If a dimension is dropped the stride must be dropped too.
589 static llvm::Optional<llvm::SmallBitVector>
590 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
591                                ArrayRef<OpFoldResult> sizes) {
592   llvm::SmallBitVector unusedDims(originalType.getRank());
593   if (originalType.getRank() == reducedType.getRank())
594     return unusedDims;
595 
596   for (const auto &dim : llvm::enumerate(sizes))
597     if (auto attr = dim.value().dyn_cast<Attribute>())
598       if (attr.cast<IntegerAttr>().getInt() == 1)
599         unusedDims.set(dim.index());
600 
601   SmallVector<int64_t> originalStrides, candidateStrides;
602   int64_t originalOffset, candidateOffset;
603   if (failed(
604           getStridesAndOffset(originalType, originalStrides, originalOffset)) ||
605       failed(
606           getStridesAndOffset(reducedType, candidateStrides, candidateOffset)))
607     return llvm::None;
608 
609   // For memrefs, a dimension is truly dropped if its corresponding stride is
610   // also dropped. This is particularly important when more than one of the dims
611   // is 1. Track the number of occurences of the strides in the original type
612   // and the candidate type. For each unused dim that stride should not be
613   // present in the candidate type. Note that there could be multiple dimensions
614   // that have the same size. We dont need to exactly figure out which dim
615   // corresponds to which stride, we just need to verify that the number of
616   // reptitions of a stride in the original + number of unused dims with that
617   // stride == number of repititions of a stride in the candidate.
618   std::map<int64_t, unsigned> currUnaccountedStrides =
619       getNumOccurences(originalStrides);
620   std::map<int64_t, unsigned> candidateStridesNumOccurences =
621       getNumOccurences(candidateStrides);
622   for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
623     if (!unusedDims.test(dim))
624       continue;
625     int64_t originalStride = originalStrides[dim];
626     if (currUnaccountedStrides[originalStride] >
627         candidateStridesNumOccurences[originalStride]) {
628       // This dim can be treated as dropped.
629       currUnaccountedStrides[originalStride]--;
630       continue;
631     }
632     if (currUnaccountedStrides[originalStride] ==
633         candidateStridesNumOccurences[originalStride]) {
634       // The stride for this is not dropped. Keep as is.
635       unusedDims.reset(dim);
636       continue;
637     }
638     if (currUnaccountedStrides[originalStride] <
639         candidateStridesNumOccurences[originalStride]) {
640       // This should never happen. Cant have a stride in the reduced rank type
641       // that wasnt in the original one.
642       return llvm::None;
643     }
644   }
645 
646   if ((int64_t)unusedDims.count() + reducedType.getRank() !=
647       originalType.getRank())
648     return llvm::None;
649   return unusedDims;
650 }
651 
652 llvm::SmallBitVector SubViewOp::getDroppedDims() {
653   MemRefType sourceType = getSourceType();
654   MemRefType resultType = getType();
655   llvm::Optional<llvm::SmallBitVector> unusedDims =
656       computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
657   assert(unusedDims && "unable to find unused dims of subview");
658   return *unusedDims;
659 }
660 
661 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
662   // All forms of folding require a known index.
663   auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
664   if (!index)
665     return {};
666 
667   // Folding for unranked types (UnrankedMemRefType) is not supported.
668   auto memrefType = source().getType().dyn_cast<MemRefType>();
669   if (!memrefType)
670     return {};
671 
672   // Fold if the shape extent along the given index is known.
673   if (!memrefType.isDynamicDim(index.getInt())) {
674     Builder builder(getContext());
675     return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
676   }
677 
678   // The size at the given index is now known to be a dynamic size.
679   unsigned unsignedIndex = index.getValue().getZExtValue();
680 
681   // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
682   Operation *definingOp = source().getDefiningOp();
683 
684   if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
685     return *(alloc.getDynamicSizes().begin() +
686              memrefType.getDynamicDimIndex(unsignedIndex));
687 
688   if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
689     return *(alloca.getDynamicSizes().begin() +
690              memrefType.getDynamicDimIndex(unsignedIndex));
691 
692   if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
693     return *(view.getDynamicSizes().begin() +
694              memrefType.getDynamicDimIndex(unsignedIndex));
695 
696   if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
697     llvm::SmallBitVector unusedDims = subview.getDroppedDims();
698     unsigned resultIndex = 0;
699     unsigned sourceRank = subview.getSourceType().getRank();
700     unsigned sourceIndex = 0;
701     for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
702       if (unusedDims.test(i))
703         continue;
704       if (resultIndex == unsignedIndex) {
705         sourceIndex = i;
706         break;
707       }
708       resultIndex++;
709     }
710     assert(subview.isDynamicSize(sourceIndex) &&
711            "expected dynamic subview size");
712     return subview.getDynamicSize(sourceIndex);
713   }
714 
715   if (auto sizeInterface =
716           dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
717     assert(sizeInterface.isDynamicSize(unsignedIndex) &&
718            "Expected dynamic subview size");
719     return sizeInterface.getDynamicSize(unsignedIndex);
720   }
721 
722   // dim(memrefcast) -> dim
723   if (succeeded(foldMemRefCast(*this)))
724     return getResult();
725 
726   return {};
727 }
728 
729 namespace {
730 /// Fold dim of a memref reshape operation to a load into the reshape's shape
731 /// operand.
732 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
733   using OpRewritePattern<DimOp>::OpRewritePattern;
734 
735   LogicalResult matchAndRewrite(DimOp dim,
736                                 PatternRewriter &rewriter) const override {
737     auto reshape = dim.source().getDefiningOp<ReshapeOp>();
738 
739     if (!reshape)
740       return failure();
741 
742     // Place the load directly after the reshape to ensure that the shape memref
743     // was not mutated.
744     rewriter.setInsertionPointAfter(reshape);
745     Location loc = dim.getLoc();
746     Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index());
747     if (load.getType() != dim.getType())
748       load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
749     rewriter.replaceOp(dim, load);
750     return success();
751   }
752 };
753 
754 } // namespace
755 
756 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
757                                         MLIRContext *context) {
758   results.add<DimOfMemRefReshape>(context);
759 }
760 
761 // ---------------------------------------------------------------------------
762 // DmaStartOp
763 // ---------------------------------------------------------------------------
764 
765 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
766                        Value srcMemRef, ValueRange srcIndices, Value destMemRef,
767                        ValueRange destIndices, Value numElements,
768                        Value tagMemRef, ValueRange tagIndices, Value stride,
769                        Value elementsPerStride) {
770   result.addOperands(srcMemRef);
771   result.addOperands(srcIndices);
772   result.addOperands(destMemRef);
773   result.addOperands(destIndices);
774   result.addOperands({numElements, tagMemRef});
775   result.addOperands(tagIndices);
776   if (stride)
777     result.addOperands({stride, elementsPerStride});
778 }
779 
780 void DmaStartOp::print(OpAsmPrinter &p) {
781   p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
782     << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
783     << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
784   if (isStrided())
785     p << ", " << getStride() << ", " << getNumElementsPerStride();
786 
787   p.printOptionalAttrDict((*this)->getAttrs());
788   p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
789     << ", " << getTagMemRef().getType();
790 }
791 
792 // Parse DmaStartOp.
793 // Ex:
794 //   %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
795 //                       %tag[%index], %stride, %num_elt_per_stride :
796 //                     : memref<3076 x f32, 0>,
797 //                       memref<1024 x f32, 2>,
798 //                       memref<1 x i32>
799 //
800 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
801   OpAsmParser::OperandType srcMemRefInfo;
802   SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
803   OpAsmParser::OperandType dstMemRefInfo;
804   SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos;
805   OpAsmParser::OperandType numElementsInfo;
806   OpAsmParser::OperandType tagMemrefInfo;
807   SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
808   SmallVector<OpAsmParser::OperandType, 2> strideInfo;
809 
810   SmallVector<Type, 3> types;
811   auto indexType = parser.getBuilder().getIndexType();
812 
813   // Parse and resolve the following list of operands:
814   // *) source memref followed by its indices (in square brackets).
815   // *) destination memref followed by its indices (in square brackets).
816   // *) dma size in KiB.
817   if (parser.parseOperand(srcMemRefInfo) ||
818       parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
819       parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
820       parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
821       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
822       parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
823       parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
824     return failure();
825 
826   // Parse optional stride and elements per stride.
827   if (parser.parseTrailingOperandList(strideInfo))
828     return failure();
829 
830   bool isStrided = strideInfo.size() == 2;
831   if (!strideInfo.empty() && !isStrided) {
832     return parser.emitError(parser.getNameLoc(),
833                             "expected two stride related operands");
834   }
835 
836   if (parser.parseColonTypeList(types))
837     return failure();
838   if (types.size() != 3)
839     return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
840 
841   if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
842       parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
843       parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
844       parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
845       // size should be an index.
846       parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
847       parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
848       // tag indices should be index.
849       parser.resolveOperands(tagIndexInfos, indexType, result.operands))
850     return failure();
851 
852   if (isStrided) {
853     if (parser.resolveOperands(strideInfo, indexType, result.operands))
854       return failure();
855   }
856 
857   return success();
858 }
859 
860 LogicalResult DmaStartOp::verify() {
861   unsigned numOperands = getNumOperands();
862 
863   // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
864   // the number of elements.
865   if (numOperands < 4)
866     return emitOpError("expected at least 4 operands");
867 
868   // Check types of operands. The order of these calls is important: the later
869   // calls rely on some type properties to compute the operand position.
870   // 1. Source memref.
871   if (!getSrcMemRef().getType().isa<MemRefType>())
872     return emitOpError("expected source to be of memref type");
873   if (numOperands < getSrcMemRefRank() + 4)
874     return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
875                          << " operands";
876   if (!getSrcIndices().empty() &&
877       !llvm::all_of(getSrcIndices().getTypes(),
878                     [](Type t) { return t.isIndex(); }))
879     return emitOpError("expected source indices to be of index type");
880 
881   // 2. Destination memref.
882   if (!getDstMemRef().getType().isa<MemRefType>())
883     return emitOpError("expected destination to be of memref type");
884   unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
885   if (numOperands < numExpectedOperands)
886     return emitOpError() << "expected at least " << numExpectedOperands
887                          << " operands";
888   if (!getDstIndices().empty() &&
889       !llvm::all_of(getDstIndices().getTypes(),
890                     [](Type t) { return t.isIndex(); }))
891     return emitOpError("expected destination indices to be of index type");
892 
893   // 3. Number of elements.
894   if (!getNumElements().getType().isIndex())
895     return emitOpError("expected num elements to be of index type");
896 
897   // 4. Tag memref.
898   if (!getTagMemRef().getType().isa<MemRefType>())
899     return emitOpError("expected tag to be of memref type");
900   numExpectedOperands += getTagMemRefRank();
901   if (numOperands < numExpectedOperands)
902     return emitOpError() << "expected at least " << numExpectedOperands
903                          << " operands";
904   if (!getTagIndices().empty() &&
905       !llvm::all_of(getTagIndices().getTypes(),
906                     [](Type t) { return t.isIndex(); }))
907     return emitOpError("expected tag indices to be of index type");
908 
909   // Optional stride-related operands must be either both present or both
910   // absent.
911   if (numOperands != numExpectedOperands &&
912       numOperands != numExpectedOperands + 2)
913     return emitOpError("incorrect number of operands");
914 
915   // 5. Strides.
916   if (isStrided()) {
917     if (!getStride().getType().isIndex() ||
918         !getNumElementsPerStride().getType().isIndex())
919       return emitOpError(
920           "expected stride and num elements per stride to be of type index");
921   }
922 
923   return success();
924 }
925 
926 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
927                                SmallVectorImpl<OpFoldResult> &results) {
928   /// dma_start(memrefcast) -> dma_start
929   return foldMemRefCast(*this);
930 }
931 
932 // ---------------------------------------------------------------------------
933 // DmaWaitOp
934 // ---------------------------------------------------------------------------
935 
936 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
937                               SmallVectorImpl<OpFoldResult> &results) {
938   /// dma_wait(memrefcast) -> dma_wait
939   return foldMemRefCast(*this);
940 }
941 
942 LogicalResult DmaWaitOp::verify() {
943   // Check that the number of tag indices matches the tagMemRef rank.
944   unsigned numTagIndices = tagIndices().size();
945   unsigned tagMemRefRank = getTagMemRefRank();
946   if (numTagIndices != tagMemRefRank)
947     return emitOpError() << "expected tagIndices to have the same number of "
948                             "elements as the tagMemRef rank, expected "
949                          << tagMemRefRank << ", but got " << numTagIndices;
950   return success();
951 }
952 
953 //===----------------------------------------------------------------------===//
954 // GenericAtomicRMWOp
955 //===----------------------------------------------------------------------===//
956 
957 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
958                                Value memref, ValueRange ivs) {
959   result.addOperands(memref);
960   result.addOperands(ivs);
961 
962   if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) {
963     Type elementType = memrefType.getElementType();
964     result.addTypes(elementType);
965 
966     Region *bodyRegion = result.addRegion();
967     bodyRegion->push_back(new Block());
968     bodyRegion->addArgument(elementType, memref.getLoc());
969   }
970 }
971 
972 LogicalResult GenericAtomicRMWOp::verify() {
973   auto &body = getRegion();
974   if (body.getNumArguments() != 1)
975     return emitOpError("expected single number of entry block arguments");
976 
977   if (getResult().getType() != body.getArgument(0).getType())
978     return emitOpError("expected block argument of the same type result type");
979 
980   bool hasSideEffects =
981       body.walk([&](Operation *nestedOp) {
982             if (MemoryEffectOpInterface::hasNoEffect(nestedOp))
983               return WalkResult::advance();
984             nestedOp->emitError(
985                 "body of 'memref.generic_atomic_rmw' should contain "
986                 "only operations with no side effects");
987             return WalkResult::interrupt();
988           })
989           .wasInterrupted();
990   return hasSideEffects ? failure() : success();
991 }
992 
993 ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
994                                       OperationState &result) {
995   OpAsmParser::OperandType memref;
996   Type memrefType;
997   SmallVector<OpAsmParser::OperandType, 4> ivs;
998 
999   Type indexType = parser.getBuilder().getIndexType();
1000   if (parser.parseOperand(memref) ||
1001       parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) ||
1002       parser.parseColonType(memrefType) ||
1003       parser.resolveOperand(memref, memrefType, result.operands) ||
1004       parser.resolveOperands(ivs, indexType, result.operands))
1005     return failure();
1006 
1007   Region *body = result.addRegion();
1008   if (parser.parseRegion(*body, llvm::None, llvm::None) ||
1009       parser.parseOptionalAttrDict(result.attributes))
1010     return failure();
1011   result.types.push_back(memrefType.cast<MemRefType>().getElementType());
1012   return success();
1013 }
1014 
1015 void GenericAtomicRMWOp::print(OpAsmPrinter &p) {
1016   p << ' ' << memref() << "[" << indices() << "] : " << memref().getType()
1017     << ' ';
1018   p.printRegion(getRegion());
1019   p.printOptionalAttrDict((*this)->getAttrs());
1020 }
1021 
1022 //===----------------------------------------------------------------------===//
1023 // AtomicYieldOp
1024 //===----------------------------------------------------------------------===//
1025 
1026 LogicalResult AtomicYieldOp::verify() {
1027   Type parentType = (*this)->getParentOp()->getResultTypes().front();
1028   Type resultType = result().getType();
1029   if (parentType != resultType)
1030     return emitOpError() << "types mismatch between yield op: " << resultType
1031                          << " and its parent: " << parentType;
1032   return success();
1033 }
1034 
1035 //===----------------------------------------------------------------------===//
1036 // GlobalOp
1037 //===----------------------------------------------------------------------===//
1038 
1039 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op,
1040                                                    TypeAttr type,
1041                                                    Attribute initialValue) {
1042   p << type;
1043   if (!op.isExternal()) {
1044     p << " = ";
1045     if (op.isUninitialized())
1046       p << "uninitialized";
1047     else
1048       p.printAttributeWithoutType(initialValue);
1049   }
1050 }
1051 
1052 static ParseResult
1053 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
1054                                        Attribute &initialValue) {
1055   Type type;
1056   if (parser.parseType(type))
1057     return failure();
1058 
1059   auto memrefType = type.dyn_cast<MemRefType>();
1060   if (!memrefType || !memrefType.hasStaticShape())
1061     return parser.emitError(parser.getNameLoc())
1062            << "type should be static shaped memref, but got " << type;
1063   typeAttr = TypeAttr::get(type);
1064 
1065   if (parser.parseOptionalEqual())
1066     return success();
1067 
1068   if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1069     initialValue = UnitAttr::get(parser.getContext());
1070     return success();
1071   }
1072 
1073   Type tensorType = getTensorTypeFromMemRefType(memrefType);
1074   if (parser.parseAttribute(initialValue, tensorType))
1075     return failure();
1076   if (!initialValue.isa<ElementsAttr>())
1077     return parser.emitError(parser.getNameLoc())
1078            << "initial value should be a unit or elements attribute";
1079   return success();
1080 }
1081 
1082 LogicalResult GlobalOp::verify() {
1083   auto memrefType = type().dyn_cast<MemRefType>();
1084   if (!memrefType || !memrefType.hasStaticShape())
1085     return emitOpError("type should be static shaped memref, but got ")
1086            << type();
1087 
1088   // Verify that the initial value, if present, is either a unit attribute or
1089   // an elements attribute.
1090   if (initial_value().hasValue()) {
1091     Attribute initValue = initial_value().getValue();
1092     if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>())
1093       return emitOpError("initial value should be a unit or elements "
1094                          "attribute, but got ")
1095              << initValue;
1096 
1097     // Check that the type of the initial value is compatible with the type of
1098     // the global variable.
1099     if (initValue.isa<ElementsAttr>()) {
1100       Type initType = initValue.getType();
1101       Type tensorType = getTensorTypeFromMemRefType(memrefType);
1102       if (initType != tensorType)
1103         return emitOpError("initial value expected to be of type ")
1104                << tensorType << ", but was of type " << initType;
1105     }
1106   }
1107 
1108   if (Optional<uint64_t> alignAttr = alignment()) {
1109     uint64_t alignment = alignAttr.getValue();
1110 
1111     if (!llvm::isPowerOf2_64(alignment))
1112       return emitError() << "alignment attribute value " << alignment
1113                          << " is not a power of 2";
1114   }
1115 
1116   // TODO: verify visibility for declarations.
1117   return success();
1118 }
1119 
1120 //===----------------------------------------------------------------------===//
1121 // GetGlobalOp
1122 //===----------------------------------------------------------------------===//
1123 
1124 LogicalResult
1125 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1126   // Verify that the result type is same as the type of the referenced
1127   // memref.global op.
1128   auto global =
1129       symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr());
1130   if (!global)
1131     return emitOpError("'")
1132            << name() << "' does not reference a valid global memref";
1133 
1134   Type resultType = result().getType();
1135   if (global.type() != resultType)
1136     return emitOpError("result type ")
1137            << resultType << " does not match type " << global.type()
1138            << " of the global memref @" << name();
1139   return success();
1140 }
1141 
1142 //===----------------------------------------------------------------------===//
1143 // LoadOp
1144 //===----------------------------------------------------------------------===//
1145 
1146 LogicalResult LoadOp::verify() {
1147   if (getNumOperands() != 1 + getMemRefType().getRank())
1148     return emitOpError("incorrect number of indices for load");
1149   return success();
1150 }
1151 
1152 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
1153   /// load(memrefcast) -> load
1154   if (succeeded(foldMemRefCast(*this)))
1155     return getResult();
1156   return OpFoldResult();
1157 }
1158 
1159 //===----------------------------------------------------------------------===//
1160 // PrefetchOp
1161 //===----------------------------------------------------------------------===//
1162 
1163 void PrefetchOp::print(OpAsmPrinter &p) {
1164   p << " " << memref() << '[';
1165   p.printOperands(indices());
1166   p << ']' << ", " << (isWrite() ? "write" : "read");
1167   p << ", locality<" << localityHint();
1168   p << ">, " << (isDataCache() ? "data" : "instr");
1169   p.printOptionalAttrDict(
1170       (*this)->getAttrs(),
1171       /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1172   p << " : " << getMemRefType();
1173 }
1174 
1175 ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
1176   OpAsmParser::OperandType memrefInfo;
1177   SmallVector<OpAsmParser::OperandType, 4> indexInfo;
1178   IntegerAttr localityHint;
1179   MemRefType type;
1180   StringRef readOrWrite, cacheType;
1181 
1182   auto indexTy = parser.getBuilder().getIndexType();
1183   auto i32Type = parser.getBuilder().getIntegerType(32);
1184   if (parser.parseOperand(memrefInfo) ||
1185       parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1186       parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1187       parser.parseComma() || parser.parseKeyword("locality") ||
1188       parser.parseLess() ||
1189       parser.parseAttribute(localityHint, i32Type, "localityHint",
1190                             result.attributes) ||
1191       parser.parseGreater() || parser.parseComma() ||
1192       parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1193       parser.resolveOperand(memrefInfo, type, result.operands) ||
1194       parser.resolveOperands(indexInfo, indexTy, result.operands))
1195     return failure();
1196 
1197   if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
1198     return parser.emitError(parser.getNameLoc(),
1199                             "rw specifier has to be 'read' or 'write'");
1200   result.addAttribute(
1201       PrefetchOp::getIsWriteAttrName(),
1202       parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
1203 
1204   if (!cacheType.equals("data") && !cacheType.equals("instr"))
1205     return parser.emitError(parser.getNameLoc(),
1206                             "cache type has to be 'data' or 'instr'");
1207 
1208   result.addAttribute(
1209       PrefetchOp::getIsDataCacheAttrName(),
1210       parser.getBuilder().getBoolAttr(cacheType.equals("data")));
1211 
1212   return success();
1213 }
1214 
1215 LogicalResult PrefetchOp::verify() {
1216   if (getNumOperands() != 1 + getMemRefType().getRank())
1217     return emitOpError("too few indices");
1218 
1219   return success();
1220 }
1221 
1222 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
1223                                SmallVectorImpl<OpFoldResult> &results) {
1224   // prefetch(memrefcast) -> prefetch
1225   return foldMemRefCast(*this);
1226 }
1227 
1228 //===----------------------------------------------------------------------===//
1229 // RankOp
1230 //===----------------------------------------------------------------------===//
1231 
1232 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
1233   // Constant fold rank when the rank of the operand is known.
1234   auto type = getOperand().getType();
1235   auto shapedType = type.dyn_cast<ShapedType>();
1236   if (shapedType && shapedType.hasRank())
1237     return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1238   return IntegerAttr();
1239 }
1240 
1241 //===----------------------------------------------------------------------===//
1242 // ReinterpretCastOp
1243 //===----------------------------------------------------------------------===//
1244 
1245 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1246 /// `staticSizes` and `staticStrides` are automatically filled with
1247 /// source-memref-rank sentinel values that encode dynamic entries.
1248 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1249                               MemRefType resultType, Value source,
1250                               OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1251                               ArrayRef<OpFoldResult> strides,
1252                               ArrayRef<NamedAttribute> attrs) {
1253   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1254   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1255   dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets,
1256                              ShapedType::kDynamicStrideOrOffset);
1257   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1258                              ShapedType::kDynamicSize);
1259   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1260                              ShapedType::kDynamicStrideOrOffset);
1261   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1262         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1263         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1264   result.addAttributes(attrs);
1265 }
1266 
1267 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1268                               MemRefType resultType, Value source,
1269                               int64_t offset, ArrayRef<int64_t> sizes,
1270                               ArrayRef<int64_t> strides,
1271                               ArrayRef<NamedAttribute> attrs) {
1272   SmallVector<OpFoldResult> sizeValues =
1273       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1274         return b.getI64IntegerAttr(v);
1275       }));
1276   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1277       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1278         return b.getI64IntegerAttr(v);
1279       }));
1280   build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1281         strideValues, attrs);
1282 }
1283 
1284 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1285                               MemRefType resultType, Value source, Value offset,
1286                               ValueRange sizes, ValueRange strides,
1287                               ArrayRef<NamedAttribute> attrs) {
1288   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1289       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1290   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1291       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1292   build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1293 }
1294 
1295 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1296 // completed automatically, like we have for subview and extract_slice.
1297 LogicalResult ReinterpretCastOp::verify() {
1298   // The source and result memrefs should be in the same memory space.
1299   auto srcType = source().getType().cast<BaseMemRefType>();
1300   auto resultType = getType().cast<MemRefType>();
1301   if (srcType.getMemorySpace() != resultType.getMemorySpace())
1302     return emitError("different memory spaces specified for source type ")
1303            << srcType << " and result memref type " << resultType;
1304   if (srcType.getElementType() != resultType.getElementType())
1305     return emitError("different element types specified for source type ")
1306            << srcType << " and result memref type " << resultType;
1307 
1308   // Match sizes in result memref type and in static_sizes attribute.
1309   for (auto &en : llvm::enumerate(llvm::zip(
1310            resultType.getShape(), extractFromI64ArrayAttr(static_sizes())))) {
1311     int64_t resultSize = std::get<0>(en.value());
1312     int64_t expectedSize = std::get<1>(en.value());
1313     if (!ShapedType::isDynamic(resultSize) &&
1314         !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
1315       return emitError("expected result type with size = ")
1316              << expectedSize << " instead of " << resultSize
1317              << " in dim = " << en.index();
1318   }
1319 
1320   // Match offset and strides in static_offset and static_strides attributes. If
1321   // result memref type has no affine map specified, this will assume an
1322   // identity layout.
1323   int64_t resultOffset;
1324   SmallVector<int64_t, 4> resultStrides;
1325   if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
1326     return emitError("expected result type to have strided layout but found ")
1327            << resultType;
1328 
1329   // Match offset in result memref type and in static_offsets attribute.
1330   int64_t expectedOffset = extractFromI64ArrayAttr(static_offsets()).front();
1331   if (!ShapedType::isDynamicStrideOrOffset(resultOffset) &&
1332       !ShapedType::isDynamicStrideOrOffset(expectedOffset) &&
1333       resultOffset != expectedOffset)
1334     return emitError("expected result type with offset = ")
1335            << resultOffset << " instead of " << expectedOffset;
1336 
1337   // Match strides in result memref type and in static_strides attribute.
1338   for (auto &en : llvm::enumerate(llvm::zip(
1339            resultStrides, extractFromI64ArrayAttr(static_strides())))) {
1340     int64_t resultStride = std::get<0>(en.value());
1341     int64_t expectedStride = std::get<1>(en.value());
1342     if (!ShapedType::isDynamicStrideOrOffset(resultStride) &&
1343         !ShapedType::isDynamicStrideOrOffset(expectedStride) &&
1344         resultStride != expectedStride)
1345       return emitError("expected result type with stride = ")
1346              << expectedStride << " instead of " << resultStride
1347              << " in dim = " << en.index();
1348   }
1349 
1350   return success();
1351 }
1352 
1353 //===----------------------------------------------------------------------===//
1354 // Reassociative reshape ops
1355 //===----------------------------------------------------------------------===//
1356 
1357 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1358   return getSymbolLessAffineMaps(getReassociationExprs());
1359 }
1360 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1361   return convertReassociationIndicesToExprs(getContext(),
1362                                             getReassociationIndices());
1363 }
1364 
1365 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1366   return getSymbolLessAffineMaps(getReassociationExprs());
1367 }
1368 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1369   return convertReassociationIndicesToExprs(getContext(),
1370                                             getReassociationIndices());
1371 }
1372 
1373 ParseResult ExpandShapeOp::parse(OpAsmParser &parser, OperationState &result) {
1374   return parseReshapeLikeOp(parser, result);
1375 }
1376 void ExpandShapeOp::print(OpAsmPrinter &p) {
1377   ::mlir::printReshapeOp<ExpandShapeOp>(p, *this);
1378 }
1379 
1380 ParseResult CollapseShapeOp::parse(OpAsmParser &parser,
1381                                    OperationState &result) {
1382   return parseReshapeLikeOp(parser, result);
1383 }
1384 void CollapseShapeOp::print(OpAsmPrinter &p) {
1385   ::mlir::printReshapeOp<CollapseShapeOp>(p, *this);
1386 }
1387 
1388 /// Detect whether memref dims [dim, dim + extent) can be reshaped without
1389 /// copies.
1390 static bool isReshapableDimBand(unsigned dim, unsigned extent,
1391                                 ArrayRef<int64_t> sizes,
1392                                 ArrayRef<AffineExpr> strides) {
1393   // Bands of extent one can be reshaped, as they are not reshaped at all.
1394   if (extent == 1)
1395     return true;
1396   // Otherwise, the size of the first dimension needs to be known.
1397   if (ShapedType::isDynamic(sizes[dim]))
1398     return false;
1399   assert(sizes.size() == strides.size() && "mismatched ranks");
1400   // off by 1 indexing to avoid out of bounds
1401   //                       V
1402   for (auto idx = dim, e = dim + extent; idx + 1 < e; ++idx) {
1403     // Only bands of static shapes are reshapable. This is due to the fact that
1404     // there is no relation between dynamic sizes and dynamic strides: we do not
1405     // have enough information to know whether a "-1" size corresponds to the
1406     // proper symbol in the AffineExpr of a stride.
1407     if (ShapedType::isDynamic(sizes[idx + 1]))
1408       return false;
1409     // TODO: Refine this by passing the proper nDims and nSymbols so we can
1410     // simplify on the fly and catch more reshapable cases.
1411     if (strides[idx] != strides[idx + 1] * sizes[idx + 1])
1412       return false;
1413   }
1414   return true;
1415 }
1416 
1417 /// Compute the MemRefType obtained by applying the `reassociation` (which is
1418 /// expected to be valid) to `type`.
1419 /// If `type` is Contiguous MemRefType, this always produce a contiguous
1420 /// MemRefType.
1421 static MemRefType
1422 computeReshapeCollapsedType(MemRefType type,
1423                             ArrayRef<AffineMap> reassociation) {
1424   auto sizes = type.getShape();
1425   AffineExpr offset;
1426   SmallVector<AffineExpr, 4> strides;
1427   auto status = getStridesAndOffset(type, strides, offset);
1428   auto isIdentityLayout = type.getLayout().isIdentity();
1429   (void)status;
1430   assert(succeeded(status) && "expected strided memref");
1431 
1432   SmallVector<int64_t, 4> newSizes;
1433   newSizes.reserve(reassociation.size());
1434   SmallVector<AffineExpr, 4> newStrides;
1435   newStrides.reserve(reassociation.size());
1436 
1437   // Use the fact that reassociation is valid to simplify the logic: only use
1438   // each map's rank.
1439   assert(isReassociationValid(reassociation) && "invalid reassociation");
1440   unsigned currentDim = 0;
1441   for (AffineMap m : reassociation) {
1442     unsigned dim = m.getNumResults();
1443     int64_t size = 1;
1444     AffineExpr stride = strides[currentDim + dim - 1];
1445     if (isIdentityLayout ||
1446         isReshapableDimBand(currentDim, dim, sizes, strides)) {
1447       for (unsigned d = 0; d < dim; ++d) {
1448         int64_t currentSize = sizes[currentDim + d];
1449         if (ShapedType::isDynamic(currentSize)) {
1450           size = ShapedType::kDynamicSize;
1451           break;
1452         }
1453         size *= currentSize;
1454       }
1455     } else {
1456       size = ShapedType::kDynamicSize;
1457       stride = AffineExpr();
1458     }
1459     newSizes.push_back(size);
1460     newStrides.push_back(stride);
1461     currentDim += dim;
1462   }
1463 
1464   // Early-exit: if `type` is contiguous, the result must be contiguous.
1465   if (canonicalizeStridedLayout(type).getLayout().isIdentity())
1466     return MemRefType::Builder(type).setShape(newSizes).setLayout({});
1467 
1468   // Convert back to int64_t because we don't have enough information to create
1469   // new strided layouts from AffineExpr only. This corresponds to a case where
1470   // copies may be necessary.
1471   int64_t intOffset = ShapedType::kDynamicStrideOrOffset;
1472   if (auto o = offset.dyn_cast<AffineConstantExpr>())
1473     intOffset = o.getValue();
1474   SmallVector<int64_t, 4> intStrides;
1475   intStrides.reserve(strides.size());
1476   for (auto stride : newStrides) {
1477     if (auto cst = stride.dyn_cast_or_null<AffineConstantExpr>())
1478       intStrides.push_back(cst.getValue());
1479     else
1480       intStrides.push_back(ShapedType::kDynamicStrideOrOffset);
1481   }
1482   auto layout =
1483       makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext());
1484   return canonicalizeStridedLayout(
1485       MemRefType::Builder(type).setShape(newSizes).setLayout(
1486           AffineMapAttr::get(layout)));
1487 }
1488 
1489 void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src,
1490                           ArrayRef<ReassociationIndices> reassociation,
1491                           ArrayRef<NamedAttribute> attrs) {
1492   auto memRefType = src.getType().cast<MemRefType>();
1493   auto resultType = computeReshapeCollapsedType(
1494       memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
1495                       b.getContext(), reassociation)));
1496   build(b, result, resultType, src, attrs);
1497   result.addAttribute(getReassociationAttrName(),
1498                       getReassociationIndicesAttribute(b, reassociation));
1499 }
1500 
1501 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
1502                             ArrayRef<ReassociationIndices> reassociation,
1503                             ArrayRef<NamedAttribute> attrs) {
1504   auto memRefType = src.getType().cast<MemRefType>();
1505   auto resultType = computeReshapeCollapsedType(
1506       memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
1507                       b.getContext(), reassociation)));
1508   build(b, result, resultType, src, attrs);
1509   result.addAttribute(getReassociationAttrName(),
1510                       getReassociationIndicesAttribute(b, reassociation));
1511 }
1512 
1513 template <typename ReshapeOp,
1514           bool isExpansion = std::is_same<ReshapeOp, ExpandShapeOp>::value>
1515 static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType,
1516                                      MemRefType collapsedType) {
1517   if (failed(
1518           verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
1519     return failure();
1520   auto maps = op.getReassociationMaps();
1521   MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps);
1522   if (collapsedType != expectedType)
1523     return op.emitOpError("expected collapsed type to be ")
1524            << expectedType << ", but got " << collapsedType;
1525   return success();
1526 }
1527 
1528 LogicalResult ExpandShapeOp::verify() {
1529   return verifyReshapeOp(*this, getResultType(), getSrcType());
1530 }
1531 
1532 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1533                                                 MLIRContext *context) {
1534   results.add<CollapseReshapeOps<ExpandShapeOp>,
1535               CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context);
1536 }
1537 
1538 LogicalResult CollapseShapeOp::verify() {
1539   return verifyReshapeOp(*this, getSrcType(), getResultType());
1540 }
1541 
1542 struct CollapseShapeOpMemRefCastFolder
1543     : public OpRewritePattern<CollapseShapeOp> {
1544 public:
1545   using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
1546 
1547   LogicalResult matchAndRewrite(CollapseShapeOp op,
1548                                 PatternRewriter &rewriter) const override {
1549     auto cast = op.getOperand().getDefiningOp<CastOp>();
1550     if (!cast)
1551       return failure();
1552 
1553     if (!CastOp::canFoldIntoConsumerOp(cast))
1554       return failure();
1555 
1556     Type newResultType = computeReshapeCollapsedType(
1557         cast.getOperand().getType().cast<MemRefType>(),
1558         op.getReassociationMaps());
1559 
1560     if (newResultType == op.getResultType()) {
1561       rewriter.updateRootInPlace(
1562           op, [&]() { op.srcMutable().assign(cast.source()); });
1563     } else {
1564       Value newOp = rewriter.create<CollapseShapeOp>(
1565           op->getLoc(), cast.source(), op.getReassociationIndices());
1566       rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
1567     }
1568     return success();
1569   }
1570 };
1571 
1572 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1573                                                   MLIRContext *context) {
1574   results.add<CollapseReshapeOps<CollapseShapeOp>,
1575               CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>,
1576               CollapseShapeOpMemRefCastFolder>(context);
1577 }
1578 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
1579   return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
1580 }
1581 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
1582   return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
1583 }
1584 
1585 //===----------------------------------------------------------------------===//
1586 // ReshapeOp
1587 //===----------------------------------------------------------------------===//
1588 
1589 LogicalResult ReshapeOp::verify() {
1590   Type operandType = source().getType();
1591   Type resultType = result().getType();
1592 
1593   Type operandElementType = operandType.cast<ShapedType>().getElementType();
1594   Type resultElementType = resultType.cast<ShapedType>().getElementType();
1595   if (operandElementType != resultElementType)
1596     return emitOpError("element types of source and destination memref "
1597                        "types should be the same");
1598 
1599   if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
1600     if (!operandMemRefType.getLayout().isIdentity())
1601       return emitOpError("source memref type should have identity affine map");
1602 
1603   int64_t shapeSize = shape().getType().cast<MemRefType>().getDimSize(0);
1604   auto resultMemRefType = resultType.dyn_cast<MemRefType>();
1605   if (resultMemRefType) {
1606     if (!resultMemRefType.getLayout().isIdentity())
1607       return emitOpError("result memref type should have identity affine map");
1608     if (shapeSize == ShapedType::kDynamicSize)
1609       return emitOpError("cannot use shape operand with dynamic length to "
1610                          "reshape to statically-ranked memref type");
1611     if (shapeSize != resultMemRefType.getRank())
1612       return emitOpError(
1613           "length of shape operand differs from the result's memref rank");
1614   }
1615   return success();
1616 }
1617 
1618 //===----------------------------------------------------------------------===//
1619 // StoreOp
1620 //===----------------------------------------------------------------------===//
1621 
1622 LogicalResult StoreOp::verify() {
1623   if (getNumOperands() != 2 + getMemRefType().getRank())
1624     return emitOpError("store index operand count not equal to memref rank");
1625 
1626   return success();
1627 }
1628 
1629 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
1630                             SmallVectorImpl<OpFoldResult> &results) {
1631   /// store(memrefcast) -> store
1632   return foldMemRefCast(*this, getValueToStore());
1633 }
1634 
1635 //===----------------------------------------------------------------------===//
1636 // SubViewOp
1637 //===----------------------------------------------------------------------===//
1638 
1639 namespace {
1640 /// Helpers to write more idiomatic operations.
1641 namespace saturated_arith {
1642 struct Wrapper {
1643   explicit Wrapper(int64_t v) : v(v) {}
1644   operator int64_t() { return v; }
1645   int64_t v;
1646 };
1647 Wrapper operator+(Wrapper a, int64_t b) {
1648   if (ShapedType::isDynamicStrideOrOffset(a) ||
1649       ShapedType::isDynamicStrideOrOffset(b))
1650     return Wrapper(ShapedType::kDynamicStrideOrOffset);
1651   return Wrapper(a.v + b);
1652 }
1653 Wrapper operator*(Wrapper a, int64_t b) {
1654   if (ShapedType::isDynamicStrideOrOffset(a) ||
1655       ShapedType::isDynamicStrideOrOffset(b))
1656     return Wrapper(ShapedType::kDynamicStrideOrOffset);
1657   return Wrapper(a.v * b);
1658 }
1659 } // namespace saturated_arith
1660 } // namespace
1661 
1662 /// A subview result type can be fully inferred from the source type and the
1663 /// static representation of offsets, sizes and strides. Special sentinels
1664 /// encode the dynamic case.
1665 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
1666                                 ArrayRef<int64_t> staticOffsets,
1667                                 ArrayRef<int64_t> staticSizes,
1668                                 ArrayRef<int64_t> staticStrides) {
1669   unsigned rank = sourceMemRefType.getRank();
1670   (void)rank;
1671   assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
1672   assert(staticSizes.size() == rank && "staticSizes length mismatch");
1673   assert(staticStrides.size() == rank && "staticStrides length mismatch");
1674 
1675   // Extract source offset and strides.
1676   int64_t sourceOffset;
1677   SmallVector<int64_t, 4> sourceStrides;
1678   auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset);
1679   assert(succeeded(res) && "SubViewOp expected strided memref type");
1680   (void)res;
1681 
1682   // Compute target offset whose value is:
1683   //   `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
1684   int64_t targetOffset = sourceOffset;
1685   for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
1686     auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
1687     using namespace saturated_arith;
1688     targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride;
1689   }
1690 
1691   // Compute target stride whose value is:
1692   //   `sourceStrides_i * staticStrides_i`.
1693   SmallVector<int64_t, 4> targetStrides;
1694   targetStrides.reserve(staticOffsets.size());
1695   for (auto it : llvm::zip(sourceStrides, staticStrides)) {
1696     auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
1697     using namespace saturated_arith;
1698     targetStrides.push_back(Wrapper(sourceStride) * staticStride);
1699   }
1700 
1701   // The type is now known.
1702   return MemRefType::get(
1703       staticSizes, sourceMemRefType.getElementType(),
1704       makeStridedLinearLayoutMap(targetStrides, targetOffset,
1705                                  sourceMemRefType.getContext()),
1706       sourceMemRefType.getMemorySpace());
1707 }
1708 
1709 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
1710                                 ArrayRef<OpFoldResult> offsets,
1711                                 ArrayRef<OpFoldResult> sizes,
1712                                 ArrayRef<OpFoldResult> strides) {
1713   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1714   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1715   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1716                              ShapedType::kDynamicStrideOrOffset);
1717   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1718                              ShapedType::kDynamicSize);
1719   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1720                              ShapedType::kDynamicStrideOrOffset);
1721   return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
1722                                     staticSizes, staticStrides);
1723 }
1724 
1725 Type SubViewOp::inferRankReducedResultType(unsigned resultRank,
1726                                            MemRefType sourceRankedTensorType,
1727                                            ArrayRef<int64_t> offsets,
1728                                            ArrayRef<int64_t> sizes,
1729                                            ArrayRef<int64_t> strides) {
1730   auto inferredType =
1731       inferResultType(sourceRankedTensorType, offsets, sizes, strides)
1732           .cast<MemRefType>();
1733   assert(inferredType.getRank() >= resultRank && "expected ");
1734   int rankDiff = inferredType.getRank() - resultRank;
1735   if (rankDiff > 0) {
1736     auto shape = inferredType.getShape();
1737     llvm::SmallBitVector dimsToProject =
1738         getPositionsOfShapeOne(rankDiff, shape);
1739     SmallVector<int64_t> projectedShape;
1740     for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
1741       if (!dimsToProject.test(pos))
1742         projectedShape.push_back(shape[pos]);
1743 
1744     AffineMap map = inferredType.getLayout().getAffineMap();
1745     if (!map.isIdentity())
1746       map = getProjectedMap(map, dimsToProject);
1747     inferredType =
1748         MemRefType::get(projectedShape, inferredType.getElementType(), map,
1749                         inferredType.getMemorySpace());
1750   }
1751   return inferredType;
1752 }
1753 
1754 Type SubViewOp::inferRankReducedResultType(unsigned resultRank,
1755                                            MemRefType sourceRankedTensorType,
1756                                            ArrayRef<OpFoldResult> offsets,
1757                                            ArrayRef<OpFoldResult> sizes,
1758                                            ArrayRef<OpFoldResult> strides) {
1759   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1760   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1761   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1762                              ShapedType::kDynamicStrideOrOffset);
1763   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1764                              ShapedType::kDynamicSize);
1765   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1766                              ShapedType::kDynamicStrideOrOffset);
1767   return SubViewOp::inferRankReducedResultType(
1768       resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
1769       staticStrides);
1770 }
1771 // Build a SubViewOp with mixed static and dynamic entries and custom result
1772 // type. If the type passed is nullptr, it is inferred.
1773 void SubViewOp::build(OpBuilder &b, OperationState &result,
1774                       MemRefType resultType, Value source,
1775                       ArrayRef<OpFoldResult> offsets,
1776                       ArrayRef<OpFoldResult> sizes,
1777                       ArrayRef<OpFoldResult> strides,
1778                       ArrayRef<NamedAttribute> attrs) {
1779   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1780   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1781   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1782                              ShapedType::kDynamicStrideOrOffset);
1783   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1784                              ShapedType::kDynamicSize);
1785   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1786                              ShapedType::kDynamicStrideOrOffset);
1787   auto sourceMemRefType = source.getType().cast<MemRefType>();
1788   // Structuring implementation this way avoids duplication between builders.
1789   if (!resultType) {
1790     resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
1791                                             staticSizes, staticStrides)
1792                      .cast<MemRefType>();
1793   }
1794   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1795         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1796         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1797   result.addAttributes(attrs);
1798 }
1799 
1800 // Build a SubViewOp with mixed static and dynamic entries and inferred result
1801 // type.
1802 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1803                       ArrayRef<OpFoldResult> offsets,
1804                       ArrayRef<OpFoldResult> sizes,
1805                       ArrayRef<OpFoldResult> strides,
1806                       ArrayRef<NamedAttribute> attrs) {
1807   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
1808 }
1809 
1810 // Build a SubViewOp with static entries and inferred result type.
1811 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1812                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1813                       ArrayRef<int64_t> strides,
1814                       ArrayRef<NamedAttribute> attrs) {
1815   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1816       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
1817         return b.getI64IntegerAttr(v);
1818       }));
1819   SmallVector<OpFoldResult> sizeValues =
1820       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1821         return b.getI64IntegerAttr(v);
1822       }));
1823   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1824       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1825         return b.getI64IntegerAttr(v);
1826       }));
1827   build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
1828 }
1829 
1830 // Build a SubViewOp with dynamic entries and custom result type. If the
1831 // type passed is nullptr, it is inferred.
1832 void SubViewOp::build(OpBuilder &b, OperationState &result,
1833                       MemRefType resultType, Value source,
1834                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1835                       ArrayRef<int64_t> strides,
1836                       ArrayRef<NamedAttribute> attrs) {
1837   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1838       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
1839         return b.getI64IntegerAttr(v);
1840       }));
1841   SmallVector<OpFoldResult> sizeValues =
1842       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1843         return b.getI64IntegerAttr(v);
1844       }));
1845   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1846       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1847         return b.getI64IntegerAttr(v);
1848       }));
1849   build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
1850         attrs);
1851 }
1852 
1853 // Build a SubViewOp with dynamic entries and custom result type. If the type
1854 // passed is nullptr, it is inferred.
1855 void SubViewOp::build(OpBuilder &b, OperationState &result,
1856                       MemRefType resultType, Value source, ValueRange offsets,
1857                       ValueRange sizes, ValueRange strides,
1858                       ArrayRef<NamedAttribute> attrs) {
1859   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1860       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1861   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1862       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1863   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1864       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1865   build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
1866 }
1867 
1868 // Build a SubViewOp with dynamic entries and inferred result type.
1869 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1870                       ValueRange offsets, ValueRange sizes, ValueRange strides,
1871                       ArrayRef<NamedAttribute> attrs) {
1872   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
1873 }
1874 
1875 /// For ViewLikeOpInterface.
1876 Value SubViewOp::getViewSource() { return source(); }
1877 
1878 /// Return true if t1 and t2 have equal offsets (both dynamic or of same static
1879 /// value).
1880 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
1881   AffineExpr t1Offset, t2Offset;
1882   SmallVector<AffineExpr> t1Strides, t2Strides;
1883   auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
1884   auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
1885   return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
1886 }
1887 
1888 /// Checks if `original` Type type can be rank reduced to `reduced` type.
1889 /// This function is slight variant of `is subsequence` algorithm where
1890 /// not matching dimension must be 1.
1891 static SliceVerificationResult
1892 isRankReducedMemRefType(MemRefType originalType,
1893                         MemRefType candidateRankReducedType,
1894                         ArrayRef<OpFoldResult> sizes) {
1895   auto partialRes = isRankReducedType(originalType, candidateRankReducedType);
1896   if (partialRes != SliceVerificationResult::Success)
1897     return partialRes;
1898 
1899   auto optionalUnusedDimsMask = computeMemRefRankReductionMask(
1900       originalType, candidateRankReducedType, sizes);
1901 
1902   // Sizes cannot be matched in case empty vector is returned.
1903   if (!optionalUnusedDimsMask.hasValue())
1904     return SliceVerificationResult::LayoutMismatch;
1905 
1906   if (originalType.getMemorySpace() !=
1907       candidateRankReducedType.getMemorySpace())
1908     return SliceVerificationResult::MemSpaceMismatch;
1909 
1910   // No amount of stride dropping can reconcile incompatible offsets.
1911   if (!haveCompatibleOffsets(originalType, candidateRankReducedType))
1912     return SliceVerificationResult::LayoutMismatch;
1913 
1914   return SliceVerificationResult::Success;
1915 }
1916 
1917 template <typename OpTy>
1918 static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result,
1919                                             OpTy op, Type expectedType) {
1920   auto memrefType = expectedType.cast<ShapedType>();
1921   switch (result) {
1922   case SliceVerificationResult::Success:
1923     return success();
1924   case SliceVerificationResult::RankTooLarge:
1925     return op.emitError("expected result rank to be smaller or equal to ")
1926            << "the source rank. ";
1927   case SliceVerificationResult::SizeMismatch:
1928     return op.emitError("expected result type to be ")
1929            << expectedType
1930            << " or a rank-reduced version. (mismatch of result sizes) ";
1931   case SliceVerificationResult::ElemTypeMismatch:
1932     return op.emitError("expected result element type to be ")
1933            << memrefType.getElementType();
1934   case SliceVerificationResult::MemSpaceMismatch:
1935     return op.emitError("expected result and source memory spaces to match.");
1936   case SliceVerificationResult::LayoutMismatch:
1937     return op.emitError("expected result type to be ")
1938            << expectedType
1939            << " or a rank-reduced version. (mismatch of result layout) ";
1940   }
1941   llvm_unreachable("unexpected subview verification result");
1942 }
1943 
1944 /// Verifier for SubViewOp.
1945 LogicalResult SubViewOp::verify() {
1946   MemRefType baseType = getSourceType();
1947   MemRefType subViewType = getType();
1948 
1949   // The base memref and the view memref should be in the same memory space.
1950   if (baseType.getMemorySpace() != subViewType.getMemorySpace())
1951     return emitError("different memory spaces specified for base memref "
1952                      "type ")
1953            << baseType << " and subview memref type " << subViewType;
1954 
1955   // Verify that the base memref type has a strided layout map.
1956   if (!isStrided(baseType))
1957     return emitError("base type ") << baseType << " is not strided";
1958 
1959   // Verify result type against inferred type.
1960   auto expectedType = SubViewOp::inferResultType(
1961       baseType, extractFromI64ArrayAttr(static_offsets()),
1962       extractFromI64ArrayAttr(static_sizes()),
1963       extractFromI64ArrayAttr(static_strides()));
1964 
1965   auto result = isRankReducedMemRefType(expectedType.cast<MemRefType>(),
1966                                         subViewType, getMixedSizes());
1967   return produceSubViewErrorMsg(result, *this, expectedType);
1968 }
1969 
1970 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
1971   return os << "range " << range.offset << ":" << range.size << ":"
1972             << range.stride;
1973 }
1974 
1975 /// Return the list of Range (i.e. offset, size, stride). Each Range
1976 /// entry contains either the dynamic value or a ConstantIndexOp constructed
1977 /// with `b` at location `loc`.
1978 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
1979                                               OpBuilder &b, Location loc) {
1980   std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
1981   assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
1982   assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
1983   SmallVector<Range, 8> res;
1984   unsigned rank = ranks[0];
1985   res.reserve(rank);
1986   for (unsigned idx = 0; idx < rank; ++idx) {
1987     Value offset =
1988         op.isDynamicOffset(idx)
1989             ? op.getDynamicOffset(idx)
1990             : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx));
1991     Value size =
1992         op.isDynamicSize(idx)
1993             ? op.getDynamicSize(idx)
1994             : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx));
1995     Value stride =
1996         op.isDynamicStride(idx)
1997             ? op.getDynamicStride(idx)
1998             : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx));
1999     res.emplace_back(Range{offset, size, stride});
2000   }
2001   return res;
2002 }
2003 
2004 /// Compute the canonical result type of a SubViewOp. Call `inferResultType` to
2005 /// deduce the result type for the given `sourceType`. Additionally, reduce the
2006 /// rank of the inferred result type if `currentResultType` is lower rank than
2007 /// `currentSourceType`. Use this signature if `sourceType` is updated together
2008 /// with the result type. In this case, it is important to compute the dropped
2009 /// dimensions using `currentSourceType` whose strides align with
2010 /// `currentResultType`.
2011 static MemRefType getCanonicalSubViewResultType(
2012     MemRefType currentResultType, MemRefType currentSourceType,
2013     MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
2014     ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
2015   auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets,
2016                                                        mixedSizes, mixedStrides)
2017                                 .cast<MemRefType>();
2018   llvm::Optional<llvm::SmallBitVector> unusedDims =
2019       computeMemRefRankReductionMask(currentSourceType, currentResultType,
2020                                      mixedSizes);
2021   // Return nullptr as failure mode.
2022   if (!unusedDims)
2023     return nullptr;
2024   SmallVector<int64_t> shape;
2025   for (const auto &sizes : llvm::enumerate(nonRankReducedType.getShape())) {
2026     if (unusedDims->test(sizes.index()))
2027       continue;
2028     shape.push_back(sizes.value());
2029   }
2030   AffineMap layoutMap = nonRankReducedType.getLayout().getAffineMap();
2031   if (!layoutMap.isIdentity())
2032     layoutMap = getProjectedMap(layoutMap, unusedDims.getValue());
2033   return MemRefType::get(shape, nonRankReducedType.getElementType(), layoutMap,
2034                          nonRankReducedType.getMemorySpace());
2035 }
2036 
2037 /// Compute the canonical result type of a SubViewOp. Call `inferResultType` to
2038 /// deduce the result type. Additionally, reduce the rank of the inferred result
2039 /// type if `currentResultType` is lower rank than `sourceType`.
2040 static MemRefType getCanonicalSubViewResultType(
2041     MemRefType currentResultType, MemRefType sourceType,
2042     ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes,
2043     ArrayRef<OpFoldResult> mixedStrides) {
2044   return getCanonicalSubViewResultType(currentResultType, sourceType,
2045                                        sourceType, mixedOffsets, mixedSizes,
2046                                        mixedStrides);
2047 }
2048 
2049 /// Helper method to check if a `subview` operation is trivially a no-op. This
2050 /// is the case if the all offsets are zero, all strides are 1, and the source
2051 /// shape is same as the size of the subview. In such cases, the subview can be
2052 /// folded into its source.
2053 static bool isTrivialSubViewOp(SubViewOp subViewOp) {
2054   if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
2055     return false;
2056 
2057   auto mixedOffsets = subViewOp.getMixedOffsets();
2058   auto mixedSizes = subViewOp.getMixedSizes();
2059   auto mixedStrides = subViewOp.getMixedStrides();
2060 
2061   // Check offsets are zero.
2062   if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
2063         Optional<int64_t> intValue = getConstantIntValue(ofr);
2064         return !intValue || intValue.getValue() != 0;
2065       }))
2066     return false;
2067 
2068   // Check strides are one.
2069   if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
2070         Optional<int64_t> intValue = getConstantIntValue(ofr);
2071         return !intValue || intValue.getValue() != 1;
2072       }))
2073     return false;
2074 
2075   // Check all size values are static and matches the (static) source shape.
2076   ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
2077   for (const auto &size : llvm::enumerate(mixedSizes)) {
2078     Optional<int64_t> intValue = getConstantIntValue(size.value());
2079     if (!intValue || intValue.getValue() != sourceShape[size.index()])
2080       return false;
2081   }
2082   // All conditions met. The `SubViewOp` is foldable as a no-op.
2083   return true;
2084 }
2085 
2086 namespace {
2087 /// Pattern to rewrite a subview op with MemRefCast arguments.
2088 /// This essentially pushes memref.cast past its consuming subview when
2089 /// `canFoldIntoConsumerOp` is true.
2090 ///
2091 /// Example:
2092 /// ```
2093 ///   %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
2094 ///   %1 = memref.subview %0[0, 0][3, 4][1, 1] :
2095 ///     memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
2096 /// ```
2097 /// is rewritten into:
2098 /// ```
2099 ///   %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
2100 ///   %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to
2101 ///     memref<3x4xf32, offset:?, strides:[?, 1]>
2102 /// ```
2103 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
2104 public:
2105   using OpRewritePattern<SubViewOp>::OpRewritePattern;
2106 
2107   LogicalResult matchAndRewrite(SubViewOp subViewOp,
2108                                 PatternRewriter &rewriter) const override {
2109     // Any constant operand, just return to let SubViewOpConstantFolder kick in.
2110     if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
2111           return matchPattern(operand, matchConstantIndex());
2112         }))
2113       return failure();
2114 
2115     auto castOp = subViewOp.source().getDefiningOp<CastOp>();
2116     if (!castOp)
2117       return failure();
2118 
2119     if (!CastOp::canFoldIntoConsumerOp(castOp))
2120       return failure();
2121 
2122     // Compute the SubViewOp result type after folding the MemRefCastOp. Use the
2123     // MemRefCastOp source operand type to infer the result type and the current
2124     // SubViewOp source operand type to compute the dropped dimensions if the
2125     // operation is rank-reducing.
2126     auto resultType = getCanonicalSubViewResultType(
2127         subViewOp.getType(), subViewOp.getSourceType(),
2128         castOp.source().getType().cast<MemRefType>(),
2129         subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
2130         subViewOp.getMixedStrides());
2131     if (!resultType)
2132       return failure();
2133 
2134     Value newSubView = rewriter.create<SubViewOp>(
2135         subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
2136         subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),
2137         subViewOp.static_sizes(), subViewOp.static_strides());
2138     rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
2139                                         newSubView);
2140     return success();
2141   }
2142 };
2143 
2144 /// Canonicalize subview ops that are no-ops. When the source shape is not same
2145 /// as a result shape due to use of `affine_map`.
2146 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
2147 public:
2148   using OpRewritePattern<SubViewOp>::OpRewritePattern;
2149 
2150   LogicalResult matchAndRewrite(SubViewOp subViewOp,
2151                                 PatternRewriter &rewriter) const override {
2152     if (!isTrivialSubViewOp(subViewOp))
2153       return failure();
2154     if (subViewOp.getSourceType() == subViewOp.getType()) {
2155       rewriter.replaceOp(subViewOp, subViewOp.source());
2156       return success();
2157     }
2158     rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
2159                                         subViewOp.source());
2160     return success();
2161   }
2162 };
2163 } // namespace
2164 
2165 /// Return the canonical type of the result of a subview.
2166 struct SubViewReturnTypeCanonicalizer {
2167   MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
2168                         ArrayRef<OpFoldResult> mixedSizes,
2169                         ArrayRef<OpFoldResult> mixedStrides) {
2170     return getCanonicalSubViewResultType(op.getType(), op.getSourceType(),
2171                                          mixedOffsets, mixedSizes,
2172                                          mixedStrides);
2173   }
2174 };
2175 
2176 /// A canonicalizer wrapper to replace SubViewOps.
2177 struct SubViewCanonicalizer {
2178   void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
2179     rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2180   }
2181 };
2182 
2183 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
2184                                             MLIRContext *context) {
2185   results
2186       .add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
2187                SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>,
2188            SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
2189 }
2190 
2191 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
2192   auto resultShapedType = getResult().getType().cast<ShapedType>();
2193   auto sourceShapedType = source().getType().cast<ShapedType>();
2194 
2195   if (resultShapedType.hasStaticShape() &&
2196       resultShapedType == sourceShapedType) {
2197     return getViewSource();
2198   }
2199 
2200   return {};
2201 }
2202 
2203 //===----------------------------------------------------------------------===//
2204 // TransposeOp
2205 //===----------------------------------------------------------------------===//
2206 
2207 /// Build a strided memref type by applying `permutationMap` tp `memRefType`.
2208 static MemRefType inferTransposeResultType(MemRefType memRefType,
2209                                            AffineMap permutationMap) {
2210   auto rank = memRefType.getRank();
2211   auto originalSizes = memRefType.getShape();
2212   // Compute permuted sizes.
2213   SmallVector<int64_t, 4> sizes(rank, 0);
2214   for (const auto &en : llvm::enumerate(permutationMap.getResults()))
2215     sizes[en.index()] =
2216         originalSizes[en.value().cast<AffineDimExpr>().getPosition()];
2217 
2218   // Compute permuted strides.
2219   int64_t offset;
2220   SmallVector<int64_t, 4> strides;
2221   auto res = getStridesAndOffset(memRefType, strides, offset);
2222   assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank));
2223   (void)res;
2224   auto map =
2225       makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
2226   map = permutationMap ? map.compose(permutationMap) : map;
2227   return MemRefType::Builder(memRefType)
2228       .setShape(sizes)
2229       .setLayout(AffineMapAttr::get(map));
2230 }
2231 
2232 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
2233                         AffineMapAttr permutation,
2234                         ArrayRef<NamedAttribute> attrs) {
2235   auto permutationMap = permutation.getValue();
2236   assert(permutationMap);
2237 
2238   auto memRefType = in.getType().cast<MemRefType>();
2239   // Compute result type.
2240   MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
2241 
2242   build(b, result, resultType, in, attrs);
2243   result.addAttribute(TransposeOp::getPermutationAttrName(), permutation);
2244 }
2245 
2246 // transpose $in $permutation attr-dict : type($in) `to` type(results)
2247 void TransposeOp::print(OpAsmPrinter &p) {
2248   p << " " << in() << " " << permutation();
2249   p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
2250   p << " : " << in().getType() << " to " << getType();
2251 }
2252 
2253 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
2254   OpAsmParser::OperandType in;
2255   AffineMap permutation;
2256   MemRefType srcType, dstType;
2257   if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
2258       parser.parseOptionalAttrDict(result.attributes) ||
2259       parser.parseColonType(srcType) ||
2260       parser.resolveOperand(in, srcType, result.operands) ||
2261       parser.parseKeywordType("to", dstType) ||
2262       parser.addTypeToList(dstType, result.types))
2263     return failure();
2264 
2265   result.addAttribute(TransposeOp::getPermutationAttrName(),
2266                       AffineMapAttr::get(permutation));
2267   return success();
2268 }
2269 
2270 LogicalResult TransposeOp::verify() {
2271   if (!permutation().isPermutation())
2272     return emitOpError("expected a permutation map");
2273   if (permutation().getNumDims() != getShapedType().getRank())
2274     return emitOpError("expected a permutation map of same rank as the input");
2275 
2276   auto srcType = in().getType().cast<MemRefType>();
2277   auto dstType = getType().cast<MemRefType>();
2278   auto transposedType = inferTransposeResultType(srcType, permutation());
2279   if (dstType != transposedType)
2280     return emitOpError("output type ")
2281            << dstType << " does not match transposed input type " << srcType
2282            << ", " << transposedType;
2283   return success();
2284 }
2285 
2286 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
2287   if (succeeded(foldMemRefCast(*this)))
2288     return getResult();
2289   return {};
2290 }
2291 
2292 //===----------------------------------------------------------------------===//
2293 // ViewOp
2294 //===----------------------------------------------------------------------===//
2295 
2296 ParseResult ViewOp::parse(OpAsmParser &parser, OperationState &result) {
2297   OpAsmParser::OperandType srcInfo;
2298   SmallVector<OpAsmParser::OperandType, 1> offsetInfo;
2299   SmallVector<OpAsmParser::OperandType, 4> sizesInfo;
2300   auto indexType = parser.getBuilder().getIndexType();
2301   Type srcType, dstType;
2302   SMLoc offsetLoc;
2303   if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) ||
2304       parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square))
2305     return failure();
2306 
2307   if (offsetInfo.size() != 1)
2308     return parser.emitError(offsetLoc) << "expects 1 offset operand";
2309 
2310   return failure(
2311       parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) ||
2312       parser.parseOptionalAttrDict(result.attributes) ||
2313       parser.parseColonType(srcType) ||
2314       parser.resolveOperand(srcInfo, srcType, result.operands) ||
2315       parser.resolveOperands(offsetInfo, indexType, result.operands) ||
2316       parser.resolveOperands(sizesInfo, indexType, result.operands) ||
2317       parser.parseKeywordType("to", dstType) ||
2318       parser.addTypeToList(dstType, result.types));
2319 }
2320 
2321 void ViewOp::print(OpAsmPrinter &p) {
2322   p << ' ' << getOperand(0) << '[';
2323   p.printOperand(byte_shift());
2324   p << "][" << sizes() << ']';
2325   p.printOptionalAttrDict((*this)->getAttrs());
2326   p << " : " << getOperand(0).getType() << " to " << getType();
2327 }
2328 
2329 LogicalResult ViewOp::verify() {
2330   auto baseType = getOperand(0).getType().cast<MemRefType>();
2331   auto viewType = getType();
2332 
2333   // The base memref should have identity layout map (or none).
2334   if (!baseType.getLayout().isIdentity())
2335     return emitError("unsupported map for base memref type ") << baseType;
2336 
2337   // The result memref should have identity layout map (or none).
2338   if (!viewType.getLayout().isIdentity())
2339     return emitError("unsupported map for result memref type ") << viewType;
2340 
2341   // The base memref and the view memref should be in the same memory space.
2342   if (baseType.getMemorySpace() != viewType.getMemorySpace())
2343     return emitError("different memory spaces specified for base memref "
2344                      "type ")
2345            << baseType << " and view memref type " << viewType;
2346 
2347   // Verify that we have the correct number of sizes for the result type.
2348   unsigned numDynamicDims = viewType.getNumDynamicDims();
2349   if (sizes().size() != numDynamicDims)
2350     return emitError("incorrect number of size operands for type ") << viewType;
2351 
2352   return success();
2353 }
2354 
2355 Value ViewOp::getViewSource() { return source(); }
2356 
2357 namespace {
2358 
2359 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
2360   using OpRewritePattern<ViewOp>::OpRewritePattern;
2361 
2362   LogicalResult matchAndRewrite(ViewOp viewOp,
2363                                 PatternRewriter &rewriter) const override {
2364     // Return if none of the operands are constants.
2365     if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
2366           return matchPattern(operand, matchConstantIndex());
2367         }))
2368       return failure();
2369 
2370     // Get result memref type.
2371     auto memrefType = viewOp.getType();
2372 
2373     // Get offset from old memref view type 'memRefType'.
2374     int64_t oldOffset;
2375     SmallVector<int64_t, 4> oldStrides;
2376     if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
2377       return failure();
2378     assert(oldOffset == 0 && "Expected 0 offset");
2379 
2380     SmallVector<Value, 4> newOperands;
2381 
2382     // Offset cannot be folded into result type.
2383 
2384     // Fold any dynamic dim operands which are produced by a constant.
2385     SmallVector<int64_t, 4> newShapeConstants;
2386     newShapeConstants.reserve(memrefType.getRank());
2387 
2388     unsigned dynamicDimPos = 0;
2389     unsigned rank = memrefType.getRank();
2390     for (unsigned dim = 0, e = rank; dim < e; ++dim) {
2391       int64_t dimSize = memrefType.getDimSize(dim);
2392       // If this is already static dimension, keep it.
2393       if (!ShapedType::isDynamic(dimSize)) {
2394         newShapeConstants.push_back(dimSize);
2395         continue;
2396       }
2397       auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp();
2398       if (auto constantIndexOp =
2399               dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
2400         // Dynamic shape dimension will be folded.
2401         newShapeConstants.push_back(constantIndexOp.value());
2402       } else {
2403         // Dynamic shape dimension not folded; copy operand from old memref.
2404         newShapeConstants.push_back(dimSize);
2405         newOperands.push_back(viewOp.sizes()[dynamicDimPos]);
2406       }
2407       dynamicDimPos++;
2408     }
2409 
2410     // Create new memref type with constant folded dims.
2411     MemRefType newMemRefType =
2412         MemRefType::Builder(memrefType).setShape(newShapeConstants);
2413     // Nothing new, don't fold.
2414     if (newMemRefType == memrefType)
2415       return failure();
2416 
2417     // Create new ViewOp.
2418     auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType,
2419                                              viewOp.getOperand(0),
2420                                              viewOp.byte_shift(), newOperands);
2421     // Insert a cast so we have the same type as the old memref type.
2422     rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
2423     return success();
2424   }
2425 };
2426 
2427 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
2428   using OpRewritePattern<ViewOp>::OpRewritePattern;
2429 
2430   LogicalResult matchAndRewrite(ViewOp viewOp,
2431                                 PatternRewriter &rewriter) const override {
2432     Value memrefOperand = viewOp.getOperand(0);
2433     CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
2434     if (!memrefCastOp)
2435       return failure();
2436     Value allocOperand = memrefCastOp.getOperand();
2437     AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
2438     if (!allocOp)
2439       return failure();
2440     rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
2441                                         viewOp.byte_shift(), viewOp.sizes());
2442     return success();
2443   }
2444 };
2445 
2446 } // namespace
2447 
2448 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
2449                                          MLIRContext *context) {
2450   results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
2451 }
2452 
2453 //===----------------------------------------------------------------------===//
2454 // AtomicRMWOp
2455 //===----------------------------------------------------------------------===//
2456 
2457 LogicalResult AtomicRMWOp::verify() {
2458   if (getMemRefType().getRank() != getNumOperands() - 2)
2459     return emitOpError(
2460         "expects the number of subscripts to be equal to memref rank");
2461   switch (kind()) {
2462   case arith::AtomicRMWKind::addf:
2463   case arith::AtomicRMWKind::maxf:
2464   case arith::AtomicRMWKind::minf:
2465   case arith::AtomicRMWKind::mulf:
2466     if (!value().getType().isa<FloatType>())
2467       return emitOpError() << "with kind '"
2468                            << arith::stringifyAtomicRMWKind(kind())
2469                            << "' expects a floating-point type";
2470     break;
2471   case arith::AtomicRMWKind::addi:
2472   case arith::AtomicRMWKind::maxs:
2473   case arith::AtomicRMWKind::maxu:
2474   case arith::AtomicRMWKind::mins:
2475   case arith::AtomicRMWKind::minu:
2476   case arith::AtomicRMWKind::muli:
2477   case arith::AtomicRMWKind::ori:
2478   case arith::AtomicRMWKind::andi:
2479     if (!value().getType().isa<IntegerType>())
2480       return emitOpError() << "with kind '"
2481                            << arith::stringifyAtomicRMWKind(kind())
2482                            << "' expects an integer type";
2483     break;
2484   default:
2485     break;
2486   }
2487   return success();
2488 }
2489 
2490 OpFoldResult AtomicRMWOp::fold(ArrayRef<Attribute> operands) {
2491   /// atomicrmw(memrefcast) -> atomicrmw
2492   if (succeeded(foldMemRefCast(*this, value())))
2493     return getResult();
2494   return OpFoldResult();
2495 }
2496 
2497 //===----------------------------------------------------------------------===//
2498 // TableGen'd op method definitions
2499 //===----------------------------------------------------------------------===//
2500 
2501 #define GET_OP_CLASSES
2502 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
2503