1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/MemRef/IR/MemRef.h"
10 #include "mlir/Dialect/StandardOps/IR/Ops.h"
11 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
12 #include "mlir/Dialect/Tensor/IR/Tensor.h"
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "llvm/ADT/STLExtras.h"
20 
21 using namespace mlir;
22 using namespace mlir::memref;
23 
24 /// Materialize a single constant operation from a given attribute value with
25 /// the desired resultant type.
26 Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
27                                               Attribute value, Type type,
28                                               Location loc) {
29   return builder.create<mlir::ConstantOp>(loc, type, value);
30 }
31 
32 /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
33 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
34   return llvm::to_vector<4>(
35       llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
36         return a.cast<IntegerAttr>().getInt();
37       }));
38 }
39 
40 /// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
41 /// it is a Value or into `staticVec` if it is an IntegerAttr.
42 /// In the case of a Value, a copy of the `sentinel` value is also pushed to
43 /// `staticVec`. This is useful to extract mixed static and dynamic entries that
44 /// come from an AttrSizedOperandSegments trait.
45 static void dispatchIndexOpFoldResult(OpFoldResult ofr,
46                                       SmallVectorImpl<Value> &dynamicVec,
47                                       SmallVectorImpl<int64_t> &staticVec,
48                                       int64_t sentinel) {
49   if (auto v = ofr.dyn_cast<Value>()) {
50     dynamicVec.push_back(v);
51     staticVec.push_back(sentinel);
52     return;
53   }
54   APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue();
55   staticVec.push_back(apInt.getSExtValue());
56 }
57 
58 static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
59                                        SmallVectorImpl<Value> &dynamicVec,
60                                        SmallVectorImpl<int64_t> &staticVec,
61                                        int64_t sentinel) {
62   for (auto ofr : ofrs)
63     dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel);
64 }
65 
66 //===----------------------------------------------------------------------===//
67 // Common canonicalization pattern support logic
68 //===----------------------------------------------------------------------===//
69 
70 /// This is a common class used for patterns of the form
71 /// "someop(memrefcast) -> someop".  It folds the source of any memref.cast
72 /// into the root operation directly.
73 static LogicalResult foldMemRefCast(Operation *op) {
74   bool folded = false;
75   for (OpOperand &operand : op->getOpOperands()) {
76     auto cast = operand.get().getDefiningOp<CastOp>();
77     if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
78       operand.set(cast.getOperand());
79       folded = true;
80     }
81   }
82   return success(folded);
83 }
84 
85 //===----------------------------------------------------------------------===//
86 // Helpers for GlobalOp
87 //===----------------------------------------------------------------------===//
88 
89 static Type getTensorTypeFromMemRefType(Type type) {
90   if (auto memref = type.dyn_cast<MemRefType>())
91     return RankedTensorType::get(memref.getShape(), memref.getElementType());
92   if (auto memref = type.dyn_cast<UnrankedMemRefType>())
93     return UnrankedTensorType::get(memref.getElementType());
94   return NoneType::get(type.getContext());
95 }
96 
97 //===----------------------------------------------------------------------===//
98 // AllocOp / AllocaOp
99 //===----------------------------------------------------------------------===//
100 
101 template <typename AllocLikeOp>
102 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
103   static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
104                 "applies to only alloc or alloca");
105   auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>();
106   if (!memRefType)
107     return op.emitOpError("result must be a memref");
108 
109   if (static_cast<int64_t>(op.dynamicSizes().size()) !=
110       memRefType.getNumDynamicDims())
111     return op.emitOpError("dimension operand count does not equal memref "
112                           "dynamic dimension count");
113 
114   unsigned numSymbols = 0;
115   if (!memRefType.getAffineMaps().empty())
116     numSymbols = memRefType.getAffineMaps().front().getNumSymbols();
117   if (op.symbolOperands().size() != numSymbols)
118     return op.emitOpError(
119         "symbol operand count does not equal memref symbol count");
120 
121   return success();
122 }
123 
124 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); }
125 
126 static LogicalResult verify(AllocaOp op) {
127   // An alloca op needs to have an ancestor with an allocation scope trait.
128   if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
129     return op.emitOpError(
130         "requires an ancestor op with AutomaticAllocationScope trait");
131 
132   return verifyAllocLikeOp(op);
133 }
134 
135 namespace {
136 /// Fold constant dimensions into an alloc like operation.
137 template <typename AllocLikeOp>
138 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
139   using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
140 
141   LogicalResult matchAndRewrite(AllocLikeOp alloc,
142                                 PatternRewriter &rewriter) const override {
143     // Check to see if any dimensions operands are constants.  If so, we can
144     // substitute and drop them.
145     if (llvm::none_of(alloc.getOperands(), [](Value operand) {
146           return matchPattern(operand, matchConstantIndex());
147         }))
148       return failure();
149 
150     auto memrefType = alloc.getType();
151 
152     // Ok, we have one or more constant operands.  Collect the non-constant ones
153     // and keep track of the resultant memref type to build.
154     SmallVector<int64_t, 4> newShapeConstants;
155     newShapeConstants.reserve(memrefType.getRank());
156     SmallVector<Value, 4> newOperands;
157 
158     unsigned dynamicDimPos = 0;
159     for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
160       int64_t dimSize = memrefType.getDimSize(dim);
161       // If this is already static dimension, keep it.
162       if (dimSize != -1) {
163         newShapeConstants.push_back(dimSize);
164         continue;
165       }
166       auto *defOp = alloc.getOperand(dynamicDimPos).getDefiningOp();
167       if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
168         // Dynamic shape dimension will be folded.
169         newShapeConstants.push_back(constantIndexOp.getValue());
170       } else {
171         // Dynamic shape dimension not folded; copy operand from old memref.
172         newShapeConstants.push_back(-1);
173         newOperands.push_back(alloc.getOperand(dynamicDimPos));
174       }
175       dynamicDimPos++;
176     }
177 
178     // Create new memref type (which will have fewer dynamic dimensions).
179     MemRefType newMemRefType =
180         MemRefType::Builder(memrefType).setShape(newShapeConstants);
181     assert(static_cast<int64_t>(newOperands.size()) ==
182            newMemRefType.getNumDynamicDims());
183 
184     // Create and insert the alloc op for the new memref.
185     auto newAlloc = rewriter.create<AllocLikeOp>(alloc.getLoc(), newMemRefType,
186                                                  newOperands, IntegerAttr());
187     // Insert a cast so we have the same type as the old alloc.
188     auto resultCast =
189         rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType());
190 
191     rewriter.replaceOp(alloc, {resultCast});
192     return success();
193   }
194 };
195 
196 /// Fold alloc operations with no uses. Alloc has side effects on the heap,
197 /// but can still be deleted if it has zero uses.
198 struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> {
199   using OpRewritePattern<AllocOp>::OpRewritePattern;
200 
201   LogicalResult matchAndRewrite(AllocOp alloc,
202                                 PatternRewriter &rewriter) const override {
203     if (alloc.use_empty()) {
204       rewriter.eraseOp(alloc);
205       return success();
206     }
207     return failure();
208   }
209 };
210 } // end anonymous namespace.
211 
212 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
213                                           MLIRContext *context) {
214   results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc>(context);
215 }
216 
217 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
218                                            MLIRContext *context) {
219   results.add<SimplifyAllocConst<AllocaOp>>(context);
220 }
221 
222 //===----------------------------------------------------------------------===//
223 // AssumeAlignmentOp
224 //===----------------------------------------------------------------------===//
225 
226 static LogicalResult verify(AssumeAlignmentOp op) {
227   unsigned alignment = op.alignment();
228   if (!llvm::isPowerOf2_32(alignment))
229     return op.emitOpError("alignment must be power of 2");
230   return success();
231 }
232 
233 //===----------------------------------------------------------------------===//
234 // BufferCastOp
235 //===----------------------------------------------------------------------===//
236 
237 OpFoldResult BufferCastOp::fold(ArrayRef<Attribute>) {
238   if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>())
239     if (tensorLoad.memref().getType() == getType())
240       return tensorLoad.memref();
241   return {};
242 }
243 
244 namespace {
245 /// Replace tensor_cast + buffer_cast by buffer_cast + memref_cast.
246 struct BufferCast : public OpRewritePattern<BufferCastOp> {
247   using OpRewritePattern<BufferCastOp>::OpRewritePattern;
248 
249   LogicalResult matchAndRewrite(BufferCastOp bufferCast,
250                                 PatternRewriter &rewriter) const final {
251     auto tensorCastOperand =
252         bufferCast.getOperand().getDefiningOp<tensor::CastOp>();
253     if (!tensorCastOperand)
254       return failure();
255     auto srcTensorType =
256         tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>();
257     if (!srcTensorType)
258       return failure();
259     auto memrefType = MemRefType::get(srcTensorType.getShape(),
260                                       srcTensorType.getElementType());
261     Value memref = rewriter.create<BufferCastOp>(
262         bufferCast.getLoc(), memrefType, tensorCastOperand.getOperand());
263     rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(),
264                                         memref);
265     return success();
266   }
267 };
268 
269 /// Canonicalize memref.tensor_load + memref.buffer_cast to memref.cast when
270 /// type mismatches prevent `BufferCastOp::fold` to kick in.
271 struct TensorLoadToMemRef : public OpRewritePattern<BufferCastOp> {
272   using OpRewritePattern<BufferCastOp>::OpRewritePattern;
273 
274   LogicalResult matchAndRewrite(BufferCastOp bufferCast,
275                                 PatternRewriter &rewriter) const final {
276     auto tensorLoad = bufferCast.tensor().getDefiningOp<TensorLoadOp>();
277     // Bail unless we have a tensor_load + memref.buffer_cast with different
278     // types. `BufferCastOp::fold` handles the same type case.
279     if (!tensorLoad || tensorLoad.memref().getType() == bufferCast.getType())
280       return failure();
281     // If types are not cast-compatible, bail.
282     if (!CastOp::areCastCompatible(tensorLoad.memref().getType(),
283                                    bufferCast.getType()))
284       return failure();
285     rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(),
286                                         tensorLoad.memref());
287     return success();
288   }
289 };
290 
291 } // namespace
292 
293 void BufferCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
294                                                MLIRContext *context) {
295   results.add<BufferCast, TensorLoadToMemRef>(context);
296 }
297 
298 //===----------------------------------------------------------------------===//
299 // CastOp
300 //===----------------------------------------------------------------------===//
301 
302 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
303 /// source memref. This is useful to to fold a memref.cast into a consuming op
304 /// and implement canonicalization patterns for ops in different dialects that
305 /// may consume the results of memref.cast operations. Such foldable memref.cast
306 /// operations are typically inserted as `view` and `subview` ops are
307 /// canonicalized, to preserve the type compatibility of their uses.
308 ///
309 /// Returns true when all conditions are met:
310 /// 1. source and result are ranked memrefs with strided semantics and same
311 /// element type and rank.
312 /// 2. each of the source's size, offset or stride has more static information
313 /// than the corresponding result's size, offset or stride.
314 ///
315 /// Example 1:
316 /// ```mlir
317 ///   %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
318 ///   %2 = consumer %1 ... : memref<?x?xf32> ...
319 /// ```
320 ///
321 /// may fold into:
322 ///
323 /// ```mlir
324 ///   %2 = consumer %0 ... : memref<8x16xf32> ...
325 /// ```
326 ///
327 /// Example 2:
328 /// ```
329 ///   %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
330 ///          to memref<?x?xf32>
331 ///   consumer %1 : memref<?x?xf32> ...
332 /// ```
333 ///
334 /// may fold into:
335 ///
336 /// ```
337 ///   consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
338 /// ```
339 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
340   MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
341   MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
342 
343   // Requires ranked MemRefType.
344   if (!sourceType || !resultType)
345     return false;
346 
347   // Requires same elemental type.
348   if (sourceType.getElementType() != resultType.getElementType())
349     return false;
350 
351   // Requires same rank.
352   if (sourceType.getRank() != resultType.getRank())
353     return false;
354 
355   // Only fold casts between strided memref forms.
356   int64_t sourceOffset, resultOffset;
357   SmallVector<int64_t, 4> sourceStrides, resultStrides;
358   if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
359       failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
360     return false;
361 
362   // If cast is towards more static sizes along any dimension, don't fold.
363   for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
364     auto ss = std::get<0>(it), st = std::get<1>(it);
365     if (ss != st)
366       if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st))
367         return false;
368   }
369 
370   // If cast is towards more static offset along any dimension, don't fold.
371   if (sourceOffset != resultOffset)
372     if (MemRefType::isDynamicStrideOrOffset(sourceOffset) &&
373         !MemRefType::isDynamicStrideOrOffset(resultOffset))
374       return false;
375 
376   // If cast is towards more static strides along any dimension, don't fold.
377   for (auto it : llvm::zip(sourceStrides, resultStrides)) {
378     auto ss = std::get<0>(it), st = std::get<1>(it);
379     if (ss != st)
380       if (MemRefType::isDynamicStrideOrOffset(ss) &&
381           !MemRefType::isDynamicStrideOrOffset(st))
382         return false;
383   }
384 
385   return true;
386 }
387 
388 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
389   if (inputs.size() != 1 || outputs.size() != 1)
390     return false;
391   Type a = inputs.front(), b = outputs.front();
392   auto aT = a.dyn_cast<MemRefType>();
393   auto bT = b.dyn_cast<MemRefType>();
394 
395   auto uaT = a.dyn_cast<UnrankedMemRefType>();
396   auto ubT = b.dyn_cast<UnrankedMemRefType>();
397 
398   if (aT && bT) {
399     if (aT.getElementType() != bT.getElementType())
400       return false;
401     if (aT.getAffineMaps() != bT.getAffineMaps()) {
402       int64_t aOffset, bOffset;
403       SmallVector<int64_t, 4> aStrides, bStrides;
404       if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
405           failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
406           aStrides.size() != bStrides.size())
407         return false;
408 
409       // Strides along a dimension/offset are compatible if the value in the
410       // source memref is static and the value in the target memref is the
411       // same. They are also compatible if either one is dynamic (see
412       // description of MemRefCastOp for details).
413       auto checkCompatible = [](int64_t a, int64_t b) {
414         return (a == MemRefType::getDynamicStrideOrOffset() ||
415                 b == MemRefType::getDynamicStrideOrOffset() || a == b);
416       };
417       if (!checkCompatible(aOffset, bOffset))
418         return false;
419       for (auto aStride : enumerate(aStrides))
420         if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
421           return false;
422     }
423     if (aT.getMemorySpaceAsInt() != bT.getMemorySpaceAsInt())
424       return false;
425 
426     // They must have the same rank, and any specified dimensions must match.
427     if (aT.getRank() != bT.getRank())
428       return false;
429 
430     for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
431       int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
432       if (aDim != -1 && bDim != -1 && aDim != bDim)
433         return false;
434     }
435     return true;
436   } else {
437     if (!aT && !uaT)
438       return false;
439     if (!bT && !ubT)
440       return false;
441     // Unranked to unranked casting is unsupported
442     if (uaT && ubT)
443       return false;
444 
445     auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
446     auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
447     if (aEltType != bEltType)
448       return false;
449 
450     auto aMemSpace =
451         (aT) ? aT.getMemorySpaceAsInt() : uaT.getMemorySpaceAsInt();
452     auto bMemSpace =
453         (bT) ? bT.getMemorySpaceAsInt() : ubT.getMemorySpaceAsInt();
454     if (aMemSpace != bMemSpace)
455       return false;
456 
457     return true;
458   }
459 
460   return false;
461 }
462 
463 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
464   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
465 }
466 
467 //===----------------------------------------------------------------------===//
468 // DeallocOp
469 //===----------------------------------------------------------------------===//
470 namespace {
471 /// Fold Dealloc operations that are deallocating an AllocOp that is only used
472 /// by other Dealloc operations.
473 struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> {
474   using OpRewritePattern<DeallocOp>::OpRewritePattern;
475 
476   LogicalResult matchAndRewrite(DeallocOp dealloc,
477                                 PatternRewriter &rewriter) const override {
478     // Check that the memref operand's defining operation is an AllocOp.
479     Value memref = dealloc.memref();
480     if (!isa_and_nonnull<AllocOp>(memref.getDefiningOp()))
481       return failure();
482 
483     // Check that all of the uses of the AllocOp are other DeallocOps.
484     for (auto *user : memref.getUsers())
485       if (!isa<DeallocOp>(user))
486         return failure();
487 
488     // Erase the dealloc operation.
489     rewriter.eraseOp(dealloc);
490     return success();
491   }
492 };
493 } // end anonymous namespace.
494 
495 static LogicalResult verify(DeallocOp op) {
496   if (!op.memref().getType().isa<MemRefType>())
497     return op.emitOpError("operand must be a memref");
498   return success();
499 }
500 
501 void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
502                                             MLIRContext *context) {
503   results.add<SimplifyDeadDealloc>(context);
504 }
505 
506 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
507                               SmallVectorImpl<OpFoldResult> &results) {
508   /// dealloc(memrefcast) -> dealloc
509   return foldMemRefCast(*this);
510 }
511 
512 //===----------------------------------------------------------------------===//
513 // DimOp
514 //===----------------------------------------------------------------------===//
515 
516 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref,
517                   int64_t index) {
518   auto loc = result.location;
519   Value indexValue = builder.create<ConstantIndexOp>(loc, index);
520   build(builder, result, memref, indexValue);
521 }
522 
523 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref,
524                   Value index) {
525   auto indexTy = builder.getIndexType();
526   build(builder, result, indexTy, memref, index);
527 }
528 
529 Optional<int64_t> DimOp::getConstantIndex() {
530   if (auto constantOp = index().getDefiningOp<ConstantOp>())
531     return constantOp.getValue().cast<IntegerAttr>().getInt();
532   return {};
533 }
534 
535 static LogicalResult verify(DimOp op) {
536   // Assume unknown index to be in range.
537   Optional<int64_t> index = op.getConstantIndex();
538   if (!index.hasValue())
539     return success();
540 
541   // Check that constant index is not knowingly out of range.
542   auto type = op.memrefOrTensor().getType();
543   if (auto memrefType = type.dyn_cast<MemRefType>()) {
544     if (index.getValue() >= memrefType.getRank())
545       return op.emitOpError("index is out of range");
546   } else if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
547     if (index.getValue() >= tensorType.getRank())
548       return op.emitOpError("index is out of range");
549   } else if (type.isa<UnrankedMemRefType>() || type.isa<UnrankedTensorType>()) {
550     // Assume index to be in range.
551   } else {
552     llvm_unreachable("expected operand with memref type");
553   }
554   return success();
555 }
556 
557 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
558   auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
559 
560   // All forms of folding require a known index.
561   if (!index)
562     return {};
563 
564   auto argTy = memrefOrTensor().getType();
565   // Fold if the shape extent along the given index is known.
566   if (auto shapedTy = argTy.dyn_cast<ShapedType>()) {
567     // Folding for unranked types (UnrankedMemRefType) is not supported.
568     if (!shapedTy.hasRank())
569       return {};
570     if (!shapedTy.isDynamicDim(index.getInt())) {
571       Builder builder(getContext());
572       return builder.getIndexAttr(shapedTy.getShape()[index.getInt()]);
573     }
574   }
575 
576   Operation *definingOp = memrefOrTensor().getDefiningOp();
577 
578   // dim(memref.tensor_load(memref)) -> dim(memref)
579   if (auto tensorLoadOp = dyn_cast_or_null<TensorLoadOp>(definingOp)) {
580     setOperand(0, tensorLoadOp.memref());
581     return getResult();
582   }
583 
584   // Fold dim to the operand of tensor.generate.
585   if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
586     auto resultType =
587         fromElements.getResult().getType().cast<RankedTensorType>();
588     // The case where the type encodes the size of the dimension is handled
589     // above.
590     assert(resultType.getShape()[index.getInt()] ==
591            RankedTensorType::kDynamicSize);
592 
593     // Find the operand of the fromElements that corresponds to this index.
594     auto dynExtents = fromElements.dynamicExtents().begin();
595     for (auto dim : resultType.getShape().take_front(index.getInt()))
596       if (dim == RankedTensorType::kDynamicSize)
597         dynExtents++;
598 
599     return Value{*dynExtents};
600   }
601 
602   // The size at the given index is now known to be a dynamic size.
603   unsigned unsignedIndex = index.getValue().getZExtValue();
604 
605   if (auto subtensor = dyn_cast_or_null<mlir::SubTensorOp>(definingOp)) {
606     assert(subtensor.isDynamicSize(unsignedIndex) &&
607            "Expected dynamic subtensor size");
608     return subtensor.getDynamicSize(unsignedIndex);
609   }
610 
611   // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
612   auto memrefType = argTy.dyn_cast<MemRefType>();
613   if (!memrefType)
614     return {};
615 
616   if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
617     return *(alloc.getDynamicSizes().begin() +
618              memrefType.getDynamicDimIndex(unsignedIndex));
619 
620   if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
621     return *(alloca.getDynamicSizes().begin() +
622              memrefType.getDynamicDimIndex(unsignedIndex));
623 
624   if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
625     return *(view.getDynamicSizes().begin() +
626              memrefType.getDynamicDimIndex(unsignedIndex));
627 
628   if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
629     assert(subview.isDynamicSize(unsignedIndex) &&
630            "Expected dynamic subview size");
631     return subview.getDynamicSize(unsignedIndex);
632   }
633 
634   // dim(memrefcast) -> dim
635   if (succeeded(foldMemRefCast(*this)))
636     return getResult();
637 
638   return {};
639 }
640 
641 namespace {
642 /// Fold dim of a memref reshape operation to a load into the reshape's shape
643 /// operand.
644 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
645   using OpRewritePattern<DimOp>::OpRewritePattern;
646 
647   LogicalResult matchAndRewrite(DimOp dim,
648                                 PatternRewriter &rewriter) const override {
649     auto reshape = dim.memrefOrTensor().getDefiningOp<ReshapeOp>();
650 
651     if (!reshape)
652       return failure();
653 
654     // Place the load directly after the reshape to ensure that the shape memref
655     // was not mutated.
656     rewriter.setInsertionPointAfter(reshape);
657     rewriter.replaceOpWithNewOp<LoadOp>(dim, reshape.shape(),
658                                         llvm::makeArrayRef({dim.index()}));
659     return success();
660   }
661 };
662 
663 /// Fold dim of a dim of a cast into the dim of the source of the tensor cast.
664 template <typename CastOpTy>
665 struct DimOfCastOp : public OpRewritePattern<DimOp> {
666   using OpRewritePattern<DimOp>::OpRewritePattern;
667 
668   LogicalResult matchAndRewrite(DimOp dimOp,
669                                 PatternRewriter &rewriter) const override {
670     auto castOp = dimOp.memrefOrTensor().getDefiningOp<CastOpTy>();
671     if (!castOp)
672       return failure();
673     Value newSource = castOp.getOperand();
674     rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.index());
675     return success();
676   }
677 };
678 } // end anonymous namespace.
679 
680 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
681                                         MLIRContext *context) {
682   results.add<DimOfMemRefReshape, DimOfCastOp<BufferCastOp>,
683               DimOfCastOp<tensor::CastOp>>(context);
684 }
685 
686 // ---------------------------------------------------------------------------
687 // DmaStartOp
688 // ---------------------------------------------------------------------------
689 
690 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
691                        Value srcMemRef, ValueRange srcIndices, Value destMemRef,
692                        ValueRange destIndices, Value numElements,
693                        Value tagMemRef, ValueRange tagIndices, Value stride,
694                        Value elementsPerStride) {
695   result.addOperands(srcMemRef);
696   result.addOperands(srcIndices);
697   result.addOperands(destMemRef);
698   result.addOperands(destIndices);
699   result.addOperands({numElements, tagMemRef});
700   result.addOperands(tagIndices);
701   if (stride)
702     result.addOperands({stride, elementsPerStride});
703 }
704 
705 void DmaStartOp::print(OpAsmPrinter &p) {
706   p << getOperationName() << " " << getSrcMemRef() << '[' << getSrcIndices()
707     << "], " << getDstMemRef() << '[' << getDstIndices() << "], "
708     << getNumElements() << ", " << getTagMemRef() << '[' << getTagIndices()
709     << ']';
710   if (isStrided())
711     p << ", " << getStride() << ", " << getNumElementsPerStride();
712 
713   p.printOptionalAttrDict((*this)->getAttrs());
714   p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
715     << ", " << getTagMemRef().getType();
716 }
717 
718 // Parse DmaStartOp.
719 // Ex:
720 //   %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
721 //                       %tag[%index], %stride, %num_elt_per_stride :
722 //                     : memref<3076 x f32, 0>,
723 //                       memref<1024 x f32, 2>,
724 //                       memref<1 x i32>
725 //
726 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
727   OpAsmParser::OperandType srcMemRefInfo;
728   SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
729   OpAsmParser::OperandType dstMemRefInfo;
730   SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos;
731   OpAsmParser::OperandType numElementsInfo;
732   OpAsmParser::OperandType tagMemrefInfo;
733   SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
734   SmallVector<OpAsmParser::OperandType, 2> strideInfo;
735 
736   SmallVector<Type, 3> types;
737   auto indexType = parser.getBuilder().getIndexType();
738 
739   // Parse and resolve the following list of operands:
740   // *) source memref followed by its indices (in square brackets).
741   // *) destination memref followed by its indices (in square brackets).
742   // *) dma size in KiB.
743   if (parser.parseOperand(srcMemRefInfo) ||
744       parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
745       parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
746       parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
747       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
748       parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
749       parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
750     return failure();
751 
752   // Parse optional stride and elements per stride.
753   if (parser.parseTrailingOperandList(strideInfo))
754     return failure();
755 
756   bool isStrided = strideInfo.size() == 2;
757   if (!strideInfo.empty() && !isStrided) {
758     return parser.emitError(parser.getNameLoc(),
759                             "expected two stride related operands");
760   }
761 
762   if (parser.parseColonTypeList(types))
763     return failure();
764   if (types.size() != 3)
765     return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
766 
767   if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
768       parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
769       parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
770       parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
771       // size should be an index.
772       parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
773       parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
774       // tag indices should be index.
775       parser.resolveOperands(tagIndexInfos, indexType, result.operands))
776     return failure();
777 
778   if (isStrided) {
779     if (parser.resolveOperands(strideInfo, indexType, result.operands))
780       return failure();
781   }
782 
783   return success();
784 }
785 
786 LogicalResult DmaStartOp::verify() {
787   unsigned numOperands = getNumOperands();
788 
789   // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
790   // the number of elements.
791   if (numOperands < 4)
792     return emitOpError("expected at least 4 operands");
793 
794   // Check types of operands. The order of these calls is important: the later
795   // calls rely on some type properties to compute the operand position.
796   // 1. Source memref.
797   if (!getSrcMemRef().getType().isa<MemRefType>())
798     return emitOpError("expected source to be of memref type");
799   if (numOperands < getSrcMemRefRank() + 4)
800     return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
801                          << " operands";
802   if (!getSrcIndices().empty() &&
803       !llvm::all_of(getSrcIndices().getTypes(),
804                     [](Type t) { return t.isIndex(); }))
805     return emitOpError("expected source indices to be of index type");
806 
807   // 2. Destination memref.
808   if (!getDstMemRef().getType().isa<MemRefType>())
809     return emitOpError("expected destination to be of memref type");
810   unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
811   if (numOperands < numExpectedOperands)
812     return emitOpError() << "expected at least " << numExpectedOperands
813                          << " operands";
814   if (!getDstIndices().empty() &&
815       !llvm::all_of(getDstIndices().getTypes(),
816                     [](Type t) { return t.isIndex(); }))
817     return emitOpError("expected destination indices to be of index type");
818 
819   // 3. Number of elements.
820   if (!getNumElements().getType().isIndex())
821     return emitOpError("expected num elements to be of index type");
822 
823   // 4. Tag memref.
824   if (!getTagMemRef().getType().isa<MemRefType>())
825     return emitOpError("expected tag to be of memref type");
826   numExpectedOperands += getTagMemRefRank();
827   if (numOperands < numExpectedOperands)
828     return emitOpError() << "expected at least " << numExpectedOperands
829                          << " operands";
830   if (!getTagIndices().empty() &&
831       !llvm::all_of(getTagIndices().getTypes(),
832                     [](Type t) { return t.isIndex(); }))
833     return emitOpError("expected tag indices to be of index type");
834 
835   // DMAs from different memory spaces supported.
836   if (getSrcMemorySpace() == getDstMemorySpace())
837     return emitOpError("DMA should be between different memory spaces");
838 
839   // Optional stride-related operands must be either both present or both
840   // absent.
841   if (numOperands != numExpectedOperands &&
842       numOperands != numExpectedOperands + 2)
843     return emitOpError("incorrect number of operands");
844 
845   // 5. Strides.
846   if (isStrided()) {
847     if (!getStride().getType().isIndex() ||
848         !getNumElementsPerStride().getType().isIndex())
849       return emitOpError(
850           "expected stride and num elements per stride to be of type index");
851   }
852 
853   return success();
854 }
855 
856 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
857                                SmallVectorImpl<OpFoldResult> &results) {
858   /// dma_start(memrefcast) -> dma_start
859   return foldMemRefCast(*this);
860 }
861 
862 // ---------------------------------------------------------------------------
863 // DmaWaitOp
864 // ---------------------------------------------------------------------------
865 
866 void DmaWaitOp::build(OpBuilder &builder, OperationState &result,
867                       Value tagMemRef, ValueRange tagIndices,
868                       Value numElements) {
869   result.addOperands(tagMemRef);
870   result.addOperands(tagIndices);
871   result.addOperands(numElements);
872 }
873 
874 void DmaWaitOp::print(OpAsmPrinter &p) {
875   p << getOperationName() << " " << getTagMemRef() << '[' << getTagIndices()
876     << "], " << getNumElements();
877   p.printOptionalAttrDict((*this)->getAttrs());
878   p << " : " << getTagMemRef().getType();
879 }
880 
881 // Parse DmaWaitOp.
882 // Eg:
883 //   dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4>
884 //
885 ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) {
886   OpAsmParser::OperandType tagMemrefInfo;
887   SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
888   Type type;
889   auto indexType = parser.getBuilder().getIndexType();
890   OpAsmParser::OperandType numElementsInfo;
891 
892   // Parse tag memref, its indices, and dma size.
893   if (parser.parseOperand(tagMemrefInfo) ||
894       parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) ||
895       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
896       parser.parseColonType(type) ||
897       parser.resolveOperand(tagMemrefInfo, type, result.operands) ||
898       parser.resolveOperands(tagIndexInfos, indexType, result.operands) ||
899       parser.resolveOperand(numElementsInfo, indexType, result.operands))
900     return failure();
901 
902   return success();
903 }
904 
905 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
906                               SmallVectorImpl<OpFoldResult> &results) {
907   /// dma_wait(memrefcast) -> dma_wait
908   return foldMemRefCast(*this);
909 }
910 
911 LogicalResult DmaWaitOp::verify() {
912   // Mandatory non-variadic operands are tag and the number of elements.
913   if (getNumOperands() < 2)
914     return emitOpError() << "expected at least 2 operands";
915 
916   // Check types of operands. The order of these calls is important: the later
917   // calls rely on some type properties to compute the operand position.
918   if (!getTagMemRef().getType().isa<MemRefType>())
919     return emitOpError() << "expected tag to be of memref type";
920 
921   if (getNumOperands() != 2 + getTagMemRefRank())
922     return emitOpError() << "expected " << 2 + getTagMemRefRank()
923                          << " operands";
924 
925   if (!getTagIndices().empty() &&
926       !llvm::all_of(getTagIndices().getTypes(),
927                     [](Type t) { return t.isIndex(); }))
928     return emitOpError() << "expected tag indices to be of index type";
929 
930   if (!getNumElements().getType().isIndex())
931     return emitOpError()
932            << "expected the number of elements to be of index type";
933 
934   return success();
935 }
936 
937 //===----------------------------------------------------------------------===//
938 // GlobalOp
939 //===----------------------------------------------------------------------===//
940 
941 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op,
942                                                    TypeAttr type,
943                                                    Attribute initialValue) {
944   p << type;
945   if (!op.isExternal()) {
946     p << " = ";
947     if (op.isUninitialized())
948       p << "uninitialized";
949     else
950       p.printAttributeWithoutType(initialValue);
951   }
952 }
953 
954 static ParseResult
955 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
956                                        Attribute &initialValue) {
957   Type type;
958   if (parser.parseType(type))
959     return failure();
960 
961   auto memrefType = type.dyn_cast<MemRefType>();
962   if (!memrefType || !memrefType.hasStaticShape())
963     return parser.emitError(parser.getNameLoc())
964            << "type should be static shaped memref, but got " << type;
965   typeAttr = TypeAttr::get(type);
966 
967   if (parser.parseOptionalEqual())
968     return success();
969 
970   if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
971     initialValue = UnitAttr::get(parser.getBuilder().getContext());
972     return success();
973   }
974 
975   Type tensorType = getTensorTypeFromMemRefType(memrefType);
976   if (parser.parseAttribute(initialValue, tensorType))
977     return failure();
978   if (!initialValue.isa<ElementsAttr>())
979     return parser.emitError(parser.getNameLoc())
980            << "initial value should be a unit or elements attribute";
981   return success();
982 }
983 
984 static LogicalResult verify(GlobalOp op) {
985   auto memrefType = op.type().dyn_cast<MemRefType>();
986   if (!memrefType || !memrefType.hasStaticShape())
987     return op.emitOpError("type should be static shaped memref, but got ")
988            << op.type();
989 
990   // Verify that the initial value, if present, is either a unit attribute or
991   // an elements attribute.
992   if (op.initial_value().hasValue()) {
993     Attribute initValue = op.initial_value().getValue();
994     if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>())
995       return op.emitOpError("initial value should be a unit or elements "
996                             "attribute, but got ")
997              << initValue;
998 
999     // Check that the type of the initial value is compatible with the type of
1000     // the global variable.
1001     if (initValue.isa<ElementsAttr>()) {
1002       Type initType = initValue.getType();
1003       Type tensorType = getTensorTypeFromMemRefType(memrefType);
1004       if (initType != tensorType)
1005         return op.emitOpError("initial value expected to be of type ")
1006                << tensorType << ", but was of type " << initType;
1007     }
1008   }
1009 
1010   // TODO: verify visibility for declarations.
1011   return success();
1012 }
1013 
1014 //===----------------------------------------------------------------------===//
1015 // GetGlobalOp
1016 //===----------------------------------------------------------------------===//
1017 
1018 LogicalResult
1019 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1020   // Verify that the result type is same as the type of the referenced
1021   // memref.global op.
1022   auto global =
1023       symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr());
1024   if (!global)
1025     return emitOpError("'")
1026            << name() << "' does not reference a valid global memref";
1027 
1028   Type resultType = result().getType();
1029   if (global.type() != resultType)
1030     return emitOpError("result type ")
1031            << resultType << " does not match type " << global.type()
1032            << " of the global memref @" << name();
1033   return success();
1034 }
1035 
1036 //===----------------------------------------------------------------------===//
1037 // LoadOp
1038 //===----------------------------------------------------------------------===//
1039 
1040 static LogicalResult verify(LoadOp op) {
1041   if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
1042     return op.emitOpError("incorrect number of indices for load");
1043   return success();
1044 }
1045 
1046 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
1047   /// load(memrefcast) -> load
1048   if (succeeded(foldMemRefCast(*this)))
1049     return getResult();
1050   return OpFoldResult();
1051 }
1052 
1053 namespace {
1054 /// Fold a load on a buffer_cast operation into an tensor.extract on the
1055 /// corresponding tensor.
1056 struct LoadOfBufferCast : public OpRewritePattern<LoadOp> {
1057   using OpRewritePattern<LoadOp>::OpRewritePattern;
1058 
1059   LogicalResult matchAndRewrite(LoadOp load,
1060                                 PatternRewriter &rewriter) const override {
1061     auto buffercast = load.memref().getDefiningOp<BufferCastOp>();
1062     if (!buffercast)
1063       return failure();
1064 
1065     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, buffercast.tensor(),
1066                                                    load.indices());
1067     return success();
1068   }
1069 };
1070 } // end anonymous namespace.
1071 
1072 void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
1073                                          MLIRContext *context) {
1074   results.add<LoadOfBufferCast>(context);
1075 }
1076 
1077 //===----------------------------------------------------------------------===//
1078 // PrefetchOp
1079 //===----------------------------------------------------------------------===//
1080 
1081 static void print(OpAsmPrinter &p, PrefetchOp op) {
1082   p << PrefetchOp::getOperationName() << " " << op.memref() << '[';
1083   p.printOperands(op.indices());
1084   p << ']' << ", " << (op.isWrite() ? "write" : "read");
1085   p << ", locality<" << op.localityHint();
1086   p << ">, " << (op.isDataCache() ? "data" : "instr");
1087   p.printOptionalAttrDict(
1088       op->getAttrs(),
1089       /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1090   p << " : " << op.getMemRefType();
1091 }
1092 
1093 static ParseResult parsePrefetchOp(OpAsmParser &parser,
1094                                    OperationState &result) {
1095   OpAsmParser::OperandType memrefInfo;
1096   SmallVector<OpAsmParser::OperandType, 4> indexInfo;
1097   IntegerAttr localityHint;
1098   MemRefType type;
1099   StringRef readOrWrite, cacheType;
1100 
1101   auto indexTy = parser.getBuilder().getIndexType();
1102   auto i32Type = parser.getBuilder().getIntegerType(32);
1103   if (parser.parseOperand(memrefInfo) ||
1104       parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1105       parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1106       parser.parseComma() || parser.parseKeyword("locality") ||
1107       parser.parseLess() ||
1108       parser.parseAttribute(localityHint, i32Type, "localityHint",
1109                             result.attributes) ||
1110       parser.parseGreater() || parser.parseComma() ||
1111       parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1112       parser.resolveOperand(memrefInfo, type, result.operands) ||
1113       parser.resolveOperands(indexInfo, indexTy, result.operands))
1114     return failure();
1115 
1116   if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
1117     return parser.emitError(parser.getNameLoc(),
1118                             "rw specifier has to be 'read' or 'write'");
1119   result.addAttribute(
1120       PrefetchOp::getIsWriteAttrName(),
1121       parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
1122 
1123   if (!cacheType.equals("data") && !cacheType.equals("instr"))
1124     return parser.emitError(parser.getNameLoc(),
1125                             "cache type has to be 'data' or 'instr'");
1126 
1127   result.addAttribute(
1128       PrefetchOp::getIsDataCacheAttrName(),
1129       parser.getBuilder().getBoolAttr(cacheType.equals("data")));
1130 
1131   return success();
1132 }
1133 
1134 static LogicalResult verify(PrefetchOp op) {
1135   if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
1136     return op.emitOpError("too few indices");
1137 
1138   return success();
1139 }
1140 
1141 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
1142                                SmallVectorImpl<OpFoldResult> &results) {
1143   // prefetch(memrefcast) -> prefetch
1144   return foldMemRefCast(*this);
1145 }
1146 
1147 //===----------------------------------------------------------------------===//
1148 // ReinterpretCastOp
1149 //===----------------------------------------------------------------------===//
1150 
1151 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1152 /// `staticSizes` and `staticStrides` are automatically filled with
1153 /// source-memref-rank sentinel values that encode dynamic entries.
1154 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1155                               MemRefType resultType, Value source,
1156                               OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1157                               ArrayRef<OpFoldResult> strides,
1158                               ArrayRef<NamedAttribute> attrs) {
1159   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1160   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1161   dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets,
1162                              ShapedType::kDynamicStrideOrOffset);
1163   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1164                              ShapedType::kDynamicSize);
1165   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1166                              ShapedType::kDynamicStrideOrOffset);
1167   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1168         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1169         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1170   result.addAttributes(attrs);
1171 }
1172 
1173 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1174                               MemRefType resultType, Value source,
1175                               int64_t offset, ArrayRef<int64_t> sizes,
1176                               ArrayRef<int64_t> strides,
1177                               ArrayRef<NamedAttribute> attrs) {
1178   SmallVector<OpFoldResult> sizeValues =
1179       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1180         return b.getI64IntegerAttr(v);
1181       }));
1182   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1183       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1184         return b.getI64IntegerAttr(v);
1185       }));
1186   build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1187         strideValues, attrs);
1188 }
1189 
1190 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1191                               MemRefType resultType, Value source, Value offset,
1192                               ValueRange sizes, ValueRange strides,
1193                               ArrayRef<NamedAttribute> attrs) {
1194   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1195       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1196   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1197       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1198   build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1199 }
1200 
1201 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1202 // completed automatically, like we have for subview and subtensor.
1203 static LogicalResult verify(ReinterpretCastOp op) {
1204   // The source and result memrefs should be in the same memory space.
1205   auto srcType = op.source().getType().cast<BaseMemRefType>();
1206   auto resultType = op.getType().cast<MemRefType>();
1207   if (srcType.getMemorySpaceAsInt() != resultType.getMemorySpaceAsInt())
1208     return op.emitError("different memory spaces specified for source type ")
1209            << srcType << " and result memref type " << resultType;
1210   if (srcType.getElementType() != resultType.getElementType())
1211     return op.emitError("different element types specified for source type ")
1212            << srcType << " and result memref type " << resultType;
1213 
1214   // Match sizes in result memref type and in static_sizes attribute.
1215   for (auto &en :
1216        llvm::enumerate(llvm::zip(resultType.getShape(),
1217                                  extractFromI64ArrayAttr(op.static_sizes())))) {
1218     int64_t resultSize = std::get<0>(en.value());
1219     int64_t expectedSize = std::get<1>(en.value());
1220     if (resultSize != expectedSize)
1221       return op.emitError("expected result type with size = ")
1222              << expectedSize << " instead of " << resultSize
1223              << " in dim = " << en.index();
1224   }
1225 
1226   // Match offset and strides in static_offset and static_strides attributes if
1227   // result memref type has an affine map specified.
1228   if (!resultType.getAffineMaps().empty()) {
1229     int64_t resultOffset;
1230     SmallVector<int64_t, 4> resultStrides;
1231     if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
1232       return failure();
1233 
1234     // Match offset in result memref type and in static_offsets attribute.
1235     int64_t expectedOffset =
1236         extractFromI64ArrayAttr(op.static_offsets()).front();
1237     if (resultOffset != expectedOffset)
1238       return op.emitError("expected result type with offset = ")
1239              << resultOffset << " instead of " << expectedOffset;
1240 
1241     // Match strides in result memref type and in static_strides attribute.
1242     for (auto &en : llvm::enumerate(llvm::zip(
1243              resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
1244       int64_t resultStride = std::get<0>(en.value());
1245       int64_t expectedStride = std::get<1>(en.value());
1246       if (resultStride != expectedStride)
1247         return op.emitError("expected result type with stride = ")
1248                << expectedStride << " instead of " << resultStride
1249                << " in dim = " << en.index();
1250     }
1251   }
1252   return success();
1253 }
1254 
1255 //===----------------------------------------------------------------------===//
1256 // ReshapeOp
1257 //===----------------------------------------------------------------------===//
1258 
1259 static LogicalResult verify(ReshapeOp op) {
1260   Type operandType = op.source().getType();
1261   Type resultType = op.result().getType();
1262 
1263   Type operandElementType = operandType.cast<ShapedType>().getElementType();
1264   Type resultElementType = resultType.cast<ShapedType>().getElementType();
1265   if (operandElementType != resultElementType)
1266     return op.emitOpError("element types of source and destination memref "
1267                           "types should be the same");
1268 
1269   if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
1270     if (!operandMemRefType.getAffineMaps().empty())
1271       return op.emitOpError(
1272           "source memref type should have identity affine map");
1273 
1274   int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0);
1275   auto resultMemRefType = resultType.dyn_cast<MemRefType>();
1276   if (resultMemRefType) {
1277     if (!resultMemRefType.getAffineMaps().empty())
1278       return op.emitOpError(
1279           "result memref type should have identity affine map");
1280     if (shapeSize == ShapedType::kDynamicSize)
1281       return op.emitOpError("cannot use shape operand with dynamic length to "
1282                             "reshape to statically-ranked memref type");
1283     if (shapeSize != resultMemRefType.getRank())
1284       return op.emitOpError(
1285           "length of shape operand differs from the result's memref rank");
1286   }
1287   return success();
1288 }
1289 
1290 //===----------------------------------------------------------------------===//
1291 // StoreOp
1292 //===----------------------------------------------------------------------===//
1293 
1294 static LogicalResult verify(StoreOp op) {
1295   if (op.getNumOperands() != 2 + op.getMemRefType().getRank())
1296     return op.emitOpError("store index operand count not equal to memref rank");
1297 
1298   return success();
1299 }
1300 
1301 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
1302                             SmallVectorImpl<OpFoldResult> &results) {
1303   /// store(memrefcast) -> store
1304   return foldMemRefCast(*this);
1305 }
1306 
1307 //===----------------------------------------------------------------------===//
1308 // SubViewOp
1309 //===----------------------------------------------------------------------===//
1310 
1311 namespace {
1312 /// Helpers to write more idiomatic operations.
1313 namespace saturated_arith {
1314 struct Wrapper {
1315   explicit Wrapper(int64_t v) : v(v) {}
1316   operator int64_t() { return v; }
1317   int64_t v;
1318 };
1319 Wrapper operator+(Wrapper a, int64_t b) {
1320   if (ShapedType::isDynamicStrideOrOffset(a) ||
1321       ShapedType::isDynamicStrideOrOffset(b))
1322     return Wrapper(ShapedType::kDynamicStrideOrOffset);
1323   return Wrapper(a.v + b);
1324 }
1325 Wrapper operator*(Wrapper a, int64_t b) {
1326   if (ShapedType::isDynamicStrideOrOffset(a) ||
1327       ShapedType::isDynamicStrideOrOffset(b))
1328     return Wrapper(ShapedType::kDynamicStrideOrOffset);
1329   return Wrapper(a.v * b);
1330 }
1331 } // end namespace saturated_arith
1332 } // end namespace
1333 
1334 /// A subview result type can be fully inferred from the source type and the
1335 /// static representation of offsets, sizes and strides. Special sentinels
1336 /// encode the dynamic case.
1337 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
1338                                 ArrayRef<int64_t> leadingStaticOffsets,
1339                                 ArrayRef<int64_t> leadingStaticSizes,
1340                                 ArrayRef<int64_t> leadingStaticStrides) {
1341   // A subview may specify only a leading subset of offset/sizes/strides in
1342   // which case we complete with offset=0, sizes from memref type and strides=1.
1343   unsigned rank = sourceMemRefType.getRank();
1344   assert(leadingStaticOffsets.size() <= rank &&
1345          "unexpected leadingStaticOffsets overflow");
1346   assert(leadingStaticSizes.size() <= rank &&
1347          "unexpected leadingStaticSizes overflow");
1348   assert(leadingStaticStrides.size() <= rank &&
1349          "unexpected leadingStaticStrides overflow");
1350   auto staticOffsets = llvm::to_vector<4>(leadingStaticOffsets);
1351   auto staticSizes = llvm::to_vector<4>(leadingStaticSizes);
1352   auto staticStrides = llvm::to_vector<4>(leadingStaticStrides);
1353   unsigned numTrailingOffsets = rank - staticOffsets.size();
1354   unsigned numTrailingSizes = rank - staticSizes.size();
1355   unsigned numTrailingStrides = rank - staticStrides.size();
1356   staticOffsets.append(numTrailingOffsets, 0);
1357   llvm::append_range(staticSizes,
1358                      sourceMemRefType.getShape().take_back(numTrailingSizes));
1359   staticStrides.append(numTrailingStrides, 1);
1360 
1361   // Extract source offset and strides.
1362   int64_t sourceOffset;
1363   SmallVector<int64_t, 4> sourceStrides;
1364   auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset);
1365   assert(succeeded(res) && "SubViewOp expected strided memref type");
1366   (void)res;
1367 
1368   // Compute target offset whose value is:
1369   //   `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
1370   int64_t targetOffset = sourceOffset;
1371   for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
1372     auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
1373     using namespace saturated_arith;
1374     targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride;
1375   }
1376 
1377   // Compute target stride whose value is:
1378   //   `sourceStrides_i * staticStrides_i`.
1379   SmallVector<int64_t, 4> targetStrides;
1380   targetStrides.reserve(staticOffsets.size());
1381   for (auto it : llvm::zip(sourceStrides, staticStrides)) {
1382     auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
1383     using namespace saturated_arith;
1384     targetStrides.push_back(Wrapper(sourceStride) * staticStride);
1385   }
1386 
1387   // The type is now known.
1388   return MemRefType::get(
1389       staticSizes, sourceMemRefType.getElementType(),
1390       makeStridedLinearLayoutMap(targetStrides, targetOffset,
1391                                  sourceMemRefType.getContext()),
1392       sourceMemRefType.getMemorySpaceAsInt());
1393 }
1394 
1395 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
1396                                 ArrayRef<OpFoldResult> leadingStaticOffsets,
1397                                 ArrayRef<OpFoldResult> leadingStaticSizes,
1398                                 ArrayRef<OpFoldResult> leadingStaticStrides) {
1399   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1400   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1401   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
1402                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
1403   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
1404                              ShapedType::kDynamicSize);
1405   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
1406                              staticStrides, ShapedType::kDynamicStrideOrOffset);
1407   return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
1408                                     staticSizes, staticStrides)
1409       .cast<MemRefType>();
1410 }
1411 
1412 Type SubViewOp::inferRankReducedResultType(
1413     unsigned resultRank, MemRefType sourceRankedTensorType,
1414     ArrayRef<int64_t> leadingStaticOffsets,
1415     ArrayRef<int64_t> leadingStaticSizes,
1416     ArrayRef<int64_t> leadingStaticStrides) {
1417   auto inferredType =
1418       inferResultType(sourceRankedTensorType, leadingStaticOffsets,
1419                       leadingStaticSizes, leadingStaticStrides)
1420           .cast<MemRefType>();
1421   assert(inferredType.getRank() >= resultRank && "expected ");
1422   int rankDiff = inferredType.getRank() - resultRank;
1423   if (rankDiff > 0) {
1424     auto shape = inferredType.getShape();
1425     llvm::SmallDenseSet<unsigned> dimsToProject;
1426     mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject);
1427     SmallVector<int64_t> projectedShape;
1428     for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
1429       if (!dimsToProject.contains(pos))
1430         projectedShape.push_back(shape[pos]);
1431 
1432     AffineMap map;
1433     auto maps = inferredType.getAffineMaps();
1434     if (!maps.empty() && maps.front())
1435       map = getProjectedMap(maps.front(), dimsToProject);
1436     inferredType =
1437         MemRefType::get(projectedShape, inferredType.getElementType(), map,
1438                         inferredType.getMemorySpaceAsInt());
1439   }
1440   return inferredType;
1441 }
1442 
1443 Type SubViewOp::inferRankReducedResultType(
1444     unsigned resultRank, MemRefType sourceRankedTensorType,
1445     ArrayRef<OpFoldResult> leadingStaticOffsets,
1446     ArrayRef<OpFoldResult> leadingStaticSizes,
1447     ArrayRef<OpFoldResult> leadingStaticStrides) {
1448   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1449   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1450   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
1451                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
1452   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
1453                              ShapedType::kDynamicSize);
1454   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
1455                              staticStrides, ShapedType::kDynamicStrideOrOffset);
1456   return SubViewOp::inferRankReducedResultType(
1457       resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
1458       staticStrides);
1459 }
1460 // Build a SubViewOp with mixed static and dynamic entries and custom result
1461 // type. If the type passed is nullptr, it is inferred.
1462 void SubViewOp::build(OpBuilder &b, OperationState &result,
1463                       MemRefType resultType, Value source,
1464                       ArrayRef<OpFoldResult> offsets,
1465                       ArrayRef<OpFoldResult> sizes,
1466                       ArrayRef<OpFoldResult> strides,
1467                       ArrayRef<NamedAttribute> attrs) {
1468   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1469   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1470   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1471                              ShapedType::kDynamicStrideOrOffset);
1472   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1473                              ShapedType::kDynamicSize);
1474   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1475                              ShapedType::kDynamicStrideOrOffset);
1476   auto sourceMemRefType = source.getType().cast<MemRefType>();
1477   // Structuring implementation this way avoids duplication between builders.
1478   if (!resultType) {
1479     resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
1480                                             staticSizes, staticStrides)
1481                      .cast<MemRefType>();
1482   }
1483   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1484         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1485         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1486   result.addAttributes(attrs);
1487 }
1488 
1489 // Build a SubViewOp with mixed static and dynamic entries and inferred result
1490 // type.
1491 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1492                       ArrayRef<OpFoldResult> offsets,
1493                       ArrayRef<OpFoldResult> sizes,
1494                       ArrayRef<OpFoldResult> strides,
1495                       ArrayRef<NamedAttribute> attrs) {
1496   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
1497 }
1498 
1499 // Build a SubViewOp with static entries and inferred result type.
1500 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1501                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1502                       ArrayRef<int64_t> strides,
1503                       ArrayRef<NamedAttribute> attrs) {
1504   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1505       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
1506         return b.getI64IntegerAttr(v);
1507       }));
1508   SmallVector<OpFoldResult> sizeValues =
1509       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1510         return b.getI64IntegerAttr(v);
1511       }));
1512   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1513       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1514         return b.getI64IntegerAttr(v);
1515       }));
1516   build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
1517 }
1518 
1519 // Build a SubViewOp with dynamic entries and custom result type. If the
1520 // type passed is nullptr, it is inferred.
1521 void SubViewOp::build(OpBuilder &b, OperationState &result,
1522                       MemRefType resultType, Value source,
1523                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1524                       ArrayRef<int64_t> strides,
1525                       ArrayRef<NamedAttribute> attrs) {
1526   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1527       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
1528         return b.getI64IntegerAttr(v);
1529       }));
1530   SmallVector<OpFoldResult> sizeValues =
1531       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1532         return b.getI64IntegerAttr(v);
1533       }));
1534   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1535       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1536         return b.getI64IntegerAttr(v);
1537       }));
1538   build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
1539         attrs);
1540 }
1541 
1542 // Build a SubViewOp with dynamic entries and custom result type. If the type
1543 // passed is nullptr, it is inferred.
1544 void SubViewOp::build(OpBuilder &b, OperationState &result,
1545                       MemRefType resultType, Value source, ValueRange offsets,
1546                       ValueRange sizes, ValueRange strides,
1547                       ArrayRef<NamedAttribute> attrs) {
1548   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1549       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1550   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1551       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1552   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1553       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1554   build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
1555 }
1556 
1557 // Build a SubViewOp with dynamic entries and inferred result type.
1558 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1559                       ValueRange offsets, ValueRange sizes, ValueRange strides,
1560                       ArrayRef<NamedAttribute> attrs) {
1561   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
1562 }
1563 
1564 /// For ViewLikeOpInterface.
1565 Value SubViewOp::getViewSource() { return source(); }
1566 
1567 enum SubViewVerificationResult {
1568   Success,
1569   RankTooLarge,
1570   SizeMismatch,
1571   ElemTypeMismatch,
1572   MemSpaceMismatch,
1573   AffineMapMismatch
1574 };
1575 
1576 /// Checks if `original` Type type can be rank reduced to `reduced` type.
1577 /// This function is slight variant of `is subsequence` algorithm where
1578 /// not matching dimension must be 1.
1579 static SubViewVerificationResult
1580 isRankReducedType(Type originalType, Type candidateReducedType,
1581                   std::string *errMsg = nullptr) {
1582   if (originalType == candidateReducedType)
1583     return SubViewVerificationResult::Success;
1584   if (!originalType.isa<MemRefType>())
1585     return SubViewVerificationResult::Success;
1586   if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>())
1587     return SubViewVerificationResult::Success;
1588 
1589   ShapedType originalShapedType = originalType.cast<ShapedType>();
1590   ShapedType candidateReducedShapedType =
1591       candidateReducedType.cast<ShapedType>();
1592 
1593   // Rank and size logic is valid for all ShapedTypes.
1594   ArrayRef<int64_t> originalShape = originalShapedType.getShape();
1595   ArrayRef<int64_t> candidateReducedShape =
1596       candidateReducedShapedType.getShape();
1597   unsigned originalRank = originalShape.size(),
1598            candidateReducedRank = candidateReducedShape.size();
1599   if (candidateReducedRank > originalRank)
1600     return SubViewVerificationResult::RankTooLarge;
1601 
1602   auto optionalUnusedDimsMask =
1603       computeRankReductionMask(originalShape, candidateReducedShape);
1604 
1605   // Sizes cannot be matched in case empty vector is returned.
1606   if (!optionalUnusedDimsMask.hasValue())
1607     return SubViewVerificationResult::SizeMismatch;
1608 
1609   if (originalShapedType.getElementType() !=
1610       candidateReducedShapedType.getElementType())
1611     return SubViewVerificationResult::ElemTypeMismatch;
1612 
1613   // Strided layout logic is relevant for MemRefType only.
1614   MemRefType original = originalType.cast<MemRefType>();
1615   MemRefType candidateReduced = candidateReducedType.cast<MemRefType>();
1616   if (original.getMemorySpaceAsInt() != candidateReduced.getMemorySpaceAsInt())
1617     return SubViewVerificationResult::MemSpaceMismatch;
1618 
1619   llvm::SmallDenseSet<unsigned> unusedDims = optionalUnusedDimsMask.getValue();
1620   auto inferredType =
1621       getProjectedMap(getStridedLinearLayoutMap(original), unusedDims);
1622   AffineMap candidateLayout;
1623   if (candidateReduced.getAffineMaps().empty())
1624     candidateLayout = getStridedLinearLayoutMap(candidateReduced);
1625   else
1626     candidateLayout = candidateReduced.getAffineMaps().front();
1627   assert(inferredType.getNumResults() == 1 &&
1628          candidateLayout.getNumResults() == 1);
1629   if (inferredType.getNumSymbols() != candidateLayout.getNumSymbols() ||
1630       inferredType.getNumDims() != candidateLayout.getNumDims()) {
1631     if (errMsg) {
1632       llvm::raw_string_ostream os(*errMsg);
1633       os << "inferred type: " << inferredType;
1634     }
1635     return SubViewVerificationResult::AffineMapMismatch;
1636   }
1637   // Check that the difference of the affine maps simplifies to 0.
1638   AffineExpr diffExpr =
1639       inferredType.getResult(0) - candidateLayout.getResult(0);
1640   diffExpr = simplifyAffineExpr(diffExpr, inferredType.getNumDims(),
1641                                 inferredType.getNumSymbols());
1642   auto cst = diffExpr.dyn_cast<AffineConstantExpr>();
1643   if (!(cst && cst.getValue() == 0)) {
1644     if (errMsg) {
1645       llvm::raw_string_ostream os(*errMsg);
1646       os << "inferred type: " << inferredType;
1647     }
1648     return SubViewVerificationResult::AffineMapMismatch;
1649   }
1650   return SubViewVerificationResult::Success;
1651 }
1652 
1653 template <typename OpTy>
1654 static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result,
1655                                             OpTy op, Type expectedType,
1656                                             StringRef errMsg = "") {
1657   auto memrefType = expectedType.cast<ShapedType>();
1658   switch (result) {
1659   case SubViewVerificationResult::Success:
1660     return success();
1661   case SubViewVerificationResult::RankTooLarge:
1662     return op.emitError("expected result rank to be smaller or equal to ")
1663            << "the source rank. " << errMsg;
1664   case SubViewVerificationResult::SizeMismatch:
1665     return op.emitError("expected result type to be ")
1666            << expectedType
1667            << " or a rank-reduced version. (mismatch of result sizes) "
1668            << errMsg;
1669   case SubViewVerificationResult::ElemTypeMismatch:
1670     return op.emitError("expected result element type to be ")
1671            << memrefType.getElementType() << errMsg;
1672   case SubViewVerificationResult::MemSpaceMismatch:
1673     return op.emitError("expected result and source memory spaces to match.")
1674            << errMsg;
1675   case SubViewVerificationResult::AffineMapMismatch:
1676     return op.emitError("expected result type to be ")
1677            << expectedType
1678            << " or a rank-reduced version. (mismatch of result affine map) "
1679            << errMsg;
1680   }
1681   llvm_unreachable("unexpected subview verification result");
1682 }
1683 
1684 /// Verifier for SubViewOp.
1685 static LogicalResult verify(SubViewOp op) {
1686   MemRefType baseType = op.getSourceType();
1687   MemRefType subViewType = op.getType();
1688 
1689   // The base memref and the view memref should be in the same memory space.
1690   if (baseType.getMemorySpaceAsInt() != subViewType.getMemorySpaceAsInt())
1691     return op.emitError("different memory spaces specified for base memref "
1692                         "type ")
1693            << baseType << " and subview memref type " << subViewType;
1694 
1695   // Verify that the base memref type has a strided layout map.
1696   if (!isStrided(baseType))
1697     return op.emitError("base type ") << baseType << " is not strided";
1698 
1699   // Verify result type against inferred type.
1700   auto expectedType = SubViewOp::inferResultType(
1701       baseType, extractFromI64ArrayAttr(op.static_offsets()),
1702       extractFromI64ArrayAttr(op.static_sizes()),
1703       extractFromI64ArrayAttr(op.static_strides()));
1704 
1705   std::string errMsg;
1706   auto result = isRankReducedType(expectedType, subViewType, &errMsg);
1707   return produceSubViewErrorMsg(result, op, expectedType, errMsg);
1708 }
1709 
1710 raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) {
1711   return os << "range " << range.offset << ":" << range.size << ":"
1712             << range.stride;
1713 }
1714 
1715 /// Return the list of Range (i.e. offset, size, stride). Each Range
1716 /// entry contains either the dynamic value or a ConstantIndexOp constructed
1717 /// with `b` at location `loc`.
1718 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
1719                                               OpBuilder &b, Location loc) {
1720   std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
1721   assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
1722   assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
1723   SmallVector<Range, 8> res;
1724   unsigned rank = ranks[0];
1725   res.reserve(rank);
1726   for (unsigned idx = 0; idx < rank; ++idx) {
1727     Value offset =
1728         op.isDynamicOffset(idx)
1729             ? op.getDynamicOffset(idx)
1730             : b.create<ConstantIndexOp>(loc, op.getStaticOffset(idx));
1731     Value size = op.isDynamicSize(idx)
1732                      ? op.getDynamicSize(idx)
1733                      : b.create<ConstantIndexOp>(loc, op.getStaticSize(idx));
1734     Value stride =
1735         op.isDynamicStride(idx)
1736             ? op.getDynamicStride(idx)
1737             : b.create<ConstantIndexOp>(loc, op.getStaticStride(idx));
1738     res.emplace_back(Range{offset, size, stride});
1739   }
1740   return res;
1741 }
1742 
1743 namespace {
1744 /// Pattern to rewrite a subview op with MemRefCast arguments.
1745 /// This essentially pushes memref.cast past its consuming subview when
1746 /// `canFoldIntoConsumerOp` is true.
1747 ///
1748 /// Example:
1749 /// ```
1750 ///   %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
1751 ///   %1 = memref.subview %0[0, 0][3, 4][1, 1] :
1752 ///     memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
1753 /// ```
1754 /// is rewritten into:
1755 /// ```
1756 ///   %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
1757 ///   %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to
1758 ///     memref<3x4xf32, offset:?, strides:[?, 1]>
1759 /// ```
1760 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
1761 public:
1762   using OpRewritePattern<SubViewOp>::OpRewritePattern;
1763 
1764   LogicalResult matchAndRewrite(SubViewOp subViewOp,
1765                                 PatternRewriter &rewriter) const override {
1766     // Any constant operand, just return to let SubViewOpConstantFolder kick in.
1767     if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
1768           return matchPattern(operand, matchConstantIndex());
1769         }))
1770       return failure();
1771 
1772     auto castOp = subViewOp.source().getDefiningOp<CastOp>();
1773     if (!castOp)
1774       return failure();
1775 
1776     if (!CastOp::canFoldIntoConsumerOp(castOp))
1777       return failure();
1778 
1779     /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
1780     /// the cast source operand type and the SubViewOp static information. This
1781     /// is the resulting type if the MemRefCastOp were folded.
1782     auto resultType = SubViewOp::inferRankReducedResultType(
1783         subViewOp.getType().getRank(),
1784         castOp.source().getType().cast<MemRefType>(),
1785         subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
1786         subViewOp.getMixedStrides());
1787     Value newSubView = rewriter.create<SubViewOp>(
1788         subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
1789         subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),
1790         subViewOp.static_sizes(), subViewOp.static_strides());
1791     rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
1792                                         newSubView);
1793     return success();
1794   }
1795 };
1796 } // namespace
1797 
1798 /// A canonicalizer wrapper to replace SubViewOps.
1799 struct SubViewCanonicalizer {
1800   void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
1801     rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType());
1802   }
1803 };
1804 
1805 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
1806                                             MLIRContext *context) {
1807   results.add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
1808                   SubViewOp, SubViewCanonicalizer>,
1809               SubViewOpMemRefCastFolder>(context);
1810 }
1811 
1812 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
1813   auto resultShapedType = getResult().getType().cast<ShapedType>();
1814   auto sourceShapedType = source().getType().cast<ShapedType>();
1815 
1816   if (resultShapedType.hasStaticShape() &&
1817       resultShapedType == sourceShapedType) {
1818     return getViewSource();
1819   }
1820 
1821   return {};
1822 }
1823 
1824 //===----------------------------------------------------------------------===//
1825 // TensorLoadOp
1826 //===----------------------------------------------------------------------===//
1827 
1828 OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute>) {
1829   if (auto bufferCast = memref().getDefiningOp<BufferCastOp>())
1830     // Approximate alias analysis by conservatively folding only when no there
1831     // is no interleaved operation.
1832     if (bufferCast->getBlock() == this->getOperation()->getBlock() &&
1833         bufferCast->getNextNode() == this->getOperation())
1834       return bufferCast.tensor();
1835   return {};
1836 }
1837 
1838 //===----------------------------------------------------------------------===//
1839 // TransposeOp
1840 //===----------------------------------------------------------------------===//
1841 
1842 /// Build a strided memref type by applying `permutationMap` tp `memRefType`.
1843 static MemRefType inferTransposeResultType(MemRefType memRefType,
1844                                            AffineMap permutationMap) {
1845   auto rank = memRefType.getRank();
1846   auto originalSizes = memRefType.getShape();
1847   // Compute permuted sizes.
1848   SmallVector<int64_t, 4> sizes(rank, 0);
1849   for (auto en : llvm::enumerate(permutationMap.getResults()))
1850     sizes[en.index()] =
1851         originalSizes[en.value().cast<AffineDimExpr>().getPosition()];
1852 
1853   // Compute permuted strides.
1854   int64_t offset;
1855   SmallVector<int64_t, 4> strides;
1856   auto res = getStridesAndOffset(memRefType, strides, offset);
1857   assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank));
1858   (void)res;
1859   auto map =
1860       makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
1861   map = permutationMap ? map.compose(permutationMap) : map;
1862   return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map);
1863 }
1864 
1865 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
1866                         AffineMapAttr permutation,
1867                         ArrayRef<NamedAttribute> attrs) {
1868   auto permutationMap = permutation.getValue();
1869   assert(permutationMap);
1870 
1871   auto memRefType = in.getType().cast<MemRefType>();
1872   // Compute result type.
1873   MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
1874 
1875   build(b, result, resultType, in, attrs);
1876   result.addAttribute(TransposeOp::getPermutationAttrName(), permutation);
1877 }
1878 
1879 // transpose $in $permutation attr-dict : type($in) `to` type(results)
1880 static void print(OpAsmPrinter &p, TransposeOp op) {
1881   p << "memref.transpose " << op.in() << " " << op.permutation();
1882   p.printOptionalAttrDict(op->getAttrs(),
1883                           {TransposeOp::getPermutationAttrName()});
1884   p << " : " << op.in().getType() << " to " << op.getType();
1885 }
1886 
1887 static ParseResult parseTransposeOp(OpAsmParser &parser,
1888                                     OperationState &result) {
1889   OpAsmParser::OperandType in;
1890   AffineMap permutation;
1891   MemRefType srcType, dstType;
1892   if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
1893       parser.parseOptionalAttrDict(result.attributes) ||
1894       parser.parseColonType(srcType) ||
1895       parser.resolveOperand(in, srcType, result.operands) ||
1896       parser.parseKeywordType("to", dstType) ||
1897       parser.addTypeToList(dstType, result.types))
1898     return failure();
1899 
1900   result.addAttribute(TransposeOp::getPermutationAttrName(),
1901                       AffineMapAttr::get(permutation));
1902   return success();
1903 }
1904 
1905 static LogicalResult verify(TransposeOp op) {
1906   if (!op.permutation().isPermutation())
1907     return op.emitOpError("expected a permutation map");
1908   if (op.permutation().getNumDims() != op.getShapedType().getRank())
1909     return op.emitOpError(
1910         "expected a permutation map of same rank as the input");
1911 
1912   auto srcType = op.in().getType().cast<MemRefType>();
1913   auto dstType = op.getType().cast<MemRefType>();
1914   auto transposedType = inferTransposeResultType(srcType, op.permutation());
1915   if (dstType != transposedType)
1916     return op.emitOpError("output type ")
1917            << dstType << " does not match transposed input type " << srcType
1918            << ", " << transposedType;
1919   return success();
1920 }
1921 
1922 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
1923   if (succeeded(foldMemRefCast(*this)))
1924     return getResult();
1925   return {};
1926 }
1927 
1928 //===----------------------------------------------------------------------===//
1929 // ViewOp
1930 //===----------------------------------------------------------------------===//
1931 
1932 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) {
1933   OpAsmParser::OperandType srcInfo;
1934   SmallVector<OpAsmParser::OperandType, 1> offsetInfo;
1935   SmallVector<OpAsmParser::OperandType, 4> sizesInfo;
1936   auto indexType = parser.getBuilder().getIndexType();
1937   Type srcType, dstType;
1938   llvm::SMLoc offsetLoc;
1939   if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) ||
1940       parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square))
1941     return failure();
1942 
1943   if (offsetInfo.size() != 1)
1944     return parser.emitError(offsetLoc) << "expects 1 offset operand";
1945 
1946   return failure(
1947       parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) ||
1948       parser.parseOptionalAttrDict(result.attributes) ||
1949       parser.parseColonType(srcType) ||
1950       parser.resolveOperand(srcInfo, srcType, result.operands) ||
1951       parser.resolveOperands(offsetInfo, indexType, result.operands) ||
1952       parser.resolveOperands(sizesInfo, indexType, result.operands) ||
1953       parser.parseKeywordType("to", dstType) ||
1954       parser.addTypeToList(dstType, result.types));
1955 }
1956 
1957 static void print(OpAsmPrinter &p, ViewOp op) {
1958   p << op.getOperationName() << ' ' << op.getOperand(0) << '[';
1959   p.printOperand(op.byte_shift());
1960   p << "][" << op.sizes() << ']';
1961   p.printOptionalAttrDict(op->getAttrs());
1962   p << " : " << op.getOperand(0).getType() << " to " << op.getType();
1963 }
1964 
1965 static LogicalResult verify(ViewOp op) {
1966   auto baseType = op.getOperand(0).getType().cast<MemRefType>();
1967   auto viewType = op.getType();
1968 
1969   // The base memref should have identity layout map (or none).
1970   if (baseType.getAffineMaps().size() > 1 ||
1971       (baseType.getAffineMaps().size() == 1 &&
1972        !baseType.getAffineMaps()[0].isIdentity()))
1973     return op.emitError("unsupported map for base memref type ") << baseType;
1974 
1975   // The result memref should have identity layout map (or none).
1976   if (viewType.getAffineMaps().size() > 1 ||
1977       (viewType.getAffineMaps().size() == 1 &&
1978        !viewType.getAffineMaps()[0].isIdentity()))
1979     return op.emitError("unsupported map for result memref type ") << viewType;
1980 
1981   // The base memref and the view memref should be in the same memory space.
1982   if (baseType.getMemorySpaceAsInt() != viewType.getMemorySpaceAsInt())
1983     return op.emitError("different memory spaces specified for base memref "
1984                         "type ")
1985            << baseType << " and view memref type " << viewType;
1986 
1987   // Verify that we have the correct number of sizes for the result type.
1988   unsigned numDynamicDims = viewType.getNumDynamicDims();
1989   if (op.sizes().size() != numDynamicDims)
1990     return op.emitError("incorrect number of size operands for type ")
1991            << viewType;
1992 
1993   return success();
1994 }
1995 
1996 Value ViewOp::getViewSource() { return source(); }
1997 
1998 namespace {
1999 
2000 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
2001   using OpRewritePattern<ViewOp>::OpRewritePattern;
2002 
2003   LogicalResult matchAndRewrite(ViewOp viewOp,
2004                                 PatternRewriter &rewriter) const override {
2005     // Return if none of the operands are constants.
2006     if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
2007           return matchPattern(operand, matchConstantIndex());
2008         }))
2009       return failure();
2010 
2011     // Get result memref type.
2012     auto memrefType = viewOp.getType();
2013 
2014     // Get offset from old memref view type 'memRefType'.
2015     int64_t oldOffset;
2016     SmallVector<int64_t, 4> oldStrides;
2017     if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
2018       return failure();
2019     assert(oldOffset == 0 && "Expected 0 offset");
2020 
2021     SmallVector<Value, 4> newOperands;
2022 
2023     // Offset cannot be folded into result type.
2024 
2025     // Fold any dynamic dim operands which are produced by a constant.
2026     SmallVector<int64_t, 4> newShapeConstants;
2027     newShapeConstants.reserve(memrefType.getRank());
2028 
2029     unsigned dynamicDimPos = 0;
2030     unsigned rank = memrefType.getRank();
2031     for (unsigned dim = 0, e = rank; dim < e; ++dim) {
2032       int64_t dimSize = memrefType.getDimSize(dim);
2033       // If this is already static dimension, keep it.
2034       if (!ShapedType::isDynamic(dimSize)) {
2035         newShapeConstants.push_back(dimSize);
2036         continue;
2037       }
2038       auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp();
2039       if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
2040         // Dynamic shape dimension will be folded.
2041         newShapeConstants.push_back(constantIndexOp.getValue());
2042       } else {
2043         // Dynamic shape dimension not folded; copy operand from old memref.
2044         newShapeConstants.push_back(dimSize);
2045         newOperands.push_back(viewOp.sizes()[dynamicDimPos]);
2046       }
2047       dynamicDimPos++;
2048     }
2049 
2050     // Create new memref type with constant folded dims.
2051     MemRefType newMemRefType =
2052         MemRefType::Builder(memrefType).setShape(newShapeConstants);
2053     // Nothing new, don't fold.
2054     if (newMemRefType == memrefType)
2055       return failure();
2056 
2057     // Create new ViewOp.
2058     auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType,
2059                                              viewOp.getOperand(0),
2060                                              viewOp.byte_shift(), newOperands);
2061     // Insert a cast so we have the same type as the old memref type.
2062     rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType());
2063     return success();
2064   }
2065 };
2066 
2067 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
2068   using OpRewritePattern<ViewOp>::OpRewritePattern;
2069 
2070   LogicalResult matchAndRewrite(ViewOp viewOp,
2071                                 PatternRewriter &rewriter) const override {
2072     Value memrefOperand = viewOp.getOperand(0);
2073     CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
2074     if (!memrefCastOp)
2075       return failure();
2076     Value allocOperand = memrefCastOp.getOperand();
2077     AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
2078     if (!allocOp)
2079       return failure();
2080     rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
2081                                         viewOp.byte_shift(), viewOp.sizes());
2082     return success();
2083   }
2084 };
2085 
2086 } // end anonymous namespace
2087 
2088 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
2089                                          MLIRContext *context) {
2090   results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
2091 }
2092 
2093 //===----------------------------------------------------------------------===//
2094 // TableGen'd op method definitions
2095 //===----------------------------------------------------------------------===//
2096 
2097 #define GET_OP_CLASSES
2098 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
2099