1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
13 #include "mlir/Dialect/Utils/StaticValueUtils.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Interfaces/InferTypeOpInterface.h"
21 #include "mlir/Interfaces/SideEffectInterfaces.h"
22 #include "mlir/Interfaces/ViewLikeInterface.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallBitVector.h"
25 
26 using namespace mlir;
27 using namespace mlir::memref;
28 
29 /// Materialize a single constant operation from a given attribute value with
30 /// the desired resultant type.
31 Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
32                                               Attribute value, Type type,
33                                               Location loc) {
34   if (arith::ConstantOp::isBuildableWith(value, type))
35     return builder.create<arith::ConstantOp>(loc, value, type);
36   return nullptr;
37 }
38 
39 //===----------------------------------------------------------------------===//
40 // Common canonicalization pattern support logic
41 //===----------------------------------------------------------------------===//
42 
43 /// This is a common class used for patterns of the form
44 /// "someop(memrefcast) -> someop".  It folds the source of any memref.cast
45 /// into the root operation directly.
46 LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) {
47   bool folded = false;
48   for (OpOperand &operand : op->getOpOperands()) {
49     auto cast = operand.get().getDefiningOp<CastOp>();
50     if (cast && operand.get() != inner &&
51         !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
52       operand.set(cast.getOperand());
53       folded = true;
54     }
55   }
56   return success(folded);
57 }
58 
59 /// Return an unranked/ranked tensor type for the given unranked/ranked memref
60 /// type.
61 Type mlir::memref::getTensorTypeFromMemRefType(Type type) {
62   if (auto memref = type.dyn_cast<MemRefType>())
63     return RankedTensorType::get(memref.getShape(), memref.getElementType());
64   if (auto memref = type.dyn_cast<UnrankedMemRefType>())
65     return UnrankedTensorType::get(memref.getElementType());
66   return NoneType::get(type.getContext());
67 }
68 
69 //===----------------------------------------------------------------------===//
70 // AllocOp / AllocaOp
71 //===----------------------------------------------------------------------===//
72 
73 template <typename AllocLikeOp>
74 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
75   static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
76                 "applies to only alloc or alloca");
77   auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>();
78   if (!memRefType)
79     return op.emitOpError("result must be a memref");
80 
81   if (static_cast<int64_t>(op.dynamicSizes().size()) !=
82       memRefType.getNumDynamicDims())
83     return op.emitOpError("dimension operand count does not equal memref "
84                           "dynamic dimension count");
85 
86   unsigned numSymbols = 0;
87   if (!memRefType.getLayout().isIdentity())
88     numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
89   if (op.symbolOperands().size() != numSymbols)
90     return op.emitOpError("symbol operand count does not equal memref symbol "
91                           "count: expected ")
92            << numSymbols << ", got " << op.symbolOperands().size();
93 
94   return success();
95 }
96 
97 LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); }
98 
99 LogicalResult AllocaOp::verify() {
100   // An alloca op needs to have an ancestor with an allocation scope trait.
101   if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
102     return emitOpError(
103         "requires an ancestor op with AutomaticAllocationScope trait");
104 
105   return verifyAllocLikeOp(*this);
106 }
107 
108 namespace {
109 /// Fold constant dimensions into an alloc like operation.
110 template <typename AllocLikeOp>
111 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
112   using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
113 
114   LogicalResult matchAndRewrite(AllocLikeOp alloc,
115                                 PatternRewriter &rewriter) const override {
116     // Check to see if any dimensions operands are constants.  If so, we can
117     // substitute and drop them.
118     if (llvm::none_of(alloc.dynamicSizes(), [](Value operand) {
119           return matchPattern(operand, matchConstantIndex());
120         }))
121       return failure();
122 
123     auto memrefType = alloc.getType();
124 
125     // Ok, we have one or more constant operands.  Collect the non-constant ones
126     // and keep track of the resultant memref type to build.
127     SmallVector<int64_t, 4> newShapeConstants;
128     newShapeConstants.reserve(memrefType.getRank());
129     SmallVector<Value, 4> dynamicSizes;
130 
131     unsigned dynamicDimPos = 0;
132     for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
133       int64_t dimSize = memrefType.getDimSize(dim);
134       // If this is already static dimension, keep it.
135       if (dimSize != -1) {
136         newShapeConstants.push_back(dimSize);
137         continue;
138       }
139       auto dynamicSize = alloc.dynamicSizes()[dynamicDimPos];
140       auto *defOp = dynamicSize.getDefiningOp();
141       if (auto constantIndexOp =
142               dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
143         // Dynamic shape dimension will be folded.
144         newShapeConstants.push_back(constantIndexOp.value());
145       } else {
146         // Dynamic shape dimension not folded; copy dynamicSize from old memref.
147         newShapeConstants.push_back(-1);
148         dynamicSizes.push_back(dynamicSize);
149       }
150       dynamicDimPos++;
151     }
152 
153     // Create new memref type (which will have fewer dynamic dimensions).
154     MemRefType newMemRefType =
155         MemRefType::Builder(memrefType).setShape(newShapeConstants);
156     assert(static_cast<int64_t>(dynamicSizes.size()) ==
157            newMemRefType.getNumDynamicDims());
158 
159     // Create and insert the alloc op for the new memref.
160     auto newAlloc = rewriter.create<AllocLikeOp>(
161         alloc.getLoc(), newMemRefType, dynamicSizes, alloc.symbolOperands(),
162         alloc.alignmentAttr());
163     // Insert a cast so we have the same type as the old alloc.
164     auto resultCast =
165         rewriter.create<CastOp>(alloc.getLoc(), alloc.getType(), newAlloc);
166 
167     rewriter.replaceOp(alloc, {resultCast});
168     return success();
169   }
170 };
171 
172 /// Fold alloc operations with no users or only store and dealloc uses.
173 template <typename T>
174 struct SimplifyDeadAlloc : public OpRewritePattern<T> {
175   using OpRewritePattern<T>::OpRewritePattern;
176 
177   LogicalResult matchAndRewrite(T alloc,
178                                 PatternRewriter &rewriter) const override {
179     if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
180           if (auto storeOp = dyn_cast<StoreOp>(op))
181             return storeOp.value() == alloc;
182           return !isa<DeallocOp>(op);
183         }))
184       return failure();
185 
186     for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
187       rewriter.eraseOp(user);
188 
189     rewriter.eraseOp(alloc);
190     return success();
191   }
192 };
193 } // namespace
194 
195 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
196                                           MLIRContext *context) {
197   results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
198 }
199 
200 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
201                                            MLIRContext *context) {
202   results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
203       context);
204 }
205 
206 //===----------------------------------------------------------------------===//
207 // AllocaScopeOp
208 //===----------------------------------------------------------------------===//
209 
210 void AllocaScopeOp::print(OpAsmPrinter &p) {
211   bool printBlockTerminators = false;
212 
213   p << ' ';
214   if (!results().empty()) {
215     p << " -> (" << getResultTypes() << ")";
216     printBlockTerminators = true;
217   }
218   p << ' ';
219   p.printRegion(bodyRegion(),
220                 /*printEntryBlockArgs=*/false,
221                 /*printBlockTerminators=*/printBlockTerminators);
222   p.printOptionalAttrDict((*this)->getAttrs());
223 }
224 
225 ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
226   // Create a region for the body.
227   result.regions.reserve(1);
228   Region *bodyRegion = result.addRegion();
229 
230   // Parse optional results type list.
231   if (parser.parseOptionalArrowTypeList(result.types))
232     return failure();
233 
234   // Parse the body region.
235   if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{}))
236     return failure();
237   AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
238                                   result.location);
239 
240   // Parse the optional attribute list.
241   if (parser.parseOptionalAttrDict(result.attributes))
242     return failure();
243 
244   return success();
245 }
246 
247 LogicalResult AllocaScopeOp::verify() {
248   return RegionBranchOpInterface::verifyTypes(*this);
249 }
250 
251 void AllocaScopeOp::getSuccessorRegions(
252     Optional<unsigned> index, ArrayRef<Attribute> operands,
253     SmallVectorImpl<RegionSuccessor> &regions) {
254   if (index.hasValue()) {
255     regions.push_back(RegionSuccessor(getResults()));
256     return;
257   }
258 
259   regions.push_back(RegionSuccessor(&bodyRegion()));
260 }
261 
262 /// Given an operation, return whether this op is guaranteed to
263 /// allocate an AutomaticAllocationScopeResource
264 static bool isGuaranteedAutomaticAllocationScope(Operation *op) {
265   MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
266   if (!interface)
267     return false;
268   for (auto res : op->getResults()) {
269     if (auto effect =
270             interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
271       if (isa<SideEffects::AutomaticAllocationScopeResource>(
272               effect->getResource()))
273         return true;
274     }
275   }
276   return false;
277 }
278 
279 /// Given an operation, return whether this op could to
280 /// allocate an AutomaticAllocationScopeResource
281 static bool isPotentialAutomaticAllocationScope(Operation *op) {
282   MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
283   if (!interface)
284     return true;
285   for (auto res : op->getResults()) {
286     if (auto effect =
287             interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
288       if (isa<SideEffects::AutomaticAllocationScopeResource>(
289               effect->getResource()))
290         return true;
291     }
292   }
293   return false;
294 }
295 
296 /// Return whether this op is the last non terminating op
297 /// in a region. That is to say, it is in a one-block region
298 /// and is only followed by a terminator. This prevents
299 /// extending the lifetime of allocations.
300 static bool lastNonTerminatorInRegion(Operation *op) {
301   return op->getNextNode() == op->getBlock()->getTerminator() &&
302          op->getParentRegion()->getBlocks().size() == 1;
303 }
304 
305 /// Inline an AllocaScopeOp if either the direct parent is an allocation scope
306 /// or it contains no allocation.
307 struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
308   using OpRewritePattern<AllocaScopeOp>::OpRewritePattern;
309 
310   LogicalResult matchAndRewrite(AllocaScopeOp op,
311                                 PatternRewriter &rewriter) const override {
312     if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>()) {
313       bool hasPotentialAlloca =
314           op->walk([&](Operation *alloc) {
315               if (isPotentialAutomaticAllocationScope(alloc))
316                 return WalkResult::interrupt();
317               return WalkResult::skip();
318             }).wasInterrupted();
319       if (hasPotentialAlloca)
320         return failure();
321     }
322 
323     // Only apply to if this is this last non-terminator
324     // op in the block (lest lifetime be extended) of a one
325     // block region
326     if (!lastNonTerminatorInRegion(op))
327       return failure();
328 
329     Block *block = &op.getRegion().front();
330     Operation *terminator = block->getTerminator();
331     ValueRange results = terminator->getOperands();
332     rewriter.mergeBlockBefore(block, op);
333     rewriter.replaceOp(op, results);
334     rewriter.eraseOp(terminator);
335     return success();
336   }
337 };
338 
339 /// Move allocations into an allocation scope, if it is legal to
340 /// move them (e.g. their operands are available at the location
341 /// the op would be moved to).
342 struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> {
343   using OpRewritePattern<AllocaScopeOp>::OpRewritePattern;
344 
345   LogicalResult matchAndRewrite(AllocaScopeOp op,
346                                 PatternRewriter &rewriter) const override {
347 
348     if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
349       return failure();
350 
351     Operation *lastParentWithoutScope = op->getParentOp();
352 
353     if (!lastParentWithoutScope ||
354         lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>())
355       return failure();
356 
357     // Only apply to if this is this last non-terminator
358     // op in the block (lest lifetime be extended) of a one
359     // block region
360     if (!lastNonTerminatorInRegion(op) ||
361         !lastNonTerminatorInRegion(lastParentWithoutScope))
362       return failure();
363 
364     while (!lastParentWithoutScope->getParentOp()
365                 ->hasTrait<OpTrait::AutomaticAllocationScope>()) {
366       lastParentWithoutScope = lastParentWithoutScope->getParentOp();
367       if (!lastParentWithoutScope ||
368           !lastNonTerminatorInRegion(lastParentWithoutScope))
369         return failure();
370     }
371     assert(lastParentWithoutScope->getParentOp()
372                ->hasTrait<OpTrait::AutomaticAllocationScope>());
373 
374     Region *containingRegion = nullptr;
375     for (auto &r : lastParentWithoutScope->getRegions()) {
376       if (r.isAncestor(op->getParentRegion())) {
377         assert(containingRegion == nullptr &&
378                "only one region can contain the op");
379         containingRegion = &r;
380       }
381     }
382     assert(containingRegion && "op must be contained in a region");
383 
384     SmallVector<Operation *> toHoist;
385     op->walk([&](Operation *alloc) {
386       if (!isGuaranteedAutomaticAllocationScope(alloc))
387         return WalkResult::skip();
388 
389       // If any operand is not defined before the location of
390       // lastParentWithoutScope (i.e. where we would hoist to), skip.
391       if (llvm::any_of(alloc->getOperands(), [&](Value v) {
392             return containingRegion->isAncestor(v.getParentRegion());
393           }))
394         return WalkResult::skip();
395       toHoist.push_back(alloc);
396       return WalkResult::advance();
397     });
398 
399     if (!toHoist.size())
400       return failure();
401     rewriter.setInsertionPoint(lastParentWithoutScope);
402     for (auto op : toHoist) {
403       auto cloned = rewriter.clone(*op);
404       rewriter.replaceOp(op, cloned->getResults());
405     }
406     return success();
407   }
408 };
409 
410 void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,
411                                                 MLIRContext *context) {
412   results.add<AllocaScopeInliner, AllocaScopeHoister>(context);
413 }
414 
415 //===----------------------------------------------------------------------===//
416 // AssumeAlignmentOp
417 //===----------------------------------------------------------------------===//
418 
419 LogicalResult AssumeAlignmentOp::verify() {
420   if (!llvm::isPowerOf2_32(alignment()))
421     return emitOpError("alignment must be power of 2");
422   return success();
423 }
424 
425 //===----------------------------------------------------------------------===//
426 // CastOp
427 //===----------------------------------------------------------------------===//
428 
429 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
430 /// source memref. This is useful to to fold a memref.cast into a consuming op
431 /// and implement canonicalization patterns for ops in different dialects that
432 /// may consume the results of memref.cast operations. Such foldable memref.cast
433 /// operations are typically inserted as `view` and `subview` ops are
434 /// canonicalized, to preserve the type compatibility of their uses.
435 ///
436 /// Returns true when all conditions are met:
437 /// 1. source and result are ranked memrefs with strided semantics and same
438 /// element type and rank.
439 /// 2. each of the source's size, offset or stride has more static information
440 /// than the corresponding result's size, offset or stride.
441 ///
442 /// Example 1:
443 /// ```mlir
444 ///   %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
445 ///   %2 = consumer %1 ... : memref<?x?xf32> ...
446 /// ```
447 ///
448 /// may fold into:
449 ///
450 /// ```mlir
451 ///   %2 = consumer %0 ... : memref<8x16xf32> ...
452 /// ```
453 ///
454 /// Example 2:
455 /// ```
456 ///   %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
457 ///          to memref<?x?xf32>
458 ///   consumer %1 : memref<?x?xf32> ...
459 /// ```
460 ///
461 /// may fold into:
462 ///
463 /// ```
464 ///   consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
465 /// ```
466 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
467   MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
468   MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
469 
470   // Requires ranked MemRefType.
471   if (!sourceType || !resultType)
472     return false;
473 
474   // Requires same elemental type.
475   if (sourceType.getElementType() != resultType.getElementType())
476     return false;
477 
478   // Requires same rank.
479   if (sourceType.getRank() != resultType.getRank())
480     return false;
481 
482   // Only fold casts between strided memref forms.
483   int64_t sourceOffset, resultOffset;
484   SmallVector<int64_t, 4> sourceStrides, resultStrides;
485   if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
486       failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
487     return false;
488 
489   // If cast is towards more static sizes along any dimension, don't fold.
490   for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
491     auto ss = std::get<0>(it), st = std::get<1>(it);
492     if (ss != st)
493       if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
494         return false;
495   }
496 
497   // If cast is towards more static offset along any dimension, don't fold.
498   if (sourceOffset != resultOffset)
499     if (ShapedType::isDynamicStrideOrOffset(sourceOffset) &&
500         !ShapedType::isDynamicStrideOrOffset(resultOffset))
501       return false;
502 
503   // If cast is towards more static strides along any dimension, don't fold.
504   for (auto it : llvm::zip(sourceStrides, resultStrides)) {
505     auto ss = std::get<0>(it), st = std::get<1>(it);
506     if (ss != st)
507       if (ShapedType::isDynamicStrideOrOffset(ss) &&
508           !ShapedType::isDynamicStrideOrOffset(st))
509         return false;
510   }
511 
512   return true;
513 }
514 
515 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
516   if (inputs.size() != 1 || outputs.size() != 1)
517     return false;
518   Type a = inputs.front(), b = outputs.front();
519   auto aT = a.dyn_cast<MemRefType>();
520   auto bT = b.dyn_cast<MemRefType>();
521 
522   auto uaT = a.dyn_cast<UnrankedMemRefType>();
523   auto ubT = b.dyn_cast<UnrankedMemRefType>();
524 
525   if (aT && bT) {
526     if (aT.getElementType() != bT.getElementType())
527       return false;
528     if (aT.getLayout() != bT.getLayout()) {
529       int64_t aOffset, bOffset;
530       SmallVector<int64_t, 4> aStrides, bStrides;
531       if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
532           failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
533           aStrides.size() != bStrides.size())
534         return false;
535 
536       // Strides along a dimension/offset are compatible if the value in the
537       // source memref is static and the value in the target memref is the
538       // same. They are also compatible if either one is dynamic (see
539       // description of MemRefCastOp for details).
540       auto checkCompatible = [](int64_t a, int64_t b) {
541         return (a == MemRefType::getDynamicStrideOrOffset() ||
542                 b == MemRefType::getDynamicStrideOrOffset() || a == b);
543       };
544       if (!checkCompatible(aOffset, bOffset))
545         return false;
546       for (const auto &aStride : enumerate(aStrides))
547         if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
548           return false;
549     }
550     if (aT.getMemorySpace() != bT.getMemorySpace())
551       return false;
552 
553     // They must have the same rank, and any specified dimensions must match.
554     if (aT.getRank() != bT.getRank())
555       return false;
556 
557     for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
558       int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
559       if (aDim != -1 && bDim != -1 && aDim != bDim)
560         return false;
561     }
562     return true;
563   } else {
564     if (!aT && !uaT)
565       return false;
566     if (!bT && !ubT)
567       return false;
568     // Unranked to unranked casting is unsupported
569     if (uaT && ubT)
570       return false;
571 
572     auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
573     auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
574     if (aEltType != bEltType)
575       return false;
576 
577     auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
578     auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
579     return aMemSpace == bMemSpace;
580   }
581 
582   return false;
583 }
584 
585 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
586   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
587 }
588 
589 //===----------------------------------------------------------------------===//
590 // CopyOp
591 //===----------------------------------------------------------------------===//
592 
593 namespace {
594 /// If the source/target of a CopyOp is a CastOp that does not modify the shape
595 /// and element type, the cast can be skipped. Such CastOps only cast the layout
596 /// of the type.
597 struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
598   using OpRewritePattern<CopyOp>::OpRewritePattern;
599 
600   LogicalResult matchAndRewrite(CopyOp copyOp,
601                                 PatternRewriter &rewriter) const override {
602     bool modified = false;
603 
604     // Check source.
605     if (auto castOp = copyOp.source().getDefiningOp<CastOp>()) {
606       auto fromType = castOp.source().getType().dyn_cast<MemRefType>();
607       auto toType = castOp.source().getType().dyn_cast<MemRefType>();
608 
609       if (fromType && toType) {
610         if (fromType.getShape() == toType.getShape() &&
611             fromType.getElementType() == toType.getElementType()) {
612           rewriter.updateRootInPlace(
613               copyOp, [&] { copyOp.sourceMutable().assign(castOp.source()); });
614           modified = true;
615         }
616       }
617     }
618 
619     // Check target.
620     if (auto castOp = copyOp.target().getDefiningOp<CastOp>()) {
621       auto fromType = castOp.source().getType().dyn_cast<MemRefType>();
622       auto toType = castOp.source().getType().dyn_cast<MemRefType>();
623 
624       if (fromType && toType) {
625         if (fromType.getShape() == toType.getShape() &&
626             fromType.getElementType() == toType.getElementType()) {
627           rewriter.updateRootInPlace(
628               copyOp, [&] { copyOp.targetMutable().assign(castOp.source()); });
629           modified = true;
630         }
631       }
632     }
633 
634     return success(modified);
635   }
636 };
637 
638 /// Fold memref.copy(%x, %x).
639 struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
640   using OpRewritePattern<CopyOp>::OpRewritePattern;
641 
642   LogicalResult matchAndRewrite(CopyOp copyOp,
643                                 PatternRewriter &rewriter) const override {
644     if (copyOp.source() != copyOp.target())
645       return failure();
646 
647     rewriter.eraseOp(copyOp);
648     return success();
649   }
650 };
651 } // namespace
652 
653 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
654                                          MLIRContext *context) {
655   results.add<FoldCopyOfCast, FoldSelfCopy>(context);
656 }
657 
658 LogicalResult CopyOp::fold(ArrayRef<Attribute> cstOperands,
659                            SmallVectorImpl<OpFoldResult> &results) {
660   /// copy(memrefcast) -> copy
661   bool folded = false;
662   Operation *op = *this;
663   for (OpOperand &operand : op->getOpOperands()) {
664     auto castOp = operand.get().getDefiningOp<memref::CastOp>();
665     if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
666       operand.set(castOp.getOperand());
667       folded = true;
668     }
669   }
670   return success(folded);
671 }
672 
673 //===----------------------------------------------------------------------===//
674 // DeallocOp
675 //===----------------------------------------------------------------------===//
676 
677 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
678                               SmallVectorImpl<OpFoldResult> &results) {
679   /// dealloc(memrefcast) -> dealloc
680   return foldMemRefCast(*this);
681 }
682 
683 //===----------------------------------------------------------------------===//
684 // DimOp
685 //===----------------------------------------------------------------------===//
686 
687 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
688                   int64_t index) {
689   auto loc = result.location;
690   Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
691   build(builder, result, source, indexValue);
692 }
693 
694 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
695                   Value index) {
696   auto indexTy = builder.getIndexType();
697   build(builder, result, indexTy, source, index);
698 }
699 
700 Optional<int64_t> DimOp::getConstantIndex() {
701   if (auto constantOp = index().getDefiningOp<arith::ConstantOp>())
702     return constantOp.getValue().cast<IntegerAttr>().getInt();
703   return {};
704 }
705 
706 LogicalResult DimOp::verify() {
707   // Assume unknown index to be in range.
708   Optional<int64_t> index = getConstantIndex();
709   if (!index.hasValue())
710     return success();
711 
712   // Check that constant index is not knowingly out of range.
713   auto type = source().getType();
714   if (auto memrefType = type.dyn_cast<MemRefType>()) {
715     if (index.getValue() >= memrefType.getRank())
716       return emitOpError("index is out of range");
717   } else if (type.isa<UnrankedMemRefType>()) {
718     // Assume index to be in range.
719   } else {
720     llvm_unreachable("expected operand with memref type");
721   }
722   return success();
723 }
724 
725 /// Return a map with key being elements in `vals` and data being number of
726 /// occurences of it. Use std::map, since the `vals` here are strides and the
727 /// dynamic stride value is the same as the tombstone value for
728 /// `DenseMap<int64_t>`.
729 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
730   std::map<int64_t, unsigned> numOccurences;
731   for (auto val : vals)
732     numOccurences[val]++;
733   return numOccurences;
734 }
735 
736 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed
737 /// to be a subset of `originalType` with some `1` entries erased, return the
738 /// set of indices that specifies which of the entries of `originalShape` are
739 /// dropped to obtain `reducedShape`.
740 /// This accounts for cases where there are multiple unit-dims, but only a
741 /// subset of those are dropped. For MemRefTypes these can be disambiguated
742 /// using the strides. If a dimension is dropped the stride must be dropped too.
743 static llvm::Optional<llvm::SmallBitVector>
744 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
745                                ArrayRef<OpFoldResult> sizes) {
746   llvm::SmallBitVector unusedDims(originalType.getRank());
747   if (originalType.getRank() == reducedType.getRank())
748     return unusedDims;
749 
750   for (const auto &dim : llvm::enumerate(sizes))
751     if (auto attr = dim.value().dyn_cast<Attribute>())
752       if (attr.cast<IntegerAttr>().getInt() == 1)
753         unusedDims.set(dim.index());
754 
755   SmallVector<int64_t> originalStrides, candidateStrides;
756   int64_t originalOffset, candidateOffset;
757   if (failed(
758           getStridesAndOffset(originalType, originalStrides, originalOffset)) ||
759       failed(
760           getStridesAndOffset(reducedType, candidateStrides, candidateOffset)))
761     return llvm::None;
762 
763   // For memrefs, a dimension is truly dropped if its corresponding stride is
764   // also dropped. This is particularly important when more than one of the dims
765   // is 1. Track the number of occurences of the strides in the original type
766   // and the candidate type. For each unused dim that stride should not be
767   // present in the candidate type. Note that there could be multiple dimensions
768   // that have the same size. We dont need to exactly figure out which dim
769   // corresponds to which stride, we just need to verify that the number of
770   // reptitions of a stride in the original + number of unused dims with that
771   // stride == number of repititions of a stride in the candidate.
772   std::map<int64_t, unsigned> currUnaccountedStrides =
773       getNumOccurences(originalStrides);
774   std::map<int64_t, unsigned> candidateStridesNumOccurences =
775       getNumOccurences(candidateStrides);
776   for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
777     if (!unusedDims.test(dim))
778       continue;
779     int64_t originalStride = originalStrides[dim];
780     if (currUnaccountedStrides[originalStride] >
781         candidateStridesNumOccurences[originalStride]) {
782       // This dim can be treated as dropped.
783       currUnaccountedStrides[originalStride]--;
784       continue;
785     }
786     if (currUnaccountedStrides[originalStride] ==
787         candidateStridesNumOccurences[originalStride]) {
788       // The stride for this is not dropped. Keep as is.
789       unusedDims.reset(dim);
790       continue;
791     }
792     if (currUnaccountedStrides[originalStride] <
793         candidateStridesNumOccurences[originalStride]) {
794       // This should never happen. Cant have a stride in the reduced rank type
795       // that wasnt in the original one.
796       return llvm::None;
797     }
798   }
799 
800   if ((int64_t)unusedDims.count() + reducedType.getRank() !=
801       originalType.getRank())
802     return llvm::None;
803   return unusedDims;
804 }
805 
806 llvm::SmallBitVector SubViewOp::getDroppedDims() {
807   MemRefType sourceType = getSourceType();
808   MemRefType resultType = getType();
809   llvm::Optional<llvm::SmallBitVector> unusedDims =
810       computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
811   assert(unusedDims && "unable to find unused dims of subview");
812   return *unusedDims;
813 }
814 
815 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
816   // All forms of folding require a known index.
817   auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
818   if (!index)
819     return {};
820 
821   // Folding for unranked types (UnrankedMemRefType) is not supported.
822   auto memrefType = source().getType().dyn_cast<MemRefType>();
823   if (!memrefType)
824     return {};
825 
826   // Fold if the shape extent along the given index is known.
827   if (!memrefType.isDynamicDim(index.getInt())) {
828     Builder builder(getContext());
829     return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
830   }
831 
832   // The size at the given index is now known to be a dynamic size.
833   unsigned unsignedIndex = index.getValue().getZExtValue();
834 
835   // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
836   Operation *definingOp = source().getDefiningOp();
837 
838   if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
839     return *(alloc.getDynamicSizes().begin() +
840              memrefType.getDynamicDimIndex(unsignedIndex));
841 
842   if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
843     return *(alloca.getDynamicSizes().begin() +
844              memrefType.getDynamicDimIndex(unsignedIndex));
845 
846   if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
847     return *(view.getDynamicSizes().begin() +
848              memrefType.getDynamicDimIndex(unsignedIndex));
849 
850   if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
851     llvm::SmallBitVector unusedDims = subview.getDroppedDims();
852     unsigned resultIndex = 0;
853     unsigned sourceRank = subview.getSourceType().getRank();
854     unsigned sourceIndex = 0;
855     for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
856       if (unusedDims.test(i))
857         continue;
858       if (resultIndex == unsignedIndex) {
859         sourceIndex = i;
860         break;
861       }
862       resultIndex++;
863     }
864     assert(subview.isDynamicSize(sourceIndex) &&
865            "expected dynamic subview size");
866     return subview.getDynamicSize(sourceIndex);
867   }
868 
869   if (auto sizeInterface =
870           dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
871     assert(sizeInterface.isDynamicSize(unsignedIndex) &&
872            "Expected dynamic subview size");
873     return sizeInterface.getDynamicSize(unsignedIndex);
874   }
875 
876   // dim(memrefcast) -> dim
877   if (succeeded(foldMemRefCast(*this)))
878     return getResult();
879 
880   return {};
881 }
882 
883 namespace {
884 /// Fold dim of a memref reshape operation to a load into the reshape's shape
885 /// operand.
886 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
887   using OpRewritePattern<DimOp>::OpRewritePattern;
888 
889   LogicalResult matchAndRewrite(DimOp dim,
890                                 PatternRewriter &rewriter) const override {
891     auto reshape = dim.source().getDefiningOp<ReshapeOp>();
892 
893     if (!reshape)
894       return failure();
895 
896     // Place the load directly after the reshape to ensure that the shape memref
897     // was not mutated.
898     rewriter.setInsertionPointAfter(reshape);
899     Location loc = dim.getLoc();
900     Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index());
901     if (load.getType() != dim.getType())
902       load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
903     rewriter.replaceOp(dim, load);
904     return success();
905   }
906 };
907 
908 } // namespace
909 
910 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
911                                         MLIRContext *context) {
912   results.add<DimOfMemRefReshape>(context);
913 }
914 
915 // ---------------------------------------------------------------------------
916 // DmaStartOp
917 // ---------------------------------------------------------------------------
918 
919 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
920                        Value srcMemRef, ValueRange srcIndices, Value destMemRef,
921                        ValueRange destIndices, Value numElements,
922                        Value tagMemRef, ValueRange tagIndices, Value stride,
923                        Value elementsPerStride) {
924   result.addOperands(srcMemRef);
925   result.addOperands(srcIndices);
926   result.addOperands(destMemRef);
927   result.addOperands(destIndices);
928   result.addOperands({numElements, tagMemRef});
929   result.addOperands(tagIndices);
930   if (stride)
931     result.addOperands({stride, elementsPerStride});
932 }
933 
934 void DmaStartOp::print(OpAsmPrinter &p) {
935   p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
936     << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
937     << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
938   if (isStrided())
939     p << ", " << getStride() << ", " << getNumElementsPerStride();
940 
941   p.printOptionalAttrDict((*this)->getAttrs());
942   p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
943     << ", " << getTagMemRef().getType();
944 }
945 
946 // Parse DmaStartOp.
947 // Ex:
948 //   %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
949 //                       %tag[%index], %stride, %num_elt_per_stride :
950 //                     : memref<3076 x f32, 0>,
951 //                       memref<1024 x f32, 2>,
952 //                       memref<1 x i32>
953 //
954 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
955   OpAsmParser::OperandType srcMemRefInfo;
956   SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
957   OpAsmParser::OperandType dstMemRefInfo;
958   SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos;
959   OpAsmParser::OperandType numElementsInfo;
960   OpAsmParser::OperandType tagMemrefInfo;
961   SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
962   SmallVector<OpAsmParser::OperandType, 2> strideInfo;
963 
964   SmallVector<Type, 3> types;
965   auto indexType = parser.getBuilder().getIndexType();
966 
967   // Parse and resolve the following list of operands:
968   // *) source memref followed by its indices (in square brackets).
969   // *) destination memref followed by its indices (in square brackets).
970   // *) dma size in KiB.
971   if (parser.parseOperand(srcMemRefInfo) ||
972       parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
973       parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
974       parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
975       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
976       parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
977       parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
978     return failure();
979 
980   // Parse optional stride and elements per stride.
981   if (parser.parseTrailingOperandList(strideInfo))
982     return failure();
983 
984   bool isStrided = strideInfo.size() == 2;
985   if (!strideInfo.empty() && !isStrided) {
986     return parser.emitError(parser.getNameLoc(),
987                             "expected two stride related operands");
988   }
989 
990   if (parser.parseColonTypeList(types))
991     return failure();
992   if (types.size() != 3)
993     return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
994 
995   if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
996       parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
997       parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
998       parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
999       // size should be an index.
1000       parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1001       parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1002       // tag indices should be index.
1003       parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1004     return failure();
1005 
1006   if (isStrided) {
1007     if (parser.resolveOperands(strideInfo, indexType, result.operands))
1008       return failure();
1009   }
1010 
1011   return success();
1012 }
1013 
1014 LogicalResult DmaStartOp::verify() {
1015   unsigned numOperands = getNumOperands();
1016 
1017   // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1018   // the number of elements.
1019   if (numOperands < 4)
1020     return emitOpError("expected at least 4 operands");
1021 
1022   // Check types of operands. The order of these calls is important: the later
1023   // calls rely on some type properties to compute the operand position.
1024   // 1. Source memref.
1025   if (!getSrcMemRef().getType().isa<MemRefType>())
1026     return emitOpError("expected source to be of memref type");
1027   if (numOperands < getSrcMemRefRank() + 4)
1028     return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1029                          << " operands";
1030   if (!getSrcIndices().empty() &&
1031       !llvm::all_of(getSrcIndices().getTypes(),
1032                     [](Type t) { return t.isIndex(); }))
1033     return emitOpError("expected source indices to be of index type");
1034 
1035   // 2. Destination memref.
1036   if (!getDstMemRef().getType().isa<MemRefType>())
1037     return emitOpError("expected destination to be of memref type");
1038   unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1039   if (numOperands < numExpectedOperands)
1040     return emitOpError() << "expected at least " << numExpectedOperands
1041                          << " operands";
1042   if (!getDstIndices().empty() &&
1043       !llvm::all_of(getDstIndices().getTypes(),
1044                     [](Type t) { return t.isIndex(); }))
1045     return emitOpError("expected destination indices to be of index type");
1046 
1047   // 3. Number of elements.
1048   if (!getNumElements().getType().isIndex())
1049     return emitOpError("expected num elements to be of index type");
1050 
1051   // 4. Tag memref.
1052   if (!getTagMemRef().getType().isa<MemRefType>())
1053     return emitOpError("expected tag to be of memref type");
1054   numExpectedOperands += getTagMemRefRank();
1055   if (numOperands < numExpectedOperands)
1056     return emitOpError() << "expected at least " << numExpectedOperands
1057                          << " operands";
1058   if (!getTagIndices().empty() &&
1059       !llvm::all_of(getTagIndices().getTypes(),
1060                     [](Type t) { return t.isIndex(); }))
1061     return emitOpError("expected tag indices to be of index type");
1062 
1063   // Optional stride-related operands must be either both present or both
1064   // absent.
1065   if (numOperands != numExpectedOperands &&
1066       numOperands != numExpectedOperands + 2)
1067     return emitOpError("incorrect number of operands");
1068 
1069   // 5. Strides.
1070   if (isStrided()) {
1071     if (!getStride().getType().isIndex() ||
1072         !getNumElementsPerStride().getType().isIndex())
1073       return emitOpError(
1074           "expected stride and num elements per stride to be of type index");
1075   }
1076 
1077   return success();
1078 }
1079 
1080 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1081                                SmallVectorImpl<OpFoldResult> &results) {
1082   /// dma_start(memrefcast) -> dma_start
1083   return foldMemRefCast(*this);
1084 }
1085 
1086 // ---------------------------------------------------------------------------
1087 // DmaWaitOp
1088 // ---------------------------------------------------------------------------
1089 
1090 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1091                               SmallVectorImpl<OpFoldResult> &results) {
1092   /// dma_wait(memrefcast) -> dma_wait
1093   return foldMemRefCast(*this);
1094 }
1095 
1096 LogicalResult DmaWaitOp::verify() {
1097   // Check that the number of tag indices matches the tagMemRef rank.
1098   unsigned numTagIndices = tagIndices().size();
1099   unsigned tagMemRefRank = getTagMemRefRank();
1100   if (numTagIndices != tagMemRefRank)
1101     return emitOpError() << "expected tagIndices to have the same number of "
1102                             "elements as the tagMemRef rank, expected "
1103                          << tagMemRefRank << ", but got " << numTagIndices;
1104   return success();
1105 }
1106 
1107 //===----------------------------------------------------------------------===//
1108 // GenericAtomicRMWOp
1109 //===----------------------------------------------------------------------===//
1110 
1111 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
1112                                Value memref, ValueRange ivs) {
1113   result.addOperands(memref);
1114   result.addOperands(ivs);
1115 
1116   if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) {
1117     Type elementType = memrefType.getElementType();
1118     result.addTypes(elementType);
1119 
1120     Region *bodyRegion = result.addRegion();
1121     bodyRegion->push_back(new Block());
1122     bodyRegion->addArgument(elementType, memref.getLoc());
1123   }
1124 }
1125 
1126 LogicalResult GenericAtomicRMWOp::verify() {
1127   auto &body = getRegion();
1128   if (body.getNumArguments() != 1)
1129     return emitOpError("expected single number of entry block arguments");
1130 
1131   if (getResult().getType() != body.getArgument(0).getType())
1132     return emitOpError("expected block argument of the same type result type");
1133 
1134   bool hasSideEffects =
1135       body.walk([&](Operation *nestedOp) {
1136             if (MemoryEffectOpInterface::hasNoEffect(nestedOp))
1137               return WalkResult::advance();
1138             nestedOp->emitError(
1139                 "body of 'memref.generic_atomic_rmw' should contain "
1140                 "only operations with no side effects");
1141             return WalkResult::interrupt();
1142           })
1143           .wasInterrupted();
1144   return hasSideEffects ? failure() : success();
1145 }
1146 
1147 ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
1148                                       OperationState &result) {
1149   OpAsmParser::OperandType memref;
1150   Type memrefType;
1151   SmallVector<OpAsmParser::OperandType, 4> ivs;
1152 
1153   Type indexType = parser.getBuilder().getIndexType();
1154   if (parser.parseOperand(memref) ||
1155       parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) ||
1156       parser.parseColonType(memrefType) ||
1157       parser.resolveOperand(memref, memrefType, result.operands) ||
1158       parser.resolveOperands(ivs, indexType, result.operands))
1159     return failure();
1160 
1161   Region *body = result.addRegion();
1162   if (parser.parseRegion(*body, llvm::None, llvm::None) ||
1163       parser.parseOptionalAttrDict(result.attributes))
1164     return failure();
1165   result.types.push_back(memrefType.cast<MemRefType>().getElementType());
1166   return success();
1167 }
1168 
1169 void GenericAtomicRMWOp::print(OpAsmPrinter &p) {
1170   p << ' ' << memref() << "[" << indices() << "] : " << memref().getType()
1171     << ' ';
1172   p.printRegion(getRegion());
1173   p.printOptionalAttrDict((*this)->getAttrs());
1174 }
1175 
1176 //===----------------------------------------------------------------------===//
1177 // AtomicYieldOp
1178 //===----------------------------------------------------------------------===//
1179 
1180 LogicalResult AtomicYieldOp::verify() {
1181   Type parentType = (*this)->getParentOp()->getResultTypes().front();
1182   Type resultType = result().getType();
1183   if (parentType != resultType)
1184     return emitOpError() << "types mismatch between yield op: " << resultType
1185                          << " and its parent: " << parentType;
1186   return success();
1187 }
1188 
1189 //===----------------------------------------------------------------------===//
1190 // GlobalOp
1191 //===----------------------------------------------------------------------===//
1192 
1193 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op,
1194                                                    TypeAttr type,
1195                                                    Attribute initialValue) {
1196   p << type;
1197   if (!op.isExternal()) {
1198     p << " = ";
1199     if (op.isUninitialized())
1200       p << "uninitialized";
1201     else
1202       p.printAttributeWithoutType(initialValue);
1203   }
1204 }
1205 
1206 static ParseResult
1207 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
1208                                        Attribute &initialValue) {
1209   Type type;
1210   if (parser.parseType(type))
1211     return failure();
1212 
1213   auto memrefType = type.dyn_cast<MemRefType>();
1214   if (!memrefType || !memrefType.hasStaticShape())
1215     return parser.emitError(parser.getNameLoc())
1216            << "type should be static shaped memref, but got " << type;
1217   typeAttr = TypeAttr::get(type);
1218 
1219   if (parser.parseOptionalEqual())
1220     return success();
1221 
1222   if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1223     initialValue = UnitAttr::get(parser.getContext());
1224     return success();
1225   }
1226 
1227   Type tensorType = getTensorTypeFromMemRefType(memrefType);
1228   if (parser.parseAttribute(initialValue, tensorType))
1229     return failure();
1230   if (!initialValue.isa<ElementsAttr>())
1231     return parser.emitError(parser.getNameLoc())
1232            << "initial value should be a unit or elements attribute";
1233   return success();
1234 }
1235 
1236 LogicalResult GlobalOp::verify() {
1237   auto memrefType = type().dyn_cast<MemRefType>();
1238   if (!memrefType || !memrefType.hasStaticShape())
1239     return emitOpError("type should be static shaped memref, but got ")
1240            << type();
1241 
1242   // Verify that the initial value, if present, is either a unit attribute or
1243   // an elements attribute.
1244   if (initial_value().hasValue()) {
1245     Attribute initValue = initial_value().getValue();
1246     if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>())
1247       return emitOpError("initial value should be a unit or elements "
1248                          "attribute, but got ")
1249              << initValue;
1250 
1251     // Check that the type of the initial value is compatible with the type of
1252     // the global variable.
1253     if (initValue.isa<ElementsAttr>()) {
1254       Type initType = initValue.getType();
1255       Type tensorType = getTensorTypeFromMemRefType(memrefType);
1256       if (initType != tensorType)
1257         return emitOpError("initial value expected to be of type ")
1258                << tensorType << ", but was of type " << initType;
1259     }
1260   }
1261 
1262   if (Optional<uint64_t> alignAttr = alignment()) {
1263     uint64_t alignment = alignAttr.getValue();
1264 
1265     if (!llvm::isPowerOf2_64(alignment))
1266       return emitError() << "alignment attribute value " << alignment
1267                          << " is not a power of 2";
1268   }
1269 
1270   // TODO: verify visibility for declarations.
1271   return success();
1272 }
1273 
1274 //===----------------------------------------------------------------------===//
1275 // GetGlobalOp
1276 //===----------------------------------------------------------------------===//
1277 
1278 LogicalResult
1279 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1280   // Verify that the result type is same as the type of the referenced
1281   // memref.global op.
1282   auto global =
1283       symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr());
1284   if (!global)
1285     return emitOpError("'")
1286            << name() << "' does not reference a valid global memref";
1287 
1288   Type resultType = result().getType();
1289   if (global.type() != resultType)
1290     return emitOpError("result type ")
1291            << resultType << " does not match type " << global.type()
1292            << " of the global memref @" << name();
1293   return success();
1294 }
1295 
1296 //===----------------------------------------------------------------------===//
1297 // LoadOp
1298 //===----------------------------------------------------------------------===//
1299 
1300 LogicalResult LoadOp::verify() {
1301   if (getNumOperands() != 1 + getMemRefType().getRank())
1302     return emitOpError("incorrect number of indices for load");
1303   return success();
1304 }
1305 
1306 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
1307   /// load(memrefcast) -> load
1308   if (succeeded(foldMemRefCast(*this)))
1309     return getResult();
1310   return OpFoldResult();
1311 }
1312 
1313 //===----------------------------------------------------------------------===//
1314 // PrefetchOp
1315 //===----------------------------------------------------------------------===//
1316 
1317 void PrefetchOp::print(OpAsmPrinter &p) {
1318   p << " " << memref() << '[';
1319   p.printOperands(indices());
1320   p << ']' << ", " << (isWrite() ? "write" : "read");
1321   p << ", locality<" << localityHint();
1322   p << ">, " << (isDataCache() ? "data" : "instr");
1323   p.printOptionalAttrDict(
1324       (*this)->getAttrs(),
1325       /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1326   p << " : " << getMemRefType();
1327 }
1328 
1329 ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
1330   OpAsmParser::OperandType memrefInfo;
1331   SmallVector<OpAsmParser::OperandType, 4> indexInfo;
1332   IntegerAttr localityHint;
1333   MemRefType type;
1334   StringRef readOrWrite, cacheType;
1335 
1336   auto indexTy = parser.getBuilder().getIndexType();
1337   auto i32Type = parser.getBuilder().getIntegerType(32);
1338   if (parser.parseOperand(memrefInfo) ||
1339       parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1340       parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1341       parser.parseComma() || parser.parseKeyword("locality") ||
1342       parser.parseLess() ||
1343       parser.parseAttribute(localityHint, i32Type, "localityHint",
1344                             result.attributes) ||
1345       parser.parseGreater() || parser.parseComma() ||
1346       parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1347       parser.resolveOperand(memrefInfo, type, result.operands) ||
1348       parser.resolveOperands(indexInfo, indexTy, result.operands))
1349     return failure();
1350 
1351   if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
1352     return parser.emitError(parser.getNameLoc(),
1353                             "rw specifier has to be 'read' or 'write'");
1354   result.addAttribute(
1355       PrefetchOp::getIsWriteAttrName(),
1356       parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
1357 
1358   if (!cacheType.equals("data") && !cacheType.equals("instr"))
1359     return parser.emitError(parser.getNameLoc(),
1360                             "cache type has to be 'data' or 'instr'");
1361 
1362   result.addAttribute(
1363       PrefetchOp::getIsDataCacheAttrName(),
1364       parser.getBuilder().getBoolAttr(cacheType.equals("data")));
1365 
1366   return success();
1367 }
1368 
1369 LogicalResult PrefetchOp::verify() {
1370   if (getNumOperands() != 1 + getMemRefType().getRank())
1371     return emitOpError("too few indices");
1372 
1373   return success();
1374 }
1375 
1376 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
1377                                SmallVectorImpl<OpFoldResult> &results) {
1378   // prefetch(memrefcast) -> prefetch
1379   return foldMemRefCast(*this);
1380 }
1381 
1382 //===----------------------------------------------------------------------===//
1383 // RankOp
1384 //===----------------------------------------------------------------------===//
1385 
1386 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
1387   // Constant fold rank when the rank of the operand is known.
1388   auto type = getOperand().getType();
1389   auto shapedType = type.dyn_cast<ShapedType>();
1390   if (shapedType && shapedType.hasRank())
1391     return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1392   return IntegerAttr();
1393 }
1394 
1395 //===----------------------------------------------------------------------===//
1396 // ReinterpretCastOp
1397 //===----------------------------------------------------------------------===//
1398 
1399 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1400 /// `staticSizes` and `staticStrides` are automatically filled with
1401 /// source-memref-rank sentinel values that encode dynamic entries.
1402 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1403                               MemRefType resultType, Value source,
1404                               OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1405                               ArrayRef<OpFoldResult> strides,
1406                               ArrayRef<NamedAttribute> attrs) {
1407   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1408   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1409   dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets,
1410                              ShapedType::kDynamicStrideOrOffset);
1411   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1412                              ShapedType::kDynamicSize);
1413   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1414                              ShapedType::kDynamicStrideOrOffset);
1415   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1416         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1417         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1418   result.addAttributes(attrs);
1419 }
1420 
1421 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1422                               MemRefType resultType, Value source,
1423                               int64_t offset, ArrayRef<int64_t> sizes,
1424                               ArrayRef<int64_t> strides,
1425                               ArrayRef<NamedAttribute> attrs) {
1426   SmallVector<OpFoldResult> sizeValues =
1427       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1428         return b.getI64IntegerAttr(v);
1429       }));
1430   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1431       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1432         return b.getI64IntegerAttr(v);
1433       }));
1434   build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1435         strideValues, attrs);
1436 }
1437 
1438 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1439                               MemRefType resultType, Value source, Value offset,
1440                               ValueRange sizes, ValueRange strides,
1441                               ArrayRef<NamedAttribute> attrs) {
1442   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1443       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1444   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1445       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1446   build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1447 }
1448 
1449 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1450 // completed automatically, like we have for subview and extract_slice.
1451 LogicalResult ReinterpretCastOp::verify() {
1452   // The source and result memrefs should be in the same memory space.
1453   auto srcType = source().getType().cast<BaseMemRefType>();
1454   auto resultType = getType().cast<MemRefType>();
1455   if (srcType.getMemorySpace() != resultType.getMemorySpace())
1456     return emitError("different memory spaces specified for source type ")
1457            << srcType << " and result memref type " << resultType;
1458   if (srcType.getElementType() != resultType.getElementType())
1459     return emitError("different element types specified for source type ")
1460            << srcType << " and result memref type " << resultType;
1461 
1462   // Match sizes in result memref type and in static_sizes attribute.
1463   for (auto &en : llvm::enumerate(llvm::zip(
1464            resultType.getShape(), extractFromI64ArrayAttr(static_sizes())))) {
1465     int64_t resultSize = std::get<0>(en.value());
1466     int64_t expectedSize = std::get<1>(en.value());
1467     if (!ShapedType::isDynamic(resultSize) &&
1468         !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
1469       return emitError("expected result type with size = ")
1470              << expectedSize << " instead of " << resultSize
1471              << " in dim = " << en.index();
1472   }
1473 
1474   // Match offset and strides in static_offset and static_strides attributes. If
1475   // result memref type has no affine map specified, this will assume an
1476   // identity layout.
1477   int64_t resultOffset;
1478   SmallVector<int64_t, 4> resultStrides;
1479   if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
1480     return emitError("expected result type to have strided layout but found ")
1481            << resultType;
1482 
1483   // Match offset in result memref type and in static_offsets attribute.
1484   int64_t expectedOffset = extractFromI64ArrayAttr(static_offsets()).front();
1485   if (!ShapedType::isDynamicStrideOrOffset(resultOffset) &&
1486       !ShapedType::isDynamicStrideOrOffset(expectedOffset) &&
1487       resultOffset != expectedOffset)
1488     return emitError("expected result type with offset = ")
1489            << resultOffset << " instead of " << expectedOffset;
1490 
1491   // Match strides in result memref type and in static_strides attribute.
1492   for (auto &en : llvm::enumerate(llvm::zip(
1493            resultStrides, extractFromI64ArrayAttr(static_strides())))) {
1494     int64_t resultStride = std::get<0>(en.value());
1495     int64_t expectedStride = std::get<1>(en.value());
1496     if (!ShapedType::isDynamicStrideOrOffset(resultStride) &&
1497         !ShapedType::isDynamicStrideOrOffset(expectedStride) &&
1498         resultStride != expectedStride)
1499       return emitError("expected result type with stride = ")
1500              << expectedStride << " instead of " << resultStride
1501              << " in dim = " << en.index();
1502   }
1503 
1504   return success();
1505 }
1506 
1507 //===----------------------------------------------------------------------===//
1508 // Reassociative reshape ops
1509 //===----------------------------------------------------------------------===//
1510 
1511 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1512   return getSymbolLessAffineMaps(getReassociationExprs());
1513 }
1514 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1515   return convertReassociationIndicesToExprs(getContext(),
1516                                             getReassociationIndices());
1517 }
1518 
1519 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1520   return getSymbolLessAffineMaps(getReassociationExprs());
1521 }
1522 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1523   return convertReassociationIndicesToExprs(getContext(),
1524                                             getReassociationIndices());
1525 }
1526 
1527 /// Detect whether memref dims [dim, dim + extent) can be reshaped without
1528 /// copies.
1529 static bool isReshapableDimBand(unsigned dim, unsigned extent,
1530                                 ArrayRef<int64_t> sizes,
1531                                 ArrayRef<AffineExpr> strides) {
1532   // Bands of extent one can be reshaped, as they are not reshaped at all.
1533   if (extent == 1)
1534     return true;
1535   // Otherwise, the size of the first dimension needs to be known.
1536   if (ShapedType::isDynamic(sizes[dim]))
1537     return false;
1538   assert(sizes.size() == strides.size() && "mismatched ranks");
1539   // off by 1 indexing to avoid out of bounds
1540   //                       V
1541   for (auto idx = dim, e = dim + extent; idx + 1 < e; ++idx) {
1542     // Only bands of static shapes are reshapable. This is due to the fact that
1543     // there is no relation between dynamic sizes and dynamic strides: we do not
1544     // have enough information to know whether a "-1" size corresponds to the
1545     // proper symbol in the AffineExpr of a stride.
1546     if (ShapedType::isDynamic(sizes[idx + 1]))
1547       return false;
1548     // TODO: Refine this by passing the proper nDims and nSymbols so we can
1549     // simplify on the fly and catch more reshapable cases.
1550     if (strides[idx] != strides[idx + 1] * sizes[idx + 1])
1551       return false;
1552   }
1553   return true;
1554 }
1555 
1556 /// Compute the MemRefType obtained by applying the `reassociation` (which is
1557 /// expected to be valid) to `type`.
1558 /// If `type` is Contiguous MemRefType, this always produce a contiguous
1559 /// MemRefType.
1560 static MemRefType
1561 computeReshapeCollapsedType(MemRefType type,
1562                             ArrayRef<AffineMap> reassociation) {
1563   auto sizes = type.getShape();
1564   AffineExpr offset;
1565   SmallVector<AffineExpr, 4> strides;
1566   auto status = getStridesAndOffset(type, strides, offset);
1567   auto isIdentityLayout = type.getLayout().isIdentity();
1568   (void)status;
1569   assert(succeeded(status) && "expected strided memref");
1570 
1571   SmallVector<int64_t, 4> newSizes;
1572   newSizes.reserve(reassociation.size());
1573   SmallVector<AffineExpr, 4> newStrides;
1574   newStrides.reserve(reassociation.size());
1575 
1576   // Use the fact that reassociation is valid to simplify the logic: only use
1577   // each map's rank.
1578   assert(isReassociationValid(reassociation) && "invalid reassociation");
1579   unsigned currentDim = 0;
1580   for (AffineMap m : reassociation) {
1581     unsigned dim = m.getNumResults();
1582     int64_t size = 1;
1583     AffineExpr stride = strides[currentDim + dim - 1];
1584     if (isIdentityLayout ||
1585         isReshapableDimBand(currentDim, dim, sizes, strides)) {
1586       for (unsigned d = 0; d < dim; ++d) {
1587         int64_t currentSize = sizes[currentDim + d];
1588         if (ShapedType::isDynamic(currentSize)) {
1589           size = ShapedType::kDynamicSize;
1590           break;
1591         }
1592         size *= currentSize;
1593       }
1594     } else {
1595       size = ShapedType::kDynamicSize;
1596       stride = AffineExpr();
1597     }
1598     newSizes.push_back(size);
1599     newStrides.push_back(stride);
1600     currentDim += dim;
1601   }
1602 
1603   // Early-exit: if `type` is contiguous, the result must be contiguous.
1604   if (canonicalizeStridedLayout(type).getLayout().isIdentity())
1605     return MemRefType::Builder(type).setShape(newSizes).setLayout({});
1606 
1607   // Convert back to int64_t because we don't have enough information to create
1608   // new strided layouts from AffineExpr only. This corresponds to a case where
1609   // copies may be necessary.
1610   int64_t intOffset = ShapedType::kDynamicStrideOrOffset;
1611   if (auto o = offset.dyn_cast<AffineConstantExpr>())
1612     intOffset = o.getValue();
1613   SmallVector<int64_t, 4> intStrides;
1614   intStrides.reserve(strides.size());
1615   for (auto stride : newStrides) {
1616     if (auto cst = stride.dyn_cast_or_null<AffineConstantExpr>())
1617       intStrides.push_back(cst.getValue());
1618     else
1619       intStrides.push_back(ShapedType::kDynamicStrideOrOffset);
1620   }
1621   auto layout =
1622       makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext());
1623   return canonicalizeStridedLayout(
1624       MemRefType::Builder(type).setShape(newSizes).setLayout(
1625           AffineMapAttr::get(layout)));
1626 }
1627 
1628 void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src,
1629                           ArrayRef<ReassociationIndices> reassociation,
1630                           ArrayRef<NamedAttribute> attrs) {
1631   auto memRefType = src.getType().cast<MemRefType>();
1632   auto resultType = computeReshapeCollapsedType(
1633       memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
1634                       b.getContext(), reassociation)));
1635   build(b, result, resultType, src, attrs);
1636   result.addAttribute(getReassociationAttrName(),
1637                       getReassociationIndicesAttribute(b, reassociation));
1638 }
1639 
1640 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
1641                             ArrayRef<ReassociationIndices> reassociation,
1642                             ArrayRef<NamedAttribute> attrs) {
1643   auto memRefType = src.getType().cast<MemRefType>();
1644   auto resultType = computeReshapeCollapsedType(
1645       memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
1646                       b.getContext(), reassociation)));
1647   build(b, result, resultType, src, attrs);
1648   result.addAttribute(getReassociationAttrName(),
1649                       getReassociationIndicesAttribute(b, reassociation));
1650 }
1651 
1652 template <typename ReshapeOp,
1653           bool isExpansion = std::is_same<ReshapeOp, ExpandShapeOp>::value>
1654 static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType,
1655                                      MemRefType collapsedType) {
1656   if (failed(
1657           verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
1658     return failure();
1659   auto maps = op.getReassociationMaps();
1660   MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps);
1661   if (collapsedType != expectedType)
1662     return op.emitOpError("expected collapsed type to be ")
1663            << expectedType << ", but got " << collapsedType;
1664   return success();
1665 }
1666 
1667 LogicalResult ExpandShapeOp::verify() {
1668   return verifyReshapeOp(*this, getResultType(), getSrcType());
1669 }
1670 
1671 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1672                                                 MLIRContext *context) {
1673   results.add<CollapseReshapeOps<ExpandShapeOp>,
1674               CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context);
1675 }
1676 
1677 LogicalResult CollapseShapeOp::verify() {
1678   return verifyReshapeOp(*this, getSrcType(), getResultType());
1679 }
1680 
1681 struct CollapseShapeOpMemRefCastFolder
1682     : public OpRewritePattern<CollapseShapeOp> {
1683 public:
1684   using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
1685 
1686   LogicalResult matchAndRewrite(CollapseShapeOp op,
1687                                 PatternRewriter &rewriter) const override {
1688     auto cast = op.getOperand().getDefiningOp<CastOp>();
1689     if (!cast)
1690       return failure();
1691 
1692     if (!CastOp::canFoldIntoConsumerOp(cast))
1693       return failure();
1694 
1695     Type newResultType = computeReshapeCollapsedType(
1696         cast.getOperand().getType().cast<MemRefType>(),
1697         op.getReassociationMaps());
1698 
1699     if (newResultType == op.getResultType()) {
1700       rewriter.updateRootInPlace(
1701           op, [&]() { op.srcMutable().assign(cast.source()); });
1702     } else {
1703       Value newOp = rewriter.create<CollapseShapeOp>(
1704           op->getLoc(), cast.source(), op.getReassociationIndices());
1705       rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
1706     }
1707     return success();
1708   }
1709 };
1710 
1711 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1712                                                   MLIRContext *context) {
1713   results.add<CollapseReshapeOps<CollapseShapeOp>,
1714               CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>,
1715               CollapseShapeOpMemRefCastFolder>(context);
1716 }
1717 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
1718   return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
1719 }
1720 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
1721   return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
1722 }
1723 
1724 //===----------------------------------------------------------------------===//
1725 // ReshapeOp
1726 //===----------------------------------------------------------------------===//
1727 
1728 LogicalResult ReshapeOp::verify() {
1729   Type operandType = source().getType();
1730   Type resultType = result().getType();
1731 
1732   Type operandElementType = operandType.cast<ShapedType>().getElementType();
1733   Type resultElementType = resultType.cast<ShapedType>().getElementType();
1734   if (operandElementType != resultElementType)
1735     return emitOpError("element types of source and destination memref "
1736                        "types should be the same");
1737 
1738   if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
1739     if (!operandMemRefType.getLayout().isIdentity())
1740       return emitOpError("source memref type should have identity affine map");
1741 
1742   int64_t shapeSize = shape().getType().cast<MemRefType>().getDimSize(0);
1743   auto resultMemRefType = resultType.dyn_cast<MemRefType>();
1744   if (resultMemRefType) {
1745     if (!resultMemRefType.getLayout().isIdentity())
1746       return emitOpError("result memref type should have identity affine map");
1747     if (shapeSize == ShapedType::kDynamicSize)
1748       return emitOpError("cannot use shape operand with dynamic length to "
1749                          "reshape to statically-ranked memref type");
1750     if (shapeSize != resultMemRefType.getRank())
1751       return emitOpError(
1752           "length of shape operand differs from the result's memref rank");
1753   }
1754   return success();
1755 }
1756 
1757 //===----------------------------------------------------------------------===//
1758 // StoreOp
1759 //===----------------------------------------------------------------------===//
1760 
1761 LogicalResult StoreOp::verify() {
1762   if (getNumOperands() != 2 + getMemRefType().getRank())
1763     return emitOpError("store index operand count not equal to memref rank");
1764 
1765   return success();
1766 }
1767 
1768 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
1769                             SmallVectorImpl<OpFoldResult> &results) {
1770   /// store(memrefcast) -> store
1771   return foldMemRefCast(*this, getValueToStore());
1772 }
1773 
1774 //===----------------------------------------------------------------------===//
1775 // SubViewOp
1776 //===----------------------------------------------------------------------===//
1777 
1778 namespace {
1779 /// Helpers to write more idiomatic operations.
1780 namespace saturated_arith {
1781 struct Wrapper {
1782   explicit Wrapper(int64_t v) : v(v) {}
1783   operator int64_t() { return v; }
1784   int64_t v;
1785 };
1786 Wrapper operator+(Wrapper a, int64_t b) {
1787   if (ShapedType::isDynamicStrideOrOffset(a) ||
1788       ShapedType::isDynamicStrideOrOffset(b))
1789     return Wrapper(ShapedType::kDynamicStrideOrOffset);
1790   return Wrapper(a.v + b);
1791 }
1792 Wrapper operator*(Wrapper a, int64_t b) {
1793   if (ShapedType::isDynamicStrideOrOffset(a) ||
1794       ShapedType::isDynamicStrideOrOffset(b))
1795     return Wrapper(ShapedType::kDynamicStrideOrOffset);
1796   return Wrapper(a.v * b);
1797 }
1798 } // namespace saturated_arith
1799 } // namespace
1800 
1801 /// A subview result type can be fully inferred from the source type and the
1802 /// static representation of offsets, sizes and strides. Special sentinels
1803 /// encode the dynamic case.
1804 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
1805                                 ArrayRef<int64_t> staticOffsets,
1806                                 ArrayRef<int64_t> staticSizes,
1807                                 ArrayRef<int64_t> staticStrides) {
1808   unsigned rank = sourceMemRefType.getRank();
1809   (void)rank;
1810   assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
1811   assert(staticSizes.size() == rank && "staticSizes length mismatch");
1812   assert(staticStrides.size() == rank && "staticStrides length mismatch");
1813 
1814   // Extract source offset and strides.
1815   int64_t sourceOffset;
1816   SmallVector<int64_t, 4> sourceStrides;
1817   auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset);
1818   assert(succeeded(res) && "SubViewOp expected strided memref type");
1819   (void)res;
1820 
1821   // Compute target offset whose value is:
1822   //   `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
1823   int64_t targetOffset = sourceOffset;
1824   for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
1825     auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
1826     using namespace saturated_arith;
1827     targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride;
1828   }
1829 
1830   // Compute target stride whose value is:
1831   //   `sourceStrides_i * staticStrides_i`.
1832   SmallVector<int64_t, 4> targetStrides;
1833   targetStrides.reserve(staticOffsets.size());
1834   for (auto it : llvm::zip(sourceStrides, staticStrides)) {
1835     auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
1836     using namespace saturated_arith;
1837     targetStrides.push_back(Wrapper(sourceStride) * staticStride);
1838   }
1839 
1840   // The type is now known.
1841   return MemRefType::get(
1842       staticSizes, sourceMemRefType.getElementType(),
1843       makeStridedLinearLayoutMap(targetStrides, targetOffset,
1844                                  sourceMemRefType.getContext()),
1845       sourceMemRefType.getMemorySpace());
1846 }
1847 
1848 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
1849                                 ArrayRef<OpFoldResult> offsets,
1850                                 ArrayRef<OpFoldResult> sizes,
1851                                 ArrayRef<OpFoldResult> strides) {
1852   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1853   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1854   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1855                              ShapedType::kDynamicStrideOrOffset);
1856   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1857                              ShapedType::kDynamicSize);
1858   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1859                              ShapedType::kDynamicStrideOrOffset);
1860   return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
1861                                     staticSizes, staticStrides);
1862 }
1863 
1864 Type SubViewOp::inferRankReducedResultType(unsigned resultRank,
1865                                            MemRefType sourceRankedTensorType,
1866                                            ArrayRef<int64_t> offsets,
1867                                            ArrayRef<int64_t> sizes,
1868                                            ArrayRef<int64_t> strides) {
1869   auto inferredType =
1870       inferResultType(sourceRankedTensorType, offsets, sizes, strides)
1871           .cast<MemRefType>();
1872   assert(inferredType.getRank() >= resultRank && "expected ");
1873   int rankDiff = inferredType.getRank() - resultRank;
1874   if (rankDiff > 0) {
1875     auto shape = inferredType.getShape();
1876     llvm::SmallBitVector dimsToProject =
1877         getPositionsOfShapeOne(rankDiff, shape);
1878     SmallVector<int64_t> projectedShape;
1879     for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
1880       if (!dimsToProject.test(pos))
1881         projectedShape.push_back(shape[pos]);
1882 
1883     AffineMap map = inferredType.getLayout().getAffineMap();
1884     if (!map.isIdentity())
1885       map = getProjectedMap(map, dimsToProject);
1886     inferredType =
1887         MemRefType::get(projectedShape, inferredType.getElementType(), map,
1888                         inferredType.getMemorySpace());
1889   }
1890   return inferredType;
1891 }
1892 
1893 Type SubViewOp::inferRankReducedResultType(unsigned resultRank,
1894                                            MemRefType sourceRankedTensorType,
1895                                            ArrayRef<OpFoldResult> offsets,
1896                                            ArrayRef<OpFoldResult> sizes,
1897                                            ArrayRef<OpFoldResult> strides) {
1898   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1899   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1900   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1901                              ShapedType::kDynamicStrideOrOffset);
1902   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1903                              ShapedType::kDynamicSize);
1904   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1905                              ShapedType::kDynamicStrideOrOffset);
1906   return SubViewOp::inferRankReducedResultType(
1907       resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
1908       staticStrides);
1909 }
1910 // Build a SubViewOp with mixed static and dynamic entries and custom result
1911 // type. If the type passed is nullptr, it is inferred.
1912 void SubViewOp::build(OpBuilder &b, OperationState &result,
1913                       MemRefType resultType, Value source,
1914                       ArrayRef<OpFoldResult> offsets,
1915                       ArrayRef<OpFoldResult> sizes,
1916                       ArrayRef<OpFoldResult> strides,
1917                       ArrayRef<NamedAttribute> attrs) {
1918   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1919   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1920   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1921                              ShapedType::kDynamicStrideOrOffset);
1922   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1923                              ShapedType::kDynamicSize);
1924   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1925                              ShapedType::kDynamicStrideOrOffset);
1926   auto sourceMemRefType = source.getType().cast<MemRefType>();
1927   // Structuring implementation this way avoids duplication between builders.
1928   if (!resultType) {
1929     resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
1930                                             staticSizes, staticStrides)
1931                      .cast<MemRefType>();
1932   }
1933   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1934         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1935         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1936   result.addAttributes(attrs);
1937 }
1938 
1939 // Build a SubViewOp with mixed static and dynamic entries and inferred result
1940 // type.
1941 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1942                       ArrayRef<OpFoldResult> offsets,
1943                       ArrayRef<OpFoldResult> sizes,
1944                       ArrayRef<OpFoldResult> strides,
1945                       ArrayRef<NamedAttribute> attrs) {
1946   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
1947 }
1948 
1949 // Build a SubViewOp with static entries and inferred result type.
1950 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1951                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1952                       ArrayRef<int64_t> strides,
1953                       ArrayRef<NamedAttribute> attrs) {
1954   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1955       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
1956         return b.getI64IntegerAttr(v);
1957       }));
1958   SmallVector<OpFoldResult> sizeValues =
1959       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1960         return b.getI64IntegerAttr(v);
1961       }));
1962   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1963       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1964         return b.getI64IntegerAttr(v);
1965       }));
1966   build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
1967 }
1968 
1969 // Build a SubViewOp with dynamic entries and custom result type. If the
1970 // type passed is nullptr, it is inferred.
1971 void SubViewOp::build(OpBuilder &b, OperationState &result,
1972                       MemRefType resultType, Value source,
1973                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1974                       ArrayRef<int64_t> strides,
1975                       ArrayRef<NamedAttribute> attrs) {
1976   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1977       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
1978         return b.getI64IntegerAttr(v);
1979       }));
1980   SmallVector<OpFoldResult> sizeValues =
1981       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1982         return b.getI64IntegerAttr(v);
1983       }));
1984   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1985       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1986         return b.getI64IntegerAttr(v);
1987       }));
1988   build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
1989         attrs);
1990 }
1991 
1992 // Build a SubViewOp with dynamic entries and custom result type. If the type
1993 // passed is nullptr, it is inferred.
1994 void SubViewOp::build(OpBuilder &b, OperationState &result,
1995                       MemRefType resultType, Value source, ValueRange offsets,
1996                       ValueRange sizes, ValueRange strides,
1997                       ArrayRef<NamedAttribute> attrs) {
1998   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1999       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2000   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2001       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2002   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2003       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2004   build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2005 }
2006 
2007 // Build a SubViewOp with dynamic entries and inferred result type.
2008 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2009                       ValueRange offsets, ValueRange sizes, ValueRange strides,
2010                       ArrayRef<NamedAttribute> attrs) {
2011   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2012 }
2013 
2014 /// For ViewLikeOpInterface.
2015 Value SubViewOp::getViewSource() { return source(); }
2016 
2017 /// Return true if t1 and t2 have equal offsets (both dynamic or of same static
2018 /// value).
2019 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
2020   AffineExpr t1Offset, t2Offset;
2021   SmallVector<AffineExpr> t1Strides, t2Strides;
2022   auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
2023   auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
2024   return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
2025 }
2026 
2027 /// Checks if `original` Type type can be rank reduced to `reduced` type.
2028 /// This function is slight variant of `is subsequence` algorithm where
2029 /// not matching dimension must be 1.
2030 static SliceVerificationResult
2031 isRankReducedMemRefType(MemRefType originalType,
2032                         MemRefType candidateRankReducedType,
2033                         ArrayRef<OpFoldResult> sizes) {
2034   auto partialRes = isRankReducedType(originalType, candidateRankReducedType);
2035   if (partialRes != SliceVerificationResult::Success)
2036     return partialRes;
2037 
2038   auto optionalUnusedDimsMask = computeMemRefRankReductionMask(
2039       originalType, candidateRankReducedType, sizes);
2040 
2041   // Sizes cannot be matched in case empty vector is returned.
2042   if (!optionalUnusedDimsMask.hasValue())
2043     return SliceVerificationResult::LayoutMismatch;
2044 
2045   if (originalType.getMemorySpace() !=
2046       candidateRankReducedType.getMemorySpace())
2047     return SliceVerificationResult::MemSpaceMismatch;
2048 
2049   // No amount of stride dropping can reconcile incompatible offsets.
2050   if (!haveCompatibleOffsets(originalType, candidateRankReducedType))
2051     return SliceVerificationResult::LayoutMismatch;
2052 
2053   return SliceVerificationResult::Success;
2054 }
2055 
2056 template <typename OpTy>
2057 static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result,
2058                                             OpTy op, Type expectedType) {
2059   auto memrefType = expectedType.cast<ShapedType>();
2060   switch (result) {
2061   case SliceVerificationResult::Success:
2062     return success();
2063   case SliceVerificationResult::RankTooLarge:
2064     return op.emitError("expected result rank to be smaller or equal to ")
2065            << "the source rank. ";
2066   case SliceVerificationResult::SizeMismatch:
2067     return op.emitError("expected result type to be ")
2068            << expectedType
2069            << " or a rank-reduced version. (mismatch of result sizes) ";
2070   case SliceVerificationResult::ElemTypeMismatch:
2071     return op.emitError("expected result element type to be ")
2072            << memrefType.getElementType();
2073   case SliceVerificationResult::MemSpaceMismatch:
2074     return op.emitError("expected result and source memory spaces to match.");
2075   case SliceVerificationResult::LayoutMismatch:
2076     return op.emitError("expected result type to be ")
2077            << expectedType
2078            << " or a rank-reduced version. (mismatch of result layout) ";
2079   }
2080   llvm_unreachable("unexpected subview verification result");
2081 }
2082 
2083 /// Verifier for SubViewOp.
2084 LogicalResult SubViewOp::verify() {
2085   MemRefType baseType = getSourceType();
2086   MemRefType subViewType = getType();
2087 
2088   // The base memref and the view memref should be in the same memory space.
2089   if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2090     return emitError("different memory spaces specified for base memref "
2091                      "type ")
2092            << baseType << " and subview memref type " << subViewType;
2093 
2094   // Verify that the base memref type has a strided layout map.
2095   if (!isStrided(baseType))
2096     return emitError("base type ") << baseType << " is not strided";
2097 
2098   // Verify result type against inferred type.
2099   auto expectedType = SubViewOp::inferResultType(
2100       baseType, extractFromI64ArrayAttr(static_offsets()),
2101       extractFromI64ArrayAttr(static_sizes()),
2102       extractFromI64ArrayAttr(static_strides()));
2103 
2104   auto result = isRankReducedMemRefType(expectedType.cast<MemRefType>(),
2105                                         subViewType, getMixedSizes());
2106   return produceSubViewErrorMsg(result, *this, expectedType);
2107 }
2108 
2109 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
2110   return os << "range " << range.offset << ":" << range.size << ":"
2111             << range.stride;
2112 }
2113 
2114 /// Return the list of Range (i.e. offset, size, stride). Each Range
2115 /// entry contains either the dynamic value or a ConstantIndexOp constructed
2116 /// with `b` at location `loc`.
2117 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
2118                                               OpBuilder &b, Location loc) {
2119   std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
2120   assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
2121   assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
2122   SmallVector<Range, 8> res;
2123   unsigned rank = ranks[0];
2124   res.reserve(rank);
2125   for (unsigned idx = 0; idx < rank; ++idx) {
2126     Value offset =
2127         op.isDynamicOffset(idx)
2128             ? op.getDynamicOffset(idx)
2129             : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx));
2130     Value size =
2131         op.isDynamicSize(idx)
2132             ? op.getDynamicSize(idx)
2133             : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx));
2134     Value stride =
2135         op.isDynamicStride(idx)
2136             ? op.getDynamicStride(idx)
2137             : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx));
2138     res.emplace_back(Range{offset, size, stride});
2139   }
2140   return res;
2141 }
2142 
2143 /// Compute the canonical result type of a SubViewOp. Call `inferResultType` to
2144 /// deduce the result type for the given `sourceType`. Additionally, reduce the
2145 /// rank of the inferred result type if `currentResultType` is lower rank than
2146 /// `currentSourceType`. Use this signature if `sourceType` is updated together
2147 /// with the result type. In this case, it is important to compute the dropped
2148 /// dimensions using `currentSourceType` whose strides align with
2149 /// `currentResultType`.
2150 static MemRefType getCanonicalSubViewResultType(
2151     MemRefType currentResultType, MemRefType currentSourceType,
2152     MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
2153     ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
2154   auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets,
2155                                                        mixedSizes, mixedStrides)
2156                                 .cast<MemRefType>();
2157   llvm::Optional<llvm::SmallBitVector> unusedDims =
2158       computeMemRefRankReductionMask(currentSourceType, currentResultType,
2159                                      mixedSizes);
2160   // Return nullptr as failure mode.
2161   if (!unusedDims)
2162     return nullptr;
2163   SmallVector<int64_t> shape;
2164   for (const auto &sizes : llvm::enumerate(nonRankReducedType.getShape())) {
2165     if (unusedDims->test(sizes.index()))
2166       continue;
2167     shape.push_back(sizes.value());
2168   }
2169   AffineMap layoutMap = nonRankReducedType.getLayout().getAffineMap();
2170   if (!layoutMap.isIdentity())
2171     layoutMap = getProjectedMap(layoutMap, unusedDims.getValue());
2172   return MemRefType::get(shape, nonRankReducedType.getElementType(), layoutMap,
2173                          nonRankReducedType.getMemorySpace());
2174 }
2175 
2176 /// Compute the canonical result type of a SubViewOp. Call `inferResultType` to
2177 /// deduce the result type. Additionally, reduce the rank of the inferred result
2178 /// type if `currentResultType` is lower rank than `sourceType`.
2179 static MemRefType getCanonicalSubViewResultType(
2180     MemRefType currentResultType, MemRefType sourceType,
2181     ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes,
2182     ArrayRef<OpFoldResult> mixedStrides) {
2183   return getCanonicalSubViewResultType(currentResultType, sourceType,
2184                                        sourceType, mixedOffsets, mixedSizes,
2185                                        mixedStrides);
2186 }
2187 
2188 /// Helper method to check if a `subview` operation is trivially a no-op. This
2189 /// is the case if the all offsets are zero, all strides are 1, and the source
2190 /// shape is same as the size of the subview. In such cases, the subview can be
2191 /// folded into its source.
2192 static bool isTrivialSubViewOp(SubViewOp subViewOp) {
2193   if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
2194     return false;
2195 
2196   auto mixedOffsets = subViewOp.getMixedOffsets();
2197   auto mixedSizes = subViewOp.getMixedSizes();
2198   auto mixedStrides = subViewOp.getMixedStrides();
2199 
2200   // Check offsets are zero.
2201   if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
2202         Optional<int64_t> intValue = getConstantIntValue(ofr);
2203         return !intValue || intValue.getValue() != 0;
2204       }))
2205     return false;
2206 
2207   // Check strides are one.
2208   if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
2209         Optional<int64_t> intValue = getConstantIntValue(ofr);
2210         return !intValue || intValue.getValue() != 1;
2211       }))
2212     return false;
2213 
2214   // Check all size values are static and matches the (static) source shape.
2215   ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
2216   for (const auto &size : llvm::enumerate(mixedSizes)) {
2217     Optional<int64_t> intValue = getConstantIntValue(size.value());
2218     if (!intValue || intValue.getValue() != sourceShape[size.index()])
2219       return false;
2220   }
2221   // All conditions met. The `SubViewOp` is foldable as a no-op.
2222   return true;
2223 }
2224 
2225 namespace {
2226 /// Pattern to rewrite a subview op with MemRefCast arguments.
2227 /// This essentially pushes memref.cast past its consuming subview when
2228 /// `canFoldIntoConsumerOp` is true.
2229 ///
2230 /// Example:
2231 /// ```
2232 ///   %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
2233 ///   %1 = memref.subview %0[0, 0][3, 4][1, 1] :
2234 ///     memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
2235 /// ```
2236 /// is rewritten into:
2237 /// ```
2238 ///   %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
2239 ///   %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to
2240 ///     memref<3x4xf32, offset:?, strides:[?, 1]>
2241 /// ```
2242 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
2243 public:
2244   using OpRewritePattern<SubViewOp>::OpRewritePattern;
2245 
2246   LogicalResult matchAndRewrite(SubViewOp subViewOp,
2247                                 PatternRewriter &rewriter) const override {
2248     // Any constant operand, just return to let SubViewOpConstantFolder kick in.
2249     if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
2250           return matchPattern(operand, matchConstantIndex());
2251         }))
2252       return failure();
2253 
2254     auto castOp = subViewOp.source().getDefiningOp<CastOp>();
2255     if (!castOp)
2256       return failure();
2257 
2258     if (!CastOp::canFoldIntoConsumerOp(castOp))
2259       return failure();
2260 
2261     // Compute the SubViewOp result type after folding the MemRefCastOp. Use the
2262     // MemRefCastOp source operand type to infer the result type and the current
2263     // SubViewOp source operand type to compute the dropped dimensions if the
2264     // operation is rank-reducing.
2265     auto resultType = getCanonicalSubViewResultType(
2266         subViewOp.getType(), subViewOp.getSourceType(),
2267         castOp.source().getType().cast<MemRefType>(),
2268         subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
2269         subViewOp.getMixedStrides());
2270     if (!resultType)
2271       return failure();
2272 
2273     Value newSubView = rewriter.create<SubViewOp>(
2274         subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
2275         subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),
2276         subViewOp.static_sizes(), subViewOp.static_strides());
2277     rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
2278                                         newSubView);
2279     return success();
2280   }
2281 };
2282 
2283 /// Canonicalize subview ops that are no-ops. When the source shape is not same
2284 /// as a result shape due to use of `affine_map`.
2285 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
2286 public:
2287   using OpRewritePattern<SubViewOp>::OpRewritePattern;
2288 
2289   LogicalResult matchAndRewrite(SubViewOp subViewOp,
2290                                 PatternRewriter &rewriter) const override {
2291     if (!isTrivialSubViewOp(subViewOp))
2292       return failure();
2293     if (subViewOp.getSourceType() == subViewOp.getType()) {
2294       rewriter.replaceOp(subViewOp, subViewOp.source());
2295       return success();
2296     }
2297     rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
2298                                         subViewOp.source());
2299     return success();
2300   }
2301 };
2302 } // namespace
2303 
2304 /// Return the canonical type of the result of a subview.
2305 struct SubViewReturnTypeCanonicalizer {
2306   MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
2307                         ArrayRef<OpFoldResult> mixedSizes,
2308                         ArrayRef<OpFoldResult> mixedStrides) {
2309     return getCanonicalSubViewResultType(op.getType(), op.getSourceType(),
2310                                          mixedOffsets, mixedSizes,
2311                                          mixedStrides);
2312   }
2313 };
2314 
2315 /// A canonicalizer wrapper to replace SubViewOps.
2316 struct SubViewCanonicalizer {
2317   void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
2318     rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2319   }
2320 };
2321 
2322 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
2323                                             MLIRContext *context) {
2324   results
2325       .add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
2326                SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>,
2327            SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
2328 }
2329 
2330 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
2331   auto resultShapedType = getResult().getType().cast<ShapedType>();
2332   auto sourceShapedType = source().getType().cast<ShapedType>();
2333 
2334   if (resultShapedType.hasStaticShape() &&
2335       resultShapedType == sourceShapedType) {
2336     return getViewSource();
2337   }
2338 
2339   return {};
2340 }
2341 
2342 //===----------------------------------------------------------------------===//
2343 // TransposeOp
2344 //===----------------------------------------------------------------------===//
2345 
2346 /// Build a strided memref type by applying `permutationMap` tp `memRefType`.
2347 static MemRefType inferTransposeResultType(MemRefType memRefType,
2348                                            AffineMap permutationMap) {
2349   auto rank = memRefType.getRank();
2350   auto originalSizes = memRefType.getShape();
2351   // Compute permuted sizes.
2352   SmallVector<int64_t, 4> sizes(rank, 0);
2353   for (const auto &en : llvm::enumerate(permutationMap.getResults()))
2354     sizes[en.index()] =
2355         originalSizes[en.value().cast<AffineDimExpr>().getPosition()];
2356 
2357   // Compute permuted strides.
2358   int64_t offset;
2359   SmallVector<int64_t, 4> strides;
2360   auto res = getStridesAndOffset(memRefType, strides, offset);
2361   assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank));
2362   (void)res;
2363   auto map =
2364       makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
2365   map = permutationMap ? map.compose(permutationMap) : map;
2366   return MemRefType::Builder(memRefType)
2367       .setShape(sizes)
2368       .setLayout(AffineMapAttr::get(map));
2369 }
2370 
2371 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
2372                         AffineMapAttr permutation,
2373                         ArrayRef<NamedAttribute> attrs) {
2374   auto permutationMap = permutation.getValue();
2375   assert(permutationMap);
2376 
2377   auto memRefType = in.getType().cast<MemRefType>();
2378   // Compute result type.
2379   MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
2380 
2381   build(b, result, resultType, in, attrs);
2382   result.addAttribute(TransposeOp::getPermutationAttrName(), permutation);
2383 }
2384 
2385 // transpose $in $permutation attr-dict : type($in) `to` type(results)
2386 void TransposeOp::print(OpAsmPrinter &p) {
2387   p << " " << in() << " " << permutation();
2388   p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
2389   p << " : " << in().getType() << " to " << getType();
2390 }
2391 
2392 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
2393   OpAsmParser::OperandType in;
2394   AffineMap permutation;
2395   MemRefType srcType, dstType;
2396   if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
2397       parser.parseOptionalAttrDict(result.attributes) ||
2398       parser.parseColonType(srcType) ||
2399       parser.resolveOperand(in, srcType, result.operands) ||
2400       parser.parseKeywordType("to", dstType) ||
2401       parser.addTypeToList(dstType, result.types))
2402     return failure();
2403 
2404   result.addAttribute(TransposeOp::getPermutationAttrName(),
2405                       AffineMapAttr::get(permutation));
2406   return success();
2407 }
2408 
2409 LogicalResult TransposeOp::verify() {
2410   if (!permutation().isPermutation())
2411     return emitOpError("expected a permutation map");
2412   if (permutation().getNumDims() != getShapedType().getRank())
2413     return emitOpError("expected a permutation map of same rank as the input");
2414 
2415   auto srcType = in().getType().cast<MemRefType>();
2416   auto dstType = getType().cast<MemRefType>();
2417   auto transposedType = inferTransposeResultType(srcType, permutation());
2418   if (dstType != transposedType)
2419     return emitOpError("output type ")
2420            << dstType << " does not match transposed input type " << srcType
2421            << ", " << transposedType;
2422   return success();
2423 }
2424 
2425 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
2426   if (succeeded(foldMemRefCast(*this)))
2427     return getResult();
2428   return {};
2429 }
2430 
2431 //===----------------------------------------------------------------------===//
2432 // ViewOp
2433 //===----------------------------------------------------------------------===//
2434 
2435 LogicalResult ViewOp::verify() {
2436   auto baseType = getOperand(0).getType().cast<MemRefType>();
2437   auto viewType = getType();
2438 
2439   // The base memref should have identity layout map (or none).
2440   if (!baseType.getLayout().isIdentity())
2441     return emitError("unsupported map for base memref type ") << baseType;
2442 
2443   // The result memref should have identity layout map (or none).
2444   if (!viewType.getLayout().isIdentity())
2445     return emitError("unsupported map for result memref type ") << viewType;
2446 
2447   // The base memref and the view memref should be in the same memory space.
2448   if (baseType.getMemorySpace() != viewType.getMemorySpace())
2449     return emitError("different memory spaces specified for base memref "
2450                      "type ")
2451            << baseType << " and view memref type " << viewType;
2452 
2453   // Verify that we have the correct number of sizes for the result type.
2454   unsigned numDynamicDims = viewType.getNumDynamicDims();
2455   if (sizes().size() != numDynamicDims)
2456     return emitError("incorrect number of size operands for type ") << viewType;
2457 
2458   return success();
2459 }
2460 
2461 Value ViewOp::getViewSource() { return source(); }
2462 
2463 namespace {
2464 
2465 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
2466   using OpRewritePattern<ViewOp>::OpRewritePattern;
2467 
2468   LogicalResult matchAndRewrite(ViewOp viewOp,
2469                                 PatternRewriter &rewriter) const override {
2470     // Return if none of the operands are constants.
2471     if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
2472           return matchPattern(operand, matchConstantIndex());
2473         }))
2474       return failure();
2475 
2476     // Get result memref type.
2477     auto memrefType = viewOp.getType();
2478 
2479     // Get offset from old memref view type 'memRefType'.
2480     int64_t oldOffset;
2481     SmallVector<int64_t, 4> oldStrides;
2482     if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
2483       return failure();
2484     assert(oldOffset == 0 && "Expected 0 offset");
2485 
2486     SmallVector<Value, 4> newOperands;
2487 
2488     // Offset cannot be folded into result type.
2489 
2490     // Fold any dynamic dim operands which are produced by a constant.
2491     SmallVector<int64_t, 4> newShapeConstants;
2492     newShapeConstants.reserve(memrefType.getRank());
2493 
2494     unsigned dynamicDimPos = 0;
2495     unsigned rank = memrefType.getRank();
2496     for (unsigned dim = 0, e = rank; dim < e; ++dim) {
2497       int64_t dimSize = memrefType.getDimSize(dim);
2498       // If this is already static dimension, keep it.
2499       if (!ShapedType::isDynamic(dimSize)) {
2500         newShapeConstants.push_back(dimSize);
2501         continue;
2502       }
2503       auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp();
2504       if (auto constantIndexOp =
2505               dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
2506         // Dynamic shape dimension will be folded.
2507         newShapeConstants.push_back(constantIndexOp.value());
2508       } else {
2509         // Dynamic shape dimension not folded; copy operand from old memref.
2510         newShapeConstants.push_back(dimSize);
2511         newOperands.push_back(viewOp.sizes()[dynamicDimPos]);
2512       }
2513       dynamicDimPos++;
2514     }
2515 
2516     // Create new memref type with constant folded dims.
2517     MemRefType newMemRefType =
2518         MemRefType::Builder(memrefType).setShape(newShapeConstants);
2519     // Nothing new, don't fold.
2520     if (newMemRefType == memrefType)
2521       return failure();
2522 
2523     // Create new ViewOp.
2524     auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType,
2525                                              viewOp.getOperand(0),
2526                                              viewOp.byte_shift(), newOperands);
2527     // Insert a cast so we have the same type as the old memref type.
2528     rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
2529     return success();
2530   }
2531 };
2532 
2533 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
2534   using OpRewritePattern<ViewOp>::OpRewritePattern;
2535 
2536   LogicalResult matchAndRewrite(ViewOp viewOp,
2537                                 PatternRewriter &rewriter) const override {
2538     Value memrefOperand = viewOp.getOperand(0);
2539     CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
2540     if (!memrefCastOp)
2541       return failure();
2542     Value allocOperand = memrefCastOp.getOperand();
2543     AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
2544     if (!allocOp)
2545       return failure();
2546     rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
2547                                         viewOp.byte_shift(), viewOp.sizes());
2548     return success();
2549   }
2550 };
2551 
2552 } // namespace
2553 
2554 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
2555                                          MLIRContext *context) {
2556   results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
2557 }
2558 
2559 //===----------------------------------------------------------------------===//
2560 // AtomicRMWOp
2561 //===----------------------------------------------------------------------===//
2562 
2563 LogicalResult AtomicRMWOp::verify() {
2564   if (getMemRefType().getRank() != getNumOperands() - 2)
2565     return emitOpError(
2566         "expects the number of subscripts to be equal to memref rank");
2567   switch (kind()) {
2568   case arith::AtomicRMWKind::addf:
2569   case arith::AtomicRMWKind::maxf:
2570   case arith::AtomicRMWKind::minf:
2571   case arith::AtomicRMWKind::mulf:
2572     if (!value().getType().isa<FloatType>())
2573       return emitOpError() << "with kind '"
2574                            << arith::stringifyAtomicRMWKind(kind())
2575                            << "' expects a floating-point type";
2576     break;
2577   case arith::AtomicRMWKind::addi:
2578   case arith::AtomicRMWKind::maxs:
2579   case arith::AtomicRMWKind::maxu:
2580   case arith::AtomicRMWKind::mins:
2581   case arith::AtomicRMWKind::minu:
2582   case arith::AtomicRMWKind::muli:
2583   case arith::AtomicRMWKind::ori:
2584   case arith::AtomicRMWKind::andi:
2585     if (!value().getType().isa<IntegerType>())
2586       return emitOpError() << "with kind '"
2587                            << arith::stringifyAtomicRMWKind(kind())
2588                            << "' expects an integer type";
2589     break;
2590   default:
2591     break;
2592   }
2593   return success();
2594 }
2595 
2596 OpFoldResult AtomicRMWOp::fold(ArrayRef<Attribute> operands) {
2597   /// atomicrmw(memrefcast) -> atomicrmw
2598   if (succeeded(foldMemRefCast(*this, value())))
2599     return getResult();
2600   return OpFoldResult();
2601 }
2602 
2603 //===----------------------------------------------------------------------===//
2604 // TableGen'd op method definitions
2605 //===----------------------------------------------------------------------===//
2606 
2607 #define GET_OP_CLASSES
2608 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
2609