1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/MemRef/IR/MemRef.h"
10 #include "mlir/Dialect/StandardOps/IR/Ops.h"
11 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
12 #include "mlir/Dialect/Tensor/IR/Tensor.h"
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "llvm/ADT/STLExtras.h"
20 
21 using namespace mlir;
22 using namespace mlir::memref;
23 
24 /// Materialize a single constant operation from a given attribute value with
25 /// the desired resultant type.
26 Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
27                                               Attribute value, Type type,
28                                               Location loc) {
29   return builder.create<mlir::ConstantOp>(loc, type, value);
30 }
31 
32 /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
33 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
34   return llvm::to_vector<4>(
35       llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
36         return a.cast<IntegerAttr>().getInt();
37       }));
38 }
39 
40 /// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
41 /// it is a Value or into `staticVec` if it is an IntegerAttr.
42 /// In the case of a Value, a copy of the `sentinel` value is also pushed to
43 /// `staticVec`. This is useful to extract mixed static and dynamic entries that
44 /// come from an AttrSizedOperandSegments trait.
45 static void dispatchIndexOpFoldResult(OpFoldResult ofr,
46                                       SmallVectorImpl<Value> &dynamicVec,
47                                       SmallVectorImpl<int64_t> &staticVec,
48                                       int64_t sentinel) {
49   if (auto v = ofr.dyn_cast<Value>()) {
50     dynamicVec.push_back(v);
51     staticVec.push_back(sentinel);
52     return;
53   }
54   APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue();
55   staticVec.push_back(apInt.getSExtValue());
56 }
57 
58 static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
59                                        SmallVectorImpl<Value> &dynamicVec,
60                                        SmallVectorImpl<int64_t> &staticVec,
61                                        int64_t sentinel) {
62   for (auto ofr : ofrs)
63     dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel);
64 }
65 
66 //===----------------------------------------------------------------------===//
67 // Common canonicalization pattern support logic
68 //===----------------------------------------------------------------------===//
69 
70 /// This is a common class used for patterns of the form
71 /// "someop(memrefcast) -> someop".  It folds the source of any memref.cast
72 /// into the root operation directly.
73 static LogicalResult foldMemRefCast(Operation *op) {
74   bool folded = false;
75   for (OpOperand &operand : op->getOpOperands()) {
76     auto cast = operand.get().getDefiningOp<CastOp>();
77     if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
78       operand.set(cast.getOperand());
79       folded = true;
80     }
81   }
82   return success(folded);
83 }
84 
85 //===----------------------------------------------------------------------===//
86 // Helpers for GlobalOp
87 //===----------------------------------------------------------------------===//
88 
89 static Type getTensorTypeFromMemRefType(Type type) {
90   if (auto memref = type.dyn_cast<MemRefType>())
91     return RankedTensorType::get(memref.getShape(), memref.getElementType());
92   if (auto memref = type.dyn_cast<UnrankedMemRefType>())
93     return UnrankedTensorType::get(memref.getElementType());
94   return NoneType::get(type.getContext());
95 }
96 
97 //===----------------------------------------------------------------------===//
98 // AllocOp / AllocaOp
99 //===----------------------------------------------------------------------===//
100 
101 template <typename AllocLikeOp>
102 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
103   static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
104                 "applies to only alloc or alloca");
105   auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>();
106   if (!memRefType)
107     return op.emitOpError("result must be a memref");
108 
109   if (static_cast<int64_t>(op.dynamicSizes().size()) !=
110       memRefType.getNumDynamicDims())
111     return op.emitOpError("dimension operand count does not equal memref "
112                           "dynamic dimension count");
113 
114   unsigned numSymbols = 0;
115   if (!memRefType.getAffineMaps().empty())
116     numSymbols = memRefType.getAffineMaps().front().getNumSymbols();
117   if (op.symbolOperands().size() != numSymbols)
118     return op.emitOpError(
119         "symbol operand count does not equal memref symbol count");
120 
121   return success();
122 }
123 
124 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); }
125 
126 static LogicalResult verify(AllocaOp op) {
127   // An alloca op needs to have an ancestor with an allocation scope trait.
128   if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
129     return op.emitOpError(
130         "requires an ancestor op with AutomaticAllocationScope trait");
131 
132   return verifyAllocLikeOp(op);
133 }
134 
135 namespace {
136 /// Fold constant dimensions into an alloc like operation.
137 template <typename AllocLikeOp>
138 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
139   using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
140 
141   LogicalResult matchAndRewrite(AllocLikeOp alloc,
142                                 PatternRewriter &rewriter) const override {
143     // Check to see if any dimensions operands are constants.  If so, we can
144     // substitute and drop them.
145     if (llvm::none_of(alloc.getOperands(), [](Value operand) {
146           return matchPattern(operand, matchConstantIndex());
147         }))
148       return failure();
149 
150     auto memrefType = alloc.getType();
151 
152     // Ok, we have one or more constant operands.  Collect the non-constant ones
153     // and keep track of the resultant memref type to build.
154     SmallVector<int64_t, 4> newShapeConstants;
155     newShapeConstants.reserve(memrefType.getRank());
156     SmallVector<Value, 4> newOperands;
157 
158     unsigned dynamicDimPos = 0;
159     for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
160       int64_t dimSize = memrefType.getDimSize(dim);
161       // If this is already static dimension, keep it.
162       if (dimSize != -1) {
163         newShapeConstants.push_back(dimSize);
164         continue;
165       }
166       auto *defOp = alloc.getOperand(dynamicDimPos).getDefiningOp();
167       if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
168         // Dynamic shape dimension will be folded.
169         newShapeConstants.push_back(constantIndexOp.getValue());
170       } else {
171         // Dynamic shape dimension not folded; copy operand from old memref.
172         newShapeConstants.push_back(-1);
173         newOperands.push_back(alloc.getOperand(dynamicDimPos));
174       }
175       dynamicDimPos++;
176     }
177 
178     // Create new memref type (which will have fewer dynamic dimensions).
179     MemRefType newMemRefType =
180         MemRefType::Builder(memrefType).setShape(newShapeConstants);
181     assert(static_cast<int64_t>(newOperands.size()) ==
182            newMemRefType.getNumDynamicDims());
183 
184     // Create and insert the alloc op for the new memref.
185     auto newAlloc = rewriter.create<AllocLikeOp>(alloc.getLoc(), newMemRefType,
186                                                  newOperands, IntegerAttr());
187     // Insert a cast so we have the same type as the old alloc.
188     auto resultCast =
189         rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType());
190 
191     rewriter.replaceOp(alloc, {resultCast});
192     return success();
193   }
194 };
195 
196 /// Fold alloc operations with no uses. Alloc has side effects on the heap,
197 /// but can still be deleted if it has zero uses.
198 struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> {
199   using OpRewritePattern<AllocOp>::OpRewritePattern;
200 
201   LogicalResult matchAndRewrite(AllocOp alloc,
202                                 PatternRewriter &rewriter) const override {
203     if (alloc.use_empty()) {
204       rewriter.eraseOp(alloc);
205       return success();
206     }
207     return failure();
208   }
209 };
210 } // end anonymous namespace.
211 
212 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
213                                           MLIRContext *context) {
214   results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc>(context);
215 }
216 
217 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
218                                            MLIRContext *context) {
219   results.add<SimplifyAllocConst<AllocaOp>>(context);
220 }
221 
222 //===----------------------------------------------------------------------===//
223 // AssumeAlignmentOp
224 //===----------------------------------------------------------------------===//
225 
226 static LogicalResult verify(AssumeAlignmentOp op) {
227   unsigned alignment = op.alignment();
228   if (!llvm::isPowerOf2_32(alignment))
229     return op.emitOpError("alignment must be power of 2");
230   return success();
231 }
232 
233 //===----------------------------------------------------------------------===//
234 // BufferCastOp
235 //===----------------------------------------------------------------------===//
236 
237 OpFoldResult BufferCastOp::fold(ArrayRef<Attribute>) {
238   if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>())
239     if (tensorLoad.memref().getType() == getType())
240       return tensorLoad.memref();
241   return {};
242 }
243 
244 namespace {
245 /// Replace tensor_cast + buffer_cast by buffer_cast + memref_cast.
246 struct BufferCast : public OpRewritePattern<BufferCastOp> {
247   using OpRewritePattern<BufferCastOp>::OpRewritePattern;
248 
249   LogicalResult matchAndRewrite(BufferCastOp bufferCast,
250                                 PatternRewriter &rewriter) const final {
251     auto tensorCastOperand =
252         bufferCast.getOperand().getDefiningOp<tensor::CastOp>();
253     if (!tensorCastOperand)
254       return failure();
255     auto srcTensorType =
256         tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>();
257     if (!srcTensorType)
258       return failure();
259     auto memrefType = MemRefType::get(srcTensorType.getShape(),
260                                       srcTensorType.getElementType());
261     Value memref = rewriter.create<BufferCastOp>(
262         bufferCast.getLoc(), memrefType, tensorCastOperand.getOperand());
263     rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(),
264                                         memref);
265     return success();
266   }
267 };
268 
269 /// Canonicalize memref.tensor_load + memref.buffer_cast to memref.cast when
270 /// type mismatches prevent `BufferCastOp::fold` to kick in.
271 struct TensorLoadToMemRef : public OpRewritePattern<BufferCastOp> {
272   using OpRewritePattern<BufferCastOp>::OpRewritePattern;
273 
274   LogicalResult matchAndRewrite(BufferCastOp bufferCast,
275                                 PatternRewriter &rewriter) const final {
276     auto tensorLoad = bufferCast.tensor().getDefiningOp<TensorLoadOp>();
277     // Bail unless we have a tensor_load + memref.buffer_cast with different
278     // types. `BufferCastOp::fold` handles the same type case.
279     if (!tensorLoad || tensorLoad.memref().getType() == bufferCast.getType())
280       return failure();
281     // If types are not cast-compatible, bail.
282     if (!CastOp::areCastCompatible(tensorLoad.memref().getType(),
283                                    bufferCast.getType()))
284       return failure();
285     rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(),
286                                         tensorLoad.memref());
287     return success();
288   }
289 };
290 
291 } // namespace
292 
293 void BufferCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
294                                                MLIRContext *context) {
295   results.add<BufferCast, TensorLoadToMemRef>(context);
296 }
297 
298 //===----------------------------------------------------------------------===//
299 // CastOp
300 //===----------------------------------------------------------------------===//
301 
302 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
303 /// source memref. This is useful to to fold a memref.cast into a consuming op
304 /// and implement canonicalization patterns for ops in different dialects that
305 /// may consume the results of memref.cast operations. Such foldable memref.cast
306 /// operations are typically inserted as `view` and `subview` ops are
307 /// canonicalized, to preserve the type compatibility of their uses.
308 ///
309 /// Returns true when all conditions are met:
310 /// 1. source and result are ranked memrefs with strided semantics and same
311 /// element type and rank.
312 /// 2. each of the source's size, offset or stride has more static information
313 /// than the corresponding result's size, offset or stride.
314 ///
315 /// Example 1:
316 /// ```mlir
317 ///   %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
318 ///   %2 = consumer %1 ... : memref<?x?xf32> ...
319 /// ```
320 ///
321 /// may fold into:
322 ///
323 /// ```mlir
324 ///   %2 = consumer %0 ... : memref<8x16xf32> ...
325 /// ```
326 ///
327 /// Example 2:
328 /// ```
329 ///   %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
330 ///          to memref<?x?xf32>
331 ///   consumer %1 : memref<?x?xf32> ...
332 /// ```
333 ///
334 /// may fold into:
335 ///
336 /// ```
337 ///   consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
338 /// ```
339 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
340   MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
341   MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
342 
343   // Requires ranked MemRefType.
344   if (!sourceType || !resultType)
345     return false;
346 
347   // Requires same elemental type.
348   if (sourceType.getElementType() != resultType.getElementType())
349     return false;
350 
351   // Requires same rank.
352   if (sourceType.getRank() != resultType.getRank())
353     return false;
354 
355   // Only fold casts between strided memref forms.
356   int64_t sourceOffset, resultOffset;
357   SmallVector<int64_t, 4> sourceStrides, resultStrides;
358   if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
359       failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
360     return false;
361 
362   // If cast is towards more static sizes along any dimension, don't fold.
363   for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
364     auto ss = std::get<0>(it), st = std::get<1>(it);
365     if (ss != st)
366       if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st))
367         return false;
368   }
369 
370   // If cast is towards more static offset along any dimension, don't fold.
371   if (sourceOffset != resultOffset)
372     if (MemRefType::isDynamicStrideOrOffset(sourceOffset) &&
373         !MemRefType::isDynamicStrideOrOffset(resultOffset))
374       return false;
375 
376   // If cast is towards more static strides along any dimension, don't fold.
377   for (auto it : llvm::zip(sourceStrides, resultStrides)) {
378     auto ss = std::get<0>(it), st = std::get<1>(it);
379     if (ss != st)
380       if (MemRefType::isDynamicStrideOrOffset(ss) &&
381           !MemRefType::isDynamicStrideOrOffset(st))
382         return false;
383   }
384 
385   return true;
386 }
387 
388 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
389   if (inputs.size() != 1 || outputs.size() != 1)
390     return false;
391   Type a = inputs.front(), b = outputs.front();
392   auto aT = a.dyn_cast<MemRefType>();
393   auto bT = b.dyn_cast<MemRefType>();
394 
395   auto uaT = a.dyn_cast<UnrankedMemRefType>();
396   auto ubT = b.dyn_cast<UnrankedMemRefType>();
397 
398   if (aT && bT) {
399     if (aT.getElementType() != bT.getElementType())
400       return false;
401     if (aT.getAffineMaps() != bT.getAffineMaps()) {
402       int64_t aOffset, bOffset;
403       SmallVector<int64_t, 4> aStrides, bStrides;
404       if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
405           failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
406           aStrides.size() != bStrides.size())
407         return false;
408 
409       // Strides along a dimension/offset are compatible if the value in the
410       // source memref is static and the value in the target memref is the
411       // same. They are also compatible if either one is dynamic (see
412       // description of MemRefCastOp for details).
413       auto checkCompatible = [](int64_t a, int64_t b) {
414         return (a == MemRefType::getDynamicStrideOrOffset() ||
415                 b == MemRefType::getDynamicStrideOrOffset() || a == b);
416       };
417       if (!checkCompatible(aOffset, bOffset))
418         return false;
419       for (auto aStride : enumerate(aStrides))
420         if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
421           return false;
422     }
423     if (aT.getMemorySpace() != bT.getMemorySpace())
424       return false;
425 
426     // They must have the same rank, and any specified dimensions must match.
427     if (aT.getRank() != bT.getRank())
428       return false;
429 
430     for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
431       int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
432       if (aDim != -1 && bDim != -1 && aDim != bDim)
433         return false;
434     }
435     return true;
436   } else {
437     if (!aT && !uaT)
438       return false;
439     if (!bT && !ubT)
440       return false;
441     // Unranked to unranked casting is unsupported
442     if (uaT && ubT)
443       return false;
444 
445     auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
446     auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
447     if (aEltType != bEltType)
448       return false;
449 
450     auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
451     auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
452     if (aMemSpace != bMemSpace)
453       return false;
454 
455     return true;
456   }
457 
458   return false;
459 }
460 
461 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
462   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
463 }
464 
465 //===----------------------------------------------------------------------===//
466 // DeallocOp
467 //===----------------------------------------------------------------------===//
468 namespace {
469 /// Fold Dealloc operations that are deallocating an AllocOp that is only used
470 /// by other Dealloc operations.
471 struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> {
472   using OpRewritePattern<DeallocOp>::OpRewritePattern;
473 
474   LogicalResult matchAndRewrite(DeallocOp dealloc,
475                                 PatternRewriter &rewriter) const override {
476     // Check that the memref operand's defining operation is an AllocOp.
477     Value memref = dealloc.memref();
478     if (!isa_and_nonnull<AllocOp>(memref.getDefiningOp()))
479       return failure();
480 
481     // Check that all of the uses of the AllocOp are other DeallocOps.
482     for (auto *user : memref.getUsers())
483       if (!isa<DeallocOp>(user))
484         return failure();
485 
486     // Erase the dealloc operation.
487     rewriter.eraseOp(dealloc);
488     return success();
489   }
490 };
491 } // end anonymous namespace.
492 
493 static LogicalResult verify(DeallocOp op) {
494   if (!op.memref().getType().isa<MemRefType>())
495     return op.emitOpError("operand must be a memref");
496   return success();
497 }
498 
499 void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
500                                             MLIRContext *context) {
501   results.add<SimplifyDeadDealloc>(context);
502 }
503 
504 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
505                               SmallVectorImpl<OpFoldResult> &results) {
506   /// dealloc(memrefcast) -> dealloc
507   return foldMemRefCast(*this);
508 }
509 
510 //===----------------------------------------------------------------------===//
511 // DimOp
512 //===----------------------------------------------------------------------===//
513 
514 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref,
515                   int64_t index) {
516   auto loc = result.location;
517   Value indexValue = builder.create<ConstantIndexOp>(loc, index);
518   build(builder, result, memref, indexValue);
519 }
520 
521 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref,
522                   Value index) {
523   auto indexTy = builder.getIndexType();
524   build(builder, result, indexTy, memref, index);
525 }
526 
527 Optional<int64_t> DimOp::getConstantIndex() {
528   if (auto constantOp = index().getDefiningOp<ConstantOp>())
529     return constantOp.getValue().cast<IntegerAttr>().getInt();
530   return {};
531 }
532 
533 static LogicalResult verify(DimOp op) {
534   // Assume unknown index to be in range.
535   Optional<int64_t> index = op.getConstantIndex();
536   if (!index.hasValue())
537     return success();
538 
539   // Check that constant index is not knowingly out of range.
540   auto type = op.memrefOrTensor().getType();
541   if (auto memrefType = type.dyn_cast<MemRefType>()) {
542     if (index.getValue() >= memrefType.getRank())
543       return op.emitOpError("index is out of range");
544   } else if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
545     if (index.getValue() >= tensorType.getRank())
546       return op.emitOpError("index is out of range");
547   } else if (type.isa<UnrankedMemRefType>() || type.isa<UnrankedTensorType>()) {
548     // Assume index to be in range.
549   } else {
550     llvm_unreachable("expected operand with memref type");
551   }
552   return success();
553 }
554 
555 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
556   auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
557 
558   // All forms of folding require a known index.
559   if (!index)
560     return {};
561 
562   auto argTy = memrefOrTensor().getType();
563   // Fold if the shape extent along the given index is known.
564   if (auto shapedTy = argTy.dyn_cast<ShapedType>()) {
565     // Folding for unranked types (UnrankedMemRefType) is not supported.
566     if (!shapedTy.hasRank())
567       return {};
568     if (!shapedTy.isDynamicDim(index.getInt())) {
569       Builder builder(getContext());
570       return builder.getIndexAttr(shapedTy.getShape()[index.getInt()]);
571     }
572   }
573 
574   Operation *definingOp = memrefOrTensor().getDefiningOp();
575 
576   // dim(memref.tensor_load(memref)) -> dim(memref)
577   if (auto tensorLoadOp = dyn_cast_or_null<TensorLoadOp>(definingOp)) {
578     setOperand(0, tensorLoadOp.memref());
579     return getResult();
580   }
581 
582   // Fold dim to the operand of tensor.generate.
583   if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
584     auto resultType =
585         fromElements.getResult().getType().cast<RankedTensorType>();
586     // The case where the type encodes the size of the dimension is handled
587     // above.
588     assert(resultType.getShape()[index.getInt()] ==
589            RankedTensorType::kDynamicSize);
590 
591     // Find the operand of the fromElements that corresponds to this index.
592     auto dynExtents = fromElements.dynamicExtents().begin();
593     for (auto dim : resultType.getShape().take_front(index.getInt()))
594       if (dim == RankedTensorType::kDynamicSize)
595         dynExtents++;
596 
597     return Value{*dynExtents};
598   }
599 
600   // The size at the given index is now known to be a dynamic size.
601   unsigned unsignedIndex = index.getValue().getZExtValue();
602 
603   if (auto subtensor = dyn_cast_or_null<mlir::SubTensorOp>(definingOp)) {
604     assert(subtensor.isDynamicSize(unsignedIndex) &&
605            "Expected dynamic subtensor size");
606     return subtensor.getDynamicSize(unsignedIndex);
607   }
608 
609   // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
610   auto memrefType = argTy.dyn_cast<MemRefType>();
611   if (!memrefType)
612     return {};
613 
614   if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
615     return *(alloc.getDynamicSizes().begin() +
616              memrefType.getDynamicDimIndex(unsignedIndex));
617 
618   if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
619     return *(alloca.getDynamicSizes().begin() +
620              memrefType.getDynamicDimIndex(unsignedIndex));
621 
622   if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
623     return *(view.getDynamicSizes().begin() +
624              memrefType.getDynamicDimIndex(unsignedIndex));
625 
626   if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
627     assert(subview.isDynamicSize(unsignedIndex) &&
628            "Expected dynamic subview size");
629     return subview.getDynamicSize(unsignedIndex);
630   }
631 
632   // dim(memrefcast) -> dim
633   if (succeeded(foldMemRefCast(*this)))
634     return getResult();
635 
636   return {};
637 }
638 
639 namespace {
640 /// Fold dim of a memref reshape operation to a load into the reshape's shape
641 /// operand.
642 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
643   using OpRewritePattern<DimOp>::OpRewritePattern;
644 
645   LogicalResult matchAndRewrite(DimOp dim,
646                                 PatternRewriter &rewriter) const override {
647     auto reshape = dim.memrefOrTensor().getDefiningOp<ReshapeOp>();
648 
649     if (!reshape)
650       return failure();
651 
652     // Place the load directly after the reshape to ensure that the shape memref
653     // was not mutated.
654     rewriter.setInsertionPointAfter(reshape);
655     rewriter.replaceOpWithNewOp<LoadOp>(dim, reshape.shape(),
656                                         llvm::makeArrayRef({dim.index()}));
657     return success();
658   }
659 };
660 
661 /// Fold dim of a dim of a cast into the dim of the source of the tensor cast.
662 template <typename CastOpTy>
663 struct DimOfCastOp : public OpRewritePattern<DimOp> {
664   using OpRewritePattern<DimOp>::OpRewritePattern;
665 
666   LogicalResult matchAndRewrite(DimOp dimOp,
667                                 PatternRewriter &rewriter) const override {
668     auto castOp = dimOp.memrefOrTensor().getDefiningOp<CastOpTy>();
669     if (!castOp)
670       return failure();
671     Value newSource = castOp.getOperand();
672     rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.index());
673     return success();
674   }
675 };
676 } // end anonymous namespace.
677 
678 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
679                                         MLIRContext *context) {
680   results.add<DimOfMemRefReshape, DimOfCastOp<BufferCastOp>,
681               DimOfCastOp<tensor::CastOp>>(context);
682 }
683 
684 // ---------------------------------------------------------------------------
685 // DmaStartOp
686 // ---------------------------------------------------------------------------
687 
688 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
689                        Value srcMemRef, ValueRange srcIndices, Value destMemRef,
690                        ValueRange destIndices, Value numElements,
691                        Value tagMemRef, ValueRange tagIndices, Value stride,
692                        Value elementsPerStride) {
693   result.addOperands(srcMemRef);
694   result.addOperands(srcIndices);
695   result.addOperands(destMemRef);
696   result.addOperands(destIndices);
697   result.addOperands({numElements, tagMemRef});
698   result.addOperands(tagIndices);
699   if (stride)
700     result.addOperands({stride, elementsPerStride});
701 }
702 
703 void DmaStartOp::print(OpAsmPrinter &p) {
704   p << getOperationName() << " " << getSrcMemRef() << '[' << getSrcIndices()
705     << "], " << getDstMemRef() << '[' << getDstIndices() << "], "
706     << getNumElements() << ", " << getTagMemRef() << '[' << getTagIndices()
707     << ']';
708   if (isStrided())
709     p << ", " << getStride() << ", " << getNumElementsPerStride();
710 
711   p.printOptionalAttrDict((*this)->getAttrs());
712   p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
713     << ", " << getTagMemRef().getType();
714 }
715 
716 // Parse DmaStartOp.
717 // Ex:
718 //   %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
719 //                       %tag[%index], %stride, %num_elt_per_stride :
720 //                     : memref<3076 x f32, 0>,
721 //                       memref<1024 x f32, 2>,
722 //                       memref<1 x i32>
723 //
724 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
725   OpAsmParser::OperandType srcMemRefInfo;
726   SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
727   OpAsmParser::OperandType dstMemRefInfo;
728   SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos;
729   OpAsmParser::OperandType numElementsInfo;
730   OpAsmParser::OperandType tagMemrefInfo;
731   SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
732   SmallVector<OpAsmParser::OperandType, 2> strideInfo;
733 
734   SmallVector<Type, 3> types;
735   auto indexType = parser.getBuilder().getIndexType();
736 
737   // Parse and resolve the following list of operands:
738   // *) source memref followed by its indices (in square brackets).
739   // *) destination memref followed by its indices (in square brackets).
740   // *) dma size in KiB.
741   if (parser.parseOperand(srcMemRefInfo) ||
742       parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
743       parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
744       parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
745       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
746       parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
747       parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
748     return failure();
749 
750   // Parse optional stride and elements per stride.
751   if (parser.parseTrailingOperandList(strideInfo))
752     return failure();
753 
754   bool isStrided = strideInfo.size() == 2;
755   if (!strideInfo.empty() && !isStrided) {
756     return parser.emitError(parser.getNameLoc(),
757                             "expected two stride related operands");
758   }
759 
760   if (parser.parseColonTypeList(types))
761     return failure();
762   if (types.size() != 3)
763     return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
764 
765   if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
766       parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
767       parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
768       parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
769       // size should be an index.
770       parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
771       parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
772       // tag indices should be index.
773       parser.resolveOperands(tagIndexInfos, indexType, result.operands))
774     return failure();
775 
776   if (isStrided) {
777     if (parser.resolveOperands(strideInfo, indexType, result.operands))
778       return failure();
779   }
780 
781   return success();
782 }
783 
784 LogicalResult DmaStartOp::verify() {
785   unsigned numOperands = getNumOperands();
786 
787   // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
788   // the number of elements.
789   if (numOperands < 4)
790     return emitOpError("expected at least 4 operands");
791 
792   // Check types of operands. The order of these calls is important: the later
793   // calls rely on some type properties to compute the operand position.
794   // 1. Source memref.
795   if (!getSrcMemRef().getType().isa<MemRefType>())
796     return emitOpError("expected source to be of memref type");
797   if (numOperands < getSrcMemRefRank() + 4)
798     return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
799                          << " operands";
800   if (!getSrcIndices().empty() &&
801       !llvm::all_of(getSrcIndices().getTypes(),
802                     [](Type t) { return t.isIndex(); }))
803     return emitOpError("expected source indices to be of index type");
804 
805   // 2. Destination memref.
806   if (!getDstMemRef().getType().isa<MemRefType>())
807     return emitOpError("expected destination to be of memref type");
808   unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
809   if (numOperands < numExpectedOperands)
810     return emitOpError() << "expected at least " << numExpectedOperands
811                          << " operands";
812   if (!getDstIndices().empty() &&
813       !llvm::all_of(getDstIndices().getTypes(),
814                     [](Type t) { return t.isIndex(); }))
815     return emitOpError("expected destination indices to be of index type");
816 
817   // 3. Number of elements.
818   if (!getNumElements().getType().isIndex())
819     return emitOpError("expected num elements to be of index type");
820 
821   // 4. Tag memref.
822   if (!getTagMemRef().getType().isa<MemRefType>())
823     return emitOpError("expected tag to be of memref type");
824   numExpectedOperands += getTagMemRefRank();
825   if (numOperands < numExpectedOperands)
826     return emitOpError() << "expected at least " << numExpectedOperands
827                          << " operands";
828   if (!getTagIndices().empty() &&
829       !llvm::all_of(getTagIndices().getTypes(),
830                     [](Type t) { return t.isIndex(); }))
831     return emitOpError("expected tag indices to be of index type");
832 
833   // DMAs from different memory spaces supported.
834   if (getSrcMemorySpace() == getDstMemorySpace())
835     return emitOpError("DMA should be between different memory spaces");
836 
837   // Optional stride-related operands must be either both present or both
838   // absent.
839   if (numOperands != numExpectedOperands &&
840       numOperands != numExpectedOperands + 2)
841     return emitOpError("incorrect number of operands");
842 
843   // 5. Strides.
844   if (isStrided()) {
845     if (!getStride().getType().isIndex() ||
846         !getNumElementsPerStride().getType().isIndex())
847       return emitOpError(
848           "expected stride and num elements per stride to be of type index");
849   }
850 
851   return success();
852 }
853 
854 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
855                                SmallVectorImpl<OpFoldResult> &results) {
856   /// dma_start(memrefcast) -> dma_start
857   return foldMemRefCast(*this);
858 }
859 
860 // ---------------------------------------------------------------------------
861 // DmaWaitOp
862 // ---------------------------------------------------------------------------
863 
864 void DmaWaitOp::build(OpBuilder &builder, OperationState &result,
865                       Value tagMemRef, ValueRange tagIndices,
866                       Value numElements) {
867   result.addOperands(tagMemRef);
868   result.addOperands(tagIndices);
869   result.addOperands(numElements);
870 }
871 
872 void DmaWaitOp::print(OpAsmPrinter &p) {
873   p << getOperationName() << " " << getTagMemRef() << '[' << getTagIndices()
874     << "], " << getNumElements();
875   p.printOptionalAttrDict((*this)->getAttrs());
876   p << " : " << getTagMemRef().getType();
877 }
878 
879 // Parse DmaWaitOp.
880 // Eg:
881 //   dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4>
882 //
883 ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) {
884   OpAsmParser::OperandType tagMemrefInfo;
885   SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
886   Type type;
887   auto indexType = parser.getBuilder().getIndexType();
888   OpAsmParser::OperandType numElementsInfo;
889 
890   // Parse tag memref, its indices, and dma size.
891   if (parser.parseOperand(tagMemrefInfo) ||
892       parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) ||
893       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
894       parser.parseColonType(type) ||
895       parser.resolveOperand(tagMemrefInfo, type, result.operands) ||
896       parser.resolveOperands(tagIndexInfos, indexType, result.operands) ||
897       parser.resolveOperand(numElementsInfo, indexType, result.operands))
898     return failure();
899 
900   return success();
901 }
902 
903 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
904                               SmallVectorImpl<OpFoldResult> &results) {
905   /// dma_wait(memrefcast) -> dma_wait
906   return foldMemRefCast(*this);
907 }
908 
909 LogicalResult DmaWaitOp::verify() {
910   // Mandatory non-variadic operands are tag and the number of elements.
911   if (getNumOperands() < 2)
912     return emitOpError() << "expected at least 2 operands";
913 
914   // Check types of operands. The order of these calls is important: the later
915   // calls rely on some type properties to compute the operand position.
916   if (!getTagMemRef().getType().isa<MemRefType>())
917     return emitOpError() << "expected tag to be of memref type";
918 
919   if (getNumOperands() != 2 + getTagMemRefRank())
920     return emitOpError() << "expected " << 2 + getTagMemRefRank()
921                          << " operands";
922 
923   if (!getTagIndices().empty() &&
924       !llvm::all_of(getTagIndices().getTypes(),
925                     [](Type t) { return t.isIndex(); }))
926     return emitOpError() << "expected tag indices to be of index type";
927 
928   if (!getNumElements().getType().isIndex())
929     return emitOpError()
930            << "expected the number of elements to be of index type";
931 
932   return success();
933 }
934 
935 //===----------------------------------------------------------------------===//
936 // GlobalOp
937 //===----------------------------------------------------------------------===//
938 
939 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op,
940                                                    TypeAttr type,
941                                                    Attribute initialValue) {
942   p << type;
943   if (!op.isExternal()) {
944     p << " = ";
945     if (op.isUninitialized())
946       p << "uninitialized";
947     else
948       p.printAttributeWithoutType(initialValue);
949   }
950 }
951 
952 static ParseResult
953 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
954                                        Attribute &initialValue) {
955   Type type;
956   if (parser.parseType(type))
957     return failure();
958 
959   auto memrefType = type.dyn_cast<MemRefType>();
960   if (!memrefType || !memrefType.hasStaticShape())
961     return parser.emitError(parser.getNameLoc())
962            << "type should be static shaped memref, but got " << type;
963   typeAttr = TypeAttr::get(type);
964 
965   if (parser.parseOptionalEqual())
966     return success();
967 
968   if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
969     initialValue = UnitAttr::get(parser.getBuilder().getContext());
970     return success();
971   }
972 
973   Type tensorType = getTensorTypeFromMemRefType(memrefType);
974   if (parser.parseAttribute(initialValue, tensorType))
975     return failure();
976   if (!initialValue.isa<ElementsAttr>())
977     return parser.emitError(parser.getNameLoc())
978            << "initial value should be a unit or elements attribute";
979   return success();
980 }
981 
982 static LogicalResult verify(GlobalOp op) {
983   auto memrefType = op.type().dyn_cast<MemRefType>();
984   if (!memrefType || !memrefType.hasStaticShape())
985     return op.emitOpError("type should be static shaped memref, but got ")
986            << op.type();
987 
988   // Verify that the initial value, if present, is either a unit attribute or
989   // an elements attribute.
990   if (op.initial_value().hasValue()) {
991     Attribute initValue = op.initial_value().getValue();
992     if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>())
993       return op.emitOpError("initial value should be a unit or elements "
994                             "attribute, but got ")
995              << initValue;
996 
997     // Check that the type of the initial value is compatible with the type of
998     // the global variable.
999     if (initValue.isa<ElementsAttr>()) {
1000       Type initType = initValue.getType();
1001       Type tensorType = getTensorTypeFromMemRefType(memrefType);
1002       if (initType != tensorType)
1003         return op.emitOpError("initial value expected to be of type ")
1004                << tensorType << ", but was of type " << initType;
1005     }
1006   }
1007 
1008   // TODO: verify visibility for declarations.
1009   return success();
1010 }
1011 
1012 //===----------------------------------------------------------------------===//
1013 // GetGlobalOp
1014 //===----------------------------------------------------------------------===//
1015 
1016 LogicalResult
1017 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1018   // Verify that the result type is same as the type of the referenced
1019   // memref.global op.
1020   auto global =
1021       symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr());
1022   if (!global)
1023     return emitOpError("'")
1024            << name() << "' does not reference a valid global memref";
1025 
1026   Type resultType = result().getType();
1027   if (global.type() != resultType)
1028     return emitOpError("result type ")
1029            << resultType << " does not match type " << global.type()
1030            << " of the global memref @" << name();
1031   return success();
1032 }
1033 
1034 //===----------------------------------------------------------------------===//
1035 // LoadOp
1036 //===----------------------------------------------------------------------===//
1037 
1038 static LogicalResult verify(LoadOp op) {
1039   if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
1040     return op.emitOpError("incorrect number of indices for load");
1041   return success();
1042 }
1043 
1044 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
1045   /// load(memrefcast) -> load
1046   if (succeeded(foldMemRefCast(*this)))
1047     return getResult();
1048   return OpFoldResult();
1049 }
1050 
1051 namespace {
1052 /// Fold a load on a buffer_cast operation into an tensor.extract on the
1053 /// corresponding tensor.
1054 struct LoadOfBufferCast : public OpRewritePattern<LoadOp> {
1055   using OpRewritePattern<LoadOp>::OpRewritePattern;
1056 
1057   LogicalResult matchAndRewrite(LoadOp load,
1058                                 PatternRewriter &rewriter) const override {
1059     auto buffercast = load.memref().getDefiningOp<BufferCastOp>();
1060     if (!buffercast)
1061       return failure();
1062 
1063     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, buffercast.tensor(),
1064                                                    load.indices());
1065     return success();
1066   }
1067 };
1068 } // end anonymous namespace.
1069 
1070 void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
1071                                          MLIRContext *context) {
1072   results.add<LoadOfBufferCast>(context);
1073 }
1074 
1075 //===----------------------------------------------------------------------===//
1076 // PrefetchOp
1077 //===----------------------------------------------------------------------===//
1078 
1079 static void print(OpAsmPrinter &p, PrefetchOp op) {
1080   p << PrefetchOp::getOperationName() << " " << op.memref() << '[';
1081   p.printOperands(op.indices());
1082   p << ']' << ", " << (op.isWrite() ? "write" : "read");
1083   p << ", locality<" << op.localityHint();
1084   p << ">, " << (op.isDataCache() ? "data" : "instr");
1085   p.printOptionalAttrDict(
1086       op->getAttrs(),
1087       /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1088   p << " : " << op.getMemRefType();
1089 }
1090 
1091 static ParseResult parsePrefetchOp(OpAsmParser &parser,
1092                                    OperationState &result) {
1093   OpAsmParser::OperandType memrefInfo;
1094   SmallVector<OpAsmParser::OperandType, 4> indexInfo;
1095   IntegerAttr localityHint;
1096   MemRefType type;
1097   StringRef readOrWrite, cacheType;
1098 
1099   auto indexTy = parser.getBuilder().getIndexType();
1100   auto i32Type = parser.getBuilder().getIntegerType(32);
1101   if (parser.parseOperand(memrefInfo) ||
1102       parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1103       parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1104       parser.parseComma() || parser.parseKeyword("locality") ||
1105       parser.parseLess() ||
1106       parser.parseAttribute(localityHint, i32Type, "localityHint",
1107                             result.attributes) ||
1108       parser.parseGreater() || parser.parseComma() ||
1109       parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1110       parser.resolveOperand(memrefInfo, type, result.operands) ||
1111       parser.resolveOperands(indexInfo, indexTy, result.operands))
1112     return failure();
1113 
1114   if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
1115     return parser.emitError(parser.getNameLoc(),
1116                             "rw specifier has to be 'read' or 'write'");
1117   result.addAttribute(
1118       PrefetchOp::getIsWriteAttrName(),
1119       parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
1120 
1121   if (!cacheType.equals("data") && !cacheType.equals("instr"))
1122     return parser.emitError(parser.getNameLoc(),
1123                             "cache type has to be 'data' or 'instr'");
1124 
1125   result.addAttribute(
1126       PrefetchOp::getIsDataCacheAttrName(),
1127       parser.getBuilder().getBoolAttr(cacheType.equals("data")));
1128 
1129   return success();
1130 }
1131 
1132 static LogicalResult verify(PrefetchOp op) {
1133   if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
1134     return op.emitOpError("too few indices");
1135 
1136   return success();
1137 }
1138 
1139 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
1140                                SmallVectorImpl<OpFoldResult> &results) {
1141   // prefetch(memrefcast) -> prefetch
1142   return foldMemRefCast(*this);
1143 }
1144 
1145 //===----------------------------------------------------------------------===//
1146 // ReinterpretCastOp
1147 //===----------------------------------------------------------------------===//
1148 
1149 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1150 /// `staticSizes` and `staticStrides` are automatically filled with
1151 /// source-memref-rank sentinel values that encode dynamic entries.
1152 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1153                               MemRefType resultType, Value source,
1154                               OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1155                               ArrayRef<OpFoldResult> strides,
1156                               ArrayRef<NamedAttribute> attrs) {
1157   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1158   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1159   dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets,
1160                              ShapedType::kDynamicStrideOrOffset);
1161   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1162                              ShapedType::kDynamicSize);
1163   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1164                              ShapedType::kDynamicStrideOrOffset);
1165   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1166         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1167         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1168   result.addAttributes(attrs);
1169 }
1170 
1171 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1172                               MemRefType resultType, Value source,
1173                               int64_t offset, ArrayRef<int64_t> sizes,
1174                               ArrayRef<int64_t> strides,
1175                               ArrayRef<NamedAttribute> attrs) {
1176   SmallVector<OpFoldResult> sizeValues =
1177       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1178         return b.getI64IntegerAttr(v);
1179       }));
1180   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1181       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1182         return b.getI64IntegerAttr(v);
1183       }));
1184   build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1185         strideValues, attrs);
1186 }
1187 
1188 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1189                               MemRefType resultType, Value source, Value offset,
1190                               ValueRange sizes, ValueRange strides,
1191                               ArrayRef<NamedAttribute> attrs) {
1192   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1193       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1194   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1195       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1196   build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1197 }
1198 
1199 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1200 // completed automatically, like we have for subview and subtensor.
1201 static LogicalResult verify(ReinterpretCastOp op) {
1202   // The source and result memrefs should be in the same memory space.
1203   auto srcType = op.source().getType().cast<BaseMemRefType>();
1204   auto resultType = op.getType().cast<MemRefType>();
1205   if (srcType.getMemorySpace() != resultType.getMemorySpace())
1206     return op.emitError("different memory spaces specified for source type ")
1207            << srcType << " and result memref type " << resultType;
1208   if (srcType.getElementType() != resultType.getElementType())
1209     return op.emitError("different element types specified for source type ")
1210            << srcType << " and result memref type " << resultType;
1211 
1212   // Match sizes in result memref type and in static_sizes attribute.
1213   for (auto &en :
1214        llvm::enumerate(llvm::zip(resultType.getShape(),
1215                                  extractFromI64ArrayAttr(op.static_sizes())))) {
1216     int64_t resultSize = std::get<0>(en.value());
1217     int64_t expectedSize = std::get<1>(en.value());
1218     if (resultSize != expectedSize)
1219       return op.emitError("expected result type with size = ")
1220              << expectedSize << " instead of " << resultSize
1221              << " in dim = " << en.index();
1222   }
1223 
1224   // Match offset and strides in static_offset and static_strides attributes if
1225   // result memref type has an affine map specified.
1226   if (!resultType.getAffineMaps().empty()) {
1227     int64_t resultOffset;
1228     SmallVector<int64_t, 4> resultStrides;
1229     if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
1230       return failure();
1231 
1232     // Match offset in result memref type and in static_offsets attribute.
1233     int64_t expectedOffset =
1234         extractFromI64ArrayAttr(op.static_offsets()).front();
1235     if (resultOffset != expectedOffset)
1236       return op.emitError("expected result type with offset = ")
1237              << resultOffset << " instead of " << expectedOffset;
1238 
1239     // Match strides in result memref type and in static_strides attribute.
1240     for (auto &en : llvm::enumerate(llvm::zip(
1241              resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
1242       int64_t resultStride = std::get<0>(en.value());
1243       int64_t expectedStride = std::get<1>(en.value());
1244       if (resultStride != expectedStride)
1245         return op.emitError("expected result type with stride = ")
1246                << expectedStride << " instead of " << resultStride
1247                << " in dim = " << en.index();
1248     }
1249   }
1250   return success();
1251 }
1252 
1253 //===----------------------------------------------------------------------===//
1254 // ReshapeOp
1255 //===----------------------------------------------------------------------===//
1256 
1257 static LogicalResult verify(ReshapeOp op) {
1258   Type operandType = op.source().getType();
1259   Type resultType = op.result().getType();
1260 
1261   Type operandElementType = operandType.cast<ShapedType>().getElementType();
1262   Type resultElementType = resultType.cast<ShapedType>().getElementType();
1263   if (operandElementType != resultElementType)
1264     return op.emitOpError("element types of source and destination memref "
1265                           "types should be the same");
1266 
1267   if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
1268     if (!operandMemRefType.getAffineMaps().empty())
1269       return op.emitOpError(
1270           "source memref type should have identity affine map");
1271 
1272   int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0);
1273   auto resultMemRefType = resultType.dyn_cast<MemRefType>();
1274   if (resultMemRefType) {
1275     if (!resultMemRefType.getAffineMaps().empty())
1276       return op.emitOpError(
1277           "result memref type should have identity affine map");
1278     if (shapeSize == ShapedType::kDynamicSize)
1279       return op.emitOpError("cannot use shape operand with dynamic length to "
1280                             "reshape to statically-ranked memref type");
1281     if (shapeSize != resultMemRefType.getRank())
1282       return op.emitOpError(
1283           "length of shape operand differs from the result's memref rank");
1284   }
1285   return success();
1286 }
1287 
1288 //===----------------------------------------------------------------------===//
1289 // StoreOp
1290 //===----------------------------------------------------------------------===//
1291 
1292 static LogicalResult verify(StoreOp op) {
1293   if (op.getNumOperands() != 2 + op.getMemRefType().getRank())
1294     return op.emitOpError("store index operand count not equal to memref rank");
1295 
1296   return success();
1297 }
1298 
1299 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
1300                             SmallVectorImpl<OpFoldResult> &results) {
1301   /// store(memrefcast) -> store
1302   return foldMemRefCast(*this);
1303 }
1304 
1305 //===----------------------------------------------------------------------===//
1306 // SubViewOp
1307 //===----------------------------------------------------------------------===//
1308 
1309 namespace {
1310 /// Helpers to write more idiomatic operations.
1311 namespace saturated_arith {
1312 struct Wrapper {
1313   explicit Wrapper(int64_t v) : v(v) {}
1314   operator int64_t() { return v; }
1315   int64_t v;
1316 };
1317 Wrapper operator+(Wrapper a, int64_t b) {
1318   if (ShapedType::isDynamicStrideOrOffset(a) ||
1319       ShapedType::isDynamicStrideOrOffset(b))
1320     return Wrapper(ShapedType::kDynamicStrideOrOffset);
1321   return Wrapper(a.v + b);
1322 }
1323 Wrapper operator*(Wrapper a, int64_t b) {
1324   if (ShapedType::isDynamicStrideOrOffset(a) ||
1325       ShapedType::isDynamicStrideOrOffset(b))
1326     return Wrapper(ShapedType::kDynamicStrideOrOffset);
1327   return Wrapper(a.v * b);
1328 }
1329 } // end namespace saturated_arith
1330 } // end namespace
1331 
1332 /// A subview result type can be fully inferred from the source type and the
1333 /// static representation of offsets, sizes and strides. Special sentinels
1334 /// encode the dynamic case.
1335 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
1336                                 ArrayRef<int64_t> leadingStaticOffsets,
1337                                 ArrayRef<int64_t> leadingStaticSizes,
1338                                 ArrayRef<int64_t> leadingStaticStrides) {
1339   // A subview may specify only a leading subset of offset/sizes/strides in
1340   // which case we complete with offset=0, sizes from memref type and strides=1.
1341   unsigned rank = sourceMemRefType.getRank();
1342   assert(leadingStaticOffsets.size() <= rank &&
1343          "unexpected leadingStaticOffsets overflow");
1344   assert(leadingStaticSizes.size() <= rank &&
1345          "unexpected leadingStaticSizes overflow");
1346   assert(leadingStaticStrides.size() <= rank &&
1347          "unexpected leadingStaticStrides overflow");
1348   auto staticOffsets = llvm::to_vector<4>(leadingStaticOffsets);
1349   auto staticSizes = llvm::to_vector<4>(leadingStaticSizes);
1350   auto staticStrides = llvm::to_vector<4>(leadingStaticStrides);
1351   unsigned numTrailingOffsets = rank - staticOffsets.size();
1352   unsigned numTrailingSizes = rank - staticSizes.size();
1353   unsigned numTrailingStrides = rank - staticStrides.size();
1354   staticOffsets.append(numTrailingOffsets, 0);
1355   llvm::append_range(staticSizes,
1356                      sourceMemRefType.getShape().take_back(numTrailingSizes));
1357   staticStrides.append(numTrailingStrides, 1);
1358 
1359   // Extract source offset and strides.
1360   int64_t sourceOffset;
1361   SmallVector<int64_t, 4> sourceStrides;
1362   auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset);
1363   assert(succeeded(res) && "SubViewOp expected strided memref type");
1364   (void)res;
1365 
1366   // Compute target offset whose value is:
1367   //   `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
1368   int64_t targetOffset = sourceOffset;
1369   for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
1370     auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
1371     using namespace saturated_arith;
1372     targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride;
1373   }
1374 
1375   // Compute target stride whose value is:
1376   //   `sourceStrides_i * staticStrides_i`.
1377   SmallVector<int64_t, 4> targetStrides;
1378   targetStrides.reserve(staticOffsets.size());
1379   for (auto it : llvm::zip(sourceStrides, staticStrides)) {
1380     auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
1381     using namespace saturated_arith;
1382     targetStrides.push_back(Wrapper(sourceStride) * staticStride);
1383   }
1384 
1385   // The type is now known.
1386   return MemRefType::get(
1387       staticSizes, sourceMemRefType.getElementType(),
1388       makeStridedLinearLayoutMap(targetStrides, targetOffset,
1389                                  sourceMemRefType.getContext()),
1390       sourceMemRefType.getMemorySpace());
1391 }
1392 
1393 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
1394                                 ArrayRef<OpFoldResult> leadingStaticOffsets,
1395                                 ArrayRef<OpFoldResult> leadingStaticSizes,
1396                                 ArrayRef<OpFoldResult> leadingStaticStrides) {
1397   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1398   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1399   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
1400                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
1401   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
1402                              ShapedType::kDynamicSize);
1403   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
1404                              staticStrides, ShapedType::kDynamicStrideOrOffset);
1405   return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
1406                                     staticSizes, staticStrides)
1407       .cast<MemRefType>();
1408 }
1409 
1410 Type SubViewOp::inferRankReducedResultType(
1411     unsigned resultRank, MemRefType sourceRankedTensorType,
1412     ArrayRef<int64_t> leadingStaticOffsets,
1413     ArrayRef<int64_t> leadingStaticSizes,
1414     ArrayRef<int64_t> leadingStaticStrides) {
1415   auto inferredType =
1416       inferResultType(sourceRankedTensorType, leadingStaticOffsets,
1417                       leadingStaticSizes, leadingStaticStrides)
1418           .cast<MemRefType>();
1419   assert(inferredType.getRank() >= resultRank && "expected ");
1420   int rankDiff = inferredType.getRank() - resultRank;
1421   if (rankDiff > 0) {
1422     auto shape = inferredType.getShape();
1423     llvm::SmallDenseSet<unsigned> dimsToProject;
1424     mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject);
1425     SmallVector<int64_t> projectedShape;
1426     for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
1427       if (!dimsToProject.contains(pos))
1428         projectedShape.push_back(shape[pos]);
1429 
1430     AffineMap map;
1431     auto maps = inferredType.getAffineMaps();
1432     if (!maps.empty() && maps.front())
1433       map = getProjectedMap(maps.front(), dimsToProject);
1434     inferredType =
1435         MemRefType::get(projectedShape, inferredType.getElementType(), map,
1436                         inferredType.getMemorySpace());
1437   }
1438   return inferredType;
1439 }
1440 
1441 Type SubViewOp::inferRankReducedResultType(
1442     unsigned resultRank, MemRefType sourceRankedTensorType,
1443     ArrayRef<OpFoldResult> leadingStaticOffsets,
1444     ArrayRef<OpFoldResult> leadingStaticSizes,
1445     ArrayRef<OpFoldResult> leadingStaticStrides) {
1446   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1447   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1448   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
1449                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
1450   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
1451                              ShapedType::kDynamicSize);
1452   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
1453                              staticStrides, ShapedType::kDynamicStrideOrOffset);
1454   return SubViewOp::inferRankReducedResultType(
1455       resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
1456       staticStrides);
1457 }
1458 // Build a SubViewOp with mixed static and dynamic entries and custom result
1459 // type. If the type passed is nullptr, it is inferred.
1460 void SubViewOp::build(OpBuilder &b, OperationState &result,
1461                       MemRefType resultType, Value source,
1462                       ArrayRef<OpFoldResult> offsets,
1463                       ArrayRef<OpFoldResult> sizes,
1464                       ArrayRef<OpFoldResult> strides,
1465                       ArrayRef<NamedAttribute> attrs) {
1466   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1467   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1468   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1469                              ShapedType::kDynamicStrideOrOffset);
1470   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1471                              ShapedType::kDynamicSize);
1472   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1473                              ShapedType::kDynamicStrideOrOffset);
1474   auto sourceMemRefType = source.getType().cast<MemRefType>();
1475   // Structuring implementation this way avoids duplication between builders.
1476   if (!resultType) {
1477     resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
1478                                             staticSizes, staticStrides)
1479                      .cast<MemRefType>();
1480   }
1481   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1482         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1483         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1484   result.addAttributes(attrs);
1485 }
1486 
1487 // Build a SubViewOp with mixed static and dynamic entries and inferred result
1488 // type.
1489 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1490                       ArrayRef<OpFoldResult> offsets,
1491                       ArrayRef<OpFoldResult> sizes,
1492                       ArrayRef<OpFoldResult> strides,
1493                       ArrayRef<NamedAttribute> attrs) {
1494   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
1495 }
1496 
1497 // Build a SubViewOp with static entries and inferred result type.
1498 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1499                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1500                       ArrayRef<int64_t> strides,
1501                       ArrayRef<NamedAttribute> attrs) {
1502   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1503       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
1504         return b.getI64IntegerAttr(v);
1505       }));
1506   SmallVector<OpFoldResult> sizeValues =
1507       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1508         return b.getI64IntegerAttr(v);
1509       }));
1510   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1511       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1512         return b.getI64IntegerAttr(v);
1513       }));
1514   build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
1515 }
1516 
1517 // Build a SubViewOp with dynamic entries and custom result type. If the
1518 // type passed is nullptr, it is inferred.
1519 void SubViewOp::build(OpBuilder &b, OperationState &result,
1520                       MemRefType resultType, Value source,
1521                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1522                       ArrayRef<int64_t> strides,
1523                       ArrayRef<NamedAttribute> attrs) {
1524   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1525       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
1526         return b.getI64IntegerAttr(v);
1527       }));
1528   SmallVector<OpFoldResult> sizeValues =
1529       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1530         return b.getI64IntegerAttr(v);
1531       }));
1532   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1533       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1534         return b.getI64IntegerAttr(v);
1535       }));
1536   build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
1537         attrs);
1538 }
1539 
1540 // Build a SubViewOp with dynamic entries and custom result type. If the type
1541 // passed is nullptr, it is inferred.
1542 void SubViewOp::build(OpBuilder &b, OperationState &result,
1543                       MemRefType resultType, Value source, ValueRange offsets,
1544                       ValueRange sizes, ValueRange strides,
1545                       ArrayRef<NamedAttribute> attrs) {
1546   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1547       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1548   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1549       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1550   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1551       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1552   build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
1553 }
1554 
1555 // Build a SubViewOp with dynamic entries and inferred result type.
1556 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1557                       ValueRange offsets, ValueRange sizes, ValueRange strides,
1558                       ArrayRef<NamedAttribute> attrs) {
1559   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
1560 }
1561 
1562 /// For ViewLikeOpInterface.
1563 Value SubViewOp::getViewSource() { return source(); }
1564 
1565 enum SubViewVerificationResult {
1566   Success,
1567   RankTooLarge,
1568   SizeMismatch,
1569   ElemTypeMismatch,
1570   MemSpaceMismatch,
1571   AffineMapMismatch
1572 };
1573 
1574 /// Checks if `original` Type type can be rank reduced to `reduced` type.
1575 /// This function is slight variant of `is subsequence` algorithm where
1576 /// not matching dimension must be 1.
1577 static SubViewVerificationResult
1578 isRankReducedType(Type originalType, Type candidateReducedType,
1579                   std::string *errMsg = nullptr) {
1580   if (originalType == candidateReducedType)
1581     return SubViewVerificationResult::Success;
1582   if (!originalType.isa<MemRefType>())
1583     return SubViewVerificationResult::Success;
1584   if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>())
1585     return SubViewVerificationResult::Success;
1586 
1587   ShapedType originalShapedType = originalType.cast<ShapedType>();
1588   ShapedType candidateReducedShapedType =
1589       candidateReducedType.cast<ShapedType>();
1590 
1591   // Rank and size logic is valid for all ShapedTypes.
1592   ArrayRef<int64_t> originalShape = originalShapedType.getShape();
1593   ArrayRef<int64_t> candidateReducedShape =
1594       candidateReducedShapedType.getShape();
1595   unsigned originalRank = originalShape.size(),
1596            candidateReducedRank = candidateReducedShape.size();
1597   if (candidateReducedRank > originalRank)
1598     return SubViewVerificationResult::RankTooLarge;
1599 
1600   auto optionalUnusedDimsMask =
1601       computeRankReductionMask(originalShape, candidateReducedShape);
1602 
1603   // Sizes cannot be matched in case empty vector is returned.
1604   if (!optionalUnusedDimsMask.hasValue())
1605     return SubViewVerificationResult::SizeMismatch;
1606 
1607   if (originalShapedType.getElementType() !=
1608       candidateReducedShapedType.getElementType())
1609     return SubViewVerificationResult::ElemTypeMismatch;
1610 
1611   // Strided layout logic is relevant for MemRefType only.
1612   MemRefType original = originalType.cast<MemRefType>();
1613   MemRefType candidateReduced = candidateReducedType.cast<MemRefType>();
1614   if (original.getMemorySpace() != candidateReduced.getMemorySpace())
1615     return SubViewVerificationResult::MemSpaceMismatch;
1616 
1617   llvm::SmallDenseSet<unsigned> unusedDims = optionalUnusedDimsMask.getValue();
1618   auto inferredType =
1619       getProjectedMap(getStridedLinearLayoutMap(original), unusedDims);
1620   AffineMap candidateLayout;
1621   if (candidateReduced.getAffineMaps().empty())
1622     candidateLayout = getStridedLinearLayoutMap(candidateReduced);
1623   else
1624     candidateLayout = candidateReduced.getAffineMaps().front();
1625   assert(inferredType.getNumResults() == 1 &&
1626          candidateLayout.getNumResults() == 1);
1627   if (inferredType.getNumSymbols() != candidateLayout.getNumSymbols() ||
1628       inferredType.getNumDims() != candidateLayout.getNumDims()) {
1629     if (errMsg) {
1630       llvm::raw_string_ostream os(*errMsg);
1631       os << "inferred type: " << inferredType;
1632     }
1633     return SubViewVerificationResult::AffineMapMismatch;
1634   }
1635   // Check that the difference of the affine maps simplifies to 0.
1636   AffineExpr diffExpr =
1637       inferredType.getResult(0) - candidateLayout.getResult(0);
1638   diffExpr = simplifyAffineExpr(diffExpr, inferredType.getNumDims(),
1639                                 inferredType.getNumSymbols());
1640   auto cst = diffExpr.dyn_cast<AffineConstantExpr>();
1641   if (!(cst && cst.getValue() == 0)) {
1642     if (errMsg) {
1643       llvm::raw_string_ostream os(*errMsg);
1644       os << "inferred type: " << inferredType;
1645     }
1646     return SubViewVerificationResult::AffineMapMismatch;
1647   }
1648   return SubViewVerificationResult::Success;
1649 }
1650 
1651 template <typename OpTy>
1652 static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result,
1653                                             OpTy op, Type expectedType,
1654                                             StringRef errMsg = "") {
1655   auto memrefType = expectedType.cast<ShapedType>();
1656   switch (result) {
1657   case SubViewVerificationResult::Success:
1658     return success();
1659   case SubViewVerificationResult::RankTooLarge:
1660     return op.emitError("expected result rank to be smaller or equal to ")
1661            << "the source rank. " << errMsg;
1662   case SubViewVerificationResult::SizeMismatch:
1663     return op.emitError("expected result type to be ")
1664            << expectedType
1665            << " or a rank-reduced version. (mismatch of result sizes) "
1666            << errMsg;
1667   case SubViewVerificationResult::ElemTypeMismatch:
1668     return op.emitError("expected result element type to be ")
1669            << memrefType.getElementType() << errMsg;
1670   case SubViewVerificationResult::MemSpaceMismatch:
1671     return op.emitError("expected result and source memory spaces to match.")
1672            << errMsg;
1673   case SubViewVerificationResult::AffineMapMismatch:
1674     return op.emitError("expected result type to be ")
1675            << expectedType
1676            << " or a rank-reduced version. (mismatch of result affine map) "
1677            << errMsg;
1678   }
1679   llvm_unreachable("unexpected subview verification result");
1680 }
1681 
1682 /// Verifier for SubViewOp.
1683 static LogicalResult verify(SubViewOp op) {
1684   MemRefType baseType = op.getSourceType();
1685   MemRefType subViewType = op.getType();
1686 
1687   // The base memref and the view memref should be in the same memory space.
1688   if (baseType.getMemorySpace() != subViewType.getMemorySpace())
1689     return op.emitError("different memory spaces specified for base memref "
1690                         "type ")
1691            << baseType << " and subview memref type " << subViewType;
1692 
1693   // Verify that the base memref type has a strided layout map.
1694   if (!isStrided(baseType))
1695     return op.emitError("base type ") << baseType << " is not strided";
1696 
1697   // Verify result type against inferred type.
1698   auto expectedType = SubViewOp::inferResultType(
1699       baseType, extractFromI64ArrayAttr(op.static_offsets()),
1700       extractFromI64ArrayAttr(op.static_sizes()),
1701       extractFromI64ArrayAttr(op.static_strides()));
1702 
1703   std::string errMsg;
1704   auto result = isRankReducedType(expectedType, subViewType, &errMsg);
1705   return produceSubViewErrorMsg(result, op, expectedType, errMsg);
1706 }
1707 
1708 raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) {
1709   return os << "range " << range.offset << ":" << range.size << ":"
1710             << range.stride;
1711 }
1712 
1713 /// Return the list of Range (i.e. offset, size, stride). Each Range
1714 /// entry contains either the dynamic value or a ConstantIndexOp constructed
1715 /// with `b` at location `loc`.
1716 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
1717                                               OpBuilder &b, Location loc) {
1718   std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
1719   assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
1720   assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
1721   SmallVector<Range, 8> res;
1722   unsigned rank = ranks[0];
1723   res.reserve(rank);
1724   for (unsigned idx = 0; idx < rank; ++idx) {
1725     Value offset =
1726         op.isDynamicOffset(idx)
1727             ? op.getDynamicOffset(idx)
1728             : b.create<ConstantIndexOp>(loc, op.getStaticOffset(idx));
1729     Value size = op.isDynamicSize(idx)
1730                      ? op.getDynamicSize(idx)
1731                      : b.create<ConstantIndexOp>(loc, op.getStaticSize(idx));
1732     Value stride =
1733         op.isDynamicStride(idx)
1734             ? op.getDynamicStride(idx)
1735             : b.create<ConstantIndexOp>(loc, op.getStaticStride(idx));
1736     res.emplace_back(Range{offset, size, stride});
1737   }
1738   return res;
1739 }
1740 
1741 namespace {
1742 /// Pattern to rewrite a subview op with MemRefCast arguments.
1743 /// This essentially pushes memref.cast past its consuming subview when
1744 /// `canFoldIntoConsumerOp` is true.
1745 ///
1746 /// Example:
1747 /// ```
1748 ///   %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
1749 ///   %1 = memref.subview %0[0, 0][3, 4][1, 1] :
1750 ///     memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
1751 /// ```
1752 /// is rewritten into:
1753 /// ```
1754 ///   %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
1755 ///   %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to
1756 ///     memref<3x4xf32, offset:?, strides:[?, 1]>
1757 /// ```
1758 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
1759 public:
1760   using OpRewritePattern<SubViewOp>::OpRewritePattern;
1761 
1762   LogicalResult matchAndRewrite(SubViewOp subViewOp,
1763                                 PatternRewriter &rewriter) const override {
1764     // Any constant operand, just return to let SubViewOpConstantFolder kick in.
1765     if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
1766           return matchPattern(operand, matchConstantIndex());
1767         }))
1768       return failure();
1769 
1770     auto castOp = subViewOp.source().getDefiningOp<CastOp>();
1771     if (!castOp)
1772       return failure();
1773 
1774     if (!CastOp::canFoldIntoConsumerOp(castOp))
1775       return failure();
1776 
1777     /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
1778     /// the cast source operand type and the SubViewOp static information. This
1779     /// is the resulting type if the MemRefCastOp were folded.
1780     auto resultType = SubViewOp::inferRankReducedResultType(
1781         subViewOp.getType().getRank(),
1782         castOp.source().getType().cast<MemRefType>(),
1783         subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
1784         subViewOp.getMixedStrides());
1785     Value newSubView = rewriter.create<SubViewOp>(
1786         subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
1787         subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),
1788         subViewOp.static_sizes(), subViewOp.static_strides());
1789     rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
1790                                         newSubView);
1791     return success();
1792   }
1793 };
1794 } // namespace
1795 
1796 /// A canonicalizer wrapper to replace SubViewOps.
1797 struct SubViewCanonicalizer {
1798   void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
1799     rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType());
1800   }
1801 };
1802 
1803 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
1804                                             MLIRContext *context) {
1805   results.add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
1806                   SubViewOp, SubViewCanonicalizer>,
1807               SubViewOpMemRefCastFolder>(context);
1808 }
1809 
1810 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
1811   auto resultShapedType = getResult().getType().cast<ShapedType>();
1812   auto sourceShapedType = source().getType().cast<ShapedType>();
1813 
1814   if (resultShapedType.hasStaticShape() &&
1815       resultShapedType == sourceShapedType) {
1816     return getViewSource();
1817   }
1818 
1819   return {};
1820 }
1821 
1822 //===----------------------------------------------------------------------===//
1823 // TensorLoadOp
1824 //===----------------------------------------------------------------------===//
1825 
1826 OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute>) {
1827   if (auto bufferCast = memref().getDefiningOp<BufferCastOp>())
1828     // Approximate alias analysis by conservatively folding only when no there
1829     // is no interleaved operation.
1830     if (bufferCast->getBlock() == this->getOperation()->getBlock() &&
1831         bufferCast->getNextNode() == this->getOperation())
1832       return bufferCast.tensor();
1833   return {};
1834 }
1835 
1836 //===----------------------------------------------------------------------===//
1837 // TransposeOp
1838 //===----------------------------------------------------------------------===//
1839 
1840 /// Build a strided memref type by applying `permutationMap` tp `memRefType`.
1841 static MemRefType inferTransposeResultType(MemRefType memRefType,
1842                                            AffineMap permutationMap) {
1843   auto rank = memRefType.getRank();
1844   auto originalSizes = memRefType.getShape();
1845   // Compute permuted sizes.
1846   SmallVector<int64_t, 4> sizes(rank, 0);
1847   for (auto en : llvm::enumerate(permutationMap.getResults()))
1848     sizes[en.index()] =
1849         originalSizes[en.value().cast<AffineDimExpr>().getPosition()];
1850 
1851   // Compute permuted strides.
1852   int64_t offset;
1853   SmallVector<int64_t, 4> strides;
1854   auto res = getStridesAndOffset(memRefType, strides, offset);
1855   assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank));
1856   (void)res;
1857   auto map =
1858       makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
1859   map = permutationMap ? map.compose(permutationMap) : map;
1860   return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map);
1861 }
1862 
1863 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
1864                         AffineMapAttr permutation,
1865                         ArrayRef<NamedAttribute> attrs) {
1866   auto permutationMap = permutation.getValue();
1867   assert(permutationMap);
1868 
1869   auto memRefType = in.getType().cast<MemRefType>();
1870   // Compute result type.
1871   MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
1872 
1873   build(b, result, resultType, in, attrs);
1874   result.addAttribute(TransposeOp::getPermutationAttrName(), permutation);
1875 }
1876 
1877 // transpose $in $permutation attr-dict : type($in) `to` type(results)
1878 static void print(OpAsmPrinter &p, TransposeOp op) {
1879   p << "memref.transpose " << op.in() << " " << op.permutation();
1880   p.printOptionalAttrDict(op->getAttrs(),
1881                           {TransposeOp::getPermutationAttrName()});
1882   p << " : " << op.in().getType() << " to " << op.getType();
1883 }
1884 
1885 static ParseResult parseTransposeOp(OpAsmParser &parser,
1886                                     OperationState &result) {
1887   OpAsmParser::OperandType in;
1888   AffineMap permutation;
1889   MemRefType srcType, dstType;
1890   if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
1891       parser.parseOptionalAttrDict(result.attributes) ||
1892       parser.parseColonType(srcType) ||
1893       parser.resolveOperand(in, srcType, result.operands) ||
1894       parser.parseKeywordType("to", dstType) ||
1895       parser.addTypeToList(dstType, result.types))
1896     return failure();
1897 
1898   result.addAttribute(TransposeOp::getPermutationAttrName(),
1899                       AffineMapAttr::get(permutation));
1900   return success();
1901 }
1902 
1903 static LogicalResult verify(TransposeOp op) {
1904   if (!op.permutation().isPermutation())
1905     return op.emitOpError("expected a permutation map");
1906   if (op.permutation().getNumDims() != op.getShapedType().getRank())
1907     return op.emitOpError(
1908         "expected a permutation map of same rank as the input");
1909 
1910   auto srcType = op.in().getType().cast<MemRefType>();
1911   auto dstType = op.getType().cast<MemRefType>();
1912   auto transposedType = inferTransposeResultType(srcType, op.permutation());
1913   if (dstType != transposedType)
1914     return op.emitOpError("output type ")
1915            << dstType << " does not match transposed input type " << srcType
1916            << ", " << transposedType;
1917   return success();
1918 }
1919 
1920 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
1921   if (succeeded(foldMemRefCast(*this)))
1922     return getResult();
1923   return {};
1924 }
1925 
1926 //===----------------------------------------------------------------------===//
1927 // ViewOp
1928 //===----------------------------------------------------------------------===//
1929 
1930 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) {
1931   OpAsmParser::OperandType srcInfo;
1932   SmallVector<OpAsmParser::OperandType, 1> offsetInfo;
1933   SmallVector<OpAsmParser::OperandType, 4> sizesInfo;
1934   auto indexType = parser.getBuilder().getIndexType();
1935   Type srcType, dstType;
1936   llvm::SMLoc offsetLoc;
1937   if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) ||
1938       parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square))
1939     return failure();
1940 
1941   if (offsetInfo.size() != 1)
1942     return parser.emitError(offsetLoc) << "expects 1 offset operand";
1943 
1944   return failure(
1945       parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) ||
1946       parser.parseOptionalAttrDict(result.attributes) ||
1947       parser.parseColonType(srcType) ||
1948       parser.resolveOperand(srcInfo, srcType, result.operands) ||
1949       parser.resolveOperands(offsetInfo, indexType, result.operands) ||
1950       parser.resolveOperands(sizesInfo, indexType, result.operands) ||
1951       parser.parseKeywordType("to", dstType) ||
1952       parser.addTypeToList(dstType, result.types));
1953 }
1954 
1955 static void print(OpAsmPrinter &p, ViewOp op) {
1956   p << op.getOperationName() << ' ' << op.getOperand(0) << '[';
1957   p.printOperand(op.byte_shift());
1958   p << "][" << op.sizes() << ']';
1959   p.printOptionalAttrDict(op->getAttrs());
1960   p << " : " << op.getOperand(0).getType() << " to " << op.getType();
1961 }
1962 
1963 static LogicalResult verify(ViewOp op) {
1964   auto baseType = op.getOperand(0).getType().cast<MemRefType>();
1965   auto viewType = op.getType();
1966 
1967   // The base memref should have identity layout map (or none).
1968   if (baseType.getAffineMaps().size() > 1 ||
1969       (baseType.getAffineMaps().size() == 1 &&
1970        !baseType.getAffineMaps()[0].isIdentity()))
1971     return op.emitError("unsupported map for base memref type ") << baseType;
1972 
1973   // The result memref should have identity layout map (or none).
1974   if (viewType.getAffineMaps().size() > 1 ||
1975       (viewType.getAffineMaps().size() == 1 &&
1976        !viewType.getAffineMaps()[0].isIdentity()))
1977     return op.emitError("unsupported map for result memref type ") << viewType;
1978 
1979   // The base memref and the view memref should be in the same memory space.
1980   if (baseType.getMemorySpace() != viewType.getMemorySpace())
1981     return op.emitError("different memory spaces specified for base memref "
1982                         "type ")
1983            << baseType << " and view memref type " << viewType;
1984 
1985   // Verify that we have the correct number of sizes for the result type.
1986   unsigned numDynamicDims = viewType.getNumDynamicDims();
1987   if (op.sizes().size() != numDynamicDims)
1988     return op.emitError("incorrect number of size operands for type ")
1989            << viewType;
1990 
1991   return success();
1992 }
1993 
1994 Value ViewOp::getViewSource() { return source(); }
1995 
1996 namespace {
1997 
1998 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
1999   using OpRewritePattern<ViewOp>::OpRewritePattern;
2000 
2001   LogicalResult matchAndRewrite(ViewOp viewOp,
2002                                 PatternRewriter &rewriter) const override {
2003     // Return if none of the operands are constants.
2004     if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
2005           return matchPattern(operand, matchConstantIndex());
2006         }))
2007       return failure();
2008 
2009     // Get result memref type.
2010     auto memrefType = viewOp.getType();
2011 
2012     // Get offset from old memref view type 'memRefType'.
2013     int64_t oldOffset;
2014     SmallVector<int64_t, 4> oldStrides;
2015     if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
2016       return failure();
2017     assert(oldOffset == 0 && "Expected 0 offset");
2018 
2019     SmallVector<Value, 4> newOperands;
2020 
2021     // Offset cannot be folded into result type.
2022 
2023     // Fold any dynamic dim operands which are produced by a constant.
2024     SmallVector<int64_t, 4> newShapeConstants;
2025     newShapeConstants.reserve(memrefType.getRank());
2026 
2027     unsigned dynamicDimPos = 0;
2028     unsigned rank = memrefType.getRank();
2029     for (unsigned dim = 0, e = rank; dim < e; ++dim) {
2030       int64_t dimSize = memrefType.getDimSize(dim);
2031       // If this is already static dimension, keep it.
2032       if (!ShapedType::isDynamic(dimSize)) {
2033         newShapeConstants.push_back(dimSize);
2034         continue;
2035       }
2036       auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp();
2037       if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
2038         // Dynamic shape dimension will be folded.
2039         newShapeConstants.push_back(constantIndexOp.getValue());
2040       } else {
2041         // Dynamic shape dimension not folded; copy operand from old memref.
2042         newShapeConstants.push_back(dimSize);
2043         newOperands.push_back(viewOp.sizes()[dynamicDimPos]);
2044       }
2045       dynamicDimPos++;
2046     }
2047 
2048     // Create new memref type with constant folded dims.
2049     MemRefType newMemRefType =
2050         MemRefType::Builder(memrefType).setShape(newShapeConstants);
2051     // Nothing new, don't fold.
2052     if (newMemRefType == memrefType)
2053       return failure();
2054 
2055     // Create new ViewOp.
2056     auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType,
2057                                              viewOp.getOperand(0),
2058                                              viewOp.byte_shift(), newOperands);
2059     // Insert a cast so we have the same type as the old memref type.
2060     rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType());
2061     return success();
2062   }
2063 };
2064 
2065 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
2066   using OpRewritePattern<ViewOp>::OpRewritePattern;
2067 
2068   LogicalResult matchAndRewrite(ViewOp viewOp,
2069                                 PatternRewriter &rewriter) const override {
2070     Value memrefOperand = viewOp.getOperand(0);
2071     CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
2072     if (!memrefCastOp)
2073       return failure();
2074     Value allocOperand = memrefCastOp.getOperand();
2075     AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
2076     if (!allocOp)
2077       return failure();
2078     rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
2079                                         viewOp.byte_shift(), viewOp.sizes());
2080     return success();
2081   }
2082 };
2083 
2084 } // end anonymous namespace
2085 
2086 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
2087                                          MLIRContext *context) {
2088   results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
2089 }
2090 
2091 //===----------------------------------------------------------------------===//
2092 // TableGen'd op method definitions
2093 //===----------------------------------------------------------------------===//
2094 
2095 #define GET_OP_CLASSES
2096 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
2097