1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/MemRef/IR/MemRef.h"
10 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Interfaces/InferTypeOpInterface.h"
21 #include "mlir/Interfaces/ViewLikeInterface.h"
22 #include "llvm/ADT/STLExtras.h"
23 
24 using namespace mlir;
25 using namespace mlir::memref;
26 
27 /// Materialize a single constant operation from a given attribute value with
28 /// the desired resultant type.
29 Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
30                                               Attribute value, Type type,
31                                               Location loc) {
32   return builder.create<mlir::ConstantOp>(loc, type, value);
33 }
34 
35 /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
36 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
37   return llvm::to_vector<4>(
38       llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
39         return a.cast<IntegerAttr>().getInt();
40       }));
41 }
42 
43 /// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
44 /// it is a Value or into `staticVec` if it is an IntegerAttr.
45 /// In the case of a Value, a copy of the `sentinel` value is also pushed to
46 /// `staticVec`. This is useful to extract mixed static and dynamic entries that
47 /// come from an AttrSizedOperandSegments trait.
48 static void dispatchIndexOpFoldResult(OpFoldResult ofr,
49                                       SmallVectorImpl<Value> &dynamicVec,
50                                       SmallVectorImpl<int64_t> &staticVec,
51                                       int64_t sentinel) {
52   if (auto v = ofr.dyn_cast<Value>()) {
53     dynamicVec.push_back(v);
54     staticVec.push_back(sentinel);
55     return;
56   }
57   APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue();
58   staticVec.push_back(apInt.getSExtValue());
59 }
60 
61 static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
62                                        SmallVectorImpl<Value> &dynamicVec,
63                                        SmallVectorImpl<int64_t> &staticVec,
64                                        int64_t sentinel) {
65   for (auto ofr : ofrs)
66     dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel);
67 }
68 
69 //===----------------------------------------------------------------------===//
70 // Common canonicalization pattern support logic
71 //===----------------------------------------------------------------------===//
72 
73 /// This is a common class used for patterns of the form
74 /// "someop(memrefcast) -> someop".  It folds the source of any memref.cast
75 /// into the root operation directly.
76 static LogicalResult foldMemRefCast(Operation *op, Value inner = nullptr) {
77   bool folded = false;
78   for (OpOperand &operand : op->getOpOperands()) {
79     auto cast = operand.get().getDefiningOp<CastOp>();
80     if (cast && operand.get() != inner &&
81         !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
82       operand.set(cast.getOperand());
83       folded = true;
84     }
85   }
86   return success(folded);
87 }
88 
89 //===----------------------------------------------------------------------===//
90 // Helpers for GlobalOp
91 //===----------------------------------------------------------------------===//
92 
93 static Type getTensorTypeFromMemRefType(Type type) {
94   if (auto memref = type.dyn_cast<MemRefType>())
95     return RankedTensorType::get(memref.getShape(), memref.getElementType());
96   if (auto memref = type.dyn_cast<UnrankedMemRefType>())
97     return UnrankedTensorType::get(memref.getElementType());
98   return NoneType::get(type.getContext());
99 }
100 
101 //===----------------------------------------------------------------------===//
102 // AllocOp / AllocaOp
103 //===----------------------------------------------------------------------===//
104 
105 template <typename AllocLikeOp>
106 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
107   static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
108                 "applies to only alloc or alloca");
109   auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>();
110   if (!memRefType)
111     return op.emitOpError("result must be a memref");
112 
113   if (static_cast<int64_t>(op.dynamicSizes().size()) !=
114       memRefType.getNumDynamicDims())
115     return op.emitOpError("dimension operand count does not equal memref "
116                           "dynamic dimension count");
117 
118   unsigned numSymbols = 0;
119   if (!memRefType.getAffineMaps().empty())
120     numSymbols = memRefType.getAffineMaps().front().getNumSymbols();
121   if (op.symbolOperands().size() != numSymbols)
122     return op.emitOpError(
123         "symbol operand count does not equal memref symbol count");
124 
125   return success();
126 }
127 
128 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); }
129 
130 static LogicalResult verify(AllocaOp op) {
131   // An alloca op needs to have an ancestor with an allocation scope trait.
132   if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
133     return op.emitOpError(
134         "requires an ancestor op with AutomaticAllocationScope trait");
135 
136   return verifyAllocLikeOp(op);
137 }
138 
139 namespace {
140 /// Fold constant dimensions into an alloc like operation.
141 template <typename AllocLikeOp>
142 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
143   using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
144 
145   LogicalResult matchAndRewrite(AllocLikeOp alloc,
146                                 PatternRewriter &rewriter) const override {
147     // Check to see if any dimensions operands are constants.  If so, we can
148     // substitute and drop them.
149     if (llvm::none_of(alloc.getOperands(), [](Value operand) {
150           return matchPattern(operand, matchConstantIndex());
151         }))
152       return failure();
153 
154     auto memrefType = alloc.getType();
155 
156     // Ok, we have one or more constant operands.  Collect the non-constant ones
157     // and keep track of the resultant memref type to build.
158     SmallVector<int64_t, 4> newShapeConstants;
159     newShapeConstants.reserve(memrefType.getRank());
160     SmallVector<Value, 4> newOperands;
161 
162     unsigned dynamicDimPos = 0;
163     for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
164       int64_t dimSize = memrefType.getDimSize(dim);
165       // If this is already static dimension, keep it.
166       if (dimSize != -1) {
167         newShapeConstants.push_back(dimSize);
168         continue;
169       }
170       auto *defOp = alloc.getOperand(dynamicDimPos).getDefiningOp();
171       if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
172         // Dynamic shape dimension will be folded.
173         newShapeConstants.push_back(constantIndexOp.getValue());
174       } else {
175         // Dynamic shape dimension not folded; copy operand from old memref.
176         newShapeConstants.push_back(-1);
177         newOperands.push_back(alloc.getOperand(dynamicDimPos));
178       }
179       dynamicDimPos++;
180     }
181 
182     // Create new memref type (which will have fewer dynamic dimensions).
183     MemRefType newMemRefType =
184         MemRefType::Builder(memrefType).setShape(newShapeConstants);
185     assert(static_cast<int64_t>(newOperands.size()) ==
186            newMemRefType.getNumDynamicDims());
187 
188     // Create and insert the alloc op for the new memref.
189     auto newAlloc = rewriter.create<AllocLikeOp>(
190         alloc.getLoc(), newMemRefType, newOperands, alloc.alignmentAttr());
191     // Insert a cast so we have the same type as the old alloc.
192     auto resultCast =
193         rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType());
194 
195     rewriter.replaceOp(alloc, {resultCast});
196     return success();
197   }
198 };
199 
200 /// Fold alloc operations with no users or only store and dealloc uses.
201 template <typename T>
202 struct SimplifyDeadAlloc : public OpRewritePattern<T> {
203   using OpRewritePattern<T>::OpRewritePattern;
204 
205   LogicalResult matchAndRewrite(T alloc,
206                                 PatternRewriter &rewriter) const override {
207     if (llvm::any_of(alloc->getUsers(), [](Operation *op) {
208           return !isa<StoreOp, DeallocOp>(op);
209         }))
210       return failure();
211 
212     for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
213       rewriter.eraseOp(user);
214 
215     rewriter.eraseOp(alloc);
216     return success();
217   }
218 };
219 } // end anonymous namespace.
220 
221 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
222                                           MLIRContext *context) {
223   results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
224 }
225 
226 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
227                                            MLIRContext *context) {
228   results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
229       context);
230 }
231 
232 //===----------------------------------------------------------------------===//
233 // AllocaScopeOp
234 //===----------------------------------------------------------------------===//
235 
236 static void print(OpAsmPrinter &p, AllocaScopeOp &op) {
237   bool printBlockTerminators = false;
238 
239   p << AllocaScopeOp::getOperationName() << " ";
240   if (!op.results().empty()) {
241     p << " -> (" << op.getResultTypes() << ")";
242     printBlockTerminators = true;
243   }
244   p.printRegion(op.bodyRegion(),
245                 /*printEntryBlockArgs=*/false,
246                 /*printBlockTerminators=*/printBlockTerminators);
247   p.printOptionalAttrDict(op->getAttrs());
248 }
249 
250 static ParseResult parseAllocaScopeOp(OpAsmParser &parser,
251                                       OperationState &result) {
252   // Create a region for the body.
253   result.regions.reserve(1);
254   Region *bodyRegion = result.addRegion();
255 
256   // Parse optional results type list.
257   if (parser.parseOptionalArrowTypeList(result.types))
258     return failure();
259 
260   // Parse the body region.
261   if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{}))
262     return failure();
263   AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
264                                   result.location);
265 
266   // Parse the optional attribute list.
267   if (parser.parseOptionalAttrDict(result.attributes))
268     return failure();
269 
270   return success();
271 }
272 
273 static LogicalResult verify(AllocaScopeOp op) {
274   if (failed(RegionBranchOpInterface::verifyTypes(op)))
275     return failure();
276 
277   return success();
278 }
279 
280 void AllocaScopeOp::getSuccessorRegions(
281     Optional<unsigned> index, ArrayRef<Attribute> operands,
282     SmallVectorImpl<RegionSuccessor> &regions) {
283   if (index.hasValue()) {
284     regions.push_back(RegionSuccessor(getResults()));
285     return;
286   }
287 
288   regions.push_back(RegionSuccessor(&bodyRegion()));
289 }
290 
291 //===----------------------------------------------------------------------===//
292 // AssumeAlignmentOp
293 //===----------------------------------------------------------------------===//
294 
295 static LogicalResult verify(AssumeAlignmentOp op) {
296   unsigned alignment = op.alignment();
297   if (!llvm::isPowerOf2_32(alignment))
298     return op.emitOpError("alignment must be power of 2");
299   return success();
300 }
301 
302 //===----------------------------------------------------------------------===//
303 // BufferCastOp
304 //===----------------------------------------------------------------------===//
305 
306 OpFoldResult BufferCastOp::fold(ArrayRef<Attribute>) {
307   if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>())
308     if (tensorLoad.memref().getType() == getType())
309       return tensorLoad.memref();
310   return {};
311 }
312 
313 namespace {
314 /// Replace tensor_cast + buffer_cast by buffer_cast + memref_cast.
315 struct BufferCast : public OpRewritePattern<BufferCastOp> {
316   using OpRewritePattern<BufferCastOp>::OpRewritePattern;
317 
318   LogicalResult matchAndRewrite(BufferCastOp bufferCast,
319                                 PatternRewriter &rewriter) const final {
320     auto tensorCastOperand =
321         bufferCast.getOperand().getDefiningOp<tensor::CastOp>();
322     if (!tensorCastOperand)
323       return failure();
324     auto srcTensorType =
325         tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>();
326     if (!srcTensorType)
327       return failure();
328     auto memrefType = MemRefType::get(srcTensorType.getShape(),
329                                       srcTensorType.getElementType());
330     Value memref = rewriter.create<BufferCastOp>(
331         bufferCast.getLoc(), memrefType, tensorCastOperand.getOperand());
332     rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(),
333                                         memref);
334     return success();
335   }
336 };
337 
338 /// Canonicalize memref.tensor_load + memref.buffer_cast to memref.cast when
339 /// type mismatches prevent `BufferCastOp::fold` to kick in.
340 struct TensorLoadToMemRef : public OpRewritePattern<BufferCastOp> {
341   using OpRewritePattern<BufferCastOp>::OpRewritePattern;
342 
343   LogicalResult matchAndRewrite(BufferCastOp bufferCast,
344                                 PatternRewriter &rewriter) const final {
345     auto tensorLoad = bufferCast.tensor().getDefiningOp<TensorLoadOp>();
346     // Bail unless we have a tensor_load + memref.buffer_cast with different
347     // types. `BufferCastOp::fold` handles the same type case.
348     if (!tensorLoad || tensorLoad.memref().getType() == bufferCast.getType())
349       return failure();
350     // If types are not cast-compatible, bail.
351     if (!CastOp::areCastCompatible(tensorLoad.memref().getType(),
352                                    bufferCast.getType()))
353       return failure();
354     rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(),
355                                         tensorLoad.memref());
356     return success();
357   }
358 };
359 
360 } // namespace
361 
362 void BufferCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
363                                                MLIRContext *context) {
364   results.add<BufferCast, TensorLoadToMemRef>(context);
365 }
366 
367 //===----------------------------------------------------------------------===//
368 // CastOp
369 //===----------------------------------------------------------------------===//
370 
371 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
372 /// source memref. This is useful to to fold a memref.cast into a consuming op
373 /// and implement canonicalization patterns for ops in different dialects that
374 /// may consume the results of memref.cast operations. Such foldable memref.cast
375 /// operations are typically inserted as `view` and `subview` ops are
376 /// canonicalized, to preserve the type compatibility of their uses.
377 ///
378 /// Returns true when all conditions are met:
379 /// 1. source and result are ranked memrefs with strided semantics and same
380 /// element type and rank.
381 /// 2. each of the source's size, offset or stride has more static information
382 /// than the corresponding result's size, offset or stride.
383 ///
384 /// Example 1:
385 /// ```mlir
386 ///   %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
387 ///   %2 = consumer %1 ... : memref<?x?xf32> ...
388 /// ```
389 ///
390 /// may fold into:
391 ///
392 /// ```mlir
393 ///   %2 = consumer %0 ... : memref<8x16xf32> ...
394 /// ```
395 ///
396 /// Example 2:
397 /// ```
398 ///   %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
399 ///          to memref<?x?xf32>
400 ///   consumer %1 : memref<?x?xf32> ...
401 /// ```
402 ///
403 /// may fold into:
404 ///
405 /// ```
406 ///   consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
407 /// ```
408 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
409   MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
410   MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
411 
412   // Requires ranked MemRefType.
413   if (!sourceType || !resultType)
414     return false;
415 
416   // Requires same elemental type.
417   if (sourceType.getElementType() != resultType.getElementType())
418     return false;
419 
420   // Requires same rank.
421   if (sourceType.getRank() != resultType.getRank())
422     return false;
423 
424   // Only fold casts between strided memref forms.
425   int64_t sourceOffset, resultOffset;
426   SmallVector<int64_t, 4> sourceStrides, resultStrides;
427   if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
428       failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
429     return false;
430 
431   // If cast is towards more static sizes along any dimension, don't fold.
432   for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
433     auto ss = std::get<0>(it), st = std::get<1>(it);
434     if (ss != st)
435       if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st))
436         return false;
437   }
438 
439   // If cast is towards more static offset along any dimension, don't fold.
440   if (sourceOffset != resultOffset)
441     if (MemRefType::isDynamicStrideOrOffset(sourceOffset) &&
442         !MemRefType::isDynamicStrideOrOffset(resultOffset))
443       return false;
444 
445   // If cast is towards more static strides along any dimension, don't fold.
446   for (auto it : llvm::zip(sourceStrides, resultStrides)) {
447     auto ss = std::get<0>(it), st = std::get<1>(it);
448     if (ss != st)
449       if (MemRefType::isDynamicStrideOrOffset(ss) &&
450           !MemRefType::isDynamicStrideOrOffset(st))
451         return false;
452   }
453 
454   return true;
455 }
456 
457 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
458   if (inputs.size() != 1 || outputs.size() != 1)
459     return false;
460   Type a = inputs.front(), b = outputs.front();
461   auto aT = a.dyn_cast<MemRefType>();
462   auto bT = b.dyn_cast<MemRefType>();
463 
464   auto uaT = a.dyn_cast<UnrankedMemRefType>();
465   auto ubT = b.dyn_cast<UnrankedMemRefType>();
466 
467   if (aT && bT) {
468     if (aT.getElementType() != bT.getElementType())
469       return false;
470     if (aT.getAffineMaps() != bT.getAffineMaps()) {
471       int64_t aOffset, bOffset;
472       SmallVector<int64_t, 4> aStrides, bStrides;
473       if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
474           failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
475           aStrides.size() != bStrides.size())
476         return false;
477 
478       // Strides along a dimension/offset are compatible if the value in the
479       // source memref is static and the value in the target memref is the
480       // same. They are also compatible if either one is dynamic (see
481       // description of MemRefCastOp for details).
482       auto checkCompatible = [](int64_t a, int64_t b) {
483         return (a == MemRefType::getDynamicStrideOrOffset() ||
484                 b == MemRefType::getDynamicStrideOrOffset() || a == b);
485       };
486       if (!checkCompatible(aOffset, bOffset))
487         return false;
488       for (auto aStride : enumerate(aStrides))
489         if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
490           return false;
491     }
492     if (aT.getMemorySpace() != bT.getMemorySpace())
493       return false;
494 
495     // They must have the same rank, and any specified dimensions must match.
496     if (aT.getRank() != bT.getRank())
497       return false;
498 
499     for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
500       int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
501       if (aDim != -1 && bDim != -1 && aDim != bDim)
502         return false;
503     }
504     return true;
505   } else {
506     if (!aT && !uaT)
507       return false;
508     if (!bT && !ubT)
509       return false;
510     // Unranked to unranked casting is unsupported
511     if (uaT && ubT)
512       return false;
513 
514     auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
515     auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
516     if (aEltType != bEltType)
517       return false;
518 
519     auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
520     auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
521     if (aMemSpace != bMemSpace)
522       return false;
523 
524     return true;
525   }
526 
527   return false;
528 }
529 
530 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
531   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
532 }
533 
534 //===----------------------------------------------------------------------===//
535 // CloneOp
536 //===----------------------------------------------------------------------===//
537 
538 void CloneOp::getEffects(
539     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
540         &effects) {
541   effects.emplace_back(MemoryEffects::Read::get(), input(),
542                        SideEffects::DefaultResource::get());
543   effects.emplace_back(MemoryEffects::Write::get(), output(),
544                        SideEffects::DefaultResource::get());
545 }
546 
547 namespace {
548 /// Fold Dealloc operations that are deallocating an AllocOp that is only used
549 /// by other Dealloc operations.
550 struct SimplifyClones : public OpRewritePattern<CloneOp> {
551   using OpRewritePattern<CloneOp>::OpRewritePattern;
552 
553   LogicalResult matchAndRewrite(CloneOp cloneOp,
554                                 PatternRewriter &rewriter) const override {
555     if (cloneOp.use_empty()) {
556       rewriter.eraseOp(cloneOp);
557       return success();
558     }
559 
560     Value source = cloneOp.input();
561 
562     // This only finds dealloc operations for the immediate value. It should
563     // also consider aliases. That would also make the safety check below
564     // redundant.
565     Operation *cloneDeallocOp = findDealloc(cloneOp.output());
566     Operation *sourceDeallocOp = findDealloc(source);
567 
568     // If both are deallocated in the same block, their in-block lifetimes
569     // might not fully overlap, so we cannot decide which one to drop.
570     if (cloneDeallocOp && sourceDeallocOp &&
571         cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
572       return failure();
573 
574     Block *currentBlock = cloneOp->getBlock();
575     Operation *redundantDealloc = nullptr;
576     if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
577       redundantDealloc = cloneDeallocOp;
578     } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
579       redundantDealloc = sourceDeallocOp;
580     }
581 
582     if (!redundantDealloc)
583       return failure();
584 
585     // Safety check that there are no other deallocations inbetween
586     // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
587     // of source before the uses of the clone. With alias information, we could
588     // restrict this to only fail of the dealloc's operand is an alias
589     // of the source.
590     for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
591          pos = pos->getNextNode()) {
592       auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
593       if (!effectInterface)
594         continue;
595       if (effectInterface.hasEffect<MemoryEffects::Free>())
596         return failure();
597     }
598 
599     rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(),
600                                                 source);
601     rewriter.eraseOp(redundantDealloc);
602     return success();
603   }
604 };
605 
606 } // end anonymous namespace.
607 
608 void CloneOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
609                                           MLIRContext *context) {
610   results.insert<SimplifyClones>(context);
611 }
612 
613 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) {
614   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
615 }
616 
617 //===----------------------------------------------------------------------===//
618 // DeallocOp
619 //===----------------------------------------------------------------------===//
620 
621 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
622                               SmallVectorImpl<OpFoldResult> &results) {
623   /// dealloc(memrefcast) -> dealloc
624   return foldMemRefCast(*this);
625 }
626 
627 //===----------------------------------------------------------------------===//
628 // DimOp
629 //===----------------------------------------------------------------------===//
630 
631 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref,
632                   int64_t index) {
633   auto loc = result.location;
634   Value indexValue = builder.create<ConstantIndexOp>(loc, index);
635   build(builder, result, memref, indexValue);
636 }
637 
638 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref,
639                   Value index) {
640   auto indexTy = builder.getIndexType();
641   build(builder, result, indexTy, memref, index);
642 }
643 
644 Optional<int64_t> DimOp::getConstantIndex() {
645   if (auto constantOp = index().getDefiningOp<ConstantOp>())
646     return constantOp.getValue().cast<IntegerAttr>().getInt();
647   return {};
648 }
649 
650 static LogicalResult verify(DimOp op) {
651   // Assume unknown index to be in range.
652   Optional<int64_t> index = op.getConstantIndex();
653   if (!index.hasValue())
654     return success();
655 
656   // Check that constant index is not knowingly out of range.
657   auto type = op.memrefOrTensor().getType();
658   if (auto memrefType = type.dyn_cast<MemRefType>()) {
659     if (index.getValue() >= memrefType.getRank())
660       return op.emitOpError("index is out of range");
661   } else if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
662     if (index.getValue() >= tensorType.getRank())
663       return op.emitOpError("index is out of range");
664   } else if (type.isa<UnrankedMemRefType>() || type.isa<UnrankedTensorType>()) {
665     // Assume index to be in range.
666   } else {
667     llvm_unreachable("expected operand with memref type");
668   }
669   return success();
670 }
671 
672 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
673   auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
674 
675   // All forms of folding require a known index.
676   if (!index)
677     return {};
678 
679   auto argTy = memrefOrTensor().getType();
680   // Fold if the shape extent along the given index is known.
681   if (auto shapedTy = argTy.dyn_cast<ShapedType>()) {
682     // Folding for unranked types (UnrankedMemRefType) is not supported.
683     if (!shapedTy.hasRank())
684       return {};
685     if (!shapedTy.isDynamicDim(index.getInt())) {
686       Builder builder(getContext());
687       return builder.getIndexAttr(shapedTy.getShape()[index.getInt()]);
688     }
689   }
690 
691   Operation *definingOp = memrefOrTensor().getDefiningOp();
692 
693   // dim(memref.tensor_load(memref)) -> dim(memref)
694   if (auto tensorLoadOp = dyn_cast_or_null<TensorLoadOp>(definingOp)) {
695     setOperand(0, tensorLoadOp.memref());
696     return getResult();
697   }
698 
699   // Fold dim to the operand of tensor.generate.
700   if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
701     auto resultType =
702         fromElements.getResult().getType().cast<RankedTensorType>();
703     // The case where the type encodes the size of the dimension is handled
704     // above.
705     assert(resultType.getShape()[index.getInt()] ==
706            RankedTensorType::kDynamicSize);
707 
708     // Find the operand of the fromElements that corresponds to this index.
709     auto dynExtents = fromElements.dynamicExtents().begin();
710     for (auto dim : resultType.getShape().take_front(index.getInt()))
711       if (dim == RankedTensorType::kDynamicSize)
712         dynExtents++;
713 
714     return Value{*dynExtents};
715   }
716 
717   // The size at the given index is now known to be a dynamic size.
718   unsigned unsignedIndex = index.getValue().getZExtValue();
719 
720   if (auto subtensor = dyn_cast_or_null<mlir::SubTensorOp>(definingOp)) {
721     assert(subtensor.isDynamicSize(unsignedIndex) &&
722            "Expected dynamic subtensor size");
723     return subtensor.getDynamicSize(unsignedIndex);
724   }
725 
726   // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
727   auto memrefType = argTy.dyn_cast<MemRefType>();
728   if (!memrefType)
729     return {};
730 
731   if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
732     return *(alloc.getDynamicSizes().begin() +
733              memrefType.getDynamicDimIndex(unsignedIndex));
734 
735   if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
736     return *(alloca.getDynamicSizes().begin() +
737              memrefType.getDynamicDimIndex(unsignedIndex));
738 
739   if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
740     return *(view.getDynamicSizes().begin() +
741              memrefType.getDynamicDimIndex(unsignedIndex));
742 
743   if (auto sizeInterface =
744           dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
745     assert(sizeInterface.isDynamicSize(unsignedIndex) &&
746            "Expected dynamic subview size");
747     return sizeInterface.getDynamicSize(unsignedIndex);
748   }
749 
750   // dim(memrefcast) -> dim
751   if (succeeded(foldMemRefCast(*this)))
752     return getResult();
753 
754   return {};
755 }
756 
757 namespace {
758 /// Fold dim of a memref reshape operation to a load into the reshape's shape
759 /// operand.
760 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
761   using OpRewritePattern<DimOp>::OpRewritePattern;
762 
763   LogicalResult matchAndRewrite(DimOp dim,
764                                 PatternRewriter &rewriter) const override {
765     auto reshape = dim.memrefOrTensor().getDefiningOp<ReshapeOp>();
766 
767     if (!reshape)
768       return failure();
769 
770     // Place the load directly after the reshape to ensure that the shape memref
771     // was not mutated.
772     rewriter.setInsertionPointAfter(reshape);
773     Location loc = dim.getLoc();
774     Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index());
775     if (load.getType() != dim.getType())
776       load = rewriter.create<IndexCastOp>(loc, dim.getType(), load);
777     rewriter.replaceOp(dim, load);
778     return success();
779   }
780 };
781 
782 /// Fold dim of a dim of a cast into the dim of the source of the tensor cast.
783 template <typename CastOpTy>
784 struct DimOfCastOp : public OpRewritePattern<DimOp> {
785   using OpRewritePattern<DimOp>::OpRewritePattern;
786 
787   LogicalResult matchAndRewrite(DimOp dimOp,
788                                 PatternRewriter &rewriter) const override {
789     auto castOp = dimOp.memrefOrTensor().getDefiningOp<CastOpTy>();
790     if (!castOp)
791       return failure();
792     Value newSource = castOp.getOperand();
793     rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.index());
794     return success();
795   }
796 };
797 
798 /// Helper method to get the `Value` that is the shape of the `resultIdx`-th
799 /// result at dimension `dimIndex` from the `ShapedTypeOpInterface`.
800 /// TODO(ravishankarm): This is better put as a interface utility method
801 /// somewhere, but that would imply the interface will depend on the `tensor`
802 /// dialect. Ideally maybe a utility method in the `tensor` dialect.
803 static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result,
804                                             int64_t dimIndex) {
805   unsigned resultNumber = result.getResultNumber();
806   auto shapedTypeOp = dyn_cast<InferShapedTypeOpInterface>(result.getOwner());
807   Location loc = result.getOwner()->getLoc();
808   if (!shapedTypeOp)
809     return nullptr;
810 
811   // The interface exposes two methods, one that returns the shape of all the
812   // results as `Value` and other that returns the shape as a list of
813   // `SmallVector<Value>`. The former takes precedence over the latter. So first
814   // check if the op implements the first interface method or the second, and
815   // get the value to use appropriately.
816   SmallVector<Value> reifiedResultShapes;
817   if (succeeded(shapedTypeOp.reifyReturnTypeShapes(
818           builder, result.getOwner()->getOperands(), reifiedResultShapes))) {
819     if (reifiedResultShapes.size() <= resultNumber)
820       return nullptr;
821     Value resultShape = reifiedResultShapes[resultNumber];
822     auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>();
823     if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>())
824       return nullptr;
825     return builder.create<tensor::ExtractOp>(
826         loc, resultShape, builder.createOrFold<ConstantIndexOp>(loc, dimIndex));
827   }
828 
829   SmallVector<SmallVector<Value>> reifiedResultShapesPerDim;
830   if (failed(shapedTypeOp.reifyReturnTypeShapesPerResultDim(
831           builder, reifiedResultShapesPerDim)))
832     return nullptr;
833   if (reifiedResultShapesPerDim.size() <= resultNumber ||
834       reifiedResultShapesPerDim[resultNumber].size() !=
835           static_cast<size_t>(result.getType().cast<ShapedType>().getRank()))
836     return nullptr;
837   OpFoldResult valueOrAttr = reifiedResultShapesPerDim[resultNumber][dimIndex];
838   if (auto attr = valueOrAttr.dyn_cast<Attribute>())
839     return builder.createOrFold<ConstantIndexOp>(
840         loc, attr.cast<IntegerAttr>().getInt());
841   return valueOrAttr.get<Value>();
842 }
843 
844 /// Fold dim of an operation that implements the InferShapedTypeOpInterface
845 struct DimOfShapedTypeOpInterface : public OpRewritePattern<DimOp> {
846   using OpRewritePattern<DimOp>::OpRewritePattern;
847 
848   LogicalResult matchAndRewrite(DimOp dimOp,
849                                 PatternRewriter &rewriter) const override {
850     OpResult dimValue = dimOp.memrefOrTensor().dyn_cast<OpResult>();
851     if (!dimValue)
852       return failure();
853     auto shapedTypeOp =
854         dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner());
855     if (!shapedTypeOp)
856       return failure();
857 
858     Optional<int64_t> dimIndex = dimOp.getConstantIndex();
859     if (!dimIndex)
860       return failure();
861     Value replacement =
862         getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex);
863     if (!replacement)
864       return failure();
865     rewriter.replaceOp(dimOp, replacement);
866     return success();
867   }
868 };
869 } // end anonymous namespace.
870 
871 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
872                                         MLIRContext *context) {
873   results.add<DimOfMemRefReshape, DimOfCastOp<BufferCastOp>,
874               DimOfCastOp<tensor::CastOp>, DimOfShapedTypeOpInterface>(context);
875 }
876 
877 // ---------------------------------------------------------------------------
878 // DmaStartOp
879 // ---------------------------------------------------------------------------
880 
881 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
882                        Value srcMemRef, ValueRange srcIndices, Value destMemRef,
883                        ValueRange destIndices, Value numElements,
884                        Value tagMemRef, ValueRange tagIndices, Value stride,
885                        Value elementsPerStride) {
886   result.addOperands(srcMemRef);
887   result.addOperands(srcIndices);
888   result.addOperands(destMemRef);
889   result.addOperands(destIndices);
890   result.addOperands({numElements, tagMemRef});
891   result.addOperands(tagIndices);
892   if (stride)
893     result.addOperands({stride, elementsPerStride});
894 }
895 
896 void DmaStartOp::print(OpAsmPrinter &p) {
897   p << getOperationName() << " " << getSrcMemRef() << '[' << getSrcIndices()
898     << "], " << getDstMemRef() << '[' << getDstIndices() << "], "
899     << getNumElements() << ", " << getTagMemRef() << '[' << getTagIndices()
900     << ']';
901   if (isStrided())
902     p << ", " << getStride() << ", " << getNumElementsPerStride();
903 
904   p.printOptionalAttrDict((*this)->getAttrs());
905   p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
906     << ", " << getTagMemRef().getType();
907 }
908 
909 // Parse DmaStartOp.
910 // Ex:
911 //   %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
912 //                       %tag[%index], %stride, %num_elt_per_stride :
913 //                     : memref<3076 x f32, 0>,
914 //                       memref<1024 x f32, 2>,
915 //                       memref<1 x i32>
916 //
917 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
918   OpAsmParser::OperandType srcMemRefInfo;
919   SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
920   OpAsmParser::OperandType dstMemRefInfo;
921   SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos;
922   OpAsmParser::OperandType numElementsInfo;
923   OpAsmParser::OperandType tagMemrefInfo;
924   SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
925   SmallVector<OpAsmParser::OperandType, 2> strideInfo;
926 
927   SmallVector<Type, 3> types;
928   auto indexType = parser.getBuilder().getIndexType();
929 
930   // Parse and resolve the following list of operands:
931   // *) source memref followed by its indices (in square brackets).
932   // *) destination memref followed by its indices (in square brackets).
933   // *) dma size in KiB.
934   if (parser.parseOperand(srcMemRefInfo) ||
935       parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
936       parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
937       parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
938       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
939       parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
940       parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
941     return failure();
942 
943   // Parse optional stride and elements per stride.
944   if (parser.parseTrailingOperandList(strideInfo))
945     return failure();
946 
947   bool isStrided = strideInfo.size() == 2;
948   if (!strideInfo.empty() && !isStrided) {
949     return parser.emitError(parser.getNameLoc(),
950                             "expected two stride related operands");
951   }
952 
953   if (parser.parseColonTypeList(types))
954     return failure();
955   if (types.size() != 3)
956     return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
957 
958   if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
959       parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
960       parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
961       parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
962       // size should be an index.
963       parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
964       parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
965       // tag indices should be index.
966       parser.resolveOperands(tagIndexInfos, indexType, result.operands))
967     return failure();
968 
969   if (isStrided) {
970     if (parser.resolveOperands(strideInfo, indexType, result.operands))
971       return failure();
972   }
973 
974   return success();
975 }
976 
977 LogicalResult DmaStartOp::verify() {
978   unsigned numOperands = getNumOperands();
979 
980   // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
981   // the number of elements.
982   if (numOperands < 4)
983     return emitOpError("expected at least 4 operands");
984 
985   // Check types of operands. The order of these calls is important: the later
986   // calls rely on some type properties to compute the operand position.
987   // 1. Source memref.
988   if (!getSrcMemRef().getType().isa<MemRefType>())
989     return emitOpError("expected source to be of memref type");
990   if (numOperands < getSrcMemRefRank() + 4)
991     return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
992                          << " operands";
993   if (!getSrcIndices().empty() &&
994       !llvm::all_of(getSrcIndices().getTypes(),
995                     [](Type t) { return t.isIndex(); }))
996     return emitOpError("expected source indices to be of index type");
997 
998   // 2. Destination memref.
999   if (!getDstMemRef().getType().isa<MemRefType>())
1000     return emitOpError("expected destination to be of memref type");
1001   unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1002   if (numOperands < numExpectedOperands)
1003     return emitOpError() << "expected at least " << numExpectedOperands
1004                          << " operands";
1005   if (!getDstIndices().empty() &&
1006       !llvm::all_of(getDstIndices().getTypes(),
1007                     [](Type t) { return t.isIndex(); }))
1008     return emitOpError("expected destination indices to be of index type");
1009 
1010   // 3. Number of elements.
1011   if (!getNumElements().getType().isIndex())
1012     return emitOpError("expected num elements to be of index type");
1013 
1014   // 4. Tag memref.
1015   if (!getTagMemRef().getType().isa<MemRefType>())
1016     return emitOpError("expected tag to be of memref type");
1017   numExpectedOperands += getTagMemRefRank();
1018   if (numOperands < numExpectedOperands)
1019     return emitOpError() << "expected at least " << numExpectedOperands
1020                          << " operands";
1021   if (!getTagIndices().empty() &&
1022       !llvm::all_of(getTagIndices().getTypes(),
1023                     [](Type t) { return t.isIndex(); }))
1024     return emitOpError("expected tag indices to be of index type");
1025 
1026   // Optional stride-related operands must be either both present or both
1027   // absent.
1028   if (numOperands != numExpectedOperands &&
1029       numOperands != numExpectedOperands + 2)
1030     return emitOpError("incorrect number of operands");
1031 
1032   // 5. Strides.
1033   if (isStrided()) {
1034     if (!getStride().getType().isIndex() ||
1035         !getNumElementsPerStride().getType().isIndex())
1036       return emitOpError(
1037           "expected stride and num elements per stride to be of type index");
1038   }
1039 
1040   return success();
1041 }
1042 
1043 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1044                                SmallVectorImpl<OpFoldResult> &results) {
1045   /// dma_start(memrefcast) -> dma_start
1046   return foldMemRefCast(*this);
1047 }
1048 
1049 // ---------------------------------------------------------------------------
1050 // DmaWaitOp
1051 // ---------------------------------------------------------------------------
1052 
1053 void DmaWaitOp::build(OpBuilder &builder, OperationState &result,
1054                       Value tagMemRef, ValueRange tagIndices,
1055                       Value numElements) {
1056   result.addOperands(tagMemRef);
1057   result.addOperands(tagIndices);
1058   result.addOperands(numElements);
1059 }
1060 
1061 void DmaWaitOp::print(OpAsmPrinter &p) {
1062   p << getOperationName() << " " << getTagMemRef() << '[' << getTagIndices()
1063     << "], " << getNumElements();
1064   p.printOptionalAttrDict((*this)->getAttrs());
1065   p << " : " << getTagMemRef().getType();
1066 }
1067 
1068 // Parse DmaWaitOp.
1069 // Eg:
1070 //   dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4>
1071 //
1072 ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) {
1073   OpAsmParser::OperandType tagMemrefInfo;
1074   SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
1075   Type type;
1076   auto indexType = parser.getBuilder().getIndexType();
1077   OpAsmParser::OperandType numElementsInfo;
1078 
1079   // Parse tag memref, its indices, and dma size.
1080   if (parser.parseOperand(tagMemrefInfo) ||
1081       parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) ||
1082       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1083       parser.parseColonType(type) ||
1084       parser.resolveOperand(tagMemrefInfo, type, result.operands) ||
1085       parser.resolveOperands(tagIndexInfos, indexType, result.operands) ||
1086       parser.resolveOperand(numElementsInfo, indexType, result.operands))
1087     return failure();
1088 
1089   return success();
1090 }
1091 
1092 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1093                               SmallVectorImpl<OpFoldResult> &results) {
1094   /// dma_wait(memrefcast) -> dma_wait
1095   return foldMemRefCast(*this);
1096 }
1097 
1098 LogicalResult DmaWaitOp::verify() {
1099   // Mandatory non-variadic operands are tag and the number of elements.
1100   if (getNumOperands() < 2)
1101     return emitOpError() << "expected at least 2 operands";
1102 
1103   // Check types of operands. The order of these calls is important: the later
1104   // calls rely on some type properties to compute the operand position.
1105   if (!getTagMemRef().getType().isa<MemRefType>())
1106     return emitOpError() << "expected tag to be of memref type";
1107 
1108   if (getNumOperands() != 2 + getTagMemRefRank())
1109     return emitOpError() << "expected " << 2 + getTagMemRefRank()
1110                          << " operands";
1111 
1112   if (!getTagIndices().empty() &&
1113       !llvm::all_of(getTagIndices().getTypes(),
1114                     [](Type t) { return t.isIndex(); }))
1115     return emitOpError() << "expected tag indices to be of index type";
1116 
1117   if (!getNumElements().getType().isIndex())
1118     return emitOpError()
1119            << "expected the number of elements to be of index type";
1120 
1121   return success();
1122 }
1123 
1124 //===----------------------------------------------------------------------===//
1125 // GlobalOp
1126 //===----------------------------------------------------------------------===//
1127 
1128 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op,
1129                                                    TypeAttr type,
1130                                                    Attribute initialValue) {
1131   p << type;
1132   if (!op.isExternal()) {
1133     p << " = ";
1134     if (op.isUninitialized())
1135       p << "uninitialized";
1136     else
1137       p.printAttributeWithoutType(initialValue);
1138   }
1139 }
1140 
1141 static ParseResult
1142 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
1143                                        Attribute &initialValue) {
1144   Type type;
1145   if (parser.parseType(type))
1146     return failure();
1147 
1148   auto memrefType = type.dyn_cast<MemRefType>();
1149   if (!memrefType || !memrefType.hasStaticShape())
1150     return parser.emitError(parser.getNameLoc())
1151            << "type should be static shaped memref, but got " << type;
1152   typeAttr = TypeAttr::get(type);
1153 
1154   if (parser.parseOptionalEqual())
1155     return success();
1156 
1157   if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1158     initialValue = UnitAttr::get(parser.getBuilder().getContext());
1159     return success();
1160   }
1161 
1162   Type tensorType = getTensorTypeFromMemRefType(memrefType);
1163   if (parser.parseAttribute(initialValue, tensorType))
1164     return failure();
1165   if (!initialValue.isa<ElementsAttr>())
1166     return parser.emitError(parser.getNameLoc())
1167            << "initial value should be a unit or elements attribute";
1168   return success();
1169 }
1170 
1171 static LogicalResult verify(GlobalOp op) {
1172   auto memrefType = op.type().dyn_cast<MemRefType>();
1173   if (!memrefType || !memrefType.hasStaticShape())
1174     return op.emitOpError("type should be static shaped memref, but got ")
1175            << op.type();
1176 
1177   // Verify that the initial value, if present, is either a unit attribute or
1178   // an elements attribute.
1179   if (op.initial_value().hasValue()) {
1180     Attribute initValue = op.initial_value().getValue();
1181     if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>())
1182       return op.emitOpError("initial value should be a unit or elements "
1183                             "attribute, but got ")
1184              << initValue;
1185 
1186     // Check that the type of the initial value is compatible with the type of
1187     // the global variable.
1188     if (initValue.isa<ElementsAttr>()) {
1189       Type initType = initValue.getType();
1190       Type tensorType = getTensorTypeFromMemRefType(memrefType);
1191       if (initType != tensorType)
1192         return op.emitOpError("initial value expected to be of type ")
1193                << tensorType << ", but was of type " << initType;
1194     }
1195   }
1196 
1197   // TODO: verify visibility for declarations.
1198   return success();
1199 }
1200 
1201 //===----------------------------------------------------------------------===//
1202 // GetGlobalOp
1203 //===----------------------------------------------------------------------===//
1204 
1205 LogicalResult
1206 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1207   // Verify that the result type is same as the type of the referenced
1208   // memref.global op.
1209   auto global =
1210       symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr());
1211   if (!global)
1212     return emitOpError("'")
1213            << name() << "' does not reference a valid global memref";
1214 
1215   Type resultType = result().getType();
1216   if (global.type() != resultType)
1217     return emitOpError("result type ")
1218            << resultType << " does not match type " << global.type()
1219            << " of the global memref @" << name();
1220   return success();
1221 }
1222 
1223 //===----------------------------------------------------------------------===//
1224 // LoadOp
1225 //===----------------------------------------------------------------------===//
1226 
1227 static LogicalResult verify(LoadOp op) {
1228   if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
1229     return op.emitOpError("incorrect number of indices for load");
1230   return success();
1231 }
1232 
1233 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
1234   /// load(memrefcast) -> load
1235   if (succeeded(foldMemRefCast(*this)))
1236     return getResult();
1237   return OpFoldResult();
1238 }
1239 
1240 namespace {
1241 /// Fold a load on a buffer_cast operation into an tensor.extract on the
1242 /// corresponding tensor.
1243 struct LoadOfBufferCast : public OpRewritePattern<LoadOp> {
1244   using OpRewritePattern<LoadOp>::OpRewritePattern;
1245 
1246   LogicalResult matchAndRewrite(LoadOp load,
1247                                 PatternRewriter &rewriter) const override {
1248     auto buffercast = load.memref().getDefiningOp<BufferCastOp>();
1249     if (!buffercast)
1250       return failure();
1251 
1252     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, buffercast.tensor(),
1253                                                    load.indices());
1254     return success();
1255   }
1256 };
1257 } // end anonymous namespace.
1258 
1259 void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
1260                                          MLIRContext *context) {
1261   results.add<LoadOfBufferCast>(context);
1262 }
1263 
1264 //===----------------------------------------------------------------------===//
1265 // PrefetchOp
1266 //===----------------------------------------------------------------------===//
1267 
1268 static void print(OpAsmPrinter &p, PrefetchOp op) {
1269   p << PrefetchOp::getOperationName() << " " << op.memref() << '[';
1270   p.printOperands(op.indices());
1271   p << ']' << ", " << (op.isWrite() ? "write" : "read");
1272   p << ", locality<" << op.localityHint();
1273   p << ">, " << (op.isDataCache() ? "data" : "instr");
1274   p.printOptionalAttrDict(
1275       op->getAttrs(),
1276       /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1277   p << " : " << op.getMemRefType();
1278 }
1279 
1280 static ParseResult parsePrefetchOp(OpAsmParser &parser,
1281                                    OperationState &result) {
1282   OpAsmParser::OperandType memrefInfo;
1283   SmallVector<OpAsmParser::OperandType, 4> indexInfo;
1284   IntegerAttr localityHint;
1285   MemRefType type;
1286   StringRef readOrWrite, cacheType;
1287 
1288   auto indexTy = parser.getBuilder().getIndexType();
1289   auto i32Type = parser.getBuilder().getIntegerType(32);
1290   if (parser.parseOperand(memrefInfo) ||
1291       parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1292       parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1293       parser.parseComma() || parser.parseKeyword("locality") ||
1294       parser.parseLess() ||
1295       parser.parseAttribute(localityHint, i32Type, "localityHint",
1296                             result.attributes) ||
1297       parser.parseGreater() || parser.parseComma() ||
1298       parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1299       parser.resolveOperand(memrefInfo, type, result.operands) ||
1300       parser.resolveOperands(indexInfo, indexTy, result.operands))
1301     return failure();
1302 
1303   if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
1304     return parser.emitError(parser.getNameLoc(),
1305                             "rw specifier has to be 'read' or 'write'");
1306   result.addAttribute(
1307       PrefetchOp::getIsWriteAttrName(),
1308       parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
1309 
1310   if (!cacheType.equals("data") && !cacheType.equals("instr"))
1311     return parser.emitError(parser.getNameLoc(),
1312                             "cache type has to be 'data' or 'instr'");
1313 
1314   result.addAttribute(
1315       PrefetchOp::getIsDataCacheAttrName(),
1316       parser.getBuilder().getBoolAttr(cacheType.equals("data")));
1317 
1318   return success();
1319 }
1320 
1321 static LogicalResult verify(PrefetchOp op) {
1322   if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
1323     return op.emitOpError("too few indices");
1324 
1325   return success();
1326 }
1327 
1328 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
1329                                SmallVectorImpl<OpFoldResult> &results) {
1330   // prefetch(memrefcast) -> prefetch
1331   return foldMemRefCast(*this);
1332 }
1333 
1334 //===----------------------------------------------------------------------===//
1335 // ReinterpretCastOp
1336 //===----------------------------------------------------------------------===//
1337 
1338 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1339 /// `staticSizes` and `staticStrides` are automatically filled with
1340 /// source-memref-rank sentinel values that encode dynamic entries.
1341 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1342                               MemRefType resultType, Value source,
1343                               OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1344                               ArrayRef<OpFoldResult> strides,
1345                               ArrayRef<NamedAttribute> attrs) {
1346   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1347   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1348   dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets,
1349                              ShapedType::kDynamicStrideOrOffset);
1350   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1351                              ShapedType::kDynamicSize);
1352   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1353                              ShapedType::kDynamicStrideOrOffset);
1354   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1355         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1356         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1357   result.addAttributes(attrs);
1358 }
1359 
1360 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1361                               MemRefType resultType, Value source,
1362                               int64_t offset, ArrayRef<int64_t> sizes,
1363                               ArrayRef<int64_t> strides,
1364                               ArrayRef<NamedAttribute> attrs) {
1365   SmallVector<OpFoldResult> sizeValues =
1366       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1367         return b.getI64IntegerAttr(v);
1368       }));
1369   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1370       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1371         return b.getI64IntegerAttr(v);
1372       }));
1373   build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1374         strideValues, attrs);
1375 }
1376 
1377 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1378                               MemRefType resultType, Value source, Value offset,
1379                               ValueRange sizes, ValueRange strides,
1380                               ArrayRef<NamedAttribute> attrs) {
1381   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1382       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1383   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1384       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1385   build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1386 }
1387 
1388 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1389 // completed automatically, like we have for subview and subtensor.
1390 static LogicalResult verify(ReinterpretCastOp op) {
1391   // The source and result memrefs should be in the same memory space.
1392   auto srcType = op.source().getType().cast<BaseMemRefType>();
1393   auto resultType = op.getType().cast<MemRefType>();
1394   if (srcType.getMemorySpace() != resultType.getMemorySpace())
1395     return op.emitError("different memory spaces specified for source type ")
1396            << srcType << " and result memref type " << resultType;
1397   if (srcType.getElementType() != resultType.getElementType())
1398     return op.emitError("different element types specified for source type ")
1399            << srcType << " and result memref type " << resultType;
1400 
1401   // Match sizes in result memref type and in static_sizes attribute.
1402   for (auto &en :
1403        llvm::enumerate(llvm::zip(resultType.getShape(),
1404                                  extractFromI64ArrayAttr(op.static_sizes())))) {
1405     int64_t resultSize = std::get<0>(en.value());
1406     int64_t expectedSize = std::get<1>(en.value());
1407     if (resultSize != expectedSize)
1408       return op.emitError("expected result type with size = ")
1409              << expectedSize << " instead of " << resultSize
1410              << " in dim = " << en.index();
1411   }
1412 
1413   // Match offset and strides in static_offset and static_strides attributes if
1414   // result memref type has an affine map specified.
1415   if (!resultType.getAffineMaps().empty()) {
1416     int64_t resultOffset;
1417     SmallVector<int64_t, 4> resultStrides;
1418     if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
1419       return failure();
1420 
1421     // Match offset in result memref type and in static_offsets attribute.
1422     int64_t expectedOffset =
1423         extractFromI64ArrayAttr(op.static_offsets()).front();
1424     if (resultOffset != expectedOffset)
1425       return op.emitError("expected result type with offset = ")
1426              << resultOffset << " instead of " << expectedOffset;
1427 
1428     // Match strides in result memref type and in static_strides attribute.
1429     for (auto &en : llvm::enumerate(llvm::zip(
1430              resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
1431       int64_t resultStride = std::get<0>(en.value());
1432       int64_t expectedStride = std::get<1>(en.value());
1433       if (resultStride != expectedStride)
1434         return op.emitError("expected result type with stride = ")
1435                << expectedStride << " instead of " << resultStride
1436                << " in dim = " << en.index();
1437     }
1438   }
1439   return success();
1440 }
1441 
1442 //===----------------------------------------------------------------------===//
1443 // ReshapeOp
1444 //===----------------------------------------------------------------------===//
1445 
1446 static LogicalResult verify(ReshapeOp op) {
1447   Type operandType = op.source().getType();
1448   Type resultType = op.result().getType();
1449 
1450   Type operandElementType = operandType.cast<ShapedType>().getElementType();
1451   Type resultElementType = resultType.cast<ShapedType>().getElementType();
1452   if (operandElementType != resultElementType)
1453     return op.emitOpError("element types of source and destination memref "
1454                           "types should be the same");
1455 
1456   if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
1457     if (!operandMemRefType.getAffineMaps().empty())
1458       return op.emitOpError(
1459           "source memref type should have identity affine map");
1460 
1461   int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0);
1462   auto resultMemRefType = resultType.dyn_cast<MemRefType>();
1463   if (resultMemRefType) {
1464     if (!resultMemRefType.getAffineMaps().empty())
1465       return op.emitOpError(
1466           "result memref type should have identity affine map");
1467     if (shapeSize == ShapedType::kDynamicSize)
1468       return op.emitOpError("cannot use shape operand with dynamic length to "
1469                             "reshape to statically-ranked memref type");
1470     if (shapeSize != resultMemRefType.getRank())
1471       return op.emitOpError(
1472           "length of shape operand differs from the result's memref rank");
1473   }
1474   return success();
1475 }
1476 
1477 //===----------------------------------------------------------------------===//
1478 // StoreOp
1479 //===----------------------------------------------------------------------===//
1480 
1481 static LogicalResult verify(StoreOp op) {
1482   if (op.getNumOperands() != 2 + op.getMemRefType().getRank())
1483     return op.emitOpError("store index operand count not equal to memref rank");
1484 
1485   return success();
1486 }
1487 
1488 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
1489                             SmallVectorImpl<OpFoldResult> &results) {
1490   /// store(memrefcast) -> store
1491   return foldMemRefCast(*this, getValueToStore());
1492 }
1493 
1494 //===----------------------------------------------------------------------===//
1495 // SubViewOp
1496 //===----------------------------------------------------------------------===//
1497 
1498 namespace {
1499 /// Helpers to write more idiomatic operations.
1500 namespace saturated_arith {
1501 struct Wrapper {
1502   explicit Wrapper(int64_t v) : v(v) {}
1503   operator int64_t() { return v; }
1504   int64_t v;
1505 };
1506 Wrapper operator+(Wrapper a, int64_t b) {
1507   if (ShapedType::isDynamicStrideOrOffset(a) ||
1508       ShapedType::isDynamicStrideOrOffset(b))
1509     return Wrapper(ShapedType::kDynamicStrideOrOffset);
1510   return Wrapper(a.v + b);
1511 }
1512 Wrapper operator*(Wrapper a, int64_t b) {
1513   if (ShapedType::isDynamicStrideOrOffset(a) ||
1514       ShapedType::isDynamicStrideOrOffset(b))
1515     return Wrapper(ShapedType::kDynamicStrideOrOffset);
1516   return Wrapper(a.v * b);
1517 }
1518 } // end namespace saturated_arith
1519 } // end namespace
1520 
1521 /// A subview result type can be fully inferred from the source type and the
1522 /// static representation of offsets, sizes and strides. Special sentinels
1523 /// encode the dynamic case.
1524 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
1525                                 ArrayRef<int64_t> leadingStaticOffsets,
1526                                 ArrayRef<int64_t> leadingStaticSizes,
1527                                 ArrayRef<int64_t> leadingStaticStrides) {
1528   // A subview may specify only a leading subset of offset/sizes/strides in
1529   // which case we complete with offset=0, sizes from memref type and strides=1.
1530   unsigned rank = sourceMemRefType.getRank();
1531   assert(leadingStaticOffsets.size() <= rank &&
1532          "unexpected leadingStaticOffsets overflow");
1533   assert(leadingStaticSizes.size() <= rank &&
1534          "unexpected leadingStaticSizes overflow");
1535   assert(leadingStaticStrides.size() <= rank &&
1536          "unexpected leadingStaticStrides overflow");
1537   auto staticOffsets = llvm::to_vector<4>(leadingStaticOffsets);
1538   auto staticSizes = llvm::to_vector<4>(leadingStaticSizes);
1539   auto staticStrides = llvm::to_vector<4>(leadingStaticStrides);
1540   unsigned numTrailingOffsets = rank - staticOffsets.size();
1541   unsigned numTrailingSizes = rank - staticSizes.size();
1542   unsigned numTrailingStrides = rank - staticStrides.size();
1543   staticOffsets.append(numTrailingOffsets, 0);
1544   llvm::append_range(staticSizes,
1545                      sourceMemRefType.getShape().take_back(numTrailingSizes));
1546   staticStrides.append(numTrailingStrides, 1);
1547 
1548   // Extract source offset and strides.
1549   int64_t sourceOffset;
1550   SmallVector<int64_t, 4> sourceStrides;
1551   auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset);
1552   assert(succeeded(res) && "SubViewOp expected strided memref type");
1553   (void)res;
1554 
1555   // Compute target offset whose value is:
1556   //   `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
1557   int64_t targetOffset = sourceOffset;
1558   for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
1559     auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
1560     using namespace saturated_arith;
1561     targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride;
1562   }
1563 
1564   // Compute target stride whose value is:
1565   //   `sourceStrides_i * staticStrides_i`.
1566   SmallVector<int64_t, 4> targetStrides;
1567   targetStrides.reserve(staticOffsets.size());
1568   for (auto it : llvm::zip(sourceStrides, staticStrides)) {
1569     auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
1570     using namespace saturated_arith;
1571     targetStrides.push_back(Wrapper(sourceStride) * staticStride);
1572   }
1573 
1574   // The type is now known.
1575   return MemRefType::get(
1576       staticSizes, sourceMemRefType.getElementType(),
1577       makeStridedLinearLayoutMap(targetStrides, targetOffset,
1578                                  sourceMemRefType.getContext()),
1579       sourceMemRefType.getMemorySpace());
1580 }
1581 
1582 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
1583                                 ArrayRef<OpFoldResult> leadingStaticOffsets,
1584                                 ArrayRef<OpFoldResult> leadingStaticSizes,
1585                                 ArrayRef<OpFoldResult> leadingStaticStrides) {
1586   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1587   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1588   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
1589                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
1590   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
1591                              ShapedType::kDynamicSize);
1592   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
1593                              staticStrides, ShapedType::kDynamicStrideOrOffset);
1594   return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
1595                                     staticSizes, staticStrides)
1596       .cast<MemRefType>();
1597 }
1598 
1599 Type SubViewOp::inferRankReducedResultType(
1600     unsigned resultRank, MemRefType sourceRankedTensorType,
1601     ArrayRef<int64_t> leadingStaticOffsets,
1602     ArrayRef<int64_t> leadingStaticSizes,
1603     ArrayRef<int64_t> leadingStaticStrides) {
1604   auto inferredType =
1605       inferResultType(sourceRankedTensorType, leadingStaticOffsets,
1606                       leadingStaticSizes, leadingStaticStrides)
1607           .cast<MemRefType>();
1608   assert(inferredType.getRank() >= resultRank && "expected ");
1609   int rankDiff = inferredType.getRank() - resultRank;
1610   if (rankDiff > 0) {
1611     auto shape = inferredType.getShape();
1612     llvm::SmallDenseSet<unsigned> dimsToProject;
1613     mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject);
1614     SmallVector<int64_t> projectedShape;
1615     for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
1616       if (!dimsToProject.contains(pos))
1617         projectedShape.push_back(shape[pos]);
1618 
1619     AffineMap map;
1620     auto maps = inferredType.getAffineMaps();
1621     if (!maps.empty() && maps.front())
1622       map = getProjectedMap(maps.front(), dimsToProject);
1623     inferredType =
1624         MemRefType::get(projectedShape, inferredType.getElementType(), map,
1625                         inferredType.getMemorySpace());
1626   }
1627   return inferredType;
1628 }
1629 
1630 Type SubViewOp::inferRankReducedResultType(
1631     unsigned resultRank, MemRefType sourceRankedTensorType,
1632     ArrayRef<OpFoldResult> leadingStaticOffsets,
1633     ArrayRef<OpFoldResult> leadingStaticSizes,
1634     ArrayRef<OpFoldResult> leadingStaticStrides) {
1635   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1636   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1637   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
1638                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
1639   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
1640                              ShapedType::kDynamicSize);
1641   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
1642                              staticStrides, ShapedType::kDynamicStrideOrOffset);
1643   return SubViewOp::inferRankReducedResultType(
1644       resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
1645       staticStrides);
1646 }
1647 // Build a SubViewOp with mixed static and dynamic entries and custom result
1648 // type. If the type passed is nullptr, it is inferred.
1649 void SubViewOp::build(OpBuilder &b, OperationState &result,
1650                       MemRefType resultType, Value source,
1651                       ArrayRef<OpFoldResult> offsets,
1652                       ArrayRef<OpFoldResult> sizes,
1653                       ArrayRef<OpFoldResult> strides,
1654                       ArrayRef<NamedAttribute> attrs) {
1655   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1656   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1657   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1658                              ShapedType::kDynamicStrideOrOffset);
1659   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1660                              ShapedType::kDynamicSize);
1661   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1662                              ShapedType::kDynamicStrideOrOffset);
1663   auto sourceMemRefType = source.getType().cast<MemRefType>();
1664   // Structuring implementation this way avoids duplication between builders.
1665   if (!resultType) {
1666     resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
1667                                             staticSizes, staticStrides)
1668                      .cast<MemRefType>();
1669   }
1670   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1671         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1672         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1673   result.addAttributes(attrs);
1674 }
1675 
1676 // Build a SubViewOp with mixed static and dynamic entries and inferred result
1677 // type.
1678 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1679                       ArrayRef<OpFoldResult> offsets,
1680                       ArrayRef<OpFoldResult> sizes,
1681                       ArrayRef<OpFoldResult> strides,
1682                       ArrayRef<NamedAttribute> attrs) {
1683   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
1684 }
1685 
1686 // Build a SubViewOp with static entries and inferred result type.
1687 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1688                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1689                       ArrayRef<int64_t> strides,
1690                       ArrayRef<NamedAttribute> attrs) {
1691   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1692       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
1693         return b.getI64IntegerAttr(v);
1694       }));
1695   SmallVector<OpFoldResult> sizeValues =
1696       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1697         return b.getI64IntegerAttr(v);
1698       }));
1699   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1700       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1701         return b.getI64IntegerAttr(v);
1702       }));
1703   build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
1704 }
1705 
1706 // Build a SubViewOp with dynamic entries and custom result type. If the
1707 // type passed is nullptr, it is inferred.
1708 void SubViewOp::build(OpBuilder &b, OperationState &result,
1709                       MemRefType resultType, Value source,
1710                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1711                       ArrayRef<int64_t> strides,
1712                       ArrayRef<NamedAttribute> attrs) {
1713   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1714       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
1715         return b.getI64IntegerAttr(v);
1716       }));
1717   SmallVector<OpFoldResult> sizeValues =
1718       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1719         return b.getI64IntegerAttr(v);
1720       }));
1721   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1722       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1723         return b.getI64IntegerAttr(v);
1724       }));
1725   build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
1726         attrs);
1727 }
1728 
1729 // Build a SubViewOp with dynamic entries and custom result type. If the type
1730 // passed is nullptr, it is inferred.
1731 void SubViewOp::build(OpBuilder &b, OperationState &result,
1732                       MemRefType resultType, Value source, ValueRange offsets,
1733                       ValueRange sizes, ValueRange strides,
1734                       ArrayRef<NamedAttribute> attrs) {
1735   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1736       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1737   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1738       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1739   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1740       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1741   build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
1742 }
1743 
1744 // Build a SubViewOp with dynamic entries and inferred result type.
1745 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1746                       ValueRange offsets, ValueRange sizes, ValueRange strides,
1747                       ArrayRef<NamedAttribute> attrs) {
1748   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
1749 }
1750 
1751 /// For ViewLikeOpInterface.
1752 Value SubViewOp::getViewSource() { return source(); }
1753 
1754 enum SubViewVerificationResult {
1755   Success,
1756   RankTooLarge,
1757   SizeMismatch,
1758   ElemTypeMismatch,
1759   MemSpaceMismatch,
1760   AffineMapMismatch
1761 };
1762 
1763 /// Checks if `original` Type type can be rank reduced to `reduced` type.
1764 /// This function is slight variant of `is subsequence` algorithm where
1765 /// not matching dimension must be 1.
1766 static SubViewVerificationResult
1767 isRankReducedType(Type originalType, Type candidateReducedType,
1768                   std::string *errMsg = nullptr) {
1769   if (originalType == candidateReducedType)
1770     return SubViewVerificationResult::Success;
1771   if (!originalType.isa<MemRefType>())
1772     return SubViewVerificationResult::Success;
1773   if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>())
1774     return SubViewVerificationResult::Success;
1775 
1776   ShapedType originalShapedType = originalType.cast<ShapedType>();
1777   ShapedType candidateReducedShapedType =
1778       candidateReducedType.cast<ShapedType>();
1779 
1780   // Rank and size logic is valid for all ShapedTypes.
1781   ArrayRef<int64_t> originalShape = originalShapedType.getShape();
1782   ArrayRef<int64_t> candidateReducedShape =
1783       candidateReducedShapedType.getShape();
1784   unsigned originalRank = originalShape.size(),
1785            candidateReducedRank = candidateReducedShape.size();
1786   if (candidateReducedRank > originalRank)
1787     return SubViewVerificationResult::RankTooLarge;
1788 
1789   auto optionalUnusedDimsMask =
1790       computeRankReductionMask(originalShape, candidateReducedShape);
1791 
1792   // Sizes cannot be matched in case empty vector is returned.
1793   if (!optionalUnusedDimsMask.hasValue())
1794     return SubViewVerificationResult::SizeMismatch;
1795 
1796   if (originalShapedType.getElementType() !=
1797       candidateReducedShapedType.getElementType())
1798     return SubViewVerificationResult::ElemTypeMismatch;
1799 
1800   // Strided layout logic is relevant for MemRefType only.
1801   MemRefType original = originalType.cast<MemRefType>();
1802   MemRefType candidateReduced = candidateReducedType.cast<MemRefType>();
1803   if (original.getMemorySpace() != candidateReduced.getMemorySpace())
1804     return SubViewVerificationResult::MemSpaceMismatch;
1805 
1806   llvm::SmallDenseSet<unsigned> unusedDims = optionalUnusedDimsMask.getValue();
1807   auto inferredType =
1808       getProjectedMap(getStridedLinearLayoutMap(original), unusedDims);
1809   AffineMap candidateLayout;
1810   if (candidateReduced.getAffineMaps().empty())
1811     candidateLayout = getStridedLinearLayoutMap(candidateReduced);
1812   else
1813     candidateLayout = candidateReduced.getAffineMaps().front();
1814   assert(inferredType.getNumResults() == 1 &&
1815          candidateLayout.getNumResults() == 1);
1816   if (inferredType.getNumSymbols() != candidateLayout.getNumSymbols() ||
1817       inferredType.getNumDims() != candidateLayout.getNumDims()) {
1818     if (errMsg) {
1819       llvm::raw_string_ostream os(*errMsg);
1820       os << "inferred type: " << inferredType;
1821     }
1822     return SubViewVerificationResult::AffineMapMismatch;
1823   }
1824   // Check that the difference of the affine maps simplifies to 0.
1825   AffineExpr diffExpr =
1826       inferredType.getResult(0) - candidateLayout.getResult(0);
1827   diffExpr = simplifyAffineExpr(diffExpr, inferredType.getNumDims(),
1828                                 inferredType.getNumSymbols());
1829   auto cst = diffExpr.dyn_cast<AffineConstantExpr>();
1830   if (!(cst && cst.getValue() == 0)) {
1831     if (errMsg) {
1832       llvm::raw_string_ostream os(*errMsg);
1833       os << "inferred type: " << inferredType;
1834     }
1835     return SubViewVerificationResult::AffineMapMismatch;
1836   }
1837   return SubViewVerificationResult::Success;
1838 }
1839 
1840 template <typename OpTy>
1841 static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result,
1842                                             OpTy op, Type expectedType,
1843                                             StringRef errMsg = "") {
1844   auto memrefType = expectedType.cast<ShapedType>();
1845   switch (result) {
1846   case SubViewVerificationResult::Success:
1847     return success();
1848   case SubViewVerificationResult::RankTooLarge:
1849     return op.emitError("expected result rank to be smaller or equal to ")
1850            << "the source rank. " << errMsg;
1851   case SubViewVerificationResult::SizeMismatch:
1852     return op.emitError("expected result type to be ")
1853            << expectedType
1854            << " or a rank-reduced version. (mismatch of result sizes) "
1855            << errMsg;
1856   case SubViewVerificationResult::ElemTypeMismatch:
1857     return op.emitError("expected result element type to be ")
1858            << memrefType.getElementType() << errMsg;
1859   case SubViewVerificationResult::MemSpaceMismatch:
1860     return op.emitError("expected result and source memory spaces to match.")
1861            << errMsg;
1862   case SubViewVerificationResult::AffineMapMismatch:
1863     return op.emitError("expected result type to be ")
1864            << expectedType
1865            << " or a rank-reduced version. (mismatch of result affine map) "
1866            << errMsg;
1867   }
1868   llvm_unreachable("unexpected subview verification result");
1869 }
1870 
1871 /// Verifier for SubViewOp.
1872 static LogicalResult verify(SubViewOp op) {
1873   MemRefType baseType = op.getSourceType();
1874   MemRefType subViewType = op.getType();
1875 
1876   // The base memref and the view memref should be in the same memory space.
1877   if (baseType.getMemorySpace() != subViewType.getMemorySpace())
1878     return op.emitError("different memory spaces specified for base memref "
1879                         "type ")
1880            << baseType << " and subview memref type " << subViewType;
1881 
1882   // Verify that the base memref type has a strided layout map.
1883   if (!isStrided(baseType))
1884     return op.emitError("base type ") << baseType << " is not strided";
1885 
1886   // Verify result type against inferred type.
1887   auto expectedType = SubViewOp::inferResultType(
1888       baseType, extractFromI64ArrayAttr(op.static_offsets()),
1889       extractFromI64ArrayAttr(op.static_sizes()),
1890       extractFromI64ArrayAttr(op.static_strides()));
1891 
1892   std::string errMsg;
1893   auto result = isRankReducedType(expectedType, subViewType, &errMsg);
1894   return produceSubViewErrorMsg(result, op, expectedType, errMsg);
1895 }
1896 
1897 raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) {
1898   return os << "range " << range.offset << ":" << range.size << ":"
1899             << range.stride;
1900 }
1901 
1902 /// Return the list of Range (i.e. offset, size, stride). Each Range
1903 /// entry contains either the dynamic value or a ConstantIndexOp constructed
1904 /// with `b` at location `loc`.
1905 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
1906                                               OpBuilder &b, Location loc) {
1907   std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
1908   assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
1909   assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
1910   SmallVector<Range, 8> res;
1911   unsigned rank = ranks[0];
1912   res.reserve(rank);
1913   for (unsigned idx = 0; idx < rank; ++idx) {
1914     Value offset =
1915         op.isDynamicOffset(idx)
1916             ? op.getDynamicOffset(idx)
1917             : b.create<ConstantIndexOp>(loc, op.getStaticOffset(idx));
1918     Value size = op.isDynamicSize(idx)
1919                      ? op.getDynamicSize(idx)
1920                      : b.create<ConstantIndexOp>(loc, op.getStaticSize(idx));
1921     Value stride =
1922         op.isDynamicStride(idx)
1923             ? op.getDynamicStride(idx)
1924             : b.create<ConstantIndexOp>(loc, op.getStaticStride(idx));
1925     res.emplace_back(Range{offset, size, stride});
1926   }
1927   return res;
1928 }
1929 
1930 /// Infer the canonical type of the result of a subview operation. Returns a
1931 /// type with rank `resultRank` that is either the rank of the rank-reduced
1932 /// type, or the non-rank-reduced type.
1933 static MemRefType
1934 getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType,
1935                               ArrayRef<OpFoldResult> mixedOffsets,
1936                               ArrayRef<OpFoldResult> mixedSizes,
1937                               ArrayRef<OpFoldResult> mixedStrides) {
1938   auto resultType =
1939       SubViewOp::inferRankReducedResultType(
1940           resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides)
1941           .cast<MemRefType>();
1942   if (resultType.getRank() != resultRank) {
1943     resultType = SubViewOp::inferResultType(sourceType, mixedOffsets,
1944                                             mixedSizes, mixedStrides)
1945                      .cast<MemRefType>();
1946   }
1947   return resultType;
1948 }
1949 
1950 namespace {
1951 /// Pattern to rewrite a subview op with MemRefCast arguments.
1952 /// This essentially pushes memref.cast past its consuming subview when
1953 /// `canFoldIntoConsumerOp` is true.
1954 ///
1955 /// Example:
1956 /// ```
1957 ///   %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
1958 ///   %1 = memref.subview %0[0, 0][3, 4][1, 1] :
1959 ///     memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
1960 /// ```
1961 /// is rewritten into:
1962 /// ```
1963 ///   %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
1964 ///   %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to
1965 ///     memref<3x4xf32, offset:?, strides:[?, 1]>
1966 /// ```
1967 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
1968 public:
1969   using OpRewritePattern<SubViewOp>::OpRewritePattern;
1970 
1971   LogicalResult matchAndRewrite(SubViewOp subViewOp,
1972                                 PatternRewriter &rewriter) const override {
1973     // Any constant operand, just return to let SubViewOpConstantFolder kick in.
1974     if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
1975           return matchPattern(operand, matchConstantIndex());
1976         }))
1977       return failure();
1978 
1979     auto castOp = subViewOp.source().getDefiningOp<CastOp>();
1980     if (!castOp)
1981       return failure();
1982 
1983     if (!CastOp::canFoldIntoConsumerOp(castOp))
1984       return failure();
1985 
1986     /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
1987     /// the cast source operand type and the SubViewOp static information. This
1988     /// is the resulting type if the MemRefCastOp were folded.
1989     auto resultType = getCanonicalSubViewResultType(
1990         subViewOp.getType().getRank(),
1991         castOp.source().getType().cast<MemRefType>(),
1992         subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
1993         subViewOp.getMixedStrides());
1994     Value newSubView = rewriter.create<SubViewOp>(
1995         subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
1996         subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),
1997         subViewOp.static_sizes(), subViewOp.static_strides());
1998     rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
1999                                         newSubView);
2000     return success();
2001   }
2002 };
2003 } // namespace
2004 
2005 /// Return the canonical type of the result of a subview.
2006 struct SubViewReturnTypeCanonicalizer {
2007   MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
2008                         ArrayRef<OpFoldResult> mixedSizes,
2009                         ArrayRef<OpFoldResult> mixedStrides) {
2010     return getCanonicalSubViewResultType(op.getType().getRank(),
2011                                          op.getSourceType(), mixedOffsets,
2012                                          mixedSizes, mixedStrides);
2013   }
2014 };
2015 
2016 /// A canonicalizer wrapper to replace SubViewOps.
2017 struct SubViewCanonicalizer {
2018   void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
2019     rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType());
2020   }
2021 };
2022 
2023 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
2024                                             MLIRContext *context) {
2025   results
2026       .add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
2027                SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>,
2028            SubViewOpMemRefCastFolder>(context);
2029 }
2030 
2031 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
2032   auto resultShapedType = getResult().getType().cast<ShapedType>();
2033   auto sourceShapedType = source().getType().cast<ShapedType>();
2034 
2035   if (resultShapedType.hasStaticShape() &&
2036       resultShapedType == sourceShapedType) {
2037     return getViewSource();
2038   }
2039 
2040   return {};
2041 }
2042 
2043 //===----------------------------------------------------------------------===//
2044 // TensorLoadOp
2045 //===----------------------------------------------------------------------===//
2046 
2047 OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute>) {
2048   if (auto bufferCast = memref().getDefiningOp<BufferCastOp>())
2049     // Approximate alias analysis by conservatively folding only when no there
2050     // is no interleaved operation.
2051     if (bufferCast->getBlock() == this->getOperation()->getBlock() &&
2052         bufferCast->getNextNode() == this->getOperation())
2053       return bufferCast.tensor();
2054   return {};
2055 }
2056 
2057 //===----------------------------------------------------------------------===//
2058 // TransposeOp
2059 //===----------------------------------------------------------------------===//
2060 
2061 /// Build a strided memref type by applying `permutationMap` tp `memRefType`.
2062 static MemRefType inferTransposeResultType(MemRefType memRefType,
2063                                            AffineMap permutationMap) {
2064   auto rank = memRefType.getRank();
2065   auto originalSizes = memRefType.getShape();
2066   // Compute permuted sizes.
2067   SmallVector<int64_t, 4> sizes(rank, 0);
2068   for (auto en : llvm::enumerate(permutationMap.getResults()))
2069     sizes[en.index()] =
2070         originalSizes[en.value().cast<AffineDimExpr>().getPosition()];
2071 
2072   // Compute permuted strides.
2073   int64_t offset;
2074   SmallVector<int64_t, 4> strides;
2075   auto res = getStridesAndOffset(memRefType, strides, offset);
2076   assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank));
2077   (void)res;
2078   auto map =
2079       makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
2080   map = permutationMap ? map.compose(permutationMap) : map;
2081   return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map);
2082 }
2083 
2084 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
2085                         AffineMapAttr permutation,
2086                         ArrayRef<NamedAttribute> attrs) {
2087   auto permutationMap = permutation.getValue();
2088   assert(permutationMap);
2089 
2090   auto memRefType = in.getType().cast<MemRefType>();
2091   // Compute result type.
2092   MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
2093 
2094   build(b, result, resultType, in, attrs);
2095   result.addAttribute(TransposeOp::getPermutationAttrName(), permutation);
2096 }
2097 
2098 // transpose $in $permutation attr-dict : type($in) `to` type(results)
2099 static void print(OpAsmPrinter &p, TransposeOp op) {
2100   p << "memref.transpose " << op.in() << " " << op.permutation();
2101   p.printOptionalAttrDict(op->getAttrs(),
2102                           {TransposeOp::getPermutationAttrName()});
2103   p << " : " << op.in().getType() << " to " << op.getType();
2104 }
2105 
2106 static ParseResult parseTransposeOp(OpAsmParser &parser,
2107                                     OperationState &result) {
2108   OpAsmParser::OperandType in;
2109   AffineMap permutation;
2110   MemRefType srcType, dstType;
2111   if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
2112       parser.parseOptionalAttrDict(result.attributes) ||
2113       parser.parseColonType(srcType) ||
2114       parser.resolveOperand(in, srcType, result.operands) ||
2115       parser.parseKeywordType("to", dstType) ||
2116       parser.addTypeToList(dstType, result.types))
2117     return failure();
2118 
2119   result.addAttribute(TransposeOp::getPermutationAttrName(),
2120                       AffineMapAttr::get(permutation));
2121   return success();
2122 }
2123 
2124 static LogicalResult verify(TransposeOp op) {
2125   if (!op.permutation().isPermutation())
2126     return op.emitOpError("expected a permutation map");
2127   if (op.permutation().getNumDims() != op.getShapedType().getRank())
2128     return op.emitOpError(
2129         "expected a permutation map of same rank as the input");
2130 
2131   auto srcType = op.in().getType().cast<MemRefType>();
2132   auto dstType = op.getType().cast<MemRefType>();
2133   auto transposedType = inferTransposeResultType(srcType, op.permutation());
2134   if (dstType != transposedType)
2135     return op.emitOpError("output type ")
2136            << dstType << " does not match transposed input type " << srcType
2137            << ", " << transposedType;
2138   return success();
2139 }
2140 
2141 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
2142   if (succeeded(foldMemRefCast(*this)))
2143     return getResult();
2144   return {};
2145 }
2146 
2147 //===----------------------------------------------------------------------===//
2148 // ViewOp
2149 //===----------------------------------------------------------------------===//
2150 
2151 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) {
2152   OpAsmParser::OperandType srcInfo;
2153   SmallVector<OpAsmParser::OperandType, 1> offsetInfo;
2154   SmallVector<OpAsmParser::OperandType, 4> sizesInfo;
2155   auto indexType = parser.getBuilder().getIndexType();
2156   Type srcType, dstType;
2157   llvm::SMLoc offsetLoc;
2158   if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) ||
2159       parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square))
2160     return failure();
2161 
2162   if (offsetInfo.size() != 1)
2163     return parser.emitError(offsetLoc) << "expects 1 offset operand";
2164 
2165   return failure(
2166       parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) ||
2167       parser.parseOptionalAttrDict(result.attributes) ||
2168       parser.parseColonType(srcType) ||
2169       parser.resolveOperand(srcInfo, srcType, result.operands) ||
2170       parser.resolveOperands(offsetInfo, indexType, result.operands) ||
2171       parser.resolveOperands(sizesInfo, indexType, result.operands) ||
2172       parser.parseKeywordType("to", dstType) ||
2173       parser.addTypeToList(dstType, result.types));
2174 }
2175 
2176 static void print(OpAsmPrinter &p, ViewOp op) {
2177   p << op.getOperationName() << ' ' << op.getOperand(0) << '[';
2178   p.printOperand(op.byte_shift());
2179   p << "][" << op.sizes() << ']';
2180   p.printOptionalAttrDict(op->getAttrs());
2181   p << " : " << op.getOperand(0).getType() << " to " << op.getType();
2182 }
2183 
2184 static LogicalResult verify(ViewOp op) {
2185   auto baseType = op.getOperand(0).getType().cast<MemRefType>();
2186   auto viewType = op.getType();
2187 
2188   // The base memref should have identity layout map (or none).
2189   if (baseType.getAffineMaps().size() > 1 ||
2190       (baseType.getAffineMaps().size() == 1 &&
2191        !baseType.getAffineMaps()[0].isIdentity()))
2192     return op.emitError("unsupported map for base memref type ") << baseType;
2193 
2194   // The result memref should have identity layout map (or none).
2195   if (viewType.getAffineMaps().size() > 1 ||
2196       (viewType.getAffineMaps().size() == 1 &&
2197        !viewType.getAffineMaps()[0].isIdentity()))
2198     return op.emitError("unsupported map for result memref type ") << viewType;
2199 
2200   // The base memref and the view memref should be in the same memory space.
2201   if (baseType.getMemorySpace() != viewType.getMemorySpace())
2202     return op.emitError("different memory spaces specified for base memref "
2203                         "type ")
2204            << baseType << " and view memref type " << viewType;
2205 
2206   // Verify that we have the correct number of sizes for the result type.
2207   unsigned numDynamicDims = viewType.getNumDynamicDims();
2208   if (op.sizes().size() != numDynamicDims)
2209     return op.emitError("incorrect number of size operands for type ")
2210            << viewType;
2211 
2212   return success();
2213 }
2214 
2215 Value ViewOp::getViewSource() { return source(); }
2216 
2217 namespace {
2218 
2219 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
2220   using OpRewritePattern<ViewOp>::OpRewritePattern;
2221 
2222   LogicalResult matchAndRewrite(ViewOp viewOp,
2223                                 PatternRewriter &rewriter) const override {
2224     // Return if none of the operands are constants.
2225     if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
2226           return matchPattern(operand, matchConstantIndex());
2227         }))
2228       return failure();
2229 
2230     // Get result memref type.
2231     auto memrefType = viewOp.getType();
2232 
2233     // Get offset from old memref view type 'memRefType'.
2234     int64_t oldOffset;
2235     SmallVector<int64_t, 4> oldStrides;
2236     if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
2237       return failure();
2238     assert(oldOffset == 0 && "Expected 0 offset");
2239 
2240     SmallVector<Value, 4> newOperands;
2241 
2242     // Offset cannot be folded into result type.
2243 
2244     // Fold any dynamic dim operands which are produced by a constant.
2245     SmallVector<int64_t, 4> newShapeConstants;
2246     newShapeConstants.reserve(memrefType.getRank());
2247 
2248     unsigned dynamicDimPos = 0;
2249     unsigned rank = memrefType.getRank();
2250     for (unsigned dim = 0, e = rank; dim < e; ++dim) {
2251       int64_t dimSize = memrefType.getDimSize(dim);
2252       // If this is already static dimension, keep it.
2253       if (!ShapedType::isDynamic(dimSize)) {
2254         newShapeConstants.push_back(dimSize);
2255         continue;
2256       }
2257       auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp();
2258       if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
2259         // Dynamic shape dimension will be folded.
2260         newShapeConstants.push_back(constantIndexOp.getValue());
2261       } else {
2262         // Dynamic shape dimension not folded; copy operand from old memref.
2263         newShapeConstants.push_back(dimSize);
2264         newOperands.push_back(viewOp.sizes()[dynamicDimPos]);
2265       }
2266       dynamicDimPos++;
2267     }
2268 
2269     // Create new memref type with constant folded dims.
2270     MemRefType newMemRefType =
2271         MemRefType::Builder(memrefType).setShape(newShapeConstants);
2272     // Nothing new, don't fold.
2273     if (newMemRefType == memrefType)
2274       return failure();
2275 
2276     // Create new ViewOp.
2277     auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType,
2278                                              viewOp.getOperand(0),
2279                                              viewOp.byte_shift(), newOperands);
2280     // Insert a cast so we have the same type as the old memref type.
2281     rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType());
2282     return success();
2283   }
2284 };
2285 
2286 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
2287   using OpRewritePattern<ViewOp>::OpRewritePattern;
2288 
2289   LogicalResult matchAndRewrite(ViewOp viewOp,
2290                                 PatternRewriter &rewriter) const override {
2291     Value memrefOperand = viewOp.getOperand(0);
2292     CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
2293     if (!memrefCastOp)
2294       return failure();
2295     Value allocOperand = memrefCastOp.getOperand();
2296     AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
2297     if (!allocOp)
2298       return failure();
2299     rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
2300                                         viewOp.byte_shift(), viewOp.sizes());
2301     return success();
2302   }
2303 };
2304 
2305 } // end anonymous namespace
2306 
2307 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
2308                                          MLIRContext *context) {
2309   results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
2310 }
2311 
2312 //===----------------------------------------------------------------------===//
2313 // TableGen'd op method definitions
2314 //===----------------------------------------------------------------------===//
2315 
2316 #define GET_OP_CLASSES
2317 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
2318