1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
10 #include "mlir/Dialect/Tensor/IR/Tensor.h"
11 #include "mlir/IR/BlockAndValueMapping.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/Matchers.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/IR/TypeUtilities.h"
16 #include "llvm/ADT/STLExtras.h"
17 
18 using namespace mlir;
19 using namespace mlir::tensor;
20 
21 //===----------------------------------------------------------------------===//
22 // CastOp
23 //===----------------------------------------------------------------------===//
24 
25 /// Determines whether tensor::CastOp casts to a more dynamic version of the
26 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
27 /// implement canonicalization patterns for ops in different dialects that may
28 /// consume the results of tensor.cast operations. Such foldable tensor.cast
29 /// operations are typically inserted as `slice` ops and are canonicalized,
30 /// to preserve the type compatibility of their uses.
31 ///
32 /// Returns true when all conditions are met:
33 /// 1. source and result are ranked tensors with same element type and rank.
34 /// 2. the tensor type has more static information than the result
35 ///
36 /// Example:
37 /// ```mlir
38 ///   %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
39 ///   %2 = consumer %1 ... : tensor<?x?xf32> ...
40 /// ```
41 ///
42 /// folds into:
43 ///
44 /// ```mlir
45 ///   %2 = consumer %0 ... : tensor<8x16xf32> ...
46 /// ```
47 bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
48   if (!castOp)
49     return false;
50 
51   RankedTensorType sourceType =
52       castOp.source().getType().dyn_cast<RankedTensorType>();
53   RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>();
54 
55   // Requires RankedTensorType.
56   if (!sourceType || !resultType)
57     return false;
58 
59   // Requires same elemental type.
60   if (sourceType.getElementType() != resultType.getElementType())
61     return false;
62 
63   // Requires same rank.
64   if (sourceType.getRank() != resultType.getRank())
65     return false;
66 
67   // If cast is towards more static sizes along any dimension, don't fold.
68   for (auto t : llvm::zip(sourceType.getShape(), resultType.getShape())) {
69     if (ShapedType::isDynamic(std::get<0>(t)) &&
70         !ShapedType::isDynamic(std::get<1>(t)))
71       return false;
72   }
73 
74   return true;
75 }
76 
77 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
78 /// that can be folded.
79 LogicalResult mlir::tensor::foldTensorCast(Operation *op) {
80   bool folded = false;
81   for (OpOperand &operand : op->getOpOperands()) {
82     auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
83     if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
84       operand.set(castOp.getOperand());
85       folded = true;
86     }
87   }
88   return success(folded);
89 }
90 
91 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
92   if (inputs.size() != 1 || outputs.size() != 1)
93     return false;
94   Type a = inputs.front(), b = outputs.front();
95   auto aT = a.dyn_cast<TensorType>();
96   auto bT = b.dyn_cast<TensorType>();
97   if (!aT || !bT)
98     return false;
99 
100   if (aT.getElementType() != bT.getElementType())
101     return false;
102 
103   return succeeded(verifyCompatibleShape(aT, bT));
104 }
105 
106 /// Compute a TensorType that has the joined shape knowledge of the two
107 /// given TensorTypes. The element types need to match.
108 static TensorType joinShapes(TensorType one, TensorType two) {
109   assert(one.getElementType() == two.getElementType());
110 
111   if (!one.hasRank())
112     return two;
113   if (!two.hasRank())
114     return one;
115 
116   int64_t rank = one.getRank();
117   if (rank != two.getRank())
118     return {};
119 
120   SmallVector<int64_t, 4> join;
121   join.reserve(rank);
122   for (int64_t i = 0; i < rank; ++i) {
123     if (one.isDynamicDim(i)) {
124       join.push_back(two.getDimSize(i));
125       continue;
126     }
127     if (two.isDynamicDim(i)) {
128       join.push_back(one.getDimSize(i));
129       continue;
130     }
131     if (one.getDimSize(i) != two.getDimSize(i))
132       return {};
133     join.push_back(one.getDimSize(i));
134   }
135   return RankedTensorType::get(join, one.getElementType());
136 }
137 
138 namespace {
139 
140 /// Replaces chains of two tensor.cast operations by a single tensor.cast
141 /// operation if doing so does not remove runtime constraints.
142 struct ChainedTensorCast : public OpRewritePattern<CastOp> {
143   using OpRewritePattern<CastOp>::OpRewritePattern;
144 
145   LogicalResult matchAndRewrite(CastOp tensorCast,
146                                 PatternRewriter &rewriter) const final {
147     auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
148 
149     if (!tensorCastOperand)
150       return failure();
151 
152     auto sourceType =
153         tensorCastOperand.getOperand().getType().cast<TensorType>();
154     auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
155     auto resultType = tensorCast.getType().cast<TensorType>();
156 
157     // We can remove the intermediate cast if joining all three produces the
158     // same result as just joining the source and result shapes.
159     auto firstJoin =
160         joinShapes(joinShapes(sourceType, intermediateType), resultType);
161 
162     // The join might not exist if the cast sequence would fail at runtime.
163     if (!firstJoin)
164       return failure();
165 
166     // The newJoin always exists if the above join exists, it might just contain
167     // less information. If so, we cannot drop the intermediate cast, as doing
168     // so would remove runtime checks.
169     auto newJoin = joinShapes(sourceType, resultType);
170     if (firstJoin != newJoin)
171       return failure();
172 
173     rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
174                                         tensorCastOperand.getOperand());
175     return success();
176   }
177 };
178 
179 } // namespace
180 
181 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
182                                          MLIRContext *context) {
183   results.add<ChainedTensorCast>(context);
184 }
185 
186 //===----------------------------------------------------------------------===//
187 // ExtractOp
188 //===----------------------------------------------------------------------===//
189 
190 static LogicalResult verify(ExtractOp op) {
191   // Verify the # indices match if we have a ranked type.
192   if (auto tensorType = op.tensor().getType().dyn_cast<RankedTensorType>())
193     if (tensorType.getRank() != static_cast<int64_t>(op.indices().size()))
194       return op.emitOpError("incorrect number of indices for extract_element");
195 
196   return success();
197 }
198 
199 OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
200   // The tensor operand must be a known constant.
201   Attribute tensor = operands.front();
202   if (!tensor)
203     return {};
204   // If this is a splat elements attribute, simply return the value. All of the
205   // elements of a splat attribute are the same.
206   if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
207     return splatTensor.getSplatValue();
208 
209   // Otherwise, collect the constant indices into the tensor.
210   SmallVector<uint64_t, 8> indices;
211   for (Attribute indice : llvm::drop_begin(operands, 1)) {
212     if (!indice || !indice.isa<IntegerAttr>())
213       return {};
214     indices.push_back(indice.cast<IntegerAttr>().getInt());
215   }
216 
217   // If this is an elements attribute, query the value at the given indices.
218   auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
219   if (elementsAttr && elementsAttr.isValidIndex(indices))
220     return elementsAttr.getValue(indices);
221   return {};
222 }
223 
224 //===----------------------------------------------------------------------===//
225 // FromElementsOp
226 //===----------------------------------------------------------------------===//
227 
228 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
229                            Type elementType, ValueRange elements) {
230   Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())},
231                                         elementType);
232   result.addOperands(elements);
233   result.addTypes(resultTy);
234 }
235 
236 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
237                            ValueRange elements) {
238   assert(!elements.empty() && "expected at least one element");
239   build(builder, result, elements.front().getType(), elements);
240 }
241 
242 OpFoldResult FromElementsOp::fold(ArrayRef<Attribute> operands) {
243   if (!llvm::is_contained(operands, nullptr))
244     return DenseElementsAttr::get(getType(), operands);
245   return {};
246 }
247 
248 namespace {
249 
250 // Canonicalizes the pattern of the form
251 //
252 // %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
253 // %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32>
254 //
255 // to just %element.
256 struct ExtractElementFromTensorFromElements
257     : public OpRewritePattern<tensor::ExtractOp> {
258   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
259 
260   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
261                                 PatternRewriter &rewriter) const final {
262     if (extract.indices().size() != 1)
263       return failure();
264 
265     auto tensorFromElements = extract.tensor().getDefiningOp<FromElementsOp>();
266     if (tensorFromElements == nullptr)
267       return failure();
268 
269     APInt index;
270     if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index)))
271       return failure();
272     // Prevent out of bounds accesses. This can happen in invalid code that will
273     // never execute.
274     if (tensorFromElements->getNumOperands() <= index.getZExtValue() ||
275         index.getSExtValue() < 0)
276       return failure();
277     rewriter.replaceOp(extract,
278                        tensorFromElements.getOperand(index.getZExtValue()));
279     return success();
280   }
281 };
282 
283 } // namespace
284 
285 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
286                                                  MLIRContext *context) {
287   results.add<ExtractElementFromTensorFromElements>(context);
288 }
289 
290 //===----------------------------------------------------------------------===//
291 // InsertOp
292 //===----------------------------------------------------------------------===//
293 
294 static LogicalResult verify(InsertOp op) {
295   // Verify the # indices match if we have a ranked type.
296   if (auto destType = op.dest().getType().dyn_cast<RankedTensorType>())
297     if (destType.getRank() != static_cast<int64_t>(op.indices().size()))
298       return op.emitOpError("incorrect number of indices");
299   return success();
300 }
301 
302 OpFoldResult InsertOp::fold(ArrayRef<Attribute> operands) {
303   Attribute scalar = operands[0];
304   Attribute dest = operands[1];
305   if (scalar && dest)
306     if (auto splatDest = dest.dyn_cast<SplatElementsAttr>())
307       if (scalar == splatDest.getSplatValue())
308         return dest;
309   return {};
310 }
311 
312 //===----------------------------------------------------------------------===//
313 // GenerateOp
314 //===----------------------------------------------------------------------===//
315 
316 static LogicalResult verify(GenerateOp op) {
317   // Ensure that the tensor type has as many dynamic dimensions as are specified
318   // by the operands.
319   RankedTensorType resultTy = op.getType().cast<RankedTensorType>();
320   if (op.getNumOperands() != resultTy.getNumDynamicDims())
321     return op.emitError("must have as many index operands as dynamic extents "
322                         "in the result type");
323 
324   // Ensure that region arguments span the index space.
325   if (!llvm::all_of(op.body().getArgumentTypes(),
326                     [](Type ty) { return ty.isIndex(); }))
327     return op.emitError("all body arguments must be index");
328   if (op.body().getNumArguments() != resultTy.getRank())
329     return op.emitError("must have one body argument per input dimension");
330 
331   // Ensure that the region yields an element of the right type.
332   auto yieldOp =
333       llvm::cast<YieldOp>(op.body().getBlocks().front().getTerminator());
334   if (yieldOp.value().getType() != resultTy.getElementType())
335     return op.emitOpError(
336         "body must be terminated with a `yield` operation of the tensor "
337         "element type");
338 
339   return success();
340 }
341 
342 void GenerateOp::build(
343     OpBuilder &b, OperationState &result, Type resultTy,
344     ValueRange dynamicExtents,
345     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
346   build(b, result, resultTy, dynamicExtents);
347 
348   // Build and populate body.
349   OpBuilder::InsertionGuard guard(b);
350   Region *bodyRegion = result.regions.front().get();
351   auto rank = resultTy.cast<RankedTensorType>().getRank();
352   SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
353   Block *bodyBlock =
354       b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes);
355   bodyBuilder(b, result.location, bodyBlock->getArguments());
356 }
357 
358 namespace {
359 
360 /// Canonicalizes tensor.generate operations with a constant
361 /// operand into the equivalent operation with the operand expressed in the
362 /// result type, instead. We also insert a type cast to make sure that the
363 /// resulting IR is still well-typed.
364 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
365   using OpRewritePattern<GenerateOp>::OpRewritePattern;
366 
367   LogicalResult matchAndRewrite(GenerateOp tensorFromElements,
368                                 PatternRewriter &rewriter) const final {
369     auto resultType =
370         tensorFromElements.getResult().getType().cast<RankedTensorType>();
371 
372     if (resultType.hasStaticShape())
373       return failure();
374 
375     SmallVector<Value, 4> newOperands;
376     SmallVector<int64_t, 4> newShape;
377     auto operandsIt = tensorFromElements.dynamicExtents().begin();
378 
379     for (int64_t dim : resultType.getShape()) {
380       if (dim != RankedTensorType::kDynamicSize) {
381         newShape.push_back(dim);
382         continue;
383       }
384       APInt index;
385       if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
386         newShape.push_back(RankedTensorType::kDynamicSize);
387         newOperands.push_back(*operandsIt++);
388         continue;
389       }
390       newShape.push_back(index.getSExtValue());
391       operandsIt++;
392     }
393 
394     if (newOperands.size() == tensorFromElements.dynamicExtents().size())
395       return failure();
396 
397     auto loc = tensorFromElements.getLoc();
398     auto newOp = rewriter.create<GenerateOp>(
399         loc, RankedTensorType::get(newShape, resultType.getElementType()),
400         newOperands);
401     rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(),
402                                 newOp.body().begin());
403     rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType,
404                                                 newOp);
405     return success();
406   }
407 };
408 
409 /// Canonicalizes the pattern of the form
410 ///
411 /// %tensor = tensor.generate %x {
412 ///   ^bb0(%arg0: index):  // no predecessors
413 ///   <computation>
414 ///   yield %1 : index
415 /// } : tensor<?xindex>
416 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
417 ///
418 /// to just <computation> with %arg0 replaced by %c0. We only do this if the
419 /// tensor.generate operation has no side-effects.
420 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
421   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
422 
423   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
424                                 PatternRewriter &rewriter) const final {
425     auto tensorFromElements = extract.tensor().getDefiningOp<GenerateOp>();
426     if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
427       return failure();
428 
429     BlockAndValueMapping mapping;
430     Block *body = tensorFromElements.getBody();
431     mapping.map(body->getArguments(), extract.indices());
432     for (auto &op : body->without_terminator())
433       rewriter.clone(op, mapping);
434 
435     auto yield = cast<YieldOp>(body->getTerminator());
436 
437     rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value()));
438     return success();
439   }
440 };
441 
442 /// Canonicalizes the pattern of the form
443 ///
444 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
445 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
446 ///
447 /// to
448 ///
449 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
450 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
451   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
452 
453   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
454                                 PatternRewriter &rewriter) const final {
455     auto tensorCast = extract.tensor().getDefiningOp<tensor::CastOp>();
456     if (!tensorCast)
457       return failure();
458 
459     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(extract, tensorCast.source(),
460                                                    extract.indices());
461     return success();
462   }
463 };
464 
465 } // namespace
466 
467 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
468                                              MLIRContext *context) {
469   // TODO: Move extract patterns to tensor::ExtractOp.
470   results.add<ExtractFromTensorGenerate, ExtractFromTensorCast,
471               StaticTensorGenerate>(context);
472 }
473 
474 //===----------------------------------------------------------------------===//
475 // ReshapeOp
476 //===----------------------------------------------------------------------===//
477 
478 static int64_t GetNumElements(ShapedType type) {
479   int64_t numElements = 1;
480   for (auto dim : type.getShape())
481     numElements *= dim;
482   return numElements;
483 }
484 
485 static LogicalResult verify(ReshapeOp op) {
486   TensorType operandType = op.source().getType().cast<TensorType>();
487   TensorType resultType = op.result().getType().cast<TensorType>();
488 
489   if (operandType.getElementType() != resultType.getElementType())
490     return op.emitOpError("element types of source and destination tensor "
491                           "types should be the same");
492 
493   int64_t shapeSize =
494       op.shape().getType().cast<RankedTensorType>().getDimSize(0);
495   auto resultRankedType = resultType.dyn_cast<RankedTensorType>();
496   auto operandRankedType = operandType.dyn_cast<RankedTensorType>();
497 
498   if (resultRankedType) {
499     if (operandRankedType && resultRankedType.hasStaticShape() &&
500         operandRankedType.hasStaticShape()) {
501       if (GetNumElements(operandRankedType) != GetNumElements(resultRankedType))
502         return op.emitOpError("source and destination tensor should have the "
503                               "same number of elements");
504     }
505     if (shapeSize == TensorType::kDynamicSize)
506       return op.emitOpError("cannot use shape operand with dynamic length to "
507                             "reshape to statically-ranked tensor type");
508     if (shapeSize != resultRankedType.getRank())
509       return op.emitOpError(
510           "length of shape operand differs from the result's tensor rank");
511   }
512   return success();
513 }
514 
515 //===----------------------------------------------------------------------===//
516 // ExtractSliceOp
517 //===----------------------------------------------------------------------===//
518 
519 /// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
520 /// it is a Value or into `staticVec` if it is an IntegerAttr.
521 /// In the case of a Value, a copy of the `sentinel` value is also pushed to
522 /// `staticVec`. This is useful to extract mixed static and dynamic entries that
523 /// come from an AttrSizedOperandSegments trait.
524 static void dispatchIndexOpFoldResult(OpFoldResult ofr,
525                                       SmallVectorImpl<Value> &dynamicVec,
526                                       SmallVectorImpl<int64_t> &staticVec,
527                                       int64_t sentinel) {
528   if (auto v = ofr.dyn_cast<Value>()) {
529     dynamicVec.push_back(v);
530     staticVec.push_back(sentinel);
531     return;
532   }
533   APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue();
534   staticVec.push_back(apInt.getSExtValue());
535 }
536 
537 static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
538                                        SmallVectorImpl<Value> &dynamicVec,
539                                        SmallVectorImpl<int64_t> &staticVec,
540                                        int64_t sentinel) {
541   for (auto ofr : ofrs)
542     dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel);
543 }
544 
545 /// An extract_slice op result type can be fully inferred from the source type
546 /// and the static representation of offsets, sizes and strides. Special
547 /// sentinels encode the dynamic case.
548 Type ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType,
549                                      ArrayRef<int64_t> leadingStaticOffsets,
550                                      ArrayRef<int64_t> leadingStaticSizes,
551                                      ArrayRef<int64_t> leadingStaticStrides) {
552   // An extract_slice op may specify only a leading subset of offset/sizes/
553   // strides in which case we complete with offset=0, sizes from memref type and
554   // strides=1.
555   unsigned rank = sourceRankedTensorType.getRank();
556   assert(leadingStaticSizes.size() <= rank &&
557          "unexpected leadingStaticSizes overflow");
558   auto staticSizes = llvm::to_vector<4>(leadingStaticSizes);
559   unsigned numTrailingSizes = rank - staticSizes.size();
560   llvm::append_range(staticSizes, sourceRankedTensorType.getShape().take_back(
561                                       numTrailingSizes));
562   return RankedTensorType::get(staticSizes,
563                                sourceRankedTensorType.getElementType());
564 }
565 
566 /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
567 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
568   return llvm::to_vector<4>(
569       llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
570         return a.cast<IntegerAttr>().getInt();
571       }));
572 }
573 
574 Type ExtractSliceOp::inferResultType(
575     RankedTensorType sourceRankedTensorType,
576     ArrayRef<OpFoldResult> leadingStaticOffsets,
577     ArrayRef<OpFoldResult> leadingStaticSizes,
578     ArrayRef<OpFoldResult> leadingStaticStrides) {
579   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
580   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
581   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
582                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
583   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
584                              ShapedType::kDynamicSize);
585   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
586                              staticStrides, ShapedType::kDynamicStrideOrOffset);
587   return ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets,
588                                          staticSizes, staticStrides);
589 }
590 
591 /// An extract_slice op result type can be fully inferred from the source type
592 /// and the static representation of offsets, sizes and strides. Special
593 /// sentinels encode the dynamic case.
594 Type ExtractSliceOp::inferRankReducedResultType(
595     unsigned resultRank, RankedTensorType sourceRankedTensorType,
596     ArrayRef<int64_t> leadingStaticOffsets,
597     ArrayRef<int64_t> leadingStaticSizes,
598     ArrayRef<int64_t> leadingStaticStrides) {
599   auto inferredType =
600       inferResultType(sourceRankedTensorType, leadingStaticOffsets,
601                       leadingStaticSizes, leadingStaticStrides)
602           .cast<RankedTensorType>();
603   int rankDiff = inferredType.getRank() - resultRank;
604   if (rankDiff > 0) {
605     auto shape = inferredType.getShape();
606     llvm::SmallDenseSet<unsigned> dimsToProject;
607     mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject);
608     SmallVector<int64_t> projectedShape;
609     for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
610       if (!dimsToProject.contains(pos))
611         projectedShape.push_back(shape[pos]);
612     inferredType =
613         RankedTensorType::get(projectedShape, inferredType.getElementType());
614   }
615   return inferredType;
616 }
617 
618 Type ExtractSliceOp::inferRankReducedResultType(
619     unsigned resultRank, RankedTensorType sourceRankedTensorType,
620     ArrayRef<OpFoldResult> leadingStaticOffsets,
621     ArrayRef<OpFoldResult> leadingStaticSizes,
622     ArrayRef<OpFoldResult> leadingStaticStrides) {
623   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
624   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
625   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
626                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
627   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
628                              ShapedType::kDynamicSize);
629   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
630                              staticStrides, ShapedType::kDynamicStrideOrOffset);
631   return ExtractSliceOp::inferRankReducedResultType(
632       resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
633       staticStrides);
634 }
635 
636 /// Build an ExtractSliceOp with mixed static and dynamic entries and custom
637 /// result type. If the type passed is nullptr, it is inferred.
638 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
639                            RankedTensorType resultType, Value source,
640                            ArrayRef<OpFoldResult> offsets,
641                            ArrayRef<OpFoldResult> sizes,
642                            ArrayRef<OpFoldResult> strides,
643                            ArrayRef<NamedAttribute> attrs) {
644   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
645   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
646   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
647                              ShapedType::kDynamicStrideOrOffset);
648   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
649                              ShapedType::kDynamicSize);
650   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
651                              ShapedType::kDynamicStrideOrOffset);
652   auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
653   // Structuring implementation this way avoids duplication between builders.
654   if (!resultType) {
655     resultType =
656         ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets,
657                                         staticSizes, staticStrides)
658             .cast<RankedTensorType>();
659   }
660   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
661         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
662         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
663   result.addAttributes(attrs);
664 }
665 
666 /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
667 /// result type.
668 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
669                            ArrayRef<OpFoldResult> offsets,
670                            ArrayRef<OpFoldResult> sizes,
671                            ArrayRef<OpFoldResult> strides,
672                            ArrayRef<NamedAttribute> attrs) {
673   build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
674 }
675 
676 /// Build an ExtractSliceOp with dynamic entries and custom result type. If the
677 /// type passed is nullptr, it is inferred.
678 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
679                            RankedTensorType resultType, Value source,
680                            ValueRange offsets, ValueRange sizes,
681                            ValueRange strides, ArrayRef<NamedAttribute> attrs) {
682   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
683       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
684   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
685       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
686   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
687       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
688   build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
689 }
690 
691 /// Build an ExtractSliceOp with dynamic entries and inferred result type.
692 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
693                            ValueRange offsets, ValueRange sizes,
694                            ValueRange strides, ArrayRef<NamedAttribute> attrs) {
695   build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
696 }
697 
698 enum SliceVerificationResult {
699   Success,
700   RankTooLarge,
701   SizeMismatch,
702   ElemTypeMismatch,
703 };
704 
705 /// Checks if `original` Type type can be rank reduced to `reduced` type.
706 /// This function is slight variant of `is subsequence` algorithm where
707 /// not matching dimension must be 1.
708 static SliceVerificationResult
709 isRankReducedType(Type originalType, Type candidateReducedType,
710                   std::string *errMsg = nullptr) {
711   if (originalType == candidateReducedType)
712     return SliceVerificationResult::Success;
713   if (!originalType.isa<RankedTensorType>())
714     return SliceVerificationResult::Success;
715   if (originalType.isa<RankedTensorType>() &&
716       !candidateReducedType.isa<RankedTensorType>())
717     return SliceVerificationResult::Success;
718 
719   ShapedType originalShapedType = originalType.cast<ShapedType>();
720   ShapedType candidateReducedShapedType =
721       candidateReducedType.cast<ShapedType>();
722 
723   // Rank and size logic is valid for all ShapedTypes.
724   ArrayRef<int64_t> originalShape = originalShapedType.getShape();
725   ArrayRef<int64_t> candidateReducedShape =
726       candidateReducedShapedType.getShape();
727   unsigned originalRank = originalShape.size(),
728            candidateReducedRank = candidateReducedShape.size();
729   if (candidateReducedRank > originalRank)
730     return SliceVerificationResult::RankTooLarge;
731 
732   auto optionalUnusedDimsMask =
733       computeRankReductionMask(originalShape, candidateReducedShape);
734 
735   // Sizes cannot be matched in case empty vector is returned.
736   if (!optionalUnusedDimsMask.hasValue())
737     return SliceVerificationResult::SizeMismatch;
738 
739   if (originalShapedType.getElementType() !=
740       candidateReducedShapedType.getElementType())
741     return SliceVerificationResult::ElemTypeMismatch;
742 
743   // We are done for the tensor case.
744   if (originalType.isa<RankedTensorType>())
745     return SliceVerificationResult::Success;
746 
747   return SliceVerificationResult::Success;
748 }
749 
750 template <typename OpTy>
751 static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
752                                           OpTy op, Type expectedType,
753                                           StringRef errMsg = "") {
754   auto memrefType = expectedType.cast<ShapedType>();
755   switch (result) {
756   case SliceVerificationResult::Success:
757     return success();
758   case SliceVerificationResult::RankTooLarge:
759     return op.emitError("expected result rank to be smaller or equal to ")
760            << "the source rank. " << errMsg;
761   case SliceVerificationResult::SizeMismatch:
762     return op.emitError("expected result type to be ")
763            << expectedType
764            << " or a rank-reduced version. (mismatch of result sizes) "
765            << errMsg;
766   case SliceVerificationResult::ElemTypeMismatch:
767     return op.emitError("expected result element type to be ")
768            << memrefType.getElementType() << errMsg;
769   }
770   llvm_unreachable("unexpected extract_slice op verification result");
771 }
772 
773 /// Verifier for ExtractSliceOp.
774 static LogicalResult verify(ExtractSliceOp op) {
775   // Verify result type against inferred type.
776   auto expectedType = ExtractSliceOp::inferResultType(
777       op.getSourceType(), extractFromI64ArrayAttr(op.static_offsets()),
778       extractFromI64ArrayAttr(op.static_sizes()),
779       extractFromI64ArrayAttr(op.static_strides()));
780   auto result = isRankReducedType(expectedType, op.getType());
781   return produceSliceErrorMsg(result, op, expectedType);
782 }
783 
784 /// Infer the canonical type of the result of an extract_slice op. Returns a
785 /// type with rank `resultRank` that is either the rank of the rank-reduced
786 /// type, or the non-rank-reduced type.
787 static RankedTensorType
788 getCanonicalSliceResultType(unsigned resultRank, RankedTensorType sourceType,
789                             ArrayRef<OpFoldResult> mixedOffsets,
790                             ArrayRef<OpFoldResult> mixedSizes,
791                             ArrayRef<OpFoldResult> mixedStrides) {
792   auto resultType =
793       ExtractSliceOp::inferRankReducedResultType(
794           resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides)
795           .cast<RankedTensorType>();
796   if (resultType.getRank() != resultRank) {
797     resultType = ExtractSliceOp::inferResultType(sourceType, mixedOffsets,
798                                                  mixedSizes, mixedStrides)
799                      .cast<RankedTensorType>();
800   }
801   return resultType;
802 }
803 
804 namespace {
805 /// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
806 /// This essentially pushes memref_cast past its consuming slice when
807 /// `canFoldIntoConsumerOp` is true.
808 ///
809 /// Example:
810 /// ```
811 ///   %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
812 ///   %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
813 ///   tensor<3x4xf32>
814 /// ```
815 /// is rewritten into:
816 /// ```
817 ///   %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
818 ///   tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
819 /// ```
820 class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
821 public:
822   using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
823 
824   LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
825                                 PatternRewriter &rewriter) const override {
826     // Any constant operand, just return to let SubViewOpConstantFolder kick in.
827     if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
828           return matchPattern(operand, matchConstantIndex());
829         }))
830       return failure();
831 
832     auto castOp = sliceOp.source().getDefiningOp<tensor::CastOp>();
833     if (!castOp)
834       return failure();
835 
836     if (!canFoldIntoConsumerOp(castOp))
837       return failure();
838 
839     /// Deduce the type of the result to use for the canonicalized operation.
840     RankedTensorType resultType = getCanonicalSliceResultType(
841         sliceOp.getType().getRank(), sliceOp.getSourceType(),
842         sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
843         sliceOp.getMixedStrides());
844     Value newSlice = rewriter.create<ExtractSliceOp>(
845         sliceOp.getLoc(), resultType, castOp.source(), sliceOp.offsets(),
846         sliceOp.sizes(), sliceOp.strides(), sliceOp.static_offsets(),
847         sliceOp.static_sizes(), sliceOp.static_strides());
848     rewriter.replaceOpWithNewOp<tensor::CastOp>(sliceOp, sliceOp.getType(),
849                                                 newSlice);
850     return success();
851   }
852 };
853 } // namespace
854 
855 /// Return the canonical type of the result of an extract_slice op.
856 struct SliceReturnTypeCanonicalizer {
857   RankedTensorType operator()(ExtractSliceOp op,
858                               ArrayRef<OpFoldResult> mixedOffsets,
859                               ArrayRef<OpFoldResult> mixedSizes,
860                               ArrayRef<OpFoldResult> mixedStrides) {
861     return getCanonicalSliceResultType(op.getType().getRank(),
862                                        op.getSourceType(), mixedOffsets,
863                                        mixedSizes, mixedStrides);
864   }
865 };
866 
867 /// A canonicalizer wrapper to replace ExtractSliceOps.
868 struct SliceCanonicalizer {
869   void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
870                   ExtractSliceOp newOp) {
871     Value replacement = newOp.getResult();
872     if (replacement.getType() != op.getType())
873       replacement = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
874                                                     replacement);
875     rewriter.replaceOp(op, replacement);
876   }
877 };
878 
879 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
880                                                  MLIRContext *context) {
881   results.add<
882       OpWithOffsetSizesAndStridesConstantArgumentFolder<
883           ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
884       ExtractSliceOpCastFolder>(context);
885 }
886 
887 //
888 static LogicalResult
889 foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
890                                            ShapedType shapedType) {
891   OpBuilder b(op.getContext());
892   for (OpFoldResult ofr : op.getMixedOffsets())
893     if (!isEqualConstantIntOrValue(ofr, b.getIndexAttr(0)))
894       return failure();
895   // Rank-reducing noops only need to inspect the leading dimensions: llvm::zip
896   // is appropriate.
897   auto shape = shapedType.getShape();
898   for (auto it : llvm::zip(op.getMixedSizes(), shape))
899     if (!isEqualConstantIntOrValue(std::get<0>(it),
900                                    b.getIndexAttr(std::get<1>(it))))
901       return failure();
902   for (OpFoldResult ofr : op.getMixedStrides())
903     if (!isEqualConstantIntOrValue(ofr, b.getIndexAttr(1)))
904       return failure();
905   return success();
906 }
907 
908 OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute>) {
909   if (getSourceType() == getType() &&
910       succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
911     return this->source();
912   return OpFoldResult();
913 }
914 
915 //===----------------------------------------------------------------------===//
916 // InsertSliceOp
917 //===----------------------------------------------------------------------===//
918 
919 // Build a InsertSliceOp with mixed static and dynamic entries.
920 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
921                           Value dest, ArrayRef<OpFoldResult> offsets,
922                           ArrayRef<OpFoldResult> sizes,
923                           ArrayRef<OpFoldResult> strides,
924                           ArrayRef<NamedAttribute> attrs) {
925   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
926   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
927   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
928                              ShapedType::kDynamicStrideOrOffset);
929   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
930                              ShapedType::kDynamicSize);
931   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
932                              ShapedType::kDynamicStrideOrOffset);
933   build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
934         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
935         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
936   result.addAttributes(attrs);
937 }
938 
939 // Build a InsertSliceOp with dynamic entries.
940 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
941                           Value dest, ValueRange offsets, ValueRange sizes,
942                           ValueRange strides, ArrayRef<NamedAttribute> attrs) {
943   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
944       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
945   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
946       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
947   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
948       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
949   build(b, result, source, dest, offsetValues, sizeValues, strideValues);
950 }
951 
952 OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute>) {
953   if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
954       getSourceType() == getType() &&
955       succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
956     return this->source();
957   return OpFoldResult();
958 }
959 
960 namespace {
961 /// Pattern to rewrite a insert_slice op with constant arguments.
962 class InsertSliceOpConstantArgumentFolder final
963     : public OpRewritePattern<InsertSliceOp> {
964 public:
965   using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
966 
967   LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
968                                 PatternRewriter &rewriter) const override {
969     // No constant operand, just return.
970     if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
971           return matchPattern(operand, matchConstantIndex());
972         }))
973       return failure();
974 
975     // At least one of offsets/sizes/strides is a new constant.
976     // Form the new list of operands and constant attributes from the
977     // existing.
978     SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
979     SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
980     SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
981     canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
982     canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
983     canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
984 
985     // Create the new op in canonical form.
986     rewriter.replaceOpWithNewOp<InsertSliceOp>(
987         insertSliceOp, insertSliceOp.source(), insertSliceOp.dest(),
988         mixedOffsets, mixedSizes, mixedStrides);
989     return success();
990   }
991 };
992 
993 /// Fold tensor_casts with insert_slice operations.
994 struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> {
995   using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
996 
997   LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
998                                 PatternRewriter &rewriter) const override {
999     if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
1000           return matchPattern(operand, matchConstantIndex());
1001         }))
1002       return failure();
1003 
1004     auto getSourceOfCastOp = [](Value v) -> Optional<Value> {
1005       auto castOp = v.getDefiningOp<tensor::CastOp>();
1006       if (!castOp || !canFoldIntoConsumerOp(castOp))
1007         return llvm::None;
1008       return castOp.source();
1009     };
1010     Optional<Value> sourceCastSource =
1011         getSourceOfCastOp(insertSliceOp.source());
1012     Optional<Value> destCastSource = getSourceOfCastOp(insertSliceOp.dest());
1013     if (!sourceCastSource && !destCastSource)
1014       return failure();
1015 
1016     Value replacement = rewriter.create<InsertSliceOp>(
1017         insertSliceOp.getLoc(),
1018         (sourceCastSource ? *sourceCastSource : insertSliceOp.source()),
1019         (destCastSource ? *destCastSource : insertSliceOp.dest()),
1020         insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
1021         insertSliceOp.getMixedStrides());
1022 
1023     if (replacement.getType() != insertSliceOp.getType()) {
1024       replacement = rewriter.create<tensor::CastOp>(
1025           insertSliceOp.getLoc(), insertSliceOp.getType(), replacement);
1026     }
1027     rewriter.replaceOp(insertSliceOp, replacement);
1028     return success();
1029   }
1030 };
1031 } // namespace
1032 
1033 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
1034                                                 MLIRContext *context) {
1035   results.add<InsertSliceOpConstantArgumentFolder, InsertSliceOpCastFolder>(
1036       context);
1037 }
1038 
1039 //===----------------------------------------------------------------------===//
1040 // TableGen'd op method definitions
1041 //===----------------------------------------------------------------------===//
1042 
1043 #define GET_OP_CLASSES
1044 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
1045