1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/MemRef/IR/MemRef.h"
10 #include "mlir/Dialect/StandardOps/IR/Ops.h"
11 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
12 #include "mlir/Dialect/Tensor/IR/Tensor.h"
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/Interfaces/InferTypeOpInterface.h"
20 #include "llvm/ADT/STLExtras.h"
21 
22 using namespace mlir;
23 using namespace mlir::memref;
24 
25 /// Materialize a single constant operation from a given attribute value with
26 /// the desired resultant type.
27 Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
28                                               Attribute value, Type type,
29                                               Location loc) {
30   return builder.create<mlir::ConstantOp>(loc, type, value);
31 }
32 
33 /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
34 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
35   return llvm::to_vector<4>(
36       llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
37         return a.cast<IntegerAttr>().getInt();
38       }));
39 }
40 
41 /// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
42 /// it is a Value or into `staticVec` if it is an IntegerAttr.
43 /// In the case of a Value, a copy of the `sentinel` value is also pushed to
44 /// `staticVec`. This is useful to extract mixed static and dynamic entries that
45 /// come from an AttrSizedOperandSegments trait.
46 static void dispatchIndexOpFoldResult(OpFoldResult ofr,
47                                       SmallVectorImpl<Value> &dynamicVec,
48                                       SmallVectorImpl<int64_t> &staticVec,
49                                       int64_t sentinel) {
50   if (auto v = ofr.dyn_cast<Value>()) {
51     dynamicVec.push_back(v);
52     staticVec.push_back(sentinel);
53     return;
54   }
55   APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue();
56   staticVec.push_back(apInt.getSExtValue());
57 }
58 
59 static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
60                                        SmallVectorImpl<Value> &dynamicVec,
61                                        SmallVectorImpl<int64_t> &staticVec,
62                                        int64_t sentinel) {
63   for (auto ofr : ofrs)
64     dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel);
65 }
66 
67 //===----------------------------------------------------------------------===//
68 // Common canonicalization pattern support logic
69 //===----------------------------------------------------------------------===//
70 
71 /// This is a common class used for patterns of the form
72 /// "someop(memrefcast) -> someop".  It folds the source of any memref.cast
73 /// into the root operation directly.
74 static LogicalResult foldMemRefCast(Operation *op) {
75   bool folded = false;
76   for (OpOperand &operand : op->getOpOperands()) {
77     auto cast = operand.get().getDefiningOp<CastOp>();
78     if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
79       operand.set(cast.getOperand());
80       folded = true;
81     }
82   }
83   return success(folded);
84 }
85 
86 //===----------------------------------------------------------------------===//
87 // Helpers for GlobalOp
88 //===----------------------------------------------------------------------===//
89 
90 static Type getTensorTypeFromMemRefType(Type type) {
91   if (auto memref = type.dyn_cast<MemRefType>())
92     return RankedTensorType::get(memref.getShape(), memref.getElementType());
93   if (auto memref = type.dyn_cast<UnrankedMemRefType>())
94     return UnrankedTensorType::get(memref.getElementType());
95   return NoneType::get(type.getContext());
96 }
97 
98 //===----------------------------------------------------------------------===//
99 // AllocOp / AllocaOp
100 //===----------------------------------------------------------------------===//
101 
102 template <typename AllocLikeOp>
103 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
104   static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
105                 "applies to only alloc or alloca");
106   auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>();
107   if (!memRefType)
108     return op.emitOpError("result must be a memref");
109 
110   if (static_cast<int64_t>(op.dynamicSizes().size()) !=
111       memRefType.getNumDynamicDims())
112     return op.emitOpError("dimension operand count does not equal memref "
113                           "dynamic dimension count");
114 
115   unsigned numSymbols = 0;
116   if (!memRefType.getAffineMaps().empty())
117     numSymbols = memRefType.getAffineMaps().front().getNumSymbols();
118   if (op.symbolOperands().size() != numSymbols)
119     return op.emitOpError(
120         "symbol operand count does not equal memref symbol count");
121 
122   return success();
123 }
124 
125 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); }
126 
127 static LogicalResult verify(AllocaOp op) {
128   // An alloca op needs to have an ancestor with an allocation scope trait.
129   if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
130     return op.emitOpError(
131         "requires an ancestor op with AutomaticAllocationScope trait");
132 
133   return verifyAllocLikeOp(op);
134 }
135 
136 namespace {
137 /// Fold constant dimensions into an alloc like operation.
138 template <typename AllocLikeOp>
139 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
140   using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
141 
142   LogicalResult matchAndRewrite(AllocLikeOp alloc,
143                                 PatternRewriter &rewriter) const override {
144     // Check to see if any dimensions operands are constants.  If so, we can
145     // substitute and drop them.
146     if (llvm::none_of(alloc.getOperands(), [](Value operand) {
147           return matchPattern(operand, matchConstantIndex());
148         }))
149       return failure();
150 
151     auto memrefType = alloc.getType();
152 
153     // Ok, we have one or more constant operands.  Collect the non-constant ones
154     // and keep track of the resultant memref type to build.
155     SmallVector<int64_t, 4> newShapeConstants;
156     newShapeConstants.reserve(memrefType.getRank());
157     SmallVector<Value, 4> newOperands;
158 
159     unsigned dynamicDimPos = 0;
160     for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
161       int64_t dimSize = memrefType.getDimSize(dim);
162       // If this is already static dimension, keep it.
163       if (dimSize != -1) {
164         newShapeConstants.push_back(dimSize);
165         continue;
166       }
167       auto *defOp = alloc.getOperand(dynamicDimPos).getDefiningOp();
168       if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
169         // Dynamic shape dimension will be folded.
170         newShapeConstants.push_back(constantIndexOp.getValue());
171       } else {
172         // Dynamic shape dimension not folded; copy operand from old memref.
173         newShapeConstants.push_back(-1);
174         newOperands.push_back(alloc.getOperand(dynamicDimPos));
175       }
176       dynamicDimPos++;
177     }
178 
179     // Create new memref type (which will have fewer dynamic dimensions).
180     MemRefType newMemRefType =
181         MemRefType::Builder(memrefType).setShape(newShapeConstants);
182     assert(static_cast<int64_t>(newOperands.size()) ==
183            newMemRefType.getNumDynamicDims());
184 
185     // Create and insert the alloc op for the new memref.
186     auto newAlloc = rewriter.create<AllocLikeOp>(alloc.getLoc(), newMemRefType,
187                                                  newOperands, IntegerAttr());
188     // Insert a cast so we have the same type as the old alloc.
189     auto resultCast =
190         rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType());
191 
192     rewriter.replaceOp(alloc, {resultCast});
193     return success();
194   }
195 };
196 
197 /// Fold alloc operations with no uses. Alloc has side effects on the heap,
198 /// but can still be deleted if it has zero uses.
199 struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> {
200   using OpRewritePattern<AllocOp>::OpRewritePattern;
201 
202   LogicalResult matchAndRewrite(AllocOp alloc,
203                                 PatternRewriter &rewriter) const override {
204     if (alloc.use_empty()) {
205       rewriter.eraseOp(alloc);
206       return success();
207     }
208     return failure();
209   }
210 };
211 } // end anonymous namespace.
212 
213 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
214                                           MLIRContext *context) {
215   results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc>(context);
216 }
217 
218 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
219                                            MLIRContext *context) {
220   results.add<SimplifyAllocConst<AllocaOp>>(context);
221 }
222 
223 //===----------------------------------------------------------------------===//
224 // AssumeAlignmentOp
225 //===----------------------------------------------------------------------===//
226 
227 static LogicalResult verify(AssumeAlignmentOp op) {
228   unsigned alignment = op.alignment();
229   if (!llvm::isPowerOf2_32(alignment))
230     return op.emitOpError("alignment must be power of 2");
231   return success();
232 }
233 
234 //===----------------------------------------------------------------------===//
235 // BufferCastOp
236 //===----------------------------------------------------------------------===//
237 
238 OpFoldResult BufferCastOp::fold(ArrayRef<Attribute>) {
239   if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>())
240     if (tensorLoad.memref().getType() == getType())
241       return tensorLoad.memref();
242   return {};
243 }
244 
245 namespace {
246 /// Replace tensor_cast + buffer_cast by buffer_cast + memref_cast.
247 struct BufferCast : public OpRewritePattern<BufferCastOp> {
248   using OpRewritePattern<BufferCastOp>::OpRewritePattern;
249 
250   LogicalResult matchAndRewrite(BufferCastOp bufferCast,
251                                 PatternRewriter &rewriter) const final {
252     auto tensorCastOperand =
253         bufferCast.getOperand().getDefiningOp<tensor::CastOp>();
254     if (!tensorCastOperand)
255       return failure();
256     auto srcTensorType =
257         tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>();
258     if (!srcTensorType)
259       return failure();
260     auto memrefType = MemRefType::get(srcTensorType.getShape(),
261                                       srcTensorType.getElementType());
262     Value memref = rewriter.create<BufferCastOp>(
263         bufferCast.getLoc(), memrefType, tensorCastOperand.getOperand());
264     rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(),
265                                         memref);
266     return success();
267   }
268 };
269 
270 /// Canonicalize memref.tensor_load + memref.buffer_cast to memref.cast when
271 /// type mismatches prevent `BufferCastOp::fold` to kick in.
272 struct TensorLoadToMemRef : public OpRewritePattern<BufferCastOp> {
273   using OpRewritePattern<BufferCastOp>::OpRewritePattern;
274 
275   LogicalResult matchAndRewrite(BufferCastOp bufferCast,
276                                 PatternRewriter &rewriter) const final {
277     auto tensorLoad = bufferCast.tensor().getDefiningOp<TensorLoadOp>();
278     // Bail unless we have a tensor_load + memref.buffer_cast with different
279     // types. `BufferCastOp::fold` handles the same type case.
280     if (!tensorLoad || tensorLoad.memref().getType() == bufferCast.getType())
281       return failure();
282     // If types are not cast-compatible, bail.
283     if (!CastOp::areCastCompatible(tensorLoad.memref().getType(),
284                                    bufferCast.getType()))
285       return failure();
286     rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(),
287                                         tensorLoad.memref());
288     return success();
289   }
290 };
291 
292 } // namespace
293 
294 void BufferCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
295                                                MLIRContext *context) {
296   results.add<BufferCast, TensorLoadToMemRef>(context);
297 }
298 
299 //===----------------------------------------------------------------------===//
300 // CastOp
301 //===----------------------------------------------------------------------===//
302 
303 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
304 /// source memref. This is useful to to fold a memref.cast into a consuming op
305 /// and implement canonicalization patterns for ops in different dialects that
306 /// may consume the results of memref.cast operations. Such foldable memref.cast
307 /// operations are typically inserted as `view` and `subview` ops are
308 /// canonicalized, to preserve the type compatibility of their uses.
309 ///
310 /// Returns true when all conditions are met:
311 /// 1. source and result are ranked memrefs with strided semantics and same
312 /// element type and rank.
313 /// 2. each of the source's size, offset or stride has more static information
314 /// than the corresponding result's size, offset or stride.
315 ///
316 /// Example 1:
317 /// ```mlir
318 ///   %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
319 ///   %2 = consumer %1 ... : memref<?x?xf32> ...
320 /// ```
321 ///
322 /// may fold into:
323 ///
324 /// ```mlir
325 ///   %2 = consumer %0 ... : memref<8x16xf32> ...
326 /// ```
327 ///
328 /// Example 2:
329 /// ```
330 ///   %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
331 ///          to memref<?x?xf32>
332 ///   consumer %1 : memref<?x?xf32> ...
333 /// ```
334 ///
335 /// may fold into:
336 ///
337 /// ```
338 ///   consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
339 /// ```
340 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
341   MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
342   MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
343 
344   // Requires ranked MemRefType.
345   if (!sourceType || !resultType)
346     return false;
347 
348   // Requires same elemental type.
349   if (sourceType.getElementType() != resultType.getElementType())
350     return false;
351 
352   // Requires same rank.
353   if (sourceType.getRank() != resultType.getRank())
354     return false;
355 
356   // Only fold casts between strided memref forms.
357   int64_t sourceOffset, resultOffset;
358   SmallVector<int64_t, 4> sourceStrides, resultStrides;
359   if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
360       failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
361     return false;
362 
363   // If cast is towards more static sizes along any dimension, don't fold.
364   for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
365     auto ss = std::get<0>(it), st = std::get<1>(it);
366     if (ss != st)
367       if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st))
368         return false;
369   }
370 
371   // If cast is towards more static offset along any dimension, don't fold.
372   if (sourceOffset != resultOffset)
373     if (MemRefType::isDynamicStrideOrOffset(sourceOffset) &&
374         !MemRefType::isDynamicStrideOrOffset(resultOffset))
375       return false;
376 
377   // If cast is towards more static strides along any dimension, don't fold.
378   for (auto it : llvm::zip(sourceStrides, resultStrides)) {
379     auto ss = std::get<0>(it), st = std::get<1>(it);
380     if (ss != st)
381       if (MemRefType::isDynamicStrideOrOffset(ss) &&
382           !MemRefType::isDynamicStrideOrOffset(st))
383         return false;
384   }
385 
386   return true;
387 }
388 
389 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
390   if (inputs.size() != 1 || outputs.size() != 1)
391     return false;
392   Type a = inputs.front(), b = outputs.front();
393   auto aT = a.dyn_cast<MemRefType>();
394   auto bT = b.dyn_cast<MemRefType>();
395 
396   auto uaT = a.dyn_cast<UnrankedMemRefType>();
397   auto ubT = b.dyn_cast<UnrankedMemRefType>();
398 
399   if (aT && bT) {
400     if (aT.getElementType() != bT.getElementType())
401       return false;
402     if (aT.getAffineMaps() != bT.getAffineMaps()) {
403       int64_t aOffset, bOffset;
404       SmallVector<int64_t, 4> aStrides, bStrides;
405       if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
406           failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
407           aStrides.size() != bStrides.size())
408         return false;
409 
410       // Strides along a dimension/offset are compatible if the value in the
411       // source memref is static and the value in the target memref is the
412       // same. They are also compatible if either one is dynamic (see
413       // description of MemRefCastOp for details).
414       auto checkCompatible = [](int64_t a, int64_t b) {
415         return (a == MemRefType::getDynamicStrideOrOffset() ||
416                 b == MemRefType::getDynamicStrideOrOffset() || a == b);
417       };
418       if (!checkCompatible(aOffset, bOffset))
419         return false;
420       for (auto aStride : enumerate(aStrides))
421         if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
422           return false;
423     }
424     if (aT.getMemorySpace() != bT.getMemorySpace())
425       return false;
426 
427     // They must have the same rank, and any specified dimensions must match.
428     if (aT.getRank() != bT.getRank())
429       return false;
430 
431     for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
432       int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
433       if (aDim != -1 && bDim != -1 && aDim != bDim)
434         return false;
435     }
436     return true;
437   } else {
438     if (!aT && !uaT)
439       return false;
440     if (!bT && !ubT)
441       return false;
442     // Unranked to unranked casting is unsupported
443     if (uaT && ubT)
444       return false;
445 
446     auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
447     auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
448     if (aEltType != bEltType)
449       return false;
450 
451     auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
452     auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
453     if (aMemSpace != bMemSpace)
454       return false;
455 
456     return true;
457   }
458 
459   return false;
460 }
461 
462 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
463   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
464 }
465 
466 //===----------------------------------------------------------------------===//
467 // DeallocOp
468 //===----------------------------------------------------------------------===//
469 namespace {
470 /// Fold Dealloc operations that are deallocating an AllocOp that is only used
471 /// by other Dealloc operations.
472 struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> {
473   using OpRewritePattern<DeallocOp>::OpRewritePattern;
474 
475   LogicalResult matchAndRewrite(DeallocOp dealloc,
476                                 PatternRewriter &rewriter) const override {
477     // Check that the memref operand's defining operation is an AllocOp.
478     Value memref = dealloc.memref();
479     if (!isa_and_nonnull<AllocOp>(memref.getDefiningOp()))
480       return failure();
481 
482     // Check that all of the uses of the AllocOp are other DeallocOps.
483     for (auto *user : memref.getUsers())
484       if (!isa<DeallocOp>(user))
485         return failure();
486 
487     // Erase the dealloc operation.
488     rewriter.eraseOp(dealloc);
489     return success();
490   }
491 };
492 } // end anonymous namespace.
493 
494 static LogicalResult verify(DeallocOp op) {
495   if (!op.memref().getType().isa<MemRefType>())
496     return op.emitOpError("operand must be a memref");
497   return success();
498 }
499 
500 void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
501                                             MLIRContext *context) {
502   results.add<SimplifyDeadDealloc>(context);
503 }
504 
505 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
506                               SmallVectorImpl<OpFoldResult> &results) {
507   /// dealloc(memrefcast) -> dealloc
508   return foldMemRefCast(*this);
509 }
510 
511 //===----------------------------------------------------------------------===//
512 // DimOp
513 //===----------------------------------------------------------------------===//
514 
515 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref,
516                   int64_t index) {
517   auto loc = result.location;
518   Value indexValue = builder.create<ConstantIndexOp>(loc, index);
519   build(builder, result, memref, indexValue);
520 }
521 
522 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref,
523                   Value index) {
524   auto indexTy = builder.getIndexType();
525   build(builder, result, indexTy, memref, index);
526 }
527 
528 Optional<int64_t> DimOp::getConstantIndex() {
529   if (auto constantOp = index().getDefiningOp<ConstantOp>())
530     return constantOp.getValue().cast<IntegerAttr>().getInt();
531   return {};
532 }
533 
534 static LogicalResult verify(DimOp op) {
535   // Assume unknown index to be in range.
536   Optional<int64_t> index = op.getConstantIndex();
537   if (!index.hasValue())
538     return success();
539 
540   // Check that constant index is not knowingly out of range.
541   auto type = op.memrefOrTensor().getType();
542   if (auto memrefType = type.dyn_cast<MemRefType>()) {
543     if (index.getValue() >= memrefType.getRank())
544       return op.emitOpError("index is out of range");
545   } else if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
546     if (index.getValue() >= tensorType.getRank())
547       return op.emitOpError("index is out of range");
548   } else if (type.isa<UnrankedMemRefType>() || type.isa<UnrankedTensorType>()) {
549     // Assume index to be in range.
550   } else {
551     llvm_unreachable("expected operand with memref type");
552   }
553   return success();
554 }
555 
556 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
557   auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
558 
559   // All forms of folding require a known index.
560   if (!index)
561     return {};
562 
563   auto argTy = memrefOrTensor().getType();
564   // Fold if the shape extent along the given index is known.
565   if (auto shapedTy = argTy.dyn_cast<ShapedType>()) {
566     // Folding for unranked types (UnrankedMemRefType) is not supported.
567     if (!shapedTy.hasRank())
568       return {};
569     if (!shapedTy.isDynamicDim(index.getInt())) {
570       Builder builder(getContext());
571       return builder.getIndexAttr(shapedTy.getShape()[index.getInt()]);
572     }
573   }
574 
575   Operation *definingOp = memrefOrTensor().getDefiningOp();
576 
577   // dim(memref.tensor_load(memref)) -> dim(memref)
578   if (auto tensorLoadOp = dyn_cast_or_null<TensorLoadOp>(definingOp)) {
579     setOperand(0, tensorLoadOp.memref());
580     return getResult();
581   }
582 
583   // Fold dim to the operand of tensor.generate.
584   if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
585     auto resultType =
586         fromElements.getResult().getType().cast<RankedTensorType>();
587     // The case where the type encodes the size of the dimension is handled
588     // above.
589     assert(resultType.getShape()[index.getInt()] ==
590            RankedTensorType::kDynamicSize);
591 
592     // Find the operand of the fromElements that corresponds to this index.
593     auto dynExtents = fromElements.dynamicExtents().begin();
594     for (auto dim : resultType.getShape().take_front(index.getInt()))
595       if (dim == RankedTensorType::kDynamicSize)
596         dynExtents++;
597 
598     return Value{*dynExtents};
599   }
600 
601   // The size at the given index is now known to be a dynamic size.
602   unsigned unsignedIndex = index.getValue().getZExtValue();
603 
604   if (auto subtensor = dyn_cast_or_null<mlir::SubTensorOp>(definingOp)) {
605     assert(subtensor.isDynamicSize(unsignedIndex) &&
606            "Expected dynamic subtensor size");
607     return subtensor.getDynamicSize(unsignedIndex);
608   }
609 
610   // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
611   auto memrefType = argTy.dyn_cast<MemRefType>();
612   if (!memrefType)
613     return {};
614 
615   if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
616     return *(alloc.getDynamicSizes().begin() +
617              memrefType.getDynamicDimIndex(unsignedIndex));
618 
619   if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
620     return *(alloca.getDynamicSizes().begin() +
621              memrefType.getDynamicDimIndex(unsignedIndex));
622 
623   if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
624     return *(view.getDynamicSizes().begin() +
625              memrefType.getDynamicDimIndex(unsignedIndex));
626 
627   if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
628     assert(subview.isDynamicSize(unsignedIndex) &&
629            "Expected dynamic subview size");
630     return subview.getDynamicSize(unsignedIndex);
631   }
632 
633   // dim(memrefcast) -> dim
634   if (succeeded(foldMemRefCast(*this)))
635     return getResult();
636 
637   return {};
638 }
639 
640 namespace {
641 /// Fold dim of a memref reshape operation to a load into the reshape's shape
642 /// operand.
643 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
644   using OpRewritePattern<DimOp>::OpRewritePattern;
645 
646   LogicalResult matchAndRewrite(DimOp dim,
647                                 PatternRewriter &rewriter) const override {
648     auto reshape = dim.memrefOrTensor().getDefiningOp<ReshapeOp>();
649 
650     if (!reshape)
651       return failure();
652 
653     // Place the load directly after the reshape to ensure that the shape memref
654     // was not mutated.
655     rewriter.setInsertionPointAfter(reshape);
656     rewriter.replaceOpWithNewOp<LoadOp>(dim, reshape.shape(),
657                                         llvm::makeArrayRef({dim.index()}));
658     return success();
659   }
660 };
661 
662 /// Fold dim of a dim of a cast into the dim of the source of the tensor cast.
663 template <typename CastOpTy>
664 struct DimOfCastOp : public OpRewritePattern<DimOp> {
665   using OpRewritePattern<DimOp>::OpRewritePattern;
666 
667   LogicalResult matchAndRewrite(DimOp dimOp,
668                                 PatternRewriter &rewriter) const override {
669     auto castOp = dimOp.memrefOrTensor().getDefiningOp<CastOpTy>();
670     if (!castOp)
671       return failure();
672     Value newSource = castOp.getOperand();
673     rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.index());
674     return success();
675   }
676 };
677 
678 /// Helper method to get the `Value` that is the shape of the `resultIdx`-th
679 /// result at dimension `dimIndex` from the `ShapedTypeOpInterface`.
680 /// TODO(ravishankarm): This is better put as a interface utility method
681 /// somewhere, but that would imply the interface will depend on the `tensor`
682 /// dialect. Ideally maybe a utility method in the `tensor` dialect.
683 static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result,
684                                             int64_t dimIndex) {
685   unsigned resultNumber = result.getResultNumber();
686   auto shapedTypeOp = dyn_cast<InferShapedTypeOpInterface>(result.getOwner());
687   Location loc = result.getOwner()->getLoc();
688   if (!shapedTypeOp)
689     return nullptr;
690 
691   // The interface exposes two methods, one that returns the shape of all the
692   // results as `Value` and other that returns the shape as a list of
693   // `SmallVector<Value>`. The former takes precedence over the latter. So first
694   // check if the op implements the first interface method or the second, and
695   // get the value to use appropriately.
696   SmallVector<Value> reifiedResultShapes;
697   if (succeeded(
698           shapedTypeOp.reifyReturnTypeShapes(builder, reifiedResultShapes))) {
699     if (reifiedResultShapes.size() <= resultNumber)
700       return nullptr;
701     Value resultShape = reifiedResultShapes[resultNumber];
702     auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>();
703     if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>())
704       return nullptr;
705     return builder.create<tensor::ExtractOp>(
706         loc, resultShape, builder.createOrFold<ConstantIndexOp>(loc, dimIndex));
707   }
708 
709   SmallVector<SmallVector<Value>> reifiedResultShapesPerDim;
710   if (failed(shapedTypeOp.reifyReturnTypeShapesPerResultDim(
711           builder, reifiedResultShapesPerDim)))
712     return nullptr;
713   if (reifiedResultShapesPerDim.size() <= resultNumber ||
714       reifiedResultShapesPerDim[resultNumber].size() !=
715           static_cast<size_t>(result.getType().cast<ShapedType>().getRank()))
716     return nullptr;
717   OpFoldResult valueOrAttr = reifiedResultShapesPerDim[resultNumber][dimIndex];
718   if (auto attr = valueOrAttr.dyn_cast<Attribute>())
719     return builder.createOrFold<ConstantIndexOp>(
720         loc, attr.cast<IntegerAttr>().getInt());
721   return valueOrAttr.get<Value>();
722 }
723 
724 /// Fold dim of an operation that implements the InferShapedTypeOpInterface
725 struct DimOfShapedTypeOpInterface : public OpRewritePattern<DimOp> {
726   using OpRewritePattern<DimOp>::OpRewritePattern;
727 
728   LogicalResult matchAndRewrite(DimOp dimOp,
729                                 PatternRewriter &rewriter) const override {
730     OpResult dimValue = dimOp.memrefOrTensor().dyn_cast<OpResult>();
731     if (!dimValue)
732       return failure();
733     auto shapedTypeOp =
734         dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner());
735     if (!shapedTypeOp)
736       return failure();
737 
738     Optional<int64_t> dimIndex = dimOp.getConstantIndex();
739     if (!dimIndex)
740       return failure();
741     Value replacement =
742         getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex);
743     if (!replacement)
744       return failure();
745     rewriter.replaceOp(dimOp, replacement);
746     return success();
747   }
748 };
749 } // end anonymous namespace.
750 
751 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
752                                         MLIRContext *context) {
753   results.add<DimOfMemRefReshape, DimOfCastOp<BufferCastOp>,
754               DimOfCastOp<tensor::CastOp>, DimOfShapedTypeOpInterface>(context);
755 }
756 
757 // ---------------------------------------------------------------------------
758 // DmaStartOp
759 // ---------------------------------------------------------------------------
760 
761 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
762                        Value srcMemRef, ValueRange srcIndices, Value destMemRef,
763                        ValueRange destIndices, Value numElements,
764                        Value tagMemRef, ValueRange tagIndices, Value stride,
765                        Value elementsPerStride) {
766   result.addOperands(srcMemRef);
767   result.addOperands(srcIndices);
768   result.addOperands(destMemRef);
769   result.addOperands(destIndices);
770   result.addOperands({numElements, tagMemRef});
771   result.addOperands(tagIndices);
772   if (stride)
773     result.addOperands({stride, elementsPerStride});
774 }
775 
776 void DmaStartOp::print(OpAsmPrinter &p) {
777   p << getOperationName() << " " << getSrcMemRef() << '[' << getSrcIndices()
778     << "], " << getDstMemRef() << '[' << getDstIndices() << "], "
779     << getNumElements() << ", " << getTagMemRef() << '[' << getTagIndices()
780     << ']';
781   if (isStrided())
782     p << ", " << getStride() << ", " << getNumElementsPerStride();
783 
784   p.printOptionalAttrDict((*this)->getAttrs());
785   p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
786     << ", " << getTagMemRef().getType();
787 }
788 
789 // Parse DmaStartOp.
790 // Ex:
791 //   %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
792 //                       %tag[%index], %stride, %num_elt_per_stride :
793 //                     : memref<3076 x f32, 0>,
794 //                       memref<1024 x f32, 2>,
795 //                       memref<1 x i32>
796 //
797 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
798   OpAsmParser::OperandType srcMemRefInfo;
799   SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
800   OpAsmParser::OperandType dstMemRefInfo;
801   SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos;
802   OpAsmParser::OperandType numElementsInfo;
803   OpAsmParser::OperandType tagMemrefInfo;
804   SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
805   SmallVector<OpAsmParser::OperandType, 2> strideInfo;
806 
807   SmallVector<Type, 3> types;
808   auto indexType = parser.getBuilder().getIndexType();
809 
810   // Parse and resolve the following list of operands:
811   // *) source memref followed by its indices (in square brackets).
812   // *) destination memref followed by its indices (in square brackets).
813   // *) dma size in KiB.
814   if (parser.parseOperand(srcMemRefInfo) ||
815       parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
816       parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
817       parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
818       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
819       parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
820       parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
821     return failure();
822 
823   // Parse optional stride and elements per stride.
824   if (parser.parseTrailingOperandList(strideInfo))
825     return failure();
826 
827   bool isStrided = strideInfo.size() == 2;
828   if (!strideInfo.empty() && !isStrided) {
829     return parser.emitError(parser.getNameLoc(),
830                             "expected two stride related operands");
831   }
832 
833   if (parser.parseColonTypeList(types))
834     return failure();
835   if (types.size() != 3)
836     return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
837 
838   if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
839       parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
840       parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
841       parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
842       // size should be an index.
843       parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
844       parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
845       // tag indices should be index.
846       parser.resolveOperands(tagIndexInfos, indexType, result.operands))
847     return failure();
848 
849   if (isStrided) {
850     if (parser.resolveOperands(strideInfo, indexType, result.operands))
851       return failure();
852   }
853 
854   return success();
855 }
856 
857 LogicalResult DmaStartOp::verify() {
858   unsigned numOperands = getNumOperands();
859 
860   // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
861   // the number of elements.
862   if (numOperands < 4)
863     return emitOpError("expected at least 4 operands");
864 
865   // Check types of operands. The order of these calls is important: the later
866   // calls rely on some type properties to compute the operand position.
867   // 1. Source memref.
868   if (!getSrcMemRef().getType().isa<MemRefType>())
869     return emitOpError("expected source to be of memref type");
870   if (numOperands < getSrcMemRefRank() + 4)
871     return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
872                          << " operands";
873   if (!getSrcIndices().empty() &&
874       !llvm::all_of(getSrcIndices().getTypes(),
875                     [](Type t) { return t.isIndex(); }))
876     return emitOpError("expected source indices to be of index type");
877 
878   // 2. Destination memref.
879   if (!getDstMemRef().getType().isa<MemRefType>())
880     return emitOpError("expected destination to be of memref type");
881   unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
882   if (numOperands < numExpectedOperands)
883     return emitOpError() << "expected at least " << numExpectedOperands
884                          << " operands";
885   if (!getDstIndices().empty() &&
886       !llvm::all_of(getDstIndices().getTypes(),
887                     [](Type t) { return t.isIndex(); }))
888     return emitOpError("expected destination indices to be of index type");
889 
890   // 3. Number of elements.
891   if (!getNumElements().getType().isIndex())
892     return emitOpError("expected num elements to be of index type");
893 
894   // 4. Tag memref.
895   if (!getTagMemRef().getType().isa<MemRefType>())
896     return emitOpError("expected tag to be of memref type");
897   numExpectedOperands += getTagMemRefRank();
898   if (numOperands < numExpectedOperands)
899     return emitOpError() << "expected at least " << numExpectedOperands
900                          << " operands";
901   if (!getTagIndices().empty() &&
902       !llvm::all_of(getTagIndices().getTypes(),
903                     [](Type t) { return t.isIndex(); }))
904     return emitOpError("expected tag indices to be of index type");
905 
906   // DMAs from different memory spaces supported.
907   if (getSrcMemorySpace() == getDstMemorySpace())
908     return emitOpError("DMA should be between different memory spaces");
909 
910   // Optional stride-related operands must be either both present or both
911   // absent.
912   if (numOperands != numExpectedOperands &&
913       numOperands != numExpectedOperands + 2)
914     return emitOpError("incorrect number of operands");
915 
916   // 5. Strides.
917   if (isStrided()) {
918     if (!getStride().getType().isIndex() ||
919         !getNumElementsPerStride().getType().isIndex())
920       return emitOpError(
921           "expected stride and num elements per stride to be of type index");
922   }
923 
924   return success();
925 }
926 
927 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
928                                SmallVectorImpl<OpFoldResult> &results) {
929   /// dma_start(memrefcast) -> dma_start
930   return foldMemRefCast(*this);
931 }
932 
933 // ---------------------------------------------------------------------------
934 // DmaWaitOp
935 // ---------------------------------------------------------------------------
936 
937 void DmaWaitOp::build(OpBuilder &builder, OperationState &result,
938                       Value tagMemRef, ValueRange tagIndices,
939                       Value numElements) {
940   result.addOperands(tagMemRef);
941   result.addOperands(tagIndices);
942   result.addOperands(numElements);
943 }
944 
945 void DmaWaitOp::print(OpAsmPrinter &p) {
946   p << getOperationName() << " " << getTagMemRef() << '[' << getTagIndices()
947     << "], " << getNumElements();
948   p.printOptionalAttrDict((*this)->getAttrs());
949   p << " : " << getTagMemRef().getType();
950 }
951 
952 // Parse DmaWaitOp.
953 // Eg:
954 //   dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4>
955 //
956 ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) {
957   OpAsmParser::OperandType tagMemrefInfo;
958   SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
959   Type type;
960   auto indexType = parser.getBuilder().getIndexType();
961   OpAsmParser::OperandType numElementsInfo;
962 
963   // Parse tag memref, its indices, and dma size.
964   if (parser.parseOperand(tagMemrefInfo) ||
965       parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) ||
966       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
967       parser.parseColonType(type) ||
968       parser.resolveOperand(tagMemrefInfo, type, result.operands) ||
969       parser.resolveOperands(tagIndexInfos, indexType, result.operands) ||
970       parser.resolveOperand(numElementsInfo, indexType, result.operands))
971     return failure();
972 
973   return success();
974 }
975 
976 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
977                               SmallVectorImpl<OpFoldResult> &results) {
978   /// dma_wait(memrefcast) -> dma_wait
979   return foldMemRefCast(*this);
980 }
981 
982 LogicalResult DmaWaitOp::verify() {
983   // Mandatory non-variadic operands are tag and the number of elements.
984   if (getNumOperands() < 2)
985     return emitOpError() << "expected at least 2 operands";
986 
987   // Check types of operands. The order of these calls is important: the later
988   // calls rely on some type properties to compute the operand position.
989   if (!getTagMemRef().getType().isa<MemRefType>())
990     return emitOpError() << "expected tag to be of memref type";
991 
992   if (getNumOperands() != 2 + getTagMemRefRank())
993     return emitOpError() << "expected " << 2 + getTagMemRefRank()
994                          << " operands";
995 
996   if (!getTagIndices().empty() &&
997       !llvm::all_of(getTagIndices().getTypes(),
998                     [](Type t) { return t.isIndex(); }))
999     return emitOpError() << "expected tag indices to be of index type";
1000 
1001   if (!getNumElements().getType().isIndex())
1002     return emitOpError()
1003            << "expected the number of elements to be of index type";
1004 
1005   return success();
1006 }
1007 
1008 //===----------------------------------------------------------------------===//
1009 // GlobalOp
1010 //===----------------------------------------------------------------------===//
1011 
1012 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op,
1013                                                    TypeAttr type,
1014                                                    Attribute initialValue) {
1015   p << type;
1016   if (!op.isExternal()) {
1017     p << " = ";
1018     if (op.isUninitialized())
1019       p << "uninitialized";
1020     else
1021       p.printAttributeWithoutType(initialValue);
1022   }
1023 }
1024 
1025 static ParseResult
1026 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
1027                                        Attribute &initialValue) {
1028   Type type;
1029   if (parser.parseType(type))
1030     return failure();
1031 
1032   auto memrefType = type.dyn_cast<MemRefType>();
1033   if (!memrefType || !memrefType.hasStaticShape())
1034     return parser.emitError(parser.getNameLoc())
1035            << "type should be static shaped memref, but got " << type;
1036   typeAttr = TypeAttr::get(type);
1037 
1038   if (parser.parseOptionalEqual())
1039     return success();
1040 
1041   if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1042     initialValue = UnitAttr::get(parser.getBuilder().getContext());
1043     return success();
1044   }
1045 
1046   Type tensorType = getTensorTypeFromMemRefType(memrefType);
1047   if (parser.parseAttribute(initialValue, tensorType))
1048     return failure();
1049   if (!initialValue.isa<ElementsAttr>())
1050     return parser.emitError(parser.getNameLoc())
1051            << "initial value should be a unit or elements attribute";
1052   return success();
1053 }
1054 
1055 static LogicalResult verify(GlobalOp op) {
1056   auto memrefType = op.type().dyn_cast<MemRefType>();
1057   if (!memrefType || !memrefType.hasStaticShape())
1058     return op.emitOpError("type should be static shaped memref, but got ")
1059            << op.type();
1060 
1061   // Verify that the initial value, if present, is either a unit attribute or
1062   // an elements attribute.
1063   if (op.initial_value().hasValue()) {
1064     Attribute initValue = op.initial_value().getValue();
1065     if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>())
1066       return op.emitOpError("initial value should be a unit or elements "
1067                             "attribute, but got ")
1068              << initValue;
1069 
1070     // Check that the type of the initial value is compatible with the type of
1071     // the global variable.
1072     if (initValue.isa<ElementsAttr>()) {
1073       Type initType = initValue.getType();
1074       Type tensorType = getTensorTypeFromMemRefType(memrefType);
1075       if (initType != tensorType)
1076         return op.emitOpError("initial value expected to be of type ")
1077                << tensorType << ", but was of type " << initType;
1078     }
1079   }
1080 
1081   // TODO: verify visibility for declarations.
1082   return success();
1083 }
1084 
1085 //===----------------------------------------------------------------------===//
1086 // GetGlobalOp
1087 //===----------------------------------------------------------------------===//
1088 
1089 LogicalResult
1090 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1091   // Verify that the result type is same as the type of the referenced
1092   // memref.global op.
1093   auto global =
1094       symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr());
1095   if (!global)
1096     return emitOpError("'")
1097            << name() << "' does not reference a valid global memref";
1098 
1099   Type resultType = result().getType();
1100   if (global.type() != resultType)
1101     return emitOpError("result type ")
1102            << resultType << " does not match type " << global.type()
1103            << " of the global memref @" << name();
1104   return success();
1105 }
1106 
1107 //===----------------------------------------------------------------------===//
1108 // LoadOp
1109 //===----------------------------------------------------------------------===//
1110 
1111 static LogicalResult verify(LoadOp op) {
1112   if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
1113     return op.emitOpError("incorrect number of indices for load");
1114   return success();
1115 }
1116 
1117 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
1118   /// load(memrefcast) -> load
1119   if (succeeded(foldMemRefCast(*this)))
1120     return getResult();
1121   return OpFoldResult();
1122 }
1123 
1124 namespace {
1125 /// Fold a load on a buffer_cast operation into an tensor.extract on the
1126 /// corresponding tensor.
1127 struct LoadOfBufferCast : public OpRewritePattern<LoadOp> {
1128   using OpRewritePattern<LoadOp>::OpRewritePattern;
1129 
1130   LogicalResult matchAndRewrite(LoadOp load,
1131                                 PatternRewriter &rewriter) const override {
1132     auto buffercast = load.memref().getDefiningOp<BufferCastOp>();
1133     if (!buffercast)
1134       return failure();
1135 
1136     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, buffercast.tensor(),
1137                                                    load.indices());
1138     return success();
1139   }
1140 };
1141 } // end anonymous namespace.
1142 
1143 void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
1144                                          MLIRContext *context) {
1145   results.add<LoadOfBufferCast>(context);
1146 }
1147 
1148 //===----------------------------------------------------------------------===//
1149 // PrefetchOp
1150 //===----------------------------------------------------------------------===//
1151 
1152 static void print(OpAsmPrinter &p, PrefetchOp op) {
1153   p << PrefetchOp::getOperationName() << " " << op.memref() << '[';
1154   p.printOperands(op.indices());
1155   p << ']' << ", " << (op.isWrite() ? "write" : "read");
1156   p << ", locality<" << op.localityHint();
1157   p << ">, " << (op.isDataCache() ? "data" : "instr");
1158   p.printOptionalAttrDict(
1159       op->getAttrs(),
1160       /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1161   p << " : " << op.getMemRefType();
1162 }
1163 
1164 static ParseResult parsePrefetchOp(OpAsmParser &parser,
1165                                    OperationState &result) {
1166   OpAsmParser::OperandType memrefInfo;
1167   SmallVector<OpAsmParser::OperandType, 4> indexInfo;
1168   IntegerAttr localityHint;
1169   MemRefType type;
1170   StringRef readOrWrite, cacheType;
1171 
1172   auto indexTy = parser.getBuilder().getIndexType();
1173   auto i32Type = parser.getBuilder().getIntegerType(32);
1174   if (parser.parseOperand(memrefInfo) ||
1175       parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1176       parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1177       parser.parseComma() || parser.parseKeyword("locality") ||
1178       parser.parseLess() ||
1179       parser.parseAttribute(localityHint, i32Type, "localityHint",
1180                             result.attributes) ||
1181       parser.parseGreater() || parser.parseComma() ||
1182       parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1183       parser.resolveOperand(memrefInfo, type, result.operands) ||
1184       parser.resolveOperands(indexInfo, indexTy, result.operands))
1185     return failure();
1186 
1187   if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
1188     return parser.emitError(parser.getNameLoc(),
1189                             "rw specifier has to be 'read' or 'write'");
1190   result.addAttribute(
1191       PrefetchOp::getIsWriteAttrName(),
1192       parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
1193 
1194   if (!cacheType.equals("data") && !cacheType.equals("instr"))
1195     return parser.emitError(parser.getNameLoc(),
1196                             "cache type has to be 'data' or 'instr'");
1197 
1198   result.addAttribute(
1199       PrefetchOp::getIsDataCacheAttrName(),
1200       parser.getBuilder().getBoolAttr(cacheType.equals("data")));
1201 
1202   return success();
1203 }
1204 
1205 static LogicalResult verify(PrefetchOp op) {
1206   if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
1207     return op.emitOpError("too few indices");
1208 
1209   return success();
1210 }
1211 
1212 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
1213                                SmallVectorImpl<OpFoldResult> &results) {
1214   // prefetch(memrefcast) -> prefetch
1215   return foldMemRefCast(*this);
1216 }
1217 
1218 //===----------------------------------------------------------------------===//
1219 // ReinterpretCastOp
1220 //===----------------------------------------------------------------------===//
1221 
1222 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1223 /// `staticSizes` and `staticStrides` are automatically filled with
1224 /// source-memref-rank sentinel values that encode dynamic entries.
1225 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1226                               MemRefType resultType, Value source,
1227                               OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1228                               ArrayRef<OpFoldResult> strides,
1229                               ArrayRef<NamedAttribute> attrs) {
1230   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1231   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1232   dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets,
1233                              ShapedType::kDynamicStrideOrOffset);
1234   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1235                              ShapedType::kDynamicSize);
1236   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1237                              ShapedType::kDynamicStrideOrOffset);
1238   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1239         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1240         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1241   result.addAttributes(attrs);
1242 }
1243 
1244 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1245                               MemRefType resultType, Value source,
1246                               int64_t offset, ArrayRef<int64_t> sizes,
1247                               ArrayRef<int64_t> strides,
1248                               ArrayRef<NamedAttribute> attrs) {
1249   SmallVector<OpFoldResult> sizeValues =
1250       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1251         return b.getI64IntegerAttr(v);
1252       }));
1253   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1254       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1255         return b.getI64IntegerAttr(v);
1256       }));
1257   build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1258         strideValues, attrs);
1259 }
1260 
1261 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1262                               MemRefType resultType, Value source, Value offset,
1263                               ValueRange sizes, ValueRange strides,
1264                               ArrayRef<NamedAttribute> attrs) {
1265   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1266       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1267   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1268       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1269   build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1270 }
1271 
1272 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1273 // completed automatically, like we have for subview and subtensor.
1274 static LogicalResult verify(ReinterpretCastOp op) {
1275   // The source and result memrefs should be in the same memory space.
1276   auto srcType = op.source().getType().cast<BaseMemRefType>();
1277   auto resultType = op.getType().cast<MemRefType>();
1278   if (srcType.getMemorySpace() != resultType.getMemorySpace())
1279     return op.emitError("different memory spaces specified for source type ")
1280            << srcType << " and result memref type " << resultType;
1281   if (srcType.getElementType() != resultType.getElementType())
1282     return op.emitError("different element types specified for source type ")
1283            << srcType << " and result memref type " << resultType;
1284 
1285   // Match sizes in result memref type and in static_sizes attribute.
1286   for (auto &en :
1287        llvm::enumerate(llvm::zip(resultType.getShape(),
1288                                  extractFromI64ArrayAttr(op.static_sizes())))) {
1289     int64_t resultSize = std::get<0>(en.value());
1290     int64_t expectedSize = std::get<1>(en.value());
1291     if (resultSize != expectedSize)
1292       return op.emitError("expected result type with size = ")
1293              << expectedSize << " instead of " << resultSize
1294              << " in dim = " << en.index();
1295   }
1296 
1297   // Match offset and strides in static_offset and static_strides attributes if
1298   // result memref type has an affine map specified.
1299   if (!resultType.getAffineMaps().empty()) {
1300     int64_t resultOffset;
1301     SmallVector<int64_t, 4> resultStrides;
1302     if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
1303       return failure();
1304 
1305     // Match offset in result memref type and in static_offsets attribute.
1306     int64_t expectedOffset =
1307         extractFromI64ArrayAttr(op.static_offsets()).front();
1308     if (resultOffset != expectedOffset)
1309       return op.emitError("expected result type with offset = ")
1310              << resultOffset << " instead of " << expectedOffset;
1311 
1312     // Match strides in result memref type and in static_strides attribute.
1313     for (auto &en : llvm::enumerate(llvm::zip(
1314              resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
1315       int64_t resultStride = std::get<0>(en.value());
1316       int64_t expectedStride = std::get<1>(en.value());
1317       if (resultStride != expectedStride)
1318         return op.emitError("expected result type with stride = ")
1319                << expectedStride << " instead of " << resultStride
1320                << " in dim = " << en.index();
1321     }
1322   }
1323   return success();
1324 }
1325 
1326 //===----------------------------------------------------------------------===//
1327 // ReshapeOp
1328 //===----------------------------------------------------------------------===//
1329 
1330 static LogicalResult verify(ReshapeOp op) {
1331   Type operandType = op.source().getType();
1332   Type resultType = op.result().getType();
1333 
1334   Type operandElementType = operandType.cast<ShapedType>().getElementType();
1335   Type resultElementType = resultType.cast<ShapedType>().getElementType();
1336   if (operandElementType != resultElementType)
1337     return op.emitOpError("element types of source and destination memref "
1338                           "types should be the same");
1339 
1340   if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
1341     if (!operandMemRefType.getAffineMaps().empty())
1342       return op.emitOpError(
1343           "source memref type should have identity affine map");
1344 
1345   int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0);
1346   auto resultMemRefType = resultType.dyn_cast<MemRefType>();
1347   if (resultMemRefType) {
1348     if (!resultMemRefType.getAffineMaps().empty())
1349       return op.emitOpError(
1350           "result memref type should have identity affine map");
1351     if (shapeSize == ShapedType::kDynamicSize)
1352       return op.emitOpError("cannot use shape operand with dynamic length to "
1353                             "reshape to statically-ranked memref type");
1354     if (shapeSize != resultMemRefType.getRank())
1355       return op.emitOpError(
1356           "length of shape operand differs from the result's memref rank");
1357   }
1358   return success();
1359 }
1360 
1361 //===----------------------------------------------------------------------===//
1362 // StoreOp
1363 //===----------------------------------------------------------------------===//
1364 
1365 static LogicalResult verify(StoreOp op) {
1366   if (op.getNumOperands() != 2 + op.getMemRefType().getRank())
1367     return op.emitOpError("store index operand count not equal to memref rank");
1368 
1369   return success();
1370 }
1371 
1372 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
1373                             SmallVectorImpl<OpFoldResult> &results) {
1374   /// store(memrefcast) -> store
1375   return foldMemRefCast(*this);
1376 }
1377 
1378 //===----------------------------------------------------------------------===//
1379 // SubViewOp
1380 //===----------------------------------------------------------------------===//
1381 
1382 namespace {
1383 /// Helpers to write more idiomatic operations.
1384 namespace saturated_arith {
1385 struct Wrapper {
1386   explicit Wrapper(int64_t v) : v(v) {}
1387   operator int64_t() { return v; }
1388   int64_t v;
1389 };
1390 Wrapper operator+(Wrapper a, int64_t b) {
1391   if (ShapedType::isDynamicStrideOrOffset(a) ||
1392       ShapedType::isDynamicStrideOrOffset(b))
1393     return Wrapper(ShapedType::kDynamicStrideOrOffset);
1394   return Wrapper(a.v + b);
1395 }
1396 Wrapper operator*(Wrapper a, int64_t b) {
1397   if (ShapedType::isDynamicStrideOrOffset(a) ||
1398       ShapedType::isDynamicStrideOrOffset(b))
1399     return Wrapper(ShapedType::kDynamicStrideOrOffset);
1400   return Wrapper(a.v * b);
1401 }
1402 } // end namespace saturated_arith
1403 } // end namespace
1404 
1405 /// A subview result type can be fully inferred from the source type and the
1406 /// static representation of offsets, sizes and strides. Special sentinels
1407 /// encode the dynamic case.
1408 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
1409                                 ArrayRef<int64_t> leadingStaticOffsets,
1410                                 ArrayRef<int64_t> leadingStaticSizes,
1411                                 ArrayRef<int64_t> leadingStaticStrides) {
1412   // A subview may specify only a leading subset of offset/sizes/strides in
1413   // which case we complete with offset=0, sizes from memref type and strides=1.
1414   unsigned rank = sourceMemRefType.getRank();
1415   assert(leadingStaticOffsets.size() <= rank &&
1416          "unexpected leadingStaticOffsets overflow");
1417   assert(leadingStaticSizes.size() <= rank &&
1418          "unexpected leadingStaticSizes overflow");
1419   assert(leadingStaticStrides.size() <= rank &&
1420          "unexpected leadingStaticStrides overflow");
1421   auto staticOffsets = llvm::to_vector<4>(leadingStaticOffsets);
1422   auto staticSizes = llvm::to_vector<4>(leadingStaticSizes);
1423   auto staticStrides = llvm::to_vector<4>(leadingStaticStrides);
1424   unsigned numTrailingOffsets = rank - staticOffsets.size();
1425   unsigned numTrailingSizes = rank - staticSizes.size();
1426   unsigned numTrailingStrides = rank - staticStrides.size();
1427   staticOffsets.append(numTrailingOffsets, 0);
1428   llvm::append_range(staticSizes,
1429                      sourceMemRefType.getShape().take_back(numTrailingSizes));
1430   staticStrides.append(numTrailingStrides, 1);
1431 
1432   // Extract source offset and strides.
1433   int64_t sourceOffset;
1434   SmallVector<int64_t, 4> sourceStrides;
1435   auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset);
1436   assert(succeeded(res) && "SubViewOp expected strided memref type");
1437   (void)res;
1438 
1439   // Compute target offset whose value is:
1440   //   `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
1441   int64_t targetOffset = sourceOffset;
1442   for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
1443     auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
1444     using namespace saturated_arith;
1445     targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride;
1446   }
1447 
1448   // Compute target stride whose value is:
1449   //   `sourceStrides_i * staticStrides_i`.
1450   SmallVector<int64_t, 4> targetStrides;
1451   targetStrides.reserve(staticOffsets.size());
1452   for (auto it : llvm::zip(sourceStrides, staticStrides)) {
1453     auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
1454     using namespace saturated_arith;
1455     targetStrides.push_back(Wrapper(sourceStride) * staticStride);
1456   }
1457 
1458   // The type is now known.
1459   return MemRefType::get(
1460       staticSizes, sourceMemRefType.getElementType(),
1461       makeStridedLinearLayoutMap(targetStrides, targetOffset,
1462                                  sourceMemRefType.getContext()),
1463       sourceMemRefType.getMemorySpace());
1464 }
1465 
1466 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
1467                                 ArrayRef<OpFoldResult> leadingStaticOffsets,
1468                                 ArrayRef<OpFoldResult> leadingStaticSizes,
1469                                 ArrayRef<OpFoldResult> leadingStaticStrides) {
1470   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1471   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1472   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
1473                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
1474   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
1475                              ShapedType::kDynamicSize);
1476   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
1477                              staticStrides, ShapedType::kDynamicStrideOrOffset);
1478   return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
1479                                     staticSizes, staticStrides)
1480       .cast<MemRefType>();
1481 }
1482 
1483 Type SubViewOp::inferRankReducedResultType(
1484     unsigned resultRank, MemRefType sourceRankedTensorType,
1485     ArrayRef<int64_t> leadingStaticOffsets,
1486     ArrayRef<int64_t> leadingStaticSizes,
1487     ArrayRef<int64_t> leadingStaticStrides) {
1488   auto inferredType =
1489       inferResultType(sourceRankedTensorType, leadingStaticOffsets,
1490                       leadingStaticSizes, leadingStaticStrides)
1491           .cast<MemRefType>();
1492   assert(inferredType.getRank() >= resultRank && "expected ");
1493   int rankDiff = inferredType.getRank() - resultRank;
1494   if (rankDiff > 0) {
1495     auto shape = inferredType.getShape();
1496     llvm::SmallDenseSet<unsigned> dimsToProject;
1497     mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject);
1498     SmallVector<int64_t> projectedShape;
1499     for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
1500       if (!dimsToProject.contains(pos))
1501         projectedShape.push_back(shape[pos]);
1502 
1503     AffineMap map;
1504     auto maps = inferredType.getAffineMaps();
1505     if (!maps.empty() && maps.front())
1506       map = getProjectedMap(maps.front(), dimsToProject);
1507     inferredType =
1508         MemRefType::get(projectedShape, inferredType.getElementType(), map,
1509                         inferredType.getMemorySpace());
1510   }
1511   return inferredType;
1512 }
1513 
1514 Type SubViewOp::inferRankReducedResultType(
1515     unsigned resultRank, MemRefType sourceRankedTensorType,
1516     ArrayRef<OpFoldResult> leadingStaticOffsets,
1517     ArrayRef<OpFoldResult> leadingStaticSizes,
1518     ArrayRef<OpFoldResult> leadingStaticStrides) {
1519   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1520   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1521   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
1522                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
1523   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
1524                              ShapedType::kDynamicSize);
1525   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
1526                              staticStrides, ShapedType::kDynamicStrideOrOffset);
1527   return SubViewOp::inferRankReducedResultType(
1528       resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
1529       staticStrides);
1530 }
1531 // Build a SubViewOp with mixed static and dynamic entries and custom result
1532 // type. If the type passed is nullptr, it is inferred.
1533 void SubViewOp::build(OpBuilder &b, OperationState &result,
1534                       MemRefType resultType, Value source,
1535                       ArrayRef<OpFoldResult> offsets,
1536                       ArrayRef<OpFoldResult> sizes,
1537                       ArrayRef<OpFoldResult> strides,
1538                       ArrayRef<NamedAttribute> attrs) {
1539   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1540   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1541   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1542                              ShapedType::kDynamicStrideOrOffset);
1543   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1544                              ShapedType::kDynamicSize);
1545   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1546                              ShapedType::kDynamicStrideOrOffset);
1547   auto sourceMemRefType = source.getType().cast<MemRefType>();
1548   // Structuring implementation this way avoids duplication between builders.
1549   if (!resultType) {
1550     resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
1551                                             staticSizes, staticStrides)
1552                      .cast<MemRefType>();
1553   }
1554   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1555         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1556         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1557   result.addAttributes(attrs);
1558 }
1559 
1560 // Build a SubViewOp with mixed static and dynamic entries and inferred result
1561 // type.
1562 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1563                       ArrayRef<OpFoldResult> offsets,
1564                       ArrayRef<OpFoldResult> sizes,
1565                       ArrayRef<OpFoldResult> strides,
1566                       ArrayRef<NamedAttribute> attrs) {
1567   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
1568 }
1569 
1570 // Build a SubViewOp with static entries and inferred result type.
1571 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1572                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1573                       ArrayRef<int64_t> strides,
1574                       ArrayRef<NamedAttribute> attrs) {
1575   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1576       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
1577         return b.getI64IntegerAttr(v);
1578       }));
1579   SmallVector<OpFoldResult> sizeValues =
1580       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1581         return b.getI64IntegerAttr(v);
1582       }));
1583   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1584       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1585         return b.getI64IntegerAttr(v);
1586       }));
1587   build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
1588 }
1589 
1590 // Build a SubViewOp with dynamic entries and custom result type. If the
1591 // type passed is nullptr, it is inferred.
1592 void SubViewOp::build(OpBuilder &b, OperationState &result,
1593                       MemRefType resultType, Value source,
1594                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1595                       ArrayRef<int64_t> strides,
1596                       ArrayRef<NamedAttribute> attrs) {
1597   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1598       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
1599         return b.getI64IntegerAttr(v);
1600       }));
1601   SmallVector<OpFoldResult> sizeValues =
1602       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1603         return b.getI64IntegerAttr(v);
1604       }));
1605   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1606       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1607         return b.getI64IntegerAttr(v);
1608       }));
1609   build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
1610         attrs);
1611 }
1612 
1613 // Build a SubViewOp with dynamic entries and custom result type. If the type
1614 // passed is nullptr, it is inferred.
1615 void SubViewOp::build(OpBuilder &b, OperationState &result,
1616                       MemRefType resultType, Value source, ValueRange offsets,
1617                       ValueRange sizes, ValueRange strides,
1618                       ArrayRef<NamedAttribute> attrs) {
1619   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1620       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1621   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1622       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1623   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1624       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1625   build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
1626 }
1627 
1628 // Build a SubViewOp with dynamic entries and inferred result type.
1629 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1630                       ValueRange offsets, ValueRange sizes, ValueRange strides,
1631                       ArrayRef<NamedAttribute> attrs) {
1632   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
1633 }
1634 
1635 /// For ViewLikeOpInterface.
1636 Value SubViewOp::getViewSource() { return source(); }
1637 
1638 enum SubViewVerificationResult {
1639   Success,
1640   RankTooLarge,
1641   SizeMismatch,
1642   ElemTypeMismatch,
1643   MemSpaceMismatch,
1644   AffineMapMismatch
1645 };
1646 
1647 /// Checks if `original` Type type can be rank reduced to `reduced` type.
1648 /// This function is slight variant of `is subsequence` algorithm where
1649 /// not matching dimension must be 1.
1650 static SubViewVerificationResult
1651 isRankReducedType(Type originalType, Type candidateReducedType,
1652                   std::string *errMsg = nullptr) {
1653   if (originalType == candidateReducedType)
1654     return SubViewVerificationResult::Success;
1655   if (!originalType.isa<MemRefType>())
1656     return SubViewVerificationResult::Success;
1657   if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>())
1658     return SubViewVerificationResult::Success;
1659 
1660   ShapedType originalShapedType = originalType.cast<ShapedType>();
1661   ShapedType candidateReducedShapedType =
1662       candidateReducedType.cast<ShapedType>();
1663 
1664   // Rank and size logic is valid for all ShapedTypes.
1665   ArrayRef<int64_t> originalShape = originalShapedType.getShape();
1666   ArrayRef<int64_t> candidateReducedShape =
1667       candidateReducedShapedType.getShape();
1668   unsigned originalRank = originalShape.size(),
1669            candidateReducedRank = candidateReducedShape.size();
1670   if (candidateReducedRank > originalRank)
1671     return SubViewVerificationResult::RankTooLarge;
1672 
1673   auto optionalUnusedDimsMask =
1674       computeRankReductionMask(originalShape, candidateReducedShape);
1675 
1676   // Sizes cannot be matched in case empty vector is returned.
1677   if (!optionalUnusedDimsMask.hasValue())
1678     return SubViewVerificationResult::SizeMismatch;
1679 
1680   if (originalShapedType.getElementType() !=
1681       candidateReducedShapedType.getElementType())
1682     return SubViewVerificationResult::ElemTypeMismatch;
1683 
1684   // Strided layout logic is relevant for MemRefType only.
1685   MemRefType original = originalType.cast<MemRefType>();
1686   MemRefType candidateReduced = candidateReducedType.cast<MemRefType>();
1687   if (original.getMemorySpace() != candidateReduced.getMemorySpace())
1688     return SubViewVerificationResult::MemSpaceMismatch;
1689 
1690   llvm::SmallDenseSet<unsigned> unusedDims = optionalUnusedDimsMask.getValue();
1691   auto inferredType =
1692       getProjectedMap(getStridedLinearLayoutMap(original), unusedDims);
1693   AffineMap candidateLayout;
1694   if (candidateReduced.getAffineMaps().empty())
1695     candidateLayout = getStridedLinearLayoutMap(candidateReduced);
1696   else
1697     candidateLayout = candidateReduced.getAffineMaps().front();
1698   assert(inferredType.getNumResults() == 1 &&
1699          candidateLayout.getNumResults() == 1);
1700   if (inferredType.getNumSymbols() != candidateLayout.getNumSymbols() ||
1701       inferredType.getNumDims() != candidateLayout.getNumDims()) {
1702     if (errMsg) {
1703       llvm::raw_string_ostream os(*errMsg);
1704       os << "inferred type: " << inferredType;
1705     }
1706     return SubViewVerificationResult::AffineMapMismatch;
1707   }
1708   // Check that the difference of the affine maps simplifies to 0.
1709   AffineExpr diffExpr =
1710       inferredType.getResult(0) - candidateLayout.getResult(0);
1711   diffExpr = simplifyAffineExpr(diffExpr, inferredType.getNumDims(),
1712                                 inferredType.getNumSymbols());
1713   auto cst = diffExpr.dyn_cast<AffineConstantExpr>();
1714   if (!(cst && cst.getValue() == 0)) {
1715     if (errMsg) {
1716       llvm::raw_string_ostream os(*errMsg);
1717       os << "inferred type: " << inferredType;
1718     }
1719     return SubViewVerificationResult::AffineMapMismatch;
1720   }
1721   return SubViewVerificationResult::Success;
1722 }
1723 
1724 template <typename OpTy>
1725 static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result,
1726                                             OpTy op, Type expectedType,
1727                                             StringRef errMsg = "") {
1728   auto memrefType = expectedType.cast<ShapedType>();
1729   switch (result) {
1730   case SubViewVerificationResult::Success:
1731     return success();
1732   case SubViewVerificationResult::RankTooLarge:
1733     return op.emitError("expected result rank to be smaller or equal to ")
1734            << "the source rank. " << errMsg;
1735   case SubViewVerificationResult::SizeMismatch:
1736     return op.emitError("expected result type to be ")
1737            << expectedType
1738            << " or a rank-reduced version. (mismatch of result sizes) "
1739            << errMsg;
1740   case SubViewVerificationResult::ElemTypeMismatch:
1741     return op.emitError("expected result element type to be ")
1742            << memrefType.getElementType() << errMsg;
1743   case SubViewVerificationResult::MemSpaceMismatch:
1744     return op.emitError("expected result and source memory spaces to match.")
1745            << errMsg;
1746   case SubViewVerificationResult::AffineMapMismatch:
1747     return op.emitError("expected result type to be ")
1748            << expectedType
1749            << " or a rank-reduced version. (mismatch of result affine map) "
1750            << errMsg;
1751   }
1752   llvm_unreachable("unexpected subview verification result");
1753 }
1754 
1755 /// Verifier for SubViewOp.
1756 static LogicalResult verify(SubViewOp op) {
1757   MemRefType baseType = op.getSourceType();
1758   MemRefType subViewType = op.getType();
1759 
1760   // The base memref and the view memref should be in the same memory space.
1761   if (baseType.getMemorySpace() != subViewType.getMemorySpace())
1762     return op.emitError("different memory spaces specified for base memref "
1763                         "type ")
1764            << baseType << " and subview memref type " << subViewType;
1765 
1766   // Verify that the base memref type has a strided layout map.
1767   if (!isStrided(baseType))
1768     return op.emitError("base type ") << baseType << " is not strided";
1769 
1770   // Verify result type against inferred type.
1771   auto expectedType = SubViewOp::inferResultType(
1772       baseType, extractFromI64ArrayAttr(op.static_offsets()),
1773       extractFromI64ArrayAttr(op.static_sizes()),
1774       extractFromI64ArrayAttr(op.static_strides()));
1775 
1776   std::string errMsg;
1777   auto result = isRankReducedType(expectedType, subViewType, &errMsg);
1778   return produceSubViewErrorMsg(result, op, expectedType, errMsg);
1779 }
1780 
1781 raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) {
1782   return os << "range " << range.offset << ":" << range.size << ":"
1783             << range.stride;
1784 }
1785 
1786 /// Return the list of Range (i.e. offset, size, stride). Each Range
1787 /// entry contains either the dynamic value or a ConstantIndexOp constructed
1788 /// with `b` at location `loc`.
1789 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
1790                                               OpBuilder &b, Location loc) {
1791   std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
1792   assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
1793   assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
1794   SmallVector<Range, 8> res;
1795   unsigned rank = ranks[0];
1796   res.reserve(rank);
1797   for (unsigned idx = 0; idx < rank; ++idx) {
1798     Value offset =
1799         op.isDynamicOffset(idx)
1800             ? op.getDynamicOffset(idx)
1801             : b.create<ConstantIndexOp>(loc, op.getStaticOffset(idx));
1802     Value size = op.isDynamicSize(idx)
1803                      ? op.getDynamicSize(idx)
1804                      : b.create<ConstantIndexOp>(loc, op.getStaticSize(idx));
1805     Value stride =
1806         op.isDynamicStride(idx)
1807             ? op.getDynamicStride(idx)
1808             : b.create<ConstantIndexOp>(loc, op.getStaticStride(idx));
1809     res.emplace_back(Range{offset, size, stride});
1810   }
1811   return res;
1812 }
1813 
1814 namespace {
1815 /// Pattern to rewrite a subview op with MemRefCast arguments.
1816 /// This essentially pushes memref.cast past its consuming subview when
1817 /// `canFoldIntoConsumerOp` is true.
1818 ///
1819 /// Example:
1820 /// ```
1821 ///   %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
1822 ///   %1 = memref.subview %0[0, 0][3, 4][1, 1] :
1823 ///     memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
1824 /// ```
1825 /// is rewritten into:
1826 /// ```
1827 ///   %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
1828 ///   %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to
1829 ///     memref<3x4xf32, offset:?, strides:[?, 1]>
1830 /// ```
1831 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
1832 public:
1833   using OpRewritePattern<SubViewOp>::OpRewritePattern;
1834 
1835   LogicalResult matchAndRewrite(SubViewOp subViewOp,
1836                                 PatternRewriter &rewriter) const override {
1837     // Any constant operand, just return to let SubViewOpConstantFolder kick in.
1838     if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
1839           return matchPattern(operand, matchConstantIndex());
1840         }))
1841       return failure();
1842 
1843     auto castOp = subViewOp.source().getDefiningOp<CastOp>();
1844     if (!castOp)
1845       return failure();
1846 
1847     if (!CastOp::canFoldIntoConsumerOp(castOp))
1848       return failure();
1849 
1850     /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
1851     /// the cast source operand type and the SubViewOp static information. This
1852     /// is the resulting type if the MemRefCastOp were folded.
1853     auto resultType = SubViewOp::inferRankReducedResultType(
1854         subViewOp.getType().getRank(),
1855         castOp.source().getType().cast<MemRefType>(),
1856         subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
1857         subViewOp.getMixedStrides());
1858     Value newSubView = rewriter.create<SubViewOp>(
1859         subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
1860         subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),
1861         subViewOp.static_sizes(), subViewOp.static_strides());
1862     rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
1863                                         newSubView);
1864     return success();
1865   }
1866 };
1867 } // namespace
1868 
1869 /// A canonicalizer wrapper to replace SubViewOps.
1870 struct SubViewCanonicalizer {
1871   void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
1872     rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType());
1873   }
1874 };
1875 
1876 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
1877                                             MLIRContext *context) {
1878   results.add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
1879                   SubViewOp, SubViewCanonicalizer>,
1880               SubViewOpMemRefCastFolder>(context);
1881 }
1882 
1883 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
1884   auto resultShapedType = getResult().getType().cast<ShapedType>();
1885   auto sourceShapedType = source().getType().cast<ShapedType>();
1886 
1887   if (resultShapedType.hasStaticShape() &&
1888       resultShapedType == sourceShapedType) {
1889     return getViewSource();
1890   }
1891 
1892   return {};
1893 }
1894 
1895 //===----------------------------------------------------------------------===//
1896 // TensorLoadOp
1897 //===----------------------------------------------------------------------===//
1898 
1899 OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute>) {
1900   if (auto bufferCast = memref().getDefiningOp<BufferCastOp>())
1901     // Approximate alias analysis by conservatively folding only when no there
1902     // is no interleaved operation.
1903     if (bufferCast->getBlock() == this->getOperation()->getBlock() &&
1904         bufferCast->getNextNode() == this->getOperation())
1905       return bufferCast.tensor();
1906   return {};
1907 }
1908 
1909 //===----------------------------------------------------------------------===//
1910 // TransposeOp
1911 //===----------------------------------------------------------------------===//
1912 
1913 /// Build a strided memref type by applying `permutationMap` tp `memRefType`.
1914 static MemRefType inferTransposeResultType(MemRefType memRefType,
1915                                            AffineMap permutationMap) {
1916   auto rank = memRefType.getRank();
1917   auto originalSizes = memRefType.getShape();
1918   // Compute permuted sizes.
1919   SmallVector<int64_t, 4> sizes(rank, 0);
1920   for (auto en : llvm::enumerate(permutationMap.getResults()))
1921     sizes[en.index()] =
1922         originalSizes[en.value().cast<AffineDimExpr>().getPosition()];
1923 
1924   // Compute permuted strides.
1925   int64_t offset;
1926   SmallVector<int64_t, 4> strides;
1927   auto res = getStridesAndOffset(memRefType, strides, offset);
1928   assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank));
1929   (void)res;
1930   auto map =
1931       makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
1932   map = permutationMap ? map.compose(permutationMap) : map;
1933   return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map);
1934 }
1935 
1936 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
1937                         AffineMapAttr permutation,
1938                         ArrayRef<NamedAttribute> attrs) {
1939   auto permutationMap = permutation.getValue();
1940   assert(permutationMap);
1941 
1942   auto memRefType = in.getType().cast<MemRefType>();
1943   // Compute result type.
1944   MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
1945 
1946   build(b, result, resultType, in, attrs);
1947   result.addAttribute(TransposeOp::getPermutationAttrName(), permutation);
1948 }
1949 
1950 // transpose $in $permutation attr-dict : type($in) `to` type(results)
1951 static void print(OpAsmPrinter &p, TransposeOp op) {
1952   p << "memref.transpose " << op.in() << " " << op.permutation();
1953   p.printOptionalAttrDict(op->getAttrs(),
1954                           {TransposeOp::getPermutationAttrName()});
1955   p << " : " << op.in().getType() << " to " << op.getType();
1956 }
1957 
1958 static ParseResult parseTransposeOp(OpAsmParser &parser,
1959                                     OperationState &result) {
1960   OpAsmParser::OperandType in;
1961   AffineMap permutation;
1962   MemRefType srcType, dstType;
1963   if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
1964       parser.parseOptionalAttrDict(result.attributes) ||
1965       parser.parseColonType(srcType) ||
1966       parser.resolveOperand(in, srcType, result.operands) ||
1967       parser.parseKeywordType("to", dstType) ||
1968       parser.addTypeToList(dstType, result.types))
1969     return failure();
1970 
1971   result.addAttribute(TransposeOp::getPermutationAttrName(),
1972                       AffineMapAttr::get(permutation));
1973   return success();
1974 }
1975 
1976 static LogicalResult verify(TransposeOp op) {
1977   if (!op.permutation().isPermutation())
1978     return op.emitOpError("expected a permutation map");
1979   if (op.permutation().getNumDims() != op.getShapedType().getRank())
1980     return op.emitOpError(
1981         "expected a permutation map of same rank as the input");
1982 
1983   auto srcType = op.in().getType().cast<MemRefType>();
1984   auto dstType = op.getType().cast<MemRefType>();
1985   auto transposedType = inferTransposeResultType(srcType, op.permutation());
1986   if (dstType != transposedType)
1987     return op.emitOpError("output type ")
1988            << dstType << " does not match transposed input type " << srcType
1989            << ", " << transposedType;
1990   return success();
1991 }
1992 
1993 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
1994   if (succeeded(foldMemRefCast(*this)))
1995     return getResult();
1996   return {};
1997 }
1998 
1999 //===----------------------------------------------------------------------===//
2000 // ViewOp
2001 //===----------------------------------------------------------------------===//
2002 
2003 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) {
2004   OpAsmParser::OperandType srcInfo;
2005   SmallVector<OpAsmParser::OperandType, 1> offsetInfo;
2006   SmallVector<OpAsmParser::OperandType, 4> sizesInfo;
2007   auto indexType = parser.getBuilder().getIndexType();
2008   Type srcType, dstType;
2009   llvm::SMLoc offsetLoc;
2010   if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) ||
2011       parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square))
2012     return failure();
2013 
2014   if (offsetInfo.size() != 1)
2015     return parser.emitError(offsetLoc) << "expects 1 offset operand";
2016 
2017   return failure(
2018       parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) ||
2019       parser.parseOptionalAttrDict(result.attributes) ||
2020       parser.parseColonType(srcType) ||
2021       parser.resolveOperand(srcInfo, srcType, result.operands) ||
2022       parser.resolveOperands(offsetInfo, indexType, result.operands) ||
2023       parser.resolveOperands(sizesInfo, indexType, result.operands) ||
2024       parser.parseKeywordType("to", dstType) ||
2025       parser.addTypeToList(dstType, result.types));
2026 }
2027 
2028 static void print(OpAsmPrinter &p, ViewOp op) {
2029   p << op.getOperationName() << ' ' << op.getOperand(0) << '[';
2030   p.printOperand(op.byte_shift());
2031   p << "][" << op.sizes() << ']';
2032   p.printOptionalAttrDict(op->getAttrs());
2033   p << " : " << op.getOperand(0).getType() << " to " << op.getType();
2034 }
2035 
2036 static LogicalResult verify(ViewOp op) {
2037   auto baseType = op.getOperand(0).getType().cast<MemRefType>();
2038   auto viewType = op.getType();
2039 
2040   // The base memref should have identity layout map (or none).
2041   if (baseType.getAffineMaps().size() > 1 ||
2042       (baseType.getAffineMaps().size() == 1 &&
2043        !baseType.getAffineMaps()[0].isIdentity()))
2044     return op.emitError("unsupported map for base memref type ") << baseType;
2045 
2046   // The result memref should have identity layout map (or none).
2047   if (viewType.getAffineMaps().size() > 1 ||
2048       (viewType.getAffineMaps().size() == 1 &&
2049        !viewType.getAffineMaps()[0].isIdentity()))
2050     return op.emitError("unsupported map for result memref type ") << viewType;
2051 
2052   // The base memref and the view memref should be in the same memory space.
2053   if (baseType.getMemorySpace() != viewType.getMemorySpace())
2054     return op.emitError("different memory spaces specified for base memref "
2055                         "type ")
2056            << baseType << " and view memref type " << viewType;
2057 
2058   // Verify that we have the correct number of sizes for the result type.
2059   unsigned numDynamicDims = viewType.getNumDynamicDims();
2060   if (op.sizes().size() != numDynamicDims)
2061     return op.emitError("incorrect number of size operands for type ")
2062            << viewType;
2063 
2064   return success();
2065 }
2066 
2067 Value ViewOp::getViewSource() { return source(); }
2068 
2069 namespace {
2070 
2071 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
2072   using OpRewritePattern<ViewOp>::OpRewritePattern;
2073 
2074   LogicalResult matchAndRewrite(ViewOp viewOp,
2075                                 PatternRewriter &rewriter) const override {
2076     // Return if none of the operands are constants.
2077     if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
2078           return matchPattern(operand, matchConstantIndex());
2079         }))
2080       return failure();
2081 
2082     // Get result memref type.
2083     auto memrefType = viewOp.getType();
2084 
2085     // Get offset from old memref view type 'memRefType'.
2086     int64_t oldOffset;
2087     SmallVector<int64_t, 4> oldStrides;
2088     if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
2089       return failure();
2090     assert(oldOffset == 0 && "Expected 0 offset");
2091 
2092     SmallVector<Value, 4> newOperands;
2093 
2094     // Offset cannot be folded into result type.
2095 
2096     // Fold any dynamic dim operands which are produced by a constant.
2097     SmallVector<int64_t, 4> newShapeConstants;
2098     newShapeConstants.reserve(memrefType.getRank());
2099 
2100     unsigned dynamicDimPos = 0;
2101     unsigned rank = memrefType.getRank();
2102     for (unsigned dim = 0, e = rank; dim < e; ++dim) {
2103       int64_t dimSize = memrefType.getDimSize(dim);
2104       // If this is already static dimension, keep it.
2105       if (!ShapedType::isDynamic(dimSize)) {
2106         newShapeConstants.push_back(dimSize);
2107         continue;
2108       }
2109       auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp();
2110       if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
2111         // Dynamic shape dimension will be folded.
2112         newShapeConstants.push_back(constantIndexOp.getValue());
2113       } else {
2114         // Dynamic shape dimension not folded; copy operand from old memref.
2115         newShapeConstants.push_back(dimSize);
2116         newOperands.push_back(viewOp.sizes()[dynamicDimPos]);
2117       }
2118       dynamicDimPos++;
2119     }
2120 
2121     // Create new memref type with constant folded dims.
2122     MemRefType newMemRefType =
2123         MemRefType::Builder(memrefType).setShape(newShapeConstants);
2124     // Nothing new, don't fold.
2125     if (newMemRefType == memrefType)
2126       return failure();
2127 
2128     // Create new ViewOp.
2129     auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType,
2130                                              viewOp.getOperand(0),
2131                                              viewOp.byte_shift(), newOperands);
2132     // Insert a cast so we have the same type as the old memref type.
2133     rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType());
2134     return success();
2135   }
2136 };
2137 
2138 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
2139   using OpRewritePattern<ViewOp>::OpRewritePattern;
2140 
2141   LogicalResult matchAndRewrite(ViewOp viewOp,
2142                                 PatternRewriter &rewriter) const override {
2143     Value memrefOperand = viewOp.getOperand(0);
2144     CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
2145     if (!memrefCastOp)
2146       return failure();
2147     Value allocOperand = memrefCastOp.getOperand();
2148     AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
2149     if (!allocOp)
2150       return failure();
2151     rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
2152                                         viewOp.byte_shift(), viewOp.sizes());
2153     return success();
2154   }
2155 };
2156 
2157 } // end anonymous namespace
2158 
2159 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
2160                                          MLIRContext *context) {
2161   results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
2162 }
2163 
2164 //===----------------------------------------------------------------------===//
2165 // TableGen'd op method definitions
2166 //===----------------------------------------------------------------------===//
2167 
2168 #define GET_OP_CLASSES
2169 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
2170