1444822d7SSean Silva //===----------------------------------------------------------------------===//
2444822d7SSean Silva //
3444822d7SSean Silva // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4444822d7SSean Silva // See https://llvm.org/LICENSE.txt for license information.
5444822d7SSean Silva // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6444822d7SSean Silva //
7444822d7SSean Silva //===----------------------------------------------------------------------===//
8444822d7SSean Silva 
9eda6f907SRiver Riddle #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10ead11072SRiver Riddle #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
11eda6f907SRiver Riddle #include "mlir/Dialect/Complex/IR/Complex.h"
124f5eb53eSOkwan Kwon #include "mlir/Dialect/Tensor/IR/Tensor.h"
13b618880eSAlexander Belyaev #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
140813700dSMatthias Springer #include "mlir/Dialect/Utils/StaticValueUtils.h"
15be7352c0SSean Silva #include "mlir/IR/BlockAndValueMapping.h"
16444822d7SSean Silva #include "mlir/IR/Builders.h"
17a08b750cSNicolas Vasilache #include "mlir/IR/BuiltinAttributeInterfaces.h"
18be7352c0SSean Silva #include "mlir/IR/Matchers.h"
19444822d7SSean Silva #include "mlir/IR/TypeUtilities.h"
20444822d7SSean Silva #include "llvm/ADT/STLExtras.h"
216635c12aSBenjamin Kramer #include "llvm/ADT/SmallBitVector.h"
22444822d7SSean Silva 
23444822d7SSean Silva using namespace mlir;
24444822d7SSean Silva using namespace mlir::tensor;
25444822d7SSean Silva 
26c0a6318dSMatthias Springer /// Materialize a single constant operation from a given attribute value with
27c0a6318dSMatthias Springer /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)28c0a6318dSMatthias Springer Operation *TensorDialect::materializeConstant(OpBuilder &builder,
29c0a6318dSMatthias Springer                                               Attribute value, Type type,
30c0a6318dSMatthias Springer                                               Location loc) {
31a54f4eaeSMogball   if (arith::ConstantOp::isBuildableWith(value, type))
32a54f4eaeSMogball     return builder.create<arith::ConstantOp>(loc, value, type);
332c7b0685SFrederik Gossen   if (complex::ConstantOp::isBuildableWith(value, type))
342c7b0685SFrederik Gossen     return builder.create<complex::ConstantOp>(loc, type,
352c7b0685SFrederik Gossen                                                value.cast<ArrayAttr>());
36a54f4eaeSMogball   return nullptr;
37c0a6318dSMatthias Springer }
38c0a6318dSMatthias Springer 
39444822d7SSean Silva //===----------------------------------------------------------------------===//
40129d6e55SSean Silva // CastOp
41129d6e55SSean Silva //===----------------------------------------------------------------------===//
42129d6e55SSean Silva 
433f89e339SAlex Zinenko /// Returns true if `target` is a ranked tensor type that preserves static
443f89e339SAlex Zinenko /// information available in the `source` ranked tensor type.
preservesStaticInformation(Type source,Type target)453f89e339SAlex Zinenko bool mlir::tensor::preservesStaticInformation(Type source, Type target) {
463f89e339SAlex Zinenko   auto sourceType = source.dyn_cast<RankedTensorType>();
473f89e339SAlex Zinenko   auto targetType = target.dyn_cast<RankedTensorType>();
483f89e339SAlex Zinenko 
493f89e339SAlex Zinenko   // Requires RankedTensorType.
503f89e339SAlex Zinenko   if (!sourceType || !targetType)
513f89e339SAlex Zinenko     return false;
523f89e339SAlex Zinenko 
533f89e339SAlex Zinenko   // Requires same elemental type.
543f89e339SAlex Zinenko   if (sourceType.getElementType() != targetType.getElementType())
553f89e339SAlex Zinenko     return false;
563f89e339SAlex Zinenko 
573f89e339SAlex Zinenko   // Requires same rank.
583f89e339SAlex Zinenko   if (sourceType.getRank() != targetType.getRank())
593f89e339SAlex Zinenko     return false;
603f89e339SAlex Zinenko 
613f89e339SAlex Zinenko   // If cast is towards more static sizes along any dimension, don't fold.
623f89e339SAlex Zinenko   for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
633f89e339SAlex Zinenko     if (!ShapedType::isDynamic(std::get<0>(t)) &&
643f89e339SAlex Zinenko         ShapedType::isDynamic(std::get<1>(t)))
653f89e339SAlex Zinenko       return false;
663f89e339SAlex Zinenko   }
673f89e339SAlex Zinenko 
683f89e339SAlex Zinenko   return true;
693f89e339SAlex Zinenko }
703f89e339SAlex Zinenko 
71129d6e55SSean Silva /// Determines whether tensor::CastOp casts to a more dynamic version of the
72129d6e55SSean Silva /// source tensor. This is useful to fold a tensor.cast into a consuming op and
73129d6e55SSean Silva /// implement canonicalization patterns for ops in different dialects that may
74129d6e55SSean Silva /// consume the results of tensor.cast operations. Such foldable tensor.cast
75060208b4SMatthias Springer /// operations are typically inserted as `slice` ops and are canonicalized,
76129d6e55SSean Silva /// to preserve the type compatibility of their uses.
77129d6e55SSean Silva ///
78129d6e55SSean Silva /// Returns true when all conditions are met:
79129d6e55SSean Silva /// 1. source and result are ranked tensors with same element type and rank.
80129d6e55SSean Silva /// 2. the tensor type has more static information than the result
81129d6e55SSean Silva ///
82129d6e55SSean Silva /// Example:
83129d6e55SSean Silva /// ```mlir
84129d6e55SSean Silva ///   %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
85129d6e55SSean Silva ///   %2 = consumer %1 ... : tensor<?x?xf32> ...
86129d6e55SSean Silva /// ```
87129d6e55SSean Silva ///
88129d6e55SSean Silva /// folds into:
89129d6e55SSean Silva ///
90129d6e55SSean Silva /// ```mlir
91129d6e55SSean Silva ///   %2 = consumer %0 ... : tensor<8x16xf32> ...
92129d6e55SSean Silva /// ```
canFoldIntoConsumerOp(CastOp castOp)93129d6e55SSean Silva bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
94129d6e55SSean Silva   if (!castOp)
95129d6e55SSean Silva     return false;
96129d6e55SSean Silva 
973f89e339SAlex Zinenko   // Can fold if the source of cast has at least as much static information as
983f89e339SAlex Zinenko   // its results.
993f89e339SAlex Zinenko   return preservesStaticInformation(castOp.getType(),
1002d70eff8SJacques Pienaar                                     castOp.getSource().getType());
101129d6e55SSean Silva }
102129d6e55SSean Silva 
103589eac65SMahesh Ravishankar /// Determines whether the tensor::CastOp casts to a more static version of the
104589eac65SMahesh Ravishankar /// source tensor. This is useful to fold into a producing op and implement
105589eac65SMahesh Ravishankar /// canonicaliation patterns with the `tensor.cast` op as the root, but producer
106589eac65SMahesh Ravishankar /// being from different dialects. Returns true when all conditions are met:
107589eac65SMahesh Ravishankar /// 1. source and result and ranked tensors with same element type and rank.
108589eac65SMahesh Ravishankar /// 2. the result type has more static information than the source.
109589eac65SMahesh Ravishankar ///
110589eac65SMahesh Ravishankar /// Example:
111589eac65SMahesh Ravishankar /// ```mlir
112589eac65SMahesh Ravishankar ///   %1 = producer ... : tensor<?x?xf32>
113589eac65SMahesh Ravishankar ///   %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
114589eac65SMahesh Ravishankar /// ```
115589eac65SMahesh Ravishankar ///
116589eac65SMahesh Ravishankar /// can be canonicalized to :
117589eac65SMahesh Ravishankar ///
118589eac65SMahesh Ravishankar /// ```mlir
119589eac65SMahesh Ravishankar ///   %2 = producer ... : tensor<8x16xf32>
120589eac65SMahesh Ravishankar /// ```
121589eac65SMahesh Ravishankar /// Not all ops might be canonicalizable this way, but for those that can be,
122589eac65SMahesh Ravishankar /// this method provides a check that it is worth doing the canonicalization.
canFoldIntoProducerOp(CastOp castOp)123589eac65SMahesh Ravishankar bool mlir::tensor::canFoldIntoProducerOp(CastOp castOp) {
124589eac65SMahesh Ravishankar   if (!castOp)
125589eac65SMahesh Ravishankar     return false;
1262d70eff8SJacques Pienaar   return preservesStaticInformation(castOp.getSource().getType(),
127589eac65SMahesh Ravishankar                                     castOp.getType());
128589eac65SMahesh Ravishankar }
129589eac65SMahesh Ravishankar 
1300ee4bf15SNicolas Vasilache /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
1310ee4bf15SNicolas Vasilache /// that can be folded.
foldTensorCast(Operation * op)1320ee4bf15SNicolas Vasilache LogicalResult mlir::tensor::foldTensorCast(Operation *op) {
1330ee4bf15SNicolas Vasilache   bool folded = false;
1340ee4bf15SNicolas Vasilache   for (OpOperand &operand : op->getOpOperands()) {
1350ee4bf15SNicolas Vasilache     auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
1360ee4bf15SNicolas Vasilache     if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
1370ee4bf15SNicolas Vasilache       operand.set(castOp.getOperand());
1380ee4bf15SNicolas Vasilache       folded = true;
1390ee4bf15SNicolas Vasilache     }
1400ee4bf15SNicolas Vasilache   }
1410ee4bf15SNicolas Vasilache   return success(folded);
1420ee4bf15SNicolas Vasilache }
1430ee4bf15SNicolas Vasilache 
areCastCompatible(TypeRange inputs,TypeRange outputs)1446ccf2d62SRiver Riddle bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1456ccf2d62SRiver Riddle   if (inputs.size() != 1 || outputs.size() != 1)
1466ccf2d62SRiver Riddle     return false;
1476ccf2d62SRiver Riddle   Type a = inputs.front(), b = outputs.front();
148129d6e55SSean Silva   auto aT = a.dyn_cast<TensorType>();
149129d6e55SSean Silva   auto bT = b.dyn_cast<TensorType>();
150129d6e55SSean Silva   if (!aT || !bT)
151129d6e55SSean Silva     return false;
152129d6e55SSean Silva 
153129d6e55SSean Silva   if (aT.getElementType() != bT.getElementType())
154129d6e55SSean Silva     return false;
155129d6e55SSean Silva 
156129d6e55SSean Silva   return succeeded(verifyCompatibleShape(aT, bT));
157129d6e55SSean Silva }
158129d6e55SSean Silva 
159129d6e55SSean Silva /// Compute a TensorType that has the joined shape knowledge of the two
160129d6e55SSean Silva /// given TensorTypes. The element types need to match.
joinShapes(TensorType one,TensorType two)161129d6e55SSean Silva static TensorType joinShapes(TensorType one, TensorType two) {
162129d6e55SSean Silva   assert(one.getElementType() == two.getElementType());
163129d6e55SSean Silva 
164129d6e55SSean Silva   if (!one.hasRank())
165129d6e55SSean Silva     return two;
166129d6e55SSean Silva   if (!two.hasRank())
167129d6e55SSean Silva     return one;
168129d6e55SSean Silva 
169129d6e55SSean Silva   int64_t rank = one.getRank();
170129d6e55SSean Silva   if (rank != two.getRank())
171129d6e55SSean Silva     return {};
172129d6e55SSean Silva 
173129d6e55SSean Silva   SmallVector<int64_t, 4> join;
174129d6e55SSean Silva   join.reserve(rank);
175129d6e55SSean Silva   for (int64_t i = 0; i < rank; ++i) {
176129d6e55SSean Silva     if (one.isDynamicDim(i)) {
177129d6e55SSean Silva       join.push_back(two.getDimSize(i));
178129d6e55SSean Silva       continue;
179129d6e55SSean Silva     }
180129d6e55SSean Silva     if (two.isDynamicDim(i)) {
181129d6e55SSean Silva       join.push_back(one.getDimSize(i));
182129d6e55SSean Silva       continue;
183129d6e55SSean Silva     }
184129d6e55SSean Silva     if (one.getDimSize(i) != two.getDimSize(i))
185129d6e55SSean Silva       return {};
186129d6e55SSean Silva     join.push_back(one.getDimSize(i));
187129d6e55SSean Silva   }
188129d6e55SSean Silva   return RankedTensorType::get(join, one.getElementType());
189129d6e55SSean Silva }
190129d6e55SSean Silva 
191129d6e55SSean Silva namespace {
192129d6e55SSean Silva 
193129d6e55SSean Silva /// Replaces chains of two tensor.cast operations by a single tensor.cast
194129d6e55SSean Silva /// operation if doing so does not remove runtime constraints.
195129d6e55SSean Silva struct ChainedTensorCast : public OpRewritePattern<CastOp> {
196129d6e55SSean Silva   using OpRewritePattern<CastOp>::OpRewritePattern;
197129d6e55SSean Silva 
matchAndRewrite__anon3fb9f79f0111::ChainedTensorCast198129d6e55SSean Silva   LogicalResult matchAndRewrite(CastOp tensorCast,
199129d6e55SSean Silva                                 PatternRewriter &rewriter) const final {
200129d6e55SSean Silva     auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
201129d6e55SSean Silva 
202129d6e55SSean Silva     if (!tensorCastOperand)
203129d6e55SSean Silva       return failure();
204129d6e55SSean Silva 
205129d6e55SSean Silva     auto sourceType =
206129d6e55SSean Silva         tensorCastOperand.getOperand().getType().cast<TensorType>();
207129d6e55SSean Silva     auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
208129d6e55SSean Silva     auto resultType = tensorCast.getType().cast<TensorType>();
209129d6e55SSean Silva 
210129d6e55SSean Silva     // We can remove the intermediate cast if joining all three produces the
211129d6e55SSean Silva     // same result as just joining the source and result shapes.
212129d6e55SSean Silva     auto firstJoin =
213129d6e55SSean Silva         joinShapes(joinShapes(sourceType, intermediateType), resultType);
214129d6e55SSean Silva 
215129d6e55SSean Silva     // The join might not exist if the cast sequence would fail at runtime.
216129d6e55SSean Silva     if (!firstJoin)
217129d6e55SSean Silva       return failure();
218129d6e55SSean Silva 
219129d6e55SSean Silva     // The newJoin always exists if the above join exists, it might just contain
220129d6e55SSean Silva     // less information. If so, we cannot drop the intermediate cast, as doing
221129d6e55SSean Silva     // so would remove runtime checks.
222129d6e55SSean Silva     auto newJoin = joinShapes(sourceType, resultType);
223129d6e55SSean Silva     if (firstJoin != newJoin)
224129d6e55SSean Silva       return failure();
225129d6e55SSean Silva 
226129d6e55SSean Silva     rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
227129d6e55SSean Silva                                         tensorCastOperand.getOperand());
228129d6e55SSean Silva     return success();
229129d6e55SSean Silva   }
230129d6e55SSean Silva };
231129d6e55SSean Silva 
232f2676b15SThomas Raoux /// Fold tensor.cast into tesor.extract_slice producer.
233f2676b15SThomas Raoux /// Example:
234f2676b15SThomas Raoux /// ```
235f2676b15SThomas Raoux ///  %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
236f2676b15SThomas Raoux ///    tensor<128x512xf32> to tensor<?x512xf32>
237f2676b15SThomas Raoux ///  %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
238f2676b15SThomas Raoux /// ```
239f2676b15SThomas Raoux /// ->
240f2676b15SThomas Raoux /// ```
241f2676b15SThomas Raoux /// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
242f2676b15SThomas Raoux ///   tensor<128x512xf32> to tensor<16x512xf32>
243f2676b15SThomas Raoux /// ```
244f2676b15SThomas Raoux struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
245f2676b15SThomas Raoux   using OpRewritePattern<CastOp>::OpRewritePattern;
246f2676b15SThomas Raoux 
matchAndRewrite__anon3fb9f79f0111::TensorCastExtractSlice247f2676b15SThomas Raoux   LogicalResult matchAndRewrite(CastOp tensorCast,
248f2676b15SThomas Raoux                                 PatternRewriter &rewriter) const final {
249f2676b15SThomas Raoux     auto extractOperand =
250f2676b15SThomas Raoux         tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
251f2676b15SThomas Raoux 
252f2676b15SThomas Raoux     if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
2532d70eff8SJacques Pienaar         tensorCast.getType().getShape() == tensorCast.getSource()
2542d70eff8SJacques Pienaar                                                .getType()
2552d70eff8SJacques Pienaar                                                .cast<RankedTensorType>()
2562d70eff8SJacques Pienaar                                                .getShape())
257f2676b15SThomas Raoux       return failure();
258f2676b15SThomas Raoux 
259f2676b15SThomas Raoux     SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
260f2676b15SThomas Raoux     auto dimMask = computeRankReductionMask(
2612d70eff8SJacques Pienaar         extractFromI64ArrayAttr(extractOperand.getStaticSizes()),
262f2676b15SThomas Raoux         extractOperand.getType().getShape());
263f2676b15SThomas Raoux     size_t dimIndex = 0;
264f2676b15SThomas Raoux     for (size_t i = 0, e = sizes.size(); i < e; i++) {
265f2676b15SThomas Raoux       if (dimMask && dimMask->count(i))
266f2676b15SThomas Raoux         continue;
267f2676b15SThomas Raoux       int64_t dim = tensorCast.getType().getShape()[dimIndex++];
268f2676b15SThomas Raoux       if (ShapedType::isDynamic(dim))
269f2676b15SThomas Raoux         continue;
270f2676b15SThomas Raoux       sizes[i] = rewriter.getIndexAttr(dim);
271f2676b15SThomas Raoux     }
272f2676b15SThomas Raoux 
273f2676b15SThomas Raoux     rewriter.replaceOpWithNewOp<ExtractSliceOp>(
274f2676b15SThomas Raoux         tensorCast, tensorCast.getType().cast<RankedTensorType>(),
2752d70eff8SJacques Pienaar         extractOperand.getSource(), extractOperand.getMixedOffsets(), sizes,
276f2676b15SThomas Raoux         extractOperand.getMixedStrides());
277f2676b15SThomas Raoux     return success();
278f2676b15SThomas Raoux   }
279f2676b15SThomas Raoux };
280f2676b15SThomas Raoux 
281129d6e55SSean Silva } // namespace
282129d6e55SSean Silva 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)283dc4e913bSChris Lattner void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
284129d6e55SSean Silva                                          MLIRContext *context) {
285f2676b15SThomas Raoux   results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
286129d6e55SSean Silva }
287129d6e55SSean Silva 
288129d6e55SSean Silva //===----------------------------------------------------------------------===//
289c0a6318dSMatthias Springer // DimOp
290c0a6318dSMatthias Springer //===----------------------------------------------------------------------===//
291c0a6318dSMatthias Springer 
build(OpBuilder & builder,OperationState & result,Value source,int64_t index)292c0a6318dSMatthias Springer void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
293c0a6318dSMatthias Springer                   int64_t index) {
294c0a6318dSMatthias Springer   auto loc = result.location;
295a54f4eaeSMogball   Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
296c0a6318dSMatthias Springer   build(builder, result, source, indexValue);
297c0a6318dSMatthias Springer }
298c0a6318dSMatthias Springer 
getConstantIndex()299c0a6318dSMatthias Springer Optional<int64_t> DimOp::getConstantIndex() {
3002d70eff8SJacques Pienaar   if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
301cfb72fd3SJacques Pienaar     return constantOp.getValue().cast<IntegerAttr>().getInt();
302c0a6318dSMatthias Springer   return {};
303c0a6318dSMatthias Springer }
304c0a6318dSMatthias Springer 
verify()305b98dc035SRiver Riddle LogicalResult DimOp::verify() {
306c0a6318dSMatthias Springer   // Assume unknown index to be in range.
307b98dc035SRiver Riddle   Optional<int64_t> index = getConstantIndex();
308037f0995SKazu Hirata   if (!index)
309c0a6318dSMatthias Springer     return success();
310c0a6318dSMatthias Springer 
311c0a6318dSMatthias Springer   // Check that constant index is not knowingly out of range.
3122d70eff8SJacques Pienaar   auto type = getSource().getType();
313c0a6318dSMatthias Springer   if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
3146d5fc1e3SKazu Hirata     if (*index >= tensorType.getRank())
315b98dc035SRiver Riddle       return emitOpError("index is out of range");
316c0a6318dSMatthias Springer   } else if (type.isa<UnrankedTensorType>()) {
317c0a6318dSMatthias Springer     // Assume index to be in range.
318c0a6318dSMatthias Springer   } else {
319c0a6318dSMatthias Springer     llvm_unreachable("expected operand with tensor type");
320c0a6318dSMatthias Springer   }
321c0a6318dSMatthias Springer   return success();
322c0a6318dSMatthias Springer }
323c0a6318dSMatthias Springer 
fold(ArrayRef<Attribute> operands)324c0a6318dSMatthias Springer OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
325c0a6318dSMatthias Springer   // All forms of folding require a known index.
326c0a6318dSMatthias Springer   auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
327c0a6318dSMatthias Springer   if (!index)
328c0a6318dSMatthias Springer     return {};
329c0a6318dSMatthias Springer 
330c0a6318dSMatthias Springer   // Folding for unranked types (UnrankedTensorType) is not supported.
3312d70eff8SJacques Pienaar   auto tensorType = getSource().getType().dyn_cast<RankedTensorType>();
332c0a6318dSMatthias Springer   if (!tensorType)
333c0a6318dSMatthias Springer     return {};
334c0a6318dSMatthias Springer 
335c0a6318dSMatthias Springer   // Fold if the shape extent along the given index is known.
336c0a6318dSMatthias Springer   if (!tensorType.isDynamicDim(index.getInt())) {
337c0a6318dSMatthias Springer     Builder builder(getContext());
338c0a6318dSMatthias Springer     return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
339c0a6318dSMatthias Springer   }
340c0a6318dSMatthias Springer 
3412d70eff8SJacques Pienaar   Operation *definingOp = getSource().getDefiningOp();
342c0a6318dSMatthias Springer 
343c0a6318dSMatthias Springer   // Fold dim to the operand of tensor.generate.
344c0a6318dSMatthias Springer   if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
345c0a6318dSMatthias Springer     auto resultType =
346c0a6318dSMatthias Springer         fromElements.getResult().getType().cast<RankedTensorType>();
347c0a6318dSMatthias Springer     // The case where the type encodes the size of the dimension is handled
348c0a6318dSMatthias Springer     // above.
349676bfb2aSRiver Riddle     assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
350c0a6318dSMatthias Springer 
351c0a6318dSMatthias Springer     // Find the operand of the fromElements that corresponds to this index.
3522d70eff8SJacques Pienaar     auto dynExtents = fromElements.getDynamicExtents().begin();
353c0a6318dSMatthias Springer     for (auto dim : resultType.getShape().take_front(index.getInt()))
354676bfb2aSRiver Riddle       if (ShapedType::isDynamic(dim))
355c0a6318dSMatthias Springer         dynExtents++;
356c0a6318dSMatthias Springer 
357c0a6318dSMatthias Springer     return Value{*dynExtents};
358c0a6318dSMatthias Springer   }
359c0a6318dSMatthias Springer 
360c0a6318dSMatthias Springer   // The size at the given index is now known to be a dynamic size.
361c0a6318dSMatthias Springer   unsigned unsignedIndex = index.getValue().getZExtValue();
362c0a6318dSMatthias Springer 
36331f80393SNicolas Vasilache   if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
36442819463SMaheshRavishankar     // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
36542819463SMaheshRavishankar     // `resolve-shaped-type-result-dims` pass.
36642819463SMaheshRavishankar     if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
36742819463SMaheshRavishankar         sliceOp.isDynamicSize(unsignedIndex)) {
36842819463SMaheshRavishankar       return {sliceOp.getDynamicSize(unsignedIndex)};
36942819463SMaheshRavishankar     }
370c0a6318dSMatthias Springer   }
371c0a6318dSMatthias Springer 
372c0a6318dSMatthias Springer   // dim(cast) -> dim
373c0a6318dSMatthias Springer   if (succeeded(foldTensorCast(*this)))
374c0a6318dSMatthias Springer     return getResult();
375c0a6318dSMatthias Springer 
376c0a6318dSMatthias Springer   return {};
377c0a6318dSMatthias Springer }
378c0a6318dSMatthias Springer 
379c0a6318dSMatthias Springer namespace {
380c0a6318dSMatthias Springer /// Fold dim of a cast into the dim of the source of the tensor cast.
381c0a6318dSMatthias Springer struct DimOfCastOp : public OpRewritePattern<DimOp> {
382c0a6318dSMatthias Springer   using OpRewritePattern<DimOp>::OpRewritePattern;
383c0a6318dSMatthias Springer 
matchAndRewrite__anon3fb9f79f0211::DimOfCastOp384c0a6318dSMatthias Springer   LogicalResult matchAndRewrite(DimOp dimOp,
385c0a6318dSMatthias Springer                                 PatternRewriter &rewriter) const override {
3862d70eff8SJacques Pienaar     auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
387c0a6318dSMatthias Springer     if (!castOp)
388c0a6318dSMatthias Springer       return failure();
389c0a6318dSMatthias Springer     Value newSource = castOp.getOperand();
3902d70eff8SJacques Pienaar     rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex());
391c0a6318dSMatthias Springer     return success();
392c0a6318dSMatthias Springer   }
393c0a6318dSMatthias Springer };
394be0a7e9fSMehdi Amini } // namespace
395c0a6318dSMatthias Springer 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)396c0a6318dSMatthias Springer void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
397c0a6318dSMatthias Springer                                         MLIRContext *context) {
398c0a6318dSMatthias Springer   results.add<DimOfCastOp>(context);
399c0a6318dSMatthias Springer }
400c0a6318dSMatthias Springer 
401c0a6318dSMatthias Springer //===----------------------------------------------------------------------===//
402444822d7SSean Silva // ExtractOp
403444822d7SSean Silva //===----------------------------------------------------------------------===//
404444822d7SSean Silva 
verify()405b98dc035SRiver Riddle LogicalResult ExtractOp::verify() {
406444822d7SSean Silva   // Verify the # indices match if we have a ranked type.
4072d70eff8SJacques Pienaar   if (auto tensorType = getTensor().getType().dyn_cast<RankedTensorType>())
4082d70eff8SJacques Pienaar     if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
409b98dc035SRiver Riddle       return emitOpError("incorrect number of indices for extract_element");
410444822d7SSean Silva 
411444822d7SSean Silva   return success();
412444822d7SSean Silva }
413444822d7SSean Silva 
fold(ArrayRef<Attribute> operands)414444822d7SSean Silva OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
415444822d7SSean Silva   // If this is a splat elements attribute, simply return the value. All of the
416444822d7SSean Silva   // elements of a splat attribute are the same.
4178544523dSMatthias Springer   if (Attribute tensor = operands.front())
418444822d7SSean Silva     if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
419ae40d625SRiver Riddle       return splatTensor.getSplatValue<Attribute>();
420444822d7SSean Silva 
4218544523dSMatthias Springer   // Collect the constant indices into the tensor.
422444822d7SSean Silva   SmallVector<uint64_t, 8> indices;
423444822d7SSean Silva   for (Attribute indice : llvm::drop_begin(operands, 1)) {
424444822d7SSean Silva     if (!indice || !indice.isa<IntegerAttr>())
425444822d7SSean Silva       return {};
426444822d7SSean Silva     indices.push_back(indice.cast<IntegerAttr>().getInt());
427444822d7SSean Silva   }
428444822d7SSean Silva 
4298544523dSMatthias Springer   // Fold extract(from_elements(...)).
4302d70eff8SJacques Pienaar   if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
4318544523dSMatthias Springer     auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
4328544523dSMatthias Springer     auto rank = tensorType.getRank();
4338544523dSMatthias Springer     assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
4348544523dSMatthias Springer            "rank mismatch");
4358544523dSMatthias Springer     int flatIndex = 0;
4368544523dSMatthias Springer     int stride = 1;
4378544523dSMatthias Springer     for (int i = rank - 1; i >= 0; --i) {
4388544523dSMatthias Springer       if (i < rank - 1)
4398544523dSMatthias Springer         stride *= tensorType.getDimSize(i);
4408544523dSMatthias Springer       flatIndex += indices[i] * stride;
4418544523dSMatthias Springer     }
4428544523dSMatthias Springer     // Prevent out of bounds accesses. This can happen in invalid code that will
4438544523dSMatthias Springer     // never execute.
4442d70eff8SJacques Pienaar     if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
4458544523dSMatthias Springer         flatIndex < 0)
4468544523dSMatthias Springer       return {};
4472d70eff8SJacques Pienaar     return fromElementsOp.getElements()[flatIndex];
4488544523dSMatthias Springer   }
4498544523dSMatthias Springer 
450444822d7SSean Silva   // If this is an elements attribute, query the value at the given indices.
4518544523dSMatthias Springer   if (Attribute tensor = operands.front()) {
452444822d7SSean Silva     auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
453444822d7SSean Silva     if (elementsAttr && elementsAttr.isValidIndex(indices))
454ae40d625SRiver Riddle       return elementsAttr.getValues<Attribute>()[indices];
4558544523dSMatthias Springer   }
4568544523dSMatthias Springer 
457444822d7SSean Silva   return {};
458444822d7SSean Silva }
459444822d7SSean Silva 
460444822d7SSean Silva //===----------------------------------------------------------------------===//
461be7352c0SSean Silva // FromElementsOp
462be7352c0SSean Silva //===----------------------------------------------------------------------===//
463be7352c0SSean Silva 
build(OpBuilder & builder,OperationState & result,Type resultType,ValueRange elements)464be7352c0SSean Silva void FromElementsOp::build(OpBuilder &builder, OperationState &result,
465f77e9f87SAlexander Belyaev                            Type resultType, ValueRange elements) {
466be7352c0SSean Silva   result.addOperands(elements);
467f77e9f87SAlexander Belyaev   result.addTypes(resultType);
468be7352c0SSean Silva }
469be7352c0SSean Silva 
build(OpBuilder & builder,OperationState & result,ValueRange elements)470be7352c0SSean Silva void FromElementsOp::build(OpBuilder &builder, OperationState &result,
471be7352c0SSean Silva                            ValueRange elements) {
472be7352c0SSean Silva   assert(!elements.empty() && "expected at least one element");
473f77e9f87SAlexander Belyaev   Type resultType = RankedTensorType::get(
474f77e9f87SAlexander Belyaev       {static_cast<int64_t>(elements.size())}, elements.front().getType());
475f77e9f87SAlexander Belyaev   build(builder, result, resultType, elements);
476be7352c0SSean Silva }
477be7352c0SSean Silva 
fold(ArrayRef<Attribute> operands)4787b52aeadSBenjamin Kramer OpFoldResult FromElementsOp::fold(ArrayRef<Attribute> operands) {
4797b52aeadSBenjamin Kramer   if (!llvm::is_contained(operands, nullptr))
4807b52aeadSBenjamin Kramer     return DenseElementsAttr::get(getType(), operands);
4817b52aeadSBenjamin Kramer   return {};
4827b52aeadSBenjamin Kramer }
4837b52aeadSBenjamin Kramer 
484be7352c0SSean Silva namespace {
485be7352c0SSean Silva 
4867c984be2SRob Suderman // Pushes the index_casts that occur before extractions to after the extract.
4877c984be2SRob Suderman // This minimizes type conversion in some cases and enables the extract
4887c984be2SRob Suderman // canonicalizer. This changes:
4897c984be2SRob Suderman //
4907c984be2SRob Suderman // %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
4917c984be2SRob Suderman // %extract = tensor.extract %cast[%index] : tensor<1xindex>
4927c984be2SRob Suderman //
4937c984be2SRob Suderman // to the following:
4947c984be2SRob Suderman //
4957c984be2SRob Suderman // %extract = tensor.extract %tensor[%index] : tensor<1xindex>
4967c984be2SRob Suderman // %cast = arith.index_cast %extract : i32 to index
4977c984be2SRob Suderman //
4987c984be2SRob Suderman // to just %element.
4997c984be2SRob Suderman //
5007c984be2SRob Suderman // Consider expanding this to a template and handle all tensor cast operations.
5017c984be2SRob Suderman struct ExtractElementFromIndexCast
5027c984be2SRob Suderman     : public OpRewritePattern<tensor::ExtractOp> {
5037c984be2SRob Suderman   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
5047c984be2SRob Suderman 
matchAndRewrite__anon3fb9f79f0311::ExtractElementFromIndexCast5057c984be2SRob Suderman   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
5067c984be2SRob Suderman                                 PatternRewriter &rewriter) const final {
5077c984be2SRob Suderman     Location loc = extract.getLoc();
5082d70eff8SJacques Pienaar     auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
5097c984be2SRob Suderman     if (!indexCast)
5107c984be2SRob Suderman       return failure();
5117c984be2SRob Suderman 
5127c984be2SRob Suderman     Type elementTy = getElementTypeOrSelf(indexCast.getIn());
5137c984be2SRob Suderman 
5147c984be2SRob Suderman     auto newExtract = rewriter.create<tensor::ExtractOp>(
5152d70eff8SJacques Pienaar         loc, elementTy, indexCast.getIn(), extract.getIndices());
5167c984be2SRob Suderman 
5177c984be2SRob Suderman     rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
5187c984be2SRob Suderman                                                     newExtract);
5197c984be2SRob Suderman 
5207c984be2SRob Suderman     return success();
5217c984be2SRob Suderman   }
5227c984be2SRob Suderman };
5237c984be2SRob Suderman 
524be7352c0SSean Silva } // namespace
525be7352c0SSean Silva 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)526dc4e913bSChris Lattner void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
527dc4e913bSChris Lattner                                                  MLIRContext *context) {
5288544523dSMatthias Springer   results.add<ExtractElementFromIndexCast>(context);
529be7352c0SSean Silva }
530be7352c0SSean Silva 
531be7352c0SSean Silva //===----------------------------------------------------------------------===//
532b4baccc2SHanhan Wang // InsertOp
533b4baccc2SHanhan Wang //===----------------------------------------------------------------------===//
534b4baccc2SHanhan Wang 
verify()535b98dc035SRiver Riddle LogicalResult InsertOp::verify() {
536b4baccc2SHanhan Wang   // Verify the # indices match if we have a ranked type.
5372d70eff8SJacques Pienaar   if (auto destType = getDest().getType().dyn_cast<RankedTensorType>())
5382d70eff8SJacques Pienaar     if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
539b98dc035SRiver Riddle       return emitOpError("incorrect number of indices");
540b4baccc2SHanhan Wang   return success();
541b4baccc2SHanhan Wang }
542b4baccc2SHanhan Wang 
fold(ArrayRef<Attribute> operands)543b4baccc2SHanhan Wang OpFoldResult InsertOp::fold(ArrayRef<Attribute> operands) {
544b4baccc2SHanhan Wang   Attribute scalar = operands[0];
545b4baccc2SHanhan Wang   Attribute dest = operands[1];
546b4baccc2SHanhan Wang   if (scalar && dest)
547b4baccc2SHanhan Wang     if (auto splatDest = dest.dyn_cast<SplatElementsAttr>())
548ae40d625SRiver Riddle       if (scalar == splatDest.getSplatValue<Attribute>())
549b4baccc2SHanhan Wang         return dest;
550b4baccc2SHanhan Wang   return {};
551b4baccc2SHanhan Wang }
552b4baccc2SHanhan Wang 
553b4baccc2SHanhan Wang //===----------------------------------------------------------------------===//
554be7352c0SSean Silva // GenerateOp
555be7352c0SSean Silva //===----------------------------------------------------------------------===//
556be7352c0SSean Silva 
reifyResultShapes(OpBuilder & builder,ReifiedRankedShapedTypeDims & reifiedReturnShapes)557fdb41a22SMatthias Springer LogicalResult GenerateOp::reifyResultShapes(
558fdb41a22SMatthias Springer     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
559fdb41a22SMatthias Springer   reifiedReturnShapes.resize(1, SmallVector<Value>(getType().getRank()));
560fdb41a22SMatthias Springer   int idx = 0;
561fdb41a22SMatthias Springer   for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
562fdb41a22SMatthias Springer     if (getType().isDynamicDim(dim)) {
563fdb41a22SMatthias Springer       reifiedReturnShapes[0][dim] = getOperand(idx++);
564fdb41a22SMatthias Springer     } else {
565fdb41a22SMatthias Springer       reifiedReturnShapes[0][dim] = builder.create<arith::ConstantIndexOp>(
566fdb41a22SMatthias Springer           getLoc(), getType().getDimSize(dim));
567fdb41a22SMatthias Springer     }
568fdb41a22SMatthias Springer   }
569fdb41a22SMatthias Springer   return success();
570fdb41a22SMatthias Springer }
571fdb41a22SMatthias Springer 
verify()572b98dc035SRiver Riddle LogicalResult GenerateOp::verify() {
573be7352c0SSean Silva   // Ensure that the tensor type has as many dynamic dimensions as are specified
574be7352c0SSean Silva   // by the operands.
575b98dc035SRiver Riddle   RankedTensorType resultTy = getType().cast<RankedTensorType>();
576b98dc035SRiver Riddle   if (getNumOperands() != resultTy.getNumDynamicDims())
577b98dc035SRiver Riddle     return emitError("must have as many index operands as dynamic extents "
578be7352c0SSean Silva                      "in the result type");
579be7352c0SSean Silva 
580ed645f63SChia-hung Duan   return success();
581ed645f63SChia-hung Duan }
582ed645f63SChia-hung Duan 
verifyRegions()583ed645f63SChia-hung Duan LogicalResult GenerateOp::verifyRegions() {
584ed645f63SChia-hung Duan   RankedTensorType resultTy = getType().cast<RankedTensorType>();
585be7352c0SSean Silva   // Ensure that region arguments span the index space.
5862d70eff8SJacques Pienaar   if (!llvm::all_of(getBody().getArgumentTypes(),
587be7352c0SSean Silva                     [](Type ty) { return ty.isIndex(); }))
588b98dc035SRiver Riddle     return emitError("all body arguments must be index");
5892d70eff8SJacques Pienaar   if (getBody().getNumArguments() != resultTy.getRank())
590b98dc035SRiver Riddle     return emitError("must have one body argument per input dimension");
591be7352c0SSean Silva 
592be7352c0SSean Silva   // Ensure that the region yields an element of the right type.
5932d70eff8SJacques Pienaar   auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
594fd0c6f53SAlexander Belyaev 
5952d70eff8SJacques Pienaar   if (yieldOp.getValue().getType() != resultTy.getElementType())
596b98dc035SRiver Riddle     return emitOpError(
597be7352c0SSean Silva         "body must be terminated with a `yield` operation of the tensor "
598be7352c0SSean Silva         "element type");
599be7352c0SSean Silva 
600be7352c0SSean Silva   return success();
601be7352c0SSean Silva }
602be7352c0SSean Silva 
build(OpBuilder & b,OperationState & result,Type resultTy,ValueRange dynamicExtents,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilder)603be7352c0SSean Silva void GenerateOp::build(
604be7352c0SSean Silva     OpBuilder &b, OperationState &result, Type resultTy,
605be7352c0SSean Silva     ValueRange dynamicExtents,
606be7352c0SSean Silva     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
607be7352c0SSean Silva   build(b, result, resultTy, dynamicExtents);
608be7352c0SSean Silva 
609be7352c0SSean Silva   // Build and populate body.
610be7352c0SSean Silva   OpBuilder::InsertionGuard guard(b);
611be7352c0SSean Silva   Region *bodyRegion = result.regions.front().get();
612be7352c0SSean Silva   auto rank = resultTy.cast<RankedTensorType>().getRank();
613be7352c0SSean Silva   SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
614e084679fSRiver Riddle   SmallVector<Location, 2> argumentLocs(rank, result.location);
615be7352c0SSean Silva   Block *bodyBlock =
616e084679fSRiver Riddle       b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);
617be7352c0SSean Silva   bodyBuilder(b, result.location, bodyBlock->getArguments());
618be7352c0SSean Silva }
619be7352c0SSean Silva 
620be7352c0SSean Silva namespace {
621be7352c0SSean Silva 
622be7352c0SSean Silva /// Canonicalizes tensor.generate operations with a constant
623be7352c0SSean Silva /// operand into the equivalent operation with the operand expressed in the
624be7352c0SSean Silva /// result type, instead. We also insert a type cast to make sure that the
625be7352c0SSean Silva /// resulting IR is still well-typed.
626be7352c0SSean Silva struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
627be7352c0SSean Silva   using OpRewritePattern<GenerateOp>::OpRewritePattern;
628be7352c0SSean Silva 
matchAndRewrite__anon3fb9f79f0511::StaticTensorGenerate629be7352c0SSean Silva   LogicalResult matchAndRewrite(GenerateOp tensorFromElements,
630be7352c0SSean Silva                                 PatternRewriter &rewriter) const final {
631be7352c0SSean Silva     auto resultType =
632be7352c0SSean Silva         tensorFromElements.getResult().getType().cast<RankedTensorType>();
633be7352c0SSean Silva 
634be7352c0SSean Silva     if (resultType.hasStaticShape())
635be7352c0SSean Silva       return failure();
636be7352c0SSean Silva 
637be7352c0SSean Silva     SmallVector<Value, 4> newOperands;
638be7352c0SSean Silva     SmallVector<int64_t, 4> newShape;
6392d70eff8SJacques Pienaar     auto operandsIt = tensorFromElements.getDynamicExtents().begin();
640be7352c0SSean Silva 
641be7352c0SSean Silva     for (int64_t dim : resultType.getShape()) {
642676bfb2aSRiver Riddle       if (!ShapedType::isDynamic(dim)) {
643be7352c0SSean Silva         newShape.push_back(dim);
644be7352c0SSean Silva         continue;
645be7352c0SSean Silva       }
646be7352c0SSean Silva       APInt index;
647be7352c0SSean Silva       if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
648676bfb2aSRiver Riddle         newShape.push_back(ShapedType::kDynamicSize);
649be7352c0SSean Silva         newOperands.push_back(*operandsIt++);
650be7352c0SSean Silva         continue;
651be7352c0SSean Silva       }
652be7352c0SSean Silva       newShape.push_back(index.getSExtValue());
653be7352c0SSean Silva       operandsIt++;
654be7352c0SSean Silva     }
655be7352c0SSean Silva 
6562d70eff8SJacques Pienaar     if (newOperands.size() == tensorFromElements.getDynamicExtents().size())
657be7352c0SSean Silva       return failure();
658be7352c0SSean Silva 
659be7352c0SSean Silva     auto loc = tensorFromElements.getLoc();
660be7352c0SSean Silva     auto newOp = rewriter.create<GenerateOp>(
661be7352c0SSean Silva         loc, RankedTensorType::get(newShape, resultType.getElementType()),
662be7352c0SSean Silva         newOperands);
6632d70eff8SJacques Pienaar     rewriter.inlineRegionBefore(tensorFromElements.getBody(), newOp.getBody(),
6642d70eff8SJacques Pienaar                                 newOp.getBody().begin());
665be7352c0SSean Silva     rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType,
666be7352c0SSean Silva                                                 newOp);
667be7352c0SSean Silva     return success();
668be7352c0SSean Silva   }
669be7352c0SSean Silva };
670be7352c0SSean Silva 
671be7352c0SSean Silva /// Canonicalizes the pattern of the form
672be7352c0SSean Silva ///
673be7352c0SSean Silva /// %tensor = tensor.generate %x {
674d75c3e83SRiver Riddle ///   ^bb0(%arg0: index):
675be7352c0SSean Silva ///   <computation>
676be7352c0SSean Silva ///   yield %1 : index
677be7352c0SSean Silva /// } : tensor<?xindex>
678be7352c0SSean Silva /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
679be7352c0SSean Silva ///
680be7352c0SSean Silva /// to just <computation> with %arg0 replaced by %c0. We only do this if the
681be7352c0SSean Silva /// tensor.generate operation has no side-effects.
682be7352c0SSean Silva struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
683be7352c0SSean Silva   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
684be7352c0SSean Silva 
matchAndRewrite__anon3fb9f79f0511::ExtractFromTensorGenerate685be7352c0SSean Silva   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
686be7352c0SSean Silva                                 PatternRewriter &rewriter) const final {
6872d70eff8SJacques Pienaar     auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
688be7352c0SSean Silva     if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
689be7352c0SSean Silva       return failure();
690be7352c0SSean Silva 
691be7352c0SSean Silva     BlockAndValueMapping mapping;
692eca86cb2SJacques Pienaar     Block *body = &tensorFromElements.getBody().front();
6932d70eff8SJacques Pienaar     mapping.map(body->getArguments(), extract.getIndices());
694be7352c0SSean Silva     for (auto &op : body->without_terminator())
695be7352c0SSean Silva       rewriter.clone(op, mapping);
696be7352c0SSean Silva 
697be7352c0SSean Silva     auto yield = cast<YieldOp>(body->getTerminator());
698be7352c0SSean Silva 
6992d70eff8SJacques Pienaar     rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue()));
700be7352c0SSean Silva     return success();
701be7352c0SSean Silva   }
702be7352c0SSean Silva };
703be7352c0SSean Silva 
704be7352c0SSean Silva /// Canonicalizes the pattern of the form
705be7352c0SSean Silva ///
706be7352c0SSean Silva /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
707be7352c0SSean Silva /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
708be7352c0SSean Silva ///
709be7352c0SSean Silva /// to
710be7352c0SSean Silva ///
711be7352c0SSean Silva /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
712be7352c0SSean Silva struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
713be7352c0SSean Silva   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
714be7352c0SSean Silva 
matchAndRewrite__anon3fb9f79f0511::ExtractFromTensorCast715be7352c0SSean Silva   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
716be7352c0SSean Silva                                 PatternRewriter &rewriter) const final {
7172d70eff8SJacques Pienaar     auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
718be7352c0SSean Silva     if (!tensorCast)
719be7352c0SSean Silva       return failure();
720be7352c0SSean Silva 
7212d70eff8SJacques Pienaar     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
7222d70eff8SJacques Pienaar         extract, tensorCast.getSource(), extract.getIndices());
723be7352c0SSean Silva     return success();
724be7352c0SSean Silva   }
725be7352c0SSean Silva };
726be7352c0SSean Silva 
727be7352c0SSean Silva } // namespace
728be7352c0SSean Silva 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)729dc4e913bSChris Lattner void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
730be7352c0SSean Silva                                              MLIRContext *context) {
731be7352c0SSean Silva   // TODO: Move extract patterns to tensor::ExtractOp.
732dc4e913bSChris Lattner   results.add<ExtractFromTensorGenerate, ExtractFromTensorCast,
733be7352c0SSean Silva               StaticTensorGenerate>(context);
734be7352c0SSean Silva }
735be7352c0SSean Silva 
736be7352c0SSean Silva //===----------------------------------------------------------------------===//
73715f8f3e2SAlexander Belyaev // RankOp
73815f8f3e2SAlexander Belyaev //===----------------------------------------------------------------------===//
73915f8f3e2SAlexander Belyaev 
fold(ArrayRef<Attribute> operands)74015f8f3e2SAlexander Belyaev OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
74115f8f3e2SAlexander Belyaev   // Constant fold rank when the rank of the operand is known.
74215f8f3e2SAlexander Belyaev   auto type = getOperand().getType();
74315f8f3e2SAlexander Belyaev   auto shapedType = type.dyn_cast<ShapedType>();
74415f8f3e2SAlexander Belyaev   if (shapedType && shapedType.hasRank())
74515f8f3e2SAlexander Belyaev     return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
74615f8f3e2SAlexander Belyaev   return IntegerAttr();
74715f8f3e2SAlexander Belyaev }
74815f8f3e2SAlexander Belyaev 
74915f8f3e2SAlexander Belyaev //===----------------------------------------------------------------------===//
7500724911dSAlexander Belyaev // ReshapeOp
7510724911dSAlexander Belyaev //===----------------------------------------------------------------------===//
7520724911dSAlexander Belyaev 
getNumElements(ShapedType type)75302b6fb21SMehdi Amini static int64_t getNumElements(ShapedType type) {
7540724911dSAlexander Belyaev   int64_t numElements = 1;
7550724911dSAlexander Belyaev   for (auto dim : type.getShape())
7560724911dSAlexander Belyaev     numElements *= dim;
7570724911dSAlexander Belyaev   return numElements;
7580724911dSAlexander Belyaev }
7590724911dSAlexander Belyaev 
verify()760b98dc035SRiver Riddle LogicalResult ReshapeOp::verify() {
7612d70eff8SJacques Pienaar   TensorType operandType = getSource().getType().cast<TensorType>();
7622d70eff8SJacques Pienaar   TensorType resultType = getResult().getType().cast<TensorType>();
7630724911dSAlexander Belyaev 
7640724911dSAlexander Belyaev   if (operandType.getElementType() != resultType.getElementType())
765b98dc035SRiver Riddle     return emitOpError("element types of source and destination tensor "
7660724911dSAlexander Belyaev                        "types should be the same");
7670724911dSAlexander Belyaev 
7682d70eff8SJacques Pienaar   int64_t shapeSize =
7692d70eff8SJacques Pienaar       getShape().getType().cast<RankedTensorType>().getDimSize(0);
7700724911dSAlexander Belyaev   auto resultRankedType = resultType.dyn_cast<RankedTensorType>();
7710724911dSAlexander Belyaev   auto operandRankedType = operandType.dyn_cast<RankedTensorType>();
7720724911dSAlexander Belyaev 
7730724911dSAlexander Belyaev   if (resultRankedType) {
7740724911dSAlexander Belyaev     if (operandRankedType && resultRankedType.hasStaticShape() &&
7750724911dSAlexander Belyaev         operandRankedType.hasStaticShape()) {
77602b6fb21SMehdi Amini       if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
777b98dc035SRiver Riddle         return emitOpError("source and destination tensor should have the "
7780724911dSAlexander Belyaev                            "same number of elements");
7790724911dSAlexander Belyaev     }
780676bfb2aSRiver Riddle     if (ShapedType::isDynamic(shapeSize))
781b98dc035SRiver Riddle       return emitOpError("cannot use shape operand with dynamic length to "
7820724911dSAlexander Belyaev                          "reshape to statically-ranked tensor type");
7830724911dSAlexander Belyaev     if (shapeSize != resultRankedType.getRank())
784b98dc035SRiver Riddle       return emitOpError(
7850724911dSAlexander Belyaev           "length of shape operand differs from the result's tensor rank");
7860724911dSAlexander Belyaev   }
7870724911dSAlexander Belyaev   return success();
7880724911dSAlexander Belyaev }
7890724911dSAlexander Belyaev 
7900724911dSAlexander Belyaev //===----------------------------------------------------------------------===//
791b618880eSAlexander Belyaev // Reassociative reshape ops
792b618880eSAlexander Belyaev //===----------------------------------------------------------------------===//
793b618880eSAlexander Belyaev 
getReassociationMaps()794b618880eSAlexander Belyaev SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
795b618880eSAlexander Belyaev   return getSymbolLessAffineMaps(getReassociationExprs());
796b618880eSAlexander Belyaev }
getReassociationExprs()797b618880eSAlexander Belyaev SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
798b618880eSAlexander Belyaev   return convertReassociationIndicesToExprs(getContext(),
799b618880eSAlexander Belyaev                                             getReassociationIndices());
800b618880eSAlexander Belyaev }
801b618880eSAlexander Belyaev 
getReassociationMaps()802b618880eSAlexander Belyaev SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
803b618880eSAlexander Belyaev   return getSymbolLessAffineMaps(getReassociationExprs());
804b618880eSAlexander Belyaev }
getReassociationExprs()805b618880eSAlexander Belyaev SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
806b618880eSAlexander Belyaev   return convertReassociationIndicesToExprs(getContext(),
807b618880eSAlexander Belyaev                                             getReassociationIndices());
808b618880eSAlexander Belyaev }
809b618880eSAlexander Belyaev 
810b618880eSAlexander Belyaev /// Compute the RankedTensorType obtained by applying `reassociation` to `type`.
811b618880eSAlexander Belyaev static RankedTensorType
computeTensorReshapeCollapsedType(RankedTensorType type,ArrayRef<AffineMap> reassociation)812b618880eSAlexander Belyaev computeTensorReshapeCollapsedType(RankedTensorType type,
813b618880eSAlexander Belyaev                                   ArrayRef<AffineMap> reassociation) {
814b618880eSAlexander Belyaev   auto shape = type.getShape();
815b618880eSAlexander Belyaev   SmallVector<int64_t, 4> newShape;
816b618880eSAlexander Belyaev   newShape.reserve(reassociation.size());
817b618880eSAlexander Belyaev 
818b618880eSAlexander Belyaev   // Use the fact that reassociation is valid to simplify the logic: only use
819b618880eSAlexander Belyaev   // each map's rank.
820b618880eSAlexander Belyaev   assert(isReassociationValid(reassociation) && "invalid reassociation");
821b618880eSAlexander Belyaev   unsigned currentDim = 0;
822b618880eSAlexander Belyaev   for (AffineMap m : reassociation) {
823b618880eSAlexander Belyaev     unsigned dim = m.getNumResults();
824b618880eSAlexander Belyaev     auto band = shape.slice(currentDim, dim);
825b618880eSAlexander Belyaev     int64_t size = 1;
826b618880eSAlexander Belyaev     if (llvm::is_contained(band, ShapedType::kDynamicSize))
827b618880eSAlexander Belyaev       size = ShapedType::kDynamicSize;
828b618880eSAlexander Belyaev     else
829b618880eSAlexander Belyaev       for (unsigned d = 0; d < dim; ++d)
830b618880eSAlexander Belyaev         size *= shape[currentDim + d];
831b618880eSAlexander Belyaev     newShape.push_back(size);
832b618880eSAlexander Belyaev     currentDim += dim;
833b618880eSAlexander Belyaev   }
834b618880eSAlexander Belyaev 
835b618880eSAlexander Belyaev   return RankedTensorType::get(newShape, type.getElementType());
836b618880eSAlexander Belyaev }
837b618880eSAlexander Belyaev 
build(OpBuilder & b,OperationState & result,Value src,ArrayRef<ReassociationIndices> reassociation,ArrayRef<NamedAttribute> attrs)838b618880eSAlexander Belyaev void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
839b618880eSAlexander Belyaev                             ArrayRef<ReassociationIndices> reassociation,
840b618880eSAlexander Belyaev                             ArrayRef<NamedAttribute> attrs) {
841b618880eSAlexander Belyaev   auto resultType = computeTensorReshapeCollapsedType(
842b618880eSAlexander Belyaev       src.getType().cast<RankedTensorType>(),
843b618880eSAlexander Belyaev       getSymbolLessAffineMaps(
844b618880eSAlexander Belyaev           convertReassociationIndicesToExprs(b.getContext(), reassociation)));
845b618880eSAlexander Belyaev   build(b, result, resultType, src, attrs);
846eca86cb2SJacques Pienaar   result.addAttribute(getReassociationAttrStrName(),
847b618880eSAlexander Belyaev                       getReassociationIndicesAttribute(b, reassociation));
848b618880eSAlexander Belyaev }
849b618880eSAlexander Belyaev 
850e7d3ba10SAart Bik // Checks if types are the same, but ignoring encoding on ranked tensors.
isSameTypesWithoutEncoding(Type tp1,Type tp2)851e7d3ba10SAart Bik static bool isSameTypesWithoutEncoding(Type tp1, Type tp2) {
852e7d3ba10SAart Bik   if (auto rtp1 = tp1.dyn_cast<RankedTensorType>()) {
853e7d3ba10SAart Bik     if (auto rtp2 = tp2.dyn_cast<RankedTensorType>())
854e7d3ba10SAart Bik       return rtp1.getShape() == rtp2.getShape() &&
855e7d3ba10SAart Bik              rtp1.getElementType() == rtp2.getElementType();
856e7d3ba10SAart Bik     return false;
857e7d3ba10SAart Bik   }
858e7d3ba10SAart Bik   // Default implementation.
859e7d3ba10SAart Bik   return tp1 == tp2;
860e7d3ba10SAart Bik }
861e7d3ba10SAart Bik 
862b618880eSAlexander Belyaev template <typename TensorReshapeOp, bool isExpansion = std::is_same<
863b618880eSAlexander Belyaev                                         TensorReshapeOp, ExpandShapeOp>::value>
verifyTensorReshapeOp(TensorReshapeOp op,RankedTensorType expandedType,RankedTensorType collapsedType)864b618880eSAlexander Belyaev static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
865b618880eSAlexander Belyaev                                            RankedTensorType expandedType,
866b618880eSAlexander Belyaev                                            RankedTensorType collapsedType) {
867b618880eSAlexander Belyaev   if (failed(
868b618880eSAlexander Belyaev           verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
869b618880eSAlexander Belyaev     return failure();
870b618880eSAlexander Belyaev 
871b618880eSAlexander Belyaev   auto maps = op.getReassociationMaps();
872b618880eSAlexander Belyaev   RankedTensorType expectedType =
873b618880eSAlexander Belyaev       computeTensorReshapeCollapsedType(expandedType, maps);
874e7d3ba10SAart Bik   if (!isSameTypesWithoutEncoding(collapsedType, expectedType))
875b618880eSAlexander Belyaev     return op.emitOpError("expected collapsed type to be ")
876b618880eSAlexander Belyaev            << expectedType << ", but got " << collapsedType;
877b618880eSAlexander Belyaev   return success();
878b618880eSAlexander Belyaev }
879b618880eSAlexander Belyaev 
verify()880b98dc035SRiver Riddle LogicalResult ExpandShapeOp::verify() {
881b98dc035SRiver Riddle   return verifyTensorReshapeOp(*this, getResultType(), getSrcType());
882b618880eSAlexander Belyaev }
883b618880eSAlexander Belyaev 
verify()884b98dc035SRiver Riddle LogicalResult CollapseShapeOp::verify() {
885b98dc035SRiver Riddle   return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
886b618880eSAlexander Belyaev }
887b618880eSAlexander Belyaev 
888b618880eSAlexander Belyaev namespace {
889b618880eSAlexander Belyaev /// Reshape of a splat constant can be replaced with a constant of the result
890b618880eSAlexander Belyaev /// type.
891b618880eSAlexander Belyaev template <typename TensorReshapeOp>
892b618880eSAlexander Belyaev struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
893b618880eSAlexander Belyaev   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
matchAndRewrite__anon3fb9f79f0611::FoldReshapeWithConstant894b618880eSAlexander Belyaev   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
895b618880eSAlexander Belyaev                                 PatternRewriter &rewriter) const override {
896b618880eSAlexander Belyaev     DenseElementsAttr attr;
8972d70eff8SJacques Pienaar     if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr)))
898b618880eSAlexander Belyaev       return failure();
899b618880eSAlexander Belyaev     if (!attr || !attr.isSplat())
900b618880eSAlexander Belyaev       return failure();
901b618880eSAlexander Belyaev     DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
902f21896f2SChris Lattner         reshapeOp.getResultType(), attr.getRawData());
903b618880eSAlexander Belyaev     rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
904b618880eSAlexander Belyaev     return success();
905b618880eSAlexander Belyaev   }
906b618880eSAlexander Belyaev };
907b618880eSAlexander Belyaev 
908d81a3c51SRob Suderman /// Reshape of a FromElements can be replaced with a FromElements of the result
909d81a3c51SRob Suderman /// type
910d81a3c51SRob Suderman template <typename TensorReshapeOp>
911d81a3c51SRob Suderman struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
912d81a3c51SRob Suderman   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
matchAndRewrite__anon3fb9f79f0611::FoldReshapeWithFromElements913d81a3c51SRob Suderman   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
914d81a3c51SRob Suderman                                 PatternRewriter &rewriter) const override {
915d81a3c51SRob Suderman     auto fromElements =
9162d70eff8SJacques Pienaar         reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
917d81a3c51SRob Suderman     if (!fromElements)
918d81a3c51SRob Suderman       return failure();
919d81a3c51SRob Suderman 
920d81a3c51SRob Suderman     auto shapedTy = reshapeOp.getType().template cast<ShapedType>();
921d81a3c51SRob Suderman 
922d81a3c51SRob Suderman     if (!shapedTy.hasStaticShape())
923d81a3c51SRob Suderman       return failure();
924d81a3c51SRob Suderman 
925d81a3c51SRob Suderman     rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
9262d70eff8SJacques Pienaar                                                 fromElements.getElements());
927d81a3c51SRob Suderman     return success();
928d81a3c51SRob Suderman   }
929d81a3c51SRob Suderman };
930d81a3c51SRob Suderman 
931b618880eSAlexander Belyaev } // namespace
932b618880eSAlexander Belyaev 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)933b618880eSAlexander Belyaev void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
934b618880eSAlexander Belyaev                                                 MLIRContext *context) {
935747b10beSAlexander Belyaev   results.add<ComposeReassociativeReshapeOps<ExpandShapeOp>,
936747b10beSAlexander Belyaev               ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
937d81a3c51SRob Suderman               FoldReshapeWithConstant<ExpandShapeOp>,
938d81a3c51SRob Suderman               FoldReshapeWithFromElements<ExpandShapeOp>>(context);
939b618880eSAlexander Belyaev }
940b618880eSAlexander Belyaev 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)941b618880eSAlexander Belyaev void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
942b618880eSAlexander Belyaev                                                   MLIRContext *context) {
943747b10beSAlexander Belyaev   results.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
944747b10beSAlexander Belyaev               ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp>,
945d81a3c51SRob Suderman               FoldReshapeWithConstant<CollapseShapeOp>,
946d81a3c51SRob Suderman               FoldReshapeWithFromElements<CollapseShapeOp>>(context);
947b618880eSAlexander Belyaev }
948b618880eSAlexander Belyaev 
fold(ArrayRef<Attribute> operands)949b618880eSAlexander Belyaev OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
950b618880eSAlexander Belyaev   return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
951b618880eSAlexander Belyaev }
fold(ArrayRef<Attribute> operands)952b618880eSAlexander Belyaev OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
953b618880eSAlexander Belyaev   return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
954b618880eSAlexander Belyaev }
955b618880eSAlexander Belyaev 
956b618880eSAlexander Belyaev //===----------------------------------------------------------------------===//
957060208b4SMatthias Springer // ExtractSliceOp
958060208b4SMatthias Springer //===----------------------------------------------------------------------===//
959060208b4SMatthias Springer 
960741f8f2bSNicolas Vasilache /// An extract_slice result type can be inferred, when it is not
961741f8f2bSNicolas Vasilache /// rank-reduced, from the source type and the static representation of
962741f8f2bSNicolas Vasilache /// offsets, sizes and strides. Special sentinels encode the dynamic case.
inferResultType(ShapedType sourceShapedTensorType,ArrayRef<int64_t> staticOffsets,ArrayRef<int64_t> staticSizes,ArrayRef<int64_t> staticStrides)9637df7586aSMaheshRavishankar RankedTensorType ExtractSliceOp::inferResultType(
964741f8f2bSNicolas Vasilache     ShapedType sourceShapedTensorType, ArrayRef<int64_t> staticOffsets,
9657df7586aSMaheshRavishankar     ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
966060208b4SMatthias Springer   // An extract_slice op may specify only a leading subset of offset/sizes/
967060208b4SMatthias Springer   // strides in which case we complete with offset=0, sizes from memref type and
968060208b4SMatthias Springer   // strides=1.
969741f8f2bSNicolas Vasilache   assert(static_cast<int64_t>(staticSizes.size()) ==
970741f8f2bSNicolas Vasilache              sourceShapedTensorType.getRank() &&
9717df7586aSMaheshRavishankar          "unexpected staticSizes not equal to rank of source");
972060208b4SMatthias Springer   return RankedTensorType::get(staticSizes,
973741f8f2bSNicolas Vasilache                                sourceShapedTensorType.getElementType());
974060208b4SMatthias Springer }
975060208b4SMatthias Springer 
inferResultType(ShapedType sourceShapedTensorType,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides)9767df7586aSMaheshRavishankar RankedTensorType ExtractSliceOp::inferResultType(
977741f8f2bSNicolas Vasilache     ShapedType sourceShapedTensorType, ArrayRef<OpFoldResult> offsets,
9787df7586aSMaheshRavishankar     ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
979060208b4SMatthias Springer   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
980060208b4SMatthias Springer   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
9817df7586aSMaheshRavishankar   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
9827df7586aSMaheshRavishankar                              ShapedType::kDynamicStrideOrOffset);
9837df7586aSMaheshRavishankar   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
984060208b4SMatthias Springer                              ShapedType::kDynamicSize);
9857df7586aSMaheshRavishankar   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
9867df7586aSMaheshRavishankar                              ShapedType::kDynamicStrideOrOffset);
987741f8f2bSNicolas Vasilache   return ExtractSliceOp::inferResultType(sourceShapedTensorType, staticOffsets,
988060208b4SMatthias Springer                                          staticSizes, staticStrides);
989060208b4SMatthias Springer }
990060208b4SMatthias Springer 
991741f8f2bSNicolas Vasilache /// If the rank is reduced (i.e. the desiredResultRank is smaller than the
992741f8f2bSNicolas Vasilache /// number of sizes), drop as many size 1 as needed to produce an inferred type
993741f8f2bSNicolas Vasilache /// with the desired rank.
994741f8f2bSNicolas Vasilache ///
995741f8f2bSNicolas Vasilache /// Note that there may be multiple ways to compute this rank-reduced type:
996741f8f2bSNicolas Vasilache ///   e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
997741f8f2bSNicolas Vasilache ///
998741f8f2bSNicolas Vasilache /// To disambiguate, this function always drops the first 1 sizes occurrences.
inferCanonicalRankReducedResultType(unsigned desiredResultRank,RankedTensorType sourceRankedTensorType,ArrayRef<int64_t> offsets,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides)999741f8f2bSNicolas Vasilache RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
1000741f8f2bSNicolas Vasilache     unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
10017df7586aSMaheshRavishankar     ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
10027df7586aSMaheshRavishankar     ArrayRef<int64_t> strides) {
1003741f8f2bSNicolas Vasilache   // Type inferred in the absence of rank-reducing behavior.
1004060208b4SMatthias Springer   auto inferredType =
10057df7586aSMaheshRavishankar       inferResultType(sourceRankedTensorType, offsets, sizes, strides)
1006060208b4SMatthias Springer           .cast<RankedTensorType>();
1007741f8f2bSNicolas Vasilache   int rankDiff = inferredType.getRank() - desiredResultRank;
1008060208b4SMatthias Springer   if (rankDiff > 0) {
1009060208b4SMatthias Springer     auto shape = inferredType.getShape();
10106635c12aSBenjamin Kramer     llvm::SmallBitVector dimsToProject =
10116635c12aSBenjamin Kramer         getPositionsOfShapeOne(rankDiff, shape);
1012060208b4SMatthias Springer     SmallVector<int64_t> projectedShape;
1013741f8f2bSNicolas Vasilache     // Best effort rank-reducing: drop 1s in order.
1014060208b4SMatthias Springer     for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
10156635c12aSBenjamin Kramer       if (!dimsToProject.test(pos))
1016060208b4SMatthias Springer         projectedShape.push_back(shape[pos]);
1017060208b4SMatthias Springer     inferredType =
1018060208b4SMatthias Springer         RankedTensorType::get(projectedShape, inferredType.getElementType());
1019060208b4SMatthias Springer   }
1020060208b4SMatthias Springer   return inferredType;
1021060208b4SMatthias Springer }
1022060208b4SMatthias Springer 
inferCanonicalRankReducedResultType(unsigned desiredResultRank,RankedTensorType sourceRankedTensorType,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides)1023741f8f2bSNicolas Vasilache RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
1024741f8f2bSNicolas Vasilache     unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
10257df7586aSMaheshRavishankar     ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
10267df7586aSMaheshRavishankar     ArrayRef<OpFoldResult> strides) {
1027060208b4SMatthias Springer   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1028060208b4SMatthias Springer   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
10297df7586aSMaheshRavishankar   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
10307df7586aSMaheshRavishankar                              ShapedType::kDynamicStrideOrOffset);
10317df7586aSMaheshRavishankar   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1032060208b4SMatthias Springer                              ShapedType::kDynamicSize);
10337df7586aSMaheshRavishankar   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
10347df7586aSMaheshRavishankar                              ShapedType::kDynamicStrideOrOffset);
1035741f8f2bSNicolas Vasilache   return ExtractSliceOp::inferCanonicalRankReducedResultType(
1036741f8f2bSNicolas Vasilache       desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
1037060208b4SMatthias Springer       staticStrides);
1038060208b4SMatthias Springer }
1039060208b4SMatthias Springer 
1040060208b4SMatthias Springer /// Build an ExtractSliceOp with mixed static and dynamic entries and custom
1041060208b4SMatthias Springer /// result type. If the type passed is nullptr, it is inferred.
build(OpBuilder & b,OperationState & result,RankedTensorType resultType,Value source,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides,ArrayRef<NamedAttribute> attrs)1042060208b4SMatthias Springer void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
1043060208b4SMatthias Springer                            RankedTensorType resultType, Value source,
1044060208b4SMatthias Springer                            ArrayRef<OpFoldResult> offsets,
1045060208b4SMatthias Springer                            ArrayRef<OpFoldResult> sizes,
1046060208b4SMatthias Springer                            ArrayRef<OpFoldResult> strides,
1047060208b4SMatthias Springer                            ArrayRef<NamedAttribute> attrs) {
1048060208b4SMatthias Springer   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1049060208b4SMatthias Springer   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1050060208b4SMatthias Springer   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1051060208b4SMatthias Springer                              ShapedType::kDynamicStrideOrOffset);
1052060208b4SMatthias Springer   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1053060208b4SMatthias Springer                              ShapedType::kDynamicSize);
1054060208b4SMatthias Springer   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1055060208b4SMatthias Springer                              ShapedType::kDynamicStrideOrOffset);
1056060208b4SMatthias Springer   auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
1057060208b4SMatthias Springer   // Structuring implementation this way avoids duplication between builders.
1058060208b4SMatthias Springer   if (!resultType) {
1059060208b4SMatthias Springer     resultType =
1060060208b4SMatthias Springer         ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets,
1061060208b4SMatthias Springer                                         staticSizes, staticStrides)
1062060208b4SMatthias Springer             .cast<RankedTensorType>();
1063060208b4SMatthias Springer   }
1064060208b4SMatthias Springer   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1065060208b4SMatthias Springer         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1066060208b4SMatthias Springer         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1067060208b4SMatthias Springer   result.addAttributes(attrs);
1068060208b4SMatthias Springer }
1069060208b4SMatthias Springer 
1070060208b4SMatthias Springer /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
1071060208b4SMatthias Springer /// result type.
build(OpBuilder & b,OperationState & result,Value source,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides,ArrayRef<NamedAttribute> attrs)1072060208b4SMatthias Springer void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1073060208b4SMatthias Springer                            ArrayRef<OpFoldResult> offsets,
1074060208b4SMatthias Springer                            ArrayRef<OpFoldResult> sizes,
1075060208b4SMatthias Springer                            ArrayRef<OpFoldResult> strides,
1076060208b4SMatthias Springer                            ArrayRef<NamedAttribute> attrs) {
1077060208b4SMatthias Springer   build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
1078060208b4SMatthias Springer }
1079060208b4SMatthias Springer 
1080060208b4SMatthias Springer /// Build an ExtractSliceOp with dynamic entries and custom result type. If the
1081060208b4SMatthias Springer /// type passed is nullptr, it is inferred.
build(OpBuilder & b,OperationState & result,RankedTensorType resultType,Value source,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)1082060208b4SMatthias Springer void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
1083060208b4SMatthias Springer                            RankedTensorType resultType, Value source,
1084060208b4SMatthias Springer                            ValueRange offsets, ValueRange sizes,
1085060208b4SMatthias Springer                            ValueRange strides, ArrayRef<NamedAttribute> attrs) {
1086060208b4SMatthias Springer   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1087060208b4SMatthias Springer       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1088060208b4SMatthias Springer   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1089060208b4SMatthias Springer       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1090060208b4SMatthias Springer   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1091060208b4SMatthias Springer       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1092060208b4SMatthias Springer   build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
1093060208b4SMatthias Springer }
1094060208b4SMatthias Springer 
1095060208b4SMatthias Springer /// Build an ExtractSliceOp with dynamic entries and inferred result type.
build(OpBuilder & b,OperationState & result,Value source,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)1096060208b4SMatthias Springer void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1097060208b4SMatthias Springer                            ValueRange offsets, ValueRange sizes,
1098060208b4SMatthias Springer                            ValueRange strides, ArrayRef<NamedAttribute> attrs) {
1099060208b4SMatthias Springer   build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
1100060208b4SMatthias Springer }
1101060208b4SMatthias Springer 
1102060208b4SMatthias Springer template <typename OpTy>
produceSliceErrorMsg(SliceVerificationResult result,OpTy op,Type expectedType)1103060208b4SMatthias Springer static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
1104a08b750cSNicolas Vasilache                                           OpTy op, Type expectedType) {
1105060208b4SMatthias Springer   auto memrefType = expectedType.cast<ShapedType>();
1106060208b4SMatthias Springer   switch (result) {
1107060208b4SMatthias Springer   case SliceVerificationResult::Success:
1108060208b4SMatthias Springer     return success();
1109060208b4SMatthias Springer   case SliceVerificationResult::RankTooLarge:
1110a08b750cSNicolas Vasilache     return op.emitError("expected rank to be smaller or equal to ")
1111a08b750cSNicolas Vasilache            << "the other rank. ";
1112060208b4SMatthias Springer   case SliceVerificationResult::SizeMismatch:
1113a08b750cSNicolas Vasilache     return op.emitError("expected type to be ")
1114a08b750cSNicolas Vasilache            << expectedType << " or a rank-reduced version. (size mismatch) ";
1115060208b4SMatthias Springer   case SliceVerificationResult::ElemTypeMismatch:
1116a08b750cSNicolas Vasilache     return op.emitError("expected element type to be ")
1117a08b750cSNicolas Vasilache            << memrefType.getElementType();
1118a08b750cSNicolas Vasilache   default:
1119060208b4SMatthias Springer     llvm_unreachable("unexpected extract_slice op verification result");
1120060208b4SMatthias Springer   }
1121a08b750cSNicolas Vasilache }
1122060208b4SMatthias Springer 
1123060208b4SMatthias Springer /// Verifier for ExtractSliceOp.
verify()1124b98dc035SRiver Riddle LogicalResult ExtractSliceOp::verify() {
1125060208b4SMatthias Springer   // Verify result type against inferred type.
1126*c9fb3c6eSNicolas Vasilache   RankedTensorType expectedType = ExtractSliceOp::inferResultType(
1127b98dc035SRiver Riddle       getSourceType(), getMixedOffsets(), getMixedSizes(), getMixedStrides());
1128*c9fb3c6eSNicolas Vasilache   SliceVerificationResult result = isRankReducedType(expectedType, getType());
1129b98dc035SRiver Riddle   return produceSliceErrorMsg(result, *this, expectedType);
1130060208b4SMatthias Springer }
1131060208b4SMatthias Springer 
getDroppedDims()11326635c12aSBenjamin Kramer llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
113342819463SMaheshRavishankar   ArrayRef<int64_t> resultShape = getType().getShape();
113442819463SMaheshRavishankar   SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
11356635c12aSBenjamin Kramer   llvm::SmallBitVector droppedDims(mixedSizes.size());
113642819463SMaheshRavishankar   unsigned shapePos = 0;
1137e4853be2SMehdi Amini   for (const auto &size : enumerate(mixedSizes)) {
113842819463SMaheshRavishankar     Optional<int64_t> sizeVal = getConstantIntValue(size.value());
113942819463SMaheshRavishankar     // If the size is not 1, or if the current matched dimension of the result
114042819463SMaheshRavishankar     // is the same static shape as the size value (which is 1), then the
114142819463SMaheshRavishankar     // dimension is preserved.
11426d5fc1e3SKazu Hirata     if (!sizeVal || *sizeVal != 1 ||
114342819463SMaheshRavishankar         (shapePos < resultShape.size() && resultShape[shapePos] == 1)) {
114442819463SMaheshRavishankar       shapePos++;
114542819463SMaheshRavishankar       continue;
114642819463SMaheshRavishankar     }
11476635c12aSBenjamin Kramer     droppedDims.set(size.index());
114842819463SMaheshRavishankar   }
114942819463SMaheshRavishankar   return droppedDims;
115042819463SMaheshRavishankar }
115142819463SMaheshRavishankar 
reifyResultShapes(OpBuilder & builder,ReifiedRankedShapedTypeDims & reifiedReturnShapes)115242819463SMaheshRavishankar LogicalResult ExtractSliceOp::reifyResultShapes(
115342819463SMaheshRavishankar     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
115442819463SMaheshRavishankar   reifiedReturnShapes.resize(1);
115542819463SMaheshRavishankar   reifiedReturnShapes[0].reserve(getType().getRank());
115642819463SMaheshRavishankar   SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
11576635c12aSBenjamin Kramer   llvm::SmallBitVector droppedDims = getDroppedDims();
115842819463SMaheshRavishankar   Location loc = getLoc();
1159e4853be2SMehdi Amini   for (const auto &size : enumerate(mixedSizes)) {
11606635c12aSBenjamin Kramer     if (droppedDims.test(size.index()))
116142819463SMaheshRavishankar       continue;
116242819463SMaheshRavishankar     if (auto attr = size.value().dyn_cast<Attribute>()) {
1163a54f4eaeSMogball       reifiedReturnShapes[0].push_back(builder.create<arith::ConstantIndexOp>(
116442819463SMaheshRavishankar           loc, attr.cast<IntegerAttr>().getInt()));
116542819463SMaheshRavishankar       continue;
116642819463SMaheshRavishankar     }
116742819463SMaheshRavishankar     reifiedReturnShapes[0].push_back(size.value().get<Value>());
116842819463SMaheshRavishankar   }
116942819463SMaheshRavishankar   return success();
117042819463SMaheshRavishankar }
117142819463SMaheshRavishankar 
1172060208b4SMatthias Springer namespace {
1173060208b4SMatthias Springer /// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
1174060208b4SMatthias Springer /// This essentially pushes memref_cast past its consuming slice when
1175060208b4SMatthias Springer /// `canFoldIntoConsumerOp` is true.
1176060208b4SMatthias Springer ///
1177060208b4SMatthias Springer /// Example:
1178060208b4SMatthias Springer /// ```
1179060208b4SMatthias Springer ///   %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
1180060208b4SMatthias Springer ///   %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
1181060208b4SMatthias Springer ///   tensor<3x4xf32>
1182060208b4SMatthias Springer /// ```
1183060208b4SMatthias Springer /// is rewritten into:
1184060208b4SMatthias Springer /// ```
1185060208b4SMatthias Springer ///   %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
1186060208b4SMatthias Springer ///   tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
1187060208b4SMatthias Springer /// ```
1188060208b4SMatthias Springer class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
1189060208b4SMatthias Springer public:
1190060208b4SMatthias Springer   using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
1191060208b4SMatthias Springer 
matchAndRewrite(ExtractSliceOp sliceOp,PatternRewriter & rewriter) const1192060208b4SMatthias Springer   LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
1193060208b4SMatthias Springer                                 PatternRewriter &rewriter) const override {
1194741f8f2bSNicolas Vasilache     // Any constant operand, just return to let the constant folder kick in.
1195060208b4SMatthias Springer     if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
1196060208b4SMatthias Springer           return matchPattern(operand, matchConstantIndex());
1197060208b4SMatthias Springer         }))
1198060208b4SMatthias Springer       return failure();
1199060208b4SMatthias Springer 
12002d70eff8SJacques Pienaar     auto castOp = sliceOp.getSource().getDefiningOp<tensor::CastOp>();
1201060208b4SMatthias Springer     if (!castOp)
1202060208b4SMatthias Springer       return failure();
1203060208b4SMatthias Springer 
1204060208b4SMatthias Springer     if (!canFoldIntoConsumerOp(castOp))
1205060208b4SMatthias Springer       return failure();
1206060208b4SMatthias Springer 
1207060208b4SMatthias Springer     /// Deduce the type of the result to use for the canonicalized operation.
1208741f8f2bSNicolas Vasilache     RankedTensorType resultType =
1209741f8f2bSNicolas Vasilache         ExtractSliceOp::inferCanonicalRankReducedResultType(
1210060208b4SMatthias Springer             sliceOp.getType().getRank(), sliceOp.getSourceType(),
1211060208b4SMatthias Springer             sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
1212060208b4SMatthias Springer             sliceOp.getMixedStrides());
1213060208b4SMatthias Springer     Value newSlice = rewriter.create<ExtractSliceOp>(
12142d70eff8SJacques Pienaar         sliceOp.getLoc(), resultType, castOp.getSource(), sliceOp.getOffsets(),
12152d70eff8SJacques Pienaar         sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
12162d70eff8SJacques Pienaar         sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
1217060208b4SMatthias Springer     rewriter.replaceOpWithNewOp<tensor::CastOp>(sliceOp, sliceOp.getType(),
1218060208b4SMatthias Springer                                                 newSlice);
1219060208b4SMatthias Springer     return success();
1220060208b4SMatthias Springer   }
1221060208b4SMatthias Springer };
12224c901bf4SOkwan Kwon 
12234c901bf4SOkwan Kwon /// Slice elements from `values` into `outValues`. `counts` represents the
12244c901bf4SOkwan Kwon /// numbers of elements to stride in the original values for each dimension.
12254c901bf4SOkwan Kwon /// The output values can be used to construct a DenseElementsAttr.
12264c901bf4SOkwan Kwon template <typename IterTy, typename ElemTy>
sliceElements(IterTy values,ArrayRef<int64_t> counts,ArrayRef<int64_t> offsets,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides,llvm::SmallVectorImpl<ElemTy> * outValues)12274c901bf4SOkwan Kwon static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
12284c901bf4SOkwan Kwon                           ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
12294c901bf4SOkwan Kwon                           ArrayRef<int64_t> strides,
12304c901bf4SOkwan Kwon                           llvm::SmallVectorImpl<ElemTy> *outValues) {
12314c901bf4SOkwan Kwon   assert(offsets.size() == sizes.size());
12324c901bf4SOkwan Kwon   assert(offsets.size() == strides.size());
12334c901bf4SOkwan Kwon   if (offsets.empty())
12344c901bf4SOkwan Kwon     return;
12354c901bf4SOkwan Kwon 
12364c901bf4SOkwan Kwon   int64_t offset = offsets.front();
12374c901bf4SOkwan Kwon   int64_t size = sizes.front();
12384c901bf4SOkwan Kwon   int64_t stride = strides.front();
12394c901bf4SOkwan Kwon   if (offsets.size() == 1) {
12404c901bf4SOkwan Kwon     for (int64_t i = 0; i < size; ++i, offset += stride)
12414c901bf4SOkwan Kwon       outValues->push_back(*(values + offset));
12424c901bf4SOkwan Kwon 
12434c901bf4SOkwan Kwon     return;
12444c901bf4SOkwan Kwon   }
12454c901bf4SOkwan Kwon 
12464c901bf4SOkwan Kwon   for (int64_t i = 0; i < size; ++i, offset += stride) {
12474c901bf4SOkwan Kwon     auto begin = values + offset * counts.front();
12484c901bf4SOkwan Kwon     sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
12494c901bf4SOkwan Kwon                                   offsets.drop_front(), sizes.drop_front(),
12504c901bf4SOkwan Kwon                                   strides.drop_front(), outValues);
12514c901bf4SOkwan Kwon   }
12524c901bf4SOkwan Kwon }
12534c901bf4SOkwan Kwon 
12544c901bf4SOkwan Kwon /// Fold arith.constant and tensor.extract_slice into arith.constant. The folded
12554c901bf4SOkwan Kwon /// operation might introduce more constant data; Users can control their
12564c901bf4SOkwan Kwon /// heuristics by the control function.
12574c901bf4SOkwan Kwon class ConstantOpExtractSliceFolder final
12584c901bf4SOkwan Kwon     : public OpRewritePattern<ExtractSliceOp> {
12594c901bf4SOkwan Kwon public:
12604c901bf4SOkwan Kwon   using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
12614c901bf4SOkwan Kwon 
ConstantOpExtractSliceFolder(MLIRContext * context,ControlConstantExtractSliceFusionFn controlFn)12624c901bf4SOkwan Kwon   ConstantOpExtractSliceFolder(MLIRContext *context,
12634c901bf4SOkwan Kwon                                ControlConstantExtractSliceFusionFn controlFn)
12644c901bf4SOkwan Kwon       : OpRewritePattern<ExtractSliceOp>(context),
12654c901bf4SOkwan Kwon         controlFn(std::move(controlFn)) {}
12664c901bf4SOkwan Kwon 
matchAndRewrite(ExtractSliceOp op,PatternRewriter & rewriter) const12674c901bf4SOkwan Kwon   LogicalResult matchAndRewrite(ExtractSliceOp op,
12684c901bf4SOkwan Kwon                                 PatternRewriter &rewriter) const override {
12694c901bf4SOkwan Kwon     DenseElementsAttr attr;
12702d70eff8SJacques Pienaar     if (!matchPattern(op.getSource(), m_Constant(&attr)))
12714c901bf4SOkwan Kwon       return failure();
12724c901bf4SOkwan Kwon 
12734c901bf4SOkwan Kwon     // A constant splat is handled by fold().
12744c901bf4SOkwan Kwon     if (attr.isSplat())
12754c901bf4SOkwan Kwon       return failure();
12764c901bf4SOkwan Kwon 
12774c901bf4SOkwan Kwon     // Dynamic result shape is not supported.
12782d70eff8SJacques Pienaar     auto sourceType = op.getSource().getType().cast<ShapedType>();
12792d70eff8SJacques Pienaar     auto resultType = op.getResult().getType().cast<ShapedType>();
12804c901bf4SOkwan Kwon     if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
12814c901bf4SOkwan Kwon       return failure();
12824c901bf4SOkwan Kwon 
12834c901bf4SOkwan Kwon     // Customized control over the folding.
12844c901bf4SOkwan Kwon     if (!controlFn(op))
12854c901bf4SOkwan Kwon       return failure();
12864c901bf4SOkwan Kwon 
12874c901bf4SOkwan Kwon     int64_t count = sourceType.getNumElements();
12884c901bf4SOkwan Kwon     if (count == 0)
12894c901bf4SOkwan Kwon       return failure();
12904c901bf4SOkwan Kwon 
12914c901bf4SOkwan Kwon     // Check if there are any dynamic parts, which are not supported.
12922d70eff8SJacques Pienaar     auto offsets = extractFromI64ArrayAttr(op.getStaticOffsets());
12934c901bf4SOkwan Kwon     if (llvm::is_contained(offsets, ShapedType::kDynamicStrideOrOffset))
12944c901bf4SOkwan Kwon       return failure();
12952d70eff8SJacques Pienaar     auto sizes = extractFromI64ArrayAttr(op.getStaticSizes());
12964c901bf4SOkwan Kwon     if (llvm::is_contained(sizes, ShapedType::kDynamicSize))
12974c901bf4SOkwan Kwon       return failure();
12982d70eff8SJacques Pienaar     auto strides = extractFromI64ArrayAttr(op.getStaticStrides());
12994c901bf4SOkwan Kwon     if (llvm::is_contained(strides, ShapedType::kDynamicStrideOrOffset))
13004c901bf4SOkwan Kwon       return failure();
13014c901bf4SOkwan Kwon 
13024c901bf4SOkwan Kwon     // Compute the stride for each dimension.
13034c901bf4SOkwan Kwon     SmallVector<int64_t> counts;
13044c901bf4SOkwan Kwon     ArrayRef<int64_t> shape = sourceType.getShape();
13054c901bf4SOkwan Kwon     counts.reserve(shape.size());
13064c901bf4SOkwan Kwon     for (int64_t v : shape) {
13074c901bf4SOkwan Kwon       count = count / v;
13084c901bf4SOkwan Kwon       counts.push_back(count);
13094c901bf4SOkwan Kwon     }
13104c901bf4SOkwan Kwon 
13114c901bf4SOkwan Kwon     // New attribute constructed by the sliced values.
13124c901bf4SOkwan Kwon     DenseElementsAttr newAttr;
13134c901bf4SOkwan Kwon 
13144c901bf4SOkwan Kwon     if (auto elems = attr.dyn_cast<DenseIntElementsAttr>()) {
13154c901bf4SOkwan Kwon       SmallVector<APInt> outValues;
13164c901bf4SOkwan Kwon       outValues.reserve(sourceType.getNumElements());
13174c901bf4SOkwan Kwon       sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
13184c901bf4SOkwan Kwon           elems.begin(), counts, offsets, sizes, strides, &outValues);
13194c901bf4SOkwan Kwon       newAttr = DenseElementsAttr::get(resultType, outValues);
13204c901bf4SOkwan Kwon     } else if (auto elems = attr.dyn_cast<DenseFPElementsAttr>()) {
13214c901bf4SOkwan Kwon       SmallVector<APFloat> outValues;
13224c901bf4SOkwan Kwon       outValues.reserve(sourceType.getNumElements());
13234c901bf4SOkwan Kwon       sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
13244c901bf4SOkwan Kwon           elems.begin(), counts, offsets, sizes, strides, &outValues);
13254c901bf4SOkwan Kwon       newAttr = DenseElementsAttr::get(resultType, outValues);
13264c901bf4SOkwan Kwon     }
13274c901bf4SOkwan Kwon 
13284c901bf4SOkwan Kwon     if (newAttr) {
13294c901bf4SOkwan Kwon       rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
13304c901bf4SOkwan Kwon       return success();
13314c901bf4SOkwan Kwon     }
13324c901bf4SOkwan Kwon 
13334c901bf4SOkwan Kwon     return failure();
13344c901bf4SOkwan Kwon   }
13354c901bf4SOkwan Kwon 
13364c901bf4SOkwan Kwon private:
13374c901bf4SOkwan Kwon   /// This additionally controls whether the fold happens or not. Users can
13384c901bf4SOkwan Kwon   /// impose their heuristics in the function.
13394c901bf4SOkwan Kwon   ControlConstantExtractSliceFusionFn controlFn;
13404c901bf4SOkwan Kwon };
13414c901bf4SOkwan Kwon 
1342060208b4SMatthias Springer } // namespace
1343060208b4SMatthias Springer 
populateFoldConstantExtractSlicePatterns(RewritePatternSet & patterns,const ControlConstantExtractSliceFusionFn & controlFn)13444c901bf4SOkwan Kwon void mlir::tensor::populateFoldConstantExtractSlicePatterns(
13454c901bf4SOkwan Kwon     RewritePatternSet &patterns,
13464c901bf4SOkwan Kwon     const ControlConstantExtractSliceFusionFn &controlFn) {
13474c901bf4SOkwan Kwon   patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
13484c901bf4SOkwan Kwon }
13494c901bf4SOkwan Kwon 
1350060208b4SMatthias Springer /// Return the canonical type of the result of an extract_slice op.
1351060208b4SMatthias Springer struct SliceReturnTypeCanonicalizer {
operator ()SliceReturnTypeCanonicalizer1352060208b4SMatthias Springer   RankedTensorType operator()(ExtractSliceOp op,
1353060208b4SMatthias Springer                               ArrayRef<OpFoldResult> mixedOffsets,
1354060208b4SMatthias Springer                               ArrayRef<OpFoldResult> mixedSizes,
1355060208b4SMatthias Springer                               ArrayRef<OpFoldResult> mixedStrides) {
1356741f8f2bSNicolas Vasilache     return ExtractSliceOp::inferCanonicalRankReducedResultType(
1357741f8f2bSNicolas Vasilache         op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
1358741f8f2bSNicolas Vasilache         mixedStrides);
1359060208b4SMatthias Springer   }
1360060208b4SMatthias Springer };
1361060208b4SMatthias Springer 
1362060208b4SMatthias Springer /// A canonicalizer wrapper to replace ExtractSliceOps.
1363060208b4SMatthias Springer struct SliceCanonicalizer {
operator ()SliceCanonicalizer1364060208b4SMatthias Springer   void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
1365060208b4SMatthias Springer                   ExtractSliceOp newOp) {
1366060208b4SMatthias Springer     Value replacement = newOp.getResult();
1367060208b4SMatthias Springer     if (replacement.getType() != op.getType())
1368060208b4SMatthias Springer       replacement = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
1369060208b4SMatthias Springer                                                     replacement);
1370060208b4SMatthias Springer     rewriter.replaceOp(op, replacement);
1371060208b4SMatthias Springer   }
1372060208b4SMatthias Springer };
1373060208b4SMatthias Springer 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1374060208b4SMatthias Springer void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
1375060208b4SMatthias Springer                                                  MLIRContext *context) {
1376060208b4SMatthias Springer   results.add<
1377060208b4SMatthias Springer       OpWithOffsetSizesAndStridesConstantArgumentFolder<
1378060208b4SMatthias Springer           ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
1379060208b4SMatthias Springer       ExtractSliceOpCastFolder>(context);
1380060208b4SMatthias Springer }
1381060208b4SMatthias Springer 
1382060208b4SMatthias Springer //
1383060208b4SMatthias Springer static LogicalResult
foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,ShapedType shapedType)1384060208b4SMatthias Springer foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
1385060208b4SMatthias Springer                                            ShapedType shapedType) {
1386060208b4SMatthias Springer   OpBuilder b(op.getContext());
1387060208b4SMatthias Springer   for (OpFoldResult ofr : op.getMixedOffsets())
13880813700dSMatthias Springer     if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
1389060208b4SMatthias Springer       return failure();
1390060208b4SMatthias Springer   // Rank-reducing noops only need to inspect the leading dimensions: llvm::zip
1391060208b4SMatthias Springer   // is appropriate.
1392060208b4SMatthias Springer   auto shape = shapedType.getShape();
1393060208b4SMatthias Springer   for (auto it : llvm::zip(op.getMixedSizes(), shape))
13940813700dSMatthias Springer     if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
1395060208b4SMatthias Springer       return failure();
1396060208b4SMatthias Springer   for (OpFoldResult ofr : op.getMixedStrides())
13970813700dSMatthias Springer     if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
1398060208b4SMatthias Springer       return failure();
1399060208b4SMatthias Springer   return success();
1400060208b4SMatthias Springer }
1401060208b4SMatthias Springer 
1402bdd37c9fSLei Zhang /// If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice,
1403bdd37c9fSLei Zhang /// we can return the InsertSliceOp's source directly.
1404bdd37c9fSLei Zhang // TODO: This only checks the immediate producer; extend to go up the
1405bdd37c9fSLei Zhang // insert/extract chain if the slices are disjoint.
foldExtractAfterInsertSlice(ExtractSliceOp extractOp)1406bdd37c9fSLei Zhang static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
14072d70eff8SJacques Pienaar   auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
1408bdd37c9fSLei Zhang 
1409bdd37c9fSLei Zhang   auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
14102d70eff8SJacques Pienaar   if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
1411bdd37c9fSLei Zhang       insertOp.isSameAs(extractOp, isSame))
14122d70eff8SJacques Pienaar     return insertOp.getSource();
1413bdd37c9fSLei Zhang 
1414bdd37c9fSLei Zhang   return {};
1415bdd37c9fSLei Zhang }
1416bdd37c9fSLei Zhang 
fold(ArrayRef<Attribute> operands)1417f79f430dSOkwan Kwon OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute> operands) {
1418f79f430dSOkwan Kwon   if (auto splat = operands[0].dyn_cast_or_null<SplatElementsAttr>()) {
14192d70eff8SJacques Pienaar     auto resultType = getResult().getType().cast<ShapedType>();
1420f79f430dSOkwan Kwon     if (resultType.hasStaticShape())
1421f79f430dSOkwan Kwon       return splat.resizeSplat(resultType);
1422f79f430dSOkwan Kwon   }
1423060208b4SMatthias Springer   if (getSourceType() == getType() &&
1424060208b4SMatthias Springer       succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
14252d70eff8SJacques Pienaar     return this->getSource();
1426bdd37c9fSLei Zhang   if (Value slice = foldExtractAfterInsertSlice(*this))
1427bdd37c9fSLei Zhang     return slice;
14284c901bf4SOkwan Kwon 
1429060208b4SMatthias Springer   return OpFoldResult();
1430060208b4SMatthias Springer }
1431060208b4SMatthias Springer 
createCanonicalRankReducingExtractSliceOp(OpBuilder & b,Location loc,Value tensor,RankedTensorType targetType)1432aa373180SNicolas Vasilache Value mlir::tensor::createCanonicalRankReducingExtractSliceOp(
1433aa373180SNicolas Vasilache     OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
1434aa373180SNicolas Vasilache   auto rankedTensorType = tensor.getType().cast<RankedTensorType>();
1435aa373180SNicolas Vasilache   unsigned rank = rankedTensorType.getRank();
1436aa373180SNicolas Vasilache   auto shape = rankedTensorType.getShape();
1437aa373180SNicolas Vasilache   SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
1438aa373180SNicolas Vasilache   SmallVector<OpFoldResult> sizes;
1439aa373180SNicolas Vasilache   for (unsigned i = 0, e = rank; i < e; ++i) {
1440aa373180SNicolas Vasilache     OpFoldResult dim;
1441aa373180SNicolas Vasilache     if (rankedTensorType.isDynamicDim(i))
1442aa373180SNicolas Vasilache       dim = b.createOrFold<tensor::DimOp>(
1443aa373180SNicolas Vasilache           loc, tensor, b.create<arith::ConstantIndexOp>(loc, i));
1444aa373180SNicolas Vasilache     else
1445aa373180SNicolas Vasilache       dim = b.getIndexAttr(shape[i]);
1446aa373180SNicolas Vasilache     sizes.push_back(dim);
1447aa373180SNicolas Vasilache   }
1448aa373180SNicolas Vasilache   SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
1449aa373180SNicolas Vasilache   return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
1450aa373180SNicolas Vasilache                                                 offsets, sizes, strides);
1451aa373180SNicolas Vasilache }
1452aa373180SNicolas Vasilache 
1453060208b4SMatthias Springer //===----------------------------------------------------------------------===//
1454060208b4SMatthias Springer // InsertSliceOp
1455060208b4SMatthias Springer //===----------------------------------------------------------------------===//
1456060208b4SMatthias Springer 
1457060208b4SMatthias Springer // Build a InsertSliceOp with mixed static and dynamic entries.
build(OpBuilder & b,OperationState & result,Value source,Value dest,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides,ArrayRef<NamedAttribute> attrs)1458060208b4SMatthias Springer void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1459060208b4SMatthias Springer                           Value dest, ArrayRef<OpFoldResult> offsets,
1460060208b4SMatthias Springer                           ArrayRef<OpFoldResult> sizes,
1461060208b4SMatthias Springer                           ArrayRef<OpFoldResult> strides,
1462060208b4SMatthias Springer                           ArrayRef<NamedAttribute> attrs) {
1463060208b4SMatthias Springer   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1464060208b4SMatthias Springer   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1465060208b4SMatthias Springer   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1466060208b4SMatthias Springer                              ShapedType::kDynamicStrideOrOffset);
1467060208b4SMatthias Springer   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1468060208b4SMatthias Springer                              ShapedType::kDynamicSize);
1469060208b4SMatthias Springer   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1470060208b4SMatthias Springer                              ShapedType::kDynamicStrideOrOffset);
1471060208b4SMatthias Springer   build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
1472060208b4SMatthias Springer         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1473060208b4SMatthias Springer         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1474060208b4SMatthias Springer   result.addAttributes(attrs);
1475060208b4SMatthias Springer }
1476060208b4SMatthias Springer 
1477060208b4SMatthias Springer // Build a InsertSliceOp with dynamic entries.
build(OpBuilder & b,OperationState & result,Value source,Value dest,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)1478060208b4SMatthias Springer void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1479060208b4SMatthias Springer                           Value dest, ValueRange offsets, ValueRange sizes,
1480060208b4SMatthias Springer                           ValueRange strides, ArrayRef<NamedAttribute> attrs) {
1481060208b4SMatthias Springer   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1482060208b4SMatthias Springer       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1483060208b4SMatthias Springer   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1484060208b4SMatthias Springer       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1485060208b4SMatthias Springer   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1486060208b4SMatthias Springer       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1487060208b4SMatthias Springer   build(b, result, source, dest, offsetValues, sizeValues, strideValues);
1488060208b4SMatthias Springer }
1489060208b4SMatthias Springer 
1490*c9fb3c6eSNicolas Vasilache /// Rank-reducing type verification for both InsertSliceOp and
1491*c9fb3c6eSNicolas Vasilache /// ParallelInsertSliceOp.
1492cd0d095cSIvan Butygin static SliceVerificationResult
verifyInsertSliceOp(ShapedType srcType,ShapedType dstType,ArrayAttr staticOffsets,ArrayAttr staticSizes,ArrayAttr staticStrides,ShapedType * expectedType=nullptr)1493cd0d095cSIvan Butygin verifyInsertSliceOp(ShapedType srcType, ShapedType dstType,
1494cd0d095cSIvan Butygin                     ArrayAttr staticOffsets, ArrayAttr staticSizes,
1495cd0d095cSIvan Butygin                     ArrayAttr staticStrides,
1496cd0d095cSIvan Butygin                     ShapedType *expectedType = nullptr) {
1497cd0d095cSIvan Butygin   // insert_slice is the inverse of extract_slice, use the same type inference.
1498*c9fb3c6eSNicolas Vasilache   RankedTensorType expected = ExtractSliceOp::inferResultType(
1499741f8f2bSNicolas Vasilache       dstType, extractFromI64ArrayAttr(staticOffsets),
1500cd0d095cSIvan Butygin       extractFromI64ArrayAttr(staticSizes),
1501*c9fb3c6eSNicolas Vasilache       extractFromI64ArrayAttr(staticStrides));
1502cd0d095cSIvan Butygin   if (expectedType)
1503cd0d095cSIvan Butygin     *expectedType = expected;
1504cd0d095cSIvan Butygin   return isRankReducedType(expected, srcType);
1505cd0d095cSIvan Butygin }
1506cd0d095cSIvan Butygin 
1507a08b750cSNicolas Vasilache /// Verifier for InsertSliceOp.
verify()1508b98dc035SRiver Riddle LogicalResult InsertSliceOp::verify() {
1509cd0d095cSIvan Butygin   ShapedType expectedType;
1510*c9fb3c6eSNicolas Vasilache   SliceVerificationResult result =
15112d70eff8SJacques Pienaar       verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
15122d70eff8SJacques Pienaar                           getStaticSizes(), getStaticStrides(), &expectedType);
1513b98dc035SRiver Riddle   return produceSliceErrorMsg(result, *this, expectedType);
1514a08b750cSNicolas Vasilache }
1515a08b750cSNicolas Vasilache 
1516bdd37c9fSLei Zhang /// If we have two consecutive InsertSliceOp writing to the same slice, we
1517bdd37c9fSLei Zhang /// can mutate the second InsertSliceOp's destination to the first one's.
1518*c9fb3c6eSNicolas Vasilache /// This works similarly when the second op is a ParallelInsertSliceOp.
1519bdd37c9fSLei Zhang ///
1520bdd37c9fSLei Zhang /// Example:
1521bdd37c9fSLei Zhang ///
1522bdd37c9fSLei Zhang /// ```mlir
1523bdd37c9fSLei Zhang ///   %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
1524bdd37c9fSLei Zhang ///   %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
1525bdd37c9fSLei Zhang /// ```
1526bdd37c9fSLei Zhang ///
1527bdd37c9fSLei Zhang /// folds into:
1528bdd37c9fSLei Zhang ///
1529bdd37c9fSLei Zhang /// ```mlir
1530bdd37c9fSLei Zhang ///   %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
1531bdd37c9fSLei Zhang /// ```
1532*c9fb3c6eSNicolas Vasilache ///
1533*c9fb3c6eSNicolas Vasilache /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
1534*c9fb3c6eSNicolas Vasilache template <typename InsertOpTy>
foldInsertAfterInsertSlice(InsertOpTy insertOp)1535*c9fb3c6eSNicolas Vasilache static LogicalResult foldInsertAfterInsertSlice(InsertOpTy insertOp) {
1536*c9fb3c6eSNicolas Vasilache   auto prevInsertOp = insertOp.getDest().template getDefiningOp<InsertOpTy>();
1537bdd37c9fSLei Zhang 
1538bdd37c9fSLei Zhang   auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
1539bdd37c9fSLei Zhang   if (!prevInsertOp ||
15402d70eff8SJacques Pienaar       prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
1541bdd37c9fSLei Zhang       !prevInsertOp.isSameAs(insertOp, isSame))
1542bdd37c9fSLei Zhang     return failure();
1543bdd37c9fSLei Zhang 
15442d70eff8SJacques Pienaar   insertOp.getDestMutable().assign(prevInsertOp.getDest());
1545bdd37c9fSLei Zhang   return success();
1546bdd37c9fSLei Zhang }
1547bdd37c9fSLei Zhang 
1548*c9fb3c6eSNicolas Vasilache /// Same logic for folding InsertSliceOp and ParallelInsertSliceOp, the return
1549*c9fb3c6eSNicolas Vasilache /// type varies though so we wrap it in a FailureOr.
1550*c9fb3c6eSNicolas Vasilache ///
1551*c9fb3c6eSNicolas Vasilache /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
1552*c9fb3c6eSNicolas Vasilache template <typename InsertOpTy>
foldInsertOp(InsertOpTy insertOp,ArrayRef<Attribute>)1553*c9fb3c6eSNicolas Vasilache FailureOr<OpFoldResult> foldInsertOp(InsertOpTy insertOp, ArrayRef<Attribute>) {
1554*c9fb3c6eSNicolas Vasilache   if (insertOp.getSourceType().hasStaticShape() &&
1555*c9fb3c6eSNicolas Vasilache       insertOp.getDestType().hasStaticShape() &&
1556*c9fb3c6eSNicolas Vasilache       insertOp.getSourceType() == insertOp.getDestType() &&
1557*c9fb3c6eSNicolas Vasilache       succeeded(foldIdentityOffsetSizeAndStrideOpInterface(
1558*c9fb3c6eSNicolas Vasilache           insertOp, insertOp.getDestType())))
1559*c9fb3c6eSNicolas Vasilache     return static_cast<OpFoldResult>(insertOp.getSource());
1560*c9fb3c6eSNicolas Vasilache   if (succeeded(foldInsertAfterInsertSlice(insertOp))) {
1561*c9fb3c6eSNicolas Vasilache     // InsertSliceOp has 1 result but ParallelInsertSliceOp has none and should
1562*c9fb3c6eSNicolas Vasilache     // return OpFoldResult().
1563*c9fb3c6eSNicolas Vasilache     if (std::is_same<InsertOpTy, InsertSliceOp>::value)
1564*c9fb3c6eSNicolas Vasilache       return static_cast<OpFoldResult>(insertOp->getResult(0));
1565*c9fb3c6eSNicolas Vasilache     else
1566060208b4SMatthias Springer       return OpFoldResult();
1567060208b4SMatthias Springer   }
1568*c9fb3c6eSNicolas Vasilache   return failure();
1569*c9fb3c6eSNicolas Vasilache }
1570*c9fb3c6eSNicolas Vasilache 
fold(ArrayRef<Attribute> operands)1571*c9fb3c6eSNicolas Vasilache OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute> operands) {
1572*c9fb3c6eSNicolas Vasilache   auto maybeOpFoldResult = foldInsertOp(*this, operands);
1573*c9fb3c6eSNicolas Vasilache   return failed(maybeOpFoldResult) ? OpFoldResult() : *maybeOpFoldResult;
1574*c9fb3c6eSNicolas Vasilache }
1575060208b4SMatthias Springer 
reifyResultShapes(OpBuilder & builder,ReifiedRankedShapedTypeDims & reifiedReturnShapes)15769afc0657SMaheshRavishankar LogicalResult InsertSliceOp::reifyResultShapes(
15779afc0657SMaheshRavishankar     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1578f2b5e438SMaheshRavishankar   reifiedReturnShapes.resize(1, SmallVector<Value>(getType().getRank()));
1579f2b5e438SMaheshRavishankar   for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1580f2b5e438SMaheshRavishankar     reifiedReturnShapes[0][dim] =
15812d70eff8SJacques Pienaar         builder.createOrFold<tensor::DimOp>(getLoc(), getDest(), dim);
1582f2b5e438SMaheshRavishankar   }
1583f2b5e438SMaheshRavishankar   return success();
1584f2b5e438SMaheshRavishankar }
1585f2b5e438SMaheshRavishankar 
1586060208b4SMatthias Springer namespace {
1587060208b4SMatthias Springer /// Pattern to rewrite a insert_slice op with constant arguments.
1588*c9fb3c6eSNicolas Vasilache ///
1589*c9fb3c6eSNicolas Vasilache /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
1590*c9fb3c6eSNicolas Vasilache template <typename InsertOpTy>
1591060208b4SMatthias Springer class InsertSliceOpConstantArgumentFolder final
1592*c9fb3c6eSNicolas Vasilache     : public OpRewritePattern<InsertOpTy> {
1593060208b4SMatthias Springer public:
1594*c9fb3c6eSNicolas Vasilache   using OpRewritePattern<InsertOpTy>::OpRewritePattern;
1595060208b4SMatthias Springer 
matchAndRewrite(InsertOpTy insertSliceOp,PatternRewriter & rewriter) const1596*c9fb3c6eSNicolas Vasilache   LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
1597060208b4SMatthias Springer                                 PatternRewriter &rewriter) const override {
1598060208b4SMatthias Springer     // No constant operand, just return.
1599060208b4SMatthias Springer     if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
1600060208b4SMatthias Springer           return matchPattern(operand, matchConstantIndex());
1601060208b4SMatthias Springer         }))
1602060208b4SMatthias Springer       return failure();
1603060208b4SMatthias Springer 
1604060208b4SMatthias Springer     // At least one of offsets/sizes/strides is a new constant.
1605060208b4SMatthias Springer     // Form the new list of operands and constant attributes from the
1606060208b4SMatthias Springer     // existing.
1607060208b4SMatthias Springer     SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
1608060208b4SMatthias Springer     SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
1609060208b4SMatthias Springer     SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
1610060208b4SMatthias Springer     canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
1611060208b4SMatthias Springer     canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
1612060208b4SMatthias Springer     canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
1613060208b4SMatthias Springer 
1614060208b4SMatthias Springer     // Create the new op in canonical form.
1615741f8f2bSNicolas Vasilache     auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
1616*c9fb3c6eSNicolas Vasilache         insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
1617060208b4SMatthias Springer         mixedOffsets, mixedSizes, mixedStrides);
16182d70eff8SJacques Pienaar     Value toInsert = insertSliceOp.getSource();
1619*c9fb3c6eSNicolas Vasilache     if (sourceType != insertSliceOp.getSourceType()) {
1620*c9fb3c6eSNicolas Vasilache       OpBuilder::InsertionGuard g(rewriter);
1621*c9fb3c6eSNicolas Vasilache       // The only difference between InsertSliceOp and ParallelInsertSliceOp is
1622*c9fb3c6eSNicolas Vasilache       // the the insertion point is just before the ParallelCombiningOp in the
1623*c9fb3c6eSNicolas Vasilache       // parallel case.
1624*c9fb3c6eSNicolas Vasilache       if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
1625*c9fb3c6eSNicolas Vasilache         rewriter.setInsertionPoint(insertSliceOp->getParentOp());
1626a08b750cSNicolas Vasilache       toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
1627a08b750cSNicolas Vasilache                                                  sourceType, toInsert);
1628*c9fb3c6eSNicolas Vasilache     }
1629*c9fb3c6eSNicolas Vasilache     rewriter.replaceOpWithNewOp<InsertOpTy>(
16302d70eff8SJacques Pienaar         insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
16312d70eff8SJacques Pienaar         mixedSizes, mixedStrides);
1632060208b4SMatthias Springer     return success();
1633060208b4SMatthias Springer   }
1634060208b4SMatthias Springer };
1635060208b4SMatthias Springer 
1636ebf35370SMatthias Springer /// Fold tensor_casts with insert_slice operations. If the source or destination
1637ebf35370SMatthias Springer /// tensor is a tensor_cast that removes static type information, the cast is
1638ebf35370SMatthias Springer /// folded into the insert_slice operation. E.g.:
1639ebf35370SMatthias Springer ///
1640ebf35370SMatthias Springer /// ```mlir
1641ebf35370SMatthias Springer ///   %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
1642ebf35370SMatthias Springer ///   %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
1643ebf35370SMatthias Springer /// ```
1644ebf35370SMatthias Springer ///
1645ebf35370SMatthias Springer /// folds into:
1646ebf35370SMatthias Springer ///
1647ebf35370SMatthias Springer /// ```mlir
1648ebf35370SMatthias Springer ///   %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
1649ebf35370SMatthias Springer /// ```
1650ebf35370SMatthias Springer ///
1651ebf35370SMatthias Springer /// Note: When folding a cast on the destination tensor, the result of the
1652ebf35370SMatthias Springer /// insert_slice operation is casted to ensure that the type of the result did
1653ebf35370SMatthias Springer /// not change.
1654*c9fb3c6eSNicolas Vasilache ///
1655*c9fb3c6eSNicolas Vasilache /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
1656*c9fb3c6eSNicolas Vasilache template <typename InsertOpTy>
1657*c9fb3c6eSNicolas Vasilache struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
1658*c9fb3c6eSNicolas Vasilache   using OpRewritePattern<InsertOpTy>::OpRewritePattern;
1659060208b4SMatthias Springer 
matchAndRewrite__anon3fb9f79f1111::InsertSliceOpCastFolder1660*c9fb3c6eSNicolas Vasilache   LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
1661060208b4SMatthias Springer                                 PatternRewriter &rewriter) const override {
1662060208b4SMatthias Springer     if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
1663060208b4SMatthias Springer           return matchPattern(operand, matchConstantIndex());
1664060208b4SMatthias Springer         }))
1665060208b4SMatthias Springer       return failure();
1666060208b4SMatthias Springer 
1667060208b4SMatthias Springer     auto getSourceOfCastOp = [](Value v) -> Optional<Value> {
1668060208b4SMatthias Springer       auto castOp = v.getDefiningOp<tensor::CastOp>();
1669060208b4SMatthias Springer       if (!castOp || !canFoldIntoConsumerOp(castOp))
1670060208b4SMatthias Springer         return llvm::None;
16712d70eff8SJacques Pienaar       return castOp.getSource();
1672060208b4SMatthias Springer     };
1673060208b4SMatthias Springer     Optional<Value> sourceCastSource =
16742d70eff8SJacques Pienaar         getSourceOfCastOp(insertSliceOp.getSource());
16752d70eff8SJacques Pienaar     Optional<Value> destCastSource = getSourceOfCastOp(insertSliceOp.getDest());
1676060208b4SMatthias Springer     if (!sourceCastSource && !destCastSource)
1677060208b4SMatthias Springer       return failure();
1678060208b4SMatthias Springer 
16792d70eff8SJacques Pienaar     auto src =
16802d70eff8SJacques Pienaar         (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
16812d70eff8SJacques Pienaar     auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
1682*c9fb3c6eSNicolas Vasilache     auto srcType = src.getType().template cast<ShapedType>();
1683*c9fb3c6eSNicolas Vasilache     auto dstType = dst.getType().template cast<ShapedType>();
16842d70eff8SJacques Pienaar     if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
16852d70eff8SJacques Pienaar                             insertSliceOp.getStaticSizes(),
16862d70eff8SJacques Pienaar                             insertSliceOp.getStaticStrides()) !=
1687cd0d095cSIvan Butygin         SliceVerificationResult::Success)
1688cd0d095cSIvan Butygin       return failure();
1689cd0d095cSIvan Butygin 
1690*c9fb3c6eSNicolas Vasilache     Operation *replacement = rewriter.create<InsertOpTy>(
1691cd0d095cSIvan Butygin         insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
1692cd0d095cSIvan Butygin         insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
1693060208b4SMatthias Springer 
1694*c9fb3c6eSNicolas Vasilache     // In the parallel case there is no result and so nothing to cast.
1695*c9fb3c6eSNicolas Vasilache     bool isParallelInsert =
1696*c9fb3c6eSNicolas Vasilache         std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
1697*c9fb3c6eSNicolas Vasilache     if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
1698*c9fb3c6eSNicolas Vasilache       replacement = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
1699*c9fb3c6eSNicolas Vasilache                                                     insertSliceOp.getDestType(),
1700*c9fb3c6eSNicolas Vasilache                                                     replacement->getResult(0));
1701060208b4SMatthias Springer     }
1702*c9fb3c6eSNicolas Vasilache     rewriter.replaceOp(insertSliceOp, replacement->getResults());
1703060208b4SMatthias Springer     return success();
1704060208b4SMatthias Springer   }
1705060208b4SMatthias Springer };
1706ebf35370SMatthias Springer 
1707ebf35370SMatthias Springer /// If additional static type information can be deduced from a insert_slice's
1708ebf35370SMatthias Springer /// size operands, insert an explicit cast of the op's source operand. This
1709ebf35370SMatthias Springer /// enables other canonicalization patterns that are matching for tensor_cast
1710ebf35370SMatthias Springer /// ops such as `ForOpTensorCastFolder` in SCF.
1711ebf35370SMatthias Springer ///
1712ebf35370SMatthias Springer /// Example:
1713ebf35370SMatthias Springer ///
1714ebf35370SMatthias Springer /// ```mlir
1715ebf35370SMatthias Springer ///   %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
1716ebf35370SMatthias Springer ///       : tensor<?x?xf32> into ...
1717ebf35370SMatthias Springer /// ```
1718ebf35370SMatthias Springer ///
1719ebf35370SMatthias Springer /// folds into:
1720ebf35370SMatthias Springer ///
1721ebf35370SMatthias Springer /// ```mlir
1722ebf35370SMatthias Springer ///   %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
1723ebf35370SMatthias Springer ///   %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
1724ebf35370SMatthias Springer ///       : tensor<64x64xf32> into ...
1725ebf35370SMatthias Springer /// ```
1726*c9fb3c6eSNicolas Vasilache ///
1727*c9fb3c6eSNicolas Vasilache /// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
1728*c9fb3c6eSNicolas Vasilache template <typename InsertOpTy>
1729ebf35370SMatthias Springer struct InsertSliceOpSourceCastInserter final
1730*c9fb3c6eSNicolas Vasilache     : public OpRewritePattern<InsertOpTy> {
1731*c9fb3c6eSNicolas Vasilache   using OpRewritePattern<InsertOpTy>::OpRewritePattern;
1732ebf35370SMatthias Springer 
matchAndRewrite__anon3fb9f79f1111::InsertSliceOpSourceCastInserter1733*c9fb3c6eSNicolas Vasilache   LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
1734ebf35370SMatthias Springer                                 PatternRewriter &rewriter) const override {
1735ebf35370SMatthias Springer     RankedTensorType srcType = insertSliceOp.getSourceType();
1736*c9fb3c6eSNicolas Vasilache     if (srcType.getRank() != insertSliceOp.getDestType().getRank())
1737ebf35370SMatthias Springer       return failure();
1738ebf35370SMatthias Springer     SmallVector<int64_t> newSrcShape(srcType.getShape().begin(),
1739ebf35370SMatthias Springer                                      srcType.getShape().end());
1740f2e945a3SNicolas Vasilache     for (int64_t i = 0; i < srcType.getRank(); ++i) {
1741f2e945a3SNicolas Vasilache       if (Optional<int64_t> constInt =
1742f2e945a3SNicolas Vasilache               getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
1743f2e945a3SNicolas Vasilache         newSrcShape[i] = *constInt;
1744f2e945a3SNicolas Vasilache     }
174582dd977bSNicolas Vasilache 
1746ebf35370SMatthias Springer     RankedTensorType newSrcType =
1747ebf35370SMatthias Springer         RankedTensorType::get(newSrcShape, srcType.getElementType());
174882dd977bSNicolas Vasilache     if (srcType == newSrcType ||
174982dd977bSNicolas Vasilache         !preservesStaticInformation(srcType, newSrcType) ||
175082dd977bSNicolas Vasilache         !tensor::CastOp::areCastCompatible(srcType, newSrcType))
1751ebf35370SMatthias Springer       return failure();
1752ebf35370SMatthias Springer 
175382dd977bSNicolas Vasilache     // newSrcType is:
175482dd977bSNicolas Vasilache     //   1) Different from srcType.
175582dd977bSNicolas Vasilache     //   2) "More static" than srcType.
175682dd977bSNicolas Vasilache     //   3) Cast-compatible with srcType.
175782dd977bSNicolas Vasilache     // Insert the cast.
1758*c9fb3c6eSNicolas Vasilache     OpBuilder::InsertionGuard g(rewriter);
1759*c9fb3c6eSNicolas Vasilache     // The only difference between InsertSliceOp and ParallelInsertSliceOp is
1760*c9fb3c6eSNicolas Vasilache     // the the insertion point is just before the ParallelCombiningOp in the
1761*c9fb3c6eSNicolas Vasilache     // parallel case.
1762*c9fb3c6eSNicolas Vasilache     if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
1763*c9fb3c6eSNicolas Vasilache       rewriter.setInsertionPoint(insertSliceOp->getParentOp());
1764ebf35370SMatthias Springer     Value cast = rewriter.create<tensor::CastOp>(
17652d70eff8SJacques Pienaar         insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
1766*c9fb3c6eSNicolas Vasilache     rewriter.replaceOpWithNewOp<InsertOpTy>(
17672d70eff8SJacques Pienaar         insertSliceOp, cast, insertSliceOp.getDest(),
1768ebf35370SMatthias Springer         insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
1769ebf35370SMatthias Springer         insertSliceOp.getMixedStrides());
1770*c9fb3c6eSNicolas Vasilache     cast.getDefiningOp()->getParentOfType<ModuleOp>().dump();
1771ebf35370SMatthias Springer     return success();
1772ebf35370SMatthias Springer   }
1773ebf35370SMatthias Springer };
1774060208b4SMatthias Springer } // namespace
1775060208b4SMatthias Springer 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1776060208b4SMatthias Springer void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
1777060208b4SMatthias Springer                                                 MLIRContext *context) {
1778*c9fb3c6eSNicolas Vasilache   results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
1779*c9fb3c6eSNicolas Vasilache               InsertSliceOpCastFolder<InsertSliceOp>,
1780*c9fb3c6eSNicolas Vasilache               InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
1781060208b4SMatthias Springer }
1782060208b4SMatthias Springer 
createCanonicalRankReducingInsertSliceOp(OpBuilder & b,Location loc,Value tensor,Value dest)1783aa373180SNicolas Vasilache Value mlir::tensor::createCanonicalRankReducingInsertSliceOp(OpBuilder &b,
1784aa373180SNicolas Vasilache                                                              Location loc,
1785aa373180SNicolas Vasilache                                                              Value tensor,
1786aa373180SNicolas Vasilache                                                              Value dest) {
1787aa373180SNicolas Vasilache   auto rankedTensorType = dest.getType().cast<RankedTensorType>();
1788aa373180SNicolas Vasilache   unsigned rank = rankedTensorType.getRank();
1789aa373180SNicolas Vasilache   auto shape = rankedTensorType.getShape();
1790aa373180SNicolas Vasilache   SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
1791aa373180SNicolas Vasilache   SmallVector<OpFoldResult> sizes;
1792aa373180SNicolas Vasilache   for (unsigned i = 0, e = rank; i < e; ++i) {
1793aa373180SNicolas Vasilache     OpFoldResult dim;
1794aa373180SNicolas Vasilache     if (rankedTensorType.isDynamicDim(i))
1795aa373180SNicolas Vasilache       dim = b.createOrFold<tensor::DimOp>(
1796aa373180SNicolas Vasilache           loc, dest, b.create<arith::ConstantIndexOp>(loc, i));
1797aa373180SNicolas Vasilache     else
1798aa373180SNicolas Vasilache       dim = b.getIndexAttr(shape[i]);
1799aa373180SNicolas Vasilache     sizes.push_back(dim);
1800aa373180SNicolas Vasilache   }
1801aa373180SNicolas Vasilache   SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
1802aa373180SNicolas Vasilache   return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
1803aa373180SNicolas Vasilache                                                sizes, strides);
1804aa373180SNicolas Vasilache }
1805aa373180SNicolas Vasilache 
1806060208b4SMatthias Springer //===----------------------------------------------------------------------===//
1807fd0c6f53SAlexander Belyaev // PadOp
1808fd0c6f53SAlexander Belyaev //===----------------------------------------------------------------------===//
1809fd0c6f53SAlexander Belyaev 
1810fd0c6f53SAlexander Belyaev // TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it
1811fd0c6f53SAlexander Belyaev // supports optional types.
printInferType(OpAsmPrinter & printer,Operation * op,Value optOperand,Type typeToInfer,Type typeToInferFrom)1812fd0c6f53SAlexander Belyaev void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand,
1813fd0c6f53SAlexander Belyaev                     Type typeToInfer, Type typeToInferFrom) {}
1814fd0c6f53SAlexander Belyaev 
parseInferType(OpAsmParser & parser,Optional<OpAsmParser::UnresolvedOperand> optOperand,Type & typeToInfer,Type typeToInferFrom)1815fd0c6f53SAlexander Belyaev ParseResult parseInferType(OpAsmParser &parser,
1816e13d23bcSMarkus Böck                            Optional<OpAsmParser::UnresolvedOperand> optOperand,
1817fd0c6f53SAlexander Belyaev                            Type &typeToInfer, Type typeToInferFrom) {
1818fd0c6f53SAlexander Belyaev   if (optOperand)
1819fd0c6f53SAlexander Belyaev     typeToInfer = typeToInferFrom;
1820fd0c6f53SAlexander Belyaev   return success();
1821fd0c6f53SAlexander Belyaev }
1822fd0c6f53SAlexander Belyaev 
verify()1823b98dc035SRiver Riddle LogicalResult PadOp::verify() {
18242d70eff8SJacques Pienaar   auto sourceType = getSource().getType().cast<RankedTensorType>();
18252d70eff8SJacques Pienaar   auto resultType = getResult().getType().cast<RankedTensorType>();
18262d70eff8SJacques Pienaar   auto expectedType = PadOp::inferResultType(
18272d70eff8SJacques Pienaar       sourceType, extractFromI64ArrayAttr(getStaticLow()),
18282d70eff8SJacques Pienaar       extractFromI64ArrayAttr(getStaticHigh()));
1829fd0c6f53SAlexander Belyaev   for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
1830fd0c6f53SAlexander Belyaev     if (resultType.getDimSize(i) == expectedType.getDimSize(i))
1831fd0c6f53SAlexander Belyaev       continue;
1832fd0c6f53SAlexander Belyaev     if (expectedType.isDynamicDim(i))
1833fd0c6f53SAlexander Belyaev       continue;
1834b98dc035SRiver Riddle     return emitError("specified type ")
1835fd0c6f53SAlexander Belyaev            << resultType << " does not match the inferred type "
1836fd0c6f53SAlexander Belyaev            << expectedType;
1837fd0c6f53SAlexander Belyaev   }
1838fd0c6f53SAlexander Belyaev 
1839ed645f63SChia-hung Duan   return success();
1840ed645f63SChia-hung Duan }
1841ed645f63SChia-hung Duan 
verifyRegions()1842ed645f63SChia-hung Duan LogicalResult PadOp::verifyRegions() {
1843b98dc035SRiver Riddle   auto &region = getRegion();
18442d70eff8SJacques Pienaar   unsigned rank = getResult().getType().cast<RankedTensorType>().getRank();
1845fd0c6f53SAlexander Belyaev   Block &block = region.front();
1846fd0c6f53SAlexander Belyaev   if (block.getNumArguments() != rank)
1847b98dc035SRiver Riddle     return emitError("expected the block to have ") << rank << " arguments";
1848fd0c6f53SAlexander Belyaev 
1849fd0c6f53SAlexander Belyaev   // Note: the number and type of yield values are checked in the YieldOp.
1850fd0c6f53SAlexander Belyaev   for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
1851fd0c6f53SAlexander Belyaev     if (!en.value().isIndex())
1852b98dc035SRiver Riddle       return emitOpError("expected block argument ")
1853fd0c6f53SAlexander Belyaev              << (en.index() + 1) << " to be an index";
1854fd0c6f53SAlexander Belyaev   }
1855fd0c6f53SAlexander Belyaev 
1856fd0c6f53SAlexander Belyaev   // Ensure that the region yields an element of the right type.
1857fd0c6f53SAlexander Belyaev   auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
18582d70eff8SJacques Pienaar   if (yieldOp.getValue().getType() !=
1859b98dc035SRiver Riddle       getType().cast<ShapedType>().getElementType())
1860b98dc035SRiver Riddle     return emitOpError("expected yield type to match shape element type");
1861fd0c6f53SAlexander Belyaev 
1862fd0c6f53SAlexander Belyaev   return success();
1863fd0c6f53SAlexander Belyaev }
1864fd0c6f53SAlexander Belyaev 
inferResultType(RankedTensorType sourceType,ArrayRef<int64_t> staticLow,ArrayRef<int64_t> staticHigh,ArrayRef<int64_t> resultShape)1865fd0c6f53SAlexander Belyaev RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
1866fd0c6f53SAlexander Belyaev                                         ArrayRef<int64_t> staticLow,
1867fd0c6f53SAlexander Belyaev                                         ArrayRef<int64_t> staticHigh,
1868fd0c6f53SAlexander Belyaev                                         ArrayRef<int64_t> resultShape) {
1869fd0c6f53SAlexander Belyaev   unsigned rank = sourceType.getRank();
1870fd0c6f53SAlexander Belyaev   assert(staticLow.size() == rank && "unexpected staticLow size mismatch");
1871fd0c6f53SAlexander Belyaev   assert(staticHigh.size() == rank && "unexpected staticHigh size mismatch");
1872fd0c6f53SAlexander Belyaev   assert((resultShape.empty() || resultShape.size() == rank) &&
1873fd0c6f53SAlexander Belyaev          "unexpected resultShape size mismatch");
1874fd0c6f53SAlexander Belyaev 
1875fd0c6f53SAlexander Belyaev   SmallVector<int64_t, 4> inferredShape;
1876fd0c6f53SAlexander Belyaev   for (auto i : llvm::seq<unsigned>(0, rank)) {
1877fd0c6f53SAlexander Belyaev     if (sourceType.isDynamicDim(i) ||
1878fd0c6f53SAlexander Belyaev         staticLow[i] == ShapedType::kDynamicSize ||
1879fd0c6f53SAlexander Belyaev         staticHigh[i] == ShapedType::kDynamicSize) {
1880fd0c6f53SAlexander Belyaev       inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamicSize
1881fd0c6f53SAlexander Belyaev                                                   : resultShape[i]);
1882fd0c6f53SAlexander Belyaev     } else {
1883fd0c6f53SAlexander Belyaev       int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
1884fd0c6f53SAlexander Belyaev       assert((resultShape.empty() || size == resultShape[i] ||
1885fd0c6f53SAlexander Belyaev               resultShape[i] == ShapedType::kDynamicSize) &&
1886fd0c6f53SAlexander Belyaev              "mismatch between inferred shape and result shape");
1887fd0c6f53SAlexander Belyaev       inferredShape.push_back(size);
1888fd0c6f53SAlexander Belyaev     }
1889fd0c6f53SAlexander Belyaev   }
1890fd0c6f53SAlexander Belyaev 
1891fd0c6f53SAlexander Belyaev   return RankedTensorType::get(inferredShape, sourceType.getElementType());
1892fd0c6f53SAlexander Belyaev }
1893fd0c6f53SAlexander Belyaev 
build(OpBuilder & b,OperationState & result,Value source,ArrayRef<int64_t> staticLow,ArrayRef<int64_t> staticHigh,ValueRange low,ValueRange high,bool nofold,ArrayRef<NamedAttribute> attrs)1894fd0c6f53SAlexander Belyaev void PadOp::build(OpBuilder &b, OperationState &result, Value source,
1895fd0c6f53SAlexander Belyaev                   ArrayRef<int64_t> staticLow, ArrayRef<int64_t> staticHigh,
1896fd0c6f53SAlexander Belyaev                   ValueRange low, ValueRange high, bool nofold,
1897fd0c6f53SAlexander Belyaev                   ArrayRef<NamedAttribute> attrs) {
1898fd0c6f53SAlexander Belyaev   auto sourceType = source.getType().cast<RankedTensorType>();
1899fd0c6f53SAlexander Belyaev   auto resultType = inferResultType(sourceType, staticLow, staticHigh);
1900fd0c6f53SAlexander Belyaev   build(b, result, resultType, source, low, high, b.getI64ArrayAttr(staticLow),
1901fd0c6f53SAlexander Belyaev         b.getI64ArrayAttr(staticHigh), nofold ? b.getUnitAttr() : UnitAttr());
1902fd0c6f53SAlexander Belyaev   result.addAttributes(attrs);
1903fd0c6f53SAlexander Belyaev }
1904fd0c6f53SAlexander Belyaev 
build(OpBuilder & b,OperationState & result,Value source,ValueRange low,ValueRange high,bool nofold,ArrayRef<NamedAttribute> attrs)1905fd0c6f53SAlexander Belyaev void PadOp::build(OpBuilder &b, OperationState &result, Value source,
1906fd0c6f53SAlexander Belyaev                   ValueRange low, ValueRange high, bool nofold,
1907fd0c6f53SAlexander Belyaev                   ArrayRef<NamedAttribute> attrs) {
1908fd0c6f53SAlexander Belyaev   auto sourceType = source.getType().cast<RankedTensorType>();
1909fd0c6f53SAlexander Belyaev   unsigned rank = sourceType.getRank();
1910fd0c6f53SAlexander Belyaev   SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamicSize);
1911fd0c6f53SAlexander Belyaev   build(b, result, source, staticVector, staticVector, low, high, nofold,
1912fd0c6f53SAlexander Belyaev         attrs);
1913fd0c6f53SAlexander Belyaev }
1914fd0c6f53SAlexander Belyaev 
build(OpBuilder & b,OperationState & result,Type resultType,Value source,ArrayRef<OpFoldResult> low,ArrayRef<OpFoldResult> high,bool nofold,ArrayRef<NamedAttribute> attrs)1915fd0c6f53SAlexander Belyaev void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
1916fd0c6f53SAlexander Belyaev                   Value source, ArrayRef<OpFoldResult> low,
1917fd0c6f53SAlexander Belyaev                   ArrayRef<OpFoldResult> high, bool nofold,
1918fd0c6f53SAlexander Belyaev                   ArrayRef<NamedAttribute> attrs) {
1919fd0c6f53SAlexander Belyaev   assert(resultType.isa<RankedTensorType>());
1920fd0c6f53SAlexander Belyaev   auto sourceType = source.getType().cast<RankedTensorType>();
1921fd0c6f53SAlexander Belyaev   SmallVector<Value, 4> dynamicLow, dynamicHigh;
1922fd0c6f53SAlexander Belyaev   SmallVector<int64_t, 4> staticLow, staticHigh;
1923fd0c6f53SAlexander Belyaev   // staticLow and staticHigh have full information of the padding config.
1924fd0c6f53SAlexander Belyaev   // This will grow staticLow and staticHigh with 1 value. If the config is
1925fd0c6f53SAlexander Belyaev   // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
1926fd0c6f53SAlexander Belyaev   // value as well.
1927fd0c6f53SAlexander Belyaev   dispatchIndexOpFoldResults(low, dynamicLow, staticLow,
1928fd0c6f53SAlexander Belyaev                              ShapedType::kDynamicSize);
1929fd0c6f53SAlexander Belyaev   dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh,
1930fd0c6f53SAlexander Belyaev                              ShapedType::kDynamicSize);
1931fd0c6f53SAlexander Belyaev   if (!resultType) {
1932fd0c6f53SAlexander Belyaev     resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
1933fd0c6f53SAlexander Belyaev   }
1934fd0c6f53SAlexander Belyaev   build(b, result, resultType, source, dynamicLow, dynamicHigh,
1935fd0c6f53SAlexander Belyaev         b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh),
1936fd0c6f53SAlexander Belyaev         nofold ? b.getUnitAttr() : UnitAttr());
1937fd0c6f53SAlexander Belyaev   result.addAttributes(attrs);
1938fd0c6f53SAlexander Belyaev }
1939fd0c6f53SAlexander Belyaev 
getPaddedDims()1940973dbe20Sgysit llvm::SmallBitVector PadOp::getPaddedDims() {
1941973dbe20Sgysit   llvm::SmallBitVector paddedDims(getSourceType().getRank());
1942973dbe20Sgysit   auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
1943973dbe20Sgysit     for (const auto &en : enumerate(paddingWidths))
1944973dbe20Sgysit       if (getConstantIntValue(en.value()) != static_cast<int64_t>(0))
1945973dbe20Sgysit         paddedDims.set(en.index());
1946973dbe20Sgysit   };
1947973dbe20Sgysit   extractPaddedDims(getMixedLowPad());
1948973dbe20Sgysit   extractPaddedDims(getMixedHighPad());
1949973dbe20Sgysit   return paddedDims;
1950973dbe20Sgysit }
1951973dbe20Sgysit 
1952fd0c6f53SAlexander Belyaev namespace {
1953fd0c6f53SAlexander Belyaev // Folds tensor.pad when padding is static zeros and the attribute
1954fd0c6f53SAlexander Belyaev // doesn't request otherwise.
1955fd0c6f53SAlexander Belyaev struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
1956fd0c6f53SAlexander Belyaev   using OpRewritePattern<PadOp>::OpRewritePattern;
1957fd0c6f53SAlexander Belyaev 
matchAndRewrite__anon3fb9f79f1611::FoldStaticZeroPadding1958fd0c6f53SAlexander Belyaev   LogicalResult matchAndRewrite(PadOp padTensorOp,
1959fd0c6f53SAlexander Belyaev                                 PatternRewriter &rewriter) const override {
1960fd0c6f53SAlexander Belyaev     if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
1961fd0c6f53SAlexander Belyaev       return failure();
19622d70eff8SJacques Pienaar     if (padTensorOp.getNofold())
1963fd0c6f53SAlexander Belyaev       return failure();
1964fd0c6f53SAlexander Belyaev     rewriter.replaceOpWithNewOp<tensor::CastOp>(
19652d70eff8SJacques Pienaar         padTensorOp, padTensorOp.getResult().getType(),
19662d70eff8SJacques Pienaar         padTensorOp.getSource());
1967fd0c6f53SAlexander Belyaev     return success();
1968fd0c6f53SAlexander Belyaev   }
1969fd0c6f53SAlexander Belyaev };
1970fd0c6f53SAlexander Belyaev 
1971fd0c6f53SAlexander Belyaev // Fold CastOp into PadOp when adding static information.
1972fd0c6f53SAlexander Belyaev struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
1973fd0c6f53SAlexander Belyaev   using OpRewritePattern<PadOp>::OpRewritePattern;
1974fd0c6f53SAlexander Belyaev 
matchAndRewrite__anon3fb9f79f1611::FoldSourceTensorCast1975fd0c6f53SAlexander Belyaev   LogicalResult matchAndRewrite(PadOp padTensorOp,
1976fd0c6f53SAlexander Belyaev                                 PatternRewriter &rewriter) const override {
19772d70eff8SJacques Pienaar     auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
1978fd0c6f53SAlexander Belyaev     if (!tensor::canFoldIntoConsumerOp(castOp))
1979fd0c6f53SAlexander Belyaev       return failure();
1980fd0c6f53SAlexander Belyaev 
1981fd0c6f53SAlexander Belyaev     auto newResultType = PadOp::inferResultType(
19822d70eff8SJacques Pienaar         castOp.getSource().getType().cast<RankedTensorType>(),
19832d70eff8SJacques Pienaar         extractFromI64ArrayAttr(padTensorOp.getStaticLow()),
19842d70eff8SJacques Pienaar         extractFromI64ArrayAttr(padTensorOp.getStaticHigh()),
1985fd0c6f53SAlexander Belyaev         padTensorOp.getResultType().getShape());
1986fd0c6f53SAlexander Belyaev 
1987fd0c6f53SAlexander Belyaev     if (newResultType == padTensorOp.getResultType()) {
1988fd0c6f53SAlexander Belyaev       rewriter.updateRootInPlace(padTensorOp, [&]() {
19892d70eff8SJacques Pienaar         padTensorOp.getSourceMutable().assign(castOp.getSource());
1990fd0c6f53SAlexander Belyaev       });
1991fd0c6f53SAlexander Belyaev     } else {
1992fd0c6f53SAlexander Belyaev       auto newOp = rewriter.create<PadOp>(
19932d70eff8SJacques Pienaar           padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
19942d70eff8SJacques Pienaar           padTensorOp.getLow(), padTensorOp.getHigh(),
19952d70eff8SJacques Pienaar           padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
19962d70eff8SJacques Pienaar           padTensorOp.getNofold());
1997fd0c6f53SAlexander Belyaev       BlockAndValueMapping mapper;
1998fd0c6f53SAlexander Belyaev       padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
1999fd0c6f53SAlexander Belyaev 
2000fd0c6f53SAlexander Belyaev       rewriter.replaceOpWithNewOp<tensor::CastOp>(
2001fd0c6f53SAlexander Belyaev           padTensorOp, padTensorOp.getResultType(), newOp);
2002fd0c6f53SAlexander Belyaev     }
2003fd0c6f53SAlexander Belyaev     return success();
2004fd0c6f53SAlexander Belyaev   }
2005fd0c6f53SAlexander Belyaev };
2006fd0c6f53SAlexander Belyaev 
2007fd0c6f53SAlexander Belyaev // Fold CastOp using the result of PadOp back into the latter if it adds
2008fd0c6f53SAlexander Belyaev // static information.
2009fd0c6f53SAlexander Belyaev struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
2010fd0c6f53SAlexander Belyaev   using OpRewritePattern<PadOp>::OpRewritePattern;
2011fd0c6f53SAlexander Belyaev 
matchAndRewrite__anon3fb9f79f1611::FoldTargetTensorCast2012fd0c6f53SAlexander Belyaev   LogicalResult matchAndRewrite(PadOp padTensorOp,
2013fd0c6f53SAlexander Belyaev                                 PatternRewriter &rewriter) const override {
20142d70eff8SJacques Pienaar     if (!padTensorOp.getResult().hasOneUse())
2015fd0c6f53SAlexander Belyaev       return failure();
2016fd0c6f53SAlexander Belyaev     auto tensorCastOp =
2017fd0c6f53SAlexander Belyaev         dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
2018fd0c6f53SAlexander Belyaev     if (!tensorCastOp)
2019fd0c6f53SAlexander Belyaev       return failure();
20202d70eff8SJacques Pienaar     if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(),
20212d70eff8SJacques Pienaar                                             tensorCastOp.getDest().getType()))
2022fd0c6f53SAlexander Belyaev       return failure();
2023fd0c6f53SAlexander Belyaev 
2024fd0c6f53SAlexander Belyaev     auto replacementOp = rewriter.create<PadOp>(
20252d70eff8SJacques Pienaar         padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
20262d70eff8SJacques Pienaar         padTensorOp.getSource(), padTensorOp.getLow(), padTensorOp.getHigh(),
20272d70eff8SJacques Pienaar         padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
20282d70eff8SJacques Pienaar         padTensorOp.getNofold());
20292d70eff8SJacques Pienaar     replacementOp.getRegion().takeBody(padTensorOp.getRegion());
2030fd0c6f53SAlexander Belyaev 
20312d70eff8SJacques Pienaar     rewriter.replaceOp(padTensorOp, replacementOp.getResult());
20322d70eff8SJacques Pienaar     rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
2033fd0c6f53SAlexander Belyaev     return success();
2034fd0c6f53SAlexander Belyaev   }
2035fd0c6f53SAlexander Belyaev };
2036973dbe20Sgysit 
2037973dbe20Sgysit /// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
2038973dbe20Sgysit /// different dimensions. The pattern applies if the following preconditions
2039973dbe20Sgysit /// hold:
2040973dbe20Sgysit ///   1) the tensor::ExtractSliceOps are not rank-reducing,
2041973dbe20Sgysit ///   2) the tensor::ExtractSliceOps have only unit-strides,
2042973dbe20Sgysit ///   3) the tensor::PadOps perform only high-padding,
2043973dbe20Sgysit ///   4) the tensor::PadOps have the same constant padding value,
2044973dbe20Sgysit ///   5) the tensor::PadOps do not have common padding dimensions,
2045973dbe20Sgysit ///   6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
2046973dbe20Sgysit ///      zero-offset for every dimension.
2047973dbe20Sgysit ///   7) the tensor::ExtractSliceOp sizes match the source tensor sizes for the
2048973dbe20Sgysit ///      padded source dimensions.
2049973dbe20Sgysit ///
2050973dbe20Sgysit /// Example:
2051973dbe20Sgysit ///
2052973dbe20Sgysit /// ```mlir
2053973dbe20Sgysit ///   %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
2054973dbe20Sgysit ///       : tensor<64x64xf32> to tensor<?x64xf32>
2055973dbe20Sgysit ///   %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
2056973dbe20Sgysit ///     } : tensor<?x64xf32> to tensor<8x64xf32>
2057973dbe20Sgysit ///   %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
2058973dbe20Sgysit ///        : tensor<8x64xf32> to tensor<8x?xf32>
2059973dbe20Sgysit ///   %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
2060973dbe20Sgysit ///     } : tensor<8x?xf32> to tensor<8x4xf32>
2061973dbe20Sgysit /// ```
2062973dbe20Sgysit ///
2063973dbe20Sgysit /// folds into:
2064973dbe20Sgysit ///
2065973dbe20Sgysit /// ```mlir
2066973dbe20Sgysit ///   %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
2067973dbe20Sgysit ///        : tensor<64x64xf32> to tensor<?x?xf32>
2068973dbe20Sgysit ///   %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
2069973dbe20Sgysit ///     } : tensor<?x?xf32> to tensor<8x4xf32>
2070973dbe20Sgysit /// ```
2071973dbe20Sgysit struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
2072973dbe20Sgysit   using OpRewritePattern<PadOp>::OpRewritePattern;
2073973dbe20Sgysit 
matchAndRewrite__anon3fb9f79f1611::FoldOrthogonalPaddings2074973dbe20Sgysit   LogicalResult matchAndRewrite(PadOp padOp,
2075973dbe20Sgysit                                 PatternRewriter &rewriter) const override {
20762d70eff8SJacques Pienaar     auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
2077973dbe20Sgysit     if (!innerSliceOp)
2078973dbe20Sgysit       return failure();
20792d70eff8SJacques Pienaar     auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
20802d70eff8SJacques Pienaar     if (!outerPadOp || outerPadOp.getNofold())
2081973dbe20Sgysit       return failure();
20822d70eff8SJacques Pienaar     auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
2083973dbe20Sgysit     if (!outerSliceOp)
2084973dbe20Sgysit       return failure();
2085973dbe20Sgysit 
2086973dbe20Sgysit     // 1) Fail if the chain is rank-reducing.
2087973dbe20Sgysit     int64_t rank = padOp.getSourceType().getRank();
2088973dbe20Sgysit     if (outerSliceOp.getSourceType().getRank() != rank) {
2089973dbe20Sgysit       return rewriter.notifyMatchFailure(padOp,
2090973dbe20Sgysit                                          "cannot fold rank-reducing chain");
2091973dbe20Sgysit     }
2092973dbe20Sgysit 
2093973dbe20Sgysit     // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
2094973dbe20Sgysit     if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
2095973dbe20Sgysit       return rewriter.notifyMatchFailure(
2096973dbe20Sgysit           padOp, "cannot fold non-unit stride ExtractSliceOps");
2097973dbe20Sgysit     }
2098973dbe20Sgysit 
2099973dbe20Sgysit     // 3) Fail if the tensor::PadOps have non-zero low padding.
2100973dbe20Sgysit     if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
2101973dbe20Sgysit       return rewriter.notifyMatchFailure(padOp,
2102973dbe20Sgysit                                          "cannot fold PadOps with low padding");
2103973dbe20Sgysit     }
2104973dbe20Sgysit 
2105973dbe20Sgysit     // 4) Fail if the tensor::PadOps padding values do not match.
2106973dbe20Sgysit     Attribute innerAttr, outerAttr;
2107973dbe20Sgysit     Value innerValue = padOp.getConstantPaddingValue();
2108973dbe20Sgysit     Value outerValue = outerPadOp.getConstantPaddingValue();
2109973dbe20Sgysit     if (!innerValue || !outerValue ||
2110973dbe20Sgysit         !matchPattern(innerValue, m_Constant(&innerAttr)) ||
2111973dbe20Sgysit         !matchPattern(outerValue, m_Constant(&outerAttr)) ||
2112973dbe20Sgysit         innerAttr != outerAttr) {
2113973dbe20Sgysit       return rewriter.notifyMatchFailure(
2114973dbe20Sgysit           padOp, "cannot fold PadOps with different padding values");
2115973dbe20Sgysit     }
2116973dbe20Sgysit 
2117973dbe20Sgysit     // 5) Fail if a dimension is padded by both tensor::PadOps.
2118973dbe20Sgysit     llvm::SmallBitVector innerDims = padOp.getPaddedDims();
2119973dbe20Sgysit     llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
2120973dbe20Sgysit     if (innerDims.anyCommon(outerDims)) {
2121973dbe20Sgysit       return rewriter.notifyMatchFailure(
2122973dbe20Sgysit           padOp, "cannot fold PadOps with common padding dimensions");
2123973dbe20Sgysit     }
2124973dbe20Sgysit 
2125973dbe20Sgysit     // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
2126973dbe20Sgysit     // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
2127973dbe20Sgysit     // for every dimension, and use the offset the other pair. Fail if no
2128973dbe20Sgysit     // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
2129973dbe20Sgysit     // exists.
2130973dbe20Sgysit     SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0));
2131973dbe20Sgysit     for (auto &en : enumerate(newOffsets)) {
2132973dbe20Sgysit       OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
2133973dbe20Sgysit       OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
2134973dbe20Sgysit       if (!innerDims.test(en.index()) &&
2135973dbe20Sgysit           (getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) {
2136973dbe20Sgysit         en.value() = outerOffset;
2137973dbe20Sgysit         continue;
2138973dbe20Sgysit       }
2139973dbe20Sgysit       if (!outerDims.test(en.index()) &&
2140973dbe20Sgysit           (getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) {
2141973dbe20Sgysit         en.value() = innerOffset;
2142973dbe20Sgysit         continue;
2143973dbe20Sgysit       }
2144973dbe20Sgysit       return rewriter.notifyMatchFailure(
2145973dbe20Sgysit           padOp, "cannot find zero-offset and zero-padding pair");
2146973dbe20Sgysit     }
2147973dbe20Sgysit 
2148973dbe20Sgysit     // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size of
2149973dbe20Sgysit     // the outer tensor::ExtractSliceOp for the dimensions padded by the outer
2150973dbe20Sgysit     // tensor::PadOp and fail if the size of the inner tensor::ExtractSliceOp
2151973dbe20Sgysit     // does not match the size of the padded dimension. Otherwise, take the size
2152973dbe20Sgysit     // of the inner tensor::ExtractSliceOp.
2153973dbe20Sgysit     SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
2154973dbe20Sgysit     for (auto &en : enumerate(newSizes)) {
2155973dbe20Sgysit       if (!outerDims.test(en.index()))
2156973dbe20Sgysit         continue;
2157973dbe20Sgysit       OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
2158973dbe20Sgysit       int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
2159973dbe20Sgysit       assert(!ShapedType::isDynamic(sourceSize) &&
2160973dbe20Sgysit              "expected padded dimension to have a static size");
2161973dbe20Sgysit       if (getConstantIntValue(sliceSize) != sourceSize) {
2162973dbe20Sgysit         return rewriter.notifyMatchFailure(
2163973dbe20Sgysit             padOp, "cannot fold since the inner ExtractSliceOp size does not "
2164973dbe20Sgysit                    "match the size of the outer padding");
2165973dbe20Sgysit       }
2166973dbe20Sgysit       en.value() = outerSliceOp.getMixedSizes()[en.index()];
2167973dbe20Sgysit     }
2168973dbe20Sgysit 
2169973dbe20Sgysit     // Combine the high paddings of the two tensor::PadOps.
2170973dbe20Sgysit     SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0));
2171973dbe20Sgysit     for (auto &en : enumerate(newHighPad)) {
2172973dbe20Sgysit       if (innerDims.test(en.index()))
2173973dbe20Sgysit         newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
2174973dbe20Sgysit       if (outerDims.test(en.index()))
2175973dbe20Sgysit         newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
2176973dbe20Sgysit     }
2177973dbe20Sgysit 
2178973dbe20Sgysit     // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs the
2179973dbe20Sgysit     // two paddings in one step.
2180973dbe20Sgysit     auto newSliceOp = rewriter.create<ExtractSliceOp>(
21812d70eff8SJacques Pienaar         padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
2182973dbe20Sgysit         innerSliceOp.getMixedStrides());
2183973dbe20Sgysit     auto newPadOp = rewriter.create<PadOp>(
2184973dbe20Sgysit         padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
21852d70eff8SJacques Pienaar         padOp.getMixedLowPad(), newHighPad, padOp.getNofold());
2186973dbe20Sgysit     rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
2187973dbe20Sgysit                                 newPadOp.getRegion().begin());
2188973dbe20Sgysit     rewriter.replaceOp(padOp, newPadOp.getResult());
2189973dbe20Sgysit     return success();
2190973dbe20Sgysit   }
2191973dbe20Sgysit };
2192973dbe20Sgysit 
2193fd0c6f53SAlexander Belyaev } // namespace
2194fd0c6f53SAlexander Belyaev 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2195fd0c6f53SAlexander Belyaev void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
2196fd0c6f53SAlexander Belyaev                                         MLIRContext *context) {
2197973dbe20Sgysit   results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
2198973dbe20Sgysit               FoldOrthogonalPaddings>(context);
2199fd0c6f53SAlexander Belyaev }
2200fd0c6f53SAlexander Belyaev 
2201fd0c6f53SAlexander Belyaev /// Return the padding value of the PadOp if it constant. In this context,
2202fd0c6f53SAlexander Belyaev /// "constant" means an actual constant or "defined outside of the block".
2203fd0c6f53SAlexander Belyaev ///
2204fd0c6f53SAlexander Belyaev /// Values are considered constant in three cases:
2205fd0c6f53SAlexander Belyaev ///  - A ConstantLike value.
2206fd0c6f53SAlexander Belyaev ///  - A basic block argument from a different block.
2207fd0c6f53SAlexander Belyaev ///  - A value defined outside of the block.
2208fd0c6f53SAlexander Belyaev ///
2209fd0c6f53SAlexander Belyaev /// If the padding value is not constant, an empty Value is returned.
getConstantPaddingValue()2210fd0c6f53SAlexander Belyaev Value PadOp::getConstantPaddingValue() {
2211fd0c6f53SAlexander Belyaev   auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
2212fd0c6f53SAlexander Belyaev   if (!yieldOp)
2213fd0c6f53SAlexander Belyaev     return {};
22142d70eff8SJacques Pienaar   Value padValue = yieldOp.getValue();
2215fd0c6f53SAlexander Belyaev   // Check if yield value is a constant.
2216fd0c6f53SAlexander Belyaev   if (matchPattern(padValue, m_Constant()))
2217fd0c6f53SAlexander Belyaev     return padValue;
2218fd0c6f53SAlexander Belyaev   // Check if yield value is defined inside the PadOp block.
2219fd0c6f53SAlexander Belyaev   if (padValue.getParentBlock() == &getRegion().front())
2220fd0c6f53SAlexander Belyaev     return {};
2221fd0c6f53SAlexander Belyaev   // Else: Yield value defined outside of the PadOp block.
2222fd0c6f53SAlexander Belyaev   return padValue;
2223fd0c6f53SAlexander Belyaev }
2224fd0c6f53SAlexander Belyaev 
fold(ArrayRef<Attribute>)2225fd0c6f53SAlexander Belyaev OpFoldResult PadOp::fold(ArrayRef<Attribute>) {
2226fd0c6f53SAlexander Belyaev   if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
22272d70eff8SJacques Pienaar       !getNofold())
22282d70eff8SJacques Pienaar     return getSource();
2229fd0c6f53SAlexander Belyaev   return {};
2230fd0c6f53SAlexander Belyaev }
2231fd0c6f53SAlexander Belyaev 
2232fd0c6f53SAlexander Belyaev //===----------------------------------------------------------------------===//
22337fbf55c9SNicolas Vasilache // ParallelInsertSliceOp
22347fbf55c9SNicolas Vasilache //===----------------------------------------------------------------------===//
22357fbf55c9SNicolas Vasilache 
getTiedOpResult()22367fbf55c9SNicolas Vasilache OpResult ParallelInsertSliceOp::getTiedOpResult() {
22377fbf55c9SNicolas Vasilache   ParallelCombiningOpInterface parallelCombiningParent =
22387fbf55c9SNicolas Vasilache       getParallelCombiningParent();
22397fbf55c9SNicolas Vasilache   for (const auto &it :
22407fbf55c9SNicolas Vasilache        llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
22417fbf55c9SNicolas Vasilache     Operation &nextOp = it.value();
22427fbf55c9SNicolas Vasilache     if (&nextOp == getOperation())
22437fbf55c9SNicolas Vasilache       return parallelCombiningParent.getParentResult(it.index());
22447fbf55c9SNicolas Vasilache   }
22457fbf55c9SNicolas Vasilache   llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
22467fbf55c9SNicolas Vasilache }
22477fbf55c9SNicolas Vasilache 
22487fbf55c9SNicolas Vasilache // Build a ParallelInsertSliceOp with mixed static and dynamic entries.
build(OpBuilder & b,OperationState & result,Value source,Value dest,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides,ArrayRef<NamedAttribute> attrs)22497fbf55c9SNicolas Vasilache void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
22507fbf55c9SNicolas Vasilache                                   Value source, Value dest,
22517fbf55c9SNicolas Vasilache                                   ArrayRef<OpFoldResult> offsets,
22527fbf55c9SNicolas Vasilache                                   ArrayRef<OpFoldResult> sizes,
22537fbf55c9SNicolas Vasilache                                   ArrayRef<OpFoldResult> strides,
22547fbf55c9SNicolas Vasilache                                   ArrayRef<NamedAttribute> attrs) {
22557fbf55c9SNicolas Vasilache   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
22567fbf55c9SNicolas Vasilache   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
22577fbf55c9SNicolas Vasilache   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
22587fbf55c9SNicolas Vasilache                              ShapedType::kDynamicStrideOrOffset);
22597fbf55c9SNicolas Vasilache   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
22607fbf55c9SNicolas Vasilache                              ShapedType::kDynamicSize);
22617fbf55c9SNicolas Vasilache   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
22627fbf55c9SNicolas Vasilache                              ShapedType::kDynamicStrideOrOffset);
22637fbf55c9SNicolas Vasilache   build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
22647fbf55c9SNicolas Vasilache         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
22657fbf55c9SNicolas Vasilache         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
22667fbf55c9SNicolas Vasilache   result.addAttributes(attrs);
22677fbf55c9SNicolas Vasilache }
22687fbf55c9SNicolas Vasilache 
22697fbf55c9SNicolas Vasilache // Build a ParallelInsertSliceOp with dynamic entries.
build(OpBuilder & b,OperationState & result,Value source,Value dest,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)22707fbf55c9SNicolas Vasilache void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
22717fbf55c9SNicolas Vasilache                                   Value source, Value dest, ValueRange offsets,
22727fbf55c9SNicolas Vasilache                                   ValueRange sizes, ValueRange strides,
22737fbf55c9SNicolas Vasilache                                   ArrayRef<NamedAttribute> attrs) {
22747fbf55c9SNicolas Vasilache   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
22757fbf55c9SNicolas Vasilache       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
22767fbf55c9SNicolas Vasilache   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
22777fbf55c9SNicolas Vasilache       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
22787fbf55c9SNicolas Vasilache   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
22797fbf55c9SNicolas Vasilache       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
22807fbf55c9SNicolas Vasilache   build(b, result, source, dest, offsetValues, sizeValues, strideValues);
22817fbf55c9SNicolas Vasilache }
22827fbf55c9SNicolas Vasilache 
verify()22837fbf55c9SNicolas Vasilache LogicalResult ParallelInsertSliceOp::verify() {
22847fbf55c9SNicolas Vasilache   if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
22857fbf55c9SNicolas Vasilache     return this->emitError("expected ParallelCombiningOpInterface parent, got:")
22867fbf55c9SNicolas Vasilache            << *(getOperation()->getParentOp());
2287*c9fb3c6eSNicolas Vasilache 
2288*c9fb3c6eSNicolas Vasilache   ShapedType expectedType;
2289*c9fb3c6eSNicolas Vasilache   SliceVerificationResult result =
2290*c9fb3c6eSNicolas Vasilache       verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
2291*c9fb3c6eSNicolas Vasilache                           getStaticSizes(), getStaticStrides(), &expectedType);
2292*c9fb3c6eSNicolas Vasilache   return produceSliceErrorMsg(result, *this, expectedType);
22937fbf55c9SNicolas Vasilache }
22947fbf55c9SNicolas Vasilache 
22957fbf55c9SNicolas Vasilache namespace {
22967fbf55c9SNicolas Vasilache /// Pattern to rewrite a parallel_insert_slice op with constant arguments.
22977fbf55c9SNicolas Vasilache class ParallelInsertSliceOpConstantArgumentFolder final
22987fbf55c9SNicolas Vasilache     : public OpRewritePattern<ParallelInsertSliceOp> {
22997fbf55c9SNicolas Vasilache public:
23007fbf55c9SNicolas Vasilache   using OpRewritePattern<ParallelInsertSliceOp>::OpRewritePattern;
23017fbf55c9SNicolas Vasilache 
matchAndRewrite(ParallelInsertSliceOp insertSliceOp,PatternRewriter & rewriter) const23027fbf55c9SNicolas Vasilache   LogicalResult matchAndRewrite(ParallelInsertSliceOp insertSliceOp,
23037fbf55c9SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
23047fbf55c9SNicolas Vasilache     // No constant operand, just return.
23057fbf55c9SNicolas Vasilache     if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
23067fbf55c9SNicolas Vasilache           return matchPattern(operand, matchConstantIndex());
23077fbf55c9SNicolas Vasilache         }))
23087fbf55c9SNicolas Vasilache       return failure();
23097fbf55c9SNicolas Vasilache 
23107fbf55c9SNicolas Vasilache     // At least one of offsets/sizes/strides is a new constant.
23117fbf55c9SNicolas Vasilache     // Form the new list of operands and constant attributes from the
23127fbf55c9SNicolas Vasilache     // existing.
23137fbf55c9SNicolas Vasilache     SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
23147fbf55c9SNicolas Vasilache     SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
23157fbf55c9SNicolas Vasilache     SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
23167fbf55c9SNicolas Vasilache     canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
23177fbf55c9SNicolas Vasilache     canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
23187fbf55c9SNicolas Vasilache     canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
23197fbf55c9SNicolas Vasilache 
23207fbf55c9SNicolas Vasilache     // Create the new op in canonical form.
2321*c9fb3c6eSNicolas Vasilache     auto sourceType =
2322*c9fb3c6eSNicolas Vasilache         tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
2323*c9fb3c6eSNicolas Vasilache             insertSliceOp.getSourceType().getRank(),
2324*c9fb3c6eSNicolas Vasilache             insertSliceOp.getDestType(), mixedOffsets, mixedSizes,
2325*c9fb3c6eSNicolas Vasilache             mixedStrides);
2326*c9fb3c6eSNicolas Vasilache     Value toInsert = insertSliceOp.getSource();
2327*c9fb3c6eSNicolas Vasilache     if (sourceType != insertSliceOp.getSourceType()) {
2328*c9fb3c6eSNicolas Vasilache       OpBuilder::InsertionGuard g(rewriter);
2329*c9fb3c6eSNicolas Vasilache       rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2330*c9fb3c6eSNicolas Vasilache       toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
2331*c9fb3c6eSNicolas Vasilache                                                  sourceType, toInsert);
2332*c9fb3c6eSNicolas Vasilache     }
23337fbf55c9SNicolas Vasilache     rewriter.replaceOpWithNewOp<ParallelInsertSliceOp>(
2334*c9fb3c6eSNicolas Vasilache         insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2335*c9fb3c6eSNicolas Vasilache         mixedSizes, mixedStrides);
23367fbf55c9SNicolas Vasilache     return success();
23377fbf55c9SNicolas Vasilache   }
23387fbf55c9SNicolas Vasilache };
23397fbf55c9SNicolas Vasilache } // namespace
23407fbf55c9SNicolas Vasilache 
23417fbf55c9SNicolas Vasilache LogicalResult
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)23427fbf55c9SNicolas Vasilache ParallelInsertSliceOp::fold(ArrayRef<Attribute> operands,
23437fbf55c9SNicolas Vasilache                             SmallVectorImpl<OpFoldResult> &results) {
2344*c9fb3c6eSNicolas Vasilache   return foldInsertOp(*this, operands);
23457fbf55c9SNicolas Vasilache }
23467fbf55c9SNicolas Vasilache 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)23477fbf55c9SNicolas Vasilache void ParallelInsertSliceOp::getCanonicalizationPatterns(
23487fbf55c9SNicolas Vasilache     RewritePatternSet &results, MLIRContext *context) {
2349*c9fb3c6eSNicolas Vasilache   results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
2350*c9fb3c6eSNicolas Vasilache               InsertSliceOpCastFolder<ParallelInsertSliceOp>,
2351*c9fb3c6eSNicolas Vasilache               InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
23527fbf55c9SNicolas Vasilache }
23537fbf55c9SNicolas Vasilache 
23547fbf55c9SNicolas Vasilache //===----------------------------------------------------------------------===//
23556a8ba318SRiver Riddle // SplatOp
23566a8ba318SRiver Riddle //===----------------------------------------------------------------------===//
23576a8ba318SRiver Riddle 
fold(ArrayRef<Attribute> operands)23586a8ba318SRiver Riddle OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
23596a8ba318SRiver Riddle   auto constOperand = operands.front();
23606a8ba318SRiver Riddle   if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
23616a8ba318SRiver Riddle     return {};
23626a8ba318SRiver Riddle 
23636a8ba318SRiver Riddle   // SplatElementsAttr::get treats single value for second arg as being a splat.
23646a8ba318SRiver Riddle   return SplatElementsAttr::get(getType(), {constOperand});
23656a8ba318SRiver Riddle }
23666a8ba318SRiver Riddle 
23676a8ba318SRiver Riddle //===----------------------------------------------------------------------===//
2368444822d7SSean Silva // TableGen'd op method definitions
2369444822d7SSean Silva //===----------------------------------------------------------------------===//
2370444822d7SSean Silva 
2371444822d7SSean Silva #define GET_OP_CLASSES
2372444822d7SSean Silva #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
2373