1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/MemRef/IR/MemRef.h"
10 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/Dialect/Utils/StaticValueUtils.h"
15 #include "mlir/IR/AffineMap.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/Interfaces/InferTypeOpInterface.h"
22 #include "mlir/Interfaces/ViewLikeInterface.h"
23 #include "llvm/ADT/STLExtras.h"
24 
25 using namespace mlir;
26 using namespace mlir::memref;
27 
28 /// Materialize a single constant operation from a given attribute value with
29 /// the desired resultant type.
30 Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
31                                               Attribute value, Type type,
32                                               Location loc) {
33   return builder.create<mlir::ConstantOp>(loc, type, value);
34 }
35 
36 //===----------------------------------------------------------------------===//
37 // Common canonicalization pattern support logic
38 //===----------------------------------------------------------------------===//
39 
40 /// This is a common class used for patterns of the form
41 /// "someop(memrefcast) -> someop".  It folds the source of any memref.cast
42 /// into the root operation directly.
43 static LogicalResult foldMemRefCast(Operation *op, Value inner = nullptr) {
44   bool folded = false;
45   for (OpOperand &operand : op->getOpOperands()) {
46     auto cast = operand.get().getDefiningOp<CastOp>();
47     if (cast && operand.get() != inner &&
48         !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
49       operand.set(cast.getOperand());
50       folded = true;
51     }
52   }
53   return success(folded);
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // Helpers for GlobalOp
58 //===----------------------------------------------------------------------===//
59 
60 static Type getTensorTypeFromMemRefType(Type type) {
61   if (auto memref = type.dyn_cast<MemRefType>())
62     return RankedTensorType::get(memref.getShape(), memref.getElementType());
63   if (auto memref = type.dyn_cast<UnrankedMemRefType>())
64     return UnrankedTensorType::get(memref.getElementType());
65   return NoneType::get(type.getContext());
66 }
67 
68 //===----------------------------------------------------------------------===//
69 // AllocOp / AllocaOp
70 //===----------------------------------------------------------------------===//
71 
72 template <typename AllocLikeOp>
73 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
74   static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
75                 "applies to only alloc or alloca");
76   auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>();
77   if (!memRefType)
78     return op.emitOpError("result must be a memref");
79 
80   if (static_cast<int64_t>(op.dynamicSizes().size()) !=
81       memRefType.getNumDynamicDims())
82     return op.emitOpError("dimension operand count does not equal memref "
83                           "dynamic dimension count");
84 
85   unsigned numSymbols = 0;
86   if (!memRefType.getAffineMaps().empty())
87     numSymbols = memRefType.getAffineMaps().front().getNumSymbols();
88   if (op.symbolOperands().size() != numSymbols)
89     return op.emitOpError("symbol operand count does not equal memref symbol "
90                           "count: expected ")
91            << numSymbols << ", got " << op.symbolOperands().size();
92 
93   return success();
94 }
95 
96 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); }
97 
98 static LogicalResult verify(AllocaOp op) {
99   // An alloca op needs to have an ancestor with an allocation scope trait.
100   if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
101     return op.emitOpError(
102         "requires an ancestor op with AutomaticAllocationScope trait");
103 
104   return verifyAllocLikeOp(op);
105 }
106 
107 namespace {
108 /// Fold constant dimensions into an alloc like operation.
109 template <typename AllocLikeOp>
110 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
111   using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
112 
113   LogicalResult matchAndRewrite(AllocLikeOp alloc,
114                                 PatternRewriter &rewriter) const override {
115     // Check to see if any dimensions operands are constants.  If so, we can
116     // substitute and drop them.
117     if (llvm::none_of(alloc.dynamicSizes(), [](Value operand) {
118           return matchPattern(operand, matchConstantIndex());
119         }))
120       return failure();
121 
122     auto memrefType = alloc.getType();
123 
124     // Ok, we have one or more constant operands.  Collect the non-constant ones
125     // and keep track of the resultant memref type to build.
126     SmallVector<int64_t, 4> newShapeConstants;
127     newShapeConstants.reserve(memrefType.getRank());
128     SmallVector<Value, 4> dynamicSizes;
129 
130     unsigned dynamicDimPos = 0;
131     for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
132       int64_t dimSize = memrefType.getDimSize(dim);
133       // If this is already static dimension, keep it.
134       if (dimSize != -1) {
135         newShapeConstants.push_back(dimSize);
136         continue;
137       }
138       auto dynamicSize = alloc.dynamicSizes()[dynamicDimPos];
139       auto *defOp = dynamicSize.getDefiningOp();
140       if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
141         // Dynamic shape dimension will be folded.
142         newShapeConstants.push_back(constantIndexOp.getValue());
143       } else {
144         // Dynamic shape dimension not folded; copy dynamicSize from old memref.
145         newShapeConstants.push_back(-1);
146         dynamicSizes.push_back(dynamicSize);
147       }
148       dynamicDimPos++;
149     }
150 
151     // Create new memref type (which will have fewer dynamic dimensions).
152     MemRefType newMemRefType =
153         MemRefType::Builder(memrefType).setShape(newShapeConstants);
154     assert(static_cast<int64_t>(dynamicSizes.size()) ==
155            newMemRefType.getNumDynamicDims());
156 
157     // Create and insert the alloc op for the new memref.
158     auto newAlloc = rewriter.create<AllocLikeOp>(
159         alloc.getLoc(), newMemRefType, dynamicSizes, alloc.symbolOperands(),
160         alloc.alignmentAttr());
161     // Insert a cast so we have the same type as the old alloc.
162     auto resultCast =
163         rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType());
164 
165     rewriter.replaceOp(alloc, {resultCast});
166     return success();
167   }
168 };
169 
170 /// Fold alloc operations with no users or only store and dealloc uses.
171 template <typename T>
172 struct SimplifyDeadAlloc : public OpRewritePattern<T> {
173   using OpRewritePattern<T>::OpRewritePattern;
174 
175   LogicalResult matchAndRewrite(T alloc,
176                                 PatternRewriter &rewriter) const override {
177     if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
178           if (auto storeOp = dyn_cast<StoreOp>(op))
179             return storeOp.value() == alloc;
180           return !isa<DeallocOp>(op);
181         }))
182       return failure();
183 
184     for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
185       rewriter.eraseOp(user);
186 
187     rewriter.eraseOp(alloc);
188     return success();
189   }
190 };
191 } // end anonymous namespace.
192 
193 Optional<Operation *> AllocOp::buildDealloc(OpBuilder &builder, Value alloc) {
194   return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
195       .getOperation();
196 }
197 
198 Optional<Value> AllocOp::buildClone(OpBuilder &builder, Value alloc) {
199   return builder.create<memref::CloneOp>(alloc.getLoc(), alloc).getResult();
200 }
201 
202 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
203                                           MLIRContext *context) {
204   results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
205 }
206 
207 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
208                                            MLIRContext *context) {
209   results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
210       context);
211 }
212 
213 //===----------------------------------------------------------------------===//
214 // AllocaScopeOp
215 //===----------------------------------------------------------------------===//
216 
217 static void print(OpAsmPrinter &p, AllocaScopeOp &op) {
218   bool printBlockTerminators = false;
219 
220   p << " ";
221   if (!op.results().empty()) {
222     p << " -> (" << op.getResultTypes() << ")";
223     printBlockTerminators = true;
224   }
225   p.printRegion(op.bodyRegion(),
226                 /*printEntryBlockArgs=*/false,
227                 /*printBlockTerminators=*/printBlockTerminators);
228   p.printOptionalAttrDict(op->getAttrs());
229 }
230 
231 static ParseResult parseAllocaScopeOp(OpAsmParser &parser,
232                                       OperationState &result) {
233   // Create a region for the body.
234   result.regions.reserve(1);
235   Region *bodyRegion = result.addRegion();
236 
237   // Parse optional results type list.
238   if (parser.parseOptionalArrowTypeList(result.types))
239     return failure();
240 
241   // Parse the body region.
242   if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{}))
243     return failure();
244   AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
245                                   result.location);
246 
247   // Parse the optional attribute list.
248   if (parser.parseOptionalAttrDict(result.attributes))
249     return failure();
250 
251   return success();
252 }
253 
254 static LogicalResult verify(AllocaScopeOp op) {
255   if (failed(RegionBranchOpInterface::verifyTypes(op)))
256     return failure();
257 
258   return success();
259 }
260 
261 void AllocaScopeOp::getSuccessorRegions(
262     Optional<unsigned> index, ArrayRef<Attribute> operands,
263     SmallVectorImpl<RegionSuccessor> &regions) {
264   if (index.hasValue()) {
265     regions.push_back(RegionSuccessor(getResults()));
266     return;
267   }
268 
269   regions.push_back(RegionSuccessor(&bodyRegion()));
270 }
271 
272 //===----------------------------------------------------------------------===//
273 // AssumeAlignmentOp
274 //===----------------------------------------------------------------------===//
275 
276 static LogicalResult verify(AssumeAlignmentOp op) {
277   unsigned alignment = op.alignment();
278   if (!llvm::isPowerOf2_32(alignment))
279     return op.emitOpError("alignment must be power of 2");
280   return success();
281 }
282 
283 //===----------------------------------------------------------------------===//
284 // BufferCastOp
285 //===----------------------------------------------------------------------===//
286 
287 OpFoldResult BufferCastOp::fold(ArrayRef<Attribute>) {
288   if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>())
289     if (tensorLoad.memref().getType() == getType())
290       return tensorLoad.memref();
291   return {};
292 }
293 
294 namespace {
295 /// Replace tensor_cast + buffer_cast by buffer_cast + memref_cast.
296 struct BufferCast : public OpRewritePattern<BufferCastOp> {
297   using OpRewritePattern<BufferCastOp>::OpRewritePattern;
298 
299   LogicalResult matchAndRewrite(BufferCastOp bufferCast,
300                                 PatternRewriter &rewriter) const final {
301     auto tensorCastOperand =
302         bufferCast.getOperand().getDefiningOp<tensor::CastOp>();
303     if (!tensorCastOperand)
304       return failure();
305     auto srcTensorType =
306         tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>();
307     if (!srcTensorType)
308       return failure();
309     auto memrefType = MemRefType::get(srcTensorType.getShape(),
310                                       srcTensorType.getElementType());
311     Value memref = rewriter.create<BufferCastOp>(
312         bufferCast.getLoc(), memrefType, tensorCastOperand.getOperand());
313     rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(),
314                                         memref);
315     return success();
316   }
317 };
318 
319 /// Canonicalize memref.tensor_load + memref.buffer_cast to memref.cast when
320 /// type mismatches prevent `BufferCastOp::fold` to kick in.
321 struct TensorLoadToMemRef : public OpRewritePattern<BufferCastOp> {
322   using OpRewritePattern<BufferCastOp>::OpRewritePattern;
323 
324   LogicalResult matchAndRewrite(BufferCastOp bufferCast,
325                                 PatternRewriter &rewriter) const final {
326     auto tensorLoad = bufferCast.tensor().getDefiningOp<TensorLoadOp>();
327     // Bail unless we have a tensor_load + memref.buffer_cast with different
328     // types. `BufferCastOp::fold` handles the same type case.
329     if (!tensorLoad || tensorLoad.memref().getType() == bufferCast.getType())
330       return failure();
331     // If types are definitely not cast-compatible, bail.
332     if (!CastOp::areCastCompatible(tensorLoad.memref().getType(),
333                                    bufferCast.getType()))
334       return failure();
335 
336     // We already know that the types are potentially cast-compatible. However
337     // in case the affine maps are different, we may need to use a copy if we go
338     // from dynamic to static offset or stride (the canonicalization cannot know
339     // at this point that it is really cast compatible).
340     auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
341       int64_t sourceOffset, targetOffset;
342       SmallVector<int64_t, 4> sourceStrides, targetStrides;
343       if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
344           failed(getStridesAndOffset(target, targetStrides, targetOffset)))
345         return false;
346       auto dynamicToStatic = [](int64_t a, int64_t b) {
347         return a == MemRefType::getDynamicStrideOrOffset() &&
348                b != MemRefType::getDynamicStrideOrOffset();
349       };
350       if (dynamicToStatic(sourceOffset, targetOffset))
351         return false;
352       for (auto it : zip(sourceStrides, targetStrides))
353         if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
354           return false;
355       return true;
356     };
357 
358     auto tensorLoadType = tensorLoad.memref().getType().dyn_cast<MemRefType>();
359     auto bufferCastType = bufferCast.getType().dyn_cast<MemRefType>();
360     if (tensorLoadType && bufferCastType &&
361         !isGuaranteedCastCompatible(tensorLoadType, bufferCastType)) {
362       MemRefType resultType = bufferCastType;
363       auto loc = bufferCast.getLoc();
364       SmallVector<Value, 4> dynamicOperands;
365       for (int i = 0; i < resultType.getRank(); ++i) {
366         if (resultType.getShape()[i] != ShapedType::kDynamicSize)
367           continue;
368         auto index = rewriter.createOrFold<ConstantIndexOp>(loc, i);
369         Value size = rewriter.create<tensor::DimOp>(loc, tensorLoad, index);
370         dynamicOperands.push_back(size);
371       }
372       auto copy =
373           rewriter.create<memref::AllocOp>(loc, resultType, dynamicOperands);
374       rewriter.create<CopyOp>(loc, tensorLoad.memref(), copy);
375       rewriter.replaceOp(bufferCast, {copy});
376     } else
377       rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(),
378                                           tensorLoad.memref());
379     return success();
380   }
381 };
382 
383 } // namespace
384 
385 void BufferCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
386                                                MLIRContext *context) {
387   results.add<BufferCast, TensorLoadToMemRef>(context);
388 }
389 
390 //===----------------------------------------------------------------------===//
391 // CastOp
392 //===----------------------------------------------------------------------===//
393 
394 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
395 /// source memref. This is useful to to fold a memref.cast into a consuming op
396 /// and implement canonicalization patterns for ops in different dialects that
397 /// may consume the results of memref.cast operations. Such foldable memref.cast
398 /// operations are typically inserted as `view` and `subview` ops are
399 /// canonicalized, to preserve the type compatibility of their uses.
400 ///
401 /// Returns true when all conditions are met:
402 /// 1. source and result are ranked memrefs with strided semantics and same
403 /// element type and rank.
404 /// 2. each of the source's size, offset or stride has more static information
405 /// than the corresponding result's size, offset or stride.
406 ///
407 /// Example 1:
408 /// ```mlir
409 ///   %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
410 ///   %2 = consumer %1 ... : memref<?x?xf32> ...
411 /// ```
412 ///
413 /// may fold into:
414 ///
415 /// ```mlir
416 ///   %2 = consumer %0 ... : memref<8x16xf32> ...
417 /// ```
418 ///
419 /// Example 2:
420 /// ```
421 ///   %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
422 ///          to memref<?x?xf32>
423 ///   consumer %1 : memref<?x?xf32> ...
424 /// ```
425 ///
426 /// may fold into:
427 ///
428 /// ```
429 ///   consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
430 /// ```
431 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
432   MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
433   MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
434 
435   // Requires ranked MemRefType.
436   if (!sourceType || !resultType)
437     return false;
438 
439   // Requires same elemental type.
440   if (sourceType.getElementType() != resultType.getElementType())
441     return false;
442 
443   // Requires same rank.
444   if (sourceType.getRank() != resultType.getRank())
445     return false;
446 
447   // Only fold casts between strided memref forms.
448   int64_t sourceOffset, resultOffset;
449   SmallVector<int64_t, 4> sourceStrides, resultStrides;
450   if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
451       failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
452     return false;
453 
454   // If cast is towards more static sizes along any dimension, don't fold.
455   for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
456     auto ss = std::get<0>(it), st = std::get<1>(it);
457     if (ss != st)
458       if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st))
459         return false;
460   }
461 
462   // If cast is towards more static offset along any dimension, don't fold.
463   if (sourceOffset != resultOffset)
464     if (MemRefType::isDynamicStrideOrOffset(sourceOffset) &&
465         !MemRefType::isDynamicStrideOrOffset(resultOffset))
466       return false;
467 
468   // If cast is towards more static strides along any dimension, don't fold.
469   for (auto it : llvm::zip(sourceStrides, resultStrides)) {
470     auto ss = std::get<0>(it), st = std::get<1>(it);
471     if (ss != st)
472       if (MemRefType::isDynamicStrideOrOffset(ss) &&
473           !MemRefType::isDynamicStrideOrOffset(st))
474         return false;
475   }
476 
477   return true;
478 }
479 
480 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
481   if (inputs.size() != 1 || outputs.size() != 1)
482     return false;
483   Type a = inputs.front(), b = outputs.front();
484   auto aT = a.dyn_cast<MemRefType>();
485   auto bT = b.dyn_cast<MemRefType>();
486 
487   auto uaT = a.dyn_cast<UnrankedMemRefType>();
488   auto ubT = b.dyn_cast<UnrankedMemRefType>();
489 
490   if (aT && bT) {
491     if (aT.getElementType() != bT.getElementType())
492       return false;
493     if (aT.getAffineMaps() != bT.getAffineMaps()) {
494       int64_t aOffset, bOffset;
495       SmallVector<int64_t, 4> aStrides, bStrides;
496       if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
497           failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
498           aStrides.size() != bStrides.size())
499         return false;
500 
501       // Strides along a dimension/offset are compatible if the value in the
502       // source memref is static and the value in the target memref is the
503       // same. They are also compatible if either one is dynamic (see
504       // description of MemRefCastOp for details).
505       auto checkCompatible = [](int64_t a, int64_t b) {
506         return (a == MemRefType::getDynamicStrideOrOffset() ||
507                 b == MemRefType::getDynamicStrideOrOffset() || a == b);
508       };
509       if (!checkCompatible(aOffset, bOffset))
510         return false;
511       for (auto aStride : enumerate(aStrides))
512         if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
513           return false;
514     }
515     if (aT.getMemorySpace() != bT.getMemorySpace())
516       return false;
517 
518     // They must have the same rank, and any specified dimensions must match.
519     if (aT.getRank() != bT.getRank())
520       return false;
521 
522     for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
523       int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
524       if (aDim != -1 && bDim != -1 && aDim != bDim)
525         return false;
526     }
527     return true;
528   } else {
529     if (!aT && !uaT)
530       return false;
531     if (!bT && !ubT)
532       return false;
533     // Unranked to unranked casting is unsupported
534     if (uaT && ubT)
535       return false;
536 
537     auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
538     auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
539     if (aEltType != bEltType)
540       return false;
541 
542     auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
543     auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
544     if (aMemSpace != bMemSpace)
545       return false;
546 
547     return true;
548   }
549 
550   return false;
551 }
552 
553 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
554   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
555 }
556 
557 //===----------------------------------------------------------------------===//
558 // CloneOp
559 //===----------------------------------------------------------------------===//
560 
561 void CloneOp::getEffects(
562     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
563         &effects) {
564   effects.emplace_back(MemoryEffects::Read::get(), input(),
565                        SideEffects::DefaultResource::get());
566   effects.emplace_back(MemoryEffects::Write::get(), output(),
567                        SideEffects::DefaultResource::get());
568   effects.emplace_back(MemoryEffects::Allocate::get(), output(),
569                        SideEffects::DefaultResource::get());
570 }
571 
572 namespace {
573 /// Merge the clone and its source (by converting the clone to a cast) when
574 /// possible.
575 struct SimplifyClones : public OpRewritePattern<CloneOp> {
576   using OpRewritePattern<CloneOp>::OpRewritePattern;
577 
578   LogicalResult matchAndRewrite(CloneOp cloneOp,
579                                 PatternRewriter &rewriter) const override {
580     if (cloneOp.use_empty()) {
581       rewriter.eraseOp(cloneOp);
582       return success();
583     }
584 
585     Value source = cloneOp.input();
586 
587     // This only finds dealloc operations for the immediate value. It should
588     // also consider aliases. That would also make the safety check below
589     // redundant.
590     llvm::Optional<Operation *> maybeCloneDeallocOp =
591         findDealloc(cloneOp.output());
592     // Skip if either of them has > 1 deallocate operations.
593     if (!maybeCloneDeallocOp.hasValue())
594       return failure();
595     llvm::Optional<Operation *> maybeSourceDeallocOp = findDealloc(source);
596     if (!maybeSourceDeallocOp.hasValue())
597       return failure();
598     Operation *cloneDeallocOp = *maybeCloneDeallocOp;
599     Operation *sourceDeallocOp = *maybeSourceDeallocOp;
600 
601     // If both are deallocated in the same block, their in-block lifetimes
602     // might not fully overlap, so we cannot decide which one to drop.
603     if (cloneDeallocOp && sourceDeallocOp &&
604         cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
605       return failure();
606 
607     Block *currentBlock = cloneOp->getBlock();
608     Operation *redundantDealloc = nullptr;
609     if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
610       redundantDealloc = cloneDeallocOp;
611     } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
612       redundantDealloc = sourceDeallocOp;
613     }
614 
615     if (!redundantDealloc)
616       return failure();
617 
618     // Safety check that there are no other deallocations inbetween
619     // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
620     // of source before the uses of the clone. With alias information, we could
621     // restrict this to only fail of the dealloc's operand is an alias
622     // of the source.
623     for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
624          pos = pos->getNextNode()) {
625       auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
626       if (!effectInterface)
627         continue;
628       if (effectInterface.hasEffect<MemoryEffects::Free>())
629         return failure();
630     }
631 
632     rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(),
633                                                 source);
634     rewriter.eraseOp(redundantDealloc);
635     return success();
636   }
637 };
638 
639 } // end anonymous namespace.
640 
641 void CloneOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
642                                           MLIRContext *context) {
643   results.insert<SimplifyClones>(context);
644 }
645 
646 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) {
647   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
648 }
649 
650 Optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) {
651   return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
652       .getOperation();
653 }
654 
655 Optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
656   return builder.create<memref::CloneOp>(alloc.getLoc(), alloc).getResult();
657 }
658 
659 //===----------------------------------------------------------------------===//
660 // DeallocOp
661 //===----------------------------------------------------------------------===//
662 
663 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
664                               SmallVectorImpl<OpFoldResult> &results) {
665   /// dealloc(memrefcast) -> dealloc
666   return foldMemRefCast(*this);
667 }
668 
669 //===----------------------------------------------------------------------===//
670 // DimOp
671 //===----------------------------------------------------------------------===//
672 
673 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
674                   int64_t index) {
675   auto loc = result.location;
676   Value indexValue = builder.create<ConstantIndexOp>(loc, index);
677   build(builder, result, source, indexValue);
678 }
679 
680 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
681                   Value index) {
682   auto indexTy = builder.getIndexType();
683   build(builder, result, indexTy, source, index);
684 }
685 
686 Optional<int64_t> DimOp::getConstantIndex() {
687   if (auto constantOp = index().getDefiningOp<ConstantOp>())
688     return constantOp.getValue().cast<IntegerAttr>().getInt();
689   return {};
690 }
691 
692 static LogicalResult verify(DimOp op) {
693   // Assume unknown index to be in range.
694   Optional<int64_t> index = op.getConstantIndex();
695   if (!index.hasValue())
696     return success();
697 
698   // Check that constant index is not knowingly out of range.
699   auto type = op.source().getType();
700   if (auto memrefType = type.dyn_cast<MemRefType>()) {
701     if (index.getValue() >= memrefType.getRank())
702       return op.emitOpError("index is out of range");
703   } else if (type.isa<UnrankedMemRefType>()) {
704     // Assume index to be in range.
705   } else {
706     llvm_unreachable("expected operand with memref type");
707   }
708   return success();
709 }
710 
711 /// Return a map with key being elements in `vals` and data being number of
712 /// occurences of it. Use std::map, since the `vals` here are strides and the
713 /// dynamic stride value is the same as the tombstone value for
714 /// `DenseMap<int64_t>`.
715 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
716   std::map<int64_t, unsigned> numOccurences;
717   for (auto val : vals)
718     numOccurences[val]++;
719   return numOccurences;
720 }
721 
722 /// Given the type of the un-rank reduced subview result type and the
723 /// rank-reduced result type, computes the dropped dimensions. This accounts for
724 /// cases where there are multiple unit-dims, but only a subset of those are
725 /// dropped. For MemRefTypes these can be disambiguated using the strides. If a
726 /// dimension is dropped the stride must be dropped too.
727 static llvm::Optional<llvm::SmallDenseSet<unsigned>>
728 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
729                                ArrayAttr staticSizes) {
730   llvm::SmallDenseSet<unsigned> unusedDims;
731   if (originalType.getRank() == reducedType.getRank())
732     return unusedDims;
733 
734   for (auto dim : llvm::enumerate(staticSizes))
735     if (dim.value().cast<IntegerAttr>().getInt() == 1)
736       unusedDims.insert(dim.index());
737   SmallVector<int64_t> originalStrides, candidateStrides;
738   int64_t originalOffset, candidateOffset;
739   if (failed(
740           getStridesAndOffset(originalType, originalStrides, originalOffset)) ||
741       failed(
742           getStridesAndOffset(reducedType, candidateStrides, candidateOffset)))
743     return llvm::None;
744 
745   // For memrefs, a dimension is truly dropped if its corresponding stride is
746   // also dropped. This is particularly important when more than one of the dims
747   // is 1. Track the number of occurences of the strides in the original type
748   // and the candidate type. For each unused dim that stride should not be
749   // present in the candidate type. Note that there could be multiple dimensions
750   // that have the same size. We dont need to exactly figure out which dim
751   // corresponds to which stride, we just need to verify that the number of
752   // reptitions of a stride in the original + number of unused dims with that
753   // stride == number of repititions of a stride in the candidate.
754   std::map<int64_t, unsigned> currUnaccountedStrides =
755       getNumOccurences(originalStrides);
756   std::map<int64_t, unsigned> candidateStridesNumOccurences =
757       getNumOccurences(candidateStrides);
758   llvm::SmallDenseSet<unsigned> prunedUnusedDims;
759   for (unsigned dim : unusedDims) {
760     int64_t originalStride = originalStrides[dim];
761     if (currUnaccountedStrides[originalStride] >
762         candidateStridesNumOccurences[originalStride]) {
763       // This dim can be treated as dropped.
764       currUnaccountedStrides[originalStride]--;
765       continue;
766     }
767     if (currUnaccountedStrides[originalStride] ==
768         candidateStridesNumOccurences[originalStride]) {
769       // The stride for this is not dropped. Keep as is.
770       prunedUnusedDims.insert(dim);
771       continue;
772     }
773     if (currUnaccountedStrides[originalStride] <
774         candidateStridesNumOccurences[originalStride]) {
775       // This should never happen. Cant have a stride in the reduced rank type
776       // that wasnt in the original one.
777       return llvm::None;
778     }
779   }
780 
781   for (auto prunedDim : prunedUnusedDims)
782     unusedDims.erase(prunedDim);
783   if (unusedDims.size() + reducedType.getRank() != originalType.getRank())
784     return llvm::None;
785   return unusedDims;
786 }
787 
788 llvm::SmallDenseSet<unsigned> SubViewOp::getDroppedDims() {
789   MemRefType sourceType = getSourceType();
790   MemRefType resultType = getType();
791   llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims =
792       computeMemRefRankReductionMask(sourceType, resultType, static_sizes());
793   assert(unusedDims && "unable to find unused dims of subview");
794   return *unusedDims;
795 }
796 
797 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
798   // All forms of folding require a known index.
799   auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
800   if (!index)
801     return {};
802 
803   // Folding for unranked types (UnrankedMemRefType) is not supported.
804   auto memrefType = source().getType().dyn_cast<MemRefType>();
805   if (!memrefType)
806     return {};
807 
808   // Fold if the shape extent along the given index is known.
809   if (!memrefType.isDynamicDim(index.getInt())) {
810     Builder builder(getContext());
811     return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
812   }
813 
814   // The size at the given index is now known to be a dynamic size.
815   unsigned unsignedIndex = index.getValue().getZExtValue();
816 
817   // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
818   Operation *definingOp = source().getDefiningOp();
819 
820   if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
821     return *(alloc.getDynamicSizes().begin() +
822              memrefType.getDynamicDimIndex(unsignedIndex));
823 
824   if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
825     return *(alloca.getDynamicSizes().begin() +
826              memrefType.getDynamicDimIndex(unsignedIndex));
827 
828   if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
829     return *(view.getDynamicSizes().begin() +
830              memrefType.getDynamicDimIndex(unsignedIndex));
831 
832   if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
833     llvm::SmallDenseSet<unsigned> unusedDims = subview.getDroppedDims();
834     unsigned resultIndex = 0;
835     unsigned sourceRank = subview.getSourceType().getRank();
836     unsigned sourceIndex = 0;
837     for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
838       if (unusedDims.count(i))
839         continue;
840       if (resultIndex == unsignedIndex) {
841         sourceIndex = i;
842         break;
843       }
844       resultIndex++;
845     }
846     assert(subview.isDynamicSize(sourceIndex) &&
847            "expected dynamic subview size");
848     return subview.getDynamicSize(sourceIndex);
849   }
850 
851   if (auto sizeInterface =
852           dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
853     assert(sizeInterface.isDynamicSize(unsignedIndex) &&
854            "Expected dynamic subview size");
855     return sizeInterface.getDynamicSize(unsignedIndex);
856   }
857 
858   // dim(memrefcast) -> dim
859   if (succeeded(foldMemRefCast(*this)))
860     return getResult();
861 
862   return {};
863 }
864 
865 namespace {
866 /// Fold dim of a memref reshape operation to a load into the reshape's shape
867 /// operand.
868 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
869   using OpRewritePattern<DimOp>::OpRewritePattern;
870 
871   LogicalResult matchAndRewrite(DimOp dim,
872                                 PatternRewriter &rewriter) const override {
873     auto reshape = dim.source().getDefiningOp<ReshapeOp>();
874 
875     if (!reshape)
876       return failure();
877 
878     // Place the load directly after the reshape to ensure that the shape memref
879     // was not mutated.
880     rewriter.setInsertionPointAfter(reshape);
881     Location loc = dim.getLoc();
882     Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index());
883     if (load.getType() != dim.getType())
884       load = rewriter.create<IndexCastOp>(loc, dim.getType(), load);
885     rewriter.replaceOp(dim, load);
886     return success();
887   }
888 };
889 
890 /// Fold dim of a cast into the dim of the source of the memref cast.
891 struct DimOfCastOp : public OpRewritePattern<DimOp> {
892   using OpRewritePattern<DimOp>::OpRewritePattern;
893 
894   LogicalResult matchAndRewrite(DimOp dimOp,
895                                 PatternRewriter &rewriter) const override {
896     auto castOp = dimOp.source().getDefiningOp<BufferCastOp>();
897     if (!castOp)
898       return failure();
899     Value newSource = castOp.getOperand();
900     rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource, dimOp.index());
901     return success();
902   }
903 };
904 } // end anonymous namespace.
905 
906 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
907                                         MLIRContext *context) {
908   results.add<DimOfMemRefReshape, DimOfCastOp>(context);
909 }
910 
911 // ---------------------------------------------------------------------------
912 // DmaStartOp
913 // ---------------------------------------------------------------------------
914 
915 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
916                        Value srcMemRef, ValueRange srcIndices, Value destMemRef,
917                        ValueRange destIndices, Value numElements,
918                        Value tagMemRef, ValueRange tagIndices, Value stride,
919                        Value elementsPerStride) {
920   result.addOperands(srcMemRef);
921   result.addOperands(srcIndices);
922   result.addOperands(destMemRef);
923   result.addOperands(destIndices);
924   result.addOperands({numElements, tagMemRef});
925   result.addOperands(tagIndices);
926   if (stride)
927     result.addOperands({stride, elementsPerStride});
928 }
929 
930 static void print(OpAsmPrinter &p, DmaStartOp op) {
931   p << " " << op.getSrcMemRef() << '[' << op.getSrcIndices() << "], "
932     << op.getDstMemRef() << '[' << op.getDstIndices() << "], "
933     << op.getNumElements() << ", " << op.getTagMemRef() << '['
934     << op.getTagIndices() << ']';
935   if (op.isStrided())
936     p << ", " << op.getStride() << ", " << op.getNumElementsPerStride();
937 
938   p.printOptionalAttrDict(op->getAttrs());
939   p << " : " << op.getSrcMemRef().getType() << ", "
940     << op.getDstMemRef().getType() << ", " << op.getTagMemRef().getType();
941 }
942 
943 // Parse DmaStartOp.
944 // Ex:
945 //   %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
946 //                       %tag[%index], %stride, %num_elt_per_stride :
947 //                     : memref<3076 x f32, 0>,
948 //                       memref<1024 x f32, 2>,
949 //                       memref<1 x i32>
950 //
951 static ParseResult parseDmaStartOp(OpAsmParser &parser,
952                                    OperationState &result) {
953   OpAsmParser::OperandType srcMemRefInfo;
954   SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
955   OpAsmParser::OperandType dstMemRefInfo;
956   SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos;
957   OpAsmParser::OperandType numElementsInfo;
958   OpAsmParser::OperandType tagMemrefInfo;
959   SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
960   SmallVector<OpAsmParser::OperandType, 2> strideInfo;
961 
962   SmallVector<Type, 3> types;
963   auto indexType = parser.getBuilder().getIndexType();
964 
965   // Parse and resolve the following list of operands:
966   // *) source memref followed by its indices (in square brackets).
967   // *) destination memref followed by its indices (in square brackets).
968   // *) dma size in KiB.
969   if (parser.parseOperand(srcMemRefInfo) ||
970       parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
971       parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
972       parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
973       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
974       parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
975       parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
976     return failure();
977 
978   // Parse optional stride and elements per stride.
979   if (parser.parseTrailingOperandList(strideInfo))
980     return failure();
981 
982   bool isStrided = strideInfo.size() == 2;
983   if (!strideInfo.empty() && !isStrided) {
984     return parser.emitError(parser.getNameLoc(),
985                             "expected two stride related operands");
986   }
987 
988   if (parser.parseColonTypeList(types))
989     return failure();
990   if (types.size() != 3)
991     return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
992 
993   if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
994       parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
995       parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
996       parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
997       // size should be an index.
998       parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
999       parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1000       // tag indices should be index.
1001       parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1002     return failure();
1003 
1004   if (isStrided) {
1005     if (parser.resolveOperands(strideInfo, indexType, result.operands))
1006       return failure();
1007   }
1008 
1009   return success();
1010 }
1011 
1012 static LogicalResult verify(DmaStartOp op) {
1013   unsigned numOperands = op.getNumOperands();
1014 
1015   // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1016   // the number of elements.
1017   if (numOperands < 4)
1018     return op.emitOpError("expected at least 4 operands");
1019 
1020   // Check types of operands. The order of these calls is important: the later
1021   // calls rely on some type properties to compute the operand position.
1022   // 1. Source memref.
1023   if (!op.getSrcMemRef().getType().isa<MemRefType>())
1024     return op.emitOpError("expected source to be of memref type");
1025   if (numOperands < op.getSrcMemRefRank() + 4)
1026     return op.emitOpError()
1027            << "expected at least " << op.getSrcMemRefRank() + 4 << " operands";
1028   if (!op.getSrcIndices().empty() &&
1029       !llvm::all_of(op.getSrcIndices().getTypes(),
1030                     [](Type t) { return t.isIndex(); }))
1031     return op.emitOpError("expected source indices to be of index type");
1032 
1033   // 2. Destination memref.
1034   if (!op.getDstMemRef().getType().isa<MemRefType>())
1035     return op.emitOpError("expected destination to be of memref type");
1036   unsigned numExpectedOperands =
1037       op.getSrcMemRefRank() + op.getDstMemRefRank() + 4;
1038   if (numOperands < numExpectedOperands)
1039     return op.emitOpError()
1040            << "expected at least " << numExpectedOperands << " operands";
1041   if (!op.getDstIndices().empty() &&
1042       !llvm::all_of(op.getDstIndices().getTypes(),
1043                     [](Type t) { return t.isIndex(); }))
1044     return op.emitOpError("expected destination indices to be of index type");
1045 
1046   // 3. Number of elements.
1047   if (!op.getNumElements().getType().isIndex())
1048     return op.emitOpError("expected num elements to be of index type");
1049 
1050   // 4. Tag memref.
1051   if (!op.getTagMemRef().getType().isa<MemRefType>())
1052     return op.emitOpError("expected tag to be of memref type");
1053   numExpectedOperands += op.getTagMemRefRank();
1054   if (numOperands < numExpectedOperands)
1055     return op.emitOpError()
1056            << "expected at least " << numExpectedOperands << " operands";
1057   if (!op.getTagIndices().empty() &&
1058       !llvm::all_of(op.getTagIndices().getTypes(),
1059                     [](Type t) { return t.isIndex(); }))
1060     return op.emitOpError("expected tag indices to be of index type");
1061 
1062   // Optional stride-related operands must be either both present or both
1063   // absent.
1064   if (numOperands != numExpectedOperands &&
1065       numOperands != numExpectedOperands + 2)
1066     return op.emitOpError("incorrect number of operands");
1067 
1068   // 5. Strides.
1069   if (op.isStrided()) {
1070     if (!op.getStride().getType().isIndex() ||
1071         !op.getNumElementsPerStride().getType().isIndex())
1072       return op.emitOpError(
1073           "expected stride and num elements per stride to be of type index");
1074   }
1075 
1076   return success();
1077 }
1078 
1079 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1080                                SmallVectorImpl<OpFoldResult> &results) {
1081   /// dma_start(memrefcast) -> dma_start
1082   return foldMemRefCast(*this);
1083 }
1084 
1085 // ---------------------------------------------------------------------------
1086 // DmaWaitOp
1087 // ---------------------------------------------------------------------------
1088 
1089 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1090                               SmallVectorImpl<OpFoldResult> &results) {
1091   /// dma_wait(memrefcast) -> dma_wait
1092   return foldMemRefCast(*this);
1093 }
1094 
1095 static LogicalResult verify(DmaWaitOp op) {
1096   // Check that the number of tag indices matches the tagMemRef rank.
1097   unsigned numTagIndices = op.tagIndices().size();
1098   unsigned tagMemRefRank = op.getTagMemRefRank();
1099   if (numTagIndices != tagMemRefRank)
1100     return op.emitOpError() << "expected tagIndices to have the same number of "
1101                                "elements as the tagMemRef rank, expected "
1102                             << tagMemRefRank << ", but got " << numTagIndices;
1103   return success();
1104 }
1105 
1106 //===----------------------------------------------------------------------===//
1107 // GlobalOp
1108 //===----------------------------------------------------------------------===//
1109 
1110 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op,
1111                                                    TypeAttr type,
1112                                                    Attribute initialValue) {
1113   p << type;
1114   if (!op.isExternal()) {
1115     p << " = ";
1116     if (op.isUninitialized())
1117       p << "uninitialized";
1118     else
1119       p.printAttributeWithoutType(initialValue);
1120   }
1121 }
1122 
1123 static ParseResult
1124 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
1125                                        Attribute &initialValue) {
1126   Type type;
1127   if (parser.parseType(type))
1128     return failure();
1129 
1130   auto memrefType = type.dyn_cast<MemRefType>();
1131   if (!memrefType || !memrefType.hasStaticShape())
1132     return parser.emitError(parser.getNameLoc())
1133            << "type should be static shaped memref, but got " << type;
1134   typeAttr = TypeAttr::get(type);
1135 
1136   if (parser.parseOptionalEqual())
1137     return success();
1138 
1139   if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1140     initialValue = UnitAttr::get(parser.getContext());
1141     return success();
1142   }
1143 
1144   Type tensorType = getTensorTypeFromMemRefType(memrefType);
1145   if (parser.parseAttribute(initialValue, tensorType))
1146     return failure();
1147   if (!initialValue.isa<ElementsAttr>())
1148     return parser.emitError(parser.getNameLoc())
1149            << "initial value should be a unit or elements attribute";
1150   return success();
1151 }
1152 
1153 static LogicalResult verify(GlobalOp op) {
1154   auto memrefType = op.type().dyn_cast<MemRefType>();
1155   if (!memrefType || !memrefType.hasStaticShape())
1156     return op.emitOpError("type should be static shaped memref, but got ")
1157            << op.type();
1158 
1159   // Verify that the initial value, if present, is either a unit attribute or
1160   // an elements attribute.
1161   if (op.initial_value().hasValue()) {
1162     Attribute initValue = op.initial_value().getValue();
1163     if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>())
1164       return op.emitOpError("initial value should be a unit or elements "
1165                             "attribute, but got ")
1166              << initValue;
1167 
1168     // Check that the type of the initial value is compatible with the type of
1169     // the global variable.
1170     if (initValue.isa<ElementsAttr>()) {
1171       Type initType = initValue.getType();
1172       Type tensorType = getTensorTypeFromMemRefType(memrefType);
1173       if (initType != tensorType)
1174         return op.emitOpError("initial value expected to be of type ")
1175                << tensorType << ", but was of type " << initType;
1176     }
1177   }
1178 
1179   // TODO: verify visibility for declarations.
1180   return success();
1181 }
1182 
1183 //===----------------------------------------------------------------------===//
1184 // GetGlobalOp
1185 //===----------------------------------------------------------------------===//
1186 
1187 LogicalResult
1188 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1189   // Verify that the result type is same as the type of the referenced
1190   // memref.global op.
1191   auto global =
1192       symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr());
1193   if (!global)
1194     return emitOpError("'")
1195            << name() << "' does not reference a valid global memref";
1196 
1197   Type resultType = result().getType();
1198   if (global.type() != resultType)
1199     return emitOpError("result type ")
1200            << resultType << " does not match type " << global.type()
1201            << " of the global memref @" << name();
1202   return success();
1203 }
1204 
1205 //===----------------------------------------------------------------------===//
1206 // LoadOp
1207 //===----------------------------------------------------------------------===//
1208 
1209 static LogicalResult verify(LoadOp op) {
1210   if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
1211     return op.emitOpError("incorrect number of indices for load");
1212   return success();
1213 }
1214 
1215 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
1216   /// load(memrefcast) -> load
1217   if (succeeded(foldMemRefCast(*this)))
1218     return getResult();
1219   return OpFoldResult();
1220 }
1221 
1222 namespace {
1223 /// Fold a load on a buffer_cast operation into an tensor.extract on the
1224 /// corresponding tensor.
1225 struct LoadOfBufferCast : public OpRewritePattern<LoadOp> {
1226   using OpRewritePattern<LoadOp>::OpRewritePattern;
1227 
1228   LogicalResult matchAndRewrite(LoadOp load,
1229                                 PatternRewriter &rewriter) const override {
1230     auto buffercast = load.memref().getDefiningOp<BufferCastOp>();
1231     if (!buffercast)
1232       return failure();
1233 
1234     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, buffercast.tensor(),
1235                                                    load.indices());
1236     return success();
1237   }
1238 };
1239 } // end anonymous namespace.
1240 
1241 void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
1242                                          MLIRContext *context) {
1243   results.add<LoadOfBufferCast>(context);
1244 }
1245 
1246 //===----------------------------------------------------------------------===//
1247 // PrefetchOp
1248 //===----------------------------------------------------------------------===//
1249 
1250 static void print(OpAsmPrinter &p, PrefetchOp op) {
1251   p << " " << op.memref() << '[';
1252   p.printOperands(op.indices());
1253   p << ']' << ", " << (op.isWrite() ? "write" : "read");
1254   p << ", locality<" << op.localityHint();
1255   p << ">, " << (op.isDataCache() ? "data" : "instr");
1256   p.printOptionalAttrDict(
1257       op->getAttrs(),
1258       /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1259   p << " : " << op.getMemRefType();
1260 }
1261 
1262 static ParseResult parsePrefetchOp(OpAsmParser &parser,
1263                                    OperationState &result) {
1264   OpAsmParser::OperandType memrefInfo;
1265   SmallVector<OpAsmParser::OperandType, 4> indexInfo;
1266   IntegerAttr localityHint;
1267   MemRefType type;
1268   StringRef readOrWrite, cacheType;
1269 
1270   auto indexTy = parser.getBuilder().getIndexType();
1271   auto i32Type = parser.getBuilder().getIntegerType(32);
1272   if (parser.parseOperand(memrefInfo) ||
1273       parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1274       parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1275       parser.parseComma() || parser.parseKeyword("locality") ||
1276       parser.parseLess() ||
1277       parser.parseAttribute(localityHint, i32Type, "localityHint",
1278                             result.attributes) ||
1279       parser.parseGreater() || parser.parseComma() ||
1280       parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1281       parser.resolveOperand(memrefInfo, type, result.operands) ||
1282       parser.resolveOperands(indexInfo, indexTy, result.operands))
1283     return failure();
1284 
1285   if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
1286     return parser.emitError(parser.getNameLoc(),
1287                             "rw specifier has to be 'read' or 'write'");
1288   result.addAttribute(
1289       PrefetchOp::getIsWriteAttrName(),
1290       parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
1291 
1292   if (!cacheType.equals("data") && !cacheType.equals("instr"))
1293     return parser.emitError(parser.getNameLoc(),
1294                             "cache type has to be 'data' or 'instr'");
1295 
1296   result.addAttribute(
1297       PrefetchOp::getIsDataCacheAttrName(),
1298       parser.getBuilder().getBoolAttr(cacheType.equals("data")));
1299 
1300   return success();
1301 }
1302 
1303 static LogicalResult verify(PrefetchOp op) {
1304   if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
1305     return op.emitOpError("too few indices");
1306 
1307   return success();
1308 }
1309 
1310 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
1311                                SmallVectorImpl<OpFoldResult> &results) {
1312   // prefetch(memrefcast) -> prefetch
1313   return foldMemRefCast(*this);
1314 }
1315 
1316 //===----------------------------------------------------------------------===//
1317 // ReinterpretCastOp
1318 //===----------------------------------------------------------------------===//
1319 
1320 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1321 /// `staticSizes` and `staticStrides` are automatically filled with
1322 /// source-memref-rank sentinel values that encode dynamic entries.
1323 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1324                               MemRefType resultType, Value source,
1325                               OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1326                               ArrayRef<OpFoldResult> strides,
1327                               ArrayRef<NamedAttribute> attrs) {
1328   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1329   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1330   dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets,
1331                              ShapedType::kDynamicStrideOrOffset);
1332   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1333                              ShapedType::kDynamicSize);
1334   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1335                              ShapedType::kDynamicStrideOrOffset);
1336   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1337         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1338         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1339   result.addAttributes(attrs);
1340 }
1341 
1342 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1343                               MemRefType resultType, Value source,
1344                               int64_t offset, ArrayRef<int64_t> sizes,
1345                               ArrayRef<int64_t> strides,
1346                               ArrayRef<NamedAttribute> attrs) {
1347   SmallVector<OpFoldResult> sizeValues =
1348       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1349         return b.getI64IntegerAttr(v);
1350       }));
1351   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1352       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1353         return b.getI64IntegerAttr(v);
1354       }));
1355   build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1356         strideValues, attrs);
1357 }
1358 
1359 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1360                               MemRefType resultType, Value source, Value offset,
1361                               ValueRange sizes, ValueRange strides,
1362                               ArrayRef<NamedAttribute> attrs) {
1363   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1364       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1365   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1366       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1367   build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1368 }
1369 
1370 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1371 // completed automatically, like we have for subview and extract_slice.
1372 static LogicalResult verify(ReinterpretCastOp op) {
1373   // The source and result memrefs should be in the same memory space.
1374   auto srcType = op.source().getType().cast<BaseMemRefType>();
1375   auto resultType = op.getType().cast<MemRefType>();
1376   if (srcType.getMemorySpace() != resultType.getMemorySpace())
1377     return op.emitError("different memory spaces specified for source type ")
1378            << srcType << " and result memref type " << resultType;
1379   if (srcType.getElementType() != resultType.getElementType())
1380     return op.emitError("different element types specified for source type ")
1381            << srcType << " and result memref type " << resultType;
1382 
1383   // Match sizes in result memref type and in static_sizes attribute.
1384   for (auto &en :
1385        llvm::enumerate(llvm::zip(resultType.getShape(),
1386                                  extractFromI64ArrayAttr(op.static_sizes())))) {
1387     int64_t resultSize = std::get<0>(en.value());
1388     int64_t expectedSize = std::get<1>(en.value());
1389     if (resultSize != expectedSize)
1390       return op.emitError("expected result type with size = ")
1391              << expectedSize << " instead of " << resultSize
1392              << " in dim = " << en.index();
1393   }
1394 
1395   // Match offset and strides in static_offset and static_strides attributes if
1396   // result memref type has an affine map specified.
1397   if (!resultType.getAffineMaps().empty()) {
1398     int64_t resultOffset;
1399     SmallVector<int64_t, 4> resultStrides;
1400     if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
1401       return failure();
1402 
1403     // Match offset in result memref type and in static_offsets attribute.
1404     int64_t expectedOffset =
1405         extractFromI64ArrayAttr(op.static_offsets()).front();
1406     if (resultOffset != expectedOffset)
1407       return op.emitError("expected result type with offset = ")
1408              << resultOffset << " instead of " << expectedOffset;
1409 
1410     // Match strides in result memref type and in static_strides attribute.
1411     for (auto &en : llvm::enumerate(llvm::zip(
1412              resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
1413       int64_t resultStride = std::get<0>(en.value());
1414       int64_t expectedStride = std::get<1>(en.value());
1415       if (resultStride != expectedStride)
1416         return op.emitError("expected result type with stride = ")
1417                << expectedStride << " instead of " << resultStride
1418                << " in dim = " << en.index();
1419     }
1420   }
1421   return success();
1422 }
1423 
1424 //===----------------------------------------------------------------------===//
1425 // Reassociative reshape ops
1426 //===----------------------------------------------------------------------===//
1427 
1428 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1429   return getSymbolLessAffineMaps(getReassociationExprs());
1430 }
1431 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1432   return convertReassociationIndicesToExprs(getContext(),
1433                                             getReassociationIndices());
1434 }
1435 
1436 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1437   return getSymbolLessAffineMaps(getReassociationExprs());
1438 }
1439 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1440   return convertReassociationIndicesToExprs(getContext(),
1441                                             getReassociationIndices());
1442 }
1443 
1444 static void print(OpAsmPrinter &p, ExpandShapeOp op) {
1445   ::mlir::printReshapeOp<ExpandShapeOp>(p, op);
1446 }
1447 
1448 static void print(OpAsmPrinter &p, CollapseShapeOp op) {
1449   ::mlir::printReshapeOp<CollapseShapeOp>(p, op);
1450 }
1451 
1452 /// Detect whether memref dims [dim, dim + extent) can be reshaped without
1453 /// copies.
1454 static bool isReshapableDimBand(unsigned dim, unsigned extent,
1455                                 ArrayRef<int64_t> sizes,
1456                                 ArrayRef<AffineExpr> strides) {
1457   assert(sizes.size() == strides.size() && "mismatched ranks");
1458   // off by 1 indexing to avoid out of bounds
1459   //                       V
1460   for (auto idx = dim, e = dim + extent; idx + 1 < e; ++idx) {
1461     // Only bands of static shapes are reshapable. This is due to the fact that
1462     // there is no relation between dynamic sizes and dynamic strides: we do not
1463     // have enough information to know whether a "-1" size corresponds to the
1464     // proper symbol in the AffineExpr of a stride.
1465     if (ShapedType::isDynamic(sizes[dim + 1]))
1466       return false;
1467     // TODO: Refine this by passing the proper nDims and nSymbols so we can
1468     // simplify on the fly and catch more reshapable cases.
1469     if (strides[idx] != strides[idx + 1] * sizes[idx + 1])
1470       return false;
1471   }
1472   return true;
1473 }
1474 
1475 /// Compute the MemRefType obtained by applying the `reassociation` (which is
1476 /// expected to be valid) to `type`.
1477 /// If `type` is Contiguous MemRefType, this always produce a contiguous
1478 /// MemRefType.
1479 static MemRefType
1480 computeReshapeCollapsedType(MemRefType type,
1481                             ArrayRef<AffineMap> reassociation) {
1482   auto sizes = type.getShape();
1483   AffineExpr offset;
1484   SmallVector<AffineExpr, 4> strides;
1485   auto status = getStridesAndOffset(type, strides, offset);
1486   (void)status;
1487   assert(succeeded(status) && "expected strided memref");
1488 
1489   SmallVector<int64_t, 4> newSizes;
1490   newSizes.reserve(reassociation.size());
1491   SmallVector<AffineExpr, 4> newStrides;
1492   newStrides.reserve(reassociation.size());
1493 
1494   // Use the fact that reassociation is valid to simplify the logic: only use
1495   // each map's rank.
1496   assert(isReassociationValid(reassociation) && "invalid reassociation");
1497   unsigned currentDim = 0;
1498   for (AffineMap m : reassociation) {
1499     unsigned dim = m.getNumResults();
1500     int64_t size = 1;
1501     AffineExpr stride = strides[currentDim + dim - 1];
1502     if (!isReshapableDimBand(currentDim, dim, sizes, strides)) {
1503       size = ShapedType::kDynamicSize;
1504       stride = AffineExpr();
1505     } else {
1506       for (unsigned d = 0; d < dim; ++d)
1507         size *= sizes[currentDim + d];
1508     }
1509     newSizes.push_back(size);
1510     newStrides.push_back(stride);
1511     currentDim += dim;
1512   }
1513 
1514   // Early-exit: if `type` is contiguous, the result must be contiguous.
1515   if (canonicalizeStridedLayout(type).getAffineMaps().empty())
1516     return MemRefType::Builder(type).setShape(newSizes).setAffineMaps({});
1517 
1518   // Convert back to int64_t because we don't have enough information to create
1519   // new strided layouts from AffineExpr only. This corresponds to a case where
1520   // copies may be necessary.
1521   int64_t intOffset = ShapedType::kDynamicStrideOrOffset;
1522   if (auto o = offset.dyn_cast<AffineConstantExpr>())
1523     intOffset = o.getValue();
1524   SmallVector<int64_t, 4> intStrides;
1525   intStrides.reserve(strides.size());
1526   for (auto stride : newStrides) {
1527     if (auto cst = stride.dyn_cast_or_null<AffineConstantExpr>())
1528       intStrides.push_back(cst.getValue());
1529     else
1530       intStrides.push_back(ShapedType::kDynamicStrideOrOffset);
1531   }
1532   auto layout =
1533       makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext());
1534   return canonicalizeStridedLayout(
1535       MemRefType::Builder(type).setShape(newSizes).setAffineMaps({layout}));
1536 }
1537 
1538 void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src,
1539                           ArrayRef<ReassociationIndices> reassociation,
1540                           ArrayRef<NamedAttribute> attrs) {
1541   auto memRefType = src.getType().cast<MemRefType>();
1542   auto resultType = computeReshapeCollapsedType(
1543       memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
1544                       b.getContext(), reassociation)));
1545   build(b, result, resultType, src, attrs);
1546   result.addAttribute(getReassociationAttrName(),
1547                       getReassociationIndicesAttribute(b, reassociation));
1548 }
1549 
1550 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
1551                             ArrayRef<ReassociationIndices> reassociation,
1552                             ArrayRef<NamedAttribute> attrs) {
1553   auto memRefType = src.getType().cast<MemRefType>();
1554   auto resultType = computeReshapeCollapsedType(
1555       memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
1556                       b.getContext(), reassociation)));
1557   build(b, result, resultType, src, attrs);
1558   result.addAttribute(getReassociationAttrName(),
1559                       getReassociationIndicesAttribute(b, reassociation));
1560 }
1561 
1562 template <typename ReshapeOp,
1563           bool isExpansion = std::is_same<ReshapeOp, ExpandShapeOp>::value>
1564 static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType,
1565                                      MemRefType collapsedType) {
1566   if (failed(
1567           verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
1568     return failure();
1569   auto maps = op.getReassociationMaps();
1570   MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps);
1571   if (collapsedType != expectedType)
1572     return op.emitOpError("expected collapsed type to be ")
1573            << expectedType << ", but got " << collapsedType;
1574   return success();
1575 }
1576 
1577 static LogicalResult verify(ExpandShapeOp op) {
1578   return verifyReshapeOp(op, op.getResultType(), op.getSrcType());
1579 }
1580 
1581 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1582                                                 MLIRContext *context) {
1583   results.add<CollapseReshapeOps<ExpandShapeOp>,
1584               CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context);
1585 }
1586 
1587 static LogicalResult verify(CollapseShapeOp op) {
1588   return verifyReshapeOp(op, op.getSrcType(), op.getResultType());
1589 }
1590 
1591 struct CollapseShapeOpMemRefCastFolder
1592     : public OpRewritePattern<CollapseShapeOp> {
1593 public:
1594   using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
1595 
1596   LogicalResult matchAndRewrite(CollapseShapeOp op,
1597                                 PatternRewriter &rewriter) const override {
1598     auto cast = op.getOperand().getDefiningOp<CastOp>();
1599     if (!cast)
1600       return failure();
1601 
1602     if (!CastOp::canFoldIntoConsumerOp(cast))
1603       return failure();
1604 
1605     Type newResultType = computeReshapeCollapsedType(
1606         cast.getOperand().getType().cast<MemRefType>(),
1607         op.getReassociationMaps());
1608 
1609     if (newResultType == op.getResultType()) {
1610       rewriter.updateRootInPlace(
1611           op, [&]() { op.srcMutable().assign(cast.source()); });
1612     } else {
1613       Value newOp = rewriter.create<CollapseShapeOp>(
1614           op->getLoc(), cast.source(), op.getReassociationIndices());
1615       rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
1616     }
1617     return success();
1618   }
1619 };
1620 
1621 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1622                                                   MLIRContext *context) {
1623   results.add<CollapseReshapeOps<CollapseShapeOp>,
1624               CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>,
1625               CollapseShapeOpMemRefCastFolder>(context);
1626 }
1627 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
1628   if (succeeded(foldMemRefCast(*this)))
1629     return getResult();
1630   return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
1631 }
1632 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
1633   return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
1634 }
1635 
1636 //===----------------------------------------------------------------------===//
1637 // ReshapeOp
1638 //===----------------------------------------------------------------------===//
1639 
1640 static LogicalResult verify(ReshapeOp op) {
1641   Type operandType = op.source().getType();
1642   Type resultType = op.result().getType();
1643 
1644   Type operandElementType = operandType.cast<ShapedType>().getElementType();
1645   Type resultElementType = resultType.cast<ShapedType>().getElementType();
1646   if (operandElementType != resultElementType)
1647     return op.emitOpError("element types of source and destination memref "
1648                           "types should be the same");
1649 
1650   if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
1651     if (!operandMemRefType.getAffineMaps().empty())
1652       return op.emitOpError(
1653           "source memref type should have identity affine map");
1654 
1655   int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0);
1656   auto resultMemRefType = resultType.dyn_cast<MemRefType>();
1657   if (resultMemRefType) {
1658     if (!resultMemRefType.getAffineMaps().empty())
1659       return op.emitOpError(
1660           "result memref type should have identity affine map");
1661     if (shapeSize == ShapedType::kDynamicSize)
1662       return op.emitOpError("cannot use shape operand with dynamic length to "
1663                             "reshape to statically-ranked memref type");
1664     if (shapeSize != resultMemRefType.getRank())
1665       return op.emitOpError(
1666           "length of shape operand differs from the result's memref rank");
1667   }
1668   return success();
1669 }
1670 
1671 //===----------------------------------------------------------------------===//
1672 // StoreOp
1673 //===----------------------------------------------------------------------===//
1674 
1675 static LogicalResult verify(StoreOp op) {
1676   if (op.getNumOperands() != 2 + op.getMemRefType().getRank())
1677     return op.emitOpError("store index operand count not equal to memref rank");
1678 
1679   return success();
1680 }
1681 
1682 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
1683                             SmallVectorImpl<OpFoldResult> &results) {
1684   /// store(memrefcast) -> store
1685   return foldMemRefCast(*this, getValueToStore());
1686 }
1687 
1688 //===----------------------------------------------------------------------===//
1689 // SubViewOp
1690 //===----------------------------------------------------------------------===//
1691 
1692 namespace {
1693 /// Helpers to write more idiomatic operations.
1694 namespace saturated_arith {
1695 struct Wrapper {
1696   explicit Wrapper(int64_t v) : v(v) {}
1697   operator int64_t() { return v; }
1698   int64_t v;
1699 };
1700 Wrapper operator+(Wrapper a, int64_t b) {
1701   if (ShapedType::isDynamicStrideOrOffset(a) ||
1702       ShapedType::isDynamicStrideOrOffset(b))
1703     return Wrapper(ShapedType::kDynamicStrideOrOffset);
1704   return Wrapper(a.v + b);
1705 }
1706 Wrapper operator*(Wrapper a, int64_t b) {
1707   if (ShapedType::isDynamicStrideOrOffset(a) ||
1708       ShapedType::isDynamicStrideOrOffset(b))
1709     return Wrapper(ShapedType::kDynamicStrideOrOffset);
1710   return Wrapper(a.v * b);
1711 }
1712 } // end namespace saturated_arith
1713 } // end namespace
1714 
1715 /// A subview result type can be fully inferred from the source type and the
1716 /// static representation of offsets, sizes and strides. Special sentinels
1717 /// encode the dynamic case.
1718 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
1719                                 ArrayRef<int64_t> leadingStaticOffsets,
1720                                 ArrayRef<int64_t> leadingStaticSizes,
1721                                 ArrayRef<int64_t> leadingStaticStrides) {
1722   // A subview may specify only a leading subset of offset/sizes/strides in
1723   // which case we complete with offset=0, sizes from memref type and strides=1.
1724   unsigned rank = sourceMemRefType.getRank();
1725   assert(leadingStaticOffsets.size() <= rank &&
1726          "unexpected leadingStaticOffsets overflow");
1727   assert(leadingStaticSizes.size() <= rank &&
1728          "unexpected leadingStaticSizes overflow");
1729   assert(leadingStaticStrides.size() <= rank &&
1730          "unexpected leadingStaticStrides overflow");
1731   auto staticOffsets = llvm::to_vector<4>(leadingStaticOffsets);
1732   auto staticSizes = llvm::to_vector<4>(leadingStaticSizes);
1733   auto staticStrides = llvm::to_vector<4>(leadingStaticStrides);
1734   unsigned numTrailingOffsets = rank - staticOffsets.size();
1735   unsigned numTrailingSizes = rank - staticSizes.size();
1736   unsigned numTrailingStrides = rank - staticStrides.size();
1737   staticOffsets.append(numTrailingOffsets, 0);
1738   llvm::append_range(staticSizes,
1739                      sourceMemRefType.getShape().take_back(numTrailingSizes));
1740   staticStrides.append(numTrailingStrides, 1);
1741 
1742   // Extract source offset and strides.
1743   int64_t sourceOffset;
1744   SmallVector<int64_t, 4> sourceStrides;
1745   auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset);
1746   assert(succeeded(res) && "SubViewOp expected strided memref type");
1747   (void)res;
1748 
1749   // Compute target offset whose value is:
1750   //   `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
1751   int64_t targetOffset = sourceOffset;
1752   for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
1753     auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
1754     using namespace saturated_arith;
1755     targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride;
1756   }
1757 
1758   // Compute target stride whose value is:
1759   //   `sourceStrides_i * staticStrides_i`.
1760   SmallVector<int64_t, 4> targetStrides;
1761   targetStrides.reserve(staticOffsets.size());
1762   for (auto it : llvm::zip(sourceStrides, staticStrides)) {
1763     auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
1764     using namespace saturated_arith;
1765     targetStrides.push_back(Wrapper(sourceStride) * staticStride);
1766   }
1767 
1768   // The type is now known.
1769   return MemRefType::get(
1770       staticSizes, sourceMemRefType.getElementType(),
1771       makeStridedLinearLayoutMap(targetStrides, targetOffset,
1772                                  sourceMemRefType.getContext()),
1773       sourceMemRefType.getMemorySpace());
1774 }
1775 
1776 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
1777                                 ArrayRef<OpFoldResult> leadingStaticOffsets,
1778                                 ArrayRef<OpFoldResult> leadingStaticSizes,
1779                                 ArrayRef<OpFoldResult> leadingStaticStrides) {
1780   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1781   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1782   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
1783                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
1784   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
1785                              ShapedType::kDynamicSize);
1786   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
1787                              staticStrides, ShapedType::kDynamicStrideOrOffset);
1788   return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
1789                                     staticSizes, staticStrides)
1790       .cast<MemRefType>();
1791 }
1792 
1793 Type SubViewOp::inferRankReducedResultType(
1794     unsigned resultRank, MemRefType sourceRankedTensorType,
1795     ArrayRef<int64_t> leadingStaticOffsets,
1796     ArrayRef<int64_t> leadingStaticSizes,
1797     ArrayRef<int64_t> leadingStaticStrides) {
1798   auto inferredType =
1799       inferResultType(sourceRankedTensorType, leadingStaticOffsets,
1800                       leadingStaticSizes, leadingStaticStrides)
1801           .cast<MemRefType>();
1802   assert(inferredType.getRank() >= resultRank && "expected ");
1803   int rankDiff = inferredType.getRank() - resultRank;
1804   if (rankDiff > 0) {
1805     auto shape = inferredType.getShape();
1806     llvm::SmallDenseSet<unsigned> dimsToProject;
1807     mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject);
1808     SmallVector<int64_t> projectedShape;
1809     for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
1810       if (!dimsToProject.contains(pos))
1811         projectedShape.push_back(shape[pos]);
1812 
1813     AffineMap map;
1814     auto maps = inferredType.getAffineMaps();
1815     if (!maps.empty() && maps.front())
1816       map = getProjectedMap(maps.front(), dimsToProject);
1817     inferredType =
1818         MemRefType::get(projectedShape, inferredType.getElementType(), map,
1819                         inferredType.getMemorySpace());
1820   }
1821   return inferredType;
1822 }
1823 
1824 Type SubViewOp::inferRankReducedResultType(
1825     unsigned resultRank, MemRefType sourceRankedTensorType,
1826     ArrayRef<OpFoldResult> leadingStaticOffsets,
1827     ArrayRef<OpFoldResult> leadingStaticSizes,
1828     ArrayRef<OpFoldResult> leadingStaticStrides) {
1829   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1830   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1831   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
1832                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
1833   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
1834                              ShapedType::kDynamicSize);
1835   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
1836                              staticStrides, ShapedType::kDynamicStrideOrOffset);
1837   return SubViewOp::inferRankReducedResultType(
1838       resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
1839       staticStrides);
1840 }
1841 // Build a SubViewOp with mixed static and dynamic entries and custom result
1842 // type. If the type passed is nullptr, it is inferred.
1843 void SubViewOp::build(OpBuilder &b, OperationState &result,
1844                       MemRefType resultType, Value source,
1845                       ArrayRef<OpFoldResult> offsets,
1846                       ArrayRef<OpFoldResult> sizes,
1847                       ArrayRef<OpFoldResult> strides,
1848                       ArrayRef<NamedAttribute> attrs) {
1849   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1850   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1851   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1852                              ShapedType::kDynamicStrideOrOffset);
1853   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1854                              ShapedType::kDynamicSize);
1855   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1856                              ShapedType::kDynamicStrideOrOffset);
1857   auto sourceMemRefType = source.getType().cast<MemRefType>();
1858   // Structuring implementation this way avoids duplication between builders.
1859   if (!resultType) {
1860     resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
1861                                             staticSizes, staticStrides)
1862                      .cast<MemRefType>();
1863   }
1864   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1865         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1866         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1867   result.addAttributes(attrs);
1868 }
1869 
1870 // Build a SubViewOp with mixed static and dynamic entries and inferred result
1871 // type.
1872 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1873                       ArrayRef<OpFoldResult> offsets,
1874                       ArrayRef<OpFoldResult> sizes,
1875                       ArrayRef<OpFoldResult> strides,
1876                       ArrayRef<NamedAttribute> attrs) {
1877   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
1878 }
1879 
1880 // Build a SubViewOp with static entries and inferred result type.
1881 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1882                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1883                       ArrayRef<int64_t> strides,
1884                       ArrayRef<NamedAttribute> attrs) {
1885   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1886       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
1887         return b.getI64IntegerAttr(v);
1888       }));
1889   SmallVector<OpFoldResult> sizeValues =
1890       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1891         return b.getI64IntegerAttr(v);
1892       }));
1893   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1894       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1895         return b.getI64IntegerAttr(v);
1896       }));
1897   build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
1898 }
1899 
1900 // Build a SubViewOp with dynamic entries and custom result type. If the
1901 // type passed is nullptr, it is inferred.
1902 void SubViewOp::build(OpBuilder &b, OperationState &result,
1903                       MemRefType resultType, Value source,
1904                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1905                       ArrayRef<int64_t> strides,
1906                       ArrayRef<NamedAttribute> attrs) {
1907   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1908       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
1909         return b.getI64IntegerAttr(v);
1910       }));
1911   SmallVector<OpFoldResult> sizeValues =
1912       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1913         return b.getI64IntegerAttr(v);
1914       }));
1915   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1916       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1917         return b.getI64IntegerAttr(v);
1918       }));
1919   build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
1920         attrs);
1921 }
1922 
1923 // Build a SubViewOp with dynamic entries and custom result type. If the type
1924 // passed is nullptr, it is inferred.
1925 void SubViewOp::build(OpBuilder &b, OperationState &result,
1926                       MemRefType resultType, Value source, ValueRange offsets,
1927                       ValueRange sizes, ValueRange strides,
1928                       ArrayRef<NamedAttribute> attrs) {
1929   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1930       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1931   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1932       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1933   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1934       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1935   build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
1936 }
1937 
1938 // Build a SubViewOp with dynamic entries and inferred result type.
1939 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1940                       ValueRange offsets, ValueRange sizes, ValueRange strides,
1941                       ArrayRef<NamedAttribute> attrs) {
1942   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
1943 }
1944 
1945 /// For ViewLikeOpInterface.
1946 Value SubViewOp::getViewSource() { return source(); }
1947 
1948 enum SubViewVerificationResult {
1949   Success,
1950   RankTooLarge,
1951   SizeMismatch,
1952   ElemTypeMismatch,
1953   MemSpaceMismatch,
1954   AffineMapMismatch
1955 };
1956 
1957 /// Checks if `original` Type type can be rank reduced to `reduced` type.
1958 /// This function is slight variant of `is subsequence` algorithm where
1959 /// not matching dimension must be 1.
1960 static SubViewVerificationResult
1961 isRankReducedType(Type originalType, Type candidateReducedType,
1962                   ArrayAttr staticSizes, std::string *errMsg = nullptr) {
1963   if (originalType == candidateReducedType)
1964     return SubViewVerificationResult::Success;
1965   if (!originalType.isa<MemRefType>())
1966     return SubViewVerificationResult::Success;
1967   if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>())
1968     return SubViewVerificationResult::Success;
1969 
1970   ShapedType originalShapedType = originalType.cast<ShapedType>();
1971   ShapedType candidateReducedShapedType =
1972       candidateReducedType.cast<ShapedType>();
1973 
1974   // Rank and size logic is valid for all ShapedTypes.
1975   ArrayRef<int64_t> originalShape = originalShapedType.getShape();
1976   ArrayRef<int64_t> candidateReducedShape =
1977       candidateReducedShapedType.getShape();
1978   unsigned originalRank = originalShape.size(),
1979            candidateReducedRank = candidateReducedShape.size();
1980   if (candidateReducedRank > originalRank)
1981     return SubViewVerificationResult::RankTooLarge;
1982 
1983   MemRefType original = originalType.cast<MemRefType>();
1984   MemRefType candidateReduced = candidateReducedType.cast<MemRefType>();
1985 
1986   auto optionalUnusedDimsMask =
1987       computeMemRefRankReductionMask(original, candidateReduced, staticSizes);
1988 
1989   // Sizes cannot be matched in case empty vector is returned.
1990   if (!optionalUnusedDimsMask.hasValue())
1991     return SubViewVerificationResult::SizeMismatch;
1992 
1993   if (originalShapedType.getElementType() !=
1994       candidateReducedShapedType.getElementType())
1995     return SubViewVerificationResult::ElemTypeMismatch;
1996 
1997   // Strided layout logic is relevant for MemRefType only.
1998   if (original.getMemorySpace() != candidateReduced.getMemorySpace())
1999     return SubViewVerificationResult::MemSpaceMismatch;
2000   return SubViewVerificationResult::Success;
2001 }
2002 
2003 template <typename OpTy>
2004 static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result,
2005                                             OpTy op, Type expectedType,
2006                                             StringRef errMsg = "") {
2007   auto memrefType = expectedType.cast<ShapedType>();
2008   switch (result) {
2009   case SubViewVerificationResult::Success:
2010     return success();
2011   case SubViewVerificationResult::RankTooLarge:
2012     return op.emitError("expected result rank to be smaller or equal to ")
2013            << "the source rank. " << errMsg;
2014   case SubViewVerificationResult::SizeMismatch:
2015     return op.emitError("expected result type to be ")
2016            << expectedType
2017            << " or a rank-reduced version. (mismatch of result sizes) "
2018            << errMsg;
2019   case SubViewVerificationResult::ElemTypeMismatch:
2020     return op.emitError("expected result element type to be ")
2021            << memrefType.getElementType() << errMsg;
2022   case SubViewVerificationResult::MemSpaceMismatch:
2023     return op.emitError("expected result and source memory spaces to match.")
2024            << errMsg;
2025   case SubViewVerificationResult::AffineMapMismatch:
2026     return op.emitError("expected result type to be ")
2027            << expectedType
2028            << " or a rank-reduced version. (mismatch of result affine map) "
2029            << errMsg;
2030   }
2031   llvm_unreachable("unexpected subview verification result");
2032 }
2033 
2034 /// Verifier for SubViewOp.
2035 static LogicalResult verify(SubViewOp op) {
2036   MemRefType baseType = op.getSourceType();
2037   MemRefType subViewType = op.getType();
2038 
2039   // The base memref and the view memref should be in the same memory space.
2040   if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2041     return op.emitError("different memory spaces specified for base memref "
2042                         "type ")
2043            << baseType << " and subview memref type " << subViewType;
2044 
2045   // Verify that the base memref type has a strided layout map.
2046   if (!isStrided(baseType))
2047     return op.emitError("base type ") << baseType << " is not strided";
2048 
2049   // Verify result type against inferred type.
2050   auto expectedType = SubViewOp::inferResultType(
2051       baseType, extractFromI64ArrayAttr(op.static_offsets()),
2052       extractFromI64ArrayAttr(op.static_sizes()),
2053       extractFromI64ArrayAttr(op.static_strides()));
2054 
2055   std::string errMsg;
2056   auto result =
2057       isRankReducedType(expectedType, subViewType, op.static_sizes(), &errMsg);
2058   return produceSubViewErrorMsg(result, op, expectedType, errMsg);
2059 }
2060 
2061 raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) {
2062   return os << "range " << range.offset << ":" << range.size << ":"
2063             << range.stride;
2064 }
2065 
2066 /// Return the list of Range (i.e. offset, size, stride). Each Range
2067 /// entry contains either the dynamic value or a ConstantIndexOp constructed
2068 /// with `b` at location `loc`.
2069 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
2070                                               OpBuilder &b, Location loc) {
2071   std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
2072   assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
2073   assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
2074   SmallVector<Range, 8> res;
2075   unsigned rank = ranks[0];
2076   res.reserve(rank);
2077   for (unsigned idx = 0; idx < rank; ++idx) {
2078     Value offset =
2079         op.isDynamicOffset(idx)
2080             ? op.getDynamicOffset(idx)
2081             : b.create<ConstantIndexOp>(loc, op.getStaticOffset(idx));
2082     Value size = op.isDynamicSize(idx)
2083                      ? op.getDynamicSize(idx)
2084                      : b.create<ConstantIndexOp>(loc, op.getStaticSize(idx));
2085     Value stride =
2086         op.isDynamicStride(idx)
2087             ? op.getDynamicStride(idx)
2088             : b.create<ConstantIndexOp>(loc, op.getStaticStride(idx));
2089     res.emplace_back(Range{offset, size, stride});
2090   }
2091   return res;
2092 }
2093 
2094 /// Infer the canonical type of the result of a subview operation. Returns a
2095 /// type with rank `resultRank` that is either the rank of the rank-reduced
2096 /// type, or the non-rank-reduced type.
2097 static MemRefType
2098 getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType,
2099                               ArrayRef<OpFoldResult> mixedOffsets,
2100                               ArrayRef<OpFoldResult> mixedSizes,
2101                               ArrayRef<OpFoldResult> mixedStrides) {
2102   auto resultType =
2103       SubViewOp::inferRankReducedResultType(
2104           resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides)
2105           .cast<MemRefType>();
2106   if (resultType.getRank() != resultRank) {
2107     resultType = SubViewOp::inferResultType(sourceType, mixedOffsets,
2108                                             mixedSizes, mixedStrides)
2109                      .cast<MemRefType>();
2110   }
2111   return resultType;
2112 }
2113 
2114 namespace {
2115 /// Pattern to rewrite a subview op with MemRefCast arguments.
2116 /// This essentially pushes memref.cast past its consuming subview when
2117 /// `canFoldIntoConsumerOp` is true.
2118 ///
2119 /// Example:
2120 /// ```
2121 ///   %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
2122 ///   %1 = memref.subview %0[0, 0][3, 4][1, 1] :
2123 ///     memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
2124 /// ```
2125 /// is rewritten into:
2126 /// ```
2127 ///   %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
2128 ///   %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to
2129 ///     memref<3x4xf32, offset:?, strides:[?, 1]>
2130 /// ```
2131 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
2132 public:
2133   using OpRewritePattern<SubViewOp>::OpRewritePattern;
2134 
2135   LogicalResult matchAndRewrite(SubViewOp subViewOp,
2136                                 PatternRewriter &rewriter) const override {
2137     // Any constant operand, just return to let SubViewOpConstantFolder kick in.
2138     if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
2139           return matchPattern(operand, matchConstantIndex());
2140         }))
2141       return failure();
2142 
2143     auto castOp = subViewOp.source().getDefiningOp<CastOp>();
2144     if (!castOp)
2145       return failure();
2146 
2147     if (!CastOp::canFoldIntoConsumerOp(castOp))
2148       return failure();
2149 
2150     /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
2151     /// the cast source operand type and the SubViewOp static information. This
2152     /// is the resulting type if the MemRefCastOp were folded.
2153     auto resultType = getCanonicalSubViewResultType(
2154         subViewOp.getType().getRank(),
2155         castOp.source().getType().cast<MemRefType>(),
2156         subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
2157         subViewOp.getMixedStrides());
2158     Value newSubView = rewriter.create<SubViewOp>(
2159         subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
2160         subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),
2161         subViewOp.static_sizes(), subViewOp.static_strides());
2162     rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
2163                                         newSubView);
2164     return success();
2165   }
2166 };
2167 } // namespace
2168 
2169 /// Return the canonical type of the result of a subview.
2170 struct SubViewReturnTypeCanonicalizer {
2171   MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
2172                         ArrayRef<OpFoldResult> mixedSizes,
2173                         ArrayRef<OpFoldResult> mixedStrides) {
2174     return getCanonicalSubViewResultType(op.getType().getRank(),
2175                                          op.getSourceType(), mixedOffsets,
2176                                          mixedSizes, mixedStrides);
2177   }
2178 };
2179 
2180 /// A canonicalizer wrapper to replace SubViewOps.
2181 struct SubViewCanonicalizer {
2182   void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
2183     rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType());
2184   }
2185 };
2186 
2187 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
2188                                             MLIRContext *context) {
2189   results
2190       .add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
2191                SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>,
2192            SubViewOpMemRefCastFolder>(context);
2193 }
2194 
2195 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
2196   auto resultShapedType = getResult().getType().cast<ShapedType>();
2197   auto sourceShapedType = source().getType().cast<ShapedType>();
2198 
2199   if (resultShapedType.hasStaticShape() &&
2200       resultShapedType == sourceShapedType) {
2201     return getViewSource();
2202   }
2203 
2204   return {};
2205 }
2206 
2207 //===----------------------------------------------------------------------===//
2208 // TensorLoadOp
2209 //===----------------------------------------------------------------------===//
2210 
2211 OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute>) {
2212   if (auto bufferCast = memref().getDefiningOp<BufferCastOp>())
2213     // Approximate alias analysis by conservatively folding only when no there
2214     // is no interleaved operation.
2215     if (bufferCast->getBlock() == this->getOperation()->getBlock() &&
2216         bufferCast->getNextNode() == this->getOperation())
2217       return bufferCast.tensor();
2218   return {};
2219 }
2220 
2221 namespace {
2222 struct DimOfTensorLoadFolder : public OpRewritePattern<tensor::DimOp> {
2223   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
2224 
2225   LogicalResult matchAndRewrite(tensor::DimOp dimOp,
2226                                 PatternRewriter &rewriter) const override {
2227     auto tensorLoadOp = dimOp.source().getDefiningOp<TensorLoadOp>();
2228     if (!tensorLoadOp)
2229       return failure();
2230 
2231     rewriter.replaceOpWithNewOp<DimOp>(dimOp, tensorLoadOp.memref(),
2232                                        dimOp.index());
2233     return success();
2234   }
2235 };
2236 } // namespace
2237 
2238 void TensorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
2239                                                MLIRContext *context) {
2240   results.add<DimOfTensorLoadFolder>(context);
2241 }
2242 
2243 //===----------------------------------------------------------------------===//
2244 // TransposeOp
2245 //===----------------------------------------------------------------------===//
2246 
2247 /// Build a strided memref type by applying `permutationMap` tp `memRefType`.
2248 static MemRefType inferTransposeResultType(MemRefType memRefType,
2249                                            AffineMap permutationMap) {
2250   auto rank = memRefType.getRank();
2251   auto originalSizes = memRefType.getShape();
2252   // Compute permuted sizes.
2253   SmallVector<int64_t, 4> sizes(rank, 0);
2254   for (auto en : llvm::enumerate(permutationMap.getResults()))
2255     sizes[en.index()] =
2256         originalSizes[en.value().cast<AffineDimExpr>().getPosition()];
2257 
2258   // Compute permuted strides.
2259   int64_t offset;
2260   SmallVector<int64_t, 4> strides;
2261   auto res = getStridesAndOffset(memRefType, strides, offset);
2262   assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank));
2263   (void)res;
2264   auto map =
2265       makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
2266   map = permutationMap ? map.compose(permutationMap) : map;
2267   return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map);
2268 }
2269 
2270 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
2271                         AffineMapAttr permutation,
2272                         ArrayRef<NamedAttribute> attrs) {
2273   auto permutationMap = permutation.getValue();
2274   assert(permutationMap);
2275 
2276   auto memRefType = in.getType().cast<MemRefType>();
2277   // Compute result type.
2278   MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
2279 
2280   build(b, result, resultType, in, attrs);
2281   result.addAttribute(TransposeOp::getPermutationAttrName(), permutation);
2282 }
2283 
2284 // transpose $in $permutation attr-dict : type($in) `to` type(results)
2285 static void print(OpAsmPrinter &p, TransposeOp op) {
2286   p << " " << op.in() << " " << op.permutation();
2287   p.printOptionalAttrDict(op->getAttrs(),
2288                           {TransposeOp::getPermutationAttrName()});
2289   p << " : " << op.in().getType() << " to " << op.getType();
2290 }
2291 
2292 static ParseResult parseTransposeOp(OpAsmParser &parser,
2293                                     OperationState &result) {
2294   OpAsmParser::OperandType in;
2295   AffineMap permutation;
2296   MemRefType srcType, dstType;
2297   if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
2298       parser.parseOptionalAttrDict(result.attributes) ||
2299       parser.parseColonType(srcType) ||
2300       parser.resolveOperand(in, srcType, result.operands) ||
2301       parser.parseKeywordType("to", dstType) ||
2302       parser.addTypeToList(dstType, result.types))
2303     return failure();
2304 
2305   result.addAttribute(TransposeOp::getPermutationAttrName(),
2306                       AffineMapAttr::get(permutation));
2307   return success();
2308 }
2309 
2310 static LogicalResult verify(TransposeOp op) {
2311   if (!op.permutation().isPermutation())
2312     return op.emitOpError("expected a permutation map");
2313   if (op.permutation().getNumDims() != op.getShapedType().getRank())
2314     return op.emitOpError(
2315         "expected a permutation map of same rank as the input");
2316 
2317   auto srcType = op.in().getType().cast<MemRefType>();
2318   auto dstType = op.getType().cast<MemRefType>();
2319   auto transposedType = inferTransposeResultType(srcType, op.permutation());
2320   if (dstType != transposedType)
2321     return op.emitOpError("output type ")
2322            << dstType << " does not match transposed input type " << srcType
2323            << ", " << transposedType;
2324   return success();
2325 }
2326 
2327 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
2328   if (succeeded(foldMemRefCast(*this)))
2329     return getResult();
2330   return {};
2331 }
2332 
2333 //===----------------------------------------------------------------------===//
2334 // ViewOp
2335 //===----------------------------------------------------------------------===//
2336 
2337 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) {
2338   OpAsmParser::OperandType srcInfo;
2339   SmallVector<OpAsmParser::OperandType, 1> offsetInfo;
2340   SmallVector<OpAsmParser::OperandType, 4> sizesInfo;
2341   auto indexType = parser.getBuilder().getIndexType();
2342   Type srcType, dstType;
2343   llvm::SMLoc offsetLoc;
2344   if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) ||
2345       parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square))
2346     return failure();
2347 
2348   if (offsetInfo.size() != 1)
2349     return parser.emitError(offsetLoc) << "expects 1 offset operand";
2350 
2351   return failure(
2352       parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) ||
2353       parser.parseOptionalAttrDict(result.attributes) ||
2354       parser.parseColonType(srcType) ||
2355       parser.resolveOperand(srcInfo, srcType, result.operands) ||
2356       parser.resolveOperands(offsetInfo, indexType, result.operands) ||
2357       parser.resolveOperands(sizesInfo, indexType, result.operands) ||
2358       parser.parseKeywordType("to", dstType) ||
2359       parser.addTypeToList(dstType, result.types));
2360 }
2361 
2362 static void print(OpAsmPrinter &p, ViewOp op) {
2363   p << ' ' << op.getOperand(0) << '[';
2364   p.printOperand(op.byte_shift());
2365   p << "][" << op.sizes() << ']';
2366   p.printOptionalAttrDict(op->getAttrs());
2367   p << " : " << op.getOperand(0).getType() << " to " << op.getType();
2368 }
2369 
2370 static LogicalResult verify(ViewOp op) {
2371   auto baseType = op.getOperand(0).getType().cast<MemRefType>();
2372   auto viewType = op.getType();
2373 
2374   // The base memref should have identity layout map (or none).
2375   if (baseType.getAffineMaps().size() > 1 ||
2376       (baseType.getAffineMaps().size() == 1 &&
2377        !baseType.getAffineMaps()[0].isIdentity()))
2378     return op.emitError("unsupported map for base memref type ") << baseType;
2379 
2380   // The result memref should have identity layout map (or none).
2381   if (viewType.getAffineMaps().size() > 1 ||
2382       (viewType.getAffineMaps().size() == 1 &&
2383        !viewType.getAffineMaps()[0].isIdentity()))
2384     return op.emitError("unsupported map for result memref type ") << viewType;
2385 
2386   // The base memref and the view memref should be in the same memory space.
2387   if (baseType.getMemorySpace() != viewType.getMemorySpace())
2388     return op.emitError("different memory spaces specified for base memref "
2389                         "type ")
2390            << baseType << " and view memref type " << viewType;
2391 
2392   // Verify that we have the correct number of sizes for the result type.
2393   unsigned numDynamicDims = viewType.getNumDynamicDims();
2394   if (op.sizes().size() != numDynamicDims)
2395     return op.emitError("incorrect number of size operands for type ")
2396            << viewType;
2397 
2398   return success();
2399 }
2400 
2401 Value ViewOp::getViewSource() { return source(); }
2402 
2403 namespace {
2404 
2405 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
2406   using OpRewritePattern<ViewOp>::OpRewritePattern;
2407 
2408   LogicalResult matchAndRewrite(ViewOp viewOp,
2409                                 PatternRewriter &rewriter) const override {
2410     // Return if none of the operands are constants.
2411     if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
2412           return matchPattern(operand, matchConstantIndex());
2413         }))
2414       return failure();
2415 
2416     // Get result memref type.
2417     auto memrefType = viewOp.getType();
2418 
2419     // Get offset from old memref view type 'memRefType'.
2420     int64_t oldOffset;
2421     SmallVector<int64_t, 4> oldStrides;
2422     if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
2423       return failure();
2424     assert(oldOffset == 0 && "Expected 0 offset");
2425 
2426     SmallVector<Value, 4> newOperands;
2427 
2428     // Offset cannot be folded into result type.
2429 
2430     // Fold any dynamic dim operands which are produced by a constant.
2431     SmallVector<int64_t, 4> newShapeConstants;
2432     newShapeConstants.reserve(memrefType.getRank());
2433 
2434     unsigned dynamicDimPos = 0;
2435     unsigned rank = memrefType.getRank();
2436     for (unsigned dim = 0, e = rank; dim < e; ++dim) {
2437       int64_t dimSize = memrefType.getDimSize(dim);
2438       // If this is already static dimension, keep it.
2439       if (!ShapedType::isDynamic(dimSize)) {
2440         newShapeConstants.push_back(dimSize);
2441         continue;
2442       }
2443       auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp();
2444       if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
2445         // Dynamic shape dimension will be folded.
2446         newShapeConstants.push_back(constantIndexOp.getValue());
2447       } else {
2448         // Dynamic shape dimension not folded; copy operand from old memref.
2449         newShapeConstants.push_back(dimSize);
2450         newOperands.push_back(viewOp.sizes()[dynamicDimPos]);
2451       }
2452       dynamicDimPos++;
2453     }
2454 
2455     // Create new memref type with constant folded dims.
2456     MemRefType newMemRefType =
2457         MemRefType::Builder(memrefType).setShape(newShapeConstants);
2458     // Nothing new, don't fold.
2459     if (newMemRefType == memrefType)
2460       return failure();
2461 
2462     // Create new ViewOp.
2463     auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType,
2464                                              viewOp.getOperand(0),
2465                                              viewOp.byte_shift(), newOperands);
2466     // Insert a cast so we have the same type as the old memref type.
2467     rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType());
2468     return success();
2469   }
2470 };
2471 
2472 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
2473   using OpRewritePattern<ViewOp>::OpRewritePattern;
2474 
2475   LogicalResult matchAndRewrite(ViewOp viewOp,
2476                                 PatternRewriter &rewriter) const override {
2477     Value memrefOperand = viewOp.getOperand(0);
2478     CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
2479     if (!memrefCastOp)
2480       return failure();
2481     Value allocOperand = memrefCastOp.getOperand();
2482     AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
2483     if (!allocOp)
2484       return failure();
2485     rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
2486                                         viewOp.byte_shift(), viewOp.sizes());
2487     return success();
2488   }
2489 };
2490 
2491 } // end anonymous namespace
2492 
2493 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
2494                                          MLIRContext *context) {
2495   results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
2496 }
2497 
2498 //===----------------------------------------------------------------------===//
2499 // TableGen'd op method definitions
2500 //===----------------------------------------------------------------------===//
2501 
2502 #define GET_OP_CLASSES
2503 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
2504