1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/MemRef/IR/MemRef.h"
10 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Interfaces/InferTypeOpInterface.h"
21 #include "llvm/ADT/STLExtras.h"
22 
23 using namespace mlir;
24 using namespace mlir::memref;
25 
26 /// Materialize a single constant operation from a given attribute value with
27 /// the desired resultant type.
28 Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
29                                               Attribute value, Type type,
30                                               Location loc) {
31   return builder.create<mlir::ConstantOp>(loc, type, value);
32 }
33 
34 /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
35 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
36   return llvm::to_vector<4>(
37       llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
38         return a.cast<IntegerAttr>().getInt();
39       }));
40 }
41 
42 /// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
43 /// it is a Value or into `staticVec` if it is an IntegerAttr.
44 /// In the case of a Value, a copy of the `sentinel` value is also pushed to
45 /// `staticVec`. This is useful to extract mixed static and dynamic entries that
46 /// come from an AttrSizedOperandSegments trait.
47 static void dispatchIndexOpFoldResult(OpFoldResult ofr,
48                                       SmallVectorImpl<Value> &dynamicVec,
49                                       SmallVectorImpl<int64_t> &staticVec,
50                                       int64_t sentinel) {
51   if (auto v = ofr.dyn_cast<Value>()) {
52     dynamicVec.push_back(v);
53     staticVec.push_back(sentinel);
54     return;
55   }
56   APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue();
57   staticVec.push_back(apInt.getSExtValue());
58 }
59 
60 static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
61                                        SmallVectorImpl<Value> &dynamicVec,
62                                        SmallVectorImpl<int64_t> &staticVec,
63                                        int64_t sentinel) {
64   for (auto ofr : ofrs)
65     dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel);
66 }
67 
68 //===----------------------------------------------------------------------===//
69 // Common canonicalization pattern support logic
70 //===----------------------------------------------------------------------===//
71 
72 /// This is a common class used for patterns of the form
73 /// "someop(memrefcast) -> someop".  It folds the source of any memref.cast
74 /// into the root operation directly.
75 static LogicalResult foldMemRefCast(Operation *op) {
76   bool folded = false;
77   for (OpOperand &operand : op->getOpOperands()) {
78     auto cast = operand.get().getDefiningOp<CastOp>();
79     if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
80       operand.set(cast.getOperand());
81       folded = true;
82     }
83   }
84   return success(folded);
85 }
86 
87 //===----------------------------------------------------------------------===//
88 // Helpers for GlobalOp
89 //===----------------------------------------------------------------------===//
90 
91 static Type getTensorTypeFromMemRefType(Type type) {
92   if (auto memref = type.dyn_cast<MemRefType>())
93     return RankedTensorType::get(memref.getShape(), memref.getElementType());
94   if (auto memref = type.dyn_cast<UnrankedMemRefType>())
95     return UnrankedTensorType::get(memref.getElementType());
96   return NoneType::get(type.getContext());
97 }
98 
99 //===----------------------------------------------------------------------===//
100 // AllocOp / AllocaOp
101 //===----------------------------------------------------------------------===//
102 
103 template <typename AllocLikeOp>
104 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
105   static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
106                 "applies to only alloc or alloca");
107   auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>();
108   if (!memRefType)
109     return op.emitOpError("result must be a memref");
110 
111   if (static_cast<int64_t>(op.dynamicSizes().size()) !=
112       memRefType.getNumDynamicDims())
113     return op.emitOpError("dimension operand count does not equal memref "
114                           "dynamic dimension count");
115 
116   unsigned numSymbols = 0;
117   if (!memRefType.getAffineMaps().empty())
118     numSymbols = memRefType.getAffineMaps().front().getNumSymbols();
119   if (op.symbolOperands().size() != numSymbols)
120     return op.emitOpError(
121         "symbol operand count does not equal memref symbol count");
122 
123   return success();
124 }
125 
126 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); }
127 
128 static LogicalResult verify(AllocaOp op) {
129   // An alloca op needs to have an ancestor with an allocation scope trait.
130   if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
131     return op.emitOpError(
132         "requires an ancestor op with AutomaticAllocationScope trait");
133 
134   return verifyAllocLikeOp(op);
135 }
136 
137 namespace {
138 /// Fold constant dimensions into an alloc like operation.
139 template <typename AllocLikeOp>
140 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
141   using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
142 
143   LogicalResult matchAndRewrite(AllocLikeOp alloc,
144                                 PatternRewriter &rewriter) const override {
145     // Check to see if any dimensions operands are constants.  If so, we can
146     // substitute and drop them.
147     if (llvm::none_of(alloc.getOperands(), [](Value operand) {
148           return matchPattern(operand, matchConstantIndex());
149         }))
150       return failure();
151 
152     auto memrefType = alloc.getType();
153 
154     // Ok, we have one or more constant operands.  Collect the non-constant ones
155     // and keep track of the resultant memref type to build.
156     SmallVector<int64_t, 4> newShapeConstants;
157     newShapeConstants.reserve(memrefType.getRank());
158     SmallVector<Value, 4> newOperands;
159 
160     unsigned dynamicDimPos = 0;
161     for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
162       int64_t dimSize = memrefType.getDimSize(dim);
163       // If this is already static dimension, keep it.
164       if (dimSize != -1) {
165         newShapeConstants.push_back(dimSize);
166         continue;
167       }
168       auto *defOp = alloc.getOperand(dynamicDimPos).getDefiningOp();
169       if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
170         // Dynamic shape dimension will be folded.
171         newShapeConstants.push_back(constantIndexOp.getValue());
172       } else {
173         // Dynamic shape dimension not folded; copy operand from old memref.
174         newShapeConstants.push_back(-1);
175         newOperands.push_back(alloc.getOperand(dynamicDimPos));
176       }
177       dynamicDimPos++;
178     }
179 
180     // Create new memref type (which will have fewer dynamic dimensions).
181     MemRefType newMemRefType =
182         MemRefType::Builder(memrefType).setShape(newShapeConstants);
183     assert(static_cast<int64_t>(newOperands.size()) ==
184            newMemRefType.getNumDynamicDims());
185 
186     // Create and insert the alloc op for the new memref.
187     auto newAlloc = rewriter.create<AllocLikeOp>(
188         alloc.getLoc(), newMemRefType, newOperands, alloc.alignmentAttr());
189     // Insert a cast so we have the same type as the old alloc.
190     auto resultCast =
191         rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType());
192 
193     rewriter.replaceOp(alloc, {resultCast});
194     return success();
195   }
196 };
197 
198 /// Fold alloc operations with no users or only store and dealloc uses.
199 template <typename T>
200 struct SimplifyDeadAlloc : public OpRewritePattern<T> {
201   using OpRewritePattern<T>::OpRewritePattern;
202 
203   LogicalResult matchAndRewrite(T alloc,
204                                 PatternRewriter &rewriter) const override {
205     if (llvm::any_of(alloc->getUsers(), [](Operation *op) {
206           return !isa<StoreOp, DeallocOp>(op);
207         }))
208       return failure();
209 
210     for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
211       rewriter.eraseOp(user);
212 
213     rewriter.eraseOp(alloc);
214     return success();
215   }
216 };
217 } // end anonymous namespace.
218 
219 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
220                                           MLIRContext *context) {
221   results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
222 }
223 
224 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
225                                            MLIRContext *context) {
226   results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
227       context);
228 }
229 
230 //===----------------------------------------------------------------------===//
231 // AssumeAlignmentOp
232 //===----------------------------------------------------------------------===//
233 
234 static LogicalResult verify(AssumeAlignmentOp op) {
235   unsigned alignment = op.alignment();
236   if (!llvm::isPowerOf2_32(alignment))
237     return op.emitOpError("alignment must be power of 2");
238   return success();
239 }
240 
241 //===----------------------------------------------------------------------===//
242 // BufferCastOp
243 //===----------------------------------------------------------------------===//
244 
245 OpFoldResult BufferCastOp::fold(ArrayRef<Attribute>) {
246   if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>())
247     if (tensorLoad.memref().getType() == getType())
248       return tensorLoad.memref();
249   return {};
250 }
251 
252 namespace {
253 /// Replace tensor_cast + buffer_cast by buffer_cast + memref_cast.
254 struct BufferCast : public OpRewritePattern<BufferCastOp> {
255   using OpRewritePattern<BufferCastOp>::OpRewritePattern;
256 
257   LogicalResult matchAndRewrite(BufferCastOp bufferCast,
258                                 PatternRewriter &rewriter) const final {
259     auto tensorCastOperand =
260         bufferCast.getOperand().getDefiningOp<tensor::CastOp>();
261     if (!tensorCastOperand)
262       return failure();
263     auto srcTensorType =
264         tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>();
265     if (!srcTensorType)
266       return failure();
267     auto memrefType = MemRefType::get(srcTensorType.getShape(),
268                                       srcTensorType.getElementType());
269     Value memref = rewriter.create<BufferCastOp>(
270         bufferCast.getLoc(), memrefType, tensorCastOperand.getOperand());
271     rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(),
272                                         memref);
273     return success();
274   }
275 };
276 
277 /// Canonicalize memref.tensor_load + memref.buffer_cast to memref.cast when
278 /// type mismatches prevent `BufferCastOp::fold` to kick in.
279 struct TensorLoadToMemRef : public OpRewritePattern<BufferCastOp> {
280   using OpRewritePattern<BufferCastOp>::OpRewritePattern;
281 
282   LogicalResult matchAndRewrite(BufferCastOp bufferCast,
283                                 PatternRewriter &rewriter) const final {
284     auto tensorLoad = bufferCast.tensor().getDefiningOp<TensorLoadOp>();
285     // Bail unless we have a tensor_load + memref.buffer_cast with different
286     // types. `BufferCastOp::fold` handles the same type case.
287     if (!tensorLoad || tensorLoad.memref().getType() == bufferCast.getType())
288       return failure();
289     // If types are not cast-compatible, bail.
290     if (!CastOp::areCastCompatible(tensorLoad.memref().getType(),
291                                    bufferCast.getType()))
292       return failure();
293     rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(),
294                                         tensorLoad.memref());
295     return success();
296   }
297 };
298 
299 } // namespace
300 
301 void BufferCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
302                                                MLIRContext *context) {
303   results.add<BufferCast, TensorLoadToMemRef>(context);
304 }
305 
306 //===----------------------------------------------------------------------===//
307 // CastOp
308 //===----------------------------------------------------------------------===//
309 
310 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
311 /// source memref. This is useful to to fold a memref.cast into a consuming op
312 /// and implement canonicalization patterns for ops in different dialects that
313 /// may consume the results of memref.cast operations. Such foldable memref.cast
314 /// operations are typically inserted as `view` and `subview` ops are
315 /// canonicalized, to preserve the type compatibility of their uses.
316 ///
317 /// Returns true when all conditions are met:
318 /// 1. source and result are ranked memrefs with strided semantics and same
319 /// element type and rank.
320 /// 2. each of the source's size, offset or stride has more static information
321 /// than the corresponding result's size, offset or stride.
322 ///
323 /// Example 1:
324 /// ```mlir
325 ///   %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
326 ///   %2 = consumer %1 ... : memref<?x?xf32> ...
327 /// ```
328 ///
329 /// may fold into:
330 ///
331 /// ```mlir
332 ///   %2 = consumer %0 ... : memref<8x16xf32> ...
333 /// ```
334 ///
335 /// Example 2:
336 /// ```
337 ///   %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
338 ///          to memref<?x?xf32>
339 ///   consumer %1 : memref<?x?xf32> ...
340 /// ```
341 ///
342 /// may fold into:
343 ///
344 /// ```
345 ///   consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
346 /// ```
347 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
348   MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
349   MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
350 
351   // Requires ranked MemRefType.
352   if (!sourceType || !resultType)
353     return false;
354 
355   // Requires same elemental type.
356   if (sourceType.getElementType() != resultType.getElementType())
357     return false;
358 
359   // Requires same rank.
360   if (sourceType.getRank() != resultType.getRank())
361     return false;
362 
363   // Only fold casts between strided memref forms.
364   int64_t sourceOffset, resultOffset;
365   SmallVector<int64_t, 4> sourceStrides, resultStrides;
366   if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
367       failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
368     return false;
369 
370   // If cast is towards more static sizes along any dimension, don't fold.
371   for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
372     auto ss = std::get<0>(it), st = std::get<1>(it);
373     if (ss != st)
374       if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st))
375         return false;
376   }
377 
378   // If cast is towards more static offset along any dimension, don't fold.
379   if (sourceOffset != resultOffset)
380     if (MemRefType::isDynamicStrideOrOffset(sourceOffset) &&
381         !MemRefType::isDynamicStrideOrOffset(resultOffset))
382       return false;
383 
384   // If cast is towards more static strides along any dimension, don't fold.
385   for (auto it : llvm::zip(sourceStrides, resultStrides)) {
386     auto ss = std::get<0>(it), st = std::get<1>(it);
387     if (ss != st)
388       if (MemRefType::isDynamicStrideOrOffset(ss) &&
389           !MemRefType::isDynamicStrideOrOffset(st))
390         return false;
391   }
392 
393   return true;
394 }
395 
396 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
397   if (inputs.size() != 1 || outputs.size() != 1)
398     return false;
399   Type a = inputs.front(), b = outputs.front();
400   auto aT = a.dyn_cast<MemRefType>();
401   auto bT = b.dyn_cast<MemRefType>();
402 
403   auto uaT = a.dyn_cast<UnrankedMemRefType>();
404   auto ubT = b.dyn_cast<UnrankedMemRefType>();
405 
406   if (aT && bT) {
407     if (aT.getElementType() != bT.getElementType())
408       return false;
409     if (aT.getAffineMaps() != bT.getAffineMaps()) {
410       int64_t aOffset, bOffset;
411       SmallVector<int64_t, 4> aStrides, bStrides;
412       if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
413           failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
414           aStrides.size() != bStrides.size())
415         return false;
416 
417       // Strides along a dimension/offset are compatible if the value in the
418       // source memref is static and the value in the target memref is the
419       // same. They are also compatible if either one is dynamic (see
420       // description of MemRefCastOp for details).
421       auto checkCompatible = [](int64_t a, int64_t b) {
422         return (a == MemRefType::getDynamicStrideOrOffset() ||
423                 b == MemRefType::getDynamicStrideOrOffset() || a == b);
424       };
425       if (!checkCompatible(aOffset, bOffset))
426         return false;
427       for (auto aStride : enumerate(aStrides))
428         if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
429           return false;
430     }
431     if (aT.getMemorySpace() != bT.getMemorySpace())
432       return false;
433 
434     // They must have the same rank, and any specified dimensions must match.
435     if (aT.getRank() != bT.getRank())
436       return false;
437 
438     for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
439       int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
440       if (aDim != -1 && bDim != -1 && aDim != bDim)
441         return false;
442     }
443     return true;
444   } else {
445     if (!aT && !uaT)
446       return false;
447     if (!bT && !ubT)
448       return false;
449     // Unranked to unranked casting is unsupported
450     if (uaT && ubT)
451       return false;
452 
453     auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
454     auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
455     if (aEltType != bEltType)
456       return false;
457 
458     auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
459     auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
460     if (aMemSpace != bMemSpace)
461       return false;
462 
463     return true;
464   }
465 
466   return false;
467 }
468 
469 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
470   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
471 }
472 
473 //===----------------------------------------------------------------------===//
474 // CloneOp
475 //===----------------------------------------------------------------------===//
476 
477 static LogicalResult verify(CloneOp op) { return success(); }
478 
479 void CloneOp::getEffects(
480     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
481         &effects) {
482   effects.emplace_back(MemoryEffects::Read::get(), input(),
483                        SideEffects::DefaultResource::get());
484   effects.emplace_back(MemoryEffects::Write::get(), output(),
485                        SideEffects::DefaultResource::get());
486 }
487 
488 namespace {
489 /// Fold Dealloc operations that are deallocating an AllocOp that is only used
490 /// by other Dealloc operations.
491 struct SimplifyClones : public OpRewritePattern<CloneOp> {
492   using OpRewritePattern<CloneOp>::OpRewritePattern;
493 
494   LogicalResult matchAndRewrite(CloneOp cloneOp,
495                                 PatternRewriter &rewriter) const override {
496     if (cloneOp.use_empty()) {
497       rewriter.eraseOp(cloneOp);
498       return success();
499     }
500 
501     Value source = cloneOp.input();
502 
503     // Removes the clone operation and the corresponding dealloc and alloc
504     // operation (if any).
505     auto tryRemoveClone = [&](Operation *sourceOp, Operation *dealloc,
506                               Operation *alloc) {
507       if (!sourceOp || !dealloc || !alloc ||
508           alloc->getBlock() != dealloc->getBlock())
509         return false;
510       rewriter.replaceOp(cloneOp, source);
511       rewriter.eraseOp(dealloc);
512       return true;
513     };
514 
515     // Removes unnecessary clones that are derived from the result of the clone
516     // op.
517     Operation *deallocOp = findDealloc(cloneOp.output());
518     Operation *sourceOp = source.getDefiningOp();
519     if (tryRemoveClone(sourceOp, deallocOp, sourceOp))
520       return success();
521 
522     // Removes unnecessary clones that are derived from the source of the clone
523     // op.
524     deallocOp = findDealloc(source);
525     if (tryRemoveClone(sourceOp, deallocOp, cloneOp))
526       return success();
527 
528     return failure();
529   }
530 };
531 
532 } // end anonymous namespace.
533 
534 void CloneOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
535                                           MLIRContext *context) {
536   results.insert<SimplifyClones>(context);
537 }
538 
539 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) {
540   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
541 }
542 
543 //===----------------------------------------------------------------------===//
544 // DeallocOp
545 //===----------------------------------------------------------------------===//
546 
547 static LogicalResult verify(DeallocOp op) {
548   if (!op.memref().getType().isa<MemRefType>())
549     return op.emitOpError("operand must be a memref");
550   return success();
551 }
552 
553 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
554                               SmallVectorImpl<OpFoldResult> &results) {
555   /// dealloc(memrefcast) -> dealloc
556   return foldMemRefCast(*this);
557 }
558 
559 //===----------------------------------------------------------------------===//
560 // DimOp
561 //===----------------------------------------------------------------------===//
562 
563 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref,
564                   int64_t index) {
565   auto loc = result.location;
566   Value indexValue = builder.create<ConstantIndexOp>(loc, index);
567   build(builder, result, memref, indexValue);
568 }
569 
570 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref,
571                   Value index) {
572   auto indexTy = builder.getIndexType();
573   build(builder, result, indexTy, memref, index);
574 }
575 
576 Optional<int64_t> DimOp::getConstantIndex() {
577   if (auto constantOp = index().getDefiningOp<ConstantOp>())
578     return constantOp.getValue().cast<IntegerAttr>().getInt();
579   return {};
580 }
581 
582 static LogicalResult verify(DimOp op) {
583   // Assume unknown index to be in range.
584   Optional<int64_t> index = op.getConstantIndex();
585   if (!index.hasValue())
586     return success();
587 
588   // Check that constant index is not knowingly out of range.
589   auto type = op.memrefOrTensor().getType();
590   if (auto memrefType = type.dyn_cast<MemRefType>()) {
591     if (index.getValue() >= memrefType.getRank())
592       return op.emitOpError("index is out of range");
593   } else if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
594     if (index.getValue() >= tensorType.getRank())
595       return op.emitOpError("index is out of range");
596   } else if (type.isa<UnrankedMemRefType>() || type.isa<UnrankedTensorType>()) {
597     // Assume index to be in range.
598   } else {
599     llvm_unreachable("expected operand with memref type");
600   }
601   return success();
602 }
603 
604 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
605   auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
606 
607   // All forms of folding require a known index.
608   if (!index)
609     return {};
610 
611   auto argTy = memrefOrTensor().getType();
612   // Fold if the shape extent along the given index is known.
613   if (auto shapedTy = argTy.dyn_cast<ShapedType>()) {
614     // Folding for unranked types (UnrankedMemRefType) is not supported.
615     if (!shapedTy.hasRank())
616       return {};
617     if (!shapedTy.isDynamicDim(index.getInt())) {
618       Builder builder(getContext());
619       return builder.getIndexAttr(shapedTy.getShape()[index.getInt()]);
620     }
621   }
622 
623   Operation *definingOp = memrefOrTensor().getDefiningOp();
624 
625   // dim(memref.tensor_load(memref)) -> dim(memref)
626   if (auto tensorLoadOp = dyn_cast_or_null<TensorLoadOp>(definingOp)) {
627     setOperand(0, tensorLoadOp.memref());
628     return getResult();
629   }
630 
631   // Fold dim to the operand of tensor.generate.
632   if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
633     auto resultType =
634         fromElements.getResult().getType().cast<RankedTensorType>();
635     // The case where the type encodes the size of the dimension is handled
636     // above.
637     assert(resultType.getShape()[index.getInt()] ==
638            RankedTensorType::kDynamicSize);
639 
640     // Find the operand of the fromElements that corresponds to this index.
641     auto dynExtents = fromElements.dynamicExtents().begin();
642     for (auto dim : resultType.getShape().take_front(index.getInt()))
643       if (dim == RankedTensorType::kDynamicSize)
644         dynExtents++;
645 
646     return Value{*dynExtents};
647   }
648 
649   // The size at the given index is now known to be a dynamic size.
650   unsigned unsignedIndex = index.getValue().getZExtValue();
651 
652   if (auto subtensor = dyn_cast_or_null<mlir::SubTensorOp>(definingOp)) {
653     assert(subtensor.isDynamicSize(unsignedIndex) &&
654            "Expected dynamic subtensor size");
655     return subtensor.getDynamicSize(unsignedIndex);
656   }
657 
658   // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
659   auto memrefType = argTy.dyn_cast<MemRefType>();
660   if (!memrefType)
661     return {};
662 
663   if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
664     return *(alloc.getDynamicSizes().begin() +
665              memrefType.getDynamicDimIndex(unsignedIndex));
666 
667   if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
668     return *(alloca.getDynamicSizes().begin() +
669              memrefType.getDynamicDimIndex(unsignedIndex));
670 
671   if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
672     return *(view.getDynamicSizes().begin() +
673              memrefType.getDynamicDimIndex(unsignedIndex));
674 
675   if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
676     assert(subview.isDynamicSize(unsignedIndex) &&
677            "Expected dynamic subview size");
678     return subview.getDynamicSize(unsignedIndex);
679   }
680 
681   // dim(memrefcast) -> dim
682   if (succeeded(foldMemRefCast(*this)))
683     return getResult();
684 
685   return {};
686 }
687 
688 namespace {
689 /// Fold dim of a memref reshape operation to a load into the reshape's shape
690 /// operand.
691 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
692   using OpRewritePattern<DimOp>::OpRewritePattern;
693 
694   LogicalResult matchAndRewrite(DimOp dim,
695                                 PatternRewriter &rewriter) const override {
696     auto reshape = dim.memrefOrTensor().getDefiningOp<ReshapeOp>();
697 
698     if (!reshape)
699       return failure();
700 
701     // Place the load directly after the reshape to ensure that the shape memref
702     // was not mutated.
703     rewriter.setInsertionPointAfter(reshape);
704     rewriter.replaceOpWithNewOp<LoadOp>(dim, reshape.shape(),
705                                         llvm::makeArrayRef({dim.index()}));
706     return success();
707   }
708 };
709 
710 /// Fold dim of a dim of a cast into the dim of the source of the tensor cast.
711 template <typename CastOpTy>
712 struct DimOfCastOp : public OpRewritePattern<DimOp> {
713   using OpRewritePattern<DimOp>::OpRewritePattern;
714 
715   LogicalResult matchAndRewrite(DimOp dimOp,
716                                 PatternRewriter &rewriter) const override {
717     auto castOp = dimOp.memrefOrTensor().getDefiningOp<CastOpTy>();
718     if (!castOp)
719       return failure();
720     Value newSource = castOp.getOperand();
721     rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.index());
722     return success();
723   }
724 };
725 
726 /// Helper method to get the `Value` that is the shape of the `resultIdx`-th
727 /// result at dimension `dimIndex` from the `ShapedTypeOpInterface`.
728 /// TODO(ravishankarm): This is better put as a interface utility method
729 /// somewhere, but that would imply the interface will depend on the `tensor`
730 /// dialect. Ideally maybe a utility method in the `tensor` dialect.
731 static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result,
732                                             int64_t dimIndex) {
733   unsigned resultNumber = result.getResultNumber();
734   auto shapedTypeOp = dyn_cast<InferShapedTypeOpInterface>(result.getOwner());
735   Location loc = result.getOwner()->getLoc();
736   if (!shapedTypeOp)
737     return nullptr;
738 
739   // The interface exposes two methods, one that returns the shape of all the
740   // results as `Value` and other that returns the shape as a list of
741   // `SmallVector<Value>`. The former takes precedence over the latter. So first
742   // check if the op implements the first interface method or the second, and
743   // get the value to use appropriately.
744   SmallVector<Value> reifiedResultShapes;
745   if (succeeded(
746           shapedTypeOp.reifyReturnTypeShapes(builder, reifiedResultShapes))) {
747     if (reifiedResultShapes.size() <= resultNumber)
748       return nullptr;
749     Value resultShape = reifiedResultShapes[resultNumber];
750     auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>();
751     if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>())
752       return nullptr;
753     return builder.create<tensor::ExtractOp>(
754         loc, resultShape, builder.createOrFold<ConstantIndexOp>(loc, dimIndex));
755   }
756 
757   SmallVector<SmallVector<Value>> reifiedResultShapesPerDim;
758   if (failed(shapedTypeOp.reifyReturnTypeShapesPerResultDim(
759           builder, reifiedResultShapesPerDim)))
760     return nullptr;
761   if (reifiedResultShapesPerDim.size() <= resultNumber ||
762       reifiedResultShapesPerDim[resultNumber].size() !=
763           static_cast<size_t>(result.getType().cast<ShapedType>().getRank()))
764     return nullptr;
765   OpFoldResult valueOrAttr = reifiedResultShapesPerDim[resultNumber][dimIndex];
766   if (auto attr = valueOrAttr.dyn_cast<Attribute>())
767     return builder.createOrFold<ConstantIndexOp>(
768         loc, attr.cast<IntegerAttr>().getInt());
769   return valueOrAttr.get<Value>();
770 }
771 
772 /// Fold dim of an operation that implements the InferShapedTypeOpInterface
773 struct DimOfShapedTypeOpInterface : public OpRewritePattern<DimOp> {
774   using OpRewritePattern<DimOp>::OpRewritePattern;
775 
776   LogicalResult matchAndRewrite(DimOp dimOp,
777                                 PatternRewriter &rewriter) const override {
778     OpResult dimValue = dimOp.memrefOrTensor().dyn_cast<OpResult>();
779     if (!dimValue)
780       return failure();
781     auto shapedTypeOp =
782         dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner());
783     if (!shapedTypeOp)
784       return failure();
785 
786     Optional<int64_t> dimIndex = dimOp.getConstantIndex();
787     if (!dimIndex)
788       return failure();
789     Value replacement =
790         getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex);
791     if (!replacement)
792       return failure();
793     rewriter.replaceOp(dimOp, replacement);
794     return success();
795   }
796 };
797 } // end anonymous namespace.
798 
799 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
800                                         MLIRContext *context) {
801   results.add<DimOfMemRefReshape, DimOfCastOp<BufferCastOp>,
802               DimOfCastOp<tensor::CastOp>, DimOfShapedTypeOpInterface>(context);
803 }
804 
805 // ---------------------------------------------------------------------------
806 // DmaStartOp
807 // ---------------------------------------------------------------------------
808 
809 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
810                        Value srcMemRef, ValueRange srcIndices, Value destMemRef,
811                        ValueRange destIndices, Value numElements,
812                        Value tagMemRef, ValueRange tagIndices, Value stride,
813                        Value elementsPerStride) {
814   result.addOperands(srcMemRef);
815   result.addOperands(srcIndices);
816   result.addOperands(destMemRef);
817   result.addOperands(destIndices);
818   result.addOperands({numElements, tagMemRef});
819   result.addOperands(tagIndices);
820   if (stride)
821     result.addOperands({stride, elementsPerStride});
822 }
823 
824 void DmaStartOp::print(OpAsmPrinter &p) {
825   p << getOperationName() << " " << getSrcMemRef() << '[' << getSrcIndices()
826     << "], " << getDstMemRef() << '[' << getDstIndices() << "], "
827     << getNumElements() << ", " << getTagMemRef() << '[' << getTagIndices()
828     << ']';
829   if (isStrided())
830     p << ", " << getStride() << ", " << getNumElementsPerStride();
831 
832   p.printOptionalAttrDict((*this)->getAttrs());
833   p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
834     << ", " << getTagMemRef().getType();
835 }
836 
837 // Parse DmaStartOp.
838 // Ex:
839 //   %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
840 //                       %tag[%index], %stride, %num_elt_per_stride :
841 //                     : memref<3076 x f32, 0>,
842 //                       memref<1024 x f32, 2>,
843 //                       memref<1 x i32>
844 //
845 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
846   OpAsmParser::OperandType srcMemRefInfo;
847   SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
848   OpAsmParser::OperandType dstMemRefInfo;
849   SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos;
850   OpAsmParser::OperandType numElementsInfo;
851   OpAsmParser::OperandType tagMemrefInfo;
852   SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
853   SmallVector<OpAsmParser::OperandType, 2> strideInfo;
854 
855   SmallVector<Type, 3> types;
856   auto indexType = parser.getBuilder().getIndexType();
857 
858   // Parse and resolve the following list of operands:
859   // *) source memref followed by its indices (in square brackets).
860   // *) destination memref followed by its indices (in square brackets).
861   // *) dma size in KiB.
862   if (parser.parseOperand(srcMemRefInfo) ||
863       parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
864       parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
865       parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
866       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
867       parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
868       parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
869     return failure();
870 
871   // Parse optional stride and elements per stride.
872   if (parser.parseTrailingOperandList(strideInfo))
873     return failure();
874 
875   bool isStrided = strideInfo.size() == 2;
876   if (!strideInfo.empty() && !isStrided) {
877     return parser.emitError(parser.getNameLoc(),
878                             "expected two stride related operands");
879   }
880 
881   if (parser.parseColonTypeList(types))
882     return failure();
883   if (types.size() != 3)
884     return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
885 
886   if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
887       parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
888       parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
889       parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
890       // size should be an index.
891       parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
892       parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
893       // tag indices should be index.
894       parser.resolveOperands(tagIndexInfos, indexType, result.operands))
895     return failure();
896 
897   if (isStrided) {
898     if (parser.resolveOperands(strideInfo, indexType, result.operands))
899       return failure();
900   }
901 
902   return success();
903 }
904 
905 LogicalResult DmaStartOp::verify() {
906   unsigned numOperands = getNumOperands();
907 
908   // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
909   // the number of elements.
910   if (numOperands < 4)
911     return emitOpError("expected at least 4 operands");
912 
913   // Check types of operands. The order of these calls is important: the later
914   // calls rely on some type properties to compute the operand position.
915   // 1. Source memref.
916   if (!getSrcMemRef().getType().isa<MemRefType>())
917     return emitOpError("expected source to be of memref type");
918   if (numOperands < getSrcMemRefRank() + 4)
919     return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
920                          << " operands";
921   if (!getSrcIndices().empty() &&
922       !llvm::all_of(getSrcIndices().getTypes(),
923                     [](Type t) { return t.isIndex(); }))
924     return emitOpError("expected source indices to be of index type");
925 
926   // 2. Destination memref.
927   if (!getDstMemRef().getType().isa<MemRefType>())
928     return emitOpError("expected destination to be of memref type");
929   unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
930   if (numOperands < numExpectedOperands)
931     return emitOpError() << "expected at least " << numExpectedOperands
932                          << " operands";
933   if (!getDstIndices().empty() &&
934       !llvm::all_of(getDstIndices().getTypes(),
935                     [](Type t) { return t.isIndex(); }))
936     return emitOpError("expected destination indices to be of index type");
937 
938   // 3. Number of elements.
939   if (!getNumElements().getType().isIndex())
940     return emitOpError("expected num elements to be of index type");
941 
942   // 4. Tag memref.
943   if (!getTagMemRef().getType().isa<MemRefType>())
944     return emitOpError("expected tag to be of memref type");
945   numExpectedOperands += getTagMemRefRank();
946   if (numOperands < numExpectedOperands)
947     return emitOpError() << "expected at least " << numExpectedOperands
948                          << " operands";
949   if (!getTagIndices().empty() &&
950       !llvm::all_of(getTagIndices().getTypes(),
951                     [](Type t) { return t.isIndex(); }))
952     return emitOpError("expected tag indices to be of index type");
953 
954   // DMAs from different memory spaces supported.
955   if (getSrcMemorySpace() == getDstMemorySpace())
956     return emitOpError("DMA should be between different memory spaces");
957 
958   // Optional stride-related operands must be either both present or both
959   // absent.
960   if (numOperands != numExpectedOperands &&
961       numOperands != numExpectedOperands + 2)
962     return emitOpError("incorrect number of operands");
963 
964   // 5. Strides.
965   if (isStrided()) {
966     if (!getStride().getType().isIndex() ||
967         !getNumElementsPerStride().getType().isIndex())
968       return emitOpError(
969           "expected stride and num elements per stride to be of type index");
970   }
971 
972   return success();
973 }
974 
975 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
976                                SmallVectorImpl<OpFoldResult> &results) {
977   /// dma_start(memrefcast) -> dma_start
978   return foldMemRefCast(*this);
979 }
980 
981 // ---------------------------------------------------------------------------
982 // DmaWaitOp
983 // ---------------------------------------------------------------------------
984 
985 void DmaWaitOp::build(OpBuilder &builder, OperationState &result,
986                       Value tagMemRef, ValueRange tagIndices,
987                       Value numElements) {
988   result.addOperands(tagMemRef);
989   result.addOperands(tagIndices);
990   result.addOperands(numElements);
991 }
992 
993 void DmaWaitOp::print(OpAsmPrinter &p) {
994   p << getOperationName() << " " << getTagMemRef() << '[' << getTagIndices()
995     << "], " << getNumElements();
996   p.printOptionalAttrDict((*this)->getAttrs());
997   p << " : " << getTagMemRef().getType();
998 }
999 
1000 // Parse DmaWaitOp.
1001 // Eg:
1002 //   dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4>
1003 //
1004 ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) {
1005   OpAsmParser::OperandType tagMemrefInfo;
1006   SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
1007   Type type;
1008   auto indexType = parser.getBuilder().getIndexType();
1009   OpAsmParser::OperandType numElementsInfo;
1010 
1011   // Parse tag memref, its indices, and dma size.
1012   if (parser.parseOperand(tagMemrefInfo) ||
1013       parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) ||
1014       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1015       parser.parseColonType(type) ||
1016       parser.resolveOperand(tagMemrefInfo, type, result.operands) ||
1017       parser.resolveOperands(tagIndexInfos, indexType, result.operands) ||
1018       parser.resolveOperand(numElementsInfo, indexType, result.operands))
1019     return failure();
1020 
1021   return success();
1022 }
1023 
1024 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1025                               SmallVectorImpl<OpFoldResult> &results) {
1026   /// dma_wait(memrefcast) -> dma_wait
1027   return foldMemRefCast(*this);
1028 }
1029 
1030 LogicalResult DmaWaitOp::verify() {
1031   // Mandatory non-variadic operands are tag and the number of elements.
1032   if (getNumOperands() < 2)
1033     return emitOpError() << "expected at least 2 operands";
1034 
1035   // Check types of operands. The order of these calls is important: the later
1036   // calls rely on some type properties to compute the operand position.
1037   if (!getTagMemRef().getType().isa<MemRefType>())
1038     return emitOpError() << "expected tag to be of memref type";
1039 
1040   if (getNumOperands() != 2 + getTagMemRefRank())
1041     return emitOpError() << "expected " << 2 + getTagMemRefRank()
1042                          << " operands";
1043 
1044   if (!getTagIndices().empty() &&
1045       !llvm::all_of(getTagIndices().getTypes(),
1046                     [](Type t) { return t.isIndex(); }))
1047     return emitOpError() << "expected tag indices to be of index type";
1048 
1049   if (!getNumElements().getType().isIndex())
1050     return emitOpError()
1051            << "expected the number of elements to be of index type";
1052 
1053   return success();
1054 }
1055 
1056 //===----------------------------------------------------------------------===//
1057 // GlobalOp
1058 //===----------------------------------------------------------------------===//
1059 
1060 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op,
1061                                                    TypeAttr type,
1062                                                    Attribute initialValue) {
1063   p << type;
1064   if (!op.isExternal()) {
1065     p << " = ";
1066     if (op.isUninitialized())
1067       p << "uninitialized";
1068     else
1069       p.printAttributeWithoutType(initialValue);
1070   }
1071 }
1072 
1073 static ParseResult
1074 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
1075                                        Attribute &initialValue) {
1076   Type type;
1077   if (parser.parseType(type))
1078     return failure();
1079 
1080   auto memrefType = type.dyn_cast<MemRefType>();
1081   if (!memrefType || !memrefType.hasStaticShape())
1082     return parser.emitError(parser.getNameLoc())
1083            << "type should be static shaped memref, but got " << type;
1084   typeAttr = TypeAttr::get(type);
1085 
1086   if (parser.parseOptionalEqual())
1087     return success();
1088 
1089   if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1090     initialValue = UnitAttr::get(parser.getBuilder().getContext());
1091     return success();
1092   }
1093 
1094   Type tensorType = getTensorTypeFromMemRefType(memrefType);
1095   if (parser.parseAttribute(initialValue, tensorType))
1096     return failure();
1097   if (!initialValue.isa<ElementsAttr>())
1098     return parser.emitError(parser.getNameLoc())
1099            << "initial value should be a unit or elements attribute";
1100   return success();
1101 }
1102 
1103 static LogicalResult verify(GlobalOp op) {
1104   auto memrefType = op.type().dyn_cast<MemRefType>();
1105   if (!memrefType || !memrefType.hasStaticShape())
1106     return op.emitOpError("type should be static shaped memref, but got ")
1107            << op.type();
1108 
1109   // Verify that the initial value, if present, is either a unit attribute or
1110   // an elements attribute.
1111   if (op.initial_value().hasValue()) {
1112     Attribute initValue = op.initial_value().getValue();
1113     if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>())
1114       return op.emitOpError("initial value should be a unit or elements "
1115                             "attribute, but got ")
1116              << initValue;
1117 
1118     // Check that the type of the initial value is compatible with the type of
1119     // the global variable.
1120     if (initValue.isa<ElementsAttr>()) {
1121       Type initType = initValue.getType();
1122       Type tensorType = getTensorTypeFromMemRefType(memrefType);
1123       if (initType != tensorType)
1124         return op.emitOpError("initial value expected to be of type ")
1125                << tensorType << ", but was of type " << initType;
1126     }
1127   }
1128 
1129   // TODO: verify visibility for declarations.
1130   return success();
1131 }
1132 
1133 //===----------------------------------------------------------------------===//
1134 // GetGlobalOp
1135 //===----------------------------------------------------------------------===//
1136 
1137 LogicalResult
1138 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1139   // Verify that the result type is same as the type of the referenced
1140   // memref.global op.
1141   auto global =
1142       symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr());
1143   if (!global)
1144     return emitOpError("'")
1145            << name() << "' does not reference a valid global memref";
1146 
1147   Type resultType = result().getType();
1148   if (global.type() != resultType)
1149     return emitOpError("result type ")
1150            << resultType << " does not match type " << global.type()
1151            << " of the global memref @" << name();
1152   return success();
1153 }
1154 
1155 //===----------------------------------------------------------------------===//
1156 // LoadOp
1157 //===----------------------------------------------------------------------===//
1158 
1159 static LogicalResult verify(LoadOp op) {
1160   if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
1161     return op.emitOpError("incorrect number of indices for load");
1162   return success();
1163 }
1164 
1165 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
1166   /// load(memrefcast) -> load
1167   if (succeeded(foldMemRefCast(*this)))
1168     return getResult();
1169   return OpFoldResult();
1170 }
1171 
1172 namespace {
1173 /// Fold a load on a buffer_cast operation into an tensor.extract on the
1174 /// corresponding tensor.
1175 struct LoadOfBufferCast : public OpRewritePattern<LoadOp> {
1176   using OpRewritePattern<LoadOp>::OpRewritePattern;
1177 
1178   LogicalResult matchAndRewrite(LoadOp load,
1179                                 PatternRewriter &rewriter) const override {
1180     auto buffercast = load.memref().getDefiningOp<BufferCastOp>();
1181     if (!buffercast)
1182       return failure();
1183 
1184     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, buffercast.tensor(),
1185                                                    load.indices());
1186     return success();
1187   }
1188 };
1189 } // end anonymous namespace.
1190 
1191 void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
1192                                          MLIRContext *context) {
1193   results.add<LoadOfBufferCast>(context);
1194 }
1195 
1196 //===----------------------------------------------------------------------===//
1197 // PrefetchOp
1198 //===----------------------------------------------------------------------===//
1199 
1200 static void print(OpAsmPrinter &p, PrefetchOp op) {
1201   p << PrefetchOp::getOperationName() << " " << op.memref() << '[';
1202   p.printOperands(op.indices());
1203   p << ']' << ", " << (op.isWrite() ? "write" : "read");
1204   p << ", locality<" << op.localityHint();
1205   p << ">, " << (op.isDataCache() ? "data" : "instr");
1206   p.printOptionalAttrDict(
1207       op->getAttrs(),
1208       /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1209   p << " : " << op.getMemRefType();
1210 }
1211 
1212 static ParseResult parsePrefetchOp(OpAsmParser &parser,
1213                                    OperationState &result) {
1214   OpAsmParser::OperandType memrefInfo;
1215   SmallVector<OpAsmParser::OperandType, 4> indexInfo;
1216   IntegerAttr localityHint;
1217   MemRefType type;
1218   StringRef readOrWrite, cacheType;
1219 
1220   auto indexTy = parser.getBuilder().getIndexType();
1221   auto i32Type = parser.getBuilder().getIntegerType(32);
1222   if (parser.parseOperand(memrefInfo) ||
1223       parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1224       parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1225       parser.parseComma() || parser.parseKeyword("locality") ||
1226       parser.parseLess() ||
1227       parser.parseAttribute(localityHint, i32Type, "localityHint",
1228                             result.attributes) ||
1229       parser.parseGreater() || parser.parseComma() ||
1230       parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1231       parser.resolveOperand(memrefInfo, type, result.operands) ||
1232       parser.resolveOperands(indexInfo, indexTy, result.operands))
1233     return failure();
1234 
1235   if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
1236     return parser.emitError(parser.getNameLoc(),
1237                             "rw specifier has to be 'read' or 'write'");
1238   result.addAttribute(
1239       PrefetchOp::getIsWriteAttrName(),
1240       parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
1241 
1242   if (!cacheType.equals("data") && !cacheType.equals("instr"))
1243     return parser.emitError(parser.getNameLoc(),
1244                             "cache type has to be 'data' or 'instr'");
1245 
1246   result.addAttribute(
1247       PrefetchOp::getIsDataCacheAttrName(),
1248       parser.getBuilder().getBoolAttr(cacheType.equals("data")));
1249 
1250   return success();
1251 }
1252 
1253 static LogicalResult verify(PrefetchOp op) {
1254   if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
1255     return op.emitOpError("too few indices");
1256 
1257   return success();
1258 }
1259 
1260 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
1261                                SmallVectorImpl<OpFoldResult> &results) {
1262   // prefetch(memrefcast) -> prefetch
1263   return foldMemRefCast(*this);
1264 }
1265 
1266 //===----------------------------------------------------------------------===//
1267 // ReinterpretCastOp
1268 //===----------------------------------------------------------------------===//
1269 
1270 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1271 /// `staticSizes` and `staticStrides` are automatically filled with
1272 /// source-memref-rank sentinel values that encode dynamic entries.
1273 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1274                               MemRefType resultType, Value source,
1275                               OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1276                               ArrayRef<OpFoldResult> strides,
1277                               ArrayRef<NamedAttribute> attrs) {
1278   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1279   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1280   dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets,
1281                              ShapedType::kDynamicStrideOrOffset);
1282   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1283                              ShapedType::kDynamicSize);
1284   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1285                              ShapedType::kDynamicStrideOrOffset);
1286   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1287         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1288         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1289   result.addAttributes(attrs);
1290 }
1291 
1292 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1293                               MemRefType resultType, Value source,
1294                               int64_t offset, ArrayRef<int64_t> sizes,
1295                               ArrayRef<int64_t> strides,
1296                               ArrayRef<NamedAttribute> attrs) {
1297   SmallVector<OpFoldResult> sizeValues =
1298       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1299         return b.getI64IntegerAttr(v);
1300       }));
1301   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1302       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1303         return b.getI64IntegerAttr(v);
1304       }));
1305   build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1306         strideValues, attrs);
1307 }
1308 
1309 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1310                               MemRefType resultType, Value source, Value offset,
1311                               ValueRange sizes, ValueRange strides,
1312                               ArrayRef<NamedAttribute> attrs) {
1313   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1314       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1315   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1316       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1317   build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1318 }
1319 
1320 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1321 // completed automatically, like we have for subview and subtensor.
1322 static LogicalResult verify(ReinterpretCastOp op) {
1323   // The source and result memrefs should be in the same memory space.
1324   auto srcType = op.source().getType().cast<BaseMemRefType>();
1325   auto resultType = op.getType().cast<MemRefType>();
1326   if (srcType.getMemorySpace() != resultType.getMemorySpace())
1327     return op.emitError("different memory spaces specified for source type ")
1328            << srcType << " and result memref type " << resultType;
1329   if (srcType.getElementType() != resultType.getElementType())
1330     return op.emitError("different element types specified for source type ")
1331            << srcType << " and result memref type " << resultType;
1332 
1333   // Match sizes in result memref type and in static_sizes attribute.
1334   for (auto &en :
1335        llvm::enumerate(llvm::zip(resultType.getShape(),
1336                                  extractFromI64ArrayAttr(op.static_sizes())))) {
1337     int64_t resultSize = std::get<0>(en.value());
1338     int64_t expectedSize = std::get<1>(en.value());
1339     if (resultSize != expectedSize)
1340       return op.emitError("expected result type with size = ")
1341              << expectedSize << " instead of " << resultSize
1342              << " in dim = " << en.index();
1343   }
1344 
1345   // Match offset and strides in static_offset and static_strides attributes if
1346   // result memref type has an affine map specified.
1347   if (!resultType.getAffineMaps().empty()) {
1348     int64_t resultOffset;
1349     SmallVector<int64_t, 4> resultStrides;
1350     if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
1351       return failure();
1352 
1353     // Match offset in result memref type and in static_offsets attribute.
1354     int64_t expectedOffset =
1355         extractFromI64ArrayAttr(op.static_offsets()).front();
1356     if (resultOffset != expectedOffset)
1357       return op.emitError("expected result type with offset = ")
1358              << resultOffset << " instead of " << expectedOffset;
1359 
1360     // Match strides in result memref type and in static_strides attribute.
1361     for (auto &en : llvm::enumerate(llvm::zip(
1362              resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
1363       int64_t resultStride = std::get<0>(en.value());
1364       int64_t expectedStride = std::get<1>(en.value());
1365       if (resultStride != expectedStride)
1366         return op.emitError("expected result type with stride = ")
1367                << expectedStride << " instead of " << resultStride
1368                << " in dim = " << en.index();
1369     }
1370   }
1371   return success();
1372 }
1373 
1374 //===----------------------------------------------------------------------===//
1375 // ReshapeOp
1376 //===----------------------------------------------------------------------===//
1377 
1378 static LogicalResult verify(ReshapeOp op) {
1379   Type operandType = op.source().getType();
1380   Type resultType = op.result().getType();
1381 
1382   Type operandElementType = operandType.cast<ShapedType>().getElementType();
1383   Type resultElementType = resultType.cast<ShapedType>().getElementType();
1384   if (operandElementType != resultElementType)
1385     return op.emitOpError("element types of source and destination memref "
1386                           "types should be the same");
1387 
1388   if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
1389     if (!operandMemRefType.getAffineMaps().empty())
1390       return op.emitOpError(
1391           "source memref type should have identity affine map");
1392 
1393   int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0);
1394   auto resultMemRefType = resultType.dyn_cast<MemRefType>();
1395   if (resultMemRefType) {
1396     if (!resultMemRefType.getAffineMaps().empty())
1397       return op.emitOpError(
1398           "result memref type should have identity affine map");
1399     if (shapeSize == ShapedType::kDynamicSize)
1400       return op.emitOpError("cannot use shape operand with dynamic length to "
1401                             "reshape to statically-ranked memref type");
1402     if (shapeSize != resultMemRefType.getRank())
1403       return op.emitOpError(
1404           "length of shape operand differs from the result's memref rank");
1405   }
1406   return success();
1407 }
1408 
1409 //===----------------------------------------------------------------------===//
1410 // StoreOp
1411 //===----------------------------------------------------------------------===//
1412 
1413 static LogicalResult verify(StoreOp op) {
1414   if (op.getNumOperands() != 2 + op.getMemRefType().getRank())
1415     return op.emitOpError("store index operand count not equal to memref rank");
1416 
1417   return success();
1418 }
1419 
1420 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
1421                             SmallVectorImpl<OpFoldResult> &results) {
1422   /// store(memrefcast) -> store
1423   return foldMemRefCast(*this);
1424 }
1425 
1426 //===----------------------------------------------------------------------===//
1427 // SubViewOp
1428 //===----------------------------------------------------------------------===//
1429 
1430 namespace {
1431 /// Helpers to write more idiomatic operations.
1432 namespace saturated_arith {
1433 struct Wrapper {
1434   explicit Wrapper(int64_t v) : v(v) {}
1435   operator int64_t() { return v; }
1436   int64_t v;
1437 };
1438 Wrapper operator+(Wrapper a, int64_t b) {
1439   if (ShapedType::isDynamicStrideOrOffset(a) ||
1440       ShapedType::isDynamicStrideOrOffset(b))
1441     return Wrapper(ShapedType::kDynamicStrideOrOffset);
1442   return Wrapper(a.v + b);
1443 }
1444 Wrapper operator*(Wrapper a, int64_t b) {
1445   if (ShapedType::isDynamicStrideOrOffset(a) ||
1446       ShapedType::isDynamicStrideOrOffset(b))
1447     return Wrapper(ShapedType::kDynamicStrideOrOffset);
1448   return Wrapper(a.v * b);
1449 }
1450 } // end namespace saturated_arith
1451 } // end namespace
1452 
1453 /// A subview result type can be fully inferred from the source type and the
1454 /// static representation of offsets, sizes and strides. Special sentinels
1455 /// encode the dynamic case.
1456 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
1457                                 ArrayRef<int64_t> leadingStaticOffsets,
1458                                 ArrayRef<int64_t> leadingStaticSizes,
1459                                 ArrayRef<int64_t> leadingStaticStrides) {
1460   // A subview may specify only a leading subset of offset/sizes/strides in
1461   // which case we complete with offset=0, sizes from memref type and strides=1.
1462   unsigned rank = sourceMemRefType.getRank();
1463   assert(leadingStaticOffsets.size() <= rank &&
1464          "unexpected leadingStaticOffsets overflow");
1465   assert(leadingStaticSizes.size() <= rank &&
1466          "unexpected leadingStaticSizes overflow");
1467   assert(leadingStaticStrides.size() <= rank &&
1468          "unexpected leadingStaticStrides overflow");
1469   auto staticOffsets = llvm::to_vector<4>(leadingStaticOffsets);
1470   auto staticSizes = llvm::to_vector<4>(leadingStaticSizes);
1471   auto staticStrides = llvm::to_vector<4>(leadingStaticStrides);
1472   unsigned numTrailingOffsets = rank - staticOffsets.size();
1473   unsigned numTrailingSizes = rank - staticSizes.size();
1474   unsigned numTrailingStrides = rank - staticStrides.size();
1475   staticOffsets.append(numTrailingOffsets, 0);
1476   llvm::append_range(staticSizes,
1477                      sourceMemRefType.getShape().take_back(numTrailingSizes));
1478   staticStrides.append(numTrailingStrides, 1);
1479 
1480   // Extract source offset and strides.
1481   int64_t sourceOffset;
1482   SmallVector<int64_t, 4> sourceStrides;
1483   auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset);
1484   assert(succeeded(res) && "SubViewOp expected strided memref type");
1485   (void)res;
1486 
1487   // Compute target offset whose value is:
1488   //   `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
1489   int64_t targetOffset = sourceOffset;
1490   for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
1491     auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
1492     using namespace saturated_arith;
1493     targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride;
1494   }
1495 
1496   // Compute target stride whose value is:
1497   //   `sourceStrides_i * staticStrides_i`.
1498   SmallVector<int64_t, 4> targetStrides;
1499   targetStrides.reserve(staticOffsets.size());
1500   for (auto it : llvm::zip(sourceStrides, staticStrides)) {
1501     auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
1502     using namespace saturated_arith;
1503     targetStrides.push_back(Wrapper(sourceStride) * staticStride);
1504   }
1505 
1506   // The type is now known.
1507   return MemRefType::get(
1508       staticSizes, sourceMemRefType.getElementType(),
1509       makeStridedLinearLayoutMap(targetStrides, targetOffset,
1510                                  sourceMemRefType.getContext()),
1511       sourceMemRefType.getMemorySpace());
1512 }
1513 
1514 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
1515                                 ArrayRef<OpFoldResult> leadingStaticOffsets,
1516                                 ArrayRef<OpFoldResult> leadingStaticSizes,
1517                                 ArrayRef<OpFoldResult> leadingStaticStrides) {
1518   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1519   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1520   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
1521                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
1522   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
1523                              ShapedType::kDynamicSize);
1524   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
1525                              staticStrides, ShapedType::kDynamicStrideOrOffset);
1526   return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
1527                                     staticSizes, staticStrides)
1528       .cast<MemRefType>();
1529 }
1530 
1531 Type SubViewOp::inferRankReducedResultType(
1532     unsigned resultRank, MemRefType sourceRankedTensorType,
1533     ArrayRef<int64_t> leadingStaticOffsets,
1534     ArrayRef<int64_t> leadingStaticSizes,
1535     ArrayRef<int64_t> leadingStaticStrides) {
1536   auto inferredType =
1537       inferResultType(sourceRankedTensorType, leadingStaticOffsets,
1538                       leadingStaticSizes, leadingStaticStrides)
1539           .cast<MemRefType>();
1540   assert(inferredType.getRank() >= resultRank && "expected ");
1541   int rankDiff = inferredType.getRank() - resultRank;
1542   if (rankDiff > 0) {
1543     auto shape = inferredType.getShape();
1544     llvm::SmallDenseSet<unsigned> dimsToProject;
1545     mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject);
1546     SmallVector<int64_t> projectedShape;
1547     for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
1548       if (!dimsToProject.contains(pos))
1549         projectedShape.push_back(shape[pos]);
1550 
1551     AffineMap map;
1552     auto maps = inferredType.getAffineMaps();
1553     if (!maps.empty() && maps.front())
1554       map = getProjectedMap(maps.front(), dimsToProject);
1555     inferredType =
1556         MemRefType::get(projectedShape, inferredType.getElementType(), map,
1557                         inferredType.getMemorySpace());
1558   }
1559   return inferredType;
1560 }
1561 
1562 Type SubViewOp::inferRankReducedResultType(
1563     unsigned resultRank, MemRefType sourceRankedTensorType,
1564     ArrayRef<OpFoldResult> leadingStaticOffsets,
1565     ArrayRef<OpFoldResult> leadingStaticSizes,
1566     ArrayRef<OpFoldResult> leadingStaticStrides) {
1567   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1568   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1569   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
1570                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
1571   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
1572                              ShapedType::kDynamicSize);
1573   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
1574                              staticStrides, ShapedType::kDynamicStrideOrOffset);
1575   return SubViewOp::inferRankReducedResultType(
1576       resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
1577       staticStrides);
1578 }
1579 // Build a SubViewOp with mixed static and dynamic entries and custom result
1580 // type. If the type passed is nullptr, it is inferred.
1581 void SubViewOp::build(OpBuilder &b, OperationState &result,
1582                       MemRefType resultType, Value source,
1583                       ArrayRef<OpFoldResult> offsets,
1584                       ArrayRef<OpFoldResult> sizes,
1585                       ArrayRef<OpFoldResult> strides,
1586                       ArrayRef<NamedAttribute> attrs) {
1587   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1588   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1589   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1590                              ShapedType::kDynamicStrideOrOffset);
1591   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1592                              ShapedType::kDynamicSize);
1593   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1594                              ShapedType::kDynamicStrideOrOffset);
1595   auto sourceMemRefType = source.getType().cast<MemRefType>();
1596   // Structuring implementation this way avoids duplication between builders.
1597   if (!resultType) {
1598     resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
1599                                             staticSizes, staticStrides)
1600                      .cast<MemRefType>();
1601   }
1602   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1603         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1604         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1605   result.addAttributes(attrs);
1606 }
1607 
1608 // Build a SubViewOp with mixed static and dynamic entries and inferred result
1609 // type.
1610 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1611                       ArrayRef<OpFoldResult> offsets,
1612                       ArrayRef<OpFoldResult> sizes,
1613                       ArrayRef<OpFoldResult> strides,
1614                       ArrayRef<NamedAttribute> attrs) {
1615   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
1616 }
1617 
1618 // Build a SubViewOp with static entries and inferred result type.
1619 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1620                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1621                       ArrayRef<int64_t> strides,
1622                       ArrayRef<NamedAttribute> attrs) {
1623   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1624       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
1625         return b.getI64IntegerAttr(v);
1626       }));
1627   SmallVector<OpFoldResult> sizeValues =
1628       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1629         return b.getI64IntegerAttr(v);
1630       }));
1631   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1632       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1633         return b.getI64IntegerAttr(v);
1634       }));
1635   build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
1636 }
1637 
1638 // Build a SubViewOp with dynamic entries and custom result type. If the
1639 // type passed is nullptr, it is inferred.
1640 void SubViewOp::build(OpBuilder &b, OperationState &result,
1641                       MemRefType resultType, Value source,
1642                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1643                       ArrayRef<int64_t> strides,
1644                       ArrayRef<NamedAttribute> attrs) {
1645   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1646       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
1647         return b.getI64IntegerAttr(v);
1648       }));
1649   SmallVector<OpFoldResult> sizeValues =
1650       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1651         return b.getI64IntegerAttr(v);
1652       }));
1653   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1654       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1655         return b.getI64IntegerAttr(v);
1656       }));
1657   build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
1658         attrs);
1659 }
1660 
1661 // Build a SubViewOp with dynamic entries and custom result type. If the type
1662 // passed is nullptr, it is inferred.
1663 void SubViewOp::build(OpBuilder &b, OperationState &result,
1664                       MemRefType resultType, Value source, ValueRange offsets,
1665                       ValueRange sizes, ValueRange strides,
1666                       ArrayRef<NamedAttribute> attrs) {
1667   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1668       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1669   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1670       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1671   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1672       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1673   build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
1674 }
1675 
1676 // Build a SubViewOp with dynamic entries and inferred result type.
1677 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1678                       ValueRange offsets, ValueRange sizes, ValueRange strides,
1679                       ArrayRef<NamedAttribute> attrs) {
1680   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
1681 }
1682 
1683 /// For ViewLikeOpInterface.
1684 Value SubViewOp::getViewSource() { return source(); }
1685 
1686 enum SubViewVerificationResult {
1687   Success,
1688   RankTooLarge,
1689   SizeMismatch,
1690   ElemTypeMismatch,
1691   MemSpaceMismatch,
1692   AffineMapMismatch
1693 };
1694 
1695 /// Checks if `original` Type type can be rank reduced to `reduced` type.
1696 /// This function is slight variant of `is subsequence` algorithm where
1697 /// not matching dimension must be 1.
1698 static SubViewVerificationResult
1699 isRankReducedType(Type originalType, Type candidateReducedType,
1700                   std::string *errMsg = nullptr) {
1701   if (originalType == candidateReducedType)
1702     return SubViewVerificationResult::Success;
1703   if (!originalType.isa<MemRefType>())
1704     return SubViewVerificationResult::Success;
1705   if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>())
1706     return SubViewVerificationResult::Success;
1707 
1708   ShapedType originalShapedType = originalType.cast<ShapedType>();
1709   ShapedType candidateReducedShapedType =
1710       candidateReducedType.cast<ShapedType>();
1711 
1712   // Rank and size logic is valid for all ShapedTypes.
1713   ArrayRef<int64_t> originalShape = originalShapedType.getShape();
1714   ArrayRef<int64_t> candidateReducedShape =
1715       candidateReducedShapedType.getShape();
1716   unsigned originalRank = originalShape.size(),
1717            candidateReducedRank = candidateReducedShape.size();
1718   if (candidateReducedRank > originalRank)
1719     return SubViewVerificationResult::RankTooLarge;
1720 
1721   auto optionalUnusedDimsMask =
1722       computeRankReductionMask(originalShape, candidateReducedShape);
1723 
1724   // Sizes cannot be matched in case empty vector is returned.
1725   if (!optionalUnusedDimsMask.hasValue())
1726     return SubViewVerificationResult::SizeMismatch;
1727 
1728   if (originalShapedType.getElementType() !=
1729       candidateReducedShapedType.getElementType())
1730     return SubViewVerificationResult::ElemTypeMismatch;
1731 
1732   // Strided layout logic is relevant for MemRefType only.
1733   MemRefType original = originalType.cast<MemRefType>();
1734   MemRefType candidateReduced = candidateReducedType.cast<MemRefType>();
1735   if (original.getMemorySpace() != candidateReduced.getMemorySpace())
1736     return SubViewVerificationResult::MemSpaceMismatch;
1737 
1738   llvm::SmallDenseSet<unsigned> unusedDims = optionalUnusedDimsMask.getValue();
1739   auto inferredType =
1740       getProjectedMap(getStridedLinearLayoutMap(original), unusedDims);
1741   AffineMap candidateLayout;
1742   if (candidateReduced.getAffineMaps().empty())
1743     candidateLayout = getStridedLinearLayoutMap(candidateReduced);
1744   else
1745     candidateLayout = candidateReduced.getAffineMaps().front();
1746   assert(inferredType.getNumResults() == 1 &&
1747          candidateLayout.getNumResults() == 1);
1748   if (inferredType.getNumSymbols() != candidateLayout.getNumSymbols() ||
1749       inferredType.getNumDims() != candidateLayout.getNumDims()) {
1750     if (errMsg) {
1751       llvm::raw_string_ostream os(*errMsg);
1752       os << "inferred type: " << inferredType;
1753     }
1754     return SubViewVerificationResult::AffineMapMismatch;
1755   }
1756   // Check that the difference of the affine maps simplifies to 0.
1757   AffineExpr diffExpr =
1758       inferredType.getResult(0) - candidateLayout.getResult(0);
1759   diffExpr = simplifyAffineExpr(diffExpr, inferredType.getNumDims(),
1760                                 inferredType.getNumSymbols());
1761   auto cst = diffExpr.dyn_cast<AffineConstantExpr>();
1762   if (!(cst && cst.getValue() == 0)) {
1763     if (errMsg) {
1764       llvm::raw_string_ostream os(*errMsg);
1765       os << "inferred type: " << inferredType;
1766     }
1767     return SubViewVerificationResult::AffineMapMismatch;
1768   }
1769   return SubViewVerificationResult::Success;
1770 }
1771 
1772 template <typename OpTy>
1773 static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result,
1774                                             OpTy op, Type expectedType,
1775                                             StringRef errMsg = "") {
1776   auto memrefType = expectedType.cast<ShapedType>();
1777   switch (result) {
1778   case SubViewVerificationResult::Success:
1779     return success();
1780   case SubViewVerificationResult::RankTooLarge:
1781     return op.emitError("expected result rank to be smaller or equal to ")
1782            << "the source rank. " << errMsg;
1783   case SubViewVerificationResult::SizeMismatch:
1784     return op.emitError("expected result type to be ")
1785            << expectedType
1786            << " or a rank-reduced version. (mismatch of result sizes) "
1787            << errMsg;
1788   case SubViewVerificationResult::ElemTypeMismatch:
1789     return op.emitError("expected result element type to be ")
1790            << memrefType.getElementType() << errMsg;
1791   case SubViewVerificationResult::MemSpaceMismatch:
1792     return op.emitError("expected result and source memory spaces to match.")
1793            << errMsg;
1794   case SubViewVerificationResult::AffineMapMismatch:
1795     return op.emitError("expected result type to be ")
1796            << expectedType
1797            << " or a rank-reduced version. (mismatch of result affine map) "
1798            << errMsg;
1799   }
1800   llvm_unreachable("unexpected subview verification result");
1801 }
1802 
1803 /// Verifier for SubViewOp.
1804 static LogicalResult verify(SubViewOp op) {
1805   MemRefType baseType = op.getSourceType();
1806   MemRefType subViewType = op.getType();
1807 
1808   // The base memref and the view memref should be in the same memory space.
1809   if (baseType.getMemorySpace() != subViewType.getMemorySpace())
1810     return op.emitError("different memory spaces specified for base memref "
1811                         "type ")
1812            << baseType << " and subview memref type " << subViewType;
1813 
1814   // Verify that the base memref type has a strided layout map.
1815   if (!isStrided(baseType))
1816     return op.emitError("base type ") << baseType << " is not strided";
1817 
1818   // Verify result type against inferred type.
1819   auto expectedType = SubViewOp::inferResultType(
1820       baseType, extractFromI64ArrayAttr(op.static_offsets()),
1821       extractFromI64ArrayAttr(op.static_sizes()),
1822       extractFromI64ArrayAttr(op.static_strides()));
1823 
1824   std::string errMsg;
1825   auto result = isRankReducedType(expectedType, subViewType, &errMsg);
1826   return produceSubViewErrorMsg(result, op, expectedType, errMsg);
1827 }
1828 
1829 raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) {
1830   return os << "range " << range.offset << ":" << range.size << ":"
1831             << range.stride;
1832 }
1833 
1834 /// Return the list of Range (i.e. offset, size, stride). Each Range
1835 /// entry contains either the dynamic value or a ConstantIndexOp constructed
1836 /// with `b` at location `loc`.
1837 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
1838                                               OpBuilder &b, Location loc) {
1839   std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
1840   assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
1841   assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
1842   SmallVector<Range, 8> res;
1843   unsigned rank = ranks[0];
1844   res.reserve(rank);
1845   for (unsigned idx = 0; idx < rank; ++idx) {
1846     Value offset =
1847         op.isDynamicOffset(idx)
1848             ? op.getDynamicOffset(idx)
1849             : b.create<ConstantIndexOp>(loc, op.getStaticOffset(idx));
1850     Value size = op.isDynamicSize(idx)
1851                      ? op.getDynamicSize(idx)
1852                      : b.create<ConstantIndexOp>(loc, op.getStaticSize(idx));
1853     Value stride =
1854         op.isDynamicStride(idx)
1855             ? op.getDynamicStride(idx)
1856             : b.create<ConstantIndexOp>(loc, op.getStaticStride(idx));
1857     res.emplace_back(Range{offset, size, stride});
1858   }
1859   return res;
1860 }
1861 
1862 /// Infer the canonical type of the result of a subview operation. Returns a
1863 /// type with rank `resultRank` that is either the rank of the rank-reduced
1864 /// type, or the non-rank-reduced type.
1865 static MemRefType
1866 getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType,
1867                               ArrayRef<OpFoldResult> mixedOffsets,
1868                               ArrayRef<OpFoldResult> mixedSizes,
1869                               ArrayRef<OpFoldResult> mixedStrides) {
1870   auto resultType =
1871       SubViewOp::inferRankReducedResultType(
1872           resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides)
1873           .cast<MemRefType>();
1874   if (resultType.getRank() != resultRank) {
1875     resultType = SubViewOp::inferResultType(sourceType, mixedOffsets,
1876                                             mixedSizes, mixedStrides)
1877                      .cast<MemRefType>();
1878   }
1879   return resultType;
1880 }
1881 
1882 namespace {
1883 /// Pattern to rewrite a subview op with MemRefCast arguments.
1884 /// This essentially pushes memref.cast past its consuming subview when
1885 /// `canFoldIntoConsumerOp` is true.
1886 ///
1887 /// Example:
1888 /// ```
1889 ///   %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
1890 ///   %1 = memref.subview %0[0, 0][3, 4][1, 1] :
1891 ///     memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
1892 /// ```
1893 /// is rewritten into:
1894 /// ```
1895 ///   %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
1896 ///   %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to
1897 ///     memref<3x4xf32, offset:?, strides:[?, 1]>
1898 /// ```
1899 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
1900 public:
1901   using OpRewritePattern<SubViewOp>::OpRewritePattern;
1902 
1903   LogicalResult matchAndRewrite(SubViewOp subViewOp,
1904                                 PatternRewriter &rewriter) const override {
1905     // Any constant operand, just return to let SubViewOpConstantFolder kick in.
1906     if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
1907           return matchPattern(operand, matchConstantIndex());
1908         }))
1909       return failure();
1910 
1911     auto castOp = subViewOp.source().getDefiningOp<CastOp>();
1912     if (!castOp)
1913       return failure();
1914 
1915     if (!CastOp::canFoldIntoConsumerOp(castOp))
1916       return failure();
1917 
1918     /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
1919     /// the cast source operand type and the SubViewOp static information. This
1920     /// is the resulting type if the MemRefCastOp were folded.
1921     auto resultType = getCanonicalSubViewResultType(
1922         subViewOp.getType().getRank(),
1923         castOp.source().getType().cast<MemRefType>(),
1924         subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
1925         subViewOp.getMixedStrides());
1926     Value newSubView = rewriter.create<SubViewOp>(
1927         subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
1928         subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),
1929         subViewOp.static_sizes(), subViewOp.static_strides());
1930     rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
1931                                         newSubView);
1932     return success();
1933   }
1934 };
1935 } // namespace
1936 
1937 /// Return the canonical type of the result of a subview.
1938 struct SubViewReturnTypeCanonicalizer {
1939   MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
1940                         ArrayRef<OpFoldResult> mixedSizes,
1941                         ArrayRef<OpFoldResult> mixedStrides) {
1942     return getCanonicalSubViewResultType(op.getType().getRank(),
1943                                          op.getSourceType(), mixedOffsets,
1944                                          mixedSizes, mixedStrides);
1945   }
1946 };
1947 
1948 /// A canonicalizer wrapper to replace SubViewOps.
1949 struct SubViewCanonicalizer {
1950   void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
1951     rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType());
1952   }
1953 };
1954 
1955 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
1956                                             MLIRContext *context) {
1957   results
1958       .add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
1959                SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>,
1960            SubViewOpMemRefCastFolder>(context);
1961 }
1962 
1963 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
1964   auto resultShapedType = getResult().getType().cast<ShapedType>();
1965   auto sourceShapedType = source().getType().cast<ShapedType>();
1966 
1967   if (resultShapedType.hasStaticShape() &&
1968       resultShapedType == sourceShapedType) {
1969     return getViewSource();
1970   }
1971 
1972   return {};
1973 }
1974 
1975 //===----------------------------------------------------------------------===//
1976 // TensorLoadOp
1977 //===----------------------------------------------------------------------===//
1978 
1979 OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute>) {
1980   if (auto bufferCast = memref().getDefiningOp<BufferCastOp>())
1981     // Approximate alias analysis by conservatively folding only when no there
1982     // is no interleaved operation.
1983     if (bufferCast->getBlock() == this->getOperation()->getBlock() &&
1984         bufferCast->getNextNode() == this->getOperation())
1985       return bufferCast.tensor();
1986   return {};
1987 }
1988 
1989 //===----------------------------------------------------------------------===//
1990 // TransposeOp
1991 //===----------------------------------------------------------------------===//
1992 
1993 /// Build a strided memref type by applying `permutationMap` tp `memRefType`.
1994 static MemRefType inferTransposeResultType(MemRefType memRefType,
1995                                            AffineMap permutationMap) {
1996   auto rank = memRefType.getRank();
1997   auto originalSizes = memRefType.getShape();
1998   // Compute permuted sizes.
1999   SmallVector<int64_t, 4> sizes(rank, 0);
2000   for (auto en : llvm::enumerate(permutationMap.getResults()))
2001     sizes[en.index()] =
2002         originalSizes[en.value().cast<AffineDimExpr>().getPosition()];
2003 
2004   // Compute permuted strides.
2005   int64_t offset;
2006   SmallVector<int64_t, 4> strides;
2007   auto res = getStridesAndOffset(memRefType, strides, offset);
2008   assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank));
2009   (void)res;
2010   auto map =
2011       makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
2012   map = permutationMap ? map.compose(permutationMap) : map;
2013   return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map);
2014 }
2015 
2016 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
2017                         AffineMapAttr permutation,
2018                         ArrayRef<NamedAttribute> attrs) {
2019   auto permutationMap = permutation.getValue();
2020   assert(permutationMap);
2021 
2022   auto memRefType = in.getType().cast<MemRefType>();
2023   // Compute result type.
2024   MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
2025 
2026   build(b, result, resultType, in, attrs);
2027   result.addAttribute(TransposeOp::getPermutationAttrName(), permutation);
2028 }
2029 
2030 // transpose $in $permutation attr-dict : type($in) `to` type(results)
2031 static void print(OpAsmPrinter &p, TransposeOp op) {
2032   p << "memref.transpose " << op.in() << " " << op.permutation();
2033   p.printOptionalAttrDict(op->getAttrs(),
2034                           {TransposeOp::getPermutationAttrName()});
2035   p << " : " << op.in().getType() << " to " << op.getType();
2036 }
2037 
2038 static ParseResult parseTransposeOp(OpAsmParser &parser,
2039                                     OperationState &result) {
2040   OpAsmParser::OperandType in;
2041   AffineMap permutation;
2042   MemRefType srcType, dstType;
2043   if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
2044       parser.parseOptionalAttrDict(result.attributes) ||
2045       parser.parseColonType(srcType) ||
2046       parser.resolveOperand(in, srcType, result.operands) ||
2047       parser.parseKeywordType("to", dstType) ||
2048       parser.addTypeToList(dstType, result.types))
2049     return failure();
2050 
2051   result.addAttribute(TransposeOp::getPermutationAttrName(),
2052                       AffineMapAttr::get(permutation));
2053   return success();
2054 }
2055 
2056 static LogicalResult verify(TransposeOp op) {
2057   if (!op.permutation().isPermutation())
2058     return op.emitOpError("expected a permutation map");
2059   if (op.permutation().getNumDims() != op.getShapedType().getRank())
2060     return op.emitOpError(
2061         "expected a permutation map of same rank as the input");
2062 
2063   auto srcType = op.in().getType().cast<MemRefType>();
2064   auto dstType = op.getType().cast<MemRefType>();
2065   auto transposedType = inferTransposeResultType(srcType, op.permutation());
2066   if (dstType != transposedType)
2067     return op.emitOpError("output type ")
2068            << dstType << " does not match transposed input type " << srcType
2069            << ", " << transposedType;
2070   return success();
2071 }
2072 
2073 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
2074   if (succeeded(foldMemRefCast(*this)))
2075     return getResult();
2076   return {};
2077 }
2078 
2079 //===----------------------------------------------------------------------===//
2080 // ViewOp
2081 //===----------------------------------------------------------------------===//
2082 
2083 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) {
2084   OpAsmParser::OperandType srcInfo;
2085   SmallVector<OpAsmParser::OperandType, 1> offsetInfo;
2086   SmallVector<OpAsmParser::OperandType, 4> sizesInfo;
2087   auto indexType = parser.getBuilder().getIndexType();
2088   Type srcType, dstType;
2089   llvm::SMLoc offsetLoc;
2090   if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) ||
2091       parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square))
2092     return failure();
2093 
2094   if (offsetInfo.size() != 1)
2095     return parser.emitError(offsetLoc) << "expects 1 offset operand";
2096 
2097   return failure(
2098       parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) ||
2099       parser.parseOptionalAttrDict(result.attributes) ||
2100       parser.parseColonType(srcType) ||
2101       parser.resolveOperand(srcInfo, srcType, result.operands) ||
2102       parser.resolveOperands(offsetInfo, indexType, result.operands) ||
2103       parser.resolveOperands(sizesInfo, indexType, result.operands) ||
2104       parser.parseKeywordType("to", dstType) ||
2105       parser.addTypeToList(dstType, result.types));
2106 }
2107 
2108 static void print(OpAsmPrinter &p, ViewOp op) {
2109   p << op.getOperationName() << ' ' << op.getOperand(0) << '[';
2110   p.printOperand(op.byte_shift());
2111   p << "][" << op.sizes() << ']';
2112   p.printOptionalAttrDict(op->getAttrs());
2113   p << " : " << op.getOperand(0).getType() << " to " << op.getType();
2114 }
2115 
2116 static LogicalResult verify(ViewOp op) {
2117   auto baseType = op.getOperand(0).getType().cast<MemRefType>();
2118   auto viewType = op.getType();
2119 
2120   // The base memref should have identity layout map (or none).
2121   if (baseType.getAffineMaps().size() > 1 ||
2122       (baseType.getAffineMaps().size() == 1 &&
2123        !baseType.getAffineMaps()[0].isIdentity()))
2124     return op.emitError("unsupported map for base memref type ") << baseType;
2125 
2126   // The result memref should have identity layout map (or none).
2127   if (viewType.getAffineMaps().size() > 1 ||
2128       (viewType.getAffineMaps().size() == 1 &&
2129        !viewType.getAffineMaps()[0].isIdentity()))
2130     return op.emitError("unsupported map for result memref type ") << viewType;
2131 
2132   // The base memref and the view memref should be in the same memory space.
2133   if (baseType.getMemorySpace() != viewType.getMemorySpace())
2134     return op.emitError("different memory spaces specified for base memref "
2135                         "type ")
2136            << baseType << " and view memref type " << viewType;
2137 
2138   // Verify that we have the correct number of sizes for the result type.
2139   unsigned numDynamicDims = viewType.getNumDynamicDims();
2140   if (op.sizes().size() != numDynamicDims)
2141     return op.emitError("incorrect number of size operands for type ")
2142            << viewType;
2143 
2144   return success();
2145 }
2146 
2147 Value ViewOp::getViewSource() { return source(); }
2148 
2149 namespace {
2150 
2151 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
2152   using OpRewritePattern<ViewOp>::OpRewritePattern;
2153 
2154   LogicalResult matchAndRewrite(ViewOp viewOp,
2155                                 PatternRewriter &rewriter) const override {
2156     // Return if none of the operands are constants.
2157     if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
2158           return matchPattern(operand, matchConstantIndex());
2159         }))
2160       return failure();
2161 
2162     // Get result memref type.
2163     auto memrefType = viewOp.getType();
2164 
2165     // Get offset from old memref view type 'memRefType'.
2166     int64_t oldOffset;
2167     SmallVector<int64_t, 4> oldStrides;
2168     if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
2169       return failure();
2170     assert(oldOffset == 0 && "Expected 0 offset");
2171 
2172     SmallVector<Value, 4> newOperands;
2173 
2174     // Offset cannot be folded into result type.
2175 
2176     // Fold any dynamic dim operands which are produced by a constant.
2177     SmallVector<int64_t, 4> newShapeConstants;
2178     newShapeConstants.reserve(memrefType.getRank());
2179 
2180     unsigned dynamicDimPos = 0;
2181     unsigned rank = memrefType.getRank();
2182     for (unsigned dim = 0, e = rank; dim < e; ++dim) {
2183       int64_t dimSize = memrefType.getDimSize(dim);
2184       // If this is already static dimension, keep it.
2185       if (!ShapedType::isDynamic(dimSize)) {
2186         newShapeConstants.push_back(dimSize);
2187         continue;
2188       }
2189       auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp();
2190       if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
2191         // Dynamic shape dimension will be folded.
2192         newShapeConstants.push_back(constantIndexOp.getValue());
2193       } else {
2194         // Dynamic shape dimension not folded; copy operand from old memref.
2195         newShapeConstants.push_back(dimSize);
2196         newOperands.push_back(viewOp.sizes()[dynamicDimPos]);
2197       }
2198       dynamicDimPos++;
2199     }
2200 
2201     // Create new memref type with constant folded dims.
2202     MemRefType newMemRefType =
2203         MemRefType::Builder(memrefType).setShape(newShapeConstants);
2204     // Nothing new, don't fold.
2205     if (newMemRefType == memrefType)
2206       return failure();
2207 
2208     // Create new ViewOp.
2209     auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType,
2210                                              viewOp.getOperand(0),
2211                                              viewOp.byte_shift(), newOperands);
2212     // Insert a cast so we have the same type as the old memref type.
2213     rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType());
2214     return success();
2215   }
2216 };
2217 
2218 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
2219   using OpRewritePattern<ViewOp>::OpRewritePattern;
2220 
2221   LogicalResult matchAndRewrite(ViewOp viewOp,
2222                                 PatternRewriter &rewriter) const override {
2223     Value memrefOperand = viewOp.getOperand(0);
2224     CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
2225     if (!memrefCastOp)
2226       return failure();
2227     Value allocOperand = memrefCastOp.getOperand();
2228     AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
2229     if (!allocOp)
2230       return failure();
2231     rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
2232                                         viewOp.byte_shift(), viewOp.sizes());
2233     return success();
2234   }
2235 };
2236 
2237 } // end anonymous namespace
2238 
2239 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
2240                                          MLIRContext *context) {
2241   results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
2242 }
2243 
2244 //===----------------------------------------------------------------------===//
2245 // TableGen'd op method definitions
2246 //===----------------------------------------------------------------------===//
2247 
2248 #define GET_OP_CLASSES
2249 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
2250