1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/MemRef/IR/MemRef.h"
11 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
12 #include "mlir/Dialect/StandardOps/IR/Ops.h"
13 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/Dialect/Utils/StaticValueUtils.h"
16 #include "mlir/IR/AffineMap.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Matchers.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/IR/TypeUtilities.h"
22 #include "mlir/Interfaces/InferTypeOpInterface.h"
23 #include "mlir/Interfaces/ViewLikeInterface.h"
24 #include "llvm/ADT/STLExtras.h"
25 
26 using namespace mlir;
27 using namespace mlir::memref;
28 
29 /// Materialize a single constant operation from a given attribute value with
30 /// the desired resultant type.
31 Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
32                                               Attribute value, Type type,
33                                               Location loc) {
34   if (arith::ConstantOp::isBuildableWith(value, type))
35     return builder.create<arith::ConstantOp>(loc, value, type);
36   if (ConstantOp::isBuildableWith(value, type))
37     return builder.create<ConstantOp>(loc, value, type);
38   return nullptr;
39 }
40 
41 //===----------------------------------------------------------------------===//
42 // Common canonicalization pattern support logic
43 //===----------------------------------------------------------------------===//
44 
45 /// This is a common class used for patterns of the form
46 /// "someop(memrefcast) -> someop".  It folds the source of any memref.cast
47 /// into the root operation directly.
48 static LogicalResult foldMemRefCast(Operation *op, Value inner = nullptr) {
49   bool folded = false;
50   for (OpOperand &operand : op->getOpOperands()) {
51     auto cast = operand.get().getDefiningOp<CastOp>();
52     if (cast && operand.get() != inner &&
53         !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
54       operand.set(cast.getOperand());
55       folded = true;
56     }
57   }
58   return success(folded);
59 }
60 
61 //===----------------------------------------------------------------------===//
62 // Helpers for GlobalOp
63 //===----------------------------------------------------------------------===//
64 
65 static Type getTensorTypeFromMemRefType(Type type) {
66   if (auto memref = type.dyn_cast<MemRefType>())
67     return RankedTensorType::get(memref.getShape(), memref.getElementType());
68   if (auto memref = type.dyn_cast<UnrankedMemRefType>())
69     return UnrankedTensorType::get(memref.getElementType());
70   return NoneType::get(type.getContext());
71 }
72 
73 //===----------------------------------------------------------------------===//
74 // AllocOp / AllocaOp
75 //===----------------------------------------------------------------------===//
76 
77 template <typename AllocLikeOp>
78 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
79   static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
80                 "applies to only alloc or alloca");
81   auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>();
82   if (!memRefType)
83     return op.emitOpError("result must be a memref");
84 
85   if (static_cast<int64_t>(op.dynamicSizes().size()) !=
86       memRefType.getNumDynamicDims())
87     return op.emitOpError("dimension operand count does not equal memref "
88                           "dynamic dimension count");
89 
90   unsigned numSymbols = 0;
91   if (!memRefType.getLayout().isIdentity())
92     numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
93   if (op.symbolOperands().size() != numSymbols)
94     return op.emitOpError("symbol operand count does not equal memref symbol "
95                           "count: expected ")
96            << numSymbols << ", got " << op.symbolOperands().size();
97 
98   return success();
99 }
100 
101 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); }
102 
103 static LogicalResult verify(AllocaOp op) {
104   // An alloca op needs to have an ancestor with an allocation scope trait.
105   if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
106     return op.emitOpError(
107         "requires an ancestor op with AutomaticAllocationScope trait");
108 
109   return verifyAllocLikeOp(op);
110 }
111 
112 namespace {
113 /// Fold constant dimensions into an alloc like operation.
114 template <typename AllocLikeOp>
115 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
116   using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
117 
118   LogicalResult matchAndRewrite(AllocLikeOp alloc,
119                                 PatternRewriter &rewriter) const override {
120     // Check to see if any dimensions operands are constants.  If so, we can
121     // substitute and drop them.
122     if (llvm::none_of(alloc.dynamicSizes(), [](Value operand) {
123           return matchPattern(operand, matchConstantIndex());
124         }))
125       return failure();
126 
127     auto memrefType = alloc.getType();
128 
129     // Ok, we have one or more constant operands.  Collect the non-constant ones
130     // and keep track of the resultant memref type to build.
131     SmallVector<int64_t, 4> newShapeConstants;
132     newShapeConstants.reserve(memrefType.getRank());
133     SmallVector<Value, 4> dynamicSizes;
134 
135     unsigned dynamicDimPos = 0;
136     for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
137       int64_t dimSize = memrefType.getDimSize(dim);
138       // If this is already static dimension, keep it.
139       if (dimSize != -1) {
140         newShapeConstants.push_back(dimSize);
141         continue;
142       }
143       auto dynamicSize = alloc.dynamicSizes()[dynamicDimPos];
144       auto *defOp = dynamicSize.getDefiningOp();
145       if (auto constantIndexOp =
146               dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
147         // Dynamic shape dimension will be folded.
148         newShapeConstants.push_back(constantIndexOp.value());
149       } else {
150         // Dynamic shape dimension not folded; copy dynamicSize from old memref.
151         newShapeConstants.push_back(-1);
152         dynamicSizes.push_back(dynamicSize);
153       }
154       dynamicDimPos++;
155     }
156 
157     // Create new memref type (which will have fewer dynamic dimensions).
158     MemRefType newMemRefType =
159         MemRefType::Builder(memrefType).setShape(newShapeConstants);
160     assert(static_cast<int64_t>(dynamicSizes.size()) ==
161            newMemRefType.getNumDynamicDims());
162 
163     // Create and insert the alloc op for the new memref.
164     auto newAlloc = rewriter.create<AllocLikeOp>(
165         alloc.getLoc(), newMemRefType, dynamicSizes, alloc.symbolOperands(),
166         alloc.alignmentAttr());
167     // Insert a cast so we have the same type as the old alloc.
168     auto resultCast =
169         rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType());
170 
171     rewriter.replaceOp(alloc, {resultCast});
172     return success();
173   }
174 };
175 
176 /// Fold alloc operations with no users or only store and dealloc uses.
177 template <typename T>
178 struct SimplifyDeadAlloc : public OpRewritePattern<T> {
179   using OpRewritePattern<T>::OpRewritePattern;
180 
181   LogicalResult matchAndRewrite(T alloc,
182                                 PatternRewriter &rewriter) const override {
183     if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
184           if (auto storeOp = dyn_cast<StoreOp>(op))
185             return storeOp.value() == alloc;
186           return !isa<DeallocOp>(op);
187         }))
188       return failure();
189 
190     for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
191       rewriter.eraseOp(user);
192 
193     rewriter.eraseOp(alloc);
194     return success();
195   }
196 };
197 } // end anonymous namespace.
198 
199 Optional<Operation *> AllocOp::buildDealloc(OpBuilder &builder, Value alloc) {
200   return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
201       .getOperation();
202 }
203 
204 Optional<Value> AllocOp::buildClone(OpBuilder &builder, Value alloc) {
205   return builder.create<memref::CloneOp>(alloc.getLoc(), alloc).getResult();
206 }
207 
208 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
209                                           MLIRContext *context) {
210   results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
211 }
212 
213 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
214                                            MLIRContext *context) {
215   results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
216       context);
217 }
218 
219 //===----------------------------------------------------------------------===//
220 // AllocaScopeOp
221 //===----------------------------------------------------------------------===//
222 
223 static void print(OpAsmPrinter &p, AllocaScopeOp &op) {
224   bool printBlockTerminators = false;
225 
226   p << " ";
227   if (!op.results().empty()) {
228     p << " -> (" << op.getResultTypes() << ")";
229     printBlockTerminators = true;
230   }
231   p.printRegion(op.bodyRegion(),
232                 /*printEntryBlockArgs=*/false,
233                 /*printBlockTerminators=*/printBlockTerminators);
234   p.printOptionalAttrDict(op->getAttrs());
235 }
236 
237 static ParseResult parseAllocaScopeOp(OpAsmParser &parser,
238                                       OperationState &result) {
239   // Create a region for the body.
240   result.regions.reserve(1);
241   Region *bodyRegion = result.addRegion();
242 
243   // Parse optional results type list.
244   if (parser.parseOptionalArrowTypeList(result.types))
245     return failure();
246 
247   // Parse the body region.
248   if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{}))
249     return failure();
250   AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
251                                   result.location);
252 
253   // Parse the optional attribute list.
254   if (parser.parseOptionalAttrDict(result.attributes))
255     return failure();
256 
257   return success();
258 }
259 
260 static LogicalResult verify(AllocaScopeOp op) {
261   if (failed(RegionBranchOpInterface::verifyTypes(op)))
262     return failure();
263 
264   return success();
265 }
266 
267 void AllocaScopeOp::getSuccessorRegions(
268     Optional<unsigned> index, ArrayRef<Attribute> operands,
269     SmallVectorImpl<RegionSuccessor> &regions) {
270   if (index.hasValue()) {
271     regions.push_back(RegionSuccessor(getResults()));
272     return;
273   }
274 
275   regions.push_back(RegionSuccessor(&bodyRegion()));
276 }
277 
278 //===----------------------------------------------------------------------===//
279 // AssumeAlignmentOp
280 //===----------------------------------------------------------------------===//
281 
282 static LogicalResult verify(AssumeAlignmentOp op) {
283   unsigned alignment = op.alignment();
284   if (!llvm::isPowerOf2_32(alignment))
285     return op.emitOpError("alignment must be power of 2");
286   return success();
287 }
288 
289 //===----------------------------------------------------------------------===//
290 // BufferCastOp
291 //===----------------------------------------------------------------------===//
292 
293 OpFoldResult BufferCastOp::fold(ArrayRef<Attribute>) {
294   if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>())
295     if (tensorLoad.memref().getType() == getType())
296       return tensorLoad.memref();
297   return {};
298 }
299 
300 namespace {
301 /// Replace tensor_cast + buffer_cast by buffer_cast + memref_cast.
302 struct BufferCast : public OpRewritePattern<BufferCastOp> {
303   using OpRewritePattern<BufferCastOp>::OpRewritePattern;
304 
305   LogicalResult matchAndRewrite(BufferCastOp bufferCast,
306                                 PatternRewriter &rewriter) const final {
307     auto tensorCastOperand =
308         bufferCast.getOperand().getDefiningOp<tensor::CastOp>();
309     if (!tensorCastOperand)
310       return failure();
311     auto srcTensorType =
312         tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>();
313     if (!srcTensorType)
314       return failure();
315     auto memrefType = MemRefType::get(srcTensorType.getShape(),
316                                       srcTensorType.getElementType());
317     Value memref = rewriter.create<BufferCastOp>(
318         bufferCast.getLoc(), memrefType, tensorCastOperand.getOperand());
319     rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(),
320                                         memref);
321     return success();
322   }
323 };
324 
325 /// Canonicalize memref.tensor_load + memref.buffer_cast to memref.cast when
326 /// type mismatches prevent `BufferCastOp::fold` to kick in.
327 struct TensorLoadToMemRef : public OpRewritePattern<BufferCastOp> {
328   using OpRewritePattern<BufferCastOp>::OpRewritePattern;
329 
330   LogicalResult matchAndRewrite(BufferCastOp bufferCast,
331                                 PatternRewriter &rewriter) const final {
332     auto tensorLoad = bufferCast.tensor().getDefiningOp<TensorLoadOp>();
333     // Bail unless we have a tensor_load + memref.buffer_cast with different
334     // types. `BufferCastOp::fold` handles the same type case.
335     if (!tensorLoad || tensorLoad.memref().getType() == bufferCast.getType())
336       return failure();
337     // If types are definitely not cast-compatible, bail.
338     if (!CastOp::areCastCompatible(tensorLoad.memref().getType(),
339                                    bufferCast.getType()))
340       return failure();
341 
342     // We already know that the types are potentially cast-compatible. However
343     // in case the affine maps are different, we may need to use a copy if we go
344     // from dynamic to static offset or stride (the canonicalization cannot know
345     // at this point that it is really cast compatible).
346     auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
347       int64_t sourceOffset, targetOffset;
348       SmallVector<int64_t, 4> sourceStrides, targetStrides;
349       if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
350           failed(getStridesAndOffset(target, targetStrides, targetOffset)))
351         return false;
352       auto dynamicToStatic = [](int64_t a, int64_t b) {
353         return a == MemRefType::getDynamicStrideOrOffset() &&
354                b != MemRefType::getDynamicStrideOrOffset();
355       };
356       if (dynamicToStatic(sourceOffset, targetOffset))
357         return false;
358       for (auto it : zip(sourceStrides, targetStrides))
359         if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
360           return false;
361       return true;
362     };
363 
364     auto tensorLoadType = tensorLoad.memref().getType().dyn_cast<MemRefType>();
365     auto bufferCastType = bufferCast.getType().dyn_cast<MemRefType>();
366     if (tensorLoadType && bufferCastType &&
367         !isGuaranteedCastCompatible(tensorLoadType, bufferCastType)) {
368       MemRefType resultType = bufferCastType;
369       auto loc = bufferCast.getLoc();
370       SmallVector<Value, 4> dynamicOperands;
371       for (int i = 0; i < resultType.getRank(); ++i) {
372         if (resultType.getShape()[i] != ShapedType::kDynamicSize)
373           continue;
374         auto index = rewriter.createOrFold<arith::ConstantIndexOp>(loc, i);
375         Value size = rewriter.create<tensor::DimOp>(loc, tensorLoad, index);
376         dynamicOperands.push_back(size);
377       }
378       auto copy =
379           rewriter.create<memref::AllocOp>(loc, resultType, dynamicOperands);
380       rewriter.create<CopyOp>(loc, tensorLoad.memref(), copy);
381       rewriter.replaceOp(bufferCast, {copy});
382     } else
383       rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(),
384                                           tensorLoad.memref());
385     return success();
386   }
387 };
388 
389 } // namespace
390 
391 void BufferCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
392                                                MLIRContext *context) {
393   results.add<BufferCast, TensorLoadToMemRef>(context);
394 }
395 
396 //===----------------------------------------------------------------------===//
397 // CastOp
398 //===----------------------------------------------------------------------===//
399 
400 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
401 /// source memref. This is useful to to fold a memref.cast into a consuming op
402 /// and implement canonicalization patterns for ops in different dialects that
403 /// may consume the results of memref.cast operations. Such foldable memref.cast
404 /// operations are typically inserted as `view` and `subview` ops are
405 /// canonicalized, to preserve the type compatibility of their uses.
406 ///
407 /// Returns true when all conditions are met:
408 /// 1. source and result are ranked memrefs with strided semantics and same
409 /// element type and rank.
410 /// 2. each of the source's size, offset or stride has more static information
411 /// than the corresponding result's size, offset or stride.
412 ///
413 /// Example 1:
414 /// ```mlir
415 ///   %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
416 ///   %2 = consumer %1 ... : memref<?x?xf32> ...
417 /// ```
418 ///
419 /// may fold into:
420 ///
421 /// ```mlir
422 ///   %2 = consumer %0 ... : memref<8x16xf32> ...
423 /// ```
424 ///
425 /// Example 2:
426 /// ```
427 ///   %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
428 ///          to memref<?x?xf32>
429 ///   consumer %1 : memref<?x?xf32> ...
430 /// ```
431 ///
432 /// may fold into:
433 ///
434 /// ```
435 ///   consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
436 /// ```
437 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
438   MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
439   MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
440 
441   // Requires ranked MemRefType.
442   if (!sourceType || !resultType)
443     return false;
444 
445   // Requires same elemental type.
446   if (sourceType.getElementType() != resultType.getElementType())
447     return false;
448 
449   // Requires same rank.
450   if (sourceType.getRank() != resultType.getRank())
451     return false;
452 
453   // Only fold casts between strided memref forms.
454   int64_t sourceOffset, resultOffset;
455   SmallVector<int64_t, 4> sourceStrides, resultStrides;
456   if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
457       failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
458     return false;
459 
460   // If cast is towards more static sizes along any dimension, don't fold.
461   for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
462     auto ss = std::get<0>(it), st = std::get<1>(it);
463     if (ss != st)
464       if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st))
465         return false;
466   }
467 
468   // If cast is towards more static offset along any dimension, don't fold.
469   if (sourceOffset != resultOffset)
470     if (MemRefType::isDynamicStrideOrOffset(sourceOffset) &&
471         !MemRefType::isDynamicStrideOrOffset(resultOffset))
472       return false;
473 
474   // If cast is towards more static strides along any dimension, don't fold.
475   for (auto it : llvm::zip(sourceStrides, resultStrides)) {
476     auto ss = std::get<0>(it), st = std::get<1>(it);
477     if (ss != st)
478       if (MemRefType::isDynamicStrideOrOffset(ss) &&
479           !MemRefType::isDynamicStrideOrOffset(st))
480         return false;
481   }
482 
483   return true;
484 }
485 
486 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
487   if (inputs.size() != 1 || outputs.size() != 1)
488     return false;
489   Type a = inputs.front(), b = outputs.front();
490   auto aT = a.dyn_cast<MemRefType>();
491   auto bT = b.dyn_cast<MemRefType>();
492 
493   auto uaT = a.dyn_cast<UnrankedMemRefType>();
494   auto ubT = b.dyn_cast<UnrankedMemRefType>();
495 
496   if (aT && bT) {
497     if (aT.getElementType() != bT.getElementType())
498       return false;
499     if (aT.getLayout() != bT.getLayout()) {
500       int64_t aOffset, bOffset;
501       SmallVector<int64_t, 4> aStrides, bStrides;
502       if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
503           failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
504           aStrides.size() != bStrides.size())
505         return false;
506 
507       // Strides along a dimension/offset are compatible if the value in the
508       // source memref is static and the value in the target memref is the
509       // same. They are also compatible if either one is dynamic (see
510       // description of MemRefCastOp for details).
511       auto checkCompatible = [](int64_t a, int64_t b) {
512         return (a == MemRefType::getDynamicStrideOrOffset() ||
513                 b == MemRefType::getDynamicStrideOrOffset() || a == b);
514       };
515       if (!checkCompatible(aOffset, bOffset))
516         return false;
517       for (auto aStride : enumerate(aStrides))
518         if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
519           return false;
520     }
521     if (aT.getMemorySpace() != bT.getMemorySpace())
522       return false;
523 
524     // They must have the same rank, and any specified dimensions must match.
525     if (aT.getRank() != bT.getRank())
526       return false;
527 
528     for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
529       int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
530       if (aDim != -1 && bDim != -1 && aDim != bDim)
531         return false;
532     }
533     return true;
534   } else {
535     if (!aT && !uaT)
536       return false;
537     if (!bT && !ubT)
538       return false;
539     // Unranked to unranked casting is unsupported
540     if (uaT && ubT)
541       return false;
542 
543     auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
544     auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
545     if (aEltType != bEltType)
546       return false;
547 
548     auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
549     auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
550     if (aMemSpace != bMemSpace)
551       return false;
552 
553     return true;
554   }
555 
556   return false;
557 }
558 
559 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
560   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
561 }
562 
563 //===----------------------------------------------------------------------===//
564 // CloneOp
565 //===----------------------------------------------------------------------===//
566 
567 void CloneOp::getEffects(
568     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
569         &effects) {
570   effects.emplace_back(MemoryEffects::Read::get(), input(),
571                        SideEffects::DefaultResource::get());
572   effects.emplace_back(MemoryEffects::Write::get(), output(),
573                        SideEffects::DefaultResource::get());
574   effects.emplace_back(MemoryEffects::Allocate::get(), output(),
575                        SideEffects::DefaultResource::get());
576 }
577 
578 namespace {
579 /// Merge the clone and its source (by converting the clone to a cast) when
580 /// possible.
581 struct SimplifyClones : public OpRewritePattern<CloneOp> {
582   using OpRewritePattern<CloneOp>::OpRewritePattern;
583 
584   LogicalResult matchAndRewrite(CloneOp cloneOp,
585                                 PatternRewriter &rewriter) const override {
586     if (cloneOp.use_empty()) {
587       rewriter.eraseOp(cloneOp);
588       return success();
589     }
590 
591     Value source = cloneOp.input();
592 
593     // This only finds dealloc operations for the immediate value. It should
594     // also consider aliases. That would also make the safety check below
595     // redundant.
596     llvm::Optional<Operation *> maybeCloneDeallocOp =
597         findDealloc(cloneOp.output());
598     // Skip if either of them has > 1 deallocate operations.
599     if (!maybeCloneDeallocOp.hasValue())
600       return failure();
601     llvm::Optional<Operation *> maybeSourceDeallocOp = findDealloc(source);
602     if (!maybeSourceDeallocOp.hasValue())
603       return failure();
604     Operation *cloneDeallocOp = *maybeCloneDeallocOp;
605     Operation *sourceDeallocOp = *maybeSourceDeallocOp;
606 
607     // If both are deallocated in the same block, their in-block lifetimes
608     // might not fully overlap, so we cannot decide which one to drop.
609     if (cloneDeallocOp && sourceDeallocOp &&
610         cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
611       return failure();
612 
613     Block *currentBlock = cloneOp->getBlock();
614     Operation *redundantDealloc = nullptr;
615     if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
616       redundantDealloc = cloneDeallocOp;
617     } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
618       redundantDealloc = sourceDeallocOp;
619     }
620 
621     if (!redundantDealloc)
622       return failure();
623 
624     // Safety check that there are no other deallocations inbetween
625     // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
626     // of source before the uses of the clone. With alias information, we could
627     // restrict this to only fail of the dealloc's operand is an alias
628     // of the source.
629     for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
630          pos = pos->getNextNode()) {
631       auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
632       if (!effectInterface)
633         continue;
634       if (effectInterface.hasEffect<MemoryEffects::Free>())
635         return failure();
636     }
637 
638     rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(),
639                                                 source);
640     rewriter.eraseOp(redundantDealloc);
641     return success();
642   }
643 };
644 
645 } // end anonymous namespace.
646 
647 void CloneOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
648                                           MLIRContext *context) {
649   results.insert<SimplifyClones>(context);
650 }
651 
652 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) {
653   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
654 }
655 
656 Optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) {
657   return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
658       .getOperation();
659 }
660 
661 Optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
662   return builder.create<memref::CloneOp>(alloc.getLoc(), alloc).getResult();
663 }
664 
665 //===----------------------------------------------------------------------===//
666 // DeallocOp
667 //===----------------------------------------------------------------------===//
668 
669 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
670                               SmallVectorImpl<OpFoldResult> &results) {
671   /// dealloc(memrefcast) -> dealloc
672   return foldMemRefCast(*this);
673 }
674 
675 //===----------------------------------------------------------------------===//
676 // DimOp
677 //===----------------------------------------------------------------------===//
678 
679 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
680                   int64_t index) {
681   auto loc = result.location;
682   Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
683   build(builder, result, source, indexValue);
684 }
685 
686 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
687                   Value index) {
688   auto indexTy = builder.getIndexType();
689   build(builder, result, indexTy, source, index);
690 }
691 
692 Optional<int64_t> DimOp::getConstantIndex() {
693   if (auto constantOp = index().getDefiningOp<arith::ConstantOp>())
694     return constantOp.getValue().cast<IntegerAttr>().getInt();
695   return {};
696 }
697 
698 static LogicalResult verify(DimOp op) {
699   // Assume unknown index to be in range.
700   Optional<int64_t> index = op.getConstantIndex();
701   if (!index.hasValue())
702     return success();
703 
704   // Check that constant index is not knowingly out of range.
705   auto type = op.source().getType();
706   if (auto memrefType = type.dyn_cast<MemRefType>()) {
707     if (index.getValue() >= memrefType.getRank())
708       return op.emitOpError("index is out of range");
709   } else if (type.isa<UnrankedMemRefType>()) {
710     // Assume index to be in range.
711   } else {
712     llvm_unreachable("expected operand with memref type");
713   }
714   return success();
715 }
716 
717 /// Return a map with key being elements in `vals` and data being number of
718 /// occurences of it. Use std::map, since the `vals` here are strides and the
719 /// dynamic stride value is the same as the tombstone value for
720 /// `DenseMap<int64_t>`.
721 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
722   std::map<int64_t, unsigned> numOccurences;
723   for (auto val : vals)
724     numOccurences[val]++;
725   return numOccurences;
726 }
727 
728 /// Given the type of the un-rank reduced subview result type and the
729 /// rank-reduced result type, computes the dropped dimensions. This accounts for
730 /// cases where there are multiple unit-dims, but only a subset of those are
731 /// dropped. For MemRefTypes these can be disambiguated using the strides. If a
732 /// dimension is dropped the stride must be dropped too.
733 static llvm::Optional<llvm::SmallDenseSet<unsigned>>
734 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
735                                ArrayAttr staticSizes) {
736   llvm::SmallDenseSet<unsigned> unusedDims;
737   if (originalType.getRank() == reducedType.getRank())
738     return unusedDims;
739 
740   for (auto dim : llvm::enumerate(staticSizes))
741     if (dim.value().cast<IntegerAttr>().getInt() == 1)
742       unusedDims.insert(dim.index());
743   SmallVector<int64_t> originalStrides, candidateStrides;
744   int64_t originalOffset, candidateOffset;
745   if (failed(
746           getStridesAndOffset(originalType, originalStrides, originalOffset)) ||
747       failed(
748           getStridesAndOffset(reducedType, candidateStrides, candidateOffset)))
749     return llvm::None;
750 
751   // For memrefs, a dimension is truly dropped if its corresponding stride is
752   // also dropped. This is particularly important when more than one of the dims
753   // is 1. Track the number of occurences of the strides in the original type
754   // and the candidate type. For each unused dim that stride should not be
755   // present in the candidate type. Note that there could be multiple dimensions
756   // that have the same size. We dont need to exactly figure out which dim
757   // corresponds to which stride, we just need to verify that the number of
758   // reptitions of a stride in the original + number of unused dims with that
759   // stride == number of repititions of a stride in the candidate.
760   std::map<int64_t, unsigned> currUnaccountedStrides =
761       getNumOccurences(originalStrides);
762   std::map<int64_t, unsigned> candidateStridesNumOccurences =
763       getNumOccurences(candidateStrides);
764   llvm::SmallDenseSet<unsigned> prunedUnusedDims;
765   for (unsigned dim : unusedDims) {
766     int64_t originalStride = originalStrides[dim];
767     if (currUnaccountedStrides[originalStride] >
768         candidateStridesNumOccurences[originalStride]) {
769       // This dim can be treated as dropped.
770       currUnaccountedStrides[originalStride]--;
771       continue;
772     }
773     if (currUnaccountedStrides[originalStride] ==
774         candidateStridesNumOccurences[originalStride]) {
775       // The stride for this is not dropped. Keep as is.
776       prunedUnusedDims.insert(dim);
777       continue;
778     }
779     if (currUnaccountedStrides[originalStride] <
780         candidateStridesNumOccurences[originalStride]) {
781       // This should never happen. Cant have a stride in the reduced rank type
782       // that wasnt in the original one.
783       return llvm::None;
784     }
785   }
786 
787   for (auto prunedDim : prunedUnusedDims)
788     unusedDims.erase(prunedDim);
789   if (unusedDims.size() + reducedType.getRank() != originalType.getRank())
790     return llvm::None;
791   return unusedDims;
792 }
793 
794 llvm::SmallDenseSet<unsigned> SubViewOp::getDroppedDims() {
795   MemRefType sourceType = getSourceType();
796   MemRefType resultType = getType();
797   llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims =
798       computeMemRefRankReductionMask(sourceType, resultType, static_sizes());
799   assert(unusedDims && "unable to find unused dims of subview");
800   return *unusedDims;
801 }
802 
803 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
804   // All forms of folding require a known index.
805   auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
806   if (!index)
807     return {};
808 
809   // Folding for unranked types (UnrankedMemRefType) is not supported.
810   auto memrefType = source().getType().dyn_cast<MemRefType>();
811   if (!memrefType)
812     return {};
813 
814   // Fold if the shape extent along the given index is known.
815   if (!memrefType.isDynamicDim(index.getInt())) {
816     Builder builder(getContext());
817     return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
818   }
819 
820   // The size at the given index is now known to be a dynamic size.
821   unsigned unsignedIndex = index.getValue().getZExtValue();
822 
823   // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
824   Operation *definingOp = source().getDefiningOp();
825 
826   if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
827     return *(alloc.getDynamicSizes().begin() +
828              memrefType.getDynamicDimIndex(unsignedIndex));
829 
830   if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
831     return *(alloca.getDynamicSizes().begin() +
832              memrefType.getDynamicDimIndex(unsignedIndex));
833 
834   if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
835     return *(view.getDynamicSizes().begin() +
836              memrefType.getDynamicDimIndex(unsignedIndex));
837 
838   if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
839     llvm::SmallDenseSet<unsigned> unusedDims = subview.getDroppedDims();
840     unsigned resultIndex = 0;
841     unsigned sourceRank = subview.getSourceType().getRank();
842     unsigned sourceIndex = 0;
843     for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
844       if (unusedDims.count(i))
845         continue;
846       if (resultIndex == unsignedIndex) {
847         sourceIndex = i;
848         break;
849       }
850       resultIndex++;
851     }
852     assert(subview.isDynamicSize(sourceIndex) &&
853            "expected dynamic subview size");
854     return subview.getDynamicSize(sourceIndex);
855   }
856 
857   if (auto sizeInterface =
858           dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
859     assert(sizeInterface.isDynamicSize(unsignedIndex) &&
860            "Expected dynamic subview size");
861     return sizeInterface.getDynamicSize(unsignedIndex);
862   }
863 
864   // dim(memrefcast) -> dim
865   if (succeeded(foldMemRefCast(*this)))
866     return getResult();
867 
868   return {};
869 }
870 
871 namespace {
872 /// Fold dim of a memref reshape operation to a load into the reshape's shape
873 /// operand.
874 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
875   using OpRewritePattern<DimOp>::OpRewritePattern;
876 
877   LogicalResult matchAndRewrite(DimOp dim,
878                                 PatternRewriter &rewriter) const override {
879     auto reshape = dim.source().getDefiningOp<ReshapeOp>();
880 
881     if (!reshape)
882       return failure();
883 
884     // Place the load directly after the reshape to ensure that the shape memref
885     // was not mutated.
886     rewriter.setInsertionPointAfter(reshape);
887     Location loc = dim.getLoc();
888     Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index());
889     if (load.getType() != dim.getType())
890       load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
891     rewriter.replaceOp(dim, load);
892     return success();
893   }
894 };
895 
896 /// Fold dim of a cast into the dim of the source of the memref cast.
897 struct DimOfCastOp : public OpRewritePattern<DimOp> {
898   using OpRewritePattern<DimOp>::OpRewritePattern;
899 
900   LogicalResult matchAndRewrite(DimOp dimOp,
901                                 PatternRewriter &rewriter) const override {
902     auto castOp = dimOp.source().getDefiningOp<BufferCastOp>();
903     if (!castOp)
904       return failure();
905     Value newSource = castOp.getOperand();
906     rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource, dimOp.index());
907     return success();
908   }
909 };
910 } // end anonymous namespace.
911 
912 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
913                                         MLIRContext *context) {
914   results.add<DimOfMemRefReshape, DimOfCastOp>(context);
915 }
916 
917 // ---------------------------------------------------------------------------
918 // DmaStartOp
919 // ---------------------------------------------------------------------------
920 
921 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
922                        Value srcMemRef, ValueRange srcIndices, Value destMemRef,
923                        ValueRange destIndices, Value numElements,
924                        Value tagMemRef, ValueRange tagIndices, Value stride,
925                        Value elementsPerStride) {
926   result.addOperands(srcMemRef);
927   result.addOperands(srcIndices);
928   result.addOperands(destMemRef);
929   result.addOperands(destIndices);
930   result.addOperands({numElements, tagMemRef});
931   result.addOperands(tagIndices);
932   if (stride)
933     result.addOperands({stride, elementsPerStride});
934 }
935 
936 static void print(OpAsmPrinter &p, DmaStartOp op) {
937   p << " " << op.getSrcMemRef() << '[' << op.getSrcIndices() << "], "
938     << op.getDstMemRef() << '[' << op.getDstIndices() << "], "
939     << op.getNumElements() << ", " << op.getTagMemRef() << '['
940     << op.getTagIndices() << ']';
941   if (op.isStrided())
942     p << ", " << op.getStride() << ", " << op.getNumElementsPerStride();
943 
944   p.printOptionalAttrDict(op->getAttrs());
945   p << " : " << op.getSrcMemRef().getType() << ", "
946     << op.getDstMemRef().getType() << ", " << op.getTagMemRef().getType();
947 }
948 
949 // Parse DmaStartOp.
950 // Ex:
951 //   %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
952 //                       %tag[%index], %stride, %num_elt_per_stride :
953 //                     : memref<3076 x f32, 0>,
954 //                       memref<1024 x f32, 2>,
955 //                       memref<1 x i32>
956 //
957 static ParseResult parseDmaStartOp(OpAsmParser &parser,
958                                    OperationState &result) {
959   OpAsmParser::OperandType srcMemRefInfo;
960   SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
961   OpAsmParser::OperandType dstMemRefInfo;
962   SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos;
963   OpAsmParser::OperandType numElementsInfo;
964   OpAsmParser::OperandType tagMemrefInfo;
965   SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
966   SmallVector<OpAsmParser::OperandType, 2> strideInfo;
967 
968   SmallVector<Type, 3> types;
969   auto indexType = parser.getBuilder().getIndexType();
970 
971   // Parse and resolve the following list of operands:
972   // *) source memref followed by its indices (in square brackets).
973   // *) destination memref followed by its indices (in square brackets).
974   // *) dma size in KiB.
975   if (parser.parseOperand(srcMemRefInfo) ||
976       parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
977       parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
978       parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
979       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
980       parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
981       parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
982     return failure();
983 
984   // Parse optional stride and elements per stride.
985   if (parser.parseTrailingOperandList(strideInfo))
986     return failure();
987 
988   bool isStrided = strideInfo.size() == 2;
989   if (!strideInfo.empty() && !isStrided) {
990     return parser.emitError(parser.getNameLoc(),
991                             "expected two stride related operands");
992   }
993 
994   if (parser.parseColonTypeList(types))
995     return failure();
996   if (types.size() != 3)
997     return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
998 
999   if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1000       parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1001       parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1002       parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1003       // size should be an index.
1004       parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1005       parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1006       // tag indices should be index.
1007       parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1008     return failure();
1009 
1010   if (isStrided) {
1011     if (parser.resolveOperands(strideInfo, indexType, result.operands))
1012       return failure();
1013   }
1014 
1015   return success();
1016 }
1017 
1018 static LogicalResult verify(DmaStartOp op) {
1019   unsigned numOperands = op.getNumOperands();
1020 
1021   // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1022   // the number of elements.
1023   if (numOperands < 4)
1024     return op.emitOpError("expected at least 4 operands");
1025 
1026   // Check types of operands. The order of these calls is important: the later
1027   // calls rely on some type properties to compute the operand position.
1028   // 1. Source memref.
1029   if (!op.getSrcMemRef().getType().isa<MemRefType>())
1030     return op.emitOpError("expected source to be of memref type");
1031   if (numOperands < op.getSrcMemRefRank() + 4)
1032     return op.emitOpError()
1033            << "expected at least " << op.getSrcMemRefRank() + 4 << " operands";
1034   if (!op.getSrcIndices().empty() &&
1035       !llvm::all_of(op.getSrcIndices().getTypes(),
1036                     [](Type t) { return t.isIndex(); }))
1037     return op.emitOpError("expected source indices to be of index type");
1038 
1039   // 2. Destination memref.
1040   if (!op.getDstMemRef().getType().isa<MemRefType>())
1041     return op.emitOpError("expected destination to be of memref type");
1042   unsigned numExpectedOperands =
1043       op.getSrcMemRefRank() + op.getDstMemRefRank() + 4;
1044   if (numOperands < numExpectedOperands)
1045     return op.emitOpError()
1046            << "expected at least " << numExpectedOperands << " operands";
1047   if (!op.getDstIndices().empty() &&
1048       !llvm::all_of(op.getDstIndices().getTypes(),
1049                     [](Type t) { return t.isIndex(); }))
1050     return op.emitOpError("expected destination indices to be of index type");
1051 
1052   // 3. Number of elements.
1053   if (!op.getNumElements().getType().isIndex())
1054     return op.emitOpError("expected num elements to be of index type");
1055 
1056   // 4. Tag memref.
1057   if (!op.getTagMemRef().getType().isa<MemRefType>())
1058     return op.emitOpError("expected tag to be of memref type");
1059   numExpectedOperands += op.getTagMemRefRank();
1060   if (numOperands < numExpectedOperands)
1061     return op.emitOpError()
1062            << "expected at least " << numExpectedOperands << " operands";
1063   if (!op.getTagIndices().empty() &&
1064       !llvm::all_of(op.getTagIndices().getTypes(),
1065                     [](Type t) { return t.isIndex(); }))
1066     return op.emitOpError("expected tag indices to be of index type");
1067 
1068   // Optional stride-related operands must be either both present or both
1069   // absent.
1070   if (numOperands != numExpectedOperands &&
1071       numOperands != numExpectedOperands + 2)
1072     return op.emitOpError("incorrect number of operands");
1073 
1074   // 5. Strides.
1075   if (op.isStrided()) {
1076     if (!op.getStride().getType().isIndex() ||
1077         !op.getNumElementsPerStride().getType().isIndex())
1078       return op.emitOpError(
1079           "expected stride and num elements per stride to be of type index");
1080   }
1081 
1082   return success();
1083 }
1084 
1085 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1086                                SmallVectorImpl<OpFoldResult> &results) {
1087   /// dma_start(memrefcast) -> dma_start
1088   return foldMemRefCast(*this);
1089 }
1090 
1091 // ---------------------------------------------------------------------------
1092 // DmaWaitOp
1093 // ---------------------------------------------------------------------------
1094 
1095 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1096                               SmallVectorImpl<OpFoldResult> &results) {
1097   /// dma_wait(memrefcast) -> dma_wait
1098   return foldMemRefCast(*this);
1099 }
1100 
1101 static LogicalResult verify(DmaWaitOp op) {
1102   // Check that the number of tag indices matches the tagMemRef rank.
1103   unsigned numTagIndices = op.tagIndices().size();
1104   unsigned tagMemRefRank = op.getTagMemRefRank();
1105   if (numTagIndices != tagMemRefRank)
1106     return op.emitOpError() << "expected tagIndices to have the same number of "
1107                                "elements as the tagMemRef rank, expected "
1108                             << tagMemRefRank << ", but got " << numTagIndices;
1109   return success();
1110 }
1111 
1112 //===----------------------------------------------------------------------===//
1113 // GlobalOp
1114 //===----------------------------------------------------------------------===//
1115 
1116 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op,
1117                                                    TypeAttr type,
1118                                                    Attribute initialValue) {
1119   p << type;
1120   if (!op.isExternal()) {
1121     p << " = ";
1122     if (op.isUninitialized())
1123       p << "uninitialized";
1124     else
1125       p.printAttributeWithoutType(initialValue);
1126   }
1127 }
1128 
1129 static ParseResult
1130 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
1131                                        Attribute &initialValue) {
1132   Type type;
1133   if (parser.parseType(type))
1134     return failure();
1135 
1136   auto memrefType = type.dyn_cast<MemRefType>();
1137   if (!memrefType || !memrefType.hasStaticShape())
1138     return parser.emitError(parser.getNameLoc())
1139            << "type should be static shaped memref, but got " << type;
1140   typeAttr = TypeAttr::get(type);
1141 
1142   if (parser.parseOptionalEqual())
1143     return success();
1144 
1145   if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1146     initialValue = UnitAttr::get(parser.getContext());
1147     return success();
1148   }
1149 
1150   Type tensorType = getTensorTypeFromMemRefType(memrefType);
1151   if (parser.parseAttribute(initialValue, tensorType))
1152     return failure();
1153   if (!initialValue.isa<ElementsAttr>())
1154     return parser.emitError(parser.getNameLoc())
1155            << "initial value should be a unit or elements attribute";
1156   return success();
1157 }
1158 
1159 static LogicalResult verify(GlobalOp op) {
1160   auto memrefType = op.type().dyn_cast<MemRefType>();
1161   if (!memrefType || !memrefType.hasStaticShape())
1162     return op.emitOpError("type should be static shaped memref, but got ")
1163            << op.type();
1164 
1165   // Verify that the initial value, if present, is either a unit attribute or
1166   // an elements attribute.
1167   if (op.initial_value().hasValue()) {
1168     Attribute initValue = op.initial_value().getValue();
1169     if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>())
1170       return op.emitOpError("initial value should be a unit or elements "
1171                             "attribute, but got ")
1172              << initValue;
1173 
1174     // Check that the type of the initial value is compatible with the type of
1175     // the global variable.
1176     if (initValue.isa<ElementsAttr>()) {
1177       Type initType = initValue.getType();
1178       Type tensorType = getTensorTypeFromMemRefType(memrefType);
1179       if (initType != tensorType)
1180         return op.emitOpError("initial value expected to be of type ")
1181                << tensorType << ", but was of type " << initType;
1182     }
1183   }
1184 
1185   if (Optional<uint64_t> alignAttr = op.alignment()) {
1186     uint64_t alignment = alignAttr.getValue();
1187 
1188     if (!llvm::isPowerOf2_64(alignment))
1189       return op->emitError() << "alignment attribute value " << alignment
1190                              << " is not a power of 2";
1191   }
1192 
1193   // TODO: verify visibility for declarations.
1194   return success();
1195 }
1196 
1197 //===----------------------------------------------------------------------===//
1198 // GetGlobalOp
1199 //===----------------------------------------------------------------------===//
1200 
1201 LogicalResult
1202 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1203   // Verify that the result type is same as the type of the referenced
1204   // memref.global op.
1205   auto global =
1206       symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr());
1207   if (!global)
1208     return emitOpError("'")
1209            << name() << "' does not reference a valid global memref";
1210 
1211   Type resultType = result().getType();
1212   if (global.type() != resultType)
1213     return emitOpError("result type ")
1214            << resultType << " does not match type " << global.type()
1215            << " of the global memref @" << name();
1216   return success();
1217 }
1218 
1219 //===----------------------------------------------------------------------===//
1220 // LoadOp
1221 //===----------------------------------------------------------------------===//
1222 
1223 static LogicalResult verify(LoadOp op) {
1224   if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
1225     return op.emitOpError("incorrect number of indices for load");
1226   return success();
1227 }
1228 
1229 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
1230   /// load(memrefcast) -> load
1231   if (succeeded(foldMemRefCast(*this)))
1232     return getResult();
1233   return OpFoldResult();
1234 }
1235 
1236 namespace {
1237 /// Fold a load on a buffer_cast operation into an tensor.extract on the
1238 /// corresponding tensor.
1239 struct LoadOfBufferCast : public OpRewritePattern<LoadOp> {
1240   using OpRewritePattern<LoadOp>::OpRewritePattern;
1241 
1242   LogicalResult matchAndRewrite(LoadOp load,
1243                                 PatternRewriter &rewriter) const override {
1244     auto buffercast = load.memref().getDefiningOp<BufferCastOp>();
1245     if (!buffercast)
1246       return failure();
1247 
1248     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, buffercast.tensor(),
1249                                                    load.indices());
1250     return success();
1251   }
1252 };
1253 } // end anonymous namespace.
1254 
1255 void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
1256                                          MLIRContext *context) {
1257   results.add<LoadOfBufferCast>(context);
1258 }
1259 
1260 //===----------------------------------------------------------------------===//
1261 // PrefetchOp
1262 //===----------------------------------------------------------------------===//
1263 
1264 static void print(OpAsmPrinter &p, PrefetchOp op) {
1265   p << " " << op.memref() << '[';
1266   p.printOperands(op.indices());
1267   p << ']' << ", " << (op.isWrite() ? "write" : "read");
1268   p << ", locality<" << op.localityHint();
1269   p << ">, " << (op.isDataCache() ? "data" : "instr");
1270   p.printOptionalAttrDict(
1271       op->getAttrs(),
1272       /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1273   p << " : " << op.getMemRefType();
1274 }
1275 
1276 static ParseResult parsePrefetchOp(OpAsmParser &parser,
1277                                    OperationState &result) {
1278   OpAsmParser::OperandType memrefInfo;
1279   SmallVector<OpAsmParser::OperandType, 4> indexInfo;
1280   IntegerAttr localityHint;
1281   MemRefType type;
1282   StringRef readOrWrite, cacheType;
1283 
1284   auto indexTy = parser.getBuilder().getIndexType();
1285   auto i32Type = parser.getBuilder().getIntegerType(32);
1286   if (parser.parseOperand(memrefInfo) ||
1287       parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1288       parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1289       parser.parseComma() || parser.parseKeyword("locality") ||
1290       parser.parseLess() ||
1291       parser.parseAttribute(localityHint, i32Type, "localityHint",
1292                             result.attributes) ||
1293       parser.parseGreater() || parser.parseComma() ||
1294       parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1295       parser.resolveOperand(memrefInfo, type, result.operands) ||
1296       parser.resolveOperands(indexInfo, indexTy, result.operands))
1297     return failure();
1298 
1299   if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
1300     return parser.emitError(parser.getNameLoc(),
1301                             "rw specifier has to be 'read' or 'write'");
1302   result.addAttribute(
1303       PrefetchOp::getIsWriteAttrName(),
1304       parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
1305 
1306   if (!cacheType.equals("data") && !cacheType.equals("instr"))
1307     return parser.emitError(parser.getNameLoc(),
1308                             "cache type has to be 'data' or 'instr'");
1309 
1310   result.addAttribute(
1311       PrefetchOp::getIsDataCacheAttrName(),
1312       parser.getBuilder().getBoolAttr(cacheType.equals("data")));
1313 
1314   return success();
1315 }
1316 
1317 static LogicalResult verify(PrefetchOp op) {
1318   if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
1319     return op.emitOpError("too few indices");
1320 
1321   return success();
1322 }
1323 
1324 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
1325                                SmallVectorImpl<OpFoldResult> &results) {
1326   // prefetch(memrefcast) -> prefetch
1327   return foldMemRefCast(*this);
1328 }
1329 
1330 //===----------------------------------------------------------------------===//
1331 // ReinterpretCastOp
1332 //===----------------------------------------------------------------------===//
1333 
1334 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1335 /// `staticSizes` and `staticStrides` are automatically filled with
1336 /// source-memref-rank sentinel values that encode dynamic entries.
1337 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1338                               MemRefType resultType, Value source,
1339                               OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1340                               ArrayRef<OpFoldResult> strides,
1341                               ArrayRef<NamedAttribute> attrs) {
1342   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1343   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1344   dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets,
1345                              ShapedType::kDynamicStrideOrOffset);
1346   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1347                              ShapedType::kDynamicSize);
1348   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1349                              ShapedType::kDynamicStrideOrOffset);
1350   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1351         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1352         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1353   result.addAttributes(attrs);
1354 }
1355 
1356 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1357                               MemRefType resultType, Value source,
1358                               int64_t offset, ArrayRef<int64_t> sizes,
1359                               ArrayRef<int64_t> strides,
1360                               ArrayRef<NamedAttribute> attrs) {
1361   SmallVector<OpFoldResult> sizeValues =
1362       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1363         return b.getI64IntegerAttr(v);
1364       }));
1365   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1366       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1367         return b.getI64IntegerAttr(v);
1368       }));
1369   build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1370         strideValues, attrs);
1371 }
1372 
1373 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1374                               MemRefType resultType, Value source, Value offset,
1375                               ValueRange sizes, ValueRange strides,
1376                               ArrayRef<NamedAttribute> attrs) {
1377   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1378       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1379   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1380       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1381   build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1382 }
1383 
1384 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1385 // completed automatically, like we have for subview and extract_slice.
1386 static LogicalResult verify(ReinterpretCastOp op) {
1387   // The source and result memrefs should be in the same memory space.
1388   auto srcType = op.source().getType().cast<BaseMemRefType>();
1389   auto resultType = op.getType().cast<MemRefType>();
1390   if (srcType.getMemorySpace() != resultType.getMemorySpace())
1391     return op.emitError("different memory spaces specified for source type ")
1392            << srcType << " and result memref type " << resultType;
1393   if (srcType.getElementType() != resultType.getElementType())
1394     return op.emitError("different element types specified for source type ")
1395            << srcType << " and result memref type " << resultType;
1396 
1397   // Match sizes in result memref type and in static_sizes attribute.
1398   for (auto &en :
1399        llvm::enumerate(llvm::zip(resultType.getShape(),
1400                                  extractFromI64ArrayAttr(op.static_sizes())))) {
1401     int64_t resultSize = std::get<0>(en.value());
1402     int64_t expectedSize = std::get<1>(en.value());
1403     if (resultSize != expectedSize)
1404       return op.emitError("expected result type with size = ")
1405              << expectedSize << " instead of " << resultSize
1406              << " in dim = " << en.index();
1407   }
1408 
1409   // Match offset and strides in static_offset and static_strides attributes if
1410   // result memref type has an affine map specified.
1411   if (!resultType.getLayout().isIdentity()) {
1412     int64_t resultOffset;
1413     SmallVector<int64_t, 4> resultStrides;
1414     if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
1415       return failure();
1416 
1417     // Match offset in result memref type and in static_offsets attribute.
1418     int64_t expectedOffset =
1419         extractFromI64ArrayAttr(op.static_offsets()).front();
1420     if (resultOffset != expectedOffset)
1421       return op.emitError("expected result type with offset = ")
1422              << resultOffset << " instead of " << expectedOffset;
1423 
1424     // Match strides in result memref type and in static_strides attribute.
1425     for (auto &en : llvm::enumerate(llvm::zip(
1426              resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
1427       int64_t resultStride = std::get<0>(en.value());
1428       int64_t expectedStride = std::get<1>(en.value());
1429       if (resultStride != expectedStride)
1430         return op.emitError("expected result type with stride = ")
1431                << expectedStride << " instead of " << resultStride
1432                << " in dim = " << en.index();
1433     }
1434   }
1435   return success();
1436 }
1437 
1438 //===----------------------------------------------------------------------===//
1439 // Reassociative reshape ops
1440 //===----------------------------------------------------------------------===//
1441 
1442 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1443   return getSymbolLessAffineMaps(getReassociationExprs());
1444 }
1445 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1446   return convertReassociationIndicesToExprs(getContext(),
1447                                             getReassociationIndices());
1448 }
1449 
1450 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1451   return getSymbolLessAffineMaps(getReassociationExprs());
1452 }
1453 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1454   return convertReassociationIndicesToExprs(getContext(),
1455                                             getReassociationIndices());
1456 }
1457 
1458 static void print(OpAsmPrinter &p, ExpandShapeOp op) {
1459   ::mlir::printReshapeOp<ExpandShapeOp>(p, op);
1460 }
1461 
1462 static void print(OpAsmPrinter &p, CollapseShapeOp op) {
1463   ::mlir::printReshapeOp<CollapseShapeOp>(p, op);
1464 }
1465 
1466 /// Detect whether memref dims [dim, dim + extent) can be reshaped without
1467 /// copies.
1468 static bool isReshapableDimBand(unsigned dim, unsigned extent,
1469                                 ArrayRef<int64_t> sizes,
1470                                 ArrayRef<AffineExpr> strides) {
1471   assert(sizes.size() == strides.size() && "mismatched ranks");
1472   // off by 1 indexing to avoid out of bounds
1473   //                       V
1474   for (auto idx = dim, e = dim + extent; idx + 1 < e; ++idx) {
1475     // Only bands of static shapes are reshapable. This is due to the fact that
1476     // there is no relation between dynamic sizes and dynamic strides: we do not
1477     // have enough information to know whether a "-1" size corresponds to the
1478     // proper symbol in the AffineExpr of a stride.
1479     if (ShapedType::isDynamic(sizes[dim + 1]))
1480       return false;
1481     // TODO: Refine this by passing the proper nDims and nSymbols so we can
1482     // simplify on the fly and catch more reshapable cases.
1483     if (strides[idx] != strides[idx + 1] * sizes[idx + 1])
1484       return false;
1485   }
1486   return true;
1487 }
1488 
1489 /// Compute the MemRefType obtained by applying the `reassociation` (which is
1490 /// expected to be valid) to `type`.
1491 /// If `type` is Contiguous MemRefType, this always produce a contiguous
1492 /// MemRefType.
1493 static MemRefType
1494 computeReshapeCollapsedType(MemRefType type,
1495                             ArrayRef<AffineMap> reassociation) {
1496   auto sizes = type.getShape();
1497   AffineExpr offset;
1498   SmallVector<AffineExpr, 4> strides;
1499   auto status = getStridesAndOffset(type, strides, offset);
1500   (void)status;
1501   assert(succeeded(status) && "expected strided memref");
1502 
1503   SmallVector<int64_t, 4> newSizes;
1504   newSizes.reserve(reassociation.size());
1505   SmallVector<AffineExpr, 4> newStrides;
1506   newStrides.reserve(reassociation.size());
1507 
1508   // Use the fact that reassociation is valid to simplify the logic: only use
1509   // each map's rank.
1510   assert(isReassociationValid(reassociation) && "invalid reassociation");
1511   unsigned currentDim = 0;
1512   for (AffineMap m : reassociation) {
1513     unsigned dim = m.getNumResults();
1514     int64_t size = 1;
1515     AffineExpr stride = strides[currentDim + dim - 1];
1516     if (!isReshapableDimBand(currentDim, dim, sizes, strides)) {
1517       size = ShapedType::kDynamicSize;
1518       stride = AffineExpr();
1519     } else {
1520       for (unsigned d = 0; d < dim; ++d)
1521         size *= sizes[currentDim + d];
1522     }
1523     newSizes.push_back(size);
1524     newStrides.push_back(stride);
1525     currentDim += dim;
1526   }
1527 
1528   // Early-exit: if `type` is contiguous, the result must be contiguous.
1529   if (canonicalizeStridedLayout(type).getLayout().isIdentity())
1530     return MemRefType::Builder(type).setShape(newSizes).setLayout({});
1531 
1532   // Convert back to int64_t because we don't have enough information to create
1533   // new strided layouts from AffineExpr only. This corresponds to a case where
1534   // copies may be necessary.
1535   int64_t intOffset = ShapedType::kDynamicStrideOrOffset;
1536   if (auto o = offset.dyn_cast<AffineConstantExpr>())
1537     intOffset = o.getValue();
1538   SmallVector<int64_t, 4> intStrides;
1539   intStrides.reserve(strides.size());
1540   for (auto stride : newStrides) {
1541     if (auto cst = stride.dyn_cast_or_null<AffineConstantExpr>())
1542       intStrides.push_back(cst.getValue());
1543     else
1544       intStrides.push_back(ShapedType::kDynamicStrideOrOffset);
1545   }
1546   auto layout =
1547       makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext());
1548   return canonicalizeStridedLayout(
1549       MemRefType::Builder(type).setShape(newSizes).setLayout(
1550           AffineMapAttr::get(layout)));
1551 }
1552 
1553 void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src,
1554                           ArrayRef<ReassociationIndices> reassociation,
1555                           ArrayRef<NamedAttribute> attrs) {
1556   auto memRefType = src.getType().cast<MemRefType>();
1557   auto resultType = computeReshapeCollapsedType(
1558       memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
1559                       b.getContext(), reassociation)));
1560   build(b, result, resultType, src, attrs);
1561   result.addAttribute(getReassociationAttrName(),
1562                       getReassociationIndicesAttribute(b, reassociation));
1563 }
1564 
1565 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
1566                             ArrayRef<ReassociationIndices> reassociation,
1567                             ArrayRef<NamedAttribute> attrs) {
1568   auto memRefType = src.getType().cast<MemRefType>();
1569   auto resultType = computeReshapeCollapsedType(
1570       memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
1571                       b.getContext(), reassociation)));
1572   build(b, result, resultType, src, attrs);
1573   result.addAttribute(getReassociationAttrName(),
1574                       getReassociationIndicesAttribute(b, reassociation));
1575 }
1576 
1577 template <typename ReshapeOp,
1578           bool isExpansion = std::is_same<ReshapeOp, ExpandShapeOp>::value>
1579 static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType,
1580                                      MemRefType collapsedType) {
1581   if (failed(
1582           verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
1583     return failure();
1584   auto maps = op.getReassociationMaps();
1585   MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps);
1586   if (collapsedType != expectedType)
1587     return op.emitOpError("expected collapsed type to be ")
1588            << expectedType << ", but got " << collapsedType;
1589   return success();
1590 }
1591 
1592 static LogicalResult verify(ExpandShapeOp op) {
1593   return verifyReshapeOp(op, op.getResultType(), op.getSrcType());
1594 }
1595 
1596 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1597                                                 MLIRContext *context) {
1598   results.add<CollapseReshapeOps<ExpandShapeOp>,
1599               CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context);
1600 }
1601 
1602 static LogicalResult verify(CollapseShapeOp op) {
1603   return verifyReshapeOp(op, op.getSrcType(), op.getResultType());
1604 }
1605 
1606 struct CollapseShapeOpMemRefCastFolder
1607     : public OpRewritePattern<CollapseShapeOp> {
1608 public:
1609   using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
1610 
1611   LogicalResult matchAndRewrite(CollapseShapeOp op,
1612                                 PatternRewriter &rewriter) const override {
1613     auto cast = op.getOperand().getDefiningOp<CastOp>();
1614     if (!cast)
1615       return failure();
1616 
1617     if (!CastOp::canFoldIntoConsumerOp(cast))
1618       return failure();
1619 
1620     Type newResultType = computeReshapeCollapsedType(
1621         cast.getOperand().getType().cast<MemRefType>(),
1622         op.getReassociationMaps());
1623 
1624     if (newResultType == op.getResultType()) {
1625       rewriter.updateRootInPlace(
1626           op, [&]() { op.srcMutable().assign(cast.source()); });
1627     } else {
1628       Value newOp = rewriter.create<CollapseShapeOp>(
1629           op->getLoc(), cast.source(), op.getReassociationIndices());
1630       rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
1631     }
1632     return success();
1633   }
1634 };
1635 
1636 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1637                                                   MLIRContext *context) {
1638   results.add<CollapseReshapeOps<CollapseShapeOp>,
1639               CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>,
1640               CollapseShapeOpMemRefCastFolder>(context);
1641 }
1642 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
1643   if (succeeded(foldMemRefCast(*this)))
1644     return getResult();
1645   return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
1646 }
1647 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
1648   return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
1649 }
1650 
1651 //===----------------------------------------------------------------------===//
1652 // ReshapeOp
1653 //===----------------------------------------------------------------------===//
1654 
1655 static LogicalResult verify(ReshapeOp op) {
1656   Type operandType = op.source().getType();
1657   Type resultType = op.result().getType();
1658 
1659   Type operandElementType = operandType.cast<ShapedType>().getElementType();
1660   Type resultElementType = resultType.cast<ShapedType>().getElementType();
1661   if (operandElementType != resultElementType)
1662     return op.emitOpError("element types of source and destination memref "
1663                           "types should be the same");
1664 
1665   if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
1666     if (!operandMemRefType.getLayout().isIdentity())
1667       return op.emitOpError(
1668           "source memref type should have identity affine map");
1669 
1670   int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0);
1671   auto resultMemRefType = resultType.dyn_cast<MemRefType>();
1672   if (resultMemRefType) {
1673     if (!resultMemRefType.getLayout().isIdentity())
1674       return op.emitOpError(
1675           "result memref type should have identity affine map");
1676     if (shapeSize == ShapedType::kDynamicSize)
1677       return op.emitOpError("cannot use shape operand with dynamic length to "
1678                             "reshape to statically-ranked memref type");
1679     if (shapeSize != resultMemRefType.getRank())
1680       return op.emitOpError(
1681           "length of shape operand differs from the result's memref rank");
1682   }
1683   return success();
1684 }
1685 
1686 //===----------------------------------------------------------------------===//
1687 // StoreOp
1688 //===----------------------------------------------------------------------===//
1689 
1690 static LogicalResult verify(StoreOp op) {
1691   if (op.getNumOperands() != 2 + op.getMemRefType().getRank())
1692     return op.emitOpError("store index operand count not equal to memref rank");
1693 
1694   return success();
1695 }
1696 
1697 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
1698                             SmallVectorImpl<OpFoldResult> &results) {
1699   /// store(memrefcast) -> store
1700   return foldMemRefCast(*this, getValueToStore());
1701 }
1702 
1703 //===----------------------------------------------------------------------===//
1704 // SubViewOp
1705 //===----------------------------------------------------------------------===//
1706 
1707 namespace {
1708 /// Helpers to write more idiomatic operations.
1709 namespace saturated_arith {
1710 struct Wrapper {
1711   explicit Wrapper(int64_t v) : v(v) {}
1712   operator int64_t() { return v; }
1713   int64_t v;
1714 };
1715 Wrapper operator+(Wrapper a, int64_t b) {
1716   if (ShapedType::isDynamicStrideOrOffset(a) ||
1717       ShapedType::isDynamicStrideOrOffset(b))
1718     return Wrapper(ShapedType::kDynamicStrideOrOffset);
1719   return Wrapper(a.v + b);
1720 }
1721 Wrapper operator*(Wrapper a, int64_t b) {
1722   if (ShapedType::isDynamicStrideOrOffset(a) ||
1723       ShapedType::isDynamicStrideOrOffset(b))
1724     return Wrapper(ShapedType::kDynamicStrideOrOffset);
1725   return Wrapper(a.v * b);
1726 }
1727 } // end namespace saturated_arith
1728 } // end namespace
1729 
1730 /// A subview result type can be fully inferred from the source type and the
1731 /// static representation of offsets, sizes and strides. Special sentinels
1732 /// encode the dynamic case.
1733 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
1734                                 ArrayRef<int64_t> leadingStaticOffsets,
1735                                 ArrayRef<int64_t> leadingStaticSizes,
1736                                 ArrayRef<int64_t> leadingStaticStrides) {
1737   // A subview may specify only a leading subset of offset/sizes/strides in
1738   // which case we complete with offset=0, sizes from memref type and strides=1.
1739   unsigned rank = sourceMemRefType.getRank();
1740   assert(leadingStaticOffsets.size() <= rank &&
1741          "unexpected leadingStaticOffsets overflow");
1742   assert(leadingStaticSizes.size() <= rank &&
1743          "unexpected leadingStaticSizes overflow");
1744   assert(leadingStaticStrides.size() <= rank &&
1745          "unexpected leadingStaticStrides overflow");
1746   auto staticOffsets = llvm::to_vector<4>(leadingStaticOffsets);
1747   auto staticSizes = llvm::to_vector<4>(leadingStaticSizes);
1748   auto staticStrides = llvm::to_vector<4>(leadingStaticStrides);
1749   unsigned numTrailingOffsets = rank - staticOffsets.size();
1750   unsigned numTrailingSizes = rank - staticSizes.size();
1751   unsigned numTrailingStrides = rank - staticStrides.size();
1752   staticOffsets.append(numTrailingOffsets, 0);
1753   llvm::append_range(staticSizes,
1754                      sourceMemRefType.getShape().take_back(numTrailingSizes));
1755   staticStrides.append(numTrailingStrides, 1);
1756 
1757   // Extract source offset and strides.
1758   int64_t sourceOffset;
1759   SmallVector<int64_t, 4> sourceStrides;
1760   auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset);
1761   assert(succeeded(res) && "SubViewOp expected strided memref type");
1762   (void)res;
1763 
1764   // Compute target offset whose value is:
1765   //   `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
1766   int64_t targetOffset = sourceOffset;
1767   for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
1768     auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
1769     using namespace saturated_arith;
1770     targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride;
1771   }
1772 
1773   // Compute target stride whose value is:
1774   //   `sourceStrides_i * staticStrides_i`.
1775   SmallVector<int64_t, 4> targetStrides;
1776   targetStrides.reserve(staticOffsets.size());
1777   for (auto it : llvm::zip(sourceStrides, staticStrides)) {
1778     auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
1779     using namespace saturated_arith;
1780     targetStrides.push_back(Wrapper(sourceStride) * staticStride);
1781   }
1782 
1783   // The type is now known.
1784   return MemRefType::get(
1785       staticSizes, sourceMemRefType.getElementType(),
1786       makeStridedLinearLayoutMap(targetStrides, targetOffset,
1787                                  sourceMemRefType.getContext()),
1788       sourceMemRefType.getMemorySpace());
1789 }
1790 
1791 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
1792                                 ArrayRef<OpFoldResult> leadingStaticOffsets,
1793                                 ArrayRef<OpFoldResult> leadingStaticSizes,
1794                                 ArrayRef<OpFoldResult> leadingStaticStrides) {
1795   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1796   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1797   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
1798                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
1799   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
1800                              ShapedType::kDynamicSize);
1801   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
1802                              staticStrides, ShapedType::kDynamicStrideOrOffset);
1803   return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
1804                                     staticSizes, staticStrides)
1805       .cast<MemRefType>();
1806 }
1807 
1808 Type SubViewOp::inferRankReducedResultType(
1809     unsigned resultRank, MemRefType sourceRankedTensorType,
1810     ArrayRef<int64_t> leadingStaticOffsets,
1811     ArrayRef<int64_t> leadingStaticSizes,
1812     ArrayRef<int64_t> leadingStaticStrides) {
1813   auto inferredType =
1814       inferResultType(sourceRankedTensorType, leadingStaticOffsets,
1815                       leadingStaticSizes, leadingStaticStrides)
1816           .cast<MemRefType>();
1817   assert(inferredType.getRank() >= resultRank && "expected ");
1818   int rankDiff = inferredType.getRank() - resultRank;
1819   if (rankDiff > 0) {
1820     auto shape = inferredType.getShape();
1821     llvm::SmallDenseSet<unsigned> dimsToProject;
1822     mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject);
1823     SmallVector<int64_t> projectedShape;
1824     for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
1825       if (!dimsToProject.contains(pos))
1826         projectedShape.push_back(shape[pos]);
1827 
1828     AffineMap map = inferredType.getLayout().getAffineMap();
1829     if (!map.isIdentity())
1830       map = getProjectedMap(map, dimsToProject);
1831     inferredType =
1832         MemRefType::get(projectedShape, inferredType.getElementType(), map,
1833                         inferredType.getMemorySpace());
1834   }
1835   return inferredType;
1836 }
1837 
1838 Type SubViewOp::inferRankReducedResultType(
1839     unsigned resultRank, MemRefType sourceRankedTensorType,
1840     ArrayRef<OpFoldResult> leadingStaticOffsets,
1841     ArrayRef<OpFoldResult> leadingStaticSizes,
1842     ArrayRef<OpFoldResult> leadingStaticStrides) {
1843   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1844   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1845   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
1846                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
1847   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
1848                              ShapedType::kDynamicSize);
1849   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
1850                              staticStrides, ShapedType::kDynamicStrideOrOffset);
1851   return SubViewOp::inferRankReducedResultType(
1852       resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
1853       staticStrides);
1854 }
1855 // Build a SubViewOp with mixed static and dynamic entries and custom result
1856 // type. If the type passed is nullptr, it is inferred.
1857 void SubViewOp::build(OpBuilder &b, OperationState &result,
1858                       MemRefType resultType, Value source,
1859                       ArrayRef<OpFoldResult> offsets,
1860                       ArrayRef<OpFoldResult> sizes,
1861                       ArrayRef<OpFoldResult> strides,
1862                       ArrayRef<NamedAttribute> attrs) {
1863   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1864   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1865   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1866                              ShapedType::kDynamicStrideOrOffset);
1867   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1868                              ShapedType::kDynamicSize);
1869   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1870                              ShapedType::kDynamicStrideOrOffset);
1871   auto sourceMemRefType = source.getType().cast<MemRefType>();
1872   // Structuring implementation this way avoids duplication between builders.
1873   if (!resultType) {
1874     resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
1875                                             staticSizes, staticStrides)
1876                      .cast<MemRefType>();
1877   }
1878   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1879         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1880         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1881   result.addAttributes(attrs);
1882 }
1883 
1884 // Build a SubViewOp with mixed static and dynamic entries and inferred result
1885 // type.
1886 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1887                       ArrayRef<OpFoldResult> offsets,
1888                       ArrayRef<OpFoldResult> sizes,
1889                       ArrayRef<OpFoldResult> strides,
1890                       ArrayRef<NamedAttribute> attrs) {
1891   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
1892 }
1893 
1894 // Build a SubViewOp with static entries and inferred result type.
1895 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1896                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1897                       ArrayRef<int64_t> strides,
1898                       ArrayRef<NamedAttribute> attrs) {
1899   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1900       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
1901         return b.getI64IntegerAttr(v);
1902       }));
1903   SmallVector<OpFoldResult> sizeValues =
1904       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1905         return b.getI64IntegerAttr(v);
1906       }));
1907   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1908       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1909         return b.getI64IntegerAttr(v);
1910       }));
1911   build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
1912 }
1913 
1914 // Build a SubViewOp with dynamic entries and custom result type. If the
1915 // type passed is nullptr, it is inferred.
1916 void SubViewOp::build(OpBuilder &b, OperationState &result,
1917                       MemRefType resultType, Value source,
1918                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1919                       ArrayRef<int64_t> strides,
1920                       ArrayRef<NamedAttribute> attrs) {
1921   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1922       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
1923         return b.getI64IntegerAttr(v);
1924       }));
1925   SmallVector<OpFoldResult> sizeValues =
1926       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1927         return b.getI64IntegerAttr(v);
1928       }));
1929   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1930       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1931         return b.getI64IntegerAttr(v);
1932       }));
1933   build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
1934         attrs);
1935 }
1936 
1937 // Build a SubViewOp with dynamic entries and custom result type. If the type
1938 // passed is nullptr, it is inferred.
1939 void SubViewOp::build(OpBuilder &b, OperationState &result,
1940                       MemRefType resultType, Value source, ValueRange offsets,
1941                       ValueRange sizes, ValueRange strides,
1942                       ArrayRef<NamedAttribute> attrs) {
1943   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1944       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1945   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1946       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1947   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1948       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1949   build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
1950 }
1951 
1952 // Build a SubViewOp with dynamic entries and inferred result type.
1953 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1954                       ValueRange offsets, ValueRange sizes, ValueRange strides,
1955                       ArrayRef<NamedAttribute> attrs) {
1956   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
1957 }
1958 
1959 /// For ViewLikeOpInterface.
1960 Value SubViewOp::getViewSource() { return source(); }
1961 
1962 enum SubViewVerificationResult {
1963   Success,
1964   RankTooLarge,
1965   SizeMismatch,
1966   ElemTypeMismatch,
1967   MemSpaceMismatch,
1968   AffineMapMismatch
1969 };
1970 
1971 /// Checks if `original` Type type can be rank reduced to `reduced` type.
1972 /// This function is slight variant of `is subsequence` algorithm where
1973 /// not matching dimension must be 1.
1974 static SubViewVerificationResult
1975 isRankReducedType(Type originalType, Type candidateReducedType,
1976                   ArrayAttr staticSizes, std::string *errMsg = nullptr) {
1977   if (originalType == candidateReducedType)
1978     return SubViewVerificationResult::Success;
1979   if (!originalType.isa<MemRefType>())
1980     return SubViewVerificationResult::Success;
1981   if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>())
1982     return SubViewVerificationResult::Success;
1983 
1984   ShapedType originalShapedType = originalType.cast<ShapedType>();
1985   ShapedType candidateReducedShapedType =
1986       candidateReducedType.cast<ShapedType>();
1987 
1988   // Rank and size logic is valid for all ShapedTypes.
1989   ArrayRef<int64_t> originalShape = originalShapedType.getShape();
1990   ArrayRef<int64_t> candidateReducedShape =
1991       candidateReducedShapedType.getShape();
1992   unsigned originalRank = originalShape.size(),
1993            candidateReducedRank = candidateReducedShape.size();
1994   if (candidateReducedRank > originalRank)
1995     return SubViewVerificationResult::RankTooLarge;
1996 
1997   MemRefType original = originalType.cast<MemRefType>();
1998   MemRefType candidateReduced = candidateReducedType.cast<MemRefType>();
1999 
2000   auto optionalUnusedDimsMask =
2001       computeMemRefRankReductionMask(original, candidateReduced, staticSizes);
2002 
2003   // Sizes cannot be matched in case empty vector is returned.
2004   if (!optionalUnusedDimsMask.hasValue())
2005     return SubViewVerificationResult::SizeMismatch;
2006 
2007   if (originalShapedType.getElementType() !=
2008       candidateReducedShapedType.getElementType())
2009     return SubViewVerificationResult::ElemTypeMismatch;
2010 
2011   // Strided layout logic is relevant for MemRefType only.
2012   if (original.getMemorySpace() != candidateReduced.getMemorySpace())
2013     return SubViewVerificationResult::MemSpaceMismatch;
2014   return SubViewVerificationResult::Success;
2015 }
2016 
2017 template <typename OpTy>
2018 static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result,
2019                                             OpTy op, Type expectedType,
2020                                             StringRef errMsg = "") {
2021   auto memrefType = expectedType.cast<ShapedType>();
2022   switch (result) {
2023   case SubViewVerificationResult::Success:
2024     return success();
2025   case SubViewVerificationResult::RankTooLarge:
2026     return op.emitError("expected result rank to be smaller or equal to ")
2027            << "the source rank. " << errMsg;
2028   case SubViewVerificationResult::SizeMismatch:
2029     return op.emitError("expected result type to be ")
2030            << expectedType
2031            << " or a rank-reduced version. (mismatch of result sizes) "
2032            << errMsg;
2033   case SubViewVerificationResult::ElemTypeMismatch:
2034     return op.emitError("expected result element type to be ")
2035            << memrefType.getElementType() << errMsg;
2036   case SubViewVerificationResult::MemSpaceMismatch:
2037     return op.emitError("expected result and source memory spaces to match.")
2038            << errMsg;
2039   case SubViewVerificationResult::AffineMapMismatch:
2040     return op.emitError("expected result type to be ")
2041            << expectedType
2042            << " or a rank-reduced version. (mismatch of result affine map) "
2043            << errMsg;
2044   }
2045   llvm_unreachable("unexpected subview verification result");
2046 }
2047 
2048 /// Verifier for SubViewOp.
2049 static LogicalResult verify(SubViewOp op) {
2050   MemRefType baseType = op.getSourceType();
2051   MemRefType subViewType = op.getType();
2052 
2053   // The base memref and the view memref should be in the same memory space.
2054   if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2055     return op.emitError("different memory spaces specified for base memref "
2056                         "type ")
2057            << baseType << " and subview memref type " << subViewType;
2058 
2059   // Verify that the base memref type has a strided layout map.
2060   if (!isStrided(baseType))
2061     return op.emitError("base type ") << baseType << " is not strided";
2062 
2063   // Verify result type against inferred type.
2064   auto expectedType = SubViewOp::inferResultType(
2065       baseType, extractFromI64ArrayAttr(op.static_offsets()),
2066       extractFromI64ArrayAttr(op.static_sizes()),
2067       extractFromI64ArrayAttr(op.static_strides()));
2068 
2069   std::string errMsg;
2070   auto result =
2071       isRankReducedType(expectedType, subViewType, op.static_sizes(), &errMsg);
2072   return produceSubViewErrorMsg(result, op, expectedType, errMsg);
2073 }
2074 
2075 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
2076   return os << "range " << range.offset << ":" << range.size << ":"
2077             << range.stride;
2078 }
2079 
2080 /// Return the list of Range (i.e. offset, size, stride). Each Range
2081 /// entry contains either the dynamic value or a ConstantIndexOp constructed
2082 /// with `b` at location `loc`.
2083 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
2084                                               OpBuilder &b, Location loc) {
2085   std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
2086   assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
2087   assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
2088   SmallVector<Range, 8> res;
2089   unsigned rank = ranks[0];
2090   res.reserve(rank);
2091   for (unsigned idx = 0; idx < rank; ++idx) {
2092     Value offset =
2093         op.isDynamicOffset(idx)
2094             ? op.getDynamicOffset(idx)
2095             : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx));
2096     Value size =
2097         op.isDynamicSize(idx)
2098             ? op.getDynamicSize(idx)
2099             : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx));
2100     Value stride =
2101         op.isDynamicStride(idx)
2102             ? op.getDynamicStride(idx)
2103             : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx));
2104     res.emplace_back(Range{offset, size, stride});
2105   }
2106   return res;
2107 }
2108 
2109 /// Infer the canonical type of the result of a subview operation. Returns a
2110 /// type with rank `resultRank` that is either the rank of the rank-reduced
2111 /// type, or the non-rank-reduced type.
2112 static MemRefType
2113 getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType,
2114                               ArrayRef<OpFoldResult> mixedOffsets,
2115                               ArrayRef<OpFoldResult> mixedSizes,
2116                               ArrayRef<OpFoldResult> mixedStrides) {
2117   auto resultType =
2118       SubViewOp::inferRankReducedResultType(
2119           resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides)
2120           .cast<MemRefType>();
2121   if (resultType.getRank() != resultRank) {
2122     resultType = SubViewOp::inferResultType(sourceType, mixedOffsets,
2123                                             mixedSizes, mixedStrides)
2124                      .cast<MemRefType>();
2125   }
2126   return resultType;
2127 }
2128 
2129 namespace {
2130 /// Pattern to rewrite a subview op with MemRefCast arguments.
2131 /// This essentially pushes memref.cast past its consuming subview when
2132 /// `canFoldIntoConsumerOp` is true.
2133 ///
2134 /// Example:
2135 /// ```
2136 ///   %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
2137 ///   %1 = memref.subview %0[0, 0][3, 4][1, 1] :
2138 ///     memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
2139 /// ```
2140 /// is rewritten into:
2141 /// ```
2142 ///   %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
2143 ///   %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to
2144 ///     memref<3x4xf32, offset:?, strides:[?, 1]>
2145 /// ```
2146 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
2147 public:
2148   using OpRewritePattern<SubViewOp>::OpRewritePattern;
2149 
2150   LogicalResult matchAndRewrite(SubViewOp subViewOp,
2151                                 PatternRewriter &rewriter) const override {
2152     // Any constant operand, just return to let SubViewOpConstantFolder kick in.
2153     if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
2154           return matchPattern(operand, matchConstantIndex());
2155         }))
2156       return failure();
2157 
2158     auto castOp = subViewOp.source().getDefiningOp<CastOp>();
2159     if (!castOp)
2160       return failure();
2161 
2162     if (!CastOp::canFoldIntoConsumerOp(castOp))
2163       return failure();
2164 
2165     /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
2166     /// the cast source operand type and the SubViewOp static information. This
2167     /// is the resulting type if the MemRefCastOp were folded.
2168     auto resultType = getCanonicalSubViewResultType(
2169         subViewOp.getType().getRank(),
2170         castOp.source().getType().cast<MemRefType>(),
2171         subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
2172         subViewOp.getMixedStrides());
2173     Value newSubView = rewriter.create<SubViewOp>(
2174         subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
2175         subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),
2176         subViewOp.static_sizes(), subViewOp.static_strides());
2177     rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
2178                                         newSubView);
2179     return success();
2180   }
2181 };
2182 } // namespace
2183 
2184 /// Return the canonical type of the result of a subview.
2185 struct SubViewReturnTypeCanonicalizer {
2186   MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
2187                         ArrayRef<OpFoldResult> mixedSizes,
2188                         ArrayRef<OpFoldResult> mixedStrides) {
2189     return getCanonicalSubViewResultType(op.getType().getRank(),
2190                                          op.getSourceType(), mixedOffsets,
2191                                          mixedSizes, mixedStrides);
2192   }
2193 };
2194 
2195 /// A canonicalizer wrapper to replace SubViewOps.
2196 struct SubViewCanonicalizer {
2197   void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
2198     rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType());
2199   }
2200 };
2201 
2202 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
2203                                             MLIRContext *context) {
2204   results
2205       .add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
2206                SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>,
2207            SubViewOpMemRefCastFolder>(context);
2208 }
2209 
2210 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
2211   auto resultShapedType = getResult().getType().cast<ShapedType>();
2212   auto sourceShapedType = source().getType().cast<ShapedType>();
2213 
2214   if (resultShapedType.hasStaticShape() &&
2215       resultShapedType == sourceShapedType) {
2216     return getViewSource();
2217   }
2218 
2219   return {};
2220 }
2221 
2222 //===----------------------------------------------------------------------===//
2223 // TensorLoadOp
2224 //===----------------------------------------------------------------------===//
2225 
2226 OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute>) {
2227   if (auto bufferCast = memref().getDefiningOp<BufferCastOp>())
2228     // Approximate alias analysis by conservatively folding only when no there
2229     // is no interleaved operation.
2230     if (bufferCast->getBlock() == this->getOperation()->getBlock() &&
2231         bufferCast->getNextNode() == this->getOperation())
2232       return bufferCast.tensor();
2233   return {};
2234 }
2235 
2236 namespace {
2237 struct DimOfTensorLoadFolder : public OpRewritePattern<tensor::DimOp> {
2238   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
2239 
2240   LogicalResult matchAndRewrite(tensor::DimOp dimOp,
2241                                 PatternRewriter &rewriter) const override {
2242     auto tensorLoadOp = dimOp.source().getDefiningOp<TensorLoadOp>();
2243     if (!tensorLoadOp)
2244       return failure();
2245 
2246     rewriter.replaceOpWithNewOp<DimOp>(dimOp, tensorLoadOp.memref(),
2247                                        dimOp.index());
2248     return success();
2249   }
2250 };
2251 } // namespace
2252 
2253 void TensorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
2254                                                MLIRContext *context) {
2255   results.add<DimOfTensorLoadFolder>(context);
2256 }
2257 
2258 //===----------------------------------------------------------------------===//
2259 // TransposeOp
2260 //===----------------------------------------------------------------------===//
2261 
2262 /// Build a strided memref type by applying `permutationMap` tp `memRefType`.
2263 static MemRefType inferTransposeResultType(MemRefType memRefType,
2264                                            AffineMap permutationMap) {
2265   auto rank = memRefType.getRank();
2266   auto originalSizes = memRefType.getShape();
2267   // Compute permuted sizes.
2268   SmallVector<int64_t, 4> sizes(rank, 0);
2269   for (auto en : llvm::enumerate(permutationMap.getResults()))
2270     sizes[en.index()] =
2271         originalSizes[en.value().cast<AffineDimExpr>().getPosition()];
2272 
2273   // Compute permuted strides.
2274   int64_t offset;
2275   SmallVector<int64_t, 4> strides;
2276   auto res = getStridesAndOffset(memRefType, strides, offset);
2277   assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank));
2278   (void)res;
2279   auto map =
2280       makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
2281   map = permutationMap ? map.compose(permutationMap) : map;
2282   return MemRefType::Builder(memRefType)
2283       .setShape(sizes)
2284       .setLayout(AffineMapAttr::get(map));
2285 }
2286 
2287 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
2288                         AffineMapAttr permutation,
2289                         ArrayRef<NamedAttribute> attrs) {
2290   auto permutationMap = permutation.getValue();
2291   assert(permutationMap);
2292 
2293   auto memRefType = in.getType().cast<MemRefType>();
2294   // Compute result type.
2295   MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
2296 
2297   build(b, result, resultType, in, attrs);
2298   result.addAttribute(TransposeOp::getPermutationAttrName(), permutation);
2299 }
2300 
2301 // transpose $in $permutation attr-dict : type($in) `to` type(results)
2302 static void print(OpAsmPrinter &p, TransposeOp op) {
2303   p << " " << op.in() << " " << op.permutation();
2304   p.printOptionalAttrDict(op->getAttrs(),
2305                           {TransposeOp::getPermutationAttrName()});
2306   p << " : " << op.in().getType() << " to " << op.getType();
2307 }
2308 
2309 static ParseResult parseTransposeOp(OpAsmParser &parser,
2310                                     OperationState &result) {
2311   OpAsmParser::OperandType in;
2312   AffineMap permutation;
2313   MemRefType srcType, dstType;
2314   if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
2315       parser.parseOptionalAttrDict(result.attributes) ||
2316       parser.parseColonType(srcType) ||
2317       parser.resolveOperand(in, srcType, result.operands) ||
2318       parser.parseKeywordType("to", dstType) ||
2319       parser.addTypeToList(dstType, result.types))
2320     return failure();
2321 
2322   result.addAttribute(TransposeOp::getPermutationAttrName(),
2323                       AffineMapAttr::get(permutation));
2324   return success();
2325 }
2326 
2327 static LogicalResult verify(TransposeOp op) {
2328   if (!op.permutation().isPermutation())
2329     return op.emitOpError("expected a permutation map");
2330   if (op.permutation().getNumDims() != op.getShapedType().getRank())
2331     return op.emitOpError(
2332         "expected a permutation map of same rank as the input");
2333 
2334   auto srcType = op.in().getType().cast<MemRefType>();
2335   auto dstType = op.getType().cast<MemRefType>();
2336   auto transposedType = inferTransposeResultType(srcType, op.permutation());
2337   if (dstType != transposedType)
2338     return op.emitOpError("output type ")
2339            << dstType << " does not match transposed input type " << srcType
2340            << ", " << transposedType;
2341   return success();
2342 }
2343 
2344 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
2345   if (succeeded(foldMemRefCast(*this)))
2346     return getResult();
2347   return {};
2348 }
2349 
2350 //===----------------------------------------------------------------------===//
2351 // ViewOp
2352 //===----------------------------------------------------------------------===//
2353 
2354 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) {
2355   OpAsmParser::OperandType srcInfo;
2356   SmallVector<OpAsmParser::OperandType, 1> offsetInfo;
2357   SmallVector<OpAsmParser::OperandType, 4> sizesInfo;
2358   auto indexType = parser.getBuilder().getIndexType();
2359   Type srcType, dstType;
2360   llvm::SMLoc offsetLoc;
2361   if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) ||
2362       parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square))
2363     return failure();
2364 
2365   if (offsetInfo.size() != 1)
2366     return parser.emitError(offsetLoc) << "expects 1 offset operand";
2367 
2368   return failure(
2369       parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) ||
2370       parser.parseOptionalAttrDict(result.attributes) ||
2371       parser.parseColonType(srcType) ||
2372       parser.resolveOperand(srcInfo, srcType, result.operands) ||
2373       parser.resolveOperands(offsetInfo, indexType, result.operands) ||
2374       parser.resolveOperands(sizesInfo, indexType, result.operands) ||
2375       parser.parseKeywordType("to", dstType) ||
2376       parser.addTypeToList(dstType, result.types));
2377 }
2378 
2379 static void print(OpAsmPrinter &p, ViewOp op) {
2380   p << ' ' << op.getOperand(0) << '[';
2381   p.printOperand(op.byte_shift());
2382   p << "][" << op.sizes() << ']';
2383   p.printOptionalAttrDict(op->getAttrs());
2384   p << " : " << op.getOperand(0).getType() << " to " << op.getType();
2385 }
2386 
2387 static LogicalResult verify(ViewOp op) {
2388   auto baseType = op.getOperand(0).getType().cast<MemRefType>();
2389   auto viewType = op.getType();
2390 
2391   // The base memref should have identity layout map (or none).
2392   if (!baseType.getLayout().isIdentity())
2393     return op.emitError("unsupported map for base memref type ") << baseType;
2394 
2395   // The result memref should have identity layout map (or none).
2396   if (!viewType.getLayout().isIdentity())
2397     return op.emitError("unsupported map for result memref type ") << viewType;
2398 
2399   // The base memref and the view memref should be in the same memory space.
2400   if (baseType.getMemorySpace() != viewType.getMemorySpace())
2401     return op.emitError("different memory spaces specified for base memref "
2402                         "type ")
2403            << baseType << " and view memref type " << viewType;
2404 
2405   // Verify that we have the correct number of sizes for the result type.
2406   unsigned numDynamicDims = viewType.getNumDynamicDims();
2407   if (op.sizes().size() != numDynamicDims)
2408     return op.emitError("incorrect number of size operands for type ")
2409            << viewType;
2410 
2411   return success();
2412 }
2413 
2414 Value ViewOp::getViewSource() { return source(); }
2415 
2416 namespace {
2417 
2418 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
2419   using OpRewritePattern<ViewOp>::OpRewritePattern;
2420 
2421   LogicalResult matchAndRewrite(ViewOp viewOp,
2422                                 PatternRewriter &rewriter) const override {
2423     // Return if none of the operands are constants.
2424     if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
2425           return matchPattern(operand, matchConstantIndex());
2426         }))
2427       return failure();
2428 
2429     // Get result memref type.
2430     auto memrefType = viewOp.getType();
2431 
2432     // Get offset from old memref view type 'memRefType'.
2433     int64_t oldOffset;
2434     SmallVector<int64_t, 4> oldStrides;
2435     if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
2436       return failure();
2437     assert(oldOffset == 0 && "Expected 0 offset");
2438 
2439     SmallVector<Value, 4> newOperands;
2440 
2441     // Offset cannot be folded into result type.
2442 
2443     // Fold any dynamic dim operands which are produced by a constant.
2444     SmallVector<int64_t, 4> newShapeConstants;
2445     newShapeConstants.reserve(memrefType.getRank());
2446 
2447     unsigned dynamicDimPos = 0;
2448     unsigned rank = memrefType.getRank();
2449     for (unsigned dim = 0, e = rank; dim < e; ++dim) {
2450       int64_t dimSize = memrefType.getDimSize(dim);
2451       // If this is already static dimension, keep it.
2452       if (!ShapedType::isDynamic(dimSize)) {
2453         newShapeConstants.push_back(dimSize);
2454         continue;
2455       }
2456       auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp();
2457       if (auto constantIndexOp =
2458               dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
2459         // Dynamic shape dimension will be folded.
2460         newShapeConstants.push_back(constantIndexOp.value());
2461       } else {
2462         // Dynamic shape dimension not folded; copy operand from old memref.
2463         newShapeConstants.push_back(dimSize);
2464         newOperands.push_back(viewOp.sizes()[dynamicDimPos]);
2465       }
2466       dynamicDimPos++;
2467     }
2468 
2469     // Create new memref type with constant folded dims.
2470     MemRefType newMemRefType =
2471         MemRefType::Builder(memrefType).setShape(newShapeConstants);
2472     // Nothing new, don't fold.
2473     if (newMemRefType == memrefType)
2474       return failure();
2475 
2476     // Create new ViewOp.
2477     auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType,
2478                                              viewOp.getOperand(0),
2479                                              viewOp.byte_shift(), newOperands);
2480     // Insert a cast so we have the same type as the old memref type.
2481     rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType());
2482     return success();
2483   }
2484 };
2485 
2486 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
2487   using OpRewritePattern<ViewOp>::OpRewritePattern;
2488 
2489   LogicalResult matchAndRewrite(ViewOp viewOp,
2490                                 PatternRewriter &rewriter) const override {
2491     Value memrefOperand = viewOp.getOperand(0);
2492     CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
2493     if (!memrefCastOp)
2494       return failure();
2495     Value allocOperand = memrefCastOp.getOperand();
2496     AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
2497     if (!allocOp)
2498       return failure();
2499     rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
2500                                         viewOp.byte_shift(), viewOp.sizes());
2501     return success();
2502   }
2503 };
2504 
2505 } // end anonymous namespace
2506 
2507 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
2508                                          MLIRContext *context) {
2509   results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
2510 }
2511 
2512 //===----------------------------------------------------------------------===//
2513 // TableGen'd op method definitions
2514 //===----------------------------------------------------------------------===//
2515 
2516 #define GET_OP_CLASSES
2517 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
2518