1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
10 #include "mlir/Dialect/Tensor/IR/Tensor.h"
11 #include "mlir/Dialect/Utils/StaticValueUtils.h"
12 #include "mlir/IR/BlockAndValueMapping.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/IR/TypeUtilities.h"
17 #include "llvm/ADT/STLExtras.h"
18 
19 using namespace mlir;
20 using namespace mlir::tensor;
21 
22 /// Materialize a single constant operation from a given attribute value with
23 /// the desired resultant type.
24 Operation *TensorDialect::materializeConstant(OpBuilder &builder,
25                                               Attribute value, Type type,
26                                               Location loc) {
27   return builder.create<mlir::ConstantOp>(loc, type, value);
28 }
29 
30 //===----------------------------------------------------------------------===//
31 // CastOp
32 //===----------------------------------------------------------------------===//
33 
34 /// Determines whether tensor::CastOp casts to a more dynamic version of the
35 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
36 /// implement canonicalization patterns for ops in different dialects that may
37 /// consume the results of tensor.cast operations. Such foldable tensor.cast
38 /// operations are typically inserted as `slice` ops and are canonicalized,
39 /// to preserve the type compatibility of their uses.
40 ///
41 /// Returns true when all conditions are met:
42 /// 1. source and result are ranked tensors with same element type and rank.
43 /// 2. the tensor type has more static information than the result
44 ///
45 /// Example:
46 /// ```mlir
47 ///   %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
48 ///   %2 = consumer %1 ... : tensor<?x?xf32> ...
49 /// ```
50 ///
51 /// folds into:
52 ///
53 /// ```mlir
54 ///   %2 = consumer %0 ... : tensor<8x16xf32> ...
55 /// ```
56 bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
57   if (!castOp)
58     return false;
59 
60   RankedTensorType sourceType =
61       castOp.source().getType().dyn_cast<RankedTensorType>();
62   RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>();
63 
64   // Requires RankedTensorType.
65   if (!sourceType || !resultType)
66     return false;
67 
68   // Requires same elemental type.
69   if (sourceType.getElementType() != resultType.getElementType())
70     return false;
71 
72   // Requires same rank.
73   if (sourceType.getRank() != resultType.getRank())
74     return false;
75 
76   // If cast is towards more static sizes along any dimension, don't fold.
77   for (auto t : llvm::zip(sourceType.getShape(), resultType.getShape())) {
78     if (ShapedType::isDynamic(std::get<0>(t)) &&
79         !ShapedType::isDynamic(std::get<1>(t)))
80       return false;
81   }
82 
83   return true;
84 }
85 
86 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
87 /// that can be folded.
88 LogicalResult mlir::tensor::foldTensorCast(Operation *op) {
89   bool folded = false;
90   for (OpOperand &operand : op->getOpOperands()) {
91     auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
92     if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
93       operand.set(castOp.getOperand());
94       folded = true;
95     }
96   }
97   return success(folded);
98 }
99 
100 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
101   if (inputs.size() != 1 || outputs.size() != 1)
102     return false;
103   Type a = inputs.front(), b = outputs.front();
104   auto aT = a.dyn_cast<TensorType>();
105   auto bT = b.dyn_cast<TensorType>();
106   if (!aT || !bT)
107     return false;
108 
109   if (aT.getElementType() != bT.getElementType())
110     return false;
111 
112   return succeeded(verifyCompatibleShape(aT, bT));
113 }
114 
115 /// Compute a TensorType that has the joined shape knowledge of the two
116 /// given TensorTypes. The element types need to match.
117 static TensorType joinShapes(TensorType one, TensorType two) {
118   assert(one.getElementType() == two.getElementType());
119 
120   if (!one.hasRank())
121     return two;
122   if (!two.hasRank())
123     return one;
124 
125   int64_t rank = one.getRank();
126   if (rank != two.getRank())
127     return {};
128 
129   SmallVector<int64_t, 4> join;
130   join.reserve(rank);
131   for (int64_t i = 0; i < rank; ++i) {
132     if (one.isDynamicDim(i)) {
133       join.push_back(two.getDimSize(i));
134       continue;
135     }
136     if (two.isDynamicDim(i)) {
137       join.push_back(one.getDimSize(i));
138       continue;
139     }
140     if (one.getDimSize(i) != two.getDimSize(i))
141       return {};
142     join.push_back(one.getDimSize(i));
143   }
144   return RankedTensorType::get(join, one.getElementType());
145 }
146 
147 namespace {
148 
149 /// Replaces chains of two tensor.cast operations by a single tensor.cast
150 /// operation if doing so does not remove runtime constraints.
151 struct ChainedTensorCast : public OpRewritePattern<CastOp> {
152   using OpRewritePattern<CastOp>::OpRewritePattern;
153 
154   LogicalResult matchAndRewrite(CastOp tensorCast,
155                                 PatternRewriter &rewriter) const final {
156     auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
157 
158     if (!tensorCastOperand)
159       return failure();
160 
161     auto sourceType =
162         tensorCastOperand.getOperand().getType().cast<TensorType>();
163     auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
164     auto resultType = tensorCast.getType().cast<TensorType>();
165 
166     // We can remove the intermediate cast if joining all three produces the
167     // same result as just joining the source and result shapes.
168     auto firstJoin =
169         joinShapes(joinShapes(sourceType, intermediateType), resultType);
170 
171     // The join might not exist if the cast sequence would fail at runtime.
172     if (!firstJoin)
173       return failure();
174 
175     // The newJoin always exists if the above join exists, it might just contain
176     // less information. If so, we cannot drop the intermediate cast, as doing
177     // so would remove runtime checks.
178     auto newJoin = joinShapes(sourceType, resultType);
179     if (firstJoin != newJoin)
180       return failure();
181 
182     rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
183                                         tensorCastOperand.getOperand());
184     return success();
185   }
186 };
187 
188 } // namespace
189 
190 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
191                                          MLIRContext *context) {
192   results.add<ChainedTensorCast>(context);
193 }
194 
195 //===----------------------------------------------------------------------===//
196 // DimOp
197 //===----------------------------------------------------------------------===//
198 
199 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
200                   int64_t index) {
201   auto loc = result.location;
202   Value indexValue = builder.create<ConstantIndexOp>(loc, index);
203   build(builder, result, source, indexValue);
204 }
205 
206 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
207                   Value index) {
208   auto indexTy = builder.getIndexType();
209   build(builder, result, indexTy, source, index);
210 }
211 
212 Optional<int64_t> DimOp::getConstantIndex() {
213   if (auto constantOp = index().getDefiningOp<ConstantOp>())
214     return constantOp.getValue().cast<IntegerAttr>().getInt();
215   return {};
216 }
217 
218 static LogicalResult verify(DimOp op) {
219   // Assume unknown index to be in range.
220   Optional<int64_t> index = op.getConstantIndex();
221   if (!index.hasValue())
222     return success();
223 
224   // Check that constant index is not knowingly out of range.
225   auto type = op.source().getType();
226   if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
227     if (index.getValue() >= tensorType.getRank())
228       return op.emitOpError("index is out of range");
229   } else if (type.isa<UnrankedTensorType>()) {
230     // Assume index to be in range.
231   } else {
232     llvm_unreachable("expected operand with tensor type");
233   }
234   return success();
235 }
236 
237 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
238   // All forms of folding require a known index.
239   auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
240   if (!index)
241     return {};
242 
243   // Folding for unranked types (UnrankedTensorType) is not supported.
244   auto tensorType = source().getType().dyn_cast<RankedTensorType>();
245   if (!tensorType)
246     return {};
247 
248   // Fold if the shape extent along the given index is known.
249   if (!tensorType.isDynamicDim(index.getInt())) {
250     Builder builder(getContext());
251     return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
252   }
253 
254   Operation *definingOp = source().getDefiningOp();
255 
256   // Fold dim to the operand of tensor.generate.
257   if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
258     auto resultType =
259         fromElements.getResult().getType().cast<RankedTensorType>();
260     // The case where the type encodes the size of the dimension is handled
261     // above.
262     assert(resultType.getShape()[index.getInt()] ==
263            RankedTensorType::kDynamicSize);
264 
265     // Find the operand of the fromElements that corresponds to this index.
266     auto dynExtents = fromElements.dynamicExtents().begin();
267     for (auto dim : resultType.getShape().take_front(index.getInt()))
268       if (dim == RankedTensorType::kDynamicSize)
269         dynExtents++;
270 
271     return Value{*dynExtents};
272   }
273 
274   // The size at the given index is now known to be a dynamic size.
275   unsigned unsignedIndex = index.getValue().getZExtValue();
276 
277   if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
278     assert(sliceOp.isDynamicSize(unsignedIndex) &&
279            "Expected dynamic slice size");
280     return sliceOp.getDynamicSize(unsignedIndex);
281   }
282 
283   // dim(cast) -> dim
284   if (succeeded(foldTensorCast(*this)))
285     return getResult();
286 
287   return {};
288 }
289 
290 namespace {
291 /// Fold dim of a cast into the dim of the source of the tensor cast.
292 struct DimOfCastOp : public OpRewritePattern<DimOp> {
293   using OpRewritePattern<DimOp>::OpRewritePattern;
294 
295   LogicalResult matchAndRewrite(DimOp dimOp,
296                                 PatternRewriter &rewriter) const override {
297     auto castOp = dimOp.source().getDefiningOp<CastOp>();
298     if (!castOp)
299       return failure();
300     Value newSource = castOp.getOperand();
301     rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.index());
302     return success();
303   }
304 };
305 } // end anonymous namespace.
306 
307 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
308                                         MLIRContext *context) {
309   results.add<DimOfCastOp>(context);
310 }
311 
312 //===----------------------------------------------------------------------===//
313 // ExtractOp
314 //===----------------------------------------------------------------------===//
315 
316 static LogicalResult verify(ExtractOp op) {
317   // Verify the # indices match if we have a ranked type.
318   if (auto tensorType = op.tensor().getType().dyn_cast<RankedTensorType>())
319     if (tensorType.getRank() != static_cast<int64_t>(op.indices().size()))
320       return op.emitOpError("incorrect number of indices for extract_element");
321 
322   return success();
323 }
324 
325 OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
326   // The tensor operand must be a known constant.
327   Attribute tensor = operands.front();
328   if (!tensor)
329     return {};
330   // If this is a splat elements attribute, simply return the value. All of the
331   // elements of a splat attribute are the same.
332   if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
333     return splatTensor.getSplatValue();
334 
335   // Otherwise, collect the constant indices into the tensor.
336   SmallVector<uint64_t, 8> indices;
337   for (Attribute indice : llvm::drop_begin(operands, 1)) {
338     if (!indice || !indice.isa<IntegerAttr>())
339       return {};
340     indices.push_back(indice.cast<IntegerAttr>().getInt());
341   }
342 
343   // If this is an elements attribute, query the value at the given indices.
344   auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
345   if (elementsAttr && elementsAttr.isValidIndex(indices))
346     return elementsAttr.getValue(indices);
347   return {};
348 }
349 
350 //===----------------------------------------------------------------------===//
351 // FromElementsOp
352 //===----------------------------------------------------------------------===//
353 
354 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
355                            Type elementType, ValueRange elements) {
356   Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())},
357                                         elementType);
358   result.addOperands(elements);
359   result.addTypes(resultTy);
360 }
361 
362 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
363                            ValueRange elements) {
364   assert(!elements.empty() && "expected at least one element");
365   build(builder, result, elements.front().getType(), elements);
366 }
367 
368 OpFoldResult FromElementsOp::fold(ArrayRef<Attribute> operands) {
369   if (!llvm::is_contained(operands, nullptr))
370     return DenseElementsAttr::get(getType(), operands);
371   return {};
372 }
373 
374 namespace {
375 
376 // Canonicalizes the pattern of the form
377 //
378 // %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
379 // %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32>
380 //
381 // to just %element.
382 struct ExtractElementFromTensorFromElements
383     : public OpRewritePattern<tensor::ExtractOp> {
384   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
385 
386   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
387                                 PatternRewriter &rewriter) const final {
388     if (extract.indices().size() != 1)
389       return failure();
390 
391     auto tensorFromElements = extract.tensor().getDefiningOp<FromElementsOp>();
392     if (tensorFromElements == nullptr)
393       return failure();
394 
395     APInt index;
396     if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index)))
397       return failure();
398     // Prevent out of bounds accesses. This can happen in invalid code that will
399     // never execute.
400     if (tensorFromElements->getNumOperands() <= index.getZExtValue() ||
401         index.getSExtValue() < 0)
402       return failure();
403     rewriter.replaceOp(extract,
404                        tensorFromElements.getOperand(index.getZExtValue()));
405     return success();
406   }
407 };
408 
409 } // namespace
410 
411 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
412                                                  MLIRContext *context) {
413   results.add<ExtractElementFromTensorFromElements>(context);
414 }
415 
416 //===----------------------------------------------------------------------===//
417 // InsertOp
418 //===----------------------------------------------------------------------===//
419 
420 static LogicalResult verify(InsertOp op) {
421   // Verify the # indices match if we have a ranked type.
422   if (auto destType = op.dest().getType().dyn_cast<RankedTensorType>())
423     if (destType.getRank() != static_cast<int64_t>(op.indices().size()))
424       return op.emitOpError("incorrect number of indices");
425   return success();
426 }
427 
428 OpFoldResult InsertOp::fold(ArrayRef<Attribute> operands) {
429   Attribute scalar = operands[0];
430   Attribute dest = operands[1];
431   if (scalar && dest)
432     if (auto splatDest = dest.dyn_cast<SplatElementsAttr>())
433       if (scalar == splatDest.getSplatValue())
434         return dest;
435   return {};
436 }
437 
438 //===----------------------------------------------------------------------===//
439 // GenerateOp
440 //===----------------------------------------------------------------------===//
441 
442 static LogicalResult verify(GenerateOp op) {
443   // Ensure that the tensor type has as many dynamic dimensions as are specified
444   // by the operands.
445   RankedTensorType resultTy = op.getType().cast<RankedTensorType>();
446   if (op.getNumOperands() != resultTy.getNumDynamicDims())
447     return op.emitError("must have as many index operands as dynamic extents "
448                         "in the result type");
449 
450   // Ensure that region arguments span the index space.
451   if (!llvm::all_of(op.body().getArgumentTypes(),
452                     [](Type ty) { return ty.isIndex(); }))
453     return op.emitError("all body arguments must be index");
454   if (op.body().getNumArguments() != resultTy.getRank())
455     return op.emitError("must have one body argument per input dimension");
456 
457   // Ensure that the region yields an element of the right type.
458   auto yieldOp =
459       llvm::cast<YieldOp>(op.body().getBlocks().front().getTerminator());
460   if (yieldOp.value().getType() != resultTy.getElementType())
461     return op.emitOpError(
462         "body must be terminated with a `yield` operation of the tensor "
463         "element type");
464 
465   return success();
466 }
467 
468 void GenerateOp::build(
469     OpBuilder &b, OperationState &result, Type resultTy,
470     ValueRange dynamicExtents,
471     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
472   build(b, result, resultTy, dynamicExtents);
473 
474   // Build and populate body.
475   OpBuilder::InsertionGuard guard(b);
476   Region *bodyRegion = result.regions.front().get();
477   auto rank = resultTy.cast<RankedTensorType>().getRank();
478   SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
479   Block *bodyBlock =
480       b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes);
481   bodyBuilder(b, result.location, bodyBlock->getArguments());
482 }
483 
484 namespace {
485 
486 /// Canonicalizes tensor.generate operations with a constant
487 /// operand into the equivalent operation with the operand expressed in the
488 /// result type, instead. We also insert a type cast to make sure that the
489 /// resulting IR is still well-typed.
490 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
491   using OpRewritePattern<GenerateOp>::OpRewritePattern;
492 
493   LogicalResult matchAndRewrite(GenerateOp tensorFromElements,
494                                 PatternRewriter &rewriter) const final {
495     auto resultType =
496         tensorFromElements.getResult().getType().cast<RankedTensorType>();
497 
498     if (resultType.hasStaticShape())
499       return failure();
500 
501     SmallVector<Value, 4> newOperands;
502     SmallVector<int64_t, 4> newShape;
503     auto operandsIt = tensorFromElements.dynamicExtents().begin();
504 
505     for (int64_t dim : resultType.getShape()) {
506       if (dim != RankedTensorType::kDynamicSize) {
507         newShape.push_back(dim);
508         continue;
509       }
510       APInt index;
511       if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
512         newShape.push_back(RankedTensorType::kDynamicSize);
513         newOperands.push_back(*operandsIt++);
514         continue;
515       }
516       newShape.push_back(index.getSExtValue());
517       operandsIt++;
518     }
519 
520     if (newOperands.size() == tensorFromElements.dynamicExtents().size())
521       return failure();
522 
523     auto loc = tensorFromElements.getLoc();
524     auto newOp = rewriter.create<GenerateOp>(
525         loc, RankedTensorType::get(newShape, resultType.getElementType()),
526         newOperands);
527     rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(),
528                                 newOp.body().begin());
529     rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType,
530                                                 newOp);
531     return success();
532   }
533 };
534 
535 /// Canonicalizes the pattern of the form
536 ///
537 /// %tensor = tensor.generate %x {
538 ///   ^bb0(%arg0: index):  // no predecessors
539 ///   <computation>
540 ///   yield %1 : index
541 /// } : tensor<?xindex>
542 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
543 ///
544 /// to just <computation> with %arg0 replaced by %c0. We only do this if the
545 /// tensor.generate operation has no side-effects.
546 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
547   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
548 
549   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
550                                 PatternRewriter &rewriter) const final {
551     auto tensorFromElements = extract.tensor().getDefiningOp<GenerateOp>();
552     if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
553       return failure();
554 
555     BlockAndValueMapping mapping;
556     Block *body = tensorFromElements.getBody();
557     mapping.map(body->getArguments(), extract.indices());
558     for (auto &op : body->without_terminator())
559       rewriter.clone(op, mapping);
560 
561     auto yield = cast<YieldOp>(body->getTerminator());
562 
563     rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value()));
564     return success();
565   }
566 };
567 
568 /// Canonicalizes the pattern of the form
569 ///
570 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
571 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
572 ///
573 /// to
574 ///
575 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
576 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
577   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
578 
579   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
580                                 PatternRewriter &rewriter) const final {
581     auto tensorCast = extract.tensor().getDefiningOp<tensor::CastOp>();
582     if (!tensorCast)
583       return failure();
584 
585     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(extract, tensorCast.source(),
586                                                    extract.indices());
587     return success();
588   }
589 };
590 
591 } // namespace
592 
593 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
594                                              MLIRContext *context) {
595   // TODO: Move extract patterns to tensor::ExtractOp.
596   results.add<ExtractFromTensorGenerate, ExtractFromTensorCast,
597               StaticTensorGenerate>(context);
598 }
599 
600 //===----------------------------------------------------------------------===//
601 // ReshapeOp
602 //===----------------------------------------------------------------------===//
603 
604 static int64_t GetNumElements(ShapedType type) {
605   int64_t numElements = 1;
606   for (auto dim : type.getShape())
607     numElements *= dim;
608   return numElements;
609 }
610 
611 static LogicalResult verify(ReshapeOp op) {
612   TensorType operandType = op.source().getType().cast<TensorType>();
613   TensorType resultType = op.result().getType().cast<TensorType>();
614 
615   if (operandType.getElementType() != resultType.getElementType())
616     return op.emitOpError("element types of source and destination tensor "
617                           "types should be the same");
618 
619   int64_t shapeSize =
620       op.shape().getType().cast<RankedTensorType>().getDimSize(0);
621   auto resultRankedType = resultType.dyn_cast<RankedTensorType>();
622   auto operandRankedType = operandType.dyn_cast<RankedTensorType>();
623 
624   if (resultRankedType) {
625     if (operandRankedType && resultRankedType.hasStaticShape() &&
626         operandRankedType.hasStaticShape()) {
627       if (GetNumElements(operandRankedType) != GetNumElements(resultRankedType))
628         return op.emitOpError("source and destination tensor should have the "
629                               "same number of elements");
630     }
631     if (shapeSize == TensorType::kDynamicSize)
632       return op.emitOpError("cannot use shape operand with dynamic length to "
633                             "reshape to statically-ranked tensor type");
634     if (shapeSize != resultRankedType.getRank())
635       return op.emitOpError(
636           "length of shape operand differs from the result's tensor rank");
637   }
638   return success();
639 }
640 
641 //===----------------------------------------------------------------------===//
642 // ExtractSliceOp
643 //===----------------------------------------------------------------------===//
644 
645 /// An extract_slice op result type can be fully inferred from the source type
646 /// and the static representation of offsets, sizes and strides. Special
647 /// sentinels encode the dynamic case.
648 Type ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType,
649                                      ArrayRef<int64_t> leadingStaticOffsets,
650                                      ArrayRef<int64_t> leadingStaticSizes,
651                                      ArrayRef<int64_t> leadingStaticStrides) {
652   // An extract_slice op may specify only a leading subset of offset/sizes/
653   // strides in which case we complete with offset=0, sizes from memref type and
654   // strides=1.
655   unsigned rank = sourceRankedTensorType.getRank();
656   assert(leadingStaticSizes.size() <= rank &&
657          "unexpected leadingStaticSizes overflow");
658   auto staticSizes = llvm::to_vector<4>(leadingStaticSizes);
659   unsigned numTrailingSizes = rank - staticSizes.size();
660   llvm::append_range(staticSizes, sourceRankedTensorType.getShape().take_back(
661                                       numTrailingSizes));
662   return RankedTensorType::get(staticSizes,
663                                sourceRankedTensorType.getElementType());
664 }
665 
666 Type ExtractSliceOp::inferResultType(
667     RankedTensorType sourceRankedTensorType,
668     ArrayRef<OpFoldResult> leadingStaticOffsets,
669     ArrayRef<OpFoldResult> leadingStaticSizes,
670     ArrayRef<OpFoldResult> leadingStaticStrides) {
671   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
672   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
673   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
674                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
675   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
676                              ShapedType::kDynamicSize);
677   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
678                              staticStrides, ShapedType::kDynamicStrideOrOffset);
679   return ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets,
680                                          staticSizes, staticStrides);
681 }
682 
683 /// An extract_slice op result type can be fully inferred from the source type
684 /// and the static representation of offsets, sizes and strides. Special
685 /// sentinels encode the dynamic case.
686 Type ExtractSliceOp::inferRankReducedResultType(
687     unsigned resultRank, RankedTensorType sourceRankedTensorType,
688     ArrayRef<int64_t> leadingStaticOffsets,
689     ArrayRef<int64_t> leadingStaticSizes,
690     ArrayRef<int64_t> leadingStaticStrides) {
691   auto inferredType =
692       inferResultType(sourceRankedTensorType, leadingStaticOffsets,
693                       leadingStaticSizes, leadingStaticStrides)
694           .cast<RankedTensorType>();
695   int rankDiff = inferredType.getRank() - resultRank;
696   if (rankDiff > 0) {
697     auto shape = inferredType.getShape();
698     llvm::SmallDenseSet<unsigned> dimsToProject;
699     mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject);
700     SmallVector<int64_t> projectedShape;
701     for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
702       if (!dimsToProject.contains(pos))
703         projectedShape.push_back(shape[pos]);
704     inferredType =
705         RankedTensorType::get(projectedShape, inferredType.getElementType());
706   }
707   return inferredType;
708 }
709 
710 Type ExtractSliceOp::inferRankReducedResultType(
711     unsigned resultRank, RankedTensorType sourceRankedTensorType,
712     ArrayRef<OpFoldResult> leadingStaticOffsets,
713     ArrayRef<OpFoldResult> leadingStaticSizes,
714     ArrayRef<OpFoldResult> leadingStaticStrides) {
715   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
716   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
717   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
718                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
719   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
720                              ShapedType::kDynamicSize);
721   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
722                              staticStrides, ShapedType::kDynamicStrideOrOffset);
723   return ExtractSliceOp::inferRankReducedResultType(
724       resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
725       staticStrides);
726 }
727 
728 /// Build an ExtractSliceOp with mixed static and dynamic entries and custom
729 /// result type. If the type passed is nullptr, it is inferred.
730 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
731                            RankedTensorType resultType, Value source,
732                            ArrayRef<OpFoldResult> offsets,
733                            ArrayRef<OpFoldResult> sizes,
734                            ArrayRef<OpFoldResult> strides,
735                            ArrayRef<NamedAttribute> attrs) {
736   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
737   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
738   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
739                              ShapedType::kDynamicStrideOrOffset);
740   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
741                              ShapedType::kDynamicSize);
742   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
743                              ShapedType::kDynamicStrideOrOffset);
744   auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
745   // Structuring implementation this way avoids duplication between builders.
746   if (!resultType) {
747     resultType =
748         ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets,
749                                         staticSizes, staticStrides)
750             .cast<RankedTensorType>();
751   }
752   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
753         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
754         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
755   result.addAttributes(attrs);
756 }
757 
758 /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
759 /// result type.
760 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
761                            ArrayRef<OpFoldResult> offsets,
762                            ArrayRef<OpFoldResult> sizes,
763                            ArrayRef<OpFoldResult> strides,
764                            ArrayRef<NamedAttribute> attrs) {
765   build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
766 }
767 
768 /// Build an ExtractSliceOp with dynamic entries and custom result type. If the
769 /// type passed is nullptr, it is inferred.
770 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
771                            RankedTensorType resultType, Value source,
772                            ValueRange offsets, ValueRange sizes,
773                            ValueRange strides, ArrayRef<NamedAttribute> attrs) {
774   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
775       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
776   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
777       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
778   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
779       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
780   build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
781 }
782 
783 /// Build an ExtractSliceOp with dynamic entries and inferred result type.
784 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
785                            ValueRange offsets, ValueRange sizes,
786                            ValueRange strides, ArrayRef<NamedAttribute> attrs) {
787   build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
788 }
789 
790 enum SliceVerificationResult {
791   Success,
792   RankTooLarge,
793   SizeMismatch,
794   ElemTypeMismatch,
795 };
796 
797 /// Checks if `original` Type type can be rank reduced to `reduced` type.
798 /// This function is slight variant of `is subsequence` algorithm where
799 /// not matching dimension must be 1.
800 static SliceVerificationResult
801 isRankReducedType(Type originalType, Type candidateReducedType,
802                   std::string *errMsg = nullptr) {
803   if (originalType == candidateReducedType)
804     return SliceVerificationResult::Success;
805   if (!originalType.isa<RankedTensorType>())
806     return SliceVerificationResult::Success;
807   if (originalType.isa<RankedTensorType>() &&
808       !candidateReducedType.isa<RankedTensorType>())
809     return SliceVerificationResult::Success;
810 
811   ShapedType originalShapedType = originalType.cast<ShapedType>();
812   ShapedType candidateReducedShapedType =
813       candidateReducedType.cast<ShapedType>();
814 
815   // Rank and size logic is valid for all ShapedTypes.
816   ArrayRef<int64_t> originalShape = originalShapedType.getShape();
817   ArrayRef<int64_t> candidateReducedShape =
818       candidateReducedShapedType.getShape();
819   unsigned originalRank = originalShape.size(),
820            candidateReducedRank = candidateReducedShape.size();
821   if (candidateReducedRank > originalRank)
822     return SliceVerificationResult::RankTooLarge;
823 
824   auto optionalUnusedDimsMask =
825       computeRankReductionMask(originalShape, candidateReducedShape);
826 
827   // Sizes cannot be matched in case empty vector is returned.
828   if (!optionalUnusedDimsMask.hasValue())
829     return SliceVerificationResult::SizeMismatch;
830 
831   if (originalShapedType.getElementType() !=
832       candidateReducedShapedType.getElementType())
833     return SliceVerificationResult::ElemTypeMismatch;
834 
835   // We are done for the tensor case.
836   if (originalType.isa<RankedTensorType>())
837     return SliceVerificationResult::Success;
838 
839   return SliceVerificationResult::Success;
840 }
841 
842 template <typename OpTy>
843 static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
844                                           OpTy op, Type expectedType,
845                                           StringRef errMsg = "") {
846   auto memrefType = expectedType.cast<ShapedType>();
847   switch (result) {
848   case SliceVerificationResult::Success:
849     return success();
850   case SliceVerificationResult::RankTooLarge:
851     return op.emitError("expected result rank to be smaller or equal to ")
852            << "the source rank. " << errMsg;
853   case SliceVerificationResult::SizeMismatch:
854     return op.emitError("expected result type to be ")
855            << expectedType
856            << " or a rank-reduced version. (mismatch of result sizes) "
857            << errMsg;
858   case SliceVerificationResult::ElemTypeMismatch:
859     return op.emitError("expected result element type to be ")
860            << memrefType.getElementType() << errMsg;
861   }
862   llvm_unreachable("unexpected extract_slice op verification result");
863 }
864 
865 /// Verifier for ExtractSliceOp.
866 static LogicalResult verify(ExtractSliceOp op) {
867   // Verify result type against inferred type.
868   auto expectedType = ExtractSliceOp::inferResultType(
869       op.getSourceType(), extractFromI64ArrayAttr(op.static_offsets()),
870       extractFromI64ArrayAttr(op.static_sizes()),
871       extractFromI64ArrayAttr(op.static_strides()));
872   auto result = isRankReducedType(expectedType, op.getType());
873   return produceSliceErrorMsg(result, op, expectedType);
874 }
875 
876 /// Infer the canonical type of the result of an extract_slice op. Returns a
877 /// type with rank `resultRank` that is either the rank of the rank-reduced
878 /// type, or the non-rank-reduced type.
879 static RankedTensorType
880 getCanonicalSliceResultType(unsigned resultRank, RankedTensorType sourceType,
881                             ArrayRef<OpFoldResult> mixedOffsets,
882                             ArrayRef<OpFoldResult> mixedSizes,
883                             ArrayRef<OpFoldResult> mixedStrides) {
884   auto resultType =
885       ExtractSliceOp::inferRankReducedResultType(
886           resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides)
887           .cast<RankedTensorType>();
888   if (resultType.getRank() != resultRank) {
889     resultType = ExtractSliceOp::inferResultType(sourceType, mixedOffsets,
890                                                  mixedSizes, mixedStrides)
891                      .cast<RankedTensorType>();
892   }
893   return resultType;
894 }
895 
896 namespace {
897 /// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
898 /// This essentially pushes memref_cast past its consuming slice when
899 /// `canFoldIntoConsumerOp` is true.
900 ///
901 /// Example:
902 /// ```
903 ///   %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
904 ///   %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
905 ///   tensor<3x4xf32>
906 /// ```
907 /// is rewritten into:
908 /// ```
909 ///   %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
910 ///   tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
911 /// ```
912 class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
913 public:
914   using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
915 
916   LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
917                                 PatternRewriter &rewriter) const override {
918     // Any constant operand, just return to let SubViewOpConstantFolder kick in.
919     if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
920           return matchPattern(operand, matchConstantIndex());
921         }))
922       return failure();
923 
924     auto castOp = sliceOp.source().getDefiningOp<tensor::CastOp>();
925     if (!castOp)
926       return failure();
927 
928     if (!canFoldIntoConsumerOp(castOp))
929       return failure();
930 
931     /// Deduce the type of the result to use for the canonicalized operation.
932     RankedTensorType resultType = getCanonicalSliceResultType(
933         sliceOp.getType().getRank(), sliceOp.getSourceType(),
934         sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
935         sliceOp.getMixedStrides());
936     Value newSlice = rewriter.create<ExtractSliceOp>(
937         sliceOp.getLoc(), resultType, castOp.source(), sliceOp.offsets(),
938         sliceOp.sizes(), sliceOp.strides(), sliceOp.static_offsets(),
939         sliceOp.static_sizes(), sliceOp.static_strides());
940     rewriter.replaceOpWithNewOp<tensor::CastOp>(sliceOp, sliceOp.getType(),
941                                                 newSlice);
942     return success();
943   }
944 };
945 } // namespace
946 
947 /// Return the canonical type of the result of an extract_slice op.
948 struct SliceReturnTypeCanonicalizer {
949   RankedTensorType operator()(ExtractSliceOp op,
950                               ArrayRef<OpFoldResult> mixedOffsets,
951                               ArrayRef<OpFoldResult> mixedSizes,
952                               ArrayRef<OpFoldResult> mixedStrides) {
953     return getCanonicalSliceResultType(op.getType().getRank(),
954                                        op.getSourceType(), mixedOffsets,
955                                        mixedSizes, mixedStrides);
956   }
957 };
958 
959 /// A canonicalizer wrapper to replace ExtractSliceOps.
960 struct SliceCanonicalizer {
961   void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
962                   ExtractSliceOp newOp) {
963     Value replacement = newOp.getResult();
964     if (replacement.getType() != op.getType())
965       replacement = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
966                                                     replacement);
967     rewriter.replaceOp(op, replacement);
968   }
969 };
970 
971 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
972                                                  MLIRContext *context) {
973   results.add<
974       OpWithOffsetSizesAndStridesConstantArgumentFolder<
975           ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
976       ExtractSliceOpCastFolder>(context);
977 }
978 
979 //
980 static LogicalResult
981 foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
982                                            ShapedType shapedType) {
983   OpBuilder b(op.getContext());
984   for (OpFoldResult ofr : op.getMixedOffsets())
985     if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
986       return failure();
987   // Rank-reducing noops only need to inspect the leading dimensions: llvm::zip
988   // is appropriate.
989   auto shape = shapedType.getShape();
990   for (auto it : llvm::zip(op.getMixedSizes(), shape))
991     if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
992       return failure();
993   for (OpFoldResult ofr : op.getMixedStrides())
994     if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
995       return failure();
996   return success();
997 }
998 
999 OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute>) {
1000   if (getSourceType() == getType() &&
1001       succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
1002     return this->source();
1003   return OpFoldResult();
1004 }
1005 
1006 //===----------------------------------------------------------------------===//
1007 // InsertSliceOp
1008 //===----------------------------------------------------------------------===//
1009 
1010 // Build a InsertSliceOp with mixed static and dynamic entries.
1011 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1012                           Value dest, ArrayRef<OpFoldResult> offsets,
1013                           ArrayRef<OpFoldResult> sizes,
1014                           ArrayRef<OpFoldResult> strides,
1015                           ArrayRef<NamedAttribute> attrs) {
1016   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1017   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1018   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1019                              ShapedType::kDynamicStrideOrOffset);
1020   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1021                              ShapedType::kDynamicSize);
1022   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1023                              ShapedType::kDynamicStrideOrOffset);
1024   build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
1025         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1026         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1027   result.addAttributes(attrs);
1028 }
1029 
1030 // Build a InsertSliceOp with dynamic entries.
1031 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1032                           Value dest, ValueRange offsets, ValueRange sizes,
1033                           ValueRange strides, ArrayRef<NamedAttribute> attrs) {
1034   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1035       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1036   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1037       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1038   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1039       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1040   build(b, result, source, dest, offsetValues, sizeValues, strideValues);
1041 }
1042 
1043 OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute>) {
1044   if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
1045       getSourceType() == getType() &&
1046       succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
1047     return this->source();
1048   return OpFoldResult();
1049 }
1050 
1051 namespace {
1052 /// Pattern to rewrite a insert_slice op with constant arguments.
1053 class InsertSliceOpConstantArgumentFolder final
1054     : public OpRewritePattern<InsertSliceOp> {
1055 public:
1056   using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
1057 
1058   LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
1059                                 PatternRewriter &rewriter) const override {
1060     // No constant operand, just return.
1061     if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
1062           return matchPattern(operand, matchConstantIndex());
1063         }))
1064       return failure();
1065 
1066     // At least one of offsets/sizes/strides is a new constant.
1067     // Form the new list of operands and constant attributes from the
1068     // existing.
1069     SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
1070     SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
1071     SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
1072     canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
1073     canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
1074     canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
1075 
1076     // Create the new op in canonical form.
1077     rewriter.replaceOpWithNewOp<InsertSliceOp>(
1078         insertSliceOp, insertSliceOp.source(), insertSliceOp.dest(),
1079         mixedOffsets, mixedSizes, mixedStrides);
1080     return success();
1081   }
1082 };
1083 
1084 /// Fold tensor_casts with insert_slice operations.
1085 struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> {
1086   using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
1087 
1088   LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
1089                                 PatternRewriter &rewriter) const override {
1090     if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
1091           return matchPattern(operand, matchConstantIndex());
1092         }))
1093       return failure();
1094 
1095     auto getSourceOfCastOp = [](Value v) -> Optional<Value> {
1096       auto castOp = v.getDefiningOp<tensor::CastOp>();
1097       if (!castOp || !canFoldIntoConsumerOp(castOp))
1098         return llvm::None;
1099       return castOp.source();
1100     };
1101     Optional<Value> sourceCastSource =
1102         getSourceOfCastOp(insertSliceOp.source());
1103     Optional<Value> destCastSource = getSourceOfCastOp(insertSliceOp.dest());
1104     if (!sourceCastSource && !destCastSource)
1105       return failure();
1106 
1107     Value replacement = rewriter.create<InsertSliceOp>(
1108         insertSliceOp.getLoc(),
1109         (sourceCastSource ? *sourceCastSource : insertSliceOp.source()),
1110         (destCastSource ? *destCastSource : insertSliceOp.dest()),
1111         insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
1112         insertSliceOp.getMixedStrides());
1113 
1114     if (replacement.getType() != insertSliceOp.getType()) {
1115       replacement = rewriter.create<tensor::CastOp>(
1116           insertSliceOp.getLoc(), insertSliceOp.getType(), replacement);
1117     }
1118     rewriter.replaceOp(insertSliceOp, replacement);
1119     return success();
1120   }
1121 };
1122 } // namespace
1123 
1124 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
1125                                                 MLIRContext *context) {
1126   results.add<InsertSliceOpConstantArgumentFolder, InsertSliceOpCastFolder>(
1127       context);
1128 }
1129 
1130 //===----------------------------------------------------------------------===//
1131 // TableGen'd op method definitions
1132 //===----------------------------------------------------------------------===//
1133 
1134 #define GET_OP_CLASSES
1135 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
1136