1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/IR/Tensor.h"
10 #include "mlir/IR/BlockAndValueMapping.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/Matchers.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/IR/TypeUtilities.h"
15 #include "llvm/ADT/STLExtras.h"
16 
17 using namespace mlir;
18 using namespace mlir::tensor;
19 
20 //===----------------------------------------------------------------------===//
21 // CastOp
22 //===----------------------------------------------------------------------===//
23 
24 /// Determines whether tensor::CastOp casts to a more dynamic version of the
25 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
26 /// implement canonicalization patterns for ops in different dialects that may
27 /// consume the results of tensor.cast operations. Such foldable tensor.cast
28 /// operations are typically inserted as `subtensor` ops and are canonicalized,
29 /// to preserve the type compatibility of their uses.
30 ///
31 /// Returns true when all conditions are met:
32 /// 1. source and result are ranked tensors with same element type and rank.
33 /// 2. the tensor type has more static information than the result
34 ///
35 /// Example:
36 /// ```mlir
37 ///   %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
38 ///   %2 = consumer %1 ... : tensor<?x?xf32> ...
39 /// ```
40 ///
41 /// folds into:
42 ///
43 /// ```mlir
44 ///   %2 = consumer %0 ... : tensor<8x16xf32> ...
45 /// ```
46 bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
47   if (!castOp)
48     return false;
49 
50   RankedTensorType sourceType =
51       castOp.source().getType().dyn_cast<RankedTensorType>();
52   RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>();
53 
54   // Requires RankedTensorType.
55   if (!sourceType || !resultType)
56     return false;
57 
58   // Requires same elemental type.
59   if (sourceType.getElementType() != resultType.getElementType())
60     return false;
61 
62   // Requires same rank.
63   if (sourceType.getRank() != resultType.getRank())
64     return false;
65 
66   // If cast is towards more static sizes along any dimension, don't fold.
67   for (auto t : llvm::zip(sourceType.getShape(), resultType.getShape())) {
68     if (ShapedType::isDynamic(std::get<0>(t)) &&
69         !ShapedType::isDynamic(std::get<1>(t)))
70       return false;
71   }
72 
73   return true;
74 }
75 
76 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
77   if (inputs.size() != 1 || outputs.size() != 1)
78     return false;
79   Type a = inputs.front(), b = outputs.front();
80   auto aT = a.dyn_cast<TensorType>();
81   auto bT = b.dyn_cast<TensorType>();
82   if (!aT || !bT)
83     return false;
84 
85   if (aT.getElementType() != bT.getElementType())
86     return false;
87 
88   return succeeded(verifyCompatibleShape(aT, bT));
89 }
90 
91 /// Compute a TensorType that has the joined shape knowledge of the two
92 /// given TensorTypes. The element types need to match.
93 static TensorType joinShapes(TensorType one, TensorType two) {
94   assert(one.getElementType() == two.getElementType());
95 
96   if (!one.hasRank())
97     return two;
98   if (!two.hasRank())
99     return one;
100 
101   int64_t rank = one.getRank();
102   if (rank != two.getRank())
103     return {};
104 
105   SmallVector<int64_t, 4> join;
106   join.reserve(rank);
107   for (int64_t i = 0; i < rank; ++i) {
108     if (one.isDynamicDim(i)) {
109       join.push_back(two.getDimSize(i));
110       continue;
111     }
112     if (two.isDynamicDim(i)) {
113       join.push_back(one.getDimSize(i));
114       continue;
115     }
116     if (one.getDimSize(i) != two.getDimSize(i))
117       return {};
118     join.push_back(one.getDimSize(i));
119   }
120   return RankedTensorType::get(join, one.getElementType());
121 }
122 
123 namespace {
124 
125 /// Replaces chains of two tensor.cast operations by a single tensor.cast
126 /// operation if doing so does not remove runtime constraints.
127 struct ChainedTensorCast : public OpRewritePattern<CastOp> {
128   using OpRewritePattern<CastOp>::OpRewritePattern;
129 
130   LogicalResult matchAndRewrite(CastOp tensorCast,
131                                 PatternRewriter &rewriter) const final {
132     auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
133 
134     if (!tensorCastOperand)
135       return failure();
136 
137     auto sourceType =
138         tensorCastOperand.getOperand().getType().cast<TensorType>();
139     auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
140     auto resultType = tensorCast.getType().cast<TensorType>();
141 
142     // We can remove the intermediate cast if joining all three produces the
143     // same result as just joining the source and result shapes.
144     auto firstJoin =
145         joinShapes(joinShapes(sourceType, intermediateType), resultType);
146 
147     // The join might not exist if the cast sequence would fail at runtime.
148     if (!firstJoin)
149       return failure();
150 
151     // The newJoin always exists if the above join exists, it might just contain
152     // less information. If so, we cannot drop the intermediate cast, as doing
153     // so would remove runtime checks.
154     auto newJoin = joinShapes(sourceType, resultType);
155     if (firstJoin != newJoin)
156       return failure();
157 
158     rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
159                                         tensorCastOperand.getOperand());
160     return success();
161   }
162 };
163 
164 } // namespace
165 
166 void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
167                                          MLIRContext *context) {
168   results.insert<ChainedTensorCast>(context);
169 }
170 
171 //===----------------------------------------------------------------------===//
172 // ExtractOp
173 //===----------------------------------------------------------------------===//
174 
175 static LogicalResult verify(ExtractOp op) {
176   // Verify the # indices match if we have a ranked type.
177   if (auto tensorType = op.tensor().getType().dyn_cast<RankedTensorType>())
178     if (tensorType.getRank() != static_cast<int64_t>(op.indices().size()))
179       return op.emitOpError("incorrect number of indices for extract_element");
180 
181   return success();
182 }
183 
184 OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
185   // The tensor operand must be a known constant.
186   Attribute tensor = operands.front();
187   if (!tensor)
188     return {};
189   // If this is a splat elements attribute, simply return the value. All of the
190   // elements of a splat attribute are the same.
191   if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
192     return splatTensor.getSplatValue();
193 
194   // Otherwise, collect the constant indices into the tensor.
195   SmallVector<uint64_t, 8> indices;
196   for (Attribute indice : llvm::drop_begin(operands, 1)) {
197     if (!indice || !indice.isa<IntegerAttr>())
198       return {};
199     indices.push_back(indice.cast<IntegerAttr>().getInt());
200   }
201 
202   // If this is an elements attribute, query the value at the given indices.
203   auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
204   if (elementsAttr && elementsAttr.isValidIndex(indices))
205     return elementsAttr.getValue(indices);
206   return {};
207 }
208 
209 //===----------------------------------------------------------------------===//
210 // FromElementsOp
211 //===----------------------------------------------------------------------===//
212 
213 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
214                            Type elementType, ValueRange elements) {
215   Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())},
216                                         elementType);
217   result.addOperands(elements);
218   result.addTypes(resultTy);
219 }
220 
221 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
222                            ValueRange elements) {
223   assert(!elements.empty() && "expected at least one element");
224   build(builder, result, elements.front().getType(), elements);
225 }
226 
227 namespace {
228 
229 // Canonicalizes the pattern of the form
230 //
231 // %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
232 // %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32>
233 //
234 // to just %element.
235 struct ExtractElementFromTensorFromElements
236     : public OpRewritePattern<tensor::ExtractOp> {
237   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
238 
239   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
240                                 PatternRewriter &rewriter) const final {
241     if (extract.indices().size() != 1)
242       return failure();
243 
244     auto tensorFromElements = extract.tensor().getDefiningOp<FromElementsOp>();
245     if (tensorFromElements == nullptr)
246       return failure();
247 
248     APInt index;
249     if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index)))
250       return failure();
251     // Prevent out of bounds accesses. This can happen in invalid code that will
252     // never execute.
253     if (tensorFromElements->getNumOperands() <= index.getZExtValue() ||
254         index.getSExtValue() < 0)
255       return failure();
256     rewriter.replaceOp(extract,
257                        tensorFromElements.getOperand(index.getZExtValue()));
258     return success();
259   }
260 };
261 
262 } // namespace
263 
264 void FromElementsOp::getCanonicalizationPatterns(
265     OwningRewritePatternList &results, MLIRContext *context) {
266   results.insert<ExtractElementFromTensorFromElements>(context);
267 }
268 
269 //===----------------------------------------------------------------------===//
270 // GenerateOp
271 //===----------------------------------------------------------------------===//
272 
273 static LogicalResult verify(GenerateOp op) {
274   // Ensure that the tensor type has as many dynamic dimensions as are specified
275   // by the operands.
276   RankedTensorType resultTy = op.getType().cast<RankedTensorType>();
277   if (op.getNumOperands() != resultTy.getNumDynamicDims())
278     return op.emitError("must have as many index operands as dynamic extents "
279                         "in the result type");
280 
281   // Ensure that region arguments span the index space.
282   if (!llvm::all_of(op.body().getArgumentTypes(),
283                     [](Type ty) { return ty.isIndex(); }))
284     return op.emitError("all body arguments must be index");
285   if (op.body().getNumArguments() != resultTy.getRank())
286     return op.emitError("must have one body argument per input dimension");
287 
288   // Ensure that the region yields an element of the right type.
289   auto yieldOp =
290       llvm::cast<YieldOp>(op.body().getBlocks().front().getTerminator());
291   if (yieldOp.value().getType() != resultTy.getElementType())
292     return op.emitOpError(
293         "body must be terminated with a `yield` operation of the tensor "
294         "element type");
295 
296   return success();
297 }
298 
299 void GenerateOp::build(
300     OpBuilder &b, OperationState &result, Type resultTy,
301     ValueRange dynamicExtents,
302     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
303   build(b, result, resultTy, dynamicExtents);
304 
305   // Build and populate body.
306   OpBuilder::InsertionGuard guard(b);
307   Region *bodyRegion = result.regions.front().get();
308   auto rank = resultTy.cast<RankedTensorType>().getRank();
309   SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
310   Block *bodyBlock =
311       b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes);
312   bodyBuilder(b, result.location, bodyBlock->getArguments());
313 }
314 
315 namespace {
316 
317 /// Canonicalizes tensor.generate operations with a constant
318 /// operand into the equivalent operation with the operand expressed in the
319 /// result type, instead. We also insert a type cast to make sure that the
320 /// resulting IR is still well-typed.
321 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
322   using OpRewritePattern<GenerateOp>::OpRewritePattern;
323 
324   LogicalResult matchAndRewrite(GenerateOp tensorFromElements,
325                                 PatternRewriter &rewriter) const final {
326     auto resultType =
327         tensorFromElements.getResult().getType().cast<RankedTensorType>();
328 
329     if (resultType.hasStaticShape())
330       return failure();
331 
332     SmallVector<Value, 4> newOperands;
333     SmallVector<int64_t, 4> newShape;
334     auto operandsIt = tensorFromElements.dynamicExtents().begin();
335 
336     for (int64_t dim : resultType.getShape()) {
337       if (dim != RankedTensorType::kDynamicSize) {
338         newShape.push_back(dim);
339         continue;
340       }
341       APInt index;
342       if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
343         newShape.push_back(RankedTensorType::kDynamicSize);
344         newOperands.push_back(*operandsIt++);
345         continue;
346       }
347       newShape.push_back(index.getSExtValue());
348       operandsIt++;
349     }
350 
351     if (newOperands.size() == tensorFromElements.dynamicExtents().size())
352       return failure();
353 
354     auto loc = tensorFromElements.getLoc();
355     auto newOp = rewriter.create<GenerateOp>(
356         loc, RankedTensorType::get(newShape, resultType.getElementType()),
357         newOperands);
358     rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(),
359                                 newOp.body().begin());
360     rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType,
361                                                 newOp);
362     return success();
363   }
364 };
365 
366 /// Canonicalizes the pattern of the form
367 ///
368 /// %tensor = tensor.generate %x {
369 ///   ^bb0(%arg0: index):  // no predecessors
370 ///   <computation>
371 ///   yield %1 : index
372 /// } : tensor<?xindex>
373 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
374 ///
375 /// to just <computation> with %arg0 replaced by %c0. We only do this if the
376 /// tensor.generate operation has no side-effects.
377 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
378   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
379 
380   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
381                                 PatternRewriter &rewriter) const final {
382     auto tensorFromElements = extract.tensor().getDefiningOp<GenerateOp>();
383     if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
384       return failure();
385 
386     BlockAndValueMapping mapping;
387     Block *body = tensorFromElements.getBody();
388     mapping.map(body->getArguments(), extract.indices());
389     for (auto &op : body->without_terminator())
390       rewriter.clone(op, mapping);
391 
392     auto yield = cast<YieldOp>(body->getTerminator());
393 
394     rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value()));
395     return success();
396   }
397 };
398 
399 /// Canonicalizes the pattern of the form
400 ///
401 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
402 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
403 ///
404 /// to
405 ///
406 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
407 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
408   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
409 
410   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
411                                 PatternRewriter &rewriter) const final {
412     auto tensorCast = extract.tensor().getDefiningOp<tensor::CastOp>();
413     if (!tensorCast)
414       return failure();
415 
416     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(extract, tensorCast.source(),
417                                                    extract.indices());
418     return success();
419   }
420 };
421 
422 } // namespace
423 
424 void GenerateOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
425                                              MLIRContext *context) {
426   // TODO: Move extract patterns to tensor::ExtractOp.
427   results.insert<ExtractFromTensorGenerate, ExtractFromTensorCast,
428                  StaticTensorGenerate>(context);
429 }
430 
431 //===----------------------------------------------------------------------===//
432 // TableGen'd op method definitions
433 //===----------------------------------------------------------------------===//
434 
435 #define GET_OP_CLASSES
436 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
437