1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
13 #include "mlir/Dialect/Utils/StaticValueUtils.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Interfaces/InferTypeOpInterface.h"
21 #include "mlir/Interfaces/SideEffectInterfaces.h"
22 #include "mlir/Interfaces/ViewLikeInterface.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallBitVector.h"
25 
26 using namespace mlir;
27 using namespace mlir::memref;
28 
29 namespace {
30 /// Idiomatic saturated operations on offsets, sizes and strides.
31 namespace saturated_arith {
32 struct Wrapper {
33   static Wrapper stride(int64_t v) {
34     return (ShapedType::isDynamicStrideOrOffset(v)) ? Wrapper{true, 0}
35                                                     : Wrapper{false, v};
36   }
37   static Wrapper offset(int64_t v) {
38     return (ShapedType::isDynamicStrideOrOffset(v)) ? Wrapper{true, 0}
39                                                     : Wrapper{false, v};
40   }
41   static Wrapper size(int64_t v) {
42     return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
43   }
44   int64_t asOffset() {
45     return saturated ? ShapedType::kDynamicStrideOrOffset : v;
46   }
47   int64_t asSize() { return saturated ? ShapedType::kDynamicSize : v; }
48   int64_t asStride() {
49     return saturated ? ShapedType::kDynamicStrideOrOffset : v;
50   }
51   bool operator==(Wrapper other) {
52     return (saturated && other.saturated) ||
53            (!saturated && !other.saturated && v == other.v);
54   }
55   bool operator!=(Wrapper other) { return !(*this == other); }
56   Wrapper operator+(Wrapper other) {
57     if (saturated || other.saturated)
58       return Wrapper{true, 0};
59     return Wrapper{false, other.v + v};
60   }
61   Wrapper operator*(Wrapper other) {
62     if (saturated || other.saturated)
63       return Wrapper{true, 0};
64     return Wrapper{false, other.v * v};
65   }
66   bool saturated;
67   int64_t v;
68 };
69 } // namespace saturated_arith
70 } // namespace
71 
72 /// Materialize a single constant operation from a given attribute value with
73 /// the desired resultant type.
74 Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
75                                               Attribute value, Type type,
76                                               Location loc) {
77   if (arith::ConstantOp::isBuildableWith(value, type))
78     return builder.create<arith::ConstantOp>(loc, value, type);
79   return nullptr;
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // Common canonicalization pattern support logic
84 //===----------------------------------------------------------------------===//
85 
86 /// This is a common class used for patterns of the form
87 /// "someop(memrefcast) -> someop".  It folds the source of any memref.cast
88 /// into the root operation directly.
89 LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) {
90   bool folded = false;
91   for (OpOperand &operand : op->getOpOperands()) {
92     auto cast = operand.get().getDefiningOp<CastOp>();
93     if (cast && operand.get() != inner &&
94         !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
95       operand.set(cast.getOperand());
96       folded = true;
97     }
98   }
99   return success(folded);
100 }
101 
102 /// Return an unranked/ranked tensor type for the given unranked/ranked memref
103 /// type.
104 Type mlir::memref::getTensorTypeFromMemRefType(Type type) {
105   if (auto memref = type.dyn_cast<MemRefType>())
106     return RankedTensorType::get(memref.getShape(), memref.getElementType());
107   if (auto memref = type.dyn_cast<UnrankedMemRefType>())
108     return UnrankedTensorType::get(memref.getElementType());
109   return NoneType::get(type.getContext());
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // AllocOp / AllocaOp
114 //===----------------------------------------------------------------------===//
115 
116 template <typename AllocLikeOp>
117 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
118   static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
119                 "applies to only alloc or alloca");
120   auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>();
121   if (!memRefType)
122     return op.emitOpError("result must be a memref");
123 
124   if (static_cast<int64_t>(op.dynamicSizes().size()) !=
125       memRefType.getNumDynamicDims())
126     return op.emitOpError("dimension operand count does not equal memref "
127                           "dynamic dimension count");
128 
129   unsigned numSymbols = 0;
130   if (!memRefType.getLayout().isIdentity())
131     numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
132   if (op.symbolOperands().size() != numSymbols)
133     return op.emitOpError("symbol operand count does not equal memref symbol "
134                           "count: expected ")
135            << numSymbols << ", got " << op.symbolOperands().size();
136 
137   return success();
138 }
139 
140 LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); }
141 
142 LogicalResult AllocaOp::verify() {
143   // An alloca op needs to have an ancestor with an allocation scope trait.
144   if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
145     return emitOpError(
146         "requires an ancestor op with AutomaticAllocationScope trait");
147 
148   return verifyAllocLikeOp(*this);
149 }
150 
151 namespace {
152 /// Fold constant dimensions into an alloc like operation.
153 template <typename AllocLikeOp>
154 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
155   using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
156 
157   LogicalResult matchAndRewrite(AllocLikeOp alloc,
158                                 PatternRewriter &rewriter) const override {
159     // Check to see if any dimensions operands are constants.  If so, we can
160     // substitute and drop them.
161     if (llvm::none_of(alloc.dynamicSizes(), [](Value operand) {
162           return matchPattern(operand, matchConstantIndex());
163         }))
164       return failure();
165 
166     auto memrefType = alloc.getType();
167 
168     // Ok, we have one or more constant operands.  Collect the non-constant ones
169     // and keep track of the resultant memref type to build.
170     SmallVector<int64_t, 4> newShapeConstants;
171     newShapeConstants.reserve(memrefType.getRank());
172     SmallVector<Value, 4> dynamicSizes;
173 
174     unsigned dynamicDimPos = 0;
175     for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
176       int64_t dimSize = memrefType.getDimSize(dim);
177       // If this is already static dimension, keep it.
178       if (dimSize != -1) {
179         newShapeConstants.push_back(dimSize);
180         continue;
181       }
182       auto dynamicSize = alloc.dynamicSizes()[dynamicDimPos];
183       auto *defOp = dynamicSize.getDefiningOp();
184       if (auto constantIndexOp =
185               dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
186         // Dynamic shape dimension will be folded.
187         newShapeConstants.push_back(constantIndexOp.value());
188       } else {
189         // Dynamic shape dimension not folded; copy dynamicSize from old memref.
190         newShapeConstants.push_back(-1);
191         dynamicSizes.push_back(dynamicSize);
192       }
193       dynamicDimPos++;
194     }
195 
196     // Create new memref type (which will have fewer dynamic dimensions).
197     MemRefType newMemRefType =
198         MemRefType::Builder(memrefType).setShape(newShapeConstants);
199     assert(static_cast<int64_t>(dynamicSizes.size()) ==
200            newMemRefType.getNumDynamicDims());
201 
202     // Create and insert the alloc op for the new memref.
203     auto newAlloc = rewriter.create<AllocLikeOp>(
204         alloc.getLoc(), newMemRefType, dynamicSizes, alloc.symbolOperands(),
205         alloc.alignmentAttr());
206     // Insert a cast so we have the same type as the old alloc.
207     auto resultCast =
208         rewriter.create<CastOp>(alloc.getLoc(), alloc.getType(), newAlloc);
209 
210     rewriter.replaceOp(alloc, {resultCast});
211     return success();
212   }
213 };
214 
215 /// Fold alloc operations with no users or only store and dealloc uses.
216 template <typename T>
217 struct SimplifyDeadAlloc : public OpRewritePattern<T> {
218   using OpRewritePattern<T>::OpRewritePattern;
219 
220   LogicalResult matchAndRewrite(T alloc,
221                                 PatternRewriter &rewriter) const override {
222     if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
223           if (auto storeOp = dyn_cast<StoreOp>(op))
224             return storeOp.value() == alloc;
225           return !isa<DeallocOp>(op);
226         }))
227       return failure();
228 
229     for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
230       rewriter.eraseOp(user);
231 
232     rewriter.eraseOp(alloc);
233     return success();
234   }
235 };
236 } // namespace
237 
238 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
239                                           MLIRContext *context) {
240   results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
241 }
242 
243 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
244                                            MLIRContext *context) {
245   results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
246       context);
247 }
248 
249 //===----------------------------------------------------------------------===//
250 // AllocaScopeOp
251 //===----------------------------------------------------------------------===//
252 
253 void AllocaScopeOp::print(OpAsmPrinter &p) {
254   bool printBlockTerminators = false;
255 
256   p << ' ';
257   if (!results().empty()) {
258     p << " -> (" << getResultTypes() << ")";
259     printBlockTerminators = true;
260   }
261   p << ' ';
262   p.printRegion(bodyRegion(),
263                 /*printEntryBlockArgs=*/false,
264                 /*printBlockTerminators=*/printBlockTerminators);
265   p.printOptionalAttrDict((*this)->getAttrs());
266 }
267 
268 ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
269   // Create a region for the body.
270   result.regions.reserve(1);
271   Region *bodyRegion = result.addRegion();
272 
273   // Parse optional results type list.
274   if (parser.parseOptionalArrowTypeList(result.types))
275     return failure();
276 
277   // Parse the body region.
278   if (parser.parseRegion(*bodyRegion, /*arguments=*/{}))
279     return failure();
280   AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
281                                   result.location);
282 
283   // Parse the optional attribute list.
284   if (parser.parseOptionalAttrDict(result.attributes))
285     return failure();
286 
287   return success();
288 }
289 
290 void AllocaScopeOp::getSuccessorRegions(
291     Optional<unsigned> index, ArrayRef<Attribute> operands,
292     SmallVectorImpl<RegionSuccessor> &regions) {
293   if (index.hasValue()) {
294     regions.push_back(RegionSuccessor(getResults()));
295     return;
296   }
297 
298   regions.push_back(RegionSuccessor(&bodyRegion()));
299 }
300 
301 /// Given an operation, return whether this op is guaranteed to
302 /// allocate an AutomaticAllocationScopeResource
303 static bool isGuaranteedAutomaticAllocation(Operation *op) {
304   MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
305   if (!interface)
306     return false;
307   for (auto res : op->getResults()) {
308     if (auto effect =
309             interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
310       if (isa<SideEffects::AutomaticAllocationScopeResource>(
311               effect->getResource()))
312         return true;
313     }
314   }
315   return false;
316 }
317 
318 /// Given an operation, return whether this op itself could
319 /// allocate an AutomaticAllocationScopeResource. Note that
320 /// this will not check whether an operation contained within
321 /// the op can allocate.
322 static bool isOpItselfPotentialAutomaticAllocation(Operation *op) {
323   // This op itself doesn't create a stack allocation,
324   // the inner allocation should be handled separately.
325   if (op->hasTrait<OpTrait::HasRecursiveSideEffects>())
326     return false;
327   MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
328   if (!interface)
329     return true;
330   for (auto res : op->getResults()) {
331     if (auto effect =
332             interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
333       if (isa<SideEffects::AutomaticAllocationScopeResource>(
334               effect->getResource()))
335         return true;
336     }
337   }
338   return false;
339 }
340 
341 /// Return whether this op is the last non terminating op
342 /// in a region. That is to say, it is in a one-block region
343 /// and is only followed by a terminator. This prevents
344 /// extending the lifetime of allocations.
345 static bool lastNonTerminatorInRegion(Operation *op) {
346   return op->getNextNode() == op->getBlock()->getTerminator() &&
347          op->getParentRegion()->getBlocks().size() == 1;
348 }
349 
350 /// Inline an AllocaScopeOp if either the direct parent is an allocation scope
351 /// or it contains no allocation.
352 struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
353   using OpRewritePattern<AllocaScopeOp>::OpRewritePattern;
354 
355   LogicalResult matchAndRewrite(AllocaScopeOp op,
356                                 PatternRewriter &rewriter) const override {
357     bool hasPotentialAlloca =
358         op->walk<WalkOrder::PreOrder>([&](Operation *alloc) {
359             if (alloc == op)
360               return WalkResult::advance();
361             if (isOpItselfPotentialAutomaticAllocation(alloc))
362               return WalkResult::interrupt();
363             if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>())
364               return WalkResult::skip();
365             return WalkResult::advance();
366           }).wasInterrupted();
367 
368     // If this contains no potential allocation, it is always legal to
369     // inline. Otherwise, consider two conditions:
370     if (hasPotentialAlloca) {
371       // If the parent isn't an allocation scope, or we are not the last
372       // non-terminator op in the parent, we will extend the lifetime.
373       if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>())
374         return failure();
375       if (!lastNonTerminatorInRegion(op))
376         return failure();
377     }
378 
379     Block *block = &op.getRegion().front();
380     Operation *terminator = block->getTerminator();
381     ValueRange results = terminator->getOperands();
382     rewriter.mergeBlockBefore(block, op);
383     rewriter.replaceOp(op, results);
384     rewriter.eraseOp(terminator);
385     return success();
386   }
387 };
388 
389 /// Move allocations into an allocation scope, if it is legal to
390 /// move them (e.g. their operands are available at the location
391 /// the op would be moved to).
392 struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> {
393   using OpRewritePattern<AllocaScopeOp>::OpRewritePattern;
394 
395   LogicalResult matchAndRewrite(AllocaScopeOp op,
396                                 PatternRewriter &rewriter) const override {
397 
398     if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
399       return failure();
400 
401     Operation *lastParentWithoutScope = op->getParentOp();
402 
403     if (!lastParentWithoutScope ||
404         lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>())
405       return failure();
406 
407     // Only apply to if this is this last non-terminator
408     // op in the block (lest lifetime be extended) of a one
409     // block region
410     if (!lastNonTerminatorInRegion(op) ||
411         !lastNonTerminatorInRegion(lastParentWithoutScope))
412       return failure();
413 
414     while (!lastParentWithoutScope->getParentOp()
415                 ->hasTrait<OpTrait::AutomaticAllocationScope>()) {
416       lastParentWithoutScope = lastParentWithoutScope->getParentOp();
417       if (!lastParentWithoutScope ||
418           !lastNonTerminatorInRegion(lastParentWithoutScope))
419         return failure();
420     }
421     assert(lastParentWithoutScope->getParentOp()
422                ->hasTrait<OpTrait::AutomaticAllocationScope>());
423 
424     Region *containingRegion = nullptr;
425     for (auto &r : lastParentWithoutScope->getRegions()) {
426       if (r.isAncestor(op->getParentRegion())) {
427         assert(containingRegion == nullptr &&
428                "only one region can contain the op");
429         containingRegion = &r;
430       }
431     }
432     assert(containingRegion && "op must be contained in a region");
433 
434     SmallVector<Operation *> toHoist;
435     op->walk([&](Operation *alloc) {
436       if (!isGuaranteedAutomaticAllocation(alloc))
437         return WalkResult::skip();
438 
439       // If any operand is not defined before the location of
440       // lastParentWithoutScope (i.e. where we would hoist to), skip.
441       if (llvm::any_of(alloc->getOperands(), [&](Value v) {
442             return containingRegion->isAncestor(v.getParentRegion());
443           }))
444         return WalkResult::skip();
445       toHoist.push_back(alloc);
446       return WalkResult::advance();
447     });
448 
449     if (toHoist.empty())
450       return failure();
451     rewriter.setInsertionPoint(lastParentWithoutScope);
452     for (auto *op : toHoist) {
453       auto *cloned = rewriter.clone(*op);
454       rewriter.replaceOp(op, cloned->getResults());
455     }
456     return success();
457   }
458 };
459 
460 void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,
461                                                 MLIRContext *context) {
462   results.add<AllocaScopeInliner, AllocaScopeHoister>(context);
463 }
464 
465 //===----------------------------------------------------------------------===//
466 // AssumeAlignmentOp
467 //===----------------------------------------------------------------------===//
468 
469 LogicalResult AssumeAlignmentOp::verify() {
470   if (!llvm::isPowerOf2_32(alignment()))
471     return emitOpError("alignment must be power of 2");
472   return success();
473 }
474 
475 //===----------------------------------------------------------------------===//
476 // CastOp
477 //===----------------------------------------------------------------------===//
478 
479 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
480 /// source memref. This is useful to to fold a memref.cast into a consuming op
481 /// and implement canonicalization patterns for ops in different dialects that
482 /// may consume the results of memref.cast operations. Such foldable memref.cast
483 /// operations are typically inserted as `view` and `subview` ops are
484 /// canonicalized, to preserve the type compatibility of their uses.
485 ///
486 /// Returns true when all conditions are met:
487 /// 1. source and result are ranked memrefs with strided semantics and same
488 /// element type and rank.
489 /// 2. each of the source's size, offset or stride has more static information
490 /// than the corresponding result's size, offset or stride.
491 ///
492 /// Example 1:
493 /// ```mlir
494 ///   %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
495 ///   %2 = consumer %1 ... : memref<?x?xf32> ...
496 /// ```
497 ///
498 /// may fold into:
499 ///
500 /// ```mlir
501 ///   %2 = consumer %0 ... : memref<8x16xf32> ...
502 /// ```
503 ///
504 /// Example 2:
505 /// ```
506 ///   %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
507 ///          to memref<?x?xf32>
508 ///   consumer %1 : memref<?x?xf32> ...
509 /// ```
510 ///
511 /// may fold into:
512 ///
513 /// ```
514 ///   consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
515 /// ```
516 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
517   MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
518   MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
519 
520   // Requires ranked MemRefType.
521   if (!sourceType || !resultType)
522     return false;
523 
524   // Requires same elemental type.
525   if (sourceType.getElementType() != resultType.getElementType())
526     return false;
527 
528   // Requires same rank.
529   if (sourceType.getRank() != resultType.getRank())
530     return false;
531 
532   // Only fold casts between strided memref forms.
533   int64_t sourceOffset, resultOffset;
534   SmallVector<int64_t, 4> sourceStrides, resultStrides;
535   if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
536       failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
537     return false;
538 
539   // If cast is towards more static sizes along any dimension, don't fold.
540   for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
541     auto ss = std::get<0>(it), st = std::get<1>(it);
542     if (ss != st)
543       if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
544         return false;
545   }
546 
547   // If cast is towards more static offset along any dimension, don't fold.
548   if (sourceOffset != resultOffset)
549     if (ShapedType::isDynamicStrideOrOffset(sourceOffset) &&
550         !ShapedType::isDynamicStrideOrOffset(resultOffset))
551       return false;
552 
553   // If cast is towards more static strides along any dimension, don't fold.
554   for (auto it : llvm::zip(sourceStrides, resultStrides)) {
555     auto ss = std::get<0>(it), st = std::get<1>(it);
556     if (ss != st)
557       if (ShapedType::isDynamicStrideOrOffset(ss) &&
558           !ShapedType::isDynamicStrideOrOffset(st))
559         return false;
560   }
561 
562   return true;
563 }
564 
565 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
566   if (inputs.size() != 1 || outputs.size() != 1)
567     return false;
568   Type a = inputs.front(), b = outputs.front();
569   auto aT = a.dyn_cast<MemRefType>();
570   auto bT = b.dyn_cast<MemRefType>();
571 
572   auto uaT = a.dyn_cast<UnrankedMemRefType>();
573   auto ubT = b.dyn_cast<UnrankedMemRefType>();
574 
575   if (aT && bT) {
576     if (aT.getElementType() != bT.getElementType())
577       return false;
578     if (aT.getLayout() != bT.getLayout()) {
579       int64_t aOffset, bOffset;
580       SmallVector<int64_t, 4> aStrides, bStrides;
581       if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
582           failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
583           aStrides.size() != bStrides.size())
584         return false;
585 
586       // Strides along a dimension/offset are compatible if the value in the
587       // source memref is static and the value in the target memref is the
588       // same. They are also compatible if either one is dynamic (see
589       // description of MemRefCastOp for details).
590       auto checkCompatible = [](int64_t a, int64_t b) {
591         return (a == MemRefType::getDynamicStrideOrOffset() ||
592                 b == MemRefType::getDynamicStrideOrOffset() || a == b);
593       };
594       if (!checkCompatible(aOffset, bOffset))
595         return false;
596       for (const auto &aStride : enumerate(aStrides))
597         if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
598           return false;
599     }
600     if (aT.getMemorySpace() != bT.getMemorySpace())
601       return false;
602 
603     // They must have the same rank, and any specified dimensions must match.
604     if (aT.getRank() != bT.getRank())
605       return false;
606 
607     for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
608       int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
609       if (aDim != -1 && bDim != -1 && aDim != bDim)
610         return false;
611     }
612     return true;
613   } else {
614     if (!aT && !uaT)
615       return false;
616     if (!bT && !ubT)
617       return false;
618     // Unranked to unranked casting is unsupported
619     if (uaT && ubT)
620       return false;
621 
622     auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
623     auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
624     if (aEltType != bEltType)
625       return false;
626 
627     auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
628     auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
629     return aMemSpace == bMemSpace;
630   }
631 
632   return false;
633 }
634 
635 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
636   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
637 }
638 
639 //===----------------------------------------------------------------------===//
640 // CopyOp
641 //===----------------------------------------------------------------------===//
642 
643 namespace {
644 /// If the source/target of a CopyOp is a CastOp that does not modify the shape
645 /// and element type, the cast can be skipped. Such CastOps only cast the layout
646 /// of the type.
647 struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
648   using OpRewritePattern<CopyOp>::OpRewritePattern;
649 
650   LogicalResult matchAndRewrite(CopyOp copyOp,
651                                 PatternRewriter &rewriter) const override {
652     bool modified = false;
653 
654     // Check source.
655     if (auto castOp = copyOp.source().getDefiningOp<CastOp>()) {
656       auto fromType = castOp.source().getType().dyn_cast<MemRefType>();
657       auto toType = castOp.source().getType().dyn_cast<MemRefType>();
658 
659       if (fromType && toType) {
660         if (fromType.getShape() == toType.getShape() &&
661             fromType.getElementType() == toType.getElementType()) {
662           rewriter.updateRootInPlace(
663               copyOp, [&] { copyOp.sourceMutable().assign(castOp.source()); });
664           modified = true;
665         }
666       }
667     }
668 
669     // Check target.
670     if (auto castOp = copyOp.target().getDefiningOp<CastOp>()) {
671       auto fromType = castOp.source().getType().dyn_cast<MemRefType>();
672       auto toType = castOp.source().getType().dyn_cast<MemRefType>();
673 
674       if (fromType && toType) {
675         if (fromType.getShape() == toType.getShape() &&
676             fromType.getElementType() == toType.getElementType()) {
677           rewriter.updateRootInPlace(
678               copyOp, [&] { copyOp.targetMutable().assign(castOp.source()); });
679           modified = true;
680         }
681       }
682     }
683 
684     return success(modified);
685   }
686 };
687 
688 /// Fold memref.copy(%x, %x).
689 struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
690   using OpRewritePattern<CopyOp>::OpRewritePattern;
691 
692   LogicalResult matchAndRewrite(CopyOp copyOp,
693                                 PatternRewriter &rewriter) const override {
694     if (copyOp.source() != copyOp.target())
695       return failure();
696 
697     rewriter.eraseOp(copyOp);
698     return success();
699   }
700 };
701 } // namespace
702 
703 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
704                                          MLIRContext *context) {
705   results.add<FoldCopyOfCast, FoldSelfCopy>(context);
706 }
707 
708 LogicalResult CopyOp::fold(ArrayRef<Attribute> cstOperands,
709                            SmallVectorImpl<OpFoldResult> &results) {
710   /// copy(memrefcast) -> copy
711   bool folded = false;
712   Operation *op = *this;
713   for (OpOperand &operand : op->getOpOperands()) {
714     auto castOp = operand.get().getDefiningOp<memref::CastOp>();
715     if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
716       operand.set(castOp.getOperand());
717       folded = true;
718     }
719   }
720   return success(folded);
721 }
722 
723 //===----------------------------------------------------------------------===//
724 // DeallocOp
725 //===----------------------------------------------------------------------===//
726 
727 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
728                               SmallVectorImpl<OpFoldResult> &results) {
729   /// dealloc(memrefcast) -> dealloc
730   return foldMemRefCast(*this);
731 }
732 
733 //===----------------------------------------------------------------------===//
734 // DimOp
735 //===----------------------------------------------------------------------===//
736 
737 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
738                   int64_t index) {
739   auto loc = result.location;
740   Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
741   build(builder, result, source, indexValue);
742 }
743 
744 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
745                   Value index) {
746   auto indexTy = builder.getIndexType();
747   build(builder, result, indexTy, source, index);
748 }
749 
750 Optional<int64_t> DimOp::getConstantIndex() {
751   if (auto constantOp = index().getDefiningOp<arith::ConstantOp>())
752     return constantOp.getValue().cast<IntegerAttr>().getInt();
753   return {};
754 }
755 
756 LogicalResult DimOp::verify() {
757   // Assume unknown index to be in range.
758   Optional<int64_t> index = getConstantIndex();
759   if (!index.hasValue())
760     return success();
761 
762   // Check that constant index is not knowingly out of range.
763   auto type = source().getType();
764   if (auto memrefType = type.dyn_cast<MemRefType>()) {
765     if (index.getValue() >= memrefType.getRank())
766       return emitOpError("index is out of range");
767   } else if (type.isa<UnrankedMemRefType>()) {
768     // Assume index to be in range.
769   } else {
770     llvm_unreachable("expected operand with memref type");
771   }
772   return success();
773 }
774 
775 /// Return a map with key being elements in `vals` and data being number of
776 /// occurences of it. Use std::map, since the `vals` here are strides and the
777 /// dynamic stride value is the same as the tombstone value for
778 /// `DenseMap<int64_t>`.
779 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
780   std::map<int64_t, unsigned> numOccurences;
781   for (auto val : vals)
782     numOccurences[val]++;
783   return numOccurences;
784 }
785 
786 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed
787 /// to be a subset of `originalType` with some `1` entries erased, return the
788 /// set of indices that specifies which of the entries of `originalShape` are
789 /// dropped to obtain `reducedShape`.
790 /// This accounts for cases where there are multiple unit-dims, but only a
791 /// subset of those are dropped. For MemRefTypes these can be disambiguated
792 /// using the strides. If a dimension is dropped the stride must be dropped too.
793 static llvm::Optional<llvm::SmallBitVector>
794 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
795                                ArrayRef<OpFoldResult> sizes) {
796   llvm::SmallBitVector unusedDims(originalType.getRank());
797   if (originalType.getRank() == reducedType.getRank())
798     return unusedDims;
799 
800   for (const auto &dim : llvm::enumerate(sizes))
801     if (auto attr = dim.value().dyn_cast<Attribute>())
802       if (attr.cast<IntegerAttr>().getInt() == 1)
803         unusedDims.set(dim.index());
804 
805   // Early exit for the case where the number of unused dims matches the number
806   // of ranks reduced.
807   if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
808       originalType.getRank())
809     return unusedDims;
810 
811   SmallVector<int64_t> originalStrides, candidateStrides;
812   int64_t originalOffset, candidateOffset;
813   if (failed(
814           getStridesAndOffset(originalType, originalStrides, originalOffset)) ||
815       failed(
816           getStridesAndOffset(reducedType, candidateStrides, candidateOffset)))
817     return llvm::None;
818 
819   // For memrefs, a dimension is truly dropped if its corresponding stride is
820   // also dropped. This is particularly important when more than one of the dims
821   // is 1. Track the number of occurences of the strides in the original type
822   // and the candidate type. For each unused dim that stride should not be
823   // present in the candidate type. Note that there could be multiple dimensions
824   // that have the same size. We dont need to exactly figure out which dim
825   // corresponds to which stride, we just need to verify that the number of
826   // reptitions of a stride in the original + number of unused dims with that
827   // stride == number of repititions of a stride in the candidate.
828   std::map<int64_t, unsigned> currUnaccountedStrides =
829       getNumOccurences(originalStrides);
830   std::map<int64_t, unsigned> candidateStridesNumOccurences =
831       getNumOccurences(candidateStrides);
832   for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
833     if (!unusedDims.test(dim))
834       continue;
835     int64_t originalStride = originalStrides[dim];
836     if (currUnaccountedStrides[originalStride] >
837         candidateStridesNumOccurences[originalStride]) {
838       // This dim can be treated as dropped.
839       currUnaccountedStrides[originalStride]--;
840       continue;
841     }
842     if (currUnaccountedStrides[originalStride] ==
843         candidateStridesNumOccurences[originalStride]) {
844       // The stride for this is not dropped. Keep as is.
845       unusedDims.reset(dim);
846       continue;
847     }
848     if (currUnaccountedStrides[originalStride] <
849         candidateStridesNumOccurences[originalStride]) {
850       // This should never happen. Cant have a stride in the reduced rank type
851       // that wasnt in the original one.
852       return llvm::None;
853     }
854   }
855 
856   if ((int64_t)unusedDims.count() + reducedType.getRank() !=
857       originalType.getRank())
858     return llvm::None;
859   return unusedDims;
860 }
861 
862 llvm::SmallBitVector SubViewOp::getDroppedDims() {
863   MemRefType sourceType = getSourceType();
864   MemRefType resultType = getType();
865   llvm::Optional<llvm::SmallBitVector> unusedDims =
866       computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
867   assert(unusedDims && "unable to find unused dims of subview");
868   return *unusedDims;
869 }
870 
871 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
872   // All forms of folding require a known index.
873   auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
874   if (!index)
875     return {};
876 
877   // Folding for unranked types (UnrankedMemRefType) is not supported.
878   auto memrefType = source().getType().dyn_cast<MemRefType>();
879   if (!memrefType)
880     return {};
881 
882   // Fold if the shape extent along the given index is known.
883   if (!memrefType.isDynamicDim(index.getInt())) {
884     Builder builder(getContext());
885     return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
886   }
887 
888   // The size at the given index is now known to be a dynamic size.
889   unsigned unsignedIndex = index.getValue().getZExtValue();
890 
891   // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
892   Operation *definingOp = source().getDefiningOp();
893 
894   if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
895     return *(alloc.getDynamicSizes().begin() +
896              memrefType.getDynamicDimIndex(unsignedIndex));
897 
898   if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
899     return *(alloca.getDynamicSizes().begin() +
900              memrefType.getDynamicDimIndex(unsignedIndex));
901 
902   if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
903     return *(view.getDynamicSizes().begin() +
904              memrefType.getDynamicDimIndex(unsignedIndex));
905 
906   if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
907     llvm::SmallBitVector unusedDims = subview.getDroppedDims();
908     unsigned resultIndex = 0;
909     unsigned sourceRank = subview.getSourceType().getRank();
910     unsigned sourceIndex = 0;
911     for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
912       if (unusedDims.test(i))
913         continue;
914       if (resultIndex == unsignedIndex) {
915         sourceIndex = i;
916         break;
917       }
918       resultIndex++;
919     }
920     assert(subview.isDynamicSize(sourceIndex) &&
921            "expected dynamic subview size");
922     return subview.getDynamicSize(sourceIndex);
923   }
924 
925   if (auto sizeInterface =
926           dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
927     assert(sizeInterface.isDynamicSize(unsignedIndex) &&
928            "Expected dynamic subview size");
929     return sizeInterface.getDynamicSize(unsignedIndex);
930   }
931 
932   // dim(memrefcast) -> dim
933   if (succeeded(foldMemRefCast(*this)))
934     return getResult();
935 
936   return {};
937 }
938 
939 namespace {
940 /// Fold dim of a memref reshape operation to a load into the reshape's shape
941 /// operand.
942 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
943   using OpRewritePattern<DimOp>::OpRewritePattern;
944 
945   LogicalResult matchAndRewrite(DimOp dim,
946                                 PatternRewriter &rewriter) const override {
947     auto reshape = dim.source().getDefiningOp<ReshapeOp>();
948 
949     if (!reshape)
950       return failure();
951 
952     // Place the load directly after the reshape to ensure that the shape memref
953     // was not mutated.
954     rewriter.setInsertionPointAfter(reshape);
955     Location loc = dim.getLoc();
956     Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index());
957     if (load.getType() != dim.getType())
958       load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
959     rewriter.replaceOp(dim, load);
960     return success();
961   }
962 };
963 
964 } // namespace
965 
966 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
967                                         MLIRContext *context) {
968   results.add<DimOfMemRefReshape>(context);
969 }
970 
971 // ---------------------------------------------------------------------------
972 // DmaStartOp
973 // ---------------------------------------------------------------------------
974 
975 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
976                        Value srcMemRef, ValueRange srcIndices, Value destMemRef,
977                        ValueRange destIndices, Value numElements,
978                        Value tagMemRef, ValueRange tagIndices, Value stride,
979                        Value elementsPerStride) {
980   result.addOperands(srcMemRef);
981   result.addOperands(srcIndices);
982   result.addOperands(destMemRef);
983   result.addOperands(destIndices);
984   result.addOperands({numElements, tagMemRef});
985   result.addOperands(tagIndices);
986   if (stride)
987     result.addOperands({stride, elementsPerStride});
988 }
989 
990 void DmaStartOp::print(OpAsmPrinter &p) {
991   p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
992     << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
993     << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
994   if (isStrided())
995     p << ", " << getStride() << ", " << getNumElementsPerStride();
996 
997   p.printOptionalAttrDict((*this)->getAttrs());
998   p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
999     << ", " << getTagMemRef().getType();
1000 }
1001 
1002 // Parse DmaStartOp.
1003 // Ex:
1004 //   %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
1005 //                       %tag[%index], %stride, %num_elt_per_stride :
1006 //                     : memref<3076 x f32, 0>,
1007 //                       memref<1024 x f32, 2>,
1008 //                       memref<1 x i32>
1009 //
1010 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
1011   OpAsmParser::UnresolvedOperand srcMemRefInfo;
1012   SmallVector<OpAsmParser::UnresolvedOperand, 4> srcIndexInfos;
1013   OpAsmParser::UnresolvedOperand dstMemRefInfo;
1014   SmallVector<OpAsmParser::UnresolvedOperand, 4> dstIndexInfos;
1015   OpAsmParser::UnresolvedOperand numElementsInfo;
1016   OpAsmParser::UnresolvedOperand tagMemrefInfo;
1017   SmallVector<OpAsmParser::UnresolvedOperand, 4> tagIndexInfos;
1018   SmallVector<OpAsmParser::UnresolvedOperand, 2> strideInfo;
1019 
1020   SmallVector<Type, 3> types;
1021   auto indexType = parser.getBuilder().getIndexType();
1022 
1023   // Parse and resolve the following list of operands:
1024   // *) source memref followed by its indices (in square brackets).
1025   // *) destination memref followed by its indices (in square brackets).
1026   // *) dma size in KiB.
1027   if (parser.parseOperand(srcMemRefInfo) ||
1028       parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
1029       parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1030       parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
1031       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1032       parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
1033       parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
1034     return failure();
1035 
1036   // Parse optional stride and elements per stride.
1037   if (parser.parseTrailingOperandList(strideInfo))
1038     return failure();
1039 
1040   bool isStrided = strideInfo.size() == 2;
1041   if (!strideInfo.empty() && !isStrided) {
1042     return parser.emitError(parser.getNameLoc(),
1043                             "expected two stride related operands");
1044   }
1045 
1046   if (parser.parseColonTypeList(types))
1047     return failure();
1048   if (types.size() != 3)
1049     return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1050 
1051   if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1052       parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1053       parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1054       parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1055       // size should be an index.
1056       parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1057       parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1058       // tag indices should be index.
1059       parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1060     return failure();
1061 
1062   if (isStrided) {
1063     if (parser.resolveOperands(strideInfo, indexType, result.operands))
1064       return failure();
1065   }
1066 
1067   return success();
1068 }
1069 
1070 LogicalResult DmaStartOp::verify() {
1071   unsigned numOperands = getNumOperands();
1072 
1073   // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1074   // the number of elements.
1075   if (numOperands < 4)
1076     return emitOpError("expected at least 4 operands");
1077 
1078   // Check types of operands. The order of these calls is important: the later
1079   // calls rely on some type properties to compute the operand position.
1080   // 1. Source memref.
1081   if (!getSrcMemRef().getType().isa<MemRefType>())
1082     return emitOpError("expected source to be of memref type");
1083   if (numOperands < getSrcMemRefRank() + 4)
1084     return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1085                          << " operands";
1086   if (!getSrcIndices().empty() &&
1087       !llvm::all_of(getSrcIndices().getTypes(),
1088                     [](Type t) { return t.isIndex(); }))
1089     return emitOpError("expected source indices to be of index type");
1090 
1091   // 2. Destination memref.
1092   if (!getDstMemRef().getType().isa<MemRefType>())
1093     return emitOpError("expected destination to be of memref type");
1094   unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1095   if (numOperands < numExpectedOperands)
1096     return emitOpError() << "expected at least " << numExpectedOperands
1097                          << " operands";
1098   if (!getDstIndices().empty() &&
1099       !llvm::all_of(getDstIndices().getTypes(),
1100                     [](Type t) { return t.isIndex(); }))
1101     return emitOpError("expected destination indices to be of index type");
1102 
1103   // 3. Number of elements.
1104   if (!getNumElements().getType().isIndex())
1105     return emitOpError("expected num elements to be of index type");
1106 
1107   // 4. Tag memref.
1108   if (!getTagMemRef().getType().isa<MemRefType>())
1109     return emitOpError("expected tag to be of memref type");
1110   numExpectedOperands += getTagMemRefRank();
1111   if (numOperands < numExpectedOperands)
1112     return emitOpError() << "expected at least " << numExpectedOperands
1113                          << " operands";
1114   if (!getTagIndices().empty() &&
1115       !llvm::all_of(getTagIndices().getTypes(),
1116                     [](Type t) { return t.isIndex(); }))
1117     return emitOpError("expected tag indices to be of index type");
1118 
1119   // Optional stride-related operands must be either both present or both
1120   // absent.
1121   if (numOperands != numExpectedOperands &&
1122       numOperands != numExpectedOperands + 2)
1123     return emitOpError("incorrect number of operands");
1124 
1125   // 5. Strides.
1126   if (isStrided()) {
1127     if (!getStride().getType().isIndex() ||
1128         !getNumElementsPerStride().getType().isIndex())
1129       return emitOpError(
1130           "expected stride and num elements per stride to be of type index");
1131   }
1132 
1133   return success();
1134 }
1135 
1136 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1137                                SmallVectorImpl<OpFoldResult> &results) {
1138   /// dma_start(memrefcast) -> dma_start
1139   return foldMemRefCast(*this);
1140 }
1141 
1142 // ---------------------------------------------------------------------------
1143 // DmaWaitOp
1144 // ---------------------------------------------------------------------------
1145 
1146 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1147                               SmallVectorImpl<OpFoldResult> &results) {
1148   /// dma_wait(memrefcast) -> dma_wait
1149   return foldMemRefCast(*this);
1150 }
1151 
1152 LogicalResult DmaWaitOp::verify() {
1153   // Check that the number of tag indices matches the tagMemRef rank.
1154   unsigned numTagIndices = tagIndices().size();
1155   unsigned tagMemRefRank = getTagMemRefRank();
1156   if (numTagIndices != tagMemRefRank)
1157     return emitOpError() << "expected tagIndices to have the same number of "
1158                             "elements as the tagMemRef rank, expected "
1159                          << tagMemRefRank << ", but got " << numTagIndices;
1160   return success();
1161 }
1162 
1163 //===----------------------------------------------------------------------===//
1164 // GenericAtomicRMWOp
1165 //===----------------------------------------------------------------------===//
1166 
1167 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
1168                                Value memref, ValueRange ivs) {
1169   result.addOperands(memref);
1170   result.addOperands(ivs);
1171 
1172   if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) {
1173     Type elementType = memrefType.getElementType();
1174     result.addTypes(elementType);
1175 
1176     Region *bodyRegion = result.addRegion();
1177     bodyRegion->push_back(new Block());
1178     bodyRegion->addArgument(elementType, memref.getLoc());
1179   }
1180 }
1181 
1182 LogicalResult GenericAtomicRMWOp::verify() {
1183   auto &body = getRegion();
1184   if (body.getNumArguments() != 1)
1185     return emitOpError("expected single number of entry block arguments");
1186 
1187   if (getResult().getType() != body.getArgument(0).getType())
1188     return emitOpError("expected block argument of the same type result type");
1189 
1190   bool hasSideEffects =
1191       body.walk([&](Operation *nestedOp) {
1192             if (MemoryEffectOpInterface::hasNoEffect(nestedOp))
1193               return WalkResult::advance();
1194             nestedOp->emitError(
1195                 "body of 'memref.generic_atomic_rmw' should contain "
1196                 "only operations with no side effects");
1197             return WalkResult::interrupt();
1198           })
1199           .wasInterrupted();
1200   return hasSideEffects ? failure() : success();
1201 }
1202 
1203 ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
1204                                       OperationState &result) {
1205   OpAsmParser::UnresolvedOperand memref;
1206   Type memrefType;
1207   SmallVector<OpAsmParser::UnresolvedOperand, 4> ivs;
1208 
1209   Type indexType = parser.getBuilder().getIndexType();
1210   if (parser.parseOperand(memref) ||
1211       parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) ||
1212       parser.parseColonType(memrefType) ||
1213       parser.resolveOperand(memref, memrefType, result.operands) ||
1214       parser.resolveOperands(ivs, indexType, result.operands))
1215     return failure();
1216 
1217   Region *body = result.addRegion();
1218   if (parser.parseRegion(*body, {}) ||
1219       parser.parseOptionalAttrDict(result.attributes))
1220     return failure();
1221   result.types.push_back(memrefType.cast<MemRefType>().getElementType());
1222   return success();
1223 }
1224 
1225 void GenericAtomicRMWOp::print(OpAsmPrinter &p) {
1226   p << ' ' << memref() << "[" << indices() << "] : " << memref().getType()
1227     << ' ';
1228   p.printRegion(getRegion());
1229   p.printOptionalAttrDict((*this)->getAttrs());
1230 }
1231 
1232 //===----------------------------------------------------------------------===//
1233 // AtomicYieldOp
1234 //===----------------------------------------------------------------------===//
1235 
1236 LogicalResult AtomicYieldOp::verify() {
1237   Type parentType = (*this)->getParentOp()->getResultTypes().front();
1238   Type resultType = result().getType();
1239   if (parentType != resultType)
1240     return emitOpError() << "types mismatch between yield op: " << resultType
1241                          << " and its parent: " << parentType;
1242   return success();
1243 }
1244 
1245 //===----------------------------------------------------------------------===//
1246 // GlobalOp
1247 //===----------------------------------------------------------------------===//
1248 
1249 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op,
1250                                                    TypeAttr type,
1251                                                    Attribute initialValue) {
1252   p << type;
1253   if (!op.isExternal()) {
1254     p << " = ";
1255     if (op.isUninitialized())
1256       p << "uninitialized";
1257     else
1258       p.printAttributeWithoutType(initialValue);
1259   }
1260 }
1261 
1262 static ParseResult
1263 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
1264                                        Attribute &initialValue) {
1265   Type type;
1266   if (parser.parseType(type))
1267     return failure();
1268 
1269   auto memrefType = type.dyn_cast<MemRefType>();
1270   if (!memrefType || !memrefType.hasStaticShape())
1271     return parser.emitError(parser.getNameLoc())
1272            << "type should be static shaped memref, but got " << type;
1273   typeAttr = TypeAttr::get(type);
1274 
1275   if (parser.parseOptionalEqual())
1276     return success();
1277 
1278   if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1279     initialValue = UnitAttr::get(parser.getContext());
1280     return success();
1281   }
1282 
1283   Type tensorType = getTensorTypeFromMemRefType(memrefType);
1284   if (parser.parseAttribute(initialValue, tensorType))
1285     return failure();
1286   if (!initialValue.isa<ElementsAttr>())
1287     return parser.emitError(parser.getNameLoc())
1288            << "initial value should be a unit or elements attribute";
1289   return success();
1290 }
1291 
1292 LogicalResult GlobalOp::verify() {
1293   auto memrefType = type().dyn_cast<MemRefType>();
1294   if (!memrefType || !memrefType.hasStaticShape())
1295     return emitOpError("type should be static shaped memref, but got ")
1296            << type();
1297 
1298   // Verify that the initial value, if present, is either a unit attribute or
1299   // an elements attribute.
1300   if (initial_value().hasValue()) {
1301     Attribute initValue = initial_value().getValue();
1302     if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>())
1303       return emitOpError("initial value should be a unit or elements "
1304                          "attribute, but got ")
1305              << initValue;
1306 
1307     // Check that the type of the initial value is compatible with the type of
1308     // the global variable.
1309     if (initValue.isa<ElementsAttr>()) {
1310       Type initType = initValue.getType();
1311       Type tensorType = getTensorTypeFromMemRefType(memrefType);
1312       if (initType != tensorType)
1313         return emitOpError("initial value expected to be of type ")
1314                << tensorType << ", but was of type " << initType;
1315     }
1316   }
1317 
1318   if (Optional<uint64_t> alignAttr = alignment()) {
1319     uint64_t alignment = alignAttr.getValue();
1320 
1321     if (!llvm::isPowerOf2_64(alignment))
1322       return emitError() << "alignment attribute value " << alignment
1323                          << " is not a power of 2";
1324   }
1325 
1326   // TODO: verify visibility for declarations.
1327   return success();
1328 }
1329 
1330 ElementsAttr GlobalOp::getConstantInitValue() {
1331   auto initVal = initial_value();
1332   if (constant() && initVal.hasValue())
1333     return initVal.getValue().cast<ElementsAttr>();
1334   return {};
1335 }
1336 
1337 //===----------------------------------------------------------------------===//
1338 // GetGlobalOp
1339 //===----------------------------------------------------------------------===//
1340 
1341 LogicalResult
1342 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1343   // Verify that the result type is same as the type of the referenced
1344   // memref.global op.
1345   auto global =
1346       symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr());
1347   if (!global)
1348     return emitOpError("'")
1349            << name() << "' does not reference a valid global memref";
1350 
1351   Type resultType = result().getType();
1352   if (global.type() != resultType)
1353     return emitOpError("result type ")
1354            << resultType << " does not match type " << global.type()
1355            << " of the global memref @" << name();
1356   return success();
1357 }
1358 
1359 //===----------------------------------------------------------------------===//
1360 // LoadOp
1361 //===----------------------------------------------------------------------===//
1362 
1363 LogicalResult LoadOp::verify() {
1364   if (getNumOperands() != 1 + getMemRefType().getRank())
1365     return emitOpError("incorrect number of indices for load");
1366   return success();
1367 }
1368 
1369 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
1370   /// load(memrefcast) -> load
1371   if (succeeded(foldMemRefCast(*this)))
1372     return getResult();
1373   return OpFoldResult();
1374 }
1375 
1376 //===----------------------------------------------------------------------===//
1377 // PrefetchOp
1378 //===----------------------------------------------------------------------===//
1379 
1380 void PrefetchOp::print(OpAsmPrinter &p) {
1381   p << " " << memref() << '[';
1382   p.printOperands(indices());
1383   p << ']' << ", " << (isWrite() ? "write" : "read");
1384   p << ", locality<" << localityHint();
1385   p << ">, " << (isDataCache() ? "data" : "instr");
1386   p.printOptionalAttrDict(
1387       (*this)->getAttrs(),
1388       /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1389   p << " : " << getMemRefType();
1390 }
1391 
1392 ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
1393   OpAsmParser::UnresolvedOperand memrefInfo;
1394   SmallVector<OpAsmParser::UnresolvedOperand, 4> indexInfo;
1395   IntegerAttr localityHint;
1396   MemRefType type;
1397   StringRef readOrWrite, cacheType;
1398 
1399   auto indexTy = parser.getBuilder().getIndexType();
1400   auto i32Type = parser.getBuilder().getIntegerType(32);
1401   if (parser.parseOperand(memrefInfo) ||
1402       parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1403       parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1404       parser.parseComma() || parser.parseKeyword("locality") ||
1405       parser.parseLess() ||
1406       parser.parseAttribute(localityHint, i32Type, "localityHint",
1407                             result.attributes) ||
1408       parser.parseGreater() || parser.parseComma() ||
1409       parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1410       parser.resolveOperand(memrefInfo, type, result.operands) ||
1411       parser.resolveOperands(indexInfo, indexTy, result.operands))
1412     return failure();
1413 
1414   if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
1415     return parser.emitError(parser.getNameLoc(),
1416                             "rw specifier has to be 'read' or 'write'");
1417   result.addAttribute(
1418       PrefetchOp::getIsWriteAttrName(),
1419       parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
1420 
1421   if (!cacheType.equals("data") && !cacheType.equals("instr"))
1422     return parser.emitError(parser.getNameLoc(),
1423                             "cache type has to be 'data' or 'instr'");
1424 
1425   result.addAttribute(
1426       PrefetchOp::getIsDataCacheAttrName(),
1427       parser.getBuilder().getBoolAttr(cacheType.equals("data")));
1428 
1429   return success();
1430 }
1431 
1432 LogicalResult PrefetchOp::verify() {
1433   if (getNumOperands() != 1 + getMemRefType().getRank())
1434     return emitOpError("too few indices");
1435 
1436   return success();
1437 }
1438 
1439 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
1440                                SmallVectorImpl<OpFoldResult> &results) {
1441   // prefetch(memrefcast) -> prefetch
1442   return foldMemRefCast(*this);
1443 }
1444 
1445 //===----------------------------------------------------------------------===//
1446 // RankOp
1447 //===----------------------------------------------------------------------===//
1448 
1449 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
1450   // Constant fold rank when the rank of the operand is known.
1451   auto type = getOperand().getType();
1452   auto shapedType = type.dyn_cast<ShapedType>();
1453   if (shapedType && shapedType.hasRank())
1454     return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1455   return IntegerAttr();
1456 }
1457 
1458 //===----------------------------------------------------------------------===//
1459 // ReinterpretCastOp
1460 //===----------------------------------------------------------------------===//
1461 
1462 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1463 /// `staticSizes` and `staticStrides` are automatically filled with
1464 /// source-memref-rank sentinel values that encode dynamic entries.
1465 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1466                               MemRefType resultType, Value source,
1467                               OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1468                               ArrayRef<OpFoldResult> strides,
1469                               ArrayRef<NamedAttribute> attrs) {
1470   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1471   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1472   dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets,
1473                              ShapedType::kDynamicStrideOrOffset);
1474   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1475                              ShapedType::kDynamicSize);
1476   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1477                              ShapedType::kDynamicStrideOrOffset);
1478   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1479         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1480         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1481   result.addAttributes(attrs);
1482 }
1483 
1484 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1485                               MemRefType resultType, Value source,
1486                               int64_t offset, ArrayRef<int64_t> sizes,
1487                               ArrayRef<int64_t> strides,
1488                               ArrayRef<NamedAttribute> attrs) {
1489   SmallVector<OpFoldResult> sizeValues =
1490       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1491         return b.getI64IntegerAttr(v);
1492       }));
1493   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1494       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1495         return b.getI64IntegerAttr(v);
1496       }));
1497   build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1498         strideValues, attrs);
1499 }
1500 
1501 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1502                               MemRefType resultType, Value source, Value offset,
1503                               ValueRange sizes, ValueRange strides,
1504                               ArrayRef<NamedAttribute> attrs) {
1505   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1506       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1507   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1508       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1509   build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1510 }
1511 
1512 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1513 // completed automatically, like we have for subview and extract_slice.
1514 LogicalResult ReinterpretCastOp::verify() {
1515   // The source and result memrefs should be in the same memory space.
1516   auto srcType = source().getType().cast<BaseMemRefType>();
1517   auto resultType = getType().cast<MemRefType>();
1518   if (srcType.getMemorySpace() != resultType.getMemorySpace())
1519     return emitError("different memory spaces specified for source type ")
1520            << srcType << " and result memref type " << resultType;
1521   if (srcType.getElementType() != resultType.getElementType())
1522     return emitError("different element types specified for source type ")
1523            << srcType << " and result memref type " << resultType;
1524 
1525   // Match sizes in result memref type and in static_sizes attribute.
1526   for (auto &en : llvm::enumerate(llvm::zip(
1527            resultType.getShape(), extractFromI64ArrayAttr(static_sizes())))) {
1528     int64_t resultSize = std::get<0>(en.value());
1529     int64_t expectedSize = std::get<1>(en.value());
1530     if (!ShapedType::isDynamic(resultSize) &&
1531         !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
1532       return emitError("expected result type with size = ")
1533              << expectedSize << " instead of " << resultSize
1534              << " in dim = " << en.index();
1535   }
1536 
1537   // Match offset and strides in static_offset and static_strides attributes. If
1538   // result memref type has no affine map specified, this will assume an
1539   // identity layout.
1540   int64_t resultOffset;
1541   SmallVector<int64_t, 4> resultStrides;
1542   if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
1543     return emitError("expected result type to have strided layout but found ")
1544            << resultType;
1545 
1546   // Match offset in result memref type and in static_offsets attribute.
1547   int64_t expectedOffset = extractFromI64ArrayAttr(static_offsets()).front();
1548   if (!ShapedType::isDynamicStrideOrOffset(resultOffset) &&
1549       !ShapedType::isDynamicStrideOrOffset(expectedOffset) &&
1550       resultOffset != expectedOffset)
1551     return emitError("expected result type with offset = ")
1552            << resultOffset << " instead of " << expectedOffset;
1553 
1554   // Match strides in result memref type and in static_strides attribute.
1555   for (auto &en : llvm::enumerate(llvm::zip(
1556            resultStrides, extractFromI64ArrayAttr(static_strides())))) {
1557     int64_t resultStride = std::get<0>(en.value());
1558     int64_t expectedStride = std::get<1>(en.value());
1559     if (!ShapedType::isDynamicStrideOrOffset(resultStride) &&
1560         !ShapedType::isDynamicStrideOrOffset(expectedStride) &&
1561         resultStride != expectedStride)
1562       return emitError("expected result type with stride = ")
1563              << expectedStride << " instead of " << resultStride
1564              << " in dim = " << en.index();
1565   }
1566 
1567   return success();
1568 }
1569 
1570 OpFoldResult ReinterpretCastOp::fold(ArrayRef<Attribute> /*operands*/) {
1571   Value src = source();
1572   auto getPrevSrc = [&]() -> Value {
1573     // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
1574     if (auto prev = src.getDefiningOp<ReinterpretCastOp>())
1575       return prev.source();
1576 
1577     // reinterpret_cast(cast(x)) -> reinterpret_cast(x).
1578     if (auto prev = src.getDefiningOp<CastOp>())
1579       return prev.source();
1580 
1581     // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
1582     // are 0.
1583     if (auto prev = src.getDefiningOp<SubViewOp>())
1584       if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) {
1585             return isConstantIntValue(val, 0);
1586           }))
1587         return prev.source();
1588 
1589     return nullptr;
1590   };
1591 
1592   if (auto prevSrc = getPrevSrc()) {
1593     sourceMutable().assign(prevSrc);
1594     return getResult();
1595   }
1596 
1597   return nullptr;
1598 }
1599 
1600 //===----------------------------------------------------------------------===//
1601 // Reassociative reshape ops
1602 //===----------------------------------------------------------------------===//
1603 
1604 /// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
1605 /// result and operand. Layout maps are verified separately.
1606 ///
1607 /// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are
1608 /// allowed in a reassocation group.
1609 static LogicalResult
1610 verifyCollapsedShape(Operation *op, ArrayRef<int64_t> collapsedShape,
1611                      ArrayRef<int64_t> expandedShape,
1612                      ArrayRef<ReassociationIndices> reassociation,
1613                      bool allowMultipleDynamicDimsPerGroup) {
1614   // There must be one reassociation group per collapsed dimension.
1615   if (collapsedShape.size() != reassociation.size())
1616     return op->emitOpError("invalid number of reassociation groups: found ")
1617            << reassociation.size() << ", expected " << collapsedShape.size();
1618 
1619   // The next expected expanded dimension index (while iterating over
1620   // reassociation indices).
1621   int64_t nextDim = 0;
1622   for (const auto &it : llvm::enumerate(reassociation)) {
1623     ReassociationIndices group = it.value();
1624     int64_t collapsedDim = it.index();
1625 
1626     bool foundDynamic = false;
1627     for (int64_t expandedDim : group) {
1628       if (expandedDim != nextDim++)
1629         return op->emitOpError("reassociation indices must be contiguous");
1630 
1631       if (expandedDim >= static_cast<int64_t>(expandedShape.size()))
1632         return op->emitOpError("reassociation index ")
1633                << expandedDim << " is out of bounds";
1634 
1635       // Check if there are multiple dynamic dims in a reassociation group.
1636       if (ShapedType::isDynamic(expandedShape[expandedDim])) {
1637         if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
1638           return op->emitOpError(
1639               "at most one dimension in a reassociation group may be dynamic");
1640         foundDynamic = true;
1641       }
1642     }
1643 
1644     // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity.
1645     if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
1646       return op->emitOpError("collapsed dim (")
1647              << collapsedDim
1648              << ") must be dynamic if and only if reassociation group is "
1649                 "dynamic";
1650 
1651     // If all dims in the reassociation group are static, the size of the
1652     // collapsed dim can be verified.
1653     if (!foundDynamic) {
1654       int64_t groupSize = 1;
1655       for (int64_t expandedDim : group)
1656         groupSize *= expandedShape[expandedDim];
1657       if (groupSize != collapsedShape[collapsedDim])
1658         return op->emitOpError("collapsed dim size (")
1659                << collapsedShape[collapsedDim]
1660                << ") must equal reassociation group size (" << groupSize << ")";
1661     }
1662   }
1663 
1664   if (collapsedShape.empty()) {
1665     // Rank 0: All expanded dimensions must be 1.
1666     for (int64_t d : expandedShape)
1667       if (d != 1)
1668         return op->emitOpError(
1669             "rank 0 memrefs can only be extended/collapsed with/from ones");
1670   } else if (nextDim != static_cast<int64_t>(expandedShape.size())) {
1671     // Rank >= 1: Number of dimensions among all reassociation groups must match
1672     // the result memref rank.
1673     return op->emitOpError("expanded rank (")
1674            << expandedShape.size()
1675            << ") inconsistent with number of reassociation indices (" << nextDim
1676            << ")";
1677   }
1678 
1679   return success();
1680 }
1681 
1682 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1683   return getSymbolLessAffineMaps(getReassociationExprs());
1684 }
1685 
1686 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1687   return convertReassociationIndicesToExprs(getContext(),
1688                                             getReassociationIndices());
1689 }
1690 
1691 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1692   return getSymbolLessAffineMaps(getReassociationExprs());
1693 }
1694 
1695 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1696   return convertReassociationIndicesToExprs(getContext(),
1697                                             getReassociationIndices());
1698 }
1699 
1700 /// Compute the layout map after expanding a given source MemRef type with the
1701 /// specified reassociation indices.
1702 static FailureOr<AffineMap>
1703 computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
1704                          ArrayRef<ReassociationIndices> reassociation) {
1705   int64_t srcOffset;
1706   SmallVector<int64_t> srcStrides;
1707   if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
1708     return failure();
1709   assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
1710 
1711   // 1-1 mapping between srcStrides and reassociation packs.
1712   // Each srcStride starts with the given value and gets expanded according to
1713   // the proper entries in resultShape.
1714   // Example:
1715   //   srcStrides     =                   [10000,  1 ,    100   ],
1716   //   reassociations =                   [  [0], [1], [2, 3, 4]],
1717   //   resultSizes    = [2, 5, 4, 3, 2] = [  [2], [5], [4, 3, 2]]
1718   //     -> For the purpose of stride calculation, the useful sizes are:
1719   //                    [x, x, x, 3, 2] = [  [x], [x], [x, 3, 2]].
1720   //   resultStrides = [10000, 1, 600, 200, 100]
1721   // Note that a stride does not get expanded along the first entry of each
1722   // shape pack.
1723   SmallVector<int64_t> reverseResultStrides;
1724   reverseResultStrides.reserve(resultShape.size());
1725   unsigned shapeIndex = resultShape.size() - 1;
1726   for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
1727     ReassociationIndices reassoc = std::get<0>(it);
1728     int64_t currentStrideToExpand = std::get<1>(it);
1729     for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
1730       using saturated_arith::Wrapper;
1731       reverseResultStrides.push_back(currentStrideToExpand);
1732       currentStrideToExpand = (Wrapper::stride(currentStrideToExpand) *
1733                                Wrapper::size(resultShape[shapeIndex--]))
1734                                   .asStride();
1735     }
1736   }
1737   auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
1738   resultStrides.resize(resultShape.size(), 1);
1739   return makeStridedLinearLayoutMap(resultStrides, srcOffset,
1740                                     srcType.getContext());
1741 }
1742 
1743 static FailureOr<MemRefType>
1744 computeExpandedType(MemRefType srcType, ArrayRef<int64_t> resultShape,
1745                     ArrayRef<ReassociationIndices> reassociation) {
1746   if (srcType.getLayout().isIdentity()) {
1747     // If the source is contiguous (i.e., no layout map specified), so is the
1748     // result.
1749     MemRefLayoutAttrInterface layout;
1750     return MemRefType::get(resultShape, srcType.getElementType(), layout,
1751                            srcType.getMemorySpace());
1752   }
1753 
1754   // Source may not be contiguous. Compute the layout map.
1755   FailureOr<AffineMap> computedLayout =
1756       computeExpandedLayoutMap(srcType, resultShape, reassociation);
1757   if (failed(computedLayout))
1758     return failure();
1759   auto computedType =
1760       MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
1761                       srcType.getMemorySpaceAsInt());
1762   return canonicalizeStridedLayout(computedType);
1763 }
1764 
1765 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1766                           ArrayRef<int64_t> resultShape, Value src,
1767                           ArrayRef<ReassociationIndices> reassociation) {
1768   // Only ranked memref source values are supported.
1769   auto srcType = src.getType().cast<MemRefType>();
1770   FailureOr<MemRefType> resultType =
1771       computeExpandedType(srcType, resultShape, reassociation);
1772   // Failure of this assertion usually indicates a problem with the source
1773   // type, e.g., could not get strides/offset.
1774   assert(succeeded(resultType) && "could not compute layout");
1775   build(builder, result, *resultType, src, reassociation);
1776 }
1777 
1778 LogicalResult ExpandShapeOp::verify() {
1779   MemRefType srcType = getSrcType();
1780   MemRefType resultType = getResultType();
1781 
1782   // Verify result shape.
1783   if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
1784                                   resultType.getShape(),
1785                                   getReassociationIndices(),
1786                                   /*allowMultipleDynamicDimsPerGroup=*/false)))
1787     return failure();
1788 
1789   // Compute expected result type (including layout map).
1790   FailureOr<MemRefType> expectedResultType = computeExpandedType(
1791       srcType, resultType.getShape(), getReassociationIndices());
1792   if (failed(expectedResultType))
1793     return emitOpError("invalid source layout map");
1794 
1795   // Check actual result type.
1796   auto canonicalizedResultType = canonicalizeStridedLayout(resultType);
1797   if (*expectedResultType != canonicalizedResultType)
1798     return emitOpError("expected expanded type to be ")
1799            << *expectedResultType << " but found " << canonicalizedResultType;
1800 
1801   return success();
1802 }
1803 
1804 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1805                                                 MLIRContext *context) {
1806   results.add<ComposeReassociativeReshapeOps<ExpandShapeOp>,
1807               ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(
1808       context);
1809 }
1810 
1811 /// Compute the layout map after collapsing a given source MemRef type with the
1812 /// specified reassociation indices.
1813 ///
1814 /// Note: All collapsed dims in a reassociation group must be contiguous. It is
1815 /// not possible to check this by inspecting a MemRefType in the general case.
1816 /// If non-contiguity cannot be checked statically, the collapse is assumed to
1817 /// be valid (and thus accepted by this function) unless `strict = true`.
1818 static FailureOr<AffineMap>
1819 computeCollapsedLayoutMap(MemRefType srcType,
1820                           ArrayRef<ReassociationIndices> reassociation,
1821                           bool strict = false) {
1822   int64_t srcOffset;
1823   SmallVector<int64_t> srcStrides;
1824   auto srcShape = srcType.getShape();
1825   if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
1826     return failure();
1827 
1828   // The result stride of a reassociation group is the stride of the last entry
1829   // of the reassociation. (TODO: Should be the minimum stride in the
1830   // reassociation because strides are not necessarily sorted. E.g., when using
1831   // memref.transpose.) Dimensions of size 1 should be skipped, because their
1832   // strides are meaningless and could have any arbitrary value.
1833   SmallVector<int64_t> resultStrides;
1834   resultStrides.reserve(reassociation.size());
1835   for (const ReassociationIndices &reassoc : reassociation) {
1836     ArrayRef<int64_t> ref = llvm::makeArrayRef(reassoc);
1837     while (srcShape[ref.back()] == 1 && ref.size() > 1)
1838       ref = ref.drop_back();
1839     if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
1840       resultStrides.push_back(srcStrides[ref.back()]);
1841     } else {
1842       // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
1843       // the corresponding stride may have to be skipped. (See above comment.)
1844       // Therefore, the result stride cannot be statically determined and must
1845       // be dynamic.
1846       resultStrides.push_back(ShapedType::kDynamicStrideOrOffset);
1847     }
1848   }
1849 
1850   // Validate that each reassociation group is contiguous.
1851   unsigned resultStrideIndex = resultStrides.size() - 1;
1852   for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) {
1853     auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
1854     using saturated_arith::Wrapper;
1855     auto stride = Wrapper::stride(resultStrides[resultStrideIndex--]);
1856     for (int64_t idx : llvm::reverse(trailingReassocs)) {
1857       stride = stride * Wrapper::size(srcShape[idx]);
1858 
1859       // Both source and result stride must have the same static value. In that
1860       // case, we can be sure, that the dimensions are collapsible (because they
1861       // are contiguous).
1862       //
1863       // One special case is when the srcShape is `1`, in which case it can
1864       // never produce non-contiguity.
1865       if (srcShape[idx] == 1)
1866         continue;
1867 
1868       // If `strict = false` (default during op verification), we accept cases
1869       // where one or both strides are dynamic. This is best effort: We reject
1870       // ops where obviously non-contiguous dims are collapsed, but accept ops
1871       // where we cannot be sure statically. Such ops may fail at runtime. See
1872       // the op documentation for details.
1873       auto srcStride = Wrapper::stride(srcStrides[idx - 1]);
1874       if (strict && (stride.saturated || srcStride.saturated))
1875         return failure();
1876 
1877       if (!stride.saturated && !srcStride.saturated && stride != srcStride)
1878         return failure();
1879     }
1880   }
1881   return makeStridedLinearLayoutMap(resultStrides, srcOffset,
1882                                     srcType.getContext());
1883 }
1884 
1885 bool CollapseShapeOp::isGuaranteedCollapsible(
1886     MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
1887   // MemRefs with standard layout are always collapsible.
1888   if (srcType.getLayout().isIdentity())
1889     return true;
1890 
1891   return succeeded(computeCollapsedLayoutMap(srcType, reassociation,
1892                                              /*strict=*/true));
1893 }
1894 
1895 static MemRefType
1896 computeCollapsedType(MemRefType srcType,
1897                      ArrayRef<ReassociationIndices> reassociation) {
1898   SmallVector<int64_t> resultShape;
1899   resultShape.reserve(reassociation.size());
1900   for (const ReassociationIndices &group : reassociation) {
1901     using saturated_arith::Wrapper;
1902     auto groupSize = Wrapper::size(1);
1903     for (int64_t srcDim : group)
1904       groupSize = groupSize * Wrapper::size(srcType.getDimSize(srcDim));
1905     resultShape.push_back(groupSize.asSize());
1906   }
1907 
1908   if (srcType.getLayout().isIdentity()) {
1909     // If the source is contiguous (i.e., no layout map specified), so is the
1910     // result.
1911     MemRefLayoutAttrInterface layout;
1912     return MemRefType::get(resultShape, srcType.getElementType(), layout,
1913                            srcType.getMemorySpace());
1914   }
1915 
1916   // Source may not be fully contiguous. Compute the layout map.
1917   // Note: Dimensions that are collapsed into a single dim are assumed to be
1918   // contiguous.
1919   FailureOr<AffineMap> computedLayout =
1920       computeCollapsedLayoutMap(srcType, reassociation);
1921   assert(succeeded(computedLayout) &&
1922          "invalid source layout map or collapsing non-contiguous dims");
1923   auto computedType =
1924       MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
1925                       srcType.getMemorySpaceAsInt());
1926   return canonicalizeStridedLayout(computedType);
1927 }
1928 
1929 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
1930                             ArrayRef<ReassociationIndices> reassociation,
1931                             ArrayRef<NamedAttribute> attrs) {
1932   auto srcType = src.getType().cast<MemRefType>();
1933   MemRefType resultType = computeCollapsedType(srcType, reassociation);
1934   build(b, result, resultType, src, attrs);
1935   result.addAttribute(getReassociationAttrName(),
1936                       getReassociationIndicesAttribute(b, reassociation));
1937 }
1938 
1939 LogicalResult CollapseShapeOp::verify() {
1940   MemRefType srcType = getSrcType();
1941   MemRefType resultType = getResultType();
1942 
1943   // Verify result shape.
1944   if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
1945                                   srcType.getShape(), getReassociationIndices(),
1946                                   /*allowMultipleDynamicDimsPerGroup=*/true)))
1947     return failure();
1948 
1949   // Compute expected result type (including layout map).
1950   MemRefType expectedResultType;
1951   if (srcType.getLayout().isIdentity()) {
1952     // If the source is contiguous (i.e., no layout map specified), so is the
1953     // result.
1954     MemRefLayoutAttrInterface layout;
1955     expectedResultType =
1956         MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
1957                         srcType.getMemorySpace());
1958   } else {
1959     // Source may not be fully contiguous. Compute the layout map.
1960     // Note: Dimensions that are collapsed into a single dim are assumed to be
1961     // contiguous.
1962     FailureOr<AffineMap> computedLayout =
1963         computeCollapsedLayoutMap(srcType, getReassociationIndices());
1964     if (failed(computedLayout))
1965       return emitOpError(
1966           "invalid source layout map or collapsing non-contiguous dims");
1967     auto computedType =
1968         MemRefType::get(resultType.getShape(), srcType.getElementType(),
1969                         *computedLayout, srcType.getMemorySpaceAsInt());
1970     expectedResultType = canonicalizeStridedLayout(computedType);
1971   }
1972 
1973   auto canonicalizedResultType = canonicalizeStridedLayout(resultType);
1974   if (expectedResultType != canonicalizedResultType)
1975     return emitOpError("expected collapsed type to be ")
1976            << expectedResultType << " but found " << canonicalizedResultType;
1977 
1978   return success();
1979 }
1980 
1981 struct CollapseShapeOpMemRefCastFolder
1982     : public OpRewritePattern<CollapseShapeOp> {
1983 public:
1984   using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
1985 
1986   LogicalResult matchAndRewrite(CollapseShapeOp op,
1987                                 PatternRewriter &rewriter) const override {
1988     auto cast = op.getOperand().getDefiningOp<CastOp>();
1989     if (!cast)
1990       return failure();
1991 
1992     if (!CastOp::canFoldIntoConsumerOp(cast))
1993       return failure();
1994 
1995     Type newResultType =
1996         computeCollapsedType(cast.getOperand().getType().cast<MemRefType>(),
1997                              op.getReassociationIndices());
1998 
1999     if (newResultType == op.getResultType()) {
2000       rewriter.updateRootInPlace(
2001           op, [&]() { op.srcMutable().assign(cast.source()); });
2002     } else {
2003       Value newOp = rewriter.create<CollapseShapeOp>(
2004           op->getLoc(), cast.source(), op.getReassociationIndices());
2005       rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2006     }
2007     return success();
2008   }
2009 };
2010 
2011 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2012                                                   MLIRContext *context) {
2013   results.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
2014               ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp>,
2015               CollapseShapeOpMemRefCastFolder>(context);
2016 }
2017 
2018 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
2019   return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
2020 }
2021 
2022 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
2023   return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
2024 }
2025 
2026 //===----------------------------------------------------------------------===//
2027 // ReshapeOp
2028 //===----------------------------------------------------------------------===//
2029 
2030 LogicalResult ReshapeOp::verify() {
2031   Type operandType = source().getType();
2032   Type resultType = result().getType();
2033 
2034   Type operandElementType = operandType.cast<ShapedType>().getElementType();
2035   Type resultElementType = resultType.cast<ShapedType>().getElementType();
2036   if (operandElementType != resultElementType)
2037     return emitOpError("element types of source and destination memref "
2038                        "types should be the same");
2039 
2040   if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
2041     if (!operandMemRefType.getLayout().isIdentity())
2042       return emitOpError("source memref type should have identity affine map");
2043 
2044   int64_t shapeSize = shape().getType().cast<MemRefType>().getDimSize(0);
2045   auto resultMemRefType = resultType.dyn_cast<MemRefType>();
2046   if (resultMemRefType) {
2047     if (!resultMemRefType.getLayout().isIdentity())
2048       return emitOpError("result memref type should have identity affine map");
2049     if (shapeSize == ShapedType::kDynamicSize)
2050       return emitOpError("cannot use shape operand with dynamic length to "
2051                          "reshape to statically-ranked memref type");
2052     if (shapeSize != resultMemRefType.getRank())
2053       return emitOpError(
2054           "length of shape operand differs from the result's memref rank");
2055   }
2056   return success();
2057 }
2058 
2059 //===----------------------------------------------------------------------===//
2060 // StoreOp
2061 //===----------------------------------------------------------------------===//
2062 
2063 LogicalResult StoreOp::verify() {
2064   if (getNumOperands() != 2 + getMemRefType().getRank())
2065     return emitOpError("store index operand count not equal to memref rank");
2066 
2067   return success();
2068 }
2069 
2070 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
2071                             SmallVectorImpl<OpFoldResult> &results) {
2072   /// store(memrefcast) -> store
2073   return foldMemRefCast(*this, getValueToStore());
2074 }
2075 
2076 //===----------------------------------------------------------------------===//
2077 // SubViewOp
2078 //===----------------------------------------------------------------------===//
2079 
2080 /// A subview result type can be fully inferred from the source type and the
2081 /// static representation of offsets, sizes and strides. Special sentinels
2082 /// encode the dynamic case.
2083 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2084                                 ArrayRef<int64_t> staticOffsets,
2085                                 ArrayRef<int64_t> staticSizes,
2086                                 ArrayRef<int64_t> staticStrides) {
2087   unsigned rank = sourceMemRefType.getRank();
2088   (void)rank;
2089   assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
2090   assert(staticSizes.size() == rank && "staticSizes length mismatch");
2091   assert(staticStrides.size() == rank && "staticStrides length mismatch");
2092 
2093   // Extract source offset and strides.
2094   int64_t sourceOffset;
2095   SmallVector<int64_t, 4> sourceStrides;
2096   auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset);
2097   assert(succeeded(res) && "SubViewOp expected strided memref type");
2098   (void)res;
2099 
2100   // Compute target offset whose value is:
2101   //   `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
2102   int64_t targetOffset = sourceOffset;
2103   for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
2104     auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
2105     using saturated_arith::Wrapper;
2106     targetOffset =
2107         (Wrapper::offset(targetOffset) +
2108          Wrapper::offset(staticOffset) * Wrapper::stride(targetStride))
2109             .asOffset();
2110   }
2111 
2112   // Compute target stride whose value is:
2113   //   `sourceStrides_i * staticStrides_i`.
2114   SmallVector<int64_t, 4> targetStrides;
2115   targetStrides.reserve(staticOffsets.size());
2116   for (auto it : llvm::zip(sourceStrides, staticStrides)) {
2117     auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2118     using saturated_arith::Wrapper;
2119     targetStrides.push_back(
2120         (Wrapper::stride(sourceStride) * Wrapper::stride(staticStride))
2121             .asStride());
2122   }
2123 
2124   // The type is now known.
2125   return MemRefType::get(
2126       staticSizes, sourceMemRefType.getElementType(),
2127       makeStridedLinearLayoutMap(targetStrides, targetOffset,
2128                                  sourceMemRefType.getContext()),
2129       sourceMemRefType.getMemorySpace());
2130 }
2131 
2132 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2133                                 ArrayRef<OpFoldResult> offsets,
2134                                 ArrayRef<OpFoldResult> sizes,
2135                                 ArrayRef<OpFoldResult> strides) {
2136   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2137   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2138   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
2139                              ShapedType::kDynamicStrideOrOffset);
2140   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
2141                              ShapedType::kDynamicSize);
2142   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
2143                              ShapedType::kDynamicStrideOrOffset);
2144   return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2145                                     staticSizes, staticStrides);
2146 }
2147 
2148 Type SubViewOp::inferRankReducedResultType(unsigned resultRank,
2149                                            MemRefType sourceRankedTensorType,
2150                                            ArrayRef<int64_t> offsets,
2151                                            ArrayRef<int64_t> sizes,
2152                                            ArrayRef<int64_t> strides) {
2153   auto inferredType =
2154       inferResultType(sourceRankedTensorType, offsets, sizes, strides)
2155           .cast<MemRefType>();
2156   assert(inferredType.getRank() >= resultRank && "expected ");
2157   int rankDiff = inferredType.getRank() - resultRank;
2158   if (rankDiff > 0) {
2159     auto shape = inferredType.getShape();
2160     llvm::SmallBitVector dimsToProject =
2161         getPositionsOfShapeOne(rankDiff, shape);
2162     SmallVector<int64_t> projectedShape;
2163     for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2164       if (!dimsToProject.test(pos))
2165         projectedShape.push_back(shape[pos]);
2166 
2167     AffineMap map = inferredType.getLayout().getAffineMap();
2168     if (!map.isIdentity())
2169       map = getProjectedMap(map, dimsToProject);
2170     inferredType =
2171         MemRefType::get(projectedShape, inferredType.getElementType(), map,
2172                         inferredType.getMemorySpace());
2173   }
2174   return inferredType;
2175 }
2176 
2177 Type SubViewOp::inferRankReducedResultType(unsigned resultRank,
2178                                            MemRefType sourceRankedTensorType,
2179                                            ArrayRef<OpFoldResult> offsets,
2180                                            ArrayRef<OpFoldResult> sizes,
2181                                            ArrayRef<OpFoldResult> strides) {
2182   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2183   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2184   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
2185                              ShapedType::kDynamicStrideOrOffset);
2186   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
2187                              ShapedType::kDynamicSize);
2188   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
2189                              ShapedType::kDynamicStrideOrOffset);
2190   return SubViewOp::inferRankReducedResultType(
2191       resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
2192       staticStrides);
2193 }
2194 // Build a SubViewOp with mixed static and dynamic entries and custom result
2195 // type. If the type passed is nullptr, it is inferred.
2196 void SubViewOp::build(OpBuilder &b, OperationState &result,
2197                       MemRefType resultType, Value source,
2198                       ArrayRef<OpFoldResult> offsets,
2199                       ArrayRef<OpFoldResult> sizes,
2200                       ArrayRef<OpFoldResult> strides,
2201                       ArrayRef<NamedAttribute> attrs) {
2202   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2203   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2204   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
2205                              ShapedType::kDynamicStrideOrOffset);
2206   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
2207                              ShapedType::kDynamicSize);
2208   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
2209                              ShapedType::kDynamicStrideOrOffset);
2210   auto sourceMemRefType = source.getType().cast<MemRefType>();
2211   // Structuring implementation this way avoids duplication between builders.
2212   if (!resultType) {
2213     resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2214                                             staticSizes, staticStrides)
2215                      .cast<MemRefType>();
2216   }
2217   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2218         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
2219         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
2220   result.addAttributes(attrs);
2221 }
2222 
2223 // Build a SubViewOp with mixed static and dynamic entries and inferred result
2224 // type.
2225 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2226                       ArrayRef<OpFoldResult> offsets,
2227                       ArrayRef<OpFoldResult> sizes,
2228                       ArrayRef<OpFoldResult> strides,
2229                       ArrayRef<NamedAttribute> attrs) {
2230   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2231 }
2232 
2233 // Build a SubViewOp with static entries and inferred result type.
2234 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2235                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2236                       ArrayRef<int64_t> strides,
2237                       ArrayRef<NamedAttribute> attrs) {
2238   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2239       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2240         return b.getI64IntegerAttr(v);
2241       }));
2242   SmallVector<OpFoldResult> sizeValues =
2243       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2244         return b.getI64IntegerAttr(v);
2245       }));
2246   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2247       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2248         return b.getI64IntegerAttr(v);
2249       }));
2250   build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2251 }
2252 
2253 // Build a SubViewOp with dynamic entries and custom result type. If the
2254 // type passed is nullptr, it is inferred.
2255 void SubViewOp::build(OpBuilder &b, OperationState &result,
2256                       MemRefType resultType, Value source,
2257                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2258                       ArrayRef<int64_t> strides,
2259                       ArrayRef<NamedAttribute> attrs) {
2260   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2261       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2262         return b.getI64IntegerAttr(v);
2263       }));
2264   SmallVector<OpFoldResult> sizeValues =
2265       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2266         return b.getI64IntegerAttr(v);
2267       }));
2268   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2269       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2270         return b.getI64IntegerAttr(v);
2271       }));
2272   build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
2273         attrs);
2274 }
2275 
2276 // Build a SubViewOp with dynamic entries and custom result type. If the type
2277 // passed is nullptr, it is inferred.
2278 void SubViewOp::build(OpBuilder &b, OperationState &result,
2279                       MemRefType resultType, Value source, ValueRange offsets,
2280                       ValueRange sizes, ValueRange strides,
2281                       ArrayRef<NamedAttribute> attrs) {
2282   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2283       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2284   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2285       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2286   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2287       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2288   build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2289 }
2290 
2291 // Build a SubViewOp with dynamic entries and inferred result type.
2292 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2293                       ValueRange offsets, ValueRange sizes, ValueRange strides,
2294                       ArrayRef<NamedAttribute> attrs) {
2295   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2296 }
2297 
2298 /// For ViewLikeOpInterface.
2299 Value SubViewOp::getViewSource() { return source(); }
2300 
2301 /// Return true if t1 and t2 have equal offsets (both dynamic or of same
2302 /// static value).
2303 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
2304   AffineExpr t1Offset, t2Offset;
2305   SmallVector<AffineExpr> t1Strides, t2Strides;
2306   auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
2307   auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
2308   return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
2309 }
2310 
2311 /// Checks if `original` Type type can be rank reduced to `reduced` type.
2312 /// This function is slight variant of `is subsequence` algorithm where
2313 /// not matching dimension must be 1.
2314 static SliceVerificationResult
2315 isRankReducedMemRefType(MemRefType originalType,
2316                         MemRefType candidateRankReducedType,
2317                         ArrayRef<OpFoldResult> sizes) {
2318   auto partialRes = isRankReducedType(originalType, candidateRankReducedType);
2319   if (partialRes != SliceVerificationResult::Success)
2320     return partialRes;
2321 
2322   auto optionalUnusedDimsMask = computeMemRefRankReductionMask(
2323       originalType, candidateRankReducedType, sizes);
2324 
2325   // Sizes cannot be matched in case empty vector is returned.
2326   if (!optionalUnusedDimsMask.hasValue())
2327     return SliceVerificationResult::LayoutMismatch;
2328 
2329   if (originalType.getMemorySpace() !=
2330       candidateRankReducedType.getMemorySpace())
2331     return SliceVerificationResult::MemSpaceMismatch;
2332 
2333   // No amount of stride dropping can reconcile incompatible offsets.
2334   if (!haveCompatibleOffsets(originalType, candidateRankReducedType))
2335     return SliceVerificationResult::LayoutMismatch;
2336 
2337   return SliceVerificationResult::Success;
2338 }
2339 
2340 template <typename OpTy>
2341 static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result,
2342                                             OpTy op, Type expectedType) {
2343   auto memrefType = expectedType.cast<ShapedType>();
2344   switch (result) {
2345   case SliceVerificationResult::Success:
2346     return success();
2347   case SliceVerificationResult::RankTooLarge:
2348     return op.emitError("expected result rank to be smaller or equal to ")
2349            << "the source rank. ";
2350   case SliceVerificationResult::SizeMismatch:
2351     return op.emitError("expected result type to be ")
2352            << expectedType
2353            << " or a rank-reduced version. (mismatch of result sizes) ";
2354   case SliceVerificationResult::ElemTypeMismatch:
2355     return op.emitError("expected result element type to be ")
2356            << memrefType.getElementType();
2357   case SliceVerificationResult::MemSpaceMismatch:
2358     return op.emitError("expected result and source memory spaces to match.");
2359   case SliceVerificationResult::LayoutMismatch:
2360     return op.emitError("expected result type to be ")
2361            << expectedType
2362            << " or a rank-reduced version. (mismatch of result layout) ";
2363   }
2364   llvm_unreachable("unexpected subview verification result");
2365 }
2366 
2367 /// Verifier for SubViewOp.
2368 LogicalResult SubViewOp::verify() {
2369   MemRefType baseType = getSourceType();
2370   MemRefType subViewType = getType();
2371 
2372   // The base memref and the view memref should be in the same memory space.
2373   if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2374     return emitError("different memory spaces specified for base memref "
2375                      "type ")
2376            << baseType << " and subview memref type " << subViewType;
2377 
2378   // Verify that the base memref type has a strided layout map.
2379   if (!isStrided(baseType))
2380     return emitError("base type ") << baseType << " is not strided";
2381 
2382   // Verify result type against inferred type.
2383   auto expectedType = SubViewOp::inferResultType(
2384       baseType, extractFromI64ArrayAttr(static_offsets()),
2385       extractFromI64ArrayAttr(static_sizes()),
2386       extractFromI64ArrayAttr(static_strides()));
2387 
2388   auto result = isRankReducedMemRefType(expectedType.cast<MemRefType>(),
2389                                         subViewType, getMixedSizes());
2390   return produceSubViewErrorMsg(result, *this, expectedType);
2391 }
2392 
2393 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
2394   return os << "range " << range.offset << ":" << range.size << ":"
2395             << range.stride;
2396 }
2397 
2398 /// Return the list of Range (i.e. offset, size, stride). Each Range
2399 /// entry contains either the dynamic value or a ConstantIndexOp constructed
2400 /// with `b` at location `loc`.
2401 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
2402                                               OpBuilder &b, Location loc) {
2403   std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
2404   assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
2405   assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
2406   SmallVector<Range, 8> res;
2407   unsigned rank = ranks[0];
2408   res.reserve(rank);
2409   for (unsigned idx = 0; idx < rank; ++idx) {
2410     Value offset =
2411         op.isDynamicOffset(idx)
2412             ? op.getDynamicOffset(idx)
2413             : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx));
2414     Value size =
2415         op.isDynamicSize(idx)
2416             ? op.getDynamicSize(idx)
2417             : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx));
2418     Value stride =
2419         op.isDynamicStride(idx)
2420             ? op.getDynamicStride(idx)
2421             : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx));
2422     res.emplace_back(Range{offset, size, stride});
2423   }
2424   return res;
2425 }
2426 
2427 /// Compute the canonical result type of a SubViewOp. Call `inferResultType`
2428 /// to deduce the result type for the given `sourceType`. Additionally, reduce
2429 /// the rank of the inferred result type if `currentResultType` is lower rank
2430 /// than `currentSourceType`. Use this signature if `sourceType` is updated
2431 /// together with the result type. In this case, it is important to compute
2432 /// the dropped dimensions using `currentSourceType` whose strides align with
2433 /// `currentResultType`.
2434 static MemRefType getCanonicalSubViewResultType(
2435     MemRefType currentResultType, MemRefType currentSourceType,
2436     MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
2437     ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
2438   auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets,
2439                                                        mixedSizes, mixedStrides)
2440                                 .cast<MemRefType>();
2441   llvm::Optional<llvm::SmallBitVector> unusedDims =
2442       computeMemRefRankReductionMask(currentSourceType, currentResultType,
2443                                      mixedSizes);
2444   // Return nullptr as failure mode.
2445   if (!unusedDims)
2446     return nullptr;
2447   SmallVector<int64_t> shape;
2448   for (const auto &sizes : llvm::enumerate(nonRankReducedType.getShape())) {
2449     if (unusedDims->test(sizes.index()))
2450       continue;
2451     shape.push_back(sizes.value());
2452   }
2453   AffineMap layoutMap = nonRankReducedType.getLayout().getAffineMap();
2454   if (!layoutMap.isIdentity())
2455     layoutMap = getProjectedMap(layoutMap, unusedDims.getValue());
2456   return MemRefType::get(shape, nonRankReducedType.getElementType(), layoutMap,
2457                          nonRankReducedType.getMemorySpace());
2458 }
2459 
2460 /// Compute the canonical result type of a SubViewOp. Call `inferResultType`
2461 /// to deduce the result type. Additionally, reduce the rank of the inferred
2462 /// result type if `currentResultType` is lower rank than `sourceType`.
2463 static MemRefType getCanonicalSubViewResultType(
2464     MemRefType currentResultType, MemRefType sourceType,
2465     ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes,
2466     ArrayRef<OpFoldResult> mixedStrides) {
2467   return getCanonicalSubViewResultType(currentResultType, sourceType,
2468                                        sourceType, mixedOffsets, mixedSizes,
2469                                        mixedStrides);
2470 }
2471 
2472 /// Helper method to check if a `subview` operation is trivially a no-op. This
2473 /// is the case if the all offsets are zero, all strides are 1, and the source
2474 /// shape is same as the size of the subview. In such cases, the subview can
2475 /// be folded into its source.
2476 static bool isTrivialSubViewOp(SubViewOp subViewOp) {
2477   if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
2478     return false;
2479 
2480   auto mixedOffsets = subViewOp.getMixedOffsets();
2481   auto mixedSizes = subViewOp.getMixedSizes();
2482   auto mixedStrides = subViewOp.getMixedStrides();
2483 
2484   // Check offsets are zero.
2485   if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
2486         Optional<int64_t> intValue = getConstantIntValue(ofr);
2487         return !intValue || intValue.getValue() != 0;
2488       }))
2489     return false;
2490 
2491   // Check strides are one.
2492   if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
2493         Optional<int64_t> intValue = getConstantIntValue(ofr);
2494         return !intValue || intValue.getValue() != 1;
2495       }))
2496     return false;
2497 
2498   // Check all size values are static and matches the (static) source shape.
2499   ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
2500   for (const auto &size : llvm::enumerate(mixedSizes)) {
2501     Optional<int64_t> intValue = getConstantIntValue(size.value());
2502     if (!intValue || intValue.getValue() != sourceShape[size.index()])
2503       return false;
2504   }
2505   // All conditions met. The `SubViewOp` is foldable as a no-op.
2506   return true;
2507 }
2508 
2509 namespace {
2510 /// Pattern to rewrite a subview op with MemRefCast arguments.
2511 /// This essentially pushes memref.cast past its consuming subview when
2512 /// `canFoldIntoConsumerOp` is true.
2513 ///
2514 /// Example:
2515 /// ```
2516 ///   %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
2517 ///   %1 = memref.subview %0[0, 0][3, 4][1, 1] :
2518 ///     memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
2519 /// ```
2520 /// is rewritten into:
2521 /// ```
2522 ///   %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
2523 ///   %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to
2524 ///     memref<3x4xf32, offset:?, strides:[?, 1]>
2525 /// ```
2526 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
2527 public:
2528   using OpRewritePattern<SubViewOp>::OpRewritePattern;
2529 
2530   LogicalResult matchAndRewrite(SubViewOp subViewOp,
2531                                 PatternRewriter &rewriter) const override {
2532     // Any constant operand, just return to let SubViewOpConstantFolder kick
2533     // in.
2534     if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
2535           return matchPattern(operand, matchConstantIndex());
2536         }))
2537       return failure();
2538 
2539     auto castOp = subViewOp.source().getDefiningOp<CastOp>();
2540     if (!castOp)
2541       return failure();
2542 
2543     if (!CastOp::canFoldIntoConsumerOp(castOp))
2544       return failure();
2545 
2546     // Compute the SubViewOp result type after folding the MemRefCastOp. Use
2547     // the MemRefCastOp source operand type to infer the result type and the
2548     // current SubViewOp source operand type to compute the dropped dimensions
2549     // if the operation is rank-reducing.
2550     auto resultType = getCanonicalSubViewResultType(
2551         subViewOp.getType(), subViewOp.getSourceType(),
2552         castOp.source().getType().cast<MemRefType>(),
2553         subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
2554         subViewOp.getMixedStrides());
2555     if (!resultType)
2556       return failure();
2557 
2558     Value newSubView = rewriter.create<SubViewOp>(
2559         subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
2560         subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),
2561         subViewOp.static_sizes(), subViewOp.static_strides());
2562     rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
2563                                         newSubView);
2564     return success();
2565   }
2566 };
2567 
2568 /// Canonicalize subview ops that are no-ops. When the source shape is not
2569 /// same as a result shape due to use of `affine_map`.
2570 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
2571 public:
2572   using OpRewritePattern<SubViewOp>::OpRewritePattern;
2573 
2574   LogicalResult matchAndRewrite(SubViewOp subViewOp,
2575                                 PatternRewriter &rewriter) const override {
2576     if (!isTrivialSubViewOp(subViewOp))
2577       return failure();
2578     if (subViewOp.getSourceType() == subViewOp.getType()) {
2579       rewriter.replaceOp(subViewOp, subViewOp.source());
2580       return success();
2581     }
2582     rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
2583                                         subViewOp.source());
2584     return success();
2585   }
2586 };
2587 } // namespace
2588 
2589 /// Return the canonical type of the result of a subview.
2590 struct SubViewReturnTypeCanonicalizer {
2591   MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
2592                         ArrayRef<OpFoldResult> mixedSizes,
2593                         ArrayRef<OpFoldResult> mixedStrides) {
2594     return getCanonicalSubViewResultType(op.getType(), op.getSourceType(),
2595                                          mixedOffsets, mixedSizes,
2596                                          mixedStrides);
2597   }
2598 };
2599 
2600 /// A canonicalizer wrapper to replace SubViewOps.
2601 struct SubViewCanonicalizer {
2602   void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
2603     rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2604   }
2605 };
2606 
2607 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
2608                                             MLIRContext *context) {
2609   results
2610       .add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
2611                SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>,
2612            SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
2613 }
2614 
2615 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
2616   auto resultShapedType = getResult().getType().cast<ShapedType>();
2617   auto sourceShapedType = source().getType().cast<ShapedType>();
2618 
2619   if (resultShapedType.hasStaticShape() &&
2620       resultShapedType == sourceShapedType) {
2621     return getViewSource();
2622   }
2623 
2624   return {};
2625 }
2626 
2627 //===----------------------------------------------------------------------===//
2628 // TransposeOp
2629 //===----------------------------------------------------------------------===//
2630 
2631 /// Build a strided memref type by applying `permutationMap` tp `memRefType`.
2632 static MemRefType inferTransposeResultType(MemRefType memRefType,
2633                                            AffineMap permutationMap) {
2634   auto rank = memRefType.getRank();
2635   auto originalSizes = memRefType.getShape();
2636   // Compute permuted sizes.
2637   SmallVector<int64_t, 4> sizes(rank, 0);
2638   for (const auto &en : llvm::enumerate(permutationMap.getResults()))
2639     sizes[en.index()] =
2640         originalSizes[en.value().cast<AffineDimExpr>().getPosition()];
2641 
2642   // Compute permuted strides.
2643   int64_t offset;
2644   SmallVector<int64_t, 4> strides;
2645   auto res = getStridesAndOffset(memRefType, strides, offset);
2646   assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank));
2647   (void)res;
2648   auto map =
2649       makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
2650   map = permutationMap ? map.compose(permutationMap) : map;
2651   return MemRefType::Builder(memRefType)
2652       .setShape(sizes)
2653       .setLayout(AffineMapAttr::get(map));
2654 }
2655 
2656 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
2657                         AffineMapAttr permutation,
2658                         ArrayRef<NamedAttribute> attrs) {
2659   auto permutationMap = permutation.getValue();
2660   assert(permutationMap);
2661 
2662   auto memRefType = in.getType().cast<MemRefType>();
2663   // Compute result type.
2664   MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
2665 
2666   build(b, result, resultType, in, attrs);
2667   result.addAttribute(TransposeOp::getPermutationAttrName(), permutation);
2668 }
2669 
2670 // transpose $in $permutation attr-dict : type($in) `to` type(results)
2671 void TransposeOp::print(OpAsmPrinter &p) {
2672   p << " " << in() << " " << permutation();
2673   p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
2674   p << " : " << in().getType() << " to " << getType();
2675 }
2676 
2677 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
2678   OpAsmParser::UnresolvedOperand in;
2679   AffineMap permutation;
2680   MemRefType srcType, dstType;
2681   if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
2682       parser.parseOptionalAttrDict(result.attributes) ||
2683       parser.parseColonType(srcType) ||
2684       parser.resolveOperand(in, srcType, result.operands) ||
2685       parser.parseKeywordType("to", dstType) ||
2686       parser.addTypeToList(dstType, result.types))
2687     return failure();
2688 
2689   result.addAttribute(TransposeOp::getPermutationAttrName(),
2690                       AffineMapAttr::get(permutation));
2691   return success();
2692 }
2693 
2694 LogicalResult TransposeOp::verify() {
2695   if (!permutation().isPermutation())
2696     return emitOpError("expected a permutation map");
2697   if (permutation().getNumDims() != getShapedType().getRank())
2698     return emitOpError("expected a permutation map of same rank as the input");
2699 
2700   auto srcType = in().getType().cast<MemRefType>();
2701   auto dstType = getType().cast<MemRefType>();
2702   auto transposedType = inferTransposeResultType(srcType, permutation());
2703   if (dstType != transposedType)
2704     return emitOpError("output type ")
2705            << dstType << " does not match transposed input type " << srcType
2706            << ", " << transposedType;
2707   return success();
2708 }
2709 
2710 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
2711   if (succeeded(foldMemRefCast(*this)))
2712     return getResult();
2713   return {};
2714 }
2715 
2716 //===----------------------------------------------------------------------===//
2717 // ViewOp
2718 //===----------------------------------------------------------------------===//
2719 
2720 LogicalResult ViewOp::verify() {
2721   auto baseType = getOperand(0).getType().cast<MemRefType>();
2722   auto viewType = getType();
2723 
2724   // The base memref should have identity layout map (or none).
2725   if (!baseType.getLayout().isIdentity())
2726     return emitError("unsupported map for base memref type ") << baseType;
2727 
2728   // The result memref should have identity layout map (or none).
2729   if (!viewType.getLayout().isIdentity())
2730     return emitError("unsupported map for result memref type ") << viewType;
2731 
2732   // The base memref and the view memref should be in the same memory space.
2733   if (baseType.getMemorySpace() != viewType.getMemorySpace())
2734     return emitError("different memory spaces specified for base memref "
2735                      "type ")
2736            << baseType << " and view memref type " << viewType;
2737 
2738   // Verify that we have the correct number of sizes for the result type.
2739   unsigned numDynamicDims = viewType.getNumDynamicDims();
2740   if (sizes().size() != numDynamicDims)
2741     return emitError("incorrect number of size operands for type ") << viewType;
2742 
2743   return success();
2744 }
2745 
2746 Value ViewOp::getViewSource() { return source(); }
2747 
2748 namespace {
2749 
2750 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
2751   using OpRewritePattern<ViewOp>::OpRewritePattern;
2752 
2753   LogicalResult matchAndRewrite(ViewOp viewOp,
2754                                 PatternRewriter &rewriter) const override {
2755     // Return if none of the operands are constants.
2756     if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
2757           return matchPattern(operand, matchConstantIndex());
2758         }))
2759       return failure();
2760 
2761     // Get result memref type.
2762     auto memrefType = viewOp.getType();
2763 
2764     // Get offset from old memref view type 'memRefType'.
2765     int64_t oldOffset;
2766     SmallVector<int64_t, 4> oldStrides;
2767     if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
2768       return failure();
2769     assert(oldOffset == 0 && "Expected 0 offset");
2770 
2771     SmallVector<Value, 4> newOperands;
2772 
2773     // Offset cannot be folded into result type.
2774 
2775     // Fold any dynamic dim operands which are produced by a constant.
2776     SmallVector<int64_t, 4> newShapeConstants;
2777     newShapeConstants.reserve(memrefType.getRank());
2778 
2779     unsigned dynamicDimPos = 0;
2780     unsigned rank = memrefType.getRank();
2781     for (unsigned dim = 0, e = rank; dim < e; ++dim) {
2782       int64_t dimSize = memrefType.getDimSize(dim);
2783       // If this is already static dimension, keep it.
2784       if (!ShapedType::isDynamic(dimSize)) {
2785         newShapeConstants.push_back(dimSize);
2786         continue;
2787       }
2788       auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp();
2789       if (auto constantIndexOp =
2790               dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
2791         // Dynamic shape dimension will be folded.
2792         newShapeConstants.push_back(constantIndexOp.value());
2793       } else {
2794         // Dynamic shape dimension not folded; copy operand from old memref.
2795         newShapeConstants.push_back(dimSize);
2796         newOperands.push_back(viewOp.sizes()[dynamicDimPos]);
2797       }
2798       dynamicDimPos++;
2799     }
2800 
2801     // Create new memref type with constant folded dims.
2802     MemRefType newMemRefType =
2803         MemRefType::Builder(memrefType).setShape(newShapeConstants);
2804     // Nothing new, don't fold.
2805     if (newMemRefType == memrefType)
2806       return failure();
2807 
2808     // Create new ViewOp.
2809     auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType,
2810                                              viewOp.getOperand(0),
2811                                              viewOp.byte_shift(), newOperands);
2812     // Insert a cast so we have the same type as the old memref type.
2813     rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
2814     return success();
2815   }
2816 };
2817 
2818 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
2819   using OpRewritePattern<ViewOp>::OpRewritePattern;
2820 
2821   LogicalResult matchAndRewrite(ViewOp viewOp,
2822                                 PatternRewriter &rewriter) const override {
2823     Value memrefOperand = viewOp.getOperand(0);
2824     CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
2825     if (!memrefCastOp)
2826       return failure();
2827     Value allocOperand = memrefCastOp.getOperand();
2828     AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
2829     if (!allocOp)
2830       return failure();
2831     rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
2832                                         viewOp.byte_shift(), viewOp.sizes());
2833     return success();
2834   }
2835 };
2836 
2837 } // namespace
2838 
2839 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
2840                                          MLIRContext *context) {
2841   results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
2842 }
2843 
2844 //===----------------------------------------------------------------------===//
2845 // AtomicRMWOp
2846 //===----------------------------------------------------------------------===//
2847 
2848 LogicalResult AtomicRMWOp::verify() {
2849   if (getMemRefType().getRank() != getNumOperands() - 2)
2850     return emitOpError(
2851         "expects the number of subscripts to be equal to memref rank");
2852   switch (kind()) {
2853   case arith::AtomicRMWKind::addf:
2854   case arith::AtomicRMWKind::maxf:
2855   case arith::AtomicRMWKind::minf:
2856   case arith::AtomicRMWKind::mulf:
2857     if (!value().getType().isa<FloatType>())
2858       return emitOpError() << "with kind '"
2859                            << arith::stringifyAtomicRMWKind(kind())
2860                            << "' expects a floating-point type";
2861     break;
2862   case arith::AtomicRMWKind::addi:
2863   case arith::AtomicRMWKind::maxs:
2864   case arith::AtomicRMWKind::maxu:
2865   case arith::AtomicRMWKind::mins:
2866   case arith::AtomicRMWKind::minu:
2867   case arith::AtomicRMWKind::muli:
2868   case arith::AtomicRMWKind::ori:
2869   case arith::AtomicRMWKind::andi:
2870     if (!value().getType().isa<IntegerType>())
2871       return emitOpError() << "with kind '"
2872                            << arith::stringifyAtomicRMWKind(kind())
2873                            << "' expects an integer type";
2874     break;
2875   default:
2876     break;
2877   }
2878   return success();
2879 }
2880 
2881 OpFoldResult AtomicRMWOp::fold(ArrayRef<Attribute> operands) {
2882   /// atomicrmw(memrefcast) -> atomicrmw
2883   if (succeeded(foldMemRefCast(*this, value())))
2884     return getResult();
2885   return OpFoldResult();
2886 }
2887 
2888 //===----------------------------------------------------------------------===//
2889 // TableGen'd op method definitions
2890 //===----------------------------------------------------------------------===//
2891 
2892 #define GET_OP_CLASSES
2893 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
2894