1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/MemRef/IR/MemRef.h"
10 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Interfaces/InferTypeOpInterface.h"
21 #include "llvm/ADT/STLExtras.h"
22 
23 using namespace mlir;
24 using namespace mlir::memref;
25 
26 /// Materialize a single constant operation from a given attribute value with
27 /// the desired resultant type.
28 Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
29                                               Attribute value, Type type,
30                                               Location loc) {
31   return builder.create<mlir::ConstantOp>(loc, type, value);
32 }
33 
34 /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
35 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
36   return llvm::to_vector<4>(
37       llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
38         return a.cast<IntegerAttr>().getInt();
39       }));
40 }
41 
42 /// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
43 /// it is a Value or into `staticVec` if it is an IntegerAttr.
44 /// In the case of a Value, a copy of the `sentinel` value is also pushed to
45 /// `staticVec`. This is useful to extract mixed static and dynamic entries that
46 /// come from an AttrSizedOperandSegments trait.
47 static void dispatchIndexOpFoldResult(OpFoldResult ofr,
48                                       SmallVectorImpl<Value> &dynamicVec,
49                                       SmallVectorImpl<int64_t> &staticVec,
50                                       int64_t sentinel) {
51   if (auto v = ofr.dyn_cast<Value>()) {
52     dynamicVec.push_back(v);
53     staticVec.push_back(sentinel);
54     return;
55   }
56   APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue();
57   staticVec.push_back(apInt.getSExtValue());
58 }
59 
60 static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
61                                        SmallVectorImpl<Value> &dynamicVec,
62                                        SmallVectorImpl<int64_t> &staticVec,
63                                        int64_t sentinel) {
64   for (auto ofr : ofrs)
65     dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel);
66 }
67 
68 //===----------------------------------------------------------------------===//
69 // Common canonicalization pattern support logic
70 //===----------------------------------------------------------------------===//
71 
72 /// This is a common class used for patterns of the form
73 /// "someop(memrefcast) -> someop".  It folds the source of any memref.cast
74 /// into the root operation directly.
75 static LogicalResult foldMemRefCast(Operation *op) {
76   bool folded = false;
77   for (OpOperand &operand : op->getOpOperands()) {
78     auto cast = operand.get().getDefiningOp<CastOp>();
79     if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
80       operand.set(cast.getOperand());
81       folded = true;
82     }
83   }
84   return success(folded);
85 }
86 
87 //===----------------------------------------------------------------------===//
88 // Helpers for GlobalOp
89 //===----------------------------------------------------------------------===//
90 
91 static Type getTensorTypeFromMemRefType(Type type) {
92   if (auto memref = type.dyn_cast<MemRefType>())
93     return RankedTensorType::get(memref.getShape(), memref.getElementType());
94   if (auto memref = type.dyn_cast<UnrankedMemRefType>())
95     return UnrankedTensorType::get(memref.getElementType());
96   return NoneType::get(type.getContext());
97 }
98 
99 //===----------------------------------------------------------------------===//
100 // AllocOp / AllocaOp
101 //===----------------------------------------------------------------------===//
102 
103 template <typename AllocLikeOp>
104 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
105   static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
106                 "applies to only alloc or alloca");
107   auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>();
108   if (!memRefType)
109     return op.emitOpError("result must be a memref");
110 
111   if (static_cast<int64_t>(op.dynamicSizes().size()) !=
112       memRefType.getNumDynamicDims())
113     return op.emitOpError("dimension operand count does not equal memref "
114                           "dynamic dimension count");
115 
116   unsigned numSymbols = 0;
117   if (!memRefType.getAffineMaps().empty())
118     numSymbols = memRefType.getAffineMaps().front().getNumSymbols();
119   if (op.symbolOperands().size() != numSymbols)
120     return op.emitOpError(
121         "symbol operand count does not equal memref symbol count");
122 
123   return success();
124 }
125 
126 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); }
127 
128 static LogicalResult verify(AllocaOp op) {
129   // An alloca op needs to have an ancestor with an allocation scope trait.
130   if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
131     return op.emitOpError(
132         "requires an ancestor op with AutomaticAllocationScope trait");
133 
134   return verifyAllocLikeOp(op);
135 }
136 
137 namespace {
138 /// Fold constant dimensions into an alloc like operation.
139 template <typename AllocLikeOp>
140 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
141   using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
142 
143   LogicalResult matchAndRewrite(AllocLikeOp alloc,
144                                 PatternRewriter &rewriter) const override {
145     // Check to see if any dimensions operands are constants.  If so, we can
146     // substitute and drop them.
147     if (llvm::none_of(alloc.getOperands(), [](Value operand) {
148           return matchPattern(operand, matchConstantIndex());
149         }))
150       return failure();
151 
152     auto memrefType = alloc.getType();
153 
154     // Ok, we have one or more constant operands.  Collect the non-constant ones
155     // and keep track of the resultant memref type to build.
156     SmallVector<int64_t, 4> newShapeConstants;
157     newShapeConstants.reserve(memrefType.getRank());
158     SmallVector<Value, 4> newOperands;
159 
160     unsigned dynamicDimPos = 0;
161     for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
162       int64_t dimSize = memrefType.getDimSize(dim);
163       // If this is already static dimension, keep it.
164       if (dimSize != -1) {
165         newShapeConstants.push_back(dimSize);
166         continue;
167       }
168       auto *defOp = alloc.getOperand(dynamicDimPos).getDefiningOp();
169       if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
170         // Dynamic shape dimension will be folded.
171         newShapeConstants.push_back(constantIndexOp.getValue());
172       } else {
173         // Dynamic shape dimension not folded; copy operand from old memref.
174         newShapeConstants.push_back(-1);
175         newOperands.push_back(alloc.getOperand(dynamicDimPos));
176       }
177       dynamicDimPos++;
178     }
179 
180     // Create new memref type (which will have fewer dynamic dimensions).
181     MemRefType newMemRefType =
182         MemRefType::Builder(memrefType).setShape(newShapeConstants);
183     assert(static_cast<int64_t>(newOperands.size()) ==
184            newMemRefType.getNumDynamicDims());
185 
186     // Create and insert the alloc op for the new memref.
187     auto newAlloc = rewriter.create<AllocLikeOp>(
188         alloc.getLoc(), newMemRefType, newOperands, alloc.alignmentAttr());
189     // Insert a cast so we have the same type as the old alloc.
190     auto resultCast =
191         rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType());
192 
193     rewriter.replaceOp(alloc, {resultCast});
194     return success();
195   }
196 };
197 
198 /// Fold alloc operations with no users or only store and dealloc uses.
199 template <typename T>
200 struct SimplifyDeadAlloc : public OpRewritePattern<T> {
201   using OpRewritePattern<T>::OpRewritePattern;
202 
203   LogicalResult matchAndRewrite(T alloc,
204                                 PatternRewriter &rewriter) const override {
205     if (llvm::any_of(alloc->getUsers(), [](Operation *op) {
206           return !isa<StoreOp, DeallocOp>(op);
207         }))
208       return failure();
209 
210     for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
211       rewriter.eraseOp(user);
212 
213     rewriter.eraseOp(alloc);
214     return success();
215   }
216 };
217 } // end anonymous namespace.
218 
219 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
220                                           MLIRContext *context) {
221   results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
222 }
223 
224 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
225                                            MLIRContext *context) {
226   results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
227       context);
228 }
229 
230 //===----------------------------------------------------------------------===//
231 // AssumeAlignmentOp
232 //===----------------------------------------------------------------------===//
233 
234 static LogicalResult verify(AssumeAlignmentOp op) {
235   unsigned alignment = op.alignment();
236   if (!llvm::isPowerOf2_32(alignment))
237     return op.emitOpError("alignment must be power of 2");
238   return success();
239 }
240 
241 //===----------------------------------------------------------------------===//
242 // BufferCastOp
243 //===----------------------------------------------------------------------===//
244 
245 OpFoldResult BufferCastOp::fold(ArrayRef<Attribute>) {
246   if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>())
247     if (tensorLoad.memref().getType() == getType())
248       return tensorLoad.memref();
249   return {};
250 }
251 
252 namespace {
253 /// Replace tensor_cast + buffer_cast by buffer_cast + memref_cast.
254 struct BufferCast : public OpRewritePattern<BufferCastOp> {
255   using OpRewritePattern<BufferCastOp>::OpRewritePattern;
256 
257   LogicalResult matchAndRewrite(BufferCastOp bufferCast,
258                                 PatternRewriter &rewriter) const final {
259     auto tensorCastOperand =
260         bufferCast.getOperand().getDefiningOp<tensor::CastOp>();
261     if (!tensorCastOperand)
262       return failure();
263     auto srcTensorType =
264         tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>();
265     if (!srcTensorType)
266       return failure();
267     auto memrefType = MemRefType::get(srcTensorType.getShape(),
268                                       srcTensorType.getElementType());
269     Value memref = rewriter.create<BufferCastOp>(
270         bufferCast.getLoc(), memrefType, tensorCastOperand.getOperand());
271     rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(),
272                                         memref);
273     return success();
274   }
275 };
276 
277 /// Canonicalize memref.tensor_load + memref.buffer_cast to memref.cast when
278 /// type mismatches prevent `BufferCastOp::fold` to kick in.
279 struct TensorLoadToMemRef : public OpRewritePattern<BufferCastOp> {
280   using OpRewritePattern<BufferCastOp>::OpRewritePattern;
281 
282   LogicalResult matchAndRewrite(BufferCastOp bufferCast,
283                                 PatternRewriter &rewriter) const final {
284     auto tensorLoad = bufferCast.tensor().getDefiningOp<TensorLoadOp>();
285     // Bail unless we have a tensor_load + memref.buffer_cast with different
286     // types. `BufferCastOp::fold` handles the same type case.
287     if (!tensorLoad || tensorLoad.memref().getType() == bufferCast.getType())
288       return failure();
289     // If types are not cast-compatible, bail.
290     if (!CastOp::areCastCompatible(tensorLoad.memref().getType(),
291                                    bufferCast.getType()))
292       return failure();
293     rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(),
294                                         tensorLoad.memref());
295     return success();
296   }
297 };
298 
299 } // namespace
300 
301 void BufferCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
302                                                MLIRContext *context) {
303   results.add<BufferCast, TensorLoadToMemRef>(context);
304 }
305 
306 //===----------------------------------------------------------------------===//
307 // CastOp
308 //===----------------------------------------------------------------------===//
309 
310 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
311 /// source memref. This is useful to to fold a memref.cast into a consuming op
312 /// and implement canonicalization patterns for ops in different dialects that
313 /// may consume the results of memref.cast operations. Such foldable memref.cast
314 /// operations are typically inserted as `view` and `subview` ops are
315 /// canonicalized, to preserve the type compatibility of their uses.
316 ///
317 /// Returns true when all conditions are met:
318 /// 1. source and result are ranked memrefs with strided semantics and same
319 /// element type and rank.
320 /// 2. each of the source's size, offset or stride has more static information
321 /// than the corresponding result's size, offset or stride.
322 ///
323 /// Example 1:
324 /// ```mlir
325 ///   %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
326 ///   %2 = consumer %1 ... : memref<?x?xf32> ...
327 /// ```
328 ///
329 /// may fold into:
330 ///
331 /// ```mlir
332 ///   %2 = consumer %0 ... : memref<8x16xf32> ...
333 /// ```
334 ///
335 /// Example 2:
336 /// ```
337 ///   %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
338 ///          to memref<?x?xf32>
339 ///   consumer %1 : memref<?x?xf32> ...
340 /// ```
341 ///
342 /// may fold into:
343 ///
344 /// ```
345 ///   consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
346 /// ```
347 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
348   MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
349   MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
350 
351   // Requires ranked MemRefType.
352   if (!sourceType || !resultType)
353     return false;
354 
355   // Requires same elemental type.
356   if (sourceType.getElementType() != resultType.getElementType())
357     return false;
358 
359   // Requires same rank.
360   if (sourceType.getRank() != resultType.getRank())
361     return false;
362 
363   // Only fold casts between strided memref forms.
364   int64_t sourceOffset, resultOffset;
365   SmallVector<int64_t, 4> sourceStrides, resultStrides;
366   if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
367       failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
368     return false;
369 
370   // If cast is towards more static sizes along any dimension, don't fold.
371   for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
372     auto ss = std::get<0>(it), st = std::get<1>(it);
373     if (ss != st)
374       if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st))
375         return false;
376   }
377 
378   // If cast is towards more static offset along any dimension, don't fold.
379   if (sourceOffset != resultOffset)
380     if (MemRefType::isDynamicStrideOrOffset(sourceOffset) &&
381         !MemRefType::isDynamicStrideOrOffset(resultOffset))
382       return false;
383 
384   // If cast is towards more static strides along any dimension, don't fold.
385   for (auto it : llvm::zip(sourceStrides, resultStrides)) {
386     auto ss = std::get<0>(it), st = std::get<1>(it);
387     if (ss != st)
388       if (MemRefType::isDynamicStrideOrOffset(ss) &&
389           !MemRefType::isDynamicStrideOrOffset(st))
390         return false;
391   }
392 
393   return true;
394 }
395 
396 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
397   if (inputs.size() != 1 || outputs.size() != 1)
398     return false;
399   Type a = inputs.front(), b = outputs.front();
400   auto aT = a.dyn_cast<MemRefType>();
401   auto bT = b.dyn_cast<MemRefType>();
402 
403   auto uaT = a.dyn_cast<UnrankedMemRefType>();
404   auto ubT = b.dyn_cast<UnrankedMemRefType>();
405 
406   if (aT && bT) {
407     if (aT.getElementType() != bT.getElementType())
408       return false;
409     if (aT.getAffineMaps() != bT.getAffineMaps()) {
410       int64_t aOffset, bOffset;
411       SmallVector<int64_t, 4> aStrides, bStrides;
412       if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
413           failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
414           aStrides.size() != bStrides.size())
415         return false;
416 
417       // Strides along a dimension/offset are compatible if the value in the
418       // source memref is static and the value in the target memref is the
419       // same. They are also compatible if either one is dynamic (see
420       // description of MemRefCastOp for details).
421       auto checkCompatible = [](int64_t a, int64_t b) {
422         return (a == MemRefType::getDynamicStrideOrOffset() ||
423                 b == MemRefType::getDynamicStrideOrOffset() || a == b);
424       };
425       if (!checkCompatible(aOffset, bOffset))
426         return false;
427       for (auto aStride : enumerate(aStrides))
428         if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
429           return false;
430     }
431     if (aT.getMemorySpace() != bT.getMemorySpace())
432       return false;
433 
434     // They must have the same rank, and any specified dimensions must match.
435     if (aT.getRank() != bT.getRank())
436       return false;
437 
438     for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
439       int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
440       if (aDim != -1 && bDim != -1 && aDim != bDim)
441         return false;
442     }
443     return true;
444   } else {
445     if (!aT && !uaT)
446       return false;
447     if (!bT && !ubT)
448       return false;
449     // Unranked to unranked casting is unsupported
450     if (uaT && ubT)
451       return false;
452 
453     auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
454     auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
455     if (aEltType != bEltType)
456       return false;
457 
458     auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
459     auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
460     if (aMemSpace != bMemSpace)
461       return false;
462 
463     return true;
464   }
465 
466   return false;
467 }
468 
469 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
470   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
471 }
472 
473 //===----------------------------------------------------------------------===//
474 // CloneOp
475 //===----------------------------------------------------------------------===//
476 
477 void CloneOp::getEffects(
478     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
479         &effects) {
480   effects.emplace_back(MemoryEffects::Read::get(), input(),
481                        SideEffects::DefaultResource::get());
482   effects.emplace_back(MemoryEffects::Write::get(), output(),
483                        SideEffects::DefaultResource::get());
484 }
485 
486 namespace {
487 /// Fold Dealloc operations that are deallocating an AllocOp that is only used
488 /// by other Dealloc operations.
489 struct SimplifyClones : public OpRewritePattern<CloneOp> {
490   using OpRewritePattern<CloneOp>::OpRewritePattern;
491 
492   LogicalResult matchAndRewrite(CloneOp cloneOp,
493                                 PatternRewriter &rewriter) const override {
494     if (cloneOp.use_empty()) {
495       rewriter.eraseOp(cloneOp);
496       return success();
497     }
498 
499     Value source = cloneOp.input();
500 
501     // Removes the clone operation and the corresponding dealloc and alloc
502     // operation (if any).
503     auto tryRemoveClone = [&](Operation *sourceOp, Operation *dealloc,
504                               Operation *alloc) {
505       if (!sourceOp || !dealloc || !alloc ||
506           alloc->getBlock() != dealloc->getBlock())
507         return false;
508       rewriter.replaceOp(cloneOp, source);
509       rewriter.eraseOp(dealloc);
510       return true;
511     };
512 
513     // Removes unnecessary clones that are derived from the result of the clone
514     // op.
515     Operation *deallocOp = findDealloc(cloneOp.output());
516     Operation *sourceOp = source.getDefiningOp();
517     if (tryRemoveClone(sourceOp, deallocOp, sourceOp))
518       return success();
519 
520     // Removes unnecessary clones that are derived from the source of the clone
521     // op.
522     deallocOp = findDealloc(source);
523     if (tryRemoveClone(sourceOp, deallocOp, cloneOp))
524       return success();
525 
526     return failure();
527   }
528 };
529 
530 } // end anonymous namespace.
531 
532 void CloneOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
533                                           MLIRContext *context) {
534   results.insert<SimplifyClones>(context);
535 }
536 
537 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) {
538   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
539 }
540 
541 //===----------------------------------------------------------------------===//
542 // DeallocOp
543 //===----------------------------------------------------------------------===//
544 
545 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
546                               SmallVectorImpl<OpFoldResult> &results) {
547   /// dealloc(memrefcast) -> dealloc
548   return foldMemRefCast(*this);
549 }
550 
551 //===----------------------------------------------------------------------===//
552 // DimOp
553 //===----------------------------------------------------------------------===//
554 
555 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref,
556                   int64_t index) {
557   auto loc = result.location;
558   Value indexValue = builder.create<ConstantIndexOp>(loc, index);
559   build(builder, result, memref, indexValue);
560 }
561 
562 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref,
563                   Value index) {
564   auto indexTy = builder.getIndexType();
565   build(builder, result, indexTy, memref, index);
566 }
567 
568 Optional<int64_t> DimOp::getConstantIndex() {
569   if (auto constantOp = index().getDefiningOp<ConstantOp>())
570     return constantOp.getValue().cast<IntegerAttr>().getInt();
571   return {};
572 }
573 
574 static LogicalResult verify(DimOp op) {
575   // Assume unknown index to be in range.
576   Optional<int64_t> index = op.getConstantIndex();
577   if (!index.hasValue())
578     return success();
579 
580   // Check that constant index is not knowingly out of range.
581   auto type = op.memrefOrTensor().getType();
582   if (auto memrefType = type.dyn_cast<MemRefType>()) {
583     if (index.getValue() >= memrefType.getRank())
584       return op.emitOpError("index is out of range");
585   } else if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
586     if (index.getValue() >= tensorType.getRank())
587       return op.emitOpError("index is out of range");
588   } else if (type.isa<UnrankedMemRefType>() || type.isa<UnrankedTensorType>()) {
589     // Assume index to be in range.
590   } else {
591     llvm_unreachable("expected operand with memref type");
592   }
593   return success();
594 }
595 
596 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
597   auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
598 
599   // All forms of folding require a known index.
600   if (!index)
601     return {};
602 
603   auto argTy = memrefOrTensor().getType();
604   // Fold if the shape extent along the given index is known.
605   if (auto shapedTy = argTy.dyn_cast<ShapedType>()) {
606     // Folding for unranked types (UnrankedMemRefType) is not supported.
607     if (!shapedTy.hasRank())
608       return {};
609     if (!shapedTy.isDynamicDim(index.getInt())) {
610       Builder builder(getContext());
611       return builder.getIndexAttr(shapedTy.getShape()[index.getInt()]);
612     }
613   }
614 
615   Operation *definingOp = memrefOrTensor().getDefiningOp();
616 
617   // dim(memref.tensor_load(memref)) -> dim(memref)
618   if (auto tensorLoadOp = dyn_cast_or_null<TensorLoadOp>(definingOp)) {
619     setOperand(0, tensorLoadOp.memref());
620     return getResult();
621   }
622 
623   // Fold dim to the operand of tensor.generate.
624   if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
625     auto resultType =
626         fromElements.getResult().getType().cast<RankedTensorType>();
627     // The case where the type encodes the size of the dimension is handled
628     // above.
629     assert(resultType.getShape()[index.getInt()] ==
630            RankedTensorType::kDynamicSize);
631 
632     // Find the operand of the fromElements that corresponds to this index.
633     auto dynExtents = fromElements.dynamicExtents().begin();
634     for (auto dim : resultType.getShape().take_front(index.getInt()))
635       if (dim == RankedTensorType::kDynamicSize)
636         dynExtents++;
637 
638     return Value{*dynExtents};
639   }
640 
641   // The size at the given index is now known to be a dynamic size.
642   unsigned unsignedIndex = index.getValue().getZExtValue();
643 
644   if (auto subtensor = dyn_cast_or_null<mlir::SubTensorOp>(definingOp)) {
645     assert(subtensor.isDynamicSize(unsignedIndex) &&
646            "Expected dynamic subtensor size");
647     return subtensor.getDynamicSize(unsignedIndex);
648   }
649 
650   // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
651   auto memrefType = argTy.dyn_cast<MemRefType>();
652   if (!memrefType)
653     return {};
654 
655   if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
656     return *(alloc.getDynamicSizes().begin() +
657              memrefType.getDynamicDimIndex(unsignedIndex));
658 
659   if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
660     return *(alloca.getDynamicSizes().begin() +
661              memrefType.getDynamicDimIndex(unsignedIndex));
662 
663   if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
664     return *(view.getDynamicSizes().begin() +
665              memrefType.getDynamicDimIndex(unsignedIndex));
666 
667   if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
668     assert(subview.isDynamicSize(unsignedIndex) &&
669            "Expected dynamic subview size");
670     return subview.getDynamicSize(unsignedIndex);
671   }
672 
673   // dim(memrefcast) -> dim
674   if (succeeded(foldMemRefCast(*this)))
675     return getResult();
676 
677   return {};
678 }
679 
680 namespace {
681 /// Fold dim of a memref reshape operation to a load into the reshape's shape
682 /// operand.
683 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
684   using OpRewritePattern<DimOp>::OpRewritePattern;
685 
686   LogicalResult matchAndRewrite(DimOp dim,
687                                 PatternRewriter &rewriter) const override {
688     auto reshape = dim.memrefOrTensor().getDefiningOp<ReshapeOp>();
689 
690     if (!reshape)
691       return failure();
692 
693     // Place the load directly after the reshape to ensure that the shape memref
694     // was not mutated.
695     rewriter.setInsertionPointAfter(reshape);
696     rewriter.replaceOpWithNewOp<LoadOp>(dim, reshape.shape(),
697                                         llvm::makeArrayRef({dim.index()}));
698     return success();
699   }
700 };
701 
702 /// Fold dim of a dim of a cast into the dim of the source of the tensor cast.
703 template <typename CastOpTy>
704 struct DimOfCastOp : public OpRewritePattern<DimOp> {
705   using OpRewritePattern<DimOp>::OpRewritePattern;
706 
707   LogicalResult matchAndRewrite(DimOp dimOp,
708                                 PatternRewriter &rewriter) const override {
709     auto castOp = dimOp.memrefOrTensor().getDefiningOp<CastOpTy>();
710     if (!castOp)
711       return failure();
712     Value newSource = castOp.getOperand();
713     rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.index());
714     return success();
715   }
716 };
717 
718 /// Helper method to get the `Value` that is the shape of the `resultIdx`-th
719 /// result at dimension `dimIndex` from the `ShapedTypeOpInterface`.
720 /// TODO(ravishankarm): This is better put as a interface utility method
721 /// somewhere, but that would imply the interface will depend on the `tensor`
722 /// dialect. Ideally maybe a utility method in the `tensor` dialect.
723 static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result,
724                                             int64_t dimIndex) {
725   unsigned resultNumber = result.getResultNumber();
726   auto shapedTypeOp = dyn_cast<InferShapedTypeOpInterface>(result.getOwner());
727   Location loc = result.getOwner()->getLoc();
728   if (!shapedTypeOp)
729     return nullptr;
730 
731   // The interface exposes two methods, one that returns the shape of all the
732   // results as `Value` and other that returns the shape as a list of
733   // `SmallVector<Value>`. The former takes precedence over the latter. So first
734   // check if the op implements the first interface method or the second, and
735   // get the value to use appropriately.
736   SmallVector<Value> reifiedResultShapes;
737   if (succeeded(
738           shapedTypeOp.reifyReturnTypeShapes(builder, reifiedResultShapes))) {
739     if (reifiedResultShapes.size() <= resultNumber)
740       return nullptr;
741     Value resultShape = reifiedResultShapes[resultNumber];
742     auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>();
743     if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>())
744       return nullptr;
745     return builder.create<tensor::ExtractOp>(
746         loc, resultShape, builder.createOrFold<ConstantIndexOp>(loc, dimIndex));
747   }
748 
749   SmallVector<SmallVector<Value>> reifiedResultShapesPerDim;
750   if (failed(shapedTypeOp.reifyReturnTypeShapesPerResultDim(
751           builder, reifiedResultShapesPerDim)))
752     return nullptr;
753   if (reifiedResultShapesPerDim.size() <= resultNumber ||
754       reifiedResultShapesPerDim[resultNumber].size() !=
755           static_cast<size_t>(result.getType().cast<ShapedType>().getRank()))
756     return nullptr;
757   OpFoldResult valueOrAttr = reifiedResultShapesPerDim[resultNumber][dimIndex];
758   if (auto attr = valueOrAttr.dyn_cast<Attribute>())
759     return builder.createOrFold<ConstantIndexOp>(
760         loc, attr.cast<IntegerAttr>().getInt());
761   return valueOrAttr.get<Value>();
762 }
763 
764 /// Fold dim of an operation that implements the InferShapedTypeOpInterface
765 struct DimOfShapedTypeOpInterface : public OpRewritePattern<DimOp> {
766   using OpRewritePattern<DimOp>::OpRewritePattern;
767 
768   LogicalResult matchAndRewrite(DimOp dimOp,
769                                 PatternRewriter &rewriter) const override {
770     OpResult dimValue = dimOp.memrefOrTensor().dyn_cast<OpResult>();
771     if (!dimValue)
772       return failure();
773     auto shapedTypeOp =
774         dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner());
775     if (!shapedTypeOp)
776       return failure();
777 
778     Optional<int64_t> dimIndex = dimOp.getConstantIndex();
779     if (!dimIndex)
780       return failure();
781     Value replacement =
782         getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex);
783     if (!replacement)
784       return failure();
785     rewriter.replaceOp(dimOp, replacement);
786     return success();
787   }
788 };
789 } // end anonymous namespace.
790 
791 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
792                                         MLIRContext *context) {
793   results.add<DimOfMemRefReshape, DimOfCastOp<BufferCastOp>,
794               DimOfCastOp<tensor::CastOp>, DimOfShapedTypeOpInterface>(context);
795 }
796 
797 // ---------------------------------------------------------------------------
798 // DmaStartOp
799 // ---------------------------------------------------------------------------
800 
801 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
802                        Value srcMemRef, ValueRange srcIndices, Value destMemRef,
803                        ValueRange destIndices, Value numElements,
804                        Value tagMemRef, ValueRange tagIndices, Value stride,
805                        Value elementsPerStride) {
806   result.addOperands(srcMemRef);
807   result.addOperands(srcIndices);
808   result.addOperands(destMemRef);
809   result.addOperands(destIndices);
810   result.addOperands({numElements, tagMemRef});
811   result.addOperands(tagIndices);
812   if (stride)
813     result.addOperands({stride, elementsPerStride});
814 }
815 
816 void DmaStartOp::print(OpAsmPrinter &p) {
817   p << getOperationName() << " " << getSrcMemRef() << '[' << getSrcIndices()
818     << "], " << getDstMemRef() << '[' << getDstIndices() << "], "
819     << getNumElements() << ", " << getTagMemRef() << '[' << getTagIndices()
820     << ']';
821   if (isStrided())
822     p << ", " << getStride() << ", " << getNumElementsPerStride();
823 
824   p.printOptionalAttrDict((*this)->getAttrs());
825   p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
826     << ", " << getTagMemRef().getType();
827 }
828 
829 // Parse DmaStartOp.
830 // Ex:
831 //   %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
832 //                       %tag[%index], %stride, %num_elt_per_stride :
833 //                     : memref<3076 x f32, 0>,
834 //                       memref<1024 x f32, 2>,
835 //                       memref<1 x i32>
836 //
837 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
838   OpAsmParser::OperandType srcMemRefInfo;
839   SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
840   OpAsmParser::OperandType dstMemRefInfo;
841   SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos;
842   OpAsmParser::OperandType numElementsInfo;
843   OpAsmParser::OperandType tagMemrefInfo;
844   SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
845   SmallVector<OpAsmParser::OperandType, 2> strideInfo;
846 
847   SmallVector<Type, 3> types;
848   auto indexType = parser.getBuilder().getIndexType();
849 
850   // Parse and resolve the following list of operands:
851   // *) source memref followed by its indices (in square brackets).
852   // *) destination memref followed by its indices (in square brackets).
853   // *) dma size in KiB.
854   if (parser.parseOperand(srcMemRefInfo) ||
855       parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
856       parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
857       parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
858       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
859       parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
860       parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
861     return failure();
862 
863   // Parse optional stride and elements per stride.
864   if (parser.parseTrailingOperandList(strideInfo))
865     return failure();
866 
867   bool isStrided = strideInfo.size() == 2;
868   if (!strideInfo.empty() && !isStrided) {
869     return parser.emitError(parser.getNameLoc(),
870                             "expected two stride related operands");
871   }
872 
873   if (parser.parseColonTypeList(types))
874     return failure();
875   if (types.size() != 3)
876     return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
877 
878   if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
879       parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
880       parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
881       parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
882       // size should be an index.
883       parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
884       parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
885       // tag indices should be index.
886       parser.resolveOperands(tagIndexInfos, indexType, result.operands))
887     return failure();
888 
889   if (isStrided) {
890     if (parser.resolveOperands(strideInfo, indexType, result.operands))
891       return failure();
892   }
893 
894   return success();
895 }
896 
897 LogicalResult DmaStartOp::verify() {
898   unsigned numOperands = getNumOperands();
899 
900   // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
901   // the number of elements.
902   if (numOperands < 4)
903     return emitOpError("expected at least 4 operands");
904 
905   // Check types of operands. The order of these calls is important: the later
906   // calls rely on some type properties to compute the operand position.
907   // 1. Source memref.
908   if (!getSrcMemRef().getType().isa<MemRefType>())
909     return emitOpError("expected source to be of memref type");
910   if (numOperands < getSrcMemRefRank() + 4)
911     return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
912                          << " operands";
913   if (!getSrcIndices().empty() &&
914       !llvm::all_of(getSrcIndices().getTypes(),
915                     [](Type t) { return t.isIndex(); }))
916     return emitOpError("expected source indices to be of index type");
917 
918   // 2. Destination memref.
919   if (!getDstMemRef().getType().isa<MemRefType>())
920     return emitOpError("expected destination to be of memref type");
921   unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
922   if (numOperands < numExpectedOperands)
923     return emitOpError() << "expected at least " << numExpectedOperands
924                          << " operands";
925   if (!getDstIndices().empty() &&
926       !llvm::all_of(getDstIndices().getTypes(),
927                     [](Type t) { return t.isIndex(); }))
928     return emitOpError("expected destination indices to be of index type");
929 
930   // 3. Number of elements.
931   if (!getNumElements().getType().isIndex())
932     return emitOpError("expected num elements to be of index type");
933 
934   // 4. Tag memref.
935   if (!getTagMemRef().getType().isa<MemRefType>())
936     return emitOpError("expected tag to be of memref type");
937   numExpectedOperands += getTagMemRefRank();
938   if (numOperands < numExpectedOperands)
939     return emitOpError() << "expected at least " << numExpectedOperands
940                          << " operands";
941   if (!getTagIndices().empty() &&
942       !llvm::all_of(getTagIndices().getTypes(),
943                     [](Type t) { return t.isIndex(); }))
944     return emitOpError("expected tag indices to be of index type");
945 
946   // DMAs from different memory spaces supported.
947   if (getSrcMemorySpace() == getDstMemorySpace())
948     return emitOpError("DMA should be between different memory spaces");
949 
950   // Optional stride-related operands must be either both present or both
951   // absent.
952   if (numOperands != numExpectedOperands &&
953       numOperands != numExpectedOperands + 2)
954     return emitOpError("incorrect number of operands");
955 
956   // 5. Strides.
957   if (isStrided()) {
958     if (!getStride().getType().isIndex() ||
959         !getNumElementsPerStride().getType().isIndex())
960       return emitOpError(
961           "expected stride and num elements per stride to be of type index");
962   }
963 
964   return success();
965 }
966 
967 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
968                                SmallVectorImpl<OpFoldResult> &results) {
969   /// dma_start(memrefcast) -> dma_start
970   return foldMemRefCast(*this);
971 }
972 
973 // ---------------------------------------------------------------------------
974 // DmaWaitOp
975 // ---------------------------------------------------------------------------
976 
977 void DmaWaitOp::build(OpBuilder &builder, OperationState &result,
978                       Value tagMemRef, ValueRange tagIndices,
979                       Value numElements) {
980   result.addOperands(tagMemRef);
981   result.addOperands(tagIndices);
982   result.addOperands(numElements);
983 }
984 
985 void DmaWaitOp::print(OpAsmPrinter &p) {
986   p << getOperationName() << " " << getTagMemRef() << '[' << getTagIndices()
987     << "], " << getNumElements();
988   p.printOptionalAttrDict((*this)->getAttrs());
989   p << " : " << getTagMemRef().getType();
990 }
991 
992 // Parse DmaWaitOp.
993 // Eg:
994 //   dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4>
995 //
996 ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) {
997   OpAsmParser::OperandType tagMemrefInfo;
998   SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
999   Type type;
1000   auto indexType = parser.getBuilder().getIndexType();
1001   OpAsmParser::OperandType numElementsInfo;
1002 
1003   // Parse tag memref, its indices, and dma size.
1004   if (parser.parseOperand(tagMemrefInfo) ||
1005       parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) ||
1006       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1007       parser.parseColonType(type) ||
1008       parser.resolveOperand(tagMemrefInfo, type, result.operands) ||
1009       parser.resolveOperands(tagIndexInfos, indexType, result.operands) ||
1010       parser.resolveOperand(numElementsInfo, indexType, result.operands))
1011     return failure();
1012 
1013   return success();
1014 }
1015 
1016 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1017                               SmallVectorImpl<OpFoldResult> &results) {
1018   /// dma_wait(memrefcast) -> dma_wait
1019   return foldMemRefCast(*this);
1020 }
1021 
1022 LogicalResult DmaWaitOp::verify() {
1023   // Mandatory non-variadic operands are tag and the number of elements.
1024   if (getNumOperands() < 2)
1025     return emitOpError() << "expected at least 2 operands";
1026 
1027   // Check types of operands. The order of these calls is important: the later
1028   // calls rely on some type properties to compute the operand position.
1029   if (!getTagMemRef().getType().isa<MemRefType>())
1030     return emitOpError() << "expected tag to be of memref type";
1031 
1032   if (getNumOperands() != 2 + getTagMemRefRank())
1033     return emitOpError() << "expected " << 2 + getTagMemRefRank()
1034                          << " operands";
1035 
1036   if (!getTagIndices().empty() &&
1037       !llvm::all_of(getTagIndices().getTypes(),
1038                     [](Type t) { return t.isIndex(); }))
1039     return emitOpError() << "expected tag indices to be of index type";
1040 
1041   if (!getNumElements().getType().isIndex())
1042     return emitOpError()
1043            << "expected the number of elements to be of index type";
1044 
1045   return success();
1046 }
1047 
1048 //===----------------------------------------------------------------------===//
1049 // GlobalOp
1050 //===----------------------------------------------------------------------===//
1051 
1052 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op,
1053                                                    TypeAttr type,
1054                                                    Attribute initialValue) {
1055   p << type;
1056   if (!op.isExternal()) {
1057     p << " = ";
1058     if (op.isUninitialized())
1059       p << "uninitialized";
1060     else
1061       p.printAttributeWithoutType(initialValue);
1062   }
1063 }
1064 
1065 static ParseResult
1066 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
1067                                        Attribute &initialValue) {
1068   Type type;
1069   if (parser.parseType(type))
1070     return failure();
1071 
1072   auto memrefType = type.dyn_cast<MemRefType>();
1073   if (!memrefType || !memrefType.hasStaticShape())
1074     return parser.emitError(parser.getNameLoc())
1075            << "type should be static shaped memref, but got " << type;
1076   typeAttr = TypeAttr::get(type);
1077 
1078   if (parser.parseOptionalEqual())
1079     return success();
1080 
1081   if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1082     initialValue = UnitAttr::get(parser.getBuilder().getContext());
1083     return success();
1084   }
1085 
1086   Type tensorType = getTensorTypeFromMemRefType(memrefType);
1087   if (parser.parseAttribute(initialValue, tensorType))
1088     return failure();
1089   if (!initialValue.isa<ElementsAttr>())
1090     return parser.emitError(parser.getNameLoc())
1091            << "initial value should be a unit or elements attribute";
1092   return success();
1093 }
1094 
1095 static LogicalResult verify(GlobalOp op) {
1096   auto memrefType = op.type().dyn_cast<MemRefType>();
1097   if (!memrefType || !memrefType.hasStaticShape())
1098     return op.emitOpError("type should be static shaped memref, but got ")
1099            << op.type();
1100 
1101   // Verify that the initial value, if present, is either a unit attribute or
1102   // an elements attribute.
1103   if (op.initial_value().hasValue()) {
1104     Attribute initValue = op.initial_value().getValue();
1105     if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>())
1106       return op.emitOpError("initial value should be a unit or elements "
1107                             "attribute, but got ")
1108              << initValue;
1109 
1110     // Check that the type of the initial value is compatible with the type of
1111     // the global variable.
1112     if (initValue.isa<ElementsAttr>()) {
1113       Type initType = initValue.getType();
1114       Type tensorType = getTensorTypeFromMemRefType(memrefType);
1115       if (initType != tensorType)
1116         return op.emitOpError("initial value expected to be of type ")
1117                << tensorType << ", but was of type " << initType;
1118     }
1119   }
1120 
1121   // TODO: verify visibility for declarations.
1122   return success();
1123 }
1124 
1125 //===----------------------------------------------------------------------===//
1126 // GetGlobalOp
1127 //===----------------------------------------------------------------------===//
1128 
1129 LogicalResult
1130 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1131   // Verify that the result type is same as the type of the referenced
1132   // memref.global op.
1133   auto global =
1134       symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr());
1135   if (!global)
1136     return emitOpError("'")
1137            << name() << "' does not reference a valid global memref";
1138 
1139   Type resultType = result().getType();
1140   if (global.type() != resultType)
1141     return emitOpError("result type ")
1142            << resultType << " does not match type " << global.type()
1143            << " of the global memref @" << name();
1144   return success();
1145 }
1146 
1147 //===----------------------------------------------------------------------===//
1148 // LoadOp
1149 //===----------------------------------------------------------------------===//
1150 
1151 static LogicalResult verify(LoadOp op) {
1152   if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
1153     return op.emitOpError("incorrect number of indices for load");
1154   return success();
1155 }
1156 
1157 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
1158   /// load(memrefcast) -> load
1159   if (succeeded(foldMemRefCast(*this)))
1160     return getResult();
1161   return OpFoldResult();
1162 }
1163 
1164 namespace {
1165 /// Fold a load on a buffer_cast operation into an tensor.extract on the
1166 /// corresponding tensor.
1167 struct LoadOfBufferCast : public OpRewritePattern<LoadOp> {
1168   using OpRewritePattern<LoadOp>::OpRewritePattern;
1169 
1170   LogicalResult matchAndRewrite(LoadOp load,
1171                                 PatternRewriter &rewriter) const override {
1172     auto buffercast = load.memref().getDefiningOp<BufferCastOp>();
1173     if (!buffercast)
1174       return failure();
1175 
1176     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, buffercast.tensor(),
1177                                                    load.indices());
1178     return success();
1179   }
1180 };
1181 } // end anonymous namespace.
1182 
1183 void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
1184                                          MLIRContext *context) {
1185   results.add<LoadOfBufferCast>(context);
1186 }
1187 
1188 //===----------------------------------------------------------------------===//
1189 // PrefetchOp
1190 //===----------------------------------------------------------------------===//
1191 
1192 static void print(OpAsmPrinter &p, PrefetchOp op) {
1193   p << PrefetchOp::getOperationName() << " " << op.memref() << '[';
1194   p.printOperands(op.indices());
1195   p << ']' << ", " << (op.isWrite() ? "write" : "read");
1196   p << ", locality<" << op.localityHint();
1197   p << ">, " << (op.isDataCache() ? "data" : "instr");
1198   p.printOptionalAttrDict(
1199       op->getAttrs(),
1200       /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1201   p << " : " << op.getMemRefType();
1202 }
1203 
1204 static ParseResult parsePrefetchOp(OpAsmParser &parser,
1205                                    OperationState &result) {
1206   OpAsmParser::OperandType memrefInfo;
1207   SmallVector<OpAsmParser::OperandType, 4> indexInfo;
1208   IntegerAttr localityHint;
1209   MemRefType type;
1210   StringRef readOrWrite, cacheType;
1211 
1212   auto indexTy = parser.getBuilder().getIndexType();
1213   auto i32Type = parser.getBuilder().getIntegerType(32);
1214   if (parser.parseOperand(memrefInfo) ||
1215       parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1216       parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1217       parser.parseComma() || parser.parseKeyword("locality") ||
1218       parser.parseLess() ||
1219       parser.parseAttribute(localityHint, i32Type, "localityHint",
1220                             result.attributes) ||
1221       parser.parseGreater() || parser.parseComma() ||
1222       parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1223       parser.resolveOperand(memrefInfo, type, result.operands) ||
1224       parser.resolveOperands(indexInfo, indexTy, result.operands))
1225     return failure();
1226 
1227   if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
1228     return parser.emitError(parser.getNameLoc(),
1229                             "rw specifier has to be 'read' or 'write'");
1230   result.addAttribute(
1231       PrefetchOp::getIsWriteAttrName(),
1232       parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
1233 
1234   if (!cacheType.equals("data") && !cacheType.equals("instr"))
1235     return parser.emitError(parser.getNameLoc(),
1236                             "cache type has to be 'data' or 'instr'");
1237 
1238   result.addAttribute(
1239       PrefetchOp::getIsDataCacheAttrName(),
1240       parser.getBuilder().getBoolAttr(cacheType.equals("data")));
1241 
1242   return success();
1243 }
1244 
1245 static LogicalResult verify(PrefetchOp op) {
1246   if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
1247     return op.emitOpError("too few indices");
1248 
1249   return success();
1250 }
1251 
1252 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
1253                                SmallVectorImpl<OpFoldResult> &results) {
1254   // prefetch(memrefcast) -> prefetch
1255   return foldMemRefCast(*this);
1256 }
1257 
1258 //===----------------------------------------------------------------------===//
1259 // ReinterpretCastOp
1260 //===----------------------------------------------------------------------===//
1261 
1262 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1263 /// `staticSizes` and `staticStrides` are automatically filled with
1264 /// source-memref-rank sentinel values that encode dynamic entries.
1265 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1266                               MemRefType resultType, Value source,
1267                               OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1268                               ArrayRef<OpFoldResult> strides,
1269                               ArrayRef<NamedAttribute> attrs) {
1270   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1271   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1272   dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets,
1273                              ShapedType::kDynamicStrideOrOffset);
1274   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1275                              ShapedType::kDynamicSize);
1276   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1277                              ShapedType::kDynamicStrideOrOffset);
1278   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1279         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1280         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1281   result.addAttributes(attrs);
1282 }
1283 
1284 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1285                               MemRefType resultType, Value source,
1286                               int64_t offset, ArrayRef<int64_t> sizes,
1287                               ArrayRef<int64_t> strides,
1288                               ArrayRef<NamedAttribute> attrs) {
1289   SmallVector<OpFoldResult> sizeValues =
1290       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1291         return b.getI64IntegerAttr(v);
1292       }));
1293   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1294       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1295         return b.getI64IntegerAttr(v);
1296       }));
1297   build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1298         strideValues, attrs);
1299 }
1300 
1301 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1302                               MemRefType resultType, Value source, Value offset,
1303                               ValueRange sizes, ValueRange strides,
1304                               ArrayRef<NamedAttribute> attrs) {
1305   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1306       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1307   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1308       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1309   build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1310 }
1311 
1312 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1313 // completed automatically, like we have for subview and subtensor.
1314 static LogicalResult verify(ReinterpretCastOp op) {
1315   // The source and result memrefs should be in the same memory space.
1316   auto srcType = op.source().getType().cast<BaseMemRefType>();
1317   auto resultType = op.getType().cast<MemRefType>();
1318   if (srcType.getMemorySpace() != resultType.getMemorySpace())
1319     return op.emitError("different memory spaces specified for source type ")
1320            << srcType << " and result memref type " << resultType;
1321   if (srcType.getElementType() != resultType.getElementType())
1322     return op.emitError("different element types specified for source type ")
1323            << srcType << " and result memref type " << resultType;
1324 
1325   // Match sizes in result memref type and in static_sizes attribute.
1326   for (auto &en :
1327        llvm::enumerate(llvm::zip(resultType.getShape(),
1328                                  extractFromI64ArrayAttr(op.static_sizes())))) {
1329     int64_t resultSize = std::get<0>(en.value());
1330     int64_t expectedSize = std::get<1>(en.value());
1331     if (resultSize != expectedSize)
1332       return op.emitError("expected result type with size = ")
1333              << expectedSize << " instead of " << resultSize
1334              << " in dim = " << en.index();
1335   }
1336 
1337   // Match offset and strides in static_offset and static_strides attributes if
1338   // result memref type has an affine map specified.
1339   if (!resultType.getAffineMaps().empty()) {
1340     int64_t resultOffset;
1341     SmallVector<int64_t, 4> resultStrides;
1342     if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
1343       return failure();
1344 
1345     // Match offset in result memref type and in static_offsets attribute.
1346     int64_t expectedOffset =
1347         extractFromI64ArrayAttr(op.static_offsets()).front();
1348     if (resultOffset != expectedOffset)
1349       return op.emitError("expected result type with offset = ")
1350              << resultOffset << " instead of " << expectedOffset;
1351 
1352     // Match strides in result memref type and in static_strides attribute.
1353     for (auto &en : llvm::enumerate(llvm::zip(
1354              resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
1355       int64_t resultStride = std::get<0>(en.value());
1356       int64_t expectedStride = std::get<1>(en.value());
1357       if (resultStride != expectedStride)
1358         return op.emitError("expected result type with stride = ")
1359                << expectedStride << " instead of " << resultStride
1360                << " in dim = " << en.index();
1361     }
1362   }
1363   return success();
1364 }
1365 
1366 //===----------------------------------------------------------------------===//
1367 // ReshapeOp
1368 //===----------------------------------------------------------------------===//
1369 
1370 static LogicalResult verify(ReshapeOp op) {
1371   Type operandType = op.source().getType();
1372   Type resultType = op.result().getType();
1373 
1374   Type operandElementType = operandType.cast<ShapedType>().getElementType();
1375   Type resultElementType = resultType.cast<ShapedType>().getElementType();
1376   if (operandElementType != resultElementType)
1377     return op.emitOpError("element types of source and destination memref "
1378                           "types should be the same");
1379 
1380   if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
1381     if (!operandMemRefType.getAffineMaps().empty())
1382       return op.emitOpError(
1383           "source memref type should have identity affine map");
1384 
1385   int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0);
1386   auto resultMemRefType = resultType.dyn_cast<MemRefType>();
1387   if (resultMemRefType) {
1388     if (!resultMemRefType.getAffineMaps().empty())
1389       return op.emitOpError(
1390           "result memref type should have identity affine map");
1391     if (shapeSize == ShapedType::kDynamicSize)
1392       return op.emitOpError("cannot use shape operand with dynamic length to "
1393                             "reshape to statically-ranked memref type");
1394     if (shapeSize != resultMemRefType.getRank())
1395       return op.emitOpError(
1396           "length of shape operand differs from the result's memref rank");
1397   }
1398   return success();
1399 }
1400 
1401 //===----------------------------------------------------------------------===//
1402 // StoreOp
1403 //===----------------------------------------------------------------------===//
1404 
1405 static LogicalResult verify(StoreOp op) {
1406   if (op.getNumOperands() != 2 + op.getMemRefType().getRank())
1407     return op.emitOpError("store index operand count not equal to memref rank");
1408 
1409   return success();
1410 }
1411 
1412 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
1413                             SmallVectorImpl<OpFoldResult> &results) {
1414   /// store(memrefcast) -> store
1415   return foldMemRefCast(*this);
1416 }
1417 
1418 //===----------------------------------------------------------------------===//
1419 // SubViewOp
1420 //===----------------------------------------------------------------------===//
1421 
1422 namespace {
1423 /// Helpers to write more idiomatic operations.
1424 namespace saturated_arith {
1425 struct Wrapper {
1426   explicit Wrapper(int64_t v) : v(v) {}
1427   operator int64_t() { return v; }
1428   int64_t v;
1429 };
1430 Wrapper operator+(Wrapper a, int64_t b) {
1431   if (ShapedType::isDynamicStrideOrOffset(a) ||
1432       ShapedType::isDynamicStrideOrOffset(b))
1433     return Wrapper(ShapedType::kDynamicStrideOrOffset);
1434   return Wrapper(a.v + b);
1435 }
1436 Wrapper operator*(Wrapper a, int64_t b) {
1437   if (ShapedType::isDynamicStrideOrOffset(a) ||
1438       ShapedType::isDynamicStrideOrOffset(b))
1439     return Wrapper(ShapedType::kDynamicStrideOrOffset);
1440   return Wrapper(a.v * b);
1441 }
1442 } // end namespace saturated_arith
1443 } // end namespace
1444 
1445 /// A subview result type can be fully inferred from the source type and the
1446 /// static representation of offsets, sizes and strides. Special sentinels
1447 /// encode the dynamic case.
1448 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
1449                                 ArrayRef<int64_t> leadingStaticOffsets,
1450                                 ArrayRef<int64_t> leadingStaticSizes,
1451                                 ArrayRef<int64_t> leadingStaticStrides) {
1452   // A subview may specify only a leading subset of offset/sizes/strides in
1453   // which case we complete with offset=0, sizes from memref type and strides=1.
1454   unsigned rank = sourceMemRefType.getRank();
1455   assert(leadingStaticOffsets.size() <= rank &&
1456          "unexpected leadingStaticOffsets overflow");
1457   assert(leadingStaticSizes.size() <= rank &&
1458          "unexpected leadingStaticSizes overflow");
1459   assert(leadingStaticStrides.size() <= rank &&
1460          "unexpected leadingStaticStrides overflow");
1461   auto staticOffsets = llvm::to_vector<4>(leadingStaticOffsets);
1462   auto staticSizes = llvm::to_vector<4>(leadingStaticSizes);
1463   auto staticStrides = llvm::to_vector<4>(leadingStaticStrides);
1464   unsigned numTrailingOffsets = rank - staticOffsets.size();
1465   unsigned numTrailingSizes = rank - staticSizes.size();
1466   unsigned numTrailingStrides = rank - staticStrides.size();
1467   staticOffsets.append(numTrailingOffsets, 0);
1468   llvm::append_range(staticSizes,
1469                      sourceMemRefType.getShape().take_back(numTrailingSizes));
1470   staticStrides.append(numTrailingStrides, 1);
1471 
1472   // Extract source offset and strides.
1473   int64_t sourceOffset;
1474   SmallVector<int64_t, 4> sourceStrides;
1475   auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset);
1476   assert(succeeded(res) && "SubViewOp expected strided memref type");
1477   (void)res;
1478 
1479   // Compute target offset whose value is:
1480   //   `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
1481   int64_t targetOffset = sourceOffset;
1482   for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
1483     auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
1484     using namespace saturated_arith;
1485     targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride;
1486   }
1487 
1488   // Compute target stride whose value is:
1489   //   `sourceStrides_i * staticStrides_i`.
1490   SmallVector<int64_t, 4> targetStrides;
1491   targetStrides.reserve(staticOffsets.size());
1492   for (auto it : llvm::zip(sourceStrides, staticStrides)) {
1493     auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
1494     using namespace saturated_arith;
1495     targetStrides.push_back(Wrapper(sourceStride) * staticStride);
1496   }
1497 
1498   // The type is now known.
1499   return MemRefType::get(
1500       staticSizes, sourceMemRefType.getElementType(),
1501       makeStridedLinearLayoutMap(targetStrides, targetOffset,
1502                                  sourceMemRefType.getContext()),
1503       sourceMemRefType.getMemorySpace());
1504 }
1505 
1506 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
1507                                 ArrayRef<OpFoldResult> leadingStaticOffsets,
1508                                 ArrayRef<OpFoldResult> leadingStaticSizes,
1509                                 ArrayRef<OpFoldResult> leadingStaticStrides) {
1510   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1511   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1512   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
1513                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
1514   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
1515                              ShapedType::kDynamicSize);
1516   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
1517                              staticStrides, ShapedType::kDynamicStrideOrOffset);
1518   return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
1519                                     staticSizes, staticStrides)
1520       .cast<MemRefType>();
1521 }
1522 
1523 Type SubViewOp::inferRankReducedResultType(
1524     unsigned resultRank, MemRefType sourceRankedTensorType,
1525     ArrayRef<int64_t> leadingStaticOffsets,
1526     ArrayRef<int64_t> leadingStaticSizes,
1527     ArrayRef<int64_t> leadingStaticStrides) {
1528   auto inferredType =
1529       inferResultType(sourceRankedTensorType, leadingStaticOffsets,
1530                       leadingStaticSizes, leadingStaticStrides)
1531           .cast<MemRefType>();
1532   assert(inferredType.getRank() >= resultRank && "expected ");
1533   int rankDiff = inferredType.getRank() - resultRank;
1534   if (rankDiff > 0) {
1535     auto shape = inferredType.getShape();
1536     llvm::SmallDenseSet<unsigned> dimsToProject;
1537     mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject);
1538     SmallVector<int64_t> projectedShape;
1539     for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
1540       if (!dimsToProject.contains(pos))
1541         projectedShape.push_back(shape[pos]);
1542 
1543     AffineMap map;
1544     auto maps = inferredType.getAffineMaps();
1545     if (!maps.empty() && maps.front())
1546       map = getProjectedMap(maps.front(), dimsToProject);
1547     inferredType =
1548         MemRefType::get(projectedShape, inferredType.getElementType(), map,
1549                         inferredType.getMemorySpace());
1550   }
1551   return inferredType;
1552 }
1553 
1554 Type SubViewOp::inferRankReducedResultType(
1555     unsigned resultRank, MemRefType sourceRankedTensorType,
1556     ArrayRef<OpFoldResult> leadingStaticOffsets,
1557     ArrayRef<OpFoldResult> leadingStaticSizes,
1558     ArrayRef<OpFoldResult> leadingStaticStrides) {
1559   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1560   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1561   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
1562                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
1563   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
1564                              ShapedType::kDynamicSize);
1565   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
1566                              staticStrides, ShapedType::kDynamicStrideOrOffset);
1567   return SubViewOp::inferRankReducedResultType(
1568       resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
1569       staticStrides);
1570 }
1571 // Build a SubViewOp with mixed static and dynamic entries and custom result
1572 // type. If the type passed is nullptr, it is inferred.
1573 void SubViewOp::build(OpBuilder &b, OperationState &result,
1574                       MemRefType resultType, Value source,
1575                       ArrayRef<OpFoldResult> offsets,
1576                       ArrayRef<OpFoldResult> sizes,
1577                       ArrayRef<OpFoldResult> strides,
1578                       ArrayRef<NamedAttribute> attrs) {
1579   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1580   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1581   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1582                              ShapedType::kDynamicStrideOrOffset);
1583   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1584                              ShapedType::kDynamicSize);
1585   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1586                              ShapedType::kDynamicStrideOrOffset);
1587   auto sourceMemRefType = source.getType().cast<MemRefType>();
1588   // Structuring implementation this way avoids duplication between builders.
1589   if (!resultType) {
1590     resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
1591                                             staticSizes, staticStrides)
1592                      .cast<MemRefType>();
1593   }
1594   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1595         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1596         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1597   result.addAttributes(attrs);
1598 }
1599 
1600 // Build a SubViewOp with mixed static and dynamic entries and inferred result
1601 // type.
1602 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1603                       ArrayRef<OpFoldResult> offsets,
1604                       ArrayRef<OpFoldResult> sizes,
1605                       ArrayRef<OpFoldResult> strides,
1606                       ArrayRef<NamedAttribute> attrs) {
1607   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
1608 }
1609 
1610 // Build a SubViewOp with static entries and inferred result type.
1611 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1612                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1613                       ArrayRef<int64_t> strides,
1614                       ArrayRef<NamedAttribute> attrs) {
1615   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1616       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
1617         return b.getI64IntegerAttr(v);
1618       }));
1619   SmallVector<OpFoldResult> sizeValues =
1620       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1621         return b.getI64IntegerAttr(v);
1622       }));
1623   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1624       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1625         return b.getI64IntegerAttr(v);
1626       }));
1627   build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
1628 }
1629 
1630 // Build a SubViewOp with dynamic entries and custom result type. If the
1631 // type passed is nullptr, it is inferred.
1632 void SubViewOp::build(OpBuilder &b, OperationState &result,
1633                       MemRefType resultType, Value source,
1634                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1635                       ArrayRef<int64_t> strides,
1636                       ArrayRef<NamedAttribute> attrs) {
1637   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1638       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
1639         return b.getI64IntegerAttr(v);
1640       }));
1641   SmallVector<OpFoldResult> sizeValues =
1642       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1643         return b.getI64IntegerAttr(v);
1644       }));
1645   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1646       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1647         return b.getI64IntegerAttr(v);
1648       }));
1649   build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
1650         attrs);
1651 }
1652 
1653 // Build a SubViewOp with dynamic entries and custom result type. If the type
1654 // passed is nullptr, it is inferred.
1655 void SubViewOp::build(OpBuilder &b, OperationState &result,
1656                       MemRefType resultType, Value source, ValueRange offsets,
1657                       ValueRange sizes, ValueRange strides,
1658                       ArrayRef<NamedAttribute> attrs) {
1659   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1660       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1661   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1662       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1663   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1664       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1665   build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
1666 }
1667 
1668 // Build a SubViewOp with dynamic entries and inferred result type.
1669 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1670                       ValueRange offsets, ValueRange sizes, ValueRange strides,
1671                       ArrayRef<NamedAttribute> attrs) {
1672   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
1673 }
1674 
1675 /// For ViewLikeOpInterface.
1676 Value SubViewOp::getViewSource() { return source(); }
1677 
1678 enum SubViewVerificationResult {
1679   Success,
1680   RankTooLarge,
1681   SizeMismatch,
1682   ElemTypeMismatch,
1683   MemSpaceMismatch,
1684   AffineMapMismatch
1685 };
1686 
1687 /// Checks if `original` Type type can be rank reduced to `reduced` type.
1688 /// This function is slight variant of `is subsequence` algorithm where
1689 /// not matching dimension must be 1.
1690 static SubViewVerificationResult
1691 isRankReducedType(Type originalType, Type candidateReducedType,
1692                   std::string *errMsg = nullptr) {
1693   if (originalType == candidateReducedType)
1694     return SubViewVerificationResult::Success;
1695   if (!originalType.isa<MemRefType>())
1696     return SubViewVerificationResult::Success;
1697   if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>())
1698     return SubViewVerificationResult::Success;
1699 
1700   ShapedType originalShapedType = originalType.cast<ShapedType>();
1701   ShapedType candidateReducedShapedType =
1702       candidateReducedType.cast<ShapedType>();
1703 
1704   // Rank and size logic is valid for all ShapedTypes.
1705   ArrayRef<int64_t> originalShape = originalShapedType.getShape();
1706   ArrayRef<int64_t> candidateReducedShape =
1707       candidateReducedShapedType.getShape();
1708   unsigned originalRank = originalShape.size(),
1709            candidateReducedRank = candidateReducedShape.size();
1710   if (candidateReducedRank > originalRank)
1711     return SubViewVerificationResult::RankTooLarge;
1712 
1713   auto optionalUnusedDimsMask =
1714       computeRankReductionMask(originalShape, candidateReducedShape);
1715 
1716   // Sizes cannot be matched in case empty vector is returned.
1717   if (!optionalUnusedDimsMask.hasValue())
1718     return SubViewVerificationResult::SizeMismatch;
1719 
1720   if (originalShapedType.getElementType() !=
1721       candidateReducedShapedType.getElementType())
1722     return SubViewVerificationResult::ElemTypeMismatch;
1723 
1724   // Strided layout logic is relevant for MemRefType only.
1725   MemRefType original = originalType.cast<MemRefType>();
1726   MemRefType candidateReduced = candidateReducedType.cast<MemRefType>();
1727   if (original.getMemorySpace() != candidateReduced.getMemorySpace())
1728     return SubViewVerificationResult::MemSpaceMismatch;
1729 
1730   llvm::SmallDenseSet<unsigned> unusedDims = optionalUnusedDimsMask.getValue();
1731   auto inferredType =
1732       getProjectedMap(getStridedLinearLayoutMap(original), unusedDims);
1733   AffineMap candidateLayout;
1734   if (candidateReduced.getAffineMaps().empty())
1735     candidateLayout = getStridedLinearLayoutMap(candidateReduced);
1736   else
1737     candidateLayout = candidateReduced.getAffineMaps().front();
1738   assert(inferredType.getNumResults() == 1 &&
1739          candidateLayout.getNumResults() == 1);
1740   if (inferredType.getNumSymbols() != candidateLayout.getNumSymbols() ||
1741       inferredType.getNumDims() != candidateLayout.getNumDims()) {
1742     if (errMsg) {
1743       llvm::raw_string_ostream os(*errMsg);
1744       os << "inferred type: " << inferredType;
1745     }
1746     return SubViewVerificationResult::AffineMapMismatch;
1747   }
1748   // Check that the difference of the affine maps simplifies to 0.
1749   AffineExpr diffExpr =
1750       inferredType.getResult(0) - candidateLayout.getResult(0);
1751   diffExpr = simplifyAffineExpr(diffExpr, inferredType.getNumDims(),
1752                                 inferredType.getNumSymbols());
1753   auto cst = diffExpr.dyn_cast<AffineConstantExpr>();
1754   if (!(cst && cst.getValue() == 0)) {
1755     if (errMsg) {
1756       llvm::raw_string_ostream os(*errMsg);
1757       os << "inferred type: " << inferredType;
1758     }
1759     return SubViewVerificationResult::AffineMapMismatch;
1760   }
1761   return SubViewVerificationResult::Success;
1762 }
1763 
1764 template <typename OpTy>
1765 static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result,
1766                                             OpTy op, Type expectedType,
1767                                             StringRef errMsg = "") {
1768   auto memrefType = expectedType.cast<ShapedType>();
1769   switch (result) {
1770   case SubViewVerificationResult::Success:
1771     return success();
1772   case SubViewVerificationResult::RankTooLarge:
1773     return op.emitError("expected result rank to be smaller or equal to ")
1774            << "the source rank. " << errMsg;
1775   case SubViewVerificationResult::SizeMismatch:
1776     return op.emitError("expected result type to be ")
1777            << expectedType
1778            << " or a rank-reduced version. (mismatch of result sizes) "
1779            << errMsg;
1780   case SubViewVerificationResult::ElemTypeMismatch:
1781     return op.emitError("expected result element type to be ")
1782            << memrefType.getElementType() << errMsg;
1783   case SubViewVerificationResult::MemSpaceMismatch:
1784     return op.emitError("expected result and source memory spaces to match.")
1785            << errMsg;
1786   case SubViewVerificationResult::AffineMapMismatch:
1787     return op.emitError("expected result type to be ")
1788            << expectedType
1789            << " or a rank-reduced version. (mismatch of result affine map) "
1790            << errMsg;
1791   }
1792   llvm_unreachable("unexpected subview verification result");
1793 }
1794 
1795 /// Verifier for SubViewOp.
1796 static LogicalResult verify(SubViewOp op) {
1797   MemRefType baseType = op.getSourceType();
1798   MemRefType subViewType = op.getType();
1799 
1800   // The base memref and the view memref should be in the same memory space.
1801   if (baseType.getMemorySpace() != subViewType.getMemorySpace())
1802     return op.emitError("different memory spaces specified for base memref "
1803                         "type ")
1804            << baseType << " and subview memref type " << subViewType;
1805 
1806   // Verify that the base memref type has a strided layout map.
1807   if (!isStrided(baseType))
1808     return op.emitError("base type ") << baseType << " is not strided";
1809 
1810   // Verify result type against inferred type.
1811   auto expectedType = SubViewOp::inferResultType(
1812       baseType, extractFromI64ArrayAttr(op.static_offsets()),
1813       extractFromI64ArrayAttr(op.static_sizes()),
1814       extractFromI64ArrayAttr(op.static_strides()));
1815 
1816   std::string errMsg;
1817   auto result = isRankReducedType(expectedType, subViewType, &errMsg);
1818   return produceSubViewErrorMsg(result, op, expectedType, errMsg);
1819 }
1820 
1821 raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) {
1822   return os << "range " << range.offset << ":" << range.size << ":"
1823             << range.stride;
1824 }
1825 
1826 /// Return the list of Range (i.e. offset, size, stride). Each Range
1827 /// entry contains either the dynamic value or a ConstantIndexOp constructed
1828 /// with `b` at location `loc`.
1829 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
1830                                               OpBuilder &b, Location loc) {
1831   std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
1832   assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
1833   assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
1834   SmallVector<Range, 8> res;
1835   unsigned rank = ranks[0];
1836   res.reserve(rank);
1837   for (unsigned idx = 0; idx < rank; ++idx) {
1838     Value offset =
1839         op.isDynamicOffset(idx)
1840             ? op.getDynamicOffset(idx)
1841             : b.create<ConstantIndexOp>(loc, op.getStaticOffset(idx));
1842     Value size = op.isDynamicSize(idx)
1843                      ? op.getDynamicSize(idx)
1844                      : b.create<ConstantIndexOp>(loc, op.getStaticSize(idx));
1845     Value stride =
1846         op.isDynamicStride(idx)
1847             ? op.getDynamicStride(idx)
1848             : b.create<ConstantIndexOp>(loc, op.getStaticStride(idx));
1849     res.emplace_back(Range{offset, size, stride});
1850   }
1851   return res;
1852 }
1853 
1854 /// Infer the canonical type of the result of a subview operation. Returns a
1855 /// type with rank `resultRank` that is either the rank of the rank-reduced
1856 /// type, or the non-rank-reduced type.
1857 static MemRefType
1858 getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType,
1859                               ArrayRef<OpFoldResult> mixedOffsets,
1860                               ArrayRef<OpFoldResult> mixedSizes,
1861                               ArrayRef<OpFoldResult> mixedStrides) {
1862   auto resultType =
1863       SubViewOp::inferRankReducedResultType(
1864           resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides)
1865           .cast<MemRefType>();
1866   if (resultType.getRank() != resultRank) {
1867     resultType = SubViewOp::inferResultType(sourceType, mixedOffsets,
1868                                             mixedSizes, mixedStrides)
1869                      .cast<MemRefType>();
1870   }
1871   return resultType;
1872 }
1873 
1874 namespace {
1875 /// Pattern to rewrite a subview op with MemRefCast arguments.
1876 /// This essentially pushes memref.cast past its consuming subview when
1877 /// `canFoldIntoConsumerOp` is true.
1878 ///
1879 /// Example:
1880 /// ```
1881 ///   %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
1882 ///   %1 = memref.subview %0[0, 0][3, 4][1, 1] :
1883 ///     memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
1884 /// ```
1885 /// is rewritten into:
1886 /// ```
1887 ///   %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
1888 ///   %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to
1889 ///     memref<3x4xf32, offset:?, strides:[?, 1]>
1890 /// ```
1891 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
1892 public:
1893   using OpRewritePattern<SubViewOp>::OpRewritePattern;
1894 
1895   LogicalResult matchAndRewrite(SubViewOp subViewOp,
1896                                 PatternRewriter &rewriter) const override {
1897     // Any constant operand, just return to let SubViewOpConstantFolder kick in.
1898     if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
1899           return matchPattern(operand, matchConstantIndex());
1900         }))
1901       return failure();
1902 
1903     auto castOp = subViewOp.source().getDefiningOp<CastOp>();
1904     if (!castOp)
1905       return failure();
1906 
1907     if (!CastOp::canFoldIntoConsumerOp(castOp))
1908       return failure();
1909 
1910     /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
1911     /// the cast source operand type and the SubViewOp static information. This
1912     /// is the resulting type if the MemRefCastOp were folded.
1913     auto resultType = getCanonicalSubViewResultType(
1914         subViewOp.getType().getRank(),
1915         castOp.source().getType().cast<MemRefType>(),
1916         subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
1917         subViewOp.getMixedStrides());
1918     Value newSubView = rewriter.create<SubViewOp>(
1919         subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
1920         subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),
1921         subViewOp.static_sizes(), subViewOp.static_strides());
1922     rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
1923                                         newSubView);
1924     return success();
1925   }
1926 };
1927 } // namespace
1928 
1929 /// Return the canonical type of the result of a subview.
1930 struct SubViewReturnTypeCanonicalizer {
1931   MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
1932                         ArrayRef<OpFoldResult> mixedSizes,
1933                         ArrayRef<OpFoldResult> mixedStrides) {
1934     return getCanonicalSubViewResultType(op.getType().getRank(),
1935                                          op.getSourceType(), mixedOffsets,
1936                                          mixedSizes, mixedStrides);
1937   }
1938 };
1939 
1940 /// A canonicalizer wrapper to replace SubViewOps.
1941 struct SubViewCanonicalizer {
1942   void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
1943     rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType());
1944   }
1945 };
1946 
1947 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
1948                                             MLIRContext *context) {
1949   results
1950       .add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
1951                SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>,
1952            SubViewOpMemRefCastFolder>(context);
1953 }
1954 
1955 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
1956   auto resultShapedType = getResult().getType().cast<ShapedType>();
1957   auto sourceShapedType = source().getType().cast<ShapedType>();
1958 
1959   if (resultShapedType.hasStaticShape() &&
1960       resultShapedType == sourceShapedType) {
1961     return getViewSource();
1962   }
1963 
1964   return {};
1965 }
1966 
1967 //===----------------------------------------------------------------------===//
1968 // TensorLoadOp
1969 //===----------------------------------------------------------------------===//
1970 
1971 OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute>) {
1972   if (auto bufferCast = memref().getDefiningOp<BufferCastOp>())
1973     // Approximate alias analysis by conservatively folding only when no there
1974     // is no interleaved operation.
1975     if (bufferCast->getBlock() == this->getOperation()->getBlock() &&
1976         bufferCast->getNextNode() == this->getOperation())
1977       return bufferCast.tensor();
1978   return {};
1979 }
1980 
1981 //===----------------------------------------------------------------------===//
1982 // TransposeOp
1983 //===----------------------------------------------------------------------===//
1984 
1985 /// Build a strided memref type by applying `permutationMap` tp `memRefType`.
1986 static MemRefType inferTransposeResultType(MemRefType memRefType,
1987                                            AffineMap permutationMap) {
1988   auto rank = memRefType.getRank();
1989   auto originalSizes = memRefType.getShape();
1990   // Compute permuted sizes.
1991   SmallVector<int64_t, 4> sizes(rank, 0);
1992   for (auto en : llvm::enumerate(permutationMap.getResults()))
1993     sizes[en.index()] =
1994         originalSizes[en.value().cast<AffineDimExpr>().getPosition()];
1995 
1996   // Compute permuted strides.
1997   int64_t offset;
1998   SmallVector<int64_t, 4> strides;
1999   auto res = getStridesAndOffset(memRefType, strides, offset);
2000   assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank));
2001   (void)res;
2002   auto map =
2003       makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
2004   map = permutationMap ? map.compose(permutationMap) : map;
2005   return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map);
2006 }
2007 
2008 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
2009                         AffineMapAttr permutation,
2010                         ArrayRef<NamedAttribute> attrs) {
2011   auto permutationMap = permutation.getValue();
2012   assert(permutationMap);
2013 
2014   auto memRefType = in.getType().cast<MemRefType>();
2015   // Compute result type.
2016   MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
2017 
2018   build(b, result, resultType, in, attrs);
2019   result.addAttribute(TransposeOp::getPermutationAttrName(), permutation);
2020 }
2021 
2022 // transpose $in $permutation attr-dict : type($in) `to` type(results)
2023 static void print(OpAsmPrinter &p, TransposeOp op) {
2024   p << "memref.transpose " << op.in() << " " << op.permutation();
2025   p.printOptionalAttrDict(op->getAttrs(),
2026                           {TransposeOp::getPermutationAttrName()});
2027   p << " : " << op.in().getType() << " to " << op.getType();
2028 }
2029 
2030 static ParseResult parseTransposeOp(OpAsmParser &parser,
2031                                     OperationState &result) {
2032   OpAsmParser::OperandType in;
2033   AffineMap permutation;
2034   MemRefType srcType, dstType;
2035   if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
2036       parser.parseOptionalAttrDict(result.attributes) ||
2037       parser.parseColonType(srcType) ||
2038       parser.resolveOperand(in, srcType, result.operands) ||
2039       parser.parseKeywordType("to", dstType) ||
2040       parser.addTypeToList(dstType, result.types))
2041     return failure();
2042 
2043   result.addAttribute(TransposeOp::getPermutationAttrName(),
2044                       AffineMapAttr::get(permutation));
2045   return success();
2046 }
2047 
2048 static LogicalResult verify(TransposeOp op) {
2049   if (!op.permutation().isPermutation())
2050     return op.emitOpError("expected a permutation map");
2051   if (op.permutation().getNumDims() != op.getShapedType().getRank())
2052     return op.emitOpError(
2053         "expected a permutation map of same rank as the input");
2054 
2055   auto srcType = op.in().getType().cast<MemRefType>();
2056   auto dstType = op.getType().cast<MemRefType>();
2057   auto transposedType = inferTransposeResultType(srcType, op.permutation());
2058   if (dstType != transposedType)
2059     return op.emitOpError("output type ")
2060            << dstType << " does not match transposed input type " << srcType
2061            << ", " << transposedType;
2062   return success();
2063 }
2064 
2065 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
2066   if (succeeded(foldMemRefCast(*this)))
2067     return getResult();
2068   return {};
2069 }
2070 
2071 //===----------------------------------------------------------------------===//
2072 // ViewOp
2073 //===----------------------------------------------------------------------===//
2074 
2075 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) {
2076   OpAsmParser::OperandType srcInfo;
2077   SmallVector<OpAsmParser::OperandType, 1> offsetInfo;
2078   SmallVector<OpAsmParser::OperandType, 4> sizesInfo;
2079   auto indexType = parser.getBuilder().getIndexType();
2080   Type srcType, dstType;
2081   llvm::SMLoc offsetLoc;
2082   if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) ||
2083       parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square))
2084     return failure();
2085 
2086   if (offsetInfo.size() != 1)
2087     return parser.emitError(offsetLoc) << "expects 1 offset operand";
2088 
2089   return failure(
2090       parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) ||
2091       parser.parseOptionalAttrDict(result.attributes) ||
2092       parser.parseColonType(srcType) ||
2093       parser.resolveOperand(srcInfo, srcType, result.operands) ||
2094       parser.resolveOperands(offsetInfo, indexType, result.operands) ||
2095       parser.resolveOperands(sizesInfo, indexType, result.operands) ||
2096       parser.parseKeywordType("to", dstType) ||
2097       parser.addTypeToList(dstType, result.types));
2098 }
2099 
2100 static void print(OpAsmPrinter &p, ViewOp op) {
2101   p << op.getOperationName() << ' ' << op.getOperand(0) << '[';
2102   p.printOperand(op.byte_shift());
2103   p << "][" << op.sizes() << ']';
2104   p.printOptionalAttrDict(op->getAttrs());
2105   p << " : " << op.getOperand(0).getType() << " to " << op.getType();
2106 }
2107 
2108 static LogicalResult verify(ViewOp op) {
2109   auto baseType = op.getOperand(0).getType().cast<MemRefType>();
2110   auto viewType = op.getType();
2111 
2112   // The base memref should have identity layout map (or none).
2113   if (baseType.getAffineMaps().size() > 1 ||
2114       (baseType.getAffineMaps().size() == 1 &&
2115        !baseType.getAffineMaps()[0].isIdentity()))
2116     return op.emitError("unsupported map for base memref type ") << baseType;
2117 
2118   // The result memref should have identity layout map (or none).
2119   if (viewType.getAffineMaps().size() > 1 ||
2120       (viewType.getAffineMaps().size() == 1 &&
2121        !viewType.getAffineMaps()[0].isIdentity()))
2122     return op.emitError("unsupported map for result memref type ") << viewType;
2123 
2124   // The base memref and the view memref should be in the same memory space.
2125   if (baseType.getMemorySpace() != viewType.getMemorySpace())
2126     return op.emitError("different memory spaces specified for base memref "
2127                         "type ")
2128            << baseType << " and view memref type " << viewType;
2129 
2130   // Verify that we have the correct number of sizes for the result type.
2131   unsigned numDynamicDims = viewType.getNumDynamicDims();
2132   if (op.sizes().size() != numDynamicDims)
2133     return op.emitError("incorrect number of size operands for type ")
2134            << viewType;
2135 
2136   return success();
2137 }
2138 
2139 Value ViewOp::getViewSource() { return source(); }
2140 
2141 namespace {
2142 
2143 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
2144   using OpRewritePattern<ViewOp>::OpRewritePattern;
2145 
2146   LogicalResult matchAndRewrite(ViewOp viewOp,
2147                                 PatternRewriter &rewriter) const override {
2148     // Return if none of the operands are constants.
2149     if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
2150           return matchPattern(operand, matchConstantIndex());
2151         }))
2152       return failure();
2153 
2154     // Get result memref type.
2155     auto memrefType = viewOp.getType();
2156 
2157     // Get offset from old memref view type 'memRefType'.
2158     int64_t oldOffset;
2159     SmallVector<int64_t, 4> oldStrides;
2160     if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
2161       return failure();
2162     assert(oldOffset == 0 && "Expected 0 offset");
2163 
2164     SmallVector<Value, 4> newOperands;
2165 
2166     // Offset cannot be folded into result type.
2167 
2168     // Fold any dynamic dim operands which are produced by a constant.
2169     SmallVector<int64_t, 4> newShapeConstants;
2170     newShapeConstants.reserve(memrefType.getRank());
2171 
2172     unsigned dynamicDimPos = 0;
2173     unsigned rank = memrefType.getRank();
2174     for (unsigned dim = 0, e = rank; dim < e; ++dim) {
2175       int64_t dimSize = memrefType.getDimSize(dim);
2176       // If this is already static dimension, keep it.
2177       if (!ShapedType::isDynamic(dimSize)) {
2178         newShapeConstants.push_back(dimSize);
2179         continue;
2180       }
2181       auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp();
2182       if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
2183         // Dynamic shape dimension will be folded.
2184         newShapeConstants.push_back(constantIndexOp.getValue());
2185       } else {
2186         // Dynamic shape dimension not folded; copy operand from old memref.
2187         newShapeConstants.push_back(dimSize);
2188         newOperands.push_back(viewOp.sizes()[dynamicDimPos]);
2189       }
2190       dynamicDimPos++;
2191     }
2192 
2193     // Create new memref type with constant folded dims.
2194     MemRefType newMemRefType =
2195         MemRefType::Builder(memrefType).setShape(newShapeConstants);
2196     // Nothing new, don't fold.
2197     if (newMemRefType == memrefType)
2198       return failure();
2199 
2200     // Create new ViewOp.
2201     auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType,
2202                                              viewOp.getOperand(0),
2203                                              viewOp.byte_shift(), newOperands);
2204     // Insert a cast so we have the same type as the old memref type.
2205     rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType());
2206     return success();
2207   }
2208 };
2209 
2210 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
2211   using OpRewritePattern<ViewOp>::OpRewritePattern;
2212 
2213   LogicalResult matchAndRewrite(ViewOp viewOp,
2214                                 PatternRewriter &rewriter) const override {
2215     Value memrefOperand = viewOp.getOperand(0);
2216     CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
2217     if (!memrefCastOp)
2218       return failure();
2219     Value allocOperand = memrefCastOp.getOperand();
2220     AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
2221     if (!allocOp)
2222       return failure();
2223     rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
2224                                         viewOp.byte_shift(), viewOp.sizes());
2225     return success();
2226   }
2227 };
2228 
2229 } // end anonymous namespace
2230 
2231 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
2232                                          MLIRContext *context) {
2233   results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
2234 }
2235 
2236 //===----------------------------------------------------------------------===//
2237 // TableGen'd op method definitions
2238 //===----------------------------------------------------------------------===//
2239 
2240 #define GET_OP_CLASSES
2241 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
2242