1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
13 #include "mlir/Dialect/Utils/StaticValueUtils.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Interfaces/InferTypeOpInterface.h"
21 #include "mlir/Interfaces/ViewLikeInterface.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallBitVector.h"
24 
25 using namespace mlir;
26 using namespace mlir::memref;
27 
28 /// Materialize a single constant operation from a given attribute value with
29 /// the desired resultant type.
30 Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
31                                               Attribute value, Type type,
32                                               Location loc) {
33   if (arith::ConstantOp::isBuildableWith(value, type))
34     return builder.create<arith::ConstantOp>(loc, value, type);
35   return nullptr;
36 }
37 
38 //===----------------------------------------------------------------------===//
39 // Common canonicalization pattern support logic
40 //===----------------------------------------------------------------------===//
41 
42 /// This is a common class used for patterns of the form
43 /// "someop(memrefcast) -> someop".  It folds the source of any memref.cast
44 /// into the root operation directly.
45 LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) {
46   bool folded = false;
47   for (OpOperand &operand : op->getOpOperands()) {
48     auto cast = operand.get().getDefiningOp<CastOp>();
49     if (cast && operand.get() != inner &&
50         !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
51       operand.set(cast.getOperand());
52       folded = true;
53     }
54   }
55   return success(folded);
56 }
57 
58 /// Return an unranked/ranked tensor type for the given unranked/ranked memref
59 /// type.
60 Type mlir::memref::getTensorTypeFromMemRefType(Type type) {
61   if (auto memref = type.dyn_cast<MemRefType>())
62     return RankedTensorType::get(memref.getShape(), memref.getElementType());
63   if (auto memref = type.dyn_cast<UnrankedMemRefType>())
64     return UnrankedTensorType::get(memref.getElementType());
65   return NoneType::get(type.getContext());
66 }
67 
68 //===----------------------------------------------------------------------===//
69 // AllocOp / AllocaOp
70 //===----------------------------------------------------------------------===//
71 
72 template <typename AllocLikeOp>
73 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
74   static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
75                 "applies to only alloc or alloca");
76   auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>();
77   if (!memRefType)
78     return op.emitOpError("result must be a memref");
79 
80   if (static_cast<int64_t>(op.dynamicSizes().size()) !=
81       memRefType.getNumDynamicDims())
82     return op.emitOpError("dimension operand count does not equal memref "
83                           "dynamic dimension count");
84 
85   unsigned numSymbols = 0;
86   if (!memRefType.getLayout().isIdentity())
87     numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
88   if (op.symbolOperands().size() != numSymbols)
89     return op.emitOpError("symbol operand count does not equal memref symbol "
90                           "count: expected ")
91            << numSymbols << ", got " << op.symbolOperands().size();
92 
93   return success();
94 }
95 
96 LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); }
97 
98 LogicalResult AllocaOp::verify() {
99   // An alloca op needs to have an ancestor with an allocation scope trait.
100   if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
101     return emitOpError(
102         "requires an ancestor op with AutomaticAllocationScope trait");
103 
104   return verifyAllocLikeOp(*this);
105 }
106 
107 namespace {
108 /// Fold constant dimensions into an alloc like operation.
109 template <typename AllocLikeOp>
110 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
111   using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
112 
113   LogicalResult matchAndRewrite(AllocLikeOp alloc,
114                                 PatternRewriter &rewriter) const override {
115     // Check to see if any dimensions operands are constants.  If so, we can
116     // substitute and drop them.
117     if (llvm::none_of(alloc.dynamicSizes(), [](Value operand) {
118           return matchPattern(operand, matchConstantIndex());
119         }))
120       return failure();
121 
122     auto memrefType = alloc.getType();
123 
124     // Ok, we have one or more constant operands.  Collect the non-constant ones
125     // and keep track of the resultant memref type to build.
126     SmallVector<int64_t, 4> newShapeConstants;
127     newShapeConstants.reserve(memrefType.getRank());
128     SmallVector<Value, 4> dynamicSizes;
129 
130     unsigned dynamicDimPos = 0;
131     for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
132       int64_t dimSize = memrefType.getDimSize(dim);
133       // If this is already static dimension, keep it.
134       if (dimSize != -1) {
135         newShapeConstants.push_back(dimSize);
136         continue;
137       }
138       auto dynamicSize = alloc.dynamicSizes()[dynamicDimPos];
139       auto *defOp = dynamicSize.getDefiningOp();
140       if (auto constantIndexOp =
141               dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
142         // Dynamic shape dimension will be folded.
143         newShapeConstants.push_back(constantIndexOp.value());
144       } else {
145         // Dynamic shape dimension not folded; copy dynamicSize from old memref.
146         newShapeConstants.push_back(-1);
147         dynamicSizes.push_back(dynamicSize);
148       }
149       dynamicDimPos++;
150     }
151 
152     // Create new memref type (which will have fewer dynamic dimensions).
153     MemRefType newMemRefType =
154         MemRefType::Builder(memrefType).setShape(newShapeConstants);
155     assert(static_cast<int64_t>(dynamicSizes.size()) ==
156            newMemRefType.getNumDynamicDims());
157 
158     // Create and insert the alloc op for the new memref.
159     auto newAlloc = rewriter.create<AllocLikeOp>(
160         alloc.getLoc(), newMemRefType, dynamicSizes, alloc.symbolOperands(),
161         alloc.alignmentAttr());
162     // Insert a cast so we have the same type as the old alloc.
163     auto resultCast =
164         rewriter.create<CastOp>(alloc.getLoc(), alloc.getType(), newAlloc);
165 
166     rewriter.replaceOp(alloc, {resultCast});
167     return success();
168   }
169 };
170 
171 /// Fold alloc operations with no users or only store and dealloc uses.
172 template <typename T>
173 struct SimplifyDeadAlloc : public OpRewritePattern<T> {
174   using OpRewritePattern<T>::OpRewritePattern;
175 
176   LogicalResult matchAndRewrite(T alloc,
177                                 PatternRewriter &rewriter) const override {
178     if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
179           if (auto storeOp = dyn_cast<StoreOp>(op))
180             return storeOp.value() == alloc;
181           return !isa<DeallocOp>(op);
182         }))
183       return failure();
184 
185     for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
186       rewriter.eraseOp(user);
187 
188     rewriter.eraseOp(alloc);
189     return success();
190   }
191 };
192 } // namespace
193 
194 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
195                                           MLIRContext *context) {
196   results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
197 }
198 
199 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
200                                            MLIRContext *context) {
201   results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
202       context);
203 }
204 
205 //===----------------------------------------------------------------------===//
206 // AllocaScopeOp
207 //===----------------------------------------------------------------------===//
208 
209 void AllocaScopeOp::print(OpAsmPrinter &p) {
210   bool printBlockTerminators = false;
211 
212   p << ' ';
213   if (!results().empty()) {
214     p << " -> (" << getResultTypes() << ")";
215     printBlockTerminators = true;
216   }
217   p << ' ';
218   p.printRegion(bodyRegion(),
219                 /*printEntryBlockArgs=*/false,
220                 /*printBlockTerminators=*/printBlockTerminators);
221   p.printOptionalAttrDict((*this)->getAttrs());
222 }
223 
224 ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
225   // Create a region for the body.
226   result.regions.reserve(1);
227   Region *bodyRegion = result.addRegion();
228 
229   // Parse optional results type list.
230   if (parser.parseOptionalArrowTypeList(result.types))
231     return failure();
232 
233   // Parse the body region.
234   if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{}))
235     return failure();
236   AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
237                                   result.location);
238 
239   // Parse the optional attribute list.
240   if (parser.parseOptionalAttrDict(result.attributes))
241     return failure();
242 
243   return success();
244 }
245 
246 LogicalResult AllocaScopeOp::verify() {
247   return RegionBranchOpInterface::verifyTypes(*this);
248 }
249 
250 void AllocaScopeOp::getSuccessorRegions(
251     Optional<unsigned> index, ArrayRef<Attribute> operands,
252     SmallVectorImpl<RegionSuccessor> &regions) {
253   if (index.hasValue()) {
254     regions.push_back(RegionSuccessor(getResults()));
255     return;
256   }
257 
258   regions.push_back(RegionSuccessor(&bodyRegion()));
259 }
260 
261 //===----------------------------------------------------------------------===//
262 // AssumeAlignmentOp
263 //===----------------------------------------------------------------------===//
264 
265 LogicalResult AssumeAlignmentOp::verify() {
266   if (!llvm::isPowerOf2_32(alignment()))
267     return emitOpError("alignment must be power of 2");
268   return success();
269 }
270 
271 //===----------------------------------------------------------------------===//
272 // CastOp
273 //===----------------------------------------------------------------------===//
274 
275 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
276 /// source memref. This is useful to to fold a memref.cast into a consuming op
277 /// and implement canonicalization patterns for ops in different dialects that
278 /// may consume the results of memref.cast operations. Such foldable memref.cast
279 /// operations are typically inserted as `view` and `subview` ops are
280 /// canonicalized, to preserve the type compatibility of their uses.
281 ///
282 /// Returns true when all conditions are met:
283 /// 1. source and result are ranked memrefs with strided semantics and same
284 /// element type and rank.
285 /// 2. each of the source's size, offset or stride has more static information
286 /// than the corresponding result's size, offset or stride.
287 ///
288 /// Example 1:
289 /// ```mlir
290 ///   %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
291 ///   %2 = consumer %1 ... : memref<?x?xf32> ...
292 /// ```
293 ///
294 /// may fold into:
295 ///
296 /// ```mlir
297 ///   %2 = consumer %0 ... : memref<8x16xf32> ...
298 /// ```
299 ///
300 /// Example 2:
301 /// ```
302 ///   %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
303 ///          to memref<?x?xf32>
304 ///   consumer %1 : memref<?x?xf32> ...
305 /// ```
306 ///
307 /// may fold into:
308 ///
309 /// ```
310 ///   consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
311 /// ```
312 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
313   MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
314   MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
315 
316   // Requires ranked MemRefType.
317   if (!sourceType || !resultType)
318     return false;
319 
320   // Requires same elemental type.
321   if (sourceType.getElementType() != resultType.getElementType())
322     return false;
323 
324   // Requires same rank.
325   if (sourceType.getRank() != resultType.getRank())
326     return false;
327 
328   // Only fold casts between strided memref forms.
329   int64_t sourceOffset, resultOffset;
330   SmallVector<int64_t, 4> sourceStrides, resultStrides;
331   if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
332       failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
333     return false;
334 
335   // If cast is towards more static sizes along any dimension, don't fold.
336   for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
337     auto ss = std::get<0>(it), st = std::get<1>(it);
338     if (ss != st)
339       if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
340         return false;
341   }
342 
343   // If cast is towards more static offset along any dimension, don't fold.
344   if (sourceOffset != resultOffset)
345     if (ShapedType::isDynamicStrideOrOffset(sourceOffset) &&
346         !ShapedType::isDynamicStrideOrOffset(resultOffset))
347       return false;
348 
349   // If cast is towards more static strides along any dimension, don't fold.
350   for (auto it : llvm::zip(sourceStrides, resultStrides)) {
351     auto ss = std::get<0>(it), st = std::get<1>(it);
352     if (ss != st)
353       if (ShapedType::isDynamicStrideOrOffset(ss) &&
354           !ShapedType::isDynamicStrideOrOffset(st))
355         return false;
356   }
357 
358   return true;
359 }
360 
361 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
362   if (inputs.size() != 1 || outputs.size() != 1)
363     return false;
364   Type a = inputs.front(), b = outputs.front();
365   auto aT = a.dyn_cast<MemRefType>();
366   auto bT = b.dyn_cast<MemRefType>();
367 
368   auto uaT = a.dyn_cast<UnrankedMemRefType>();
369   auto ubT = b.dyn_cast<UnrankedMemRefType>();
370 
371   if (aT && bT) {
372     if (aT.getElementType() != bT.getElementType())
373       return false;
374     if (aT.getLayout() != bT.getLayout()) {
375       int64_t aOffset, bOffset;
376       SmallVector<int64_t, 4> aStrides, bStrides;
377       if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
378           failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
379           aStrides.size() != bStrides.size())
380         return false;
381 
382       // Strides along a dimension/offset are compatible if the value in the
383       // source memref is static and the value in the target memref is the
384       // same. They are also compatible if either one is dynamic (see
385       // description of MemRefCastOp for details).
386       auto checkCompatible = [](int64_t a, int64_t b) {
387         return (a == MemRefType::getDynamicStrideOrOffset() ||
388                 b == MemRefType::getDynamicStrideOrOffset() || a == b);
389       };
390       if (!checkCompatible(aOffset, bOffset))
391         return false;
392       for (const auto &aStride : enumerate(aStrides))
393         if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
394           return false;
395     }
396     if (aT.getMemorySpace() != bT.getMemorySpace())
397       return false;
398 
399     // They must have the same rank, and any specified dimensions must match.
400     if (aT.getRank() != bT.getRank())
401       return false;
402 
403     for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
404       int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
405       if (aDim != -1 && bDim != -1 && aDim != bDim)
406         return false;
407     }
408     return true;
409   } else {
410     if (!aT && !uaT)
411       return false;
412     if (!bT && !ubT)
413       return false;
414     // Unranked to unranked casting is unsupported
415     if (uaT && ubT)
416       return false;
417 
418     auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
419     auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
420     if (aEltType != bEltType)
421       return false;
422 
423     auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
424     auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
425     return aMemSpace == bMemSpace;
426   }
427 
428   return false;
429 }
430 
431 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
432   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
433 }
434 
435 //===----------------------------------------------------------------------===//
436 // CopyOp
437 //===----------------------------------------------------------------------===//
438 
439 namespace {
440 /// If the source/target of a CopyOp is a CastOp that does not modify the shape
441 /// and element type, the cast can be skipped. Such CastOps only cast the layout
442 /// of the type.
443 struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
444   using OpRewritePattern<CopyOp>::OpRewritePattern;
445 
446   LogicalResult matchAndRewrite(CopyOp copyOp,
447                                 PatternRewriter &rewriter) const override {
448     bool modified = false;
449 
450     // Check source.
451     if (auto castOp = copyOp.source().getDefiningOp<CastOp>()) {
452       auto fromType = castOp.source().getType().dyn_cast<MemRefType>();
453       auto toType = castOp.source().getType().dyn_cast<MemRefType>();
454 
455       if (fromType && toType) {
456         if (fromType.getShape() == toType.getShape() &&
457             fromType.getElementType() == toType.getElementType()) {
458           rewriter.updateRootInPlace(
459               copyOp, [&] { copyOp.sourceMutable().assign(castOp.source()); });
460           modified = true;
461         }
462       }
463     }
464 
465     // Check target.
466     if (auto castOp = copyOp.target().getDefiningOp<CastOp>()) {
467       auto fromType = castOp.source().getType().dyn_cast<MemRefType>();
468       auto toType = castOp.source().getType().dyn_cast<MemRefType>();
469 
470       if (fromType && toType) {
471         if (fromType.getShape() == toType.getShape() &&
472             fromType.getElementType() == toType.getElementType()) {
473           rewriter.updateRootInPlace(
474               copyOp, [&] { copyOp.targetMutable().assign(castOp.source()); });
475           modified = true;
476         }
477       }
478     }
479 
480     return success(modified);
481   }
482 };
483 
484 /// Fold memref.copy(%x, %x).
485 struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
486   using OpRewritePattern<CopyOp>::OpRewritePattern;
487 
488   LogicalResult matchAndRewrite(CopyOp copyOp,
489                                 PatternRewriter &rewriter) const override {
490     if (copyOp.source() != copyOp.target())
491       return failure();
492 
493     rewriter.eraseOp(copyOp);
494     return success();
495   }
496 };
497 } // namespace
498 
499 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
500                                          MLIRContext *context) {
501   results.add<FoldCopyOfCast, FoldSelfCopy>(context);
502 }
503 
504 LogicalResult CopyOp::fold(ArrayRef<Attribute> cstOperands,
505                            SmallVectorImpl<OpFoldResult> &results) {
506   /// copy(memrefcast) -> copy
507   bool folded = false;
508   Operation *op = *this;
509   for (OpOperand &operand : op->getOpOperands()) {
510     auto castOp = operand.get().getDefiningOp<memref::CastOp>();
511     if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
512       operand.set(castOp.getOperand());
513       folded = true;
514     }
515   }
516   return success(folded);
517 }
518 
519 //===----------------------------------------------------------------------===//
520 // DeallocOp
521 //===----------------------------------------------------------------------===//
522 
523 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
524                               SmallVectorImpl<OpFoldResult> &results) {
525   /// dealloc(memrefcast) -> dealloc
526   return foldMemRefCast(*this);
527 }
528 
529 //===----------------------------------------------------------------------===//
530 // DimOp
531 //===----------------------------------------------------------------------===//
532 
533 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
534                   int64_t index) {
535   auto loc = result.location;
536   Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
537   build(builder, result, source, indexValue);
538 }
539 
540 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
541                   Value index) {
542   auto indexTy = builder.getIndexType();
543   build(builder, result, indexTy, source, index);
544 }
545 
546 Optional<int64_t> DimOp::getConstantIndex() {
547   if (auto constantOp = index().getDefiningOp<arith::ConstantOp>())
548     return constantOp.getValue().cast<IntegerAttr>().getInt();
549   return {};
550 }
551 
552 LogicalResult DimOp::verify() {
553   // Assume unknown index to be in range.
554   Optional<int64_t> index = getConstantIndex();
555   if (!index.hasValue())
556     return success();
557 
558   // Check that constant index is not knowingly out of range.
559   auto type = source().getType();
560   if (auto memrefType = type.dyn_cast<MemRefType>()) {
561     if (index.getValue() >= memrefType.getRank())
562       return emitOpError("index is out of range");
563   } else if (type.isa<UnrankedMemRefType>()) {
564     // Assume index to be in range.
565   } else {
566     llvm_unreachable("expected operand with memref type");
567   }
568   return success();
569 }
570 
571 /// Return a map with key being elements in `vals` and data being number of
572 /// occurences of it. Use std::map, since the `vals` here are strides and the
573 /// dynamic stride value is the same as the tombstone value for
574 /// `DenseMap<int64_t>`.
575 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
576   std::map<int64_t, unsigned> numOccurences;
577   for (auto val : vals)
578     numOccurences[val]++;
579   return numOccurences;
580 }
581 
582 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed
583 /// to be a subset of `originalType` with some `1` entries erased, return the
584 /// set of indices that specifies which of the entries of `originalShape` are
585 /// dropped to obtain `reducedShape`.
586 /// This accounts for cases where there are multiple unit-dims, but only a
587 /// subset of those are dropped. For MemRefTypes these can be disambiguated
588 /// using the strides. If a dimension is dropped the stride must be dropped too.
589 static llvm::Optional<llvm::SmallBitVector>
590 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
591                                ArrayRef<OpFoldResult> sizes) {
592   llvm::SmallBitVector unusedDims(originalType.getRank());
593   if (originalType.getRank() == reducedType.getRank())
594     return unusedDims;
595 
596   for (const auto &dim : llvm::enumerate(sizes))
597     if (auto attr = dim.value().dyn_cast<Attribute>())
598       if (attr.cast<IntegerAttr>().getInt() == 1)
599         unusedDims.set(dim.index());
600 
601   SmallVector<int64_t> originalStrides, candidateStrides;
602   int64_t originalOffset, candidateOffset;
603   if (failed(
604           getStridesAndOffset(originalType, originalStrides, originalOffset)) ||
605       failed(
606           getStridesAndOffset(reducedType, candidateStrides, candidateOffset)))
607     return llvm::None;
608 
609   // For memrefs, a dimension is truly dropped if its corresponding stride is
610   // also dropped. This is particularly important when more than one of the dims
611   // is 1. Track the number of occurences of the strides in the original type
612   // and the candidate type. For each unused dim that stride should not be
613   // present in the candidate type. Note that there could be multiple dimensions
614   // that have the same size. We dont need to exactly figure out which dim
615   // corresponds to which stride, we just need to verify that the number of
616   // reptitions of a stride in the original + number of unused dims with that
617   // stride == number of repititions of a stride in the candidate.
618   std::map<int64_t, unsigned> currUnaccountedStrides =
619       getNumOccurences(originalStrides);
620   std::map<int64_t, unsigned> candidateStridesNumOccurences =
621       getNumOccurences(candidateStrides);
622   for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
623     if (!unusedDims.test(dim))
624       continue;
625     int64_t originalStride = originalStrides[dim];
626     if (currUnaccountedStrides[originalStride] >
627         candidateStridesNumOccurences[originalStride]) {
628       // This dim can be treated as dropped.
629       currUnaccountedStrides[originalStride]--;
630       continue;
631     }
632     if (currUnaccountedStrides[originalStride] ==
633         candidateStridesNumOccurences[originalStride]) {
634       // The stride for this is not dropped. Keep as is.
635       unusedDims.reset(dim);
636       continue;
637     }
638     if (currUnaccountedStrides[originalStride] <
639         candidateStridesNumOccurences[originalStride]) {
640       // This should never happen. Cant have a stride in the reduced rank type
641       // that wasnt in the original one.
642       return llvm::None;
643     }
644   }
645 
646   if ((int64_t)unusedDims.count() + reducedType.getRank() !=
647       originalType.getRank())
648     return llvm::None;
649   return unusedDims;
650 }
651 
652 llvm::SmallBitVector SubViewOp::getDroppedDims() {
653   MemRefType sourceType = getSourceType();
654   MemRefType resultType = getType();
655   llvm::Optional<llvm::SmallBitVector> unusedDims =
656       computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
657   assert(unusedDims && "unable to find unused dims of subview");
658   return *unusedDims;
659 }
660 
661 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
662   // All forms of folding require a known index.
663   auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
664   if (!index)
665     return {};
666 
667   // Folding for unranked types (UnrankedMemRefType) is not supported.
668   auto memrefType = source().getType().dyn_cast<MemRefType>();
669   if (!memrefType)
670     return {};
671 
672   // Fold if the shape extent along the given index is known.
673   if (!memrefType.isDynamicDim(index.getInt())) {
674     Builder builder(getContext());
675     return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
676   }
677 
678   // The size at the given index is now known to be a dynamic size.
679   unsigned unsignedIndex = index.getValue().getZExtValue();
680 
681   // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
682   Operation *definingOp = source().getDefiningOp();
683 
684   if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
685     return *(alloc.getDynamicSizes().begin() +
686              memrefType.getDynamicDimIndex(unsignedIndex));
687 
688   if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
689     return *(alloca.getDynamicSizes().begin() +
690              memrefType.getDynamicDimIndex(unsignedIndex));
691 
692   if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
693     return *(view.getDynamicSizes().begin() +
694              memrefType.getDynamicDimIndex(unsignedIndex));
695 
696   if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
697     llvm::SmallBitVector unusedDims = subview.getDroppedDims();
698     unsigned resultIndex = 0;
699     unsigned sourceRank = subview.getSourceType().getRank();
700     unsigned sourceIndex = 0;
701     for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
702       if (unusedDims.test(i))
703         continue;
704       if (resultIndex == unsignedIndex) {
705         sourceIndex = i;
706         break;
707       }
708       resultIndex++;
709     }
710     assert(subview.isDynamicSize(sourceIndex) &&
711            "expected dynamic subview size");
712     return subview.getDynamicSize(sourceIndex);
713   }
714 
715   if (auto sizeInterface =
716           dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
717     assert(sizeInterface.isDynamicSize(unsignedIndex) &&
718            "Expected dynamic subview size");
719     return sizeInterface.getDynamicSize(unsignedIndex);
720   }
721 
722   // dim(memrefcast) -> dim
723   if (succeeded(foldMemRefCast(*this)))
724     return getResult();
725 
726   return {};
727 }
728 
729 namespace {
730 /// Fold dim of a memref reshape operation to a load into the reshape's shape
731 /// operand.
732 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
733   using OpRewritePattern<DimOp>::OpRewritePattern;
734 
735   LogicalResult matchAndRewrite(DimOp dim,
736                                 PatternRewriter &rewriter) const override {
737     auto reshape = dim.source().getDefiningOp<ReshapeOp>();
738 
739     if (!reshape)
740       return failure();
741 
742     // Place the load directly after the reshape to ensure that the shape memref
743     // was not mutated.
744     rewriter.setInsertionPointAfter(reshape);
745     Location loc = dim.getLoc();
746     Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index());
747     if (load.getType() != dim.getType())
748       load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
749     rewriter.replaceOp(dim, load);
750     return success();
751   }
752 };
753 
754 } // namespace
755 
756 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
757                                         MLIRContext *context) {
758   results.add<DimOfMemRefReshape>(context);
759 }
760 
761 // ---------------------------------------------------------------------------
762 // DmaStartOp
763 // ---------------------------------------------------------------------------
764 
765 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
766                        Value srcMemRef, ValueRange srcIndices, Value destMemRef,
767                        ValueRange destIndices, Value numElements,
768                        Value tagMemRef, ValueRange tagIndices, Value stride,
769                        Value elementsPerStride) {
770   result.addOperands(srcMemRef);
771   result.addOperands(srcIndices);
772   result.addOperands(destMemRef);
773   result.addOperands(destIndices);
774   result.addOperands({numElements, tagMemRef});
775   result.addOperands(tagIndices);
776   if (stride)
777     result.addOperands({stride, elementsPerStride});
778 }
779 
780 void DmaStartOp::print(OpAsmPrinter &p) {
781   p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
782     << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
783     << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
784   if (isStrided())
785     p << ", " << getStride() << ", " << getNumElementsPerStride();
786 
787   p.printOptionalAttrDict((*this)->getAttrs());
788   p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
789     << ", " << getTagMemRef().getType();
790 }
791 
792 // Parse DmaStartOp.
793 // Ex:
794 //   %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
795 //                       %tag[%index], %stride, %num_elt_per_stride :
796 //                     : memref<3076 x f32, 0>,
797 //                       memref<1024 x f32, 2>,
798 //                       memref<1 x i32>
799 //
800 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
801   OpAsmParser::OperandType srcMemRefInfo;
802   SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
803   OpAsmParser::OperandType dstMemRefInfo;
804   SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos;
805   OpAsmParser::OperandType numElementsInfo;
806   OpAsmParser::OperandType tagMemrefInfo;
807   SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
808   SmallVector<OpAsmParser::OperandType, 2> strideInfo;
809 
810   SmallVector<Type, 3> types;
811   auto indexType = parser.getBuilder().getIndexType();
812 
813   // Parse and resolve the following list of operands:
814   // *) source memref followed by its indices (in square brackets).
815   // *) destination memref followed by its indices (in square brackets).
816   // *) dma size in KiB.
817   if (parser.parseOperand(srcMemRefInfo) ||
818       parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
819       parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
820       parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
821       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
822       parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
823       parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
824     return failure();
825 
826   // Parse optional stride and elements per stride.
827   if (parser.parseTrailingOperandList(strideInfo))
828     return failure();
829 
830   bool isStrided = strideInfo.size() == 2;
831   if (!strideInfo.empty() && !isStrided) {
832     return parser.emitError(parser.getNameLoc(),
833                             "expected two stride related operands");
834   }
835 
836   if (parser.parseColonTypeList(types))
837     return failure();
838   if (types.size() != 3)
839     return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
840 
841   if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
842       parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
843       parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
844       parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
845       // size should be an index.
846       parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
847       parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
848       // tag indices should be index.
849       parser.resolveOperands(tagIndexInfos, indexType, result.operands))
850     return failure();
851 
852   if (isStrided) {
853     if (parser.resolveOperands(strideInfo, indexType, result.operands))
854       return failure();
855   }
856 
857   return success();
858 }
859 
860 LogicalResult DmaStartOp::verify() {
861   unsigned numOperands = getNumOperands();
862 
863   // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
864   // the number of elements.
865   if (numOperands < 4)
866     return emitOpError("expected at least 4 operands");
867 
868   // Check types of operands. The order of these calls is important: the later
869   // calls rely on some type properties to compute the operand position.
870   // 1. Source memref.
871   if (!getSrcMemRef().getType().isa<MemRefType>())
872     return emitOpError("expected source to be of memref type");
873   if (numOperands < getSrcMemRefRank() + 4)
874     return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
875                          << " operands";
876   if (!getSrcIndices().empty() &&
877       !llvm::all_of(getSrcIndices().getTypes(),
878                     [](Type t) { return t.isIndex(); }))
879     return emitOpError("expected source indices to be of index type");
880 
881   // 2. Destination memref.
882   if (!getDstMemRef().getType().isa<MemRefType>())
883     return emitOpError("expected destination to be of memref type");
884   unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
885   if (numOperands < numExpectedOperands)
886     return emitOpError() << "expected at least " << numExpectedOperands
887                          << " operands";
888   if (!getDstIndices().empty() &&
889       !llvm::all_of(getDstIndices().getTypes(),
890                     [](Type t) { return t.isIndex(); }))
891     return emitOpError("expected destination indices to be of index type");
892 
893   // 3. Number of elements.
894   if (!getNumElements().getType().isIndex())
895     return emitOpError("expected num elements to be of index type");
896 
897   // 4. Tag memref.
898   if (!getTagMemRef().getType().isa<MemRefType>())
899     return emitOpError("expected tag to be of memref type");
900   numExpectedOperands += getTagMemRefRank();
901   if (numOperands < numExpectedOperands)
902     return emitOpError() << "expected at least " << numExpectedOperands
903                          << " operands";
904   if (!getTagIndices().empty() &&
905       !llvm::all_of(getTagIndices().getTypes(),
906                     [](Type t) { return t.isIndex(); }))
907     return emitOpError("expected tag indices to be of index type");
908 
909   // Optional stride-related operands must be either both present or both
910   // absent.
911   if (numOperands != numExpectedOperands &&
912       numOperands != numExpectedOperands + 2)
913     return emitOpError("incorrect number of operands");
914 
915   // 5. Strides.
916   if (isStrided()) {
917     if (!getStride().getType().isIndex() ||
918         !getNumElementsPerStride().getType().isIndex())
919       return emitOpError(
920           "expected stride and num elements per stride to be of type index");
921   }
922 
923   return success();
924 }
925 
926 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
927                                SmallVectorImpl<OpFoldResult> &results) {
928   /// dma_start(memrefcast) -> dma_start
929   return foldMemRefCast(*this);
930 }
931 
932 // ---------------------------------------------------------------------------
933 // DmaWaitOp
934 // ---------------------------------------------------------------------------
935 
936 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
937                               SmallVectorImpl<OpFoldResult> &results) {
938   /// dma_wait(memrefcast) -> dma_wait
939   return foldMemRefCast(*this);
940 }
941 
942 LogicalResult DmaWaitOp::verify() {
943   // Check that the number of tag indices matches the tagMemRef rank.
944   unsigned numTagIndices = tagIndices().size();
945   unsigned tagMemRefRank = getTagMemRefRank();
946   if (numTagIndices != tagMemRefRank)
947     return emitOpError() << "expected tagIndices to have the same number of "
948                             "elements as the tagMemRef rank, expected "
949                          << tagMemRefRank << ", but got " << numTagIndices;
950   return success();
951 }
952 
953 //===----------------------------------------------------------------------===//
954 // GenericAtomicRMWOp
955 //===----------------------------------------------------------------------===//
956 
957 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
958                                Value memref, ValueRange ivs) {
959   result.addOperands(memref);
960   result.addOperands(ivs);
961 
962   if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) {
963     Type elementType = memrefType.getElementType();
964     result.addTypes(elementType);
965 
966     Region *bodyRegion = result.addRegion();
967     bodyRegion->push_back(new Block());
968     bodyRegion->addArgument(elementType, memref.getLoc());
969   }
970 }
971 
972 LogicalResult GenericAtomicRMWOp::verify() {
973   auto &body = getRegion();
974   if (body.getNumArguments() != 1)
975     return emitOpError("expected single number of entry block arguments");
976 
977   if (getResult().getType() != body.getArgument(0).getType())
978     return emitOpError("expected block argument of the same type result type");
979 
980   bool hasSideEffects =
981       body.walk([&](Operation *nestedOp) {
982             if (MemoryEffectOpInterface::hasNoEffect(nestedOp))
983               return WalkResult::advance();
984             nestedOp->emitError(
985                 "body of 'memref.generic_atomic_rmw' should contain "
986                 "only operations with no side effects");
987             return WalkResult::interrupt();
988           })
989           .wasInterrupted();
990   return hasSideEffects ? failure() : success();
991 }
992 
993 ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
994                                       OperationState &result) {
995   OpAsmParser::OperandType memref;
996   Type memrefType;
997   SmallVector<OpAsmParser::OperandType, 4> ivs;
998 
999   Type indexType = parser.getBuilder().getIndexType();
1000   if (parser.parseOperand(memref) ||
1001       parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) ||
1002       parser.parseColonType(memrefType) ||
1003       parser.resolveOperand(memref, memrefType, result.operands) ||
1004       parser.resolveOperands(ivs, indexType, result.operands))
1005     return failure();
1006 
1007   Region *body = result.addRegion();
1008   if (parser.parseRegion(*body, llvm::None, llvm::None) ||
1009       parser.parseOptionalAttrDict(result.attributes))
1010     return failure();
1011   result.types.push_back(memrefType.cast<MemRefType>().getElementType());
1012   return success();
1013 }
1014 
1015 void GenericAtomicRMWOp::print(OpAsmPrinter &p) {
1016   p << ' ' << memref() << "[" << indices() << "] : " << memref().getType()
1017     << ' ';
1018   p.printRegion(getRegion());
1019   p.printOptionalAttrDict((*this)->getAttrs());
1020 }
1021 
1022 //===----------------------------------------------------------------------===//
1023 // AtomicYieldOp
1024 //===----------------------------------------------------------------------===//
1025 
1026 LogicalResult AtomicYieldOp::verify() {
1027   Type parentType = (*this)->getParentOp()->getResultTypes().front();
1028   Type resultType = result().getType();
1029   if (parentType != resultType)
1030     return emitOpError() << "types mismatch between yield op: " << resultType
1031                          << " and its parent: " << parentType;
1032   return success();
1033 }
1034 
1035 //===----------------------------------------------------------------------===//
1036 // GlobalOp
1037 //===----------------------------------------------------------------------===//
1038 
1039 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op,
1040                                                    TypeAttr type,
1041                                                    Attribute initialValue) {
1042   p << type;
1043   if (!op.isExternal()) {
1044     p << " = ";
1045     if (op.isUninitialized())
1046       p << "uninitialized";
1047     else
1048       p.printAttributeWithoutType(initialValue);
1049   }
1050 }
1051 
1052 static ParseResult
1053 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
1054                                        Attribute &initialValue) {
1055   Type type;
1056   if (parser.parseType(type))
1057     return failure();
1058 
1059   auto memrefType = type.dyn_cast<MemRefType>();
1060   if (!memrefType || !memrefType.hasStaticShape())
1061     return parser.emitError(parser.getNameLoc())
1062            << "type should be static shaped memref, but got " << type;
1063   typeAttr = TypeAttr::get(type);
1064 
1065   if (parser.parseOptionalEqual())
1066     return success();
1067 
1068   if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1069     initialValue = UnitAttr::get(parser.getContext());
1070     return success();
1071   }
1072 
1073   Type tensorType = getTensorTypeFromMemRefType(memrefType);
1074   if (parser.parseAttribute(initialValue, tensorType))
1075     return failure();
1076   if (!initialValue.isa<ElementsAttr>())
1077     return parser.emitError(parser.getNameLoc())
1078            << "initial value should be a unit or elements attribute";
1079   return success();
1080 }
1081 
1082 LogicalResult GlobalOp::verify() {
1083   auto memrefType = type().dyn_cast<MemRefType>();
1084   if (!memrefType || !memrefType.hasStaticShape())
1085     return emitOpError("type should be static shaped memref, but got ")
1086            << type();
1087 
1088   // Verify that the initial value, if present, is either a unit attribute or
1089   // an elements attribute.
1090   if (initial_value().hasValue()) {
1091     Attribute initValue = initial_value().getValue();
1092     if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>())
1093       return emitOpError("initial value should be a unit or elements "
1094                          "attribute, but got ")
1095              << initValue;
1096 
1097     // Check that the type of the initial value is compatible with the type of
1098     // the global variable.
1099     if (initValue.isa<ElementsAttr>()) {
1100       Type initType = initValue.getType();
1101       Type tensorType = getTensorTypeFromMemRefType(memrefType);
1102       if (initType != tensorType)
1103         return emitOpError("initial value expected to be of type ")
1104                << tensorType << ", but was of type " << initType;
1105     }
1106   }
1107 
1108   if (Optional<uint64_t> alignAttr = alignment()) {
1109     uint64_t alignment = alignAttr.getValue();
1110 
1111     if (!llvm::isPowerOf2_64(alignment))
1112       return emitError() << "alignment attribute value " << alignment
1113                          << " is not a power of 2";
1114   }
1115 
1116   // TODO: verify visibility for declarations.
1117   return success();
1118 }
1119 
1120 //===----------------------------------------------------------------------===//
1121 // GetGlobalOp
1122 //===----------------------------------------------------------------------===//
1123 
1124 LogicalResult
1125 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1126   // Verify that the result type is same as the type of the referenced
1127   // memref.global op.
1128   auto global =
1129       symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr());
1130   if (!global)
1131     return emitOpError("'")
1132            << name() << "' does not reference a valid global memref";
1133 
1134   Type resultType = result().getType();
1135   if (global.type() != resultType)
1136     return emitOpError("result type ")
1137            << resultType << " does not match type " << global.type()
1138            << " of the global memref @" << name();
1139   return success();
1140 }
1141 
1142 //===----------------------------------------------------------------------===//
1143 // LoadOp
1144 //===----------------------------------------------------------------------===//
1145 
1146 LogicalResult LoadOp::verify() {
1147   if (getNumOperands() != 1 + getMemRefType().getRank())
1148     return emitOpError("incorrect number of indices for load");
1149   return success();
1150 }
1151 
1152 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
1153   /// load(memrefcast) -> load
1154   if (succeeded(foldMemRefCast(*this)))
1155     return getResult();
1156   return OpFoldResult();
1157 }
1158 
1159 //===----------------------------------------------------------------------===//
1160 // PrefetchOp
1161 //===----------------------------------------------------------------------===//
1162 
1163 void PrefetchOp::print(OpAsmPrinter &p) {
1164   p << " " << memref() << '[';
1165   p.printOperands(indices());
1166   p << ']' << ", " << (isWrite() ? "write" : "read");
1167   p << ", locality<" << localityHint();
1168   p << ">, " << (isDataCache() ? "data" : "instr");
1169   p.printOptionalAttrDict(
1170       (*this)->getAttrs(),
1171       /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1172   p << " : " << getMemRefType();
1173 }
1174 
1175 ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
1176   OpAsmParser::OperandType memrefInfo;
1177   SmallVector<OpAsmParser::OperandType, 4> indexInfo;
1178   IntegerAttr localityHint;
1179   MemRefType type;
1180   StringRef readOrWrite, cacheType;
1181 
1182   auto indexTy = parser.getBuilder().getIndexType();
1183   auto i32Type = parser.getBuilder().getIntegerType(32);
1184   if (parser.parseOperand(memrefInfo) ||
1185       parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1186       parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1187       parser.parseComma() || parser.parseKeyword("locality") ||
1188       parser.parseLess() ||
1189       parser.parseAttribute(localityHint, i32Type, "localityHint",
1190                             result.attributes) ||
1191       parser.parseGreater() || parser.parseComma() ||
1192       parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1193       parser.resolveOperand(memrefInfo, type, result.operands) ||
1194       parser.resolveOperands(indexInfo, indexTy, result.operands))
1195     return failure();
1196 
1197   if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
1198     return parser.emitError(parser.getNameLoc(),
1199                             "rw specifier has to be 'read' or 'write'");
1200   result.addAttribute(
1201       PrefetchOp::getIsWriteAttrName(),
1202       parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
1203 
1204   if (!cacheType.equals("data") && !cacheType.equals("instr"))
1205     return parser.emitError(parser.getNameLoc(),
1206                             "cache type has to be 'data' or 'instr'");
1207 
1208   result.addAttribute(
1209       PrefetchOp::getIsDataCacheAttrName(),
1210       parser.getBuilder().getBoolAttr(cacheType.equals("data")));
1211 
1212   return success();
1213 }
1214 
1215 LogicalResult PrefetchOp::verify() {
1216   if (getNumOperands() != 1 + getMemRefType().getRank())
1217     return emitOpError("too few indices");
1218 
1219   return success();
1220 }
1221 
1222 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
1223                                SmallVectorImpl<OpFoldResult> &results) {
1224   // prefetch(memrefcast) -> prefetch
1225   return foldMemRefCast(*this);
1226 }
1227 
1228 //===----------------------------------------------------------------------===//
1229 // RankOp
1230 //===----------------------------------------------------------------------===//
1231 
1232 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
1233   // Constant fold rank when the rank of the operand is known.
1234   auto type = getOperand().getType();
1235   auto shapedType = type.dyn_cast<ShapedType>();
1236   if (shapedType && shapedType.hasRank())
1237     return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1238   return IntegerAttr();
1239 }
1240 
1241 //===----------------------------------------------------------------------===//
1242 // ReinterpretCastOp
1243 //===----------------------------------------------------------------------===//
1244 
1245 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1246 /// `staticSizes` and `staticStrides` are automatically filled with
1247 /// source-memref-rank sentinel values that encode dynamic entries.
1248 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1249                               MemRefType resultType, Value source,
1250                               OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1251                               ArrayRef<OpFoldResult> strides,
1252                               ArrayRef<NamedAttribute> attrs) {
1253   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1254   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1255   dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets,
1256                              ShapedType::kDynamicStrideOrOffset);
1257   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1258                              ShapedType::kDynamicSize);
1259   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1260                              ShapedType::kDynamicStrideOrOffset);
1261   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1262         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1263         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1264   result.addAttributes(attrs);
1265 }
1266 
1267 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1268                               MemRefType resultType, Value source,
1269                               int64_t offset, ArrayRef<int64_t> sizes,
1270                               ArrayRef<int64_t> strides,
1271                               ArrayRef<NamedAttribute> attrs) {
1272   SmallVector<OpFoldResult> sizeValues =
1273       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1274         return b.getI64IntegerAttr(v);
1275       }));
1276   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1277       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1278         return b.getI64IntegerAttr(v);
1279       }));
1280   build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1281         strideValues, attrs);
1282 }
1283 
1284 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1285                               MemRefType resultType, Value source, Value offset,
1286                               ValueRange sizes, ValueRange strides,
1287                               ArrayRef<NamedAttribute> attrs) {
1288   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1289       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1290   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1291       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1292   build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1293 }
1294 
1295 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1296 // completed automatically, like we have for subview and extract_slice.
1297 LogicalResult ReinterpretCastOp::verify() {
1298   // The source and result memrefs should be in the same memory space.
1299   auto srcType = source().getType().cast<BaseMemRefType>();
1300   auto resultType = getType().cast<MemRefType>();
1301   if (srcType.getMemorySpace() != resultType.getMemorySpace())
1302     return emitError("different memory spaces specified for source type ")
1303            << srcType << " and result memref type " << resultType;
1304   if (srcType.getElementType() != resultType.getElementType())
1305     return emitError("different element types specified for source type ")
1306            << srcType << " and result memref type " << resultType;
1307 
1308   // Match sizes in result memref type and in static_sizes attribute.
1309   for (auto &en : llvm::enumerate(llvm::zip(
1310            resultType.getShape(), extractFromI64ArrayAttr(static_sizes())))) {
1311     int64_t resultSize = std::get<0>(en.value());
1312     int64_t expectedSize = std::get<1>(en.value());
1313     if (!ShapedType::isDynamic(resultSize) &&
1314         !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
1315       return emitError("expected result type with size = ")
1316              << expectedSize << " instead of " << resultSize
1317              << " in dim = " << en.index();
1318   }
1319 
1320   // Match offset and strides in static_offset and static_strides attributes. If
1321   // result memref type has no affine map specified, this will assume an
1322   // identity layout.
1323   int64_t resultOffset;
1324   SmallVector<int64_t, 4> resultStrides;
1325   if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
1326     return emitError("expected result type to have strided layout but found ")
1327            << resultType;
1328 
1329   // Match offset in result memref type and in static_offsets attribute.
1330   int64_t expectedOffset = extractFromI64ArrayAttr(static_offsets()).front();
1331   if (!ShapedType::isDynamicStrideOrOffset(resultOffset) &&
1332       !ShapedType::isDynamicStrideOrOffset(expectedOffset) &&
1333       resultOffset != expectedOffset)
1334     return emitError("expected result type with offset = ")
1335            << resultOffset << " instead of " << expectedOffset;
1336 
1337   // Match strides in result memref type and in static_strides attribute.
1338   for (auto &en : llvm::enumerate(llvm::zip(
1339            resultStrides, extractFromI64ArrayAttr(static_strides())))) {
1340     int64_t resultStride = std::get<0>(en.value());
1341     int64_t expectedStride = std::get<1>(en.value());
1342     if (!ShapedType::isDynamicStrideOrOffset(resultStride) &&
1343         !ShapedType::isDynamicStrideOrOffset(expectedStride) &&
1344         resultStride != expectedStride)
1345       return emitError("expected result type with stride = ")
1346              << expectedStride << " instead of " << resultStride
1347              << " in dim = " << en.index();
1348   }
1349 
1350   return success();
1351 }
1352 
1353 //===----------------------------------------------------------------------===//
1354 // Reassociative reshape ops
1355 //===----------------------------------------------------------------------===//
1356 
1357 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1358   return getSymbolLessAffineMaps(getReassociationExprs());
1359 }
1360 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1361   return convertReassociationIndicesToExprs(getContext(),
1362                                             getReassociationIndices());
1363 }
1364 
1365 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1366   return getSymbolLessAffineMaps(getReassociationExprs());
1367 }
1368 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1369   return convertReassociationIndicesToExprs(getContext(),
1370                                             getReassociationIndices());
1371 }
1372 
1373 /// Detect whether memref dims [dim, dim + extent) can be reshaped without
1374 /// copies.
1375 static bool isReshapableDimBand(unsigned dim, unsigned extent,
1376                                 ArrayRef<int64_t> sizes,
1377                                 ArrayRef<AffineExpr> strides) {
1378   // Bands of extent one can be reshaped, as they are not reshaped at all.
1379   if (extent == 1)
1380     return true;
1381   // Otherwise, the size of the first dimension needs to be known.
1382   if (ShapedType::isDynamic(sizes[dim]))
1383     return false;
1384   assert(sizes.size() == strides.size() && "mismatched ranks");
1385   // off by 1 indexing to avoid out of bounds
1386   //                       V
1387   for (auto idx = dim, e = dim + extent; idx + 1 < e; ++idx) {
1388     // Only bands of static shapes are reshapable. This is due to the fact that
1389     // there is no relation between dynamic sizes and dynamic strides: we do not
1390     // have enough information to know whether a "-1" size corresponds to the
1391     // proper symbol in the AffineExpr of a stride.
1392     if (ShapedType::isDynamic(sizes[idx + 1]))
1393       return false;
1394     // TODO: Refine this by passing the proper nDims and nSymbols so we can
1395     // simplify on the fly and catch more reshapable cases.
1396     if (strides[idx] != strides[idx + 1] * sizes[idx + 1])
1397       return false;
1398   }
1399   return true;
1400 }
1401 
1402 /// Compute the MemRefType obtained by applying the `reassociation` (which is
1403 /// expected to be valid) to `type`.
1404 /// If `type` is Contiguous MemRefType, this always produce a contiguous
1405 /// MemRefType.
1406 static MemRefType
1407 computeReshapeCollapsedType(MemRefType type,
1408                             ArrayRef<AffineMap> reassociation) {
1409   auto sizes = type.getShape();
1410   AffineExpr offset;
1411   SmallVector<AffineExpr, 4> strides;
1412   auto status = getStridesAndOffset(type, strides, offset);
1413   auto isIdentityLayout = type.getLayout().isIdentity();
1414   (void)status;
1415   assert(succeeded(status) && "expected strided memref");
1416 
1417   SmallVector<int64_t, 4> newSizes;
1418   newSizes.reserve(reassociation.size());
1419   SmallVector<AffineExpr, 4> newStrides;
1420   newStrides.reserve(reassociation.size());
1421 
1422   // Use the fact that reassociation is valid to simplify the logic: only use
1423   // each map's rank.
1424   assert(isReassociationValid(reassociation) && "invalid reassociation");
1425   unsigned currentDim = 0;
1426   for (AffineMap m : reassociation) {
1427     unsigned dim = m.getNumResults();
1428     int64_t size = 1;
1429     AffineExpr stride = strides[currentDim + dim - 1];
1430     if (isIdentityLayout ||
1431         isReshapableDimBand(currentDim, dim, sizes, strides)) {
1432       for (unsigned d = 0; d < dim; ++d) {
1433         int64_t currentSize = sizes[currentDim + d];
1434         if (ShapedType::isDynamic(currentSize)) {
1435           size = ShapedType::kDynamicSize;
1436           break;
1437         }
1438         size *= currentSize;
1439       }
1440     } else {
1441       size = ShapedType::kDynamicSize;
1442       stride = AffineExpr();
1443     }
1444     newSizes.push_back(size);
1445     newStrides.push_back(stride);
1446     currentDim += dim;
1447   }
1448 
1449   // Early-exit: if `type` is contiguous, the result must be contiguous.
1450   if (canonicalizeStridedLayout(type).getLayout().isIdentity())
1451     return MemRefType::Builder(type).setShape(newSizes).setLayout({});
1452 
1453   // Convert back to int64_t because we don't have enough information to create
1454   // new strided layouts from AffineExpr only. This corresponds to a case where
1455   // copies may be necessary.
1456   int64_t intOffset = ShapedType::kDynamicStrideOrOffset;
1457   if (auto o = offset.dyn_cast<AffineConstantExpr>())
1458     intOffset = o.getValue();
1459   SmallVector<int64_t, 4> intStrides;
1460   intStrides.reserve(strides.size());
1461   for (auto stride : newStrides) {
1462     if (auto cst = stride.dyn_cast_or_null<AffineConstantExpr>())
1463       intStrides.push_back(cst.getValue());
1464     else
1465       intStrides.push_back(ShapedType::kDynamicStrideOrOffset);
1466   }
1467   auto layout =
1468       makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext());
1469   return canonicalizeStridedLayout(
1470       MemRefType::Builder(type).setShape(newSizes).setLayout(
1471           AffineMapAttr::get(layout)));
1472 }
1473 
1474 void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src,
1475                           ArrayRef<ReassociationIndices> reassociation,
1476                           ArrayRef<NamedAttribute> attrs) {
1477   auto memRefType = src.getType().cast<MemRefType>();
1478   auto resultType = computeReshapeCollapsedType(
1479       memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
1480                       b.getContext(), reassociation)));
1481   build(b, result, resultType, src, attrs);
1482   result.addAttribute(getReassociationAttrName(),
1483                       getReassociationIndicesAttribute(b, reassociation));
1484 }
1485 
1486 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
1487                             ArrayRef<ReassociationIndices> reassociation,
1488                             ArrayRef<NamedAttribute> attrs) {
1489   auto memRefType = src.getType().cast<MemRefType>();
1490   auto resultType = computeReshapeCollapsedType(
1491       memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
1492                       b.getContext(), reassociation)));
1493   build(b, result, resultType, src, attrs);
1494   result.addAttribute(getReassociationAttrName(),
1495                       getReassociationIndicesAttribute(b, reassociation));
1496 }
1497 
1498 template <typename ReshapeOp,
1499           bool isExpansion = std::is_same<ReshapeOp, ExpandShapeOp>::value>
1500 static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType,
1501                                      MemRefType collapsedType) {
1502   if (failed(
1503           verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
1504     return failure();
1505   auto maps = op.getReassociationMaps();
1506   MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps);
1507   if (collapsedType != expectedType)
1508     return op.emitOpError("expected collapsed type to be ")
1509            << expectedType << ", but got " << collapsedType;
1510   return success();
1511 }
1512 
1513 LogicalResult ExpandShapeOp::verify() {
1514   return verifyReshapeOp(*this, getResultType(), getSrcType());
1515 }
1516 
1517 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1518                                                 MLIRContext *context) {
1519   results.add<CollapseReshapeOps<ExpandShapeOp>,
1520               CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context);
1521 }
1522 
1523 LogicalResult CollapseShapeOp::verify() {
1524   return verifyReshapeOp(*this, getSrcType(), getResultType());
1525 }
1526 
1527 struct CollapseShapeOpMemRefCastFolder
1528     : public OpRewritePattern<CollapseShapeOp> {
1529 public:
1530   using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
1531 
1532   LogicalResult matchAndRewrite(CollapseShapeOp op,
1533                                 PatternRewriter &rewriter) const override {
1534     auto cast = op.getOperand().getDefiningOp<CastOp>();
1535     if (!cast)
1536       return failure();
1537 
1538     if (!CastOp::canFoldIntoConsumerOp(cast))
1539       return failure();
1540 
1541     Type newResultType = computeReshapeCollapsedType(
1542         cast.getOperand().getType().cast<MemRefType>(),
1543         op.getReassociationMaps());
1544 
1545     if (newResultType == op.getResultType()) {
1546       rewriter.updateRootInPlace(
1547           op, [&]() { op.srcMutable().assign(cast.source()); });
1548     } else {
1549       Value newOp = rewriter.create<CollapseShapeOp>(
1550           op->getLoc(), cast.source(), op.getReassociationIndices());
1551       rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
1552     }
1553     return success();
1554   }
1555 };
1556 
1557 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1558                                                   MLIRContext *context) {
1559   results.add<CollapseReshapeOps<CollapseShapeOp>,
1560               CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>,
1561               CollapseShapeOpMemRefCastFolder>(context);
1562 }
1563 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
1564   return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
1565 }
1566 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
1567   return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
1568 }
1569 
1570 //===----------------------------------------------------------------------===//
1571 // ReshapeOp
1572 //===----------------------------------------------------------------------===//
1573 
1574 LogicalResult ReshapeOp::verify() {
1575   Type operandType = source().getType();
1576   Type resultType = result().getType();
1577 
1578   Type operandElementType = operandType.cast<ShapedType>().getElementType();
1579   Type resultElementType = resultType.cast<ShapedType>().getElementType();
1580   if (operandElementType != resultElementType)
1581     return emitOpError("element types of source and destination memref "
1582                        "types should be the same");
1583 
1584   if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
1585     if (!operandMemRefType.getLayout().isIdentity())
1586       return emitOpError("source memref type should have identity affine map");
1587 
1588   int64_t shapeSize = shape().getType().cast<MemRefType>().getDimSize(0);
1589   auto resultMemRefType = resultType.dyn_cast<MemRefType>();
1590   if (resultMemRefType) {
1591     if (!resultMemRefType.getLayout().isIdentity())
1592       return emitOpError("result memref type should have identity affine map");
1593     if (shapeSize == ShapedType::kDynamicSize)
1594       return emitOpError("cannot use shape operand with dynamic length to "
1595                          "reshape to statically-ranked memref type");
1596     if (shapeSize != resultMemRefType.getRank())
1597       return emitOpError(
1598           "length of shape operand differs from the result's memref rank");
1599   }
1600   return success();
1601 }
1602 
1603 //===----------------------------------------------------------------------===//
1604 // StoreOp
1605 //===----------------------------------------------------------------------===//
1606 
1607 LogicalResult StoreOp::verify() {
1608   if (getNumOperands() != 2 + getMemRefType().getRank())
1609     return emitOpError("store index operand count not equal to memref rank");
1610 
1611   return success();
1612 }
1613 
1614 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
1615                             SmallVectorImpl<OpFoldResult> &results) {
1616   /// store(memrefcast) -> store
1617   return foldMemRefCast(*this, getValueToStore());
1618 }
1619 
1620 //===----------------------------------------------------------------------===//
1621 // SubViewOp
1622 //===----------------------------------------------------------------------===//
1623 
1624 namespace {
1625 /// Helpers to write more idiomatic operations.
1626 namespace saturated_arith {
1627 struct Wrapper {
1628   explicit Wrapper(int64_t v) : v(v) {}
1629   operator int64_t() { return v; }
1630   int64_t v;
1631 };
1632 Wrapper operator+(Wrapper a, int64_t b) {
1633   if (ShapedType::isDynamicStrideOrOffset(a) ||
1634       ShapedType::isDynamicStrideOrOffset(b))
1635     return Wrapper(ShapedType::kDynamicStrideOrOffset);
1636   return Wrapper(a.v + b);
1637 }
1638 Wrapper operator*(Wrapper a, int64_t b) {
1639   if (ShapedType::isDynamicStrideOrOffset(a) ||
1640       ShapedType::isDynamicStrideOrOffset(b))
1641     return Wrapper(ShapedType::kDynamicStrideOrOffset);
1642   return Wrapper(a.v * b);
1643 }
1644 } // namespace saturated_arith
1645 } // namespace
1646 
1647 /// A subview result type can be fully inferred from the source type and the
1648 /// static representation of offsets, sizes and strides. Special sentinels
1649 /// encode the dynamic case.
1650 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
1651                                 ArrayRef<int64_t> staticOffsets,
1652                                 ArrayRef<int64_t> staticSizes,
1653                                 ArrayRef<int64_t> staticStrides) {
1654   unsigned rank = sourceMemRefType.getRank();
1655   (void)rank;
1656   assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
1657   assert(staticSizes.size() == rank && "staticSizes length mismatch");
1658   assert(staticStrides.size() == rank && "staticStrides length mismatch");
1659 
1660   // Extract source offset and strides.
1661   int64_t sourceOffset;
1662   SmallVector<int64_t, 4> sourceStrides;
1663   auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset);
1664   assert(succeeded(res) && "SubViewOp expected strided memref type");
1665   (void)res;
1666 
1667   // Compute target offset whose value is:
1668   //   `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
1669   int64_t targetOffset = sourceOffset;
1670   for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
1671     auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
1672     using namespace saturated_arith;
1673     targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride;
1674   }
1675 
1676   // Compute target stride whose value is:
1677   //   `sourceStrides_i * staticStrides_i`.
1678   SmallVector<int64_t, 4> targetStrides;
1679   targetStrides.reserve(staticOffsets.size());
1680   for (auto it : llvm::zip(sourceStrides, staticStrides)) {
1681     auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
1682     using namespace saturated_arith;
1683     targetStrides.push_back(Wrapper(sourceStride) * staticStride);
1684   }
1685 
1686   // The type is now known.
1687   return MemRefType::get(
1688       staticSizes, sourceMemRefType.getElementType(),
1689       makeStridedLinearLayoutMap(targetStrides, targetOffset,
1690                                  sourceMemRefType.getContext()),
1691       sourceMemRefType.getMemorySpace());
1692 }
1693 
1694 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
1695                                 ArrayRef<OpFoldResult> offsets,
1696                                 ArrayRef<OpFoldResult> sizes,
1697                                 ArrayRef<OpFoldResult> strides) {
1698   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1699   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1700   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1701                              ShapedType::kDynamicStrideOrOffset);
1702   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1703                              ShapedType::kDynamicSize);
1704   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1705                              ShapedType::kDynamicStrideOrOffset);
1706   return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
1707                                     staticSizes, staticStrides);
1708 }
1709 
1710 Type SubViewOp::inferRankReducedResultType(unsigned resultRank,
1711                                            MemRefType sourceRankedTensorType,
1712                                            ArrayRef<int64_t> offsets,
1713                                            ArrayRef<int64_t> sizes,
1714                                            ArrayRef<int64_t> strides) {
1715   auto inferredType =
1716       inferResultType(sourceRankedTensorType, offsets, sizes, strides)
1717           .cast<MemRefType>();
1718   assert(inferredType.getRank() >= resultRank && "expected ");
1719   int rankDiff = inferredType.getRank() - resultRank;
1720   if (rankDiff > 0) {
1721     auto shape = inferredType.getShape();
1722     llvm::SmallBitVector dimsToProject =
1723         getPositionsOfShapeOne(rankDiff, shape);
1724     SmallVector<int64_t> projectedShape;
1725     for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
1726       if (!dimsToProject.test(pos))
1727         projectedShape.push_back(shape[pos]);
1728 
1729     AffineMap map = inferredType.getLayout().getAffineMap();
1730     if (!map.isIdentity())
1731       map = getProjectedMap(map, dimsToProject);
1732     inferredType =
1733         MemRefType::get(projectedShape, inferredType.getElementType(), map,
1734                         inferredType.getMemorySpace());
1735   }
1736   return inferredType;
1737 }
1738 
1739 Type SubViewOp::inferRankReducedResultType(unsigned resultRank,
1740                                            MemRefType sourceRankedTensorType,
1741                                            ArrayRef<OpFoldResult> offsets,
1742                                            ArrayRef<OpFoldResult> sizes,
1743                                            ArrayRef<OpFoldResult> strides) {
1744   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1745   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1746   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1747                              ShapedType::kDynamicStrideOrOffset);
1748   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1749                              ShapedType::kDynamicSize);
1750   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1751                              ShapedType::kDynamicStrideOrOffset);
1752   return SubViewOp::inferRankReducedResultType(
1753       resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
1754       staticStrides);
1755 }
1756 // Build a SubViewOp with mixed static and dynamic entries and custom result
1757 // type. If the type passed is nullptr, it is inferred.
1758 void SubViewOp::build(OpBuilder &b, OperationState &result,
1759                       MemRefType resultType, Value source,
1760                       ArrayRef<OpFoldResult> offsets,
1761                       ArrayRef<OpFoldResult> sizes,
1762                       ArrayRef<OpFoldResult> strides,
1763                       ArrayRef<NamedAttribute> attrs) {
1764   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1765   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1766   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1767                              ShapedType::kDynamicStrideOrOffset);
1768   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1769                              ShapedType::kDynamicSize);
1770   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1771                              ShapedType::kDynamicStrideOrOffset);
1772   auto sourceMemRefType = source.getType().cast<MemRefType>();
1773   // Structuring implementation this way avoids duplication between builders.
1774   if (!resultType) {
1775     resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
1776                                             staticSizes, staticStrides)
1777                      .cast<MemRefType>();
1778   }
1779   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1780         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1781         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1782   result.addAttributes(attrs);
1783 }
1784 
1785 // Build a SubViewOp with mixed static and dynamic entries and inferred result
1786 // type.
1787 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1788                       ArrayRef<OpFoldResult> offsets,
1789                       ArrayRef<OpFoldResult> sizes,
1790                       ArrayRef<OpFoldResult> strides,
1791                       ArrayRef<NamedAttribute> attrs) {
1792   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
1793 }
1794 
1795 // Build a SubViewOp with static entries and inferred result type.
1796 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1797                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1798                       ArrayRef<int64_t> strides,
1799                       ArrayRef<NamedAttribute> attrs) {
1800   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1801       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
1802         return b.getI64IntegerAttr(v);
1803       }));
1804   SmallVector<OpFoldResult> sizeValues =
1805       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1806         return b.getI64IntegerAttr(v);
1807       }));
1808   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1809       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1810         return b.getI64IntegerAttr(v);
1811       }));
1812   build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
1813 }
1814 
1815 // Build a SubViewOp with dynamic entries and custom result type. If the
1816 // type passed is nullptr, it is inferred.
1817 void SubViewOp::build(OpBuilder &b, OperationState &result,
1818                       MemRefType resultType, Value source,
1819                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1820                       ArrayRef<int64_t> strides,
1821                       ArrayRef<NamedAttribute> attrs) {
1822   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1823       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
1824         return b.getI64IntegerAttr(v);
1825       }));
1826   SmallVector<OpFoldResult> sizeValues =
1827       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1828         return b.getI64IntegerAttr(v);
1829       }));
1830   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1831       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1832         return b.getI64IntegerAttr(v);
1833       }));
1834   build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
1835         attrs);
1836 }
1837 
1838 // Build a SubViewOp with dynamic entries and custom result type. If the type
1839 // passed is nullptr, it is inferred.
1840 void SubViewOp::build(OpBuilder &b, OperationState &result,
1841                       MemRefType resultType, Value source, ValueRange offsets,
1842                       ValueRange sizes, ValueRange strides,
1843                       ArrayRef<NamedAttribute> attrs) {
1844   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1845       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1846   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1847       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1848   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1849       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1850   build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
1851 }
1852 
1853 // Build a SubViewOp with dynamic entries and inferred result type.
1854 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1855                       ValueRange offsets, ValueRange sizes, ValueRange strides,
1856                       ArrayRef<NamedAttribute> attrs) {
1857   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
1858 }
1859 
1860 /// For ViewLikeOpInterface.
1861 Value SubViewOp::getViewSource() { return source(); }
1862 
1863 /// Return true if t1 and t2 have equal offsets (both dynamic or of same static
1864 /// value).
1865 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
1866   AffineExpr t1Offset, t2Offset;
1867   SmallVector<AffineExpr> t1Strides, t2Strides;
1868   auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
1869   auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
1870   return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
1871 }
1872 
1873 /// Checks if `original` Type type can be rank reduced to `reduced` type.
1874 /// This function is slight variant of `is subsequence` algorithm where
1875 /// not matching dimension must be 1.
1876 static SliceVerificationResult
1877 isRankReducedMemRefType(MemRefType originalType,
1878                         MemRefType candidateRankReducedType,
1879                         ArrayRef<OpFoldResult> sizes) {
1880   auto partialRes = isRankReducedType(originalType, candidateRankReducedType);
1881   if (partialRes != SliceVerificationResult::Success)
1882     return partialRes;
1883 
1884   auto optionalUnusedDimsMask = computeMemRefRankReductionMask(
1885       originalType, candidateRankReducedType, sizes);
1886 
1887   // Sizes cannot be matched in case empty vector is returned.
1888   if (!optionalUnusedDimsMask.hasValue())
1889     return SliceVerificationResult::LayoutMismatch;
1890 
1891   if (originalType.getMemorySpace() !=
1892       candidateRankReducedType.getMemorySpace())
1893     return SliceVerificationResult::MemSpaceMismatch;
1894 
1895   // No amount of stride dropping can reconcile incompatible offsets.
1896   if (!haveCompatibleOffsets(originalType, candidateRankReducedType))
1897     return SliceVerificationResult::LayoutMismatch;
1898 
1899   return SliceVerificationResult::Success;
1900 }
1901 
1902 template <typename OpTy>
1903 static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result,
1904                                             OpTy op, Type expectedType) {
1905   auto memrefType = expectedType.cast<ShapedType>();
1906   switch (result) {
1907   case SliceVerificationResult::Success:
1908     return success();
1909   case SliceVerificationResult::RankTooLarge:
1910     return op.emitError("expected result rank to be smaller or equal to ")
1911            << "the source rank. ";
1912   case SliceVerificationResult::SizeMismatch:
1913     return op.emitError("expected result type to be ")
1914            << expectedType
1915            << " or a rank-reduced version. (mismatch of result sizes) ";
1916   case SliceVerificationResult::ElemTypeMismatch:
1917     return op.emitError("expected result element type to be ")
1918            << memrefType.getElementType();
1919   case SliceVerificationResult::MemSpaceMismatch:
1920     return op.emitError("expected result and source memory spaces to match.");
1921   case SliceVerificationResult::LayoutMismatch:
1922     return op.emitError("expected result type to be ")
1923            << expectedType
1924            << " or a rank-reduced version. (mismatch of result layout) ";
1925   }
1926   llvm_unreachable("unexpected subview verification result");
1927 }
1928 
1929 /// Verifier for SubViewOp.
1930 LogicalResult SubViewOp::verify() {
1931   MemRefType baseType = getSourceType();
1932   MemRefType subViewType = getType();
1933 
1934   // The base memref and the view memref should be in the same memory space.
1935   if (baseType.getMemorySpace() != subViewType.getMemorySpace())
1936     return emitError("different memory spaces specified for base memref "
1937                      "type ")
1938            << baseType << " and subview memref type " << subViewType;
1939 
1940   // Verify that the base memref type has a strided layout map.
1941   if (!isStrided(baseType))
1942     return emitError("base type ") << baseType << " is not strided";
1943 
1944   // Verify result type against inferred type.
1945   auto expectedType = SubViewOp::inferResultType(
1946       baseType, extractFromI64ArrayAttr(static_offsets()),
1947       extractFromI64ArrayAttr(static_sizes()),
1948       extractFromI64ArrayAttr(static_strides()));
1949 
1950   auto result = isRankReducedMemRefType(expectedType.cast<MemRefType>(),
1951                                         subViewType, getMixedSizes());
1952   return produceSubViewErrorMsg(result, *this, expectedType);
1953 }
1954 
1955 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
1956   return os << "range " << range.offset << ":" << range.size << ":"
1957             << range.stride;
1958 }
1959 
1960 /// Return the list of Range (i.e. offset, size, stride). Each Range
1961 /// entry contains either the dynamic value or a ConstantIndexOp constructed
1962 /// with `b` at location `loc`.
1963 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
1964                                               OpBuilder &b, Location loc) {
1965   std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
1966   assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
1967   assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
1968   SmallVector<Range, 8> res;
1969   unsigned rank = ranks[0];
1970   res.reserve(rank);
1971   for (unsigned idx = 0; idx < rank; ++idx) {
1972     Value offset =
1973         op.isDynamicOffset(idx)
1974             ? op.getDynamicOffset(idx)
1975             : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx));
1976     Value size =
1977         op.isDynamicSize(idx)
1978             ? op.getDynamicSize(idx)
1979             : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx));
1980     Value stride =
1981         op.isDynamicStride(idx)
1982             ? op.getDynamicStride(idx)
1983             : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx));
1984     res.emplace_back(Range{offset, size, stride});
1985   }
1986   return res;
1987 }
1988 
1989 /// Compute the canonical result type of a SubViewOp. Call `inferResultType` to
1990 /// deduce the result type for the given `sourceType`. Additionally, reduce the
1991 /// rank of the inferred result type if `currentResultType` is lower rank than
1992 /// `currentSourceType`. Use this signature if `sourceType` is updated together
1993 /// with the result type. In this case, it is important to compute the dropped
1994 /// dimensions using `currentSourceType` whose strides align with
1995 /// `currentResultType`.
1996 static MemRefType getCanonicalSubViewResultType(
1997     MemRefType currentResultType, MemRefType currentSourceType,
1998     MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
1999     ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
2000   auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets,
2001                                                        mixedSizes, mixedStrides)
2002                                 .cast<MemRefType>();
2003   llvm::Optional<llvm::SmallBitVector> unusedDims =
2004       computeMemRefRankReductionMask(currentSourceType, currentResultType,
2005                                      mixedSizes);
2006   // Return nullptr as failure mode.
2007   if (!unusedDims)
2008     return nullptr;
2009   SmallVector<int64_t> shape;
2010   for (const auto &sizes : llvm::enumerate(nonRankReducedType.getShape())) {
2011     if (unusedDims->test(sizes.index()))
2012       continue;
2013     shape.push_back(sizes.value());
2014   }
2015   AffineMap layoutMap = nonRankReducedType.getLayout().getAffineMap();
2016   if (!layoutMap.isIdentity())
2017     layoutMap = getProjectedMap(layoutMap, unusedDims.getValue());
2018   return MemRefType::get(shape, nonRankReducedType.getElementType(), layoutMap,
2019                          nonRankReducedType.getMemorySpace());
2020 }
2021 
2022 /// Compute the canonical result type of a SubViewOp. Call `inferResultType` to
2023 /// deduce the result type. Additionally, reduce the rank of the inferred result
2024 /// type if `currentResultType` is lower rank than `sourceType`.
2025 static MemRefType getCanonicalSubViewResultType(
2026     MemRefType currentResultType, MemRefType sourceType,
2027     ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes,
2028     ArrayRef<OpFoldResult> mixedStrides) {
2029   return getCanonicalSubViewResultType(currentResultType, sourceType,
2030                                        sourceType, mixedOffsets, mixedSizes,
2031                                        mixedStrides);
2032 }
2033 
2034 /// Helper method to check if a `subview` operation is trivially a no-op. This
2035 /// is the case if the all offsets are zero, all strides are 1, and the source
2036 /// shape is same as the size of the subview. In such cases, the subview can be
2037 /// folded into its source.
2038 static bool isTrivialSubViewOp(SubViewOp subViewOp) {
2039   if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
2040     return false;
2041 
2042   auto mixedOffsets = subViewOp.getMixedOffsets();
2043   auto mixedSizes = subViewOp.getMixedSizes();
2044   auto mixedStrides = subViewOp.getMixedStrides();
2045 
2046   // Check offsets are zero.
2047   if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
2048         Optional<int64_t> intValue = getConstantIntValue(ofr);
2049         return !intValue || intValue.getValue() != 0;
2050       }))
2051     return false;
2052 
2053   // Check strides are one.
2054   if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
2055         Optional<int64_t> intValue = getConstantIntValue(ofr);
2056         return !intValue || intValue.getValue() != 1;
2057       }))
2058     return false;
2059 
2060   // Check all size values are static and matches the (static) source shape.
2061   ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
2062   for (const auto &size : llvm::enumerate(mixedSizes)) {
2063     Optional<int64_t> intValue = getConstantIntValue(size.value());
2064     if (!intValue || intValue.getValue() != sourceShape[size.index()])
2065       return false;
2066   }
2067   // All conditions met. The `SubViewOp` is foldable as a no-op.
2068   return true;
2069 }
2070 
2071 namespace {
2072 /// Pattern to rewrite a subview op with MemRefCast arguments.
2073 /// This essentially pushes memref.cast past its consuming subview when
2074 /// `canFoldIntoConsumerOp` is true.
2075 ///
2076 /// Example:
2077 /// ```
2078 ///   %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
2079 ///   %1 = memref.subview %0[0, 0][3, 4][1, 1] :
2080 ///     memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
2081 /// ```
2082 /// is rewritten into:
2083 /// ```
2084 ///   %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
2085 ///   %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to
2086 ///     memref<3x4xf32, offset:?, strides:[?, 1]>
2087 /// ```
2088 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
2089 public:
2090   using OpRewritePattern<SubViewOp>::OpRewritePattern;
2091 
2092   LogicalResult matchAndRewrite(SubViewOp subViewOp,
2093                                 PatternRewriter &rewriter) const override {
2094     // Any constant operand, just return to let SubViewOpConstantFolder kick in.
2095     if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
2096           return matchPattern(operand, matchConstantIndex());
2097         }))
2098       return failure();
2099 
2100     auto castOp = subViewOp.source().getDefiningOp<CastOp>();
2101     if (!castOp)
2102       return failure();
2103 
2104     if (!CastOp::canFoldIntoConsumerOp(castOp))
2105       return failure();
2106 
2107     // Compute the SubViewOp result type after folding the MemRefCastOp. Use the
2108     // MemRefCastOp source operand type to infer the result type and the current
2109     // SubViewOp source operand type to compute the dropped dimensions if the
2110     // operation is rank-reducing.
2111     auto resultType = getCanonicalSubViewResultType(
2112         subViewOp.getType(), subViewOp.getSourceType(),
2113         castOp.source().getType().cast<MemRefType>(),
2114         subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
2115         subViewOp.getMixedStrides());
2116     if (!resultType)
2117       return failure();
2118 
2119     Value newSubView = rewriter.create<SubViewOp>(
2120         subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
2121         subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),
2122         subViewOp.static_sizes(), subViewOp.static_strides());
2123     rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
2124                                         newSubView);
2125     return success();
2126   }
2127 };
2128 
2129 /// Canonicalize subview ops that are no-ops. When the source shape is not same
2130 /// as a result shape due to use of `affine_map`.
2131 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
2132 public:
2133   using OpRewritePattern<SubViewOp>::OpRewritePattern;
2134 
2135   LogicalResult matchAndRewrite(SubViewOp subViewOp,
2136                                 PatternRewriter &rewriter) const override {
2137     if (!isTrivialSubViewOp(subViewOp))
2138       return failure();
2139     if (subViewOp.getSourceType() == subViewOp.getType()) {
2140       rewriter.replaceOp(subViewOp, subViewOp.source());
2141       return success();
2142     }
2143     rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
2144                                         subViewOp.source());
2145     return success();
2146   }
2147 };
2148 } // namespace
2149 
2150 /// Return the canonical type of the result of a subview.
2151 struct SubViewReturnTypeCanonicalizer {
2152   MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
2153                         ArrayRef<OpFoldResult> mixedSizes,
2154                         ArrayRef<OpFoldResult> mixedStrides) {
2155     return getCanonicalSubViewResultType(op.getType(), op.getSourceType(),
2156                                          mixedOffsets, mixedSizes,
2157                                          mixedStrides);
2158   }
2159 };
2160 
2161 /// A canonicalizer wrapper to replace SubViewOps.
2162 struct SubViewCanonicalizer {
2163   void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
2164     rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2165   }
2166 };
2167 
2168 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
2169                                             MLIRContext *context) {
2170   results
2171       .add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
2172                SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>,
2173            SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
2174 }
2175 
2176 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
2177   auto resultShapedType = getResult().getType().cast<ShapedType>();
2178   auto sourceShapedType = source().getType().cast<ShapedType>();
2179 
2180   if (resultShapedType.hasStaticShape() &&
2181       resultShapedType == sourceShapedType) {
2182     return getViewSource();
2183   }
2184 
2185   return {};
2186 }
2187 
2188 //===----------------------------------------------------------------------===//
2189 // TransposeOp
2190 //===----------------------------------------------------------------------===//
2191 
2192 /// Build a strided memref type by applying `permutationMap` tp `memRefType`.
2193 static MemRefType inferTransposeResultType(MemRefType memRefType,
2194                                            AffineMap permutationMap) {
2195   auto rank = memRefType.getRank();
2196   auto originalSizes = memRefType.getShape();
2197   // Compute permuted sizes.
2198   SmallVector<int64_t, 4> sizes(rank, 0);
2199   for (const auto &en : llvm::enumerate(permutationMap.getResults()))
2200     sizes[en.index()] =
2201         originalSizes[en.value().cast<AffineDimExpr>().getPosition()];
2202 
2203   // Compute permuted strides.
2204   int64_t offset;
2205   SmallVector<int64_t, 4> strides;
2206   auto res = getStridesAndOffset(memRefType, strides, offset);
2207   assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank));
2208   (void)res;
2209   auto map =
2210       makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
2211   map = permutationMap ? map.compose(permutationMap) : map;
2212   return MemRefType::Builder(memRefType)
2213       .setShape(sizes)
2214       .setLayout(AffineMapAttr::get(map));
2215 }
2216 
2217 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
2218                         AffineMapAttr permutation,
2219                         ArrayRef<NamedAttribute> attrs) {
2220   auto permutationMap = permutation.getValue();
2221   assert(permutationMap);
2222 
2223   auto memRefType = in.getType().cast<MemRefType>();
2224   // Compute result type.
2225   MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
2226 
2227   build(b, result, resultType, in, attrs);
2228   result.addAttribute(TransposeOp::getPermutationAttrName(), permutation);
2229 }
2230 
2231 // transpose $in $permutation attr-dict : type($in) `to` type(results)
2232 void TransposeOp::print(OpAsmPrinter &p) {
2233   p << " " << in() << " " << permutation();
2234   p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
2235   p << " : " << in().getType() << " to " << getType();
2236 }
2237 
2238 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
2239   OpAsmParser::OperandType in;
2240   AffineMap permutation;
2241   MemRefType srcType, dstType;
2242   if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
2243       parser.parseOptionalAttrDict(result.attributes) ||
2244       parser.parseColonType(srcType) ||
2245       parser.resolveOperand(in, srcType, result.operands) ||
2246       parser.parseKeywordType("to", dstType) ||
2247       parser.addTypeToList(dstType, result.types))
2248     return failure();
2249 
2250   result.addAttribute(TransposeOp::getPermutationAttrName(),
2251                       AffineMapAttr::get(permutation));
2252   return success();
2253 }
2254 
2255 LogicalResult TransposeOp::verify() {
2256   if (!permutation().isPermutation())
2257     return emitOpError("expected a permutation map");
2258   if (permutation().getNumDims() != getShapedType().getRank())
2259     return emitOpError("expected a permutation map of same rank as the input");
2260 
2261   auto srcType = in().getType().cast<MemRefType>();
2262   auto dstType = getType().cast<MemRefType>();
2263   auto transposedType = inferTransposeResultType(srcType, permutation());
2264   if (dstType != transposedType)
2265     return emitOpError("output type ")
2266            << dstType << " does not match transposed input type " << srcType
2267            << ", " << transposedType;
2268   return success();
2269 }
2270 
2271 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
2272   if (succeeded(foldMemRefCast(*this)))
2273     return getResult();
2274   return {};
2275 }
2276 
2277 //===----------------------------------------------------------------------===//
2278 // ViewOp
2279 //===----------------------------------------------------------------------===//
2280 
2281 LogicalResult ViewOp::verify() {
2282   auto baseType = getOperand(0).getType().cast<MemRefType>();
2283   auto viewType = getType();
2284 
2285   // The base memref should have identity layout map (or none).
2286   if (!baseType.getLayout().isIdentity())
2287     return emitError("unsupported map for base memref type ") << baseType;
2288 
2289   // The result memref should have identity layout map (or none).
2290   if (!viewType.getLayout().isIdentity())
2291     return emitError("unsupported map for result memref type ") << viewType;
2292 
2293   // The base memref and the view memref should be in the same memory space.
2294   if (baseType.getMemorySpace() != viewType.getMemorySpace())
2295     return emitError("different memory spaces specified for base memref "
2296                      "type ")
2297            << baseType << " and view memref type " << viewType;
2298 
2299   // Verify that we have the correct number of sizes for the result type.
2300   unsigned numDynamicDims = viewType.getNumDynamicDims();
2301   if (sizes().size() != numDynamicDims)
2302     return emitError("incorrect number of size operands for type ") << viewType;
2303 
2304   return success();
2305 }
2306 
2307 Value ViewOp::getViewSource() { return source(); }
2308 
2309 namespace {
2310 
2311 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
2312   using OpRewritePattern<ViewOp>::OpRewritePattern;
2313 
2314   LogicalResult matchAndRewrite(ViewOp viewOp,
2315                                 PatternRewriter &rewriter) const override {
2316     // Return if none of the operands are constants.
2317     if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
2318           return matchPattern(operand, matchConstantIndex());
2319         }))
2320       return failure();
2321 
2322     // Get result memref type.
2323     auto memrefType = viewOp.getType();
2324 
2325     // Get offset from old memref view type 'memRefType'.
2326     int64_t oldOffset;
2327     SmallVector<int64_t, 4> oldStrides;
2328     if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
2329       return failure();
2330     assert(oldOffset == 0 && "Expected 0 offset");
2331 
2332     SmallVector<Value, 4> newOperands;
2333 
2334     // Offset cannot be folded into result type.
2335 
2336     // Fold any dynamic dim operands which are produced by a constant.
2337     SmallVector<int64_t, 4> newShapeConstants;
2338     newShapeConstants.reserve(memrefType.getRank());
2339 
2340     unsigned dynamicDimPos = 0;
2341     unsigned rank = memrefType.getRank();
2342     for (unsigned dim = 0, e = rank; dim < e; ++dim) {
2343       int64_t dimSize = memrefType.getDimSize(dim);
2344       // If this is already static dimension, keep it.
2345       if (!ShapedType::isDynamic(dimSize)) {
2346         newShapeConstants.push_back(dimSize);
2347         continue;
2348       }
2349       auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp();
2350       if (auto constantIndexOp =
2351               dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
2352         // Dynamic shape dimension will be folded.
2353         newShapeConstants.push_back(constantIndexOp.value());
2354       } else {
2355         // Dynamic shape dimension not folded; copy operand from old memref.
2356         newShapeConstants.push_back(dimSize);
2357         newOperands.push_back(viewOp.sizes()[dynamicDimPos]);
2358       }
2359       dynamicDimPos++;
2360     }
2361 
2362     // Create new memref type with constant folded dims.
2363     MemRefType newMemRefType =
2364         MemRefType::Builder(memrefType).setShape(newShapeConstants);
2365     // Nothing new, don't fold.
2366     if (newMemRefType == memrefType)
2367       return failure();
2368 
2369     // Create new ViewOp.
2370     auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType,
2371                                              viewOp.getOperand(0),
2372                                              viewOp.byte_shift(), newOperands);
2373     // Insert a cast so we have the same type as the old memref type.
2374     rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
2375     return success();
2376   }
2377 };
2378 
2379 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
2380   using OpRewritePattern<ViewOp>::OpRewritePattern;
2381 
2382   LogicalResult matchAndRewrite(ViewOp viewOp,
2383                                 PatternRewriter &rewriter) const override {
2384     Value memrefOperand = viewOp.getOperand(0);
2385     CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
2386     if (!memrefCastOp)
2387       return failure();
2388     Value allocOperand = memrefCastOp.getOperand();
2389     AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
2390     if (!allocOp)
2391       return failure();
2392     rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
2393                                         viewOp.byte_shift(), viewOp.sizes());
2394     return success();
2395   }
2396 };
2397 
2398 } // namespace
2399 
2400 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
2401                                          MLIRContext *context) {
2402   results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
2403 }
2404 
2405 //===----------------------------------------------------------------------===//
2406 // AtomicRMWOp
2407 //===----------------------------------------------------------------------===//
2408 
2409 LogicalResult AtomicRMWOp::verify() {
2410   if (getMemRefType().getRank() != getNumOperands() - 2)
2411     return emitOpError(
2412         "expects the number of subscripts to be equal to memref rank");
2413   switch (kind()) {
2414   case arith::AtomicRMWKind::addf:
2415   case arith::AtomicRMWKind::maxf:
2416   case arith::AtomicRMWKind::minf:
2417   case arith::AtomicRMWKind::mulf:
2418     if (!value().getType().isa<FloatType>())
2419       return emitOpError() << "with kind '"
2420                            << arith::stringifyAtomicRMWKind(kind())
2421                            << "' expects a floating-point type";
2422     break;
2423   case arith::AtomicRMWKind::addi:
2424   case arith::AtomicRMWKind::maxs:
2425   case arith::AtomicRMWKind::maxu:
2426   case arith::AtomicRMWKind::mins:
2427   case arith::AtomicRMWKind::minu:
2428   case arith::AtomicRMWKind::muli:
2429   case arith::AtomicRMWKind::ori:
2430   case arith::AtomicRMWKind::andi:
2431     if (!value().getType().isa<IntegerType>())
2432       return emitOpError() << "with kind '"
2433                            << arith::stringifyAtomicRMWKind(kind())
2434                            << "' expects an integer type";
2435     break;
2436   default:
2437     break;
2438   }
2439   return success();
2440 }
2441 
2442 OpFoldResult AtomicRMWOp::fold(ArrayRef<Attribute> operands) {
2443   /// atomicrmw(memrefcast) -> atomicrmw
2444   if (succeeded(foldMemRefCast(*this, value())))
2445     return getResult();
2446   return OpFoldResult();
2447 }
2448 
2449 //===----------------------------------------------------------------------===//
2450 // TableGen'd op method definitions
2451 //===----------------------------------------------------------------------===//
2452 
2453 #define GET_OP_CLASSES
2454 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
2455