1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
13 #include "mlir/Dialect/Utils/StaticValueUtils.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Interfaces/InferTypeOpInterface.h"
21 #include "mlir/Interfaces/SideEffectInterfaces.h"
22 #include "mlir/Interfaces/ViewLikeInterface.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallBitVector.h"
25 
26 using namespace mlir;
27 using namespace mlir::memref;
28 
29 namespace {
30 /// Idiomatic saturated operations on offsets, sizes and strides.
31 namespace saturated_arith {
32 struct Wrapper {
stride__anon3c5201a10111::saturated_arith::Wrapper33   static Wrapper stride(int64_t v) {
34     return (ShapedType::isDynamicStrideOrOffset(v)) ? Wrapper{true, 0}
35                                                     : Wrapper{false, v};
36   }
offset__anon3c5201a10111::saturated_arith::Wrapper37   static Wrapper offset(int64_t v) {
38     return (ShapedType::isDynamicStrideOrOffset(v)) ? Wrapper{true, 0}
39                                                     : Wrapper{false, v};
40   }
size__anon3c5201a10111::saturated_arith::Wrapper41   static Wrapper size(int64_t v) {
42     return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
43   }
asOffset__anon3c5201a10111::saturated_arith::Wrapper44   int64_t asOffset() {
45     return saturated ? ShapedType::kDynamicStrideOrOffset : v;
46   }
asSize__anon3c5201a10111::saturated_arith::Wrapper47   int64_t asSize() { return saturated ? ShapedType::kDynamicSize : v; }
asStride__anon3c5201a10111::saturated_arith::Wrapper48   int64_t asStride() {
49     return saturated ? ShapedType::kDynamicStrideOrOffset : v;
50   }
operator ==__anon3c5201a10111::saturated_arith::Wrapper51   bool operator==(Wrapper other) {
52     return (saturated && other.saturated) ||
53            (!saturated && !other.saturated && v == other.v);
54   }
operator !=__anon3c5201a10111::saturated_arith::Wrapper55   bool operator!=(Wrapper other) { return !(*this == other); }
operator +__anon3c5201a10111::saturated_arith::Wrapper56   Wrapper operator+(Wrapper other) {
57     if (saturated || other.saturated)
58       return Wrapper{true, 0};
59     return Wrapper{false, other.v + v};
60   }
operator *__anon3c5201a10111::saturated_arith::Wrapper61   Wrapper operator*(Wrapper other) {
62     if (saturated || other.saturated)
63       return Wrapper{true, 0};
64     return Wrapper{false, other.v * v};
65   }
66   bool saturated;
67   int64_t v;
68 };
69 } // namespace saturated_arith
70 } // namespace
71 
72 /// Materialize a single constant operation from a given attribute value with
73 /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)74 Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
75                                               Attribute value, Type type,
76                                               Location loc) {
77   if (arith::ConstantOp::isBuildableWith(value, type))
78     return builder.create<arith::ConstantOp>(loc, value, type);
79   return nullptr;
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // Common canonicalization pattern support logic
84 //===----------------------------------------------------------------------===//
85 
86 /// This is a common class used for patterns of the form
87 /// "someop(memrefcast) -> someop".  It folds the source of any memref.cast
88 /// into the root operation directly.
foldMemRefCast(Operation * op,Value inner)89 LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) {
90   bool folded = false;
91   for (OpOperand &operand : op->getOpOperands()) {
92     auto cast = operand.get().getDefiningOp<CastOp>();
93     if (cast && operand.get() != inner &&
94         !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
95       operand.set(cast.getOperand());
96       folded = true;
97     }
98   }
99   return success(folded);
100 }
101 
102 /// Return an unranked/ranked tensor type for the given unranked/ranked memref
103 /// type.
getTensorTypeFromMemRefType(Type type)104 Type mlir::memref::getTensorTypeFromMemRefType(Type type) {
105   if (auto memref = type.dyn_cast<MemRefType>())
106     return RankedTensorType::get(memref.getShape(), memref.getElementType());
107   if (auto memref = type.dyn_cast<UnrankedMemRefType>())
108     return UnrankedTensorType::get(memref.getElementType());
109   return NoneType::get(type.getContext());
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // AllocOp / AllocaOp
114 //===----------------------------------------------------------------------===//
115 
116 template <typename AllocLikeOp>
verifyAllocLikeOp(AllocLikeOp op)117 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
118   static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
119                 "applies to only alloc or alloca");
120   auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>();
121   if (!memRefType)
122     return op.emitOpError("result must be a memref");
123 
124   if (static_cast<int64_t>(op.getDynamicSizes().size()) !=
125       memRefType.getNumDynamicDims())
126     return op.emitOpError("dimension operand count does not equal memref "
127                           "dynamic dimension count");
128 
129   unsigned numSymbols = 0;
130   if (!memRefType.getLayout().isIdentity())
131     numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
132   if (op.getSymbolOperands().size() != numSymbols)
133     return op.emitOpError("symbol operand count does not equal memref symbol "
134                           "count: expected ")
135            << numSymbols << ", got " << op.getSymbolOperands().size();
136 
137   return success();
138 }
139 
verify()140 LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); }
141 
verify()142 LogicalResult AllocaOp::verify() {
143   // An alloca op needs to have an ancestor with an allocation scope trait.
144   if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
145     return emitOpError(
146         "requires an ancestor op with AutomaticAllocationScope trait");
147 
148   return verifyAllocLikeOp(*this);
149 }
150 
151 namespace {
152 /// Fold constant dimensions into an alloc like operation.
153 template <typename AllocLikeOp>
154 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
155   using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
156 
matchAndRewrite__anon3c5201a10211::SimplifyAllocConst157   LogicalResult matchAndRewrite(AllocLikeOp alloc,
158                                 PatternRewriter &rewriter) const override {
159     // Check to see if any dimensions operands are constants.  If so, we can
160     // substitute and drop them.
161     if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
162           return matchPattern(operand, matchConstantIndex());
163         }))
164       return failure();
165 
166     auto memrefType = alloc.getType();
167 
168     // Ok, we have one or more constant operands.  Collect the non-constant ones
169     // and keep track of the resultant memref type to build.
170     SmallVector<int64_t, 4> newShapeConstants;
171     newShapeConstants.reserve(memrefType.getRank());
172     SmallVector<Value, 4> dynamicSizes;
173 
174     unsigned dynamicDimPos = 0;
175     for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
176       int64_t dimSize = memrefType.getDimSize(dim);
177       // If this is already static dimension, keep it.
178       if (dimSize != -1) {
179         newShapeConstants.push_back(dimSize);
180         continue;
181       }
182       auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
183       auto *defOp = dynamicSize.getDefiningOp();
184       if (auto constantIndexOp =
185               dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
186         // Dynamic shape dimension will be folded.
187         newShapeConstants.push_back(constantIndexOp.value());
188       } else {
189         // Dynamic shape dimension not folded; copy dynamicSize from old memref.
190         newShapeConstants.push_back(-1);
191         dynamicSizes.push_back(dynamicSize);
192       }
193       dynamicDimPos++;
194     }
195 
196     // Create new memref type (which will have fewer dynamic dimensions).
197     MemRefType newMemRefType =
198         MemRefType::Builder(memrefType).setShape(newShapeConstants);
199     assert(static_cast<int64_t>(dynamicSizes.size()) ==
200            newMemRefType.getNumDynamicDims());
201 
202     // Create and insert the alloc op for the new memref.
203     auto newAlloc = rewriter.create<AllocLikeOp>(
204         alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(),
205         alloc.getAlignmentAttr());
206     // Insert a cast so we have the same type as the old alloc.
207     auto resultCast =
208         rewriter.create<CastOp>(alloc.getLoc(), alloc.getType(), newAlloc);
209 
210     rewriter.replaceOp(alloc, {resultCast});
211     return success();
212   }
213 };
214 
215 /// Fold alloc operations with no users or only store and dealloc uses.
216 template <typename T>
217 struct SimplifyDeadAlloc : public OpRewritePattern<T> {
218   using OpRewritePattern<T>::OpRewritePattern;
219 
matchAndRewrite__anon3c5201a10211::SimplifyDeadAlloc220   LogicalResult matchAndRewrite(T alloc,
221                                 PatternRewriter &rewriter) const override {
222     if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
223           if (auto storeOp = dyn_cast<StoreOp>(op))
224             return storeOp.getValue() == alloc;
225           return !isa<DeallocOp>(op);
226         }))
227       return failure();
228 
229     for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
230       rewriter.eraseOp(user);
231 
232     rewriter.eraseOp(alloc);
233     return success();
234   }
235 };
236 } // namespace
237 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)238 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
239                                           MLIRContext *context) {
240   results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
241 }
242 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)243 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
244                                            MLIRContext *context) {
245   results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
246       context);
247 }
248 
249 //===----------------------------------------------------------------------===//
250 // AllocaScopeOp
251 //===----------------------------------------------------------------------===//
252 
print(OpAsmPrinter & p)253 void AllocaScopeOp::print(OpAsmPrinter &p) {
254   bool printBlockTerminators = false;
255 
256   p << ' ';
257   if (!getResults().empty()) {
258     p << " -> (" << getResultTypes() << ")";
259     printBlockTerminators = true;
260   }
261   p << ' ';
262   p.printRegion(getBodyRegion(),
263                 /*printEntryBlockArgs=*/false,
264                 /*printBlockTerminators=*/printBlockTerminators);
265   p.printOptionalAttrDict((*this)->getAttrs());
266 }
267 
parse(OpAsmParser & parser,OperationState & result)268 ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
269   // Create a region for the body.
270   result.regions.reserve(1);
271   Region *bodyRegion = result.addRegion();
272 
273   // Parse optional results type list.
274   if (parser.parseOptionalArrowTypeList(result.types))
275     return failure();
276 
277   // Parse the body region.
278   if (parser.parseRegion(*bodyRegion, /*arguments=*/{}))
279     return failure();
280   AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
281                                   result.location);
282 
283   // Parse the optional attribute list.
284   if (parser.parseOptionalAttrDict(result.attributes))
285     return failure();
286 
287   return success();
288 }
289 
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)290 void AllocaScopeOp::getSuccessorRegions(
291     Optional<unsigned> index, ArrayRef<Attribute> operands,
292     SmallVectorImpl<RegionSuccessor> &regions) {
293   if (index) {
294     regions.push_back(RegionSuccessor(getResults()));
295     return;
296   }
297 
298   regions.push_back(RegionSuccessor(&getBodyRegion()));
299 }
300 
301 /// Given an operation, return whether this op is guaranteed to
302 /// allocate an AutomaticAllocationScopeResource
isGuaranteedAutomaticAllocation(Operation * op)303 static bool isGuaranteedAutomaticAllocation(Operation *op) {
304   MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
305   if (!interface)
306     return false;
307   for (auto res : op->getResults()) {
308     if (auto effect =
309             interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
310       if (isa<SideEffects::AutomaticAllocationScopeResource>(
311               effect->getResource()))
312         return true;
313     }
314   }
315   return false;
316 }
317 
318 /// Given an operation, return whether this op itself could
319 /// allocate an AutomaticAllocationScopeResource. Note that
320 /// this will not check whether an operation contained within
321 /// the op can allocate.
isOpItselfPotentialAutomaticAllocation(Operation * op)322 static bool isOpItselfPotentialAutomaticAllocation(Operation *op) {
323   // This op itself doesn't create a stack allocation,
324   // the inner allocation should be handled separately.
325   if (op->hasTrait<OpTrait::HasRecursiveSideEffects>())
326     return false;
327   MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
328   if (!interface)
329     return true;
330   for (auto res : op->getResults()) {
331     if (auto effect =
332             interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
333       if (isa<SideEffects::AutomaticAllocationScopeResource>(
334               effect->getResource()))
335         return true;
336     }
337   }
338   return false;
339 }
340 
341 /// Return whether this op is the last non terminating op
342 /// in a region. That is to say, it is in a one-block region
343 /// and is only followed by a terminator. This prevents
344 /// extending the lifetime of allocations.
lastNonTerminatorInRegion(Operation * op)345 static bool lastNonTerminatorInRegion(Operation *op) {
346   return op->getNextNode() == op->getBlock()->getTerminator() &&
347          op->getParentRegion()->getBlocks().size() == 1;
348 }
349 
350 /// Inline an AllocaScopeOp if either the direct parent is an allocation scope
351 /// or it contains no allocation.
352 struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
353   using OpRewritePattern<AllocaScopeOp>::OpRewritePattern;
354 
matchAndRewriteAllocaScopeInliner355   LogicalResult matchAndRewrite(AllocaScopeOp op,
356                                 PatternRewriter &rewriter) const override {
357     bool hasPotentialAlloca =
358         op->walk<WalkOrder::PreOrder>([&](Operation *alloc) {
359             if (alloc == op)
360               return WalkResult::advance();
361             if (isOpItselfPotentialAutomaticAllocation(alloc))
362               return WalkResult::interrupt();
363             if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>())
364               return WalkResult::skip();
365             return WalkResult::advance();
366           }).wasInterrupted();
367 
368     // If this contains no potential allocation, it is always legal to
369     // inline. Otherwise, consider two conditions:
370     if (hasPotentialAlloca) {
371       // If the parent isn't an allocation scope, or we are not the last
372       // non-terminator op in the parent, we will extend the lifetime.
373       if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>())
374         return failure();
375       if (!lastNonTerminatorInRegion(op))
376         return failure();
377     }
378 
379     Block *block = &op.getRegion().front();
380     Operation *terminator = block->getTerminator();
381     ValueRange results = terminator->getOperands();
382     rewriter.mergeBlockBefore(block, op);
383     rewriter.replaceOp(op, results);
384     rewriter.eraseOp(terminator);
385     return success();
386   }
387 };
388 
389 /// Move allocations into an allocation scope, if it is legal to
390 /// move them (e.g. their operands are available at the location
391 /// the op would be moved to).
392 struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> {
393   using OpRewritePattern<AllocaScopeOp>::OpRewritePattern;
394 
matchAndRewriteAllocaScopeHoister395   LogicalResult matchAndRewrite(AllocaScopeOp op,
396                                 PatternRewriter &rewriter) const override {
397 
398     if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
399       return failure();
400 
401     Operation *lastParentWithoutScope = op->getParentOp();
402 
403     if (!lastParentWithoutScope ||
404         lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>())
405       return failure();
406 
407     // Only apply to if this is this last non-terminator
408     // op in the block (lest lifetime be extended) of a one
409     // block region
410     if (!lastNonTerminatorInRegion(op) ||
411         !lastNonTerminatorInRegion(lastParentWithoutScope))
412       return failure();
413 
414     while (!lastParentWithoutScope->getParentOp()
415                 ->hasTrait<OpTrait::AutomaticAllocationScope>()) {
416       lastParentWithoutScope = lastParentWithoutScope->getParentOp();
417       if (!lastParentWithoutScope ||
418           !lastNonTerminatorInRegion(lastParentWithoutScope))
419         return failure();
420     }
421     assert(lastParentWithoutScope->getParentOp()
422                ->hasTrait<OpTrait::AutomaticAllocationScope>());
423 
424     Region *containingRegion = nullptr;
425     for (auto &r : lastParentWithoutScope->getRegions()) {
426       if (r.isAncestor(op->getParentRegion())) {
427         assert(containingRegion == nullptr &&
428                "only one region can contain the op");
429         containingRegion = &r;
430       }
431     }
432     assert(containingRegion && "op must be contained in a region");
433 
434     SmallVector<Operation *> toHoist;
435     op->walk([&](Operation *alloc) {
436       if (!isGuaranteedAutomaticAllocation(alloc))
437         return WalkResult::skip();
438 
439       // If any operand is not defined before the location of
440       // lastParentWithoutScope (i.e. where we would hoist to), skip.
441       if (llvm::any_of(alloc->getOperands(), [&](Value v) {
442             return containingRegion->isAncestor(v.getParentRegion());
443           }))
444         return WalkResult::skip();
445       toHoist.push_back(alloc);
446       return WalkResult::advance();
447     });
448 
449     if (toHoist.empty())
450       return failure();
451     rewriter.setInsertionPoint(lastParentWithoutScope);
452     for (auto *op : toHoist) {
453       auto *cloned = rewriter.clone(*op);
454       rewriter.replaceOp(op, cloned->getResults());
455     }
456     return success();
457   }
458 };
459 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)460 void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,
461                                                 MLIRContext *context) {
462   results.add<AllocaScopeInliner, AllocaScopeHoister>(context);
463 }
464 
465 //===----------------------------------------------------------------------===//
466 // AssumeAlignmentOp
467 //===----------------------------------------------------------------------===//
468 
verify()469 LogicalResult AssumeAlignmentOp::verify() {
470   if (!llvm::isPowerOf2_32(getAlignment()))
471     return emitOpError("alignment must be power of 2");
472   return success();
473 }
474 
475 //===----------------------------------------------------------------------===//
476 // CastOp
477 //===----------------------------------------------------------------------===//
478 
479 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
480 /// source memref. This is useful to to fold a memref.cast into a consuming op
481 /// and implement canonicalization patterns for ops in different dialects that
482 /// may consume the results of memref.cast operations. Such foldable memref.cast
483 /// operations are typically inserted as `view` and `subview` ops are
484 /// canonicalized, to preserve the type compatibility of their uses.
485 ///
486 /// Returns true when all conditions are met:
487 /// 1. source and result are ranked memrefs with strided semantics and same
488 /// element type and rank.
489 /// 2. each of the source's size, offset or stride has more static information
490 /// than the corresponding result's size, offset or stride.
491 ///
492 /// Example 1:
493 /// ```mlir
494 ///   %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
495 ///   %2 = consumer %1 ... : memref<?x?xf32> ...
496 /// ```
497 ///
498 /// may fold into:
499 ///
500 /// ```mlir
501 ///   %2 = consumer %0 ... : memref<8x16xf32> ...
502 /// ```
503 ///
504 /// Example 2:
505 /// ```
506 ///   %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
507 ///          to memref<?x?xf32>
508 ///   consumer %1 : memref<?x?xf32> ...
509 /// ```
510 ///
511 /// may fold into:
512 ///
513 /// ```
514 ///   consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
515 /// ```
canFoldIntoConsumerOp(CastOp castOp)516 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
517   MemRefType sourceType = castOp.getSource().getType().dyn_cast<MemRefType>();
518   MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
519 
520   // Requires ranked MemRefType.
521   if (!sourceType || !resultType)
522     return false;
523 
524   // Requires same elemental type.
525   if (sourceType.getElementType() != resultType.getElementType())
526     return false;
527 
528   // Requires same rank.
529   if (sourceType.getRank() != resultType.getRank())
530     return false;
531 
532   // Only fold casts between strided memref forms.
533   int64_t sourceOffset, resultOffset;
534   SmallVector<int64_t, 4> sourceStrides, resultStrides;
535   if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
536       failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
537     return false;
538 
539   // If cast is towards more static sizes along any dimension, don't fold.
540   for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
541     auto ss = std::get<0>(it), st = std::get<1>(it);
542     if (ss != st)
543       if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
544         return false;
545   }
546 
547   // If cast is towards more static offset along any dimension, don't fold.
548   if (sourceOffset != resultOffset)
549     if (ShapedType::isDynamicStrideOrOffset(sourceOffset) &&
550         !ShapedType::isDynamicStrideOrOffset(resultOffset))
551       return false;
552 
553   // If cast is towards more static strides along any dimension, don't fold.
554   for (auto it : llvm::zip(sourceStrides, resultStrides)) {
555     auto ss = std::get<0>(it), st = std::get<1>(it);
556     if (ss != st)
557       if (ShapedType::isDynamicStrideOrOffset(ss) &&
558           !ShapedType::isDynamicStrideOrOffset(st))
559         return false;
560   }
561 
562   return true;
563 }
564 
areCastCompatible(TypeRange inputs,TypeRange outputs)565 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
566   if (inputs.size() != 1 || outputs.size() != 1)
567     return false;
568   Type a = inputs.front(), b = outputs.front();
569   auto aT = a.dyn_cast<MemRefType>();
570   auto bT = b.dyn_cast<MemRefType>();
571 
572   auto uaT = a.dyn_cast<UnrankedMemRefType>();
573   auto ubT = b.dyn_cast<UnrankedMemRefType>();
574 
575   if (aT && bT) {
576     if (aT.getElementType() != bT.getElementType())
577       return false;
578     if (aT.getLayout() != bT.getLayout()) {
579       int64_t aOffset, bOffset;
580       SmallVector<int64_t, 4> aStrides, bStrides;
581       if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
582           failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
583           aStrides.size() != bStrides.size())
584         return false;
585 
586       // Strides along a dimension/offset are compatible if the value in the
587       // source memref is static and the value in the target memref is the
588       // same. They are also compatible if either one is dynamic (see
589       // description of MemRefCastOp for details).
590       auto checkCompatible = [](int64_t a, int64_t b) {
591         return (a == MemRefType::getDynamicStrideOrOffset() ||
592                 b == MemRefType::getDynamicStrideOrOffset() || a == b);
593       };
594       if (!checkCompatible(aOffset, bOffset))
595         return false;
596       for (const auto &aStride : enumerate(aStrides))
597         if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
598           return false;
599     }
600     if (aT.getMemorySpace() != bT.getMemorySpace())
601       return false;
602 
603     // They must have the same rank, and any specified dimensions must match.
604     if (aT.getRank() != bT.getRank())
605       return false;
606 
607     for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
608       int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
609       if (aDim != -1 && bDim != -1 && aDim != bDim)
610         return false;
611     }
612     return true;
613   } else {
614     if (!aT && !uaT)
615       return false;
616     if (!bT && !ubT)
617       return false;
618     // Unranked to unranked casting is unsupported
619     if (uaT && ubT)
620       return false;
621 
622     auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
623     auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
624     if (aEltType != bEltType)
625       return false;
626 
627     auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
628     auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
629     return aMemSpace == bMemSpace;
630   }
631 
632   return false;
633 }
634 
fold(ArrayRef<Attribute> operands)635 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
636   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
637 }
638 
639 //===----------------------------------------------------------------------===//
640 // CopyOp
641 //===----------------------------------------------------------------------===//
642 
643 namespace {
644 /// If the source/target of a CopyOp is a CastOp that does not modify the shape
645 /// and element type, the cast can be skipped. Such CastOps only cast the layout
646 /// of the type.
647 struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
648   using OpRewritePattern<CopyOp>::OpRewritePattern;
649 
matchAndRewrite__anon3c5201a10911::FoldCopyOfCast650   LogicalResult matchAndRewrite(CopyOp copyOp,
651                                 PatternRewriter &rewriter) const override {
652     bool modified = false;
653 
654     // Check source.
655     if (auto castOp = copyOp.getSource().getDefiningOp<CastOp>()) {
656       auto fromType = castOp.getSource().getType().dyn_cast<MemRefType>();
657       auto toType = castOp.getSource().getType().dyn_cast<MemRefType>();
658 
659       if (fromType && toType) {
660         if (fromType.getShape() == toType.getShape() &&
661             fromType.getElementType() == toType.getElementType()) {
662           rewriter.updateRootInPlace(copyOp, [&] {
663             copyOp.getSourceMutable().assign(castOp.getSource());
664           });
665           modified = true;
666         }
667       }
668     }
669 
670     // Check target.
671     if (auto castOp = copyOp.getTarget().getDefiningOp<CastOp>()) {
672       auto fromType = castOp.getSource().getType().dyn_cast<MemRefType>();
673       auto toType = castOp.getSource().getType().dyn_cast<MemRefType>();
674 
675       if (fromType && toType) {
676         if (fromType.getShape() == toType.getShape() &&
677             fromType.getElementType() == toType.getElementType()) {
678           rewriter.updateRootInPlace(copyOp, [&] {
679             copyOp.getTargetMutable().assign(castOp.getSource());
680           });
681           modified = true;
682         }
683       }
684     }
685 
686     return success(modified);
687   }
688 };
689 
690 /// Fold memref.copy(%x, %x).
691 struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
692   using OpRewritePattern<CopyOp>::OpRewritePattern;
693 
matchAndRewrite__anon3c5201a10911::FoldSelfCopy694   LogicalResult matchAndRewrite(CopyOp copyOp,
695                                 PatternRewriter &rewriter) const override {
696     if (copyOp.getSource() != copyOp.getTarget())
697       return failure();
698 
699     rewriter.eraseOp(copyOp);
700     return success();
701   }
702 };
703 } // namespace
704 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)705 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
706                                          MLIRContext *context) {
707   results.add<FoldCopyOfCast, FoldSelfCopy>(context);
708 }
709 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)710 LogicalResult CopyOp::fold(ArrayRef<Attribute> cstOperands,
711                            SmallVectorImpl<OpFoldResult> &results) {
712   /// copy(memrefcast) -> copy
713   bool folded = false;
714   Operation *op = *this;
715   for (OpOperand &operand : op->getOpOperands()) {
716     auto castOp = operand.get().getDefiningOp<memref::CastOp>();
717     if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
718       operand.set(castOp.getOperand());
719       folded = true;
720     }
721   }
722   return success(folded);
723 }
724 
725 //===----------------------------------------------------------------------===//
726 // DeallocOp
727 //===----------------------------------------------------------------------===//
728 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)729 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
730                               SmallVectorImpl<OpFoldResult> &results) {
731   /// dealloc(memrefcast) -> dealloc
732   return foldMemRefCast(*this);
733 }
734 
735 //===----------------------------------------------------------------------===//
736 // DimOp
737 //===----------------------------------------------------------------------===//
738 
build(OpBuilder & builder,OperationState & result,Value source,int64_t index)739 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
740                   int64_t index) {
741   auto loc = result.location;
742   Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
743   build(builder, result, source, indexValue);
744 }
745 
build(OpBuilder & builder,OperationState & result,Value source,Value index)746 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
747                   Value index) {
748   auto indexTy = builder.getIndexType();
749   build(builder, result, indexTy, source, index);
750 }
751 
getConstantIndex()752 Optional<int64_t> DimOp::getConstantIndex() {
753   if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
754     return constantOp.getValue().cast<IntegerAttr>().getInt();
755   return {};
756 }
757 
verify()758 LogicalResult DimOp::verify() {
759   // Assume unknown index to be in range.
760   Optional<int64_t> index = getConstantIndex();
761   if (!index)
762     return success();
763 
764   // Check that constant index is not knowingly out of range.
765   auto type = getSource().getType();
766   if (auto memrefType = type.dyn_cast<MemRefType>()) {
767     if (*index >= memrefType.getRank())
768       return emitOpError("index is out of range");
769   } else if (type.isa<UnrankedMemRefType>()) {
770     // Assume index to be in range.
771   } else {
772     llvm_unreachable("expected operand with memref type");
773   }
774   return success();
775 }
776 
777 /// Return a map with key being elements in `vals` and data being number of
778 /// occurences of it. Use std::map, since the `vals` here are strides and the
779 /// dynamic stride value is the same as the tombstone value for
780 /// `DenseMap<int64_t>`.
getNumOccurences(ArrayRef<int64_t> vals)781 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
782   std::map<int64_t, unsigned> numOccurences;
783   for (auto val : vals)
784     numOccurences[val]++;
785   return numOccurences;
786 }
787 
788 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed
789 /// to be a subset of `originalType` with some `1` entries erased, return the
790 /// set of indices that specifies which of the entries of `originalShape` are
791 /// dropped to obtain `reducedShape`.
792 /// This accounts for cases where there are multiple unit-dims, but only a
793 /// subset of those are dropped. For MemRefTypes these can be disambiguated
794 /// using the strides. If a dimension is dropped the stride must be dropped too.
795 static llvm::Optional<llvm::SmallBitVector>
computeMemRefRankReductionMask(MemRefType originalType,MemRefType reducedType,ArrayRef<OpFoldResult> sizes)796 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
797                                ArrayRef<OpFoldResult> sizes) {
798   llvm::SmallBitVector unusedDims(originalType.getRank());
799   if (originalType.getRank() == reducedType.getRank())
800     return unusedDims;
801 
802   for (const auto &dim : llvm::enumerate(sizes))
803     if (auto attr = dim.value().dyn_cast<Attribute>())
804       if (attr.cast<IntegerAttr>().getInt() == 1)
805         unusedDims.set(dim.index());
806 
807   // Early exit for the case where the number of unused dims matches the number
808   // of ranks reduced.
809   if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
810       originalType.getRank())
811     return unusedDims;
812 
813   SmallVector<int64_t> originalStrides, candidateStrides;
814   int64_t originalOffset, candidateOffset;
815   if (failed(
816           getStridesAndOffset(originalType, originalStrides, originalOffset)) ||
817       failed(
818           getStridesAndOffset(reducedType, candidateStrides, candidateOffset)))
819     return llvm::None;
820 
821   // For memrefs, a dimension is truly dropped if its corresponding stride is
822   // also dropped. This is particularly important when more than one of the dims
823   // is 1. Track the number of occurences of the strides in the original type
824   // and the candidate type. For each unused dim that stride should not be
825   // present in the candidate type. Note that there could be multiple dimensions
826   // that have the same size. We dont need to exactly figure out which dim
827   // corresponds to which stride, we just need to verify that the number of
828   // reptitions of a stride in the original + number of unused dims with that
829   // stride == number of repititions of a stride in the candidate.
830   std::map<int64_t, unsigned> currUnaccountedStrides =
831       getNumOccurences(originalStrides);
832   std::map<int64_t, unsigned> candidateStridesNumOccurences =
833       getNumOccurences(candidateStrides);
834   for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
835     if (!unusedDims.test(dim))
836       continue;
837     int64_t originalStride = originalStrides[dim];
838     if (currUnaccountedStrides[originalStride] >
839         candidateStridesNumOccurences[originalStride]) {
840       // This dim can be treated as dropped.
841       currUnaccountedStrides[originalStride]--;
842       continue;
843     }
844     if (currUnaccountedStrides[originalStride] ==
845         candidateStridesNumOccurences[originalStride]) {
846       // The stride for this is not dropped. Keep as is.
847       unusedDims.reset(dim);
848       continue;
849     }
850     if (currUnaccountedStrides[originalStride] <
851         candidateStridesNumOccurences[originalStride]) {
852       // This should never happen. Cant have a stride in the reduced rank type
853       // that wasnt in the original one.
854       return llvm::None;
855     }
856   }
857 
858   if ((int64_t)unusedDims.count() + reducedType.getRank() !=
859       originalType.getRank())
860     return llvm::None;
861   return unusedDims;
862 }
863 
getDroppedDims()864 llvm::SmallBitVector SubViewOp::getDroppedDims() {
865   MemRefType sourceType = getSourceType();
866   MemRefType resultType = getType();
867   llvm::Optional<llvm::SmallBitVector> unusedDims =
868       computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
869   assert(unusedDims && "unable to find unused dims of subview");
870   return *unusedDims;
871 }
872 
fold(ArrayRef<Attribute> operands)873 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
874   // All forms of folding require a known index.
875   auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
876   if (!index)
877     return {};
878 
879   // Folding for unranked types (UnrankedMemRefType) is not supported.
880   auto memrefType = getSource().getType().dyn_cast<MemRefType>();
881   if (!memrefType)
882     return {};
883 
884   // Fold if the shape extent along the given index is known.
885   if (!memrefType.isDynamicDim(index.getInt())) {
886     Builder builder(getContext());
887     return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
888   }
889 
890   // The size at the given index is now known to be a dynamic size.
891   unsigned unsignedIndex = index.getValue().getZExtValue();
892 
893   // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
894   Operation *definingOp = getSource().getDefiningOp();
895 
896   if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
897     return *(alloc.getDynamicSizes().begin() +
898              memrefType.getDynamicDimIndex(unsignedIndex));
899 
900   if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
901     return *(alloca.getDynamicSizes().begin() +
902              memrefType.getDynamicDimIndex(unsignedIndex));
903 
904   if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
905     return *(view.getDynamicSizes().begin() +
906              memrefType.getDynamicDimIndex(unsignedIndex));
907 
908   if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
909     llvm::SmallBitVector unusedDims = subview.getDroppedDims();
910     unsigned resultIndex = 0;
911     unsigned sourceRank = subview.getSourceType().getRank();
912     unsigned sourceIndex = 0;
913     for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
914       if (unusedDims.test(i))
915         continue;
916       if (resultIndex == unsignedIndex) {
917         sourceIndex = i;
918         break;
919       }
920       resultIndex++;
921     }
922     assert(subview.isDynamicSize(sourceIndex) &&
923            "expected dynamic subview size");
924     return subview.getDynamicSize(sourceIndex);
925   }
926 
927   if (auto sizeInterface =
928           dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
929     assert(sizeInterface.isDynamicSize(unsignedIndex) &&
930            "Expected dynamic subview size");
931     return sizeInterface.getDynamicSize(unsignedIndex);
932   }
933 
934   // dim(memrefcast) -> dim
935   if (succeeded(foldMemRefCast(*this)))
936     return getResult();
937 
938   return {};
939 }
940 
941 namespace {
942 /// Fold dim of a memref reshape operation to a load into the reshape's shape
943 /// operand.
944 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
945   using OpRewritePattern<DimOp>::OpRewritePattern;
946 
matchAndRewrite__anon3c5201a10c11::DimOfMemRefReshape947   LogicalResult matchAndRewrite(DimOp dim,
948                                 PatternRewriter &rewriter) const override {
949     auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
950 
951     if (!reshape)
952       return failure();
953 
954     // Place the load directly after the reshape to ensure that the shape memref
955     // was not mutated.
956     rewriter.setInsertionPointAfter(reshape);
957     Location loc = dim.getLoc();
958     Value load =
959         rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
960     if (load.getType() != dim.getType())
961       load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
962     rewriter.replaceOp(dim, load);
963     return success();
964   }
965 };
966 
967 } // namespace
968 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)969 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
970                                         MLIRContext *context) {
971   results.add<DimOfMemRefReshape>(context);
972 }
973 
974 // ---------------------------------------------------------------------------
975 // DmaStartOp
976 // ---------------------------------------------------------------------------
977 
build(OpBuilder & builder,OperationState & result,Value srcMemRef,ValueRange srcIndices,Value destMemRef,ValueRange destIndices,Value numElements,Value tagMemRef,ValueRange tagIndices,Value stride,Value elementsPerStride)978 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
979                        Value srcMemRef, ValueRange srcIndices, Value destMemRef,
980                        ValueRange destIndices, Value numElements,
981                        Value tagMemRef, ValueRange tagIndices, Value stride,
982                        Value elementsPerStride) {
983   result.addOperands(srcMemRef);
984   result.addOperands(srcIndices);
985   result.addOperands(destMemRef);
986   result.addOperands(destIndices);
987   result.addOperands({numElements, tagMemRef});
988   result.addOperands(tagIndices);
989   if (stride)
990     result.addOperands({stride, elementsPerStride});
991 }
992 
print(OpAsmPrinter & p)993 void DmaStartOp::print(OpAsmPrinter &p) {
994   p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
995     << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
996     << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
997   if (isStrided())
998     p << ", " << getStride() << ", " << getNumElementsPerStride();
999 
1000   p.printOptionalAttrDict((*this)->getAttrs());
1001   p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
1002     << ", " << getTagMemRef().getType();
1003 }
1004 
1005 // Parse DmaStartOp.
1006 // Ex:
1007 //   %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
1008 //                       %tag[%index], %stride, %num_elt_per_stride :
1009 //                     : memref<3076 x f32, 0>,
1010 //                       memref<1024 x f32, 2>,
1011 //                       memref<1 x i32>
1012 //
parse(OpAsmParser & parser,OperationState & result)1013 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
1014   OpAsmParser::UnresolvedOperand srcMemRefInfo;
1015   SmallVector<OpAsmParser::UnresolvedOperand, 4> srcIndexInfos;
1016   OpAsmParser::UnresolvedOperand dstMemRefInfo;
1017   SmallVector<OpAsmParser::UnresolvedOperand, 4> dstIndexInfos;
1018   OpAsmParser::UnresolvedOperand numElementsInfo;
1019   OpAsmParser::UnresolvedOperand tagMemrefInfo;
1020   SmallVector<OpAsmParser::UnresolvedOperand, 4> tagIndexInfos;
1021   SmallVector<OpAsmParser::UnresolvedOperand, 2> strideInfo;
1022 
1023   SmallVector<Type, 3> types;
1024   auto indexType = parser.getBuilder().getIndexType();
1025 
1026   // Parse and resolve the following list of operands:
1027   // *) source memref followed by its indices (in square brackets).
1028   // *) destination memref followed by its indices (in square brackets).
1029   // *) dma size in KiB.
1030   if (parser.parseOperand(srcMemRefInfo) ||
1031       parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
1032       parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1033       parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
1034       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1035       parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
1036       parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
1037     return failure();
1038 
1039   // Parse optional stride and elements per stride.
1040   if (parser.parseTrailingOperandList(strideInfo))
1041     return failure();
1042 
1043   bool isStrided = strideInfo.size() == 2;
1044   if (!strideInfo.empty() && !isStrided) {
1045     return parser.emitError(parser.getNameLoc(),
1046                             "expected two stride related operands");
1047   }
1048 
1049   if (parser.parseColonTypeList(types))
1050     return failure();
1051   if (types.size() != 3)
1052     return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1053 
1054   if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1055       parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1056       parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1057       parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1058       // size should be an index.
1059       parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1060       parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1061       // tag indices should be index.
1062       parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1063     return failure();
1064 
1065   if (isStrided) {
1066     if (parser.resolveOperands(strideInfo, indexType, result.operands))
1067       return failure();
1068   }
1069 
1070   return success();
1071 }
1072 
verify()1073 LogicalResult DmaStartOp::verify() {
1074   unsigned numOperands = getNumOperands();
1075 
1076   // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1077   // the number of elements.
1078   if (numOperands < 4)
1079     return emitOpError("expected at least 4 operands");
1080 
1081   // Check types of operands. The order of these calls is important: the later
1082   // calls rely on some type properties to compute the operand position.
1083   // 1. Source memref.
1084   if (!getSrcMemRef().getType().isa<MemRefType>())
1085     return emitOpError("expected source to be of memref type");
1086   if (numOperands < getSrcMemRefRank() + 4)
1087     return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1088                          << " operands";
1089   if (!getSrcIndices().empty() &&
1090       !llvm::all_of(getSrcIndices().getTypes(),
1091                     [](Type t) { return t.isIndex(); }))
1092     return emitOpError("expected source indices to be of index type");
1093 
1094   // 2. Destination memref.
1095   if (!getDstMemRef().getType().isa<MemRefType>())
1096     return emitOpError("expected destination to be of memref type");
1097   unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1098   if (numOperands < numExpectedOperands)
1099     return emitOpError() << "expected at least " << numExpectedOperands
1100                          << " operands";
1101   if (!getDstIndices().empty() &&
1102       !llvm::all_of(getDstIndices().getTypes(),
1103                     [](Type t) { return t.isIndex(); }))
1104     return emitOpError("expected destination indices to be of index type");
1105 
1106   // 3. Number of elements.
1107   if (!getNumElements().getType().isIndex())
1108     return emitOpError("expected num elements to be of index type");
1109 
1110   // 4. Tag memref.
1111   if (!getTagMemRef().getType().isa<MemRefType>())
1112     return emitOpError("expected tag to be of memref type");
1113   numExpectedOperands += getTagMemRefRank();
1114   if (numOperands < numExpectedOperands)
1115     return emitOpError() << "expected at least " << numExpectedOperands
1116                          << " operands";
1117   if (!getTagIndices().empty() &&
1118       !llvm::all_of(getTagIndices().getTypes(),
1119                     [](Type t) { return t.isIndex(); }))
1120     return emitOpError("expected tag indices to be of index type");
1121 
1122   // Optional stride-related operands must be either both present or both
1123   // absent.
1124   if (numOperands != numExpectedOperands &&
1125       numOperands != numExpectedOperands + 2)
1126     return emitOpError("incorrect number of operands");
1127 
1128   // 5. Strides.
1129   if (isStrided()) {
1130     if (!getStride().getType().isIndex() ||
1131         !getNumElementsPerStride().getType().isIndex())
1132       return emitOpError(
1133           "expected stride and num elements per stride to be of type index");
1134   }
1135 
1136   return success();
1137 }
1138 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1139 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1140                                SmallVectorImpl<OpFoldResult> &results) {
1141   /// dma_start(memrefcast) -> dma_start
1142   return foldMemRefCast(*this);
1143 }
1144 
1145 // ---------------------------------------------------------------------------
1146 // DmaWaitOp
1147 // ---------------------------------------------------------------------------
1148 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1149 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1150                               SmallVectorImpl<OpFoldResult> &results) {
1151   /// dma_wait(memrefcast) -> dma_wait
1152   return foldMemRefCast(*this);
1153 }
1154 
verify()1155 LogicalResult DmaWaitOp::verify() {
1156   // Check that the number of tag indices matches the tagMemRef rank.
1157   unsigned numTagIndices = getTagIndices().size();
1158   unsigned tagMemRefRank = getTagMemRefRank();
1159   if (numTagIndices != tagMemRefRank)
1160     return emitOpError() << "expected tagIndices to have the same number of "
1161                             "elements as the tagMemRef rank, expected "
1162                          << tagMemRefRank << ", but got " << numTagIndices;
1163   return success();
1164 }
1165 
1166 //===----------------------------------------------------------------------===//
1167 // GenericAtomicRMWOp
1168 //===----------------------------------------------------------------------===//
1169 
build(OpBuilder & builder,OperationState & result,Value memref,ValueRange ivs)1170 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
1171                                Value memref, ValueRange ivs) {
1172   result.addOperands(memref);
1173   result.addOperands(ivs);
1174 
1175   if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) {
1176     Type elementType = memrefType.getElementType();
1177     result.addTypes(elementType);
1178 
1179     Region *bodyRegion = result.addRegion();
1180     bodyRegion->push_back(new Block());
1181     bodyRegion->addArgument(elementType, memref.getLoc());
1182   }
1183 }
1184 
verify()1185 LogicalResult GenericAtomicRMWOp::verify() {
1186   auto &body = getRegion();
1187   if (body.getNumArguments() != 1)
1188     return emitOpError("expected single number of entry block arguments");
1189 
1190   if (getResult().getType() != body.getArgument(0).getType())
1191     return emitOpError("expected block argument of the same type result type");
1192 
1193   bool hasSideEffects =
1194       body.walk([&](Operation *nestedOp) {
1195             if (MemoryEffectOpInterface::hasNoEffect(nestedOp))
1196               return WalkResult::advance();
1197             nestedOp->emitError(
1198                 "body of 'memref.generic_atomic_rmw' should contain "
1199                 "only operations with no side effects");
1200             return WalkResult::interrupt();
1201           })
1202           .wasInterrupted();
1203   return hasSideEffects ? failure() : success();
1204 }
1205 
parse(OpAsmParser & parser,OperationState & result)1206 ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
1207                                       OperationState &result) {
1208   OpAsmParser::UnresolvedOperand memref;
1209   Type memrefType;
1210   SmallVector<OpAsmParser::UnresolvedOperand, 4> ivs;
1211 
1212   Type indexType = parser.getBuilder().getIndexType();
1213   if (parser.parseOperand(memref) ||
1214       parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) ||
1215       parser.parseColonType(memrefType) ||
1216       parser.resolveOperand(memref, memrefType, result.operands) ||
1217       parser.resolveOperands(ivs, indexType, result.operands))
1218     return failure();
1219 
1220   Region *body = result.addRegion();
1221   if (parser.parseRegion(*body, {}) ||
1222       parser.parseOptionalAttrDict(result.attributes))
1223     return failure();
1224   result.types.push_back(memrefType.cast<MemRefType>().getElementType());
1225   return success();
1226 }
1227 
print(OpAsmPrinter & p)1228 void GenericAtomicRMWOp::print(OpAsmPrinter &p) {
1229   p << ' ' << getMemref() << "[" << getIndices()
1230     << "] : " << getMemref().getType() << ' ';
1231   p.printRegion(getRegion());
1232   p.printOptionalAttrDict((*this)->getAttrs());
1233 }
1234 
1235 //===----------------------------------------------------------------------===//
1236 // AtomicYieldOp
1237 //===----------------------------------------------------------------------===//
1238 
verify()1239 LogicalResult AtomicYieldOp::verify() {
1240   Type parentType = (*this)->getParentOp()->getResultTypes().front();
1241   Type resultType = getResult().getType();
1242   if (parentType != resultType)
1243     return emitOpError() << "types mismatch between yield op: " << resultType
1244                          << " and its parent: " << parentType;
1245   return success();
1246 }
1247 
1248 //===----------------------------------------------------------------------===//
1249 // GlobalOp
1250 //===----------------------------------------------------------------------===//
1251 
printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter & p,GlobalOp op,TypeAttr type,Attribute initialValue)1252 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op,
1253                                                    TypeAttr type,
1254                                                    Attribute initialValue) {
1255   p << type;
1256   if (!op.isExternal()) {
1257     p << " = ";
1258     if (op.isUninitialized())
1259       p << "uninitialized";
1260     else
1261       p.printAttributeWithoutType(initialValue);
1262   }
1263 }
1264 
1265 static ParseResult
parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser & parser,TypeAttr & typeAttr,Attribute & initialValue)1266 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
1267                                        Attribute &initialValue) {
1268   Type type;
1269   if (parser.parseType(type))
1270     return failure();
1271 
1272   auto memrefType = type.dyn_cast<MemRefType>();
1273   if (!memrefType || !memrefType.hasStaticShape())
1274     return parser.emitError(parser.getNameLoc())
1275            << "type should be static shaped memref, but got " << type;
1276   typeAttr = TypeAttr::get(type);
1277 
1278   if (parser.parseOptionalEqual())
1279     return success();
1280 
1281   if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1282     initialValue = UnitAttr::get(parser.getContext());
1283     return success();
1284   }
1285 
1286   Type tensorType = getTensorTypeFromMemRefType(memrefType);
1287   if (parser.parseAttribute(initialValue, tensorType))
1288     return failure();
1289   if (!initialValue.isa<ElementsAttr>())
1290     return parser.emitError(parser.getNameLoc())
1291            << "initial value should be a unit or elements attribute";
1292   return success();
1293 }
1294 
verify()1295 LogicalResult GlobalOp::verify() {
1296   auto memrefType = getType().dyn_cast<MemRefType>();
1297   if (!memrefType || !memrefType.hasStaticShape())
1298     return emitOpError("type should be static shaped memref, but got ")
1299            << getType();
1300 
1301   // Verify that the initial value, if present, is either a unit attribute or
1302   // an elements attribute.
1303   if (getInitialValue().has_value()) {
1304     Attribute initValue = getInitialValue().value();
1305     if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>())
1306       return emitOpError("initial value should be a unit or elements "
1307                          "attribute, but got ")
1308              << initValue;
1309 
1310     // Check that the type of the initial value is compatible with the type of
1311     // the global variable.
1312     if (initValue.isa<ElementsAttr>()) {
1313       Type initType = initValue.getType();
1314       Type tensorType = getTensorTypeFromMemRefType(memrefType);
1315       if (initType != tensorType)
1316         return emitOpError("initial value expected to be of type ")
1317                << tensorType << ", but was of type " << initType;
1318     }
1319   }
1320 
1321   if (Optional<uint64_t> alignAttr = getAlignment()) {
1322     uint64_t alignment = *alignAttr;
1323 
1324     if (!llvm::isPowerOf2_64(alignment))
1325       return emitError() << "alignment attribute value " << alignment
1326                          << " is not a power of 2";
1327   }
1328 
1329   // TODO: verify visibility for declarations.
1330   return success();
1331 }
1332 
getConstantInitValue()1333 ElementsAttr GlobalOp::getConstantInitValue() {
1334   auto initVal = getInitialValue();
1335   if (getConstant() && initVal.has_value())
1336     return initVal.value().cast<ElementsAttr>();
1337   return {};
1338 }
1339 
1340 //===----------------------------------------------------------------------===//
1341 // GetGlobalOp
1342 //===----------------------------------------------------------------------===//
1343 
1344 LogicalResult
verifySymbolUses(SymbolTableCollection & symbolTable)1345 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1346   // Verify that the result type is same as the type of the referenced
1347   // memref.global op.
1348   auto global =
1349       symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1350   if (!global)
1351     return emitOpError("'")
1352            << getName() << "' does not reference a valid global memref";
1353 
1354   Type resultType = getResult().getType();
1355   if (global.getType() != resultType)
1356     return emitOpError("result type ")
1357            << resultType << " does not match type " << global.getType()
1358            << " of the global memref @" << getName();
1359   return success();
1360 }
1361 
1362 //===----------------------------------------------------------------------===//
1363 // LoadOp
1364 //===----------------------------------------------------------------------===//
1365 
verify()1366 LogicalResult LoadOp::verify() {
1367   if (getNumOperands() != 1 + getMemRefType().getRank())
1368     return emitOpError("incorrect number of indices for load");
1369   return success();
1370 }
1371 
fold(ArrayRef<Attribute> cstOperands)1372 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
1373   /// load(memrefcast) -> load
1374   if (succeeded(foldMemRefCast(*this)))
1375     return getResult();
1376   return OpFoldResult();
1377 }
1378 
1379 //===----------------------------------------------------------------------===//
1380 // PrefetchOp
1381 //===----------------------------------------------------------------------===//
1382 
print(OpAsmPrinter & p)1383 void PrefetchOp::print(OpAsmPrinter &p) {
1384   p << " " << getMemref() << '[';
1385   p.printOperands(getIndices());
1386   p << ']' << ", " << (getIsWrite() ? "write" : "read");
1387   p << ", locality<" << getLocalityHint();
1388   p << ">, " << (getIsDataCache() ? "data" : "instr");
1389   p.printOptionalAttrDict(
1390       (*this)->getAttrs(),
1391       /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1392   p << " : " << getMemRefType();
1393 }
1394 
parse(OpAsmParser & parser,OperationState & result)1395 ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
1396   OpAsmParser::UnresolvedOperand memrefInfo;
1397   SmallVector<OpAsmParser::UnresolvedOperand, 4> indexInfo;
1398   IntegerAttr localityHint;
1399   MemRefType type;
1400   StringRef readOrWrite, cacheType;
1401 
1402   auto indexTy = parser.getBuilder().getIndexType();
1403   auto i32Type = parser.getBuilder().getIntegerType(32);
1404   if (parser.parseOperand(memrefInfo) ||
1405       parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1406       parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1407       parser.parseComma() || parser.parseKeyword("locality") ||
1408       parser.parseLess() ||
1409       parser.parseAttribute(localityHint, i32Type, "localityHint",
1410                             result.attributes) ||
1411       parser.parseGreater() || parser.parseComma() ||
1412       parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1413       parser.resolveOperand(memrefInfo, type, result.operands) ||
1414       parser.resolveOperands(indexInfo, indexTy, result.operands))
1415     return failure();
1416 
1417   if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
1418     return parser.emitError(parser.getNameLoc(),
1419                             "rw specifier has to be 'read' or 'write'");
1420   result.addAttribute(
1421       PrefetchOp::getIsWriteAttrStrName(),
1422       parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
1423 
1424   if (!cacheType.equals("data") && !cacheType.equals("instr"))
1425     return parser.emitError(parser.getNameLoc(),
1426                             "cache type has to be 'data' or 'instr'");
1427 
1428   result.addAttribute(
1429       PrefetchOp::getIsDataCacheAttrStrName(),
1430       parser.getBuilder().getBoolAttr(cacheType.equals("data")));
1431 
1432   return success();
1433 }
1434 
verify()1435 LogicalResult PrefetchOp::verify() {
1436   if (getNumOperands() != 1 + getMemRefType().getRank())
1437     return emitOpError("too few indices");
1438 
1439   return success();
1440 }
1441 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1442 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
1443                                SmallVectorImpl<OpFoldResult> &results) {
1444   // prefetch(memrefcast) -> prefetch
1445   return foldMemRefCast(*this);
1446 }
1447 
1448 //===----------------------------------------------------------------------===//
1449 // RankOp
1450 //===----------------------------------------------------------------------===//
1451 
fold(ArrayRef<Attribute> operands)1452 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
1453   // Constant fold rank when the rank of the operand is known.
1454   auto type = getOperand().getType();
1455   auto shapedType = type.dyn_cast<ShapedType>();
1456   if (shapedType && shapedType.hasRank())
1457     return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1458   return IntegerAttr();
1459 }
1460 
1461 //===----------------------------------------------------------------------===//
1462 // ReinterpretCastOp
1463 //===----------------------------------------------------------------------===//
1464 
1465 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1466 /// `staticSizes` and `staticStrides` are automatically filled with
1467 /// source-memref-rank sentinel values that encode dynamic entries.
build(OpBuilder & b,OperationState & result,MemRefType resultType,Value source,OpFoldResult offset,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides,ArrayRef<NamedAttribute> attrs)1468 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1469                               MemRefType resultType, Value source,
1470                               OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1471                               ArrayRef<OpFoldResult> strides,
1472                               ArrayRef<NamedAttribute> attrs) {
1473   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1474   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1475   dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets,
1476                              ShapedType::kDynamicStrideOrOffset);
1477   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1478                              ShapedType::kDynamicSize);
1479   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1480                              ShapedType::kDynamicStrideOrOffset);
1481   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1482         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1483         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1484   result.addAttributes(attrs);
1485 }
1486 
build(OpBuilder & b,OperationState & result,MemRefType resultType,Value source,int64_t offset,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides,ArrayRef<NamedAttribute> attrs)1487 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1488                               MemRefType resultType, Value source,
1489                               int64_t offset, ArrayRef<int64_t> sizes,
1490                               ArrayRef<int64_t> strides,
1491                               ArrayRef<NamedAttribute> attrs) {
1492   SmallVector<OpFoldResult> sizeValues =
1493       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1494         return b.getI64IntegerAttr(v);
1495       }));
1496   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1497       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1498         return b.getI64IntegerAttr(v);
1499       }));
1500   build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1501         strideValues, attrs);
1502 }
1503 
build(OpBuilder & b,OperationState & result,MemRefType resultType,Value source,Value offset,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)1504 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1505                               MemRefType resultType, Value source, Value offset,
1506                               ValueRange sizes, ValueRange strides,
1507                               ArrayRef<NamedAttribute> attrs) {
1508   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1509       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1510   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1511       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1512   build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1513 }
1514 
1515 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1516 // completed automatically, like we have for subview and extract_slice.
verify()1517 LogicalResult ReinterpretCastOp::verify() {
1518   // The source and result memrefs should be in the same memory space.
1519   auto srcType = getSource().getType().cast<BaseMemRefType>();
1520   auto resultType = getType().cast<MemRefType>();
1521   if (srcType.getMemorySpace() != resultType.getMemorySpace())
1522     return emitError("different memory spaces specified for source type ")
1523            << srcType << " and result memref type " << resultType;
1524   if (srcType.getElementType() != resultType.getElementType())
1525     return emitError("different element types specified for source type ")
1526            << srcType << " and result memref type " << resultType;
1527 
1528   // Match sizes in result memref type and in static_sizes attribute.
1529   for (auto &en : llvm::enumerate(llvm::zip(
1530            resultType.getShape(), extractFromI64ArrayAttr(getStaticSizes())))) {
1531     int64_t resultSize = std::get<0>(en.value());
1532     int64_t expectedSize = std::get<1>(en.value());
1533     if (!ShapedType::isDynamic(resultSize) &&
1534         !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
1535       return emitError("expected result type with size = ")
1536              << expectedSize << " instead of " << resultSize
1537              << " in dim = " << en.index();
1538   }
1539 
1540   // Match offset and strides in static_offset and static_strides attributes. If
1541   // result memref type has no affine map specified, this will assume an
1542   // identity layout.
1543   int64_t resultOffset;
1544   SmallVector<int64_t, 4> resultStrides;
1545   if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
1546     return emitError("expected result type to have strided layout but found ")
1547            << resultType;
1548 
1549   // Match offset in result memref type and in static_offsets attribute.
1550   int64_t expectedOffset = extractFromI64ArrayAttr(getStaticOffsets()).front();
1551   if (!ShapedType::isDynamicStrideOrOffset(resultOffset) &&
1552       !ShapedType::isDynamicStrideOrOffset(expectedOffset) &&
1553       resultOffset != expectedOffset)
1554     return emitError("expected result type with offset = ")
1555            << resultOffset << " instead of " << expectedOffset;
1556 
1557   // Match strides in result memref type and in static_strides attribute.
1558   for (auto &en : llvm::enumerate(llvm::zip(
1559            resultStrides, extractFromI64ArrayAttr(getStaticStrides())))) {
1560     int64_t resultStride = std::get<0>(en.value());
1561     int64_t expectedStride = std::get<1>(en.value());
1562     if (!ShapedType::isDynamicStrideOrOffset(resultStride) &&
1563         !ShapedType::isDynamicStrideOrOffset(expectedStride) &&
1564         resultStride != expectedStride)
1565       return emitError("expected result type with stride = ")
1566              << expectedStride << " instead of " << resultStride
1567              << " in dim = " << en.index();
1568   }
1569 
1570   return success();
1571 }
1572 
fold(ArrayRef<Attribute>)1573 OpFoldResult ReinterpretCastOp::fold(ArrayRef<Attribute> /*operands*/) {
1574   Value src = getSource();
1575   auto getPrevSrc = [&]() -> Value {
1576     // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
1577     if (auto prev = src.getDefiningOp<ReinterpretCastOp>())
1578       return prev.getSource();
1579 
1580     // reinterpret_cast(cast(x)) -> reinterpret_cast(x).
1581     if (auto prev = src.getDefiningOp<CastOp>())
1582       return prev.getSource();
1583 
1584     // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
1585     // are 0.
1586     if (auto prev = src.getDefiningOp<SubViewOp>())
1587       if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) {
1588             return isConstantIntValue(val, 0);
1589           }))
1590         return prev.getSource();
1591 
1592     return nullptr;
1593   };
1594 
1595   if (auto prevSrc = getPrevSrc()) {
1596     getSourceMutable().assign(prevSrc);
1597     return getResult();
1598   }
1599 
1600   return nullptr;
1601 }
1602 
1603 //===----------------------------------------------------------------------===//
1604 // Reassociative reshape ops
1605 //===----------------------------------------------------------------------===//
1606 
1607 /// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
1608 /// result and operand. Layout maps are verified separately.
1609 ///
1610 /// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are
1611 /// allowed in a reassocation group.
1612 static LogicalResult
verifyCollapsedShape(Operation * op,ArrayRef<int64_t> collapsedShape,ArrayRef<int64_t> expandedShape,ArrayRef<ReassociationIndices> reassociation,bool allowMultipleDynamicDimsPerGroup)1613 verifyCollapsedShape(Operation *op, ArrayRef<int64_t> collapsedShape,
1614                      ArrayRef<int64_t> expandedShape,
1615                      ArrayRef<ReassociationIndices> reassociation,
1616                      bool allowMultipleDynamicDimsPerGroup) {
1617   // There must be one reassociation group per collapsed dimension.
1618   if (collapsedShape.size() != reassociation.size())
1619     return op->emitOpError("invalid number of reassociation groups: found ")
1620            << reassociation.size() << ", expected " << collapsedShape.size();
1621 
1622   // The next expected expanded dimension index (while iterating over
1623   // reassociation indices).
1624   int64_t nextDim = 0;
1625   for (const auto &it : llvm::enumerate(reassociation)) {
1626     ReassociationIndices group = it.value();
1627     int64_t collapsedDim = it.index();
1628 
1629     bool foundDynamic = false;
1630     for (int64_t expandedDim : group) {
1631       if (expandedDim != nextDim++)
1632         return op->emitOpError("reassociation indices must be contiguous");
1633 
1634       if (expandedDim >= static_cast<int64_t>(expandedShape.size()))
1635         return op->emitOpError("reassociation index ")
1636                << expandedDim << " is out of bounds";
1637 
1638       // Check if there are multiple dynamic dims in a reassociation group.
1639       if (ShapedType::isDynamic(expandedShape[expandedDim])) {
1640         if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
1641           return op->emitOpError(
1642               "at most one dimension in a reassociation group may be dynamic");
1643         foundDynamic = true;
1644       }
1645     }
1646 
1647     // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity.
1648     if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
1649       return op->emitOpError("collapsed dim (")
1650              << collapsedDim
1651              << ") must be dynamic if and only if reassociation group is "
1652                 "dynamic";
1653 
1654     // If all dims in the reassociation group are static, the size of the
1655     // collapsed dim can be verified.
1656     if (!foundDynamic) {
1657       int64_t groupSize = 1;
1658       for (int64_t expandedDim : group)
1659         groupSize *= expandedShape[expandedDim];
1660       if (groupSize != collapsedShape[collapsedDim])
1661         return op->emitOpError("collapsed dim size (")
1662                << collapsedShape[collapsedDim]
1663                << ") must equal reassociation group size (" << groupSize << ")";
1664     }
1665   }
1666 
1667   if (collapsedShape.empty()) {
1668     // Rank 0: All expanded dimensions must be 1.
1669     for (int64_t d : expandedShape)
1670       if (d != 1)
1671         return op->emitOpError(
1672             "rank 0 memrefs can only be extended/collapsed with/from ones");
1673   } else if (nextDim != static_cast<int64_t>(expandedShape.size())) {
1674     // Rank >= 1: Number of dimensions among all reassociation groups must match
1675     // the result memref rank.
1676     return op->emitOpError("expanded rank (")
1677            << expandedShape.size()
1678            << ") inconsistent with number of reassociation indices (" << nextDim
1679            << ")";
1680   }
1681 
1682   return success();
1683 }
1684 
getReassociationMaps()1685 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1686   return getSymbolLessAffineMaps(getReassociationExprs());
1687 }
1688 
getReassociationExprs()1689 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1690   return convertReassociationIndicesToExprs(getContext(),
1691                                             getReassociationIndices());
1692 }
1693 
getReassociationMaps()1694 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1695   return getSymbolLessAffineMaps(getReassociationExprs());
1696 }
1697 
getReassociationExprs()1698 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1699   return convertReassociationIndicesToExprs(getContext(),
1700                                             getReassociationIndices());
1701 }
1702 
1703 /// Compute the layout map after expanding a given source MemRef type with the
1704 /// specified reassociation indices.
1705 static FailureOr<AffineMap>
computeExpandedLayoutMap(MemRefType srcType,ArrayRef<int64_t> resultShape,ArrayRef<ReassociationIndices> reassociation)1706 computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
1707                          ArrayRef<ReassociationIndices> reassociation) {
1708   int64_t srcOffset;
1709   SmallVector<int64_t> srcStrides;
1710   if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
1711     return failure();
1712   assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
1713 
1714   // 1-1 mapping between srcStrides and reassociation packs.
1715   // Each srcStride starts with the given value and gets expanded according to
1716   // the proper entries in resultShape.
1717   // Example:
1718   //   srcStrides     =                   [10000,  1 ,    100   ],
1719   //   reassociations =                   [  [0], [1], [2, 3, 4]],
1720   //   resultSizes    = [2, 5, 4, 3, 2] = [  [2], [5], [4, 3, 2]]
1721   //     -> For the purpose of stride calculation, the useful sizes are:
1722   //                    [x, x, x, 3, 2] = [  [x], [x], [x, 3, 2]].
1723   //   resultStrides = [10000, 1, 600, 200, 100]
1724   // Note that a stride does not get expanded along the first entry of each
1725   // shape pack.
1726   SmallVector<int64_t> reverseResultStrides;
1727   reverseResultStrides.reserve(resultShape.size());
1728   unsigned shapeIndex = resultShape.size() - 1;
1729   for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
1730     ReassociationIndices reassoc = std::get<0>(it);
1731     int64_t currentStrideToExpand = std::get<1>(it);
1732     for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
1733       using saturated_arith::Wrapper;
1734       reverseResultStrides.push_back(currentStrideToExpand);
1735       currentStrideToExpand = (Wrapper::stride(currentStrideToExpand) *
1736                                Wrapper::size(resultShape[shapeIndex--]))
1737                                   .asStride();
1738     }
1739   }
1740   auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
1741   resultStrides.resize(resultShape.size(), 1);
1742   return makeStridedLinearLayoutMap(resultStrides, srcOffset,
1743                                     srcType.getContext());
1744 }
1745 
1746 static FailureOr<MemRefType>
computeExpandedType(MemRefType srcType,ArrayRef<int64_t> resultShape,ArrayRef<ReassociationIndices> reassociation)1747 computeExpandedType(MemRefType srcType, ArrayRef<int64_t> resultShape,
1748                     ArrayRef<ReassociationIndices> reassociation) {
1749   if (srcType.getLayout().isIdentity()) {
1750     // If the source is contiguous (i.e., no layout map specified), so is the
1751     // result.
1752     MemRefLayoutAttrInterface layout;
1753     return MemRefType::get(resultShape, srcType.getElementType(), layout,
1754                            srcType.getMemorySpace());
1755   }
1756 
1757   // Source may not be contiguous. Compute the layout map.
1758   FailureOr<AffineMap> computedLayout =
1759       computeExpandedLayoutMap(srcType, resultShape, reassociation);
1760   if (failed(computedLayout))
1761     return failure();
1762   auto computedType =
1763       MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
1764                       srcType.getMemorySpaceAsInt());
1765   return canonicalizeStridedLayout(computedType);
1766 }
1767 
build(OpBuilder & builder,OperationState & result,ArrayRef<int64_t> resultShape,Value src,ArrayRef<ReassociationIndices> reassociation)1768 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1769                           ArrayRef<int64_t> resultShape, Value src,
1770                           ArrayRef<ReassociationIndices> reassociation) {
1771   // Only ranked memref source values are supported.
1772   auto srcType = src.getType().cast<MemRefType>();
1773   FailureOr<MemRefType> resultType =
1774       computeExpandedType(srcType, resultShape, reassociation);
1775   // Failure of this assertion usually indicates a problem with the source
1776   // type, e.g., could not get strides/offset.
1777   assert(succeeded(resultType) && "could not compute layout");
1778   build(builder, result, *resultType, src, reassociation);
1779 }
1780 
verify()1781 LogicalResult ExpandShapeOp::verify() {
1782   MemRefType srcType = getSrcType();
1783   MemRefType resultType = getResultType();
1784 
1785   // Verify result shape.
1786   if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
1787                                   resultType.getShape(),
1788                                   getReassociationIndices(),
1789                                   /*allowMultipleDynamicDimsPerGroup=*/false)))
1790     return failure();
1791 
1792   // Compute expected result type (including layout map).
1793   FailureOr<MemRefType> expectedResultType = computeExpandedType(
1794       srcType, resultType.getShape(), getReassociationIndices());
1795   if (failed(expectedResultType))
1796     return emitOpError("invalid source layout map");
1797 
1798   // Check actual result type.
1799   auto canonicalizedResultType = canonicalizeStridedLayout(resultType);
1800   if (*expectedResultType != canonicalizedResultType)
1801     return emitOpError("expected expanded type to be ")
1802            << *expectedResultType << " but found " << canonicalizedResultType;
1803 
1804   return success();
1805 }
1806 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1807 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1808                                                 MLIRContext *context) {
1809   results.add<ComposeReassociativeReshapeOps<ExpandShapeOp>,
1810               ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(
1811       context);
1812 }
1813 
1814 /// Compute the layout map after collapsing a given source MemRef type with the
1815 /// specified reassociation indices.
1816 ///
1817 /// Note: All collapsed dims in a reassociation group must be contiguous. It is
1818 /// not possible to check this by inspecting a MemRefType in the general case.
1819 /// If non-contiguity cannot be checked statically, the collapse is assumed to
1820 /// be valid (and thus accepted by this function) unless `strict = true`.
1821 static FailureOr<AffineMap>
computeCollapsedLayoutMap(MemRefType srcType,ArrayRef<ReassociationIndices> reassociation,bool strict=false)1822 computeCollapsedLayoutMap(MemRefType srcType,
1823                           ArrayRef<ReassociationIndices> reassociation,
1824                           bool strict = false) {
1825   int64_t srcOffset;
1826   SmallVector<int64_t> srcStrides;
1827   auto srcShape = srcType.getShape();
1828   if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
1829     return failure();
1830 
1831   // The result stride of a reassociation group is the stride of the last entry
1832   // of the reassociation. (TODO: Should be the minimum stride in the
1833   // reassociation because strides are not necessarily sorted. E.g., when using
1834   // memref.transpose.) Dimensions of size 1 should be skipped, because their
1835   // strides are meaningless and could have any arbitrary value.
1836   SmallVector<int64_t> resultStrides;
1837   resultStrides.reserve(reassociation.size());
1838   for (const ReassociationIndices &reassoc : reassociation) {
1839     ArrayRef<int64_t> ref = llvm::makeArrayRef(reassoc);
1840     while (srcShape[ref.back()] == 1 && ref.size() > 1)
1841       ref = ref.drop_back();
1842     if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
1843       resultStrides.push_back(srcStrides[ref.back()]);
1844     } else {
1845       // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
1846       // the corresponding stride may have to be skipped. (See above comment.)
1847       // Therefore, the result stride cannot be statically determined and must
1848       // be dynamic.
1849       resultStrides.push_back(ShapedType::kDynamicStrideOrOffset);
1850     }
1851   }
1852 
1853   // Validate that each reassociation group is contiguous.
1854   unsigned resultStrideIndex = resultStrides.size() - 1;
1855   for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) {
1856     auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
1857     using saturated_arith::Wrapper;
1858     auto stride = Wrapper::stride(resultStrides[resultStrideIndex--]);
1859     for (int64_t idx : llvm::reverse(trailingReassocs)) {
1860       stride = stride * Wrapper::size(srcShape[idx]);
1861 
1862       // Both source and result stride must have the same static value. In that
1863       // case, we can be sure, that the dimensions are collapsible (because they
1864       // are contiguous).
1865       //
1866       // One special case is when the srcShape is `1`, in which case it can
1867       // never produce non-contiguity.
1868       if (srcShape[idx] == 1)
1869         continue;
1870 
1871       // If `strict = false` (default during op verification), we accept cases
1872       // where one or both strides are dynamic. This is best effort: We reject
1873       // ops where obviously non-contiguous dims are collapsed, but accept ops
1874       // where we cannot be sure statically. Such ops may fail at runtime. See
1875       // the op documentation for details.
1876       auto srcStride = Wrapper::stride(srcStrides[idx - 1]);
1877       if (strict && (stride.saturated || srcStride.saturated))
1878         return failure();
1879 
1880       if (!stride.saturated && !srcStride.saturated && stride != srcStride)
1881         return failure();
1882     }
1883   }
1884   return makeStridedLinearLayoutMap(resultStrides, srcOffset,
1885                                     srcType.getContext());
1886 }
1887 
isGuaranteedCollapsible(MemRefType srcType,ArrayRef<ReassociationIndices> reassociation)1888 bool CollapseShapeOp::isGuaranteedCollapsible(
1889     MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
1890   // MemRefs with standard layout are always collapsible.
1891   if (srcType.getLayout().isIdentity())
1892     return true;
1893 
1894   return succeeded(computeCollapsedLayoutMap(srcType, reassociation,
1895                                              /*strict=*/true));
1896 }
1897 
1898 static MemRefType
computeCollapsedType(MemRefType srcType,ArrayRef<ReassociationIndices> reassociation)1899 computeCollapsedType(MemRefType srcType,
1900                      ArrayRef<ReassociationIndices> reassociation) {
1901   SmallVector<int64_t> resultShape;
1902   resultShape.reserve(reassociation.size());
1903   for (const ReassociationIndices &group : reassociation) {
1904     using saturated_arith::Wrapper;
1905     auto groupSize = Wrapper::size(1);
1906     for (int64_t srcDim : group)
1907       groupSize = groupSize * Wrapper::size(srcType.getDimSize(srcDim));
1908     resultShape.push_back(groupSize.asSize());
1909   }
1910 
1911   if (srcType.getLayout().isIdentity()) {
1912     // If the source is contiguous (i.e., no layout map specified), so is the
1913     // result.
1914     MemRefLayoutAttrInterface layout;
1915     return MemRefType::get(resultShape, srcType.getElementType(), layout,
1916                            srcType.getMemorySpace());
1917   }
1918 
1919   // Source may not be fully contiguous. Compute the layout map.
1920   // Note: Dimensions that are collapsed into a single dim are assumed to be
1921   // contiguous.
1922   FailureOr<AffineMap> computedLayout =
1923       computeCollapsedLayoutMap(srcType, reassociation);
1924   assert(succeeded(computedLayout) &&
1925          "invalid source layout map or collapsing non-contiguous dims");
1926   auto computedType =
1927       MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
1928                       srcType.getMemorySpaceAsInt());
1929   return canonicalizeStridedLayout(computedType);
1930 }
1931 
build(OpBuilder & b,OperationState & result,Value src,ArrayRef<ReassociationIndices> reassociation,ArrayRef<NamedAttribute> attrs)1932 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
1933                             ArrayRef<ReassociationIndices> reassociation,
1934                             ArrayRef<NamedAttribute> attrs) {
1935   auto srcType = src.getType().cast<MemRefType>();
1936   MemRefType resultType = computeCollapsedType(srcType, reassociation);
1937   build(b, result, resultType, src, attrs);
1938   result.addAttribute(::mlir::getReassociationAttrName(),
1939                       getReassociationIndicesAttribute(b, reassociation));
1940 }
1941 
verify()1942 LogicalResult CollapseShapeOp::verify() {
1943   MemRefType srcType = getSrcType();
1944   MemRefType resultType = getResultType();
1945 
1946   // Verify result shape.
1947   if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
1948                                   srcType.getShape(), getReassociationIndices(),
1949                                   /*allowMultipleDynamicDimsPerGroup=*/true)))
1950     return failure();
1951 
1952   // Compute expected result type (including layout map).
1953   MemRefType expectedResultType;
1954   if (srcType.getLayout().isIdentity()) {
1955     // If the source is contiguous (i.e., no layout map specified), so is the
1956     // result.
1957     MemRefLayoutAttrInterface layout;
1958     expectedResultType =
1959         MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
1960                         srcType.getMemorySpace());
1961   } else {
1962     // Source may not be fully contiguous. Compute the layout map.
1963     // Note: Dimensions that are collapsed into a single dim are assumed to be
1964     // contiguous.
1965     FailureOr<AffineMap> computedLayout =
1966         computeCollapsedLayoutMap(srcType, getReassociationIndices());
1967     if (failed(computedLayout))
1968       return emitOpError(
1969           "invalid source layout map or collapsing non-contiguous dims");
1970     auto computedType =
1971         MemRefType::get(resultType.getShape(), srcType.getElementType(),
1972                         *computedLayout, srcType.getMemorySpaceAsInt());
1973     expectedResultType = canonicalizeStridedLayout(computedType);
1974   }
1975 
1976   auto canonicalizedResultType = canonicalizeStridedLayout(resultType);
1977   if (expectedResultType != canonicalizedResultType)
1978     return emitOpError("expected collapsed type to be ")
1979            << expectedResultType << " but found " << canonicalizedResultType;
1980 
1981   return success();
1982 }
1983 
1984 struct CollapseShapeOpMemRefCastFolder
1985     : public OpRewritePattern<CollapseShapeOp> {
1986 public:
1987   using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
1988 
matchAndRewriteCollapseShapeOpMemRefCastFolder1989   LogicalResult matchAndRewrite(CollapseShapeOp op,
1990                                 PatternRewriter &rewriter) const override {
1991     auto cast = op.getOperand().getDefiningOp<CastOp>();
1992     if (!cast)
1993       return failure();
1994 
1995     if (!CastOp::canFoldIntoConsumerOp(cast))
1996       return failure();
1997 
1998     Type newResultType =
1999         computeCollapsedType(cast.getOperand().getType().cast<MemRefType>(),
2000                              op.getReassociationIndices());
2001 
2002     if (newResultType == op.getResultType()) {
2003       rewriter.updateRootInPlace(
2004           op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2005     } else {
2006       Value newOp = rewriter.create<CollapseShapeOp>(
2007           op->getLoc(), cast.getSource(), op.getReassociationIndices());
2008       rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2009     }
2010     return success();
2011   }
2012 };
2013 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2014 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2015                                                   MLIRContext *context) {
2016   results.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
2017               ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp>,
2018               CollapseShapeOpMemRefCastFolder>(context);
2019 }
2020 
fold(ArrayRef<Attribute> operands)2021 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
2022   return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
2023 }
2024 
fold(ArrayRef<Attribute> operands)2025 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
2026   return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
2027 }
2028 
2029 //===----------------------------------------------------------------------===//
2030 // ReshapeOp
2031 //===----------------------------------------------------------------------===//
2032 
verify()2033 LogicalResult ReshapeOp::verify() {
2034   Type operandType = getSource().getType();
2035   Type resultType = getResult().getType();
2036 
2037   Type operandElementType = operandType.cast<ShapedType>().getElementType();
2038   Type resultElementType = resultType.cast<ShapedType>().getElementType();
2039   if (operandElementType != resultElementType)
2040     return emitOpError("element types of source and destination memref "
2041                        "types should be the same");
2042 
2043   if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
2044     if (!operandMemRefType.getLayout().isIdentity())
2045       return emitOpError("source memref type should have identity affine map");
2046 
2047   int64_t shapeSize = getShape().getType().cast<MemRefType>().getDimSize(0);
2048   auto resultMemRefType = resultType.dyn_cast<MemRefType>();
2049   if (resultMemRefType) {
2050     if (!resultMemRefType.getLayout().isIdentity())
2051       return emitOpError("result memref type should have identity affine map");
2052     if (shapeSize == ShapedType::kDynamicSize)
2053       return emitOpError("cannot use shape operand with dynamic length to "
2054                          "reshape to statically-ranked memref type");
2055     if (shapeSize != resultMemRefType.getRank())
2056       return emitOpError(
2057           "length of shape operand differs from the result's memref rank");
2058   }
2059   return success();
2060 }
2061 
2062 //===----------------------------------------------------------------------===//
2063 // StoreOp
2064 //===----------------------------------------------------------------------===//
2065 
verify()2066 LogicalResult StoreOp::verify() {
2067   if (getNumOperands() != 2 + getMemRefType().getRank())
2068     return emitOpError("store index operand count not equal to memref rank");
2069 
2070   return success();
2071 }
2072 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)2073 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
2074                             SmallVectorImpl<OpFoldResult> &results) {
2075   /// store(memrefcast) -> store
2076   return foldMemRefCast(*this, getValueToStore());
2077 }
2078 
2079 //===----------------------------------------------------------------------===//
2080 // SubViewOp
2081 //===----------------------------------------------------------------------===//
2082 
2083 /// A subview result type can be fully inferred from the source type and the
2084 /// static representation of offsets, sizes and strides. Special sentinels
2085 /// encode the dynamic case.
inferResultType(MemRefType sourceMemRefType,ArrayRef<int64_t> staticOffsets,ArrayRef<int64_t> staticSizes,ArrayRef<int64_t> staticStrides)2086 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2087                                 ArrayRef<int64_t> staticOffsets,
2088                                 ArrayRef<int64_t> staticSizes,
2089                                 ArrayRef<int64_t> staticStrides) {
2090   unsigned rank = sourceMemRefType.getRank();
2091   (void)rank;
2092   assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
2093   assert(staticSizes.size() == rank && "staticSizes length mismatch");
2094   assert(staticStrides.size() == rank && "staticStrides length mismatch");
2095 
2096   // Extract source offset and strides.
2097   int64_t sourceOffset;
2098   SmallVector<int64_t, 4> sourceStrides;
2099   auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset);
2100   assert(succeeded(res) && "SubViewOp expected strided memref type");
2101   (void)res;
2102 
2103   // Compute target offset whose value is:
2104   //   `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
2105   int64_t targetOffset = sourceOffset;
2106   for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
2107     auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
2108     using saturated_arith::Wrapper;
2109     targetOffset =
2110         (Wrapper::offset(targetOffset) +
2111          Wrapper::offset(staticOffset) * Wrapper::stride(targetStride))
2112             .asOffset();
2113   }
2114 
2115   // Compute target stride whose value is:
2116   //   `sourceStrides_i * staticStrides_i`.
2117   SmallVector<int64_t, 4> targetStrides;
2118   targetStrides.reserve(staticOffsets.size());
2119   for (auto it : llvm::zip(sourceStrides, staticStrides)) {
2120     auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2121     using saturated_arith::Wrapper;
2122     targetStrides.push_back(
2123         (Wrapper::stride(sourceStride) * Wrapper::stride(staticStride))
2124             .asStride());
2125   }
2126 
2127   // The type is now known.
2128   return MemRefType::get(
2129       staticSizes, sourceMemRefType.getElementType(),
2130       makeStridedLinearLayoutMap(targetStrides, targetOffset,
2131                                  sourceMemRefType.getContext()),
2132       sourceMemRefType.getMemorySpace());
2133 }
2134 
inferResultType(MemRefType sourceMemRefType,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides)2135 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2136                                 ArrayRef<OpFoldResult> offsets,
2137                                 ArrayRef<OpFoldResult> sizes,
2138                                 ArrayRef<OpFoldResult> strides) {
2139   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2140   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2141   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
2142                              ShapedType::kDynamicStrideOrOffset);
2143   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
2144                              ShapedType::kDynamicSize);
2145   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
2146                              ShapedType::kDynamicStrideOrOffset);
2147   return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2148                                     staticSizes, staticStrides);
2149 }
2150 
inferRankReducedResultType(ArrayRef<int64_t> resultShape,MemRefType sourceRankedTensorType,ArrayRef<int64_t> offsets,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides)2151 Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2152                                            MemRefType sourceRankedTensorType,
2153                                            ArrayRef<int64_t> offsets,
2154                                            ArrayRef<int64_t> sizes,
2155                                            ArrayRef<int64_t> strides) {
2156   auto inferredType =
2157       inferResultType(sourceRankedTensorType, offsets, sizes, strides)
2158           .cast<MemRefType>();
2159   assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
2160          "expected ");
2161   if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
2162     return inferredType;
2163 
2164   // Compute which dimensions are dropped.
2165   Optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2166       computeRankReductionMask(inferredType.getShape(), resultShape);
2167   assert(dimsToProject.has_value() && "invalid rank reduction");
2168   llvm::SmallBitVector dimsToProjectVector(inferredType.getRank());
2169   for (unsigned dim : *dimsToProject)
2170     dimsToProjectVector.set(dim);
2171 
2172   // Compute layout map and result type.
2173   AffineMap map = getProjectedMap(inferredType.getLayout().getAffineMap(),
2174                                   dimsToProjectVector);
2175   return MemRefType::get(resultShape, inferredType.getElementType(), map,
2176                          inferredType.getMemorySpace());
2177 }
2178 
inferRankReducedResultType(ArrayRef<int64_t> resultShape,MemRefType sourceRankedTensorType,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides)2179 Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2180                                            MemRefType sourceRankedTensorType,
2181                                            ArrayRef<OpFoldResult> offsets,
2182                                            ArrayRef<OpFoldResult> sizes,
2183                                            ArrayRef<OpFoldResult> strides) {
2184   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2185   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2186   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
2187                              ShapedType::kDynamicStrideOrOffset);
2188   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
2189                              ShapedType::kDynamicSize);
2190   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
2191                              ShapedType::kDynamicStrideOrOffset);
2192   return SubViewOp::inferRankReducedResultType(
2193       resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2194       staticStrides);
2195 }
2196 
2197 // Build a SubViewOp with mixed static and dynamic entries and custom result
2198 // type. If the type passed is nullptr, it is inferred.
build(OpBuilder & b,OperationState & result,MemRefType resultType,Value source,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides,ArrayRef<NamedAttribute> attrs)2199 void SubViewOp::build(OpBuilder &b, OperationState &result,
2200                       MemRefType resultType, Value source,
2201                       ArrayRef<OpFoldResult> offsets,
2202                       ArrayRef<OpFoldResult> sizes,
2203                       ArrayRef<OpFoldResult> strides,
2204                       ArrayRef<NamedAttribute> attrs) {
2205   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2206   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2207   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
2208                              ShapedType::kDynamicStrideOrOffset);
2209   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
2210                              ShapedType::kDynamicSize);
2211   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
2212                              ShapedType::kDynamicStrideOrOffset);
2213   auto sourceMemRefType = source.getType().cast<MemRefType>();
2214   // Structuring implementation this way avoids duplication between builders.
2215   if (!resultType) {
2216     resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2217                                             staticSizes, staticStrides)
2218                      .cast<MemRefType>();
2219   }
2220   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2221         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
2222         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
2223   result.addAttributes(attrs);
2224 }
2225 
2226 // Build a SubViewOp with mixed static and dynamic entries and inferred result
2227 // type.
build(OpBuilder & b,OperationState & result,Value source,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides,ArrayRef<NamedAttribute> attrs)2228 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2229                       ArrayRef<OpFoldResult> offsets,
2230                       ArrayRef<OpFoldResult> sizes,
2231                       ArrayRef<OpFoldResult> strides,
2232                       ArrayRef<NamedAttribute> attrs) {
2233   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2234 }
2235 
2236 // Build a SubViewOp with static entries and inferred result type.
build(OpBuilder & b,OperationState & result,Value source,ArrayRef<int64_t> offsets,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides,ArrayRef<NamedAttribute> attrs)2237 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2238                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2239                       ArrayRef<int64_t> strides,
2240                       ArrayRef<NamedAttribute> attrs) {
2241   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2242       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2243         return b.getI64IntegerAttr(v);
2244       }));
2245   SmallVector<OpFoldResult> sizeValues =
2246       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2247         return b.getI64IntegerAttr(v);
2248       }));
2249   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2250       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2251         return b.getI64IntegerAttr(v);
2252       }));
2253   build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2254 }
2255 
2256 // Build a SubViewOp with dynamic entries and custom result type. If the
2257 // type passed is nullptr, it is inferred.
build(OpBuilder & b,OperationState & result,MemRefType resultType,Value source,ArrayRef<int64_t> offsets,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides,ArrayRef<NamedAttribute> attrs)2258 void SubViewOp::build(OpBuilder &b, OperationState &result,
2259                       MemRefType resultType, Value source,
2260                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2261                       ArrayRef<int64_t> strides,
2262                       ArrayRef<NamedAttribute> attrs) {
2263   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2264       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2265         return b.getI64IntegerAttr(v);
2266       }));
2267   SmallVector<OpFoldResult> sizeValues =
2268       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2269         return b.getI64IntegerAttr(v);
2270       }));
2271   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2272       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2273         return b.getI64IntegerAttr(v);
2274       }));
2275   build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
2276         attrs);
2277 }
2278 
2279 // Build a SubViewOp with dynamic entries and custom result type. If the type
2280 // passed is nullptr, it is inferred.
build(OpBuilder & b,OperationState & result,MemRefType resultType,Value source,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)2281 void SubViewOp::build(OpBuilder &b, OperationState &result,
2282                       MemRefType resultType, Value source, ValueRange offsets,
2283                       ValueRange sizes, ValueRange strides,
2284                       ArrayRef<NamedAttribute> attrs) {
2285   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2286       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2287   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2288       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2289   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2290       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2291   build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2292 }
2293 
2294 // Build a SubViewOp with dynamic entries and inferred result type.
build(OpBuilder & b,OperationState & result,Value source,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)2295 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2296                       ValueRange offsets, ValueRange sizes, ValueRange strides,
2297                       ArrayRef<NamedAttribute> attrs) {
2298   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2299 }
2300 
2301 /// For ViewLikeOpInterface.
getViewSource()2302 Value SubViewOp::getViewSource() { return getSource(); }
2303 
2304 /// Return true if t1 and t2 have equal offsets (both dynamic or of same
2305 /// static value).
haveCompatibleOffsets(MemRefType t1,MemRefType t2)2306 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
2307   AffineExpr t1Offset, t2Offset;
2308   SmallVector<AffineExpr> t1Strides, t2Strides;
2309   auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
2310   auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
2311   return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
2312 }
2313 
2314 /// Checks if `original` Type type can be rank reduced to `reduced` type.
2315 /// This function is slight variant of `is subsequence` algorithm where
2316 /// not matching dimension must be 1.
2317 static SliceVerificationResult
isRankReducedMemRefType(MemRefType originalType,MemRefType candidateRankReducedType,ArrayRef<OpFoldResult> sizes)2318 isRankReducedMemRefType(MemRefType originalType,
2319                         MemRefType candidateRankReducedType,
2320                         ArrayRef<OpFoldResult> sizes) {
2321   auto partialRes = isRankReducedType(originalType, candidateRankReducedType);
2322   if (partialRes != SliceVerificationResult::Success)
2323     return partialRes;
2324 
2325   auto optionalUnusedDimsMask = computeMemRefRankReductionMask(
2326       originalType, candidateRankReducedType, sizes);
2327 
2328   // Sizes cannot be matched in case empty vector is returned.
2329   if (!optionalUnusedDimsMask)
2330     return SliceVerificationResult::LayoutMismatch;
2331 
2332   if (originalType.getMemorySpace() !=
2333       candidateRankReducedType.getMemorySpace())
2334     return SliceVerificationResult::MemSpaceMismatch;
2335 
2336   // No amount of stride dropping can reconcile incompatible offsets.
2337   if (!haveCompatibleOffsets(originalType, candidateRankReducedType))
2338     return SliceVerificationResult::LayoutMismatch;
2339 
2340   return SliceVerificationResult::Success;
2341 }
2342 
2343 template <typename OpTy>
produceSubViewErrorMsg(SliceVerificationResult result,OpTy op,Type expectedType)2344 static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result,
2345                                             OpTy op, Type expectedType) {
2346   auto memrefType = expectedType.cast<ShapedType>();
2347   switch (result) {
2348   case SliceVerificationResult::Success:
2349     return success();
2350   case SliceVerificationResult::RankTooLarge:
2351     return op.emitError("expected result rank to be smaller or equal to ")
2352            << "the source rank. ";
2353   case SliceVerificationResult::SizeMismatch:
2354     return op.emitError("expected result type to be ")
2355            << expectedType
2356            << " or a rank-reduced version. (mismatch of result sizes) ";
2357   case SliceVerificationResult::ElemTypeMismatch:
2358     return op.emitError("expected result element type to be ")
2359            << memrefType.getElementType();
2360   case SliceVerificationResult::MemSpaceMismatch:
2361     return op.emitError("expected result and source memory spaces to match.");
2362   case SliceVerificationResult::LayoutMismatch:
2363     return op.emitError("expected result type to be ")
2364            << expectedType
2365            << " or a rank-reduced version. (mismatch of result layout) ";
2366   }
2367   llvm_unreachable("unexpected subview verification result");
2368 }
2369 
2370 /// Verifier for SubViewOp.
verify()2371 LogicalResult SubViewOp::verify() {
2372   MemRefType baseType = getSourceType();
2373   MemRefType subViewType = getType();
2374 
2375   // The base memref and the view memref should be in the same memory space.
2376   if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2377     return emitError("different memory spaces specified for base memref "
2378                      "type ")
2379            << baseType << " and subview memref type " << subViewType;
2380 
2381   // Verify that the base memref type has a strided layout map.
2382   if (!isStrided(baseType))
2383     return emitError("base type ") << baseType << " is not strided";
2384 
2385   // Verify result type against inferred type.
2386   auto expectedType = SubViewOp::inferResultType(
2387       baseType, extractFromI64ArrayAttr(getStaticOffsets()),
2388       extractFromI64ArrayAttr(getStaticSizes()),
2389       extractFromI64ArrayAttr(getStaticStrides()));
2390 
2391   auto result = isRankReducedMemRefType(expectedType.cast<MemRefType>(),
2392                                         subViewType, getMixedSizes());
2393   return produceSubViewErrorMsg(result, *this, expectedType);
2394 }
2395 
operator <<(raw_ostream & os,const Range & range)2396 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
2397   return os << "range " << range.offset << ":" << range.size << ":"
2398             << range.stride;
2399 }
2400 
2401 /// Return the list of Range (i.e. offset, size, stride). Each Range
2402 /// entry contains either the dynamic value or a ConstantIndexOp constructed
2403 /// with `b` at location `loc`.
getOrCreateRanges(OffsetSizeAndStrideOpInterface op,OpBuilder & b,Location loc)2404 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
2405                                               OpBuilder &b, Location loc) {
2406   std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
2407   assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
2408   assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
2409   SmallVector<Range, 8> res;
2410   unsigned rank = ranks[0];
2411   res.reserve(rank);
2412   for (unsigned idx = 0; idx < rank; ++idx) {
2413     Value offset =
2414         op.isDynamicOffset(idx)
2415             ? op.getDynamicOffset(idx)
2416             : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx));
2417     Value size =
2418         op.isDynamicSize(idx)
2419             ? op.getDynamicSize(idx)
2420             : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx));
2421     Value stride =
2422         op.isDynamicStride(idx)
2423             ? op.getDynamicStride(idx)
2424             : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx));
2425     res.emplace_back(Range{offset, size, stride});
2426   }
2427   return res;
2428 }
2429 
2430 /// Compute the canonical result type of a SubViewOp. Call `inferResultType`
2431 /// to deduce the result type for the given `sourceType`. Additionally, reduce
2432 /// the rank of the inferred result type if `currentResultType` is lower rank
2433 /// than `currentSourceType`. Use this signature if `sourceType` is updated
2434 /// together with the result type. In this case, it is important to compute
2435 /// the dropped dimensions using `currentSourceType` whose strides align with
2436 /// `currentResultType`.
getCanonicalSubViewResultType(MemRefType currentResultType,MemRefType currentSourceType,MemRefType sourceType,ArrayRef<OpFoldResult> mixedOffsets,ArrayRef<OpFoldResult> mixedSizes,ArrayRef<OpFoldResult> mixedStrides)2437 static MemRefType getCanonicalSubViewResultType(
2438     MemRefType currentResultType, MemRefType currentSourceType,
2439     MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
2440     ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
2441   auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets,
2442                                                        mixedSizes, mixedStrides)
2443                                 .cast<MemRefType>();
2444   llvm::Optional<llvm::SmallBitVector> unusedDims =
2445       computeMemRefRankReductionMask(currentSourceType, currentResultType,
2446                                      mixedSizes);
2447   // Return nullptr as failure mode.
2448   if (!unusedDims)
2449     return nullptr;
2450   SmallVector<int64_t> shape;
2451   for (const auto &sizes : llvm::enumerate(nonRankReducedType.getShape())) {
2452     if (unusedDims->test(sizes.index()))
2453       continue;
2454     shape.push_back(sizes.value());
2455   }
2456   AffineMap layoutMap = nonRankReducedType.getLayout().getAffineMap();
2457   if (!layoutMap.isIdentity())
2458     layoutMap = getProjectedMap(layoutMap, *unusedDims);
2459   return MemRefType::get(shape, nonRankReducedType.getElementType(), layoutMap,
2460                          nonRankReducedType.getMemorySpace());
2461 }
2462 
2463 /// Compute the canonical result type of a SubViewOp. Call `inferResultType`
2464 /// to deduce the result type. Additionally, reduce the rank of the inferred
2465 /// result type if `currentResultType` is lower rank than `sourceType`.
getCanonicalSubViewResultType(MemRefType currentResultType,MemRefType sourceType,ArrayRef<OpFoldResult> mixedOffsets,ArrayRef<OpFoldResult> mixedSizes,ArrayRef<OpFoldResult> mixedStrides)2466 static MemRefType getCanonicalSubViewResultType(
2467     MemRefType currentResultType, MemRefType sourceType,
2468     ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes,
2469     ArrayRef<OpFoldResult> mixedStrides) {
2470   return getCanonicalSubViewResultType(currentResultType, sourceType,
2471                                        sourceType, mixedOffsets, mixedSizes,
2472                                        mixedStrides);
2473 }
2474 
2475 /// Helper method to check if a `subview` operation is trivially a no-op. This
2476 /// is the case if the all offsets are zero, all strides are 1, and the source
2477 /// shape is same as the size of the subview. In such cases, the subview can
2478 /// be folded into its source.
isTrivialSubViewOp(SubViewOp subViewOp)2479 static bool isTrivialSubViewOp(SubViewOp subViewOp) {
2480   if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
2481     return false;
2482 
2483   auto mixedOffsets = subViewOp.getMixedOffsets();
2484   auto mixedSizes = subViewOp.getMixedSizes();
2485   auto mixedStrides = subViewOp.getMixedStrides();
2486 
2487   // Check offsets are zero.
2488   if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
2489         Optional<int64_t> intValue = getConstantIntValue(ofr);
2490         return !intValue || intValue.value() != 0;
2491       }))
2492     return false;
2493 
2494   // Check strides are one.
2495   if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
2496         Optional<int64_t> intValue = getConstantIntValue(ofr);
2497         return !intValue || intValue.value() != 1;
2498       }))
2499     return false;
2500 
2501   // Check all size values are static and matches the (static) source shape.
2502   ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
2503   for (const auto &size : llvm::enumerate(mixedSizes)) {
2504     Optional<int64_t> intValue = getConstantIntValue(size.value());
2505     if (!intValue || *intValue != sourceShape[size.index()])
2506       return false;
2507   }
2508   // All conditions met. The `SubViewOp` is foldable as a no-op.
2509   return true;
2510 }
2511 
2512 namespace {
2513 /// Pattern to rewrite a subview op with MemRefCast arguments.
2514 /// This essentially pushes memref.cast past its consuming subview when
2515 /// `canFoldIntoConsumerOp` is true.
2516 ///
2517 /// Example:
2518 /// ```
2519 ///   %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
2520 ///   %1 = memref.subview %0[0, 0][3, 4][1, 1] :
2521 ///     memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
2522 /// ```
2523 /// is rewritten into:
2524 /// ```
2525 ///   %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
2526 ///   %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to
2527 ///     memref<3x4xf32, offset:?, strides:[?, 1]>
2528 /// ```
2529 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
2530 public:
2531   using OpRewritePattern<SubViewOp>::OpRewritePattern;
2532 
matchAndRewrite(SubViewOp subViewOp,PatternRewriter & rewriter) const2533   LogicalResult matchAndRewrite(SubViewOp subViewOp,
2534                                 PatternRewriter &rewriter) const override {
2535     // Any constant operand, just return to let SubViewOpConstantFolder kick
2536     // in.
2537     if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
2538           return matchPattern(operand, matchConstantIndex());
2539         }))
2540       return failure();
2541 
2542     auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
2543     if (!castOp)
2544       return failure();
2545 
2546     if (!CastOp::canFoldIntoConsumerOp(castOp))
2547       return failure();
2548 
2549     // Compute the SubViewOp result type after folding the MemRefCastOp. Use
2550     // the MemRefCastOp source operand type to infer the result type and the
2551     // current SubViewOp source operand type to compute the dropped dimensions
2552     // if the operation is rank-reducing.
2553     auto resultType = getCanonicalSubViewResultType(
2554         subViewOp.getType(), subViewOp.getSourceType(),
2555         castOp.getSource().getType().cast<MemRefType>(),
2556         subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
2557         subViewOp.getMixedStrides());
2558     if (!resultType)
2559       return failure();
2560 
2561     Value newSubView = rewriter.create<SubViewOp>(
2562         subViewOp.getLoc(), resultType, castOp.getSource(),
2563         subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
2564         subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
2565         subViewOp.getStaticStrides());
2566     rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
2567                                         newSubView);
2568     return success();
2569   }
2570 };
2571 
2572 /// Canonicalize subview ops that are no-ops. When the source shape is not
2573 /// same as a result shape due to use of `affine_map`.
2574 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
2575 public:
2576   using OpRewritePattern<SubViewOp>::OpRewritePattern;
2577 
matchAndRewrite(SubViewOp subViewOp,PatternRewriter & rewriter) const2578   LogicalResult matchAndRewrite(SubViewOp subViewOp,
2579                                 PatternRewriter &rewriter) const override {
2580     if (!isTrivialSubViewOp(subViewOp))
2581       return failure();
2582     if (subViewOp.getSourceType() == subViewOp.getType()) {
2583       rewriter.replaceOp(subViewOp, subViewOp.getSource());
2584       return success();
2585     }
2586     rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
2587                                         subViewOp.getSource());
2588     return success();
2589   }
2590 };
2591 } // namespace
2592 
2593 /// Return the canonical type of the result of a subview.
2594 struct SubViewReturnTypeCanonicalizer {
operator ()SubViewReturnTypeCanonicalizer2595   MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
2596                         ArrayRef<OpFoldResult> mixedSizes,
2597                         ArrayRef<OpFoldResult> mixedStrides) {
2598     return getCanonicalSubViewResultType(op.getType(), op.getSourceType(),
2599                                          mixedOffsets, mixedSizes,
2600                                          mixedStrides);
2601   }
2602 };
2603 
2604 /// A canonicalizer wrapper to replace SubViewOps.
2605 struct SubViewCanonicalizer {
operator ()SubViewCanonicalizer2606   void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
2607     rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2608   }
2609 };
2610 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2611 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
2612                                             MLIRContext *context) {
2613   results
2614       .add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
2615                SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>,
2616            SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
2617 }
2618 
fold(ArrayRef<Attribute> operands)2619 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
2620   auto resultShapedType = getResult().getType().cast<ShapedType>();
2621   auto sourceShapedType = getSource().getType().cast<ShapedType>();
2622 
2623   if (resultShapedType.hasStaticShape() &&
2624       resultShapedType == sourceShapedType) {
2625     return getViewSource();
2626   }
2627 
2628   return {};
2629 }
2630 
2631 //===----------------------------------------------------------------------===//
2632 // TransposeOp
2633 //===----------------------------------------------------------------------===//
2634 
2635 /// Build a strided memref type by applying `permutationMap` tp `memRefType`.
inferTransposeResultType(MemRefType memRefType,AffineMap permutationMap)2636 static MemRefType inferTransposeResultType(MemRefType memRefType,
2637                                            AffineMap permutationMap) {
2638   auto rank = memRefType.getRank();
2639   auto originalSizes = memRefType.getShape();
2640   // Compute permuted sizes.
2641   SmallVector<int64_t, 4> sizes(rank, 0);
2642   for (const auto &en : llvm::enumerate(permutationMap.getResults()))
2643     sizes[en.index()] =
2644         originalSizes[en.value().cast<AffineDimExpr>().getPosition()];
2645 
2646   // Compute permuted strides.
2647   int64_t offset;
2648   SmallVector<int64_t, 4> strides;
2649   auto res = getStridesAndOffset(memRefType, strides, offset);
2650   assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank));
2651   (void)res;
2652   auto map =
2653       makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
2654   map = permutationMap ? map.compose(permutationMap) : map;
2655   return MemRefType::Builder(memRefType)
2656       .setShape(sizes)
2657       .setLayout(AffineMapAttr::get(map));
2658 }
2659 
build(OpBuilder & b,OperationState & result,Value in,AffineMapAttr permutation,ArrayRef<NamedAttribute> attrs)2660 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
2661                         AffineMapAttr permutation,
2662                         ArrayRef<NamedAttribute> attrs) {
2663   auto permutationMap = permutation.getValue();
2664   assert(permutationMap);
2665 
2666   auto memRefType = in.getType().cast<MemRefType>();
2667   // Compute result type.
2668   MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
2669 
2670   build(b, result, resultType, in, attrs);
2671   result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
2672 }
2673 
2674 // transpose $in $permutation attr-dict : type($in) `to` type(results)
print(OpAsmPrinter & p)2675 void TransposeOp::print(OpAsmPrinter &p) {
2676   p << " " << getIn() << " " << getPermutation();
2677   p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
2678   p << " : " << getIn().getType() << " to " << getType();
2679 }
2680 
parse(OpAsmParser & parser,OperationState & result)2681 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
2682   OpAsmParser::UnresolvedOperand in;
2683   AffineMap permutation;
2684   MemRefType srcType, dstType;
2685   if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
2686       parser.parseOptionalAttrDict(result.attributes) ||
2687       parser.parseColonType(srcType) ||
2688       parser.resolveOperand(in, srcType, result.operands) ||
2689       parser.parseKeywordType("to", dstType) ||
2690       parser.addTypeToList(dstType, result.types))
2691     return failure();
2692 
2693   result.addAttribute(TransposeOp::getPermutationAttrStrName(),
2694                       AffineMapAttr::get(permutation));
2695   return success();
2696 }
2697 
verify()2698 LogicalResult TransposeOp::verify() {
2699   if (!getPermutation().isPermutation())
2700     return emitOpError("expected a permutation map");
2701   if (getPermutation().getNumDims() != getShapedType().getRank())
2702     return emitOpError("expected a permutation map of same rank as the input");
2703 
2704   auto srcType = getIn().getType().cast<MemRefType>();
2705   auto dstType = getType().cast<MemRefType>();
2706   auto transposedType = inferTransposeResultType(srcType, getPermutation());
2707   if (dstType != transposedType)
2708     return emitOpError("output type ")
2709            << dstType << " does not match transposed input type " << srcType
2710            << ", " << transposedType;
2711   return success();
2712 }
2713 
fold(ArrayRef<Attribute>)2714 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
2715   if (succeeded(foldMemRefCast(*this)))
2716     return getResult();
2717   return {};
2718 }
2719 
2720 //===----------------------------------------------------------------------===//
2721 // ViewOp
2722 //===----------------------------------------------------------------------===//
2723 
verify()2724 LogicalResult ViewOp::verify() {
2725   auto baseType = getOperand(0).getType().cast<MemRefType>();
2726   auto viewType = getType();
2727 
2728   // The base memref should have identity layout map (or none).
2729   if (!baseType.getLayout().isIdentity())
2730     return emitError("unsupported map for base memref type ") << baseType;
2731 
2732   // The result memref should have identity layout map (or none).
2733   if (!viewType.getLayout().isIdentity())
2734     return emitError("unsupported map for result memref type ") << viewType;
2735 
2736   // The base memref and the view memref should be in the same memory space.
2737   if (baseType.getMemorySpace() != viewType.getMemorySpace())
2738     return emitError("different memory spaces specified for base memref "
2739                      "type ")
2740            << baseType << " and view memref type " << viewType;
2741 
2742   // Verify that we have the correct number of sizes for the result type.
2743   unsigned numDynamicDims = viewType.getNumDynamicDims();
2744   if (getSizes().size() != numDynamicDims)
2745     return emitError("incorrect number of size operands for type ") << viewType;
2746 
2747   return success();
2748 }
2749 
getViewSource()2750 Value ViewOp::getViewSource() { return getSource(); }
2751 
2752 namespace {
2753 
2754 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
2755   using OpRewritePattern<ViewOp>::OpRewritePattern;
2756 
matchAndRewrite__anon3c5201a12511::ViewOpShapeFolder2757   LogicalResult matchAndRewrite(ViewOp viewOp,
2758                                 PatternRewriter &rewriter) const override {
2759     // Return if none of the operands are constants.
2760     if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
2761           return matchPattern(operand, matchConstantIndex());
2762         }))
2763       return failure();
2764 
2765     // Get result memref type.
2766     auto memrefType = viewOp.getType();
2767 
2768     // Get offset from old memref view type 'memRefType'.
2769     int64_t oldOffset;
2770     SmallVector<int64_t, 4> oldStrides;
2771     if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
2772       return failure();
2773     assert(oldOffset == 0 && "Expected 0 offset");
2774 
2775     SmallVector<Value, 4> newOperands;
2776 
2777     // Offset cannot be folded into result type.
2778 
2779     // Fold any dynamic dim operands which are produced by a constant.
2780     SmallVector<int64_t, 4> newShapeConstants;
2781     newShapeConstants.reserve(memrefType.getRank());
2782 
2783     unsigned dynamicDimPos = 0;
2784     unsigned rank = memrefType.getRank();
2785     for (unsigned dim = 0, e = rank; dim < e; ++dim) {
2786       int64_t dimSize = memrefType.getDimSize(dim);
2787       // If this is already static dimension, keep it.
2788       if (!ShapedType::isDynamic(dimSize)) {
2789         newShapeConstants.push_back(dimSize);
2790         continue;
2791       }
2792       auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
2793       if (auto constantIndexOp =
2794               dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
2795         // Dynamic shape dimension will be folded.
2796         newShapeConstants.push_back(constantIndexOp.value());
2797       } else {
2798         // Dynamic shape dimension not folded; copy operand from old memref.
2799         newShapeConstants.push_back(dimSize);
2800         newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
2801       }
2802       dynamicDimPos++;
2803     }
2804 
2805     // Create new memref type with constant folded dims.
2806     MemRefType newMemRefType =
2807         MemRefType::Builder(memrefType).setShape(newShapeConstants);
2808     // Nothing new, don't fold.
2809     if (newMemRefType == memrefType)
2810       return failure();
2811 
2812     // Create new ViewOp.
2813     auto newViewOp = rewriter.create<ViewOp>(
2814         viewOp.getLoc(), newMemRefType, viewOp.getOperand(0),
2815         viewOp.getByteShift(), newOperands);
2816     // Insert a cast so we have the same type as the old memref type.
2817     rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
2818     return success();
2819   }
2820 };
2821 
2822 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
2823   using OpRewritePattern<ViewOp>::OpRewritePattern;
2824 
matchAndRewrite__anon3c5201a12511::ViewOpMemrefCastFolder2825   LogicalResult matchAndRewrite(ViewOp viewOp,
2826                                 PatternRewriter &rewriter) const override {
2827     Value memrefOperand = viewOp.getOperand(0);
2828     CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
2829     if (!memrefCastOp)
2830       return failure();
2831     Value allocOperand = memrefCastOp.getOperand();
2832     AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
2833     if (!allocOp)
2834       return failure();
2835     rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
2836                                         viewOp.getByteShift(),
2837                                         viewOp.getSizes());
2838     return success();
2839   }
2840 };
2841 
2842 } // namespace
2843 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2844 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
2845                                          MLIRContext *context) {
2846   results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
2847 }
2848 
2849 //===----------------------------------------------------------------------===//
2850 // AtomicRMWOp
2851 //===----------------------------------------------------------------------===//
2852 
verify()2853 LogicalResult AtomicRMWOp::verify() {
2854   if (getMemRefType().getRank() != getNumOperands() - 2)
2855     return emitOpError(
2856         "expects the number of subscripts to be equal to memref rank");
2857   switch (getKind()) {
2858   case arith::AtomicRMWKind::addf:
2859   case arith::AtomicRMWKind::maxf:
2860   case arith::AtomicRMWKind::minf:
2861   case arith::AtomicRMWKind::mulf:
2862     if (!getValue().getType().isa<FloatType>())
2863       return emitOpError() << "with kind '"
2864                            << arith::stringifyAtomicRMWKind(getKind())
2865                            << "' expects a floating-point type";
2866     break;
2867   case arith::AtomicRMWKind::addi:
2868   case arith::AtomicRMWKind::maxs:
2869   case arith::AtomicRMWKind::maxu:
2870   case arith::AtomicRMWKind::mins:
2871   case arith::AtomicRMWKind::minu:
2872   case arith::AtomicRMWKind::muli:
2873   case arith::AtomicRMWKind::ori:
2874   case arith::AtomicRMWKind::andi:
2875     if (!getValue().getType().isa<IntegerType>())
2876       return emitOpError() << "with kind '"
2877                            << arith::stringifyAtomicRMWKind(getKind())
2878                            << "' expects an integer type";
2879     break;
2880   default:
2881     break;
2882   }
2883   return success();
2884 }
2885 
fold(ArrayRef<Attribute> operands)2886 OpFoldResult AtomicRMWOp::fold(ArrayRef<Attribute> operands) {
2887   /// atomicrmw(memrefcast) -> atomicrmw
2888   if (succeeded(foldMemRefCast(*this, getValue())))
2889     return getResult();
2890   return OpFoldResult();
2891 }
2892 
2893 //===----------------------------------------------------------------------===//
2894 // TableGen'd op method definitions
2895 //===----------------------------------------------------------------------===//
2896 
2897 #define GET_OP_CLASSES
2898 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
2899