1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/MemRef/IR/MemRef.h"
10 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Interfaces/InferTypeOpInterface.h"
21 #include "llvm/ADT/STLExtras.h"
22 
23 using namespace mlir;
24 using namespace mlir::memref;
25 
26 /// Materialize a single constant operation from a given attribute value with
27 /// the desired resultant type.
28 Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
29                                               Attribute value, Type type,
30                                               Location loc) {
31   return builder.create<mlir::ConstantOp>(loc, type, value);
32 }
33 
34 /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
35 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
36   return llvm::to_vector<4>(
37       llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
38         return a.cast<IntegerAttr>().getInt();
39       }));
40 }
41 
42 /// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
43 /// it is a Value or into `staticVec` if it is an IntegerAttr.
44 /// In the case of a Value, a copy of the `sentinel` value is also pushed to
45 /// `staticVec`. This is useful to extract mixed static and dynamic entries that
46 /// come from an AttrSizedOperandSegments trait.
47 static void dispatchIndexOpFoldResult(OpFoldResult ofr,
48                                       SmallVectorImpl<Value> &dynamicVec,
49                                       SmallVectorImpl<int64_t> &staticVec,
50                                       int64_t sentinel) {
51   if (auto v = ofr.dyn_cast<Value>()) {
52     dynamicVec.push_back(v);
53     staticVec.push_back(sentinel);
54     return;
55   }
56   APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue();
57   staticVec.push_back(apInt.getSExtValue());
58 }
59 
60 static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
61                                        SmallVectorImpl<Value> &dynamicVec,
62                                        SmallVectorImpl<int64_t> &staticVec,
63                                        int64_t sentinel) {
64   for (auto ofr : ofrs)
65     dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel);
66 }
67 
68 //===----------------------------------------------------------------------===//
69 // Common canonicalization pattern support logic
70 //===----------------------------------------------------------------------===//
71 
72 /// This is a common class used for patterns of the form
73 /// "someop(memrefcast) -> someop".  It folds the source of any memref.cast
74 /// into the root operation directly.
75 static LogicalResult foldMemRefCast(Operation *op) {
76   bool folded = false;
77   for (OpOperand &operand : op->getOpOperands()) {
78     auto cast = operand.get().getDefiningOp<CastOp>();
79     if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
80       operand.set(cast.getOperand());
81       folded = true;
82     }
83   }
84   return success(folded);
85 }
86 
87 //===----------------------------------------------------------------------===//
88 // Helpers for GlobalOp
89 //===----------------------------------------------------------------------===//
90 
91 static Type getTensorTypeFromMemRefType(Type type) {
92   if (auto memref = type.dyn_cast<MemRefType>())
93     return RankedTensorType::get(memref.getShape(), memref.getElementType());
94   if (auto memref = type.dyn_cast<UnrankedMemRefType>())
95     return UnrankedTensorType::get(memref.getElementType());
96   return NoneType::get(type.getContext());
97 }
98 
99 //===----------------------------------------------------------------------===//
100 // AllocOp / AllocaOp
101 //===----------------------------------------------------------------------===//
102 
103 template <typename AllocLikeOp>
104 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
105   static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
106                 "applies to only alloc or alloca");
107   auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>();
108   if (!memRefType)
109     return op.emitOpError("result must be a memref");
110 
111   if (static_cast<int64_t>(op.dynamicSizes().size()) !=
112       memRefType.getNumDynamicDims())
113     return op.emitOpError("dimension operand count does not equal memref "
114                           "dynamic dimension count");
115 
116   unsigned numSymbols = 0;
117   if (!memRefType.getAffineMaps().empty())
118     numSymbols = memRefType.getAffineMaps().front().getNumSymbols();
119   if (op.symbolOperands().size() != numSymbols)
120     return op.emitOpError(
121         "symbol operand count does not equal memref symbol count");
122 
123   return success();
124 }
125 
126 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); }
127 
128 static LogicalResult verify(AllocaOp op) {
129   // An alloca op needs to have an ancestor with an allocation scope trait.
130   if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
131     return op.emitOpError(
132         "requires an ancestor op with AutomaticAllocationScope trait");
133 
134   return verifyAllocLikeOp(op);
135 }
136 
137 namespace {
138 /// Fold constant dimensions into an alloc like operation.
139 template <typename AllocLikeOp>
140 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
141   using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
142 
143   LogicalResult matchAndRewrite(AllocLikeOp alloc,
144                                 PatternRewriter &rewriter) const override {
145     // Check to see if any dimensions operands are constants.  If so, we can
146     // substitute and drop them.
147     if (llvm::none_of(alloc.getOperands(), [](Value operand) {
148           return matchPattern(operand, matchConstantIndex());
149         }))
150       return failure();
151 
152     auto memrefType = alloc.getType();
153 
154     // Ok, we have one or more constant operands.  Collect the non-constant ones
155     // and keep track of the resultant memref type to build.
156     SmallVector<int64_t, 4> newShapeConstants;
157     newShapeConstants.reserve(memrefType.getRank());
158     SmallVector<Value, 4> newOperands;
159 
160     unsigned dynamicDimPos = 0;
161     for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
162       int64_t dimSize = memrefType.getDimSize(dim);
163       // If this is already static dimension, keep it.
164       if (dimSize != -1) {
165         newShapeConstants.push_back(dimSize);
166         continue;
167       }
168       auto *defOp = alloc.getOperand(dynamicDimPos).getDefiningOp();
169       if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
170         // Dynamic shape dimension will be folded.
171         newShapeConstants.push_back(constantIndexOp.getValue());
172       } else {
173         // Dynamic shape dimension not folded; copy operand from old memref.
174         newShapeConstants.push_back(-1);
175         newOperands.push_back(alloc.getOperand(dynamicDimPos));
176       }
177       dynamicDimPos++;
178     }
179 
180     // Create new memref type (which will have fewer dynamic dimensions).
181     MemRefType newMemRefType =
182         MemRefType::Builder(memrefType).setShape(newShapeConstants);
183     assert(static_cast<int64_t>(newOperands.size()) ==
184            newMemRefType.getNumDynamicDims());
185 
186     // Create and insert the alloc op for the new memref.
187     auto newAlloc = rewriter.create<AllocLikeOp>(
188         alloc.getLoc(), newMemRefType, newOperands, alloc.alignmentAttr());
189     // Insert a cast so we have the same type as the old alloc.
190     auto resultCast =
191         rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType());
192 
193     rewriter.replaceOp(alloc, {resultCast});
194     return success();
195   }
196 };
197 
198 /// Fold alloc operations with no uses. Alloc has side effects on the heap,
199 /// but can still be deleted if it has zero uses.
200 struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> {
201   using OpRewritePattern<AllocOp>::OpRewritePattern;
202 
203   LogicalResult matchAndRewrite(AllocOp alloc,
204                                 PatternRewriter &rewriter) const override {
205     if (alloc.use_empty()) {
206       rewriter.eraseOp(alloc);
207       return success();
208     }
209     return failure();
210   }
211 };
212 } // end anonymous namespace.
213 
214 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
215                                           MLIRContext *context) {
216   results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc>(context);
217 }
218 
219 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
220                                            MLIRContext *context) {
221   results.add<SimplifyAllocConst<AllocaOp>>(context);
222 }
223 
224 //===----------------------------------------------------------------------===//
225 // AssumeAlignmentOp
226 //===----------------------------------------------------------------------===//
227 
228 static LogicalResult verify(AssumeAlignmentOp op) {
229   unsigned alignment = op.alignment();
230   if (!llvm::isPowerOf2_32(alignment))
231     return op.emitOpError("alignment must be power of 2");
232   return success();
233 }
234 
235 //===----------------------------------------------------------------------===//
236 // BufferCastOp
237 //===----------------------------------------------------------------------===//
238 
239 OpFoldResult BufferCastOp::fold(ArrayRef<Attribute>) {
240   if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>())
241     if (tensorLoad.memref().getType() == getType())
242       return tensorLoad.memref();
243   return {};
244 }
245 
246 namespace {
247 /// Replace tensor_cast + buffer_cast by buffer_cast + memref_cast.
248 struct BufferCast : public OpRewritePattern<BufferCastOp> {
249   using OpRewritePattern<BufferCastOp>::OpRewritePattern;
250 
251   LogicalResult matchAndRewrite(BufferCastOp bufferCast,
252                                 PatternRewriter &rewriter) const final {
253     auto tensorCastOperand =
254         bufferCast.getOperand().getDefiningOp<tensor::CastOp>();
255     if (!tensorCastOperand)
256       return failure();
257     auto srcTensorType =
258         tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>();
259     if (!srcTensorType)
260       return failure();
261     auto memrefType = MemRefType::get(srcTensorType.getShape(),
262                                       srcTensorType.getElementType());
263     Value memref = rewriter.create<BufferCastOp>(
264         bufferCast.getLoc(), memrefType, tensorCastOperand.getOperand());
265     rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(),
266                                         memref);
267     return success();
268   }
269 };
270 
271 /// Canonicalize memref.tensor_load + memref.buffer_cast to memref.cast when
272 /// type mismatches prevent `BufferCastOp::fold` to kick in.
273 struct TensorLoadToMemRef : public OpRewritePattern<BufferCastOp> {
274   using OpRewritePattern<BufferCastOp>::OpRewritePattern;
275 
276   LogicalResult matchAndRewrite(BufferCastOp bufferCast,
277                                 PatternRewriter &rewriter) const final {
278     auto tensorLoad = bufferCast.tensor().getDefiningOp<TensorLoadOp>();
279     // Bail unless we have a tensor_load + memref.buffer_cast with different
280     // types. `BufferCastOp::fold` handles the same type case.
281     if (!tensorLoad || tensorLoad.memref().getType() == bufferCast.getType())
282       return failure();
283     // If types are not cast-compatible, bail.
284     if (!CastOp::areCastCompatible(tensorLoad.memref().getType(),
285                                    bufferCast.getType()))
286       return failure();
287     rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(),
288                                         tensorLoad.memref());
289     return success();
290   }
291 };
292 
293 } // namespace
294 
295 void BufferCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
296                                                MLIRContext *context) {
297   results.add<BufferCast, TensorLoadToMemRef>(context);
298 }
299 
300 //===----------------------------------------------------------------------===//
301 // CastOp
302 //===----------------------------------------------------------------------===//
303 
304 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
305 /// source memref. This is useful to to fold a memref.cast into a consuming op
306 /// and implement canonicalization patterns for ops in different dialects that
307 /// may consume the results of memref.cast operations. Such foldable memref.cast
308 /// operations are typically inserted as `view` and `subview` ops are
309 /// canonicalized, to preserve the type compatibility of their uses.
310 ///
311 /// Returns true when all conditions are met:
312 /// 1. source and result are ranked memrefs with strided semantics and same
313 /// element type and rank.
314 /// 2. each of the source's size, offset or stride has more static information
315 /// than the corresponding result's size, offset or stride.
316 ///
317 /// Example 1:
318 /// ```mlir
319 ///   %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
320 ///   %2 = consumer %1 ... : memref<?x?xf32> ...
321 /// ```
322 ///
323 /// may fold into:
324 ///
325 /// ```mlir
326 ///   %2 = consumer %0 ... : memref<8x16xf32> ...
327 /// ```
328 ///
329 /// Example 2:
330 /// ```
331 ///   %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
332 ///          to memref<?x?xf32>
333 ///   consumer %1 : memref<?x?xf32> ...
334 /// ```
335 ///
336 /// may fold into:
337 ///
338 /// ```
339 ///   consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
340 /// ```
341 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
342   MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
343   MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
344 
345   // Requires ranked MemRefType.
346   if (!sourceType || !resultType)
347     return false;
348 
349   // Requires same elemental type.
350   if (sourceType.getElementType() != resultType.getElementType())
351     return false;
352 
353   // Requires same rank.
354   if (sourceType.getRank() != resultType.getRank())
355     return false;
356 
357   // Only fold casts between strided memref forms.
358   int64_t sourceOffset, resultOffset;
359   SmallVector<int64_t, 4> sourceStrides, resultStrides;
360   if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
361       failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
362     return false;
363 
364   // If cast is towards more static sizes along any dimension, don't fold.
365   for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
366     auto ss = std::get<0>(it), st = std::get<1>(it);
367     if (ss != st)
368       if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st))
369         return false;
370   }
371 
372   // If cast is towards more static offset along any dimension, don't fold.
373   if (sourceOffset != resultOffset)
374     if (MemRefType::isDynamicStrideOrOffset(sourceOffset) &&
375         !MemRefType::isDynamicStrideOrOffset(resultOffset))
376       return false;
377 
378   // If cast is towards more static strides along any dimension, don't fold.
379   for (auto it : llvm::zip(sourceStrides, resultStrides)) {
380     auto ss = std::get<0>(it), st = std::get<1>(it);
381     if (ss != st)
382       if (MemRefType::isDynamicStrideOrOffset(ss) &&
383           !MemRefType::isDynamicStrideOrOffset(st))
384         return false;
385   }
386 
387   return true;
388 }
389 
390 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
391   if (inputs.size() != 1 || outputs.size() != 1)
392     return false;
393   Type a = inputs.front(), b = outputs.front();
394   auto aT = a.dyn_cast<MemRefType>();
395   auto bT = b.dyn_cast<MemRefType>();
396 
397   auto uaT = a.dyn_cast<UnrankedMemRefType>();
398   auto ubT = b.dyn_cast<UnrankedMemRefType>();
399 
400   if (aT && bT) {
401     if (aT.getElementType() != bT.getElementType())
402       return false;
403     if (aT.getAffineMaps() != bT.getAffineMaps()) {
404       int64_t aOffset, bOffset;
405       SmallVector<int64_t, 4> aStrides, bStrides;
406       if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
407           failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
408           aStrides.size() != bStrides.size())
409         return false;
410 
411       // Strides along a dimension/offset are compatible if the value in the
412       // source memref is static and the value in the target memref is the
413       // same. They are also compatible if either one is dynamic (see
414       // description of MemRefCastOp for details).
415       auto checkCompatible = [](int64_t a, int64_t b) {
416         return (a == MemRefType::getDynamicStrideOrOffset() ||
417                 b == MemRefType::getDynamicStrideOrOffset() || a == b);
418       };
419       if (!checkCompatible(aOffset, bOffset))
420         return false;
421       for (auto aStride : enumerate(aStrides))
422         if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
423           return false;
424     }
425     if (aT.getMemorySpace() != bT.getMemorySpace())
426       return false;
427 
428     // They must have the same rank, and any specified dimensions must match.
429     if (aT.getRank() != bT.getRank())
430       return false;
431 
432     for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
433       int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
434       if (aDim != -1 && bDim != -1 && aDim != bDim)
435         return false;
436     }
437     return true;
438   } else {
439     if (!aT && !uaT)
440       return false;
441     if (!bT && !ubT)
442       return false;
443     // Unranked to unranked casting is unsupported
444     if (uaT && ubT)
445       return false;
446 
447     auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
448     auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
449     if (aEltType != bEltType)
450       return false;
451 
452     auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
453     auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
454     if (aMemSpace != bMemSpace)
455       return false;
456 
457     return true;
458   }
459 
460   return false;
461 }
462 
463 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
464   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
465 }
466 
467 //===----------------------------------------------------------------------===//
468 // CloneOp
469 //===----------------------------------------------------------------------===//
470 
471 static LogicalResult verify(CloneOp op) { return success(); }
472 
473 void CloneOp::getEffects(
474     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
475         &effects) {
476   effects.emplace_back(MemoryEffects::Read::get(), input(),
477                        SideEffects::DefaultResource::get());
478   effects.emplace_back(MemoryEffects::Write::get(), output(),
479                        SideEffects::DefaultResource::get());
480 }
481 
482 namespace {
483 /// Fold Dealloc operations that are deallocating an AllocOp that is only used
484 /// by other Dealloc operations.
485 struct SimplifyClones : public OpRewritePattern<CloneOp> {
486   using OpRewritePattern<CloneOp>::OpRewritePattern;
487 
488   LogicalResult matchAndRewrite(CloneOp cloneOp,
489                                 PatternRewriter &rewriter) const override {
490     if (cloneOp.use_empty()) {
491       rewriter.eraseOp(cloneOp);
492       return success();
493     }
494 
495     Value source = cloneOp.input();
496 
497     // Removes the clone operation and the corresponding dealloc and alloc
498     // operation (if any).
499     auto tryRemoveClone = [&](Operation *sourceOp, Operation *dealloc,
500                               Operation *alloc) {
501       if (!sourceOp || !dealloc || !alloc ||
502           alloc->getBlock() != dealloc->getBlock())
503         return false;
504       rewriter.replaceOp(cloneOp, source);
505       rewriter.eraseOp(dealloc);
506       return true;
507     };
508 
509     // Removes unnecessary clones that are derived from the result of the clone
510     // op.
511     Operation *deallocOp = findDealloc(cloneOp.output());
512     Operation *sourceOp = source.getDefiningOp();
513     if (tryRemoveClone(sourceOp, deallocOp, sourceOp))
514       return success();
515 
516     // Removes unnecessary clones that are derived from the source of the clone
517     // op.
518     deallocOp = findDealloc(source);
519     if (tryRemoveClone(sourceOp, deallocOp, cloneOp))
520       return success();
521 
522     return failure();
523   }
524 };
525 
526 } // end anonymous namespace.
527 
528 void CloneOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
529                                           MLIRContext *context) {
530   results.insert<SimplifyClones>(context);
531 }
532 
533 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) {
534   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
535 }
536 
537 //===----------------------------------------------------------------------===//
538 // DeallocOp
539 //===----------------------------------------------------------------------===//
540 namespace {
541 /// Fold Dealloc operations that are deallocating an AllocOp that is only used
542 /// by other Dealloc operations.
543 struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> {
544   using OpRewritePattern<DeallocOp>::OpRewritePattern;
545 
546   LogicalResult matchAndRewrite(DeallocOp dealloc,
547                                 PatternRewriter &rewriter) const override {
548     // Check that the memref operand's defining operation is an AllocOp.
549     Value memref = dealloc.memref();
550     if (!isa_and_nonnull<AllocOp>(memref.getDefiningOp()))
551       return failure();
552 
553     // Check that all of the uses of the AllocOp are other DeallocOps.
554     for (auto *user : memref.getUsers())
555       if (!isa<DeallocOp>(user))
556         return failure();
557 
558     // Erase the dealloc operation.
559     rewriter.eraseOp(dealloc);
560     return success();
561   }
562 };
563 } // end anonymous namespace.
564 
565 static LogicalResult verify(DeallocOp op) {
566   if (!op.memref().getType().isa<MemRefType>())
567     return op.emitOpError("operand must be a memref");
568   return success();
569 }
570 
571 void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
572                                             MLIRContext *context) {
573   results.add<SimplifyDeadDealloc>(context);
574 }
575 
576 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
577                               SmallVectorImpl<OpFoldResult> &results) {
578   /// dealloc(memrefcast) -> dealloc
579   return foldMemRefCast(*this);
580 }
581 
582 //===----------------------------------------------------------------------===//
583 // DimOp
584 //===----------------------------------------------------------------------===//
585 
586 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref,
587                   int64_t index) {
588   auto loc = result.location;
589   Value indexValue = builder.create<ConstantIndexOp>(loc, index);
590   build(builder, result, memref, indexValue);
591 }
592 
593 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref,
594                   Value index) {
595   auto indexTy = builder.getIndexType();
596   build(builder, result, indexTy, memref, index);
597 }
598 
599 Optional<int64_t> DimOp::getConstantIndex() {
600   if (auto constantOp = index().getDefiningOp<ConstantOp>())
601     return constantOp.getValue().cast<IntegerAttr>().getInt();
602   return {};
603 }
604 
605 static LogicalResult verify(DimOp op) {
606   // Assume unknown index to be in range.
607   Optional<int64_t> index = op.getConstantIndex();
608   if (!index.hasValue())
609     return success();
610 
611   // Check that constant index is not knowingly out of range.
612   auto type = op.memrefOrTensor().getType();
613   if (auto memrefType = type.dyn_cast<MemRefType>()) {
614     if (index.getValue() >= memrefType.getRank())
615       return op.emitOpError("index is out of range");
616   } else if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
617     if (index.getValue() >= tensorType.getRank())
618       return op.emitOpError("index is out of range");
619   } else if (type.isa<UnrankedMemRefType>() || type.isa<UnrankedTensorType>()) {
620     // Assume index to be in range.
621   } else {
622     llvm_unreachable("expected operand with memref type");
623   }
624   return success();
625 }
626 
627 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
628   auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
629 
630   // All forms of folding require a known index.
631   if (!index)
632     return {};
633 
634   auto argTy = memrefOrTensor().getType();
635   // Fold if the shape extent along the given index is known.
636   if (auto shapedTy = argTy.dyn_cast<ShapedType>()) {
637     // Folding for unranked types (UnrankedMemRefType) is not supported.
638     if (!shapedTy.hasRank())
639       return {};
640     if (!shapedTy.isDynamicDim(index.getInt())) {
641       Builder builder(getContext());
642       return builder.getIndexAttr(shapedTy.getShape()[index.getInt()]);
643     }
644   }
645 
646   Operation *definingOp = memrefOrTensor().getDefiningOp();
647 
648   // dim(memref.tensor_load(memref)) -> dim(memref)
649   if (auto tensorLoadOp = dyn_cast_or_null<TensorLoadOp>(definingOp)) {
650     setOperand(0, tensorLoadOp.memref());
651     return getResult();
652   }
653 
654   // Fold dim to the operand of tensor.generate.
655   if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
656     auto resultType =
657         fromElements.getResult().getType().cast<RankedTensorType>();
658     // The case where the type encodes the size of the dimension is handled
659     // above.
660     assert(resultType.getShape()[index.getInt()] ==
661            RankedTensorType::kDynamicSize);
662 
663     // Find the operand of the fromElements that corresponds to this index.
664     auto dynExtents = fromElements.dynamicExtents().begin();
665     for (auto dim : resultType.getShape().take_front(index.getInt()))
666       if (dim == RankedTensorType::kDynamicSize)
667         dynExtents++;
668 
669     return Value{*dynExtents};
670   }
671 
672   // The size at the given index is now known to be a dynamic size.
673   unsigned unsignedIndex = index.getValue().getZExtValue();
674 
675   if (auto subtensor = dyn_cast_or_null<mlir::SubTensorOp>(definingOp)) {
676     assert(subtensor.isDynamicSize(unsignedIndex) &&
677            "Expected dynamic subtensor size");
678     return subtensor.getDynamicSize(unsignedIndex);
679   }
680 
681   // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
682   auto memrefType = argTy.dyn_cast<MemRefType>();
683   if (!memrefType)
684     return {};
685 
686   if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
687     return *(alloc.getDynamicSizes().begin() +
688              memrefType.getDynamicDimIndex(unsignedIndex));
689 
690   if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
691     return *(alloca.getDynamicSizes().begin() +
692              memrefType.getDynamicDimIndex(unsignedIndex));
693 
694   if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
695     return *(view.getDynamicSizes().begin() +
696              memrefType.getDynamicDimIndex(unsignedIndex));
697 
698   if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
699     assert(subview.isDynamicSize(unsignedIndex) &&
700            "Expected dynamic subview size");
701     return subview.getDynamicSize(unsignedIndex);
702   }
703 
704   // dim(memrefcast) -> dim
705   if (succeeded(foldMemRefCast(*this)))
706     return getResult();
707 
708   return {};
709 }
710 
711 namespace {
712 /// Fold dim of a memref reshape operation to a load into the reshape's shape
713 /// operand.
714 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
715   using OpRewritePattern<DimOp>::OpRewritePattern;
716 
717   LogicalResult matchAndRewrite(DimOp dim,
718                                 PatternRewriter &rewriter) const override {
719     auto reshape = dim.memrefOrTensor().getDefiningOp<ReshapeOp>();
720 
721     if (!reshape)
722       return failure();
723 
724     // Place the load directly after the reshape to ensure that the shape memref
725     // was not mutated.
726     rewriter.setInsertionPointAfter(reshape);
727     rewriter.replaceOpWithNewOp<LoadOp>(dim, reshape.shape(),
728                                         llvm::makeArrayRef({dim.index()}));
729     return success();
730   }
731 };
732 
733 /// Fold dim of a dim of a cast into the dim of the source of the tensor cast.
734 template <typename CastOpTy>
735 struct DimOfCastOp : public OpRewritePattern<DimOp> {
736   using OpRewritePattern<DimOp>::OpRewritePattern;
737 
738   LogicalResult matchAndRewrite(DimOp dimOp,
739                                 PatternRewriter &rewriter) const override {
740     auto castOp = dimOp.memrefOrTensor().getDefiningOp<CastOpTy>();
741     if (!castOp)
742       return failure();
743     Value newSource = castOp.getOperand();
744     rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.index());
745     return success();
746   }
747 };
748 
749 /// Helper method to get the `Value` that is the shape of the `resultIdx`-th
750 /// result at dimension `dimIndex` from the `ShapedTypeOpInterface`.
751 /// TODO(ravishankarm): This is better put as a interface utility method
752 /// somewhere, but that would imply the interface will depend on the `tensor`
753 /// dialect. Ideally maybe a utility method in the `tensor` dialect.
754 static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result,
755                                             int64_t dimIndex) {
756   unsigned resultNumber = result.getResultNumber();
757   auto shapedTypeOp = dyn_cast<InferShapedTypeOpInterface>(result.getOwner());
758   Location loc = result.getOwner()->getLoc();
759   if (!shapedTypeOp)
760     return nullptr;
761 
762   // The interface exposes two methods, one that returns the shape of all the
763   // results as `Value` and other that returns the shape as a list of
764   // `SmallVector<Value>`. The former takes precedence over the latter. So first
765   // check if the op implements the first interface method or the second, and
766   // get the value to use appropriately.
767   SmallVector<Value> reifiedResultShapes;
768   if (succeeded(
769           shapedTypeOp.reifyReturnTypeShapes(builder, reifiedResultShapes))) {
770     if (reifiedResultShapes.size() <= resultNumber)
771       return nullptr;
772     Value resultShape = reifiedResultShapes[resultNumber];
773     auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>();
774     if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>())
775       return nullptr;
776     return builder.create<tensor::ExtractOp>(
777         loc, resultShape, builder.createOrFold<ConstantIndexOp>(loc, dimIndex));
778   }
779 
780   SmallVector<SmallVector<Value>> reifiedResultShapesPerDim;
781   if (failed(shapedTypeOp.reifyReturnTypeShapesPerResultDim(
782           builder, reifiedResultShapesPerDim)))
783     return nullptr;
784   if (reifiedResultShapesPerDim.size() <= resultNumber ||
785       reifiedResultShapesPerDim[resultNumber].size() !=
786           static_cast<size_t>(result.getType().cast<ShapedType>().getRank()))
787     return nullptr;
788   OpFoldResult valueOrAttr = reifiedResultShapesPerDim[resultNumber][dimIndex];
789   if (auto attr = valueOrAttr.dyn_cast<Attribute>())
790     return builder.createOrFold<ConstantIndexOp>(
791         loc, attr.cast<IntegerAttr>().getInt());
792   return valueOrAttr.get<Value>();
793 }
794 
795 /// Fold dim of an operation that implements the InferShapedTypeOpInterface
796 struct DimOfShapedTypeOpInterface : public OpRewritePattern<DimOp> {
797   using OpRewritePattern<DimOp>::OpRewritePattern;
798 
799   LogicalResult matchAndRewrite(DimOp dimOp,
800                                 PatternRewriter &rewriter) const override {
801     OpResult dimValue = dimOp.memrefOrTensor().dyn_cast<OpResult>();
802     if (!dimValue)
803       return failure();
804     auto shapedTypeOp =
805         dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner());
806     if (!shapedTypeOp)
807       return failure();
808 
809     Optional<int64_t> dimIndex = dimOp.getConstantIndex();
810     if (!dimIndex)
811       return failure();
812     Value replacement =
813         getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex);
814     if (!replacement)
815       return failure();
816     rewriter.replaceOp(dimOp, replacement);
817     return success();
818   }
819 };
820 } // end anonymous namespace.
821 
822 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
823                                         MLIRContext *context) {
824   results.add<DimOfMemRefReshape, DimOfCastOp<BufferCastOp>,
825               DimOfCastOp<tensor::CastOp>, DimOfShapedTypeOpInterface>(context);
826 }
827 
828 // ---------------------------------------------------------------------------
829 // DmaStartOp
830 // ---------------------------------------------------------------------------
831 
832 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
833                        Value srcMemRef, ValueRange srcIndices, Value destMemRef,
834                        ValueRange destIndices, Value numElements,
835                        Value tagMemRef, ValueRange tagIndices, Value stride,
836                        Value elementsPerStride) {
837   result.addOperands(srcMemRef);
838   result.addOperands(srcIndices);
839   result.addOperands(destMemRef);
840   result.addOperands(destIndices);
841   result.addOperands({numElements, tagMemRef});
842   result.addOperands(tagIndices);
843   if (stride)
844     result.addOperands({stride, elementsPerStride});
845 }
846 
847 void DmaStartOp::print(OpAsmPrinter &p) {
848   p << getOperationName() << " " << getSrcMemRef() << '[' << getSrcIndices()
849     << "], " << getDstMemRef() << '[' << getDstIndices() << "], "
850     << getNumElements() << ", " << getTagMemRef() << '[' << getTagIndices()
851     << ']';
852   if (isStrided())
853     p << ", " << getStride() << ", " << getNumElementsPerStride();
854 
855   p.printOptionalAttrDict((*this)->getAttrs());
856   p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
857     << ", " << getTagMemRef().getType();
858 }
859 
860 // Parse DmaStartOp.
861 // Ex:
862 //   %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
863 //                       %tag[%index], %stride, %num_elt_per_stride :
864 //                     : memref<3076 x f32, 0>,
865 //                       memref<1024 x f32, 2>,
866 //                       memref<1 x i32>
867 //
868 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
869   OpAsmParser::OperandType srcMemRefInfo;
870   SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
871   OpAsmParser::OperandType dstMemRefInfo;
872   SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos;
873   OpAsmParser::OperandType numElementsInfo;
874   OpAsmParser::OperandType tagMemrefInfo;
875   SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
876   SmallVector<OpAsmParser::OperandType, 2> strideInfo;
877 
878   SmallVector<Type, 3> types;
879   auto indexType = parser.getBuilder().getIndexType();
880 
881   // Parse and resolve the following list of operands:
882   // *) source memref followed by its indices (in square brackets).
883   // *) destination memref followed by its indices (in square brackets).
884   // *) dma size in KiB.
885   if (parser.parseOperand(srcMemRefInfo) ||
886       parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
887       parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
888       parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
889       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
890       parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
891       parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
892     return failure();
893 
894   // Parse optional stride and elements per stride.
895   if (parser.parseTrailingOperandList(strideInfo))
896     return failure();
897 
898   bool isStrided = strideInfo.size() == 2;
899   if (!strideInfo.empty() && !isStrided) {
900     return parser.emitError(parser.getNameLoc(),
901                             "expected two stride related operands");
902   }
903 
904   if (parser.parseColonTypeList(types))
905     return failure();
906   if (types.size() != 3)
907     return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
908 
909   if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
910       parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
911       parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
912       parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
913       // size should be an index.
914       parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
915       parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
916       // tag indices should be index.
917       parser.resolveOperands(tagIndexInfos, indexType, result.operands))
918     return failure();
919 
920   if (isStrided) {
921     if (parser.resolveOperands(strideInfo, indexType, result.operands))
922       return failure();
923   }
924 
925   return success();
926 }
927 
928 LogicalResult DmaStartOp::verify() {
929   unsigned numOperands = getNumOperands();
930 
931   // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
932   // the number of elements.
933   if (numOperands < 4)
934     return emitOpError("expected at least 4 operands");
935 
936   // Check types of operands. The order of these calls is important: the later
937   // calls rely on some type properties to compute the operand position.
938   // 1. Source memref.
939   if (!getSrcMemRef().getType().isa<MemRefType>())
940     return emitOpError("expected source to be of memref type");
941   if (numOperands < getSrcMemRefRank() + 4)
942     return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
943                          << " operands";
944   if (!getSrcIndices().empty() &&
945       !llvm::all_of(getSrcIndices().getTypes(),
946                     [](Type t) { return t.isIndex(); }))
947     return emitOpError("expected source indices to be of index type");
948 
949   // 2. Destination memref.
950   if (!getDstMemRef().getType().isa<MemRefType>())
951     return emitOpError("expected destination to be of memref type");
952   unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
953   if (numOperands < numExpectedOperands)
954     return emitOpError() << "expected at least " << numExpectedOperands
955                          << " operands";
956   if (!getDstIndices().empty() &&
957       !llvm::all_of(getDstIndices().getTypes(),
958                     [](Type t) { return t.isIndex(); }))
959     return emitOpError("expected destination indices to be of index type");
960 
961   // 3. Number of elements.
962   if (!getNumElements().getType().isIndex())
963     return emitOpError("expected num elements to be of index type");
964 
965   // 4. Tag memref.
966   if (!getTagMemRef().getType().isa<MemRefType>())
967     return emitOpError("expected tag to be of memref type");
968   numExpectedOperands += getTagMemRefRank();
969   if (numOperands < numExpectedOperands)
970     return emitOpError() << "expected at least " << numExpectedOperands
971                          << " operands";
972   if (!getTagIndices().empty() &&
973       !llvm::all_of(getTagIndices().getTypes(),
974                     [](Type t) { return t.isIndex(); }))
975     return emitOpError("expected tag indices to be of index type");
976 
977   // DMAs from different memory spaces supported.
978   if (getSrcMemorySpace() == getDstMemorySpace())
979     return emitOpError("DMA should be between different memory spaces");
980 
981   // Optional stride-related operands must be either both present or both
982   // absent.
983   if (numOperands != numExpectedOperands &&
984       numOperands != numExpectedOperands + 2)
985     return emitOpError("incorrect number of operands");
986 
987   // 5. Strides.
988   if (isStrided()) {
989     if (!getStride().getType().isIndex() ||
990         !getNumElementsPerStride().getType().isIndex())
991       return emitOpError(
992           "expected stride and num elements per stride to be of type index");
993   }
994 
995   return success();
996 }
997 
998 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
999                                SmallVectorImpl<OpFoldResult> &results) {
1000   /// dma_start(memrefcast) -> dma_start
1001   return foldMemRefCast(*this);
1002 }
1003 
1004 // ---------------------------------------------------------------------------
1005 // DmaWaitOp
1006 // ---------------------------------------------------------------------------
1007 
1008 void DmaWaitOp::build(OpBuilder &builder, OperationState &result,
1009                       Value tagMemRef, ValueRange tagIndices,
1010                       Value numElements) {
1011   result.addOperands(tagMemRef);
1012   result.addOperands(tagIndices);
1013   result.addOperands(numElements);
1014 }
1015 
1016 void DmaWaitOp::print(OpAsmPrinter &p) {
1017   p << getOperationName() << " " << getTagMemRef() << '[' << getTagIndices()
1018     << "], " << getNumElements();
1019   p.printOptionalAttrDict((*this)->getAttrs());
1020   p << " : " << getTagMemRef().getType();
1021 }
1022 
1023 // Parse DmaWaitOp.
1024 // Eg:
1025 //   dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4>
1026 //
1027 ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) {
1028   OpAsmParser::OperandType tagMemrefInfo;
1029   SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
1030   Type type;
1031   auto indexType = parser.getBuilder().getIndexType();
1032   OpAsmParser::OperandType numElementsInfo;
1033 
1034   // Parse tag memref, its indices, and dma size.
1035   if (parser.parseOperand(tagMemrefInfo) ||
1036       parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) ||
1037       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1038       parser.parseColonType(type) ||
1039       parser.resolveOperand(tagMemrefInfo, type, result.operands) ||
1040       parser.resolveOperands(tagIndexInfos, indexType, result.operands) ||
1041       parser.resolveOperand(numElementsInfo, indexType, result.operands))
1042     return failure();
1043 
1044   return success();
1045 }
1046 
1047 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1048                               SmallVectorImpl<OpFoldResult> &results) {
1049   /// dma_wait(memrefcast) -> dma_wait
1050   return foldMemRefCast(*this);
1051 }
1052 
1053 LogicalResult DmaWaitOp::verify() {
1054   // Mandatory non-variadic operands are tag and the number of elements.
1055   if (getNumOperands() < 2)
1056     return emitOpError() << "expected at least 2 operands";
1057 
1058   // Check types of operands. The order of these calls is important: the later
1059   // calls rely on some type properties to compute the operand position.
1060   if (!getTagMemRef().getType().isa<MemRefType>())
1061     return emitOpError() << "expected tag to be of memref type";
1062 
1063   if (getNumOperands() != 2 + getTagMemRefRank())
1064     return emitOpError() << "expected " << 2 + getTagMemRefRank()
1065                          << " operands";
1066 
1067   if (!getTagIndices().empty() &&
1068       !llvm::all_of(getTagIndices().getTypes(),
1069                     [](Type t) { return t.isIndex(); }))
1070     return emitOpError() << "expected tag indices to be of index type";
1071 
1072   if (!getNumElements().getType().isIndex())
1073     return emitOpError()
1074            << "expected the number of elements to be of index type";
1075 
1076   return success();
1077 }
1078 
1079 //===----------------------------------------------------------------------===//
1080 // GlobalOp
1081 //===----------------------------------------------------------------------===//
1082 
1083 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op,
1084                                                    TypeAttr type,
1085                                                    Attribute initialValue) {
1086   p << type;
1087   if (!op.isExternal()) {
1088     p << " = ";
1089     if (op.isUninitialized())
1090       p << "uninitialized";
1091     else
1092       p.printAttributeWithoutType(initialValue);
1093   }
1094 }
1095 
1096 static ParseResult
1097 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
1098                                        Attribute &initialValue) {
1099   Type type;
1100   if (parser.parseType(type))
1101     return failure();
1102 
1103   auto memrefType = type.dyn_cast<MemRefType>();
1104   if (!memrefType || !memrefType.hasStaticShape())
1105     return parser.emitError(parser.getNameLoc())
1106            << "type should be static shaped memref, but got " << type;
1107   typeAttr = TypeAttr::get(type);
1108 
1109   if (parser.parseOptionalEqual())
1110     return success();
1111 
1112   if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1113     initialValue = UnitAttr::get(parser.getBuilder().getContext());
1114     return success();
1115   }
1116 
1117   Type tensorType = getTensorTypeFromMemRefType(memrefType);
1118   if (parser.parseAttribute(initialValue, tensorType))
1119     return failure();
1120   if (!initialValue.isa<ElementsAttr>())
1121     return parser.emitError(parser.getNameLoc())
1122            << "initial value should be a unit or elements attribute";
1123   return success();
1124 }
1125 
1126 static LogicalResult verify(GlobalOp op) {
1127   auto memrefType = op.type().dyn_cast<MemRefType>();
1128   if (!memrefType || !memrefType.hasStaticShape())
1129     return op.emitOpError("type should be static shaped memref, but got ")
1130            << op.type();
1131 
1132   // Verify that the initial value, if present, is either a unit attribute or
1133   // an elements attribute.
1134   if (op.initial_value().hasValue()) {
1135     Attribute initValue = op.initial_value().getValue();
1136     if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>())
1137       return op.emitOpError("initial value should be a unit or elements "
1138                             "attribute, but got ")
1139              << initValue;
1140 
1141     // Check that the type of the initial value is compatible with the type of
1142     // the global variable.
1143     if (initValue.isa<ElementsAttr>()) {
1144       Type initType = initValue.getType();
1145       Type tensorType = getTensorTypeFromMemRefType(memrefType);
1146       if (initType != tensorType)
1147         return op.emitOpError("initial value expected to be of type ")
1148                << tensorType << ", but was of type " << initType;
1149     }
1150   }
1151 
1152   // TODO: verify visibility for declarations.
1153   return success();
1154 }
1155 
1156 //===----------------------------------------------------------------------===//
1157 // GetGlobalOp
1158 //===----------------------------------------------------------------------===//
1159 
1160 LogicalResult
1161 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1162   // Verify that the result type is same as the type of the referenced
1163   // memref.global op.
1164   auto global =
1165       symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr());
1166   if (!global)
1167     return emitOpError("'")
1168            << name() << "' does not reference a valid global memref";
1169 
1170   Type resultType = result().getType();
1171   if (global.type() != resultType)
1172     return emitOpError("result type ")
1173            << resultType << " does not match type " << global.type()
1174            << " of the global memref @" << name();
1175   return success();
1176 }
1177 
1178 //===----------------------------------------------------------------------===//
1179 // LoadOp
1180 //===----------------------------------------------------------------------===//
1181 
1182 static LogicalResult verify(LoadOp op) {
1183   if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
1184     return op.emitOpError("incorrect number of indices for load");
1185   return success();
1186 }
1187 
1188 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
1189   /// load(memrefcast) -> load
1190   if (succeeded(foldMemRefCast(*this)))
1191     return getResult();
1192   return OpFoldResult();
1193 }
1194 
1195 namespace {
1196 /// Fold a load on a buffer_cast operation into an tensor.extract on the
1197 /// corresponding tensor.
1198 struct LoadOfBufferCast : public OpRewritePattern<LoadOp> {
1199   using OpRewritePattern<LoadOp>::OpRewritePattern;
1200 
1201   LogicalResult matchAndRewrite(LoadOp load,
1202                                 PatternRewriter &rewriter) const override {
1203     auto buffercast = load.memref().getDefiningOp<BufferCastOp>();
1204     if (!buffercast)
1205       return failure();
1206 
1207     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, buffercast.tensor(),
1208                                                    load.indices());
1209     return success();
1210   }
1211 };
1212 } // end anonymous namespace.
1213 
1214 void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
1215                                          MLIRContext *context) {
1216   results.add<LoadOfBufferCast>(context);
1217 }
1218 
1219 //===----------------------------------------------------------------------===//
1220 // PrefetchOp
1221 //===----------------------------------------------------------------------===//
1222 
1223 static void print(OpAsmPrinter &p, PrefetchOp op) {
1224   p << PrefetchOp::getOperationName() << " " << op.memref() << '[';
1225   p.printOperands(op.indices());
1226   p << ']' << ", " << (op.isWrite() ? "write" : "read");
1227   p << ", locality<" << op.localityHint();
1228   p << ">, " << (op.isDataCache() ? "data" : "instr");
1229   p.printOptionalAttrDict(
1230       op->getAttrs(),
1231       /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1232   p << " : " << op.getMemRefType();
1233 }
1234 
1235 static ParseResult parsePrefetchOp(OpAsmParser &parser,
1236                                    OperationState &result) {
1237   OpAsmParser::OperandType memrefInfo;
1238   SmallVector<OpAsmParser::OperandType, 4> indexInfo;
1239   IntegerAttr localityHint;
1240   MemRefType type;
1241   StringRef readOrWrite, cacheType;
1242 
1243   auto indexTy = parser.getBuilder().getIndexType();
1244   auto i32Type = parser.getBuilder().getIntegerType(32);
1245   if (parser.parseOperand(memrefInfo) ||
1246       parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1247       parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1248       parser.parseComma() || parser.parseKeyword("locality") ||
1249       parser.parseLess() ||
1250       parser.parseAttribute(localityHint, i32Type, "localityHint",
1251                             result.attributes) ||
1252       parser.parseGreater() || parser.parseComma() ||
1253       parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1254       parser.resolveOperand(memrefInfo, type, result.operands) ||
1255       parser.resolveOperands(indexInfo, indexTy, result.operands))
1256     return failure();
1257 
1258   if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
1259     return parser.emitError(parser.getNameLoc(),
1260                             "rw specifier has to be 'read' or 'write'");
1261   result.addAttribute(
1262       PrefetchOp::getIsWriteAttrName(),
1263       parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
1264 
1265   if (!cacheType.equals("data") && !cacheType.equals("instr"))
1266     return parser.emitError(parser.getNameLoc(),
1267                             "cache type has to be 'data' or 'instr'");
1268 
1269   result.addAttribute(
1270       PrefetchOp::getIsDataCacheAttrName(),
1271       parser.getBuilder().getBoolAttr(cacheType.equals("data")));
1272 
1273   return success();
1274 }
1275 
1276 static LogicalResult verify(PrefetchOp op) {
1277   if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
1278     return op.emitOpError("too few indices");
1279 
1280   return success();
1281 }
1282 
1283 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
1284                                SmallVectorImpl<OpFoldResult> &results) {
1285   // prefetch(memrefcast) -> prefetch
1286   return foldMemRefCast(*this);
1287 }
1288 
1289 //===----------------------------------------------------------------------===//
1290 // ReinterpretCastOp
1291 //===----------------------------------------------------------------------===//
1292 
1293 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1294 /// `staticSizes` and `staticStrides` are automatically filled with
1295 /// source-memref-rank sentinel values that encode dynamic entries.
1296 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1297                               MemRefType resultType, Value source,
1298                               OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1299                               ArrayRef<OpFoldResult> strides,
1300                               ArrayRef<NamedAttribute> attrs) {
1301   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1302   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1303   dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets,
1304                              ShapedType::kDynamicStrideOrOffset);
1305   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1306                              ShapedType::kDynamicSize);
1307   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1308                              ShapedType::kDynamicStrideOrOffset);
1309   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1310         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1311         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1312   result.addAttributes(attrs);
1313 }
1314 
1315 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1316                               MemRefType resultType, Value source,
1317                               int64_t offset, ArrayRef<int64_t> sizes,
1318                               ArrayRef<int64_t> strides,
1319                               ArrayRef<NamedAttribute> attrs) {
1320   SmallVector<OpFoldResult> sizeValues =
1321       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1322         return b.getI64IntegerAttr(v);
1323       }));
1324   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1325       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1326         return b.getI64IntegerAttr(v);
1327       }));
1328   build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1329         strideValues, attrs);
1330 }
1331 
1332 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1333                               MemRefType resultType, Value source, Value offset,
1334                               ValueRange sizes, ValueRange strides,
1335                               ArrayRef<NamedAttribute> attrs) {
1336   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1337       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1338   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1339       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1340   build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1341 }
1342 
1343 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1344 // completed automatically, like we have for subview and subtensor.
1345 static LogicalResult verify(ReinterpretCastOp op) {
1346   // The source and result memrefs should be in the same memory space.
1347   auto srcType = op.source().getType().cast<BaseMemRefType>();
1348   auto resultType = op.getType().cast<MemRefType>();
1349   if (srcType.getMemorySpace() != resultType.getMemorySpace())
1350     return op.emitError("different memory spaces specified for source type ")
1351            << srcType << " and result memref type " << resultType;
1352   if (srcType.getElementType() != resultType.getElementType())
1353     return op.emitError("different element types specified for source type ")
1354            << srcType << " and result memref type " << resultType;
1355 
1356   // Match sizes in result memref type and in static_sizes attribute.
1357   for (auto &en :
1358        llvm::enumerate(llvm::zip(resultType.getShape(),
1359                                  extractFromI64ArrayAttr(op.static_sizes())))) {
1360     int64_t resultSize = std::get<0>(en.value());
1361     int64_t expectedSize = std::get<1>(en.value());
1362     if (resultSize != expectedSize)
1363       return op.emitError("expected result type with size = ")
1364              << expectedSize << " instead of " << resultSize
1365              << " in dim = " << en.index();
1366   }
1367 
1368   // Match offset and strides in static_offset and static_strides attributes if
1369   // result memref type has an affine map specified.
1370   if (!resultType.getAffineMaps().empty()) {
1371     int64_t resultOffset;
1372     SmallVector<int64_t, 4> resultStrides;
1373     if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
1374       return failure();
1375 
1376     // Match offset in result memref type and in static_offsets attribute.
1377     int64_t expectedOffset =
1378         extractFromI64ArrayAttr(op.static_offsets()).front();
1379     if (resultOffset != expectedOffset)
1380       return op.emitError("expected result type with offset = ")
1381              << resultOffset << " instead of " << expectedOffset;
1382 
1383     // Match strides in result memref type and in static_strides attribute.
1384     for (auto &en : llvm::enumerate(llvm::zip(
1385              resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
1386       int64_t resultStride = std::get<0>(en.value());
1387       int64_t expectedStride = std::get<1>(en.value());
1388       if (resultStride != expectedStride)
1389         return op.emitError("expected result type with stride = ")
1390                << expectedStride << " instead of " << resultStride
1391                << " in dim = " << en.index();
1392     }
1393   }
1394   return success();
1395 }
1396 
1397 //===----------------------------------------------------------------------===//
1398 // ReshapeOp
1399 //===----------------------------------------------------------------------===//
1400 
1401 static LogicalResult verify(ReshapeOp op) {
1402   Type operandType = op.source().getType();
1403   Type resultType = op.result().getType();
1404 
1405   Type operandElementType = operandType.cast<ShapedType>().getElementType();
1406   Type resultElementType = resultType.cast<ShapedType>().getElementType();
1407   if (operandElementType != resultElementType)
1408     return op.emitOpError("element types of source and destination memref "
1409                           "types should be the same");
1410 
1411   if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
1412     if (!operandMemRefType.getAffineMaps().empty())
1413       return op.emitOpError(
1414           "source memref type should have identity affine map");
1415 
1416   int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0);
1417   auto resultMemRefType = resultType.dyn_cast<MemRefType>();
1418   if (resultMemRefType) {
1419     if (!resultMemRefType.getAffineMaps().empty())
1420       return op.emitOpError(
1421           "result memref type should have identity affine map");
1422     if (shapeSize == ShapedType::kDynamicSize)
1423       return op.emitOpError("cannot use shape operand with dynamic length to "
1424                             "reshape to statically-ranked memref type");
1425     if (shapeSize != resultMemRefType.getRank())
1426       return op.emitOpError(
1427           "length of shape operand differs from the result's memref rank");
1428   }
1429   return success();
1430 }
1431 
1432 //===----------------------------------------------------------------------===//
1433 // StoreOp
1434 //===----------------------------------------------------------------------===//
1435 
1436 static LogicalResult verify(StoreOp op) {
1437   if (op.getNumOperands() != 2 + op.getMemRefType().getRank())
1438     return op.emitOpError("store index operand count not equal to memref rank");
1439 
1440   return success();
1441 }
1442 
1443 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
1444                             SmallVectorImpl<OpFoldResult> &results) {
1445   /// store(memrefcast) -> store
1446   return foldMemRefCast(*this);
1447 }
1448 
1449 //===----------------------------------------------------------------------===//
1450 // SubViewOp
1451 //===----------------------------------------------------------------------===//
1452 
1453 namespace {
1454 /// Helpers to write more idiomatic operations.
1455 namespace saturated_arith {
1456 struct Wrapper {
1457   explicit Wrapper(int64_t v) : v(v) {}
1458   operator int64_t() { return v; }
1459   int64_t v;
1460 };
1461 Wrapper operator+(Wrapper a, int64_t b) {
1462   if (ShapedType::isDynamicStrideOrOffset(a) ||
1463       ShapedType::isDynamicStrideOrOffset(b))
1464     return Wrapper(ShapedType::kDynamicStrideOrOffset);
1465   return Wrapper(a.v + b);
1466 }
1467 Wrapper operator*(Wrapper a, int64_t b) {
1468   if (ShapedType::isDynamicStrideOrOffset(a) ||
1469       ShapedType::isDynamicStrideOrOffset(b))
1470     return Wrapper(ShapedType::kDynamicStrideOrOffset);
1471   return Wrapper(a.v * b);
1472 }
1473 } // end namespace saturated_arith
1474 } // end namespace
1475 
1476 /// A subview result type can be fully inferred from the source type and the
1477 /// static representation of offsets, sizes and strides. Special sentinels
1478 /// encode the dynamic case.
1479 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
1480                                 ArrayRef<int64_t> leadingStaticOffsets,
1481                                 ArrayRef<int64_t> leadingStaticSizes,
1482                                 ArrayRef<int64_t> leadingStaticStrides) {
1483   // A subview may specify only a leading subset of offset/sizes/strides in
1484   // which case we complete with offset=0, sizes from memref type and strides=1.
1485   unsigned rank = sourceMemRefType.getRank();
1486   assert(leadingStaticOffsets.size() <= rank &&
1487          "unexpected leadingStaticOffsets overflow");
1488   assert(leadingStaticSizes.size() <= rank &&
1489          "unexpected leadingStaticSizes overflow");
1490   assert(leadingStaticStrides.size() <= rank &&
1491          "unexpected leadingStaticStrides overflow");
1492   auto staticOffsets = llvm::to_vector<4>(leadingStaticOffsets);
1493   auto staticSizes = llvm::to_vector<4>(leadingStaticSizes);
1494   auto staticStrides = llvm::to_vector<4>(leadingStaticStrides);
1495   unsigned numTrailingOffsets = rank - staticOffsets.size();
1496   unsigned numTrailingSizes = rank - staticSizes.size();
1497   unsigned numTrailingStrides = rank - staticStrides.size();
1498   staticOffsets.append(numTrailingOffsets, 0);
1499   llvm::append_range(staticSizes,
1500                      sourceMemRefType.getShape().take_back(numTrailingSizes));
1501   staticStrides.append(numTrailingStrides, 1);
1502 
1503   // Extract source offset and strides.
1504   int64_t sourceOffset;
1505   SmallVector<int64_t, 4> sourceStrides;
1506   auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset);
1507   assert(succeeded(res) && "SubViewOp expected strided memref type");
1508   (void)res;
1509 
1510   // Compute target offset whose value is:
1511   //   `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
1512   int64_t targetOffset = sourceOffset;
1513   for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
1514     auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
1515     using namespace saturated_arith;
1516     targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride;
1517   }
1518 
1519   // Compute target stride whose value is:
1520   //   `sourceStrides_i * staticStrides_i`.
1521   SmallVector<int64_t, 4> targetStrides;
1522   targetStrides.reserve(staticOffsets.size());
1523   for (auto it : llvm::zip(sourceStrides, staticStrides)) {
1524     auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
1525     using namespace saturated_arith;
1526     targetStrides.push_back(Wrapper(sourceStride) * staticStride);
1527   }
1528 
1529   // The type is now known.
1530   return MemRefType::get(
1531       staticSizes, sourceMemRefType.getElementType(),
1532       makeStridedLinearLayoutMap(targetStrides, targetOffset,
1533                                  sourceMemRefType.getContext()),
1534       sourceMemRefType.getMemorySpace());
1535 }
1536 
1537 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
1538                                 ArrayRef<OpFoldResult> leadingStaticOffsets,
1539                                 ArrayRef<OpFoldResult> leadingStaticSizes,
1540                                 ArrayRef<OpFoldResult> leadingStaticStrides) {
1541   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1542   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1543   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
1544                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
1545   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
1546                              ShapedType::kDynamicSize);
1547   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
1548                              staticStrides, ShapedType::kDynamicStrideOrOffset);
1549   return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
1550                                     staticSizes, staticStrides)
1551       .cast<MemRefType>();
1552 }
1553 
1554 Type SubViewOp::inferRankReducedResultType(
1555     unsigned resultRank, MemRefType sourceRankedTensorType,
1556     ArrayRef<int64_t> leadingStaticOffsets,
1557     ArrayRef<int64_t> leadingStaticSizes,
1558     ArrayRef<int64_t> leadingStaticStrides) {
1559   auto inferredType =
1560       inferResultType(sourceRankedTensorType, leadingStaticOffsets,
1561                       leadingStaticSizes, leadingStaticStrides)
1562           .cast<MemRefType>();
1563   assert(inferredType.getRank() >= resultRank && "expected ");
1564   int rankDiff = inferredType.getRank() - resultRank;
1565   if (rankDiff > 0) {
1566     auto shape = inferredType.getShape();
1567     llvm::SmallDenseSet<unsigned> dimsToProject;
1568     mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject);
1569     SmallVector<int64_t> projectedShape;
1570     for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
1571       if (!dimsToProject.contains(pos))
1572         projectedShape.push_back(shape[pos]);
1573 
1574     AffineMap map;
1575     auto maps = inferredType.getAffineMaps();
1576     if (!maps.empty() && maps.front())
1577       map = getProjectedMap(maps.front(), dimsToProject);
1578     inferredType =
1579         MemRefType::get(projectedShape, inferredType.getElementType(), map,
1580                         inferredType.getMemorySpace());
1581   }
1582   return inferredType;
1583 }
1584 
1585 Type SubViewOp::inferRankReducedResultType(
1586     unsigned resultRank, MemRefType sourceRankedTensorType,
1587     ArrayRef<OpFoldResult> leadingStaticOffsets,
1588     ArrayRef<OpFoldResult> leadingStaticSizes,
1589     ArrayRef<OpFoldResult> leadingStaticStrides) {
1590   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1591   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1592   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
1593                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
1594   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
1595                              ShapedType::kDynamicSize);
1596   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
1597                              staticStrides, ShapedType::kDynamicStrideOrOffset);
1598   return SubViewOp::inferRankReducedResultType(
1599       resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
1600       staticStrides);
1601 }
1602 // Build a SubViewOp with mixed static and dynamic entries and custom result
1603 // type. If the type passed is nullptr, it is inferred.
1604 void SubViewOp::build(OpBuilder &b, OperationState &result,
1605                       MemRefType resultType, Value source,
1606                       ArrayRef<OpFoldResult> offsets,
1607                       ArrayRef<OpFoldResult> sizes,
1608                       ArrayRef<OpFoldResult> strides,
1609                       ArrayRef<NamedAttribute> attrs) {
1610   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1611   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1612   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1613                              ShapedType::kDynamicStrideOrOffset);
1614   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1615                              ShapedType::kDynamicSize);
1616   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1617                              ShapedType::kDynamicStrideOrOffset);
1618   auto sourceMemRefType = source.getType().cast<MemRefType>();
1619   // Structuring implementation this way avoids duplication between builders.
1620   if (!resultType) {
1621     resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
1622                                             staticSizes, staticStrides)
1623                      .cast<MemRefType>();
1624   }
1625   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1626         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1627         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1628   result.addAttributes(attrs);
1629 }
1630 
1631 // Build a SubViewOp with mixed static and dynamic entries and inferred result
1632 // type.
1633 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1634                       ArrayRef<OpFoldResult> offsets,
1635                       ArrayRef<OpFoldResult> sizes,
1636                       ArrayRef<OpFoldResult> strides,
1637                       ArrayRef<NamedAttribute> attrs) {
1638   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
1639 }
1640 
1641 // Build a SubViewOp with static entries and inferred result type.
1642 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1643                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1644                       ArrayRef<int64_t> strides,
1645                       ArrayRef<NamedAttribute> attrs) {
1646   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1647       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
1648         return b.getI64IntegerAttr(v);
1649       }));
1650   SmallVector<OpFoldResult> sizeValues =
1651       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1652         return b.getI64IntegerAttr(v);
1653       }));
1654   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1655       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1656         return b.getI64IntegerAttr(v);
1657       }));
1658   build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
1659 }
1660 
1661 // Build a SubViewOp with dynamic entries and custom result type. If the
1662 // type passed is nullptr, it is inferred.
1663 void SubViewOp::build(OpBuilder &b, OperationState &result,
1664                       MemRefType resultType, Value source,
1665                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1666                       ArrayRef<int64_t> strides,
1667                       ArrayRef<NamedAttribute> attrs) {
1668   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1669       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
1670         return b.getI64IntegerAttr(v);
1671       }));
1672   SmallVector<OpFoldResult> sizeValues =
1673       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1674         return b.getI64IntegerAttr(v);
1675       }));
1676   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1677       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1678         return b.getI64IntegerAttr(v);
1679       }));
1680   build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
1681         attrs);
1682 }
1683 
1684 // Build a SubViewOp with dynamic entries and custom result type. If the type
1685 // passed is nullptr, it is inferred.
1686 void SubViewOp::build(OpBuilder &b, OperationState &result,
1687                       MemRefType resultType, Value source, ValueRange offsets,
1688                       ValueRange sizes, ValueRange strides,
1689                       ArrayRef<NamedAttribute> attrs) {
1690   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1691       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1692   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1693       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1694   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1695       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1696   build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
1697 }
1698 
1699 // Build a SubViewOp with dynamic entries and inferred result type.
1700 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1701                       ValueRange offsets, ValueRange sizes, ValueRange strides,
1702                       ArrayRef<NamedAttribute> attrs) {
1703   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
1704 }
1705 
1706 /// For ViewLikeOpInterface.
1707 Value SubViewOp::getViewSource() { return source(); }
1708 
1709 enum SubViewVerificationResult {
1710   Success,
1711   RankTooLarge,
1712   SizeMismatch,
1713   ElemTypeMismatch,
1714   MemSpaceMismatch,
1715   AffineMapMismatch
1716 };
1717 
1718 /// Checks if `original` Type type can be rank reduced to `reduced` type.
1719 /// This function is slight variant of `is subsequence` algorithm where
1720 /// not matching dimension must be 1.
1721 static SubViewVerificationResult
1722 isRankReducedType(Type originalType, Type candidateReducedType,
1723                   std::string *errMsg = nullptr) {
1724   if (originalType == candidateReducedType)
1725     return SubViewVerificationResult::Success;
1726   if (!originalType.isa<MemRefType>())
1727     return SubViewVerificationResult::Success;
1728   if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>())
1729     return SubViewVerificationResult::Success;
1730 
1731   ShapedType originalShapedType = originalType.cast<ShapedType>();
1732   ShapedType candidateReducedShapedType =
1733       candidateReducedType.cast<ShapedType>();
1734 
1735   // Rank and size logic is valid for all ShapedTypes.
1736   ArrayRef<int64_t> originalShape = originalShapedType.getShape();
1737   ArrayRef<int64_t> candidateReducedShape =
1738       candidateReducedShapedType.getShape();
1739   unsigned originalRank = originalShape.size(),
1740            candidateReducedRank = candidateReducedShape.size();
1741   if (candidateReducedRank > originalRank)
1742     return SubViewVerificationResult::RankTooLarge;
1743 
1744   auto optionalUnusedDimsMask =
1745       computeRankReductionMask(originalShape, candidateReducedShape);
1746 
1747   // Sizes cannot be matched in case empty vector is returned.
1748   if (!optionalUnusedDimsMask.hasValue())
1749     return SubViewVerificationResult::SizeMismatch;
1750 
1751   if (originalShapedType.getElementType() !=
1752       candidateReducedShapedType.getElementType())
1753     return SubViewVerificationResult::ElemTypeMismatch;
1754 
1755   // Strided layout logic is relevant for MemRefType only.
1756   MemRefType original = originalType.cast<MemRefType>();
1757   MemRefType candidateReduced = candidateReducedType.cast<MemRefType>();
1758   if (original.getMemorySpace() != candidateReduced.getMemorySpace())
1759     return SubViewVerificationResult::MemSpaceMismatch;
1760 
1761   llvm::SmallDenseSet<unsigned> unusedDims = optionalUnusedDimsMask.getValue();
1762   auto inferredType =
1763       getProjectedMap(getStridedLinearLayoutMap(original), unusedDims);
1764   AffineMap candidateLayout;
1765   if (candidateReduced.getAffineMaps().empty())
1766     candidateLayout = getStridedLinearLayoutMap(candidateReduced);
1767   else
1768     candidateLayout = candidateReduced.getAffineMaps().front();
1769   assert(inferredType.getNumResults() == 1 &&
1770          candidateLayout.getNumResults() == 1);
1771   if (inferredType.getNumSymbols() != candidateLayout.getNumSymbols() ||
1772       inferredType.getNumDims() != candidateLayout.getNumDims()) {
1773     if (errMsg) {
1774       llvm::raw_string_ostream os(*errMsg);
1775       os << "inferred type: " << inferredType;
1776     }
1777     return SubViewVerificationResult::AffineMapMismatch;
1778   }
1779   // Check that the difference of the affine maps simplifies to 0.
1780   AffineExpr diffExpr =
1781       inferredType.getResult(0) - candidateLayout.getResult(0);
1782   diffExpr = simplifyAffineExpr(diffExpr, inferredType.getNumDims(),
1783                                 inferredType.getNumSymbols());
1784   auto cst = diffExpr.dyn_cast<AffineConstantExpr>();
1785   if (!(cst && cst.getValue() == 0)) {
1786     if (errMsg) {
1787       llvm::raw_string_ostream os(*errMsg);
1788       os << "inferred type: " << inferredType;
1789     }
1790     return SubViewVerificationResult::AffineMapMismatch;
1791   }
1792   return SubViewVerificationResult::Success;
1793 }
1794 
1795 template <typename OpTy>
1796 static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result,
1797                                             OpTy op, Type expectedType,
1798                                             StringRef errMsg = "") {
1799   auto memrefType = expectedType.cast<ShapedType>();
1800   switch (result) {
1801   case SubViewVerificationResult::Success:
1802     return success();
1803   case SubViewVerificationResult::RankTooLarge:
1804     return op.emitError("expected result rank to be smaller or equal to ")
1805            << "the source rank. " << errMsg;
1806   case SubViewVerificationResult::SizeMismatch:
1807     return op.emitError("expected result type to be ")
1808            << expectedType
1809            << " or a rank-reduced version. (mismatch of result sizes) "
1810            << errMsg;
1811   case SubViewVerificationResult::ElemTypeMismatch:
1812     return op.emitError("expected result element type to be ")
1813            << memrefType.getElementType() << errMsg;
1814   case SubViewVerificationResult::MemSpaceMismatch:
1815     return op.emitError("expected result and source memory spaces to match.")
1816            << errMsg;
1817   case SubViewVerificationResult::AffineMapMismatch:
1818     return op.emitError("expected result type to be ")
1819            << expectedType
1820            << " or a rank-reduced version. (mismatch of result affine map) "
1821            << errMsg;
1822   }
1823   llvm_unreachable("unexpected subview verification result");
1824 }
1825 
1826 /// Verifier for SubViewOp.
1827 static LogicalResult verify(SubViewOp op) {
1828   MemRefType baseType = op.getSourceType();
1829   MemRefType subViewType = op.getType();
1830 
1831   // The base memref and the view memref should be in the same memory space.
1832   if (baseType.getMemorySpace() != subViewType.getMemorySpace())
1833     return op.emitError("different memory spaces specified for base memref "
1834                         "type ")
1835            << baseType << " and subview memref type " << subViewType;
1836 
1837   // Verify that the base memref type has a strided layout map.
1838   if (!isStrided(baseType))
1839     return op.emitError("base type ") << baseType << " is not strided";
1840 
1841   // Verify result type against inferred type.
1842   auto expectedType = SubViewOp::inferResultType(
1843       baseType, extractFromI64ArrayAttr(op.static_offsets()),
1844       extractFromI64ArrayAttr(op.static_sizes()),
1845       extractFromI64ArrayAttr(op.static_strides()));
1846 
1847   std::string errMsg;
1848   auto result = isRankReducedType(expectedType, subViewType, &errMsg);
1849   return produceSubViewErrorMsg(result, op, expectedType, errMsg);
1850 }
1851 
1852 raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) {
1853   return os << "range " << range.offset << ":" << range.size << ":"
1854             << range.stride;
1855 }
1856 
1857 /// Return the list of Range (i.e. offset, size, stride). Each Range
1858 /// entry contains either the dynamic value or a ConstantIndexOp constructed
1859 /// with `b` at location `loc`.
1860 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
1861                                               OpBuilder &b, Location loc) {
1862   std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
1863   assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
1864   assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
1865   SmallVector<Range, 8> res;
1866   unsigned rank = ranks[0];
1867   res.reserve(rank);
1868   for (unsigned idx = 0; idx < rank; ++idx) {
1869     Value offset =
1870         op.isDynamicOffset(idx)
1871             ? op.getDynamicOffset(idx)
1872             : b.create<ConstantIndexOp>(loc, op.getStaticOffset(idx));
1873     Value size = op.isDynamicSize(idx)
1874                      ? op.getDynamicSize(idx)
1875                      : b.create<ConstantIndexOp>(loc, op.getStaticSize(idx));
1876     Value stride =
1877         op.isDynamicStride(idx)
1878             ? op.getDynamicStride(idx)
1879             : b.create<ConstantIndexOp>(loc, op.getStaticStride(idx));
1880     res.emplace_back(Range{offset, size, stride});
1881   }
1882   return res;
1883 }
1884 
1885 namespace {
1886 /// Pattern to rewrite a subview op with MemRefCast arguments.
1887 /// This essentially pushes memref.cast past its consuming subview when
1888 /// `canFoldIntoConsumerOp` is true.
1889 ///
1890 /// Example:
1891 /// ```
1892 ///   %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
1893 ///   %1 = memref.subview %0[0, 0][3, 4][1, 1] :
1894 ///     memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
1895 /// ```
1896 /// is rewritten into:
1897 /// ```
1898 ///   %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
1899 ///   %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to
1900 ///     memref<3x4xf32, offset:?, strides:[?, 1]>
1901 /// ```
1902 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
1903 public:
1904   using OpRewritePattern<SubViewOp>::OpRewritePattern;
1905 
1906   LogicalResult matchAndRewrite(SubViewOp subViewOp,
1907                                 PatternRewriter &rewriter) const override {
1908     // Any constant operand, just return to let SubViewOpConstantFolder kick in.
1909     if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
1910           return matchPattern(operand, matchConstantIndex());
1911         }))
1912       return failure();
1913 
1914     auto castOp = subViewOp.source().getDefiningOp<CastOp>();
1915     if (!castOp)
1916       return failure();
1917 
1918     if (!CastOp::canFoldIntoConsumerOp(castOp))
1919       return failure();
1920 
1921     /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
1922     /// the cast source operand type and the SubViewOp static information. This
1923     /// is the resulting type if the MemRefCastOp were folded.
1924     auto resultType = SubViewOp::inferRankReducedResultType(
1925         subViewOp.getType().getRank(),
1926         castOp.source().getType().cast<MemRefType>(),
1927         subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
1928         subViewOp.getMixedStrides());
1929     Value newSubView = rewriter.create<SubViewOp>(
1930         subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
1931         subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),
1932         subViewOp.static_sizes(), subViewOp.static_strides());
1933     rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
1934                                         newSubView);
1935     return success();
1936   }
1937 };
1938 } // namespace
1939 
1940 /// A canonicalizer wrapper to replace SubViewOps.
1941 struct SubViewCanonicalizer {
1942   void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
1943     rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType());
1944   }
1945 };
1946 
1947 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
1948                                             MLIRContext *context) {
1949   results.add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
1950                   SubViewOp, SubViewCanonicalizer>,
1951               SubViewOpMemRefCastFolder>(context);
1952 }
1953 
1954 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
1955   auto resultShapedType = getResult().getType().cast<ShapedType>();
1956   auto sourceShapedType = source().getType().cast<ShapedType>();
1957 
1958   if (resultShapedType.hasStaticShape() &&
1959       resultShapedType == sourceShapedType) {
1960     return getViewSource();
1961   }
1962 
1963   return {};
1964 }
1965 
1966 //===----------------------------------------------------------------------===//
1967 // TensorLoadOp
1968 //===----------------------------------------------------------------------===//
1969 
1970 OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute>) {
1971   if (auto bufferCast = memref().getDefiningOp<BufferCastOp>())
1972     // Approximate alias analysis by conservatively folding only when no there
1973     // is no interleaved operation.
1974     if (bufferCast->getBlock() == this->getOperation()->getBlock() &&
1975         bufferCast->getNextNode() == this->getOperation())
1976       return bufferCast.tensor();
1977   return {};
1978 }
1979 
1980 //===----------------------------------------------------------------------===//
1981 // TransposeOp
1982 //===----------------------------------------------------------------------===//
1983 
1984 /// Build a strided memref type by applying `permutationMap` tp `memRefType`.
1985 static MemRefType inferTransposeResultType(MemRefType memRefType,
1986                                            AffineMap permutationMap) {
1987   auto rank = memRefType.getRank();
1988   auto originalSizes = memRefType.getShape();
1989   // Compute permuted sizes.
1990   SmallVector<int64_t, 4> sizes(rank, 0);
1991   for (auto en : llvm::enumerate(permutationMap.getResults()))
1992     sizes[en.index()] =
1993         originalSizes[en.value().cast<AffineDimExpr>().getPosition()];
1994 
1995   // Compute permuted strides.
1996   int64_t offset;
1997   SmallVector<int64_t, 4> strides;
1998   auto res = getStridesAndOffset(memRefType, strides, offset);
1999   assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank));
2000   (void)res;
2001   auto map =
2002       makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
2003   map = permutationMap ? map.compose(permutationMap) : map;
2004   return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map);
2005 }
2006 
2007 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
2008                         AffineMapAttr permutation,
2009                         ArrayRef<NamedAttribute> attrs) {
2010   auto permutationMap = permutation.getValue();
2011   assert(permutationMap);
2012 
2013   auto memRefType = in.getType().cast<MemRefType>();
2014   // Compute result type.
2015   MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
2016 
2017   build(b, result, resultType, in, attrs);
2018   result.addAttribute(TransposeOp::getPermutationAttrName(), permutation);
2019 }
2020 
2021 // transpose $in $permutation attr-dict : type($in) `to` type(results)
2022 static void print(OpAsmPrinter &p, TransposeOp op) {
2023   p << "memref.transpose " << op.in() << " " << op.permutation();
2024   p.printOptionalAttrDict(op->getAttrs(),
2025                           {TransposeOp::getPermutationAttrName()});
2026   p << " : " << op.in().getType() << " to " << op.getType();
2027 }
2028 
2029 static ParseResult parseTransposeOp(OpAsmParser &parser,
2030                                     OperationState &result) {
2031   OpAsmParser::OperandType in;
2032   AffineMap permutation;
2033   MemRefType srcType, dstType;
2034   if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
2035       parser.parseOptionalAttrDict(result.attributes) ||
2036       parser.parseColonType(srcType) ||
2037       parser.resolveOperand(in, srcType, result.operands) ||
2038       parser.parseKeywordType("to", dstType) ||
2039       parser.addTypeToList(dstType, result.types))
2040     return failure();
2041 
2042   result.addAttribute(TransposeOp::getPermutationAttrName(),
2043                       AffineMapAttr::get(permutation));
2044   return success();
2045 }
2046 
2047 static LogicalResult verify(TransposeOp op) {
2048   if (!op.permutation().isPermutation())
2049     return op.emitOpError("expected a permutation map");
2050   if (op.permutation().getNumDims() != op.getShapedType().getRank())
2051     return op.emitOpError(
2052         "expected a permutation map of same rank as the input");
2053 
2054   auto srcType = op.in().getType().cast<MemRefType>();
2055   auto dstType = op.getType().cast<MemRefType>();
2056   auto transposedType = inferTransposeResultType(srcType, op.permutation());
2057   if (dstType != transposedType)
2058     return op.emitOpError("output type ")
2059            << dstType << " does not match transposed input type " << srcType
2060            << ", " << transposedType;
2061   return success();
2062 }
2063 
2064 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
2065   if (succeeded(foldMemRefCast(*this)))
2066     return getResult();
2067   return {};
2068 }
2069 
2070 //===----------------------------------------------------------------------===//
2071 // ViewOp
2072 //===----------------------------------------------------------------------===//
2073 
2074 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) {
2075   OpAsmParser::OperandType srcInfo;
2076   SmallVector<OpAsmParser::OperandType, 1> offsetInfo;
2077   SmallVector<OpAsmParser::OperandType, 4> sizesInfo;
2078   auto indexType = parser.getBuilder().getIndexType();
2079   Type srcType, dstType;
2080   llvm::SMLoc offsetLoc;
2081   if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) ||
2082       parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square))
2083     return failure();
2084 
2085   if (offsetInfo.size() != 1)
2086     return parser.emitError(offsetLoc) << "expects 1 offset operand";
2087 
2088   return failure(
2089       parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) ||
2090       parser.parseOptionalAttrDict(result.attributes) ||
2091       parser.parseColonType(srcType) ||
2092       parser.resolveOperand(srcInfo, srcType, result.operands) ||
2093       parser.resolveOperands(offsetInfo, indexType, result.operands) ||
2094       parser.resolveOperands(sizesInfo, indexType, result.operands) ||
2095       parser.parseKeywordType("to", dstType) ||
2096       parser.addTypeToList(dstType, result.types));
2097 }
2098 
2099 static void print(OpAsmPrinter &p, ViewOp op) {
2100   p << op.getOperationName() << ' ' << op.getOperand(0) << '[';
2101   p.printOperand(op.byte_shift());
2102   p << "][" << op.sizes() << ']';
2103   p.printOptionalAttrDict(op->getAttrs());
2104   p << " : " << op.getOperand(0).getType() << " to " << op.getType();
2105 }
2106 
2107 static LogicalResult verify(ViewOp op) {
2108   auto baseType = op.getOperand(0).getType().cast<MemRefType>();
2109   auto viewType = op.getType();
2110 
2111   // The base memref should have identity layout map (or none).
2112   if (baseType.getAffineMaps().size() > 1 ||
2113       (baseType.getAffineMaps().size() == 1 &&
2114        !baseType.getAffineMaps()[0].isIdentity()))
2115     return op.emitError("unsupported map for base memref type ") << baseType;
2116 
2117   // The result memref should have identity layout map (or none).
2118   if (viewType.getAffineMaps().size() > 1 ||
2119       (viewType.getAffineMaps().size() == 1 &&
2120        !viewType.getAffineMaps()[0].isIdentity()))
2121     return op.emitError("unsupported map for result memref type ") << viewType;
2122 
2123   // The base memref and the view memref should be in the same memory space.
2124   if (baseType.getMemorySpace() != viewType.getMemorySpace())
2125     return op.emitError("different memory spaces specified for base memref "
2126                         "type ")
2127            << baseType << " and view memref type " << viewType;
2128 
2129   // Verify that we have the correct number of sizes for the result type.
2130   unsigned numDynamicDims = viewType.getNumDynamicDims();
2131   if (op.sizes().size() != numDynamicDims)
2132     return op.emitError("incorrect number of size operands for type ")
2133            << viewType;
2134 
2135   return success();
2136 }
2137 
2138 Value ViewOp::getViewSource() { return source(); }
2139 
2140 namespace {
2141 
2142 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
2143   using OpRewritePattern<ViewOp>::OpRewritePattern;
2144 
2145   LogicalResult matchAndRewrite(ViewOp viewOp,
2146                                 PatternRewriter &rewriter) const override {
2147     // Return if none of the operands are constants.
2148     if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
2149           return matchPattern(operand, matchConstantIndex());
2150         }))
2151       return failure();
2152 
2153     // Get result memref type.
2154     auto memrefType = viewOp.getType();
2155 
2156     // Get offset from old memref view type 'memRefType'.
2157     int64_t oldOffset;
2158     SmallVector<int64_t, 4> oldStrides;
2159     if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
2160       return failure();
2161     assert(oldOffset == 0 && "Expected 0 offset");
2162 
2163     SmallVector<Value, 4> newOperands;
2164 
2165     // Offset cannot be folded into result type.
2166 
2167     // Fold any dynamic dim operands which are produced by a constant.
2168     SmallVector<int64_t, 4> newShapeConstants;
2169     newShapeConstants.reserve(memrefType.getRank());
2170 
2171     unsigned dynamicDimPos = 0;
2172     unsigned rank = memrefType.getRank();
2173     for (unsigned dim = 0, e = rank; dim < e; ++dim) {
2174       int64_t dimSize = memrefType.getDimSize(dim);
2175       // If this is already static dimension, keep it.
2176       if (!ShapedType::isDynamic(dimSize)) {
2177         newShapeConstants.push_back(dimSize);
2178         continue;
2179       }
2180       auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp();
2181       if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
2182         // Dynamic shape dimension will be folded.
2183         newShapeConstants.push_back(constantIndexOp.getValue());
2184       } else {
2185         // Dynamic shape dimension not folded; copy operand from old memref.
2186         newShapeConstants.push_back(dimSize);
2187         newOperands.push_back(viewOp.sizes()[dynamicDimPos]);
2188       }
2189       dynamicDimPos++;
2190     }
2191 
2192     // Create new memref type with constant folded dims.
2193     MemRefType newMemRefType =
2194         MemRefType::Builder(memrefType).setShape(newShapeConstants);
2195     // Nothing new, don't fold.
2196     if (newMemRefType == memrefType)
2197       return failure();
2198 
2199     // Create new ViewOp.
2200     auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType,
2201                                              viewOp.getOperand(0),
2202                                              viewOp.byte_shift(), newOperands);
2203     // Insert a cast so we have the same type as the old memref type.
2204     rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType());
2205     return success();
2206   }
2207 };
2208 
2209 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
2210   using OpRewritePattern<ViewOp>::OpRewritePattern;
2211 
2212   LogicalResult matchAndRewrite(ViewOp viewOp,
2213                                 PatternRewriter &rewriter) const override {
2214     Value memrefOperand = viewOp.getOperand(0);
2215     CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
2216     if (!memrefCastOp)
2217       return failure();
2218     Value allocOperand = memrefCastOp.getOperand();
2219     AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
2220     if (!allocOp)
2221       return failure();
2222     rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
2223                                         viewOp.byte_shift(), viewOp.sizes());
2224     return success();
2225   }
2226 };
2227 
2228 } // end anonymous namespace
2229 
2230 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
2231                                          MLIRContext *context) {
2232   results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
2233 }
2234 
2235 //===----------------------------------------------------------------------===//
2236 // TableGen'd op method definitions
2237 //===----------------------------------------------------------------------===//
2238 
2239 #define GET_OP_CLASSES
2240 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
2241