1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/IR/Tensor.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/PatternMatch.h"
12 #include "mlir/IR/TypeUtilities.h"
13 #include "llvm/ADT/STLExtras.h"
14 
15 using namespace mlir;
16 using namespace mlir::tensor;
17 
18 //===----------------------------------------------------------------------===//
19 // CastOp
20 //===----------------------------------------------------------------------===//
21 
22 /// Determines whether tensor::CastOp casts to a more dynamic version of the
23 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
24 /// implement canonicalization patterns for ops in different dialects that may
25 /// consume the results of tensor.cast operations. Such foldable tensor.cast
26 /// operations are typically inserted as `subtensor` ops and are canonicalized,
27 /// to preserve the type compatibility of their uses.
28 ///
29 /// Returns true when all conditions are met:
30 /// 1. source and result are ranked tensors with same element type and rank.
31 /// 2. the tensor type has more static information than the result
32 ///
33 /// Example:
34 /// ```mlir
35 ///   %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
36 ///   %2 = consumer %1 ... : tensor<?x?xf32> ...
37 /// ```
38 ///
39 /// folds into:
40 ///
41 /// ```mlir
42 ///   %2 = consumer %0 ... : tensor<8x16xf32> ...
43 /// ```
44 bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
45   if (!castOp)
46     return false;
47 
48   RankedTensorType sourceType =
49       castOp.source().getType().dyn_cast<RankedTensorType>();
50   RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>();
51 
52   // Requires RankedTensorType.
53   if (!sourceType || !resultType)
54     return false;
55 
56   // Requires same elemental type.
57   if (sourceType.getElementType() != resultType.getElementType())
58     return false;
59 
60   // Requires same rank.
61   if (sourceType.getRank() != resultType.getRank())
62     return false;
63 
64   // If cast is towards more static sizes along any dimension, don't fold.
65   for (auto t : llvm::zip(sourceType.getShape(), resultType.getShape())) {
66     if (ShapedType::isDynamic(std::get<0>(t)) &&
67         !ShapedType::isDynamic(std::get<1>(t)))
68       return false;
69   }
70 
71   return true;
72 }
73 
74 bool CastOp::areCastCompatible(Type a, Type b) {
75   auto aT = a.dyn_cast<TensorType>();
76   auto bT = b.dyn_cast<TensorType>();
77   if (!aT || !bT)
78     return false;
79 
80   if (aT.getElementType() != bT.getElementType())
81     return false;
82 
83   return succeeded(verifyCompatibleShape(aT, bT));
84 }
85 
86 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
87   return impl::foldCastOp(*this);
88 }
89 
90 /// Compute a TensorType that has the joined shape knowledge of the two
91 /// given TensorTypes. The element types need to match.
92 static TensorType joinShapes(TensorType one, TensorType two) {
93   assert(one.getElementType() == two.getElementType());
94 
95   if (!one.hasRank())
96     return two;
97   if (!two.hasRank())
98     return one;
99 
100   int64_t rank = one.getRank();
101   if (rank != two.getRank())
102     return {};
103 
104   SmallVector<int64_t, 4> join;
105   join.reserve(rank);
106   for (int64_t i = 0; i < rank; ++i) {
107     if (one.isDynamicDim(i)) {
108       join.push_back(two.getDimSize(i));
109       continue;
110     }
111     if (two.isDynamicDim(i)) {
112       join.push_back(one.getDimSize(i));
113       continue;
114     }
115     if (one.getDimSize(i) != two.getDimSize(i))
116       return {};
117     join.push_back(one.getDimSize(i));
118   }
119   return RankedTensorType::get(join, one.getElementType());
120 }
121 
122 namespace {
123 
124 /// Replaces chains of two tensor.cast operations by a single tensor.cast
125 /// operation if doing so does not remove runtime constraints.
126 struct ChainedTensorCast : public OpRewritePattern<CastOp> {
127   using OpRewritePattern<CastOp>::OpRewritePattern;
128 
129   LogicalResult matchAndRewrite(CastOp tensorCast,
130                                 PatternRewriter &rewriter) const final {
131     auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
132 
133     if (!tensorCastOperand)
134       return failure();
135 
136     auto sourceType =
137         tensorCastOperand.getOperand().getType().cast<TensorType>();
138     auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
139     auto resultType = tensorCast.getType().cast<TensorType>();
140 
141     // We can remove the intermediate cast if joining all three produces the
142     // same result as just joining the source and result shapes.
143     auto firstJoin =
144         joinShapes(joinShapes(sourceType, intermediateType), resultType);
145 
146     // The join might not exist if the cast sequence would fail at runtime.
147     if (!firstJoin)
148       return failure();
149 
150     // The newJoin always exists if the above join exists, it might just contain
151     // less information. If so, we cannot drop the intermediate cast, as doing
152     // so would remove runtime checks.
153     auto newJoin = joinShapes(sourceType, resultType);
154     if (firstJoin != newJoin)
155       return failure();
156 
157     rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
158                                         tensorCastOperand.getOperand());
159     return success();
160   }
161 };
162 
163 } // namespace
164 
165 void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
166                                          MLIRContext *context) {
167   results.insert<ChainedTensorCast>(context);
168 }
169 
170 //===----------------------------------------------------------------------===//
171 // ExtractOp
172 //===----------------------------------------------------------------------===//
173 
174 static LogicalResult verify(ExtractOp op) {
175   // Verify the # indices match if we have a ranked type.
176   if (auto tensorType = op.tensor().getType().dyn_cast<RankedTensorType>())
177     if (tensorType.getRank() != static_cast<int64_t>(op.indices().size()))
178       return op.emitOpError("incorrect number of indices for extract_element");
179 
180   return success();
181 }
182 
183 OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
184   // The tensor operand must be a known constant.
185   Attribute tensor = operands.front();
186   if (!tensor)
187     return {};
188   // If this is a splat elements attribute, simply return the value. All of the
189   // elements of a splat attribute are the same.
190   if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
191     return splatTensor.getSplatValue();
192 
193   // Otherwise, collect the constant indices into the tensor.
194   SmallVector<uint64_t, 8> indices;
195   for (Attribute indice : llvm::drop_begin(operands, 1)) {
196     if (!indice || !indice.isa<IntegerAttr>())
197       return {};
198     indices.push_back(indice.cast<IntegerAttr>().getInt());
199   }
200 
201   // If this is an elements attribute, query the value at the given indices.
202   auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
203   if (elementsAttr && elementsAttr.isValidIndex(indices))
204     return elementsAttr.getValue(indices);
205   return {};
206 }
207 
208 //===----------------------------------------------------------------------===//
209 // TableGen'd op method definitions
210 //===----------------------------------------------------------------------===//
211 
212 #define GET_OP_CLASSES
213 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
214