1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/MemRef/IR/MemRef.h"
10 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Interfaces/InferTypeOpInterface.h"
21 #include "mlir/Interfaces/ViewLikeInterface.h"
22 #include "llvm/ADT/STLExtras.h"
23 
24 using namespace mlir;
25 using namespace mlir::memref;
26 
27 /// Materialize a single constant operation from a given attribute value with
28 /// the desired resultant type.
29 Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
30                                               Attribute value, Type type,
31                                               Location loc) {
32   return builder.create<mlir::ConstantOp>(loc, type, value);
33 }
34 
35 /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
36 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
37   return llvm::to_vector<4>(
38       llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
39         return a.cast<IntegerAttr>().getInt();
40       }));
41 }
42 
43 /// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
44 /// it is a Value or into `staticVec` if it is an IntegerAttr.
45 /// In the case of a Value, a copy of the `sentinel` value is also pushed to
46 /// `staticVec`. This is useful to extract mixed static and dynamic entries that
47 /// come from an AttrSizedOperandSegments trait.
48 static void dispatchIndexOpFoldResult(OpFoldResult ofr,
49                                       SmallVectorImpl<Value> &dynamicVec,
50                                       SmallVectorImpl<int64_t> &staticVec,
51                                       int64_t sentinel) {
52   if (auto v = ofr.dyn_cast<Value>()) {
53     dynamicVec.push_back(v);
54     staticVec.push_back(sentinel);
55     return;
56   }
57   APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue();
58   staticVec.push_back(apInt.getSExtValue());
59 }
60 
61 static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
62                                        SmallVectorImpl<Value> &dynamicVec,
63                                        SmallVectorImpl<int64_t> &staticVec,
64                                        int64_t sentinel) {
65   for (auto ofr : ofrs)
66     dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel);
67 }
68 
69 //===----------------------------------------------------------------------===//
70 // Common canonicalization pattern support logic
71 //===----------------------------------------------------------------------===//
72 
73 /// This is a common class used for patterns of the form
74 /// "someop(memrefcast) -> someop".  It folds the source of any memref.cast
75 /// into the root operation directly.
76 static LogicalResult foldMemRefCast(Operation *op) {
77   bool folded = false;
78   for (OpOperand &operand : op->getOpOperands()) {
79     auto cast = operand.get().getDefiningOp<CastOp>();
80     if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
81       operand.set(cast.getOperand());
82       folded = true;
83     }
84   }
85   return success(folded);
86 }
87 
88 //===----------------------------------------------------------------------===//
89 // Helpers for GlobalOp
90 //===----------------------------------------------------------------------===//
91 
92 static Type getTensorTypeFromMemRefType(Type type) {
93   if (auto memref = type.dyn_cast<MemRefType>())
94     return RankedTensorType::get(memref.getShape(), memref.getElementType());
95   if (auto memref = type.dyn_cast<UnrankedMemRefType>())
96     return UnrankedTensorType::get(memref.getElementType());
97   return NoneType::get(type.getContext());
98 }
99 
100 //===----------------------------------------------------------------------===//
101 // AllocOp / AllocaOp
102 //===----------------------------------------------------------------------===//
103 
104 template <typename AllocLikeOp>
105 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
106   static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
107                 "applies to only alloc or alloca");
108   auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>();
109   if (!memRefType)
110     return op.emitOpError("result must be a memref");
111 
112   if (static_cast<int64_t>(op.dynamicSizes().size()) !=
113       memRefType.getNumDynamicDims())
114     return op.emitOpError("dimension operand count does not equal memref "
115                           "dynamic dimension count");
116 
117   unsigned numSymbols = 0;
118   if (!memRefType.getAffineMaps().empty())
119     numSymbols = memRefType.getAffineMaps().front().getNumSymbols();
120   if (op.symbolOperands().size() != numSymbols)
121     return op.emitOpError(
122         "symbol operand count does not equal memref symbol count");
123 
124   return success();
125 }
126 
127 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); }
128 
129 static LogicalResult verify(AllocaOp op) {
130   // An alloca op needs to have an ancestor with an allocation scope trait.
131   if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
132     return op.emitOpError(
133         "requires an ancestor op with AutomaticAllocationScope trait");
134 
135   return verifyAllocLikeOp(op);
136 }
137 
138 namespace {
139 /// Fold constant dimensions into an alloc like operation.
140 template <typename AllocLikeOp>
141 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
142   using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
143 
144   LogicalResult matchAndRewrite(AllocLikeOp alloc,
145                                 PatternRewriter &rewriter) const override {
146     // Check to see if any dimensions operands are constants.  If so, we can
147     // substitute and drop them.
148     if (llvm::none_of(alloc.getOperands(), [](Value operand) {
149           return matchPattern(operand, matchConstantIndex());
150         }))
151       return failure();
152 
153     auto memrefType = alloc.getType();
154 
155     // Ok, we have one or more constant operands.  Collect the non-constant ones
156     // and keep track of the resultant memref type to build.
157     SmallVector<int64_t, 4> newShapeConstants;
158     newShapeConstants.reserve(memrefType.getRank());
159     SmallVector<Value, 4> newOperands;
160 
161     unsigned dynamicDimPos = 0;
162     for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
163       int64_t dimSize = memrefType.getDimSize(dim);
164       // If this is already static dimension, keep it.
165       if (dimSize != -1) {
166         newShapeConstants.push_back(dimSize);
167         continue;
168       }
169       auto *defOp = alloc.getOperand(dynamicDimPos).getDefiningOp();
170       if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
171         // Dynamic shape dimension will be folded.
172         newShapeConstants.push_back(constantIndexOp.getValue());
173       } else {
174         // Dynamic shape dimension not folded; copy operand from old memref.
175         newShapeConstants.push_back(-1);
176         newOperands.push_back(alloc.getOperand(dynamicDimPos));
177       }
178       dynamicDimPos++;
179     }
180 
181     // Create new memref type (which will have fewer dynamic dimensions).
182     MemRefType newMemRefType =
183         MemRefType::Builder(memrefType).setShape(newShapeConstants);
184     assert(static_cast<int64_t>(newOperands.size()) ==
185            newMemRefType.getNumDynamicDims());
186 
187     // Create and insert the alloc op for the new memref.
188     auto newAlloc = rewriter.create<AllocLikeOp>(
189         alloc.getLoc(), newMemRefType, newOperands, alloc.alignmentAttr());
190     // Insert a cast so we have the same type as the old alloc.
191     auto resultCast =
192         rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType());
193 
194     rewriter.replaceOp(alloc, {resultCast});
195     return success();
196   }
197 };
198 
199 /// Fold alloc operations with no users or only store and dealloc uses.
200 template <typename T>
201 struct SimplifyDeadAlloc : public OpRewritePattern<T> {
202   using OpRewritePattern<T>::OpRewritePattern;
203 
204   LogicalResult matchAndRewrite(T alloc,
205                                 PatternRewriter &rewriter) const override {
206     if (llvm::any_of(alloc->getUsers(), [](Operation *op) {
207           return !isa<StoreOp, DeallocOp>(op);
208         }))
209       return failure();
210 
211     for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
212       rewriter.eraseOp(user);
213 
214     rewriter.eraseOp(alloc);
215     return success();
216   }
217 };
218 } // end anonymous namespace.
219 
220 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
221                                           MLIRContext *context) {
222   results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
223 }
224 
225 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
226                                            MLIRContext *context) {
227   results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
228       context);
229 }
230 
231 //===----------------------------------------------------------------------===//
232 // AssumeAlignmentOp
233 //===----------------------------------------------------------------------===//
234 
235 static LogicalResult verify(AssumeAlignmentOp op) {
236   unsigned alignment = op.alignment();
237   if (!llvm::isPowerOf2_32(alignment))
238     return op.emitOpError("alignment must be power of 2");
239   return success();
240 }
241 
242 //===----------------------------------------------------------------------===//
243 // BufferCastOp
244 //===----------------------------------------------------------------------===//
245 
246 OpFoldResult BufferCastOp::fold(ArrayRef<Attribute>) {
247   if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>())
248     if (tensorLoad.memref().getType() == getType())
249       return tensorLoad.memref();
250   return {};
251 }
252 
253 namespace {
254 /// Replace tensor_cast + buffer_cast by buffer_cast + memref_cast.
255 struct BufferCast : public OpRewritePattern<BufferCastOp> {
256   using OpRewritePattern<BufferCastOp>::OpRewritePattern;
257 
258   LogicalResult matchAndRewrite(BufferCastOp bufferCast,
259                                 PatternRewriter &rewriter) const final {
260     auto tensorCastOperand =
261         bufferCast.getOperand().getDefiningOp<tensor::CastOp>();
262     if (!tensorCastOperand)
263       return failure();
264     auto srcTensorType =
265         tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>();
266     if (!srcTensorType)
267       return failure();
268     auto memrefType = MemRefType::get(srcTensorType.getShape(),
269                                       srcTensorType.getElementType());
270     Value memref = rewriter.create<BufferCastOp>(
271         bufferCast.getLoc(), memrefType, tensorCastOperand.getOperand());
272     rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(),
273                                         memref);
274     return success();
275   }
276 };
277 
278 /// Canonicalize memref.tensor_load + memref.buffer_cast to memref.cast when
279 /// type mismatches prevent `BufferCastOp::fold` to kick in.
280 struct TensorLoadToMemRef : public OpRewritePattern<BufferCastOp> {
281   using OpRewritePattern<BufferCastOp>::OpRewritePattern;
282 
283   LogicalResult matchAndRewrite(BufferCastOp bufferCast,
284                                 PatternRewriter &rewriter) const final {
285     auto tensorLoad = bufferCast.tensor().getDefiningOp<TensorLoadOp>();
286     // Bail unless we have a tensor_load + memref.buffer_cast with different
287     // types. `BufferCastOp::fold` handles the same type case.
288     if (!tensorLoad || tensorLoad.memref().getType() == bufferCast.getType())
289       return failure();
290     // If types are not cast-compatible, bail.
291     if (!CastOp::areCastCompatible(tensorLoad.memref().getType(),
292                                    bufferCast.getType()))
293       return failure();
294     rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(),
295                                         tensorLoad.memref());
296     return success();
297   }
298 };
299 
300 } // namespace
301 
302 void BufferCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
303                                                MLIRContext *context) {
304   results.add<BufferCast, TensorLoadToMemRef>(context);
305 }
306 
307 //===----------------------------------------------------------------------===//
308 // CastOp
309 //===----------------------------------------------------------------------===//
310 
311 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
312 /// source memref. This is useful to to fold a memref.cast into a consuming op
313 /// and implement canonicalization patterns for ops in different dialects that
314 /// may consume the results of memref.cast operations. Such foldable memref.cast
315 /// operations are typically inserted as `view` and `subview` ops are
316 /// canonicalized, to preserve the type compatibility of their uses.
317 ///
318 /// Returns true when all conditions are met:
319 /// 1. source and result are ranked memrefs with strided semantics and same
320 /// element type and rank.
321 /// 2. each of the source's size, offset or stride has more static information
322 /// than the corresponding result's size, offset or stride.
323 ///
324 /// Example 1:
325 /// ```mlir
326 ///   %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
327 ///   %2 = consumer %1 ... : memref<?x?xf32> ...
328 /// ```
329 ///
330 /// may fold into:
331 ///
332 /// ```mlir
333 ///   %2 = consumer %0 ... : memref<8x16xf32> ...
334 /// ```
335 ///
336 /// Example 2:
337 /// ```
338 ///   %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
339 ///          to memref<?x?xf32>
340 ///   consumer %1 : memref<?x?xf32> ...
341 /// ```
342 ///
343 /// may fold into:
344 ///
345 /// ```
346 ///   consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
347 /// ```
348 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
349   MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
350   MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
351 
352   // Requires ranked MemRefType.
353   if (!sourceType || !resultType)
354     return false;
355 
356   // Requires same elemental type.
357   if (sourceType.getElementType() != resultType.getElementType())
358     return false;
359 
360   // Requires same rank.
361   if (sourceType.getRank() != resultType.getRank())
362     return false;
363 
364   // Only fold casts between strided memref forms.
365   int64_t sourceOffset, resultOffset;
366   SmallVector<int64_t, 4> sourceStrides, resultStrides;
367   if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
368       failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
369     return false;
370 
371   // If cast is towards more static sizes along any dimension, don't fold.
372   for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
373     auto ss = std::get<0>(it), st = std::get<1>(it);
374     if (ss != st)
375       if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st))
376         return false;
377   }
378 
379   // If cast is towards more static offset along any dimension, don't fold.
380   if (sourceOffset != resultOffset)
381     if (MemRefType::isDynamicStrideOrOffset(sourceOffset) &&
382         !MemRefType::isDynamicStrideOrOffset(resultOffset))
383       return false;
384 
385   // If cast is towards more static strides along any dimension, don't fold.
386   for (auto it : llvm::zip(sourceStrides, resultStrides)) {
387     auto ss = std::get<0>(it), st = std::get<1>(it);
388     if (ss != st)
389       if (MemRefType::isDynamicStrideOrOffset(ss) &&
390           !MemRefType::isDynamicStrideOrOffset(st))
391         return false;
392   }
393 
394   return true;
395 }
396 
397 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
398   if (inputs.size() != 1 || outputs.size() != 1)
399     return false;
400   Type a = inputs.front(), b = outputs.front();
401   auto aT = a.dyn_cast<MemRefType>();
402   auto bT = b.dyn_cast<MemRefType>();
403 
404   auto uaT = a.dyn_cast<UnrankedMemRefType>();
405   auto ubT = b.dyn_cast<UnrankedMemRefType>();
406 
407   if (aT && bT) {
408     if (aT.getElementType() != bT.getElementType())
409       return false;
410     if (aT.getAffineMaps() != bT.getAffineMaps()) {
411       int64_t aOffset, bOffset;
412       SmallVector<int64_t, 4> aStrides, bStrides;
413       if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
414           failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
415           aStrides.size() != bStrides.size())
416         return false;
417 
418       // Strides along a dimension/offset are compatible if the value in the
419       // source memref is static and the value in the target memref is the
420       // same. They are also compatible if either one is dynamic (see
421       // description of MemRefCastOp for details).
422       auto checkCompatible = [](int64_t a, int64_t b) {
423         return (a == MemRefType::getDynamicStrideOrOffset() ||
424                 b == MemRefType::getDynamicStrideOrOffset() || a == b);
425       };
426       if (!checkCompatible(aOffset, bOffset))
427         return false;
428       for (auto aStride : enumerate(aStrides))
429         if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
430           return false;
431     }
432     if (aT.getMemorySpace() != bT.getMemorySpace())
433       return false;
434 
435     // They must have the same rank, and any specified dimensions must match.
436     if (aT.getRank() != bT.getRank())
437       return false;
438 
439     for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
440       int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
441       if (aDim != -1 && bDim != -1 && aDim != bDim)
442         return false;
443     }
444     return true;
445   } else {
446     if (!aT && !uaT)
447       return false;
448     if (!bT && !ubT)
449       return false;
450     // Unranked to unranked casting is unsupported
451     if (uaT && ubT)
452       return false;
453 
454     auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
455     auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
456     if (aEltType != bEltType)
457       return false;
458 
459     auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
460     auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
461     if (aMemSpace != bMemSpace)
462       return false;
463 
464     return true;
465   }
466 
467   return false;
468 }
469 
470 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
471   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
472 }
473 
474 //===----------------------------------------------------------------------===//
475 // CloneOp
476 //===----------------------------------------------------------------------===//
477 
478 void CloneOp::getEffects(
479     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
480         &effects) {
481   effects.emplace_back(MemoryEffects::Read::get(), input(),
482                        SideEffects::DefaultResource::get());
483   effects.emplace_back(MemoryEffects::Write::get(), output(),
484                        SideEffects::DefaultResource::get());
485 }
486 
487 namespace {
488 /// Fold Dealloc operations that are deallocating an AllocOp that is only used
489 /// by other Dealloc operations.
490 struct SimplifyClones : public OpRewritePattern<CloneOp> {
491   using OpRewritePattern<CloneOp>::OpRewritePattern;
492 
493   LogicalResult matchAndRewrite(CloneOp cloneOp,
494                                 PatternRewriter &rewriter) const override {
495     if (cloneOp.use_empty()) {
496       rewriter.eraseOp(cloneOp);
497       return success();
498     }
499 
500     Value source = cloneOp.input();
501 
502     // This only finds dealloc operations for the immediate value. It should
503     // also consider aliases. That would also make the safety check below
504     // redundant.
505     Operation *cloneDeallocOp = findDealloc(cloneOp.output());
506     Operation *sourceDeallocOp = findDealloc(source);
507 
508     // If both are deallocated in the same block, their in-block lifetimes
509     // might not fully overlap, so we cannot decide which one to drop.
510     if (cloneDeallocOp && sourceDeallocOp &&
511         cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
512       return failure();
513 
514     Block *currentBlock = cloneOp->getBlock();
515     Operation *redundantDealloc = nullptr;
516     if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
517       redundantDealloc = cloneDeallocOp;
518     } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
519       redundantDealloc = sourceDeallocOp;
520     }
521 
522     if (!redundantDealloc)
523       return failure();
524 
525     // Safety check that there are no other deallocations inbetween
526     // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
527     // of source before the uses of the clone. With alias information, we could
528     // restrict this to only fail of the dealloc's operand is an alias
529     // of the source.
530     for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
531          pos = pos->getNextNode()) {
532       auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
533       if (!effectInterface)
534         continue;
535       if (effectInterface.hasEffect<MemoryEffects::Free>())
536         return failure();
537     }
538 
539     rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(),
540                                                 source);
541     rewriter.eraseOp(redundantDealloc);
542     return success();
543   }
544 };
545 
546 } // end anonymous namespace.
547 
548 void CloneOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
549                                           MLIRContext *context) {
550   results.insert<SimplifyClones>(context);
551 }
552 
553 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) {
554   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
555 }
556 
557 //===----------------------------------------------------------------------===//
558 // DeallocOp
559 //===----------------------------------------------------------------------===//
560 
561 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
562                               SmallVectorImpl<OpFoldResult> &results) {
563   /// dealloc(memrefcast) -> dealloc
564   return foldMemRefCast(*this);
565 }
566 
567 //===----------------------------------------------------------------------===//
568 // DimOp
569 //===----------------------------------------------------------------------===//
570 
571 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref,
572                   int64_t index) {
573   auto loc = result.location;
574   Value indexValue = builder.create<ConstantIndexOp>(loc, index);
575   build(builder, result, memref, indexValue);
576 }
577 
578 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref,
579                   Value index) {
580   auto indexTy = builder.getIndexType();
581   build(builder, result, indexTy, memref, index);
582 }
583 
584 Optional<int64_t> DimOp::getConstantIndex() {
585   if (auto constantOp = index().getDefiningOp<ConstantOp>())
586     return constantOp.getValue().cast<IntegerAttr>().getInt();
587   return {};
588 }
589 
590 static LogicalResult verify(DimOp op) {
591   // Assume unknown index to be in range.
592   Optional<int64_t> index = op.getConstantIndex();
593   if (!index.hasValue())
594     return success();
595 
596   // Check that constant index is not knowingly out of range.
597   auto type = op.memrefOrTensor().getType();
598   if (auto memrefType = type.dyn_cast<MemRefType>()) {
599     if (index.getValue() >= memrefType.getRank())
600       return op.emitOpError("index is out of range");
601   } else if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
602     if (index.getValue() >= tensorType.getRank())
603       return op.emitOpError("index is out of range");
604   } else if (type.isa<UnrankedMemRefType>() || type.isa<UnrankedTensorType>()) {
605     // Assume index to be in range.
606   } else {
607     llvm_unreachable("expected operand with memref type");
608   }
609   return success();
610 }
611 
612 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
613   auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
614 
615   // All forms of folding require a known index.
616   if (!index)
617     return {};
618 
619   auto argTy = memrefOrTensor().getType();
620   // Fold if the shape extent along the given index is known.
621   if (auto shapedTy = argTy.dyn_cast<ShapedType>()) {
622     // Folding for unranked types (UnrankedMemRefType) is not supported.
623     if (!shapedTy.hasRank())
624       return {};
625     if (!shapedTy.isDynamicDim(index.getInt())) {
626       Builder builder(getContext());
627       return builder.getIndexAttr(shapedTy.getShape()[index.getInt()]);
628     }
629   }
630 
631   Operation *definingOp = memrefOrTensor().getDefiningOp();
632 
633   // dim(memref.tensor_load(memref)) -> dim(memref)
634   if (auto tensorLoadOp = dyn_cast_or_null<TensorLoadOp>(definingOp)) {
635     setOperand(0, tensorLoadOp.memref());
636     return getResult();
637   }
638 
639   // Fold dim to the operand of tensor.generate.
640   if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
641     auto resultType =
642         fromElements.getResult().getType().cast<RankedTensorType>();
643     // The case where the type encodes the size of the dimension is handled
644     // above.
645     assert(resultType.getShape()[index.getInt()] ==
646            RankedTensorType::kDynamicSize);
647 
648     // Find the operand of the fromElements that corresponds to this index.
649     auto dynExtents = fromElements.dynamicExtents().begin();
650     for (auto dim : resultType.getShape().take_front(index.getInt()))
651       if (dim == RankedTensorType::kDynamicSize)
652         dynExtents++;
653 
654     return Value{*dynExtents};
655   }
656 
657   // The size at the given index is now known to be a dynamic size.
658   unsigned unsignedIndex = index.getValue().getZExtValue();
659 
660   if (auto subtensor = dyn_cast_or_null<mlir::SubTensorOp>(definingOp)) {
661     assert(subtensor.isDynamicSize(unsignedIndex) &&
662            "Expected dynamic subtensor size");
663     return subtensor.getDynamicSize(unsignedIndex);
664   }
665 
666   // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
667   auto memrefType = argTy.dyn_cast<MemRefType>();
668   if (!memrefType)
669     return {};
670 
671   if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
672     return *(alloc.getDynamicSizes().begin() +
673              memrefType.getDynamicDimIndex(unsignedIndex));
674 
675   if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
676     return *(alloca.getDynamicSizes().begin() +
677              memrefType.getDynamicDimIndex(unsignedIndex));
678 
679   if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
680     return *(view.getDynamicSizes().begin() +
681              memrefType.getDynamicDimIndex(unsignedIndex));
682 
683   if (auto sizeInterface =
684           dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
685     assert(sizeInterface.isDynamicSize(unsignedIndex) &&
686            "Expected dynamic subview size");
687     return sizeInterface.getDynamicSize(unsignedIndex);
688   }
689 
690   // dim(memrefcast) -> dim
691   if (succeeded(foldMemRefCast(*this)))
692     return getResult();
693 
694   return {};
695 }
696 
697 namespace {
698 /// Fold dim of a memref reshape operation to a load into the reshape's shape
699 /// operand.
700 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
701   using OpRewritePattern<DimOp>::OpRewritePattern;
702 
703   LogicalResult matchAndRewrite(DimOp dim,
704                                 PatternRewriter &rewriter) const override {
705     auto reshape = dim.memrefOrTensor().getDefiningOp<ReshapeOp>();
706 
707     if (!reshape)
708       return failure();
709 
710     // Place the load directly after the reshape to ensure that the shape memref
711     // was not mutated.
712     rewriter.setInsertionPointAfter(reshape);
713     rewriter.replaceOpWithNewOp<LoadOp>(dim, reshape.shape(),
714                                         llvm::makeArrayRef({dim.index()}));
715     return success();
716   }
717 };
718 
719 /// Fold dim of a dim of a cast into the dim of the source of the tensor cast.
720 template <typename CastOpTy>
721 struct DimOfCastOp : public OpRewritePattern<DimOp> {
722   using OpRewritePattern<DimOp>::OpRewritePattern;
723 
724   LogicalResult matchAndRewrite(DimOp dimOp,
725                                 PatternRewriter &rewriter) const override {
726     auto castOp = dimOp.memrefOrTensor().getDefiningOp<CastOpTy>();
727     if (!castOp)
728       return failure();
729     Value newSource = castOp.getOperand();
730     rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.index());
731     return success();
732   }
733 };
734 
735 /// Helper method to get the `Value` that is the shape of the `resultIdx`-th
736 /// result at dimension `dimIndex` from the `ShapedTypeOpInterface`.
737 /// TODO(ravishankarm): This is better put as a interface utility method
738 /// somewhere, but that would imply the interface will depend on the `tensor`
739 /// dialect. Ideally maybe a utility method in the `tensor` dialect.
740 static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result,
741                                             int64_t dimIndex) {
742   unsigned resultNumber = result.getResultNumber();
743   auto shapedTypeOp = dyn_cast<InferShapedTypeOpInterface>(result.getOwner());
744   Location loc = result.getOwner()->getLoc();
745   if (!shapedTypeOp)
746     return nullptr;
747 
748   // The interface exposes two methods, one that returns the shape of all the
749   // results as `Value` and other that returns the shape as a list of
750   // `SmallVector<Value>`. The former takes precedence over the latter. So first
751   // check if the op implements the first interface method or the second, and
752   // get the value to use appropriately.
753   SmallVector<Value> reifiedResultShapes;
754   if (succeeded(shapedTypeOp.reifyReturnTypeShapes(
755           builder, result.getOwner()->getOperands(), reifiedResultShapes))) {
756     if (reifiedResultShapes.size() <= resultNumber)
757       return nullptr;
758     Value resultShape = reifiedResultShapes[resultNumber];
759     auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>();
760     if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>())
761       return nullptr;
762     return builder.create<tensor::ExtractOp>(
763         loc, resultShape, builder.createOrFold<ConstantIndexOp>(loc, dimIndex));
764   }
765 
766   SmallVector<SmallVector<Value>> reifiedResultShapesPerDim;
767   if (failed(shapedTypeOp.reifyReturnTypeShapesPerResultDim(
768           builder, reifiedResultShapesPerDim)))
769     return nullptr;
770   if (reifiedResultShapesPerDim.size() <= resultNumber ||
771       reifiedResultShapesPerDim[resultNumber].size() !=
772           static_cast<size_t>(result.getType().cast<ShapedType>().getRank()))
773     return nullptr;
774   OpFoldResult valueOrAttr = reifiedResultShapesPerDim[resultNumber][dimIndex];
775   if (auto attr = valueOrAttr.dyn_cast<Attribute>())
776     return builder.createOrFold<ConstantIndexOp>(
777         loc, attr.cast<IntegerAttr>().getInt());
778   return valueOrAttr.get<Value>();
779 }
780 
781 /// Fold dim of an operation that implements the InferShapedTypeOpInterface
782 struct DimOfShapedTypeOpInterface : public OpRewritePattern<DimOp> {
783   using OpRewritePattern<DimOp>::OpRewritePattern;
784 
785   LogicalResult matchAndRewrite(DimOp dimOp,
786                                 PatternRewriter &rewriter) const override {
787     OpResult dimValue = dimOp.memrefOrTensor().dyn_cast<OpResult>();
788     if (!dimValue)
789       return failure();
790     auto shapedTypeOp =
791         dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner());
792     if (!shapedTypeOp)
793       return failure();
794 
795     Optional<int64_t> dimIndex = dimOp.getConstantIndex();
796     if (!dimIndex)
797       return failure();
798     Value replacement =
799         getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex);
800     if (!replacement)
801       return failure();
802     rewriter.replaceOp(dimOp, replacement);
803     return success();
804   }
805 };
806 } // end anonymous namespace.
807 
808 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
809                                         MLIRContext *context) {
810   results.add<DimOfMemRefReshape, DimOfCastOp<BufferCastOp>,
811               DimOfCastOp<tensor::CastOp>, DimOfShapedTypeOpInterface>(context);
812 }
813 
814 // ---------------------------------------------------------------------------
815 // DmaStartOp
816 // ---------------------------------------------------------------------------
817 
818 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
819                        Value srcMemRef, ValueRange srcIndices, Value destMemRef,
820                        ValueRange destIndices, Value numElements,
821                        Value tagMemRef, ValueRange tagIndices, Value stride,
822                        Value elementsPerStride) {
823   result.addOperands(srcMemRef);
824   result.addOperands(srcIndices);
825   result.addOperands(destMemRef);
826   result.addOperands(destIndices);
827   result.addOperands({numElements, tagMemRef});
828   result.addOperands(tagIndices);
829   if (stride)
830     result.addOperands({stride, elementsPerStride});
831 }
832 
833 void DmaStartOp::print(OpAsmPrinter &p) {
834   p << getOperationName() << " " << getSrcMemRef() << '[' << getSrcIndices()
835     << "], " << getDstMemRef() << '[' << getDstIndices() << "], "
836     << getNumElements() << ", " << getTagMemRef() << '[' << getTagIndices()
837     << ']';
838   if (isStrided())
839     p << ", " << getStride() << ", " << getNumElementsPerStride();
840 
841   p.printOptionalAttrDict((*this)->getAttrs());
842   p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
843     << ", " << getTagMemRef().getType();
844 }
845 
846 // Parse DmaStartOp.
847 // Ex:
848 //   %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
849 //                       %tag[%index], %stride, %num_elt_per_stride :
850 //                     : memref<3076 x f32, 0>,
851 //                       memref<1024 x f32, 2>,
852 //                       memref<1 x i32>
853 //
854 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
855   OpAsmParser::OperandType srcMemRefInfo;
856   SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
857   OpAsmParser::OperandType dstMemRefInfo;
858   SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos;
859   OpAsmParser::OperandType numElementsInfo;
860   OpAsmParser::OperandType tagMemrefInfo;
861   SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
862   SmallVector<OpAsmParser::OperandType, 2> strideInfo;
863 
864   SmallVector<Type, 3> types;
865   auto indexType = parser.getBuilder().getIndexType();
866 
867   // Parse and resolve the following list of operands:
868   // *) source memref followed by its indices (in square brackets).
869   // *) destination memref followed by its indices (in square brackets).
870   // *) dma size in KiB.
871   if (parser.parseOperand(srcMemRefInfo) ||
872       parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
873       parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
874       parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
875       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
876       parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
877       parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
878     return failure();
879 
880   // Parse optional stride and elements per stride.
881   if (parser.parseTrailingOperandList(strideInfo))
882     return failure();
883 
884   bool isStrided = strideInfo.size() == 2;
885   if (!strideInfo.empty() && !isStrided) {
886     return parser.emitError(parser.getNameLoc(),
887                             "expected two stride related operands");
888   }
889 
890   if (parser.parseColonTypeList(types))
891     return failure();
892   if (types.size() != 3)
893     return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
894 
895   if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
896       parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
897       parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
898       parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
899       // size should be an index.
900       parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
901       parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
902       // tag indices should be index.
903       parser.resolveOperands(tagIndexInfos, indexType, result.operands))
904     return failure();
905 
906   if (isStrided) {
907     if (parser.resolveOperands(strideInfo, indexType, result.operands))
908       return failure();
909   }
910 
911   return success();
912 }
913 
914 LogicalResult DmaStartOp::verify() {
915   unsigned numOperands = getNumOperands();
916 
917   // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
918   // the number of elements.
919   if (numOperands < 4)
920     return emitOpError("expected at least 4 operands");
921 
922   // Check types of operands. The order of these calls is important: the later
923   // calls rely on some type properties to compute the operand position.
924   // 1. Source memref.
925   if (!getSrcMemRef().getType().isa<MemRefType>())
926     return emitOpError("expected source to be of memref type");
927   if (numOperands < getSrcMemRefRank() + 4)
928     return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
929                          << " operands";
930   if (!getSrcIndices().empty() &&
931       !llvm::all_of(getSrcIndices().getTypes(),
932                     [](Type t) { return t.isIndex(); }))
933     return emitOpError("expected source indices to be of index type");
934 
935   // 2. Destination memref.
936   if (!getDstMemRef().getType().isa<MemRefType>())
937     return emitOpError("expected destination to be of memref type");
938   unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
939   if (numOperands < numExpectedOperands)
940     return emitOpError() << "expected at least " << numExpectedOperands
941                          << " operands";
942   if (!getDstIndices().empty() &&
943       !llvm::all_of(getDstIndices().getTypes(),
944                     [](Type t) { return t.isIndex(); }))
945     return emitOpError("expected destination indices to be of index type");
946 
947   // 3. Number of elements.
948   if (!getNumElements().getType().isIndex())
949     return emitOpError("expected num elements to be of index type");
950 
951   // 4. Tag memref.
952   if (!getTagMemRef().getType().isa<MemRefType>())
953     return emitOpError("expected tag to be of memref type");
954   numExpectedOperands += getTagMemRefRank();
955   if (numOperands < numExpectedOperands)
956     return emitOpError() << "expected at least " << numExpectedOperands
957                          << " operands";
958   if (!getTagIndices().empty() &&
959       !llvm::all_of(getTagIndices().getTypes(),
960                     [](Type t) { return t.isIndex(); }))
961     return emitOpError("expected tag indices to be of index type");
962 
963   // Optional stride-related operands must be either both present or both
964   // absent.
965   if (numOperands != numExpectedOperands &&
966       numOperands != numExpectedOperands + 2)
967     return emitOpError("incorrect number of operands");
968 
969   // 5. Strides.
970   if (isStrided()) {
971     if (!getStride().getType().isIndex() ||
972         !getNumElementsPerStride().getType().isIndex())
973       return emitOpError(
974           "expected stride and num elements per stride to be of type index");
975   }
976 
977   return success();
978 }
979 
980 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
981                                SmallVectorImpl<OpFoldResult> &results) {
982   /// dma_start(memrefcast) -> dma_start
983   return foldMemRefCast(*this);
984 }
985 
986 // ---------------------------------------------------------------------------
987 // DmaWaitOp
988 // ---------------------------------------------------------------------------
989 
990 void DmaWaitOp::build(OpBuilder &builder, OperationState &result,
991                       Value tagMemRef, ValueRange tagIndices,
992                       Value numElements) {
993   result.addOperands(tagMemRef);
994   result.addOperands(tagIndices);
995   result.addOperands(numElements);
996 }
997 
998 void DmaWaitOp::print(OpAsmPrinter &p) {
999   p << getOperationName() << " " << getTagMemRef() << '[' << getTagIndices()
1000     << "], " << getNumElements();
1001   p.printOptionalAttrDict((*this)->getAttrs());
1002   p << " : " << getTagMemRef().getType();
1003 }
1004 
1005 // Parse DmaWaitOp.
1006 // Eg:
1007 //   dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4>
1008 //
1009 ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) {
1010   OpAsmParser::OperandType tagMemrefInfo;
1011   SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
1012   Type type;
1013   auto indexType = parser.getBuilder().getIndexType();
1014   OpAsmParser::OperandType numElementsInfo;
1015 
1016   // Parse tag memref, its indices, and dma size.
1017   if (parser.parseOperand(tagMemrefInfo) ||
1018       parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) ||
1019       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1020       parser.parseColonType(type) ||
1021       parser.resolveOperand(tagMemrefInfo, type, result.operands) ||
1022       parser.resolveOperands(tagIndexInfos, indexType, result.operands) ||
1023       parser.resolveOperand(numElementsInfo, indexType, result.operands))
1024     return failure();
1025 
1026   return success();
1027 }
1028 
1029 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1030                               SmallVectorImpl<OpFoldResult> &results) {
1031   /// dma_wait(memrefcast) -> dma_wait
1032   return foldMemRefCast(*this);
1033 }
1034 
1035 LogicalResult DmaWaitOp::verify() {
1036   // Mandatory non-variadic operands are tag and the number of elements.
1037   if (getNumOperands() < 2)
1038     return emitOpError() << "expected at least 2 operands";
1039 
1040   // Check types of operands. The order of these calls is important: the later
1041   // calls rely on some type properties to compute the operand position.
1042   if (!getTagMemRef().getType().isa<MemRefType>())
1043     return emitOpError() << "expected tag to be of memref type";
1044 
1045   if (getNumOperands() != 2 + getTagMemRefRank())
1046     return emitOpError() << "expected " << 2 + getTagMemRefRank()
1047                          << " operands";
1048 
1049   if (!getTagIndices().empty() &&
1050       !llvm::all_of(getTagIndices().getTypes(),
1051                     [](Type t) { return t.isIndex(); }))
1052     return emitOpError() << "expected tag indices to be of index type";
1053 
1054   if (!getNumElements().getType().isIndex())
1055     return emitOpError()
1056            << "expected the number of elements to be of index type";
1057 
1058   return success();
1059 }
1060 
1061 //===----------------------------------------------------------------------===//
1062 // GlobalOp
1063 //===----------------------------------------------------------------------===//
1064 
1065 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op,
1066                                                    TypeAttr type,
1067                                                    Attribute initialValue) {
1068   p << type;
1069   if (!op.isExternal()) {
1070     p << " = ";
1071     if (op.isUninitialized())
1072       p << "uninitialized";
1073     else
1074       p.printAttributeWithoutType(initialValue);
1075   }
1076 }
1077 
1078 static ParseResult
1079 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
1080                                        Attribute &initialValue) {
1081   Type type;
1082   if (parser.parseType(type))
1083     return failure();
1084 
1085   auto memrefType = type.dyn_cast<MemRefType>();
1086   if (!memrefType || !memrefType.hasStaticShape())
1087     return parser.emitError(parser.getNameLoc())
1088            << "type should be static shaped memref, but got " << type;
1089   typeAttr = TypeAttr::get(type);
1090 
1091   if (parser.parseOptionalEqual())
1092     return success();
1093 
1094   if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1095     initialValue = UnitAttr::get(parser.getBuilder().getContext());
1096     return success();
1097   }
1098 
1099   Type tensorType = getTensorTypeFromMemRefType(memrefType);
1100   if (parser.parseAttribute(initialValue, tensorType))
1101     return failure();
1102   if (!initialValue.isa<ElementsAttr>())
1103     return parser.emitError(parser.getNameLoc())
1104            << "initial value should be a unit or elements attribute";
1105   return success();
1106 }
1107 
1108 static LogicalResult verify(GlobalOp op) {
1109   auto memrefType = op.type().dyn_cast<MemRefType>();
1110   if (!memrefType || !memrefType.hasStaticShape())
1111     return op.emitOpError("type should be static shaped memref, but got ")
1112            << op.type();
1113 
1114   // Verify that the initial value, if present, is either a unit attribute or
1115   // an elements attribute.
1116   if (op.initial_value().hasValue()) {
1117     Attribute initValue = op.initial_value().getValue();
1118     if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>())
1119       return op.emitOpError("initial value should be a unit or elements "
1120                             "attribute, but got ")
1121              << initValue;
1122 
1123     // Check that the type of the initial value is compatible with the type of
1124     // the global variable.
1125     if (initValue.isa<ElementsAttr>()) {
1126       Type initType = initValue.getType();
1127       Type tensorType = getTensorTypeFromMemRefType(memrefType);
1128       if (initType != tensorType)
1129         return op.emitOpError("initial value expected to be of type ")
1130                << tensorType << ", but was of type " << initType;
1131     }
1132   }
1133 
1134   // TODO: verify visibility for declarations.
1135   return success();
1136 }
1137 
1138 //===----------------------------------------------------------------------===//
1139 // GetGlobalOp
1140 //===----------------------------------------------------------------------===//
1141 
1142 LogicalResult
1143 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1144   // Verify that the result type is same as the type of the referenced
1145   // memref.global op.
1146   auto global =
1147       symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr());
1148   if (!global)
1149     return emitOpError("'")
1150            << name() << "' does not reference a valid global memref";
1151 
1152   Type resultType = result().getType();
1153   if (global.type() != resultType)
1154     return emitOpError("result type ")
1155            << resultType << " does not match type " << global.type()
1156            << " of the global memref @" << name();
1157   return success();
1158 }
1159 
1160 //===----------------------------------------------------------------------===//
1161 // LoadOp
1162 //===----------------------------------------------------------------------===//
1163 
1164 static LogicalResult verify(LoadOp op) {
1165   if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
1166     return op.emitOpError("incorrect number of indices for load");
1167   return success();
1168 }
1169 
1170 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
1171   /// load(memrefcast) -> load
1172   if (succeeded(foldMemRefCast(*this)))
1173     return getResult();
1174   return OpFoldResult();
1175 }
1176 
1177 namespace {
1178 /// Fold a load on a buffer_cast operation into an tensor.extract on the
1179 /// corresponding tensor.
1180 struct LoadOfBufferCast : public OpRewritePattern<LoadOp> {
1181   using OpRewritePattern<LoadOp>::OpRewritePattern;
1182 
1183   LogicalResult matchAndRewrite(LoadOp load,
1184                                 PatternRewriter &rewriter) const override {
1185     auto buffercast = load.memref().getDefiningOp<BufferCastOp>();
1186     if (!buffercast)
1187       return failure();
1188 
1189     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, buffercast.tensor(),
1190                                                    load.indices());
1191     return success();
1192   }
1193 };
1194 } // end anonymous namespace.
1195 
1196 void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
1197                                          MLIRContext *context) {
1198   results.add<LoadOfBufferCast>(context);
1199 }
1200 
1201 //===----------------------------------------------------------------------===//
1202 // PrefetchOp
1203 //===----------------------------------------------------------------------===//
1204 
1205 static void print(OpAsmPrinter &p, PrefetchOp op) {
1206   p << PrefetchOp::getOperationName() << " " << op.memref() << '[';
1207   p.printOperands(op.indices());
1208   p << ']' << ", " << (op.isWrite() ? "write" : "read");
1209   p << ", locality<" << op.localityHint();
1210   p << ">, " << (op.isDataCache() ? "data" : "instr");
1211   p.printOptionalAttrDict(
1212       op->getAttrs(),
1213       /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1214   p << " : " << op.getMemRefType();
1215 }
1216 
1217 static ParseResult parsePrefetchOp(OpAsmParser &parser,
1218                                    OperationState &result) {
1219   OpAsmParser::OperandType memrefInfo;
1220   SmallVector<OpAsmParser::OperandType, 4> indexInfo;
1221   IntegerAttr localityHint;
1222   MemRefType type;
1223   StringRef readOrWrite, cacheType;
1224 
1225   auto indexTy = parser.getBuilder().getIndexType();
1226   auto i32Type = parser.getBuilder().getIntegerType(32);
1227   if (parser.parseOperand(memrefInfo) ||
1228       parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1229       parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1230       parser.parseComma() || parser.parseKeyword("locality") ||
1231       parser.parseLess() ||
1232       parser.parseAttribute(localityHint, i32Type, "localityHint",
1233                             result.attributes) ||
1234       parser.parseGreater() || parser.parseComma() ||
1235       parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1236       parser.resolveOperand(memrefInfo, type, result.operands) ||
1237       parser.resolveOperands(indexInfo, indexTy, result.operands))
1238     return failure();
1239 
1240   if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
1241     return parser.emitError(parser.getNameLoc(),
1242                             "rw specifier has to be 'read' or 'write'");
1243   result.addAttribute(
1244       PrefetchOp::getIsWriteAttrName(),
1245       parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
1246 
1247   if (!cacheType.equals("data") && !cacheType.equals("instr"))
1248     return parser.emitError(parser.getNameLoc(),
1249                             "cache type has to be 'data' or 'instr'");
1250 
1251   result.addAttribute(
1252       PrefetchOp::getIsDataCacheAttrName(),
1253       parser.getBuilder().getBoolAttr(cacheType.equals("data")));
1254 
1255   return success();
1256 }
1257 
1258 static LogicalResult verify(PrefetchOp op) {
1259   if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
1260     return op.emitOpError("too few indices");
1261 
1262   return success();
1263 }
1264 
1265 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
1266                                SmallVectorImpl<OpFoldResult> &results) {
1267   // prefetch(memrefcast) -> prefetch
1268   return foldMemRefCast(*this);
1269 }
1270 
1271 //===----------------------------------------------------------------------===//
1272 // ReinterpretCastOp
1273 //===----------------------------------------------------------------------===//
1274 
1275 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1276 /// `staticSizes` and `staticStrides` are automatically filled with
1277 /// source-memref-rank sentinel values that encode dynamic entries.
1278 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1279                               MemRefType resultType, Value source,
1280                               OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1281                               ArrayRef<OpFoldResult> strides,
1282                               ArrayRef<NamedAttribute> attrs) {
1283   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1284   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1285   dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets,
1286                              ShapedType::kDynamicStrideOrOffset);
1287   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1288                              ShapedType::kDynamicSize);
1289   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1290                              ShapedType::kDynamicStrideOrOffset);
1291   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1292         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1293         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1294   result.addAttributes(attrs);
1295 }
1296 
1297 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1298                               MemRefType resultType, Value source,
1299                               int64_t offset, ArrayRef<int64_t> sizes,
1300                               ArrayRef<int64_t> strides,
1301                               ArrayRef<NamedAttribute> attrs) {
1302   SmallVector<OpFoldResult> sizeValues =
1303       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1304         return b.getI64IntegerAttr(v);
1305       }));
1306   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1307       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1308         return b.getI64IntegerAttr(v);
1309       }));
1310   build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1311         strideValues, attrs);
1312 }
1313 
1314 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1315                               MemRefType resultType, Value source, Value offset,
1316                               ValueRange sizes, ValueRange strides,
1317                               ArrayRef<NamedAttribute> attrs) {
1318   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1319       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1320   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1321       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1322   build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1323 }
1324 
1325 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1326 // completed automatically, like we have for subview and subtensor.
1327 static LogicalResult verify(ReinterpretCastOp op) {
1328   // The source and result memrefs should be in the same memory space.
1329   auto srcType = op.source().getType().cast<BaseMemRefType>();
1330   auto resultType = op.getType().cast<MemRefType>();
1331   if (srcType.getMemorySpace() != resultType.getMemorySpace())
1332     return op.emitError("different memory spaces specified for source type ")
1333            << srcType << " and result memref type " << resultType;
1334   if (srcType.getElementType() != resultType.getElementType())
1335     return op.emitError("different element types specified for source type ")
1336            << srcType << " and result memref type " << resultType;
1337 
1338   // Match sizes in result memref type and in static_sizes attribute.
1339   for (auto &en :
1340        llvm::enumerate(llvm::zip(resultType.getShape(),
1341                                  extractFromI64ArrayAttr(op.static_sizes())))) {
1342     int64_t resultSize = std::get<0>(en.value());
1343     int64_t expectedSize = std::get<1>(en.value());
1344     if (resultSize != expectedSize)
1345       return op.emitError("expected result type with size = ")
1346              << expectedSize << " instead of " << resultSize
1347              << " in dim = " << en.index();
1348   }
1349 
1350   // Match offset and strides in static_offset and static_strides attributes if
1351   // result memref type has an affine map specified.
1352   if (!resultType.getAffineMaps().empty()) {
1353     int64_t resultOffset;
1354     SmallVector<int64_t, 4> resultStrides;
1355     if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
1356       return failure();
1357 
1358     // Match offset in result memref type and in static_offsets attribute.
1359     int64_t expectedOffset =
1360         extractFromI64ArrayAttr(op.static_offsets()).front();
1361     if (resultOffset != expectedOffset)
1362       return op.emitError("expected result type with offset = ")
1363              << resultOffset << " instead of " << expectedOffset;
1364 
1365     // Match strides in result memref type and in static_strides attribute.
1366     for (auto &en : llvm::enumerate(llvm::zip(
1367              resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
1368       int64_t resultStride = std::get<0>(en.value());
1369       int64_t expectedStride = std::get<1>(en.value());
1370       if (resultStride != expectedStride)
1371         return op.emitError("expected result type with stride = ")
1372                << expectedStride << " instead of " << resultStride
1373                << " in dim = " << en.index();
1374     }
1375   }
1376   return success();
1377 }
1378 
1379 //===----------------------------------------------------------------------===//
1380 // ReshapeOp
1381 //===----------------------------------------------------------------------===//
1382 
1383 static LogicalResult verify(ReshapeOp op) {
1384   Type operandType = op.source().getType();
1385   Type resultType = op.result().getType();
1386 
1387   Type operandElementType = operandType.cast<ShapedType>().getElementType();
1388   Type resultElementType = resultType.cast<ShapedType>().getElementType();
1389   if (operandElementType != resultElementType)
1390     return op.emitOpError("element types of source and destination memref "
1391                           "types should be the same");
1392 
1393   if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
1394     if (!operandMemRefType.getAffineMaps().empty())
1395       return op.emitOpError(
1396           "source memref type should have identity affine map");
1397 
1398   int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0);
1399   auto resultMemRefType = resultType.dyn_cast<MemRefType>();
1400   if (resultMemRefType) {
1401     if (!resultMemRefType.getAffineMaps().empty())
1402       return op.emitOpError(
1403           "result memref type should have identity affine map");
1404     if (shapeSize == ShapedType::kDynamicSize)
1405       return op.emitOpError("cannot use shape operand with dynamic length to "
1406                             "reshape to statically-ranked memref type");
1407     if (shapeSize != resultMemRefType.getRank())
1408       return op.emitOpError(
1409           "length of shape operand differs from the result's memref rank");
1410   }
1411   return success();
1412 }
1413 
1414 //===----------------------------------------------------------------------===//
1415 // StoreOp
1416 //===----------------------------------------------------------------------===//
1417 
1418 static LogicalResult verify(StoreOp op) {
1419   if (op.getNumOperands() != 2 + op.getMemRefType().getRank())
1420     return op.emitOpError("store index operand count not equal to memref rank");
1421 
1422   return success();
1423 }
1424 
1425 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
1426                             SmallVectorImpl<OpFoldResult> &results) {
1427   /// store(memrefcast) -> store
1428   return foldMemRefCast(*this);
1429 }
1430 
1431 //===----------------------------------------------------------------------===//
1432 // SubViewOp
1433 //===----------------------------------------------------------------------===//
1434 
1435 namespace {
1436 /// Helpers to write more idiomatic operations.
1437 namespace saturated_arith {
1438 struct Wrapper {
1439   explicit Wrapper(int64_t v) : v(v) {}
1440   operator int64_t() { return v; }
1441   int64_t v;
1442 };
1443 Wrapper operator+(Wrapper a, int64_t b) {
1444   if (ShapedType::isDynamicStrideOrOffset(a) ||
1445       ShapedType::isDynamicStrideOrOffset(b))
1446     return Wrapper(ShapedType::kDynamicStrideOrOffset);
1447   return Wrapper(a.v + b);
1448 }
1449 Wrapper operator*(Wrapper a, int64_t b) {
1450   if (ShapedType::isDynamicStrideOrOffset(a) ||
1451       ShapedType::isDynamicStrideOrOffset(b))
1452     return Wrapper(ShapedType::kDynamicStrideOrOffset);
1453   return Wrapper(a.v * b);
1454 }
1455 } // end namespace saturated_arith
1456 } // end namespace
1457 
1458 /// A subview result type can be fully inferred from the source type and the
1459 /// static representation of offsets, sizes and strides. Special sentinels
1460 /// encode the dynamic case.
1461 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
1462                                 ArrayRef<int64_t> leadingStaticOffsets,
1463                                 ArrayRef<int64_t> leadingStaticSizes,
1464                                 ArrayRef<int64_t> leadingStaticStrides) {
1465   // A subview may specify only a leading subset of offset/sizes/strides in
1466   // which case we complete with offset=0, sizes from memref type and strides=1.
1467   unsigned rank = sourceMemRefType.getRank();
1468   assert(leadingStaticOffsets.size() <= rank &&
1469          "unexpected leadingStaticOffsets overflow");
1470   assert(leadingStaticSizes.size() <= rank &&
1471          "unexpected leadingStaticSizes overflow");
1472   assert(leadingStaticStrides.size() <= rank &&
1473          "unexpected leadingStaticStrides overflow");
1474   auto staticOffsets = llvm::to_vector<4>(leadingStaticOffsets);
1475   auto staticSizes = llvm::to_vector<4>(leadingStaticSizes);
1476   auto staticStrides = llvm::to_vector<4>(leadingStaticStrides);
1477   unsigned numTrailingOffsets = rank - staticOffsets.size();
1478   unsigned numTrailingSizes = rank - staticSizes.size();
1479   unsigned numTrailingStrides = rank - staticStrides.size();
1480   staticOffsets.append(numTrailingOffsets, 0);
1481   llvm::append_range(staticSizes,
1482                      sourceMemRefType.getShape().take_back(numTrailingSizes));
1483   staticStrides.append(numTrailingStrides, 1);
1484 
1485   // Extract source offset and strides.
1486   int64_t sourceOffset;
1487   SmallVector<int64_t, 4> sourceStrides;
1488   auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset);
1489   assert(succeeded(res) && "SubViewOp expected strided memref type");
1490   (void)res;
1491 
1492   // Compute target offset whose value is:
1493   //   `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
1494   int64_t targetOffset = sourceOffset;
1495   for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
1496     auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
1497     using namespace saturated_arith;
1498     targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride;
1499   }
1500 
1501   // Compute target stride whose value is:
1502   //   `sourceStrides_i * staticStrides_i`.
1503   SmallVector<int64_t, 4> targetStrides;
1504   targetStrides.reserve(staticOffsets.size());
1505   for (auto it : llvm::zip(sourceStrides, staticStrides)) {
1506     auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
1507     using namespace saturated_arith;
1508     targetStrides.push_back(Wrapper(sourceStride) * staticStride);
1509   }
1510 
1511   // The type is now known.
1512   return MemRefType::get(
1513       staticSizes, sourceMemRefType.getElementType(),
1514       makeStridedLinearLayoutMap(targetStrides, targetOffset,
1515                                  sourceMemRefType.getContext()),
1516       sourceMemRefType.getMemorySpace());
1517 }
1518 
1519 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
1520                                 ArrayRef<OpFoldResult> leadingStaticOffsets,
1521                                 ArrayRef<OpFoldResult> leadingStaticSizes,
1522                                 ArrayRef<OpFoldResult> leadingStaticStrides) {
1523   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1524   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1525   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
1526                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
1527   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
1528                              ShapedType::kDynamicSize);
1529   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
1530                              staticStrides, ShapedType::kDynamicStrideOrOffset);
1531   return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
1532                                     staticSizes, staticStrides)
1533       .cast<MemRefType>();
1534 }
1535 
1536 Type SubViewOp::inferRankReducedResultType(
1537     unsigned resultRank, MemRefType sourceRankedTensorType,
1538     ArrayRef<int64_t> leadingStaticOffsets,
1539     ArrayRef<int64_t> leadingStaticSizes,
1540     ArrayRef<int64_t> leadingStaticStrides) {
1541   auto inferredType =
1542       inferResultType(sourceRankedTensorType, leadingStaticOffsets,
1543                       leadingStaticSizes, leadingStaticStrides)
1544           .cast<MemRefType>();
1545   assert(inferredType.getRank() >= resultRank && "expected ");
1546   int rankDiff = inferredType.getRank() - resultRank;
1547   if (rankDiff > 0) {
1548     auto shape = inferredType.getShape();
1549     llvm::SmallDenseSet<unsigned> dimsToProject;
1550     mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject);
1551     SmallVector<int64_t> projectedShape;
1552     for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
1553       if (!dimsToProject.contains(pos))
1554         projectedShape.push_back(shape[pos]);
1555 
1556     AffineMap map;
1557     auto maps = inferredType.getAffineMaps();
1558     if (!maps.empty() && maps.front())
1559       map = getProjectedMap(maps.front(), dimsToProject);
1560     inferredType =
1561         MemRefType::get(projectedShape, inferredType.getElementType(), map,
1562                         inferredType.getMemorySpace());
1563   }
1564   return inferredType;
1565 }
1566 
1567 Type SubViewOp::inferRankReducedResultType(
1568     unsigned resultRank, MemRefType sourceRankedTensorType,
1569     ArrayRef<OpFoldResult> leadingStaticOffsets,
1570     ArrayRef<OpFoldResult> leadingStaticSizes,
1571     ArrayRef<OpFoldResult> leadingStaticStrides) {
1572   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1573   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1574   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
1575                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
1576   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
1577                              ShapedType::kDynamicSize);
1578   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
1579                              staticStrides, ShapedType::kDynamicStrideOrOffset);
1580   return SubViewOp::inferRankReducedResultType(
1581       resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
1582       staticStrides);
1583 }
1584 // Build a SubViewOp with mixed static and dynamic entries and custom result
1585 // type. If the type passed is nullptr, it is inferred.
1586 void SubViewOp::build(OpBuilder &b, OperationState &result,
1587                       MemRefType resultType, Value source,
1588                       ArrayRef<OpFoldResult> offsets,
1589                       ArrayRef<OpFoldResult> sizes,
1590                       ArrayRef<OpFoldResult> strides,
1591                       ArrayRef<NamedAttribute> attrs) {
1592   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1593   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1594   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1595                              ShapedType::kDynamicStrideOrOffset);
1596   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1597                              ShapedType::kDynamicSize);
1598   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1599                              ShapedType::kDynamicStrideOrOffset);
1600   auto sourceMemRefType = source.getType().cast<MemRefType>();
1601   // Structuring implementation this way avoids duplication between builders.
1602   if (!resultType) {
1603     resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
1604                                             staticSizes, staticStrides)
1605                      .cast<MemRefType>();
1606   }
1607   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1608         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1609         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1610   result.addAttributes(attrs);
1611 }
1612 
1613 // Build a SubViewOp with mixed static and dynamic entries and inferred result
1614 // type.
1615 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1616                       ArrayRef<OpFoldResult> offsets,
1617                       ArrayRef<OpFoldResult> sizes,
1618                       ArrayRef<OpFoldResult> strides,
1619                       ArrayRef<NamedAttribute> attrs) {
1620   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
1621 }
1622 
1623 // Build a SubViewOp with static entries and inferred result type.
1624 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1625                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1626                       ArrayRef<int64_t> strides,
1627                       ArrayRef<NamedAttribute> attrs) {
1628   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1629       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
1630         return b.getI64IntegerAttr(v);
1631       }));
1632   SmallVector<OpFoldResult> sizeValues =
1633       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1634         return b.getI64IntegerAttr(v);
1635       }));
1636   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1637       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1638         return b.getI64IntegerAttr(v);
1639       }));
1640   build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
1641 }
1642 
1643 // Build a SubViewOp with dynamic entries and custom result type. If the
1644 // type passed is nullptr, it is inferred.
1645 void SubViewOp::build(OpBuilder &b, OperationState &result,
1646                       MemRefType resultType, Value source,
1647                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1648                       ArrayRef<int64_t> strides,
1649                       ArrayRef<NamedAttribute> attrs) {
1650   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1651       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
1652         return b.getI64IntegerAttr(v);
1653       }));
1654   SmallVector<OpFoldResult> sizeValues =
1655       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1656         return b.getI64IntegerAttr(v);
1657       }));
1658   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1659       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1660         return b.getI64IntegerAttr(v);
1661       }));
1662   build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
1663         attrs);
1664 }
1665 
1666 // Build a SubViewOp with dynamic entries and custom result type. If the type
1667 // passed is nullptr, it is inferred.
1668 void SubViewOp::build(OpBuilder &b, OperationState &result,
1669                       MemRefType resultType, Value source, ValueRange offsets,
1670                       ValueRange sizes, ValueRange strides,
1671                       ArrayRef<NamedAttribute> attrs) {
1672   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1673       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1674   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1675       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1676   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1677       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1678   build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
1679 }
1680 
1681 // Build a SubViewOp with dynamic entries and inferred result type.
1682 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1683                       ValueRange offsets, ValueRange sizes, ValueRange strides,
1684                       ArrayRef<NamedAttribute> attrs) {
1685   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
1686 }
1687 
1688 /// For ViewLikeOpInterface.
1689 Value SubViewOp::getViewSource() { return source(); }
1690 
1691 enum SubViewVerificationResult {
1692   Success,
1693   RankTooLarge,
1694   SizeMismatch,
1695   ElemTypeMismatch,
1696   MemSpaceMismatch,
1697   AffineMapMismatch
1698 };
1699 
1700 /// Checks if `original` Type type can be rank reduced to `reduced` type.
1701 /// This function is slight variant of `is subsequence` algorithm where
1702 /// not matching dimension must be 1.
1703 static SubViewVerificationResult
1704 isRankReducedType(Type originalType, Type candidateReducedType,
1705                   std::string *errMsg = nullptr) {
1706   if (originalType == candidateReducedType)
1707     return SubViewVerificationResult::Success;
1708   if (!originalType.isa<MemRefType>())
1709     return SubViewVerificationResult::Success;
1710   if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>())
1711     return SubViewVerificationResult::Success;
1712 
1713   ShapedType originalShapedType = originalType.cast<ShapedType>();
1714   ShapedType candidateReducedShapedType =
1715       candidateReducedType.cast<ShapedType>();
1716 
1717   // Rank and size logic is valid for all ShapedTypes.
1718   ArrayRef<int64_t> originalShape = originalShapedType.getShape();
1719   ArrayRef<int64_t> candidateReducedShape =
1720       candidateReducedShapedType.getShape();
1721   unsigned originalRank = originalShape.size(),
1722            candidateReducedRank = candidateReducedShape.size();
1723   if (candidateReducedRank > originalRank)
1724     return SubViewVerificationResult::RankTooLarge;
1725 
1726   auto optionalUnusedDimsMask =
1727       computeRankReductionMask(originalShape, candidateReducedShape);
1728 
1729   // Sizes cannot be matched in case empty vector is returned.
1730   if (!optionalUnusedDimsMask.hasValue())
1731     return SubViewVerificationResult::SizeMismatch;
1732 
1733   if (originalShapedType.getElementType() !=
1734       candidateReducedShapedType.getElementType())
1735     return SubViewVerificationResult::ElemTypeMismatch;
1736 
1737   // Strided layout logic is relevant for MemRefType only.
1738   MemRefType original = originalType.cast<MemRefType>();
1739   MemRefType candidateReduced = candidateReducedType.cast<MemRefType>();
1740   if (original.getMemorySpace() != candidateReduced.getMemorySpace())
1741     return SubViewVerificationResult::MemSpaceMismatch;
1742 
1743   llvm::SmallDenseSet<unsigned> unusedDims = optionalUnusedDimsMask.getValue();
1744   auto inferredType =
1745       getProjectedMap(getStridedLinearLayoutMap(original), unusedDims);
1746   AffineMap candidateLayout;
1747   if (candidateReduced.getAffineMaps().empty())
1748     candidateLayout = getStridedLinearLayoutMap(candidateReduced);
1749   else
1750     candidateLayout = candidateReduced.getAffineMaps().front();
1751   assert(inferredType.getNumResults() == 1 &&
1752          candidateLayout.getNumResults() == 1);
1753   if (inferredType.getNumSymbols() != candidateLayout.getNumSymbols() ||
1754       inferredType.getNumDims() != candidateLayout.getNumDims()) {
1755     if (errMsg) {
1756       llvm::raw_string_ostream os(*errMsg);
1757       os << "inferred type: " << inferredType;
1758     }
1759     return SubViewVerificationResult::AffineMapMismatch;
1760   }
1761   // Check that the difference of the affine maps simplifies to 0.
1762   AffineExpr diffExpr =
1763       inferredType.getResult(0) - candidateLayout.getResult(0);
1764   diffExpr = simplifyAffineExpr(diffExpr, inferredType.getNumDims(),
1765                                 inferredType.getNumSymbols());
1766   auto cst = diffExpr.dyn_cast<AffineConstantExpr>();
1767   if (!(cst && cst.getValue() == 0)) {
1768     if (errMsg) {
1769       llvm::raw_string_ostream os(*errMsg);
1770       os << "inferred type: " << inferredType;
1771     }
1772     return SubViewVerificationResult::AffineMapMismatch;
1773   }
1774   return SubViewVerificationResult::Success;
1775 }
1776 
1777 template <typename OpTy>
1778 static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result,
1779                                             OpTy op, Type expectedType,
1780                                             StringRef errMsg = "") {
1781   auto memrefType = expectedType.cast<ShapedType>();
1782   switch (result) {
1783   case SubViewVerificationResult::Success:
1784     return success();
1785   case SubViewVerificationResult::RankTooLarge:
1786     return op.emitError("expected result rank to be smaller or equal to ")
1787            << "the source rank. " << errMsg;
1788   case SubViewVerificationResult::SizeMismatch:
1789     return op.emitError("expected result type to be ")
1790            << expectedType
1791            << " or a rank-reduced version. (mismatch of result sizes) "
1792            << errMsg;
1793   case SubViewVerificationResult::ElemTypeMismatch:
1794     return op.emitError("expected result element type to be ")
1795            << memrefType.getElementType() << errMsg;
1796   case SubViewVerificationResult::MemSpaceMismatch:
1797     return op.emitError("expected result and source memory spaces to match.")
1798            << errMsg;
1799   case SubViewVerificationResult::AffineMapMismatch:
1800     return op.emitError("expected result type to be ")
1801            << expectedType
1802            << " or a rank-reduced version. (mismatch of result affine map) "
1803            << errMsg;
1804   }
1805   llvm_unreachable("unexpected subview verification result");
1806 }
1807 
1808 /// Verifier for SubViewOp.
1809 static LogicalResult verify(SubViewOp op) {
1810   MemRefType baseType = op.getSourceType();
1811   MemRefType subViewType = op.getType();
1812 
1813   // The base memref and the view memref should be in the same memory space.
1814   if (baseType.getMemorySpace() != subViewType.getMemorySpace())
1815     return op.emitError("different memory spaces specified for base memref "
1816                         "type ")
1817            << baseType << " and subview memref type " << subViewType;
1818 
1819   // Verify that the base memref type has a strided layout map.
1820   if (!isStrided(baseType))
1821     return op.emitError("base type ") << baseType << " is not strided";
1822 
1823   // Verify result type against inferred type.
1824   auto expectedType = SubViewOp::inferResultType(
1825       baseType, extractFromI64ArrayAttr(op.static_offsets()),
1826       extractFromI64ArrayAttr(op.static_sizes()),
1827       extractFromI64ArrayAttr(op.static_strides()));
1828 
1829   std::string errMsg;
1830   auto result = isRankReducedType(expectedType, subViewType, &errMsg);
1831   return produceSubViewErrorMsg(result, op, expectedType, errMsg);
1832 }
1833 
1834 raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) {
1835   return os << "range " << range.offset << ":" << range.size << ":"
1836             << range.stride;
1837 }
1838 
1839 /// Return the list of Range (i.e. offset, size, stride). Each Range
1840 /// entry contains either the dynamic value or a ConstantIndexOp constructed
1841 /// with `b` at location `loc`.
1842 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
1843                                               OpBuilder &b, Location loc) {
1844   std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
1845   assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
1846   assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
1847   SmallVector<Range, 8> res;
1848   unsigned rank = ranks[0];
1849   res.reserve(rank);
1850   for (unsigned idx = 0; idx < rank; ++idx) {
1851     Value offset =
1852         op.isDynamicOffset(idx)
1853             ? op.getDynamicOffset(idx)
1854             : b.create<ConstantIndexOp>(loc, op.getStaticOffset(idx));
1855     Value size = op.isDynamicSize(idx)
1856                      ? op.getDynamicSize(idx)
1857                      : b.create<ConstantIndexOp>(loc, op.getStaticSize(idx));
1858     Value stride =
1859         op.isDynamicStride(idx)
1860             ? op.getDynamicStride(idx)
1861             : b.create<ConstantIndexOp>(loc, op.getStaticStride(idx));
1862     res.emplace_back(Range{offset, size, stride});
1863   }
1864   return res;
1865 }
1866 
1867 /// Infer the canonical type of the result of a subview operation. Returns a
1868 /// type with rank `resultRank` that is either the rank of the rank-reduced
1869 /// type, or the non-rank-reduced type.
1870 static MemRefType
1871 getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType,
1872                               ArrayRef<OpFoldResult> mixedOffsets,
1873                               ArrayRef<OpFoldResult> mixedSizes,
1874                               ArrayRef<OpFoldResult> mixedStrides) {
1875   auto resultType =
1876       SubViewOp::inferRankReducedResultType(
1877           resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides)
1878           .cast<MemRefType>();
1879   if (resultType.getRank() != resultRank) {
1880     resultType = SubViewOp::inferResultType(sourceType, mixedOffsets,
1881                                             mixedSizes, mixedStrides)
1882                      .cast<MemRefType>();
1883   }
1884   return resultType;
1885 }
1886 
1887 namespace {
1888 /// Pattern to rewrite a subview op with MemRefCast arguments.
1889 /// This essentially pushes memref.cast past its consuming subview when
1890 /// `canFoldIntoConsumerOp` is true.
1891 ///
1892 /// Example:
1893 /// ```
1894 ///   %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
1895 ///   %1 = memref.subview %0[0, 0][3, 4][1, 1] :
1896 ///     memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
1897 /// ```
1898 /// is rewritten into:
1899 /// ```
1900 ///   %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
1901 ///   %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to
1902 ///     memref<3x4xf32, offset:?, strides:[?, 1]>
1903 /// ```
1904 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
1905 public:
1906   using OpRewritePattern<SubViewOp>::OpRewritePattern;
1907 
1908   LogicalResult matchAndRewrite(SubViewOp subViewOp,
1909                                 PatternRewriter &rewriter) const override {
1910     // Any constant operand, just return to let SubViewOpConstantFolder kick in.
1911     if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
1912           return matchPattern(operand, matchConstantIndex());
1913         }))
1914       return failure();
1915 
1916     auto castOp = subViewOp.source().getDefiningOp<CastOp>();
1917     if (!castOp)
1918       return failure();
1919 
1920     if (!CastOp::canFoldIntoConsumerOp(castOp))
1921       return failure();
1922 
1923     /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
1924     /// the cast source operand type and the SubViewOp static information. This
1925     /// is the resulting type if the MemRefCastOp were folded.
1926     auto resultType = getCanonicalSubViewResultType(
1927         subViewOp.getType().getRank(),
1928         castOp.source().getType().cast<MemRefType>(),
1929         subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
1930         subViewOp.getMixedStrides());
1931     Value newSubView = rewriter.create<SubViewOp>(
1932         subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
1933         subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),
1934         subViewOp.static_sizes(), subViewOp.static_strides());
1935     rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
1936                                         newSubView);
1937     return success();
1938   }
1939 };
1940 } // namespace
1941 
1942 /// Return the canonical type of the result of a subview.
1943 struct SubViewReturnTypeCanonicalizer {
1944   MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
1945                         ArrayRef<OpFoldResult> mixedSizes,
1946                         ArrayRef<OpFoldResult> mixedStrides) {
1947     return getCanonicalSubViewResultType(op.getType().getRank(),
1948                                          op.getSourceType(), mixedOffsets,
1949                                          mixedSizes, mixedStrides);
1950   }
1951 };
1952 
1953 /// A canonicalizer wrapper to replace SubViewOps.
1954 struct SubViewCanonicalizer {
1955   void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
1956     rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType());
1957   }
1958 };
1959 
1960 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
1961                                             MLIRContext *context) {
1962   results
1963       .add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
1964                SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>,
1965            SubViewOpMemRefCastFolder>(context);
1966 }
1967 
1968 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
1969   auto resultShapedType = getResult().getType().cast<ShapedType>();
1970   auto sourceShapedType = source().getType().cast<ShapedType>();
1971 
1972   if (resultShapedType.hasStaticShape() &&
1973       resultShapedType == sourceShapedType) {
1974     return getViewSource();
1975   }
1976 
1977   return {};
1978 }
1979 
1980 //===----------------------------------------------------------------------===//
1981 // TensorLoadOp
1982 //===----------------------------------------------------------------------===//
1983 
1984 OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute>) {
1985   if (auto bufferCast = memref().getDefiningOp<BufferCastOp>())
1986     // Approximate alias analysis by conservatively folding only when no there
1987     // is no interleaved operation.
1988     if (bufferCast->getBlock() == this->getOperation()->getBlock() &&
1989         bufferCast->getNextNode() == this->getOperation())
1990       return bufferCast.tensor();
1991   return {};
1992 }
1993 
1994 //===----------------------------------------------------------------------===//
1995 // TransposeOp
1996 //===----------------------------------------------------------------------===//
1997 
1998 /// Build a strided memref type by applying `permutationMap` tp `memRefType`.
1999 static MemRefType inferTransposeResultType(MemRefType memRefType,
2000                                            AffineMap permutationMap) {
2001   auto rank = memRefType.getRank();
2002   auto originalSizes = memRefType.getShape();
2003   // Compute permuted sizes.
2004   SmallVector<int64_t, 4> sizes(rank, 0);
2005   for (auto en : llvm::enumerate(permutationMap.getResults()))
2006     sizes[en.index()] =
2007         originalSizes[en.value().cast<AffineDimExpr>().getPosition()];
2008 
2009   // Compute permuted strides.
2010   int64_t offset;
2011   SmallVector<int64_t, 4> strides;
2012   auto res = getStridesAndOffset(memRefType, strides, offset);
2013   assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank));
2014   (void)res;
2015   auto map =
2016       makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
2017   map = permutationMap ? map.compose(permutationMap) : map;
2018   return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map);
2019 }
2020 
2021 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
2022                         AffineMapAttr permutation,
2023                         ArrayRef<NamedAttribute> attrs) {
2024   auto permutationMap = permutation.getValue();
2025   assert(permutationMap);
2026 
2027   auto memRefType = in.getType().cast<MemRefType>();
2028   // Compute result type.
2029   MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
2030 
2031   build(b, result, resultType, in, attrs);
2032   result.addAttribute(TransposeOp::getPermutationAttrName(), permutation);
2033 }
2034 
2035 // transpose $in $permutation attr-dict : type($in) `to` type(results)
2036 static void print(OpAsmPrinter &p, TransposeOp op) {
2037   p << "memref.transpose " << op.in() << " " << op.permutation();
2038   p.printOptionalAttrDict(op->getAttrs(),
2039                           {TransposeOp::getPermutationAttrName()});
2040   p << " : " << op.in().getType() << " to " << op.getType();
2041 }
2042 
2043 static ParseResult parseTransposeOp(OpAsmParser &parser,
2044                                     OperationState &result) {
2045   OpAsmParser::OperandType in;
2046   AffineMap permutation;
2047   MemRefType srcType, dstType;
2048   if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
2049       parser.parseOptionalAttrDict(result.attributes) ||
2050       parser.parseColonType(srcType) ||
2051       parser.resolveOperand(in, srcType, result.operands) ||
2052       parser.parseKeywordType("to", dstType) ||
2053       parser.addTypeToList(dstType, result.types))
2054     return failure();
2055 
2056   result.addAttribute(TransposeOp::getPermutationAttrName(),
2057                       AffineMapAttr::get(permutation));
2058   return success();
2059 }
2060 
2061 static LogicalResult verify(TransposeOp op) {
2062   if (!op.permutation().isPermutation())
2063     return op.emitOpError("expected a permutation map");
2064   if (op.permutation().getNumDims() != op.getShapedType().getRank())
2065     return op.emitOpError(
2066         "expected a permutation map of same rank as the input");
2067 
2068   auto srcType = op.in().getType().cast<MemRefType>();
2069   auto dstType = op.getType().cast<MemRefType>();
2070   auto transposedType = inferTransposeResultType(srcType, op.permutation());
2071   if (dstType != transposedType)
2072     return op.emitOpError("output type ")
2073            << dstType << " does not match transposed input type " << srcType
2074            << ", " << transposedType;
2075   return success();
2076 }
2077 
2078 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
2079   if (succeeded(foldMemRefCast(*this)))
2080     return getResult();
2081   return {};
2082 }
2083 
2084 //===----------------------------------------------------------------------===//
2085 // ViewOp
2086 //===----------------------------------------------------------------------===//
2087 
2088 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) {
2089   OpAsmParser::OperandType srcInfo;
2090   SmallVector<OpAsmParser::OperandType, 1> offsetInfo;
2091   SmallVector<OpAsmParser::OperandType, 4> sizesInfo;
2092   auto indexType = parser.getBuilder().getIndexType();
2093   Type srcType, dstType;
2094   llvm::SMLoc offsetLoc;
2095   if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) ||
2096       parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square))
2097     return failure();
2098 
2099   if (offsetInfo.size() != 1)
2100     return parser.emitError(offsetLoc) << "expects 1 offset operand";
2101 
2102   return failure(
2103       parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) ||
2104       parser.parseOptionalAttrDict(result.attributes) ||
2105       parser.parseColonType(srcType) ||
2106       parser.resolveOperand(srcInfo, srcType, result.operands) ||
2107       parser.resolveOperands(offsetInfo, indexType, result.operands) ||
2108       parser.resolveOperands(sizesInfo, indexType, result.operands) ||
2109       parser.parseKeywordType("to", dstType) ||
2110       parser.addTypeToList(dstType, result.types));
2111 }
2112 
2113 static void print(OpAsmPrinter &p, ViewOp op) {
2114   p << op.getOperationName() << ' ' << op.getOperand(0) << '[';
2115   p.printOperand(op.byte_shift());
2116   p << "][" << op.sizes() << ']';
2117   p.printOptionalAttrDict(op->getAttrs());
2118   p << " : " << op.getOperand(0).getType() << " to " << op.getType();
2119 }
2120 
2121 static LogicalResult verify(ViewOp op) {
2122   auto baseType = op.getOperand(0).getType().cast<MemRefType>();
2123   auto viewType = op.getType();
2124 
2125   // The base memref should have identity layout map (or none).
2126   if (baseType.getAffineMaps().size() > 1 ||
2127       (baseType.getAffineMaps().size() == 1 &&
2128        !baseType.getAffineMaps()[0].isIdentity()))
2129     return op.emitError("unsupported map for base memref type ") << baseType;
2130 
2131   // The result memref should have identity layout map (or none).
2132   if (viewType.getAffineMaps().size() > 1 ||
2133       (viewType.getAffineMaps().size() == 1 &&
2134        !viewType.getAffineMaps()[0].isIdentity()))
2135     return op.emitError("unsupported map for result memref type ") << viewType;
2136 
2137   // The base memref and the view memref should be in the same memory space.
2138   if (baseType.getMemorySpace() != viewType.getMemorySpace())
2139     return op.emitError("different memory spaces specified for base memref "
2140                         "type ")
2141            << baseType << " and view memref type " << viewType;
2142 
2143   // Verify that we have the correct number of sizes for the result type.
2144   unsigned numDynamicDims = viewType.getNumDynamicDims();
2145   if (op.sizes().size() != numDynamicDims)
2146     return op.emitError("incorrect number of size operands for type ")
2147            << viewType;
2148 
2149   return success();
2150 }
2151 
2152 Value ViewOp::getViewSource() { return source(); }
2153 
2154 namespace {
2155 
2156 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
2157   using OpRewritePattern<ViewOp>::OpRewritePattern;
2158 
2159   LogicalResult matchAndRewrite(ViewOp viewOp,
2160                                 PatternRewriter &rewriter) const override {
2161     // Return if none of the operands are constants.
2162     if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
2163           return matchPattern(operand, matchConstantIndex());
2164         }))
2165       return failure();
2166 
2167     // Get result memref type.
2168     auto memrefType = viewOp.getType();
2169 
2170     // Get offset from old memref view type 'memRefType'.
2171     int64_t oldOffset;
2172     SmallVector<int64_t, 4> oldStrides;
2173     if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
2174       return failure();
2175     assert(oldOffset == 0 && "Expected 0 offset");
2176 
2177     SmallVector<Value, 4> newOperands;
2178 
2179     // Offset cannot be folded into result type.
2180 
2181     // Fold any dynamic dim operands which are produced by a constant.
2182     SmallVector<int64_t, 4> newShapeConstants;
2183     newShapeConstants.reserve(memrefType.getRank());
2184 
2185     unsigned dynamicDimPos = 0;
2186     unsigned rank = memrefType.getRank();
2187     for (unsigned dim = 0, e = rank; dim < e; ++dim) {
2188       int64_t dimSize = memrefType.getDimSize(dim);
2189       // If this is already static dimension, keep it.
2190       if (!ShapedType::isDynamic(dimSize)) {
2191         newShapeConstants.push_back(dimSize);
2192         continue;
2193       }
2194       auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp();
2195       if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
2196         // Dynamic shape dimension will be folded.
2197         newShapeConstants.push_back(constantIndexOp.getValue());
2198       } else {
2199         // Dynamic shape dimension not folded; copy operand from old memref.
2200         newShapeConstants.push_back(dimSize);
2201         newOperands.push_back(viewOp.sizes()[dynamicDimPos]);
2202       }
2203       dynamicDimPos++;
2204     }
2205 
2206     // Create new memref type with constant folded dims.
2207     MemRefType newMemRefType =
2208         MemRefType::Builder(memrefType).setShape(newShapeConstants);
2209     // Nothing new, don't fold.
2210     if (newMemRefType == memrefType)
2211       return failure();
2212 
2213     // Create new ViewOp.
2214     auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType,
2215                                              viewOp.getOperand(0),
2216                                              viewOp.byte_shift(), newOperands);
2217     // Insert a cast so we have the same type as the old memref type.
2218     rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType());
2219     return success();
2220   }
2221 };
2222 
2223 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
2224   using OpRewritePattern<ViewOp>::OpRewritePattern;
2225 
2226   LogicalResult matchAndRewrite(ViewOp viewOp,
2227                                 PatternRewriter &rewriter) const override {
2228     Value memrefOperand = viewOp.getOperand(0);
2229     CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
2230     if (!memrefCastOp)
2231       return failure();
2232     Value allocOperand = memrefCastOp.getOperand();
2233     AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
2234     if (!allocOp)
2235       return failure();
2236     rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
2237                                         viewOp.byte_shift(), viewOp.sizes());
2238     return success();
2239   }
2240 };
2241 
2242 } // end anonymous namespace
2243 
2244 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
2245                                          MLIRContext *context) {
2246   results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
2247 }
2248 
2249 //===----------------------------------------------------------------------===//
2250 // TableGen'd op method definitions
2251 //===----------------------------------------------------------------------===//
2252 
2253 #define GET_OP_CLASSES
2254 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
2255