1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/IR/Tensor.h"
10 #include "mlir/IR/BlockAndValueMapping.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/Matchers.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/IR/TypeUtilities.h"
15 #include "llvm/ADT/STLExtras.h"
16 
17 using namespace mlir;
18 using namespace mlir::tensor;
19 
20 //===----------------------------------------------------------------------===//
21 // CastOp
22 //===----------------------------------------------------------------------===//
23 
24 /// Determines whether tensor::CastOp casts to a more dynamic version of the
25 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
26 /// implement canonicalization patterns for ops in different dialects that may
27 /// consume the results of tensor.cast operations. Such foldable tensor.cast
28 /// operations are typically inserted as `subtensor` ops and are canonicalized,
29 /// to preserve the type compatibility of their uses.
30 ///
31 /// Returns true when all conditions are met:
32 /// 1. source and result are ranked tensors with same element type and rank.
33 /// 2. the tensor type has more static information than the result
34 ///
35 /// Example:
36 /// ```mlir
37 ///   %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
38 ///   %2 = consumer %1 ... : tensor<?x?xf32> ...
39 /// ```
40 ///
41 /// folds into:
42 ///
43 /// ```mlir
44 ///   %2 = consumer %0 ... : tensor<8x16xf32> ...
45 /// ```
46 bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
47   if (!castOp)
48     return false;
49 
50   RankedTensorType sourceType =
51       castOp.source().getType().dyn_cast<RankedTensorType>();
52   RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>();
53 
54   // Requires RankedTensorType.
55   if (!sourceType || !resultType)
56     return false;
57 
58   // Requires same elemental type.
59   if (sourceType.getElementType() != resultType.getElementType())
60     return false;
61 
62   // Requires same rank.
63   if (sourceType.getRank() != resultType.getRank())
64     return false;
65 
66   // If cast is towards more static sizes along any dimension, don't fold.
67   for (auto t : llvm::zip(sourceType.getShape(), resultType.getShape())) {
68     if (ShapedType::isDynamic(std::get<0>(t)) &&
69         !ShapedType::isDynamic(std::get<1>(t)))
70       return false;
71   }
72 
73   return true;
74 }
75 
76 bool CastOp::areCastCompatible(Type a, Type b) {
77   auto aT = a.dyn_cast<TensorType>();
78   auto bT = b.dyn_cast<TensorType>();
79   if (!aT || !bT)
80     return false;
81 
82   if (aT.getElementType() != bT.getElementType())
83     return false;
84 
85   return succeeded(verifyCompatibleShape(aT, bT));
86 }
87 
88 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
89   return impl::foldCastOp(*this);
90 }
91 
92 /// Compute a TensorType that has the joined shape knowledge of the two
93 /// given TensorTypes. The element types need to match.
94 static TensorType joinShapes(TensorType one, TensorType two) {
95   assert(one.getElementType() == two.getElementType());
96 
97   if (!one.hasRank())
98     return two;
99   if (!two.hasRank())
100     return one;
101 
102   int64_t rank = one.getRank();
103   if (rank != two.getRank())
104     return {};
105 
106   SmallVector<int64_t, 4> join;
107   join.reserve(rank);
108   for (int64_t i = 0; i < rank; ++i) {
109     if (one.isDynamicDim(i)) {
110       join.push_back(two.getDimSize(i));
111       continue;
112     }
113     if (two.isDynamicDim(i)) {
114       join.push_back(one.getDimSize(i));
115       continue;
116     }
117     if (one.getDimSize(i) != two.getDimSize(i))
118       return {};
119     join.push_back(one.getDimSize(i));
120   }
121   return RankedTensorType::get(join, one.getElementType());
122 }
123 
124 namespace {
125 
126 /// Replaces chains of two tensor.cast operations by a single tensor.cast
127 /// operation if doing so does not remove runtime constraints.
128 struct ChainedTensorCast : public OpRewritePattern<CastOp> {
129   using OpRewritePattern<CastOp>::OpRewritePattern;
130 
131   LogicalResult matchAndRewrite(CastOp tensorCast,
132                                 PatternRewriter &rewriter) const final {
133     auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
134 
135     if (!tensorCastOperand)
136       return failure();
137 
138     auto sourceType =
139         tensorCastOperand.getOperand().getType().cast<TensorType>();
140     auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
141     auto resultType = tensorCast.getType().cast<TensorType>();
142 
143     // We can remove the intermediate cast if joining all three produces the
144     // same result as just joining the source and result shapes.
145     auto firstJoin =
146         joinShapes(joinShapes(sourceType, intermediateType), resultType);
147 
148     // The join might not exist if the cast sequence would fail at runtime.
149     if (!firstJoin)
150       return failure();
151 
152     // The newJoin always exists if the above join exists, it might just contain
153     // less information. If so, we cannot drop the intermediate cast, as doing
154     // so would remove runtime checks.
155     auto newJoin = joinShapes(sourceType, resultType);
156     if (firstJoin != newJoin)
157       return failure();
158 
159     rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
160                                         tensorCastOperand.getOperand());
161     return success();
162   }
163 };
164 
165 } // namespace
166 
167 void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
168                                          MLIRContext *context) {
169   results.insert<ChainedTensorCast>(context);
170 }
171 
172 //===----------------------------------------------------------------------===//
173 // ExtractOp
174 //===----------------------------------------------------------------------===//
175 
176 static LogicalResult verify(ExtractOp op) {
177   // Verify the # indices match if we have a ranked type.
178   if (auto tensorType = op.tensor().getType().dyn_cast<RankedTensorType>())
179     if (tensorType.getRank() != static_cast<int64_t>(op.indices().size()))
180       return op.emitOpError("incorrect number of indices for extract_element");
181 
182   return success();
183 }
184 
185 OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
186   // The tensor operand must be a known constant.
187   Attribute tensor = operands.front();
188   if (!tensor)
189     return {};
190   // If this is a splat elements attribute, simply return the value. All of the
191   // elements of a splat attribute are the same.
192   if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
193     return splatTensor.getSplatValue();
194 
195   // Otherwise, collect the constant indices into the tensor.
196   SmallVector<uint64_t, 8> indices;
197   for (Attribute indice : llvm::drop_begin(operands, 1)) {
198     if (!indice || !indice.isa<IntegerAttr>())
199       return {};
200     indices.push_back(indice.cast<IntegerAttr>().getInt());
201   }
202 
203   // If this is an elements attribute, query the value at the given indices.
204   auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
205   if (elementsAttr && elementsAttr.isValidIndex(indices))
206     return elementsAttr.getValue(indices);
207   return {};
208 }
209 
210 //===----------------------------------------------------------------------===//
211 // FromElementsOp
212 //===----------------------------------------------------------------------===//
213 
214 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
215                            Type elementType, ValueRange elements) {
216   Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())},
217                                         elementType);
218   result.addOperands(elements);
219   result.addTypes(resultTy);
220 }
221 
222 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
223                            ValueRange elements) {
224   assert(!elements.empty() && "expected at least one element");
225   build(builder, result, elements.front().getType(), elements);
226 }
227 
228 namespace {
229 
230 // Canonicalizes the pattern of the form
231 //
232 // %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
233 // %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32>
234 //
235 // to just %element.
236 struct ExtractElementFromTensorFromElements
237     : public OpRewritePattern<tensor::ExtractOp> {
238   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
239 
240   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
241                                 PatternRewriter &rewriter) const final {
242     if (extract.indices().size() != 1)
243       return failure();
244 
245     auto tensorFromElements = extract.tensor().getDefiningOp<FromElementsOp>();
246     if (tensorFromElements == nullptr)
247       return failure();
248 
249     APInt index;
250     if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index)))
251       return failure();
252     rewriter.replaceOp(extract,
253                        tensorFromElements.getOperand(index.getZExtValue()));
254     return success();
255   }
256 };
257 
258 } // namespace
259 
260 void FromElementsOp::getCanonicalizationPatterns(
261     OwningRewritePatternList &results, MLIRContext *context) {
262   results.insert<ExtractElementFromTensorFromElements>(context);
263 }
264 
265 //===----------------------------------------------------------------------===//
266 // GenerateOp
267 //===----------------------------------------------------------------------===//
268 
269 static LogicalResult verify(GenerateOp op) {
270   // Ensure that the tensor type has as many dynamic dimensions as are specified
271   // by the operands.
272   RankedTensorType resultTy = op.getType().cast<RankedTensorType>();
273   if (op.getNumOperands() != resultTy.getNumDynamicDims())
274     return op.emitError("must have as many index operands as dynamic extents "
275                         "in the result type");
276 
277   // Ensure that region arguments span the index space.
278   if (!llvm::all_of(op.body().getArgumentTypes(),
279                     [](Type ty) { return ty.isIndex(); }))
280     return op.emitError("all body arguments must be index");
281   if (op.body().getNumArguments() != resultTy.getRank())
282     return op.emitError("must have one body argument per input dimension");
283 
284   // Ensure that the region yields an element of the right type.
285   auto yieldOp =
286       llvm::cast<YieldOp>(op.body().getBlocks().front().getTerminator());
287   if (yieldOp.value().getType() != resultTy.getElementType())
288     return op.emitOpError(
289         "body must be terminated with a `yield` operation of the tensor "
290         "element type");
291 
292   return success();
293 }
294 
295 void GenerateOp::build(
296     OpBuilder &b, OperationState &result, Type resultTy,
297     ValueRange dynamicExtents,
298     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
299   build(b, result, resultTy, dynamicExtents);
300 
301   // Build and populate body.
302   OpBuilder::InsertionGuard guard(b);
303   Region *bodyRegion = result.regions.front().get();
304   auto rank = resultTy.cast<RankedTensorType>().getRank();
305   SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
306   Block *bodyBlock =
307       b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes);
308   bodyBuilder(b, result.location, bodyBlock->getArguments());
309 }
310 
311 namespace {
312 
313 /// Canonicalizes tensor.generate operations with a constant
314 /// operand into the equivalent operation with the operand expressed in the
315 /// result type, instead. We also insert a type cast to make sure that the
316 /// resulting IR is still well-typed.
317 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
318   using OpRewritePattern<GenerateOp>::OpRewritePattern;
319 
320   LogicalResult matchAndRewrite(GenerateOp tensorFromElements,
321                                 PatternRewriter &rewriter) const final {
322     auto resultType =
323         tensorFromElements.getResult().getType().cast<RankedTensorType>();
324 
325     if (resultType.hasStaticShape())
326       return failure();
327 
328     SmallVector<Value, 4> newOperands;
329     SmallVector<int64_t, 4> newShape;
330     auto operandsIt = tensorFromElements.dynamicExtents().begin();
331 
332     for (int64_t dim : resultType.getShape()) {
333       if (dim != RankedTensorType::kDynamicSize) {
334         newShape.push_back(dim);
335         continue;
336       }
337       APInt index;
338       if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
339         newShape.push_back(RankedTensorType::kDynamicSize);
340         newOperands.push_back(*operandsIt++);
341         continue;
342       }
343       newShape.push_back(index.getSExtValue());
344       operandsIt++;
345     }
346 
347     if (newOperands.size() == tensorFromElements.dynamicExtents().size())
348       return failure();
349 
350     auto loc = tensorFromElements.getLoc();
351     auto newOp = rewriter.create<GenerateOp>(
352         loc, RankedTensorType::get(newShape, resultType.getElementType()),
353         newOperands);
354     rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(),
355                                 newOp.body().begin());
356     rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType,
357                                                 newOp);
358     return success();
359   }
360 };
361 
362 /// Canonicalizes the pattern of the form
363 ///
364 /// %tensor = tensor.generate %x {
365 ///   ^bb0(%arg0: index):  // no predecessors
366 ///   <computation>
367 ///   yield %1 : index
368 /// } : tensor<?xindex>
369 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
370 ///
371 /// to just <computation> with %arg0 replaced by %c0. We only do this if the
372 /// tensor.generate operation has no side-effects.
373 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
374   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
375 
376   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
377                                 PatternRewriter &rewriter) const final {
378     auto tensorFromElements = extract.tensor().getDefiningOp<GenerateOp>();
379     if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
380       return failure();
381 
382     BlockAndValueMapping mapping;
383     Block *body = tensorFromElements.getBody();
384     mapping.map(body->getArguments(), extract.indices());
385     for (auto &op : body->without_terminator())
386       rewriter.clone(op, mapping);
387 
388     auto yield = cast<YieldOp>(body->getTerminator());
389 
390     rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value()));
391     return success();
392   }
393 };
394 
395 /// Canonicalizes the pattern of the form
396 ///
397 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
398 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
399 ///
400 /// to
401 ///
402 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
403 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
404   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
405 
406   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
407                                 PatternRewriter &rewriter) const final {
408     auto tensorCast = extract.tensor().getDefiningOp<tensor::CastOp>();
409     if (!tensorCast)
410       return failure();
411 
412     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(extract, tensorCast.source(),
413                                                    extract.indices());
414     return success();
415   }
416 };
417 
418 } // namespace
419 
420 void GenerateOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
421                                              MLIRContext *context) {
422   // TODO: Move extract patterns to tensor::ExtractOp.
423   results.insert<ExtractFromTensorGenerate, ExtractFromTensorCast,
424                  StaticTensorGenerate>(context);
425 }
426 
427 //===----------------------------------------------------------------------===//
428 // TableGen'd op method definitions
429 //===----------------------------------------------------------------------===//
430 
431 #define GET_OP_CLASSES
432 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
433