1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/MemRef/IR/MemRef.h"
10 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Interfaces/InferTypeOpInterface.h"
21 #include "mlir/Interfaces/ViewLikeInterface.h"
22 #include "llvm/ADT/STLExtras.h"
23 
24 using namespace mlir;
25 using namespace mlir::memref;
26 
27 /// Materialize a single constant operation from a given attribute value with
28 /// the desired resultant type.
29 Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
30                                               Attribute value, Type type,
31                                               Location loc) {
32   return builder.create<mlir::ConstantOp>(loc, type, value);
33 }
34 
35 /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
36 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
37   return llvm::to_vector<4>(
38       llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
39         return a.cast<IntegerAttr>().getInt();
40       }));
41 }
42 
43 /// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
44 /// it is a Value or into `staticVec` if it is an IntegerAttr.
45 /// In the case of a Value, a copy of the `sentinel` value is also pushed to
46 /// `staticVec`. This is useful to extract mixed static and dynamic entries that
47 /// come from an AttrSizedOperandSegments trait.
48 static void dispatchIndexOpFoldResult(OpFoldResult ofr,
49                                       SmallVectorImpl<Value> &dynamicVec,
50                                       SmallVectorImpl<int64_t> &staticVec,
51                                       int64_t sentinel) {
52   if (auto v = ofr.dyn_cast<Value>()) {
53     dynamicVec.push_back(v);
54     staticVec.push_back(sentinel);
55     return;
56   }
57   APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue();
58   staticVec.push_back(apInt.getSExtValue());
59 }
60 
61 static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
62                                        SmallVectorImpl<Value> &dynamicVec,
63                                        SmallVectorImpl<int64_t> &staticVec,
64                                        int64_t sentinel) {
65   for (auto ofr : ofrs)
66     dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel);
67 }
68 
69 //===----------------------------------------------------------------------===//
70 // Common canonicalization pattern support logic
71 //===----------------------------------------------------------------------===//
72 
73 /// This is a common class used for patterns of the form
74 /// "someop(memrefcast) -> someop".  It folds the source of any memref.cast
75 /// into the root operation directly.
76 static LogicalResult foldMemRefCast(Operation *op, Value inner = nullptr) {
77   bool folded = false;
78   for (OpOperand &operand : op->getOpOperands()) {
79     auto cast = operand.get().getDefiningOp<CastOp>();
80     if (cast && operand.get() != inner &&
81         !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
82       operand.set(cast.getOperand());
83       folded = true;
84     }
85   }
86   return success(folded);
87 }
88 
89 //===----------------------------------------------------------------------===//
90 // Helpers for GlobalOp
91 //===----------------------------------------------------------------------===//
92 
93 static Type getTensorTypeFromMemRefType(Type type) {
94   if (auto memref = type.dyn_cast<MemRefType>())
95     return RankedTensorType::get(memref.getShape(), memref.getElementType());
96   if (auto memref = type.dyn_cast<UnrankedMemRefType>())
97     return UnrankedTensorType::get(memref.getElementType());
98   return NoneType::get(type.getContext());
99 }
100 
101 //===----------------------------------------------------------------------===//
102 // AllocOp / AllocaOp
103 //===----------------------------------------------------------------------===//
104 
105 template <typename AllocLikeOp>
106 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
107   static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
108                 "applies to only alloc or alloca");
109   auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>();
110   if (!memRefType)
111     return op.emitOpError("result must be a memref");
112 
113   if (static_cast<int64_t>(op.dynamicSizes().size()) !=
114       memRefType.getNumDynamicDims())
115     return op.emitOpError("dimension operand count does not equal memref "
116                           "dynamic dimension count");
117 
118   unsigned numSymbols = 0;
119   if (!memRefType.getAffineMaps().empty())
120     numSymbols = memRefType.getAffineMaps().front().getNumSymbols();
121   if (op.symbolOperands().size() != numSymbols)
122     return op.emitOpError("symbol operand count does not equal memref symbol "
123                           "count: expected ")
124            << numSymbols << ", got " << op.symbolOperands().size();
125 
126   return success();
127 }
128 
129 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); }
130 
131 static LogicalResult verify(AllocaOp op) {
132   // An alloca op needs to have an ancestor with an allocation scope trait.
133   if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
134     return op.emitOpError(
135         "requires an ancestor op with AutomaticAllocationScope trait");
136 
137   return verifyAllocLikeOp(op);
138 }
139 
140 namespace {
141 /// Fold constant dimensions into an alloc like operation.
142 template <typename AllocLikeOp>
143 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
144   using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
145 
146   LogicalResult matchAndRewrite(AllocLikeOp alloc,
147                                 PatternRewriter &rewriter) const override {
148     // Check to see if any dimensions operands are constants.  If so, we can
149     // substitute and drop them.
150     if (llvm::none_of(alloc.dynamicSizes(), [](Value operand) {
151           return matchPattern(operand, matchConstantIndex());
152         }))
153       return failure();
154 
155     auto memrefType = alloc.getType();
156 
157     // Ok, we have one or more constant operands.  Collect the non-constant ones
158     // and keep track of the resultant memref type to build.
159     SmallVector<int64_t, 4> newShapeConstants;
160     newShapeConstants.reserve(memrefType.getRank());
161     SmallVector<Value, 4> dynamicSizes;
162 
163     unsigned dynamicDimPos = 0;
164     for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
165       int64_t dimSize = memrefType.getDimSize(dim);
166       // If this is already static dimension, keep it.
167       if (dimSize != -1) {
168         newShapeConstants.push_back(dimSize);
169         continue;
170       }
171       auto dynamicSize = alloc.dynamicSizes()[dynamicDimPos];
172       auto *defOp = dynamicSize.getDefiningOp();
173       if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
174         // Dynamic shape dimension will be folded.
175         newShapeConstants.push_back(constantIndexOp.getValue());
176       } else {
177         // Dynamic shape dimension not folded; copy dynamicSize from old memref.
178         newShapeConstants.push_back(-1);
179         dynamicSizes.push_back(dynamicSize);
180       }
181       dynamicDimPos++;
182     }
183 
184     // Create new memref type (which will have fewer dynamic dimensions).
185     MemRefType newMemRefType =
186         MemRefType::Builder(memrefType).setShape(newShapeConstants);
187     assert(static_cast<int64_t>(dynamicSizes.size()) ==
188            newMemRefType.getNumDynamicDims());
189 
190     // Create and insert the alloc op for the new memref.
191     auto newAlloc = rewriter.create<AllocLikeOp>(
192         alloc.getLoc(), newMemRefType, dynamicSizes, alloc.symbolOperands(),
193         alloc.alignmentAttr());
194     // Insert a cast so we have the same type as the old alloc.
195     auto resultCast =
196         rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType());
197 
198     rewriter.replaceOp(alloc, {resultCast});
199     return success();
200   }
201 };
202 
203 /// Fold alloc operations with no users or only store and dealloc uses.
204 template <typename T>
205 struct SimplifyDeadAlloc : public OpRewritePattern<T> {
206   using OpRewritePattern<T>::OpRewritePattern;
207 
208   LogicalResult matchAndRewrite(T alloc,
209                                 PatternRewriter &rewriter) const override {
210     if (llvm::any_of(alloc->getUsers(), [](Operation *op) {
211           return !isa<StoreOp, DeallocOp>(op);
212         }))
213       return failure();
214 
215     for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
216       rewriter.eraseOp(user);
217 
218     rewriter.eraseOp(alloc);
219     return success();
220   }
221 };
222 } // end anonymous namespace.
223 
224 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
225                                           MLIRContext *context) {
226   results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
227 }
228 
229 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
230                                            MLIRContext *context) {
231   results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
232       context);
233 }
234 
235 //===----------------------------------------------------------------------===//
236 // AllocaScopeOp
237 //===----------------------------------------------------------------------===//
238 
239 static void print(OpAsmPrinter &p, AllocaScopeOp &op) {
240   bool printBlockTerminators = false;
241 
242   p << AllocaScopeOp::getOperationName() << " ";
243   if (!op.results().empty()) {
244     p << " -> (" << op.getResultTypes() << ")";
245     printBlockTerminators = true;
246   }
247   p.printRegion(op.bodyRegion(),
248                 /*printEntryBlockArgs=*/false,
249                 /*printBlockTerminators=*/printBlockTerminators);
250   p.printOptionalAttrDict(op->getAttrs());
251 }
252 
253 static ParseResult parseAllocaScopeOp(OpAsmParser &parser,
254                                       OperationState &result) {
255   // Create a region for the body.
256   result.regions.reserve(1);
257   Region *bodyRegion = result.addRegion();
258 
259   // Parse optional results type list.
260   if (parser.parseOptionalArrowTypeList(result.types))
261     return failure();
262 
263   // Parse the body region.
264   if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{}))
265     return failure();
266   AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
267                                   result.location);
268 
269   // Parse the optional attribute list.
270   if (parser.parseOptionalAttrDict(result.attributes))
271     return failure();
272 
273   return success();
274 }
275 
276 static LogicalResult verify(AllocaScopeOp op) {
277   if (failed(RegionBranchOpInterface::verifyTypes(op)))
278     return failure();
279 
280   return success();
281 }
282 
283 void AllocaScopeOp::getSuccessorRegions(
284     Optional<unsigned> index, ArrayRef<Attribute> operands,
285     SmallVectorImpl<RegionSuccessor> &regions) {
286   if (index.hasValue()) {
287     regions.push_back(RegionSuccessor(getResults()));
288     return;
289   }
290 
291   regions.push_back(RegionSuccessor(&bodyRegion()));
292 }
293 
294 //===----------------------------------------------------------------------===//
295 // AssumeAlignmentOp
296 //===----------------------------------------------------------------------===//
297 
298 static LogicalResult verify(AssumeAlignmentOp op) {
299   unsigned alignment = op.alignment();
300   if (!llvm::isPowerOf2_32(alignment))
301     return op.emitOpError("alignment must be power of 2");
302   return success();
303 }
304 
305 //===----------------------------------------------------------------------===//
306 // BufferCastOp
307 //===----------------------------------------------------------------------===//
308 
309 OpFoldResult BufferCastOp::fold(ArrayRef<Attribute>) {
310   if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>())
311     if (tensorLoad.memref().getType() == getType())
312       return tensorLoad.memref();
313   return {};
314 }
315 
316 namespace {
317 /// Replace tensor_cast + buffer_cast by buffer_cast + memref_cast.
318 struct BufferCast : public OpRewritePattern<BufferCastOp> {
319   using OpRewritePattern<BufferCastOp>::OpRewritePattern;
320 
321   LogicalResult matchAndRewrite(BufferCastOp bufferCast,
322                                 PatternRewriter &rewriter) const final {
323     auto tensorCastOperand =
324         bufferCast.getOperand().getDefiningOp<tensor::CastOp>();
325     if (!tensorCastOperand)
326       return failure();
327     auto srcTensorType =
328         tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>();
329     if (!srcTensorType)
330       return failure();
331     auto memrefType = MemRefType::get(srcTensorType.getShape(),
332                                       srcTensorType.getElementType());
333     Value memref = rewriter.create<BufferCastOp>(
334         bufferCast.getLoc(), memrefType, tensorCastOperand.getOperand());
335     rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(),
336                                         memref);
337     return success();
338   }
339 };
340 
341 /// Canonicalize memref.tensor_load + memref.buffer_cast to memref.cast when
342 /// type mismatches prevent `BufferCastOp::fold` to kick in.
343 struct TensorLoadToMemRef : public OpRewritePattern<BufferCastOp> {
344   using OpRewritePattern<BufferCastOp>::OpRewritePattern;
345 
346   LogicalResult matchAndRewrite(BufferCastOp bufferCast,
347                                 PatternRewriter &rewriter) const final {
348     auto tensorLoad = bufferCast.tensor().getDefiningOp<TensorLoadOp>();
349     // Bail unless we have a tensor_load + memref.buffer_cast with different
350     // types. `BufferCastOp::fold` handles the same type case.
351     if (!tensorLoad || tensorLoad.memref().getType() == bufferCast.getType())
352       return failure();
353     // If types are not cast-compatible, bail.
354     if (!CastOp::areCastCompatible(tensorLoad.memref().getType(),
355                                    bufferCast.getType()))
356       return failure();
357     rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(),
358                                         tensorLoad.memref());
359     return success();
360   }
361 };
362 
363 } // namespace
364 
365 void BufferCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
366                                                MLIRContext *context) {
367   results.add<BufferCast, TensorLoadToMemRef>(context);
368 }
369 
370 //===----------------------------------------------------------------------===//
371 // CastOp
372 //===----------------------------------------------------------------------===//
373 
374 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
375 /// source memref. This is useful to to fold a memref.cast into a consuming op
376 /// and implement canonicalization patterns for ops in different dialects that
377 /// may consume the results of memref.cast operations. Such foldable memref.cast
378 /// operations are typically inserted as `view` and `subview` ops are
379 /// canonicalized, to preserve the type compatibility of their uses.
380 ///
381 /// Returns true when all conditions are met:
382 /// 1. source and result are ranked memrefs with strided semantics and same
383 /// element type and rank.
384 /// 2. each of the source's size, offset or stride has more static information
385 /// than the corresponding result's size, offset or stride.
386 ///
387 /// Example 1:
388 /// ```mlir
389 ///   %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
390 ///   %2 = consumer %1 ... : memref<?x?xf32> ...
391 /// ```
392 ///
393 /// may fold into:
394 ///
395 /// ```mlir
396 ///   %2 = consumer %0 ... : memref<8x16xf32> ...
397 /// ```
398 ///
399 /// Example 2:
400 /// ```
401 ///   %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
402 ///          to memref<?x?xf32>
403 ///   consumer %1 : memref<?x?xf32> ...
404 /// ```
405 ///
406 /// may fold into:
407 ///
408 /// ```
409 ///   consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
410 /// ```
411 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
412   MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
413   MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
414 
415   // Requires ranked MemRefType.
416   if (!sourceType || !resultType)
417     return false;
418 
419   // Requires same elemental type.
420   if (sourceType.getElementType() != resultType.getElementType())
421     return false;
422 
423   // Requires same rank.
424   if (sourceType.getRank() != resultType.getRank())
425     return false;
426 
427   // Only fold casts between strided memref forms.
428   int64_t sourceOffset, resultOffset;
429   SmallVector<int64_t, 4> sourceStrides, resultStrides;
430   if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
431       failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
432     return false;
433 
434   // If cast is towards more static sizes along any dimension, don't fold.
435   for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
436     auto ss = std::get<0>(it), st = std::get<1>(it);
437     if (ss != st)
438       if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st))
439         return false;
440   }
441 
442   // If cast is towards more static offset along any dimension, don't fold.
443   if (sourceOffset != resultOffset)
444     if (MemRefType::isDynamicStrideOrOffset(sourceOffset) &&
445         !MemRefType::isDynamicStrideOrOffset(resultOffset))
446       return false;
447 
448   // If cast is towards more static strides along any dimension, don't fold.
449   for (auto it : llvm::zip(sourceStrides, resultStrides)) {
450     auto ss = std::get<0>(it), st = std::get<1>(it);
451     if (ss != st)
452       if (MemRefType::isDynamicStrideOrOffset(ss) &&
453           !MemRefType::isDynamicStrideOrOffset(st))
454         return false;
455   }
456 
457   return true;
458 }
459 
460 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
461   if (inputs.size() != 1 || outputs.size() != 1)
462     return false;
463   Type a = inputs.front(), b = outputs.front();
464   auto aT = a.dyn_cast<MemRefType>();
465   auto bT = b.dyn_cast<MemRefType>();
466 
467   auto uaT = a.dyn_cast<UnrankedMemRefType>();
468   auto ubT = b.dyn_cast<UnrankedMemRefType>();
469 
470   if (aT && bT) {
471     if (aT.getElementType() != bT.getElementType())
472       return false;
473     if (aT.getAffineMaps() != bT.getAffineMaps()) {
474       int64_t aOffset, bOffset;
475       SmallVector<int64_t, 4> aStrides, bStrides;
476       if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
477           failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
478           aStrides.size() != bStrides.size())
479         return false;
480 
481       // Strides along a dimension/offset are compatible if the value in the
482       // source memref is static and the value in the target memref is the
483       // same. They are also compatible if either one is dynamic (see
484       // description of MemRefCastOp for details).
485       auto checkCompatible = [](int64_t a, int64_t b) {
486         return (a == MemRefType::getDynamicStrideOrOffset() ||
487                 b == MemRefType::getDynamicStrideOrOffset() || a == b);
488       };
489       if (!checkCompatible(aOffset, bOffset))
490         return false;
491       for (auto aStride : enumerate(aStrides))
492         if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
493           return false;
494     }
495     if (aT.getMemorySpace() != bT.getMemorySpace())
496       return false;
497 
498     // They must have the same rank, and any specified dimensions must match.
499     if (aT.getRank() != bT.getRank())
500       return false;
501 
502     for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
503       int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
504       if (aDim != -1 && bDim != -1 && aDim != bDim)
505         return false;
506     }
507     return true;
508   } else {
509     if (!aT && !uaT)
510       return false;
511     if (!bT && !ubT)
512       return false;
513     // Unranked to unranked casting is unsupported
514     if (uaT && ubT)
515       return false;
516 
517     auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
518     auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
519     if (aEltType != bEltType)
520       return false;
521 
522     auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
523     auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
524     if (aMemSpace != bMemSpace)
525       return false;
526 
527     return true;
528   }
529 
530   return false;
531 }
532 
533 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
534   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
535 }
536 
537 //===----------------------------------------------------------------------===//
538 // CloneOp
539 //===----------------------------------------------------------------------===//
540 
541 void CloneOp::getEffects(
542     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
543         &effects) {
544   effects.emplace_back(MemoryEffects::Read::get(), input(),
545                        SideEffects::DefaultResource::get());
546   effects.emplace_back(MemoryEffects::Write::get(), output(),
547                        SideEffects::DefaultResource::get());
548 }
549 
550 namespace {
551 /// Fold Dealloc operations that are deallocating an AllocOp that is only used
552 /// by other Dealloc operations.
553 struct SimplifyClones : public OpRewritePattern<CloneOp> {
554   using OpRewritePattern<CloneOp>::OpRewritePattern;
555 
556   LogicalResult matchAndRewrite(CloneOp cloneOp,
557                                 PatternRewriter &rewriter) const override {
558     if (cloneOp.use_empty()) {
559       rewriter.eraseOp(cloneOp);
560       return success();
561     }
562 
563     Value source = cloneOp.input();
564 
565     // This only finds dealloc operations for the immediate value. It should
566     // also consider aliases. That would also make the safety check below
567     // redundant.
568     Operation *cloneDeallocOp = findDealloc(cloneOp.output());
569     Operation *sourceDeallocOp = findDealloc(source);
570 
571     // If both are deallocated in the same block, their in-block lifetimes
572     // might not fully overlap, so we cannot decide which one to drop.
573     if (cloneDeallocOp && sourceDeallocOp &&
574         cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
575       return failure();
576 
577     Block *currentBlock = cloneOp->getBlock();
578     Operation *redundantDealloc = nullptr;
579     if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
580       redundantDealloc = cloneDeallocOp;
581     } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
582       redundantDealloc = sourceDeallocOp;
583     }
584 
585     if (!redundantDealloc)
586       return failure();
587 
588     // Safety check that there are no other deallocations inbetween
589     // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
590     // of source before the uses of the clone. With alias information, we could
591     // restrict this to only fail of the dealloc's operand is an alias
592     // of the source.
593     for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
594          pos = pos->getNextNode()) {
595       auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
596       if (!effectInterface)
597         continue;
598       if (effectInterface.hasEffect<MemoryEffects::Free>())
599         return failure();
600     }
601 
602     rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(),
603                                                 source);
604     rewriter.eraseOp(redundantDealloc);
605     return success();
606   }
607 };
608 
609 } // end anonymous namespace.
610 
611 void CloneOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
612                                           MLIRContext *context) {
613   results.insert<SimplifyClones>(context);
614 }
615 
616 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) {
617   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
618 }
619 
620 //===----------------------------------------------------------------------===//
621 // DeallocOp
622 //===----------------------------------------------------------------------===//
623 
624 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
625                               SmallVectorImpl<OpFoldResult> &results) {
626   /// dealloc(memrefcast) -> dealloc
627   return foldMemRefCast(*this);
628 }
629 
630 //===----------------------------------------------------------------------===//
631 // DimOp
632 //===----------------------------------------------------------------------===//
633 
634 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref,
635                   int64_t index) {
636   auto loc = result.location;
637   Value indexValue = builder.create<ConstantIndexOp>(loc, index);
638   build(builder, result, memref, indexValue);
639 }
640 
641 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref,
642                   Value index) {
643   auto indexTy = builder.getIndexType();
644   build(builder, result, indexTy, memref, index);
645 }
646 
647 Optional<int64_t> DimOp::getConstantIndex() {
648   if (auto constantOp = index().getDefiningOp<ConstantOp>())
649     return constantOp.getValue().cast<IntegerAttr>().getInt();
650   return {};
651 }
652 
653 static LogicalResult verify(DimOp op) {
654   // Assume unknown index to be in range.
655   Optional<int64_t> index = op.getConstantIndex();
656   if (!index.hasValue())
657     return success();
658 
659   // Check that constant index is not knowingly out of range.
660   auto type = op.memrefOrTensor().getType();
661   if (auto memrefType = type.dyn_cast<MemRefType>()) {
662     if (index.getValue() >= memrefType.getRank())
663       return op.emitOpError("index is out of range");
664   } else if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
665     if (index.getValue() >= tensorType.getRank())
666       return op.emitOpError("index is out of range");
667   } else if (type.isa<UnrankedMemRefType>() || type.isa<UnrankedTensorType>()) {
668     // Assume index to be in range.
669   } else {
670     llvm_unreachable("expected operand with memref type");
671   }
672   return success();
673 }
674 
675 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
676   auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
677 
678   // All forms of folding require a known index.
679   if (!index)
680     return {};
681 
682   auto argTy = memrefOrTensor().getType();
683   // Fold if the shape extent along the given index is known.
684   if (auto shapedTy = argTy.dyn_cast<ShapedType>()) {
685     // Folding for unranked types (UnrankedMemRefType) is not supported.
686     if (!shapedTy.hasRank())
687       return {};
688     if (!shapedTy.isDynamicDim(index.getInt())) {
689       Builder builder(getContext());
690       return builder.getIndexAttr(shapedTy.getShape()[index.getInt()]);
691     }
692   }
693 
694   Operation *definingOp = memrefOrTensor().getDefiningOp();
695 
696   // dim(memref.tensor_load(memref)) -> dim(memref)
697   if (auto tensorLoadOp = dyn_cast_or_null<TensorLoadOp>(definingOp)) {
698     setOperand(0, tensorLoadOp.memref());
699     return getResult();
700   }
701 
702   // Fold dim to the operand of tensor.generate.
703   if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
704     auto resultType =
705         fromElements.getResult().getType().cast<RankedTensorType>();
706     // The case where the type encodes the size of the dimension is handled
707     // above.
708     assert(resultType.getShape()[index.getInt()] ==
709            RankedTensorType::kDynamicSize);
710 
711     // Find the operand of the fromElements that corresponds to this index.
712     auto dynExtents = fromElements.dynamicExtents().begin();
713     for (auto dim : resultType.getShape().take_front(index.getInt()))
714       if (dim == RankedTensorType::kDynamicSize)
715         dynExtents++;
716 
717     return Value{*dynExtents};
718   }
719 
720   // The size at the given index is now known to be a dynamic size.
721   unsigned unsignedIndex = index.getValue().getZExtValue();
722 
723   if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
724     assert(sliceOp.isDynamicSize(unsignedIndex) &&
725            "Expected dynamic slice size");
726     return sliceOp.getDynamicSize(unsignedIndex);
727   }
728 
729   // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
730   auto memrefType = argTy.dyn_cast<MemRefType>();
731   if (!memrefType)
732     return {};
733 
734   if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
735     return *(alloc.getDynamicSizes().begin() +
736              memrefType.getDynamicDimIndex(unsignedIndex));
737 
738   if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
739     return *(alloca.getDynamicSizes().begin() +
740              memrefType.getDynamicDimIndex(unsignedIndex));
741 
742   if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
743     return *(view.getDynamicSizes().begin() +
744              memrefType.getDynamicDimIndex(unsignedIndex));
745 
746   if (auto sizeInterface =
747           dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
748     assert(sizeInterface.isDynamicSize(unsignedIndex) &&
749            "Expected dynamic subview size");
750     return sizeInterface.getDynamicSize(unsignedIndex);
751   }
752 
753   // dim(memrefcast) -> dim
754   if (succeeded(foldMemRefCast(*this)))
755     return getResult();
756 
757   return {};
758 }
759 
760 namespace {
761 /// Fold dim of a memref reshape operation to a load into the reshape's shape
762 /// operand.
763 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
764   using OpRewritePattern<DimOp>::OpRewritePattern;
765 
766   LogicalResult matchAndRewrite(DimOp dim,
767                                 PatternRewriter &rewriter) const override {
768     auto reshape = dim.memrefOrTensor().getDefiningOp<ReshapeOp>();
769 
770     if (!reshape)
771       return failure();
772 
773     // Place the load directly after the reshape to ensure that the shape memref
774     // was not mutated.
775     rewriter.setInsertionPointAfter(reshape);
776     Location loc = dim.getLoc();
777     Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index());
778     if (load.getType() != dim.getType())
779       load = rewriter.create<IndexCastOp>(loc, dim.getType(), load);
780     rewriter.replaceOp(dim, load);
781     return success();
782   }
783 };
784 
785 /// Fold dim of a dim of a cast into the dim of the source of the tensor cast.
786 template <typename CastOpTy>
787 struct DimOfCastOp : public OpRewritePattern<DimOp> {
788   using OpRewritePattern<DimOp>::OpRewritePattern;
789 
790   LogicalResult matchAndRewrite(DimOp dimOp,
791                                 PatternRewriter &rewriter) const override {
792     auto castOp = dimOp.memrefOrTensor().getDefiningOp<CastOpTy>();
793     if (!castOp)
794       return failure();
795     Value newSource = castOp.getOperand();
796     rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.index());
797     return success();
798   }
799 };
800 } // end anonymous namespace.
801 
802 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
803                                         MLIRContext *context) {
804   results.add<DimOfMemRefReshape, DimOfCastOp<BufferCastOp>,
805               DimOfCastOp<tensor::CastOp>>(context);
806 }
807 
808 // ---------------------------------------------------------------------------
809 // DmaStartOp
810 // ---------------------------------------------------------------------------
811 
812 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
813                        Value srcMemRef, ValueRange srcIndices, Value destMemRef,
814                        ValueRange destIndices, Value numElements,
815                        Value tagMemRef, ValueRange tagIndices, Value stride,
816                        Value elementsPerStride) {
817   result.addOperands(srcMemRef);
818   result.addOperands(srcIndices);
819   result.addOperands(destMemRef);
820   result.addOperands(destIndices);
821   result.addOperands({numElements, tagMemRef});
822   result.addOperands(tagIndices);
823   if (stride)
824     result.addOperands({stride, elementsPerStride});
825 }
826 
827 void DmaStartOp::print(OpAsmPrinter &p) {
828   p << getOperationName() << " " << getSrcMemRef() << '[' << getSrcIndices()
829     << "], " << getDstMemRef() << '[' << getDstIndices() << "], "
830     << getNumElements() << ", " << getTagMemRef() << '[' << getTagIndices()
831     << ']';
832   if (isStrided())
833     p << ", " << getStride() << ", " << getNumElementsPerStride();
834 
835   p.printOptionalAttrDict((*this)->getAttrs());
836   p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
837     << ", " << getTagMemRef().getType();
838 }
839 
840 // Parse DmaStartOp.
841 // Ex:
842 //   %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
843 //                       %tag[%index], %stride, %num_elt_per_stride :
844 //                     : memref<3076 x f32, 0>,
845 //                       memref<1024 x f32, 2>,
846 //                       memref<1 x i32>
847 //
848 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
849   OpAsmParser::OperandType srcMemRefInfo;
850   SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
851   OpAsmParser::OperandType dstMemRefInfo;
852   SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos;
853   OpAsmParser::OperandType numElementsInfo;
854   OpAsmParser::OperandType tagMemrefInfo;
855   SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
856   SmallVector<OpAsmParser::OperandType, 2> strideInfo;
857 
858   SmallVector<Type, 3> types;
859   auto indexType = parser.getBuilder().getIndexType();
860 
861   // Parse and resolve the following list of operands:
862   // *) source memref followed by its indices (in square brackets).
863   // *) destination memref followed by its indices (in square brackets).
864   // *) dma size in KiB.
865   if (parser.parseOperand(srcMemRefInfo) ||
866       parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
867       parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
868       parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
869       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
870       parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
871       parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
872     return failure();
873 
874   // Parse optional stride and elements per stride.
875   if (parser.parseTrailingOperandList(strideInfo))
876     return failure();
877 
878   bool isStrided = strideInfo.size() == 2;
879   if (!strideInfo.empty() && !isStrided) {
880     return parser.emitError(parser.getNameLoc(),
881                             "expected two stride related operands");
882   }
883 
884   if (parser.parseColonTypeList(types))
885     return failure();
886   if (types.size() != 3)
887     return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
888 
889   if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
890       parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
891       parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
892       parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
893       // size should be an index.
894       parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
895       parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
896       // tag indices should be index.
897       parser.resolveOperands(tagIndexInfos, indexType, result.operands))
898     return failure();
899 
900   if (isStrided) {
901     if (parser.resolveOperands(strideInfo, indexType, result.operands))
902       return failure();
903   }
904 
905   return success();
906 }
907 
908 LogicalResult DmaStartOp::verify() {
909   unsigned numOperands = getNumOperands();
910 
911   // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
912   // the number of elements.
913   if (numOperands < 4)
914     return emitOpError("expected at least 4 operands");
915 
916   // Check types of operands. The order of these calls is important: the later
917   // calls rely on some type properties to compute the operand position.
918   // 1. Source memref.
919   if (!getSrcMemRef().getType().isa<MemRefType>())
920     return emitOpError("expected source to be of memref type");
921   if (numOperands < getSrcMemRefRank() + 4)
922     return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
923                          << " operands";
924   if (!getSrcIndices().empty() &&
925       !llvm::all_of(getSrcIndices().getTypes(),
926                     [](Type t) { return t.isIndex(); }))
927     return emitOpError("expected source indices to be of index type");
928 
929   // 2. Destination memref.
930   if (!getDstMemRef().getType().isa<MemRefType>())
931     return emitOpError("expected destination to be of memref type");
932   unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
933   if (numOperands < numExpectedOperands)
934     return emitOpError() << "expected at least " << numExpectedOperands
935                          << " operands";
936   if (!getDstIndices().empty() &&
937       !llvm::all_of(getDstIndices().getTypes(),
938                     [](Type t) { return t.isIndex(); }))
939     return emitOpError("expected destination indices to be of index type");
940 
941   // 3. Number of elements.
942   if (!getNumElements().getType().isIndex())
943     return emitOpError("expected num elements to be of index type");
944 
945   // 4. Tag memref.
946   if (!getTagMemRef().getType().isa<MemRefType>())
947     return emitOpError("expected tag to be of memref type");
948   numExpectedOperands += getTagMemRefRank();
949   if (numOperands < numExpectedOperands)
950     return emitOpError() << "expected at least " << numExpectedOperands
951                          << " operands";
952   if (!getTagIndices().empty() &&
953       !llvm::all_of(getTagIndices().getTypes(),
954                     [](Type t) { return t.isIndex(); }))
955     return emitOpError("expected tag indices to be of index type");
956 
957   // Optional stride-related operands must be either both present or both
958   // absent.
959   if (numOperands != numExpectedOperands &&
960       numOperands != numExpectedOperands + 2)
961     return emitOpError("incorrect number of operands");
962 
963   // 5. Strides.
964   if (isStrided()) {
965     if (!getStride().getType().isIndex() ||
966         !getNumElementsPerStride().getType().isIndex())
967       return emitOpError(
968           "expected stride and num elements per stride to be of type index");
969   }
970 
971   return success();
972 }
973 
974 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
975                                SmallVectorImpl<OpFoldResult> &results) {
976   /// dma_start(memrefcast) -> dma_start
977   return foldMemRefCast(*this);
978 }
979 
980 // ---------------------------------------------------------------------------
981 // DmaWaitOp
982 // ---------------------------------------------------------------------------
983 
984 void DmaWaitOp::build(OpBuilder &builder, OperationState &result,
985                       Value tagMemRef, ValueRange tagIndices,
986                       Value numElements) {
987   result.addOperands(tagMemRef);
988   result.addOperands(tagIndices);
989   result.addOperands(numElements);
990 }
991 
992 void DmaWaitOp::print(OpAsmPrinter &p) {
993   p << getOperationName() << " " << getTagMemRef() << '[' << getTagIndices()
994     << "], " << getNumElements();
995   p.printOptionalAttrDict((*this)->getAttrs());
996   p << " : " << getTagMemRef().getType();
997 }
998 
999 // Parse DmaWaitOp.
1000 // Eg:
1001 //   dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4>
1002 //
1003 ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) {
1004   OpAsmParser::OperandType tagMemrefInfo;
1005   SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
1006   Type type;
1007   auto indexType = parser.getBuilder().getIndexType();
1008   OpAsmParser::OperandType numElementsInfo;
1009 
1010   // Parse tag memref, its indices, and dma size.
1011   if (parser.parseOperand(tagMemrefInfo) ||
1012       parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) ||
1013       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1014       parser.parseColonType(type) ||
1015       parser.resolveOperand(tagMemrefInfo, type, result.operands) ||
1016       parser.resolveOperands(tagIndexInfos, indexType, result.operands) ||
1017       parser.resolveOperand(numElementsInfo, indexType, result.operands))
1018     return failure();
1019 
1020   return success();
1021 }
1022 
1023 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1024                               SmallVectorImpl<OpFoldResult> &results) {
1025   /// dma_wait(memrefcast) -> dma_wait
1026   return foldMemRefCast(*this);
1027 }
1028 
1029 LogicalResult DmaWaitOp::verify() {
1030   // Mandatory non-variadic operands are tag and the number of elements.
1031   if (getNumOperands() < 2)
1032     return emitOpError() << "expected at least 2 operands";
1033 
1034   // Check types of operands. The order of these calls is important: the later
1035   // calls rely on some type properties to compute the operand position.
1036   if (!getTagMemRef().getType().isa<MemRefType>())
1037     return emitOpError() << "expected tag to be of memref type";
1038 
1039   if (getNumOperands() != 2 + getTagMemRefRank())
1040     return emitOpError() << "expected " << 2 + getTagMemRefRank()
1041                          << " operands";
1042 
1043   if (!getTagIndices().empty() &&
1044       !llvm::all_of(getTagIndices().getTypes(),
1045                     [](Type t) { return t.isIndex(); }))
1046     return emitOpError() << "expected tag indices to be of index type";
1047 
1048   if (!getNumElements().getType().isIndex())
1049     return emitOpError()
1050            << "expected the number of elements to be of index type";
1051 
1052   return success();
1053 }
1054 
1055 //===----------------------------------------------------------------------===//
1056 // GlobalOp
1057 //===----------------------------------------------------------------------===//
1058 
1059 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op,
1060                                                    TypeAttr type,
1061                                                    Attribute initialValue) {
1062   p << type;
1063   if (!op.isExternal()) {
1064     p << " = ";
1065     if (op.isUninitialized())
1066       p << "uninitialized";
1067     else
1068       p.printAttributeWithoutType(initialValue);
1069   }
1070 }
1071 
1072 static ParseResult
1073 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
1074                                        Attribute &initialValue) {
1075   Type type;
1076   if (parser.parseType(type))
1077     return failure();
1078 
1079   auto memrefType = type.dyn_cast<MemRefType>();
1080   if (!memrefType || !memrefType.hasStaticShape())
1081     return parser.emitError(parser.getNameLoc())
1082            << "type should be static shaped memref, but got " << type;
1083   typeAttr = TypeAttr::get(type);
1084 
1085   if (parser.parseOptionalEqual())
1086     return success();
1087 
1088   if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1089     initialValue = UnitAttr::get(parser.getBuilder().getContext());
1090     return success();
1091   }
1092 
1093   Type tensorType = getTensorTypeFromMemRefType(memrefType);
1094   if (parser.parseAttribute(initialValue, tensorType))
1095     return failure();
1096   if (!initialValue.isa<ElementsAttr>())
1097     return parser.emitError(parser.getNameLoc())
1098            << "initial value should be a unit or elements attribute";
1099   return success();
1100 }
1101 
1102 static LogicalResult verify(GlobalOp op) {
1103   auto memrefType = op.type().dyn_cast<MemRefType>();
1104   if (!memrefType || !memrefType.hasStaticShape())
1105     return op.emitOpError("type should be static shaped memref, but got ")
1106            << op.type();
1107 
1108   // Verify that the initial value, if present, is either a unit attribute or
1109   // an elements attribute.
1110   if (op.initial_value().hasValue()) {
1111     Attribute initValue = op.initial_value().getValue();
1112     if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>())
1113       return op.emitOpError("initial value should be a unit or elements "
1114                             "attribute, but got ")
1115              << initValue;
1116 
1117     // Check that the type of the initial value is compatible with the type of
1118     // the global variable.
1119     if (initValue.isa<ElementsAttr>()) {
1120       Type initType = initValue.getType();
1121       Type tensorType = getTensorTypeFromMemRefType(memrefType);
1122       if (initType != tensorType)
1123         return op.emitOpError("initial value expected to be of type ")
1124                << tensorType << ", but was of type " << initType;
1125     }
1126   }
1127 
1128   // TODO: verify visibility for declarations.
1129   return success();
1130 }
1131 
1132 //===----------------------------------------------------------------------===//
1133 // GetGlobalOp
1134 //===----------------------------------------------------------------------===//
1135 
1136 LogicalResult
1137 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1138   // Verify that the result type is same as the type of the referenced
1139   // memref.global op.
1140   auto global =
1141       symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr());
1142   if (!global)
1143     return emitOpError("'")
1144            << name() << "' does not reference a valid global memref";
1145 
1146   Type resultType = result().getType();
1147   if (global.type() != resultType)
1148     return emitOpError("result type ")
1149            << resultType << " does not match type " << global.type()
1150            << " of the global memref @" << name();
1151   return success();
1152 }
1153 
1154 //===----------------------------------------------------------------------===//
1155 // LoadOp
1156 //===----------------------------------------------------------------------===//
1157 
1158 static LogicalResult verify(LoadOp op) {
1159   if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
1160     return op.emitOpError("incorrect number of indices for load");
1161   return success();
1162 }
1163 
1164 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
1165   /// load(memrefcast) -> load
1166   if (succeeded(foldMemRefCast(*this)))
1167     return getResult();
1168   return OpFoldResult();
1169 }
1170 
1171 namespace {
1172 /// Fold a load on a buffer_cast operation into an tensor.extract on the
1173 /// corresponding tensor.
1174 struct LoadOfBufferCast : public OpRewritePattern<LoadOp> {
1175   using OpRewritePattern<LoadOp>::OpRewritePattern;
1176 
1177   LogicalResult matchAndRewrite(LoadOp load,
1178                                 PatternRewriter &rewriter) const override {
1179     auto buffercast = load.memref().getDefiningOp<BufferCastOp>();
1180     if (!buffercast)
1181       return failure();
1182 
1183     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, buffercast.tensor(),
1184                                                    load.indices());
1185     return success();
1186   }
1187 };
1188 } // end anonymous namespace.
1189 
1190 void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
1191                                          MLIRContext *context) {
1192   results.add<LoadOfBufferCast>(context);
1193 }
1194 
1195 //===----------------------------------------------------------------------===//
1196 // PrefetchOp
1197 //===----------------------------------------------------------------------===//
1198 
1199 static void print(OpAsmPrinter &p, PrefetchOp op) {
1200   p << PrefetchOp::getOperationName() << " " << op.memref() << '[';
1201   p.printOperands(op.indices());
1202   p << ']' << ", " << (op.isWrite() ? "write" : "read");
1203   p << ", locality<" << op.localityHint();
1204   p << ">, " << (op.isDataCache() ? "data" : "instr");
1205   p.printOptionalAttrDict(
1206       op->getAttrs(),
1207       /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1208   p << " : " << op.getMemRefType();
1209 }
1210 
1211 static ParseResult parsePrefetchOp(OpAsmParser &parser,
1212                                    OperationState &result) {
1213   OpAsmParser::OperandType memrefInfo;
1214   SmallVector<OpAsmParser::OperandType, 4> indexInfo;
1215   IntegerAttr localityHint;
1216   MemRefType type;
1217   StringRef readOrWrite, cacheType;
1218 
1219   auto indexTy = parser.getBuilder().getIndexType();
1220   auto i32Type = parser.getBuilder().getIntegerType(32);
1221   if (parser.parseOperand(memrefInfo) ||
1222       parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1223       parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1224       parser.parseComma() || parser.parseKeyword("locality") ||
1225       parser.parseLess() ||
1226       parser.parseAttribute(localityHint, i32Type, "localityHint",
1227                             result.attributes) ||
1228       parser.parseGreater() || parser.parseComma() ||
1229       parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1230       parser.resolveOperand(memrefInfo, type, result.operands) ||
1231       parser.resolveOperands(indexInfo, indexTy, result.operands))
1232     return failure();
1233 
1234   if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
1235     return parser.emitError(parser.getNameLoc(),
1236                             "rw specifier has to be 'read' or 'write'");
1237   result.addAttribute(
1238       PrefetchOp::getIsWriteAttrName(),
1239       parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
1240 
1241   if (!cacheType.equals("data") && !cacheType.equals("instr"))
1242     return parser.emitError(parser.getNameLoc(),
1243                             "cache type has to be 'data' or 'instr'");
1244 
1245   result.addAttribute(
1246       PrefetchOp::getIsDataCacheAttrName(),
1247       parser.getBuilder().getBoolAttr(cacheType.equals("data")));
1248 
1249   return success();
1250 }
1251 
1252 static LogicalResult verify(PrefetchOp op) {
1253   if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
1254     return op.emitOpError("too few indices");
1255 
1256   return success();
1257 }
1258 
1259 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
1260                                SmallVectorImpl<OpFoldResult> &results) {
1261   // prefetch(memrefcast) -> prefetch
1262   return foldMemRefCast(*this);
1263 }
1264 
1265 //===----------------------------------------------------------------------===//
1266 // ReinterpretCastOp
1267 //===----------------------------------------------------------------------===//
1268 
1269 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1270 /// `staticSizes` and `staticStrides` are automatically filled with
1271 /// source-memref-rank sentinel values that encode dynamic entries.
1272 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1273                               MemRefType resultType, Value source,
1274                               OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1275                               ArrayRef<OpFoldResult> strides,
1276                               ArrayRef<NamedAttribute> attrs) {
1277   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1278   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1279   dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets,
1280                              ShapedType::kDynamicStrideOrOffset);
1281   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1282                              ShapedType::kDynamicSize);
1283   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1284                              ShapedType::kDynamicStrideOrOffset);
1285   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1286         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1287         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1288   result.addAttributes(attrs);
1289 }
1290 
1291 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1292                               MemRefType resultType, Value source,
1293                               int64_t offset, ArrayRef<int64_t> sizes,
1294                               ArrayRef<int64_t> strides,
1295                               ArrayRef<NamedAttribute> attrs) {
1296   SmallVector<OpFoldResult> sizeValues =
1297       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1298         return b.getI64IntegerAttr(v);
1299       }));
1300   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1301       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1302         return b.getI64IntegerAttr(v);
1303       }));
1304   build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1305         strideValues, attrs);
1306 }
1307 
1308 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1309                               MemRefType resultType, Value source, Value offset,
1310                               ValueRange sizes, ValueRange strides,
1311                               ArrayRef<NamedAttribute> attrs) {
1312   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1313       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1314   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1315       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1316   build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1317 }
1318 
1319 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1320 // completed automatically, like we have for subview and extract_slice.
1321 static LogicalResult verify(ReinterpretCastOp op) {
1322   // The source and result memrefs should be in the same memory space.
1323   auto srcType = op.source().getType().cast<BaseMemRefType>();
1324   auto resultType = op.getType().cast<MemRefType>();
1325   if (srcType.getMemorySpace() != resultType.getMemorySpace())
1326     return op.emitError("different memory spaces specified for source type ")
1327            << srcType << " and result memref type " << resultType;
1328   if (srcType.getElementType() != resultType.getElementType())
1329     return op.emitError("different element types specified for source type ")
1330            << srcType << " and result memref type " << resultType;
1331 
1332   // Match sizes in result memref type and in static_sizes attribute.
1333   for (auto &en :
1334        llvm::enumerate(llvm::zip(resultType.getShape(),
1335                                  extractFromI64ArrayAttr(op.static_sizes())))) {
1336     int64_t resultSize = std::get<0>(en.value());
1337     int64_t expectedSize = std::get<1>(en.value());
1338     if (resultSize != expectedSize)
1339       return op.emitError("expected result type with size = ")
1340              << expectedSize << " instead of " << resultSize
1341              << " in dim = " << en.index();
1342   }
1343 
1344   // Match offset and strides in static_offset and static_strides attributes if
1345   // result memref type has an affine map specified.
1346   if (!resultType.getAffineMaps().empty()) {
1347     int64_t resultOffset;
1348     SmallVector<int64_t, 4> resultStrides;
1349     if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
1350       return failure();
1351 
1352     // Match offset in result memref type and in static_offsets attribute.
1353     int64_t expectedOffset =
1354         extractFromI64ArrayAttr(op.static_offsets()).front();
1355     if (resultOffset != expectedOffset)
1356       return op.emitError("expected result type with offset = ")
1357              << resultOffset << " instead of " << expectedOffset;
1358 
1359     // Match strides in result memref type and in static_strides attribute.
1360     for (auto &en : llvm::enumerate(llvm::zip(
1361              resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
1362       int64_t resultStride = std::get<0>(en.value());
1363       int64_t expectedStride = std::get<1>(en.value());
1364       if (resultStride != expectedStride)
1365         return op.emitError("expected result type with stride = ")
1366                << expectedStride << " instead of " << resultStride
1367                << " in dim = " << en.index();
1368     }
1369   }
1370   return success();
1371 }
1372 
1373 //===----------------------------------------------------------------------===//
1374 // ReshapeOp
1375 //===----------------------------------------------------------------------===//
1376 
1377 static LogicalResult verify(ReshapeOp op) {
1378   Type operandType = op.source().getType();
1379   Type resultType = op.result().getType();
1380 
1381   Type operandElementType = operandType.cast<ShapedType>().getElementType();
1382   Type resultElementType = resultType.cast<ShapedType>().getElementType();
1383   if (operandElementType != resultElementType)
1384     return op.emitOpError("element types of source and destination memref "
1385                           "types should be the same");
1386 
1387   if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
1388     if (!operandMemRefType.getAffineMaps().empty())
1389       return op.emitOpError(
1390           "source memref type should have identity affine map");
1391 
1392   int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0);
1393   auto resultMemRefType = resultType.dyn_cast<MemRefType>();
1394   if (resultMemRefType) {
1395     if (!resultMemRefType.getAffineMaps().empty())
1396       return op.emitOpError(
1397           "result memref type should have identity affine map");
1398     if (shapeSize == ShapedType::kDynamicSize)
1399       return op.emitOpError("cannot use shape operand with dynamic length to "
1400                             "reshape to statically-ranked memref type");
1401     if (shapeSize != resultMemRefType.getRank())
1402       return op.emitOpError(
1403           "length of shape operand differs from the result's memref rank");
1404   }
1405   return success();
1406 }
1407 
1408 //===----------------------------------------------------------------------===//
1409 // StoreOp
1410 //===----------------------------------------------------------------------===//
1411 
1412 static LogicalResult verify(StoreOp op) {
1413   if (op.getNumOperands() != 2 + op.getMemRefType().getRank())
1414     return op.emitOpError("store index operand count not equal to memref rank");
1415 
1416   return success();
1417 }
1418 
1419 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
1420                             SmallVectorImpl<OpFoldResult> &results) {
1421   /// store(memrefcast) -> store
1422   return foldMemRefCast(*this, getValueToStore());
1423 }
1424 
1425 //===----------------------------------------------------------------------===//
1426 // SubViewOp
1427 //===----------------------------------------------------------------------===//
1428 
1429 namespace {
1430 /// Helpers to write more idiomatic operations.
1431 namespace saturated_arith {
1432 struct Wrapper {
1433   explicit Wrapper(int64_t v) : v(v) {}
1434   operator int64_t() { return v; }
1435   int64_t v;
1436 };
1437 Wrapper operator+(Wrapper a, int64_t b) {
1438   if (ShapedType::isDynamicStrideOrOffset(a) ||
1439       ShapedType::isDynamicStrideOrOffset(b))
1440     return Wrapper(ShapedType::kDynamicStrideOrOffset);
1441   return Wrapper(a.v + b);
1442 }
1443 Wrapper operator*(Wrapper a, int64_t b) {
1444   if (ShapedType::isDynamicStrideOrOffset(a) ||
1445       ShapedType::isDynamicStrideOrOffset(b))
1446     return Wrapper(ShapedType::kDynamicStrideOrOffset);
1447   return Wrapper(a.v * b);
1448 }
1449 } // end namespace saturated_arith
1450 } // end namespace
1451 
1452 /// A subview result type can be fully inferred from the source type and the
1453 /// static representation of offsets, sizes and strides. Special sentinels
1454 /// encode the dynamic case.
1455 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
1456                                 ArrayRef<int64_t> leadingStaticOffsets,
1457                                 ArrayRef<int64_t> leadingStaticSizes,
1458                                 ArrayRef<int64_t> leadingStaticStrides) {
1459   // A subview may specify only a leading subset of offset/sizes/strides in
1460   // which case we complete with offset=0, sizes from memref type and strides=1.
1461   unsigned rank = sourceMemRefType.getRank();
1462   assert(leadingStaticOffsets.size() <= rank &&
1463          "unexpected leadingStaticOffsets overflow");
1464   assert(leadingStaticSizes.size() <= rank &&
1465          "unexpected leadingStaticSizes overflow");
1466   assert(leadingStaticStrides.size() <= rank &&
1467          "unexpected leadingStaticStrides overflow");
1468   auto staticOffsets = llvm::to_vector<4>(leadingStaticOffsets);
1469   auto staticSizes = llvm::to_vector<4>(leadingStaticSizes);
1470   auto staticStrides = llvm::to_vector<4>(leadingStaticStrides);
1471   unsigned numTrailingOffsets = rank - staticOffsets.size();
1472   unsigned numTrailingSizes = rank - staticSizes.size();
1473   unsigned numTrailingStrides = rank - staticStrides.size();
1474   staticOffsets.append(numTrailingOffsets, 0);
1475   llvm::append_range(staticSizes,
1476                      sourceMemRefType.getShape().take_back(numTrailingSizes));
1477   staticStrides.append(numTrailingStrides, 1);
1478 
1479   // Extract source offset and strides.
1480   int64_t sourceOffset;
1481   SmallVector<int64_t, 4> sourceStrides;
1482   auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset);
1483   assert(succeeded(res) && "SubViewOp expected strided memref type");
1484   (void)res;
1485 
1486   // Compute target offset whose value is:
1487   //   `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
1488   int64_t targetOffset = sourceOffset;
1489   for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
1490     auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
1491     using namespace saturated_arith;
1492     targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride;
1493   }
1494 
1495   // Compute target stride whose value is:
1496   //   `sourceStrides_i * staticStrides_i`.
1497   SmallVector<int64_t, 4> targetStrides;
1498   targetStrides.reserve(staticOffsets.size());
1499   for (auto it : llvm::zip(sourceStrides, staticStrides)) {
1500     auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
1501     using namespace saturated_arith;
1502     targetStrides.push_back(Wrapper(sourceStride) * staticStride);
1503   }
1504 
1505   // The type is now known.
1506   return MemRefType::get(
1507       staticSizes, sourceMemRefType.getElementType(),
1508       makeStridedLinearLayoutMap(targetStrides, targetOffset,
1509                                  sourceMemRefType.getContext()),
1510       sourceMemRefType.getMemorySpace());
1511 }
1512 
1513 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
1514                                 ArrayRef<OpFoldResult> leadingStaticOffsets,
1515                                 ArrayRef<OpFoldResult> leadingStaticSizes,
1516                                 ArrayRef<OpFoldResult> leadingStaticStrides) {
1517   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1518   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1519   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
1520                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
1521   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
1522                              ShapedType::kDynamicSize);
1523   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
1524                              staticStrides, ShapedType::kDynamicStrideOrOffset);
1525   return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
1526                                     staticSizes, staticStrides)
1527       .cast<MemRefType>();
1528 }
1529 
1530 Type SubViewOp::inferRankReducedResultType(
1531     unsigned resultRank, MemRefType sourceRankedTensorType,
1532     ArrayRef<int64_t> leadingStaticOffsets,
1533     ArrayRef<int64_t> leadingStaticSizes,
1534     ArrayRef<int64_t> leadingStaticStrides) {
1535   auto inferredType =
1536       inferResultType(sourceRankedTensorType, leadingStaticOffsets,
1537                       leadingStaticSizes, leadingStaticStrides)
1538           .cast<MemRefType>();
1539   assert(inferredType.getRank() >= resultRank && "expected ");
1540   int rankDiff = inferredType.getRank() - resultRank;
1541   if (rankDiff > 0) {
1542     auto shape = inferredType.getShape();
1543     llvm::SmallDenseSet<unsigned> dimsToProject;
1544     mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject);
1545     SmallVector<int64_t> projectedShape;
1546     for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
1547       if (!dimsToProject.contains(pos))
1548         projectedShape.push_back(shape[pos]);
1549 
1550     AffineMap map;
1551     auto maps = inferredType.getAffineMaps();
1552     if (!maps.empty() && maps.front())
1553       map = getProjectedMap(maps.front(), dimsToProject);
1554     inferredType =
1555         MemRefType::get(projectedShape, inferredType.getElementType(), map,
1556                         inferredType.getMemorySpace());
1557   }
1558   return inferredType;
1559 }
1560 
1561 Type SubViewOp::inferRankReducedResultType(
1562     unsigned resultRank, MemRefType sourceRankedTensorType,
1563     ArrayRef<OpFoldResult> leadingStaticOffsets,
1564     ArrayRef<OpFoldResult> leadingStaticSizes,
1565     ArrayRef<OpFoldResult> leadingStaticStrides) {
1566   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1567   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1568   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
1569                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
1570   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
1571                              ShapedType::kDynamicSize);
1572   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
1573                              staticStrides, ShapedType::kDynamicStrideOrOffset);
1574   return SubViewOp::inferRankReducedResultType(
1575       resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
1576       staticStrides);
1577 }
1578 // Build a SubViewOp with mixed static and dynamic entries and custom result
1579 // type. If the type passed is nullptr, it is inferred.
1580 void SubViewOp::build(OpBuilder &b, OperationState &result,
1581                       MemRefType resultType, Value source,
1582                       ArrayRef<OpFoldResult> offsets,
1583                       ArrayRef<OpFoldResult> sizes,
1584                       ArrayRef<OpFoldResult> strides,
1585                       ArrayRef<NamedAttribute> attrs) {
1586   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1587   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1588   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1589                              ShapedType::kDynamicStrideOrOffset);
1590   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1591                              ShapedType::kDynamicSize);
1592   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1593                              ShapedType::kDynamicStrideOrOffset);
1594   auto sourceMemRefType = source.getType().cast<MemRefType>();
1595   // Structuring implementation this way avoids duplication between builders.
1596   if (!resultType) {
1597     resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
1598                                             staticSizes, staticStrides)
1599                      .cast<MemRefType>();
1600   }
1601   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1602         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1603         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1604   result.addAttributes(attrs);
1605 }
1606 
1607 // Build a SubViewOp with mixed static and dynamic entries and inferred result
1608 // type.
1609 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1610                       ArrayRef<OpFoldResult> offsets,
1611                       ArrayRef<OpFoldResult> sizes,
1612                       ArrayRef<OpFoldResult> strides,
1613                       ArrayRef<NamedAttribute> attrs) {
1614   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
1615 }
1616 
1617 // Build a SubViewOp with static entries and inferred result type.
1618 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1619                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1620                       ArrayRef<int64_t> strides,
1621                       ArrayRef<NamedAttribute> attrs) {
1622   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1623       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
1624         return b.getI64IntegerAttr(v);
1625       }));
1626   SmallVector<OpFoldResult> sizeValues =
1627       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1628         return b.getI64IntegerAttr(v);
1629       }));
1630   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1631       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1632         return b.getI64IntegerAttr(v);
1633       }));
1634   build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
1635 }
1636 
1637 // Build a SubViewOp with dynamic entries and custom result type. If the
1638 // type passed is nullptr, it is inferred.
1639 void SubViewOp::build(OpBuilder &b, OperationState &result,
1640                       MemRefType resultType, Value source,
1641                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1642                       ArrayRef<int64_t> strides,
1643                       ArrayRef<NamedAttribute> attrs) {
1644   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1645       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
1646         return b.getI64IntegerAttr(v);
1647       }));
1648   SmallVector<OpFoldResult> sizeValues =
1649       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1650         return b.getI64IntegerAttr(v);
1651       }));
1652   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1653       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1654         return b.getI64IntegerAttr(v);
1655       }));
1656   build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
1657         attrs);
1658 }
1659 
1660 // Build a SubViewOp with dynamic entries and custom result type. If the type
1661 // passed is nullptr, it is inferred.
1662 void SubViewOp::build(OpBuilder &b, OperationState &result,
1663                       MemRefType resultType, Value source, ValueRange offsets,
1664                       ValueRange sizes, ValueRange strides,
1665                       ArrayRef<NamedAttribute> attrs) {
1666   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1667       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1668   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1669       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1670   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1671       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1672   build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
1673 }
1674 
1675 // Build a SubViewOp with dynamic entries and inferred result type.
1676 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1677                       ValueRange offsets, ValueRange sizes, ValueRange strides,
1678                       ArrayRef<NamedAttribute> attrs) {
1679   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
1680 }
1681 
1682 /// For ViewLikeOpInterface.
1683 Value SubViewOp::getViewSource() { return source(); }
1684 
1685 enum SubViewVerificationResult {
1686   Success,
1687   RankTooLarge,
1688   SizeMismatch,
1689   ElemTypeMismatch,
1690   MemSpaceMismatch,
1691   AffineMapMismatch
1692 };
1693 
1694 /// Checks if `original` Type type can be rank reduced to `reduced` type.
1695 /// This function is slight variant of `is subsequence` algorithm where
1696 /// not matching dimension must be 1.
1697 static SubViewVerificationResult
1698 isRankReducedType(Type originalType, Type candidateReducedType,
1699                   std::string *errMsg = nullptr) {
1700   if (originalType == candidateReducedType)
1701     return SubViewVerificationResult::Success;
1702   if (!originalType.isa<MemRefType>())
1703     return SubViewVerificationResult::Success;
1704   if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>())
1705     return SubViewVerificationResult::Success;
1706 
1707   ShapedType originalShapedType = originalType.cast<ShapedType>();
1708   ShapedType candidateReducedShapedType =
1709       candidateReducedType.cast<ShapedType>();
1710 
1711   // Rank and size logic is valid for all ShapedTypes.
1712   ArrayRef<int64_t> originalShape = originalShapedType.getShape();
1713   ArrayRef<int64_t> candidateReducedShape =
1714       candidateReducedShapedType.getShape();
1715   unsigned originalRank = originalShape.size(),
1716            candidateReducedRank = candidateReducedShape.size();
1717   if (candidateReducedRank > originalRank)
1718     return SubViewVerificationResult::RankTooLarge;
1719 
1720   auto optionalUnusedDimsMask =
1721       computeRankReductionMask(originalShape, candidateReducedShape);
1722 
1723   // Sizes cannot be matched in case empty vector is returned.
1724   if (!optionalUnusedDimsMask.hasValue())
1725     return SubViewVerificationResult::SizeMismatch;
1726 
1727   if (originalShapedType.getElementType() !=
1728       candidateReducedShapedType.getElementType())
1729     return SubViewVerificationResult::ElemTypeMismatch;
1730 
1731   // Strided layout logic is relevant for MemRefType only.
1732   MemRefType original = originalType.cast<MemRefType>();
1733   MemRefType candidateReduced = candidateReducedType.cast<MemRefType>();
1734   if (original.getMemorySpace() != candidateReduced.getMemorySpace())
1735     return SubViewVerificationResult::MemSpaceMismatch;
1736 
1737   llvm::SmallDenseSet<unsigned> unusedDims = optionalUnusedDimsMask.getValue();
1738   auto inferredType =
1739       getProjectedMap(getStridedLinearLayoutMap(original), unusedDims);
1740   AffineMap candidateLayout;
1741   if (candidateReduced.getAffineMaps().empty())
1742     candidateLayout = getStridedLinearLayoutMap(candidateReduced);
1743   else
1744     candidateLayout = candidateReduced.getAffineMaps().front();
1745   assert(inferredType.getNumResults() == 1 &&
1746          candidateLayout.getNumResults() == 1);
1747   if (inferredType.getNumSymbols() != candidateLayout.getNumSymbols() ||
1748       inferredType.getNumDims() != candidateLayout.getNumDims()) {
1749     if (errMsg) {
1750       llvm::raw_string_ostream os(*errMsg);
1751       os << "inferred type: " << inferredType;
1752     }
1753     return SubViewVerificationResult::AffineMapMismatch;
1754   }
1755   // Check that the difference of the affine maps simplifies to 0.
1756   AffineExpr diffExpr =
1757       inferredType.getResult(0) - candidateLayout.getResult(0);
1758   diffExpr = simplifyAffineExpr(diffExpr, inferredType.getNumDims(),
1759                                 inferredType.getNumSymbols());
1760   auto cst = diffExpr.dyn_cast<AffineConstantExpr>();
1761   if (!(cst && cst.getValue() == 0)) {
1762     if (errMsg) {
1763       llvm::raw_string_ostream os(*errMsg);
1764       os << "inferred type: " << inferredType;
1765     }
1766     return SubViewVerificationResult::AffineMapMismatch;
1767   }
1768   return SubViewVerificationResult::Success;
1769 }
1770 
1771 template <typename OpTy>
1772 static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result,
1773                                             OpTy op, Type expectedType,
1774                                             StringRef errMsg = "") {
1775   auto memrefType = expectedType.cast<ShapedType>();
1776   switch (result) {
1777   case SubViewVerificationResult::Success:
1778     return success();
1779   case SubViewVerificationResult::RankTooLarge:
1780     return op.emitError("expected result rank to be smaller or equal to ")
1781            << "the source rank. " << errMsg;
1782   case SubViewVerificationResult::SizeMismatch:
1783     return op.emitError("expected result type to be ")
1784            << expectedType
1785            << " or a rank-reduced version. (mismatch of result sizes) "
1786            << errMsg;
1787   case SubViewVerificationResult::ElemTypeMismatch:
1788     return op.emitError("expected result element type to be ")
1789            << memrefType.getElementType() << errMsg;
1790   case SubViewVerificationResult::MemSpaceMismatch:
1791     return op.emitError("expected result and source memory spaces to match.")
1792            << errMsg;
1793   case SubViewVerificationResult::AffineMapMismatch:
1794     return op.emitError("expected result type to be ")
1795            << expectedType
1796            << " or a rank-reduced version. (mismatch of result affine map) "
1797            << errMsg;
1798   }
1799   llvm_unreachable("unexpected subview verification result");
1800 }
1801 
1802 /// Verifier for SubViewOp.
1803 static LogicalResult verify(SubViewOp op) {
1804   MemRefType baseType = op.getSourceType();
1805   MemRefType subViewType = op.getType();
1806 
1807   // The base memref and the view memref should be in the same memory space.
1808   if (baseType.getMemorySpace() != subViewType.getMemorySpace())
1809     return op.emitError("different memory spaces specified for base memref "
1810                         "type ")
1811            << baseType << " and subview memref type " << subViewType;
1812 
1813   // Verify that the base memref type has a strided layout map.
1814   if (!isStrided(baseType))
1815     return op.emitError("base type ") << baseType << " is not strided";
1816 
1817   // Verify result type against inferred type.
1818   auto expectedType = SubViewOp::inferResultType(
1819       baseType, extractFromI64ArrayAttr(op.static_offsets()),
1820       extractFromI64ArrayAttr(op.static_sizes()),
1821       extractFromI64ArrayAttr(op.static_strides()));
1822 
1823   std::string errMsg;
1824   auto result = isRankReducedType(expectedType, subViewType, &errMsg);
1825   return produceSubViewErrorMsg(result, op, expectedType, errMsg);
1826 }
1827 
1828 raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) {
1829   return os << "range " << range.offset << ":" << range.size << ":"
1830             << range.stride;
1831 }
1832 
1833 /// Return the list of Range (i.e. offset, size, stride). Each Range
1834 /// entry contains either the dynamic value or a ConstantIndexOp constructed
1835 /// with `b` at location `loc`.
1836 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
1837                                               OpBuilder &b, Location loc) {
1838   std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
1839   assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
1840   assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
1841   SmallVector<Range, 8> res;
1842   unsigned rank = ranks[0];
1843   res.reserve(rank);
1844   for (unsigned idx = 0; idx < rank; ++idx) {
1845     Value offset =
1846         op.isDynamicOffset(idx)
1847             ? op.getDynamicOffset(idx)
1848             : b.create<ConstantIndexOp>(loc, op.getStaticOffset(idx));
1849     Value size = op.isDynamicSize(idx)
1850                      ? op.getDynamicSize(idx)
1851                      : b.create<ConstantIndexOp>(loc, op.getStaticSize(idx));
1852     Value stride =
1853         op.isDynamicStride(idx)
1854             ? op.getDynamicStride(idx)
1855             : b.create<ConstantIndexOp>(loc, op.getStaticStride(idx));
1856     res.emplace_back(Range{offset, size, stride});
1857   }
1858   return res;
1859 }
1860 
1861 /// Infer the canonical type of the result of a subview operation. Returns a
1862 /// type with rank `resultRank` that is either the rank of the rank-reduced
1863 /// type, or the non-rank-reduced type.
1864 static MemRefType
1865 getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType,
1866                               ArrayRef<OpFoldResult> mixedOffsets,
1867                               ArrayRef<OpFoldResult> mixedSizes,
1868                               ArrayRef<OpFoldResult> mixedStrides) {
1869   auto resultType =
1870       SubViewOp::inferRankReducedResultType(
1871           resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides)
1872           .cast<MemRefType>();
1873   if (resultType.getRank() != resultRank) {
1874     resultType = SubViewOp::inferResultType(sourceType, mixedOffsets,
1875                                             mixedSizes, mixedStrides)
1876                      .cast<MemRefType>();
1877   }
1878   return resultType;
1879 }
1880 
1881 namespace {
1882 /// Pattern to rewrite a subview op with MemRefCast arguments.
1883 /// This essentially pushes memref.cast past its consuming subview when
1884 /// `canFoldIntoConsumerOp` is true.
1885 ///
1886 /// Example:
1887 /// ```
1888 ///   %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
1889 ///   %1 = memref.subview %0[0, 0][3, 4][1, 1] :
1890 ///     memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
1891 /// ```
1892 /// is rewritten into:
1893 /// ```
1894 ///   %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
1895 ///   %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to
1896 ///     memref<3x4xf32, offset:?, strides:[?, 1]>
1897 /// ```
1898 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
1899 public:
1900   using OpRewritePattern<SubViewOp>::OpRewritePattern;
1901 
1902   LogicalResult matchAndRewrite(SubViewOp subViewOp,
1903                                 PatternRewriter &rewriter) const override {
1904     // Any constant operand, just return to let SubViewOpConstantFolder kick in.
1905     if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
1906           return matchPattern(operand, matchConstantIndex());
1907         }))
1908       return failure();
1909 
1910     auto castOp = subViewOp.source().getDefiningOp<CastOp>();
1911     if (!castOp)
1912       return failure();
1913 
1914     if (!CastOp::canFoldIntoConsumerOp(castOp))
1915       return failure();
1916 
1917     /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
1918     /// the cast source operand type and the SubViewOp static information. This
1919     /// is the resulting type if the MemRefCastOp were folded.
1920     auto resultType = getCanonicalSubViewResultType(
1921         subViewOp.getType().getRank(),
1922         castOp.source().getType().cast<MemRefType>(),
1923         subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
1924         subViewOp.getMixedStrides());
1925     Value newSubView = rewriter.create<SubViewOp>(
1926         subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
1927         subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),
1928         subViewOp.static_sizes(), subViewOp.static_strides());
1929     rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
1930                                         newSubView);
1931     return success();
1932   }
1933 };
1934 } // namespace
1935 
1936 /// Return the canonical type of the result of a subview.
1937 struct SubViewReturnTypeCanonicalizer {
1938   MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
1939                         ArrayRef<OpFoldResult> mixedSizes,
1940                         ArrayRef<OpFoldResult> mixedStrides) {
1941     return getCanonicalSubViewResultType(op.getType().getRank(),
1942                                          op.getSourceType(), mixedOffsets,
1943                                          mixedSizes, mixedStrides);
1944   }
1945 };
1946 
1947 /// A canonicalizer wrapper to replace SubViewOps.
1948 struct SubViewCanonicalizer {
1949   void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
1950     rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType());
1951   }
1952 };
1953 
1954 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
1955                                             MLIRContext *context) {
1956   results
1957       .add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
1958                SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>,
1959            SubViewOpMemRefCastFolder>(context);
1960 }
1961 
1962 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
1963   auto resultShapedType = getResult().getType().cast<ShapedType>();
1964   auto sourceShapedType = source().getType().cast<ShapedType>();
1965 
1966   if (resultShapedType.hasStaticShape() &&
1967       resultShapedType == sourceShapedType) {
1968     return getViewSource();
1969   }
1970 
1971   return {};
1972 }
1973 
1974 //===----------------------------------------------------------------------===//
1975 // TensorLoadOp
1976 //===----------------------------------------------------------------------===//
1977 
1978 OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute>) {
1979   if (auto bufferCast = memref().getDefiningOp<BufferCastOp>())
1980     // Approximate alias analysis by conservatively folding only when no there
1981     // is no interleaved operation.
1982     if (bufferCast->getBlock() == this->getOperation()->getBlock() &&
1983         bufferCast->getNextNode() == this->getOperation())
1984       return bufferCast.tensor();
1985   return {};
1986 }
1987 
1988 //===----------------------------------------------------------------------===//
1989 // TransposeOp
1990 //===----------------------------------------------------------------------===//
1991 
1992 /// Build a strided memref type by applying `permutationMap` tp `memRefType`.
1993 static MemRefType inferTransposeResultType(MemRefType memRefType,
1994                                            AffineMap permutationMap) {
1995   auto rank = memRefType.getRank();
1996   auto originalSizes = memRefType.getShape();
1997   // Compute permuted sizes.
1998   SmallVector<int64_t, 4> sizes(rank, 0);
1999   for (auto en : llvm::enumerate(permutationMap.getResults()))
2000     sizes[en.index()] =
2001         originalSizes[en.value().cast<AffineDimExpr>().getPosition()];
2002 
2003   // Compute permuted strides.
2004   int64_t offset;
2005   SmallVector<int64_t, 4> strides;
2006   auto res = getStridesAndOffset(memRefType, strides, offset);
2007   assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank));
2008   (void)res;
2009   auto map =
2010       makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
2011   map = permutationMap ? map.compose(permutationMap) : map;
2012   return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map);
2013 }
2014 
2015 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
2016                         AffineMapAttr permutation,
2017                         ArrayRef<NamedAttribute> attrs) {
2018   auto permutationMap = permutation.getValue();
2019   assert(permutationMap);
2020 
2021   auto memRefType = in.getType().cast<MemRefType>();
2022   // Compute result type.
2023   MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
2024 
2025   build(b, result, resultType, in, attrs);
2026   result.addAttribute(TransposeOp::getPermutationAttrName(), permutation);
2027 }
2028 
2029 // transpose $in $permutation attr-dict : type($in) `to` type(results)
2030 static void print(OpAsmPrinter &p, TransposeOp op) {
2031   p << "memref.transpose " << op.in() << " " << op.permutation();
2032   p.printOptionalAttrDict(op->getAttrs(),
2033                           {TransposeOp::getPermutationAttrName()});
2034   p << " : " << op.in().getType() << " to " << op.getType();
2035 }
2036 
2037 static ParseResult parseTransposeOp(OpAsmParser &parser,
2038                                     OperationState &result) {
2039   OpAsmParser::OperandType in;
2040   AffineMap permutation;
2041   MemRefType srcType, dstType;
2042   if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
2043       parser.parseOptionalAttrDict(result.attributes) ||
2044       parser.parseColonType(srcType) ||
2045       parser.resolveOperand(in, srcType, result.operands) ||
2046       parser.parseKeywordType("to", dstType) ||
2047       parser.addTypeToList(dstType, result.types))
2048     return failure();
2049 
2050   result.addAttribute(TransposeOp::getPermutationAttrName(),
2051                       AffineMapAttr::get(permutation));
2052   return success();
2053 }
2054 
2055 static LogicalResult verify(TransposeOp op) {
2056   if (!op.permutation().isPermutation())
2057     return op.emitOpError("expected a permutation map");
2058   if (op.permutation().getNumDims() != op.getShapedType().getRank())
2059     return op.emitOpError(
2060         "expected a permutation map of same rank as the input");
2061 
2062   auto srcType = op.in().getType().cast<MemRefType>();
2063   auto dstType = op.getType().cast<MemRefType>();
2064   auto transposedType = inferTransposeResultType(srcType, op.permutation());
2065   if (dstType != transposedType)
2066     return op.emitOpError("output type ")
2067            << dstType << " does not match transposed input type " << srcType
2068            << ", " << transposedType;
2069   return success();
2070 }
2071 
2072 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
2073   if (succeeded(foldMemRefCast(*this)))
2074     return getResult();
2075   return {};
2076 }
2077 
2078 //===----------------------------------------------------------------------===//
2079 // ViewOp
2080 //===----------------------------------------------------------------------===//
2081 
2082 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) {
2083   OpAsmParser::OperandType srcInfo;
2084   SmallVector<OpAsmParser::OperandType, 1> offsetInfo;
2085   SmallVector<OpAsmParser::OperandType, 4> sizesInfo;
2086   auto indexType = parser.getBuilder().getIndexType();
2087   Type srcType, dstType;
2088   llvm::SMLoc offsetLoc;
2089   if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) ||
2090       parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square))
2091     return failure();
2092 
2093   if (offsetInfo.size() != 1)
2094     return parser.emitError(offsetLoc) << "expects 1 offset operand";
2095 
2096   return failure(
2097       parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) ||
2098       parser.parseOptionalAttrDict(result.attributes) ||
2099       parser.parseColonType(srcType) ||
2100       parser.resolveOperand(srcInfo, srcType, result.operands) ||
2101       parser.resolveOperands(offsetInfo, indexType, result.operands) ||
2102       parser.resolveOperands(sizesInfo, indexType, result.operands) ||
2103       parser.parseKeywordType("to", dstType) ||
2104       parser.addTypeToList(dstType, result.types));
2105 }
2106 
2107 static void print(OpAsmPrinter &p, ViewOp op) {
2108   p << op.getOperationName() << ' ' << op.getOperand(0) << '[';
2109   p.printOperand(op.byte_shift());
2110   p << "][" << op.sizes() << ']';
2111   p.printOptionalAttrDict(op->getAttrs());
2112   p << " : " << op.getOperand(0).getType() << " to " << op.getType();
2113 }
2114 
2115 static LogicalResult verify(ViewOp op) {
2116   auto baseType = op.getOperand(0).getType().cast<MemRefType>();
2117   auto viewType = op.getType();
2118 
2119   // The base memref should have identity layout map (or none).
2120   if (baseType.getAffineMaps().size() > 1 ||
2121       (baseType.getAffineMaps().size() == 1 &&
2122        !baseType.getAffineMaps()[0].isIdentity()))
2123     return op.emitError("unsupported map for base memref type ") << baseType;
2124 
2125   // The result memref should have identity layout map (or none).
2126   if (viewType.getAffineMaps().size() > 1 ||
2127       (viewType.getAffineMaps().size() == 1 &&
2128        !viewType.getAffineMaps()[0].isIdentity()))
2129     return op.emitError("unsupported map for result memref type ") << viewType;
2130 
2131   // The base memref and the view memref should be in the same memory space.
2132   if (baseType.getMemorySpace() != viewType.getMemorySpace())
2133     return op.emitError("different memory spaces specified for base memref "
2134                         "type ")
2135            << baseType << " and view memref type " << viewType;
2136 
2137   // Verify that we have the correct number of sizes for the result type.
2138   unsigned numDynamicDims = viewType.getNumDynamicDims();
2139   if (op.sizes().size() != numDynamicDims)
2140     return op.emitError("incorrect number of size operands for type ")
2141            << viewType;
2142 
2143   return success();
2144 }
2145 
2146 Value ViewOp::getViewSource() { return source(); }
2147 
2148 namespace {
2149 
2150 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
2151   using OpRewritePattern<ViewOp>::OpRewritePattern;
2152 
2153   LogicalResult matchAndRewrite(ViewOp viewOp,
2154                                 PatternRewriter &rewriter) const override {
2155     // Return if none of the operands are constants.
2156     if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
2157           return matchPattern(operand, matchConstantIndex());
2158         }))
2159       return failure();
2160 
2161     // Get result memref type.
2162     auto memrefType = viewOp.getType();
2163 
2164     // Get offset from old memref view type 'memRefType'.
2165     int64_t oldOffset;
2166     SmallVector<int64_t, 4> oldStrides;
2167     if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
2168       return failure();
2169     assert(oldOffset == 0 && "Expected 0 offset");
2170 
2171     SmallVector<Value, 4> newOperands;
2172 
2173     // Offset cannot be folded into result type.
2174 
2175     // Fold any dynamic dim operands which are produced by a constant.
2176     SmallVector<int64_t, 4> newShapeConstants;
2177     newShapeConstants.reserve(memrefType.getRank());
2178 
2179     unsigned dynamicDimPos = 0;
2180     unsigned rank = memrefType.getRank();
2181     for (unsigned dim = 0, e = rank; dim < e; ++dim) {
2182       int64_t dimSize = memrefType.getDimSize(dim);
2183       // If this is already static dimension, keep it.
2184       if (!ShapedType::isDynamic(dimSize)) {
2185         newShapeConstants.push_back(dimSize);
2186         continue;
2187       }
2188       auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp();
2189       if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
2190         // Dynamic shape dimension will be folded.
2191         newShapeConstants.push_back(constantIndexOp.getValue());
2192       } else {
2193         // Dynamic shape dimension not folded; copy operand from old memref.
2194         newShapeConstants.push_back(dimSize);
2195         newOperands.push_back(viewOp.sizes()[dynamicDimPos]);
2196       }
2197       dynamicDimPos++;
2198     }
2199 
2200     // Create new memref type with constant folded dims.
2201     MemRefType newMemRefType =
2202         MemRefType::Builder(memrefType).setShape(newShapeConstants);
2203     // Nothing new, don't fold.
2204     if (newMemRefType == memrefType)
2205       return failure();
2206 
2207     // Create new ViewOp.
2208     auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType,
2209                                              viewOp.getOperand(0),
2210                                              viewOp.byte_shift(), newOperands);
2211     // Insert a cast so we have the same type as the old memref type.
2212     rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType());
2213     return success();
2214   }
2215 };
2216 
2217 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
2218   using OpRewritePattern<ViewOp>::OpRewritePattern;
2219 
2220   LogicalResult matchAndRewrite(ViewOp viewOp,
2221                                 PatternRewriter &rewriter) const override {
2222     Value memrefOperand = viewOp.getOperand(0);
2223     CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
2224     if (!memrefCastOp)
2225       return failure();
2226     Value allocOperand = memrefCastOp.getOperand();
2227     AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
2228     if (!allocOp)
2229       return failure();
2230     rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
2231                                         viewOp.byte_shift(), viewOp.sizes());
2232     return success();
2233   }
2234 };
2235 
2236 } // end anonymous namespace
2237 
2238 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
2239                                          MLIRContext *context) {
2240   results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
2241 }
2242 
2243 //===----------------------------------------------------------------------===//
2244 // TableGen'd op method definitions
2245 //===----------------------------------------------------------------------===//
2246 
2247 #define GET_OP_CLASSES
2248 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
2249