1444822d7SSean Silva //===----------------------------------------------------------------------===//
2444822d7SSean Silva //
3444822d7SSean Silva // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4444822d7SSean Silva // See https://llvm.org/LICENSE.txt for license information.
5444822d7SSean Silva // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6444822d7SSean Silva //
7444822d7SSean Silva //===----------------------------------------------------------------------===//
8444822d7SSean Silva
9eda6f907SRiver Riddle #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10ead11072SRiver Riddle #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
11eda6f907SRiver Riddle #include "mlir/Dialect/Complex/IR/Complex.h"
124f5eb53eSOkwan Kwon #include "mlir/Dialect/Tensor/IR/Tensor.h"
13b618880eSAlexander Belyaev #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
140813700dSMatthias Springer #include "mlir/Dialect/Utils/StaticValueUtils.h"
15be7352c0SSean Silva #include "mlir/IR/BlockAndValueMapping.h"
16444822d7SSean Silva #include "mlir/IR/Builders.h"
17a08b750cSNicolas Vasilache #include "mlir/IR/BuiltinAttributeInterfaces.h"
18be7352c0SSean Silva #include "mlir/IR/Matchers.h"
19444822d7SSean Silva #include "mlir/IR/TypeUtilities.h"
20444822d7SSean Silva #include "llvm/ADT/STLExtras.h"
216635c12aSBenjamin Kramer #include "llvm/ADT/SmallBitVector.h"
22444822d7SSean Silva
23444822d7SSean Silva using namespace mlir;
24444822d7SSean Silva using namespace mlir::tensor;
25444822d7SSean Silva
26c0a6318dSMatthias Springer /// Materialize a single constant operation from a given attribute value with
27c0a6318dSMatthias Springer /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)28c0a6318dSMatthias Springer Operation *TensorDialect::materializeConstant(OpBuilder &builder,
29c0a6318dSMatthias Springer Attribute value, Type type,
30c0a6318dSMatthias Springer Location loc) {
31a54f4eaeSMogball if (arith::ConstantOp::isBuildableWith(value, type))
32a54f4eaeSMogball return builder.create<arith::ConstantOp>(loc, value, type);
332c7b0685SFrederik Gossen if (complex::ConstantOp::isBuildableWith(value, type))
342c7b0685SFrederik Gossen return builder.create<complex::ConstantOp>(loc, type,
352c7b0685SFrederik Gossen value.cast<ArrayAttr>());
36a54f4eaeSMogball return nullptr;
37c0a6318dSMatthias Springer }
38c0a6318dSMatthias Springer
39444822d7SSean Silva //===----------------------------------------------------------------------===//
40129d6e55SSean Silva // CastOp
41129d6e55SSean Silva //===----------------------------------------------------------------------===//
42129d6e55SSean Silva
433f89e339SAlex Zinenko /// Returns true if `target` is a ranked tensor type that preserves static
443f89e339SAlex Zinenko /// information available in the `source` ranked tensor type.
preservesStaticInformation(Type source,Type target)453f89e339SAlex Zinenko bool mlir::tensor::preservesStaticInformation(Type source, Type target) {
463f89e339SAlex Zinenko auto sourceType = source.dyn_cast<RankedTensorType>();
473f89e339SAlex Zinenko auto targetType = target.dyn_cast<RankedTensorType>();
483f89e339SAlex Zinenko
493f89e339SAlex Zinenko // Requires RankedTensorType.
503f89e339SAlex Zinenko if (!sourceType || !targetType)
513f89e339SAlex Zinenko return false;
523f89e339SAlex Zinenko
533f89e339SAlex Zinenko // Requires same elemental type.
543f89e339SAlex Zinenko if (sourceType.getElementType() != targetType.getElementType())
553f89e339SAlex Zinenko return false;
563f89e339SAlex Zinenko
573f89e339SAlex Zinenko // Requires same rank.
583f89e339SAlex Zinenko if (sourceType.getRank() != targetType.getRank())
593f89e339SAlex Zinenko return false;
603f89e339SAlex Zinenko
613f89e339SAlex Zinenko // If cast is towards more static sizes along any dimension, don't fold.
623f89e339SAlex Zinenko for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
633f89e339SAlex Zinenko if (!ShapedType::isDynamic(std::get<0>(t)) &&
643f89e339SAlex Zinenko ShapedType::isDynamic(std::get<1>(t)))
653f89e339SAlex Zinenko return false;
663f89e339SAlex Zinenko }
673f89e339SAlex Zinenko
683f89e339SAlex Zinenko return true;
693f89e339SAlex Zinenko }
703f89e339SAlex Zinenko
71129d6e55SSean Silva /// Determines whether tensor::CastOp casts to a more dynamic version of the
72129d6e55SSean Silva /// source tensor. This is useful to fold a tensor.cast into a consuming op and
73129d6e55SSean Silva /// implement canonicalization patterns for ops in different dialects that may
74129d6e55SSean Silva /// consume the results of tensor.cast operations. Such foldable tensor.cast
75060208b4SMatthias Springer /// operations are typically inserted as `slice` ops and are canonicalized,
76129d6e55SSean Silva /// to preserve the type compatibility of their uses.
77129d6e55SSean Silva ///
78129d6e55SSean Silva /// Returns true when all conditions are met:
79129d6e55SSean Silva /// 1. source and result are ranked tensors with same element type and rank.
80129d6e55SSean Silva /// 2. the tensor type has more static information than the result
81129d6e55SSean Silva ///
82129d6e55SSean Silva /// Example:
83129d6e55SSean Silva /// ```mlir
84129d6e55SSean Silva /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
85129d6e55SSean Silva /// %2 = consumer %1 ... : tensor<?x?xf32> ...
86129d6e55SSean Silva /// ```
87129d6e55SSean Silva ///
88129d6e55SSean Silva /// folds into:
89129d6e55SSean Silva ///
90129d6e55SSean Silva /// ```mlir
91129d6e55SSean Silva /// %2 = consumer %0 ... : tensor<8x16xf32> ...
92129d6e55SSean Silva /// ```
canFoldIntoConsumerOp(CastOp castOp)93129d6e55SSean Silva bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
94129d6e55SSean Silva if (!castOp)
95129d6e55SSean Silva return false;
96129d6e55SSean Silva
973f89e339SAlex Zinenko // Can fold if the source of cast has at least as much static information as
983f89e339SAlex Zinenko // its results.
993f89e339SAlex Zinenko return preservesStaticInformation(castOp.getType(),
1002d70eff8SJacques Pienaar castOp.getSource().getType());
101129d6e55SSean Silva }
102129d6e55SSean Silva
103589eac65SMahesh Ravishankar /// Determines whether the tensor::CastOp casts to a more static version of the
104589eac65SMahesh Ravishankar /// source tensor. This is useful to fold into a producing op and implement
105589eac65SMahesh Ravishankar /// canonicaliation patterns with the `tensor.cast` op as the root, but producer
106589eac65SMahesh Ravishankar /// being from different dialects. Returns true when all conditions are met:
107589eac65SMahesh Ravishankar /// 1. source and result and ranked tensors with same element type and rank.
108589eac65SMahesh Ravishankar /// 2. the result type has more static information than the source.
109589eac65SMahesh Ravishankar ///
110589eac65SMahesh Ravishankar /// Example:
111589eac65SMahesh Ravishankar /// ```mlir
112589eac65SMahesh Ravishankar /// %1 = producer ... : tensor<?x?xf32>
113589eac65SMahesh Ravishankar /// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
114589eac65SMahesh Ravishankar /// ```
115589eac65SMahesh Ravishankar ///
116589eac65SMahesh Ravishankar /// can be canonicalized to :
117589eac65SMahesh Ravishankar ///
118589eac65SMahesh Ravishankar /// ```mlir
119589eac65SMahesh Ravishankar /// %2 = producer ... : tensor<8x16xf32>
120589eac65SMahesh Ravishankar /// ```
121589eac65SMahesh Ravishankar /// Not all ops might be canonicalizable this way, but for those that can be,
122589eac65SMahesh Ravishankar /// this method provides a check that it is worth doing the canonicalization.
canFoldIntoProducerOp(CastOp castOp)123589eac65SMahesh Ravishankar bool mlir::tensor::canFoldIntoProducerOp(CastOp castOp) {
124589eac65SMahesh Ravishankar if (!castOp)
125589eac65SMahesh Ravishankar return false;
1262d70eff8SJacques Pienaar return preservesStaticInformation(castOp.getSource().getType(),
127589eac65SMahesh Ravishankar castOp.getType());
128589eac65SMahesh Ravishankar }
129589eac65SMahesh Ravishankar
1300ee4bf15SNicolas Vasilache /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
1310ee4bf15SNicolas Vasilache /// that can be folded.
foldTensorCast(Operation * op)1320ee4bf15SNicolas Vasilache LogicalResult mlir::tensor::foldTensorCast(Operation *op) {
1330ee4bf15SNicolas Vasilache bool folded = false;
1340ee4bf15SNicolas Vasilache for (OpOperand &operand : op->getOpOperands()) {
1350ee4bf15SNicolas Vasilache auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
1360ee4bf15SNicolas Vasilache if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
1370ee4bf15SNicolas Vasilache operand.set(castOp.getOperand());
1380ee4bf15SNicolas Vasilache folded = true;
1390ee4bf15SNicolas Vasilache }
1400ee4bf15SNicolas Vasilache }
1410ee4bf15SNicolas Vasilache return success(folded);
1420ee4bf15SNicolas Vasilache }
1430ee4bf15SNicolas Vasilache
areCastCompatible(TypeRange inputs,TypeRange outputs)1446ccf2d62SRiver Riddle bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1456ccf2d62SRiver Riddle if (inputs.size() != 1 || outputs.size() != 1)
1466ccf2d62SRiver Riddle return false;
1476ccf2d62SRiver Riddle Type a = inputs.front(), b = outputs.front();
148129d6e55SSean Silva auto aT = a.dyn_cast<TensorType>();
149129d6e55SSean Silva auto bT = b.dyn_cast<TensorType>();
150129d6e55SSean Silva if (!aT || !bT)
151129d6e55SSean Silva return false;
152129d6e55SSean Silva
153129d6e55SSean Silva if (aT.getElementType() != bT.getElementType())
154129d6e55SSean Silva return false;
155129d6e55SSean Silva
156129d6e55SSean Silva return succeeded(verifyCompatibleShape(aT, bT));
157129d6e55SSean Silva }
158129d6e55SSean Silva
159129d6e55SSean Silva /// Compute a TensorType that has the joined shape knowledge of the two
160129d6e55SSean Silva /// given TensorTypes. The element types need to match.
joinShapes(TensorType one,TensorType two)161129d6e55SSean Silva static TensorType joinShapes(TensorType one, TensorType two) {
162129d6e55SSean Silva assert(one.getElementType() == two.getElementType());
163129d6e55SSean Silva
164129d6e55SSean Silva if (!one.hasRank())
165129d6e55SSean Silva return two;
166129d6e55SSean Silva if (!two.hasRank())
167129d6e55SSean Silva return one;
168129d6e55SSean Silva
169129d6e55SSean Silva int64_t rank = one.getRank();
170129d6e55SSean Silva if (rank != two.getRank())
171129d6e55SSean Silva return {};
172129d6e55SSean Silva
173129d6e55SSean Silva SmallVector<int64_t, 4> join;
174129d6e55SSean Silva join.reserve(rank);
175129d6e55SSean Silva for (int64_t i = 0; i < rank; ++i) {
176129d6e55SSean Silva if (one.isDynamicDim(i)) {
177129d6e55SSean Silva join.push_back(two.getDimSize(i));
178129d6e55SSean Silva continue;
179129d6e55SSean Silva }
180129d6e55SSean Silva if (two.isDynamicDim(i)) {
181129d6e55SSean Silva join.push_back(one.getDimSize(i));
182129d6e55SSean Silva continue;
183129d6e55SSean Silva }
184129d6e55SSean Silva if (one.getDimSize(i) != two.getDimSize(i))
185129d6e55SSean Silva return {};
186129d6e55SSean Silva join.push_back(one.getDimSize(i));
187129d6e55SSean Silva }
188129d6e55SSean Silva return RankedTensorType::get(join, one.getElementType());
189129d6e55SSean Silva }
190129d6e55SSean Silva
191129d6e55SSean Silva namespace {
192129d6e55SSean Silva
193129d6e55SSean Silva /// Replaces chains of two tensor.cast operations by a single tensor.cast
194129d6e55SSean Silva /// operation if doing so does not remove runtime constraints.
195129d6e55SSean Silva struct ChainedTensorCast : public OpRewritePattern<CastOp> {
196129d6e55SSean Silva using OpRewritePattern<CastOp>::OpRewritePattern;
197129d6e55SSean Silva
matchAndRewrite__anon3fb9f79f0111::ChainedTensorCast198129d6e55SSean Silva LogicalResult matchAndRewrite(CastOp tensorCast,
199129d6e55SSean Silva PatternRewriter &rewriter) const final {
200129d6e55SSean Silva auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
201129d6e55SSean Silva
202129d6e55SSean Silva if (!tensorCastOperand)
203129d6e55SSean Silva return failure();
204129d6e55SSean Silva
205129d6e55SSean Silva auto sourceType =
206129d6e55SSean Silva tensorCastOperand.getOperand().getType().cast<TensorType>();
207129d6e55SSean Silva auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
208129d6e55SSean Silva auto resultType = tensorCast.getType().cast<TensorType>();
209129d6e55SSean Silva
210129d6e55SSean Silva // We can remove the intermediate cast if joining all three produces the
211129d6e55SSean Silva // same result as just joining the source and result shapes.
212129d6e55SSean Silva auto firstJoin =
213129d6e55SSean Silva joinShapes(joinShapes(sourceType, intermediateType), resultType);
214129d6e55SSean Silva
215129d6e55SSean Silva // The join might not exist if the cast sequence would fail at runtime.
216129d6e55SSean Silva if (!firstJoin)
217129d6e55SSean Silva return failure();
218129d6e55SSean Silva
219129d6e55SSean Silva // The newJoin always exists if the above join exists, it might just contain
220129d6e55SSean Silva // less information. If so, we cannot drop the intermediate cast, as doing
221129d6e55SSean Silva // so would remove runtime checks.
222129d6e55SSean Silva auto newJoin = joinShapes(sourceType, resultType);
223129d6e55SSean Silva if (firstJoin != newJoin)
224129d6e55SSean Silva return failure();
225129d6e55SSean Silva
226129d6e55SSean Silva rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
227129d6e55SSean Silva tensorCastOperand.getOperand());
228129d6e55SSean Silva return success();
229129d6e55SSean Silva }
230129d6e55SSean Silva };
231129d6e55SSean Silva
232f2676b15SThomas Raoux /// Fold tensor.cast into tesor.extract_slice producer.
233f2676b15SThomas Raoux /// Example:
234f2676b15SThomas Raoux /// ```
235f2676b15SThomas Raoux /// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
236f2676b15SThomas Raoux /// tensor<128x512xf32> to tensor<?x512xf32>
237f2676b15SThomas Raoux /// %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
238f2676b15SThomas Raoux /// ```
239f2676b15SThomas Raoux /// ->
240f2676b15SThomas Raoux /// ```
241f2676b15SThomas Raoux /// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
242f2676b15SThomas Raoux /// tensor<128x512xf32> to tensor<16x512xf32>
243f2676b15SThomas Raoux /// ```
244f2676b15SThomas Raoux struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
245f2676b15SThomas Raoux using OpRewritePattern<CastOp>::OpRewritePattern;
246f2676b15SThomas Raoux
matchAndRewrite__anon3fb9f79f0111::TensorCastExtractSlice247f2676b15SThomas Raoux LogicalResult matchAndRewrite(CastOp tensorCast,
248f2676b15SThomas Raoux PatternRewriter &rewriter) const final {
249f2676b15SThomas Raoux auto extractOperand =
250f2676b15SThomas Raoux tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
251f2676b15SThomas Raoux
252f2676b15SThomas Raoux if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
2532d70eff8SJacques Pienaar tensorCast.getType().getShape() == tensorCast.getSource()
2542d70eff8SJacques Pienaar .getType()
2552d70eff8SJacques Pienaar .cast<RankedTensorType>()
2562d70eff8SJacques Pienaar .getShape())
257f2676b15SThomas Raoux return failure();
258f2676b15SThomas Raoux
259f2676b15SThomas Raoux SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
260f2676b15SThomas Raoux auto dimMask = computeRankReductionMask(
2612d70eff8SJacques Pienaar extractFromI64ArrayAttr(extractOperand.getStaticSizes()),
262f2676b15SThomas Raoux extractOperand.getType().getShape());
263f2676b15SThomas Raoux size_t dimIndex = 0;
264f2676b15SThomas Raoux for (size_t i = 0, e = sizes.size(); i < e; i++) {
265f2676b15SThomas Raoux if (dimMask && dimMask->count(i))
266f2676b15SThomas Raoux continue;
267f2676b15SThomas Raoux int64_t dim = tensorCast.getType().getShape()[dimIndex++];
268f2676b15SThomas Raoux if (ShapedType::isDynamic(dim))
269f2676b15SThomas Raoux continue;
270f2676b15SThomas Raoux sizes[i] = rewriter.getIndexAttr(dim);
271f2676b15SThomas Raoux }
272f2676b15SThomas Raoux
273f2676b15SThomas Raoux rewriter.replaceOpWithNewOp<ExtractSliceOp>(
274f2676b15SThomas Raoux tensorCast, tensorCast.getType().cast<RankedTensorType>(),
2752d70eff8SJacques Pienaar extractOperand.getSource(), extractOperand.getMixedOffsets(), sizes,
276f2676b15SThomas Raoux extractOperand.getMixedStrides());
277f2676b15SThomas Raoux return success();
278f2676b15SThomas Raoux }
279f2676b15SThomas Raoux };
280f2676b15SThomas Raoux
281129d6e55SSean Silva } // namespace
282129d6e55SSean Silva
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)283dc4e913bSChris Lattner void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
284129d6e55SSean Silva MLIRContext *context) {
285f2676b15SThomas Raoux results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
286129d6e55SSean Silva }
287129d6e55SSean Silva
288129d6e55SSean Silva //===----------------------------------------------------------------------===//
289c0a6318dSMatthias Springer // DimOp
290c0a6318dSMatthias Springer //===----------------------------------------------------------------------===//
291c0a6318dSMatthias Springer
build(OpBuilder & builder,OperationState & result,Value source,int64_t index)292c0a6318dSMatthias Springer void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
293c0a6318dSMatthias Springer int64_t index) {
294c0a6318dSMatthias Springer auto loc = result.location;
295a54f4eaeSMogball Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
296c0a6318dSMatthias Springer build(builder, result, source, indexValue);
297c0a6318dSMatthias Springer }
298c0a6318dSMatthias Springer
getConstantIndex()299c0a6318dSMatthias Springer Optional<int64_t> DimOp::getConstantIndex() {
3002d70eff8SJacques Pienaar if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
301cfb72fd3SJacques Pienaar return constantOp.getValue().cast<IntegerAttr>().getInt();
302c0a6318dSMatthias Springer return {};
303c0a6318dSMatthias Springer }
304c0a6318dSMatthias Springer
verify()305b98dc035SRiver Riddle LogicalResult DimOp::verify() {
306c0a6318dSMatthias Springer // Assume unknown index to be in range.
307b98dc035SRiver Riddle Optional<int64_t> index = getConstantIndex();
308037f0995SKazu Hirata if (!index)
309c0a6318dSMatthias Springer return success();
310c0a6318dSMatthias Springer
311c0a6318dSMatthias Springer // Check that constant index is not knowingly out of range.
3122d70eff8SJacques Pienaar auto type = getSource().getType();
313c0a6318dSMatthias Springer if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
3146d5fc1e3SKazu Hirata if (*index >= tensorType.getRank())
315b98dc035SRiver Riddle return emitOpError("index is out of range");
316c0a6318dSMatthias Springer } else if (type.isa<UnrankedTensorType>()) {
317c0a6318dSMatthias Springer // Assume index to be in range.
318c0a6318dSMatthias Springer } else {
319c0a6318dSMatthias Springer llvm_unreachable("expected operand with tensor type");
320c0a6318dSMatthias Springer }
321c0a6318dSMatthias Springer return success();
322c0a6318dSMatthias Springer }
323c0a6318dSMatthias Springer
fold(ArrayRef<Attribute> operands)324c0a6318dSMatthias Springer OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
325c0a6318dSMatthias Springer // All forms of folding require a known index.
326c0a6318dSMatthias Springer auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
327c0a6318dSMatthias Springer if (!index)
328c0a6318dSMatthias Springer return {};
329c0a6318dSMatthias Springer
330c0a6318dSMatthias Springer // Folding for unranked types (UnrankedTensorType) is not supported.
3312d70eff8SJacques Pienaar auto tensorType = getSource().getType().dyn_cast<RankedTensorType>();
332c0a6318dSMatthias Springer if (!tensorType)
333c0a6318dSMatthias Springer return {};
334c0a6318dSMatthias Springer
335c0a6318dSMatthias Springer // Fold if the shape extent along the given index is known.
336c0a6318dSMatthias Springer if (!tensorType.isDynamicDim(index.getInt())) {
337c0a6318dSMatthias Springer Builder builder(getContext());
338c0a6318dSMatthias Springer return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
339c0a6318dSMatthias Springer }
340c0a6318dSMatthias Springer
3412d70eff8SJacques Pienaar Operation *definingOp = getSource().getDefiningOp();
342c0a6318dSMatthias Springer
343c0a6318dSMatthias Springer // Fold dim to the operand of tensor.generate.
344c0a6318dSMatthias Springer if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
345c0a6318dSMatthias Springer auto resultType =
346c0a6318dSMatthias Springer fromElements.getResult().getType().cast<RankedTensorType>();
347c0a6318dSMatthias Springer // The case where the type encodes the size of the dimension is handled
348c0a6318dSMatthias Springer // above.
349676bfb2aSRiver Riddle assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
350c0a6318dSMatthias Springer
351c0a6318dSMatthias Springer // Find the operand of the fromElements that corresponds to this index.
3522d70eff8SJacques Pienaar auto dynExtents = fromElements.getDynamicExtents().begin();
353c0a6318dSMatthias Springer for (auto dim : resultType.getShape().take_front(index.getInt()))
354676bfb2aSRiver Riddle if (ShapedType::isDynamic(dim))
355c0a6318dSMatthias Springer dynExtents++;
356c0a6318dSMatthias Springer
357c0a6318dSMatthias Springer return Value{*dynExtents};
358c0a6318dSMatthias Springer }
359c0a6318dSMatthias Springer
360c0a6318dSMatthias Springer // The size at the given index is now known to be a dynamic size.
361c0a6318dSMatthias Springer unsigned unsignedIndex = index.getValue().getZExtValue();
362c0a6318dSMatthias Springer
36331f80393SNicolas Vasilache if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
36442819463SMaheshRavishankar // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
36542819463SMaheshRavishankar // `resolve-shaped-type-result-dims` pass.
36642819463SMaheshRavishankar if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
36742819463SMaheshRavishankar sliceOp.isDynamicSize(unsignedIndex)) {
36842819463SMaheshRavishankar return {sliceOp.getDynamicSize(unsignedIndex)};
36942819463SMaheshRavishankar }
370c0a6318dSMatthias Springer }
371c0a6318dSMatthias Springer
372c0a6318dSMatthias Springer // dim(cast) -> dim
373c0a6318dSMatthias Springer if (succeeded(foldTensorCast(*this)))
374c0a6318dSMatthias Springer return getResult();
375c0a6318dSMatthias Springer
376c0a6318dSMatthias Springer return {};
377c0a6318dSMatthias Springer }
378c0a6318dSMatthias Springer
379c0a6318dSMatthias Springer namespace {
380c0a6318dSMatthias Springer /// Fold dim of a cast into the dim of the source of the tensor cast.
381c0a6318dSMatthias Springer struct DimOfCastOp : public OpRewritePattern<DimOp> {
382c0a6318dSMatthias Springer using OpRewritePattern<DimOp>::OpRewritePattern;
383c0a6318dSMatthias Springer
matchAndRewrite__anon3fb9f79f0211::DimOfCastOp384c0a6318dSMatthias Springer LogicalResult matchAndRewrite(DimOp dimOp,
385c0a6318dSMatthias Springer PatternRewriter &rewriter) const override {
3862d70eff8SJacques Pienaar auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
387c0a6318dSMatthias Springer if (!castOp)
388c0a6318dSMatthias Springer return failure();
389c0a6318dSMatthias Springer Value newSource = castOp.getOperand();
3902d70eff8SJacques Pienaar rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex());
391c0a6318dSMatthias Springer return success();
392c0a6318dSMatthias Springer }
393c0a6318dSMatthias Springer };
394be0a7e9fSMehdi Amini } // namespace
395c0a6318dSMatthias Springer
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)396c0a6318dSMatthias Springer void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
397c0a6318dSMatthias Springer MLIRContext *context) {
398c0a6318dSMatthias Springer results.add<DimOfCastOp>(context);
399c0a6318dSMatthias Springer }
400c0a6318dSMatthias Springer
401c0a6318dSMatthias Springer //===----------------------------------------------------------------------===//
402444822d7SSean Silva // ExtractOp
403444822d7SSean Silva //===----------------------------------------------------------------------===//
404444822d7SSean Silva
verify()405b98dc035SRiver Riddle LogicalResult ExtractOp::verify() {
406444822d7SSean Silva // Verify the # indices match if we have a ranked type.
4072d70eff8SJacques Pienaar if (auto tensorType = getTensor().getType().dyn_cast<RankedTensorType>())
4082d70eff8SJacques Pienaar if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
409b98dc035SRiver Riddle return emitOpError("incorrect number of indices for extract_element");
410444822d7SSean Silva
411444822d7SSean Silva return success();
412444822d7SSean Silva }
413444822d7SSean Silva
fold(ArrayRef<Attribute> operands)414444822d7SSean Silva OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
415444822d7SSean Silva // If this is a splat elements attribute, simply return the value. All of the
416444822d7SSean Silva // elements of a splat attribute are the same.
4178544523dSMatthias Springer if (Attribute tensor = operands.front())
418444822d7SSean Silva if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
419ae40d625SRiver Riddle return splatTensor.getSplatValue<Attribute>();
420444822d7SSean Silva
4218544523dSMatthias Springer // Collect the constant indices into the tensor.
422444822d7SSean Silva SmallVector<uint64_t, 8> indices;
423444822d7SSean Silva for (Attribute indice : llvm::drop_begin(operands, 1)) {
424444822d7SSean Silva if (!indice || !indice.isa<IntegerAttr>())
425444822d7SSean Silva return {};
426444822d7SSean Silva indices.push_back(indice.cast<IntegerAttr>().getInt());
427444822d7SSean Silva }
428444822d7SSean Silva
4298544523dSMatthias Springer // Fold extract(from_elements(...)).
4302d70eff8SJacques Pienaar if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
4318544523dSMatthias Springer auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
4328544523dSMatthias Springer auto rank = tensorType.getRank();
4338544523dSMatthias Springer assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
4348544523dSMatthias Springer "rank mismatch");
4358544523dSMatthias Springer int flatIndex = 0;
4368544523dSMatthias Springer int stride = 1;
4378544523dSMatthias Springer for (int i = rank - 1; i >= 0; --i) {
4388544523dSMatthias Springer if (i < rank - 1)
4398544523dSMatthias Springer stride *= tensorType.getDimSize(i);
4408544523dSMatthias Springer flatIndex += indices[i] * stride;
4418544523dSMatthias Springer }
4428544523dSMatthias Springer // Prevent out of bounds accesses. This can happen in invalid code that will
4438544523dSMatthias Springer // never execute.
4442d70eff8SJacques Pienaar if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
4458544523dSMatthias Springer flatIndex < 0)
4468544523dSMatthias Springer return {};
4472d70eff8SJacques Pienaar return fromElementsOp.getElements()[flatIndex];
4488544523dSMatthias Springer }
4498544523dSMatthias Springer
450444822d7SSean Silva // If this is an elements attribute, query the value at the given indices.
4518544523dSMatthias Springer if (Attribute tensor = operands.front()) {
452444822d7SSean Silva auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
453444822d7SSean Silva if (elementsAttr && elementsAttr.isValidIndex(indices))
454ae40d625SRiver Riddle return elementsAttr.getValues<Attribute>()[indices];
4558544523dSMatthias Springer }
4568544523dSMatthias Springer
457444822d7SSean Silva return {};
458444822d7SSean Silva }
459444822d7SSean Silva
460444822d7SSean Silva //===----------------------------------------------------------------------===//
461be7352c0SSean Silva // FromElementsOp
462be7352c0SSean Silva //===----------------------------------------------------------------------===//
463be7352c0SSean Silva
build(OpBuilder & builder,OperationState & result,Type resultType,ValueRange elements)464be7352c0SSean Silva void FromElementsOp::build(OpBuilder &builder, OperationState &result,
465f77e9f87SAlexander Belyaev Type resultType, ValueRange elements) {
466be7352c0SSean Silva result.addOperands(elements);
467f77e9f87SAlexander Belyaev result.addTypes(resultType);
468be7352c0SSean Silva }
469be7352c0SSean Silva
build(OpBuilder & builder,OperationState & result,ValueRange elements)470be7352c0SSean Silva void FromElementsOp::build(OpBuilder &builder, OperationState &result,
471be7352c0SSean Silva ValueRange elements) {
472be7352c0SSean Silva assert(!elements.empty() && "expected at least one element");
473f77e9f87SAlexander Belyaev Type resultType = RankedTensorType::get(
474f77e9f87SAlexander Belyaev {static_cast<int64_t>(elements.size())}, elements.front().getType());
475f77e9f87SAlexander Belyaev build(builder, result, resultType, elements);
476be7352c0SSean Silva }
477be7352c0SSean Silva
fold(ArrayRef<Attribute> operands)4787b52aeadSBenjamin Kramer OpFoldResult FromElementsOp::fold(ArrayRef<Attribute> operands) {
4797b52aeadSBenjamin Kramer if (!llvm::is_contained(operands, nullptr))
4807b52aeadSBenjamin Kramer return DenseElementsAttr::get(getType(), operands);
4817b52aeadSBenjamin Kramer return {};
4827b52aeadSBenjamin Kramer }
4837b52aeadSBenjamin Kramer
484be7352c0SSean Silva namespace {
485be7352c0SSean Silva
4867c984be2SRob Suderman // Pushes the index_casts that occur before extractions to after the extract.
4877c984be2SRob Suderman // This minimizes type conversion in some cases and enables the extract
4887c984be2SRob Suderman // canonicalizer. This changes:
4897c984be2SRob Suderman //
4907c984be2SRob Suderman // %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
4917c984be2SRob Suderman // %extract = tensor.extract %cast[%index] : tensor<1xindex>
4927c984be2SRob Suderman //
4937c984be2SRob Suderman // to the following:
4947c984be2SRob Suderman //
4957c984be2SRob Suderman // %extract = tensor.extract %tensor[%index] : tensor<1xindex>
4967c984be2SRob Suderman // %cast = arith.index_cast %extract : i32 to index
4977c984be2SRob Suderman //
4987c984be2SRob Suderman // to just %element.
4997c984be2SRob Suderman //
5007c984be2SRob Suderman // Consider expanding this to a template and handle all tensor cast operations.
5017c984be2SRob Suderman struct ExtractElementFromIndexCast
5027c984be2SRob Suderman : public OpRewritePattern<tensor::ExtractOp> {
5037c984be2SRob Suderman using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
5047c984be2SRob Suderman
matchAndRewrite__anon3fb9f79f0311::ExtractElementFromIndexCast5057c984be2SRob Suderman LogicalResult matchAndRewrite(tensor::ExtractOp extract,
5067c984be2SRob Suderman PatternRewriter &rewriter) const final {
5077c984be2SRob Suderman Location loc = extract.getLoc();
5082d70eff8SJacques Pienaar auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
5097c984be2SRob Suderman if (!indexCast)
5107c984be2SRob Suderman return failure();
5117c984be2SRob Suderman
5127c984be2SRob Suderman Type elementTy = getElementTypeOrSelf(indexCast.getIn());
5137c984be2SRob Suderman
5147c984be2SRob Suderman auto newExtract = rewriter.create<tensor::ExtractOp>(
5152d70eff8SJacques Pienaar loc, elementTy, indexCast.getIn(), extract.getIndices());
5167c984be2SRob Suderman
5177c984be2SRob Suderman rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
5187c984be2SRob Suderman newExtract);
5197c984be2SRob Suderman
5207c984be2SRob Suderman return success();
5217c984be2SRob Suderman }
5227c984be2SRob Suderman };
5237c984be2SRob Suderman
524be7352c0SSean Silva } // namespace
525be7352c0SSean Silva
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)526dc4e913bSChris Lattner void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
527dc4e913bSChris Lattner MLIRContext *context) {
5288544523dSMatthias Springer results.add<ExtractElementFromIndexCast>(context);
529be7352c0SSean Silva }
530be7352c0SSean Silva
531be7352c0SSean Silva //===----------------------------------------------------------------------===//
532b4baccc2SHanhan Wang // InsertOp
533b4baccc2SHanhan Wang //===----------------------------------------------------------------------===//
534b4baccc2SHanhan Wang
verify()535b98dc035SRiver Riddle LogicalResult InsertOp::verify() {
536b4baccc2SHanhan Wang // Verify the # indices match if we have a ranked type.
5372d70eff8SJacques Pienaar if (auto destType = getDest().getType().dyn_cast<RankedTensorType>())
5382d70eff8SJacques Pienaar if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
539b98dc035SRiver Riddle return emitOpError("incorrect number of indices");
540b4baccc2SHanhan Wang return success();
541b4baccc2SHanhan Wang }
542b4baccc2SHanhan Wang
fold(ArrayRef<Attribute> operands)543b4baccc2SHanhan Wang OpFoldResult InsertOp::fold(ArrayRef<Attribute> operands) {
544b4baccc2SHanhan Wang Attribute scalar = operands[0];
545b4baccc2SHanhan Wang Attribute dest = operands[1];
546b4baccc2SHanhan Wang if (scalar && dest)
547b4baccc2SHanhan Wang if (auto splatDest = dest.dyn_cast<SplatElementsAttr>())
548ae40d625SRiver Riddle if (scalar == splatDest.getSplatValue<Attribute>())
549b4baccc2SHanhan Wang return dest;
550b4baccc2SHanhan Wang return {};
551b4baccc2SHanhan Wang }
552b4baccc2SHanhan Wang
553b4baccc2SHanhan Wang //===----------------------------------------------------------------------===//
554be7352c0SSean Silva // GenerateOp
555be7352c0SSean Silva //===----------------------------------------------------------------------===//
556be7352c0SSean Silva
reifyResultShapes(OpBuilder & builder,ReifiedRankedShapedTypeDims & reifiedReturnShapes)557fdb41a22SMatthias Springer LogicalResult GenerateOp::reifyResultShapes(
558fdb41a22SMatthias Springer OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
559fdb41a22SMatthias Springer reifiedReturnShapes.resize(1, SmallVector<Value>(getType().getRank()));
560fdb41a22SMatthias Springer int idx = 0;
561fdb41a22SMatthias Springer for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
562fdb41a22SMatthias Springer if (getType().isDynamicDim(dim)) {
563fdb41a22SMatthias Springer reifiedReturnShapes[0][dim] = getOperand(idx++);
564fdb41a22SMatthias Springer } else {
565fdb41a22SMatthias Springer reifiedReturnShapes[0][dim] = builder.create<arith::ConstantIndexOp>(
566fdb41a22SMatthias Springer getLoc(), getType().getDimSize(dim));
567fdb41a22SMatthias Springer }
568fdb41a22SMatthias Springer }
569fdb41a22SMatthias Springer return success();
570fdb41a22SMatthias Springer }
571fdb41a22SMatthias Springer
verify()572b98dc035SRiver Riddle LogicalResult GenerateOp::verify() {
573be7352c0SSean Silva // Ensure that the tensor type has as many dynamic dimensions as are specified
574be7352c0SSean Silva // by the operands.
575b98dc035SRiver Riddle RankedTensorType resultTy = getType().cast<RankedTensorType>();
576b98dc035SRiver Riddle if (getNumOperands() != resultTy.getNumDynamicDims())
577b98dc035SRiver Riddle return emitError("must have as many index operands as dynamic extents "
578be7352c0SSean Silva "in the result type");
579be7352c0SSean Silva
580ed645f63SChia-hung Duan return success();
581ed645f63SChia-hung Duan }
582ed645f63SChia-hung Duan
verifyRegions()583ed645f63SChia-hung Duan LogicalResult GenerateOp::verifyRegions() {
584ed645f63SChia-hung Duan RankedTensorType resultTy = getType().cast<RankedTensorType>();
585be7352c0SSean Silva // Ensure that region arguments span the index space.
5862d70eff8SJacques Pienaar if (!llvm::all_of(getBody().getArgumentTypes(),
587be7352c0SSean Silva [](Type ty) { return ty.isIndex(); }))
588b98dc035SRiver Riddle return emitError("all body arguments must be index");
5892d70eff8SJacques Pienaar if (getBody().getNumArguments() != resultTy.getRank())
590b98dc035SRiver Riddle return emitError("must have one body argument per input dimension");
591be7352c0SSean Silva
592be7352c0SSean Silva // Ensure that the region yields an element of the right type.
5932d70eff8SJacques Pienaar auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
594fd0c6f53SAlexander Belyaev
5952d70eff8SJacques Pienaar if (yieldOp.getValue().getType() != resultTy.getElementType())
596b98dc035SRiver Riddle return emitOpError(
597be7352c0SSean Silva "body must be terminated with a `yield` operation of the tensor "
598be7352c0SSean Silva "element type");
599be7352c0SSean Silva
600be7352c0SSean Silva return success();
601be7352c0SSean Silva }
602be7352c0SSean Silva
build(OpBuilder & b,OperationState & result,Type resultTy,ValueRange dynamicExtents,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilder)603be7352c0SSean Silva void GenerateOp::build(
604be7352c0SSean Silva OpBuilder &b, OperationState &result, Type resultTy,
605be7352c0SSean Silva ValueRange dynamicExtents,
606be7352c0SSean Silva function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
607be7352c0SSean Silva build(b, result, resultTy, dynamicExtents);
608be7352c0SSean Silva
609be7352c0SSean Silva // Build and populate body.
610be7352c0SSean Silva OpBuilder::InsertionGuard guard(b);
611be7352c0SSean Silva Region *bodyRegion = result.regions.front().get();
612be7352c0SSean Silva auto rank = resultTy.cast<RankedTensorType>().getRank();
613be7352c0SSean Silva SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
614e084679fSRiver Riddle SmallVector<Location, 2> argumentLocs(rank, result.location);
615be7352c0SSean Silva Block *bodyBlock =
616e084679fSRiver Riddle b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);
617be7352c0SSean Silva bodyBuilder(b, result.location, bodyBlock->getArguments());
618be7352c0SSean Silva }
619be7352c0SSean Silva
620be7352c0SSean Silva namespace {
621be7352c0SSean Silva
622be7352c0SSean Silva /// Canonicalizes tensor.generate operations with a constant
623be7352c0SSean Silva /// operand into the equivalent operation with the operand expressed in the
624be7352c0SSean Silva /// result type, instead. We also insert a type cast to make sure that the
625be7352c0SSean Silva /// resulting IR is still well-typed.
626be7352c0SSean Silva struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
627be7352c0SSean Silva using OpRewritePattern<GenerateOp>::OpRewritePattern;
628be7352c0SSean Silva
matchAndRewrite__anon3fb9f79f0511::StaticTensorGenerate629be7352c0SSean Silva LogicalResult matchAndRewrite(GenerateOp tensorFromElements,
630be7352c0SSean Silva PatternRewriter &rewriter) const final {
631be7352c0SSean Silva auto resultType =
632be7352c0SSean Silva tensorFromElements.getResult().getType().cast<RankedTensorType>();
633be7352c0SSean Silva
634be7352c0SSean Silva if (resultType.hasStaticShape())
635be7352c0SSean Silva return failure();
636be7352c0SSean Silva
637be7352c0SSean Silva SmallVector<Value, 4> newOperands;
638be7352c0SSean Silva SmallVector<int64_t, 4> newShape;
6392d70eff8SJacques Pienaar auto operandsIt = tensorFromElements.getDynamicExtents().begin();
640be7352c0SSean Silva
641be7352c0SSean Silva for (int64_t dim : resultType.getShape()) {
642676bfb2aSRiver Riddle if (!ShapedType::isDynamic(dim)) {
643be7352c0SSean Silva newShape.push_back(dim);
644be7352c0SSean Silva continue;
645be7352c0SSean Silva }
646be7352c0SSean Silva APInt index;
647be7352c0SSean Silva if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
648676bfb2aSRiver Riddle newShape.push_back(ShapedType::kDynamicSize);
649be7352c0SSean Silva newOperands.push_back(*operandsIt++);
650be7352c0SSean Silva continue;
651be7352c0SSean Silva }
652be7352c0SSean Silva newShape.push_back(index.getSExtValue());
653be7352c0SSean Silva operandsIt++;
654be7352c0SSean Silva }
655be7352c0SSean Silva
6562d70eff8SJacques Pienaar if (newOperands.size() == tensorFromElements.getDynamicExtents().size())
657be7352c0SSean Silva return failure();
658be7352c0SSean Silva
659be7352c0SSean Silva auto loc = tensorFromElements.getLoc();
660be7352c0SSean Silva auto newOp = rewriter.create<GenerateOp>(
661be7352c0SSean Silva loc, RankedTensorType::get(newShape, resultType.getElementType()),
662be7352c0SSean Silva newOperands);
6632d70eff8SJacques Pienaar rewriter.inlineRegionBefore(tensorFromElements.getBody(), newOp.getBody(),
6642d70eff8SJacques Pienaar newOp.getBody().begin());
665be7352c0SSean Silva rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType,
666be7352c0SSean Silva newOp);
667be7352c0SSean Silva return success();
668be7352c0SSean Silva }
669be7352c0SSean Silva };
670be7352c0SSean Silva
671be7352c0SSean Silva /// Canonicalizes the pattern of the form
672be7352c0SSean Silva ///
673be7352c0SSean Silva /// %tensor = tensor.generate %x {
674d75c3e83SRiver Riddle /// ^bb0(%arg0: index):
675be7352c0SSean Silva /// <computation>
676be7352c0SSean Silva /// yield %1 : index
677be7352c0SSean Silva /// } : tensor<?xindex>
678be7352c0SSean Silva /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
679be7352c0SSean Silva ///
680be7352c0SSean Silva /// to just <computation> with %arg0 replaced by %c0. We only do this if the
681be7352c0SSean Silva /// tensor.generate operation has no side-effects.
682be7352c0SSean Silva struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
683be7352c0SSean Silva using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
684be7352c0SSean Silva
matchAndRewrite__anon3fb9f79f0511::ExtractFromTensorGenerate685be7352c0SSean Silva LogicalResult matchAndRewrite(tensor::ExtractOp extract,
686be7352c0SSean Silva PatternRewriter &rewriter) const final {
6872d70eff8SJacques Pienaar auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
688be7352c0SSean Silva if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
689be7352c0SSean Silva return failure();
690be7352c0SSean Silva
691be7352c0SSean Silva BlockAndValueMapping mapping;
692eca86cb2SJacques Pienaar Block *body = &tensorFromElements.getBody().front();
6932d70eff8SJacques Pienaar mapping.map(body->getArguments(), extract.getIndices());
694be7352c0SSean Silva for (auto &op : body->without_terminator())
695be7352c0SSean Silva rewriter.clone(op, mapping);
696be7352c0SSean Silva
697be7352c0SSean Silva auto yield = cast<YieldOp>(body->getTerminator());
698be7352c0SSean Silva
6992d70eff8SJacques Pienaar rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue()));
700be7352c0SSean Silva return success();
701be7352c0SSean Silva }
702be7352c0SSean Silva };
703be7352c0SSean Silva
704be7352c0SSean Silva /// Canonicalizes the pattern of the form
705be7352c0SSean Silva ///
706be7352c0SSean Silva /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
707be7352c0SSean Silva /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
708be7352c0SSean Silva ///
709be7352c0SSean Silva /// to
710be7352c0SSean Silva ///
711be7352c0SSean Silva /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
712be7352c0SSean Silva struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
713be7352c0SSean Silva using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
714be7352c0SSean Silva
matchAndRewrite__anon3fb9f79f0511::ExtractFromTensorCast715be7352c0SSean Silva LogicalResult matchAndRewrite(tensor::ExtractOp extract,
716be7352c0SSean Silva PatternRewriter &rewriter) const final {
7172d70eff8SJacques Pienaar auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
718be7352c0SSean Silva if (!tensorCast)
719be7352c0SSean Silva return failure();
720be7352c0SSean Silva
7212d70eff8SJacques Pienaar rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
7222d70eff8SJacques Pienaar extract, tensorCast.getSource(), extract.getIndices());
723be7352c0SSean Silva return success();
724be7352c0SSean Silva }
725be7352c0SSean Silva };
726be7352c0SSean Silva
727be7352c0SSean Silva } // namespace
728be7352c0SSean Silva
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)729dc4e913bSChris Lattner void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
730be7352c0SSean Silva MLIRContext *context) {
731be7352c0SSean Silva // TODO: Move extract patterns to tensor::ExtractOp.
732dc4e913bSChris Lattner results.add<ExtractFromTensorGenerate, ExtractFromTensorCast,
733be7352c0SSean Silva StaticTensorGenerate>(context);
734be7352c0SSean Silva }
735be7352c0SSean Silva
736be7352c0SSean Silva //===----------------------------------------------------------------------===//
73715f8f3e2SAlexander Belyaev // RankOp
73815f8f3e2SAlexander Belyaev //===----------------------------------------------------------------------===//
73915f8f3e2SAlexander Belyaev
fold(ArrayRef<Attribute> operands)74015f8f3e2SAlexander Belyaev OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
74115f8f3e2SAlexander Belyaev // Constant fold rank when the rank of the operand is known.
74215f8f3e2SAlexander Belyaev auto type = getOperand().getType();
74315f8f3e2SAlexander Belyaev auto shapedType = type.dyn_cast<ShapedType>();
74415f8f3e2SAlexander Belyaev if (shapedType && shapedType.hasRank())
74515f8f3e2SAlexander Belyaev return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
74615f8f3e2SAlexander Belyaev return IntegerAttr();
74715f8f3e2SAlexander Belyaev }
74815f8f3e2SAlexander Belyaev
74915f8f3e2SAlexander Belyaev //===----------------------------------------------------------------------===//
7500724911dSAlexander Belyaev // ReshapeOp
7510724911dSAlexander Belyaev //===----------------------------------------------------------------------===//
7520724911dSAlexander Belyaev
getNumElements(ShapedType type)75302b6fb21SMehdi Amini static int64_t getNumElements(ShapedType type) {
7540724911dSAlexander Belyaev int64_t numElements = 1;
7550724911dSAlexander Belyaev for (auto dim : type.getShape())
7560724911dSAlexander Belyaev numElements *= dim;
7570724911dSAlexander Belyaev return numElements;
7580724911dSAlexander Belyaev }
7590724911dSAlexander Belyaev
verify()760b98dc035SRiver Riddle LogicalResult ReshapeOp::verify() {
7612d70eff8SJacques Pienaar TensorType operandType = getSource().getType().cast<TensorType>();
7622d70eff8SJacques Pienaar TensorType resultType = getResult().getType().cast<TensorType>();
7630724911dSAlexander Belyaev
7640724911dSAlexander Belyaev if (operandType.getElementType() != resultType.getElementType())
765b98dc035SRiver Riddle return emitOpError("element types of source and destination tensor "
7660724911dSAlexander Belyaev "types should be the same");
7670724911dSAlexander Belyaev
7682d70eff8SJacques Pienaar int64_t shapeSize =
7692d70eff8SJacques Pienaar getShape().getType().cast<RankedTensorType>().getDimSize(0);
7700724911dSAlexander Belyaev auto resultRankedType = resultType.dyn_cast<RankedTensorType>();
7710724911dSAlexander Belyaev auto operandRankedType = operandType.dyn_cast<RankedTensorType>();
7720724911dSAlexander Belyaev
7730724911dSAlexander Belyaev if (resultRankedType) {
7740724911dSAlexander Belyaev if (operandRankedType && resultRankedType.hasStaticShape() &&
7750724911dSAlexander Belyaev operandRankedType.hasStaticShape()) {
77602b6fb21SMehdi Amini if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
777b98dc035SRiver Riddle return emitOpError("source and destination tensor should have the "
7780724911dSAlexander Belyaev "same number of elements");
7790724911dSAlexander Belyaev }
780676bfb2aSRiver Riddle if (ShapedType::isDynamic(shapeSize))
781b98dc035SRiver Riddle return emitOpError("cannot use shape operand with dynamic length to "
7820724911dSAlexander Belyaev "reshape to statically-ranked tensor type");
7830724911dSAlexander Belyaev if (shapeSize != resultRankedType.getRank())
784b98dc035SRiver Riddle return emitOpError(
7850724911dSAlexander Belyaev "length of shape operand differs from the result's tensor rank");
7860724911dSAlexander Belyaev }
7870724911dSAlexander Belyaev return success();
7880724911dSAlexander Belyaev }
7890724911dSAlexander Belyaev
7900724911dSAlexander Belyaev //===----------------------------------------------------------------------===//
791b618880eSAlexander Belyaev // Reassociative reshape ops
792b618880eSAlexander Belyaev //===----------------------------------------------------------------------===//
793b618880eSAlexander Belyaev
getReassociationMaps()794b618880eSAlexander Belyaev SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
795b618880eSAlexander Belyaev return getSymbolLessAffineMaps(getReassociationExprs());
796b618880eSAlexander Belyaev }
getReassociationExprs()797b618880eSAlexander Belyaev SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
798b618880eSAlexander Belyaev return convertReassociationIndicesToExprs(getContext(),
799b618880eSAlexander Belyaev getReassociationIndices());
800b618880eSAlexander Belyaev }
801b618880eSAlexander Belyaev
getReassociationMaps()802b618880eSAlexander Belyaev SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
803b618880eSAlexander Belyaev return getSymbolLessAffineMaps(getReassociationExprs());
804b618880eSAlexander Belyaev }
getReassociationExprs()805b618880eSAlexander Belyaev SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
806b618880eSAlexander Belyaev return convertReassociationIndicesToExprs(getContext(),
807b618880eSAlexander Belyaev getReassociationIndices());
808b618880eSAlexander Belyaev }
809b618880eSAlexander Belyaev
810b618880eSAlexander Belyaev /// Compute the RankedTensorType obtained by applying `reassociation` to `type`.
811b618880eSAlexander Belyaev static RankedTensorType
computeTensorReshapeCollapsedType(RankedTensorType type,ArrayRef<AffineMap> reassociation)812b618880eSAlexander Belyaev computeTensorReshapeCollapsedType(RankedTensorType type,
813b618880eSAlexander Belyaev ArrayRef<AffineMap> reassociation) {
814b618880eSAlexander Belyaev auto shape = type.getShape();
815b618880eSAlexander Belyaev SmallVector<int64_t, 4> newShape;
816b618880eSAlexander Belyaev newShape.reserve(reassociation.size());
817b618880eSAlexander Belyaev
818b618880eSAlexander Belyaev // Use the fact that reassociation is valid to simplify the logic: only use
819b618880eSAlexander Belyaev // each map's rank.
820b618880eSAlexander Belyaev assert(isReassociationValid(reassociation) && "invalid reassociation");
821b618880eSAlexander Belyaev unsigned currentDim = 0;
822b618880eSAlexander Belyaev for (AffineMap m : reassociation) {
823b618880eSAlexander Belyaev unsigned dim = m.getNumResults();
824b618880eSAlexander Belyaev auto band = shape.slice(currentDim, dim);
825b618880eSAlexander Belyaev int64_t size = 1;
826b618880eSAlexander Belyaev if (llvm::is_contained(band, ShapedType::kDynamicSize))
827b618880eSAlexander Belyaev size = ShapedType::kDynamicSize;
828b618880eSAlexander Belyaev else
829b618880eSAlexander Belyaev for (unsigned d = 0; d < dim; ++d)
830b618880eSAlexander Belyaev size *= shape[currentDim + d];
831b618880eSAlexander Belyaev newShape.push_back(size);
832b618880eSAlexander Belyaev currentDim += dim;
833b618880eSAlexander Belyaev }
834b618880eSAlexander Belyaev
835b618880eSAlexander Belyaev return RankedTensorType::get(newShape, type.getElementType());
836b618880eSAlexander Belyaev }
837b618880eSAlexander Belyaev
build(OpBuilder & b,OperationState & result,Value src,ArrayRef<ReassociationIndices> reassociation,ArrayRef<NamedAttribute> attrs)838b618880eSAlexander Belyaev void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
839b618880eSAlexander Belyaev ArrayRef<ReassociationIndices> reassociation,
840b618880eSAlexander Belyaev ArrayRef<NamedAttribute> attrs) {
841b618880eSAlexander Belyaev auto resultType = computeTensorReshapeCollapsedType(
842b618880eSAlexander Belyaev src.getType().cast<RankedTensorType>(),
843b618880eSAlexander Belyaev getSymbolLessAffineMaps(
844b618880eSAlexander Belyaev convertReassociationIndicesToExprs(b.getContext(), reassociation)));
845b618880eSAlexander Belyaev build(b, result, resultType, src, attrs);
846eca86cb2SJacques Pienaar result.addAttribute(getReassociationAttrStrName(),
847b618880eSAlexander Belyaev getReassociationIndicesAttribute(b, reassociation));
848b618880eSAlexander Belyaev }
849b618880eSAlexander Belyaev
850e7d3ba10SAart Bik // Checks if types are the same, but ignoring encoding on ranked tensors.
isSameTypesWithoutEncoding(Type tp1,Type tp2)851e7d3ba10SAart Bik static bool isSameTypesWithoutEncoding(Type tp1, Type tp2) {
852e7d3ba10SAart Bik if (auto rtp1 = tp1.dyn_cast<RankedTensorType>()) {
853e7d3ba10SAart Bik if (auto rtp2 = tp2.dyn_cast<RankedTensorType>())
854e7d3ba10SAart Bik return rtp1.getShape() == rtp2.getShape() &&
855e7d3ba10SAart Bik rtp1.getElementType() == rtp2.getElementType();
856e7d3ba10SAart Bik return false;
857e7d3ba10SAart Bik }
858e7d3ba10SAart Bik // Default implementation.
859e7d3ba10SAart Bik return tp1 == tp2;
860e7d3ba10SAart Bik }
861e7d3ba10SAart Bik
862b618880eSAlexander Belyaev template <typename TensorReshapeOp, bool isExpansion = std::is_same<
863b618880eSAlexander Belyaev TensorReshapeOp, ExpandShapeOp>::value>
verifyTensorReshapeOp(TensorReshapeOp op,RankedTensorType expandedType,RankedTensorType collapsedType)864b618880eSAlexander Belyaev static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
865b618880eSAlexander Belyaev RankedTensorType expandedType,
866b618880eSAlexander Belyaev RankedTensorType collapsedType) {
867b618880eSAlexander Belyaev if (failed(
868b618880eSAlexander Belyaev verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
869b618880eSAlexander Belyaev return failure();
870b618880eSAlexander Belyaev
871b618880eSAlexander Belyaev auto maps = op.getReassociationMaps();
872b618880eSAlexander Belyaev RankedTensorType expectedType =
873b618880eSAlexander Belyaev computeTensorReshapeCollapsedType(expandedType, maps);
874e7d3ba10SAart Bik if (!isSameTypesWithoutEncoding(collapsedType, expectedType))
875b618880eSAlexander Belyaev return op.emitOpError("expected collapsed type to be ")
876b618880eSAlexander Belyaev << expectedType << ", but got " << collapsedType;
877b618880eSAlexander Belyaev return success();
878b618880eSAlexander Belyaev }
879b618880eSAlexander Belyaev
verify()880b98dc035SRiver Riddle LogicalResult ExpandShapeOp::verify() {
881b98dc035SRiver Riddle return verifyTensorReshapeOp(*this, getResultType(), getSrcType());
882b618880eSAlexander Belyaev }
883b618880eSAlexander Belyaev
verify()884b98dc035SRiver Riddle LogicalResult CollapseShapeOp::verify() {
885b98dc035SRiver Riddle return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
886b618880eSAlexander Belyaev }
887b618880eSAlexander Belyaev
888b618880eSAlexander Belyaev namespace {
889b618880eSAlexander Belyaev /// Reshape of a splat constant can be replaced with a constant of the result
890b618880eSAlexander Belyaev /// type.
891b618880eSAlexander Belyaev template <typename TensorReshapeOp>
892b618880eSAlexander Belyaev struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
893b618880eSAlexander Belyaev using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
matchAndRewrite__anon3fb9f79f0611::FoldReshapeWithConstant894b618880eSAlexander Belyaev LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
895b618880eSAlexander Belyaev PatternRewriter &rewriter) const override {
896b618880eSAlexander Belyaev DenseElementsAttr attr;
8972d70eff8SJacques Pienaar if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr)))
898b618880eSAlexander Belyaev return failure();
899b618880eSAlexander Belyaev if (!attr || !attr.isSplat())
900b618880eSAlexander Belyaev return failure();
901b618880eSAlexander Belyaev DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
902f21896f2SChris Lattner reshapeOp.getResultType(), attr.getRawData());
903b618880eSAlexander Belyaev rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
904b618880eSAlexander Belyaev return success();
905b618880eSAlexander Belyaev }
906b618880eSAlexander Belyaev };
907b618880eSAlexander Belyaev
908d81a3c51SRob Suderman /// Reshape of a FromElements can be replaced with a FromElements of the result
909d81a3c51SRob Suderman /// type
910d81a3c51SRob Suderman template <typename TensorReshapeOp>
911d81a3c51SRob Suderman struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
912d81a3c51SRob Suderman using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
matchAndRewrite__anon3fb9f79f0611::FoldReshapeWithFromElements913d81a3c51SRob Suderman LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
914d81a3c51SRob Suderman PatternRewriter &rewriter) const override {
915d81a3c51SRob Suderman auto fromElements =
9162d70eff8SJacques Pienaar reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
917d81a3c51SRob Suderman if (!fromElements)
918d81a3c51SRob Suderman return failure();
919d81a3c51SRob Suderman
920d81a3c51SRob Suderman auto shapedTy = reshapeOp.getType().template cast<ShapedType>();
921d81a3c51SRob Suderman
922d81a3c51SRob Suderman if (!shapedTy.hasStaticShape())
923d81a3c51SRob Suderman return failure();
924d81a3c51SRob Suderman
925d81a3c51SRob Suderman rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
9262d70eff8SJacques Pienaar fromElements.getElements());
927d81a3c51SRob Suderman return success();
928d81a3c51SRob Suderman }
929d81a3c51SRob Suderman };
930d81a3c51SRob Suderman
931b618880eSAlexander Belyaev } // namespace
932b618880eSAlexander Belyaev
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)933b618880eSAlexander Belyaev void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
934b618880eSAlexander Belyaev MLIRContext *context) {
935747b10beSAlexander Belyaev results.add<ComposeReassociativeReshapeOps<ExpandShapeOp>,
936747b10beSAlexander Belyaev ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
937d81a3c51SRob Suderman FoldReshapeWithConstant<ExpandShapeOp>,
938d81a3c51SRob Suderman FoldReshapeWithFromElements<ExpandShapeOp>>(context);
939b618880eSAlexander Belyaev }
940b618880eSAlexander Belyaev
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)941b618880eSAlexander Belyaev void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
942b618880eSAlexander Belyaev MLIRContext *context) {
943747b10beSAlexander Belyaev results.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
944747b10beSAlexander Belyaev ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp>,
945d81a3c51SRob Suderman FoldReshapeWithConstant<CollapseShapeOp>,
946d81a3c51SRob Suderman FoldReshapeWithFromElements<CollapseShapeOp>>(context);
947b618880eSAlexander Belyaev }
948b618880eSAlexander Belyaev
fold(ArrayRef<Attribute> operands)949b618880eSAlexander Belyaev OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
950b618880eSAlexander Belyaev return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
951b618880eSAlexander Belyaev }
fold(ArrayRef<Attribute> operands)952b618880eSAlexander Belyaev OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
953b618880eSAlexander Belyaev return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
954b618880eSAlexander Belyaev }
955b618880eSAlexander Belyaev
956b618880eSAlexander Belyaev //===----------------------------------------------------------------------===//
957060208b4SMatthias Springer // ExtractSliceOp
958060208b4SMatthias Springer //===----------------------------------------------------------------------===//
959060208b4SMatthias Springer
960741f8f2bSNicolas Vasilache /// An extract_slice result type can be inferred, when it is not
961741f8f2bSNicolas Vasilache /// rank-reduced, from the source type and the static representation of
962741f8f2bSNicolas Vasilache /// offsets, sizes and strides. Special sentinels encode the dynamic case.
inferResultType(ShapedType sourceShapedTensorType,ArrayRef<int64_t> staticOffsets,ArrayRef<int64_t> staticSizes,ArrayRef<int64_t> staticStrides)9637df7586aSMaheshRavishankar RankedTensorType ExtractSliceOp::inferResultType(
964741f8f2bSNicolas Vasilache ShapedType sourceShapedTensorType, ArrayRef<int64_t> staticOffsets,
9657df7586aSMaheshRavishankar ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
966060208b4SMatthias Springer // An extract_slice op may specify only a leading subset of offset/sizes/
967060208b4SMatthias Springer // strides in which case we complete with offset=0, sizes from memref type and
968060208b4SMatthias Springer // strides=1.
969741f8f2bSNicolas Vasilache assert(static_cast<int64_t>(staticSizes.size()) ==
970741f8f2bSNicolas Vasilache sourceShapedTensorType.getRank() &&
9717df7586aSMaheshRavishankar "unexpected staticSizes not equal to rank of source");
972060208b4SMatthias Springer return RankedTensorType::get(staticSizes,
973741f8f2bSNicolas Vasilache sourceShapedTensorType.getElementType());
974060208b4SMatthias Springer }
975060208b4SMatthias Springer
inferResultType(ShapedType sourceShapedTensorType,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides)9767df7586aSMaheshRavishankar RankedTensorType ExtractSliceOp::inferResultType(
977741f8f2bSNicolas Vasilache ShapedType sourceShapedTensorType, ArrayRef<OpFoldResult> offsets,
9787df7586aSMaheshRavishankar ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
979060208b4SMatthias Springer SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
980060208b4SMatthias Springer SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
9817df7586aSMaheshRavishankar dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
9827df7586aSMaheshRavishankar ShapedType::kDynamicStrideOrOffset);
9837df7586aSMaheshRavishankar dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
984060208b4SMatthias Springer ShapedType::kDynamicSize);
9857df7586aSMaheshRavishankar dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
9867df7586aSMaheshRavishankar ShapedType::kDynamicStrideOrOffset);
987741f8f2bSNicolas Vasilache return ExtractSliceOp::inferResultType(sourceShapedTensorType, staticOffsets,
988060208b4SMatthias Springer staticSizes, staticStrides);
989060208b4SMatthias Springer }
990060208b4SMatthias Springer
991741f8f2bSNicolas Vasilache /// If the rank is reduced (i.e. the desiredResultRank is smaller than the
992741f8f2bSNicolas Vasilache /// number of sizes), drop as many size 1 as needed to produce an inferred type
993741f8f2bSNicolas Vasilache /// with the desired rank.
994741f8f2bSNicolas Vasilache ///
995741f8f2bSNicolas Vasilache /// Note that there may be multiple ways to compute this rank-reduced type:
996741f8f2bSNicolas Vasilache /// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
997741f8f2bSNicolas Vasilache ///
998741f8f2bSNicolas Vasilache /// To disambiguate, this function always drops the first 1 sizes occurrences.
inferCanonicalRankReducedResultType(unsigned desiredResultRank,RankedTensorType sourceRankedTensorType,ArrayRef<int64_t> offsets,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides)999741f8f2bSNicolas Vasilache RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
1000741f8f2bSNicolas Vasilache unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
10017df7586aSMaheshRavishankar ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
10027df7586aSMaheshRavishankar ArrayRef<int64_t> strides) {
1003741f8f2bSNicolas Vasilache // Type inferred in the absence of rank-reducing behavior.
1004060208b4SMatthias Springer auto inferredType =
10057df7586aSMaheshRavishankar inferResultType(sourceRankedTensorType, offsets, sizes, strides)
1006060208b4SMatthias Springer .cast<RankedTensorType>();
1007741f8f2bSNicolas Vasilache int rankDiff = inferredType.getRank() - desiredResultRank;
1008060208b4SMatthias Springer if (rankDiff > 0) {
1009060208b4SMatthias Springer auto shape = inferredType.getShape();
10106635c12aSBenjamin Kramer llvm::SmallBitVector dimsToProject =
10116635c12aSBenjamin Kramer getPositionsOfShapeOne(rankDiff, shape);
1012060208b4SMatthias Springer SmallVector<int64_t> projectedShape;
1013741f8f2bSNicolas Vasilache // Best effort rank-reducing: drop 1s in order.
1014060208b4SMatthias Springer for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
10156635c12aSBenjamin Kramer if (!dimsToProject.test(pos))
1016060208b4SMatthias Springer projectedShape.push_back(shape[pos]);
1017060208b4SMatthias Springer inferredType =
1018060208b4SMatthias Springer RankedTensorType::get(projectedShape, inferredType.getElementType());
1019060208b4SMatthias Springer }
1020060208b4SMatthias Springer return inferredType;
1021060208b4SMatthias Springer }
1022060208b4SMatthias Springer
inferCanonicalRankReducedResultType(unsigned desiredResultRank,RankedTensorType sourceRankedTensorType,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides)1023741f8f2bSNicolas Vasilache RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
1024741f8f2bSNicolas Vasilache unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
10257df7586aSMaheshRavishankar ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
10267df7586aSMaheshRavishankar ArrayRef<OpFoldResult> strides) {
1027060208b4SMatthias Springer SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1028060208b4SMatthias Springer SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
10297df7586aSMaheshRavishankar dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
10307df7586aSMaheshRavishankar ShapedType::kDynamicStrideOrOffset);
10317df7586aSMaheshRavishankar dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1032060208b4SMatthias Springer ShapedType::kDynamicSize);
10337df7586aSMaheshRavishankar dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
10347df7586aSMaheshRavishankar ShapedType::kDynamicStrideOrOffset);
1035741f8f2bSNicolas Vasilache return ExtractSliceOp::inferCanonicalRankReducedResultType(
1036741f8f2bSNicolas Vasilache desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
1037060208b4SMatthias Springer staticStrides);
1038060208b4SMatthias Springer }
1039060208b4SMatthias Springer
1040060208b4SMatthias Springer /// Build an ExtractSliceOp with mixed static and dynamic entries and custom
1041060208b4SMatthias Springer /// result type. If the type passed is nullptr, it is inferred.
build(OpBuilder & b,OperationState & result,RankedTensorType resultType,Value source,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides,ArrayRef<NamedAttribute> attrs)1042060208b4SMatthias Springer void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
1043060208b4SMatthias Springer RankedTensorType resultType, Value source,
1044060208b4SMatthias Springer ArrayRef<OpFoldResult> offsets,
1045060208b4SMatthias Springer ArrayRef<OpFoldResult> sizes,
1046060208b4SMatthias Springer ArrayRef<OpFoldResult> strides,
1047060208b4SMatthias Springer ArrayRef<NamedAttribute> attrs) {
1048060208b4SMatthias Springer SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1049060208b4SMatthias Springer SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1050060208b4SMatthias Springer dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1051060208b4SMatthias Springer ShapedType::kDynamicStrideOrOffset);
1052060208b4SMatthias Springer dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1053060208b4SMatthias Springer ShapedType::kDynamicSize);
1054060208b4SMatthias Springer dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1055060208b4SMatthias Springer ShapedType::kDynamicStrideOrOffset);
1056060208b4SMatthias Springer auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
1057060208b4SMatthias Springer // Structuring implementation this way avoids duplication between builders.
1058060208b4SMatthias Springer if (!resultType) {
1059060208b4SMatthias Springer resultType =
1060060208b4SMatthias Springer ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets,
1061060208b4SMatthias Springer staticSizes, staticStrides)
1062060208b4SMatthias Springer .cast<RankedTensorType>();
1063060208b4SMatthias Springer }
1064060208b4SMatthias Springer build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1065060208b4SMatthias Springer dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1066060208b4SMatthias Springer b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1067060208b4SMatthias Springer result.addAttributes(attrs);
1068060208b4SMatthias Springer }
1069060208b4SMatthias Springer
1070060208b4SMatthias Springer /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
1071060208b4SMatthias Springer /// result type.
build(OpBuilder & b,OperationState & result,Value source,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides,ArrayRef<NamedAttribute> attrs)1072060208b4SMatthias Springer void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1073060208b4SMatthias Springer ArrayRef<OpFoldResult> offsets,
1074060208b4SMatthias Springer ArrayRef<OpFoldResult> sizes,
1075060208b4SMatthias Springer ArrayRef<OpFoldResult> strides,
1076060208b4SMatthias Springer ArrayRef<NamedAttribute> attrs) {
1077060208b4SMatthias Springer build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
1078060208b4SMatthias Springer }
1079060208b4SMatthias Springer
1080060208b4SMatthias Springer /// Build an ExtractSliceOp with dynamic entries and custom result type. If the
1081060208b4SMatthias Springer /// type passed is nullptr, it is inferred.
build(OpBuilder & b,OperationState & result,RankedTensorType resultType,Value source,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)1082060208b4SMatthias Springer void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
1083060208b4SMatthias Springer RankedTensorType resultType, Value source,
1084060208b4SMatthias Springer ValueRange offsets, ValueRange sizes,
1085060208b4SMatthias Springer ValueRange strides, ArrayRef<NamedAttribute> attrs) {
1086060208b4SMatthias Springer SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1087060208b4SMatthias Springer llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1088060208b4SMatthias Springer SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1089060208b4SMatthias Springer llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1090060208b4SMatthias Springer SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1091060208b4SMatthias Springer llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1092060208b4SMatthias Springer build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
1093060208b4SMatthias Springer }
1094060208b4SMatthias Springer
1095060208b4SMatthias Springer /// Build an ExtractSliceOp with dynamic entries and inferred result type.
build(OpBuilder & b,OperationState & result,Value source,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)1096060208b4SMatthias Springer void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1097060208b4SMatthias Springer ValueRange offsets, ValueRange sizes,
1098060208b4SMatthias Springer ValueRange strides, ArrayRef<NamedAttribute> attrs) {
1099060208b4SMatthias Springer build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
1100060208b4SMatthias Springer }
1101060208b4SMatthias Springer
1102060208b4SMatthias Springer template <typename OpTy>
produceSliceErrorMsg(SliceVerificationResult result,OpTy op,Type expectedType)1103060208b4SMatthias Springer static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
1104a08b750cSNicolas Vasilache OpTy op, Type expectedType) {
1105060208b4SMatthias Springer auto memrefType = expectedType.cast<ShapedType>();
1106060208b4SMatthias Springer switch (result) {
1107060208b4SMatthias Springer case SliceVerificationResult::Success:
1108060208b4SMatthias Springer return success();
1109060208b4SMatthias Springer case SliceVerificationResult::RankTooLarge:
1110a08b750cSNicolas Vasilache return op.emitError("expected rank to be smaller or equal to ")
1111a08b750cSNicolas Vasilache << "the other rank. ";
1112060208b4SMatthias Springer case SliceVerificationResult::SizeMismatch:
1113a08b750cSNicolas Vasilache return op.emitError("expected type to be ")
1114a08b750cSNicolas Vasilache << expectedType << " or a rank-reduced version. (size mismatch) ";
1115060208b4SMatthias Springer case SliceVerificationResult::ElemTypeMismatch:
1116a08b750cSNicolas Vasilache return op.emitError("expected element type to be ")
1117a08b750cSNicolas Vasilache << memrefType.getElementType();
1118a08b750cSNicolas Vasilache default:
1119060208b4SMatthias Springer llvm_unreachable("unexpected extract_slice op verification result");
1120060208b4SMatthias Springer }
1121a08b750cSNicolas Vasilache }
1122060208b4SMatthias Springer
1123060208b4SMatthias Springer /// Verifier for ExtractSliceOp.
verify()1124b98dc035SRiver Riddle LogicalResult ExtractSliceOp::verify() {
1125060208b4SMatthias Springer // Verify result type against inferred type.
1126*c9fb3c6eSNicolas Vasilache RankedTensorType expectedType = ExtractSliceOp::inferResultType(
1127b98dc035SRiver Riddle getSourceType(), getMixedOffsets(), getMixedSizes(), getMixedStrides());
1128*c9fb3c6eSNicolas Vasilache SliceVerificationResult result = isRankReducedType(expectedType, getType());
1129b98dc035SRiver Riddle return produceSliceErrorMsg(result, *this, expectedType);
1130060208b4SMatthias Springer }
1131060208b4SMatthias Springer
getDroppedDims()11326635c12aSBenjamin Kramer llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
113342819463SMaheshRavishankar ArrayRef<int64_t> resultShape = getType().getShape();
113442819463SMaheshRavishankar SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
11356635c12aSBenjamin Kramer llvm::SmallBitVector droppedDims(mixedSizes.size());
113642819463SMaheshRavishankar unsigned shapePos = 0;
1137e4853be2SMehdi Amini for (const auto &size : enumerate(mixedSizes)) {
113842819463SMaheshRavishankar Optional<int64_t> sizeVal = getConstantIntValue(size.value());
113942819463SMaheshRavishankar // If the size is not 1, or if the current matched dimension of the result
114042819463SMaheshRavishankar // is the same static shape as the size value (which is 1), then the
114142819463SMaheshRavishankar // dimension is preserved.
11426d5fc1e3SKazu Hirata if (!sizeVal || *sizeVal != 1 ||
114342819463SMaheshRavishankar (shapePos < resultShape.size() && resultShape[shapePos] == 1)) {
114442819463SMaheshRavishankar shapePos++;
114542819463SMaheshRavishankar continue;
114642819463SMaheshRavishankar }
11476635c12aSBenjamin Kramer droppedDims.set(size.index());
114842819463SMaheshRavishankar }
114942819463SMaheshRavishankar return droppedDims;
115042819463SMaheshRavishankar }
115142819463SMaheshRavishankar
reifyResultShapes(OpBuilder & builder,ReifiedRankedShapedTypeDims & reifiedReturnShapes)115242819463SMaheshRavishankar LogicalResult ExtractSliceOp::reifyResultShapes(
115342819463SMaheshRavishankar OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
115442819463SMaheshRavishankar reifiedReturnShapes.resize(1);
115542819463SMaheshRavishankar reifiedReturnShapes[0].reserve(getType().getRank());
115642819463SMaheshRavishankar SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
11576635c12aSBenjamin Kramer llvm::SmallBitVector droppedDims = getDroppedDims();
115842819463SMaheshRavishankar Location loc = getLoc();
1159e4853be2SMehdi Amini for (const auto &size : enumerate(mixedSizes)) {
11606635c12aSBenjamin Kramer if (droppedDims.test(size.index()))
116142819463SMaheshRavishankar continue;
116242819463SMaheshRavishankar if (auto attr = size.value().dyn_cast<Attribute>()) {
1163a54f4eaeSMogball reifiedReturnShapes[0].push_back(builder.create<arith::ConstantIndexOp>(
116442819463SMaheshRavishankar loc, attr.cast<IntegerAttr>().getInt()));
116542819463SMaheshRavishankar continue;
116642819463SMaheshRavishankar }
116742819463SMaheshRavishankar reifiedReturnShapes[0].push_back(size.value().get<Value>());
116842819463SMaheshRavishankar }
116942819463SMaheshRavishankar return success();
117042819463SMaheshRavishankar }
117142819463SMaheshRavishankar
1172060208b4SMatthias Springer namespace {
1173060208b4SMatthias Springer /// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
1174060208b4SMatthias Springer /// This essentially pushes memref_cast past its consuming slice when
1175060208b4SMatthias Springer /// `canFoldIntoConsumerOp` is true.
1176060208b4SMatthias Springer ///
1177060208b4SMatthias Springer /// Example:
1178060208b4SMatthias Springer /// ```
1179060208b4SMatthias Springer /// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
1180060208b4SMatthias Springer /// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
1181060208b4SMatthias Springer /// tensor<3x4xf32>
1182060208b4SMatthias Springer /// ```
1183060208b4SMatthias Springer /// is rewritten into:
1184060208b4SMatthias Springer /// ```
1185060208b4SMatthias Springer /// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
1186060208b4SMatthias Springer /// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
1187060208b4SMatthias Springer /// ```
1188060208b4SMatthias Springer class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
1189060208b4SMatthias Springer public:
1190060208b4SMatthias Springer using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
1191060208b4SMatthias Springer
matchAndRewrite(ExtractSliceOp sliceOp,PatternRewriter & rewriter) const1192060208b4SMatthias Springer LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
1193060208b4SMatthias Springer PatternRewriter &rewriter) const override {
1194741f8f2bSNicolas Vasilache // Any constant operand, just return to let the constant folder kick in.
1195060208b4SMatthias Springer if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
1196060208b4SMatthias Springer return matchPattern(operand, matchConstantIndex());
1197060208b4SMatthias Springer }))
1198060208b4SMatthias Springer return failure();
1199060208b4SMatthias Springer
12002d70eff8SJacques Pienaar auto castOp = sliceOp.getSource().getDefiningOp<tensor::CastOp>();
1201060208b4SMatthias Springer if (!castOp)
1202060208b4SMatthias Springer return failure();
1203060208b4SMatthias Springer
1204060208b4SMatthias Springer if (!canFoldIntoConsumerOp(castOp))
1205060208b4SMatthias Springer return failure();
1206060208b4SMatthias Springer
1207060208b4SMatthias Springer /// Deduce the type of the result to use for the canonicalized operation.
1208741f8f2bSNicolas Vasilache RankedTensorType resultType =
1209741f8f2bSNicolas Vasilache ExtractSliceOp::inferCanonicalRankReducedResultType(
1210060208b4SMatthias Springer sliceOp.getType().getRank(), sliceOp.getSourceType(),
1211060208b4SMatthias Springer sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
1212060208b4SMatthias Springer sliceOp.getMixedStrides());
1213060208b4SMatthias Springer Value newSlice = rewriter.create<ExtractSliceOp>(
12142d70eff8SJacques Pienaar sliceOp.getLoc(), resultType, castOp.getSource(), sliceOp.getOffsets(),
12152d70eff8SJacques Pienaar sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
12162d70eff8SJacques Pienaar sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
1217060208b4SMatthias Springer rewriter.replaceOpWithNewOp<tensor::CastOp>(sliceOp, sliceOp.getType(),
1218060208b4SMatthias Springer newSlice);
1219060208b4SMatthias Springer return success();
1220060208b4SMatthias Springer }
1221060208b4SMatthias Springer };
12224c901bf4SOkwan Kwon
12234c901bf4SOkwan Kwon /// Slice elements from `values` into `outValues`. `counts` represents the
12244c901bf4SOkwan Kwon /// numbers of elements to stride in the original values for each dimension.
12254c901bf4SOkwan Kwon /// The output values can be used to construct a DenseElementsAttr.
12264c901bf4SOkwan Kwon template <typename IterTy, typename ElemTy>
sliceElements(IterTy values,ArrayRef<int64_t> counts,ArrayRef<int64_t> offsets,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides,llvm::SmallVectorImpl<ElemTy> * outValues)12274c901bf4SOkwan Kwon static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
12284c901bf4SOkwan Kwon ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
12294c901bf4SOkwan Kwon ArrayRef<int64_t> strides,
12304c901bf4SOkwan Kwon llvm::SmallVectorImpl<ElemTy> *outValues) {
12314c901bf4SOkwan Kwon assert(offsets.size() == sizes.size());
12324c901bf4SOkwan Kwon assert(offsets.size() == strides.size());
12334c901bf4SOkwan Kwon if (offsets.empty())
12344c901bf4SOkwan Kwon return;
12354c901bf4SOkwan Kwon
12364c901bf4SOkwan Kwon int64_t offset = offsets.front();
12374c901bf4SOkwan Kwon int64_t size = sizes.front();
12384c901bf4SOkwan Kwon int64_t stride = strides.front();
12394c901bf4SOkwan Kwon if (offsets.size() == 1) {
12404c901bf4SOkwan Kwon for (int64_t i = 0; i < size; ++i, offset += stride)
12414c901bf4SOkwan Kwon outValues->push_back(*(values + offset));
12424c901bf4SOkwan Kwon
12434c901bf4SOkwan Kwon return;
12444c901bf4SOkwan Kwon }
12454c901bf4SOkwan Kwon
12464c901bf4SOkwan Kwon for (int64_t i = 0; i < size; ++i, offset += stride) {
12474c901bf4SOkwan Kwon auto begin = values + offset * counts.front();
12484c901bf4SOkwan Kwon sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
12494c901bf4SOkwan Kwon offsets.drop_front(), sizes.drop_front(),
12504c901bf4SOkwan Kwon strides.drop_front(), outValues);
12514c901bf4SOkwan Kwon }
12524c901bf4SOkwan Kwon }
12534c901bf4SOkwan Kwon
12544c901bf4SOkwan Kwon /// Fold arith.constant and tensor.extract_slice into arith.constant. The folded
12554c901bf4SOkwan Kwon /// operation might introduce more constant data; Users can control their
12564c901bf4SOkwan Kwon /// heuristics by the control function.
12574c901bf4SOkwan Kwon class ConstantOpExtractSliceFolder final
12584c901bf4SOkwan Kwon : public OpRewritePattern<ExtractSliceOp> {
12594c901bf4SOkwan Kwon public:
12604c901bf4SOkwan Kwon using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
12614c901bf4SOkwan Kwon
ConstantOpExtractSliceFolder(MLIRContext * context,ControlConstantExtractSliceFusionFn controlFn)12624c901bf4SOkwan Kwon ConstantOpExtractSliceFolder(MLIRContext *context,
12634c901bf4SOkwan Kwon ControlConstantExtractSliceFusionFn controlFn)
12644c901bf4SOkwan Kwon : OpRewritePattern<ExtractSliceOp>(context),
12654c901bf4SOkwan Kwon controlFn(std::move(controlFn)) {}
12664c901bf4SOkwan Kwon
matchAndRewrite(ExtractSliceOp op,PatternRewriter & rewriter) const12674c901bf4SOkwan Kwon LogicalResult matchAndRewrite(ExtractSliceOp op,
12684c901bf4SOkwan Kwon PatternRewriter &rewriter) const override {
12694c901bf4SOkwan Kwon DenseElementsAttr attr;
12702d70eff8SJacques Pienaar if (!matchPattern(op.getSource(), m_Constant(&attr)))
12714c901bf4SOkwan Kwon return failure();
12724c901bf4SOkwan Kwon
12734c901bf4SOkwan Kwon // A constant splat is handled by fold().
12744c901bf4SOkwan Kwon if (attr.isSplat())
12754c901bf4SOkwan Kwon return failure();
12764c901bf4SOkwan Kwon
12774c901bf4SOkwan Kwon // Dynamic result shape is not supported.
12782d70eff8SJacques Pienaar auto sourceType = op.getSource().getType().cast<ShapedType>();
12792d70eff8SJacques Pienaar auto resultType = op.getResult().getType().cast<ShapedType>();
12804c901bf4SOkwan Kwon if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
12814c901bf4SOkwan Kwon return failure();
12824c901bf4SOkwan Kwon
12834c901bf4SOkwan Kwon // Customized control over the folding.
12844c901bf4SOkwan Kwon if (!controlFn(op))
12854c901bf4SOkwan Kwon return failure();
12864c901bf4SOkwan Kwon
12874c901bf4SOkwan Kwon int64_t count = sourceType.getNumElements();
12884c901bf4SOkwan Kwon if (count == 0)
12894c901bf4SOkwan Kwon return failure();
12904c901bf4SOkwan Kwon
12914c901bf4SOkwan Kwon // Check if there are any dynamic parts, which are not supported.
12922d70eff8SJacques Pienaar auto offsets = extractFromI64ArrayAttr(op.getStaticOffsets());
12934c901bf4SOkwan Kwon if (llvm::is_contained(offsets, ShapedType::kDynamicStrideOrOffset))
12944c901bf4SOkwan Kwon return failure();
12952d70eff8SJacques Pienaar auto sizes = extractFromI64ArrayAttr(op.getStaticSizes());
12964c901bf4SOkwan Kwon if (llvm::is_contained(sizes, ShapedType::kDynamicSize))
12974c901bf4SOkwan Kwon return failure();
12982d70eff8SJacques Pienaar auto strides = extractFromI64ArrayAttr(op.getStaticStrides());
12994c901bf4SOkwan Kwon if (llvm::is_contained(strides, ShapedType::kDynamicStrideOrOffset))
13004c901bf4SOkwan Kwon return failure();
13014c901bf4SOkwan Kwon
13024c901bf4SOkwan Kwon // Compute the stride for each dimension.
13034c901bf4SOkwan Kwon SmallVector<int64_t> counts;
13044c901bf4SOkwan Kwon ArrayRef<int64_t> shape = sourceType.getShape();
13054c901bf4SOkwan Kwon counts.reserve(shape.size());
13064c901bf4SOkwan Kwon for (int64_t v : shape) {
13074c901bf4SOkwan Kwon count = count / v;
13084c901bf4SOkwan Kwon counts.push_back(count);
13094c901bf4SOkwan Kwon }
13104c901bf4SOkwan Kwon
13114c901bf4SOkwan Kwon // New attribute constructed by the sliced values.
13124c901bf4SOkwan Kwon DenseElementsAttr newAttr;
13134c901bf4SOkwan Kwon
13144c901bf4SOkwan Kwon if (auto elems = attr.dyn_cast<DenseIntElementsAttr>()) {
13154c901bf4SOkwan Kwon SmallVector<APInt> outValues;
13164c901bf4SOkwan Kwon outValues.reserve(sourceType.getNumElements());
13174c901bf4SOkwan Kwon sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
13184c901bf4SOkwan Kwon elems.begin(), counts, offsets, sizes, strides, &outValues);
13194c901bf4SOkwan Kwon newAttr = DenseElementsAttr::get(resultType, outValues);
13204c901bf4SOkwan Kwon } else if (auto elems = attr.dyn_cast<DenseFPElementsAttr>()) {
13214c901bf4SOkwan Kwon SmallVector<APFloat> outValues;
13224c901bf4SOkwan Kwon outValues.reserve(sourceType.getNumElements());
13234c901bf4SOkwan Kwon sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
13244c901bf4SOkwan Kwon elems.begin(), counts, offsets, sizes, strides, &outValues);
13254c901bf4SOkwan Kwon newAttr = DenseElementsAttr::get(resultType, outValues);
13264c901bf4SOkwan Kwon }
13274c901bf4SOkwan Kwon
13284c901bf4SOkwan Kwon if (newAttr) {
13294c901bf4SOkwan Kwon rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
13304c901bf4SOkwan Kwon return success();
13314c901bf4SOkwan Kwon }
13324c901bf4SOkwan Kwon
13334c901bf4SOkwan Kwon return failure();
13344c901bf4SOkwan Kwon }
13354c901bf4SOkwan Kwon
13364c901bf4SOkwan Kwon private:
13374c901bf4SOkwan Kwon /// This additionally controls whether the fold happens or not. Users can
13384c901bf4SOkwan Kwon /// impose their heuristics in the function.
13394c901bf4SOkwan Kwon ControlConstantExtractSliceFusionFn controlFn;
13404c901bf4SOkwan Kwon };
13414c901bf4SOkwan Kwon
1342060208b4SMatthias Springer } // namespace
1343060208b4SMatthias Springer
populateFoldConstantExtractSlicePatterns(RewritePatternSet & patterns,const ControlConstantExtractSliceFusionFn & controlFn)13444c901bf4SOkwan Kwon void mlir::tensor::populateFoldConstantExtractSlicePatterns(
13454c901bf4SOkwan Kwon RewritePatternSet &patterns,
13464c901bf4SOkwan Kwon const ControlConstantExtractSliceFusionFn &controlFn) {
13474c901bf4SOkwan Kwon patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
13484c901bf4SOkwan Kwon }
13494c901bf4SOkwan Kwon
1350060208b4SMatthias Springer /// Return the canonical type of the result of an extract_slice op.
1351060208b4SMatthias Springer struct SliceReturnTypeCanonicalizer {
operator ()SliceReturnTypeCanonicalizer1352060208b4SMatthias Springer RankedTensorType operator()(ExtractSliceOp op,
1353060208b4SMatthias Springer ArrayRef<OpFoldResult> mixedOffsets,
1354060208b4SMatthias Springer ArrayRef<OpFoldResult> mixedSizes,
1355060208b4SMatthias Springer ArrayRef<OpFoldResult> mixedStrides) {
1356741f8f2bSNicolas Vasilache return ExtractSliceOp::inferCanonicalRankReducedResultType(
1357741f8f2bSNicolas Vasilache op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
1358741f8f2bSNicolas Vasilache mixedStrides);
1359060208b4SMatthias Springer }
1360060208b4SMatthias Springer };
1361060208b4SMatthias Springer
1362060208b4SMatthias Springer /// A canonicalizer wrapper to replace ExtractSliceOps.
1363060208b4SMatthias Springer struct SliceCanonicalizer {
operator ()SliceCanonicalizer1364060208b4SMatthias Springer void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
1365060208b4SMatthias Springer ExtractSliceOp newOp) {
1366060208b4SMatthias Springer Value replacement = newOp.getResult();
1367060208b4SMatthias Springer if (replacement.getType() != op.getType())
1368060208b4SMatthias Springer replacement = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
1369060208b4SMatthias Springer replacement);
1370060208b4SMatthias Springer rewriter.replaceOp(op, replacement);
1371060208b4SMatthias Springer }
1372060208b4SMatthias Springer };
1373060208b4SMatthias Springer
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1374060208b4SMatthias Springer void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
1375060208b4SMatthias Springer MLIRContext *context) {
1376060208b4SMatthias Springer results.add<
1377060208b4SMatthias Springer OpWithOffsetSizesAndStridesConstantArgumentFolder<
1378060208b4SMatthias Springer ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
1379060208b4SMatthias Springer ExtractSliceOpCastFolder>(context);
1380060208b4SMatthias Springer }
1381060208b4SMatthias Springer
1382060208b4SMatthias Springer //
1383060208b4SMatthias Springer static LogicalResult
foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,ShapedType shapedType)1384060208b4SMatthias Springer foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
1385060208b4SMatthias Springer ShapedType shapedType) {
1386060208b4SMatthias Springer OpBuilder b(op.getContext());
1387060208b4SMatthias Springer for (OpFoldResult ofr : op.getMixedOffsets())
13880813700dSMatthias Springer if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
1389060208b4SMatthias Springer return failure();
1390060208b4SMatthias Springer // Rank-reducing noops only need to inspect the leading dimensions: llvm::zip
1391060208b4SMatthias Springer // is appropriate.
1392060208b4SMatthias Springer auto shape = shapedType.getShape();
1393060208b4SMatthias Springer for (auto it : llvm::zip(op.getMixedSizes(), shape))
13940813700dSMatthias Springer if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
1395060208b4SMatthias Springer return failure();
1396060208b4SMatthias Springer for (OpFoldResult ofr : op.getMixedStrides())
13970813700dSMatthias Springer if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
1398060208b4SMatthias Springer return failure();
1399060208b4SMatthias Springer return success();
1400060208b4SMatthias Springer }
1401060208b4SMatthias Springer
1402bdd37c9fSLei Zhang /// If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice,
1403bdd37c9fSLei Zhang /// we can return the InsertSliceOp's source directly.
1404bdd37c9fSLei Zhang // TODO: This only checks the immediate producer; extend to go up the
1405bdd37c9fSLei Zhang // insert/extract chain if the slices are disjoint.
foldExtractAfterInsertSlice(ExtractSliceOp extractOp)1406bdd37c9fSLei Zhang static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
14072d70eff8SJacques Pienaar auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
1408bdd37c9fSLei Zhang
1409bdd37c9fSLei Zhang auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
14102d70eff8SJacques Pienaar if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
1411bdd37c9fSLei Zhang insertOp.isSameAs(extractOp, isSame))
14122d70eff8SJacques Pienaar return insertOp.getSource();
1413bdd37c9fSLei Zhang
1414bdd37c9fSLei Zhang return {};
1415bdd37c9fSLei Zhang }
1416bdd37c9fSLei Zhang
fold(ArrayRef<Attribute> operands)1417f79f430dSOkwan Kwon OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute> operands) {
1418f79f430dSOkwan Kwon if (auto splat = operands[0].dyn_cast_or_null<SplatElementsAttr>()) {
14192d70eff8SJacques Pienaar auto resultType = getResult().getType().cast<ShapedType>();
1420f79f430dSOkwan Kwon if (resultType.hasStaticShape())
1421f79f430dSOkwan Kwon return splat.resizeSplat(resultType);
1422f79f430dSOkwan Kwon }
1423060208b4SMatthias Springer if (getSourceType() == getType() &&
1424060208b4SMatthias Springer succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
14252d70eff8SJacques Pienaar return this->getSource();
1426bdd37c9fSLei Zhang if (Value slice = foldExtractAfterInsertSlice(*this))
1427bdd37c9fSLei Zhang return slice;
14284c901bf4SOkwan Kwon
1429060208b4SMatthias Springer return OpFoldResult();
1430060208b4SMatthias Springer }
1431060208b4SMatthias Springer
createCanonicalRankReducingExtractSliceOp(OpBuilder & b,Location loc,Value tensor,RankedTensorType targetType)1432aa373180SNicolas Vasilache Value mlir::tensor::createCanonicalRankReducingExtractSliceOp(
1433aa373180SNicolas Vasilache OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
1434aa373180SNicolas Vasilache auto rankedTensorType = tensor.getType().cast<RankedTensorType>();
1435aa373180SNicolas Vasilache unsigned rank = rankedTensorType.getRank();
1436aa373180SNicolas Vasilache auto shape = rankedTensorType.getShape();
1437aa373180SNicolas Vasilache SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
1438aa373180SNicolas Vasilache SmallVector<OpFoldResult> sizes;
1439aa373180SNicolas Vasilache for (unsigned i = 0, e = rank; i < e; ++i) {
1440aa373180SNicolas Vasilache OpFoldResult dim;
1441aa373180SNicolas Vasilache if (rankedTensorType.isDynamicDim(i))
1442aa373180SNicolas Vasilache dim = b.createOrFold<tensor::DimOp>(
1443aa373180SNicolas Vasilache loc, tensor, b.create<arith::ConstantIndexOp>(loc, i));
1444aa373180SNicolas Vasilache else
1445aa373180SNicolas Vasilache dim = b.getIndexAttr(shape[i]);
1446aa373180SNicolas Vasilache sizes.push_back(dim);
1447aa373180SNicolas Vasilache }
1448aa373180SNicolas Vasilache SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
1449aa373180SNicolas Vasilache return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
1450aa373180SNicolas Vasilache offsets, sizes, strides);
1451aa373180SNicolas Vasilache }
1452aa373180SNicolas Vasilache
1453060208b4SMatthias Springer //===----------------------------------------------------------------------===//
1454060208b4SMatthias Springer // InsertSliceOp
1455060208b4SMatthias Springer //===----------------------------------------------------------------------===//
1456060208b4SMatthias Springer
1457060208b4SMatthias Springer // Build a InsertSliceOp with mixed static and dynamic entries.
build(OpBuilder & b,OperationState & result,Value source,Value dest,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides,ArrayRef<NamedAttribute> attrs)1458060208b4SMatthias Springer void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1459060208b4SMatthias Springer Value dest, ArrayRef<OpFoldResult> offsets,
1460060208b4SMatthias Springer ArrayRef<OpFoldResult> sizes,
1461060208b4SMatthias Springer ArrayRef<OpFoldResult> strides,
1462060208b4SMatthias Springer ArrayRef<NamedAttribute> attrs) {
1463060208b4SMatthias Springer SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1464060208b4SMatthias Springer SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1465060208b4SMatthias Springer dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1466060208b4SMatthias Springer ShapedType::kDynamicStrideOrOffset);
1467060208b4SMatthias Springer dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1468060208b4SMatthias Springer ShapedType::kDynamicSize);
1469060208b4SMatthias Springer dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1470060208b4SMatthias Springer ShapedType::kDynamicStrideOrOffset);
1471060208b4SMatthias Springer build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
1472060208b4SMatthias Springer dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1473060208b4SMatthias Springer b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1474060208b4SMatthias Springer result.addAttributes(attrs);
1475060208b4SMatthias Springer }
1476060208b4SMatthias Springer
1477060208b4SMatthias Springer // Build a InsertSliceOp with dynamic entries.
build(OpBuilder & b,OperationState & result,Value source,Value dest,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)1478060208b4SMatthias Springer void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1479060208b4SMatthias Springer Value dest, ValueRange offsets, ValueRange sizes,
1480060208b4SMatthias Springer ValueRange strides, ArrayRef<NamedAttribute> attrs) {
1481060208b4SMatthias Springer SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1482060208b4SMatthias Springer llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1483060208b4SMatthias Springer SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1484060208b4SMatthias Springer llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1485060208b4SMatthias Springer SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1486060208b4SMatthias Springer llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1487060208b4SMatthias Springer build(b, result, source, dest, offsetValues, sizeValues, strideValues);
1488060208b4SMatthias Springer }
1489060208b4SMatthias Springer
1490*c9fb3c6eSNicolas Vasilache /// Rank-reducing type verification for both InsertSliceOp and
1491*c9fb3c6eSNicolas Vasilache /// ParallelInsertSliceOp.
1492cd0d095cSIvan Butygin static SliceVerificationResult
verifyInsertSliceOp(ShapedType srcType,ShapedType dstType,ArrayAttr staticOffsets,ArrayAttr staticSizes,ArrayAttr staticStrides,ShapedType * expectedType=nullptr)1493cd0d095cSIvan Butygin verifyInsertSliceOp(ShapedType srcType, ShapedType dstType,
1494cd0d095cSIvan Butygin ArrayAttr staticOffsets, ArrayAttr staticSizes,
1495cd0d095cSIvan Butygin ArrayAttr staticStrides,
1496cd0d095cSIvan Butygin ShapedType *expectedType = nullptr) {
1497cd0d095cSIvan Butygin // insert_slice is the inverse of extract_slice, use the same type inference.
1498*c9fb3c6eSNicolas Vasilache RankedTensorType expected = ExtractSliceOp::inferResultType(
1499741f8f2bSNicolas Vasilache dstType, extractFromI64ArrayAttr(staticOffsets),
1500cd0d095cSIvan Butygin extractFromI64ArrayAttr(staticSizes),
1501*c9fb3c6eSNicolas Vasilache extractFromI64ArrayAttr(staticStrides));
1502cd0d095cSIvan Butygin if (expectedType)
1503cd0d095cSIvan Butygin *expectedType = expected;
1504cd0d095cSIvan Butygin return isRankReducedType(expected, srcType);
1505cd0d095cSIvan Butygin }
1506cd0d095cSIvan Butygin
1507a08b750cSNicolas Vasilache /// Verifier for InsertSliceOp.
verify()1508b98dc035SRiver Riddle LogicalResult InsertSliceOp::verify() {
1509cd0d095cSIvan Butygin ShapedType expectedType;
1510*c9fb3c6eSNicolas Vasilache SliceVerificationResult result =
15112d70eff8SJacques Pienaar verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
15122d70eff8SJacques Pienaar getStaticSizes(), getStaticStrides(), &expectedType);
1513b98dc035SRiver Riddle return produceSliceErrorMsg(result, *this, expectedType);
1514a08b750cSNicolas Vasilache }
1515a08b750cSNicolas Vasilache
1516bdd37c9fSLei Zhang /// If we have two consecutive InsertSliceOp writing to the same slice, we
1517bdd37c9fSLei Zhang /// can mutate the second InsertSliceOp's destination to the first one's.
1518*c9fb3c6eSNicolas Vasilache /// This works similarly when the second op is a ParallelInsertSliceOp.
1519bdd37c9fSLei Zhang ///
1520bdd37c9fSLei Zhang /// Example:
1521bdd37c9fSLei Zhang ///
1522bdd37c9fSLei Zhang /// ```mlir
1523bdd37c9fSLei Zhang /// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
1524bdd37c9fSLei Zhang /// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
1525bdd37c9fSLei Zhang /// ```
1526bdd37c9fSLei Zhang ///
1527bdd37c9fSLei Zhang /// folds into:
1528bdd37c9fSLei Zhang ///
1529bdd37c9fSLei Zhang /// ```mlir
1530bdd37c9fSLei Zhang /// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
1531bdd37c9fSLei Zhang /// ```
1532*c9fb3c6eSNicolas Vasilache ///
1533*c9fb3c6eSNicolas Vasilache /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
1534*c9fb3c6eSNicolas Vasilache template <typename InsertOpTy>
foldInsertAfterInsertSlice(InsertOpTy insertOp)1535*c9fb3c6eSNicolas Vasilache static LogicalResult foldInsertAfterInsertSlice(InsertOpTy insertOp) {
1536*c9fb3c6eSNicolas Vasilache auto prevInsertOp = insertOp.getDest().template getDefiningOp<InsertOpTy>();
1537bdd37c9fSLei Zhang
1538bdd37c9fSLei Zhang auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
1539bdd37c9fSLei Zhang if (!prevInsertOp ||
15402d70eff8SJacques Pienaar prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
1541bdd37c9fSLei Zhang !prevInsertOp.isSameAs(insertOp, isSame))
1542bdd37c9fSLei Zhang return failure();
1543bdd37c9fSLei Zhang
15442d70eff8SJacques Pienaar insertOp.getDestMutable().assign(prevInsertOp.getDest());
1545bdd37c9fSLei Zhang return success();
1546bdd37c9fSLei Zhang }
1547bdd37c9fSLei Zhang
1548*c9fb3c6eSNicolas Vasilache /// Same logic for folding InsertSliceOp and ParallelInsertSliceOp, the return
1549*c9fb3c6eSNicolas Vasilache /// type varies though so we wrap it in a FailureOr.
1550*c9fb3c6eSNicolas Vasilache ///
1551*c9fb3c6eSNicolas Vasilache /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
1552*c9fb3c6eSNicolas Vasilache template <typename InsertOpTy>
foldInsertOp(InsertOpTy insertOp,ArrayRef<Attribute>)1553*c9fb3c6eSNicolas Vasilache FailureOr<OpFoldResult> foldInsertOp(InsertOpTy insertOp, ArrayRef<Attribute>) {
1554*c9fb3c6eSNicolas Vasilache if (insertOp.getSourceType().hasStaticShape() &&
1555*c9fb3c6eSNicolas Vasilache insertOp.getDestType().hasStaticShape() &&
1556*c9fb3c6eSNicolas Vasilache insertOp.getSourceType() == insertOp.getDestType() &&
1557*c9fb3c6eSNicolas Vasilache succeeded(foldIdentityOffsetSizeAndStrideOpInterface(
1558*c9fb3c6eSNicolas Vasilache insertOp, insertOp.getDestType())))
1559*c9fb3c6eSNicolas Vasilache return static_cast<OpFoldResult>(insertOp.getSource());
1560*c9fb3c6eSNicolas Vasilache if (succeeded(foldInsertAfterInsertSlice(insertOp))) {
1561*c9fb3c6eSNicolas Vasilache // InsertSliceOp has 1 result but ParallelInsertSliceOp has none and should
1562*c9fb3c6eSNicolas Vasilache // return OpFoldResult().
1563*c9fb3c6eSNicolas Vasilache if (std::is_same<InsertOpTy, InsertSliceOp>::value)
1564*c9fb3c6eSNicolas Vasilache return static_cast<OpFoldResult>(insertOp->getResult(0));
1565*c9fb3c6eSNicolas Vasilache else
1566060208b4SMatthias Springer return OpFoldResult();
1567060208b4SMatthias Springer }
1568*c9fb3c6eSNicolas Vasilache return failure();
1569*c9fb3c6eSNicolas Vasilache }
1570*c9fb3c6eSNicolas Vasilache
fold(ArrayRef<Attribute> operands)1571*c9fb3c6eSNicolas Vasilache OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute> operands) {
1572*c9fb3c6eSNicolas Vasilache auto maybeOpFoldResult = foldInsertOp(*this, operands);
1573*c9fb3c6eSNicolas Vasilache return failed(maybeOpFoldResult) ? OpFoldResult() : *maybeOpFoldResult;
1574*c9fb3c6eSNicolas Vasilache }
1575060208b4SMatthias Springer
reifyResultShapes(OpBuilder & builder,ReifiedRankedShapedTypeDims & reifiedReturnShapes)15769afc0657SMaheshRavishankar LogicalResult InsertSliceOp::reifyResultShapes(
15779afc0657SMaheshRavishankar OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1578f2b5e438SMaheshRavishankar reifiedReturnShapes.resize(1, SmallVector<Value>(getType().getRank()));
1579f2b5e438SMaheshRavishankar for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1580f2b5e438SMaheshRavishankar reifiedReturnShapes[0][dim] =
15812d70eff8SJacques Pienaar builder.createOrFold<tensor::DimOp>(getLoc(), getDest(), dim);
1582f2b5e438SMaheshRavishankar }
1583f2b5e438SMaheshRavishankar return success();
1584f2b5e438SMaheshRavishankar }
1585f2b5e438SMaheshRavishankar
1586060208b4SMatthias Springer namespace {
1587060208b4SMatthias Springer /// Pattern to rewrite a insert_slice op with constant arguments.
1588*c9fb3c6eSNicolas Vasilache ///
1589*c9fb3c6eSNicolas Vasilache /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
1590*c9fb3c6eSNicolas Vasilache template <typename InsertOpTy>
1591060208b4SMatthias Springer class InsertSliceOpConstantArgumentFolder final
1592*c9fb3c6eSNicolas Vasilache : public OpRewritePattern<InsertOpTy> {
1593060208b4SMatthias Springer public:
1594*c9fb3c6eSNicolas Vasilache using OpRewritePattern<InsertOpTy>::OpRewritePattern;
1595060208b4SMatthias Springer
matchAndRewrite(InsertOpTy insertSliceOp,PatternRewriter & rewriter) const1596*c9fb3c6eSNicolas Vasilache LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
1597060208b4SMatthias Springer PatternRewriter &rewriter) const override {
1598060208b4SMatthias Springer // No constant operand, just return.
1599060208b4SMatthias Springer if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
1600060208b4SMatthias Springer return matchPattern(operand, matchConstantIndex());
1601060208b4SMatthias Springer }))
1602060208b4SMatthias Springer return failure();
1603060208b4SMatthias Springer
1604060208b4SMatthias Springer // At least one of offsets/sizes/strides is a new constant.
1605060208b4SMatthias Springer // Form the new list of operands and constant attributes from the
1606060208b4SMatthias Springer // existing.
1607060208b4SMatthias Springer SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
1608060208b4SMatthias Springer SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
1609060208b4SMatthias Springer SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
1610060208b4SMatthias Springer canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
1611060208b4SMatthias Springer canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
1612060208b4SMatthias Springer canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
1613060208b4SMatthias Springer
1614060208b4SMatthias Springer // Create the new op in canonical form.
1615741f8f2bSNicolas Vasilache auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
1616*c9fb3c6eSNicolas Vasilache insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
1617060208b4SMatthias Springer mixedOffsets, mixedSizes, mixedStrides);
16182d70eff8SJacques Pienaar Value toInsert = insertSliceOp.getSource();
1619*c9fb3c6eSNicolas Vasilache if (sourceType != insertSliceOp.getSourceType()) {
1620*c9fb3c6eSNicolas Vasilache OpBuilder::InsertionGuard g(rewriter);
1621*c9fb3c6eSNicolas Vasilache // The only difference between InsertSliceOp and ParallelInsertSliceOp is
1622*c9fb3c6eSNicolas Vasilache // the the insertion point is just before the ParallelCombiningOp in the
1623*c9fb3c6eSNicolas Vasilache // parallel case.
1624*c9fb3c6eSNicolas Vasilache if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
1625*c9fb3c6eSNicolas Vasilache rewriter.setInsertionPoint(insertSliceOp->getParentOp());
1626a08b750cSNicolas Vasilache toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
1627a08b750cSNicolas Vasilache sourceType, toInsert);
1628*c9fb3c6eSNicolas Vasilache }
1629*c9fb3c6eSNicolas Vasilache rewriter.replaceOpWithNewOp<InsertOpTy>(
16302d70eff8SJacques Pienaar insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
16312d70eff8SJacques Pienaar mixedSizes, mixedStrides);
1632060208b4SMatthias Springer return success();
1633060208b4SMatthias Springer }
1634060208b4SMatthias Springer };
1635060208b4SMatthias Springer
1636ebf35370SMatthias Springer /// Fold tensor_casts with insert_slice operations. If the source or destination
1637ebf35370SMatthias Springer /// tensor is a tensor_cast that removes static type information, the cast is
1638ebf35370SMatthias Springer /// folded into the insert_slice operation. E.g.:
1639ebf35370SMatthias Springer ///
1640ebf35370SMatthias Springer /// ```mlir
1641ebf35370SMatthias Springer /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
1642ebf35370SMatthias Springer /// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
1643ebf35370SMatthias Springer /// ```
1644ebf35370SMatthias Springer ///
1645ebf35370SMatthias Springer /// folds into:
1646ebf35370SMatthias Springer ///
1647ebf35370SMatthias Springer /// ```mlir
1648ebf35370SMatthias Springer /// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
1649ebf35370SMatthias Springer /// ```
1650ebf35370SMatthias Springer ///
1651ebf35370SMatthias Springer /// Note: When folding a cast on the destination tensor, the result of the
1652ebf35370SMatthias Springer /// insert_slice operation is casted to ensure that the type of the result did
1653ebf35370SMatthias Springer /// not change.
1654*c9fb3c6eSNicolas Vasilache ///
1655*c9fb3c6eSNicolas Vasilache /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
1656*c9fb3c6eSNicolas Vasilache template <typename InsertOpTy>
1657*c9fb3c6eSNicolas Vasilache struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
1658*c9fb3c6eSNicolas Vasilache using OpRewritePattern<InsertOpTy>::OpRewritePattern;
1659060208b4SMatthias Springer
matchAndRewrite__anon3fb9f79f1111::InsertSliceOpCastFolder1660*c9fb3c6eSNicolas Vasilache LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
1661060208b4SMatthias Springer PatternRewriter &rewriter) const override {
1662060208b4SMatthias Springer if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
1663060208b4SMatthias Springer return matchPattern(operand, matchConstantIndex());
1664060208b4SMatthias Springer }))
1665060208b4SMatthias Springer return failure();
1666060208b4SMatthias Springer
1667060208b4SMatthias Springer auto getSourceOfCastOp = [](Value v) -> Optional<Value> {
1668060208b4SMatthias Springer auto castOp = v.getDefiningOp<tensor::CastOp>();
1669060208b4SMatthias Springer if (!castOp || !canFoldIntoConsumerOp(castOp))
1670060208b4SMatthias Springer return llvm::None;
16712d70eff8SJacques Pienaar return castOp.getSource();
1672060208b4SMatthias Springer };
1673060208b4SMatthias Springer Optional<Value> sourceCastSource =
16742d70eff8SJacques Pienaar getSourceOfCastOp(insertSliceOp.getSource());
16752d70eff8SJacques Pienaar Optional<Value> destCastSource = getSourceOfCastOp(insertSliceOp.getDest());
1676060208b4SMatthias Springer if (!sourceCastSource && !destCastSource)
1677060208b4SMatthias Springer return failure();
1678060208b4SMatthias Springer
16792d70eff8SJacques Pienaar auto src =
16802d70eff8SJacques Pienaar (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
16812d70eff8SJacques Pienaar auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
1682*c9fb3c6eSNicolas Vasilache auto srcType = src.getType().template cast<ShapedType>();
1683*c9fb3c6eSNicolas Vasilache auto dstType = dst.getType().template cast<ShapedType>();
16842d70eff8SJacques Pienaar if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
16852d70eff8SJacques Pienaar insertSliceOp.getStaticSizes(),
16862d70eff8SJacques Pienaar insertSliceOp.getStaticStrides()) !=
1687cd0d095cSIvan Butygin SliceVerificationResult::Success)
1688cd0d095cSIvan Butygin return failure();
1689cd0d095cSIvan Butygin
1690*c9fb3c6eSNicolas Vasilache Operation *replacement = rewriter.create<InsertOpTy>(
1691cd0d095cSIvan Butygin insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
1692cd0d095cSIvan Butygin insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
1693060208b4SMatthias Springer
1694*c9fb3c6eSNicolas Vasilache // In the parallel case there is no result and so nothing to cast.
1695*c9fb3c6eSNicolas Vasilache bool isParallelInsert =
1696*c9fb3c6eSNicolas Vasilache std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
1697*c9fb3c6eSNicolas Vasilache if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
1698*c9fb3c6eSNicolas Vasilache replacement = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
1699*c9fb3c6eSNicolas Vasilache insertSliceOp.getDestType(),
1700*c9fb3c6eSNicolas Vasilache replacement->getResult(0));
1701060208b4SMatthias Springer }
1702*c9fb3c6eSNicolas Vasilache rewriter.replaceOp(insertSliceOp, replacement->getResults());
1703060208b4SMatthias Springer return success();
1704060208b4SMatthias Springer }
1705060208b4SMatthias Springer };
1706ebf35370SMatthias Springer
1707ebf35370SMatthias Springer /// If additional static type information can be deduced from a insert_slice's
1708ebf35370SMatthias Springer /// size operands, insert an explicit cast of the op's source operand. This
1709ebf35370SMatthias Springer /// enables other canonicalization patterns that are matching for tensor_cast
1710ebf35370SMatthias Springer /// ops such as `ForOpTensorCastFolder` in SCF.
1711ebf35370SMatthias Springer ///
1712ebf35370SMatthias Springer /// Example:
1713ebf35370SMatthias Springer ///
1714ebf35370SMatthias Springer /// ```mlir
1715ebf35370SMatthias Springer /// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
1716ebf35370SMatthias Springer /// : tensor<?x?xf32> into ...
1717ebf35370SMatthias Springer /// ```
1718ebf35370SMatthias Springer ///
1719ebf35370SMatthias Springer /// folds into:
1720ebf35370SMatthias Springer ///
1721ebf35370SMatthias Springer /// ```mlir
1722ebf35370SMatthias Springer /// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
1723ebf35370SMatthias Springer /// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
1724ebf35370SMatthias Springer /// : tensor<64x64xf32> into ...
1725ebf35370SMatthias Springer /// ```
1726*c9fb3c6eSNicolas Vasilache ///
1727*c9fb3c6eSNicolas Vasilache /// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
1728*c9fb3c6eSNicolas Vasilache template <typename InsertOpTy>
1729ebf35370SMatthias Springer struct InsertSliceOpSourceCastInserter final
1730*c9fb3c6eSNicolas Vasilache : public OpRewritePattern<InsertOpTy> {
1731*c9fb3c6eSNicolas Vasilache using OpRewritePattern<InsertOpTy>::OpRewritePattern;
1732ebf35370SMatthias Springer
matchAndRewrite__anon3fb9f79f1111::InsertSliceOpSourceCastInserter1733*c9fb3c6eSNicolas Vasilache LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
1734ebf35370SMatthias Springer PatternRewriter &rewriter) const override {
1735ebf35370SMatthias Springer RankedTensorType srcType = insertSliceOp.getSourceType();
1736*c9fb3c6eSNicolas Vasilache if (srcType.getRank() != insertSliceOp.getDestType().getRank())
1737ebf35370SMatthias Springer return failure();
1738ebf35370SMatthias Springer SmallVector<int64_t> newSrcShape(srcType.getShape().begin(),
1739ebf35370SMatthias Springer srcType.getShape().end());
1740f2e945a3SNicolas Vasilache for (int64_t i = 0; i < srcType.getRank(); ++i) {
1741f2e945a3SNicolas Vasilache if (Optional<int64_t> constInt =
1742f2e945a3SNicolas Vasilache getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
1743f2e945a3SNicolas Vasilache newSrcShape[i] = *constInt;
1744f2e945a3SNicolas Vasilache }
174582dd977bSNicolas Vasilache
1746ebf35370SMatthias Springer RankedTensorType newSrcType =
1747ebf35370SMatthias Springer RankedTensorType::get(newSrcShape, srcType.getElementType());
174882dd977bSNicolas Vasilache if (srcType == newSrcType ||
174982dd977bSNicolas Vasilache !preservesStaticInformation(srcType, newSrcType) ||
175082dd977bSNicolas Vasilache !tensor::CastOp::areCastCompatible(srcType, newSrcType))
1751ebf35370SMatthias Springer return failure();
1752ebf35370SMatthias Springer
175382dd977bSNicolas Vasilache // newSrcType is:
175482dd977bSNicolas Vasilache // 1) Different from srcType.
175582dd977bSNicolas Vasilache // 2) "More static" than srcType.
175682dd977bSNicolas Vasilache // 3) Cast-compatible with srcType.
175782dd977bSNicolas Vasilache // Insert the cast.
1758*c9fb3c6eSNicolas Vasilache OpBuilder::InsertionGuard g(rewriter);
1759*c9fb3c6eSNicolas Vasilache // The only difference between InsertSliceOp and ParallelInsertSliceOp is
1760*c9fb3c6eSNicolas Vasilache // the the insertion point is just before the ParallelCombiningOp in the
1761*c9fb3c6eSNicolas Vasilache // parallel case.
1762*c9fb3c6eSNicolas Vasilache if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
1763*c9fb3c6eSNicolas Vasilache rewriter.setInsertionPoint(insertSliceOp->getParentOp());
1764ebf35370SMatthias Springer Value cast = rewriter.create<tensor::CastOp>(
17652d70eff8SJacques Pienaar insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
1766*c9fb3c6eSNicolas Vasilache rewriter.replaceOpWithNewOp<InsertOpTy>(
17672d70eff8SJacques Pienaar insertSliceOp, cast, insertSliceOp.getDest(),
1768ebf35370SMatthias Springer insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
1769ebf35370SMatthias Springer insertSliceOp.getMixedStrides());
1770*c9fb3c6eSNicolas Vasilache cast.getDefiningOp()->getParentOfType<ModuleOp>().dump();
1771ebf35370SMatthias Springer return success();
1772ebf35370SMatthias Springer }
1773ebf35370SMatthias Springer };
1774060208b4SMatthias Springer } // namespace
1775060208b4SMatthias Springer
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1776060208b4SMatthias Springer void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
1777060208b4SMatthias Springer MLIRContext *context) {
1778*c9fb3c6eSNicolas Vasilache results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
1779*c9fb3c6eSNicolas Vasilache InsertSliceOpCastFolder<InsertSliceOp>,
1780*c9fb3c6eSNicolas Vasilache InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
1781060208b4SMatthias Springer }
1782060208b4SMatthias Springer
createCanonicalRankReducingInsertSliceOp(OpBuilder & b,Location loc,Value tensor,Value dest)1783aa373180SNicolas Vasilache Value mlir::tensor::createCanonicalRankReducingInsertSliceOp(OpBuilder &b,
1784aa373180SNicolas Vasilache Location loc,
1785aa373180SNicolas Vasilache Value tensor,
1786aa373180SNicolas Vasilache Value dest) {
1787aa373180SNicolas Vasilache auto rankedTensorType = dest.getType().cast<RankedTensorType>();
1788aa373180SNicolas Vasilache unsigned rank = rankedTensorType.getRank();
1789aa373180SNicolas Vasilache auto shape = rankedTensorType.getShape();
1790aa373180SNicolas Vasilache SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
1791aa373180SNicolas Vasilache SmallVector<OpFoldResult> sizes;
1792aa373180SNicolas Vasilache for (unsigned i = 0, e = rank; i < e; ++i) {
1793aa373180SNicolas Vasilache OpFoldResult dim;
1794aa373180SNicolas Vasilache if (rankedTensorType.isDynamicDim(i))
1795aa373180SNicolas Vasilache dim = b.createOrFold<tensor::DimOp>(
1796aa373180SNicolas Vasilache loc, dest, b.create<arith::ConstantIndexOp>(loc, i));
1797aa373180SNicolas Vasilache else
1798aa373180SNicolas Vasilache dim = b.getIndexAttr(shape[i]);
1799aa373180SNicolas Vasilache sizes.push_back(dim);
1800aa373180SNicolas Vasilache }
1801aa373180SNicolas Vasilache SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
1802aa373180SNicolas Vasilache return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
1803aa373180SNicolas Vasilache sizes, strides);
1804aa373180SNicolas Vasilache }
1805aa373180SNicolas Vasilache
1806060208b4SMatthias Springer //===----------------------------------------------------------------------===//
1807fd0c6f53SAlexander Belyaev // PadOp
1808fd0c6f53SAlexander Belyaev //===----------------------------------------------------------------------===//
1809fd0c6f53SAlexander Belyaev
1810fd0c6f53SAlexander Belyaev // TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it
1811fd0c6f53SAlexander Belyaev // supports optional types.
printInferType(OpAsmPrinter & printer,Operation * op,Value optOperand,Type typeToInfer,Type typeToInferFrom)1812fd0c6f53SAlexander Belyaev void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand,
1813fd0c6f53SAlexander Belyaev Type typeToInfer, Type typeToInferFrom) {}
1814fd0c6f53SAlexander Belyaev
parseInferType(OpAsmParser & parser,Optional<OpAsmParser::UnresolvedOperand> optOperand,Type & typeToInfer,Type typeToInferFrom)1815fd0c6f53SAlexander Belyaev ParseResult parseInferType(OpAsmParser &parser,
1816e13d23bcSMarkus Böck Optional<OpAsmParser::UnresolvedOperand> optOperand,
1817fd0c6f53SAlexander Belyaev Type &typeToInfer, Type typeToInferFrom) {
1818fd0c6f53SAlexander Belyaev if (optOperand)
1819fd0c6f53SAlexander Belyaev typeToInfer = typeToInferFrom;
1820fd0c6f53SAlexander Belyaev return success();
1821fd0c6f53SAlexander Belyaev }
1822fd0c6f53SAlexander Belyaev
verify()1823b98dc035SRiver Riddle LogicalResult PadOp::verify() {
18242d70eff8SJacques Pienaar auto sourceType = getSource().getType().cast<RankedTensorType>();
18252d70eff8SJacques Pienaar auto resultType = getResult().getType().cast<RankedTensorType>();
18262d70eff8SJacques Pienaar auto expectedType = PadOp::inferResultType(
18272d70eff8SJacques Pienaar sourceType, extractFromI64ArrayAttr(getStaticLow()),
18282d70eff8SJacques Pienaar extractFromI64ArrayAttr(getStaticHigh()));
1829fd0c6f53SAlexander Belyaev for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
1830fd0c6f53SAlexander Belyaev if (resultType.getDimSize(i) == expectedType.getDimSize(i))
1831fd0c6f53SAlexander Belyaev continue;
1832fd0c6f53SAlexander Belyaev if (expectedType.isDynamicDim(i))
1833fd0c6f53SAlexander Belyaev continue;
1834b98dc035SRiver Riddle return emitError("specified type ")
1835fd0c6f53SAlexander Belyaev << resultType << " does not match the inferred type "
1836fd0c6f53SAlexander Belyaev << expectedType;
1837fd0c6f53SAlexander Belyaev }
1838fd0c6f53SAlexander Belyaev
1839ed645f63SChia-hung Duan return success();
1840ed645f63SChia-hung Duan }
1841ed645f63SChia-hung Duan
verifyRegions()1842ed645f63SChia-hung Duan LogicalResult PadOp::verifyRegions() {
1843b98dc035SRiver Riddle auto ®ion = getRegion();
18442d70eff8SJacques Pienaar unsigned rank = getResult().getType().cast<RankedTensorType>().getRank();
1845fd0c6f53SAlexander Belyaev Block &block = region.front();
1846fd0c6f53SAlexander Belyaev if (block.getNumArguments() != rank)
1847b98dc035SRiver Riddle return emitError("expected the block to have ") << rank << " arguments";
1848fd0c6f53SAlexander Belyaev
1849fd0c6f53SAlexander Belyaev // Note: the number and type of yield values are checked in the YieldOp.
1850fd0c6f53SAlexander Belyaev for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
1851fd0c6f53SAlexander Belyaev if (!en.value().isIndex())
1852b98dc035SRiver Riddle return emitOpError("expected block argument ")
1853fd0c6f53SAlexander Belyaev << (en.index() + 1) << " to be an index";
1854fd0c6f53SAlexander Belyaev }
1855fd0c6f53SAlexander Belyaev
1856fd0c6f53SAlexander Belyaev // Ensure that the region yields an element of the right type.
1857fd0c6f53SAlexander Belyaev auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
18582d70eff8SJacques Pienaar if (yieldOp.getValue().getType() !=
1859b98dc035SRiver Riddle getType().cast<ShapedType>().getElementType())
1860b98dc035SRiver Riddle return emitOpError("expected yield type to match shape element type");
1861fd0c6f53SAlexander Belyaev
1862fd0c6f53SAlexander Belyaev return success();
1863fd0c6f53SAlexander Belyaev }
1864fd0c6f53SAlexander Belyaev
inferResultType(RankedTensorType sourceType,ArrayRef<int64_t> staticLow,ArrayRef<int64_t> staticHigh,ArrayRef<int64_t> resultShape)1865fd0c6f53SAlexander Belyaev RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
1866fd0c6f53SAlexander Belyaev ArrayRef<int64_t> staticLow,
1867fd0c6f53SAlexander Belyaev ArrayRef<int64_t> staticHigh,
1868fd0c6f53SAlexander Belyaev ArrayRef<int64_t> resultShape) {
1869fd0c6f53SAlexander Belyaev unsigned rank = sourceType.getRank();
1870fd0c6f53SAlexander Belyaev assert(staticLow.size() == rank && "unexpected staticLow size mismatch");
1871fd0c6f53SAlexander Belyaev assert(staticHigh.size() == rank && "unexpected staticHigh size mismatch");
1872fd0c6f53SAlexander Belyaev assert((resultShape.empty() || resultShape.size() == rank) &&
1873fd0c6f53SAlexander Belyaev "unexpected resultShape size mismatch");
1874fd0c6f53SAlexander Belyaev
1875fd0c6f53SAlexander Belyaev SmallVector<int64_t, 4> inferredShape;
1876fd0c6f53SAlexander Belyaev for (auto i : llvm::seq<unsigned>(0, rank)) {
1877fd0c6f53SAlexander Belyaev if (sourceType.isDynamicDim(i) ||
1878fd0c6f53SAlexander Belyaev staticLow[i] == ShapedType::kDynamicSize ||
1879fd0c6f53SAlexander Belyaev staticHigh[i] == ShapedType::kDynamicSize) {
1880fd0c6f53SAlexander Belyaev inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamicSize
1881fd0c6f53SAlexander Belyaev : resultShape[i]);
1882fd0c6f53SAlexander Belyaev } else {
1883fd0c6f53SAlexander Belyaev int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
1884fd0c6f53SAlexander Belyaev assert((resultShape.empty() || size == resultShape[i] ||
1885fd0c6f53SAlexander Belyaev resultShape[i] == ShapedType::kDynamicSize) &&
1886fd0c6f53SAlexander Belyaev "mismatch between inferred shape and result shape");
1887fd0c6f53SAlexander Belyaev inferredShape.push_back(size);
1888fd0c6f53SAlexander Belyaev }
1889fd0c6f53SAlexander Belyaev }
1890fd0c6f53SAlexander Belyaev
1891fd0c6f53SAlexander Belyaev return RankedTensorType::get(inferredShape, sourceType.getElementType());
1892fd0c6f53SAlexander Belyaev }
1893fd0c6f53SAlexander Belyaev
build(OpBuilder & b,OperationState & result,Value source,ArrayRef<int64_t> staticLow,ArrayRef<int64_t> staticHigh,ValueRange low,ValueRange high,bool nofold,ArrayRef<NamedAttribute> attrs)1894fd0c6f53SAlexander Belyaev void PadOp::build(OpBuilder &b, OperationState &result, Value source,
1895fd0c6f53SAlexander Belyaev ArrayRef<int64_t> staticLow, ArrayRef<int64_t> staticHigh,
1896fd0c6f53SAlexander Belyaev ValueRange low, ValueRange high, bool nofold,
1897fd0c6f53SAlexander Belyaev ArrayRef<NamedAttribute> attrs) {
1898fd0c6f53SAlexander Belyaev auto sourceType = source.getType().cast<RankedTensorType>();
1899fd0c6f53SAlexander Belyaev auto resultType = inferResultType(sourceType, staticLow, staticHigh);
1900fd0c6f53SAlexander Belyaev build(b, result, resultType, source, low, high, b.getI64ArrayAttr(staticLow),
1901fd0c6f53SAlexander Belyaev b.getI64ArrayAttr(staticHigh), nofold ? b.getUnitAttr() : UnitAttr());
1902fd0c6f53SAlexander Belyaev result.addAttributes(attrs);
1903fd0c6f53SAlexander Belyaev }
1904fd0c6f53SAlexander Belyaev
build(OpBuilder & b,OperationState & result,Value source,ValueRange low,ValueRange high,bool nofold,ArrayRef<NamedAttribute> attrs)1905fd0c6f53SAlexander Belyaev void PadOp::build(OpBuilder &b, OperationState &result, Value source,
1906fd0c6f53SAlexander Belyaev ValueRange low, ValueRange high, bool nofold,
1907fd0c6f53SAlexander Belyaev ArrayRef<NamedAttribute> attrs) {
1908fd0c6f53SAlexander Belyaev auto sourceType = source.getType().cast<RankedTensorType>();
1909fd0c6f53SAlexander Belyaev unsigned rank = sourceType.getRank();
1910fd0c6f53SAlexander Belyaev SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamicSize);
1911fd0c6f53SAlexander Belyaev build(b, result, source, staticVector, staticVector, low, high, nofold,
1912fd0c6f53SAlexander Belyaev attrs);
1913fd0c6f53SAlexander Belyaev }
1914fd0c6f53SAlexander Belyaev
build(OpBuilder & b,OperationState & result,Type resultType,Value source,ArrayRef<OpFoldResult> low,ArrayRef<OpFoldResult> high,bool nofold,ArrayRef<NamedAttribute> attrs)1915fd0c6f53SAlexander Belyaev void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
1916fd0c6f53SAlexander Belyaev Value source, ArrayRef<OpFoldResult> low,
1917fd0c6f53SAlexander Belyaev ArrayRef<OpFoldResult> high, bool nofold,
1918fd0c6f53SAlexander Belyaev ArrayRef<NamedAttribute> attrs) {
1919fd0c6f53SAlexander Belyaev assert(resultType.isa<RankedTensorType>());
1920fd0c6f53SAlexander Belyaev auto sourceType = source.getType().cast<RankedTensorType>();
1921fd0c6f53SAlexander Belyaev SmallVector<Value, 4> dynamicLow, dynamicHigh;
1922fd0c6f53SAlexander Belyaev SmallVector<int64_t, 4> staticLow, staticHigh;
1923fd0c6f53SAlexander Belyaev // staticLow and staticHigh have full information of the padding config.
1924fd0c6f53SAlexander Belyaev // This will grow staticLow and staticHigh with 1 value. If the config is
1925fd0c6f53SAlexander Belyaev // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
1926fd0c6f53SAlexander Belyaev // value as well.
1927fd0c6f53SAlexander Belyaev dispatchIndexOpFoldResults(low, dynamicLow, staticLow,
1928fd0c6f53SAlexander Belyaev ShapedType::kDynamicSize);
1929fd0c6f53SAlexander Belyaev dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh,
1930fd0c6f53SAlexander Belyaev ShapedType::kDynamicSize);
1931fd0c6f53SAlexander Belyaev if (!resultType) {
1932fd0c6f53SAlexander Belyaev resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
1933fd0c6f53SAlexander Belyaev }
1934fd0c6f53SAlexander Belyaev build(b, result, resultType, source, dynamicLow, dynamicHigh,
1935fd0c6f53SAlexander Belyaev b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh),
1936fd0c6f53SAlexander Belyaev nofold ? b.getUnitAttr() : UnitAttr());
1937fd0c6f53SAlexander Belyaev result.addAttributes(attrs);
1938fd0c6f53SAlexander Belyaev }
1939fd0c6f53SAlexander Belyaev
getPaddedDims()1940973dbe20Sgysit llvm::SmallBitVector PadOp::getPaddedDims() {
1941973dbe20Sgysit llvm::SmallBitVector paddedDims(getSourceType().getRank());
1942973dbe20Sgysit auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
1943973dbe20Sgysit for (const auto &en : enumerate(paddingWidths))
1944973dbe20Sgysit if (getConstantIntValue(en.value()) != static_cast<int64_t>(0))
1945973dbe20Sgysit paddedDims.set(en.index());
1946973dbe20Sgysit };
1947973dbe20Sgysit extractPaddedDims(getMixedLowPad());
1948973dbe20Sgysit extractPaddedDims(getMixedHighPad());
1949973dbe20Sgysit return paddedDims;
1950973dbe20Sgysit }
1951973dbe20Sgysit
1952fd0c6f53SAlexander Belyaev namespace {
1953fd0c6f53SAlexander Belyaev // Folds tensor.pad when padding is static zeros and the attribute
1954fd0c6f53SAlexander Belyaev // doesn't request otherwise.
1955fd0c6f53SAlexander Belyaev struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
1956fd0c6f53SAlexander Belyaev using OpRewritePattern<PadOp>::OpRewritePattern;
1957fd0c6f53SAlexander Belyaev
matchAndRewrite__anon3fb9f79f1611::FoldStaticZeroPadding1958fd0c6f53SAlexander Belyaev LogicalResult matchAndRewrite(PadOp padTensorOp,
1959fd0c6f53SAlexander Belyaev PatternRewriter &rewriter) const override {
1960fd0c6f53SAlexander Belyaev if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
1961fd0c6f53SAlexander Belyaev return failure();
19622d70eff8SJacques Pienaar if (padTensorOp.getNofold())
1963fd0c6f53SAlexander Belyaev return failure();
1964fd0c6f53SAlexander Belyaev rewriter.replaceOpWithNewOp<tensor::CastOp>(
19652d70eff8SJacques Pienaar padTensorOp, padTensorOp.getResult().getType(),
19662d70eff8SJacques Pienaar padTensorOp.getSource());
1967fd0c6f53SAlexander Belyaev return success();
1968fd0c6f53SAlexander Belyaev }
1969fd0c6f53SAlexander Belyaev };
1970fd0c6f53SAlexander Belyaev
1971fd0c6f53SAlexander Belyaev // Fold CastOp into PadOp when adding static information.
1972fd0c6f53SAlexander Belyaev struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
1973fd0c6f53SAlexander Belyaev using OpRewritePattern<PadOp>::OpRewritePattern;
1974fd0c6f53SAlexander Belyaev
matchAndRewrite__anon3fb9f79f1611::FoldSourceTensorCast1975fd0c6f53SAlexander Belyaev LogicalResult matchAndRewrite(PadOp padTensorOp,
1976fd0c6f53SAlexander Belyaev PatternRewriter &rewriter) const override {
19772d70eff8SJacques Pienaar auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
1978fd0c6f53SAlexander Belyaev if (!tensor::canFoldIntoConsumerOp(castOp))
1979fd0c6f53SAlexander Belyaev return failure();
1980fd0c6f53SAlexander Belyaev
1981fd0c6f53SAlexander Belyaev auto newResultType = PadOp::inferResultType(
19822d70eff8SJacques Pienaar castOp.getSource().getType().cast<RankedTensorType>(),
19832d70eff8SJacques Pienaar extractFromI64ArrayAttr(padTensorOp.getStaticLow()),
19842d70eff8SJacques Pienaar extractFromI64ArrayAttr(padTensorOp.getStaticHigh()),
1985fd0c6f53SAlexander Belyaev padTensorOp.getResultType().getShape());
1986fd0c6f53SAlexander Belyaev
1987fd0c6f53SAlexander Belyaev if (newResultType == padTensorOp.getResultType()) {
1988fd0c6f53SAlexander Belyaev rewriter.updateRootInPlace(padTensorOp, [&]() {
19892d70eff8SJacques Pienaar padTensorOp.getSourceMutable().assign(castOp.getSource());
1990fd0c6f53SAlexander Belyaev });
1991fd0c6f53SAlexander Belyaev } else {
1992fd0c6f53SAlexander Belyaev auto newOp = rewriter.create<PadOp>(
19932d70eff8SJacques Pienaar padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
19942d70eff8SJacques Pienaar padTensorOp.getLow(), padTensorOp.getHigh(),
19952d70eff8SJacques Pienaar padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
19962d70eff8SJacques Pienaar padTensorOp.getNofold());
1997fd0c6f53SAlexander Belyaev BlockAndValueMapping mapper;
1998fd0c6f53SAlexander Belyaev padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
1999fd0c6f53SAlexander Belyaev
2000fd0c6f53SAlexander Belyaev rewriter.replaceOpWithNewOp<tensor::CastOp>(
2001fd0c6f53SAlexander Belyaev padTensorOp, padTensorOp.getResultType(), newOp);
2002fd0c6f53SAlexander Belyaev }
2003fd0c6f53SAlexander Belyaev return success();
2004fd0c6f53SAlexander Belyaev }
2005fd0c6f53SAlexander Belyaev };
2006fd0c6f53SAlexander Belyaev
2007fd0c6f53SAlexander Belyaev // Fold CastOp using the result of PadOp back into the latter if it adds
2008fd0c6f53SAlexander Belyaev // static information.
2009fd0c6f53SAlexander Belyaev struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
2010fd0c6f53SAlexander Belyaev using OpRewritePattern<PadOp>::OpRewritePattern;
2011fd0c6f53SAlexander Belyaev
matchAndRewrite__anon3fb9f79f1611::FoldTargetTensorCast2012fd0c6f53SAlexander Belyaev LogicalResult matchAndRewrite(PadOp padTensorOp,
2013fd0c6f53SAlexander Belyaev PatternRewriter &rewriter) const override {
20142d70eff8SJacques Pienaar if (!padTensorOp.getResult().hasOneUse())
2015fd0c6f53SAlexander Belyaev return failure();
2016fd0c6f53SAlexander Belyaev auto tensorCastOp =
2017fd0c6f53SAlexander Belyaev dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
2018fd0c6f53SAlexander Belyaev if (!tensorCastOp)
2019fd0c6f53SAlexander Belyaev return failure();
20202d70eff8SJacques Pienaar if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(),
20212d70eff8SJacques Pienaar tensorCastOp.getDest().getType()))
2022fd0c6f53SAlexander Belyaev return failure();
2023fd0c6f53SAlexander Belyaev
2024fd0c6f53SAlexander Belyaev auto replacementOp = rewriter.create<PadOp>(
20252d70eff8SJacques Pienaar padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
20262d70eff8SJacques Pienaar padTensorOp.getSource(), padTensorOp.getLow(), padTensorOp.getHigh(),
20272d70eff8SJacques Pienaar padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
20282d70eff8SJacques Pienaar padTensorOp.getNofold());
20292d70eff8SJacques Pienaar replacementOp.getRegion().takeBody(padTensorOp.getRegion());
2030fd0c6f53SAlexander Belyaev
20312d70eff8SJacques Pienaar rewriter.replaceOp(padTensorOp, replacementOp.getResult());
20322d70eff8SJacques Pienaar rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
2033fd0c6f53SAlexander Belyaev return success();
2034fd0c6f53SAlexander Belyaev }
2035fd0c6f53SAlexander Belyaev };
2036973dbe20Sgysit
2037973dbe20Sgysit /// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
2038973dbe20Sgysit /// different dimensions. The pattern applies if the following preconditions
2039973dbe20Sgysit /// hold:
2040973dbe20Sgysit /// 1) the tensor::ExtractSliceOps are not rank-reducing,
2041973dbe20Sgysit /// 2) the tensor::ExtractSliceOps have only unit-strides,
2042973dbe20Sgysit /// 3) the tensor::PadOps perform only high-padding,
2043973dbe20Sgysit /// 4) the tensor::PadOps have the same constant padding value,
2044973dbe20Sgysit /// 5) the tensor::PadOps do not have common padding dimensions,
2045973dbe20Sgysit /// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
2046973dbe20Sgysit /// zero-offset for every dimension.
2047973dbe20Sgysit /// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for the
2048973dbe20Sgysit /// padded source dimensions.
2049973dbe20Sgysit ///
2050973dbe20Sgysit /// Example:
2051973dbe20Sgysit ///
2052973dbe20Sgysit /// ```mlir
2053973dbe20Sgysit /// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
2054973dbe20Sgysit /// : tensor<64x64xf32> to tensor<?x64xf32>
2055973dbe20Sgysit /// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
2056973dbe20Sgysit /// } : tensor<?x64xf32> to tensor<8x64xf32>
2057973dbe20Sgysit /// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
2058973dbe20Sgysit /// : tensor<8x64xf32> to tensor<8x?xf32>
2059973dbe20Sgysit /// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
2060973dbe20Sgysit /// } : tensor<8x?xf32> to tensor<8x4xf32>
2061973dbe20Sgysit /// ```
2062973dbe20Sgysit ///
2063973dbe20Sgysit /// folds into:
2064973dbe20Sgysit ///
2065973dbe20Sgysit /// ```mlir
2066973dbe20Sgysit /// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
2067973dbe20Sgysit /// : tensor<64x64xf32> to tensor<?x?xf32>
2068973dbe20Sgysit /// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
2069973dbe20Sgysit /// } : tensor<?x?xf32> to tensor<8x4xf32>
2070973dbe20Sgysit /// ```
2071973dbe20Sgysit struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
2072973dbe20Sgysit using OpRewritePattern<PadOp>::OpRewritePattern;
2073973dbe20Sgysit
matchAndRewrite__anon3fb9f79f1611::FoldOrthogonalPaddings2074973dbe20Sgysit LogicalResult matchAndRewrite(PadOp padOp,
2075973dbe20Sgysit PatternRewriter &rewriter) const override {
20762d70eff8SJacques Pienaar auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
2077973dbe20Sgysit if (!innerSliceOp)
2078973dbe20Sgysit return failure();
20792d70eff8SJacques Pienaar auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
20802d70eff8SJacques Pienaar if (!outerPadOp || outerPadOp.getNofold())
2081973dbe20Sgysit return failure();
20822d70eff8SJacques Pienaar auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
2083973dbe20Sgysit if (!outerSliceOp)
2084973dbe20Sgysit return failure();
2085973dbe20Sgysit
2086973dbe20Sgysit // 1) Fail if the chain is rank-reducing.
2087973dbe20Sgysit int64_t rank = padOp.getSourceType().getRank();
2088973dbe20Sgysit if (outerSliceOp.getSourceType().getRank() != rank) {
2089973dbe20Sgysit return rewriter.notifyMatchFailure(padOp,
2090973dbe20Sgysit "cannot fold rank-reducing chain");
2091973dbe20Sgysit }
2092973dbe20Sgysit
2093973dbe20Sgysit // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
2094973dbe20Sgysit if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
2095973dbe20Sgysit return rewriter.notifyMatchFailure(
2096973dbe20Sgysit padOp, "cannot fold non-unit stride ExtractSliceOps");
2097973dbe20Sgysit }
2098973dbe20Sgysit
2099973dbe20Sgysit // 3) Fail if the tensor::PadOps have non-zero low padding.
2100973dbe20Sgysit if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
2101973dbe20Sgysit return rewriter.notifyMatchFailure(padOp,
2102973dbe20Sgysit "cannot fold PadOps with low padding");
2103973dbe20Sgysit }
2104973dbe20Sgysit
2105973dbe20Sgysit // 4) Fail if the tensor::PadOps padding values do not match.
2106973dbe20Sgysit Attribute innerAttr, outerAttr;
2107973dbe20Sgysit Value innerValue = padOp.getConstantPaddingValue();
2108973dbe20Sgysit Value outerValue = outerPadOp.getConstantPaddingValue();
2109973dbe20Sgysit if (!innerValue || !outerValue ||
2110973dbe20Sgysit !matchPattern(innerValue, m_Constant(&innerAttr)) ||
2111973dbe20Sgysit !matchPattern(outerValue, m_Constant(&outerAttr)) ||
2112973dbe20Sgysit innerAttr != outerAttr) {
2113973dbe20Sgysit return rewriter.notifyMatchFailure(
2114973dbe20Sgysit padOp, "cannot fold PadOps with different padding values");
2115973dbe20Sgysit }
2116973dbe20Sgysit
2117973dbe20Sgysit // 5) Fail if a dimension is padded by both tensor::PadOps.
2118973dbe20Sgysit llvm::SmallBitVector innerDims = padOp.getPaddedDims();
2119973dbe20Sgysit llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
2120973dbe20Sgysit if (innerDims.anyCommon(outerDims)) {
2121973dbe20Sgysit return rewriter.notifyMatchFailure(
2122973dbe20Sgysit padOp, "cannot fold PadOps with common padding dimensions");
2123973dbe20Sgysit }
2124973dbe20Sgysit
2125973dbe20Sgysit // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
2126973dbe20Sgysit // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
2127973dbe20Sgysit // for every dimension, and use the offset the other pair. Fail if no
2128973dbe20Sgysit // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
2129973dbe20Sgysit // exists.
2130973dbe20Sgysit SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0));
2131973dbe20Sgysit for (auto &en : enumerate(newOffsets)) {
2132973dbe20Sgysit OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
2133973dbe20Sgysit OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
2134973dbe20Sgysit if (!innerDims.test(en.index()) &&
2135973dbe20Sgysit (getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) {
2136973dbe20Sgysit en.value() = outerOffset;
2137973dbe20Sgysit continue;
2138973dbe20Sgysit }
2139973dbe20Sgysit if (!outerDims.test(en.index()) &&
2140973dbe20Sgysit (getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) {
2141973dbe20Sgysit en.value() = innerOffset;
2142973dbe20Sgysit continue;
2143973dbe20Sgysit }
2144973dbe20Sgysit return rewriter.notifyMatchFailure(
2145973dbe20Sgysit padOp, "cannot find zero-offset and zero-padding pair");
2146973dbe20Sgysit }
2147973dbe20Sgysit
2148973dbe20Sgysit // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size of
2149973dbe20Sgysit // the outer tensor::ExtractSliceOp for the dimensions padded by the outer
2150973dbe20Sgysit // tensor::PadOp and fail if the size of the inner tensor::ExtractSliceOp
2151973dbe20Sgysit // does not match the size of the padded dimension. Otherwise, take the size
2152973dbe20Sgysit // of the inner tensor::ExtractSliceOp.
2153973dbe20Sgysit SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
2154973dbe20Sgysit for (auto &en : enumerate(newSizes)) {
2155973dbe20Sgysit if (!outerDims.test(en.index()))
2156973dbe20Sgysit continue;
2157973dbe20Sgysit OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
2158973dbe20Sgysit int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
2159973dbe20Sgysit assert(!ShapedType::isDynamic(sourceSize) &&
2160973dbe20Sgysit "expected padded dimension to have a static size");
2161973dbe20Sgysit if (getConstantIntValue(sliceSize) != sourceSize) {
2162973dbe20Sgysit return rewriter.notifyMatchFailure(
2163973dbe20Sgysit padOp, "cannot fold since the inner ExtractSliceOp size does not "
2164973dbe20Sgysit "match the size of the outer padding");
2165973dbe20Sgysit }
2166973dbe20Sgysit en.value() = outerSliceOp.getMixedSizes()[en.index()];
2167973dbe20Sgysit }
2168973dbe20Sgysit
2169973dbe20Sgysit // Combine the high paddings of the two tensor::PadOps.
2170973dbe20Sgysit SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0));
2171973dbe20Sgysit for (auto &en : enumerate(newHighPad)) {
2172973dbe20Sgysit if (innerDims.test(en.index()))
2173973dbe20Sgysit newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
2174973dbe20Sgysit if (outerDims.test(en.index()))
2175973dbe20Sgysit newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
2176973dbe20Sgysit }
2177973dbe20Sgysit
2178973dbe20Sgysit // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs the
2179973dbe20Sgysit // two paddings in one step.
2180973dbe20Sgysit auto newSliceOp = rewriter.create<ExtractSliceOp>(
21812d70eff8SJacques Pienaar padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
2182973dbe20Sgysit innerSliceOp.getMixedStrides());
2183973dbe20Sgysit auto newPadOp = rewriter.create<PadOp>(
2184973dbe20Sgysit padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
21852d70eff8SJacques Pienaar padOp.getMixedLowPad(), newHighPad, padOp.getNofold());
2186973dbe20Sgysit rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
2187973dbe20Sgysit newPadOp.getRegion().begin());
2188973dbe20Sgysit rewriter.replaceOp(padOp, newPadOp.getResult());
2189973dbe20Sgysit return success();
2190973dbe20Sgysit }
2191973dbe20Sgysit };
2192973dbe20Sgysit
2193fd0c6f53SAlexander Belyaev } // namespace
2194fd0c6f53SAlexander Belyaev
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2195fd0c6f53SAlexander Belyaev void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
2196fd0c6f53SAlexander Belyaev MLIRContext *context) {
2197973dbe20Sgysit results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
2198973dbe20Sgysit FoldOrthogonalPaddings>(context);
2199fd0c6f53SAlexander Belyaev }
2200fd0c6f53SAlexander Belyaev
2201fd0c6f53SAlexander Belyaev /// Return the padding value of the PadOp if it constant. In this context,
2202fd0c6f53SAlexander Belyaev /// "constant" means an actual constant or "defined outside of the block".
2203fd0c6f53SAlexander Belyaev ///
2204fd0c6f53SAlexander Belyaev /// Values are considered constant in three cases:
2205fd0c6f53SAlexander Belyaev /// - A ConstantLike value.
2206fd0c6f53SAlexander Belyaev /// - A basic block argument from a different block.
2207fd0c6f53SAlexander Belyaev /// - A value defined outside of the block.
2208fd0c6f53SAlexander Belyaev ///
2209fd0c6f53SAlexander Belyaev /// If the padding value is not constant, an empty Value is returned.
getConstantPaddingValue()2210fd0c6f53SAlexander Belyaev Value PadOp::getConstantPaddingValue() {
2211fd0c6f53SAlexander Belyaev auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
2212fd0c6f53SAlexander Belyaev if (!yieldOp)
2213fd0c6f53SAlexander Belyaev return {};
22142d70eff8SJacques Pienaar Value padValue = yieldOp.getValue();
2215fd0c6f53SAlexander Belyaev // Check if yield value is a constant.
2216fd0c6f53SAlexander Belyaev if (matchPattern(padValue, m_Constant()))
2217fd0c6f53SAlexander Belyaev return padValue;
2218fd0c6f53SAlexander Belyaev // Check if yield value is defined inside the PadOp block.
2219fd0c6f53SAlexander Belyaev if (padValue.getParentBlock() == &getRegion().front())
2220fd0c6f53SAlexander Belyaev return {};
2221fd0c6f53SAlexander Belyaev // Else: Yield value defined outside of the PadOp block.
2222fd0c6f53SAlexander Belyaev return padValue;
2223fd0c6f53SAlexander Belyaev }
2224fd0c6f53SAlexander Belyaev
fold(ArrayRef<Attribute>)2225fd0c6f53SAlexander Belyaev OpFoldResult PadOp::fold(ArrayRef<Attribute>) {
2226fd0c6f53SAlexander Belyaev if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
22272d70eff8SJacques Pienaar !getNofold())
22282d70eff8SJacques Pienaar return getSource();
2229fd0c6f53SAlexander Belyaev return {};
2230fd0c6f53SAlexander Belyaev }
2231fd0c6f53SAlexander Belyaev
2232fd0c6f53SAlexander Belyaev //===----------------------------------------------------------------------===//
22337fbf55c9SNicolas Vasilache // ParallelInsertSliceOp
22347fbf55c9SNicolas Vasilache //===----------------------------------------------------------------------===//
22357fbf55c9SNicolas Vasilache
getTiedOpResult()22367fbf55c9SNicolas Vasilache OpResult ParallelInsertSliceOp::getTiedOpResult() {
22377fbf55c9SNicolas Vasilache ParallelCombiningOpInterface parallelCombiningParent =
22387fbf55c9SNicolas Vasilache getParallelCombiningParent();
22397fbf55c9SNicolas Vasilache for (const auto &it :
22407fbf55c9SNicolas Vasilache llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
22417fbf55c9SNicolas Vasilache Operation &nextOp = it.value();
22427fbf55c9SNicolas Vasilache if (&nextOp == getOperation())
22437fbf55c9SNicolas Vasilache return parallelCombiningParent.getParentResult(it.index());
22447fbf55c9SNicolas Vasilache }
22457fbf55c9SNicolas Vasilache llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
22467fbf55c9SNicolas Vasilache }
22477fbf55c9SNicolas Vasilache
22487fbf55c9SNicolas Vasilache // Build a ParallelInsertSliceOp with mixed static and dynamic entries.
build(OpBuilder & b,OperationState & result,Value source,Value dest,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides,ArrayRef<NamedAttribute> attrs)22497fbf55c9SNicolas Vasilache void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
22507fbf55c9SNicolas Vasilache Value source, Value dest,
22517fbf55c9SNicolas Vasilache ArrayRef<OpFoldResult> offsets,
22527fbf55c9SNicolas Vasilache ArrayRef<OpFoldResult> sizes,
22537fbf55c9SNicolas Vasilache ArrayRef<OpFoldResult> strides,
22547fbf55c9SNicolas Vasilache ArrayRef<NamedAttribute> attrs) {
22557fbf55c9SNicolas Vasilache SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
22567fbf55c9SNicolas Vasilache SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
22577fbf55c9SNicolas Vasilache dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
22587fbf55c9SNicolas Vasilache ShapedType::kDynamicStrideOrOffset);
22597fbf55c9SNicolas Vasilache dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
22607fbf55c9SNicolas Vasilache ShapedType::kDynamicSize);
22617fbf55c9SNicolas Vasilache dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
22627fbf55c9SNicolas Vasilache ShapedType::kDynamicStrideOrOffset);
22637fbf55c9SNicolas Vasilache build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
22647fbf55c9SNicolas Vasilache dynamicStrides, b.getI64ArrayAttr(staticOffsets),
22657fbf55c9SNicolas Vasilache b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
22667fbf55c9SNicolas Vasilache result.addAttributes(attrs);
22677fbf55c9SNicolas Vasilache }
22687fbf55c9SNicolas Vasilache
22697fbf55c9SNicolas Vasilache // Build a ParallelInsertSliceOp with dynamic entries.
build(OpBuilder & b,OperationState & result,Value source,Value dest,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)22707fbf55c9SNicolas Vasilache void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
22717fbf55c9SNicolas Vasilache Value source, Value dest, ValueRange offsets,
22727fbf55c9SNicolas Vasilache ValueRange sizes, ValueRange strides,
22737fbf55c9SNicolas Vasilache ArrayRef<NamedAttribute> attrs) {
22747fbf55c9SNicolas Vasilache SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
22757fbf55c9SNicolas Vasilache llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
22767fbf55c9SNicolas Vasilache SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
22777fbf55c9SNicolas Vasilache llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
22787fbf55c9SNicolas Vasilache SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
22797fbf55c9SNicolas Vasilache llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
22807fbf55c9SNicolas Vasilache build(b, result, source, dest, offsetValues, sizeValues, strideValues);
22817fbf55c9SNicolas Vasilache }
22827fbf55c9SNicolas Vasilache
verify()22837fbf55c9SNicolas Vasilache LogicalResult ParallelInsertSliceOp::verify() {
22847fbf55c9SNicolas Vasilache if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
22857fbf55c9SNicolas Vasilache return this->emitError("expected ParallelCombiningOpInterface parent, got:")
22867fbf55c9SNicolas Vasilache << *(getOperation()->getParentOp());
2287*c9fb3c6eSNicolas Vasilache
2288*c9fb3c6eSNicolas Vasilache ShapedType expectedType;
2289*c9fb3c6eSNicolas Vasilache SliceVerificationResult result =
2290*c9fb3c6eSNicolas Vasilache verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
2291*c9fb3c6eSNicolas Vasilache getStaticSizes(), getStaticStrides(), &expectedType);
2292*c9fb3c6eSNicolas Vasilache return produceSliceErrorMsg(result, *this, expectedType);
22937fbf55c9SNicolas Vasilache }
22947fbf55c9SNicolas Vasilache
22957fbf55c9SNicolas Vasilache namespace {
22967fbf55c9SNicolas Vasilache /// Pattern to rewrite a parallel_insert_slice op with constant arguments.
22977fbf55c9SNicolas Vasilache class ParallelInsertSliceOpConstantArgumentFolder final
22987fbf55c9SNicolas Vasilache : public OpRewritePattern<ParallelInsertSliceOp> {
22997fbf55c9SNicolas Vasilache public:
23007fbf55c9SNicolas Vasilache using OpRewritePattern<ParallelInsertSliceOp>::OpRewritePattern;
23017fbf55c9SNicolas Vasilache
matchAndRewrite(ParallelInsertSliceOp insertSliceOp,PatternRewriter & rewriter) const23027fbf55c9SNicolas Vasilache LogicalResult matchAndRewrite(ParallelInsertSliceOp insertSliceOp,
23037fbf55c9SNicolas Vasilache PatternRewriter &rewriter) const override {
23047fbf55c9SNicolas Vasilache // No constant operand, just return.
23057fbf55c9SNicolas Vasilache if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
23067fbf55c9SNicolas Vasilache return matchPattern(operand, matchConstantIndex());
23077fbf55c9SNicolas Vasilache }))
23087fbf55c9SNicolas Vasilache return failure();
23097fbf55c9SNicolas Vasilache
23107fbf55c9SNicolas Vasilache // At least one of offsets/sizes/strides is a new constant.
23117fbf55c9SNicolas Vasilache // Form the new list of operands and constant attributes from the
23127fbf55c9SNicolas Vasilache // existing.
23137fbf55c9SNicolas Vasilache SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
23147fbf55c9SNicolas Vasilache SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
23157fbf55c9SNicolas Vasilache SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
23167fbf55c9SNicolas Vasilache canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
23177fbf55c9SNicolas Vasilache canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
23187fbf55c9SNicolas Vasilache canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
23197fbf55c9SNicolas Vasilache
23207fbf55c9SNicolas Vasilache // Create the new op in canonical form.
2321*c9fb3c6eSNicolas Vasilache auto sourceType =
2322*c9fb3c6eSNicolas Vasilache tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
2323*c9fb3c6eSNicolas Vasilache insertSliceOp.getSourceType().getRank(),
2324*c9fb3c6eSNicolas Vasilache insertSliceOp.getDestType(), mixedOffsets, mixedSizes,
2325*c9fb3c6eSNicolas Vasilache mixedStrides);
2326*c9fb3c6eSNicolas Vasilache Value toInsert = insertSliceOp.getSource();
2327*c9fb3c6eSNicolas Vasilache if (sourceType != insertSliceOp.getSourceType()) {
2328*c9fb3c6eSNicolas Vasilache OpBuilder::InsertionGuard g(rewriter);
2329*c9fb3c6eSNicolas Vasilache rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2330*c9fb3c6eSNicolas Vasilache toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
2331*c9fb3c6eSNicolas Vasilache sourceType, toInsert);
2332*c9fb3c6eSNicolas Vasilache }
23337fbf55c9SNicolas Vasilache rewriter.replaceOpWithNewOp<ParallelInsertSliceOp>(
2334*c9fb3c6eSNicolas Vasilache insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2335*c9fb3c6eSNicolas Vasilache mixedSizes, mixedStrides);
23367fbf55c9SNicolas Vasilache return success();
23377fbf55c9SNicolas Vasilache }
23387fbf55c9SNicolas Vasilache };
23397fbf55c9SNicolas Vasilache } // namespace
23407fbf55c9SNicolas Vasilache
23417fbf55c9SNicolas Vasilache LogicalResult
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)23427fbf55c9SNicolas Vasilache ParallelInsertSliceOp::fold(ArrayRef<Attribute> operands,
23437fbf55c9SNicolas Vasilache SmallVectorImpl<OpFoldResult> &results) {
2344*c9fb3c6eSNicolas Vasilache return foldInsertOp(*this, operands);
23457fbf55c9SNicolas Vasilache }
23467fbf55c9SNicolas Vasilache
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)23477fbf55c9SNicolas Vasilache void ParallelInsertSliceOp::getCanonicalizationPatterns(
23487fbf55c9SNicolas Vasilache RewritePatternSet &results, MLIRContext *context) {
2349*c9fb3c6eSNicolas Vasilache results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
2350*c9fb3c6eSNicolas Vasilache InsertSliceOpCastFolder<ParallelInsertSliceOp>,
2351*c9fb3c6eSNicolas Vasilache InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
23527fbf55c9SNicolas Vasilache }
23537fbf55c9SNicolas Vasilache
23547fbf55c9SNicolas Vasilache //===----------------------------------------------------------------------===//
23556a8ba318SRiver Riddle // SplatOp
23566a8ba318SRiver Riddle //===----------------------------------------------------------------------===//
23576a8ba318SRiver Riddle
fold(ArrayRef<Attribute> operands)23586a8ba318SRiver Riddle OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
23596a8ba318SRiver Riddle auto constOperand = operands.front();
23606a8ba318SRiver Riddle if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
23616a8ba318SRiver Riddle return {};
23626a8ba318SRiver Riddle
23636a8ba318SRiver Riddle // SplatElementsAttr::get treats single value for second arg as being a splat.
23646a8ba318SRiver Riddle return SplatElementsAttr::get(getType(), {constOperand});
23656a8ba318SRiver Riddle }
23666a8ba318SRiver Riddle
23676a8ba318SRiver Riddle //===----------------------------------------------------------------------===//
2368444822d7SSean Silva // TableGen'd op method definitions
2369444822d7SSean Silva //===----------------------------------------------------------------------===//
2370444822d7SSean Silva
2371444822d7SSean Silva #define GET_OP_CLASSES
2372444822d7SSean Silva #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
2373